Amit-kr26's picture
Initial commit: Multimodal Math Mentor
3c25c17
from __future__ import annotations
from datetime import datetime
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from agents.state import MathMentorState
from agents.guardrail_agent import guardrail_node
from agents.parser_agent import parser_node
from agents.router_agent import router_node
from agents.solver_agent import solver_node
from agents.verifier_agent import verifier_node
from agents.explainer_agent import explainer_node
from config import settings
from input_handlers.image_handler import handle_image_input
from input_handlers.audio_handler import handle_audio_input
from input_handlers.text_handler import handle_text_input
from rag.retriever import retrieve as rag_retrieve
def extract_input_node(state: MathMentorState) -> dict:
input_type = state.get("input_type", "text")
raw = state.get("raw_input", "")
if input_type == "image":
result = handle_image_input(raw)
elif input_type == "audio":
result = handle_audio_input(raw)
else:
result = handle_text_input(raw)
needs_review = result["confidence"] < settings.ocr_confidence_threshold and input_type != "text"
return {
"extracted_text": result["text"],
"extraction_confidence": result["confidence"],
"needs_human_review": needs_review,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "extractor",
"action": "extracted",
"summary": f"Type: {input_type}, confidence: {result['confidence']:.2f}",
"timestamp": datetime.now().isoformat(),
}
],
}
def retrieve_context_node(state: MathMentorState) -> dict:
parsed = state.get("parsed_problem", {})
query = parsed.get("problem_text", state.get("extracted_text", ""))
topic = state.get("problem_topic", "")
search_query = f"{topic}: {query}" if topic else query
chunks = rag_retrieve(search_query)
return {
"retrieved_chunks": chunks,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "retriever",
"action": "retrieved",
"summary": f"Found {len(chunks)} relevant chunks",
"timestamp": datetime.now().isoformat(),
}
],
}
def retrieve_memory_node(state: MathMentorState) -> dict:
try:
from memory.retriever import find_similar
parsed = state.get("parsed_problem", {})
query = parsed.get("problem_text", state.get("extracted_text", ""))
similar = find_similar(query, top_k=3)
except Exception:
similar = []
return {
"similar_past_problems": similar,
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "memory_retriever",
"action": "retrieved",
"summary": f"Found {len(similar)} similar past problems",
"timestamp": datetime.now().isoformat(),
}
],
}
def save_to_memory_node(state: MathMentorState) -> dict:
try:
from memory.store import save_record
save_record(
input_type=state.get("input_type", "text"),
extracted_text=state.get("extracted_text", ""),
parsed_problem=state.get("parsed_problem", {}),
topic=state.get("problem_topic", ""),
retrieved_chunks=[c.get("source", "") for c in state.get("retrieved_chunks", [])],
solution=state.get("solution", ""),
solution_steps=state.get("solution_steps", []),
verification=state.get("verification_result", {}),
explanation=state.get("explanation", ""),
)
except Exception:
pass
return {
"agent_trace": state.get("agent_trace", [])
+ [
{
"agent": "memory_saver",
"action": "saved",
"summary": "Problem and solution saved to memory",
"timestamp": datetime.now().isoformat(),
}
],
}
def should_review_extraction(state: MathMentorState) -> str:
if state.get("needs_human_review", False):
return "hitl_extraction"
return "guardrail"
def should_review_after_guardrail(state: MathMentorState) -> str:
if not state.get("is_valid_input", True):
return END
return "parse"
def should_review_parse(state: MathMentorState) -> str:
parsed = state.get("parsed_problem", {})
if parsed.get("needs_clarification", False) or state.get("needs_human_review", False):
return "hitl_clarification"
return "route"
def should_review_verification(state: MathMentorState) -> str:
verification = state.get("verification_result", {})
confidence = verification.get("confidence", 0)
is_correct = verification.get("is_correct", False)
retries = state.get("solver_retries", 0)
if not is_correct and retries < settings.max_solver_retries:
return "solve" # Retry solving
if confidence < settings.verifier_confidence_threshold or not is_correct:
return "hitl_verification"
return "explain"
def hitl_extraction_node(state: MathMentorState) -> dict:
text = state.get("human_edited_text") or state.get("extracted_text", "")
return {
"extracted_text": text,
"needs_human_review": False,
"human_approved": True,
}
def hitl_clarification_node(state: MathMentorState) -> dict:
text = state.get("human_edited_text") or state.get("extracted_text", "")
return {
"extracted_text": text,
"needs_human_review": False,
"human_approved": True,
}
def hitl_verification_node(state: MathMentorState) -> dict:
return {
"needs_human_review": False,
"human_approved": True,
}
def build_graph():
graph = StateGraph(MathMentorState)
graph.add_node("extract", extract_input_node)
graph.add_node("hitl_extraction", hitl_extraction_node)
graph.add_node("guardrail", guardrail_node)
graph.add_node("parse", parser_node)
graph.add_node("hitl_clarification", hitl_clarification_node)
graph.add_node("route", router_node)
graph.add_node("retrieve_context", retrieve_context_node)
graph.add_node("retrieve_memory", retrieve_memory_node)
graph.add_node("solve", solver_node)
graph.add_node("verify", verifier_node)
graph.add_node("hitl_verification", hitl_verification_node)
graph.add_node("explain", explainer_node)
graph.add_node("save_memory", save_to_memory_node)
graph.set_entry_point("extract")
graph.add_conditional_edges("extract", should_review_extraction, {
"hitl_extraction": "hitl_extraction",
"guardrail": "guardrail",
})
graph.add_edge("hitl_extraction", "guardrail")
graph.add_conditional_edges("guardrail", should_review_after_guardrail, {
END: END,
"parse": "parse",
})
graph.add_conditional_edges("parse", should_review_parse, {
"hitl_clarification": "hitl_clarification",
"route": "route",
})
graph.add_edge("hitl_clarification", "parse")
graph.add_edge("route", "retrieve_context")
graph.add_edge("retrieve_context", "retrieve_memory")
graph.add_edge("retrieve_memory", "solve")
graph.add_edge("solve", "verify")
graph.add_conditional_edges("verify", should_review_verification, {
"solve": "solve",
"hitl_verification": "hitl_verification",
"explain": "explain",
})
graph.add_edge("hitl_verification", "explain")
graph.add_edge("explain", "save_memory")
graph.add_edge("save_memory", END)
checkpointer = MemorySaver()
compiled = graph.compile(
checkpointer=checkpointer,
interrupt_before=["hitl_extraction", "hitl_clarification", "hitl_verification"],
)
return compiled
app_graph = build_graph()