File size: 15,554 Bytes
10ff0db
abf3ab0
10ff0db
abf3ab0
6594725
abf3ab0
 
 
 
 
 
 
10ff0db
 
 
 
abf3ab0
 
 
10ff0db
abf3ab0
10ff0db
 
 
 
 
 
 
 
 
abf3ab0
 
 
10ff0db
 
 
abf3ab0
 
 
 
431c55b
4212baa
431c55b
4212baa
 
 
431c55b
4212baa
 
 
 
 
 
 
 
 
 
431c55b
4212baa
 
 
 
431c55b
4212baa
 
431c55b
 
 
 
 
 
 
 
 
faa05c3
431c55b
 
 
 
faa05c3
 
 
 
431c55b
 
 
 
10ff0db
 
abf3ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf3ab0
 
10ff0db
abf3ab0
 
 
 
 
10ff0db
abf3ab0
10ff0db
 
6594725
abf3ab0
10ff0db
 
 
abf3ab0
 
10ff0db
 
abf3ab0
 
 
 
 
 
 
10ff0db
abf3ab0
 
 
10ff0db
 
abf3ab0
 
10ff0db
 
 
abf3ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ff0db
abf3ab0
 
 
 
10ff0db
abf3ab0
10ff0db
abf3ab0
10ff0db
abf3ab0
 
 
 
 
 
 
 
 
10ff0db
 
 
 
abf3ab0
10ff0db
 
 
 
abf3ab0
10ff0db
 
 
 
 
 
abf3ab0
10ff0db
 
 
 
 
 
 
abf3ab0
10ff0db
 
 
 
 
 
 
abf3ab0
10ff0db
 
 
 
 
 
 
 
 
 
 
 
 
 
abf3ab0
10ff0db
abf3ab0
 
 
 
 
10ff0db
 
 
 
 
 
 
 
abf3ab0
 
 
 
 
10ff0db
 
 
abf3ab0
10ff0db
 
 
abf3ab0
10ff0db
abf3ab0
10ff0db
abf3ab0
 
10ff0db
 
 
 
abf3ab0
10ff0db
 
abf3ab0
10ff0db
 
 
 
abf3ab0
10ff0db
 
abf3ab0
10ff0db
 
 
abf3ab0
10ff0db
 
abf3ab0
10ff0db
 
abf3ab0
10ff0db
 
abf3ab0
10ff0db
 
 
 
abf3ab0
 
10ff0db
faa05c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abf3ab0
faa05c3
10ff0db
 
 
 
 
 
 
 
 
 
abf3ab0
 
10ff0db
 
abf3ab0
 
 
08926a8
abf3ab0
 
08926a8
10ff0db
 
08926a8
abf3ab0
10ff0db
abf3ab0
10ff0db
 
abf3ab0
 
10ff0db
 
abf3ab0
10ff0db
 
 
 
abf3ab0
10ff0db
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
"""
Financial Document Extraction Pipeline β€” Groq Edition.

Replaces the local llama-cpp-python / GGUF inference with a
server-side Groq API call to llama-3.3-70b-versatile.

Features:
- Zero GPU / VRAM requirements β€” pure API call
- Sub-2-second inference via Groq's LPU hardware
- JSON parsing with multi-strategy fallback
- Pydantic v2 schema validation (unchanged)
- Self-correcting Auditor loop (math check + retry)

Usage:
    from src.extractor import FinancialDocExtractor

    extractor = FinancialDocExtractor()
    result = extractor.extract("raw document markdown here...")

    print(result.json_output)      # Raw JSON dict
    print(result.validated)        # Pydantic model (or None)
    print(result.flags)            # List of anomaly flags
"""

import os
import re
import json
from typing import Optional, Tuple
from dataclasses import dataclass, field

from groq import Groq, RateLimitError, APITimeoutError
from dotenv import load_dotenv

from src.schema import DocumentExtraction, AnomalyFlag


# Load .env for local development; on HF Spaces the secret is injected automatically
load_dotenv()


SYSTEM_PROMPT = """You are a Senior Financial Auditor and Data Extraction Expert. Your task is to transform raw document text into a high-precision, structured JSON audit report.

### 1. DATA EXTRACTION HIERARCHY
Extract data into this exact JSON schema:
{
  "common": {
    "document_type": "invoice|purchase_order|receipt|bank_statement",
    "date": "YYYY-MM-DD or null",
    "issuer": {"name": "string", "address": "string or null"},
    "recipient": {"name": "string", "address": "string or null"},
    "total_amount": number,
    "currency": "ISO Code (e.g., USD, INR)"
  },
  "line_items": [
    {"description": "string", "quantity": number, "unit_price": number, "amount": number}
  ],
  "type_specific": {
     // invoice_number, po_number, receipt_number, or account_number
  },
  "flags": [
    {"category": "string", "field": "string", "severity": "low|medium|high", "description": "string"}
  ],
  "confidence_score": number
}

### 2. CRITICAL EXTRACTION OVERRIDES
- **THE FLOATING DATE RULE:** Scan the first 10 lines of the provided text. If you find a date string (e.g., 04/06/2026 or 06-Apr-2026) that is not explicitly labeled, assume it is the PRIMARY document date. Do not return null if a date exists at the top of the page.
- **ENTITY MERGING:** Financial documents often separate the Company Name and the Contact Name. You must merge them. 
  * Example: If you see "Company: Phasellus i" and "Name: Carmita Hammel", the issuer name MUST be "Phasellus i (Attn: Carmita Hammel)".
  * Never use "Company" as a placeholder; find the actual business name.
- **NUMERIC PRECISION:** If a "Total" is found at the bottom of the page (e.g., 7598), use that as the `total_amount`. Do not calculate a new total based on line items; report the document's stated total and use 'flags' to report discrepancies.

### 3. THE AUDITOR'S ANOMALY ENGINE
Every document must be audited for the following:
- **arithmetic_error:** Mandatory check. If (Sum of line_items) != total_amount, flag as HIGH severity. DO NOT invent or modify line items to force the math to work. Keep the original extracted data exactly as it appears and simply set this flag.
- **missing_field:** Flag if the expected reference number (Invoice #, PO #) or Date is missing.
- **business_logic:** Flag "Round Number" totals (e.g., exactly $5,000.00) or unusual tax rates.
- **format_anomaly:** Flag if dates are in the future or if quantities are negative (unless marked as 'Credit' or 'Refund').

### 4. DATE PERSISTENCE
- Never drop or forget the `common.date` field during the extraction. If you found it, keep it.

### 5. OUTPUT CONSTRAINTS
- Return ONLY minified JSON. 
- No markdown formatting (no ```json blocks). 
- No preamble or "Here is your JSON" conversational text.
- If a field is truly missing, use null."""


FEW_SHOT_EXAMPLE = """{
  "common": {
    "document_type": "invoice",
    "date": "2025-03-15",
    "issuer": {"name": "Acme Corp", "address": "123 Main St, Springfield"},
    "recipient": {"name": "Widget Inc", "address": "456 Oak Ave, Shelbyville"},
    "total_amount": 1250.00,
    "currency": "USD"
  },
  "line_items": [
    {"description": "Widget A", "quantity": 10, "unit_price": 50.00, "amount": 500.00},
    {"description": "Widget B", "quantity": 5, "unit_price": 150.00, "amount": 750.00}
  ],
  "type_specific": {
    "invoice_number": "INV-2025-0042",
    "due_date": "2025-04-14",
    "payment_terms": "Net 30",
    "tax_amount": 0.00,
    "subtotal": 1250.00
  },
  "flags": [
    {
      "category": "arithmetic_error",
      "field": "total_amount",
      "severity": "high",
      "description": "Line item sum (1250.00) matches total, but no tax is applied despite taxable items."
    }
  ],
  "confidence_score": 0.92
}"""


@dataclass
class ExtractionResult:
    """Result from the extraction pipeline."""
    raw_output: str = ""
    json_output: Optional[dict] = None
    validated: Optional[DocumentExtraction] = None
    flags: list = field(default_factory=list)
    is_valid_json: bool = False
    is_schema_compliant: bool = False
    attempts: int = 0
    error: Optional[str] = None

    @property
    def success(self) -> bool:
        return self.json_output is not None

    @property
    def has_anomalies(self) -> bool:
        return len(self.flags) > 0


class FinancialDocExtractor:
    """
    Production inference pipeline using the Groq API.

    Features:
    - Uses llama-3.1-70b-versatile via Groq LPU for ~1-2s inference
    - Few-shot prompting for consistent JSON structure
    - JSON parsing with multi-strategy fallback
    - Self-correcting Auditor loop (math verification + retry)
    - Pydantic v2 schema validation
    """

    def __init__(
        self,
        model: str = "llama-3.3-70b-versatile",
        max_tokens: int = 4096,
        temperature: float = 0.1,
        max_retries: int = 3,
    ):
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.max_retries = max_retries

        api_key = os.environ.get("GROQ_API_KEY")
        if not api_key:
            raise EnvironmentError(
                "GROQ_API_KEY is not set. "
                "Add it as a Secret in your HF Space settings, "
                "or create a .env file with GROQ_API_KEY=gsk_..."
            )
        self._client = Groq(api_key=api_key)
        print(f"  βœ… Groq client initialised  β†’  model: {self.model}")

    def _generate(self, text: str) -> str:
        """
        Call the Groq API with a few-shot prompt.
        Returns the raw string response from the model.
        """
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            # --- Few-shot example ---
            {
                "role": "user",
                "content": (
                    "Extract structured data from this financial document:\n\n---\n"
                    "Invoice #INV-2025-0042\nDate: 2025-03-15\n"
                    "From: Acme Corp, 123 Main St, Springfield\n"
                    "To: Widget Inc, 456 Oak Ave, Shelbyville\n\n"
                    "Widget A   10 x $50.00  = $500.00\n"
                    "Widget B    5 x $150.00 = $750.00\n\n"
                    "Total: $1,250.00\nDue: 2025-04-14  |  Terms: Net 30\n---"
                ),
            },
            {"role": "assistant", "content": FEW_SHOT_EXAMPLE},
            # --- Actual document ---
            {
                "role": "user",
                "content": f"Extract structured data from this financial document:\n\n---\n{text}\n---",
            },
        ]

        print(f"  [Groq] Sending {len(text):,} chars to {self.model}...")
        response = self._client.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=self.max_tokens,
            temperature=self.temperature,
            timeout=30.0,  # Hard timeout β€” prevents hung workers
        )

        content = response.choices[0].message.content.strip()
        print(f"  [Groq] Response received ({len(content):,} chars).")
        return content

    # ------------------------------------------------------------------
    # JSON parsing β€” 4-strategy fallback (unchanged from original)
    # ------------------------------------------------------------------

    @staticmethod
    def _extract_json(text: str) -> Optional[dict]:
        """
        Try to extract valid JSON from model output.

        Strategy:
        1. Direct parse
        2. Strip markdown code fences
        3. Regex extraction of JSON object
        4. Find first '{' to last matching '}'
        """
        # Strategy 1: Direct parse
        try:
            return json.loads(text)
        except json.JSONDecodeError:
            pass

        # Strategy 2: Strip markdown code fences
        cleaned = re.sub(r'^```(?:json)?\s*', '', text, flags=re.MULTILINE)
        cleaned = re.sub(r'\s*```\s*$', '', cleaned, flags=re.MULTILINE)
        try:
            return json.loads(cleaned.strip())
        except json.JSONDecodeError:
            pass

        # Strategy 3: Regex for JSON object
        json_match = re.search(r'\{[\s\S]*\}', text)
        if json_match:
            try:
                return json.loads(json_match.group())
            except json.JSONDecodeError:
                pass

        # Strategy 4: Find matching braces
        start = text.find('{')
        if start != -1:
            depth = 0
            for i in range(start, len(text)):
                if text[i] == '{':
                    depth += 1
                elif text[i] == '}':
                    depth -= 1
                    if depth == 0:
                        try:
                            return json.loads(text[start:i+1])
                        except json.JSONDecodeError:
                            break

        return None

    # ------------------------------------------------------------------
    # Schema validation
    # ------------------------------------------------------------------

    @staticmethod
    def _validate_schema(json_data: dict) -> Tuple[Optional[DocumentExtraction], Optional[str]]:
        """Validate JSON against Pydantic schema."""
        try:
            validated = DocumentExtraction(**json_data)
            return validated, None
        except Exception as e:
            return None, str(e)

    # ------------------------------------------------------------------
    # Main extraction loop (with Auditor self-correction)
    # ------------------------------------------------------------------

    def extract(self, text: str) -> ExtractionResult:
        """
        Extract structured data from financial document text.

        Runs inference with retry logic:
        - Up to max_retries attempts
        - Each attempt: generate β†’ parse JSON β†’ validate schema
        - Auditor checks line-item math and triggers self-correction
        - Returns best result found

        Args:
            text: Raw document text / Markdown (from Docling).

        Returns:
            ExtractionResult with parsed data and validation status.
        """
        result = ExtractionResult()

        for attempt in range(1, self.max_retries + 1):
            result.attempts = attempt

            try:
                # Generate
                raw_output = self._generate(text)
                result.raw_output = raw_output

                # Parse JSON
                json_data = self._extract_json(raw_output)

                if json_data is None:
                    result.error = f"Attempt {attempt}: Could not extract valid JSON"
                    continue

                result.json_output = json_data
                result.is_valid_json = True

                # Extract flags
                result.flags = json_data.get("flags", [])

                # Validate against Pydantic schema
                validated, val_error = self._validate_schema(json_data)

                if validated:
                    # --- AUDITOR AGENT (SELF-CORRECTION LOOP) ---
                    if validated.line_items and validated.common.total_amount is not None:
                        computed_total = sum(item.amount or 0.0 for item in validated.line_items)

                        # Allow small floating-point variance, but fail on >1.0 mismatch
                        if abs(computed_total - validated.common.total_amount) > 1.0:
                            print(
                                f"  [Auditor] Math mismatch detected "
                                f"(Sum: {computed_total} vs Total: {validated.common.total_amount}). "
                                f"Preserving first-pass data and injecting arithmetic_error flag."
                            )
                            # We deliberately do NOT trigger self-correction (retry) here to 
                            # protect first-pass accuracy on dates and entities.
                            
                            # Programmatically ensure the flag is present if the LLM missed it
                            if not any(f.category == "arithmetic_error" for f in validated.flags):
                                error_flag = AnomalyFlag(
                                    category="arithmetic_error",
                                    field="total_amount",
                                    severity="high",
                                    description=f"Calculated line item sum ({computed_total:.2f}) does not match stated total_amount ({validated.common.total_amount:.2f})."
                                )
                                validated.flags.append(error_flag)

                    result.validated = validated
                    result.is_schema_compliant = True
                    result.flags = [
                        f.model_dump() if hasattr(f, 'model_dump') else f
                        for f in validated.flags
                    ]
                    return result  # Success β€” no need to retry
                else:
                    result.error = f"Attempt {attempt}: Schema validation failed: {val_error}"

                # If JSON is valid but schema failed, still return partial success
                if attempt == self.max_retries:
                    return result

            except RateLimitError:
                result.error = f"Attempt {attempt}: Groq rate limit exceeded. Please wait 60s and retry."
                print(f"  [Error] {result.error}")
            except APITimeoutError:
                result.error = f"Attempt {attempt}: Groq API timed out after 30s. The service may be busy."
                print(f"  [Error] {result.error}")
            except Exception as e:
                result.error = f"Attempt {attempt}: {str(e)}"
                print(f"  [Error] {result.error}")

        return result

    def extract_from_pdf(self, pdf_path: str) -> ExtractionResult:
        """
        Extract from a PDF file (uses Docling internally).

        Args:
            pdf_path: Path to PDF file.

        Returns:
            ExtractionResult.
        """
        from src.pdf_reader import extract_text_from_pdf

        try:
            text = extract_text_from_pdf(pdf_path)
            return self.extract(text)
        except Exception as e:
            result = ExtractionResult()
            result.error = f"PDF extraction failed: {str(e)}"
            return result