|
|
import gradio as gr |
|
|
import torch |
|
|
from threading import Thread |
|
|
from transformers import ( |
|
|
SmolVLMProcessor, |
|
|
AutoModelForImageTextToText, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
MODEL_ID = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" |
|
|
|
|
|
processor = SmolVLMProcessor.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_stream(text, image, max_tokens): |
|
|
if image is None and not text.strip(): |
|
|
return "❌ Veuillez fournir un texte ou une image." |
|
|
|
|
|
content = [] |
|
|
if image: |
|
|
content.append({"type": "image", "path": image}) |
|
|
if text.strip(): |
|
|
content.append({"type": "text", "text": text}) |
|
|
|
|
|
messages = [{"role": "user", "content": content}] |
|
|
|
|
|
inputs = processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt", |
|
|
).to(DEVICE) |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
|
processor, |
|
|
skip_prompt=True, |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
|
|
|
Thread( |
|
|
target=model.generate, |
|
|
kwargs=dict( |
|
|
**inputs, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
), |
|
|
).start() |
|
|
|
|
|
output = "" |
|
|
for token in streamer: |
|
|
output += token |
|
|
yield output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## ⚡ SmolVLM2 – Analyse Temps Réel") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
txt = gr.Textbox( |
|
|
label="Question / Description", |
|
|
lines=3, |
|
|
) |
|
|
img = gr.Image(type="filepath", label="Image") |
|
|
max_tokens = gr.Slider( |
|
|
50, 400, value=200, step=50, label="Max Tokens" |
|
|
) |
|
|
btn = gr.Button("🚀 Analyser", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
out = gr.Textbox( |
|
|
label="Réponse en Temps Réel", |
|
|
lines=14, |
|
|
) |
|
|
|
|
|
btn.click( |
|
|
fn=analyze_stream, |
|
|
inputs=[txt, img, max_tokens], |
|
|
outputs=out, |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|