"""Document Agent for Invoice Processing""" # TODO: Implement agent import os import json import re import fitz # PyMuPDF import pdfplumber from typing import Dict, Any, Optional, List import google.generativeai as genai from dotenv import load_dotenv from datetime import datetime from agents.base_agent import BaseAgent from state import ( InvoiceProcessingState, InvoiceData, ItemDetail, ProcessingStatus, ValidationStatus ) from utils.logger import StructuredLogger load_dotenv() logger = StructuredLogger("DocumentAgent") def safe_json_parse(result_text: str): # Remove Markdown formatting if present cleaned = re.sub(r"^```[a-zA-Z]*\n|```$", "", result_text.strip()) try: return json.loads(cleaned) except json.JSONDecodeError: # Fallback if the AI wrapped JSON in text start, end = cleaned.find("{"), cleaned.rfind("}") + 1 if start >= 0 and end > 0: return json.loads(cleaned[start:end]) raise def to_float(value): if isinstance(value, (int, float)): return float(value) if isinstance(value, str): try: return float(value.replace(',', '').replace('$', '').strip()) except (ValueError, TypeError): return 0.0 return 0.0 def parse_date_safe(date_str): if not date_str: return None for fmt in ("%b %d %Y", "%b %d, %Y", "%Y-%m-%d", "%d-%b-%Y"): try: return datetime.strptime(date_str.strip(), fmt).date() except ValueError: continue return None from collections import defaultdict class APIKeyBalancer: SAVE_FILE = "key_stats.json" def __init__(self, keys): self.keys = keys self.usage = defaultdict(int) self.errors = defaultdict(int) self.load() def load(self): if os.path.exists(self.SAVE_FILE): data = json.load(open(self.SAVE_FILE)) self.usage.update(data.get("usage", {})) self.errors.update(data.get("errors", {})) def save(self): json.dump({ "usage": self.usage, "errors": self.errors }, open(self.SAVE_FILE, "w")) def get_best_key(self): # choose least used or least errored key best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k])) self.usage[best_key] += 1 self.save() return best_key def report_error(self, key): self.errors[key] += 1 self.save() balancer = APIKeyBalancer([ os.getenv("GEMINI_API_KEY_1"), os.getenv("GEMINI_API_KEY_2"), os.getenv("GEMINI_API_KEY_3"), # os.getenv("GEMINI_API_KEY_4"), os.getenv("GEMINI_API_KEY_5"), os.getenv("GEMINI_API_KEY_6"), # os.getenv("GEMINI_API_KEY_7"), ]) class DocumentAgent(BaseAgent): """Agent responsible for document processing and invoice data extraction""" def __init__(self, config: Dict[str, Any] = None): # pass super().__init__("document_agent", config) self.logger = StructuredLogger("DocumentAgent") self.api_key = balancer.get_best_key() print("self.api_key..........", self.api_key) genai.configure(api_key=self.api_key) # genai.configure(api_key=os.getenv("GEMINI_API_KEY_7")) self.model = genai.GenerativeModel("gemini-2.5-flash") def generate(self, prompt): try: print("generate called") response = self.model.generate_content(prompt) print("response....", response) return response except Exception as e: print("errrororrrooroor") balancer.report_error(self.api_key) print(balancer.keys) print(balancer.usage) print(balancer.errors) raise def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool: # pass if not state.file_name or not os.path.exists(state.file_name): self.logger.logger.error(f"[Document Agent] Missing or invalid file: {state.file_name}") return False return True def _validate_postconditions(self, state: InvoiceProcessingState) -> bool: # pass return bool(state.invoice_data and state.invoice_data.total > 0) async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState: # pass # file_name = state.file_name self.logger.logger.info(f"Executing Document Agent for file: {state.file_name}") if not self._validate_preconditions(state, workflow_type): state.overall_status = ProcessingStatus.FAILED self._log_decision(state, "Extraction Failed", "Preconditions not met", confidence=0.0) try: raw_text = await self._extract_text_from_pdf(state.file_name) invoice_data = await self._parse_invoice_with_ai(raw_text) invoice_data = await self._enhance_invoice_data(invoice_data, raw_text) invoice_data.file_name = state.file_name state.invoice_data = invoice_data state.overall_status = ProcessingStatus.IN_PROGRESS state.current_agent = self.agent_name state.updated_at = datetime.utcnow() confidence = self._calculate_extraction_confidence(invoice_data, raw_text) state.invoice_data.extraction_confidence = confidence self._log_decision( state, "Extraction Successful", "PDF text successfully extracted and parsed by AI", confidence, state.process_id ) return state except Exception as e: self.logger.logger.exception(f"[Document Agent] Extraction failed: {e}") state.overall_status = ProcessingStatus.FAILED self._should_escalate(state, reason=str(e)) return state async def _extract_text_from_pdf(self, file_name: str) -> str: # pass text = "" try: self.logger.logger.info("[DocumentAgent] Extracting text using PyMuPDF...") with fitz.open(file_name) as doc: for page in doc: text += page.get_text() if len(text.strip()) < 5: raise ValueError("PyMuPDF extraction too short, switching to PDFPlumber") except Exception as e: self.logger.logger.info("[DocumentAgent] Fallback to PDFPlumber...") try: with pdfplumber.open(file_name) as pdf: for page in pdf.pages: text += page.extract_text() or "" except Exception as e2: self.logger.logger.error("[DocumentAgent] PDFPlumber failed :{e2}") text = "" return text async def _parse_invoice_with_ai(self, text: str) -> InvoiceData: # pass self.logger.logger.info("[DocumentAgent] Parsing invoice data using Gemini AI...") print("text-----------", text) prompt = f""" Extract structured invoice information as JSON with fields: invoice_number, order_id, customer_name, due_date, ship_to, ship_mode, subtotal, discount, shipping_cost, total, and item_details (item_name, quantity, rate, amount). Important Note: If an item description continues on multiple lines, combine them into one item_name. Check intelligently that if at all there will be more than one item then it should have more numbers. So extract by verifying that is there only one item or more than one. Input Text: {text[:8000]} """ response = self.generate(prompt) result_text = response.text.strip() data = safe_json_parse(result_text) print("----------------------------------text-----------------------------------",text) print("result text::::::::::::::::::::::::::::",data) # try: # data = json.loads(result_text) # except Exception as e: # self.logger.logger.warning("AI output not valid JSON, retrying with fallback parse.") # data = json.loads(result_text[result_text.find('{'): result_text.rfind('}')+1]) items = [] for item in data.get("item_details", []): items.append(ItemDetail( item_name=item.get("item_name"), quantity=float(item.get("quantity", 1)), rate=to_float(item.get("rate", 0.0)), amount=to_float(item.get("amount", 0.0)), # category=self._categorize_item(item.get("item_name", "Unknown")), )) invoice_data = InvoiceData( invoice_number=data.get("invoice_number"), order_id=data.get("order_id"), customer_name=data.get("customer_name"), due_date=parse_date_safe(data.get("due_date")), ship_to=data.get("ship_to"), ship_mode=data.get("ship_mode"), subtotal=to_float(data.get("subtotal", 0.0)), discount=to_float(data.get("discount", 0.0)), shipping_cost=to_float(data.get("shipping_cost", 0.0)), total=to_float(data.get("total", 0.0)), item_details=items, raw_text=text, ) confidence = self._calculate_extraction_confidence(invoice_data, text) invoice_data.extraction_confidence = confidence self.logger.logger.info("AI output successfully parsed into JSON format") return invoice_data async def _enhance_invoice_data(self, invoice_data: InvoiceData, raw_text: str) -> InvoiceData: # pass if not invoice_data.customer_name: if "Invoice To" in raw_text: lines = raw_text.split("\n") for i, line in enumerate(lines): if "Invoice To" in line: invoice_data.customer_name = lines[i+1].strip() break return invoice_data def _categorize_item(self, item_name: str) -> str: # pass name = item_name.lower() prompt = f""" Extract the category of the Item from the item details very intelligently so that we can get the category in which the item belongs to very efficiently: Example: "Electronics", "Furniture", "Software", etc..... Input Text- The item is given below (provide the category in JSON format like -- category: 'extracted category') ----> {name} """ response = self.generate(prompt) result_text = response.text.strip() category = safe_json_parse(result_text) print(category['category']) return category['category'] def _calculate_extraction_confidence(self, invoice_data: InvoiceData, raw_text: str) -> float: """ Intelligent confidence scoring for extracted invoice data. Combines presence, consistency, and numeric sanity checks. """ score = 0.0 weight = { "invoice_number": 0.1, "order_id": 0.05, "customer_name": 0.1, "due_date": 0.05, "ship_to": 0.05, "item_details": 0.25, "total_consistency": 0.25, "currency_detected": 0.05, "text_match_bonus": 0.1 } text_lower = raw_text.lower() # Presence-based confidence if invoice_data.invoice_number: score += weight["invoice_number"] if invoice_data.order_id: score += weight["order_id"] if invoice_data.customer_name: score += weight["customer_name"] if invoice_data.due_date and "due_date" in text_lower: score += weight["due_date"] if not invoice_data.due_date and "due_date" not in text_lower: score += weight["due_date"] if invoice_data.item_details: score += weight["item_details"] # Currency detection if any(c in raw_text for c in ["$", "โ‚น", "โ‚ฌ", "usd", "inr", "eur"]): score += weight["currency_detected"] # Numeric Consistency: subtotal + shipping โ‰ˆ total def _extract_amounts(pattern): import re matches = re.findall(pattern, raw_text) return [float(m.replace(",", "").replace("$", "").strip()) for m in matches if m] import re numbers = _extract_amounts(r"\$?\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?") if len(numbers) >= 3 and invoice_data.total: approx_total = max(numbers) diff = abs(approx_total - invoice_data.total) if diff < 5: # minor difference allowed score += weight["total_consistency"] elif diff < 50: score += weight["total_consistency"] * 0.5 # Textual verification hits = 0 for field in [invoice_data.customer_name, invoice_data.order_id, invoice_data.invoice_number]: if field and str(field).lower() in text_lower: hits += 1 if hits >= 2: score += weight["text_match_bonus"] # Penalty for empty critical fields missing_critical = not invoice_data.total or not invoice_data.customer_name or not invoice_data.invoice_number if missing_critical: score *= 0.8 # Clamp and finalize final_conf = round(min(score, 0.99), 2) invoice_data.extraction_confidence = final_conf return final_conf * 100.0 async def health_check(self) -> Dict[str, Any]: """ Perform intelligent health diagnostics for the Document Agent. Collects operational, performance, and API connectivity metrics. """ from datetime import datetime metrics_data = {} executions = 0 success_rate = 0.0 avg_duration = 0.0 failures = 0 last_run = None # latency_trend = None # 1. Try to get live metrics from state print("(self.state)-------",self.metrics) # print("self.state.agent_metrics-------", self.state.agent_metrics) if self.metrics: executions = self.metrics["processed"] avg_duration = self.metrics["avg_latency_ms"] failures = self.metrics["errors"] last_run = self.metrics["last_run_at"] success_rate = (executions - failures) / (executions+1e-8) # print(executions, avg_duration, failures, last_run, success_rate) # latency_trend = getattr(m, "total_duration_ms", None) # 2. API connectivity check gemini_ok = bool(self.api_key) # print("self.api---", self.api_key) # print("geminiokkkkkk", gemini_ok) api_status = "๐ŸŸข Active" if gemini_ok else "๐Ÿ”ด Missing Key" # 3. Health logic overall_status = "๐ŸŸข Healthy" if not gemini_ok or failures > 3: overall_status = "๐ŸŸ  Degraded" if executions > 0 and success_rate < 0.5: overall_status = "๐Ÿ”ด Unhealthy" # 4. Extended agent diagnostics metrics_data = { "Agent": "Document Agent ๐Ÿงพ", "Executions": executions, "Success Rate (%)": round(success_rate * 100, 2), "Avg Duration (ms)": round(avg_duration, 2), "Total Failures": failures, "API Status": api_status, "Last Run": str(last_run) if last_run else "Not applicable", "Overall Health": overall_status, # "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), } self.logger.logger.info(f"[HealthCheck] Document Agent metrics: {metrics_data}") return metrics_data