from __future__ import annotations from datetime import datetime from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, StateGraph from agents.state import MathMentorState from agents.guardrail_agent import guardrail_node from agents.parser_agent import parser_node from agents.router_agent import router_node from agents.solver_agent import solver_node from agents.verifier_agent import verifier_node from agents.explainer_agent import explainer_node from config import settings from input_handlers.image_handler import handle_image_input from input_handlers.audio_handler import handle_audio_input from input_handlers.text_handler import handle_text_input from rag.retriever import retrieve as rag_retrieve def extract_input_node(state: MathMentorState) -> dict: input_type = state.get("input_type", "text") raw = state.get("raw_input", "") if input_type == "image": result = handle_image_input(raw) elif input_type == "audio": result = handle_audio_input(raw) else: result = handle_text_input(raw) needs_review = result["confidence"] < settings.ocr_confidence_threshold and input_type != "text" return { "extracted_text": result["text"], "extraction_confidence": result["confidence"], "needs_human_review": needs_review, "agent_trace": state.get("agent_trace", []) + [ { "agent": "extractor", "action": "extracted", "summary": f"Type: {input_type}, confidence: {result['confidence']:.2f}", "timestamp": datetime.now().isoformat(), } ], } def retrieve_context_node(state: MathMentorState) -> dict: parsed = state.get("parsed_problem", {}) query = parsed.get("problem_text", state.get("extracted_text", "")) topic = state.get("problem_topic", "") search_query = f"{topic}: {query}" if topic else query chunks = rag_retrieve(search_query) return { "retrieved_chunks": chunks, "agent_trace": state.get("agent_trace", []) + [ { "agent": "retriever", "action": "retrieved", "summary": f"Found {len(chunks)} relevant chunks", "timestamp": datetime.now().isoformat(), } ], } def retrieve_memory_node(state: MathMentorState) -> dict: try: from memory.retriever import find_similar parsed = state.get("parsed_problem", {}) query = parsed.get("problem_text", state.get("extracted_text", "")) similar = find_similar(query, top_k=3) except Exception: similar = [] return { "similar_past_problems": similar, "agent_trace": state.get("agent_trace", []) + [ { "agent": "memory_retriever", "action": "retrieved", "summary": f"Found {len(similar)} similar past problems", "timestamp": datetime.now().isoformat(), } ], } def save_to_memory_node(state: MathMentorState) -> dict: try: from memory.store import save_record save_record( input_type=state.get("input_type", "text"), extracted_text=state.get("extracted_text", ""), parsed_problem=state.get("parsed_problem", {}), topic=state.get("problem_topic", ""), retrieved_chunks=[c.get("source", "") for c in state.get("retrieved_chunks", [])], solution=state.get("solution", ""), solution_steps=state.get("solution_steps", []), verification=state.get("verification_result", {}), explanation=state.get("explanation", ""), ) except Exception: pass return { "agent_trace": state.get("agent_trace", []) + [ { "agent": "memory_saver", "action": "saved", "summary": "Problem and solution saved to memory", "timestamp": datetime.now().isoformat(), } ], } def should_review_extraction(state: MathMentorState) -> str: if state.get("needs_human_review", False): return "hitl_extraction" return "guardrail" def should_review_after_guardrail(state: MathMentorState) -> str: if not state.get("is_valid_input", True): return END return "parse" def should_review_parse(state: MathMentorState) -> str: parsed = state.get("parsed_problem", {}) if parsed.get("needs_clarification", False) or state.get("needs_human_review", False): return "hitl_clarification" return "route" def should_review_verification(state: MathMentorState) -> str: verification = state.get("verification_result", {}) confidence = verification.get("confidence", 0) is_correct = verification.get("is_correct", False) retries = state.get("solver_retries", 0) if not is_correct and retries < settings.max_solver_retries: return "solve" # Retry solving if confidence < settings.verifier_confidence_threshold or not is_correct: return "hitl_verification" return "explain" def hitl_extraction_node(state: MathMentorState) -> dict: text = state.get("human_edited_text") or state.get("extracted_text", "") return { "extracted_text": text, "needs_human_review": False, "human_approved": True, } def hitl_clarification_node(state: MathMentorState) -> dict: text = state.get("human_edited_text") or state.get("extracted_text", "") return { "extracted_text": text, "needs_human_review": False, "human_approved": True, } def hitl_verification_node(state: MathMentorState) -> dict: return { "needs_human_review": False, "human_approved": True, } def build_graph(): graph = StateGraph(MathMentorState) graph.add_node("extract", extract_input_node) graph.add_node("hitl_extraction", hitl_extraction_node) graph.add_node("guardrail", guardrail_node) graph.add_node("parse", parser_node) graph.add_node("hitl_clarification", hitl_clarification_node) graph.add_node("route", router_node) graph.add_node("retrieve_context", retrieve_context_node) graph.add_node("retrieve_memory", retrieve_memory_node) graph.add_node("solve", solver_node) graph.add_node("verify", verifier_node) graph.add_node("hitl_verification", hitl_verification_node) graph.add_node("explain", explainer_node) graph.add_node("save_memory", save_to_memory_node) graph.set_entry_point("extract") graph.add_conditional_edges("extract", should_review_extraction, { "hitl_extraction": "hitl_extraction", "guardrail": "guardrail", }) graph.add_edge("hitl_extraction", "guardrail") graph.add_conditional_edges("guardrail", should_review_after_guardrail, { END: END, "parse": "parse", }) graph.add_conditional_edges("parse", should_review_parse, { "hitl_clarification": "hitl_clarification", "route": "route", }) graph.add_edge("hitl_clarification", "parse") graph.add_edge("route", "retrieve_context") graph.add_edge("retrieve_context", "retrieve_memory") graph.add_edge("retrieve_memory", "solve") graph.add_edge("solve", "verify") graph.add_conditional_edges("verify", should_review_verification, { "solve": "solve", "hitl_verification": "hitl_verification", "explain": "explain", }) graph.add_edge("hitl_verification", "explain") graph.add_edge("explain", "save_memory") graph.add_edge("save_memory", END) checkpointer = MemorySaver() compiled = graph.compile( checkpointer=checkpointer, interrupt_before=["hitl_extraction", "hitl_clarification", "hitl_verification"], ) return compiled app_graph = build_graph()