Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
"""
Pydantic models for financial document extraction.
Option C: Common core + type-specific extensions.
"""
from pydantic import BaseModel, Field
from typing import Optional, List, Literal
from enum import Enum
class AnomalyCategory(str, Enum):
"""Categories of anomalies the model can detect."""
ARITHMETIC = "arithmetic_error"
MISSING_FIELD = "missing_field"
FORMAT = "format_anomaly"
BUSINESS_LOGIC = "business_logic"
CROSS_FIELD = "cross_field"
class Severity(str, Enum):
"""Severity levels for detected anomalies."""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class Party(BaseModel):
"""Represents an entity (vendor, buyer, etc.)."""
name: Optional[str] = None
address: Optional[str] = None
class CommonFields(BaseModel):
"""Fields shared across all financial document types."""
document_type: Literal["invoice", "purchase_order", "receipt", "bank_statement"]
date: Optional[str] = None
issuer: Optional[Party] = None
recipient: Optional[Party] = None
total_amount: Optional[float] = None
currency: Optional[str] = "USD"
class LineItem(BaseModel):
"""A single line item in a financial document."""
description: str
quantity: Optional[float] = None
unit_price: Optional[float] = None
amount: Optional[float] = None
# === Type-Specific Extensions ===
class InvoiceFields(BaseModel):
"""Fields specific to invoices."""
invoice_number: Optional[str] = None
due_date: Optional[str] = None
payment_terms: Optional[str] = None
tax_amount: Optional[float] = None
subtotal: Optional[float] = None
class PurchaseOrderFields(BaseModel):
"""Fields specific to purchase orders."""
po_number: Optional[str] = None
delivery_date: Optional[str] = None
shipping_address: Optional[str] = None
referenced_invoice: Optional[str] = None
class ReceiptFields(BaseModel):
"""Fields specific to receipts."""
receipt_number: Optional[str] = None
payment_method: Optional[str] = None
store_location: Optional[str] = None
cashier: Optional[str] = None
class BankStatementFields(BaseModel):
"""Fields specific to bank statements."""
account_number: Optional[str] = None
statement_period: Optional[str] = None
opening_balance: Optional[float] = None
closing_balance: Optional[float] = None
class AnomalyFlag(BaseModel):
"""A single anomaly detected in the document."""
category: AnomalyCategory
field: str
severity: Severity
description: str
class DocumentExtraction(BaseModel):
"""
Top-level extraction result.
Schema Option C: Common core + type-specific extensions.
"""
common: CommonFields
line_items: Optional[List[LineItem]] = []
type_specific: dict = {} # Flexible dict to handle all doc types
flags: List[AnomalyFlag] = []
confidence_score: float = Field(ge=0.0, le=1.0, default=0.95)
def has_anomalies(self) -> bool:
"""Check if any anomalies were detected."""
return len(self.flags) > 0
def high_severity_flags(self) -> List[AnomalyFlag]:
"""Return only high-severity anomalies."""
return [f for f in self.flags if f.severity == Severity.HIGH]