| | """ |
| | Layout Detection Model Interface |
| | |
| | Abstract interface for document layout analysis models. |
| | Detects regions like text blocks, tables, figures, headers, etc. |
| | """ |
| |
|
| | from abc import abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from ..chunks.models import BoundingBox, ChunkType |
| | from .base import ( |
| | BaseModel, |
| | BatchableModel, |
| | ImageInput, |
| | ModelCapability, |
| | ModelConfig, |
| | ) |
| |
|
| |
|
| | class LayoutRegionType(str, Enum): |
| | """Types of layout regions that can be detected.""" |
| |
|
| | |
| | TEXT = "text" |
| | TITLE = "title" |
| | HEADING = "heading" |
| | PARAGRAPH = "paragraph" |
| | LIST = "list" |
| |
|
| | |
| | TABLE = "table" |
| | FIGURE = "figure" |
| | CHART = "chart" |
| | FORMULA = "formula" |
| | CODE = "code" |
| |
|
| | |
| | HEADER = "header" |
| | FOOTER = "footer" |
| | PAGE_NUMBER = "page_number" |
| | CAPTION = "caption" |
| | FOOTNOTE = "footnote" |
| |
|
| | |
| | LOGO = "logo" |
| | SIGNATURE = "signature" |
| | STAMP = "stamp" |
| | WATERMARK = "watermark" |
| | FORM_FIELD = "form_field" |
| | CHECKBOX = "checkbox" |
| |
|
| | |
| | UNKNOWN = "unknown" |
| |
|
| | def to_chunk_type(self) -> ChunkType: |
| | """Convert layout region type to chunk type.""" |
| | mapping = { |
| | LayoutRegionType.TEXT: ChunkType.TEXT, |
| | LayoutRegionType.TITLE: ChunkType.TITLE, |
| | LayoutRegionType.HEADING: ChunkType.HEADING, |
| | LayoutRegionType.PARAGRAPH: ChunkType.PARAGRAPH, |
| | LayoutRegionType.LIST: ChunkType.LIST, |
| | LayoutRegionType.TABLE: ChunkType.TABLE, |
| | LayoutRegionType.FIGURE: ChunkType.FIGURE, |
| | LayoutRegionType.CHART: ChunkType.CHART, |
| | LayoutRegionType.FORMULA: ChunkType.FORMULA, |
| | LayoutRegionType.CODE: ChunkType.CODE, |
| | LayoutRegionType.HEADER: ChunkType.HEADER, |
| | LayoutRegionType.FOOTER: ChunkType.FOOTER, |
| | LayoutRegionType.PAGE_NUMBER: ChunkType.PAGE_NUMBER, |
| | LayoutRegionType.CAPTION: ChunkType.CAPTION, |
| | LayoutRegionType.FOOTNOTE: ChunkType.FOOTNOTE, |
| | LayoutRegionType.LOGO: ChunkType.LOGO, |
| | LayoutRegionType.SIGNATURE: ChunkType.SIGNATURE, |
| | LayoutRegionType.STAMP: ChunkType.STAMP, |
| | LayoutRegionType.WATERMARK: ChunkType.WATERMARK, |
| | LayoutRegionType.FORM_FIELD: ChunkType.FORM_FIELD, |
| | LayoutRegionType.CHECKBOX: ChunkType.CHECKBOX, |
| | } |
| | return mapping.get(self, ChunkType.TEXT) |
| |
|
| |
|
| | @dataclass |
| | class LayoutConfig(ModelConfig): |
| | """Configuration for layout detection models.""" |
| |
|
| | min_confidence: float = 0.5 |
| | merge_overlapping: bool = True |
| | overlap_threshold: float = 0.5 |
| | detect_reading_order: bool = True |
| | detect_columns: bool = True |
| | region_types: Optional[List[LayoutRegionType]] = None |
| |
|
| | def __post_init__(self): |
| | super().__post_init__() |
| | if not self.name: |
| | self.name = "layout_detector" |
| |
|
| |
|
| | @dataclass |
| | class LayoutRegion: |
| | """A detected layout region.""" |
| |
|
| | region_type: LayoutRegionType |
| | bbox: BoundingBox |
| | confidence: float |
| | region_id: str = "" |
| |
|
| | |
| | reading_order: int = -1 |
| |
|
| | |
| | parent_id: Optional[str] = None |
| | child_ids: List[str] = field(default_factory=list) |
| |
|
| | |
| | column_index: int = 0 |
| | num_columns: int = 1 |
| |
|
| | |
| | attributes: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | def __post_init__(self): |
| | if not self.region_id: |
| | import hashlib |
| | content = f"{self.region_type.value}_{self.bbox.xyxy}" |
| | self.region_id = hashlib.md5(content.encode()).hexdigest()[:12] |
| |
|
| |
|
| | @dataclass |
| | class LayoutResult: |
| | """Complete layout analysis result for a page.""" |
| |
|
| | regions: List[LayoutRegion] = field(default_factory=list) |
| | reading_order: List[str] = field(default_factory=list) |
| | num_columns: int = 1 |
| | page_orientation: float = 0.0 |
| | image_width: int = 0 |
| | image_height: int = 0 |
| | processing_time_ms: float = 0.0 |
| | model_metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | def get_regions_by_type(self, region_type: LayoutRegionType) -> List[LayoutRegion]: |
| | """Get all regions of a specific type.""" |
| | return [r for r in self.regions if r.region_type == region_type] |
| |
|
| | def get_region_by_id(self, region_id: str) -> Optional[LayoutRegion]: |
| | """Get a region by its ID.""" |
| | for region in self.regions: |
| | if region.region_id == region_id: |
| | return region |
| | return None |
| |
|
| | def get_ordered_regions(self) -> List[LayoutRegion]: |
| | """Get regions in reading order.""" |
| | if not self.reading_order: |
| | |
| | return sorted( |
| | self.regions, |
| | key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
| | ) |
| |
|
| | ordered = [] |
| | for region_id in self.reading_order: |
| | region = self.get_region_by_id(region_id) |
| | if region: |
| | ordered.append(region) |
| | return ordered |
| |
|
| | def get_tables(self) -> List[LayoutRegion]: |
| | """Get all table regions.""" |
| | return self.get_regions_by_type(LayoutRegionType.TABLE) |
| |
|
| | def get_figures(self) -> List[LayoutRegion]: |
| | """Get all figure regions.""" |
| | return self.get_regions_by_type(LayoutRegionType.FIGURE) |
| |
|
| | def get_text_regions(self) -> List[LayoutRegion]: |
| | """Get all text-based regions.""" |
| | text_types = { |
| | LayoutRegionType.TEXT, |
| | LayoutRegionType.TITLE, |
| | LayoutRegionType.HEADING, |
| | LayoutRegionType.PARAGRAPH, |
| | LayoutRegionType.LIST, |
| | LayoutRegionType.CAPTION, |
| | LayoutRegionType.FOOTNOTE, |
| | } |
| | return [r for r in self.regions if r.region_type in text_types] |
| |
|
| |
|
| | class LayoutModel(BatchableModel): |
| | """ |
| | Abstract base class for layout detection models. |
| | |
| | Implementations should detect: |
| | - Document regions (text, tables, figures, etc.) |
| | - Reading order |
| | - Column structure |
| | - Region hierarchy |
| | """ |
| |
|
| | def __init__(self, config: Optional[LayoutConfig] = None): |
| | super().__init__(config or LayoutConfig(name="layout")) |
| | self.config: LayoutConfig = self.config |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | caps = [ModelCapability.LAYOUT_DETECTION] |
| | if self.config.detect_reading_order: |
| | caps.append(ModelCapability.READING_ORDER) |
| | return caps |
| |
|
| | @abstractmethod |
| | def detect( |
| | self, |
| | image: ImageInput, |
| | **kwargs |
| | ) -> LayoutResult: |
| | """ |
| | Detect layout regions in an image. |
| | |
| | Args: |
| | image: Input document image |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | LayoutResult with detected regions |
| | """ |
| | pass |
| |
|
| | def process_batch( |
| | self, |
| | inputs: List[ImageInput], |
| | **kwargs |
| | ) -> List[LayoutResult]: |
| | """Process multiple images.""" |
| | return [self.detect(img, **kwargs) for img in inputs] |
| |
|
| | def detect_tables( |
| | self, |
| | image: ImageInput, |
| | **kwargs |
| | ) -> List[LayoutRegion]: |
| | """ |
| | Detect only table regions. |
| | |
| | Convenience method that filters layout detection results. |
| | """ |
| | result = self.detect(image, **kwargs) |
| | return result.get_tables() |
| |
|
| | def detect_figures( |
| | self, |
| | image: ImageInput, |
| | **kwargs |
| | ) -> List[LayoutRegion]: |
| | """Detect only figure regions.""" |
| | result = self.detect(image, **kwargs) |
| | return result.get_figures() |
| |
|
| |
|
| | class ReadingOrderModel(BaseModel): |
| | """ |
| | Abstract base class for reading order determination. |
| | |
| | Some implementations may be separate from layout detection, |
| | requiring a specialized model for complex layouts. |
| | """ |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | return [ModelCapability.READING_ORDER] |
| |
|
| | @abstractmethod |
| | def determine_order( |
| | self, |
| | regions: List[LayoutRegion], |
| | image: Optional[ImageInput] = None, |
| | **kwargs |
| | ) -> List[str]: |
| | """ |
| | Determine reading order for a list of regions. |
| | |
| | Args: |
| | regions: Layout regions to order |
| | image: Optional image for visual cues |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | List of region_ids in reading order |
| | """ |
| | pass |
| |
|
| |
|
| | class HeuristicReadingOrderModel(ReadingOrderModel): |
| | """ |
| | Simple heuristic-based reading order model. |
| | |
| | Uses geometric analysis for column detection and ordering. |
| | Suitable for simple document layouts. |
| | """ |
| |
|
| | def __init__(self, config: Optional[ModelConfig] = None): |
| | super().__init__(config or ModelConfig(name="heuristic_reading_order")) |
| |
|
| | def load(self) -> None: |
| | self._is_loaded = True |
| |
|
| | def unload(self) -> None: |
| | self._is_loaded = False |
| |
|
| | def determine_order( |
| | self, |
| | regions: List[LayoutRegion], |
| | image: Optional[ImageInput] = None, |
| | column_threshold: float = 0.3, |
| | **kwargs |
| | ) -> List[str]: |
| | """ |
| | Determine reading order using heuristics. |
| | |
| | Strategy: |
| | 1. Detect columns based on x-coordinate clustering |
| | 2. Within each column, sort top-to-bottom |
| | 3. Process columns left-to-right |
| | """ |
| | if not regions: |
| | return [] |
| |
|
| | |
| | columns = self._detect_columns(regions, column_threshold) |
| |
|
| | |
| | ordered_ids = [] |
| | for column in columns: |
| | column_regions = sorted(column, key=lambda r: r.bbox.y_min) |
| | ordered_ids.extend(r.region_id for r in column_regions) |
| |
|
| | return ordered_ids |
| |
|
| | def _detect_columns( |
| | self, |
| | regions: List[LayoutRegion], |
| | threshold: float |
| | ) -> List[List[LayoutRegion]]: |
| | """Detect columns by x-coordinate clustering.""" |
| | if not regions: |
| | return [] |
| |
|
| | |
| | sorted_regions = sorted(regions, key=lambda r: r.bbox.x_min) |
| |
|
| | columns = [] |
| | current_column = [sorted_regions[0]] |
| |
|
| | for region in sorted_regions[1:]: |
| | |
| | prev_region = current_column[-1] |
| |
|
| | |
| | overlap_start = max(region.bbox.x_min, prev_region.bbox.x_min) |
| | overlap_end = min(region.bbox.x_max, prev_region.bbox.x_max) |
| |
|
| | if overlap_end > overlap_start: |
| | |
| | current_column.append(region) |
| | else: |
| | |
| | columns.append(current_column) |
| | current_column = [region] |
| |
|
| | columns.append(current_column) |
| | return columns |
| |
|