Bapt120 commited on
Commit
3dcdc69
·
verified ·
1 Parent(s): f2b79b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +510 -0
app.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import threading
6
+
7
+ import spaces
8
+ import torch
9
+
10
+ import gradio as gr
11
+ from PIL import Image
12
+ from io import BytesIO
13
+ import pypdfium2 as pdfium
14
+ from transformers import (
15
+ LightOnOcrForConditionalGeneration,
16
+ LightOnOcrProcessor,
17
+ TextIteratorStreamer,
18
+ )
19
+ import re
20
+ import base64
21
+ from collections import OrderedDict
22
+
23
+ # Model Registry with all supported models
24
+ MODEL_REGISTRY = {
25
+ "LightOnOCR-2-1B (Best OCR)": {
26
+ "model_id": "lightonai/LightOnOCR-2-1B",
27
+ "has_bbox": False,
28
+ "description": "Best overall OCR performance",
29
+ },
30
+ "LightOnOCR-2-1B-base": {
31
+ "model_id": "lightonai/LightOnOCR-2-1B-base",
32
+ "has_bbox": False,
33
+ "description": "Base OCR model",
34
+ },
35
+ "LightOnOCR-2-1B-ocr-soup": {
36
+ "model_id": "lightonai/LightOnOCR-2-1B-ocr-soup",
37
+ "has_bbox": False,
38
+ "description": "OCR soup variant",
39
+ },
40
+ "LightOnOCR-2-1B-bbox (Best Bbox)": {
41
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox",
42
+ "has_bbox": True,
43
+ "description": "Best bounding box detection",
44
+ },
45
+ "LightOnOCR-2-1B-bbox-base": {
46
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox-base",
47
+ "has_bbox": True,
48
+ "description": "Base bounding box model",
49
+ },
50
+ "LightOnOCR-2-1B-bbox-soup": {
51
+ "model_id": "lightonai/LightOnOCR-2-1B-bbox-soup",
52
+ "has_bbox": True,
53
+ "description": "Bounding box soup variant",
54
+ },
55
+ }
56
+
57
+ DEFAULT_MODEL = "LightOnOCR-2-1B (Best OCR)"
58
+
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+
61
+ # Choose best attention implementation based on device
62
+ if device == "cuda":
63
+ attn_implementation = "sdpa"
64
+ dtype = torch.bfloat16
65
+ print("Using sdpa for GPU")
66
+ else:
67
+ attn_implementation = "eager" # Best for CPU
68
+ dtype = torch.float32
69
+ print("Using eager attention for CPU")
70
+
71
+
72
+ class ModelManager:
73
+ """Manages model loading with LRU caching and GPU memory management."""
74
+
75
+ def __init__(self, max_cached=2):
76
+ self._cache = OrderedDict() # {model_id: (model, processor)}
77
+ self._max_cached = max_cached
78
+
79
+ def get_model(self, model_name):
80
+ """Get model and processor, loading if necessary."""
81
+ config = MODEL_REGISTRY.get(model_name)
82
+ if config is None:
83
+ raise ValueError(f"Unknown model: {model_name}")
84
+
85
+ model_id = config["model_id"]
86
+
87
+ # Check cache
88
+ if model_id in self._cache:
89
+ # Move to end (most recently used)
90
+ self._cache.move_to_end(model_id)
91
+ print(f"Using cached model: {model_name}")
92
+ return self._cache[model_id]
93
+
94
+ # Evict oldest if cache is full
95
+ while len(self._cache) >= self._max_cached:
96
+ evicted_id, (evicted_model, _) = self._cache.popitem(last=False)
97
+ print(f"Evicting model from cache: {evicted_id}")
98
+ del evicted_model
99
+ if device == "cuda":
100
+ torch.cuda.empty_cache()
101
+
102
+ # Load new model
103
+ print(f"Loading model: {model_name} ({model_id})...")
104
+ hf_token = os.environ.get("HF_TOKEN")
105
+ model = LightOnOcrForConditionalGeneration.from_pretrained(
106
+ model_id,
107
+ attn_implementation=attn_implementation,
108
+ torch_dtype=dtype,
109
+ trust_remote_code=True,
110
+ token=hf_token
111
+ ).to(device).eval()
112
+
113
+ processor = LightOnOcrProcessor.from_pretrained(
114
+ model_id,
115
+ trust_remote_code=True,
116
+ token=hf_token
117
+ )
118
+
119
+ # Add to cache
120
+ self._cache[model_id] = (model, processor)
121
+ print(f"Model loaded successfully: {model_name}")
122
+
123
+ return model, processor
124
+
125
+ def get_model_info(self, model_name):
126
+ """Get model info without loading."""
127
+ return MODEL_REGISTRY.get(model_name)
128
+
129
+
130
+ # Initialize model manager
131
+ model_manager = ModelManager(max_cached=2)
132
+ print("Model manager initialized. Models will be loaded on first use.")
133
+
134
+
135
+ def render_pdf_page(page, max_resolution=1540, scale=2.77):
136
+ """Render a PDF page to PIL Image."""
137
+ width, height = page.get_size()
138
+ pixel_width = width * scale
139
+ pixel_height = height * scale
140
+ resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
141
+ target_scale = scale * resize_factor
142
+ return page.render(scale=target_scale, rev_byteorder=True).to_pil()
143
+
144
+
145
+ def process_pdf(pdf_path, page_num=1):
146
+ """Extract a specific page from PDF."""
147
+ pdf = pdfium.PdfDocument(pdf_path)
148
+ total_pages = len(pdf)
149
+ page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
150
+
151
+ page = pdf[page_idx]
152
+ img = render_pdf_page(page)
153
+
154
+ pdf.close()
155
+ return img, total_pages, page_idx + 1
156
+
157
+
158
+ def clean_output_text(text):
159
+ """Remove chat template artifacts from output."""
160
+ # Remove common chat template markers
161
+ markers_to_remove = ["system", "user", "assistant"]
162
+
163
+ # Split by lines and filter
164
+ lines = text.split('\n')
165
+ cleaned_lines = []
166
+
167
+ for line in lines:
168
+ stripped = line.strip()
169
+ # Skip lines that are just template markers
170
+ if stripped.lower() not in markers_to_remove:
171
+ cleaned_lines.append(line)
172
+
173
+ # Join back and strip leading/trailing whitespace
174
+ cleaned = '\n'.join(cleaned_lines).strip()
175
+
176
+ # Alternative approach: if there's an "assistant" marker, take everything after it
177
+ if "assistant" in text.lower():
178
+ parts = text.split("assistant", 1)
179
+ if len(parts) > 1:
180
+ cleaned = parts[1].strip()
181
+
182
+ return cleaned
183
+
184
+
185
+ # Bbox parsing pattern: ![image](image_N.png) x1,y1,x2,y2
186
+ BBOX_PATTERN = r'!\[image\]\((image_\d+\.png)\)\s+(\d+),(\d+),(\d+),(\d+)'
187
+
188
+
189
+ def parse_bbox_output(text):
190
+ """Parse bbox output and return cleaned text with list of detections."""
191
+ detections = []
192
+ for match in re.finditer(BBOX_PATTERN, text):
193
+ image_ref, x1, y1, x2, y2 = match.groups()
194
+ detections.append({
195
+ "ref": image_ref,
196
+ "coords": (int(x1), int(y1), int(x2), int(y2))
197
+ })
198
+ # Clean text: remove coordinates, keep markdown image refs
199
+ cleaned = re.sub(BBOX_PATTERN, r'![image](\1)', text)
200
+ return cleaned, detections
201
+
202
+
203
+ def crop_from_bbox(source_image, bbox, padding=5):
204
+ """Crop region from image based on normalized [0,1000] coords."""
205
+ w, h = source_image.size
206
+ x1, y1, x2, y2 = bbox["coords"]
207
+
208
+ # Convert to pixel coordinates (coords are normalized to 0-1000)
209
+ px1 = int(x1 * w / 1000)
210
+ py1 = int(y1 * h / 1000)
211
+ px2 = int(x2 * w / 1000)
212
+ py2 = int(y2 * h / 1000)
213
+
214
+ # Add padding, clamp to bounds
215
+ px1, py1 = max(0, px1 - padding), max(0, py1 - padding)
216
+ px2, py2 = min(w, px2 + padding), min(h, py2 + padding)
217
+
218
+ return source_image.crop((px1, py1, px2, py2))
219
+
220
+
221
+ def image_to_data_uri(image):
222
+ """Convert PIL image to base64 data URI for markdown embedding."""
223
+ buffer = BytesIO()
224
+ image.save(buffer, format="PNG")
225
+ b64 = base64.b64encode(buffer.getvalue()).decode()
226
+ return f"data:image/png;base64,{b64}"
227
+
228
+
229
+ def render_bbox_with_crops(raw_output, source_image):
230
+ """Replace markdown image placeholders with actual cropped images."""
231
+ cleaned, detections = parse_bbox_output(raw_output)
232
+
233
+ for bbox in detections:
234
+ try:
235
+ cropped = crop_from_bbox(source_image, bbox)
236
+ data_uri = image_to_data_uri(cropped)
237
+ # Replace ![image](image_N.png) with ![Cropped](data:...)
238
+ cleaned = cleaned.replace(
239
+ f"![image]({bbox['ref']})",
240
+ f"![Cropped region]({data_uri})"
241
+ )
242
+ except Exception as e:
243
+ print(f"Error cropping bbox {bbox}: {e}")
244
+ # Keep original reference if cropping fails
245
+ continue
246
+
247
+ return cleaned
248
+
249
+
250
+ @spaces.GPU
251
+ def extract_text_from_image(image, model_name, temperature=0.2, stream=False):
252
+ """Extract text from image using LightOnOCR model."""
253
+ # Get model and processor from cache or load
254
+ model, processor = model_manager.get_model(model_name)
255
+
256
+ # Prepare the chat format
257
+ chat = [
258
+ {
259
+ "role": "user",
260
+ "content": [
261
+ {"type": "image", "url": image},
262
+ ],
263
+ }
264
+ ]
265
+
266
+ # Apply chat template and tokenize
267
+ inputs = processor.apply_chat_template(
268
+ chat,
269
+ add_generation_prompt=True,
270
+ tokenize=True,
271
+ return_dict=True,
272
+ return_tensors="pt"
273
+ )
274
+
275
+ # Move inputs to device AND convert to the correct dtype
276
+ inputs = {
277
+ k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
278
+ else v.to(device) if isinstance(v, torch.Tensor)
279
+ else v
280
+ for k, v in inputs.items()
281
+ }
282
+
283
+ generation_kwargs = dict(
284
+ **inputs,
285
+ max_new_tokens=2048,
286
+ temperature=temperature if temperature > 0 else 0.0,
287
+ use_cache=True,
288
+ do_sample=temperature > 0,
289
+ )
290
+
291
+ if stream:
292
+ # Setup streamer for streaming generation
293
+ streamer = TextIteratorStreamer(
294
+ processor.tokenizer,
295
+ skip_prompt=True,
296
+ skip_special_tokens=True
297
+ )
298
+ generation_kwargs["streamer"] = streamer
299
+
300
+ # Run generation in a separate thread
301
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
302
+ thread.start()
303
+
304
+ # Yield chunks as they arrive
305
+ full_text = ""
306
+ for new_text in streamer:
307
+ full_text += new_text
308
+ # Clean the accumulated text
309
+ cleaned_text = clean_output_text(full_text)
310
+ yield cleaned_text
311
+
312
+ thread.join()
313
+ else:
314
+ # Non-streaming generation
315
+ with torch.no_grad():
316
+ outputs = model.generate(**generation_kwargs)
317
+
318
+ # Decode the output
319
+ output_text = processor.decode(outputs[0], skip_special_tokens=True)
320
+
321
+ # Clean the output
322
+ cleaned_text = clean_output_text(output_text)
323
+
324
+ yield cleaned_text
325
+
326
+
327
+ def process_input(file_input, model_name, temperature, page_num, enable_streaming):
328
+ """Process uploaded file (image or PDF) and extract text with optional streaming."""
329
+ if file_input is None:
330
+ yield "Please upload an image or PDF first.", "", "", None, gr.update()
331
+ return
332
+
333
+ image_to_process = None
334
+ page_info = ""
335
+
336
+ file_path = file_input if isinstance(file_input, str) else file_input.name
337
+
338
+ # Handle PDF files
339
+ if file_path.lower().endswith('.pdf'):
340
+ try:
341
+ image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
342
+ page_info = f"Processing page {actual_page} of {total_pages}"
343
+ except Exception as e:
344
+ yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
345
+ return
346
+ # Handle image files
347
+ else:
348
+ try:
349
+ image_to_process = Image.open(file_path)
350
+ page_info = "Processing image"
351
+ except Exception as e:
352
+ yield f"Error opening image: {str(e)}", "", "", None, gr.update()
353
+ return
354
+
355
+ # Check if model has bbox capability
356
+ model_info = MODEL_REGISTRY.get(model_name, {})
357
+ has_bbox = model_info.get("has_bbox", False)
358
+
359
+ try:
360
+ # Extract text using LightOnOCR with optional streaming
361
+ for extracted_text in extract_text_from_image(image_to_process, model_name, temperature, stream=enable_streaming):
362
+ # For bbox models, render cropped images inline
363
+ if has_bbox:
364
+ rendered_text = render_bbox_with_crops(extracted_text, image_to_process)
365
+ else:
366
+ rendered_text = extracted_text
367
+ yield rendered_text, extracted_text, page_info, image_to_process, gr.update()
368
+
369
+ except Exception as e:
370
+ error_msg = f"Error during text extraction: {str(e)}"
371
+ yield error_msg, error_msg, page_info, image_to_process, gr.update()
372
+
373
+
374
+ def update_slider(file_input):
375
+ """Update page slider based on PDF page count."""
376
+ if file_input is None:
377
+ return gr.update(maximum=20, value=1)
378
+
379
+ file_path = file_input if isinstance(file_input, str) else file_input.name
380
+
381
+ if file_path.lower().endswith('.pdf'):
382
+ try:
383
+ pdf = pdfium.PdfDocument(file_path)
384
+ total_pages = len(pdf)
385
+ pdf.close()
386
+ return gr.update(maximum=total_pages, value=1)
387
+ except:
388
+ return gr.update(maximum=20, value=1)
389
+ else:
390
+ return gr.update(maximum=1, value=1)
391
+
392
+
393
+ # Helper function to get model info text
394
+ def get_model_info_text(model_name):
395
+ """Return formatted model info string."""
396
+ info = MODEL_REGISTRY.get(model_name, {})
397
+ has_bbox = "Yes - will show cropped regions inline" if info.get("has_bbox", False) else "No"
398
+ return f"**Description:** {info.get('description', 'N/A')}\n**Bounding Box Detection:** {has_bbox}"
399
+
400
+
401
+ # Create Gradio interface
402
+ with gr.Blocks(title="LightOnOCR-2 Multi-Model OCR") as demo:
403
+ gr.Markdown(f"""
404
+ # LightOnOCR-2 Multi-Model OCR
405
+
406
+ **How to use:**
407
+ 1. Select a model (OCR models for text extraction, Bbox models for region detection)
408
+ 2. Upload an image or PDF
409
+ 3. For PDFs: select which page to extract
410
+ 4. Click "Extract Text"
411
+
412
+ **Note:** Bbox models output cropped regions inline. Check raw output for coordinates.
413
+
414
+ **Device:** {device.upper()} | **Attention:** {attn_implementation}
415
+ """)
416
+
417
+ with gr.Row():
418
+ with gr.Column(scale=1):
419
+ model_selector = gr.Dropdown(
420
+ choices=list(MODEL_REGISTRY.keys()),
421
+ value=DEFAULT_MODEL,
422
+ label="Model",
423
+ info="Select OCR model variant"
424
+ )
425
+ model_info = gr.Markdown(
426
+ value=get_model_info_text(DEFAULT_MODEL),
427
+ label="Model Info"
428
+ )
429
+ file_input = gr.File(
430
+ label="Upload Image or PDF",
431
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
432
+ type="filepath"
433
+ )
434
+ rendered_image = gr.Image(
435
+ label="Preview",
436
+ type="pil",
437
+ height=400,
438
+ interactive=False
439
+ )
440
+ num_pages = gr.Slider(
441
+ minimum=1,
442
+ maximum=20,
443
+ value=1,
444
+ step=1,
445
+ label="PDF: Page Number",
446
+ info="Select which page to extract"
447
+ )
448
+ page_info = gr.Textbox(
449
+ label="Processing Info",
450
+ value="",
451
+ interactive=False
452
+ )
453
+ temperature = gr.Slider(
454
+ minimum=0.0,
455
+ maximum=1.0,
456
+ value=0.2,
457
+ step=0.05,
458
+ label="Temperature",
459
+ info="0.0 = deterministic, Higher = more varied"
460
+ )
461
+ enable_streaming = gr.Checkbox(
462
+ label="Enable Streaming",
463
+ value=True,
464
+ info="Show text progressively as it's generated"
465
+ )
466
+ submit_btn = gr.Button("Extract Text", variant="primary")
467
+ clear_btn = gr.Button("Clear", variant="secondary")
468
+
469
+ with gr.Column(scale=2):
470
+ output_text = gr.Markdown(
471
+ label="📄 Extracted Text (Rendered)",
472
+ value="*Extracted text will appear here...*"
473
+ )
474
+
475
+ with gr.Row():
476
+ with gr.Column():
477
+ raw_output = gr.Textbox(
478
+ label="Raw Markdown Output",
479
+ placeholder="Raw text will appear here...",
480
+ lines=20,
481
+ max_lines=30
482
+ )
483
+
484
+ # Event handlers
485
+ submit_btn.click(
486
+ fn=process_input,
487
+ inputs=[file_input, model_selector, temperature, num_pages, enable_streaming],
488
+ outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
489
+ )
490
+
491
+ file_input.change(
492
+ fn=update_slider,
493
+ inputs=[file_input],
494
+ outputs=[num_pages]
495
+ )
496
+
497
+ model_selector.change(
498
+ fn=get_model_info_text,
499
+ inputs=[model_selector],
500
+ outputs=[model_info]
501
+ )
502
+
503
+ clear_btn.click(
504
+ fn=lambda: (None, DEFAULT_MODEL, get_model_info_text(DEFAULT_MODEL), "*Extracted text will appear here...*", "", "", None, 1),
505
+ outputs=[file_input, model_selector, model_info, output_text, raw_output, page_info, rendered_image, num_pages]
506
+ )
507
+
508
+
509
+ if __name__ == "__main__":
510
+ demo.launch(theme=gr.themes.Soft())