krishbaresha's picture
Initial commit: MehfoozAI All-in-One
dbe9dba
from typing import TypedDict, List, Optional
from langgraph.graph import StateGraph, END
from src.agents.intake import extract_details
from src.agents.fir_drafter import draft_fir
from src.agents.router import route_report
from src.agents.safety_mapper import map_to_safety_zones
from loguru import logger
# ─── State ──────────────────────────────────────────────────────────
class AgentState(TypedDict):
input: str
details: Optional[dict]
fir_result: Optional[dict]
fir_draft: Optional[str]
ppc_sections: Optional[List[str]]
routing: Optional[dict]
safety_zone: Optional[dict]
status: str
errors: List[str]
# ─── Nodes ──────────────────────────────────────────────────────────
async def intake_node(state: AgentState) -> dict:
logger.info("🧠 [Agent 1/4] Intake β€” extracting details...")
try:
details = await extract_details(state["input"])
return {"details": details, "status": "details_extracted", "errors": []}
except Exception as e:
logger.error(f"Intake error: {e}")
return {"errors": [str(e)], "status": "error_intake", "details": {}}
async def fir_drafting_node(state: AgentState) -> dict:
logger.info("πŸ“„ [Agent 2/4] FIR β€” drafting legal document...")
try:
result = await draft_fir(state["details"])
return {
"fir_result": result,
"fir_draft": result.get("fir_text_english"),
"ppc_sections": result.get("ppc_sections", []),
"status": "fir_drafted",
}
except Exception as e:
logger.error(f"FIR error: {e}")
return {"errors": [str(e)], "status": "error_fir"}
async def routing_node(state: AgentState) -> dict:
logger.info("πŸš” [Agent 3/4] Routing β€” finding nearest authority...")
try:
routing = await route_report(state["details"])
return {"routing": routing, "status": "routed"}
except Exception as e:
logger.error(f"Routing error: {e}")
return {"errors": [str(e)], "status": "error_routing"}
async def safety_mapper_node(state: AgentState) -> dict:
logger.info("πŸ—ΊοΈ [Agent 4/4] Safety Mapper β€” classifying danger zone...")
try:
zone = await map_to_safety_zones(state["details"])
return {"safety_zone": zone, "status": "completed"}
except Exception as e:
logger.error(f"Safety mapper error: {e}")
return {"errors": [str(e)], "status": "completed"} # Non-fatal
# ─── Graph ──────────────────────────────────────────────────────────
workflow = StateGraph(AgentState)
workflow.add_node("intake", intake_node)
workflow.add_node("fir_drafter", fir_drafting_node)
workflow.add_node("router", routing_node)
workflow.add_node("safety_mapper", safety_mapper_node)
workflow.set_entry_point("intake")
workflow.add_edge("intake", "fir_drafter")
workflow.add_edge("fir_drafter", "router")
workflow.add_edge("router", "safety_mapper")
workflow.add_edge("safety_mapper", END)
app_graph = workflow.compile()
async def run_pipeline(user_input: str) -> AgentState:
"""Run the full 4-agent pipeline and return final state."""
initial_state: AgentState = {
"input": user_input,
"details": None,
"fir_result": None,
"fir_draft": None,
"ppc_sections": None,
"routing": None,
"safety_zone": None,
"status": "started",
"errors": [],
}
result = await app_graph.ainvoke(initial_state)
return result