File size: 6,298 Bytes
d19bd3e fe8222e d19bd3e fe8222e d19bd3e fe8222e d19bd3e fe8222e d19bd3e 3ea6165 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import torch
import torch.nn.functional as F
try:
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
MSMV_CUDA = True
except ImportError as e:
print('Warning: failed to load one or more CUDA extensions, performance may be hurt.')
print('Error message:', e)
MSMV_CUDA = False
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
"""
value: [B, N, H1W1 + H2W2..., C]
sampling_locations: [B, Q, P, 3]
scale_weights: [B, Q, P, 4]
"""
assert scale_weights.shape[-1] == len(mlvl_feats)
B, C, _, _, _ = mlvl_feats[0].shape
_, Q, P, _ = sampling_locations.shape
sampling_locations = sampling_locations * 2 - 1
sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3]
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
for lvl, feat in enumerate(mlvl_feats):
out = F.grid_sample(
feat, sampling_locations, mode='bilinear',
padding_mode='zeros', align_corners=True,
)[..., 0] # [B, C, Q, P]
out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)
final += out
return final.permute(0, 2, 1, 3)
class MSMVSamplingC2345(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c2345_forward)
return _ms_deform_attn_cuda_c2345_forward(
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c2345_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
class MSMVSamplingC23456(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c23456_forward)
return _ms_deform_attn_cuda_c23456_forward(
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c23456_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
if len(mlvl_feats) == 4 and MSMV_CUDA:
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
elif len(mlvl_feats) == 5 and MSMV_CUDA:
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
else:
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
def msmv_sampling_onnx(mlvl_feats, uv, view_idx, scale_weights):
"""
ONNX-compatible multi-scale multi-view sampling using 4D F.grid_sample.
Replaces the 5D volumetric grid_sample used in msmv_sampling_pytorch with
separate per-view 4D grid_samples followed by a torch.gather for view
selection. All ops are in ONNX opset 16.
Args:
mlvl_feats: list of [BTG, C, N, H, W] channel-first feature maps
uv: [BTG, Q, P, 2] normalised (u, v) in [0, 1]
view_idx: [BTG, Q, P] integer camera-view indices
scale_weights:[BTG, Q, P, L] softmax weights over pyramid levels
Returns:
[BTG, Q, C, P]
"""
BTG, C, N, _, _ = mlvl_feats[0].shape
_, Q, P, _ = uv.shape
# Convert UV from [0, 1] to [-1, 1] for F.grid_sample
uv_gs = uv * 2.0 - 1.0 # [BTG, Q, P, 2]
# Tile UV for all N views: [BTG*N, Q, P, 2]
# Use expand+contiguous+reshape (maps to ONNX Expand, better CoreML EP support
# than repeat_interleave which maps to ONNX Tile and can trip up CoreML)
uv_gs = uv_gs.unsqueeze(1).expand(BTG, N, Q, P, 2).contiguous().reshape(BTG * N, Q, P, 2)
# Pre-expand view_idx for gathering along the N dim: [BTG, C, 1, Q, P]
view_idx_g = view_idx[:, None, None, :, :].expand(BTG, C, 1, Q, P)
final = torch.zeros(BTG, C, Q, P, device=mlvl_feats[0].device, dtype=mlvl_feats[0].dtype)
for lvl, feat in enumerate(mlvl_feats):
_, _, _, H_lvl, W_lvl = feat.shape
# [BTG, C, N, H, W] -> [BTG, N, C, H, W] -> [BTG*N, C, H, W]
feat_4d = feat.permute(0, 2, 1, 3, 4).reshape(BTG * N, C, H_lvl, W_lvl)
# 4D grid_sample: [BTG*N, C, Q, P]
sampled = F.grid_sample(feat_4d, uv_gs, mode='bilinear', padding_mode='zeros', align_corners=True)
# [BTG*N, C, Q, P] -> [BTG, N, C, Q, P] -> [BTG, C, N, Q, P]
sampled = sampled.reshape(BTG, N, C, Q, P).permute(0, 2, 1, 3, 4)
# Gather the selected camera view: [BTG, C, 1, Q, P] -> [BTG, C, Q, P]
sampled = torch.gather(sampled, 2, view_idx_g).squeeze(2)
# Accumulate with per-level scale weight
w = scale_weights[..., lvl].reshape(BTG, 1, Q, P)
final = final + sampled * w
return final.permute(0, 2, 1, 3) # [BTG, Q, C, P]
|