""" Vehicle Speed Estimation Module ================================ Implements speed calculation for tracked vehicles using perspective transformation and temporal position tracking with smoothing and outlier detection. Authors: - Abhay Gupta (0205CC221005) - Aditi Lakhera (0205CC221011) - Balraj Patel (0205CC221049) - Bhumika Patel (0205CC221050) Technical Approach: - Tracks vehicle positions across frames in transformed coordinate space - Calculates displacement over time windows - Applies smoothing to reduce noise - Converts to desired speed units """ import numpy as np import supervision as sv from collections import defaultdict, deque from typing import Dict, Optional import logging from .view_transformer import PerspectiveTransformer logger = logging.getLogger(__name__) class VehicleSpeedEstimator: """ Estimates vehicle speeds using perspective-corrected position tracking. This class maintains a history of vehicle positions in real-world coordinates and calculates speeds based on displacement over time. """ def __init__( self, fps: int, transformer: PerspectiveTransformer, history_duration: int = 1, speed_unit: str = "km/h", min_frames_for_speed: Optional[int] = None ): """ Initialize the speed estimator. Args: fps: Video frames per second transformer: Perspective transformation instance history_duration: Time window for speed calculation (seconds) speed_unit: Output speed unit ("km/h", "mph", or "m/s") min_frames_for_speed: Minimum frames needed for speed calculation (defaults to fps/2) """ self.fps = fps self.transformer = transformer self.history_duration = history_duration self.speed_unit = speed_unit # Calculate minimum frames needed for reliable speed estimation self.min_frames = min_frames_for_speed or max(int(fps / 2), 5) # Maximum history length in frames max_history_frames = int(fps * history_duration) # Store position history for each tracked object # Key: tracker_id, Value: deque of (x, y) positions in real-world coordinates self.position_history: Dict[int, deque] = defaultdict( lambda: deque(maxlen=max_history_frames) ) # Speed unit conversion factors (from m/s) self.unit_conversions = { "km/h": 3.6, "mph": 2.23694, "m/s": 1.0 } if speed_unit not in self.unit_conversions: raise ValueError(f"Invalid speed unit: {speed_unit}") self.conversion_factor = self.unit_conversions[speed_unit] logger.info(f"Speed estimator initialized: {fps}fps, {history_duration}s history, " f"unit={speed_unit}, min_frames={self.min_frames}") def _calculate_speed_for_vehicle(self, tracker_id: int) -> Optional[float]: """ Calculate speed for a specific tracked vehicle. Args: tracker_id: Unique tracker ID Returns: Speed in configured units, or None if insufficient data """ positions = self.position_history[tracker_id] # Need sufficient position history if len(positions) < self.min_frames: return None try: # Get first and last positions start_pos = positions[0] end_pos = positions[-1] # Calculate Euclidean distance in real-world coordinates (meters) displacement = np.linalg.norm(end_pos - start_pos) # Calculate time elapsed time_elapsed = len(positions) / self.fps # Avoid division by zero if time_elapsed == 0: return None # Calculate speed in m/s speed_ms = displacement / time_elapsed # Convert to desired unit speed = speed_ms * self.conversion_factor # Apply reasonable bounds (0-300 km/h equivalent) max_speed = 300 * self.unit_conversions["km/h"] / self.conversion_factor if speed < 0 or speed > max_speed: logger.debug(f"Outlier speed detected for vehicle {tracker_id}: {speed:.1f}") return None return speed except Exception as e: logger.warning(f"Error calculating speed for vehicle {tracker_id}: {e}") return None def _apply_smoothing(self, speeds: np.ndarray, window_size: int = 3) -> np.ndarray: """ Apply moving average smoothing to speed values. Args: speeds: Array of speed values window_size: Smoothing window size Returns: Smoothed speed array """ if len(speeds) < window_size: return speeds # Simple moving average smoothed = np.convolve(speeds, np.ones(window_size)/window_size, mode='same') return smoothed def estimate(self, detections: sv.Detections) -> sv.Detections: """ Estimate speeds for all detected vehicles in current frame. Args: detections: Detection results from current frame Returns: Updated detections with 'speed' field added to data """ # Initialize speed array speeds = [] # Check if we have tracker IDs if not hasattr(detections, 'tracker_id') or detections.tracker_id is None: logger.warning("No tracker IDs found in detections") detections.data["speed"] = np.zeros(len(detections)) return detections # Get anchor points (bottom center of bounding boxes) anchor_points = detections.get_anchors_coordinates( anchor=sv.Position.BOTTOM_CENTER ) # Transform points to real-world coordinates try: transformed_points = self.transformer.apply_transformation(anchor_points) except Exception as e: logger.error(f"Error transforming points: {e}") detections.data["speed"] = np.zeros(len(detections)) return detections # Process each detection for tracker_id, point in zip(detections.tracker_id, transformed_points): # Add position to history self.position_history[tracker_id].append(point) # Calculate speed speed = self._calculate_speed_for_vehicle(tracker_id) # Use 0 if speed cannot be calculated speeds.append(speed if speed is not None else 0.0) # Convert to numpy array and round speeds = np.array(speeds) speeds = np.round(speeds).astype(int) # Add to detections data detections.data["speed"] = speeds return detections def reset(self) -> None: """Clear all position history.""" self.position_history.clear() logger.info("Speed estimator history cleared") def get_tracked_vehicle_count(self) -> int: """ Get number of currently tracked vehicles. Returns: Number of vehicles with position history """ return len(self.position_history) def cleanup_old_tracks(self, active_ids: set) -> None: """ Remove position history for vehicles no longer being tracked. Args: active_ids: Set of currently active tracker IDs """ inactive_ids = set(self.position_history.keys()) - active_ids for tracker_id in inactive_ids: del self.position_history[tracker_id] if inactive_ids: logger.debug(f"Cleaned up {len(inactive_ids)} inactive tracks")