Spaces:
Runtime error
Runtime error
| # CRITICAL: Import spaces FIRST before any CUDA-related packages | |
| import spaces | |
| import os | |
| # Now import other packages | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModel, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| Qwen2VLForConditionalGeneration, # Changed from Qwen3VL | |
| Qwen2_5_VLForConditionalGeneration, | |
| TextIteratorStreamer | |
| ) | |
| from threading import Thread | |
| import time | |
| # Device setup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Chandra-OCR (uses Qwen2.5-VL architecture) | |
| MODEL_ID_V = "datalab-to/chandra" | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( # Changed to Qwen2_5 | |
| MODEL_ID_V, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| attn_implementation="sdpa" | |
| ).to(device).eval() | |
| # Load Nanonets-OCR2-3B | |
| MODEL_ID_X = "nanonets/Nanonets-OCR2-3B" | |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
| model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_X, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| attn_implementation="sdpa" | |
| ).to(device).eval() | |
| # Load Dots.OCR | |
| MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix" | |
| processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| model_d = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH_D, | |
| attn_implementation="sdpa", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ).eval() | |
| # Load olmOCR-2-7B-1025-FP8 (Quantized version) | |
| MODEL_ID_M = "allenai/olmOCR-2-7B-1025-FP8" | |
| processor_m = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True) | |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_M, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ).to(device).eval() | |
| # Load DeepSeek-OCR | |
| MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR" | |
| tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True) | |
| model_ds = AutoModel.from_pretrained( | |
| MODEL_ID_DS, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ).eval().to(device).to(torch.bfloat16) | |
| def generate_image(model_name: str, text: str, image: Image.Image, | |
| max_new_tokens: int, temperature: float, top_p: float, | |
| top_k: int, repetition_penalty: float, resolution_mode: str): | |
| """ | |
| Generates responses using the selected model for image input. | |
| Yields raw text and Markdown-formatted text. | |
| """ | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| # Handle DeepSeek-OCR separately due to different API | |
| if model_name == "DeepSeek-OCR": | |
| # DeepSeek-OCR resolution configs | |
| resolution_configs = { | |
| "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "Small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True} | |
| } | |
| config = resolution_configs[resolution_mode] | |
| # Save image temporarily | |
| temp_image_path = "/tmp/temp_ocr_image.jpg" | |
| image.save(temp_image_path) | |
| # DeepSeek-OCR uses special prompt format | |
| if not text: | |
| text = "Free OCR." | |
| prompt_ds = f"<image>\n{text}" | |
| try: | |
| # DeepSeek-OCR's custom infer method | |
| result = model_ds.infer( | |
| tokenizer_ds, | |
| prompt=prompt_ds, | |
| image_file=temp_image_path, | |
| output_path="/tmp", | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| test_compress=True, | |
| save_results=False | |
| ) | |
| yield result, result | |
| except Exception as e: | |
| yield f"Error: {str(e)}", f"Error: {str(e)}" | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| return | |
| # Handle other models with standard API | |
| if model_name == "olmOCR-2-7B-1025-FP8": | |
| processor = processor_m | |
| model = model_m | |
| elif model_name == "Nanonets-OCR2-3B": | |
| processor = processor_x | |
| model = model_x | |
| elif model_name == "Chandra-OCR": | |
| processor = processor_v | |
| model = model_v | |
| elif model_name == "Dots.OCR": | |
| processor = processor_d | |
| model = model_d | |
| else: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text}, | |
| ] | |
| }] | |
| prompt_full = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| streamer = TextIteratorStreamer( | |
| processor, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| buffer = buffer.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield buffer, buffer | |
| # Image examples | |
| image_examples = [ | |
| ["OCR the content perfectly.", "examples/3.jpg"], | |
| ["Perform OCR on the image.", "examples/1.jpg"], | |
| ["Extract the contents. [page].", "examples/2.jpg"], | |
| ] | |
| # CSS styling | |
| css = """ | |
| .gradio-container { | |
| max-width: 1400px; | |
| margin: auto; | |
| } | |
| .model-selector { | |
| font-size: 16px; | |
| } | |
| """ | |
| # Build Gradio interface | |
| with gr.Blocks(css=css, title="Multi-Model OCR Space") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 Multi-Model OCR Comparison Space | |
| Compare five state-of-the-art OCR models on your images: | |
| - **Chandra-OCR**: Specialized OCR model | |
| - **Nanonets-OCR2-3B**: High-accuracy OCR | |
| - **Dots.OCR**: Lightweight OCR solution | |
| - **olmOCR-2-7B-1025-FP8**: Advanced FP8 quantized OCR model | |
| - **DeepSeek-OCR**: Context compression OCR with 10× compression ratio (97% accuracy) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=[ | |
| "Chandra-OCR", | |
| "Nanonets-OCR2-3B", | |
| "Dots.OCR", | |
| "olmOCR-2-7B-1025-FP8", | |
| "DeepSeek-OCR" | |
| ], | |
| value="DeepSeek-OCR", | |
| label="Select OCR Model", | |
| elem_classes=["model-selector"] | |
| ) | |
| resolution_selector = gr.Dropdown( | |
| choices=["Tiny", "Small", "Base", "Large", "Gundam"], | |
| value="Gundam", | |
| label="DeepSeek-OCR Resolution Mode", | |
| info="Only applies to DeepSeek-OCR. Gundam mode recommended for best results.", | |
| visible=True | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| value="Perform OCR on this image.", | |
| label="Prompt", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens_slider = gr.Slider( | |
| minimum=256, | |
| maximum=8192, | |
| value=2048, | |
| step=256, | |
| label="Max New Tokens" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top K" | |
| ) | |
| repetition_penalty_slider = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty" | |
| ) | |
| submit_btn = gr.Button("🚀 Extract Text", variant="primary") | |
| clear_btn = gr.ClearButton() | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Extracted Text", | |
| lines=20, | |
| show_copy_button=True | |
| ) | |
| output_markdown = gr.Markdown(label="Formatted Output") | |
| gr.Examples( | |
| examples=image_examples, | |
| inputs=[text_input, image_input], | |
| label="Example Images" | |
| ) | |
| # Show/hide resolution selector based on model | |
| def update_resolution_visibility(model_name): | |
| return gr.update(visible=(model_name == "DeepSeek-OCR")) | |
| model_selector.change( | |
| fn=update_resolution_visibility, | |
| inputs=[model_selector], | |
| outputs=[resolution_selector] | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| model_selector, | |
| text_input, | |
| image_input, | |
| max_tokens_slider, | |
| temperature_slider, | |
| top_p_slider, | |
| top_k_slider, | |
| repetition_penalty_slider, | |
| resolution_selector | |
| ], | |
| outputs=[output_text, output_markdown] | |
| ) | |
| clear_btn.add([image_input, text_input, output_text, output_markdown]) | |
| gr.Markdown( | |
| """ | |
| ### Model Information: | |
| **DeepSeek-OCR Modes:** | |
| - **Tiny**: 64 tokens @ 512×512 (fastest, basic documents) | |
| - **Small**: 100 tokens @ 640×640 (good for simple pages) | |
| - **Base**: 256 tokens @ 1024×1024 (standard documents) | |
| - **Large**: 400 tokens @ 1280×1280 (complex layouts) | |
| - **Gundam**: Dynamic multi-view (recommended for best accuracy) | |
| ### Tips: | |
| - Upload clear images for best results | |
| - DeepSeek-OCR excels at table extraction and markdown conversion | |
| - Adjust temperature for more creative or conservative outputs | |
| - Try different models to compare performance on your specific use case | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |