Multimodal-OCR2 / app.py
prithivMLmods's picture
Update app.py
e61207c verified
raw
history blame
16.1 kB
import os
import random
import uuid
import time
import base64
from http import HTTPStatus
from threading import Thread
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageOps
import cv2
from transformers import (
Qwen2_5_VLForConditionalGeneration,
AutoModelForVision2Seq,
AutoProcessor,
TextIteratorStreamer,
)
from gradio_client import utils as client_utils
import modelscope_studio.components.antd as antd
import modelscope_studio.components.antdx as antdx
import modelscope_studio.components.base as ms
import modelscope_studio.components.pro as pro
# --- Constants and Configuration ---
MAX_MAX_NEW_TOKENS = 5120
DEFAULT_MAX_NEW_TOKENS = 3072
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# --- Model Loading ---
# A dictionary to hold our models and processors for easy access
models = {}
processors = {}
MODEL_CHOICES = [
"Nanonets-OCR-s",
"MonkeyOCR-Recognition",
"Thyme-RL",
"Typhoon-OCR-7B",
"SmolDocling-256M-preview"
]
def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''):
"""Helper function to load a model and processor."""
print(f"Loading model: {model_key}...")
try:
processor_args = {"trust_remote_code": True}
model_args = {"trust_remote_code": True, "torch_dtype": torch.float16}
if subfolder:
processor_args["subfolder"] = subfolder
model_args["subfolder"] = subfolder
processors[model_key] = processor_class.from_pretrained(model_id, **processor_args)
models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval()
print(f"Successfully loaded {model_key}.")
except Exception as e:
print(f"Error loading model {model_key}: {e}")
# If a model fails to load, remove it from the choices
if model_key in MODEL_CHOICES:
MODEL_CHOICES.remove(model_key)
# Load all models
load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s")
load_model("echo840/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition")
load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B")
load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview")
load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL")
# --- Preprocessing and Helper Functions ---
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
"""Add random padding to an image."""
image = image.convert("RGB")
width, height = image.size
pad_w = int(width * random.uniform(min_percent, max_percent))
pad_h = int(height * random.uniform(min_percent, max_percent))
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0)))
return padded_image
def downsample_video(video_path, num_frames=10):
"""Downsample a video into a list of PIL Image frames."""
if not os.path.exists(video_path): return []
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
if total_frames > 0:
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
vidcap.release()
return frames
def format_history_for_model(history, selected_model):
"""Prepares history for the multimodal model, handling text and media files."""
last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None)
if not last_user_message:
return None, [], ""
text = ""
files = []
images = []
for content_part in last_user_message["content"]:
if content_part["type"] == "text":
text = content_part["content"]
elif content_part["type"] == "file":
files.extend(content_part["content"])
for file_path in files:
mime_type = client_utils.get_mimetype(file_path)
if mime_type.startswith("image"):
images.append(Image.open(file_path))
elif mime_type.startswith("video"):
images.extend(downsample_video(file_path))
# Apply model-specific preprocessing
if selected_model == "SmolDocling-256M-preview":
if "OTSL" in text or "code" in text:
images = [add_random_padding(img) for img in images]
return text, images, selected_model
# --- Gradio Events and Application Logic ---
class Gradio_Events:
@staticmethod
def submit(state_value):
conv_id = state_value["conversation_id"]
context = state_value["conversation_contexts"][conv_id]
history = context["history"]
model_name = context.get("selected_model", MODEL_CHOICES[0])
processor = processors.get(model_name)
model = models.get(model_name)
if not processor or not model:
history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]})
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
return
text, images, _ = format_history_for_model(history, model_name)
if not text and not images:
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
return
history.append({
"role": "assistant",
"content": [],
"key": str(uuid.uuid4()),
"loading": True,
})
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
try:
messages = [{"role": "user", "content": []}]
if images:
messages[0]["content"].extend([{"type": "image"}] * len(images))
messages[0]["content"].append({"type": "text", "text": text or "Describe the media."})
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text.replace("<|im_end|>", "")
history[-1]["content"] = [{"type": "text", "content": buffer}]
history[-1]["loading"] = True
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
history[-1]["loading"] = False
# Final post-processing, especially for models like SmolDocling
final_content = buffer.strip().replace("<end_of_utterance>", "")
history[-1]["content"] = [{"type": "text", "content": final_content}]
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
except Exception as e:
print(f"Error during model generation: {e}")
history[-1]["loading"] = False
history[-1]["content"] = [{"type": "text", "content": f'<span style="color: red;">An error occurred: {e}</span>'}]
yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)}
@staticmethod
def add_message(input_value, state_value):
text = input_value["text"]
files = input_value["files"]
if not state_value["conversation_id"]:
random_id = str(uuid.uuid4())
state_value["conversation_id"] = random_id
state_value["conversations"].append({"label": text or "New Chat", "key": random_id})
state_value["conversation_contexts"][random_id] = {
"history": [],
"selected_model": MODEL_CHOICES[0] # Default model
}
conv_id = state_value["conversation_id"]
history = state_value["conversation_contexts"][conv_id]["history"]
history.append({
"key": str(uuid.uuid4()),
"role": "user",
"content": [{"type": "file", "content": files}, {"type": "text", "content": text}]
})
yield Gradio_Events.preprocess_submit(clear_input=True)(state_value)
for chunk in Gradio_Events.submit(state_value):
yield chunk
yield Gradio_Events.postprocess_submit(state_value)
@staticmethod
def preprocess_submit(clear_input=True):
def handler(state_value):
conv_id = state_value["conversation_id"]
history = state_value["conversation_contexts"][conv_id]["history"]
return {
input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True),
conversations: gr.update(active_key=conv_id, items=state_value["conversations"]),
add_conversation_btn: gr.update(disabled=True),
chatbot: gr.update(value=history),
state: gr.update(value=state_value),
}
return handler
@staticmethod
def postprocess_submit(state_value):
conv_id = state_value["conversation_id"]
history = state_value["conversation_contexts"][conv_id]["history"]
return {
input_comp: gr.update(loading=False),
add_conversation_btn: gr.update(disabled=False),
chatbot: gr.update(value=history),
state: gr.update(value=state_value),
}
@staticmethod
def apply_prompt(e: gr.EventData):
# Example format: {"description": "Query text", "urls": ["path/to/image.png"]}
prompt_data = e._data["payload"][0]["value"]
return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']})
@staticmethod
def new_chat(state_value):
state_value["conversation_id"] = ""
return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0])
@staticmethod
def select_conversation(state_value, e: gr.EventData):
active_key = e._data["payload"][0]
if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]:
return gr.skip()
state_value["conversation_id"] = active_key
context = state_value["conversation_contexts"][active_key]
return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0]))
@staticmethod
def on_model_change(model_name, state_value):
if state_value["conversation_id"]:
state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name
return state_value
# --- UI Layout and Components ---
css = """
.gradio-container { padding: 0 !important; }
main.fillable { padding: 0 !important; }
#chatbot_container { height: calc(100vh - 80px); max-height: 1000px; }
#conversations_sidebar .chatbot-conversations {
height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px;
}
#main_chat_area { padding: 16px; height: 100%; }
"""
# Define welcome prompts based on available examples
welcome_prompts = [
{
"title": "Reconstruct Table",
"description": "Reconstruct the doc [table] as it is.",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"]
},
{
"title": "Describe Image",
"description": "Describe the image!",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"]
},
{
"title": "OCR Image",
"description": "OCR the image",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"]
},
{
"title": "Convert to Docling",
"description": "Convert this page to docling",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"]
},
{
"title": "Convert Chart",
"description": "Convert chart to OTSL.",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/4.png"]
},
{
"title": "Extract Code",
"description": "Convert code to text",
"urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/5.jpg"]
},
]
with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo:
state = gr.State({
"conversation_contexts": {},
"conversations": [],
"conversation_id": "",
})
with ms.Application(), antdx.XProvider(), ms.AutoLoading():
with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"):
# Left Sidebar for Conversations
with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"):
with ms.Div(elem_classes="chatbot-conversations"):
with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")):
gr.Markdown("### OCR Conversations")
with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn:
ms.Text("New Conversation")
with ms.Slot("icon"): antd.Icon("PlusOutlined")
with antdx.Conversations() as conversations:
pass # Handled by events
# Right Main Chat Area
with antd.Col(flex=1, elem_style=dict(height="100%")):
with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"):
gr.Markdown("## Multimodal OCR2")
chatbot = pro.Chatbot(
height="calc(100vh - 200px)",
welcome_config=pro.Chatbot.WelcomeConfig(prompts=welcome_prompts, title="Start by selecting an example:")
)
with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp:
with ms.Slot("prefix"):
model_selector = gr.Dropdown(
choices=MODEL_CHOICES,
value=MODEL_CHOICES[0],
label="Select Model",
container=False
)
# --- Event Wiring ---
add_conversation_btn.click(
fn=Gradio_Events.new_chat,
inputs=[state],
outputs=[conversations, chatbot, state, model_selector]
)
conversations.active_change(
fn=Gradio_Events.select_conversation,
inputs=[state],
outputs=[conversations, chatbot, state, model_selector]
)
chatbot.welcome_prompt_select(
fn=Gradio_Events.apply_prompt,
inputs=[],
outputs=[input_comp]
)
submit_event = input_comp.submit(
fn=Gradio_Events.add_message,
inputs=[input_comp, state],
outputs=[input_comp, add_conversation_btn, conversations, chatbot, state]
)
model_selector.change(
fn=Gradio_Events.on_model_change,
inputs=[model_selector, state],
outputs=[state]
)
if __name__ == "__main__":
demo.queue().launch(show_error=True, debug=True)