File size: 2,486 Bytes
bd04204 0aa2638 d8a0153 0aa2638 d8a0153 bd04204 0aa2638 d8a0153 0aa2638 a1cae3c 0aa2638 d8a0153 0aa2638 bd04204 0aa2638 d8a0153 0aa2638 bd04204 0aa2638 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import torch
from threading import Thread
from transformers import (
SmolVLMProcessor,
AutoModelForImageTextToText,
TextIteratorStreamer,
)
# ======================
# INIT MODÈLE
# ======================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_ID = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
processor = SmolVLMProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE).eval()
# ======================
# STREAMING INFERENCE
# ======================
def analyze_stream(text, image, max_tokens):
if image is None and not text.strip():
return "❌ Veuillez fournir un texte ou une image."
content = []
if image:
content.append({"type": "image", "path": image})
if text.strip():
content.append({"type": "text", "text": text})
messages = [{"role": "user", "content": content}]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
).to(DEVICE)
streamer = TextIteratorStreamer(
processor,
skip_prompt=True,
skip_special_tokens=True,
)
Thread(
target=model.generate,
kwargs=dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=False,
temperature=0.0,
),
).start()
output = ""
for token in streamer:
output += token
yield output
# ======================
# UI GRADIO
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## ⚡ SmolVLM2 – Analyse Temps Réel")
with gr.Row():
with gr.Column():
txt = gr.Textbox(
label="Question / Description",
lines=3,
)
img = gr.Image(type="filepath", label="Image")
max_tokens = gr.Slider(
50, 400, value=200, step=50, label="Max Tokens"
)
btn = gr.Button("🚀 Analyser", variant="primary")
with gr.Column():
out = gr.Textbox(
label="Réponse en Temps Réel",
lines=14,
)
btn.click(
fn=analyze_stream,
inputs=[txt, img, max_tokens],
outputs=out,
)
demo.launch()
|