hasanbasbunar commited on
Commit
5e513b8
·
verified ·
1 Parent(s): 8f0a7b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -50
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- from PIL import Image
6
  import matplotlib.pyplot as plt
7
  import matplotlib
8
  import tempfile
@@ -115,12 +115,28 @@ def get_first_frame(video_path):
115
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
116
  return None
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # --- HELPERS POUR DUREE DYNAMIQUE ZEROGPU ---
119
 
120
  def compute_duration_text(video_path, text_prompt, max_frames, timeout_seconds):
121
  return timeout_seconds
122
 
123
- def compute_duration_tracker(video_path, x, y, max_frames, timeout_seconds):
124
  return timeout_seconds
125
 
126
  # --- LOGIQUE AVEC DÉCORATEURS ZEROGPU ---
@@ -144,11 +160,10 @@ def process_image_text(image, text_prompt, threshold, mask_threshold):
144
  except Exception as e:
145
  return image, f"Error: {str(e)}"
146
 
147
- # MODIFICATION IMPORTANTE : Cette fonction GPU ne prend plus 'evt', mais 'x' et 'y' directement
148
  @spaces.GPU
149
  def process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask):
150
  if image is None: return image, [], []
151
- # x et y sont maintenant des entiers simples
152
  if points_state is None: points_state = []; labels_state = []
153
  points_state.append([x, y])
154
  labels_state.append(1)
@@ -166,21 +181,18 @@ def process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask
166
  best_idx = np.argmax(scores)
167
  masks_to_show = masks_to_show[best_idx:best_idx+1]
168
  final_img = overlay_masks(image, masks_to_show)
169
- draw = Image.fromarray(np.array(final_img))
170
- import PIL.ImageDraw
171
- d = PIL.ImageDraw.Draw(draw)
172
- for pt in points_state:
173
- d.ellipse((pt[0]-5, pt[1]-5, pt[0]+5, pt[1]+5), fill="red", outline="white")
174
- return draw, points_state, labels_state
175
  except Exception as e:
176
  print(f"Tracker Error: {e}")
177
  return image, points_state, labels_state
178
 
179
- # WRAPPER CPU POUR IMAGE TRACKER : Extrait les données avant d'appeler le GPU
180
  def process_image_tracker_wrapper(image, evt: gr.SelectData, points_state, labels_state, multimask):
181
  if evt is None: return image, points_state, labels_state
182
  x, y = evt.index
183
- # Appel de la fonction GPU avec des types simples
184
  return process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask)
185
 
186
 
@@ -220,9 +232,36 @@ def process_video_text(video_path, text_prompt, max_frames, timeout_seconds):
220
  return output_path, "Done!"
221
  except Exception as e: return None, f"Error: {str(e)}"
222
 
223
- # MODIFICATION IMPORTANTE : Cette fonction GPU ne prend plus 'first_frame_click' (objet complexe) mais x, y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  @spaces.GPU(duration=compute_duration_tracker)
225
- def process_video_tracker_gpu(video_path, x, y, max_frames, timeout_seconds):
 
 
226
  try:
227
  model, processor = get_model("sam3_video_tracker")
228
  cap = cv2.VideoCapture(video_path)
@@ -238,7 +277,18 @@ def process_video_tracker_gpu(video_path, x, y, max_frames, timeout_seconds):
238
  frame_count += 1
239
  cap.release()
240
  inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
241
- processor.add_inputs_to_inference_session(inference_session=inference_session, frame_idx=0, obj_ids=1, input_points=[[[[x, y]]]], input_labels=[[[1]]])
 
 
 
 
 
 
 
 
 
 
 
242
  output_path = tempfile.mktemp(suffix=".mp4")
243
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
244
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
@@ -257,25 +307,6 @@ def process_video_tracker_gpu(video_path, x, y, max_frames, timeout_seconds):
257
  print(f"Video Tracker Error: {e}")
258
  return None, f"Fatal Error: {str(e)}"
259
 
260
- # WRAPPER CPU POUR VIDEO TRACKER
261
- def process_video_tracker_wrapper(video_path, first_frame_click, max_frames, timeout_seconds):
262
- if not video_path or not first_frame_click: return None, "Please click on the first frame."
263
- # Extraction des données simples ici, sur le CPU
264
- if hasattr(first_frame_click, 'index'):
265
- x, y = first_frame_click.index
266
- else:
267
- return None, "Click error."
268
-
269
- # Appel de la fonction GPU avec des entiers
270
- return process_video_tracker_gpu(video_path, x, y, max_frames, timeout_seconds)
271
-
272
- # NOUVEAU WRAPPER POUR L'AUTO-START (Select Event)
273
- def video_select_trigger(video_path, max_frames, duration, evt: gr.SelectData):
274
- # On lance le traitement directement avec l'event du clic
275
- output_video, status = process_video_tracker_wrapper(video_path, evt, max_frames, duration)
276
- # On retourne aussi l'event pour mettre à jour le state "click_state" au cas où
277
- return output_video, status, evt
278
-
279
  # --- INTERFACE GRADIO ---
280
 
281
  with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
@@ -313,7 +344,6 @@ with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
313
  with gr.Column():
314
  i2_output = gr.Image(type="pil", label="Interactive Result")
315
 
316
- # APPEL DU WRAPPER CPU, PAS DE LA FONCTION GPU DIRECTEMENT
317
  i2_input.select(process_image_tracker_wrapper, [i2_input, points_state, labels_state, i2_multimask], [i2_output, points_state, labels_state])
318
  i2_clear.click(lambda: (None, [], []), outputs=[i2_output, points_state, labels_state])
319
 
@@ -325,46 +355,60 @@ with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
325
  v3_input = gr.Video(label="Input Video", format="mp4")
326
  v3_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
327
  v3_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
328
- # Ajout choix durée
329
  v3_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
330
  v3_btn = gr.Button("Start Video Segmentation", variant="primary")
331
  with gr.Column():
332
  v3_output = gr.Video(label="Result Video")
333
  v3_status = gr.Textbox(label="Status")
334
- # Ajout v3_duration aux inputs
335
  v3_btn.click(process_video_text, [v3_input, v3_text, v3_max_frames, v3_duration], [v3_output, v3_status])
336
 
337
  # TAB 4 : VIDEO + TRACKER
338
  with gr.Tab("🎯 Video - Visual Tracker"):
339
- gr.Markdown("### Track a specific object in video\n1. Upload a video.\n2. Wait for the first frame to appear below.\n3. Click on the object you want to track.\n4. Processing starts automatically.")
340
  with gr.Row():
341
  with gr.Column():
342
  v4_input = gr.Video(label="Input Video", format="mp4")
343
- v4_frame0 = gr.Image(label="First Frame (Click the object here)", interactive=True)
344
  v4_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
345
  v4_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
346
  with gr.Row():
347
  v4_btn = gr.Button("Start Object Tracking", variant="primary")
348
- v4_clear = gr.Button("Reset Tracking") # Nouveau bouton Reset
 
 
 
349
  with gr.Column():
350
  v4_output = gr.Video(label="Result Video")
351
  v4_status = gr.Textbox(label="Status")
352
 
353
- v4_input.change(get_first_frame, inputs=v4_input, outputs=v4_frame0)
354
- click_state = gr.State()
 
 
 
 
 
 
 
 
 
355
 
356
- # AUTO-START : Le clic déclenche directement le tracking + sauvegarde l'état
357
  v4_frame0.select(
358
- video_select_trigger,
359
- inputs=[v4_input, v4_max_frames, v4_duration],
360
- outputs=[v4_output, v4_status, click_state]
361
  )
362
 
363
- # Bouton manuel (utilise l'état sauvegardé par le clic)
364
- v4_btn.click(process_video_tracker_wrapper, [v4_input, click_state, v4_max_frames, v4_duration], [v4_output, v4_status])
365
 
366
- # Bouton Reset
367
- v4_clear.click(lambda: (None, "", None), outputs=[v4_output, v4_status, click_state])
 
 
 
 
368
 
369
  if __name__ == "__main__":
370
  demo.launch(share=False, debug=True, theme=gr.themes.Soft())
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image, ImageDraw
6
  import matplotlib.pyplot as plt
7
  import matplotlib
8
  import tempfile
 
115
  return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
116
  return None
117
 
118
+ def draw_points_on_image(image, points):
119
+ """Dessine des points rouges sur l'image pour feedback visuel."""
120
+ if isinstance(image, np.ndarray):
121
+ image = Image.fromarray(image)
122
+
123
+ # Créer une copie pour dessiner
124
+ draw_img = image.copy()
125
+ draw = ImageDraw.Draw(draw_img)
126
+
127
+ for pt in points:
128
+ x, y = pt
129
+ r = 5
130
+ draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white")
131
+
132
+ return draw_img
133
+
134
  # --- HELPERS POUR DUREE DYNAMIQUE ZEROGPU ---
135
 
136
  def compute_duration_text(video_path, text_prompt, max_frames, timeout_seconds):
137
  return timeout_seconds
138
 
139
+ def compute_duration_tracker(video_path, points_state, labels_state, max_frames, timeout_seconds):
140
  return timeout_seconds
141
 
142
  # --- LOGIQUE AVEC DÉCORATEURS ZEROGPU ---
 
160
  except Exception as e:
161
  return image, f"Error: {str(e)}"
162
 
163
+ # Image Tracker avec Multi-points
164
  @spaces.GPU
165
  def process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask):
166
  if image is None: return image, [], []
 
167
  if points_state is None: points_state = []; labels_state = []
168
  points_state.append([x, y])
169
  labels_state.append(1)
 
181
  best_idx = np.argmax(scores)
182
  masks_to_show = masks_to_show[best_idx:best_idx+1]
183
  final_img = overlay_masks(image, masks_to_show)
184
+
185
+ # Dessiner les points
186
+ final_img = draw_points_on_image(final_img, points_state)
187
+
188
+ return final_img, points_state, labels_state
 
189
  except Exception as e:
190
  print(f"Tracker Error: {e}")
191
  return image, points_state, labels_state
192
 
 
193
  def process_image_tracker_wrapper(image, evt: gr.SelectData, points_state, labels_state, multimask):
194
  if evt is None: return image, points_state, labels_state
195
  x, y = evt.index
 
196
  return process_image_tracker_gpu(image, x, y, points_state, labels_state, multimask)
197
 
198
 
 
232
  return output_path, "Done!"
233
  except Exception as e: return None, f"Error: {str(e)}"
234
 
235
+ # --- VIDEO TRACKER MULTI-POINT ---
236
+
237
+ # Fonction CPU pour ajouter un point VISUELLEMENT (sans appeler le GPU)
238
+ def add_point_video_preview(video_path, evt: gr.SelectData, points_state, labels_state):
239
+ """Ajoute un point à la liste et met à jour l'image de preview avec un point rouge."""
240
+ if not video_path: return None, points_state, labels_state
241
+
242
+ # Récupérer la frame originale brute (sans points)
243
+ # Pour faire simple ici, on la recharge à chaque fois.
244
+ # Optimisation possible: stocker l'image originale dans un State.
245
+ orig_frame = get_first_frame(video_path)
246
+ if orig_frame is None: return None, points_state, labels_state
247
+ orig_img = Image.fromarray(orig_frame)
248
+
249
+ x, y = evt.index
250
+ if points_state is None: points_state = []; labels_state = []
251
+
252
+ points_state.append([x, y])
253
+ labels_state.append(1)
254
+
255
+ # Dessiner TOUS les points sur l'image originale
256
+ preview_img = draw_points_on_image(orig_img, points_state)
257
+
258
+ return preview_img, points_state, labels_state
259
+
260
+
261
  @spaces.GPU(duration=compute_duration_tracker)
262
+ def process_video_tracker_gpu(video_path, points_state, labels_state, max_frames, timeout_seconds):
263
+ if not video_path or not points_state: return None, "Please click on the frame first."
264
+
265
  try:
266
  model, processor = get_model("sam3_video_tracker")
267
  cap = cv2.VideoCapture(video_path)
 
277
  frame_count += 1
278
  cap.release()
279
  inference_session = processor.init_video_session(video=frames, inference_device=device, dtype=torch.bfloat16)
280
+
281
+ # Envoi de TOUS les points accumulés
282
+ input_points = [[points_state]] # [Obj=1 [Points...]]
283
+ input_labels = [[labels_state]]
284
+
285
+ processor.add_inputs_to_inference_session(
286
+ inference_session=inference_session,
287
+ frame_idx=0,
288
+ obj_ids=1,
289
+ input_points=input_points,
290
+ input_labels=input_labels
291
+ )
292
  output_path = tempfile.mktemp(suffix=".mp4")
293
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
294
  out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
 
307
  print(f"Video Tracker Error: {e}")
308
  return None, f"Fatal Error: {str(e)}"
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # --- INTERFACE GRADIO ---
311
 
312
  with gr.Blocks(title="SAM3 Ultimate Suite") as demo:
 
344
  with gr.Column():
345
  i2_output = gr.Image(type="pil", label="Interactive Result")
346
 
 
347
  i2_input.select(process_image_tracker_wrapper, [i2_input, points_state, labels_state, i2_multimask], [i2_output, points_state, labels_state])
348
  i2_clear.click(lambda: (None, [], []), outputs=[i2_output, points_state, labels_state])
349
 
 
355
  v3_input = gr.Video(label="Input Video", format="mp4")
356
  v3_text = gr.Textbox(label="Text Prompt", placeholder="e.g.: person, car")
357
  v3_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
 
358
  v3_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
359
  v3_btn = gr.Button("Start Video Segmentation", variant="primary")
360
  with gr.Column():
361
  v3_output = gr.Video(label="Result Video")
362
  v3_status = gr.Textbox(label="Status")
 
363
  v3_btn.click(process_video_text, [v3_input, v3_text, v3_max_frames, v3_duration], [v3_output, v3_status])
364
 
365
  # TAB 4 : VIDEO + TRACKER
366
  with gr.Tab("🎯 Video - Visual Tracker"):
367
+ gr.Markdown("### Track a specific object in video (Multi-point Support)\n1. Upload a video.\n2. Click on the object in the 'First Frame'. **You can click multiple times** to refine the selection.\n3. Click 'Start Object Tracking' when ready.")
368
  with gr.Row():
369
  with gr.Column():
370
  v4_input = gr.Video(label="Input Video", format="mp4")
371
+ v4_frame0 = gr.Image(label="First Frame (Click to add points)", interactive=True)
372
  v4_max_frames = gr.Slider(10, 300, value=50, step=10, label="Max Frames to Process")
373
  v4_duration = gr.Radio([60, 120], value=60, label="Max Processing Time (seconds)", info="Choose 60s for short clips, 120s for complex tasks")
374
  with gr.Row():
375
  v4_btn = gr.Button("Start Object Tracking", variant="primary")
376
+ v4_clear = gr.Button("Reset Tracking")
377
+ # États pour stocker les points multiples
378
+ v4_points_state = gr.State([])
379
+ v4_labels_state = gr.State([])
380
  with gr.Column():
381
  v4_output = gr.Video(label="Result Video")
382
  v4_status = gr.Textbox(label="Status")
383
 
384
+ # --- CORRECTION ICI ---
385
+ # Fusion des deux événements pour éviter le conflit (affichage vs reset)
386
+ def on_video_upload(video_path):
387
+ # 1. On récupère l'image
388
+ frame = get_first_frame(video_path)
389
+ # 2. On reset les états (points et labels vides)
390
+ # Retourne : Image, Points vides, Labels vides
391
+ return frame, [], []
392
+
393
+ v4_input.change(on_video_upload, inputs=v4_input, outputs=[v4_frame0, v4_points_state, v4_labels_state])
394
+ # ----------------------
395
 
396
+ # 1. Clic -> Ajout point visuel (CPU) + Mise à jour State
397
  v4_frame0.select(
398
+ add_point_video_preview,
399
+ inputs=[v4_input, v4_points_state, v4_labels_state],
400
+ outputs=[v4_frame0, v4_points_state, v4_labels_state]
401
  )
402
 
403
+ # 2. Bouton Start -> Envoi de la liste complète des points au GPU
404
+ v4_btn.click(process_video_tracker_gpu, [v4_input, v4_points_state, v4_labels_state, v4_max_frames, v4_duration], [v4_output, v4_status])
405
 
406
+ # 3. Bouton Reset -> Vide les points, recharge l'image vierge
407
+ def reset_tracking_view(video_path):
408
+ img = get_first_frame(video_path)
409
+ return None, "", [], [], img
410
+
411
+ v4_clear.click(reset_tracking_view, inputs=[v4_input], outputs=[v4_output, v4_status, v4_points_state, v4_labels_state, v4_frame0])
412
 
413
  if __name__ == "__main__":
414
  demo.launch(share=False, debug=True, theme=gr.themes.Soft())