Reality123b commited on
Commit
cd793e6
·
verified ·
1 Parent(s): c5ea85f

Add sensor_fusion.py

Browse files
Files changed (1) hide show
  1. fsd_model/sensor_fusion.py +447 -0
fsd_model/sensor_fusion.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Modal Sensor Fusion Module
3
+ Inspired by BEVFusion and GaussianFusion architectures.
4
+ Fuses camera images and ultrasonic sensor data into a unified
5
+ Bird's Eye View (BEV) representation.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ from typing import List, Optional, Dict, Tuple
13
+
14
+ from .config import SensorConfig, CameraSensorConfig, UltrasonicSensorConfig
15
+
16
+
17
+ class CameraBackbone(nn.Module):
18
+ """
19
+ Lightweight CNN backbone for camera feature extraction.
20
+ Extracts multi-scale features from each camera image.
21
+ Architecture inspired by EfficientNet-lite / ResNet-18 style blocks.
22
+ """
23
+ def __init__(self, in_channels: int = 3, base_channels: int = 64):
24
+ super().__init__()
25
+ self.base_channels = base_channels
26
+
27
+ # Stage 1: Initial convolution
28
+ self.stage1 = nn.Sequential(
29
+ nn.Conv2d(in_channels, base_channels, 7, stride=2, padding=3, bias=False),
30
+ nn.BatchNorm2d(base_channels),
31
+ nn.ReLU(inplace=True),
32
+ nn.MaxPool2d(3, stride=2, padding=1),
33
+ )
34
+
35
+ # Stage 2: Feature extraction blocks
36
+ self.stage2 = self._make_stage(base_channels, base_channels * 2, num_blocks=2, stride=2)
37
+
38
+ # Stage 3
39
+ self.stage3 = self._make_stage(base_channels * 2, base_channels * 4, num_blocks=2, stride=2)
40
+
41
+ # Stage 4: Deepest features
42
+ self.stage4 = self._make_stage(base_channels * 4, base_channels * 8, num_blocks=2, stride=2)
43
+
44
+ # Feature Pyramid Network (FPN) for multi-scale fusion
45
+ self.fpn_lateral4 = nn.Conv2d(base_channels * 8, base_channels * 4, 1)
46
+ self.fpn_lateral3 = nn.Conv2d(base_channels * 4, base_channels * 4, 1)
47
+ self.fpn_lateral2 = nn.Conv2d(base_channels * 2, base_channels * 4, 1)
48
+
49
+ self.fpn_output4 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
50
+ self.fpn_output3 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
51
+ self.fpn_output2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
52
+
53
+ def _make_stage(self, in_channels, out_channels, num_blocks, stride):
54
+ layers = []
55
+ layers.append(ResBlock(in_channels, out_channels, stride))
56
+ for _ in range(1, num_blocks):
57
+ layers.append(ResBlock(out_channels, out_channels, 1))
58
+ return nn.Sequential(*layers)
59
+
60
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
61
+ """
62
+ Args:
63
+ x: (B, C, H, W) camera image tensor
64
+ Returns:
65
+ Dict with multi-scale features
66
+ """
67
+ c1 = self.stage1(x) # (B, 64, H/4, W/4)
68
+ c2 = self.stage2(c1) # (B, 128, H/8, W/8)
69
+ c3 = self.stage3(c2) # (B, 256, H/16, W/16)
70
+ c4 = self.stage4(c3) # (B, 512, H/32, W/32)
71
+
72
+ # FPN top-down pathway
73
+ p4 = self.fpn_lateral4(c4)
74
+ p3 = self.fpn_lateral3(c3) + F.interpolate(p4, size=c3.shape[2:], mode='bilinear', align_corners=False)
75
+ p2 = self.fpn_lateral2(c2) + F.interpolate(p3, size=c2.shape[2:], mode='bilinear', align_corners=False)
76
+
77
+ p4 = self.fpn_output4(p4)
78
+ p3 = self.fpn_output3(p3)
79
+ p2 = self.fpn_output2(p2)
80
+
81
+ return {"p2": p2, "p3": p3, "p4": p4}
82
+
83
+
84
+ class ResBlock(nn.Module):
85
+ """Residual block with optional downsampling."""
86
+ def __init__(self, in_channels, out_channels, stride=1):
87
+ super().__init__()
88
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
89
+ self.bn1 = nn.BatchNorm2d(out_channels)
90
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
91
+ self.bn2 = nn.BatchNorm2d(out_channels)
92
+
93
+ self.shortcut = nn.Sequential()
94
+ if stride != 1 or in_channels != out_channels:
95
+ self.shortcut = nn.Sequential(
96
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
97
+ nn.BatchNorm2d(out_channels)
98
+ )
99
+
100
+ def forward(self, x):
101
+ out = F.relu(self.bn1(self.conv1(x)))
102
+ out = self.bn2(self.conv2(out))
103
+ out = out + self.shortcut(x)
104
+ return F.relu(out)
105
+
106
+
107
+ class UltrasonicEncoder(nn.Module):
108
+ """
109
+ Encodes ultrasonic sensor readings into a spatial feature representation.
110
+ Each ultrasonic sensor provides a distance reading that is mapped to a
111
+ spatial cone in BEV space.
112
+ """
113
+ def __init__(self, num_sensors: int, hidden_dim: int = 128, bev_size: int = 200):
114
+ super().__init__()
115
+ self.num_sensors = num_sensors
116
+ self.hidden_dim = hidden_dim
117
+ self.bev_size = bev_size
118
+
119
+ # Per-sensor distance encoding
120
+ self.distance_encoder = nn.Sequential(
121
+ nn.Linear(1, 32),
122
+ nn.ReLU(),
123
+ nn.Linear(32, 64),
124
+ nn.ReLU(),
125
+ )
126
+
127
+ # Sensor placement encoding (x, y, z, yaw, pitch, roll)
128
+ self.placement_encoder = nn.Sequential(
129
+ nn.Linear(6, 32),
130
+ nn.ReLU(),
131
+ nn.Linear(32, 64),
132
+ nn.ReLU(),
133
+ )
134
+
135
+ # Combined sensor feature
136
+ self.sensor_fusion = nn.Sequential(
137
+ nn.Linear(128, hidden_dim),
138
+ nn.ReLU(),
139
+ nn.Linear(hidden_dim, hidden_dim),
140
+ )
141
+
142
+ # Project all sensor features to BEV grid
143
+ self.bev_projection = nn.Sequential(
144
+ nn.Linear(num_sensors * hidden_dim, 512),
145
+ nn.ReLU(),
146
+ nn.Linear(512, hidden_dim * (bev_size // 10) * (bev_size // 10)),
147
+ )
148
+
149
+ # Upsample to full BEV resolution
150
+ self.bev_upsample = nn.Sequential(
151
+ nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, 4, stride=2, padding=1),
152
+ nn.BatchNorm2d(hidden_dim // 2),
153
+ nn.ReLU(),
154
+ nn.ConvTranspose2d(hidden_dim // 2, hidden_dim // 4, 4, stride=2, padding=1),
155
+ nn.BatchNorm2d(hidden_dim // 4),
156
+ nn.ReLU(),
157
+ nn.Conv2d(hidden_dim // 4, hidden_dim // 4, 3, padding=1),
158
+ nn.BatchNorm2d(hidden_dim // 4),
159
+ nn.ReLU(),
160
+ )
161
+
162
+ def forward(self, distances: torch.Tensor, placements: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ Args:
165
+ distances: (B, num_sensors, 1) - distance readings per sensor
166
+ placements: (B, num_sensors, 6) - sensor positions (x,y,z,yaw,pitch,roll)
167
+ Returns:
168
+ bev_features: (B, hidden_dim//4, bev_size//2~, bev_size//2~) BEV feature map
169
+ """
170
+ B = distances.shape[0]
171
+
172
+ # Encode each sensor's distance
173
+ dist_feat = self.distance_encoder(distances) # (B, N, 64)
174
+
175
+ # Encode each sensor's position
176
+ place_feat = self.placement_encoder(placements) # (B, N, 64)
177
+
178
+ # Combine distance + placement
179
+ combined = torch.cat([dist_feat, place_feat], dim=-1) # (B, N, 128)
180
+ sensor_feat = self.sensor_fusion(combined) # (B, N, hidden_dim)
181
+
182
+ # Flatten all sensors and project to BEV
183
+ flat = sensor_feat.reshape(B, -1) # (B, N * hidden_dim)
184
+ bev_flat = self.bev_projection(flat) # (B, hidden_dim * small_h * small_w)
185
+
186
+ small_size = self.bev_size // 10
187
+ bev = bev_flat.reshape(B, self.hidden_dim, small_size, small_size)
188
+
189
+ # Upsample to larger BEV resolution
190
+ bev = self.bev_upsample(bev)
191
+
192
+ return bev
193
+
194
+
195
+ class ViewTransformer(nn.Module):
196
+ """
197
+ Transforms camera perspective features into BEV space.
198
+ Uses Lift-Splat-Shoot (LSS) approach: predict depth distribution
199
+ per pixel, then scatter features into 3D space and collapse to BEV.
200
+ """
201
+ def __init__(
202
+ self,
203
+ in_channels: int = 256,
204
+ num_depth_bins: int = 64,
205
+ depth_min: float = 1.0,
206
+ depth_max: float = 50.0,
207
+ bev_size: int = 200,
208
+ bev_resolution: float = 0.25, # meters per pixel
209
+ ):
210
+ super().__init__()
211
+ self.in_channels = in_channels
212
+ self.num_depth_bins = num_depth_bins
213
+ self.bev_size = bev_size
214
+ self.bev_resolution = bev_resolution
215
+
216
+ # Depth distribution prediction
217
+ self.depth_net = nn.Sequential(
218
+ nn.Conv2d(in_channels, in_channels, 3, padding=1),
219
+ nn.BatchNorm2d(in_channels),
220
+ nn.ReLU(),
221
+ nn.Conv2d(in_channels, num_depth_bins, 1),
222
+ )
223
+
224
+ # Feature compression for BEV
225
+ self.feature_net = nn.Sequential(
226
+ nn.Conv2d(in_channels, in_channels // 2, 1),
227
+ nn.BatchNorm2d(in_channels // 2),
228
+ nn.ReLU(),
229
+ )
230
+
231
+ # Depth bins
232
+ self.register_buffer(
233
+ 'depth_bins',
234
+ torch.linspace(depth_min, depth_max, num_depth_bins)
235
+ )
236
+
237
+ # BEV encoder after scattering
238
+ self.bev_encoder = nn.Sequential(
239
+ nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
240
+ nn.BatchNorm2d(in_channels // 2),
241
+ nn.ReLU(),
242
+ nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
243
+ nn.BatchNorm2d(in_channels // 2),
244
+ nn.ReLU(),
245
+ )
246
+
247
+ def forward(
248
+ self,
249
+ camera_features: torch.Tensor,
250
+ intrinsics: torch.Tensor,
251
+ extrinsics: torch.Tensor,
252
+ ) -> torch.Tensor:
253
+ """
254
+ Args:
255
+ camera_features: (B, N_cams, C, H, W) multi-camera features
256
+ intrinsics: (B, N_cams, 3, 3) camera intrinsic matrices
257
+ extrinsics: (B, N_cams, 4, 4) camera-to-ego transformation matrices
258
+ Returns:
259
+ bev: (B, C//2, bev_size, bev_size) BEV feature map
260
+ """
261
+ B, N, C, H, W = camera_features.shape
262
+
263
+ # Reshape for batch processing
264
+ features = camera_features.reshape(B * N, C, H, W)
265
+
266
+ # Predict depth distribution
267
+ depth_logits = self.depth_net(features) # (B*N, D, H, W)
268
+ depth_probs = F.softmax(depth_logits, dim=1) # (B*N, D, H, W)
269
+
270
+ # Compress features
271
+ feat = self.feature_net(features) # (B*N, C//2, H, W)
272
+ C_out = feat.shape[1]
273
+
274
+ # Outer product: depth_probs * features -> volume
275
+ # (B*N, C_out, D, H, W)
276
+ feat_expanded = feat.unsqueeze(2) # (B*N, C_out, 1, H, W)
277
+ depth_expanded = depth_probs.unsqueeze(1) # (B*N, 1, D, H, W)
278
+ volume = feat_expanded * depth_expanded # (B*N, C_out, D, H, W)
279
+
280
+ # Simplified BEV pooling: average pool over depth and spatial dims
281
+ # In full implementation, would do proper 3D-to-BEV projection
282
+ volume = volume.reshape(B, N, C_out, self.num_depth_bins, H, W)
283
+
284
+ # Pool over depth dimension
285
+ bev_per_cam = volume.mean(dim=3) # (B, N, C_out, H, W)
286
+
287
+ # Adaptive pool each camera view to BEV size
288
+ bev_per_cam = bev_per_cam.reshape(B * N, C_out, H, W)
289
+ bev_per_cam = F.adaptive_avg_pool2d(bev_per_cam, (self.bev_size, self.bev_size))
290
+ bev_per_cam = bev_per_cam.reshape(B, N, C_out, self.bev_size, self.bev_size)
291
+
292
+ # Fuse all camera BEV views (mean fusion)
293
+ bev = bev_per_cam.mean(dim=1) # (B, C_out, bev_size, bev_size)
294
+
295
+ # Refine BEV features
296
+ bev = self.bev_encoder(bev)
297
+
298
+ return bev
299
+
300
+
301
+ class MultiModalSensorFusion(nn.Module):
302
+ """
303
+ Main sensor fusion module that combines:
304
+ 1. Multi-camera visual features (via CNN backbone + View Transformer → BEV)
305
+ 2. Ultrasonic proximity features (via encoder → BEV)
306
+
307
+ Output: Unified BEV representation for downstream perception/planning.
308
+ Fully configurable for any number/placement of sensors.
309
+ """
310
+ def __init__(
311
+ self,
312
+ sensor_config: SensorConfig,
313
+ bev_size: int = 200,
314
+ bev_resolution: float = 0.25,
315
+ camera_channels: int = 3,
316
+ backbone_base: int = 64,
317
+ bev_feature_dim: int = 256,
318
+ ):
319
+ super().__init__()
320
+ self.sensor_config = sensor_config
321
+ self.bev_size = bev_size
322
+ self.bev_resolution = bev_resolution
323
+ self.bev_feature_dim = bev_feature_dim
324
+
325
+ num_cameras = sensor_config.num_cameras
326
+ num_ultrasonics = sensor_config.num_ultrasonics
327
+
328
+ # Camera processing pipeline
329
+ if num_cameras > 0:
330
+ self.camera_backbone = CameraBackbone(camera_channels, backbone_base)
331
+ self.view_transformer = ViewTransformer(
332
+ in_channels=backbone_base * 4, # FPN output channels
333
+ bev_size=bev_size,
334
+ bev_resolution=bev_resolution,
335
+ )
336
+ camera_bev_channels = backbone_base * 2 # output of view transformer
337
+ else:
338
+ self.camera_backbone = None
339
+ self.view_transformer = None
340
+ camera_bev_channels = 0
341
+
342
+ # Ultrasonic processing pipeline
343
+ if num_ultrasonics > 0:
344
+ self.ultrasonic_encoder = UltrasonicEncoder(
345
+ num_sensors=num_ultrasonics,
346
+ hidden_dim=128,
347
+ bev_size=bev_size,
348
+ )
349
+ # Get output size of ultrasonic encoder
350
+ us_bev_channels = 32 # hidden_dim // 4
351
+ else:
352
+ self.ultrasonic_encoder = None
353
+ us_bev_channels = 0
354
+
355
+ # Adaptive fusion of different sensor modalities
356
+ total_bev_channels = camera_bev_channels + us_bev_channels
357
+
358
+ self.fusion_conv = nn.Sequential(
359
+ nn.Conv2d(total_bev_channels, bev_feature_dim, 3, padding=1),
360
+ nn.BatchNorm2d(bev_feature_dim),
361
+ nn.ReLU(),
362
+ nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
363
+ nn.BatchNorm2d(bev_feature_dim),
364
+ nn.ReLU(),
365
+ )
366
+
367
+ # Channel attention for adaptive sensor weighting
368
+ self.channel_attention = nn.Sequential(
369
+ nn.AdaptiveAvgPool2d(1),
370
+ nn.Flatten(),
371
+ nn.Linear(bev_feature_dim, bev_feature_dim // 4),
372
+ nn.ReLU(),
373
+ nn.Linear(bev_feature_dim // 4, bev_feature_dim),
374
+ nn.Sigmoid(),
375
+ )
376
+
377
+ # Final BEV refinement with residual
378
+ self.bev_refine = nn.Sequential(
379
+ nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
380
+ nn.BatchNorm2d(bev_feature_dim),
381
+ nn.ReLU(),
382
+ nn.Conv2d(bev_feature_dim, bev_feature_dim, 3, padding=1),
383
+ nn.BatchNorm2d(bev_feature_dim),
384
+ )
385
+
386
+ def forward(
387
+ self,
388
+ camera_images: Optional[torch.Tensor] = None,
389
+ camera_intrinsics: Optional[torch.Tensor] = None,
390
+ camera_extrinsics: Optional[torch.Tensor] = None,
391
+ ultrasonic_distances: Optional[torch.Tensor] = None,
392
+ ultrasonic_placements: Optional[torch.Tensor] = None,
393
+ ) -> torch.Tensor:
394
+ """
395
+ Args:
396
+ camera_images: (B, N_cams, 3, H, W)
397
+ camera_intrinsics: (B, N_cams, 3, 3)
398
+ camera_extrinsics: (B, N_cams, 4, 4)
399
+ ultrasonic_distances: (B, N_us, 1)
400
+ ultrasonic_placements: (B, N_us, 6)
401
+ Returns:
402
+ bev_features: (B, bev_feature_dim, bev_size, bev_size)
403
+ """
404
+ bev_parts = []
405
+
406
+ # Process cameras
407
+ if self.camera_backbone is not None and camera_images is not None:
408
+ B, N, C, H, W = camera_images.shape
409
+ # Extract features for each camera
410
+ imgs = camera_images.reshape(B * N, C, H, W)
411
+ multi_scale = self.camera_backbone(imgs)
412
+
413
+ # Use p2 (highest resolution FPN output) for view transformation
414
+ cam_feat = multi_scale["p2"]
415
+ _, Cf, Hf, Wf = cam_feat.shape
416
+ cam_feat = cam_feat.reshape(B, N, Cf, Hf, Wf)
417
+
418
+ cam_bev = self.view_transformer(
419
+ cam_feat, camera_intrinsics, camera_extrinsics
420
+ )
421
+ bev_parts.append(cam_bev)
422
+
423
+ # Process ultrasonics
424
+ if self.ultrasonic_encoder is not None and ultrasonic_distances is not None:
425
+ us_bev = self.ultrasonic_encoder(ultrasonic_distances, ultrasonic_placements)
426
+ # Resize to match BEV size
427
+ us_bev = F.adaptive_avg_pool2d(us_bev, (self.bev_size, self.bev_size))
428
+ bev_parts.append(us_bev)
429
+
430
+ if len(bev_parts) == 0:
431
+ raise ValueError("No sensor data provided!")
432
+
433
+ # Concatenate all BEV features
434
+ bev_concat = torch.cat(bev_parts, dim=1)
435
+
436
+ # Fuse modalities
437
+ bev = self.fusion_conv(bev_concat)
438
+
439
+ # Channel attention
440
+ attn = self.channel_attention(bev).unsqueeze(-1).unsqueeze(-1)
441
+ bev = bev * attn
442
+
443
+ # Residual refinement
444
+ bev = bev + self.bev_refine(bev)
445
+ bev = F.relu(bev)
446
+
447
+ return bev