| """ |
| Pydantic models for financial document extraction. |
| Option C: Common core + type-specific extensions. |
| """ |
|
|
| from pydantic import BaseModel, Field |
| from typing import Optional, List, Literal |
| from enum import Enum |
|
|
|
|
| class AnomalyCategory(str, Enum): |
| """Categories of anomalies the model can detect.""" |
| ARITHMETIC = "arithmetic_error" |
| MISSING_FIELD = "missing_field" |
| FORMAT = "format_anomaly" |
| BUSINESS_LOGIC = "business_logic" |
| CROSS_FIELD = "cross_field" |
|
|
|
|
| class Severity(str, Enum): |
| """Severity levels for detected anomalies.""" |
| LOW = "low" |
| MEDIUM = "medium" |
| HIGH = "high" |
|
|
|
|
| class Party(BaseModel): |
| """Represents an entity (vendor, buyer, etc.).""" |
| name: Optional[str] = None |
| address: Optional[str] = None |
|
|
|
|
| class CommonFields(BaseModel): |
| """Fields shared across all financial document types.""" |
| document_type: Literal["invoice", "purchase_order", "receipt", "bank_statement"] |
| date: Optional[str] = None |
| issuer: Optional[Party] = None |
| recipient: Optional[Party] = None |
| total_amount: Optional[float] = None |
| currency: Optional[str] = "USD" |
|
|
|
|
| class LineItem(BaseModel): |
| """A single line item in a financial document.""" |
| description: str |
| quantity: Optional[float] = None |
| unit_price: Optional[float] = None |
| amount: Optional[float] = None |
|
|
|
|
| |
|
|
| class InvoiceFields(BaseModel): |
| """Fields specific to invoices.""" |
| invoice_number: Optional[str] = None |
| due_date: Optional[str] = None |
| payment_terms: Optional[str] = None |
| tax_amount: Optional[float] = None |
| subtotal: Optional[float] = None |
|
|
|
|
| class PurchaseOrderFields(BaseModel): |
| """Fields specific to purchase orders.""" |
| po_number: Optional[str] = None |
| delivery_date: Optional[str] = None |
| shipping_address: Optional[str] = None |
| referenced_invoice: Optional[str] = None |
|
|
|
|
| class ReceiptFields(BaseModel): |
| """Fields specific to receipts.""" |
| receipt_number: Optional[str] = None |
| payment_method: Optional[str] = None |
| store_location: Optional[str] = None |
| cashier: Optional[str] = None |
|
|
|
|
| class BankStatementFields(BaseModel): |
| """Fields specific to bank statements.""" |
| account_number: Optional[str] = None |
| statement_period: Optional[str] = None |
| opening_balance: Optional[float] = None |
| closing_balance: Optional[float] = None |
|
|
|
|
| class AnomalyFlag(BaseModel): |
| """A single anomaly detected in the document.""" |
| category: AnomalyCategory |
| field: str |
| severity: Severity |
| description: str |
|
|
|
|
| class DocumentExtraction(BaseModel): |
| """ |
| Top-level extraction result. |
| Schema Option C: Common core + type-specific extensions. |
| """ |
| common: CommonFields |
| line_items: Optional[List[LineItem]] = [] |
| type_specific: dict = {} |
| flags: List[AnomalyFlag] = [] |
| confidence_score: float = Field(ge=0.0, le=1.0, default=0.95) |
|
|
| def has_anomalies(self) -> bool: |
| """Check if any anomalies were detected.""" |
| return len(self.flags) > 0 |
|
|
| def high_severity_flags(self) -> List[AnomalyFlag]: |
| """Return only high-severity anomalies.""" |
| return [f for f in self.flags if f.severity == Severity.HIGH] |
|
|