deep / app.py
jehadcheyi's picture
Update app.py
ae3b06f verified
Raw
History Blame Contribute Delete
7.17 kB
import cv2
import numpy as np
import argparse
import os
from collections import deque
from ultralytics import YOLO
from bytetracker import BYTETracker
class ObjectTracker:
def __init__(self, model_path='yolov8n.pt', conf_thresh=0.5, track_thresh=0.3, trail_length=10):
# Initialize detection model
self.model = YOLO(model_path)
# Initialize tracker
self.tracker = BYTETracker(track_thresh=track_thresh)
# Parameters
self.conf_thresh = conf_thresh
self.trail_length = trail_length
# Trail storage
self.trails = {}
# Color palette for different objects
self.colors = self._generate_colors(80)
# Class names (COCO dataset)
self.class_names = self.model.names
def _generate_colors(self, num_classes):
"""Generate distinct colors for different classes"""
np.random.seed(42)
colors = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8)
return colors.tolist()
def process_frame(self, frame):
"""Process a single frame through detection and tracking pipeline"""
# Run detection
results = self.model(frame, conf=self.conf_thresh, verbose=False)
# Extract detections
detections = []
for result in results:
boxes = result.boxes.xyxy.cpu().numpy()
scores = result.boxes.conf.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
for i in range(len(boxes)):
detections.append([
boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3],
scores[i], int(classes[i])
])
# Update tracker
tracked_objects = self.tracker.update(detections)
# Update trails
for obj in tracked_objects:
track_id = int(obj.track_id)
center = ((obj.tlwh[0] + obj.tlwh[2]) / 2, (obj.tlwh[1] + obj.tlwh[3]) / 2)
if track_id not in self.trails:
self.trails[track_id] = deque(maxlen=self.trail_length)
self.trails[track_id].append(center)
# Draw results
output_frame = frame.copy()
output_frame = self._draw_results(output_frame, tracked_objects)
return output_frame
def _draw_results(self, frame, tracked_objects):
"""Draw bounding boxes, labels, and trails on frame"""
for obj in tracked_objects:
track_id = int(obj.track_id)
class_id = int(obj.class_id)
class_name = self.class_names[class_id]
conf = obj.score
# Get color for this class
color = self.colors[class_id]
# Draw bounding box
x1, y1, w, h = obj.tlwh
x2, y2 = int(x1 + w), int(y1 + h)
cv2.rectangle(frame, (int(x1), int(y1)), (x2, y2), color, 2)
# Draw label background
label = f"{class_name}-{track_id} {conf:.2f}"
(w_label, h_label), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(frame, (int(x1), int(y1) - 25),
(int(x1) + w_label, int(y1)), color, -1)
# Draw label text
cv2.putText(frame, label, (int(x1), int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
# Draw motion trail
if track_id in self.trails and len(self.trails[track_id]) > 1:
points = np.array(self.trails[track_id], dtype=np.int32)
cv2.polylines(frame, [points], False, color, 2)
# Draw trail points
for i, point in enumerate(points):
cv2.circle(frame, tuple(point), 3, color, -1)
return frame
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="Object Detection and Tracking")
parser.add_argument("--input", type=str, default="0",
help="Input source (video file, image, or webcam index)")
parser.add_argument("--output", type=str, default="output.mp4",
help="Output video file")
parser.add_argument("--model", type=str, default="yolov8n.pt",
help="YOLOv8 model path")
parser.add_argument("--conf", type=float, default=0.5,
help="Confidence threshold")
parser.add_argument("--track_thresh", type=float, default=0.3,
help="Tracking threshold")
parser.add_argument("--trail_length", type=int, default=10,
help="Motion trail length")
parser.add_argument("--no_display", action="store_true",
help="Disable display window")
args = parser.parse_args()
# Initialize tracker
tracker = ObjectTracker(
model_path=args.model,
conf_thresh=args.conf,
track_thresh=args.track_thresh,
trail_length=args.trail_length
)
# Initialize video source
if args.input.isdigit():
cap = cv2.VideoCapture(int(args.input))
else:
cap = cv2.VideoCapture(args.input)
if not cap.isOpened():
print("Error: Could not open video source")
return
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0:
fps = 30 # Default FPS
# Initialize video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
# Create display window
if not args.no_display:
cv2.namedWindow("Object Detection & Tracking", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Object Detection & Tracking", width, height)
# Process frames
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Process frame
processed_frame = tracker.process_frame(frame)
# Write output
out.write(processed_frame)
# Display
if not args.no_display:
cv2.imshow("Object Detection & Tracking", processed_frame)
# Handle key presses
key = cv2.waitKey(1) & 0xFF
if key == ord('q'): # Quit
break
elif key == ord('s'): # Save snapshot
cv2.imwrite(f"snapshot_{frame_count}.jpg", processed_frame)
print(f"Saved snapshot_{frame_count}.jpg")
elif key == ord(' '): # Pause/resume
cv2.waitKey(0)
frame_count += 1
# Release resources
cap.release()
out.release()
if not args.no_display:
cv2.destroyAllWindows()
print(f"Processing complete. Output saved to {args.output}")
print(f"Processed {frame_count} frames")
if __name__ == "__main__":
main()