Alfred Liu commited on
Commit
fe8222e
·
1 Parent(s): c2f1911

Support native sampling operator

Browse files
models/csrc/wrapper.py CHANGED
@@ -1,7 +1,14 @@
1
  import torch
2
  import torch.nn.functional as F
3
- from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
4
- from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
 
 
 
 
 
 
 
5
 
6
 
7
  def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
@@ -12,7 +19,7 @@ def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
12
  """
13
  assert scale_weights.shape[-1] == len(mlvl_feats)
14
 
15
- B, _, _, _, C = mlvl_feats[0].shape
16
  _, Q, P, _ = sampling_locations.shape
17
 
18
  sampling_locations = sampling_locations * 2 - 1
@@ -21,7 +28,6 @@ def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
21
  final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
22
 
23
  for lvl, feat in enumerate(mlvl_feats):
24
- feat = feat.permute(0, 4, 1, 2, 3)
25
  out = F.grid_sample(
26
  feat, sampling_locations, mode='bilinear',
27
  padding_mode='zeros', align_corners=True,
@@ -79,9 +85,9 @@ class MSMVSamplingC23456(torch.autograd.Function):
79
 
80
 
81
  def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
82
- if len(mlvl_feats) == 4:
83
  return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
84
- elif len(mlvl_feats) == 5:
85
  return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
86
  else:
87
  return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
 
1
  import torch
2
  import torch.nn.functional as F
3
+
4
+ try:
5
+ from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
6
+ from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
7
+ MSMV_CUDA = True
8
+ except ImportError as e:
9
+ print('Warning: failed to load one or more CUDA extensions, performance may be hurt.')
10
+ print('Error message:', e)
11
+ MSMV_CUDA = False
12
 
13
 
14
  def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
 
19
  """
20
  assert scale_weights.shape[-1] == len(mlvl_feats)
21
 
22
+ B, C, _, _, _ = mlvl_feats[0].shape
23
  _, Q, P, _ = sampling_locations.shape
24
 
25
  sampling_locations = sampling_locations * 2 - 1
 
28
  final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
29
 
30
  for lvl, feat in enumerate(mlvl_feats):
 
31
  out = F.grid_sample(
32
  feat, sampling_locations, mode='bilinear',
33
  padding_mode='zeros', align_corners=True,
 
85
 
86
 
87
  def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
88
+ if len(mlvl_feats) == 4 and MSMV_CUDA:
89
  return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
90
+ elif len(mlvl_feats) == 5 and MSMV_CUDA:
91
  return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
92
  else:
93
  return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
models/sparsebev_transformer.py CHANGED
@@ -10,6 +10,7 @@ from .bbox.utils import decode_bbox
10
  from .utils import inverse_sigmoid, DUMP
11
  from .sparsebev_sampling import sampling_4d, make_sample_points
12
  from .checkpoint import checkpoint as cp
 
13
 
14
 
15
  @TRANSFORMER.register_module()
@@ -73,8 +74,14 @@ class SparseBEVTransformerDecoder(BaseModule):
73
  B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
74
  N, T, G, C = 6, TN // 6, 4, GC // 4
75
  feat = feat.reshape(B, T, N, G, C, H, W)
76
- feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
77
- feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, C, N, H, W]
 
 
 
 
 
 
78
  mlvl_feats[lvl] = feat.contiguous()
79
 
80
  for i in range(self.num_layers):
 
10
  from .utils import inverse_sigmoid, DUMP
11
  from .sparsebev_sampling import sampling_4d, make_sample_points
12
  from .checkpoint import checkpoint as cp
13
+ from .csrc.wrapper import MSMV_CUDA
14
 
15
 
16
  @TRANSFORMER.register_module()
 
74
  B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
75
  N, T, G, C = 6, TN // 6, 4, GC // 4
76
  feat = feat.reshape(B, T, N, G, C, H, W)
77
+
78
+ if MSMV_CUDA: # Our CUDA operator requires channel_last
79
+ feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
80
+ feat = feat.reshape(B*T*G, N, H, W, C)
81
+ else: # Torch's grid_sample requires channel_first
82
+ feat = feat.permute(0, 1, 3, 4, 2, 5, 6) # [B, T, G, C, N, H, W]
83
+ feat = feat.reshape(B*T*G, C, N, H, W)
84
+
85
  mlvl_feats[lvl] = feat.contiguous()
86
 
87
  for i in range(self.num_layers):