| | """ |
| | Schema Definitions for Field Extraction |
| | |
| | Pydantic-compatible schemas for defining extraction targets. |
| | """ |
| |
|
| | from dataclasses import dataclass, field as dataclass_field |
| | from enum import Enum |
| | from typing import Any, Callable, Dict, List, Optional, Type, Union |
| |
|
| | from pydantic import BaseModel, Field, create_model |
| |
|
| |
|
| | class FieldType(str, Enum): |
| | """Types of extractable fields.""" |
| |
|
| | STRING = "string" |
| | INTEGER = "integer" |
| | FLOAT = "float" |
| | BOOLEAN = "boolean" |
| | DATE = "date" |
| | DATETIME = "datetime" |
| | CURRENCY = "currency" |
| | PERCENTAGE = "percentage" |
| | EMAIL = "email" |
| | PHONE = "phone" |
| | ADDRESS = "address" |
| | LIST = "list" |
| | OBJECT = "object" |
| |
|
| |
|
| | @dataclass |
| | class FieldSpec: |
| | """Specification for a single extraction field.""" |
| |
|
| | name: str |
| | field_type: FieldType = FieldType.STRING |
| | description: str = "" |
| | required: bool = True |
| | default: Any = None |
| |
|
| | |
| | pattern: Optional[str] = None |
| | min_value: Optional[float] = None |
| | max_value: Optional[float] = None |
| | min_length: Optional[int] = None |
| | max_length: Optional[int] = None |
| | allowed_values: Optional[List[Any]] = None |
| |
|
| | |
| | nested_schema: Optional["ExtractionSchema"] = None |
| | list_item_type: Optional[FieldType] = None |
| |
|
| | |
| | aliases: List[str] = dataclass_field(default_factory=list) |
| | examples: List[str] = dataclass_field(default_factory=list) |
| | context_hints: List[str] = dataclass_field(default_factory=list) |
| |
|
| | |
| | min_confidence: float = 0.5 |
| |
|
| | def to_json_schema(self) -> Dict[str, Any]: |
| | """Convert to JSON Schema format.""" |
| | type_mapping = { |
| | FieldType.STRING: "string", |
| | FieldType.INTEGER: "integer", |
| | FieldType.FLOAT: "number", |
| | FieldType.BOOLEAN: "boolean", |
| | FieldType.DATE: "string", |
| | FieldType.DATETIME: "string", |
| | FieldType.CURRENCY: "string", |
| | FieldType.PERCENTAGE: "string", |
| | FieldType.EMAIL: "string", |
| | FieldType.PHONE: "string", |
| | FieldType.ADDRESS: "string", |
| | FieldType.LIST: "array", |
| | FieldType.OBJECT: "object", |
| | } |
| |
|
| | schema: Dict[str, Any] = { |
| | "type": type_mapping.get(self.field_type, "string"), |
| | } |
| |
|
| | if self.description: |
| | schema["description"] = self.description |
| |
|
| | if self.pattern: |
| | schema["pattern"] = self.pattern |
| |
|
| | if self.field_type == FieldType.DATE: |
| | schema["format"] = "date" |
| | elif self.field_type == FieldType.DATETIME: |
| | schema["format"] = "date-time" |
| | elif self.field_type == FieldType.EMAIL: |
| | schema["format"] = "email" |
| |
|
| | if self.min_value is not None: |
| | schema["minimum"] = self.min_value |
| | if self.max_value is not None: |
| | schema["maximum"] = self.max_value |
| | if self.min_length is not None: |
| | schema["minLength"] = self.min_length |
| | if self.max_length is not None: |
| | schema["maxLength"] = self.max_length |
| | if self.allowed_values: |
| | schema["enum"] = self.allowed_values |
| |
|
| | if self.field_type == FieldType.LIST and self.nested_schema: |
| | schema["items"] = self.nested_schema.to_json_schema() |
| | elif self.field_type == FieldType.OBJECT and self.nested_schema: |
| | schema.update(self.nested_schema.to_json_schema()) |
| |
|
| | return schema |
| |
|
| |
|
| | @dataclass |
| | class ExtractionSchema: |
| | """ |
| | Schema defining fields to extract from a document. |
| | |
| | Can be nested for complex document structures. |
| | """ |
| |
|
| | name: str |
| | description: str = "" |
| | fields: List[FieldSpec] = dataclass_field(default_factory=list) |
| |
|
| | |
| | allow_partial: bool = True |
| | abstain_on_low_confidence: bool = True |
| | min_overall_confidence: float = 0.5 |
| |
|
| | def add_field(self, field: FieldSpec) -> "ExtractionSchema": |
| | """Add a field to the schema.""" |
| | self.fields.append(field) |
| | return self |
| |
|
| | def add_string_field( |
| | self, |
| | name: str, |
| | description: str = "", |
| | required: bool = True, |
| | **kwargs |
| | ) -> "ExtractionSchema": |
| | """Add a string field.""" |
| | field = FieldSpec( |
| | name=name, |
| | field_type=FieldType.STRING, |
| | description=description, |
| | required=required, |
| | **kwargs |
| | ) |
| | return self.add_field(field) |
| |
|
| | def add_number_field( |
| | self, |
| | name: str, |
| | description: str = "", |
| | required: bool = True, |
| | is_integer: bool = False, |
| | **kwargs |
| | ) -> "ExtractionSchema": |
| | """Add a number field.""" |
| | field = FieldSpec( |
| | name=name, |
| | field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT, |
| | description=description, |
| | required=required, |
| | **kwargs |
| | ) |
| | return self.add_field(field) |
| |
|
| | def add_date_field( |
| | self, |
| | name: str, |
| | description: str = "", |
| | required: bool = True, |
| | **kwargs |
| | ) -> "ExtractionSchema": |
| | """Add a date field.""" |
| | field = FieldSpec( |
| | name=name, |
| | field_type=FieldType.DATE, |
| | description=description, |
| | required=required, |
| | **kwargs |
| | ) |
| | return self.add_field(field) |
| |
|
| | def add_currency_field( |
| | self, |
| | name: str, |
| | description: str = "", |
| | required: bool = True, |
| | **kwargs |
| | ) -> "ExtractionSchema": |
| | """Add a currency field.""" |
| | field = FieldSpec( |
| | name=name, |
| | field_type=FieldType.CURRENCY, |
| | description=description, |
| | required=required, |
| | **kwargs |
| | ) |
| | return self.add_field(field) |
| |
|
| | def get_field(self, name: str) -> Optional[FieldSpec]: |
| | """Get a field by name.""" |
| | for field in self.fields: |
| | if field.name == name: |
| | return field |
| | return None |
| |
|
| | def get_required_fields(self) -> List[FieldSpec]: |
| | """Get all required fields.""" |
| | return [f for f in self.fields if f.required] |
| |
|
| | def get_optional_fields(self) -> List[FieldSpec]: |
| | """Get all optional fields.""" |
| | return [f for f in self.fields if not f.required] |
| |
|
| | def to_json_schema(self) -> Dict[str, Any]: |
| | """Convert to JSON Schema format.""" |
| | properties = {} |
| | required = [] |
| |
|
| | for field in self.fields: |
| | properties[field.name] = field.to_json_schema() |
| | if field.required: |
| | required.append(field.name) |
| |
|
| | schema = { |
| | "type": "object", |
| | "properties": properties, |
| | } |
| |
|
| | if required: |
| | schema["required"] = required |
| |
|
| | if self.description: |
| | schema["description"] = self.description |
| |
|
| | return schema |
| |
|
| | def to_pydantic_model(self) -> Type[BaseModel]: |
| | """Generate a Pydantic model from this schema.""" |
| | field_definitions = {} |
| |
|
| | for field in self.fields: |
| | python_type = self._get_python_type(field.field_type) |
| | default = ... if field.required else field.default |
| |
|
| | field_definitions[field.name] = ( |
| | python_type, |
| | Field(default=default, description=field.description) |
| | ) |
| |
|
| | return create_model( |
| | self.name, |
| | **field_definitions |
| | ) |
| |
|
| | def _get_python_type(self, field_type: FieldType) -> type: |
| | """Get Python type for field type.""" |
| | type_mapping = { |
| | FieldType.STRING: str, |
| | FieldType.INTEGER: int, |
| | FieldType.FLOAT: float, |
| | FieldType.BOOLEAN: bool, |
| | FieldType.DATE: str, |
| | FieldType.DATETIME: str, |
| | FieldType.CURRENCY: str, |
| | FieldType.PERCENTAGE: str, |
| | FieldType.EMAIL: str, |
| | FieldType.PHONE: str, |
| | FieldType.ADDRESS: str, |
| | FieldType.LIST: list, |
| | FieldType.OBJECT: dict, |
| | } |
| | return type_mapping.get(field_type, str) |
| |
|
| | @classmethod |
| | def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema": |
| | """Create from JSON Schema.""" |
| | extraction_schema = cls( |
| | name=name, |
| | description=schema.get("description", ""), |
| | ) |
| |
|
| | properties = schema.get("properties", {}) |
| | required = set(schema.get("required", [])) |
| |
|
| | for field_name, field_schema in properties.items(): |
| | field_type = cls._json_type_to_field_type(field_schema) |
| |
|
| | field = FieldSpec( |
| | name=field_name, |
| | field_type=field_type, |
| | description=field_schema.get("description", ""), |
| | required=field_name in required, |
| | pattern=field_schema.get("pattern"), |
| | min_value=field_schema.get("minimum"), |
| | max_value=field_schema.get("maximum"), |
| | min_length=field_schema.get("minLength"), |
| | max_length=field_schema.get("maxLength"), |
| | allowed_values=field_schema.get("enum"), |
| | ) |
| |
|
| | extraction_schema.add_field(field) |
| |
|
| | return extraction_schema |
| |
|
| | @staticmethod |
| | def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType: |
| | """Convert JSON Schema type to FieldType.""" |
| | json_type = field_schema.get("type", "string") |
| | format_ = field_schema.get("format", "") |
| |
|
| | if json_type == "integer": |
| | return FieldType.INTEGER |
| | elif json_type == "number": |
| | return FieldType.FLOAT |
| | elif json_type == "boolean": |
| | return FieldType.BOOLEAN |
| | elif json_type == "array": |
| | return FieldType.LIST |
| | elif json_type == "object": |
| | return FieldType.OBJECT |
| | elif format_ == "date": |
| | return FieldType.DATE |
| | elif format_ == "date-time": |
| | return FieldType.DATETIME |
| | elif format_ == "email": |
| | return FieldType.EMAIL |
| | else: |
| | return FieldType.STRING |
| |
|
| |
|
| | |
| |
|
| | def create_invoice_schema() -> ExtractionSchema: |
| | """Create schema for invoice extraction.""" |
| | schema = ExtractionSchema( |
| | name="Invoice", |
| | description="Invoice document extraction schema" |
| | ) |
| |
|
| | schema.add_string_field("invoice_number", "Invoice number or ID", required=True) |
| | schema.add_date_field("invoice_date", "Date of invoice") |
| | schema.add_date_field("due_date", "Payment due date", required=False) |
| | schema.add_string_field("vendor_name", "Name of vendor/seller") |
| | schema.add_string_field("vendor_address", "Address of vendor", required=False) |
| | schema.add_string_field("customer_name", "Name of customer/buyer", required=False) |
| | schema.add_string_field("customer_address", "Address of customer", required=False) |
| | schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
| | schema.add_currency_field("tax_amount", "Tax amount", required=False) |
| | schema.add_currency_field("total_amount", "Total amount due", required=True) |
| | schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False) |
| | schema.add_string_field("payment_terms", "Payment terms", required=False) |
| |
|
| | return schema |
| |
|
| |
|
| | def create_receipt_schema() -> ExtractionSchema: |
| | """Create schema for receipt extraction.""" |
| | schema = ExtractionSchema( |
| | name="Receipt", |
| | description="Receipt document extraction schema" |
| | ) |
| |
|
| | schema.add_string_field("merchant_name", "Name of merchant/store") |
| | schema.add_string_field("merchant_address", "Address of merchant", required=False) |
| | schema.add_date_field("transaction_date", "Date of transaction") |
| | schema.add_string_field("transaction_time", "Time of transaction", required=False) |
| | schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
| | schema.add_currency_field("tax_amount", "Tax amount", required=False) |
| | schema.add_currency_field("total_amount", "Total amount paid") |
| | schema.add_string_field("payment_method", "Method of payment", required=False) |
| | schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False) |
| |
|
| | return schema |
| |
|
| |
|
| | def create_contract_schema() -> ExtractionSchema: |
| | """Create schema for contract extraction.""" |
| | schema = ExtractionSchema( |
| | name="Contract", |
| | description="Contract document extraction schema" |
| | ) |
| |
|
| | schema.add_string_field("contract_title", "Title of the contract", required=False) |
| | schema.add_date_field("effective_date", "Date contract becomes effective") |
| | schema.add_date_field("expiration_date", "Date contract expires", required=False) |
| | schema.add_string_field("party_a_name", "Name of first party") |
| | schema.add_string_field("party_b_name", "Name of second party") |
| | schema.add_currency_field("contract_value", "Total contract value", required=False) |
| | schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False) |
| | schema.add_string_field("termination_clause", "Summary of termination terms", required=False) |
| |
|
| | return schema |
| |
|