| | import gradio as gr |
| | import torch |
| | import json |
| | import spaces |
| | import os |
| | from PIL import Image |
| | from transformers import AutoModelForCausalLM, AutoProcessor |
| | from transformers.processing_utils import ProcessorMixin |
| | from qwen_vl_utils import process_vision_info |
| | from huggingface_hub import login |
| |
|
| | |
| | MODEL_PATH = "rednote-hilab/dots.ocr" |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN") |
| | if HF_TOKEN: |
| | print("Authenticating with Hugging Face token...") |
| | login(token=HF_TOKEN, add_to_git_credential=False) |
| |
|
| | |
| | model = None |
| | processor = None |
| |
|
| | def load_model(): |
| | """Load model and processor on GPU""" |
| | global model, processor |
| | if model is None: |
| | print(f"Loading model weights from {MODEL_PATH}...") |
| | |
| | |
| | try: |
| | import flash_attn |
| | attn_implementation = "flash_attention_2" |
| | print("Using FlashAttention2 for faster inference") |
| | except ImportError: |
| | attn_implementation = "eager" |
| | print("FlashAttention2 not available, using default attention") |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_PATH, |
| | dtype=torch.bfloat16, |
| | device_map="auto", |
| | trust_remote_code=True, |
| | token=HF_TOKEN, |
| | attn_implementation=attn_implementation |
| | ) |
| | print("Model loaded successfully.") |
| |
|
| | print(f"Loading processor from {MODEL_PATH}...") |
| | |
| | |
| | _original_check = ProcessorMixin.check_argument_for_proper_class |
| | |
| | def _patched_check(self, attribute_name, value): |
| | if attribute_name == "video_processor" and value is None: |
| | return |
| | return _original_check(self, attribute_name, value) |
| | |
| | ProcessorMixin.check_argument_for_proper_class = _patched_check |
| | |
| | try: |
| | processor = AutoProcessor.from_pretrained( |
| | MODEL_PATH, |
| | trust_remote_code=True, |
| | token=HF_TOKEN |
| | ) |
| | print("Processor loaded successfully.") |
| | finally: |
| | |
| | ProcessorMixin.check_argument_for_proper_class = _original_check |
| | |
| | return model, processor |
| |
|
| | |
| | PROMPTS = { |
| | "Full Layout + OCR (English)": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. |
| | |
| | 1. Bbox format: [x1, y1, x2, y2] |
| | |
| | 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. |
| | |
| | 3. Text Extraction & Formatting Rules: |
| | - Picture: For the 'Picture' category, the text field should be omitted. |
| | - Formula: Format its text as LaTeX. |
| | - Table: Format its text as HTML. |
| | - All Others (Text, Title, etc.): Format their text as Markdown. |
| | |
| | 4. Constraints: |
| | - The output text must be the original text from the image, with no translation. |
| | - All layout elements must be sorted according to human reading order. |
| | |
| | 5. Final Output: The entire output must be a single JSON object.""", |
| |
|
| | "OCR Only": """Please extract all text from the image in reading order. Format the output as plain text, preserving the original structure as much as possible.""", |
| | |
| | "Layout Detection Only": """Please detect all layout elements in the image and output their bounding boxes and categories. Format: [{"bbox": [x1, y1, x2, y2], "category": "category_name"}]""", |
| | |
| | "Custom": "" |
| | } |
| |
|
| | @spaces.GPU(duration=120) |
| | def process_image(image, prompt_type, custom_prompt): |
| | """Process image with OCR model""" |
| | try: |
| | |
| | current_model, current_processor = load_model() |
| | |
| | |
| | if prompt_type == "Custom" and custom_prompt.strip(): |
| | prompt = custom_prompt |
| | else: |
| | prompt = PROMPTS[prompt_type] |
| | |
| | |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image", "image": image}, |
| | {"type": "text", "text": prompt} |
| | ] |
| | } |
| | ] |
| | |
| | |
| | text = current_processor.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | image_inputs, video_inputs = process_vision_info(messages) |
| | inputs = current_processor( |
| | text=[text], |
| | images=image_inputs, |
| | videos=video_inputs, |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | |
| | inputs = inputs.to("cuda") |
| | |
| | |
| | with torch.no_grad(): |
| | generated_ids = current_model.generate( |
| | **inputs, |
| | max_new_tokens=24000, |
| | temperature=0.1, |
| | top_p=0.9, |
| | ) |
| | |
| | |
| | generated_ids_trimmed = [ |
| | out_ids[len(in_ids):] |
| | for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
| | ] |
| | output_text = current_processor.batch_decode( |
| | generated_ids_trimmed, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | )[0] |
| | |
| | |
| | try: |
| | parsed_json = json.loads(output_text) |
| | output_text = json.dumps(parsed_json, ensure_ascii=False, indent=2) |
| | except: |
| | pass |
| | |
| | return output_text |
| | |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| | |
| | with gr.Blocks(title="dots.ocr - Multilingual Document OCR") as demo: |
| | gr.Markdown(""" |
| | # ๐ dots.ocr - Multilingual Document Layout Parsing |
| | |
| | Upload a document image and get OCR results with layout detection. |
| | This space uses the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model. |
| | |
| | **Features:** |
| | - Multilingual support |
| | - Layout detection (tables, formulas, text, etc.) |
| | - Reading order preservation |
| | - Formula extraction (LaTeX format) |
| | - Table extraction (HTML format) |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | image_input = gr.Image( |
| | type="pil", |
| | label="Upload Document Image", |
| | height=400 |
| | ) |
| | |
| | prompt_type = gr.Dropdown( |
| | choices=list(PROMPTS.keys()), |
| | value="Full Layout + OCR (English)", |
| | label="Prompt Type", |
| | info="Select the type of processing you want" |
| | ) |
| | |
| | custom_prompt = gr.Textbox( |
| | label="Custom Prompt (used when 'Custom' is selected)", |
| | placeholder="Enter your custom prompt here...", |
| | lines=5, |
| | visible=False |
| | ) |
| | |
| | submit_btn = gr.Button("Process Document", variant="primary", size="lg") |
| | |
| | with gr.Column(): |
| | output_text = gr.Textbox( |
| | label="OCR Result", |
| | lines=25, |
| | show_copy_button=True |
| | ) |
| | |
| | |
| | def toggle_custom_prompt(choice): |
| | return gr.update(visible=(choice == "Custom")) |
| | |
| | prompt_type.change( |
| | fn=toggle_custom_prompt, |
| | inputs=[prompt_type], |
| | outputs=[custom_prompt] |
| | ) |
| | |
| | submit_btn.click( |
| | fn=process_image, |
| | inputs=[image_input, prompt_type, custom_prompt], |
| | outputs=[output_text] |
| | ) |
| | |
| | |
| | gr.Markdown("## ๐ Examples") |
| | gr.Examples( |
| | examples=[ |
| | ["examples/example1.jpg", "Full Layout + OCR (English)", ""], |
| | ["examples/example2.jpg", "OCR Only", ""], |
| | ], |
| | inputs=[image_input, prompt_type, custom_prompt], |
| | outputs=[output_text], |
| | fn=process_image, |
| | cache_examples=False, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|