""" Multi-Modal Sensor Fusion Module Inspired by BEVFusion and GaussianFusion architectures. Fuses camera images and ultrasonic sensor data into a unified Bird's Eye View (BEV) representation. """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import List, Optional, Dict, Tuple from .config import SensorConfig, CameraSensorConfig, UltrasonicSensorConfig class CameraBackbone(nn.Module): """ Lightweight CNN backbone for camera feature extraction. Extracts multi-scale features from each camera image. Architecture inspired by EfficientNet-lite / ResNet-18 style blocks. """ def __init__(self, in_channels: int = 3, base_channels: int = 64): super().__init__() self.base_channels = base_channels # Stage 1: Initial convolution self.stage1 = nn.Sequential( nn.Conv2d(in_channels, base_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(base_channels), nn.ReLU(inplace=True), nn.MaxPool2d(3, stride=2, padding=1), ) # Stage 2: Feature extraction blocks self.stage2 = self._make_stage(base_channels, base_channels * 2, num_blocks=2, stride=2) # Stage 3 self.stage3 = self._make_stage(base_channels * 2, base_channels * 4, num_blocks=2, stride=2) # Stage 4: Deepest features self.stage4 = self._make_stage(base_channels * 4, base_channels * 8, num_blocks=2, stride=2) # Feature Pyramid Network (FPN) for multi-scale fusion self.fpn_lateral4 = nn.Conv2d(base_channels * 8, base_channels * 4, 1) self.fpn_lateral3 = nn.Conv2d(base_channels * 4, base_channels * 4, 1) self.fpn_lateral2 = nn.Conv2d(base_channels * 2, base_channels * 4, 1) self.fpn_output4 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1) self.fpn_output3 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1) self.fpn_output2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1) def _make_stage(self, in_channels, out_channels, num_blocks, stride): layers = [] layers.append(ResBlock(in_channels, out_channels, stride)) for _ in range(1, num_blocks): layers.append(ResBlock(out_channels, out_channels, 1)) return nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Args: x: (B, C, H, W) camera image tensor Returns: Dict with multi-scale features """ c1 = self.stage1(x) # (B, 64, H/4, W/4) c2 = self.stage2(c1) # (B, 128, H/8, W/8) c3 = self.stage3(c2) # (B, 256, H/16, W/16) c4 = self.stage4(c3) # (B, 512, H/32, W/32) # FPN top-down pathway p4 = self.fpn_lateral4(c4) p3 = self.fpn_lateral3(c3) + F.interpolate(p4, size=c3.shape[2:], mode='bilinear', align_corners=False) p2 = self.fpn_lateral2(c2) + F.interpolate(p3, size=c2.shape[2:], mode='bilinear', align_corners=False) p4 = self.fpn_output4(p4) p3 = self.fpn_output3(p3) p2 = self.fpn_output2(p2) return {"p2": p2, "p3": p3, "p4": p4} class ResBlock(nn.Module): """Residual block with optional downsampling.""" def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out = out + self.shortcut(x) return F.relu(out) class UltrasonicEncoder(nn.Module): """ Encodes ultrasonic sensor readings into a spatial feature representation. Each ultrasonic sensor provides a distance reading that is mapped to a spatial cone in BEV space. """ def __init__(self, num_sensors: int, hidden_dim: int = 128, bev_size: int = 200): super().__init__() self.num_sensors = num_sensors self.hidden_dim = hidden_dim self.bev_size = bev_size # Per-sensor distance encoding self.distance_encoder = nn.Sequential( nn.Linear(1, 32), nn.ReLU(), nn.Linear(32, 64), nn.ReLU(), ) # Sensor placement encoding (x, y, z, yaw, pitch, roll) self.placement_encoder = nn.Sequential( nn.Linear(6, 32), nn.ReLU(), nn.Linear(32, 64), nn.ReLU(), ) # Combined sensor feature self.sensor_fusion = nn.Sequential( nn.Linear(128, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), ) # Project all sensor features to BEV grid self.bev_projection = nn.Sequential( nn.Linear(num_sensors * hidden_dim, 512), nn.ReLU(), nn.Linear(512, hidden_dim * (bev_size // 10) * (bev_size // 10)), ) # Upsample to full BEV resolution self.bev_upsample = nn.Sequential( nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, stride=2, padding=1), nn.BatchNorm2d(hidden_dim // 2), nn.ReLU(), nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 4, stride=2, padding=1), nn.BatchNorm2d(hidden_dim // 4), nn.ReLU(), nn.Conv2d(hidden_dim // 4, hidden_dim // 4, 3, padding=1), nn.BatchNorm2d(hidden_dim // 4), nn.ReLU(), ) def forward(self, distances: torch.Tensor, placements: torch.Tensor) -> torch.Tensor: """ Args: distances: (B, num_sensors, 1) - distance readings per sensor placements: (B, num_sensors, 6) - sensor positions (x,y,z,yaw,pitch,roll) Returns: bev_features: (B, hidden_dim//4, bev_size//2~, bev_size//2~) BEV feature map """ B = distances.shape[0] # Encode each sensor's distance dist_feat = self.distance_encoder(distances) # (B, N, 64) # Encode each sensor's position place_feat = self.placement_encoder(placements) # (B, N, 64) # Combine distance + placement combined = torch.cat([dist_feat, place_feat], dim=-1) # (B, N, 128) sensor_feat = self.sensor_fusion(combined) # (B, N, hidden_dim) # Flatten all sensors and project to BEV flat = sensor_feat.reshape(B, -1) # (B, N * hidden_dim) bev_flat = self.bev_projection(flat) # (B, hidden_dim * small_h * small_w) small_size = self.bev_size // 10 bev = bev_flat.reshape(B, self.hidden_dim, small_size, small_size) # Upsample to larger BEV resolution bev = self.bev_upsample(bev) return bev class ViewTransformer(nn.Module): """ Transforms camera perspective features into BEV space. Uses Lift-Splat-Shoot (LSS) approach: predict depth distribution per pixel, then scatter features into 3D space and collapse to BEV. """ def __init__( self, in_channels: int = 256, num_depth_bins: int = 64, depth_min: float = 1.0, depth_max: float = 50.0, bev_size: int = 200, bev_resolution: float = 0.25, # meters per pixel ): super().__init__() self.in_channels = in_channels self.num_depth_bins = num_depth_bins self.bev_size = bev_size self.bev_resolution = bev_resolution # Depth distribution prediction self.depth_net = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.BatchNorm2d(in_channels), nn.ReLU(), nn.Conv2d(in_channels, num_depth_bins, 1), ) # Feature compression for BEV self.feature_net = nn.Sequential( nn.Conv2d(in_channels, in_channels // 2, 1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(), ) # Depth bins self.register_buffer( 'depth_bins', torch.linspace(depth_min, depth_max, num_depth_bins) ) # BEV encoder after scattering self.bev_encoder = nn.Sequential( nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(), nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), nn.BatchNorm2d(in_channels // 2), nn.ReLU(), ) def forward( self, camera_features: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor, ) -> torch.Tensor: """ Args: camera_features: (B, N_cams, C, H, W) multi-camera features intrinsics: (B, N_cams, 3, 3) camera intrinsic matrices extrinsics: (B, N_cams, 4, 4) camera-to-ego transformation matrices Returns: bev: (B, C//2, bev_size, bev_size) BEV feature map """ B, N, C, H, W = camera_features.shape # Reshape for batch processing features = camera_features.reshape(B * N, C, H, W) # Predict depth distribution depth_logits = self.depth_net(features) # (B*N, D, H, W) depth_probs = F.softmax(depth_logits, dim=1) # (B*N, D, H, W) # Compress features feat = self.feature_net(features) # (B*N, C//2, H, W) C_out = feat.shape[1] # Outer product: depth_probs * features -> volume # (B*N, C_out, D, H, W) feat_expanded = feat.unsqueeze(2) # (B*N, C_out, 1, H, W) depth_expanded = depth_probs.unsqueeze(1) # (B*N, 1, D, H, W) volume = feat_expanded * depth_expanded # (B*N, C_out, D, H, W) # Simplified BEV pooling: average pool over depth and spatial dims # In full implementation, would do proper 3D-to-BEV projection volume = volume.reshape(B, N, C_out, self.num_depth_bins, H, W) # Pool over depth dimension bev_per_cam = volume.mean(dim=3) # (B, N, C_out, H, W) # Adaptive pool each camera view to BEV size bev_per_cam = bev_per_cam.reshape(B * N, C_out, H, W) bev_per_cam = F.adaptive_avg_pool2d(bev_per_cam, (self.bev_size, self.bev_size)) bev_per_cam = bev_per_cam.reshape(B, N, C_out, self.bev_size, self.bev_size) # Fuse all camera BEV views (mean fusion) bev = bev_per_cam.mean(dim=1) # (B, C_out, bev_size, bev_size) # Refine BEV features bev = self.bev_encoder(bev) return bev class MultiModalSensorFusion(nn.Module): """ Main sensor fusion module that combines: 1. Multi-camera visual features (via CNN backbone + View Transformer → BEV) 2. Ultrasonic proximity features (via encoder → BEV) Output: Unified BEV representation for downstream perception/planning. Fully configurable for any number/placement of sensors. """ def __init__( self, sensor_config: SensorConfig, bev_size: int = 200, bev_resolution: float = 0.25, camera_channels: int = 3, backbone_base: int = 64, bev_feature_dim: int = 256, ): super().__init__() self.sensor_config = sensor_config self.bev_size = bev_size self.bev_resolution = bev_resolution self.bev_feature_dim = bev_feature_dim num_cameras = sensor_config.num_cameras num_ultrasonics = sensor_config.num_ultrasonics # Camera processing pipeline if num_cameras > 0: self.camera_backbone = CameraBackbone(camera_channels, backbone_base) self.view_transformer = ViewTransformer( in_channels=backbone_base * 4, # FPN output channels bev_size=bev_size, bev_resolution=bev_resolution, ) camera_bev_channels = backbone_base * 2 # output of view transformer else: self.camera_backbone = None self.view_transformer = None camera_bev_channels = 0 # Ultrasonic processing pipeline if num_ultrasonics > 0: self.ultrasonic_encoder = UltrasonicEncoder( num_sensors=num_ultrasonics, hidden_dim=128, bev_size=bev_size, ) # Get output size of ultrasonic encoder us_bev_channels = 32 # hidden_dim // 4 else: self.ultrasonic_encoder = None us_bev_channels = 0 # Adaptive fusion of different sensor modalities total_bev_channels = camera_bev_channels + us_bev_channels self.fusion_conv = nn.Sequential( nn.Conv2d(total_bev_channels, bev_feature_dim, 3, padding=1), nn.BatchNorm2d(bev_feature_dim), nn.ReLU(), nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1), nn.BatchNorm2d(bev_feature_dim), nn.ReLU(), ) # Channel attention for adaptive sensor weighting self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(bev_feature_dim, bev_feature_dim // 4), nn.ReLU(), nn.Linear(bev_feature_dim // 4, bev_feature_dim), nn.Sigmoid(), ) # Final BEV refinement with residual self.bev_refine = nn.Sequential( nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1), nn.BatchNorm2d(bev_feature_dim), nn.ReLU(), nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1), nn.BatchNorm2d(bev_feature_dim), ) def forward( self, camera_images: Optional[torch.Tensor] = None, camera_intrinsics: Optional[torch.Tensor] = None, camera_extrinsics: Optional[torch.Tensor] = None, ultrasonic_distances: Optional[torch.Tensor] = None, ultrasonic_placements: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: camera_images: (B, N_cams, 3, H, W) camera_intrinsics: (B, N_cams, 3, 3) camera_extrinsics: (B, N_cams, 4, 4) ultrasonic_distances: (B, N_us, 1) ultrasonic_placements: (B, N_us, 6) Returns: bev_features: (B, bev_feature_dim, bev_size, bev_size) """ bev_parts = [] # Process cameras if self.camera_backbone is not None and camera_images is not None: B, N, C, H, W = camera_images.shape # Extract features for each camera imgs = camera_images.reshape(B * N, C, H, W) multi_scale = self.camera_backbone(imgs) # Use p2 (highest resolution FPN output) for view transformation cam_feat = multi_scale["p2"] _, Cf, Hf, Wf = cam_feat.shape cam_feat = cam_feat.reshape(B, N, Cf, Hf, Wf) cam_bev = self.view_transformer( cam_feat, camera_intrinsics, camera_extrinsics ) bev_parts.append(cam_bev) # Process ultrasonics if self.ultrasonic_encoder is not None and ultrasonic_distances is not None: us_bev = self.ultrasonic_encoder(ultrasonic_distances, ultrasonic_placements) # Resize to match BEV size us_bev = F.adaptive_avg_pool2d(us_bev, (self.bev_size, self.bev_size)) bev_parts.append(us_bev) if len(bev_parts) == 0: raise ValueError("No sensor data provided!") # Concatenate all BEV features bev_concat = torch.cat(bev_parts, dim=1) # Fuse modalities bev = self.fusion_conv(bev_concat) # Channel attention attn = self.channel_attention(bev).unsqueeze(-1).unsqueeze(-1) bev = bev * attn # Residual refinement bev = bev + self.bev_refine(bev) bev = F.relu(bev) return bev