Spaces:
Sleeping
Sleeping
refactor: update model loading logic to improve compatibility and adjust output textbox settings
aeca817 | import os | |
| import sys | |
| import re | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| MODEL_ID = "rednote-hilab/dots.ocr" | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "model_weights") | |
| DEFAULT_PROMPT = "Extract the text content from this image." | |
| def patch_configuration_dots(model_path: str) -> None: | |
| """Patch configuration_dots.py to fix the video_processor TypeError. | |
| Recent transformers versions require DotsVLProcessor to explicitly | |
| declare `attributes` and accept `video_processor=None`. | |
| See: https://huggingface.co/rednote-hilab/dots.ocr/discussions/38 | |
| """ | |
| config_path = os.path.join(model_path, "configuration_dots.py") | |
| if not os.path.exists(config_path): | |
| return | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| source = f.read() | |
| patched = source | |
| # Force processor mixin to treat dots.ocr as image+tokenizer only. | |
| # This avoids newer transformers requiring BaseVideoProcessor. | |
| if 'attributes = ["image_processor", "tokenizer"]' not in patched: | |
| patched = re.sub( | |
| r"(class\s+DotsVLProcessor\(Qwen2_5_VLProcessor\):\n)", | |
| r'\1 attributes = ["image_processor", "tokenizer"]\n', | |
| patched, | |
| count=1, | |
| ) | |
| # Handle both older and newer remote class signatures. | |
| patched = patched.replace( | |
| "def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):", | |
| "def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):", | |
| ) | |
| if patched == source: | |
| print("No dots.ocr processor patch changes were required.") | |
| return | |
| with open(config_path, "w", encoding="utf-8") as f: | |
| f.write(patched) | |
| def load_model(): | |
| print(f"Downloading {MODEL_ID} ...") | |
| model_path = snapshot_download( | |
| repo_id=MODEL_ID, | |
| local_dir=MODEL_DIR, | |
| ) | |
| patch_configuration_dots(model_path) | |
| sys.path.insert(0, model_path) | |
| # Try flash_attention_2 first, fall back to eager for compatibility. | |
| attn_impl = "flash_attention_2" | |
| try: | |
| import flash_attn # noqa: F401 | |
| except ImportError: | |
| attn_impl = "eager" | |
| print(f"Loading model with attn_implementation={attn_impl} ...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| attn_implementation=attn_impl, | |
| dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| processor = AutoProcessor.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| use_fast=False, | |
| ) | |
| return model, processor | |
| MODEL, PROCESSOR = load_model() | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def predict(image: Image.Image, prompt: str = DEFAULT_PROMPT) -> str: | |
| """Run OCR inference on a single image. | |
| Args: | |
| image: PIL Image to process. | |
| prompt: Instruction for the model. | |
| Returns: | |
| Raw text/JSON generated by dots.ocr. | |
| """ | |
| if image is None: | |
| return "Error: no image provided." | |
| if not prompt or not prompt.strip(): | |
| prompt = DEFAULT_PROMPT | |
| image = image.convert("RGB") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| text = PROCESSOR.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = PROCESSOR( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| generated_ids = MODEL.generate(**inputs, max_new_tokens=24000) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = PROCESSOR.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| return output_text[0] if output_text else "" | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="dots.ocr API") as demo: | |
| gr.Markdown( | |
| """ | |
| # dots.ocr -- OCR API | |
| Upload an image and get the extracted text. This Space is optimized for | |
| **programmatic API access** so you can batch-process hundreds of images from | |
| an external script. | |
| ### Calling the API from Python | |
| ```python | |
| from gradio_client import Client | |
| client = Client("openpecha/bec-dot.orc-api") | |
| result = client.predict( | |
| "path/to/image.png", # image filepath | |
| "Extract the text content from this image.", # prompt | |
| api_name="/predict", | |
| ) | |
| print(result) | |
| ``` | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(type="pil", label="Upload Image") | |
| prompt_input = gr.Textbox( | |
| value=DEFAULT_PROMPT, | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| run_btn = gr.Button("Run OCR", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox( | |
| label="Model Output", | |
| lines=20, | |
| ) | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[img_input, prompt_input], | |
| outputs=output_text, | |
| api_name="predict", | |
| ) | |
| demo.queue(max_size=20).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) | |