InvoiceAgenticAI / agents /document_agent.py
PARTHASAKHAPAUL
Restructure project and sync local as source of truth
9107a5d
"""Document Agent for Invoice Processing"""
# TODO: Implement agent
import os
import json
import re
import fitz # PyMuPDF
import pdfplumber
from typing import Dict, Any, Optional, List
import google.generativeai as genai
from dotenv import load_dotenv
from datetime import datetime
from agents.base_agent import BaseAgent
from state import (
InvoiceProcessingState, InvoiceData, ItemDetail,
ProcessingStatus, ValidationStatus
)
from utils.logger import StructuredLogger
load_dotenv()
logger = StructuredLogger("DocumentAgent")
def safe_json_parse(result_text: str):
# Remove Markdown formatting if present
cleaned = re.sub(r"^```[a-zA-Z]*\n|```$", "", result_text.strip())
try:
return json.loads(cleaned)
except json.JSONDecodeError:
# Fallback if the AI wrapped JSON in text
start, end = cleaned.find("{"), cleaned.rfind("}") + 1
if start >= 0 and end > 0:
return json.loads(cleaned[start:end])
raise
def to_float(value):
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
try:
return float(value.replace(',', '').replace('$', '').strip())
except (ValueError, TypeError):
return 0.0
return 0.0
def parse_date_safe(date_str):
if not date_str:
return None
for fmt in ("%b %d %Y", "%b %d, %Y", "%Y-%m-%d", "%d-%b-%Y"):
try:
return datetime.strptime(date_str.strip(), fmt).date()
except ValueError:
continue
return None
from collections import defaultdict
class APIKeyBalancer:
SAVE_FILE = "key_stats.json"
def __init__(self, keys):
self.keys = keys
self.usage = defaultdict(int)
self.errors = defaultdict(int)
self.load()
def load(self):
if os.path.exists(self.SAVE_FILE):
data = json.load(open(self.SAVE_FILE))
self.usage.update(data.get("usage", {}))
self.errors.update(data.get("errors", {}))
def save(self):
json.dump({
"usage": self.usage,
"errors": self.errors
}, open(self.SAVE_FILE, "w"))
def get_best_key(self):
# choose least used or least errored key
best_key = min(self.keys, key=lambda k: (self.errors[k], self.usage[k]))
self.usage[best_key] += 1
self.save()
return best_key
def report_error(self, key):
self.errors[key] += 1
self.save()
balancer = APIKeyBalancer([
os.getenv("GEMINI_API_KEY_1"),
os.getenv("GEMINI_API_KEY_2"),
os.getenv("GEMINI_API_KEY_3"),
# os.getenv("GEMINI_API_KEY_4"),
os.getenv("GEMINI_API_KEY_5"),
os.getenv("GEMINI_API_KEY_6"),
# os.getenv("GEMINI_API_KEY_7"),
])
class DocumentAgent(BaseAgent):
"""Agent responsible for document processing and invoice data extraction"""
def __init__(self, config: Dict[str, Any] = None):
# pass
super().__init__("document_agent", config)
self.logger = StructuredLogger("DocumentAgent")
self.api_key = balancer.get_best_key()
print("self.api_key..........", self.api_key)
genai.configure(api_key=self.api_key)
# genai.configure(api_key=os.getenv("GEMINI_API_KEY_7"))
self.model = genai.GenerativeModel("gemini-2.5-flash")
def generate(self, prompt):
try:
print("generate called")
response = self.model.generate_content(prompt)
print("response....", response)
return response
except Exception as e:
print("errrororrrooroor")
balancer.report_error(self.api_key)
print(balancer.keys)
print(balancer.usage)
print(balancer.errors)
raise
def _validate_preconditions(self, state: InvoiceProcessingState, workflow_type) -> bool:
# pass
if not state.file_name or not os.path.exists(state.file_name):
self.logger.logger.error(f"[Document Agent] Missing or invalid file: {state.file_name}")
return False
return True
def _validate_postconditions(self, state: InvoiceProcessingState) -> bool:
# pass
return bool(state.invoice_data and state.invoice_data.total > 0)
async def execute(self, state: InvoiceProcessingState, workflow_type) -> InvoiceProcessingState:
# pass
# file_name = state.file_name
self.logger.logger.info(f"Executing Document Agent for file: {state.file_name}")
if not self._validate_preconditions(state, workflow_type):
state.overall_status = ProcessingStatus.FAILED
self._log_decision(state, "Extraction Failed", "Preconditions not met", confidence=0.0)
try:
raw_text = await self._extract_text_from_pdf(state.file_name)
invoice_data = await self._parse_invoice_with_ai(raw_text)
invoice_data = await self._enhance_invoice_data(invoice_data, raw_text)
invoice_data.file_name = state.file_name
state.invoice_data = invoice_data
state.overall_status = ProcessingStatus.IN_PROGRESS
state.current_agent = self.agent_name
state.updated_at = datetime.utcnow()
confidence = self._calculate_extraction_confidence(invoice_data, raw_text)
state.invoice_data.extraction_confidence = confidence
self._log_decision(
state,
"Extraction Successful",
"PDF text successfully extracted and parsed by AI",
confidence,
state.process_id
)
return state
except Exception as e:
self.logger.logger.exception(f"[Document Agent] Extraction failed: {e}")
state.overall_status = ProcessingStatus.FAILED
self._should_escalate(state, reason=str(e))
return state
async def _extract_text_from_pdf(self, file_name: str) -> str:
# pass
text = ""
try:
self.logger.logger.info("[DocumentAgent] Extracting text using PyMuPDF...")
with fitz.open(file_name) as doc:
for page in doc:
text += page.get_text()
if len(text.strip()) < 5:
raise ValueError("PyMuPDF extraction too short, switching to PDFPlumber")
except Exception as e:
self.logger.logger.info("[DocumentAgent] Fallback to PDFPlumber...")
try:
with pdfplumber.open(file_name) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
except Exception as e2:
self.logger.logger.error("[DocumentAgent] PDFPlumber failed :{e2}")
text = ""
return text
async def _parse_invoice_with_ai(self, text: str) -> InvoiceData:
# pass
self.logger.logger.info("[DocumentAgent] Parsing invoice data using Gemini AI...")
print("text-----------", text)
prompt = f"""
Extract structured invoice information as JSON with fields:
invoice_number, order_id, customer_name, due_date, ship_to, ship_mode,
subtotal, discount, shipping_cost, total, and item_details (item_name, quantity, rate, amount).
Important Note: If an item description continues on multiple lines, combine them into one item_name. Check intelligently
that if at all there will be more than one item then it should have more numbers.
So extract by verifying that is there only one item or more than one.
Input Text:
{text[:8000]}
"""
response = self.generate(prompt)
result_text = response.text.strip()
data = safe_json_parse(result_text)
print("----------------------------------text-----------------------------------",text)
print("result text::::::::::::::::::::::::::::",data)
# try:
# data = json.loads(result_text)
# except Exception as e:
# self.logger.logger.warning("AI output not valid JSON, retrying with fallback parse.")
# data = json.loads(result_text[result_text.find('{'): result_text.rfind('}')+1])
items = []
for item in data.get("item_details", []):
items.append(ItemDetail(
item_name=item.get("item_name"),
quantity=float(item.get("quantity", 1)),
rate=to_float(item.get("rate", 0.0)),
amount=to_float(item.get("amount", 0.0)),
# category=self._categorize_item(item.get("item_name", "Unknown")),
))
invoice_data = InvoiceData(
invoice_number=data.get("invoice_number"),
order_id=data.get("order_id"),
customer_name=data.get("customer_name"),
due_date=parse_date_safe(data.get("due_date")),
ship_to=data.get("ship_to"),
ship_mode=data.get("ship_mode"),
subtotal=to_float(data.get("subtotal", 0.0)),
discount=to_float(data.get("discount", 0.0)),
shipping_cost=to_float(data.get("shipping_cost", 0.0)),
total=to_float(data.get("total", 0.0)),
item_details=items,
raw_text=text,
)
confidence = self._calculate_extraction_confidence(invoice_data, text)
invoice_data.extraction_confidence = confidence
self.logger.logger.info("AI output successfully parsed into JSON format")
return invoice_data
async def _enhance_invoice_data(self, invoice_data: InvoiceData, raw_text: str) -> InvoiceData:
# pass
if not invoice_data.customer_name:
if "Invoice To" in raw_text:
lines = raw_text.split("\n")
for i, line in enumerate(lines):
if "Invoice To" in line:
invoice_data.customer_name = lines[i+1].strip()
break
return invoice_data
def _categorize_item(self, item_name: str) -> str:
# pass
name = item_name.lower()
prompt = f"""
Extract the category of the Item from the item details very intelligently
so that we can get the category in which the item belongs to very efficiently:
Example: "Electronics", "Furniture", "Software", etc.....
Input Text- The item is given below (provide the category in JSON format like -- category: 'extracted category') ---->
{name}
"""
response = self.generate(prompt)
result_text = response.text.strip()
category = safe_json_parse(result_text)
print(category['category'])
return category['category']
def _calculate_extraction_confidence(self, invoice_data: InvoiceData, raw_text: str) -> float:
"""
Intelligent confidence scoring for extracted invoice data.
Combines presence, consistency, and numeric sanity checks.
"""
score = 0.0
weight = {
"invoice_number": 0.1,
"order_id": 0.05,
"customer_name": 0.1,
"due_date": 0.05,
"ship_to": 0.05,
"item_details": 0.25,
"total_consistency": 0.25,
"currency_detected": 0.05,
"text_match_bonus": 0.1
}
text_lower = raw_text.lower()
# Presence-based confidence
if invoice_data.invoice_number:
score += weight["invoice_number"]
if invoice_data.order_id:
score += weight["order_id"]
if invoice_data.customer_name:
score += weight["customer_name"]
if invoice_data.due_date and "due_date" in text_lower:
score += weight["due_date"]
if not invoice_data.due_date and "due_date" not in text_lower:
score += weight["due_date"]
if invoice_data.item_details:
score += weight["item_details"]
# Currency detection
if any(c in raw_text for c in ["$", "₹", "€", "usd", "inr", "eur"]):
score += weight["currency_detected"]
# Numeric Consistency: subtotal + shipping ≈ total
def _extract_amounts(pattern):
import re
matches = re.findall(pattern, raw_text)
return [float(m.replace(",", "").replace("$", "").strip()) for m in matches if m]
import re
numbers = _extract_amounts(r"\$?\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?")
if len(numbers) >= 3 and invoice_data.total:
approx_total = max(numbers)
diff = abs(approx_total - invoice_data.total)
if diff < 5: # minor difference allowed
score += weight["total_consistency"]
elif diff < 50:
score += weight["total_consistency"] * 0.5
# Textual verification
hits = 0
for field in [invoice_data.customer_name, invoice_data.order_id, invoice_data.invoice_number]:
if field and str(field).lower() in text_lower:
hits += 1
if hits >= 2:
score += weight["text_match_bonus"]
# Penalty for empty critical fields
missing_critical = not invoice_data.total or not invoice_data.customer_name or not invoice_data.invoice_number
if missing_critical:
score *= 0.8
# Clamp and finalize
final_conf = round(min(score, 0.99), 2)
invoice_data.extraction_confidence = final_conf
return final_conf * 100.0
async def health_check(self) -> Dict[str, Any]:
"""
Perform intelligent health diagnostics for the Document Agent.
Collects operational, performance, and API connectivity metrics.
"""
from datetime import datetime
metrics_data = {}
executions = 0
success_rate = 0.0
avg_duration = 0.0
failures = 0
last_run = None
# latency_trend = None
# 1. Try to get live metrics from state
print("(self.state)-------",self.metrics)
# print("self.state.agent_metrics-------", self.state.agent_metrics)
if self.metrics:
executions = self.metrics["processed"]
avg_duration = self.metrics["avg_latency_ms"]
failures = self.metrics["errors"]
last_run = self.metrics["last_run_at"]
success_rate = (executions - failures) / (executions+1e-8)
# print(executions, avg_duration, failures, last_run, success_rate)
# latency_trend = getattr(m, "total_duration_ms", None)
# 2. API connectivity check
gemini_ok = bool(self.api_key)
# print("self.api---", self.api_key)
# print("geminiokkkkkk", gemini_ok)
api_status = "🟢 Active" if gemini_ok else "🔴 Missing Key"
# 3. Health logic
overall_status = "🟢 Healthy"
if not gemini_ok or failures > 3:
overall_status = "🟠 Degraded"
if executions > 0 and success_rate < 0.5:
overall_status = "🔴 Unhealthy"
# 4. Extended agent diagnostics
metrics_data = {
"Agent": "Document Agent 🧾",
"Executions": executions,
"Success Rate (%)": round(success_rate * 100, 2),
"Avg Duration (ms)": round(avg_duration, 2),
"Total Failures": failures,
"API Status": api_status,
"Last Run": str(last_run) if last_run else "Not applicable",
"Overall Health": overall_status,
# "Timestamp": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"),
}
self.logger.logger.info(f"[HealthCheck] Document Agent metrics: {metrics_data}")
return metrics_data