File size: 22,059 Bytes
2b44e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
"""LangGraph workflow orchestrator"""
# TODO: Implement graph workflow

import asyncio
import uuid # extra import
from typing import Dict, Any, List, Optional, Literal
from datetime import datetime
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver

from state import (
    InvoiceProcessingState, ProcessingStatus, ValidationStatus,
    RiskLevel, PaymentStatus, WORKFLOW_CONFIGS
)
from agents.base_agent import agent_registry
from agents.document_agent import DocumentAgent
from agents.validation_agent import ValidationAgent
from agents.risk_agent import RiskAgent
from agents.payment_agent import PaymentAgent
from agents.audit_agent import AuditAgent
from agents.escalation_agent import EscalationAgent
from utils.logger import StructuredLogger


class InvoiceProcessingGraph:
    """Graph orchestrator"""

    def __init__(self, config: Dict[str, Any] = None):
        self.logger = StructuredLogger("InvoiceProcessingGraph")
        self.config = config or {}
        #Simple in-memory store for process states (process_id -> InvoiceProcessingState)
        self._process_store: Dict[str, InvoiceProcessingState] = {}
        #Register and initialize agents
        self._initialize_agents()
        try:
            self.graph = self._create_workflow_graph()
            self.compiled_graph = self.graph.compile(checkpointer=MemorySaver())
            self.logger.logger.info("InvoiceProcessingGraph initialized successfully with compiled graph.")
        except Exception as e:
            self.logger.logger.warning(f"Failed to fully build graph nodes: {e} — exposing empty StateGraph")
            self.graph = StateGraph("invoice_processing_graph_fallback")

    def _initialize_agents(self):
        """Instantiate and register agent instances in the global registry"""
        #create agent instances (idempotent - replace if already registered)
        agents = [
            DocumentAgent(),
            ValidationAgent(),
            RiskAgent(),
            PaymentAgent(),
            AuditAgent(),
            EscalationAgent(),
        ]
        for agent in agents:
            agent_registry.register(agent)
        self.logger.logger.info(f"Registered agents: {agent_registry.list_agents()}")

    def _create_workflow_graph(self) -> StateGraph:
        """
        Build a LangGraph StateGraph with conditional routing:
        Each node executes its corresponding agent and determines
        the next node based on runtime logic (risk, validation, etc.)
        """

        graph = StateGraph("invoice_processing_graph")

        # NODE DEFINITIONS 
        async def node_document(state: InvoiceProcessingState):
            state = await self._document_agent_node(state)
            next_node = self._route_after_document(state)
            return next_node, state

        async def node_validation(state: InvoiceProcessingState):
            state = await self._validation_agent_node(state)
            next_node = self._route_after_validation(state)
            return next_node, state

        async def node_risk(state: InvoiceProcessingState):
            state = await self._risk_agent_node(state)
            next_node = self._route_after_risk(state)
            return next_node, state

        async def node_payment(state: InvoiceProcessingState):
            state = await self._payment_agent_node(state)
            next_node = self._route_after_payment(state)
            return next_node, state

        async def node_audit(state: InvoiceProcessingState):
            state = await self._audit_agent_node(state)
            next_node = self._route_after_audit(state)
            return next_node, state

        async def node_escalation(state: InvoiceProcessingState):
            state = await self._escalation_agent_node(state)
            next_node = self._route_after_escalation(state)
            return next_node, state

        async def node_human_review(state: InvoiceProcessingState):
            state = await self._human_review_node(state)
            next_node = self._route_after_human_review(state)
            return next_node, state

        async def node_end(state: InvoiceProcessingState):
            self.logger.logger.info(f"Invoice {state.invoice_id} completed at {state.updated_at}")
            return "end", state

        # REGISTER NODES 
        for name, func in {
            "document": node_document,
            "validation": node_validation,
            "risk": node_risk,
            "payment": node_payment,
            "audit": node_audit,
            "escalation": node_escalation,
            "human_review": node_human_review,
            "end": node_end,
        }.items():
            try:
                graph.add_node(name, func)
            except Exception:
                # fallback if add_node signature differs
                setattr(graph, name, func)

        # ADD EDGES (DEFAULT PATHS)
        try:
            graph.add_edge("document", "validation")
            graph.add_edge("validation", "risk")
            graph.add_edge("risk", "payment")
            graph.add_edge("payment", "audit")
            graph.add_edge("audit", "end")
            # Alternative / exception flows
            graph.add_edge("document", "escalation")
            graph.add_edge("validation", "escalation")
            graph.add_edge("risk", "escalation")
            graph.add_edge("escalation", "human_review")
            graph.add_edge("human_review", "end")

            graph.set_entry_point("document")
        except Exception as ex:
            self.logger.logger.warning(f"Edge registration failed: {ex}")

        self.logger.logger.info("Conditional workflow graph built successfully.")
        return graph


    async def _document_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: DocumentAgent = agent_registry.get("document_agent")
        print("agent from doc", agent)
        if not agent:
            agent = DocumentAgent()
            agent_registry.register(agent)
        print("Registry instance ID in graph:", id(agent_registry))

        return await agent.run(state, workflow_type)

    async def _validation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: ValidationAgent = agent_registry.get("validation_agent")
        print("agent from val", agent)
        if not agent:
            agent = ValidationAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    async def _risk_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: RiskAgent = agent_registry.get("risk_agent")
        if not agent:
            agent = RiskAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    async def _payment_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: PaymentAgent = agent_registry.get("payment_agent")
        if not agent:
            agent = PaymentAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    async def _audit_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: AuditAgent = agent_registry.get("audit_agent")
        if not agent:
            agent = AuditAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    async def _escalation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        agent: EscalationAgent = agent_registry.get("escalation_agent")
        if not agent:
            agent = EscalationAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    async def _human_review_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        #Reusing escalation agent's human-in-the-loop or simply marking for manual review
        agent: EscalationAgent = agent_registry.get("escalation_agent")
        if not agent:
            agent = EscalationAgent()
            agent_registry.register(agent)
        return await agent.run(state, workflow_type)

    def _route_after_document(self, state: InvoiceProcessingState) -> Literal["validation", "escalation", "end"]:
        """Route decision after document extraction"""
        #if extraction yielded no invoice_data or low confidence -> escalate
        if not state.invoice_data:
            return "escalation"
        #if extraction confidence exists and is low -> escalate
        conf = getattr(state.invoice_data, "extraction_confidence", None)
        if conf is not None and conf<0.6:
            return "escalation"
        return "validation"


    def _route_after_validation(self, state: InvoiceProcessingState) -> Literal["risk", "escalation", "end"]:
        """Route decision after document validation"""
        vr = state.validation_result
        if not vr:
            return "escalation"
        #if missing PO or invalid -> escalate
        try:
            status = vr.validation_status
            #ValidationStatus maybe enum or str
            if isinstance(status,ValidationStatus):
                status_val = status
            else:
                status_val = ValidationStatus(status) if isinstance(status,str) else None
            if status_val == ValidationStatus.NO_MATCH or status_val == ValidationStatus.PARTIAL_MATCH and (not vr.amount_match):
                return "escalation"
        except Exception:
            #fallback: if discrepancies exist -> escalation
            if vr and getattr(vr,"discrepancies",None):
                return "escalation"
        return "risk"

    def _route_after_risk(self, state: InvoiceProcessingState) -> Literal["payment", "escalation", "human_review", "end"]:
        """Route decision after risk assessment"""
        ra = state.risk_assessment
        if not ra:
            return "escalation"
        #ra.risk_level is an enum RiskLevel
        rl = getattr(ra,"risk_level",None)
        #handle strings or enums
        rl_val = rl.value if hasattr(rl,"value") else str(rl).lower()
        try:
            if rl_val in (RiskLevel.CRITICAL.value, RiskLevel.HIGH.value):
                #For critical-> human review; for high->escalate
                if rl_val == RiskLevel.CRITICAL.value:
                    return "human_review"
                return "escalation"
            else:
                #low or medium -> payment
                return "payment"
        except Exception:
            return "payment"

    def _route_after_payment(self, state: InvoiceProcessingState) -> Literal["audit", "escalation", "end"]:
        pd = getattr(state,"payment_decision",None)
        if not pd:
            return "escalation"
        #If approved (or scheduled) -> audit
        try:
            status = pd.payment_status
            #Accept enum or str
            status_val = status if isinstance(status,str) else getattr(status,"value",str(status))
            if status_val in (PaymentStatus.APPROVED.value, PaymentStatus.SCHEDULED.value, PaymentStatus.PENDING_APPROVAL.value):
                return "audit"
            else:
                return "escalation"
        except Exception:
            return "audit"

    def _route_after_audit(self, state: InvoiceProcessingState) -> Literal["escalation", "end"]:
        cr = getattr(state, "compliance_report",None)
        if not cr:
            return "end"
        #If any compliance issues ->escalate
        issues = cr.get("issues",{}) if isinstance(cr, dict) else {}
        has_issues = any(issues.get(k) for k in issues)
        return "escalation" if has_issues else "end"

    async def _handle_escalation_chain(self, state: "InvoiceProcessingState", workflow_type):
        """Common handler for escalation → human review → complete"""
        state = await self._escalation_agent_node(state, workflow_type)
        self._process_store[state.process_id] = state
        state = await self._human_review_node(state, workflow_type)
        state.overall_status = ProcessingStatus.COMPLETED
        self._process_store[state.process_id] = state
        return state

    async def process_invoice(self, file_name: str, workflow_type: str = "standard",
                          config: Dict[str, Any] = None) -> InvoiceProcessingState:
        """
        Orchestrate processing for a single invoice file.
        Supports 3 workflow types: standard, high_value, and expedited.
        """
        process_id = f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
        initial_state = InvoiceProcessingState(
            process_id=process_id,
            file_name=file_name,
            overall_status=ProcessingStatus.PENDING,
            current_agent=None,
            workflow_type=workflow_type,
            created_at=datetime.utcnow(),
            updated_at=datetime.utcnow(),
        )
    
        self._process_store[process_id] = initial_state
        start_ts = datetime.utcnow()
        state = initial_state
        worked_agents = []
        try:
            # STEP 1️ Document Extraction
            state = await self._document_agent_node(state, workflow_type)
            self._process_store[process_id] = state
            route = self._route_after_document(state)
            print("state agent anme ::::::::::::::", state.agent_name)
            worked_agents.append(state.agent_name)
            if route == "escalation":
                state = await self._handle_escalation_chain(state, workflow_type)
                worked_agents.append(state.agent_name)
                return state, worked_agents
    
            # ---- Workflow branching ----
            if workflow_type == "expedited":
                # Fast lane - skip validation if AI confidence is high
                # if getattr(state, "extraction_confidence", 0.0) < 0.85:
                state = await self._validation_agent_node(state, workflow_type)
                self._process_store[process_id] = state
                worked_agents.append(state.agent_name)
                route = self._route_after_validation(state)
                if route == "escalation":
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # Directly go to Payment, minimal audit
                state = await self._payment_agent_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                self._process_store[process_id] = state
                if getattr(state, "payment_decision", {}).decision == "auto-pay":
                    state = await self._audit_agent_node(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    self._process_store[process_id] = state

            elif workflow_type == "high_value":
                # 2️ Validation (twice for accuracy)
                state = await self._validation_agent_node(state, workflow_type)
                self._process_store[process_id] = state
                state = await self._validation_agent_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                route = self._route_after_validation(state)
                if route == "escalation":
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # 3️ Risk
                state = await self._risk_agent_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                self._process_store[process_id] = state
                route = self._route_after_risk(state)
                if route in ["escalation", "human_review"]:
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # 4 Audit
                state = await self._audit_agent_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                self._process_store[process_id] = state

                # 5 Mandatory human review for high-value invoices
                state = await self._human_review_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                self._process_store[process_id] = state
    
            else:  # STANDARD workflow
                # 2️ Validation
                state = await self._validation_agent_node(state, workflow_type)
                self._process_store[process_id] = state
                worked_agents.append(state.agent_name)
                route = self._route_after_validation(state)
                if route == "escalation":
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # 3️ Risk
                state = await self._risk_agent_node(state, workflow_type)
                self._process_store[process_id] = state
                worked_agents.append(state.agent_name)
                route = self._route_after_risk(state)
                if route in ["escalation", "human_review"]:
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # 4️ Payment
                state = await self._payment_agent_node(state, workflow_type)
                self._process_store[process_id] = state
                worked_agents.append(state.agent_name)
                route = self._route_after_payment(state)
                if route == "escalation":
                    state = await self._handle_escalation_chain(state, workflow_type)
                    worked_agents.append(state.agent_name)
                    return state, worked_agents
    
                # 5️ Audit
                state = await self._audit_agent_node(state, workflow_type)
                worked_agents.append(state.agent_name)
                self._process_store[process_id] = state
    
            # Success completion
            state.overall_status = ProcessingStatus.COMPLETED
            state.updated_at = datetime.utcnow()
            elapsed = (datetime.utcnow() - start_ts).total_seconds()
            self.logger.logger.info(f"Process {process_id} ({workflow_type}) completed in {elapsed:.2f}s")
            self._process_store[process_id] = state
            # print("from graph worked agents::::", worked_agents)
            return state, worked_agents
    
        except Exception as e:
            self.logger.logger.exception(f"Error processing invoice {file_name}: {e}")
            state.overall_status = ProcessingStatus.FAILED
            self._process_store[process_id] = state
            return state, worked_agents


    # async def process_batch(self, file_names: List[str], workflow_type: str = "standard",
    #                       max_concurrent: int = 5) -> List[InvoiceProcessingState]:
    #     """Process a batch of files with limit concurrency"""
    #     sem = asyncio.Semaphore(max_concurrent)
    #     results: List[InvoiceProcessingState] = []

    #     async def _worker(fn: str):
    #         async with sem:
    #             return await self.process_invoice(fn, workflow_type=workflow_type)

    #     tasks = [asyncio.create_task(_worker(f)) for f in file_names]
    #     completed = await asyncio.gather(*tasks)
    #     for st in completed:
    #         results.append(st)
    #     return results
    async def process_batch(self, file_names: List[str], workflow_type: str = "standard",
                        max_concurrent: int = 5):
        sem = asyncio.Semaphore(max_concurrent)
        results = []   # will store: {"state": ..., "worked_agents": [...]}

        async def _worker(fn: str):
            async with sem:
                return await self.process_invoice(fn, workflow_type=workflow_type)

        tasks = [asyncio.create_task(_worker(f)) for f in file_names]
        completed = await asyncio.gather(*tasks)

        for result in completed:
            state, worked_agents = result  # unpack the tuple
            results.append({
                "state": state,
                "worked_agents": worked_agents
            })

        return results


    async def get_workflow_status(self, process_id: str) -> Optional[Dict[str, Any]]:
        """Return the stored workflow status dictionary for a given process_id"""
        state = self._process_store.get(process_id)
        if not state:
            return None
        return {"process_id":process_id, "status":state.overall_status, "updated_at": getattr(state,"updated_at",None), "state":state.model_dump()}

    async def health_check(self) -> Dict[str, Any]:
        """Aggregate health check across agents and the orchestrator itself"""
        agents_health = await agent_registry.health_check_all()
        return {"orchestrator":"Healthy","timestamp":datetime.utcnow().isoformat(),"agent":agents_health}

    def _extract_final_state(self, result, initial_state: InvoiceProcessingState) -> InvoiceProcessingState:
        """Compatibility helper (returns invoice processing state)"""
        return result


invoice_workflow: Optional[InvoiceProcessingGraph] = None

def get_workflow(config: Dict[str, Any] = None) -> InvoiceProcessingGraph:
    global invoice_workflow
    if invoice_workflow is None:
        invoice_workflow = InvoiceProcessingGraph(config=config)
    return invoice_workflow