Spaces:
Runtime error
Runtime error
| """ | |
| OCR Application with Multiple Models including DeepSeek OCR | |
| Merged version with working DeepSeek implementation | |
| """ | |
| import os | |
| import time | |
| import torch | |
| import spaces | |
| import warnings | |
| import tempfile | |
| import sys | |
| from io import StringIO | |
| from contextlib import contextmanager | |
| from threading import Thread | |
| from PIL import Image | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForCausalLM, | |
| AutoModel, | |
| AutoTokenizer, | |
| Qwen2_5_VLForConditionalGeneration, | |
| TextIteratorStreamer | |
| ) | |
| from qwen_vl_utils import process_vision_info | |
| # Suppress the warning about uninitialized weights | |
| warnings.filterwarnings('ignore', message='Some weights.*were not initialized') | |
| # Try importing Qwen3VL if available | |
| try: | |
| from transformers import Qwen3VLForConditionalGeneration | |
| except ImportError: | |
| Qwen3VLForConditionalGeneration = None | |
| MAX_MAX_NEW_TOKENS = 4096 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Initial Device: {device}") | |
| print(f"CUDA Available: {torch.cuda.is_available()}") | |
| # Load Chandra-OCR | |
| try: | |
| MODEL_ID_V = "datalab-to/chandra" | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| if Qwen3VLForConditionalGeneration: | |
| model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_V, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).eval() | |
| print("✓ Chandra-OCR loaded") | |
| else: | |
| model_v = None | |
| print("✗ Chandra-OCR: Qwen3VL not available") | |
| except Exception as e: | |
| model_v = None | |
| processor_v = None | |
| print(f"✗ Chandra-OCR: Failed to load - {str(e)}") | |
| # Load Nanonets-OCR2-3B | |
| try: | |
| MODEL_ID_X = "nanonets/Nanonets-OCR2-3B" | |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
| model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_X, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).eval() | |
| print("✓ Nanonets-OCR2-3B loaded") | |
| except Exception as e: | |
| model_x = None | |
| processor_x = None | |
| print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}") | |
| # Load Dots.OCR - will be moved to GPU when needed | |
| try: | |
| MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix" | |
| processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| model_d = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH_D, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ).eval() | |
| print("✓ Dots.OCR loaded") | |
| except Exception as e: | |
| model_d = None | |
| processor_d = None | |
| print(f"✗ Dots.OCR: Failed to load - {str(e)}") | |
| # Load olmOCR-2-7B-1025 | |
| try: | |
| MODEL_ID_M = "allenai/olmOCR-2-7B-1025" | |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_M, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 | |
| ).eval() | |
| print("✓ olmOCR-2-7B-1025 loaded") | |
| except Exception as e: | |
| model_m = None | |
| processor_m = None | |
| print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}") | |
| # Load DeepSeek-OCR using the working implementation | |
| try: | |
| MODEL_ID_DS = "deepseek-ai/DeepSeek-OCR" # Note: capital letters in DeepSeek-OCR | |
| print(f"Loading DeepSeek-OCR from {MODEL_ID_DS}...") | |
| tokenizer_ds = AutoTokenizer.from_pretrained(MODEL_ID_DS, trust_remote_code=True) | |
| print(" - Tokenizer loaded") | |
| model_ds = AutoModel.from_pretrained( | |
| MODEL_ID_DS, | |
| _attn_implementation="flash_attention_2", | |
| trust_remote_code=True, | |
| use_safetensors=True, | |
| ).eval() | |
| print("✓ DeepSeek-OCR loaded successfully") | |
| except Exception as e: | |
| model_ds = None | |
| tokenizer_ds = None | |
| print(f"✗ DeepSeek-OCR: Failed to load - {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| def capture_stdout(): | |
| """Capture stdout to get printed output from model""" | |
| old_stdout = sys.stdout | |
| sys.stdout = StringIO() | |
| try: | |
| yield sys.stdout | |
| finally: | |
| sys.stdout = old_stdout | |
| def generate_image_deepseek(text: str, image: Image.Image, | |
| preset: str = "gundam"): | |
| """ | |
| Special generation function for DeepSeek-OCR using its native infer method. | |
| Args: | |
| text: Prompt text (used to determine task type) | |
| image: PIL Image object to process | |
| preset: Model preset configuration | |
| Yields: | |
| tuple: (raw_text, markdown_text) | |
| """ | |
| if model_ds is None: | |
| yield "DeepSeek-OCR is not available.", "DeepSeek-OCR is not available." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| try: | |
| # Move model to GPU | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_ds.to(device).to(torch.bfloat16) | |
| # Create temp directory for this session | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Save image with proper format | |
| temp_image_path = os.path.join(temp_dir, "input_image.jpg") | |
| # Convert RGBA to RGB if necessary | |
| if image.mode in ('RGBA', 'LA', 'P'): | |
| rgb_image = Image.new('RGB', image.size, (255, 255, 255)) | |
| if image.mode == 'RGBA': | |
| rgb_image.paste(image, mask=image.split()[3]) | |
| else: | |
| rgb_image.paste(image) | |
| rgb_image.save(temp_image_path, 'JPEG', quality=95) | |
| else: | |
| image.save(temp_image_path, 'JPEG', quality=95) | |
| # Set parameters based on preset | |
| presets = { | |
| "tiny": {"base_size": 512, "image_size": 512, "crop_mode": False}, | |
| "small": {"base_size": 640, "image_size": 640, "crop_mode": False}, | |
| "base": {"base_size": 1024, "image_size": 1024, "crop_mode": False}, | |
| "large": {"base_size": 1280, "image_size": 1280, "crop_mode": False}, | |
| "gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True}, | |
| } | |
| config = presets[preset] | |
| # Determine task type from prompt | |
| if "markdown" in text.lower() or "convert" in text.lower(): | |
| prompt = "<image>\n<|grounding|>Convert the document to markdown. " | |
| else: | |
| prompt = "<image>\nFree OCR. " | |
| # Capture stdout while running inference | |
| captured_output = "" | |
| with capture_stdout() as output: | |
| result = model_ds.infer( | |
| tokenizer_ds, | |
| prompt=prompt, | |
| image_file=temp_image_path, | |
| output_path=temp_dir, | |
| base_size=config["base_size"], | |
| image_size=config["image_size"], | |
| crop_mode=config["crop_mode"], | |
| save_results=True, | |
| test_compress=True, | |
| ) | |
| captured_output = output.getvalue() | |
| # Extract the text from captured output | |
| extracted_text = "" | |
| # Look for the actual OCR result in the captured output | |
| lines = captured_output.split('\n') | |
| capture_text = False | |
| text_lines = [] | |
| for line in lines: | |
| # Start capturing after seeing certain patterns | |
| if "# " in line or line.strip().startswith("**"): | |
| capture_text = True | |
| if capture_text: | |
| # Stop at the separator lines | |
| if line.startswith("====") or line.startswith("---") and len(line) > 10: | |
| if text_lines: # Only stop if we've captured something | |
| break | |
| # Add non-empty lines that aren't debug output | |
| elif line.strip() and not line.startswith("image size:") and not line.startswith("valid image") and not line.startswith("output texts") and not line.startswith("compression"): | |
| text_lines.append(line) | |
| if text_lines: | |
| extracted_text = '\n'.join(text_lines) | |
| # If we didn't get text from stdout, check if result contains text | |
| if not extracted_text and result is not None: | |
| if isinstance(result, str): | |
| extracted_text = result | |
| elif isinstance(result, (list, tuple)) and len(result) > 0: | |
| if isinstance(result[0], str): | |
| extracted_text = result[0] | |
| elif hasattr(result[0], 'text'): | |
| extracted_text = result[0].text | |
| # Clean up any remaining markers from the text | |
| if extracted_text: | |
| clean_lines = [] | |
| for line in extracted_text.split('\n'): | |
| if not any(pattern in line.lower() for pattern in ['image size:', 'valid image', 'compression ratio', 'save results:', 'output texts']): | |
| clean_lines.append(line) | |
| extracted_text = '\n'.join(clean_lines).strip() | |
| # Move model back to CPU to free GPU memory | |
| model_ds.to("cpu") | |
| torch.cuda.empty_cache() | |
| # Return the extracted text | |
| final_text = extracted_text if extracted_text else "No text could be extracted from the image." | |
| yield final_text, final_text | |
| except Exception as e: | |
| error_msg = f"Error during DeepSeek generation: {str(e)}" | |
| print(f"Full error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| yield error_msg, error_msg | |
| def generate_image(model_name: str, text: str, image: Image.Image, | |
| max_new_tokens: int, temperature: float, top_p: float, | |
| top_k: int, repetition_penalty: float, deepseek_preset: str = "gundam"): | |
| """ | |
| Generates responses using the selected model for image input. | |
| Yields raw text and Markdown-formatted text. | |
| This function is decorated with @spaces.GPU to ensure it runs on GPU | |
| when available in Hugging Face Spaces. | |
| Args: | |
| model_name: Name of the OCR model to use | |
| text: Prompt text for the model | |
| image: PIL Image object to process | |
| max_new_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| repetition_penalty: Penalty for repeating tokens | |
| deepseek_preset: Preset for DeepSeek model | |
| Yields: | |
| tuple: (raw_text, markdown_text) | |
| """ | |
| # Special handling for DeepSeek-OCR | |
| if model_name == "DeepSeek-OCR": | |
| yield from generate_image_deepseek(text, image, deepseek_preset) | |
| return | |
| # Device will be cuda when @spaces.GPU decorator activates | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Select model and processor based on model_name | |
| if model_name == "olmOCR-2-7B-1025": | |
| if model_m is None: | |
| yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available." | |
| return | |
| processor = processor_m | |
| model = model_m.to(device) | |
| elif model_name == "Nanonets-OCR2-3B": | |
| if model_x is None: | |
| yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available." | |
| return | |
| processor = processor_x | |
| model = model_x.to(device) | |
| elif model_name == "Chandra-OCR": | |
| if model_v is None: | |
| yield "Chandra-OCR is not available.", "Chandra-OCR is not available." | |
| return | |
| processor = processor_v | |
| model = model_v.to(device) | |
| elif model_name == "Dots.OCR": | |
| if model_d is None: | |
| yield "Dots.OCR is not available.", "Dots.OCR is not available." | |
| return | |
| processor = processor_d | |
| model = model_d.to(device) | |
| else: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| if image is None: | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| try: | |
| # Prepare messages in chat format | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text}, | |
| ] | |
| }] | |
| # Apply chat template with fallback | |
| try: | |
| prompt_full = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except Exception as template_error: | |
| # Fallback: create a simple prompt without chat template | |
| print(f"Chat template error: {template_error}. Using fallback prompt.") | |
| prompt_full = f"{text}" | |
| # Process inputs | |
| inputs = processor( | |
| text=[prompt_full], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(device) | |
| # Setup streaming generation | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer if hasattr(processor, 'tokenizer') else processor, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = { | |
| **inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| } | |
| # Start generation in separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the results | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| buffer = buffer.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield buffer, buffer | |
| # Ensure thread completes | |
| thread.join() | |
| except Exception as e: | |
| error_msg = f"Error during generation: {str(e)}" | |
| print(f"Full error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| yield error_msg, error_msg | |
| # Example usage for Gradio interface | |
| if __name__ == "__main__": | |
| import gradio as gr | |
| # Determine available models | |
| available_models = [] | |
| if model_m is not None: | |
| available_models.append("olmOCR-2-7B-1025") | |
| print(" Added: olmOCR-2-7B-1025") | |
| if model_x is not None: | |
| available_models.append("Nanonets-OCR2-3B") | |
| print(" Added: Nanonets-OCR2-3B") | |
| if model_v is not None: | |
| available_models.append("Chandra-OCR") | |
| print(" Added: Chandra-OCR") | |
| if model_d is not None: | |
| available_models.append("Dots.OCR") | |
| print(" Added: Dots.OCR") | |
| if model_ds is not None: | |
| available_models.append("DeepSeek-OCR") | |
| print(" Added: DeepSeek-OCR") | |
| else: | |
| print(" Skipped: DeepSeek-OCR (model_ds is None)") | |
| if not available_models: | |
| print("ERROR: No models were loaded successfully!") | |
| exit(1) | |
| print(f"\n✓ Available models for dropdown: {', '.join(available_models)}") | |
| with gr.Blocks(title="Multi-Model OCR") as demo: | |
| gr.Markdown("# 🔍 Multi-Model OCR Application") | |
| gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_selector = gr.Dropdown( | |
| choices=available_models, | |
| value=available_models[0] if available_models else None, | |
| label="Select OCR Model" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| value="Extract all text from this image.", | |
| label="Prompt", | |
| lines=2 | |
| ) | |
| # DeepSeek-specific settings (visible when DeepSeek is selected) | |
| deepseek_preset = gr.Radio( | |
| choices=["gundam", "base", "large", "small", "tiny"], | |
| value="gundam", | |
| label="DeepSeek Preset", | |
| info="Gundam recommended for most documents", | |
| visible=False | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| step=1, | |
| label="Max New Tokens (not used for DeepSeek)" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (not used for DeepSeek)" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top P (not used for DeepSeek)" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top K (not used for DeepSeek)" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty (not used for DeepSeek)" | |
| ) | |
| submit_btn = gr.Button("Extract Text", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Extracted Text", lines=20) | |
| output_markdown = gr.Markdown(label="Formatted Output") | |
| gr.Markdown(""" | |
| ### Available Models: | |
| - **olmOCR-2-7B-1025**: Allen AI's OCR model | |
| - **Nanonets-OCR2-3B**: Nanonets OCR model | |
| - **Chandra-OCR**: Datalab OCR model | |
| - **Dots.OCR**: Stranger Vision OCR model | |
| - **DeepSeek-OCR**: DeepSeek AI's OCR model (uses native inference method) | |
| ### DeepSeek-OCR Presets: | |
| - **Gundam** (Recommended): Balanced performance with crop mode | |
| - **Base**: Standard quality without cropping | |
| - **Large**: Highest quality for complex documents | |
| - **Small**: Faster processing, good for simple text | |
| - **Tiny**: Fastest, suitable for clear printed text | |
| """) | |
| # Event handler to show/hide DeepSeek preset based on model selection | |
| def update_preset_visibility(model_name): | |
| return gr.update(visible=(model_name == "DeepSeek-OCR")) | |
| model_selector.change( | |
| fn=update_preset_visibility, | |
| inputs=[model_selector], | |
| outputs=[deepseek_preset] | |
| ) | |
| submit_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| model_selector, | |
| text_input, | |
| image_input, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| deepseek_preset | |
| ], | |
| outputs=[output_text, output_markdown] | |
| ) | |
| # Launch with share=True for Hugging Face Spaces | |
| demo.launch(share=True) |