File size: 3,838 Bytes
88a5b8d
875f054
 
 
 
 
 
 
 
 
4aff560
 
 
 
 
 
 
 
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
 
 
4aff560
 
875f054
20b327e
875f054
4aff560
875f054
 
 
4aff560
875f054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02b26ea
 
875f054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8178d81
dc798af
 
 
875f054
 
dc798af
875f054
 
 
 
8178d81
875f054
 
88a5b8d
dc798af
 
 
840e373
dc798af
 
 
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
875f054
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    TextIteratorStreamer,
)
from peft import PeftModel
from transformers.image_utils import load_image
from threading import Thread
import time
import html


def progress_bar_html(label: str) -> str:
    """
    Returns an HTML snippet for a thin progress bar with a label.
    The progress bar is styled as a dark animated bar.
    """
    return f"""
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    """


model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name, dtype=torch.float32, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_name)

print(f"Successfully load the model: {model}")


def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]

    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []

    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")
        return
    if text == "" and images:
        gr.Error("Please input a text query along with the image(s).")
        return

    messages = [
        {
            "role": "user",
            "content": [
                *[{"type": "image", "image": image} for image in images],
                {"type": "text", "text": text},
            ],
        }
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=model.dtype)
    streamer = TextIteratorStreamer(
        processor, skip_prompt=True, skip_special_tokens=True
    )
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
    # start timer just before generation begins
    start_time = time.time()
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = "Baseline Model Response: "
    yield progress_bar_html("Processing...")
    for new_text in streamer:
        escaped_new_text = html.escape(new_text)
        buffer += escaped_new_text

        time.sleep(0.001)
        yield buffer

    # Ensure generation thread has finished and measure elapsed time
    thread.join()
    elapsed = time.time() - start_time
    elapsed_text = f"\nBaseline Generation Time: {elapsed:.2f} s"
    buffer += html.escape(elapsed_text)
    yield buffer


examples = [
    [
        {
            "text": "Write a descriptive caption for this image in a formal tone.",
            "files": ["example_images/example.png"],
        }
    ],
    [
        {
            "text": "What are the characters wearing?",
            "files": ["example_images/example.png"],
        }
    ],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
    examples=examples,
    fill_height=True,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)