File size: 9,392 Bytes
1118181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""Pairwise VLM judge β€” prompt templates, structured output schema, comparison building."""

from __future__ import annotations

import base64
import io
import json
import logging
import random
from dataclasses import dataclass
from itertools import combinations
from typing import Any

from PIL import Image

logger = logging.getLogger(__name__)

# --- Judge prompt ---

PAIRWISE_PROMPT = """\
You are an expert OCR quality evaluator. You are given a document image and \
TWO OCR outputs (A and B) extracted from that same image.

Compare them and decide which extraction is better overall.

Evaluation criteria (in priority order):

1. Faithfulness: The output must ONLY contain text actually visible in the document. \
Hallucinating text that is not in the image (garbled strings, repeated tokens, \
nonsensical output) is the most serious error. Added commentary or notes \
(e.g. "it appears the text says...") is also an error, but less severe than \
hallucination. If a page is blank or has minimal text, saying so is acceptable β€” \
fabricating content is always worse.

2. Completeness: ALL visible text must be captured β€” headers, footers, marginalia, \
stamps, handwritten notes. Missing any section of text is a significant penalty.

3. Accuracy: Correct characters, no garbled or fabricated words.

4. Reading order: Text flows naturally as a human would read the document.

5. Formatting: Clean structure. Ignore bounding box tags like <|ref|> <|det|> \
if present. Do NOT prefer fancier markdown formatting β€” plain accurate text is \
better than nicely formatted but incomplete text.

If both outputs capture the same text with similar accuracy, respond with "tie". \
Only pick a winner when there is a clear quality difference.

Output A:
---
{ocr_text_a}
---

Output B:
---
{ocr_text_b}
---

Respond with JSON only (no markdown fences, no extra text):
{{"winner": "A", "reason": "brief explanation"}}
Use "A", "B", or "tie" for the winner field."""

JUDGE_SCHEMA: dict[str, Any] = {
    "type": "object",
    "properties": {
        "winner": {"type": "string", "enum": ["A", "B", "tie"]},
        "reason": {"type": "string"},
    },
    "required": ["winner", "reason"],
}

# Max characters of OCR text to include per output in the prompt.
MAX_OCR_TEXT_LENGTH = 2500

# Max image dimension (longer side) before resizing.
MAX_IMAGE_DIM = 1024


# --- Image helpers ---


def image_to_base64(image: Image.Image, max_dim: int = MAX_IMAGE_DIM) -> str:
    """Convert a PIL image to a base64-encoded JPEG string, resizing if needed."""
    if image.mode != "RGB":
        image = image.convert("RGB")
    if max(image.size) > max_dim:
        ratio = max_dim / max(image.size)
        new_size = (int(image.width * ratio), int(image.height * ratio))
        image = image.resize(new_size, Image.Resampling.LANCZOS)
    buf = io.BytesIO()
    image.save(buf, format="JPEG", quality=85)
    return base64.b64encode(buf.getvalue()).decode()


# --- Comparison ---


@dataclass
class Comparison:
    """A single pairwise comparison to evaluate."""

    sample_idx: int
    model_a: str
    model_b: str
    col_a: str
    col_b: str
    swapped: bool
    messages: list[dict[str, Any]]
    text_a: str = ""
    text_b: str = ""


def build_prompt(text_a: str, text_b: str, swapped: bool) -> tuple[str, bool]:
    """Build the pairwise comparison prompt, applying position-bias swap.

    Returns (prompt_text, swapped).
    """
    a = text_a[:MAX_OCR_TEXT_LENGTH]
    b = text_b[:MAX_OCR_TEXT_LENGTH]
    if swapped:
        a, b = b, a
    return PAIRWISE_PROMPT.format(ocr_text_a=a, ocr_text_b=b), swapped


def build_messages(image_b64: str, prompt: str) -> list[dict[str, Any]]:
    """Build chat messages for the judge (image + prompt)."""
    return [
        {
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"},
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]


def _normalize_pair(a: str, b: str) -> tuple[str, str]:
    """Return a canonical (sorted) pair for symmetric lookup."""
    return (a, b) if a <= b else (b, a)


def sample_indices(
    dataset_len: int, max_samples: int | None = None, seed: int = 42
) -> list[int]:
    """Compute shuffled sample indices (cheap β€” no image loading).

    Args:
        dataset_len: Total number of rows in the dataset.
        max_samples: If set, randomly sample this many indices.
        seed: Random seed for reproducible sampling.

    Returns:
        List of integer indices into the dataset.
    """
    indices = list(range(dataset_len))
    if max_samples and max_samples < len(indices):
        random.seed(seed)
        indices = random.sample(indices, max_samples)
    return indices


def build_comparisons(
    dataset: Any,
    ocr_columns: dict[str, str],
    max_samples: int | None = None,
    seed: int = 42,
    skip_pairs: set[tuple[str, str]] | None = None,
    indices: list[int] | None = None,
) -> list[Comparison]:
    """Build pairwise comparison prompts from a dataset.

    Args:
        dataset: HF dataset with an "image" column and OCR output columns.
        ocr_columns: Mapping of column_name -> model_name.
        max_samples: If set, randomly sample this many rows. Ignored when
            ``indices`` is provided.
        seed: Random seed for sampling and position-bias randomization.
        skip_pairs: Set of (model_a, model_b) pairs to exclude. Pairs are
            normalized so (a, b) and (b, a) are treated identically.
            If None, all pairs are included.
        indices: Explicit row indices to use. When provided, ``max_samples``
            and ``seed`` are not used for index selection (seed is still used
            for position-bias randomization).

    Returns:
        List of Comparison objects with pre-built chat messages.
    """
    col_names = list(ocr_columns.keys())
    model_names = list(ocr_columns.values())
    pairs = list(combinations(range(len(col_names)), 2))

    # Normalize skip set for symmetric lookup
    normalized_skip: set[tuple[str, str]] = set()
    if skip_pairs:
        normalized_skip = {_normalize_pair(a, b) for a, b in skip_pairs}

    if indices is None:
        indices = sample_indices(len(dataset), max_samples, seed)

    rng = random.Random(seed)
    comparisons: list[Comparison] = []

    # Pre-fetch text columns to avoid triggering image decode per row.
    # HF Dataset supports column access (dataset["col"]), plain lists don't.
    text_cols_data: dict[str, list] | None = None
    if hasattr(dataset, "column_names"):
        text_cols_data = {col: dataset[col] for col in col_names}

    for idx in indices:
        # Determine which pairs need judging for this row
        needed_pairs = [
            (i, j)
            for i, j in pairs
            if _normalize_pair(model_names[i], model_names[j]) not in normalized_skip
        ]
        if not needed_pairs:
            continue  # Skip image encoding entirely

        # Check text availability before decoding the image
        valid_pairs = []
        if text_cols_data is not None:
            for i, j in needed_pairs:
                text_a = text_cols_data[col_names[i]][idx] or ""
                text_b = text_cols_data[col_names[j]][idx] or ""
                if text_a.strip() and text_b.strip():
                    valid_pairs.append((i, j, text_a, text_b))
        else:
            row = dataset[idx]
            for i, j in needed_pairs:
                text_a = row[col_names[i]] or ""
                text_b = row[col_names[j]] or ""
                if text_a.strip() and text_b.strip():
                    valid_pairs.append((i, j, text_a, text_b))

        if not valid_pairs:
            continue

        image_b64 = image_to_base64(dataset[idx]["image"])

        for i, j, text_a, text_b in valid_pairs:
            swapped = rng.random() < 0.5
            prompt, swapped = build_prompt(text_a, text_b, swapped)
            messages = build_messages(image_b64, prompt)

            comparisons.append(
                Comparison(
                    sample_idx=idx,
                    model_a=model_names[i],
                    model_b=model_names[j],
                    col_a=col_names[i],
                    col_b=col_names[j],
                    swapped=swapped,
                    messages=messages,
                    text_a=text_a,
                    text_b=text_b,
                )
            )

    return comparisons


# --- Output parsing ---


def parse_judge_output(text: str) -> dict[str, str]:
    """Parse judge JSON output, handling markdown fences and invalid values.

    Returns dict with "winner" and "reason" keys, or empty dict on failure.
    """
    text = text.strip()
    if text.startswith("```"):
        text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip()
    try:
        result = json.loads(text)
        winner = result.get("winner", "tie").upper().strip()
        if winner == "TIE":
            winner = "tie"
        if winner not in ("A", "B", "tie"):
            winner = "tie"
        return {"winner": winner, "reason": result.get("reason", "")}
    except json.JSONDecodeError:
        logger.warning("Failed to parse judge output: %s", text[:200])
        return {}