Chhagan005 commited on
Commit
1a70a82
Β·
verified Β·
1 Parent(s): 5982d54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -70
app.py CHANGED
@@ -14,11 +14,19 @@ from PIL import Image
14
  import cv2
15
 
16
  from transformers import (
17
- Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
 
 
 
 
 
 
 
 
 
22
  from transformers.image_utils import load_image
23
  from gradio.themes import Soft
24
  from gradio.themes.utils import colors, fonts, sizes
@@ -148,6 +156,28 @@ if torch.cuda.is_available():
148
 
149
  print("Using device:", device)
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  class RadioAnimated(gr.HTML):
152
  def __init__(self, choices, value=None, **kwargs):
153
  if not choices or len(choices) < 2:
@@ -215,7 +245,7 @@ class RadioAnimated(gr.HTML):
215
  def apply_gpu_duration(val: str):
216
  return int(val)
217
 
218
- # Model V: Nanonets-OCR2-3B
219
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
220
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
221
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -224,54 +254,69 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
224
  trust_remote_code=True,
225
  torch_dtype=torch.float16
226
  ).to(device).eval()
 
227
 
228
- # Model X: Qwen2-VL-OCR-2B
229
- MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
230
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
231
- model_x = Qwen2VLForConditionalGeneration.from_pretrained(
232
- MODEL_ID_X,
233
- attn_implementation="flash_attention_2",
234
- trust_remote_code=True,
235
- torch_dtype=torch.float16
236
- ).to(device).eval()
237
-
238
- # Model P: PaddleOCR-VL (NEW - More stable than Qwen3)
239
- MODEL_ID_P = "PaddlePaddle/PaddleOCR-VL"
240
  try:
241
- processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True)
242
- model_p = Qwen2_5_VLForConditionalGeneration.from_pretrained(
243
- MODEL_ID_P,
244
  attn_implementation="flash_attention_2",
245
  trust_remote_code=True,
246
  torch_dtype=torch.float16
247
  ).to(device).eval()
248
- PADDLE_AVAILABLE = True
249
- print("βœ“ PaddleOCR-VL model loaded successfully")
250
  except Exception as e:
251
- print(f"βœ— PaddleOCR-VL model not available: {e}")
252
- PADDLE_AVAILABLE = False
253
- processor_p = None
254
- model_p = None
255
-
256
- # Model W: olmOCR-7B-0725
257
- MODEL_ID_W = "allenai/olmOCR-7B-0725"
258
- processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
259
- model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained(
260
- MODEL_ID_W,
261
- attn_implementation="flash_attention_2",
262
- trust_remote_code=True,
263
- torch_dtype=torch.float16
264
- ).to(device).eval()
265
-
266
- # Model M: RolmOCR
267
- MODEL_ID_M = "reducto/RolmOCR"
268
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
269
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
270
- MODEL_ID_M,
271
- attn_implementation="flash_attention_2",
272
- trust_remote_code=True,
273
- torch_dtype=torch.float16
274
- ).to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
277
  max_new_tokens: int, temperature: float, top_p: float,
@@ -291,24 +336,28 @@ def generate_image(model_name: str, text: str, image: Image.Image,
291
  Generates responses using the selected model for image input.
292
  Yields raw text and Markdown-formatted text.
293
  """
294
- if model_name == "RolmOCR-7B":
295
- processor = processor_m
296
- model = model_m
297
- elif model_name == "Qwen2-VL-OCR-2B":
298
- processor = processor_x
299
- model = model_x
300
- elif model_name == "Nanonets-OCR2-3B":
301
  processor = processor_v
302
  model = model_v
303
- elif model_name == "PaddleOCR-VL":
304
- if not PADDLE_AVAILABLE:
305
- yield "PaddleOCR-VL model is not available.", "PaddleOCR-VL model is not available."
 
 
 
 
 
 
 
 
 
 
 
 
306
  return
307
- processor = processor_p
308
- model = model_p
309
- elif model_name == "olmOCR-7B-0725":
310
- processor = processor_w
311
- model = model_w
312
  else:
313
  yield "Invalid model selected.", "Invalid model selected."
314
  return
@@ -317,6 +366,10 @@ def generate_image(model_name: str, text: str, image: Image.Image,
317
  yield "Please upload an image.", "Please upload an image."
318
  return
319
 
 
 
 
 
320
  messages = [{
321
  "role": "user",
322
  "content": [
@@ -324,7 +377,13 @@ def generate_image(model_name: str, text: str, image: Image.Image,
324
  {"type": "text", "text": text},
325
  ]
326
  }]
327
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
328
 
329
  inputs = processor(
330
  text=[prompt_full],
@@ -354,23 +413,33 @@ def generate_image(model_name: str, text: str, image: Image.Image,
354
 
355
 
356
  image_examples = [
357
- ["Perform OCR on the image precisely.", "examples/5.jpg"],
358
- ["Run OCR on the image and ensure high accuracy.", "examples/4.jpg"],
359
- ["Conduct OCR on the image with exact text recognition.", "examples/2.jpg"],
360
- ["Perform precise OCR extraction on the image.", "examples/1.jpg"],
361
- ["Convert this page to docling", "examples/3.jpg"],
362
  ]
363
 
364
  # Build model choices dynamically
365
- model_choices = ["Nanonets-OCR2-3B", "olmOCR-7B-0725", "RolmOCR-7B", "Qwen2-VL-OCR-2B"]
366
- if PADDLE_AVAILABLE:
367
- model_choices.append("PaddleOCR-VL")
 
 
 
 
368
 
369
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
370
- gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
 
 
371
  with gr.Row():
372
  with gr.Column(scale=2):
373
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
 
 
 
 
374
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
375
 
376
  image_submit = gr.Button("Submit", variant="primary")
@@ -395,7 +464,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
395
  model_choice = gr.Radio(
396
  choices=model_choices,
397
  label="Select Model",
398
- value="Nanonets-OCR2-3B"
399
  )
400
 
401
  with gr.Row(elem_id="gpu-duration-container"):
@@ -409,6 +478,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
409
  gpu_duration_state = gr.Number(value=60, visible=False)
410
 
411
  gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
 
412
 
413
  radioanimated_gpu_duration.change(
414
  fn=apply_gpu_duration,
 
14
  import cv2
15
 
16
  from transformers import (
 
17
  Qwen2_5_VLForConditionalGeneration,
18
  AutoProcessor,
19
  TextIteratorStreamer,
20
  )
21
+
22
+ # Try importing Qwen3VL if available
23
+ try:
24
+ from transformers import Qwen3VLForConditionalGeneration
25
+ QWEN3_AVAILABLE = True
26
+ except:
27
+ QWEN3_AVAILABLE = False
28
+ print("⚠️ Qwen3VL not available in current transformers version")
29
+
30
  from transformers.image_utils import load_image
31
  from gradio.themes import Soft
32
  from gradio.themes.utils import colors, fonts, sizes
 
156
 
157
  print("Using device:", device)
158
 
159
+ # Multilingual OCR prompt template
160
+ MULTILINGUAL_OCR_PROMPT = """Perform comprehensive OCR extraction on this document. Follow these rules:
161
+
162
+ 1. Extract ALL text exactly as it appears in the original language
163
+ 2. If the text is NOT in English, provide an English translation after the original text
164
+ 3. Identify the document type and extract key fields
165
+ 4. Preserve formatting and layout structure
166
+
167
+ Format your response as:
168
+
169
+ **Original Text:** (in source language)
170
+ [extracted text]
171
+
172
+ **English Translation:** (if not already in English)
173
+ [translated text]
174
+
175
+ **Key Fields Extracted:**
176
+ - Document type:
177
+ - [other relevant fields based on document type]
178
+
179
+ Be accurate and preserve all details."""
180
+
181
  class RadioAnimated(gr.HTML):
182
  def __init__(self, choices, value=None, **kwargs):
183
  if not choices or len(choices) < 2:
 
245
  def apply_gpu_duration(val: str):
246
  return int(val)
247
 
248
+ # Model V: Nanonets-OCR2-3B (Kept)
249
  MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
250
  processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
251
  model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
254
  trust_remote_code=True,
255
  torch_dtype=torch.float16
256
  ).to(device).eval()
257
+ print("βœ“ Nanonets-OCR2-3B loaded")
258
 
259
+ # Model C1: Chhagan_ML-VL-OCR-v1 (NEW)
260
+ MODEL_ID_C1 = "Chhagan005/Chhagan_ML-VL-OCR-v1"
 
 
 
 
 
 
 
 
 
 
261
  try:
262
+ processor_c1 = AutoProcessor.from_pretrained(MODEL_ID_C1, trust_remote_code=True)
263
+ model_c1 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
264
+ MODEL_ID_C1,
265
  attn_implementation="flash_attention_2",
266
  trust_remote_code=True,
267
  torch_dtype=torch.float16
268
  ).to(device).eval()
269
+ C1_AVAILABLE = True
270
+ print("βœ“ Chhagan_ML-VL-OCR-v1 loaded")
271
  except Exception as e:
272
+ print(f"βœ— Chhagan_ML-VL-OCR-v1 failed: {e}")
273
+ C1_AVAILABLE = False
274
+ processor_c1 = None
275
+ model_c1 = None
276
+
277
+ # Model C2: Chhagan-DocVL-Qwen3 (NEW)
278
+ MODEL_ID_C2 = "Chhagan005/Chhagan-DocVL-Qwen3"
279
+ C2_AVAILABLE = False
280
+ if QWEN3_AVAILABLE:
281
+ try:
282
+ processor_c2 = AutoProcessor.from_pretrained(MODEL_ID_C2, trust_remote_code=True)
283
+ model_c2 = Qwen3VLForConditionalGeneration.from_pretrained(
284
+ MODEL_ID_C2,
285
+ attn_implementation="flash_attention_2",
286
+ trust_remote_code=True,
287
+ torch_dtype=torch.float16
288
+ ).to(device).eval()
289
+ C2_AVAILABLE = True
290
+ print("βœ“ Chhagan-DocVL-Qwen3 loaded")
291
+ except Exception as e:
292
+ print(f"βœ— Chhagan-DocVL-Qwen3 failed: {e}")
293
+ processor_c2 = None
294
+ model_c2 = None
295
+ else:
296
+ processor_c2 = None
297
+ model_c2 = None
298
+
299
+ # Model Q3: Qwen3-VL-2B-Instruct (NEW - Official)
300
+ MODEL_ID_Q3 = "Qwen/Qwen3-VL-2B-Instruct"
301
+ Q3_AVAILABLE = False
302
+ if QWEN3_AVAILABLE:
303
+ try:
304
+ processor_q3 = AutoProcessor.from_pretrained(MODEL_ID_Q3, trust_remote_code=True)
305
+ model_q3 = Qwen3VLForConditionalGeneration.from_pretrained(
306
+ MODEL_ID_Q3,
307
+ attn_implementation="flash_attention_2",
308
+ trust_remote_code=True,
309
+ torch_dtype=torch.float16
310
+ ).to(device).eval()
311
+ Q3_AVAILABLE = True
312
+ print("βœ“ Qwen3-VL-2B-Instruct loaded")
313
+ except Exception as e:
314
+ print(f"βœ— Qwen3-VL-2B-Instruct failed: {e}")
315
+ processor_q3 = None
316
+ model_q3 = None
317
+ else:
318
+ processor_q3 = None
319
+ model_q3 = None
320
 
321
  def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
322
  max_new_tokens: int, temperature: float, top_p: float,
 
336
  Generates responses using the selected model for image input.
337
  Yields raw text and Markdown-formatted text.
338
  """
339
+ # Select model and processor
340
+ if model_name == "Nanonets-OCR2-3B":
 
 
 
 
 
341
  processor = processor_v
342
  model = model_v
343
+ elif model_name == "Chhagan-ML-VL-OCR-v1":
344
+ if not C1_AVAILABLE:
345
+ yield "Chhagan-ML-VL-OCR-v1 model is not available.", "Chhagan-ML-VL-OCR-v1 model is not available."
346
+ return
347
+ processor = processor_c1
348
+ model = model_c1
349
+ elif model_name == "Chhagan-DocVL-Qwen3":
350
+ if not C2_AVAILABLE:
351
+ yield "Chhagan-DocVL-Qwen3 model is not available. Requires transformers>=4.57", "Chhagan-DocVL-Qwen3 model is not available."
352
+ return
353
+ processor = processor_c2
354
+ model = model_c2
355
+ elif model_name == "Qwen3-VL-2B-Instruct":
356
+ if not Q3_AVAILABLE:
357
+ yield "Qwen3-VL-2B-Instruct model is not available. Requires transformers>=4.57", "Qwen3-VL-2B-Instruct model is not available."
358
  return
359
+ processor = processor_q3
360
+ model = model_q3
 
 
 
361
  else:
362
  yield "Invalid model selected.", "Invalid model selected."
363
  return
 
366
  yield "Please upload an image.", "Please upload an image."
367
  return
368
 
369
+ # Use multilingual prompt if user query is empty or simple
370
+ if not text or text.strip().lower() in ["ocr", "extract", "read"]:
371
+ text = MULTILINGUAL_OCR_PROMPT
372
+
373
  messages = [{
374
  "role": "user",
375
  "content": [
 
377
  {"type": "text", "text": text},
378
  ]
379
  }]
380
+
381
+ try:
382
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
383
+ except Exception as e:
384
+ print(f"Chat template error: {e}")
385
+ # Fallback to simple prompt
386
+ prompt_full = text
387
 
388
  inputs = processor(
389
  text=[prompt_full],
 
413
 
414
 
415
  image_examples = [
416
+ ["Perform comprehensive multilingual OCR with English translation", "examples/5.jpg"],
417
+ ["Extract all text in original language and translate to English", "examples/4.jpg"],
418
+ ["Perform OCR and provide structured key fields extraction", "examples/2.jpg"],
419
+ ["Extract document details with original text and English translation", "examples/1.jpg"],
420
+ ["Convert this page with multilingual support", "examples/3.jpg"],
421
  ]
422
 
423
  # Build model choices dynamically
424
+ model_choices = ["Nanonets-OCR2-3B"]
425
+ if C1_AVAILABLE:
426
+ model_choices.append("Chhagan-ML-VL-OCR-v1")
427
+ if C2_AVAILABLE:
428
+ model_choices.append("Chhagan-DocVL-Qwen3")
429
+ if Q3_AVAILABLE:
430
+ model_choices.append("Qwen3-VL-2B-Instruct")
431
 
432
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
433
+ gr.Markdown("# **Multimodal Multilingual OCR**", elem_id="main-title")
434
+ gr.Markdown("*Supports multilingual text extraction with automatic English translation*")
435
+
436
  with gr.Row():
437
  with gr.Column(scale=2):
438
+ image_query = gr.Textbox(
439
+ label="Query Input",
440
+ placeholder="Leave empty for automatic multilingual extraction with translation...",
441
+ value=""
442
+ )
443
  image_upload = gr.Image(type="pil", label="Upload Image", height=290)
444
 
445
  image_submit = gr.Button("Submit", variant="primary")
 
464
  model_choice = gr.Radio(
465
  choices=model_choices,
466
  label="Select Model",
467
+ value=model_choices[0]
468
  )
469
 
470
  with gr.Row(elem_id="gpu-duration-container"):
 
478
  gpu_duration_state = gr.Number(value=60, visible=False)
479
 
480
  gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
481
+ gr.Markdown(f"**Models loaded:** {', '.join(model_choices)}")
482
 
483
  radioanimated_gpu_duration.change(
484
  fn=apply_gpu_duration,