"""Pairwise VLM judge — prompt templates, structured output schema, comparison building.""" from __future__ import annotations import base64 import io import json import logging import random from dataclasses import dataclass from itertools import combinations from typing import Any from PIL import Image logger = logging.getLogger(__name__) # --- Judge prompt --- PAIRWISE_PROMPT = """\ You are an expert OCR quality evaluator. You are given a document image and \ TWO OCR outputs (A and B) extracted from that same image. Compare them and decide which extraction is better overall. Evaluation criteria (in priority order): 1. Faithfulness: The output must ONLY contain text actually visible in the document. \ Hallucinating text that is not in the image (garbled strings, repeated tokens, \ nonsensical output) is the most serious error. Added commentary or notes \ (e.g. "it appears the text says...") is also an error, but less severe than \ hallucination. If a page is blank or has minimal text, saying so is acceptable — \ fabricating content is always worse. 2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \ stamps, handwritten notes. Missing any section of text is a significant penalty. 3. Accuracy: Correct characters, no garbled or fabricated words. 4. Reading order: Text flows naturally as a human would read the document. 5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \ if present. Do NOT prefer fancier markdown formatting — plain accurate text is \ better than nicely formatted but incomplete text. If both outputs capture the same text with similar accuracy, respond with "tie". \ Only pick a winner when there is a clear quality difference. Output A: --- {ocr_text_a} --- Output B: --- {ocr_text_b} --- Respond with JSON only (no markdown fences, no extra text): {{"winner": "A", "reason": "brief explanation"}} Use "A", "B", or "tie" for the winner field.""" JUDGE_SCHEMA: dict[str, Any] = { "type": "object", "properties": { "winner": {"type": "string", "enum": ["A", "B", "tie"]}, "reason": {"type": "string"}, }, "required": ["winner", "reason"], } # Max characters of OCR text to include per output in the prompt. MAX_OCR_TEXT_LENGTH = 2500 # Max image dimension (longer side) before resizing. MAX_IMAGE_DIM = 1024 # --- Image helpers --- def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str: """Convert a PIL image to a base64-encoded JPEG string, resizing if needed.""" if image.mode != "RGB": image = image.convert("RGB") if max(image.size) > max_dim: ratio = max_dim / max(image.size) new_size = (int(image.width * ratio), int(image.height * ratio)) image = image.resize(new_size, Image.Resampling.LANCZOS) buf = io.BytesIO() image.save(buf, format="JPEG", quality=85) return base64.b64encode(buf.getvalue()).decode() # --- Comparison --- @dataclass class Comparison: """A single pairwise comparison to evaluate.""" sample_idx: int model_a: str model_b: str col_a: str col_b: str swapped: bool messages: list[dict[str, Any]] text_a: str = "" text_b: str = "" def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]: """Build the pairwise comparison prompt, applying position-bias swap. Returns (prompt_text, swapped). """ a = text_a[:MAX_OCR_TEXT_LENGTH] b = text_b[:MAX_OCR_TEXT_LENGTH] if swapped: a, b = b, a return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]: """Build chat messages for the judge (image + prompt).""" return [ { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}, }, {"type": "text", "text": prompt}, ], } ] def _normalize_pair(a: str, b: str) -> tuple[str, str]: """Return a canonical (sorted) pair for symmetric lookup.""" return (a, b) if a <= b else (b, a) def sample_indices( dataset_len: int, max_samples: int | None = None, seed: int = 42 ) -> list[int]: """Compute shuffled sample indices (cheap — no image loading). Args: dataset_len: Total number of rows in the dataset. max_samples: If set, randomly sample this many indices. seed: Random seed for reproducible sampling. Returns: List of integer indices into the dataset. """ indices = list(range(dataset_len)) if max_samples and max_samples < len(indices): random.seed(seed) indices = random.sample(indices, max_samples) return indices def build_comparisons( dataset: Any, ocr_columns: dict[str, str], max_samples: int | None = None, seed: int = 42, skip_pairs: set[tuple[str, str]] | None = None, indices: list[int] | None = None, ) -> list[Comparison]: """Build pairwise comparison prompts from a dataset. Args: dataset: HF dataset with an "image" column and OCR output columns. ocr_columns: Mapping of column_name -> model_name. max_samples: If set, randomly sample this many rows. Ignored when ``indices`` is provided. seed: Random seed for sampling and position-bias randomization. skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are normalized so (a, b) and (b, a) are treated identically. If None, all pairs are included. indices: Explicit row indices to use. When provided, ``max_samples`` and ``seed`` are not used for index selection (seed is still used for position-bias randomization). Returns: List of Comparison objects with pre-built chat messages. """ col_names = list(ocr_columns.keys()) model_names = list(ocr_columns.values()) pairs = list(combinations(range(len(col_names)), 2)) # Normalize skip set for symmetric lookup normalized_skip: set[tuple[str, str]] = set() if skip_pairs: normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs} if indices is None: indices = sample_indices(len(dataset), max_samples, seed) rng = random.Random(seed) comparisons: list[Comparison] = [] # Pre-fetch text columns to avoid triggering image decode per row. # HF Dataset supports column access (dataset["col"]), plain lists don't. text_cols_data: dict[str, list] | None = None if hasattr(dataset, "column_names"): text_cols_data = {col: dataset[col] for col in col_names} for idx in indices: # Determine which pairs need judging for this row needed_pairs = [ (i, j) for i, j in pairs if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip ] if not needed_pairs: continue # Skip image encoding entirely # Check text availability before decoding the image valid_pairs = [] if text_cols_data is not None: for i, j in needed_pairs: text_a = text_cols_data[col_names[i]][idx] or "" text_b = text_cols_data[col_names[j]][idx] or "" if text_a.strip() and text_b.strip(): valid_pairs.append((i, j, text_a, text_b)) else: row = dataset[idx] for i, j in needed_pairs: text_a = row[col_names[i]] or "" text_b = row[col_names[j]] or "" if text_a.strip() and text_b.strip(): valid_pairs.append((i, j, text_a, text_b)) if not valid_pairs: continue image_b64 = image_to_base64(dataset[idx]["image"]) for i, j, text_a, text_b in valid_pairs: swapped = rng.random() < 0.5 prompt, swapped = build_prompt(text_a, text_b, swapped) messages = build_messages(image_b64, prompt) comparisons.append( Comparison( sample_idx=idx, model_a=model_names[i], model_b=model_names[j], col_a=col_names[i], col_b=col_names[j], swapped=swapped, messages=messages, text_a=text_a, text_b=text_b, ) ) return comparisons # --- Output parsing --- def parse_judge_output(text: str) -> dict[str, str]: """Parse judge JSON output, handling markdown fences and invalid values. Returns dict with "winner" and "reason" keys, or empty dict on failure. """ text = text.strip() if text.startswith("```"): text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip() try: result = json.loads(text) winner = result.get("winner", "tie").upper().strip() if winner == "TIE": winner = "tie" if winner not in ("A", "B", "tie"): winner = "tie" return {"winner": winner, "reason": result.get("reason", "")} except json.JSONDecodeError: logger.warning("Failed to parse judge output: %s", text[:200]) return {}