Spaces:
Running
Running
| """Pairwise VLM judge — prompt templates, structured output schema, comparison building.""" | |
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import random | |
| from dataclasses import dataclass | |
| from itertools import combinations | |
| from typing import Any | |
| from PIL import Image | |
| logger = logging.getLogger(__name__) | |
| # --- Judge prompt --- | |
| PAIRWISE_PROMPT = """\ | |
| You are an expert OCR quality evaluator. You are given a document image and \ | |
| TWO OCR outputs (A and B) extracted from that same image. | |
| Compare them and decide which extraction is better overall. | |
| Evaluation criteria (in priority order): | |
| 1. Faithfulness: The output must ONLY contain text actually visible in the document. \ | |
| Hallucinating text that is not in the image (garbled strings, repeated tokens, \ | |
| nonsensical output) is the most serious error. Added commentary or notes \ | |
| (e.g. "it appears the text says...") is also an error, but less severe than \ | |
| hallucination. If a page is blank or has minimal text, saying so is acceptable — \ | |
| fabricating content is always worse. | |
| 2. Completeness: ALL visible text must be captured — headers, footers, marginalia, \ | |
| stamps, handwritten notes. Missing any section of text is a significant penalty. | |
| 3. Accuracy: Correct characters, no garbled or fabricated words. | |
| 4. Reading order: Text flows naturally as a human would read the document. | |
| 5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \ | |
| if present. Do NOT prefer fancier markdown formatting — plain accurate text is \ | |
| better than nicely formatted but incomplete text. | |
| If both outputs capture the same text with similar accuracy, respond with "tie". \ | |
| Only pick a winner when there is a clear quality difference. | |
| Output A: | |
| --- | |
| {ocr_text_a} | |
| --- | |
| Output B: | |
| --- | |
| {ocr_text_b} | |
| --- | |
| Respond with JSON only (no markdown fences, no extra text): | |
| {{"winner": "A", "reason": "brief explanation"}} | |
| Use "A", "B", or "tie" for the winner field.""" | |
| JUDGE_SCHEMA: dict[str, Any] = { | |
| "type": "object", | |
| "properties": { | |
| "winner": {"type": "string", "enum": ["A", "B", "tie"]}, | |
| "reason": {"type": "string"}, | |
| }, | |
| "required": ["winner", "reason"], | |
| } | |
| # Max characters of OCR text to include per output in the prompt. | |
| MAX_OCR_TEXT_LENGTH = 2500 | |
| # Max image dimension (longer side) before resizing. | |
| MAX_IMAGE_DIM = 1024 | |
| # --- Image helpers --- | |
| def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str: | |
| """Convert a PIL image to a base64-encoded JPEG string, resizing if needed.""" | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| if max(image.size) > max_dim: | |
| ratio = max_dim / max(image.size) | |
| new_size = (int(image.width * ratio), int(image.height * ratio)) | |
| image = image.resize(new_size, Image.Resampling.LANCZOS) | |
| buf = io.BytesIO() | |
| image.save(buf, format="JPEG", quality=85) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| # --- Comparison --- | |
| class Comparison: | |
| """A single pairwise comparison to evaluate.""" | |
| sample_idx: int | |
| model_a: str | |
| model_b: str | |
| col_a: str | |
| col_b: str | |
| swapped: bool | |
| messages: list[dict[str, Any]] | |
| text_a: str = "" | |
| text_b: str = "" | |
| def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]: | |
| """Build the pairwise comparison prompt, applying position-bias swap. | |
| Returns (prompt_text, swapped). | |
| """ | |
| a = text_a[:MAX_OCR_TEXT_LENGTH] | |
| b = text_b[:MAX_OCR_TEXT_LENGTH] | |
| if swapped: | |
| a, b = b, a | |
| return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped | |
| def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]: | |
| """Build chat messages for the judge (image + prompt).""" | |
| return [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}, | |
| }, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| def _normalize_pair(a: str, b: str) -> tuple[str, str]: | |
| """Return a canonical (sorted) pair for symmetric lookup.""" | |
| return (a, b) if a <= b else (b, a) | |
| def sample_indices( | |
| dataset_len: int, max_samples: int | None = None, seed: int = 42 | |
| ) -> list[int]: | |
| """Compute shuffled sample indices (cheap — no image loading). | |
| Args: | |
| dataset_len: Total number of rows in the dataset. | |
| max_samples: If set, randomly sample this many indices. | |
| seed: Random seed for reproducible sampling. | |
| Returns: | |
| List of integer indices into the dataset. | |
| """ | |
| indices = list(range(dataset_len)) | |
| if max_samples and max_samples < len(indices): | |
| random.seed(seed) | |
| indices = random.sample(indices, max_samples) | |
| return indices | |
| def build_comparisons( | |
| dataset: Any, | |
| ocr_columns: dict[str, str], | |
| max_samples: int | None = None, | |
| seed: int = 42, | |
| skip_pairs: set[tuple[str, str]] | None = None, | |
| indices: list[int] | None = None, | |
| ) -> list[Comparison]: | |
| """Build pairwise comparison prompts from a dataset. | |
| Args: | |
| dataset: HF dataset with an "image" column and OCR output columns. | |
| ocr_columns: Mapping of column_name -> model_name. | |
| max_samples: If set, randomly sample this many rows. Ignored when | |
| ``indices`` is provided. | |
| seed: Random seed for sampling and position-bias randomization. | |
| skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are | |
| normalized so (a, b) and (b, a) are treated identically. | |
| If None, all pairs are included. | |
| indices: Explicit row indices to use. When provided, ``max_samples`` | |
| and ``seed`` are not used for index selection (seed is still used | |
| for position-bias randomization). | |
| Returns: | |
| List of Comparison objects with pre-built chat messages. | |
| """ | |
| col_names = list(ocr_columns.keys()) | |
| model_names = list(ocr_columns.values()) | |
| pairs = list(combinations(range(len(col_names)), 2)) | |
| # Normalize skip set for symmetric lookup | |
| normalized_skip: set[tuple[str, str]] = set() | |
| if skip_pairs: | |
| normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs} | |
| if indices is None: | |
| indices = sample_indices(len(dataset), max_samples, seed) | |
| rng = random.Random(seed) | |
| comparisons: list[Comparison] = [] | |
| # Pre-fetch text columns to avoid triggering image decode per row. | |
| # HF Dataset supports column access (dataset["col"]), plain lists don't. | |
| text_cols_data: dict[str, list] | None = None | |
| if hasattr(dataset, "column_names"): | |
| text_cols_data = {col: dataset[col] for col in col_names} | |
| for idx in indices: | |
| # Determine which pairs need judging for this row | |
| needed_pairs = [ | |
| (i, j) | |
| for i, j in pairs | |
| if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip | |
| ] | |
| if not needed_pairs: | |
| continue # Skip image encoding entirely | |
| # Check text availability before decoding the image | |
| valid_pairs = [] | |
| if text_cols_data is not None: | |
| for i, j in needed_pairs: | |
| text_a = text_cols_data[col_names[i]][idx] or "" | |
| text_b = text_cols_data[col_names[j]][idx] or "" | |
| if text_a.strip() and text_b.strip(): | |
| valid_pairs.append((i, j, text_a, text_b)) | |
| else: | |
| row = dataset[idx] | |
| for i, j in needed_pairs: | |
| text_a = row[col_names[i]] or "" | |
| text_b = row[col_names[j]] or "" | |
| if text_a.strip() and text_b.strip(): | |
| valid_pairs.append((i, j, text_a, text_b)) | |
| if not valid_pairs: | |
| continue | |
| image_b64 = image_to_base64(dataset[idx]["image"]) | |
| for i, j, text_a, text_b in valid_pairs: | |
| swapped = rng.random() < 0.5 | |
| prompt, swapped = build_prompt(text_a, text_b, swapped) | |
| messages = build_messages(image_b64, prompt) | |
| comparisons.append( | |
| Comparison( | |
| sample_idx=idx, | |
| model_a=model_names[i], | |
| model_b=model_names[j], | |
| col_a=col_names[i], | |
| col_b=col_names[j], | |
| swapped=swapped, | |
| messages=messages, | |
| text_a=text_a, | |
| text_b=text_b, | |
| ) | |
| ) | |
| return comparisons | |
| # --- Output parsing --- | |
| def parse_judge_output(text: str) -> dict[str, str]: | |
| """Parse judge JSON output, handling markdown fences and invalid values. | |
| Returns dict with "winner" and "reason" keys, or empty dict on failure. | |
| """ | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip() | |
| try: | |
| result = json.loads(text) | |
| winner = result.get("winner", "tie").upper().strip() | |
| if winner == "TIE": | |
| winner = "tie" | |
| if winner not in ("A", "B", "tie"): | |
| winner = "tie" | |
| return {"winner": winner, "reason": result.get("reason", "")} | |
| except json.JSONDecodeError: | |
| logger.warning("Failed to parse judge output: %s", text[:200]) | |
| return {} | |