| """ |
| Control Module for FSD Model. |
| Converts planned trajectory waypoints into actuator commands: |
| - Steering angle |
| - Throttle (acceleration) |
| - Brake |
| - Gear (forward/reverse/park) |
| |
| Uses a combination of: |
| 1. PID controllers for smooth tracking |
| 2. Neural network for adaptive control |
| 3. Stanley controller for lateral control |
| 4. Bicycle model for vehicle dynamics |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Optional, Tuple |
| import math |
|
|
|
|
| class BicycleModel(nn.Module): |
| """ |
| Kinematic bicycle model for vehicle dynamics simulation. |
| Used for both prediction and control. |
| State: [x, y, heading, speed] |
| Control: [steering_angle, acceleration] |
| """ |
| def __init__(self, wheelbase: float = 2.7, dt: float = 0.1): |
| super().__init__() |
| self.wheelbase = wheelbase |
| self.dt = dt |
| |
| def forward( |
| self, state: torch.Tensor, control: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Args: |
| state: (B, 4) [x, y, heading, speed] |
| control: (B, 2) [steering_angle, acceleration] |
| Returns: |
| next_state: (B, 4) |
| """ |
| x, y, heading, speed = state[:, 0], state[:, 1], state[:, 2], state[:, 3] |
| steer, accel = control[:, 0], control[:, 1] |
| |
| |
| beta = torch.atan(0.5 * torch.tan(steer)) |
| |
| x_new = x + speed * torch.cos(heading + beta) * self.dt |
| y_new = y + speed * torch.sin(heading + beta) * self.dt |
| heading_new = heading + (speed / self.wheelbase) * torch.sin(beta) * self.dt |
| speed_new = speed + accel * self.dt |
| |
| |
| speed_new = torch.clamp(speed_new, min=0.0) |
| |
| return torch.stack([x_new, y_new, heading_new, speed_new], dim=-1) |
|
|
|
|
| class PIDController(nn.Module): |
| """ |
| Learnable PID controller with neural network gain scheduling. |
| Gains (Kp, Ki, Kd) are predicted based on current state. |
| """ |
| def __init__(self, state_dim: int = 6, hidden_dim: int = 64): |
| super().__init__() |
| |
| |
| self.gain_net = nn.Sequential( |
| nn.Linear(state_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 6), |
| nn.Softplus(), |
| ) |
| |
| |
| self.register_buffer('integral_error', torch.zeros(1, 2)) |
| self.register_buffer('prev_error', torch.zeros(1, 2)) |
| |
| def forward( |
| self, |
| error: torch.Tensor, |
| ego_state: torch.Tensor, |
| dt: float = 0.1, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| error: (B, 2) [lateral_error, longitudinal_error] |
| ego_state: (B, 6) current vehicle state |
| dt: time step |
| Returns: |
| control: (B, 2) [steering_correction, accel_correction] |
| """ |
| B = error.shape[0] |
| |
| |
| gains = self.gain_net(ego_state) |
| kp = gains[:, :2] |
| ki = gains[:, 2:4] |
| kd = gains[:, 4:6] |
| |
| |
| proportional = kp * error |
| |
| |
| if self.integral_error.shape[0] != B: |
| self.integral_error = torch.zeros(B, 2, device=error.device) |
| if self.prev_error.shape[0] != B: |
| self.prev_error = torch.zeros(B, 2, device=error.device) |
| |
| integral_error = self.integral_error.detach() + error * dt |
| integral_error = torch.clamp(integral_error, -10.0, 10.0) |
| self.integral_error = integral_error.detach() |
| integral = ki * integral_error |
| |
| |
| derivative = kd * (error - self.prev_error.detach()) / dt |
| self.prev_error = error.detach() |
| |
| control = proportional + integral + derivative |
| |
| return control |
| |
| def reset(self): |
| """Reset integral and derivative buffers.""" |
| self.integral_error.zero_() |
| self.prev_error.zero_() |
|
|
|
|
| class StanleyController(nn.Module): |
| """ |
| Stanley lateral controller enhanced with learned parameters. |
| Computes steering angle based on: |
| 1. Heading error |
| 2. Cross-track error |
| """ |
| def __init__(self, k_gain: float = 0.5, k_soft: float = 1.0): |
| super().__init__() |
| |
| self.k_gain = nn.Parameter(torch.tensor(k_gain)) |
| self.k_soft = nn.Parameter(torch.tensor(k_soft)) |
| |
| def forward( |
| self, |
| heading_error: torch.Tensor, |
| cross_track_error: torch.Tensor, |
| speed: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| heading_error: (B,) heading difference to path |
| cross_track_error: (B,) lateral distance to path |
| speed: (B,) current speed |
| Returns: |
| steering: (B,) desired steering angle (radians) |
| """ |
| |
| cross_track_steer = torch.atan2( |
| self.k_gain * cross_track_error, |
| speed + self.k_soft |
| ) |
| steering = heading_error + cross_track_steer |
| |
| |
| max_steer = math.radians(35) |
| steering = torch.clamp(steering, -max_steer, max_steer) |
| |
| return steering |
|
|
|
|
| class NeuralController(nn.Module): |
| """ |
| End-to-end neural network controller. |
| Takes BEV features + ego state + waypoints and directly outputs |
| steering, throttle, brake commands. |
| Serves as a refinement on top of classical controllers. |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| waypoint_dim: int = 4, |
| num_waypoints: int = 20, |
| ego_dim: int = 6, |
| hidden_dim: int = 256, |
| ): |
| super().__init__() |
| |
| |
| self.bev_encoder = nn.Sequential( |
| nn.AdaptiveAvgPool2d(4), |
| nn.Flatten(), |
| nn.Linear(bev_channels * 16, hidden_dim), |
| nn.ReLU(), |
| ) |
| |
| |
| self.waypoint_encoder = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(num_waypoints * waypoint_dim, hidden_dim), |
| nn.ReLU(), |
| ) |
| |
| |
| self.ego_encoder = nn.Sequential( |
| nn.Linear(ego_dim, hidden_dim // 2), |
| nn.ReLU(), |
| ) |
| |
| |
| self.control_head = nn.Sequential( |
| nn.Linear(hidden_dim * 2 + hidden_dim // 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden_dim, 128), |
| nn.ReLU(), |
| nn.Linear(128, 3), |
| ) |
| |
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| waypoints: torch.Tensor, |
| ego_state: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Returns: |
| Dict with steering (-1 to 1), throttle (0 to 1), brake (0 to 1) |
| """ |
| bev_feat = self.bev_encoder(bev_features) |
| wp_feat = self.waypoint_encoder(waypoints) |
| ego_feat = self.ego_encoder(ego_state) |
| |
| combined = torch.cat([bev_feat, wp_feat, ego_feat], dim=-1) |
| raw = self.control_head(combined) |
| |
| steering = torch.tanh(raw[:, 0]) |
| throttle = torch.sigmoid(raw[:, 1]) |
| brake = torch.sigmoid(raw[:, 2]) |
| |
| return { |
| "steering": steering, |
| "throttle": throttle, |
| "brake": brake, |
| } |
|
|
|
|
| class ControlModule(nn.Module): |
| """ |
| Complete control module that combines: |
| 1. Neural controller (BEV-aware, end-to-end) |
| 2. Stanley controller (geometric lateral control) |
| 3. PID controller (error-based correction) |
| 4. Bicycle model (physics-based prediction) |
| 5. Safety limits enforcement |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| num_waypoints: int = 20, |
| wheelbase: float = 2.7, |
| max_speed_ms: float = 8.94, |
| max_steering_deg: float = 35.0, |
| max_accel: float = 3.0, |
| max_decel: float = 8.0, |
| dt: float = 0.1, |
| ): |
| super().__init__() |
| self.max_speed_ms = max_speed_ms |
| self.max_steering = math.radians(max_steering_deg) |
| self.max_accel = max_accel |
| self.max_decel = max_decel |
| self.dt = dt |
| |
| |
| self.neural_controller = NeuralController( |
| bev_channels=bev_channels, |
| num_waypoints=num_waypoints, |
| ) |
| self.stanley_controller = StanleyController() |
| self.pid_controller = PIDController() |
| self.bicycle_model = BicycleModel(wheelbase, dt) |
| |
| |
| self.fusion_weights = nn.Sequential( |
| nn.Linear(6, 32), |
| nn.ReLU(), |
| nn.Linear(32, 3), |
| nn.Softmax(dim=-1), |
| ) |
| |
| def forward( |
| self, |
| bev_features: torch.Tensor, |
| planned_waypoints: torch.Tensor, |
| ego_state: torch.Tensor, |
| emergency_brake: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| bev_features: (B, C, H, W) BEV features |
| planned_waypoints: (B, T, 4) [x, y, heading, speed] |
| ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y] |
| emergency_brake: (B, 1) emergency brake probability |
| Returns: |
| Dict with final actuator commands |
| """ |
| B = ego_state.shape[0] |
| device = ego_state.device |
| |
| |
| neural_out = self.neural_controller(bev_features, planned_waypoints, ego_state) |
| |
| |
| next_wp = planned_waypoints[:, 0, :] |
| heading_error = next_wp[:, 2] - ego_state[:, 3] |
| cross_track_error = torch.sqrt( |
| (next_wp[:, 0] - ego_state[:, 4])**2 + |
| (next_wp[:, 1] - ego_state[:, 5])**2 |
| ) |
| stanley_steer = self.stanley_controller( |
| heading_error, cross_track_error, ego_state[:, 0] |
| ) |
| |
| |
| lateral_err = cross_track_error |
| speed_err = next_wp[:, 3] - ego_state[:, 0] |
| pid_error = torch.stack([lateral_err, speed_err], dim=-1) |
| pid_out = self.pid_controller(pid_error, ego_state, self.dt) |
| |
| |
| weights = self.fusion_weights(ego_state) |
| |
| |
| neural_steer = neural_out["steering"] * self.max_steering |
| final_steering = ( |
| weights[:, 0] * neural_steer + |
| weights[:, 1] * stanley_steer + |
| weights[:, 2] * torch.clamp(pid_out[:, 0], -self.max_steering, self.max_steering) |
| ) |
| |
| |
| final_throttle = neural_out["throttle"] |
| final_brake = neural_out["brake"] |
| |
| |
| pid_accel = pid_out[:, 1] |
| final_throttle = final_throttle + torch.clamp(pid_accel, 0, 1) * weights[:, 2] |
| final_brake = final_brake + torch.clamp(-pid_accel, 0, 1) * weights[:, 2] |
| |
| |
| if emergency_brake is not None: |
| emergency_mask = (emergency_brake.squeeze(-1) > 0.5).float() |
| final_throttle = final_throttle * (1 - emergency_mask) |
| final_brake = torch.max(final_brake, emergency_mask) |
| |
| |
| final_steering = torch.clamp(final_steering, -self.max_steering, self.max_steering) |
| final_throttle = torch.clamp(final_throttle, 0.0, 1.0) |
| final_brake = torch.clamp(final_brake, 0.0, 1.0) |
| |
| |
| |
| brake_dominant = (final_brake > final_throttle).float() |
| final_throttle = final_throttle * (1 - brake_dominant) |
| |
| |
| accel_cmd = final_throttle * self.max_accel - final_brake * self.max_decel |
| steer_deg = torch.rad2deg(final_steering) |
| |
| |
| current_state = torch.stack([ |
| ego_state[:, 4], |
| ego_state[:, 5], |
| ego_state[:, 3], |
| ego_state[:, 0], |
| ], dim=-1) |
| |
| control_input = torch.stack([final_steering, accel_cmd], dim=-1) |
| predicted_next_state = self.bicycle_model(current_state, control_input) |
| |
| return { |
| "steering_rad": final_steering, |
| "steering_deg": steer_deg, |
| "throttle": final_throttle, |
| "brake": final_brake, |
| "acceleration_cmd": accel_cmd, |
| "controller_weights": weights, |
| "predicted_next_state": predicted_next_state, |
| } |
| |
| def reset(self): |
| """Reset controller states (call at start of new episode).""" |
| self.pid_controller.reset() |
|
|