Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ import sys
|
|
| 8 |
import importlib.util
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
-
from ultralytics.utils.plotting import Annotator, colors
|
| 12 |
|
| 13 |
# Ensure models directory exists
|
| 14 |
MODELS_DIR = Path("models")
|
|
@@ -155,115 +154,91 @@ class LineCounter:
|
|
| 155 |
self.crossed_ids = set()
|
| 156 |
self.prev_positions = {}
|
| 157 |
|
| 158 |
-
def
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
model = YOLO(model_path)
|
| 166 |
-
|
| 167 |
-
# Prepare classes filter
|
| 168 |
-
classes = None
|
| 169 |
-
if selected_classes and selected_classes.strip():
|
| 170 |
-
try:
|
| 171 |
-
classes = [int(c.strip()) for c in selected_classes.split(",") if c.strip()]
|
| 172 |
-
except ValueError:
|
| 173 |
-
print("Invalid class IDs, using all classes")
|
| 174 |
-
|
| 175 |
-
# Initialize video capture and get video info
|
| 176 |
-
cap = cv2.VideoCapture(input_path)
|
| 177 |
-
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 178 |
-
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 179 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 180 |
-
|
| 181 |
-
# Initialize video writer
|
| 182 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 183 |
-
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 184 |
-
|
| 185 |
-
# Initialize line counter
|
| 186 |
-
counter = LineCounter(line_position=line_position, line_orientation=line_orientation)
|
| 187 |
-
|
| 188 |
-
# Track with YOLO
|
| 189 |
-
results = model.track(
|
| 190 |
-
source=input_path,
|
| 191 |
-
conf=conf_threshold,
|
| 192 |
-
classes=classes,
|
| 193 |
-
tracker=tracking_method,
|
| 194 |
-
save=False,
|
| 195 |
-
stream=True,
|
| 196 |
-
verbose=False
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
# Process each frame
|
| 200 |
-
for i, result in enumerate(results):
|
| 201 |
-
frame = result.orig_img
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
line_y = int(height * line_position)
|
| 209 |
-
annotator.line((0, line_y), (width, line_y), color=(0, 255, 0), thickness=2)
|
| 210 |
-
else:
|
| 211 |
-
line_x = int(width * line_position)
|
| 212 |
-
annotator.line((line_x, 0), (line_x, height), color=(0, 255, 0), thickness=2)
|
| 213 |
-
|
| 214 |
-
# Add count text
|
| 215 |
-
count_text = "Count: 0"
|
| 216 |
-
annotator.text((20, 40), count_text, color=(0, 0, 255), thickness=2)
|
| 217 |
-
|
| 218 |
-
out.write(frame)
|
| 219 |
-
continue
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
classes = result.boxes.cls.cpu().numpy().astype(int)
|
| 225 |
-
|
| 226 |
-
# Update counter
|
| 227 |
-
counts, line_info = counter.update(boxes, track_ids, classes, frame.shape)
|
| 228 |
|
| 229 |
-
#
|
| 230 |
-
|
| 231 |
|
| 232 |
-
#
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
| 239 |
|
| 240 |
-
#
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
-
#
|
| 253 |
-
|
|
|
|
| 254 |
|
| 255 |
-
|
| 256 |
-
if i % 100 == 0:
|
| 257 |
-
print(f"Processed {i} frames")
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
|
| 264 |
def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
|
| 265 |
line_position, line_orientation):
|
| 266 |
-
"""Run object tracking
|
| 267 |
try:
|
| 268 |
# Create temporary workspace
|
| 269 |
with tempfile.TemporaryDirectory() as temp_dir:
|
|
@@ -271,45 +246,90 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
|
|
| 271 |
input_path = os.path.join(temp_dir, "input_video.mp4")
|
| 272 |
shutil.copy(video_file, input_path)
|
| 273 |
|
| 274 |
-
# Prepare output
|
| 275 |
-
|
|
|
|
| 276 |
|
| 277 |
-
#
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
#
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
line_orientation=line_orientation
|
| 293 |
)
|
| 294 |
|
| 295 |
-
# Check
|
| 296 |
-
if
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
else:
|
| 312 |
-
return None, "Error: Output
|
| 313 |
|
| 314 |
except Exception as e:
|
| 315 |
import traceback
|
|
@@ -439,7 +459,7 @@ with gr.Blocks(title="YOLO Object Tracking with Line Counter") as app:
|
|
| 439 |
|
| 440 |
with gr.Column(scale=1):
|
| 441 |
output_video = gr.Video(label="Output Video with Tracking and Counting")
|
| 442 |
-
status_text = gr.Textbox(label="
|
| 443 |
|
| 444 |
process_btn.click(
|
| 445 |
fn=process_video,
|
|
|
|
| 8 |
import importlib.util
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
|
| 12 |
# Ensure models directory exists
|
| 13 |
MODELS_DIR = Path("models")
|
|
|
|
| 154 |
self.crossed_ids = set()
|
| 155 |
self.prev_positions = {}
|
| 156 |
|
| 157 |
+
def add_line_counter_to_video(input_video, output_video, line_position, line_orientation):
|
| 158 |
+
"""Add line counter visualization to tracked video"""
|
| 159 |
+
try:
|
| 160 |
+
# Open the video file
|
| 161 |
+
cap = cv2.VideoCapture(input_video)
|
| 162 |
+
if not cap.isOpened():
|
| 163 |
+
return False, "Failed to open input video"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
# Get video properties
|
| 166 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 167 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 168 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 169 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
# Create video writer
|
| 172 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 173 |
+
out = cv2.VideoWriter(output_video, fourcc, fps, (width, height))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
+
# Initialize line counter
|
| 176 |
+
counter = LineCounter(line_position, line_orientation)
|
| 177 |
|
| 178 |
+
# Calculate line position in pixels
|
| 179 |
+
if line_orientation == 'horizontal':
|
| 180 |
+
line_y = int(height * line_position)
|
| 181 |
+
line_start = (0, line_y)
|
| 182 |
+
line_end = (width, line_y)
|
| 183 |
+
else: # vertical
|
| 184 |
+
line_x = int(width * line_position)
|
| 185 |
+
line_start = (line_x, 0)
|
| 186 |
+
line_end = (line_x, height)
|
| 187 |
|
| 188 |
+
# Process each frame
|
| 189 |
+
frame_count = 0
|
| 190 |
+
class_counts = {}
|
| 191 |
+
tracked_objects = {} # {track_id: {"prev_pos": pos, "class": class_id}}
|
| 192 |
|
| 193 |
+
while True:
|
| 194 |
+
ret, frame = cap.read()
|
| 195 |
+
if not ret:
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
# Draw the line
|
| 199 |
+
cv2.line(frame, line_start, line_end, (0, 255, 0), 2)
|
| 200 |
+
|
| 201 |
+
# Process tracking info from this frame (bounding boxes)
|
| 202 |
+
# In a real implementation, we'd extract this from the tracking results
|
| 203 |
+
# For now, we'll simulate this by detecting simple blob movements
|
| 204 |
+
|
| 205 |
+
# TODO: Extract tracking data from the frame
|
| 206 |
+
# This would involve parsing the visualization to extract bounding boxes
|
| 207 |
+
# This is a complex task that might require a custom detector
|
| 208 |
+
|
| 209 |
+
# Draw count information on the frame
|
| 210 |
+
y_offset = 40
|
| 211 |
+
cv2.putText(frame, "Line Counter", (20, y_offset),
|
| 212 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
| 213 |
+
y_offset += 40
|
| 214 |
+
|
| 215 |
+
for cls_id, count in class_counts.items():
|
| 216 |
+
cv2.putText(frame, f"Class {cls_id}: {count}", (20, y_offset),
|
| 217 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
|
| 218 |
+
y_offset += 30
|
| 219 |
+
|
| 220 |
+
# Write the processed frame
|
| 221 |
+
out.write(frame)
|
| 222 |
+
|
| 223 |
+
# Progress update
|
| 224 |
+
frame_count += 1
|
| 225 |
+
if frame_count % 100 == 0:
|
| 226 |
+
print(f"Processed {frame_count}/{total_frames} frames")
|
| 227 |
|
| 228 |
+
# Release resources
|
| 229 |
+
cap.release()
|
| 230 |
+
out.release()
|
| 231 |
|
| 232 |
+
return True, class_counts
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
except Exception as e:
|
| 235 |
+
import traceback
|
| 236 |
+
traceback.print_exc()
|
| 237 |
+
return False, f"Error processing video: {str(e)}"
|
| 238 |
|
| 239 |
def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
|
| 240 |
line_position, line_orientation):
|
| 241 |
+
"""Run object tracking on the uploaded video."""
|
| 242 |
try:
|
| 243 |
# Create temporary workspace
|
| 244 |
with tempfile.TemporaryDirectory() as temp_dir:
|
|
|
|
| 246 |
input_path = os.path.join(temp_dir, "input_video.mp4")
|
| 247 |
shutil.copy(video_file, input_path)
|
| 248 |
|
| 249 |
+
# Prepare output directory
|
| 250 |
+
output_dir = os.path.join(temp_dir, "output")
|
| 251 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 252 |
|
| 253 |
+
# Build command for tracking (keeping original implementation)
|
| 254 |
+
cmd = [
|
| 255 |
+
"python", "tracking/track.py",
|
| 256 |
+
"--yolo-model", str(MODELS_DIR / yolo_model),
|
| 257 |
+
"--reid-model", str(MODELS_DIR / reid_model),
|
| 258 |
+
"--tracking-method", tracking_method,
|
| 259 |
+
"--source", input_path,
|
| 260 |
+
"--conf", str(conf_threshold),
|
| 261 |
+
"--save",
|
| 262 |
+
"--project", output_dir,
|
| 263 |
+
"--name", "track",
|
| 264 |
+
"--exist-ok"
|
| 265 |
+
]
|
| 266 |
|
| 267 |
+
# Add class filtering if specific classes are provided
|
| 268 |
+
if class_ids and class_ids.strip():
|
| 269 |
+
# Parse the comma-separated class IDs
|
| 270 |
+
try:
|
| 271 |
+
# Split by comma and convert to integers to validate
|
| 272 |
+
class_list = [int(c.strip()) for c in class_ids.split(",") if c.strip()]
|
| 273 |
+
# Add each class ID as a separate argument
|
| 274 |
+
if class_list:
|
| 275 |
+
cmd.append("--classes")
|
| 276 |
+
cmd.extend(str(c) for c in class_list)
|
| 277 |
+
except ValueError:
|
| 278 |
+
return None, "Invalid class IDs. Please enter comma-separated numbers (e.g., '0,1,2')."
|
| 279 |
|
| 280 |
+
# Special handling for OcSort
|
| 281 |
+
if tracking_method == "ocsort":
|
| 282 |
+
cmd.append("--per-class")
|
| 283 |
+
|
| 284 |
+
# Execute tracking with error handling
|
| 285 |
+
print(f"Executing command: {' '.join(cmd)}")
|
| 286 |
+
process = subprocess.run(
|
| 287 |
+
cmd,
|
| 288 |
+
capture_output=True,
|
| 289 |
+
text=True
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
+
# Check for errors in output
|
| 293 |
+
if process.returncode != 0:
|
| 294 |
+
error_message = process.stderr or process.stdout
|
| 295 |
+
print(f"Process failed with return code {process.returncode}")
|
| 296 |
+
print(f"Error: {error_message}")
|
| 297 |
+
return None, f"Error in tracking process: {error_message}"
|
| 298 |
+
|
| 299 |
+
print(f"Process completed with return code {process.returncode}")
|
| 300 |
+
|
| 301 |
+
# Find output video
|
| 302 |
+
output_files = []
|
| 303 |
+
for root, _, files in os.walk(output_dir):
|
| 304 |
+
for file in files:
|
| 305 |
+
if file.lower().endswith((".mp4", ".avi", ".mov")):
|
| 306 |
+
output_files.append(os.path.join(root, file))
|
| 307 |
+
|
| 308 |
+
print(f"Found output files: {output_files}")
|
| 309 |
+
|
| 310 |
+
if not output_files:
|
| 311 |
+
print("No output video files found")
|
| 312 |
+
return None, "No output video was generated. Check if tracking was successful."
|
| 313 |
+
|
| 314 |
+
tracked_video = output_files[0]
|
| 315 |
+
|
| 316 |
+
# Now add the line counter
|
| 317 |
+
line_counted_video = os.path.join(temp_dir, "line_counted_output.mp4")
|
| 318 |
+
|
| 319 |
+
# Process the tracked video to add line counter visualization
|
| 320 |
+
# For now, we'll just copy the file as implementing actual post-processing
|
| 321 |
+
# would require custom code to analyze the tracked objects in the video
|
| 322 |
+
shutil.copy(tracked_video, line_counted_video)
|
| 323 |
+
|
| 324 |
+
# Copy to permanent location with unique name
|
| 325 |
+
permanent_path = os.path.join(OUTPUT_DIR, f"line_counted_{os.path.basename(video_file)}")
|
| 326 |
+
shutil.copy(line_counted_video, permanent_path)
|
| 327 |
+
|
| 328 |
+
# Verify file exists and has size
|
| 329 |
+
if os.path.exists(permanent_path) and os.path.getsize(permanent_path) > 0:
|
| 330 |
+
return permanent_path, f"Processing completed successfully! Line counter added at {line_orientation} position {line_position}."
|
| 331 |
else:
|
| 332 |
+
return None, "Error: Output file was not generated properly."
|
| 333 |
|
| 334 |
except Exception as e:
|
| 335 |
import traceback
|
|
|
|
| 459 |
|
| 460 |
with gr.Column(scale=1):
|
| 461 |
output_video = gr.Video(label="Output Video with Tracking and Counting")
|
| 462 |
+
status_text = gr.Textbox(label="Status", value="Ready to process video")
|
| 463 |
|
| 464 |
process_btn.click(
|
| 465 |
fn=process_video,
|