Alfred Liu commited on
Commit
c2f1911
·
1 Parent(s): e3f4e57

Add comments to core components

Browse files
models/sparsebev_head.py CHANGED
@@ -70,6 +70,7 @@ class SparseBEVHead(DETRHead):
70
  query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
71
  #query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
72
 
 
73
  B = mlvl_feats[0].shape[0]
74
  query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
75
 
@@ -92,7 +93,7 @@ class SparseBEVHead(DETRHead):
92
  bbox_preds[..., 5:10],
93
  ], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
94
 
95
- if mask_dict is not None and mask_dict['pad_size'] > 0:
96
  output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
97
  output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
98
  output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
@@ -116,6 +117,10 @@ class SparseBEVHead(DETRHead):
116
  return outs
117
 
118
  def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
 
 
 
 
119
  device = init_query_bbox.device
120
  indicator0 = torch.zeros([self.num_query, 1], device=device)
121
  init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
 
70
  query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
71
  #query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
72
 
73
+ # query denoising
74
  B = mlvl_feats[0].shape[0]
75
  query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
76
 
 
93
  bbox_preds[..., 5:10],
94
  ], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
95
 
96
+ if mask_dict is not None and mask_dict['pad_size'] > 0: # if using query denoising
97
  output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
98
  output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
99
  output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
 
117
  return outs
118
 
119
  def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
120
+ # mostly borrowed from:
121
+ # - https://github.com/IDEA-Research/DN-DETR/blob/main/models/DN_DAB_DETR/dn_components.py
122
+ # - https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petrv2_dnhead.py
123
+
124
  device = init_query_bbox.device
125
  indicator0 = torch.zeros([self.num_query, 1], device=device)
126
  init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
models/sparsebev_sampling.py CHANGED
@@ -25,9 +25,25 @@ def make_sample_points(query_bbox, offset, pc_range):
25
 
26
 
27
  def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
28
- B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  N = 6
30
-
31
  sample_points = sample_points.reshape(B, Q, T, G * P, 3)
32
 
33
  # get the projection matrix
@@ -42,7 +58,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
42
  sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
43
  sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
44
 
45
- # project 3d sampling points to image
46
  sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
47
 
48
  # homo coord -> pixel coord
@@ -62,6 +78,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
62
  & (sample_points_cam[..., 0:1] < 1.0)
63
  ).squeeze(-1).float() # [B, T, N, Q, GP]
64
 
 
65
  if DUMP.enabled:
66
  torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
67
  '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
@@ -71,6 +88,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
71
  valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
72
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
73
 
 
74
  i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
75
  i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
76
  i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
@@ -79,21 +97,31 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
79
  i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
80
  i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
81
  i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
 
 
82
  i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
83
 
 
84
  sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
85
  valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
86
 
87
- sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1)
 
 
 
88
  sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
89
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
90
  sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
91
 
 
92
  scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
93
  scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
94
  scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
95
 
 
96
  final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
 
 
97
  C = final.shape[2] # [BTG, Q, C, P]
98
  final = final.reshape(B, T, G, Q, C, P)
99
  final = final.permute(0, 3, 2, 1, 5, 4)
 
25
 
26
 
27
  def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
28
+ """
29
+ Args:
30
+ sample_points: 3D sampling points in shape [B, Q, T, G, P, 3]
31
+ mlvl_feats: list of multi-scale features from neck, each in shape [B*T*G, C, N, H, W]
32
+ scale_weights: weights for multi-scale aggregation, [B, Q, G, T, P, L]
33
+ lidar2img: 4x4 projection matrix in shape [B, TN, 4, 4]
34
+ Symbol meaning:
35
+ B: batch size
36
+ Q: num of queries
37
+ T: num of frames
38
+ G: num of groups (we follow the group sampling mechanism of AdaMixer)
39
+ P: num of sampling points per frame per group
40
+ N: num of views (six for nuScenes)
41
+ L: num of layers of feature pyramid (typically it is 4: C2, C3, C4, C5)
42
+ """
43
+
44
+ B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 3]
45
  N = 6
46
+
47
  sample_points = sample_points.reshape(B, Q, T, G * P, 3)
48
 
49
  # get the projection matrix
 
58
  sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
59
  sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
60
 
61
+ # project 3d sampling points to N views
62
  sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
63
 
64
  # homo coord -> pixel coord
 
78
  & (sample_points_cam[..., 0:1] < 1.0)
79
  ).squeeze(-1).float() # [B, T, N, Q, GP]
80
 
81
+ # for visualization only
82
  if DUMP.enabled:
83
  torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
84
  '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
 
88
  valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
89
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
90
 
91
+ # prepare batched indexing
92
  i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
93
  i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
94
  i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
 
97
  i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
98
  i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
99
  i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
100
+
101
+ # we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
102
  i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
103
 
104
+ # index the only one sampling point and its valid flag
105
  sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
106
  valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
107
 
108
+ # treat the view index as a new axis for grid_sample and normalize the view index to [0, 1]
109
+ sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)
110
+
111
+ # reorganize the tensor to stack T and G to the batch dim for better parallelism
112
  sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
113
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
114
  sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
115
 
116
+ # reorganize the tensor to stack T and G to the batch dim for better parallelism
117
  scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
118
  scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
119
  scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
120
 
121
+ # multi-scale multi-view grid sample
122
  final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
123
+
124
+ # reorganize the sampled features
125
  C = final.shape[2] # [BTG, Q, C, P]
126
  final = final.reshape(B, T, G, Q, C, P)
127
  final = final.permute(0, 3, 2, 1, 5, 4)
models/sparsebev_transformer.py CHANGED
@@ -43,6 +43,7 @@ class SparseBEVTransformerDecoder(BaseModule):
43
  self.num_layers = num_layers
44
  self.pc_range = pc_range
45
 
 
46
  self.decoder_layer = SparseBEVTransformerDecoderLayer(
47
  embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
48
  )
@@ -54,6 +55,7 @@ class SparseBEVTransformerDecoder(BaseModule):
54
  def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
55
  cls_scores, bbox_preds = [], []
56
 
 
57
  timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
58
  timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
59
  time_diff = timestamps[:, :1, :] - timestamps
@@ -61,10 +63,12 @@ class SparseBEVTransformerDecoder(BaseModule):
61
  time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
62
  img_metas[0]['time_diff'] = time_diff
63
 
 
64
  lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
65
  lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
66
  img_metas[0]['lidar2img'] = lidar2img
67
 
 
68
  for lvl, feat in enumerate(mlvl_feats):
69
  B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
70
  N, T, G, C = 6, TN // 6, 4, GC // 4
@@ -164,6 +168,7 @@ class SparseBEVTransformerDecoderLayer(BaseModule):
164
  bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
165
  bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
166
 
 
167
  time_diff = img_metas[0]['time_diff'] # [B, F]
168
  if time_diff.shape[1] > 1:
169
  time_diff = time_diff.clone()
@@ -182,6 +187,7 @@ class SparseBEVTransformerDecoderLayer(BaseModule):
182
 
183
 
184
  class SparseBEVSelfAttention(BaseModule):
 
185
  def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
186
  super().__init__(init_cfg)
187
  self.pc_range = pc_range
@@ -207,8 +213,10 @@ class SparseBEVSelfAttention(BaseModule):
207
 
208
  tau = tau.permute(0, 2, 1) # [B, 8, Q]
209
  attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
210
- if pre_attn_mask is not None:
 
211
  attn_mask[:, :, pre_attn_mask] = float('-inf')
 
212
  attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
213
  return self.attention(query_feat, attn_mask=attn_mask)
214
 
@@ -234,6 +242,7 @@ class SparseBEVSelfAttention(BaseModule):
234
 
235
 
236
  class SparseBEVSampling(BaseModule):
 
237
  def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
238
  super().__init__(init_cfg)
239
 
@@ -302,6 +311,7 @@ class SparseBEVSampling(BaseModule):
302
 
303
 
304
  class AdaptiveMixing(nn.Module):
 
305
  def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
306
  super(AdaptiveMixing, self).__init__()
307
 
 
43
  self.num_layers = num_layers
44
  self.pc_range = pc_range
45
 
46
+ # params are shared across all decoder layers
47
  self.decoder_layer = SparseBEVTransformerDecoderLayer(
48
  embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
49
  )
 
55
  def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
56
  cls_scores, bbox_preds = [], []
57
 
58
+ # calculate time difference according to timestamps
59
  timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
60
  timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
61
  time_diff = timestamps[:, :1, :] - timestamps
 
63
  time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
64
  img_metas[0]['time_diff'] = time_diff
65
 
66
+ # organize projections matrix and copy to CUDA
67
  lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
68
  lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
69
  img_metas[0]['lidar2img'] = lidar2img
70
 
71
+ # group image features in advance for sampling, see `sampling_4d` for more details
72
  for lvl, feat in enumerate(mlvl_feats):
73
  B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
74
  N, T, G, C = 6, TN // 6, 4, GC // 4
 
168
  bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
169
  bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
170
 
171
+ # calculate absolute velocity according to time difference
172
  time_diff = img_metas[0]['time_diff'] # [B, F]
173
  if time_diff.shape[1] > 1:
174
  time_diff = time_diff.clone()
 
187
 
188
 
189
  class SparseBEVSelfAttention(BaseModule):
190
+ """Scale-adaptive Self Attention"""
191
  def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
192
  super().__init__(init_cfg)
193
  self.pc_range = pc_range
 
213
 
214
  tau = tau.permute(0, 2, 1) # [B, 8, Q]
215
  attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
216
+
217
+ if pre_attn_mask is not None: # for query denoising
218
  attn_mask[:, :, pre_attn_mask] = float('-inf')
219
+
220
  attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
221
  return self.attention(query_feat, attn_mask=attn_mask)
222
 
 
242
 
243
 
244
  class SparseBEVSampling(BaseModule):
245
+ """Adaptive Spatio-temporal Sampling"""
246
  def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
247
  super().__init__(init_cfg)
248
 
 
311
 
312
 
313
  class AdaptiveMixing(nn.Module):
314
+ """Adaptive Mixing"""
315
  def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
316
  super(AdaptiveMixing, self).__init__()
317