Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -47,7 +47,7 @@ def apply_patches():
|
|
| 47 |
else:
|
| 48 |
print("⚠️ tracker_patch.py not found, skipping patches")
|
| 49 |
|
| 50 |
-
def run_tracking(video_file, yolo_model, reid_model, tracking_method,
|
| 51 |
"""Run object tracking on the uploaded video."""
|
| 52 |
try:
|
| 53 |
# Create temporary workspace
|
|
@@ -74,11 +74,18 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_cla
|
|
| 74 |
"--exist-ok"
|
| 75 |
]
|
| 76 |
|
| 77 |
-
# Add class filtering if specific classes are
|
| 78 |
-
if
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# Special handling for OcSort
|
| 84 |
if tracking_method == "ocsort":
|
|
@@ -157,20 +164,20 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_cla
|
|
| 157 |
return None, f"Error: {str(e)}"
|
| 158 |
|
| 159 |
# Define the Gradio interface
|
| 160 |
-
def process_video(video_path, yolo_model, reid_model, tracking_method,
|
| 161 |
# Validate inputs
|
| 162 |
if not video_path:
|
| 163 |
return None, "Please upload a video file"
|
| 164 |
|
| 165 |
print(f"Processing video: {video_path}")
|
| 166 |
-
print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={
|
| 167 |
|
| 168 |
output_path, status = run_tracking(
|
| 169 |
video_path,
|
| 170 |
yolo_model,
|
| 171 |
reid_model,
|
| 172 |
tracking_method,
|
| 173 |
-
|
| 174 |
conf_threshold
|
| 175 |
)
|
| 176 |
|
|
@@ -188,28 +195,6 @@ yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
|
| 188 |
reid_models = ["osnet_x0_25_msmt17.pt"]
|
| 189 |
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
|
| 190 |
|
| 191 |
-
# YOLO COCO class names
|
| 192 |
-
YOLO_CLASSES = [
|
| 193 |
-
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
| 194 |
-
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
| 195 |
-
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
| 196 |
-
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
|
| 197 |
-
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
|
| 198 |
-
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
|
| 199 |
-
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
|
| 200 |
-
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
| 201 |
-
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
|
| 202 |
-
]
|
| 203 |
-
|
| 204 |
-
# Common object groups for convenience
|
| 205 |
-
COMMON_CLASSES = {
|
| 206 |
-
"all": ["all"],
|
| 207 |
-
"people": ["person"],
|
| 208 |
-
"vehicles": ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"],
|
| 209 |
-
"animals": ["bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe"],
|
| 210 |
-
"common_objects": ["backpack", "umbrella", "handbag", "tie", "suitcase", "cell phone", "laptop", "book"]
|
| 211 |
-
}
|
| 212 |
-
|
| 213 |
# Ensure dependencies and apply patches at startup
|
| 214 |
ensure_dependencies()
|
| 215 |
apply_patches()
|
|
@@ -219,6 +204,28 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 219 |
gr.Markdown("# 🚀 YOLO Object Tracking")
|
| 220 |
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
with gr.Row():
|
| 223 |
with gr.Column(scale=1):
|
| 224 |
input_video = gr.Video(label="Input Video", sources=["upload"])
|
|
@@ -240,19 +247,11 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 240 |
label="Tracking Method"
|
| 241 |
)
|
| 242 |
|
| 243 |
-
#
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
# Individual class selection
|
| 251 |
-
target_classes = gr.CheckboxGroup(
|
| 252 |
-
choices=YOLO_CLASSES,
|
| 253 |
-
value=["all"],
|
| 254 |
-
label="Target Classes",
|
| 255 |
-
interactive=True
|
| 256 |
)
|
| 257 |
|
| 258 |
conf_threshold = gr.Slider(
|
|
@@ -269,18 +268,9 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 269 |
output_video = gr.Video(label="Output Video with Tracking")
|
| 270 |
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 271 |
|
| 272 |
-
# Update the target classes based on the selected category
|
| 273 |
-
def update_classes(category):
|
| 274 |
-
if category == "all":
|
| 275 |
-
return ["all"]
|
| 276 |
-
else:
|
| 277 |
-
return COMMON_CLASSES[category]
|
| 278 |
-
|
| 279 |
-
class_category.change(fn=update_classes, inputs=class_category, outputs=target_classes)
|
| 280 |
-
|
| 281 |
process_btn.click(
|
| 282 |
fn=process_video,
|
| 283 |
-
inputs=[input_video, yolo_model, reid_model, tracking_method,
|
| 284 |
outputs=[output_video, status_text]
|
| 285 |
)
|
| 286 |
|
|
|
|
| 47 |
else:
|
| 48 |
print("⚠️ tracker_patch.py not found, skipping patches")
|
| 49 |
|
| 50 |
+
def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
|
| 51 |
"""Run object tracking on the uploaded video."""
|
| 52 |
try:
|
| 53 |
# Create temporary workspace
|
|
|
|
| 74 |
"--exist-ok"
|
| 75 |
]
|
| 76 |
|
| 77 |
+
# Add class filtering if specific classes are provided
|
| 78 |
+
if class_ids and class_ids.strip():
|
| 79 |
+
# Parse the comma-separated class IDs
|
| 80 |
+
try:
|
| 81 |
+
# Split by comma and convert to integers to validate
|
| 82 |
+
class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
|
| 83 |
+
# Add each class ID as a separate argument
|
| 84 |
+
if class_list:
|
| 85 |
+
cmd.append("--classes")
|
| 86 |
+
cmd.extend(str(c) for c in class_list)
|
| 87 |
+
except ValueError:
|
| 88 |
+
return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
|
| 89 |
|
| 90 |
# Special handling for OcSort
|
| 91 |
if tracking_method == "ocsort":
|
|
|
|
| 164 |
return None, f"Error: {str(e)}"
|
| 165 |
|
| 166 |
# Define the Gradio interface
|
| 167 |
+
def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold):
|
| 168 |
# Validate inputs
|
| 169 |
if not video_path:
|
| 170 |
return None, "Please upload a video file"
|
| 171 |
|
| 172 |
print(f"Processing video: {video_path}")
|
| 173 |
+
print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={class_ids}, conf={conf_threshold}")
|
| 174 |
|
| 175 |
output_path, status = run_tracking(
|
| 176 |
video_path,
|
| 177 |
yolo_model,
|
| 178 |
reid_model,
|
| 179 |
tracking_method,
|
| 180 |
+
class_ids,
|
| 181 |
conf_threshold
|
| 182 |
)
|
| 183 |
|
|
|
|
| 195 |
reid_models = ["osnet_x0_25_msmt17.pt"]
|
| 196 |
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
# Ensure dependencies and apply patches at startup
|
| 199 |
ensure_dependencies()
|
| 200 |
apply_patches()
|
|
|
|
| 204 |
gr.Markdown("# 🚀 YOLO Object Tracking")
|
| 205 |
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
|
| 206 |
|
| 207 |
+
# Add class reference information
|
| 208 |
+
with gr.Accordion("YOLO Class Reference", open=False):
|
| 209 |
+
gr.Markdown("""
|
| 210 |
+
# YOLO Class IDs Reference
|
| 211 |
+
|
| 212 |
+
Enter the class IDs as comma-separated numbers in the "Target Classes" field.
|
| 213 |
+
Leave empty to track all classes.
|
| 214 |
+
|
| 215 |
+
## Common Class IDs:
|
| 216 |
+
- 0: person
|
| 217 |
+
- 1: bicycle
|
| 218 |
+
- 2: car
|
| 219 |
+
- 3: motorcycle
|
| 220 |
+
- 5: bus
|
| 221 |
+
- 7: truck
|
| 222 |
+
- 16: dog
|
| 223 |
+
- 17: horse
|
| 224 |
+
- 67: cell phone
|
| 225 |
+
|
| 226 |
+
[See full COCO class list here](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/datasets/coco128.yaml)
|
| 227 |
+
""")
|
| 228 |
+
|
| 229 |
with gr.Row():
|
| 230 |
with gr.Column(scale=1):
|
| 231 |
input_video = gr.Video(label="Input Video", sources=["upload"])
|
|
|
|
| 247 |
label="Tracking Method"
|
| 248 |
)
|
| 249 |
|
| 250 |
+
# Class ID input field
|
| 251 |
+
class_ids = gr.Textbox(
|
| 252 |
+
value="",
|
| 253 |
+
label="Target Classes (comma-separated IDs, e.g. '0,2,5', leave empty for all classes)",
|
| 254 |
+
placeholder="e.g. 0,2,5"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
)
|
| 256 |
|
| 257 |
conf_threshold = gr.Slider(
|
|
|
|
| 268 |
output_video = gr.Video(label="Output Video with Tracking")
|
| 269 |
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
process_btn.click(
|
| 272 |
fn=process_video,
|
| 273 |
+
inputs=[input_video, yolo_model, reid_model, tracking_method, class_ids, conf_threshold],
|
| 274 |
outputs=[output_video, status_text]
|
| 275 |
)
|
| 276 |
|