usiddiquee commited on
Commit
f55cfe0
·
1 Parent(s): 1d1c8d9
Files changed (1) hide show
  1. app.py +60 -6
app.py CHANGED
@@ -47,7 +47,7 @@ def apply_patches():
47
  else:
48
  print("⚠️ tracker_patch.py not found, skipping patches")
49
 
50
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold):
51
  """Run object tracking on the uploaded video."""
52
  try:
53
  # Create temporary workspace
@@ -74,6 +74,12 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_thres
74
  "--exist-ok"
75
  ]
76
 
 
 
 
 
 
 
77
  # Special handling for OcSort
78
  if tracking_method == "ocsort":
79
  cmd.append("--per-class")
@@ -151,19 +157,20 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_thres
151
  return None, f"Error: {str(e)}"
152
 
153
  # Define the Gradio interface
154
- def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold):
155
  # Validate inputs
156
  if not video_path:
157
  return None, "Please upload a video file"
158
 
159
  print(f"Processing video: {video_path}")
160
- print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, conf={conf_threshold}")
161
 
162
  output_path, status = run_tracking(
163
  video_path,
164
  yolo_model,
165
  reid_model,
166
  tracking_method,
 
167
  conf_threshold
168
  )
169
 
@@ -181,6 +188,28 @@ yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
181
  reid_models = ["osnet_x0_25_msmt17.pt"]
182
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Ensure dependencies and apply patches at startup
185
  ensure_dependencies()
186
  apply_patches()
@@ -191,7 +220,7 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
191
  gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
192
 
193
  with gr.Row():
194
- with gr.Column():
195
  input_video = gr.Video(label="Input Video", sources=["upload"])
196
 
197
  with gr.Group():
@@ -210,6 +239,22 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
210
  value="bytetrack",
211
  label="Tracking Method"
212
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  conf_threshold = gr.Slider(
214
  minimum=0.1,
215
  maximum=0.9,
@@ -220,13 +265,22 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
220
 
221
  process_btn = gr.Button("Process Video", variant="primary")
222
 
223
- with gr.Column():
224
  output_video = gr.Video(label="Output Video with Tracking")
225
  status_text = gr.Textbox(label="Status", value="Ready to process video")
226
 
 
 
 
 
 
 
 
 
 
227
  process_btn.click(
228
  fn=process_video,
229
- inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold],
230
  outputs=[output_video, status_text]
231
  )
232
 
 
47
  else:
48
  print("⚠️ tracker_patch.py not found, skipping patches")
49
 
50
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
51
  """Run object tracking on the uploaded video."""
52
  try:
53
  # Create temporary workspace
 
74
  "--exist-ok"
75
  ]
76
 
77
+ # Add class filtering if specific classes are selected
78
+ if target_classes and target_classes != ["all"]:
79
+ classes_arg = ",".join(str(YOLO_CLASSES.index(cls)) for cls in target_classes if cls in YOLO_CLASSES)
80
+ if classes_arg:
81
+ cmd.extend(["--classes", classes_arg])
82
+
83
  # Special handling for OcSort
84
  if tracking_method == "ocsort":
85
  cmd.append("--per-class")
 
157
  return None, f"Error: {str(e)}"
158
 
159
  # Define the Gradio interface
160
+ def process_video(video_path, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
161
  # Validate inputs
162
  if not video_path:
163
  return None, "Please upload a video file"
164
 
165
  print(f"Processing video: {video_path}")
166
+ print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={target_classes}, conf={conf_threshold}")
167
 
168
  output_path, status = run_tracking(
169
  video_path,
170
  yolo_model,
171
  reid_model,
172
  tracking_method,
173
+ target_classes,
174
  conf_threshold
175
  )
176
 
 
188
  reid_models = ["osnet_x0_25_msmt17.pt"]
189
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
190
 
191
+ # YOLO COCO class names
192
+ YOLO_CLASSES = [
193
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
194
+ "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
195
+ "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
196
+ "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
197
+ "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
198
+ "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
199
+ "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
200
+ "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
201
+ "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
202
+ ]
203
+
204
+ # Common object groups for convenience
205
+ COMMON_CLASSES = {
206
+ "all": ["all"],
207
+ "people": ["person"],
208
+ "vehicles": ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"],
209
+ "animals": ["bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe"],
210
+ "common_objects": ["backpack", "umbrella", "handbag", "tie", "suitcase", "cell phone", "laptop", "book"]
211
+ }
212
+
213
  # Ensure dependencies and apply patches at startup
214
  ensure_dependencies()
215
  apply_patches()
 
220
  gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
221
 
222
  with gr.Row():
223
+ with gr.Column(scale=1):
224
  input_video = gr.Video(label="Input Video", sources=["upload"])
225
 
226
  with gr.Group():
 
239
  value="bytetrack",
240
  label="Tracking Method"
241
  )
242
+
243
+ # Object category selection
244
+ class_category = gr.Radio(
245
+ choices=list(COMMON_CLASSES.keys()),
246
+ value="all",
247
+ label="Object Category"
248
+ )
249
+
250
+ # Individual class selection
251
+ target_classes = gr.CheckboxGroup(
252
+ choices=YOLO_CLASSES,
253
+ value=["all"],
254
+ label="Target Classes",
255
+ interactive=True
256
+ )
257
+
258
  conf_threshold = gr.Slider(
259
  minimum=0.1,
260
  maximum=0.9,
 
265
 
266
  process_btn = gr.Button("Process Video", variant="primary")
267
 
268
+ with gr.Column(scale=1):
269
  output_video = gr.Video(label="Output Video with Tracking")
270
  status_text = gr.Textbox(label="Status", value="Ready to process video")
271
 
272
+ # Update the target classes based on the selected category
273
+ def update_classes(category):
274
+ if category == "all":
275
+ return ["all"]
276
+ else:
277
+ return COMMON_CLASSES[category]
278
+
279
+ class_category.change(fn=update_classes, inputs=class_category, outputs=target_classes)
280
+
281
  process_btn.click(
282
  fn=process_video,
283
+ inputs=[input_video, yolo_model, reid_model, tracking_method, target_classes, conf_threshold],
284
  outputs=[output_video, status_text]
285
  )
286