File size: 3,487 Bytes
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
 
 
4aff560
 
 
 
 
 
 
 
 
 
88a5b8d
4aff560
 
 
88a5b8d
4aff560
 
 
 
 
 
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
 
4aff560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a5b8d
4aff560
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    TextIteratorStreamer,
)
from peft import PeftModel
from transformers.image_utils import load_image
from threading import Thread
import time
import html


def progress_bar_html(label: str) -> str:
    """
    Returns an HTML snippet for a thin progress bar with a label.
    The progress bar is styled as a dark animated bar.
    """
    return f"""
<div style="display: flex; align-items: center;">
    <span style="margin-right: 10px; font-size: 14px;">{label}</span>
    <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
        <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
    </div>
</div>
<style>
@keyframes loading {{
    0% {{ transform: translateX(-100%); }}
    100% {{ transform: translateX(100%); }}
}}
</style>
    """


model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"

model = AutoModelForImageTextToText.from_pretrained(
    model_name, dtype=torch.bfloat16, device_map="auto"
).eval()

processor = AutoProcessor.from_pretrained(model_name)

print(f"Successfully load the model: {model}")


def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]

    if len(files) > 1:
        images = [load_image(image) for image in files]
    elif len(files) == 1:
        images = [load_image(files[0])]
    else:
        images = []

    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")
        return
    if text == "" and images:
        gr.Error("Please input a text query along with the image(s).")
        return

    messages = [
        {
            "role": "user",
            "content": [
                *[{"type": "image", "image": image} for image in images],
                {"type": "text", "text": text},
            ],
        }
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=model.dtype)
    streamer = TextIteratorStreamer(
        processor, skip_prompt=True, skip_special_tokens=True
    )
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    yield progress_bar_html("Processing...")
    for new_text in streamer:
        escaped_new_text = html.escape(new_text)
        buffer += escaped_new_text

        time.sleep(0.001)
        yield buffer


examples = [
    [
        {
            "text": "Write a descriptive caption for this image in a formal tone.",
            "files": ["example_images/example.png"],
        }
    ],
    [
        {
            "text": "What are the characters wearing?",
            "files": ["example_images/example.png"],
        }
    ],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
    examples=examples,
    fill_height=True,
    textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)