usiddiquee commited on
Commit
b72d0b5
·
1 Parent(s): e1832f4
Files changed (1) hide show
  1. app.py +117 -76
app.py CHANGED
@@ -6,6 +6,7 @@ import shutil
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
 
9
 
10
  # Ensure models directory exists
11
  MODELS_DIR = Path("models")
@@ -43,96 +44,125 @@ def apply_patches():
43
  else:
44
  print("⚠️ tracker_patch.py not found, skipping patches")
45
 
46
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold):
47
- """Run object tracking on the uploaded video."""
48
  try:
49
- # Create temporary workspace
50
- with tempfile.TemporaryDirectory() as temp_dir:
51
- # Prepare input
52
- input_path = os.path.join(temp_dir, "input_video.mp4")
53
  shutil.copy(video_file, input_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Prepare output directory
56
- output_dir = os.path.join(temp_dir, "output")
57
- os.makedirs(output_dir, exist_ok=True)
58
-
59
- # Build command
60
- cmd = [
61
- "python", "tracking/track.py",
62
- "--yolo-model", str(MODELS_DIR / yolo_model),
63
- "--reid-model", str(MODELS_DIR / reid_model),
64
- "--tracking-method", tracking_method,
65
- "--source", input_path,
66
- "--conf", str(conf_threshold),
67
- "--save",
68
- "--project", output_dir,
69
- "--name", "track",
70
- "--exist-ok"
71
- ]
72
-
73
- # Special handling for OcSort
74
- if tracking_method == "ocsort":
75
- cmd.append("--per-class")
76
-
77
- # Execute tracking with error handling
78
- process = subprocess.run(
79
- cmd,
80
- capture_output=True,
81
- text=True
82
- )
83
-
84
- # Check for errors in output
85
- if process.returncode != 0:
86
- error_message = process.stderr or process.stdout
87
- return None, f"Error in tracking process: {error_message}"
88
-
89
- # Find output video
90
- output_files = []
91
- for root, _, files in os.walk(output_dir):
92
- for file in files:
93
- if file.lower().endswith((".mp4", ".avi", ".mov")):
94
- output_files.append(os.path.join(root, file))
95
-
96
- if not output_files:
97
- return None, "No output video was generated. Check if tracking was successful."
98
-
99
- return output_files[0], "Processing completed successfully!"
100
-
101
  except Exception as e:
102
- return None, f"Error: {str(e)}"
103
 
104
- # Define the Gradio interface
105
- def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold):
 
 
 
 
106
  # Validate inputs
107
  if not video_path:
108
- return None, "Please upload a video file"
 
 
 
109
 
110
- output_path, status = run_tracking(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  video_path,
112
  yolo_model,
113
  reid_model,
114
- tracking_method,
115
  conf_threshold
116
  )
117
 
118
- return output_path, status
119
 
120
- # Available models and tracking methods
121
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
122
  reid_models = ["osnet_x0_25_msmt17.pt"]
123
- tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
124
 
125
  # Ensure dependencies and apply patches at startup
126
  ensure_dependencies()
127
  apply_patches()
128
 
129
  # Create the Gradio interface
130
- with gr.Blocks(title="YOLO Object Tracking") as app:
131
- gr.Markdown("# 🚀 YOLO Object Tracking")
132
- gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
133
 
134
  with gr.Row():
135
- with gr.Column():
136
  input_video = gr.Video(label="Input Video", sources=["upload"])
137
 
138
  with gr.Group():
@@ -146,11 +176,6 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
146
  value="osnet_x0_25_msmt17.pt",
147
  label="ReID Model"
148
  )
149
- tracking_method = gr.Dropdown(
150
- choices=tracking_methods,
151
- value="bytetrack",
152
- label="Tracking Method"
153
- )
154
  conf_threshold = gr.Slider(
155
  minimum=0.1,
156
  maximum=0.9,
@@ -159,16 +184,32 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
159
  label="Confidence Threshold"
160
  )
161
 
162
- process_btn = gr.Button("Process Video", variant="primary")
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  with gr.Column():
165
- output_video = gr.Video(label="Output Video with Tracking", autoplay=True)
166
- status_text = gr.Textbox(label="Status", value="Ready to process video")
167
 
168
  process_btn.click(
169
  fn=process_video,
170
- inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold],
171
- outputs=[output_video, status_text]
172
  )
173
 
174
  # Launch the app
 
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
9
+ from concurrent.futures import ThreadPoolExecutor
10
 
11
  # Ensure models directory exists
12
  MODELS_DIR = Path("models")
 
44
  else:
45
  print("⚠️ tracker_patch.py not found, skipping patches")
46
 
47
+ def run_single_tracker(video_file, yolo_model, reid_model, tracking_method, conf_threshold, temp_dir):
48
+ """Run object tracking with a single tracking method."""
49
  try:
50
+ # Prepare input
51
+ input_path = os.path.join(temp_dir, "input_video.mp4")
52
+ if not os.path.exists(input_path):
 
53
  shutil.copy(video_file, input_path)
54
+
55
+ # Prepare output directory for this tracker
56
+ output_dir = os.path.join(temp_dir, f"output_{tracking_method}")
57
+ os.makedirs(output_dir, exist_ok=True)
58
+
59
+ # Build command
60
+ cmd = [
61
+ "python", "tracking/track.py",
62
+ "--yolo-model", str(MODELS_DIR / yolo_model),
63
+ "--reid-model", str(MODELS_DIR / reid_model),
64
+ "--tracking-method", tracking_method,
65
+ "--source", input_path,
66
+ "--conf", str(conf_threshold),
67
+ "--save",
68
+ "--project", output_dir,
69
+ "--name", tracking_method,
70
+ "--exist-ok"
71
+ ]
72
+
73
+ # Special handling for OcSort
74
+ if tracking_method == "ocsort":
75
+ cmd.append("--per-class")
76
+
77
+ # Execute tracking
78
+ process = subprocess.run(
79
+ cmd,
80
+ capture_output=True,
81
+ text=True
82
+ )
83
+
84
+ # Check for errors
85
+ if process.returncode != 0:
86
+ error_message = process.stderr or process.stdout
87
+ return None, f"Error in {tracking_method}: {error_message}"
88
+
89
+ # Find output video
90
+ output_files = []
91
+ for root, _, files in os.walk(output_dir):
92
+ for file in files:
93
+ if file.lower().endswith((".mp4", ".avi", ".mov")):
94
+ output_files.append(os.path.join(root, file))
95
+
96
+ if not output_files:
97
+ return None, f"No output video generated for {tracking_method}"
98
 
99
+ return output_files[0], f"{tracking_method} completed successfully!"
100
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
+ return None, f"Error in {tracking_method}: {str(e)}"
103
 
104
+ def run_all_trackers(video_path, yolo_model, reid_model, conf_threshold):
105
+ """Run all tracking methods and return results."""
106
+ tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
107
+ results = {}
108
+ status_messages = []
109
+
110
  # Validate inputs
111
  if not video_path:
112
+ return None, None, None, None, "Please upload a video file"
113
+
114
+ # Create temporary workspace
115
+ temp_dir = tempfile.mkdtemp()
116
 
117
+ try:
118
+ for method in tracking_methods:
119
+ # Run tracking for this method
120
+ output_path, status = run_single_tracker(
121
+ video_path,
122
+ yolo_model,
123
+ reid_model,
124
+ method,
125
+ conf_threshold,
126
+ temp_dir
127
+ )
128
+
129
+ results[method] = output_path
130
+ status_messages.append(status)
131
+
132
+ combined_status = "\n".join(status_messages)
133
+ return results.get("bytetrack"), results.get("botsort"), results.get("ocsort"), results.get("strongsort"), combined_status
134
+
135
+ except Exception as e:
136
+ shutil.rmtree(temp_dir, ignore_errors=True)
137
+ return None, None, None, None, f"Error: {str(e)}"
138
+
139
+ # Define the Gradio interface
140
+ def process_video(video_path, yolo_model, reid_model, conf_threshold):
141
+ # Run all tracking methods
142
+ bytetrack_output, botsort_output, ocsort_output, strongsort_output, status = run_all_trackers(
143
  video_path,
144
  yolo_model,
145
  reid_model,
 
146
  conf_threshold
147
  )
148
 
149
+ return bytetrack_output, botsort_output, ocsort_output, strongsort_output, status
150
 
151
+ # Available models
152
  yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
153
  reid_models = ["osnet_x0_25_msmt17.pt"]
 
154
 
155
  # Ensure dependencies and apply patches at startup
156
  ensure_dependencies()
157
  apply_patches()
158
 
159
  # Create the Gradio interface
160
+ with gr.Blocks(title="YOLO Multi-Tracker Comparison") as app:
161
+ gr.Markdown("# 🚀 YOLO Multi-Tracker Comparison")
162
+ gr.Markdown("Upload a video file to see tracking results from all four trackers side by side.")
163
 
164
  with gr.Row():
165
+ with gr.Column(scale=1):
166
  input_video = gr.Video(label="Input Video", sources=["upload"])
167
 
168
  with gr.Group():
 
176
  value="osnet_x0_25_msmt17.pt",
177
  label="ReID Model"
178
  )
 
 
 
 
 
179
  conf_threshold = gr.Slider(
180
  minimum=0.1,
181
  maximum=0.9,
 
184
  label="Confidence Threshold"
185
  )
186
 
187
+ process_btn = gr.Button("Process Video with All Trackers", variant="primary")
188
+ status_text = gr.Textbox(label="Status", value="Ready to process video", lines=5)
189
+
190
+ gr.Markdown("## Tracking Results")
191
+ with gr.Row():
192
+ with gr.Column():
193
+ gr.Markdown("### ByteTrack")
194
+ bytetrack_video = gr.Video(label="ByteTrack Result", autoplay=False)
195
+
196
+ with gr.Column():
197
+ gr.Markdown("### BOT-SORT")
198
+ botsort_video = gr.Video(label="BOT-SORT Result", autoplay=False)
199
+
200
+ with gr.Row():
201
+ with gr.Column():
202
+ gr.Markdown("### OC-SORT")
203
+ ocsort_video = gr.Video(label="OC-SORT Result", autoplay=False)
204
+
205
  with gr.Column():
206
+ gr.Markdown("### StrongSORT")
207
+ strongsort_video = gr.Video(label="StrongSORT Result", autoplay=False)
208
 
209
  process_btn.click(
210
  fn=process_video,
211
+ inputs=[input_video, yolo_model, reid_model, conf_threshold],
212
+ outputs=[bytetrack_video, botsort_video, ocsort_video, strongsort_video, status_text]
213
  )
214
 
215
  # Launch the app