""" Full Self-Driving Model - Level 5 Autonomous Driving Complete end-to-end architecture: Sensor Fusion → Perception → Planning → Control Architecture Summary: ───────────────────── Sensors (configurable): ├── 6 Cameras → CNN Backbone → FPN → View Transform → Camera BEV └── 20 Ultrasonics → Distance Encoder → Position Encoder → US BEV ↓ Multi-Modal Fusion (Channel Attention) → Unified BEV ↓ Perception: ├── Object Detection (CenterPoint-style heatmap) ├── BEV Segmentation (road, lanes, crosswalks) ├── Occupancy Grid (current + future) └── Motion Forecasting (multi-modal trajectories) ↓ Planning: ├── Behavior Prediction (10 driving behaviors) ├── Trajectory Transformer (20 waypoints, 8-head attention) └── Safety Verification (collision + emergency brake) ↓ Control: ├── Neural Controller (end-to-end) ├── Stanley Controller (lateral) ├── PID Controller (adaptive gains) └── Bicycle Model (dynamics prediction) ↓ Output: steering, throttle, brake, predicted trajectory """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Optional, Tuple, List import math import json import os from .config import VehicleConfig, SensorConfig from .sensor_fusion import MultiModalSensorFusion from .perception import PerceptionModule from .planning import PlanningModule from .control import ControlModule from .cot_reasoning import ChainOfThoughtReasoning class FullSelfDrivingModel(nn.Module): """ End-to-end Level 5 Full Self-Driving Model. Takes raw sensor data (cameras + ultrasonics) and outputs actuator commands (steering, throttle, brake). Fully modular: sensor configuration, perception heads, planning strategy, and control method are all configurable. """ def __init__( self, vehicle_config: Optional[VehicleConfig] = None, bev_size: int = 200, bev_resolution: float = 0.25, bev_feature_dim: int = 256, num_object_classes: int = 10, num_seg_classes: int = 7, num_waypoints: int = 20, planning_d_model: int = 256, future_steps: int = 6, num_forecast_modes: int = 6, forecast_steps: int = 12, num_behaviors: int = 10, enable_cot: bool = True, cot_num_actor_queries: int = 64, cot_num_road_queries: int = 32, ): super().__init__() # Vehicle and sensor configuration if vehicle_config is None: vehicle_config = VehicleConfig() self.vehicle_config = vehicle_config self.sensor_config = vehicle_config.sensor_config self.enable_cot = enable_cot # Store hyperparameters self.hparams = { "bev_size": bev_size, "bev_resolution": bev_resolution, "bev_feature_dim": bev_feature_dim, "num_object_classes": num_object_classes, "num_seg_classes": num_seg_classes, "num_waypoints": num_waypoints, "planning_d_model": planning_d_model, "future_steps": future_steps, "num_forecast_modes": num_forecast_modes, "forecast_steps": forecast_steps, "num_behaviors": num_behaviors, "max_speed_mph": vehicle_config.max_speed_mph, "max_speed_ms": vehicle_config.max_speed_ms, "num_cameras": self.sensor_config.num_cameras, "num_ultrasonics": self.sensor_config.num_ultrasonics, "enable_cot": enable_cot, "cot_num_actor_queries": cot_num_actor_queries, "cot_num_road_queries": cot_num_road_queries, } # 1. Sensor Fusion Module self.sensor_fusion = MultiModalSensorFusion( sensor_config=self.sensor_config, bev_size=bev_size, bev_resolution=bev_resolution, bev_feature_dim=bev_feature_dim, ) # 2. Perception Module self.perception = PerceptionModule( bev_channels=bev_feature_dim, num_object_classes=num_object_classes, num_seg_classes=num_seg_classes, future_steps=future_steps, num_forecast_modes=num_forecast_modes, forecast_steps=forecast_steps, ) # 3. Planning Module self.planning = PlanningModule( bev_channels=bev_feature_dim, d_model=planning_d_model, num_waypoints=num_waypoints, max_speed_ms=vehicle_config.max_speed_ms, num_behaviors=num_behaviors, ) # 4. Control Module self.control = ControlModule( bev_channels=bev_feature_dim, num_waypoints=num_waypoints, wheelbase=vehicle_config.wheelbase, max_speed_ms=vehicle_config.max_speed_ms, max_steering_deg=vehicle_config.max_steering_angle, max_accel=vehicle_config.max_acceleration, max_decel=vehicle_config.max_deceleration, ) # 5. Chain-of-Thought Safety Reasoning (optional but default ON) if enable_cot: self.cot_reasoning = ChainOfThoughtReasoning( bev_channels=bev_feature_dim, d_model=planning_d_model, num_actor_queries=cot_num_actor_queries, num_road_queries=cot_num_road_queries, num_waypoints=num_waypoints, num_behaviors=num_behaviors, max_speed_ms=vehicle_config.max_speed_ms, ) else: self.cot_reasoning = None # Initialize weights self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward( self, camera_images: Optional[torch.Tensor] = None, camera_intrinsics: Optional[torch.Tensor] = None, camera_extrinsics: Optional[torch.Tensor] = None, ultrasonic_distances: Optional[torch.Tensor] = None, ultrasonic_placements: Optional[torch.Tensor] = None, ego_state: Optional[torch.Tensor] = None, nav_command: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Full forward pass: sensors → BEV → perception → planning → control. Args: camera_images: (B, N_cams, 3, H, W) raw camera images camera_intrinsics: (B, N_cams, 3, 3) camera calibration camera_extrinsics: (B, N_cams, 4, 4) camera-to-ego transforms ultrasonic_distances: (B, N_us, 1) distance readings ultrasonic_placements: (B, N_us, 6) sensor positions ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y] nav_command: (B,) navigation command integer Returns: Dict containing all intermediate and final outputs """ B = (camera_images.shape[0] if camera_images is not None else ultrasonic_distances.shape[0]) device = (camera_images.device if camera_images is not None else ultrasonic_distances.device) # Default ego state if not provided if ego_state is None: ego_state = torch.zeros(B, 6, device=device) # ───── Stage 1: Sensor Fusion ───── bev_features = self.sensor_fusion( camera_images=camera_images, camera_intrinsics=camera_intrinsics, camera_extrinsics=camera_extrinsics, ultrasonic_distances=ultrasonic_distances, ultrasonic_placements=ultrasonic_placements, ) # ───── Stage 2: Perception ───── perception_out = self.perception(bev_features) # ───── Stage 3: Planning ───── planning_out = self.planning( bev_features=bev_features, ego_state=ego_state, nav_command=nav_command, ) # ───── Stage 3.5: Chain-of-Thought Safety Reasoning ───── cot_out = {} final_waypoints = planning_out["safe_waypoints"] if self.cot_reasoning is not None: cot_out = self.cot_reasoning( bev_features=bev_features, ego_state=ego_state, planner_waypoints=planning_out["safe_waypoints"], ) # Use CoT-enriched BEV for control (safety-aware features) bev_for_control = cot_out.get("enriched_bev", bev_features) # Use safety-gated waypoints if available if "cot/gated_waypoints" in cot_out: final_waypoints = cot_out["cot/gated_waypoints"] else: bev_for_control = bev_features # ───── Stage 4: Control ───── control_out = self.control( bev_features=bev_for_control, planned_waypoints=final_waypoints, ego_state=ego_state, emergency_brake=planning_out["emergency_brake"], ) # Combine all outputs output = {} output["bev_features"] = bev_features output.update({f"perception/{k}": v for k, v in perception_out.items()}) output.update({f"planning/{k}": v for k, v in planning_out.items()}) output.update({f"control/{k}": v for k, v in control_out.items()}) if cot_out: output.update({k: v for k, v in cot_out.items() if k != "enriched_bev"}) return output def get_control_output( self, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Convenience method that returns only (steering, throttle, brake). """ out = self.forward(**kwargs) return ( out["control/steering_deg"], out["control/throttle"], out["control/brake"], ) def count_parameters(self) -> Dict[str, int]: """Count parameters per module.""" counts = { "sensor_fusion": sum(p.numel() for p in self.sensor_fusion.parameters()), "perception": sum(p.numel() for p in self.perception.parameters()), "planning": sum(p.numel() for p in self.planning.parameters()), "control": sum(p.numel() for p in self.control.parameters()), } if self.cot_reasoning is not None: counts["cot_reasoning"] = sum(p.numel() for p in self.cot_reasoning.parameters()) counts["total"] = sum(counts.values()) counts["total_trainable"] = sum( p.numel() for p in self.parameters() if p.requires_grad ) return counts def save_pretrained(self, save_dir: str): """Save model with config for easy loading.""" os.makedirs(save_dir, exist_ok=True) # Save model weights torch.save(self.state_dict(), os.path.join(save_dir, "model.pt")) # Save hyperparameters with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(self.hparams, f, indent=2) # Save sensor config self.sensor_config.save(os.path.join(save_dir, "sensor_config.json")) # Save parameter counts counts = self.count_parameters() with open(os.path.join(save_dir, "model_summary.json"), "w") as f: json.dump(counts, f, indent=2) @classmethod def from_pretrained(cls, load_dir: str, device: str = "cpu"): """Load model from saved directory.""" with open(os.path.join(load_dir, "config.json"), "r") as f: hparams = json.load(f) config = VehicleConfig( max_speed_mph=hparams.get("max_speed_mph", 20.0) ) model = cls( vehicle_config=config, bev_size=hparams.get("bev_size", 200), bev_resolution=hparams.get("bev_resolution", 0.25), bev_feature_dim=hparams.get("bev_feature_dim", 256), num_object_classes=hparams.get("num_object_classes", 10), num_seg_classes=hparams.get("num_seg_classes", 7), num_waypoints=hparams.get("num_waypoints", 20), planning_d_model=hparams.get("planning_d_model", 256), future_steps=hparams.get("future_steps", 6), num_forecast_modes=hparams.get("num_forecast_modes", 6), forecast_steps=hparams.get("forecast_steps", 12), num_behaviors=hparams.get("num_behaviors", 10), enable_cot=hparams.get("enable_cot", True), cot_num_actor_queries=hparams.get("cot_num_actor_queries", 64), cot_num_road_queries=hparams.get("cot_num_road_queries", 32), ) weights = torch.load( os.path.join(load_dir, "model.pt"), map_location=device, weights_only=True, ) model.load_state_dict(weights, strict=False) return model def reset(self): """Reset stateful components (call at episode start).""" self.control.reset() class FSDLoss(nn.Module): """ Multi-task loss for training the FSD model. Combines losses from all modules with learnable task weights. """ def __init__( self, num_object_classes: int = 10, num_seg_classes: int = 7, num_behaviors: int = 10, # Loss weights (initial) w_detection: float = 1.0, w_segmentation: float = 1.0, w_occupancy: float = 1.0, w_motion: float = 1.0, w_behavior: float = 0.5, w_trajectory: float = 2.0, w_control: float = 2.0, w_safety: float = 1.5, learnable_weights: bool = True, ): super().__init__() self.num_object_classes = num_object_classes if learnable_weights: # Log-scale learnable task weights (homoscedastic uncertainty) self.log_vars = nn.Parameter(torch.zeros(8)) else: self.register_buffer('log_vars', None) self.fixed_weights = [ w_detection, w_segmentation, w_occupancy, w_motion, w_behavior, w_trajectory, w_control, w_safety ] def forward( self, predictions: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: """ Compute multi-task loss. Args: predictions: Model output dict targets: Ground truth dict with keys: - gt_heatmap: (B, C, H, W) object heatmap - gt_bbox: (B, 6, H, W) bounding boxes - gt_segmentation: (B, H, W) segmentation labels - gt_occupancy: (B, 1, H, W) occupancy grid - gt_behavior: (B,) behavior labels - gt_waypoints: (B, T, 4) ground truth trajectory - gt_steering: (B,) steering commands - gt_throttle: (B,) throttle commands - gt_brake: (B,) brake commands """ losses = {} # 1. Detection loss (focal loss for heatmap + L1 for bbox) if "gt_heatmap" in targets: pred_heat = predictions["perception/object_heatmap"] gt_heat = targets["gt_heatmap"] # Resize if needed if pred_heat.shape != gt_heat.shape: gt_heat = F.interpolate(gt_heat.float(), size=pred_heat.shape[2:], mode='nearest') losses["detection"] = self._focal_loss(pred_heat, gt_heat) # 2. Segmentation loss (cross entropy) if "gt_segmentation" in targets: pred_seg = predictions["perception/segmentation"] gt_seg = targets["gt_segmentation"] if pred_seg.shape[2:] != gt_seg.shape[1:]: gt_seg = F.interpolate(gt_seg.float().unsqueeze(1), size=pred_seg.shape[2:], mode='nearest').squeeze(1).long() losses["segmentation"] = F.cross_entropy(pred_seg, gt_seg) # 3. Occupancy loss (binary cross entropy) if "gt_occupancy" in targets: pred_occ = predictions["perception/occupancy_current"] gt_occ = targets["gt_occupancy"] if pred_occ.shape != gt_occ.shape: gt_occ = F.interpolate(gt_occ.float(), size=pred_occ.shape[2:], mode='nearest') losses["occupancy"] = F.binary_cross_entropy(pred_occ, gt_occ.float()) # 4. Motion forecasting loss (ADE - average displacement error) if "gt_future_trajectories" in targets: pred_traj = predictions["perception/motion_trajectories"] gt_traj = targets["gt_future_trajectories"] # Best-of-K: min ADE across modes errors = torch.norm(pred_traj - gt_traj.unsqueeze(1), dim=-1).mean(dim=-1) min_errors, _ = errors.min(dim=1) losses["motion"] = min_errors.mean() # 5. Behavior prediction loss if "gt_behavior" in targets: pred_behavior = predictions["planning/behavior_logits"] losses["behavior"] = F.cross_entropy(pred_behavior, targets["gt_behavior"]) # 6. Trajectory planning loss (L2 waypoint error) if "gt_waypoints" in targets: pred_wp = predictions["planning/safe_waypoints"] gt_wp = targets["gt_waypoints"] min_len = min(pred_wp.shape[1], gt_wp.shape[1]) losses["trajectory"] = F.mse_loss( pred_wp[:, :min_len], gt_wp[:, :min_len] ) # 7. Control loss control_loss = torch.tensor(0.0, device=list(predictions.values())[0].device) if "gt_steering" in targets: pred_steer = predictions["control/steering_deg"] control_loss = control_loss + F.mse_loss(pred_steer, targets["gt_steering"]) if "gt_throttle" in targets: pred_throttle = predictions["control/throttle"] control_loss = control_loss + F.mse_loss(pred_throttle, targets["gt_throttle"]) if "gt_brake" in targets: pred_brake = predictions["control/brake"] control_loss = control_loss + F.mse_loss(pred_brake, targets["gt_brake"]) losses["control"] = control_loss # 8. Safety loss (minimize collision risk) losses["safety"] = predictions["planning/collision_risk"].mean() # Combine losses with weights if self.log_vars is not None: total_loss = torch.tensor(0.0, device=self.log_vars.device) loss_keys = list(losses.keys()) for i, key in enumerate(loss_keys): if i < len(self.log_vars): precision = torch.exp(-self.log_vars[i]) total_loss = total_loss + precision * losses[key] + self.log_vars[i] else: total_loss = sum( w * losses.get(k, torch.tensor(0.0)) for w, k in zip(self.fixed_weights, [ "detection", "segmentation", "occupancy", "motion", "behavior", "trajectory", "control", "safety" ]) ) losses["total"] = total_loss return losses def _focal_loss(self, pred, target, alpha=2.0, beta=4.0): """Focal loss for heatmap detection.""" pos_mask = target.eq(1).float() neg_mask = target.lt(1).float() pred = torch.clamp(pred, 1e-6, 1 - 1e-6) pos_loss = -torch.log(pred) * torch.pow(1 - pred, alpha) * pos_mask neg_loss = -torch.log(1 - pred) * torch.pow(pred, alpha) * torch.pow(1 - target, beta) * neg_mask num_pos = pos_mask.sum().clamp(min=1) loss = (pos_loss.sum() + neg_loss.sum()) / num_pos return loss