File size: 11,358 Bytes
b94bee0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""
Model loading and inference for OCR Confidence Visualization.

Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides
inference with token-level probability extraction.
"""

import math
from dataclasses import dataclass, field
from typing import Generator, Optional

import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor

# Available models for selection
AVAILABLE_MODELS = {
    "Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B",
    "olmOCR-7B": "allenai/olmOCR-7B-0725",
    "Aya-Vision-8B": "CohereLabs/aya-vision-8b",
}

DEFAULT_MODEL = "Aya-Vision-8B"

# Global model and processor (loaded once per model)
_model = None
_processor = None
_device = None
_current_model_name = None


@dataclass
class TokenData:
    """Data for a single generated token with probability info."""

    token: str
    probability: float
    alternatives: list[dict[str, float]]  # [{"token": str, "probability": float}, ...]
    entropy: float = field(default=0.0)  # Shannon entropy in bits


def calculate_entropy(probs: list[float]) -> float:
    """Calculate Shannon entropy in bits from a probability distribution.

    Args:
        probs: List of probabilities (should sum to ~1.0).

    Returns:
        Entropy in bits. 0.0 for empty or single-certainty distributions.
    """
    entropy = 0.0
    for p in probs:
        if p > 0:
            entropy -= p * math.log2(p)
    return entropy


def load_model(model_name: str = None):
    """Load the OCR model and processor. Reloads if model_name differs from current."""
    global _model, _processor, _device, _current_model_name

    if model_name is None:
        model_name = DEFAULT_MODEL

    model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL])

    # Return cached model if already loaded
    if _model is not None and _current_model_name == model_name:
        return _model, _processor

    # Unload previous model if switching
    if _model is not None:
        print(f"Unloading previous model: {_current_model_name}")
        del _model
        del _processor
        _model = None
        _processor = None
        torch.cuda.empty_cache()

    _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {_device}")
    print(f"Loading model: {model_id}...")

    _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
    _model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        attn_implementation="flash_attention_2",
        trust_remote_code=True,
        torch_dtype=torch.float16,
    ).to(_device).eval()

    _current_model_name = model_name
    print("Model loaded successfully")
    return _model, _processor


def run_ocr(image: Image.Image, prompt: str = None) -> str:
    """
    Run OCR on an image and return extracted text.

    Args:
        image: PIL Image to process
        prompt: Optional custom prompt (default: natural reading extraction)

    Returns:
        Extracted text from the image
    """
    model, processor = load_model()

    if prompt is None:
        prompt = "Extract the text from the above document as if you were reading it naturally."

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    prompt_full = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = processor(
        text=[prompt_full],
        images=[image],
        return_tensors="pt",
        padding=True,
    ).to(_device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=1,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
        )

    # Slice off input tokens
    generated_ids = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )[0]

    return output_text


def generate_with_logprobs(
    image: Image.Image,
    prompt: Optional[str] = None,
    max_new_tokens: int = 1024,
    top_k: int = 20,
    top_p: float = 0.9,
    temperature: float = 1.0,  # Use 1.0 for standard distribution, pick top token (argmax)
    repetition_penalty: float = 1.1,
    model_name: str = None,
) -> Generator[TokenData, None, None]:
    """
    Generate OCR text token-by-token with probability information.

    Yields TokenData for each generated token, enabling streaming display
    with confidence visualization.

    Args:
        image: PIL Image to process
        prompt: Optional custom prompt (default: natural reading extraction)
        max_new_tokens: Maximum tokens to generate
        top_k: Number of top alternatives to include
        top_p: Nucleus sampling parameter
        temperature: Sampling temperature (low = more deterministic)
        repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition)
        model_name: Which model to use (from AVAILABLE_MODELS keys)

    Yields:
        TokenData with token string, probability, and top-k alternatives
    """
    model, processor = load_model(model_name)

    if prompt is None:
        prompt = "Extract the text from the above document as if you were reading it naturally."

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    prompt_full = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    inputs = processor(
        text=[prompt_full],
        images=[image],
        return_tensors="pt",
        padding=True,
    ).to(_device)

    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask

    # Get EOS token ID for stopping - check model config first, then tokenizer
    eos_token_id = model.config.eos_token_id
    if eos_token_id is None:
        eos_token_id = processor.tokenizer.eos_token_id
    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    elif eos_token_id is None:
        eos_token_id = []  # No EOS token - will rely on max_new_tokens

    # Track generated tokens
    generated_ids = input_ids.clone()

    # Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL)
    model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")}

    # Use DynamicCache for proper KV cache management
    from transformers import DynamicCache
    past_key_values = DynamicCache()

    # Track sequence length for cache_position
    seq_length = input_ids.shape[1]

    # Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL)
    # This is computed on the first forward pass and must be passed to subsequent passes
    rope_deltas = None

    with torch.no_grad():
        for step in range(max_new_tokens):
            # Forward pass
            if step == 0:
                # First step: include image data, full sequence
                cache_position = torch.arange(seq_length, device=_device)
                outputs = model(
                    input_ids=generated_ids,
                    attention_mask=attention_mask,
                    cache_position=cache_position,
                    past_key_values=past_key_values,
                    **model_inputs,
                    return_dict=True,
                    use_cache=True,
                )
            else:
                # Subsequent steps: only new token with cache
                cache_position = torch.tensor([seq_length], device=_device)
                outputs = model(
                    input_ids=generated_ids[:, -1:],
                    attention_mask=attention_mask,
                    cache_position=cache_position,
                    past_key_values=past_key_values,
                    rope_deltas=rope_deltas,  # Pass rope_deltas for correct multimodal position encoding
                    return_dict=True,
                    use_cache=True,
                )

            past_key_values = outputs.past_key_values
            # Capture rope_deltas from first pass for multimodal position encoding
            if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None:
                rope_deltas = outputs.rope_deltas

            # Get logits for last token position - convert to float32 to avoid overflow
            next_token_logits = outputs.logits[:, -1, :].float()

            # Apply repetition penalty to previously generated tokens
            if repetition_penalty != 1.0:
                for prev_token_id in generated_ids[0].tolist():
                    if next_token_logits[0, prev_token_id] < 0:
                        next_token_logits[0, prev_token_id] *= repetition_penalty
                    else:
                        next_token_logits[0, prev_token_id] /= repetition_penalty

            # Apply temperature
            if temperature > 0:
                next_token_logits = next_token_logits / temperature

            # Compute probabilities via softmax
            probs = torch.softmax(next_token_logits, dim=-1)

            # Get top-k probabilities and indices
            top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
            top_probs = top_probs[0].cpu().tolist()
            top_indices = top_indices[0].cpu().tolist()

            # Sample next token (argmax - we use temperature=1.0 for standard distribution)
            next_token_id = top_indices[0]
            next_token_prob = top_probs[0]

            # Check for EOS
            if next_token_id in eos_token_id:
                break

            # Decode token
            token_str = processor.decode([next_token_id], skip_special_tokens=False)

            # Build alternatives list (excluding the selected token)
            alternatives = []
            for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])):
                alt_token = processor.decode([alt_idx], skip_special_tokens=False)
                alternatives.append({"token": alt_token, "probability": alt_prob})

            # Calculate entropy from full top-k distribution
            all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives]
            token_entropy = calculate_entropy(all_probs)

            # Yield token data
            yield TokenData(
                token=token_str,
                probability=next_token_prob,
                alternatives=alternatives,
                entropy=token_entropy,
            )

            # Update for next iteration
            next_token_tensor = torch.tensor([[next_token_id]], device=_device)
            generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
            # Extend attention mask to cover full sequence (required for Qwen VL models)
            attention_mask = torch.cat(
                [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
                dim=-1,
            )
            seq_length += 1