IW2025's picture
Upload 30 files
93fe96e verified
import gradio as gr
import os
from pathlib import Path
import fitz # PyMuPDF
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from transformers import pipeline
import torch
import base64
from PIL import Image
import io
import re
import time
# --- Local Test Version ---
class LocalCurriculumChatbot:
def __init__(self, slides_dir="Slides", fast_mode=True):
self.pdf_pages = {} # {filename: {page_num: text}}
self.pdf_files = {} # {filename: path}
self.chunks = []
self.chunk_metadata = []
self.vector_db = None
self.embeddings = None
self.llm = None
self.response_cache = {} # Simple cache for responses
self.fast_mode = fast_mode # Skip LLM for faster responses
self._process_pdfs(slides_dir)
self._build_vector_db()
if not fast_mode:
self._setup_llm()
else:
print("πŸš€ Fast mode enabled - LLM disabled for instant responses")
def _process_pdfs(self, slides_dir):
slides_path = Path(slides_dir)
pdf_files = list(slides_path.glob("*.pdf"))
for pdf_file in pdf_files:
self.pdf_files[pdf_file.name] = str(pdf_file)
doc = fitz.open(str(pdf_file))
pages = {}
for page_num in range(len(doc)):
page = doc[page_num]
text = page.get_text()
if text.strip():
pages[page_num + 1] = text.strip()
self.pdf_pages[pdf_file.name] = pages
doc.close()
# Add each page as a chunk
for page_num, text in pages.items():
self.chunks.append(text)
self.chunk_metadata.append({
"filename": pdf_file.name,
"page_number": page_num
})
def _build_vector_db(self):
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
self.vector_db = Chroma.from_texts(
texts=self.chunks,
embedding=self.embeddings,
metadatas=self.chunk_metadata,
persist_directory="./chroma_db"
)
def _setup_llm(self):
try:
# Use a very small, fast model for local testing
model_name = "distilgpt2" # Much smaller and faster
pipe = pipeline(
"text-generation",
model=model_name,
max_new_tokens=50, # Very short for speed
temperature=0.3,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
device_map="auto" if torch.cuda.is_available() else None,
# Performance optimizations
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True
)
self.llm = pipe
print("βœ… Local model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load local model: {e}")
print("Falling back to fast mode...")
self.llm = None
def get_pdf_page_image(self, pdf_path, page_num):
try:
doc = fitz.open(pdf_path)
if page_num <= len(doc):
page = doc[page_num - 1]
mat = fitz.Matrix(1.5, 1.5)
pix = page.get_pixmap(matrix=mat)
img_data = pix.tobytes("png")
img = Image.open(io.BytesIO(img_data))
if img.mode != 'RGB':
img = img.convert('RGB')
doc.close()
return img
doc.close()
return None
except Exception as e:
print(f"Error rendering PDF page: {str(e)}")
return None
def chat(self, query):
"""Fast chat function optimized for local testing"""
start_time = time.time()
# Check cache first for faster responses
if query in self.response_cache:
print(f"βœ… Using cached response (took {time.time() - start_time:.2f}s)")
return self.response_cache[query]
# First, try to find relevant curriculum content
results = self.vector_db.similarity_search(query, k=3) # Reduced for speed
# Check if query is curriculum-related
curriculum_relevance_score = 0
if results:
curriculum_relevance_score = len([r for r in results if r.page_content.strip()])
# Debug: Print what we found
print(f"Query: {query}")
print(f"Found {len(results)} relevant results in {time.time() - start_time:.2f}s")
# Fast mode - skip LLM processing
best_slide_content = ""
best_result = None
if curriculum_relevance_score > 0:
best_result = results[0]
best_slide_content = results[0].page_content
# Generate simple answer without LLM
if curriculum_relevance_score > 0:
slide_info = f"πŸ“„ **Slide Reference:** {best_result.metadata['filename']} - Page {best_result.metadata['page_number']}"
if "loops" in query.lower():
answer = f"{slide_info}\n\n**Slide Content:**\n{best_slide_content}\n\n**What are loops?**\n\nLoops are programming constructs that solve the problem of repetition. Instead of writing hundreds of print statements, loops allow you to accomplish the same task with just a few lines of code.\n\n**Key benefits:**\nβ€’ Efficiency: Reduce repetitive code\nβ€’ Scalability: Handle large ranges easily\nβ€’ Maintainability: Easier to modify and debug"
else:
answer = f"{slide_info}\n\n**Slide Content:**\n{best_slide_content}\n\nThis slide contains relevant information about your question."
else:
answer = "I couldn't find relevant content in the curriculum for this question. Please try rephrasing or ask about a different programming topic."
# Get relevant slides
relevant_slides = []
if curriculum_relevance_score > 0:
filename = best_result.metadata["filename"]
page_number = best_result.metadata["page_number"]
if filename in self.pdf_files:
pdf_path = self.pdf_files[filename]
doc = fitz.open(pdf_path)
total_pages = len(doc)
doc.close()
# Get the target page and neighboring pages
start_page = max(1, page_number - 1)
end_page = min(total_pages, page_number + 1)
for page_num in range(start_page, end_page + 1):
img = self.get_pdf_page_image(pdf_path, page_num)
if img:
if page_num == page_number:
label = f"πŸ“Œ {filename} - Page {page_num} (Most Relevant)"
else:
label = f"{filename} - Page {page_num}"
relevant_slides.append((img, label))
else:
# Show a few slides from different PDFs
for filename, pages in list(self.pdf_pages.items())[:2]:
for page_num in list(pages.keys())[:1]:
img = self.get_pdf_page_image(self.pdf_files[filename], page_num)
if img:
relevant_slides.append((img, f"{filename} - Page {page_num}"))
# Cache the response
self.response_cache[query] = (answer, None, None, relevant_slides)
# Limit cache size
if len(self.response_cache) > 20:
oldest_key = next(iter(self.response_cache))
del self.response_cache[oldest_key]
total_time = time.time() - start_time
print(f"βœ… Response generated in {total_time:.2f} seconds")
return answer, None, None, relevant_slides
# --- Local Test UI ---
print("πŸš€ Starting Local Test Version...")
chatbot = LocalCurriculumChatbot(fast_mode=True)
def local_chat(query):
answer, _, _, relevant_slides = chatbot.chat(query)
return answer, relevant_slides
# Simple test function
def test_performance():
print("\nπŸ§ͺ Performance Test:")
test_queries = [
"What are loops?",
"How do variables work?",
"Explain functions",
"What is programming?"
]
for query in test_queries:
print(f"\nTesting: '{query}'")
start_time = time.time()
answer, slides = local_chat(query)
response_time = time.time() - start_time
print(f"Response time: {response_time:.2f} seconds")
print(f"Answer length: {len(answer)} characters")
print(f"Slides found: {len(slides)}")
# Run performance test
if __name__ == "__main__":
test_performance()
# Start Gradio interface
with gr.Blocks(title="Local Curriculum Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ§ͺ Local Test - Curriculum Assistant")
gr.Markdown("**Testing performance optimizations**")
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
label="Question",
placeholder="e.g., What are loops?",
lines=2
)
submit = gr.Button("πŸš€ Test", variant="primary")
answer = gr.Markdown(label="Response")
with gr.Column(scale=1):
gallery = gr.Gallery(
label="Slides",
columns=1,
rows=2,
height="400px",
object_fit="contain"
)
submit.click(fn=local_chat, inputs=question, outputs=[answer, gallery])
question.submit(fn=local_chat, inputs=question, outputs=[answer, gallery])
print("\n🌐 Starting local server...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)