Vaibuzzz's picture
Upload folder using huggingface_hub
faa05c3 verified
"""
Financial Document Extraction Pipeline — Groq Edition.
Replaces the local llama-cpp-python / GGUF inference with a
server-side Groq API call to llama-3.3-70b-versatile.
Features:
- Zero GPU / VRAM requirements — pure API call
- Sub-2-second inference via Groq's LPU hardware
- JSON parsing with multi-strategy fallback
- Pydantic v2 schema validation (unchanged)
- Self-correcting Auditor loop (math check + retry)
Usage:
from src.extractor import FinancialDocExtractor
extractor = FinancialDocExtractor()
result = extractor.extract("raw document markdown here...")
print(result.json_output) # Raw JSON dict
print(result.validated) # Pydantic model (or None)
print(result.flags) # List of anomaly flags
"""
import os
import re
import json
from typing import Optional, Tuple
from dataclasses import dataclass, field
from groq import Groq, RateLimitError, APITimeoutError
from dotenv import load_dotenv
from src.schema import DocumentExtraction, AnomalyFlag
# Load .env for local development; on HF Spaces the secret is injected automatically
load_dotenv()
SYSTEM_PROMPT = """You are a Senior Financial Auditor and Data Extraction Expert. Your task is to transform raw document text into a high-precision, structured JSON audit report.
### 1. DATA EXTRACTION HIERARCHY
Extract data into this exact JSON schema:
{
"common": {
"document_type": "invoice|purchase_order|receipt|bank_statement",
"date": "YYYY-MM-DD or null",
"issuer": {"name": "string", "address": "string or null"},
"recipient": {"name": "string", "address": "string or null"},
"total_amount": number,
"currency": "ISO Code (e.g., USD, INR)"
},
"line_items": [
{"description": "string", "quantity": number, "unit_price": number, "amount": number}
],
"type_specific": {
// invoice_number, po_number, receipt_number, or account_number
},
"flags": [
{"category": "string", "field": "string", "severity": "low|medium|high", "description": "string"}
],
"confidence_score": number
}
### 2. CRITICAL EXTRACTION OVERRIDES
- **THE FLOATING DATE RULE:** Scan the first 10 lines of the provided text. If you find a date string (e.g., 04/06/2026 or 06-Apr-2026) that is not explicitly labeled, assume it is the PRIMARY document date. Do not return null if a date exists at the top of the page.
- **ENTITY MERGING:** Financial documents often separate the Company Name and the Contact Name. You must merge them.
* Example: If you see "Company: Phasellus i" and "Name: Carmita Hammel", the issuer name MUST be "Phasellus i (Attn: Carmita Hammel)".
* Never use "Company" as a placeholder; find the actual business name.
- **NUMERIC PRECISION:** If a "Total" is found at the bottom of the page (e.g., 7598), use that as the `total_amount`. Do not calculate a new total based on line items; report the document's stated total and use 'flags' to report discrepancies.
### 3. THE AUDITOR'S ANOMALY ENGINE
Every document must be audited for the following:
- **arithmetic_error:** Mandatory check. If (Sum of line_items) != total_amount, flag as HIGH severity. DO NOT invent or modify line items to force the math to work. Keep the original extracted data exactly as it appears and simply set this flag.
- **missing_field:** Flag if the expected reference number (Invoice #, PO #) or Date is missing.
- **business_logic:** Flag "Round Number" totals (e.g., exactly $5,000.00) or unusual tax rates.
- **format_anomaly:** Flag if dates are in the future or if quantities are negative (unless marked as 'Credit' or 'Refund').
### 4. DATE PERSISTENCE
- Never drop or forget the `common.date` field during the extraction. If you found it, keep it.
### 5. OUTPUT CONSTRAINTS
- Return ONLY minified JSON.
- No markdown formatting (no ```json blocks).
- No preamble or "Here is your JSON" conversational text.
- If a field is truly missing, use null."""
FEW_SHOT_EXAMPLE = """{
"common": {
"document_type": "invoice",
"date": "2025-03-15",
"issuer": {"name": "Acme Corp", "address": "123 Main St, Springfield"},
"recipient": {"name": "Widget Inc", "address": "456 Oak Ave, Shelbyville"},
"total_amount": 1250.00,
"currency": "USD"
},
"line_items": [
{"description": "Widget A", "quantity": 10, "unit_price": 50.00, "amount": 500.00},
{"description": "Widget B", "quantity": 5, "unit_price": 150.00, "amount": 750.00}
],
"type_specific": {
"invoice_number": "INV-2025-0042",
"due_date": "2025-04-14",
"payment_terms": "Net 30",
"tax_amount": 0.00,
"subtotal": 1250.00
},
"flags": [
{
"category": "arithmetic_error",
"field": "total_amount",
"severity": "high",
"description": "Line item sum (1250.00) matches total, but no tax is applied despite taxable items."
}
],
"confidence_score": 0.92
}"""
@dataclass
class ExtractionResult:
"""Result from the extraction pipeline."""
raw_output: str = ""
json_output: Optional[dict] = None
validated: Optional[DocumentExtraction] = None
flags: list = field(default_factory=list)
is_valid_json: bool = False
is_schema_compliant: bool = False
attempts: int = 0
error: Optional[str] = None
@property
def success(self) -> bool:
return self.json_output is not None
@property
def has_anomalies(self) -> bool:
return len(self.flags) > 0
class FinancialDocExtractor:
"""
Production inference pipeline using the Groq API.
Features:
- Uses llama-3.1-70b-versatile via Groq LPU for ~1-2s inference
- Few-shot prompting for consistent JSON structure
- JSON parsing with multi-strategy fallback
- Self-correcting Auditor loop (math verification + retry)
- Pydantic v2 schema validation
"""
def __init__(
self,
model: str = "llama-3.3-70b-versatile",
max_tokens: int = 4096,
temperature: float = 0.1,
max_retries: int = 3,
):
self.model = model
self.max_tokens = max_tokens
self.temperature = temperature
self.max_retries = max_retries
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
raise EnvironmentError(
"GROQ_API_KEY is not set. "
"Add it as a Secret in your HF Space settings, "
"or create a .env file with GROQ_API_KEY=gsk_..."
)
self._client = Groq(api_key=api_key)
print(f" ✅ Groq client initialised → model: {self.model}")
def _generate(self, text: str) -> str:
"""
Call the Groq API with a few-shot prompt.
Returns the raw string response from the model.
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
# --- Few-shot example ---
{
"role": "user",
"content": (
"Extract structured data from this financial document:\n\n---\n"
"Invoice #INV-2025-0042\nDate: 2025-03-15\n"
"From: Acme Corp, 123 Main St, Springfield\n"
"To: Widget Inc, 456 Oak Ave, Shelbyville\n\n"
"Widget A 10 x $50.00 = $500.00\n"
"Widget B 5 x $150.00 = $750.00\n\n"
"Total: $1,250.00\nDue: 2025-04-14 | Terms: Net 30\n---"
),
},
{"role": "assistant", "content": FEW_SHOT_EXAMPLE},
# --- Actual document ---
{
"role": "user",
"content": f"Extract structured data from this financial document:\n\n---\n{text}\n---",
},
]
print(f" [Groq] Sending {len(text):,} chars to {self.model}...")
response = self._client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
timeout=30.0, # Hard timeout — prevents hung workers
)
content = response.choices[0].message.content.strip()
print(f" [Groq] Response received ({len(content):,} chars).")
return content
# ------------------------------------------------------------------
# JSON parsing — 4-strategy fallback (unchanged from original)
# ------------------------------------------------------------------
@staticmethod
def _extract_json(text: str) -> Optional[dict]:
"""
Try to extract valid JSON from model output.
Strategy:
1. Direct parse
2. Strip markdown code fences
3. Regex extraction of JSON object
4. Find first '{' to last matching '}'
"""
# Strategy 1: Direct parse
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Strategy 2: Strip markdown code fences
cleaned = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE)
cleaned = re.sub(r'\s*```\s*$', '', cleaned, flags=re.MULTILINE)
try:
return json.loads(cleaned.strip())
except json.JSONDecodeError:
pass
# Strategy 3: Regex for JSON object
json_match = re.search(r'\{[\s\S]*\}', text)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
pass
# Strategy 4: Find matching braces
start = text.find('{')
if start != -1:
depth = 0
for i in range(start, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
try:
return json.loads(text[start:i+1])
except json.JSONDecodeError:
break
return None
# ------------------------------------------------------------------
# Schema validation
# ------------------------------------------------------------------
@staticmethod
def _validate_schema(json_data: dict) -> Tuple[Optional[DocumentExtraction], Optional[str]]:
"""Validate JSON against Pydantic schema."""
try:
validated = DocumentExtraction(**json_data)
return validated, None
except Exception as e:
return None, str(e)
# ------------------------------------------------------------------
# Main extraction loop (with Auditor self-correction)
# ------------------------------------------------------------------
def extract(self, text: str) -> ExtractionResult:
"""
Extract structured data from financial document text.
Runs inference with retry logic:
- Up to max_retries attempts
- Each attempt: generate → parse JSON → validate schema
- Auditor checks line-item math and triggers self-correction
- Returns best result found
Args:
text: Raw document text / Markdown (from Docling).
Returns:
ExtractionResult with parsed data and validation status.
"""
result = ExtractionResult()
for attempt in range(1, self.max_retries + 1):
result.attempts = attempt
try:
# Generate
raw_output = self._generate(text)
result.raw_output = raw_output
# Parse JSON
json_data = self._extract_json(raw_output)
if json_data is None:
result.error = f"Attempt {attempt}: Could not extract valid JSON"
continue
result.json_output = json_data
result.is_valid_json = True
# Extract flags
result.flags = json_data.get("flags", [])
# Validate against Pydantic schema
validated, val_error = self._validate_schema(json_data)
if validated:
# --- AUDITOR AGENT (SELF-CORRECTION LOOP) ---
if validated.line_items and validated.common.total_amount is not None:
computed_total = sum(item.amount or 0.0 for item in validated.line_items)
# Allow small floating-point variance, but fail on >1.0 mismatch
if abs(computed_total - validated.common.total_amount) > 1.0:
print(
f" [Auditor] Math mismatch detected "
f"(Sum: {computed_total} vs Total: {validated.common.total_amount}). "
f"Preserving first-pass data and injecting arithmetic_error flag."
)
# We deliberately do NOT trigger self-correction (retry) here to
# protect first-pass accuracy on dates and entities.
# Programmatically ensure the flag is present if the LLM missed it
if not any(f.category == "arithmetic_error" for f in validated.flags):
error_flag = AnomalyFlag(
category="arithmetic_error",
field="total_amount",
severity="high",
description=f"Calculated line item sum ({computed_total:.2f}) does not match stated total_amount ({validated.common.total_amount:.2f})."
)
validated.flags.append(error_flag)
result.validated = validated
result.is_schema_compliant = True
result.flags = [
f.model_dump() if hasattr(f, 'model_dump') else f
for f in validated.flags
]
return result # Success — no need to retry
else:
result.error = f"Attempt {attempt}: Schema validation failed: {val_error}"
# If JSON is valid but schema failed, still return partial success
if attempt == self.max_retries:
return result
except RateLimitError:
result.error = f"Attempt {attempt}: Groq rate limit exceeded. Please wait 60s and retry."
print(f" [Error] {result.error}")
except APITimeoutError:
result.error = f"Attempt {attempt}: Groq API timed out after 30s. The service may be busy."
print(f" [Error] {result.error}")
except Exception as e:
result.error = f"Attempt {attempt}: {str(e)}"
print(f" [Error] {result.error}")
return result
def extract_from_pdf(self, pdf_path: str) -> ExtractionResult:
"""
Extract from a PDF file (uses Docling internally).
Args:
pdf_path: Path to PDF file.
Returns:
ExtractionResult.
"""
from src.pdf_reader import extract_text_from_pdf
try:
text = extract_text_from_pdf(pdf_path)
return self.extract(text)
except Exception as e:
result = ExtractionResult()
result.error = f"PDF extraction failed: {str(e)}"
return result