Multimodal-OCR3 / app.py
prithivMLmods's picture
update app
33cd763 verified
raw
history blame
8.44 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
from typing import Iterable
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
import requests
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoModelForCausalLM,
AutoProcessor,
TextIteratorStreamer,
)
from transformers.image_utils import load_image
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from huggingface_hub import snapshot_download
# --- Theme and CSS Definition ---
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4", # SteelBlue base color
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
# Constants for text generation
MAX_MAX_NEW_TOKENS = 5120
DEFAULT_MAX_NEW_TOKENS = 3072
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Nanonets-OCR-s
MODEL_ID_M = "nanonets/Nanonets-OCR2-3B"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load Dots.OCR
MODEL_ID_D = "rednote-hilab/dots.ocr"
model_path_d = "./models/dots-ocr-local"
snapshot_download(
repo_id=MODEL_ID_D,
local_dir=model_path_d,
local_dir_use_symlinks=False,
)
model_d = AutoModelForCausalLM.from_pretrained(
model_path_d,
attn_implementation="flash_attention_2" if "cuda" in device.type else "eager",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor_d = AutoProcessor.from_pretrained(
model_path_d,
trust_remote_code=True
)
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float):
"""Generate responses for image input using the selected model."""
if model_name == "Nanonets-OCR2-3B":
processor, model = processor_m, model_m
elif model_name == "Dots.OCR":
processor, model = processor_d, model_d
else:
yield "Invalid model selected.", "Invalid model selected."
return
if image is None:
yield "Please upload an image.", "Please upload an image."
return
images = [image]
messages = [
{
"role": "user",
"content": [{"type": "image"}] * len(images) + [
{"type": "text", "text": text}
]
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
# Dots.OCR uses a different generation parameter name for end-of-sequence
if "dots.ocr" in model.config.name_or_path.lower():
generation_kwargs["eos_token_id"] = processor.tokenizer.eos_token_id
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text.replace("<|im_end|>", "").replace("</s>", "")
yield buffer, buffer
# The formatted output is the same as the raw output in this version
yield buffer, buffer
# Define examples for image inference
image_examples = [
["Reconstruct the doc [table] as it is.", "images/0.png"],
["Describe the image!", "images/8.png"],
["OCR the image", "images/2.jpg"],
["Convert this page to markdown", "images/1.png"],
]
# Create the Gradio Interface
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **Multimodal Image OCR**", elem_id="main-title")
with gr.Row():
with gr.Column(scale=2):
model_choice = gr.Radio(
choices=["Nanonets-OCR2-3B", "Dots.OCR"],
label="Select Model",
value="Nanonets-OCR-s"
)
query_input = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Upload Image", height=320)
submit_button = gr.Button("Submit", variant="primary")
gr.Examples(examples=image_examples, inputs=[query_input, image_upload])
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column(scale=3):
gr.Markdown("## Output", elem_id="output-title")
raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=18, show_copy_button=True)
formatted_output = gr.Markdown(label="Formatted Output (Result.md)")
submit_button.click(
fn=generate_image,
inputs=[model_choice, query_input, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[raw_output, formatted_output]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)