ericjedha commited on
Commit
65a0071
·
verified ·
1 Parent(s): ebb5958

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -71
app.py CHANGED
@@ -1,89 +1,147 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
  from threading import Thread
4
- import torch
5
- import spaces
6
  import re
7
  import time
 
 
 
 
8
 
9
- # Utilisation de la version ultra-légère pour garantir la stabilité
10
- MODEL_ID = "HuggingFaceTB/SmolVLM2-256M-Instruct"
11
 
12
- # Chargement du processeur et du modèle
13
- processor = AutoProcessor.from_pretrained(MODEL_ID)
14
- model = AutoModelForImageTextToText.from_pretrained(
15
- MODEL_ID,
16
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
17
- device_map="auto"
18
- )
19
 
20
- @spaces.GPU(duration=60)
21
- def model_inference(input_dict, history, max_tokens):
 
 
 
22
  text = input_dict["text"]
23
- files = input_dict.get("files", [])
24
-
25
- # Construction du contenu multimodal
26
  user_content = []
27
-
28
- # On ajoute d'abord les médias (images/vidéos)
29
- for file in files:
30
- if file.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
31
- user_content.append({"type": "image", "path": file})
32
- elif file.lower().endswith((".mp4", ".mov", ".avi", ".mkv")):
33
- user_content.append({"type": "video", "path": file})
34
-
35
- # On ajoute le texte
36
- if text.strip():
37
- user_content.append({"type": "text", "text": text})
38
-
39
- if not user_content:
40
- yield "Veuillez uploader une vidéo de chat ou poser une question."
41
- return
42
 
43
- # Structure du message pour SmolVLM2
44
- messages = [{"role": "user", "content": user_content}]
45
-
46
- # Préparation des inputs via le template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  inputs = processor.apply_chat_template(
48
- messages,
49
- add_generation_prompt=True,
50
- tokenize=True,
51
- return_dict=True,
52
- return_tensors="pt"
53
- ).to(model.device)
54
-
55
- # Configuration du streamer pour l'affichage en direct
 
 
 
56
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
57
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
58
-
59
- # Lancement de la génération dans un thread séparé
60
  thread = Thread(target=model.generate, kwargs=generation_args)
61
  thread.start()
62
-
 
63
  buffer = ""
 
 
64
  for new_text in streamer:
65
- buffer += new_text
66
- yield buffer
67
-
68
- # Création de l'interface Chat
69
- demo = gr.ChatInterface(
70
- fn=model_inference,
71
- title="SmolVLM2: The Smollest Video Model Ever 📺",
72
- description="Analyse de vidéos et d'images avec SmolVLM2-256M.",
73
- textbox=gr.MultimodalTextbox(
74
- label="Query Input",
75
- file_types=["image", ".mp4"],
76
- file_count="multiple"
77
- ),
78
- stop_btn="Stop Generation",
79
- # On SUPPRIME 'multimodal=True' ici, car gr.MultimodalTextbox l'est par défaut
80
- cache_examples=False,
81
- additional_inputs=[
82
- gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")
83
- ],
84
- type="messages"
85
- )
86
-
87
- if __name__ == "__main__":
88
- # ss=False pour éviter les bugs de schéma JSON vus précédemment
89
- demo.launch(debug=True)
 
 
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
  from threading import Thread
 
 
4
  import re
5
  import time
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
+ from io import BytesIO
 
12
 
13
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
14
+ model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
15
+ _attn_implementation="flash_attention_2",
16
+ torch_dtype=torch.bfloat16).to("cuda:0")
 
 
 
17
 
18
+
19
+ @spaces.GPU
20
+ def model_inference(
21
+ input_dict, history, max_tokens
22
+ ):
23
  text = input_dict["text"]
24
+ images = []
 
 
25
  user_content = []
26
+ media_queue = []
27
+ if history == []:
28
+ text = input_dict["text"].strip()
29
+
30
+ for file in input_dict.get("files", []):
31
+ if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
32
+ media_queue.append({"type": "image", "path": file})
33
+ elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
34
+ media_queue.append({"type": "video", "path": file})
 
 
 
 
 
 
35
 
36
+ if "<image>" in text or "<video>" in text:
37
+ parts = re.split(r'(<image>|<video>)', text)
38
+ for part in parts:
39
+ if part == "<image>" and media_queue:
40
+ user_content.append(media_queue.pop(0))
41
+ elif part == "<video>" and media_queue:
42
+ user_content.append(media_queue.pop(0))
43
+ elif part.strip():
44
+ user_content.append({"type": "text", "text": part.strip()})
45
+ else:
46
+ user_content.append({"type": "text", "text": text})
47
+
48
+ for media in media_queue:
49
+ user_content.append(media)
50
+
51
+ resulting_messages = [{"role": "user", "content": user_content}]
52
+
53
+ elif len(history) > 0:
54
+ resulting_messages = []
55
+ user_content = []
56
+ media_queue = []
57
+ for hist in history:
58
+ if hist["role"] == "user" and isinstance(hist["content"], tuple):
59
+ file_name = hist["content"][0]
60
+ if file_name.endswith((".png", ".jpg", ".jpeg")):
61
+ media_queue.append({"type": "image", "path": file_name})
62
+ elif file_name.endswith(".mp4"):
63
+ media_queue.append({"type": "video", "path": file_name})
64
+
65
+
66
+ for hist in history:
67
+ if hist["role"] == "user" and isinstance(hist["content"], str):
68
+ text = hist["content"]
69
+ parts = re.split(r'(<image>|<video>)', text)
70
+
71
+ for part in parts:
72
+ if part == "<image>" and media_queue:
73
+ user_content.append(media_queue.pop(0))
74
+ elif part == "<video>" and media_queue:
75
+ user_content.append(media_queue.pop(0))
76
+ elif part.strip():
77
+ user_content.append({"type": "text", "text": part.strip()})
78
+
79
+ elif hist["role"] == "assistant":
80
+ resulting_messages.append({
81
+ "role": "user",
82
+ "content": user_content
83
+ })
84
+ resulting_messages.append({
85
+ "role": "assistant",
86
+ "content": [{"type": "text", "text": hist["content"]}]
87
+ })
88
+ user_content = []
89
+
90
+
91
+ if text == "" and not images:
92
+ gr.Error("Please input a query and optionally image(s).")
93
+
94
+ if text == "" and images:
95
+ gr.Error("Please input a text query along the images(s).")
96
+ print("resulting_messages", resulting_messages)
97
  inputs = processor.apply_chat_template(
98
+ resulting_messages,
99
+ add_generation_prompt=True,
100
+ tokenize=True,
101
+ return_dict=True,
102
+ return_tensors="pt",
103
+ )
104
+
105
+ inputs = inputs.to(model.device)
106
+
107
+
108
+ # Generate
109
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
110
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
111
+ generated_text = ""
112
+
113
  thread = Thread(target=model.generate, kwargs=generation_args)
114
  thread.start()
115
+
116
+ yield "..."
117
  buffer = ""
118
+
119
+
120
  for new_text in streamer:
121
+
122
+ buffer += new_text
123
+ generated_text_without_prompt = buffer#[len(ext_buffer):]
124
+ time.sleep(0.01)
125
+ yield buffer
126
+
127
+
128
+ examples=[
129
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
130
+ [{"text": "What art era this artpiece <image> and this artpiece <image> belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}],
131
+ [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
132
+ [{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
133
+ [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
134
+ [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
135
+ ]
136
+ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
137
+ description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
138
+ examples=examples,
139
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
140
+ cache_examples=False,
141
+ additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
142
+ type="messages"
143
+ )
144
+
145
+
146
+
147
+ demo.launch(debug=True)