SparseBev / models /sparsebev_sampling.py
chinmaygarde's picture
Attempt an export to ONNX.
3ea6165 unverified
import torch
import torch.nn.functional as F
from .bbox.utils import decode_bbox
from .utils import rotation_3d_in_axis, DUMP
from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch, msmv_sampling_onnx, MSMV_CUDA
def make_sample_points(query_bbox, offset, pc_range):
'''
query_bbox: [B, Q, 10]
offset: [B, Q, num_points, 4], normalized by stride
'''
query_bbox = decode_bbox(query_bbox, pc_range) # [B, Q, 9]
xyz = query_bbox[..., 0:3] # [B, Q, 3]
wlh = query_bbox[..., 3:6] # [B, Q, 3]
ang = query_bbox[..., 6:7] # [B, Q, 1]
delta_xyz = offset[..., 0:3] # [B, Q, P, 3]
delta_xyz = wlh[:, :, None, :] * delta_xyz # [B, Q, P, 3]
delta_xyz = rotation_3d_in_axis(delta_xyz, ang) # [B, Q, P, 3]
sample_xyz = xyz[:, :, None, :] + delta_xyz # [B, Q, P, 3]
return sample_xyz # [B, Q, P, 3]
def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
"""
Args:
sample_points: 3D sampling points in shape [B, Q, T, G, P, 3]
mlvl_feats: list of multi-scale features from neck, each in shape [B*T*G, C, N, H, W]
scale_weights: weights for multi-scale aggregation, [B, Q, G, T, P, L]
lidar2img: 4x4 projection matrix in shape [B, TN, 4, 4]
Symbol meaning:
B: batch size
Q: num of queries
T: num of frames
G: num of groups (we follow the group sampling mechanism of AdaMixer)
P: num of sampling points per frame per group
N: num of views (six for nuScenes)
L: num of layers of feature pyramid (typically it is 4: C2, C3, C4, C5)
"""
B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 3]
N = 6
sample_points = sample_points.reshape(B, Q, T, G * P, 3)
# get the projection matrix
lidar2img = lidar2img[:, :, None, None, :, :] # [B, TN, 1, 1, 4, 4]
lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)
lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)
# expand the points
ones = torch.ones_like(sample_points[..., :1])
sample_points = torch.cat([sample_points, ones], dim=-1) # [B, Q, GP, 4]
sample_points = sample_points[:, :, None, ..., None] # [B, Q, T, GP, 4]
sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
# project 3d sampling points to N views
sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
# homo coord -> pixel coord
homo = sample_points_cam[..., 2:3]
homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)
sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero # [B, T, N, Q, GP, 2]
# normalize
sample_points_cam[..., 0] /= image_w
sample_points_cam[..., 1] /= image_h
# check if out of image
valid_mask = ((homo > eps) \
& (sample_points_cam[..., 1:2] > 0.0)
& (sample_points_cam[..., 1:2] < 1.0)
& (sample_points_cam[..., 0:1] > 0.0)
& (sample_points_cam[..., 0:1] < 1.0)
).squeeze(-1).float() # [B, T, N, Q, GP]
# for visualization only
if DUMP.enabled:
torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1).cpu(),
'{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
torch.save(valid_mask.cpu(),
'{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
# we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
i_view = torch.argmax(valid_mask, dim=-1, keepdim=True) # [B, T, Q, GP, 1]
if MSMV_CUDA:
# Original fancy-indexing path (used with CUDA kernel on Linux/Windows)
i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :]
valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view]
# treat the view index as a new axis for grid_sample, normalise to [0, 1]
sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6)
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
else:
# ONNX-compatible path: torch.gather + 4D grid_sample (no custom CUDA ops)
# Select best-view UV coords via gather [B, T, Q, GP, 1, 2]
i_view_uv = i_view.unsqueeze(-1).expand(B, T, Q, G * P, 1, 2)
sample_points_cam = torch.gather(sample_points_cam, 4, i_view_uv).squeeze(4) # [B, T, Q, GP, 2]
# Reorganize UV to [B*T*G, Q, P, 2]
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 2)
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5) # [B, T, G, Q, P, 2]
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 2)
# Reorganize view_idx to [B*T*G, Q, P]
i_view = i_view.squeeze(4).reshape(B, T, Q, G, P)
i_view = i_view.permute(0, 1, 3, 2, 4).reshape(B*T*G, Q, P)
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
final = msmv_sampling_onnx(mlvl_feats, sample_points_cam, i_view, scale_weights)
# reorganize the sampled features
C = final.shape[2] # [BTG, Q, C, P]
final = final.reshape(B, T, G, Q, C, P)
final = final.permute(0, 3, 2, 1, 5, 4)
final = final.flatten(3, 4) # [B, Q, G, FP, C]
return final