| """ |
| Planning Module for FSD Model. |
| Handles: |
| 1. Route Planning (high-level waypoints from navigation) |
| 2. Behavior Planning (lane changes, turns, stops, yields) |
| 3. Trajectory Planning (smooth, collision-free path generation) |
| 4. Safety Verification (collision checking, emergency braking) |
| |
| Architecture: Transformer-based planner that attends to perception features |
| and produces waypoint trajectories. Inspired by UniAD and VAD planners. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple |
| import math |
|
|
|
|
| class PositionalEncoding2D(nn.Module): |
| """2D sinusoidal positional encoding for BEV features.""" |
| def __init__(self, channels: int): |
| super().__init__() |
| self.channels = channels |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, H, W = x.shape |
| device = x.device |
| |
| y_pos = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) / H |
| x_pos = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) / W |
| |
| dim = torch.arange(0, self.channels // 4, device=device).float() |
| dim = 10000 ** (2 * dim / (self.channels // 2)) |
| |
| pe = torch.zeros(self.channels, H, W, device=device) |
| quarter = self.channels // 4 |
| pe[0:quarter] = torch.sin(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) |
| pe[quarter:2*quarter] = torch.cos(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) |
| pe[2*quarter:3*quarter] = torch.sin(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) |
| pe[3*quarter:4*quarter] = torch.cos(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2)) |
| |
| return x + pe.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
|
|
| class BehaviorPredictor(nn.Module): |
| """ |
| Predicts high-level driving behavior/command. |
| Commands: keep_lane, turn_left, turn_right, lane_change_left, |
| lane_change_right, stop, yield, park, reverse, emergency_stop |
| """ |
| def __init__(self, in_channels: int = 256, num_behaviors: int = 10): |
| super().__init__() |
| self.num_behaviors = num_behaviors |
| |
| self.encoder = nn.Sequential( |
| nn.AdaptiveAvgPool2d(8), |
| nn.Flatten(), |
| nn.Linear(in_channels * 64, 512), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 256), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(256, num_behaviors), |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> torch.Tensor: |
| """Returns: (B, num_behaviors) logits""" |
| return self.encoder(bev) |
|
|
|
|
| class TrajectoryTransformer(nn.Module): |
| """ |
| Transformer-based trajectory planner. |
| Generates waypoints by attending to BEV features and navigation commands. |
| Uses learnable trajectory queries (similar to DETR object queries). |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| d_model: int = 256, |
| nhead: int = 8, |
| num_decoder_layers: int = 6, |
| num_waypoints: int = 20, |
| dim_feedforward: int = 1024, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.num_waypoints = num_waypoints |
| self.d_model = d_model |
| |
| |
| self.bev_compress = nn.Sequential( |
| nn.Conv2d(bev_channels, d_model, 1), |
| nn.BatchNorm2d(d_model), |
| nn.ReLU(), |
| ) |
| self.pos_encoding = PositionalEncoding2D(d_model) |
| |
| |
| self.trajectory_queries = nn.Parameter( |
| torch.randn(num_waypoints, d_model) |
| ) |
| |
| |
| self.command_embed = nn.Embedding(10, d_model) |
| |
| |
| self.ego_state_embed = nn.Sequential( |
| nn.Linear(6, d_model), |
| nn.ReLU(), |
| ) |
| |
| |
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| batch_first=True, |
| ) |
| self.transformer_decoder = nn.TransformerDecoder( |
| decoder_layer, num_layers=num_decoder_layers |
| ) |
| |
| |
| self.waypoint_head = nn.Sequential( |
| nn.Linear(d_model, 128), |
| nn.ReLU(), |
| nn.Linear(128, 4), |
| ) |
| |
| |
| self.confidence_head = nn.Sequential( |
| nn.Linear(d_model, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
| |
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| ego_state: torch.Tensor, |
| nav_command: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| bev_features: (B, C, H, W) from perception |
| ego_state: (B, 6) current ego state [speed, accel, steer, yaw_rate, x, y] |
| nav_command: (B,) integer navigation command |
| Returns: |
| waypoints: (B, num_waypoints, 4) predicted trajectory |
| confidence: (B, num_waypoints, 1) per-waypoint confidence |
| """ |
| B = bev_features.shape[0] |
| device = bev_features.device |
| |
| |
| bev = self.bev_compress(bev_features) |
| bev = self.pos_encoding(bev) |
| |
| |
| bev_seq = bev.flatten(2).permute(0, 2, 1) |
| |
| |
| queries = self.trajectory_queries.unsqueeze(0).expand(B, -1, -1) |
| |
| |
| ego_feat = self.ego_state_embed(ego_state).unsqueeze(1) |
| queries = queries + ego_feat |
| |
| |
| if nav_command is not None: |
| cmd_feat = self.command_embed(nav_command).unsqueeze(1) |
| queries = queries + cmd_feat |
| |
| |
| decoded = self.transformer_decoder(queries, bev_seq) |
| |
| |
| waypoints = self.waypoint_head(decoded) |
| confidence = self.confidence_head(decoded) |
| |
| return { |
| "waypoints": waypoints, |
| "confidence": confidence, |
| } |
|
|
|
|
| class SafetyChecker(nn.Module): |
| """ |
| Verifies planned trajectories against safety constraints. |
| Checks for: |
| - Collision with detected objects |
| - Lane boundary violations |
| - Speed limit violations |
| - Minimum following distance |
| - Emergency stop conditions |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| max_speed_ms: float = 8.94, |
| min_following_distance: float = 4.0, |
| emergency_decel: float = 8.0, |
| ): |
| super().__init__() |
| self.max_speed_ms = max_speed_ms |
| self.min_following_distance = min_following_distance |
| self.emergency_decel = emergency_decel |
| |
| |
| self.collision_net = nn.Sequential( |
| nn.AdaptiveAvgPool2d(8), |
| nn.Flatten(), |
| nn.Linear(bev_channels * 64, 256), |
| nn.ReLU(), |
| nn.Linear(256, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
| |
| |
| self.emergency_detector = nn.Sequential( |
| nn.AdaptiveAvgPool2d(4), |
| nn.Flatten(), |
| nn.Linear(bev_channels * 16, 128), |
| nn.ReLU(), |
| nn.Linear(128, 2), |
| ) |
| |
| def forward( |
| self, |
| bev: torch.Tensor, |
| planned_waypoints: torch.Tensor, |
| ego_state: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| bev: (B, C, H, W) BEV features with occupancy info |
| planned_waypoints: (B, T, 4) planned trajectory |
| ego_state: (B, 6) |
| Returns: |
| Dict with safety scores and emergency signals |
| """ |
| |
| collision_risk = self.collision_net(bev) |
| |
| |
| emergency_logits = self.emergency_detector(bev) |
| emergency_prob = F.softmax(emergency_logits, dim=-1)[:, 1:] |
| |
| |
| planned_speeds = planned_waypoints[:, :, 3] |
| speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True) |
| |
| |
| clamped_speeds = torch.clamp(planned_waypoints[:, :, 3], 0.0, self.max_speed_ms) |
| clamped_waypoints = torch.cat([ |
| planned_waypoints[:, :, :3], |
| clamped_speeds.unsqueeze(-1), |
| ], dim=-1) |
| |
| return { |
| "collision_risk": collision_risk, |
| "emergency_brake": emergency_prob, |
| "speed_violation": speed_violation, |
| "safe_waypoints": clamped_waypoints, |
| } |
|
|
|
|
| class PlanningModule(nn.Module): |
| """ |
| Complete planning module. |
| Pipeline: BEV → Behavior Prediction → Trajectory Generation → Safety Check |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| d_model: int = 256, |
| num_waypoints: int = 20, |
| max_speed_ms: float = 8.94, |
| num_behaviors: int = 10, |
| ): |
| super().__init__() |
| |
| self.behavior_predictor = BehaviorPredictor(bev_channels, num_behaviors) |
| self.trajectory_planner = TrajectoryTransformer( |
| bev_channels=bev_channels, |
| d_model=d_model, |
| num_waypoints=num_waypoints, |
| ) |
| self.safety_checker = SafetyChecker( |
| bev_channels=bev_channels, |
| max_speed_ms=max_speed_ms, |
| ) |
| |
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| ego_state: torch.Tensor, |
| nav_command: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| bev_features: (B, C, H, W) |
| ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y] |
| nav_command: (B,) high-level navigation command |
| Returns: |
| Complete planning output including safe trajectory |
| """ |
| |
| behavior_logits = self.behavior_predictor(bev_features) |
| |
| |
| traj_output = self.trajectory_planner( |
| bev_features, ego_state, nav_command |
| ) |
| |
| |
| safety = self.safety_checker( |
| bev_features, traj_output["waypoints"], ego_state |
| ) |
| |
| return { |
| "behavior_logits": behavior_logits, |
| "raw_waypoints": traj_output["waypoints"], |
| "waypoint_confidence": traj_output["confidence"], |
| "safe_waypoints": safety["safe_waypoints"], |
| "collision_risk": safety["collision_risk"], |
| "emergency_brake": safety["emergency_brake"], |
| "speed_violation": safety["speed_violation"], |
| } |
|
|