Spaces:
Running
Running
| """LangGraph workflow orchestrator""" | |
| # TODO: Implement graph workflow | |
| import asyncio | |
| import uuid # extra import | |
| from typing import Dict, Any, List, Optional, Literal | |
| from datetime import datetime | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from state import ( | |
| InvoiceProcessingState, ProcessingStatus, ValidationStatus, | |
| RiskLevel, PaymentStatus, WORKFLOW_CONFIGS | |
| ) | |
| from agents.base_agent import agent_registry | |
| from agents.document_agent import DocumentAgent | |
| from agents.validation_agent import ValidationAgent | |
| from agents.risk_agent import RiskAgent | |
| from agents.payment_agent import PaymentAgent | |
| from agents.audit_agent import AuditAgent | |
| from agents.escalation_agent import EscalationAgent | |
| from utils.logger import StructuredLogger | |
| class InvoiceProcessingGraph: | |
| """Graph orchestrator""" | |
| def __init__(self, config: Dict[str, Any] = None): | |
| self.logger = StructuredLogger("InvoiceProcessingGraph") | |
| self.config = config or {} | |
| #Simple in-memory store for process states (process_id -> InvoiceProcessingState) | |
| self._process_store: Dict[str, InvoiceProcessingState] = {} | |
| #Register and initialize agents | |
| self._initialize_agents() | |
| try: | |
| self.graph = self._create_workflow_graph() | |
| self.compiled_graph = self.graph.compile(checkpointer=MemorySaver()) | |
| self.logger.logger.info("InvoiceProcessingGraph initialized successfully with compiled graph.") | |
| except Exception as e: | |
| self.logger.logger.warning(f"Failed to fully build graph nodes: {e} — exposing empty StateGraph") | |
| self.graph = StateGraph("invoice_processing_graph_fallback") | |
| def _initialize_agents(self): | |
| """Instantiate and register agent instances in the global registry""" | |
| #create agent instances (idempotent - replace if already registered) | |
| agents = [ | |
| DocumentAgent(), | |
| ValidationAgent(), | |
| RiskAgent(), | |
| PaymentAgent(), | |
| AuditAgent(), | |
| EscalationAgent(), | |
| ] | |
| for agent in agents: | |
| agent_registry.register(agent) | |
| self.logger.logger.info(f"Registered agents: {agent_registry.list_agents()}") | |
| def _create_workflow_graph(self) -> StateGraph: | |
| """ | |
| Build a LangGraph StateGraph with conditional routing: | |
| Each node executes its corresponding agent and determines | |
| the next node based on runtime logic (risk, validation, etc.) | |
| """ | |
| graph = StateGraph("invoice_processing_graph") | |
| # NODE DEFINITIONS | |
| async def node_document(state: InvoiceProcessingState): | |
| state = await self._document_agent_node(state) | |
| next_node = self._route_after_document(state) | |
| return next_node, state | |
| async def node_validation(state: InvoiceProcessingState): | |
| state = await self._validation_agent_node(state) | |
| next_node = self._route_after_validation(state) | |
| return next_node, state | |
| async def node_risk(state: InvoiceProcessingState): | |
| state = await self._risk_agent_node(state) | |
| next_node = self._route_after_risk(state) | |
| return next_node, state | |
| async def node_payment(state: InvoiceProcessingState): | |
| state = await self._payment_agent_node(state) | |
| next_node = self._route_after_payment(state) | |
| return next_node, state | |
| async def node_audit(state: InvoiceProcessingState): | |
| state = await self._audit_agent_node(state) | |
| next_node = self._route_after_audit(state) | |
| return next_node, state | |
| async def node_escalation(state: InvoiceProcessingState): | |
| state = await self._escalation_agent_node(state) | |
| next_node = self._route_after_escalation(state) | |
| return next_node, state | |
| async def node_human_review(state: InvoiceProcessingState): | |
| state = await self._human_review_node(state) | |
| next_node = self._route_after_human_review(state) | |
| return next_node, state | |
| async def node_end(state: InvoiceProcessingState): | |
| self.logger.logger.info(f"Invoice {state.invoice_id} completed at {state.updated_at}") | |
| return "end", state | |
| # REGISTER NODES | |
| for name, func in { | |
| "document": node_document, | |
| "validation": node_validation, | |
| "risk": node_risk, | |
| "payment": node_payment, | |
| "audit": node_audit, | |
| "escalation": node_escalation, | |
| "human_review": node_human_review, | |
| "end": node_end, | |
| }.items(): | |
| try: | |
| graph.add_node(name, func) | |
| except Exception: | |
| # fallback if add_node signature differs | |
| setattr(graph, name, func) | |
| # ADD EDGES (DEFAULT PATHS) | |
| try: | |
| graph.add_edge("document", "validation") | |
| graph.add_edge("validation", "risk") | |
| graph.add_edge("risk", "payment") | |
| graph.add_edge("payment", "audit") | |
| graph.add_edge("audit", "end") | |
| # Alternative / exception flows | |
| graph.add_edge("document", "escalation") | |
| graph.add_edge("validation", "escalation") | |
| graph.add_edge("risk", "escalation") | |
| graph.add_edge("escalation", "human_review") | |
| graph.add_edge("human_review", "end") | |
| graph.set_entry_point("document") | |
| except Exception as ex: | |
| self.logger.logger.warning(f"Edge registration failed: {ex}") | |
| self.logger.logger.info("Conditional workflow graph built successfully.") | |
| return graph | |
| async def _document_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: DocumentAgent = agent_registry.get("document_agent") | |
| print("agent from doc", agent) | |
| if not agent: | |
| agent = DocumentAgent() | |
| agent_registry.register(agent) | |
| print("Registry instance ID in graph:", id(agent_registry)) | |
| return await agent.run(state, workflow_type) | |
| async def _validation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: ValidationAgent = agent_registry.get("validation_agent") | |
| print("agent from val", agent) | |
| if not agent: | |
| agent = ValidationAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| async def _risk_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: RiskAgent = agent_registry.get("risk_agent") | |
| if not agent: | |
| agent = RiskAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| async def _payment_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: PaymentAgent = agent_registry.get("payment_agent") | |
| if not agent: | |
| agent = PaymentAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| async def _audit_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: AuditAgent = agent_registry.get("audit_agent") | |
| if not agent: | |
| agent = AuditAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| async def _escalation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| agent: EscalationAgent = agent_registry.get("escalation_agent") | |
| if not agent: | |
| agent = EscalationAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| async def _human_review_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: | |
| #Reusing escalation agent's human-in-the-loop or simply marking for manual review | |
| agent: EscalationAgent = agent_registry.get("escalation_agent") | |
| if not agent: | |
| agent = EscalationAgent() | |
| agent_registry.register(agent) | |
| return await agent.run(state, workflow_type) | |
| def _route_after_document(self, state: InvoiceProcessingState) -> Literal["validation", "escalation", "end"]: | |
| """Route decision after document extraction""" | |
| #if extraction yielded no invoice_data or low confidence -> escalate | |
| if not state.invoice_data: | |
| return "escalation" | |
| #if extraction confidence exists and is low -> escalate | |
| conf = getattr(state.invoice_data, "extraction_confidence", None) | |
| if conf is not None and conf<0.6: | |
| return "escalation" | |
| return "validation" | |
| def _route_after_validation(self, state: InvoiceProcessingState) -> Literal["risk", "escalation", "end"]: | |
| """Route decision after document validation""" | |
| vr = state.validation_result | |
| if not vr: | |
| return "escalation" | |
| #if missing PO or invalid -> escalate | |
| try: | |
| status = vr.validation_status | |
| #ValidationStatus maybe enum or str | |
| if isinstance(status,ValidationStatus): | |
| status_val = status | |
| else: | |
| status_val = ValidationStatus(status) if isinstance(status,str) else None | |
| if status_val == ValidationStatus.NO_MATCH or status_val == ValidationStatus.PARTIAL_MATCH and (not vr.amount_match): | |
| return "escalation" | |
| except Exception: | |
| #fallback: if discrepancies exist -> escalation | |
| if vr and getattr(vr,"discrepancies",None): | |
| return "escalation" | |
| return "risk" | |
| def _route_after_risk(self, state: InvoiceProcessingState) -> Literal["payment", "escalation", "human_review", "end"]: | |
| """Route decision after risk assessment""" | |
| ra = state.risk_assessment | |
| if not ra: | |
| return "escalation" | |
| #ra.risk_level is an enum RiskLevel | |
| rl = getattr(ra,"risk_level",None) | |
| #handle strings or enums | |
| rl_val = rl.value if hasattr(rl,"value") else str(rl).lower() | |
| try: | |
| if rl_val in (RiskLevel.CRITICAL.value, RiskLevel.HIGH.value): | |
| #For critical-> human review; for high->escalate | |
| if rl_val == RiskLevel.CRITICAL.value: | |
| return "human_review" | |
| return "escalation" | |
| else: | |
| #low or medium -> payment | |
| return "payment" | |
| except Exception: | |
| return "payment" | |
| def _route_after_payment(self, state: InvoiceProcessingState) -> Literal["audit", "escalation", "end"]: | |
| pd = getattr(state,"payment_decision",None) | |
| if not pd: | |
| return "escalation" | |
| #If approved (or scheduled) -> audit | |
| try: | |
| status = pd.payment_status | |
| #Accept enum or str | |
| status_val = status if isinstance(status,str) else getattr(status,"value",str(status)) | |
| if status_val in (PaymentStatus.APPROVED.value, PaymentStatus.SCHEDULED.value, PaymentStatus.PENDING_APPROVAL.value): | |
| return "audit" | |
| else: | |
| return "escalation" | |
| except Exception: | |
| return "audit" | |
| def _route_after_audit(self, state: InvoiceProcessingState) -> Literal["escalation", "end"]: | |
| cr = getattr(state, "compliance_report",None) | |
| if not cr: | |
| return "end" | |
| #If any compliance issues ->escalate | |
| issues = cr.get("issues",{}) if isinstance(cr, dict) else {} | |
| has_issues = any(issues.get(k) for k in issues) | |
| return "escalation" if has_issues else "end" | |
| async def _handle_escalation_chain(self, state: "InvoiceProcessingState", workflow_type): | |
| """Common handler for escalation → human review → complete""" | |
| state = await self._escalation_agent_node(state, workflow_type) | |
| self._process_store[state.process_id] = state | |
| state = await self._human_review_node(state, workflow_type) | |
| state.overall_status = ProcessingStatus.COMPLETED | |
| self._process_store[state.process_id] = state | |
| return state | |
| async def process_invoice(self, file_name: str, workflow_type: str = "standard", | |
| config: Dict[str, Any] = None) -> InvoiceProcessingState: | |
| """ | |
| Orchestrate processing for a single invoice file. | |
| Supports 3 workflow types: standard, high_value, and expedited. | |
| """ | |
| process_id = f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" | |
| initial_state = InvoiceProcessingState( | |
| process_id=process_id, | |
| file_name=file_name, | |
| overall_status=ProcessingStatus.PENDING, | |
| current_agent=None, | |
| workflow_type=workflow_type, | |
| created_at=datetime.utcnow(), | |
| updated_at=datetime.utcnow(), | |
| ) | |
| self._process_store[process_id] = initial_state | |
| start_ts = datetime.utcnow() | |
| state = initial_state | |
| worked_agents = [] | |
| try: | |
| # STEP 1️ Document Extraction | |
| state = await self._document_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| route = self._route_after_document(state) | |
| print("state agent anme ::::::::::::::", state.agent_name) | |
| worked_agents.append(state.agent_name) | |
| if route == "escalation": | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # ---- Workflow branching ---- | |
| if workflow_type == "expedited": | |
| # Fast lane - skip validation if AI confidence is high | |
| # if getattr(state, "extraction_confidence", 0.0) < 0.85: | |
| state = await self._validation_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| worked_agents.append(state.agent_name) | |
| route = self._route_after_validation(state) | |
| if route == "escalation": | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # Directly go to Payment, minimal audit | |
| state = await self._payment_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| if getattr(state, "payment_decision", {}).decision == "auto-pay": | |
| state = await self._audit_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| elif workflow_type == "high_value": | |
| # 2️ Validation (twice for accuracy) | |
| state = await self._validation_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| state = await self._validation_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| route = self._route_after_validation(state) | |
| if route == "escalation": | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # 3️ Risk | |
| state = await self._risk_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| route = self._route_after_risk(state) | |
| if route in ["escalation", "human_review"]: | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # 4 Audit | |
| state = await self._audit_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| # 5 Mandatory human review for high-value invoices | |
| state = await self._human_review_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| else: # STANDARD workflow | |
| # 2️ Validation | |
| state = await self._validation_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| worked_agents.append(state.agent_name) | |
| route = self._route_after_validation(state) | |
| if route == "escalation": | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # 3️ Risk | |
| state = await self._risk_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| worked_agents.append(state.agent_name) | |
| route = self._route_after_risk(state) | |
| if route in ["escalation", "human_review"]: | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # 4️ Payment | |
| state = await self._payment_agent_node(state, workflow_type) | |
| self._process_store[process_id] = state | |
| worked_agents.append(state.agent_name) | |
| route = self._route_after_payment(state) | |
| if route == "escalation": | |
| state = await self._handle_escalation_chain(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| return state, worked_agents | |
| # 5️ Audit | |
| state = await self._audit_agent_node(state, workflow_type) | |
| worked_agents.append(state.agent_name) | |
| self._process_store[process_id] = state | |
| # Success completion | |
| state.overall_status = ProcessingStatus.COMPLETED | |
| state.updated_at = datetime.utcnow() | |
| elapsed = (datetime.utcnow() - start_ts).total_seconds() | |
| self.logger.logger.info(f"Process {process_id} ({workflow_type}) completed in {elapsed:.2f}s") | |
| self._process_store[process_id] = state | |
| # print("from graph worked agents::::", worked_agents) | |
| return state, worked_agents | |
| except Exception as e: | |
| self.logger.logger.exception(f"Error processing invoice {file_name}: {e}") | |
| state.overall_status = ProcessingStatus.FAILED | |
| self._process_store[process_id] = state | |
| return state, worked_agents | |
| # async def process_batch(self, file_names: List[str], workflow_type: str = "standard", | |
| # max_concurrent: int = 5) -> List[InvoiceProcessingState]: | |
| # """Process a batch of files with limit concurrency""" | |
| # sem = asyncio.Semaphore(max_concurrent) | |
| # results: List[InvoiceProcessingState] = [] | |
| # async def _worker(fn: str): | |
| # async with sem: | |
| # return await self.process_invoice(fn, workflow_type=workflow_type) | |
| # tasks = [asyncio.create_task(_worker(f)) for f in file_names] | |
| # completed = await asyncio.gather(*tasks) | |
| # for st in completed: | |
| # results.append(st) | |
| # return results | |
| async def process_batch(self, file_names: List[str], workflow_type: str = "standard", | |
| max_concurrent: int = 5): | |
| sem = asyncio.Semaphore(max_concurrent) | |
| results = [] # will store: {"state": ..., "worked_agents": [...]} | |
| async def _worker(fn: str): | |
| async with sem: | |
| return await self.process_invoice(fn, workflow_type=workflow_type) | |
| tasks = [asyncio.create_task(_worker(f)) for f in file_names] | |
| completed = await asyncio.gather(*tasks) | |
| for result in completed: | |
| state, worked_agents = result # unpack the tuple | |
| results.append({ | |
| "state": state, | |
| "worked_agents": worked_agents | |
| }) | |
| return results | |
| async def get_workflow_status(self, process_id: str) -> Optional[Dict[str, Any]]: | |
| """Return the stored workflow status dictionary for a given process_id""" | |
| state = self._process_store.get(process_id) | |
| if not state: | |
| return None | |
| return {"process_id":process_id, "status":state.overall_status, "updated_at": getattr(state,"updated_at",None), "state":state.model_dump()} | |
| async def health_check(self) -> Dict[str, Any]: | |
| """Aggregate health check across agents and the orchestrator itself""" | |
| agents_health = await agent_registry.health_check_all() | |
| return {"orchestrator":"Healthy","timestamp":datetime.utcnow().isoformat(),"agent":agents_health} | |
| def _extract_final_state(self, result, initial_state: InvoiceProcessingState) -> InvoiceProcessingState: | |
| """Compatibility helper (returns invoice processing state)""" | |
| return result | |
| invoice_workflow: Optional[InvoiceProcessingGraph] = None | |
| def get_workflow(config: Dict[str, Any] = None) -> InvoiceProcessingGraph: | |
| global invoice_workflow | |
| if invoice_workflow is None: | |
| invoice_workflow = InvoiceProcessingGraph(config=config) | |
| return invoice_workflow |