InvoiceAgenticAI / state.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
"""State models and enumerations"""
# TODO: Define state models
from __future__ import annotations
import uuid # extra import
from typing import Dict, List, Optional, Any, Literal
from pydantic import BaseModel, Field, root_validator
from datetime import datetime
from enum import Enum
class ProcessingStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RiskLevel(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class PaymentStatus(str, Enum):
NOT_STARTED = "not_started"
PENDING_APPROVAL = "pending_approval"
APPROVED = "approved"
SCHEDULED = "scheduled"
PAID = "paid"
FAILED = "failed"
class ItemDetail(BaseModel):
item_id: Optional[str] = None
item_name: Optional[str] = None
# description: Optional[str] = None
quantity: int = Field(..., ge=0)
rate: float = Field(..., ge=0.0)
# total: Optional[float] = None
# unit: Optional[str] = None
amount: float = Field(..., ge=0.0)
category: Optional[str] = None
class InvoiceData(BaseModel):
invoice_number: Optional[str] = None
order_id: Optional[str] = None
file_name: Optional[str] = None
customer_name: Optional[str] = None
invoice_date: Optional[datetime] = None
due_date: Optional[datetime] = None
currency: Optional[str] = "USD"
total: Optional[float] = None
# line_items: List[ItemDetail] = Field(default_factory=list)
raw_text: Optional[str] = None
item_details: Optional[list] = None
# confidence_scores: Dict[str, float] = Field(default_factory=dict)
extraction_confidence: Optional[float] = None
class ValidationStatus(str, Enum):
NOT_STARTED = "not_started"
VALID = "valid"
INVALID = "invalid"
PARTIAL_MATCH = "partial_match"
MISSING_PO = "missing_po"
class ValidationResult(BaseModel):
po_found: bool = False
quantity_match: bool = False
rate_match: bool = False
amount_match: bool = False
validation_status: ValidationStatus = ValidationStatus.NOT_STARTED
validation_result: Optional[str] = None
discrepencies: List[str] = Field(default_factory=list)
confidence_score: Optional[float] = None
# expected_amount: Optional[float] = None
po_data: Optional[Dict[str, Any]] = None
class RiskAssessment(BaseModel):
risk_score: float = Field(0.0, ge=0.0, le=1.0)
risk_level: RiskLevel = RiskLevel.LOW
signals: List[str] = Field(default_factory=list)
vendor_status: Optional[str] = None
compliance_violations: List[str] = Field(default_factory=list)
class PaymentDecision(BaseModel):
decision: Optional[Literal["auto_pay", "manual_approval", "hold", "reject"]]
status: PaymentStatus = PaymentStatus.NOT_STARTED
scheduled_date: Optional[datetime] = None
transaction_id: Optional[str] = None
attempts: int = 0
reason: Optional[str] = None
class AuditTrail(BaseModel):
process_id: Optional[str] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
agent_name: str
action: str
details: Dict[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
class AgentMetrics(BaseModel):
processed_count: int = 0
avg_latency_ms: Optional[float] = None
last_run_at: Optional[datetime] = None
errors: int = 0
success_rate: Optional[float] = None
class InvoiceProcessingState(BaseModel):
# Core identifiers
process_id: str = Field(
default_factory=lambda: f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
)
file_name: Optional[str] = None
# Processing status
overall_status: ProcessingStatus = ProcessingStatus.PENDING
current_agent: Optional[str] = None
workflow_type: str = "Standard"
# Agent Outputs
invoice_data: Optional[InvoiceData] = None
validation_result: Optional[ValidationResult] = None
validation_status: Optional[str] = None
risk_assessment: Optional[RiskAssessment] = None
payment_decision: Optional[PaymentDecision] = None
approval_chain: Optional[List[Dict[str, Any]]] = None
# Audit and Tracking
agent_name: Optional[str] = None
audit_trail: List[AuditTrail] = Field(default_factory=list)
agent_metrics: Dict[str, AgentMetrics] = Field(default_factory=dict)
compliance_report: Optional[Dict[str, Any]] = None
audit_summary: Optional[Dict[str, Any]] = None
reportable_events: Optional[List[Dict[str, Any]]] = None
# Escalation
escalation_required: bool = False
human_review_required: bool = False
escalation_details: Optional[str] = None
escalation_reason: Optional[str] = None
# Workflow Control
retry_count: int = 0
completed_agents: List[str] = Field(default_factory=list)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# Convenience Methods
def add_audit_entry(self, agent_name: str, action: str, status: Optional[ProcessingStatus] = None, details: Optional[Dict[str, Any]] = None, process_id: Optional[str] = None) -> None:
entry = AuditTrail(agent_name=agent_name, action=action, status = status or self.overall_status, details=details or {}, process_id=process_id)
print("entry.....", entry)
print("self.audit_trail...", self.audit_trail)
self.audit_trail.append(entry)
self.updated_at = datetime.utcnow()
def add_agent_metric(self, agent: str, processed: int = 0, latency_ms: Optional[float] = None, errors: int = 0) -> None:
metrics = self.agent_metrics.get(agent) or AgentMetrics()
metrics.processed_count += processed
metrics.errors += errors
if latency_ms is not None:
if metrics.avg_latency_ms is None:
metrics.avg_latency_ms = latency_ms
else:
metrics.avg_latency_ms = (metrics.avg_latency_ms + latency_ms) / 2.0
metrics.last_run_at = datetime.utcnow()
if metrics.processed_count > 0:
metrics.success_rate = max(0.0, 1.0 - (metrics.errors / max(1, metrics.processed_count)))
self.agent_metrics[agent] = metrics
self.updated_at = datetime.utcnow()
def update_agent_metrics(self, agent_name: str, success: bool, duration_ms: float):
"""
Update or create performance metrics for an agent.
Expected structure aligns with test_9_agent_metrics: attributes like executions, success_count, failure_count.
"""
# Ensure agent_metrics dict exists
if self.agent_metrics is None:
self.agent_metrics = {}
# Get or initialize metrics object
metrics = self.agent_metrics.get(agent_name)
# If existing metrics is a dict or None, replace with a new AgentMetrics-like object
if isinstance(metrics, dict) or metrics is None:
metrics = type("DynamicMetrics", (), {})()
metrics.executions = 0
metrics.successes = 0
metrics.failure_count = 0
metrics.total_duration_ms = 0.0
metrics.avg_duration_ms = 0.0
# Update fields
metrics.executions += 1
if success:
metrics.successes += 1
else:
metrics.failure_count += 1
metrics.total_duration_ms += duration_ms
metrics.avg_duration_ms = round(metrics.total_duration_ms / metrics.executions, 2)
# Save back
self.agent_metrics[agent_name] = metrics
self.updated_at = datetime.utcnow()
def mark_agent_completed(self, agent: str) -> None:
if agent not in self.completed_agents:
self.completed_agents.append(agent)
self.updated_at = datetime.utcnow()
def requires_escalation(self, risk_threshold: float = 0.6, confidence_threshold: float = 0.7) -> bool:
if self.validation_result and self.validation_result.validation_status == ValidationStatus.INVALID:
return True
if self.risk_assessment and self.risk_assessment.risk_score >= risk_threshold:
return True
if self.validation_result and self.validation_result.confidence_score is not None and self.validation_result.confidence_score < confidence_threshold:
return True
return False
def to_dict(self) -> Dict[str, Any]:
return self.model_dump()
@root_validator(pre=True)
def ensure_timestamps(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "created_at" not in values or values.get("created_at") is None:
values["created_at"] = datetime.utcnow()
if "updated_at" not in values or values.get("updated_at") is None:
values["updated_at"] = datetime.utcnow()
return values
class WorkflowConfig(BaseModel):
name: str
auto_approve_threshold: float = 0.3
auto_approve_amount_limit: Optional[float] = None
tolerance_percent: float = 5.0
escalation_rules: Dict[str, Any] = Field(default_factory=dict)
WORKFLOW_CONFIGS: Dict[str, WorkflowConfig] = {
"standard": WorkflowConfig(
name="standard",
auto_approve_threshold=0.3,
auto_approve_amount_limit=10000.0,
tolerance_percent=5.0,
escalation_rules={"sla_hours": 24},
),
"high_value": WorkflowConfig(
name="high_value",
auto_approve_threshold=0.1,
auto_approve_amount_limit=5000.0,
tolerance_percent=2.0,
escalation_rules={"require_cfo": True, "sla_hours": 12},
),
"expedited": WorkflowConfig(
name="expedited",
auto_approve_threshold=0.5,
auto_approve_amount_limit=5000.0,
tolerance_percent=10.0,
escalation_rules={"skip_manual_for_low_risk": True},
),
}