usiddiquee786 commited on
Commit
7ec8a6e
·
verified ·
1 Parent(s): acf7711

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -58
app.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  import sys
8
  import importlib.util
9
  import time
 
10
 
11
  # Ensure models directory exists
12
  MODELS_DIR = Path("models")
@@ -47,7 +48,19 @@ def apply_patches():
47
  else:
48
  print("⚠️ tracker_patch.py not found, skipping patches")
49
 
50
- def run_tracking_single(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, temp_dir):
 
 
 
 
 
 
 
 
 
 
 
 
51
  """Run object tracking with a single tracking method."""
52
  try:
53
  # Prepare input
@@ -167,66 +180,64 @@ def run_tracking_single(video_file, yolo_model, reid_model, tracking_method, cla
167
  traceback.print_exc()
168
  return None, f"Error in {tracking_method}: {str(e)}", 0
169
 
170
- def benchmark_trackers(video_path, yolo_model, reid_model, class_ids, conf_threshold):
171
- """Run all tracking methods and return their results."""
172
  if not video_path:
173
- return (None, None, None, None, # Videos
174
- "Please upload a video file", "Please upload a video file",
175
- "Please upload a video file", "Please upload a video file", # Statuses
176
- "", "", "", "", # Times
177
- None) # DataFrame
178
 
179
- print(f"Benchmarking video: {video_path}")
180
  print(f"Parameters: model={yolo_model}, reid={reid_model}, classes={class_ids}, conf={conf_threshold}")
181
 
182
- tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
183
- results = [None] * 4
184
- statuses = ["Processing..."] * 4
185
- times = [0] * 4
186
 
187
  try:
188
- # Create a single temporary directory for all processes
189
  with tempfile.TemporaryDirectory() as temp_dir:
190
- # Process each tracking method sequentially
191
- for i, method in enumerate(tracking_methods):
192
- result, status, process_time = run_tracking_single(
193
- video_path,
194
- yolo_model,
195
- reid_model,
196
- method,
197
- class_ids,
198
- conf_threshold,
199
- temp_dir
200
- )
201
-
202
- # Store results
203
- if result:
204
- results[i] = os.path.abspath(result)
205
- statuses[i] = status
206
- times[i] = f"{process_time:.2f} seconds"
 
 
 
 
 
207
 
208
  except Exception as e:
209
  import traceback
210
  traceback.print_exc()
211
- # On failure, fill in the error for all pending results
212
- error_msg = f"Benchmark process error: {str(e)}"
213
- for i in range(len(tracking_methods)):
214
- if statuses[i] == "Processing...":
215
- statuses[i] = error_msg
216
-
217
- # Create the DataFrame data
218
- comparison_data = [[tracking_methods[i], times[i], "Success" if results[i] else "Failed"]
219
- for i in range(len(tracking_methods))]
 
 
220
 
221
- # Return all results as separate outputs
222
- return (results[0], results[1], results[2], results[3],
223
- statuses[0], statuses[1], statuses[2], statuses[3],
224
- times[0], times[1], times[2], times[3],
225
- comparison_data)
226
 
227
  # Available models
228
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
229
  reid_models = ["osnet_x0_25_msmt17.pt"]
 
230
 
231
  # Ensure dependencies and apply patches at startup
232
  ensure_dependencies()
@@ -235,7 +246,7 @@ apply_patches()
235
  # Create the Gradio interface
236
  with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
237
  gr.Markdown("# 🔍 YOLO Object Tracking Benchmark")
238
- gr.Markdown("Upload a video file to benchmark all four tracking methods. Processing may take several minutes depending on video length.")
239
 
240
  # Add class reference information
241
  with gr.Accordion("YOLO Class Reference", open=False):
@@ -259,6 +270,9 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
259
  [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
260
  """)
261
 
 
 
 
262
  with gr.Row():
263
  with gr.Column(scale=1):
264
  input_video = gr.Video(label="Input Video", sources=["upload"])
@@ -289,8 +303,18 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
289
  step=0.05,
290
  label="Confidence Threshold"
291
  )
292
-
293
- benchmark_btn = gr.Button("Benchmark All Trackers", variant="primary")
 
 
 
 
 
 
 
 
 
 
294
 
295
  # Output Tabs for each tracking method
296
  with gr.Tabs() as tabs:
@@ -323,17 +347,49 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
323
  label="Performance Comparison"
324
  )
325
 
326
- # Connect the benchmark button to the function and outputs
327
- benchmark_btn.click(
328
- fn=benchmark_trackers,
329
- inputs=[input_video, yolo_model, reid_model, class_ids, conf_threshold],
330
- outputs=[
331
- bytetrack_video, botsort_video, ocsort_video, strongsort_video,
332
- bytetrack_status, botsort_status, ocsort_status, strongsort_status,
333
- bytetrack_time, botsort_time, ocsort_time, strongsort_time,
334
- perf_table
335
- ],
 
 
 
 
 
 
336
  show_progress="full"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  )
338
 
339
  # Add a debug section
@@ -362,10 +418,23 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
362
  for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []:
363
  info.append(f" - {model}")
364
 
 
 
 
 
 
 
 
 
 
 
365
  return "\n".join(info)
366
 
367
  check_btn = gr.Button("Check Environment")
368
  check_btn.click(fn=check_environment, outputs=debug_text)
 
 
 
369
 
370
  # Launch the app
371
  if __name__ == "__main__":
 
7
  import sys
8
  import importlib.util
9
  import time
10
+ import gc
11
 
12
  # Ensure models directory exists
13
  MODELS_DIR = Path("models")
 
48
  else:
49
  print("⚠️ tracker_patch.py not found, skipping patches")
50
 
51
+ def clean_memory():
52
+ """Force garbage collection to free memory."""
53
+ gc.collect()
54
+ if hasattr(torch, 'cuda') and torch.cuda.is_available():
55
+ try:
56
+ import torch
57
+ torch.cuda.empty_cache()
58
+ print("🧹 Cleared CUDA memory cache")
59
+ except (ImportError, NameError):
60
+ print("⚠️ Could not clear CUDA memory (torch not available)")
61
+ print("🧹 Memory cleaned")
62
+
63
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, temp_dir):
64
  """Run object tracking with a single tracking method."""
65
  try:
66
  # Prepare input
 
180
  traceback.print_exc()
181
  return None, f"Error in {tracking_method}: {str(e)}", 0
182
 
183
+ def sequential_tracker(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, progress=gr.Progress()):
184
+ """Run a single tracker and return its results."""
185
  if not video_path:
186
+ return None, "Please upload a video file", "", None
 
 
 
 
187
 
188
+ print(f"Processing video: {video_path} with {tracking_method}")
189
  print(f"Parameters: model={yolo_model}, reid={reid_model}, classes={class_ids}, conf={conf_threshold}")
190
 
191
+ progress(0, desc=f"Starting {tracking_method}...")
 
 
 
192
 
193
  try:
194
+ # Create a temporary directory for processing
195
  with tempfile.TemporaryDirectory() as temp_dir:
196
+ progress(0.1, desc=f"Running {tracking_method}...")
197
+
198
+ result, status, process_time = run_tracking(
199
+ video_path,
200
+ yolo_model,
201
+ reid_model,
202
+ tracking_method,
203
+ class_ids,
204
+ conf_threshold,
205
+ temp_dir
206
+ )
207
+
208
+ progress(0.9, desc="Finalizing results...")
209
+
210
+ # Clean up memory after each tracker run
211
+ clean_memory()
212
+
213
+ # Create the DataFrame data for this single tracker
214
+ comparison_data = [[tracking_method, f"{process_time:.2f} seconds", "Success" if result else "Failed"]]
215
+
216
+ # Return the results
217
+ return result, status, f"{process_time:.2f} seconds", comparison_data
218
 
219
  except Exception as e:
220
  import traceback
221
  traceback.print_exc()
222
+ error_msg = f"Process error for {tracking_method}: {str(e)}"
223
+
224
+ # Clean memory even if there was an error
225
+ clean_memory()
226
+
227
+ return None, error_msg, "", [[tracking_method, "Error", "Failed"]]
228
+
229
+ def update_comparison_table(current_data, new_data):
230
+ """Update the comparison table with results from a new tracker."""
231
+ if current_data is None:
232
+ return new_data
233
 
234
+ # Append the new tracker data to the existing table
235
+ return current_data + new_data
 
 
 
236
 
237
  # Available models
238
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
239
  reid_models = ["osnet_x0_25_msmt17.pt"]
240
+ tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
241
 
242
  # Ensure dependencies and apply patches at startup
243
  ensure_dependencies()
 
246
  # Create the Gradio interface
247
  with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
248
  gr.Markdown("# 🔍 YOLO Object Tracking Benchmark")
249
+ gr.Markdown("Upload a video file and run each tracking method sequentially. Results will be displayed as they become available.")
250
 
251
  # Add class reference information
252
  with gr.Accordion("YOLO Class Reference", open=False):
 
270
  [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
271
  """)
272
 
273
+ # State variables to keep track of comparison data
274
+ comparison_data_state = gr.State([])
275
+
276
  with gr.Row():
277
  with gr.Column(scale=1):
278
  input_video = gr.Video(label="Input Video", sources=["upload"])
 
303
  step=0.05,
304
  label="Confidence Threshold"
305
  )
306
+
307
+ # Individual tracker buttons
308
+ with gr.Group():
309
+ gr.Markdown("### Run Trackers One by One")
310
+ with gr.Row():
311
+ bytetrack_btn = gr.Button("Run ByteTrack", variant="primary")
312
+ with gr.Row():
313
+ botsort_btn = gr.Button("Run BoTSORT", variant="primary")
314
+ with gr.Row():
315
+ ocsort_btn = gr.Button("Run OC-SORT", variant="primary")
316
+ with gr.Row():
317
+ strongsort_btn = gr.Button("Run StrongSORT", variant="primary")
318
 
319
  # Output Tabs for each tracking method
320
  with gr.Tabs() as tabs:
 
347
  label="Performance Comparison"
348
  )
349
 
350
+ # Connect individual tracker buttons to their respective functions
351
+ bytetrack_btn.click(
352
+ fn=sequential_tracker,
353
+ inputs=[input_video, yolo_model, reid_model, gr.State("bytetrack"), class_ids, conf_threshold],
354
+ outputs=[bytetrack_video, bytetrack_status, bytetrack_time, comparison_data_state],
355
+ show_progress="full"
356
+ ).then(
357
+ fn=update_comparison_table,
358
+ inputs=[perf_table, comparison_data_state],
359
+ outputs=perf_table
360
+ )
361
+
362
+ botsort_btn.click(
363
+ fn=sequential_tracker,
364
+ inputs=[input_video, yolo_model, reid_model, gr.State("botsort"), class_ids, conf_threshold],
365
+ outputs=[botsort_video, botsort_status, botsort_time, comparison_data_state],
366
  show_progress="full"
367
+ ).then(
368
+ fn=update_comparison_table,
369
+ inputs=[perf_table, comparison_data_state],
370
+ outputs=perf_table
371
+ )
372
+
373
+ ocsort_btn.click(
374
+ fn=sequential_tracker,
375
+ inputs=[input_video, yolo_model, reid_model, gr.State("ocsort"), class_ids, conf_threshold],
376
+ outputs=[ocsort_video, ocsort_status, ocsort_time, comparison_data_state],
377
+ show_progress="full"
378
+ ).then(
379
+ fn=update_comparison_table,
380
+ inputs=[perf_table, comparison_data_state],
381
+ outputs=perf_table
382
+ )
383
+
384
+ strongsort_btn.click(
385
+ fn=sequential_tracker,
386
+ inputs=[input_video, yolo_model, reid_model, gr.State("strongsort"), class_ids, conf_threshold],
387
+ outputs=[strongsort_video, strongsort_status, strongsort_time, comparison_data_state],
388
+ show_progress="full"
389
+ ).then(
390
+ fn=update_comparison_table,
391
+ inputs=[perf_table, comparison_data_state],
392
+ outputs=perf_table
393
  )
394
 
395
  # Add a debug section
 
418
  for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []:
419
  info.append(f" - {model}")
420
 
421
+ # Check for GPU
422
+ try:
423
+ import torch
424
+ info.append(f"CUDA available: {torch.cuda.is_available()}")
425
+ if torch.cuda.is_available():
426
+ info.append(f"CUDA device: {torch.cuda.get_device_name(0)}")
427
+ info.append(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
428
+ except ImportError:
429
+ info.append("PyTorch: Not installed")
430
+
431
  return "\n".join(info)
432
 
433
  check_btn = gr.Button("Check Environment")
434
  check_btn.click(fn=check_environment, outputs=debug_text)
435
+
436
+ clean_mem_btn = gr.Button("Force Memory Cleanup")
437
+ clean_mem_btn.click(fn=clean_memory)
438
 
439
  # Launch the app
440
  if __name__ == "__main__":