Spaces:
Build error
Build error
| import os | |
| import random | |
| import uuid | |
| import time | |
| import base64 | |
| from http import HTTPStatus | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import cv2 | |
| from transformers import ( | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoModelForVision2Seq, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from gradio_client import utils as client_utils | |
| import modelscope_studio.components.antd as antd | |
| import modelscope_studio.components.antdx as antdx | |
| import modelscope_studio.components.base as ms | |
| import modelscope_studio.components.pro as pro | |
| # --- Constants and Configuration --- | |
| MAX_MAX_NEW_TOKENS = 5120 | |
| DEFAULT_MAX_NEW_TOKENS = 3072 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # --- Model Loading --- | |
| # A dictionary to hold our models and processors for easy access | |
| models = {} | |
| processors = {} | |
| MODEL_CHOICES = [ | |
| "Nanonets-OCR-s", | |
| "MonkeyOCR-Recognition", | |
| "Thyme-RL", | |
| "Typhoon-OCR-7B", | |
| "SmolDocling-256M-preview" | |
| ] | |
| def load_model(model_id, processor_class, model_class, subfolder=None, model_key=''): | |
| """Helper function to load a model and processor.""" | |
| print(f"Loading model: {model_key}...") | |
| try: | |
| processor_args = {"trust_remote_code": True} | |
| model_args = {"trust_remote_code": True, "torch_dtype": torch.float16} | |
| if subfolder: | |
| processor_args["subfolder"] = subfolder | |
| model_args["subfolder"] = subfolder | |
| processors[model_key] = processor_class.from_pretrained(model_id, **processor_args) | |
| models[model_key] = model_class.from_pretrained(model_id, **model_args).to(device).eval() | |
| print(f"Successfully loaded {model_key}.") | |
| except Exception as e: | |
| print(f"Error loading model {model_key}: {e}") | |
| # If a model fails to load, remove it from the choices | |
| if model_key in MODEL_CHOICES: | |
| MODEL_CHOICES.remove(model_key) | |
| # Load all models | |
| load_model("nanonets/Nanonets-OCR-s", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Nanonets-OCR-s") | |
| load_model("echo840/MonkeyOCR", AutoProcessor, Qwen2_5_VLForConditionalGeneration, subfolder="Recognition", model_key="MonkeyOCR-Recognition") | |
| load_model("scb10x/typhoon-ocr-7b", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Typhoon-OCR-7B") | |
| load_model("ds4sd/SmolDocling-256M-preview", AutoProcessor, AutoModelForVision2Seq, model_key="SmolDocling-256M-preview") | |
| load_model("Kwai-Keye/Thyme-RL", AutoProcessor, Qwen2_5_VLForConditionalGeneration, model_key="Thyme-RL") | |
| # --- Preprocessing and Helper Functions --- | |
| def add_random_padding(image, min_percent=0.1, max_percent=0.10): | |
| """Add random padding to an image.""" | |
| image = image.convert("RGB") | |
| width, height = image.size | |
| pad_w = int(width * random.uniform(min_percent, max_percent)) | |
| pad_h = int(height * random.uniform(min_percent, max_percent)) | |
| padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=image.getpixel((0, 0))) | |
| return padded_image | |
| def downsample_video(video_path, num_frames=10): | |
| """Downsample a video into a list of PIL Image frames.""" | |
| if not os.path.exists(video_path): return [] | |
| vidcap = cv2.VideoCapture(video_path) | |
| total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frames = [] | |
| if total_frames > 0: | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| for i in frame_indices: | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| success, image = vidcap.read() | |
| if success: | |
| frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))) | |
| vidcap.release() | |
| return frames | |
| def format_history_for_model(history, selected_model): | |
| """Prepares history for the multimodal model, handling text and media files.""" | |
| last_user_message = next((item for item in reversed(history) if item["role"] == "user"), None) | |
| if not last_user_message: | |
| return None, [], "" | |
| text = "" | |
| files = [] | |
| images = [] | |
| for content_part in last_user_message["content"]: | |
| if content_part["type"] == "text": | |
| text = content_part["content"] | |
| elif content_part["type"] == "file": | |
| files.extend(content_part["content"]) | |
| for file_path in files: | |
| mime_type = client_utils.get_mimetype(file_path) | |
| if mime_type.startswith("image"): | |
| images.append(Image.open(file_path)) | |
| elif mime_type.startswith("video"): | |
| images.extend(downsample_video(file_path)) | |
| # Apply model-specific preprocessing | |
| if selected_model == "SmolDocling-256M-preview": | |
| if "OTSL" in text or "code" in text: | |
| images = [add_random_padding(img) for img in images] | |
| return text, images, selected_model | |
| # --- Gradio Events and Application Logic --- | |
| class Gradio_Events: | |
| def submit(state_value): | |
| conv_id = state_value["conversation_id"] | |
| context = state_value["conversation_contexts"][conv_id] | |
| history = context["history"] | |
| model_name = context.get("selected_model", MODEL_CHOICES[0]) | |
| processor = processors.get(model_name) | |
| model = models.get(model_name) | |
| if not processor or not model: | |
| history.append({"role": "assistant", "content": [{"type": "text", "content": f"Error: Model '{model_name}' not loaded."}]}) | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| return | |
| text, images, _ = format_history_for_model(history, model_name) | |
| if not text and not images: | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| return | |
| history.append({ | |
| "role": "assistant", | |
| "content": [], | |
| "key": str(uuid.uuid4()), | |
| "loading": True, | |
| }) | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| try: | |
| messages = [{"role": "user", "content": []}] | |
| if images: | |
| messages[0]["content"].extend([{"type": "image"}] * len(images)) | |
| messages[0]["content"].append({"type": "text", "text": text or "Describe the media."}) | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=images, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": MAX_MAX_NEW_TOKENS} | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text.replace("<|im_end|>", "") | |
| history[-1]["content"] = [{"type": "text", "content": buffer}] | |
| history[-1]["loading"] = True | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| history[-1]["loading"] = False | |
| # Final post-processing, especially for models like SmolDocling | |
| final_content = buffer.strip().replace("<end_of_utterance>", "") | |
| history[-1]["content"] = [{"type": "text", "content": final_content}] | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| except Exception as e: | |
| print(f"Error during model generation: {e}") | |
| history[-1]["loading"] = False | |
| history[-1]["content"] = [{"type": "text", "content": f'<span style="color: red;">An error occurred: {e}</span>'}] | |
| yield {chatbot: gr.update(value=history), state: gr.update(value=state_value)} | |
| def add_message(input_value, state_value): | |
| text = input_value["text"] | |
| files = input_value["files"] | |
| if not state_value["conversation_id"]: | |
| random_id = str(uuid.uuid4()) | |
| state_value["conversation_id"] = random_id | |
| state_value["conversations"].append({"label": text or "New Chat", "key": random_id}) | |
| state_value["conversation_contexts"][random_id] = { | |
| "history": [], | |
| "selected_model": MODEL_CHOICES[0] # Default model | |
| } | |
| conv_id = state_value["conversation_id"] | |
| history = state_value["conversation_contexts"][conv_id]["history"] | |
| history.append({ | |
| "key": str(uuid.uuid4()), | |
| "role": "user", | |
| "content": [{"type": "file", "content": files}, {"type": "text", "content": text}] | |
| }) | |
| yield Gradio_Events.preprocess_submit(clear_input=True)(state_value) | |
| for chunk in Gradio_Events.submit(state_value): | |
| yield chunk | |
| yield Gradio_Events.postprocess_submit(state_value) | |
| def preprocess_submit(clear_input=True): | |
| def handler(state_value): | |
| conv_id = state_value["conversation_id"] | |
| history = state_value["conversation_contexts"][conv_id]["history"] | |
| return { | |
| input_comp: gr.update(value={'text': '', 'files': []} if clear_input else {}, loading=True), | |
| conversations: gr.update(active_key=conv_id, items=state_value["conversations"]), | |
| add_conversation_btn: gr.update(disabled=True), | |
| chatbot: gr.update(value=history), | |
| state: gr.update(value=state_value), | |
| } | |
| return handler | |
| def postprocess_submit(state_value): | |
| conv_id = state_value["conversation_id"] | |
| history = state_value["conversation_contexts"][conv_id]["history"] | |
| return { | |
| input_comp: gr.update(loading=False), | |
| add_conversation_btn: gr.update(disabled=False), | |
| chatbot: gr.update(value=history), | |
| state: gr.update(value=state_value), | |
| } | |
| def apply_prompt(e: gr.EventData): | |
| # Example format: {"description": "Query text", "urls": ["path/to/image.png"]} | |
| prompt_data = e._data["payload"][0]["value"] | |
| return gr.update(value={'text': prompt_data['description'], 'files': prompt_data['urls']}) | |
| def new_chat(state_value): | |
| state_value["conversation_id"] = "" | |
| return gr.update(active_key=""), gr.update(value=None), gr.update(value=state_value), gr.update(value=MODEL_CHOICES[0]) | |
| def select_conversation(state_value, e: gr.EventData): | |
| active_key = e._data["payload"][0] | |
| if state_value["conversation_id"] == active_key or active_key not in state_value["conversation_contexts"]: | |
| return gr.skip() | |
| state_value["conversation_id"] = active_key | |
| context = state_value["conversation_contexts"][active_key] | |
| return gr.update(active_key=active_key), gr.update(value=context["history"]), gr.update(value=state_value), gr.update(value=context.get("selected_model", MODEL_CHOICES[0])) | |
| def on_model_change(model_name, state_value): | |
| if state_value["conversation_id"]: | |
| state_value["conversation_contexts"][state_value["conversation_id"]]["selected_model"] = model_name | |
| return state_value | |
| # --- UI Layout and Components --- | |
| css = """ | |
| .gradio-container { padding: 0 !important; } | |
| main.fillable { padding: 0 !important; } | |
| #chatbot_container { height: calc(100vh - 80px); max-height: 1000px; } | |
| #conversations_sidebar .chatbot-conversations { | |
| height: 100vh; background-color: var(--ms-gr-ant-color-bg-layout); padding: 8px; | |
| } | |
| #main_chat_area { padding: 16px; height: 100%; } | |
| """ | |
| # Define welcome prompts based on available examples | |
| welcome_prompts = [ | |
| { | |
| "title": "Reconstruct Table", | |
| "description": "Reconstruct the doc [table] as it is.", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/0.png"] | |
| }, | |
| { | |
| "title": "Describe Image", | |
| "description": "Describe the image!", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/8.png"] | |
| }, | |
| { | |
| "title": "OCR Image", | |
| "description": "OCR the image", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/2.jpg"] | |
| }, | |
| { | |
| "title": "Convert to Docling", | |
| "description": "Convert this page to docling", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/1.png"] | |
| }, | |
| { | |
| "title": "Convert Chart", | |
| "description": "Convert chart to OTSL.", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/4.png"] | |
| }, | |
| { | |
| "title": "Extract Code", | |
| "description": "Convert code to text", | |
| "urls": ["https://huggingface.co/spaces/prithivMLmods/Multimodal-OCR2/resolve/main/images/5.jpg"] | |
| }, | |
| ] | |
| with gr.Blocks(css=css, fill_width=True, title="Multimodal OCR2") as demo: | |
| state = gr.State({ | |
| "conversation_contexts": {}, | |
| "conversations": [], | |
| "conversation_id": "", | |
| }) | |
| with ms.Application(), antdx.XProvider(), ms.AutoLoading(): | |
| with antd.Row(gutter=[0, 0], wrap=False, elem_id="chatbot_container"): | |
| # Left Sidebar for Conversations | |
| with antd.Col(md=dict(flex="0 0 260px"), elem_id="conversations_sidebar"): | |
| with ms.Div(elem_classes="chatbot-conversations"): | |
| with antd.Flex(vertical=True, gap="small", elem_style=dict(height="100%")): | |
| gr.Markdown("### OCR Conversations") | |
| with antd.Button(color="primary", variant="filled", block=True) as add_conversation_btn: | |
| ms.Text("New Conversation") | |
| with ms.Slot("icon"): antd.Icon("PlusOutlined") | |
| with antdx.Conversations() as conversations: | |
| pass # Handled by events | |
| # Right Main Chat Area | |
| with antd.Col(flex=1, elem_style=dict(height="100%")): | |
| with antd.Flex(vertical=True, gap="small", elem_id="main_chat_area"): | |
| gr.Markdown("## Multimodal OCR2") | |
| chatbot = pro.Chatbot( | |
| height="calc(100vh - 200px)", | |
| welcome_config=pro.Chatbot.WelcomeConfig(prompts=welcome_prompts, title="Start by selecting an example:") | |
| ) | |
| with pro.MultimodalInput(placeholder="Ask a question about your image or video...") as input_comp: | |
| with ms.Slot("prefix"): | |
| model_selector = gr.Dropdown( | |
| choices=MODEL_CHOICES, | |
| value=MODEL_CHOICES[0], | |
| label="Select Model", | |
| container=False | |
| ) | |
| # --- Event Wiring --- | |
| add_conversation_btn.click( | |
| fn=Gradio_Events.new_chat, | |
| inputs=[state], | |
| outputs=[conversations, chatbot, state, model_selector] | |
| ) | |
| conversations.active_change( | |
| fn=Gradio_Events.select_conversation, | |
| inputs=[state], | |
| outputs=[conversations, chatbot, state, model_selector] | |
| ) | |
| chatbot.welcome_prompt_select( | |
| fn=Gradio_Events.apply_prompt, | |
| inputs=[], | |
| outputs=[input_comp] | |
| ) | |
| submit_event = input_comp.submit( | |
| fn=Gradio_Events.add_message, | |
| inputs=[input_comp, state], | |
| outputs=[input_comp, add_conversation_btn, conversations, chatbot, state] | |
| ) | |
| model_selector.change( | |
| fn=Gradio_Events.on_model_change, | |
| inputs=[model_selector, state], | |
| outputs=[state] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(show_error=True, debug=True) |