File size: 15,826 Bytes
2b44e69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412

"""Document Agent for Invoice Processing"""

# TODO: Implement agent

import os
import json
import re
import fitz  # PyMuPDF
import pdfplumber
from typing import Dict, Any, Optional, List
import google.generativeai as genai
from dotenv import load_dotenv
from datetime import datetime

from agents.base_agent import BaseAgent
from state import (
    InvoiceProcessingState, InvoiceData, ItemDetail,
    ProcessingStatus, ValidationStatus
)
from utils.logger import StructuredLogger


load_dotenv()
logger = StructuredLogger("DocumentAgent")

def safe_json_parse(result_text: str):
    # Remove Markdown formatting if present
    cleaned = re.sub(r"^```[a-zA-Z]*\n|```$", "", result_text.strip())
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        # Fallback if the AI wrapped JSON in text
        start, end = cleaned.find("{"), cleaned.rfind("}") + 1
        if start >= 0 and end > 0:
            return json.loads(cleaned[start:end])
        raise

def to_float(value):
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        try:
            return float(value.replace(',', '').replace('$', '').strip())
        except (ValueError, TypeError):
            return 0.0
    return 0.0

def parse_date_safe(date_str):
    if not date_str:
        return None
    for fmt in ("%b %d %Y", "%b %d, %Y", "%Y-%m-%d", "%d-%b-%Y"):
        try:
            return datetime.strptime(date_str.strip(), fmt).date()
        except ValueError:
            continue
    return None


from collections import defaultdict
class APIKeyBalancer:
    SAVE_FILE = "key_stats.json"
    def __init__(self, keys):
        self.keys = keys
        self.usage = defaultdict(int)
        self.errors = defaultdict(int)
        self.load()

    def load(self):
        if os.path.exists(self.SAVE_FILE):
            data = json.load(open(self.SAVE_FILE))
            self.usage.update(data.get("usage", {}))
            self.errors.update(data.get("errors", {}))

    def save(self):
        json.dump({
            "usage": self.usage,
            "errors": self.errors
        }, open(self.SAVE_FILE, "w"))

    def get_best_key(self):
        # choose least used or least errored key
        best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k]))
        self.usage[best_key] += 1
        self.save()
        return best_key

    def report_error(self, key):
        self.errors[key] += 1
        self.save()

        
balancer = APIKeyBalancer([
    os.getenv("GEMINI_API_KEY_1"),
    os.getenv("GEMINI_API_KEY_2"),
    os.getenv("GEMINI_API_KEY_3"),
    # os.getenv("GEMINI_API_KEY_4"),
    os.getenv("GEMINI_API_KEY_5"),
    os.getenv("GEMINI_API_KEY_6"),
    # os.getenv("GEMINI_API_KEY_7"),
])


class DocumentAgent(BaseAgent):
    """Agent responsible for document processing and invoice data extraction"""

    def __init__(self, config: Dict[str, Any] = None):
        # pass
        super().__init__("document_agent", config)
        self.logger = StructuredLogger("DocumentAgent")
        self.api_key = balancer.get_best_key()
        print("self.api_key..........", self.api_key)

        genai.configure(api_key=self.api_key)
        # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7"))
        self.model = genai.GenerativeModel("gemini-2.5-flash")

    def generate(self, prompt):
        try:
            print("generate called")
            response = self.model.generate_content(prompt)
            print("response....", response)
            return response
        except Exception as e:
            print("errrororrrooroor")
            balancer.report_error(self.api_key)
            print(balancer.keys)
            print(balancer.usage)
            print(balancer.errors)
            raise

    def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
        # pass
        if not state.file_name or not os.path.exists(state.file_name):
            self.logger.logger.error(f"[Document Agent] Missing or invalid file: {state.file_name}")
            return False
        return True

    def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
        # pass
        return bool(state.invoice_data and state.invoice_data.total > 0)

    async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
        # pass
        # file_name = state.file_name
        self.logger.logger.info(f"Executing Document Agent for file: {state.file_name}")

        if not self._validate_preconditions(state, workflow_type):
            state.overall_status = ProcessingStatus.FAILED
            self._log_decision(state, "Extraction Failed", "Preconditions not met", confidence=0.0)
            
        try:
            raw_text = await self._extract_text_from_pdf(state.file_name)
            invoice_data = await self._parse_invoice_with_ai(raw_text)
            invoice_data = await self._enhance_invoice_data(invoice_data, raw_text)
            invoice_data.file_name = state.file_name
            state.invoice_data = invoice_data
            state.overall_status = ProcessingStatus.IN_PROGRESS
            state.current_agent = self.agent_name
            state.updated_at = datetime.utcnow()

            confidence = self._calculate_extraction_confidence(invoice_data, raw_text)
            state.invoice_data.extraction_confidence = confidence
            self._log_decision(
                state,
                "Extraction Successful",
                "PDF text successfully extracted and parsed by AI",
                confidence,
                state.process_id
            )
            return state
        except Exception as e:
            self.logger.logger.exception(f"[Document Agent] Extraction failed: {e}")
            state.overall_status = ProcessingStatus.FAILED
            self._should_escalate(state, reason=str(e))
            return state


    async def _extract_text_from_pdf(self, file_name: str) -> str:
        # pass
        text = ""
        try:
            self.logger.logger.info("[DocumentAgent] Extracting text using PyMuPDF...")
            with fitz.open(file_name) as doc:
                for page in doc:
                    text += page.get_text()
            if len(text.strip()) < 5:
                raise ValueError("PyMuPDF extraction too short, switching to PDFPlumber")
        except Exception as e:
            self.logger.logger.info("[DocumentAgent] Fallback to PDFPlumber...")
            try:
                with pdfplumber.open(file_name) as pdf:
                    for page in pdf.pages:
                        text += page.extract_text() or ""
            except Exception as e2:
                self.logger.logger.error("[DocumentAgent] PDFPlumber failed :{e2}")
                text = ""
        return text

    async def _parse_invoice_with_ai(self, text: str) -> InvoiceData:
        # pass
        self.logger.logger.info("[DocumentAgent] Parsing invoice data using Gemini AI...")
        print("text-----------", text)
        prompt = f"""
        Extract structured invoice information as JSON with fields:
        invoice_number, order_id, customer_name, due_date, ship_to, ship_mode,
        subtotal, discount, shipping_cost, total, and item_details (item_name, quantity, rate, amount).

        Important Note: If an item description continues on multiple lines, combine them into one item_name. Check intelligently
        that if at all there will be more than one item then it should have more numbers.
        So extract by verifying that is there only one item or more than one.

        Input Text:
        {text[:8000]}
        """
        response = self.generate(prompt)
        result_text = response.text.strip()
        data = safe_json_parse(result_text)
        print("----------------------------------text-----------------------------------",text)
        print("result text::::::::::::::::::::::::::::",data)
        # try:
        #     data = json.loads(result_text)
        # except Exception as e:
        #     self.logger.logger.warning("AI output not valid JSON, retrying with fallback parse.")
        #     data = json.loads(result_text[result_text.find('{'): result_text.rfind('}')+1])
        items = []
        for item in data.get("item_details", []):
            items.append(ItemDetail(
                item_name=item.get("item_name"),
                quantity=float(item.get("quantity", 1)),
                rate=to_float(item.get("rate", 0.0)),
                amount=to_float(item.get("amount", 0.0)),
                # category=self._categorize_item(item.get("item_name", "Unknown")),
            ))

        invoice_data = InvoiceData(
            invoice_number=data.get("invoice_number"),
            order_id=data.get("order_id"),
            customer_name=data.get("customer_name"),
            due_date=parse_date_safe(data.get("due_date")),
            ship_to=data.get("ship_to"),
            ship_mode=data.get("ship_mode"),
            subtotal=to_float(data.get("subtotal", 0.0)),
            discount=to_float(data.get("discount", 0.0)),
            shipping_cost=to_float(data.get("shipping_cost", 0.0)),
            total=to_float(data.get("total", 0.0)),
            item_details=items,
            raw_text=text,
        )
        confidence = self._calculate_extraction_confidence(invoice_data, text)
        invoice_data.extraction_confidence = confidence
        self.logger.logger.info("AI output successfully parsed into JSON format")
        return invoice_data


    async def _enhance_invoice_data(self, invoice_data: InvoiceData, raw_text: str) -> InvoiceData:
        # pass
        if not invoice_data.customer_name:
            if "Invoice To" in raw_text:
                lines = raw_text.split("\n")
                for i, line in enumerate(lines):
                    if "Invoice To" in line:
                        invoice_data.customer_name = lines[i+1].strip()
                        break
        return invoice_data

    def _categorize_item(self, item_name: str) -> str:
        # pass
        name = item_name.lower()
        prompt = f"""
        Extract the category of the Item from the item details very intelligently
        so that we can get the category in which the item belongs to very efficiently:
        Example: "Electronics", "Furniture", "Software", etc.....
        Input Text- The item is given below (provide the category in JSON format like -- category: 'extracted category') ---->
        {name}
        """
        response = self.generate(prompt)
        result_text = response.text.strip()
        category = safe_json_parse(result_text)
        print(category['category'])
        return category['category']

    def _calculate_extraction_confidence(self, invoice_data: InvoiceData, raw_text: str) -> float:
        """
        Intelligent confidence scoring for extracted invoice data.
        Combines presence, consistency, and numeric sanity checks.
        """
        score = 0.0
        weight = {
            "invoice_number": 0.1,
            "order_id": 0.05,
            "customer_name": 0.1,
            "due_date": 0.05,
            "ship_to": 0.05,
            "item_details": 0.25,
            "total_consistency": 0.25,
            "currency_detected": 0.05,
            "text_match_bonus": 0.1
        }
    
        text_lower = raw_text.lower()
    
        # Presence-based confidence
        if invoice_data.invoice_number:
            score += weight["invoice_number"]
        if invoice_data.order_id:
            score += weight["order_id"]
        if invoice_data.customer_name:
            score += weight["customer_name"]
        if invoice_data.due_date and "due_date" in text_lower:
            score += weight["due_date"]
        if not invoice_data.due_date and "due_date" not in text_lower:
            score += weight["due_date"]
        if invoice_data.item_details:
            score += weight["item_details"]
    
        # Currency detection
        if any(c in raw_text for c in ["$", "₹", "€", "usd", "inr", "eur"]):
            score += weight["currency_detected"]
    
        # Numeric Consistency: subtotal + shipping ≈ total 
        def _extract_amounts(pattern):
            import re
            matches = re.findall(pattern, raw_text)
            return [float(m.replace(",", "").replace("$", "").strip()) for m in matches if m]
    
        import re
        numbers = _extract_amounts(r"\$?\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?")
        if len(numbers) >= 3 and invoice_data.total:
            approx_total = max(numbers)
            diff = abs(approx_total - invoice_data.total)
            if diff < 5:  # minor difference allowed
                score += weight["total_consistency"]
            elif diff < 50:
                score += weight["total_consistency"] * 0.5
    
        # Textual verification 
        hits = 0
        for field in [invoice_data.customer_name, invoice_data.order_id, invoice_data.invoice_number]:
            if field and str(field).lower() in text_lower:
                hits += 1
        if hits >= 2:
            score += weight["text_match_bonus"]
    
        # Penalty for empty critical fields 
        missing_critical = not invoice_data.total or not invoice_data.customer_name or not invoice_data.invoice_number
        if missing_critical:
            score *= 0.8
    
        # Clamp and finalize 
        final_conf = round(min(score, 0.99), 2)
        invoice_data.extraction_confidence = final_conf
        return final_conf * 100.0


    async def health_check(self) -> Dict[str, Any]:
        """
        Perform intelligent health diagnostics for the Document Agent.
        Collects operational, performance, and API connectivity metrics.
        """
        from datetime import datetime

        metrics_data = {}
        executions = 0
        success_rate = 0.0
        avg_duration = 0.0
        failures = 0
        last_run = None
        # latency_trend = None

        # 1. Try to get live metrics from state
        print("(self.state)-------",self.metrics)
        # print("self.state.agent_metrics-------", self.state.agent_metrics)
        if self.metrics:
            executions = self.metrics["processed"]
            avg_duration = self.metrics["avg_latency_ms"]
            failures = self.metrics["errors"]
            last_run = self.metrics["last_run_at"]
            success_rate = (executions - failures) / (executions+1e-8)

            # print(executions, avg_duration, failures, last_run, success_rate)
            # latency_trend = getattr(m, "total_duration_ms", None)

        # 2. API connectivity check
        gemini_ok = bool(self.api_key)
        # print("self.api---", self.api_key)
        # print("geminiokkkkkk", gemini_ok)
        api_status = "🟢 Active" if gemini_ok else "🔴 Missing Key"

        # 3. Health logic
        overall_status = "🟢 Healthy"
        if not gemini_ok or failures > 3:
            overall_status = "🟠 Degraded"
        if executions > 0 and success_rate < 0.5:
            overall_status = "🔴 Unhealthy"

        # 4. Extended agent diagnostics
        metrics_data = {
            "Agent": "Document Agent 🧾",
            "Executions": executions,
            "Success Rate (%)": round(success_rate * 100, 2),
            "Avg Duration (ms)": round(avg_duration, 2),
            "Total Failures": failures,
            "API Status": api_status,
            "Last Run": str(last_run) if last_run else "Not applicable",
            "Overall Health": overall_status,
            # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
        }

        self.logger.logger.info(f"[HealthCheck] Document Agent metrics: {metrics_data}")
        return metrics_data