import os import sys import re import torch import gradio as gr from PIL import Image from huggingface_hub import snapshot_download from transformers import AutoModelForCausalLM, AutoProcessor from qwen_vl_utils import process_vision_info MODEL_ID = "rednote-hilab/dots.ocr" MODEL_DIR = os.path.join(os.path.dirname(__file__), "model_weights") DEFAULT_PROMPT = "Extract the text content from this image." def patch_configuration_dots(model_path: str) -> None: """Patch configuration_dots.py to fix the video_processor TypeError. Recent transformers versions require DotsVLProcessor to explicitly declare `attributes` and accept `video_processor=None`. See: https://huggingface.co/rednote-hilab/dots.ocr/discussions/38 """ config_path = os.path.join(model_path, "configuration_dots.py") if not os.path.exists(config_path): return with open(config_path, "r", encoding="utf-8") as f: source = f.read() patched = source # Force processor mixin to treat dots.ocr as image+tokenizer only. # This avoids newer transformers requiring BaseVideoProcessor. if 'attributes = ["image_processor", "tokenizer"]' not in patched: patched = re.sub( r"(class\s+DotsVLProcessor\(Qwen2_5_VLProcessor\):\n)", r'\1 attributes = ["image_processor", "tokenizer"]\n', patched, count=1, ) # Handle both older and newer remote class signatures. patched = patched.replace( "def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):", "def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):", ) if patched == source: print("No dots.ocr processor patch changes were required.") return with open(config_path, "w", encoding="utf-8") as f: f.write(patched) def load_model(): print(f"Downloading {MODEL_ID} ...") model_path = snapshot_download( repo_id=MODEL_ID, local_dir=MODEL_DIR, ) patch_configuration_dots(model_path) sys.path.insert(0, model_path) # Try flash_attention_2 first, fall back to eager for compatibility. attn_impl = "flash_attention_2" try: import flash_attn # noqa: F401 except ImportError: attn_impl = "eager" print(f"Loading model with attn_implementation={attn_impl} ...") model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation=attn_impl, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, use_fast=False, ) return model, processor MODEL, PROCESSOR = load_model() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def predict(image: Image.Image, prompt: str = DEFAULT_PROMPT) -> str: """Run OCR inference on a single image. Args: image: PIL Image to process. prompt: Instruction for the model. Returns: Raw text/JSON generated by dots.ocr. """ if image is None: return "Error: no image provided." if not prompt or not prompt.strip(): prompt = DEFAULT_PROMPT image = image.convert("RGB") messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = PROCESSOR.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = PROCESSOR( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(DEVICE) with torch.no_grad(): generated_ids = MODEL.generate(**inputs, max_new_tokens=24000) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = PROCESSOR.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return output_text[0] if output_text else "" # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="dots.ocr API") as demo: gr.Markdown( """ # dots.ocr -- OCR API Upload an image and get the extracted text. This Space is optimized for **programmatic API access** so you can batch-process hundreds of images from an external script. ### Calling the API from Python ```python from gradio_client import Client client = Client("openpecha/bec-dot.orc-api") result = client.predict( "path/to/image.png", # image filepath "Extract the text content from this image.", # prompt api_name="/predict", ) print(result) ``` """ ) with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(type="pil", label="Upload Image") prompt_input = gr.Textbox( value=DEFAULT_PROMPT, label="Prompt", lines=2, ) run_btn = gr.Button("Run OCR", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="Model Output", lines=20, ) run_btn.click( fn=predict, inputs=[img_input, prompt_input], outputs=output_text, api_name="predict", ) demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=7860, show_error=True, )