davanstrien's picture
davanstrien HF Staff
Upload folder using huggingface_hub
1118181 verified
"""Pairwise VLM judge — prompt templates, structured output schema, comparison building."""
from __future__ import annotations
import base64
import io
import json
import logging
import random
from dataclasses import dataclass
from itertools import combinations
from typing import Any
from PIL import Image
logger = logging.getLogger(__name__)
# --- Judge prompt ---
PAIRWISE_PROMPT = """\
You are an expert OCR quality evaluator. You are given a document image and \
TWO OCR outputs (A and B) extracted from that same image.
Compare them and decide which extraction is better overall.
Evaluation criteria (in priority order):
1. Faithfulness: The output must ONLY contain text actually visible in the document. \
Hallucinating text that is not in the image (garbled strings, repeated tokens, \
nonsensical output) is the most serious error. Added commentary or notes \
(e.g. "it appears the text says...") is also an error, but less severe than \
hallucination. If a page is blank or has minimal text, saying so is acceptable — \
fabricating content is always worse.
2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \
stamps, handwritten notes. Missing any section of text is a significant penalty.
3. Accuracy: Correct characters, no garbled or fabricated words.
4. Reading order: Text flows naturally as a human would read the document.
5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \
if present. Do NOT prefer fancier markdown formatting — plain accurate text is \
better than nicely formatted but incomplete text.
If both outputs capture the same text with similar accuracy, respond with "tie". \
Only pick a winner when there is a clear quality difference.
Output A:
---
{ocr_text_a}
---
Output B:
---
{ocr_text_b}
---
Respond with JSON only (no markdown fences, no extra text):
{{"winner": "A", "reason": "brief explanation"}}
Use "A", "B", or "tie" for the winner field."""
JUDGE_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"winner": {"type": "string", "enum": ["A", "B", "tie"]},
"reason": {"type": "string"},
},
"required": ["winner", "reason"],
}
# Max characters of OCR text to include per output in the prompt.
MAX_OCR_TEXT_LENGTH = 2500
# Max image dimension (longer side) before resizing.
MAX_IMAGE_DIM = 1024
# --- Image helpers ---
def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str:
"""Convert a PIL image to a base64-encoded JPEG string, resizing if needed."""
if image.mode != "RGB":
image = image.convert("RGB")
if max(image.size) > max_dim:
ratio = max_dim / max(image.size)
new_size = (int(image.width * ratio), int(image.height * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
buf = io.BytesIO()
image.save(buf, format="JPEG", quality=85)
return base64.b64encode(buf.getvalue()).decode()
# --- Comparison ---
@dataclass
class Comparison:
"""A single pairwise comparison to evaluate."""
sample_idx: int
model_a: str
model_b: str
col_a: str
col_b: str
swapped: bool
messages: list[dict[str, Any]]
text_a: str = ""
text_b: str = ""
def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]:
"""Build the pairwise comparison prompt, applying position-bias swap.
Returns (prompt_text, swapped).
"""
a = text_a[:MAX_OCR_TEXT_LENGTH]
b = text_b[:MAX_OCR_TEXT_LENGTH]
if swapped:
a, b = b, a
return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped
def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]:
"""Build chat messages for the judge (image + prompt)."""
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
},
{"type": "text", "text": prompt},
],
}
]
def _normalize_pair(a: str, b: str) -> tuple[str, str]:
"""Return a canonical (sorted) pair for symmetric lookup."""
return (a, b) if a <= b else (b, a)
def sample_indices(
dataset_len: int, max_samples: int | None = None, seed: int = 42
) -> list[int]:
"""Compute shuffled sample indices (cheap — no image loading).
Args:
dataset_len: Total number of rows in the dataset.
max_samples: If set, randomly sample this many indices.
seed: Random seed for reproducible sampling.
Returns:
List of integer indices into the dataset.
"""
indices = list(range(dataset_len))
if max_samples and max_samples < len(indices):
random.seed(seed)
indices = random.sample(indices, max_samples)
return indices
def build_comparisons(
dataset: Any,
ocr_columns: dict[str, str],
max_samples: int | None = None,
seed: int = 42,
skip_pairs: set[tuple[str, str]] | None = None,
indices: list[int] | None = None,
) -> list[Comparison]:
"""Build pairwise comparison prompts from a dataset.
Args:
dataset: HF dataset with an "image" column and OCR output columns.
ocr_columns: Mapping of column_name -> model_name.
max_samples: If set, randomly sample this many rows. Ignored when
``indices`` is provided.
seed: Random seed for sampling and position-bias randomization.
skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are
normalized so (a, b) and (b, a) are treated identically.
If None, all pairs are included.
indices: Explicit row indices to use. When provided, ``max_samples``
and ``seed`` are not used for index selection (seed is still used
for position-bias randomization).
Returns:
List of Comparison objects with pre-built chat messages.
"""
col_names = list(ocr_columns.keys())
model_names = list(ocr_columns.values())
pairs = list(combinations(range(len(col_names)), 2))
# Normalize skip set for symmetric lookup
normalized_skip: set[tuple[str, str]] = set()
if skip_pairs:
normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs}
if indices is None:
indices = sample_indices(len(dataset), max_samples, seed)
rng = random.Random(seed)
comparisons: list[Comparison] = []
# Pre-fetch text columns to avoid triggering image decode per row.
# HF Dataset supports column access (dataset["col"]), plain lists don't.
text_cols_data: dict[str, list] | None = None
if hasattr(dataset, "column_names"):
text_cols_data = {col: dataset[col] for col in col_names}
for idx in indices:
# Determine which pairs need judging for this row
needed_pairs = [
(i, j)
for i, j in pairs
if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip
]
if not needed_pairs:
continue # Skip image encoding entirely
# Check text availability before decoding the image
valid_pairs = []
if text_cols_data is not None:
for i, j in needed_pairs:
text_a = text_cols_data[col_names[i]][idx] or ""
text_b = text_cols_data[col_names[j]][idx] or ""
if text_a.strip() and text_b.strip():
valid_pairs.append((i, j, text_a, text_b))
else:
row = dataset[idx]
for i, j in needed_pairs:
text_a = row[col_names[i]] or ""
text_b = row[col_names[j]] or ""
if text_a.strip() and text_b.strip():
valid_pairs.append((i, j, text_a, text_b))
if not valid_pairs:
continue
image_b64 = image_to_base64(dataset[idx]["image"])
for i, j, text_a, text_b in valid_pairs:
swapped = rng.random() < 0.5
prompt, swapped = build_prompt(text_a, text_b, swapped)
messages = build_messages(image_b64, prompt)
comparisons.append(
Comparison(
sample_idx=idx,
model_a=model_names[i],
model_b=model_names[j],
col_a=col_names[i],
col_b=col_names[j],
swapped=swapped,
messages=messages,
text_a=text_a,
text_b=text_b,
)
)
return comparisons
# --- Output parsing ---
def parse_judge_output(text: str) -> dict[str, str]:
"""Parse judge JSON output, handling markdown fences and invalid values.
Returns dict with "winner" and "reason" keys, or empty dict on failure.
"""
text = text.strip()
if text.startswith("```"):
text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()
try:
result = json.loads(text)
winner = result.get("winner", "tie").upper().strip()
if winner == "TIE":
winner = "tie"
if winner not in ("A", "B", "tie"):
winner = "tie"
return {"winner": winner, "reason": result.get("reason", "")}
except json.JSONDecodeError:
logger.warning("Failed to parse judge output: %s", text[:200])
return {}