Spaces:
Running
Running
File size: 8,939 Bytes
2b44e69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
"""Base Agent Class for Invoice Processing System"""
# TODO: Implement agent
import time
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from datetime import datetime
from state import InvoiceProcessingState, ProcessingStatus, AuditTrail
from utils.logger import get_logger
class BaseAgent(ABC):
"""Abstract base class for all invoice processing agents"""
def __init__(self, agent_name: str, config: Dict[str, Any] = None):
self.agent_name = agent_name
self.config = config or {}
self.logger = get_logger(agent_name)
self.metrics: Dict[str,Any] = {
"processed" : 0,
"errors" : 0,
"avg_latency_ms" : None,
"last_run_at" : None
}
self.start_time: Optional[float] = None
@abstractmethod
async def execute(self, state: InvoiceProcessingState) -> InvoiceProcessingState:
raise NotImplementedError
async def run(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
self.start_time = time.time()
self.logger.logger.info(f"Starting {self.agent_name} execution.")
if not self._validate_preconditions(state, workflow_type):
self.logger.logger.warning(f"Preconditions not met for {self.agent_name}.")
self.metrics["processed"] = int(self.metrics.get("processed", 0)) + 1
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
# optional but very good:
state.add_agent_metric(self.agent_name, processed=1, latency_ms=0, errors=0)
state.add_audit_entry(
self.agent_name,
"precondition_failed",
{"note": "Preconditions not met, agent skipped."}
)
return state
state.current_agent = self.agent_name
state.agent_name = self.agent_name
state.overall_status = ProcessingStatus.IN_PROGRESS
try:
updated_state = await self.execute(state, workflow_type)
try:
self._validate_postconditions(updated_state)
except Exception as post_exc:
self.logger.logger.warning(f"Postcondition check raised for {self.agent_name}:{post_exc}")
state.mark_agent_completed(self.agent_name)
latency_ms = (time.time()-self.start_time)*1000
self.metrics["processed"] = int(self.metrics.get("processed",0)) + 1
prev_avg = self.metrics.get("avg_latency_ms")
if prev_avg is None:
self.metrics["avg_latency_ms"] = latency_ms
else:
self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
print(
f"Agent: {self.agent_name} | "
f"id: {id(self)} | "
f"last_run_at: {self.metrics['last_run_at']}"
)
print("self.metrics[last_run_at]", self.metrics["last_run_at"])
state.add_agent_metric(self.agent_name,processed=1,latency_ms=latency_ms)
state.add_audit_entry(self.agent_name, action="Agent Successfully Executed", status=ProcessingStatus.COMPLETED, details={"latency_ms":latency_ms}, process_id=state.process_id)
self.logger.logger.info(f"{self.agent_name}completed successfully in {latency_ms:.2f}ms.")
return updated_state
except Exception as e:
latency_ms = (time.time()-self.start_time)*1000 if self.start_time else 0.0
# self._update_metrics(latency_ms=latency_ms,error=True)
self.metrics["processed"] = int(self.metrics.get("processed",0))+1
self.metrics["errors"] = int(self.metrics.get("errors",0))+1
prev_avg = self.metrics.get("avg_latency_ms")
if prev_avg is None:
self.metrics["avg_latency_ms"] = latency_ms
else:
self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
state.add_agent_metric(self.agent_name, processed = 1, latency_ms = latency_ms, errors = 1)
state.add_audit_entry(self.agent_name,"Error in Execution",{"error":str(e)})
state.overall_status = ProcessingStatus.FAILED
self.logger.logger.exception(f"{self.agent_name} failed: {e}")
return state
def _validate_preconditions(self, state: InvoiceProcessingState) -> bool:
# pass
"override to add custom preconditions for agent execution"
return True
def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
# pass
"override to verify expected outcomes after agent execution"
return True
def get_metrics(self) -> Dict[str, Any]:
# pass
return dict(self.metrics)
def reset_metrics(self):
# pass
self.metrics = {"processed":0,
"errors":0,
"avg_latency_ms":None,
"last_run_at":None}
async def health_check(self) -> Dict[str, Any]:
# pass
"""perform a basic health check for the agent"""
return {
"agent":self.agent_name,
"status":"Healthy",
"Last Run":self.metrics.get("last_run_at"),
"errors":self.metrics.get("errors", 0)
}
def _extract_business_context(self, state: InvoiceProcessingState) -> Dict[str, Any]:
# pass
"""Extract relevant invoice or PO context for resaoning logs"""
context: Dict[str,Any] = {}
if state.invoice_data:
context["vendor"] = state.invoice_data.vendor_name
context["invoice_id"] = state.invoice_data.invoice_id
context["amount"] = state.invoice_data.total_amount
if state.validation_result:
try:
context["validation_status"] = state.validation_result.validation_status.value
except Exception:
context["validation_status"] = str(state.validation_result.validation_status)
if state.risk_assessment:
context["risk_score"] = state.risk_assessment.risk_score
context["risk_level"] = state.risk_assessment.risk_level.value if hasattr(state.risk_assessment.risk_level, "value") else str(state.risk_assessment.risk_level)
return context
def _should_escalate(self, state: InvoiceProcessingState, reason: str = None) -> bool:
# pass
"""Determine whether the workflow should escalate."""
try:
result = state.requires_escalation()
except Exception:
result = True
if result:
self.logger.logger.warning(f"Escalation triggered by {self.agent_name}:{reason or 'auto'}")
state.escalation_required = True
state.human_review_required = True
state.add_audit_entry(self.agent_name,"Escalation Triggered", None, {"reason":reason or "auto"})
return result
def _log_decision(self, state: InvoiceProcessingState, decision: str,
reasoning: str, confidence: float = None, process_id: str = None):
# pass
"""Log and record an agent decision into audit trail."""
details:Dict[str,Any] = {
"decision":decision,
"reasoning":reasoning,
"confidence":confidence,
# "timestamp":datetime.utcnow().isoformat()
}
self.logger.logger.info(f"{self.agent_name} decision:{decision}(confidence = {confidence})")
state.add_audit_entry(self.agent_name, decision, None, details, process_id)
class AgentRegistry:
"""Registry for managing agent instances"""
def __init__(self):
# pass
self._agents:Dict[str,BaseAgent] = {}
def register(self, agent: BaseAgent):
# pass
if agent.agent_name in self._agents:
print(f"{agent.agent_name} already registered - skipping")
return
self._agents[agent.agent_name] = agent
def get(self, agent_name: str) -> Optional[BaseAgent]:
# pass
return self._agents.get(agent_name)
def list_agents(self) -> List[str]:
# pass
return list(self._agents.keys())
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
# pass
return {name:agent.get_metrics() for name, agent in self._agents.items()}
async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
# pass
result:Dict[str,Dict[str,Any]] = {}
for name, agent in self._agents.items():
result[name] = await agent.health_check()
return result
# Global agent registry instance
agent_registry = AgentRegistry()
print("Registry instance ID in base:", id(agent_registry))
|