SocraticAI / socratic_graph.py
Deployer
Initial deployment commit with Git LFS tracking
a10a6c0
Raw
History Blame Contribute Delete
9.52 kB
import os
from typing import Annotated, List, TypedDict, Union
from typing_extensions import TypedDict
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# --- CONFIGURATION ---
CHROMA_PATH = "chroma_db"
# Lazy-loaded singletons
_embeddings = None
_vector_store = None
_llm = None
def get_resources():
global _embeddings, _vector_store
if _embeddings is None:
_embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'})
if _vector_store is None:
_vector_store = Chroma(
collection_name="socratic_knowledge",
embedding_function=_embeddings,
persist_directory=CHROMA_PATH
)
return _vector_store
def get_text_content(content: Union[str, List[dict]]) -> str:
if isinstance(content, str): return content
elif isinstance(content, list):
return "".join([part.get("text", "") for part in content if isinstance(part, dict) and "text" in part])
return str(content)
# --- STATE DEFINITION ---
class LearnerState(TypedDict):
messages: Annotated[List[BaseMessage], "The chat history"]
context: str
hint_level: int
safety_status: str
current_topic: str
grade: str
subject: str
selected_topic: str
status: str # ACTIVE or COMPLETED
_llm_cache = {}
def get_llm(temperature=0.2, max_tokens=None):
global _llm_cache
# Create a unique key for the specific model configuration
cache_key = (temperature, max_tokens)
if cache_key not in _llm_cache:
_llm_cache[cache_key] = ChatGoogleGenerativeAI(
model="gemini-3.1-flash-lite",
google_api_key=os.getenv("GOOGLE_API_KEY"),
temperature=temperature,
max_output_tokens=max_tokens
)
return _llm_cache[cache_key]
# --- NODES ---
def safety_node(state: LearnerState):
print("\n[GRAPH DEBUG] Entering safety_node (high-speed hybrid)...")
last_msg = get_text_content(state['messages'][-1].content).strip()
last_msg_lower = last_msg.lower()
# Fast Path 1: Instant PASS for common short fillers
safe_words = {"hi", "hello", "hey", "thanks", "thank you", "ok", "yes", "no", "help"}
if last_msg_lower in safe_words or len(last_msg.split()) < 2:
return {"safety_status": "PASS"}
# Fast Path 2: Instant BLOCK for obvious jailbreak/toxic keywords
unsafe_keywords = {"ignore all", "system prompt", "hack", "bomb", "kill", "porn"}
if any(kw in last_msg_lower for kw in unsafe_keywords):
print("[GRAPH DEBUG] Safety Check Result: BLOCK (Keyword match)")
return {"safety_status": "BLOCK"}
# Fast Path 3: Minimal LLM check for complex queries
# Limit to 2 tokens for maximum speed
llm = get_llm(temperature=0.0, max_tokens=2)
prompt = f"Is this query safe and for school? Query: '{last_msg}'. Reply PASS or BLOCK only."
try:
response = llm.invoke(prompt)
result = get_text_content(response.content).strip().upper()
status = "BLOCK" if "BLOCK" in result else "PASS"
print(f"[GRAPH DEBUG] Safety Check Result: {status}")
return {"safety_status": status}
except Exception as e:
print(f"[GRAPH DEBUG] Safety Node Error: {e}")
return {"safety_status": "PASS"}
def blocked_node(state: LearnerState):
warning = AIMessage(content="⚠️ **Safety Warning:** This query has been flagged as off-topic or inappropriate for this educational session. I am here to help you learn your school subjects—please try asking a question related to the current topic!")
return {"messages": state['messages'] + [warning]}
def retriever_node(state: LearnerState):
last_msg = get_text_content(state['messages'][-1].content)
# Skip RAG only for extremely short 1-2 word filler
if len(last_msg.split()) < 2:
return {"context": ""}
try:
vector_store = get_resources()
# Metadata Filtering: Only search within the selected Grade and Subject
# This saves tokens and prevents 'cross-talk' between subjects
# ChromaDB requires $and for multiple conditions
search_filter = {
"$and": [
{"grade": state.get('grade')},
{"subject": state.get('subject')}
]
}
results = vector_store.similarity_search(
last_msg,
k=2,
filter=search_filter
)
context = "\n---\n".join([r.page_content for r in results])
# DEBUG LOG for the terminal
print(f"\n[RAG DEBUG] Query: '{last_msg}'")
print(f"[RAG DEBUG] Found {len(results)} relevant chunks.")
if results:
print(f"[RAG DEBUG] Top Chunk Source: {results[0].metadata.get('source', 'Unknown')}")
return {"context": context}
except Exception as e:
print(f"[RAG ERROR] {e}")
return {"context": ""}
def learner_node(state: LearnerState):
print("[GRAPH DEBUG] Entering learner_node...")
# Budget 500 tokens for the response to keep it fast
llm = get_llm(temperature=0.2, max_tokens=500)
selected_topic = state.get('selected_topic', 'General')
# Strict pedagogical instructions
# Condensed pedagogical instructions for faster processing
system_instruction = f"""Socratic {state.get('grade')} {state.get('subject')} Tutor. Topic: {selected_topic}. Hint Level: {state['hint_level']}/5.
STRATEGY:
- L1: High-level memory nudge. No facts.
- L2: Point to specific concept or context part.
- L3: Partial scaffold or 'fill-in-blank' prompt.
- L4: Explain core logic, but student concludes.
- L5: Full explanation only if stuck.
RULES: No direct answers L1-L3. Use Context. If correct, say "Good work!", reset to L1, and ask what's next.
CONTEXT: {state['context']}
FORMAT: [Safety Status], [Status], [Hint Level], [Response]."""
chat_prompt = ChatPromptTemplate.from_messages([
("system", system_instruction),
MessagesPlaceholder(variable_name="messages"),
])
# Use last 8 messages for context (balanced for speed vs awareness)
history = state['messages'][-8:]
# Google API requires alternating roles and often fails if it starts with an AIMessage after the SystemMessage.
# Add a dummy HumanMessage if needed to ensure the sequence is valid.
if history and history[0].type == "ai":
history.insert(0, HumanMessage(content="[Conversation Started]"))
chain = chat_prompt | llm
response = chain.invoke({"messages": history})
print("[GRAPH DEBUG] Tutoring response received.")
content = get_text_content(response.content)
new_level = state['hint_level']
status = "ACTIVE"
# Robust parsing for both old and new compressed formats
if "[Hint Level]:" in content:
try:
lvl_line = [l for l in content.split('\n') if "[Hint Level]:" in l][0]
new_level = int(''.join(filter(str.isdigit, lvl_line)))
except: pass
elif "[L" in content:
# Match [L1], [L2], etc.
try:
import re
match = re.search(r'\[L(\d)\]', content)
if match: new_level = int(match.group(1))
except: pass
new_level = max(1, min(5, new_level))
# Status detection
if "COMPLETED" in content.upper():
status = "ACTIVE" # Keep active for the final follow-up
# Clean the response for the UI
final_response = content
if "[Response]:" in content:
final_response = content.split("[Response]:")[-1].strip()
elif "]," in content:
# Handle [Safe], [Active], [L1], [Actual Response]
parts = content.split("],")
final_response = parts[-1].strip().strip("[]")
elif content.startswith("[") and content.count("]") >= 3:
# Handle cases where commas might be missing but brackets are present
final_response = content.split("]")[-1].strip().strip("[")
clean_msg = AIMessage(content=final_response)
return {"messages": state['messages'] + [clean_msg], "hint_level": new_level, "status": status}
def route_next(state: LearnerState):
return END
def route_safety(state: LearnerState):
if state.get("safety_status") == "BLOCK":
return "blocked"
return "safe"
def create_learner_graph():
workflow = StateGraph(LearnerState)
workflow.add_node("safety", safety_node)
workflow.add_node("blocked", blocked_node)
workflow.add_node("retrieve", retriever_node)
workflow.add_node("learner", learner_node)
# Start with the safety check
workflow.set_entry_point("safety")
# Conditional routing based on safety result
workflow.add_conditional_edges(
"safety",
route_safety,
{
"blocked": "blocked",
"safe": "retrieve"
}
)
# Normal flow if safe
workflow.add_edge("retrieve", "learner")
workflow.add_edge("learner", END)
# End after showing the warning
workflow.add_edge("blocked", END)
return workflow.compile()
learner_app = create_learner_graph()