File size: 8,939 Bytes
2b44e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

"""Base Agent Class for Invoice Processing System"""

# TODO: Implement agent

import time
import logging
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from datetime import datetime

from state import InvoiceProcessingState, ProcessingStatus, AuditTrail
from utils.logger import get_logger


class BaseAgent(ABC):
    """Abstract base class for all invoice processing agents"""

    def __init__(self, agent_name: str, config: Dict[str, Any] = None):
        self.agent_name = agent_name
        self.config = config or {}
        self.logger = get_logger(agent_name)
        self.metrics: Dict[str,Any] = {
            "processed" : 0,
            "errors" : 0,
            "avg_latency_ms" : None,
            "last_run_at" : None
        }
        self.start_time: Optional[float] = None

    @abstractmethod
    async def execute(self, state: InvoiceProcessingState) -> InvoiceProcessingState:
        raise NotImplementedError

    async def run(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        self.start_time = time.time()
        self.logger.logger.info(f"Starting {self.agent_name} execution.")
        if not self._validate_preconditions(state, workflow_type):
            self.logger.logger.warning(f"Preconditions not met for {self.agent_name}.")
            self.metrics["processed"] = int(self.metrics.get("processed", 0)) + 1
            self.metrics["last_run_at"] = datetime.utcnow().isoformat()

            # optional but very good:
            state.add_agent_metric(self.agent_name, processed=1, latency_ms=0, errors=0)

            state.add_audit_entry(
                self.agent_name,
                "precondition_failed",
                {"note": "Preconditions not met, agent skipped."}
            )
            return state
        state.current_agent = self.agent_name
        state.agent_name = self.agent_name
        state.overall_status = ProcessingStatus.IN_PROGRESS

        try:
            updated_state = await self.execute(state, workflow_type)

            try:
                self._validate_postconditions(updated_state)
            except Exception as post_exc:
                self.logger.logger.warning(f"Postcondition check raised for {self.agent_name}:{post_exc}")
            
            state.mark_agent_completed(self.agent_name)
            latency_ms = (time.time()-self.start_time)*1000
            self.metrics["processed"] = int(self.metrics.get("processed",0)) + 1
            prev_avg = self.metrics.get("avg_latency_ms")

            if prev_avg is None:
                self.metrics["avg_latency_ms"] = latency_ms
            else:
                self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0

            self.metrics["last_run_at"] = datetime.utcnow().isoformat()
            print(
                f"Agent: {self.agent_name} | "
                f"id: {id(self)} | "
                f"last_run_at: {self.metrics['last_run_at']}"
            )

            print("self.metrics[last_run_at]", self.metrics["last_run_at"])
            state.add_agent_metric(self.agent_name,processed=1,latency_ms=latency_ms)
            state.add_audit_entry(self.agent_name, action="Agent Successfully Executed", status=ProcessingStatus.COMPLETED, details={"latency_ms":latency_ms}, process_id=state.process_id)
            
            self.logger.logger.info(f"{self.agent_name}completed successfully in {latency_ms:.2f}ms.")
            return updated_state
            
        except Exception as e:
            latency_ms = (time.time()-self.start_time)*1000 if self.start_time else 0.0
            # self._update_metrics(latency_ms=latency_ms,error=True)
            self.metrics["processed"] = int(self.metrics.get("processed",0))+1
            self.metrics["errors"] = int(self.metrics.get("errors",0))+1
            prev_avg = self.metrics.get("avg_latency_ms")

            if prev_avg is None:
                self.metrics["avg_latency_ms"] = latency_ms
            else:
                self.metrics["avg_latency_ms"] = (prev_avg+latency_ms)/2.0
            self.metrics["last_run_at"] = datetime.utcnow().isoformat() 
            state.add_agent_metric(self.agent_name, processed = 1, latency_ms = latency_ms, errors = 1)
            state.add_audit_entry(self.agent_name,"Error in Execution",{"error":str(e)})
            state.overall_status = ProcessingStatus.FAILED
            self.logger.logger.exception(f"{self.agent_name} failed: {e}")
            return state

    def _validate_preconditions(self, state: InvoiceProcessingState) -> bool:
        # pass
        "override to add custom preconditions for agent execution"
        return True

    def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
        # pass
        "override to verify expected outcomes after agent execution"
        return True


    def get_metrics(self) -> Dict[str, Any]:
        # pass
        return dict(self.metrics)

    def reset_metrics(self):
        # pass
        self.metrics = {"processed":0,
                        "errors":0,
                        "avg_latency_ms":None,
                        "last_run_at":None}

    async def health_check(self) -> Dict[str, Any]:
        # pass
        """perform a basic health check for the agent"""
        return {
            "agent":self.agent_name,
            "status":"Healthy",
            "Last Run":self.metrics.get("last_run_at"),
            "errors":self.metrics.get("errors", 0)
        }

    def _extract_business_context(self, state: InvoiceProcessingState) -> Dict[str, Any]:
        # pass
        """Extract relevant invoice or PO context for resaoning logs"""
        context: Dict[str,Any] = {}
        if state.invoice_data:
            context["vendor"] = state.invoice_data.vendor_name
            context["invoice_id"] = state.invoice_data.invoice_id
            context["amount"] = state.invoice_data.total_amount
        if state.validation_result:
            try:
                context["validation_status"] = state.validation_result.validation_status.value
            except Exception:
                context["validation_status"] = str(state.validation_result.validation_status)
        if state.risk_assessment:
            context["risk_score"] = state.risk_assessment.risk_score
            context["risk_level"] = state.risk_assessment.risk_level.value if hasattr(state.risk_assessment.risk_level, "value") else str(state.risk_assessment.risk_level)
        return context


    def _should_escalate(self, state: InvoiceProcessingState, reason: str = None) -> bool:
        # pass
        """Determine whether the workflow should escalate."""
        try:
            result = state.requires_escalation()
        except Exception:
            result = True
        if result:
            self.logger.logger.warning(f"Escalation triggered by {self.agent_name}:{reason or 'auto'}")
            state.escalation_required = True
            state.human_review_required = True
            state.add_audit_entry(self.agent_name,"Escalation Triggered", None, {"reason":reason or "auto"})
        return result

    def _log_decision(self, state: InvoiceProcessingState, decision: str,
                     reasoning: str, confidence: float = None, process_id: str = None):
        # pass
        """Log and record an agent decision into audit trail."""
        details:Dict[str,Any] = {
            "decision":decision,
            "reasoning":reasoning,
            "confidence":confidence,
            # "timestamp":datetime.utcnow().isoformat()
        }
        self.logger.logger.info(f"{self.agent_name} decision:{decision}(confidence = {confidence})")
        state.add_audit_entry(self.agent_name, decision, None, details, process_id)

class AgentRegistry:
    """Registry for managing agent instances"""

    def __init__(self):
        # pass
        self._agents:Dict[str,BaseAgent] = {}

    def register(self, agent: BaseAgent):
        # pass
        if agent.agent_name in self._agents:
            print(f"{agent.agent_name} already registered - skipping")
            return
        self._agents[agent.agent_name] = agent

    def get(self, agent_name: str) -> Optional[BaseAgent]:
        # pass
        return self._agents.get(agent_name)

    def list_agents(self) -> List[str]:
        # pass
        return list(self._agents.keys())

    def get_all_metrics(self) -> Dict[str, Dict[str, Any]]:
        # pass
        return {name:agent.get_metrics() for name, agent in self._agents.items()}

    async def health_check_all(self) -> Dict[str, Dict[str, Any]]:
        # pass
        result:Dict[str,Dict[str,Any]] = {}
        for name, agent in self._agents.items():
            result[name] = await agent.health_check()
        return result



# Global agent registry instance
agent_registry = AgentRegistry()
print("Registry instance ID in base:", id(agent_registry))