File size: 3,928 Bytes
f74e17e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
097a95c
 
 
 
f74e17e
 
097a95c
 
f74e17e
 
 
 
 
 
 
 
 
 
 
097a95c
 
f74e17e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# src/schema.py

from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional, Union, Dict
from decimal import Decimal, InvalidOperation
from datetime import date as DateType, datetime

# --- 1. Line Item Schema ---
class LineItem(BaseModel):
    description: str
    quantity: int = Field(default=1, ge=1)
    unit_price: Optional[Decimal] = Field(default=None, ge=0)
    total: Decimal = Field(default=0, ge=0)

    @field_validator('unit_price', 'total', mode='before')
    @classmethod
    def validate_precision(cls, v):
        """Ensure exactly 2 decimal places for currency."""
        if v is None:
            return None
        try:
            d = Decimal(str(v))
            return d.quantize(Decimal('0.01'))
        except (InvalidOperation, ValueError, TypeError):
            return Decimal('0.00')

# --- 2. Invoice Schema ---
class InvoiceData(BaseModel):
    """
    Strict Data Contract for Invoice Extraction.
    """
    # Core Fields
    receipt_number: Optional[str] = Field(default=None, description="Unique ID")
    
    date: Optional[DateType] = Field(default=None, description="Invoice Date")
    
    # Financials
    total_amount: Optional[Decimal] = Field(default=None, ge=0)
    
    # Entities
    vendor: Optional[str] = None
    address: Optional[str] = None
    bill_to: Optional[Union[str, Dict]] = None 
    
    # Nested Items
    items: List[LineItem] = Field(default_factory=list)

    # --- METADATA ---
    validation_status: str = Field(default="unknown", description="passed/failed")
    validation_errors: List[str] = Field(default_factory=list, description="List of validation failure messages")
    semantic_hash: Optional[str] = Field(default=None, description="Unique fingerprint of the invoice content")

    # --- VALIDATORS ---

    @field_validator('date', mode='before')
    @classmethod
    def clean_date(cls, v):
        """Logic: Handle None, parse formats, then validate range."""
        if not v: 
            return None
        
        parsed_date = v
        
        if isinstance(v, str):
            try:
                # Try common formats
                for fmt in (
                    "%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y", 
                    "%m/%d/%Y", "%m-%d-%Y"  
                ):
                    try:
                        parsed_date = datetime.strptime(v, fmt).date()
                        # Sanity check: If we parsed 05/01/2020, was it May 1st or Jan 5th?
                        # Usually, if we are here, strict parsing succeeded.
                        break
                    except ValueError:
                        continue
            except Exception:
                return None 

        if isinstance(parsed_date, DateType):
            today = datetime.now().date()
            if parsed_date > today:
                return None 
            
            # FIX: Use 'DateType' constructor
            min_date = DateType(today.year - 30, 1, 1)
            if parsed_date < min_date:
                return None
            
            return parsed_date
            
        return None

    @field_validator('total_amount', mode='before')
    @classmethod
    def validate_money(cls, v):
        if v is None: 
            return None
        try:
            d = Decimal(str(v))
            return d.quantize(Decimal('0.01'))
        except (InvalidOperation, ValueError):
            return None

    @model_validator(mode='after')
    def validate_math(self):
        if not self.items or self.total_amount is None:
            return self
        
        line_items_sum = sum(item.total for item in self.items)
        diff = abs(self.total_amount - line_items_sum)
        
        if diff > Decimal('0.05'):
            print(f"⚠️ Validation Warning: Total {self.total_amount} != Sum of items {line_items_sum}")
            
        return self