echarlaix's picture
echarlaix HF Staff
update model with ov counterpart
bc84cb4
raw
history blame
6.04 kB
import gradio as gr
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
TextIteratorStreamer,
)
from threading import Thread
import re
import time
from optimum.intel import OVModelForVisualCausalLM
# model_id = "echarlaix/SmolVLM2-2.2B-Instruct-openvino"
# model_id = "echarlaix/SmolVLM-256M-Instruct-openvino"
model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino"
processor = AutoProcessor.from_pretrained(model_id)
model = OVModelForVisualCausalLM.from_pretrained(model_id)
def model_inference(input_dict, history, max_tokens):
text = input_dict["text"]
images = []
user_content = []
media_queue = []
if history == []:
text = input_dict["text"].strip()
for file in input_dict.get("files", []):
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
media_queue.append({"type": "image", "path": file})
elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
media_queue.append({"type": "video", "path": file})
if "<image>" in text or "<video>" in text:
parts = re.split(r"(<image>|<video>)", text)
for part in parts:
if part == "<image>" and media_queue:
user_content.append(media_queue.pop(0))
elif part == "<video>" and media_queue:
user_content.append(media_queue.pop(0))
elif part.strip():
user_content.append({"type": "text", "text": part.strip()})
else:
user_content.append({"type": "text", "text": text})
for media in media_queue:
user_content.append(media)
resulting_messages = [{"role": "user", "content": user_content}]
elif len(history) > 0:
resulting_messages = []
user_content = []
media_queue = []
for hist in history:
if hist["role"] == "user" and isinstance(hist["content"], tuple):
file_name = hist["content"][0]
if file_name.endswith((".png", ".jpg", ".jpeg")):
media_queue.append({"type": "image", "path": file_name})
elif file_name.endswith(".mp4"):
media_queue.append({"type": "video", "path": file_name})
for hist in history:
if hist["role"] == "user" and isinstance(hist["content"], str):
text = hist["content"]
parts = re.split(r"(<image>|<video>)", text)
for part in parts:
if part == "<image>" and media_queue:
user_content.append(media_queue.pop(0))
elif part == "<video>" and media_queue:
user_content.append(media_queue.pop(0))
elif part.strip():
user_content.append({"type": "text", "text": part.strip()})
elif hist["role"] == "assistant":
resulting_messages.append({"role": "user", "content": user_content})
resulting_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": hist["content"]}],
}
)
user_content = []
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
if text == "" and images:
gr.Error("Please input a text query along the images(s).")
# print("resulting_messages", resulting_messages)
inputs = processor.apply_chat_template(
resulting_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Generate
streamer = TextIteratorStreamer(
processor, skip_prompt=True, skip_special_tokens=True
)
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
# generated_text = ""
thread = Thread(target=model.generate, kwargs=generation_args)
thread.start()
yield "..."
buffer = ""
for new_text in streamer:
buffer += new_text
# generated_text_without_prompt = buffer#[len(ext_buffer):]
time.sleep(0.01)
yield buffer
examples = [
[
{
"text": "Where do the severe droughts happen according to this diagram?",
"files": ["example_images/examples_weather_events.png"],
}
],
[
{
"text": "What art era this artpiece <image> and this artpiece <image> belong to?",
"files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"],
}
],
[ {
"text": "Describe this image.",
"files": ["example_images/mosque.jpg"]
}
],
[
{
"text": "When was this purchase made and how much did it cost?",
"files": ["example_images/fiche.jpg"],
}
],
[
{
"text": "What is the date in this document?",
"files": ["example_images/document.jpg"],
}
],
[
{
"text": "What is happening in the video?",
"files": ["example_images/short.mp4"],
}
],
]
demo = gr.ChatInterface(
fn=model_inference,
title="SmolVLM2: The Smollest Video Model Ever 📺",
description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
examples=examples,
textbox=gr.MultimodalTextbox(
label="Query Input", file_types=["image", ".mp4"], file_count="multiple"
),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
additional_inputs=[
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")
],
type="messages",
)
demo.launch(debug=True)