RL-OCR-Openenv / env /environment.py
SpandanM110's picture
Clamp task scores to strictly (0, 1) — never exactly 0.0 or 1.0
15e2e13
"""Core OCR Table RL Environment."""
from __future__ import annotations
from typing import Optional
from .models import OCRAction, OCRObservation, OCRState
from .graders import (
cer, markdown_score, kpi_score, kpi_hallucination_penalty,
calibration_score, score_task1, score_task2, score_task3,
)
from .tasks import TASK_REGISTRY
class OCREnvironment:
"""
OpenEnv-compatible environment for OCR Table Extraction RL.
Three tasks:
- clean_table (easy)
- noisy_financial (medium)
- degraded_report (hard)
Supports:
step(action) -> (OCRObservation, reward, done, info)
reset(task) -> OCRObservation
state() -> OCRState
"""
def __init__(self):
self._task_name: str = "clean_table"
self._task_data: dict = {}
self._step: int = 0
self._max_steps: int = 5
self._done: bool = False
# Submitted outputs
self._markdown: Optional[str] = None
self._kpis: Optional[dict] = None
self._confidences: list = []
# State tracking
self._best_cer: float = 1.0
self._kpi_fields_correct: int = 0
self._active_table: int = 1
self._last_action_type: Optional[str] = None
self._repeat_count: int = 0
self._cropped_hint: Optional[str] = None
# ------------------------------------------------------------------
# reset
# ------------------------------------------------------------------
def reset(self, task: str = "clean_table") -> OCRObservation:
if task not in TASK_REGISTRY:
task = "clean_table"
self._task_name = task
self._task_data = TASK_REGISTRY[task]()
self._step = 0
self._max_steps = self._task_data["max_steps"]
self._done = False
self._markdown = None
self._kpis = None
self._confidences = []
self._best_cer = 1.0
self._kpi_fields_correct = 0
self._active_table = 1
self._last_action_type = None
self._repeat_count = 0
self._cropped_hint = None
return OCRObservation(
image_b64=self._task_data.get("image_b64"),
text_hint=self._task_data.get("text_hint", ""),
reward=0.0,
done=False,
cer=None,
kpi_score=None,
error=None,
metadata={
"task": self._task_name,
"max_steps": self._max_steps,
"instructions": self._instructions(),
},
)
# ------------------------------------------------------------------
# step
# ------------------------------------------------------------------
def step(self, action: OCRAction) -> tuple[OCRObservation, float, bool, dict]:
if self._done:
obs = OCRObservation(
text_hint=self._task_data.get("text_hint", ""),
reward=0.0,
done=True,
error="Episode already done. Call reset().",
)
return obs, 0.0, True, {"error": "Episode already done."}
self._step += 1
reward = 0.0
error = None
# Loop detection
if action.action_type == self._last_action_type:
self._repeat_count += 1
else:
self._repeat_count = 0
self._last_action_type = action.action_type
if self._repeat_count >= 2:
reward -= 0.10
# Process action
if action.action_type == "extract_table_md":
reward += self._handle_extract_md(action)
elif action.action_type == "extract_kpis":
reward += self._handle_extract_kpis(action)
elif action.action_type == "crop_region":
reward, error = self._handle_crop_region(action)
elif action.action_type == "retry_region":
reward += self._handle_retry_region()
elif action.action_type == "correct_cell":
reward += self._handle_correct_cell(action)
elif action.action_type == "switch_table":
self._active_table = 2 if self._active_table == 1 else 1
reward = 0.0
elif action.action_type == "finalize":
reward, self._done = self._handle_finalize()
# Max steps exceeded
if self._step >= self._max_steps and not self._done:
self._done = True
# Final score on timeout
final = self._compute_final_score()
reward = max(reward, final)
gt_kpis = self._gt_kpis()
current_kpi_score = kpi_score(self._kpis or {}, gt_kpis) if gt_kpis else None
current_cer = cer(self._markdown or "", self._gt_md()) if self._markdown else None
obs = OCRObservation(
image_b64=self._cropped_hint,
text_hint=self._cropped_hint or self._task_data.get("text_hint", ""),
reward=round(reward, 4),
done=self._done,
cer=round(current_cer, 4) if current_cer is not None else None,
kpi_score=round(current_kpi_score, 4) if current_kpi_score is not None else None,
error=error,
metadata={
"step": self._step,
"max_steps": self._max_steps,
"active_table": self._active_table,
},
)
self._cropped_hint = None # reset after one step
return obs, round(reward, 4), self._done, {"error": error}
# ------------------------------------------------------------------
# state
# ------------------------------------------------------------------
def state(self) -> OCRState:
gt_kpis = self._gt_kpis()
correct = 0
if self._kpis and gt_kpis:
correct = sum(1 for k, v in gt_kpis.items()
if str(self._kpis.get(k, "")).strip() == str(v).strip())
return OCRState(
task_name=self._task_name,
step=self._step,
max_steps=self._max_steps,
best_cer=round(self._best_cer, 4),
kpi_fields_correct=correct,
kpi_fields_total=len(gt_kpis) if gt_kpis else 0,
markdown_submitted=self._markdown is not None,
kpis_submitted=self._kpis is not None,
active_table=self._active_table,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _gt_md(self) -> str:
td = self._task_data
if self._task_name == "degraded_report":
if self._active_table == 1:
return td["gt_md"].split("\n---\n")[0]
else:
parts = td["gt_md"].split("\n---\n")
return parts[1] if len(parts) > 1 else td["gt_md"]
return td.get("gt_md", "")
def _gt_kpis(self) -> dict:
td = self._task_data
if self._task_name == "degraded_report":
return td.get("gt_kpis", {})
return td.get("gt_kpis", {})
def _handle_extract_md(self, action: OCRAction) -> float:
if not action.markdown:
return -0.05
prev_cer = self._best_cer
gt = self._gt_md()
new_cer = cer(action.markdown, gt)
self._markdown = action.markdown
if self._confidences is not None and action.confidences:
self._confidences = [c if isinstance(c, dict) else c.model_dump()
for c in action.confidences]
if new_cer < self._best_cer:
self._best_cer = new_cer
return round(0.1 * (prev_cer - new_cer), 4)
return 0.0
def _handle_extract_kpis(self, action: OCRAction) -> float:
if not action.kpis:
return -0.05
gt = self._gt_kpis()
reward = 0.0
prev_correct = self._kpi_fields_correct
new_correct = sum(1 for k, v in gt.items()
if str(action.kpis.get(k, "")).strip() == str(v).strip())
gained = max(0, new_correct - prev_correct)
reward += 0.10 * gained
hallucinations = kpi_hallucination_penalty(action.kpis, gt)
reward -= 0.05 * hallucinations
self._kpis = action.kpis
self._kpi_fields_correct = new_correct
return round(reward, 4)
def _handle_crop_region(self, action: OCRAction) -> tuple[float, Optional[str]]:
# Simulate zoom: return a sub-hint from text_hint
hint = self._task_data.get("text_hint", "")
region = action.region or {}
r1 = int(region.get("r1", 0))
r2 = int(region.get("r2", 999))
lines = hint.splitlines()
sub = "\n".join(lines[r1:r2]) if lines else hint
self._cropped_hint = sub # will be returned as text_hint next step
return 0.0, None
def _handle_retry_region(self) -> float:
# If the cropped region would improve CER, small bonus
if self._markdown:
prev = self._best_cer
gt = self._gt_md()
# Use clean GT lines as a proxy (real system would re-OCR)
new_cer = cer(self._markdown, gt)
if new_cer < prev:
self._best_cer = new_cer
return 0.05
return 0.0
def _handle_correct_cell(self, action: OCRAction) -> float:
if action.cell_row is None or action.cell_col is None or not action.cell_value:
return -0.02
if not self._markdown:
return -0.02
# Apply cell correction to markdown
lines = self._markdown.splitlines()
data_lines = [l for l in lines if l.strip().startswith("|") and "---" not in l]
ri = action.cell_row
ci = action.cell_col
if ri < len(data_lines):
cells = data_lines[ri].strip("|").split("|")
if ci < len(cells):
cells[ci] = f" {action.cell_value} "
data_lines[ri] = "|" + "|".join(cells) + "|"
# Reconstruct markdown
full = []
di = 0
for l in lines:
if l.strip().startswith("|") and "---" not in l:
full.append(data_lines[di] if di < len(data_lines) else l)
di += 1
else:
full.append(l)
self._markdown = "\n".join(full)
new_cer = cer(self._markdown, self._gt_md())
if new_cer < self._best_cer:
self._best_cer = new_cer
return 0.05
return 0.0
def _handle_finalize(self) -> tuple[float, bool]:
if not self._markdown and not self._kpis:
return 0.01, True # strictly > 0
score = self._compute_final_score()
return round(score, 4), True
def _compute_final_score(self) -> float:
td = self._task_data
gt_md = td.get("gt_md", "")
gt_kpis = self._gt_kpis()
task_id = td.get("task_id", 1)
if task_id == 1:
return score_task1(self._markdown or "", self._kpis or {}, gt_md, gt_kpis)
elif task_id == 2:
gt_cells = td.get("gt_cells", {})
return score_task2(
self._markdown or "", self._kpis or {}, self._confidences,
gt_md, gt_kpis, gt_cells,
)
else:
return score_task3(
self._markdown or "", self._kpis or {}, gt_md, gt_kpis,
steps_used=self._step,
)
def _instructions(self) -> str:
base = (
"You are an OCR agent. Extract the table(s) from the document.\n"
"Actions: extract_table_md, extract_kpis, crop_region, retry_region, "
"correct_cell, switch_table (task 3 only), finalize.\n"
"Output BOTH a Markdown table AND a JSON KPI dict before calling finalize.\n"
)
if self._task_name == "degraded_report":
base += "This document has TWO tables. Use switch_table to toggle between them.\n"
return base