GSoumyajit2005's picture
refactor: remove obsolete OCR test file and enhance address extraction logic
097a95c
# src/schema.py
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional, Union, Dict
from decimal import Decimal, InvalidOperation
from datetime import date as DateType, datetime
# --- 1. Line Item Schema ---
class LineItem(BaseModel):
description: str
quantity: int = Field(default=1, ge=1)
unit_price: Optional[Decimal] = Field(default=None, ge=0)
total: Decimal = Field(default=0, ge=0)
@field_validator('unit_price', 'total', mode='before')
@classmethod
def validate_precision(cls, v):
"""Ensure exactly 2 decimal places for currency."""
if v is None:
return None
try:
d = Decimal(str(v))
return d.quantize(Decimal('0.01'))
except (InvalidOperation, ValueError, TypeError):
return Decimal('0.00')
# --- 2. Invoice Schema ---
class InvoiceData(BaseModel):
"""
Strict Data Contract for Invoice Extraction.
"""
# Core Fields
receipt_number: Optional[str] = Field(default=None, description="Unique ID")
date: Optional[DateType] = Field(default=None, description="Invoice Date")
# Financials
total_amount: Optional[Decimal] = Field(default=None, ge=0)
# Entities
vendor: Optional[str] = None
address: Optional[str] = None
bill_to: Optional[Union[str, Dict]] = None
# Nested Items
items: List[LineItem] = Field(default_factory=list)
# --- METADATA ---
validation_status: str = Field(default="unknown", description="passed/failed")
validation_errors: List[str] = Field(default_factory=list, description="List of validation failure messages")
semantic_hash: Optional[str] = Field(default=None, description="Unique fingerprint of the invoice content")
# --- VALIDATORS ---
@field_validator('date', mode='before')
@classmethod
def clean_date(cls, v):
"""Logic: Handle None, parse formats, then validate range."""
if not v:
return None
parsed_date = v
if isinstance(v, str):
try:
# Try common formats
for fmt in (
"%d/%m/%Y", "%Y-%m-%d", "%d-%m-%Y", "%d.%m.%Y",
"%m/%d/%Y", "%m-%d-%Y"
):
try:
parsed_date = datetime.strptime(v, fmt).date()
# Sanity check: If we parsed 05/01/2020, was it May 1st or Jan 5th?
# Usually, if we are here, strict parsing succeeded.
break
except ValueError:
continue
except Exception:
return None
if isinstance(parsed_date, DateType):
today = datetime.now().date()
if parsed_date > today:
return None
# FIX: Use 'DateType' constructor
min_date = DateType(today.year - 30, 1, 1)
if parsed_date < min_date:
return None
return parsed_date
return None
@field_validator('total_amount', mode='before')
@classmethod
def validate_money(cls, v):
if v is None:
return None
try:
d = Decimal(str(v))
return d.quantize(Decimal('0.01'))
except (InvalidOperation, ValueError):
return None
@model_validator(mode='after')
def validate_math(self):
if not self.items or self.total_amount is None:
return self
line_items_sum = sum(item.total for item in self.items)
diff = abs(self.total_amount - line_items_sum)
if diff > Decimal('0.05'):
print(f"⚠️ Validation Warning: Total {self.total_amount} != Sum of items {line_items_sum}")
return self