hasanbasbunar commited on
Commit
2b120f6
·
verified ·
1 Parent(s): a7057b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +333 -0
app.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import tempfile
9
+ import os
10
+ import gc
11
+ import spaces # <--- IMPORT OBLIGATOIRE POUR ZEROGPU
12
+ from transformers import (
13
+ Sam3Model, Sam3Processor,
14
+ Sam3TrackerModel, Sam3TrackerProcessor,
15
+ Sam3VideoModel, Sam3VideoProcessor,
16
+ Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
17
+ )
18
+
19
+ # --- CONFIGURATION ---
20
+ MODELS = {}
21
+ # Sur ZeroGPU, on peut forcer "cuda" car le décorateur nous garantit un GPU.
22
+ device = "cuda"
23
+ print(f"🖥️ Configuration ZeroGPU active.")
24
+
25
+ def cleanup_memory():
26
+ """Force le nettoyage de la VRAM."""
27
+ if MODELS:
28
+ print("🧹 Nettoyage préventif mémoire...")
29
+ MODELS.clear()
30
+ gc.collect()
31
+ torch.cuda.empty_cache()
32
+
33
+ def get_model(model_type):
34
+ """Charge le modèle. Sur ZeroGPU, cela se produit à l'intérieur de la fonction décorée."""
35
+ if model_type in MODELS:
36
+ return MODELS[model_type]
37
+
38
+ # Même avec 70GB de VRAM sur ZeroGPU, garder le swap est une bonne pratique de stabilité
39
+ cleanup_memory()
40
+
41
+ print(f"⏳ Chargement de {model_type} sur H200...")
42
+
43
+ try:
44
+ if model_type == "sam3_image_text":
45
+ model = Sam3Model.from_pretrained("facebook/sam3").to(device)
46
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
47
+ MODELS[model_type] = (model, processor)
48
+
49
+ elif model_type == "sam3_image_tracker":
50
+ model = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device)
51
+ processor = Sam3TrackerProcessor.from_pretrained("facebook/sam3")
52
+ MODELS[model_type] = (model, processor)
53
+
54
+ elif model_type == "sam3_video_text":
55
+ # Sur H200, on peut se permettre le float32, mais bfloat16 reste plus rapide
56
+ model = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
57
+ processor = Sam3VideoProcessor.from_pretrained("facebook/sam3")
58
+ MODELS[model_type] = (model, processor)
59
+
60
+ elif model_type == "sam3_video_tracker":
61
+ model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
62
+ processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")
63
+ MODELS[model_type] = (model, processor)
64
+
65
+ print(f"✅ {model_type} chargé.")
66
+ return MODELS[model_type]
67
+
68
+ except Exception as e:
69
+ print(f"❌ Erreur chargement : {e}")
70
+ cleanup_memory()
71
+ raise e
72
+
73
+ # --- UTILITAIRES (Pas besoin de GPU ici) ---
74
+
75
+ def overlay_masks(image, masks, scores=None, alpha=0.5):
76
+ if isinstance(image, np.ndarray):
77
+ image = Image.fromarray(image)
78
+ image = image.convert("RGBA")
79
+ if masks is None or len(masks) == 0:
80
+ return image
81
+ if isinstance(masks, torch.Tensor):
82
+ masks = masks.cpu().numpy()
83
+ masks = masks.astype(np.uint8)
84
+ if masks.ndim == 4: masks = masks[0]
85
+ if masks.ndim == 3 and masks.shape[0] == 1: masks = masks[0]
86
+ n_masks = masks.shape[0] if masks.ndim == 3 else 1
87
+ if masks.ndim == 2:
88
+ masks = [masks]
89
+ n_masks = 1
90
+ try:
91
+ cmap = matplotlib.colormaps["rainbow"].resampled(max(n_masks, 1))
92
+ except AttributeError:
93
+ import matplotlib.cm as cm
94
+ cmap = cm.get_cmap("rainbow").resampled(max(n_masks, 1))
95
+ colors = [tuple(int(c * 255) for c in cmap(i)[:3]) for i in range(n_masks)]
96
+ overlay_layer = Image.new("RGBA", image.size, (0, 0, 0, 0))
97
+ for i, mask in enumerate(masks):
98
+ mask_img = Image.fromarray((mask * 255).astype(np.uint8))
99
+ if mask_img.size != image.size:
100
+ mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
101
+ color = colors[i]
102
+ color_layer = Image.new("RGBA", image.size, color + (0,))
103
+ mask_alpha = mask_img.point(lambda v: int(v * alpha) if v > 0 else 0)
104
+ color_layer.putalpha(mask_alpha)
105
+ overlay_layer = Image.alpha_composite(overlay_layer, color_layer)
106
+ return Image.alpha_composite(image, overlay_layer).convert("RGB")
107
+
108
+ def get_first_frame(video_path):
109
+ """Extrait la première frame d'une vidéo pour permettre le clic."""
110
+ if not video_path: return None
111
+ cap = cv2.VideoCapture(video_path)
112
+ ret, frame = cap.read()
113
+ cap.release()
114
+ if ret:
115
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
116
+ return None
117
+
118
+ # --- HELPERS POUR DUREE DYNAMIQUE ZEROGPU ---
119
+
120
+ def compute_duration_text(video_path, text_prompt, max_frames, timeout_seconds):
121
+ return timeout_seconds
122
+
123
+ def compute_duration_tracker(video_path, first_frame_click, max_frames, timeout_seconds):
124
+ return timeout_seconds
125
+
126
+ # --- LOGIQUE AVEC DÉCORATEURS ZEROGPU ---
127
+
128
+ @spaces.GPU
129
+ def process_image_text(image, text_prompt, threshold, mask_threshold):
130
+ if image is None or not text_prompt:
131
+ return image, "Please provide an image and a text prompt."
132
+ try:
133
+ model, processor = get_model("sam3_image_text")
134
+ inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(device)
135
+ with torch.no_grad():
136
+ outputs = model(**inputs)
137
+ results = processor.post_process_instance_segmentation(
138
+ outputs, threshold=threshold, mask_threshold=mask_threshold,
139
+ target_sizes=inputs.get("original_sizes").tolist()
140
+ )[0]
141
+ final_img = overlay_masks(image, results["masks"])
142
+ info = f"Objects found: {len(results['masks'])}\nScores: {results['scores'].cpu().numpy()}"
143
+ return final_img, info
144
+ except Exception as e:
145
+ return image, f"Error: {str(e)}"
146
+
147
+ @spaces.GPU
148
+ def process_image_tracker(image, evt: gr.SelectData, points_state, labels_state, multimask):
149
+ if image is None: return image, [], []
150
+ x, y = evt.index
151
+ if points_state is None: points_state = []; labels_state = []
152
+ points_state.append([x, y])
153
+ labels_state.append(1)
154
+ try:
155
+ model, processor = get_model("sam3_image_tracker")
156
+ input_points = [[points_state]]
157
+ input_labels = [[labels_state]]
158
+ inputs = processor(images=image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
159
+ with torch.no_grad():
160
+ outputs = model(**inputs, multimask_output=multimask)
161
+ masks = processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
162
+ masks_to_show = masks[0]
163
+ if multimask and masks_to_show.shape[0] > 1:
164
+ scores = outputs.iou_scores.cpu().numpy()[0, 0]
165
+ best_idx = np.argmax(scores)
166
+ masks_to_show = masks_to_show[best_idx:best_idx+1]
167
+ final_img = overlay_masks(image, masks_to_show)
168
+ draw = Image.fromarray(np.array(final_img))
169
+ import PIL.ImageDraw
170
+ d = PIL.ImageDraw.Draw(draw)
171
+ for pt in points_state:
172
+ d.ellipse((pt[0]-5, pt[1]-5, pt[0]+5, pt[1]+5), fill="red", outline="white")
173
+ return draw, points_state, labels_state
174
+ except Exception as e:
175
+ print(f"Tracker Error: {e}")
176
+ return image, points_state, labels_state
177
+
178
+ @spaces.GPU(duration=compute_duration_text)
179
+ def process_video_text(video_path, text_prompt, max_frames, timeout_seconds):
180
+ if not video_path or not text_prompt: return None, "Missing video or prompt."
181
+ try:
182
+ model, processor = get_model("sam3_video_text")
183
+ cap = cv2.VideoCapture(video_path)
184
+ fps = cap.get(cv2.CAP_PROP_FPS)
185
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
186
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
187
+ frames = []
188
+ frame_count = 0
189
+ while cap.isOpened():
190
+ ret, frame = cap.read()
191
+ if not ret or (max_frames > 0 and frame_count >= max_frames): break
192
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
193
+ frame_count += 1
194
+ cap.release()
195
+ inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
196
+ inference_session = processor.add_text_prompt(inference_session=inference_session, text=text_prompt)
197
+ output_path = tempfile.mktemp(suffix=".mp4")
198
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
199
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
200
+ for model_outputs in model.propagate_in_video_iterator(inference_session=inference_session, max_frame_num_to_track=len(frames)):
201
+ processed_outputs = processor.postprocess_outputs(inference_session, model_outputs)
202
+ frame_idx = model_outputs.frame_idx
203
+ orig_frame = Image.fromarray(frames[frame_idx])
204
+ if 'masks' in processed_outputs:
205
+ masks = processed_outputs['masks']
206
+ if masks.ndim == 4: masks = masks.squeeze(1)
207
+ res_frame = overlay_masks(orig_frame, masks)
208
+ else: res_frame = orig_frame
209
+ out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
210
+ out.release()
211
+ return output_path, "Done!"
212
+ except Exception as e: return None, f"Error: {str(e)}"
213
+
214
+ @spaces.GPU(duration=compute_duration_tracker)
215
+ def process_video_tracker(video_path, first_frame_click, max_frames, timeout_seconds):
216
+ if not video_path or not first_frame_click: return None, "Please click on the first frame."
217
+ if hasattr(first_frame_click, 'index'): x, y = first_frame_click.index
218
+ else: return None, "Click error."
219
+ try:
220
+ model, processor = get_model("sam3_video_tracker")
221
+ cap = cv2.VideoCapture(video_path)
222
+ fps = cap.get(cv2.CAP_PROP_FPS)
223
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
224
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
225
+ frames = []
226
+ frame_count = 0
227
+ while cap.isOpened():
228
+ ret, frame = cap.read()
229
+ if not ret or (max_frames > 0 and frame_count >= max_frames): break
230
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
231
+ frame_count += 1
232
+ cap.release()
233
+ inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
234
+ processor.add_inputs_to_inference_session(inference_session=inference_session, frame_idx=0, obj_ids=1, input_points=[[[[x, y]]]], input_labels=[[[1]]])
235
+ output_path = tempfile.mktemp(suffix=".mp4")
236
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
237
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
238
+
239
+ model(inference_session=inference_session, frame_idx=0) # Verrouillage
240
+
241
+ for sam3_out in model.propagate_in_video_iterator(inference_session):
242
+ masks = processor.post_process_masks([sam3_out.pred_masks], original_sizes=[[height, width]], binarize=True)[0]
243
+ frame_idx = sam3_out.frame_idx
244
+ orig_frame = Image.fromarray(frames[frame_idx])
245
+ res_frame = overlay_masks(orig_frame, masks[:, 0, :, :])
246
+ out.write(cv2.cvtColor(np.array(res_frame), cv2.COLOR_RGB2BGR))
247
+ out.release()
248
+ return output_path, "Tracking Finished!"
249
+ except Exception as e:
250
+ print(f"Video Tracker Error: {e}")
251
+ return None, f"Fatal Error: {str(e)}"
252
+
253
+ # --- INTERFACE GRADIO ---
254
+
255
+ with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
256
+ gr.Markdown("# 🚀 SAM 3 : Unified Promptable Segmentation")
257
+ gr.Markdown("This application allows you to utilize **all the powerful features of the SAM 3 model** for segmenting images and videos using text or visual prompts.")
258
+
259
+ with gr.Tabs():
260
+ # TAB 1 : IMAGE + TEXTE
261
+ with gr.Tab("🖼️ Image - Text Prompt"):
262
+ gr.Markdown("### Segment objects by description\nSimply upload an image and type the name of the objects you want to detect (e.g., 'cat', 'wheel', 'person'). The model will find all instances.")
263
+ with gr.Row():
264
+ with gr.Column():
265
+ i1_input = gr.Image(type="pil", label="Input Image")
266
+ i1_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: cat, wheel, person")
267
+ with gr.Accordion("Advanced Settings", open=False):
268
+ i1_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Confidence Threshold")
269
+ i1_mask_thresh = gr.Slider(0.0, 1.0, value=0.5, label="Mask Threshold")
270
+ i1_btn = gr.Button("Segment Image", variant="primary")
271
+ with gr.Column():
272
+ i1_output = gr.Image(type="pil", label="Result")
273
+ i1_info = gr.Textbox(label="Details", lines=2)
274
+
275
+ i1_btn.click(process_image_text, [i1_input, i1_text, i1_thresh, i1_mask_thresh], [i1_output, i1_info])
276
+
277
+ # TAB 2 : IMAGE + TRACKER
278
+ with gr.Tab("🖱️ Image - Visual Tracker"):
279
+ gr.Markdown("### Segment objects by clicking\nUpload an image and click on the object you wish to segment. You can click multiple times to refine the selection.")
280
+ with gr.Row():
281
+ with gr.Column():
282
+ i2_input = gr.Image(type="pil", label="Input Image (Click to add points)", interactive=True)
283
+ i2_multimask = gr.Checkbox(label="Return Multiple Masks (Handling Ambiguity)", value=False)
284
+ i2_clear = gr.Button("Reset Points")
285
+ points_state = gr.State([])
286
+ labels_state = gr.State([])
287
+ with gr.Column():
288
+ i2_output = gr.Image(type="pil", label="Interactive Result")
289
+
290
+ i2_input.select(process_image_tracker, [i2_input, points_state, labels_state, i2_multimask], [i2_output, points_state, labels_state])
291
+ i2_clear.click(lambda: (None, [], []), outputs=[i2_output, points_state, labels_state])
292
+
293
+ # TAB 3 : VIDEO + TEXTE
294
+ with gr.Tab("🎥 Video - Text Prompt"):
295
+ gr.Markdown("### Track objects in video by description\nUpload a video and describe what to track. The model will detect and segment all matching objects throughout the video.")
296
+ with gr.Row():
297
+ with gr.Column():
298
+ v3_input = gr.Video(label="Input Video", format="mp4")
299
+ v3_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
300
+ v3_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
301
+ # Ajout choix durée
302
+ v3_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
303
+ v3_btn = gr.Button("Start Video Segmentation", variant="primary")
304
+ with gr.Column():
305
+ v3_output = gr.Video(label="Result Video")
306
+ v3_status = gr.Textbox(label="Status")
307
+ # Ajout v3_duration aux inputs
308
+ v3_btn.click(process_video_text, [v3_input, v3_text, v3_max_frames, v3_duration], [v3_output, v3_status])
309
+
310
+ # TAB 4 : VIDEO + TRACKER
311
+ with gr.Tab("🎯 Video - Visual Tracker"):
312
+ gr.Markdown("### Track a specific object in video\n1. Upload a video.\n2. Wait for the first frame to appear below.\n3. Click on the object you want to track.\n4. Start tracking.")
313
+ with gr.Row():
314
+ with gr.Column():
315
+ v4_input = gr.Video(label="Input Video", format="mp4")
316
+ v4_frame0 = gr.Image(label="First Frame (Click the object here)", interactive=True)
317
+ v4_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
318
+ # Ajout choix durée
319
+ v4_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
320
+ v4_btn = gr.Button("Start Object Tracking", variant="primary")
321
+ with gr.Column():
322
+ v4_output = gr.Video(label="Result Video")
323
+ v4_status = gr.Textbox(label="Status")
324
+
325
+ v4_input.change(get_first_frame, inputs=v4_input, outputs=v4_frame0)
326
+ click_state = gr.State()
327
+ def save_click(evt: gr.SelectData): return evt
328
+ v4_frame0.select(save_click, None, click_state)
329
+ # Ajout v4_duration aux inputs
330
+ v4_btn.click(process_video_tracker, [v4_input, click_state, v4_max_frames, v4_duration], [v4_output, v4_status])
331
+
332
+ if __name__ == "__main__":
333
+ demo.launch(share=False, debug=True, theme=gr.themes.Soft())