Alfred Liu commited on
Commit ·
c2f1911
1
Parent(s): e3f4e57
Add comments to core components
Browse files- models/sparsebev_head.py +6 -1
- models/sparsebev_sampling.py +32 -4
- models/sparsebev_transformer.py +11 -1
models/sparsebev_head.py
CHANGED
|
@@ -70,6 +70,7 @@ class SparseBEVHead(DETRHead):
|
|
| 70 |
query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
|
| 71 |
#query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
|
| 72 |
|
|
|
|
| 73 |
B = mlvl_feats[0].shape[0]
|
| 74 |
query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
|
| 75 |
|
|
@@ -92,7 +93,7 @@ class SparseBEVHead(DETRHead):
|
|
| 92 |
bbox_preds[..., 5:10],
|
| 93 |
], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
|
| 94 |
|
| 95 |
-
if mask_dict is not None and mask_dict['pad_size'] > 0:
|
| 96 |
output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
|
| 97 |
output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
|
| 98 |
output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
|
|
@@ -116,6 +117,10 @@ class SparseBEVHead(DETRHead):
|
|
| 116 |
return outs
|
| 117 |
|
| 118 |
def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
device = init_query_bbox.device
|
| 120 |
indicator0 = torch.zeros([self.num_query, 1], device=device)
|
| 121 |
init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
|
|
|
|
| 70 |
query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
|
| 71 |
#query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
|
| 72 |
|
| 73 |
+
# query denoising
|
| 74 |
B = mlvl_feats[0].shape[0]
|
| 75 |
query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
|
| 76 |
|
|
|
|
| 93 |
bbox_preds[..., 5:10],
|
| 94 |
], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
|
| 95 |
|
| 96 |
+
if mask_dict is not None and mask_dict['pad_size'] > 0: # if using query denoising
|
| 97 |
output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
|
| 98 |
output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
|
| 99 |
output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
|
|
|
|
| 117 |
return outs
|
| 118 |
|
| 119 |
def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
|
| 120 |
+
# mostly borrowed from:
|
| 121 |
+
# - https://github.com/IDEA-Research/DN-DETR/blob/main/models/DN_DAB_DETR/dn_components.py
|
| 122 |
+
# - https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petrv2_dnhead.py
|
| 123 |
+
|
| 124 |
device = init_query_bbox.device
|
| 125 |
indicator0 = torch.zeros([self.num_query, 1], device=device)
|
| 126 |
init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
|
models/sparsebev_sampling.py
CHANGED
|
@@ -25,9 +25,25 @@ def make_sample_points(query_bbox, offset, pc_range):
|
|
| 25 |
|
| 26 |
|
| 27 |
def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
N = 6
|
| 30 |
-
|
| 31 |
sample_points = sample_points.reshape(B, Q, T, G * P, 3)
|
| 32 |
|
| 33 |
# get the projection matrix
|
|
@@ -42,7 +58,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
|
|
| 42 |
sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
|
| 43 |
sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
|
| 44 |
|
| 45 |
-
# project 3d sampling points to
|
| 46 |
sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
|
| 47 |
|
| 48 |
# homo coord -> pixel coord
|
|
@@ -62,6 +78,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
|
|
| 62 |
& (sample_points_cam[..., 0:1] < 1.0)
|
| 63 |
).squeeze(-1).float() # [B, T, N, Q, GP]
|
| 64 |
|
|
|
|
| 65 |
if DUMP.enabled:
|
| 66 |
torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
|
| 67 |
'{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
|
@@ -71,6 +88,7 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
|
|
| 71 |
valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
|
| 72 |
sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
|
| 73 |
|
|
|
|
| 74 |
i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
|
| 75 |
i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
|
| 76 |
i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
|
|
@@ -79,21 +97,31 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
|
|
| 79 |
i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
|
| 80 |
i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
|
| 81 |
i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
|
|
|
|
|
|
|
| 82 |
i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
|
| 83 |
|
|
|
|
| 84 |
sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
|
| 85 |
valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
|
| 89 |
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
|
| 90 |
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
|
| 91 |
|
|
|
|
| 92 |
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
|
| 93 |
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
|
| 94 |
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
|
| 95 |
|
|
|
|
| 96 |
final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
|
|
|
|
|
|
|
| 97 |
C = final.shape[2] # [BTG, Q, C, P]
|
| 98 |
final = final.reshape(B, T, G, Q, C, P)
|
| 99 |
final = final.permute(0, 3, 2, 1, 5, 4)
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
sample_points: 3D sampling points in shape [B, Q, T, G, P, 3]
|
| 31 |
+
mlvl_feats: list of multi-scale features from neck, each in shape [B*T*G, C, N, H, W]
|
| 32 |
+
scale_weights: weights for multi-scale aggregation, [B, Q, G, T, P, L]
|
| 33 |
+
lidar2img: 4x4 projection matrix in shape [B, TN, 4, 4]
|
| 34 |
+
Symbol meaning:
|
| 35 |
+
B: batch size
|
| 36 |
+
Q: num of queries
|
| 37 |
+
T: num of frames
|
| 38 |
+
G: num of groups (we follow the group sampling mechanism of AdaMixer)
|
| 39 |
+
P: num of sampling points per frame per group
|
| 40 |
+
N: num of views (six for nuScenes)
|
| 41 |
+
L: num of layers of feature pyramid (typically it is 4: C2, C3, C4, C5)
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 3]
|
| 45 |
N = 6
|
| 46 |
+
|
| 47 |
sample_points = sample_points.reshape(B, Q, T, G * P, 3)
|
| 48 |
|
| 49 |
# get the projection matrix
|
|
|
|
| 58 |
sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
|
| 59 |
sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
|
| 60 |
|
| 61 |
+
# project 3d sampling points to N views
|
| 62 |
sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
|
| 63 |
|
| 64 |
# homo coord -> pixel coord
|
|
|
|
| 78 |
& (sample_points_cam[..., 0:1] < 1.0)
|
| 79 |
).squeeze(-1).float() # [B, T, N, Q, GP]
|
| 80 |
|
| 81 |
+
# for visualization only
|
| 82 |
if DUMP.enabled:
|
| 83 |
torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
|
| 84 |
'{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
|
|
|
|
| 88 |
valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
|
| 89 |
sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
|
| 90 |
|
| 91 |
+
# prepare batched indexing
|
| 92 |
i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
|
| 93 |
i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
|
| 94 |
i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
|
|
|
|
| 97 |
i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
|
| 98 |
i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
|
| 99 |
i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
|
| 100 |
+
|
| 101 |
+
# we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
|
| 102 |
i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
|
| 103 |
|
| 104 |
+
# index the only one sampling point and its valid flag
|
| 105 |
sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
|
| 106 |
valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
|
| 107 |
|
| 108 |
+
# treat the view index as a new axis for grid_sample and normalize the view index to [0, 1]
|
| 109 |
+
sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)
|
| 110 |
+
|
| 111 |
+
# reorganize the tensor to stack T and G to the batch dim for better parallelism
|
| 112 |
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
|
| 113 |
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
|
| 114 |
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
|
| 115 |
|
| 116 |
+
# reorganize the tensor to stack T and G to the batch dim for better parallelism
|
| 117 |
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
|
| 118 |
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
|
| 119 |
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
|
| 120 |
|
| 121 |
+
# multi-scale multi-view grid sample
|
| 122 |
final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
|
| 123 |
+
|
| 124 |
+
# reorganize the sampled features
|
| 125 |
C = final.shape[2] # [BTG, Q, C, P]
|
| 126 |
final = final.reshape(B, T, G, Q, C, P)
|
| 127 |
final = final.permute(0, 3, 2, 1, 5, 4)
|
models/sparsebev_transformer.py
CHANGED
|
@@ -43,6 +43,7 @@ class SparseBEVTransformerDecoder(BaseModule):
|
|
| 43 |
self.num_layers = num_layers
|
| 44 |
self.pc_range = pc_range
|
| 45 |
|
|
|
|
| 46 |
self.decoder_layer = SparseBEVTransformerDecoderLayer(
|
| 47 |
embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
|
| 48 |
)
|
|
@@ -54,6 +55,7 @@ class SparseBEVTransformerDecoder(BaseModule):
|
|
| 54 |
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
|
| 55 |
cls_scores, bbox_preds = [], []
|
| 56 |
|
|
|
|
| 57 |
timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
|
| 58 |
timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
|
| 59 |
time_diff = timestamps[:, :1, :] - timestamps
|
|
@@ -61,10 +63,12 @@ class SparseBEVTransformerDecoder(BaseModule):
|
|
| 61 |
time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
|
| 62 |
img_metas[0]['time_diff'] = time_diff
|
| 63 |
|
|
|
|
| 64 |
lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
|
| 65 |
lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
|
| 66 |
img_metas[0]['lidar2img'] = lidar2img
|
| 67 |
|
|
|
|
| 68 |
for lvl, feat in enumerate(mlvl_feats):
|
| 69 |
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
|
| 70 |
N, T, G, C = 6, TN // 6, 4, GC // 4
|
|
@@ -164,6 +168,7 @@ class SparseBEVTransformerDecoderLayer(BaseModule):
|
|
| 164 |
bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
|
| 165 |
bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
|
| 166 |
|
|
|
|
| 167 |
time_diff = img_metas[0]['time_diff'] # [B, F]
|
| 168 |
if time_diff.shape[1] > 1:
|
| 169 |
time_diff = time_diff.clone()
|
|
@@ -182,6 +187,7 @@ class SparseBEVTransformerDecoderLayer(BaseModule):
|
|
| 182 |
|
| 183 |
|
| 184 |
class SparseBEVSelfAttention(BaseModule):
|
|
|
|
| 185 |
def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
|
| 186 |
super().__init__(init_cfg)
|
| 187 |
self.pc_range = pc_range
|
|
@@ -207,8 +213,10 @@ class SparseBEVSelfAttention(BaseModule):
|
|
| 207 |
|
| 208 |
tau = tau.permute(0, 2, 1) # [B, 8, Q]
|
| 209 |
attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
|
| 210 |
-
|
|
|
|
| 211 |
attn_mask[:, :, pre_attn_mask] = float('-inf')
|
|
|
|
| 212 |
attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
|
| 213 |
return self.attention(query_feat, attn_mask=attn_mask)
|
| 214 |
|
|
@@ -234,6 +242,7 @@ class SparseBEVSelfAttention(BaseModule):
|
|
| 234 |
|
| 235 |
|
| 236 |
class SparseBEVSampling(BaseModule):
|
|
|
|
| 237 |
def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
|
| 238 |
super().__init__(init_cfg)
|
| 239 |
|
|
@@ -302,6 +311,7 @@ class SparseBEVSampling(BaseModule):
|
|
| 302 |
|
| 303 |
|
| 304 |
class AdaptiveMixing(nn.Module):
|
|
|
|
| 305 |
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
|
| 306 |
super(AdaptiveMixing, self).__init__()
|
| 307 |
|
|
|
|
| 43 |
self.num_layers = num_layers
|
| 44 |
self.pc_range = pc_range
|
| 45 |
|
| 46 |
+
# params are shared across all decoder layers
|
| 47 |
self.decoder_layer = SparseBEVTransformerDecoderLayer(
|
| 48 |
embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
|
| 49 |
)
|
|
|
|
| 55 |
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
|
| 56 |
cls_scores, bbox_preds = [], []
|
| 57 |
|
| 58 |
+
# calculate time difference according to timestamps
|
| 59 |
timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
|
| 60 |
timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
|
| 61 |
time_diff = timestamps[:, :1, :] - timestamps
|
|
|
|
| 63 |
time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
|
| 64 |
img_metas[0]['time_diff'] = time_diff
|
| 65 |
|
| 66 |
+
# organize projections matrix and copy to CUDA
|
| 67 |
lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
|
| 68 |
lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
|
| 69 |
img_metas[0]['lidar2img'] = lidar2img
|
| 70 |
|
| 71 |
+
# group image features in advance for sampling, see `sampling_4d` for more details
|
| 72 |
for lvl, feat in enumerate(mlvl_feats):
|
| 73 |
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
|
| 74 |
N, T, G, C = 6, TN // 6, 4, GC // 4
|
|
|
|
| 168 |
bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
|
| 169 |
bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
|
| 170 |
|
| 171 |
+
# calculate absolute velocity according to time difference
|
| 172 |
time_diff = img_metas[0]['time_diff'] # [B, F]
|
| 173 |
if time_diff.shape[1] > 1:
|
| 174 |
time_diff = time_diff.clone()
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
class SparseBEVSelfAttention(BaseModule):
|
| 190 |
+
"""Scale-adaptive Self Attention"""
|
| 191 |
def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
|
| 192 |
super().__init__(init_cfg)
|
| 193 |
self.pc_range = pc_range
|
|
|
|
| 213 |
|
| 214 |
tau = tau.permute(0, 2, 1) # [B, 8, Q]
|
| 215 |
attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
|
| 216 |
+
|
| 217 |
+
if pre_attn_mask is not None: # for query denoising
|
| 218 |
attn_mask[:, :, pre_attn_mask] = float('-inf')
|
| 219 |
+
|
| 220 |
attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
|
| 221 |
return self.attention(query_feat, attn_mask=attn_mask)
|
| 222 |
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
class SparseBEVSampling(BaseModule):
|
| 245 |
+
"""Adaptive Spatio-temporal Sampling"""
|
| 246 |
def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
|
| 247 |
super().__init__(init_cfg)
|
| 248 |
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
class AdaptiveMixing(nn.Module):
|
| 314 |
+
"""Adaptive Mixing"""
|
| 315 |
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
|
| 316 |
super(AdaptiveMixing, self).__init__()
|
| 317 |
|