ryandt commited on
Commit
c6f6682
·
verified ·
1 Parent(s): b94bee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -285
app.py CHANGED
@@ -1,333 +1,312 @@
1
  """
2
- Model loading and inference for OCR Confidence Visualization.
3
 
4
- Loads Nanonets-OCR2-3B (Qwen2.5-VL fine-tune) and provides
5
- inference with token-level probability extraction.
 
 
6
  """
7
 
8
- import math
9
- from dataclasses import dataclass, field
10
- from typing import Generator, Optional
11
 
12
- import torch
13
  from PIL import Image
14
- from transformers import AutoModelForImageTextToText, AutoProcessor
15
-
16
- # Available models for selection
17
- AVAILABLE_MODELS = {
18
- "Nanonets-OCR2-3B": "nanonets/Nanonets-OCR2-3B",
19
- "olmOCR-7B": "allenai/olmOCR-7B-0725",
20
- "Aya-Vision-8B": "CohereLabs/aya-vision-8b",
21
- }
22
 
23
- DEFAULT_MODEL = "Aya-Vision-8B"
 
 
 
 
 
24
 
25
- # Global model and processor (loaded once per model)
26
- _model = None
27
- _processor = None
28
- _device = None
29
- _current_model_name = None
30
 
31
 
32
- @dataclass
33
- class TokenData:
34
- """Data for a single generated token with probability info."""
35
 
36
- token: str
37
- probability: float
38
- alternatives: list[dict[str, float]] # [{"token": str, "probability": float}, ...]
39
- entropy: float = field(default=0.0) # Shannon entropy in bits
 
40
 
41
 
42
- def calculate_entropy(probs: list[float]) -> float:
43
- """Calculate Shannon entropy in bits from a probability distribution.
 
44
 
45
  Args:
46
- probs: List of probabilities (should sum to ~1.0).
47
 
48
  Returns:
49
- Entropy in bits. 0.0 for empty or single-certainty distributions.
50
  """
51
- entropy = 0.0
52
- for p in probs:
53
- if p > 0:
54
- entropy -= p * math.log2(p)
55
- return entropy
56
-
57
-
58
- def load_model(model_name: str = None):
59
- """Load the OCR model and processor. Reloads if model_name differs from current."""
60
- global _model, _processor, _device, _current_model_name
61
-
62
- if model_name is None:
63
- model_name = DEFAULT_MODEL
 
 
 
 
64
 
65
- model_id = AVAILABLE_MODELS.get(model_name, AVAILABLE_MODELS[DEFAULT_MODEL])
66
 
67
- # Return cached model if already loaded
68
- if _model is not None and _current_model_name == model_name:
69
- return _model, _processor
70
 
71
- # Unload previous model if switching
72
- if _model is not None:
73
- print(f"Unloading previous model: {_current_model_name}")
74
- del _model
75
- del _processor
76
- _model = None
77
- _processor = None
78
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
81
- print(f"Using device: {_device}")
82
- print(f"Loading model: {model_id}...")
83
 
84
- _processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
85
- _model = AutoModelForImageTextToText.from_pretrained(
86
- model_id,
87
- attn_implementation="flash_attention_2",
88
- trust_remote_code=True,
89
- torch_dtype=torch.float16,
90
- ).to(_device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- _current_model_name = model_name
93
- print("Model loaded successfully")
94
- return _model, _processor
95
 
96
 
97
- def run_ocr(image: Image.Image, prompt: str = None) -> str:
 
98
  """
99
- Run OCR on an image and return extracted text.
 
 
 
 
100
 
101
  Args:
102
  image: PIL Image to process
103
- prompt: Optional custom prompt (default: natural reading extraction)
104
 
105
  Returns:
106
- Extracted text from the image
107
  """
108
- model, processor = load_model()
109
-
110
- if prompt is None:
111
- prompt = "Extract the text from the above document as if you were reading it naturally."
112
-
113
- messages = [
114
- {
115
- "role": "user",
116
- "content": [
117
- {"type": "image"},
118
- {"type": "text", "text": prompt},
119
- ],
120
- }
121
- ]
122
 
123
- prompt_full = processor.apply_chat_template(
124
- messages, tokenize=False, add_generation_prompt=True
125
- )
126
 
127
- inputs = processor(
128
- text=[prompt_full],
129
- images=[image],
130
- return_tensors="pt",
131
- padding=True,
132
- ).to(_device)
133
-
134
- with torch.no_grad():
135
- output_ids = model.generate(
136
- **inputs,
137
- max_new_tokens=1024,
138
- do_sample=True,
139
- temperature=1,
140
- top_p=0.9,
141
- top_k=50,
142
- repetition_penalty=1.1,
143
- )
144
-
145
- # Slice off input tokens
146
- generated_ids = [
147
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, output_ids)
148
- ]
149
- output_text = processor.batch_decode(
150
- generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
151
- )[0]
152
-
153
- return output_text
154
-
155
-
156
- def generate_with_logprobs(
157
- image: Image.Image,
158
- prompt: Optional[str] = None,
159
- max_new_tokens: int = 1024,
160
- top_k: int = 20,
161
- top_p: float = 0.9,
162
- temperature: float = 1.0, # Use 1.0 for standard distribution, pick top token (argmax)
163
- repetition_penalty: float = 1.1,
164
- model_name: str = None,
165
- ) -> Generator[TokenData, None, None]:
166
  """
167
- Generate OCR text token-by-token with probability information.
168
 
169
- Yields TokenData for each generated token, enabling streaming display
170
- with confidence visualization.
 
 
 
 
 
171
 
172
  Args:
173
  image: PIL Image to process
174
- prompt: Optional custom prompt (default: natural reading extraction)
175
- max_new_tokens: Maximum tokens to generate
176
- top_k: Number of top alternatives to include
177
- top_p: Nucleus sampling parameter
178
- temperature: Sampling temperature (low = more deterministic)
179
- repetition_penalty: Penalty for repeating tokens (>1.0 reduces repetition)
180
- model_name: Which model to use (from AVAILABLE_MODELS keys)
181
 
182
  Yields:
183
- TokenData with token string, probability, and top-k alternatives
184
  """
185
- model, processor = load_model(model_name)
186
-
187
- if prompt is None:
188
- prompt = "Extract the text from the above document as if you were reading it naturally."
189
-
190
- messages = [
191
- {
192
- "role": "user",
193
- "content": [
194
- {"type": "image"},
195
- {"type": "text", "text": prompt},
196
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  }
198
- ]
199
 
200
- prompt_full = processor.apply_chat_template(
201
- messages, tokenize=False, add_generation_prompt=True
202
- )
 
203
 
204
- inputs = processor(
205
- text=[prompt_full],
206
- images=[image],
207
- return_tensors="pt",
208
- padding=True,
209
- ).to(_device)
210
-
211
- input_ids = inputs.input_ids
212
- attention_mask = inputs.attention_mask
213
-
214
- # Get EOS token ID for stopping - check model config first, then tokenizer
215
- eos_token_id = model.config.eos_token_id
216
- if eos_token_id is None:
217
- eos_token_id = processor.tokenizer.eos_token_id
218
- if isinstance(eos_token_id, int):
219
- eos_token_id = [eos_token_id]
220
- elif eos_token_id is None:
221
- eos_token_id = [] # No EOS token - will rely on max_new_tokens
222
-
223
- # Track generated tokens
224
- generated_ids = input_ids.clone()
225
-
226
- # Extract image inputs (pixel_values, image_grid_thw for Qwen2.5-VL)
227
- model_inputs = {k: v for k, v in inputs.items() if k not in ("input_ids", "attention_mask")}
228
-
229
- # Use DynamicCache for proper KV cache management
230
- from transformers import DynamicCache
231
- past_key_values = DynamicCache()
232
-
233
- # Track sequence length for cache_position
234
- seq_length = input_ids.shape[1]
235
-
236
- # Track rope_deltas for multimodal RoPE (required for Qwen2.5-VL)
237
- # This is computed on the first forward pass and must be passed to subsequent passes
238
- rope_deltas = None
239
-
240
- with torch.no_grad():
241
- for step in range(max_new_tokens):
242
- # Forward pass
243
- if step == 0:
244
- # First step: include image data, full sequence
245
- cache_position = torch.arange(seq_length, device=_device)
246
- outputs = model(
247
- input_ids=generated_ids,
248
- attention_mask=attention_mask,
249
- cache_position=cache_position,
250
- past_key_values=past_key_values,
251
- **model_inputs,
252
- return_dict=True,
253
- use_cache=True,
254
- )
255
- else:
256
- # Subsequent steps: only new token with cache
257
- cache_position = torch.tensor([seq_length], device=_device)
258
- outputs = model(
259
- input_ids=generated_ids[:, -1:],
260
- attention_mask=attention_mask,
261
- cache_position=cache_position,
262
- past_key_values=past_key_values,
263
- rope_deltas=rope_deltas, # Pass rope_deltas for correct multimodal position encoding
264
- return_dict=True,
265
- use_cache=True,
266
- )
267
-
268
- past_key_values = outputs.past_key_values
269
- # Capture rope_deltas from first pass for multimodal position encoding
270
- if step == 0 and hasattr(outputs, 'rope_deltas') and outputs.rope_deltas is not None:
271
- rope_deltas = outputs.rope_deltas
272
-
273
- # Get logits for last token position - convert to float32 to avoid overflow
274
- next_token_logits = outputs.logits[:, -1, :].float()
275
-
276
- # Apply repetition penalty to previously generated tokens
277
- if repetition_penalty != 1.0:
278
- for prev_token_id in generated_ids[0].tolist():
279
- if next_token_logits[0, prev_token_id] < 0:
280
- next_token_logits[0, prev_token_id] *= repetition_penalty
281
- else:
282
- next_token_logits[0, prev_token_id] /= repetition_penalty
283
-
284
- # Apply temperature
285
- if temperature > 0:
286
- next_token_logits = next_token_logits / temperature
287
-
288
- # Compute probabilities via softmax
289
- probs = torch.softmax(next_token_logits, dim=-1)
290
-
291
- # Get top-k probabilities and indices
292
- top_probs, top_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
293
- top_probs = top_probs[0].cpu().tolist()
294
- top_indices = top_indices[0].cpu().tolist()
295
-
296
- # Sample next token (argmax - we use temperature=1.0 for standard distribution)
297
- next_token_id = top_indices[0]
298
- next_token_prob = top_probs[0]
299
-
300
- # Check for EOS
301
- if next_token_id in eos_token_id:
302
- break
303
-
304
- # Decode token
305
- token_str = processor.decode([next_token_id], skip_special_tokens=False)
306
-
307
- # Build alternatives list (excluding the selected token)
308
- alternatives = []
309
- for idx, (alt_idx, alt_prob) in enumerate(zip(top_indices[1:], top_probs[1:])):
310
- alt_token = processor.decode([alt_idx], skip_special_tokens=False)
311
- alternatives.append({"token": alt_token, "probability": alt_prob})
312
-
313
- # Calculate entropy from full top-k distribution
314
- all_probs = [next_token_prob] + [alt["probability"] for alt in alternatives]
315
- token_entropy = calculate_entropy(all_probs)
316
-
317
- # Yield token data
318
- yield TokenData(
319
- token=token_str,
320
- probability=next_token_prob,
321
- alternatives=alternatives,
322
- entropy=token_entropy,
323
  )
324
-
325
- # Update for next iteration
326
- next_token_tensor = torch.tensor([[next_token_id]], device=_device)
327
- generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
328
- # Extend attention mask to cover full sequence (required for Qwen VL models)
329
- attention_mask = torch.cat(
330
- [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
331
- dim=-1,
 
 
 
 
 
 
 
 
332
  )
333
- seq_length += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ OCR Confidence Visualization - Gradio Application.
3
 
4
+ Upload a document image to extract text with confidence visualization.
5
+
6
+ Supports deployment to HuggingFace Spaces with ZeroGPU via @spaces.GPU decorator.
7
+ The decorator is effect-free in non-ZeroGPU environments for local development.
8
  """
9
 
10
+ import html
11
+ import json
12
+ from typing import Generator
13
 
14
+ import gradio as gr
15
  from PIL import Image
 
 
 
 
 
 
 
 
16
 
17
+ # Import spaces for ZeroGPU support (effect-free outside HuggingFace Spaces)
18
+ try:
19
+ import spaces
20
+ SPACES_AVAILABLE = True
21
+ except ImportError:
22
+ SPACES_AVAILABLE = False
23
 
24
+ from model import generate_with_logprobs, load_model, TokenData, AVAILABLE_MODELS, DEFAULT_MODEL
 
 
 
 
25
 
26
 
27
+ def gpu_decorator(duration: int = 120):
28
+ """
29
+ Return @spaces.GPU decorator if available, otherwise a no-op decorator.
30
 
31
+ This allows the code to work both locally and on HuggingFace Spaces.
32
+ """
33
+ if SPACES_AVAILABLE:
34
+ return spaces.GPU(duration=duration)
35
+ return lambda fn: fn
36
 
37
 
38
+ def probability_to_color(prob: float) -> str:
39
+ """
40
+ Map probability to a color for text and underline styling.
41
 
42
  Args:
43
+ prob: Confidence probability (0.0 to 1.0)
44
 
45
  Returns:
46
+ Hex color string
47
  """
48
+ if prob > 0.99:
49
+ return "#3b82f6" # Blue - very high confidence
50
+ elif prob > 0.95:
51
+ return "#16a34a" # Dark Green - high confidence
52
+ elif prob > 0.85:
53
+ return "#65a30d" # Darker Light Green - good confidence (darkened for readability)
54
+ elif prob > 0.70:
55
+ return "#ca8a04" # Darker Yellow - moderate confidence (darkened for readability)
56
+ elif prob > 0.50:
57
+ return "#ef4444" # Red - low confidence
58
+ else:
59
+ return "#a855f7" # Purple - very low confidence
60
+
61
+
62
+ def entropy_to_color(entropy: float) -> str:
63
+ """
64
+ Map entropy (in bits) to a color for visualization.
65
 
66
+ Higher entropy = more uncertainty = warmer colors.
67
 
68
+ Args:
69
+ entropy: Shannon entropy in bits (0.0 = certain)
 
70
 
71
+ Returns:
72
+ Hex color string
73
+ """
74
+ if entropy < 0.1:
75
+ return "#3b82f6" # Blue - very certain
76
+ elif entropy < 0.3:
77
+ return "#16a34a" # Dark Green - certain
78
+ elif entropy < 0.7:
79
+ return "#65a30d" # Green - fairly certain
80
+ elif entropy < 1.5:
81
+ return "#ca8a04" # Amber - some uncertainty
82
+ elif entropy < 2.5:
83
+ return "#ef4444" # Red - uncertain
84
+ else:
85
+ return "#a855f7" # Purple - very uncertain
86
+
87
+
88
+ def build_html_output(tokens: list[TokenData], mode: str = "probability") -> str:
89
+ """
90
+ Build HTML output from accumulated tokens with confidence coloring.
91
 
92
+ Args:
93
+ tokens: List of TokenData objects
94
+ mode: "probability" for confidence coloring, "entropy" for uncertainty coloring
95
 
96
+ Returns:
97
+ HTML string with styled token spans
98
+ """
99
+ # Font stack with emoji support
100
+ font_family = "'Cascadia Code', 'Fira Code', Consolas, monospace, 'Apple Color Emoji', 'Segoe UI Emoji', 'Noto Color Emoji'"
101
+
102
+ # CSS for hover underline effect
103
+ style_tag = '<style>.token-span:hover { text-decoration: underline !important; }</style>'
104
+
105
+ if not tokens:
106
+ return f'{style_tag}<div class="token-container" style="font-family: {font_family}; line-height: 1.8; padding: 10px;"></div>'
107
+
108
+ spans = []
109
+ for token_data in tokens:
110
+ # Escape HTML entities in token text
111
+ token_text = html.escape(token_data.token)
112
+
113
+ # Handle newlines - convert to <br>
114
+ if "\n" in token_text:
115
+ token_text = token_text.replace("\n", "<br>")
116
+ spans.append(token_text)
117
+ else:
118
+ # Get color based on mode
119
+ if mode == "entropy":
120
+ color = entropy_to_color(token_data.entropy)
121
+ else:
122
+ color = probability_to_color(token_data.probability)
123
+
124
+ # Encode alternatives as JSON for data attribute
125
+ alternatives_json = html.escape(json.dumps(token_data.alternatives))
126
+
127
+ # Build styled span with color (underline on hover via CSS)
128
+ span = (
129
+ f'<span class="token-span" style="color: {color}; '
130
+ f'text-decoration-color: {color}; cursor: pointer;" '
131
+ f'data-prob="{token_data.probability}" '
132
+ f'data-entropy="{token_data.entropy}" '
133
+ f'data-alternatives="{alternatives_json}">'
134
+ f'{token_text}</span>'
135
+ )
136
+ spans.append(span)
137
 
138
+ html_content = "".join(spans)
139
+ return f'{style_tag}<div class="token-container" style="font-family: {font_family}; line-height: 1.6; padding: 10px; white-space: pre-wrap;">{html_content}</div>'
 
140
 
141
 
142
+ @gpu_decorator(duration=120)
143
+ def transcribe_full(image: Image.Image, model_name: str = None) -> list[TokenData]:
144
  """
145
+ Run full OCR inference on GPU and return all tokens.
146
+
147
+ On HuggingFace Spaces with ZeroGPU, this function is decorated with
148
+ @spaces.GPU to allocate GPU resources for the duration of inference.
149
+ The GPU is released when the function returns.
150
 
151
  Args:
152
  image: PIL Image to process
153
+ model_name: Which model to use for inference
154
 
155
  Returns:
156
+ List of TokenData with token strings, probabilities, and alternatives
157
  """
158
+ return list(generate_with_logprobs(image, model_name=model_name))
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
 
 
 
160
 
161
+ def transcribe_streaming(image: Image.Image, model_name: str = None) -> Generator[tuple[str, str], None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  """
163
+ Stream OCR transcription with progressive HTML output for both views.
164
 
165
+ This function separates GPU-bound inference from HTML rendering:
166
+ 1. Shows a "Processing..." indicator during inference
167
+ 2. Runs full inference in a single GPU-decorated call
168
+ 3. Streams HTML rendering from pre-computed tokens (no GPU needed)
169
+
170
+ This architecture is required for HuggingFace ZeroGPU, which allocates
171
+ GPU resources per decorated function call rather than for streaming.
172
 
173
  Args:
174
  image: PIL Image to process
175
+ model_name: Which model to use for inference
 
 
 
 
 
 
176
 
177
  Yields:
178
+ Tuple of (probability_html, entropy_html) as tokens stream
179
  """
180
+ if image is None:
181
+ empty = '<div style="color: #666; padding: 10px;">Please upload an image.</div>'
182
+ yield empty, empty
183
+ return
184
+
185
+ # Show processing indicator during GPU inference
186
+ loading = f'''<div style="color: #60a5fa; padding: 10px; display: flex; align-items: center; gap: 10px;">
187
+ <div style="width: 20px; height: 20px; border: 2px solid #60a5fa; border-top-color: transparent; border-radius: 50%; animation: spin 1s linear infinite;"></div>
188
+ <style>@keyframes spin {{ to {{ transform: rotate(360deg); }} }}</style>
189
+ Processing image with {model_name or DEFAULT_MODEL}...
190
+ </div>'''
191
+ yield loading, loading
192
+
193
+ # Run full inference (GPU allocated here on ZeroGPU)
194
+ tokens = transcribe_full(image, model_name=model_name)
195
+
196
+ # Stream HTML rendering (no GPU needed)
197
+ accumulated: list[TokenData] = []
198
+ for token in tokens:
199
+ accumulated.append(token)
200
+ prob_html = build_html_output(accumulated, mode="probability")
201
+ entropy_html = build_html_output(accumulated, mode="entropy")
202
+ yield prob_html, entropy_html
203
+
204
+
205
+ # JavaScript for token alternatives panel (loaded via launch js parameter)
206
+ TOKEN_ALTERNATIVES_JS = """
207
+ (function() {
208
+ document.addEventListener('click', function(e) {
209
+ var token = e.target.closest('[data-alternatives]');
210
+ if (!token || !token.dataset.alternatives) return;
211
+
212
+ var panel = document.getElementById('alternatives-panel');
213
+ if (!panel) return;
214
+
215
+ var prob = parseFloat(token.dataset.prob) || 0;
216
+ var alts = JSON.parse(token.dataset.alternatives);
217
+ var tokenText = token.textContent;
218
+
219
+ // Build panel content
220
+ var html = '<div style="font-weight:600;margin-bottom:12px;padding-bottom:8px;border-bottom:1px solid #374151;">' +
221
+ 'Selected: "<span style="color:#60a5fa">' + tokenText + '</span>" (' + (prob * 100).toFixed(2) + '%)' +
222
+ '</div>';
223
+
224
+ if (alts.length === 0) {
225
+ html += '<div style="color:#9ca3af;font-style:italic">No alternatives available</div>';
226
+ } else {
227
+ html += '<div style="font-size:12px;color:#9ca3af;margin-bottom:8px;">Top ' + Math.min(alts.length, 10) + ' alternatives:</div>';
228
+ for (var i = 0; i < Math.min(alts.length, 10); i++) {
229
+ var alt = alts[i];
230
+ var altProb = (alt.probability * 100).toFixed(2);
231
+ var barWidth = Math.max(alt.probability * 100, 1);
232
+ html += '<div style="display:flex;align-items:center;margin:6px 0;">' +
233
+ '<span style="width:80px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;font-family:monospace;">' +
234
+ alt.token.replace(/</g,'&lt;').replace(/>/g,'&gt;') + '</span>' +
235
+ '<span style="width:55px;text-align:right;color:#9ca3af;font-size:12px;margin-right:10px;">' +
236
+ altProb + '%</span>' +
237
+ '<div style="flex:1;height:10px;background:#374151;border-radius:5px;overflow:hidden;">' +
238
+ '<div style="width:' + barWidth + '%;height:100%;background:#60a5fa;border-radius:5px;"></div>' +
239
+ '</div></div>';
240
+ }
241
  }
 
242
 
243
+ panel.innerHTML = html;
244
+ });
245
+ })();
246
+ """
247
 
248
+ # Initial HTML for alternatives panel
249
+ ALTERNATIVES_PANEL_INITIAL = '''
250
+ <div id="alternatives-panel" style="
251
+ padding: 16px;
252
+ background: #1f2937;
253
+ border-radius: 8px;
254
+ color: #e5e7eb;
255
+ font-family: system-ui, -apple-system, sans-serif;
256
+ font-size: 14px;
257
+ min-height: 100px;
258
+ ">
259
+ <div style="color: #9ca3af; font-style: italic;">
260
+ Click on any token above to see alternative predictions.
261
+ </div>
262
+ </div>
263
+ '''
264
+
265
+ # Build Gradio interface
266
+ with gr.Blocks(title="OCR Confidence Visualization") as demo:
267
+ gr.Markdown("# OCR Confidence Visualization")
268
+ gr.Markdown("Upload a document image to extract text with token streaming.")
269
+
270
+ with gr.Row():
271
+ with gr.Column(scale=1):
272
+ model_selector = gr.Radio(
273
+ choices=list(AVAILABLE_MODELS.keys()),
274
+ value=DEFAULT_MODEL,
275
+ label="Model",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  )
277
+ image_input = gr.Image(type="pil", label="Upload Document")
278
+ submit_btn = gr.Button("Transcribe", variant="primary")
279
+
280
+ with gr.Column(scale=2):
281
+ with gr.Tabs():
282
+ with gr.TabItem("Probability"):
283
+ output_html_prob = gr.HTML(
284
+ value='<div style="color: #666; padding: 10px;">Upload an image and click Transcribe to start.</div>',
285
+ )
286
+ with gr.TabItem("Entropy"):
287
+ output_html_entropy = gr.HTML(
288
+ value='<div style="color: #666; padding: 10px;">Upload an image and click Transcribe to start.</div>',
289
+ )
290
+ gr.Markdown("### Token Alternatives")
291
+ alternatives_html = gr.HTML(
292
+ value=ALTERNATIVES_PANEL_INITIAL,
293
  )
294
+
295
+ submit_btn.click(
296
+ fn=transcribe_streaming,
297
+ inputs=[image_input, model_selector],
298
+ outputs=[output_html_prob, output_html_entropy],
299
+ )
300
+
301
+
302
+ if __name__ == "__main__":
303
+ # Preload model at startup for local development
304
+ # On HuggingFace Spaces with ZeroGPU, model loading happens on first request
305
+ # when GPU is allocated by the @spaces.GPU decorator
306
+ if not SPACES_AVAILABLE:
307
+ print("Preloading model (local development)...")
308
+ load_model()
309
+ else:
310
+ print("ZeroGPU detected - model will load on first inference request")
311
+ print("Starting Gradio server...")
312
+ demo.launch(server_port=7860, js=TOKEN_ALTERNATIVES_JS)