usiddiquee786 commited on
Commit
bb82e85
·
verified ·
1 Parent(s): f55cfe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -54
app.py CHANGED
@@ -47,7 +47,7 @@ def apply_patches():
47
  else:
48
  print("⚠️ tracker_patch.py not found, skipping patches")
49
 
50
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
51
  """Run object tracking on the uploaded video."""
52
  try:
53
  # Create temporary workspace
@@ -74,11 +74,18 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_cla
74
  "--exist-ok"
75
  ]
76
 
77
- # Add class filtering if specific classes are selected
78
- if target_classes and target_classes != ["all"]:
79
- classes_arg = ",".join(str(YOLO_CLASSES.index(cls)) for cls in target_classes if cls in YOLO_CLASSES)
80
- if classes_arg:
81
- cmd.extend(["--classes", classes_arg])
 
 
 
 
 
 
 
82
 
83
  # Special handling for OcSort
84
  if tracking_method == "ocsort":
@@ -157,20 +164,20 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_cla
157
  return None, f"Error: {str(e)}"
158
 
159
  # Define the Gradio interface
160
- def process_video(video_path, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
161
  # Validate inputs
162
  if not video_path:
163
  return None, "Please upload a video file"
164
 
165
  print(f"Processing video: {video_path}")
166
- print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={target_classes}, conf={conf_threshold}")
167
 
168
  output_path, status = run_tracking(
169
  video_path,
170
  yolo_model,
171
  reid_model,
172
  tracking_method,
173
- target_classes,
174
  conf_threshold
175
  )
176
 
@@ -188,28 +195,6 @@ yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
188
  reid_models = ["osnet_x0_25_msmt17.pt"]
189
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
190
 
191
- # YOLO COCO class names
192
- YOLO_CLASSES = [
193
- "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
194
- "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
195
- "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
196
- "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
197
- "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
198
- "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
199
- "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
200
- "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
201
- "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
202
- ]
203
-
204
- # Common object groups for convenience
205
- COMMON_CLASSES = {
206
- "all": ["all"],
207
- "people": ["person"],
208
- "vehicles": ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"],
209
- "animals": ["bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe"],
210
- "common_objects": ["backpack", "umbrella", "handbag", "tie", "suitcase", "cell phone", "laptop", "book"]
211
- }
212
-
213
  # Ensure dependencies and apply patches at startup
214
  ensure_dependencies()
215
  apply_patches()
@@ -219,6 +204,28 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
219
  gr.Markdown("# 🚀 YOLO Object Tracking")
220
  gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  with gr.Row():
223
  with gr.Column(scale=1):
224
  input_video = gr.Video(label="Input Video", sources=["upload"])
@@ -240,19 +247,11 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
240
  label="Tracking Method"
241
  )
242
 
243
- # Object category selection
244
- class_category = gr.Radio(
245
- choices=list(COMMON_CLASSES.keys()),
246
- value="all",
247
- label="Object Category"
248
- )
249
-
250
- # Individual class selection
251
- target_classes = gr.CheckboxGroup(
252
- choices=YOLO_CLASSES,
253
- value=["all"],
254
- label="Target Classes",
255
- interactive=True
256
  )
257
 
258
  conf_threshold = gr.Slider(
@@ -269,18 +268,9 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
269
  output_video = gr.Video(label="Output Video with Tracking")
270
  status_text = gr.Textbox(label="Status", value="Ready to process video")
271
 
272
- # Update the target classes based on the selected category
273
- def update_classes(category):
274
- if category == "all":
275
- return ["all"]
276
- else:
277
- return COMMON_CLASSES[category]
278
-
279
- class_category.change(fn=update_classes, inputs=class_category, outputs=target_classes)
280
-
281
  process_btn.click(
282
  fn=process_video,
283
- inputs=[input_video, yolo_model, reid_model, tracking_method, target_classes, conf_threshold],
284
  outputs=[output_video, status_text]
285
  )
286
 
 
47
  else:
48
  print("⚠️ tracker_patch.py not found, skipping patches")
49
 
50
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
51
  """Run object tracking on the uploaded video."""
52
  try:
53
  # Create temporary workspace
 
74
  "--exist-ok"
75
  ]
76
 
77
+ # Add class filtering if specific classes are provided
78
+ if class_ids and class_ids.strip():
79
+ # Parse the comma-separated class IDs
80
+ try:
81
+ # Split by comma and convert to integers to validate
82
+ class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
83
+ # Add each class ID as a separate argument
84
+ if class_list:
85
+ cmd.append("--classes")
86
+ cmd.extend(str(c) for c in class_list)
87
+ except ValueError:
88
+ return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
89
 
90
  # Special handling for OcSort
91
  if tracking_method == "ocsort":
 
164
  return None, f"Error: {str(e)}"
165
 
166
  # Define the Gradio interface
167
+ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
168
  # Validate inputs
169
  if not video_path:
170
  return None, "Please upload a video file"
171
 
172
  print(f"Processing video: {video_path}")
173
+ print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
174
 
175
  output_path, status = run_tracking(
176
  video_path,
177
  yolo_model,
178
  reid_model,
179
  tracking_method,
180
+ class_ids,
181
  conf_threshold
182
  )
183
 
 
195
  reid_models = ["osnet_x0_25_msmt17.pt"]
196
  tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  # Ensure dependencies and apply patches at startup
199
  ensure_dependencies()
200
  apply_patches()
 
204
  gr.Markdown("# 🚀 YOLO Object Tracking")
205
  gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
206
 
207
+ # Add class reference information
208
+ with gr.Accordion("YOLO Class Reference", open=False):
209
+ gr.Markdown("""
210
+ # YOLO Class IDs Reference
211
+
212
+ Enter the class IDs as comma-separated numbers in the "Target Classes" field.
213
+ Leave empty to track all classes.
214
+
215
+ ## Common Class IDs:
216
+ - 0: person
217
+ - 1: bicycle
218
+ - 2: car
219
+ - 3: motorcycle
220
+ - 5: bus
221
+ - 7: truck
222
+ - 16: dog
223
+ - 17: horse
224
+ - 67: cell phone
225
+
226
+ [See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
227
+ """)
228
+
229
  with gr.Row():
230
  with gr.Column(scale=1):
231
  input_video = gr.Video(label="Input Video", sources=["upload"])
 
247
  label="Tracking Method"
248
  )
249
 
250
+ # Class ID input field
251
+ class_ids = gr.Textbox(
252
+ value="",
253
+ label="Target Classes (comma-separated IDs, e.g. '0,2,5', leave empty for all classes)",
254
+ placeholder="e.g. 0,2,5"
 
 
 
 
 
 
 
 
255
  )
256
 
257
  conf_threshold = gr.Slider(
 
268
  output_video = gr.Video(label="Output Video with Tracking")
269
  status_text = gr.Textbox(label="Status", value="Ready to process video")
270
 
 
 
 
 
 
 
 
 
 
271
  process_btn.click(
272
  fn=process_video,
273
+ inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold],
274
  outputs=[output_video, status_text]
275
  )
276