Spaces:
Sleeping
Sleeping
File size: 9,914 Bytes
2b44e69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
"""State models and enumerations"""
# TODO: Define state models
from __future__ import annotations
import uuid # extra import
from typing import Dict, List, Optional, Any, Literal
from pydantic import BaseModel, Field, root_validator
from datetime import datetime
from enum import Enum
class ProcessingStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RiskLevel(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class PaymentStatus(str, Enum):
NOT_STARTED = "not_started"
PENDING_APPROVAL = "pending_approval"
APPROVED = "approved"
SCHEDULED = "scheduled"
PAID = "paid"
FAILED = "failed"
class ItemDetail(BaseModel):
item_id: Optional[str] = None
item_name: Optional[str] = None
# description: Optional[str] = None
quantity: int = Field(..., ge=0)
rate: float = Field(..., ge=0.0)
# total: Optional[float] = None
# unit: Optional[str] = None
amount: float = Field(..., ge=0.0)
category: Optional[str] = None
class InvoiceData(BaseModel):
invoice_number: Optional[str] = None
order_id: Optional[str] = None
file_name: Optional[str] = None
customer_name: Optional[str] = None
invoice_date: Optional[datetime] = None
due_date: Optional[datetime] = None
currency: Optional[str] = "USD"
total: Optional[float] = None
# line_items: List[ItemDetail] = Field(default_factory=list)
raw_text: Optional[str] = None
item_details: Optional[list] = None
# confidence_scores: Dict[str, float] = Field(default_factory=dict)
extraction_confidence: Optional[float] = None
class ValidationStatus(str, Enum):
NOT_STARTED = "not_started"
VALID = "valid"
INVALID = "invalid"
PARTIAL_MATCH = "partial_match"
MISSING_PO = "missing_po"
class ValidationResult(BaseModel):
po_found: bool = False
quantity_match: bool = False
rate_match: bool = False
amount_match: bool = False
validation_status: ValidationStatus = ValidationStatus.NOT_STARTED
validation_result: Optional[str] = None
discrepencies: List[str] = Field(default_factory=list)
confidence_score: Optional[float] = None
# expected_amount: Optional[float] = None
po_data: Optional[Dict[str, Any]] = None
class RiskAssessment(BaseModel):
risk_score: float = Field(0.0, ge=0.0, le=1.0)
risk_level: RiskLevel = RiskLevel.LOW
signals: List[str] = Field(default_factory=list)
vendor_status: Optional[str] = None
compliance_violations: List[str] = Field(default_factory=list)
class PaymentDecision(BaseModel):
decision: Optional[Literal["auto_pay", "manual_approval", "hold", "reject"]]
status: PaymentStatus = PaymentStatus.NOT_STARTED
scheduled_date: Optional[datetime] = None
transaction_id: Optional[str] = None
attempts: int = 0
reason: Optional[str] = None
class AuditTrail(BaseModel):
process_id: Optional[str] = None
timestamp: datetime = Field(default_factory=datetime.utcnow)
agent_name: str
action: str
details: Dict[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
class AgentMetrics(BaseModel):
processed_count: int = 0
avg_latency_ms: Optional[float] = None
last_run_at: Optional[datetime] = None
errors: int = 0
success_rate: Optional[float] = None
class InvoiceProcessingState(BaseModel):
# Core identifiers
process_id: str = Field(
default_factory=lambda: f"proc_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
)
file_name: Optional[str] = None
# Processing status
overall_status: ProcessingStatus = ProcessingStatus.PENDING
current_agent: Optional[str] = None
workflow_type: str = "Standard"
# Agent Outputs
invoice_data: Optional[InvoiceData] = None
validation_result: Optional[ValidationResult] = None
validation_status: Optional[str] = None
risk_assessment: Optional[RiskAssessment] = None
payment_decision: Optional[PaymentDecision] = None
approval_chain: Optional[List[Dict[str, Any]]] = None
# Audit and Tracking
agent_name: Optional[str] = None
audit_trail: List[AuditTrail] = Field(default_factory=list)
agent_metrics: Dict[str, AgentMetrics] = Field(default_factory=dict)
compliance_report: Optional[Dict[str, Any]] = None
audit_summary: Optional[Dict[str, Any]] = None
reportable_events: Optional[List[Dict[str, Any]]] = None
# Escalation
escalation_required: bool = False
human_review_required: bool = False
escalation_details: Optional[str] = None
escalation_reason: Optional[str] = None
# Workflow Control
retry_count: int = 0
completed_agents: List[str] = Field(default_factory=list)
# Timestamps
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# Convenience Methods
def add_audit_entry(self, agent_name: str, action: str, status: Optional[ProcessingStatus] = None, details: Optional[Dict[str, Any]] = None, process_id: Optional[str] = None) -> None:
entry = AuditTrail(agent_name=agent_name, action=action, status = status or self.overall_status, details=details or {}, process_id=process_id)
print("entry.....", entry)
print("self.audit_trail...", self.audit_trail)
self.audit_trail.append(entry)
self.updated_at = datetime.utcnow()
def add_agent_metric(self, agent: str, processed: int = 0, latency_ms: Optional[float] = None, errors: int = 0) -> None:
metrics = self.agent_metrics.get(agent) or AgentMetrics()
metrics.processed_count += processed
metrics.errors += errors
if latency_ms is not None:
if metrics.avg_latency_ms is None:
metrics.avg_latency_ms = latency_ms
else:
metrics.avg_latency_ms = (metrics.avg_latency_ms + latency_ms) / 2.0
metrics.last_run_at = datetime.utcnow()
if metrics.processed_count > 0:
metrics.success_rate = max(0.0, 1.0 - (metrics.errors / max(1, metrics.processed_count)))
self.agent_metrics[agent] = metrics
self.updated_at = datetime.utcnow()
def update_agent_metrics(self, agent_name: str, success: bool, duration_ms: float):
"""
Update or create performance metrics for an agent.
Expected structure aligns with test_9_agent_metrics: attributes like executions, success_count, failure_count.
"""
# Ensure agent_metrics dict exists
if self.agent_metrics is None:
self.agent_metrics = {}
# Get or initialize metrics object
metrics = self.agent_metrics.get(agent_name)
# If existing metrics is a dict or None, replace with a new AgentMetrics-like object
if isinstance(metrics, dict) or metrics is None:
metrics = type("DynamicMetrics", (), {})()
metrics.executions = 0
metrics.successes = 0
metrics.failure_count = 0
metrics.total_duration_ms = 0.0
metrics.avg_duration_ms = 0.0
# Update fields
metrics.executions += 1
if success:
metrics.successes += 1
else:
metrics.failure_count += 1
metrics.total_duration_ms += duration_ms
metrics.avg_duration_ms = round(metrics.total_duration_ms / metrics.executions, 2)
# Save back
self.agent_metrics[agent_name] = metrics
self.updated_at = datetime.utcnow()
def mark_agent_completed(self, agent: str) -> None:
if agent not in self.completed_agents:
self.completed_agents.append(agent)
self.updated_at = datetime.utcnow()
def requires_escalation(self, risk_threshold: float = 0.6, confidence_threshold: float = 0.7) -> bool:
if self.validation_result and self.validation_result.validation_status == ValidationStatus.INVALID:
return True
if self.risk_assessment and self.risk_assessment.risk_score >= risk_threshold:
return True
if self.validation_result and self.validation_result.confidence_score is not None and self.validation_result.confidence_score < confidence_threshold:
return True
return False
def to_dict(self) -> Dict[str, Any]:
return self.model_dump()
@root_validator(pre=True)
def ensure_timestamps(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "created_at" not in values or values.get("created_at") is None:
values["created_at"] = datetime.utcnow()
if "updated_at" not in values or values.get("updated_at") is None:
values["updated_at"] = datetime.utcnow()
return values
class WorkflowConfig(BaseModel):
name: str
auto_approve_threshold: float = 0.3
auto_approve_amount_limit: Optional[float] = None
tolerance_percent: float = 5.0
escalation_rules: Dict[str, Any] = Field(default_factory=dict)
WORKFLOW_CONFIGS: Dict[str, WorkflowConfig] = {
"standard": WorkflowConfig(
name="standard",
auto_approve_threshold=0.3,
auto_approve_amount_limit=10000.0,
tolerance_percent=5.0,
escalation_rules={"sla_hours": 24},
),
"high_value": WorkflowConfig(
name="high_value",
auto_approve_threshold=0.1,
auto_approve_amount_limit=5000.0,
tolerance_percent=2.0,
escalation_rules={"require_cfo": True, "sla_hours": 12},
),
"expedited": WorkflowConfig(
name="expedited",
auto_approve_threshold=0.5,
auto_approve_amount_limit=5000.0,
tolerance_percent=10.0,
escalation_rules={"skip_manual_for_low_risk": True},
),
}
|