import gradio as gr import torch import json import spaces import os from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor from transformers.processing_utils import ProcessorMixin from qwen_vl_utils import process_vision_info from huggingface_hub import login # Model configuration MODEL_PATH = "rednote-hilab/dots.ocr" # Optional authentication (required if the repository is gated) HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: print("Authenticating with Hugging Face token...") login(token=HF_TOKEN, add_to_git_credential=False) # Model and processor will be loaded on GPU when decorated function is called model = None processor = None def load_model(): """Load model and processor on GPU""" global model, processor if model is None: print(f"Loading model weights from {MODEL_PATH}...") # Try to use FlashAttention2 if available, otherwise use default attention try: import flash_attn attn_implementation = "flash_attention_2" print("Using FlashAttention2 for faster inference") except ImportError: attn_implementation = "eager" print("FlashAttention2 not available, using default attention") model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, token=HF_TOKEN, attn_implementation=attn_implementation ) print("Model loaded successfully.") print(f"Loading processor from {MODEL_PATH}...") # Patch check_argument_for_proper_class to allow None for video_processor _original_check = ProcessorMixin.check_argument_for_proper_class def _patched_check(self, attribute_name, value): if attribute_name == "video_processor" and value is None: return # Skip validation for None video_processor return _original_check(self, attribute_name, value) ProcessorMixin.check_argument_for_proper_class = _patched_check try: processor = AutoProcessor.from_pretrained( MODEL_PATH, trust_remote_code=True, token=HF_TOKEN ) print("Processor loaded successfully.") finally: # Restore original validation method ProcessorMixin.check_argument_for_proper_class = _original_check return model, processor # Predefined prompts PROMPTS = { "Full Layout + OCR (English)": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. 1. Bbox format: [x1, y1, x2, y2] 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. 3. Text Extraction & Formatting Rules: - Picture: For the 'Picture' category, the text field should be omitted. - Formula: Format its text as LaTeX. - Table: Format its text as HTML. - All Others (Text, Title, etc.): Format their text as Markdown. 4. Constraints: - The output text must be the original text from the image, with no translation. - All layout elements must be sorted according to human reading order. 5. Final Output: The entire output must be a single JSON object.""", "OCR Only": """Please extract all text from the image in reading order. Format the output as plain text, preserving the original structure as much as possible.""", "Layout Detection Only": """Please detect all layout elements in the image and output their bounding boxes and categories. Format: [{"bbox": [x1, y1, x2, y2], "category": "category_name"}]""", "Custom": "" } @spaces.GPU(duration=120) def process_image(image, prompt_type, custom_prompt): """Process image with OCR model""" try: # Load model and processor on GPU current_model, current_processor = load_model() # Determine which prompt to use if prompt_type == "Custom" and custom_prompt.strip(): prompt = custom_prompt else: prompt = PROMPTS[prompt_type] # Prepare messages messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] # Prepare inputs text = current_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = current_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Generate output with torch.no_grad(): generated_ids = current_model.generate( **inputs, max_new_tokens=24000, temperature=0.1, top_p=0.9, ) # Decode output generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = current_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # Try to format as JSON if possible try: parsed_json = json.loads(output_text) output_text = json.dumps(parsed_json, ensure_ascii=False, indent=2) except: pass # Keep as plain text if not valid JSON return output_text except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="dots.ocr - Multilingual Document OCR") as demo: gr.Markdown(""" # 🔍 dots.ocr - Multilingual Document Layout Parsing Upload a document image and get OCR results with layout detection. This space uses the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model. **Features:** - Multilingual support - Layout detection (tables, formulas, text, etc.) - Reading order preservation - Formula extraction (LaTeX format) - Table extraction (HTML format) """) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Document Image", height=400 ) prompt_type = gr.Dropdown( choices=list(PROMPTS.keys()), value="Full Layout + OCR (English)", label="Prompt Type", info="Select the type of processing you want" ) custom_prompt = gr.Textbox( label="Custom Prompt (used when 'Custom' is selected)", placeholder="Enter your custom prompt here...", lines=5, visible=False ) submit_btn = gr.Button("Process Document", variant="primary", size="lg") with gr.Column(): output_text = gr.Textbox( label="OCR Result", lines=25, show_copy_button=True ) # Show/hide custom prompt based on selection def toggle_custom_prompt(choice): return gr.update(visible=(choice == "Custom")) prompt_type.change( fn=toggle_custom_prompt, inputs=[prompt_type], outputs=[custom_prompt] ) submit_btn.click( fn=process_image, inputs=[image_input, prompt_type, custom_prompt], outputs=[output_text] ) # Examples gr.Markdown("## 📝 Examples") gr.Examples( examples=[ ["examples/example1.jpg", "Full Layout + OCR (English)", ""], ["examples/example2.jpg", "OCR Only", ""], ], inputs=[image_input, prompt_type, custom_prompt], outputs=[output_text], fn=process_image, cache_examples=False, ) if __name__ == "__main__": demo.launch()