import os import sys from threading import Thread from typing import Iterable from huggingface_hub import snapshot_download import gradio as gr import spaces import torch from PIL import Image from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText, AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer, ) from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() css = """ #main-title h1 { font-size: 2.3em !important; } #output-title h2 { font-size: 2.1em !important; } """ CACHE_PATH = "./model_cache" if not os.path.exists(CACHE_PATH): os.makedirs(CACHE_PATH) model_path_d_local = snapshot_download( repo_id='rednote-hilab/dots.ocr', local_dir=os.path.join(CACHE_PATH, 'dots.ocr'), max_workers=20, local_dir_use_symlinks=False ) config_file_path = os.path.join(model_path_d_local, "configuration_dots.py") if os.path.exists(config_file_path): with open(config_file_path, 'r') as f: input_code = f.read() lines = input_code.splitlines() if "class DotsVLProcessor" in input_code and not any("attributes = " in line for line in lines): output_lines = [] for line in lines: output_lines.append(line) if line.strip().startswith("class DotsVLProcessor"): output_lines.append(" attributes = [\"image_processor\", \"tokenizer\"]") with open(config_file_path, 'w') as f: f.write('\n'.join(output_lines)) print("Patched configuration_dots.py successfully.") sys.path.append(model_path_d_local) MAX_MAX_NEW_TOKENS = 4096 DEFAULT_MAX_NEW_TOKENS = 1440 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load Nanonets-OCR2-3B MODEL_ID_M = "nanonets/Nanonets-OCR2-3B" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # Load Nanonets-OCR2-1.5B-exp MODEL_ID_N = "strangervisionhf/excess_layer_pruned-nanonets-1.5b" # -> https://huggingface.co/nanonets/Nanonets-OCR2-1.5B-exp processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True) model_n = AutoModelForImageTextToText.from_pretrained( MODEL_ID_N, trust_remote_code=True, torch_dtype=torch.float16, attn_implementation="flash_attention_2" ).to(device).eval() # Load Dots.OCR from the local, patched directory MODEL_PATH_D = model_path_d_local processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) model_d = AutoModelForCausalLM.from_pretrained( MODEL_PATH_D, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ).eval() # Load PaddleOCR MODEL_ID_P = "strangervisionhf/paddle" # -> https://huggingface.co/PaddlePaddle/PaddleOCR-VL processor_p = AutoProcessor.from_pretrained(MODEL_ID_P, trust_remote_code=True) model_p = AutoModelForCausalLM.from_pretrained( MODEL_ID_P, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device).eval() @spaces.GPU def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): """Generate responses for image input using the selected model.""" if model_name == "Nanonets-OCR2-3B": processor, model = processor_m, model_m elif model_name == "Nanonets-OCR2-1.5B(exp)": processor, model = processor_n, model_n elif model_name == "Dots.OCR": processor, model = processor_d, model_d elif model_name == "PaddleOCR": processor, model = processor_p, model_p else: yield "Invalid model selected.", "Invalid model selected." return if image is None: yield "Please upload an image.", "Please upload an image." return images = [image.convert("RGB")] if model_name == "PaddleOCR": messages = [ {"role": "user", "content": text} ] else: messages = [ { "role": "user", "content": [{"type": "image"}] + [{"type": "text", "text": text}] } ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text.replace("<|im_end|>", "").replace("", "") yield buffer, buffer image_examples = [ ["Perform OCR on the image.", "examples/1.jpg"], ["Phrase the document [page].", "examples/2.jpg"], ["OCR the content perfectly.", "examples/3.jpg"], ] with gr.Blocks(css=css, theme=steel_blue_theme) as demo: gr.Markdown("# **Multimodal OCR3**", elem_id="main-title") with gr.Row(): with gr.Column(scale=2): image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") image_upload = gr.Image(type="pil", label="Upload Image", height=320) image_submit = gr.Button("Submit", variant="primary") gr.Examples(examples=image_examples, inputs=[image_query, image_upload]) with gr.Accordion("Advanced options", open=False): max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) with gr.Column(scale=3): gr.Markdown("## Output", elem_id="output-title") raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True) with gr.Accordion("[Result.md]", open=False): formatted_output = gr.Markdown(label="Formatted Result") model_choice = gr.Radio( choices=["Nanonets-OCR2-3B", "Dots.OCR", "Nanonets-OCR2-1.5B(exp)", "PaddleOCR"], label="Select Model", value="Nanonets-OCR2-3B" ) gr.Markdown("Note: Currently, PaddleOCR VL only supports OCR inference. Structured OCR document parsing transformer inference is coming soon. [Report – Bug/Issue](https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR3/discussions/1)") image_submit.click( fn=generate_image, inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty], outputs=[raw_output, formatted_output] ) if __name__ == "__main__": demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)