davanstrien HF Staff commited on
Commit
1118181
·
verified ·
1 Parent(s): 810cfa0

Upload folder using huggingface_hub

Browse files
src/ocr_bench/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards."""
2
+
3
+ __version__ = "0.1.0"
src/ocr_bench/backends.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Judge backends — API-based (HF Inference Providers, OpenAI-compatible)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import abc
6
+ from collections import Counter
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from typing import Any
9
+
10
+ import stamina
11
+ import structlog
12
+ from huggingface_hub import InferenceClient
13
+ from openai import OpenAI
14
+
15
+ from ocr_bench.judge import JUDGE_SCHEMA, Comparison, parse_judge_output
16
+
17
+ logger = structlog.get_logger()
18
+
19
+ # Retry on these exception types with exponential backoff + jitter.
20
+ _RETRYABLE = (Exception,)
21
+
22
+
23
+ class JudgeBackend(abc.ABC):
24
+ """Base class for judge backends."""
25
+
26
+ name: str
27
+ concurrency: int = 1
28
+
29
+ @abc.abstractmethod
30
+ def _call_single(self, comp: Comparison) -> dict[str, str]:
31
+ """Run the judge on a single comparison."""
32
+
33
+ def judge(self, comparisons: list[Comparison]) -> list[dict[str, str]]:
34
+ """Run the judge on a list of comparisons (concurrently if supported).
35
+
36
+ Returns a list of parsed results (one per comparison).
37
+ Each result is a dict with ``winner`` and ``reason`` keys,
38
+ or an empty dict on failure.
39
+ """
40
+ if self.concurrency <= 1 or len(comparisons) <= 1:
41
+ return [self._call_single(comp) for comp in comparisons]
42
+
43
+ # Concurrent execution preserving order
44
+ results: list[dict[str, str]] = [{}] * len(comparisons)
45
+ with ThreadPoolExecutor(max_workers=self.concurrency) as pool:
46
+ future_to_idx = {
47
+ pool.submit(self._call_single, comp): i
48
+ for i, comp in enumerate(comparisons)
49
+ }
50
+ for future in as_completed(future_to_idx):
51
+ idx = future_to_idx[future]
52
+ try:
53
+ results[idx] = future.result()
54
+ except Exception as exc:
55
+ logger.warning("judge_call_failed", idx=idx, error=str(exc))
56
+ results[idx] = {}
57
+ return results
58
+
59
+
60
+ DEFAULT_MAX_TOKENS = 1024
61
+
62
+
63
+ class InferenceProviderJudge(JudgeBackend):
64
+ """HF Inference Providers backend (Novita, Together, etc.)."""
65
+
66
+ def __init__(
67
+ self, model: str, provider: str | None = None, max_tokens: int = DEFAULT_MAX_TOKENS,
68
+ ):
69
+ self.name = f"{provider + ':' if provider else ''}{model}"
70
+ self.model = model
71
+ self.max_tokens = max_tokens
72
+ self.client = InferenceClient(model=model, provider=provider) # type: ignore[invalid-argument-type]
73
+
74
+ @stamina.retry(on=_RETRYABLE, attempts=6)
75
+ def _call_single(self, comp: Comparison) -> dict[str, str]:
76
+ response = self.client.chat_completion( # type: ignore[no-matching-overload]
77
+ messages=comp.messages,
78
+ max_tokens=self.max_tokens,
79
+ temperature=0.0,
80
+ response_format={"type": "json_object"},
81
+ extra_body={"chat_template_kwargs": {"enable_thinking": False}},
82
+ )
83
+ raw = response.choices[0].message.content.strip()
84
+ result = parse_judge_output(raw)
85
+ if not result:
86
+ logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
87
+ return result
88
+
89
+
90
+ class OpenAICompatibleJudge(JudgeBackend):
91
+ """OpenAI-compatible endpoint (local vLLM server, Ollama, HF IE, etc.)."""
92
+
93
+ def __init__(
94
+ self,
95
+ base_url: str,
96
+ model: str = "default",
97
+ max_tokens: int = DEFAULT_MAX_TOKENS,
98
+ api_key: str = "not-needed",
99
+ extra_body: dict | None = None,
100
+ temperature: float = 0.0,
101
+ concurrency: int = 1,
102
+ ):
103
+ self.name = model if model != "default" else f"openai@{base_url}"
104
+ self.model = model
105
+ self.max_tokens = max_tokens
106
+ self.temperature = temperature
107
+ self.extra_body = extra_body if extra_body is not None else {"guided_json": JUDGE_SCHEMA}
108
+ self.concurrency = concurrency
109
+ self.client = OpenAI(base_url=base_url, api_key=api_key)
110
+
111
+ @stamina.retry(on=_RETRYABLE, attempts=3)
112
+ def _call_single(self, comp: Comparison) -> dict[str, str]:
113
+ response = self.client.chat.completions.create(
114
+ model=self.model,
115
+ messages=comp.messages, # type: ignore[invalid-argument-type]
116
+ max_tokens=self.max_tokens,
117
+ temperature=self.temperature,
118
+ extra_body=self.extra_body,
119
+ )
120
+ raw = response.choices[0].message.content.strip()
121
+ result = parse_judge_output(raw)
122
+ if not result:
123
+ logger.warning("empty_parse", backend=self.name, sample=comp.sample_idx)
124
+ return result
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Spec parsing
129
+ # ---------------------------------------------------------------------------
130
+
131
+ DEFAULT_JUDGE = "novita:moonshotai/Kimi-K2.5"
132
+
133
+
134
+ def parse_judge_spec(
135
+ spec: str, max_tokens: int = DEFAULT_MAX_TOKENS, concurrency: int = 1,
136
+ ) -> JudgeBackend:
137
+ """Parse a judge specification string into a backend.
138
+
139
+ Formats:
140
+ - ``"https://xxx.endpoints.huggingface.cloud"`` → :class:`OpenAICompatibleJudge`
141
+ (HF Inference Endpoints, OpenAI-compatible with HF token auth)
142
+ - ``"http://..."`` or ``"https://..."`` (other) → :class:`OpenAICompatibleJudge`
143
+ - ``"provider:org/model"`` (colon before first ``/``) → :class:`InferenceProviderJudge`
144
+ - anything else → :class:`InferenceProviderJudge` (no provider)
145
+ """
146
+ if spec.startswith("http://") or spec.startswith("https://"):
147
+ # Check for url:model format (e.g. https://...cloud/v1/:org/model)
148
+ url_part = spec
149
+ model_name = "default"
150
+ # Split on /v1/: to separate URL from model name
151
+ if "/v1/:" in spec:
152
+ url_part, model_name = spec.split("/v1/:", 1)
153
+ url_part += "/v1"
154
+
155
+ # HF Inference Endpoints — OpenAI-compatible, auth via HF token
156
+ if ".endpoints.huggingface." in url_part:
157
+ from huggingface_hub import get_token
158
+
159
+ base_url = url_part.rstrip("/")
160
+ if not base_url.endswith("/v1"):
161
+ base_url += "/v1"
162
+ token = get_token() or "not-needed"
163
+ return OpenAICompatibleJudge(
164
+ base_url=base_url,
165
+ model=model_name,
166
+ api_key=token,
167
+ max_tokens=max_tokens,
168
+ temperature=0.7,
169
+ extra_body={"chat_template_kwargs": {"enable_thinking": False}},
170
+ concurrency=concurrency,
171
+ )
172
+ return OpenAICompatibleJudge(
173
+ base_url=url_part, model=model_name, max_tokens=max_tokens,
174
+ concurrency=concurrency,
175
+ )
176
+
177
+ if ":" in spec:
178
+ # provider:model format — colon must come before first slash
179
+ colon_idx = spec.index(":")
180
+ slash_idx = spec.find("/")
181
+ if slash_idx == -1 or colon_idx < slash_idx:
182
+ provider, model = spec.split(":", 1)
183
+ return InferenceProviderJudge(model=model, provider=provider, max_tokens=max_tokens)
184
+
185
+ return InferenceProviderJudge(model=spec, max_tokens=max_tokens)
186
+
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # Jury aggregation
190
+ # ---------------------------------------------------------------------------
191
+
192
+
193
+ def aggregate_jury_votes(
194
+ all_results: list[list[dict[str, str]]],
195
+ judge_names: list[str],
196
+ ) -> list[dict[str, Any]]:
197
+ """Aggregate votes from multiple judges using majority voting.
198
+
199
+ Args:
200
+ all_results: List of result lists, one per judge. Each inner list
201
+ has one dict per comparison.
202
+ judge_names: Names of the judges (same order as *all_results*).
203
+
204
+ Returns:
205
+ Aggregated results with ``winner``, ``reason``, and ``agreement`` fields.
206
+ """
207
+ if not all_results:
208
+ return []
209
+
210
+ n_comparisons = len(all_results[0])
211
+ n_judges = len(all_results)
212
+ aggregated: list[dict[str, Any]] = []
213
+
214
+ for i in range(n_comparisons):
215
+ votes: list[str] = []
216
+ reasons: list[str] = []
217
+ for j in range(n_judges):
218
+ result = all_results[j][i] if i < len(all_results[j]) else {}
219
+ winner = result.get("winner", "")
220
+ if winner:
221
+ votes.append(winner)
222
+ reasons.append(f"{judge_names[j]}: {result.get('reason', '')}")
223
+
224
+ if not votes:
225
+ aggregated.append({"winner": "tie", "reason": "no valid votes", "agreement": "0/0"})
226
+ continue
227
+
228
+ counter = Counter(votes)
229
+ majority_winner, majority_count = counter.most_common(1)[0]
230
+ agreement = f"{majority_count}/{len(votes)}"
231
+
232
+ aggregated.append({
233
+ "winner": majority_winner,
234
+ "reason": "; ".join(reasons),
235
+ "agreement": agreement,
236
+ })
237
+
238
+ return aggregated
src/ocr_bench/cli.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entrypoint for ocr-bench."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+
8
+ import structlog
9
+ from rich.console import Console
10
+ from rich.table import Table
11
+
12
+ from ocr_bench.backends import (
13
+ DEFAULT_JUDGE,
14
+ DEFAULT_MAX_TOKENS,
15
+ aggregate_jury_votes,
16
+ parse_judge_spec,
17
+ )
18
+ from ocr_bench.dataset import (
19
+ DatasetError,
20
+ discover_configs,
21
+ discover_pr_configs,
22
+ load_config_dataset,
23
+ load_flat_dataset,
24
+ )
25
+ from ocr_bench.elo import ComparisonResult, Leaderboard, compute_elo, rankings_resolved
26
+ from ocr_bench.judge import Comparison, _normalize_pair, build_comparisons, sample_indices
27
+ from ocr_bench.publish import (
28
+ EvalMetadata,
29
+ load_existing_comparisons,
30
+ load_existing_metadata,
31
+ publish_results,
32
+ )
33
+
34
+ logger = structlog.get_logger()
35
+ console = Console()
36
+
37
+
38
+ def build_parser() -> argparse.ArgumentParser:
39
+ parser = argparse.ArgumentParser(
40
+ prog="ocr-bench",
41
+ description="OCR model evaluation toolkit — VLM-as-judge with per-dataset leaderboards",
42
+ )
43
+ sub = parser.add_subparsers(dest="command")
44
+
45
+ judge = sub.add_parser("judge", help="Run pairwise VLM judge on OCR outputs")
46
+
47
+ # Dataset
48
+ judge.add_argument("dataset", help="HF dataset repo id")
49
+ judge.add_argument("--split", default="train", help="Dataset split (default: train)")
50
+ judge.add_argument("--columns", nargs="+", default=None, help="Explicit OCR column names")
51
+ judge.add_argument(
52
+ "--configs", nargs="+", default=None, help="Config-per-model: list of config names"
53
+ )
54
+ judge.add_argument("--from-prs", action="store_true", help="Force PR-based config discovery")
55
+ judge.add_argument(
56
+ "--merge",
57
+ action="store_true",
58
+ help="Merge PRs to main after discovery (default: load via revision)",
59
+ )
60
+
61
+ # Judge
62
+ judge.add_argument(
63
+ "--model",
64
+ action="append",
65
+ dest="models",
66
+ help=f"Judge model spec (repeatable for jury). Default: {DEFAULT_JUDGE}",
67
+ )
68
+
69
+ # Eval
70
+ judge.add_argument("--max-samples", type=int, default=None, help="Max samples to evaluate")
71
+ judge.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
72
+ judge.add_argument(
73
+ "--max-tokens",
74
+ type=int,
75
+ default=DEFAULT_MAX_TOKENS,
76
+ help=f"Max tokens for judge response (default: {DEFAULT_MAX_TOKENS})",
77
+ )
78
+
79
+ # Output
80
+ judge.add_argument(
81
+ "--save-results",
82
+ default=None,
83
+ help="HF repo id to publish results to (default: {dataset}-results)",
84
+ )
85
+ judge.add_argument(
86
+ "--no-publish",
87
+ action="store_true",
88
+ help="Don't publish results (default: publish to {dataset}-results)",
89
+ )
90
+ judge.add_argument(
91
+ "--full-rejudge",
92
+ action="store_true",
93
+ help="Re-judge all pairs, ignoring existing comparisons in --save-results repo",
94
+ )
95
+ judge.add_argument(
96
+ "--no-adaptive",
97
+ action="store_true",
98
+ help="Disable adaptive stopping (default: adaptive is on)",
99
+ )
100
+ judge.add_argument(
101
+ "--concurrency",
102
+ type=int,
103
+ default=1,
104
+ help="Number of concurrent judge API calls (default: 1)",
105
+ )
106
+
107
+ # --- run subcommand ---
108
+ run = sub.add_parser("run", help="Launch OCR models on a dataset via HF Jobs")
109
+ run.add_argument("input_dataset", help="HF dataset repo id with images")
110
+ run.add_argument("output_repo", help="Output dataset repo (all models push here)")
111
+ run.add_argument(
112
+ "--models", nargs="+", default=None, help="Model slugs to run (default: all 4 core)"
113
+ )
114
+ run.add_argument("--max-samples", type=int, default=None, help="Per-model sample limit")
115
+ run.add_argument("--split", default="train", help="Dataset split (default: train)")
116
+ run.add_argument("--flavor", default=None, help="Override GPU flavor for all models")
117
+ run.add_argument("--timeout", default="4h", help="Per-job timeout (default: 4h)")
118
+ run.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
119
+ run.add_argument("--shuffle", action="store_true", help="Shuffle source dataset")
120
+ run.add_argument("--list-models", action="store_true", help="Print available models and exit")
121
+ run.add_argument(
122
+ "--dry-run", action="store_true", help="Show what would launch without launching"
123
+ )
124
+ run.add_argument(
125
+ "--no-wait", action="store_true", help="Launch and exit without polling (default: wait)"
126
+ )
127
+
128
+ # --- view subcommand ---
129
+ view = sub.add_parser("view", help="Browse and validate results in a web UI")
130
+ view.add_argument("results", help="HF dataset repo id with published results")
131
+ view.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
132
+ view.add_argument("--host", default="127.0.0.1", help="Host (default: 127.0.0.1)")
133
+ view.add_argument("--output", default=None, help="Path to save annotations JSON")
134
+
135
+ return parser
136
+
137
+
138
+ def print_leaderboard(board: Leaderboard) -> None:
139
+ """Print leaderboard as a Rich table."""
140
+ table = Table(title="OCR Model Leaderboard")
141
+ table.add_column("Rank", style="bold")
142
+ table.add_column("Model")
143
+ has_ci = bool(board.elo_ci)
144
+ if has_ci:
145
+ table.add_column("ELO (95% CI)", justify="right")
146
+ else:
147
+ table.add_column("ELO", justify="right")
148
+ table.add_column("Wins", justify="right")
149
+ table.add_column("Losses", justify="right")
150
+ table.add_column("Ties", justify="right")
151
+ table.add_column("Win%", justify="right")
152
+
153
+ for rank, (model, elo) in enumerate(board.ranked, 1):
154
+ pct = board.win_pct(model)
155
+ pct_str = f"{pct:.0f}%" if pct is not None else "-"
156
+ if has_ci and model in board.elo_ci:
157
+ lo, hi = board.elo_ci[model]
158
+ elo_str = f"{round(elo)} ({round(lo)}\u2013{round(hi)})"
159
+ else:
160
+ elo_str = str(round(elo))
161
+ table.add_row(
162
+ str(rank),
163
+ model,
164
+ elo_str,
165
+ str(board.wins[model]),
166
+ str(board.losses[model]),
167
+ str(board.ties[model]),
168
+ pct_str,
169
+ )
170
+
171
+ console.print(table)
172
+
173
+
174
+ def _convert_results(
175
+ comparisons: list[Comparison], aggregated: list[dict]
176
+ ) -> list[ComparisonResult]:
177
+ """Convert judged comparisons + aggregated outputs into ComparisonResult list."""
178
+ results: list[ComparisonResult] = []
179
+ for comp, result in zip(comparisons, aggregated):
180
+ if not result:
181
+ continue
182
+ results.append(
183
+ ComparisonResult(
184
+ sample_idx=comp.sample_idx,
185
+ model_a=comp.model_a,
186
+ model_b=comp.model_b,
187
+ winner=result.get("winner", "tie"),
188
+ reason=result.get("reason", ""),
189
+ agreement=result.get("agreement", "1/1"),
190
+ swapped=comp.swapped,
191
+ text_a=comp.text_a,
192
+ text_b=comp.text_b,
193
+ col_a=comp.col_a,
194
+ col_b=comp.col_b,
195
+ )
196
+ )
197
+ return results
198
+
199
+
200
+ def _resolve_results_repo(dataset: str, save_results: str | None, no_publish: bool) -> str | None:
201
+ """Derive the results repo id. Returns None if publishing is disabled."""
202
+ if no_publish:
203
+ return None
204
+ if save_results:
205
+ return save_results
206
+ return f"{dataset}-results"
207
+
208
+
209
+ def cmd_judge(args: argparse.Namespace) -> None:
210
+ """Orchestrate: load → compare → judge → elo → print → publish."""
211
+ # --- Resolve flags ---
212
+ adaptive = not args.no_adaptive
213
+ merge = args.merge
214
+ results_repo = _resolve_results_repo(args.dataset, args.save_results, args.no_publish)
215
+ from_prs = False # track for metadata
216
+
217
+ if results_repo:
218
+ console.print(f"Results will be published to [bold]{results_repo}[/bold]")
219
+
220
+ # --- Load dataset (cascading auto-detection) ---
221
+ if args.configs:
222
+ # Explicit configs — use them directly
223
+ config_names = args.configs
224
+ ds, ocr_columns = load_config_dataset(args.dataset, config_names, split=args.split)
225
+ elif args.columns:
226
+ # Explicit columns — flat loading
227
+ ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split, columns=args.columns)
228
+ elif args.from_prs:
229
+ # Forced PR discovery
230
+ config_names, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
231
+ if not config_names:
232
+ raise DatasetError("No configs found in open PRs")
233
+ from_prs = True
234
+ console.print(f"Discovered {len(config_names)} configs from PRs: {config_names}")
235
+ ds, ocr_columns = load_config_dataset(
236
+ args.dataset,
237
+ config_names,
238
+ split=args.split,
239
+ pr_revisions=pr_revisions if not merge else None,
240
+ )
241
+ else:
242
+ # Auto-detect: PRs + main branch configs combined, fall back to flat
243
+ pr_configs, pr_revisions = discover_pr_configs(args.dataset, merge=merge)
244
+ main_configs = discover_configs(args.dataset)
245
+
246
+ # Combine: PR configs + main configs not already in PRs
247
+ config_names = list(pr_configs)
248
+ for mc in main_configs:
249
+ if mc not in pr_configs:
250
+ config_names.append(mc)
251
+
252
+ if config_names:
253
+ if pr_configs:
254
+ from_prs = True
255
+ console.print(f"Auto-detected {len(pr_configs)} configs from PRs: {pr_configs}")
256
+ if main_configs:
257
+ main_only = [c for c in main_configs if c not in pr_configs]
258
+ if main_only:
259
+ console.print(f"Auto-detected {len(main_only)} configs on main: {main_only}")
260
+ ds, ocr_columns = load_config_dataset(
261
+ args.dataset,
262
+ config_names,
263
+ split=args.split,
264
+ pr_revisions=pr_revisions if pr_configs else None,
265
+ )
266
+ else:
267
+ # No configs anywhere — fall back to flat loading
268
+ ds, ocr_columns = load_flat_dataset(args.dataset, split=args.split)
269
+
270
+ console.print(f"Loaded {len(ds)} samples with {len(ocr_columns)} models:")
271
+ for col, model in ocr_columns.items():
272
+ console.print(f" {col} → {model}")
273
+
274
+ # --- Incremental: load existing comparisons ---
275
+ existing_results: list[ComparisonResult] = []
276
+ existing_meta_rows: list[dict] = []
277
+ skip_pairs: set[tuple[str, str]] | None = None
278
+
279
+ if results_repo and not args.full_rejudge:
280
+ existing_results = load_existing_comparisons(results_repo)
281
+ if existing_results:
282
+ judged_pairs = {_normalize_pair(r.model_a, r.model_b) for r in existing_results}
283
+ skip_pairs = judged_pairs
284
+ console.print(
285
+ f"\nIncremental mode: {len(existing_results)} existing comparisons "
286
+ f"across {len(judged_pairs)} model pairs — skipping those."
287
+ )
288
+ existing_meta_rows = load_existing_metadata(results_repo)
289
+ else:
290
+ console.print("\nNo existing comparisons found — full judge run.")
291
+
292
+ model_names = list(set(ocr_columns.values()))
293
+
294
+ # --- Judge setup (shared by both paths) ---
295
+ model_specs = args.models or [DEFAULT_JUDGE]
296
+ judges = [
297
+ parse_judge_spec(spec, max_tokens=args.max_tokens, concurrency=args.concurrency)
298
+ for spec in model_specs
299
+ ]
300
+ is_jury = len(judges) > 1
301
+
302
+ def _judge_batch(batch_comps: list[Comparison]) -> list[ComparisonResult]:
303
+ """Run judge(s) on a batch of comparisons and return ComparisonResults."""
304
+ all_judge_outputs: list[list[dict]] = []
305
+ for judge in judges:
306
+ results = judge.judge(batch_comps)
307
+ all_judge_outputs.append(results)
308
+ if is_jury:
309
+ judge_names = [j.name for j in judges]
310
+ aggregated = aggregate_jury_votes(all_judge_outputs, judge_names)
311
+ else:
312
+ aggregated = all_judge_outputs[0]
313
+ return _convert_results(batch_comps, aggregated)
314
+
315
+ if adaptive:
316
+ # --- Adaptive stopping: batch-by-batch with convergence check ---
317
+ from itertools import combinations as _combs
318
+
319
+ all_indices = sample_indices(len(ds), args.max_samples, args.seed)
320
+ n_pairs = len(list(_combs(model_names, 2)))
321
+ batch_samples = 5
322
+ min_before_check = max(3 * n_pairs, 20)
323
+
324
+ if is_jury:
325
+ console.print(f"\nJury mode: {len(judges)} judges")
326
+ console.print(
327
+ f"\n[bold]Adaptive mode[/bold]: {len(all_indices)} samples, "
328
+ f"{n_pairs} pairs, batch size {batch_samples}, "
329
+ f"checking after {min_before_check} comparisons"
330
+ )
331
+
332
+ new_results: list[ComparisonResult] = []
333
+ total_comparisons = 0
334
+ for batch_num, batch_start in enumerate(range(0, len(all_indices), batch_samples)):
335
+ batch_indices = all_indices[batch_start : batch_start + batch_samples]
336
+ batch_comps = build_comparisons(
337
+ ds,
338
+ ocr_columns,
339
+ skip_pairs=skip_pairs,
340
+ indices=batch_indices,
341
+ seed=args.seed,
342
+ )
343
+ if not batch_comps:
344
+ continue
345
+
346
+ batch_results = _judge_batch(batch_comps)
347
+ new_results.extend(batch_results)
348
+ total_comparisons += len(batch_comps)
349
+ # batch_comps goes out of scope → GC can free images
350
+
351
+ total = len(existing_results) + len(new_results)
352
+ console.print(f" Batch {batch_num + 1}: {len(batch_results)} new, {total} total")
353
+
354
+ if total >= min_before_check:
355
+ board = compute_elo(existing_results + new_results, model_names)
356
+ # Show CI gaps for each adjacent pair
357
+ ranked = board.ranked
358
+ if board.elo_ci:
359
+ gaps: list[str] = []
360
+ for i in range(len(ranked) - 1):
361
+ hi_model, _ = ranked[i]
362
+ lo_model, _ = ranked[i + 1]
363
+ hi_ci = board.elo_ci.get(hi_model)
364
+ lo_ci = board.elo_ci.get(lo_model)
365
+ if hi_ci and lo_ci:
366
+ gap = hi_ci[0] - lo_ci[1] # positive = resolved
367
+ if gap > 0:
368
+ status = "[green]ok[/green]"
369
+ else:
370
+ status = f"[yellow]overlap {-gap:.0f}[/yellow]"
371
+ gaps.append(f" {hi_model} vs {lo_model}: gap={gap:+.0f} {status}")
372
+ if gaps:
373
+ console.print(" CI gaps:")
374
+ for g in gaps:
375
+ console.print(g)
376
+
377
+ if rankings_resolved(board):
378
+ remaining = len(all_indices) - batch_start - len(batch_indices)
379
+ console.print(
380
+ f"[green]Rankings converged after {total} comparisons! "
381
+ f"Skipped ~{remaining * n_pairs} remaining.[/green]"
382
+ )
383
+ break
384
+
385
+ console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
386
+ else:
387
+ # --- Standard single-pass flow ---
388
+ comparisons = build_comparisons(
389
+ ds,
390
+ ocr_columns,
391
+ max_samples=args.max_samples,
392
+ seed=args.seed,
393
+ skip_pairs=skip_pairs,
394
+ )
395
+ console.print(f"\nBuilt {len(comparisons)} new pairwise comparisons")
396
+
397
+ if not comparisons and not existing_results:
398
+ console.print(
399
+ "[yellow]No valid comparisons — check that OCR columns have text.[/yellow]"
400
+ )
401
+ return
402
+
403
+ if not comparisons:
404
+ console.print("[green]All pairs already judged — refitting leaderboard.[/green]")
405
+ board = compute_elo(existing_results, model_names)
406
+ console.print()
407
+ print_leaderboard(board)
408
+ if results_repo:
409
+ metadata = EvalMetadata(
410
+ source_dataset=args.dataset,
411
+ judge_models=[],
412
+ seed=args.seed,
413
+ max_samples=args.max_samples or len(ds),
414
+ total_comparisons=0,
415
+ valid_comparisons=0,
416
+ from_prs=from_prs,
417
+ )
418
+ publish_results(
419
+ results_repo,
420
+ board,
421
+ metadata,
422
+ existing_metadata=existing_meta_rows,
423
+ )
424
+ console.print(f"\nResults published to [bold]{results_repo}[/bold]")
425
+ return
426
+
427
+ if is_jury:
428
+ console.print(f"\nJury mode: {len(judges)} judges")
429
+
430
+ for judge in judges:
431
+ console.print(f"\nRunning judge: {judge.name}")
432
+
433
+ new_results = _judge_batch(comparisons)
434
+ total_comparisons = len(comparisons)
435
+ console.print(f"\n{len(new_results)}/{total_comparisons} valid comparisons")
436
+
437
+ # --- Merge existing + new, compute ELO ---
438
+ all_results = existing_results + new_results
439
+ board = compute_elo(all_results, model_names)
440
+ console.print()
441
+ print_leaderboard(board)
442
+
443
+ # --- Publish ---
444
+ if results_repo:
445
+ metadata = EvalMetadata(
446
+ source_dataset=args.dataset,
447
+ judge_models=[j.name for j in judges],
448
+ seed=args.seed,
449
+ max_samples=args.max_samples or len(ds),
450
+ total_comparisons=total_comparisons,
451
+ valid_comparisons=len(new_results),
452
+ from_prs=from_prs,
453
+ )
454
+ publish_results(results_repo, board, metadata, existing_metadata=existing_meta_rows)
455
+ console.print(f"\nResults published to [bold]{results_repo}[/bold]")
456
+
457
+
458
+ def cmd_run(args: argparse.Namespace) -> None:
459
+ """Launch OCR models on a dataset via HF Jobs."""
460
+ from ocr_bench.run import (
461
+ DEFAULT_MODELS,
462
+ MODEL_REGISTRY,
463
+ build_script_args,
464
+ launch_ocr_jobs,
465
+ poll_jobs,
466
+ )
467
+
468
+ # --list-models
469
+ if args.list_models:
470
+ table = Table(title="Available OCR Models", show_lines=True)
471
+ table.add_column("Slug", style="cyan bold")
472
+ table.add_column("Model ID")
473
+ table.add_column("Size", justify="right")
474
+ table.add_column("Default GPU", justify="center")
475
+
476
+ for slug in sorted(MODEL_REGISTRY):
477
+ cfg = MODEL_REGISTRY[slug]
478
+ default = " (default)" if slug in DEFAULT_MODELS else ""
479
+ table.add_row(slug + default, cfg.model_id, cfg.size, cfg.default_flavor)
480
+
481
+ console.print(table)
482
+ console.print(f"\nDefault set: {', '.join(DEFAULT_MODELS)}")
483
+ return
484
+
485
+ selected = args.models or DEFAULT_MODELS
486
+ for slug in selected:
487
+ if slug not in MODEL_REGISTRY:
488
+ console.print(f"[red]Unknown model: {slug}[/red]")
489
+ console.print(f"Available: {', '.join(MODEL_REGISTRY.keys())}")
490
+ sys.exit(1)
491
+
492
+ console.print("\n[bold]OCR Benchmark Run[/bold]")
493
+ console.print(f" Source: {args.input_dataset}")
494
+ console.print(f" Output: {args.output_repo}")
495
+ console.print(f" Models: {', '.join(selected)}")
496
+ if args.max_samples:
497
+ console.print(f" Samples: {args.max_samples} per model")
498
+ console.print()
499
+
500
+ # Dry run
501
+ if args.dry_run:
502
+ console.print("[bold yellow]DRY RUN[/bold yellow] — no jobs will be launched\n")
503
+ for slug in selected:
504
+ cfg = MODEL_REGISTRY[slug]
505
+ flavor = args.flavor or cfg.default_flavor
506
+ script_args = build_script_args(
507
+ args.input_dataset,
508
+ args.output_repo,
509
+ slug,
510
+ max_samples=args.max_samples,
511
+ shuffle=args.shuffle,
512
+ seed=args.seed,
513
+ extra_args=cfg.default_args or None,
514
+ )
515
+ console.print(f"[cyan]{slug}[/cyan] ({cfg.model_id})")
516
+ console.print(f" Flavor: {flavor}")
517
+ console.print(f" Timeout: {args.timeout}")
518
+ console.print(f" Script: {cfg.script}")
519
+ console.print(f" Args: {' '.join(script_args)}")
520
+ console.print()
521
+ console.print("Remove --dry-run to launch these jobs.")
522
+ return
523
+
524
+ # Launch
525
+ jobs = launch_ocr_jobs(
526
+ args.input_dataset,
527
+ args.output_repo,
528
+ models=selected,
529
+ max_samples=args.max_samples,
530
+ split=args.split,
531
+ shuffle=args.shuffle,
532
+ seed=args.seed,
533
+ flavor_override=args.flavor,
534
+ timeout=args.timeout,
535
+ )
536
+
537
+ console.print(f"\n[green]{len(jobs)} jobs launched.[/green]")
538
+ for job in jobs:
539
+ console.print(f" [cyan]{job.model_slug}[/cyan]: {job.job_url}")
540
+
541
+ if not args.no_wait:
542
+ console.print("\n[bold]Waiting for jobs to complete...[/bold]")
543
+ poll_jobs(jobs)
544
+ console.print("\n[bold green]All jobs finished![/bold green]")
545
+ console.print("\nEvaluate:")
546
+ console.print(f" ocr-bench judge {args.output_repo}")
547
+ else:
548
+ console.print("\nJobs running in background.")
549
+ console.print("Check status at: https://huggingface.co/settings/jobs")
550
+ console.print(f"When complete: ocr-bench judge {args.output_repo}")
551
+
552
+
553
+ def cmd_view(args: argparse.Namespace) -> None:
554
+ """Launch the FastAPI + HTMX results viewer."""
555
+ try:
556
+ import uvicorn
557
+
558
+ from ocr_bench.web import create_app
559
+ except ImportError:
560
+ console.print(
561
+ "[red]Error:[/red] FastAPI/uvicorn not installed. "
562
+ "Install the viewer extra: [bold]pip install ocr-bench\\[viewer][/bold]"
563
+ )
564
+ sys.exit(1)
565
+
566
+ console.print(f"Loading results from [bold]{args.results}[/bold]...")
567
+ app = create_app(args.results, output_path=args.output)
568
+ console.print(f"Starting viewer at [bold]http://{args.host}:{args.port}[/bold]")
569
+ uvicorn.run(app, host=args.host, port=args.port)
570
+
571
+
572
+ def main() -> None:
573
+ parser = build_parser()
574
+ args = parser.parse_args()
575
+
576
+ if args.command is None:
577
+ parser.print_help()
578
+ sys.exit(0)
579
+
580
+ try:
581
+ if args.command == "judge":
582
+ cmd_judge(args)
583
+ elif args.command == "run":
584
+ cmd_run(args)
585
+ elif args.command == "view":
586
+ cmd_view(args)
587
+ except DatasetError as exc:
588
+ console.print(f"[red]Error:[/red] {exc}")
589
+ sys.exit(1)
src/ocr_bench/dataset.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset loading — flat, config-per-model, PR-based. OCR column discovery."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+
7
+ import structlog
8
+ from datasets import Dataset, get_dataset_config_names, load_dataset
9
+ from huggingface_hub import HfApi
10
+
11
+ logger = structlog.get_logger()
12
+
13
+
14
+ class DatasetError(Exception):
15
+ """Raised when dataset loading or column discovery fails."""
16
+
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # OCR column discovery
20
+ # ---------------------------------------------------------------------------
21
+
22
+
23
+ def discover_ocr_columns(dataset: Dataset) -> dict[str, str]:
24
+ """Discover OCR output columns and their model names from a dataset.
25
+
26
+ Strategy:
27
+ 1. Parse ``inference_info`` JSON from the first row (list or single entry).
28
+ 2. Fallback: heuristic column-name matching (``markdown``, ``ocr``, ``text``).
29
+ 3. Disambiguate duplicate model names by appending the column name.
30
+
31
+ Returns:
32
+ Mapping of ``column_name → model_name``.
33
+
34
+ Raises:
35
+ DatasetError: If no OCR columns can be found.
36
+ """
37
+ columns: dict[str, str] = {}
38
+
39
+ try:
40
+ if "inference_info" not in dataset.column_names:
41
+ raise KeyError("no inference_info column")
42
+ info_raw = dataset["inference_info"][0] # column access avoids image decode
43
+ if info_raw:
44
+ info = json.loads(info_raw)
45
+ if not isinstance(info, list):
46
+ info = [info]
47
+ for entry in info:
48
+ col = entry.get("column_name", "")
49
+ model = entry.get("model_id", entry.get("model_name", "unknown"))
50
+ if col and col in dataset.column_names:
51
+ columns[col] = model
52
+ except (json.JSONDecodeError, TypeError, KeyError) as exc:
53
+ logger.warning("could_not_parse_inference_info", error=str(exc))
54
+
55
+ # Fallback: heuristic
56
+ if not columns:
57
+ for col in dataset.column_names:
58
+ lower = col.lower()
59
+ if "markdown" in lower or "ocr" in lower or col == "text":
60
+ columns[col] = col
61
+
62
+ if not columns:
63
+ raise DatasetError(f"No OCR columns found. Available columns: {dataset.column_names}")
64
+
65
+ # Disambiguate duplicates
66
+ model_counts: dict[str, int] = {}
67
+ for model in columns.values():
68
+ model_counts[model] = model_counts.get(model, 0) + 1
69
+
70
+ disambiguated: dict[str, str] = {}
71
+ for col, model in columns.items():
72
+ if model_counts[model] > 1:
73
+ short = model.split("/")[-1] if "/" in model else model
74
+ disambiguated[col] = f"{short} ({col})"
75
+ else:
76
+ disambiguated[col] = model
77
+
78
+ return disambiguated
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # PR-based config discovery
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ def discover_pr_configs(
87
+ repo_id: str,
88
+ merge: bool = False,
89
+ api: HfApi | None = None,
90
+ ) -> tuple[list[str], dict[str, str]]:
91
+ """Discover dataset configs from open PRs on a Hub dataset repo.
92
+
93
+ PR titles must end with ``[config_name]`` to be detected.
94
+
95
+ Args:
96
+ repo_id: HF dataset repo id.
97
+ merge: If True, merge each discovered PR before loading.
98
+ api: Optional pre-configured HfApi instance.
99
+
100
+ Returns:
101
+ Tuple of (config_names, {config_name: pr_revision}).
102
+ """
103
+ if api is None:
104
+ api = HfApi()
105
+
106
+ config_names: list[str] = []
107
+ revisions: dict[str, str] = {}
108
+
109
+ discussions = api.get_repo_discussions(repo_id, repo_type="dataset")
110
+ for disc in discussions:
111
+ if not disc.is_pull_request or disc.status != "open":
112
+ continue
113
+ title = disc.title
114
+ if "[" in title and title.endswith("]"):
115
+ config = title[title.rindex("[") + 1 : -1].strip()
116
+ if config:
117
+ if merge:
118
+ api.merge_pull_request(repo_id, disc.num, repo_type="dataset")
119
+ logger.info("merged_pr", pr=disc.num, config=config)
120
+ else:
121
+ revisions[config] = f"refs/pr/{disc.num}"
122
+ config_names.append(config)
123
+
124
+ return config_names, revisions
125
+
126
+
127
+ def discover_configs(repo_id: str) -> list[str]:
128
+ """List non-default configs from the main branch of a Hub dataset.
129
+
130
+ Returns:
131
+ Config names excluding "default", or empty list if none found.
132
+ """
133
+ try:
134
+ configs = get_dataset_config_names(repo_id)
135
+ except Exception as exc:
136
+ logger.info("no_configs_on_main", repo=repo_id, reason=str(exc))
137
+ return []
138
+ return [c for c in configs if c != "default"]
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # Config-per-model loading
143
+ # ---------------------------------------------------------------------------
144
+
145
+
146
+ def load_config_dataset(
147
+ repo_id: str,
148
+ config_names: list[str],
149
+ split: str = "train",
150
+ pr_revisions: dict[str, str] | None = None,
151
+ ) -> tuple[Dataset, dict[str, str]]:
152
+ """Load multiple configs from a Hub dataset and merge into one.
153
+
154
+ Each config becomes a column whose name is the config name and whose value
155
+ is the OCR text (from the first column matching heuristics, or ``markdown``).
156
+
157
+ Args:
158
+ repo_id: HF dataset repo id.
159
+ config_names: List of config names to load.
160
+ split: Dataset split to load.
161
+ pr_revisions: Optional mapping of config_name → revision for PR-based loading.
162
+
163
+ Returns:
164
+ Tuple of (unified Dataset, {column_name: model_id}).
165
+ """
166
+ if not config_names:
167
+ raise DatasetError("No config names provided")
168
+
169
+ pr_revisions = pr_revisions or {}
170
+ unified: Dataset | None = None
171
+ ocr_columns: dict[str, str] = {}
172
+
173
+ for config in config_names:
174
+ revision = pr_revisions.get(config)
175
+ kwargs: dict = {"path": repo_id, "name": config, "split": split}
176
+ if revision:
177
+ kwargs["revision"] = revision
178
+
179
+ ds = load_dataset(**kwargs)
180
+
181
+ # Find the OCR text column in this config
182
+ text_col = _find_text_column(ds)
183
+ if text_col is None:
184
+ logger.warning("no_text_column_in_config", config=config)
185
+ continue
186
+
187
+ # Extract model_id from inference_info if available
188
+ model_id = _extract_model_id(ds, config)
189
+ ocr_columns[config] = model_id
190
+
191
+ # Build unified dataset using Arrow-level ops (no per-row image decode)
192
+ text_values = ds[text_col] # column access — no image decoding
193
+ if unified is None:
194
+ # First config: keep all columns except text_col, add text as config name
195
+ drop = [text_col] if text_col != config else []
196
+ unified = ds.remove_columns(drop) if drop else ds
197
+ if config != text_col:
198
+ unified = unified.add_column(config, text_values)
199
+ # Also rename text_col to config if they differ and text_col was kept
200
+ else:
201
+ if len(ds) != len(unified):
202
+ logger.warning(
203
+ "config_length_mismatch",
204
+ config=config,
205
+ expected=len(unified),
206
+ got=len(ds),
207
+ )
208
+ text_values = text_values[: len(unified)]
209
+ unified = unified.add_column(config, text_values)
210
+
211
+ if unified is None:
212
+ raise DatasetError("No configs loaded successfully")
213
+
214
+ return unified, ocr_columns
215
+
216
+
217
+ def _extract_model_id(ds: Dataset, config: str) -> str:
218
+ """Extract model_id from inference_info in first row, falling back to config name."""
219
+ if "inference_info" not in ds.column_names:
220
+ return config
221
+ try:
222
+ info_raw = ds["inference_info"][0] # column access avoids image decode
223
+ if info_raw:
224
+ info = json.loads(info_raw)
225
+ if isinstance(info, list):
226
+ info = info[0]
227
+ return info.get("model_id", info.get("model_name", config))
228
+ except (json.JSONDecodeError, TypeError, KeyError, IndexError):
229
+ pass
230
+ return config
231
+
232
+
233
+ def _find_text_column(ds: Dataset) -> str | None:
234
+ """Find the likely OCR text column in a dataset.
235
+
236
+ Priority:
237
+ 1. ``inference_info[0]["column_name"]`` if present and exists in dataset.
238
+ 2. First column matching ``markdown`` (case-insensitive).
239
+ 3. First column matching ``ocr`` (case-insensitive).
240
+ 4. Column named exactly ``text``.
241
+ """
242
+ # Try inference_info first (column access avoids image decoding)
243
+ if "inference_info" in ds.column_names:
244
+ try:
245
+ info_raw = ds["inference_info"][0]
246
+ if info_raw:
247
+ info = json.loads(info_raw)
248
+ if isinstance(info, list):
249
+ info = info[0]
250
+ col_name = info.get("column_name", "")
251
+ if col_name and col_name in ds.column_names:
252
+ return col_name
253
+ except (json.JSONDecodeError, TypeError, KeyError, IndexError):
254
+ pass
255
+
256
+ # Prioritized heuristic: markdown > ocr > text
257
+ for pattern in ["markdown", "ocr"]:
258
+ for col in ds.column_names:
259
+ if pattern in col.lower():
260
+ return col
261
+ if "text" in ds.column_names:
262
+ return "text"
263
+ return None
264
+
265
+
266
+ # ---------------------------------------------------------------------------
267
+ # Flat dataset loading
268
+ # ---------------------------------------------------------------------------
269
+
270
+
271
+ def load_flat_dataset(
272
+ repo_id: str,
273
+ split: str = "train",
274
+ columns: list[str] | None = None,
275
+ ) -> tuple[Dataset, dict[str, str]]:
276
+ """Load a flat dataset from Hub and discover OCR columns.
277
+
278
+ Args:
279
+ repo_id: HF dataset repo id.
280
+ split: Dataset split.
281
+ columns: If given, use these as OCR columns (maps col→col).
282
+
283
+ Returns:
284
+ Tuple of (Dataset, {column_name: model_name}).
285
+ """
286
+ ds = load_dataset(repo_id, split=split)
287
+
288
+ if columns:
289
+ # Validate columns exist
290
+ for col in columns:
291
+ if col not in ds.column_names:
292
+ raise DatasetError(f"Column '{col}' not found. Available: {ds.column_names}")
293
+ ocr_columns = {col: col for col in columns}
294
+ else:
295
+ ocr_columns = discover_ocr_columns(ds)
296
+
297
+ return ds, ocr_columns
src/ocr_bench/elo.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bradley-Terry MLE rating computation for pairwise comparisons."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import random
7
+ from collections import defaultdict
8
+ from dataclasses import dataclass, field
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+ from scipy.optimize import minimize
13
+
14
+ INITIAL_ELO: float = 1500.0
15
+
16
+ Winner = Literal["A", "B", "tie"]
17
+
18
+
19
+ @dataclass
20
+ class ComparisonResult:
21
+ """Result of a single pairwise comparison, ready for ELO computation."""
22
+
23
+ sample_idx: int
24
+ model_a: str
25
+ model_b: str
26
+ winner: Winner
27
+ reason: str = ""
28
+ agreement: str = "1/1"
29
+ swapped: bool = False
30
+ text_a: str = ""
31
+ text_b: str = ""
32
+ col_a: str = ""
33
+ col_b: str = ""
34
+
35
+
36
+ @dataclass
37
+ class Leaderboard:
38
+ """ELO leaderboard computed from pairwise comparison results."""
39
+
40
+ elo: dict[str, float] = field(default_factory=dict)
41
+ wins: dict[str, int] = field(default_factory=dict)
42
+ losses: dict[str, int] = field(default_factory=dict)
43
+ ties: dict[str, int] = field(default_factory=dict)
44
+ comparison_log: list[dict[str, object]] = field(default_factory=list)
45
+ elo_ci: dict[str, tuple[float, float]] = field(default_factory=dict)
46
+
47
+ @property
48
+ def ranked(self) -> list[tuple[str, float]]:
49
+ """Models sorted by ELO rating, descending."""
50
+ return sorted(self.elo.items(), key=lambda x: x[1], reverse=True)
51
+
52
+ def win_pct(self, model: str) -> float | None:
53
+ """Win percentage for a model, or None if no comparisons."""
54
+ total = self.wins[model] + self.losses[model] + self.ties[model]
55
+ if total == 0:
56
+ return None
57
+ return self.wins[model] / total * 100
58
+
59
+
60
+ def _unswap_winner(winner: Winner, swapped: bool) -> Winner:
61
+ """Unswap winner if positions were randomized."""
62
+ if swapped:
63
+ if winner == "A":
64
+ return "B"
65
+ elif winner == "B":
66
+ return "A"
67
+ return winner
68
+
69
+
70
+ def _build_win_matrix(
71
+ results: list[ComparisonResult],
72
+ ) -> tuple[dict[tuple[str, str], float], set[str]]:
73
+ """Count wins per ordered pair. Ties count as 0.5 for each side.
74
+
75
+ Returns (win_counts, models_seen) where win_counts[(i, j)] = fractional
76
+ wins of i over j.
77
+ """
78
+ win_counts: dict[tuple[str, str], float] = defaultdict(float)
79
+ models_seen: set[str] = set()
80
+
81
+ for r in results:
82
+ winner = _unswap_winner(r.winner, r.swapped)
83
+ models_seen.add(r.model_a)
84
+ models_seen.add(r.model_b)
85
+
86
+ if winner == "A":
87
+ win_counts[(r.model_a, r.model_b)] += 1.0
88
+ elif winner == "B":
89
+ win_counts[(r.model_b, r.model_a)] += 1.0
90
+ else:
91
+ win_counts[(r.model_a, r.model_b)] += 0.5
92
+ win_counts[(r.model_b, r.model_a)] += 0.5
93
+
94
+ return win_counts, models_seen
95
+
96
+
97
+ def _bt_mle(
98
+ win_counts: dict[tuple[str, str], float],
99
+ model_names: list[str],
100
+ ) -> dict[str, float]:
101
+ """Fit Bradley-Terry model via maximum likelihood estimation.
102
+
103
+ Returns theta (strength) per model. Uses scipy L-BFGS-B on the
104
+ negative log-likelihood with log-parameterization for positivity.
105
+ """
106
+ n = len(model_names)
107
+ if n == 0:
108
+ return {}
109
+ if n == 1:
110
+ return {model_names[0]: 1.0}
111
+
112
+ idx = {name: i for i, name in enumerate(model_names)}
113
+
114
+ # Collect all pairs with nonzero games
115
+ pairs: list[tuple[int, int, float, float]] = []
116
+ for i_name in model_names:
117
+ for j_name in model_names:
118
+ if i_name >= j_name:
119
+ continue
120
+ w_ij = win_counts.get((i_name, j_name), 0.0)
121
+ w_ji = win_counts.get((j_name, i_name), 0.0)
122
+ if w_ij + w_ji > 0:
123
+ pairs.append((idx[i_name], idx[j_name], w_ij, w_ji))
124
+
125
+ if not pairs:
126
+ return {name: 1.0 for name in model_names}
127
+
128
+ def neg_log_likelihood(log_theta: np.ndarray) -> float:
129
+ nll = 0.0
130
+ for i, j, w_ij, w_ji in pairs:
131
+ diff = log_theta[i] - log_theta[j]
132
+ # log(theta_i / (theta_i + theta_j)) = diff - log(1 + exp(diff))
133
+ # log(theta_j / (theta_i + theta_j)) = -diff - log(1 + exp(-diff))
134
+ # Use log-sum-exp for numerical stability
135
+ log_p_ij = diff - np.logaddexp(0.0, diff)
136
+ log_p_ji = -diff - np.logaddexp(0.0, -diff)
137
+ nll -= w_ij * log_p_ij + w_ji * log_p_ji
138
+ return nll
139
+
140
+ def gradient(log_theta: np.ndarray) -> np.ndarray:
141
+ grad = np.zeros(n)
142
+ for i, j, w_ij, w_ji in pairs:
143
+ diff = log_theta[i] - log_theta[j]
144
+ p_ij = 1.0 / (1.0 + np.exp(-diff)) # sigmoid(diff)
145
+ total = w_ij + w_ji
146
+ # d(NLL)/d(log_theta_i)
147
+ grad[i] -= w_ij - total * p_ij
148
+ grad[j] -= w_ji - total * (1.0 - p_ij)
149
+ return grad
150
+
151
+ # Pin first model at 0 to fix the scale
152
+ x0 = np.zeros(n)
153
+ result = minimize(
154
+ neg_log_likelihood,
155
+ x0,
156
+ jac=gradient,
157
+ method="L-BFGS-B",
158
+ )
159
+
160
+ log_theta = result.x
161
+ # Center: subtract geometric mean (= mean of log_theta)
162
+ log_theta -= log_theta.mean()
163
+ theta = np.exp(log_theta)
164
+
165
+ return {name: float(theta[idx[name]]) for name in model_names}
166
+
167
+
168
+ def _theta_to_elo(theta: dict[str, float], center: float = 1500.0) -> dict[str, float]:
169
+ """Convert BT theta values to ELO scale.
170
+
171
+ ELO_i = 400 * log10(theta_i / theta_ref) + center
172
+ where theta_ref is the geometric mean of all theta values.
173
+ """
174
+ if not theta:
175
+ return {}
176
+
177
+ values = list(theta.values())
178
+ log_geo_mean = sum(math.log(v) for v in values) / len(values)
179
+ geo_mean = math.exp(log_geo_mean)
180
+
181
+ return {
182
+ name: 400.0 * math.log10(t / geo_mean) + center
183
+ for name, t in theta.items()
184
+ }
185
+
186
+
187
+ def _bootstrap_ci(
188
+ results: list[ComparisonResult],
189
+ model_names: list[str],
190
+ n_bootstrap: int = 1000,
191
+ ci: float = 0.95,
192
+ seed: int = 42,
193
+ ) -> dict[str, tuple[float, float]]:
194
+ """Compute bootstrap confidence intervals for ELO ratings.
195
+
196
+ Resamples comparisons with replacement, fits BT-MLE each time,
197
+ returns percentile-based CIs.
198
+ """
199
+ if not results or not model_names:
200
+ return {}
201
+
202
+ rng = random.Random(seed)
203
+ n = len(results)
204
+ elo_samples: dict[str, list[float]] = {name: [] for name in model_names}
205
+
206
+ for _ in range(n_bootstrap):
207
+ boot = rng.choices(results, k=n)
208
+ win_counts, _ = _build_win_matrix(boot)
209
+ theta = _bt_mle(win_counts, model_names)
210
+ elos = _theta_to_elo(theta)
211
+ for name in model_names:
212
+ elo_samples[name].append(elos.get(name, 1500.0))
213
+
214
+ alpha = (1.0 - ci) / 2.0
215
+ lo_pct = alpha * 100
216
+ hi_pct = (1.0 - alpha) * 100
217
+
218
+ cis: dict[str, tuple[float, float]] = {}
219
+ for name in model_names:
220
+ samples = sorted(elo_samples[name])
221
+ lo_idx = int(len(samples) * lo_pct / 100)
222
+ hi_idx = min(int(len(samples) * hi_pct / 100), len(samples) - 1)
223
+ cis[name] = (samples[lo_idx], samples[hi_idx])
224
+
225
+ return cis
226
+
227
+
228
+ def rankings_resolved(board: Leaderboard) -> bool:
229
+ """Check if all adjacent ranks have non-overlapping 95% CIs.
230
+
231
+ Returns True when the ranking order is statistically resolved — i.e. for
232
+ every pair of adjacent models in the ranking, the higher-ranked model's
233
+ CI lower bound exceeds the lower-ranked model's CI upper bound.
234
+ """
235
+ if not board.elo_ci:
236
+ return False
237
+ ranked = board.ranked
238
+ if len(ranked) < 2:
239
+ return False
240
+ for i in range(len(ranked) - 1):
241
+ model_hi, _ = ranked[i]
242
+ model_lo, _ = ranked[i + 1]
243
+ if model_hi not in board.elo_ci or model_lo not in board.elo_ci:
244
+ return False
245
+ lo_of_higher, _ = board.elo_ci[model_hi]
246
+ _, hi_of_lower = board.elo_ci[model_lo]
247
+ if hi_of_lower >= lo_of_higher:
248
+ return False # CIs overlap
249
+ return True
250
+
251
+
252
+ def compute_elo(
253
+ results: list[ComparisonResult],
254
+ model_names: list[str],
255
+ n_bootstrap: int = 1000,
256
+ ) -> Leaderboard:
257
+ """Compute ELO ratings from pairwise comparison results using Bradley-Terry MLE.
258
+
259
+ Handles position-bias unswapping: if a result has swapped=True,
260
+ the winner is flipped before updating ratings.
261
+
262
+ Bootstrap confidence intervals are computed when n_bootstrap > 0.
263
+ """
264
+ board = Leaderboard(
265
+ elo={m: INITIAL_ELO for m in model_names},
266
+ wins={m: 0 for m in model_names},
267
+ losses={m: 0 for m in model_names},
268
+ ties={m: 0 for m in model_names},
269
+ )
270
+
271
+ # Tally wins/losses/ties and build comparison log
272
+ for r in results:
273
+ winner = _unswap_winner(r.winner, r.swapped)
274
+
275
+ if winner == "A":
276
+ board.wins[r.model_a] += 1
277
+ board.losses[r.model_b] += 1
278
+ elif winner == "B":
279
+ board.losses[r.model_a] += 1
280
+ board.wins[r.model_b] += 1
281
+ else:
282
+ board.ties[r.model_a] += 1
283
+ board.ties[r.model_b] += 1
284
+
285
+ board.comparison_log.append(
286
+ {
287
+ "sample_idx": r.sample_idx,
288
+ "model_a": r.model_a,
289
+ "model_b": r.model_b,
290
+ "winner": winner,
291
+ "reason": r.reason,
292
+ "agreement": r.agreement,
293
+ "text_a": r.text_a,
294
+ "text_b": r.text_b,
295
+ "col_a": r.col_a,
296
+ "col_b": r.col_b,
297
+ }
298
+ )
299
+
300
+ # Fit BT-MLE
301
+ win_counts, _ = _build_win_matrix(results)
302
+ theta = _bt_mle(win_counts, model_names)
303
+ board.elo = _theta_to_elo(theta)
304
+
305
+ # Bootstrap CIs
306
+ if n_bootstrap > 0 and results:
307
+ board.elo_ci = _bootstrap_ci(results, model_names, n_bootstrap=n_bootstrap)
308
+
309
+ return board
src/ocr_bench/judge.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pairwise VLM judge — prompt templates, structured output schema, comparison building."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import io
7
+ import json
8
+ import logging
9
+ import random
10
+ from dataclasses import dataclass
11
+ from itertools import combinations
12
+ from typing import Any
13
+
14
+ from PIL import Image
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # --- Judge prompt ---
19
+
20
+ PAIRWISE_PROMPT = """\
21
+ You are an expert OCR quality evaluator. You are given a document image and \
22
+ TWO OCR outputs (A and B) extracted from that same image.
23
+
24
+ Compare them and decide which extraction is better overall.
25
+
26
+ Evaluation criteria (in priority order):
27
+
28
+ 1. Faithfulness: The output must ONLY contain text actually visible in the document. \
29
+ Hallucinating text that is not in the image (garbled strings, repeated tokens, \
30
+ nonsensical output) is the most serious error. Added commentary or notes \
31
+ (e.g. "it appears the text says...") is also an error, but less severe than \
32
+ hallucination. If a page is blank or has minimal text, saying so is acceptable — \
33
+ fabricating content is always worse.
34
+
35
+ 2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \
36
+ stamps, handwritten notes. Missing any section of text is a significant penalty.
37
+
38
+ 3. Accuracy: Correct characters, no garbled or fabricated words.
39
+
40
+ 4. Reading order: Text flows naturally as a human would read the document.
41
+
42
+ 5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \
43
+ if present. Do NOT prefer fancier markdown formatting — plain accurate text is \
44
+ better than nicely formatted but incomplete text.
45
+
46
+ If both outputs capture the same text with similar accuracy, respond with "tie". \
47
+ Only pick a winner when there is a clear quality difference.
48
+
49
+ Output A:
50
+ ---
51
+ {ocr_text_a}
52
+ ---
53
+
54
+ Output B:
55
+ ---
56
+ {ocr_text_b}
57
+ ---
58
+
59
+ Respond with JSON only (no markdown fences, no extra text):
60
+ {{"winner": "A", "reason": "brief explanation"}}
61
+ Use "A", "B", or "tie" for the winner field."""
62
+
63
+ JUDGE_SCHEMA: dict[str, Any] = {
64
+ "type": "object",
65
+ "properties": {
66
+ "winner": {"type": "string", "enum": ["A", "B", "tie"]},
67
+ "reason": {"type": "string"},
68
+ },
69
+ "required": ["winner", "reason"],
70
+ }
71
+
72
+ # Max characters of OCR text to include per output in the prompt.
73
+ MAX_OCR_TEXT_LENGTH = 2500
74
+
75
+ # Max image dimension (longer side) before resizing.
76
+ MAX_IMAGE_DIM = 1024
77
+
78
+
79
+ # --- Image helpers ---
80
+
81
+
82
+ def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str:
83
+ """Convert a PIL image to a base64-encoded JPEG string, resizing if needed."""
84
+ if image.mode != "RGB":
85
+ image = image.convert("RGB")
86
+ if max(image.size) > max_dim:
87
+ ratio = max_dim / max(image.size)
88
+ new_size = (int(image.width * ratio), int(image.height * ratio))
89
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
90
+ buf = io.BytesIO()
91
+ image.save(buf, format="JPEG", quality=85)
92
+ return base64.b64encode(buf.getvalue()).decode()
93
+
94
+
95
+ # --- Comparison ---
96
+
97
+
98
+ @dataclass
99
+ class Comparison:
100
+ """A single pairwise comparison to evaluate."""
101
+
102
+ sample_idx: int
103
+ model_a: str
104
+ model_b: str
105
+ col_a: str
106
+ col_b: str
107
+ swapped: bool
108
+ messages: list[dict[str, Any]]
109
+ text_a: str = ""
110
+ text_b: str = ""
111
+
112
+
113
+ def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]:
114
+ """Build the pairwise comparison prompt, applying position-bias swap.
115
+
116
+ Returns (prompt_text, swapped).
117
+ """
118
+ a = text_a[:MAX_OCR_TEXT_LENGTH]
119
+ b = text_b[:MAX_OCR_TEXT_LENGTH]
120
+ if swapped:
121
+ a, b = b, a
122
+ return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped
123
+
124
+
125
+ def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]:
126
+ """Build chat messages for the judge (image + prompt)."""
127
+ return [
128
+ {
129
+ "role": "user",
130
+ "content": [
131
+ {
132
+ "type": "image_url",
133
+ "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
134
+ },
135
+ {"type": "text", "text": prompt},
136
+ ],
137
+ }
138
+ ]
139
+
140
+
141
+ def _normalize_pair(a: str, b: str) -> tuple[str, str]:
142
+ """Return a canonical (sorted) pair for symmetric lookup."""
143
+ return (a, b) if a <= b else (b, a)
144
+
145
+
146
+ def sample_indices(
147
+ dataset_len: int, max_samples: int | None = None, seed: int = 42
148
+ ) -> list[int]:
149
+ """Compute shuffled sample indices (cheap — no image loading).
150
+
151
+ Args:
152
+ dataset_len: Total number of rows in the dataset.
153
+ max_samples: If set, randomly sample this many indices.
154
+ seed: Random seed for reproducible sampling.
155
+
156
+ Returns:
157
+ List of integer indices into the dataset.
158
+ """
159
+ indices = list(range(dataset_len))
160
+ if max_samples and max_samples < len(indices):
161
+ random.seed(seed)
162
+ indices = random.sample(indices, max_samples)
163
+ return indices
164
+
165
+
166
+ def build_comparisons(
167
+ dataset: Any,
168
+ ocr_columns: dict[str, str],
169
+ max_samples: int | None = None,
170
+ seed: int = 42,
171
+ skip_pairs: set[tuple[str, str]] | None = None,
172
+ indices: list[int] | None = None,
173
+ ) -> list[Comparison]:
174
+ """Build pairwise comparison prompts from a dataset.
175
+
176
+ Args:
177
+ dataset: HF dataset with an "image" column and OCR output columns.
178
+ ocr_columns: Mapping of column_name -> model_name.
179
+ max_samples: If set, randomly sample this many rows. Ignored when
180
+ ``indices`` is provided.
181
+ seed: Random seed for sampling and position-bias randomization.
182
+ skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are
183
+ normalized so (a, b) and (b, a) are treated identically.
184
+ If None, all pairs are included.
185
+ indices: Explicit row indices to use. When provided, ``max_samples``
186
+ and ``seed`` are not used for index selection (seed is still used
187
+ for position-bias randomization).
188
+
189
+ Returns:
190
+ List of Comparison objects with pre-built chat messages.
191
+ """
192
+ col_names = list(ocr_columns.keys())
193
+ model_names = list(ocr_columns.values())
194
+ pairs = list(combinations(range(len(col_names)), 2))
195
+
196
+ # Normalize skip set for symmetric lookup
197
+ normalized_skip: set[tuple[str, str]] = set()
198
+ if skip_pairs:
199
+ normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs}
200
+
201
+ if indices is None:
202
+ indices = sample_indices(len(dataset), max_samples, seed)
203
+
204
+ rng = random.Random(seed)
205
+ comparisons: list[Comparison] = []
206
+
207
+ # Pre-fetch text columns to avoid triggering image decode per row.
208
+ # HF Dataset supports column access (dataset["col"]), plain lists don't.
209
+ text_cols_data: dict[str, list] | None = None
210
+ if hasattr(dataset, "column_names"):
211
+ text_cols_data = {col: dataset[col] for col in col_names}
212
+
213
+ for idx in indices:
214
+ # Determine which pairs need judging for this row
215
+ needed_pairs = [
216
+ (i, j)
217
+ for i, j in pairs
218
+ if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip
219
+ ]
220
+ if not needed_pairs:
221
+ continue # Skip image encoding entirely
222
+
223
+ # Check text availability before decoding the image
224
+ valid_pairs = []
225
+ if text_cols_data is not None:
226
+ for i, j in needed_pairs:
227
+ text_a = text_cols_data[col_names[i]][idx] or ""
228
+ text_b = text_cols_data[col_names[j]][idx] or ""
229
+ if text_a.strip() and text_b.strip():
230
+ valid_pairs.append((i, j, text_a, text_b))
231
+ else:
232
+ row = dataset[idx]
233
+ for i, j in needed_pairs:
234
+ text_a = row[col_names[i]] or ""
235
+ text_b = row[col_names[j]] or ""
236
+ if text_a.strip() and text_b.strip():
237
+ valid_pairs.append((i, j, text_a, text_b))
238
+
239
+ if not valid_pairs:
240
+ continue
241
+
242
+ image_b64 = image_to_base64(dataset[idx]["image"])
243
+
244
+ for i, j, text_a, text_b in valid_pairs:
245
+ swapped = rng.random() < 0.5
246
+ prompt, swapped = build_prompt(text_a, text_b, swapped)
247
+ messages = build_messages(image_b64, prompt)
248
+
249
+ comparisons.append(
250
+ Comparison(
251
+ sample_idx=idx,
252
+ model_a=model_names[i],
253
+ model_b=model_names[j],
254
+ col_a=col_names[i],
255
+ col_b=col_names[j],
256
+ swapped=swapped,
257
+ messages=messages,
258
+ text_a=text_a,
259
+ text_b=text_b,
260
+ )
261
+ )
262
+
263
+ return comparisons
264
+
265
+
266
+ # --- Output parsing ---
267
+
268
+
269
+ def parse_judge_output(text: str) -> dict[str, str]:
270
+ """Parse judge JSON output, handling markdown fences and invalid values.
271
+
272
+ Returns dict with "winner" and "reason" keys, or empty dict on failure.
273
+ """
274
+ text = text.strip()
275
+ if text.startswith("```"):
276
+ text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()
277
+ try:
278
+ result = json.loads(text)
279
+ winner = result.get("winner", "tie").upper().strip()
280
+ if winner == "TIE":
281
+ winner = "tie"
282
+ if winner not in ("A", "B", "tie"):
283
+ winner = "tie"
284
+ return {"winner": winner, "reason": result.get("reason", "")}
285
+ except json.JSONDecodeError:
286
+ logger.warning("Failed to parse judge output: %s", text[:200])
287
+ return {}
src/ocr_bench/publish.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub publishing — push comparisons, leaderboard, and metadata configs to HF Hub."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import datetime
6
+ import json
7
+ from dataclasses import dataclass
8
+
9
+ import structlog
10
+ from datasets import Dataset, load_dataset
11
+ from huggingface_hub import HfApi
12
+
13
+ from ocr_bench.elo import ComparisonResult, Leaderboard
14
+
15
+ logger = structlog.get_logger()
16
+
17
+
18
+ @dataclass
19
+ class EvalMetadata:
20
+ """Metadata for an evaluation run, stored alongside results on Hub."""
21
+
22
+ source_dataset: str
23
+ judge_models: list[str]
24
+ seed: int
25
+ max_samples: int
26
+ total_comparisons: int
27
+ valid_comparisons: int
28
+ from_prs: bool = False
29
+ timestamp: str = ""
30
+
31
+ def __post_init__(self):
32
+ if not self.timestamp:
33
+ self.timestamp = datetime.datetime.now(datetime.UTC).isoformat()
34
+
35
+
36
+ def load_existing_comparisons(repo_id: str) -> list[ComparisonResult]:
37
+ """Load existing comparisons from a Hub results repo.
38
+
39
+ The stored winner is already unswapped (canonical), so ``swapped=False``.
40
+ Returns an empty list if the repo or config doesn't exist.
41
+ """
42
+ try:
43
+ ds = load_dataset(repo_id, name="comparisons", split="train")
44
+ except Exception as exc:
45
+ logger.info("no_existing_comparisons", repo=repo_id, reason=str(exc))
46
+ return []
47
+
48
+ results = []
49
+ for row in ds:
50
+ results.append(
51
+ ComparisonResult(
52
+ sample_idx=row["sample_idx"],
53
+ model_a=row["model_a"],
54
+ model_b=row["model_b"],
55
+ winner=row["winner"],
56
+ reason=row.get("reason", ""),
57
+ agreement=row.get("agreement", "1/1"),
58
+ swapped=False,
59
+ text_a=row.get("text_a", ""),
60
+ text_b=row.get("text_b", ""),
61
+ col_a=row.get("col_a", ""),
62
+ col_b=row.get("col_b", ""),
63
+ )
64
+ )
65
+ logger.info("loaded_existing_comparisons", repo=repo_id, n=len(results))
66
+ return results
67
+
68
+
69
+ def load_existing_metadata(repo_id: str) -> list[dict]:
70
+ """Load existing metadata rows from a Hub results repo.
71
+
72
+ Returns an empty list if the repo or config doesn't exist.
73
+ """
74
+ try:
75
+ ds = load_dataset(repo_id, name="metadata", split="train")
76
+ return [dict(row) for row in ds]
77
+ except Exception as exc:
78
+ logger.info("no_existing_metadata", repo=repo_id, reason=str(exc))
79
+ return []
80
+
81
+
82
+ def build_leaderboard_rows(board: Leaderboard) -> list[dict]:
83
+ """Convert a Leaderboard into rows suitable for a Hub dataset."""
84
+ rows = []
85
+ for model, elo in board.ranked:
86
+ total = board.wins[model] + board.losses[model] + board.ties[model]
87
+ row = {
88
+ "model": model,
89
+ "elo": round(elo),
90
+ "wins": board.wins[model],
91
+ "losses": board.losses[model],
92
+ "ties": board.ties[model],
93
+ "win_pct": round(board.wins[model] / total * 100) if total > 0 else 0,
94
+ }
95
+ if board.elo_ci and model in board.elo_ci:
96
+ lo, hi = board.elo_ci[model]
97
+ row["elo_low"] = round(lo)
98
+ row["elo_high"] = round(hi)
99
+ rows.append(row)
100
+ return rows
101
+
102
+
103
+ def build_metadata_row(metadata: EvalMetadata) -> dict:
104
+ """Convert EvalMetadata into a single row for a Hub dataset."""
105
+ return {
106
+ "source_dataset": metadata.source_dataset,
107
+ "judge_models": json.dumps(metadata.judge_models),
108
+ "seed": metadata.seed,
109
+ "max_samples": metadata.max_samples,
110
+ "total_comparisons": metadata.total_comparisons,
111
+ "valid_comparisons": metadata.valid_comparisons,
112
+ "from_prs": metadata.from_prs,
113
+ "timestamp": metadata.timestamp,
114
+ }
115
+
116
+
117
+ def publish_results(
118
+ repo_id: str,
119
+ board: Leaderboard,
120
+ metadata: EvalMetadata,
121
+ existing_metadata: list[dict] | None = None,
122
+ ) -> None:
123
+ """Push evaluation results to Hub as a dataset with multiple configs.
124
+
125
+ Configs:
126
+ - (default): Leaderboard table — ``load_dataset("repo")`` returns this.
127
+ - ``leaderboard``: Same table, named config (backward compat for viewer).
128
+ - ``comparisons``: Full comparison log from the board (caller merges
129
+ existing + new before ``compute_elo``, so ``board.comparison_log``
130
+ is already the complete set).
131
+ - ``metadata``: Append-only run log. New row is appended to
132
+ ``existing_metadata``.
133
+ """
134
+ # Comparisons
135
+ if board.comparison_log:
136
+ comp_ds = Dataset.from_list(board.comparison_log)
137
+ comp_ds.push_to_hub(repo_id, config_name="comparisons")
138
+ logger.info("published_comparisons", repo=repo_id, n=len(board.comparison_log))
139
+
140
+ # Leaderboard — dual push: default config + named config
141
+ rows = build_leaderboard_rows(board)
142
+ lb_ds = Dataset.from_list(rows)
143
+ lb_ds.push_to_hub(repo_id)
144
+ lb_ds.push_to_hub(repo_id, config_name="leaderboard")
145
+ logger.info("published_leaderboard", repo=repo_id, n=len(rows))
146
+
147
+ # Metadata — append-only
148
+ meta_row = build_metadata_row(metadata)
149
+ all_meta = (existing_metadata or []) + [meta_row]
150
+ Dataset.from_list(all_meta).push_to_hub(repo_id, config_name="metadata")
151
+ logger.info("published_metadata", repo=repo_id, n=len(all_meta))
152
+
153
+ # README — auto-generated dataset card with leaderboard
154
+ readme = _build_readme(repo_id, rows, board, metadata)
155
+ api = HfApi()
156
+ api.upload_file(
157
+ path_or_fileobj=readme.encode(),
158
+ path_in_repo="README.md",
159
+ repo_id=repo_id,
160
+ repo_type="dataset",
161
+ )
162
+ logger.info("published_readme", repo=repo_id)
163
+
164
+
165
+ def _build_readme(
166
+ repo_id: str,
167
+ rows: list[dict],
168
+ board: Leaderboard,
169
+ metadata: EvalMetadata,
170
+ ) -> str:
171
+ """Build a dataset card README with the leaderboard table."""
172
+ has_ci = bool(board.elo_ci)
173
+ source_short = metadata.source_dataset.split("/")[-1]
174
+ judges = json.loads(
175
+ metadata.judge_models
176
+ if isinstance(metadata.judge_models, str)
177
+ else json.dumps(metadata.judge_models)
178
+ )
179
+ judge_str = ", ".join(j.split("/")[-1] for j in judges) if judges else "N/A"
180
+ n_comparisons = len(board.comparison_log)
181
+
182
+ lines = [
183
+ "---",
184
+ "license: mit",
185
+ "tags:",
186
+ " - ocr-bench",
187
+ " - leaderboard",
188
+ "configs:",
189
+ " - config_name: default",
190
+ " data_files:",
191
+ " - split: train",
192
+ " path: data/train-*.parquet",
193
+ " - config_name: comparisons",
194
+ " data_files:",
195
+ " - split: train",
196
+ " path: comparisons/train-*.parquet",
197
+ " - config_name: leaderboard",
198
+ " data_files:",
199
+ " - split: train",
200
+ " path: leaderboard/train-*.parquet",
201
+ " - config_name: metadata",
202
+ " data_files:",
203
+ " - split: train",
204
+ " path: metadata/train-*.parquet",
205
+ "---",
206
+ "",
207
+ f"# OCR Bench Results: {source_short}",
208
+ "",
209
+ "VLM-as-judge pairwise evaluation of OCR models. "
210
+ "Rankings depend on document type — there is no single best OCR model.",
211
+ "",
212
+ "## Leaderboard",
213
+ "",
214
+ ]
215
+
216
+ # Table header
217
+ if has_ci:
218
+ lines.append("| Rank | Model | ELO | 95% CI | Wins | Losses | Ties | Win% |")
219
+ lines.append("|------|-------|-----|--------|------|--------|------|------|")
220
+ else:
221
+ lines.append("| Rank | Model | ELO | Wins | Losses | Ties | Win% |")
222
+ lines.append("|------|-------|-----|------|--------|------|------|")
223
+
224
+ for rank, row in enumerate(rows, 1):
225
+ model = row["model"]
226
+ elo = row["elo"]
227
+ if has_ci and "elo_low" in row:
228
+ ci = f"{row['elo_low']}\u2013{row['elo_high']}"
229
+ lines.append(
230
+ f"| {rank} | {model} | {elo} | {ci} "
231
+ f"| {row['wins']} | {row['losses']} | {row['ties']} "
232
+ f"| {row['win_pct']}% |"
233
+ )
234
+ else:
235
+ lines.append(
236
+ f"| {rank} | {model} | {elo} "
237
+ f"| {row['wins']} | {row['losses']} | {row['ties']} "
238
+ f"| {row['win_pct']}% |"
239
+ )
240
+
241
+ lines += [
242
+ "",
243
+ "## Details",
244
+ "",
245
+ f"- **Source dataset**: [`{metadata.source_dataset}`]"
246
+ f"(https://huggingface.co/datasets/{metadata.source_dataset})",
247
+ f"- **Judge**: {judge_str}",
248
+ f"- **Comparisons**: {n_comparisons}",
249
+ "- **Method**: Bradley-Terry MLE with bootstrap 95% CIs",
250
+ "",
251
+ "## Configs",
252
+ "",
253
+ f"- `load_dataset(\"{repo_id}\")` — leaderboard table",
254
+ f"- `load_dataset(\"{repo_id}\", name=\"comparisons\")` "
255
+ "— full pairwise comparison log",
256
+ f"- `load_dataset(\"{repo_id}\", name=\"metadata\")` "
257
+ "— evaluation run history",
258
+ "",
259
+ "*Generated by [ocr-bench](https://github.com/davanstrien/ocr-bench)*",
260
+ ]
261
+
262
+ return "\n".join(lines) + "\n"
src/ocr_bench/run.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OCR model orchestration — launch HF Jobs for multiple OCR models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass, field
7
+
8
+ import structlog
9
+ from huggingface_hub import HfApi, get_token
10
+
11
+ logger = structlog.get_logger()
12
+
13
+
14
+ @dataclass
15
+ class ModelConfig:
16
+ """Configuration for a single OCR model."""
17
+
18
+ script: str
19
+ model_id: str
20
+ size: str
21
+ default_flavor: str = "l4x1"
22
+ default_args: list[str] = field(default_factory=list)
23
+
24
+
25
+ MODEL_REGISTRY: dict[str, ModelConfig] = {
26
+ "glm-ocr": ModelConfig(
27
+ script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/glm-ocr.py",
28
+ model_id="zai-org/GLM-OCR",
29
+ size="0.9B",
30
+ default_flavor="l4x1",
31
+ ),
32
+ "deepseek-ocr": ModelConfig(
33
+ script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/deepseek-ocr-vllm.py",
34
+ model_id="deepseek-ai/DeepSeek-OCR",
35
+ size="4B",
36
+ default_flavor="l4x1",
37
+ default_args=["--prompt-mode", "free"],
38
+ ),
39
+ "lighton-ocr-2": ModelConfig(
40
+ script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/lighton-ocr2.py",
41
+ model_id="lightonai/LightOnOCR-2-1B",
42
+ size="1B",
43
+ default_flavor="a100-large",
44
+ ),
45
+ "dots-ocr": ModelConfig(
46
+ script="https://huggingface.co/datasets/uv-scripts/ocr/raw/main/dots-ocr.py",
47
+ model_id="rednote-hilab/dots.ocr",
48
+ size="1.7B",
49
+ default_flavor="l4x1",
50
+ ),
51
+ }
52
+
53
+ DEFAULT_MODELS = ["glm-ocr", "deepseek-ocr", "lighton-ocr-2", "dots-ocr"]
54
+
55
+
56
+ @dataclass
57
+ class JobRun:
58
+ """Tracks a launched HF Job."""
59
+
60
+ model_slug: str
61
+ job_id: str
62
+ job_url: str
63
+ status: str = "running"
64
+
65
+
66
+ def list_models() -> list[str]:
67
+ """Return sorted list of available model slugs."""
68
+ return sorted(MODEL_REGISTRY.keys())
69
+
70
+
71
+ def build_script_args(
72
+ input_dataset: str,
73
+ output_repo: str,
74
+ config_name: str,
75
+ *,
76
+ max_samples: int | None = None,
77
+ shuffle: bool = False,
78
+ seed: int = 42,
79
+ extra_args: list[str] | None = None,
80
+ ) -> list[str]:
81
+ """Build the script_args list for run_uv_job."""
82
+ args = [
83
+ input_dataset,
84
+ output_repo,
85
+ "--config",
86
+ config_name,
87
+ "--create-pr",
88
+ ]
89
+ if max_samples is not None:
90
+ args += ["--max-samples", str(max_samples)]
91
+ if shuffle:
92
+ args.append("--shuffle")
93
+ if seed != 42:
94
+ args += ["--seed", str(seed)]
95
+ if extra_args:
96
+ args += extra_args
97
+ return args
98
+
99
+
100
+ def launch_ocr_jobs(
101
+ input_dataset: str,
102
+ output_repo: str,
103
+ *,
104
+ models: list[str] | None = None,
105
+ max_samples: int | None = None,
106
+ split: str = "train",
107
+ shuffle: bool = False,
108
+ seed: int = 42,
109
+ flavor_override: str | None = None,
110
+ timeout: str = "4h",
111
+ api: HfApi | None = None,
112
+ ) -> list[JobRun]:
113
+ """Launch HF Jobs for each model. Returns list of JobRun tracking objects."""
114
+ if api is None:
115
+ api = HfApi()
116
+
117
+ token = get_token()
118
+ if not token:
119
+ raise RuntimeError("No HF token found. Log in with `hf login` or set HF_TOKEN.")
120
+
121
+ selected = models or DEFAULT_MODELS
122
+ for slug in selected:
123
+ if slug not in MODEL_REGISTRY:
124
+ raise ValueError(
125
+ f"Unknown model: {slug}. Available: {', '.join(MODEL_REGISTRY.keys())}"
126
+ )
127
+
128
+ jobs: list[JobRun] = []
129
+ for slug in selected:
130
+ config = MODEL_REGISTRY[slug]
131
+ flavor = flavor_override or config.default_flavor
132
+ script_args = build_script_args(
133
+ input_dataset,
134
+ output_repo,
135
+ slug,
136
+ max_samples=max_samples,
137
+ shuffle=shuffle,
138
+ seed=seed,
139
+ extra_args=config.default_args or None,
140
+ )
141
+
142
+ logger.info("launching_job", model=slug, flavor=flavor, script=config.script)
143
+ job = api.run_uv_job(
144
+ script=config.script,
145
+ script_args=script_args,
146
+ flavor=flavor,
147
+ secrets={"HF_TOKEN": token},
148
+ timeout=timeout,
149
+ )
150
+ jobs.append(JobRun(model_slug=slug, job_id=job.id, job_url=job.url))
151
+ logger.info("job_launched", model=slug, job_id=job.id, url=job.url)
152
+
153
+ return jobs
154
+
155
+
156
+ _TERMINAL_STAGES = frozenset({"COMPLETED", "ERROR", "CANCELED", "DELETED"})
157
+
158
+
159
+ def poll_jobs(
160
+ jobs: list[JobRun],
161
+ *,
162
+ interval: int = 30,
163
+ api: HfApi | None = None,
164
+ ) -> list[JobRun]:
165
+ """Poll until all jobs complete or fail. Updates status in-place and returns the list."""
166
+ if api is None:
167
+ api = HfApi()
168
+
169
+ pending = {j.job_id: j for j in jobs if j.status == "running"}
170
+
171
+ while pending:
172
+ time.sleep(interval)
173
+ still_running: dict[str, JobRun] = {}
174
+ for job_id, job_run in pending.items():
175
+ info = api.inspect_job(job_id=job_id)
176
+ stage = info.status.stage
177
+ if stage in _TERMINAL_STAGES:
178
+ job_run.status = stage.lower()
179
+ logger.info("job_finished", model=job_run.model_slug, status=job_run.status)
180
+ else:
181
+ still_running[job_id] = job_run
182
+ pending = still_running
183
+ if pending:
184
+ slugs = [j.model_slug for j in pending.values()]
185
+ logger.info("jobs_pending", models=slugs)
186
+
187
+ return jobs
src/ocr_bench/space.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF Space entry point for ocr-bench viewer."""
2
+
3
+ import os
4
+
5
+ import uvicorn
6
+
7
+ from ocr_bench.web import create_app
8
+
9
+
10
+ def main():
11
+ repos = os.environ.get("REPOS", "davanstrien/bpl-ocr-bench-results")
12
+ repo_id = repos.split(",")[0].strip()
13
+ app = create_app(repo_id)
14
+ uvicorn.run(app, host="0.0.0.0", port=7860)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ main()
src/ocr_bench/static/style.css ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* ocr-bench viewer — Tufte-inspired minimal styles */
2
+
3
+ *,
4
+ *::before,
5
+ *::after {
6
+ box-sizing: border-box;
7
+ }
8
+
9
+ body {
10
+ font-family: system-ui, -apple-system, sans-serif;
11
+ color: #333;
12
+ background: #fff;
13
+ margin: 0;
14
+ padding: 0;
15
+ line-height: 1.5;
16
+ }
17
+
18
+ .container {
19
+ max-width: 960px;
20
+ margin: 0 auto;
21
+ padding: 0 1.5rem 3rem;
22
+ }
23
+
24
+ /* Navigation */
25
+ nav {
26
+ border-bottom: 1px solid #ddd;
27
+ padding: 0.75rem 0;
28
+ margin-bottom: 2rem;
29
+ display: flex;
30
+ align-items: baseline;
31
+ gap: 2rem;
32
+ }
33
+
34
+ nav .brand {
35
+ font-weight: 600;
36
+ color: #333;
37
+ text-decoration: none;
38
+ font-size: 0.9rem;
39
+ letter-spacing: 0.02em;
40
+ }
41
+
42
+ nav a {
43
+ color: #666;
44
+ text-decoration: none;
45
+ font-size: 0.85rem;
46
+ }
47
+
48
+ nav a:hover,
49
+ nav a.active {
50
+ color: #333;
51
+ }
52
+
53
+ nav a.active {
54
+ border-bottom: 2px solid #333;
55
+ padding-bottom: 2px;
56
+ }
57
+
58
+ /* Comparison layout */
59
+ .comparison-columns {
60
+ display: grid;
61
+ grid-template-columns: 1fr 1fr;
62
+ gap: 2rem;
63
+ margin: 1.5rem 0;
64
+ }
65
+
66
+ .ocr-column h3 {
67
+ font-size: 0.85rem;
68
+ font-weight: 600;
69
+ color: #666;
70
+ margin: 0 0 0.5rem;
71
+ padding-bottom: 0.35rem;
72
+ border-bottom: 1px solid #ddd;
73
+ letter-spacing: 0.02em;
74
+ }
75
+
76
+ .ocr-column h3.revealed {
77
+ color: #333;
78
+ }
79
+
80
+ .ocr-text {
81
+ font-family: "SF Mono", "Menlo", "Consolas", monospace;
82
+ font-size: 0.82rem;
83
+ line-height: 1.6;
84
+ white-space: pre-wrap;
85
+ word-break: break-word;
86
+ max-height: 50vh;
87
+ overflow-y: auto;
88
+ padding: 0.25rem 0;
89
+ color: #444;
90
+ }
91
+
92
+ /* Navigation header */
93
+ .comp-nav {
94
+ display: flex;
95
+ justify-content: flex-end;
96
+ align-items: baseline;
97
+ gap: 0.75rem;
98
+ margin-bottom: 0.5rem;
99
+ color: #999;
100
+ font-size: 0.8rem;
101
+ }
102
+
103
+ .comp-nav a {
104
+ color: #999;
105
+ text-decoration: none;
106
+ font-size: 0.85rem;
107
+ padding: 0.15rem 0.4rem;
108
+ }
109
+
110
+ .comp-nav a:hover {
111
+ color: #333;
112
+ }
113
+
114
+ /* Vote prompt */
115
+ .vote-prompt {
116
+ text-align: center;
117
+ font-size: 0.8rem;
118
+ color: #999;
119
+ margin: 1.5rem 0 0.5rem;
120
+ }
121
+
122
+ /* Vote buttons */
123
+ .vote-row {
124
+ text-align: center;
125
+ margin: 0.25rem 0 0.5rem;
126
+ display: flex;
127
+ justify-content: center;
128
+ gap: 0.5rem;
129
+ }
130
+
131
+ .vote-btn {
132
+ display: inline-block;
133
+ color: #555;
134
+ text-decoration: none;
135
+ padding: 0.35rem 1rem;
136
+ border: 1px solid #ddd;
137
+ border-radius: 4px;
138
+ font-size: 0.85rem;
139
+ transition: border-color 0.15s, color 0.15s;
140
+ }
141
+
142
+ .vote-btn:hover {
143
+ color: #333;
144
+ border-color: #999;
145
+ }
146
+
147
+ .vote-btn.vote-tie {
148
+ color: #888;
149
+ }
150
+
151
+ /* Hints below vote buttons */
152
+ .vote-hints {
153
+ text-align: center;
154
+ margin: 0.5rem 0 1rem;
155
+ font-size: 0.75rem;
156
+ color: #bbb;
157
+ }
158
+
159
+ .vote-hints a {
160
+ color: #999;
161
+ text-decoration: none;
162
+ }
163
+
164
+ .vote-hints a:hover {
165
+ color: #666;
166
+ text-decoration: underline;
167
+ }
168
+
169
+ .vote-hints .separator {
170
+ color: #ddd;
171
+ }
172
+
173
+ .vote-hints kbd {
174
+ font-family: system-ui, sans-serif;
175
+ font-size: 0.7rem;
176
+ padding: 0.05rem 0.3rem;
177
+ border: 1px solid #ddd;
178
+ border-radius: 3px;
179
+ background: #f8f8f8;
180
+ color: #999;
181
+ }
182
+
183
+ /* Legacy reveal-row (kept for compat) */
184
+ .reveal-row {
185
+ text-align: right;
186
+ margin: 0.25rem 0 1rem;
187
+ font-size: 0.8rem;
188
+ }
189
+
190
+ .reveal-row a {
191
+ color: #999;
192
+ text-decoration: none;
193
+ }
194
+
195
+ .reveal-row a:hover {
196
+ color: #666;
197
+ }
198
+
199
+ /* Verdict display */
200
+ .verdict {
201
+ margin: 1rem 0;
202
+ font-size: 0.85rem;
203
+ color: #555;
204
+ line-height: 1.6;
205
+ }
206
+
207
+ .verdict .agreement {
208
+ font-weight: 500;
209
+ }
210
+
211
+ .verdict .agreement.agreed {
212
+ color: #457b4d;
213
+ }
214
+
215
+ .verdict .agreement.soft-disagree {
216
+ color: #a07828;
217
+ }
218
+
219
+ .verdict .agreement.hard-disagree {
220
+ color: #b04040;
221
+ }
222
+
223
+ .verdict .reason {
224
+ font-style: italic;
225
+ color: #777;
226
+ display: block;
227
+ margin-top: 0.25rem;
228
+ }
229
+
230
+ /* Document image */
231
+ .doc-image {
232
+ margin: 1.5rem 0;
233
+ text-align: center;
234
+ }
235
+
236
+ .doc-image img {
237
+ max-width: 100%;
238
+ height: auto;
239
+ max-height: 60vh;
240
+ }
241
+
242
+ /* Leaderboard table */
243
+ table {
244
+ width: 100%;
245
+ border-collapse: collapse;
246
+ font-size: 0.85rem;
247
+ margin: 1.5rem 0;
248
+ }
249
+
250
+ thead th {
251
+ text-align: left;
252
+ font-weight: 600;
253
+ padding: 0.5rem 0.75rem;
254
+ border-bottom: 2px solid #333;
255
+ color: #333;
256
+ font-size: 0.8rem;
257
+ letter-spacing: 0.02em;
258
+ }
259
+
260
+ thead th.num {
261
+ text-align: right;
262
+ }
263
+
264
+ tbody td {
265
+ padding: 0.4rem 0.75rem;
266
+ border-bottom: 1px solid #eee;
267
+ }
268
+
269
+ tbody td.num {
270
+ text-align: right;
271
+ font-variant-numeric: tabular-nums;
272
+ }
273
+
274
+ tbody td.model {
275
+ font-weight: 500;
276
+ }
277
+
278
+ tbody tr:hover {
279
+ background: #fafafa;
280
+ }
281
+
282
+ /* Filters */
283
+ .filters {
284
+ display: flex;
285
+ gap: 1rem;
286
+ margin-bottom: 1rem;
287
+ align-items: center;
288
+ }
289
+
290
+ .filters label {
291
+ font-size: 0.8rem;
292
+ color: #666;
293
+ }
294
+
295
+ .filters select {
296
+ font-size: 0.8rem;
297
+ padding: 0.25rem 0.5rem;
298
+ border: 1px solid #ddd;
299
+ border-radius: 3px;
300
+ background: #fff;
301
+ color: #333;
302
+ }
303
+
304
+ /* Stats panel */
305
+ .stats-panel {
306
+ color: #888;
307
+ font-size: 0.8rem;
308
+ padding: 1rem 0;
309
+ border-top: 1px solid #eee;
310
+ margin-top: 2rem;
311
+ }
312
+
313
+ .stats-panel .calibrated {
314
+ color: #457b4d;
315
+ }
316
+
317
+ .stats-panel .warning {
318
+ color: #b04040;
319
+ }
320
+
321
+ /* Pair summary table */
322
+ .pair-summary {
323
+ margin-bottom: 1rem;
324
+ }
325
+
326
+ .pair-table {
327
+ width: auto;
328
+ font-size: 0.8rem;
329
+ color: #888;
330
+ }
331
+
332
+ .pair-table th {
333
+ font-size: 0.75rem;
334
+ color: #999;
335
+ font-weight: 500;
336
+ padding: 0.2rem 0.6rem;
337
+ border-bottom: 1px solid #ddd;
338
+ }
339
+
340
+ .pair-table td {
341
+ padding: 0.15rem 0.6rem;
342
+ border-bottom: 1px solid #f0f0f0;
343
+ }
344
+
345
+ /* HTMX loading indicator */
346
+ .htmx-indicator {
347
+ opacity: 0;
348
+ transition: opacity 200ms ease-in;
349
+ }
350
+
351
+ .htmx-request .htmx-indicator,
352
+ .htmx-request.htmx-indicator {
353
+ opacity: 1;
354
+ }
355
+
356
+ /* Empty state */
357
+ .empty {
358
+ text-align: center;
359
+ color: #999;
360
+ padding: 3rem 0;
361
+ font-size: 0.9rem;
362
+ }
363
+
364
+ /* Responsive */
365
+ @media (max-width: 768px) {
366
+ .comparison-columns {
367
+ grid-template-columns: 1fr;
368
+ }
369
+
370
+ .container {
371
+ padding: 0 1rem 2rem;
372
+ }
373
+
374
+ table {
375
+ display: block;
376
+ overflow-x: auto;
377
+ -webkit-overflow-scrolling: touch;
378
+ }
379
+ }
src/ocr_bench/templates/base.html ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1">
6
+ <title>{% block title %}OCR Bench{% endblock %}</title>
7
+ <link rel="stylesheet" href="/static/style.css">
8
+ <script src="https://unpkg.com/htmx.org@2.0.4"></script>
9
+ </head>
10
+ <body>
11
+ <div class="container">
12
+ <nav>
13
+ <a href="/" class="brand">ocr-bench</a>
14
+ <a href="/leaderboard" {% if active_tab == "leaderboard" %}class="active"{% endif %}>Leaderboard</a>
15
+ <a href="/comparisons" {% if active_tab == "comparisons" %}class="active"{% endif %}>Comparisons</a>
16
+ </nav>
17
+ {% block content %}{% endblock %}
18
+ </div>
19
+
20
+ <script>
21
+ document.addEventListener("keydown", function(e) {
22
+ // Ignore when focus is in input/select/textarea
23
+ var tag = document.activeElement.tagName.toLowerCase();
24
+ if (tag === "input" || tag === "select" || tag === "textarea") return;
25
+
26
+ if (e.key === "ArrowLeft") {
27
+ var prev = document.querySelector("[data-nav='prev']");
28
+ if (prev) { prev.click(); e.preventDefault(); }
29
+ } else if (e.key === "ArrowRight") {
30
+ var next = document.querySelector("[data-nav='next']");
31
+ if (next) { next.click(); e.preventDefault(); }
32
+ } else if (e.key === "a" || e.key === "A") {
33
+ var voteA = document.querySelector("[data-vote='A']");
34
+ if (voteA) { voteA.click(); e.preventDefault(); }
35
+ } else if (e.key === "b" || e.key === "B") {
36
+ var voteB = document.querySelector("[data-vote='B']");
37
+ if (voteB) { voteB.click(); e.preventDefault(); }
38
+ } else if (e.key === "t" || e.key === "T") {
39
+ var voteTie = document.querySelector("[data-vote='tie']");
40
+ if (voteTie) { voteTie.click(); e.preventDefault(); }
41
+ } else if (e.key === "r" || e.key === "R") {
42
+ var reveal = document.querySelector("[data-action='reveal']");
43
+ if (reveal) { reveal.click(); e.preventDefault(); }
44
+ }
45
+ });
46
+ </script>
47
+ </body>
48
+ </html>
src/ocr_bench/templates/comparison_card.html ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% if comp %}
2
+ <div class="comp-nav">
3
+ <span>{{ nav_idx + 1 }} of {{ nav_total }}</span>
4
+ {% if nav_idx > 0 %}
5
+ <a href="#" data-nav="prev"
6
+ hx-get="/comparisons/{{ nav_idx - 1 }}{% if winner_filter and winner_filter != 'All' %}?winner={{ winner_filter }}{% endif %}{% if model_filter and model_filter != 'All' %}{{ '&' if winner_filter and winner_filter != 'All' else '?' }}model={{ model_filter }}{% endif %}"
7
+ hx-target="#comparison-container">&larr;</a>
8
+ {% endif %}
9
+ {% if nav_idx < nav_total - 1 %}
10
+ <a href="#" data-nav="next"
11
+ hx-get="/comparisons/{{ nav_idx + 1 }}{% if winner_filter and winner_filter != 'All' %}?winner={{ winner_filter }}{% endif %}{% if model_filter and model_filter != 'All' %}{{ '&' if winner_filter and winner_filter != 'All' else '?' }}model={{ model_filter }}{% endif %}"
12
+ hx-target="#comparison-container">&rarr;</a>
13
+ {% endif %}
14
+ </div>
15
+
16
+ <div class="comparison-columns">
17
+ <div class="ocr-column">
18
+ {% if revealed %}
19
+ <h3 class="revealed">{{ model_a_name }}</h3>
20
+ {% else %}
21
+ <h3>A</h3>
22
+ {% endif %}
23
+ <div class="ocr-text">{{ display_text_a }}</div>
24
+ </div>
25
+ <div class="ocr-column">
26
+ {% if revealed %}
27
+ <h3 class="revealed">{{ model_b_name }}</h3>
28
+ {% else %}
29
+ <h3>B</h3>
30
+ {% endif %}
31
+ <div class="ocr-text">{{ display_text_b }}</div>
32
+ </div>
33
+ </div>
34
+
35
+ {% if not voted %}
36
+ <div class="vote-prompt">Which OCR output is better?</div>
37
+ <div class="vote-row">
38
+ <a href="#" data-vote="A" class="vote-btn"
39
+ hx-post="/vote/{{ comp_idx }}"
40
+ hx-vals='{"winner": "A"}'
41
+ hx-target="#comparison-container">A is better</a>
42
+ <a href="#" data-vote="tie" class="vote-btn vote-tie"
43
+ hx-post="/vote/{{ comp_idx }}"
44
+ hx-vals='{"winner": "tie"}'
45
+ hx-target="#comparison-container">Tie</a>
46
+ <a href="#" data-vote="B" class="vote-btn"
47
+ hx-post="/vote/{{ comp_idx }}"
48
+ hx-vals='{"winner": "B"}'
49
+ hx-target="#comparison-container">B is better</a>
50
+ </div>
51
+ <div class="vote-hints">
52
+ {% if not revealed %}
53
+ <a href="#" data-action="reveal"
54
+ hx-get="/reveal/{{ comp_idx }}"
55
+ hx-target="#comparison-container">show judge verdict</a>
56
+ <span class="separator">&middot;</span>
57
+ {% endif %}
58
+ <span class="keys">keys: <kbd>a</kbd> <kbd>t</kbd> <kbd>b</kbd> vote &middot; <kbd>&larr;</kbd> <kbd>&rarr;</kbd> navigate{% if not revealed %} &middot; <kbd>r</kbd> reveal{% endif %}</span>
59
+ </div>
60
+ {% endif %}
61
+
62
+ {% if revealed %}
63
+ <div class="verdict">
64
+ {% if voted %}
65
+ Judge: {{ judge_verdict }}
66
+ &middot; You: {{ human_vote }}
67
+ &middot; <span class="agreement {{ agreement_class }}">{{ agreement_word }}</span>
68
+ {% else %}
69
+ Judge: {{ judge_verdict }}
70
+ {% endif %}
71
+ {% if reason %}
72
+ <span class="reason">"{{ reason }}"</span>
73
+ {% endif %}
74
+ </div>
75
+ {% if just_voted and next_url %}
76
+ <div hx-get="{{ next_url }}" hx-trigger="load delay:1.2s" hx-target="#comparison-container"></div>
77
+ {% endif %}
78
+ {% endif %}
79
+
80
+ {% if has_image %}
81
+ <div class="doc-image">
82
+ <img src="/image/{{ sample_idx }}" alt="Document image" loading="lazy">
83
+ </div>
84
+ {% endif %}
85
+
86
+ {% else %}
87
+ <div class="empty">No comparisons match the current filters.</div>
88
+ {% endif %}
src/ocr_bench/templates/comparisons.html ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends "base.html" %}
2
+ {% block title %}Comparisons — OCR Bench{% endblock %}
3
+ {% block content %}
4
+ <div class="filters">
5
+ <label>Winner
6
+ <select name="winner"
7
+ hx-get="/comparisons/filter"
8
+ hx-target="#comparison-container"
9
+ hx-include="[name='model']">
10
+ <option value="All" {% if winner_filter == "All" %}selected{% endif %}>All</option>
11
+ <option value="A" {% if winner_filter == "A" %}selected{% endif %}>A</option>
12
+ <option value="B" {% if winner_filter == "B" %}selected{% endif %}>B</option>
13
+ <option value="tie" {% if winner_filter == "tie" %}selected{% endif %}>tie</option>
14
+ </select>
15
+ </label>
16
+ <label>Model
17
+ <select name="model"
18
+ hx-get="/comparisons/filter"
19
+ hx-target="#comparison-container"
20
+ hx-include="[name='winner']">
21
+ <option value="All" {% if model_filter == "All" %}selected{% endif %}>All</option>
22
+ {% for m in models %}
23
+ <option value="{{ m }}" {% if model_filter == m %}selected{% endif %}>{{ m }}</option>
24
+ {% endfor %}
25
+ </select>
26
+ </label>
27
+ </div>
28
+
29
+ {% if pair_summary %}
30
+ <div class="pair-summary">{{ pair_summary | safe }}</div>
31
+ {% endif %}
32
+
33
+ <div id="comparison-container">
34
+ {% include "comparison_card.html" %}
35
+ </div>
36
+
37
+ <div id="stats-panel" hx-get="/stats" hx-trigger="vote-recorded from:body" hx-swap="innerHTML">
38
+ {% include "stats_panel.html" %}
39
+ </div>
40
+ {% endblock %}
src/ocr_bench/templates/leaderboard.html ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% extends "base.html" %}
2
+ {% block title %}Leaderboard — OCR Bench{% endblock %}
3
+ {% block content %}
4
+ <h2 style="font-size: 1.1rem; font-weight: 600; margin-bottom: 0.25rem;">Leaderboard</h2>
5
+ <p style="font-size: 0.8rem; color: #888; margin-top: 0;">{{ repo_id }}</p>
6
+
7
+ <table>
8
+ <thead>
9
+ <tr>
10
+ <th>#</th>
11
+ <th>Model</th>
12
+ <th class="num">Judge ELO</th>
13
+ {% if has_ci %}<th class="num">95% CI</th>{% endif %}
14
+ <th class="num">Wins</th>
15
+ <th class="num">Losses</th>
16
+ <th class="num">Ties</th>
17
+ <th class="num">Win%</th>
18
+ {% if has_human_elo %}
19
+ <th class="num">Human ELO</th>
20
+ <th class="num">H-Win%</th>
21
+ {% endif %}
22
+ </tr>
23
+ </thead>
24
+ <tbody>
25
+ {% for row in rows %}
26
+ <tr>
27
+ <td>{{ loop.index }}</td>
28
+ <td class="model">{{ row.model_short }}</td>
29
+ <td class="num">{{ row.elo }}</td>
30
+ {% if has_ci %}<td class="num">{{ row.elo_low }}&ndash;{{ row.elo_high }}</td>{% endif %}
31
+ <td class="num">{{ row.wins }}</td>
32
+ <td class="num">{{ row.losses }}</td>
33
+ <td class="num">{{ row.ties }}</td>
34
+ <td class="num">{{ row.win_pct }}%</td>
35
+ {% if has_human_elo %}
36
+ <td class="num">{{ row.human_elo if row.human_elo is not none else "—" }}</td>
37
+ <td class="num">{{ row.human_win_pct if row.human_win_pct is not none else "—" }}</td>
38
+ {% endif %}
39
+ </tr>
40
+ {% endfor %}
41
+ </tbody>
42
+ </table>
43
+ {% endblock %}
src/ocr_bench/templates/stats_panel.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {% if vote_count > 0 %}
2
+ <span>{{ vote_count }} vote{{ "s" if vote_count != 1 else "" }}</span>
3
+ &middot;
4
+ <span>{{ agreement_pct }}% agree</span>
5
+ {% if hard_disagree_rate > 25 %}
6
+ &middot; <span class="warning">judge may be miscalibrated</span>
7
+ {% elif vote_count >= 15 %}
8
+ &middot; <span class="calibrated">judge well-calibrated</span>
9
+ {% endif %}
10
+ {% endif %}
src/ocr_bench/validate.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Blind human A/B validation for OCR judge quality."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import random
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+ import structlog
13
+
14
+ logger = structlog.get_logger()
15
+
16
+ # Confidence thresholds
17
+ MIN_ANNOTATIONS_FOR_CONFIDENCE = 15
18
+ HIGH_AGREEMENT_THRESHOLD = 0.75
19
+
20
+
21
+ @dataclass
22
+ class AgreementStats:
23
+ """Tracks agreement between human and VLM judge."""
24
+
25
+ agree: int = 0
26
+ soft_disagree: int = 0 # one picks tie, other picks winner
27
+ hard_disagree: int = 0 # both pick winners but opposite
28
+ total: int = 0
29
+
30
+ @property
31
+ def agreement_rate(self) -> float:
32
+ """Rate including soft disagreements as partial agreement."""
33
+ return (self.agree + self.soft_disagree) / self.total if self.total else 0.0
34
+
35
+ @property
36
+ def hard_disagree_rate(self) -> float:
37
+ return self.hard_disagree / self.total if self.total else 0.0
38
+
39
+
40
+ @dataclass
41
+ class ValidationComparison:
42
+ """A single comparison for human validation.
43
+
44
+ Built from enriched comparison data published by the judge.
45
+ """
46
+
47
+ comparison_id: int
48
+ sample_idx: int
49
+ model_a: str
50
+ model_b: str
51
+ winner: str # judge's verdict (hidden during annotation)
52
+ reason: str
53
+ agreement: str # jury agreement (e.g. "2/2")
54
+ text_a: str # OCR text from model A
55
+ text_b: str # OCR text from model B
56
+ col_a: str
57
+ col_b: str
58
+ swapped: bool # position-bias randomization for human display
59
+ display_text_a: str = "" # text shown to human (may be swapped)
60
+ display_text_b: str = ""
61
+
62
+
63
+ @dataclass
64
+ class ValidationSession:
65
+ """Holds state for a validation session."""
66
+
67
+ comparisons: list[ValidationComparison]
68
+ model_names: list[str]
69
+ metadata: dict[str, Any] = field(default_factory=dict)
70
+ annotations: list[dict[str, Any]] = field(default_factory=list)
71
+ completed_ids: set[int] = field(default_factory=set)
72
+
73
+
74
+ def _is_split_jury(agreement: str) -> bool:
75
+ """Check if a jury vote was split (e.g. '1/2' not '2/2')."""
76
+ parts = agreement.split("/")
77
+ return len(parts) == 2 and parts[0] != parts[1]
78
+
79
+
80
+ def _interleave_by_sample(
81
+ comparisons: list[ValidationComparison],
82
+ ) -> list[ValidationComparison]:
83
+ """Interleave comparisons so you see different samples before repeating."""
84
+ by_sample: dict[int, list[ValidationComparison]] = defaultdict(list)
85
+ for comp in comparisons:
86
+ by_sample[comp.sample_idx].append(comp)
87
+
88
+ result: list[ValidationComparison] = []
89
+ queues = list(by_sample.values())
90
+ while queues:
91
+ next_round = []
92
+ for q in queues:
93
+ result.append(q.pop(0))
94
+ if q:
95
+ next_round.append(q)
96
+ queues = next_round
97
+ return result
98
+
99
+
100
+ def build_validation_comparisons(
101
+ comparison_rows: list[dict[str, Any]],
102
+ *,
103
+ n: int | None = None,
104
+ prioritize_splits: bool = True,
105
+ seed: int = 42,
106
+ ) -> list[ValidationComparison]:
107
+ """Build validation comparisons from published judge results.
108
+
109
+ Args:
110
+ comparison_rows: Rows from the comparisons config of a results dataset.
111
+ n: Max number of comparisons to include (None = all).
112
+ prioritize_splits: Show split-jury cases first (most informative).
113
+ seed: Random seed for position-bias randomization.
114
+ """
115
+ rng = random.Random(seed)
116
+
117
+ comps: list[ValidationComparison] = []
118
+ for i, row in enumerate(comparison_rows):
119
+ swapped = rng.random() < 0.5
120
+ text_a = row.get("text_a", "")
121
+ text_b = row.get("text_b", "")
122
+
123
+ if swapped:
124
+ display_a, display_b = text_b, text_a
125
+ else:
126
+ display_a, display_b = text_a, text_b
127
+
128
+ comps.append(
129
+ ValidationComparison(
130
+ comparison_id=i,
131
+ sample_idx=row.get("sample_idx", i),
132
+ model_a=row.get("model_a", ""),
133
+ model_b=row.get("model_b", ""),
134
+ winner=row.get("winner", "tie"),
135
+ reason=row.get("reason", ""),
136
+ agreement=row.get("agreement", "1/1"),
137
+ text_a=text_a,
138
+ text_b=text_b,
139
+ col_a=row.get("col_a", ""),
140
+ col_b=row.get("col_b", ""),
141
+ swapped=swapped,
142
+ display_text_a=display_a,
143
+ display_text_b=display_b,
144
+ )
145
+ )
146
+
147
+ if prioritize_splits:
148
+ splits = [c for c in comps if _is_split_jury(c.agreement)]
149
+ unanimous = [c for c in comps if not _is_split_jury(c.agreement)]
150
+ ordered = _interleave_by_sample(splits) + _interleave_by_sample(unanimous)
151
+ else:
152
+ ordered = _interleave_by_sample(comps)
153
+
154
+ if n is not None and n < len(ordered):
155
+ ordered = ordered[:n]
156
+
157
+ # Re-assign comparison IDs after reordering
158
+ return [
159
+ ValidationComparison(
160
+ comparison_id=i,
161
+ sample_idx=c.sample_idx,
162
+ model_a=c.model_a,
163
+ model_b=c.model_b,
164
+ winner=c.winner,
165
+ reason=c.reason,
166
+ agreement=c.agreement,
167
+ text_a=c.text_a,
168
+ text_b=c.text_b,
169
+ col_a=c.col_a,
170
+ col_b=c.col_b,
171
+ swapped=c.swapped,
172
+ display_text_a=c.display_text_a,
173
+ display_text_b=c.display_text_b,
174
+ )
175
+ for i, c in enumerate(ordered)
176
+ ]
177
+
178
+
179
+ def compute_agreement(
180
+ annotations: list[dict[str, Any]],
181
+ comparisons: list[ValidationComparison],
182
+ ) -> AgreementStats:
183
+ """Compute agreement between human annotations and judge verdicts."""
184
+ comp_by_id = {c.comparison_id: c for c in comparisons}
185
+ stats = AgreementStats()
186
+
187
+ for ann in annotations:
188
+ comp = comp_by_id.get(ann.get("comparison_id"))
189
+ if not comp:
190
+ continue
191
+
192
+ # Unswap human vote
193
+ human_winner = ann["winner"]
194
+ if comp.swapped:
195
+ if human_winner == "A":
196
+ human_winner = "B"
197
+ elif human_winner == "B":
198
+ human_winner = "A"
199
+
200
+ judge_winner = comp.winner
201
+ stats.total += 1
202
+
203
+ if human_winner == judge_winner:
204
+ stats.agree += 1
205
+ elif human_winner == "tie" or judge_winner == "tie":
206
+ stats.soft_disagree += 1
207
+ else:
208
+ stats.hard_disagree += 1
209
+
210
+ return stats
211
+
212
+
213
+ def compute_human_elo(
214
+ annotations: list[dict[str, Any]],
215
+ comparisons: list[ValidationComparison],
216
+ ) -> Any:
217
+ """Compute ELO leaderboard from human annotations.
218
+
219
+ Returns a ``Leaderboard`` from ``elo.py``, or None if no annotations.
220
+ """
221
+ from ocr_bench.elo import ComparisonResult, compute_elo
222
+
223
+ comp_by_id = {c.comparison_id: c for c in comparisons}
224
+ model_set: set[str] = set()
225
+ results: list[ComparisonResult] = []
226
+
227
+ for ann in annotations:
228
+ comp = comp_by_id.get(ann.get("comparison_id"))
229
+ if not comp:
230
+ continue
231
+
232
+ # Unswap human vote to get canonical winner
233
+ human_winner = ann["winner"]
234
+ if comp.swapped:
235
+ if human_winner == "A":
236
+ human_winner = "B"
237
+ elif human_winner == "B":
238
+ human_winner = "A"
239
+
240
+ model_set.add(comp.model_a)
241
+ model_set.add(comp.model_b)
242
+ results.append(
243
+ ComparisonResult(
244
+ sample_idx=comp.sample_idx,
245
+ model_a=comp.model_a,
246
+ model_b=comp.model_b,
247
+ winner=human_winner,
248
+ )
249
+ )
250
+
251
+ if not results:
252
+ return None
253
+
254
+ return compute_elo(results, sorted(model_set))
255
+
256
+
257
+ def save_annotations(
258
+ path: str,
259
+ metadata: dict[str, Any],
260
+ annotations: list[dict[str, Any]],
261
+ ) -> None:
262
+ """Atomically save annotations to JSON file."""
263
+ data = {"metadata": metadata, "annotations": annotations}
264
+ tmp = path + ".tmp"
265
+ with open(tmp, "w") as f:
266
+ json.dump(data, f, indent=2)
267
+ os.replace(tmp, path)
268
+
269
+
270
+ def load_annotations(path: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
271
+ """Load annotations from JSON file. Returns (metadata, annotations)."""
272
+ if not os.path.exists(path):
273
+ return {}, []
274
+ with open(path) as f:
275
+ data = json.load(f)
276
+ return data.get("metadata", {}), data.get("annotations", [])
277
+
278
+
279
+ def _agreement_banner(stats: AgreementStats) -> str:
280
+ """Format agreement stats for display."""
281
+ if stats.total == 0:
282
+ return ""
283
+
284
+ parts = [f"Agree: {stats.agree}"]
285
+ if stats.soft_disagree:
286
+ parts.append(f"Soft: {stats.soft_disagree}")
287
+ if stats.hard_disagree:
288
+ parts.append(f"**Hard: {stats.hard_disagree}**")
289
+ parts.append(f"(of {stats.total})")
290
+
291
+ confidence = ""
292
+ if stats.total >= MIN_ANNOTATIONS_FOR_CONFIDENCE:
293
+ if stats.hard_disagree_rate == 0:
294
+ confidence = (
295
+ f" -- No hard disagreements after {stats.total} annotations. "
296
+ "Judge rankings reliable for this domain."
297
+ )
298
+ elif stats.hard_disagree_rate <= 0.1:
299
+ confidence = (
300
+ f" -- Very few hard disagreements ({stats.hard_disagree}). "
301
+ "Rankings likely trustworthy."
302
+ )
303
+ elif stats.hard_disagree_rate > 0.25:
304
+ confidence = (
305
+ f" -- Many hard disagreements ({stats.hard_disagree}/{stats.total}). "
306
+ "Judge may not be calibrated for this content."
307
+ )
308
+
309
+ return f"Judge: {' | '.join(parts)}{confidence}"
310
+
311
+
src/ocr_bench/viewer.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Results viewer — data loading and helpers for OCR bench results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import structlog
8
+ from datasets import load_dataset
9
+
10
+ if TYPE_CHECKING:
11
+ from PIL import Image
12
+
13
+ logger = structlog.get_logger()
14
+
15
+
16
+ def load_results(repo_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
17
+ """Load leaderboard and comparisons from a Hub results dataset.
18
+
19
+ Tries the default config first (new repos), then falls back to the
20
+ named ``leaderboard`` config (old repos).
21
+
22
+ Returns:
23
+ (leaderboard_rows, comparison_rows)
24
+ """
25
+ try:
26
+ leaderboard_ds = load_dataset(repo_id, split="train")
27
+ leaderboard_rows = [dict(row) for row in leaderboard_ds]
28
+ except Exception:
29
+ leaderboard_ds = load_dataset(repo_id, name="leaderboard", split="train")
30
+ leaderboard_rows = [dict(row) for row in leaderboard_ds]
31
+
32
+ try:
33
+ comparisons_ds = load_dataset(repo_id, name="comparisons", split="train")
34
+ except Exception:
35
+ logger.warning("no_comparisons_config", repo=repo_id)
36
+ return leaderboard_rows, []
37
+ comparison_rows = [dict(row) for row in comparisons_ds]
38
+
39
+ return leaderboard_rows, comparison_rows
40
+
41
+
42
+ def _load_source_metadata(repo_id: str) -> dict[str, Any]:
43
+ """Load metadata config from results repo to find the source dataset."""
44
+ try:
45
+ meta_ds = load_dataset(repo_id, name="metadata", split="train")
46
+ if len(meta_ds) > 0:
47
+ return dict(meta_ds[0])
48
+ except Exception as exc:
49
+ logger.warning("could_not_load_metadata", repo=repo_id, error=str(exc))
50
+ return {}
51
+
52
+
53
+ class ImageLoader:
54
+ """Lazy image loader — fetches images from source dataset by sample_idx."""
55
+
56
+ def __init__(self, source_dataset: str, from_prs: bool = False):
57
+ self._source = source_dataset
58
+ self._from_prs = from_prs
59
+ self._cache: dict[int, Any] = {}
60
+ self._image_col: str | None = None
61
+ self._pr_revision: str | None = None
62
+ self._available = True
63
+ self._init_done = False
64
+
65
+ def _init_source(self) -> None:
66
+ """Lazy init: discover image column and PR revision on first call."""
67
+ if self._init_done:
68
+ return
69
+ self._init_done = True
70
+
71
+ try:
72
+ if self._from_prs:
73
+ from ocr_bench.dataset import discover_pr_configs
74
+
75
+ _, revisions = discover_pr_configs(self._source)
76
+ if revisions:
77
+ # Use the first PR revision to get images
78
+ first_config = next(iter(revisions))
79
+ self._pr_revision = revisions[first_config]
80
+
81
+ # Probe for image column by loading 1 row
82
+ kwargs: dict[str, Any] = {"path": self._source, "split": "train[:1]"}
83
+ if self._pr_revision:
84
+ # Load from the first PR config
85
+ first_config = next(iter(revisions))
86
+ kwargs["name"] = first_config
87
+ kwargs["revision"] = self._pr_revision
88
+ probe = load_dataset(**kwargs)
89
+ for col in probe.column_names:
90
+ if col == "image" or "image" in col.lower():
91
+ self._image_col = col
92
+ break
93
+ if not self._image_col:
94
+ logger.info("no_image_column_in_source", source=self._source)
95
+ self._available = False
96
+ except Exception as exc:
97
+ logger.warning("image_loader_init_failed", source=self._source, error=str(exc))
98
+ self._available = False
99
+
100
+ def get(self, sample_idx: int) -> Image.Image | None:
101
+ """Fetch image for a sample index. Returns None on failure."""
102
+ self._init_source()
103
+ if not self._available or self._image_col is None:
104
+ return None
105
+ if sample_idx in self._cache:
106
+ return self._cache[sample_idx]
107
+ try:
108
+ kwargs: dict[str, Any] = {
109
+ "path": self._source,
110
+ "split": f"train[{sample_idx}:{sample_idx + 1}]",
111
+ }
112
+ if self._pr_revision:
113
+ from ocr_bench.dataset import discover_pr_configs
114
+
115
+ _, revisions = discover_pr_configs(self._source)
116
+ if revisions:
117
+ first_config = next(iter(revisions))
118
+ kwargs["name"] = first_config
119
+ kwargs["revision"] = revisions[first_config]
120
+ row = load_dataset(**kwargs)
121
+ img = row[0][self._image_col]
122
+ self._cache[sample_idx] = img
123
+ return img
124
+ except Exception as exc:
125
+ logger.debug("image_load_failed", sample_idx=sample_idx, error=str(exc))
126
+ return None
127
+
128
+
129
+ def _filter_comparisons(
130
+ comparisons: list[dict[str, Any]],
131
+ winner_filter: str,
132
+ model_filter: str,
133
+ ) -> list[dict[str, Any]]:
134
+ """Filter comparison rows by winner and model."""
135
+ filtered = comparisons
136
+ if winner_filter and winner_filter != "All":
137
+ filtered = [c for c in filtered if c.get("winner") == winner_filter]
138
+ if model_filter and model_filter != "All":
139
+ filtered = [
140
+ c
141
+ for c in filtered
142
+ if c.get("model_a") == model_filter or c.get("model_b") == model_filter
143
+ ]
144
+ return filtered
145
+
146
+
147
+ def _winner_badge(winner: str) -> str:
148
+ """Return a badge string for the winner."""
149
+ if winner == "A":
150
+ return "Winner: A"
151
+ elif winner == "B":
152
+ return "Winner: B"
153
+ else:
154
+ return "Tie"
155
+
156
+
157
+ def _model_label(model: str, col: str) -> str:
158
+ """Format model name with optional column name. Avoids empty parens."""
159
+ if col:
160
+ return f"{model} ({col})"
161
+ return model
162
+
163
+
164
+ def _build_pair_summary(comparisons: list[dict[str, Any]]) -> str:
165
+ """Build a win/loss summary string for each model pair."""
166
+ from collections import Counter
167
+
168
+ pair_counts: dict[tuple[str, str], Counter[str]] = {}
169
+ for c in comparisons:
170
+ ma = c.get("model_a", "")
171
+ mb = c.get("model_b", "")
172
+ winner = c.get("winner", "tie")
173
+ key = (ma, mb) if ma <= mb else (mb, ma)
174
+ if key not in pair_counts:
175
+ pair_counts[key] = Counter()
176
+ # Track from perspective of first model in sorted pair
177
+ if winner == "A":
178
+ actual_winner = ma
179
+ elif winner == "B":
180
+ actual_winner = mb
181
+ else:
182
+ actual_winner = "tie"
183
+
184
+ if actual_winner == key[0]:
185
+ pair_counts[key]["W"] += 1
186
+ elif actual_winner == key[1]:
187
+ pair_counts[key]["L"] += 1
188
+ else:
189
+ pair_counts[key]["T"] += 1
190
+
191
+ if not pair_counts:
192
+ return ""
193
+
194
+ parts = []
195
+ for (ma, mb), counts in sorted(pair_counts.items()):
196
+ short_a = ma.split("/")[-1] if "/" in ma else ma
197
+ short_b = mb.split("/")[-1] if "/" in mb else mb
198
+ wins, losses, ties = counts["W"], counts["L"], counts["T"]
199
+ parts.append(f"**{short_a}** vs **{short_b}**: {wins}W {losses}L {ties}T")
200
+ return " | ".join(parts)
201
+
202
+
src/ocr_bench/web.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI + HTMX viewer — unified browse + validate for OCR bench results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import io
6
+ from dataclasses import dataclass, field
7
+ from datetime import UTC, datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import structlog
12
+ from fastapi import FastAPI, Form, Request
13
+ from fastapi.responses import HTMLResponse, RedirectResponse, StreamingResponse
14
+ from fastapi.staticfiles import StaticFiles
15
+ from fastapi.templating import Jinja2Templates
16
+
17
+ from ocr_bench.validate import (
18
+ ValidationComparison,
19
+ build_validation_comparisons,
20
+ compute_agreement,
21
+ compute_human_elo,
22
+ load_annotations,
23
+ save_annotations,
24
+ )
25
+ from ocr_bench.viewer import (
26
+ ImageLoader,
27
+ _filter_comparisons,
28
+ _load_source_metadata,
29
+ load_results,
30
+ )
31
+
32
+ logger = structlog.get_logger()
33
+
34
+
35
+ def _short_model(model: str) -> str:
36
+ """Return just the model name after the org prefix."""
37
+ return model.split("/")[-1] if "/" in model else model
38
+
39
+
40
+ def _build_pair_summary_html(comparisons: list[dict[str, Any]]) -> str:
41
+ """Build a compact HTML table of head-to-head records."""
42
+ from collections import Counter
43
+
44
+ pair_counts: dict[tuple[str, str], Counter[str]] = {}
45
+ for c in comparisons:
46
+ ma = c.get("model_a", "")
47
+ mb = c.get("model_b", "")
48
+ winner = c.get("winner", "tie")
49
+ key = (ma, mb) if ma <= mb else (mb, ma)
50
+ if key not in pair_counts:
51
+ pair_counts[key] = Counter()
52
+ if winner == "A":
53
+ actual_winner = ma
54
+ elif winner == "B":
55
+ actual_winner = mb
56
+ else:
57
+ actual_winner = "tie"
58
+ if actual_winner == key[0]:
59
+ pair_counts[key]["W"] += 1
60
+ elif actual_winner == key[1]:
61
+ pair_counts[key]["L"] += 1
62
+ else:
63
+ pair_counts[key]["T"] += 1
64
+
65
+ if not pair_counts:
66
+ return ""
67
+
68
+ rows = []
69
+ for (ma, mb), counts in sorted(pair_counts.items()):
70
+ short_a = _short_model(ma)
71
+ short_b = _short_model(mb)
72
+ wins, losses, ties = counts["W"], counts["L"], counts["T"]
73
+ rows.append(
74
+ f"<tr><td>{short_a}</td><td>{short_b}</td>"
75
+ f"<td class='num'>{wins}</td><td class='num'>{losses}</td>"
76
+ f"<td class='num'>{ties}</td></tr>"
77
+ )
78
+ return (
79
+ '<table class="pair-table"><thead><tr>'
80
+ "<th>Model A</th><th>Model B</th>"
81
+ '<th class="num">W</th><th class="num">L</th><th class="num">T</th>'
82
+ "</tr></thead><tbody>" + "".join(rows) + "</tbody></table>"
83
+ )
84
+
85
+
86
+ PKG_DIR = Path(__file__).parent
87
+ TEMPLATES_DIR = PKG_DIR / "templates"
88
+ STATIC_DIR = PKG_DIR / "static"
89
+
90
+
91
+ @dataclass
92
+ class ViewerState:
93
+ """In-memory state for the single-user viewer."""
94
+
95
+ repo_id: str
96
+ leaderboard_rows: list[dict[str, Any]]
97
+ comparison_rows: list[dict[str, Any]]
98
+ validation_comps: list[ValidationComparison]
99
+ models: list[str]
100
+ img_loader: ImageLoader | None
101
+ save_path: str
102
+ annotations: list[dict[str, Any]] = field(default_factory=list)
103
+ completed_ids: set[int] = field(default_factory=set)
104
+ filtered_indices: list[int] = field(default_factory=list)
105
+
106
+
107
+ def _build_filtered_indices(
108
+ state: ViewerState,
109
+ winner_filter: str = "All",
110
+ model_filter: str = "All",
111
+ ) -> list[int]:
112
+ """Map nav indices to validation_comps indices, respecting filters."""
113
+ filtered_comps = _filter_comparisons(state.comparison_rows, winner_filter, model_filter)
114
+ # Build a lookup from (sample_idx, model_a, model_b) -> validation comp index
115
+ filtered_sample_keys = {
116
+ (c["sample_idx"], c["model_a"], c["model_b"]) for c in filtered_comps
117
+ }
118
+ return [
119
+ i
120
+ for i, vc in enumerate(state.validation_comps)
121
+ if (vc.sample_idx, vc.model_a, vc.model_b) in filtered_sample_keys
122
+ ]
123
+
124
+
125
+ def create_app(
126
+ repo_id: str,
127
+ *,
128
+ output_path: str | None = None,
129
+ n_validate: int | None = None,
130
+ ) -> FastAPI:
131
+ """Create the FastAPI app with all routes.
132
+
133
+ Args:
134
+ repo_id: HF dataset repo with published judge results.
135
+ output_path: Path to save human annotations JSON.
136
+ n_validate: Max comparisons to include for validation (None = all).
137
+ """
138
+ app = FastAPI(title=f"OCR Bench — {repo_id}")
139
+ app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
140
+ templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
141
+
142
+ # --- Load data ---
143
+ leaderboard_rows, comparison_rows = load_results(repo_id)
144
+
145
+ metadata = _load_source_metadata(repo_id)
146
+ source_dataset = metadata.get("source_dataset", "")
147
+ from_prs = metadata.get("from_prs", False)
148
+
149
+ img_loader: ImageLoader | None = None
150
+ if source_dataset:
151
+ img_loader = ImageLoader(source_dataset, from_prs=from_prs)
152
+
153
+ validation_comps = build_validation_comparisons(
154
+ comparison_rows, n=n_validate, prioritize_splits=True
155
+ )
156
+
157
+ models = sorted(
158
+ {c.get("model_a", "") for c in comparison_rows}
159
+ | {c.get("model_b", "") for c in comparison_rows}
160
+ )
161
+
162
+ slug = repo_id.replace("/", "-")
163
+ save_path = output_path or f"human-eval-{slug}.json"
164
+
165
+ # Resume existing annotations
166
+ _, existing_annotations = load_annotations(save_path)
167
+ completed_ids = {ann["comparison_id"] for ann in existing_annotations}
168
+
169
+ state = ViewerState(
170
+ repo_id=repo_id,
171
+ leaderboard_rows=leaderboard_rows,
172
+ comparison_rows=comparison_rows,
173
+ validation_comps=validation_comps,
174
+ models=models,
175
+ img_loader=img_loader,
176
+ save_path=save_path,
177
+ annotations=existing_annotations,
178
+ completed_ids=completed_ids,
179
+ filtered_indices=list(range(len(validation_comps))),
180
+ )
181
+
182
+ # Store state on app for access in routes
183
+ app.state.viewer = state
184
+
185
+ ann_metadata = {
186
+ "results_repo": repo_id,
187
+ "n_comparisons": len(validation_comps),
188
+ "models": models,
189
+ "started_at": datetime.now(UTC).isoformat(),
190
+ }
191
+
192
+ # --- Helpers ---
193
+
194
+ def _get_comp_context(
195
+ nav_idx: int,
196
+ *,
197
+ revealed: bool = False,
198
+ voted: bool = False,
199
+ human_vote: str = "",
200
+ winner_filter: str = "All",
201
+ model_filter: str = "All",
202
+ ) -> dict[str, Any]:
203
+ """Build template context for a comparison card."""
204
+ indices = state.filtered_indices
205
+ if nav_idx < 0 or nav_idx >= len(indices):
206
+ return {"comp": None, "nav_idx": nav_idx, "nav_total": len(indices)}
207
+
208
+ comp_idx = indices[nav_idx]
209
+ comp = state.validation_comps[comp_idx]
210
+
211
+ # Check if already voted
212
+ already_voted = comp.comparison_id in state.completed_ids
213
+ if already_voted:
214
+ voted = True
215
+ revealed = True
216
+ # Find the annotation to get human vote
217
+ for ann in state.annotations:
218
+ if ann["comparison_id"] == comp.comparison_id:
219
+ human_vote = ann["winner"]
220
+ break
221
+
222
+ # Model names — short form for clean headers
223
+ model_a_name = _short_model(comp.model_a)
224
+ model_b_name = _short_model(comp.model_b)
225
+ if comp.swapped:
226
+ model_a_name, model_b_name = model_b_name, model_a_name
227
+
228
+ # Judge verdict (canonical → display)
229
+ judge_winner = comp.winner
230
+ if comp.swapped:
231
+ if judge_winner == "A":
232
+ judge_verdict = "B"
233
+ elif judge_winner == "B":
234
+ judge_verdict = "A"
235
+ else:
236
+ judge_verdict = "tie"
237
+ else:
238
+ judge_verdict = judge_winner
239
+
240
+ # Agreement
241
+ agreement_word = ""
242
+ agreement_class = ""
243
+ if voted and human_vote:
244
+ # Unswap human vote for comparison
245
+ unswapped_human = human_vote
246
+ if comp.swapped:
247
+ if human_vote == "A":
248
+ unswapped_human = "B"
249
+ elif human_vote == "B":
250
+ unswapped_human = "A"
251
+
252
+ if unswapped_human == comp.winner:
253
+ agreement_word = "agreed"
254
+ agreement_class = "agreed"
255
+ elif unswapped_human == "tie" or comp.winner == "tie":
256
+ agreement_word = "soft disagree"
257
+ agreement_class = "soft-disagree"
258
+ else:
259
+ agreement_word = "hard disagree"
260
+ agreement_class = "hard-disagree"
261
+
262
+ has_image = img_loader is not None
263
+
264
+ return {
265
+ "comp": comp,
266
+ "comp_idx": comp_idx,
267
+ "nav_idx": nav_idx,
268
+ "nav_total": len(indices),
269
+ "revealed": revealed,
270
+ "voted": voted,
271
+ "display_text_a": comp.display_text_a,
272
+ "display_text_b": comp.display_text_b,
273
+ "model_a_name": model_a_name,
274
+ "model_b_name": model_b_name,
275
+ "judge_verdict": judge_verdict,
276
+ "human_vote": human_vote,
277
+ "agreement_word": agreement_word,
278
+ "agreement_class": agreement_class,
279
+ "reason": comp.reason,
280
+ "sample_idx": comp.sample_idx,
281
+ "has_image": has_image,
282
+ "winner_filter": winner_filter,
283
+ "model_filter": model_filter,
284
+ }
285
+
286
+ def _stats_context() -> dict[str, Any]:
287
+ """Build template context for the stats panel."""
288
+ stats = compute_agreement(state.annotations, state.validation_comps)
289
+ return {
290
+ "vote_count": stats.total,
291
+ "agreement_pct": round(stats.agreement_rate * 100) if stats.total else 0,
292
+ "hard_disagree_rate": round(stats.hard_disagree_rate * 100) if stats.total else 0,
293
+ }
294
+
295
+ def _nav_idx_for_comp_idx(comp_idx: int) -> int:
296
+ """Find the nav_idx for a given comp_idx in filtered_indices."""
297
+ try:
298
+ return state.filtered_indices.index(comp_idx)
299
+ except ValueError:
300
+ return 0
301
+
302
+ # --- Routes ---
303
+
304
+ @app.get("/", response_class=RedirectResponse)
305
+ async def index():
306
+ return RedirectResponse(url="/comparisons", status_code=302)
307
+
308
+ @app.get("/leaderboard", response_class=HTMLResponse)
309
+ async def leaderboard(request: Request):
310
+ # Build human ELO if we have annotations
311
+ human_board = compute_human_elo(state.annotations, state.validation_comps)
312
+
313
+ rows = []
314
+ for row in sorted(state.leaderboard_rows, key=lambda r: r.get("elo", 0), reverse=True):
315
+ model = row.get("model", "")
316
+ short = model.split("/")[-1] if "/" in model else model
317
+ human_elo = None
318
+ human_win_pct = None
319
+ if human_board and model in human_board.elo:
320
+ human_elo = round(human_board.elo[model])
321
+ wp = human_board.win_pct(model)
322
+ human_win_pct = f"{wp:.0f}" if wp is not None else None
323
+
324
+ rows.append({
325
+ "model": model,
326
+ "model_short": short,
327
+ "elo": round(row.get("elo", 0)),
328
+ "elo_low": row.get("elo_low"),
329
+ "elo_high": row.get("elo_high"),
330
+ "wins": row.get("wins", 0),
331
+ "losses": row.get("losses", 0),
332
+ "ties": row.get("ties", 0),
333
+ "win_pct": row.get("win_pct", 0),
334
+ "human_elo": human_elo,
335
+ "human_win_pct": human_win_pct,
336
+ })
337
+
338
+ has_ci = any(r.get("elo_low") is not None for r in rows)
339
+ return templates.TemplateResponse(request, "leaderboard.html", {
340
+ "active_tab": "leaderboard",
341
+ "repo_id": state.repo_id,
342
+ "rows": rows,
343
+ "has_ci": has_ci,
344
+ "has_human_elo": human_board is not None,
345
+ })
346
+
347
+ @app.get("/comparisons", response_class=HTMLResponse)
348
+ async def comparisons_page(request: Request):
349
+ state.filtered_indices = _build_filtered_indices(state)
350
+ pair_summary = _build_pair_summary_html(state.comparison_rows)
351
+ ctx = _get_comp_context(0)
352
+ stats = _stats_context()
353
+ return templates.TemplateResponse(request, "comparisons.html", {
354
+ "active_tab": "comparisons",
355
+ "models": state.models,
356
+ "pair_summary": pair_summary,
357
+ "winner_filter": "All",
358
+ "model_filter": "All",
359
+ **ctx,
360
+ **stats,
361
+ })
362
+
363
+ @app.get("/comparisons/filter", response_class=HTMLResponse)
364
+ async def comparisons_filter(
365
+ request: Request,
366
+ winner: str = "All",
367
+ model: str = "All",
368
+ ):
369
+ state.filtered_indices = _build_filtered_indices(state, winner, model)
370
+ ctx = _get_comp_context(0, winner_filter=winner, model_filter=model)
371
+ return templates.TemplateResponse(request, "comparison_card.html", ctx)
372
+
373
+ @app.get("/comparisons/{nav_idx}", response_class=HTMLResponse)
374
+ async def comparison_at(
375
+ request: Request,
376
+ nav_idx: int,
377
+ winner: str = "All",
378
+ model: str = "All",
379
+ ):
380
+ # Clamp nav_idx
381
+ nav_idx = max(0, min(nav_idx, len(state.filtered_indices) - 1))
382
+ ctx = _get_comp_context(nav_idx, winner_filter=winner, model_filter=model)
383
+ return templates.TemplateResponse(request, "comparison_card.html", ctx)
384
+
385
+ @app.post("/vote/{comp_idx}", response_class=HTMLResponse)
386
+ async def vote(request: Request, comp_idx: int, winner: str = Form(...)):
387
+ if comp_idx < 0 or comp_idx >= len(state.validation_comps):
388
+ return HTMLResponse("Invalid comparison", status_code=404)
389
+
390
+ comp = state.validation_comps[comp_idx]
391
+
392
+ # Idempotent: if already voted, just return revealed card
393
+ if comp.comparison_id not in state.completed_ids:
394
+ # Unswap for storage
395
+ winner_unswapped = winner
396
+ if comp.swapped:
397
+ if winner == "A":
398
+ winner_unswapped = "B"
399
+ elif winner == "B":
400
+ winner_unswapped = "A"
401
+
402
+ if winner_unswapped == "A":
403
+ winner_model = comp.model_a
404
+ elif winner_unswapped == "B":
405
+ winner_model = comp.model_b
406
+ else:
407
+ winner_model = "tie"
408
+
409
+ ann = {
410
+ "comparison_id": comp.comparison_id,
411
+ "sample_idx": comp.sample_idx,
412
+ "model_a": comp.model_a,
413
+ "model_b": comp.model_b,
414
+ "swapped": comp.swapped,
415
+ "winner": winner,
416
+ "winner_model": winner_model,
417
+ "timestamp": datetime.now(UTC).isoformat(),
418
+ }
419
+
420
+ state.annotations.append(ann)
421
+ state.completed_ids.add(comp.comparison_id)
422
+ save_annotations(state.save_path, ann_metadata, state.annotations)
423
+
424
+ nav_idx = _nav_idx_for_comp_idx(comp_idx)
425
+ # Read current filters from request query params (forwarded by htmx)
426
+ winner_filter = request.query_params.get("winner", "All")
427
+ model_filter = request.query_params.get("model", "All")
428
+
429
+ ctx = _get_comp_context(
430
+ nav_idx,
431
+ revealed=True,
432
+ voted=True,
433
+ human_vote=winner,
434
+ winner_filter=winner_filter,
435
+ model_filter=model_filter,
436
+ )
437
+ # Auto-advance: tell template this was a fresh vote
438
+ next_nav = nav_idx + 1 if nav_idx + 1 < len(state.filtered_indices) else None
439
+ ctx["just_voted"] = True
440
+ ctx["next_nav_idx"] = next_nav
441
+ ctx["next_url"] = (
442
+ f"/comparisons/{next_nav}"
443
+ + (f"?winner={winner_filter}" if winner_filter != "All" else "")
444
+ + (f"{'&' if winner_filter != 'All' else '?'}model={model_filter}" if model_filter != "All" else "")
445
+ if next_nav is not None
446
+ else None
447
+ )
448
+ response = templates.TemplateResponse(request, "comparison_card.html", ctx)
449
+ response.headers["HX-Trigger"] = "vote-recorded"
450
+ return response
451
+
452
+ @app.get("/reveal/{comp_idx}", response_class=HTMLResponse)
453
+ async def reveal(request: Request, comp_idx: int):
454
+ if comp_idx < 0 or comp_idx >= len(state.validation_comps):
455
+ return HTMLResponse("Invalid comparison", status_code=404)
456
+
457
+ nav_idx = _nav_idx_for_comp_idx(comp_idx)
458
+ winner_filter = request.query_params.get("winner", "All")
459
+ model_filter = request.query_params.get("model", "All")
460
+
461
+ ctx = _get_comp_context(
462
+ nav_idx,
463
+ revealed=True,
464
+ voted=False,
465
+ winner_filter=winner_filter,
466
+ model_filter=model_filter,
467
+ )
468
+ return templates.TemplateResponse(request, "comparison_card.html", ctx)
469
+
470
+ @app.get("/stats", response_class=HTMLResponse)
471
+ async def stats(request: Request):
472
+ ctx = _stats_context()
473
+ return templates.TemplateResponse(request, "stats_panel.html", ctx)
474
+
475
+ @app.get("/image/{sample_idx}")
476
+ async def image(sample_idx: int):
477
+ if img_loader is None:
478
+ return HTMLResponse("No images available", status_code=404)
479
+ img = img_loader.get(sample_idx)
480
+ if img is None:
481
+ return HTMLResponse("Image not found", status_code=404)
482
+ buf = io.BytesIO()
483
+ img.save(buf, format="PNG")
484
+ buf.seek(0)
485
+ return StreamingResponse(buf, media_type="image/png")
486
+
487
+ return app