| | """ |
| | Layout Detector Implementations |
| | |
| | Rule-based and model-based layout detection. |
| | """ |
| |
|
| | import time |
| | import uuid |
| | from typing import List, Optional, Dict, Tuple |
| | from collections import defaultdict |
| | import numpy as np |
| | from loguru import logger |
| |
|
| | from .base import LayoutDetector, LayoutConfig, LayoutResult |
| | from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
| |
|
| |
|
| | class RuleBasedLayoutDetector(LayoutDetector): |
| | """ |
| | Rule-based layout detector using OCR region analysis. |
| | |
| | Uses heuristics based on: |
| | - Text positioning and alignment |
| | - Font size estimation (based on region height) |
| | - Spacing patterns |
| | - Structural patterns (tables, lists) |
| | """ |
| |
|
| | def __init__(self, config: Optional[LayoutConfig] = None): |
| | """Initialize rule-based detector.""" |
| | super().__init__(config) |
| |
|
| | def initialize(self): |
| | """Initialize detector (no model loading needed for rule-based).""" |
| | self._initialized = True |
| | logger.info("Initialized rule-based layout detector") |
| |
|
| | def detect( |
| | self, |
| | image: np.ndarray, |
| | page_number: int = 0, |
| | ocr_regions: Optional[List[OCRRegion]] = None, |
| | ) -> LayoutResult: |
| | """ |
| | Detect layout regions using rule-based analysis. |
| | |
| | Args: |
| | image: Page image |
| | page_number: Page number |
| | ocr_regions: OCR regions for text-based analysis |
| | |
| | Returns: |
| | LayoutResult with detected regions |
| | """ |
| | if not self._initialized: |
| | self.initialize() |
| |
|
| | start_time = time.time() |
| | height, width = image.shape[:2] |
| |
|
| | regions = [] |
| | region_counter = 0 |
| |
|
| | def make_region_id(): |
| | nonlocal region_counter |
| | region_counter += 1 |
| | return f"region_{page_number}_{region_counter}" |
| |
|
| | if ocr_regions: |
| | |
| | regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height)) |
| | regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id)) |
| | regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id)) |
| | regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id)) |
| | regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height)) |
| |
|
| | |
| | if self.config.detect_figures: |
| | regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions)) |
| |
|
| | |
| | regions = self._merge_overlapping_regions(regions) |
| |
|
| | |
| | regions = self._assign_reading_order(regions) |
| |
|
| | processing_time = (time.time() - start_time) * 1000 |
| |
|
| | return LayoutResult( |
| | page=page_number, |
| | regions=regions, |
| | image_width=width, |
| | image_height=height, |
| | processing_time_ms=processing_time, |
| | success=True, |
| | ) |
| |
|
| | def _detect_titles_headings( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | page_number: int, |
| | make_id, |
| | page_height: int, |
| | ) -> List[LayoutRegion]: |
| | """Detect title and heading regions based on font size and position.""" |
| | if not ocr_regions or not self.config.detect_titles: |
| | return [] |
| |
|
| | regions = [] |
| |
|
| | |
| | heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0] |
| | if not heights: |
| | return [] |
| |
|
| | avg_height = np.median(heights) |
| | title_threshold = avg_height * self.config.heading_font_ratio |
| |
|
| | |
| | lines = self._group_into_lines(ocr_regions) |
| |
|
| | for line_id, line_regions in lines.items(): |
| | if not line_regions: |
| | continue |
| |
|
| | |
| | line_height = max(r.bbox.height for r in line_regions) |
| | line_text = " ".join(r.text for r in line_regions) |
| | line_y = min(r.bbox.y_min for r in line_regions) |
| |
|
| | |
| | is_large_text = line_height > title_threshold |
| | is_short = len(line_text) < 100 |
| | is_top_of_page = line_y < page_height * 0.15 |
| |
|
| | if is_large_text and is_short: |
| | |
| | x_min = min(r.bbox.x_min for r in line_regions) |
| | y_min = min(r.bbox.y_min for r in line_regions) |
| | x_max = max(r.bbox.x_max for r in line_regions) |
| | y_max = max(r.bbox.y_max for r in line_regions) |
| |
|
| | |
| | if is_top_of_page and line_height > title_threshold * 1.2: |
| | layout_type = LayoutType.TITLE |
| | else: |
| | layout_type = LayoutType.HEADING |
| |
|
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=layout_type, |
| | confidence=0.8, |
| | bbox=BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=False, |
| | ), |
| | page=page_number, |
| | ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions], |
| | )) |
| |
|
| | return regions |
| |
|
| | def _detect_paragraphs( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | page_number: int, |
| | make_id, |
| | ) -> List[LayoutRegion]: |
| | """Detect paragraph regions by grouping nearby text.""" |
| | if not ocr_regions: |
| | return [] |
| |
|
| | regions = [] |
| |
|
| | |
| | lines = self._group_into_lines(ocr_regions) |
| | paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions) |
| |
|
| | for para_lines in paragraphs: |
| | if not para_lines: |
| | continue |
| |
|
| | |
| | para_regions = [] |
| | for line_id in para_lines: |
| | para_regions.extend(lines.get(line_id, [])) |
| |
|
| | if not para_regions: |
| | continue |
| |
|
| | |
| | x_min = min(r.bbox.x_min for r in para_regions) |
| | y_min = min(r.bbox.y_min for r in para_regions) |
| | x_max = max(r.bbox.x_max for r in para_regions) |
| | y_max = max(r.bbox.y_max for r in para_regions) |
| |
|
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=LayoutType.PARAGRAPH, |
| | confidence=0.7, |
| | bbox=BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=False, |
| | ), |
| | page=page_number, |
| | ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions], |
| | )) |
| |
|
| | return regions |
| |
|
| | def _detect_lists( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | page_number: int, |
| | make_id, |
| | ) -> List[LayoutRegion]: |
| | """Detect list structures based on bullet/number patterns.""" |
| | if not ocr_regions or not self.config.detect_lists: |
| | return [] |
| |
|
| | regions = [] |
| |
|
| | |
| | bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'} |
| | number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', |
| | '1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)') |
| |
|
| | |
| | lines = self._group_into_lines(ocr_regions) |
| |
|
| | |
| | list_lines = [] |
| | current_list = [] |
| |
|
| | sorted_line_ids = sorted(lines.keys()) |
| | for line_id in sorted_line_ids: |
| | line_regions = lines[line_id] |
| | if not line_regions: |
| | continue |
| |
|
| | first_text = line_regions[0].text.strip() |
| |
|
| | |
| | is_list_item = ( |
| | any(first_text.startswith(p) for p in bullet_patterns) or |
| | any(first_text.startswith(p) for p in number_patterns) or |
| | (len(first_text) <= 3 and first_text.endswith('.')) |
| | ) |
| |
|
| | if is_list_item: |
| | current_list.append(line_id) |
| | else: |
| | if len(current_list) >= 2: |
| | list_lines.append(current_list) |
| | current_list = [] |
| |
|
| | |
| | if len(current_list) >= 2: |
| | list_lines.append(current_list) |
| |
|
| | |
| | for list_line_ids in list_lines: |
| | list_regions = [] |
| | for line_id in list_line_ids: |
| | list_regions.extend(lines.get(line_id, [])) |
| |
|
| | if not list_regions: |
| | continue |
| |
|
| | x_min = min(r.bbox.x_min for r in list_regions) |
| | y_min = min(r.bbox.y_min for r in list_regions) |
| | x_max = max(r.bbox.x_max for r in list_regions) |
| | y_max = max(r.bbox.y_max for r in list_regions) |
| |
|
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=LayoutType.LIST, |
| | confidence=0.75, |
| | bbox=BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=False, |
| | ), |
| | page=page_number, |
| | ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions], |
| | extra={"item_count": len(list_line_ids)}, |
| | )) |
| |
|
| | return regions |
| |
|
| | def _detect_tables_from_ocr( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | page_number: int, |
| | make_id, |
| | ) -> List[LayoutRegion]: |
| | """Detect table regions based on aligned text patterns.""" |
| | if not ocr_regions or not self.config.detect_tables: |
| | return [] |
| |
|
| | regions = [] |
| |
|
| | |
| | x_groups = defaultdict(list) |
| | x_tolerance = 20 |
| |
|
| | for region in ocr_regions: |
| | x_center = region.bbox.center[0] |
| | |
| | matched = False |
| | for x_key in list(x_groups.keys()): |
| | if abs(x_center - x_key) < x_tolerance: |
| | x_groups[x_key].append(region) |
| | matched = True |
| | break |
| | if not matched: |
| | x_groups[x_center].append(region) |
| |
|
| | |
| | if len(x_groups) >= self.config.table_min_cols: |
| | |
| | columns = sorted(x_groups.keys()) |
| |
|
| | |
| | |
| | all_regions = [r for regions in x_groups.values() for r in regions] |
| | if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols: |
| | x_min = min(r.bbox.x_min for r in all_regions) |
| | y_min = min(r.bbox.y_min for r in all_regions) |
| | x_max = max(r.bbox.x_max for r in all_regions) |
| | y_max = max(r.bbox.y_max for r in all_regions) |
| |
|
| | |
| | width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions) |
| | if width_ratio > 0.3: |
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=LayoutType.TABLE, |
| | confidence=0.6, |
| | bbox=BoundingBox( |
| | x_min=x_min, y_min=y_min, |
| | x_max=x_max, y_max=y_max, |
| | normalized=False, |
| | ), |
| | page=page_number, |
| | extra={"estimated_cols": len(columns)}, |
| | )) |
| |
|
| | return regions |
| |
|
| | def _detect_headers_footers( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | page_number: int, |
| | make_id, |
| | page_height: int, |
| | ) -> List[LayoutRegion]: |
| | """Detect header and footer regions.""" |
| | if not ocr_regions or not self.config.detect_headers: |
| | return [] |
| |
|
| | regions = [] |
| | header_threshold = page_height * 0.08 |
| | footer_threshold = page_height * 0.92 |
| |
|
| | header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold] |
| | footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold] |
| |
|
| | if header_regions: |
| | x_min = min(r.bbox.x_min for r in header_regions) |
| | y_min = min(r.bbox.y_min for r in header_regions) |
| | x_max = max(r.bbox.x_max for r in header_regions) |
| | y_max = max(r.bbox.y_max for r in header_regions) |
| |
|
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=LayoutType.HEADER, |
| | confidence=0.7, |
| | bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
| | page=page_number, |
| | )) |
| |
|
| | if footer_regions: |
| | x_min = min(r.bbox.x_min for r in footer_regions) |
| | y_min = min(r.bbox.y_min for r in footer_regions) |
| | x_max = max(r.bbox.x_max for r in footer_regions) |
| | y_max = max(r.bbox.y_max for r in footer_regions) |
| |
|
| | regions.append(LayoutRegion( |
| | id=make_id(), |
| | type=LayoutType.FOOTER, |
| | confidence=0.7, |
| | bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
| | page=page_number, |
| | )) |
| |
|
| | return regions |
| |
|
| | def _detect_figures_from_image( |
| | self, |
| | image: np.ndarray, |
| | page_number: int, |
| | make_id, |
| | ocr_regions: Optional[List[OCRRegion]], |
| | ) -> List[LayoutRegion]: |
| | """Detect figure regions using image analysis.""" |
| | |
| | regions = [] |
| |
|
| | |
| | if ocr_regions: |
| | height, width = image.shape[:2] |
| |
|
| | |
| | text_mask = np.zeros((height, width), dtype=np.uint8) |
| | for r in ocr_regions: |
| | bbox = r.bbox |
| | x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max) |
| | text_mask[y1:y2, x1:x2] = 255 |
| |
|
| | |
| | |
| | |
| |
|
| | return regions |
| |
|
| | def _group_into_lines( |
| | self, |
| | ocr_regions: List[OCRRegion], |
| | ) -> Dict[int, List[OCRRegion]]: |
| | """Group OCR regions into lines based on y-position.""" |
| | if not ocr_regions: |
| | return {} |
| |
|
| | lines = defaultdict(list) |
| | y_tolerance = 10 |
| |
|
| | |
| | sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min) |
| |
|
| | current_line_id = 0 |
| | current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0 |
| |
|
| | for region in sorted_regions: |
| | if abs(region.bbox.y_min - current_y) > y_tolerance: |
| | current_line_id += 1 |
| | current_y = region.bbox.y_min |
| | lines[current_line_id].append(region) |
| |
|
| | |
| | for line_id in lines: |
| | lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min) |
| |
|
| | return dict(lines) |
| |
|
| | def _group_lines_into_paragraphs( |
| | self, |
| | lines: Dict[int, List[OCRRegion]], |
| | all_regions: List[OCRRegion], |
| | ) -> List[List[int]]: |
| | """Group lines into paragraphs based on spacing.""" |
| | if not lines: |
| | return [] |
| |
|
| | paragraphs = [] |
| | current_para = [] |
| |
|
| | sorted_line_ids = sorted(lines.keys()) |
| |
|
| | for i, line_id in enumerate(sorted_line_ids): |
| | if not current_para: |
| | current_para.append(line_id) |
| | continue |
| |
|
| | prev_line = lines[sorted_line_ids[i - 1]] |
| | curr_line = lines[line_id] |
| |
|
| | if not prev_line or not curr_line: |
| | continue |
| |
|
| | |
| | prev_y_max = max(r.bbox.y_max for r in prev_line) |
| | curr_y_min = min(r.bbox.y_min for r in curr_line) |
| | gap = curr_y_min - prev_y_max |
| |
|
| | |
| | avg_height = np.mean([r.bbox.height for r in prev_line + curr_line]) |
| |
|
| | |
| | if gap > avg_height * 1.5: |
| | paragraphs.append(current_para) |
| | current_para = [line_id] |
| | else: |
| | current_para.append(line_id) |
| |
|
| | if current_para: |
| | paragraphs.append(current_para) |
| |
|
| | return paragraphs |
| |
|
| | def _merge_overlapping_regions( |
| | self, |
| | regions: List[LayoutRegion], |
| | ) -> List[LayoutRegion]: |
| | """Merge overlapping regions of the same type.""" |
| | if not regions: |
| | return [] |
| |
|
| | |
| | by_type = defaultdict(list) |
| | for r in regions: |
| | by_type[r.type].append(r) |
| |
|
| | merged = [] |
| | for layout_type, type_regions in by_type.items(): |
| | |
| | |
| | merged.extend(type_regions) |
| |
|
| | return merged |
| |
|
| | def _assign_reading_order( |
| | self, |
| | regions: List[LayoutRegion], |
| | ) -> List[LayoutRegion]: |
| | """Assign reading order to regions (top-to-bottom, left-to-right).""" |
| | if not regions: |
| | return [] |
| |
|
| | |
| | sorted_regions = sorted( |
| | regions, |
| | key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
| | ) |
| |
|
| | for i, region in enumerate(sorted_regions): |
| | region.reading_order = i |
| |
|
| | return sorted_regions |
| |
|
| |
|
| | |
| | _layout_detector: Optional[LayoutDetector] = None |
| |
|
| |
|
| | def create_layout_detector( |
| | config: Optional[LayoutConfig] = None, |
| | initialize: bool = True, |
| | ) -> LayoutDetector: |
| | """Create a layout detector instance.""" |
| | if config is None: |
| | config = LayoutConfig() |
| |
|
| | if config.method == "rule_based": |
| | detector = RuleBasedLayoutDetector(config) |
| | else: |
| | |
| | logger.warning(f"Unknown method {config.method}, using rule_based") |
| | detector = RuleBasedLayoutDetector(config) |
| |
|
| | if initialize: |
| | detector.initialize() |
| |
|
| | return detector |
| |
|
| |
|
| | def get_layout_detector( |
| | config: Optional[LayoutConfig] = None, |
| | ) -> LayoutDetector: |
| | """Get or create singleton layout detector.""" |
| | global _layout_detector |
| | if _layout_detector is None: |
| | _layout_detector = create_layout_detector(config) |
| | return _layout_detector |
| |
|