Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- src/ocr_bench/__init__.py +3 -0
- src/ocr_bench/backends.py +238 -0
- src/ocr_bench/cli.py +589 -0
- src/ocr_bench/dataset.py +297 -0
- src/ocr_bench/elo.py +309 -0
- src/ocr_bench/judge.py +287 -0
- src/ocr_bench/publish.py +262 -0
- src/ocr_bench/run.py +187 -0
- src/ocr_bench/space.py +18 -0
- src/ocr_bench/static/style.css +379 -0
- src/ocr_bench/templates/base.html +48 -0
- src/ocr_bench/templates/comparison_card.html +88 -0
- src/ocr_bench/templates/comparisons.html +40 -0
- src/ocr_bench/templates/leaderboard.html +43 -0
- src/ocr_bench/templates/stats_panel.html +10 -0
- src/ocr_bench/validate.py +311 -0
- src/ocr_bench/viewer.py +202 -0
- src/ocr_bench/web.py +487 -0
src/ocr_bench/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
src/ocr_bench/backends.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Judge backends — API-based (HF Inference Providers, OpenAI-compatible)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import abc
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import stamina
|
| 11 |
+
import structlog
|
| 12 |
+
from huggingface_hub import InferenceClient
|
| 13 |
+
from openai import OpenAI
|
| 14 |
+
|
| 15 |
+
from ocr_bench.judge import JUDGE_SCHEMA, Comparison, parse_judge_output
|
| 16 |
+
|
| 17 |
+
logger = structlog.get_logger()
|
| 18 |
+
|
| 19 |
+
# Retry on these exception types with exponential backoff + jitter.
|
| 20 |
+
_RETRYABLE = (Exception,)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class JudgeBackend(abc.ABC):
|
| 24 |
+
"""Base class for judge backends."""
|
| 25 |
+
|
| 26 |
+
name: str
|
| 27 |
+
concurrency: int = 1
|
| 28 |
+
|
| 29 |
+
@abc.abstractmethod
|
| 30 |
+
def _call_single(self, comp: Comparison) -> dict[str, str]:
|
| 31 |
+
"""Run the judge on a single comparison."""
|
| 32 |
+
|
| 33 |
+
def judge(self, comparisons: list[Comparison]) -> list[dict[str, str]]:
|
| 34 |
+
"""Run the judge on a list of comparisons (concurrently if supported).
|
| 35 |
+
|
| 36 |
+
Returns a list of parsed results (one per comparison).
|
| 37 |
+
Each result is a dict with ``winner`` and ``reason`` keys,
|
| 38 |
+
or an empty dict on failure.
|
| 39 |
+
"""
|
| 40 |
+
if self.concurrency <= 1 or len(comparisons) <= 1:
|
| 41 |
+
return [self._call_single(comp) for comp in comparisons]
|
| 42 |
+
|
| 43 |
+
# Concurrent execution preserving order
|
| 44 |
+
results: list[dict[str, str]] = [{}] * len(comparisons)
|
| 45 |
+
with ThreadPoolExecutor(max_workers=self.concurrency) as pool:
|
| 46 |
+
future_to_idx = {
|
| 47 |
+
pool.submit(self._call_single, comp): i
|
| 48 |
+
for i, comp in enumerate(comparisons)
|
| 49 |
+
}
|
| 50 |
+
for future in as_completed(future_to_idx):
|
| 51 |
+
idx = future_to_idx[future]
|
| 52 |
+
try:
|
| 53 |
+
results[idx] = future.result()
|
| 54 |
+
except Exception as exc:
|
| 55 |
+
logger.warning("judge_call_failed", idx=idx, error=str(exc))
|
| 56 |
+
results[idx] = {}
|
| 57 |
+
return results
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
DEFAULT_MAX_TOKENS = 1024
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class InferenceProviderJudge(JudgeBackend):
|
| 64 |
+
"""HF Inference Providers backend (Novita, Together, etc.)."""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self, model: str, provider: str | None = None, max_tokens: int = DEFAULT_MAX_TOKENS,
|
| 68 |
+
):
|
| 69 |
+
self.name = f"{provider + ':' if provider else ''}{model}"
|
| 70 |
+
self.model = model
|
| 71 |
+
self.max_tokens = max_tokens
|
| 72 |
+
self.client = InferenceClient(model=model, provider=provider) # type: ignore[invalid-argument-type]
|
| 73 |
+
|
| 74 |
+
@stamina.retry(on=_RETRYABLE, attempts=6)
|
| 75 |
+
def _call_single(self, comp: Comparison) -> dict[str, str]:
|
| 76 |
+
response = self.client.chat_completion( # type: ignore[no-matching-overload]
|
| 77 |
+
messages=comp.messages,
|
| 78 |
+
max_tokens=self.max_tokens,
|
| 79 |
+
temperature=0.0,
|
| 80 |
+
response_format={"type": "json_object"},
|
| 81 |
+
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
| 82 |
+
)
|
| 83 |
+
raw = response.choices[0].message.content.strip()
|
| 84 |
+
result = parse_judge_output(raw)
|
| 85 |
+
if not result:
|
| 86 |
+
logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class OpenAICompatibleJudge(JudgeBackend):
|
| 91 |
+
"""OpenAI-compatible endpoint (local vLLM server, Ollama, HF IE, etc.)."""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
base_url: str,
|
| 96 |
+
model: str = "default",
|
| 97 |
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
| 98 |
+
api_key: str = "not-needed",
|
| 99 |
+
extra_body: dict | None = None,
|
| 100 |
+
temperature: float = 0.0,
|
| 101 |
+
concurrency: int = 1,
|
| 102 |
+
):
|
| 103 |
+
self.name = model if model != "default" else f"openai@{base_url}"
|
| 104 |
+
self.model = model
|
| 105 |
+
self.max_tokens = max_tokens
|
| 106 |
+
self.temperature = temperature
|
| 107 |
+
self.extra_body = extra_body if extra_body is not None else {"guided_json": JUDGE_SCHEMA}
|
| 108 |
+
self.concurrency = concurrency
|
| 109 |
+
self.client = OpenAI(base_url=base_url, api_key=api_key)
|
| 110 |
+
|
| 111 |
+
@stamina.retry(on=_RETRYABLE, attempts=3)
|
| 112 |
+
def _call_single(self, comp: Comparison) -> dict[str, str]:
|
| 113 |
+
response = self.client.chat.completions.create(
|
| 114 |
+
model=self.model,
|
| 115 |
+
messages=comp.messages, # type: ignore[invalid-argument-type]
|
| 116 |
+
max_tokens=self.max_tokens,
|
| 117 |
+
temperature=self.temperature,
|
| 118 |
+
extra_body=self.extra_body,
|
| 119 |
+
)
|
| 120 |
+
raw = response.choices[0].message.content.strip()
|
| 121 |
+
result = parse_judge_output(raw)
|
| 122 |
+
if not result:
|
| 123 |
+
logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
|
| 124 |
+
return result
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# Spec parsing
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
DEFAULT_JUDGE = "novita:moonshotai/Kimi-K2.5"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def parse_judge_spec(
|
| 135 |
+
spec: str, max_tokens: int = DEFAULT_MAX_TOKENS, concurrency: int = 1,
|
| 136 |
+
) -> JudgeBackend:
|
| 137 |
+
"""Parse a judge specification string into a backend.
|
| 138 |
+
|
| 139 |
+
Formats:
|
| 140 |
+
- ``"https://xxx.endpoints.huggingface.cloud"`` → :class:`OpenAICompatibleJudge`
|
| 141 |
+
(HF Inference Endpoints, OpenAI-compatible with HF token auth)
|
| 142 |
+
- ``"http://..."`` or ``"https://..."`` (other) → :class:`OpenAICompatibleJudge`
|
| 143 |
+
- ``"provider:org/model"`` (colon before first ``/``) → :class:`InferenceProviderJudge`
|
| 144 |
+
- anything else → :class:`InferenceProviderJudge` (no provider)
|
| 145 |
+
"""
|
| 146 |
+
if spec.startswith("http://") or spec.startswith("https://"):
|
| 147 |
+
# Check for url:model format (e.g. https://...cloud/v1/:org/model)
|
| 148 |
+
url_part = spec
|
| 149 |
+
model_name = "default"
|
| 150 |
+
# Split on /v1/: to separate URL from model name
|
| 151 |
+
if "/v1/:" in spec:
|
| 152 |
+
url_part, model_name = spec.split("/v1/:", 1)
|
| 153 |
+
url_part += "/v1"
|
| 154 |
+
|
| 155 |
+
# HF Inference Endpoints — OpenAI-compatible, auth via HF token
|
| 156 |
+
if ".endpoints.huggingface." in url_part:
|
| 157 |
+
from huggingface_hub import get_token
|
| 158 |
+
|
| 159 |
+
base_url = url_part.rstrip("/")
|
| 160 |
+
if not base_url.endswith("/v1"):
|
| 161 |
+
base_url += "/v1"
|
| 162 |
+
token = get_token() or "not-needed"
|
| 163 |
+
return OpenAICompatibleJudge(
|
| 164 |
+
base_url=base_url,
|
| 165 |
+
model=model_name,
|
| 166 |
+
api_key=token,
|
| 167 |
+
max_tokens=max_tokens,
|
| 168 |
+
temperature=0.7,
|
| 169 |
+
extra_body={"chat_template_kwargs": {"enable_thinking": False}},
|
| 170 |
+
concurrency=concurrency,
|
| 171 |
+
)
|
| 172 |
+
return OpenAICompatibleJudge(
|
| 173 |
+
base_url=url_part, model=model_name, max_tokens=max_tokens,
|
| 174 |
+
concurrency=concurrency,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if ":" in spec:
|
| 178 |
+
# provider:model format — colon must come before first slash
|
| 179 |
+
colon_idx = spec.index(":")
|
| 180 |
+
slash_idx = spec.find("/")
|
| 181 |
+
if slash_idx == -1 or colon_idx < slash_idx:
|
| 182 |
+
provider, model = spec.split(":", 1)
|
| 183 |
+
return InferenceProviderJudge(model=model, provider=provider, max_tokens=max_tokens)
|
| 184 |
+
|
| 185 |
+
return InferenceProviderJudge(model=spec, max_tokens=max_tokens)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
# Jury aggregation
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def aggregate_jury_votes(
|
| 194 |
+
all_results: list[list[dict[str, str]]],
|
| 195 |
+
judge_names: list[str],
|
| 196 |
+
) -> list[dict[str, Any]]:
|
| 197 |
+
"""Aggregate votes from multiple judges using majority voting.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
all_results: List of result lists, one per judge. Each inner list
|
| 201 |
+
has one dict per comparison.
|
| 202 |
+
judge_names: Names of the judges (same order as *all_results*).
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
Aggregated results with ``winner``, ``reason``, and ``agreement`` fields.
|
| 206 |
+
"""
|
| 207 |
+
if not all_results:
|
| 208 |
+
return []
|
| 209 |
+
|
| 210 |
+
n_comparisons = len(all_results[0])
|
| 211 |
+
n_judges = len(all_results)
|
| 212 |
+
aggregated: list[dict[str, Any]] = []
|
| 213 |
+
|
| 214 |
+
for i in range(n_comparisons):
|
| 215 |
+
votes: list[str] = []
|
| 216 |
+
reasons: list[str] = []
|
| 217 |
+
for j in range(n_judges):
|
| 218 |
+
result = all_results[j][i] if i < len(all_results[j]) else {}
|
| 219 |
+
winner = result.get("winner", "")
|
| 220 |
+
if winner:
|
| 221 |
+
votes.append(winner)
|
| 222 |
+
reasons.append(f"{judge_names[j]}: {result.get('reason', '')}")
|
| 223 |
+
|
| 224 |
+
if not votes:
|
| 225 |
+
aggregated.append({"winner": "tie", "reason": "no valid votes", "agreement": "0/0"})
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
counter = Counter(votes)
|
| 229 |
+
majority_winner, majority_count = counter.most_common(1)[0]
|
| 230 |
+
agreement = f"{majority_count}/{len(votes)}"
|
| 231 |
+
|
| 232 |
+
aggregated.append({
|
| 233 |
+
"winner": majority_winner,
|
| 234 |
+
"reason": "; ".join(reasons),
|
| 235 |
+
"agreement": agreement,
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
return aggregated
|
src/ocr_bench/cli.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI entrypoint for ocr-bench."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import structlog
|
| 9 |
+
from rich.console import Console
|
| 10 |
+
from rich.table import Table
|
| 11 |
+
|
| 12 |
+
from ocr_bench.backends import (
|
| 13 |
+
DEFAULT_JUDGE,
|
| 14 |
+
DEFAULT_MAX_TOKENS,
|
| 15 |
+
aggregate_jury_votes,
|
| 16 |
+
parse_judge_spec,
|
| 17 |
+
)
|
| 18 |
+
from ocr_bench.dataset import (
|
| 19 |
+
DatasetError,
|
| 20 |
+
discover_configs,
|
| 21 |
+
discover_pr_configs,
|
| 22 |
+
load_config_dataset,
|
| 23 |
+
load_flat_dataset,
|
| 24 |
+
)
|
| 25 |
+
from ocr_bench.elo import ComparisonResult, Leaderboard, compute_elo, rankings_resolved
|
| 26 |
+
from ocr_bench.judge import Comparison, _normalize_pair, build_comparisons, sample_indices
|
| 27 |
+
from ocr_bench.publish import (
|
| 28 |
+
EvalMetadata,
|
| 29 |
+
load_existing_comparisons,
|
| 30 |
+
load_existing_metadata,
|
| 31 |
+
publish_results,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
logger = structlog.get_logger()
|
| 35 |
+
console = Console()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 39 |
+
parser = argparse.ArgumentParser(
|
| 40 |
+
prog="ocr-bench",
|
| 41 |
+
description="OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards",
|
| 42 |
+
)
|
| 43 |
+
sub = parser.add_subparsers(dest="command")
|
| 44 |
+
|
| 45 |
+
judge = sub.add_parser("judge", help="Run pairwise VLM judge on OCR outputs")
|
| 46 |
+
|
| 47 |
+
# Dataset
|
| 48 |
+
judge.add_argument("dataset", help="HF dataset repo id")
|
| 49 |
+
judge.add_argument("--split", default="train", help="Dataset split (default: train)")
|
| 50 |
+
judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names")
|
| 51 |
+
judge.add_argument(
|
| 52 |
+
"--configs", nargs="+", default=None, help="Config-per-model: list of config names"
|
| 53 |
+
)
|
| 54 |
+
judge.add_argument("--from-prs", action="store_true", help="Force PR-based config discovery")
|
| 55 |
+
judge.add_argument(
|
| 56 |
+
"--merge",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="Merge PRs to main after discovery (default: load via revision)",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Judge
|
| 62 |
+
judge.add_argument(
|
| 63 |
+
"--model",
|
| 64 |
+
action="append",
|
| 65 |
+
dest="models",
|
| 66 |
+
help=f"Judge model spec (repeatable for jury). Default: {DEFAULT_JUDGE}",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Eval
|
| 70 |
+
judge.add_argument("--max-samples", type=int, default=None, help="Max samples to evaluate")
|
| 71 |
+
judge.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
|
| 72 |
+
judge.add_argument(
|
| 73 |
+
"--max-tokens",
|
| 74 |
+
type=int,
|
| 75 |
+
default=DEFAULT_MAX_TOKENS,
|
| 76 |
+
help=f"Max tokens for judge response (default: {DEFAULT_MAX_TOKENS})",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Output
|
| 80 |
+
judge.add_argument(
|
| 81 |
+
"--save-results",
|
| 82 |
+
default=None,
|
| 83 |
+
help="HF repo id to publish results to (default: {dataset}-results)",
|
| 84 |
+
)
|
| 85 |
+
judge.add_argument(
|
| 86 |
+
"--no-publish",
|
| 87 |
+
action="store_true",
|
| 88 |
+
help="Don't publish results (default: publish to {dataset}-results)",
|
| 89 |
+
)
|
| 90 |
+
judge.add_argument(
|
| 91 |
+
"--full-rejudge",
|
| 92 |
+
action="store_true",
|
| 93 |
+
help="Re-judge all pairs, ignoring existing comparisons in --save-results repo",
|
| 94 |
+
)
|
| 95 |
+
judge.add_argument(
|
| 96 |
+
"--no-adaptive",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="Disable adaptive stopping (default: adaptive is on)",
|
| 99 |
+
)
|
| 100 |
+
judge.add_argument(
|
| 101 |
+
"--concurrency",
|
| 102 |
+
type=int,
|
| 103 |
+
default=1,
|
| 104 |
+
help="Number of concurrent judge API calls (default: 1)",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# --- run subcommand ---
|
| 108 |
+
run = sub.add_parser("run", help="Launch OCR models on a dataset via HF Jobs")
|
| 109 |
+
run.add_argument("input_dataset", help="HF dataset repo id with images")
|
| 110 |
+
run.add_argument("output_repo", help="Output dataset repo (all models push here)")
|
| 111 |
+
run.add_argument(
|
| 112 |
+
"--models", nargs="+", default=None, help="Model slugs to run (default: all 4 core)"
|
| 113 |
+
)
|
| 114 |
+
run.add_argument("--max-samples", type=int, default=None, help="Per-model sample limit")
|
| 115 |
+
run.add_argument("--split", default="train", help="Dataset split (default: train)")
|
| 116 |
+
run.add_argument("--flavor", default=None, help="Override GPU flavor for all models")
|
| 117 |
+
run.add_argument("--timeout", default="4h", help="Per-job timeout (default: 4h)")
|
| 118 |
+
run.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
|
| 119 |
+
run.add_argument("--shuffle", action="store_true", help="Shuffle source dataset")
|
| 120 |
+
run.add_argument("--list-models", action="store_true", help="Print available models and exit")
|
| 121 |
+
run.add_argument(
|
| 122 |
+
"--dry-run", action="store_true", help="Show what would launch without launching"
|
| 123 |
+
)
|
| 124 |
+
run.add_argument(
|
| 125 |
+
"--no-wait", action="store_true", help="Launch and exit without polling (default: wait)"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# --- view subcommand ---
|
| 129 |
+
view = sub.add_parser("view", help="Browse and validate results in a web UI")
|
| 130 |
+
view.add_argument("results", help="HF dataset repo id with published results")
|
| 131 |
+
view.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
|
| 132 |
+
view.add_argument("--host", default="127.0.0.1", help="Host (default: 127.0.0.1)")
|
| 133 |
+
view.add_argument("--output", default=None, help="Path to save annotations JSON")
|
| 134 |
+
|
| 135 |
+
return parser
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def print_leaderboard(board: Leaderboard) -> None:
|
| 139 |
+
"""Print leaderboard as a Rich table."""
|
| 140 |
+
table = Table(title="OCR Model Leaderboard")
|
| 141 |
+
table.add_column("Rank", style="bold")
|
| 142 |
+
table.add_column("Model")
|
| 143 |
+
has_ci = bool(board.elo_ci)
|
| 144 |
+
if has_ci:
|
| 145 |
+
table.add_column("ELO (95% CI)", justify="right")
|
| 146 |
+
else:
|
| 147 |
+
table.add_column("ELO", justify="right")
|
| 148 |
+
table.add_column("Wins", justify="right")
|
| 149 |
+
table.add_column("Losses", justify="right")
|
| 150 |
+
table.add_column("Ties", justify="right")
|
| 151 |
+
table.add_column("Win%", justify="right")
|
| 152 |
+
|
| 153 |
+
for rank, (model, elo) in enumerate(board.ranked, 1):
|
| 154 |
+
pct = board.win_pct(model)
|
| 155 |
+
pct_str = f"{pct:.0f}%" if pct is not None else "-"
|
| 156 |
+
if has_ci and model in board.elo_ci:
|
| 157 |
+
lo, hi = board.elo_ci[model]
|
| 158 |
+
elo_str = f"{round(elo)} ({round(lo)}\u2013{round(hi)})"
|
| 159 |
+
else:
|
| 160 |
+
elo_str = str(round(elo))
|
| 161 |
+
table.add_row(
|
| 162 |
+
str(rank),
|
| 163 |
+
model,
|
| 164 |
+
elo_str,
|
| 165 |
+
str(board.wins[model]),
|
| 166 |
+
str(board.losses[model]),
|
| 167 |
+
str(board.ties[model]),
|
| 168 |
+
pct_str,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
console.print(table)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _convert_results(
|
| 175 |
+
comparisons: list[Comparison], aggregated: list[dict]
|
| 176 |
+
) -> list[ComparisonResult]:
|
| 177 |
+
"""Convert judged comparisons + aggregated outputs into ComparisonResult list."""
|
| 178 |
+
results: list[ComparisonResult] = []
|
| 179 |
+
for comp, result in zip(comparisons, aggregated):
|
| 180 |
+
if not result:
|
| 181 |
+
continue
|
| 182 |
+
results.append(
|
| 183 |
+
ComparisonResult(
|
| 184 |
+
sample_idx=comp.sample_idx,
|
| 185 |
+
model_a=comp.model_a,
|
| 186 |
+
model_b=comp.model_b,
|
| 187 |
+
winner=result.get("winner", "tie"),
|
| 188 |
+
reason=result.get("reason", ""),
|
| 189 |
+
agreement=result.get("agreement", "1/1"),
|
| 190 |
+
swapped=comp.swapped,
|
| 191 |
+
text_a=comp.text_a,
|
| 192 |
+
text_b=comp.text_b,
|
| 193 |
+
col_a=comp.col_a,
|
| 194 |
+
col_b=comp.col_b,
|
| 195 |
+
)
|
| 196 |
+
)
|
| 197 |
+
return results
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _resolve_results_repo(dataset: str, save_results: str | None, no_publish: bool) -> str | None:
|
| 201 |
+
"""Derive the results repo id. Returns None if publishing is disabled."""
|
| 202 |
+
if no_publish:
|
| 203 |
+
return None
|
| 204 |
+
if save_results:
|
| 205 |
+
return save_results
|
| 206 |
+
return f"{dataset}-results"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def cmd_judge(args: argparse.Namespace) -> None:
|
| 210 |
+
"""Orchestrate: load → compare → judge → elo → print → publish."""
|
| 211 |
+
# --- Resolve flags ---
|
| 212 |
+
adaptive = not args.no_adaptive
|
| 213 |
+
merge = args.merge
|
| 214 |
+
results_repo = _resolve_results_repo(args.dataset, args.save_results, args.no_publish)
|
| 215 |
+
from_prs = False # track for metadata
|
| 216 |
+
|
| 217 |
+
if results_repo:
|
| 218 |
+
console.print(f"Results will be published to [bold]{results_repo}[/bold]")
|
| 219 |
+
|
| 220 |
+
# --- Load dataset (cascading auto-detection) ---
|
| 221 |
+
if args.configs:
|
| 222 |
+
# Explicit configs — use them directly
|
| 223 |
+
config_names = args.configs
|
| 224 |
+
ds, ocr_columns = load_config_dataset(args.dataset, config_names, split=args.split)
|
| 225 |
+
elif args.columns:
|
| 226 |
+
# Explicit columns — flat loading
|
| 227 |
+
ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split, columns=args.columns)
|
| 228 |
+
elif args.from_prs:
|
| 229 |
+
# Forced PR discovery
|
| 230 |
+
config_names, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
|
| 231 |
+
if not config_names:
|
| 232 |
+
raise DatasetError("No configs found in open PRs")
|
| 233 |
+
from_prs = True
|
| 234 |
+
console.print(f"Discovered {len(config_names)} configs from PRs: {config_names}")
|
| 235 |
+
ds, ocr_columns = load_config_dataset(
|
| 236 |
+
args.dataset,
|
| 237 |
+
config_names,
|
| 238 |
+
split=args.split,
|
| 239 |
+
pr_revisions=pr_revisions if not merge else None,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
# Auto-detect: PRs + main branch configs combined, fall back to flat
|
| 243 |
+
pr_configs, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
|
| 244 |
+
main_configs = discover_configs(args.dataset)
|
| 245 |
+
|
| 246 |
+
# Combine: PR configs + main configs not already in PRs
|
| 247 |
+
config_names = list(pr_configs)
|
| 248 |
+
for mc in main_configs:
|
| 249 |
+
if mc not in pr_configs:
|
| 250 |
+
config_names.append(mc)
|
| 251 |
+
|
| 252 |
+
if config_names:
|
| 253 |
+
if pr_configs:
|
| 254 |
+
from_prs = True
|
| 255 |
+
console.print(f"Auto-detected {len(pr_configs)} configs from PRs: {pr_configs}")
|
| 256 |
+
if main_configs:
|
| 257 |
+
main_only = [c for c in main_configs if c not in pr_configs]
|
| 258 |
+
if main_only:
|
| 259 |
+
console.print(f"Auto-detected {len(main_only)} configs on main: {main_only}")
|
| 260 |
+
ds, ocr_columns = load_config_dataset(
|
| 261 |
+
args.dataset,
|
| 262 |
+
config_names,
|
| 263 |
+
split=args.split,
|
| 264 |
+
pr_revisions=pr_revisions if pr_configs else None,
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
# No configs anywhere — fall back to flat loading
|
| 268 |
+
ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split)
|
| 269 |
+
|
| 270 |
+
console.print(f"Loaded {len(ds)} samples with {len(ocr_columns)} models:")
|
| 271 |
+
for col, model in ocr_columns.items():
|
| 272 |
+
console.print(f" {col} → {model}")
|
| 273 |
+
|
| 274 |
+
# --- Incremental: load existing comparisons ---
|
| 275 |
+
existing_results: list[ComparisonResult] = []
|
| 276 |
+
existing_meta_rows: list[dict] = []
|
| 277 |
+
skip_pairs: set[tuple[str, str]] | None = None
|
| 278 |
+
|
| 279 |
+
if results_repo and not args.full_rejudge:
|
| 280 |
+
existing_results = load_existing_comparisons(results_repo)
|
| 281 |
+
if existing_results:
|
| 282 |
+
judged_pairs = {_normalize_pair(r.model_a, r.model_b) for r in existing_results}
|
| 283 |
+
skip_pairs = judged_pairs
|
| 284 |
+
console.print(
|
| 285 |
+
f"\nIncremental mode: {len(existing_results)} existing comparisons "
|
| 286 |
+
f"across {len(judged_pairs)} model pairs — skipping those."
|
| 287 |
+
)
|
| 288 |
+
existing_meta_rows = load_existing_metadata(results_repo)
|
| 289 |
+
else:
|
| 290 |
+
console.print("\nNo existing comparisons found — full judge run.")
|
| 291 |
+
|
| 292 |
+
model_names = list(set(ocr_columns.values()))
|
| 293 |
+
|
| 294 |
+
# --- Judge setup (shared by both paths) ---
|
| 295 |
+
model_specs = args.models or [DEFAULT_JUDGE]
|
| 296 |
+
judges = [
|
| 297 |
+
parse_judge_spec(spec, max_tokens=args.max_tokens, concurrency=args.concurrency)
|
| 298 |
+
for spec in model_specs
|
| 299 |
+
]
|
| 300 |
+
is_jury = len(judges) > 1
|
| 301 |
+
|
| 302 |
+
def _judge_batch(batch_comps: list[Comparison]) -> list[ComparisonResult]:
|
| 303 |
+
"""Run judge(s) on a batch of comparisons and return ComparisonResults."""
|
| 304 |
+
all_judge_outputs: list[list[dict]] = []
|
| 305 |
+
for judge in judges:
|
| 306 |
+
results = judge.judge(batch_comps)
|
| 307 |
+
all_judge_outputs.append(results)
|
| 308 |
+
if is_jury:
|
| 309 |
+
judge_names = [j.name for j in judges]
|
| 310 |
+
aggregated = aggregate_jury_votes(all_judge_outputs, judge_names)
|
| 311 |
+
else:
|
| 312 |
+
aggregated = all_judge_outputs[0]
|
| 313 |
+
return _convert_results(batch_comps, aggregated)
|
| 314 |
+
|
| 315 |
+
if adaptive:
|
| 316 |
+
# --- Adaptive stopping: batch-by-batch with convergence check ---
|
| 317 |
+
from itertools import combinations as _combs
|
| 318 |
+
|
| 319 |
+
all_indices = sample_indices(len(ds), args.max_samples, args.seed)
|
| 320 |
+
n_pairs = len(list(_combs(model_names, 2)))
|
| 321 |
+
batch_samples = 5
|
| 322 |
+
min_before_check = max(3 * n_pairs, 20)
|
| 323 |
+
|
| 324 |
+
if is_jury:
|
| 325 |
+
console.print(f"\nJury mode: {len(judges)} judges")
|
| 326 |
+
console.print(
|
| 327 |
+
f"\n[bold]Adaptive mode[/bold]: {len(all_indices)} samples, "
|
| 328 |
+
f"{n_pairs} pairs, batch size {batch_samples}, "
|
| 329 |
+
f"checking after {min_before_check} comparisons"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
new_results: list[ComparisonResult] = []
|
| 333 |
+
total_comparisons = 0
|
| 334 |
+
for batch_num, batch_start in enumerate(range(0, len(all_indices), batch_samples)):
|
| 335 |
+
batch_indices = all_indices[batch_start : batch_start + batch_samples]
|
| 336 |
+
batch_comps = build_comparisons(
|
| 337 |
+
ds,
|
| 338 |
+
ocr_columns,
|
| 339 |
+
skip_pairs=skip_pairs,
|
| 340 |
+
indices=batch_indices,
|
| 341 |
+
seed=args.seed,
|
| 342 |
+
)
|
| 343 |
+
if not batch_comps:
|
| 344 |
+
continue
|
| 345 |
+
|
| 346 |
+
batch_results = _judge_batch(batch_comps)
|
| 347 |
+
new_results.extend(batch_results)
|
| 348 |
+
total_comparisons += len(batch_comps)
|
| 349 |
+
# batch_comps goes out of scope → GC can free images
|
| 350 |
+
|
| 351 |
+
total = len(existing_results) + len(new_results)
|
| 352 |
+
console.print(f" Batch {batch_num + 1}: {len(batch_results)} new, {total} total")
|
| 353 |
+
|
| 354 |
+
if total >= min_before_check:
|
| 355 |
+
board = compute_elo(existing_results + new_results, model_names)
|
| 356 |
+
# Show CI gaps for each adjacent pair
|
| 357 |
+
ranked = board.ranked
|
| 358 |
+
if board.elo_ci:
|
| 359 |
+
gaps: list[str] = []
|
| 360 |
+
for i in range(len(ranked) - 1):
|
| 361 |
+
hi_model, _ = ranked[i]
|
| 362 |
+
lo_model, _ = ranked[i + 1]
|
| 363 |
+
hi_ci = board.elo_ci.get(hi_model)
|
| 364 |
+
lo_ci = board.elo_ci.get(lo_model)
|
| 365 |
+
if hi_ci and lo_ci:
|
| 366 |
+
gap = hi_ci[0] - lo_ci[1] # positive = resolved
|
| 367 |
+
if gap > 0:
|
| 368 |
+
status = "[green]ok[/green]"
|
| 369 |
+
else:
|
| 370 |
+
status = f"[yellow]overlap {-gap:.0f}[/yellow]"
|
| 371 |
+
gaps.append(f" {hi_model} vs {lo_model}: gap={gap:+.0f} {status}")
|
| 372 |
+
if gaps:
|
| 373 |
+
console.print(" CI gaps:")
|
| 374 |
+
for g in gaps:
|
| 375 |
+
console.print(g)
|
| 376 |
+
|
| 377 |
+
if rankings_resolved(board):
|
| 378 |
+
remaining = len(all_indices) - batch_start - len(batch_indices)
|
| 379 |
+
console.print(
|
| 380 |
+
f"[green]Rankings converged after {total} comparisons! "
|
| 381 |
+
f"Skipped ~{remaining * n_pairs} remaining.[/green]"
|
| 382 |
+
)
|
| 383 |
+
break
|
| 384 |
+
|
| 385 |
+
console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
|
| 386 |
+
else:
|
| 387 |
+
# --- Standard single-pass flow ---
|
| 388 |
+
comparisons = build_comparisons(
|
| 389 |
+
ds,
|
| 390 |
+
ocr_columns,
|
| 391 |
+
max_samples=args.max_samples,
|
| 392 |
+
seed=args.seed,
|
| 393 |
+
skip_pairs=skip_pairs,
|
| 394 |
+
)
|
| 395 |
+
console.print(f"\nBuilt {len(comparisons)} new pairwise comparisons")
|
| 396 |
+
|
| 397 |
+
if not comparisons and not existing_results:
|
| 398 |
+
console.print(
|
| 399 |
+
"[yellow]No valid comparisons — check that OCR columns have text.[/yellow]"
|
| 400 |
+
)
|
| 401 |
+
return
|
| 402 |
+
|
| 403 |
+
if not comparisons:
|
| 404 |
+
console.print("[green]All pairs already judged — refitting leaderboard.[/green]")
|
| 405 |
+
board = compute_elo(existing_results, model_names)
|
| 406 |
+
console.print()
|
| 407 |
+
print_leaderboard(board)
|
| 408 |
+
if results_repo:
|
| 409 |
+
metadata = EvalMetadata(
|
| 410 |
+
source_dataset=args.dataset,
|
| 411 |
+
judge_models=[],
|
| 412 |
+
seed=args.seed,
|
| 413 |
+
max_samples=args.max_samples or len(ds),
|
| 414 |
+
total_comparisons=0,
|
| 415 |
+
valid_comparisons=0,
|
| 416 |
+
from_prs=from_prs,
|
| 417 |
+
)
|
| 418 |
+
publish_results(
|
| 419 |
+
results_repo,
|
| 420 |
+
board,
|
| 421 |
+
metadata,
|
| 422 |
+
existing_metadata=existing_meta_rows,
|
| 423 |
+
)
|
| 424 |
+
console.print(f"\nResults published to [bold]{results_repo}[/bold]")
|
| 425 |
+
return
|
| 426 |
+
|
| 427 |
+
if is_jury:
|
| 428 |
+
console.print(f"\nJury mode: {len(judges)} judges")
|
| 429 |
+
|
| 430 |
+
for judge in judges:
|
| 431 |
+
console.print(f"\nRunning judge: {judge.name}")
|
| 432 |
+
|
| 433 |
+
new_results = _judge_batch(comparisons)
|
| 434 |
+
total_comparisons = len(comparisons)
|
| 435 |
+
console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
|
| 436 |
+
|
| 437 |
+
# --- Merge existing + new, compute ELO ---
|
| 438 |
+
all_results = existing_results + new_results
|
| 439 |
+
board = compute_elo(all_results, model_names)
|
| 440 |
+
console.print()
|
| 441 |
+
print_leaderboard(board)
|
| 442 |
+
|
| 443 |
+
# --- Publish ---
|
| 444 |
+
if results_repo:
|
| 445 |
+
metadata = EvalMetadata(
|
| 446 |
+
source_dataset=args.dataset,
|
| 447 |
+
judge_models=[j.name for j in judges],
|
| 448 |
+
seed=args.seed,
|
| 449 |
+
max_samples=args.max_samples or len(ds),
|
| 450 |
+
total_comparisons=total_comparisons,
|
| 451 |
+
valid_comparisons=len(new_results),
|
| 452 |
+
from_prs=from_prs,
|
| 453 |
+
)
|
| 454 |
+
publish_results(results_repo, board, metadata, existing_metadata=existing_meta_rows)
|
| 455 |
+
console.print(f"\nResults published to [bold]{results_repo}[/bold]")
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def cmd_run(args: argparse.Namespace) -> None:
|
| 459 |
+
"""Launch OCR models on a dataset via HF Jobs."""
|
| 460 |
+
from ocr_bench.run import (
|
| 461 |
+
DEFAULT_MODELS,
|
| 462 |
+
MODEL_REGISTRY,
|
| 463 |
+
build_script_args,
|
| 464 |
+
launch_ocr_jobs,
|
| 465 |
+
poll_jobs,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# --list-models
|
| 469 |
+
if args.list_models:
|
| 470 |
+
table = Table(title="Available OCR Models", show_lines=True)
|
| 471 |
+
table.add_column("Slug", style="cyan bold")
|
| 472 |
+
table.add_column("Model ID")
|
| 473 |
+
table.add_column("Size", justify="right")
|
| 474 |
+
table.add_column("Default GPU", justify="center")
|
| 475 |
+
|
| 476 |
+
for slug in sorted(MODEL_REGISTRY):
|
| 477 |
+
cfg = MODEL_REGISTRY[slug]
|
| 478 |
+
default = " (default)" if slug in DEFAULT_MODELS else ""
|
| 479 |
+
table.add_row(slug + default, cfg.model_id, cfg.size, cfg.default_flavor)
|
| 480 |
+
|
| 481 |
+
console.print(table)
|
| 482 |
+
console.print(f"\nDefault set: {', '.join(DEFAULT_MODELS)}")
|
| 483 |
+
return
|
| 484 |
+
|
| 485 |
+
selected = args.models or DEFAULT_MODELS
|
| 486 |
+
for slug in selected:
|
| 487 |
+
if slug not in MODEL_REGISTRY:
|
| 488 |
+
console.print(f"[red]Unknown model: {slug}[/red]")
|
| 489 |
+
console.print(f"Available: {', '.join(MODEL_REGISTRY.keys())}")
|
| 490 |
+
sys.exit(1)
|
| 491 |
+
|
| 492 |
+
console.print("\n[bold]OCR Benchmark Run[/bold]")
|
| 493 |
+
console.print(f" Source: {args.input_dataset}")
|
| 494 |
+
console.print(f" Output: {args.output_repo}")
|
| 495 |
+
console.print(f" Models: {', '.join(selected)}")
|
| 496 |
+
if args.max_samples:
|
| 497 |
+
console.print(f" Samples: {args.max_samples} per model")
|
| 498 |
+
console.print()
|
| 499 |
+
|
| 500 |
+
# Dry run
|
| 501 |
+
if args.dry_run:
|
| 502 |
+
console.print("[bold yellow]DRY RUN[/bold yellow] — no jobs will be launched\n")
|
| 503 |
+
for slug in selected:
|
| 504 |
+
cfg = MODEL_REGISTRY[slug]
|
| 505 |
+
flavor = args.flavor or cfg.default_flavor
|
| 506 |
+
script_args = build_script_args(
|
| 507 |
+
args.input_dataset,
|
| 508 |
+
args.output_repo,
|
| 509 |
+
slug,
|
| 510 |
+
max_samples=args.max_samples,
|
| 511 |
+
shuffle=args.shuffle,
|
| 512 |
+
seed=args.seed,
|
| 513 |
+
extra_args=cfg.default_args or None,
|
| 514 |
+
)
|
| 515 |
+
console.print(f"[cyan]{slug}[/cyan] ({cfg.model_id})")
|
| 516 |
+
console.print(f" Flavor: {flavor}")
|
| 517 |
+
console.print(f" Timeout: {args.timeout}")
|
| 518 |
+
console.print(f" Script: {cfg.script}")
|
| 519 |
+
console.print(f" Args: {' '.join(script_args)}")
|
| 520 |
+
console.print()
|
| 521 |
+
console.print("Remove --dry-run to launch these jobs.")
|
| 522 |
+
return
|
| 523 |
+
|
| 524 |
+
# Launch
|
| 525 |
+
jobs = launch_ocr_jobs(
|
| 526 |
+
args.input_dataset,
|
| 527 |
+
args.output_repo,
|
| 528 |
+
models=selected,
|
| 529 |
+
max_samples=args.max_samples,
|
| 530 |
+
split=args.split,
|
| 531 |
+
shuffle=args.shuffle,
|
| 532 |
+
seed=args.seed,
|
| 533 |
+
flavor_override=args.flavor,
|
| 534 |
+
timeout=args.timeout,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
console.print(f"\n[green]{len(jobs)} jobs launched.[/green]")
|
| 538 |
+
for job in jobs:
|
| 539 |
+
console.print(f" [cyan]{job.model_slug}[/cyan]: {job.job_url}")
|
| 540 |
+
|
| 541 |
+
if not args.no_wait:
|
| 542 |
+
console.print("\n[bold]Waiting for jobs to complete...[/bold]")
|
| 543 |
+
poll_jobs(jobs)
|
| 544 |
+
console.print("\n[bold green]All jobs finished![/bold green]")
|
| 545 |
+
console.print("\nEvaluate:")
|
| 546 |
+
console.print(f" ocr-bench judge {args.output_repo}")
|
| 547 |
+
else:
|
| 548 |
+
console.print("\nJobs running in background.")
|
| 549 |
+
console.print("Check status at: https://huggingface.co/settings/jobs")
|
| 550 |
+
console.print(f"When complete: ocr-bench judge {args.output_repo}")
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def cmd_view(args: argparse.Namespace) -> None:
|
| 554 |
+
"""Launch the FastAPI + HTMX results viewer."""
|
| 555 |
+
try:
|
| 556 |
+
import uvicorn
|
| 557 |
+
|
| 558 |
+
from ocr_bench.web import create_app
|
| 559 |
+
except ImportError:
|
| 560 |
+
console.print(
|
| 561 |
+
"[red]Error:[/red] FastAPI/uvicorn not installed. "
|
| 562 |
+
"Install the viewer extra: [bold]pip install ocr-bench\\[viewer][/bold]"
|
| 563 |
+
)
|
| 564 |
+
sys.exit(1)
|
| 565 |
+
|
| 566 |
+
console.print(f"Loading results from [bold]{args.results}[/bold]...")
|
| 567 |
+
app = create_app(args.results, output_path=args.output)
|
| 568 |
+
console.print(f"Starting viewer at [bold]http://{args.host}:{args.port}[/bold]")
|
| 569 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def main() -> None:
|
| 573 |
+
parser = build_parser()
|
| 574 |
+
args = parser.parse_args()
|
| 575 |
+
|
| 576 |
+
if args.command is None:
|
| 577 |
+
parser.print_help()
|
| 578 |
+
sys.exit(0)
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
if args.command == "judge":
|
| 582 |
+
cmd_judge(args)
|
| 583 |
+
elif args.command == "run":
|
| 584 |
+
cmd_run(args)
|
| 585 |
+
elif args.command == "view":
|
| 586 |
+
cmd_view(args)
|
| 587 |
+
except DatasetError as exc:
|
| 588 |
+
console.print(f"[red]Error:[/red] {exc}")
|
| 589 |
+
sys.exit(1)
|
src/ocr_bench/dataset.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loading — flat, config-per-model, PR-based. OCR column discovery."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import structlog
|
| 8 |
+
from datasets import Dataset, get_dataset_config_names, load_dataset
|
| 9 |
+
from huggingface_hub import HfApi
|
| 10 |
+
|
| 11 |
+
logger = structlog.get_logger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DatasetError(Exception):
|
| 15 |
+
"""Raised when dataset loading or column discovery fails."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
# OCR column discovery
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def discover_ocr_columns(dataset: Dataset) -> dict[str, str]:
|
| 24 |
+
"""Discover OCR output columns and their model names from a dataset.
|
| 25 |
+
|
| 26 |
+
Strategy:
|
| 27 |
+
1. Parse ``inference_info`` JSON from the first row (list or single entry).
|
| 28 |
+
2. Fallback: heuristic column-name matching (``markdown``, ``ocr``, ``text``).
|
| 29 |
+
3. Disambiguate duplicate model names by appending the column name.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Mapping of ``column_name → model_name``.
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
DatasetError: If no OCR columns can be found.
|
| 36 |
+
"""
|
| 37 |
+
columns: dict[str, str] = {}
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
if "inference_info" not in dataset.column_names:
|
| 41 |
+
raise KeyError("no inference_info column")
|
| 42 |
+
info_raw = dataset["inference_info"][0] # column access avoids image decode
|
| 43 |
+
if info_raw:
|
| 44 |
+
info = json.loads(info_raw)
|
| 45 |
+
if not isinstance(info, list):
|
| 46 |
+
info = [info]
|
| 47 |
+
for entry in info:
|
| 48 |
+
col = entry.get("column_name", "")
|
| 49 |
+
model = entry.get("model_id", entry.get("model_name", "unknown"))
|
| 50 |
+
if col and col in dataset.column_names:
|
| 51 |
+
columns[col] = model
|
| 52 |
+
except (json.JSONDecodeError, TypeError, KeyError) as exc:
|
| 53 |
+
logger.warning("could_not_parse_inference_info", error=str(exc))
|
| 54 |
+
|
| 55 |
+
# Fallback: heuristic
|
| 56 |
+
if not columns:
|
| 57 |
+
for col in dataset.column_names:
|
| 58 |
+
lower = col.lower()
|
| 59 |
+
if "markdown" in lower or "ocr" in lower or col == "text":
|
| 60 |
+
columns[col] = col
|
| 61 |
+
|
| 62 |
+
if not columns:
|
| 63 |
+
raise DatasetError(f"No OCR columns found. Available columns: {dataset.column_names}")
|
| 64 |
+
|
| 65 |
+
# Disambiguate duplicates
|
| 66 |
+
model_counts: dict[str, int] = {}
|
| 67 |
+
for model in columns.values():
|
| 68 |
+
model_counts[model] = model_counts.get(model, 0) + 1
|
| 69 |
+
|
| 70 |
+
disambiguated: dict[str, str] = {}
|
| 71 |
+
for col, model in columns.items():
|
| 72 |
+
if model_counts[model] > 1:
|
| 73 |
+
short = model.split("/")[-1] if "/" in model else model
|
| 74 |
+
disambiguated[col] = f"{short} ({col})"
|
| 75 |
+
else:
|
| 76 |
+
disambiguated[col] = model
|
| 77 |
+
|
| 78 |
+
return disambiguated
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# PR-based config discovery
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def discover_pr_configs(
|
| 87 |
+
repo_id: str,
|
| 88 |
+
merge: bool = False,
|
| 89 |
+
api: HfApi | None = None,
|
| 90 |
+
) -> tuple[list[str], dict[str, str]]:
|
| 91 |
+
"""Discover dataset configs from open PRs on a Hub dataset repo.
|
| 92 |
+
|
| 93 |
+
PR titles must end with ``[config_name]`` to be detected.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
repo_id: HF dataset repo id.
|
| 97 |
+
merge: If True, merge each discovered PR before loading.
|
| 98 |
+
api: Optional pre-configured HfApi instance.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple of (config_names, {config_name: pr_revision}).
|
| 102 |
+
"""
|
| 103 |
+
if api is None:
|
| 104 |
+
api = HfApi()
|
| 105 |
+
|
| 106 |
+
config_names: list[str] = []
|
| 107 |
+
revisions: dict[str, str] = {}
|
| 108 |
+
|
| 109 |
+
discussions = api.get_repo_discussions(repo_id, repo_type="dataset")
|
| 110 |
+
for disc in discussions:
|
| 111 |
+
if not disc.is_pull_request or disc.status != "open":
|
| 112 |
+
continue
|
| 113 |
+
title = disc.title
|
| 114 |
+
if "[" in title and title.endswith("]"):
|
| 115 |
+
config = title[title.rindex("[") + 1 : -1].strip()
|
| 116 |
+
if config:
|
| 117 |
+
if merge:
|
| 118 |
+
api.merge_pull_request(repo_id, disc.num, repo_type="dataset")
|
| 119 |
+
logger.info("merged_pr", pr=disc.num, config=config)
|
| 120 |
+
else:
|
| 121 |
+
revisions[config] = f"refs/pr/{disc.num}"
|
| 122 |
+
config_names.append(config)
|
| 123 |
+
|
| 124 |
+
return config_names, revisions
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def discover_configs(repo_id: str) -> list[str]:
|
| 128 |
+
"""List non-default configs from the main branch of a Hub dataset.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Config names excluding "default", or empty list if none found.
|
| 132 |
+
"""
|
| 133 |
+
try:
|
| 134 |
+
configs = get_dataset_config_names(repo_id)
|
| 135 |
+
except Exception as exc:
|
| 136 |
+
logger.info("no_configs_on_main", repo=repo_id, reason=str(exc))
|
| 137 |
+
return []
|
| 138 |
+
return [c for c in configs if c != "default"]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# Config-per-model loading
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def load_config_dataset(
|
| 147 |
+
repo_id: str,
|
| 148 |
+
config_names: list[str],
|
| 149 |
+
split: str = "train",
|
| 150 |
+
pr_revisions: dict[str, str] | None = None,
|
| 151 |
+
) -> tuple[Dataset, dict[str, str]]:
|
| 152 |
+
"""Load multiple configs from a Hub dataset and merge into one.
|
| 153 |
+
|
| 154 |
+
Each config becomes a column whose name is the config name and whose value
|
| 155 |
+
is the OCR text (from the first column matching heuristics, or ``markdown``).
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
repo_id: HF dataset repo id.
|
| 159 |
+
config_names: List of config names to load.
|
| 160 |
+
split: Dataset split to load.
|
| 161 |
+
pr_revisions: Optional mapping of config_name → revision for PR-based loading.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tuple of (unified Dataset, {column_name: model_id}).
|
| 165 |
+
"""
|
| 166 |
+
if not config_names:
|
| 167 |
+
raise DatasetError("No config names provided")
|
| 168 |
+
|
| 169 |
+
pr_revisions = pr_revisions or {}
|
| 170 |
+
unified: Dataset | None = None
|
| 171 |
+
ocr_columns: dict[str, str] = {}
|
| 172 |
+
|
| 173 |
+
for config in config_names:
|
| 174 |
+
revision = pr_revisions.get(config)
|
| 175 |
+
kwargs: dict = {"path": repo_id, "name": config, "split": split}
|
| 176 |
+
if revision:
|
| 177 |
+
kwargs["revision"] = revision
|
| 178 |
+
|
| 179 |
+
ds = load_dataset(**kwargs)
|
| 180 |
+
|
| 181 |
+
# Find the OCR text column in this config
|
| 182 |
+
text_col = _find_text_column(ds)
|
| 183 |
+
if text_col is None:
|
| 184 |
+
logger.warning("no_text_column_in_config", config=config)
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
# Extract model_id from inference_info if available
|
| 188 |
+
model_id = _extract_model_id(ds, config)
|
| 189 |
+
ocr_columns[config] = model_id
|
| 190 |
+
|
| 191 |
+
# Build unified dataset using Arrow-level ops (no per-row image decode)
|
| 192 |
+
text_values = ds[text_col] # column access — no image decoding
|
| 193 |
+
if unified is None:
|
| 194 |
+
# First config: keep all columns except text_col, add text as config name
|
| 195 |
+
drop = [text_col] if text_col != config else []
|
| 196 |
+
unified = ds.remove_columns(drop) if drop else ds
|
| 197 |
+
if config != text_col:
|
| 198 |
+
unified = unified.add_column(config, text_values)
|
| 199 |
+
# Also rename text_col to config if they differ and text_col was kept
|
| 200 |
+
else:
|
| 201 |
+
if len(ds) != len(unified):
|
| 202 |
+
logger.warning(
|
| 203 |
+
"config_length_mismatch",
|
| 204 |
+
config=config,
|
| 205 |
+
expected=len(unified),
|
| 206 |
+
got=len(ds),
|
| 207 |
+
)
|
| 208 |
+
text_values = text_values[: len(unified)]
|
| 209 |
+
unified = unified.add_column(config, text_values)
|
| 210 |
+
|
| 211 |
+
if unified is None:
|
| 212 |
+
raise DatasetError("No configs loaded successfully")
|
| 213 |
+
|
| 214 |
+
return unified, ocr_columns
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _extract_model_id(ds: Dataset, config: str) -> str:
|
| 218 |
+
"""Extract model_id from inference_info in first row, falling back to config name."""
|
| 219 |
+
if "inference_info" not in ds.column_names:
|
| 220 |
+
return config
|
| 221 |
+
try:
|
| 222 |
+
info_raw = ds["inference_info"][0] # column access avoids image decode
|
| 223 |
+
if info_raw:
|
| 224 |
+
info = json.loads(info_raw)
|
| 225 |
+
if isinstance(info, list):
|
| 226 |
+
info = info[0]
|
| 227 |
+
return info.get("model_id", info.get("model_name", config))
|
| 228 |
+
except (json.JSONDecodeError, TypeError, KeyError, IndexError):
|
| 229 |
+
pass
|
| 230 |
+
return config
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _find_text_column(ds: Dataset) -> str | None:
|
| 234 |
+
"""Find the likely OCR text column in a dataset.
|
| 235 |
+
|
| 236 |
+
Priority:
|
| 237 |
+
1. ``inference_info[0]["column_name"]`` if present and exists in dataset.
|
| 238 |
+
2. First column matching ``markdown`` (case-insensitive).
|
| 239 |
+
3. First column matching ``ocr`` (case-insensitive).
|
| 240 |
+
4. Column named exactly ``text``.
|
| 241 |
+
"""
|
| 242 |
+
# Try inference_info first (column access avoids image decoding)
|
| 243 |
+
if "inference_info" in ds.column_names:
|
| 244 |
+
try:
|
| 245 |
+
info_raw = ds["inference_info"][0]
|
| 246 |
+
if info_raw:
|
| 247 |
+
info = json.loads(info_raw)
|
| 248 |
+
if isinstance(info, list):
|
| 249 |
+
info = info[0]
|
| 250 |
+
col_name = info.get("column_name", "")
|
| 251 |
+
if col_name and col_name in ds.column_names:
|
| 252 |
+
return col_name
|
| 253 |
+
except (json.JSONDecodeError, TypeError, KeyError, IndexError):
|
| 254 |
+
pass
|
| 255 |
+
|
| 256 |
+
# Prioritized heuristic: markdown > ocr > text
|
| 257 |
+
for pattern in ["markdown", "ocr"]:
|
| 258 |
+
for col in ds.column_names:
|
| 259 |
+
if pattern in col.lower():
|
| 260 |
+
return col
|
| 261 |
+
if "text" in ds.column_names:
|
| 262 |
+
return "text"
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ---------------------------------------------------------------------------
|
| 267 |
+
# Flat dataset loading
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_flat_dataset(
|
| 272 |
+
repo_id: str,
|
| 273 |
+
split: str = "train",
|
| 274 |
+
columns: list[str] | None = None,
|
| 275 |
+
) -> tuple[Dataset, dict[str, str]]:
|
| 276 |
+
"""Load a flat dataset from Hub and discover OCR columns.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
repo_id: HF dataset repo id.
|
| 280 |
+
split: Dataset split.
|
| 281 |
+
columns: If given, use these as OCR columns (maps col→col).
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Tuple of (Dataset, {column_name: model_name}).
|
| 285 |
+
"""
|
| 286 |
+
ds = load_dataset(repo_id, split=split)
|
| 287 |
+
|
| 288 |
+
if columns:
|
| 289 |
+
# Validate columns exist
|
| 290 |
+
for col in columns:
|
| 291 |
+
if col not in ds.column_names:
|
| 292 |
+
raise DatasetError(f"Column '{col}' not found. Available: {ds.column_names}")
|
| 293 |
+
ocr_columns = {col: col for col in columns}
|
| 294 |
+
else:
|
| 295 |
+
ocr_columns = discover_ocr_columns(ds)
|
| 296 |
+
|
| 297 |
+
return ds, ocr_columns
|
src/ocr_bench/elo.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bradley-Terry MLE rating computation for pairwise comparisons."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Literal
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.optimize import minimize
|
| 13 |
+
|
| 14 |
+
INITIAL_ELO: float = 1500.0
|
| 15 |
+
|
| 16 |
+
Winner = Literal["A", "B", "tie"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ComparisonResult:
|
| 21 |
+
"""Result of a single pairwise comparison, ready for ELO computation."""
|
| 22 |
+
|
| 23 |
+
sample_idx: int
|
| 24 |
+
model_a: str
|
| 25 |
+
model_b: str
|
| 26 |
+
winner: Winner
|
| 27 |
+
reason: str = ""
|
| 28 |
+
agreement: str = "1/1"
|
| 29 |
+
swapped: bool = False
|
| 30 |
+
text_a: str = ""
|
| 31 |
+
text_b: str = ""
|
| 32 |
+
col_a: str = ""
|
| 33 |
+
col_b: str = ""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Leaderboard:
|
| 38 |
+
"""ELO leaderboard computed from pairwise comparison results."""
|
| 39 |
+
|
| 40 |
+
elo: dict[str, float] = field(default_factory=dict)
|
| 41 |
+
wins: dict[str, int] = field(default_factory=dict)
|
| 42 |
+
losses: dict[str, int] = field(default_factory=dict)
|
| 43 |
+
ties: dict[str, int] = field(default_factory=dict)
|
| 44 |
+
comparison_log: list[dict[str, object]] = field(default_factory=list)
|
| 45 |
+
elo_ci: dict[str, tuple[float, float]] = field(default_factory=dict)
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def ranked(self) -> list[tuple[str, float]]:
|
| 49 |
+
"""Models sorted by ELO rating, descending."""
|
| 50 |
+
return sorted(self.elo.items(), key=lambda x: x[1], reverse=True)
|
| 51 |
+
|
| 52 |
+
def win_pct(self, model: str) -> float | None:
|
| 53 |
+
"""Win percentage for a model, or None if no comparisons."""
|
| 54 |
+
total = self.wins[model] + self.losses[model] + self.ties[model]
|
| 55 |
+
if total == 0:
|
| 56 |
+
return None
|
| 57 |
+
return self.wins[model] / total * 100
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _unswap_winner(winner: Winner, swapped: bool) -> Winner:
|
| 61 |
+
"""Unswap winner if positions were randomized."""
|
| 62 |
+
if swapped:
|
| 63 |
+
if winner == "A":
|
| 64 |
+
return "B"
|
| 65 |
+
elif winner == "B":
|
| 66 |
+
return "A"
|
| 67 |
+
return winner
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _build_win_matrix(
|
| 71 |
+
results: list[ComparisonResult],
|
| 72 |
+
) -> tuple[dict[tuple[str, str], float], set[str]]:
|
| 73 |
+
"""Count wins per ordered pair. Ties count as 0.5 for each side.
|
| 74 |
+
|
| 75 |
+
Returns (win_counts, models_seen) where win_counts[(i, j)] = fractional
|
| 76 |
+
wins of i over j.
|
| 77 |
+
"""
|
| 78 |
+
win_counts: dict[tuple[str, str], float] = defaultdict(float)
|
| 79 |
+
models_seen: set[str] = set()
|
| 80 |
+
|
| 81 |
+
for r in results:
|
| 82 |
+
winner = _unswap_winner(r.winner, r.swapped)
|
| 83 |
+
models_seen.add(r.model_a)
|
| 84 |
+
models_seen.add(r.model_b)
|
| 85 |
+
|
| 86 |
+
if winner == "A":
|
| 87 |
+
win_counts[(r.model_a, r.model_b)] += 1.0
|
| 88 |
+
elif winner == "B":
|
| 89 |
+
win_counts[(r.model_b, r.model_a)] += 1.0
|
| 90 |
+
else:
|
| 91 |
+
win_counts[(r.model_a, r.model_b)] += 0.5
|
| 92 |
+
win_counts[(r.model_b, r.model_a)] += 0.5
|
| 93 |
+
|
| 94 |
+
return win_counts, models_seen
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _bt_mle(
|
| 98 |
+
win_counts: dict[tuple[str, str], float],
|
| 99 |
+
model_names: list[str],
|
| 100 |
+
) -> dict[str, float]:
|
| 101 |
+
"""Fit Bradley-Terry model via maximum likelihood estimation.
|
| 102 |
+
|
| 103 |
+
Returns theta (strength) per model. Uses scipy L-BFGS-B on the
|
| 104 |
+
negative log-likelihood with log-parameterization for positivity.
|
| 105 |
+
"""
|
| 106 |
+
n = len(model_names)
|
| 107 |
+
if n == 0:
|
| 108 |
+
return {}
|
| 109 |
+
if n == 1:
|
| 110 |
+
return {model_names[0]: 1.0}
|
| 111 |
+
|
| 112 |
+
idx = {name: i for i, name in enumerate(model_names)}
|
| 113 |
+
|
| 114 |
+
# Collect all pairs with nonzero games
|
| 115 |
+
pairs: list[tuple[int, int, float, float]] = []
|
| 116 |
+
for i_name in model_names:
|
| 117 |
+
for j_name in model_names:
|
| 118 |
+
if i_name >= j_name:
|
| 119 |
+
continue
|
| 120 |
+
w_ij = win_counts.get((i_name, j_name), 0.0)
|
| 121 |
+
w_ji = win_counts.get((j_name, i_name), 0.0)
|
| 122 |
+
if w_ij + w_ji > 0:
|
| 123 |
+
pairs.append((idx[i_name], idx[j_name], w_ij, w_ji))
|
| 124 |
+
|
| 125 |
+
if not pairs:
|
| 126 |
+
return {name: 1.0 for name in model_names}
|
| 127 |
+
|
| 128 |
+
def neg_log_likelihood(log_theta: np.ndarray) -> float:
|
| 129 |
+
nll = 0.0
|
| 130 |
+
for i, j, w_ij, w_ji in pairs:
|
| 131 |
+
diff = log_theta[i] - log_theta[j]
|
| 132 |
+
# log(theta_i / (theta_i + theta_j)) = diff - log(1 + exp(diff))
|
| 133 |
+
# log(theta_j / (theta_i + theta_j)) = -diff - log(1 + exp(-diff))
|
| 134 |
+
# Use log-sum-exp for numerical stability
|
| 135 |
+
log_p_ij = diff - np.logaddexp(0.0, diff)
|
| 136 |
+
log_p_ji = -diff - np.logaddexp(0.0, -diff)
|
| 137 |
+
nll -= w_ij * log_p_ij + w_ji * log_p_ji
|
| 138 |
+
return nll
|
| 139 |
+
|
| 140 |
+
def gradient(log_theta: np.ndarray) -> np.ndarray:
|
| 141 |
+
grad = np.zeros(n)
|
| 142 |
+
for i, j, w_ij, w_ji in pairs:
|
| 143 |
+
diff = log_theta[i] - log_theta[j]
|
| 144 |
+
p_ij = 1.0 / (1.0 + np.exp(-diff)) # sigmoid(diff)
|
| 145 |
+
total = w_ij + w_ji
|
| 146 |
+
# d(NLL)/d(log_theta_i)
|
| 147 |
+
grad[i] -= w_ij - total * p_ij
|
| 148 |
+
grad[j] -= w_ji - total * (1.0 - p_ij)
|
| 149 |
+
return grad
|
| 150 |
+
|
| 151 |
+
# Pin first model at 0 to fix the scale
|
| 152 |
+
x0 = np.zeros(n)
|
| 153 |
+
result = minimize(
|
| 154 |
+
neg_log_likelihood,
|
| 155 |
+
x0,
|
| 156 |
+
jac=gradient,
|
| 157 |
+
method="L-BFGS-B",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
log_theta = result.x
|
| 161 |
+
# Center: subtract geometric mean (= mean of log_theta)
|
| 162 |
+
log_theta -= log_theta.mean()
|
| 163 |
+
theta = np.exp(log_theta)
|
| 164 |
+
|
| 165 |
+
return {name: float(theta[idx[name]]) for name in model_names}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _theta_to_elo(theta: dict[str, float], center: float = 1500.0) -> dict[str, float]:
|
| 169 |
+
"""Convert BT theta values to ELO scale.
|
| 170 |
+
|
| 171 |
+
ELO_i = 400 * log10(theta_i / theta_ref) + center
|
| 172 |
+
where theta_ref is the geometric mean of all theta values.
|
| 173 |
+
"""
|
| 174 |
+
if not theta:
|
| 175 |
+
return {}
|
| 176 |
+
|
| 177 |
+
values = list(theta.values())
|
| 178 |
+
log_geo_mean = sum(math.log(v) for v in values) / len(values)
|
| 179 |
+
geo_mean = math.exp(log_geo_mean)
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
name: 400.0 * math.log10(t / geo_mean) + center
|
| 183 |
+
for name, t in theta.items()
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _bootstrap_ci(
|
| 188 |
+
results: list[ComparisonResult],
|
| 189 |
+
model_names: list[str],
|
| 190 |
+
n_bootstrap: int = 1000,
|
| 191 |
+
ci: float = 0.95,
|
| 192 |
+
seed: int = 42,
|
| 193 |
+
) -> dict[str, tuple[float, float]]:
|
| 194 |
+
"""Compute bootstrap confidence intervals for ELO ratings.
|
| 195 |
+
|
| 196 |
+
Resamples comparisons with replacement, fits BT-MLE each time,
|
| 197 |
+
returns percentile-based CIs.
|
| 198 |
+
"""
|
| 199 |
+
if not results or not model_names:
|
| 200 |
+
return {}
|
| 201 |
+
|
| 202 |
+
rng = random.Random(seed)
|
| 203 |
+
n = len(results)
|
| 204 |
+
elo_samples: dict[str, list[float]] = {name: [] for name in model_names}
|
| 205 |
+
|
| 206 |
+
for _ in range(n_bootstrap):
|
| 207 |
+
boot = rng.choices(results, k=n)
|
| 208 |
+
win_counts, _ = _build_win_matrix(boot)
|
| 209 |
+
theta = _bt_mle(win_counts, model_names)
|
| 210 |
+
elos = _theta_to_elo(theta)
|
| 211 |
+
for name in model_names:
|
| 212 |
+
elo_samples[name].append(elos.get(name, 1500.0))
|
| 213 |
+
|
| 214 |
+
alpha = (1.0 - ci) / 2.0
|
| 215 |
+
lo_pct = alpha * 100
|
| 216 |
+
hi_pct = (1.0 - alpha) * 100
|
| 217 |
+
|
| 218 |
+
cis: dict[str, tuple[float, float]] = {}
|
| 219 |
+
for name in model_names:
|
| 220 |
+
samples = sorted(elo_samples[name])
|
| 221 |
+
lo_idx = int(len(samples) * lo_pct / 100)
|
| 222 |
+
hi_idx = min(int(len(samples) * hi_pct / 100), len(samples) - 1)
|
| 223 |
+
cis[name] = (samples[lo_idx], samples[hi_idx])
|
| 224 |
+
|
| 225 |
+
return cis
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def rankings_resolved(board: Leaderboard) -> bool:
|
| 229 |
+
"""Check if all adjacent ranks have non-overlapping 95% CIs.
|
| 230 |
+
|
| 231 |
+
Returns True when the ranking order is statistically resolved — i.e. for
|
| 232 |
+
every pair of adjacent models in the ranking, the higher-ranked model's
|
| 233 |
+
CI lower bound exceeds the lower-ranked model's CI upper bound.
|
| 234 |
+
"""
|
| 235 |
+
if not board.elo_ci:
|
| 236 |
+
return False
|
| 237 |
+
ranked = board.ranked
|
| 238 |
+
if len(ranked) < 2:
|
| 239 |
+
return False
|
| 240 |
+
for i in range(len(ranked) - 1):
|
| 241 |
+
model_hi, _ = ranked[i]
|
| 242 |
+
model_lo, _ = ranked[i + 1]
|
| 243 |
+
if model_hi not in board.elo_ci or model_lo not in board.elo_ci:
|
| 244 |
+
return False
|
| 245 |
+
lo_of_higher, _ = board.elo_ci[model_hi]
|
| 246 |
+
_, hi_of_lower = board.elo_ci[model_lo]
|
| 247 |
+
if hi_of_lower >= lo_of_higher:
|
| 248 |
+
return False # CIs overlap
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_elo(
|
| 253 |
+
results: list[ComparisonResult],
|
| 254 |
+
model_names: list[str],
|
| 255 |
+
n_bootstrap: int = 1000,
|
| 256 |
+
) -> Leaderboard:
|
| 257 |
+
"""Compute ELO ratings from pairwise comparison results using Bradley-Terry MLE.
|
| 258 |
+
|
| 259 |
+
Handles position-bias unswapping: if a result has swapped=True,
|
| 260 |
+
the winner is flipped before updating ratings.
|
| 261 |
+
|
| 262 |
+
Bootstrap confidence intervals are computed when n_bootstrap > 0.
|
| 263 |
+
"""
|
| 264 |
+
board = Leaderboard(
|
| 265 |
+
elo={m: INITIAL_ELO for m in model_names},
|
| 266 |
+
wins={m: 0 for m in model_names},
|
| 267 |
+
losses={m: 0 for m in model_names},
|
| 268 |
+
ties={m: 0 for m in model_names},
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Tally wins/losses/ties and build comparison log
|
| 272 |
+
for r in results:
|
| 273 |
+
winner = _unswap_winner(r.winner, r.swapped)
|
| 274 |
+
|
| 275 |
+
if winner == "A":
|
| 276 |
+
board.wins[r.model_a] += 1
|
| 277 |
+
board.losses[r.model_b] += 1
|
| 278 |
+
elif winner == "B":
|
| 279 |
+
board.losses[r.model_a] += 1
|
| 280 |
+
board.wins[r.model_b] += 1
|
| 281 |
+
else:
|
| 282 |
+
board.ties[r.model_a] += 1
|
| 283 |
+
board.ties[r.model_b] += 1
|
| 284 |
+
|
| 285 |
+
board.comparison_log.append(
|
| 286 |
+
{
|
| 287 |
+
"sample_idx": r.sample_idx,
|
| 288 |
+
"model_a": r.model_a,
|
| 289 |
+
"model_b": r.model_b,
|
| 290 |
+
"winner": winner,
|
| 291 |
+
"reason": r.reason,
|
| 292 |
+
"agreement": r.agreement,
|
| 293 |
+
"text_a": r.text_a,
|
| 294 |
+
"text_b": r.text_b,
|
| 295 |
+
"col_a": r.col_a,
|
| 296 |
+
"col_b": r.col_b,
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Fit BT-MLE
|
| 301 |
+
win_counts, _ = _build_win_matrix(results)
|
| 302 |
+
theta = _bt_mle(win_counts, model_names)
|
| 303 |
+
board.elo = _theta_to_elo(theta)
|
| 304 |
+
|
| 305 |
+
# Bootstrap CIs
|
| 306 |
+
if n_bootstrap > 0 and results:
|
| 307 |
+
board.elo_ci = _bootstrap_ci(results, model_names, n_bootstrap=n_bootstrap)
|
| 308 |
+
|
| 309 |
+
return board
|
src/ocr_bench/judge.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pairwise VLM judge — prompt templates, structured output schema, comparison building."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import io
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
import random
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from itertools import combinations
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
# --- Judge prompt ---
|
| 19 |
+
|
| 20 |
+
PAIRWISE_PROMPT = """\
|
| 21 |
+
You are an expert OCR quality evaluator. You are given a document image and \
|
| 22 |
+
TWO OCR outputs (A and B) extracted from that same image.
|
| 23 |
+
|
| 24 |
+
Compare them and decide which extraction is better overall.
|
| 25 |
+
|
| 26 |
+
Evaluation criteria (in priority order):
|
| 27 |
+
|
| 28 |
+
1. Faithfulness: The output must ONLY contain text actually visible in the document. \
|
| 29 |
+
Hallucinating text that is not in the image (garbled strings, repeated tokens, \
|
| 30 |
+
nonsensical output) is the most serious error. Added commentary or notes \
|
| 31 |
+
(e.g. "it appears the text says...") is also an error, but less severe than \
|
| 32 |
+
hallucination. If a page is blank or has minimal text, saying so is acceptable — \
|
| 33 |
+
fabricating content is always worse.
|
| 34 |
+
|
| 35 |
+
2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \
|
| 36 |
+
stamps, handwritten notes. Missing any section of text is a significant penalty.
|
| 37 |
+
|
| 38 |
+
3. Accuracy: Correct characters, no garbled or fabricated words.
|
| 39 |
+
|
| 40 |
+
4. Reading order: Text flows naturally as a human would read the document.
|
| 41 |
+
|
| 42 |
+
5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \
|
| 43 |
+
if present. Do NOT prefer fancier markdown formatting — plain accurate text is \
|
| 44 |
+
better than nicely formatted but incomplete text.
|
| 45 |
+
|
| 46 |
+
If both outputs capture the same text with similar accuracy, respond with "tie". \
|
| 47 |
+
Only pick a winner when there is a clear quality difference.
|
| 48 |
+
|
| 49 |
+
Output A:
|
| 50 |
+
---
|
| 51 |
+
{ocr_text_a}
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
Output B:
|
| 55 |
+
---
|
| 56 |
+
{ocr_text_b}
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
Respond with JSON only (no markdown fences, no extra text):
|
| 60 |
+
{{"winner": "A", "reason": "brief explanation"}}
|
| 61 |
+
Use "A", "B", or "tie" for the winner field."""
|
| 62 |
+
|
| 63 |
+
JUDGE_SCHEMA: dict[str, Any] = {
|
| 64 |
+
"type": "object",
|
| 65 |
+
"properties": {
|
| 66 |
+
"winner": {"type": "string", "enum": ["A", "B", "tie"]},
|
| 67 |
+
"reason": {"type": "string"},
|
| 68 |
+
},
|
| 69 |
+
"required": ["winner", "reason"],
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
# Max characters of OCR text to include per output in the prompt.
|
| 73 |
+
MAX_OCR_TEXT_LENGTH = 2500
|
| 74 |
+
|
| 75 |
+
# Max image dimension (longer side) before resizing.
|
| 76 |
+
MAX_IMAGE_DIM = 1024
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# --- Image helpers ---
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str:
|
| 83 |
+
"""Convert a PIL image to a base64-encoded JPEG string, resizing if needed."""
|
| 84 |
+
if image.mode != "RGB":
|
| 85 |
+
image = image.convert("RGB")
|
| 86 |
+
if max(image.size) > max_dim:
|
| 87 |
+
ratio = max_dim / max(image.size)
|
| 88 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
| 89 |
+
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
| 90 |
+
buf = io.BytesIO()
|
| 91 |
+
image.save(buf, format="JPEG", quality=85)
|
| 92 |
+
return base64.b64encode(buf.getvalue()).decode()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# --- Comparison ---
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class Comparison:
|
| 100 |
+
"""A single pairwise comparison to evaluate."""
|
| 101 |
+
|
| 102 |
+
sample_idx: int
|
| 103 |
+
model_a: str
|
| 104 |
+
model_b: str
|
| 105 |
+
col_a: str
|
| 106 |
+
col_b: str
|
| 107 |
+
swapped: bool
|
| 108 |
+
messages: list[dict[str, Any]]
|
| 109 |
+
text_a: str = ""
|
| 110 |
+
text_b: str = ""
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]:
|
| 114 |
+
"""Build the pairwise comparison prompt, applying position-bias swap.
|
| 115 |
+
|
| 116 |
+
Returns (prompt_text, swapped).
|
| 117 |
+
"""
|
| 118 |
+
a = text_a[:MAX_OCR_TEXT_LENGTH]
|
| 119 |
+
b = text_b[:MAX_OCR_TEXT_LENGTH]
|
| 120 |
+
if swapped:
|
| 121 |
+
a, b = b, a
|
| 122 |
+
return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]:
|
| 126 |
+
"""Build chat messages for the judge (image + prompt)."""
|
| 127 |
+
return [
|
| 128 |
+
{
|
| 129 |
+
"role": "user",
|
| 130 |
+
"content": [
|
| 131 |
+
{
|
| 132 |
+
"type": "image_url",
|
| 133 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
|
| 134 |
+
},
|
| 135 |
+
{"type": "text", "text": prompt},
|
| 136 |
+
],
|
| 137 |
+
}
|
| 138 |
+
]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _normalize_pair(a: str, b: str) -> tuple[str, str]:
|
| 142 |
+
"""Return a canonical (sorted) pair for symmetric lookup."""
|
| 143 |
+
return (a, b) if a <= b else (b, a)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def sample_indices(
|
| 147 |
+
dataset_len: int, max_samples: int | None = None, seed: int = 42
|
| 148 |
+
) -> list[int]:
|
| 149 |
+
"""Compute shuffled sample indices (cheap — no image loading).
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
dataset_len: Total number of rows in the dataset.
|
| 153 |
+
max_samples: If set, randomly sample this many indices.
|
| 154 |
+
seed: Random seed for reproducible sampling.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List of integer indices into the dataset.
|
| 158 |
+
"""
|
| 159 |
+
indices = list(range(dataset_len))
|
| 160 |
+
if max_samples and max_samples < len(indices):
|
| 161 |
+
random.seed(seed)
|
| 162 |
+
indices = random.sample(indices, max_samples)
|
| 163 |
+
return indices
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_comparisons(
|
| 167 |
+
dataset: Any,
|
| 168 |
+
ocr_columns: dict[str, str],
|
| 169 |
+
max_samples: int | None = None,
|
| 170 |
+
seed: int = 42,
|
| 171 |
+
skip_pairs: set[tuple[str, str]] | None = None,
|
| 172 |
+
indices: list[int] | None = None,
|
| 173 |
+
) -> list[Comparison]:
|
| 174 |
+
"""Build pairwise comparison prompts from a dataset.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
dataset: HF dataset with an "image" column and OCR output columns.
|
| 178 |
+
ocr_columns: Mapping of column_name -> model_name.
|
| 179 |
+
max_samples: If set, randomly sample this many rows. Ignored when
|
| 180 |
+
``indices`` is provided.
|
| 181 |
+
seed: Random seed for sampling and position-bias randomization.
|
| 182 |
+
skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are
|
| 183 |
+
normalized so (a, b) and (b, a) are treated identically.
|
| 184 |
+
If None, all pairs are included.
|
| 185 |
+
indices: Explicit row indices to use. When provided, ``max_samples``
|
| 186 |
+
and ``seed`` are not used for index selection (seed is still used
|
| 187 |
+
for position-bias randomization).
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
List of Comparison objects with pre-built chat messages.
|
| 191 |
+
"""
|
| 192 |
+
col_names = list(ocr_columns.keys())
|
| 193 |
+
model_names = list(ocr_columns.values())
|
| 194 |
+
pairs = list(combinations(range(len(col_names)), 2))
|
| 195 |
+
|
| 196 |
+
# Normalize skip set for symmetric lookup
|
| 197 |
+
normalized_skip: set[tuple[str, str]] = set()
|
| 198 |
+
if skip_pairs:
|
| 199 |
+
normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs}
|
| 200 |
+
|
| 201 |
+
if indices is None:
|
| 202 |
+
indices = sample_indices(len(dataset), max_samples, seed)
|
| 203 |
+
|
| 204 |
+
rng = random.Random(seed)
|
| 205 |
+
comparisons: list[Comparison] = []
|
| 206 |
+
|
| 207 |
+
# Pre-fetch text columns to avoid triggering image decode per row.
|
| 208 |
+
# HF Dataset supports column access (dataset["col"]), plain lists don't.
|
| 209 |
+
text_cols_data: dict[str, list] | None = None
|
| 210 |
+
if hasattr(dataset, "column_names"):
|
| 211 |
+
text_cols_data = {col: dataset[col] for col in col_names}
|
| 212 |
+
|
| 213 |
+
for idx in indices:
|
| 214 |
+
# Determine which pairs need judging for this row
|
| 215 |
+
needed_pairs = [
|
| 216 |
+
(i, j)
|
| 217 |
+
for i, j in pairs
|
| 218 |
+
if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip
|
| 219 |
+
]
|
| 220 |
+
if not needed_pairs:
|
| 221 |
+
continue # Skip image encoding entirely
|
| 222 |
+
|
| 223 |
+
# Check text availability before decoding the image
|
| 224 |
+
valid_pairs = []
|
| 225 |
+
if text_cols_data is not None:
|
| 226 |
+
for i, j in needed_pairs:
|
| 227 |
+
text_a = text_cols_data[col_names[i]][idx] or ""
|
| 228 |
+
text_b = text_cols_data[col_names[j]][idx] or ""
|
| 229 |
+
if text_a.strip() and text_b.strip():
|
| 230 |
+
valid_pairs.append((i, j, text_a, text_b))
|
| 231 |
+
else:
|
| 232 |
+
row = dataset[idx]
|
| 233 |
+
for i, j in needed_pairs:
|
| 234 |
+
text_a = row[col_names[i]] or ""
|
| 235 |
+
text_b = row[col_names[j]] or ""
|
| 236 |
+
if text_a.strip() and text_b.strip():
|
| 237 |
+
valid_pairs.append((i, j, text_a, text_b))
|
| 238 |
+
|
| 239 |
+
if not valid_pairs:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
image_b64 = image_to_base64(dataset[idx]["image"])
|
| 243 |
+
|
| 244 |
+
for i, j, text_a, text_b in valid_pairs:
|
| 245 |
+
swapped = rng.random() < 0.5
|
| 246 |
+
prompt, swapped = build_prompt(text_a, text_b, swapped)
|
| 247 |
+
messages = build_messages(image_b64, prompt)
|
| 248 |
+
|
| 249 |
+
comparisons.append(
|
| 250 |
+
Comparison(
|
| 251 |
+
sample_idx=idx,
|
| 252 |
+
model_a=model_names[i],
|
| 253 |
+
model_b=model_names[j],
|
| 254 |
+
col_a=col_names[i],
|
| 255 |
+
col_b=col_names[j],
|
| 256 |
+
swapped=swapped,
|
| 257 |
+
messages=messages,
|
| 258 |
+
text_a=text_a,
|
| 259 |
+
text_b=text_b,
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return comparisons
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# --- Output parsing ---
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def parse_judge_output(text: str) -> dict[str, str]:
|
| 270 |
+
"""Parse judge JSON output, handling markdown fences and invalid values.
|
| 271 |
+
|
| 272 |
+
Returns dict with "winner" and "reason" keys, or empty dict on failure.
|
| 273 |
+
"""
|
| 274 |
+
text = text.strip()
|
| 275 |
+
if text.startswith("```"):
|
| 276 |
+
text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()
|
| 277 |
+
try:
|
| 278 |
+
result = json.loads(text)
|
| 279 |
+
winner = result.get("winner", "tie").upper().strip()
|
| 280 |
+
if winner == "TIE":
|
| 281 |
+
winner = "tie"
|
| 282 |
+
if winner not in ("A", "B", "tie"):
|
| 283 |
+
winner = "tie"
|
| 284 |
+
return {"winner": winner, "reason": result.get("reason", "")}
|
| 285 |
+
except json.JSONDecodeError:
|
| 286 |
+
logger.warning("Failed to parse judge output: %s", text[:200])
|
| 287 |
+
return {}
|
src/ocr_bench/publish.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hub publishing — push comparisons, leaderboard, and metadata configs to HF Hub."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import datetime
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
import structlog
|
| 10 |
+
from datasets import Dataset, load_dataset
|
| 11 |
+
from huggingface_hub import HfApi
|
| 12 |
+
|
| 13 |
+
from ocr_bench.elo import ComparisonResult, Leaderboard
|
| 14 |
+
|
| 15 |
+
logger = structlog.get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class EvalMetadata:
|
| 20 |
+
"""Metadata for an evaluation run, stored alongside results on Hub."""
|
| 21 |
+
|
| 22 |
+
source_dataset: str
|
| 23 |
+
judge_models: list[str]
|
| 24 |
+
seed: int
|
| 25 |
+
max_samples: int
|
| 26 |
+
total_comparisons: int
|
| 27 |
+
valid_comparisons: int
|
| 28 |
+
from_prs: bool = False
|
| 29 |
+
timestamp: str = ""
|
| 30 |
+
|
| 31 |
+
def __post_init__(self):
|
| 32 |
+
if not self.timestamp:
|
| 33 |
+
self.timestamp = datetime.datetime.now(datetime.UTC).isoformat()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_existing_comparisons(repo_id: str) -> list[ComparisonResult]:
|
| 37 |
+
"""Load existing comparisons from a Hub results repo.
|
| 38 |
+
|
| 39 |
+
The stored winner is already unswapped (canonical), so ``swapped=False``.
|
| 40 |
+
Returns an empty list if the repo or config doesn't exist.
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
ds = load_dataset(repo_id, name="comparisons", split="train")
|
| 44 |
+
except Exception as exc:
|
| 45 |
+
logger.info("no_existing_comparisons", repo=repo_id, reason=str(exc))
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
results = []
|
| 49 |
+
for row in ds:
|
| 50 |
+
results.append(
|
| 51 |
+
ComparisonResult(
|
| 52 |
+
sample_idx=row["sample_idx"],
|
| 53 |
+
model_a=row["model_a"],
|
| 54 |
+
model_b=row["model_b"],
|
| 55 |
+
winner=row["winner"],
|
| 56 |
+
reason=row.get("reason", ""),
|
| 57 |
+
agreement=row.get("agreement", "1/1"),
|
| 58 |
+
swapped=False,
|
| 59 |
+
text_a=row.get("text_a", ""),
|
| 60 |
+
text_b=row.get("text_b", ""),
|
| 61 |
+
col_a=row.get("col_a", ""),
|
| 62 |
+
col_b=row.get("col_b", ""),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
logger.info("loaded_existing_comparisons", repo=repo_id, n=len(results))
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_existing_metadata(repo_id: str) -> list[dict]:
|
| 70 |
+
"""Load existing metadata rows from a Hub results repo.
|
| 71 |
+
|
| 72 |
+
Returns an empty list if the repo or config doesn't exist.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
ds = load_dataset(repo_id, name="metadata", split="train")
|
| 76 |
+
return [dict(row) for row in ds]
|
| 77 |
+
except Exception as exc:
|
| 78 |
+
logger.info("no_existing_metadata", repo=repo_id, reason=str(exc))
|
| 79 |
+
return []
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def build_leaderboard_rows(board: Leaderboard) -> list[dict]:
|
| 83 |
+
"""Convert a Leaderboard into rows suitable for a Hub dataset."""
|
| 84 |
+
rows = []
|
| 85 |
+
for model, elo in board.ranked:
|
| 86 |
+
total = board.wins[model] + board.losses[model] + board.ties[model]
|
| 87 |
+
row = {
|
| 88 |
+
"model": model,
|
| 89 |
+
"elo": round(elo),
|
| 90 |
+
"wins": board.wins[model],
|
| 91 |
+
"losses": board.losses[model],
|
| 92 |
+
"ties": board.ties[model],
|
| 93 |
+
"win_pct": round(board.wins[model] / total * 100) if total > 0 else 0,
|
| 94 |
+
}
|
| 95 |
+
if board.elo_ci and model in board.elo_ci:
|
| 96 |
+
lo, hi = board.elo_ci[model]
|
| 97 |
+
row["elo_low"] = round(lo)
|
| 98 |
+
row["elo_high"] = round(hi)
|
| 99 |
+
rows.append(row)
|
| 100 |
+
return rows
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_metadata_row(metadata: EvalMetadata) -> dict:
|
| 104 |
+
"""Convert EvalMetadata into a single row for a Hub dataset."""
|
| 105 |
+
return {
|
| 106 |
+
"source_dataset": metadata.source_dataset,
|
| 107 |
+
"judge_models": json.dumps(metadata.judge_models),
|
| 108 |
+
"seed": metadata.seed,
|
| 109 |
+
"max_samples": metadata.max_samples,
|
| 110 |
+
"total_comparisons": metadata.total_comparisons,
|
| 111 |
+
"valid_comparisons": metadata.valid_comparisons,
|
| 112 |
+
"from_prs": metadata.from_prs,
|
| 113 |
+
"timestamp": metadata.timestamp,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def publish_results(
|
| 118 |
+
repo_id: str,
|
| 119 |
+
board: Leaderboard,
|
| 120 |
+
metadata: EvalMetadata,
|
| 121 |
+
existing_metadata: list[dict] | None = None,
|
| 122 |
+
) -> None:
|
| 123 |
+
"""Push evaluation results to Hub as a dataset with multiple configs.
|
| 124 |
+
|
| 125 |
+
Configs:
|
| 126 |
+
- (default): Leaderboard table — ``load_dataset("repo")`` returns this.
|
| 127 |
+
- ``leaderboard``: Same table, named config (backward compat for viewer).
|
| 128 |
+
- ``comparisons``: Full comparison log from the board (caller merges
|
| 129 |
+
existing + new before ``compute_elo``, so ``board.comparison_log``
|
| 130 |
+
is already the complete set).
|
| 131 |
+
- ``metadata``: Append-only run log. New row is appended to
|
| 132 |
+
``existing_metadata``.
|
| 133 |
+
"""
|
| 134 |
+
# Comparisons
|
| 135 |
+
if board.comparison_log:
|
| 136 |
+
comp_ds = Dataset.from_list(board.comparison_log)
|
| 137 |
+
comp_ds.push_to_hub(repo_id, config_name="comparisons")
|
| 138 |
+
logger.info("published_comparisons", repo=repo_id, n=len(board.comparison_log))
|
| 139 |
+
|
| 140 |
+
# Leaderboard — dual push: default config + named config
|
| 141 |
+
rows = build_leaderboard_rows(board)
|
| 142 |
+
lb_ds = Dataset.from_list(rows)
|
| 143 |
+
lb_ds.push_to_hub(repo_id)
|
| 144 |
+
lb_ds.push_to_hub(repo_id, config_name="leaderboard")
|
| 145 |
+
logger.info("published_leaderboard", repo=repo_id, n=len(rows))
|
| 146 |
+
|
| 147 |
+
# Metadata — append-only
|
| 148 |
+
meta_row = build_metadata_row(metadata)
|
| 149 |
+
all_meta = (existing_metadata or []) + [meta_row]
|
| 150 |
+
Dataset.from_list(all_meta).push_to_hub(repo_id, config_name="metadata")
|
| 151 |
+
logger.info("published_metadata", repo=repo_id, n=len(all_meta))
|
| 152 |
+
|
| 153 |
+
# README — auto-generated dataset card with leaderboard
|
| 154 |
+
readme = _build_readme(repo_id, rows, board, metadata)
|
| 155 |
+
api = HfApi()
|
| 156 |
+
api.upload_file(
|
| 157 |
+
path_or_fileobj=readme.encode(),
|
| 158 |
+
path_in_repo="README.md",
|
| 159 |
+
repo_id=repo_id,
|
| 160 |
+
repo_type="dataset",
|
| 161 |
+
)
|
| 162 |
+
logger.info("published_readme", repo=repo_id)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _build_readme(
|
| 166 |
+
repo_id: str,
|
| 167 |
+
rows: list[dict],
|
| 168 |
+
board: Leaderboard,
|
| 169 |
+
metadata: EvalMetadata,
|
| 170 |
+
) -> str:
|
| 171 |
+
"""Build a dataset card README with the leaderboard table."""
|
| 172 |
+
has_ci = bool(board.elo_ci)
|
| 173 |
+
source_short = metadata.source_dataset.split("/")[-1]
|
| 174 |
+
judges = json.loads(
|
| 175 |
+
metadata.judge_models
|
| 176 |
+
if isinstance(metadata.judge_models, str)
|
| 177 |
+
else json.dumps(metadata.judge_models)
|
| 178 |
+
)
|
| 179 |
+
judge_str = ", ".join(j.split("/")[-1] for j in judges) if judges else "N/A"
|
| 180 |
+
n_comparisons = len(board.comparison_log)
|
| 181 |
+
|
| 182 |
+
lines = [
|
| 183 |
+
"---",
|
| 184 |
+
"license: mit",
|
| 185 |
+
"tags:",
|
| 186 |
+
" - ocr-bench",
|
| 187 |
+
" - leaderboard",
|
| 188 |
+
"configs:",
|
| 189 |
+
" - config_name: default",
|
| 190 |
+
" data_files:",
|
| 191 |
+
" - split: train",
|
| 192 |
+
" path: data/train-*.parquet",
|
| 193 |
+
" - config_name: comparisons",
|
| 194 |
+
" data_files:",
|
| 195 |
+
" - split: train",
|
| 196 |
+
" path: comparisons/train-*.parquet",
|
| 197 |
+
" - config_name: leaderboard",
|
| 198 |
+
" data_files:",
|
| 199 |
+
" - split: train",
|
| 200 |
+
" path: leaderboard/train-*.parquet",
|
| 201 |
+
" - config_name: metadata",
|
| 202 |
+
" data_files:",
|
| 203 |
+
" - split: train",
|
| 204 |
+
" path: metadata/train-*.parquet",
|
| 205 |
+
"---",
|
| 206 |
+
"",
|
| 207 |
+
f"# OCR Bench Results: {source_short}",
|
| 208 |
+
"",
|
| 209 |
+
"VLM-as-judge pairwise evaluation of OCR models. "
|
| 210 |
+
"Rankings depend on document type — there is no single best OCR model.",
|
| 211 |
+
"",
|
| 212 |
+
"## Leaderboard",
|
| 213 |
+
"",
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
# Table header
|
| 217 |
+
if has_ci:
|
| 218 |
+
lines.append("| Rank | Model | ELO | 95% CI | Wins | Losses | Ties | Win% |")
|
| 219 |
+
lines.append("|------|-------|-----|--------|------|--------|------|------|")
|
| 220 |
+
else:
|
| 221 |
+
lines.append("| Rank | Model | ELO | Wins | Losses | Ties | Win% |")
|
| 222 |
+
lines.append("|------|-------|-----|------|--------|------|------|")
|
| 223 |
+
|
| 224 |
+
for rank, row in enumerate(rows, 1):
|
| 225 |
+
model = row["model"]
|
| 226 |
+
elo = row["elo"]
|
| 227 |
+
if has_ci and "elo_low" in row:
|
| 228 |
+
ci = f"{row['elo_low']}\u2013{row['elo_high']}"
|
| 229 |
+
lines.append(
|
| 230 |
+
f"| {rank} | {model} | {elo} | {ci} "
|
| 231 |
+
f"| {row['wins']} | {row['losses']} | {row['ties']} "
|
| 232 |
+
f"| {row['win_pct']}% |"
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
lines.append(
|
| 236 |
+
f"| {rank} | {model} | {elo} "
|
| 237 |
+
f"| {row['wins']} | {row['losses']} | {row['ties']} "
|
| 238 |
+
f"| {row['win_pct']}% |"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
lines += [
|
| 242 |
+
"",
|
| 243 |
+
"## Details",
|
| 244 |
+
"",
|
| 245 |
+
f"- **Source dataset**: [`{metadata.source_dataset}`]"
|
| 246 |
+
f"(https://huggingface.co/datasets/{metadata.source_dataset})",
|
| 247 |
+
f"- **Judge**: {judge_str}",
|
| 248 |
+
f"- **Comparisons**: {n_comparisons}",
|
| 249 |
+
"- **Method**: Bradley-Terry MLE with bootstrap 95% CIs",
|
| 250 |
+
"",
|
| 251 |
+
"## Configs",
|
| 252 |
+
"",
|
| 253 |
+
f"- `load_dataset(\"{repo_id}\")` — leaderboard table",
|
| 254 |
+
f"- `load_dataset(\"{repo_id}\", name=\"comparisons\")` "
|
| 255 |
+
"— full pairwise comparison log",
|
| 256 |
+
f"- `load_dataset(\"{repo_id}\", name=\"metadata\")` "
|
| 257 |
+
"— evaluation run history",
|
| 258 |
+
"",
|
| 259 |
+
"*Generated by [ocr-bench](https://github.com/davanstrien/ocr-bench)*",
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
return "\n".join(lines) + "\n"
|
src/ocr_bench/run.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OCR model orchestration — launch HF Jobs for multiple OCR models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
import structlog
|
| 9 |
+
from huggingface_hub import HfApi, get_token
|
| 10 |
+
|
| 11 |
+
logger = structlog.get_logger()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class ModelConfig:
|
| 16 |
+
"""Configuration for a single OCR model."""
|
| 17 |
+
|
| 18 |
+
script: str
|
| 19 |
+
model_id: str
|
| 20 |
+
size: str
|
| 21 |
+
default_flavor: str = "l4x1"
|
| 22 |
+
default_args: list[str] = field(default_factory=list)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MODEL_REGISTRY: dict[str, ModelConfig] = {
|
| 26 |
+
"glm-ocr": ModelConfig(
|
| 27 |
+
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py",
|
| 28 |
+
model_id="zai-org/GLM-OCR",
|
| 29 |
+
size="0.9B",
|
| 30 |
+
default_flavor="l4x1",
|
| 31 |
+
),
|
| 32 |
+
"deepseek-ocr": ModelConfig(
|
| 33 |
+
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py",
|
| 34 |
+
model_id="deepseek-ai/DeepSeek-OCR",
|
| 35 |
+
size="4B",
|
| 36 |
+
default_flavor="l4x1",
|
| 37 |
+
default_args=["--prompt-mode", "free"],
|
| 38 |
+
),
|
| 39 |
+
"lighton-ocr-2": ModelConfig(
|
| 40 |
+
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py",
|
| 41 |
+
model_id="lightonai/LightOnOCR-2-1B",
|
| 42 |
+
size="1B",
|
| 43 |
+
default_flavor="a100-large",
|
| 44 |
+
),
|
| 45 |
+
"dots-ocr": ModelConfig(
|
| 46 |
+
script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py",
|
| 47 |
+
model_id="rednote-hilab/dots.ocr",
|
| 48 |
+
size="1.7B",
|
| 49 |
+
default_flavor="l4x1",
|
| 50 |
+
),
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class JobRun:
|
| 58 |
+
"""Tracks a launched HF Job."""
|
| 59 |
+
|
| 60 |
+
model_slug: str
|
| 61 |
+
job_id: str
|
| 62 |
+
job_url: str
|
| 63 |
+
status: str = "running"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def list_models() -> list[str]:
|
| 67 |
+
"""Return sorted list of available model slugs."""
|
| 68 |
+
return sorted(MODEL_REGISTRY.keys())
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_script_args(
|
| 72 |
+
input_dataset: str,
|
| 73 |
+
output_repo: str,
|
| 74 |
+
config_name: str,
|
| 75 |
+
*,
|
| 76 |
+
max_samples: int | None = None,
|
| 77 |
+
shuffle: bool = False,
|
| 78 |
+
seed: int = 42,
|
| 79 |
+
extra_args: list[str] | None = None,
|
| 80 |
+
) -> list[str]:
|
| 81 |
+
"""Build the script_args list for run_uv_job."""
|
| 82 |
+
args = [
|
| 83 |
+
input_dataset,
|
| 84 |
+
output_repo,
|
| 85 |
+
"--config",
|
| 86 |
+
config_name,
|
| 87 |
+
"--create-pr",
|
| 88 |
+
]
|
| 89 |
+
if max_samples is not None:
|
| 90 |
+
args += ["--max-samples", str(max_samples)]
|
| 91 |
+
if shuffle:
|
| 92 |
+
args.append("--shuffle")
|
| 93 |
+
if seed != 42:
|
| 94 |
+
args += ["--seed", str(seed)]
|
| 95 |
+
if extra_args:
|
| 96 |
+
args += extra_args
|
| 97 |
+
return args
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def launch_ocr_jobs(
|
| 101 |
+
input_dataset: str,
|
| 102 |
+
output_repo: str,
|
| 103 |
+
*,
|
| 104 |
+
models: list[str] | None = None,
|
| 105 |
+
max_samples: int | None = None,
|
| 106 |
+
split: str = "train",
|
| 107 |
+
shuffle: bool = False,
|
| 108 |
+
seed: int = 42,
|
| 109 |
+
flavor_override: str | None = None,
|
| 110 |
+
timeout: str = "4h",
|
| 111 |
+
api: HfApi | None = None,
|
| 112 |
+
) -> list[JobRun]:
|
| 113 |
+
"""Launch HF Jobs for each model. Returns list of JobRun tracking objects."""
|
| 114 |
+
if api is None:
|
| 115 |
+
api = HfApi()
|
| 116 |
+
|
| 117 |
+
token = get_token()
|
| 118 |
+
if not token:
|
| 119 |
+
raise RuntimeError("No HF token found. Log in with `hf login` or set HF_TOKEN.")
|
| 120 |
+
|
| 121 |
+
selected = models or DEFAULT_MODELS
|
| 122 |
+
for slug in selected:
|
| 123 |
+
if slug not in MODEL_REGISTRY:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"Unknown model: {slug}. Available: {', '.join(MODEL_REGISTRY.keys())}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
jobs: list[JobRun] = []
|
| 129 |
+
for slug in selected:
|
| 130 |
+
config = MODEL_REGISTRY[slug]
|
| 131 |
+
flavor = flavor_override or config.default_flavor
|
| 132 |
+
script_args = build_script_args(
|
| 133 |
+
input_dataset,
|
| 134 |
+
output_repo,
|
| 135 |
+
slug,
|
| 136 |
+
max_samples=max_samples,
|
| 137 |
+
shuffle=shuffle,
|
| 138 |
+
seed=seed,
|
| 139 |
+
extra_args=config.default_args or None,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
logger.info("launching_job", model=slug, flavor=flavor, script=config.script)
|
| 143 |
+
job = api.run_uv_job(
|
| 144 |
+
script=config.script,
|
| 145 |
+
script_args=script_args,
|
| 146 |
+
flavor=flavor,
|
| 147 |
+
secrets={"HF_TOKEN": token},
|
| 148 |
+
timeout=timeout,
|
| 149 |
+
)
|
| 150 |
+
jobs.append(JobRun(model_slug=slug, job_id=job.id, job_url=job.url))
|
| 151 |
+
logger.info("job_launched", model=slug, job_id=job.id, url=job.url)
|
| 152 |
+
|
| 153 |
+
return jobs
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
_TERMINAL_STAGES = frozenset({"COMPLETED", "ERROR", "CANCELED", "DELETED"})
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def poll_jobs(
|
| 160 |
+
jobs: list[JobRun],
|
| 161 |
+
*,
|
| 162 |
+
interval: int = 30,
|
| 163 |
+
api: HfApi | None = None,
|
| 164 |
+
) -> list[JobRun]:
|
| 165 |
+
"""Poll until all jobs complete or fail. Updates status in-place and returns the list."""
|
| 166 |
+
if api is None:
|
| 167 |
+
api = HfApi()
|
| 168 |
+
|
| 169 |
+
pending = {j.job_id: j for j in jobs if j.status == "running"}
|
| 170 |
+
|
| 171 |
+
while pending:
|
| 172 |
+
time.sleep(interval)
|
| 173 |
+
still_running: dict[str, JobRun] = {}
|
| 174 |
+
for job_id, job_run in pending.items():
|
| 175 |
+
info = api.inspect_job(job_id=job_id)
|
| 176 |
+
stage = info.status.stage
|
| 177 |
+
if stage in _TERMINAL_STAGES:
|
| 178 |
+
job_run.status = stage.lower()
|
| 179 |
+
logger.info("job_finished", model=job_run.model_slug, status=job_run.status)
|
| 180 |
+
else:
|
| 181 |
+
still_running[job_id] = job_run
|
| 182 |
+
pending = still_running
|
| 183 |
+
if pending:
|
| 184 |
+
slugs = [j.model_slug for j in pending.values()]
|
| 185 |
+
logger.info("jobs_pending", models=slugs)
|
| 186 |
+
|
| 187 |
+
return jobs
|
src/ocr_bench/space.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Space entry point for ocr-bench viewer."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import uvicorn
|
| 6 |
+
|
| 7 |
+
from ocr_bench.web import create_app
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
repos = os.environ.get("REPOS", "davanstrien/bpl-ocr-bench-results")
|
| 12 |
+
repo_id = repos.split(",")[0].strip()
|
| 13 |
+
app = create_app(repo_id)
|
| 14 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
main()
|
src/ocr_bench/static/style.css
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* ocr-bench viewer — Tufte-inspired minimal styles */
|
| 2 |
+
|
| 3 |
+
*,
|
| 4 |
+
*::before,
|
| 5 |
+
*::after {
|
| 6 |
+
box-sizing: border-box;
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
body {
|
| 10 |
+
font-family: system-ui, -apple-system, sans-serif;
|
| 11 |
+
color: #333;
|
| 12 |
+
background: #fff;
|
| 13 |
+
margin: 0;
|
| 14 |
+
padding: 0;
|
| 15 |
+
line-height: 1.5;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
.container {
|
| 19 |
+
max-width: 960px;
|
| 20 |
+
margin: 0 auto;
|
| 21 |
+
padding: 0 1.5rem 3rem;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* Navigation */
|
| 25 |
+
nav {
|
| 26 |
+
border-bottom: 1px solid #ddd;
|
| 27 |
+
padding: 0.75rem 0;
|
| 28 |
+
margin-bottom: 2rem;
|
| 29 |
+
display: flex;
|
| 30 |
+
align-items: baseline;
|
| 31 |
+
gap: 2rem;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
nav .brand {
|
| 35 |
+
font-weight: 600;
|
| 36 |
+
color: #333;
|
| 37 |
+
text-decoration: none;
|
| 38 |
+
font-size: 0.9rem;
|
| 39 |
+
letter-spacing: 0.02em;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
nav a {
|
| 43 |
+
color: #666;
|
| 44 |
+
text-decoration: none;
|
| 45 |
+
font-size: 0.85rem;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
nav a:hover,
|
| 49 |
+
nav a.active {
|
| 50 |
+
color: #333;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
nav a.active {
|
| 54 |
+
border-bottom: 2px solid #333;
|
| 55 |
+
padding-bottom: 2px;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/* Comparison layout */
|
| 59 |
+
.comparison-columns {
|
| 60 |
+
display: grid;
|
| 61 |
+
grid-template-columns: 1fr 1fr;
|
| 62 |
+
gap: 2rem;
|
| 63 |
+
margin: 1.5rem 0;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.ocr-column h3 {
|
| 67 |
+
font-size: 0.85rem;
|
| 68 |
+
font-weight: 600;
|
| 69 |
+
color: #666;
|
| 70 |
+
margin: 0 0 0.5rem;
|
| 71 |
+
padding-bottom: 0.35rem;
|
| 72 |
+
border-bottom: 1px solid #ddd;
|
| 73 |
+
letter-spacing: 0.02em;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.ocr-column h3.revealed {
|
| 77 |
+
color: #333;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.ocr-text {
|
| 81 |
+
font-family: "SF Mono", "Menlo", "Consolas", monospace;
|
| 82 |
+
font-size: 0.82rem;
|
| 83 |
+
line-height: 1.6;
|
| 84 |
+
white-space: pre-wrap;
|
| 85 |
+
word-break: break-word;
|
| 86 |
+
max-height: 50vh;
|
| 87 |
+
overflow-y: auto;
|
| 88 |
+
padding: 0.25rem 0;
|
| 89 |
+
color: #444;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
/* Navigation header */
|
| 93 |
+
.comp-nav {
|
| 94 |
+
display: flex;
|
| 95 |
+
justify-content: flex-end;
|
| 96 |
+
align-items: baseline;
|
| 97 |
+
gap: 0.75rem;
|
| 98 |
+
margin-bottom: 0.5rem;
|
| 99 |
+
color: #999;
|
| 100 |
+
font-size: 0.8rem;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
.comp-nav a {
|
| 104 |
+
color: #999;
|
| 105 |
+
text-decoration: none;
|
| 106 |
+
font-size: 0.85rem;
|
| 107 |
+
padding: 0.15rem 0.4rem;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
.comp-nav a:hover {
|
| 111 |
+
color: #333;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/* Vote prompt */
|
| 115 |
+
.vote-prompt {
|
| 116 |
+
text-align: center;
|
| 117 |
+
font-size: 0.8rem;
|
| 118 |
+
color: #999;
|
| 119 |
+
margin: 1.5rem 0 0.5rem;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/* Vote buttons */
|
| 123 |
+
.vote-row {
|
| 124 |
+
text-align: center;
|
| 125 |
+
margin: 0.25rem 0 0.5rem;
|
| 126 |
+
display: flex;
|
| 127 |
+
justify-content: center;
|
| 128 |
+
gap: 0.5rem;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
.vote-btn {
|
| 132 |
+
display: inline-block;
|
| 133 |
+
color: #555;
|
| 134 |
+
text-decoration: none;
|
| 135 |
+
padding: 0.35rem 1rem;
|
| 136 |
+
border: 1px solid #ddd;
|
| 137 |
+
border-radius: 4px;
|
| 138 |
+
font-size: 0.85rem;
|
| 139 |
+
transition: border-color 0.15s, color 0.15s;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.vote-btn:hover {
|
| 143 |
+
color: #333;
|
| 144 |
+
border-color: #999;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
.vote-btn.vote-tie {
|
| 148 |
+
color: #888;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/* Hints below vote buttons */
|
| 152 |
+
.vote-hints {
|
| 153 |
+
text-align: center;
|
| 154 |
+
margin: 0.5rem 0 1rem;
|
| 155 |
+
font-size: 0.75rem;
|
| 156 |
+
color: #bbb;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.vote-hints a {
|
| 160 |
+
color: #999;
|
| 161 |
+
text-decoration: none;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
.vote-hints a:hover {
|
| 165 |
+
color: #666;
|
| 166 |
+
text-decoration: underline;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
.vote-hints .separator {
|
| 170 |
+
color: #ddd;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
.vote-hints kbd {
|
| 174 |
+
font-family: system-ui, sans-serif;
|
| 175 |
+
font-size: 0.7rem;
|
| 176 |
+
padding: 0.05rem 0.3rem;
|
| 177 |
+
border: 1px solid #ddd;
|
| 178 |
+
border-radius: 3px;
|
| 179 |
+
background: #f8f8f8;
|
| 180 |
+
color: #999;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/* Legacy reveal-row (kept for compat) */
|
| 184 |
+
.reveal-row {
|
| 185 |
+
text-align: right;
|
| 186 |
+
margin: 0.25rem 0 1rem;
|
| 187 |
+
font-size: 0.8rem;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
.reveal-row a {
|
| 191 |
+
color: #999;
|
| 192 |
+
text-decoration: none;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
.reveal-row a:hover {
|
| 196 |
+
color: #666;
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/* Verdict display */
|
| 200 |
+
.verdict {
|
| 201 |
+
margin: 1rem 0;
|
| 202 |
+
font-size: 0.85rem;
|
| 203 |
+
color: #555;
|
| 204 |
+
line-height: 1.6;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
.verdict .agreement {
|
| 208 |
+
font-weight: 500;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
.verdict .agreement.agreed {
|
| 212 |
+
color: #457b4d;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
.verdict .agreement.soft-disagree {
|
| 216 |
+
color: #a07828;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
.verdict .agreement.hard-disagree {
|
| 220 |
+
color: #b04040;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
.verdict .reason {
|
| 224 |
+
font-style: italic;
|
| 225 |
+
color: #777;
|
| 226 |
+
display: block;
|
| 227 |
+
margin-top: 0.25rem;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/* Document image */
|
| 231 |
+
.doc-image {
|
| 232 |
+
margin: 1.5rem 0;
|
| 233 |
+
text-align: center;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
.doc-image img {
|
| 237 |
+
max-width: 100%;
|
| 238 |
+
height: auto;
|
| 239 |
+
max-height: 60vh;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
/* Leaderboard table */
|
| 243 |
+
table {
|
| 244 |
+
width: 100%;
|
| 245 |
+
border-collapse: collapse;
|
| 246 |
+
font-size: 0.85rem;
|
| 247 |
+
margin: 1.5rem 0;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
thead th {
|
| 251 |
+
text-align: left;
|
| 252 |
+
font-weight: 600;
|
| 253 |
+
padding: 0.5rem 0.75rem;
|
| 254 |
+
border-bottom: 2px solid #333;
|
| 255 |
+
color: #333;
|
| 256 |
+
font-size: 0.8rem;
|
| 257 |
+
letter-spacing: 0.02em;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
thead th.num {
|
| 261 |
+
text-align: right;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
tbody td {
|
| 265 |
+
padding: 0.4rem 0.75rem;
|
| 266 |
+
border-bottom: 1px solid #eee;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
tbody td.num {
|
| 270 |
+
text-align: right;
|
| 271 |
+
font-variant-numeric: tabular-nums;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
tbody td.model {
|
| 275 |
+
font-weight: 500;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
tbody tr:hover {
|
| 279 |
+
background: #fafafa;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/* Filters */
|
| 283 |
+
.filters {
|
| 284 |
+
display: flex;
|
| 285 |
+
gap: 1rem;
|
| 286 |
+
margin-bottom: 1rem;
|
| 287 |
+
align-items: center;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.filters label {
|
| 291 |
+
font-size: 0.8rem;
|
| 292 |
+
color: #666;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
.filters select {
|
| 296 |
+
font-size: 0.8rem;
|
| 297 |
+
padding: 0.25rem 0.5rem;
|
| 298 |
+
border: 1px solid #ddd;
|
| 299 |
+
border-radius: 3px;
|
| 300 |
+
background: #fff;
|
| 301 |
+
color: #333;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/* Stats panel */
|
| 305 |
+
.stats-panel {
|
| 306 |
+
color: #888;
|
| 307 |
+
font-size: 0.8rem;
|
| 308 |
+
padding: 1rem 0;
|
| 309 |
+
border-top: 1px solid #eee;
|
| 310 |
+
margin-top: 2rem;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
.stats-panel .calibrated {
|
| 314 |
+
color: #457b4d;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
.stats-panel .warning {
|
| 318 |
+
color: #b04040;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
/* Pair summary table */
|
| 322 |
+
.pair-summary {
|
| 323 |
+
margin-bottom: 1rem;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
.pair-table {
|
| 327 |
+
width: auto;
|
| 328 |
+
font-size: 0.8rem;
|
| 329 |
+
color: #888;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
.pair-table th {
|
| 333 |
+
font-size: 0.75rem;
|
| 334 |
+
color: #999;
|
| 335 |
+
font-weight: 500;
|
| 336 |
+
padding: 0.2rem 0.6rem;
|
| 337 |
+
border-bottom: 1px solid #ddd;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
.pair-table td {
|
| 341 |
+
padding: 0.15rem 0.6rem;
|
| 342 |
+
border-bottom: 1px solid #f0f0f0;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
/* HTMX loading indicator */
|
| 346 |
+
.htmx-indicator {
|
| 347 |
+
opacity: 0;
|
| 348 |
+
transition: opacity 200ms ease-in;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
.htmx-request .htmx-indicator,
|
| 352 |
+
.htmx-request.htmx-indicator {
|
| 353 |
+
opacity: 1;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
/* Empty state */
|
| 357 |
+
.empty {
|
| 358 |
+
text-align: center;
|
| 359 |
+
color: #999;
|
| 360 |
+
padding: 3rem 0;
|
| 361 |
+
font-size: 0.9rem;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
/* Responsive */
|
| 365 |
+
@media (max-width: 768px) {
|
| 366 |
+
.comparison-columns {
|
| 367 |
+
grid-template-columns: 1fr;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
.container {
|
| 371 |
+
padding: 0 1rem 2rem;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
table {
|
| 375 |
+
display: block;
|
| 376 |
+
overflow-x: auto;
|
| 377 |
+
-webkit-overflow-scrolling: touch;
|
| 378 |
+
}
|
| 379 |
+
}
|
src/ocr_bench/templates/base.html
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 6 |
+
<title>{% block title %}OCR Bench{% endblock %}</title>
|
| 7 |
+
<link rel="stylesheet" href="/static/style.css">
|
| 8 |
+
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
| 9 |
+
</head>
|
| 10 |
+
<body>
|
| 11 |
+
<div class="container">
|
| 12 |
+
<nav>
|
| 13 |
+
<a href="/" class="brand">ocr-bench</a>
|
| 14 |
+
<a href="/leaderboard" {% if active_tab == "leaderboard" %}class="active"{% endif %}>Leaderboard</a>
|
| 15 |
+
<a href="/comparisons" {% if active_tab == "comparisons" %}class="active"{% endif %}>Comparisons</a>
|
| 16 |
+
</nav>
|
| 17 |
+
{% block content %}{% endblock %}
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
<script>
|
| 21 |
+
document.addEventListener("keydown", function(e) {
|
| 22 |
+
// Ignore when focus is in input/select/textarea
|
| 23 |
+
var tag = document.activeElement.tagName.toLowerCase();
|
| 24 |
+
if (tag === "input" || tag === "select" || tag === "textarea") return;
|
| 25 |
+
|
| 26 |
+
if (e.key === "ArrowLeft") {
|
| 27 |
+
var prev = document.querySelector("[data-nav='prev']");
|
| 28 |
+
if (prev) { prev.click(); e.preventDefault(); }
|
| 29 |
+
} else if (e.key === "ArrowRight") {
|
| 30 |
+
var next = document.querySelector("[data-nav='next']");
|
| 31 |
+
if (next) { next.click(); e.preventDefault(); }
|
| 32 |
+
} else if (e.key === "a" || e.key === "A") {
|
| 33 |
+
var voteA = document.querySelector("[data-vote='A']");
|
| 34 |
+
if (voteA) { voteA.click(); e.preventDefault(); }
|
| 35 |
+
} else if (e.key === "b" || e.key === "B") {
|
| 36 |
+
var voteB = document.querySelector("[data-vote='B']");
|
| 37 |
+
if (voteB) { voteB.click(); e.preventDefault(); }
|
| 38 |
+
} else if (e.key === "t" || e.key === "T") {
|
| 39 |
+
var voteTie = document.querySelector("[data-vote='tie']");
|
| 40 |
+
if (voteTie) { voteTie.click(); e.preventDefault(); }
|
| 41 |
+
} else if (e.key === "r" || e.key === "R") {
|
| 42 |
+
var reveal = document.querySelector("[data-action='reveal']");
|
| 43 |
+
if (reveal) { reveal.click(); e.preventDefault(); }
|
| 44 |
+
}
|
| 45 |
+
});
|
| 46 |
+
</script>
|
| 47 |
+
</body>
|
| 48 |
+
</html>
|
src/ocr_bench/templates/comparison_card.html
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% if comp %}
|
| 2 |
+
<div class="comp-nav">
|
| 3 |
+
<span>{{ nav_idx + 1 }} of {{ nav_total }}</span>
|
| 4 |
+
{% if nav_idx > 0 %}
|
| 5 |
+
<a href="#" data-nav="prev"
|
| 6 |
+
hx-get="/comparisons/{{ nav_idx - 1 }}{% if winner_filter and winner_filter != 'All' %}?winner={{ winner_filter }}{% endif %}{% if model_filter and model_filter != 'All' %}{{ '&' if winner_filter and winner_filter != 'All' else '?' }}model={{ model_filter }}{% endif %}"
|
| 7 |
+
hx-target="#comparison-container">←</a>
|
| 8 |
+
{% endif %}
|
| 9 |
+
{% if nav_idx < nav_total - 1 %}
|
| 10 |
+
<a href="#" data-nav="next"
|
| 11 |
+
hx-get="/comparisons/{{ nav_idx + 1 }}{% if winner_filter and winner_filter != 'All' %}?winner={{ winner_filter }}{% endif %}{% if model_filter and model_filter != 'All' %}{{ '&' if winner_filter and winner_filter != 'All' else '?' }}model={{ model_filter }}{% endif %}"
|
| 12 |
+
hx-target="#comparison-container">→</a>
|
| 13 |
+
{% endif %}
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+
<div class="comparison-columns">
|
| 17 |
+
<div class="ocr-column">
|
| 18 |
+
{% if revealed %}
|
| 19 |
+
<h3 class="revealed">{{ model_a_name }}</h3>
|
| 20 |
+
{% else %}
|
| 21 |
+
<h3>A</h3>
|
| 22 |
+
{% endif %}
|
| 23 |
+
<div class="ocr-text">{{ display_text_a }}</div>
|
| 24 |
+
</div>
|
| 25 |
+
<div class="ocr-column">
|
| 26 |
+
{% if revealed %}
|
| 27 |
+
<h3 class="revealed">{{ model_b_name }}</h3>
|
| 28 |
+
{% else %}
|
| 29 |
+
<h3>B</h3>
|
| 30 |
+
{% endif %}
|
| 31 |
+
<div class="ocr-text">{{ display_text_b }}</div>
|
| 32 |
+
</div>
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
{% if not voted %}
|
| 36 |
+
<div class="vote-prompt">Which OCR output is better?</div>
|
| 37 |
+
<div class="vote-row">
|
| 38 |
+
<a href="#" data-vote="A" class="vote-btn"
|
| 39 |
+
hx-post="/vote/{{ comp_idx }}"
|
| 40 |
+
hx-vals='{"winner": "A"}'
|
| 41 |
+
hx-target="#comparison-container">A is better</a>
|
| 42 |
+
<a href="#" data-vote="tie" class="vote-btn vote-tie"
|
| 43 |
+
hx-post="/vote/{{ comp_idx }}"
|
| 44 |
+
hx-vals='{"winner": "tie"}'
|
| 45 |
+
hx-target="#comparison-container">Tie</a>
|
| 46 |
+
<a href="#" data-vote="B" class="vote-btn"
|
| 47 |
+
hx-post="/vote/{{ comp_idx }}"
|
| 48 |
+
hx-vals='{"winner": "B"}'
|
| 49 |
+
hx-target="#comparison-container">B is better</a>
|
| 50 |
+
</div>
|
| 51 |
+
<div class="vote-hints">
|
| 52 |
+
{% if not revealed %}
|
| 53 |
+
<a href="#" data-action="reveal"
|
| 54 |
+
hx-get="/reveal/{{ comp_idx }}"
|
| 55 |
+
hx-target="#comparison-container">show judge verdict</a>
|
| 56 |
+
<span class="separator">·</span>
|
| 57 |
+
{% endif %}
|
| 58 |
+
<span class="keys">keys: <kbd>a</kbd> <kbd>t</kbd> <kbd>b</kbd> vote · <kbd>←</kbd> <kbd>→</kbd> navigate{% if not revealed %} · <kbd>r</kbd> reveal{% endif %}</span>
|
| 59 |
+
</div>
|
| 60 |
+
{% endif %}
|
| 61 |
+
|
| 62 |
+
{% if revealed %}
|
| 63 |
+
<div class="verdict">
|
| 64 |
+
{% if voted %}
|
| 65 |
+
Judge: {{ judge_verdict }}
|
| 66 |
+
· You: {{ human_vote }}
|
| 67 |
+
· <span class="agreement {{ agreement_class }}">{{ agreement_word }}</span>
|
| 68 |
+
{% else %}
|
| 69 |
+
Judge: {{ judge_verdict }}
|
| 70 |
+
{% endif %}
|
| 71 |
+
{% if reason %}
|
| 72 |
+
<span class="reason">"{{ reason }}"</span>
|
| 73 |
+
{% endif %}
|
| 74 |
+
</div>
|
| 75 |
+
{% if just_voted and next_url %}
|
| 76 |
+
<div hx-get="{{ next_url }}" hx-trigger="load delay:1.2s" hx-target="#comparison-container"></div>
|
| 77 |
+
{% endif %}
|
| 78 |
+
{% endif %}
|
| 79 |
+
|
| 80 |
+
{% if has_image %}
|
| 81 |
+
<div class="doc-image">
|
| 82 |
+
<img src="/image/{{ sample_idx }}" alt="Document image" loading="lazy">
|
| 83 |
+
</div>
|
| 84 |
+
{% endif %}
|
| 85 |
+
|
| 86 |
+
{% else %}
|
| 87 |
+
<div class="empty">No comparisons match the current filters.</div>
|
| 88 |
+
{% endif %}
|
src/ocr_bench/templates/comparisons.html
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% extends "base.html" %}
|
| 2 |
+
{% block title %}Comparisons — OCR Bench{% endblock %}
|
| 3 |
+
{% block content %}
|
| 4 |
+
<div class="filters">
|
| 5 |
+
<label>Winner
|
| 6 |
+
<select name="winner"
|
| 7 |
+
hx-get="/comparisons/filter"
|
| 8 |
+
hx-target="#comparison-container"
|
| 9 |
+
hx-include="[name='model']">
|
| 10 |
+
<option value="All" {% if winner_filter == "All" %}selected{% endif %}>All</option>
|
| 11 |
+
<option value="A" {% if winner_filter == "A" %}selected{% endif %}>A</option>
|
| 12 |
+
<option value="B" {% if winner_filter == "B" %}selected{% endif %}>B</option>
|
| 13 |
+
<option value="tie" {% if winner_filter == "tie" %}selected{% endif %}>tie</option>
|
| 14 |
+
</select>
|
| 15 |
+
</label>
|
| 16 |
+
<label>Model
|
| 17 |
+
<select name="model"
|
| 18 |
+
hx-get="/comparisons/filter"
|
| 19 |
+
hx-target="#comparison-container"
|
| 20 |
+
hx-include="[name='winner']">
|
| 21 |
+
<option value="All" {% if model_filter == "All" %}selected{% endif %}>All</option>
|
| 22 |
+
{% for m in models %}
|
| 23 |
+
<option value="{{ m }}" {% if model_filter == m %}selected{% endif %}>{{ m }}</option>
|
| 24 |
+
{% endfor %}
|
| 25 |
+
</select>
|
| 26 |
+
</label>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
{% if pair_summary %}
|
| 30 |
+
<div class="pair-summary">{{ pair_summary | safe }}</div>
|
| 31 |
+
{% endif %}
|
| 32 |
+
|
| 33 |
+
<div id="comparison-container">
|
| 34 |
+
{% include "comparison_card.html" %}
|
| 35 |
+
</div>
|
| 36 |
+
|
| 37 |
+
<div id="stats-panel" hx-get="/stats" hx-trigger="vote-recorded from:body" hx-swap="innerHTML">
|
| 38 |
+
{% include "stats_panel.html" %}
|
| 39 |
+
</div>
|
| 40 |
+
{% endblock %}
|
src/ocr_bench/templates/leaderboard.html
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% extends "base.html" %}
|
| 2 |
+
{% block title %}Leaderboard — OCR Bench{% endblock %}
|
| 3 |
+
{% block content %}
|
| 4 |
+
<h2 style="font-size: 1.1rem; font-weight: 600; margin-bottom: 0.25rem;">Leaderboard</h2>
|
| 5 |
+
<p style="font-size: 0.8rem; color: #888; margin-top: 0;">{{ repo_id }}</p>
|
| 6 |
+
|
| 7 |
+
<table>
|
| 8 |
+
<thead>
|
| 9 |
+
<tr>
|
| 10 |
+
<th>#</th>
|
| 11 |
+
<th>Model</th>
|
| 12 |
+
<th class="num">Judge ELO</th>
|
| 13 |
+
{% if has_ci %}<th class="num">95% CI</th>{% endif %}
|
| 14 |
+
<th class="num">Wins</th>
|
| 15 |
+
<th class="num">Losses</th>
|
| 16 |
+
<th class="num">Ties</th>
|
| 17 |
+
<th class="num">Win%</th>
|
| 18 |
+
{% if has_human_elo %}
|
| 19 |
+
<th class="num">Human ELO</th>
|
| 20 |
+
<th class="num">H-Win%</th>
|
| 21 |
+
{% endif %}
|
| 22 |
+
</tr>
|
| 23 |
+
</thead>
|
| 24 |
+
<tbody>
|
| 25 |
+
{% for row in rows %}
|
| 26 |
+
<tr>
|
| 27 |
+
<td>{{ loop.index }}</td>
|
| 28 |
+
<td class="model">{{ row.model_short }}</td>
|
| 29 |
+
<td class="num">{{ row.elo }}</td>
|
| 30 |
+
{% if has_ci %}<td class="num">{{ row.elo_low }}–{{ row.elo_high }}</td>{% endif %}
|
| 31 |
+
<td class="num">{{ row.wins }}</td>
|
| 32 |
+
<td class="num">{{ row.losses }}</td>
|
| 33 |
+
<td class="num">{{ row.ties }}</td>
|
| 34 |
+
<td class="num">{{ row.win_pct }}%</td>
|
| 35 |
+
{% if has_human_elo %}
|
| 36 |
+
<td class="num">{{ row.human_elo if row.human_elo is not none else "—" }}</td>
|
| 37 |
+
<td class="num">{{ row.human_win_pct if row.human_win_pct is not none else "—" }}</td>
|
| 38 |
+
{% endif %}
|
| 39 |
+
</tr>
|
| 40 |
+
{% endfor %}
|
| 41 |
+
</tbody>
|
| 42 |
+
</table>
|
| 43 |
+
{% endblock %}
|
src/ocr_bench/templates/stats_panel.html
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% if vote_count > 0 %}
|
| 2 |
+
<span>{{ vote_count }} vote{{ "s" if vote_count != 1 else "" }}</span>
|
| 3 |
+
·
|
| 4 |
+
<span>{{ agreement_pct }}% agree</span>
|
| 5 |
+
{% if hard_disagree_rate > 25 %}
|
| 6 |
+
· <span class="warning">judge may be miscalibrated</span>
|
| 7 |
+
{% elif vote_count >= 15 %}
|
| 8 |
+
· <span class="calibrated">judge well-calibrated</span>
|
| 9 |
+
{% endif %}
|
| 10 |
+
{% endif %}
|
src/ocr_bench/validate.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Blind human A/B validation for OCR judge quality."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import structlog
|
| 13 |
+
|
| 14 |
+
logger = structlog.get_logger()
|
| 15 |
+
|
| 16 |
+
# Confidence thresholds
|
| 17 |
+
MIN_ANNOTATIONS_FOR_CONFIDENCE = 15
|
| 18 |
+
HIGH_AGREEMENT_THRESHOLD = 0.75
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AgreementStats:
|
| 23 |
+
"""Tracks agreement between human and VLM judge."""
|
| 24 |
+
|
| 25 |
+
agree: int = 0
|
| 26 |
+
soft_disagree: int = 0 # one picks tie, other picks winner
|
| 27 |
+
hard_disagree: int = 0 # both pick winners but opposite
|
| 28 |
+
total: int = 0
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def agreement_rate(self) -> float:
|
| 32 |
+
"""Rate including soft disagreements as partial agreement."""
|
| 33 |
+
return (self.agree + self.soft_disagree) / self.total if self.total else 0.0
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def hard_disagree_rate(self) -> float:
|
| 37 |
+
return self.hard_disagree / self.total if self.total else 0.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ValidationComparison:
|
| 42 |
+
"""A single comparison for human validation.
|
| 43 |
+
|
| 44 |
+
Built from enriched comparison data published by the judge.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
comparison_id: int
|
| 48 |
+
sample_idx: int
|
| 49 |
+
model_a: str
|
| 50 |
+
model_b: str
|
| 51 |
+
winner: str # judge's verdict (hidden during annotation)
|
| 52 |
+
reason: str
|
| 53 |
+
agreement: str # jury agreement (e.g. "2/2")
|
| 54 |
+
text_a: str # OCR text from model A
|
| 55 |
+
text_b: str # OCR text from model B
|
| 56 |
+
col_a: str
|
| 57 |
+
col_b: str
|
| 58 |
+
swapped: bool # position-bias randomization for human display
|
| 59 |
+
display_text_a: str = "" # text shown to human (may be swapped)
|
| 60 |
+
display_text_b: str = ""
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class ValidationSession:
|
| 65 |
+
"""Holds state for a validation session."""
|
| 66 |
+
|
| 67 |
+
comparisons: list[ValidationComparison]
|
| 68 |
+
model_names: list[str]
|
| 69 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 70 |
+
annotations: list[dict[str, Any]] = field(default_factory=list)
|
| 71 |
+
completed_ids: set[int] = field(default_factory=set)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _is_split_jury(agreement: str) -> bool:
|
| 75 |
+
"""Check if a jury vote was split (e.g. '1/2' not '2/2')."""
|
| 76 |
+
parts = agreement.split("/")
|
| 77 |
+
return len(parts) == 2 and parts[0] != parts[1]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _interleave_by_sample(
|
| 81 |
+
comparisons: list[ValidationComparison],
|
| 82 |
+
) -> list[ValidationComparison]:
|
| 83 |
+
"""Interleave comparisons so you see different samples before repeating."""
|
| 84 |
+
by_sample: dict[int, list[ValidationComparison]] = defaultdict(list)
|
| 85 |
+
for comp in comparisons:
|
| 86 |
+
by_sample[comp.sample_idx].append(comp)
|
| 87 |
+
|
| 88 |
+
result: list[ValidationComparison] = []
|
| 89 |
+
queues = list(by_sample.values())
|
| 90 |
+
while queues:
|
| 91 |
+
next_round = []
|
| 92 |
+
for q in queues:
|
| 93 |
+
result.append(q.pop(0))
|
| 94 |
+
if q:
|
| 95 |
+
next_round.append(q)
|
| 96 |
+
queues = next_round
|
| 97 |
+
return result
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_validation_comparisons(
|
| 101 |
+
comparison_rows: list[dict[str, Any]],
|
| 102 |
+
*,
|
| 103 |
+
n: int | None = None,
|
| 104 |
+
prioritize_splits: bool = True,
|
| 105 |
+
seed: int = 42,
|
| 106 |
+
) -> list[ValidationComparison]:
|
| 107 |
+
"""Build validation comparisons from published judge results.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
comparison_rows: Rows from the comparisons config of a results dataset.
|
| 111 |
+
n: Max number of comparisons to include (None = all).
|
| 112 |
+
prioritize_splits: Show split-jury cases first (most informative).
|
| 113 |
+
seed: Random seed for position-bias randomization.
|
| 114 |
+
"""
|
| 115 |
+
rng = random.Random(seed)
|
| 116 |
+
|
| 117 |
+
comps: list[ValidationComparison] = []
|
| 118 |
+
for i, row in enumerate(comparison_rows):
|
| 119 |
+
swapped = rng.random() < 0.5
|
| 120 |
+
text_a = row.get("text_a", "")
|
| 121 |
+
text_b = row.get("text_b", "")
|
| 122 |
+
|
| 123 |
+
if swapped:
|
| 124 |
+
display_a, display_b = text_b, text_a
|
| 125 |
+
else:
|
| 126 |
+
display_a, display_b = text_a, text_b
|
| 127 |
+
|
| 128 |
+
comps.append(
|
| 129 |
+
ValidationComparison(
|
| 130 |
+
comparison_id=i,
|
| 131 |
+
sample_idx=row.get("sample_idx", i),
|
| 132 |
+
model_a=row.get("model_a", ""),
|
| 133 |
+
model_b=row.get("model_b", ""),
|
| 134 |
+
winner=row.get("winner", "tie"),
|
| 135 |
+
reason=row.get("reason", ""),
|
| 136 |
+
agreement=row.get("agreement", "1/1"),
|
| 137 |
+
text_a=text_a,
|
| 138 |
+
text_b=text_b,
|
| 139 |
+
col_a=row.get("col_a", ""),
|
| 140 |
+
col_b=row.get("col_b", ""),
|
| 141 |
+
swapped=swapped,
|
| 142 |
+
display_text_a=display_a,
|
| 143 |
+
display_text_b=display_b,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if prioritize_splits:
|
| 148 |
+
splits = [c for c in comps if _is_split_jury(c.agreement)]
|
| 149 |
+
unanimous = [c for c in comps if not _is_split_jury(c.agreement)]
|
| 150 |
+
ordered = _interleave_by_sample(splits) + _interleave_by_sample(unanimous)
|
| 151 |
+
else:
|
| 152 |
+
ordered = _interleave_by_sample(comps)
|
| 153 |
+
|
| 154 |
+
if n is not None and n < len(ordered):
|
| 155 |
+
ordered = ordered[:n]
|
| 156 |
+
|
| 157 |
+
# Re-assign comparison IDs after reordering
|
| 158 |
+
return [
|
| 159 |
+
ValidationComparison(
|
| 160 |
+
comparison_id=i,
|
| 161 |
+
sample_idx=c.sample_idx,
|
| 162 |
+
model_a=c.model_a,
|
| 163 |
+
model_b=c.model_b,
|
| 164 |
+
winner=c.winner,
|
| 165 |
+
reason=c.reason,
|
| 166 |
+
agreement=c.agreement,
|
| 167 |
+
text_a=c.text_a,
|
| 168 |
+
text_b=c.text_b,
|
| 169 |
+
col_a=c.col_a,
|
| 170 |
+
col_b=c.col_b,
|
| 171 |
+
swapped=c.swapped,
|
| 172 |
+
display_text_a=c.display_text_a,
|
| 173 |
+
display_text_b=c.display_text_b,
|
| 174 |
+
)
|
| 175 |
+
for i, c in enumerate(ordered)
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compute_agreement(
|
| 180 |
+
annotations: list[dict[str, Any]],
|
| 181 |
+
comparisons: list[ValidationComparison],
|
| 182 |
+
) -> AgreementStats:
|
| 183 |
+
"""Compute agreement between human annotations and judge verdicts."""
|
| 184 |
+
comp_by_id = {c.comparison_id: c for c in comparisons}
|
| 185 |
+
stats = AgreementStats()
|
| 186 |
+
|
| 187 |
+
for ann in annotations:
|
| 188 |
+
comp = comp_by_id.get(ann.get("comparison_id"))
|
| 189 |
+
if not comp:
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
# Unswap human vote
|
| 193 |
+
human_winner = ann["winner"]
|
| 194 |
+
if comp.swapped:
|
| 195 |
+
if human_winner == "A":
|
| 196 |
+
human_winner = "B"
|
| 197 |
+
elif human_winner == "B":
|
| 198 |
+
human_winner = "A"
|
| 199 |
+
|
| 200 |
+
judge_winner = comp.winner
|
| 201 |
+
stats.total += 1
|
| 202 |
+
|
| 203 |
+
if human_winner == judge_winner:
|
| 204 |
+
stats.agree += 1
|
| 205 |
+
elif human_winner == "tie" or judge_winner == "tie":
|
| 206 |
+
stats.soft_disagree += 1
|
| 207 |
+
else:
|
| 208 |
+
stats.hard_disagree += 1
|
| 209 |
+
|
| 210 |
+
return stats
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def compute_human_elo(
|
| 214 |
+
annotations: list[dict[str, Any]],
|
| 215 |
+
comparisons: list[ValidationComparison],
|
| 216 |
+
) -> Any:
|
| 217 |
+
"""Compute ELO leaderboard from human annotations.
|
| 218 |
+
|
| 219 |
+
Returns a ``Leaderboard`` from ``elo.py``, or None if no annotations.
|
| 220 |
+
"""
|
| 221 |
+
from ocr_bench.elo import ComparisonResult, compute_elo
|
| 222 |
+
|
| 223 |
+
comp_by_id = {c.comparison_id: c for c in comparisons}
|
| 224 |
+
model_set: set[str] = set()
|
| 225 |
+
results: list[ComparisonResult] = []
|
| 226 |
+
|
| 227 |
+
for ann in annotations:
|
| 228 |
+
comp = comp_by_id.get(ann.get("comparison_id"))
|
| 229 |
+
if not comp:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
# Unswap human vote to get canonical winner
|
| 233 |
+
human_winner = ann["winner"]
|
| 234 |
+
if comp.swapped:
|
| 235 |
+
if human_winner == "A":
|
| 236 |
+
human_winner = "B"
|
| 237 |
+
elif human_winner == "B":
|
| 238 |
+
human_winner = "A"
|
| 239 |
+
|
| 240 |
+
model_set.add(comp.model_a)
|
| 241 |
+
model_set.add(comp.model_b)
|
| 242 |
+
results.append(
|
| 243 |
+
ComparisonResult(
|
| 244 |
+
sample_idx=comp.sample_idx,
|
| 245 |
+
model_a=comp.model_a,
|
| 246 |
+
model_b=comp.model_b,
|
| 247 |
+
winner=human_winner,
|
| 248 |
+
)
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if not results:
|
| 252 |
+
return None
|
| 253 |
+
|
| 254 |
+
return compute_elo(results, sorted(model_set))
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def save_annotations(
|
| 258 |
+
path: str,
|
| 259 |
+
metadata: dict[str, Any],
|
| 260 |
+
annotations: list[dict[str, Any]],
|
| 261 |
+
) -> None:
|
| 262 |
+
"""Atomically save annotations to JSON file."""
|
| 263 |
+
data = {"metadata": metadata, "annotations": annotations}
|
| 264 |
+
tmp = path + ".tmp"
|
| 265 |
+
with open(tmp, "w") as f:
|
| 266 |
+
json.dump(data, f, indent=2)
|
| 267 |
+
os.replace(tmp, path)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def load_annotations(path: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
| 271 |
+
"""Load annotations from JSON file. Returns (metadata, annotations)."""
|
| 272 |
+
if not os.path.exists(path):
|
| 273 |
+
return {}, []
|
| 274 |
+
with open(path) as f:
|
| 275 |
+
data = json.load(f)
|
| 276 |
+
return data.get("metadata", {}), data.get("annotations", [])
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _agreement_banner(stats: AgreementStats) -> str:
|
| 280 |
+
"""Format agreement stats for display."""
|
| 281 |
+
if stats.total == 0:
|
| 282 |
+
return ""
|
| 283 |
+
|
| 284 |
+
parts = [f"Agree: {stats.agree}"]
|
| 285 |
+
if stats.soft_disagree:
|
| 286 |
+
parts.append(f"Soft: {stats.soft_disagree}")
|
| 287 |
+
if stats.hard_disagree:
|
| 288 |
+
parts.append(f"**Hard: {stats.hard_disagree}**")
|
| 289 |
+
parts.append(f"(of {stats.total})")
|
| 290 |
+
|
| 291 |
+
confidence = ""
|
| 292 |
+
if stats.total >= MIN_ANNOTATIONS_FOR_CONFIDENCE:
|
| 293 |
+
if stats.hard_disagree_rate == 0:
|
| 294 |
+
confidence = (
|
| 295 |
+
f" -- No hard disagreements after {stats.total} annotations. "
|
| 296 |
+
"Judge rankings reliable for this domain."
|
| 297 |
+
)
|
| 298 |
+
elif stats.hard_disagree_rate <= 0.1:
|
| 299 |
+
confidence = (
|
| 300 |
+
f" -- Very few hard disagreements ({stats.hard_disagree}). "
|
| 301 |
+
"Rankings likely trustworthy."
|
| 302 |
+
)
|
| 303 |
+
elif stats.hard_disagree_rate > 0.25:
|
| 304 |
+
confidence = (
|
| 305 |
+
f" -- Many hard disagreements ({stats.hard_disagree}/{stats.total}). "
|
| 306 |
+
"Judge may not be calibrated for this content."
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
return f"Judge: {' | '.join(parts)}{confidence}"
|
| 310 |
+
|
| 311 |
+
|
src/ocr_bench/viewer.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Results viewer — data loading and helpers for OCR bench results."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING, Any
|
| 6 |
+
|
| 7 |
+
import structlog
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
logger = structlog.get_logger()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_results(repo_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
| 17 |
+
"""Load leaderboard and comparisons from a Hub results dataset.
|
| 18 |
+
|
| 19 |
+
Tries the default config first (new repos), then falls back to the
|
| 20 |
+
named ``leaderboard`` config (old repos).
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
(leaderboard_rows, comparison_rows)
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
leaderboard_ds = load_dataset(repo_id, split="train")
|
| 27 |
+
leaderboard_rows = [dict(row) for row in leaderboard_ds]
|
| 28 |
+
except Exception:
|
| 29 |
+
leaderboard_ds = load_dataset(repo_id, name="leaderboard", split="train")
|
| 30 |
+
leaderboard_rows = [dict(row) for row in leaderboard_ds]
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
comparisons_ds = load_dataset(repo_id, name="comparisons", split="train")
|
| 34 |
+
except Exception:
|
| 35 |
+
logger.warning("no_comparisons_config", repo=repo_id)
|
| 36 |
+
return leaderboard_rows, []
|
| 37 |
+
comparison_rows = [dict(row) for row in comparisons_ds]
|
| 38 |
+
|
| 39 |
+
return leaderboard_rows, comparison_rows
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _load_source_metadata(repo_id: str) -> dict[str, Any]:
|
| 43 |
+
"""Load metadata config from results repo to find the source dataset."""
|
| 44 |
+
try:
|
| 45 |
+
meta_ds = load_dataset(repo_id, name="metadata", split="train")
|
| 46 |
+
if len(meta_ds) > 0:
|
| 47 |
+
return dict(meta_ds[0])
|
| 48 |
+
except Exception as exc:
|
| 49 |
+
logger.warning("could_not_load_metadata", repo=repo_id, error=str(exc))
|
| 50 |
+
return {}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ImageLoader:
|
| 54 |
+
"""Lazy image loader — fetches images from source dataset by sample_idx."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, source_dataset: str, from_prs: bool = False):
|
| 57 |
+
self._source = source_dataset
|
| 58 |
+
self._from_prs = from_prs
|
| 59 |
+
self._cache: dict[int, Any] = {}
|
| 60 |
+
self._image_col: str | None = None
|
| 61 |
+
self._pr_revision: str | None = None
|
| 62 |
+
self._available = True
|
| 63 |
+
self._init_done = False
|
| 64 |
+
|
| 65 |
+
def _init_source(self) -> None:
|
| 66 |
+
"""Lazy init: discover image column and PR revision on first call."""
|
| 67 |
+
if self._init_done:
|
| 68 |
+
return
|
| 69 |
+
self._init_done = True
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
if self._from_prs:
|
| 73 |
+
from ocr_bench.dataset import discover_pr_configs
|
| 74 |
+
|
| 75 |
+
_, revisions = discover_pr_configs(self._source)
|
| 76 |
+
if revisions:
|
| 77 |
+
# Use the first PR revision to get images
|
| 78 |
+
first_config = next(iter(revisions))
|
| 79 |
+
self._pr_revision = revisions[first_config]
|
| 80 |
+
|
| 81 |
+
# Probe for image column by loading 1 row
|
| 82 |
+
kwargs: dict[str, Any] = {"path": self._source, "split": "train[:1]"}
|
| 83 |
+
if self._pr_revision:
|
| 84 |
+
# Load from the first PR config
|
| 85 |
+
first_config = next(iter(revisions))
|
| 86 |
+
kwargs["name"] = first_config
|
| 87 |
+
kwargs["revision"] = self._pr_revision
|
| 88 |
+
probe = load_dataset(**kwargs)
|
| 89 |
+
for col in probe.column_names:
|
| 90 |
+
if col == "image" or "image" in col.lower():
|
| 91 |
+
self._image_col = col
|
| 92 |
+
break
|
| 93 |
+
if not self._image_col:
|
| 94 |
+
logger.info("no_image_column_in_source", source=self._source)
|
| 95 |
+
self._available = False
|
| 96 |
+
except Exception as exc:
|
| 97 |
+
logger.warning("image_loader_init_failed", source=self._source, error=str(exc))
|
| 98 |
+
self._available = False
|
| 99 |
+
|
| 100 |
+
def get(self, sample_idx: int) -> Image.Image | None:
|
| 101 |
+
"""Fetch image for a sample index. Returns None on failure."""
|
| 102 |
+
self._init_source()
|
| 103 |
+
if not self._available or self._image_col is None:
|
| 104 |
+
return None
|
| 105 |
+
if sample_idx in self._cache:
|
| 106 |
+
return self._cache[sample_idx]
|
| 107 |
+
try:
|
| 108 |
+
kwargs: dict[str, Any] = {
|
| 109 |
+
"path": self._source,
|
| 110 |
+
"split": f"train[{sample_idx}:{sample_idx + 1}]",
|
| 111 |
+
}
|
| 112 |
+
if self._pr_revision:
|
| 113 |
+
from ocr_bench.dataset import discover_pr_configs
|
| 114 |
+
|
| 115 |
+
_, revisions = discover_pr_configs(self._source)
|
| 116 |
+
if revisions:
|
| 117 |
+
first_config = next(iter(revisions))
|
| 118 |
+
kwargs["name"] = first_config
|
| 119 |
+
kwargs["revision"] = revisions[first_config]
|
| 120 |
+
row = load_dataset(**kwargs)
|
| 121 |
+
img = row[0][self._image_col]
|
| 122 |
+
self._cache[sample_idx] = img
|
| 123 |
+
return img
|
| 124 |
+
except Exception as exc:
|
| 125 |
+
logger.debug("image_load_failed", sample_idx=sample_idx, error=str(exc))
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _filter_comparisons(
|
| 130 |
+
comparisons: list[dict[str, Any]],
|
| 131 |
+
winner_filter: str,
|
| 132 |
+
model_filter: str,
|
| 133 |
+
) -> list[dict[str, Any]]:
|
| 134 |
+
"""Filter comparison rows by winner and model."""
|
| 135 |
+
filtered = comparisons
|
| 136 |
+
if winner_filter and winner_filter != "All":
|
| 137 |
+
filtered = [c for c in filtered if c.get("winner") == winner_filter]
|
| 138 |
+
if model_filter and model_filter != "All":
|
| 139 |
+
filtered = [
|
| 140 |
+
c
|
| 141 |
+
for c in filtered
|
| 142 |
+
if c.get("model_a") == model_filter or c.get("model_b") == model_filter
|
| 143 |
+
]
|
| 144 |
+
return filtered
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _winner_badge(winner: str) -> str:
|
| 148 |
+
"""Return a badge string for the winner."""
|
| 149 |
+
if winner == "A":
|
| 150 |
+
return "Winner: A"
|
| 151 |
+
elif winner == "B":
|
| 152 |
+
return "Winner: B"
|
| 153 |
+
else:
|
| 154 |
+
return "Tie"
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _model_label(model: str, col: str) -> str:
|
| 158 |
+
"""Format model name with optional column name. Avoids empty parens."""
|
| 159 |
+
if col:
|
| 160 |
+
return f"{model} ({col})"
|
| 161 |
+
return model
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _build_pair_summary(comparisons: list[dict[str, Any]]) -> str:
|
| 165 |
+
"""Build a win/loss summary string for each model pair."""
|
| 166 |
+
from collections import Counter
|
| 167 |
+
|
| 168 |
+
pair_counts: dict[tuple[str, str], Counter[str]] = {}
|
| 169 |
+
for c in comparisons:
|
| 170 |
+
ma = c.get("model_a", "")
|
| 171 |
+
mb = c.get("model_b", "")
|
| 172 |
+
winner = c.get("winner", "tie")
|
| 173 |
+
key = (ma, mb) if ma <= mb else (mb, ma)
|
| 174 |
+
if key not in pair_counts:
|
| 175 |
+
pair_counts[key] = Counter()
|
| 176 |
+
# Track from perspective of first model in sorted pair
|
| 177 |
+
if winner == "A":
|
| 178 |
+
actual_winner = ma
|
| 179 |
+
elif winner == "B":
|
| 180 |
+
actual_winner = mb
|
| 181 |
+
else:
|
| 182 |
+
actual_winner = "tie"
|
| 183 |
+
|
| 184 |
+
if actual_winner == key[0]:
|
| 185 |
+
pair_counts[key]["W"] += 1
|
| 186 |
+
elif actual_winner == key[1]:
|
| 187 |
+
pair_counts[key]["L"] += 1
|
| 188 |
+
else:
|
| 189 |
+
pair_counts[key]["T"] += 1
|
| 190 |
+
|
| 191 |
+
if not pair_counts:
|
| 192 |
+
return ""
|
| 193 |
+
|
| 194 |
+
parts = []
|
| 195 |
+
for (ma, mb), counts in sorted(pair_counts.items()):
|
| 196 |
+
short_a = ma.split("/")[-1] if "/" in ma else ma
|
| 197 |
+
short_b = mb.split("/")[-1] if "/" in mb else mb
|
| 198 |
+
wins, losses, ties = counts["W"], counts["L"], counts["T"]
|
| 199 |
+
parts.append(f"**{short_a}** vs **{short_b}**: {wins}W {losses}L {ties}T")
|
| 200 |
+
return " | ".join(parts)
|
| 201 |
+
|
| 202 |
+
|
src/ocr_bench/web.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI + HTMX viewer — unified browse + validate for OCR bench results."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from datetime import UTC, datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import structlog
|
| 12 |
+
from fastapi import FastAPI, Form, Request
|
| 13 |
+
from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
|
| 14 |
+
from fastapi.staticfiles import StaticFiles
|
| 15 |
+
from fastapi.templating import Jinja2Templates
|
| 16 |
+
|
| 17 |
+
from ocr_bench.validate import (
|
| 18 |
+
ValidationComparison,
|
| 19 |
+
build_validation_comparisons,
|
| 20 |
+
compute_agreement,
|
| 21 |
+
compute_human_elo,
|
| 22 |
+
load_annotations,
|
| 23 |
+
save_annotations,
|
| 24 |
+
)
|
| 25 |
+
from ocr_bench.viewer import (
|
| 26 |
+
ImageLoader,
|
| 27 |
+
_filter_comparisons,
|
| 28 |
+
_load_source_metadata,
|
| 29 |
+
load_results,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
logger = structlog.get_logger()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _short_model(model: str) -> str:
|
| 36 |
+
"""Return just the model name after the org prefix."""
|
| 37 |
+
return model.split("/")[-1] if "/" in model else model
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_pair_summary_html(comparisons: list[dict[str, Any]]) -> str:
|
| 41 |
+
"""Build a compact HTML table of head-to-head records."""
|
| 42 |
+
from collections import Counter
|
| 43 |
+
|
| 44 |
+
pair_counts: dict[tuple[str, str], Counter[str]] = {}
|
| 45 |
+
for c in comparisons:
|
| 46 |
+
ma = c.get("model_a", "")
|
| 47 |
+
mb = c.get("model_b", "")
|
| 48 |
+
winner = c.get("winner", "tie")
|
| 49 |
+
key = (ma, mb) if ma <= mb else (mb, ma)
|
| 50 |
+
if key not in pair_counts:
|
| 51 |
+
pair_counts[key] = Counter()
|
| 52 |
+
if winner == "A":
|
| 53 |
+
actual_winner = ma
|
| 54 |
+
elif winner == "B":
|
| 55 |
+
actual_winner = mb
|
| 56 |
+
else:
|
| 57 |
+
actual_winner = "tie"
|
| 58 |
+
if actual_winner == key[0]:
|
| 59 |
+
pair_counts[key]["W"] += 1
|
| 60 |
+
elif actual_winner == key[1]:
|
| 61 |
+
pair_counts[key]["L"] += 1
|
| 62 |
+
else:
|
| 63 |
+
pair_counts[key]["T"] += 1
|
| 64 |
+
|
| 65 |
+
if not pair_counts:
|
| 66 |
+
return ""
|
| 67 |
+
|
| 68 |
+
rows = []
|
| 69 |
+
for (ma, mb), counts in sorted(pair_counts.items()):
|
| 70 |
+
short_a = _short_model(ma)
|
| 71 |
+
short_b = _short_model(mb)
|
| 72 |
+
wins, losses, ties = counts["W"], counts["L"], counts["T"]
|
| 73 |
+
rows.append(
|
| 74 |
+
f"<tr><td>{short_a}</td><td>{short_b}</td>"
|
| 75 |
+
f"<td class='num'>{wins}</td><td class='num'>{losses}</td>"
|
| 76 |
+
f"<td class='num'>{ties}</td></tr>"
|
| 77 |
+
)
|
| 78 |
+
return (
|
| 79 |
+
'<table class="pair-table"><thead><tr>'
|
| 80 |
+
"<th>Model A</th><th>Model B</th>"
|
| 81 |
+
'<th class="num">W</th><th class="num">L</th><th class="num">T</th>'
|
| 82 |
+
"</tr></thead><tbody>" + "".join(rows) + "</tbody></table>"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
PKG_DIR = Path(__file__).parent
|
| 87 |
+
TEMPLATES_DIR = PKG_DIR / "templates"
|
| 88 |
+
STATIC_DIR = PKG_DIR / "static"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass
|
| 92 |
+
class ViewerState:
|
| 93 |
+
"""In-memory state for the single-user viewer."""
|
| 94 |
+
|
| 95 |
+
repo_id: str
|
| 96 |
+
leaderboard_rows: list[dict[str, Any]]
|
| 97 |
+
comparison_rows: list[dict[str, Any]]
|
| 98 |
+
validation_comps: list[ValidationComparison]
|
| 99 |
+
models: list[str]
|
| 100 |
+
img_loader: ImageLoader | None
|
| 101 |
+
save_path: str
|
| 102 |
+
annotations: list[dict[str, Any]] = field(default_factory=list)
|
| 103 |
+
completed_ids: set[int] = field(default_factory=set)
|
| 104 |
+
filtered_indices: list[int] = field(default_factory=list)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _build_filtered_indices(
|
| 108 |
+
state: ViewerState,
|
| 109 |
+
winner_filter: str = "All",
|
| 110 |
+
model_filter: str = "All",
|
| 111 |
+
) -> list[int]:
|
| 112 |
+
"""Map nav indices to validation_comps indices, respecting filters."""
|
| 113 |
+
filtered_comps = _filter_comparisons(state.comparison_rows, winner_filter, model_filter)
|
| 114 |
+
# Build a lookup from (sample_idx, model_a, model_b) -> validation comp index
|
| 115 |
+
filtered_sample_keys = {
|
| 116 |
+
(c["sample_idx"], c["model_a"], c["model_b"]) for c in filtered_comps
|
| 117 |
+
}
|
| 118 |
+
return [
|
| 119 |
+
i
|
| 120 |
+
for i, vc in enumerate(state.validation_comps)
|
| 121 |
+
if (vc.sample_idx, vc.model_a, vc.model_b) in filtered_sample_keys
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def create_app(
|
| 126 |
+
repo_id: str,
|
| 127 |
+
*,
|
| 128 |
+
output_path: str | None = None,
|
| 129 |
+
n_validate: int | None = None,
|
| 130 |
+
) -> FastAPI:
|
| 131 |
+
"""Create the FastAPI app with all routes.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
repo_id: HF dataset repo with published judge results.
|
| 135 |
+
output_path: Path to save human annotations JSON.
|
| 136 |
+
n_validate: Max comparisons to include for validation (None = all).
|
| 137 |
+
"""
|
| 138 |
+
app = FastAPI(title=f"OCR Bench — {repo_id}")
|
| 139 |
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
| 140 |
+
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
| 141 |
+
|
| 142 |
+
# --- Load data ---
|
| 143 |
+
leaderboard_rows, comparison_rows = load_results(repo_id)
|
| 144 |
+
|
| 145 |
+
metadata = _load_source_metadata(repo_id)
|
| 146 |
+
source_dataset = metadata.get("source_dataset", "")
|
| 147 |
+
from_prs = metadata.get("from_prs", False)
|
| 148 |
+
|
| 149 |
+
img_loader: ImageLoader | None = None
|
| 150 |
+
if source_dataset:
|
| 151 |
+
img_loader = ImageLoader(source_dataset, from_prs=from_prs)
|
| 152 |
+
|
| 153 |
+
validation_comps = build_validation_comparisons(
|
| 154 |
+
comparison_rows, n=n_validate, prioritize_splits=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
models = sorted(
|
| 158 |
+
{c.get("model_a", "") for c in comparison_rows}
|
| 159 |
+
| {c.get("model_b", "") for c in comparison_rows}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
slug = repo_id.replace("/", "-")
|
| 163 |
+
save_path = output_path or f"human-eval-{slug}.json"
|
| 164 |
+
|
| 165 |
+
# Resume existing annotations
|
| 166 |
+
_, existing_annotations = load_annotations(save_path)
|
| 167 |
+
completed_ids = {ann["comparison_id"] for ann in existing_annotations}
|
| 168 |
+
|
| 169 |
+
state = ViewerState(
|
| 170 |
+
repo_id=repo_id,
|
| 171 |
+
leaderboard_rows=leaderboard_rows,
|
| 172 |
+
comparison_rows=comparison_rows,
|
| 173 |
+
validation_comps=validation_comps,
|
| 174 |
+
models=models,
|
| 175 |
+
img_loader=img_loader,
|
| 176 |
+
save_path=save_path,
|
| 177 |
+
annotations=existing_annotations,
|
| 178 |
+
completed_ids=completed_ids,
|
| 179 |
+
filtered_indices=list(range(len(validation_comps))),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Store state on app for access in routes
|
| 183 |
+
app.state.viewer = state
|
| 184 |
+
|
| 185 |
+
ann_metadata = {
|
| 186 |
+
"results_repo": repo_id,
|
| 187 |
+
"n_comparisons": len(validation_comps),
|
| 188 |
+
"models": models,
|
| 189 |
+
"started_at": datetime.now(UTC).isoformat(),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# --- Helpers ---
|
| 193 |
+
|
| 194 |
+
def _get_comp_context(
|
| 195 |
+
nav_idx: int,
|
| 196 |
+
*,
|
| 197 |
+
revealed: bool = False,
|
| 198 |
+
voted: bool = False,
|
| 199 |
+
human_vote: str = "",
|
| 200 |
+
winner_filter: str = "All",
|
| 201 |
+
model_filter: str = "All",
|
| 202 |
+
) -> dict[str, Any]:
|
| 203 |
+
"""Build template context for a comparison card."""
|
| 204 |
+
indices = state.filtered_indices
|
| 205 |
+
if nav_idx < 0 or nav_idx >= len(indices):
|
| 206 |
+
return {"comp": None, "nav_idx": nav_idx, "nav_total": len(indices)}
|
| 207 |
+
|
| 208 |
+
comp_idx = indices[nav_idx]
|
| 209 |
+
comp = state.validation_comps[comp_idx]
|
| 210 |
+
|
| 211 |
+
# Check if already voted
|
| 212 |
+
already_voted = comp.comparison_id in state.completed_ids
|
| 213 |
+
if already_voted:
|
| 214 |
+
voted = True
|
| 215 |
+
revealed = True
|
| 216 |
+
# Find the annotation to get human vote
|
| 217 |
+
for ann in state.annotations:
|
| 218 |
+
if ann["comparison_id"] == comp.comparison_id:
|
| 219 |
+
human_vote = ann["winner"]
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
# Model names — short form for clean headers
|
| 223 |
+
model_a_name = _short_model(comp.model_a)
|
| 224 |
+
model_b_name = _short_model(comp.model_b)
|
| 225 |
+
if comp.swapped:
|
| 226 |
+
model_a_name, model_b_name = model_b_name, model_a_name
|
| 227 |
+
|
| 228 |
+
# Judge verdict (canonical → display)
|
| 229 |
+
judge_winner = comp.winner
|
| 230 |
+
if comp.swapped:
|
| 231 |
+
if judge_winner == "A":
|
| 232 |
+
judge_verdict = "B"
|
| 233 |
+
elif judge_winner == "B":
|
| 234 |
+
judge_verdict = "A"
|
| 235 |
+
else:
|
| 236 |
+
judge_verdict = "tie"
|
| 237 |
+
else:
|
| 238 |
+
judge_verdict = judge_winner
|
| 239 |
+
|
| 240 |
+
# Agreement
|
| 241 |
+
agreement_word = ""
|
| 242 |
+
agreement_class = ""
|
| 243 |
+
if voted and human_vote:
|
| 244 |
+
# Unswap human vote for comparison
|
| 245 |
+
unswapped_human = human_vote
|
| 246 |
+
if comp.swapped:
|
| 247 |
+
if human_vote == "A":
|
| 248 |
+
unswapped_human = "B"
|
| 249 |
+
elif human_vote == "B":
|
| 250 |
+
unswapped_human = "A"
|
| 251 |
+
|
| 252 |
+
if unswapped_human == comp.winner:
|
| 253 |
+
agreement_word = "agreed"
|
| 254 |
+
agreement_class = "agreed"
|
| 255 |
+
elif unswapped_human == "tie" or comp.winner == "tie":
|
| 256 |
+
agreement_word = "soft disagree"
|
| 257 |
+
agreement_class = "soft-disagree"
|
| 258 |
+
else:
|
| 259 |
+
agreement_word = "hard disagree"
|
| 260 |
+
agreement_class = "hard-disagree"
|
| 261 |
+
|
| 262 |
+
has_image = img_loader is not None
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"comp": comp,
|
| 266 |
+
"comp_idx": comp_idx,
|
| 267 |
+
"nav_idx": nav_idx,
|
| 268 |
+
"nav_total": len(indices),
|
| 269 |
+
"revealed": revealed,
|
| 270 |
+
"voted": voted,
|
| 271 |
+
"display_text_a": comp.display_text_a,
|
| 272 |
+
"display_text_b": comp.display_text_b,
|
| 273 |
+
"model_a_name": model_a_name,
|
| 274 |
+
"model_b_name": model_b_name,
|
| 275 |
+
"judge_verdict": judge_verdict,
|
| 276 |
+
"human_vote": human_vote,
|
| 277 |
+
"agreement_word": agreement_word,
|
| 278 |
+
"agreement_class": agreement_class,
|
| 279 |
+
"reason": comp.reason,
|
| 280 |
+
"sample_idx": comp.sample_idx,
|
| 281 |
+
"has_image": has_image,
|
| 282 |
+
"winner_filter": winner_filter,
|
| 283 |
+
"model_filter": model_filter,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
def _stats_context() -> dict[str, Any]:
|
| 287 |
+
"""Build template context for the stats panel."""
|
| 288 |
+
stats = compute_agreement(state.annotations, state.validation_comps)
|
| 289 |
+
return {
|
| 290 |
+
"vote_count": stats.total,
|
| 291 |
+
"agreement_pct": round(stats.agreement_rate * 100) if stats.total else 0,
|
| 292 |
+
"hard_disagree_rate": round(stats.hard_disagree_rate * 100) if stats.total else 0,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
def _nav_idx_for_comp_idx(comp_idx: int) -> int:
|
| 296 |
+
"""Find the nav_idx for a given comp_idx in filtered_indices."""
|
| 297 |
+
try:
|
| 298 |
+
return state.filtered_indices.index(comp_idx)
|
| 299 |
+
except ValueError:
|
| 300 |
+
return 0
|
| 301 |
+
|
| 302 |
+
# --- Routes ---
|
| 303 |
+
|
| 304 |
+
@app.get("/", response_class=RedirectResponse)
|
| 305 |
+
async def index():
|
| 306 |
+
return RedirectResponse(url="/comparisons", status_code=302)
|
| 307 |
+
|
| 308 |
+
@app.get("/leaderboard", response_class=HTMLResponse)
|
| 309 |
+
async def leaderboard(request: Request):
|
| 310 |
+
# Build human ELO if we have annotations
|
| 311 |
+
human_board = compute_human_elo(state.annotations, state.validation_comps)
|
| 312 |
+
|
| 313 |
+
rows = []
|
| 314 |
+
for row in sorted(state.leaderboard_rows, key=lambda r: r.get("elo", 0), reverse=True):
|
| 315 |
+
model = row.get("model", "")
|
| 316 |
+
short = model.split("/")[-1] if "/" in model else model
|
| 317 |
+
human_elo = None
|
| 318 |
+
human_win_pct = None
|
| 319 |
+
if human_board and model in human_board.elo:
|
| 320 |
+
human_elo = round(human_board.elo[model])
|
| 321 |
+
wp = human_board.win_pct(model)
|
| 322 |
+
human_win_pct = f"{wp:.0f}" if wp is not None else None
|
| 323 |
+
|
| 324 |
+
rows.append({
|
| 325 |
+
"model": model,
|
| 326 |
+
"model_short": short,
|
| 327 |
+
"elo": round(row.get("elo", 0)),
|
| 328 |
+
"elo_low": row.get("elo_low"),
|
| 329 |
+
"elo_high": row.get("elo_high"),
|
| 330 |
+
"wins": row.get("wins", 0),
|
| 331 |
+
"losses": row.get("losses", 0),
|
| 332 |
+
"ties": row.get("ties", 0),
|
| 333 |
+
"win_pct": row.get("win_pct", 0),
|
| 334 |
+
"human_elo": human_elo,
|
| 335 |
+
"human_win_pct": human_win_pct,
|
| 336 |
+
})
|
| 337 |
+
|
| 338 |
+
has_ci = any(r.get("elo_low") is not None for r in rows)
|
| 339 |
+
return templates.TemplateResponse(request, "leaderboard.html", {
|
| 340 |
+
"active_tab": "leaderboard",
|
| 341 |
+
"repo_id": state.repo_id,
|
| 342 |
+
"rows": rows,
|
| 343 |
+
"has_ci": has_ci,
|
| 344 |
+
"has_human_elo": human_board is not None,
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
@app.get("/comparisons", response_class=HTMLResponse)
|
| 348 |
+
async def comparisons_page(request: Request):
|
| 349 |
+
state.filtered_indices = _build_filtered_indices(state)
|
| 350 |
+
pair_summary = _build_pair_summary_html(state.comparison_rows)
|
| 351 |
+
ctx = _get_comp_context(0)
|
| 352 |
+
stats = _stats_context()
|
| 353 |
+
return templates.TemplateResponse(request, "comparisons.html", {
|
| 354 |
+
"active_tab": "comparisons",
|
| 355 |
+
"models": state.models,
|
| 356 |
+
"pair_summary": pair_summary,
|
| 357 |
+
"winner_filter": "All",
|
| 358 |
+
"model_filter": "All",
|
| 359 |
+
**ctx,
|
| 360 |
+
**stats,
|
| 361 |
+
})
|
| 362 |
+
|
| 363 |
+
@app.get("/comparisons/filter", response_class=HTMLResponse)
|
| 364 |
+
async def comparisons_filter(
|
| 365 |
+
request: Request,
|
| 366 |
+
winner: str = "All",
|
| 367 |
+
model: str = "All",
|
| 368 |
+
):
|
| 369 |
+
state.filtered_indices = _build_filtered_indices(state, winner, model)
|
| 370 |
+
ctx = _get_comp_context(0, winner_filter=winner, model_filter=model)
|
| 371 |
+
return templates.TemplateResponse(request, "comparison_card.html", ctx)
|
| 372 |
+
|
| 373 |
+
@app.get("/comparisons/{nav_idx}", response_class=HTMLResponse)
|
| 374 |
+
async def comparison_at(
|
| 375 |
+
request: Request,
|
| 376 |
+
nav_idx: int,
|
| 377 |
+
winner: str = "All",
|
| 378 |
+
model: str = "All",
|
| 379 |
+
):
|
| 380 |
+
# Clamp nav_idx
|
| 381 |
+
nav_idx = max(0, min(nav_idx, len(state.filtered_indices) - 1))
|
| 382 |
+
ctx = _get_comp_context(nav_idx, winner_filter=winner, model_filter=model)
|
| 383 |
+
return templates.TemplateResponse(request, "comparison_card.html", ctx)
|
| 384 |
+
|
| 385 |
+
@app.post("/vote/{comp_idx}", response_class=HTMLResponse)
|
| 386 |
+
async def vote(request: Request, comp_idx: int, winner: str = Form(...)):
|
| 387 |
+
if comp_idx < 0 or comp_idx >= len(state.validation_comps):
|
| 388 |
+
return HTMLResponse("Invalid comparison", status_code=404)
|
| 389 |
+
|
| 390 |
+
comp = state.validation_comps[comp_idx]
|
| 391 |
+
|
| 392 |
+
# Idempotent: if already voted, just return revealed card
|
| 393 |
+
if comp.comparison_id not in state.completed_ids:
|
| 394 |
+
# Unswap for storage
|
| 395 |
+
winner_unswapped = winner
|
| 396 |
+
if comp.swapped:
|
| 397 |
+
if winner == "A":
|
| 398 |
+
winner_unswapped = "B"
|
| 399 |
+
elif winner == "B":
|
| 400 |
+
winner_unswapped = "A"
|
| 401 |
+
|
| 402 |
+
if winner_unswapped == "A":
|
| 403 |
+
winner_model = comp.model_a
|
| 404 |
+
elif winner_unswapped == "B":
|
| 405 |
+
winner_model = comp.model_b
|
| 406 |
+
else:
|
| 407 |
+
winner_model = "tie"
|
| 408 |
+
|
| 409 |
+
ann = {
|
| 410 |
+
"comparison_id": comp.comparison_id,
|
| 411 |
+
"sample_idx": comp.sample_idx,
|
| 412 |
+
"model_a": comp.model_a,
|
| 413 |
+
"model_b": comp.model_b,
|
| 414 |
+
"swapped": comp.swapped,
|
| 415 |
+
"winner": winner,
|
| 416 |
+
"winner_model": winner_model,
|
| 417 |
+
"timestamp": datetime.now(UTC).isoformat(),
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
state.annotations.append(ann)
|
| 421 |
+
state.completed_ids.add(comp.comparison_id)
|
| 422 |
+
save_annotations(state.save_path, ann_metadata, state.annotations)
|
| 423 |
+
|
| 424 |
+
nav_idx = _nav_idx_for_comp_idx(comp_idx)
|
| 425 |
+
# Read current filters from request query params (forwarded by htmx)
|
| 426 |
+
winner_filter = request.query_params.get("winner", "All")
|
| 427 |
+
model_filter = request.query_params.get("model", "All")
|
| 428 |
+
|
| 429 |
+
ctx = _get_comp_context(
|
| 430 |
+
nav_idx,
|
| 431 |
+
revealed=True,
|
| 432 |
+
voted=True,
|
| 433 |
+
human_vote=winner,
|
| 434 |
+
winner_filter=winner_filter,
|
| 435 |
+
model_filter=model_filter,
|
| 436 |
+
)
|
| 437 |
+
# Auto-advance: tell template this was a fresh vote
|
| 438 |
+
next_nav = nav_idx + 1 if nav_idx + 1 < len(state.filtered_indices) else None
|
| 439 |
+
ctx["just_voted"] = True
|
| 440 |
+
ctx["next_nav_idx"] = next_nav
|
| 441 |
+
ctx["next_url"] = (
|
| 442 |
+
f"/comparisons/{next_nav}"
|
| 443 |
+
+ (f"?winner={winner_filter}" if winner_filter != "All" else "")
|
| 444 |
+
+ (f"{'&' if winner_filter != 'All' else '?'}model={model_filter}" if model_filter != "All" else "")
|
| 445 |
+
if next_nav is not None
|
| 446 |
+
else None
|
| 447 |
+
)
|
| 448 |
+
response = templates.TemplateResponse(request, "comparison_card.html", ctx)
|
| 449 |
+
response.headers["HX-Trigger"] = "vote-recorded"
|
| 450 |
+
return response
|
| 451 |
+
|
| 452 |
+
@app.get("/reveal/{comp_idx}", response_class=HTMLResponse)
|
| 453 |
+
async def reveal(request: Request, comp_idx: int):
|
| 454 |
+
if comp_idx < 0 or comp_idx >= len(state.validation_comps):
|
| 455 |
+
return HTMLResponse("Invalid comparison", status_code=404)
|
| 456 |
+
|
| 457 |
+
nav_idx = _nav_idx_for_comp_idx(comp_idx)
|
| 458 |
+
winner_filter = request.query_params.get("winner", "All")
|
| 459 |
+
model_filter = request.query_params.get("model", "All")
|
| 460 |
+
|
| 461 |
+
ctx = _get_comp_context(
|
| 462 |
+
nav_idx,
|
| 463 |
+
revealed=True,
|
| 464 |
+
voted=False,
|
| 465 |
+
winner_filter=winner_filter,
|
| 466 |
+
model_filter=model_filter,
|
| 467 |
+
)
|
| 468 |
+
return templates.TemplateResponse(request, "comparison_card.html", ctx)
|
| 469 |
+
|
| 470 |
+
@app.get("/stats", response_class=HTMLResponse)
|
| 471 |
+
async def stats(request: Request):
|
| 472 |
+
ctx = _stats_context()
|
| 473 |
+
return templates.TemplateResponse(request, "stats_panel.html", ctx)
|
| 474 |
+
|
| 475 |
+
@app.get("/image/{sample_idx}")
|
| 476 |
+
async def image(sample_idx: int):
|
| 477 |
+
if img_loader is None:
|
| 478 |
+
return HTMLResponse("No images available", status_code=404)
|
| 479 |
+
img = img_loader.get(sample_idx)
|
| 480 |
+
if img is None:
|
| 481 |
+
return HTMLResponse("Image not found", status_code=404)
|
| 482 |
+
buf = io.BytesIO()
|
| 483 |
+
img.save(buf, format="PNG")
|
| 484 |
+
buf.seek(0)
|
| 485 |
+
return StreamingResponse(buf, media_type="image/png")
|
| 486 |
+
|
| 487 |
+
return app
|