| """ |
| Full Self-Driving Model - Level 5 Autonomous Driving |
| Complete end-to-end architecture: Sensor Fusion → Perception → Planning → Control |
| |
| Architecture Summary: |
| ───────────────────── |
| Sensors (configurable): |
| ├── 6 Cameras → CNN Backbone → FPN → View Transform → Camera BEV |
| └── 20 Ultrasonics → Distance Encoder → Position Encoder → US BEV |
| ↓ |
| Multi-Modal Fusion (Channel Attention) → Unified BEV |
| ↓ |
| Perception: |
| ├── Object Detection (CenterPoint-style heatmap) |
| ├── BEV Segmentation (road, lanes, crosswalks) |
| ├── Occupancy Grid (current + future) |
| └── Motion Forecasting (multi-modal trajectories) |
| ↓ |
| Planning: |
| ├── Behavior Prediction (10 driving behaviors) |
| ├── Trajectory Transformer (20 waypoints, 8-head attention) |
| └── Safety Verification (collision + emergency brake) |
| ↓ |
| Control: |
| ├── Neural Controller (end-to-end) |
| ├── Stanley Controller (lateral) |
| ├── PID Controller (adaptive gains) |
| └── Bicycle Model (dynamics prediction) |
| ↓ |
| Output: steering, throttle, brake, predicted trajectory |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, Optional, Tuple, List |
| import math |
| import json |
| import os |
|
|
| from .config import VehicleConfig, SensorConfig |
| from .sensor_fusion import MultiModalSensorFusion |
| from .perception import PerceptionModule |
| from .planning import PlanningModule |
| from .control import ControlModule |
| from .cot_reasoning import ChainOfThoughtReasoning |
|
|
|
|
| class FullSelfDrivingModel(nn.Module): |
| """ |
| End-to-end Level 5 Full Self-Driving Model. |
| |
| Takes raw sensor data (cameras + ultrasonics) and outputs |
| actuator commands (steering, throttle, brake). |
| |
| Fully modular: sensor configuration, perception heads, planning |
| strategy, and control method are all configurable. |
| """ |
| |
| def __init__( |
| self, |
| vehicle_config: Optional[VehicleConfig] = None, |
| bev_size: int = 200, |
| bev_resolution: float = 0.25, |
| bev_feature_dim: int = 256, |
| num_object_classes: int = 10, |
| num_seg_classes: int = 7, |
| num_waypoints: int = 20, |
| planning_d_model: int = 256, |
| future_steps: int = 6, |
| num_forecast_modes: int = 6, |
| forecast_steps: int = 12, |
| num_behaviors: int = 10, |
| enable_cot: bool = True, |
| cot_num_actor_queries: int = 64, |
| cot_num_road_queries: int = 32, |
| ): |
| super().__init__() |
| |
| |
| if vehicle_config is None: |
| vehicle_config = VehicleConfig() |
| self.vehicle_config = vehicle_config |
| self.sensor_config = vehicle_config.sensor_config |
| self.enable_cot = enable_cot |
| |
| |
| self.hparams = { |
| "bev_size": bev_size, |
| "bev_resolution": bev_resolution, |
| "bev_feature_dim": bev_feature_dim, |
| "num_object_classes": num_object_classes, |
| "num_seg_classes": num_seg_classes, |
| "num_waypoints": num_waypoints, |
| "planning_d_model": planning_d_model, |
| "future_steps": future_steps, |
| "num_forecast_modes": num_forecast_modes, |
| "forecast_steps": forecast_steps, |
| "num_behaviors": num_behaviors, |
| "max_speed_mph": vehicle_config.max_speed_mph, |
| "max_speed_ms": vehicle_config.max_speed_ms, |
| "num_cameras": self.sensor_config.num_cameras, |
| "num_ultrasonics": self.sensor_config.num_ultrasonics, |
| "enable_cot": enable_cot, |
| "cot_num_actor_queries": cot_num_actor_queries, |
| "cot_num_road_queries": cot_num_road_queries, |
| } |
| |
| |
| self.sensor_fusion = MultiModalSensorFusion( |
| sensor_config=self.sensor_config, |
| bev_size=bev_size, |
| bev_resolution=bev_resolution, |
| bev_feature_dim=bev_feature_dim, |
| ) |
| |
| |
| self.perception = PerceptionModule( |
| bev_channels=bev_feature_dim, |
| num_object_classes=num_object_classes, |
| num_seg_classes=num_seg_classes, |
| future_steps=future_steps, |
| num_forecast_modes=num_forecast_modes, |
| forecast_steps=forecast_steps, |
| ) |
| |
| |
| self.planning = PlanningModule( |
| bev_channels=bev_feature_dim, |
| d_model=planning_d_model, |
| num_waypoints=num_waypoints, |
| max_speed_ms=vehicle_config.max_speed_ms, |
| num_behaviors=num_behaviors, |
| ) |
| |
| |
| self.control = ControlModule( |
| bev_channels=bev_feature_dim, |
| num_waypoints=num_waypoints, |
| wheelbase=vehicle_config.wheelbase, |
| max_speed_ms=vehicle_config.max_speed_ms, |
| max_steering_deg=vehicle_config.max_steering_angle, |
| max_accel=vehicle_config.max_acceleration, |
| max_decel=vehicle_config.max_deceleration, |
| ) |
| |
| |
| if enable_cot: |
| self.cot_reasoning = ChainOfThoughtReasoning( |
| bev_channels=bev_feature_dim, |
| d_model=planning_d_model, |
| num_actor_queries=cot_num_actor_queries, |
| num_road_queries=cot_num_road_queries, |
| num_waypoints=num_waypoints, |
| num_behaviors=num_behaviors, |
| max_speed_ms=vehicle_config.max_speed_ms, |
| ) |
| else: |
| self.cot_reasoning = None |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, m): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| |
| def forward( |
| self, |
| camera_images: Optional[torch.Tensor] = None, |
| camera_intrinsics: Optional[torch.Tensor] = None, |
| camera_extrinsics: Optional[torch.Tensor] = None, |
| ultrasonic_distances: Optional[torch.Tensor] = None, |
| ultrasonic_placements: Optional[torch.Tensor] = None, |
| ego_state: Optional[torch.Tensor] = None, |
| nav_command: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Full forward pass: sensors → BEV → perception → planning → control. |
| |
| Args: |
| camera_images: (B, N_cams, 3, H, W) raw camera images |
| camera_intrinsics: (B, N_cams, 3, 3) camera calibration |
| camera_extrinsics: (B, N_cams, 4, 4) camera-to-ego transforms |
| ultrasonic_distances: (B, N_us, 1) distance readings |
| ultrasonic_placements: (B, N_us, 6) sensor positions |
| ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y] |
| nav_command: (B,) navigation command integer |
| |
| Returns: |
| Dict containing all intermediate and final outputs |
| """ |
| B = (camera_images.shape[0] if camera_images is not None |
| else ultrasonic_distances.shape[0]) |
| device = (camera_images.device if camera_images is not None |
| else ultrasonic_distances.device) |
| |
| |
| if ego_state is None: |
| ego_state = torch.zeros(B, 6, device=device) |
| |
| |
| bev_features = self.sensor_fusion( |
| camera_images=camera_images, |
| camera_intrinsics=camera_intrinsics, |
| camera_extrinsics=camera_extrinsics, |
| ultrasonic_distances=ultrasonic_distances, |
| ultrasonic_placements=ultrasonic_placements, |
| ) |
| |
| |
| perception_out = self.perception(bev_features) |
| |
| |
| planning_out = self.planning( |
| bev_features=bev_features, |
| ego_state=ego_state, |
| nav_command=nav_command, |
| ) |
| |
| |
| cot_out = {} |
| final_waypoints = planning_out["safe_waypoints"] |
| if self.cot_reasoning is not None: |
| cot_out = self.cot_reasoning( |
| bev_features=bev_features, |
| ego_state=ego_state, |
| planner_waypoints=planning_out["safe_waypoints"], |
| ) |
| |
| bev_for_control = cot_out.get("enriched_bev", bev_features) |
| |
| if "cot/gated_waypoints" in cot_out: |
| final_waypoints = cot_out["cot/gated_waypoints"] |
| else: |
| bev_for_control = bev_features |
| |
| |
| control_out = self.control( |
| bev_features=bev_for_control, |
| planned_waypoints=final_waypoints, |
| ego_state=ego_state, |
| emergency_brake=planning_out["emergency_brake"], |
| ) |
| |
| |
| output = {} |
| output["bev_features"] = bev_features |
| output.update({f"perception/{k}": v for k, v in perception_out.items()}) |
| output.update({f"planning/{k}": v for k, v in planning_out.items()}) |
| output.update({f"control/{k}": v for k, v in control_out.items()}) |
| if cot_out: |
| output.update({k: v for k, v in cot_out.items() if k != "enriched_bev"}) |
| |
| return output |
| |
| def get_control_output( |
| self, **kwargs |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Convenience method that returns only (steering, throttle, brake). |
| """ |
| out = self.forward(**kwargs) |
| return ( |
| out["control/steering_deg"], |
| out["control/throttle"], |
| out["control/brake"], |
| ) |
| |
| def count_parameters(self) -> Dict[str, int]: |
| """Count parameters per module.""" |
| counts = { |
| "sensor_fusion": sum(p.numel() for p in self.sensor_fusion.parameters()), |
| "perception": sum(p.numel() for p in self.perception.parameters()), |
| "planning": sum(p.numel() for p in self.planning.parameters()), |
| "control": sum(p.numel() for p in self.control.parameters()), |
| } |
| if self.cot_reasoning is not None: |
| counts["cot_reasoning"] = sum(p.numel() for p in self.cot_reasoning.parameters()) |
| counts["total"] = sum(counts.values()) |
| counts["total_trainable"] = sum( |
| p.numel() for p in self.parameters() if p.requires_grad |
| ) |
| return counts |
| |
| def save_pretrained(self, save_dir: str): |
| """Save model with config for easy loading.""" |
| os.makedirs(save_dir, exist_ok=True) |
| |
| |
| torch.save(self.state_dict(), os.path.join(save_dir, "model.pt")) |
| |
| |
| with open(os.path.join(save_dir, "config.json"), "w") as f: |
| json.dump(self.hparams, f, indent=2) |
| |
| |
| self.sensor_config.save(os.path.join(save_dir, "sensor_config.json")) |
| |
| |
| counts = self.count_parameters() |
| with open(os.path.join(save_dir, "model_summary.json"), "w") as f: |
| json.dump(counts, f, indent=2) |
| |
| @classmethod |
| def from_pretrained(cls, load_dir: str, device: str = "cpu"): |
| """Load model from saved directory.""" |
| with open(os.path.join(load_dir, "config.json"), "r") as f: |
| hparams = json.load(f) |
| |
| config = VehicleConfig( |
| max_speed_mph=hparams.get("max_speed_mph", 20.0) |
| ) |
| |
| model = cls( |
| vehicle_config=config, |
| bev_size=hparams.get("bev_size", 200), |
| bev_resolution=hparams.get("bev_resolution", 0.25), |
| bev_feature_dim=hparams.get("bev_feature_dim", 256), |
| num_object_classes=hparams.get("num_object_classes", 10), |
| num_seg_classes=hparams.get("num_seg_classes", 7), |
| num_waypoints=hparams.get("num_waypoints", 20), |
| planning_d_model=hparams.get("planning_d_model", 256), |
| future_steps=hparams.get("future_steps", 6), |
| num_forecast_modes=hparams.get("num_forecast_modes", 6), |
| forecast_steps=hparams.get("forecast_steps", 12), |
| num_behaviors=hparams.get("num_behaviors", 10), |
| enable_cot=hparams.get("enable_cot", True), |
| cot_num_actor_queries=hparams.get("cot_num_actor_queries", 64), |
| cot_num_road_queries=hparams.get("cot_num_road_queries", 32), |
| ) |
| |
| weights = torch.load( |
| os.path.join(load_dir, "model.pt"), |
| map_location=device, |
| weights_only=True, |
| ) |
| model.load_state_dict(weights, strict=False) |
| return model |
| |
| def reset(self): |
| """Reset stateful components (call at episode start).""" |
| self.control.reset() |
|
|
|
|
| class FSDLoss(nn.Module): |
| """ |
| Multi-task loss for training the FSD model. |
| Combines losses from all modules with learnable task weights. |
| """ |
| def __init__( |
| self, |
| num_object_classes: int = 10, |
| num_seg_classes: int = 7, |
| num_behaviors: int = 10, |
| |
| w_detection: float = 1.0, |
| w_segmentation: float = 1.0, |
| w_occupancy: float = 1.0, |
| w_motion: float = 1.0, |
| w_behavior: float = 0.5, |
| w_trajectory: float = 2.0, |
| w_control: float = 2.0, |
| w_safety: float = 1.5, |
| learnable_weights: bool = True, |
| ): |
| super().__init__() |
| |
| self.num_object_classes = num_object_classes |
| |
| if learnable_weights: |
| |
| self.log_vars = nn.Parameter(torch.zeros(8)) |
| else: |
| self.register_buffer('log_vars', None) |
| self.fixed_weights = [ |
| w_detection, w_segmentation, w_occupancy, w_motion, |
| w_behavior, w_trajectory, w_control, w_safety |
| ] |
| |
| def forward( |
| self, |
| predictions: Dict[str, torch.Tensor], |
| targets: Dict[str, torch.Tensor], |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Compute multi-task loss. |
| |
| Args: |
| predictions: Model output dict |
| targets: Ground truth dict with keys: |
| - gt_heatmap: (B, C, H, W) object heatmap |
| - gt_bbox: (B, 6, H, W) bounding boxes |
| - gt_segmentation: (B, H, W) segmentation labels |
| - gt_occupancy: (B, 1, H, W) occupancy grid |
| - gt_behavior: (B,) behavior labels |
| - gt_waypoints: (B, T, 4) ground truth trajectory |
| - gt_steering: (B,) steering commands |
| - gt_throttle: (B,) throttle commands |
| - gt_brake: (B,) brake commands |
| """ |
| losses = {} |
| |
| |
| if "gt_heatmap" in targets: |
| pred_heat = predictions["perception/object_heatmap"] |
| gt_heat = targets["gt_heatmap"] |
| |
| if pred_heat.shape != gt_heat.shape: |
| gt_heat = F.interpolate(gt_heat.float(), size=pred_heat.shape[2:], mode='nearest') |
| losses["detection"] = self._focal_loss(pred_heat, gt_heat) |
| |
| |
| if "gt_segmentation" in targets: |
| pred_seg = predictions["perception/segmentation"] |
| gt_seg = targets["gt_segmentation"] |
| if pred_seg.shape[2:] != gt_seg.shape[1:]: |
| gt_seg = F.interpolate(gt_seg.float().unsqueeze(1), size=pred_seg.shape[2:], mode='nearest').squeeze(1).long() |
| losses["segmentation"] = F.cross_entropy(pred_seg, gt_seg) |
| |
| |
| if "gt_occupancy" in targets: |
| pred_occ = predictions["perception/occupancy_current"] |
| gt_occ = targets["gt_occupancy"] |
| if pred_occ.shape != gt_occ.shape: |
| gt_occ = F.interpolate(gt_occ.float(), size=pred_occ.shape[2:], mode='nearest') |
| losses["occupancy"] = F.binary_cross_entropy(pred_occ, gt_occ.float()) |
| |
| |
| if "gt_future_trajectories" in targets: |
| pred_traj = predictions["perception/motion_trajectories"] |
| gt_traj = targets["gt_future_trajectories"] |
| |
| errors = torch.norm(pred_traj - gt_traj.unsqueeze(1), dim=-1).mean(dim=-1) |
| min_errors, _ = errors.min(dim=1) |
| losses["motion"] = min_errors.mean() |
| |
| |
| if "gt_behavior" in targets: |
| pred_behavior = predictions["planning/behavior_logits"] |
| losses["behavior"] = F.cross_entropy(pred_behavior, targets["gt_behavior"]) |
| |
| |
| if "gt_waypoints" in targets: |
| pred_wp = predictions["planning/safe_waypoints"] |
| gt_wp = targets["gt_waypoints"] |
| min_len = min(pred_wp.shape[1], gt_wp.shape[1]) |
| losses["trajectory"] = F.mse_loss( |
| pred_wp[:, :min_len], gt_wp[:, :min_len] |
| ) |
| |
| |
| control_loss = torch.tensor(0.0, device=list(predictions.values())[0].device) |
| if "gt_steering" in targets: |
| pred_steer = predictions["control/steering_deg"] |
| control_loss = control_loss + F.mse_loss(pred_steer, targets["gt_steering"]) |
| if "gt_throttle" in targets: |
| pred_throttle = predictions["control/throttle"] |
| control_loss = control_loss + F.mse_loss(pred_throttle, targets["gt_throttle"]) |
| if "gt_brake" in targets: |
| pred_brake = predictions["control/brake"] |
| control_loss = control_loss + F.mse_loss(pred_brake, targets["gt_brake"]) |
| losses["control"] = control_loss |
| |
| |
| losses["safety"] = predictions["planning/collision_risk"].mean() |
| |
| |
| if self.log_vars is not None: |
| total_loss = torch.tensor(0.0, device=self.log_vars.device) |
| loss_keys = list(losses.keys()) |
| for i, key in enumerate(loss_keys): |
| if i < len(self.log_vars): |
| precision = torch.exp(-self.log_vars[i]) |
| total_loss = total_loss + precision * losses[key] + self.log_vars[i] |
| else: |
| total_loss = sum( |
| w * losses.get(k, torch.tensor(0.0)) |
| for w, k in zip(self.fixed_weights, [ |
| "detection", "segmentation", "occupancy", "motion", |
| "behavior", "trajectory", "control", "safety" |
| ]) |
| ) |
| |
| losses["total"] = total_loss |
| return losses |
| |
| def _focal_loss(self, pred, target, alpha=2.0, beta=4.0): |
| """Focal loss for heatmap detection.""" |
| pos_mask = target.eq(1).float() |
| neg_mask = target.lt(1).float() |
| |
| pred = torch.clamp(pred, 1e-6, 1 - 1e-6) |
| |
| pos_loss = -torch.log(pred) * torch.pow(1 - pred, alpha) * pos_mask |
| neg_loss = -torch.log(1 - pred) * torch.pow(pred, alpha) * torch.pow(1 - target, beta) * neg_mask |
| |
| num_pos = pos_mask.sum().clamp(min=1) |
| loss = (pos_loss.sum() + neg_loss.sum()) / num_pos |
| return loss |
|
|