FSD-Level5-CoT / fsd_model /planning.py
Reality123b's picture
Fix fsd_model/planning.py for training (autograd + in-place ops)
ff1bca4 verified
"""
Planning Module for FSD Model.
Handles:
1. Route Planning (high-level waypoints from navigation)
2. Behavior Planning (lane changes, turns, stops, yields)
3. Trajectory Planning (smooth, collision-free path generation)
4. Safety Verification (collision checking, emergency braking)
Architecture: Transformer-based planner that attends to perception features
and produces waypoint trajectories. Inspired by UniAD and VAD planners.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import math
class PositionalEncoding2D(nn.Module):
"""2D sinusoidal positional encoding for BEV features."""
def __init__(self, channels: int):
super().__init__()
self.channels = channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
device = x.device
y_pos = torch.arange(H, device=device).float().unsqueeze(1).expand(H, W) / H
x_pos = torch.arange(W, device=device).float().unsqueeze(0).expand(H, W) / W
dim = torch.arange(0, self.channels // 4, device=device).float()
dim = 10000 ** (2 * dim / (self.channels // 2))
pe = torch.zeros(self.channels, H, W, device=device)
quarter = self.channels // 4
pe[0:quarter] = torch.sin(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
pe[quarter:2*quarter] = torch.cos(x_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
pe[2*quarter:3*quarter] = torch.sin(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
pe[3*quarter:4*quarter] = torch.cos(y_pos.unsqueeze(0) / dim.unsqueeze(1).unsqueeze(2))
return x + pe.unsqueeze(0).expand(B, -1, -1, -1)
class BehaviorPredictor(nn.Module):
"""
Predicts high-level driving behavior/command.
Commands: keep_lane, turn_left, turn_right, lane_change_left,
lane_change_right, stop, yield, park, reverse, emergency_stop
"""
def __init__(self, in_channels: int = 256, num_behaviors: int = 10):
super().__init__()
self.num_behaviors = num_behaviors
self.encoder = nn.Sequential(
nn.AdaptiveAvgPool2d(8),
nn.Flatten(),
nn.Linear(in_channels * 64, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_behaviors),
)
def forward(self, bev: torch.Tensor) -> torch.Tensor:
"""Returns: (B, num_behaviors) logits"""
return self.encoder(bev)
class TrajectoryTransformer(nn.Module):
"""
Transformer-based trajectory planner.
Generates waypoints by attending to BEV features and navigation commands.
Uses learnable trajectory queries (similar to DETR object queries).
"""
def __init__(
self,
bev_channels: int = 256,
d_model: int = 256,
nhead: int = 8,
num_decoder_layers: int = 6,
num_waypoints: int = 20, # planning horizon waypoints
dim_feedforward: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.num_waypoints = num_waypoints
self.d_model = d_model
# BEV feature compression
self.bev_compress = nn.Sequential(
nn.Conv2d(bev_channels, d_model, 1),
nn.BatchNorm2d(d_model),
nn.ReLU(),
)
self.pos_encoding = PositionalEncoding2D(d_model)
# Learnable trajectory queries
self.trajectory_queries = nn.Parameter(
torch.randn(num_waypoints, d_model)
)
# Navigation command embedding (high-level route)
self.command_embed = nn.Embedding(10, d_model) # 10 possible commands
# Ego state embedding (speed, acceleration, steering)
self.ego_state_embed = nn.Sequential(
nn.Linear(6, d_model), # speed, accel, steer, yaw_rate, x, y
nn.ReLU(),
)
# Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer, num_layers=num_decoder_layers
)
# Waypoint prediction heads
self.waypoint_head = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Linear(128, 4), # (x, y, heading, speed)
)
# Confidence / collision probability per waypoint
self.confidence_head = nn.Sequential(
nn.Linear(d_model, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(
self,
bev_features: torch.Tensor,
ego_state: torch.Tensor,
nav_command: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Args:
bev_features: (B, C, H, W) from perception
ego_state: (B, 6) current ego state [speed, accel, steer, yaw_rate, x, y]
nav_command: (B,) integer navigation command
Returns:
waypoints: (B, num_waypoints, 4) predicted trajectory
confidence: (B, num_waypoints, 1) per-waypoint confidence
"""
B = bev_features.shape[0]
device = bev_features.device
# Compress and add positional encoding to BEV
bev = self.bev_compress(bev_features)
bev = self.pos_encoding(bev)
# Flatten BEV to sequence: (B, H*W, d_model)
bev_seq = bev.flatten(2).permute(0, 2, 1)
# Build trajectory queries
queries = self.trajectory_queries.unsqueeze(0).expand(B, -1, -1)
# Add ego state information to queries
ego_feat = self.ego_state_embed(ego_state).unsqueeze(1)
queries = queries + ego_feat
# Add navigation command if provided
if nav_command is not None:
cmd_feat = self.command_embed(nav_command).unsqueeze(1)
queries = queries + cmd_feat
# Transformer decoding
decoded = self.transformer_decoder(queries, bev_seq)
# Predict waypoints and confidence
waypoints = self.waypoint_head(decoded) # (B, T, 4)
confidence = self.confidence_head(decoded) # (B, T, 1)
return {
"waypoints": waypoints,
"confidence": confidence,
}
class SafetyChecker(nn.Module):
"""
Verifies planned trajectories against safety constraints.
Checks for:
- Collision with detected objects
- Lane boundary violations
- Speed limit violations
- Minimum following distance
- Emergency stop conditions
"""
def __init__(
self,
bev_channels: int = 256,
max_speed_ms: float = 8.94, # 20 mph
min_following_distance: float = 4.0, # meters
emergency_decel: float = 8.0, # m/s^2
):
super().__init__()
self.max_speed_ms = max_speed_ms
self.min_following_distance = min_following_distance
self.emergency_decel = emergency_decel
# Collision risk estimator
self.collision_net = nn.Sequential(
nn.AdaptiveAvgPool2d(8),
nn.Flatten(),
nn.Linear(bev_channels * 64, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
# Emergency brake detector
self.emergency_detector = nn.Sequential(
nn.AdaptiveAvgPool2d(4),
nn.Flatten(),
nn.Linear(bev_channels * 16, 128),
nn.ReLU(),
nn.Linear(128, 2), # [no_emergency, emergency]
)
def forward(
self,
bev: torch.Tensor,
planned_waypoints: torch.Tensor,
ego_state: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Args:
bev: (B, C, H, W) BEV features with occupancy info
planned_waypoints: (B, T, 4) planned trajectory
ego_state: (B, 6)
Returns:
Dict with safety scores and emergency signals
"""
# Collision risk
collision_risk = self.collision_net(bev)
# Emergency brake
emergency_logits = self.emergency_detector(bev)
emergency_prob = F.softmax(emergency_logits, dim=-1)[:, 1:]
# Speed constraint check
planned_speeds = planned_waypoints[:, :, 3] # speed component
speed_violation = (planned_speeds > self.max_speed_ms).float().mean(dim=-1, keepdim=True)
# Clamp speeds to max (no in-place ops for autograd)
clamped_speeds = torch.clamp(planned_waypoints[:, :, 3], 0.0, self.max_speed_ms)
clamped_waypoints = torch.cat([
planned_waypoints[:, :, :3],
clamped_speeds.unsqueeze(-1),
], dim=-1)
return {
"collision_risk": collision_risk,
"emergency_brake": emergency_prob,
"speed_violation": speed_violation,
"safe_waypoints": clamped_waypoints,
}
class PlanningModule(nn.Module):
"""
Complete planning module.
Pipeline: BEV → Behavior Prediction → Trajectory Generation → Safety Check
"""
def __init__(
self,
bev_channels: int = 256,
d_model: int = 256,
num_waypoints: int = 20,
max_speed_ms: float = 8.94,
num_behaviors: int = 10,
):
super().__init__()
self.behavior_predictor = BehaviorPredictor(bev_channels, num_behaviors)
self.trajectory_planner = TrajectoryTransformer(
bev_channels=bev_channels,
d_model=d_model,
num_waypoints=num_waypoints,
)
self.safety_checker = SafetyChecker(
bev_channels=bev_channels,
max_speed_ms=max_speed_ms,
)
def forward(
self,
bev_features: torch.Tensor,
ego_state: torch.Tensor,
nav_command: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Args:
bev_features: (B, C, H, W)
ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y]
nav_command: (B,) high-level navigation command
Returns:
Complete planning output including safe trajectory
"""
# Predict behavior
behavior_logits = self.behavior_predictor(bev_features)
# Generate trajectory
traj_output = self.trajectory_planner(
bev_features, ego_state, nav_command
)
# Safety verification
safety = self.safety_checker(
bev_features, traj_output["waypoints"], ego_state
)
return {
"behavior_logits": behavior_logits,
"raw_waypoints": traj_output["waypoints"],
"waypoint_confidence": traj_output["confidence"],
"safe_waypoints": safety["safe_waypoints"],
"collision_risk": safety["collision_risk"],
"emergency_brake": safety["emergency_brake"],
"speed_violation": safety["speed_violation"],
}