Spaces:
Sleeping
Sleeping
| """AI Tutor service using Gemini 2.5 Flash with PostgreSQL pgvector for RAG.""" | |
| import json | |
| import os | |
| from typing import List, Optional | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_postgres import PGVector | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from pypdf import PdfReader | |
| from app.config import settings | |
| from app.schemas.tutor import SourceDocument, GeneratedMCQ | |
| class AITutorService: | |
| """Enhanced AI Tutor with Gemini 2.5 Flash and PostgreSQL pgvector for RAG.""" | |
| def __init__(self): | |
| self.embeddings: Optional[HuggingFaceEmbeddings] = None | |
| self.vector_store: Optional[PGVector] = None | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| ) | |
| self._initialized = False | |
| self._llm = None # Cache LLM instance | |
| async def initialize(self): | |
| """Initialize HuggingFace embeddings and pgvector store.""" | |
| if self._initialized: | |
| return | |
| print("🔄 Loading embedding model (this may take a moment on first run)...") | |
| # Use HuggingFace sentence-transformers for embeddings | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=settings.EMBEDDING_MODEL, | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| print("✅ Embedding model loaded!") | |
| # Initialize PGVector with PostgreSQL connection | |
| # Remove query parameters from URL for PGVector compatibility | |
| clean_db_url = settings.DATABASE_URL.replace("+asyncpg", "").split("?")[0] | |
| self.vector_store = PGVector( | |
| embeddings=self.embeddings, | |
| collection_name="documents", | |
| connection=clean_db_url, # psycopg2 format without query params | |
| use_jsonb=True, | |
| ) | |
| # Pre-initialize LLM | |
| self._llm = self._get_llm(temperature=0.3) | |
| self._initialized = True | |
| def _get_llm(self, temperature: float = 0.3) -> ChatGoogleGenerativeAI: | |
| """Get Gemini LLM instance.""" | |
| return ChatGoogleGenerativeAI( | |
| model=settings.LLM_MODEL, | |
| google_api_key=settings.GOOGLE_API_KEY, | |
| temperature=temperature, | |
| ) | |
| async def ask_question( | |
| self, | |
| question: str, | |
| lesson_context: Optional[str] = None, | |
| module_context: Optional[str] = None, | |
| ) -> tuple[str, List[SourceDocument]]: | |
| """Answer a question using RAG with lesson/module context.""" | |
| await self.initialize() | |
| # Build filter for specific lesson/module if provided | |
| filter_dict = {} | |
| if lesson_context: | |
| filter_dict["lesson_id"] = lesson_context | |
| elif module_context: | |
| filter_dict["module_id"] = module_context | |
| # Retrieve relevant documents | |
| try: | |
| if filter_dict: | |
| docs = self.vector_store.similarity_search( | |
| question, k=4, filter=filter_dict | |
| ) | |
| else: | |
| docs = self.vector_store.similarity_search(question, k=4) | |
| except Exception as e: | |
| print(f"⚠️ Vector search error: {e}") | |
| # If vector search fails (e.g., no documents yet), provide general answer | |
| docs = [] | |
| # Format context | |
| context = self._format_context(docs) | |
| sources = self._format_sources(docs) | |
| # Generate answer | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an AI tutor helping students learn. Answer questions based on the provided context. | |
| INSTRUCTIONS: | |
| - Be clear, educational, and helpful | |
| - Use examples when appropriate | |
| - If the context doesn't have the answer, say so honestly | |
| - Format with markdown for readability | |
| CONTEXT: | |
| {context}"""), | |
| ("human", "{question}"), | |
| ]) | |
| llm = self._llm or self._get_llm(temperature=0.3) | |
| chain = prompt | llm | |
| response = await chain.ainvoke({"context": context, "question": question}) | |
| return response.content, sources | |
| async def generate_summary(self, content: str, title: str) -> tuple[str, List[str]]: | |
| """Generate a summary with key points from lesson content.""" | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert educator. Create a concise summary and extract key learning points. | |
| Return ONLY a valid JSON object (no markdown formatting, no code blocks) with: | |
| - "summary": A 2-3 paragraph summary of the content | |
| - "key_points": A list of 5-7 bullet points highlighting the most important concepts | |
| Be educational and clear. Focus on what students need to remember. | |
| IMPORTANT: Return ONLY the JSON object, do NOT wrap it in ```json code blocks."""), | |
| ("human", "Summarize this lesson titled '{title}':\n\n{content}"), | |
| ]) | |
| llm = self._get_llm(temperature=0.2) | |
| chain = prompt | llm | |
| response = await chain.ainvoke({"title": title, "content": content}) | |
| try: | |
| # Clean response - remove markdown code blocks if present | |
| content = response.content.strip() | |
| # Remove ```json and ``` wrappers if present | |
| if content.startswith("```json"): | |
| content = content[7:] # Remove ```json | |
| elif content.startswith("```"): | |
| content = content[3:] # Remove ``` | |
| if content.endswith("```"): | |
| content = content[:-3] # Remove trailing ``` | |
| content = content.strip() | |
| # Try to parse JSON response | |
| result = json.loads(content) | |
| return result.get("summary", ""), result.get("key_points", []) | |
| except json.JSONDecodeError as e: | |
| print(f"⚠️ JSON decode error: {e}") | |
| print(f"⚠️ Raw response: {response.content[:500]}") | |
| # Fallback: return raw response as summary | |
| return response.content, [] | |
| async def generate_mcqs( | |
| self, | |
| content: str, | |
| num_questions: int = 5, | |
| difficulty: str = "medium", | |
| ) -> List[GeneratedMCQ]: | |
| """Generate MCQ questions from content using Gemini.""" | |
| difficulty_instruction = { | |
| "easy": "Focus on basic recall and understanding", | |
| "medium": "Include application and analysis questions", | |
| "hard": "Focus on synthesis, evaluation, and edge cases", | |
| }.get(difficulty, "Include a mix of question difficulties") | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", f"""You are an expert test creator. Generate {num_questions} multiple-choice questions based on the provided content. | |
| DIFFICULTY LEVEL: {difficulty} | |
| INSTRUCTION: {difficulty_instruction} | |
| Return ONLY a valid JSON array (no markdown formatting, no code blocks) where each question has: | |
| - "question": The question text | |
| - "options": Array of 4 options (A, B, C, D) | |
| - "correct_answer": The correct option letter and text (e.g., "A. Photosynthesis") | |
| - "explanation": Brief explanation of why this is correct | |
| Make questions educational and test real understanding, not just memorization. | |
| IMPORTANT: Return ONLY the JSON array, do NOT wrap it in ```json code blocks."""), | |
| ("human", "{content}"), | |
| ]) | |
| llm = self._get_llm(temperature=0.4) | |
| chain = prompt | llm | |
| response = await chain.ainvoke({"content": content}) | |
| try: | |
| # Clean response - remove markdown code blocks if present | |
| content_str = response.content.strip() | |
| # Remove ```json and ``` wrappers if present | |
| if content_str.startswith("```json"): | |
| content_str = content_str[7:] | |
| elif content_str.startswith("```"): | |
| content_str = content_str[3:] | |
| if content_str.endswith("```"): | |
| content_str = content_str[:-3] | |
| content_str = content_str.strip() | |
| questions_data = json.loads(content_str) | |
| return [ | |
| GeneratedMCQ( | |
| question=q["question"], | |
| options=q["options"], | |
| correct_answer=q["correct_answer"], | |
| explanation=q["explanation"], | |
| ) | |
| for q in questions_data | |
| ] | |
| except (json.JSONDecodeError, KeyError) as e: | |
| print(f"⚠️ Quiz generation error: {e}") | |
| print(f"⚠️ Raw response: {response.content[:500]}") | |
| return [] | |
| async def generate_hint( | |
| self, | |
| question: str, | |
| context: Optional[str] = None, | |
| hint_level: int = 1, | |
| ) -> tuple[str, bool]: | |
| """Generate progressive hints for a question.""" | |
| hint_styles = { | |
| 1: "Give a subtle hint that points the student in the right direction without revealing the answer. Be Socratic — ask guiding questions.", | |
| 2: "Give a moderate hint that narrows down the possibilities. Mention relevant concepts without stating the answer directly.", | |
| 3: "Give a strong hint that almost reveals the answer. Provide key information but still require the student to make the final connection.", | |
| } | |
| style = hint_styles.get(hint_level, hint_styles[1]) | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", f"""You are a patient tutor helping a student who is stuck. | |
| HINT STYLE: {style} | |
| Your goal is to help the student learn, not just give them the answer. | |
| {f"Use this context if helpful: {context}" if context else ""}"""), | |
| ("human", "I need help with this: {question}"), | |
| ]) | |
| llm = self._get_llm(temperature=0.5) | |
| chain = prompt | llm | |
| response = await chain.ainvoke({"question": question}) | |
| has_more_hints = hint_level < 3 | |
| return response.content, has_more_hints | |
| def _format_context(self, documents: List[Document]) -> str: | |
| """Format retrieved documents into context string.""" | |
| if not documents: | |
| return "No relevant content found." | |
| parts = [] | |
| for i, doc in enumerate(documents, 1): | |
| source = doc.metadata.get("source", "Unknown") | |
| parts.append(f"[Source {i}: {source}]\n{doc.page_content}") | |
| return "\n\n---\n\n".join(parts) | |
| def _format_sources(self, documents: List[Document]) -> List[SourceDocument]: | |
| """Convert documents to SourceDocument list.""" | |
| sources = [] | |
| seen = set() | |
| for doc in documents: | |
| key = doc.metadata.get("source", "") | |
| if key not in seen: | |
| seen.add(key) | |
| sources.append( | |
| SourceDocument( | |
| content=doc.page_content[:300], | |
| source=doc.metadata.get("source", "Unknown"), | |
| page=doc.metadata.get("page"), | |
| ) | |
| ) | |
| return sources | |
| async def index_lesson_content( | |
| self, | |
| lesson_id: str, | |
| module_id: str, | |
| title: str, | |
| content_text: Optional[str] = None, | |
| content_url: Optional[str] = None, | |
| content_type: str = "markdown", | |
| ) -> bool: | |
| """Index lesson content into the vector store.""" | |
| await self.initialize() | |
| # Extract content based on type | |
| text_to_index = "" | |
| if content_type == "markdown" and content_text: | |
| text_to_index = content_text | |
| elif content_type == "pdf" and content_url: | |
| # Extract text from PDF file | |
| file_path = os.path.join(settings.UPLOAD_DIR, os.path.basename(content_url)) | |
| if os.path.exists(file_path): | |
| try: | |
| reader = PdfReader(file_path) | |
| pdf_parts = [] | |
| # Extract up to 50 pages for full indexing | |
| for i, page in enumerate(reader.pages[:50]): | |
| text = page.extract_text() | |
| if text: | |
| pdf_parts.append(text) | |
| text_to_index = "\n\n".join(pdf_parts) | |
| except Exception as e: | |
| print(f"⚠️ PDF extraction error for {file_path}: {e}") | |
| return False | |
| if not text_to_index or len(text_to_index.strip()) < 50: | |
| print(f"⚠️ No content to index for lesson {lesson_id}") | |
| return False | |
| # Split text into chunks | |
| documents = [] | |
| chunks = self.text_splitter.split_text(text_to_index) | |
| for i, chunk in enumerate(chunks): | |
| doc = Document( | |
| page_content=chunk, | |
| metadata={ | |
| "source": title, | |
| "lesson_id": lesson_id, | |
| "module_id": module_id, | |
| "chunk_index": i, | |
| "content_type": content_type, | |
| } | |
| ) | |
| documents.append(doc) | |
| # Add documents to vector store | |
| try: | |
| self.vector_store.add_documents(documents) | |
| print(f"✅ Indexed {len(documents)} chunks for lesson '{title}'") | |
| return True | |
| except Exception as e: | |
| print(f"⚠️ Failed to index lesson {lesson_id}: {e}") | |
| return False | |
| tutor_service = AITutorService() | |