Spaces:
Sleeping
Sleeping
usiddiquee commited on
Commit ·
f55cfe0
1
Parent(s): 1d1c8d9
done
Browse files
app.py
CHANGED
|
@@ -47,7 +47,7 @@ def apply_patches():
|
|
| 47 |
else:
|
| 48 |
print("⚠️ tracker_patch.py not found, skipping patches")
|
| 49 |
|
| 50 |
-
def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_threshold):
|
| 51 |
"""Run object tracking on the uploaded video."""
|
| 52 |
try:
|
| 53 |
# Create temporary workspace
|
|
@@ -74,6 +74,12 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_thres
|
|
| 74 |
"--exist-ok"
|
| 75 |
]
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
# Special handling for OcSort
|
| 78 |
if tracking_method == "ocsort":
|
| 79 |
cmd.append("--per-class")
|
|
@@ -151,19 +157,20 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, conf_thres
|
|
| 151 |
return None, f"Error: {str(e)}"
|
| 152 |
|
| 153 |
# Define the Gradio interface
|
| 154 |
-
def process_video(video_path, yolo_model, reid_model, tracking_method, conf_threshold):
|
| 155 |
# Validate inputs
|
| 156 |
if not video_path:
|
| 157 |
return None, "Please upload a video file"
|
| 158 |
|
| 159 |
print(f"Processing video: {video_path}")
|
| 160 |
-
print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, conf={conf_threshold}")
|
| 161 |
|
| 162 |
output_path, status = run_tracking(
|
| 163 |
video_path,
|
| 164 |
yolo_model,
|
| 165 |
reid_model,
|
| 166 |
tracking_method,
|
|
|
|
| 167 |
conf_threshold
|
| 168 |
)
|
| 169 |
|
|
@@ -181,6 +188,28 @@ yolo_models = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"]
|
|
| 181 |
reid_models = ["osnet_x0_25_msmt17.pt"]
|
| 182 |
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
# Ensure dependencies and apply patches at startup
|
| 185 |
ensure_dependencies()
|
| 186 |
apply_patches()
|
|
@@ -191,7 +220,7 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 191 |
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
|
| 192 |
|
| 193 |
with gr.Row():
|
| 194 |
-
with gr.Column():
|
| 195 |
input_video = gr.Video(label="Input Video", sources=["upload"])
|
| 196 |
|
| 197 |
with gr.Group():
|
|
@@ -210,6 +239,22 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 210 |
value="bytetrack",
|
| 211 |
label="Tracking Method"
|
| 212 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
conf_threshold = gr.Slider(
|
| 214 |
minimum=0.1,
|
| 215 |
maximum=0.9,
|
|
@@ -220,13 +265,22 @@ with gr.Blocks(title="YOLO Object Tracking") as app:
|
|
| 220 |
|
| 221 |
process_btn = gr.Button("Process Video", variant="primary")
|
| 222 |
|
| 223 |
-
with gr.Column():
|
| 224 |
output_video = gr.Video(label="Output Video with Tracking")
|
| 225 |
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
process_btn.click(
|
| 228 |
fn=process_video,
|
| 229 |
-
inputs=[input_video, yolo_model, reid_model, tracking_method, conf_threshold],
|
| 230 |
outputs=[output_video, status_text]
|
| 231 |
)
|
| 232 |
|
|
|
|
| 47 |
else:
|
| 48 |
print("⚠️ tracker_patch.py not found, skipping patches")
|
| 49 |
|
| 50 |
+
def run_tracking(video_file, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
|
| 51 |
"""Run object tracking on the uploaded video."""
|
| 52 |
try:
|
| 53 |
# Create temporary workspace
|
|
|
|
| 74 |
"--exist-ok"
|
| 75 |
]
|
| 76 |
|
| 77 |
+
# Add class filtering if specific classes are selected
|
| 78 |
+
if target_classes and target_classes != ["all"]:
|
| 79 |
+
classes_arg = ",".join(str(YOLO_CLASSES.index(cls)) for cls in target_classes if cls in YOLO_CLASSES)
|
| 80 |
+
if classes_arg:
|
| 81 |
+
cmd.extend(["--classes", classes_arg])
|
| 82 |
+
|
| 83 |
# Special handling for OcSort
|
| 84 |
if tracking_method == "ocsort":
|
| 85 |
cmd.append("--per-class")
|
|
|
|
| 157 |
return None, f"Error: {str(e)}"
|
| 158 |
|
| 159 |
# Define the Gradio interface
|
| 160 |
+
def process_video(video_path, yolo_model, reid_model, tracking_method, target_classes, conf_threshold):
|
| 161 |
# Validate inputs
|
| 162 |
if not video_path:
|
| 163 |
return None, "Please upload a video file"
|
| 164 |
|
| 165 |
print(f"Processing video: {video_path}")
|
| 166 |
+
print(f"Parameters: model={yolo_model}, reid={reid_model}, tracker={tracking_method}, classes={target_classes}, conf={conf_threshold}")
|
| 167 |
|
| 168 |
output_path, status = run_tracking(
|
| 169 |
video_path,
|
| 170 |
yolo_model,
|
| 171 |
reid_model,
|
| 172 |
tracking_method,
|
| 173 |
+
target_classes,
|
| 174 |
conf_threshold
|
| 175 |
)
|
| 176 |
|
|
|
|
| 188 |
reid_models = ["osnet_x0_25_msmt17.pt"]
|
| 189 |
tracking_methods = ["bytetrack", "botsort", "ocsort", "strongsort"]
|
| 190 |
|
| 191 |
+
# YOLO COCO class names
|
| 192 |
+
YOLO_CLASSES = [
|
| 193 |
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
|
| 194 |
+
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
|
| 195 |
+
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
|
| 196 |
+
"suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
|
| 197 |
+
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
|
| 198 |
+
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
|
| 199 |
+
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
|
| 200 |
+
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book",
|
| 201 |
+
"clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
# Common object groups for convenience
|
| 205 |
+
COMMON_CLASSES = {
|
| 206 |
+
"all": ["all"],
|
| 207 |
+
"people": ["person"],
|
| 208 |
+
"vehicles": ["bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat"],
|
| 209 |
+
"animals": ["bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe"],
|
| 210 |
+
"common_objects": ["backpack", "umbrella", "handbag", "tie", "suitcase", "cell phone", "laptop", "book"]
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
# Ensure dependencies and apply patches at startup
|
| 214 |
ensure_dependencies()
|
| 215 |
apply_patches()
|
|
|
|
| 220 |
gr.Markdown("Upload a video file to detect and track objects. Processing may take a few minutes depending on video length.")
|
| 221 |
|
| 222 |
with gr.Row():
|
| 223 |
+
with gr.Column(scale=1):
|
| 224 |
input_video = gr.Video(label="Input Video", sources=["upload"])
|
| 225 |
|
| 226 |
with gr.Group():
|
|
|
|
| 239 |
value="bytetrack",
|
| 240 |
label="Tracking Method"
|
| 241 |
)
|
| 242 |
+
|
| 243 |
+
# Object category selection
|
| 244 |
+
class_category = gr.Radio(
|
| 245 |
+
choices=list(COMMON_CLASSES.keys()),
|
| 246 |
+
value="all",
|
| 247 |
+
label="Object Category"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Individual class selection
|
| 251 |
+
target_classes = gr.CheckboxGroup(
|
| 252 |
+
choices=YOLO_CLASSES,
|
| 253 |
+
value=["all"],
|
| 254 |
+
label="Target Classes",
|
| 255 |
+
interactive=True
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
conf_threshold = gr.Slider(
|
| 259 |
minimum=0.1,
|
| 260 |
maximum=0.9,
|
|
|
|
| 265 |
|
| 266 |
process_btn = gr.Button("Process Video", variant="primary")
|
| 267 |
|
| 268 |
+
with gr.Column(scale=1):
|
| 269 |
output_video = gr.Video(label="Output Video with Tracking")
|
| 270 |
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 271 |
|
| 272 |
+
# Update the target classes based on the selected category
|
| 273 |
+
def update_classes(category):
|
| 274 |
+
if category == "all":
|
| 275 |
+
return ["all"]
|
| 276 |
+
else:
|
| 277 |
+
return COMMON_CLASSES[category]
|
| 278 |
+
|
| 279 |
+
class_category.change(fn=update_classes, inputs=class_category, outputs=target_classes)
|
| 280 |
+
|
| 281 |
process_btn.click(
|
| 282 |
fn=process_video,
|
| 283 |
+
inputs=[input_video, yolo_model, reid_model, tracking_method, target_classes, conf_threshold],
|
| 284 |
outputs=[output_video, status_text]
|
| 285 |
)
|
| 286 |
|