""" Perception Module for FSD Model. Handles: 1. Object Detection in BEV (vehicles, pedestrians, cyclists, etc.) 2. Lane Detection and Road Segmentation 3. Free Space Estimation 4. Traffic Sign/Signal Recognition 5. Occupancy Grid Generation """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, Tuple class BEVObjectDetectionHead(nn.Module): """ Detects objects in BEV space. Predicts: class, bounding box (x, y, w, h, yaw), velocity (vx, vy). Uses anchor-free detection similar to CenterPoint. """ def __init__( self, in_channels: int = 256, num_classes: int = 10, num_heads: int = 6, ): super().__init__() self.num_classes = num_classes # Shared feature extraction self.shared_conv = nn.Sequential( nn.Conv2d(in_channels, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), ) # Heatmap head (object center detection) self.heatmap_head = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, num_classes, 1), ) # Bounding box regression head (x, y, w, h, sin(yaw), cos(yaw)) self.bbox_head = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 6, 1), ) # Velocity head (vx, vy) self.velocity_head = nn.Sequential( nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 2, 1), ) def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: feat = self.shared_conv(bev) heatmap = torch.sigmoid(self.heatmap_head(feat)) bbox = self.bbox_head(feat) velocity = self.velocity_head(feat) return { "heatmap": heatmap, # (B, num_classes, H, W) "bbox": bbox, # (B, 6, H, W) "velocity": velocity, # (B, 2, H, W) } class BEVSegmentationHead(nn.Module): """ Semantic segmentation in BEV space. Classes: drivable area, lane lines, crosswalks, sidewalks, etc. """ def __init__( self, in_channels: int = 256, num_seg_classes: int = 7, ): super().__init__() # Segmentation classes: # 0: background, 1: drivable, 2: lane_line, 3: crosswalk, # 4: sidewalk, 5: stop_line, 6: road_edge self.num_seg_classes = num_seg_classes self.decoder = nn.Sequential( nn.Conv2d(in_channels, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, num_seg_classes, 1), ) def forward(self, bev: torch.Tensor) -> torch.Tensor: """Returns: (B, num_seg_classes, H, W) logits""" return self.decoder(bev) class OccupancyGridHead(nn.Module): """ Predicts occupancy probability for each BEV grid cell. Binary: occupied / free space. Also predicts future occupancy for T timesteps (motion forecasting). """ def __init__( self, in_channels: int = 256, future_steps: int = 6, ): super().__init__() self.future_steps = future_steps # Current occupancy self.current_occ = nn.Sequential( nn.Conv2d(in_channels, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 1, 1), ) # Future occupancy prediction (temporal convolution) self.future_occ = nn.Sequential( nn.Conv2d(in_channels, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, future_steps, 1), ) def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: current = torch.sigmoid(self.current_occ(bev)) # (B, 1, H, W) future = torch.sigmoid(self.future_occ(bev)) # (B, T, H, W) return {"current": current, "future": future} class MotionForecastingHead(nn.Module): """ Predicts future motion of detected agents. For each detected object, predicts K possible trajectories. """ def __init__( self, in_channels: int = 256, num_modes: int = 6, future_steps: int = 12, hidden_dim: int = 128, ): super().__init__() self.num_modes = num_modes self.future_steps = future_steps # Global feature pooling self.pool = nn.AdaptiveAvgPool2d(8) self.trajectory_decoder = nn.Sequential( nn.Flatten(), nn.Linear(in_channels * 8 * 8, hidden_dim * 4), nn.ReLU(), nn.Linear(hidden_dim * 4, hidden_dim * 2), nn.ReLU(), ) # Multi-modal trajectory output self.mode_heads = nn.ModuleList([ nn.Linear(hidden_dim * 2, future_steps * 2) # (x, y) for each step for _ in range(num_modes) ]) # Mode probability self.mode_prob = nn.Sequential( nn.Linear(hidden_dim * 2, num_modes), nn.Softmax(dim=-1), ) def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: feat = self.pool(bev) feat = self.trajectory_decoder(feat) trajectories = [] for head in self.mode_heads: traj = head(feat).reshape(-1, self.future_steps, 2) trajectories.append(traj) trajectories = torch.stack(trajectories, dim=1) # (B, K, T, 2) probs = self.mode_prob(feat) # (B, K) return {"trajectories": trajectories, "probabilities": probs} class PerceptionModule(nn.Module): """ Complete perception module combining all detection heads. Input: BEV features from sensor fusion. Output: Full scene understanding including objects, lanes, occupancy, motion. """ def __init__( self, bev_channels: int = 256, num_object_classes: int = 10, num_seg_classes: int = 7, future_steps: int = 6, num_forecast_modes: int = 6, forecast_steps: int = 12, ): super().__init__() # Shared BEV feature refinement with temporal aggregation self.temporal_conv = nn.Sequential( nn.Conv2d(bev_channels, bev_channels, 3, padding=1), nn.BatchNorm2d(bev_channels), nn.ReLU(), nn.Conv2d(bev_channels, bev_channels, 3, padding=1), nn.BatchNorm2d(bev_channels), nn.ReLU(), ) # Detection heads self.object_detection = BEVObjectDetectionHead( bev_channels, num_object_classes ) self.segmentation = BEVSegmentationHead( bev_channels, num_seg_classes ) self.occupancy = OccupancyGridHead( bev_channels, future_steps ) self.motion_forecasting = MotionForecastingHead( bev_channels, num_forecast_modes, forecast_steps ) def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: bev: (B, C, H, W) BEV feature map from sensor fusion Returns: Dict with all perception outputs """ # Refine BEV features bev_refined = self.temporal_conv(bev) + bev # residual # Run all detection heads detections = self.object_detection(bev_refined) segmentation = self.segmentation(bev_refined) occupancy = self.occupancy(bev_refined) motion = self.motion_forecasting(bev_refined) return { "object_heatmap": detections["heatmap"], "object_bbox": detections["bbox"], "object_velocity": detections["velocity"], "segmentation": segmentation, "occupancy_current": occupancy["current"], "occupancy_future": occupancy["future"], "motion_trajectories": motion["trajectories"], "motion_probabilities": motion["probabilities"], }