Multi-Model-OCR / app.py
IFMedTechdemo's picture
Update app.py
acb34dc verified
raw
history blame
11.4 kB
# CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces
import os
# Now import other packages
import gradio as gr
import torch
from PIL import Image
from transformers import (
AutoProcessor,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
Qwen2VLForConditionalGeneration, # Changed from Qwen3VL
Qwen2_5_VLForConditionalGeneration,
TextIteratorStreamer
)
from threading import Thread
import time
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Chandra-OCR (uses Qwen2.5-VL architecture)
MODEL_ID_V = "datalab-to/chandra"
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( # Changed to Qwen2_5
MODEL_ID_V,
trust_remote_code=True,
torch_dtype=torch.float16,
attn_implementation="sdpa"
).to(device).eval()
# Load Nanonets-OCR2-3B
MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_X,
trust_remote_code=True,
torch_dtype=torch.float16,
attn_implementation="sdpa"
).to(device).eval()
# Load Dots.OCR
MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
model_d = AutoModelForCausalLM.from_pretrained(
MODEL_PATH_D,
attn_implementation="sdpa",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
).eval()
# Load olmOCR-2-7B-1025-FP8 (Quantized version)
MODEL_ID_M = "allenai/olmOCR-2-7B-1025-FP8"
processor_m = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
).to(device).eval()
# Load DeepSeek-OCR
MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR"
tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True)
model_ds = AutoModel.from_pretrained(
MODEL_ID_DS,
attn_implementation="sdpa",
trust_remote_code=True,
use_safetensors=True
).eval().to(device).to(torch.bfloat16)
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int, temperature: float, top_p: float,
top_k: int, repetition_penalty: float, resolution_mode: str):
"""
Generates responses using the selected model for image input.
Yields raw text and Markdown-formatted text.
"""
if image is None:
yield "Please upload an image.", "Please upload an image."
return
# Handle DeepSeek-OCR separately due to different API
if model_name == "DeepSeek-OCR":
# DeepSeek-OCR resolution configs
resolution_configs = {
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
"Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True}
}
config = resolution_configs[resolution_mode]
# Save image temporarily
temp_image_path = "/tmp/temp_ocr_image.jpg"
image.save(temp_image_path)
# DeepSeek-OCR uses special prompt format
if not text:
text = "Free OCR."
prompt_ds = f"<image>\n{text}"
try:
# DeepSeek-OCR's custom infer method
result = model_ds.infer(
tokenizer_ds,
prompt=prompt_ds,
image_file=temp_image_path,
output_path="/tmp",
base_size=config["base_size"],
image_size=config["image_size"],
crop_mode=config["crop_mode"],
test_compress=True,
save_results=False
)
yield result, result
except Exception as e:
yield f"Error: {str(e)}", f"Error: {str(e)}"
finally:
# Clean up temp file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
return
# Handle other models with standard API
if model_name == "olmOCR-2-7B-1025-FP8":
processor = processor_m
model = model_m
elif model_name == "Nanonets-OCR2-3B":
processor = processor_x
model = model_x
elif model_name == "Chandra-OCR":
processor = processor_v
model = model_v
elif model_name == "Dots.OCR":
processor = processor_d
model = model_d
else:
yield "Invalid model selected.", "Invalid model selected."
return
messages = [{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text},
]
}]
prompt_full = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True
).to(device)
streamer = TextIteratorStreamer(
processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer, buffer
# Image examples
image_examples = [
["OCR the content perfectly.", "examples/3.jpg"],
["Perform OCR on the image.", "examples/1.jpg"],
["Extract the contents. [page].", "examples/2.jpg"],
]
# CSS styling
css = """
.gradio-container {
max-width: 1400px;
margin: auto;
}
.model-selector {
font-size: 16px;
}
"""
# Build Gradio interface
with gr.Blocks(css=css, title="Multi-Model OCR Space") as demo:
gr.Markdown(
"""
# 🔍 Multi-Model OCR Comparison Space
Compare five state-of-the-art OCR models on your images:
- **Chandra-OCR**: Specialized OCR model
- **Nanonets-OCR2-3B**: High-accuracy OCR
- **Dots.OCR**: Lightweight OCR solution
- **olmOCR-2-7B-1025-FP8**: Advanced FP8 quantized OCR model
- **DeepSeek-OCR**: Context compression OCR with 10× compression ratio (97% accuracy)
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=[
"Chandra-OCR",
"Nanonets-OCR2-3B",
"Dots.OCR",
"olmOCR-2-7B-1025-FP8",
"DeepSeek-OCR"
],
value="DeepSeek-OCR",
label="Select OCR Model",
elem_classes=["model-selector"]
)
resolution_selector = gr.Dropdown(
choices=["Tiny", "Small", "Base", "Large", "Gundam"],
value="Gundam",
label="DeepSeek-OCR Resolution Mode",
info="Only applies to DeepSeek-OCR. Gundam mode recommended for best results.",
visible=True
)
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
value="Perform OCR on this image.",
label="Prompt",
lines=2
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=256,
maximum=8192,
value=2048,
step=256,
label="Max New Tokens"
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top P"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top K"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty"
)
submit_btn = gr.Button("🚀 Extract Text", variant="primary")
clear_btn = gr.ClearButton()
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Extracted Text",
lines=20,
show_copy_button=True
)
output_markdown = gr.Markdown(label="Formatted Output")
gr.Examples(
examples=image_examples,
inputs=[text_input, image_input],
label="Example Images"
)
# Show/hide resolution selector based on model
def update_resolution_visibility(model_name):
return gr.update(visible=(model_name == "DeepSeek-OCR"))
model_selector.change(
fn=update_resolution_visibility,
inputs=[model_selector],
outputs=[resolution_selector]
)
# Event handlers
submit_btn.click(
fn=generate_image,
inputs=[
model_selector,
text_input,
image_input,
max_tokens_slider,
temperature_slider,
top_p_slider,
top_k_slider,
repetition_penalty_slider,
resolution_selector
],
outputs=[output_text, output_markdown]
)
clear_btn.add([image_input, text_input, output_text, output_markdown])
gr.Markdown(
"""
### Model Information:
**DeepSeek-OCR Modes:**
- **Tiny**: 64 tokens @ 512×512 (fastest, basic documents)
- **Small**: 100 tokens @ 640×640 (good for simple pages)
- **Base**: 256 tokens @ 1024×1024 (standard documents)
- **Large**: 400 tokens @ 1280×1280 (complex layouts)
- **Gundam**: Dynamic multi-view (recommended for best accuracy)
### Tips:
- Upload clear images for best results
- DeepSeek-OCR excels at table extraction and markdown conversion
- Adjust temperature for more creative or conservative outputs
- Try different models to compare performance on your specific use case
"""
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()