| | """Results viewer — data loading and helpers for OCR bench results.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import TYPE_CHECKING, Any |
| |
|
| | import structlog |
| | from datasets import load_dataset |
| |
|
| | if TYPE_CHECKING: |
| | from PIL import Image |
| |
|
| | logger = structlog.get_logger() |
| |
|
| |
|
| | def load_results(repo_id: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: |
| | """Load leaderboard and comparisons from a Hub results dataset. |
| | |
| | Tries the default config first (new repos), then falls back to the |
| | named ``leaderboard`` config (old repos). |
| | |
| | Returns: |
| | (leaderboard_rows, comparison_rows) |
| | """ |
| | try: |
| | leaderboard_ds = load_dataset(repo_id, split="train") |
| | leaderboard_rows = [dict(row) for row in leaderboard_ds] |
| | except Exception: |
| | leaderboard_ds = load_dataset(repo_id, name="leaderboard", split="train") |
| | leaderboard_rows = [dict(row) for row in leaderboard_ds] |
| |
|
| | try: |
| | comparisons_ds = load_dataset(repo_id, name="comparisons", split="train") |
| | except Exception: |
| | logger.warning("no_comparisons_config", repo=repo_id) |
| | return leaderboard_rows, [] |
| | comparison_rows = [dict(row) for row in comparisons_ds] |
| |
|
| | return leaderboard_rows, comparison_rows |
| |
|
| |
|
| | def _load_source_metadata(repo_id: str) -> dict[str, Any]: |
| | """Load metadata config from results repo to find the source dataset.""" |
| | try: |
| | meta_ds = load_dataset(repo_id, name="metadata", split="train") |
| | if len(meta_ds) > 0: |
| | return dict(meta_ds[0]) |
| | except Exception as exc: |
| | logger.warning("could_not_load_metadata", repo=repo_id, error=str(exc)) |
| | return {} |
| |
|
| |
|
| | class ImageLoader: |
| | """Lazy image loader — fetches images from source dataset by sample_idx.""" |
| |
|
| | def __init__(self, source_dataset: str, from_prs: bool = False): |
| | self._source = source_dataset |
| | self._from_prs = from_prs |
| | self._cache: dict[int, Any] = {} |
| | self._image_col: str | None = None |
| | self._pr_revision: str | None = None |
| | self._available = True |
| | self._init_done = False |
| |
|
| | def _init_source(self) -> None: |
| | """Lazy init: discover image column and PR revision on first call.""" |
| | if self._init_done: |
| | return |
| | self._init_done = True |
| |
|
| | try: |
| | if self._from_prs: |
| | from ocr_bench.dataset import discover_pr_configs |
| |
|
| | _, revisions = discover_pr_configs(self._source) |
| | if revisions: |
| | |
| | first_config = next(iter(revisions)) |
| | self._pr_revision = revisions[first_config] |
| |
|
| | |
| | kwargs: dict[str, Any] = {"path": self._source, "split": "train[:1]"} |
| | if self._pr_revision: |
| | |
| | first_config = next(iter(revisions)) |
| | kwargs["name"] = first_config |
| | kwargs["revision"] = self._pr_revision |
| | probe = load_dataset(**kwargs) |
| | for col in probe.column_names: |
| | if col == "image" or "image" in col.lower(): |
| | self._image_col = col |
| | break |
| | if not self._image_col: |
| | logger.info("no_image_column_in_source", source=self._source) |
| | self._available = False |
| | except Exception as exc: |
| | logger.warning("image_loader_init_failed", source=self._source, error=str(exc)) |
| | self._available = False |
| |
|
| | def get(self, sample_idx: int) -> Image.Image | None: |
| | """Fetch image for a sample index. Returns None on failure.""" |
| | self._init_source() |
| | if not self._available or self._image_col is None: |
| | return None |
| | if sample_idx in self._cache: |
| | return self._cache[sample_idx] |
| | try: |
| | kwargs: dict[str, Any] = { |
| | "path": self._source, |
| | "split": f"train[{sample_idx}:{sample_idx + 1}]", |
| | } |
| | if self._pr_revision: |
| | from ocr_bench.dataset import discover_pr_configs |
| |
|
| | _, revisions = discover_pr_configs(self._source) |
| | if revisions: |
| | first_config = next(iter(revisions)) |
| | kwargs["name"] = first_config |
| | kwargs["revision"] = revisions[first_config] |
| | row = load_dataset(**kwargs) |
| | img = row[0][self._image_col] |
| | self._cache[sample_idx] = img |
| | return img |
| | except Exception as exc: |
| | logger.debug("image_load_failed", sample_idx=sample_idx, error=str(exc)) |
| | return None |
| |
|
| |
|
| | def _filter_comparisons( |
| | comparisons: list[dict[str, Any]], |
| | winner_filter: str, |
| | model_filter: str, |
| | ) -> list[dict[str, Any]]: |
| | """Filter comparison rows by winner and model.""" |
| | filtered = comparisons |
| | if winner_filter and winner_filter != "All": |
| | filtered = [c for c in filtered if c.get("winner") == winner_filter] |
| | if model_filter and model_filter != "All": |
| | filtered = [ |
| | c |
| | for c in filtered |
| | if c.get("model_a") == model_filter or c.get("model_b") == model_filter |
| | ] |
| | return filtered |
| |
|
| |
|
| | def _winner_badge(winner: str) -> str: |
| | """Return a badge string for the winner.""" |
| | if winner == "A": |
| | return "Winner: A" |
| | elif winner == "B": |
| | return "Winner: B" |
| | else: |
| | return "Tie" |
| |
|
| |
|
| | def _model_label(model: str, col: str) -> str: |
| | """Format model name with optional column name. Avoids empty parens.""" |
| | if col: |
| | return f"{model} ({col})" |
| | return model |
| |
|
| |
|
| | def _build_pair_summary(comparisons: list[dict[str, Any]]) -> str: |
| | """Build a win/loss summary string for each model pair.""" |
| | from collections import Counter |
| |
|
| | pair_counts: dict[tuple[str, str], Counter[str]] = {} |
| | for c in comparisons: |
| | ma = c.get("model_a", "") |
| | mb = c.get("model_b", "") |
| | winner = c.get("winner", "tie") |
| | key = (ma, mb) if ma <= mb else (mb, ma) |
| | if key not in pair_counts: |
| | pair_counts[key] = Counter() |
| | |
| | if winner == "A": |
| | actual_winner = ma |
| | elif winner == "B": |
| | actual_winner = mb |
| | else: |
| | actual_winner = "tie" |
| |
|
| | if actual_winner == key[0]: |
| | pair_counts[key]["W"] += 1 |
| | elif actual_winner == key[1]: |
| | pair_counts[key]["L"] += 1 |
| | else: |
| | pair_counts[key]["T"] += 1 |
| |
|
| | if not pair_counts: |
| | return "" |
| |
|
| | parts = [] |
| | for (ma, mb), counts in sorted(pair_counts.items()): |
| | short_a = ma.split("/")[-1] if "/" in ma else ma |
| | short_b = mb.split("/")[-1] if "/" in mb else mb |
| | wins, losses, ties = counts["W"], counts["L"], counts["T"] |
| | parts.append(f"**{short_a}** vs **{short_b}**: {wins}W {losses}L {ties}T") |
| | return " | ".join(parts) |
| |
|
| |
|
| |
|