Spaces:
Running
Running
| """State models and enumerations""" | |
| # TODO: Define state models | |
| from __future__ import annotations | |
| import uuid # extra import | |
| from typing import Dict, List, Optional, Any, Literal | |
| from pydantic import BaseModel, Field, root_validator | |
| from datetime import datetime | |
| from enum import Enum | |
| class ProcessingStatus(str, Enum): | |
| PENDING = "pending" | |
| IN_PROGRESS = "in_progress" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| CANCELLED = "cancelled" | |
| class RiskLevel(str, Enum): | |
| LOW = "low" | |
| MEDIUM = "medium" | |
| HIGH = "high" | |
| CRITICAL = "critical" | |
| class PaymentStatus(str, Enum): | |
| NOT_STARTED = "not_started" | |
| PENDING_APPROVAL = "pending_approval" | |
| APPROVED = "approved" | |
| SCHEDULED = "scheduled" | |
| PAID = "paid" | |
| FAILED = "failed" | |
| class ItemDetail(BaseModel): | |
| item_id: Optional[str] = None | |
| item_name: Optional[str] = None | |
| # description: Optional[str] = None | |
| quantity: int = Field(..., ge=0) | |
| rate: float = Field(..., ge=0.0) | |
| # total: Optional[float] = None | |
| # unit: Optional[str] = None | |
| amount: float = Field(..., ge=0.0) | |
| category: Optional[str] = None | |
| class InvoiceData(BaseModel): | |
| invoice_number: Optional[str] = None | |
| order_id: Optional[str] = None | |
| file_name: Optional[str] = None | |
| customer_name: Optional[str] = None | |
| invoice_date: Optional[datetime] = None | |
| due_date: Optional[datetime] = None | |
| currency: Optional[str] = "USD" | |
| total: Optional[float] = None | |
| # line_items: List[ItemDetail] = Field(default_factory=list) | |
| raw_text: Optional[str] = None | |
| item_details: Optional[list] = None | |
| # confidence_scores: Dict[str, float] = Field(default_factory=dict) | |
| extraction_confidence: Optional[float] = None | |
| class ValidationStatus(str, Enum): | |
| NOT_STARTED = "not_started" | |
| VALID = "valid" | |
| INVALID = "invalid" | |
| PARTIAL_MATCH = "partial_match" | |
| MISSING_PO = "missing_po" | |
| class ValidationResult(BaseModel): | |
| po_found: bool = False | |
| quantity_match: bool = False | |
| rate_match: bool = False | |
| amount_match: bool = False | |
| validation_status: ValidationStatus = ValidationStatus.NOT_STARTED | |
| validation_result: Optional[str] = None | |
| discrepencies: List[str] = Field(default_factory=list) | |
| confidence_score: Optional[float] = None | |
| # expected_amount: Optional[float] = None | |
| po_data: Optional[Dict[str, Any]] = None | |
| class RiskAssessment(BaseModel): | |
| risk_score: float = Field(0.0, ge=0.0, le=1.0) | |
| risk_level: RiskLevel = RiskLevel.LOW | |
| signals: List[str] = Field(default_factory=list) | |
| vendor_status: Optional[str] = None | |
| compliance_violations: List[str] = Field(default_factory=list) | |
| class PaymentDecision(BaseModel): | |
| decision: Optional[Literal["auto_pay", "manual_approval", "hold", "reject"]] | |
| status: PaymentStatus = PaymentStatus.NOT_STARTED | |
| scheduled_date: Optional[datetime] = None | |
| transaction_id: Optional[str] = None | |
| attempts: int = 0 | |
| reason: Optional[str] = None | |
| class AuditTrail(BaseModel): | |
| process_id: Optional[str] = None | |
| timestamp: datetime = Field(default_factory=datetime.utcnow) | |
| agent_name: str | |
| action: str | |
| details: Dict[str, Any] = Field(default_factory=dict) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| class AgentMetrics(BaseModel): | |
| processed_count: int = 0 | |
| avg_latency_ms: Optional[float] = None | |
| last_run_at: Optional[datetime] = None | |
| errors: int = 0 | |
| success_rate: Optional[float] = None | |
| class InvoiceProcessingState(BaseModel): | |
| # Core identifiers | |
| process_id: str = Field( | |
| default_factory=lambda: f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" | |
| ) | |
| file_name: Optional[str] = None | |
| # Processing status | |
| overall_status: ProcessingStatus = ProcessingStatus.PENDING | |
| current_agent: Optional[str] = None | |
| workflow_type: str = "Standard" | |
| # Agent Outputs | |
| invoice_data: Optional[InvoiceData] = None | |
| validation_result: Optional[ValidationResult] = None | |
| validation_status: Optional[str] = None | |
| risk_assessment: Optional[RiskAssessment] = None | |
| payment_decision: Optional[PaymentDecision] = None | |
| approval_chain: Optional[List[Dict[str, Any]]] = None | |
| # Audit and Tracking | |
| agent_name: Optional[str] = None | |
| audit_trail: List[AuditTrail] = Field(default_factory=list) | |
| agent_metrics: Dict[str, AgentMetrics] = Field(default_factory=dict) | |
| compliance_report: Optional[Dict[str, Any]] = None | |
| audit_summary: Optional[Dict[str, Any]] = None | |
| reportable_events: Optional[List[Dict[str, Any]]] = None | |
| # Escalation | |
| escalation_required: bool = False | |
| human_review_required: bool = False | |
| escalation_details: Optional[str] = None | |
| escalation_reason: Optional[str] = None | |
| # Workflow Control | |
| retry_count: int = 0 | |
| completed_agents: List[str] = Field(default_factory=list) | |
| # Timestamps | |
| created_at: datetime = Field(default_factory=datetime.utcnow) | |
| updated_at: datetime = Field(default_factory=datetime.utcnow) | |
| # Convenience Methods | |
| def add_audit_entry(self, agent_name: str, action: str, status: Optional[ProcessingStatus] = None, details: Optional[Dict[str, Any]] = None, process_id: Optional[str] = None) -> None: | |
| entry = AuditTrail(agent_name=agent_name, action=action, status = status or self.overall_status, details=details or {}, process_id=process_id) | |
| print("entry.....", entry) | |
| print("self.audit_trail...", self.audit_trail) | |
| self.audit_trail.append(entry) | |
| self.updated_at = datetime.utcnow() | |
| def add_agent_metric(self, agent: str, processed: int = 0, latency_ms: Optional[float] = None, errors: int = 0) -> None: | |
| metrics = self.agent_metrics.get(agent) or AgentMetrics() | |
| metrics.processed_count += processed | |
| metrics.errors += errors | |
| if latency_ms is not None: | |
| if metrics.avg_latency_ms is None: | |
| metrics.avg_latency_ms = latency_ms | |
| else: | |
| metrics.avg_latency_ms = (metrics.avg_latency_ms + latency_ms) / 2.0 | |
| metrics.last_run_at = datetime.utcnow() | |
| if metrics.processed_count > 0: | |
| metrics.success_rate = max(0.0, 1.0 - (metrics.errors / max(1, metrics.processed_count))) | |
| self.agent_metrics[agent] = metrics | |
| self.updated_at = datetime.utcnow() | |
| def update_agent_metrics(self, agent_name: str, success: bool, duration_ms: float): | |
| """ | |
| Update or create performance metrics for an agent. | |
| Expected structure aligns with test_9_agent_metrics: attributes like executions, success_count, failure_count. | |
| """ | |
| # Ensure agent_metrics dict exists | |
| if self.agent_metrics is None: | |
| self.agent_metrics = {} | |
| # Get or initialize metrics object | |
| metrics = self.agent_metrics.get(agent_name) | |
| # If existing metrics is a dict or None, replace with a new AgentMetrics-like object | |
| if isinstance(metrics, dict) or metrics is None: | |
| metrics = type("DynamicMetrics", (), {})() | |
| metrics.executions = 0 | |
| metrics.successes = 0 | |
| metrics.failure_count = 0 | |
| metrics.total_duration_ms = 0.0 | |
| metrics.avg_duration_ms = 0.0 | |
| # Update fields | |
| metrics.executions += 1 | |
| if success: | |
| metrics.successes += 1 | |
| else: | |
| metrics.failure_count += 1 | |
| metrics.total_duration_ms += duration_ms | |
| metrics.avg_duration_ms = round(metrics.total_duration_ms / metrics.executions, 2) | |
| # Save back | |
| self.agent_metrics[agent_name] = metrics | |
| self.updated_at = datetime.utcnow() | |
| def mark_agent_completed(self, agent: str) -> None: | |
| if agent not in self.completed_agents: | |
| self.completed_agents.append(agent) | |
| self.updated_at = datetime.utcnow() | |
| def requires_escalation(self, risk_threshold: float = 0.6, confidence_threshold: float = 0.7) -> bool: | |
| if self.validation_result and self.validation_result.validation_status == ValidationStatus.INVALID: | |
| return True | |
| if self.risk_assessment and self.risk_assessment.risk_score >= risk_threshold: | |
| return True | |
| if self.validation_result and self.validation_result.confidence_score is not None and self.validation_result.confidence_score < confidence_threshold: | |
| return True | |
| return False | |
| def to_dict(self) -> Dict[str, Any]: | |
| return self.model_dump() | |
| def ensure_timestamps(cls, values: Dict[str, Any]) -> Dict[str, Any]: | |
| if "created_at" not in values or values.get("created_at") is None: | |
| values["created_at"] = datetime.utcnow() | |
| if "updated_at" not in values or values.get("updated_at") is None: | |
| values["updated_at"] = datetime.utcnow() | |
| return values | |
| class WorkflowConfig(BaseModel): | |
| name: str | |
| auto_approve_threshold: float = 0.3 | |
| auto_approve_amount_limit: Optional[float] = None | |
| tolerance_percent: float = 5.0 | |
| escalation_rules: Dict[str, Any] = Field(default_factory=dict) | |
| WORKFLOW_CONFIGS: Dict[str, WorkflowConfig] = { | |
| "standard": WorkflowConfig( | |
| name="standard", | |
| auto_approve_threshold=0.3, | |
| auto_approve_amount_limit=10000.0, | |
| tolerance_percent=5.0, | |
| escalation_rules={"sla_hours": 24}, | |
| ), | |
| "high_value": WorkflowConfig( | |
| name="high_value", | |
| auto_approve_threshold=0.1, | |
| auto_approve_amount_limit=5000.0, | |
| tolerance_percent=2.0, | |
| escalation_rules={"require_cfo": True, "sla_hours": 12}, | |
| ), | |
| "expedited": WorkflowConfig( | |
| name="expedited", | |
| auto_approve_threshold=0.5, | |
| auto_approve_amount_limit=5000.0, | |
| tolerance_percent=10.0, | |
| escalation_rules={"skip_manual_for_low_risk": True}, | |
| ), | |
| } | |