| | """ |
| | Table Extraction Model Interface |
| | |
| | Abstract interface for table structure recognition and cell extraction. |
| | Handles complex tables with merged cells, headers, and nested structures. |
| | """ |
| |
|
| | from abc import abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| | from ..chunks.models import BoundingBox, TableCell, TableChunk |
| | from .base import ( |
| | BaseModel, |
| | BatchableModel, |
| | ImageInput, |
| | ModelCapability, |
| | ModelConfig, |
| | ) |
| | from .layout import LayoutRegion |
| |
|
| |
|
| | class TableCellType(str, Enum): |
| | """Types of table cells.""" |
| |
|
| | HEADER = "header" |
| | DATA = "data" |
| | INDEX = "index" |
| | MERGED = "merged" |
| | EMPTY = "empty" |
| |
|
| |
|
| | @dataclass |
| | class TableConfig(ModelConfig): |
| | """Configuration for table extraction models.""" |
| |
|
| | min_confidence: float = 0.5 |
| | detect_headers: bool = True |
| | detect_merged_cells: bool = True |
| | max_rows: int = 500 |
| | max_cols: int = 50 |
| | extract_cell_text: bool = True |
| |
|
| | def __post_init__(self): |
| | super().__post_init__() |
| | if not self.name: |
| | self.name = "table_extractor" |
| |
|
| |
|
| | @dataclass |
| | class TableStructure: |
| | """ |
| | Detected table structure with cell grid. |
| | |
| | Represents the logical structure of a table including |
| | merged cells, headers, and cell relationships. |
| | """ |
| |
|
| | bbox: BoundingBox |
| | cells: List[TableCell] = field(default_factory=list) |
| | num_rows: int = 0 |
| | num_cols: int = 0 |
| |
|
| | |
| | header_rows: List[int] = field(default_factory=list) |
| | header_cols: List[int] = field(default_factory=list) |
| |
|
| | |
| | structure_confidence: float = 0.0 |
| | cell_confidence_avg: float = 0.0 |
| |
|
| | |
| | has_merged_cells: bool = False |
| | is_bordered: bool = True |
| | table_id: str = "" |
| |
|
| | def __post_init__(self): |
| | if not self.table_id: |
| | import hashlib |
| | content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}" |
| | self.table_id = hashlib.md5(content.encode()).hexdigest()[:12] |
| |
|
| | def get_cell(self, row: int, col: int) -> Optional[TableCell]: |
| | """Get cell at specific position.""" |
| | for cell in self.cells: |
| | if cell.row == row and cell.col == col: |
| | return cell |
| | |
| | if (cell.row <= row < cell.row + cell.rowspan and |
| | cell.col <= col < cell.col + cell.colspan): |
| | return cell |
| | return None |
| |
|
| | def get_row(self, row_index: int) -> List[TableCell]: |
| | """Get all cells in a row.""" |
| | return sorted( |
| | [c for c in self.cells if c.row == row_index], |
| | key=lambda c: c.col |
| | ) |
| |
|
| | def get_col(self, col_index: int) -> List[TableCell]: |
| | """Get all cells in a column.""" |
| | return sorted( |
| | [c for c in self.cells if c.col == col_index], |
| | key=lambda c: c.row |
| | ) |
| |
|
| | def get_headers(self) -> List[TableCell]: |
| | """Get all header cells.""" |
| | return [c for c in self.cells if c.is_header] |
| |
|
| | def to_csv(self, delimiter: str = ",") -> str: |
| | """Convert table to CSV string.""" |
| | rows = [] |
| | for r in range(self.num_rows): |
| | row_cells = [] |
| | for c in range(self.num_cols): |
| | cell = self.get_cell(r, c) |
| | text = cell.text if cell else "" |
| | |
| | if delimiter in text or '"' in text or '\n' in text: |
| | text = '"' + text.replace('"', '""') + '"' |
| | row_cells.append(text) |
| | rows.append(delimiter.join(row_cells)) |
| | return "\n".join(rows) |
| |
|
| | def to_markdown(self) -> str: |
| | """Convert table to Markdown format.""" |
| | if self.num_rows == 0 or self.num_cols == 0: |
| | return "" |
| |
|
| | lines = [] |
| |
|
| | |
| | for r in range(self.num_rows): |
| | row_texts = [] |
| | for c in range(self.num_cols): |
| | cell = self.get_cell(r, c) |
| | text = cell.text.replace("|", "\\|") if cell else "" |
| | row_texts.append(text) |
| | lines.append("| " + " | ".join(row_texts) + " |") |
| |
|
| | |
| | if r == 0: |
| | separators = ["---"] * self.num_cols |
| | lines.append("| " + " | ".join(separators) + " |") |
| |
|
| | return "\n".join(lines) |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to structured dictionary.""" |
| | return { |
| | "num_rows": self.num_rows, |
| | "num_cols": self.num_cols, |
| | "header_rows": self.header_rows, |
| | "header_cols": self.header_cols, |
| | "cells": [ |
| | { |
| | "row": c.row, |
| | "col": c.col, |
| | "text": c.text, |
| | "rowspan": c.rowspan, |
| | "colspan": c.colspan, |
| | "is_header": c.is_header, |
| | "confidence": c.confidence |
| | } |
| | for c in self.cells |
| | ] |
| | } |
| |
|
| | def to_table_chunk( |
| | self, |
| | doc_id: str, |
| | page: int, |
| | sequence_index: int |
| | ) -> TableChunk: |
| | """Convert to TableChunk for the chunks module.""" |
| | return TableChunk( |
| | chunk_id=TableChunk.generate_chunk_id( |
| | doc_id=doc_id, |
| | page=page, |
| | bbox=self.bbox, |
| | chunk_type_str="table" |
| | ), |
| | doc_id=doc_id, |
| | text=self.to_markdown(), |
| | page=page, |
| | bbox=self.bbox, |
| | confidence=self.structure_confidence, |
| | sequence_index=sequence_index, |
| | cells=self.cells, |
| | num_rows=self.num_rows, |
| | num_cols=self.num_cols, |
| | header_rows=self.header_rows, |
| | header_cols=self.header_cols, |
| | has_merged_cells=self.has_merged_cells |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class TableExtractionResult: |
| | """Result of table extraction from a page.""" |
| |
|
| | tables: List[TableStructure] = field(default_factory=list) |
| | processing_time_ms: float = 0.0 |
| | model_metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | @property |
| | def table_count(self) -> int: |
| | return len(self.tables) |
| |
|
| | def get_table_at_region( |
| | self, |
| | region: LayoutRegion, |
| | iou_threshold: float = 0.5 |
| | ) -> Optional[TableStructure]: |
| | """Find table that matches a layout region.""" |
| | best_match = None |
| | best_iou = 0.0 |
| |
|
| | for table in self.tables: |
| | iou = table.bbox.iou(region.bbox) |
| | if iou > iou_threshold and iou > best_iou: |
| | best_match = table |
| | best_iou = iou |
| |
|
| | return best_match |
| |
|
| |
|
| | class TableModel(BatchableModel): |
| | """ |
| | Abstract base class for table extraction models. |
| | |
| | Implementations should handle: |
| | - Table structure detection (rows, columns) |
| | - Cell boundary detection |
| | - Merged cell handling |
| | - Header detection |
| | - Cell content extraction |
| | """ |
| |
|
| | def __init__(self, config: Optional[TableConfig] = None): |
| | super().__init__(config or TableConfig(name="table")) |
| | self.config: TableConfig = self.config |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | return [ModelCapability.TABLE_EXTRACTION] |
| |
|
| | @abstractmethod |
| | def extract_structure( |
| | self, |
| | image: ImageInput, |
| | table_region: Optional[BoundingBox] = None, |
| | **kwargs |
| | ) -> TableStructure: |
| | """ |
| | Extract table structure from an image. |
| | |
| | Args: |
| | image: Input image containing a table |
| | table_region: Optional bounding box of the table region |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | TableStructure with cells and metadata |
| | """ |
| | pass |
| |
|
| | def extract_all_tables( |
| | self, |
| | image: ImageInput, |
| | table_regions: Optional[List[BoundingBox]] = None, |
| | **kwargs |
| | ) -> TableExtractionResult: |
| | """ |
| | Extract all tables from an image. |
| | |
| | Args: |
| | image: Input document image |
| | table_regions: Optional list of table bounding boxes |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | TableExtractionResult with all detected tables |
| | """ |
| | import time |
| | start_time = time.time() |
| |
|
| | tables = [] |
| |
|
| | if table_regions: |
| | |
| | for region in table_regions: |
| | try: |
| | table = self.extract_structure(image, region, **kwargs) |
| | tables.append(table) |
| | except Exception: |
| | continue |
| | else: |
| | |
| | table = self.extract_structure(image, **kwargs) |
| | if table.num_rows > 0: |
| | tables.append(table) |
| |
|
| | processing_time = (time.time() - start_time) * 1000 |
| |
|
| | return TableExtractionResult( |
| | tables=tables, |
| | processing_time_ms=processing_time |
| | ) |
| |
|
| | def process_batch( |
| | self, |
| | inputs: List[ImageInput], |
| | **kwargs |
| | ) -> List[TableExtractionResult]: |
| | """Process multiple images.""" |
| | return [self.extract_all_tables(img, **kwargs) for img in inputs] |
| |
|
| | @abstractmethod |
| | def extract_cell_text( |
| | self, |
| | image: ImageInput, |
| | cell_bbox: BoundingBox, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Extract text from a specific cell region. |
| | |
| | Args: |
| | image: Image containing the cell |
| | cell_bbox: Bounding box of the cell |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | Extracted text content |
| | """ |
| | pass |
| |
|