| """ |
| Perception Module for FSD Model. |
| Handles: |
| 1. Object Detection in BEV (vehicles, pedestrians, cyclists, etc.) |
| 2. Lane Detection and Road Segmentation |
| 3. Free Space Estimation |
| 4. Traffic Sign/Signal Recognition |
| 5. Occupancy Grid Generation |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Dict, List, Optional, Tuple |
|
|
|
|
| class BEVObjectDetectionHead(nn.Module): |
| """ |
| Detects objects in BEV space. |
| Predicts: class, bounding box (x, y, w, h, yaw), velocity (vx, vy). |
| Uses anchor-free detection similar to CenterPoint. |
| """ |
| def __init__( |
| self, |
| in_channels: int = 256, |
| num_classes: int = 10, |
| num_heads: int = 6, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| |
| |
| self.shared_conv = nn.Sequential( |
| nn.Conv2d(in_channels, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Conv2d(128, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| ) |
| |
| |
| self.heatmap_head = nn.Sequential( |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, num_classes, 1), |
| ) |
| |
| |
| self.bbox_head = nn.Sequential( |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 6, 1), |
| ) |
| |
| |
| self.velocity_head = nn.Sequential( |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 2, 1), |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: |
| feat = self.shared_conv(bev) |
| |
| heatmap = torch.sigmoid(self.heatmap_head(feat)) |
| bbox = self.bbox_head(feat) |
| velocity = self.velocity_head(feat) |
| |
| return { |
| "heatmap": heatmap, |
| "bbox": bbox, |
| "velocity": velocity, |
| } |
|
|
|
|
| class BEVSegmentationHead(nn.Module): |
| """ |
| Semantic segmentation in BEV space. |
| Classes: drivable area, lane lines, crosswalks, sidewalks, etc. |
| """ |
| def __init__( |
| self, |
| in_channels: int = 256, |
| num_seg_classes: int = 7, |
| ): |
| super().__init__() |
| |
| |
| |
| self.num_seg_classes = num_seg_classes |
| |
| self.decoder = nn.Sequential( |
| nn.Conv2d(in_channels, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, num_seg_classes, 1), |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> torch.Tensor: |
| """Returns: (B, num_seg_classes, H, W) logits""" |
| return self.decoder(bev) |
|
|
|
|
| class OccupancyGridHead(nn.Module): |
| """ |
| Predicts occupancy probability for each BEV grid cell. |
| Binary: occupied / free space. |
| Also predicts future occupancy for T timesteps (motion forecasting). |
| """ |
| def __init__( |
| self, |
| in_channels: int = 256, |
| future_steps: int = 6, |
| ): |
| super().__init__() |
| self.future_steps = future_steps |
| |
| |
| self.current_occ = nn.Sequential( |
| nn.Conv2d(in_channels, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, 1, 1), |
| ) |
| |
| |
| self.future_occ = nn.Sequential( |
| nn.Conv2d(in_channels, 128, 3, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Conv2d(128, 64, 3, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, future_steps, 1), |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: |
| current = torch.sigmoid(self.current_occ(bev)) |
| future = torch.sigmoid(self.future_occ(bev)) |
| return {"current": current, "future": future} |
|
|
|
|
| class MotionForecastingHead(nn.Module): |
| """ |
| Predicts future motion of detected agents. |
| For each detected object, predicts K possible trajectories. |
| """ |
| def __init__( |
| self, |
| in_channels: int = 256, |
| num_modes: int = 6, |
| future_steps: int = 12, |
| hidden_dim: int = 128, |
| ): |
| super().__init__() |
| self.num_modes = num_modes |
| self.future_steps = future_steps |
| |
| |
| self.pool = nn.AdaptiveAvgPool2d(8) |
| |
| self.trajectory_decoder = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(in_channels * 8 * 8, hidden_dim * 4), |
| nn.ReLU(), |
| nn.Linear(hidden_dim * 4, hidden_dim * 2), |
| nn.ReLU(), |
| ) |
| |
| |
| self.mode_heads = nn.ModuleList([ |
| nn.Linear(hidden_dim * 2, future_steps * 2) |
| for _ in range(num_modes) |
| ]) |
| |
| |
| self.mode_prob = nn.Sequential( |
| nn.Linear(hidden_dim * 2, num_modes), |
| nn.Softmax(dim=-1), |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: |
| feat = self.pool(bev) |
| feat = self.trajectory_decoder(feat) |
| |
| trajectories = [] |
| for head in self.mode_heads: |
| traj = head(feat).reshape(-1, self.future_steps, 2) |
| trajectories.append(traj) |
| |
| trajectories = torch.stack(trajectories, dim=1) |
| probs = self.mode_prob(feat) |
| |
| return {"trajectories": trajectories, "probabilities": probs} |
|
|
|
|
| class PerceptionModule(nn.Module): |
| """ |
| Complete perception module combining all detection heads. |
| Input: BEV features from sensor fusion. |
| Output: Full scene understanding including objects, lanes, occupancy, motion. |
| """ |
| def __init__( |
| self, |
| bev_channels: int = 256, |
| num_object_classes: int = 10, |
| num_seg_classes: int = 7, |
| future_steps: int = 6, |
| num_forecast_modes: int = 6, |
| forecast_steps: int = 12, |
| ): |
| super().__init__() |
| |
| |
| self.temporal_conv = nn.Sequential( |
| nn.Conv2d(bev_channels, bev_channels, 3, padding=1), |
| nn.BatchNorm2d(bev_channels), |
| nn.ReLU(), |
| nn.Conv2d(bev_channels, bev_channels, 3, padding=1), |
| nn.BatchNorm2d(bev_channels), |
| nn.ReLU(), |
| ) |
| |
| |
| self.object_detection = BEVObjectDetectionHead( |
| bev_channels, num_object_classes |
| ) |
| self.segmentation = BEVSegmentationHead( |
| bev_channels, num_seg_classes |
| ) |
| self.occupancy = OccupancyGridHead( |
| bev_channels, future_steps |
| ) |
| self.motion_forecasting = MotionForecastingHead( |
| bev_channels, num_forecast_modes, forecast_steps |
| ) |
| |
| def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]: |
| """ |
| Args: |
| bev: (B, C, H, W) BEV feature map from sensor fusion |
| Returns: |
| Dict with all perception outputs |
| """ |
| |
| bev_refined = self.temporal_conv(bev) + bev |
| |
| |
| detections = self.object_detection(bev_refined) |
| segmentation = self.segmentation(bev_refined) |
| occupancy = self.occupancy(bev_refined) |
| motion = self.motion_forecasting(bev_refined) |
| |
| return { |
| "object_heatmap": detections["heatmap"], |
| "object_bbox": detections["bbox"], |
| "object_velocity": detections["velocity"], |
| "segmentation": segmentation, |
| "occupancy_current": occupancy["current"], |
| "occupancy_future": occupancy["future"], |
| "motion_trajectories": motion["trajectories"], |
| "motion_probabilities": motion["probabilities"], |
| } |
|
|