File size: 9,522 Bytes
a10a6c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import os
from typing import Annotated, List, TypedDict, Union
from typing_extensions import TypedDict

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# --- CONFIGURATION ---
CHROMA_PATH = "chroma_db"

# Lazy-loaded singletons
_embeddings = None
_vector_store = None
_llm = None

def get_resources():
    global _embeddings, _vector_store
    if _embeddings is None:
        _embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'})
    if _vector_store is None:
        _vector_store = Chroma(
            collection_name="socratic_knowledge",
            embedding_function=_embeddings,
            persist_directory=CHROMA_PATH
        )
    return _vector_store

def get_text_content(content: Union[str, List[dict]]) -> str:
    if isinstance(content, str): return content
    elif isinstance(content, list):
        return "".join([part.get("text", "") for part in content if isinstance(part, dict) and "text" in part])
    return str(content)

# --- STATE DEFINITION ---
class LearnerState(TypedDict):
    messages: Annotated[List[BaseMessage], "The chat history"]
    context: str
    hint_level: int
    safety_status: str
    current_topic: str
    grade: str
    subject: str
    selected_topic: str
    status: str # ACTIVE or COMPLETED

_llm_cache = {}

def get_llm(temperature=0.2, max_tokens=None):
    global _llm_cache
    # Create a unique key for the specific model configuration
    cache_key = (temperature, max_tokens)
    
    if cache_key not in _llm_cache:
        _llm_cache[cache_key] = ChatGoogleGenerativeAI(
            model="gemini-3.1-flash-lite", 
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            temperature=temperature,
            max_output_tokens=max_tokens
        )
    return _llm_cache[cache_key]

# --- NODES ---

def safety_node(state: LearnerState):
    print("\n[GRAPH DEBUG] Entering safety_node (high-speed hybrid)...")
    last_msg = get_text_content(state['messages'][-1].content).strip()
    last_msg_lower = last_msg.lower()
    
    # Fast Path 1: Instant PASS for common short fillers
    safe_words = {"hi", "hello", "hey", "thanks", "thank you", "ok", "yes", "no", "help"}
    if last_msg_lower in safe_words or len(last_msg.split()) < 2:
        return {"safety_status": "PASS"}

    # Fast Path 2: Instant BLOCK for obvious jailbreak/toxic keywords
    unsafe_keywords = {"ignore all", "system prompt", "hack", "bomb", "kill", "porn"}
    if any(kw in last_msg_lower for kw in unsafe_keywords):
        print("[GRAPH DEBUG] Safety Check Result: BLOCK (Keyword match)")
        return {"safety_status": "BLOCK"}

    # Fast Path 3: Minimal LLM check for complex queries
    # Limit to 2 tokens for maximum speed
    llm = get_llm(temperature=0.0, max_tokens=2)
    prompt = f"Is this query safe and for school? Query: '{last_msg}'. Reply PASS or BLOCK only."
    
    try:
        response = llm.invoke(prompt)
        result = get_text_content(response.content).strip().upper()
        status = "BLOCK" if "BLOCK" in result else "PASS"
        print(f"[GRAPH DEBUG] Safety Check Result: {status}")
        return {"safety_status": status}
    except Exception as e:
        print(f"[GRAPH DEBUG] Safety Node Error: {e}")
        return {"safety_status": "PASS"}

def blocked_node(state: LearnerState):
    warning = AIMessage(content="⚠️ **Safety Warning:** This query has been flagged as off-topic or inappropriate for this educational session. I am here to help you learn your school subjects—please try asking a question related to the current topic!")
    return {"messages": state['messages'] + [warning]}

def retriever_node(state: LearnerState):
    last_msg = get_text_content(state['messages'][-1].content)
    
    # Skip RAG only for extremely short 1-2 word filler
    if len(last_msg.split()) < 2:
        return {"context": ""}
        
    try:
        vector_store = get_resources()
        # Metadata Filtering: Only search within the selected Grade and Subject
        # This saves tokens and prevents 'cross-talk' between subjects
        # ChromaDB requires $and for multiple conditions
        search_filter = {
            "$and": [
                {"grade": state.get('grade')},
                {"subject": state.get('subject')}
            ]
        }
        
        results = vector_store.similarity_search(
            last_msg, 
            k=2, 
            filter=search_filter
        )
        context = "\n---\n".join([r.page_content for r in results])
        
        # DEBUG LOG for the terminal
        print(f"\n[RAG DEBUG] Query: '{last_msg}'")
        print(f"[RAG DEBUG] Found {len(results)} relevant chunks.")
        if results:
            print(f"[RAG DEBUG] Top Chunk Source: {results[0].metadata.get('source', 'Unknown')}")
            
        return {"context": context}
    except Exception as e:
        print(f"[RAG ERROR] {e}")
        return {"context": ""}

def learner_node(state: LearnerState):
    print("[GRAPH DEBUG] Entering learner_node...")
    # Budget 500 tokens for the response to keep it fast
    llm = get_llm(temperature=0.2, max_tokens=500)
    
    selected_topic = state.get('selected_topic', 'General')
    
    # Strict pedagogical instructions
    # Condensed pedagogical instructions for faster processing
    system_instruction = f"""Socratic {state.get('grade')} {state.get('subject')} Tutor. Topic: {selected_topic}. Hint Level: {state['hint_level']}/5.
STRATEGY:
- L1: High-level memory nudge. No facts.
- L2: Point to specific concept or context part.
- L3: Partial scaffold or 'fill-in-blank' prompt.
- L4: Explain core logic, but student concludes.
- L5: Full explanation only if stuck.
RULES: No direct answers L1-L3. Use Context. If correct, say "Good work!", reset to L1, and ask what's next.
CONTEXT: {state['context']}
FORMAT: [Safety Status], [Status], [Hint Level], [Response]."""
    
    chat_prompt = ChatPromptTemplate.from_messages([
        ("system", system_instruction),
        MessagesPlaceholder(variable_name="messages"),
    ])
    
    # Use last 8 messages for context (balanced for speed vs awareness)
    history = state['messages'][-8:]
    
    # Google API requires alternating roles and often fails if it starts with an AIMessage after the SystemMessage.
    # Add a dummy HumanMessage if needed to ensure the sequence is valid.
    if history and history[0].type == "ai":
        history.insert(0, HumanMessage(content="[Conversation Started]"))
        
    chain = chat_prompt | llm
    response = chain.invoke({"messages": history})
    print("[GRAPH DEBUG] Tutoring response received.")
    content = get_text_content(response.content)
    
    new_level = state['hint_level']
    status = "ACTIVE"
    
    # Robust parsing for both old and new compressed formats
    if "[Hint Level]:" in content:
        try:
            lvl_line = [l for l in content.split('\n') if "[Hint Level]:" in l][0]
            new_level = int(''.join(filter(str.isdigit, lvl_line)))
        except: pass
    elif "[L" in content:
        # Match [L1], [L2], etc.
        try:
            import re
            match = re.search(r'\[L(\d)\]', content)
            if match: new_level = int(match.group(1))
        except: pass
    
    new_level = max(1, min(5, new_level))

    # Status detection
    if "COMPLETED" in content.upper():
        status = "ACTIVE" # Keep active for the final follow-up

    # Clean the response for the UI
    final_response = content
    if "[Response]:" in content:
        final_response = content.split("[Response]:")[-1].strip()
    elif "]," in content:
        # Handle [Safe], [Active], [L1], [Actual Response]
        parts = content.split("],")
        final_response = parts[-1].strip().strip("[]")
    elif content.startswith("[") and content.count("]") >= 3:
        # Handle cases where commas might be missing but brackets are present
        final_response = content.split("]")[-1].strip().strip("[")
    
    clean_msg = AIMessage(content=final_response)
    return {"messages": state['messages'] + [clean_msg], "hint_level": new_level, "status": status}

def route_next(state: LearnerState):
    return END

def route_safety(state: LearnerState):
    if state.get("safety_status") == "BLOCK":
        return "blocked"
    return "safe"

def create_learner_graph():
    workflow = StateGraph(LearnerState)
    
    workflow.add_node("safety", safety_node)
    workflow.add_node("blocked", blocked_node)
    workflow.add_node("retrieve", retriever_node)
    workflow.add_node("learner", learner_node)
    
    # Start with the safety check
    workflow.set_entry_point("safety")
    
    # Conditional routing based on safety result
    workflow.add_conditional_edges(
        "safety",
        route_safety,
        {
            "blocked": "blocked",
            "safe": "retrieve"
        }
    )
    
    # Normal flow if safe
    workflow.add_edge("retrieve", "learner")
    workflow.add_edge("learner", END)
    
    # End after showing the warning
    workflow.add_edge("blocked", END)
    
    return workflow.compile()

learner_app = create_learner_graph()