""" Planning Module for FSD Model. Handles: 1. Route Planning (high-level waypoints from navigation) 2. Behavior Planning (lane changes, turns, stops, yields) 3. Trajectory Planning (smooth, collision-free path generation) 4. Safety Verification (collision checking, emergency braking) Architecture: Transformer-based planner that attends to perception features and produces waypoint trajectories. Inspired by UniAD and VAD planners. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple import math class PositionalEncoding2D(nn.Module): """2D sinusoidal positional encoding for BEV features.""" def __init__(self, channels: int): super().__init__() self.channels = channels def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape device = x.device y_pos = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) / H x_pos = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) / W dim = torch.arange(0, self.channels // 4, device=device).float() dim = 10000 ** (2 * dim / (self.channels // 2)) pe = torch.zeros(self.channels, H, W, device=device) quarter = self.channels // 4 pe[0:quarter] = torch.sin(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) pe[quarter:2*quarter] = torch.cos(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) pe[2*quarter:3*quarter] = torch.sin(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) pe[3*quarter:4*quarter] = torch.cos(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) return x + pe.unsqueeze(0).expand(B, -1, -1, -1) class BehaviorPredictor(nn.Module): """ Predicts high-level driving behavior/command. Commands: keep_lane, turn_left, turn_right, lane_change_left, lane_change_right, stop, yield, park, reverse, emergency_stop """ def __init__(self, in_channels: int = 256, num_behaviors: int = 10): super().__init__() self.num_behaviors = num_behaviors self.encoder = nn.Sequential( nn.AdaptiveAvgPool2d(8), nn.Flatten(), nn.Linear(in_channels * 64, 512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, num_behaviors), ) def forward(self, bev: torch.Tensor) -> torch.Tensor: """Returns: (B, num_behaviors) logits""" return self.encoder(bev) class TrajectoryTransformer(nn.Module): """ Transformer-based trajectory planner. Generates waypoints by attending to BEV features and navigation commands. Uses learnable trajectory queries (similar to DETR object queries). """ def __init__( self, bev_channels: int = 256, d_model: int = 256, nhead: int = 8, num_decoder_layers: int = 6, num_waypoints: int = 20, # planning horizon waypoints dim_feedforward: int = 1024, dropout: float = 0.1, ): super().__init__() self.num_waypoints = num_waypoints self.d_model = d_model # BEV feature compression self.bev_compress = nn.Sequential( nn.Conv2d(bev_channels, d_model, 1), nn.BatchNorm2d(d_model), nn.ReLU(), ) self.pos_encoding = PositionalEncoding2D(d_model) # Learnable trajectory queries self.trajectory_queries = nn.Parameter( torch.randn(num_waypoints, d_model) ) # Navigation command embedding (high-level route) self.command_embed = nn.Embedding(10, d_model) # 10 possible commands # Ego state embedding (speed, acceleration, steering) self.ego_state_embed = nn.Sequential( nn.Linear(6, d_model), # speed, accel, steer, yaw_rate, x, y nn.ReLU(), ) # Transformer decoder decoder_layer = nn.TransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True, ) self.transformer_decoder = nn.TransformerDecoder( decoder_layer, num_layers=num_decoder_layers ) # Waypoint prediction heads self.waypoint_head = nn.Sequential( nn.Linear(d_model, 128), nn.ReLU(), nn.Linear(128, 4), # (x, y, heading, speed) ) # Confidence / collision probability per waypoint self.confidence_head = nn.Sequential( nn.Linear(d_model, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid(), ) def forward( self, bev_features: torch.Tensor, ego_state: torch.Tensor, nav_command: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Args: bev_features: (B, C, H, W) from perception ego_state: (B, 6) current ego state [speed, accel, steer, yaw_rate, x, y] nav_command: (B,) integer navigation command Returns: waypoints: (B, num_waypoints, 4) predicted trajectory confidence: (B, num_waypoints, 1) per-waypoint confidence """ B = bev_features.shape[0] device = bev_features.device # Compress and add positional encoding to BEV bev = self.bev_compress(bev_features) bev = self.pos_encoding(bev) # Flatten BEV to sequence: (B, H*W, d_model) bev_seq = bev.flatten(2).permute(0, 2, 1) # Build trajectory queries queries = self.trajectory_queries.unsqueeze(0).expand(B, -1, -1) # Add ego state information to queries ego_feat = self.ego_state_embed(ego_state).unsqueeze(1) queries = queries + ego_feat # Add navigation command if provided if nav_command is not None: cmd_feat = self.command_embed(nav_command).unsqueeze(1) queries = queries + cmd_feat # Transformer decoding decoded = self.transformer_decoder(queries, bev_seq) # Predict waypoints and confidence waypoints = self.waypoint_head(decoded) # (B, T, 4) confidence = self.confidence_head(decoded) # (B, T, 1) return { "waypoints": waypoints, "confidence": confidence, } class SafetyChecker(nn.Module): """ Verifies planned trajectories against safety constraints. Checks for: - Collision with detected objects - Lane boundary violations - Speed limit violations - Minimum following distance - Emergency stop conditions """ def __init__( self, bev_channels: int = 256, max_speed_ms: float = 8.94, # 20 mph min_following_distance: float = 4.0, # meters emergency_decel: float = 8.0, # m/s^2 ): super().__init__() self.max_speed_ms = max_speed_ms self.min_following_distance = min_following_distance self.emergency_decel = emergency_decel # Collision risk estimator self.collision_net = nn.Sequential( nn.AdaptiveAvgPool2d(8), nn.Flatten(), nn.Linear(bev_channels * 64, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid(), ) # Emergency brake detector self.emergency_detector = nn.Sequential( nn.AdaptiveAvgPool2d(4), nn.Flatten(), nn.Linear(bev_channels * 16, 128), nn.ReLU(), nn.Linear(128, 2), # [no_emergency, emergency] ) def forward( self, bev: torch.Tensor, planned_waypoints: torch.Tensor, ego_state: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ Args: bev: (B, C, H, W) BEV features with occupancy info planned_waypoints: (B, T, 4) planned trajectory ego_state: (B, 6) Returns: Dict with safety scores and emergency signals """ # Collision risk collision_risk = self.collision_net(bev) # Emergency brake emergency_logits = self.emergency_detector(bev) emergency_prob = F.softmax(emergency_logits, dim=-1)[:, 1:] # Speed constraint check planned_speeds = planned_waypoints[:, :, 3] # speed component speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True) # Clamp speeds to max (no in-place ops for autograd) clamped_speeds = torch.clamp(planned_waypoints[:, :, 3], 0.0, self.max_speed_ms) clamped_waypoints = torch.cat([ planned_waypoints[:, :, :3], clamped_speeds.unsqueeze(-1), ], dim=-1) return { "collision_risk": collision_risk, "emergency_brake": emergency_prob, "speed_violation": speed_violation, "safe_waypoints": clamped_waypoints, } class PlanningModule(nn.Module): """ Complete planning module. Pipeline: BEV → Behavior Prediction → Trajectory Generation → Safety Check """ def __init__( self, bev_channels: int = 256, d_model: int = 256, num_waypoints: int = 20, max_speed_ms: float = 8.94, num_behaviors: int = 10, ): super().__init__() self.behavior_predictor = BehaviorPredictor(bev_channels, num_behaviors) self.trajectory_planner = TrajectoryTransformer( bev_channels=bev_channels, d_model=d_model, num_waypoints=num_waypoints, ) self.safety_checker = SafetyChecker( bev_channels=bev_channels, max_speed_ms=max_speed_ms, ) def forward( self, bev_features: torch.Tensor, ego_state: torch.Tensor, nav_command: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Args: bev_features: (B, C, H, W) ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y] nav_command: (B,) high-level navigation command Returns: Complete planning output including safe trajectory """ # Predict behavior behavior_logits = self.behavior_predictor(bev_features) # Generate trajectory traj_output = self.trajectory_planner( bev_features, ego_state, nav_command ) # Safety verification safety = self.safety_checker( bev_features, traj_output["waypoints"], ego_state ) return { "behavior_logits": behavior_logits, "raw_waypoints": traj_output["waypoints"], "waypoint_confidence": traj_output["confidence"], "safe_waypoints": safety["safe_waypoints"], "collision_risk": safety["collision_risk"], "emergency_brake": safety["emergency_brake"], "speed_violation": safety["speed_violation"], }