Spaces:
Sleeping
Sleeping
| # src/schema.py | |
| from pydantic import BaseModel, Field, field_validator, model_validator | |
| from typing import List, Optional, Union, Dict | |
| from decimal import Decimal, InvalidOperation | |
| from datetime import date as DateType, datetime | |
| # --- 1. Line Item Schema --- | |
| class LineItem(BaseModel): | |
| description: str | |
| quantity: int = Field(default=1, ge=1) | |
| unit_price: Optional[Decimal] = Field(default=None, ge=0) | |
| total: Decimal = Field(default=0, ge=0) | |
| def validate_precision(cls, v): | |
| """Ensure exactly 2 decimal places for currency.""" | |
| if v is None: | |
| return None | |
| try: | |
| d = Decimal(str(v)) | |
| return d.quantize(Decimal('0.01')) | |
| except (InvalidOperation, ValueError, TypeError): | |
| return Decimal('0.00') | |
| # --- 2. Invoice Schema --- | |
| class InvoiceData(BaseModel): | |
| """ | |
| Strict Data Contract for Invoice Extraction. | |
| """ | |
| # Core Fields | |
| receipt_number: Optional[str] = Field(default=None, description="Unique ID") | |
| date: Optional[DateType] = Field(default=None, description="Invoice Date") | |
| # Financials | |
| total_amount: Optional[Decimal] = Field(default=None, ge=0) | |
| # Entities | |
| vendor: Optional[str] = None | |
| address: Optional[str] = None | |
| bill_to: Optional[Union[str, Dict]] = None | |
| # Nested Items | |
| items: List[LineItem] = Field(default_factory=list) | |
| # --- METADATA --- | |
| validation_status: str = Field(default="unknown", description="passed/failed") | |
| validation_errors: List[str] = Field(default_factory=list, description="List of validation failure messages") | |
| semantic_hash: Optional[str] = Field(default=None, description="Unique fingerprint of the invoice content") | |
| # --- VALIDATORS --- | |
| def clean_date(cls, v): | |
| """Logic: Handle None, parse formats, then validate range.""" | |
| if not v: | |
| return None | |
| parsed_date = v | |
| if isinstance(v, str): | |
| try: | |
| # Try common formats | |
| for fmt in ( | |
| "%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y", | |
| "%m/%d/%Y", "%m-%d-%Y" | |
| ): | |
| try: | |
| parsed_date = datetime.strptime(v, fmt).date() | |
| # Sanity check: If we parsed 05/01/2020, was it May 1st or Jan 5th? | |
| # Usually, if we are here, strict parsing succeeded. | |
| break | |
| except ValueError: | |
| continue | |
| except Exception: | |
| return None | |
| if isinstance(parsed_date, DateType): | |
| today = datetime.now().date() | |
| if parsed_date > today: | |
| return None | |
| # FIX: Use 'DateType' constructor | |
| min_date = DateType(today.year - 30, 1, 1) | |
| if parsed_date < min_date: | |
| return None | |
| return parsed_date | |
| return None | |
| def validate_money(cls, v): | |
| if v is None: | |
| return None | |
| try: | |
| d = Decimal(str(v)) | |
| return d.quantize(Decimal('0.01')) | |
| except (InvalidOperation, ValueError): | |
| return None | |
| def validate_math(self): | |
| if not self.items or self.total_amount is None: | |
| return self | |
| line_items_sum = sum(item.total for item in self.items) | |
| diff = abs(self.total_amount - line_items_sum) | |
| if diff > Decimal('0.05'): | |
| print(f"⚠️ Validation Warning: Total {self.total_amount} != Sum of items {line_items_sum}") | |
| return self |