Spaces:
Sleeping
Sleeping
| """Core OCR Table RL Environment.""" | |
| from __future__ import annotations | |
| from typing import Optional | |
| from .models import OCRAction, OCRObservation, OCRState | |
| from .graders import ( | |
| cer, markdown_score, kpi_score, kpi_hallucination_penalty, | |
| calibration_score, score_task1, score_task2, score_task3, | |
| ) | |
| from .tasks import TASK_REGISTRY | |
| class OCREnvironment: | |
| """ | |
| OpenEnv-compatible environment for OCR Table Extraction RL. | |
| Three tasks: | |
| - clean_table (easy) | |
| - noisy_financial (medium) | |
| - degraded_report (hard) | |
| Supports: | |
| step(action) -> (OCRObservation, reward, done, info) | |
| reset(task) -> OCRObservation | |
| state() -> OCRState | |
| """ | |
| def __init__(self): | |
| self._task_name: str = "clean_table" | |
| self._task_data: dict = {} | |
| self._step: int = 0 | |
| self._max_steps: int = 5 | |
| self._done: bool = False | |
| # Submitted outputs | |
| self._markdown: Optional[str] = None | |
| self._kpis: Optional[dict] = None | |
| self._confidences: list = [] | |
| # State tracking | |
| self._best_cer: float = 1.0 | |
| self._kpi_fields_correct: int = 0 | |
| self._active_table: int = 1 | |
| self._last_action_type: Optional[str] = None | |
| self._repeat_count: int = 0 | |
| self._cropped_hint: Optional[str] = None | |
| # ------------------------------------------------------------------ | |
| # reset | |
| # ------------------------------------------------------------------ | |
| def reset(self, task: str = "clean_table") -> OCRObservation: | |
| if task not in TASK_REGISTRY: | |
| task = "clean_table" | |
| self._task_name = task | |
| self._task_data = TASK_REGISTRY[task]() | |
| self._step = 0 | |
| self._max_steps = self._task_data["max_steps"] | |
| self._done = False | |
| self._markdown = None | |
| self._kpis = None | |
| self._confidences = [] | |
| self._best_cer = 1.0 | |
| self._kpi_fields_correct = 0 | |
| self._active_table = 1 | |
| self._last_action_type = None | |
| self._repeat_count = 0 | |
| self._cropped_hint = None | |
| return OCRObservation( | |
| image_b64=self._task_data.get("image_b64"), | |
| text_hint=self._task_data.get("text_hint", ""), | |
| reward=0.0, | |
| done=False, | |
| cer=None, | |
| kpi_score=None, | |
| error=None, | |
| metadata={ | |
| "task": self._task_name, | |
| "max_steps": self._max_steps, | |
| "instructions": self._instructions(), | |
| }, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # step | |
| # ------------------------------------------------------------------ | |
| def step(self, action: OCRAction) -> tuple[OCRObservation, float, bool, dict]: | |
| if self._done: | |
| obs = OCRObservation( | |
| text_hint=self._task_data.get("text_hint", ""), | |
| reward=0.0, | |
| done=True, | |
| error="Episode already done. Call reset().", | |
| ) | |
| return obs, 0.0, True, {"error": "Episode already done."} | |
| self._step += 1 | |
| reward = 0.0 | |
| error = None | |
| # Loop detection | |
| if action.action_type == self._last_action_type: | |
| self._repeat_count += 1 | |
| else: | |
| self._repeat_count = 0 | |
| self._last_action_type = action.action_type | |
| if self._repeat_count >= 2: | |
| reward -= 0.10 | |
| # Process action | |
| if action.action_type == "extract_table_md": | |
| reward += self._handle_extract_md(action) | |
| elif action.action_type == "extract_kpis": | |
| reward += self._handle_extract_kpis(action) | |
| elif action.action_type == "crop_region": | |
| reward, error = self._handle_crop_region(action) | |
| elif action.action_type == "retry_region": | |
| reward += self._handle_retry_region() | |
| elif action.action_type == "correct_cell": | |
| reward += self._handle_correct_cell(action) | |
| elif action.action_type == "switch_table": | |
| self._active_table = 2 if self._active_table == 1 else 1 | |
| reward = 0.0 | |
| elif action.action_type == "finalize": | |
| reward, self._done = self._handle_finalize() | |
| # Max steps exceeded | |
| if self._step >= self._max_steps and not self._done: | |
| self._done = True | |
| # Final score on timeout | |
| final = self._compute_final_score() | |
| reward = max(reward, final) | |
| gt_kpis = self._gt_kpis() | |
| current_kpi_score = kpi_score(self._kpis or {}, gt_kpis) if gt_kpis else None | |
| current_cer = cer(self._markdown or "", self._gt_md()) if self._markdown else None | |
| obs = OCRObservation( | |
| image_b64=self._cropped_hint, | |
| text_hint=self._cropped_hint or self._task_data.get("text_hint", ""), | |
| reward=round(reward, 4), | |
| done=self._done, | |
| cer=round(current_cer, 4) if current_cer is not None else None, | |
| kpi_score=round(current_kpi_score, 4) if current_kpi_score is not None else None, | |
| error=error, | |
| metadata={ | |
| "step": self._step, | |
| "max_steps": self._max_steps, | |
| "active_table": self._active_table, | |
| }, | |
| ) | |
| self._cropped_hint = None # reset after one step | |
| return obs, round(reward, 4), self._done, {"error": error} | |
| # ------------------------------------------------------------------ | |
| # state | |
| # ------------------------------------------------------------------ | |
| def state(self) -> OCRState: | |
| gt_kpis = self._gt_kpis() | |
| correct = 0 | |
| if self._kpis and gt_kpis: | |
| correct = sum(1 for k, v in gt_kpis.items() | |
| if str(self._kpis.get(k, "")).strip() == str(v).strip()) | |
| return OCRState( | |
| task_name=self._task_name, | |
| step=self._step, | |
| max_steps=self._max_steps, | |
| best_cer=round(self._best_cer, 4), | |
| kpi_fields_correct=correct, | |
| kpi_fields_total=len(gt_kpis) if gt_kpis else 0, | |
| markdown_submitted=self._markdown is not None, | |
| kpis_submitted=self._kpis is not None, | |
| active_table=self._active_table, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _gt_md(self) -> str: | |
| td = self._task_data | |
| if self._task_name == "degraded_report": | |
| if self._active_table == 1: | |
| return td["gt_md"].split("\n---\n")[0] | |
| else: | |
| parts = td["gt_md"].split("\n---\n") | |
| return parts[1] if len(parts) > 1 else td["gt_md"] | |
| return td.get("gt_md", "") | |
| def _gt_kpis(self) -> dict: | |
| td = self._task_data | |
| if self._task_name == "degraded_report": | |
| return td.get("gt_kpis", {}) | |
| return td.get("gt_kpis", {}) | |
| def _handle_extract_md(self, action: OCRAction) -> float: | |
| if not action.markdown: | |
| return -0.05 | |
| prev_cer = self._best_cer | |
| gt = self._gt_md() | |
| new_cer = cer(action.markdown, gt) | |
| self._markdown = action.markdown | |
| if self._confidences is not None and action.confidences: | |
| self._confidences = [c if isinstance(c, dict) else c.model_dump() | |
| for c in action.confidences] | |
| if new_cer < self._best_cer: | |
| self._best_cer = new_cer | |
| return round(0.1 * (prev_cer - new_cer), 4) | |
| return 0.0 | |
| def _handle_extract_kpis(self, action: OCRAction) -> float: | |
| if not action.kpis: | |
| return -0.05 | |
| gt = self._gt_kpis() | |
| reward = 0.0 | |
| prev_correct = self._kpi_fields_correct | |
| new_correct = sum(1 for k, v in gt.items() | |
| if str(action.kpis.get(k, "")).strip() == str(v).strip()) | |
| gained = max(0, new_correct - prev_correct) | |
| reward += 0.10 * gained | |
| hallucinations = kpi_hallucination_penalty(action.kpis, gt) | |
| reward -= 0.05 * hallucinations | |
| self._kpis = action.kpis | |
| self._kpi_fields_correct = new_correct | |
| return round(reward, 4) | |
| def _handle_crop_region(self, action: OCRAction) -> tuple[float, Optional[str]]: | |
| # Simulate zoom: return a sub-hint from text_hint | |
| hint = self._task_data.get("text_hint", "") | |
| region = action.region or {} | |
| r1 = int(region.get("r1", 0)) | |
| r2 = int(region.get("r2", 999)) | |
| lines = hint.splitlines() | |
| sub = "\n".join(lines[r1:r2]) if lines else hint | |
| self._cropped_hint = sub # will be returned as text_hint next step | |
| return 0.0, None | |
| def _handle_retry_region(self) -> float: | |
| # If the cropped region would improve CER, small bonus | |
| if self._markdown: | |
| prev = self._best_cer | |
| gt = self._gt_md() | |
| # Use clean GT lines as a proxy (real system would re-OCR) | |
| new_cer = cer(self._markdown, gt) | |
| if new_cer < prev: | |
| self._best_cer = new_cer | |
| return 0.05 | |
| return 0.0 | |
| def _handle_correct_cell(self, action: OCRAction) -> float: | |
| if action.cell_row is None or action.cell_col is None or not action.cell_value: | |
| return -0.02 | |
| if not self._markdown: | |
| return -0.02 | |
| # Apply cell correction to markdown | |
| lines = self._markdown.splitlines() | |
| data_lines = [l for l in lines if l.strip().startswith("|") and "---" not in l] | |
| ri = action.cell_row | |
| ci = action.cell_col | |
| if ri < len(data_lines): | |
| cells = data_lines[ri].strip("|").split("|") | |
| if ci < len(cells): | |
| cells[ci] = f" {action.cell_value} " | |
| data_lines[ri] = "|" + "|".join(cells) + "|" | |
| # Reconstruct markdown | |
| full = [] | |
| di = 0 | |
| for l in lines: | |
| if l.strip().startswith("|") and "---" not in l: | |
| full.append(data_lines[di] if di < len(data_lines) else l) | |
| di += 1 | |
| else: | |
| full.append(l) | |
| self._markdown = "\n".join(full) | |
| new_cer = cer(self._markdown, self._gt_md()) | |
| if new_cer < self._best_cer: | |
| self._best_cer = new_cer | |
| return 0.05 | |
| return 0.0 | |
| def _handle_finalize(self) -> tuple[float, bool]: | |
| if not self._markdown and not self._kpis: | |
| return 0.01, True # strictly > 0 | |
| score = self._compute_final_score() | |
| return round(score, 4), True | |
| def _compute_final_score(self) -> float: | |
| td = self._task_data | |
| gt_md = td.get("gt_md", "") | |
| gt_kpis = self._gt_kpis() | |
| task_id = td.get("task_id", 1) | |
| if task_id == 1: | |
| return score_task1(self._markdown or "", self._kpis or {}, gt_md, gt_kpis) | |
| elif task_id == 2: | |
| gt_cells = td.get("gt_cells", {}) | |
| return score_task2( | |
| self._markdown or "", self._kpis or {}, self._confidences, | |
| gt_md, gt_kpis, gt_cells, | |
| ) | |
| else: | |
| return score_task3( | |
| self._markdown or "", self._kpis or {}, gt_md, gt_kpis, | |
| steps_used=self._step, | |
| ) | |
| def _instructions(self) -> str: | |
| base = ( | |
| "You are an OCR agent. Extract the table(s) from the document.\n" | |
| "Actions: extract_table_md, extract_kpis, crop_region, retry_region, " | |
| "correct_cell, switch_table (task 3 only), finalize.\n" | |
| "Output BOTH a Markdown table AND a JSON KPI dict before calling finalize.\n" | |
| ) | |
| if self._task_name == "degraded_report": | |
| base += "This document has TWO tables. Use switch_table to toggle between them.\n" | |
| return base | |