# CRITICAL: Import spaces FIRST before any CUDA-related packages import spaces import os # Now import other packages import gradio as gr import torch from PIL import Image from transformers import ( AutoProcessor, AutoModel, AutoModelForCausalLM, AutoTokenizer, Qwen2VLForConditionalGeneration, # Changed from Qwen3VL Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer ) from threading import Thread import time # Device setup device = "cuda" if torch.cuda.is_available() else "cpu" # Load Chandra-OCR (uses Qwen2.5-VL architecture) MODEL_ID_V = "datalab-to/chandra" processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( # Changed to Qwen2_5 MODEL_ID_V, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="sdpa" ).to(device).eval() # Load Nanonets-OCR2-3B MODEL_ID_X = "nanonets/Nanonets-OCR2-3B" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="sdpa" ).to(device).eval() # Load Dots.OCR MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix" processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) model_d = AutoModelForCausalLM.from_pretrained( MODEL_PATH_D, attn_implementation="sdpa", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ).eval() # Load olmOCR-2-7B-1025-FP8 (Quantized version) MODEL_ID_M = "allenai/olmOCR-2-7B-1025-FP8" processor_m = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa" ).to(device).eval() # Load DeepSeek-OCR MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR" tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True) model_ds = AutoModel.from_pretrained( MODEL_ID_DS, attn_implementation="sdpa", trust_remote_code=True, use_safetensors=True ).eval().to(device).to(torch.bfloat16) @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, resolution_mode: str): """ Generates responses using the selected model for image input. Yields raw text and Markdown-formatted text. """ if image is None: yield "Please upload an image.", "Please upload an image." return # Handle DeepSeek-OCR separately due to different API if model_name == "DeepSeek-OCR": # DeepSeek-OCR resolution configs resolution_configs = { "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True} } config = resolution_configs[resolution_mode] # Save image temporarily temp_image_path = "/tmp/temp_ocr_image.jpg" image.save(temp_image_path) # DeepSeek-OCR uses special prompt format if not text: text = "Free OCR." prompt_ds = f"\n{text}" try: # DeepSeek-OCR's custom infer method result = model_ds.infer( tokenizer_ds, prompt=prompt_ds, image_file=temp_image_path, output_path="/tmp", base_size=config["base_size"], image_size=config["image_size"], crop_mode=config["crop_mode"], test_compress=True, save_results=False ) yield result, result except Exception as e: yield f"Error: {str(e)}", f"Error: {str(e)}" finally: # Clean up temp file if os.path.exists(temp_image_path): os.remove(temp_image_path) return # Handle other models with standard API if model_name == "olmOCR-2-7B-1025-FP8": processor = processor_m model = model_m elif model_name == "Nanonets-OCR2-3B": processor = processor_x model = model_x elif model_name == "Chandra-OCR": processor = processor_v model = model_v elif model_name == "Dots.OCR": processor = processor_d model = model_d else: yield "Invalid model selected.", "Invalid model selected." return messages = [{ "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": text}, ] }] prompt_full = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[prompt_full], images=[image], return_tensors="pt", padding=True ).to(device) streamer = TextIteratorStreamer( processor, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, buffer # Image examples image_examples = [ ["OCR the content perfectly.", "examples/3.jpg"], ["Perform OCR on the image.", "examples/1.jpg"], ["Extract the contents. [page].", "examples/2.jpg"], ] # CSS styling css = """ .gradio-container { max-width: 1400px; margin: auto; } .model-selector { font-size: 16px; } """ # Build Gradio interface with gr.Blocks(css=css, title="Multi-Model OCR Space") as demo: gr.Markdown( """ # 🔍 Multi-Model OCR Comparison Space Compare five state-of-the-art OCR models on your images: - **Chandra-OCR**: Specialized OCR model - **Nanonets-OCR2-3B**: High-accuracy OCR - **Dots.OCR**: Lightweight OCR solution - **olmOCR-2-7B-1025-FP8**: Advanced FP8 quantized OCR model - **DeepSeek-OCR**: Context compression OCR with 10× compression ratio (97% accuracy) """ ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown( choices=[ "Chandra-OCR", "Nanonets-OCR2-3B", "Dots.OCR", "olmOCR-2-7B-1025-FP8", "DeepSeek-OCR" ], value="DeepSeek-OCR", label="Select OCR Model", elem_classes=["model-selector"] ) resolution_selector = gr.Dropdown( choices=["Tiny", "Small", "Base", "Large", "Gundam"], value="Gundam", label="DeepSeek-OCR Resolution Mode", info="Only applies to DeepSeek-OCR. Gundam mode recommended for best results.", visible=True ) image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox( value="Perform OCR on this image.", label="Prompt", lines=2 ) with gr.Accordion("Advanced Settings", open=False): max_tokens_slider = gr.Slider( minimum=256, maximum=8192, value=2048, step=256, label="Max New Tokens" ) temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P" ) top_k_slider = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top K" ) repetition_penalty_slider = gr.Slider( minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty" ) submit_btn = gr.Button("🚀 Extract Text", variant="primary") clear_btn = gr.ClearButton() with gr.Column(scale=1): output_text = gr.Textbox( label="Extracted Text", lines=20, show_copy_button=True ) output_markdown = gr.Markdown(label="Formatted Output") gr.Examples( examples=image_examples, inputs=[text_input, image_input], label="Example Images" ) # Show/hide resolution selector based on model def update_resolution_visibility(model_name): return gr.update(visible=(model_name == "DeepSeek-OCR")) model_selector.change( fn=update_resolution_visibility, inputs=[model_selector], outputs=[resolution_selector] ) # Event handlers submit_btn.click( fn=generate_image, inputs=[ model_selector, text_input, image_input, max_tokens_slider, temperature_slider, top_p_slider, top_k_slider, repetition_penalty_slider, resolution_selector ], outputs=[output_text, output_markdown] ) clear_btn.add([image_input, text_input, output_text, output_markdown]) gr.Markdown( """ ### Model Information: **DeepSeek-OCR Modes:** - **Tiny**: 64 tokens @ 512×512 (fastest, basic documents) - **Small**: 100 tokens @ 640×640 (good for simple pages) - **Base**: 256 tokens @ 1024×1024 (standard documents) - **Large**: 400 tokens @ 1280×1280 (complex layouts) - **Gundam**: Dynamic multi-view (recommended for best accuracy) ### Tips: - Upload clear images for best results - DeepSeek-OCR excels at table extraction and markdown conversion - Adjust temperature for more creative or conservative outputs - Try different models to compare performance on your specific use case """ ) if __name__ == "__main__": demo.queue(max_size=20).launch()