import os import random import uuid import time import base64 from http import HTTPStatus from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image, ImageOps import cv2 from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoModelForVision2Seq, AutoProcessor, TextIteratorStreamer, ) from gradio_client import utils as client_utils import modelscope_studio.components.antd as antd import modelscope_studio.components.antdx as antdx import modelscope_studio.components.base as ms import modelscope_studio.components.pro as pro # --- Constants and Configuration --- MAX_MAX_NEW_TOKENS = 5120 DEFAULT_MAX_NEW_TOKENS = 3072 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # --- Model Loading --- # A dictionary to hold our models and processors for easy access models = {} processors = {} MODEL_CHOICES = [ "Nanonets-OCR-s", "MonkeyOCR-Recognition", "Thyme-RL", "Typhoon-OCR-7B", "SmolDocling-256M-preview" ] def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''): """Helper function to load a model and processor.""" print(f"Loading model: {model_key}...") try: processor_args = {"trust_remote_code": True} model_args = {"trust_remote_code": True, "torch_dtype": torch.float16} if subfolder: processor_args["subfolder"] = subfolder model_args["subfolder"] = subfolder processors[model_key] = processor_class.from_pretrained(model_id, **processor_args) models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval() print(f"Successfully loaded {model_key}.") except Exception as e: print(f"Error loading model {model_key}: {e}") # If a model fails to load, remove it from the choices if model_key in MODEL_CHOICES: MODEL_CHOICES.remove(model_key) # Load all models load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s") load_model("echo840/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition") load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B") load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview") load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL") # --- Preprocessing and Helper Functions --- def add_random_padding(image, min_percent=0.1, max_percent=0.10): """Add random padding to an image.""" image = image.convert("RGB") width, height = image.size pad_w = int(width * random.uniform(min_percent, max_percent)) pad_h = int(height * random.uniform(min_percent, max_percent)) padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0))) return padded_image def downsample_video(video_path, num_frames=10): """Downsample a video into a list of PIL Image frames.""" if not os.path.exists(video_path): return [] vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] if total_frames > 0: frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))) vidcap.release() return frames def format_history_for_model(history, selected_model): """Prepares history for the multimodal model, handling text and media files.""" last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None) if not last_user_message: return None, [], "" text = "" files = [] images = [] for content_part in last_user_message["content"]: if content_part["type"] == "text": text = content_part["content"] elif content_part["type"] == "file": files.extend(content_part["content"]) for file_path in files: mime_type = client_utils.get_mimetype(file_path) if mime_type.startswith("image"): images.append(Image.open(file_path)) elif mime_type.startswith("video"): images.extend(downsample_video(file_path)) # Apply model-specific preprocessing if selected_model == "SmolDocling-256M-preview": if "OTSL" in text or "code" in text: images = [add_random_padding(img) for img in images] return text, images, selected_model # --- Gradio Events and Application Logic --- class Gradio_Events: @staticmethod def submit(state_value): conv_id = state_value["conversation_id"] context = state_value["conversation_contexts"][conv_id] history = context["history"] model_name = context.get("selected_model", MODEL_CHOICES[0]) processor = processors.get(model_name) model = models.get(model_name) if not processor or not model: history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]}) yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} return text, images, _ = format_history_for_model(history, model_name) if not text and not images: yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} return history.append({ "role": "assistant", "content": [], "key": str(uuid.uuid4()), "loading": True, }) yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} try: messages = [{"role": "user", "content": []}] if images: messages[0]["content"].extend([{"type": "image"}] * len(images)) messages[0]["content"].append({"type": "text", "text": text or "Describe the media."}) prompt = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS} thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text.replace("<|im_end|>", "") history[-1]["content"] = [{"type": "text", "content": buffer}] history[-1]["loading"] = True yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} history[-1]["loading"] = False # Final post-processing, especially for models like SmolDocling final_content = buffer.strip().replace("", "") history[-1]["content"] = [{"type": "text", "content": final_content}] yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} except Exception as e: print(f"Error during model generation: {e}") history[-1]["loading"] = False history[-1]["content"] = [{"type": "text", "content": f'An error occurred: {e}'}] yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} @staticmethod def add_message(input_value, state_value): text = input_value["text"] files = input_value["files"] if not state_value["conversation_id"]: random_id = str(uuid.uuid4()) state_value["conversation_id"] = random_id state_value["conversations"].append({"label": text or "New Chat", "key": random_id}) state_value["conversation_contexts"][random_id] = { "history": [], "selected_model": MODEL_CHOICES[0] # Default model } conv_id = state_value["conversation_id"] history = state_value["conversation_contexts"][conv_id]["history"] history.append({ "key": str(uuid.uuid4()), "role": "user", "content": [{"type": "file", "content": files}, {"type": "text", "content": text}] }) yield Gradio_Events.preprocess_submit(clear_input=True)(state_value) for chunk in Gradio_Events.submit(state_value): yield chunk yield Gradio_Events.postprocess_submit(state_value) @staticmethod def preprocess_submit(clear_input=True): def handler(state_value): conv_id = state_value["conversation_id"] history = state_value["conversation_contexts"][conv_id]["history"] return { input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True), conversations: gr.update(active_key=conv_id, items=state_value["conversations"]), add_conversation_btn: gr.update(disabled=True), chatbot: gr.update(value=history), state: gr.update(value=state_value), } return handler @staticmethod def postprocess_submit(state_value): conv_id = state_value["conversation_id"] history = state_value["conversation_contexts"][conv_id]["history"] return { input_comp: gr.update(loading=False), add_conversation_btn: gr.update(disabled=False), chatbot: gr.update(value=history), state: gr.update(value=state_value), } @staticmethod def apply_prompt(e: gr.EventData): # Example format: {"description": "Query text", "urls": ["path/to/image.png"]} prompt_data = e._data["payload"][0]["value"] return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']}) @staticmethod def new_chat(state_value): state_value["conversation_id"] = "" return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0]) @staticmethod def select_conversation(state_value, e: gr.EventData): active_key = e._data["payload"][0] if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]: return gr.skip() state_value["conversation_id"] = active_key context = state_value["conversation_contexts"][active_key] return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0])) @staticmethod def on_model_change(model_name, state_value): if state_value["conversation_id"]: state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name return state_value # --- UI Layout and Components --- css = """ .gradio-container { padding: 0 !important; } main.fillable { padding: 0 !important; } #chatbot_container { height: calc(100vh - 80px); max-height: 1000px; } #conversations_sidebar .chatbot-conversations { height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px; } #main_chat_area { padding: 16px; height: 100%; } """ # Define welcome prompts based on available examples welcome_prompts = [ { "title": "Reconstruct Table", "description": "Reconstruct the doc [table] as it is.", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"] }, { "title": "Describe Image", "description": "Describe the image!", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"] }, { "title": "OCR Image", "description": "OCR the image", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"] }, { "title": "Convert to Docling", "description": "Convert this page to docling", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"] }, { "title": "Convert Chart", "description": "Convert chart to OTSL.", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/4.png"] }, { "title": "Extract Code", "description": "Convert code to text", "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/5.jpg"] }, ] with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo: state = gr.State({ "conversation_contexts": {}, "conversations": [], "conversation_id": "", }) with ms.Application(), antdx.XProvider(), ms.AutoLoading(): with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"): # Left Sidebar for Conversations with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"): with ms.Div(elem_classes="chatbot-conversations"): with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")): gr.Markdown("### OCR Conversations") with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn: ms.Text("New Conversation") with ms.Slot("icon"): antd.Icon("PlusOutlined") with antdx.Conversations() as conversations: pass # Handled by events # Right Main Chat Area with antd.Col(flex=1, elem_style=dict(height="100%")): with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"): gr.Markdown("## Multimodal OCR2") chatbot = pro.Chatbot( height="calc(100vh - 200px)", welcome_config=pro.Chatbot.WelcomeConfig(prompts=welcome_prompts, title="Start by selecting an example:") ) with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp: with ms.Slot("prefix"): model_selector = gr.Dropdown( choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Select Model", container=False ) # --- Event Wiring --- add_conversation_btn.click( fn=Gradio_Events.new_chat, inputs=[state], outputs=[conversations, chatbot, state, model_selector] ) conversations.active_change( fn=Gradio_Events.select_conversation, inputs=[state], outputs=[conversations, chatbot, state, model_selector] ) chatbot.welcome_prompt_select( fn=Gradio_Events.apply_prompt, inputs=[], outputs=[input_comp] ) submit_event = input_comp.submit( fn=Gradio_Events.add_message, inputs=[input_comp, state], outputs=[input_comp, add_conversation_btn, conversations, chatbot, state] ) model_selector.change( fn=Gradio_Events.on_model_change, inputs=[model_selector, state], outputs=[state] ) if __name__ == "__main__": demo.queue().launch(show_error=True, debug=True)