Abs6187's picture
Upload 11 files
806bdda verified
"""
Vehicle Speed Estimation Module
================================
Implements speed calculation for tracked vehicles using perspective transformation
and temporal position tracking with smoothing and outlier detection.
Authors:
- Abhay Gupta (0205CC221005)
- Aditi Lakhera (0205CC221011)
- Balraj Patel (0205CC221049)
- Bhumika Patel (0205CC221050)
Technical Approach:
- Tracks vehicle positions across frames in transformed coordinate space
- Calculates displacement over time windows
- Applies smoothing to reduce noise
- Converts to desired speed units
"""
import numpy as np
import supervision as sv
from collections import defaultdict, deque
from typing import Dict, Optional
import logging
from .view_transformer import PerspectiveTransformer
logger = logging.getLogger(__name__)
class VehicleSpeedEstimator:
"""
Estimates vehicle speeds using perspective-corrected position tracking.
This class maintains a history of vehicle positions in real-world coordinates
and calculates speeds based on displacement over time.
"""
def __init__(
self,
fps: int,
transformer: PerspectiveTransformer,
history_duration: int = 1,
speed_unit: str = "km/h",
min_frames_for_speed: Optional[int] = None
):
"""
Initialize the speed estimator.
Args:
fps: Video frames per second
transformer: Perspective transformation instance
history_duration: Time window for speed calculation (seconds)
speed_unit: Output speed unit ("km/h", "mph", or "m/s")
min_frames_for_speed: Minimum frames needed for speed calculation
(defaults to fps/2)
"""
self.fps = fps
self.transformer = transformer
self.history_duration = history_duration
self.speed_unit = speed_unit
# Calculate minimum frames needed for reliable speed estimation
self.min_frames = min_frames_for_speed or max(int(fps / 2), 5)
# Maximum history length in frames
max_history_frames = int(fps * history_duration)
# Store position history for each tracked object
# Key: tracker_id, Value: deque of (x, y) positions in real-world coordinates
self.position_history: Dict[int, deque] = defaultdict(
lambda: deque(maxlen=max_history_frames)
)
# Speed unit conversion factors (from m/s)
self.unit_conversions = {
"km/h": 3.6,
"mph": 2.23694,
"m/s": 1.0
}
if speed_unit not in self.unit_conversions:
raise ValueError(f"Invalid speed unit: {speed_unit}")
self.conversion_factor = self.unit_conversions[speed_unit]
logger.info(f"Speed estimator initialized: {fps}fps, {history_duration}s history, "
f"unit={speed_unit}, min_frames={self.min_frames}")
def _calculate_speed_for_vehicle(self, tracker_id: int) -> Optional[float]:
"""
Calculate speed for a specific tracked vehicle.
Args:
tracker_id: Unique tracker ID
Returns:
Speed in configured units, or None if insufficient data
"""
positions = self.position_history[tracker_id]
# Need sufficient position history
if len(positions) < self.min_frames:
return None
try:
# Get first and last positions
start_pos = positions[0]
end_pos = positions[-1]
# Calculate Euclidean distance in real-world coordinates (meters)
displacement = np.linalg.norm(end_pos - start_pos)
# Calculate time elapsed
time_elapsed = len(positions) / self.fps
# Avoid division by zero
if time_elapsed == 0:
return None
# Calculate speed in m/s
speed_ms = displacement / time_elapsed
# Convert to desired unit
speed = speed_ms * self.conversion_factor
# Apply reasonable bounds (0-300 km/h equivalent)
max_speed = 300 * self.unit_conversions["km/h"] / self.conversion_factor
if speed < 0 or speed > max_speed:
logger.debug(f"Outlier speed detected for vehicle {tracker_id}: {speed:.1f}")
return None
return speed
except Exception as e:
logger.warning(f"Error calculating speed for vehicle {tracker_id}: {e}")
return None
def _apply_smoothing(self, speeds: np.ndarray, window_size: int = 3) -> np.ndarray:
"""
Apply moving average smoothing to speed values.
Args:
speeds: Array of speed values
window_size: Smoothing window size
Returns:
Smoothed speed array
"""
if len(speeds) < window_size:
return speeds
# Simple moving average
smoothed = np.convolve(speeds, np.ones(window_size)/window_size, mode='same')
return smoothed
def estimate(self, detections: sv.Detections) -> sv.Detections:
"""
Estimate speeds for all detected vehicles in current frame.
Args:
detections: Detection results from current frame
Returns:
Updated detections with 'speed' field added to data
"""
# Initialize speed array
speeds = []
# Check if we have tracker IDs
if not hasattr(detections, 'tracker_id') or detections.tracker_id is None:
logger.warning("No tracker IDs found in detections")
detections.data["speed"] = np.zeros(len(detections))
return detections
# Get anchor points (bottom center of bounding boxes)
anchor_points = detections.get_anchors_coordinates(
anchor=sv.Position.BOTTOM_CENTER
)
# Transform points to real-world coordinates
try:
transformed_points = self.transformer.apply_transformation(anchor_points)
except Exception as e:
logger.error(f"Error transforming points: {e}")
detections.data["speed"] = np.zeros(len(detections))
return detections
# Process each detection
for tracker_id, point in zip(detections.tracker_id, transformed_points):
# Add position to history
self.position_history[tracker_id].append(point)
# Calculate speed
speed = self._calculate_speed_for_vehicle(tracker_id)
# Use 0 if speed cannot be calculated
speeds.append(speed if speed is not None else 0.0)
# Convert to numpy array and round
speeds = np.array(speeds)
speeds = np.round(speeds).astype(int)
# Add to detections data
detections.data["speed"] = speeds
return detections
def reset(self) -> None:
"""Clear all position history."""
self.position_history.clear()
logger.info("Speed estimator history cleared")
def get_tracked_vehicle_count(self) -> int:
"""
Get number of currently tracked vehicles.
Returns:
Number of vehicles with position history
"""
return len(self.position_history)
def cleanup_old_tracks(self, active_ids: set) -> None:
"""
Remove position history for vehicles no longer being tracked.
Args:
active_ids: Set of currently active tracker IDs
"""
inactive_ids = set(self.position_history.keys()) - active_ids
for tracker_id in inactive_ids:
del self.position_history[tracker_id]
if inactive_ids:
logger.debug(f"Cleaned up {len(inactive_ids)} inactive tracks")