Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Annotated, List, TypedDict, Union | |
| from typing_extensions import TypedDict | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langgraph.graph import StateGraph, END | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # --- CONFIGURATION --- | |
| CHROMA_PATH = "chroma_db" | |
| # Lazy-loaded singletons | |
| _embeddings = None | |
| _vector_store = None | |
| _llm = None | |
| def get_resources(): | |
| global _embeddings, _vector_store | |
| if _embeddings is None: | |
| _embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}) | |
| if _vector_store is None: | |
| _vector_store = Chroma( | |
| collection_name="socratic_knowledge", | |
| embedding_function=_embeddings, | |
| persist_directory=CHROMA_PATH | |
| ) | |
| return _vector_store | |
| def get_text_content(content: Union[str, List[dict]]) -> str: | |
| if isinstance(content, str): return content | |
| elif isinstance(content, list): | |
| return "".join([part.get("text", "") for part in content if isinstance(part, dict) and "text" in part]) | |
| return str(content) | |
| # --- STATE DEFINITION --- | |
| class LearnerState(TypedDict): | |
| messages: Annotated[List[BaseMessage], "The chat history"] | |
| context: str | |
| hint_level: int | |
| safety_status: str | |
| current_topic: str | |
| grade: str | |
| subject: str | |
| selected_topic: str | |
| status: str # ACTIVE or COMPLETED | |
| _llm_cache = {} | |
| def get_llm(temperature=0.2, max_tokens=None): | |
| global _llm_cache | |
| # Create a unique key for the specific model configuration | |
| cache_key = (temperature, max_tokens) | |
| if cache_key not in _llm_cache: | |
| _llm_cache[cache_key] = ChatGoogleGenerativeAI( | |
| model="gemini-3.1-flash-lite", | |
| google_api_key=os.getenv("GOOGLE_API_KEY"), | |
| temperature=temperature, | |
| max_output_tokens=max_tokens | |
| ) | |
| return _llm_cache[cache_key] | |
| # --- NODES --- | |
| def safety_node(state: LearnerState): | |
| print("\n[GRAPH DEBUG] Entering safety_node (high-speed hybrid)...") | |
| last_msg = get_text_content(state['messages'][-1].content).strip() | |
| last_msg_lower = last_msg.lower() | |
| # Fast Path 1: Instant PASS for common short fillers | |
| safe_words = {"hi", "hello", "hey", "thanks", "thank you", "ok", "yes", "no", "help"} | |
| if last_msg_lower in safe_words or len(last_msg.split()) < 2: | |
| return {"safety_status": "PASS"} | |
| # Fast Path 2: Instant BLOCK for obvious jailbreak/toxic keywords | |
| unsafe_keywords = {"ignore all", "system prompt", "hack", "bomb", "kill", "porn"} | |
| if any(kw in last_msg_lower for kw in unsafe_keywords): | |
| print("[GRAPH DEBUG] Safety Check Result: BLOCK (Keyword match)") | |
| return {"safety_status": "BLOCK"} | |
| # Fast Path 3: Minimal LLM check for complex queries | |
| # Limit to 2 tokens for maximum speed | |
| llm = get_llm(temperature=0.0, max_tokens=2) | |
| prompt = f"Is this query safe and for school? Query: '{last_msg}'. Reply PASS or BLOCK only." | |
| try: | |
| response = llm.invoke(prompt) | |
| result = get_text_content(response.content).strip().upper() | |
| status = "BLOCK" if "BLOCK" in result else "PASS" | |
| print(f"[GRAPH DEBUG] Safety Check Result: {status}") | |
| return {"safety_status": status} | |
| except Exception as e: | |
| print(f"[GRAPH DEBUG] Safety Node Error: {e}") | |
| return {"safety_status": "PASS"} | |
| def blocked_node(state: LearnerState): | |
| warning = AIMessage(content="⚠️ **Safety Warning:** This query has been flagged as off-topic or inappropriate for this educational session. I am here to help you learn your school subjects—please try asking a question related to the current topic!") | |
| return {"messages": state['messages'] + [warning]} | |
| def retriever_node(state: LearnerState): | |
| last_msg = get_text_content(state['messages'][-1].content) | |
| # Skip RAG only for extremely short 1-2 word filler | |
| if len(last_msg.split()) < 2: | |
| return {"context": ""} | |
| try: | |
| vector_store = get_resources() | |
| # Metadata Filtering: Only search within the selected Grade and Subject | |
| # This saves tokens and prevents 'cross-talk' between subjects | |
| # ChromaDB requires $and for multiple conditions | |
| search_filter = { | |
| "$and": [ | |
| {"grade": state.get('grade')}, | |
| {"subject": state.get('subject')} | |
| ] | |
| } | |
| results = vector_store.similarity_search( | |
| last_msg, | |
| k=2, | |
| filter=search_filter | |
| ) | |
| context = "\n---\n".join([r.page_content for r in results]) | |
| # DEBUG LOG for the terminal | |
| print(f"\n[RAG DEBUG] Query: '{last_msg}'") | |
| print(f"[RAG DEBUG] Found {len(results)} relevant chunks.") | |
| if results: | |
| print(f"[RAG DEBUG] Top Chunk Source: {results[0].metadata.get('source', 'Unknown')}") | |
| return {"context": context} | |
| except Exception as e: | |
| print(f"[RAG ERROR] {e}") | |
| return {"context": ""} | |
| def learner_node(state: LearnerState): | |
| print("[GRAPH DEBUG] Entering learner_node...") | |
| # Budget 500 tokens for the response to keep it fast | |
| llm = get_llm(temperature=0.2, max_tokens=500) | |
| selected_topic = state.get('selected_topic', 'General') | |
| # Strict pedagogical instructions | |
| # Condensed pedagogical instructions for faster processing | |
| system_instruction = f"""Socratic {state.get('grade')} {state.get('subject')} Tutor. Topic: {selected_topic}. Hint Level: {state['hint_level']}/5. | |
| STRATEGY: | |
| - L1: High-level memory nudge. No facts. | |
| - L2: Point to specific concept or context part. | |
| - L3: Partial scaffold or 'fill-in-blank' prompt. | |
| - L4: Explain core logic, but student concludes. | |
| - L5: Full explanation only if stuck. | |
| RULES: No direct answers L1-L3. Use Context. If correct, say "Good work!", reset to L1, and ask what's next. | |
| CONTEXT: {state['context']} | |
| FORMAT: [Safety Status], [Status], [Hint Level], [Response].""" | |
| chat_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_instruction), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ]) | |
| # Use last 8 messages for context (balanced for speed vs awareness) | |
| history = state['messages'][-8:] | |
| # Google API requires alternating roles and often fails if it starts with an AIMessage after the SystemMessage. | |
| # Add a dummy HumanMessage if needed to ensure the sequence is valid. | |
| if history and history[0].type == "ai": | |
| history.insert(0, HumanMessage(content="[Conversation Started]")) | |
| chain = chat_prompt | llm | |
| response = chain.invoke({"messages": history}) | |
| print("[GRAPH DEBUG] Tutoring response received.") | |
| content = get_text_content(response.content) | |
| new_level = state['hint_level'] | |
| status = "ACTIVE" | |
| # Robust parsing for both old and new compressed formats | |
| if "[Hint Level]:" in content: | |
| try: | |
| lvl_line = [l for l in content.split('\n') if "[Hint Level]:" in l][0] | |
| new_level = int(''.join(filter(str.isdigit, lvl_line))) | |
| except: pass | |
| elif "[L" in content: | |
| # Match [L1], [L2], etc. | |
| try: | |
| import re | |
| match = re.search(r'\[L(\d)\]', content) | |
| if match: new_level = int(match.group(1)) | |
| except: pass | |
| new_level = max(1, min(5, new_level)) | |
| # Status detection | |
| if "COMPLETED" in content.upper(): | |
| status = "ACTIVE" # Keep active for the final follow-up | |
| # Clean the response for the UI | |
| final_response = content | |
| if "[Response]:" in content: | |
| final_response = content.split("[Response]:")[-1].strip() | |
| elif "]," in content: | |
| # Handle [Safe], [Active], [L1], [Actual Response] | |
| parts = content.split("],") | |
| final_response = parts[-1].strip().strip("[]") | |
| elif content.startswith("[") and content.count("]") >= 3: | |
| # Handle cases where commas might be missing but brackets are present | |
| final_response = content.split("]")[-1].strip().strip("[") | |
| clean_msg = AIMessage(content=final_response) | |
| return {"messages": state['messages'] + [clean_msg], "hint_level": new_level, "status": status} | |
| def route_next(state: LearnerState): | |
| return END | |
| def route_safety(state: LearnerState): | |
| if state.get("safety_status") == "BLOCK": | |
| return "blocked" | |
| return "safe" | |
| def create_learner_graph(): | |
| workflow = StateGraph(LearnerState) | |
| workflow.add_node("safety", safety_node) | |
| workflow.add_node("blocked", blocked_node) | |
| workflow.add_node("retrieve", retriever_node) | |
| workflow.add_node("learner", learner_node) | |
| # Start with the safety check | |
| workflow.set_entry_point("safety") | |
| # Conditional routing based on safety result | |
| workflow.add_conditional_edges( | |
| "safety", | |
| route_safety, | |
| { | |
| "blocked": "blocked", | |
| "safe": "retrieve" | |
| } | |
| ) | |
| # Normal flow if safe | |
| workflow.add_edge("retrieve", "learner") | |
| workflow.add_edge("learner", END) | |
| # End after showing the warning | |
| workflow.add_edge("blocked", END) | |
| return workflow.compile() | |
| learner_app = create_learner_graph() | |