usiddiquee786 commited on
Commit
77103ef
·
verified ·
1 Parent(s): 7ec8a6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -275
app.py CHANGED
@@ -6,8 +6,6 @@ import shutil
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
9
- import time
10
- import gc
11
 
12
  # Ensure models directory exists
13
  MODELS_DIR = Path("models")
@@ -48,193 +46,150 @@ def apply_patches():
48
  else:
49
  print("⚠️ tracker_patch.py not found, skipping patches")
50
 
51
- def clean_memory():
52
- """Force garbage collection to free memory."""
53
- gc.collect()
54
- if hasattr(torch, 'cuda') and torch.cuda.is_available():
55
- try:
56
- import torch
57
- torch.cuda.empty_cache()
58
- print("🧹 Cleared CUDA memory cache")
59
- except (ImportError, NameError):
60
- print("⚠️ Could not clear CUDA memory (torch not available)")
61
- print("🧹 Memory cleaned")
62
-
63
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, temp_dir):
64
- """Run object tracking with a single tracking method."""
65
  try:
66
- # Prepare input
67
- input_path = os.path.join(temp_dir, "input_video.mp4")
68
- if not os.path.exists(input_path):
 
69
  shutil.copy(video_file, input_path)
70
-
71
- # Prepare output directory
72
- output_dir = os.path.join(temp_dir, tracking_method)
73
- os.makedirs(output_dir, exist_ok=True)
74
-
75
- start_time = time.time()
76
-
77
- # Build command
78
- cmd = [
79
- "python", "tracking/track.py",
80
- "--yolo-model", str(MODELS_DIR / yolo_model),
81
- "--reid-model", str(MODELS_DIR / reid_model),
82
- "--tracking-method", tracking_method,
83
- "--source", input_path,
84
- "--conf", str(conf_threshold),
85
- "--save",
86
- "--project", output_dir,
87
- "--name", "track",
88
- "--exist-ok"
89
- ]
90
-
91
- # Add class filtering if specific classes are provided
92
- if class_ids and class_ids.strip():
93
- # Parse the comma-separated class IDs
94
- try:
95
- # Split by comma and convert to integers to validate
96
- class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
97
- # Add each class ID as a separate argument
98
- if class_list:
99
- cmd.append("--classes")
100
- cmd.extend(str(c) for c in class_list)
101
- except ValueError:
102
- return None, f"Invalid class IDs for {tracking_method}. Please enter comma-separated numbers.", 0
103
-
104
- # Special handling for OcSort
105
- if tracking_method == "ocsort":
106
- cmd.append("--per-class")
107
-
108
- # Execute tracking with error handling
109
- print(f"Executing command for {tracking_method}: {' '.join(cmd)}")
110
- process = subprocess.run(
111
- cmd,
112
- capture_output=True,
113
- text=True
114
- )
115
-
116
- end_time = time.time()
117
- processing_time = end_time - start_time
118
-
119
- # Check for errors in output
120
- if process.returncode != 0:
121
- error_message = process.stderr or process.stdout
122
- print(f"Process for {tracking_method} failed with return code {process.returncode}")
123
- print(f"Error: {error_message}")
124
- return None, f"Error in {tracking_method}: {error_message}", processing_time
125
-
126
- print(f"Process for {tracking_method} completed with return code {process.returncode}")
127
-
128
- # Find output video
129
- output_files = []
130
- for root, _, files in os.walk(output_dir):
131
- for file in files:
132
- if file.lower().endswith((".mp4", ".avi", ".mov")):
133
- output_files.append(os.path.join(root, file))
134
-
135
- print(f"Found output files for {tracking_method}: {output_files}")
136
-
137
- if not output_files:
138
- print(f"No output video files found for {tracking_method}")
139
- return None, f"No output video was generated for {tracking_method}.", processing_time
140
-
141
- output_file = output_files[0]
142
- print(f"Selected output file for {tracking_method}: {output_file}")
143
-
144
- # Verify file exists and has size
145
- if os.path.exists(output_file):
146
- file_size = os.path.getsize(output_file)
147
- print(f"Output file for {tracking_method} exists with size: {file_size} bytes")
148
 
149
- if file_size == 0:
150
- return None, f"Output video for {tracking_method} was generated but has zero size.", processing_time
 
151
 
152
- # Copy to permanent location with unique name
153
- permanent_path = os.path.join(OUTPUT_DIR, f"{tracking_method}_{os.path.basename(video_file)}")
154
- shutil.copy(output_file, permanent_path)
155
- print(f"Copied {tracking_method} output to permanent location: {permanent_path}")
 
 
 
 
 
 
 
 
 
156
 
157
- # Ensure the file is in MP4 format for better compatibility with Gradio
158
- if not permanent_path.lower().endswith('.mp4'):
159
- mp4_path = os.path.splitext(permanent_path)[0] + '.mp4'
160
  try:
161
- print(f"Converting to MP4 format: {mp4_path}")
162
- subprocess.run([
163
- 'ffmpeg', '-i', permanent_path,
164
- '-c:v', 'libx264', '-preset', 'fast',
165
- '-c:a', 'aac', mp4_path
166
- ], check=True, capture_output=True)
167
- os.remove(permanent_path) # Remove the original file
168
- permanent_path = mp4_path
169
- except Exception as e:
170
- print(f"Failed to convert to MP4: {str(e)}")
171
- # Continue with original file if conversion fails
172
 
173
- return permanent_path, f"{tracking_method} completed successfully in {processing_time:.2f} seconds!", processing_time
174
- else:
175
- print(f"Output file for {tracking_method} not found at {output_file}")
176
- return None, f"Output file for {tracking_method} was referenced but doesn't exist on disk.", processing_time
177
-
178
- except Exception as e:
179
- import traceback
180
- traceback.print_exc()
181
- return None, f"Error in {tracking_method}: {str(e)}", 0
182
-
183
- def sequential_tracker(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, progress=gr.Progress()):
184
- """Run a single tracker and return its results."""
185
- if not video_path:
186
- return None, "Please upload a video file", "", None
187
-
188
- print(f"Processing video: {video_path} with {tracking_method}")
189
- print(f"Parameters: model={yolo_model}, reid={reid_model}, classes={class_ids}, conf={conf_threshold}")
190
-
191
- progress(0, desc=f"Starting {tracking_method}...")
192
-
193
- try:
194
- # Create a temporary directory for processing
195
- with tempfile.TemporaryDirectory() as temp_dir:
196
- progress(0.1, desc=f"Running {tracking_method}...")
197
 
198
- result, status, process_time = run_tracking(
199
- video_path,
200
- yolo_model,
201
- reid_model,
202
- tracking_method,
203
- class_ids,
204
- conf_threshold,
205
- temp_dir
206
  )
207
 
208
- progress(0.9, desc="Finalizing results...")
 
 
 
 
 
209
 
210
- # Clean up memory after each tracker run
211
- clean_memory()
212
 
213
- # Create the DataFrame data for this single tracker
214
- comparison_data = [[tracking_method, f"{process_time:.2f} seconds", "Success" if result else "Failed"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- # Return the results
217
- return result, status, f"{process_time:.2f} seconds", comparison_data
218
-
219
  except Exception as e:
220
  import traceback
221
  traceback.print_exc()
222
- error_msg = f"Process error for {tracking_method}: {str(e)}"
223
-
224
- # Clean memory even if there was an error
225
- clean_memory()
226
-
227
- return None, error_msg, "", [[tracking_method, "Error", "Failed"]]
228
 
229
- def update_comparison_table(current_data, new_data):
230
- """Update the comparison table with results from a new tracker."""
231
- if current_data is None:
232
- return new_data
 
 
 
 
233
 
234
- # Append the new tracker data to the existing table
235
- return current_data + new_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
- # Available models
238
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
239
  reid_models = ["osnet_x0_25_msmt17.pt"]
240
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
@@ -244,9 +199,9 @@ ensure_dependencies()
244
  apply_patches()
245
 
246
  # Create the Gradio interface
247
- with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
248
- gr.Markdown("# 🔍 YOLO Object Tracking Benchmark")
249
- gr.Markdown("Upload a video file and run each tracking method sequentially. Results will be displayed as they become available.")
250
 
251
  # Add class reference information
252
  with gr.Accordion("YOLO Class Reference", open=False):
@@ -270,9 +225,6 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
270
  [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
271
  """)
272
 
273
- # State variables to keep track of comparison data
274
- comparison_data_state = gr.State([])
275
-
276
  with gr.Row():
277
  with gr.Column(scale=1):
278
  input_video = gr.Video(label="Input Video", sources=["upload"])
@@ -288,6 +240,11 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
288
  value="osnet_x0_25_msmt17.pt",
289
  label="ReID Model"
290
  )
 
 
 
 
 
291
 
292
  # Class ID input field
293
  class_ids = gr.Textbox(
@@ -303,93 +260,17 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
303
  step=0.05,
304
  label="Confidence Threshold"
305
  )
306
-
307
- # Individual tracker buttons
308
- with gr.Group():
309
- gr.Markdown("### Run Trackers One by One")
310
- with gr.Row():
311
- bytetrack_btn = gr.Button("Run ByteTrack", variant="primary")
312
- with gr.Row():
313
- botsort_btn = gr.Button("Run BoTSORT", variant="primary")
314
- with gr.Row():
315
- ocsort_btn = gr.Button("Run OC-SORT", variant="primary")
316
- with gr.Row():
317
- strongsort_btn = gr.Button("Run StrongSORT", variant="primary")
318
-
319
- # Output Tabs for each tracking method
320
- with gr.Tabs() as tabs:
321
- with gr.TabItem("ByteTrack"):
322
- bytetrack_video = gr.Video(label="ByteTrack Result")
323
- bytetrack_status = gr.Textbox(label="Status", value="Ready to process")
324
- bytetrack_time = gr.Textbox(label="Processing Time", value="")
325
-
326
- with gr.TabItem("BoTSORT"):
327
- botsort_video = gr.Video(label="BoTSORT Result")
328
- botsort_status = gr.Textbox(label="Status", value="Ready to process")
329
- botsort_time = gr.Textbox(label="Processing Time", value="")
330
-
331
- with gr.TabItem("OC-SORT"):
332
- ocsort_video = gr.Video(label="OC-SORT Result")
333
- ocsort_status = gr.Textbox(label="Status", value="Ready to process")
334
- ocsort_time = gr.Textbox(label="Processing Time", value="")
335
-
336
- with gr.TabItem("StrongSORT"):
337
- strongsort_video = gr.Video(label="StrongSORT Result")
338
- strongsort_status = gr.Textbox(label="Status", value="Ready to process")
339
- strongsort_time = gr.Textbox(label="Processing Time", value="")
340
-
341
- # Comparison Tab
342
- with gr.Tabs() as comparison_tab:
343
- with gr.TabItem("Performance Comparison"):
344
- perf_table = gr.DataFrame(
345
- headers=["Tracker", "Processing Time", "Status"],
346
- datatype=["str", "str", "str"],
347
- label="Performance Comparison"
348
- )
349
-
350
- # Connect individual tracker buttons to their respective functions
351
- bytetrack_btn.click(
352
- fn=sequential_tracker,
353
- inputs=[input_video, yolo_model, reid_model, gr.State("bytetrack"), class_ids, conf_threshold],
354
- outputs=[bytetrack_video, bytetrack_status, bytetrack_time, comparison_data_state],
355
- show_progress="full"
356
- ).then(
357
- fn=update_comparison_table,
358
- inputs=[perf_table, comparison_data_state],
359
- outputs=perf_table
360
- )
361
-
362
- botsort_btn.click(
363
- fn=sequential_tracker,
364
- inputs=[input_video, yolo_model, reid_model, gr.State("botsort"), class_ids, conf_threshold],
365
- outputs=[botsort_video, botsort_status, botsort_time, comparison_data_state],
366
- show_progress="full"
367
- ).then(
368
- fn=update_comparison_table,
369
- inputs=[perf_table, comparison_data_state],
370
- outputs=perf_table
371
- )
372
-
373
- ocsort_btn.click(
374
- fn=sequential_tracker,
375
- inputs=[input_video, yolo_model, reid_model, gr.State("ocsort"), class_ids, conf_threshold],
376
- outputs=[ocsort_video, ocsort_status, ocsort_time, comparison_data_state],
377
- show_progress="full"
378
- ).then(
379
- fn=update_comparison_table,
380
- inputs=[perf_table, comparison_data_state],
381
- outputs=perf_table
382
- )
383
 
384
- strongsort_btn.click(
385
- fn=sequential_tracker,
386
- inputs=[input_video, yolo_model, reid_model, gr.State("strongsort"), class_ids, conf_threshold],
387
- outputs=[strongsort_video, strongsort_status, strongsort_time, comparison_data_state],
388
- show_progress="full"
389
- ).then(
390
- fn=update_comparison_table,
391
- inputs=[perf_table, comparison_data_state],
392
- outputs=perf_table
393
  )
394
 
395
  # Add a debug section
@@ -418,23 +299,10 @@ with gr.Blocks(title="YOLO Object Tracking Benchmark") as app:
418
  for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []:
419
  info.append(f" - {model}")
420
 
421
- # Check for GPU
422
- try:
423
- import torch
424
- info.append(f"CUDA available: {torch.cuda.is_available()}")
425
- if torch.cuda.is_available():
426
- info.append(f"CUDA device: {torch.cuda.get_device_name(0)}")
427
- info.append(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
428
- except ImportError:
429
- info.append("PyTorch: Not installed")
430
-
431
  return "\n".join(info)
432
 
433
  check_btn = gr.Button("Check Environment")
434
  check_btn.click(fn=check_environment, outputs=debug_text)
435
-
436
- clean_mem_btn = gr.Button("Force Memory Cleanup")
437
- clean_mem_btn.click(fn=clean_memory)
438
 
439
  # Launch the app
440
  if __name__ == "__main__":
 
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
 
 
9
 
10
  # Ensure models directory exists
11
  MODELS_DIR = Path("models")
 
46
  else:
47
  print("⚠️ tracker_patch.py not found, skipping patches")
48
 
49
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
50
+ """Run object tracking on the uploaded video."""
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
+ # Create temporary workspace
53
+ with tempfile.TemporaryDirectory() as temp_dir:
54
+ # Prepare input
55
+ input_path = os.path.join(temp_dir, "input_video.mp4")
56
  shutil.copy(video_file, input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Prepare output directory
59
+ output_dir = os.path.join(temp_dir, "output")
60
+ os.makedirs(output_dir, exist_ok=True)
61
 
62
+ # Build command
63
+ cmd = [
64
+ "python", "tracking/track.py",
65
+ "--yolo-model", str(MODELS_DIR / yolo_model),
66
+ "--reid-model", str(MODELS_DIR / reid_model),
67
+ "--tracking-method", tracking_method,
68
+ "--source", input_path,
69
+ "--conf", str(conf_threshold),
70
+ "--save",
71
+ "--project", output_dir,
72
+ "--name", "track",
73
+ "--exist-ok"
74
+ ]
75
 
76
+ # Add class filtering if specific classes are provided
77
+ if class_ids and class_ids.strip():
78
+ # Parse the comma-separated class IDs
79
  try:
80
+ # Split by comma and convert to integers to validate
81
+ class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
82
+ # Add each class ID as a separate argument
83
+ if class_list:
84
+ cmd.append("--classes")
85
+ cmd.extend(str(c) for c in class_list)
86
+ except ValueError:
87
+ return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
 
 
 
88
 
89
+ # Special handling for OcSort
90
+ if tracking_method == "ocsort":
91
+ cmd.append("--per-class")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # Execute tracking with error handling
94
+ print(f"Executing command: {' '.join(cmd)}")
95
+ process = subprocess.run(
96
+ cmd,
97
+ capture_output=True,
98
+ text=True
 
 
99
  )
100
 
101
+ # Check for errors in output
102
+ if process.returncode != 0:
103
+ error_message = process.stderr or process.stdout
104
+ print(f"Process failed with return code {process.returncode}")
105
+ print(f"Error: {error_message}")
106
+ return None, f"Error in tracking process: {error_message}"
107
 
108
+ print(f"Process completed with return code {process.returncode}")
 
109
 
110
+ # Find output video
111
+ output_files = []
112
+ for root, _, files in os.walk(output_dir):
113
+ for file in files:
114
+ if file.lower().endswith((".mp4", ".avi", ".mov")):
115
+ output_files.append(os.path.join(root, file))
116
+
117
+ print(f"Found output files: {output_files}")
118
+
119
+ if not output_files:
120
+ print("No output video files found")
121
+ return None, "No output video was generated. Check if tracking was successful."
122
+
123
+ output_file = output_files[0]
124
+ print(f"Selected output file: {output_file}")
125
+
126
+ # Verify file exists and has size
127
+ if os.path.exists(output_file):
128
+ file_size = os.path.getsize(output_file)
129
+ print(f"Output file exists with size: {file_size} bytes")
130
+
131
+ if file_size == 0:
132
+ return None, "Output video was generated but has zero size."
133
+
134
+ # Copy to permanent location with unique name
135
+ permanent_path = os.path.join(OUTPUT_DIR, f"output_{os.path.basename(video_file)}")
136
+ shutil.copy(output_file, permanent_path)
137
+ print(f"Copied output to permanent location: {permanent_path}")
138
+
139
+ # Ensure the file is in MP4 format for better compatibility with Gradio
140
+ if not permanent_path.lower().endswith('.mp4'):
141
+ mp4_path = os.path.splitext(permanent_path)[0] + '.mp4'
142
+ try:
143
+ print(f"Converting to MP4 format: {mp4_path}")
144
+ subprocess.run([
145
+ 'ffmpeg', '-i', permanent_path,
146
+ '-c:v', 'libx264', '-preset', 'fast',
147
+ '-c:a', 'aac', mp4_path
148
+ ], check=True, capture_output=True)
149
+ os.remove(permanent_path) # Remove the original file
150
+ permanent_path = mp4_path
151
+ except Exception as e:
152
+ print(f"Failed to convert to MP4: {str(e)}")
153
+ # Continue with original file if conversion fails
154
+
155
+ return permanent_path, "Processing completed successfully!"
156
+ else:
157
+ print(f"Output file not found at {output_file}")
158
+ return None, "Output file was referenced but doesn't exist on disk."
159
 
 
 
 
160
  except Exception as e:
161
  import traceback
162
  traceback.print_exc()
163
+ return None, f"Error: {str(e)}"
 
 
 
 
 
164
 
165
+ # Define the Gradio interface
166
+ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
167
+ # Validate inputs
168
+ if not video_path:
169
+ return None, "Please upload a video file"
170
+
171
+ print(f"Processing video: {video_path}")
172
+ print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
173
 
174
+ output_path, status = run_tracking(
175
+ video_path,
176
+ yolo_model,
177
+ reid_model,
178
+ tracking_method,
179
+ class_ids,
180
+ conf_threshold
181
+ )
182
+
183
+ if output_path:
184
+ print(f"Returning output path: {output_path}")
185
+ # Make sure the path is absolute for Gradio
186
+ abs_path = os.path.abspath(output_path)
187
+ return abs_path, status
188
+ else:
189
+ print(f"No output path available. Status: {status}")
190
+ return None, status
191
 
192
+ # Available models and tracking methods
193
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
194
  reid_models = ["osnet_x0_25_msmt17.pt"]
195
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
 
199
  apply_patches()
200
 
201
  # Create the Gradio interface
202
+ with gr.Blocks(title="YOLO Object Tracking") as app:
203
+ gr.Markdown("# 🚀 YOLO Object Tracking")
204
+ gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
205
 
206
  # Add class reference information
207
  with gr.Accordion("YOLO Class Reference", open=False):
 
225
  [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
226
  """)
227
 
 
 
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
  input_video = gr.Video(label="Input Video", sources=["upload"])
 
240
  value="osnet_x0_25_msmt17.pt",
241
  label="ReID Model"
242
  )
243
+ tracking_method = gr.Dropdown(
244
+ choices=tracking_methods,
245
+ value="bytetrack",
246
+ label="Tracking Method"
247
+ )
248
 
249
  # Class ID input field
250
  class_ids = gr.Textbox(
 
260
  step=0.05,
261
  label="Confidence Threshold"
262
  )
263
+
264
+ process_btn = gr.Button("Process Video", variant="primary")
265
+
266
+ with gr.Column(scale=1):
267
+ output_video = gr.Video(label="Output Video with Tracking")
268
+ status_text = gr.Textbox(label="Status", value="Ready to process video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ process_btn.click(
271
+ fn=process_video,
272
+ inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold],
273
+ outputs=[output_video, status_text]
 
 
 
 
 
274
  )
275
 
276
  # Add a debug section
 
299
  for model in os.listdir(MODELS_DIR) if os.path.exists(MODELS_DIR) else []:
300
  info.append(f" - {model}")
301
 
 
 
 
 
 
 
 
 
 
 
302
  return "\n".join(info)
303
 
304
  check_btn = gr.Button("Check Environment")
305
  check_btn.click(fn=check_environment, outputs=debug_text)
 
 
 
306
 
307
  # Launch the app
308
  if __name__ == "__main__":