InvoiceAgenticAI / agents /base_agent.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
"""Base Agent Class for Invoice Processing System"""
# TODO: Implement agent
import time
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from datetime import datetime
from state import InvoiceProcessingState, ProcessingStatus, AuditTrail
from utils.logger import get_logger
class BaseAgent(ABC):
"""Abstract base class for all invoice processing agents"""
def __init__(self, agent_name: str, config: Dict[str, Any] = None):
self.agent_name = agent_name
self.config = config or {}
self.logger = get_logger(agent_name)
self.metrics: Dict[str,Any] = {
"processed" : 0,
"errors" : 0,
"avg_latency_ms" : None,
"last_run_at" : None
}
self.start_time: Optional[float] = None
@abstractmethod
async def execute(self, state: InvoiceProcessingState) -> InvoiceProcessingState:
raise NotImplementedError
async def run(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
self.start_time = time.time()
self.logger.logger.info(f"Starting {self.agent_name} execution.")
if not self._validate_preconditions(state, workflow_type):
self.logger.logger.warning(f"Preconditions not met for {self.agent_name}.")
self.metrics["processed"] = int(self.metrics.get("processed", 0)) + 1
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
# optional but very good:
state.add_agent_metric(self.agent_name, processed=1, latency_ms=0, errors=0)
state.add_audit_entry(
self.agent_name,
"precondition_failed",
{"note": "Preconditions not met, agent skipped."}
)
return state
state.current_agent = self.agent_name
state.agent_name = self.agent_name
state.overall_status = ProcessingStatus.IN_PROGRESS
try:
updated_state = await self.execute(state, workflow_type)
try:
self._validate_postconditions(updated_state)
except Exception as post_exc:
self.logger.logger.warning(f"Postcondition check raised for {self.agent_name}:{post_exc}")
state.mark_agent_completed(self.agent_name)
latency_ms = (time.time()-self.start_time)*1000
self.metrics["processed"] = int(self.metrics.get("processed",0)) + 1
prev_avg = self.metrics.get("avg_latency_ms")
if prev_avg is None:
self.metrics["avg_latency_ms"] = latency_ms
else:
self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
print(
f"Agent: {self.agent_name} | "
f"id: {id(self)} | "
f"last_run_at: {self.metrics['last_run_at']}"
)
print("self.metrics[last_run_at]", self.metrics["last_run_at"])
state.add_agent_metric(self.agent_name,processed=1,latency_ms=latency_ms)
state.add_audit_entry(self.agent_name, action="Agent Successfully Executed", status=ProcessingStatus.COMPLETED, details={"latency_ms":latency_ms}, process_id=state.process_id)
self.logger.logger.info(f"{self.agent_name}completed successfully in {latency_ms:.2f}ms.")
return updated_state
except Exception as e:
latency_ms = (time.time()-self.start_time)*1000 if self.start_time else 0.0
# self._update_metrics(latency_ms=latency_ms,error=True)
self.metrics["processed"] = int(self.metrics.get("processed",0))+1
self.metrics["errors"] = int(self.metrics.get("errors",0))+1
prev_avg = self.metrics.get("avg_latency_ms")
if prev_avg is None:
self.metrics["avg_latency_ms"] = latency_ms
else:
self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
self.metrics["last_run_at"] = datetime.utcnow().isoformat()
state.add_agent_metric(self.agent_name, processed = 1, latency_ms = latency_ms, errors = 1)
state.add_audit_entry(self.agent_name,"Error in Execution",{"error":str(e)})
state.overall_status = ProcessingStatus.FAILED
self.logger.logger.exception(f"{self.agent_name} failed: {e}")
return state
def _validate_preconditions(self, state: InvoiceProcessingState) -> bool:
# pass
"override to add custom preconditions for agent execution"
return True
def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
# pass
"override to verify expected outcomes after agent execution"
return True
def get_metrics(self) -> Dict[str, Any]:
# pass
return dict(self.metrics)
def reset_metrics(self):
# pass
self.metrics = {"processed":0,
"errors":0,
"avg_latency_ms":None,
"last_run_at":None}
async def health_check(self) -> Dict[str, Any]:
# pass
"""perform a basic health check for the agent"""
return {
"agent":self.agent_name,
"status":"Healthy",
"Last Run":self.metrics.get("last_run_at"),
"errors":self.metrics.get("errors", 0)
}
def _extract_business_context(self, state: InvoiceProcessingState) -> Dict[str, Any]:
# pass
"""Extract relevant invoice or PO context for resaoning logs"""
context: Dict[str,Any] = {}
if state.invoice_data:
context["vendor"] = state.invoice_data.vendor_name
context["invoice_id"] = state.invoice_data.invoice_id
context["amount"] = state.invoice_data.total_amount
if state.validation_result:
try:
context["validation_status"] = state.validation_result.validation_status.value
except Exception:
context["validation_status"] = str(state.validation_result.validation_status)
if state.risk_assessment:
context["risk_score"] = state.risk_assessment.risk_score
context["risk_level"] = state.risk_assessment.risk_level.value if hasattr(state.risk_assessment.risk_level, "value") else str(state.risk_assessment.risk_level)
return context
def _should_escalate(self, state: InvoiceProcessingState, reason: str = None) -> bool:
# pass
"""Determine whether the workflow should escalate."""
try:
result = state.requires_escalation()
except Exception:
result = True
if result:
self.logger.logger.warning(f"Escalation triggered by {self.agent_name}:{reason or 'auto'}")
state.escalation_required = True
state.human_review_required = True
state.add_audit_entry(self.agent_name,"Escalation Triggered", None, {"reason":reason or "auto"})
return result
def _log_decision(self, state: InvoiceProcessingState, decision: str,
reasoning: str, confidence: float = None, process_id: str = None):
# pass
"""Log and record an agent decision into audit trail."""
details:Dict[str,Any] = {
"decision":decision,
"reasoning":reasoning,
"confidence":confidence,
# "timestamp":datetime.utcnow().isoformat()
}
self.logger.logger.info(f"{self.agent_name} decision:{decision}(confidence = {confidence})")
state.add_audit_entry(self.agent_name, decision, None, details, process_id)
class AgentRegistry:
"""Registry for managing agent instances"""
def __init__(self):
# pass
self._agents:Dict[str,BaseAgent] = {}
def register(self, agent: BaseAgent):
# pass
if agent.agent_name in self._agents:
print(f"{agent.agent_name} already registered - skipping")
return
self._agents[agent.agent_name] = agent
def get(self, agent_name: str) -> Optional[BaseAgent]:
# pass
return self._agents.get(agent_name)
def list_agents(self) -> List[str]:
# pass
return list(self._agents.keys())
def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
# pass
return {name:agent.get_metrics() for name, agent in self._agents.items()}
async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
# pass
result:Dict[str,Dict[str,Any]] = {}
for name, agent in self._agents.items():
result[name] = await agent.health_check()
return result
# Global agent registry instance
agent_registry = AgentRegistry()
print("Registry instance ID in base:", id(agent_registry))