ryandt commited on
Commit
b94bee0
·
verified ·
1 Parent(s): 26079d9

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +333 -0
model.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and inference for OCR Confidence Visualization.
3
+
4
+ Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides
5
+ inference with token-level probability extraction.
6
+ """
7
+
8
+ import math
9
+ from dataclasses import dataclass, field
10
+ from typing import Generator, Optional
11
+
12
+ import torch
13
+ from PIL import Image
14
+ from transformers import AutoModelForImageTextToText, AutoProcessor
15
+
16
+ # Available models for selection
17
+ AVAILABLE_MODELS = {
18
+ "Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B",
19
+ "olmOCR-7B": "allenai/olmOCR-7B-0725",
20
+ "Aya-Vision-8B": "CohereLabs/aya-vision-8b",
21
+ }
22
+
23
+ DEFAULT_MODEL = "Aya-Vision-8B"
24
+
25
+ # Global model and processor (loaded once per model)
26
+ _model = None
27
+ _processor = None
28
+ _device = None
29
+ _current_model_name = None
30
+
31
+
32
+ @dataclass
33
+ class TokenData:
34
+ """Data for a single generated token with probability info."""
35
+
36
+ token: str
37
+ probability: float
38
+ alternatives: list[dict[str, float]] # [{"token": str, "probability": float}, ...]
39
+ entropy: float = field(default=0.0) # Shannon entropy in bits
40
+
41
+
42
+ def calculate_entropy(probs: list[float]) -> float:
43
+ """Calculate Shannon entropy in bits from a probability distribution.
44
+
45
+ Args:
46
+ probs: List of probabilities (should sum to ~1.0).
47
+
48
+ Returns:
49
+ Entropy in bits. 0.0 for empty or single-certainty distributions.
50
+ """
51
+ entropy = 0.0
52
+ for p in probs:
53
+ if p > 0:
54
+ entropy -= p * math.log2(p)
55
+ return entropy
56
+
57
+
58
+ def load_model(model_name: str = None):
59
+ """Load the OCR model and processor. Reloads if model_name differs from current."""
60
+ global _model, _processor, _device, _current_model_name
61
+
62
+ if model_name is None:
63
+ model_name = DEFAULT_MODEL
64
+
65
+ model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL])
66
+
67
+ # Return cached model if already loaded
68
+ if _model is not None and _current_model_name == model_name:
69
+ return _model, _processor
70
+
71
+ # Unload previous model if switching
72
+ if _model is not None:
73
+ print(f"Unloading previous model: {_current_model_name}")
74
+ del _model
75
+ del _processor
76
+ _model = None
77
+ _processor = None
78
+ torch.cuda.empty_cache()
79
+
80
+ _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
81
+ print(f"Using device: {_device}")
82
+ print(f"Loading model: {model_id}...")
83
+
84
+ _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
85
+ _model = AutoModelForImageTextToText.from_pretrained(
86
+ model_id,
87
+ attn_implementation="flash_attention_2",
88
+ trust_remote_code=True,
89
+ torch_dtype=torch.float16,
90
+ ).to(_device).eval()
91
+
92
+ _current_model_name = model_name
93
+ print("Model loaded successfully")
94
+ return _model, _processor
95
+
96
+
97
+ def run_ocr(image: Image.Image, prompt: str = None) -> str:
98
+ """
99
+ Run OCR on an image and return extracted text.
100
+
101
+ Args:
102
+ image: PIL Image to process
103
+ prompt: Optional custom prompt (default: natural reading extraction)
104
+
105
+ Returns:
106
+ Extracted text from the image
107
+ """
108
+ model, processor = load_model()
109
+
110
+ if prompt is None:
111
+ prompt = "Extract the text from the above document as if you were reading it naturally."
112
+
113
+ messages = [
114
+ {
115
+ "role": "user",
116
+ "content": [
117
+ {"type": "image"},
118
+ {"type": "text", "text": prompt},
119
+ ],
120
+ }
121
+ ]
122
+
123
+ prompt_full = processor.apply_chat_template(
124
+ messages, tokenize=False, add_generation_prompt=True
125
+ )
126
+
127
+ inputs = processor(
128
+ text=[prompt_full],
129
+ images=[image],
130
+ return_tensors="pt",
131
+ padding=True,
132
+ ).to(_device)
133
+
134
+ with torch.no_grad():
135
+ output_ids = model.generate(
136
+ **inputs,
137
+ max_new_tokens=1024,
138
+ do_sample=True,
139
+ temperature=1,
140
+ top_p=0.9,
141
+ top_k=50,
142
+ repetition_penalty=1.1,
143
+ )
144
+
145
+ # Slice off input tokens
146
+ generated_ids = [
147
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
148
+ ]
149
+ output_text = processor.batch_decode(
150
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
151
+ )[0]
152
+
153
+ return output_text
154
+
155
+
156
+ def generate_with_logprobs(
157
+ image: Image.Image,
158
+ prompt: Optional[str] = None,
159
+ max_new_tokens: int = 1024,
160
+ top_k: int = 20,
161
+ top_p: float = 0.9,
162
+ temperature: float = 1.0, # Use 1.0 for standard distribution, pick top token (argmax)
163
+ repetition_penalty: float = 1.1,
164
+ model_name: str = None,
165
+ ) -> Generator[TokenData, None, None]:
166
+ """
167
+ Generate OCR text token-by-token with probability information.
168
+
169
+ Yields TokenData for each generated token, enabling streaming display
170
+ with confidence visualization.
171
+
172
+ Args:
173
+ image: PIL Image to process
174
+ prompt: Optional custom prompt (default: natural reading extraction)
175
+ max_new_tokens: Maximum tokens to generate
176
+ top_k: Number of top alternatives to include
177
+ top_p: Nucleus sampling parameter
178
+ temperature: Sampling temperature (low = more deterministic)
179
+ repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition)
180
+ model_name: Which model to use (from AVAILABLE_MODELS keys)
181
+
182
+ Yields:
183
+ TokenData with token string, probability, and top-k alternatives
184
+ """
185
+ model, processor = load_model(model_name)
186
+
187
+ if prompt is None:
188
+ prompt = "Extract the text from the above document as if you were reading it naturally."
189
+
190
+ messages = [
191
+ {
192
+ "role": "user",
193
+ "content": [
194
+ {"type": "image"},
195
+ {"type": "text", "text": prompt},
196
+ ],
197
+ }
198
+ ]
199
+
200
+ prompt_full = processor.apply_chat_template(
201
+ messages, tokenize=False, add_generation_prompt=True
202
+ )
203
+
204
+ inputs = processor(
205
+ text=[prompt_full],
206
+ images=[image],
207
+ return_tensors="pt",
208
+ padding=True,
209
+ ).to(_device)
210
+
211
+ input_ids = inputs.input_ids
212
+ attention_mask = inputs.attention_mask
213
+
214
+ # Get EOS token ID for stopping - check model config first, then tokenizer
215
+ eos_token_id = model.config.eos_token_id
216
+ if eos_token_id is None:
217
+ eos_token_id = processor.tokenizer.eos_token_id
218
+ if isinstance(eos_token_id, int):
219
+ eos_token_id = [eos_token_id]
220
+ elif eos_token_id is None:
221
+ eos_token_id = [] # No EOS token - will rely on max_new_tokens
222
+
223
+ # Track generated tokens
224
+ generated_ids = input_ids.clone()
225
+
226
+ # Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL)
227
+ model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")}
228
+
229
+ # Use DynamicCache for proper KV cache management
230
+ from transformers import DynamicCache
231
+ past_key_values = DynamicCache()
232
+
233
+ # Track sequence length for cache_position
234
+ seq_length = input_ids.shape[1]
235
+
236
+ # Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL)
237
+ # This is computed on the first forward pass and must be passed to subsequent passes
238
+ rope_deltas = None
239
+
240
+ with torch.no_grad():
241
+ for step in range(max_new_tokens):
242
+ # Forward pass
243
+ if step == 0:
244
+ # First step: include image data, full sequence
245
+ cache_position = torch.arange(seq_length, device=_device)
246
+ outputs = model(
247
+ input_ids=generated_ids,
248
+ attention_mask=attention_mask,
249
+ cache_position=cache_position,
250
+ past_key_values=past_key_values,
251
+ **model_inputs,
252
+ return_dict=True,
253
+ use_cache=True,
254
+ )
255
+ else:
256
+ # Subsequent steps: only new token with cache
257
+ cache_position = torch.tensor([seq_length], device=_device)
258
+ outputs = model(
259
+ input_ids=generated_ids[:, -1:],
260
+ attention_mask=attention_mask,
261
+ cache_position=cache_position,
262
+ past_key_values=past_key_values,
263
+ rope_deltas=rope_deltas, # Pass rope_deltas for correct multimodal position encoding
264
+ return_dict=True,
265
+ use_cache=True,
266
+ )
267
+
268
+ past_key_values = outputs.past_key_values
269
+ # Capture rope_deltas from first pass for multimodal position encoding
270
+ if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None:
271
+ rope_deltas = outputs.rope_deltas
272
+
273
+ # Get logits for last token position - convert to float32 to avoid overflow
274
+ next_token_logits = outputs.logits[:, -1, :].float()
275
+
276
+ # Apply repetition penalty to previously generated tokens
277
+ if repetition_penalty != 1.0:
278
+ for prev_token_id in generated_ids[0].tolist():
279
+ if next_token_logits[0, prev_token_id] < 0:
280
+ next_token_logits[0, prev_token_id] *= repetition_penalty
281
+ else:
282
+ next_token_logits[0, prev_token_id] /= repetition_penalty
283
+
284
+ # Apply temperature
285
+ if temperature > 0:
286
+ next_token_logits = next_token_logits / temperature
287
+
288
+ # Compute probabilities via softmax
289
+ probs = torch.softmax(next_token_logits, dim=-1)
290
+
291
+ # Get top-k probabilities and indices
292
+ top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
293
+ top_probs = top_probs[0].cpu().tolist()
294
+ top_indices = top_indices[0].cpu().tolist()
295
+
296
+ # Sample next token (argmax - we use temperature=1.0 for standard distribution)
297
+ next_token_id = top_indices[0]
298
+ next_token_prob = top_probs[0]
299
+
300
+ # Check for EOS
301
+ if next_token_id in eos_token_id:
302
+ break
303
+
304
+ # Decode token
305
+ token_str = processor.decode([next_token_id], skip_special_tokens=False)
306
+
307
+ # Build alternatives list (excluding the selected token)
308
+ alternatives = []
309
+ for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])):
310
+ alt_token = processor.decode([alt_idx], skip_special_tokens=False)
311
+ alternatives.append({"token": alt_token, "probability": alt_prob})
312
+
313
+ # Calculate entropy from full top-k distribution
314
+ all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives]
315
+ token_entropy = calculate_entropy(all_probs)
316
+
317
+ # Yield token data
318
+ yield TokenData(
319
+ token=token_str,
320
+ probability=next_token_prob,
321
+ alternatives=alternatives,
322
+ entropy=token_entropy,
323
+ )
324
+
325
+ # Update for next iteration
326
+ next_token_tensor = torch.tensor([[next_token_id]], device=_device)
327
+ generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
328
+ # Extend attention mask to cover full sequence (required for Qwen VL models)
329
+ attention_mask = torch.cat(
330
+ [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
331
+ dim=-1,
332
+ )
333
+ seq_length += 1