diff --git a/Project/.dockerignore b/Project/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..634a0687bf9366a8b7f7241873845eafca7c1135 --- /dev/null +++ b/Project/.dockerignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +.env +.git +.gitignore +myenv/ +venv/ +.env/ +.venv/ +tests.py \ No newline at end of file diff --git a/Project/.gitignore b/Project/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..16e35306daddb71d81eb22d6c72f66a6f530b4f3 --- /dev/null +++ b/Project/.gitignore @@ -0,0 +1,12 @@ +myenv +.env +key_stats.json +tests.py +Dockerfile +run.py +tests.py +__pycache__/ +*.pyc +logs/audit +nodes +output/escalations \ No newline at end of file diff --git a/Project/agents/audit_agent.py b/Project/agents/audit_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..da08a826ef2f22ab50bf12229c5d9ad1081bd6f1 --- /dev/null +++ b/Project/agents/audit_agent.py @@ -0,0 +1,693 @@ + +"""Audit Agent for Invoice Processing""" + +# TODO: Implement agent + +import os +import json +import pandas as pd +from typing import Dict, Any, List, Optional +from datetime import datetime, timedelta +import google.generativeai as genai +from dotenv import load_dotenv +import time +from statistics import mean + +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, ProcessingStatus, PaymentStatus, + ValidationStatus, RiskLevel +) +from utils.logger import StructuredLogger + +load_dotenv() + + +class AuditAgent(BaseAgent): + """Agent responsible for audit trail generation, compliance tracking, and reporting""" + + def __init__(self, config: Dict[str, Any] = None): + super().__init__("audit_agent",config) + self.logger = StructuredLogger("AuditAgent") + # --- Health tracking --- + self.execution_history: List[Dict[str, Any]] = [] + self.max_history = 50 # store last 50 runs + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + """ + Ensure that the state object is properly initialized before invoice processing begins. + Checks for presence of critical fields like process_id, file_name, and timestamps. + """ + if not state: + return False + + # Must have valid process id and file name + if not getattr(state, "process_id", None) or not getattr(state, "file_name", None): + return False + + # Must have timestamps and valid status + if not getattr(state, "created_at", None) or not getattr(state, "overall_status", None): + return False + + # Should not already be marked complete + if state.overall_status in ("failed", "pending"): + return False + + return True + + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + """ + Validate that all expected outputs and audit data are present after processing. + Ensures that critical workflow components completed successfully. + """ + if not state: + return False + + # Must have processed invoice data and validation results + if not state.invoice_data or not state.validation_result: + return False + + # Must have at least one audit entry for traceability + if not state.audit_trail or len(state.audit_trail) == 0: + return False + + # Risk or payment results may be optional, but check consistency + if state.risk_assessment and state.risk_assessment.risk_score > 1.0: + return False # sanity check for invalid scores + + # Final status should not be pending anymore + if state.overall_status == "pending": + return False + + return True + + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + """Main audit generation workflow""" + self.logger.logger.info("Starting audit trail generation") + start_time = time.time() + success = False + try: + if not self._validate_preconditions(state, workflow_type): + self.logger.logger.error("Preconditions not met for audit generation") + state.overall_status = ProcessingStatus.FAILED + self._log_decision(state, "Audit Failed", "Preconditions not met", confidence=0.0) + return state + + audit_record = await self._generate_audit_record(state) + print("audit_record---------", audit_record) + compliance_results = await self._perform_compliance_checks(state,audit_record) + print("compliance_results---------", compliance_results) + audit_summary = await self._generate_audit_summary(state,audit_record,compliance_results) + print("audit_summary---------", audit_summary) + await self._save_audit_records(state,audit_record,audit_summary,compliance_results) + + reportable_events = await self._identify_reportable_events(state,audit_record) + print("reportable_events---------", reportable_events) + + await self._generate_audit_alerts(state,reportable_events) + + state.audit_trail = audit_record.get("audit_trail",[]) + print("state.audit_trail---------", state.audit_trail) + state.compliance_report = compliance_results + state.current_agent = "audit_agent" + state.overall_status = "completed" + + self.logger.logger.info("Audit trail and compliance generated successfully") + success = True + self._log_decision( + state, + "Auditing Successful", + "Auditing Processed", + 100.0, + state.process_id + ) + state.audit_trail[-1] + return state + + except Exception as e: + self.logger.logger.error(f"Audit agent execution failed: {e}") + state.overall_status = ProcessingStatus.FAILED + return state + + finally: + duration_ms = round((time.time() - start_time) * 1000, 2) + self._record_execution(success, duration_ms, state) + + async def _generate_audit_record(self, state: InvoiceProcessingState) -> Dict[str, Any]: + """ + Aggregate and structure all agent-level logs into a consistent audit report. + Uses the state's existing audit_trail list and agent_metrics for detailed tracking. + """ + self.logger.logger.debug("Generating audit record") + + if not isinstance(state, InvoiceProcessingState): + raise ValueError("Invalid state object passed to _generate_audit_record") + + audit_trail_records = [] + for entry in getattr(state, "audit_trail", []): + record = { + "process_id": getattr(entry, "process_id", state.process_id), + "timestamp": getattr(entry, "timestamp", datetime.utcnow().isoformat() + "Z"), + "agent_name": getattr(entry, "agent_name", "unknown"), + "action": getattr(entry, "action", "undefined"), + # "status": getattr(entry, "status", "completed"), + "details": getattr(entry, "details", {}), + # "duration_ms": getattr(entry, "details", {}).get("duration_ms", 0), + # "error_message": getattr(entry, "details", {}).get("error_message", None), + } + audit_trail_records.append(record) + + # Include agent metrics summary for full traceability + metrics_summary = { + agent: { + "executions": getattr(m, "processed_count", 0), + "success_rate": getattr(m, "success_rate", 0), + "failures": getattr(m, "errors", 0), + "avg_duration_ms": getattr(m, "avg_latency_ms", 0.0), + "last_run_at": getattr(m, "last_run_at", None), + } + for agent, m in getattr(state, "agent_metrics", {}).items() + } + + audit_report = { + "process_id": state.process_id, + "created_at": state.created_at.isoformat() + "Z", + "updated_at": state.updated_at.isoformat() + "Z", + "total_entries": len(audit_trail_records), + "audit_trail": audit_trail_records, + "metrics_summary": metrics_summary, + } + + self.logger.logger.info( + f"Audit record generated with {len(audit_trail_records)} entries for process {state.process_id}" + ) + + return audit_report + + async def _perform_compliance_checks( + self, state: InvoiceProcessingState, audit_record: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Perform SOX, GDPR, and financial compliance validations. + Aggregates results from internal compliance check methods and produces + a structured compliance report. + """ + self.logger.logger.debug("Performing compliance checks for process %s", state.process_id) + + # Defensive: ensure proper structures + if not isinstance(state, InvoiceProcessingState): + raise ValueError("Invalid state object passed to _perform_compliance_checks") + if not isinstance(audit_record, dict): + raise ValueError("Invalid audit record structure") + + # Run all compliance sub-checks safely + sox = self._check_sox_compliance(state, audit_record) or {} + privacy = self._check_data_privacy_compliance(state, audit_record) or {} + financial = self._check_financial_controls(state, audit_record) or {} + completeness = self._check_audit_trail_completeness(state, audit_record) or {} + + # Normalize results for consistency + sox_issues = sox.get("issues", []) + privacy_issues = privacy.get("issues", []) + financial_issues = financial.get("issues", []) + is_complete = completeness.get("complete", True) + + # Compose structured compliance summary + compliance_report = { + "process_id": state.process_id, + "timestamp": datetime.utcnow().isoformat() + "Z", + "sox_compliance": "compliant" if not sox_issues else "non_compliant", + "gdpr_compliance": "compliant" if not privacy_issues else "non_compliant", + "financial_controls": "passed" if not financial_issues else "failed", + "audit_trail_complete": is_complete, + "retention_policy": getattr(self.config, "retention_policy", "7_years"), + "encryption_status": "encrypted", + "issues": { + "sox": sox_issues, + "privacy": privacy_issues, + "financial": financial_issues, + }, + } + + # Optional: attach compliance report to the state for future use + setattr(state, "compliance_report", compliance_report) + state.updated_at = datetime.utcnow() + + self.logger.logger.info( + f"Compliance checks completed for process {state.process_id}: " + f"SOX={compliance_report['sox_compliance']}, " + f"GDPR={compliance_report['gdpr_compliance']}, " + f"Financial={compliance_report['financial_controls']}" + ) + + return compliance_report + + + def _check_sox_compliance( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any] + ) -> Dict[str, List[str]]: + """ + Intelligent SOX compliance verification. + Checks that all approval steps, segregation of duties, + and key sign-offs are properly recorded and timestamped. + """ + issues = [] + + approval_chain = getattr(state, "approval_chain", []) + if not approval_chain: + issues.append("Missing approval chain records") + else: + # Verify each approval step includes signer and timestamp + for step in approval_chain: + if not step.get("approved_by") or not step.get("timestamp"): + issues.append(f"Incomplete approval step: {step}") + # Optional: check segregation of duties + approvers = [a.get("approved_by") for a in approval_chain if a.get("approved_by")] + if len(set(approvers)) < len(approvers): + issues.append("Potential conflict of interest: repeated approver detected") + + VALID_ACTIONS = { + "Extraction Successful", + "Validation Successful", + "Risk Assessment Successful", + "Agent Successfully Executed", + "approved" + } + has_final_approval = all( + any(keyword in entry.get("action", "") for keyword in VALID_ACTIONS) + for entry in audit_record.get("audit_trail", []) + ) + + if not has_final_approval: + issues.append("Some approval event yet to successful in audit trail") + + return {"issues": issues} + + + def _check_data_privacy_compliance( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any] + ) -> Dict[str, List[str]]: + """ + Validate GDPR / Data Privacy compliance. + Ensures that no unmasked personal or financial data is logged or stored. + """ + issues = [] + text_repr = str(audit_record).lower() + + # PII patterns to scan for (we can expand this list) + suspicious_patterns = ["@gmail.com", "@yahoo.com", "ssn", "credit card", "bank_account"] + + for pattern in suspicious_patterns: + if pattern in text_repr: + issues.append(f"Unmasked PII detected: '{pattern}'") + + # Ensure encryption and retention policy + # if getattr(state, "config", {}).get("encryption_status") != "encrypted": + # issues.append("Data encryption not confirmed") + + # if "retention_policy" not in getattr(state, "config", {}): + # issues.append("Retention policy not defined") + + return {"issues": issues} + + + def _check_financial_controls( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any] + ) -> Dict[str, List[str]]: + """ + Validate financial control compliance. + Ensures that transactions, approvals, and risk assessments + are properly recorded before payment release. + """ + issues = [] + + # Check for missing financial artifacts + if not getattr(state, "payment_decision", None): + issues.append("Missing payment decision records") + + if not getattr(state, "validation_result", None): + issues.append("Missing validation result for payment control") + + if state.validation_result and state.validation_result.validation_status == "invalid": + issues.append("Invoice marked invalid but payment decision exists") + + # Cross-check audit trail for financial actions + actions = [a.get("action", "").lower() for a in audit_record.get("audit_trail", [])] + if not any("approved" in a for a in actions): + issues.append("No payment-related activity recorded in audit trail") + + return {"issues": issues} + + def _check_audit_trail_completeness( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Ensure all mandatory agents and workflow stages were executed and logged. + Validates sequence integrity and timestamp order. + """ + required_agents = ["document_agent", "validation_agent", "risk_agent", "payment_agent"] + logged_agents = [x.get("agent_name") for x in audit_record.get("audit_trail", [])] + missing = [a for a in required_agents if a not in logged_agents] + + complete = len(missing) == 0 + + timestamps = [] + for e in audit_record.get("audit_trail", []): + ts = e.get("timestamp") + if ts: + try: + if isinstance(ts, datetime): + timestamps.append(ts) + else: + # Normalize 'Z' and try parsing + ts_str = str(ts).replace("Z", "+00:00") + try: + timestamps.append(datetime.fromisoformat(ts_str)) + except Exception: + try: + timestamps.append(datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S.%f")) + except Exception: + timestamps.append(datetime.strptime(ts_str, "%Y-%m-%d %H:%M:%S")) + except Exception: + self.logger.logger.warning(f"Invalid timestamp format in audit trail: {ts}") + + + + if timestamps and timestamps != sorted(timestamps): + missing.append("Non-sequential timestamps detected in audit trail") + + # Check for duplicate agent entries + if len(logged_agents) != len(set(logged_agents)): + missing.append("Duplicate agent entries found in audit trail") + + return {"complete": complete, "missing": missing} + + + async def _generate_audit_summary( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any], + compliance_results: Dict[str, Any] + ) -> str: + """ + Generate a structured textual audit summary report. + Combines audit record data and compliance results into a concise, test-friendly JSON summary. + """ + self.logger.logger.debug("Generating audit summary for process %s", state.process_id) + + # Defensive: ensure valid input types + if not isinstance(state, InvoiceProcessingState): + raise ValueError("Invalid state object passed to _generate_audit_summary") + if not isinstance(audit_record, dict): + raise ValueError("Invalid audit record structure") + if not isinstance(compliance_results, dict): + raise ValueError("Invalid compliance results structure") + + # Extract audit trail count safely + total_actions = len(audit_record.get("audit_trail", [])) + + # Safely extract compliance keys + sox_status = compliance_results.get("sox_compliance", "unknown") + gdpr_status = compliance_results.get("gdpr_compliance", "unknown") + financial_status = compliance_results.get("financial_controls", "unknown") + retention_policy = compliance_results.get("retention_policy", "7_years") + + # Build structured summary + summary_data = { + "process_id": state.process_id, + "generated_at": datetime.utcnow().isoformat() + "Z", + "total_actions": total_actions, + "overall_status": getattr(state, "overall_status", "UNKNOWN"), + "compliance": { + "SOX": sox_status, + "GDPR": gdpr_status, + "Financial": financial_status, + }, + "retention_policy": retention_policy, + } + + # Attach to state for post-validation + setattr(state, "audit_summary", summary_data) + state.updated_at = datetime.utcnow() + + # Log completion + self.logger.logger.info( + f"Audit summary generated for process {state.process_id}: " + f"Actions={total_actions}, SOX={sox_status}, GDPR={gdpr_status}, Financial={financial_status}" + ) + + # Return formatted JSON for easy test validation or storage + return json.dumps(summary_data, indent=2) + + + async def _save_audit_records(self, state: InvoiceProcessingState, + audit_record: Dict[str, Any], + audit_summary: str, + compliance_results: Dict[str, Any]): + """Save audit log to file""" + os.makedirs("logs/audit",exist_ok=True) + file_path = f"logs/audit/audit_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json" + with open(file_path,"w") as f: + json.dump({ + "audit_trail": audit_record["audit_trail"], + "summary": json.loads(audit_summary), + "compliance":compliance_results + },f,indent=2, default=str) + self.logger.logger.info(f"Audit record saved:{file_path}") + + async def _identify_reportable_events( + self, + state: InvoiceProcessingState, + audit_record: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """ + Identify reportable anomalies or irregularities from the audit trail for compliance auditors. + Includes failed actions, high latency events, and repeated errors. + """ + self.logger.logger.debug("Analyzing audit trail for reportable events...") + + reportable: List[Dict[str, Any]] = [] + audit_trail = audit_record.get("audit_trail", []) + + if not audit_trail: + self.logger.logger.warning("No audit trail found for process %s", state.process_id) + return [] + + # Group by agent to detect repeated failures + failure_counts = {} + + for entry in audit_trail: + # Defensive: ensure entry is a dict + if not isinstance(entry, dict): + continue + + status = str(entry.get("status", "")).lower() + error_message = entry.get("error_message") + duration_ms = entry.get("duration_ms", 0) + agent = entry.get("agent_name", "unknown") + + # Track failures for later aggregation + if status == "failed": + failure_counts[agent] = failure_counts.get(agent, 0) + 1 + + # Identify anomalies + anomaly_detected = ( + status == "failed" + or bool(error_message) + or duration_ms > 5000 # 5-second latency threshold + ) + + if anomaly_detected: + reportable.append({ + "process_id": state.process_id, + "agent_name": agent, + "timestamp": entry.get("timestamp", datetime.utcnow().isoformat() + "Z"), + "status": status, + "duration_ms": duration_ms, + "error_message": error_message, + "details": entry.get("details", {}), + "anomaly_reason": ( + "Failure" + if status == "failed" + else "High latency" + if duration_ms > 5000 + else "Error message logged" + ), + }) + + # Add summary-level anomaly if multiple failures detected + for agent, count in failure_counts.items(): + if count > 2: + reportable.append({ + "process_id": state.process_id, + "agent_name": agent, + "timestamp": datetime.utcnow().isoformat() + "Z", + "status": "repeated_failures", + "details": {"failure_count": count}, + "anomaly_reason": f"{count} repeated failures detected for {agent}", + }) + + # Log summary for visibility + if reportable: + self.logger.logger.info( + "Detected %d reportable events for process %s", + len(reportable), + state.process_id, + ) + else: + self.logger.logger.debug("No reportable events found for process %s", state.process_id) + + # Attach to state for traceability + setattr(state, "reportable_events", reportable) + state.updated_at = datetime.utcnow() + + return reportable + + + async def _generate_audit_alerts( + self, + state: InvoiceProcessingState, + reportable_events: List[Dict[str, Any]] + ) -> None: + """ + Generate and dispatch alerts for detected audit anomalies. + Alerts are categorized based on severity (warning or critical) + and logged for traceability. Optionally integrates with external + alerting channels (e.g., Slack, PagerDuty, email). + """ + if not reportable_events: + self.logger.logger.debug("No audit alerts to generate for process %s", state.process_id) + return + + self.logger.logger.warning( + "[AuditSystem] %d reportable audit events detected for process %s", + len(reportable_events), + state.process_id, + ) + + alerts_summary = [] + critical_events = 0 + + for event in reportable_events: + agent = event.get("agent_name", "unknown") + reason = event.get("anomaly_reason", "unspecified") + status = str(event.get("status", "")).lower() + duration = event.get("duration_ms", 0) + timestamp = event.get("timestamp", datetime.utcnow().isoformat() + "Z") + + # Classify severity + severity = "critical" if status == "failed" or "repeated" in status else "warning" + if severity == "critical": + critical_events += 1 + + alert_message = ( + f"[{severity.upper()} ALERT] Agent: {agent} | Reason: {reason} | " + f"Status: {status} | Duration: {duration} ms | Time: {timestamp}" + ) + + # Log structured alert + if severity == "critical": + self.logger.logger.error(alert_message) + else: + self.logger.logger.warning(alert_message) + + alerts_summary.append({ + "severity": severity, + "agent_name": agent, + "reason": reason, + "status": status, + "duration_ms": duration, + "timestamp": timestamp, + }) + + # Optionally send to external alerting channels (mocked) + try: + await self._send_alert_notification(alerts_summary[-1]) + except Exception as e: + self.logger.logger.error(f"Failed to dispatch alert notification: {e}") + + # Attach alerts summary to state for later review + setattr(state, "audit_alerts", alerts_summary) + state.updated_at = datetime.utcnow() + + # Log summary + self.logger.logger.info( + "Audit alert generation completed: %d total (%d critical)", + len(alerts_summary), + critical_events, + ) + + def _record_execution(self, success: bool, duration_ms: float, state: Optional[InvoiceProcessingState] = None): + compliance = getattr(state, "compliance_report", {}) if state else {} + compliant_flags = [ + compliance.get("sox_compliance") == "compliant", + compliance.get("gdpr_compliance") == "compliant", + compliance.get("financial_controls") in ("passed", "compliant") + ] + compliance_score = round((sum(compliant_flags) / len(compliant_flags)) * 100, 2) if compliant_flags else 0 + + self.execution_history.append({ + # "timestamp": datetime.utcnow().isoformat(), + "success": success, + "duration_ms": duration_ms, + "compliance_score": compliance_score, + "reportable_events": len(getattr(state, "reportable_events", [])) if state else 0, + }) + + if len(self.execution_history) > self.max_history: + self.execution_history.pop(0) + + async def health_check(self) -> Dict[str, Any]: + total_runs = len(self.execution_history) + if total_runs == 0: + return { + "Agent": "Audit Agent ๐งฎ", + "Executions": 0, + "Success Rate (%)": 0.0, + "Avg Duration (ms)": 0.0, + "Total Failures": 0, + "Avg Compliance (%)": 0.0, + "Avg Reportable Events": 0.0, + "Status": "idle", + # "Timestamp": datetime.utcnow().isoformat() + } + + successes = sum(1 for e in self.execution_history if e["success"]) + failures = total_runs - successes + avg_duration = round(mean(e["duration_ms"] for e in self.execution_history), 2) + success_rate = round((successes / (total_runs+1e-8)) * 100, 2) + avg_compliance = round(mean(e["compliance_score"] for e in self.execution_history), 2) + avg_events = round(mean(e["reportable_events"] for e in self.execution_history), 2) + + # Dynamic health status logic + print("self.execution_history------", self.execution_history) + print(avg_compliance) + if success_rate >= 85 and avg_compliance >= 90: + overall_status = "๐ข Healthy" + elif success_rate >= 60: + overall_status = "๐ Degraded" + else: + overall_status = "๐ด Unhealthy" + + return { + "Agent": "Audit Agent ๐งฎ", + "Executions": total_runs, + "Success Rate (%)": success_rate, + "Avg Duration (ms)": avg_duration, + "Total Failures": failures, + "Avg Compliance (%)": avg_compliance, + "Avg Reportable Events": avg_events, + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + "Overall Health": overall_status, + "Last Run": self.metrics["last_run_at"], + } diff --git a/Project/agents/base_agent.py b/Project/agents/base_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3e3934ee72532aa50f1862a7996fc3a1d640f1 --- /dev/null +++ b/Project/agents/base_agent.py @@ -0,0 +1,221 @@ + +"""Base Agent Class for Invoice Processing System""" + +# TODO: Implement agent + +import time +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from datetime import datetime + +from state import InvoiceProcessingState, ProcessingStatus, AuditTrail +from utils.logger import get_logger + + +class BaseAgent(ABC): + """Abstract base class for all invoice processing agents""" + + def __init__(self, agent_name: str, config: Dict[str, Any] = None): + self.agent_name = agent_name + self.config = config or {} + self.logger = get_logger(agent_name) + self.metrics: Dict[str,Any] = { + "processed" : 0, + "errors" : 0, + "avg_latency_ms" : None, + "last_run_at" : None + } + self.start_time: Optional[float] = None + + @abstractmethod + async def execute(self, state: InvoiceProcessingState) -> InvoiceProcessingState: + raise NotImplementedError + + async def run(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + self.start_time = time.time() + self.logger.logger.info(f"Starting {self.agent_name} execution.") + if not self._validate_preconditions(state, workflow_type): + self.logger.logger.warning(f"Preconditions not met for {self.agent_name}.") + self.metrics["processed"] = int(self.metrics.get("processed", 0)) + 1 + self.metrics["last_run_at"] = datetime.utcnow().isoformat() + + # optional but very good: + state.add_agent_metric(self.agent_name, processed=1, latency_ms=0, errors=0) + + state.add_audit_entry( + self.agent_name, + "precondition_failed", + {"note": "Preconditions not met, agent skipped."} + ) + return state + state.current_agent = self.agent_name + state.agent_name = self.agent_name + state.overall_status = ProcessingStatus.IN_PROGRESS + + try: + updated_state = await self.execute(state, workflow_type) + + try: + self._validate_postconditions(updated_state) + except Exception as post_exc: + self.logger.logger.warning(f"Postcondition check raised for {self.agent_name}:{post_exc}") + + state.mark_agent_completed(self.agent_name) + latency_ms = (time.time()-self.start_time)*1000 + self.metrics["processed"] = int(self.metrics.get("processed",0)) + 1 + prev_avg = self.metrics.get("avg_latency_ms") + + if prev_avg is None: + self.metrics["avg_latency_ms"] = latency_ms + else: + self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0 + + self.metrics["last_run_at"] = datetime.utcnow().isoformat() + print( + f"Agent: {self.agent_name} | " + f"id: {id(self)} | " + f"last_run_at: {self.metrics['last_run_at']}" + ) + + print("self.metrics[last_run_at]", self.metrics["last_run_at"]) + state.add_agent_metric(self.agent_name,processed=1,latency_ms=latency_ms) + state.add_audit_entry(self.agent_name, action="Agent Successfully Executed", status=ProcessingStatus.COMPLETED, details={"latency_ms":latency_ms}, process_id=state.process_id) + + self.logger.logger.info(f"{self.agent_name}completed successfully in {latency_ms:.2f}ms.") + return updated_state + + except Exception as e: + latency_ms = (time.time()-self.start_time)*1000 if self.start_time else 0.0 + # self._update_metrics(latency_ms=latency_ms,error=True) + self.metrics["processed"] = int(self.metrics.get("processed",0))+1 + self.metrics["errors"] = int(self.metrics.get("errors",0))+1 + prev_avg = self.metrics.get("avg_latency_ms") + + if prev_avg is None: + self.metrics["avg_latency_ms"] = latency_ms + else: + self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0 + self.metrics["last_run_at"] = datetime.utcnow().isoformat() + state.add_agent_metric(self.agent_name, processed = 1, latency_ms = latency_ms, errors = 1) + state.add_audit_entry(self.agent_name,"Error in Execution",{"error":str(e)}) + state.overall_status = ProcessingStatus.FAILED + self.logger.logger.exception(f"{self.agent_name} failed: {e}") + return state + + def _validate_preconditions(self, state: InvoiceProcessingState) -> bool: + # pass + "override to add custom preconditions for agent execution" + return True + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + # pass + "override to verify expected outcomes after agent execution" + return True + + + def get_metrics(self) -> Dict[str, Any]: + # pass + return dict(self.metrics) + + def reset_metrics(self): + # pass + self.metrics = {"processed":0, + "errors":0, + "avg_latency_ms":None, + "last_run_at":None} + + async def health_check(self) -> Dict[str, Any]: + # pass + """perform a basic health check for the agent""" + return { + "agent":self.agent_name, + "status":"Healthy", + "Last Run":self.metrics.get("last_run_at"), + "errors":self.metrics.get("errors", 0) + } + + def _extract_business_context(self, state: InvoiceProcessingState) -> Dict[str, Any]: + # pass + """Extract relevant invoice or PO context for resaoning logs""" + context: Dict[str,Any] = {} + if state.invoice_data: + context["vendor"] = state.invoice_data.vendor_name + context["invoice_id"] = state.invoice_data.invoice_id + context["amount"] = state.invoice_data.total_amount + if state.validation_result: + try: + context["validation_status"] = state.validation_result.validation_status.value + except Exception: + context["validation_status"] = str(state.validation_result.validation_status) + if state.risk_assessment: + context["risk_score"] = state.risk_assessment.risk_score + context["risk_level"] = state.risk_assessment.risk_level.value if hasattr(state.risk_assessment.risk_level, "value") else str(state.risk_assessment.risk_level) + return context + + + def _should_escalate(self, state: InvoiceProcessingState, reason: str = None) -> bool: + # pass + """Determine whether the workflow should escalate.""" + try: + result = state.requires_escalation() + except Exception: + result = True + if result: + self.logger.logger.warning(f"Escalation triggered by {self.agent_name}:{reason or 'auto'}") + state.escalation_required = True + state.human_review_required = True + state.add_audit_entry(self.agent_name,"Escalation Triggered", None, {"reason":reason or "auto"}) + return result + + def _log_decision(self, state: InvoiceProcessingState, decision: str, + reasoning: str, confidence: float = None, process_id: str = None): + # pass + """Log and record an agent decision into audit trail.""" + details:Dict[str,Any] = { + "decision":decision, + "reasoning":reasoning, + "confidence":confidence, + # "timestamp":datetime.utcnow().isoformat() + } + self.logger.logger.info(f"{self.agent_name} decision:{decision}(confidence = {confidence})") + state.add_audit_entry(self.agent_name, decision, None, details, process_id) + +class AgentRegistry: + """Registry for managing agent instances""" + + def __init__(self): + # pass + self._agents:Dict[str,BaseAgent] = {} + + def register(self, agent: BaseAgent): + # pass + if agent.agent_name in self._agents: + print(f"{agent.agent_name} already registered - skipping") + return + self._agents[agent.agent_name] = agent + + def get(self, agent_name: str) -> Optional[BaseAgent]: + # pass + return self._agents.get(agent_name) + + def list_agents(self) -> List[str]: + # pass + return list(self._agents.keys()) + + def get_all_metrics(self) -> Dict[str, Dict[str, Any]]: + # pass + return {name:agent.get_metrics() for name, agent in self._agents.items()} + + async def health_check_all(self) -> Dict[str, Dict[str, Any]]: + # pass + result:Dict[str,Dict[str,Any]] = {} + for name, agent in self._agents.items(): + result[name] = await agent.health_check() + return result + + + +# Global agent registry instance +agent_registry = AgentRegistry() +print("Registry instance ID in base:", id(agent_registry)) diff --git a/Project/agents/document_agent.py b/Project/agents/document_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..b2507ec0518092e8d55f3acb382ffba247888b35 --- /dev/null +++ b/Project/agents/document_agent.py @@ -0,0 +1,411 @@ + +"""Document Agent for Invoice Processing""" + +# TODO: Implement agent + +import os +import json +import re +import fitz # PyMuPDF +import pdfplumber +from typing import Dict, Any, Optional, List +import google.generativeai as genai +from dotenv import load_dotenv +from datetime import datetime + +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, InvoiceData, ItemDetail, + ProcessingStatus, ValidationStatus +) +from utils.logger import StructuredLogger + + +load_dotenv() +logger = StructuredLogger("DocumentAgent") + +def safe_json_parse(result_text: str): + # Remove Markdown formatting if present + cleaned = re.sub(r"^```[a-zA-Z]*\n|```$", "", result_text.strip()) + try: + return json.loads(cleaned) + except json.JSONDecodeError: + # Fallback if the AI wrapped JSON in text + start, end = cleaned.find("{"), cleaned.rfind("}") + 1 + if start >= 0 and end > 0: + return json.loads(cleaned[start:end]) + raise + +def to_float(value): + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value.replace(',', '').replace('$', '').strip()) + except (ValueError, TypeError): + return 0.0 + return 0.0 + +def parse_date_safe(date_str): + if not date_str: + return None + for fmt in ("%b %d %Y", "%b %d, %Y", "%Y-%m-%d", "%d-%b-%Y"): + try: + return datetime.strptime(date_str.strip(), fmt).date() + except ValueError: + continue + return None + + +from collections import defaultdict +class APIKeyBalancer: + SAVE_FILE = "key_stats.json" + def __init__(self, keys): + self.keys = keys + self.usage = defaultdict(int) + self.errors = defaultdict(int) + self.load() + + def load(self): + if os.path.exists(self.SAVE_FILE): + data = json.load(open(self.SAVE_FILE)) + self.usage.update(data.get("usage", {})) + self.errors.update(data.get("errors", {})) + + def save(self): + json.dump({ + "usage": self.usage, + "errors": self.errors + }, open(self.SAVE_FILE, "w")) + + def get_best_key(self): + # choose least used or least errored key + best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k])) + self.usage[best_key] += 1 + self.save() + return best_key + + def report_error(self, key): + self.errors[key] += 1 + self.save() + + +balancer = APIKeyBalancer([ + os.getenv("GEMINI_API_KEY_1"), + os.getenv("GEMINI_API_KEY_2"), + os.getenv("GEMINI_API_KEY_3"), + # os.getenv("GEMINI_API_KEY_4"), + os.getenv("GEMINI_API_KEY_5"), + os.getenv("GEMINI_API_KEY_6"), + # os.getenv("GEMINI_API_KEY_7"), +]) + + +class DocumentAgent(BaseAgent): + """Agent responsible for document processing and invoice data extraction""" + + def __init__(self, config: Dict[str, Any] = None): + # pass + super().__init__("document_agent", config) + self.logger = StructuredLogger("DocumentAgent") + self.api_key = balancer.get_best_key() + print("self.api_key..........", self.api_key) + + genai.configure(api_key=self.api_key) + # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7")) + self.model = genai.GenerativeModel("gemini-2.5-flash") + + def generate(self, prompt): + try: + print("generate called") + response = self.model.generate_content(prompt) + print("response....", response) + return response + except Exception as e: + print("errrororrrooroor") + balancer.report_error(self.api_key) + print(balancer.keys) + print(balancer.usage) + print(balancer.errors) + raise + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + # pass + if not state.file_name or not os.path.exists(state.file_name): + self.logger.logger.error(f"[Document Agent] Missing or invalid file: {state.file_name}") + return False + return True + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + # pass + return bool(state.invoice_data and state.invoice_data.total > 0) + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + # pass + # file_name = state.file_name + self.logger.logger.info(f"Executing Document Agent for file: {state.file_name}") + + if not self._validate_preconditions(state, workflow_type): + state.overall_status = ProcessingStatus.FAILED + self._log_decision(state, "Extraction Failed", "Preconditions not met", confidence=0.0) + + try: + raw_text = await self._extract_text_from_pdf(state.file_name) + invoice_data = await self._parse_invoice_with_ai(raw_text) + invoice_data = await self._enhance_invoice_data(invoice_data, raw_text) + invoice_data.file_name = state.file_name + state.invoice_data = invoice_data + state.overall_status = ProcessingStatus.IN_PROGRESS + state.current_agent = self.agent_name + state.updated_at = datetime.utcnow() + + confidence = self._calculate_extraction_confidence(invoice_data, raw_text) + state.invoice_data.extraction_confidence = confidence + self._log_decision( + state, + "Extraction Successful", + "PDF text successfully extracted and parsed by AI", + confidence, + state.process_id + ) + return state + except Exception as e: + self.logger.logger.exception(f"[Document Agent] Extraction failed: {e}") + state.overall_status = ProcessingStatus.FAILED + self._should_escalate(state, reason=str(e)) + return state + + + async def _extract_text_from_pdf(self, file_name: str) -> str: + # pass + text = "" + try: + self.logger.logger.info("[DocumentAgent] Extracting text using PyMuPDF...") + with fitz.open(file_name) as doc: + for page in doc: + text += page.get_text() + if len(text.strip()) < 5: + raise ValueError("PyMuPDF extraction too short, switching to PDFPlumber") + except Exception as e: + self.logger.logger.info("[DocumentAgent] Fallback to PDFPlumber...") + try: + with pdfplumber.open(file_name) as pdf: + for page in pdf.pages: + text += page.extract_text() or "" + except Exception as e2: + self.logger.logger.error("[DocumentAgent] PDFPlumber failed :{e2}") + text = "" + return text + + async def _parse_invoice_with_ai(self, text: str) -> InvoiceData: + # pass + self.logger.logger.info("[DocumentAgent] Parsing invoice data using Gemini AI...") + print("text-----------", text) + prompt = f""" + Extract structured invoice information as JSON with fields: + invoice_number, order_id, customer_name, due_date, ship_to, ship_mode, + subtotal, discount, shipping_cost, total, and item_details (item_name, quantity, rate, amount). + + Important Note: If an item description continues on multiple lines, combine them into one item_name. Check intelligently + that if at all there will be more than one item then it should have more numbers. + So extract by verifying that is there only one item or more than one. + + Input Text: + {text[:8000]} + """ + response = self.generate(prompt) + result_text = response.text.strip() + data = safe_json_parse(result_text) + print("----------------------------------text-----------------------------------",text) + print("result text::::::::::::::::::::::::::::",data) + # try: + # data = json.loads(result_text) + # except Exception as e: + # self.logger.logger.warning("AI output not valid JSON, retrying with fallback parse.") + # data = json.loads(result_text[result_text.find('{'): result_text.rfind('}')+1]) + items = [] + for item in data.get("item_details", []): + items.append(ItemDetail( + item_name=item.get("item_name"), + quantity=float(item.get("quantity", 1)), + rate=to_float(item.get("rate", 0.0)), + amount=to_float(item.get("amount", 0.0)), + # category=self._categorize_item(item.get("item_name", "Unknown")), + )) + + invoice_data = InvoiceData( + invoice_number=data.get("invoice_number"), + order_id=data.get("order_id"), + customer_name=data.get("customer_name"), + due_date=parse_date_safe(data.get("due_date")), + ship_to=data.get("ship_to"), + ship_mode=data.get("ship_mode"), + subtotal=to_float(data.get("subtotal", 0.0)), + discount=to_float(data.get("discount", 0.0)), + shipping_cost=to_float(data.get("shipping_cost", 0.0)), + total=to_float(data.get("total", 0.0)), + item_details=items, + raw_text=text, + ) + confidence = self._calculate_extraction_confidence(invoice_data, text) + invoice_data.extraction_confidence = confidence + self.logger.logger.info("AI output successfully parsed into JSON format") + return invoice_data + + + async def _enhance_invoice_data(self, invoice_data: InvoiceData, raw_text: str) -> InvoiceData: + # pass + if not invoice_data.customer_name: + if "Invoice To" in raw_text: + lines = raw_text.split("\n") + for i, line in enumerate(lines): + if "Invoice To" in line: + invoice_data.customer_name = lines[i+1].strip() + break + return invoice_data + + def _categorize_item(self, item_name: str) -> str: + # pass + name = item_name.lower() + prompt = f""" + Extract the category of the Item from the item details very intelligently + so that we can get the category in which the item belongs to very efficiently: + Example: "Electronics", "Furniture", "Software", etc..... + Input Text- The item is given below (provide the category in JSON format like -- category: 'extracted category') ----> + {name} + """ + response = self.generate(prompt) + result_text = response.text.strip() + category = safe_json_parse(result_text) + print(category['category']) + return category['category'] + + def _calculate_extraction_confidence(self, invoice_data: InvoiceData, raw_text: str) -> float: + """ + Intelligent confidence scoring for extracted invoice data. + Combines presence, consistency, and numeric sanity checks. + """ + score = 0.0 + weight = { + "invoice_number": 0.1, + "order_id": 0.05, + "customer_name": 0.1, + "due_date": 0.05, + "ship_to": 0.05, + "item_details": 0.25, + "total_consistency": 0.25, + "currency_detected": 0.05, + "text_match_bonus": 0.1 + } + + text_lower = raw_text.lower() + + # Presence-based confidence + if invoice_data.invoice_number: + score += weight["invoice_number"] + if invoice_data.order_id: + score += weight["order_id"] + if invoice_data.customer_name: + score += weight["customer_name"] + if invoice_data.due_date and "due_date" in text_lower: + score += weight["due_date"] + if not invoice_data.due_date and "due_date" not in text_lower: + score += weight["due_date"] + if invoice_data.item_details: + score += weight["item_details"] + + # Currency detection + if any(c in raw_text for c in ["$", "โน", "โฌ", "usd", "inr", "eur"]): + score += weight["currency_detected"] + + # Numeric Consistency: subtotal + shipping โ total + def _extract_amounts(pattern): + import re + matches = re.findall(pattern, raw_text) + return [float(m.replace(",", "").replace("$", "").strip()) for m in matches if m] + + import re + numbers = _extract_amounts(r"\$?\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?") + if len(numbers) >= 3 and invoice_data.total: + approx_total = max(numbers) + diff = abs(approx_total - invoice_data.total) + if diff < 5: # minor difference allowed + score += weight["total_consistency"] + elif diff < 50: + score += weight["total_consistency"] * 0.5 + + # Textual verification + hits = 0 + for field in [invoice_data.customer_name, invoice_data.order_id, invoice_data.invoice_number]: + if field and str(field).lower() in text_lower: + hits += 1 + if hits >= 2: + score += weight["text_match_bonus"] + + # Penalty for empty critical fields + missing_critical = not invoice_data.total or not invoice_data.customer_name or not invoice_data.invoice_number + if missing_critical: + score *= 0.8 + + # Clamp and finalize + final_conf = round(min(score, 0.99), 2) + invoice_data.extraction_confidence = final_conf + return final_conf * 100.0 + + + async def health_check(self) -> Dict[str, Any]: + """ + Perform intelligent health diagnostics for the Document Agent. + Collects operational, performance, and API connectivity metrics. + """ + from datetime import datetime + + metrics_data = {} + executions = 0 + success_rate = 0.0 + avg_duration = 0.0 + failures = 0 + last_run = None + # latency_trend = None + + # 1. Try to get live metrics from state + print("(self.state)-------",self.metrics) + # print("self.state.agent_metrics-------", self.state.agent_metrics) + if self.metrics: + executions = self.metrics["processed"] + avg_duration = self.metrics["avg_latency_ms"] + failures = self.metrics["errors"] + last_run = self.metrics["last_run_at"] + success_rate = (executions - failures) / (executions+1e-8) + + # print(executions, avg_duration, failures, last_run, success_rate) + # latency_trend = getattr(m, "total_duration_ms", None) + + # 2. API connectivity check + gemini_ok = bool(self.api_key) + # print("self.api---", self.api_key) + # print("geminiokkkkkk", gemini_ok) + api_status = "๐ข Active" if gemini_ok else "๐ด Missing Key" + + # 3. Health logic + overall_status = "๐ข Healthy" + if not gemini_ok or failures > 3: + overall_status = "๐ Degraded" + if executions > 0 and success_rate < 0.5: + overall_status = "๐ด Unhealthy" + + # 4. Extended agent diagnostics + metrics_data = { + "Agent": "Document Agent ๐งพ", + "Executions": executions, + "Success Rate (%)": round(success_rate * 100, 2), + "Avg Duration (ms)": round(avg_duration, 2), + "Total Failures": failures, + "API Status": api_status, + "Last Run": str(last_run) if last_run else "Not applicable", + "Overall Health": overall_status, + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + } + + self.logger.logger.info(f"[HealthCheck] Document Agent metrics: {metrics_data}") + return metrics_data diff --git a/Project/agents/escalation_agent.py b/Project/agents/escalation_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9139682708472f4f2c1c623cceede9771b394d64 --- /dev/null +++ b/Project/agents/escalation_agent.py @@ -0,0 +1,315 @@ + +"""Escalation Agent for Invoice Processing""" + +# TODO: Implement agent + +import os +import json +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from typing import Dict, Any, List, Optional +from datetime import datetime, timedelta +import google.generativeai as genai +from dotenv import load_dotenv + +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, ProcessingStatus, PaymentStatus, + RiskLevel, ValidationStatus +) +from utils.logger import StructuredLogger + +load_dotenv() + + +class EscalationAgent(BaseAgent): + """Agent responsible for escalation management and human-in-the-loop workflows""" + + def __init__(self, config: Dict[str, Any] = None): + super().__init__("escalation_agent",config) + self.logger = StructuredLogger("EscalationAgent") + + self.escalation_triggers = { + 'high_risk' : {'route_to':'risk_manager','sla_hours':4}, + 'validation_failure': {'route_to':'finance_manager','sla_hours':8}, + 'high_value': {'route_to':'cfo','sla_hours':24}, + 'fraud_suspicion': {'route_to':'fraud_team','sla_hours':2}, + 'new_vendor':{'route_to':'procurement','sla_hours':48} + } + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + # pass + return hasattr(state,'invoice_data') and hasattr(state,'risk_assessment') + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + # pass + return hasattr(state,'escalation_details') + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + # pass + self.logger.logger.info('Executing Escalation Agent...') + if not self._validate_preconditions(state, workflow_type): + self.logger.logger.error("Preconditions not meet for Escalation handling") + state.status = ProcessingStatus.FAILED + self._log_decision(state, "Escalation Agent Failed", "Preconditions not met", confidence=0.0) + return state + + escalation_type = self._determine_escalation_type(state) + if not escalation_type: + self.logger.logger.info("No escalation required for this invoice.") + state.escalation_required = False + state.overall_status = 'completed' + return state + + priority_level = self._calculate_priority_level(state) + approver_info = self._route_to_approver(state, escalation_type,priority_level) + summary = await self._generate_escalation_summary(state,escalation_type,approver_info) + + escalation_record = await self._create_escalation_record(state, escalation_type, priority_level, approver_info,summary) + await self._send_escalation_notifications(state,escalation_record,approver_info) + await self._setup_sla_monitoring(state,escalation_record,priority_level) + + state.escalation_required = True + state.human_review_required = True + state.escalation_details = escalation_record + state.human_review_required = summary + state.escalation_reason = escalation_record["escalation_reason"] + state.current_agent = 'escalation_agent' + state.overall_status = 'escalated' + self._log_decision( + state, + "Escalation Successful", + "PDF successfully escalated to Human for review", + "N/A", + state.process_id + ) + self.logger.logger.info('Escalation record successfully created and routed.') + return state + + def _determine_escalation_type(self, state: InvoiceProcessingState) -> str: + # pass + risk = getattr(state,'risk_assessment',{}) + validation = getattr(state,'validation_result',{}) + invoice = getattr(state,'invoice_data',{}) + risk_level = getattr(risk,'risk_level',{}) + amount = getattr(invoice,'total',0) + vendor = getattr(invoice,'customer_name','') + # fraud_indicators = risk.get('fraud_indicators',[]) + fraud_indicators = getattr(risk,'fraud_indicators',[]) + + if risk_level in ['high','critical']: + return 'high_risk' + elif state.validation_status == 'invalid' or state.validation_status == 'missing_po': + return 'validation_failure' + elif amount and amount>250000: + return 'high_value' + elif len(fraud_indicators) > 3: + return 'fraud_suspicion' + elif vendor and 'new' in vendor.lower(): + return 'new_vendor' + else: + return None + + def _calculate_priority_level(self, state: InvoiceProcessingState) -> str: + # pass + # risk = getattr(state,'risk_assessment',{}).get('risk_level','low').lower() + # amount = getattr(state,'invoice_data',{}).get('total',0) + risk_assessment = getattr(state,'risk_assessment',{}) + invoice_data = getattr(state,'invoice_data',{}) + risk = getattr(risk_assessment,'risk_level','low').lower() + amount = getattr(invoice_data,'total',0) + if risk == 'critical' or amount > 50000: + return 'urgent' + elif risk == 'high' or amount > 25000: + return 'high' + else: + return 'medium' + + def _route_to_approver(self, state: InvoiceProcessingState, + escalation_type: str, priority_level: str) -> Dict[str, Any]: + # pass + # print(self.escalation_triggers) + route_info = self.escalation_triggers.get(escalation_type,{}) + # print("route_info..................", route_info) + assigned_to = route_info.get('route_to','finance_manager') + sla_hours = route_info.get('sla_hours',8) + approvers = ['finance_manager'] + if assigned_to == 'cfo': + approvers.append('cfo') + return { + 'assigned_to':assigned_to, + 'sla_hours':sla_hours, + 'approval_required_from':approvers + } + + + def _parse_date(self, date_str: str) -> Optional[datetime.date]: + # pass + try: + return datetime.strptime(date_str,"%Y-%m-%d").date() + except Exception: + return None + + async def _generate_escalation_summary(self, state: InvoiceProcessingState, + escalation_type: str, approver_info: Dict[str, Any]) -> str: + # pass + + risk = getattr(state,'risk_assessment',{}) + invoice = getattr(state,'invoice_data',{}) + risk_level = getattr(risk,'risk_level',{}) + amount = getattr(invoice,'total',0) + # invoice = state.invoice_data + # risk = state.risk_assessment + reason = "" + + if escalation_type == 'high_risk': + reason = f"Invoice marked as high risk ({risk_level})." + elif escalation_type == 'validation_failure': + reason = 'Validation discrepancies require finance approval.' + elif escalation_type == 'high_value': + reason = f"High-value invoice ({amount}) requires CFO approval." + elif escalation_type == 'fraud_suspicion': + reason = 'Fraud suspicion based on anomalies detected' + elif escalation_type == 'new_vendor': + reason = 'Vendor is new and not yet in approved list.' + return f"{reason} Routed to {approver_info['assigned_to']} for review." + + + async def _create_escalation_record(self, state: InvoiceProcessingState, + escalation_type: str, priority_level: str, + approver_info: Dict[str, Any], summary: str) -> Dict[str, Any]: + # pass + timestamp = datetime.utcnow() + sla_deadline = timestamp+timedelta(hours=approver_info['sla_hours']) + return { + 'escalation_type':escalation_type, + 'severity':priority_level, + 'assigned_to':approver_info['assigned_to'], + 'escalation_time':timestamp.isoformat()+'Z', + 'sla_deadline':sla_deadline.isoformat()+'Z', + 'notification_sent':True, + 'approval_required_from':approver_info['approval_required_from'], + 'escalation_reason':summary + } + + + async def _send_escalation_notifications(self, state: InvoiceProcessingState, + escalation_record: Dict[str, Any], + approver_info: Dict[str, Any]) -> Dict[str, Any]: + # pass + try: + subject = f"[Escalation Alert] Invoice requires {approver_info['assigned_to']} review" + body = f""" + Escalation Type: {escalation_record['escalation_type']} + severity: {escalation_record['severity']} + SLA Deadline: {escalation_record['sla_deadline']} + reason: {escalation_record['escalation_reason']} + """ + to_email = f"{approver_info['assigned_to']}@company.com" + self._send_email(to_email,subject,body) + self.logger.logger.info(f"Escalation notification send to {to_email}.") + return {'status':'send','to':to_email} + except Exception as e: + self.logger.logger.error(f'Failed to send notification: {e}') + return {'status':'failed','error':str(e)} + + def _send_email(self, to_email: str, subject: str, body: str) -> Dict[str, Any]: + # pass + try: + sender = os.getenv('EMAIL_SENDER','noreply@invoicesystem.com') + msg = MIMEMultipart() + msg['From'] = send + msg['To'] = to_email + msg['Subject'] = subject + msg.attach(MIMEText(body,'plain')) + with smtplib.SMTP('localhost') as server: + server.send_message(msg) + return {'sent':True} + except Exception as e: + return {'sent':False, 'error':str(e)} + + + async def _setup_sla_monitoring(self, state: InvoiceProcessingState, + escalation_record: Dict[str, Any], priority_level: str): + # pass + self.logger.logger.debug( + f"SLA monitoring initialized for {escalation_record['escalation_type']}" + f"with deadline {escalation_record['sla_deadline']}" + ) + + async def resolve_escalation(self, escalation_id: str, resolution: str, + resolver: str) -> Dict[str, Any]: + # pass + return { + 'escalation_id':escalation_id, + 'resolved_by':resolver, + 'resolution_notes':resolution, + 'resolved_at':datetime.utcnow().isoformat()+'Z', + 'status':'resolved' + } + + async def health_check(self) -> Dict[str, Any]: + """ + Performs a detailed health check for the Escalation Agent. + Includes operational metrics, configuration validation, and reliability stats. + """ + + start_time = datetime.utcnow() + self.logger.logger.info("Performing health check for EscalationAgent...") + + executions = 0 + avg_duration = 0.0 + failures = 0 + last_run = None + success_rate = 0.0 + + try: + if self.metrics: + executions = self.metrics["processed"] + avg_duration = self.metrics["avg_latency_ms"] + failures = self.metrics["errors"] + last_run = self.metrics["last_run_at"] + success_rate = (executions - failures) / (executions + 1e-8) * 100.0 if executions > 0 else 0.0 + + total_executions = executions + total_failures = failures + avg_duration_ms = avg_duration + + # Email and trigger configuration validation + email_configured = bool(os.getenv('EMAIL_SENDER')) + missing_triggers = [k for k, v in self.escalation_triggers.items() if not v.get("route_to")] + + # Duration calculation + # duration_ms = (datetime.utcnow() - start_time).total_seconds() * 1000 + # last_run = self.metrics["last_run_at"] + + health_report = { + "Agent": "Escalation Agent ๐จ", + "Executions": total_executions, + "Success Rate (%)": round(success_rate, 2), + "Avg Duration (ms)": round(avg_duration_ms, 2) if avg_duration_ms else "Not Called", + "Total Failures": total_failures, + # "Email Configured": email_configured, + # "Available Triggers": list(self.escalation_triggers.keys()), + "Missing Routes": missing_triggers, + "Last Run": self.metrics["last_run_at"], + "Overall Health": "๐ข Healthy" if (success_rate > 70 or total_executions == 0) else "Degraded โ ๏ธ", + # "Response Time (ms)": round(duration_ms, 2) + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + } + + self.logger.logger.info("EscalationAgent health check completed successfully.") + return health_report + + except Exception as e: + error_time = (datetime.utcnow() - start_time).total_seconds() * 1000 + self.logger.logger.error(f"Health check failed: {e}") + + # Return degraded health if something goes wrong + return { + "Agent": "EscalationAgent โ", + "Overall Health": "Degraded", + "Error": str(e), + "Timestamp": datetime.utcnow().isoformat() + "Z" + } diff --git a/Project/agents/forecast_agent.py b/Project/agents/forecast_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..79090e35c45f444f535da4cd5d0923540ee5e8b8 --- /dev/null +++ b/Project/agents/forecast_agent.py @@ -0,0 +1,253 @@ + +# agents/forecast_agent.py +""" +Forecast Agent (robust) +- Accepts a list of invoice states (dicts or InvoiceProcessingState models). +- Produces monthly historical spend and a simple forecast (moving average). +- Performs lightweight anomaly detection. +- Returns a dict containing a Plotly chart and numeric summary. +""" +from typing import List, Dict, Any, Union +from datetime import datetime +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import math +import os + +# keep the type import only for hints; we do NOT require reconstructing models +try: + from state import InvoiceProcessingState +except Exception: + InvoiceProcessingState = None # type: ignore + + +class ForecastAgent: + def __init__(self): + pass + + # ---- Internal: normalize input states -> DataFrame ---- + def _normalize_states_to_df(self, states: List[Union[dict, object]]) -> pd.DataFrame: + """ + Accepts list of dicts or model instances. + Produces a cleaned DataFrame with columns: + ['file_name','invoice_date','due_date','total','vendor','risk_score','status'] + """ + rows = [] + for s in states: + try: + # 1) obtain a plain dict representation without constructing pydantic models + if isinstance(s, dict): + raw = dict(s) + else: + # model-like object: try model_dump, to_dict, or __dict__ + if hasattr(s, "model_dump"): + raw = s.model_dump(exclude_none=False) + elif hasattr(s, "dict"): + raw = s.dict() + else: + # best effort: convert attributes to dict + raw = { + k: getattr(s, k) + for k in dir(s) + if not k.startswith("_") and not callable(getattr(s, k)) + } + + # 2) sanitize well-known problematic fields that break pydantic elsewhere + if "human_review_required" in raw and isinstance(raw["human_review_required"], str): + v = raw["human_review_required"].strip().lower() + raw["human_review_required"] = v in ("true", "yes", "1", "required") + if "escalation_details" in raw and isinstance(raw["escalation_details"], dict): + # convert to string summary so downstream code doesn't expect a dict + try: + raw["escalation_details"] = str(raw["escalation_details"]) + except Exception: + raw["escalation_details"] = "" + + # 3) pull invoice_data safely (may be None, dict, or model) + inv = {} + if raw.get("invoice_data") is None: + inv = {} + else: + inv_raw = raw.get("invoice_data") + if isinstance(inv_raw, dict): + inv = dict(inv_raw) + else: + # model-like invoice_data + if hasattr(inv_raw, "model_dump"): + inv = inv_raw.model_dump(exclude_none=False) + elif hasattr(inv_raw, "dict"): + inv = inv_raw.dict() + else: + # fallback: read attributes + inv = { + k: getattr(inv_raw, k) + for k in dir(inv_raw) + if not k.startswith("_") and not callable(getattr(inv_raw, k)) + } + + # 4) turnout the row items we care about + total = inv.get("total") or inv.get("amount") or raw.get("total") or 0.0 + # risk may be under risk_assessment.risk_score or top-level + risk_src = raw.get("risk_assessment") or {} + if isinstance(risk_src, dict): + risk_score = risk_src.get("risk_score") or 0.0 + else: + # model-like risk_assessment + if hasattr(risk_src, "model_dump"): + try: + risk_score = risk_src.model_dump().get("risk_score", 0.0) + except Exception: + risk_score = 0.0 + else: + risk_score = getattr(risk_src, "risk_score", 0.0) + + # dates: prefer due_date then invoice_date - they could be strings or datetimes + due = inv.get("due_date") or inv.get("invoice_date") or raw.get("due_date") or raw.get("invoice_date") + vendor = inv.get("customer_name") or inv.get("vendor_name") or raw.get("vendor") or raw.get("customer_name") or "Unknown" + file_name = inv.get("file_name") or raw.get("file_name") or "unknown" + + rows.append( + { + "file_name": file_name, + "due_date": due, + "invoice_date": inv.get("invoice_date") or raw.get("invoice_date"), + "total": total, + "vendor": vendor, + "risk_score": risk_score, + "status": raw.get("overall_status") or inv.get("status") or "unknown", + } + ) + except Exception: + # skip malformed state + continue + + df = pd.DataFrame(rows) + if df.empty: + return df + + # coerce and normalize + df["due_date"] = pd.to_datetime(df["due_date"], errors="coerce") + df["invoice_date"] = pd.to_datetime(df["invoice_date"], errors="coerce") + # if due_date missing, fallback to invoice_date + df["date"] = df["due_date"].fillna(df["invoice_date"]) + df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0) + df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0) + df["vendor"] = df["vendor"].fillna("Unknown") + return df + + # ---- Public: predict monthly cashflow and return a plotly chart ---- + def predict_cashflow(self, states: List[Union[dict, object]], months: int = 6) -> Dict[str, Any]: + """ + Produces a monthly historical spend + simple forecast for `months` into the future. + Returns: + { + "chart": plotly_figure, + "average_monthly_spend": float, + "total_forecast": float, + "forecast_values": {month_str: float, ...}, + "historical": pandas.Series, + "forecast_start_month": str, + "forecast_end_month": str + } + """ + df = self._normalize_states_to_df(states) + if df.empty or df["date"].dropna().empty: + return {"message": "No data to forecast", "chart": None} + + # create monthly buckets (period start) + df = df.dropna(subset=["date"]) + df["month"] = df["date"].dt.to_period("M").dt.to_timestamp() + monthly_hist = df.groupby("month")["total"].sum().sort_index() + + # compute average monthly spend from available historical months + average_month = float(monthly_hist.mean()) if not monthly_hist.empty else 0.0 + + # build forecast months (next `months` starting from the next month after last historical) + last_hist_month = monthly_hist.index.max() + if pd.isnull(last_hist_month): + start_month = pd.Timestamp.now().to_period("M").to_timestamp() + else: + # next month + start_month = (last_hist_month + pd.offsets.MonthBegin(1)).normalize() + + forecast_index = pd.date_range(start=start_month, periods=months, freq="MS") + # simple forecast: repeat the historical mean (interpretable and safe) + forecast_vals = [average_month for _ in range(len(forecast_index))] + + # build plot dataframe (historical + forecast) + hist_df = monthly_hist.reset_index().rename(columns={"month": "date", "total": "amount"}) + hist_df["type"] = "Historical" + fc_df = pd.DataFrame({"date": forecast_index, "amount": forecast_vals}) + fc_df["type"] = "Forecast" + plot_df = pd.concat([hist_df, fc_df], ignore_index=True).sort_values("date") + + # prepare a plotly figure with clear styling + fig = go.Figure() + # historical - solid line + hist_plot = plot_df[plot_df["type"] == "Historical"] + if not hist_plot.empty: + fig.add_trace(go.Scatter( + x=hist_plot["date"], + y=hist_plot["amount"], + mode="lines+markers", + name="Historical Spend", + line=dict(dash="solid"), + )) + # forecast - dashed line + fc_plot = plot_df[plot_df["type"] == "Forecast"] + if not fc_plot.empty: + fig.add_trace(go.Scatter( + x=fc_plot["date"], + y=fc_plot["amount"], + mode="lines+markers", + name="Forecast", + line=dict(dash="dash"), + marker=dict(symbol="circle-open") + )) + + fig.update_layout( + title="Monthly Spend (Historical + Forecast)", + xaxis_title="Month", + yaxis_title="Total Spend (USD)", + hovermode="x unified", + template="plotly_dark", + ) + + forecast_series = pd.Series(forecast_vals, index=[d.strftime("%Y-%m") for d in forecast_index]) + total_forecast = float(forecast_series.sum()) + + result = { + "chart": fig, + "average_monthly_spend": round(average_month, 2), + "total_forecast": round(total_forecast, 2), + "forecast_values": forecast_series.to_dict(), + "historical": monthly_hist, + "forecast_start_month": forecast_index[0].strftime("%Y-%m"), + "forecast_end_month": forecast_index[-1].strftime("%Y-%m"), + } + return result + + # ---- Public: detect anomalies on sanitized data ---- + def detect_anomalies(self, states: List[Union[dict, object]]) -> pd.DataFrame: + """ + Returns DataFrame of anomalies: + - total > 2 * mean(total) + - OR risk_score >= 0.7 + Columns returned: ['file_name','date','vendor','total','risk_score','anomaly_reason'] + """ + df = self._normalize_states_to_df(states) + if df.empty: + return pd.DataFrame() + + mean_spend = df["total"].mean() + cond = (df["total"] > mean_spend * 2) | (df["risk_score"] >= 0.7) + anomalies = df.loc[cond, ["file_name", "date", "vendor", "total", "risk_score"]].copy() + if anomalies.empty: + return pd.DataFrame() + anomalies = anomalies.rename(columns={"date": "invoice_date"}) + anomalies["anomaly_reason"] = anomalies.apply( + lambda r: "High Spend" if r["total"] > mean_spend * 2 else "High Risk", + axis=1, + ) + return anomalies.reset_index(drop=True) diff --git a/Project/agents/insights_agent.py b/Project/agents/insights_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..1d621cfa2a48266d5c16ab443e9339ef0a76fff0 --- /dev/null +++ b/Project/agents/insights_agent.py @@ -0,0 +1,107 @@ + +# agents/insights_agent.py +""" +Insight Agent +------------- +Generates analytical and visual insights from processed invoices. +""" + +import pandas as pd +import plotly.express as px +from typing import List, Dict, Any +from state import InvoiceProcessingState + + +class InsightAgent: + def __init__(self): + pass + + def _extract_invoice_records(self, results: List[InvoiceProcessingState]) -> pd.DataFrame: + """Extract flat invoice info for analysis""" + records = [] + for r in results: + if isinstance(r, dict): + # Convert dict to InvoiceProcessingState if needed + try: + r = InvoiceProcessingState(**r) + except Exception: + continue + + inv = getattr(r, "invoice_data", None) + risk = getattr(r, "risk_assessment", None) + val = getattr(r, "validation_result", None) + pay = getattr(r, "payment_decision", None) + + records.append({ + "file_name": getattr(inv, "file_name", None), + "invoice_number": getattr(inv, "invoice_number", None), + "customer_name": getattr(inv, "customer_name", None), + "invoice_date": getattr(inv, "invoice_date", None), + "total": getattr(inv, "total", None), + "validation_status": getattr(val, "validation_status", None), + "risk_score": getattr(risk, "risk_score", None), + "risk_level": getattr(risk, "risk_level", None), + "payment_status": getattr(pay, "status", None), + "decision": getattr(pay, "decision", None), + }) + + df = pd.DataFrame(records) + if df.empty: + return pd.DataFrame() + + # Clean up data + df["customer_name"] = df["customer_name"].fillna("Unknown Vendor") + df["total"] = pd.to_numeric(df["total"], errors="coerce").fillna(0.0) + df["risk_score"] = pd.to_numeric(df["risk_score"], errors="coerce").fillna(0.0) + return df + + def generate_insights(self, results: List[InvoiceProcessingState]) -> Dict[str, Any]: + """Generate charts and textual summary.""" + df = self._extract_invoice_records(results) + if df.empty: + return {"summary": "No data available for insights.", "charts": []} + + charts = [] + + # ๐น Total spend per customer + if "customer_name" in df.columns: + spend_chart = px.bar( + df.groupby("customer_name", as_index=False)["total"].sum(), + x="customer_name", + y="total", + title="Total Spend per Customer" + ) + charts.append(spend_chart) + + # ๐น Risk distribution + if "risk_level" in df.columns: + risk_chart = px.pie( + df, + names="risk_level", + title="Risk Level Distribution" + ) + charts.append(risk_chart) + + # ๐น Validation status counts + if "validation_status" in df.columns: + val_chart = px.bar( + df.groupby("validation_status", as_index=False).size(), + x="validation_status", + y="size", + title="Validation Status Overview" + ) + charts.append(val_chart) + + # ๐น Summary text + total_spend = df["total"].sum() + high_risk = (df["risk_score"] >= 0.7).sum() + valid_invoices = (df["validation_status"].astype(str).str.lower() == "valid").sum() + + summary = ( + f"๐ฐ **Total Spend:** โน{total_spend:,.2f}\n\n" + f"๐ **Invoices Processed:** {len(df)}\n\n" + f"โ **Valid Invoices:** {valid_invoices}\n\n" + f"โ ๏ธ **High Risk Invoices:** {high_risk}\n\n" + ) + + return {"summary": summary, "charts": charts} diff --git a/Project/agents/payment_agent.py b/Project/agents/payment_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ae2e506bb93d18570ebc8f35972bad5d878c04 --- /dev/null +++ b/Project/agents/payment_agent.py @@ -0,0 +1,348 @@ + +"""Payment Agent for Invoice Processing""" + +# TODO: Implement agent + +import os +import json +import requests +from typing import Dict, Any, Optional +from datetime import datetime, timedelta +import google.generativeai as genai +from dotenv import load_dotenv +import time +import requests + +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, PaymentDecision, PaymentStatus, + RiskLevel, ValidationStatus, ProcessingStatus, RiskAssessment +) +from utils.logger import StructuredLogger + +load_dotenv() + + +class PaymentAgent(BaseAgent): + """Agent responsible for payment processing decisions and execution""" + # Persistent in-memory history (like validation agent) + health_history = [] + + def __init__(self, config: Dict[str, Any] = None): + # pass + super().__init__("payment_agent", config) + self.logger = StructuredLogger("PaymentAgent") + self.approved_vendor_list = ["Acme Corporation", "TechNova Ltd", "SupplyCo"] + self.retry_limit = 3 + # Health metrics tracking + self.total_executions = 0 + self.successful_executions = 0 + self.failed_executions = 0 + self.total_duration = 0.0 + self.last_transaction_id = None + self.last_run = None + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + # pass + if workflow_type == "expedited": + return bool(state.validation_result.validation_status.VALID and state.invoice_data) + else: + return bool(state.risk_assessment and state.invoice_data) + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + # pass + return bool(state.payment_decision) + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + # pass + start_time = time.time() + try: + if not self._validate_preconditions(state, workflow_type): + state.overall_status = ProcessingStatus.FAILED + self._log_decision(state, "Payment Agent Failed", "Preconditions not met", confidence=0.0) + return state + + invoice_data = state.invoice_data + validation_result = state.validation_result + if workflow_type == "expedited": + risk_assessment = RiskAssessment( + risk_level = RiskLevel.LOW, + risk_score = 0.3, + fraud_indicators = None, + compliance_issues = None, + recommendation = None, + reason = "Expedited Workflow Called", + requires_human_review = "Not needed due to Expedited Workflow" + ) + payment_decision = PaymentDecision( + decision = "auto_pay", + status = PaymentStatus.APPROVED, + approved_amount = invoice_data.total, + transaction_id = f"TXN-{datetime.utcnow().strftime('%Y-%m-%d-%H%M%S')}", + payment_method = self._select_payment_method(invoice_data.total), + approval_chain = ["system_auto_approval"], + rejection_reason = None, + scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH") + ) + payment_result = await self._execute_payment(invoice_data, payment_decision) + payment_decision = self._update_payment_decision(payment_decision, payment_result) + + justification = await self._generate_payment_justification( + invoice_data, payment_decision, validation_result, risk_assessment + ) + + state.payment_decision = payment_decision + state.overall_status = ProcessingStatus.COMPLETED + state.current_agent = "payment_agent" + # success criteria + if payment_decision.status == PaymentStatus.APPROVED: + self.successful_executions += 1 + else: + self.failed_executions += 1 + + self.last_transaction_id = payment_decision.transaction_id + self._log_decision(state, payment_decision.status, justification, 95.0, state.process_id) + return state + else: + risk_assessment = state.risk_assessment + + payment_decision = await self._make_payment_decision( + invoice_data, validation_result, risk_assessment, state + ) + if payment_decision.decision == "auto_pay": + state.approval_chain = [ + { + "approved_by":"system_auto_approval in payment_agent", + "timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") + } + ] + else: + state.approval_chain = [{"payment_agent":"Failed or Rejected"}] + + + payment_result = await self._execute_payment(invoice_data, payment_decision) + payment_decision = self._update_payment_decision(payment_decision, payment_result) + + justification = await self._generate_payment_justification( + invoice_data, payment_decision, validation_result, risk_assessment + ) + + state.payment_decision = payment_decision + state.overall_status = ProcessingStatus.COMPLETED + state.current_agent = "payment_agent" + # success criteria + if payment_decision.status == PaymentStatus.APPROVED: + print("self.successful_executions---", self.successful_executions) + self.successful_executions += 1 + else: + self.failed_executions += 1 + + self.last_transaction_id = payment_decision.transaction_id + self._log_decision(state, payment_decision.status, justification, 95.0, state.process_id) + return state + + except Exception as e: + self.failed_executions += 1 + self.logger.logger.error(f"[PaymentAgent] Execution failed: {e}") + state.overall_status = ProcessingStatus.FAILED + return state + + finally: + duration = (time.time() - start_time) * 1000 # in ms + print("self.total_executions---", self.total_executions) + self.last_run = datetime.utcnow().isoformat() + self.total_executions += 1 + self.total_duration += duration + self._record_health_metrics(duration) + + async def _make_payment_decision(self, invoice_data, validation_result, + risk_assessment, state: InvoiceProcessingState) -> PaymentDecision: + # pass + amount = invoice_data.total or invoice_data.total_amount or 0.0 + risk_level = risk_assessment.risk_level + validation_status = validation_result.validation_status + + if risk_level == RiskLevel.CRITICAL or validation_status == ValidationStatus.INVALID: + decision = PaymentDecision( + decision = "reject", + status = PaymentStatus.FAILED, + approved_amount = 0.0, + transaction_id = None, + payment_method = None, + approval_chain = [], + rejection_reason = "Critical Risk or Invalid Validation", + scheduled_date = None + ) + elif risk_level == RiskLevel.LOW or amount < 5000: + decision = PaymentDecision( + decision = "auto_pay", + status = PaymentStatus.APPROVED, + approved_amount = amount, + transaction_id = f"TXN-{datetime.utcnow().strftime('%Y-%m-%d-%H%M%S')}", + payment_method = self._select_payment_method(amount), + approval_chain = ["system_auto_approval"], + rejection_reason = None, + scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH") + ) + elif risk_level == RiskLevel.MEDIUM or validation_status == ValidationStatus.PARTIAL_MATCH: + decision = PaymentDecision( + decision = "hold", + status = PaymentStatus.PENDING_APPROVAL, + approved_amount = amount, + transaction_id = None, + payment_method = self._select_payment_method(amount), + approval_chain = ["system_auto_approval", "finance_manager_approval"], + rejection_reason = None, + scheduled_date = self._calculate_payment_date(invoice_data.due_date, "ACH") + ) + else: + decision = PaymentDecision( + decision = "manual_approval", + status = PaymentStatus.PENDING_APPROVAL, + approved_amount = amount, + transaction_id = None, + payment_method = self._select_payment_method(amount), + approval_chain = ["system_auto_approval", "executive_approval"], + rejection_reason = None, + scheduled_date = self._calculate_payment_date(invoice_data.due_date, "WIRE") + ) + + return decision + + def _select_payment_method(self, amount: float) -> str: + # pass + if amount < 5000: + return "ACH" + elif amount < 25000: + return "WIRE" + return "MANUAL" + + def _calculate_payment_date(self, due_date_str: Optional[str], payment_method: str) -> datetime: + # pass + due_date = self._parse_date(due_date_str) + if not due_date: + due_date = datetime.utcnow().date() + timedelta(days=3) + offset = 1 if payment_method == "ACH" else 2 + return datetime.combine(due_date, datetime.min.time()) + timedelta(days=offset) + + + def _parse_date(self, date_str: str) -> Optional[datetime.date]: + # pass + if not date_str: + return None + try: + return datetime.strptime(date_str, "%Y-%m-%d").date() + except Exception: + return None + + # async def _execute_payment(self, invoice_data, payment_decision: PaymentDecision) -> Dict[str, Any]: + # # pass + # await self._async_sleep(1) + # response = requests.post("http://localhost:8000", data=PaymentRequest) + # if payment_decision.status == PaymentStatus.FAILED: + # return {"status": "failed", "message": "Payment rejected by policy."} + # return {"status": "success", "transaction_id": payment_decision.transaction_id or f"TXN-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}", "message": "Payment executed."} + + async def _execute_payment(self, invoice_data, payment_decision: PaymentDecision) -> Dict[str, Any]: + """Send payment request to web API and return response with transaction_id""" + import asyncio + await asyncio.sleep(1) + + payment_payload = { + "order_id": invoice_data.invoice_number or f"INV-{int(datetime.utcnow().timestamp())}", + "customer_name": invoice_data.customer_name or "Unknown Vendor", + "amount": float(invoice_data.total), + "currency": "USD", + # "method": payment_decision.payment_method.lower(), + "recipient_account": "auto_generated_account", + "due_date": str(invoice_data.due_date or datetime.utcnow().date()) + } + + try: + response = requests.post("http://localhost:8001/initiate_payment", json=payment_payload, timeout=10) + if response.status_code == 200: + result = response.json() + print("res from apiii =======", result) + return { + "status": "success" if result["status"] == "SUCCESS" else "failed", + "transaction_id": result["transaction_id"], + "message": result["message"] + } + else: + print("res from apiii111111 =======", result) + return {"status": "failed", "message": f"HTTP {response.status_code}: {response.text}"} + + except Exception as e: + print("res from apiii111111222222222222 =======", result) + return {"status": "failed", "message": f"Payment API error: {e}"} + + async def _async_sleep(self, seconds: int): + # pass + import asyncio + await asyncio.sleep(seconds) + + def _update_payment_decision(self, payment_decision: PaymentDecision, + payment_result: Dict[str, Any]) -> PaymentDecision: + # pass + if payment_result.get("status") == "success": + payment_decision.status = PaymentStatus.APPROVED + payment_decision.transaction_id = payment_result.get("transaction_id") + else: + payment_decision.status = PaymentStatus.FAILED + payment_decision.rejection_reason = payment_result.get("message") + return payment_decision + + + async def _generate_payment_justification(self, invoice_data, payment_decision: PaymentDecision, + validation_result, risk_assessment) -> str: + # pass + reason = f"Payment Decision: {payment_decision.status}. " + if payment_decision.status == PaymentStatus.FAILED: + reason += f"Reason: {payment_decision.rejection_reason}" + reason += f"Risk level: {risk_assessment.risk_level}. Validation: {validation_result.validation_status}." + return reason + + def _record_health_metrics(self, duration: float): + """Update and record health statistics""" + success_rate = ( + (self.successful_executions / self.total_executions) * 100 + if self.total_executions else 0 + ) + avg_duration = ( + self.total_duration / self.total_executions + if self.total_executions else 0 + ) + overall_status = "๐ข Healthy" + if success_rate < 70: + overall_status = "๐ Degraded" + if success_rate < 60: + overall_status = "๐ด Unhealthy" + + metrics = { + "Agent": "Payment Agent ๐ณ", + "Executions": self.total_executions, + "Success Rate (%)": round(success_rate, 2), + "Avg Duration (ms)": round(avg_duration, 2), + "Total Failures": self.failed_executions, + "Last Transaction ID": self.last_transaction_id or "N/A", + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + "Last Run": self.last_run, + "Overall Health": overall_status, + } + + PaymentAgent.health_history.append(metrics) + PaymentAgent.health_history = PaymentAgent.health_history[-50:] # keep last 50 + + async def health_check(self) -> Dict[str, Any]: + """Return the current or last known health state""" + await self._async_sleep(0.05) + if not PaymentAgent.health_history: + return { + "Agent": "Payment Agent ๐ณ", + "Executions": 0, + "Success Rate (%)": 0.0, + "Avg Duration (ms)": 0.0, + "Total Failures": 0, + "Last Transaction ID": "N/A", + } + return PaymentAgent.health_history[-1] diff --git a/Project/agents/risk_agent.py b/Project/agents/risk_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b0362075321f6b340b2f6ac4853d8bd092979f --- /dev/null +++ b/Project/agents/risk_agent.py @@ -0,0 +1,644 @@ + +"""Risk Assessment Agent for Invoice Processing""" + +# TODO: Implement agent + +import os +import json +import re +from typing import Dict, Any, List +import google.generativeai as genai +from dotenv import load_dotenv +import numpy as np +from datetime import datetime, timedelta +from statistics import mean +import time +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, RiskAssessment, RiskLevel, + ValidationStatus, ProcessingStatus +) +from utils.logger import StructuredLogger + +load_dotenv() + +from collections import defaultdict +class APIKeyBalancer: + SAVE_FILE = "key_stats.json" + def __init__(self, keys): + self.keys = keys + self.usage = defaultdict(int) + self.errors = defaultdict(int) + self.load() + + def load(self): + if os.path.exists(self.SAVE_FILE): + data = json.load(open(self.SAVE_FILE)) + self.usage.update(data.get("usage", {})) + self.errors.update(data.get("errors", {})) + + def save(self): + json.dump({ + "usage": self.usage, + "errors": self.errors + }, open(self.SAVE_FILE, "w")) + + def get_best_key(self): + # choose least used or least errored key + best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k])) + self.usage[best_key] += 1 + self.save() + return best_key + + def report_error(self, key): + self.errors[key] += 1 + self.save() + +balancer = APIKeyBalancer([ + os.getenv("GEMINI_API_KEY_1"), + os.getenv("GEMINI_API_KEY_2"), + os.getenv("GEMINI_API_KEY_3"), + # os.getenv("GEMINI_API_KEY_4"), + os.getenv("GEMINI_API_KEY_5"), + os.getenv("GEMINI_API_KEY_6"), + # os.getenv("GEMINI_API_KEY_7"), +]) + +class RiskAgent(BaseAgent): + """Agent responsible for risk assessment, fraud detection, and compliance checking""" + + def __init__(self, config: Dict[str, Any] = None): + super().__init__("risk_agent",config) + # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7")) + self.logger = StructuredLogger("risk_agent") + self.api_key = balancer.get_best_key() + print("self.api_key..........", self.api_key) + genai.configure(api_key=self.api_key) + self.model = genai.GenerativeModel("gemini-2.0-flash") + # --- Metrics tracking --- + self.execution_history: List[Dict[str, Any]] = [] + self.max_history = 50 # keep last 50 runs + + def generate(self, prompt): + try: + response = self.model.generate_content(prompt) + return response + except Exception as e: + balancer.report_error(self.api_key) + raise + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + return bool(state.invoice_data and state.validation_result) + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + return bool(state.risk_assessment and state.risk_assessment.risk_score is not None) + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + start_time = time.time() + success = False + try: + if not self._validate_preconditions(state, workflow_type): + state.overall_status = ProcessingStatus.FAILED + self._log_decision(state, "Risk Assessment Analysis Failed", "Preconditions not met", confidence=0.0) + + invoice_data = state.invoice_data + validation_result = state.validation_result + + base_score = await self._calculate_base_risk_score(invoice_data, validation_result) + print("base_score:",base_score) + fraud_indicators = await self._detect_fraud_indicators(invoice_data, validation_result) + print("fraud_indicators:",fraud_indicators) + compliance_issues = await self._check_compliance(invoice_data, state) + print("compliance_issues:",compliance_issues) + ai_assessment = await self._ai_risk_assessment(invoice_data, validation_result, fraud_indicators) + print("ai_assessment:",ai_assessment) + + combined_score = self._combine_risk_factors(base_score, fraud_indicators, compliance_issues, ai_assessment) + print("combined_score:",combined_score) + + risk_level = self._determine_risk_level(combined_score) + print("risk_level:",risk_level) + + recommendation = self._generate_recommendation(risk_level, fraud_indicators, compliance_issues, validation_result) + print("recommendation:", recommendation) + state.risk_assessment = RiskAssessment( + risk_level = risk_level, + risk_score = combined_score, + fraud_indicators = fraud_indicators, + compliance_issues = compliance_issues, + recommendation = recommendation["action"], + reason = recommendation["reason"], + requires_human_review = recommendation["requires_human_review"] + ) + + state.current_agent = "risk_agent" + state.overall_status = ProcessingStatus.IN_PROGRESS + success = True + self._log_decision( + state, + "Risk Assessment Successful", + "PDF text successfully verified by Risk Agent and checked by AI", + combined_score, + state.process_id + ) + return state + finally: + duration_ms = round((time.time() - start_time) * 1000, 2) + self._record_execution(success, duration_ms) + + async def _calculate_base_risk_score(self, invoice_data, validation_result) -> float: + """ + Calculates an intelligent risk score (0.0โ1.0) based on validation results, + invoice metadata, and contextual financial factors. + """ + score = 0.0 + + # --- 1. Validation & PO related risks --- + if validation_result: + if validation_result.validation_status == ValidationStatus.INVALID: + score += 0.4 + elif validation_result.validation_status == ValidationStatus.PARTIAL_MATCH: + score += 0.25 + elif validation_result.validation_status == ValidationStatus.MISSING_PO: + score += 0.3 + + # Core mismatch signals + if not validation_result.amount_match: + score += 0.2 + if not validation_result.rate_match: + score += 0.15 + if not validation_result.quantity_match: + score += 0.1 + + # Low confidence from validation adds risk + if validation_result.confidence_score is not None: + score += (0.5 - validation_result.confidence_score) * 0.3 if validation_result.confidence_score < 0.5 else 0 + + # --- 2. Invoice amount-based risk --- + if invoice_data and invoice_data.total is not None: + total = invoice_data.total + if total > 1_000_000: + score += 0.4 # Extremely high-value invoices + elif total > 100_000: + score += 0.25 + elif total > 10_000: + score += 0.1 + elif total < 10: + score += 0.15 # Suspiciously small invoice + + # --- 3. Temporal risks (based on due date) --- + if invoice_data and getattr(invoice_data, "due_date", None): + try: + score += self._calculate_due_date_risk(invoice_data.due_date) + except Exception: + pass # Graceful degradation if due_date is invalid + + # --- 4. Vendor / Customer risks --- + if invoice_data and getattr(invoice_data, "customer_name", None): + name = invoice_data.customer_name.lower() + if "new_vendor" in name or "test" in name or "demo" in name: + score += 0.2 + elif any(flag in name for flag in ["fraud", "fake", "invalid"]): + score += 0.3 + + # --- 5. Data reliability / extraction confidence --- + if invoice_data and getattr(invoice_data, "extraction_confidence", None) is not None: + conf = invoice_data.extraction_confidence + if conf < 0.5: + score += 0.2 + elif conf < 0.7: + score += 0.1 + + # --- 6. Currency and metadata anomalies --- + currency = getattr(invoice_data, "currency", "USD") or "USD" + if currency.upper() not in {"USD", "EUR", "GBP", "INR"}: + score += 0.15 # uncommon currencies add risk + + # Normalize score within [0, 1.0] + return round(min(score, 1.0), 3) + + def _calculate_due_date_risk(self, due_date_str: str) -> float: + try: + due_date = self._parse_date(due_date_str) + days_until_due = (due_date - datetime.utcnow().date()).days + if days_until_due < 0: + return 0.2 + elif days_until_due < 5: + return 0.1 + return 0.0 + except Exception: + return 0.05 + + def _parse_date(self, date_str: str) -> datetime.date: + return datetime.strptime(date_str,"%Y-%m-%d").date() + + async def _detect_fraud_indicators(self, invoice_data, validation_result) -> List[str]: + """ + Performs intelligent fraud detection on the given invoice and validation results. + Returns a list of detected fraud indicators. + """ + indicators = [] + + # 1. PO / Validation mismatches + if validation_result: + if not validation_result.po_found: + indicators.append("No matching Purchase Order found") + if not validation_result.amount_match: + indicators.append("Amount discrepancy detected") + if not validation_result.rate_match: + indicators.append("Rate inconsistency with Purchase Order") + if not validation_result.quantity_match: + indicators.append("Quantity mismatch detected") + if validation_result.confidence_score is not None and validation_result.confidence_score < 0.6: + indicators.append(f"Low validation confidence ({validation_result.confidence_score:.2f})") + + # 2. Vendor / Customer anomalies + customer_name = getattr(invoice_data, "customer_name", "") or "" + if "test" in customer_name.lower() or "demo" in customer_name.lower(): + indicators.append("Suspicious vendor name (Test/Demo account)") + if "new_vendor" in customer_name.lower(): + indicators.append("First-time or unverified vendor") + if any(keyword in customer_name.lower() for keyword in ["fraud", "fake", "invalid"]): + indicators.append("Vendor flagged with risky keywords") + + # 3. Amount-level risk signals + total = getattr(invoice_data, "total", 0.0) or 0.0 + if total > 1_000_000: + indicators.append(f"Unusually high invoice total (${total:,.2f})") + elif total < 10: + indicators.append(f"Suspiciously low invoice total (${total:,.2f})") + + # 4. Date anomalies + due_date = getattr(invoice_data, "due_date", None) + invoice_date = getattr(invoice_data, "invoice_date", None) + if invoice_date and due_date and (due_date - invoice_date).days < 0: + indicators.append("Due date earlier than invoice date (possible manipulation)") + elif invoice_date and due_date and (due_date - invoice_date).days < 3: + indicators.append("Unusually short payment window") + + # 5. Duplicate or pattern-based red flags + if invoice_data.invoice_number and invoice_data.invoice_number.lower().startswith("dup-"): + indicators.append("Possible duplicate invoice ID pattern") + if invoice_data.file_name and "copy" in invoice_data.file_name.lower(): + indicators.append("Invoice filename suggests duplication") + + # 6. Confidence anomalies (AI extraction) + if invoice_data.extraction_confidence is not None and invoice_data.extraction_confidence < 0.5: + indicators.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f}) โ possible OCR tampering") + + # 7. Currency or unusual metadata patterns + if getattr(invoice_data, "currency", "").upper() not in {"USD", "EUR", "GBP", "INR"}: + indicators.append(f"Uncommon currency code: {invoice_data.currency}") + + return indicators + + + async def _check_compliance(self, invoice_data, state: InvoiceProcessingState) -> List[str]: + """ + Performs a multi-layer compliance check on invoice and state integrity. + Returns a list of detected compliance issues. + """ + issues = [] + + # 1. Invoice integrity checks + if not invoice_data.invoice_number: + issues.append("Missing invoice number") + if not invoice_data.customer_name: + issues.append("Missing customer name") + if not invoice_data.total or invoice_data.total <= 0: + issues.append("Invalid or missing total amount") + if not invoice_data.due_date: + issues.append("Missing due date") + + # 2. Item-level verification + if not invoice_data.item_details or len(invoice_data.item_details) == 0: + issues.append("No item details present") + else: + for item in invoice_data.item_details: + if not getattr(item, "item_name", None): + issues.append("Item missing name") + if getattr(item, "quantity", 1) <= 0: + issues.append(f"Invalid quantity for item '{item.item_name or 'Unknown'}'") + + # 3. Confidence & quality checks + if invoice_data.extraction_confidence and invoice_data.extraction_confidence < 0.7: + issues.append(f"Low extraction confidence ({invoice_data.extraction_confidence:.2f})") + + # 4. Workflow state checks + if not getattr(state, "approval_chain", True): + issues.append("Approval chain incomplete") + if getattr(state, "escalation_required", False): + issues.append("Escalation required before payment") + if getattr(state, "human_review_required", False): + issues.append("Pending human review") + + # 5. Audit consistency + if len(state.audit_trail) == 0: + issues.append("No audit trail entries found") + + # # 6. Optional receipt confirmation + # if not getattr(invoice_data, "receipt_confirmed", False): + # issues.append("Missing receipt confirmation") + + # 7. Risk-based compliance (if risk assessment exists) + if state.risk_assessment and state.risk_assessment.risk_score >= 0.7: + issues.append(f"High risk score detected ({state.risk_assessment.risk_score:.2f})") + + return issues + + + async def _ai_risk_assessment( + self, + invoice_data, + validation_result, + fraud_indicators: List[str] + ) -> Dict[str, Any]: + """ + Uses a Generative AI model (Gemini) to assess risk level based on + structured invoice data, validation results, and detected fraud indicators. + + Returns: + dict: { + "risk_score": float between 0โ1, + "reason": str (explanation for the score) + } + """ + self.logger.logger.info("[RiskAgent] Running AI-based risk assessment...") + # model_name = "gemini-2.5-flash" + result = {"risk_score": 0.0, "reason": "Default โ AI assessment not available"} + + try: + # Initialize model + # model = genai.GenerativeModel(model_name) + + # --- Construct dynamic and context-rich prompt --- + prompt = f""" + You are a financial risk analysis model for invoice fraud detection. + Carefully analyze the following details: + + INVOICE DATA: + {invoice_data} + + VALIDATION RESULT: + {validation_result} + + DETECTED FRAUD INDICATORS: + {fraud_indicators} + + TASK: + 1. Assess overall risk of this invoice being fraudulent or non-compliant. + 2. Provide reasoning. + 3. Respond **only in JSON** with keys: + - "risk_score": a float between 0 and 1 (higher = higher risk) + - "reason": short explanation of what contributed to this score. + + EXAMPLES: + {{ + "risk_score": 0.85, + "reason": "High amount mismatch, new vendor, and unusual currency" + }} + {{ + "risk_score": 0.25, + "reason": "Valid PO and consistent totals, low fraud signals" + }} + """ + import asyncio + # --- Model call --- + response = self.generate(prompt) + # response = await asyncio.to_thread(model.generate_content, prompt) + + # --- Clean and parse --- + raw_text = getattr(response, "text", "") or "" + cleaned_json = self._clean_json_response(raw_text) + ai_output = json.loads(cleaned_json) + + # --- Validate AI output --- + score = float(ai_output.get("risk_score", 0.0)) + reason = str(ai_output.get("reason", "No reason provided")) + + # Clamp score between 0โ1 for safety + result = { + "risk_score": max(0.0, min(score, 1.0)), + "reason": reason.strip()[:400] # limit for logs + } + + self.logger.logger.info( + f"[RiskAgent] AI Risk Assessment completed: score={result['risk_score']}, reason={result['reason']}" + ) + + except json.JSONDecodeError as e: + self.logger.logger.warning(f"[RiskAgent] JSON parsing failed: {e}") + result["reason"] = "AI response could not be parsed" + + except Exception as e: + self.logger.logger.error(f"[RiskAgent] AI assessment error: {e}", exc_info=True) + result["reason"] = "Fallback to base risk model" + + return result + + + def _clean_json_response(self, text: str) -> str: + text = re.sub(r'^[^{]*','',text) + text = re.sub(r'[^}]*$','',text) + return text + + def _combine_risk_factors( + self, + base_score: float, + fraud_indicators: List[str], + compliance_issues: List[str], + ai_assessment: Dict[str, Any] + ) -> float: + """ + Combines multiple risk components (base, fraud, compliance, and AI analysis) + into a single normalized risk score between 0.0 and 1.0. + + Weighting strategy: + - Base Score: foundation derived from deterministic checks + - Fraud Indicators: +0.1 per flag (max +0.3) + - Compliance Issues: +0.05 per issue (max +0.2) + - AI Risk Score: contributes 40โ50% of total weight + + Returns: + float: final risk score clamped to [0, 1] + """ + try: + # Extract and normalize AI risk + ai_score = float(ai_assessment.get("risk_score", 0.0)) + ai_score = max(0.0, min(ai_score, 1.0)) + + # --- Weighted contributions --- + fraud_contrib = min(len(fraud_indicators) * 0.1, 0.3) + compliance_contrib = min(len(compliance_issues) * 0.05, 0.2) + ai_contrib = 0.5 * ai_score if ai_score > 0 else 0.2 * base_score + + combined = base_score + fraud_contrib + compliance_contrib + ai_contrib + + # Cap at 1.0 for safety + final_score = round(min(combined, 1.0), 3) + + self.logger.logger.info( + f"[RiskAgent] Combined risk computed: base={base_score}, " + f"fraud_flags={len(fraud_indicators)}, compliance_flags={len(compliance_issues)}, " + f"ai_score={ai_score}, final={final_score}" + ) + + return final_score + + except Exception as e: + self.logger.logger.error(f"[RiskAgent] Error combining risk factors: {e}", exc_info=True) + return min(base_score + 0.2, 1.0) # fallback conservative estimate + + + def _determine_risk_level(self, risk_score: float) -> RiskLevel: + if risk_score<0.3: + return RiskLevel.LOW + elif risk_score<0.6: + return RiskLevel.MEDIUM + elif risk_score<0.8: + return RiskLevel.HIGH + return RiskLevel.CRITICAL + + def _generate_recommendation( + self, + risk_level: RiskLevel, + fraud_indicators: List[str], + compliance_issues: List[str], + validation_result + ) -> Dict[str, Any]: + """ + Generate a structured recommendation (approve, escalate, or reject) + based on overall risk, fraud, and compliance outcomes. + + Decision Logic: + - HIGH / CRITICAL risk โ escalate for human review + - INVALID validation โ reject + - Medium risk with minor issues โ escalate + - Otherwise โ approve + + Returns: + Dict[str, Any]: { + 'action': str, # 'approve', 'escalate', or 'reject' + 'reason': str, # Explanation summary + 'requires_human_review': bool + } + """ + try: + # --- Determine key flags --- + has_fraud = bool(fraud_indicators) + has_compliance_issues = bool(compliance_issues) + validation_invalid = ( + validation_result and validation_result.validation_status == ValidationStatus.INVALID + ) + + # --- Decision Logic --- + if validation_invalid: + action = "reject" + requires_review = True + reason = "Validation failed: " + "; ".join(fraud_indicators + compliance_issues or ["Invalid invoice data"]) + + elif risk_level in [RiskLevel.HIGH, RiskLevel.CRITICAL]: + action = "escalate" + requires_review = True + reason = f"High risk level detected ({risk_level.value}). Issues: " + "; ".join(fraud_indicators + compliance_issues or ["Potential anomalies"]) + + elif has_fraud or has_compliance_issues: + action = "escalate" + requires_review = True + reason = "Minor irregularities found: " + "; ".join(fraud_indicators + compliance_issues) + + else: + action = "approve" + requires_review = False + reason = "All checks passed; invoice appears valid and compliant." + + # --- Structured Output --- + recommendation = { + "action": action, + "reason": reason, + "requires_human_review": requires_review, + } + + self.logger.logger.info( + f"[DecisionAgent] Recommendation generated: {recommendation}" + ) + return recommendation + + except Exception as e: + self.logger.logger.error(f"[DecisionAgent] Error generating recommendation: {e}", exc_info=True) + # Safe fallback + return { + "action": "escalate", + "reason": "Error during recommendation generation", + "requires_human_review": True, + } + + + def _record_execution(self, success: bool, duration_ms: float): + self.execution_history.append({ + # "timestamp": datetime.utcnow().isoformat(), + "success": success, + "duration_ms": duration_ms, + }) + # Keep recent N only + if len(self.execution_history) > self.max_history: + self.execution_history.pop(0) + + async def health_check(self) -> Dict[str, Any]: + total_runs = len(self.execution_history) + if total_runs == 0: + return { + "Agent": "Risk Agent โ ๏ธ", + "Executions": 0, + "Success Rate (%)": 0.0, + "Avg Duration (ms)": 0.0, + "Total Failures": 0, + "Status": "idle", + # "Timestamp": datetime.utcnow().isoformat() + } + metrics_data = {} + executions = 0 + success_rate = 0.0 + avg_duration = 0.0 + failures = 0 + last_run = None + + # 1. Try to get live metrics from state + # print("(self.state)-------",self.metrics) + # print("self.state.agent_metrics-------", self.state.agent_metrics) + if self.metrics: + executions = self.metrics["processed"] + avg_duration = self.metrics["avg_latency_ms"] + failures = self.metrics["errors"] + last_run = self.metrics["last_run_at"] + success_rate = (executions - failures) / (executions+1e-8) + + # 2. API connectivity check + gemini_ok = bool(self.api_key) + api_status = "๐ข Active" if gemini_ok else "๐ด Missing Key" + + # 3. Health logic + overall_status = "๐ข Healthy" + if not gemini_ok or failures > 3: + overall_status = "๐ Degraded" + if executions > 0 and success_rate < 0.5: + overall_status = "๐ด Unhealthy" + + successes = sum(1 for e in self.execution_history if e["success"]) + failures = total_runs - successes + avg_duration = round(mean(e["duration_ms"] for e in self.execution_history), 2) + success_rate = round((successes / (total_runs+1e-8)) * 100, 2) + + return { + "Agent": "Risk Agent โ ๏ธ", + "Executions": total_runs, + "Success Rate (%)": success_rate, + "Avg Duration (ms)": avg_duration, + "API Status": api_status, + "Total Failures": failures, + "Last Run": str(last_run) if last_run else "Not applicable", + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + "Overall Health": overall_status, + } + + diff --git a/Project/agents/smart_explainer_agent.py b/Project/agents/smart_explainer_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..5a0bc7193734f34fef6ba14c9afdfa0d929f1208 --- /dev/null +++ b/Project/agents/smart_explainer_agent.py @@ -0,0 +1,220 @@ + +""" +Smart Explainer Agent (Enhanced + Gemini-powered) +- Produces a detailed, human-readable explanation for a single InvoiceProcessingState. +- Uses Gemini for natural summarization if API key is present. +- Defensive, HTML-enhanced, and fully dashboard-ready. +""" + +from state import InvoiceProcessingState, ValidationStatus, PaymentStatus, RiskLevel +from datetime import datetime +import google.generativeai as genai +import json +import os + + +class SmartExplainerAgent: + def __init__(self): + # Configure Gemini only if available + self.api_key = os.environ.get("GEMINI_API_KEY_4") + self.use_gemini = bool(self.api_key) + if self.use_gemini: + genai.configure(api_key=self.api_key) + self.model = genai.GenerativeModel("gemini-2.0-flash") + + # ---------- Helper functions ---------- + def _safe_invoice_dict(self, state: InvoiceProcessingState) -> dict: + if not state or not getattr(state, "invoice_data", None): + return {} + return ( + state.invoice_data.model_dump(exclude_none=True) + if hasattr(state.invoice_data, "model_dump") + else state.invoice_data.dict() + ) + + def _safe_validation(self, state: InvoiceProcessingState) -> dict: + if not state or not getattr(state, "validation_result", None): + return {} + return ( + state.validation_result.model_dump(exclude_none=True) + if hasattr(state.validation_result, "model_dump") + else state.validation_result.dict() + ) + + def _safe_risk(self, state: InvoiceProcessingState) -> dict: + if not state or not getattr(state, "risk_assessment", None): + return {} + return ( + state.risk_assessment.model_dump(exclude_none=True) + if hasattr(state.risk_assessment, "model_dump") + else state.risk_assessment.dict() + ) + + # ---------- Core explain logic ---------- + def explain(self, state) -> str: + """ + Generate a detailed HTML + markdown explanation for a given invoice. + Falls back gracefully if data or Gemini is unavailable. + """ + + # --- Defensive normalization --- + if state is None: + return "
โ ๏ธ No invoice state provided.
" + + if isinstance(state, dict): + try: + state = InvoiceProcessingState(**state) + except Exception: + pass + + # --- Extract fields safely --- + invoice = self._safe_invoice_dict(state) or {} + validation = self._safe_validation(state) or {} + risk = self._safe_risk(state) or {} + payment = ( + state.payment_decision.model_dump(exclude_none=True) + if getattr(state, "payment_decision", None) + and hasattr(state.payment_decision, "model_dump") + else getattr(state, "payment_decision", {}) or {} + ) + + discrepancies = validation.get("discrepencies", []) # per schema + + inv_id = invoice.get("invoice_number") or invoice.get("file_name") or "Invoice: {inv_id}
", + f"Vendor: {vendor}
", + f"Amount: {_fmt(total)}
", + f"Status: {status_val}
", + "Validation: {val_status or 'unknown'}
", + f"Risk Level: {risk_level or 'low'} ({risk_score})
", + f"Payment: {payment.get('decision', 'N/A')} ({payment_status or 'pending'})
", + ] + + if discrepancies: + lines.append("Discrepancies Found:
{expected}, got {actual}Recommendation:
Gemini explanation failed: {e}
" + + return explanation_html diff --git a/Project/agents/validation_agent.py b/Project/agents/validation_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..a2838cec0f91c9013ed09b773021a9439a364b58 --- /dev/null +++ b/Project/agents/validation_agent.py @@ -0,0 +1,357 @@ + +"""Validation Agent for Invoice Processing""" + +# TODO: Implement agent +import asyncio +import os +import pandas as pd +from typing import Dict, Any, List, Tuple +from fuzzywuzzy import fuzz +import numpy as np +import time +from agents.base_agent import BaseAgent +from state import ( + InvoiceProcessingState, ValidationResult, ValidationStatus, + ProcessingStatus +) +from datetime import datetime, timedelta + +from utils.logger import StructuredLogger +from difflib import SequenceMatcher + +class ValidationAgent(BaseAgent): + """Agent responsible for validating invoice data against purchase orders""" + + health_history: List[Dict[str, Any]] = [] # global history for metrics + + def __init__(self, config: Dict[str, Any] = None): + # pass + super().__init__(agent_name="validation_agent",config=config or {}) + self.logger = StructuredLogger(__name__) + self.po_file = self.config.get("po_file","data/purchase_orders.csv") + self.tolerance = self.config.get("tolerance",0.05) + self.successful_executions = 0 + self.failed_executions = 0 + self.total_duration = 0.0 + self.total_executions = 0 + self.last_run = None + # self.match_threshold = self.config.get("match_threshold",80) + + def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: + # pass + if not state.invoice_data: + self.logger.logger.error("No invoice data available for validation.") + return False + return True + + def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: + # pass + return hasattr(state,'validation_result') and state.validation_result is not None + + async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + # pass + self.logger.logger.info(f"[ValidationAgent] Starting validation for {state.file_name}") + start_time = time.time() + try: + if not self._validate_preconditions(state, workflow_type): + state.status = ProcessingStatus.FAILED + self._log_decision(state,"Validation Failed","Precondition not met",confidence = 0.0) + return state + invoice_data = state.invoice_data + matching_pos = await self._find_matching_pos(invoice_data) + validation_result = await self._validate_against_pos(invoice_data,matching_pos) + state.validation_result = validation_result + state.current_agent = "validation_agent" + state.overall_status = ProcessingStatus.IN_PROGRESS + + if self._should_escalate_validation(validation_result, invoice_data): + state.escalation_required = True + self._validate_postconditions(state) + self.successful_executions += 1 + self.last_run = datetime.utcnow().isoformat() + # print("ValidationResult().confidence_score", state.validation_result.confidence_score) + self._log_decision( + state, + "Validation Successful", + "PDF text successfully validated and checked by AI", + state.validation_result.confidence_score, + state.process_id + ) + return state + except Exception as e: + self.logger.logger.error(f"[ValidationAgent] Execution failed: {e}") + self.failed_executions += 1 + state.overall_status = ProcessingStatus.FAILED + return state + + finally: + duration = (time.time() - start_time) * 1000 # ms + self.total_executions += 1 + self.total_duration += duration + self._record_health_metrics(duration) + + def _load_purchase_orders(self) -> pd.DataFrame: + # pass + """load po data from csv""" + try: + df = pd.read_csv(self.po_file) + self.logger.logger.info(f"[ValidationAgent] Loaded {len(df)} purchase orders") + return df + except Exception as e: + self.logger.logger.error(f"[ValidationAgent] failed to load purchase order: {e}") + raise + + async def _find_matching_pos(self, invoice_data) -> List[Dict[str, Any]]: + """find POs matching invoice order_id or fuzzy customer/items""" + po_df = self._load_purchase_orders() + matches = [] + for _,po in po_df.iterrows(): + customer_score = fuzz.token_sort_ratio(po["customer_name"], invoice_data.customer_name) + order_id_score = fuzz.token_sort_ratio(po["order_id"], invoice_data.order_id) + for item in invoice_data.item_details: + item_score = fuzz.token_sort_ratio(po["item_name"],item.item_name) + print(f"Compairing PO item {po['item_name']} with invoice item {item.item_name}: score = {item_score}") + + if (customer_score >= 80) and (item_score >=80) and (order_id_score >=80) and (po['invoice_number'] == int(invoice_data.invoice_number)): + matches.append(po.to_dict()) + + print("matches.....", matches) + return matches + + + async def _validate_against_pos(self, invoice_data, matching_pos: List[Dict[str, Any]]) -> ValidationResult: + # pass + + if not matching_pos: + return ValidationResult(po_found=False, validation_status='missing_po',validation_result='No matching purchase order found', + discrepancies = [], + confidence_score = 0.0) + po_data = matching_pos[0] + discrepancies = self._validate_item_against_po(invoice_data,po_data) + discrepancies += self._validate_totals(invoice_data,po_data) + actual_amount = [item.amount for item in invoice_data.item_details][0] + actual_quantity = [item.quantity for item in invoice_data.item_details][0] + actual_rate = [item.rate for item in invoice_data.item_details][0] + amount_diff = abs(actual_amount - po_data.get('expected_amount',0)) + tolerance_limit = po_data.get('expected_amount',0)*self.tolerance + amount_match = amount_diff <= tolerance_limit + + validation_result = ValidationResult( + po_found=True, + quantity_match=actual_quantity == po_data.get('quantity'), + rate_match=abs(actual_rate - po_data.get('rate', 0)) <= tolerance_limit, + amount_match=amount_match, + validation_status=ValidationStatus.NOT_STARTED, # temporary + validation_result="; ".join(discrepancies) if discrepancies else "All checks passed", + discrepencies=discrepancies, + confidence_score=0.0, # temporary + expected_amount=po_data.get('amount'), + po_data=po_data + ) + validation_result.validation_status = self._determine_validation_status(validation_result) + validation_result.confidence_score = self._calculate_validation_confidence(validation_result, matching_pos, invoice_data) + return validation_result + + def _validate_item_against_po(self, item, po_data: Dict[str, Any]) -> List[str]: + # pass + # print("itemmmmmmmmm", item.item_details.quantity) + print("po_-------------", po_data) + discrepancies = [] + for item in item.item_details: + if item.quantity != po_data.get('quantity'): + discrepancies.append(f"Quantity mismatch: Expected {po_data['quantity']}, Found {item.quantity}") + if abs(item.rate - po_data.get('rate',0)) > po_data.get('rate',0)*self.tolerance: + discrepancies.append(f"Rate mismatch: Expected {po_data['rate']}, Found {item.rate}") + return discrepancies + + def _validate_totals(self, invoice_data, po_data: Dict[str, Any]) -> List[str]: + # pass + discrepancies = [] + expected = po_data.get('expected_amount',0) + actual = [item.amount for item in invoice_data.item_details][0] + diff = abs(expected-actual) + if diff > expected*self.tolerance: + discrepancies.append(f"Total amount mismatch: Expected {expected}, Actual {actual} (Difference:{diff:.2f})") + return discrepancies + + def _calculate_validation_confidence(self, validation_result: ValidationResult, + matching_pos: List[Dict[str, Any]], invoice_data) -> float: + """ + Compute an intelligent, weighted confidence score across 7 key dimensions: + invoice_number, order_id, customer_name, item_name, amount, rate, quantity. + Each field contributes based on importance. + """ + + if not validation_result.po_found or not matching_pos: + return 0.0 + + po_data = matching_pos[0] + + # Extract PO (expected) values + expected = { + "invoice_number": po_data.get("invoice_number", ""), + "order_id": po_data.get("order_id", ""), + "customer_name": po_data.get("customer_name", ""), + "item_name": po_data.get("item_name", ""), + "amount": float(po_data.get("expected_amount", po_data.get("amount", 0))), + "rate": float(po_data.get("rate", 0)), + "quantity": float(po_data.get("quantity", 0)) + } + + # Extract actual (from invoice) + actual = { + "invoice_number": invoice_data.invoice_number, + "order_id": invoice_data.order_id, + "customer_name": invoice_data.customer_name, + } + + # Handle line-item level (assuming single dominant item) + if invoice_data.item_details: + item = invoice_data.item_details[0] + actual.update({ + "item_name": item.item_name, + "amount": float(item.amount or 0), + "rate": float(item.rate or 0), + "quantity": float(item.quantity or 0) + }) + + # Define weights intelligently (sum = 1) + weights = { + "invoice_number": 0.20, + "order_id": 0.15, + "customer_name": 0.05, + "item_name": 0.05, + "amount": 0.25, + "rate": 0.15, + "quantity": 0.15 + } + + # --- Similarity functions --- + def numeric_similarity(expected_val, actual_val): + if expected_val == 0: + return 1.0 if actual_val == 0 else 0.0 + diff_ratio = abs(expected_val - actual_val) / (abs(expected_val) + 1e-6) + return max(0.0, 1.0 - diff_ratio) + + def text_similarity(a, b): + return SequenceMatcher(None, str(a).lower(), str(b).lower()).ratio() + + # --- Compute weighted similarities --- + weighted_scores = [] + for field, weight in weights.items(): + exp_val, act_val = expected.get(field), actual.get(field) + + if isinstance(exp_val, (int, float)) and isinstance(act_val, (int, float)): + score = numeric_similarity(exp_val, act_val) + else: + score = text_similarity(exp_val, act_val) + + weighted_scores.append(weight * score) + + # Combine to final confidence + confidence = sum(weighted_scores) + confidence = round(confidence * 100, 2) # convert to % + confidence = max(0.0, min(confidence, 100.0)) # clamp 0โ100 + + self.logger.logger.debug(f"Validation Confidence (weighted): {confidence}%") + return confidence + + + + def _determine_validation_status(self, validation_result: ValidationResult) -> ValidationStatus: + """ + Determine the final validation status based on PO existence, discrepancies, and amount match. + """ + if not validation_result.po_found: + return ValidationStatus.MISSING_PO + + discrepancies_count = len(validation_result.discrepencies) + + if discrepancies_count == 0 and validation_result.amount_match: + return ValidationStatus.VALID + + if validation_result.amount_match and discrepancies_count <= 2: + return ValidationStatus.PARTIAL_MATCH + + return ValidationStatus.INVALID + + + def _should_escalate_validation(self, validation_result: ValidationResult, invoice_data) -> bool: + # pass + return validation_result.validation_status in ['invalid','missing_po'] + + def _record_health_metrics(self, duration: float): + """Record the health metrics after each execution""" + success_rate = ( + (self.successful_executions / self.total_executions) * 100 + if self.total_executions > 0 else 0 + ) + avg_duration = ( + self.total_duration / self.total_executions + if self.total_executions > 0 else 0 + ) + + metrics = { + "Agent": "Validation Agent โ ", + "Executions": self.total_executions, + "Success Rate (%)": round(success_rate, 2), + "Avg Duration (ms)": round(avg_duration, 2), + "Total Failures": self.failed_executions, + # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + } + metrics_data = {} + executions = 0 + success_rate = 0.0 + avg_duration = 0.0 + failures = 0 + last_run = None + + if self.metrics: + print("self.metrics from validation agent", self.metrics) + executions = self.metrics["processed"] + print("executions.....", executions) + avg_duration = self.metrics["avg_latency_ms"] + failures = self.metrics["errors"] + last_run = self.metrics["last_run_at"] + print("last_run.....", last_run) + success_rate = (executions - failures) / (executions + 1e-6) + + # if last_run == None: + last_run = self.last_run + + # 3. Health logic + overall_status = "๐ข Healthy" + if failures > 3: + overall_status = "๐ Degraded" + if executions > 0 and success_rate < 0.5: + overall_status = "๐ด Unhealthy" + + print("metrics from val---....1", metrics) + + metrics.update({ + "Last Run": str(last_run) if last_run else "Not applicable", + "Overall Health": overall_status, + }) + print("metrics from val---....", metrics) + # maintain up to last 50 records + ValidationAgent.health_history.append(metrics) + # ValidationAgent.health_history = ValidationAgent.health_history[-50:] + + async def health_check(self) -> Dict[str, Any]: + """ + Returns the health metrics summary for UI display. + """ + await asyncio.sleep(0.05) + if not ValidationAgent.health_history: + return { + "Agent": "Validation Agent โ ", + "Executions": 0, + "Success Rate (%)": 0.0, + "Avg Duration (ms)": 0.0, + "Total Failures": 0, + } + + + latest = ValidationAgent.health_history[-1] + print("latest.....", latest) + return latest diff --git a/Project/bounding_box.py b/Project/bounding_box.py new file mode 100644 index 0000000000000000000000000000000000000000..f3ff8f97245c57e4b47f7b2de1db24f10cc43243 --- /dev/null +++ b/Project/bounding_box.py @@ -0,0 +1,138 @@ + +import fitz # PyMuPDF +import pandas as pd +import os +import re + +# === File paths === +DATA_DIR = os.path.join(os.getcwd(), "data") +PDF_PATH = os.path.join(DATA_DIR, "invoices/Invoice-26.pdf") # Update for new PDF if needed +CSV_PATH = os.path.join(DATA_DIR, "purchase_orders.csv") +OUTPUT_PATH = os.path.join(DATA_DIR, "annotated_invoice.pdf") + +# === Field coordinate map (from your data) === +FIELD_BOXES = { + "invoice_number": (525, 55, 575, 75), + "order_id": (45, 470, 230, 490), + "customer_name": (40, 135, 100, 155), + "quantity": (370, 235, 385, 250), + "rate": (450, 235, 500, 250), + "expected_amount": (520, 360, 570, 375), +} + +# === Step 1: Open PDF and extract text === +pdf = fitz.open(PDF_PATH) +page = pdf[0] +pdf_text = page.get_text() + +# === Step 2: Helper to extract fields === +def extract_field(pattern, text, group=1): + match = re.search(pattern, text, re.IGNORECASE) + return match.group(group).strip() if match else None + +# Extract key identifiers +invoice_number_pdf = extract_field(r"#\s*(\d+)", pdf_text) +order_id_pdf = extract_field(r"Order ID\s*[:\-]?\s*(\S+)", pdf_text) +customer_name_pdf = extract_field(r"Bill To:\s*(.*)", pdf_text) + +# === Step 3: Read CSV and match correct row === +po_df = pd.read_csv(CSV_PATH) + +matched_row = po_df[ + (po_df['invoice_number'].astype(str) == str(invoice_number_pdf)) + | (po_df['order_id'] == order_id_pdf) +] + +if matched_row.empty: + raise ValueError(f"No matching CSV row found for Invoice {invoice_number_pdf} / Order {order_id_pdf}") + +expected = matched_row.iloc[0].to_dict() +expected = {k.lower(): str(v).strip() for k, v in expected.items()} + +print("โ Loaded expected data from CSV for this PDF:") +for k, v in expected.items(): + print(f" {k}: {v}") + +# === Step 4: Extract fields from PDF === +invoice_data = { + "invoice_number": invoice_number_pdf, + "customer_name": customer_name_pdf, + "order_id": order_id_pdf, +} + +# Numeric fields +amounts = re.findall(r"\$?([\d,]+\.\d{2})", pdf_text) +invoice_data["expected_amount"] = amounts[-1] if amounts else None + +# Extract first item (quantity, rate) +item_lines = re.findall( + r"([A-Za-z0-9 ,\-]+)\s+(\d+)\s+\$?([\d,]+\.\d{2})\s+\$?([\d,]+\.\d{2})", + pdf_text, +) +if item_lines: + invoice_data["quantity"] = item_lines[0][1] + invoice_data["rate"] = item_lines[0][2] + +print("\nโ Extracted data from PDF:") +for k, v in invoice_data.items(): + print(f" {k}: {v}") + +# === Step 5: Compare PDF vs CSV === +discrepancies = [] + +def add_discrepancy(field, expected_val, found_val): + discrepancies.append({"field": field, "expected": expected_val, "found": found_val}) + +# Compare string fields +for field in ["invoice_number", "order_id", "customer_name"]: + if str(invoice_data.get(field, "")).strip() != str(expected.get(field, "")).strip(): + add_discrepancy(field, expected.get(field, ""), invoice_data.get(field, "")) + +# Compare numeric fields +for field in ["quantity", "rate", "expected_amount"]: + try: + found_val = float(str(invoice_data.get(field, 0)).replace(",", "").replace("$", "")) + expected_val = float(str(expected.get(field, 0)).replace(",", "").replace("$", "")) + if round(found_val, 2) != round(expected_val, 2): + add_discrepancy(field, expected_val, found_val) + except: + if str(invoice_data.get(field, "")) != str(expected.get(field, "")): + add_discrepancy(field, expected.get(field, ""), invoice_data.get(field, "")) + +# === Step 6: Annotate mismatched fields using fixed coordinates === +for d in discrepancies: + field = d["field"] + if field not in FIELD_BOXES: + print(f"โ ๏ธ No coordinates found for field '{field}' โ skipping annotation.") + continue + + rect_coords = FIELD_BOXES[field] + rect = fitz.Rect(rect_coords) + expected_text = ( + f"{float(d['expected']):,.2f}" + if field in ["quantity", "rate", "expected_amount"] + else str(d["expected"]) + ) + + # Draw red bounding box + page.draw_rect(rect, color=(1, 0, 0), width=1.5) + + # Add expected value below box + page.insert_text( + (rect.x0, rect.y1 + 10), + expected_text, + fontsize=9, + color=(1, 0, 0), + ) + +pdf.save(OUTPUT_PATH) +pdf.close() + +print("\nโ Annotated invoice saved at:", OUTPUT_PATH) + +if discrepancies: + print("\nโ ๏ธ Mismatches found:") + for d in discrepancies: + print(f" - {d['field']}: expected {d['expected']}, found {d['found']}") +else: + print("\nโ No mismatches found! Invoice matches CSV.") diff --git a/Project/data/annotated_invoice.pdf b/Project/data/annotated_invoice.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a20e05135c4536da162a516d493e6e18128009b5 Binary files /dev/null and b/Project/data/annotated_invoice.pdf differ diff --git a/Project/data/invoices/Invoice-01.pdf b/Project/data/invoices/Invoice-01.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ecfdf68222f1e900880bc15e51eaf214fa276b8c Binary files /dev/null and b/Project/data/invoices/Invoice-01.pdf differ diff --git a/Project/data/invoices/Invoice-02.pdf b/Project/data/invoices/Invoice-02.pdf new file mode 100644 index 0000000000000000000000000000000000000000..af39020d574e8806478b1ca8eae374c7d46898f8 Binary files /dev/null and b/Project/data/invoices/Invoice-02.pdf differ diff --git a/Project/data/invoices/Invoice-03.pdf b/Project/data/invoices/Invoice-03.pdf new file mode 100644 index 0000000000000000000000000000000000000000..5603fb506eb5ec9dbc834ef01ca66dd96fb5e19d Binary files /dev/null and b/Project/data/invoices/Invoice-03.pdf differ diff --git a/Project/data/invoices/Invoice-04.pdf b/Project/data/invoices/Invoice-04.pdf new file mode 100644 index 0000000000000000000000000000000000000000..7feeced3584889b6bf57fbf9bbc6bc8204b4ac1a Binary files /dev/null and b/Project/data/invoices/Invoice-04.pdf differ diff --git a/Project/data/invoices/Invoice-05.pdf b/Project/data/invoices/Invoice-05.pdf new file mode 100644 index 0000000000000000000000000000000000000000..adaf933183c519f0a456db7cca2a1ca51fedc118 Binary files /dev/null and b/Project/data/invoices/Invoice-05.pdf differ diff --git a/Project/data/invoices/Invoice-06.pdf b/Project/data/invoices/Invoice-06.pdf new file mode 100644 index 0000000000000000000000000000000000000000..26f48dd42d5b837187a204a5f4e5220d2f13b803 Binary files /dev/null and b/Project/data/invoices/Invoice-06.pdf differ diff --git a/Project/data/invoices/Invoice-07.pdf b/Project/data/invoices/Invoice-07.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ab4f2f05613c5d30a131e8cfda681ee46cbf99e5 Binary files /dev/null and b/Project/data/invoices/Invoice-07.pdf differ diff --git a/Project/data/invoices/Invoice-08.pdf b/Project/data/invoices/Invoice-08.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ad40433bbb70cbf264ffa55c816473d6bdd23ba6 Binary files /dev/null and b/Project/data/invoices/Invoice-08.pdf differ diff --git a/Project/data/invoices/Invoice-09.pdf b/Project/data/invoices/Invoice-09.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e6948d8bae8968946feaae6c85d25cbef1b4063c Binary files /dev/null and b/Project/data/invoices/Invoice-09.pdf differ diff --git a/Project/data/invoices/Invoice-10.pdf b/Project/data/invoices/Invoice-10.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ced3999cee25b765cd9e8ada1811f916eb8b81e5 Binary files /dev/null and b/Project/data/invoices/Invoice-10.pdf differ diff --git a/Project/data/invoices/Invoice-11.pdf b/Project/data/invoices/Invoice-11.pdf new file mode 100644 index 0000000000000000000000000000000000000000..9385825ca388da1438aa39bdddb1dd64bcc54cf1 Binary files /dev/null and b/Project/data/invoices/Invoice-11.pdf differ diff --git a/Project/data/invoices/Invoice-12.pdf b/Project/data/invoices/Invoice-12.pdf new file mode 100644 index 0000000000000000000000000000000000000000..17d155ea17c3bba1824a0d2160e2dfbea3562e64 Binary files /dev/null and b/Project/data/invoices/Invoice-12.pdf differ diff --git a/Project/data/invoices/Invoice-13.pdf b/Project/data/invoices/Invoice-13.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a621cffdf971687b23a08cffddd689aef8bbd689 Binary files /dev/null and b/Project/data/invoices/Invoice-13.pdf differ diff --git a/Project/data/invoices/Invoice-14.pdf b/Project/data/invoices/Invoice-14.pdf new file mode 100644 index 0000000000000000000000000000000000000000..75e2d36460d64abb542313c664fad631b0bfca7c Binary files /dev/null and b/Project/data/invoices/Invoice-14.pdf differ diff --git a/Project/data/invoices/Invoice-15.pdf b/Project/data/invoices/Invoice-15.pdf new file mode 100644 index 0000000000000000000000000000000000000000..28789650578358643d112260521d6e43ecee7504 Binary files /dev/null and b/Project/data/invoices/Invoice-15.pdf differ diff --git a/Project/data/invoices/Invoice-16.pdf b/Project/data/invoices/Invoice-16.pdf new file mode 100644 index 0000000000000000000000000000000000000000..06f36701897c8771759e617b78befd2921a6ad74 Binary files /dev/null and b/Project/data/invoices/Invoice-16.pdf differ diff --git a/Project/data/invoices/Invoice-17.pdf b/Project/data/invoices/Invoice-17.pdf new file mode 100644 index 0000000000000000000000000000000000000000..71617e61b831fa7a074fb8381f46cec5710f8a9b Binary files /dev/null and b/Project/data/invoices/Invoice-17.pdf differ diff --git a/Project/data/invoices/Invoice-18.pdf b/Project/data/invoices/Invoice-18.pdf new file mode 100644 index 0000000000000000000000000000000000000000..39e3d0ace78d0b82593cb6b3588bfd5b3a3678af Binary files /dev/null and b/Project/data/invoices/Invoice-18.pdf differ diff --git a/Project/data/invoices/Invoice-19.pdf b/Project/data/invoices/Invoice-19.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d73458b574c054b574be69524a7e0b3b710000df Binary files /dev/null and b/Project/data/invoices/Invoice-19.pdf differ diff --git a/Project/data/invoices/Invoice-20.pdf b/Project/data/invoices/Invoice-20.pdf new file mode 100644 index 0000000000000000000000000000000000000000..dc2ecdc80c912aaca920cccdbf47a5f2d58fe700 Binary files /dev/null and b/Project/data/invoices/Invoice-20.pdf differ diff --git a/Project/data/invoices/Invoice-21.pdf b/Project/data/invoices/Invoice-21.pdf new file mode 100644 index 0000000000000000000000000000000000000000..65d3410764a8142ff633219d8855d91c4943925d Binary files /dev/null and b/Project/data/invoices/Invoice-21.pdf differ diff --git a/Project/data/invoices/Invoice-22.pdf b/Project/data/invoices/Invoice-22.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e466db6fd8d551db07a3a3bc79664342d95c73b6 Binary files /dev/null and b/Project/data/invoices/Invoice-22.pdf differ diff --git a/Project/data/invoices/Invoice-23.pdf b/Project/data/invoices/Invoice-23.pdf new file mode 100644 index 0000000000000000000000000000000000000000..63591097449a7710ddd5a1f62f3233ab968f1593 Binary files /dev/null and b/Project/data/invoices/Invoice-23.pdf differ diff --git a/Project/data/invoices/Invoice-24.pdf b/Project/data/invoices/Invoice-24.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c273d2281738262282320df6a0a3de0b1600e865 Binary files /dev/null and b/Project/data/invoices/Invoice-24.pdf differ diff --git a/Project/data/invoices/Invoice-25.pdf b/Project/data/invoices/Invoice-25.pdf new file mode 100644 index 0000000000000000000000000000000000000000..35a6e9d60acc756a73d0bbd4087f8958c0100cb7 Binary files /dev/null and b/Project/data/invoices/Invoice-25.pdf differ diff --git a/Project/data/invoices/Invoice-26.pdf b/Project/data/invoices/Invoice-26.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d3ec4bf4986ca498f45b1730f7fcd875cae1e160 Binary files /dev/null and b/Project/data/invoices/Invoice-26.pdf differ diff --git a/Project/data/invoices/Invoice-27.pdf b/Project/data/invoices/Invoice-27.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ce9a2aea421990762b0a48a557bd2bc0a4f4fee1 Binary files /dev/null and b/Project/data/invoices/Invoice-27.pdf differ diff --git a/Project/data/invoices/Invoice-28.pdf b/Project/data/invoices/Invoice-28.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4e6200886dbf354840302735c39124ec77e177c8 Binary files /dev/null and b/Project/data/invoices/Invoice-28.pdf differ diff --git a/Project/data/invoices/Invoice-29.pdf b/Project/data/invoices/Invoice-29.pdf new file mode 100644 index 0000000000000000000000000000000000000000..09d0d0d9d8f6d4b8f82b486ce66b6b40b659601d Binary files /dev/null and b/Project/data/invoices/Invoice-29.pdf differ diff --git a/Project/data/invoices/Invoice-30.pdf b/Project/data/invoices/Invoice-30.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c1c6e7b027cb027268c97e838e18c439399b06b7 Binary files /dev/null and b/Project/data/invoices/Invoice-30.pdf differ diff --git a/Project/data/invoices/Invoice-31.pdf b/Project/data/invoices/Invoice-31.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4a9030fe15ee4bc2720edd2ff1ecce8c6d885c65 Binary files /dev/null and b/Project/data/invoices/Invoice-31.pdf differ diff --git a/Project/data/invoices/Invoice-32.pdf b/Project/data/invoices/Invoice-32.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f5f569c4b55b39f462f2168a6ba21d0cf8739f76 Binary files /dev/null and b/Project/data/invoices/Invoice-32.pdf differ diff --git a/Project/data/invoices/Invoice-33.pdf b/Project/data/invoices/Invoice-33.pdf new file mode 100644 index 0000000000000000000000000000000000000000..92736aa3e768a0c28e833e93e32f10dde085a561 Binary files /dev/null and b/Project/data/invoices/Invoice-33.pdf differ diff --git a/Project/data/invoices/Invoice-34.pdf b/Project/data/invoices/Invoice-34.pdf new file mode 100644 index 0000000000000000000000000000000000000000..894fa7f945db3d963abd91f7463717e38533a1a1 Binary files /dev/null and b/Project/data/invoices/Invoice-34.pdf differ diff --git a/Project/data/invoices/Invoice-35.pdf b/Project/data/invoices/Invoice-35.pdf new file mode 100644 index 0000000000000000000000000000000000000000..9bfc25e9eef27c1efe1863dc7e8a863620f9b5e5 Binary files /dev/null and b/Project/data/invoices/Invoice-35.pdf differ diff --git a/Project/data/invoices/Invoice-36.pdf b/Project/data/invoices/Invoice-36.pdf new file mode 100644 index 0000000000000000000000000000000000000000..b450493b58abae5521bc423b6875a0d15ee8d7ac Binary files /dev/null and b/Project/data/invoices/Invoice-36.pdf differ diff --git a/Project/data/invoices/Invoice-37.pdf b/Project/data/invoices/Invoice-37.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c782b45a4be1dbdf5a3dc4f132537a3e56b0b28b Binary files /dev/null and b/Project/data/invoices/Invoice-37.pdf differ diff --git a/Project/data/invoices/Invoice-38.pdf b/Project/data/invoices/Invoice-38.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a688a8442c69c534f0b409c9781ee5b2bc348d53 Binary files /dev/null and b/Project/data/invoices/Invoice-38.pdf differ diff --git a/Project/data/invoices/Invoice-39.pdf b/Project/data/invoices/Invoice-39.pdf new file mode 100644 index 0000000000000000000000000000000000000000..0f40061c0296c2e90c6316436bd10573020f6f8d Binary files /dev/null and b/Project/data/invoices/Invoice-39.pdf differ diff --git a/Project/data/invoices/Invoice-40.pdf b/Project/data/invoices/Invoice-40.pdf new file mode 100644 index 0000000000000000000000000000000000000000..c5edbede858a675e3e45e144d590a10378ce8c24 Binary files /dev/null and b/Project/data/invoices/Invoice-40.pdf differ diff --git a/Project/data/invoices/Invoice-41.pdf b/Project/data/invoices/Invoice-41.pdf new file mode 100644 index 0000000000000000000000000000000000000000..04c75ac6b1998098274cd4964911b6f116d523fc Binary files /dev/null and b/Project/data/invoices/Invoice-41.pdf differ diff --git a/Project/data/invoices/Invoice-42.pdf b/Project/data/invoices/Invoice-42.pdf new file mode 100644 index 0000000000000000000000000000000000000000..8f81787e0a617a41b201d99a2aad37b5f6232747 Binary files /dev/null and b/Project/data/invoices/Invoice-42.pdf differ diff --git a/Project/data/invoices/Invoice-43.pdf b/Project/data/invoices/Invoice-43.pdf new file mode 100644 index 0000000000000000000000000000000000000000..3c0ef0c6fe119a042f0009a8c95fcf6f568c6d08 Binary files /dev/null and b/Project/data/invoices/Invoice-43.pdf differ diff --git a/Project/data/invoices/Invoice-44.pdf b/Project/data/invoices/Invoice-44.pdf new file mode 100644 index 0000000000000000000000000000000000000000..864221ce79c3fed877686aaa45edf41828f99c4a Binary files /dev/null and b/Project/data/invoices/Invoice-44.pdf differ diff --git a/Project/data/invoices/Invoice-45.pdf b/Project/data/invoices/Invoice-45.pdf new file mode 100644 index 0000000000000000000000000000000000000000..27c2b02e52b9f4b0a97d106fae7e1e0daa3d0928 Binary files /dev/null and b/Project/data/invoices/Invoice-45.pdf differ diff --git a/Project/data/invoices/Invoice-46.pdf b/Project/data/invoices/Invoice-46.pdf new file mode 100644 index 0000000000000000000000000000000000000000..43129ec8ef92695e52c8fb4cc778712b7fbd3b69 Binary files /dev/null and b/Project/data/invoices/Invoice-46.pdf differ diff --git a/Project/data/purchase_orders.csv b/Project/data/purchase_orders.csv new file mode 100644 index 0000000000000000000000000000000000000000..89f7c57deecc26d33a8caf9e96cfb5c6cb71e665 --- /dev/null +++ b/Project/data/purchase_orders.csv @@ -0,0 +1,26 @@ +invoice_number,order_id,customer_name,item_name,quantity,rate,expected_amount +14021,ES-2025-BE11335139-41340,Bill Eplett,"Canon Wireless Fax, Laser Copiers, Technology, TEC-CO-3710",5,1893.30,9466.5 +6459,MX-2025-DK1298539-41339,Darren Koutras,"KitchenAid Stove, Silver Appliances, Office Supplies, OFF-AP-4966",4,1903.0,7612.0 +24450,IN-2025-GM146807-41338,Greg Matthias,"Hon Executive Leather Armchair, Adjustable Chairs, Furniture, FUR-CH-4654",1,409.24,409.24 +14130,ES-2025-LT1711045-41340,Liz Thompson,"Safco 3-Shelf Cabinet, Mobile Bookcases, Furniture, FUR-BO-5746",8,1079.57,8636.56 +24429,IN-2025-MZ173357-41340,Maria Zettner,"Hon Rocking Chair, Black Chairs, Furniture, FUR-CH-4682",4,461.48,1845.94 +10963,ES-2025-SW202458-41340,Scot Wooten,"Hewlett Fax Machine, Color Copiers, Technology, TEC-CO-4575",3,1285.44,3856.32 +22091,IN-2025-GB145307-41338,George Bell,"Safco Classic Bookcase, Metal Bookcases, Furniture, FUR-BO-5760",9,2756.94,24812.46 +36552,CA-2025-AH10195140-41338,Alan Haines,"36X48 HARDFLOOR CHAIRMAT Furnishings, Furniture, FUR-FU-2864",2,33.57,67.14 +36551,CA-2025-AH10195140-41338,Alan Haines,"Situations Contoured Folding Chairs, 4/Set Chairs, Furniture, FUR-CH-6016",2,99.37,198.74 +9063,MX-2025-PO1885051-41338,Patrick O'Brill,"Novimex Executive Leather Armchair, Adjustable Chairs, Furniture, FUR-CH-5378",2,607.36,1214.72 +28557,IN-2025-BT1153058-41331,Bradley Talbott,"Barricks Conference Table, Fully Assembled Tables, Furniture, FUR-TA-3344",3,5451.3,16353.9 +11001,ES-2025-RA19285120-41335,Ralph Arnett,"StarTech Inkjet, Wireless Machines, Technology, TEC-MA-6142",3,815.1,2445.31 +11000,ES-2025-RA19285120-41335,Ralph Arnett,"Canon Fax and Copier, Laser Copiers, Technology, TEC-CO-3685",11,2104.74,23152.14 +18509,ES-2025-JC15340139-41334,Jasper Cacioppo,"Cisco Audio Dock, Full Size Phones, Technology, TEC-PH-3785",4,733.44,2933.76 +49650,SG-2025-EH4005111-41331,Erica Hernandez,"Brother Fax Machine, High-Speed Copiers, Technology, TEC-CO-3597",2,633.48,1266.96 +22988,IN-2025-RE194507-41329,Richard Eichhorn,"Safco Library with Doors, Metal Bookcases, Furniture, FUR-BO-5785",2,700.65,1401.3 +22999,IN-2025-BP1123058-41329,Benjamin Patterson,"Office Star Executive Leather Armchair, Red Chairs, Furniture, FUR-CH-5443",4,1878.72,7514.88 +22711,IN-2025-DL1333058-41328,Denise Leinenbach,"Samsung Smart Phone, Full Size Phones, Technology, TEC-PH-5840",5,3187.2,15936.0 +11502,ES-2025-PB19210139-41327,Phillip Breyer,"Enermax Keyboard, Erganomic Accessories, Technology, TEC-AC-4156",6,487.62,2925.72 +49674,UP-2025-AH10030137-41325,Aaron Hawkins,"Hon Rocking Chair, Black Chairs, Furniture, FUR-CH-4682",8,1025.52,8204.16 +35513,CA-2025-AB10600140-41322,Ann Blume,"Smead Alpha-Z Color-Coded Second Alphabetical Labels and Starter Set Labels, Office Supplies, OFF-LA-6022",3,9.24,27.72 +20145,ES-2025-BD11560139-41327,Brendan Dodson,"Hon Executive Leather Armchair, Black Chairs, Furniture, FUR-CH-4655",5,2285.70,11428.50 +18517,ES-2025-FO1430545-41322,Frank Olsen,"Belkin Numeric Keypad, Bluetooth Accessories, Technology, TEC-AC-3396",10,579.60,5796.00 +28863,IN-2025-ME1801027-41322,Michelle Ellison,"Office Star Executive Leather Armchair, Adjustable Chairs, Furniture, FUR-CH-5441",1,930.00,930.00 +23981,IN-2025-NG1843058-41319,Nathan Gelder,"SAFCO Chairmat, Set of Two Chairs, Furniture, FUR-CH-5759",4,248.28,993.12 \ No newline at end of file diff --git a/Project/graph.py b/Project/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..f609e3500b6ff6100054e444a00eb8bd89e8abb8 --- /dev/null +++ b/Project/graph.py @@ -0,0 +1,487 @@ +"""LangGraph workflow orchestrator""" +# TODO: Implement graph workflow + +import asyncio +import uuid # extra import +from typing import Dict, Any, List, Optional, Literal +from datetime import datetime +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.memory import MemorySaver + +from state import ( + InvoiceProcessingState, ProcessingStatus, ValidationStatus, + RiskLevel, PaymentStatus, WORKFLOW_CONFIGS +) +from agents.base_agent import agent_registry +from agents.document_agent import DocumentAgent +from agents.validation_agent import ValidationAgent +from agents.risk_agent import RiskAgent +from agents.payment_agent import PaymentAgent +from agents.audit_agent import AuditAgent +from agents.escalation_agent import EscalationAgent +from utils.logger import StructuredLogger + + +class InvoiceProcessingGraph: + """Graph orchestrator""" + + def __init__(self, config: Dict[str, Any] = None): + self.logger = StructuredLogger("InvoiceProcessingGraph") + self.config = config or {} + #Simple in-memory store for process states (process_id -> InvoiceProcessingState) + self._process_store: Dict[str, InvoiceProcessingState] = {} + #Register and initialize agents + self._initialize_agents() + try: + self.graph = self._create_workflow_graph() + self.compiled_graph = self.graph.compile(checkpointer=MemorySaver()) + self.logger.logger.info("InvoiceProcessingGraph initialized successfully with compiled graph.") + except Exception as e: + self.logger.logger.warning(f"Failed to fully build graph nodes: {e} โ exposing empty StateGraph") + self.graph = StateGraph("invoice_processing_graph_fallback") + + def _initialize_agents(self): + """Instantiate and register agent instances in the global registry""" + #create agent instances (idempotent - replace if already registered) + agents = [ + DocumentAgent(), + ValidationAgent(), + RiskAgent(), + PaymentAgent(), + AuditAgent(), + EscalationAgent(), + ] + for agent in agents: + agent_registry.register(agent) + self.logger.logger.info(f"Registered agents: {agent_registry.list_agents()}") + + def _create_workflow_graph(self) -> StateGraph: + """ + Build a LangGraph StateGraph with conditional routing: + Each node executes its corresponding agent and determines + the next node based on runtime logic (risk, validation, etc.) + """ + + graph = StateGraph("invoice_processing_graph") + + # NODE DEFINITIONS + async def node_document(state: InvoiceProcessingState): + state = await self._document_agent_node(state) + next_node = self._route_after_document(state) + return next_node, state + + async def node_validation(state: InvoiceProcessingState): + state = await self._validation_agent_node(state) + next_node = self._route_after_validation(state) + return next_node, state + + async def node_risk(state: InvoiceProcessingState): + state = await self._risk_agent_node(state) + next_node = self._route_after_risk(state) + return next_node, state + + async def node_payment(state: InvoiceProcessingState): + state = await self._payment_agent_node(state) + next_node = self._route_after_payment(state) + return next_node, state + + async def node_audit(state: InvoiceProcessingState): + state = await self._audit_agent_node(state) + next_node = self._route_after_audit(state) + return next_node, state + + async def node_escalation(state: InvoiceProcessingState): + state = await self._escalation_agent_node(state) + next_node = self._route_after_escalation(state) + return next_node, state + + async def node_human_review(state: InvoiceProcessingState): + state = await self._human_review_node(state) + next_node = self._route_after_human_review(state) + return next_node, state + + async def node_end(state: InvoiceProcessingState): + self.logger.logger.info(f"Invoice {state.invoice_id} completed at {state.updated_at}") + return "end", state + + # REGISTER NODES + for name, func in { + "document": node_document, + "validation": node_validation, + "risk": node_risk, + "payment": node_payment, + "audit": node_audit, + "escalation": node_escalation, + "human_review": node_human_review, + "end": node_end, + }.items(): + try: + graph.add_node(name, func) + except Exception: + # fallback if add_node signature differs + setattr(graph, name, func) + + # ADD EDGES (DEFAULT PATHS) + try: + graph.add_edge("document", "validation") + graph.add_edge("validation", "risk") + graph.add_edge("risk", "payment") + graph.add_edge("payment", "audit") + graph.add_edge("audit", "end") + # Alternative / exception flows + graph.add_edge("document", "escalation") + graph.add_edge("validation", "escalation") + graph.add_edge("risk", "escalation") + graph.add_edge("escalation", "human_review") + graph.add_edge("human_review", "end") + + graph.set_entry_point("document") + except Exception as ex: + self.logger.logger.warning(f"Edge registration failed: {ex}") + + self.logger.logger.info("Conditional workflow graph built successfully.") + return graph + + + async def _document_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: DocumentAgent = agent_registry.get("document_agent") + print("agent from doc", agent) + if not agent: + agent = DocumentAgent() + agent_registry.register(agent) + print("Registry instance ID in graph:", id(agent_registry)) + + return await agent.run(state, workflow_type) + + async def _validation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: ValidationAgent = agent_registry.get("validation_agent") + print("agent from val", agent) + if not agent: + agent = ValidationAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + async def _risk_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: RiskAgent = agent_registry.get("risk_agent") + if not agent: + agent = RiskAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + async def _payment_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: PaymentAgent = agent_registry.get("payment_agent") + if not agent: + agent = PaymentAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + async def _audit_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: AuditAgent = agent_registry.get("audit_agent") + if not agent: + agent = AuditAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + async def _escalation_agent_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + agent: EscalationAgent = agent_registry.get("escalation_agent") + if not agent: + agent = EscalationAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + async def _human_review_node(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: + #Reusing escalation agent's human-in-the-loop or simply marking for manual review + agent: EscalationAgent = agent_registry.get("escalation_agent") + if not agent: + agent = EscalationAgent() + agent_registry.register(agent) + return await agent.run(state, workflow_type) + + def _route_after_document(self, state: InvoiceProcessingState) -> Literal["validation", "escalation", "end"]: + """Route decision after document extraction""" + #if extraction yielded no invoice_data or low confidence -> escalate + if not state.invoice_data: + return "escalation" + #if extraction confidence exists and is low -> escalate + conf = getattr(state.invoice_data, "extraction_confidence", None) + if conf is not None and conf<0.6: + return "escalation" + return "validation" + + + def _route_after_validation(self, state: InvoiceProcessingState) -> Literal["risk", "escalation", "end"]: + """Route decision after document validation""" + vr = state.validation_result + if not vr: + return "escalation" + #if missing PO or invalid -> escalate + try: + status = vr.validation_status + #ValidationStatus maybe enum or str + if isinstance(status,ValidationStatus): + status_val = status + else: + status_val = ValidationStatus(status) if isinstance(status,str) else None + if status_val == ValidationStatus.NO_MATCH or status_val == ValidationStatus.PARTIAL_MATCH and (not vr.amount_match): + return "escalation" + except Exception: + #fallback: if discrepancies exist -> escalation + if vr and getattr(vr,"discrepancies",None): + return "escalation" + return "risk" + + def _route_after_risk(self, state: InvoiceProcessingState) -> Literal["payment", "escalation", "human_review", "end"]: + """Route decision after risk assessment""" + ra = state.risk_assessment + if not ra: + return "escalation" + #ra.risk_level is an enum RiskLevel + rl = getattr(ra,"risk_level",None) + #handle strings or enums + rl_val = rl.value if hasattr(rl,"value") else str(rl).lower() + try: + if rl_val in (RiskLevel.CRITICAL.value, RiskLevel.HIGH.value): + #For critical-> human review; for high->escalate + if rl_val == RiskLevel.CRITICAL.value: + return "human_review" + return "escalation" + else: + #low or medium -> payment + return "payment" + except Exception: + return "payment" + + def _route_after_payment(self, state: InvoiceProcessingState) -> Literal["audit", "escalation", "end"]: + pd = getattr(state,"payment_decision",None) + if not pd: + return "escalation" + #If approved (or scheduled) -> audit + try: + status = pd.payment_status + #Accept enum or str + status_val = status if isinstance(status,str) else getattr(status,"value",str(status)) + if status_val in (PaymentStatus.APPROVED.value, PaymentStatus.SCHEDULED.value, PaymentStatus.PENDING_APPROVAL.value): + return "audit" + else: + return "escalation" + except Exception: + return "audit" + + def _route_after_audit(self, state: InvoiceProcessingState) -> Literal["escalation", "end"]: + cr = getattr(state, "compliance_report",None) + if not cr: + return "end" + #If any compliance issues ->escalate + issues = cr.get("issues",{}) if isinstance(cr, dict) else {} + has_issues = any(issues.get(k) for k in issues) + return "escalation" if has_issues else "end" + + async def _handle_escalation_chain(self, state: "InvoiceProcessingState", workflow_type): + """Common handler for escalation โ human review โ complete""" + state = await self._escalation_agent_node(state, workflow_type) + self._process_store[state.process_id] = state + state = await self._human_review_node(state, workflow_type) + state.overall_status = ProcessingStatus.COMPLETED + self._process_store[state.process_id] = state + return state + + async def process_invoice(self, file_name: str, workflow_type: str = "standard", + config: Dict[str, Any] = None) -> InvoiceProcessingState: + """ + Orchestrate processing for a single invoice file. + Supports 3 workflow types: standard, high_value, and expedited. + """ + process_id = f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" + initial_state = InvoiceProcessingState( + process_id=process_id, + file_name=file_name, + overall_status=ProcessingStatus.PENDING, + current_agent=None, + workflow_type=workflow_type, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + ) + + self._process_store[process_id] = initial_state + start_ts = datetime.utcnow() + state = initial_state + worked_agents = [] + try: + # STEP 1๏ธ Document Extraction + state = await self._document_agent_node(state, workflow_type) + self._process_store[process_id] = state + route = self._route_after_document(state) + print("state agent anme ::::::::::::::", state.agent_name) + worked_agents.append(state.agent_name) + if route == "escalation": + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # ---- Workflow branching ---- + if workflow_type == "expedited": + # Fast lane - skip validation if AI confidence is high + # if getattr(state, "extraction_confidence", 0.0) < 0.85: + state = await self._validation_agent_node(state, workflow_type) + self._process_store[process_id] = state + worked_agents.append(state.agent_name) + route = self._route_after_validation(state) + if route == "escalation": + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # Directly go to Payment, minimal audit + state = await self._payment_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + if getattr(state, "payment_decision", {}).decision == "auto-pay": + state = await self._audit_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + + elif workflow_type == "high_value": + # 2๏ธ Validation (twice for accuracy) + state = await self._validation_agent_node(state, workflow_type) + self._process_store[process_id] = state + state = await self._validation_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + route = self._route_after_validation(state) + if route == "escalation": + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # 3๏ธ Risk + state = await self._risk_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + route = self._route_after_risk(state) + if route in ["escalation", "human_review"]: + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # 4 Audit + state = await self._audit_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + + # 5 Mandatory human review for high-value invoices + state = await self._human_review_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + + else: # STANDARD workflow + # 2๏ธ Validation + state = await self._validation_agent_node(state, workflow_type) + self._process_store[process_id] = state + worked_agents.append(state.agent_name) + route = self._route_after_validation(state) + if route == "escalation": + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # 3๏ธ Risk + state = await self._risk_agent_node(state, workflow_type) + self._process_store[process_id] = state + worked_agents.append(state.agent_name) + route = self._route_after_risk(state) + if route in ["escalation", "human_review"]: + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # 4๏ธ Payment + state = await self._payment_agent_node(state, workflow_type) + self._process_store[process_id] = state + worked_agents.append(state.agent_name) + route = self._route_after_payment(state) + if route == "escalation": + state = await self._handle_escalation_chain(state, workflow_type) + worked_agents.append(state.agent_name) + return state, worked_agents + + # 5๏ธ Audit + state = await self._audit_agent_node(state, workflow_type) + worked_agents.append(state.agent_name) + self._process_store[process_id] = state + + # Success completion + state.overall_status = ProcessingStatus.COMPLETED + state.updated_at = datetime.utcnow() + elapsed = (datetime.utcnow() - start_ts).total_seconds() + self.logger.logger.info(f"Process {process_id} ({workflow_type}) completed in {elapsed:.2f}s") + self._process_store[process_id] = state + # print("from graph worked agents::::", worked_agents) + return state, worked_agents + + except Exception as e: + self.logger.logger.exception(f"Error processing invoice {file_name}: {e}") + state.overall_status = ProcessingStatus.FAILED + self._process_store[process_id] = state + return state, worked_agents + + + # async def process_batch(self, file_names: List[str], workflow_type: str = "standard", + # max_concurrent: int = 5) -> List[InvoiceProcessingState]: + # """Process a batch of files with limit concurrency""" + # sem = asyncio.Semaphore(max_concurrent) + # results: List[InvoiceProcessingState] = [] + + # async def _worker(fn: str): + # async with sem: + # return await self.process_invoice(fn, workflow_type=workflow_type) + + # tasks = [asyncio.create_task(_worker(f)) for f in file_names] + # completed = await asyncio.gather(*tasks) + # for st in completed: + # results.append(st) + # return results + async def process_batch(self, file_names: List[str], workflow_type: str = "standard", + max_concurrent: int = 5): + sem = asyncio.Semaphore(max_concurrent) + results = [] # will store: {"state": ..., "worked_agents": [...]} + + async def _worker(fn: str): + async with sem: + return await self.process_invoice(fn, workflow_type=workflow_type) + + tasks = [asyncio.create_task(_worker(f)) for f in file_names] + completed = await asyncio.gather(*tasks) + + for result in completed: + state, worked_agents = result # unpack the tuple + results.append({ + "state": state, + "worked_agents": worked_agents + }) + + return results + + + async def get_workflow_status(self, process_id: str) -> Optional[Dict[str, Any]]: + """Return the stored workflow status dictionary for a given process_id""" + state = self._process_store.get(process_id) + if not state: + return None + return {"process_id":process_id, "status":state.overall_status, "updated_at": getattr(state,"updated_at",None), "state":state.model_dump()} + + async def health_check(self) -> Dict[str, Any]: + """Aggregate health check across agents and the orchestrator itself""" + agents_health = await agent_registry.health_check_all() + return {"orchestrator":"Healthy","timestamp":datetime.utcnow().isoformat(),"agent":agents_health} + + def _extract_final_state(self, result, initial_state: InvoiceProcessingState) -> InvoiceProcessingState: + """Compatibility helper (returns invoice processing state)""" + return result + + +invoice_workflow: Optional[InvoiceProcessingGraph] = None + +def get_workflow(config: Dict[str, Any] = None) -> InvoiceProcessingGraph: + global invoice_workflow + if invoice_workflow is None: + invoice_workflow = InvoiceProcessingGraph(config=config) + return invoice_workflow \ No newline at end of file diff --git a/Project/main.py b/Project/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b56ba8f116087343b3f6b107e0c1c5dd94660f5a --- /dev/null +++ b/Project/main.py @@ -0,0 +1,1312 @@ +"""Main Streamlit UI for invoice processing""" +# TODO: Build Streamlit dashboard +import os +import asyncio +import pandas as pd +import streamlit as st +from datetime import datetime +import plotly.express as px +import plotly.graph_objects as go +from typing import Dict, Any, List +from enum import Enum +import fitz # PyMuPDF +import re + +from graph import get_workflow +from state import InvoiceProcessingState, ProcessingStatus, ValidationStatus, RiskLevel, PaymentStatus +from utils.logger import setup_logging, get_logger + +import json +import google.generativeai as genai +from agents.smart_explainer_agent import SmartExplainerAgent +from agents.insights_agent import InsightAgent +from agents.forecast_agent import ForecastAgent + + +# Logging Setup +setup_logging() +logger = get_logger("InvoiceProcessingApp") + +def make_arrow_safe(df: pd.DataFrame) -> pd.DataFrame: + """ + Convert any DataFrame to be Streamlit/Arrow compatible: + - Converts Enums to string values + - Replaces None/NaN with 'Not applicable' + - Ensures all columns are strings (avoids Arrow conversion errors) + - Capitalizes column headers + """ + if df.empty: + return df + + # Convert Enums to strings + df = df.applymap(lambda x: x.value if isinstance(x, Enum) else x) + + # Replace None/NaN and make all values string + df = df.fillna("Not applicable").astype(str) + + # Capitalize column names nicely + df.columns = [col.capitalize() for col in df.columns] + return df + +import ast +import re +def parse_escalation_details(s): + if isinstance(s, dict): + return s + if not isinstance(s, str) or not s.strip(): + return {} + + # Convert datetime.datetime(YYYY,MM,DD,HH,MM,SS) โ "YYYY-MM-DD HH:MM:SS" + def repl(match): + parts = match.group(1).split(',') + parts = [p.strip() for p in parts] + # convert to ISO style + return f"'{parts[0]}-{parts[1]}-{parts[2]} {parts[3]}:{parts[4]}:{parts[5]}'" + + s_clean = re.sub(r"datetime\.datetime\((.*?)\)", repl, s) + + try: + return ast.literal_eval(s_clean) + except: + return {} + +def serialize_state(state): + # Pydantic v2 + if hasattr(state, "model_dump"): + return state.model_dump() + + # Pydantic v1 fallback + if hasattr(state, "dict"): + return state.dict() + + # Normal python object + if hasattr(state, "__dict__"): + return state.__dict__ + + # Already a dict + if isinstance(state, dict): + return state + + # string, int, None, etc + return {"value": state} + + +class InvoiceProcessingApp: + """Main application class for AI Invoice Processing Dashboard""" + + def __init__(self): + self.workflow = None + self.initialize_session_state() + self.initialize_workflow() + self.smart_explainer = SmartExplainerAgent() + self.insights = InsightAgent() + self.forecast = ForecastAgent() + self.gemini_api_key = os.getenv("GEMINI_API_KEY_7") + + # INITIALIZATION + def initialize_session_state(self): + if "selected_files" not in st.session_state: + st.session_state.selected_files = [] + if "results" not in st.session_state: + st.session_state.results = [] + if "last_run" not in st.session_state: + st.session_state.last_run = None + if "workflow_type" not in st.session_state: + st.session_state.workflow_type = "standard" + if "max_concurrent" not in st.session_state: + st.session_state.max_concurrent = 1 + if "annotated_pdfs" not in st.session_state: + st.session_state.annotated_pdfs = {} + # if "priority_level" not in st.session_state: + # st.session_state.priority_level = 1 + + def initialize_workflow(self): + try: + self.workflow = get_workflow() + logger.info("Workflow initialized successfully.") + except Exception as e: + logger.exception("Workflow initialization failed: %s", e) + st.error("Failed to initialize workflow. Check logs for details.") + + # SIDEBAR + HEADER + def render_header(self): + st.markdown( + """ +AI-Powered Invoice Processing with Intelligent Agent Workflows
+