FSD-Level5-CoT / fsd_model /sensor_fusion.py
Reality123b's picture
Add sensor_fusion.py
cd793e6 verified
"""
Multi-Modal Sensor Fusion Module
Inspired by BEVFusion and GaussianFusion architectures.
Fuses camera images and ultrasonic sensor data into a unified
Bird's Eye View (BEV) representation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Optional, Dict, Tuple
from .config import SensorConfig, CameraSensorConfig, UltrasonicSensorConfig
class CameraBackbone(nn.Module):
"""
Lightweight CNN backbone for camera feature extraction.
Extracts multi-scale features from each camera image.
Architecture inspired by EfficientNet-lite / ResNet-18 style blocks.
"""
def __init__(self, in_channels: int = 3, base_channels: int = 64):
super().__init__()
self.base_channels = base_channels
# Stage 1: Initial convolution
self.stage1 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2, padding=1),
)
# Stage 2: Feature extraction blocks
self.stage2 = self._make_stage(base_channels, base_channels * 2, num_blocks=2, stride=2)
# Stage 3
self.stage3 = self._make_stage(base_channels * 2, base_channels * 4, num_blocks=2, stride=2)
# Stage 4: Deepest features
self.stage4 = self._make_stage(base_channels * 4, base_channels * 8, num_blocks=2, stride=2)
# Feature Pyramid Network (FPN) for multi-scale fusion
self.fpn_lateral4 = nn.Conv2d(base_channels * 8, base_channels * 4, 1)
self.fpn_lateral3 = nn.Conv2d(base_channels * 4, base_channels * 4, 1)
self.fpn_lateral2 = nn.Conv2d(base_channels * 2, base_channels * 4, 1)
self.fpn_output4 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
self.fpn_output3 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
self.fpn_output2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
def _make_stage(self, in_channels, out_channels, num_blocks, stride):
layers = []
layers.append(ResBlock(in_channels, out_channels, stride))
for _ in range(1, num_blocks):
layers.append(ResBlock(out_channels, out_channels, 1))
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Args:
x: (B, C, H, W) camera image tensor
Returns:
Dict with multi-scale features
"""
c1 = self.stage1(x) # (B, 64, H/4, W/4)
c2 = self.stage2(c1) # (B, 128, H/8, W/8)
c3 = self.stage3(c2) # (B, 256, H/16, W/16)
c4 = self.stage4(c3) # (B, 512, H/32, W/32)
# FPN top-down pathway
p4 = self.fpn_lateral4(c4)
p3 = self.fpn_lateral3(c3) + F.interpolate(p4, size=c3.shape[2:], mode='bilinear', align_corners=False)
p2 = self.fpn_lateral2(c2) + F.interpolate(p3, size=c2.shape[2:], mode='bilinear', align_corners=False)
p4 = self.fpn_output4(p4)
p3 = self.fpn_output3(p3)
p2 = self.fpn_output2(p2)
return {"p2": p2, "p3": p3, "p4": p4}
class ResBlock(nn.Module):
"""Residual block with optional downsampling."""
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = out + self.shortcut(x)
return F.relu(out)
class UltrasonicEncoder(nn.Module):
"""
Encodes ultrasonic sensor readings into a spatial feature representation.
Each ultrasonic sensor provides a distance reading that is mapped to a
spatial cone in BEV space.
"""
def __init__(self, num_sensors: int, hidden_dim: int = 128, bev_size: int = 200):
super().__init__()
self.num_sensors = num_sensors
self.hidden_dim = hidden_dim
self.bev_size = bev_size
# Per-sensor distance encoding
self.distance_encoder = nn.Sequential(
nn.Linear(1, 32),
nn.ReLU(),
nn.Linear(32, 64),
nn.ReLU(),
)
# Sensor placement encoding (x, y, z, yaw, pitch, roll)
self.placement_encoder = nn.Sequential(
nn.Linear(6, 32),
nn.ReLU(),
nn.Linear(32, 64),
nn.ReLU(),
)
# Combined sensor feature
self.sensor_fusion = nn.Sequential(
nn.Linear(128, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Project all sensor features to BEV grid
self.bev_projection = nn.Sequential(
nn.Linear(num_sensors * hidden_dim, 512),
nn.ReLU(),
nn.Linear(512, hidden_dim * (bev_size // 10) * (bev_size // 10)),
)
# Upsample to full BEV resolution
self.bev_upsample = nn.Sequential(
nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, stride=2, padding=1),
nn.BatchNorm2d(hidden_dim // 2),
nn.ReLU(),
nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 4, stride=2, padding=1),
nn.BatchNorm2d(hidden_dim // 4),
nn.ReLU(),
nn.Conv2d(hidden_dim // 4, hidden_dim // 4, 3, padding=1),
nn.BatchNorm2d(hidden_dim // 4),
nn.ReLU(),
)
def forward(self, distances: torch.Tensor, placements: torch.Tensor) -> torch.Tensor:
"""
Args:
distances: (B, num_sensors, 1) - distance readings per sensor
placements: (B, num_sensors, 6) - sensor positions (x,y,z,yaw,pitch,roll)
Returns:
bev_features: (B, hidden_dim//4, bev_size//2~, bev_size//2~) BEV feature map
"""
B = distances.shape[0]
# Encode each sensor's distance
dist_feat = self.distance_encoder(distances) # (B, N, 64)
# Encode each sensor's position
place_feat = self.placement_encoder(placements) # (B, N, 64)
# Combine distance + placement
combined = torch.cat([dist_feat, place_feat], dim=-1) # (B, N, 128)
sensor_feat = self.sensor_fusion(combined) # (B, N, hidden_dim)
# Flatten all sensors and project to BEV
flat = sensor_feat.reshape(B, -1) # (B, N * hidden_dim)
bev_flat = self.bev_projection(flat) # (B, hidden_dim * small_h * small_w)
small_size = self.bev_size // 10
bev = bev_flat.reshape(B, self.hidden_dim, small_size, small_size)
# Upsample to larger BEV resolution
bev = self.bev_upsample(bev)
return bev
class ViewTransformer(nn.Module):
"""
Transforms camera perspective features into BEV space.
Uses Lift-Splat-Shoot (LSS) approach: predict depth distribution
per pixel, then scatter features into 3D space and collapse to BEV.
"""
def __init__(
self,
in_channels: int = 256,
num_depth_bins: int = 64,
depth_min: float = 1.0,
depth_max: float = 50.0,
bev_size: int = 200,
bev_resolution: float = 0.25, # meters per pixel
):
super().__init__()
self.in_channels = in_channels
self.num_depth_bins = num_depth_bins
self.bev_size = bev_size
self.bev_resolution = bev_resolution
# Depth distribution prediction
self.depth_net = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, num_depth_bins, 1),
)
# Feature compression for BEV
self.feature_net = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, 1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(),
)
# Depth bins
self.register_buffer(
'depth_bins',
torch.linspace(depth_min, depth_max, num_depth_bins)
)
# BEV encoder after scattering
self.bev_encoder = nn.Sequential(
nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(),
nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(),
)
def forward(
self,
camera_features: torch.Tensor,
intrinsics: torch.Tensor,
extrinsics: torch.Tensor,
) -> torch.Tensor:
"""
Args:
camera_features: (B, N_cams, C, H, W) multi-camera features
intrinsics: (B, N_cams, 3, 3) camera intrinsic matrices
extrinsics: (B, N_cams, 4, 4) camera-to-ego transformation matrices
Returns:
bev: (B, C//2, bev_size, bev_size) BEV feature map
"""
B, N, C, H, W = camera_features.shape
# Reshape for batch processing
features = camera_features.reshape(B * N, C, H, W)
# Predict depth distribution
depth_logits = self.depth_net(features) # (B*N, D, H, W)
depth_probs = F.softmax(depth_logits, dim=1) # (B*N, D, H, W)
# Compress features
feat = self.feature_net(features) # (B*N, C//2, H, W)
C_out = feat.shape[1]
# Outer product: depth_probs * features -> volume
# (B*N, C_out, D, H, W)
feat_expanded = feat.unsqueeze(2) # (B*N, C_out, 1, H, W)
depth_expanded = depth_probs.unsqueeze(1) # (B*N, 1, D, H, W)
volume = feat_expanded * depth_expanded # (B*N, C_out, D, H, W)
# Simplified BEV pooling: average pool over depth and spatial dims
# In full implementation, would do proper 3D-to-BEV projection
volume = volume.reshape(B, N, C_out, self.num_depth_bins, H, W)
# Pool over depth dimension
bev_per_cam = volume.mean(dim=3) # (B, N, C_out, H, W)
# Adaptive pool each camera view to BEV size
bev_per_cam = bev_per_cam.reshape(B * N, C_out, H, W)
bev_per_cam = F.adaptive_avg_pool2d(bev_per_cam, (self.bev_size, self.bev_size))
bev_per_cam = bev_per_cam.reshape(B, N, C_out, self.bev_size, self.bev_size)
# Fuse all camera BEV views (mean fusion)
bev = bev_per_cam.mean(dim=1) # (B, C_out, bev_size, bev_size)
# Refine BEV features
bev = self.bev_encoder(bev)
return bev
class MultiModalSensorFusion(nn.Module):
"""
Main sensor fusion module that combines:
1. Multi-camera visual features (via CNN backbone + View Transformer → BEV)
2. Ultrasonic proximity features (via encoder → BEV)
Output: Unified BEV representation for downstream perception/planning.
Fully configurable for any number/placement of sensors.
"""
def __init__(
self,
sensor_config: SensorConfig,
bev_size: int = 200,
bev_resolution: float = 0.25,
camera_channels: int = 3,
backbone_base: int = 64,
bev_feature_dim: int = 256,
):
super().__init__()
self.sensor_config = sensor_config
self.bev_size = bev_size
self.bev_resolution = bev_resolution
self.bev_feature_dim = bev_feature_dim
num_cameras = sensor_config.num_cameras
num_ultrasonics = sensor_config.num_ultrasonics
# Camera processing pipeline
if num_cameras > 0:
self.camera_backbone = CameraBackbone(camera_channels, backbone_base)
self.view_transformer = ViewTransformer(
in_channels=backbone_base * 4, # FPN output channels
bev_size=bev_size,
bev_resolution=bev_resolution,
)
camera_bev_channels = backbone_base * 2 # output of view transformer
else:
self.camera_backbone = None
self.view_transformer = None
camera_bev_channels = 0
# Ultrasonic processing pipeline
if num_ultrasonics > 0:
self.ultrasonic_encoder = UltrasonicEncoder(
num_sensors=num_ultrasonics,
hidden_dim=128,
bev_size=bev_size,
)
# Get output size of ultrasonic encoder
us_bev_channels = 32 # hidden_dim // 4
else:
self.ultrasonic_encoder = None
us_bev_channels = 0
# Adaptive fusion of different sensor modalities
total_bev_channels = camera_bev_channels + us_bev_channels
self.fusion_conv = nn.Sequential(
nn.Conv2d(total_bev_channels, bev_feature_dim, 3, padding=1),
nn.BatchNorm2d(bev_feature_dim),
nn.ReLU(),
nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
nn.BatchNorm2d(bev_feature_dim),
nn.ReLU(),
)
# Channel attention for adaptive sensor weighting
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(bev_feature_dim, bev_feature_dim // 4),
nn.ReLU(),
nn.Linear(bev_feature_dim // 4, bev_feature_dim),
nn.Sigmoid(),
)
# Final BEV refinement with residual
self.bev_refine = nn.Sequential(
nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
nn.BatchNorm2d(bev_feature_dim),
nn.ReLU(),
nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
nn.BatchNorm2d(bev_feature_dim),
)
def forward(
self,
camera_images: Optional[torch.Tensor] = None,
camera_intrinsics: Optional[torch.Tensor] = None,
camera_extrinsics: Optional[torch.Tensor] = None,
ultrasonic_distances: Optional[torch.Tensor] = None,
ultrasonic_placements: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
camera_images: (B, N_cams, 3, H, W)
camera_intrinsics: (B, N_cams, 3, 3)
camera_extrinsics: (B, N_cams, 4, 4)
ultrasonic_distances: (B, N_us, 1)
ultrasonic_placements: (B, N_us, 6)
Returns:
bev_features: (B, bev_feature_dim, bev_size, bev_size)
"""
bev_parts = []
# Process cameras
if self.camera_backbone is not None and camera_images is not None:
B, N, C, H, W = camera_images.shape
# Extract features for each camera
imgs = camera_images.reshape(B * N, C, H, W)
multi_scale = self.camera_backbone(imgs)
# Use p2 (highest resolution FPN output) for view transformation
cam_feat = multi_scale["p2"]
_, Cf, Hf, Wf = cam_feat.shape
cam_feat = cam_feat.reshape(B, N, Cf, Hf, Wf)
cam_bev = self.view_transformer(
cam_feat, camera_intrinsics, camera_extrinsics
)
bev_parts.append(cam_bev)
# Process ultrasonics
if self.ultrasonic_encoder is not None and ultrasonic_distances is not None:
us_bev = self.ultrasonic_encoder(ultrasonic_distances, ultrasonic_placements)
# Resize to match BEV size
us_bev = F.adaptive_avg_pool2d(us_bev, (self.bev_size, self.bev_size))
bev_parts.append(us_bev)
if len(bev_parts) == 0:
raise ValueError("No sensor data provided!")
# Concatenate all BEV features
bev_concat = torch.cat(bev_parts, dim=1)
# Fuse modalities
bev = self.fusion_conv(bev_concat)
# Channel attention
attn = self.channel_attention(bev).unsqueeze(-1).unsqueeze(-1)
bev = bev * attn
# Residual refinement
bev = bev + self.bev_refine(bev)
bev = F.relu(bev)
return bev