Multimodal-OCR3 / app.py
prithivMLmods's picture
Update app.py
6caf4be verified
raw
history blame
8.57 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
from typing import Iterable
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
import requests
from transformers import (
AutoTokenizer,
AutoProcessor,
TextIteratorStreamer,
)
from transformers.image_utils import load_image
# The custom model class is imported via trust_remote_code=True
from transformers import AutoModelForImageTextToText
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from docling_core.types.doc import DoclingDocument, DocTagsDocument
import re
import ast
import html
# --- Theme and CSS Definition ---
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4", # SteelBlue base color
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
# Constants for text generation
MAX_MAX_NEW_TOKENS = 5120
DEFAULT_MAX_NEW_TOKENS = 3072
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Nanonets-OCR2-3B
MODEL_ID_3B = "nanonets/Nanonets-OCR2-3B"
processor_3b = AutoProcessor.from_pretrained(MODEL_ID_3B, trust_remote_code=True)
model_3b = AutoModelForImageTextToText.from_pretrained(
MODEL_ID_3B,
dtype=torch.float16,
#device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2"
).to(device).eval()
# Load Nanonets-OCR2-1.5B-exp
MODEL_ID_1_5B = "nanonets/Nanonets-OCR2-1.5B-exp"
processor_1_5b = AutoProcessor.from_pretrained(MODEL_ID_1_5B, trust_remote_code=True)
model_1_5b = AutoModelForImageTextToText.from_pretrained(
MODEL_ID_1_5B,
dtype=torch.float16,
#device_map="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2"
).to(device).eval()
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""Generation function for image input."""
if model_name == "Nanonets-OCR2-3B":
processor, model = processor_3b, model_3b
elif model_name == "Nanonets-OCR2-1.5B-exp":
processor, model = processor_1_5b, model_1_5b
else:
yield "Invalid model selected.", "Invalid model selected."
return
if image is None:
yield "Please upload an image.", "Please upload an image."
return
images = [image]
messages = [
{
"role": "user",
"content": [{"type": "image"}] + [{"type": "text", "text": text}]
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text.replace("<|im_end|>", "")
yield buffer, buffer
# Define examples for image inference
image_examples = [
["Reconstruct the doc [table] as it is.", "images/0.png"],
["Describe the image!", "images/8.png"],
["OCR the image", "images/2.jpg"],
["Convert this page to docling", "images/1.png"],
["Convert this page to docling", "images/3.png"],
["Convert chart to OTSL.", "images/4.png"],
["Convert code to text", "images/5.jpg"],
["Convert this table to OTSL.", "images/6.jpg"],
["Convert formula to late.", "images/7.jpg"],
]
# Create the Gradio Interface
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **Multimodal OCR3**", elem_id="main-title")
with gr.Row():
with gr.Column(scale=2):
# Image Inference Components
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Upload Image", height=290)
image_submit = gr.Button("Submit", variant="primary")
gr.Examples(examples=image_examples, inputs=[image_query, image_upload])
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column(scale=3):
gr.Markdown("## Output", elem_id="output-title")
raw_output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
with gr.Accordion("(Result.md)", open=True):
formatted_output = gr.Markdown(label="(Result.md)")
model_choice = gr.Radio(
choices=["Nanonets-OCR2-3B", "Nanonets-OCR2-1.5B-exp"],
label="Select Model",
value="Nanonets-OCR2-3B"
)
image_submit.click(
fn=generate_image,
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[raw_output, formatted_output]
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(ssr_mode=False, show_error=True)