PARTHASAKHAPAUL commited on
Commit
2b44e69
·
1 Parent(s): f64784b

Initial Commit 0.1

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Project/.dockerignore +12 -0
  2. Project/.gitignore +12 -0
  3. Project/agents/audit_agent.py +693 -0
  4. Project/agents/base_agent.py +221 -0
  5. Project/agents/document_agent.py +411 -0
  6. Project/agents/escalation_agent.py +315 -0
  7. Project/agents/forecast_agent.py +253 -0
  8. Project/agents/insights_agent.py +107 -0
  9. Project/agents/payment_agent.py +348 -0
  10. Project/agents/risk_agent.py +644 -0
  11. Project/agents/smart_explainer_agent.py +220 -0
  12. Project/agents/validation_agent.py +357 -0
  13. Project/bounding_box.py +138 -0
  14. Project/data/annotated_invoice.pdf +0 -0
  15. Project/data/invoices/Invoice-01.pdf +0 -0
  16. Project/data/invoices/Invoice-02.pdf +0 -0
  17. Project/data/invoices/Invoice-03.pdf +0 -0
  18. Project/data/invoices/Invoice-04.pdf +0 -0
  19. Project/data/invoices/Invoice-05.pdf +0 -0
  20. Project/data/invoices/Invoice-06.pdf +0 -0
  21. Project/data/invoices/Invoice-07.pdf +0 -0
  22. Project/data/invoices/Invoice-08.pdf +0 -0
  23. Project/data/invoices/Invoice-09.pdf +0 -0
  24. Project/data/invoices/Invoice-10.pdf +0 -0
  25. Project/data/invoices/Invoice-11.pdf +0 -0
  26. Project/data/invoices/Invoice-12.pdf +0 -0
  27. Project/data/invoices/Invoice-13.pdf +0 -0
  28. Project/data/invoices/Invoice-14.pdf +0 -0
  29. Project/data/invoices/Invoice-15.pdf +0 -0
  30. Project/data/invoices/Invoice-16.pdf +0 -0
  31. Project/data/invoices/Invoice-17.pdf +0 -0
  32. Project/data/invoices/Invoice-18.pdf +0 -0
  33. Project/data/invoices/Invoice-19.pdf +0 -0
  34. Project/data/invoices/Invoice-20.pdf +0 -0
  35. Project/data/invoices/Invoice-21.pdf +0 -0
  36. Project/data/invoices/Invoice-22.pdf +0 -0
  37. Project/data/invoices/Invoice-23.pdf +0 -0
  38. Project/data/invoices/Invoice-24.pdf +0 -0
  39. Project/data/invoices/Invoice-25.pdf +0 -0
  40. Project/data/invoices/Invoice-26.pdf +0 -0
  41. Project/data/invoices/Invoice-27.pdf +0 -0
  42. Project/data/invoices/Invoice-28.pdf +0 -0
  43. Project/data/invoices/Invoice-29.pdf +0 -0
  44. Project/data/invoices/Invoice-30.pdf +0 -0
  45. Project/data/invoices/Invoice-31.pdf +0 -0
  46. Project/data/invoices/Invoice-32.pdf +0 -0
  47. Project/data/invoices/Invoice-33.pdf +0 -0
  48. Project/data/invoices/Invoice-34.pdf +0 -0
  49. Project/data/invoices/Invoice-35.pdf +0 -0
  50. Project/data/invoices/Invoice-36.pdf +0 -0
Project/.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .env
6
+ .git
7
+ .gitignore
8
+ myenv/
9
+ venv/
10
+ .env/
11
+ .venv/
12
+ tests.py
Project/.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ myenv
2
+ .env
3
+ key_stats.json
4
+ tests.py
5
+ Dockerfile
6
+ run.py
7
+ tests.py
8
+ __pycache__/
9
+ *.pyc
10
+ logs/audit
11
+ nodes
12
+ output/escalations
Project/agents/audit_agent.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Audit Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import os
7
+ import json
8
+ import pandas as pd
9
+ from typing import Dict, Any, List, Optional
10
+ from datetime import datetime, timedelta
11
+ import google.generativeai as genai
12
+ from dotenv import load_dotenv
13
+ import time
14
+ from statistics import mean
15
+
16
+ from agents.base_agent import BaseAgent
17
+ from state import (
18
+ InvoiceProcessingState, ProcessingStatus, PaymentStatus,
19
+ ValidationStatus, RiskLevel
20
+ )
21
+ from utils.logger import StructuredLogger
22
+
23
+ load_dotenv()
24
+
25
+
26
+ class AuditAgent(BaseAgent):
27
+ """Agent responsible for audit trail generation, compliance tracking, and reporting"""
28
+
29
+ def __init__(self, config: Dict[str, Any] = None):
30
+ super().__init__("audit_agent",config)
31
+ self.logger = StructuredLogger("AuditAgent")
32
+ # --- Health tracking ---
33
+ self.execution_history: List[Dict[str, Any]] = []
34
+ self.max_history = 50 # store last 50 runs
35
+
36
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
37
+ """
38
+ Ensure that the state object is properly initialized before invoice processing begins.
39
+ Checks for presence of critical fields like process_id, file_name, and timestamps.
40
+ """
41
+ if not state:
42
+ return False
43
+
44
+ # Must have valid process id and file name
45
+ if not getattr(state, "process_id", None) or not getattr(state, "file_name", None):
46
+ return False
47
+
48
+ # Must have timestamps and valid status
49
+ if not getattr(state, "created_at", None) or not getattr(state, "overall_status", None):
50
+ return False
51
+
52
+ # Should not already be marked complete
53
+ if state.overall_status in ("failed", "pending"):
54
+ return False
55
+
56
+ return True
57
+
58
+
59
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
60
+ """
61
+ Validate that all expected outputs and audit data are present after processing.
62
+ Ensures that critical workflow components completed successfully.
63
+ """
64
+ if not state:
65
+ return False
66
+
67
+ # Must have processed invoice data and validation results
68
+ if not state.invoice_data or not state.validation_result:
69
+ return False
70
+
71
+ # Must have at least one audit entry for traceability
72
+ if not state.audit_trail or len(state.audit_trail) == 0:
73
+ return False
74
+
75
+ # Risk or payment results may be optional, but check consistency
76
+ if state.risk_assessment and state.risk_assessment.risk_score > 1.0:
77
+ return False # sanity check for invalid scores
78
+
79
+ # Final status should not be pending anymore
80
+ if state.overall_status == "pending":
81
+ return False
82
+
83
+ return True
84
+
85
+
86
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
87
+ """Main audit generation workflow"""
88
+ self.logger.logger.info("Starting audit trail generation")
89
+ start_time = time.time()
90
+ success = False
91
+ try:
92
+ if not self._validate_preconditions(state, workflow_type):
93
+ self.logger.logger.error("Preconditions not met for audit generation")
94
+ state.overall_status = ProcessingStatus.FAILED
95
+ self._log_decision(state, "Audit Failed", "Preconditions not met", confidence=0.0)
96
+ return state
97
+
98
+ audit_record = await self._generate_audit_record(state)
99
+ print("audit_record---------", audit_record)
100
+ compliance_results = await self._perform_compliance_checks(state,audit_record)
101
+ print("compliance_results---------", compliance_results)
102
+ audit_summary = await self._generate_audit_summary(state,audit_record,compliance_results)
103
+ print("audit_summary---------", audit_summary)
104
+ await self._save_audit_records(state,audit_record,audit_summary,compliance_results)
105
+
106
+ reportable_events = await self._identify_reportable_events(state,audit_record)
107
+ print("reportable_events---------", reportable_events)
108
+
109
+ await self._generate_audit_alerts(state,reportable_events)
110
+
111
+ state.audit_trail = audit_record.get("audit_trail",[])
112
+ print("state.audit_trail---------", state.audit_trail)
113
+ state.compliance_report = compliance_results
114
+ state.current_agent = "audit_agent"
115
+ state.overall_status = "completed"
116
+
117
+ self.logger.logger.info("Audit trail and compliance generated successfully")
118
+ success = True
119
+ self._log_decision(
120
+ state,
121
+ "Auditing Successful",
122
+ "Auditing Processed",
123
+ 100.0,
124
+ state.process_id
125
+ )
126
+ state.audit_trail[-1]
127
+ return state
128
+
129
+ except Exception as e:
130
+ self.logger.logger.error(f"Audit agent execution failed: {e}")
131
+ state.overall_status = ProcessingStatus.FAILED
132
+ return state
133
+
134
+ finally:
135
+ duration_ms = round((time.time() - start_time) * 1000, 2)
136
+ self._record_execution(success, duration_ms, state)
137
+
138
+ async def _generate_audit_record(self, state: InvoiceProcessingState) -> Dict[str, Any]:
139
+ """
140
+ Aggregate and structure all agent-level logs into a consistent audit report.
141
+ Uses the state's existing audit_trail list and agent_metrics for detailed tracking.
142
+ """
143
+ self.logger.logger.debug("Generating audit record")
144
+
145
+ if not isinstance(state, InvoiceProcessingState):
146
+ raise ValueError("Invalid state object passed to _generate_audit_record")
147
+
148
+ audit_trail_records = []
149
+ for entry in getattr(state, "audit_trail", []):
150
+ record = {
151
+ "process_id": getattr(entry, "process_id", state.process_id),
152
+ "timestamp": getattr(entry, "timestamp", datetime.utcnow().isoformat() + "Z"),
153
+ "agent_name": getattr(entry, "agent_name", "unknown"),
154
+ "action": getattr(entry, "action", "undefined"),
155
+ # "status": getattr(entry, "status", "completed"),
156
+ "details": getattr(entry, "details", {}),
157
+ # "duration_ms": getattr(entry, "details", {}).get("duration_ms", 0),
158
+ # "error_message": getattr(entry, "details", {}).get("error_message", None),
159
+ }
160
+ audit_trail_records.append(record)
161
+
162
+ # Include agent metrics summary for full traceability
163
+ metrics_summary = {
164
+ agent: {
165
+ "executions": getattr(m, "processed_count", 0),
166
+ "success_rate": getattr(m, "success_rate", 0),
167
+ "failures": getattr(m, "errors", 0),
168
+ "avg_duration_ms": getattr(m, "avg_latency_ms", 0.0),
169
+ "last_run_at": getattr(m, "last_run_at", None),
170
+ }
171
+ for agent, m in getattr(state, "agent_metrics", {}).items()
172
+ }
173
+
174
+ audit_report = {
175
+ "process_id": state.process_id,
176
+ "created_at": state.created_at.isoformat() + "Z",
177
+ "updated_at": state.updated_at.isoformat() + "Z",
178
+ "total_entries": len(audit_trail_records),
179
+ "audit_trail": audit_trail_records,
180
+ "metrics_summary": metrics_summary,
181
+ }
182
+
183
+ self.logger.logger.info(
184
+ f"Audit record generated with {len(audit_trail_records)} entries for process {state.process_id}"
185
+ )
186
+
187
+ return audit_report
188
+
189
+ async def _perform_compliance_checks(
190
+ self, state: InvoiceProcessingState, audit_record: Dict[str, Any]
191
+ ) -> Dict[str, Any]:
192
+ """
193
+ Perform SOX, GDPR, and financial compliance validations.
194
+ Aggregates results from internal compliance check methods and produces
195
+ a structured compliance report.
196
+ """
197
+ self.logger.logger.debug("Performing compliance checks for process %s", state.process_id)
198
+
199
+ # Defensive: ensure proper structures
200
+ if not isinstance(state, InvoiceProcessingState):
201
+ raise ValueError("Invalid state object passed to _perform_compliance_checks")
202
+ if not isinstance(audit_record, dict):
203
+ raise ValueError("Invalid audit record structure")
204
+
205
+ # Run all compliance sub-checks safely
206
+ sox = self._check_sox_compliance(state, audit_record) or {}
207
+ privacy = self._check_data_privacy_compliance(state, audit_record) or {}
208
+ financial = self._check_financial_controls(state, audit_record) or {}
209
+ completeness = self._check_audit_trail_completeness(state, audit_record) or {}
210
+
211
+ # Normalize results for consistency
212
+ sox_issues = sox.get("issues", [])
213
+ privacy_issues = privacy.get("issues", [])
214
+ financial_issues = financial.get("issues", [])
215
+ is_complete = completeness.get("complete", True)
216
+
217
+ # Compose structured compliance summary
218
+ compliance_report = {
219
+ "process_id": state.process_id,
220
+ "timestamp": datetime.utcnow().isoformat() + "Z",
221
+ "sox_compliance": "compliant" if not sox_issues else "non_compliant",
222
+ "gdpr_compliance": "compliant" if not privacy_issues else "non_compliant",
223
+ "financial_controls": "passed" if not financial_issues else "failed",
224
+ "audit_trail_complete": is_complete,
225
+ "retention_policy": getattr(self.config, "retention_policy", "7_years"),
226
+ "encryption_status": "encrypted",
227
+ "issues": {
228
+ "sox": sox_issues,
229
+ "privacy": privacy_issues,
230
+ "financial": financial_issues,
231
+ },
232
+ }
233
+
234
+ # Optional: attach compliance report to the state for future use
235
+ setattr(state, "compliance_report", compliance_report)
236
+ state.updated_at = datetime.utcnow()
237
+
238
+ self.logger.logger.info(
239
+ f"Compliance checks completed for process {state.process_id}: "
240
+ f"SOX={compliance_report['sox_compliance']}, "
241
+ f"GDPR={compliance_report['gdpr_compliance']}, "
242
+ f"Financial={compliance_report['financial_controls']}"
243
+ )
244
+
245
+ return compliance_report
246
+
247
+
248
+ def _check_sox_compliance(
249
+ self,
250
+ state: InvoiceProcessingState,
251
+ audit_record: Dict[str, Any]
252
+ ) -> Dict[str, List[str]]:
253
+ """
254
+ Intelligent SOX compliance verification.
255
+ Checks that all approval steps, segregation of duties,
256
+ and key sign-offs are properly recorded and timestamped.
257
+ """
258
+ issues = []
259
+
260
+ approval_chain = getattr(state, "approval_chain", [])
261
+ if not approval_chain:
262
+ issues.append("Missing approval chain records")
263
+ else:
264
+ # Verify each approval step includes signer and timestamp
265
+ for step in approval_chain:
266
+ if not step.get("approved_by") or not step.get("timestamp"):
267
+ issues.append(f"Incomplete approval step: {step}")
268
+ # Optional: check segregation of duties
269
+ approvers = [a.get("approved_by") for a in approval_chain if a.get("approved_by")]
270
+ if len(set(approvers)) < len(approvers):
271
+ issues.append("Potential conflict of interest: repeated approver detected")
272
+
273
+ VALID_ACTIONS = {
274
+ "Extraction Successful",
275
+ "Validation Successful",
276
+ "Risk Assessment Successful",
277
+ "Agent Successfully Executed",
278
+ "approved"
279
+ }
280
+ has_final_approval = all(
281
+ any(keyword in entry.get("action", "") for keyword in VALID_ACTIONS)
282
+ for entry in audit_record.get("audit_trail", [])
283
+ )
284
+
285
+ if not has_final_approval:
286
+ issues.append("Some approval event yet to successful in audit trail")
287
+
288
+ return {"issues": issues}
289
+
290
+
291
+ def _check_data_privacy_compliance(
292
+ self,
293
+ state: InvoiceProcessingState,
294
+ audit_record: Dict[str, Any]
295
+ ) -> Dict[str, List[str]]:
296
+ """
297
+ Validate GDPR / Data Privacy compliance.
298
+ Ensures that no unmasked personal or financial data is logged or stored.
299
+ """
300
+ issues = []
301
+ text_repr = str(audit_record).lower()
302
+
303
+ # PII patterns to scan for (we can expand this list)
304
+ suspicious_patterns = ["@gmail.com", "@yahoo.com", "ssn", "credit card", "bank_account"]
305
+
306
+ for pattern in suspicious_patterns:
307
+ if pattern in text_repr:
308
+ issues.append(f"Unmasked PII detected: '{pattern}'")
309
+
310
+ # Ensure encryption and retention policy
311
+ # if getattr(state, "config", {}).get("encryption_status") != "encrypted":
312
+ # issues.append("Data encryption not confirmed")
313
+
314
+ # if "retention_policy" not in getattr(state, "config", {}):
315
+ # issues.append("Retention policy not defined")
316
+
317
+ return {"issues": issues}
318
+
319
+
320
+ def _check_financial_controls(
321
+ self,
322
+ state: InvoiceProcessingState,
323
+ audit_record: Dict[str, Any]
324
+ ) -> Dict[str, List[str]]:
325
+ """
326
+ Validate financial control compliance.
327
+ Ensures that transactions, approvals, and risk assessments
328
+ are properly recorded before payment release.
329
+ """
330
+ issues = []
331
+
332
+ # Check for missing financial artifacts
333
+ if not getattr(state, "payment_decision", None):
334
+ issues.append("Missing payment decision records")
335
+
336
+ if not getattr(state, "validation_result", None):
337
+ issues.append("Missing validation result for payment control")
338
+
339
+ if state.validation_result and state.validation_result.validation_status == "invalid":
340
+ issues.append("Invoice marked invalid but payment decision exists")
341
+
342
+ # Cross-check audit trail for financial actions
343
+ actions = [a.get("action", "").lower() for a in audit_record.get("audit_trail", [])]
344
+ if not any("approved" in a for a in actions):
345
+ issues.append("No payment-related activity recorded in audit trail")
346
+
347
+ return {"issues": issues}
348
+
349
+ def _check_audit_trail_completeness(
350
+ self,
351
+ state: InvoiceProcessingState,
352
+ audit_record: Dict[str, Any]
353
+ ) -> Dict[str, Any]:
354
+ """
355
+ Ensure all mandatory agents and workflow stages were executed and logged.
356
+ Validates sequence integrity and timestamp order.
357
+ """
358
+ required_agents = ["document_agent", "validation_agent", "risk_agent", "payment_agent"]
359
+ logged_agents = [x.get("agent_name") for x in audit_record.get("audit_trail", [])]
360
+ missing = [a for a in required_agents if a not in logged_agents]
361
+
362
+ complete = len(missing) == 0
363
+
364
+ timestamps = []
365
+ for e in audit_record.get("audit_trail", []):
366
+ ts = e.get("timestamp")
367
+ if ts:
368
+ try:
369
+ if isinstance(ts, datetime):
370
+ timestamps.append(ts)
371
+ else:
372
+ # Normalize 'Z' and try parsing
373
+ ts_str = str(ts).replace("Z", "+00:00")
374
+ try:
375
+ timestamps.append(datetime.fromisoformat(ts_str))
376
+ except Exception:
377
+ try:
378
+ timestamps.append(datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S.%f"))
379
+ except Exception:
380
+ timestamps.append(datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S"))
381
+ except Exception:
382
+ self.logger.logger.warning(f"Invalid timestamp format in audit trail: {ts}")
383
+
384
+
385
+
386
+ if timestamps and timestamps != sorted(timestamps):
387
+ missing.append("Non-sequential timestamps detected in audit trail")
388
+
389
+ # Check for duplicate agent entries
390
+ if len(logged_agents) != len(set(logged_agents)):
391
+ missing.append("Duplicate agent entries found in audit trail")
392
+
393
+ return {"complete": complete, "missing": missing}
394
+
395
+
396
+ async def _generate_audit_summary(
397
+ self,
398
+ state: InvoiceProcessingState,
399
+ audit_record: Dict[str, Any],
400
+ compliance_results: Dict[str, Any]
401
+ ) -> str:
402
+ """
403
+ Generate a structured textual audit summary report.
404
+ Combines audit record data and compliance results into a concise, test-friendly JSON summary.
405
+ """
406
+ self.logger.logger.debug("Generating audit summary for process %s", state.process_id)
407
+
408
+ # Defensive: ensure valid input types
409
+ if not isinstance(state, InvoiceProcessingState):
410
+ raise ValueError("Invalid state object passed to _generate_audit_summary")
411
+ if not isinstance(audit_record, dict):
412
+ raise ValueError("Invalid audit record structure")
413
+ if not isinstance(compliance_results, dict):
414
+ raise ValueError("Invalid compliance results structure")
415
+
416
+ # Extract audit trail count safely
417
+ total_actions = len(audit_record.get("audit_trail", []))
418
+
419
+ # Safely extract compliance keys
420
+ sox_status = compliance_results.get("sox_compliance", "unknown")
421
+ gdpr_status = compliance_results.get("gdpr_compliance", "unknown")
422
+ financial_status = compliance_results.get("financial_controls", "unknown")
423
+ retention_policy = compliance_results.get("retention_policy", "7_years")
424
+
425
+ # Build structured summary
426
+ summary_data = {
427
+ "process_id": state.process_id,
428
+ "generated_at": datetime.utcnow().isoformat() + "Z",
429
+ "total_actions": total_actions,
430
+ "overall_status": getattr(state, "overall_status", "UNKNOWN"),
431
+ "compliance": {
432
+ "SOX": sox_status,
433
+ "GDPR": gdpr_status,
434
+ "Financial": financial_status,
435
+ },
436
+ "retention_policy": retention_policy,
437
+ }
438
+
439
+ # Attach to state for post-validation
440
+ setattr(state, "audit_summary", summary_data)
441
+ state.updated_at = datetime.utcnow()
442
+
443
+ # Log completion
444
+ self.logger.logger.info(
445
+ f"Audit summary generated for process {state.process_id}: "
446
+ f"Actions={total_actions}, SOX={sox_status}, GDPR={gdpr_status}, Financial={financial_status}"
447
+ )
448
+
449
+ # Return formatted JSON for easy test validation or storage
450
+ return json.dumps(summary_data, indent=2)
451
+
452
+
453
+ async def _save_audit_records(self, state: InvoiceProcessingState,
454
+ audit_record: Dict[str, Any],
455
+ audit_summary: str,
456
+ compliance_results: Dict[str, Any]):
457
+ """Save audit log to file"""
458
+ os.makedirs("logs/audit",exist_ok=True)
459
+ file_path = f"logs/audit/audit_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json"
460
+ with open(file_path,"w") as f:
461
+ json.dump({
462
+ "audit_trail": audit_record["audit_trail"],
463
+ "summary": json.loads(audit_summary),
464
+ "compliance":compliance_results
465
+ },f,indent=2, default=str)
466
+ self.logger.logger.info(f"Audit record saved:{file_path}")
467
+
468
+ async def _identify_reportable_events(
469
+ self,
470
+ state: InvoiceProcessingState,
471
+ audit_record: Dict[str, Any]
472
+ ) -> List[Dict[str, Any]]:
473
+ """
474
+ Identify reportable anomalies or irregularities from the audit trail for compliance auditors.
475
+ Includes failed actions, high latency events, and repeated errors.
476
+ """
477
+ self.logger.logger.debug("Analyzing audit trail for reportable events...")
478
+
479
+ reportable: List[Dict[str, Any]] = []
480
+ audit_trail = audit_record.get("audit_trail", [])
481
+
482
+ if not audit_trail:
483
+ self.logger.logger.warning("No audit trail found for process %s", state.process_id)
484
+ return []
485
+
486
+ # Group by agent to detect repeated failures
487
+ failure_counts = {}
488
+
489
+ for entry in audit_trail:
490
+ # Defensive: ensure entry is a dict
491
+ if not isinstance(entry, dict):
492
+ continue
493
+
494
+ status = str(entry.get("status", "")).lower()
495
+ error_message = entry.get("error_message")
496
+ duration_ms = entry.get("duration_ms", 0)
497
+ agent = entry.get("agent_name", "unknown")
498
+
499
+ # Track failures for later aggregation
500
+ if status == "failed":
501
+ failure_counts[agent] = failure_counts.get(agent, 0) + 1
502
+
503
+ # Identify anomalies
504
+ anomaly_detected = (
505
+ status == "failed"
506
+ or bool(error_message)
507
+ or duration_ms > 5000 # 5-second latency threshold
508
+ )
509
+
510
+ if anomaly_detected:
511
+ reportable.append({
512
+ "process_id": state.process_id,
513
+ "agent_name": agent,
514
+ "timestamp": entry.get("timestamp", datetime.utcnow().isoformat() + "Z"),
515
+ "status": status,
516
+ "duration_ms": duration_ms,
517
+ "error_message": error_message,
518
+ "details": entry.get("details", {}),
519
+ "anomaly_reason": (
520
+ "Failure"
521
+ if status == "failed"
522
+ else "High latency"
523
+ if duration_ms > 5000
524
+ else "Error message logged"
525
+ ),
526
+ })
527
+
528
+ # Add summary-level anomaly if multiple failures detected
529
+ for agent, count in failure_counts.items():
530
+ if count > 2:
531
+ reportable.append({
532
+ "process_id": state.process_id,
533
+ "agent_name": agent,
534
+ "timestamp": datetime.utcnow().isoformat() + "Z",
535
+ "status": "repeated_failures",
536
+ "details": {"failure_count": count},
537
+ "anomaly_reason": f"{count} repeated failures detected for {agent}",
538
+ })
539
+
540
+ # Log summary for visibility
541
+ if reportable:
542
+ self.logger.logger.info(
543
+ "Detected %d reportable events for process %s",
544
+ len(reportable),
545
+ state.process_id,
546
+ )
547
+ else:
548
+ self.logger.logger.debug("No reportable events found for process %s", state.process_id)
549
+
550
+ # Attach to state for traceability
551
+ setattr(state, "reportable_events", reportable)
552
+ state.updated_at = datetime.utcnow()
553
+
554
+ return reportable
555
+
556
+
557
+ async def _generate_audit_alerts(
558
+ self,
559
+ state: InvoiceProcessingState,
560
+ reportable_events: List[Dict[str, Any]]
561
+ ) -> None:
562
+ """
563
+ Generate and dispatch alerts for detected audit anomalies.
564
+ Alerts are categorized based on severity (warning or critical)
565
+ and logged for traceability. Optionally integrates with external
566
+ alerting channels (e.g., Slack, PagerDuty, email).
567
+ """
568
+ if not reportable_events:
569
+ self.logger.logger.debug("No audit alerts to generate for process %s", state.process_id)
570
+ return
571
+
572
+ self.logger.logger.warning(
573
+ "[AuditSystem] %d reportable audit events detected for process %s",
574
+ len(reportable_events),
575
+ state.process_id,
576
+ )
577
+
578
+ alerts_summary = []
579
+ critical_events = 0
580
+
581
+ for event in reportable_events:
582
+ agent = event.get("agent_name", "unknown")
583
+ reason = event.get("anomaly_reason", "unspecified")
584
+ status = str(event.get("status", "")).lower()
585
+ duration = event.get("duration_ms", 0)
586
+ timestamp = event.get("timestamp", datetime.utcnow().isoformat() + "Z")
587
+
588
+ # Classify severity
589
+ severity = "critical" if status == "failed" or "repeated" in status else "warning"
590
+ if severity == "critical":
591
+ critical_events += 1
592
+
593
+ alert_message = (
594
+ f"[{severity.upper()} ALERT] Agent: {agent} | Reason: {reason} | "
595
+ f"Status: {status} | Duration: {duration} ms | Time: {timestamp}"
596
+ )
597
+
598
+ # Log structured alert
599
+ if severity == "critical":
600
+ self.logger.logger.error(alert_message)
601
+ else:
602
+ self.logger.logger.warning(alert_message)
603
+
604
+ alerts_summary.append({
605
+ "severity": severity,
606
+ "agent_name": agent,
607
+ "reason": reason,
608
+ "status": status,
609
+ "duration_ms": duration,
610
+ "timestamp": timestamp,
611
+ })
612
+
613
+ # Optionally send to external alerting channels (mocked)
614
+ try:
615
+ await self._send_alert_notification(alerts_summary[-1])
616
+ except Exception as e:
617
+ self.logger.logger.error(f"Failed to dispatch alert notification: {e}")
618
+
619
+ # Attach alerts summary to state for later review
620
+ setattr(state, "audit_alerts", alerts_summary)
621
+ state.updated_at = datetime.utcnow()
622
+
623
+ # Log summary
624
+ self.logger.logger.info(
625
+ "Audit alert generation completed: %d total (%d critical)",
626
+ len(alerts_summary),
627
+ critical_events,
628
+ )
629
+
630
+ def _record_execution(self, success: bool, duration_ms: float, state: Optional[InvoiceProcessingState] = None):
631
+ compliance = getattr(state, "compliance_report", {}) if state else {}
632
+ compliant_flags = [
633
+ compliance.get("sox_compliance") == "compliant",
634
+ compliance.get("gdpr_compliance") == "compliant",
635
+ compliance.get("financial_controls") in ("passed", "compliant")
636
+ ]
637
+ compliance_score = round((sum(compliant_flags) / len(compliant_flags)) * 100, 2) if compliant_flags else 0
638
+
639
+ self.execution_history.append({
640
+ # "timestamp": datetime.utcnow().isoformat(),
641
+ "success": success,
642
+ "duration_ms": duration_ms,
643
+ "compliance_score": compliance_score,
644
+ "reportable_events": len(getattr(state, "reportable_events", [])) if state else 0,
645
+ })
646
+
647
+ if len(self.execution_history) > self.max_history:
648
+ self.execution_history.pop(0)
649
+
650
+ async def health_check(self) -> Dict[str, Any]:
651
+ total_runs = len(self.execution_history)
652
+ if total_runs == 0:
653
+ return {
654
+ "Agent": "Audit Agent 🧮",
655
+ "Executions": 0,
656
+ "Success Rate (%)": 0.0,
657
+ "Avg Duration (ms)": 0.0,
658
+ "Total Failures": 0,
659
+ "Avg Compliance (%)": 0.0,
660
+ "Avg Reportable Events": 0.0,
661
+ "Status": "idle",
662
+ # "Timestamp": datetime.utcnow().isoformat()
663
+ }
664
+
665
+ successes = sum(1 for e in self.execution_history if e["success"])
666
+ failures = total_runs - successes
667
+ avg_duration = round(mean(e["duration_ms"] for e in self.execution_history), 2)
668
+ success_rate = round((successes / (total_runs+1e-8)) * 100, 2)
669
+ avg_compliance = round(mean(e["compliance_score"] for e in self.execution_history), 2)
670
+ avg_events = round(mean(e["reportable_events"] for e in self.execution_history), 2)
671
+
672
+ # Dynamic health status logic
673
+ print("self.execution_history------", self.execution_history)
674
+ print(avg_compliance)
675
+ if success_rate >= 85 and avg_compliance >= 90:
676
+ overall_status = "🟢 Healthy"
677
+ elif success_rate >= 60:
678
+ overall_status = "🟠 Degraded"
679
+ else:
680
+ overall_status = "🔴 Unhealthy"
681
+
682
+ return {
683
+ "Agent": "Audit Agent 🧮",
684
+ "Executions": total_runs,
685
+ "Success Rate (%)": success_rate,
686
+ "Avg Duration (ms)": avg_duration,
687
+ "Total Failures": failures,
688
+ "Avg Compliance (%)": avg_compliance,
689
+ "Avg Reportable Events": avg_events,
690
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
691
+ "Overall Health": overall_status,
692
+ "Last Run": self.metrics["last_run_at"],
693
+ }
Project/agents/base_agent.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Base Agent Class for Invoice Processing System"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import time
7
+ import logging
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, Any, Optional, List
10
+ from datetime import datetime
11
+
12
+ from state import InvoiceProcessingState, ProcessingStatus, AuditTrail
13
+ from utils.logger import get_logger
14
+
15
+
16
+ class BaseAgent(ABC):
17
+ """Abstract base class for all invoice processing agents"""
18
+
19
+ def __init__(self, agent_name: str, config: Dict[str, Any] = None):
20
+ self.agent_name = agent_name
21
+ self.config = config or {}
22
+ self.logger = get_logger(agent_name)
23
+ self.metrics: Dict[str,Any] = {
24
+ "processed" : 0,
25
+ "errors" : 0,
26
+ "avg_latency_ms" : None,
27
+ "last_run_at" : None
28
+ }
29
+ self.start_time: Optional[float] = None
30
+
31
+ @abstractmethod
32
+ async def execute(self, state: InvoiceProcessingState) -> InvoiceProcessingState:
33
+ raise NotImplementedError
34
+
35
+ async def run(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
36
+ self.start_time = time.time()
37
+ self.logger.logger.info(f"Starting {self.agent_name} execution.")
38
+ if not self._validate_preconditions(state, workflow_type):
39
+ self.logger.logger.warning(f"Preconditions not met for {self.agent_name}.")
40
+ self.metrics["processed"] = int(self.metrics.get("processed", 0)) + 1
41
+ self.metrics["last_run_at"] = datetime.utcnow().isoformat()
42
+
43
+ # optional but very good:
44
+ state.add_agent_metric(self.agent_name, processed=1, latency_ms=0, errors=0)
45
+
46
+ state.add_audit_entry(
47
+ self.agent_name,
48
+ "precondition_failed",
49
+ {"note": "Preconditions not met, agent skipped."}
50
+ )
51
+ return state
52
+ state.current_agent = self.agent_name
53
+ state.agent_name = self.agent_name
54
+ state.overall_status = ProcessingStatus.IN_PROGRESS
55
+
56
+ try:
57
+ updated_state = await self.execute(state, workflow_type)
58
+
59
+ try:
60
+ self._validate_postconditions(updated_state)
61
+ except Exception as post_exc:
62
+ self.logger.logger.warning(f"Postcondition check raised for {self.agent_name}:{post_exc}")
63
+
64
+ state.mark_agent_completed(self.agent_name)
65
+ latency_ms = (time.time()-self.start_time)*1000
66
+ self.metrics["processed"] = int(self.metrics.get("processed",0)) + 1
67
+ prev_avg = self.metrics.get("avg_latency_ms")
68
+
69
+ if prev_avg is None:
70
+ self.metrics["avg_latency_ms"] = latency_ms
71
+ else:
72
+ self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
73
+
74
+ self.metrics["last_run_at"] = datetime.utcnow().isoformat()
75
+ print(
76
+ f"Agent: {self.agent_name} | "
77
+ f"id: {id(self)} | "
78
+ f"last_run_at: {self.metrics['last_run_at']}"
79
+ )
80
+
81
+ print("self.metrics[last_run_at]", self.metrics["last_run_at"])
82
+ state.add_agent_metric(self.agent_name,processed=1,latency_ms=latency_ms)
83
+ state.add_audit_entry(self.agent_name, action="Agent Successfully Executed", status=ProcessingStatus.COMPLETED, details={"latency_ms":latency_ms}, process_id=state.process_id)
84
+
85
+ self.logger.logger.info(f"{self.agent_name}completed successfully in {latency_ms:.2f}ms.")
86
+ return updated_state
87
+
88
+ except Exception as e:
89
+ latency_ms = (time.time()-self.start_time)*1000 if self.start_time else 0.0
90
+ # self._update_metrics(latency_ms=latency_ms,error=True)
91
+ self.metrics["processed"] = int(self.metrics.get("processed",0))+1
92
+ self.metrics["errors"] = int(self.metrics.get("errors",0))+1
93
+ prev_avg = self.metrics.get("avg_latency_ms")
94
+
95
+ if prev_avg is None:
96
+ self.metrics["avg_latency_ms"] = latency_ms
97
+ else:
98
+ self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
99
+ self.metrics["last_run_at"] = datetime.utcnow().isoformat()
100
+ state.add_agent_metric(self.agent_name, processed = 1, latency_ms = latency_ms, errors = 1)
101
+ state.add_audit_entry(self.agent_name,"Error in Execution",{"error":str(e)})
102
+ state.overall_status = ProcessingStatus.FAILED
103
+ self.logger.logger.exception(f"{self.agent_name} failed: {e}")
104
+ return state
105
+
106
+ def _validate_preconditions(self, state: InvoiceProcessingState) -> bool:
107
+ # pass
108
+ "override to add custom preconditions for agent execution"
109
+ return True
110
+
111
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
112
+ # pass
113
+ "override to verify expected outcomes after agent execution"
114
+ return True
115
+
116
+
117
+ def get_metrics(self) -> Dict[str, Any]:
118
+ # pass
119
+ return dict(self.metrics)
120
+
121
+ def reset_metrics(self):
122
+ # pass
123
+ self.metrics = {"processed":0,
124
+ "errors":0,
125
+ "avg_latency_ms":None,
126
+ "last_run_at":None}
127
+
128
+ async def health_check(self) -> Dict[str, Any]:
129
+ # pass
130
+ """perform a basic health check for the agent"""
131
+ return {
132
+ "agent":self.agent_name,
133
+ "status":"Healthy",
134
+ "Last Run":self.metrics.get("last_run_at"),
135
+ "errors":self.metrics.get("errors", 0)
136
+ }
137
+
138
+ def _extract_business_context(self, state: InvoiceProcessingState) -> Dict[str, Any]:
139
+ # pass
140
+ """Extract relevant invoice or PO context for resaoning logs"""
141
+ context: Dict[str,Any] = {}
142
+ if state.invoice_data:
143
+ context["vendor"] = state.invoice_data.vendor_name
144
+ context["invoice_id"] = state.invoice_data.invoice_id
145
+ context["amount"] = state.invoice_data.total_amount
146
+ if state.validation_result:
147
+ try:
148
+ context["validation_status"] = state.validation_result.validation_status.value
149
+ except Exception:
150
+ context["validation_status"] = str(state.validation_result.validation_status)
151
+ if state.risk_assessment:
152
+ context["risk_score"] = state.risk_assessment.risk_score
153
+ context["risk_level"] = state.risk_assessment.risk_level.value if hasattr(state.risk_assessment.risk_level, "value") else str(state.risk_assessment.risk_level)
154
+ return context
155
+
156
+
157
+ def _should_escalate(self, state: InvoiceProcessingState, reason: str = None) -> bool:
158
+ # pass
159
+ """Determine whether the workflow should escalate."""
160
+ try:
161
+ result = state.requires_escalation()
162
+ except Exception:
163
+ result = True
164
+ if result:
165
+ self.logger.logger.warning(f"Escalation triggered by {self.agent_name}:{reason or 'auto'}")
166
+ state.escalation_required = True
167
+ state.human_review_required = True
168
+ state.add_audit_entry(self.agent_name,"Escalation Triggered", None, {"reason":reason or "auto"})
169
+ return result
170
+
171
+ def _log_decision(self, state: InvoiceProcessingState, decision: str,
172
+ reasoning: str, confidence: float = None, process_id: str = None):
173
+ # pass
174
+ """Log and record an agent decision into audit trail."""
175
+ details:Dict[str,Any] = {
176
+ "decision":decision,
177
+ "reasoning":reasoning,
178
+ "confidence":confidence,
179
+ # "timestamp":datetime.utcnow().isoformat()
180
+ }
181
+ self.logger.logger.info(f"{self.agent_name} decision:{decision}(confidence = {confidence})")
182
+ state.add_audit_entry(self.agent_name, decision, None, details, process_id)
183
+
184
+ class AgentRegistry:
185
+ """Registry for managing agent instances"""
186
+
187
+ def __init__(self):
188
+ # pass
189
+ self._agents:Dict[str,BaseAgent] = {}
190
+
191
+ def register(self, agent: BaseAgent):
192
+ # pass
193
+ if agent.agent_name in self._agents:
194
+ print(f"{agent.agent_name} already registered - skipping")
195
+ return
196
+ self._agents[agent.agent_name] = agent
197
+
198
+ def get(self, agent_name: str) -> Optional[BaseAgent]:
199
+ # pass
200
+ return self._agents.get(agent_name)
201
+
202
+ def list_agents(self) -> List[str]:
203
+ # pass
204
+ return list(self._agents.keys())
205
+
206
+ def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
207
+ # pass
208
+ return {name:agent.get_metrics() for name, agent in self._agents.items()}
209
+
210
+ async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
211
+ # pass
212
+ result:Dict[str,Dict[str,Any]] = {}
213
+ for name, agent in self._agents.items():
214
+ result[name] = await agent.health_check()
215
+ return result
216
+
217
+
218
+
219
+ # Global agent registry instance
220
+ agent_registry = AgentRegistry()
221
+ print("Registry instance ID in base:", id(agent_registry))
Project/agents/document_agent.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Document Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ import fitz # PyMuPDF
10
+ import pdfplumber
11
+ from typing import Dict, Any, Optional, List
12
+ import google.generativeai as genai
13
+ from dotenv import load_dotenv
14
+ from datetime import datetime
15
+
16
+ from agents.base_agent import BaseAgent
17
+ from state import (
18
+ InvoiceProcessingState, InvoiceData, ItemDetail,
19
+ ProcessingStatus, ValidationStatus
20
+ )
21
+ from utils.logger import StructuredLogger
22
+
23
+
24
+ load_dotenv()
25
+ logger = StructuredLogger("DocumentAgent")
26
+
27
+ def safe_json_parse(result_text: str):
28
+ # Remove Markdown formatting if present
29
+ cleaned = re.sub(r"^```[a-zA-Z]*\n|```$", "", result_text.strip())
30
+ try:
31
+ return json.loads(cleaned)
32
+ except json.JSONDecodeError:
33
+ # Fallback if the AI wrapped JSON in text
34
+ start, end = cleaned.find("{"), cleaned.rfind("}") + 1
35
+ if start >= 0 and end > 0:
36
+ return json.loads(cleaned[start:end])
37
+ raise
38
+
39
+ def to_float(value):
40
+ if isinstance(value, (int, float)):
41
+ return float(value)
42
+ if isinstance(value, str):
43
+ try:
44
+ return float(value.replace(',', '').replace('$', '').strip())
45
+ except (ValueError, TypeError):
46
+ return 0.0
47
+ return 0.0
48
+
49
+ def parse_date_safe(date_str):
50
+ if not date_str:
51
+ return None
52
+ for fmt in ("%b %d %Y", "%b %d, %Y", "%Y-%m-%d", "%d-%b-%Y"):
53
+ try:
54
+ return datetime.strptime(date_str.strip(), fmt).date()
55
+ except ValueError:
56
+ continue
57
+ return None
58
+
59
+
60
+ from collections import defaultdict
61
+ class APIKeyBalancer:
62
+ SAVE_FILE = "key_stats.json"
63
+ def __init__(self, keys):
64
+ self.keys = keys
65
+ self.usage = defaultdict(int)
66
+ self.errors = defaultdict(int)
67
+ self.load()
68
+
69
+ def load(self):
70
+ if os.path.exists(self.SAVE_FILE):
71
+ data = json.load(open(self.SAVE_FILE))
72
+ self.usage.update(data.get("usage", {}))
73
+ self.errors.update(data.get("errors", {}))
74
+
75
+ def save(self):
76
+ json.dump({
77
+ "usage": self.usage,
78
+ "errors": self.errors
79
+ }, open(self.SAVE_FILE, "w"))
80
+
81
+ def get_best_key(self):
82
+ # choose least used or least errored key
83
+ best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k]))
84
+ self.usage[best_key] += 1
85
+ self.save()
86
+ return best_key
87
+
88
+ def report_error(self, key):
89
+ self.errors[key] += 1
90
+ self.save()
91
+
92
+
93
+ balancer = APIKeyBalancer([
94
+ os.getenv("GEMINI_API_KEY_1"),
95
+ os.getenv("GEMINI_API_KEY_2"),
96
+ os.getenv("GEMINI_API_KEY_3"),
97
+ # os.getenv("GEMINI_API_KEY_4"),
98
+ os.getenv("GEMINI_API_KEY_5"),
99
+ os.getenv("GEMINI_API_KEY_6"),
100
+ # os.getenv("GEMINI_API_KEY_7"),
101
+ ])
102
+
103
+
104
+ class DocumentAgent(BaseAgent):
105
+ """Agent responsible for document processing and invoice data extraction"""
106
+
107
+ def __init__(self, config: Dict[str, Any] = None):
108
+ # pass
109
+ super().__init__("document_agent", config)
110
+ self.logger = StructuredLogger("DocumentAgent")
111
+ self.api_key = balancer.get_best_key()
112
+ print("self.api_key..........", self.api_key)
113
+
114
+ genai.configure(api_key=self.api_key)
115
+ # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7"))
116
+ self.model = genai.GenerativeModel("gemini-2.5-flash")
117
+
118
+ def generate(self, prompt):
119
+ try:
120
+ print("generate called")
121
+ response = self.model.generate_content(prompt)
122
+ print("response....", response)
123
+ return response
124
+ except Exception as e:
125
+ print("errrororrrooroor")
126
+ balancer.report_error(self.api_key)
127
+ print(balancer.keys)
128
+ print(balancer.usage)
129
+ print(balancer.errors)
130
+ raise
131
+
132
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
133
+ # pass
134
+ if not state.file_name or not os.path.exists(state.file_name):
135
+ self.logger.logger.error(f"[Document Agent] Missing or invalid file: {state.file_name}")
136
+ return False
137
+ return True
138
+
139
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
140
+ # pass
141
+ return bool(state.invoice_data and state.invoice_data.total > 0)
142
+
143
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
144
+ # pass
145
+ # file_name = state.file_name
146
+ self.logger.logger.info(f"Executing Document Agent for file: {state.file_name}")
147
+
148
+ if not self._validate_preconditions(state, workflow_type):
149
+ state.overall_status = ProcessingStatus.FAILED
150
+ self._log_decision(state, "Extraction Failed", "Preconditions not met", confidence=0.0)
151
+
152
+ try:
153
+ raw_text = await self._extract_text_from_pdf(state.file_name)
154
+ invoice_data = await self._parse_invoice_with_ai(raw_text)
155
+ invoice_data = await self._enhance_invoice_data(invoice_data, raw_text)
156
+ invoice_data.file_name = state.file_name
157
+ state.invoice_data = invoice_data
158
+ state.overall_status = ProcessingStatus.IN_PROGRESS
159
+ state.current_agent = self.agent_name
160
+ state.updated_at = datetime.utcnow()
161
+
162
+ confidence = self._calculate_extraction_confidence(invoice_data, raw_text)
163
+ state.invoice_data.extraction_confidence = confidence
164
+ self._log_decision(
165
+ state,
166
+ "Extraction Successful",
167
+ "PDF text successfully extracted and parsed by AI",
168
+ confidence,
169
+ state.process_id
170
+ )
171
+ return state
172
+ except Exception as e:
173
+ self.logger.logger.exception(f"[Document Agent] Extraction failed: {e}")
174
+ state.overall_status = ProcessingStatus.FAILED
175
+ self._should_escalate(state, reason=str(e))
176
+ return state
177
+
178
+
179
+ async def _extract_text_from_pdf(self, file_name: str) -> str:
180
+ # pass
181
+ text = ""
182
+ try:
183
+ self.logger.logger.info("[DocumentAgent] Extracting text using PyMuPDF...")
184
+ with fitz.open(file_name) as doc:
185
+ for page in doc:
186
+ text += page.get_text()
187
+ if len(text.strip()) < 5:
188
+ raise ValueError("PyMuPDF extraction too short, switching to PDFPlumber")
189
+ except Exception as e:
190
+ self.logger.logger.info("[DocumentAgent] Fallback to PDFPlumber...")
191
+ try:
192
+ with pdfplumber.open(file_name) as pdf:
193
+ for page in pdf.pages:
194
+ text += page.extract_text() or ""
195
+ except Exception as e2:
196
+ self.logger.logger.error("[DocumentAgent] PDFPlumber failed :{e2}")
197
+ text = ""
198
+ return text
199
+
200
+ async def _parse_invoice_with_ai(self, text: str) -> InvoiceData:
201
+ # pass
202
+ self.logger.logger.info("[DocumentAgent] Parsing invoice data using Gemini AI...")
203
+ print("text-----------", text)
204
+ prompt = f"""
205
+ Extract structured invoice information as JSON with fields:
206
+ invoice_number, order_id, customer_name, due_date, ship_to, ship_mode,
207
+ subtotal, discount, shipping_cost, total, and item_details (item_name, quantity, rate, amount).
208
+
209
+ Important Note: If an item description continues on multiple lines, combine them into one item_name. Check intelligently
210
+ that if at all there will be more than one item then it should have more numbers.
211
+ So extract by verifying that is there only one item or more than one.
212
+
213
+ Input Text:
214
+ {text[:8000]}
215
+ """
216
+ response = self.generate(prompt)
217
+ result_text = response.text.strip()
218
+ data = safe_json_parse(result_text)
219
+ print("----------------------------------text-----------------------------------",text)
220
+ print("result text::::::::::::::::::::::::::::",data)
221
+ # try:
222
+ # data = json.loads(result_text)
223
+ # except Exception as e:
224
+ # self.logger.logger.warning("AI output not valid JSON, retrying with fallback parse.")
225
+ # data = json.loads(result_text[result_text.find('{'): result_text.rfind('}')+1])
226
+ items = []
227
+ for item in data.get("item_details", []):
228
+ items.append(ItemDetail(
229
+ item_name=item.get("item_name"),
230
+ quantity=float(item.get("quantity", 1)),
231
+ rate=to_float(item.get("rate", 0.0)),
232
+ amount=to_float(item.get("amount", 0.0)),
233
+ # category=self._categorize_item(item.get("item_name", "Unknown")),
234
+ ))
235
+
236
+ invoice_data = InvoiceData(
237
+ invoice_number=data.get("invoice_number"),
238
+ order_id=data.get("order_id"),
239
+ customer_name=data.get("customer_name"),
240
+ due_date=parse_date_safe(data.get("due_date")),
241
+ ship_to=data.get("ship_to"),
242
+ ship_mode=data.get("ship_mode"),
243
+ subtotal=to_float(data.get("subtotal", 0.0)),
244
+ discount=to_float(data.get("discount", 0.0)),
245
+ shipping_cost=to_float(data.get("shipping_cost", 0.0)),
246
+ total=to_float(data.get("total", 0.0)),
247
+ item_details=items,
248
+ raw_text=text,
249
+ )
250
+ confidence = self._calculate_extraction_confidence(invoice_data, text)
251
+ invoice_data.extraction_confidence = confidence
252
+ self.logger.logger.info("AI output successfully parsed into JSON format")
253
+ return invoice_data
254
+
255
+
256
+ async def _enhance_invoice_data(self, invoice_data: InvoiceData, raw_text: str) -> InvoiceData:
257
+ # pass
258
+ if not invoice_data.customer_name:
259
+ if "Invoice To" in raw_text:
260
+ lines = raw_text.split("\n")
261
+ for i, line in enumerate(lines):
262
+ if "Invoice To" in line:
263
+ invoice_data.customer_name = lines[i+1].strip()
264
+ break
265
+ return invoice_data
266
+
267
+ def _categorize_item(self, item_name: str) -> str:
268
+ # pass
269
+ name = item_name.lower()
270
+ prompt = f"""
271
+ Extract the category of the Item from the item details very intelligently
272
+ so that we can get the category in which the item belongs to very efficiently:
273
+ Example: "Electronics", "Furniture", "Software", etc.....
274
+ Input Text- The item is given below (provide the category in JSON format like -- category: 'extracted category') ---->
275
+ {name}
276
+ """
277
+ response = self.generate(prompt)
278
+ result_text = response.text.strip()
279
+ category = safe_json_parse(result_text)
280
+ print(category['category'])
281
+ return category['category']
282
+
283
+ def _calculate_extraction_confidence(self, invoice_data: InvoiceData, raw_text: str) -> float:
284
+ """
285
+ Intelligent confidence scoring for extracted invoice data.
286
+ Combines presence, consistency, and numeric sanity checks.
287
+ """
288
+ score = 0.0
289
+ weight = {
290
+ "invoice_number": 0.1,
291
+ "order_id": 0.05,
292
+ "customer_name": 0.1,
293
+ "due_date": 0.05,
294
+ "ship_to": 0.05,
295
+ "item_details": 0.25,
296
+ "total_consistency": 0.25,
297
+ "currency_detected": 0.05,
298
+ "text_match_bonus": 0.1
299
+ }
300
+
301
+ text_lower = raw_text.lower()
302
+
303
+ # Presence-based confidence
304
+ if invoice_data.invoice_number:
305
+ score += weight["invoice_number"]
306
+ if invoice_data.order_id:
307
+ score += weight["order_id"]
308
+ if invoice_data.customer_name:
309
+ score += weight["customer_name"]
310
+ if invoice_data.due_date and "due_date" in text_lower:
311
+ score += weight["due_date"]
312
+ if not invoice_data.due_date and "due_date" not in text_lower:
313
+ score += weight["due_date"]
314
+ if invoice_data.item_details:
315
+ score += weight["item_details"]
316
+
317
+ # Currency detection
318
+ if any(c in raw_text for c in ["$", "₹", "€", "usd", "inr", "eur"]):
319
+ score += weight["currency_detected"]
320
+
321
+ # Numeric Consistency: subtotal + shipping ≈ total
322
+ def _extract_amounts(pattern):
323
+ import re
324
+ matches = re.findall(pattern, raw_text)
325
+ return [float(m.replace(",", "").replace("$", "").strip()) for m in matches if m]
326
+
327
+ import re
328
+ numbers = _extract_amounts(r"\$?\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?")
329
+ if len(numbers) >= 3 and invoice_data.total:
330
+ approx_total = max(numbers)
331
+ diff = abs(approx_total - invoice_data.total)
332
+ if diff < 5: # minor difference allowed
333
+ score += weight["total_consistency"]
334
+ elif diff < 50:
335
+ score += weight["total_consistency"] * 0.5
336
+
337
+ # Textual verification
338
+ hits = 0
339
+ for field in [invoice_data.customer_name, invoice_data.order_id, invoice_data.invoice_number]:
340
+ if field and str(field).lower() in text_lower:
341
+ hits += 1
342
+ if hits >= 2:
343
+ score += weight["text_match_bonus"]
344
+
345
+ # Penalty for empty critical fields
346
+ missing_critical = not invoice_data.total or not invoice_data.customer_name or not invoice_data.invoice_number
347
+ if missing_critical:
348
+ score *= 0.8
349
+
350
+ # Clamp and finalize
351
+ final_conf = round(min(score, 0.99), 2)
352
+ invoice_data.extraction_confidence = final_conf
353
+ return final_conf * 100.0
354
+
355
+
356
+ async def health_check(self) -> Dict[str, Any]:
357
+ """
358
+ Perform intelligent health diagnostics for the Document Agent.
359
+ Collects operational, performance, and API connectivity metrics.
360
+ """
361
+ from datetime import datetime
362
+
363
+ metrics_data = {}
364
+ executions = 0
365
+ success_rate = 0.0
366
+ avg_duration = 0.0
367
+ failures = 0
368
+ last_run = None
369
+ # latency_trend = None
370
+
371
+ # 1. Try to get live metrics from state
372
+ print("(self.state)-------",self.metrics)
373
+ # print("self.state.agent_metrics-------", self.state.agent_metrics)
374
+ if self.metrics:
375
+ executions = self.metrics["processed"]
376
+ avg_duration = self.metrics["avg_latency_ms"]
377
+ failures = self.metrics["errors"]
378
+ last_run = self.metrics["last_run_at"]
379
+ success_rate = (executions - failures) / (executions+1e-8)
380
+
381
+ # print(executions, avg_duration, failures, last_run, success_rate)
382
+ # latency_trend = getattr(m, "total_duration_ms", None)
383
+
384
+ # 2. API connectivity check
385
+ gemini_ok = bool(self.api_key)
386
+ # print("self.api---", self.api_key)
387
+ # print("geminiokkkkkk", gemini_ok)
388
+ api_status = "🟢 Active" if gemini_ok else "🔴 Missing Key"
389
+
390
+ # 3. Health logic
391
+ overall_status = "🟢 Healthy"
392
+ if not gemini_ok or failures > 3:
393
+ overall_status = "🟠 Degraded"
394
+ if executions > 0 and success_rate < 0.5:
395
+ overall_status = "🔴 Unhealthy"
396
+
397
+ # 4. Extended agent diagnostics
398
+ metrics_data = {
399
+ "Agent": "Document Agent 🧾",
400
+ "Executions": executions,
401
+ "Success Rate (%)": round(success_rate * 100, 2),
402
+ "Avg Duration (ms)": round(avg_duration, 2),
403
+ "Total Failures": failures,
404
+ "API Status": api_status,
405
+ "Last Run": str(last_run) if last_run else "Not applicable",
406
+ "Overall Health": overall_status,
407
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
408
+ }
409
+
410
+ self.logger.logger.info(f"[HealthCheck] Document Agent metrics: {metrics_data}")
411
+ return metrics_data
Project/agents/escalation_agent.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Escalation Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import os
7
+ import json
8
+ import smtplib
9
+ from email.mime.text import MIMEText
10
+ from email.mime.multipart import MIMEMultipart
11
+ from typing import Dict, Any, List, Optional
12
+ from datetime import datetime, timedelta
13
+ import google.generativeai as genai
14
+ from dotenv import load_dotenv
15
+
16
+ from agents.base_agent import BaseAgent
17
+ from state import (
18
+ InvoiceProcessingState, ProcessingStatus, PaymentStatus,
19
+ RiskLevel, ValidationStatus
20
+ )
21
+ from utils.logger import StructuredLogger
22
+
23
+ load_dotenv()
24
+
25
+
26
+ class EscalationAgent(BaseAgent):
27
+ """Agent responsible for escalation management and human-in-the-loop workflows"""
28
+
29
+ def __init__(self, config: Dict[str, Any] = None):
30
+ super().__init__("escalation_agent",config)
31
+ self.logger = StructuredLogger("EscalationAgent")
32
+
33
+ self.escalation_triggers = {
34
+ 'high_risk' : {'route_to':'risk_manager','sla_hours':4},
35
+ 'validation_failure': {'route_to':'finance_manager','sla_hours':8},
36
+ 'high_value': {'route_to':'cfo','sla_hours':24},
37
+ 'fraud_suspicion': {'route_to':'fraud_team','sla_hours':2},
38
+ 'new_vendor':{'route_to':'procurement','sla_hours':48}
39
+ }
40
+
41
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
42
+ # pass
43
+ return hasattr(state,'invoice_data') and hasattr(state,'risk_assessment')
44
+
45
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
46
+ # pass
47
+ return hasattr(state,'escalation_details')
48
+
49
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
50
+ # pass
51
+ self.logger.logger.info('Executing Escalation Agent...')
52
+ if not self._validate_preconditions(state, workflow_type):
53
+ self.logger.logger.error("Preconditions not meet for Escalation handling")
54
+ state.status = ProcessingStatus.FAILED
55
+ self._log_decision(state, "Escalation Agent Failed", "Preconditions not met", confidence=0.0)
56
+ return state
57
+
58
+ escalation_type = self._determine_escalation_type(state)
59
+ if not escalation_type:
60
+ self.logger.logger.info("No escalation required for this invoice.")
61
+ state.escalation_required = False
62
+ state.overall_status = 'completed'
63
+ return state
64
+
65
+ priority_level = self._calculate_priority_level(state)
66
+ approver_info = self._route_to_approver(state, escalation_type,priority_level)
67
+ summary = await self._generate_escalation_summary(state,escalation_type,approver_info)
68
+
69
+ escalation_record = await self._create_escalation_record(state, escalation_type, priority_level, approver_info,summary)
70
+ await self._send_escalation_notifications(state,escalation_record,approver_info)
71
+ await self._setup_sla_monitoring(state,escalation_record,priority_level)
72
+
73
+ state.escalation_required = True
74
+ state.human_review_required = True
75
+ state.escalation_details = escalation_record
76
+ state.human_review_required = summary
77
+ state.escalation_reason = escalation_record["escalation_reason"]
78
+ state.current_agent = 'escalation_agent'
79
+ state.overall_status = 'escalated'
80
+ self._log_decision(
81
+ state,
82
+ "Escalation Successful",
83
+ "PDF successfully escalated to Human for review",
84
+ "N/A",
85
+ state.process_id
86
+ )
87
+ self.logger.logger.info('Escalation record successfully created and routed.')
88
+ return state
89
+
90
+ def _determine_escalation_type(self, state: InvoiceProcessingState) -> str:
91
+ # pass
92
+ risk = getattr(state,'risk_assessment',{})
93
+ validation = getattr(state,'validation_result',{})
94
+ invoice = getattr(state,'invoice_data',{})
95
+ risk_level = getattr(risk,'risk_level',{})
96
+ amount = getattr(invoice,'total',0)
97
+ vendor = getattr(invoice,'customer_name','')
98
+ # fraud_indicators = risk.get('fraud_indicators',[])
99
+ fraud_indicators = getattr(risk,'fraud_indicators',[])
100
+
101
+ if risk_level in ['high','critical']:
102
+ return 'high_risk'
103
+ elif state.validation_status == 'invalid' or state.validation_status == 'missing_po':
104
+ return 'validation_failure'
105
+ elif amount and amount>250000:
106
+ return 'high_value'
107
+ elif len(fraud_indicators) > 3:
108
+ return 'fraud_suspicion'
109
+ elif vendor and 'new' in vendor.lower():
110
+ return 'new_vendor'
111
+ else:
112
+ return None
113
+
114
+ def _calculate_priority_level(self, state: InvoiceProcessingState) -> str:
115
+ # pass
116
+ # risk = getattr(state,'risk_assessment',{}).get('risk_level','low').lower()
117
+ # amount = getattr(state,'invoice_data',{}).get('total',0)
118
+ risk_assessment = getattr(state,'risk_assessment',{})
119
+ invoice_data = getattr(state,'invoice_data',{})
120
+ risk = getattr(risk_assessment,'risk_level','low').lower()
121
+ amount = getattr(invoice_data,'total',0)
122
+ if risk == 'critical' or amount > 50000:
123
+ return 'urgent'
124
+ elif risk == 'high' or amount > 25000:
125
+ return 'high'
126
+ else:
127
+ return 'medium'
128
+
129
+ def _route_to_approver(self, state: InvoiceProcessingState,
130
+ escalation_type: str, priority_level: str) -> Dict[str, Any]:
131
+ # pass
132
+ # print(self.escalation_triggers)
133
+ route_info = self.escalation_triggers.get(escalation_type,{})
134
+ # print("route_info..................", route_info)
135
+ assigned_to = route_info.get('route_to','finance_manager')
136
+ sla_hours = route_info.get('sla_hours',8)
137
+ approvers = ['finance_manager']
138
+ if assigned_to == 'cfo':
139
+ approvers.append('cfo')
140
+ return {
141
+ 'assigned_to':assigned_to,
142
+ 'sla_hours':sla_hours,
143
+ 'approval_required_from':approvers
144
+ }
145
+
146
+
147
+ def _parse_date(self, date_str: str) -> Optional[datetime.date]:
148
+ # pass
149
+ try:
150
+ return datetime.strptime(date_str,"%Y-%m-%d").date()
151
+ except Exception:
152
+ return None
153
+
154
+ async def _generate_escalation_summary(self, state: InvoiceProcessingState,
155
+ escalation_type: str, approver_info: Dict[str, Any]) -> str:
156
+ # pass
157
+
158
+ risk = getattr(state,'risk_assessment',{})
159
+ invoice = getattr(state,'invoice_data',{})
160
+ risk_level = getattr(risk,'risk_level',{})
161
+ amount = getattr(invoice,'total',0)
162
+ # invoice = state.invoice_data
163
+ # risk = state.risk_assessment
164
+ reason = ""
165
+
166
+ if escalation_type == 'high_risk':
167
+ reason = f"Invoice marked as high risk ({risk_level})."
168
+ elif escalation_type == 'validation_failure':
169
+ reason = 'Validation discrepancies require finance approval.'
170
+ elif escalation_type == 'high_value':
171
+ reason = f"High-value invoice ({amount}) requires CFO approval."
172
+ elif escalation_type == 'fraud_suspicion':
173
+ reason = 'Fraud suspicion based on anomalies detected'
174
+ elif escalation_type == 'new_vendor':
175
+ reason = 'Vendor is new and not yet in approved list.'
176
+ return f"{reason} Routed to {approver_info['assigned_to']} for review."
177
+
178
+
179
+ async def _create_escalation_record(self, state: InvoiceProcessingState,
180
+ escalation_type: str, priority_level: str,
181
+ approver_info: Dict[str, Any], summary: str) -> Dict[str, Any]:
182
+ # pass
183
+ timestamp = datetime.utcnow()
184
+ sla_deadline = timestamp+timedelta(hours=approver_info['sla_hours'])
185
+ return {
186
+ 'escalation_type':escalation_type,
187
+ 'severity':priority_level,
188
+ 'assigned_to':approver_info['assigned_to'],
189
+ 'escalation_time':timestamp.isoformat()+'Z',
190
+ 'sla_deadline':sla_deadline.isoformat()+'Z',
191
+ 'notification_sent':True,
192
+ 'approval_required_from':approver_info['approval_required_from'],
193
+ 'escalation_reason':summary
194
+ }
195
+
196
+
197
+ async def _send_escalation_notifications(self, state: InvoiceProcessingState,
198
+ escalation_record: Dict[str, Any],
199
+ approver_info: Dict[str, Any]) -> Dict[str, Any]:
200
+ # pass
201
+ try:
202
+ subject = f"[Escalation Alert] Invoice requires {approver_info['assigned_to']} review"
203
+ body = f"""
204
+ Escalation Type: {escalation_record['escalation_type']}
205
+ severity: {escalation_record['severity']}
206
+ SLA Deadline: {escalation_record['sla_deadline']}
207
+ reason: {escalation_record['escalation_reason']}
208
+ """
209
+ to_email = f"{approver_info['assigned_to']}@company.com"
210
+ self._send_email(to_email,subject,body)
211
+ self.logger.logger.info(f"Escalation notification send to {to_email}.")
212
+ return {'status':'send','to':to_email}
213
+ except Exception as e:
214
+ self.logger.logger.error(f'Failed to send notification: {e}')
215
+ return {'status':'failed','error':str(e)}
216
+
217
+ def _send_email(self, to_email: str, subject: str, body: str) -> Dict[str, Any]:
218
+ # pass
219
+ try:
220
+ sender = os.getenv('EMAIL_SENDER','noreply@invoicesystem.com')
221
+ msg = MIMEMultipart()
222
+ msg['From'] = send
223
+ msg['To'] = to_email
224
+ msg['Subject'] = subject
225
+ msg.attach(MIMEText(body,'plain'))
226
+ with smtplib.SMTP('localhost') as server:
227
+ server.send_message(msg)
228
+ return {'sent':True}
229
+ except Exception as e:
230
+ return {'sent':False, 'error':str(e)}
231
+
232
+
233
+ async def _setup_sla_monitoring(self, state: InvoiceProcessingState,
234
+ escalation_record: Dict[str, Any], priority_level: str):
235
+ # pass
236
+ self.logger.logger.debug(
237
+ f"SLA monitoring initialized for {escalation_record['escalation_type']}"
238
+ f"with deadline {escalation_record['sla_deadline']}"
239
+ )
240
+
241
+ async def resolve_escalation(self, escalation_id: str, resolution: str,
242
+ resolver: str) -> Dict[str, Any]:
243
+ # pass
244
+ return {
245
+ 'escalation_id':escalation_id,
246
+ 'resolved_by':resolver,
247
+ 'resolution_notes':resolution,
248
+ 'resolved_at':datetime.utcnow().isoformat()+'Z',
249
+ 'status':'resolved'
250
+ }
251
+
252
+ async def health_check(self) -> Dict[str, Any]:
253
+ """
254
+ Performs a detailed health check for the Escalation Agent.
255
+ Includes operational metrics, configuration validation, and reliability stats.
256
+ """
257
+
258
+ start_time = datetime.utcnow()
259
+ self.logger.logger.info("Performing health check for EscalationAgent...")
260
+
261
+ executions = 0
262
+ avg_duration = 0.0
263
+ failures = 0
264
+ last_run = None
265
+ success_rate = 0.0
266
+
267
+ try:
268
+ if self.metrics:
269
+ executions = self.metrics["processed"]
270
+ avg_duration = self.metrics["avg_latency_ms"]
271
+ failures = self.metrics["errors"]
272
+ last_run = self.metrics["last_run_at"]
273
+ success_rate = (executions - failures) / (executions + 1e-8) * 100.0 if executions > 0 else 0.0
274
+
275
+ total_executions = executions
276
+ total_failures = failures
277
+ avg_duration_ms = avg_duration
278
+
279
+ # Email and trigger configuration validation
280
+ email_configured = bool(os.getenv('EMAIL_SENDER'))
281
+ missing_triggers = [k for k, v in self.escalation_triggers.items() if not v.get("route_to")]
282
+
283
+ # Duration calculation
284
+ # duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
285
+ # last_run = self.metrics["last_run_at"]
286
+
287
+ health_report = {
288
+ "Agent": "Escalation Agent 🚨",
289
+ "Executions": total_executions,
290
+ "Success Rate (%)": round(success_rate, 2),
291
+ "Avg Duration (ms)": round(avg_duration_ms, 2) if avg_duration_ms else "Not Called",
292
+ "Total Failures": total_failures,
293
+ # "Email Configured": email_configured,
294
+ # "Available Triggers": list(self.escalation_triggers.keys()),
295
+ "Missing Routes": missing_triggers,
296
+ "Last Run": self.metrics["last_run_at"],
297
+ "Overall Health": "🟢 Healthy" if (success_rate > 70 or total_executions == 0) else "Degraded ⚠️",
298
+ # "Response Time (ms)": round(duration_ms, 2)
299
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
300
+ }
301
+
302
+ self.logger.logger.info("EscalationAgent health check completed successfully.")
303
+ return health_report
304
+
305
+ except Exception as e:
306
+ error_time = (datetime.utcnow() - start_time).total_seconds() * 1000
307
+ self.logger.logger.error(f"Health check failed: {e}")
308
+
309
+ # Return degraded health if something goes wrong
310
+ return {
311
+ "Agent": "EscalationAgent ❌",
312
+ "Overall Health": "Degraded",
313
+ "Error": str(e),
314
+ "Timestamp": datetime.utcnow().isoformat() + "Z"
315
+ }
Project/agents/forecast_agent.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # agents/forecast_agent.py
3
+ """
4
+ Forecast Agent (robust)
5
+ - Accepts a list of invoice states (dicts or InvoiceProcessingState models).
6
+ - Produces monthly historical spend and a simple forecast (moving average).
7
+ - Performs lightweight anomaly detection.
8
+ - Returns a dict containing a Plotly chart and numeric summary.
9
+ """
10
+ from typing import List, Dict, Any, Union
11
+ from datetime import datetime
12
+ import pandas as pd
13
+ import plotly.express as px
14
+ import plotly.graph_objects as go
15
+ import math
16
+ import os
17
+
18
+ # keep the type import only for hints; we do NOT require reconstructing models
19
+ try:
20
+ from state import InvoiceProcessingState
21
+ except Exception:
22
+ InvoiceProcessingState = None # type: ignore
23
+
24
+
25
+ class ForecastAgent:
26
+ def __init__(self):
27
+ pass
28
+
29
+ # ---- Internal: normalize input states -> DataFrame ----
30
+ def _normalize_states_to_df(self, states: List[Union[dict, object]]) -> pd.DataFrame:
31
+ """
32
+ Accepts list of dicts or model instances.
33
+ Produces a cleaned DataFrame with columns:
34
+ ['file_name','invoice_date','due_date','total','vendor','risk_score','status']
35
+ """
36
+ rows = []
37
+ for s in states:
38
+ try:
39
+ # 1) obtain a plain dict representation without constructing pydantic models
40
+ if isinstance(s, dict):
41
+ raw = dict(s)
42
+ else:
43
+ # model-like object: try model_dump, to_dict, or __dict__
44
+ if hasattr(s, "model_dump"):
45
+ raw = s.model_dump(exclude_none=False)
46
+ elif hasattr(s, "dict"):
47
+ raw = s.dict()
48
+ else:
49
+ # best effort: convert attributes to dict
50
+ raw = {
51
+ k: getattr(s, k)
52
+ for k in dir(s)
53
+ if not k.startswith("_") and not callable(getattr(s, k))
54
+ }
55
+
56
+ # 2) sanitize well-known problematic fields that break pydantic elsewhere
57
+ if "human_review_required" in raw and isinstance(raw["human_review_required"], str):
58
+ v = raw["human_review_required"].strip().lower()
59
+ raw["human_review_required"] = v in ("true", "yes", "1", "required")
60
+ if "escalation_details" in raw and isinstance(raw["escalation_details"], dict):
61
+ # convert to string summary so downstream code doesn't expect a dict
62
+ try:
63
+ raw["escalation_details"] = str(raw["escalation_details"])
64
+ except Exception:
65
+ raw["escalation_details"] = ""
66
+
67
+ # 3) pull invoice_data safely (may be None, dict, or model)
68
+ inv = {}
69
+ if raw.get("invoice_data") is None:
70
+ inv = {}
71
+ else:
72
+ inv_raw = raw.get("invoice_data")
73
+ if isinstance(inv_raw, dict):
74
+ inv = dict(inv_raw)
75
+ else:
76
+ # model-like invoice_data
77
+ if hasattr(inv_raw, "model_dump"):
78
+ inv = inv_raw.model_dump(exclude_none=False)
79
+ elif hasattr(inv_raw, "dict"):
80
+ inv = inv_raw.dict()
81
+ else:
82
+ # fallback: read attributes
83
+ inv = {
84
+ k: getattr(inv_raw, k)
85
+ for k in dir(inv_raw)
86
+ if not k.startswith("_") and not callable(getattr(inv_raw, k))
87
+ }
88
+
89
+ # 4) turnout the row items we care about
90
+ total = inv.get("total") or inv.get("amount") or raw.get("total") or 0.0
91
+ # risk may be under risk_assessment.risk_score or top-level
92
+ risk_src = raw.get("risk_assessment") or {}
93
+ if isinstance(risk_src, dict):
94
+ risk_score = risk_src.get("risk_score") or 0.0
95
+ else:
96
+ # model-like risk_assessment
97
+ if hasattr(risk_src, "model_dump"):
98
+ try:
99
+ risk_score = risk_src.model_dump().get("risk_score", 0.0)
100
+ except Exception:
101
+ risk_score = 0.0
102
+ else:
103
+ risk_score = getattr(risk_src, "risk_score", 0.0)
104
+
105
+ # dates: prefer due_date then invoice_date - they could be strings or datetimes
106
+ due = inv.get("due_date") or inv.get("invoice_date") or raw.get("due_date") or raw.get("invoice_date")
107
+ vendor = inv.get("customer_name") or inv.get("vendor_name") or raw.get("vendor") or raw.get("customer_name") or "Unknown"
108
+ file_name = inv.get("file_name") or raw.get("file_name") or "unknown"
109
+
110
+ rows.append(
111
+ {
112
+ "file_name": file_name,
113
+ "due_date": due,
114
+ "invoice_date": inv.get("invoice_date") or raw.get("invoice_date"),
115
+ "total": total,
116
+ "vendor": vendor,
117
+ "risk_score": risk_score,
118
+ "status": raw.get("overall_status") or inv.get("status") or "unknown",
119
+ }
120
+ )
121
+ except Exception:
122
+ # skip malformed state
123
+ continue
124
+
125
+ df = pd.DataFrame(rows)
126
+ if df.empty:
127
+ return df
128
+
129
+ # coerce and normalize
130
+ df["due_date"] = pd.to_datetime(df["due_date"], errors="coerce")
131
+ df["invoice_date"] = pd.to_datetime(df["invoice_date"], errors="coerce")
132
+ # if due_date missing, fallback to invoice_date
133
+ df["date"] = df["due_date"].fillna(df["invoice_date"])
134
+ df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0)
135
+ df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0)
136
+ df["vendor"] = df["vendor"].fillna("Unknown")
137
+ return df
138
+
139
+ # ---- Public: predict monthly cashflow and return a plotly chart ----
140
+ def predict_cashflow(self, states: List[Union[dict, object]], months: int = 6) -> Dict[str, Any]:
141
+ """
142
+ Produces a monthly historical spend + simple forecast for `months` into the future.
143
+ Returns:
144
+ {
145
+ "chart": plotly_figure,
146
+ "average_monthly_spend": float,
147
+ "total_forecast": float,
148
+ "forecast_values": {month_str: float, ...},
149
+ "historical": pandas.Series,
150
+ "forecast_start_month": str,
151
+ "forecast_end_month": str
152
+ }
153
+ """
154
+ df = self._normalize_states_to_df(states)
155
+ if df.empty or df["date"].dropna().empty:
156
+ return {"message": "No data to forecast", "chart": None}
157
+
158
+ # create monthly buckets (period start)
159
+ df = df.dropna(subset=["date"])
160
+ df["month"] = df["date"].dt.to_period("M").dt.to_timestamp()
161
+ monthly_hist = df.groupby("month")["total"].sum().sort_index()
162
+
163
+ # compute average monthly spend from available historical months
164
+ average_month = float(monthly_hist.mean()) if not monthly_hist.empty else 0.0
165
+
166
+ # build forecast months (next `months` starting from the next month after last historical)
167
+ last_hist_month = monthly_hist.index.max()
168
+ if pd.isnull(last_hist_month):
169
+ start_month = pd.Timestamp.now().to_period("M").to_timestamp()
170
+ else:
171
+ # next month
172
+ start_month = (last_hist_month + pd.offsets.MonthBegin(1)).normalize()
173
+
174
+ forecast_index = pd.date_range(start=start_month, periods=months, freq="MS")
175
+ # simple forecast: repeat the historical mean (interpretable and safe)
176
+ forecast_vals = [average_month for _ in range(len(forecast_index))]
177
+
178
+ # build plot dataframe (historical + forecast)
179
+ hist_df = monthly_hist.reset_index().rename(columns={"month": "date", "total": "amount"})
180
+ hist_df["type"] = "Historical"
181
+ fc_df = pd.DataFrame({"date": forecast_index, "amount": forecast_vals})
182
+ fc_df["type"] = "Forecast"
183
+ plot_df = pd.concat([hist_df, fc_df], ignore_index=True).sort_values("date")
184
+
185
+ # prepare a plotly figure with clear styling
186
+ fig = go.Figure()
187
+ # historical - solid line
188
+ hist_plot = plot_df[plot_df["type"] == "Historical"]
189
+ if not hist_plot.empty:
190
+ fig.add_trace(go.Scatter(
191
+ x=hist_plot["date"],
192
+ y=hist_plot["amount"],
193
+ mode="lines+markers",
194
+ name="Historical Spend",
195
+ line=dict(dash="solid"),
196
+ ))
197
+ # forecast - dashed line
198
+ fc_plot = plot_df[plot_df["type"] == "Forecast"]
199
+ if not fc_plot.empty:
200
+ fig.add_trace(go.Scatter(
201
+ x=fc_plot["date"],
202
+ y=fc_plot["amount"],
203
+ mode="lines+markers",
204
+ name="Forecast",
205
+ line=dict(dash="dash"),
206
+ marker=dict(symbol="circle-open")
207
+ ))
208
+
209
+ fig.update_layout(
210
+ title="Monthly Spend (Historical + Forecast)",
211
+ xaxis_title="Month",
212
+ yaxis_title="Total Spend (USD)",
213
+ hovermode="x unified",
214
+ template="plotly_dark",
215
+ )
216
+
217
+ forecast_series = pd.Series(forecast_vals, index=[d.strftime("%Y-%m") for d in forecast_index])
218
+ total_forecast = float(forecast_series.sum())
219
+
220
+ result = {
221
+ "chart": fig,
222
+ "average_monthly_spend": round(average_month, 2),
223
+ "total_forecast": round(total_forecast, 2),
224
+ "forecast_values": forecast_series.to_dict(),
225
+ "historical": monthly_hist,
226
+ "forecast_start_month": forecast_index[0].strftime("%Y-%m"),
227
+ "forecast_end_month": forecast_index[-1].strftime("%Y-%m"),
228
+ }
229
+ return result
230
+
231
+ # ---- Public: detect anomalies on sanitized data ----
232
+ def detect_anomalies(self, states: List[Union[dict, object]]) -> pd.DataFrame:
233
+ """
234
+ Returns DataFrame of anomalies:
235
+ - total > 2 * mean(total)
236
+ - OR risk_score >= 0.7
237
+ Columns returned: ['file_name','date','vendor','total','risk_score','anomaly_reason']
238
+ """
239
+ df = self._normalize_states_to_df(states)
240
+ if df.empty:
241
+ return pd.DataFrame()
242
+
243
+ mean_spend = df["total"].mean()
244
+ cond = (df["total"] > mean_spend * 2) | (df["risk_score"] >= 0.7)
245
+ anomalies = df.loc[cond, ["file_name", "date", "vendor", "total", "risk_score"]].copy()
246
+ if anomalies.empty:
247
+ return pd.DataFrame()
248
+ anomalies = anomalies.rename(columns={"date": "invoice_date"})
249
+ anomalies["anomaly_reason"] = anomalies.apply(
250
+ lambda r: "High Spend" if r["total"] > mean_spend * 2 else "High Risk",
251
+ axis=1,
252
+ )
253
+ return anomalies.reset_index(drop=True)
Project/agents/insights_agent.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # agents/insights_agent.py
3
+ """
4
+ Insight Agent
5
+ -------------
6
+ Generates analytical and visual insights from processed invoices.
7
+ """
8
+
9
+ import pandas as pd
10
+ import plotly.express as px
11
+ from typing import List, Dict, Any
12
+ from state import InvoiceProcessingState
13
+
14
+
15
+ class InsightAgent:
16
+ def __init__(self):
17
+ pass
18
+
19
+ def _extract_invoice_records(self, results: List[InvoiceProcessingState]) -> pd.DataFrame:
20
+ """Extract flat invoice info for analysis"""
21
+ records = []
22
+ for r in results:
23
+ if isinstance(r, dict):
24
+ # Convert dict to InvoiceProcessingState if needed
25
+ try:
26
+ r = InvoiceProcessingState(**r)
27
+ except Exception:
28
+ continue
29
+
30
+ inv = getattr(r, "invoice_data", None)
31
+ risk = getattr(r, "risk_assessment", None)
32
+ val = getattr(r, "validation_result", None)
33
+ pay = getattr(r, "payment_decision", None)
34
+
35
+ records.append({
36
+ "file_name": getattr(inv, "file_name", None),
37
+ "invoice_number": getattr(inv, "invoice_number", None),
38
+ "customer_name": getattr(inv, "customer_name", None),
39
+ "invoice_date": getattr(inv, "invoice_date", None),
40
+ "total": getattr(inv, "total", None),
41
+ "validation_status": getattr(val, "validation_status", None),
42
+ "risk_score": getattr(risk, "risk_score", None),
43
+ "risk_level": getattr(risk, "risk_level", None),
44
+ "payment_status": getattr(pay, "status", None),
45
+ "decision": getattr(pay, "decision", None),
46
+ })
47
+
48
+ df = pd.DataFrame(records)
49
+ if df.empty:
50
+ return pd.DataFrame()
51
+
52
+ # Clean up data
53
+ df["customer_name"] = df["customer_name"].fillna("Unknown Vendor")
54
+ df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0)
55
+ df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0)
56
+ return df
57
+
58
+ def generate_insights(self, results: List[InvoiceProcessingState]) -> Dict[str, Any]:
59
+ """Generate charts and textual summary."""
60
+ df = self._extract_invoice_records(results)
61
+ if df.empty:
62
+ return {"summary": "No data available for insights.", "charts": []}
63
+
64
+ charts = []
65
+
66
+ # 🔹 Total spend per customer
67
+ if "customer_name" in df.columns:
68
+ spend_chart = px.bar(
69
+ df.groupby("customer_name", as_index=False)["total"].sum(),
70
+ x="customer_name",
71
+ y="total",
72
+ title="Total Spend per Customer"
73
+ )
74
+ charts.append(spend_chart)
75
+
76
+ # 🔹 Risk distribution
77
+ if "risk_level" in df.columns:
78
+ risk_chart = px.pie(
79
+ df,
80
+ names="risk_level",
81
+ title="Risk Level Distribution"
82
+ )
83
+ charts.append(risk_chart)
84
+
85
+ # 🔹 Validation status counts
86
+ if "validation_status" in df.columns:
87
+ val_chart = px.bar(
88
+ df.groupby("validation_status", as_index=False).size(),
89
+ x="validation_status",
90
+ y="size",
91
+ title="Validation Status Overview"
92
+ )
93
+ charts.append(val_chart)
94
+
95
+ # 🔹 Summary text
96
+ total_spend = df["total"].sum()
97
+ high_risk = (df["risk_score"] >= 0.7).sum()
98
+ valid_invoices = (df["validation_status"].astype(str).str.lower() == "valid").sum()
99
+
100
+ summary = (
101
+ f"💰 **Total Spend:** ₹{total_spend:,.2f}\n\n"
102
+ f"📄 **Invoices Processed:** {len(df)}\n\n"
103
+ f"✅ **Valid Invoices:** {valid_invoices}\n\n"
104
+ f"⚠️ **High Risk Invoices:** {high_risk}\n\n"
105
+ )
106
+
107
+ return {"summary": summary, "charts": charts}
Project/agents/payment_agent.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Payment Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import os
7
+ import json
8
+ import requests
9
+ from typing import Dict, Any, Optional
10
+ from datetime import datetime, timedelta
11
+ import google.generativeai as genai
12
+ from dotenv import load_dotenv
13
+ import time
14
+ import requests
15
+
16
+ from agents.base_agent import BaseAgent
17
+ from state import (
18
+ InvoiceProcessingState, PaymentDecision, PaymentStatus,
19
+ RiskLevel, ValidationStatus, ProcessingStatus, RiskAssessment
20
+ )
21
+ from utils.logger import StructuredLogger
22
+
23
+ load_dotenv()
24
+
25
+
26
+ class PaymentAgent(BaseAgent):
27
+ """Agent responsible for payment processing decisions and execution"""
28
+ # Persistent in-memory history (like validation agent)
29
+ health_history = []
30
+
31
+ def __init__(self, config: Dict[str, Any] = None):
32
+ # pass
33
+ super().__init__("payment_agent", config)
34
+ self.logger = StructuredLogger("PaymentAgent")
35
+ self.approved_vendor_list = ["Acme Corporation", "TechNova Ltd", "SupplyCo"]
36
+ self.retry_limit = 3
37
+ # Health metrics tracking
38
+ self.total_executions = 0
39
+ self.successful_executions = 0
40
+ self.failed_executions = 0
41
+ self.total_duration = 0.0
42
+ self.last_transaction_id = None
43
+ self.last_run = None
44
+
45
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
46
+ # pass
47
+ if workflow_type == "expedited":
48
+ return bool(state.validation_result.validation_status.VALID and state.invoice_data)
49
+ else:
50
+ return bool(state.risk_assessment and state.invoice_data)
51
+
52
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
53
+ # pass
54
+ return bool(state.payment_decision)
55
+
56
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
57
+ # pass
58
+ start_time = time.time()
59
+ try:
60
+ if not self._validate_preconditions(state, workflow_type):
61
+ state.overall_status = ProcessingStatus.FAILED
62
+ self._log_decision(state, "Payment Agent Failed", "Preconditions not met", confidence=0.0)
63
+ return state
64
+
65
+ invoice_data = state.invoice_data
66
+ validation_result = state.validation_result
67
+ if workflow_type == "expedited":
68
+ risk_assessment = RiskAssessment(
69
+ risk_level = RiskLevel.LOW,
70
+ risk_score = 0.3,
71
+ fraud_indicators = None,
72
+ compliance_issues = None,
73
+ recommendation = None,
74
+ reason = "Expedited Workflow Called",
75
+ requires_human_review = "Not needed due to Expedited Workflow"
76
+ )
77
+ payment_decision = PaymentDecision(
78
+ decision = "auto_pay",
79
+ status = PaymentStatus.APPROVED,
80
+ approved_amount = invoice_data.total,
81
+ transaction_id = f"TXN-{datetime.utcnow().strftime('%Y-%m-%d-%H%M%S')}",
82
+ payment_method = self._select_payment_method(invoice_data.total),
83
+ approval_chain = ["system_auto_approval"],
84
+ rejection_reason = None,
85
+ scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH")
86
+ )
87
+ payment_result = await self._execute_payment(invoice_data, payment_decision)
88
+ payment_decision = self._update_payment_decision(payment_decision, payment_result)
89
+
90
+ justification = await self._generate_payment_justification(
91
+ invoice_data, payment_decision, validation_result, risk_assessment
92
+ )
93
+
94
+ state.payment_decision = payment_decision
95
+ state.overall_status = ProcessingStatus.COMPLETED
96
+ state.current_agent = "payment_agent"
97
+ # success criteria
98
+ if payment_decision.status == PaymentStatus.APPROVED:
99
+ self.successful_executions += 1
100
+ else:
101
+ self.failed_executions += 1
102
+
103
+ self.last_transaction_id = payment_decision.transaction_id
104
+ self._log_decision(state, payment_decision.status, justification, 95.0, state.process_id)
105
+ return state
106
+ else:
107
+ risk_assessment = state.risk_assessment
108
+
109
+ payment_decision = await self._make_payment_decision(
110
+ invoice_data, validation_result, risk_assessment, state
111
+ )
112
+ if payment_decision.decision == "auto_pay":
113
+ state.approval_chain = [
114
+ {
115
+ "approved_by":"system_auto_approval in payment_agent",
116
+ "timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
117
+ }
118
+ ]
119
+ else:
120
+ state.approval_chain = [{"payment_agent":"Failed or Rejected"}]
121
+
122
+
123
+ payment_result = await self._execute_payment(invoice_data, payment_decision)
124
+ payment_decision = self._update_payment_decision(payment_decision, payment_result)
125
+
126
+ justification = await self._generate_payment_justification(
127
+ invoice_data, payment_decision, validation_result, risk_assessment
128
+ )
129
+
130
+ state.payment_decision = payment_decision
131
+ state.overall_status = ProcessingStatus.COMPLETED
132
+ state.current_agent = "payment_agent"
133
+ # success criteria
134
+ if payment_decision.status == PaymentStatus.APPROVED:
135
+ print("self.successful_executions---", self.successful_executions)
136
+ self.successful_executions += 1
137
+ else:
138
+ self.failed_executions += 1
139
+
140
+ self.last_transaction_id = payment_decision.transaction_id
141
+ self._log_decision(state, payment_decision.status, justification, 95.0, state.process_id)
142
+ return state
143
+
144
+ except Exception as e:
145
+ self.failed_executions += 1
146
+ self.logger.logger.error(f"[PaymentAgent] Execution failed: {e}")
147
+ state.overall_status = ProcessingStatus.FAILED
148
+ return state
149
+
150
+ finally:
151
+ duration = (time.time() - start_time) * 1000 # in ms
152
+ print("self.total_executions---", self.total_executions)
153
+ self.last_run = datetime.utcnow().isoformat()
154
+ self.total_executions += 1
155
+ self.total_duration += duration
156
+ self._record_health_metrics(duration)
157
+
158
+ async def _make_payment_decision(self, invoice_data, validation_result,
159
+ risk_assessment, state: InvoiceProcessingState) -> PaymentDecision:
160
+ # pass
161
+ amount = invoice_data.total or invoice_data.total_amount or 0.0
162
+ risk_level = risk_assessment.risk_level
163
+ validation_status = validation_result.validation_status
164
+
165
+ if risk_level == RiskLevel.CRITICAL or validation_status == ValidationStatus.INVALID:
166
+ decision = PaymentDecision(
167
+ decision = "reject",
168
+ status = PaymentStatus.FAILED,
169
+ approved_amount = 0.0,
170
+ transaction_id = None,
171
+ payment_method = None,
172
+ approval_chain = [],
173
+ rejection_reason = "Critical Risk or Invalid Validation",
174
+ scheduled_date = None
175
+ )
176
+ elif risk_level == RiskLevel.LOW or amount < 5000:
177
+ decision = PaymentDecision(
178
+ decision = "auto_pay",
179
+ status = PaymentStatus.APPROVED,
180
+ approved_amount = amount,
181
+ transaction_id = f"TXN-{datetime.utcnow().strftime('%Y-%m-%d-%H%M%S')}",
182
+ payment_method = self._select_payment_method(amount),
183
+ approval_chain = ["system_auto_approval"],
184
+ rejection_reason = None,
185
+ scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH")
186
+ )
187
+ elif risk_level == RiskLevel.MEDIUM or validation_status == ValidationStatus.PARTIAL_MATCH:
188
+ decision = PaymentDecision(
189
+ decision = "hold",
190
+ status = PaymentStatus.PENDING_APPROVAL,
191
+ approved_amount = amount,
192
+ transaction_id = None,
193
+ payment_method = self._select_payment_method(amount),
194
+ approval_chain = ["system_auto_approval", "finance_manager_approval"],
195
+ rejection_reason = None,
196
+ scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH")
197
+ )
198
+ else:
199
+ decision = PaymentDecision(
200
+ decision = "manual_approval",
201
+ status = PaymentStatus.PENDING_APPROVAL,
202
+ approved_amount = amount,
203
+ transaction_id = None,
204
+ payment_method = self._select_payment_method(amount),
205
+ approval_chain = ["system_auto_approval", "executive_approval"],
206
+ rejection_reason = None,
207
+ scheduled_date = self._calculate_payment_date(invoice_data.due_date, "WIRE")
208
+ )
209
+
210
+ return decision
211
+
212
+ def _select_payment_method(self, amount: float) -> str:
213
+ # pass
214
+ if amount < 5000:
215
+ return "ACH"
216
+ elif amount < 25000:
217
+ return "WIRE"
218
+ return "MANUAL"
219
+
220
+ def _calculate_payment_date(self, due_date_str: Optional[str], payment_method: str) -> datetime:
221
+ # pass
222
+ due_date = self._parse_date(due_date_str)
223
+ if not due_date:
224
+ due_date = datetime.utcnow().date() + timedelta(days=3)
225
+ offset = 1 if payment_method == "ACH" else 2
226
+ return datetime.combine(due_date, datetime.min.time()) + timedelta(days=offset)
227
+
228
+
229
+ def _parse_date(self, date_str: str) -> Optional[datetime.date]:
230
+ # pass
231
+ if not date_str:
232
+ return None
233
+ try:
234
+ return datetime.strptime(date_str, "%Y-%m-%d").date()
235
+ except Exception:
236
+ return None
237
+
238
+ # async def _execute_payment(self, invoice_data, payment_decision: PaymentDecision) -> Dict[str, Any]:
239
+ # # pass
240
+ # await self._async_sleep(1)
241
+ # response = requests.post("http://localhost:8000", data=PaymentRequest)
242
+ # if payment_decision.status == PaymentStatus.FAILED:
243
+ # return {"status": "failed", "message": "Payment rejected by policy."}
244
+ # return {"status": "success", "transaction_id": payment_decision.transaction_id or f"TXN-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}", "message": "Payment executed."}
245
+
246
+ async def _execute_payment(self, invoice_data, payment_decision: PaymentDecision) -> Dict[str, Any]:
247
+ """Send payment request to web API and return response with transaction_id"""
248
+ import asyncio
249
+ await asyncio.sleep(1)
250
+
251
+ payment_payload = {
252
+ "order_id": invoice_data.invoice_number or f"INV-{int(datetime.utcnow().timestamp())}",
253
+ "customer_name": invoice_data.customer_name or "Unknown Vendor",
254
+ "amount": float(invoice_data.total),
255
+ "currency": "USD",
256
+ # "method": payment_decision.payment_method.lower(),
257
+ "recipient_account": "auto_generated_account",
258
+ "due_date": str(invoice_data.due_date or datetime.utcnow().date())
259
+ }
260
+
261
+ try:
262
+ response = requests.post("http://localhost:8001/initiate_payment", json=payment_payload, timeout=10)
263
+ if response.status_code == 200:
264
+ result = response.json()
265
+ print("res from apiii =======", result)
266
+ return {
267
+ "status": "success" if result["status"] == "SUCCESS" else "failed",
268
+ "transaction_id": result["transaction_id"],
269
+ "message": result["message"]
270
+ }
271
+ else:
272
+ print("res from apiii111111 =======", result)
273
+ return {"status": "failed", "message": f"HTTP {response.status_code}: {response.text}"}
274
+
275
+ except Exception as e:
276
+ print("res from apiii111111222222222222 =======", result)
277
+ return {"status": "failed", "message": f"Payment API error: {e}"}
278
+
279
+ async def _async_sleep(self, seconds: int):
280
+ # pass
281
+ import asyncio
282
+ await asyncio.sleep(seconds)
283
+
284
+ def _update_payment_decision(self, payment_decision: PaymentDecision,
285
+ payment_result: Dict[str, Any]) -> PaymentDecision:
286
+ # pass
287
+ if payment_result.get("status") == "success":
288
+ payment_decision.status = PaymentStatus.APPROVED
289
+ payment_decision.transaction_id = payment_result.get("transaction_id")
290
+ else:
291
+ payment_decision.status = PaymentStatus.FAILED
292
+ payment_decision.rejection_reason = payment_result.get("message")
293
+ return payment_decision
294
+
295
+
296
+ async def _generate_payment_justification(self, invoice_data, payment_decision: PaymentDecision,
297
+ validation_result, risk_assessment) -> str:
298
+ # pass
299
+ reason = f"Payment Decision: {payment_decision.status}. "
300
+ if payment_decision.status == PaymentStatus.FAILED:
301
+ reason += f"Reason: {payment_decision.rejection_reason}"
302
+ reason += f"Risk level: {risk_assessment.risk_level}. Validation: {validation_result.validation_status}."
303
+ return reason
304
+
305
+ def _record_health_metrics(self, duration: float):
306
+ """Update and record health statistics"""
307
+ success_rate = (
308
+ (self.successful_executions / self.total_executions) * 100
309
+ if self.total_executions else 0
310
+ )
311
+ avg_duration = (
312
+ self.total_duration / self.total_executions
313
+ if self.total_executions else 0
314
+ )
315
+ overall_status = "🟢 Healthy"
316
+ if success_rate < 70:
317
+ overall_status = "🟠 Degraded"
318
+ if success_rate < 60:
319
+ overall_status = "🔴 Unhealthy"
320
+
321
+ metrics = {
322
+ "Agent": "Payment Agent 💳",
323
+ "Executions": self.total_executions,
324
+ "Success Rate (%)": round(success_rate, 2),
325
+ "Avg Duration (ms)": round(avg_duration, 2),
326
+ "Total Failures": self.failed_executions,
327
+ "Last Transaction ID": self.last_transaction_id or "N/A",
328
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
329
+ "Last Run": self.last_run,
330
+ "Overall Health": overall_status,
331
+ }
332
+
333
+ PaymentAgent.health_history.append(metrics)
334
+ PaymentAgent.health_history = PaymentAgent.health_history[-50:] # keep last 50
335
+
336
+ async def health_check(self) -> Dict[str, Any]:
337
+ """Return the current or last known health state"""
338
+ await self._async_sleep(0.05)
339
+ if not PaymentAgent.health_history:
340
+ return {
341
+ "Agent": "Payment Agent 💳",
342
+ "Executions": 0,
343
+ "Success Rate (%)": 0.0,
344
+ "Avg Duration (ms)": 0.0,
345
+ "Total Failures": 0,
346
+ "Last Transaction ID": "N/A",
347
+ }
348
+ return PaymentAgent.health_history[-1]
Project/agents/risk_agent.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Risk Assessment Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ from typing import Dict, Any, List
10
+ import google.generativeai as genai
11
+ from dotenv import load_dotenv
12
+ import numpy as np
13
+ from datetime import datetime, timedelta
14
+ from statistics import mean
15
+ import time
16
+ from agents.base_agent import BaseAgent
17
+ from state import (
18
+ InvoiceProcessingState, RiskAssessment, RiskLevel,
19
+ ValidationStatus, ProcessingStatus
20
+ )
21
+ from utils.logger import StructuredLogger
22
+
23
+ load_dotenv()
24
+
25
+ from collections import defaultdict
26
+ class APIKeyBalancer:
27
+ SAVE_FILE = "key_stats.json"
28
+ def __init__(self, keys):
29
+ self.keys = keys
30
+ self.usage = defaultdict(int)
31
+ self.errors = defaultdict(int)
32
+ self.load()
33
+
34
+ def load(self):
35
+ if os.path.exists(self.SAVE_FILE):
36
+ data = json.load(open(self.SAVE_FILE))
37
+ self.usage.update(data.get("usage", {}))
38
+ self.errors.update(data.get("errors", {}))
39
+
40
+ def save(self):
41
+ json.dump({
42
+ "usage": self.usage,
43
+ "errors": self.errors
44
+ }, open(self.SAVE_FILE, "w"))
45
+
46
+ def get_best_key(self):
47
+ # choose least used or least errored key
48
+ best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k]))
49
+ self.usage[best_key] += 1
50
+ self.save()
51
+ return best_key
52
+
53
+ def report_error(self, key):
54
+ self.errors[key] += 1
55
+ self.save()
56
+
57
+ balancer = APIKeyBalancer([
58
+ os.getenv("GEMINI_API_KEY_1"),
59
+ os.getenv("GEMINI_API_KEY_2"),
60
+ os.getenv("GEMINI_API_KEY_3"),
61
+ # os.getenv("GEMINI_API_KEY_4"),
62
+ os.getenv("GEMINI_API_KEY_5"),
63
+ os.getenv("GEMINI_API_KEY_6"),
64
+ # os.getenv("GEMINI_API_KEY_7"),
65
+ ])
66
+
67
+ class RiskAgent(BaseAgent):
68
+ """Agent responsible for risk assessment, fraud detection, and compliance checking"""
69
+
70
+ def __init__(self, config: Dict[str, Any] = None):
71
+ super().__init__("risk_agent",config)
72
+ # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7"))
73
+ self.logger = StructuredLogger("risk_agent")
74
+ self.api_key = balancer.get_best_key()
75
+ print("self.api_key..........", self.api_key)
76
+ genai.configure(api_key=self.api_key)
77
+ self.model = genai.GenerativeModel("gemini-2.0-flash")
78
+ # --- Metrics tracking ---
79
+ self.execution_history: List[Dict[str, Any]] = []
80
+ self.max_history = 50 # keep last 50 runs
81
+
82
+ def generate(self, prompt):
83
+ try:
84
+ response = self.model.generate_content(prompt)
85
+ return response
86
+ except Exception as e:
87
+ balancer.report_error(self.api_key)
88
+ raise
89
+
90
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
91
+ return bool(state.invoice_data and state.validation_result)
92
+
93
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
94
+ return bool(state.risk_assessment and state.risk_assessment.risk_score is not None)
95
+
96
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
97
+ start_time = time.time()
98
+ success = False
99
+ try:
100
+ if not self._validate_preconditions(state, workflow_type):
101
+ state.overall_status = ProcessingStatus.FAILED
102
+ self._log_decision(state, "Risk Assessment Analysis Failed", "Preconditions not met", confidence=0.0)
103
+
104
+ invoice_data = state.invoice_data
105
+ validation_result = state.validation_result
106
+
107
+ base_score = await self._calculate_base_risk_score(invoice_data, validation_result)
108
+ print("base_score:",base_score)
109
+ fraud_indicators = await self._detect_fraud_indicators(invoice_data, validation_result)
110
+ print("fraud_indicators:",fraud_indicators)
111
+ compliance_issues = await self._check_compliance(invoice_data, state)
112
+ print("compliance_issues:",compliance_issues)
113
+ ai_assessment = await self._ai_risk_assessment(invoice_data, validation_result, fraud_indicators)
114
+ print("ai_assessment:",ai_assessment)
115
+
116
+ combined_score = self._combine_risk_factors(base_score, fraud_indicators, compliance_issues, ai_assessment)
117
+ print("combined_score:",combined_score)
118
+
119
+ risk_level = self._determine_risk_level(combined_score)
120
+ print("risk_level:",risk_level)
121
+
122
+ recommendation = self._generate_recommendation(risk_level, fraud_indicators, compliance_issues, validation_result)
123
+ print("recommendation:", recommendation)
124
+ state.risk_assessment = RiskAssessment(
125
+ risk_level = risk_level,
126
+ risk_score = combined_score,
127
+ fraud_indicators = fraud_indicators,
128
+ compliance_issues = compliance_issues,
129
+ recommendation = recommendation["action"],
130
+ reason = recommendation["reason"],
131
+ requires_human_review = recommendation["requires_human_review"]
132
+ )
133
+
134
+ state.current_agent = "risk_agent"
135
+ state.overall_status = ProcessingStatus.IN_PROGRESS
136
+ success = True
137
+ self._log_decision(
138
+ state,
139
+ "Risk Assessment Successful",
140
+ "PDF text successfully verified by Risk Agent and checked by AI",
141
+ combined_score,
142
+ state.process_id
143
+ )
144
+ return state
145
+ finally:
146
+ duration_ms = round((time.time() - start_time) * 1000, 2)
147
+ self._record_execution(success, duration_ms)
148
+
149
+ async def _calculate_base_risk_score(self, invoice_data, validation_result) -> float:
150
+ """
151
+ Calculates an intelligent risk score (0.0–1.0) based on validation results,
152
+ invoice metadata, and contextual financial factors.
153
+ """
154
+ score = 0.0
155
+
156
+ # --- 1. Validation & PO related risks ---
157
+ if validation_result:
158
+ if validation_result.validation_status == ValidationStatus.INVALID:
159
+ score += 0.4
160
+ elif validation_result.validation_status == ValidationStatus.PARTIAL_MATCH:
161
+ score += 0.25
162
+ elif validation_result.validation_status == ValidationStatus.MISSING_PO:
163
+ score += 0.3
164
+
165
+ # Core mismatch signals
166
+ if not validation_result.amount_match:
167
+ score += 0.2
168
+ if not validation_result.rate_match:
169
+ score += 0.15
170
+ if not validation_result.quantity_match:
171
+ score += 0.1
172
+
173
+ # Low confidence from validation adds risk
174
+ if validation_result.confidence_score is not None:
175
+ score += (0.5 - validation_result.confidence_score) * 0.3 if validation_result.confidence_score < 0.5 else 0
176
+
177
+ # --- 2. Invoice amount-based risk ---
178
+ if invoice_data and invoice_data.total is not None:
179
+ total = invoice_data.total
180
+ if total > 1_000_000:
181
+ score += 0.4 # Extremely high-value invoices
182
+ elif total > 100_000:
183
+ score += 0.25
184
+ elif total > 10_000:
185
+ score += 0.1
186
+ elif total < 10:
187
+ score += 0.15 # Suspiciously small invoice
188
+
189
+ # --- 3. Temporal risks (based on due date) ---
190
+ if invoice_data and getattr(invoice_data, "due_date", None):
191
+ try:
192
+ score += self._calculate_due_date_risk(invoice_data.due_date)
193
+ except Exception:
194
+ pass # Graceful degradation if due_date is invalid
195
+
196
+ # --- 4. Vendor / Customer risks ---
197
+ if invoice_data and getattr(invoice_data, "customer_name", None):
198
+ name = invoice_data.customer_name.lower()
199
+ if "new_vendor" in name or "test" in name or "demo" in name:
200
+ score += 0.2
201
+ elif any(flag in name for flag in ["fraud", "fake", "invalid"]):
202
+ score += 0.3
203
+
204
+ # --- 5. Data reliability / extraction confidence ---
205
+ if invoice_data and getattr(invoice_data, "extraction_confidence", None) is not None:
206
+ conf = invoice_data.extraction_confidence
207
+ if conf < 0.5:
208
+ score += 0.2
209
+ elif conf < 0.7:
210
+ score += 0.1
211
+
212
+ # --- 6. Currency and metadata anomalies ---
213
+ currency = getattr(invoice_data, "currency", "USD") or "USD"
214
+ if currency.upper() not in {"USD", "EUR", "GBP", "INR"}:
215
+ score += 0.15 # uncommon currencies add risk
216
+
217
+ # Normalize score within [0, 1.0]
218
+ return round(min(score, 1.0), 3)
219
+
220
+ def _calculate_due_date_risk(self, due_date_str: str) -> float:
221
+ try:
222
+ due_date = self._parse_date(due_date_str)
223
+ days_until_due = (due_date - datetime.utcnow().date()).days
224
+ if days_until_due < 0:
225
+ return 0.2
226
+ elif days_until_due < 5:
227
+ return 0.1
228
+ return 0.0
229
+ except Exception:
230
+ return 0.05
231
+
232
+ def _parse_date(self, date_str: str) -> datetime.date:
233
+ return datetime.strptime(date_str,"%Y-%m-%d").date()
234
+
235
+ async def _detect_fraud_indicators(self, invoice_data, validation_result) -> List[str]:
236
+ """
237
+ Performs intelligent fraud detection on the given invoice and validation results.
238
+ Returns a list of detected fraud indicators.
239
+ """
240
+ indicators = []
241
+
242
+ # 1. PO / Validation mismatches
243
+ if validation_result:
244
+ if not validation_result.po_found:
245
+ indicators.append("No matching Purchase Order found")
246
+ if not validation_result.amount_match:
247
+ indicators.append("Amount discrepancy detected")
248
+ if not validation_result.rate_match:
249
+ indicators.append("Rate inconsistency with Purchase Order")
250
+ if not validation_result.quantity_match:
251
+ indicators.append("Quantity mismatch detected")
252
+ if validation_result.confidence_score is not None and validation_result.confidence_score < 0.6:
253
+ indicators.append(f"Low validation confidence ({validation_result.confidence_score:.2f})")
254
+
255
+ # 2. Vendor / Customer anomalies
256
+ customer_name = getattr(invoice_data, "customer_name", "") or ""
257
+ if "test" in customer_name.lower() or "demo" in customer_name.lower():
258
+ indicators.append("Suspicious vendor name (Test/Demo account)")
259
+ if "new_vendor" in customer_name.lower():
260
+ indicators.append("First-time or unverified vendor")
261
+ if any(keyword in customer_name.lower() for keyword in ["fraud", "fake", "invalid"]):
262
+ indicators.append("Vendor flagged with risky keywords")
263
+
264
+ # 3. Amount-level risk signals
265
+ total = getattr(invoice_data, "total", 0.0) or 0.0
266
+ if total > 1_000_000:
267
+ indicators.append(f"Unusually high invoice total (${total:,.2f})")
268
+ elif total < 10:
269
+ indicators.append(f"Suspiciously low invoice total (${total:,.2f})")
270
+
271
+ # 4. Date anomalies
272
+ due_date = getattr(invoice_data, "due_date", None)
273
+ invoice_date = getattr(invoice_data, "invoice_date", None)
274
+ if invoice_date and due_date and (due_date - invoice_date).days < 0:
275
+ indicators.append("Due date earlier than invoice date (possible manipulation)")
276
+ elif invoice_date and due_date and (due_date - invoice_date).days < 3:
277
+ indicators.append("Unusually short payment window")
278
+
279
+ # 5. Duplicate or pattern-based red flags
280
+ if invoice_data.invoice_number and invoice_data.invoice_number.lower().startswith("dup-"):
281
+ indicators.append("Possible duplicate invoice ID pattern")
282
+ if invoice_data.file_name and "copy" in invoice_data.file_name.lower():
283
+ indicators.append("Invoice filename suggests duplication")
284
+
285
+ # 6. Confidence anomalies (AI extraction)
286
+ if invoice_data.extraction_confidence is not None and invoice_data.extraction_confidence < 0.5:
287
+ indicators.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f}) — possible OCR tampering")
288
+
289
+ # 7. Currency or unusual metadata patterns
290
+ if getattr(invoice_data, "currency", "").upper() not in {"USD", "EUR", "GBP", "INR"}:
291
+ indicators.append(f"Uncommon currency code: {invoice_data.currency}")
292
+
293
+ return indicators
294
+
295
+
296
+ async def _check_compliance(self, invoice_data, state: InvoiceProcessingState) -> List[str]:
297
+ """
298
+ Performs a multi-layer compliance check on invoice and state integrity.
299
+ Returns a list of detected compliance issues.
300
+ """
301
+ issues = []
302
+
303
+ # 1. Invoice integrity checks
304
+ if not invoice_data.invoice_number:
305
+ issues.append("Missing invoice number")
306
+ if not invoice_data.customer_name:
307
+ issues.append("Missing customer name")
308
+ if not invoice_data.total or invoice_data.total <= 0:
309
+ issues.append("Invalid or missing total amount")
310
+ if not invoice_data.due_date:
311
+ issues.append("Missing due date")
312
+
313
+ # 2. Item-level verification
314
+ if not invoice_data.item_details or len(invoice_data.item_details) == 0:
315
+ issues.append("No item details present")
316
+ else:
317
+ for item in invoice_data.item_details:
318
+ if not getattr(item, "item_name", None):
319
+ issues.append("Item missing name")
320
+ if getattr(item, "quantity", 1) <= 0:
321
+ issues.append(f"Invalid quantity for item '{item.item_name or 'Unknown'}'")
322
+
323
+ # 3. Confidence & quality checks
324
+ if invoice_data.extraction_confidence and invoice_data.extraction_confidence < 0.7:
325
+ issues.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f})")
326
+
327
+ # 4. Workflow state checks
328
+ if not getattr(state, "approval_chain", True):
329
+ issues.append("Approval chain incomplete")
330
+ if getattr(state, "escalation_required", False):
331
+ issues.append("Escalation required before payment")
332
+ if getattr(state, "human_review_required", False):
333
+ issues.append("Pending human review")
334
+
335
+ # 5. Audit consistency
336
+ if len(state.audit_trail) == 0:
337
+ issues.append("No audit trail entries found")
338
+
339
+ # # 6. Optional receipt confirmation
340
+ # if not getattr(invoice_data, "receipt_confirmed", False):
341
+ # issues.append("Missing receipt confirmation")
342
+
343
+ # 7. Risk-based compliance (if risk assessment exists)
344
+ if state.risk_assessment and state.risk_assessment.risk_score >= 0.7:
345
+ issues.append(f"High risk score detected ({state.risk_assessment.risk_score:.2f})")
346
+
347
+ return issues
348
+
349
+
350
+ async def _ai_risk_assessment(
351
+ self,
352
+ invoice_data,
353
+ validation_result,
354
+ fraud_indicators: List[str]
355
+ ) -> Dict[str, Any]:
356
+ """
357
+ Uses a Generative AI model (Gemini) to assess risk level based on
358
+ structured invoice data, validation results, and detected fraud indicators.
359
+
360
+ Returns:
361
+ dict: {
362
+ "risk_score": float between 0–1,
363
+ "reason": str (explanation for the score)
364
+ }
365
+ """
366
+ self.logger.logger.info("[RiskAgent] Running AI-based risk assessment...")
367
+ # model_name = "gemini-2.5-flash"
368
+ result = {"risk_score": 0.0, "reason": "Default – AI assessment not available"}
369
+
370
+ try:
371
+ # Initialize model
372
+ # model = genai.GenerativeModel(model_name)
373
+
374
+ # --- Construct dynamic and context-rich prompt ---
375
+ prompt = f"""
376
+ You are a financial risk analysis model for invoice fraud detection.
377
+ Carefully analyze the following details:
378
+
379
+ INVOICE DATA:
380
+ {invoice_data}
381
+
382
+ VALIDATION RESULT:
383
+ {validation_result}
384
+
385
+ DETECTED FRAUD INDICATORS:
386
+ {fraud_indicators}
387
+
388
+ TASK:
389
+ 1. Assess overall risk of this invoice being fraudulent or non-compliant.
390
+ 2. Provide reasoning.
391
+ 3. Respond **only in JSON** with keys:
392
+ - "risk_score": a float between 0 and 1 (higher = higher risk)
393
+ - "reason": short explanation of what contributed to this score.
394
+
395
+ EXAMPLES:
396
+ {{
397
+ "risk_score": 0.85,
398
+ "reason": "High amount mismatch, new vendor, and unusual currency"
399
+ }}
400
+ {{
401
+ "risk_score": 0.25,
402
+ "reason": "Valid PO and consistent totals, low fraud signals"
403
+ }}
404
+ """
405
+ import asyncio
406
+ # --- Model call ---
407
+ response = self.generate(prompt)
408
+ # response = await asyncio.to_thread(model.generate_content, prompt)
409
+
410
+ # --- Clean and parse ---
411
+ raw_text = getattr(response, "text", "") or ""
412
+ cleaned_json = self._clean_json_response(raw_text)
413
+ ai_output = json.loads(cleaned_json)
414
+
415
+ # --- Validate AI output ---
416
+ score = float(ai_output.get("risk_score", 0.0))
417
+ reason = str(ai_output.get("reason", "No reason provided"))
418
+
419
+ # Clamp score between 0–1 for safety
420
+ result = {
421
+ "risk_score": max(0.0, min(score, 1.0)),
422
+ "reason": reason.strip()[:400] # limit for logs
423
+ }
424
+
425
+ self.logger.logger.info(
426
+ f"[RiskAgent] AI Risk Assessment completed: score={result['risk_score']}, reason={result['reason']}"
427
+ )
428
+
429
+ except json.JSONDecodeError as e:
430
+ self.logger.logger.warning(f"[RiskAgent] JSON parsing failed: {e}")
431
+ result["reason"] = "AI response could not be parsed"
432
+
433
+ except Exception as e:
434
+ self.logger.logger.error(f"[RiskAgent] AI assessment error: {e}", exc_info=True)
435
+ result["reason"] = "Fallback to base risk model"
436
+
437
+ return result
438
+
439
+
440
+ def _clean_json_response(self, text: str) -> str:
441
+ text = re.sub(r'^[^{]*','',text)
442
+ text = re.sub(r'[^}]*$','',text)
443
+ return text
444
+
445
+ def _combine_risk_factors(
446
+ self,
447
+ base_score: float,
448
+ fraud_indicators: List[str],
449
+ compliance_issues: List[str],
450
+ ai_assessment: Dict[str, Any]
451
+ ) -> float:
452
+ """
453
+ Combines multiple risk components (base, fraud, compliance, and AI analysis)
454
+ into a single normalized risk score between 0.0 and 1.0.
455
+
456
+ Weighting strategy:
457
+ - Base Score: foundation derived from deterministic checks
458
+ - Fraud Indicators: +0.1 per flag (max +0.3)
459
+ - Compliance Issues: +0.05 per issue (max +0.2)
460
+ - AI Risk Score: contributes 40–50% of total weight
461
+
462
+ Returns:
463
+ float: final risk score clamped to [0, 1]
464
+ """
465
+ try:
466
+ # Extract and normalize AI risk
467
+ ai_score = float(ai_assessment.get("risk_score", 0.0))
468
+ ai_score = max(0.0, min(ai_score, 1.0))
469
+
470
+ # --- Weighted contributions ---
471
+ fraud_contrib = min(len(fraud_indicators) * 0.1, 0.3)
472
+ compliance_contrib = min(len(compliance_issues) * 0.05, 0.2)
473
+ ai_contrib = 0.5 * ai_score if ai_score > 0 else 0.2 * base_score
474
+
475
+ combined = base_score + fraud_contrib + compliance_contrib + ai_contrib
476
+
477
+ # Cap at 1.0 for safety
478
+ final_score = round(min(combined, 1.0), 3)
479
+
480
+ self.logger.logger.info(
481
+ f"[RiskAgent] Combined risk computed: base={base_score}, "
482
+ f"fraud_flags={len(fraud_indicators)}, compliance_flags={len(compliance_issues)}, "
483
+ f"ai_score={ai_score}, final={final_score}"
484
+ )
485
+
486
+ return final_score
487
+
488
+ except Exception as e:
489
+ self.logger.logger.error(f"[RiskAgent] Error combining risk factors: {e}", exc_info=True)
490
+ return min(base_score + 0.2, 1.0) # fallback conservative estimate
491
+
492
+
493
+ def _determine_risk_level(self, risk_score: float) -> RiskLevel:
494
+ if risk_score<0.3:
495
+ return RiskLevel.LOW
496
+ elif risk_score<0.6:
497
+ return RiskLevel.MEDIUM
498
+ elif risk_score<0.8:
499
+ return RiskLevel.HIGH
500
+ return RiskLevel.CRITICAL
501
+
502
+ def _generate_recommendation(
503
+ self,
504
+ risk_level: RiskLevel,
505
+ fraud_indicators: List[str],
506
+ compliance_issues: List[str],
507
+ validation_result
508
+ ) -> Dict[str, Any]:
509
+ """
510
+ Generate a structured recommendation (approve, escalate, or reject)
511
+ based on overall risk, fraud, and compliance outcomes.
512
+
513
+ Decision Logic:
514
+ - HIGH / CRITICAL risk → escalate for human review
515
+ - INVALID validation → reject
516
+ - Medium risk with minor issues → escalate
517
+ - Otherwise → approve
518
+
519
+ Returns:
520
+ Dict[str, Any]: {
521
+ 'action': str, # 'approve', 'escalate', or 'reject'
522
+ 'reason': str, # Explanation summary
523
+ 'requires_human_review': bool
524
+ }
525
+ """
526
+ try:
527
+ # --- Determine key flags ---
528
+ has_fraud = bool(fraud_indicators)
529
+ has_compliance_issues = bool(compliance_issues)
530
+ validation_invalid = (
531
+ validation_result and validation_result.validation_status == ValidationStatus.INVALID
532
+ )
533
+
534
+ # --- Decision Logic ---
535
+ if validation_invalid:
536
+ action = "reject"
537
+ requires_review = True
538
+ reason = "Validation failed: " + "; ".join(fraud_indicators + compliance_issues or ["Invalid invoice data"])
539
+
540
+ elif risk_level in [RiskLevel.HIGH, RiskLevel.CRITICAL]:
541
+ action = "escalate"
542
+ requires_review = True
543
+ reason = f"High risk level detected ({risk_level.value}). Issues: " + "; ".join(fraud_indicators + compliance_issues or ["Potential anomalies"])
544
+
545
+ elif has_fraud or has_compliance_issues:
546
+ action = "escalate"
547
+ requires_review = True
548
+ reason = "Minor irregularities found: " + "; ".join(fraud_indicators + compliance_issues)
549
+
550
+ else:
551
+ action = "approve"
552
+ requires_review = False
553
+ reason = "All checks passed; invoice appears valid and compliant."
554
+
555
+ # --- Structured Output ---
556
+ recommendation = {
557
+ "action": action,
558
+ "reason": reason,
559
+ "requires_human_review": requires_review,
560
+ }
561
+
562
+ self.logger.logger.info(
563
+ f"[DecisionAgent] Recommendation generated: {recommendation}"
564
+ )
565
+ return recommendation
566
+
567
+ except Exception as e:
568
+ self.logger.logger.error(f"[DecisionAgent] Error generating recommendation: {e}", exc_info=True)
569
+ # Safe fallback
570
+ return {
571
+ "action": "escalate",
572
+ "reason": "Error during recommendation generation",
573
+ "requires_human_review": True,
574
+ }
575
+
576
+
577
+ def _record_execution(self, success: bool, duration_ms: float):
578
+ self.execution_history.append({
579
+ # "timestamp": datetime.utcnow().isoformat(),
580
+ "success": success,
581
+ "duration_ms": duration_ms,
582
+ })
583
+ # Keep recent N only
584
+ if len(self.execution_history) > self.max_history:
585
+ self.execution_history.pop(0)
586
+
587
+ async def health_check(self) -> Dict[str, Any]:
588
+ total_runs = len(self.execution_history)
589
+ if total_runs == 0:
590
+ return {
591
+ "Agent": "Risk Agent ⚠️",
592
+ "Executions": 0,
593
+ "Success Rate (%)": 0.0,
594
+ "Avg Duration (ms)": 0.0,
595
+ "Total Failures": 0,
596
+ "Status": "idle",
597
+ # "Timestamp": datetime.utcnow().isoformat()
598
+ }
599
+ metrics_data = {}
600
+ executions = 0
601
+ success_rate = 0.0
602
+ avg_duration = 0.0
603
+ failures = 0
604
+ last_run = None
605
+
606
+ # 1. Try to get live metrics from state
607
+ # print("(self.state)-------",self.metrics)
608
+ # print("self.state.agent_metrics-------", self.state.agent_metrics)
609
+ if self.metrics:
610
+ executions = self.metrics["processed"]
611
+ avg_duration = self.metrics["avg_latency_ms"]
612
+ failures = self.metrics["errors"]
613
+ last_run = self.metrics["last_run_at"]
614
+ success_rate = (executions - failures) / (executions+1e-8)
615
+
616
+ # 2. API connectivity check
617
+ gemini_ok = bool(self.api_key)
618
+ api_status = "🟢 Active" if gemini_ok else "🔴 Missing Key"
619
+
620
+ # 3. Health logic
621
+ overall_status = "🟢 Healthy"
622
+ if not gemini_ok or failures > 3:
623
+ overall_status = "🟠 Degraded"
624
+ if executions > 0 and success_rate < 0.5:
625
+ overall_status = "🔴 Unhealthy"
626
+
627
+ successes = sum(1 for e in self.execution_history if e["success"])
628
+ failures = total_runs - successes
629
+ avg_duration = round(mean(e["duration_ms"] for e in self.execution_history), 2)
630
+ success_rate = round((successes / (total_runs+1e-8)) * 100, 2)
631
+
632
+ return {
633
+ "Agent": "Risk Agent ⚠️",
634
+ "Executions": total_runs,
635
+ "Success Rate (%)": success_rate,
636
+ "Avg Duration (ms)": avg_duration,
637
+ "API Status": api_status,
638
+ "Total Failures": failures,
639
+ "Last Run": str(last_run) if last_run else "Not applicable",
640
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
641
+ "Overall Health": overall_status,
642
+ }
643
+
644
+
Project/agents/smart_explainer_agent.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Smart Explainer Agent (Enhanced + Gemini-powered)
4
+ - Produces a detailed, human-readable explanation for a single InvoiceProcessingState.
5
+ - Uses Gemini for natural summarization if API key is present.
6
+ - Defensive, HTML-enhanced, and fully dashboard-ready.
7
+ """
8
+
9
+ from state import InvoiceProcessingState, ValidationStatus, PaymentStatus, RiskLevel
10
+ from datetime import datetime
11
+ import google.generativeai as genai
12
+ import json
13
+ import os
14
+
15
+
16
+ class SmartExplainerAgent:
17
+ def __init__(self):
18
+ # Configure Gemini only if available
19
+ self.api_key = os.environ.get("GEMINI_API_KEY_4")
20
+ self.use_gemini = bool(self.api_key)
21
+ if self.use_gemini:
22
+ genai.configure(api_key=self.api_key)
23
+ self.model = genai.GenerativeModel("gemini-2.0-flash")
24
+
25
+ # ---------- Helper functions ----------
26
+ def _safe_invoice_dict(self, state: InvoiceProcessingState) -> dict:
27
+ if not state or not getattr(state, "invoice_data", None):
28
+ return {}
29
+ return (
30
+ state.invoice_data.model_dump(exclude_none=True)
31
+ if hasattr(state.invoice_data, "model_dump")
32
+ else state.invoice_data.dict()
33
+ )
34
+
35
+ def _safe_validation(self, state: InvoiceProcessingState) -> dict:
36
+ if not state or not getattr(state, "validation_result", None):
37
+ return {}
38
+ return (
39
+ state.validation_result.model_dump(exclude_none=True)
40
+ if hasattr(state.validation_result, "model_dump")
41
+ else state.validation_result.dict()
42
+ )
43
+
44
+ def _safe_risk(self, state: InvoiceProcessingState) -> dict:
45
+ if not state or not getattr(state, "risk_assessment", None):
46
+ return {}
47
+ return (
48
+ state.risk_assessment.model_dump(exclude_none=True)
49
+ if hasattr(state.risk_assessment, "model_dump")
50
+ else state.risk_assessment.dict()
51
+ )
52
+
53
+ # ---------- Core explain logic ----------
54
+ def explain(self, state) -> str:
55
+ """
56
+ Generate a detailed HTML + markdown explanation for a given invoice.
57
+ Falls back gracefully if data or Gemini is unavailable.
58
+ """
59
+
60
+ # --- Defensive normalization ---
61
+ if state is None:
62
+ return "<p>⚠️ No invoice state provided.</p>"
63
+
64
+ if isinstance(state, dict):
65
+ try:
66
+ state = InvoiceProcessingState(**state)
67
+ except Exception:
68
+ pass
69
+
70
+ # --- Extract fields safely ---
71
+ invoice = self._safe_invoice_dict(state) or {}
72
+ validation = self._safe_validation(state) or {}
73
+ risk = self._safe_risk(state) or {}
74
+ payment = (
75
+ state.payment_decision.model_dump(exclude_none=True)
76
+ if getattr(state, "payment_decision", None)
77
+ and hasattr(state.payment_decision, "model_dump")
78
+ else getattr(state, "payment_decision", {}) or {}
79
+ )
80
+
81
+ discrepancies = validation.get("discrepencies", []) # per schema
82
+
83
+ inv_id = invoice.get("invoice_number") or invoice.get("file_name") or "<unknown>"
84
+ vendor = invoice.get("customer_name") or invoice.get("vendor_name") or "Unknown"
85
+ total = invoice.get("total") or invoice.get("amount") or 0
86
+
87
+ status = getattr(state, "overall_status", "unknown")
88
+ status_val = status.value if hasattr(status, "value") else str(status)
89
+
90
+ # --- Interpret status fields ---
91
+ risk_level = risk.get("risk_level")
92
+ if hasattr(risk_level, "value"):
93
+ risk_level = risk_level.value
94
+ risk_score = risk.get("risk_score", 0) or 0.0
95
+
96
+ val_status = validation.get("validation_status")
97
+ if hasattr(val_status, "value"):
98
+ val_status = val_status.value
99
+
100
+ payment_status = payment.get("status")
101
+ if hasattr(payment_status, "value"):
102
+ payment_status = payment_status.value
103
+
104
+ # --- Badge colors ---
105
+ colors = {
106
+ "VALIDATION": "#ffc107",
107
+ "RISK": (
108
+ "#ff1744" if str(risk_level).lower() == "critical"
109
+ else "#ff9800" if str(risk_level).lower() == "medium"
110
+ else "#4caf50"
111
+ ),
112
+ "PAYMENT": "#4caf50",
113
+ "AUDIT": "#2196f3",
114
+ }
115
+
116
+ # --- Header layout ---
117
+ header_html = f"""
118
+ <div style="display:flex;justify-content:center;margin-bottom:1rem;">
119
+ <div style="flex:1;text-align:center;padding:0.8rem;
120
+ border-radius:10px;background:{colors['VALIDATION']};
121
+ color:white;margin:0 4px;">
122
+ <b>Validation</b>
123
+ </div>
124
+ <div style="flex:1;text-align:center;padding:0.8rem;
125
+ border-radius:10px;background:{colors['RISK']};
126
+ color:white;margin:0 4px;">
127
+ <b>Risk</b>
128
+ </div>
129
+ <div style="flex:1;text-align:center;padding:0.8rem;
130
+ border-radius:10px;background:{colors['PAYMENT']};
131
+ color:white;margin:0 4px;">
132
+ <b>Payment</b>
133
+ </div>
134
+ <div style="flex:1;text-align:center;padding:0.8rem;
135
+ border-radius:10px;background:{colors['AUDIT']};
136
+ color:white;margin:0 4px;box-shadow:0 0 10px rgba(0,255,0,0.7);">
137
+ <b>Audit</b>
138
+ </div>
139
+ </div>
140
+ """
141
+
142
+ # --- Formatter ---
143
+ def _fmt(val):
144
+ if val is None:
145
+ return "N/A"
146
+ if isinstance(val, (int, float)) and not isinstance(val, bool):
147
+ return f"${val:,.2f}"
148
+ return str(val)
149
+
150
+ # --- Base explanation (structured) ---
151
+ lines = [
152
+ f"<p><b>Invoice:</b> {inv_id}</p>",
153
+ f"<p><b>Vendor:</b> {vendor}</p>",
154
+ f"<p><b>Amount:</b> {_fmt(total)}</p>",
155
+ f"<p><b>Status:</b> {status_val}</p>",
156
+ "<hr>",
157
+ f"<p><b>Validation:</b> {val_status or 'unknown'}</p>",
158
+ f"<p><b>Risk Level:</b> {risk_level or 'low'} ({risk_score})</p>",
159
+ f"<p><b>Payment:</b> {payment.get('decision', 'N/A')} ({payment_status or 'pending'})</p>",
160
+ ]
161
+
162
+ if discrepancies:
163
+ lines.append("<p><b>Discrepancies Found:</b></p><ul>")
164
+ for d in discrepancies:
165
+ field = d.get("field", "unknown")
166
+ expected = d.get("expected", "")
167
+ actual = d.get("actual", "")
168
+ lines.append(f"<li>{field}: expected <code>{expected}</code>, got <code>{actual}</code></li>")
169
+ lines.append("</ul>")
170
+
171
+ # --- Recommendations ---
172
+ advice = []
173
+ if str(val_status).lower() == "invalid":
174
+ advice.append("❌ Invoice failed validation — requires manual review.")
175
+ elif str(val_status).lower() in ("partial", "partial_match"):
176
+ advice.append("⚠️ Partial validation — check mismatched fields.")
177
+ if str(risk_level).lower() == "critical":
178
+ advice.append("🚨 Critical risk detected — immediate escalation required.")
179
+ elif str(risk_level).lower() == "medium":
180
+ advice.append("⚠️ Medium risk — consider manual review.")
181
+ if not advice:
182
+ advice.append("✅ No major issues detected. Proceed as usual.")
183
+
184
+ lines.append("<p><b>Recommendation:</b></p><ul>")
185
+ for a in advice:
186
+ lines.append(f"<li>{a}</li>")
187
+ lines.append("</ul>")
188
+
189
+ explanation_html = header_html + "\n".join(lines)
190
+
191
+ # --- Gemini polishing (using your API key) ---
192
+ if self.use_gemini:
193
+ try:
194
+ import google.generativeai as genai
195
+ model = genai.GenerativeModel("models/gemini-2.0-flash")
196
+
197
+ prompt = f"""
198
+ You are a professional financial analyst.
199
+ Here is structured invoice data and an auto-generated explanation.
200
+
201
+ Invoice summary:
202
+ {json.dumps(invoice, indent=2)}
203
+
204
+ Validation details: {json.dumps(validation, indent=2)}
205
+ Risk assessment: {json.dumps(risk, indent=2)}
206
+ Payment info: {json.dumps(payment, indent=2)}
207
+
208
+ Rewrite the following explanation to sound executive-level, clear, and concise.
209
+ Use HTML for sections but do not remove any factual details.
210
+
211
+ Existing summary:
212
+ {explanation_html}
213
+ """
214
+ response = model.generate_content(prompt)
215
+ if response and getattr(response, "text", None):
216
+ return response.text.strip()
217
+ except Exception as e:
218
+ return explanation_html + f"<p><i>Gemini explanation failed: {e}</i></p>"
219
+
220
+ return explanation_html
Project/agents/validation_agent.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Validation Agent for Invoice Processing"""
3
+
4
+ # TODO: Implement agent
5
+ import asyncio
6
+ import os
7
+ import pandas as pd
8
+ from typing import Dict, Any, List, Tuple
9
+ from fuzzywuzzy import fuzz
10
+ import numpy as np
11
+ import time
12
+ from agents.base_agent import BaseAgent
13
+ from state import (
14
+ InvoiceProcessingState, ValidationResult, ValidationStatus,
15
+ ProcessingStatus
16
+ )
17
+ from datetime import datetime, timedelta
18
+
19
+ from utils.logger import StructuredLogger
20
+ from difflib import SequenceMatcher
21
+
22
+ class ValidationAgent(BaseAgent):
23
+ """Agent responsible for validating invoice data against purchase orders"""
24
+
25
+ health_history: List[Dict[str, Any]] = [] # global history for metrics
26
+
27
+ def __init__(self, config: Dict[str, Any] = None):
28
+ # pass
29
+ super().__init__(agent_name="validation_agent",config=config or {})
30
+ self.logger = StructuredLogger(__name__)
31
+ self.po_file = self.config.get("po_file","data/purchase_orders.csv")
32
+ self.tolerance = self.config.get("tolerance",0.05)
33
+ self.successful_executions = 0
34
+ self.failed_executions = 0
35
+ self.total_duration = 0.0
36
+ self.total_executions = 0
37
+ self.last_run = None
38
+ # self.match_threshold = self.config.get("match_threshold",80)
39
+
40
+ def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
41
+ # pass
42
+ if not state.invoice_data:
43
+ self.logger.logger.error("No invoice data available for validation.")
44
+ return False
45
+ return True
46
+
47
+ def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
48
+ # pass
49
+ return hasattr(state,'validation_result') and state.validation_result is not None
50
+
51
+ async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
52
+ # pass
53
+ self.logger.logger.info(f"[ValidationAgent] Starting validation for {state.file_name}")
54
+ start_time = time.time()
55
+ try:
56
+ if not self._validate_preconditions(state, workflow_type):
57
+ state.status = ProcessingStatus.FAILED
58
+ self._log_decision(state,"Validation Failed","Precondition not met",confidence = 0.0)
59
+ return state
60
+ invoice_data = state.invoice_data
61
+ matching_pos = await self._find_matching_pos(invoice_data)
62
+ validation_result = await self._validate_against_pos(invoice_data,matching_pos)
63
+ state.validation_result = validation_result
64
+ state.current_agent = "validation_agent"
65
+ state.overall_status = ProcessingStatus.IN_PROGRESS
66
+
67
+ if self._should_escalate_validation(validation_result, invoice_data):
68
+ state.escalation_required = True
69
+ self._validate_postconditions(state)
70
+ self.successful_executions += 1
71
+ self.last_run = datetime.utcnow().isoformat()
72
+ # print("ValidationResult().confidence_score", state.validation_result.confidence_score)
73
+ self._log_decision(
74
+ state,
75
+ "Validation Successful",
76
+ "PDF text successfully validated and checked by AI",
77
+ state.validation_result.confidence_score,
78
+ state.process_id
79
+ )
80
+ return state
81
+ except Exception as e:
82
+ self.logger.logger.error(f"[ValidationAgent] Execution failed: {e}")
83
+ self.failed_executions += 1
84
+ state.overall_status = ProcessingStatus.FAILED
85
+ return state
86
+
87
+ finally:
88
+ duration = (time.time() - start_time) * 1000 # ms
89
+ self.total_executions += 1
90
+ self.total_duration += duration
91
+ self._record_health_metrics(duration)
92
+
93
+ def _load_purchase_orders(self) -> pd.DataFrame:
94
+ # pass
95
+ """load po data from csv"""
96
+ try:
97
+ df = pd.read_csv(self.po_file)
98
+ self.logger.logger.info(f"[ValidationAgent] Loaded {len(df)} purchase orders")
99
+ return df
100
+ except Exception as e:
101
+ self.logger.logger.error(f"[ValidationAgent] failed to load purchase order: {e}")
102
+ raise
103
+
104
+ async def _find_matching_pos(self, invoice_data) -> List[Dict[str, Any]]:
105
+ """find POs matching invoice order_id or fuzzy customer/items"""
106
+ po_df = self._load_purchase_orders()
107
+ matches = []
108
+ for _,po in po_df.iterrows():
109
+ customer_score = fuzz.token_sort_ratio(po["customer_name"], invoice_data.customer_name)
110
+ order_id_score = fuzz.token_sort_ratio(po["order_id"], invoice_data.order_id)
111
+ for item in invoice_data.item_details:
112
+ item_score = fuzz.token_sort_ratio(po["item_name"],item.item_name)
113
+ print(f"Compairing PO item {po['item_name']} with invoice item {item.item_name}: score = {item_score}")
114
+
115
+ if (customer_score >= 80) and (item_score >=80) and (order_id_score >=80) and (po['invoice_number'] == int(invoice_data.invoice_number)):
116
+ matches.append(po.to_dict())
117
+
118
+ print("matches.....", matches)
119
+ return matches
120
+
121
+
122
+ async def _validate_against_pos(self, invoice_data, matching_pos: List[Dict[str, Any]]) -> ValidationResult:
123
+ # pass
124
+
125
+ if not matching_pos:
126
+ return ValidationResult(po_found=False, validation_status='missing_po',validation_result='No matching purchase order found',
127
+ discrepancies = [],
128
+ confidence_score = 0.0)
129
+ po_data = matching_pos[0]
130
+ discrepancies = self._validate_item_against_po(invoice_data,po_data)
131
+ discrepancies += self._validate_totals(invoice_data,po_data)
132
+ actual_amount = [item.amount for item in invoice_data.item_details][0]
133
+ actual_quantity = [item.quantity for item in invoice_data.item_details][0]
134
+ actual_rate = [item.rate for item in invoice_data.item_details][0]
135
+ amount_diff = abs(actual_amount - po_data.get('expected_amount',0))
136
+ tolerance_limit = po_data.get('expected_amount',0)*self.tolerance
137
+ amount_match = amount_diff <= tolerance_limit
138
+
139
+ validation_result = ValidationResult(
140
+ po_found=True,
141
+ quantity_match=actual_quantity == po_data.get('quantity'),
142
+ rate_match=abs(actual_rate - po_data.get('rate', 0)) <= tolerance_limit,
143
+ amount_match=amount_match,
144
+ validation_status=ValidationStatus.NOT_STARTED, # temporary
145
+ validation_result="; ".join(discrepancies) if discrepancies else "All checks passed",
146
+ discrepencies=discrepancies,
147
+ confidence_score=0.0, # temporary
148
+ expected_amount=po_data.get('amount'),
149
+ po_data=po_data
150
+ )
151
+ validation_result.validation_status = self._determine_validation_status(validation_result)
152
+ validation_result.confidence_score = self._calculate_validation_confidence(validation_result, matching_pos, invoice_data)
153
+ return validation_result
154
+
155
+ def _validate_item_against_po(self, item, po_data: Dict[str, Any]) -> List[str]:
156
+ # pass
157
+ # print("itemmmmmmmmm", item.item_details.quantity)
158
+ print("po_-------------", po_data)
159
+ discrepancies = []
160
+ for item in item.item_details:
161
+ if item.quantity != po_data.get('quantity'):
162
+ discrepancies.append(f"Quantity mismatch: Expected {po_data['quantity']}, Found {item.quantity}")
163
+ if abs(item.rate - po_data.get('rate',0)) > po_data.get('rate',0)*self.tolerance:
164
+ discrepancies.append(f"Rate mismatch: Expected {po_data['rate']}, Found {item.rate}")
165
+ return discrepancies
166
+
167
+ def _validate_totals(self, invoice_data, po_data: Dict[str, Any]) -> List[str]:
168
+ # pass
169
+ discrepancies = []
170
+ expected = po_data.get('expected_amount',0)
171
+ actual = [item.amount for item in invoice_data.item_details][0]
172
+ diff = abs(expected-actual)
173
+ if diff > expected*self.tolerance:
174
+ discrepancies.append(f"Total amount mismatch: Expected {expected}, Actual {actual} (Difference:{diff:.2f})")
175
+ return discrepancies
176
+
177
+ def _calculate_validation_confidence(self, validation_result: ValidationResult,
178
+ matching_pos: List[Dict[str, Any]], invoice_data) -> float:
179
+ """
180
+ Compute an intelligent, weighted confidence score across 7 key dimensions:
181
+ invoice_number, order_id, customer_name, item_name, amount, rate, quantity.
182
+ Each field contributes based on importance.
183
+ """
184
+
185
+ if not validation_result.po_found or not matching_pos:
186
+ return 0.0
187
+
188
+ po_data = matching_pos[0]
189
+
190
+ # Extract PO (expected) values
191
+ expected = {
192
+ "invoice_number": po_data.get("invoice_number", ""),
193
+ "order_id": po_data.get("order_id", ""),
194
+ "customer_name": po_data.get("customer_name", ""),
195
+ "item_name": po_data.get("item_name", ""),
196
+ "amount": float(po_data.get("expected_amount", po_data.get("amount", 0))),
197
+ "rate": float(po_data.get("rate", 0)),
198
+ "quantity": float(po_data.get("quantity", 0))
199
+ }
200
+
201
+ # Extract actual (from invoice)
202
+ actual = {
203
+ "invoice_number": invoice_data.invoice_number,
204
+ "order_id": invoice_data.order_id,
205
+ "customer_name": invoice_data.customer_name,
206
+ }
207
+
208
+ # Handle line-item level (assuming single dominant item)
209
+ if invoice_data.item_details:
210
+ item = invoice_data.item_details[0]
211
+ actual.update({
212
+ "item_name": item.item_name,
213
+ "amount": float(item.amount or 0),
214
+ "rate": float(item.rate or 0),
215
+ "quantity": float(item.quantity or 0)
216
+ })
217
+
218
+ # Define weights intelligently (sum = 1)
219
+ weights = {
220
+ "invoice_number": 0.20,
221
+ "order_id": 0.15,
222
+ "customer_name": 0.05,
223
+ "item_name": 0.05,
224
+ "amount": 0.25,
225
+ "rate": 0.15,
226
+ "quantity": 0.15
227
+ }
228
+
229
+ # --- Similarity functions ---
230
+ def numeric_similarity(expected_val, actual_val):
231
+ if expected_val == 0:
232
+ return 1.0 if actual_val == 0 else 0.0
233
+ diff_ratio = abs(expected_val - actual_val) / (abs(expected_val) + 1e-6)
234
+ return max(0.0, 1.0 - diff_ratio)
235
+
236
+ def text_similarity(a, b):
237
+ return SequenceMatcher(None, str(a).lower(), str(b).lower()).ratio()
238
+
239
+ # --- Compute weighted similarities ---
240
+ weighted_scores = []
241
+ for field, weight in weights.items():
242
+ exp_val, act_val = expected.get(field), actual.get(field)
243
+
244
+ if isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)):
245
+ score = numeric_similarity(exp_val, act_val)
246
+ else:
247
+ score = text_similarity(exp_val, act_val)
248
+
249
+ weighted_scores.append(weight * score)
250
+
251
+ # Combine to final confidence
252
+ confidence = sum(weighted_scores)
253
+ confidence = round(confidence * 100, 2) # convert to %
254
+ confidence = max(0.0, min(confidence, 100.0)) # clamp 0–100
255
+
256
+ self.logger.logger.debug(f"Validation Confidence (weighted): {confidence}%")
257
+ return confidence
258
+
259
+
260
+
261
+ def _determine_validation_status(self, validation_result: ValidationResult) -> ValidationStatus:
262
+ """
263
+ Determine the final validation status based on PO existence, discrepancies, and amount match.
264
+ """
265
+ if not validation_result.po_found:
266
+ return ValidationStatus.MISSING_PO
267
+
268
+ discrepancies_count = len(validation_result.discrepencies)
269
+
270
+ if discrepancies_count == 0 and validation_result.amount_match:
271
+ return ValidationStatus.VALID
272
+
273
+ if validation_result.amount_match and discrepancies_count <= 2:
274
+ return ValidationStatus.PARTIAL_MATCH
275
+
276
+ return ValidationStatus.INVALID
277
+
278
+
279
+ def _should_escalate_validation(self, validation_result: ValidationResult, invoice_data) -> bool:
280
+ # pass
281
+ return validation_result.validation_status in ['invalid','missing_po']
282
+
283
+ def _record_health_metrics(self, duration: float):
284
+ """Record the health metrics after each execution"""
285
+ success_rate = (
286
+ (self.successful_executions / self.total_executions) * 100
287
+ if self.total_executions > 0 else 0
288
+ )
289
+ avg_duration = (
290
+ self.total_duration / self.total_executions
291
+ if self.total_executions > 0 else 0
292
+ )
293
+
294
+ metrics = {
295
+ "Agent": "Validation Agent ✅",
296
+ "Executions": self.total_executions,
297
+ "Success Rate (%)": round(success_rate, 2),
298
+ "Avg Duration (ms)": round(avg_duration, 2),
299
+ "Total Failures": self.failed_executions,
300
+ # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
301
+ }
302
+ metrics_data = {}
303
+ executions = 0
304
+ success_rate = 0.0
305
+ avg_duration = 0.0
306
+ failures = 0
307
+ last_run = None
308
+
309
+ if self.metrics:
310
+ print("self.metrics from validation agent", self.metrics)
311
+ executions = self.metrics["processed"]
312
+ print("executions.....", executions)
313
+ avg_duration = self.metrics["avg_latency_ms"]
314
+ failures = self.metrics["errors"]
315
+ last_run = self.metrics["last_run_at"]
316
+ print("last_run.....", last_run)
317
+ success_rate = (executions - failures) / (executions + 1e-6)
318
+
319
+ # if last_run == None:
320
+ last_run = self.last_run
321
+
322
+ # 3. Health logic
323
+ overall_status = "🟢 Healthy"
324
+ if failures > 3:
325
+ overall_status = "🟠 Degraded"
326
+ if executions > 0 and success_rate < 0.5:
327
+ overall_status = "🔴 Unhealthy"
328
+
329
+ print("metrics from val---....1", metrics)
330
+
331
+ metrics.update({
332
+ "Last Run": str(last_run) if last_run else "Not applicable",
333
+ "Overall Health": overall_status,
334
+ })
335
+ print("metrics from val---....", metrics)
336
+ # maintain up to last 50 records
337
+ ValidationAgent.health_history.append(metrics)
338
+ # ValidationAgent.health_history = ValidationAgent.health_history[-50:]
339
+
340
+ async def health_check(self) -> Dict[str, Any]:
341
+ """
342
+ Returns the health metrics summary for UI display.
343
+ """
344
+ await asyncio.sleep(0.05)
345
+ if not ValidationAgent.health_history:
346
+ return {
347
+ "Agent": "Validation Agent ✅",
348
+ "Executions": 0,
349
+ "Success Rate (%)": 0.0,
350
+ "Avg Duration (ms)": 0.0,
351
+ "Total Failures": 0,
352
+ }
353
+
354
+
355
+ latest = ValidationAgent.health_history[-1]
356
+ print("latest.....", latest)
357
+ return latest
Project/bounding_box.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import fitz # PyMuPDF
3
+ import pandas as pd
4
+ import os
5
+ import re
6
+
7
+ # === File paths ===
8
+ DATA_DIR = os.path.join(os.getcwd(), "data")
9
+ PDF_PATH = os.path.join(DATA_DIR, "invoices/Invoice-26.pdf") # Update for new PDF if needed
10
+ CSV_PATH = os.path.join(DATA_DIR, "purchase_orders.csv")
11
+ OUTPUT_PATH = os.path.join(DATA_DIR, "annotated_invoice.pdf")
12
+
13
+ # === Field coordinate map (from your data) ===
14
+ FIELD_BOXES = {
15
+ "invoice_number": (525, 55, 575, 75),
16
+ "order_id": (45, 470, 230, 490),
17
+ "customer_name": (40, 135, 100, 155),
18
+ "quantity": (370, 235, 385, 250),
19
+ "rate": (450, 235, 500, 250),
20
+ "expected_amount": (520, 360, 570, 375),
21
+ }
22
+
23
+ # === Step 1: Open PDF and extract text ===
24
+ pdf = fitz.open(PDF_PATH)
25
+ page = pdf[0]
26
+ pdf_text = page.get_text()
27
+
28
+ # === Step 2: Helper to extract fields ===
29
+ def extract_field(pattern, text, group=1):
30
+ match = re.search(pattern, text, re.IGNORECASE)
31
+ return match.group(group).strip() if match else None
32
+
33
+ # Extract key identifiers
34
+ invoice_number_pdf = extract_field(r"#\s*(\d+)", pdf_text)
35
+ order_id_pdf = extract_field(r"Order ID\s*[:\-]?\s*(\S+)", pdf_text)
36
+ customer_name_pdf = extract_field(r"Bill To:\s*(.*)", pdf_text)
37
+
38
+ # === Step 3: Read CSV and match correct row ===
39
+ po_df = pd.read_csv(CSV_PATH)
40
+
41
+ matched_row = po_df[
42
+ (po_df['invoice_number'].astype(str) == str(invoice_number_pdf))
43
+ | (po_df['order_id'] == order_id_pdf)
44
+ ]
45
+
46
+ if matched_row.empty:
47
+ raise ValueError(f"No matching CSV row found for Invoice {invoice_number_pdf} / Order {order_id_pdf}")
48
+
49
+ expected = matched_row.iloc[0].to_dict()
50
+ expected = {k.lower(): str(v).strip() for k, v in expected.items()}
51
+
52
+ print("✅ Loaded expected data from CSV for this PDF:")
53
+ for k, v in expected.items():
54
+ print(f" {k}: {v}")
55
+
56
+ # === Step 4: Extract fields from PDF ===
57
+ invoice_data = {
58
+ "invoice_number": invoice_number_pdf,
59
+ "customer_name": customer_name_pdf,
60
+ "order_id": order_id_pdf,
61
+ }
62
+
63
+ # Numeric fields
64
+ amounts = re.findall(r"\$?([\d,]+\.\d{2})", pdf_text)
65
+ invoice_data["expected_amount"] = amounts[-1] if amounts else None
66
+
67
+ # Extract first item (quantity, rate)
68
+ item_lines = re.findall(
69
+ r"([A-Za-z0-9 ,\-]+)\s+(\d+)\s+\$?([\d,]+\.\d{2})\s+\$?([\d,]+\.\d{2})",
70
+ pdf_text,
71
+ )
72
+ if item_lines:
73
+ invoice_data["quantity"] = item_lines[0][1]
74
+ invoice_data["rate"] = item_lines[0][2]
75
+
76
+ print("\n✅ Extracted data from PDF:")
77
+ for k, v in invoice_data.items():
78
+ print(f" {k}: {v}")
79
+
80
+ # === Step 5: Compare PDF vs CSV ===
81
+ discrepancies = []
82
+
83
+ def add_discrepancy(field, expected_val, found_val):
84
+ discrepancies.append({"field": field, "expected": expected_val, "found": found_val})
85
+
86
+ # Compare string fields
87
+ for field in ["invoice_number", "order_id", "customer_name"]:
88
+ if str(invoice_data.get(field, "")).strip() != str(expected.get(field, "")).strip():
89
+ add_discrepancy(field, expected.get(field, ""), invoice_data.get(field, ""))
90
+
91
+ # Compare numeric fields
92
+ for field in ["quantity", "rate", "expected_amount"]:
93
+ try:
94
+ found_val = float(str(invoice_data.get(field, 0)).replace(",", "").replace("$", ""))
95
+ expected_val = float(str(expected.get(field, 0)).replace(",", "").replace("$", ""))
96
+ if round(found_val, 2) != round(expected_val, 2):
97
+ add_discrepancy(field, expected_val, found_val)
98
+ except:
99
+ if str(invoice_data.get(field, "")) != str(expected.get(field, "")):
100
+ add_discrepancy(field, expected.get(field, ""), invoice_data.get(field, ""))
101
+
102
+ # === Step 6: Annotate mismatched fields using fixed coordinates ===
103
+ for d in discrepancies:
104
+ field = d["field"]
105
+ if field not in FIELD_BOXES:
106
+ print(f"⚠️ No coordinates found for field '{field}' — skipping annotation.")
107
+ continue
108
+
109
+ rect_coords = FIELD_BOXES[field]
110
+ rect = fitz.Rect(rect_coords)
111
+ expected_text = (
112
+ f"{float(d['expected']):,.2f}"
113
+ if field in ["quantity", "rate", "expected_amount"]
114
+ else str(d["expected"])
115
+ )
116
+
117
+ # Draw red bounding box
118
+ page.draw_rect(rect, color=(1, 0, 0), width=1.5)
119
+
120
+ # Add expected value below box
121
+ page.insert_text(
122
+ (rect.x0, rect.y1 + 10),
123
+ expected_text,
124
+ fontsize=9,
125
+ color=(1, 0, 0),
126
+ )
127
+
128
+ pdf.save(OUTPUT_PATH)
129
+ pdf.close()
130
+
131
+ print("\n✅ Annotated invoice saved at:", OUTPUT_PATH)
132
+
133
+ if discrepancies:
134
+ print("\n⚠️ Mismatches found:")
135
+ for d in discrepancies:
136
+ print(f" - {d['field']}: expected {d['expected']}, found {d['found']}")
137
+ else:
138
+ print("\n✅ No mismatches found! Invoice matches CSV.")
Project/data/annotated_invoice.pdf ADDED
Binary file (15.8 kB). View file
 
Project/data/invoices/Invoice-01.pdf ADDED
Binary file (13.5 kB). View file
 
Project/data/invoices/Invoice-02.pdf ADDED
Binary file (13.3 kB). View file
 
Project/data/invoices/Invoice-03.pdf ADDED
Binary file (14.5 kB). View file
 
Project/data/invoices/Invoice-04.pdf ADDED
Binary file (14.3 kB). View file
 
Project/data/invoices/Invoice-05.pdf ADDED
Binary file (15.6 kB). View file
 
Project/data/invoices/Invoice-06.pdf ADDED
Binary file (14.1 kB). View file
 
Project/data/invoices/Invoice-07.pdf ADDED
Binary file (13.3 kB). View file
 
Project/data/invoices/Invoice-08.pdf ADDED
Binary file (14.4 kB). View file
 
Project/data/invoices/Invoice-09.pdf ADDED
Binary file (14.7 kB). View file
 
Project/data/invoices/Invoice-10.pdf ADDED
Binary file (14.2 kB). View file
 
Project/data/invoices/Invoice-11.pdf ADDED
Binary file (14.2 kB). View file
 
Project/data/invoices/Invoice-12.pdf ADDED
Binary file (13.7 kB). View file
 
Project/data/invoices/Invoice-13.pdf ADDED
Binary file (13.6 kB). View file
 
Project/data/invoices/Invoice-14.pdf ADDED
Binary file (13.3 kB). View file
 
Project/data/invoices/Invoice-15.pdf ADDED
Binary file (13.6 kB). View file
 
Project/data/invoices/Invoice-16.pdf ADDED
Binary file (14.2 kB). View file
 
Project/data/invoices/Invoice-17.pdf ADDED
Binary file (14 kB). View file
 
Project/data/invoices/Invoice-18.pdf ADDED
Binary file (13.6 kB). View file
 
Project/data/invoices/Invoice-19.pdf ADDED
Binary file (13.7 kB). View file
 
Project/data/invoices/Invoice-20.pdf ADDED
Binary file (13.9 kB). View file
 
Project/data/invoices/Invoice-21.pdf ADDED
Binary file (13.5 kB). View file
 
Project/data/invoices/Invoice-22.pdf ADDED
Binary file (14 kB). View file
 
Project/data/invoices/Invoice-23.pdf ADDED
Binary file (13.8 kB). View file
 
Project/data/invoices/Invoice-24.pdf ADDED
Binary file (14.2 kB). View file
 
Project/data/invoices/Invoice-25.pdf ADDED
Binary file (13.4 kB). View file
 
Project/data/invoices/Invoice-26.pdf ADDED
Binary file (13.7 kB). View file
 
Project/data/invoices/Invoice-27.pdf ADDED
Binary file (13.5 kB). View file
 
Project/data/invoices/Invoice-28.pdf ADDED
Binary file (14.2 kB). View file
 
Project/data/invoices/Invoice-29.pdf ADDED
Binary file (13.8 kB). View file
 
Project/data/invoices/Invoice-30.pdf ADDED
Binary file (13.9 kB). View file
 
Project/data/invoices/Invoice-31.pdf ADDED
Binary file (14 kB). View file
 
Project/data/invoices/Invoice-32.pdf ADDED
Binary file (13.7 kB). View file
 
Project/data/invoices/Invoice-33.pdf ADDED
Binary file (14 kB). View file
 
Project/data/invoices/Invoice-34.pdf ADDED
Binary file (13.9 kB). View file
 
Project/data/invoices/Invoice-35.pdf ADDED
Binary file (14 kB). View file
 
Project/data/invoices/Invoice-36.pdf ADDED
Binary file (13.6 kB). View file