FSD-Level5-CoT / fsd_model /perception.py
Reality123b's picture
Add perception.py
87321df verified
"""
Perception Module for FSD Model.
Handles:
1. Object Detection in BEV (vehicles, pedestrians, cyclists, etc.)
2. Lane Detection and Road Segmentation
3. Free Space Estimation
4. Traffic Sign/Signal Recognition
5. Occupancy Grid Generation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
class BEVObjectDetectionHead(nn.Module):
"""
Detects objects in BEV space.
Predicts: class, bounding box (x, y, w, h, yaw), velocity (vx, vy).
Uses anchor-free detection similar to CenterPoint.
"""
def __init__(
self,
in_channels: int = 256,
num_classes: int = 10,
num_heads: int = 6,
):
super().__init__()
self.num_classes = num_classes
# Shared feature extraction
self.shared_conv = nn.Sequential(
nn.Conv2d(in_channels, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
# Heatmap head (object center detection)
self.heatmap_head = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, num_classes, 1),
)
# Bounding box regression head (x, y, w, h, sin(yaw), cos(yaw))
self.bbox_head = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 6, 1),
)
# Velocity head (vx, vy)
self.velocity_head = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 2, 1),
)
def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]:
feat = self.shared_conv(bev)
heatmap = torch.sigmoid(self.heatmap_head(feat))
bbox = self.bbox_head(feat)
velocity = self.velocity_head(feat)
return {
"heatmap": heatmap, # (B, num_classes, H, W)
"bbox": bbox, # (B, 6, H, W)
"velocity": velocity, # (B, 2, H, W)
}
class BEVSegmentationHead(nn.Module):
"""
Semantic segmentation in BEV space.
Classes: drivable area, lane lines, crosswalks, sidewalks, etc.
"""
def __init__(
self,
in_channels: int = 256,
num_seg_classes: int = 7,
):
super().__init__()
# Segmentation classes:
# 0: background, 1: drivable, 2: lane_line, 3: crosswalk,
# 4: sidewalk, 5: stop_line, 6: road_edge
self.num_seg_classes = num_seg_classes
self.decoder = nn.Sequential(
nn.Conv2d(in_channels, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, num_seg_classes, 1),
)
def forward(self, bev: torch.Tensor) -> torch.Tensor:
"""Returns: (B, num_seg_classes, H, W) logits"""
return self.decoder(bev)
class OccupancyGridHead(nn.Module):
"""
Predicts occupancy probability for each BEV grid cell.
Binary: occupied / free space.
Also predicts future occupancy for T timesteps (motion forecasting).
"""
def __init__(
self,
in_channels: int = 256,
future_steps: int = 6,
):
super().__init__()
self.future_steps = future_steps
# Current occupancy
self.current_occ = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, 1),
)
# Future occupancy prediction (temporal convolution)
self.future_occ = nn.Sequential(
nn.Conv2d(in_channels, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, future_steps, 1),
)
def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]:
current = torch.sigmoid(self.current_occ(bev)) # (B, 1, H, W)
future = torch.sigmoid(self.future_occ(bev)) # (B, T, H, W)
return {"current": current, "future": future}
class MotionForecastingHead(nn.Module):
"""
Predicts future motion of detected agents.
For each detected object, predicts K possible trajectories.
"""
def __init__(
self,
in_channels: int = 256,
num_modes: int = 6,
future_steps: int = 12,
hidden_dim: int = 128,
):
super().__init__()
self.num_modes = num_modes
self.future_steps = future_steps
# Global feature pooling
self.pool = nn.AdaptiveAvgPool2d(8)
self.trajectory_decoder = nn.Sequential(
nn.Flatten(),
nn.Linear(in_channels * 8 * 8, hidden_dim * 4),
nn.ReLU(),
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.ReLU(),
)
# Multi-modal trajectory output
self.mode_heads = nn.ModuleList([
nn.Linear(hidden_dim * 2, future_steps * 2) # (x, y) for each step
for _ in range(num_modes)
])
# Mode probability
self.mode_prob = nn.Sequential(
nn.Linear(hidden_dim * 2, num_modes),
nn.Softmax(dim=-1),
)
def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]:
feat = self.pool(bev)
feat = self.trajectory_decoder(feat)
trajectories = []
for head in self.mode_heads:
traj = head(feat).reshape(-1, self.future_steps, 2)
trajectories.append(traj)
trajectories = torch.stack(trajectories, dim=1) # (B, K, T, 2)
probs = self.mode_prob(feat) # (B, K)
return {"trajectories": trajectories, "probabilities": probs}
class PerceptionModule(nn.Module):
"""
Complete perception module combining all detection heads.
Input: BEV features from sensor fusion.
Output: Full scene understanding including objects, lanes, occupancy, motion.
"""
def __init__(
self,
bev_channels: int = 256,
num_object_classes: int = 10,
num_seg_classes: int = 7,
future_steps: int = 6,
num_forecast_modes: int = 6,
forecast_steps: int = 12,
):
super().__init__()
# Shared BEV feature refinement with temporal aggregation
self.temporal_conv = nn.Sequential(
nn.Conv2d(bev_channels, bev_channels, 3, padding=1),
nn.BatchNorm2d(bev_channels),
nn.ReLU(),
nn.Conv2d(bev_channels, bev_channels, 3, padding=1),
nn.BatchNorm2d(bev_channels),
nn.ReLU(),
)
# Detection heads
self.object_detection = BEVObjectDetectionHead(
bev_channels, num_object_classes
)
self.segmentation = BEVSegmentationHead(
bev_channels, num_seg_classes
)
self.occupancy = OccupancyGridHead(
bev_channels, future_steps
)
self.motion_forecasting = MotionForecastingHead(
bev_channels, num_forecast_modes, forecast_steps
)
def forward(self, bev: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
bev: (B, C, H, W) BEV feature map from sensor fusion
Returns:
Dict with all perception outputs
"""
# Refine BEV features
bev_refined = self.temporal_conv(bev) + bev # residual
# Run all detection heads
detections = self.object_detection(bev_refined)
segmentation = self.segmentation(bev_refined)
occupancy = self.occupancy(bev_refined)
motion = self.motion_forecasting(bev_refined)
return {
"object_heatmap": detections["heatmap"],
"object_bbox": detections["bbox"],
"object_velocity": detections["velocity"],
"segmentation": segmentation,
"occupancy_current": occupancy["current"],
"occupancy_future": occupancy["future"],
"motion_trajectories": motion["trajectories"],
"motion_probabilities": motion["probabilities"],
}