Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from datetime import datetime | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langgraph.graph import END, StateGraph | |
| from agents.state import MathMentorState | |
| from agents.guardrail_agent import guardrail_node | |
| from agents.parser_agent import parser_node | |
| from agents.router_agent import router_node | |
| from agents.solver_agent import solver_node | |
| from agents.verifier_agent import verifier_node | |
| from agents.explainer_agent import explainer_node | |
| from config import settings | |
| from input_handlers.image_handler import handle_image_input | |
| from input_handlers.audio_handler import handle_audio_input | |
| from input_handlers.text_handler import handle_text_input | |
| from rag.retriever import retrieve as rag_retrieve | |
| def extract_input_node(state: MathMentorState) -> dict: | |
| input_type = state.get("input_type", "text") | |
| raw = state.get("raw_input", "") | |
| if input_type == "image": | |
| result = handle_image_input(raw) | |
| elif input_type == "audio": | |
| result = handle_audio_input(raw) | |
| else: | |
| result = handle_text_input(raw) | |
| needs_review = result["confidence"] < settings.ocr_confidence_threshold and input_type != "text" | |
| return { | |
| "extracted_text": result["text"], | |
| "extraction_confidence": result["confidence"], | |
| "needs_human_review": needs_review, | |
| "agent_trace": state.get("agent_trace", []) | |
| + [ | |
| { | |
| "agent": "extractor", | |
| "action": "extracted", | |
| "summary": f"Type: {input_type}, confidence: {result['confidence']:.2f}", | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| ], | |
| } | |
| def retrieve_context_node(state: MathMentorState) -> dict: | |
| parsed = state.get("parsed_problem", {}) | |
| query = parsed.get("problem_text", state.get("extracted_text", "")) | |
| topic = state.get("problem_topic", "") | |
| search_query = f"{topic}: {query}" if topic else query | |
| chunks = rag_retrieve(search_query) | |
| return { | |
| "retrieved_chunks": chunks, | |
| "agent_trace": state.get("agent_trace", []) | |
| + [ | |
| { | |
| "agent": "retriever", | |
| "action": "retrieved", | |
| "summary": f"Found {len(chunks)} relevant chunks", | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| ], | |
| } | |
| def retrieve_memory_node(state: MathMentorState) -> dict: | |
| try: | |
| from memory.retriever import find_similar | |
| parsed = state.get("parsed_problem", {}) | |
| query = parsed.get("problem_text", state.get("extracted_text", "")) | |
| similar = find_similar(query, top_k=3) | |
| except Exception: | |
| similar = [] | |
| return { | |
| "similar_past_problems": similar, | |
| "agent_trace": state.get("agent_trace", []) | |
| + [ | |
| { | |
| "agent": "memory_retriever", | |
| "action": "retrieved", | |
| "summary": f"Found {len(similar)} similar past problems", | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| ], | |
| } | |
| def save_to_memory_node(state: MathMentorState) -> dict: | |
| try: | |
| from memory.store import save_record | |
| save_record( | |
| input_type=state.get("input_type", "text"), | |
| extracted_text=state.get("extracted_text", ""), | |
| parsed_problem=state.get("parsed_problem", {}), | |
| topic=state.get("problem_topic", ""), | |
| retrieved_chunks=[c.get("source", "") for c in state.get("retrieved_chunks", [])], | |
| solution=state.get("solution", ""), | |
| solution_steps=state.get("solution_steps", []), | |
| verification=state.get("verification_result", {}), | |
| explanation=state.get("explanation", ""), | |
| ) | |
| except Exception: | |
| pass | |
| return { | |
| "agent_trace": state.get("agent_trace", []) | |
| + [ | |
| { | |
| "agent": "memory_saver", | |
| "action": "saved", | |
| "summary": "Problem and solution saved to memory", | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| ], | |
| } | |
| def should_review_extraction(state: MathMentorState) -> str: | |
| if state.get("needs_human_review", False): | |
| return "hitl_extraction" | |
| return "guardrail" | |
| def should_review_after_guardrail(state: MathMentorState) -> str: | |
| if not state.get("is_valid_input", True): | |
| return END | |
| return "parse" | |
| def should_review_parse(state: MathMentorState) -> str: | |
| parsed = state.get("parsed_problem", {}) | |
| if parsed.get("needs_clarification", False) or state.get("needs_human_review", False): | |
| return "hitl_clarification" | |
| return "route" | |
| def should_review_verification(state: MathMentorState) -> str: | |
| verification = state.get("verification_result", {}) | |
| confidence = verification.get("confidence", 0) | |
| is_correct = verification.get("is_correct", False) | |
| retries = state.get("solver_retries", 0) | |
| if not is_correct and retries < settings.max_solver_retries: | |
| return "solve" # Retry solving | |
| if confidence < settings.verifier_confidence_threshold or not is_correct: | |
| return "hitl_verification" | |
| return "explain" | |
| def hitl_extraction_node(state: MathMentorState) -> dict: | |
| text = state.get("human_edited_text") or state.get("extracted_text", "") | |
| return { | |
| "extracted_text": text, | |
| "needs_human_review": False, | |
| "human_approved": True, | |
| } | |
| def hitl_clarification_node(state: MathMentorState) -> dict: | |
| text = state.get("human_edited_text") or state.get("extracted_text", "") | |
| return { | |
| "extracted_text": text, | |
| "needs_human_review": False, | |
| "human_approved": True, | |
| } | |
| def hitl_verification_node(state: MathMentorState) -> dict: | |
| return { | |
| "needs_human_review": False, | |
| "human_approved": True, | |
| } | |
| def build_graph(): | |
| graph = StateGraph(MathMentorState) | |
| graph.add_node("extract", extract_input_node) | |
| graph.add_node("hitl_extraction", hitl_extraction_node) | |
| graph.add_node("guardrail", guardrail_node) | |
| graph.add_node("parse", parser_node) | |
| graph.add_node("hitl_clarification", hitl_clarification_node) | |
| graph.add_node("route", router_node) | |
| graph.add_node("retrieve_context", retrieve_context_node) | |
| graph.add_node("retrieve_memory", retrieve_memory_node) | |
| graph.add_node("solve", solver_node) | |
| graph.add_node("verify", verifier_node) | |
| graph.add_node("hitl_verification", hitl_verification_node) | |
| graph.add_node("explain", explainer_node) | |
| graph.add_node("save_memory", save_to_memory_node) | |
| graph.set_entry_point("extract") | |
| graph.add_conditional_edges("extract", should_review_extraction, { | |
| "hitl_extraction": "hitl_extraction", | |
| "guardrail": "guardrail", | |
| }) | |
| graph.add_edge("hitl_extraction", "guardrail") | |
| graph.add_conditional_edges("guardrail", should_review_after_guardrail, { | |
| END: END, | |
| "parse": "parse", | |
| }) | |
| graph.add_conditional_edges("parse", should_review_parse, { | |
| "hitl_clarification": "hitl_clarification", | |
| "route": "route", | |
| }) | |
| graph.add_edge("hitl_clarification", "parse") | |
| graph.add_edge("route", "retrieve_context") | |
| graph.add_edge("retrieve_context", "retrieve_memory") | |
| graph.add_edge("retrieve_memory", "solve") | |
| graph.add_edge("solve", "verify") | |
| graph.add_conditional_edges("verify", should_review_verification, { | |
| "solve": "solve", | |
| "hitl_verification": "hitl_verification", | |
| "explain": "explain", | |
| }) | |
| graph.add_edge("hitl_verification", "explain") | |
| graph.add_edge("explain", "save_memory") | |
| graph.add_edge("save_memory", END) | |
| checkpointer = MemorySaver() | |
| compiled = graph.compile( | |
| checkpointer=checkpointer, | |
| interrupt_before=["hitl_extraction", "hitl_clarification", "hitl_verification"], | |
| ) | |
| return compiled | |
| app_graph = build_graph() | |