| | """ |
| | OCR Model Interface |
| | |
| | Abstract interface for Optical Character Recognition models. |
| | Supports both local engines and cloud services. |
| | """ |
| |
|
| | from abc import abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from ..chunks.models import BoundingBox |
| | from .base import ( |
| | BaseModel, |
| | BatchableModel, |
| | ImageInput, |
| | ModelCapability, |
| | ModelConfig, |
| | ) |
| |
|
| |
|
| | class OCREngine(str, Enum): |
| | """Supported OCR engines.""" |
| |
|
| | PADDLEOCR = "paddleocr" |
| | TESSERACT = "tesseract" |
| | EASYOCR = "easyocr" |
| | CUSTOM = "custom" |
| |
|
| |
|
| | @dataclass |
| | class OCRConfig(ModelConfig): |
| | """Configuration for OCR models.""" |
| |
|
| | engine: OCREngine = OCREngine.PADDLEOCR |
| | languages: List[str] = field(default_factory=lambda: ["en"]) |
| | detect_orientation: bool = True |
| | detect_tables: bool = True |
| | min_confidence: float = 0.5 |
| | |
| | use_angle_cls: bool = True |
| | use_gpu: bool = True |
| | |
| | tesseract_config: str = "" |
| | psm_mode: int = 3 |
| |
|
| | def __post_init__(self): |
| | super().__post_init__() |
| | if not self.name: |
| | self.name = f"ocr_{self.engine.value}" |
| |
|
| |
|
| | @dataclass |
| | class OCRWord: |
| | """A single recognized word with its bounding box.""" |
| |
|
| | text: str |
| | bbox: BoundingBox |
| | confidence: float |
| | language: Optional[str] = None |
| | is_handwritten: bool = False |
| | font_size: Optional[float] = None |
| | is_bold: bool = False |
| | is_italic: bool = False |
| |
|
| |
|
| | @dataclass |
| | class OCRLine: |
| | """A line of text composed of words.""" |
| |
|
| | text: str |
| | bbox: BoundingBox |
| | confidence: float |
| | words: List[OCRWord] = field(default_factory=list) |
| | line_index: int = 0 |
| |
|
| | @property |
| | def word_count(self) -> int: |
| | return len(self.words) |
| |
|
| | @classmethod |
| | def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine": |
| | """Create a line from a list of words.""" |
| | if not words: |
| | raise ValueError("Cannot create line from empty word list") |
| |
|
| | text = " ".join(w.text for w in words) |
| | confidence = sum(w.confidence for w in words) / len(words) |
| |
|
| | |
| | x_min = min(w.bbox.x_min for w in words) |
| | y_min = min(w.bbox.y_min for w in words) |
| | x_max = max(w.bbox.x_max for w in words) |
| | y_max = max(w.bbox.y_max for w in words) |
| |
|
| | bbox = BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=words[0].bbox.normalized |
| | ) |
| |
|
| | return cls( |
| | text=text, |
| | bbox=bbox, |
| | confidence=confidence, |
| | words=words, |
| | line_index=line_index |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class OCRBlock: |
| | """A block of text composed of lines (e.g., a paragraph).""" |
| |
|
| | text: str |
| | bbox: BoundingBox |
| | confidence: float |
| | lines: List[OCRLine] = field(default_factory=list) |
| | block_type: str = "text" |
| |
|
| | @property |
| | def line_count(self) -> int: |
| | return len(self.lines) |
| |
|
| | @classmethod |
| | def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock": |
| | """Create a block from a list of lines.""" |
| | if not lines: |
| | raise ValueError("Cannot create block from empty line list") |
| |
|
| | text = "\n".join(line.text for line in lines) |
| | confidence = sum(line.confidence for line in lines) / len(lines) |
| |
|
| | x_min = min(line.bbox.x_min for line in lines) |
| | y_min = min(line.bbox.y_min for line in lines) |
| | x_max = max(line.bbox.x_max for line in lines) |
| | y_max = max(line.bbox.y_max for line in lines) |
| |
|
| | bbox = BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=lines[0].bbox.normalized |
| | ) |
| |
|
| | return cls( |
| | text=text, |
| | bbox=bbox, |
| | confidence=confidence, |
| | lines=lines, |
| | block_type=block_type |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class OCRResult: |
| | """Complete OCR result for a single page/image.""" |
| |
|
| | text: str |
| | blocks: List[OCRBlock] = field(default_factory=list) |
| | lines: List[OCRLine] = field(default_factory=list) |
| | words: List[OCRWord] = field(default_factory=list) |
| | confidence: float = 0.0 |
| | language_detected: Optional[str] = None |
| | orientation: float = 0.0 |
| | deskew_angle: float = 0.0 |
| | image_width: int = 0 |
| | image_height: int = 0 |
| | processing_time_ms: float = 0.0 |
| | engine_metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | @property |
| | def word_count(self) -> int: |
| | return len(self.words) |
| |
|
| | @property |
| | def line_count(self) -> int: |
| | return len(self.lines) |
| |
|
| | @property |
| | def block_count(self) -> int: |
| | return len(self.blocks) |
| |
|
| | def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str: |
| | """ |
| | Get text within a specific bounding box region. |
| | |
| | Args: |
| | bbox: Region to extract text from |
| | threshold: Minimum IoU overlap required |
| | |
| | Returns: |
| | Concatenated text of words in region |
| | """ |
| | words_in_region = [] |
| | for word in self.words: |
| | iou = word.bbox.iou(bbox) |
| | if iou >= threshold or bbox.contains(word.bbox.center): |
| | words_in_region.append(word) |
| |
|
| | |
| | words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min)) |
| | return " ".join(w.text for w in words_in_region) |
| |
|
| |
|
| | class OCRModel(BatchableModel): |
| | """ |
| | Abstract base class for OCR models. |
| | |
| | Implementations should handle: |
| | - Text detection (finding text regions) |
| | - Text recognition (converting regions to text) |
| | - Word/line/block segmentation |
| | - Confidence scoring |
| | """ |
| |
|
| | def __init__(self, config: Optional[OCRConfig] = None): |
| | super().__init__(config or OCRConfig(name="ocr")) |
| | self.config: OCRConfig = self.config |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | return [ModelCapability.OCR] |
| |
|
| | @abstractmethod |
| | def recognize( |
| | self, |
| | image: ImageInput, |
| | **kwargs |
| | ) -> OCRResult: |
| | """ |
| | Perform OCR on a single image. |
| | |
| | Args: |
| | image: Input image (numpy array, PIL Image, or path) |
| | **kwargs: Additional engine-specific parameters |
| | |
| | Returns: |
| | OCRResult with detected text and locations |
| | """ |
| | pass |
| |
|
| | def process_batch( |
| | self, |
| | inputs: List[ImageInput], |
| | **kwargs |
| | ) -> List[OCRResult]: |
| | """ |
| | Process multiple images. |
| | |
| | Default implementation processes sequentially. |
| | Override for optimized batch processing. |
| | """ |
| | return [self.recognize(img, **kwargs) for img in inputs] |
| |
|
| | def detect_text_regions( |
| | self, |
| | image: ImageInput, |
| | **kwargs |
| | ) -> List[BoundingBox]: |
| | """ |
| | Detect text regions without performing recognition. |
| | |
| | Useful for layout analysis or selective OCR. |
| | |
| | Args: |
| | image: Input image |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | List of bounding boxes containing text |
| | """ |
| | |
| | result = self.recognize(image, **kwargs) |
| | return [block.bbox for block in result.blocks] |
| |
|
| | def recognize_region( |
| | self, |
| | image: ImageInput, |
| | region: BoundingBox, |
| | **kwargs |
| | ) -> OCRResult: |
| | """ |
| | Perform OCR on a specific region of an image. |
| | |
| | Args: |
| | image: Full image |
| | region: Region to OCR |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | OCR result for the region |
| | """ |
| | from .base import ensure_pil_image |
| |
|
| | pil_image = ensure_pil_image(image) |
| |
|
| | |
| | if region.normalized: |
| | pixel_bbox = region.to_pixel(pil_image.width, pil_image.height) |
| | else: |
| | pixel_bbox = region |
| |
|
| | |
| | cropped = pil_image.crop(( |
| | int(pixel_bbox.x_min), |
| | int(pixel_bbox.y_min), |
| | int(pixel_bbox.x_max), |
| | int(pixel_bbox.y_max) |
| | )) |
| |
|
| | |
| | result = self.recognize(cropped, **kwargs) |
| |
|
| | |
| | offset_x = pixel_bbox.x_min |
| | offset_y = pixel_bbox.y_min |
| |
|
| | for word in result.words: |
| | word.bbox = BoundingBox( |
| | x_min=word.bbox.x_min + offset_x, |
| | y_min=word.bbox.y_min + offset_y, |
| | x_max=word.bbox.x_max + offset_x, |
| | y_max=word.bbox.y_max + offset_y, |
| | normalized=False |
| | ) |
| |
|
| | for line in result.lines: |
| | line.bbox = BoundingBox( |
| | x_min=line.bbox.x_min + offset_x, |
| | y_min=line.bbox.y_min + offset_y, |
| | x_max=line.bbox.x_max + offset_x, |
| | y_max=line.bbox.y_max + offset_y, |
| | normalized=False |
| | ) |
| |
|
| | for block in result.blocks: |
| | block.bbox = BoundingBox( |
| | x_min=block.bbox.x_min + offset_x, |
| | y_min=block.bbox.y_min + offset_y, |
| | x_max=block.bbox.x_max + offset_x, |
| | y_max=block.bbox.y_max + offset_y, |
| | normalized=False |
| | ) |
| |
|
| | return result |
| |
|