Alfred Liu commited on
Commit ·
fe8222e
1
Parent(s): c2f1911
Support native sampling operator
Browse files- models/csrc/wrapper.py +12 -6
- models/sparsebev_transformer.py +9 -2
models/csrc/wrapper.py
CHANGED
|
@@ -1,7 +1,14 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
|
|
@@ -12,7 +19,7 @@ def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
|
|
| 12 |
"""
|
| 13 |
assert scale_weights.shape[-1] == len(mlvl_feats)
|
| 14 |
|
| 15 |
-
B,
|
| 16 |
_, Q, P, _ = sampling_locations.shape
|
| 17 |
|
| 18 |
sampling_locations = sampling_locations * 2 - 1
|
|
@@ -21,7 +28,6 @@ def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
|
|
| 21 |
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
|
| 22 |
|
| 23 |
for lvl, feat in enumerate(mlvl_feats):
|
| 24 |
-
feat = feat.permute(0, 4, 1, 2, 3)
|
| 25 |
out = F.grid_sample(
|
| 26 |
feat, sampling_locations, mode='bilinear',
|
| 27 |
padding_mode='zeros', align_corners=True,
|
|
@@ -79,9 +85,9 @@ class MSMVSamplingC23456(torch.autograd.Function):
|
|
| 79 |
|
| 80 |
|
| 81 |
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
|
| 82 |
-
if len(mlvl_feats) == 4:
|
| 83 |
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 84 |
-
elif len(mlvl_feats) == 5:
|
| 85 |
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 86 |
else:
|
| 87 |
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
|
| 6 |
+
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
|
| 7 |
+
MSMV_CUDA = True
|
| 8 |
+
except ImportError as e:
|
| 9 |
+
print('Warning: failed to load one or more CUDA extensions, performance may be hurt.')
|
| 10 |
+
print('Error message:', e)
|
| 11 |
+
MSMV_CUDA = False
|
| 12 |
|
| 13 |
|
| 14 |
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
|
|
|
|
| 19 |
"""
|
| 20 |
assert scale_weights.shape[-1] == len(mlvl_feats)
|
| 21 |
|
| 22 |
+
B, C, _, _, _ = mlvl_feats[0].shape
|
| 23 |
_, Q, P, _ = sampling_locations.shape
|
| 24 |
|
| 25 |
sampling_locations = sampling_locations * 2 - 1
|
|
|
|
| 28 |
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
|
| 29 |
|
| 30 |
for lvl, feat in enumerate(mlvl_feats):
|
|
|
|
| 31 |
out = F.grid_sample(
|
| 32 |
feat, sampling_locations, mode='bilinear',
|
| 33 |
padding_mode='zeros', align_corners=True,
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
|
| 88 |
+
if len(mlvl_feats) == 4 and MSMV_CUDA:
|
| 89 |
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 90 |
+
elif len(mlvl_feats) == 5 and MSMV_CUDA:
|
| 91 |
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
|
| 92 |
else:
|
| 93 |
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
|
models/sparsebev_transformer.py
CHANGED
|
@@ -10,6 +10,7 @@ from .bbox.utils import decode_bbox
|
|
| 10 |
from .utils import inverse_sigmoid, DUMP
|
| 11 |
from .sparsebev_sampling import sampling_4d, make_sample_points
|
| 12 |
from .checkpoint import checkpoint as cp
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@TRANSFORMER.register_module()
|
|
@@ -73,8 +74,14 @@ class SparseBEVTransformerDecoder(BaseModule):
|
|
| 73 |
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
|
| 74 |
N, T, G, C = 6, TN // 6, 4, GC // 4
|
| 75 |
feat = feat.reshape(B, T, N, G, C, H, W)
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
mlvl_feats[lvl] = feat.contiguous()
|
| 79 |
|
| 80 |
for i in range(self.num_layers):
|
|
|
|
| 10 |
from .utils import inverse_sigmoid, DUMP
|
| 11 |
from .sparsebev_sampling import sampling_4d, make_sample_points
|
| 12 |
from .checkpoint import checkpoint as cp
|
| 13 |
+
from .csrc.wrapper import MSMV_CUDA
|
| 14 |
|
| 15 |
|
| 16 |
@TRANSFORMER.register_module()
|
|
|
|
| 74 |
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
|
| 75 |
N, T, G, C = 6, TN // 6, 4, GC // 4
|
| 76 |
feat = feat.reshape(B, T, N, G, C, H, W)
|
| 77 |
+
|
| 78 |
+
if MSMV_CUDA: # Our CUDA operator requires channel_last
|
| 79 |
+
feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
|
| 80 |
+
feat = feat.reshape(B*T*G, N, H, W, C)
|
| 81 |
+
else: # Torch's grid_sample requires channel_first
|
| 82 |
+
feat = feat.permute(0, 1, 3, 4, 2, 5, 6) # [B, T, G, C, N, H, W]
|
| 83 |
+
feat = feat.reshape(B*T*G, C, N, H, W)
|
| 84 |
+
|
| 85 |
mlvl_feats[lvl] = feat.contiguous()
|
| 86 |
|
| 87 |
for i in range(self.num_layers):
|