File size: 6,722 Bytes
d19bd3e
 
 
 
3ea6165
d19bd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f1911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d19bd3e
c2f1911
d19bd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f1911
d19bd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f1911
d19bd3e
3403979
d19bd3e
3403979
d19bd3e
 
 
 
 
c2f1911
3ea6165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2f1911
 
d19bd3e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn.functional as F
from .bbox.utils import decode_bbox
from .utils import rotation_3d_in_axis, DUMP
from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch, msmv_sampling_onnx, MSMV_CUDA


def make_sample_points(query_bbox, offset, pc_range):
    '''
    query_bbox: [B, Q, 10]
    offset: [B, Q, num_points, 4], normalized by stride
    '''
    query_bbox = decode_bbox(query_bbox, pc_range)  # [B, Q, 9]

    xyz = query_bbox[..., 0:3]  # [B, Q, 3]
    wlh = query_bbox[..., 3:6]  # [B, Q, 3]
    ang = query_bbox[..., 6:7]  # [B, Q, 1]

    delta_xyz = offset[..., 0:3]  # [B, Q, P, 3]
    delta_xyz = wlh[:, :, None, :] * delta_xyz  # [B, Q, P, 3]
    delta_xyz = rotation_3d_in_axis(delta_xyz, ang)  # [B, Q, P, 3]
    sample_xyz = xyz[:, :, None, :] + delta_xyz  # [B, Q, P, 3]

    return sample_xyz  # [B, Q, P, 3]


def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
    """
    Args:
        sample_points: 3D sampling points in shape [B, Q, T, G, P, 3]
        mlvl_feats: list of multi-scale features from neck, each in shape [B*T*G, C, N, H, W]
        scale_weights: weights for multi-scale aggregation, [B, Q, G, T, P, L]
        lidar2img: 4x4 projection matrix in shape [B, TN, 4, 4]
    Symbol meaning:
        B: batch size
        Q: num of queries
        T: num of frames
        G: num of groups (we follow the group sampling mechanism of AdaMixer)
        P: num of sampling points per frame per group
        N: num of views (six for nuScenes)
        L: num of layers of feature pyramid (typically it is 4: C2, C3, C4, C5)
    """

    B, Q, T, G, P, _ = sample_points.shape  # [B, Q, T, G, P, 3]
    N = 6
    
    sample_points = sample_points.reshape(B, Q, T, G * P, 3)

    # get the projection matrix
    lidar2img = lidar2img[:, :, None, None, :, :]  # [B, TN, 1, 1, 4, 4]
    lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)
    lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)

    # expand the points
    ones = torch.ones_like(sample_points[..., :1])
    sample_points = torch.cat([sample_points, ones], dim=-1)  # [B, Q, GP, 4]
    sample_points = sample_points[:, :, None, ..., None]     # [B, Q, T, GP, 4]
    sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
    sample_points = sample_points.transpose(1, 3)   # [B, T, N, Q, GP, 4, 1]

    # project 3d sampling points to N views
    sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1)  # [B, T, N, Q, GP, 4]

    # homo coord -> pixel coord
    homo = sample_points_cam[..., 2:3]
    homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)
    sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero  # [B, T, N, Q, GP, 2]

    # normalize
    sample_points_cam[..., 0] /= image_w
    sample_points_cam[..., 1] /= image_h

    # check if out of image
    valid_mask = ((homo > eps) \
        & (sample_points_cam[..., 1:2] > 0.0)
        & (sample_points_cam[..., 1:2] < 1.0)
        & (sample_points_cam[..., 0:1] > 0.0)
        & (sample_points_cam[..., 0:1] < 1.0)
    ).squeeze(-1).float()  # [B, T, N, Q, GP]

    # for visualization only
    if DUMP.enabled:
        torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1).cpu(),
                   '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
        torch.save(valid_mask.cpu(),
                   '{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))

    valid_mask = valid_mask.permute(0, 1, 3, 4, 2)  # [B, T, Q, GP, N]
    sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5)  # [B, T, Q, GP, N, 2]

    # we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
    i_view = torch.argmax(valid_mask, dim=-1, keepdim=True)  # [B, T, Q, GP, 1]

    if MSMV_CUDA:
        # Original fancy-indexing path (used with CUDA kernel on Linux/Windows)
        i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
        i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
        i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
        i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
        i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
        i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
        i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
        i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)

        sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :]
        valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view]

        # treat the view index as a new axis for grid_sample, normalise to [0, 1]
        sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)

        sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
        sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6)
        sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)

        scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
        scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
        scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)

        final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
    else:
        # ONNX-compatible path: torch.gather + 4D grid_sample (no custom CUDA ops)
        # Select best-view UV coords via gather  [B, T, Q, GP, 1, 2]
        i_view_uv = i_view.unsqueeze(-1).expand(B, T, Q, G * P, 1, 2)
        sample_points_cam = torch.gather(sample_points_cam, 4, i_view_uv).squeeze(4)  # [B, T, Q, GP, 2]

        # Reorganize UV to [B*T*G, Q, P, 2]
        sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 2)
        sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5)  # [B, T, G, Q, P, 2]
        sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 2)

        # Reorganize view_idx to [B*T*G, Q, P]
        i_view = i_view.squeeze(4).reshape(B, T, Q, G, P)
        i_view = i_view.permute(0, 1, 3, 2, 4).reshape(B*T*G, Q, P)

        scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
        scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
        scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)

        final = msmv_sampling_onnx(mlvl_feats, sample_points_cam, i_view, scale_weights)

    # reorganize the sampled features
    C = final.shape[2]  # [BTG, Q, C, P]
    final = final.reshape(B, T, G, Q, C, P)
    final = final.permute(0, 3, 2, 1, 5, 4)
    final = final.flatten(3, 4)  # [B, Q, G, FP, C]

    return final