| | """ |
| | Chart Extraction Model Interface |
| | |
| | Abstract interface for chart/graph understanding models. |
| | Extracts data points, axes, legends, and interprets visualizations. |
| | """ |
| |
|
| | from abc import abstractmethod |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint |
| | from .base import ( |
| | BaseModel, |
| | BatchableModel, |
| | ImageInput, |
| | ModelCapability, |
| | ModelConfig, |
| | ) |
| |
|
| |
|
| | class ChartType(str, Enum): |
| | """Types of charts that can be detected.""" |
| |
|
| | |
| | BAR = "bar" |
| | LINE = "line" |
| | PIE = "pie" |
| | SCATTER = "scatter" |
| | AREA = "area" |
| |
|
| | |
| | HISTOGRAM = "histogram" |
| | BOX_PLOT = "box_plot" |
| | HEATMAP = "heatmap" |
| | TREEMAP = "treemap" |
| | RADAR = "radar" |
| | BUBBLE = "bubble" |
| | WATERFALL = "waterfall" |
| | FUNNEL = "funnel" |
| | GANTT = "gantt" |
| |
|
| | |
| | STACKED_BAR = "stacked_bar" |
| | GROUPED_BAR = "grouped_bar" |
| | MULTI_LINE = "multi_line" |
| | COMBO = "combo" |
| |
|
| | |
| | DIAGRAM = "diagram" |
| | UNKNOWN = "unknown" |
| |
|
| |
|
| | @dataclass |
| | class ChartConfig(ModelConfig): |
| | """Configuration for chart extraction models.""" |
| |
|
| | min_confidence: float = 0.5 |
| | extract_data_points: bool = True |
| | extract_trends: bool = True |
| | max_data_points: int = 1000 |
| | detect_chart_type: bool = True |
| |
|
| | def __post_init__(self): |
| | super().__post_init__() |
| | if not self.name: |
| | self.name = "chart_extractor" |
| |
|
| |
|
| | @dataclass |
| | class AxisInfo: |
| | """Information about a chart axis.""" |
| |
|
| | label: str = "" |
| | unit: str = "" |
| | min_value: Optional[float] = None |
| | max_value: Optional[float] = None |
| | scale: str = "linear" |
| | tick_labels: List[str] = field(default_factory=list) |
| | tick_values: List[float] = field(default_factory=list) |
| | is_datetime: bool = False |
| | orientation: str = "horizontal" |
| |
|
| |
|
| | @dataclass |
| | class LegendItem: |
| | """A single legend entry.""" |
| |
|
| | label: str |
| | color: Optional[str] = None |
| | series_index: int = 0 |
| |
|
| |
|
| | @dataclass |
| | class DataSeries: |
| | """A data series in a chart.""" |
| |
|
| | name: str |
| | data_points: List[ChartDataPoint] = field(default_factory=list) |
| | color: Optional[str] = None |
| | series_type: Optional[ChartType] = None |
| |
|
| | @property |
| | def x_values(self) -> List[Any]: |
| | return [p.x for p in self.data_points] |
| |
|
| | @property |
| | def y_values(self) -> List[Any]: |
| | return [p.y for p in self.data_points] |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to dictionary.""" |
| | return { |
| | "name": self.name, |
| | "color": self.color, |
| | "series_type": self.series_type.value if self.series_type else None, |
| | "data_points": [ |
| | {"x": p.x, "y": p.y, "label": p.label, "value": p.value} |
| | for p in self.data_points |
| | ] |
| | } |
| |
|
| |
|
| | @dataclass |
| | class TrendInfo: |
| | """Detected trend in the data.""" |
| |
|
| | description: str |
| | direction: str = "neutral" |
| | start_point: Optional[ChartDataPoint] = None |
| | end_point: Optional[ChartDataPoint] = None |
| | change_percent: Optional[float] = None |
| | confidence: float = 0.0 |
| |
|
| |
|
| | @dataclass |
| | class ChartStructure: |
| | """ |
| | Complete extracted chart structure. |
| | |
| | Contains all detected elements of a chart including |
| | type, axes, data series, legends, and interpretations. |
| | """ |
| |
|
| | bbox: BoundingBox |
| | chart_type: ChartType = ChartType.UNKNOWN |
| | confidence: float = 0.0 |
| |
|
| | |
| | title: str = "" |
| | subtitle: str = "" |
| |
|
| | |
| | x_axis: Optional[AxisInfo] = None |
| | y_axis: Optional[AxisInfo] = None |
| | secondary_y_axis: Optional[AxisInfo] = None |
| |
|
| | |
| | series: List[DataSeries] = field(default_factory=list) |
| | legend_items: List[LegendItem] = field(default_factory=list) |
| |
|
| | |
| | key_values: Dict[str, Any] = field(default_factory=dict) |
| | trends: List[TrendInfo] = field(default_factory=list) |
| | summary: str = "" |
| |
|
| | |
| | chart_id: str = "" |
| | source_text: str = "" |
| |
|
| | def __post_init__(self): |
| | if not self.chart_id: |
| | import hashlib |
| | content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}" |
| | self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12] |
| |
|
| | @property |
| | def total_data_points(self) -> int: |
| | return sum(len(s.data_points) for s in self.series) |
| |
|
| | @property |
| | def all_data_points(self) -> List[ChartDataPoint]: |
| | """Get all data points from all series.""" |
| | points = [] |
| | for series in self.series: |
| | points.extend(series.data_points) |
| | return points |
| |
|
| | def get_series_by_name(self, name: str) -> Optional[DataSeries]: |
| | """Find a series by name.""" |
| | for series in self.series: |
| | if series.name.lower() == name.lower(): |
| | return series |
| | return None |
| |
|
| | def to_text_description(self) -> str: |
| | """Generate a text description of the chart.""" |
| | parts = [] |
| |
|
| | if self.title: |
| | parts.append(f"Chart: {self.title}") |
| | else: |
| | parts.append(f"Chart Type: {self.chart_type.value}") |
| |
|
| | if self.x_axis and self.x_axis.label: |
| | parts.append(f"X-Axis: {self.x_axis.label}") |
| | if self.y_axis and self.y_axis.label: |
| | parts.append(f"Y-Axis: {self.y_axis.label}") |
| |
|
| | if self.series: |
| | parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}") |
| |
|
| | if self.key_values: |
| | kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items()) |
| | parts.append(f"Key Values: {kv_str}") |
| |
|
| | if self.trends: |
| | trend_strs = [t.description for t in self.trends if t.description] |
| | if trend_strs: |
| | parts.append(f"Trends: {'; '.join(trend_strs)}") |
| |
|
| | return "\n".join(parts) |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to structured dictionary.""" |
| | return { |
| | "chart_type": self.chart_type.value, |
| | "title": self.title, |
| | "x_axis": { |
| | "label": self.x_axis.label if self.x_axis else "", |
| | "unit": self.x_axis.unit if self.x_axis else "", |
| | }, |
| | "y_axis": { |
| | "label": self.y_axis.label if self.y_axis else "", |
| | "unit": self.y_axis.unit if self.y_axis else "", |
| | }, |
| | "series": [s.to_dict() for s in self.series], |
| | "key_values": self.key_values, |
| | "trends": [ |
| | {"description": t.description, "direction": t.direction} |
| | for t in self.trends |
| | ], |
| | "summary": self.summary |
| | } |
| |
|
| | def to_chart_chunk( |
| | self, |
| | doc_id: str, |
| | page: int, |
| | sequence_index: int |
| | ) -> ChartChunk: |
| | """Convert to ChartChunk for the chunks module.""" |
| | |
| | all_points = self.all_data_points |
| |
|
| | return ChartChunk( |
| | chunk_id=ChartChunk.generate_chunk_id( |
| | doc_id=doc_id, |
| | page=page, |
| | bbox=self.bbox, |
| | chunk_type_str="chart" |
| | ), |
| | doc_id=doc_id, |
| | text=self.to_text_description(), |
| | page=page, |
| | bbox=self.bbox, |
| | confidence=self.confidence, |
| | sequence_index=sequence_index, |
| | chart_type=self.chart_type.value, |
| | title=self.title, |
| | x_axis_label=self.x_axis.label if self.x_axis else None, |
| | y_axis_label=self.y_axis.label if self.y_axis else None, |
| | data_points=all_points, |
| | key_values=self.key_values, |
| | trends=[t.description for t in self.trends] |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class ChartExtractionResult: |
| | """Result of chart extraction from a page.""" |
| |
|
| | charts: List[ChartStructure] = field(default_factory=list) |
| | processing_time_ms: float = 0.0 |
| | model_metadata: Dict[str, Any] = field(default_factory=dict) |
| |
|
| | @property |
| | def chart_count(self) -> int: |
| | return len(self.charts) |
| |
|
| |
|
| | class ChartModel(BatchableModel): |
| | """ |
| | Abstract base class for chart extraction models. |
| | |
| | Implementations should handle: |
| | - Chart type classification |
| | - Axis detection and labeling |
| | - Data point extraction |
| | - Legend parsing |
| | - Trend detection |
| | """ |
| |
|
| | def __init__(self, config: Optional[ChartConfig] = None): |
| | super().__init__(config or ChartConfig(name="chart")) |
| | self.config: ChartConfig = self.config |
| |
|
| | def get_capabilities(self) -> List[ModelCapability]: |
| | return [ModelCapability.CHART_EXTRACTION] |
| |
|
| | @abstractmethod |
| | def extract_chart( |
| | self, |
| | image: ImageInput, |
| | chart_region: Optional[BoundingBox] = None, |
| | **kwargs |
| | ) -> ChartStructure: |
| | """ |
| | Extract chart structure from an image. |
| | |
| | Args: |
| | image: Input image containing a chart |
| | chart_region: Optional bounding box of the chart |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | ChartStructure with extracted data |
| | """ |
| | pass |
| |
|
| | def extract_all_charts( |
| | self, |
| | image: ImageInput, |
| | chart_regions: Optional[List[BoundingBox]] = None, |
| | **kwargs |
| | ) -> ChartExtractionResult: |
| | """ |
| | Extract all charts from an image. |
| | |
| | Args: |
| | image: Input document image |
| | chart_regions: Optional list of chart bounding boxes |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | ChartExtractionResult with all detected charts |
| | """ |
| | import time |
| | start_time = time.time() |
| |
|
| | charts = [] |
| |
|
| | if chart_regions: |
| | for region in chart_regions: |
| | try: |
| | chart = self.extract_chart(image, region, **kwargs) |
| | if chart.chart_type != ChartType.UNKNOWN: |
| | charts.append(chart) |
| | except Exception: |
| | continue |
| | else: |
| | chart = self.extract_chart(image, **kwargs) |
| | if chart.chart_type != ChartType.UNKNOWN: |
| | charts.append(chart) |
| |
|
| | processing_time = (time.time() - start_time) * 1000 |
| |
|
| | return ChartExtractionResult( |
| | charts=charts, |
| | processing_time_ms=processing_time |
| | ) |
| |
|
| | def process_batch( |
| | self, |
| | inputs: List[ImageInput], |
| | **kwargs |
| | ) -> List[ChartExtractionResult]: |
| | """Process multiple images.""" |
| | return [self.extract_all_charts(img, **kwargs) for img in inputs] |
| |
|
| | @abstractmethod |
| | def classify_chart_type( |
| | self, |
| | image: ImageInput, |
| | chart_region: Optional[BoundingBox] = None, |
| | **kwargs |
| | ) -> Tuple[ChartType, float]: |
| | """ |
| | Classify the type of chart in an image. |
| | |
| | Args: |
| | image: Input image |
| | chart_region: Optional bounding box |
| | **kwargs: Additional parameters |
| | |
| | Returns: |
| | Tuple of (ChartType, confidence) |
| | """ |
| | pass |
| |
|
| | def detect_trends( |
| | self, |
| | chart: ChartStructure, |
| | **kwargs |
| | ) -> List[TrendInfo]: |
| | """ |
| | Analyze chart data for trends. |
| | |
| | Default implementation provides basic trend detection. |
| | Override for more sophisticated analysis. |
| | """ |
| | trends = [] |
| |
|
| | for series in chart.series: |
| | if len(series.data_points) < 2: |
| | continue |
| |
|
| | |
| | y_values = [] |
| | for dp in series.data_points: |
| | if dp.y is not None: |
| | try: |
| | y_values.append(float(dp.y)) |
| | except (ValueError, TypeError): |
| | continue |
| |
|
| | if len(y_values) < 2: |
| | continue |
| |
|
| | |
| | first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2) |
| | second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2) |
| |
|
| | if second_half_avg > first_half_avg * 1.1: |
| | direction = "increasing" |
| | elif second_half_avg < first_half_avg * 0.9: |
| | direction = "decreasing" |
| | else: |
| | direction = "stable" |
| |
|
| | change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100 |
| | if first_half_avg != 0 else 0) |
| |
|
| | trend = TrendInfo( |
| | description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)", |
| | direction=direction, |
| | start_point=series.data_points[0], |
| | end_point=series.data_points[-1], |
| | change_percent=change_pct, |
| | confidence=0.7 |
| | ) |
| | trends.append(trend) |
| |
|
| | return trends |
| |
|