""" Pydantic models for financial document extraction. Option C: Common core + type-specific extensions. """ from pydantic import BaseModel, Field from typing import Optional, List, Literal from enum import Enum class AnomalyCategory(str, Enum): """Categories of anomalies the model can detect.""" ARITHMETIC = "arithmetic_error" MISSING_FIELD = "missing_field" FORMAT = "format_anomaly" BUSINESS_LOGIC = "business_logic" CROSS_FIELD = "cross_field" class Severity(str, Enum): """Severity levels for detected anomalies.""" LOW = "low" MEDIUM = "medium" HIGH = "high" class Party(BaseModel): """Represents an entity (vendor, buyer, etc.).""" name: Optional[str] = None address: Optional[str] = None class CommonFields(BaseModel): """Fields shared across all financial document types.""" document_type: Literal["invoice", "purchase_order", "receipt", "bank_statement"] date: Optional[str] = None issuer: Optional[Party] = None recipient: Optional[Party] = None total_amount: Optional[float] = None currency: Optional[str] = "USD" class LineItem(BaseModel): """A single line item in a financial document.""" description: str quantity: Optional[float] = None unit_price: Optional[float] = None amount: Optional[float] = None # === Type-Specific Extensions === class InvoiceFields(BaseModel): """Fields specific to invoices.""" invoice_number: Optional[str] = None due_date: Optional[str] = None payment_terms: Optional[str] = None tax_amount: Optional[float] = None subtotal: Optional[float] = None class PurchaseOrderFields(BaseModel): """Fields specific to purchase orders.""" po_number: Optional[str] = None delivery_date: Optional[str] = None shipping_address: Optional[str] = None referenced_invoice: Optional[str] = None class ReceiptFields(BaseModel): """Fields specific to receipts.""" receipt_number: Optional[str] = None payment_method: Optional[str] = None store_location: Optional[str] = None cashier: Optional[str] = None class BankStatementFields(BaseModel): """Fields specific to bank statements.""" account_number: Optional[str] = None statement_period: Optional[str] = None opening_balance: Optional[float] = None closing_balance: Optional[float] = None class AnomalyFlag(BaseModel): """A single anomaly detected in the document.""" category: AnomalyCategory field: str severity: Severity description: str class DocumentExtraction(BaseModel): """ Top-level extraction result. Schema Option C: Common core + type-specific extensions. """ common: CommonFields line_items: Optional[List[LineItem]] = [] type_specific: dict = {} # Flexible dict to handle all doc types flags: List[AnomalyFlag] = [] confidence_score: float = Field(ge=0.0, le=1.0, default=0.95) def has_anomalies(self) -> bool: """Check if any anomalies were detected.""" return len(self.flags) > 0 def high_severity_flags(self) -> List[AnomalyFlag]: """Return only high-severity anomalies.""" return [f for f in self.flags if f.severity == Severity.HIGH]