File size: 2,486 Bytes
bd04204
0aa2638
 
d8a0153
0aa2638
 
 
d8a0153
bd04204
0aa2638
 
 
d8a0153
0aa2638
a1cae3c
0aa2638
 
 
d8a0153
 
 
0aa2638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd04204
0aa2638
 
d8a0153
 
0aa2638
 
 
 
 
 
 
 
 
bd04204
0aa2638
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import torch
from threading import Thread
from transformers import (
    SmolVLMProcessor,
    AutoModelForImageTextToText,
    TextIteratorStreamer,
)

# ======================
# INIT MODÈLE
# ======================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_ID = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"

processor = SmolVLMProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()


# ======================
# STREAMING INFERENCE
# ======================
def analyze_stream(text, image, max_tokens):
    if image is None and not text.strip():
        return "❌ Veuillez fournir un texte ou une image."

    content = []
    if image:
        content.append({"type": "image", "path": image})
    if text.strip():
        content.append({"type": "text", "text": text})

    messages = [{"role": "user", "content": content}]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    ).to(DEVICE)

    streamer = TextIteratorStreamer(
        processor,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    Thread(
        target=model.generate,
        kwargs=dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_tokens,
            do_sample=False,
            temperature=0.0,
        ),
    ).start()

    output = ""
    for token in streamer:
        output += token
        yield output


# ======================
# UI GRADIO
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## ⚡ SmolVLM2 – Analyse Temps Réel")

    with gr.Row():
        with gr.Column():
            txt = gr.Textbox(
                label="Question / Description",
                lines=3,
            )
            img = gr.Image(type="filepath", label="Image")
            max_tokens = gr.Slider(
                50, 400, value=200, step=50, label="Max Tokens"
            )
            btn = gr.Button("🚀 Analyser", variant="primary")

        with gr.Column():
            out = gr.Textbox(
                label="Réponse en Temps Réel",
                lines=14,
            )

    btn.click(
        fn=analyze_stream,
        inputs=[txt, img, max_tokens],
        outputs=out,
    )

demo.launch()