Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,11 +14,19 @@ from PIL import Image
|
|
| 14 |
import cv2
|
| 15 |
|
| 16 |
from transformers import (
|
| 17 |
-
Qwen2VLForConditionalGeneration,
|
| 18 |
Qwen2_5_VLForConditionalGeneration,
|
| 19 |
AutoProcessor,
|
| 20 |
TextIteratorStreamer,
|
| 21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from transformers.image_utils import load_image
|
| 23 |
from gradio.themes import Soft
|
| 24 |
from gradio.themes.utils import colors, fonts, sizes
|
|
@@ -148,6 +156,28 @@ if torch.cuda.is_available():
|
|
| 148 |
|
| 149 |
print("Using device:", device)
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
class RadioAnimated(gr.HTML):
|
| 152 |
def __init__(self, choices, value=None, **kwargs):
|
| 153 |
if not choices or len(choices) < 2:
|
|
@@ -215,7 +245,7 @@ class RadioAnimated(gr.HTML):
|
|
| 215 |
def apply_gpu_duration(val: str):
|
| 216 |
return int(val)
|
| 217 |
|
| 218 |
-
# Model V: Nanonets-OCR2-3B
|
| 219 |
MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
|
| 220 |
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
|
| 221 |
model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
@@ -224,54 +254,69 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
| 224 |
trust_remote_code=True,
|
| 225 |
torch_dtype=torch.float16
|
| 226 |
).to(device).eval()
|
|
|
|
| 227 |
|
| 228 |
-
# Model
|
| 229 |
-
|
| 230 |
-
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
|
| 231 |
-
model_x = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 232 |
-
MODEL_ID_X,
|
| 233 |
-
attn_implementation="flash_attention_2",
|
| 234 |
-
trust_remote_code=True,
|
| 235 |
-
torch_dtype=torch.float16
|
| 236 |
-
).to(device).eval()
|
| 237 |
-
|
| 238 |
-
# Model P: PaddleOCR-VL (NEW - More stable than Qwen3)
|
| 239 |
-
MODEL_ID_P = "PaddlePaddle/PaddleOCR-VL"
|
| 240 |
try:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
attn_implementation="flash_attention_2",
|
| 245 |
trust_remote_code=True,
|
| 246 |
torch_dtype=torch.float16
|
| 247 |
).to(device).eval()
|
| 248 |
-
|
| 249 |
-
print("β
|
| 250 |
except Exception as e:
|
| 251 |
-
print(f"β
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
# Model
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
|
| 277 |
max_new_tokens: int, temperature: float, top_p: float,
|
|
@@ -291,24 +336,28 @@ def generate_image(model_name: str, text: str, image: Image.Image,
|
|
| 291 |
Generates responses using the selected model for image input.
|
| 292 |
Yields raw text and Markdown-formatted text.
|
| 293 |
"""
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
model = model_m
|
| 297 |
-
elif model_name == "Qwen2-VL-OCR-2B":
|
| 298 |
-
processor = processor_x
|
| 299 |
-
model = model_x
|
| 300 |
-
elif model_name == "Nanonets-OCR2-3B":
|
| 301 |
processor = processor_v
|
| 302 |
model = model_v
|
| 303 |
-
elif model_name == "
|
| 304 |
-
if not
|
| 305 |
-
yield "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
return
|
| 307 |
-
processor =
|
| 308 |
-
model =
|
| 309 |
-
elif model_name == "olmOCR-7B-0725":
|
| 310 |
-
processor = processor_w
|
| 311 |
-
model = model_w
|
| 312 |
else:
|
| 313 |
yield "Invalid model selected.", "Invalid model selected."
|
| 314 |
return
|
|
@@ -317,6 +366,10 @@ def generate_image(model_name: str, text: str, image: Image.Image,
|
|
| 317 |
yield "Please upload an image.", "Please upload an image."
|
| 318 |
return
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
messages = [{
|
| 321 |
"role": "user",
|
| 322 |
"content": [
|
|
@@ -324,7 +377,13 @@ def generate_image(model_name: str, text: str, image: Image.Image,
|
|
| 324 |
{"type": "text", "text": text},
|
| 325 |
]
|
| 326 |
}]
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
inputs = processor(
|
| 330 |
text=[prompt_full],
|
|
@@ -354,23 +413,33 @@ def generate_image(model_name: str, text: str, image: Image.Image,
|
|
| 354 |
|
| 355 |
|
| 356 |
image_examples = [
|
| 357 |
-
["Perform
|
| 358 |
-
["
|
| 359 |
-
["
|
| 360 |
-
["
|
| 361 |
-
["Convert this page
|
| 362 |
]
|
| 363 |
|
| 364 |
# Build model choices dynamically
|
| 365 |
-
model_choices = ["Nanonets-OCR2-3B"
|
| 366 |
-
if
|
| 367 |
-
model_choices.append("
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
|
| 370 |
-
gr.Markdown("# **Multimodal OCR**", elem_id="main-title")
|
|
|
|
|
|
|
| 371 |
with gr.Row():
|
| 372 |
with gr.Column(scale=2):
|
| 373 |
-
image_query = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
image_upload = gr.Image(type="pil", label="Upload Image", height=290)
|
| 375 |
|
| 376 |
image_submit = gr.Button("Submit", variant="primary")
|
|
@@ -395,7 +464,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
|
|
| 395 |
model_choice = gr.Radio(
|
| 396 |
choices=model_choices,
|
| 397 |
label="Select Model",
|
| 398 |
-
value=
|
| 399 |
)
|
| 400 |
|
| 401 |
with gr.Row(elem_id="gpu-duration-container"):
|
|
@@ -409,6 +478,7 @@ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
|
|
| 409 |
gpu_duration_state = gr.Number(value=60, visible=False)
|
| 410 |
|
| 411 |
gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
|
|
|
|
| 412 |
|
| 413 |
radioanimated_gpu_duration.change(
|
| 414 |
fn=apply_gpu_duration,
|
|
|
|
| 14 |
import cv2
|
| 15 |
|
| 16 |
from transformers import (
|
|
|
|
| 17 |
Qwen2_5_VLForConditionalGeneration,
|
| 18 |
AutoProcessor,
|
| 19 |
TextIteratorStreamer,
|
| 20 |
)
|
| 21 |
+
|
| 22 |
+
# Try importing Qwen3VL if available
|
| 23 |
+
try:
|
| 24 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 25 |
+
QWEN3_AVAILABLE = True
|
| 26 |
+
except:
|
| 27 |
+
QWEN3_AVAILABLE = False
|
| 28 |
+
print("β οΈ Qwen3VL not available in current transformers version")
|
| 29 |
+
|
| 30 |
from transformers.image_utils import load_image
|
| 31 |
from gradio.themes import Soft
|
| 32 |
from gradio.themes.utils import colors, fonts, sizes
|
|
|
|
| 156 |
|
| 157 |
print("Using device:", device)
|
| 158 |
|
| 159 |
+
# Multilingual OCR prompt template
|
| 160 |
+
MULTILINGUAL_OCR_PROMPT = """Perform comprehensive OCR extraction on this document. Follow these rules:
|
| 161 |
+
|
| 162 |
+
1. Extract ALL text exactly as it appears in the original language
|
| 163 |
+
2. If the text is NOT in English, provide an English translation after the original text
|
| 164 |
+
3. Identify the document type and extract key fields
|
| 165 |
+
4. Preserve formatting and layout structure
|
| 166 |
+
|
| 167 |
+
Format your response as:
|
| 168 |
+
|
| 169 |
+
**Original Text:** (in source language)
|
| 170 |
+
[extracted text]
|
| 171 |
+
|
| 172 |
+
**English Translation:** (if not already in English)
|
| 173 |
+
[translated text]
|
| 174 |
+
|
| 175 |
+
**Key Fields Extracted:**
|
| 176 |
+
- Document type:
|
| 177 |
+
- [other relevant fields based on document type]
|
| 178 |
+
|
| 179 |
+
Be accurate and preserve all details."""
|
| 180 |
+
|
| 181 |
class RadioAnimated(gr.HTML):
|
| 182 |
def __init__(self, choices, value=None, **kwargs):
|
| 183 |
if not choices or len(choices) < 2:
|
|
|
|
| 245 |
def apply_gpu_duration(val: str):
|
| 246 |
return int(val)
|
| 247 |
|
| 248 |
+
# Model V: Nanonets-OCR2-3B (Kept)
|
| 249 |
MODEL_ID_V = "nanonets/Nanonets-OCR2-3B"
|
| 250 |
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
|
| 251 |
model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
|
|
| 254 |
trust_remote_code=True,
|
| 255 |
torch_dtype=torch.float16
|
| 256 |
).to(device).eval()
|
| 257 |
+
print("β Nanonets-OCR2-3B loaded")
|
| 258 |
|
| 259 |
+
# Model C1: Chhagan_ML-VL-OCR-v1 (NEW)
|
| 260 |
+
MODEL_ID_C1 = "Chhagan005/Chhagan_ML-VL-OCR-v1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
try:
|
| 262 |
+
processor_c1 = AutoProcessor.from_pretrained(MODEL_ID_C1, trust_remote_code=True)
|
| 263 |
+
model_c1 = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 264 |
+
MODEL_ID_C1,
|
| 265 |
attn_implementation="flash_attention_2",
|
| 266 |
trust_remote_code=True,
|
| 267 |
torch_dtype=torch.float16
|
| 268 |
).to(device).eval()
|
| 269 |
+
C1_AVAILABLE = True
|
| 270 |
+
print("β Chhagan_ML-VL-OCR-v1 loaded")
|
| 271 |
except Exception as e:
|
| 272 |
+
print(f"β Chhagan_ML-VL-OCR-v1 failed: {e}")
|
| 273 |
+
C1_AVAILABLE = False
|
| 274 |
+
processor_c1 = None
|
| 275 |
+
model_c1 = None
|
| 276 |
+
|
| 277 |
+
# Model C2: Chhagan-DocVL-Qwen3 (NEW)
|
| 278 |
+
MODEL_ID_C2 = "Chhagan005/Chhagan-DocVL-Qwen3"
|
| 279 |
+
C2_AVAILABLE = False
|
| 280 |
+
if QWEN3_AVAILABLE:
|
| 281 |
+
try:
|
| 282 |
+
processor_c2 = AutoProcessor.from_pretrained(MODEL_ID_C2, trust_remote_code=True)
|
| 283 |
+
model_c2 = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 284 |
+
MODEL_ID_C2,
|
| 285 |
+
attn_implementation="flash_attention_2",
|
| 286 |
+
trust_remote_code=True,
|
| 287 |
+
torch_dtype=torch.float16
|
| 288 |
+
).to(device).eval()
|
| 289 |
+
C2_AVAILABLE = True
|
| 290 |
+
print("β Chhagan-DocVL-Qwen3 loaded")
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"β Chhagan-DocVL-Qwen3 failed: {e}")
|
| 293 |
+
processor_c2 = None
|
| 294 |
+
model_c2 = None
|
| 295 |
+
else:
|
| 296 |
+
processor_c2 = None
|
| 297 |
+
model_c2 = None
|
| 298 |
+
|
| 299 |
+
# Model Q3: Qwen3-VL-2B-Instruct (NEW - Official)
|
| 300 |
+
MODEL_ID_Q3 = "Qwen/Qwen3-VL-2B-Instruct"
|
| 301 |
+
Q3_AVAILABLE = False
|
| 302 |
+
if QWEN3_AVAILABLE:
|
| 303 |
+
try:
|
| 304 |
+
processor_q3 = AutoProcessor.from_pretrained(MODEL_ID_Q3, trust_remote_code=True)
|
| 305 |
+
model_q3 = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 306 |
+
MODEL_ID_Q3,
|
| 307 |
+
attn_implementation="flash_attention_2",
|
| 308 |
+
trust_remote_code=True,
|
| 309 |
+
torch_dtype=torch.float16
|
| 310 |
+
).to(device).eval()
|
| 311 |
+
Q3_AVAILABLE = True
|
| 312 |
+
print("β Qwen3-VL-2B-Instruct loaded")
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"β Qwen3-VL-2B-Instruct failed: {e}")
|
| 315 |
+
processor_q3 = None
|
| 316 |
+
model_q3 = None
|
| 317 |
+
else:
|
| 318 |
+
processor_q3 = None
|
| 319 |
+
model_q3 = None
|
| 320 |
|
| 321 |
def calc_timeout_duration(model_name: str, text: str, image: Image.Image,
|
| 322 |
max_new_tokens: int, temperature: float, top_p: float,
|
|
|
|
| 336 |
Generates responses using the selected model for image input.
|
| 337 |
Yields raw text and Markdown-formatted text.
|
| 338 |
"""
|
| 339 |
+
# Select model and processor
|
| 340 |
+
if model_name == "Nanonets-OCR2-3B":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
processor = processor_v
|
| 342 |
model = model_v
|
| 343 |
+
elif model_name == "Chhagan-ML-VL-OCR-v1":
|
| 344 |
+
if not C1_AVAILABLE:
|
| 345 |
+
yield "Chhagan-ML-VL-OCR-v1 model is not available.", "Chhagan-ML-VL-OCR-v1 model is not available."
|
| 346 |
+
return
|
| 347 |
+
processor = processor_c1
|
| 348 |
+
model = model_c1
|
| 349 |
+
elif model_name == "Chhagan-DocVL-Qwen3":
|
| 350 |
+
if not C2_AVAILABLE:
|
| 351 |
+
yield "Chhagan-DocVL-Qwen3 model is not available. Requires transformers>=4.57", "Chhagan-DocVL-Qwen3 model is not available."
|
| 352 |
+
return
|
| 353 |
+
processor = processor_c2
|
| 354 |
+
model = model_c2
|
| 355 |
+
elif model_name == "Qwen3-VL-2B-Instruct":
|
| 356 |
+
if not Q3_AVAILABLE:
|
| 357 |
+
yield "Qwen3-VL-2B-Instruct model is not available. Requires transformers>=4.57", "Qwen3-VL-2B-Instruct model is not available."
|
| 358 |
return
|
| 359 |
+
processor = processor_q3
|
| 360 |
+
model = model_q3
|
|
|
|
|
|
|
|
|
|
| 361 |
else:
|
| 362 |
yield "Invalid model selected.", "Invalid model selected."
|
| 363 |
return
|
|
|
|
| 366 |
yield "Please upload an image.", "Please upload an image."
|
| 367 |
return
|
| 368 |
|
| 369 |
+
# Use multilingual prompt if user query is empty or simple
|
| 370 |
+
if not text or text.strip().lower() in ["ocr", "extract", "read"]:
|
| 371 |
+
text = MULTILINGUAL_OCR_PROMPT
|
| 372 |
+
|
| 373 |
messages = [{
|
| 374 |
"role": "user",
|
| 375 |
"content": [
|
|
|
|
| 377 |
{"type": "text", "text": text},
|
| 378 |
]
|
| 379 |
}]
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
print(f"Chat template error: {e}")
|
| 385 |
+
# Fallback to simple prompt
|
| 386 |
+
prompt_full = text
|
| 387 |
|
| 388 |
inputs = processor(
|
| 389 |
text=[prompt_full],
|
|
|
|
| 413 |
|
| 414 |
|
| 415 |
image_examples = [
|
| 416 |
+
["Perform comprehensive multilingual OCR with English translation", "examples/5.jpg"],
|
| 417 |
+
["Extract all text in original language and translate to English", "examples/4.jpg"],
|
| 418 |
+
["Perform OCR and provide structured key fields extraction", "examples/2.jpg"],
|
| 419 |
+
["Extract document details with original text and English translation", "examples/1.jpg"],
|
| 420 |
+
["Convert this page with multilingual support", "examples/3.jpg"],
|
| 421 |
]
|
| 422 |
|
| 423 |
# Build model choices dynamically
|
| 424 |
+
model_choices = ["Nanonets-OCR2-3B"]
|
| 425 |
+
if C1_AVAILABLE:
|
| 426 |
+
model_choices.append("Chhagan-ML-VL-OCR-v1")
|
| 427 |
+
if C2_AVAILABLE:
|
| 428 |
+
model_choices.append("Chhagan-DocVL-Qwen3")
|
| 429 |
+
if Q3_AVAILABLE:
|
| 430 |
+
model_choices.append("Qwen3-VL-2B-Instruct")
|
| 431 |
|
| 432 |
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
|
| 433 |
+
gr.Markdown("# **Multimodal Multilingual OCR**", elem_id="main-title")
|
| 434 |
+
gr.Markdown("*Supports multilingual text extraction with automatic English translation*")
|
| 435 |
+
|
| 436 |
with gr.Row():
|
| 437 |
with gr.Column(scale=2):
|
| 438 |
+
image_query = gr.Textbox(
|
| 439 |
+
label="Query Input",
|
| 440 |
+
placeholder="Leave empty for automatic multilingual extraction with translation...",
|
| 441 |
+
value=""
|
| 442 |
+
)
|
| 443 |
image_upload = gr.Image(type="pil", label="Upload Image", height=290)
|
| 444 |
|
| 445 |
image_submit = gr.Button("Submit", variant="primary")
|
|
|
|
| 464 |
model_choice = gr.Radio(
|
| 465 |
choices=model_choices,
|
| 466 |
label="Select Model",
|
| 467 |
+
value=model_choices[0]
|
| 468 |
)
|
| 469 |
|
| 470 |
with gr.Row(elem_id="gpu-duration-container"):
|
|
|
|
| 478 |
gpu_duration_state = gr.Number(value=60, visible=False)
|
| 479 |
|
| 480 |
gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
|
| 481 |
+
gr.Markdown(f"**Models loaded:** {', '.join(model_choices)}")
|
| 482 |
|
| 483 |
radioanimated_gpu_duration.change(
|
| 484 |
fn=apply_gpu_duration,
|