Reality123b's picture
Add model.py
0f1ded8 verified
"""
Full Self-Driving Model - Level 5 Autonomous Driving
Complete end-to-end architecture: Sensor Fusion → Perception → Planning → Control
Architecture Summary:
─────────────────────
Sensors (configurable):
├── 6 Cameras → CNN Backbone → FPN → View Transform → Camera BEV
└── 20 Ultrasonics → Distance Encoder → Position Encoder → US BEV
Multi-Modal Fusion (Channel Attention) → Unified BEV
Perception:
├── Object Detection (CenterPoint-style heatmap)
├── BEV Segmentation (road, lanes, crosswalks)
├── Occupancy Grid (current + future)
└── Motion Forecasting (multi-modal trajectories)
Planning:
├── Behavior Prediction (10 driving behaviors)
├── Trajectory Transformer (20 waypoints, 8-head attention)
└── Safety Verification (collision + emergency brake)
Control:
├── Neural Controller (end-to-end)
├── Stanley Controller (lateral)
├── PID Controller (adaptive gains)
└── Bicycle Model (dynamics prediction)
Output: steering, throttle, brake, predicted trajectory
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, List
import math
import json
import os
from .config import VehicleConfig, SensorConfig
from .sensor_fusion import MultiModalSensorFusion
from .perception import PerceptionModule
from .planning import PlanningModule
from .control import ControlModule
from .cot_reasoning import ChainOfThoughtReasoning
class FullSelfDrivingModel(nn.Module):
"""
End-to-end Level 5 Full Self-Driving Model.
Takes raw sensor data (cameras + ultrasonics) and outputs
actuator commands (steering, throttle, brake).
Fully modular: sensor configuration, perception heads, planning
strategy, and control method are all configurable.
"""
def __init__(
self,
vehicle_config: Optional[VehicleConfig] = None,
bev_size: int = 200,
bev_resolution: float = 0.25,
bev_feature_dim: int = 256,
num_object_classes: int = 10,
num_seg_classes: int = 7,
num_waypoints: int = 20,
planning_d_model: int = 256,
future_steps: int = 6,
num_forecast_modes: int = 6,
forecast_steps: int = 12,
num_behaviors: int = 10,
enable_cot: bool = True,
cot_num_actor_queries: int = 64,
cot_num_road_queries: int = 32,
):
super().__init__()
# Vehicle and sensor configuration
if vehicle_config is None:
vehicle_config = VehicleConfig()
self.vehicle_config = vehicle_config
self.sensor_config = vehicle_config.sensor_config
self.enable_cot = enable_cot
# Store hyperparameters
self.hparams = {
"bev_size": bev_size,
"bev_resolution": bev_resolution,
"bev_feature_dim": bev_feature_dim,
"num_object_classes": num_object_classes,
"num_seg_classes": num_seg_classes,
"num_waypoints": num_waypoints,
"planning_d_model": planning_d_model,
"future_steps": future_steps,
"num_forecast_modes": num_forecast_modes,
"forecast_steps": forecast_steps,
"num_behaviors": num_behaviors,
"max_speed_mph": vehicle_config.max_speed_mph,
"max_speed_ms": vehicle_config.max_speed_ms,
"num_cameras": self.sensor_config.num_cameras,
"num_ultrasonics": self.sensor_config.num_ultrasonics,
"enable_cot": enable_cot,
"cot_num_actor_queries": cot_num_actor_queries,
"cot_num_road_queries": cot_num_road_queries,
}
# 1. Sensor Fusion Module
self.sensor_fusion = MultiModalSensorFusion(
sensor_config=self.sensor_config,
bev_size=bev_size,
bev_resolution=bev_resolution,
bev_feature_dim=bev_feature_dim,
)
# 2. Perception Module
self.perception = PerceptionModule(
bev_channels=bev_feature_dim,
num_object_classes=num_object_classes,
num_seg_classes=num_seg_classes,
future_steps=future_steps,
num_forecast_modes=num_forecast_modes,
forecast_steps=forecast_steps,
)
# 3. Planning Module
self.planning = PlanningModule(
bev_channels=bev_feature_dim,
d_model=planning_d_model,
num_waypoints=num_waypoints,
max_speed_ms=vehicle_config.max_speed_ms,
num_behaviors=num_behaviors,
)
# 4. Control Module
self.control = ControlModule(
bev_channels=bev_feature_dim,
num_waypoints=num_waypoints,
wheelbase=vehicle_config.wheelbase,
max_speed_ms=vehicle_config.max_speed_ms,
max_steering_deg=vehicle_config.max_steering_angle,
max_accel=vehicle_config.max_acceleration,
max_decel=vehicle_config.max_deceleration,
)
# 5. Chain-of-Thought Safety Reasoning (optional but default ON)
if enable_cot:
self.cot_reasoning = ChainOfThoughtReasoning(
bev_channels=bev_feature_dim,
d_model=planning_d_model,
num_actor_queries=cot_num_actor_queries,
num_road_queries=cot_num_road_queries,
num_waypoints=num_waypoints,
num_behaviors=num_behaviors,
max_speed_ms=vehicle_config.max_speed_ms,
)
else:
self.cot_reasoning = None
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(
self,
camera_images: Optional[torch.Tensor] = None,
camera_intrinsics: Optional[torch.Tensor] = None,
camera_extrinsics: Optional[torch.Tensor] = None,
ultrasonic_distances: Optional[torch.Tensor] = None,
ultrasonic_placements: Optional[torch.Tensor] = None,
ego_state: Optional[torch.Tensor] = None,
nav_command: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Full forward pass: sensors → BEV → perception → planning → control.
Args:
camera_images: (B, N_cams, 3, H, W) raw camera images
camera_intrinsics: (B, N_cams, 3, 3) camera calibration
camera_extrinsics: (B, N_cams, 4, 4) camera-to-ego transforms
ultrasonic_distances: (B, N_us, 1) distance readings
ultrasonic_placements: (B, N_us, 6) sensor positions
ego_state: (B, 6) [speed, accel, steer, yaw_rate, x, y]
nav_command: (B,) navigation command integer
Returns:
Dict containing all intermediate and final outputs
"""
B = (camera_images.shape[0] if camera_images is not None
else ultrasonic_distances.shape[0])
device = (camera_images.device if camera_images is not None
else ultrasonic_distances.device)
# Default ego state if not provided
if ego_state is None:
ego_state = torch.zeros(B, 6, device=device)
# ───── Stage 1: Sensor Fusion ─────
bev_features = self.sensor_fusion(
camera_images=camera_images,
camera_intrinsics=camera_intrinsics,
camera_extrinsics=camera_extrinsics,
ultrasonic_distances=ultrasonic_distances,
ultrasonic_placements=ultrasonic_placements,
)
# ───── Stage 2: Perception ─────
perception_out = self.perception(bev_features)
# ───── Stage 3: Planning ─────
planning_out = self.planning(
bev_features=bev_features,
ego_state=ego_state,
nav_command=nav_command,
)
# ───── Stage 3.5: Chain-of-Thought Safety Reasoning ─────
cot_out = {}
final_waypoints = planning_out["safe_waypoints"]
if self.cot_reasoning is not None:
cot_out = self.cot_reasoning(
bev_features=bev_features,
ego_state=ego_state,
planner_waypoints=planning_out["safe_waypoints"],
)
# Use CoT-enriched BEV for control (safety-aware features)
bev_for_control = cot_out.get("enriched_bev", bev_features)
# Use safety-gated waypoints if available
if "cot/gated_waypoints" in cot_out:
final_waypoints = cot_out["cot/gated_waypoints"]
else:
bev_for_control = bev_features
# ───── Stage 4: Control ─────
control_out = self.control(
bev_features=bev_for_control,
planned_waypoints=final_waypoints,
ego_state=ego_state,
emergency_brake=planning_out["emergency_brake"],
)
# Combine all outputs
output = {}
output["bev_features"] = bev_features
output.update({f"perception/{k}": v for k, v in perception_out.items()})
output.update({f"planning/{k}": v for k, v in planning_out.items()})
output.update({f"control/{k}": v for k, v in control_out.items()})
if cot_out:
output.update({k: v for k, v in cot_out.items() if k != "enriched_bev"})
return output
def get_control_output(
self, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convenience method that returns only (steering, throttle, brake).
"""
out = self.forward(**kwargs)
return (
out["control/steering_deg"],
out["control/throttle"],
out["control/brake"],
)
def count_parameters(self) -> Dict[str, int]:
"""Count parameters per module."""
counts = {
"sensor_fusion": sum(p.numel() for p in self.sensor_fusion.parameters()),
"perception": sum(p.numel() for p in self.perception.parameters()),
"planning": sum(p.numel() for p in self.planning.parameters()),
"control": sum(p.numel() for p in self.control.parameters()),
}
if self.cot_reasoning is not None:
counts["cot_reasoning"] = sum(p.numel() for p in self.cot_reasoning.parameters())
counts["total"] = sum(counts.values())
counts["total_trainable"] = sum(
p.numel() for p in self.parameters() if p.requires_grad
)
return counts
def save_pretrained(self, save_dir: str):
"""Save model with config for easy loading."""
os.makedirs(save_dir, exist_ok=True)
# Save model weights
torch.save(self.state_dict(), os.path.join(save_dir, "model.pt"))
# Save hyperparameters
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(self.hparams, f, indent=2)
# Save sensor config
self.sensor_config.save(os.path.join(save_dir, "sensor_config.json"))
# Save parameter counts
counts = self.count_parameters()
with open(os.path.join(save_dir, "model_summary.json"), "w") as f:
json.dump(counts, f, indent=2)
@classmethod
def from_pretrained(cls, load_dir: str, device: str = "cpu"):
"""Load model from saved directory."""
with open(os.path.join(load_dir, "config.json"), "r") as f:
hparams = json.load(f)
config = VehicleConfig(
max_speed_mph=hparams.get("max_speed_mph", 20.0)
)
model = cls(
vehicle_config=config,
bev_size=hparams.get("bev_size", 200),
bev_resolution=hparams.get("bev_resolution", 0.25),
bev_feature_dim=hparams.get("bev_feature_dim", 256),
num_object_classes=hparams.get("num_object_classes", 10),
num_seg_classes=hparams.get("num_seg_classes", 7),
num_waypoints=hparams.get("num_waypoints", 20),
planning_d_model=hparams.get("planning_d_model", 256),
future_steps=hparams.get("future_steps", 6),
num_forecast_modes=hparams.get("num_forecast_modes", 6),
forecast_steps=hparams.get("forecast_steps", 12),
num_behaviors=hparams.get("num_behaviors", 10),
enable_cot=hparams.get("enable_cot", True),
cot_num_actor_queries=hparams.get("cot_num_actor_queries", 64),
cot_num_road_queries=hparams.get("cot_num_road_queries", 32),
)
weights = torch.load(
os.path.join(load_dir, "model.pt"),
map_location=device,
weights_only=True,
)
model.load_state_dict(weights, strict=False)
return model
def reset(self):
"""Reset stateful components (call at episode start)."""
self.control.reset()
class FSDLoss(nn.Module):
"""
Multi-task loss for training the FSD model.
Combines losses from all modules with learnable task weights.
"""
def __init__(
self,
num_object_classes: int = 10,
num_seg_classes: int = 7,
num_behaviors: int = 10,
# Loss weights (initial)
w_detection: float = 1.0,
w_segmentation: float = 1.0,
w_occupancy: float = 1.0,
w_motion: float = 1.0,
w_behavior: float = 0.5,
w_trajectory: float = 2.0,
w_control: float = 2.0,
w_safety: float = 1.5,
learnable_weights: bool = True,
):
super().__init__()
self.num_object_classes = num_object_classes
if learnable_weights:
# Log-scale learnable task weights (homoscedastic uncertainty)
self.log_vars = nn.Parameter(torch.zeros(8))
else:
self.register_buffer('log_vars', None)
self.fixed_weights = [
w_detection, w_segmentation, w_occupancy, w_motion,
w_behavior, w_trajectory, w_control, w_safety
]
def forward(
self,
predictions: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Compute multi-task loss.
Args:
predictions: Model output dict
targets: Ground truth dict with keys:
- gt_heatmap: (B, C, H, W) object heatmap
- gt_bbox: (B, 6, H, W) bounding boxes
- gt_segmentation: (B, H, W) segmentation labels
- gt_occupancy: (B, 1, H, W) occupancy grid
- gt_behavior: (B,) behavior labels
- gt_waypoints: (B, T, 4) ground truth trajectory
- gt_steering: (B,) steering commands
- gt_throttle: (B,) throttle commands
- gt_brake: (B,) brake commands
"""
losses = {}
# 1. Detection loss (focal loss for heatmap + L1 for bbox)
if "gt_heatmap" in targets:
pred_heat = predictions["perception/object_heatmap"]
gt_heat = targets["gt_heatmap"]
# Resize if needed
if pred_heat.shape != gt_heat.shape:
gt_heat = F.interpolate(gt_heat.float(), size=pred_heat.shape[2:], mode='nearest')
losses["detection"] = self._focal_loss(pred_heat, gt_heat)
# 2. Segmentation loss (cross entropy)
if "gt_segmentation" in targets:
pred_seg = predictions["perception/segmentation"]
gt_seg = targets["gt_segmentation"]
if pred_seg.shape[2:] != gt_seg.shape[1:]:
gt_seg = F.interpolate(gt_seg.float().unsqueeze(1), size=pred_seg.shape[2:], mode='nearest').squeeze(1).long()
losses["segmentation"] = F.cross_entropy(pred_seg, gt_seg)
# 3. Occupancy loss (binary cross entropy)
if "gt_occupancy" in targets:
pred_occ = predictions["perception/occupancy_current"]
gt_occ = targets["gt_occupancy"]
if pred_occ.shape != gt_occ.shape:
gt_occ = F.interpolate(gt_occ.float(), size=pred_occ.shape[2:], mode='nearest')
losses["occupancy"] = F.binary_cross_entropy(pred_occ, gt_occ.float())
# 4. Motion forecasting loss (ADE - average displacement error)
if "gt_future_trajectories" in targets:
pred_traj = predictions["perception/motion_trajectories"]
gt_traj = targets["gt_future_trajectories"]
# Best-of-K: min ADE across modes
errors = torch.norm(pred_traj - gt_traj.unsqueeze(1), dim=-1).mean(dim=-1)
min_errors, _ = errors.min(dim=1)
losses["motion"] = min_errors.mean()
# 5. Behavior prediction loss
if "gt_behavior" in targets:
pred_behavior = predictions["planning/behavior_logits"]
losses["behavior"] = F.cross_entropy(pred_behavior, targets["gt_behavior"])
# 6. Trajectory planning loss (L2 waypoint error)
if "gt_waypoints" in targets:
pred_wp = predictions["planning/safe_waypoints"]
gt_wp = targets["gt_waypoints"]
min_len = min(pred_wp.shape[1], gt_wp.shape[1])
losses["trajectory"] = F.mse_loss(
pred_wp[:, :min_len], gt_wp[:, :min_len]
)
# 7. Control loss
control_loss = torch.tensor(0.0, device=list(predictions.values())[0].device)
if "gt_steering" in targets:
pred_steer = predictions["control/steering_deg"]
control_loss = control_loss + F.mse_loss(pred_steer, targets["gt_steering"])
if "gt_throttle" in targets:
pred_throttle = predictions["control/throttle"]
control_loss = control_loss + F.mse_loss(pred_throttle, targets["gt_throttle"])
if "gt_brake" in targets:
pred_brake = predictions["control/brake"]
control_loss = control_loss + F.mse_loss(pred_brake, targets["gt_brake"])
losses["control"] = control_loss
# 8. Safety loss (minimize collision risk)
losses["safety"] = predictions["planning/collision_risk"].mean()
# Combine losses with weights
if self.log_vars is not None:
total_loss = torch.tensor(0.0, device=self.log_vars.device)
loss_keys = list(losses.keys())
for i, key in enumerate(loss_keys):
if i < len(self.log_vars):
precision = torch.exp(-self.log_vars[i])
total_loss = total_loss + precision * losses[key] + self.log_vars[i]
else:
total_loss = sum(
w * losses.get(k, torch.tensor(0.0))
for w, k in zip(self.fixed_weights, [
"detection", "segmentation", "occupancy", "motion",
"behavior", "trajectory", "control", "safety"
])
)
losses["total"] = total_loss
return losses
def _focal_loss(self, pred, target, alpha=2.0, beta=4.0):
"""Focal loss for heatmap detection."""
pos_mask = target.eq(1).float()
neg_mask = target.lt(1).float()
pred = torch.clamp(pred, 1e-6, 1 - 1e-6)
pos_loss = -torch.log(pred) * torch.pow(1 - pred, alpha) * pos_mask
neg_loss = -torch.log(1 - pred) * torch.pow(pred, alpha) * torch.pow(1 - target, beta) * neg_mask
num_pos = pos_mask.sum().clamp(min=1)
loss = (pos_loss.sum() + neg_loss.sum()) / num_pos
return loss