Saksak65 commited on
Commit
de47001
Β·
verified Β·
1 Parent(s): b14aea6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +684 -0
app.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import warnings
3
+
4
+ # Suppress FutureWarning from spaces library about torch.distributed.reduce_op
5
+ warnings.filterwarnings("ignore", category=FutureWarning, module="spaces")
6
+
7
+ import base64
8
+ import os
9
+ import re
10
+ import subprocess
11
+ import sys
12
+ import threading
13
+ import time
14
+ from collections import OrderedDict
15
+ from io import BytesIO
16
+
17
+ import gradio as gr
18
+ import pypdfium2 as pdfium
19
+ import spaces
20
+ import torch
21
+ from openai import OpenAI
22
+ from PIL import Image
23
+ from transformers import (
24
+ LightOnOcrForConditionalGeneration,
25
+ LightOnOcrProcessor,
26
+ TextIteratorStreamer,
27
+ )
28
+
29
+ # vLLM endpoint configuration from environment variables
30
+ # VLLM_ENDPOINT_OCR = os.environ.get("VLLM_ENDPOINT_OCR")
31
+ # VLLM_ENDPOINT_BBOX = os.environ.get("VLLM_ENDPOINT_BBOX")
32
+
33
+ # Streaming configuration
34
+ STREAM_YIELD_INTERVAL = 0.5 # Yield every N seconds to reduce UI overhead
35
+
36
+ # Model Registry with all supported models
37
+ MODEL_REGISTRY = {
38
+ "LightOnOCR-2-1B (Best OCR)": {
39
+ "model_id": "lightonai/LightOnOCR-2-1B",
40
+ "has_bbox": False,
41
+ "description": "Best overall OCR performance",
42
+ # "vllm_endpoint": VLLM_ENDPOINT_OCR,
43
+ },
44
+ "LightOnOCR-2-1B-bbox (Best Bbox)": {
45
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox",
46
+ "has_bbox": True,
47
+ "description": "Best bounding box detection",
48
+ # "vllm_endpoint": VLLM_ENDPOINT_BBOX,
49
+ },
50
+ "LightOnOCR-2-1B-base": {
51
+ "model_id": "lightonai/LightOnOCR-2-1B-base",
52
+ "has_bbox": False,
53
+ "description": "Base OCR model",
54
+ },
55
+ "LightOnOCR-2-1B-bbox-base": {
56
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox-base",
57
+ "has_bbox": True,
58
+ "description": "Base bounding box model",
59
+ },
60
+ "LightOnOCR-2-1B-ocr-soup": {
61
+ "model_id": "lightonai/LightOnOCR-2-1B-ocr-soup",
62
+ "has_bbox": False,
63
+ "description": "OCR soup variant",
64
+ },
65
+ "LightOnOCR-2-1B-bbox-soup": {
66
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox-soup",
67
+ "has_bbox": True,
68
+ "description": "Bounding box soup variant",
69
+ },
70
+ }
71
+
72
+ DEFAULT_MODEL = "LightOnOCR-2-1B (Best OCR)"
73
+
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ # Choose best attention implementation based on device
77
+ if device == "cuda":
78
+ attn_implementation = "sdpa"
79
+ dtype = torch.bfloat16
80
+ print("Using sdpa for GPU")
81
+ else:
82
+ attn_implementation = "eager" # Best for CPU
83
+ dtype = torch.float32
84
+ print("Using eager attention for CPU")
85
+
86
+
87
+ class ModelManager:
88
+ """Manages model loading with LRU caching and GPU memory management."""
89
+
90
+ def __init__(self, max_cached=2):
91
+ self._cache = OrderedDict() # {model_id: (model, processor)}
92
+ self._max_cached = max_cached
93
+
94
+ def get_model(self, model_name):
95
+ """Get model and processor, loading if necessary."""
96
+ config = MODEL_REGISTRY.get(model_name)
97
+ if config is None:
98
+ raise ValueError(f"Unknown model: {model_name}")
99
+
100
+ model_id = config["model_id"]
101
+
102
+ # Check cache
103
+ if model_id in self._cache:
104
+ # Move to end (most recently used)
105
+ self._cache.move_to_end(model_id)
106
+ print(f"Using cached model: {model_name}")
107
+ return self._cache[model_id]
108
+
109
+ # Evict oldest if cache is full
110
+ while len(self._cache) >= self._max_cached:
111
+ evicted_id, (evicted_model, _) = self._cache.popitem(last=False)
112
+ print(f"Evicting model from cache: {evicted_id}")
113
+ del evicted_model
114
+ if device == "cuda":
115
+ torch.cuda.empty_cache()
116
+
117
+ # Load new model
118
+ print(f"Loading model: {model_name} ({model_id})...")
119
+ model = (
120
+ LightOnOcrForConditionalGeneration.from_pretrained(
121
+ model_id,
122
+ attn_implementation=attn_implementation,
123
+ torch_dtype=dtype,
124
+ trust_remote_code=True,
125
+ )
126
+ .to(device)
127
+ .eval()
128
+ )
129
+
130
+ processor = LightOnOcrProcessor.from_pretrained(
131
+ model_id, trust_remote_code=True
132
+ )
133
+
134
+ # Add to cache
135
+ self._cache[model_id] = (model, processor)
136
+ print(f"Model loaded successfully: {model_name}")
137
+
138
+ return model, processor
139
+
140
+ def get_model_info(self, model_name):
141
+ """Get model info without loading."""
142
+ return MODEL_REGISTRY.get(model_name)
143
+
144
+
145
+ # Initialize model manager
146
+ model_manager = ModelManager(max_cached=2)
147
+ print("Model manager initialized. Models will be loaded on first use.")
148
+
149
+
150
+ def render_pdf_page(page, max_resolution=1540, scale=2.77):
151
+ """Render a PDF page to PIL Image."""
152
+ width, height = page.get_size()
153
+ pixel_width = width * scale
154
+ pixel_height = height * scale
155
+ resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
156
+ target_scale = scale * resize_factor
157
+ return page.render(scale=target_scale, rev_byteorder=True).to_pil()
158
+
159
+
160
+ def process_pdf(pdf_path, page_num=1):
161
+ """Extract a specific page from PDF."""
162
+ pdf = pdfium.PdfDocument(pdf_path)
163
+ total_pages = len(pdf)
164
+ page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
165
+
166
+ page = pdf[page_idx]
167
+ img = render_pdf_page(page)
168
+
169
+ pdf.close()
170
+ return img, total_pages, page_idx + 1
171
+
172
+
173
+ def clean_output_text(text):
174
+ """Remove chat template artifacts from output."""
175
+ # Remove common chat template markers
176
+ markers_to_remove = ["system", "user", "assistant"]
177
+
178
+ # Split by lines and filter
179
+ lines = text.split("\n")
180
+ cleaned_lines = []
181
+
182
+ for line in lines:
183
+ stripped = line.strip()
184
+ # Skip lines that are just template markers
185
+ if stripped.lower() not in markers_to_remove:
186
+ cleaned_lines.append(line)
187
+
188
+ # Join back and strip leading/trailing whitespace
189
+ cleaned = "\n".join(cleaned_lines).strip()
190
+
191
+ # Alternative approach: if there's an "assistant" marker, take everything after it
192
+ if "assistant" in text.lower():
193
+ parts = text.split("assistant", 1)
194
+ if len(parts) > 1:
195
+ cleaned = parts[1].strip()
196
+
197
+ return cleaned
198
+
199
+
200
+ # Bbox parsing pattern: ![image](image_N.png)x1,y1,x2,y2 (no space between)
201
+ BBOX_PATTERN = r"!\[image\]\((image_\d+\.png)\)\s*(\d+),(\d+),(\d+),(\d+)"
202
+
203
+
204
+ def parse_bbox_output(text):
205
+ """Parse bbox output and return cleaned text with list of detections."""
206
+ detections = []
207
+ for match in re.finditer(BBOX_PATTERN, text):
208
+ image_ref, x1, y1, x2, y2 = match.groups()
209
+ detections.append(
210
+ {"ref": image_ref, "coords": (int(x1), int(y1), int(x2), int(y2))}
211
+ )
212
+ # Clean text: remove coordinates, keep markdown image refs
213
+ cleaned = re.sub(BBOX_PATTERN, r"![image](\1)", text)
214
+ return cleaned, detections
215
+
216
+
217
+ def crop_from_bbox(source_image, bbox, padding=5):
218
+ """Crop region from image based on normalized [0,1000] coords."""
219
+ w, h = source_image.size
220
+ x1, y1, x2, y2 = bbox["coords"]
221
+
222
+ # Convert to pixel coordinates (coords are normalized to 0-1000)
223
+ px1 = int(x1 * w / 1000)
224
+ py1 = int(y1 * h / 1000)
225
+ px2 = int(x2 * w / 1000)
226
+ py2 = int(y2 * h / 1000)
227
+
228
+ # Add padding, clamp to bounds
229
+ px1, py1 = max(0, px1 - padding), max(0, py1 - padding)
230
+ px2, py2 = min(w, px2 + padding), min(h, py2 + padding)
231
+
232
+ return source_image.crop((px1, py1, px2, py2))
233
+
234
+
235
+ def image_to_data_uri(image):
236
+ """Convert PIL image to base64 data URI for markdown embedding."""
237
+ buffer = BytesIO()
238
+ image.save(buffer, format="PNG")
239
+ b64 = base64.b64encode(buffer.getvalue()).decode()
240
+ return f"data:image/png;base64,{b64}"
241
+
242
+
243
+ def extract_text_via_vllm(image, model_name, temperature=0.2, stream=False, max_tokens=2048):
244
+ """Extract text from image using vLLM endpoint."""
245
+ config = MODEL_REGISTRY.get(model_name)
246
+ if config is None:
247
+ raise ValueError(f"Unknown model: {model_name}")
248
+
249
+ endpoint = config.get("vllm_endpoint")
250
+ if endpoint is None:
251
+ raise ValueError(f"Model {model_name} does not have a vLLM endpoint")
252
+
253
+ model_id = config["model_id"]
254
+
255
+ # Convert image to base64 data URI
256
+ if isinstance(image, Image.Image):
257
+ image_uri = image_to_data_uri(image)
258
+ else:
259
+ # Assume it's already a data URI or URL
260
+ image_uri = image
261
+
262
+ # Create OpenAI client pointing to vLLM endpoint
263
+ client = OpenAI(base_url=endpoint, api_key="not-needed")
264
+
265
+ # Prepare the message with image
266
+ messages = [
267
+ {
268
+ "role": "user",
269
+ "content": [
270
+ {"type": "image_url", "image_url": {"url": image_uri}},
271
+ ],
272
+ }
273
+ ]
274
+
275
+ if stream:
276
+ # Streaming response
277
+ response = client.chat.completions.create(
278
+ model=model_id,
279
+ messages=messages,
280
+ max_tokens=max_tokens,
281
+ temperature=temperature if temperature > 0 else 0.0,
282
+ top_p=0.9,
283
+ stream=True,
284
+ )
285
+
286
+ full_text = ""
287
+ last_yield_time = time.time()
288
+ for chunk in response:
289
+ if chunk.choices and chunk.choices[0].delta.content:
290
+ full_text += chunk.choices[0].delta.content
291
+ # Batch yields to reduce UI overhead
292
+ if time.time() - last_yield_time > STREAM_YIELD_INTERVAL:
293
+ yield clean_output_text(full_text)
294
+ last_yield_time = time.time()
295
+ # Final yield with cleaned text
296
+ yield clean_output_text(full_text)
297
+ else:
298
+ # Non-streaming response
299
+ response = client.chat.completions.create(
300
+ model=model_id,
301
+ messages=messages,
302
+ max_tokens=max_tokens,
303
+ temperature=temperature if temperature > 0 else 0.0,
304
+ top_p=0.9,
305
+ stream=False,
306
+ )
307
+
308
+ output_text = response.choices[0].message.content
309
+ cleaned_text = clean_output_text(output_text)
310
+ yield cleaned_text
311
+
312
+
313
+ def render_bbox_with_crops(raw_output, source_image):
314
+ """Replace markdown image placeholders with actual cropped images."""
315
+ cleaned, detections = parse_bbox_output(raw_output)
316
+
317
+ for bbox in detections:
318
+ try:
319
+ cropped = crop_from_bbox(source_image, bbox)
320
+ data_uri = image_to_data_uri(cropped)
321
+ # Replace ![image](image_N.png) with ![Cropped](data:...)
322
+ cleaned = cleaned.replace(
323
+ f"![image]({bbox['ref']})", f"![Cropped region]({data_uri})"
324
+ )
325
+ except Exception as e:
326
+ print(f"Error cropping bbox {bbox}: {e}")
327
+ # Keep original reference if cropping fails
328
+ continue
329
+
330
+ return cleaned
331
+
332
+
333
+ @spaces.GPU
334
+ def extract_text_from_image(image, model_name, temperature=0.2, stream=False, max_tokens=2048):
335
+ """Extract text from image using LightOnOCR model."""
336
+ # Check if model has a vLLM endpoint configured
337
+ config = MODEL_REGISTRY.get(model_name, {})
338
+ if config.get("vllm_endpoint"):
339
+ # Use vLLM endpoint instead of local model
340
+ yield from extract_text_via_vllm(image, model_name, temperature, stream, max_tokens)
341
+ return
342
+
343
+ # Get model and processor from cache or load
344
+ model, processor = model_manager.get_model(model_name)
345
+
346
+ # Prepare the chat format
347
+ chat = [
348
+ {
349
+ "role": "user",
350
+ "content": [
351
+ {"type": "image", "url": image},
352
+ ],
353
+ }
354
+ ]
355
+
356
+ # Apply chat template and tokenize
357
+ inputs = processor.apply_chat_template(
358
+ chat,
359
+ add_generation_prompt=True,
360
+ tokenize=True,
361
+ return_dict=True,
362
+ return_tensors="pt",
363
+ )
364
+
365
+ # Move inputs to device AND convert to the correct dtype
366
+ inputs = {
367
+ k: v.to(device=device, dtype=dtype)
368
+ if isinstance(v, torch.Tensor)
369
+ and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
370
+ else v.to(device)
371
+ if isinstance(v, torch.Tensor)
372
+ else v
373
+ for k, v in inputs.items()
374
+ }
375
+
376
+ generation_kwargs = dict(
377
+ **inputs,
378
+ max_new_tokens=max_tokens,
379
+ temperature=temperature if temperature > 0 else 0.0,
380
+ top_p=0.9,
381
+ top_k=0,
382
+ use_cache=True,
383
+ do_sample=temperature > 0,
384
+ )
385
+
386
+ if stream:
387
+ # Setup streamer for streaming generation
388
+ streamer = TextIteratorStreamer(
389
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
390
+ )
391
+ generation_kwargs["streamer"] = streamer
392
+
393
+ # Run generation in a separate thread
394
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
395
+ thread.start()
396
+
397
+ # Yield chunks as they arrive
398
+ full_text = ""
399
+ last_yield_time = time.time()
400
+ for new_text in streamer:
401
+ full_text += new_text
402
+ # Batch yields to reduce UI overhead
403
+ if time.time() - last_yield_time > STREAM_YIELD_INTERVAL:
404
+ yield clean_output_text(full_text)
405
+ last_yield_time = time.time()
406
+
407
+ thread.join()
408
+ # Final yield with cleaned text
409
+ yield clean_output_text(full_text)
410
+ else:
411
+ # Non-streaming generation
412
+ with torch.no_grad():
413
+ outputs = model.generate(**generation_kwargs)
414
+
415
+ # Decode the output
416
+ output_text = processor.decode(outputs[0], skip_special_tokens=True)
417
+
418
+ # Clean the output
419
+ cleaned_text = clean_output_text(output_text)
420
+
421
+ yield cleaned_text
422
+
423
+
424
+ def process_input(file_input, model_name, temperature, page_num, enable_streaming, max_output_tokens):
425
+ """Process uploaded file (image or PDF) and extract text with optional streaming."""
426
+ if file_input is None:
427
+ yield "Please upload an image or PDF first.", "", "", None, gr.update()
428
+ return
429
+
430
+ image_to_process = None
431
+ page_info = ""
432
+
433
+ file_path = file_input if isinstance(file_input, str) else file_input.name
434
+
435
+ # Handle PDF files
436
+ if file_path.lower().endswith(".pdf"):
437
+ try:
438
+ image_to_process, total_pages, actual_page = process_pdf(
439
+ file_path, int(page_num)
440
+ )
441
+ page_info = f"Processing page {actual_page} of {total_pages}"
442
+ except Exception as e:
443
+ yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
444
+ return
445
+ # Handle image files
446
+ else:
447
+ try:
448
+ image_to_process = Image.open(file_path)
449
+ page_info = "Processing image"
450
+ except Exception as e:
451
+ yield f"Error opening image: {str(e)}", "", "", None, gr.update()
452
+ return
453
+
454
+ # Check if model has bbox capability
455
+ model_info = MODEL_REGISTRY.get(model_name, {})
456
+ has_bbox = model_info.get("has_bbox", False)
457
+
458
+ try:
459
+ # Extract text using LightOnOCR with optional streaming
460
+ for extracted_text in extract_text_from_image(
461
+ image_to_process, model_name, temperature, stream=enable_streaming, max_tokens=max_output_tokens
462
+ ):
463
+ # For bbox models, render cropped images inline
464
+ if has_bbox:
465
+ rendered_text = render_bbox_with_crops(extracted_text, image_to_process)
466
+ else:
467
+ rendered_text = extracted_text
468
+ yield (
469
+ rendered_text,
470
+ extracted_text,
471
+ page_info,
472
+ image_to_process,
473
+ gr.update(),
474
+ )
475
+
476
+ except Exception as e:
477
+ error_msg = f"Error during text extraction: {str(e)}"
478
+ yield error_msg, error_msg, page_info, image_to_process, gr.update()
479
+
480
+
481
+ def update_slider_and_preview(file_input):
482
+ """Update page slider and preview image based on uploaded file."""
483
+ if file_input is None:
484
+ return gr.update(maximum=20, value=1), None
485
+
486
+ file_path = file_input if isinstance(file_input, str) else file_input.name
487
+
488
+ if file_path.lower().endswith(".pdf"):
489
+ try:
490
+ pdf = pdfium.PdfDocument(file_path)
491
+ total_pages = len(pdf)
492
+ # Render first page for preview
493
+ page = pdf[0]
494
+ preview_image = page.render(scale=2).to_pil()
495
+ pdf.close()
496
+ return gr.update(maximum=total_pages, value=1), preview_image
497
+ except:
498
+ return gr.update(maximum=20, value=1), None
499
+ else:
500
+ # It's an image file
501
+ try:
502
+ preview_image = Image.open(file_path)
503
+ return gr.update(maximum=1, value=1), preview_image
504
+ except:
505
+ return gr.update(maximum=1, value=1), None
506
+
507
+
508
+ # Helper function to get model info text
509
+ def get_model_info_text(model_name):
510
+ """Return formatted model info string."""
511
+ info = MODEL_REGISTRY.get(model_name, {})
512
+ has_bbox = (
513
+ "Yes - will show cropped regions inline"
514
+ if info.get("has_bbox", False)
515
+ else "No"
516
+ )
517
+ return f"**Description:** {info.get('description', 'N/A')}\n**Bounding Box Detection:** {has_bbox}"
518
+
519
+
520
+ # Create Gradio interface
521
+ with gr.Blocks(title="LightOnOCR-2 Multi-Model OCR") as demo:
522
+ gr.Markdown(f"""
523
+ # LightOnOCR-2 β€” Efficient 1B VLM for OCR
524
+
525
+ State-of-the-art OCR on OlmOCR-Bench, ~9Γ— smaller and faster than competitors. Handles tables, forms, math, multi-column layouts.
526
+
527
+ ⚑ **3.3Γ— faster** than Chandra, **1.7Γ— faster** than OlmOCR | πŸ’Έ **<$0.01/1k pages** | 🧠 End-to-end differentiable | πŸ“ Bbox variants for image detection
528
+
529
+ πŸ“„ [Paper](https://arxiv.org/pdf/2601.14251) | πŸ“ [Blog](https://huggingface.co/blog/lightonai/lightonocr-2) | πŸ“Š [Dataset](https://huggingface.co/datasets/lightonai/LightOnOCR-mix-0126) | πŸ““ [Finetuning](https://colab.research.google.com/drive/1WjbsFJZ4vOAAlKtcCauFLn_evo5UBRNa?usp=sharing)
530
+
531
+ ---
532
+
533
+ **How to use:** Select a model β†’ Upload image/PDF β†’ Click "Extract Text" | **Device:** {device.upper()} | **Attention:** {attn_implementation}
534
+ """)
535
+
536
+ with gr.Row():
537
+ with gr.Column(scale=1):
538
+ model_selector = gr.Dropdown(
539
+ choices=list(MODEL_REGISTRY.keys()),
540
+ value=DEFAULT_MODEL,
541
+ label="Model",
542
+ info="Select OCR model variant",
543
+ )
544
+ model_info = gr.Markdown(
545
+ value=get_model_info_text(DEFAULT_MODEL), label="Model Info"
546
+ )
547
+ file_input = gr.File(
548
+ label="Upload Image or PDF",
549
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
550
+ type="filepath",
551
+ )
552
+ rendered_image = gr.Image(
553
+ label="Preview", type="pil", height=400, interactive=False
554
+ )
555
+ num_pages = gr.Slider(
556
+ minimum=1,
557
+ maximum=20,
558
+ value=1,
559
+ step=1,
560
+ label="PDF: Page Number",
561
+ info="Select which page to extract",
562
+ )
563
+ page_info = gr.Textbox(label="Processing Info", value="", interactive=False)
564
+ temperature = gr.Slider(
565
+ minimum=0.0,
566
+ maximum=1.0,
567
+ value=0.2,
568
+ step=0.05,
569
+ label="Temperature",
570
+ info="0.0 = deterministic, Higher = more varied",
571
+ )
572
+ enable_streaming = gr.Checkbox(
573
+ label="Enable Streaming",
574
+ value=True,
575
+ info="Show text progressively as it's generated",
576
+ )
577
+ max_output_tokens = gr.Slider(
578
+ minimum=256,
579
+ maximum=8192,
580
+ value=2048,
581
+ step=256,
582
+ label="Max Output Tokens",
583
+ info="Maximum number of tokens to generate",
584
+ )
585
+ submit_btn = gr.Button("Extract Text", variant="primary")
586
+ clear_btn = gr.Button("Clear", variant="secondary")
587
+
588
+ with gr.Column(scale=2):
589
+ output_text = gr.Markdown(
590
+ label="πŸ“„ Extracted Text (Rendered)",
591
+ value="*Extracted text will appear here...*",
592
+ latex_delimiters=[
593
+ {"left": "$$", "right": "$$", "display": True},
594
+ {"left": "$", "right": "$", "display": False},
595
+ ],
596
+ )
597
+
598
+ # Example inputs with image previews
599
+ EXAMPLE_IMAGES = [
600
+ "examples/example_1.png",
601
+ "examples/example_2.png",
602
+ "examples/example_3.png",
603
+ "examples/example_4.png",
604
+ "examples/example_5.png",
605
+ "examples/example_6.png",
606
+ "examples/example_7.png",
607
+ "examples/example_8.png",
608
+ "examples/example_9.png",
609
+ ]
610
+
611
+ with gr.Accordion("πŸ“ Example Documents (click an image to load)", open=True):
612
+ example_gallery = gr.Gallery(
613
+ value=EXAMPLE_IMAGES,
614
+ columns=5,
615
+ rows=2,
616
+ height="auto",
617
+ object_fit="contain",
618
+ show_label=False,
619
+ allow_preview=False,
620
+ )
621
+
622
+ def load_example_image(evt: gr.SelectData):
623
+ """Load selected example image into file input."""
624
+ return EXAMPLE_IMAGES[evt.index]
625
+
626
+ example_gallery.select(
627
+ fn=load_example_image,
628
+ outputs=[file_input],
629
+ )
630
+
631
+ with gr.Row():
632
+ with gr.Column():
633
+ raw_output = gr.Textbox(
634
+ label="Raw Markdown Output",
635
+ placeholder="Raw text will appear here...",
636
+ lines=20,
637
+ max_lines=30,
638
+ )
639
+
640
+ # Event handlers
641
+ submit_btn.click(
642
+ fn=process_input,
643
+ inputs=[file_input, model_selector, temperature, num_pages, enable_streaming, max_output_tokens],
644
+ outputs=[output_text, raw_output, page_info, rendered_image, num_pages],
645
+ )
646
+
647
+ file_input.change(
648
+ fn=update_slider_and_preview,
649
+ inputs=[file_input],
650
+ outputs=[num_pages, rendered_image],
651
+ )
652
+
653
+ model_selector.change(
654
+ fn=get_model_info_text, inputs=[model_selector], outputs=[model_info]
655
+ )
656
+
657
+ clear_btn.click(
658
+ fn=lambda: (
659
+ None,
660
+ DEFAULT_MODEL,
661
+ get_model_info_text(DEFAULT_MODEL),
662
+ "*Extracted text will appear here...*",
663
+ "",
664
+ "",
665
+ None,
666
+ 1,
667
+ 2048,
668
+ ),
669
+ outputs=[
670
+ file_input,
671
+ model_selector,
672
+ model_info,
673
+ output_text,
674
+ raw_output,
675
+ page_info,
676
+ rendered_image,
677
+ num_pages,
678
+ max_output_tokens,
679
+ ],
680
+ )
681
+
682
+
683
+ if __name__ == "__main__":
684
+ demo.launch(theme=gr.themes.Soft(), ssr_mode=False)