InvoiceAgenticAI / graph.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
"""LangGraph workflow orchestrator"""
# TODO: Implement graph workflow
import asyncio
import uuid # extra import
from typing import Dict, Any, List, Optional, Literal
from datetime import datetime
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import (
InvoiceProcessingState, ProcessingStatus, ValidationStatus,
RiskLevel, PaymentStatus, WORKFLOW_CONFIGS
)
from agents.base_agent import agent_registry
from agents.document_agent import DocumentAgent
from agents.validation_agent import ValidationAgent
from agents.risk_agent import RiskAgent
from agents.payment_agent import PaymentAgent
from agents.audit_agent import AuditAgent
from agents.escalation_agent import EscalationAgent
from utils.logger import StructuredLogger
class InvoiceProcessingGraph:
"""Graph orchestrator"""
def __init__(self, config: Dict[str, Any] = None):
self.logger = StructuredLogger("InvoiceProcessingGraph")
self.config = config or {}
#Simple in-memory store for process states (process_id -> InvoiceProcessingState)
self._process_store: Dict[str, InvoiceProcessingState] = {}
#Register and initialize agents
self._initialize_agents()
try:
self.graph = self._create_workflow_graph()
self.compiled_graph = self.graph.compile(checkpointer=MemorySaver())
self.logger.logger.info("InvoiceProcessingGraph initialized successfully with compiled graph.")
except Exception as e:
self.logger.logger.warning(f"Failed to fully build graph nodes: {e} — exposing empty StateGraph")
self.graph = StateGraph("invoice_processing_graph_fallback")
def _initialize_agents(self):
"""Instantiate and register agent instances in the global registry"""
#create agent instances (idempotent - replace if already registered)
agents = [
DocumentAgent(),
ValidationAgent(),
RiskAgent(),
PaymentAgent(),
AuditAgent(),
EscalationAgent(),
]
for agent in agents:
agent_registry.register(agent)
self.logger.logger.info(f"Registered agents: {agent_registry.list_agents()}")
def _create_workflow_graph(self) -> StateGraph:
"""
Build a LangGraph StateGraph with conditional routing:
Each node executes its corresponding agent and determines
the next node based on runtime logic (risk, validation, etc.)
"""
graph = StateGraph("invoice_processing_graph")
# NODE DEFINITIONS
async def node_document(state: InvoiceProcessingState):
state = await self._document_agent_node(state)
next_node = self._route_after_document(state)
return next_node, state
async def node_validation(state: InvoiceProcessingState):
state = await self._validation_agent_node(state)
next_node = self._route_after_validation(state)
return next_node, state
async def node_risk(state: InvoiceProcessingState):
state = await self._risk_agent_node(state)
next_node = self._route_after_risk(state)
return next_node, state
async def node_payment(state: InvoiceProcessingState):
state = await self._payment_agent_node(state)
next_node = self._route_after_payment(state)
return next_node, state
async def node_audit(state: InvoiceProcessingState):
state = await self._audit_agent_node(state)
next_node = self._route_after_audit(state)
return next_node, state
async def node_escalation(state: InvoiceProcessingState):
state = await self._escalation_agent_node(state)
next_node = self._route_after_escalation(state)
return next_node, state
async def node_human_review(state: InvoiceProcessingState):
state = await self._human_review_node(state)
next_node = self._route_after_human_review(state)
return next_node, state
async def node_end(state: InvoiceProcessingState):
self.logger.logger.info(f"Invoice {state.invoice_id} completed at {state.updated_at}")
return "end", state
# REGISTER NODES
for name, func in {
"document": node_document,
"validation": node_validation,
"risk": node_risk,
"payment": node_payment,
"audit": node_audit,
"escalation": node_escalation,
"human_review": node_human_review,
"end": node_end,
}.items():
try:
graph.add_node(name, func)
except Exception:
# fallback if add_node signature differs
setattr(graph, name, func)
# ADD EDGES (DEFAULT PATHS)
try:
graph.add_edge("document", "validation")
graph.add_edge("validation", "risk")
graph.add_edge("risk", "payment")
graph.add_edge("payment", "audit")
graph.add_edge("audit", "end")
# Alternative / exception flows
graph.add_edge("document", "escalation")
graph.add_edge("validation", "escalation")
graph.add_edge("risk", "escalation")
graph.add_edge("escalation", "human_review")
graph.add_edge("human_review", "end")
graph.set_entry_point("document")
except Exception as ex:
self.logger.logger.warning(f"Edge registration failed: {ex}")
self.logger.logger.info("Conditional workflow graph built successfully.")
return graph
async def _document_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: DocumentAgent = agent_registry.get("document_agent")
print("agent from doc", agent)
if not agent:
agent = DocumentAgent()
agent_registry.register(agent)
print("Registry instance ID in graph:", id(agent_registry))
return await agent.run(state, workflow_type)
async def _validation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: ValidationAgent = agent_registry.get("validation_agent")
print("agent from val", agent)
if not agent:
agent = ValidationAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
async def _risk_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: RiskAgent = agent_registry.get("risk_agent")
if not agent:
agent = RiskAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
async def _payment_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: PaymentAgent = agent_registry.get("payment_agent")
if not agent:
agent = PaymentAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
async def _audit_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: AuditAgent = agent_registry.get("audit_agent")
if not agent:
agent = AuditAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
async def _escalation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
agent: EscalationAgent = agent_registry.get("escalation_agent")
if not agent:
agent = EscalationAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
async def _human_review_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
#Reusing escalation agent's human-in-the-loop or simply marking for manual review
agent: EscalationAgent = agent_registry.get("escalation_agent")
if not agent:
agent = EscalationAgent()
agent_registry.register(agent)
return await agent.run(state, workflow_type)
def _route_after_document(self, state: InvoiceProcessingState) -> Literal["validation", "escalation", "end"]:
"""Route decision after document extraction"""
#if extraction yielded no invoice_data or low confidence -> escalate
if not state.invoice_data:
return "escalation"
#if extraction confidence exists and is low -> escalate
conf = getattr(state.invoice_data, "extraction_confidence", None)
if conf is not None and conf<0.6:
return "escalation"
return "validation"
def _route_after_validation(self, state: InvoiceProcessingState) -> Literal["risk", "escalation", "end"]:
"""Route decision after document validation"""
vr = state.validation_result
if not vr:
return "escalation"
#if missing PO or invalid -> escalate
try:
status = vr.validation_status
#ValidationStatus maybe enum or str
if isinstance(status,ValidationStatus):
status_val = status
else:
status_val = ValidationStatus(status) if isinstance(status,str) else None
if status_val == ValidationStatus.NO_MATCH or status_val == ValidationStatus.PARTIAL_MATCH and (not vr.amount_match):
return "escalation"
except Exception:
#fallback: if discrepancies exist -> escalation
if vr and getattr(vr,"discrepancies",None):
return "escalation"
return "risk"
def _route_after_risk(self, state: InvoiceProcessingState) -> Literal["payment", "escalation", "human_review", "end"]:
"""Route decision after risk assessment"""
ra = state.risk_assessment
if not ra:
return "escalation"
#ra.risk_level is an enum RiskLevel
rl = getattr(ra,"risk_level",None)
#handle strings or enums
rl_val = rl.value if hasattr(rl,"value") else str(rl).lower()
try:
if rl_val in (RiskLevel.CRITICAL.value, RiskLevel.HIGH.value):
#For critical-> human review; for high->escalate
if rl_val == RiskLevel.CRITICAL.value:
return "human_review"
return "escalation"
else:
#low or medium -> payment
return "payment"
except Exception:
return "payment"
def _route_after_payment(self, state: InvoiceProcessingState) -> Literal["audit", "escalation", "end"]:
pd = getattr(state,"payment_decision",None)
if not pd:
return "escalation"
#If approved (or scheduled) -> audit
try:
status = pd.payment_status
#Accept enum or str
status_val = status if isinstance(status,str) else getattr(status,"value",str(status))
if status_val in (PaymentStatus.APPROVED.value, PaymentStatus.SCHEDULED.value, PaymentStatus.PENDING_APPROVAL.value):
return "audit"
else:
return "escalation"
except Exception:
return "audit"
def _route_after_audit(self, state: InvoiceProcessingState) -> Literal["escalation", "end"]:
cr = getattr(state, "compliance_report",None)
if not cr:
return "end"
#If any compliance issues ->escalate
issues = cr.get("issues",{}) if isinstance(cr, dict) else {}
has_issues = any(issues.get(k) for k in issues)
return "escalation" if has_issues else "end"
async def _handle_escalation_chain(self, state: "InvoiceProcessingState", workflow_type):
"""Common handler for escalation → human review → complete"""
state = await self._escalation_agent_node(state, workflow_type)
self._process_store[state.process_id] = state
state = await self._human_review_node(state, workflow_type)
state.overall_status = ProcessingStatus.COMPLETED
self._process_store[state.process_id] = state
return state
async def process_invoice(self, file_name: str, workflow_type: str = "standard",
config: Dict[str, Any] = None) -> InvoiceProcessingState:
"""
Orchestrate processing for a single invoice file.
Supports 3 workflow types: standard, high_value, and expedited.
"""
process_id = f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
initial_state = InvoiceProcessingState(
process_id=process_id,
file_name=file_name,
overall_status=ProcessingStatus.PENDING,
current_agent=None,
workflow_type=workflow_type,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
self._process_store[process_id] = initial_state
start_ts = datetime.utcnow()
state = initial_state
worked_agents = []
try:
# STEP 1️ Document Extraction
state = await self._document_agent_node(state, workflow_type)
self._process_store[process_id] = state
route = self._route_after_document(state)
print("state agent anme ::::::::::::::", state.agent_name)
worked_agents.append(state.agent_name)
if route == "escalation":
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# ---- Workflow branching ----
if workflow_type == "expedited":
# Fast lane - skip validation if AI confidence is high
# if getattr(state, "extraction_confidence", 0.0) < 0.85:
state = await self._validation_agent_node(state, workflow_type)
self._process_store[process_id] = state
worked_agents.append(state.agent_name)
route = self._route_after_validation(state)
if route == "escalation":
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# Directly go to Payment, minimal audit
state = await self._payment_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
if getattr(state, "payment_decision", {}).decision == "auto-pay":
state = await self._audit_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
elif workflow_type == "high_value":
# 2️ Validation (twice for accuracy)
state = await self._validation_agent_node(state, workflow_type)
self._process_store[process_id] = state
state = await self._validation_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
route = self._route_after_validation(state)
if route == "escalation":
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# 3️ Risk
state = await self._risk_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
route = self._route_after_risk(state)
if route in ["escalation", "human_review"]:
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# 4 Audit
state = await self._audit_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
# 5 Mandatory human review for high-value invoices
state = await self._human_review_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
else: # STANDARD workflow
# 2️ Validation
state = await self._validation_agent_node(state, workflow_type)
self._process_store[process_id] = state
worked_agents.append(state.agent_name)
route = self._route_after_validation(state)
if route == "escalation":
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# 3️ Risk
state = await self._risk_agent_node(state, workflow_type)
self._process_store[process_id] = state
worked_agents.append(state.agent_name)
route = self._route_after_risk(state)
if route in ["escalation", "human_review"]:
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# 4️ Payment
state = await self._payment_agent_node(state, workflow_type)
self._process_store[process_id] = state
worked_agents.append(state.agent_name)
route = self._route_after_payment(state)
if route == "escalation":
state = await self._handle_escalation_chain(state, workflow_type)
worked_agents.append(state.agent_name)
return state, worked_agents
# 5️ Audit
state = await self._audit_agent_node(state, workflow_type)
worked_agents.append(state.agent_name)
self._process_store[process_id] = state
# Success completion
state.overall_status = ProcessingStatus.COMPLETED
state.updated_at = datetime.utcnow()
elapsed = (datetime.utcnow() - start_ts).total_seconds()
self.logger.logger.info(f"Process {process_id} ({workflow_type}) completed in {elapsed:.2f}s")
self._process_store[process_id] = state
# print("from graph worked agents::::", worked_agents)
return state, worked_agents
except Exception as e:
self.logger.logger.exception(f"Error processing invoice {file_name}: {e}")
state.overall_status = ProcessingStatus.FAILED
self._process_store[process_id] = state
return state, worked_agents
# async def process_batch(self, file_names: List[str], workflow_type: str = "standard",
# max_concurrent: int = 5) -> List[InvoiceProcessingState]:
# """Process a batch of files with limit concurrency"""
# sem = asyncio.Semaphore(max_concurrent)
# results: List[InvoiceProcessingState] = []
# async def _worker(fn: str):
# async with sem:
# return await self.process_invoice(fn, workflow_type=workflow_type)
# tasks = [asyncio.create_task(_worker(f)) for f in file_names]
# completed = await asyncio.gather(*tasks)
# for st in completed:
# results.append(st)
# return results
async def process_batch(self, file_names: List[str], workflow_type: str = "standard",
max_concurrent: int = 5):
sem = asyncio.Semaphore(max_concurrent)
results = [] # will store: {"state": ..., "worked_agents": [...]}
async def _worker(fn: str):
async with sem:
return await self.process_invoice(fn, workflow_type=workflow_type)
tasks = [asyncio.create_task(_worker(f)) for f in file_names]
completed = await asyncio.gather(*tasks)
for result in completed:
state, worked_agents = result # unpack the tuple
results.append({
"state": state,
"worked_agents": worked_agents
})
return results
async def get_workflow_status(self, process_id: str) -> Optional[Dict[str, Any]]:
"""Return the stored workflow status dictionary for a given process_id"""
state = self._process_store.get(process_id)
if not state:
return None
return {"process_id":process_id, "status":state.overall_status, "updated_at": getattr(state,"updated_at",None), "state":state.model_dump()}
async def health_check(self) -> Dict[str, Any]:
"""Aggregate health check across agents and the orchestrator itself"""
agents_health = await agent_registry.health_check_all()
return {"orchestrator":"Healthy","timestamp":datetime.utcnow().isoformat(),"agent":agents_health}
def _extract_final_state(self, result, initial_state: InvoiceProcessingState) -> InvoiceProcessingState:
"""Compatibility helper (returns invoice processing state)"""
return result
invoice_workflow: Optional[InvoiceProcessingGraph] = None
def get_workflow(config: Dict[str, Any] = None) -> InvoiceProcessingGraph:
global invoice_workflow
if invoice_workflow is None:
invoice_workflow = InvoiceProcessingGraph(config=config)
return invoice_workflow