Spaces:
Build error
Build error
| import cv2 | |
| import numpy as np | |
| import argparse | |
| import os | |
| from collections import deque | |
| from ultralytics import YOLO | |
| from bytetracker import BYTETracker | |
| class ObjectTracker: | |
| def __init__(self, model_path='yolov8n.pt', conf_thresh=0.5, track_thresh=0.3, trail_length=10): | |
| # Initialize detection model | |
| self.model = YOLO(model_path) | |
| # Initialize tracker | |
| self.tracker = BYTETracker(track_thresh=track_thresh) | |
| # Parameters | |
| self.conf_thresh = conf_thresh | |
| self.trail_length = trail_length | |
| # Trail storage | |
| self.trails = {} | |
| # Color palette for different objects | |
| self.colors = self._generate_colors(80) | |
| # Class names (COCO dataset) | |
| self.class_names = self.model.names | |
| def _generate_colors(self, num_classes): | |
| """Generate distinct colors for different classes""" | |
| np.random.seed(42) | |
| colors = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8) | |
| return colors.tolist() | |
| def process_frame(self, frame): | |
| """Process a single frame through detection and tracking pipeline""" | |
| # Run detection | |
| results = self.model(frame, conf=self.conf_thresh, verbose=False) | |
| # Extract detections | |
| detections = [] | |
| for result in results: | |
| boxes = result.boxes.xyxy.cpu().numpy() | |
| scores = result.boxes.conf.cpu().numpy() | |
| classes = result.boxes.cls.cpu().numpy() | |
| for i in range(len(boxes)): | |
| detections.append([ | |
| boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3], | |
| scores[i], int(classes[i]) | |
| ]) | |
| # Update tracker | |
| tracked_objects = self.tracker.update(detections) | |
| # Update trails | |
| for obj in tracked_objects: | |
| track_id = int(obj.track_id) | |
| center = ((obj.tlwh[0] + obj.tlwh[2]) / 2, (obj.tlwh[1] + obj.tlwh[3]) / 2) | |
| if track_id not in self.trails: | |
| self.trails[track_id] = deque(maxlen=self.trail_length) | |
| self.trails[track_id].append(center) | |
| # Draw results | |
| output_frame = frame.copy() | |
| output_frame = self._draw_results(output_frame, tracked_objects) | |
| return output_frame | |
| def _draw_results(self, frame, tracked_objects): | |
| """Draw bounding boxes, labels, and trails on frame""" | |
| for obj in tracked_objects: | |
| track_id = int(obj.track_id) | |
| class_id = int(obj.class_id) | |
| class_name = self.class_names[class_id] | |
| conf = obj.score | |
| # Get color for this class | |
| color = self.colors[class_id] | |
| # Draw bounding box | |
| x1, y1, w, h = obj.tlwh | |
| x2, y2 = int(x1 + w), int(y1 + h) | |
| cv2.rectangle(frame, (int(x1), int(y1)), (x2, y2), color, 2) | |
| # Draw label background | |
| label = f"{class_name}-{track_id} {conf:.2f}" | |
| (w_label, h_label), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) | |
| cv2.rectangle(frame, (int(x1), int(y1) - 25), | |
| (int(x1) + w_label, int(y1)), color, -1) | |
| # Draw label text | |
| cv2.putText(frame, label, (int(x1), int(y1) - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) | |
| # Draw motion trail | |
| if track_id in self.trails and len(self.trails[track_id]) > 1: | |
| points = np.array(self.trails[track_id], dtype=np.int32) | |
| cv2.polylines(frame, [points], False, color, 2) | |
| # Draw trail points | |
| for i, point in enumerate(points): | |
| cv2.circle(frame, tuple(point), 3, color, -1) | |
| return frame | |
| def main(): | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Object Detection and Tracking") | |
| parser.add_argument("--input", type=str, default="0", | |
| help="Input source (video file, image, or webcam index)") | |
| parser.add_argument("--output", type=str, default="output.mp4", | |
| help="Output video file") | |
| parser.add_argument("--model", type=str, default="yolov8n.pt", | |
| help="YOLOv8 model path") | |
| parser.add_argument("--conf", type=float, default=0.5, | |
| help="Confidence threshold") | |
| parser.add_argument("--track_thresh", type=float, default=0.3, | |
| help="Tracking threshold") | |
| parser.add_argument("--trail_length", type=int, default=10, | |
| help="Motion trail length") | |
| parser.add_argument("--no_display", action="store_true", | |
| help="Disable display window") | |
| args = parser.parse_args() | |
| # Initialize tracker | |
| tracker = ObjectTracker( | |
| model_path=args.model, | |
| conf_thresh=args.conf, | |
| track_thresh=args.track_thresh, | |
| trail_length=args.trail_length | |
| ) | |
| # Initialize video source | |
| if args.input.isdigit(): | |
| cap = cv2.VideoCapture(int(args.input)) | |
| else: | |
| cap = cv2.VideoCapture(args.input) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video source") | |
| return | |
| # Get video properties | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps == 0: | |
| fps = 30 # Default FPS | |
| # Initialize video writer | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(args.output, fourcc, fps, (width, height)) | |
| # Create display window | |
| if not args.no_display: | |
| cv2.namedWindow("Object Detection & Tracking", cv2.WINDOW_NORMAL) | |
| cv2.resizeWindow("Object Detection & Tracking", width, height) | |
| # Process frames | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Process frame | |
| processed_frame = tracker.process_frame(frame) | |
| # Write output | |
| out.write(processed_frame) | |
| # Display | |
| if not args.no_display: | |
| cv2.imshow("Object Detection & Tracking", processed_frame) | |
| # Handle key presses | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord('q'): # Quit | |
| break | |
| elif key == ord('s'): # Save snapshot | |
| cv2.imwrite(f"snapshot_{frame_count}.jpg", processed_frame) | |
| print(f"Saved snapshot_{frame_count}.jpg") | |
| elif key == ord(' '): # Pause/resume | |
| cv2.waitKey(0) | |
| frame_count += 1 | |
| # Release resources | |
| cap.release() | |
| out.release() | |
| if not args.no_display: | |
| cv2.destroyAllWindows() | |
| print(f"Processing complete. Output saved to {args.output}") | |
| print(f"Processed {frame_count} frames") | |
| if __name__ == "__main__": | |
| main() |