import os from typing import Annotated, List, TypedDict, Union from typing_extensions import TypedDict from langchain_google_genai import ChatGoogleGenerativeAI from langchain_chroma import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langgraph.graph import StateGraph, END from dotenv import load_dotenv # Load environment variables load_dotenv() # --- CONFIGURATION --- CHROMA_PATH = "chroma_db" # Lazy-loaded singletons _embeddings = None _vector_store = None _llm = None def get_resources(): global _embeddings, _vector_store if _embeddings is None: _embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}) if _vector_store is None: _vector_store = Chroma( collection_name="socratic_knowledge", embedding_function=_embeddings, persist_directory=CHROMA_PATH ) return _vector_store def get_text_content(content: Union[str, List[dict]]) -> str: if isinstance(content, str): return content elif isinstance(content, list): return "".join([part.get("text", "") for part in content if isinstance(part, dict) and "text" in part]) return str(content) # --- STATE DEFINITION --- class LearnerState(TypedDict): messages: Annotated[List[BaseMessage], "The chat history"] context: str hint_level: int safety_status: str current_topic: str grade: str subject: str selected_topic: str status: str # ACTIVE or COMPLETED _llm_cache = {} def get_llm(temperature=0.2, max_tokens=None): global _llm_cache # Create a unique key for the specific model configuration cache_key = (temperature, max_tokens) if cache_key not in _llm_cache: _llm_cache[cache_key] = ChatGoogleGenerativeAI( model="gemini-3.1-flash-lite", google_api_key=os.getenv("GOOGLE_API_KEY"), temperature=temperature, max_output_tokens=max_tokens ) return _llm_cache[cache_key] # --- NODES --- def safety_node(state: LearnerState): print("\n[GRAPH DEBUG] Entering safety_node (high-speed hybrid)...") last_msg = get_text_content(state['messages'][-1].content).strip() last_msg_lower = last_msg.lower() # Fast Path 1: Instant PASS for common short fillers safe_words = {"hi", "hello", "hey", "thanks", "thank you", "ok", "yes", "no", "help"} if last_msg_lower in safe_words or len(last_msg.split()) < 2: return {"safety_status": "PASS"} # Fast Path 2: Instant BLOCK for obvious jailbreak/toxic keywords unsafe_keywords = {"ignore all", "system prompt", "hack", "bomb", "kill", "porn"} if any(kw in last_msg_lower for kw in unsafe_keywords): print("[GRAPH DEBUG] Safety Check Result: BLOCK (Keyword match)") return {"safety_status": "BLOCK"} # Fast Path 3: Minimal LLM check for complex queries # Limit to 2 tokens for maximum speed llm = get_llm(temperature=0.0, max_tokens=2) prompt = f"Is this query safe and for school? Query: '{last_msg}'. Reply PASS or BLOCK only." try: response = llm.invoke(prompt) result = get_text_content(response.content).strip().upper() status = "BLOCK" if "BLOCK" in result else "PASS" print(f"[GRAPH DEBUG] Safety Check Result: {status}") return {"safety_status": status} except Exception as e: print(f"[GRAPH DEBUG] Safety Node Error: {e}") return {"safety_status": "PASS"} def blocked_node(state: LearnerState): warning = AIMessage(content="⚠️ **Safety Warning:** This query has been flagged as off-topic or inappropriate for this educational session. I am here to help you learn your school subjects—please try asking a question related to the current topic!") return {"messages": state['messages'] + [warning]} def retriever_node(state: LearnerState): last_msg = get_text_content(state['messages'][-1].content) # Skip RAG only for extremely short 1-2 word filler if len(last_msg.split()) < 2: return {"context": ""} try: vector_store = get_resources() # Metadata Filtering: Only search within the selected Grade and Subject # This saves tokens and prevents 'cross-talk' between subjects # ChromaDB requires $and for multiple conditions search_filter = { "$and": [ {"grade": state.get('grade')}, {"subject": state.get('subject')} ] } results = vector_store.similarity_search( last_msg, k=2, filter=search_filter ) context = "\n---\n".join([r.page_content for r in results]) # DEBUG LOG for the terminal print(f"\n[RAG DEBUG] Query: '{last_msg}'") print(f"[RAG DEBUG] Found {len(results)} relevant chunks.") if results: print(f"[RAG DEBUG] Top Chunk Source: {results[0].metadata.get('source', 'Unknown')}") return {"context": context} except Exception as e: print(f"[RAG ERROR] {e}") return {"context": ""} def learner_node(state: LearnerState): print("[GRAPH DEBUG] Entering learner_node...") # Budget 500 tokens for the response to keep it fast llm = get_llm(temperature=0.2, max_tokens=500) selected_topic = state.get('selected_topic', 'General') # Strict pedagogical instructions # Condensed pedagogical instructions for faster processing system_instruction = f"""Socratic {state.get('grade')} {state.get('subject')} Tutor. Topic: {selected_topic}. Hint Level: {state['hint_level']}/5. STRATEGY: - L1: High-level memory nudge. No facts. - L2: Point to specific concept or context part. - L3: Partial scaffold or 'fill-in-blank' prompt. - L4: Explain core logic, but student concludes. - L5: Full explanation only if stuck. RULES: No direct answers L1-L3. Use Context. If correct, say "Good work!", reset to L1, and ask what's next. CONTEXT: {state['context']} FORMAT: [Safety Status], [Status], [Hint Level], [Response].""" chat_prompt = ChatPromptTemplate.from_messages([ ("system", system_instruction), MessagesPlaceholder(variable_name="messages"), ]) # Use last 8 messages for context (balanced for speed vs awareness) history = state['messages'][-8:] # Google API requires alternating roles and often fails if it starts with an AIMessage after the SystemMessage. # Add a dummy HumanMessage if needed to ensure the sequence is valid. if history and history[0].type == "ai": history.insert(0, HumanMessage(content="[Conversation Started]")) chain = chat_prompt | llm response = chain.invoke({"messages": history}) print("[GRAPH DEBUG] Tutoring response received.") content = get_text_content(response.content) new_level = state['hint_level'] status = "ACTIVE" # Robust parsing for both old and new compressed formats if "[Hint Level]:" in content: try: lvl_line = [l for l in content.split('\n') if "[Hint Level]:" in l][0] new_level = int(''.join(filter(str.isdigit, lvl_line))) except: pass elif "[L" in content: # Match [L1], [L2], etc. try: import re match = re.search(r'\[L(\d)\]', content) if match: new_level = int(match.group(1)) except: pass new_level = max(1, min(5, new_level)) # Status detection if "COMPLETED" in content.upper(): status = "ACTIVE" # Keep active for the final follow-up # Clean the response for the UI final_response = content if "[Response]:" in content: final_response = content.split("[Response]:")[-1].strip() elif "]," in content: # Handle [Safe], [Active], [L1], [Actual Response] parts = content.split("],") final_response = parts[-1].strip().strip("[]") elif content.startswith("[") and content.count("]") >= 3: # Handle cases where commas might be missing but brackets are present final_response = content.split("]")[-1].strip().strip("[") clean_msg = AIMessage(content=final_response) return {"messages": state['messages'] + [clean_msg], "hint_level": new_level, "status": status} def route_next(state: LearnerState): return END def route_safety(state: LearnerState): if state.get("safety_status") == "BLOCK": return "blocked" return "safe" def create_learner_graph(): workflow = StateGraph(LearnerState) workflow.add_node("safety", safety_node) workflow.add_node("blocked", blocked_node) workflow.add_node("retrieve", retriever_node) workflow.add_node("learner", learner_node) # Start with the safety check workflow.set_entry_point("safety") # Conditional routing based on safety result workflow.add_conditional_edges( "safety", route_safety, { "blocked": "blocked", "safe": "retrieve" } ) # Normal flow if safe workflow.add_edge("retrieve", "learner") workflow.add_edge("learner", END) # End after showing the warning workflow.add_edge("blocked", END) return workflow.compile() learner_app = create_learner_graph()