| | """ |
| | Layout Detection Base Interface |
| | |
| | Defines the abstract interface for document layout detection. |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import List, Optional, Dict, Any |
| | from dataclasses import dataclass, field |
| | from pydantic import BaseModel, Field |
| | import numpy as np |
| |
|
| | from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
| |
|
| |
|
| | class LayoutConfig(BaseModel): |
| | """Configuration for layout detection.""" |
| | |
| | method: str = Field( |
| | default="rule_based", |
| | description="Detection method: rule_based, paddle_structure, layoutlm" |
| | ) |
| |
|
| | |
| | min_confidence: float = Field( |
| | default=0.5, |
| | ge=0.0, |
| | le=1.0, |
| | description="Minimum confidence for detected regions" |
| | ) |
| |
|
| | |
| | detect_tables: bool = Field(default=True, description="Detect table regions") |
| | detect_figures: bool = Field(default=True, description="Detect figure regions") |
| | detect_headers: bool = Field(default=True, description="Detect header/footer") |
| | detect_titles: bool = Field(default=True, description="Detect title/heading") |
| | detect_lists: bool = Field(default=True, description="Detect list structures") |
| |
|
| | |
| | merge_threshold: float = Field( |
| | default=0.7, |
| | ge=0.0, |
| | le=1.0, |
| | description="IoU threshold for merging overlapping regions" |
| | ) |
| |
|
| | |
| | use_gpu: bool = Field(default=True, description="Use GPU acceleration") |
| | gpu_id: int = Field(default=0, ge=0, description="GPU device ID") |
| |
|
| | |
| | table_min_rows: int = Field(default=2, ge=1, description="Minimum rows for table") |
| | table_min_cols: int = Field(default=2, ge=1, description="Minimum columns for table") |
| |
|
| | |
| | title_max_lines: int = Field(default=3, description="Max lines for title") |
| | heading_font_ratio: float = Field( |
| | default=1.2, |
| | description="Font size ratio vs body text for headings" |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class LayoutResult: |
| | """Result of layout detection for a page.""" |
| | page: int |
| | regions: List[LayoutRegion] = field(default_factory=list) |
| | image_width: int = 0 |
| | image_height: int = 0 |
| | processing_time_ms: float = 0.0 |
| |
|
| | |
| | success: bool = True |
| | error: Optional[str] = None |
| |
|
| | def get_regions_by_type(self, layout_type: LayoutType) -> List[LayoutRegion]: |
| | """Get regions of a specific type.""" |
| | return [r for r in self.regions if r.type == layout_type] |
| |
|
| | def get_tables(self) -> List[LayoutRegion]: |
| | """Get table regions.""" |
| | return self.get_regions_by_type(LayoutType.TABLE) |
| |
|
| | def get_figures(self) -> List[LayoutRegion]: |
| | """Get figure regions.""" |
| | return self.get_regions_by_type(LayoutType.FIGURE) |
| |
|
| | def get_text_regions(self) -> List[LayoutRegion]: |
| | """Get text-based regions (paragraph, title, heading, list).""" |
| | text_types = { |
| | LayoutType.TEXT, |
| | LayoutType.TITLE, |
| | LayoutType.HEADING, |
| | LayoutType.PARAGRAPH, |
| | LayoutType.LIST, |
| | } |
| | return [r for r in self.regions if r.type in text_types] |
| |
|
| |
|
| | class LayoutDetector(ABC): |
| | """ |
| | Abstract base class for layout detectors. |
| | """ |
| |
|
| | def __init__(self, config: Optional[LayoutConfig] = None): |
| | """ |
| | Initialize layout detector. |
| | |
| | Args: |
| | config: Layout detection configuration |
| | """ |
| | self.config = config or LayoutConfig() |
| | self._initialized = False |
| |
|
| | @abstractmethod |
| | def initialize(self): |
| | """Initialize the detector (load models, etc.).""" |
| | pass |
| |
|
| | @abstractmethod |
| | def detect( |
| | self, |
| | image: np.ndarray, |
| | page_number: int = 0, |
| | ocr_regions: Optional[List[OCRRegion]] = None, |
| | ) -> LayoutResult: |
| | """ |
| | Detect layout regions in an image. |
| | |
| | Args: |
| | image: Image as numpy array (RGB, HWC format) |
| | page_number: Page number |
| | ocr_regions: Optional OCR regions for text-aware detection |
| | |
| | Returns: |
| | LayoutResult with detected regions |
| | """ |
| | pass |
| |
|
| | def detect_batch( |
| | self, |
| | images: List[np.ndarray], |
| | page_numbers: Optional[List[int]] = None, |
| | ocr_results: Optional[List[List[OCRRegion]]] = None, |
| | ) -> List[LayoutResult]: |
| | """ |
| | Detect layout in multiple images. |
| | |
| | Args: |
| | images: List of images |
| | page_numbers: Optional page numbers |
| | ocr_results: Optional OCR regions for each page |
| | |
| | Returns: |
| | List of LayoutResult |
| | """ |
| | if page_numbers is None: |
| | page_numbers = list(range(len(images))) |
| | if ocr_results is None: |
| | ocr_results = [None] * len(images) |
| |
|
| | results = [] |
| | for img, page_num, ocr in zip(images, page_numbers, ocr_results): |
| | results.append(self.detect(img, page_num, ocr)) |
| | return results |
| |
|
| | @property |
| | def name(self) -> str: |
| | """Return detector name.""" |
| | return self.__class__.__name__ |
| |
|
| | @property |
| | def is_initialized(self) -> bool: |
| | """Check if detector is initialized.""" |
| | return self._initialized |
| |
|
| | def __enter__(self): |
| | """Context manager entry.""" |
| | if not self._initialized: |
| | self.initialize() |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | """Context manager exit.""" |
| | pass |
| |
|