| | """ |
| | Base Model Interfaces for Document Intelligence |
| | |
| | Abstract base classes defining the contract for all model components. |
| | All models are pluggable and can be swapped without changing the pipeline. |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import numpy as np |
| | from PIL import Image |
| |
|
| |
|
| | class ModelCapability(str, Enum): |
| | """Capabilities that a model may support.""" |
| |
|
| | OCR = "ocr" |
| | LAYOUT_DETECTION = "layout_detection" |
| | TABLE_EXTRACTION = "table_extraction" |
| | CHART_EXTRACTION = "chart_extraction" |
| | READING_ORDER = "reading_order" |
| | VISION_LANGUAGE = "vision_language" |
| | EMBEDDING = "embedding" |
| | CLASSIFICATION = "classification" |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig: |
| | """Base configuration for all models.""" |
| |
|
| | name: str |
| | version: str = "1.0.0" |
| | device: str = "auto" |
| | batch_size: int = 1 |
| | max_workers: int = 4 |
| | cache_enabled: bool = True |
| | cache_dir: Optional[Path] = None |
| | timeout_seconds: float = 300.0 |
| | extra_params: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | def __post_init__(self): |
| | if self.cache_dir is not None: |
| | self.cache_dir = Path(self.cache_dir) |
| |
|
| |
|
| | @dataclass |
| | class ModelMetadata: |
| | """Metadata about a loaded model.""" |
| |
|
| | name: str |
| | version: str |
| | capabilities: List[ModelCapability] |
| | device: str |
| | memory_usage_mb: float = 0.0 |
| | is_loaded: bool = False |
| | supports_batching: bool = False |
| | max_batch_size: int = 1 |
| | input_requirements: Dict[str, Any] = field(default_factory=dict) |
| | output_format: Dict[str, Any] = field(default_factory=dict) |
| |
|
| |
|
| | class BaseModel(ABC): |
| | """ |
| | Abstract base class for all document intelligence models. |
| | |
| | All model implementations must inherit from this class and implement |
| | the required abstract methods. |
| | """ |
| |
|
| | def __init__(self, config: Optional[ModelConfig] = None): |
| | self.config = config or ModelConfig(name=self.__class__.__name__) |
| | self._is_loaded = False |
| | self._metadata: Optional[ModelMetadata] = None |
| |
|
| | @property |
| | def is_loaded(self) -> bool: |
| | """Check if the model is loaded and ready for inference.""" |
| | return self._is_loaded |
| |
|
| | @property |
| | def metadata(self) -> Optional[ModelMetadata]: |
| | """Get model metadata.""" |
| | return self._metadata |
| |
|
| | @abstractmethod |
| | def load(self) -> None: |
| | """ |
| | Load the model into memory. |
| | |
| | Should set self._is_loaded = True upon successful loading. |
| | Should populate self._metadata with model information. |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def unload(self) -> None: |
| | """ |
| | Unload the model from memory. |
| | |
| | Should set self._is_loaded = False. |
| | Should free GPU/CPU memory. |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def get_capabilities(self) -> List[ModelCapability]: |
| | """Return list of capabilities this model provides.""" |
| | pass |
| |
|
| | def validate_input(self, input_data: Any) -> bool: |
| | """ |
| | Validate input data before processing. |
| | |
| | Override in subclasses for specific validation. |
| | """ |
| | return True |
| |
|
| | def preprocess(self, input_data: Any) -> Any: |
| | """ |
| | Preprocess input data before model inference. |
| | |
| | Override in subclasses for specific preprocessing. |
| | """ |
| | return input_data |
| |
|
| | def postprocess(self, output_data: Any) -> Any: |
| | """ |
| | Postprocess model output. |
| | |
| | Override in subclasses for specific postprocessing. |
| | """ |
| | return output_data |
| |
|
| | def __enter__(self): |
| | """Context manager entry.""" |
| | if not self.is_loaded: |
| | self.load() |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | """Context manager exit.""" |
| | self.unload() |
| | return False |
| |
|
| |
|
| | class BatchableModel(BaseModel): |
| | """ |
| | Base class for models that support batch processing. |
| | |
| | Provides infrastructure for processing multiple inputs efficiently. |
| | """ |
| |
|
| | @abstractmethod |
| | def process_batch( |
| | self, |
| | inputs: List[Any], |
| | **kwargs |
| | ) -> List[Any]: |
| | """ |
| | Process a batch of inputs. |
| | |
| | Args: |
| | inputs: List of input items to process |
| | **kwargs: Additional processing parameters |
| | |
| | Returns: |
| | List of outputs, one per input |
| | """ |
| | pass |
| |
|
| | def process_single(self, input_data: Any, **kwargs) -> Any: |
| | """Process a single input by wrapping in a batch.""" |
| | results = self.process_batch([input_data], **kwargs) |
| | return results[0] if results else None |
| |
|
| |
|
| | ImageInput = Union[np.ndarray, Image.Image, Path, str] |
| |
|
| |
|
| | def normalize_image_input(image: ImageInput) -> np.ndarray: |
| | """ |
| | Normalize various image input formats to numpy array. |
| | |
| | Args: |
| | image: Image as numpy array, PIL Image, or path |
| | |
| | Returns: |
| | Image as numpy array (RGB, HWC format) |
| | """ |
| | if isinstance(image, np.ndarray): |
| | return image |
| |
|
| | if isinstance(image, Image.Image): |
| | return np.array(image.convert("RGB")) |
| |
|
| | if isinstance(image, (str, Path)): |
| | img = Image.open(image).convert("RGB") |
| | return np.array(img) |
| |
|
| | raise ValueError(f"Unsupported image input type: {type(image)}") |
| |
|
| |
|
| | def ensure_pil_image(image: ImageInput) -> Image.Image: |
| | """ |
| | Ensure input is a PIL Image. |
| | |
| | Args: |
| | image: Image as numpy array, PIL Image, or path |
| | |
| | Returns: |
| | PIL Image in RGB mode |
| | """ |
| | if isinstance(image, Image.Image): |
| | return image.convert("RGB") |
| |
|
| | if isinstance(image, np.ndarray): |
| | return Image.fromarray(image).convert("RGB") |
| |
|
| | if isinstance(image, (str, Path)): |
| | return Image.open(image).convert("RGB") |
| |
|
| | raise ValueError(f"Unsupported image input type: {type(image)}") |
| |
|