| | """Pairwise VLM judge — prompt templates, structured output schema, comparison building.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import base64 |
| | import io |
| | import json |
| | import logging |
| | import random |
| | from dataclasses import dataclass |
| | from itertools import combinations |
| | from typing import Any |
| |
|
| | from PIL import Image |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| |
|
| | PAIRWISE_PROMPT = """\ |
| | You are an expert OCR quality evaluator. You are given a document image and \ |
| | TWO OCR outputs (A and B) extracted from that same image. |
| | |
| | Compare them and decide which extraction is better overall. |
| | |
| | Evaluation criteria (in priority order): |
| | |
| | 1. Faithfulness: The output must ONLY contain text actually visible in the document. \ |
| | Hallucinating text that is not in the image (garbled strings, repeated tokens, \ |
| | nonsensical output) is the most serious error. Added commentary or notes \ |
| | (e.g. "it appears the text says...") is also an error, but less severe than \ |
| | hallucination. If a page is blank or has minimal text, saying so is acceptable — \ |
| | fabricating content is always worse. |
| | |
| | 2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \ |
| | stamps, handwritten notes. Missing any section of text is a significant penalty. |
| | |
| | 3. Accuracy: Correct characters, no garbled or fabricated words. |
| | |
| | 4. Reading order: Text flows naturally as a human would read the document. |
| | |
| | 5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \ |
| | if present. Do NOT prefer fancier markdown formatting — plain accurate text is \ |
| | better than nicely formatted but incomplete text. |
| | |
| | If both outputs capture the same text with similar accuracy, respond with "tie". \ |
| | Only pick a winner when there is a clear quality difference. |
| | |
| | Output A: |
| | --- |
| | {ocr_text_a} |
| | --- |
| | |
| | Output B: |
| | --- |
| | {ocr_text_b} |
| | --- |
| | |
| | Respond with JSON only (no markdown fences, no extra text): |
| | {{"winner": "A", "reason": "brief explanation"}} |
| | Use "A", "B", or "tie" for the winner field.""" |
| |
|
| | JUDGE_SCHEMA: dict[str, Any] = { |
| | "type": "object", |
| | "properties": { |
| | "winner": {"type": "string", "enum": ["A", "B", "tie"]}, |
| | "reason": {"type": "string"}, |
| | }, |
| | "required": ["winner", "reason"], |
| | } |
| |
|
| | |
| | MAX_OCR_TEXT_LENGTH = 2500 |
| |
|
| | |
| | MAX_IMAGE_DIM = 1024 |
| |
|
| |
|
| | |
| |
|
| |
|
| | def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str: |
| | """Convert a PIL image to a base64-encoded JPEG string, resizing if needed.""" |
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| | if max(image.size) > max_dim: |
| | ratio = max_dim / max(image.size) |
| | new_size = (int(image.width * ratio), int(image.height * ratio)) |
| | image = image.resize(new_size, Image.Resampling.LANCZOS) |
| | buf = io.BytesIO() |
| | image.save(buf, format="JPEG", quality=85) |
| | return base64.b64encode(buf.getvalue()).decode() |
| |
|
| |
|
| | |
| |
|
| |
|
| | @dataclass |
| | class Comparison: |
| | """A single pairwise comparison to evaluate.""" |
| |
|
| | sample_idx: int |
| | model_a: str |
| | model_b: str |
| | col_a: str |
| | col_b: str |
| | swapped: bool |
| | messages: list[dict[str, Any]] |
| | text_a: str = "" |
| | text_b: str = "" |
| |
|
| |
|
| | def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]: |
| | """Build the pairwise comparison prompt, applying position-bias swap. |
| | |
| | Returns (prompt_text, swapped). |
| | """ |
| | a = text_a[:MAX_OCR_TEXT_LENGTH] |
| | b = text_b[:MAX_OCR_TEXT_LENGTH] |
| | if swapped: |
| | a, b = b, a |
| | return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped |
| |
|
| |
|
| | def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]: |
| | """Build chat messages for the judge (image + prompt).""" |
| | return [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image_url", |
| | "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}, |
| | }, |
| | {"type": "text", "text": prompt}, |
| | ], |
| | } |
| | ] |
| |
|
| |
|
| | def _normalize_pair(a: str, b: str) -> tuple[str, str]: |
| | """Return a canonical (sorted) pair for symmetric lookup.""" |
| | return (a, b) if a <= b else (b, a) |
| |
|
| |
|
| | def sample_indices( |
| | dataset_len: int, max_samples: int | None = None, seed: int = 42 |
| | ) -> list[int]: |
| | """Compute shuffled sample indices (cheap — no image loading). |
| | |
| | Args: |
| | dataset_len: Total number of rows in the dataset. |
| | max_samples: If set, randomly sample this many indices. |
| | seed: Random seed for reproducible sampling. |
| | |
| | Returns: |
| | List of integer indices into the dataset. |
| | """ |
| | indices = list(range(dataset_len)) |
| | if max_samples and max_samples < len(indices): |
| | random.seed(seed) |
| | indices = random.sample(indices, max_samples) |
| | return indices |
| |
|
| |
|
| | def build_comparisons( |
| | dataset: Any, |
| | ocr_columns: dict[str, str], |
| | max_samples: int | None = None, |
| | seed: int = 42, |
| | skip_pairs: set[tuple[str, str]] | None = None, |
| | indices: list[int] | None = None, |
| | ) -> list[Comparison]: |
| | """Build pairwise comparison prompts from a dataset. |
| | |
| | Args: |
| | dataset: HF dataset with an "image" column and OCR output columns. |
| | ocr_columns: Mapping of column_name -> model_name. |
| | max_samples: If set, randomly sample this many rows. Ignored when |
| | ``indices`` is provided. |
| | seed: Random seed for sampling and position-bias randomization. |
| | skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are |
| | normalized so (a, b) and (b, a) are treated identically. |
| | If None, all pairs are included. |
| | indices: Explicit row indices to use. When provided, ``max_samples`` |
| | and ``seed`` are not used for index selection (seed is still used |
| | for position-bias randomization). |
| | |
| | Returns: |
| | List of Comparison objects with pre-built chat messages. |
| | """ |
| | col_names = list(ocr_columns.keys()) |
| | model_names = list(ocr_columns.values()) |
| | pairs = list(combinations(range(len(col_names)), 2)) |
| |
|
| | |
| | normalized_skip: set[tuple[str, str]] = set() |
| | if skip_pairs: |
| | normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs} |
| |
|
| | if indices is None: |
| | indices = sample_indices(len(dataset), max_samples, seed) |
| |
|
| | rng = random.Random(seed) |
| | comparisons: list[Comparison] = [] |
| |
|
| | |
| | |
| | text_cols_data: dict[str, list] | None = None |
| | if hasattr(dataset, "column_names"): |
| | text_cols_data = {col: dataset[col] for col in col_names} |
| |
|
| | for idx in indices: |
| | |
| | needed_pairs = [ |
| | (i, j) |
| | for i, j in pairs |
| | if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip |
| | ] |
| | if not needed_pairs: |
| | continue |
| |
|
| | |
| | valid_pairs = [] |
| | if text_cols_data is not None: |
| | for i, j in needed_pairs: |
| | text_a = text_cols_data[col_names[i]][idx] or "" |
| | text_b = text_cols_data[col_names[j]][idx] or "" |
| | if text_a.strip() and text_b.strip(): |
| | valid_pairs.append((i, j, text_a, text_b)) |
| | else: |
| | row = dataset[idx] |
| | for i, j in needed_pairs: |
| | text_a = row[col_names[i]] or "" |
| | text_b = row[col_names[j]] or "" |
| | if text_a.strip() and text_b.strip(): |
| | valid_pairs.append((i, j, text_a, text_b)) |
| |
|
| | if not valid_pairs: |
| | continue |
| |
|
| | image_b64 = image_to_base64(dataset[idx]["image"]) |
| |
|
| | for i, j, text_a, text_b in valid_pairs: |
| | swapped = rng.random() < 0.5 |
| | prompt, swapped = build_prompt(text_a, text_b, swapped) |
| | messages = build_messages(image_b64, prompt) |
| |
|
| | comparisons.append( |
| | Comparison( |
| | sample_idx=idx, |
| | model_a=model_names[i], |
| | model_b=model_names[j], |
| | col_a=col_names[i], |
| | col_b=col_names[j], |
| | swapped=swapped, |
| | messages=messages, |
| | text_a=text_a, |
| | text_b=text_b, |
| | ) |
| | ) |
| |
|
| | return comparisons |
| |
|
| |
|
| | |
| |
|
| |
|
| | def parse_judge_output(text: str) -> dict[str, str]: |
| | """Parse judge JSON output, handling markdown fences and invalid values. |
| | |
| | Returns dict with "winner" and "reason" keys, or empty dict on failure. |
| | """ |
| | text = text.strip() |
| | if text.startswith("```"): |
| | text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip() |
| | try: |
| | result = json.loads(text) |
| | winner = result.get("winner", "tie").upper().strip() |
| | if winner == "TIE": |
| | winner = "tie" |
| | if winner not in ("A", "B", "tie"): |
| | winner = "tie" |
| | return {"winner": winner, "reason": result.get("reason", "")} |
| | except json.JSONDecodeError: |
| | logger.warning("Failed to parse judge output: %s", text[:200]) |
| | return {} |
| |
|