CU1-X / rfdetr /models /ops /functions /ms_deform_attn_func.py
AI-DrivenTesting's picture
init
77da9e2
raw
history blame
2.67 kB
# ------------------------------------------------------------------------
# RF-DETR
# Copyright (c) 2025 Roboflow. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR)
# Copyright (c) 2024 Baidu. All Rights Reserved.
# ------------------------------------------------------------------------------------------------
# Modified from Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------------------------------
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
# ------------------------------------------------------------------------------------------------
"""
ms_deform_attn_func
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
""""for debug and test only, need to use cuda version instead
"""
# B, n_heads, head_dim, N
B, n_heads, head_dim, _ = value.shape
_, Len_q, n_heads, L, P, _ = sampling_locations.shape
value_list = value.split([H * W for H, W in value_spatial_shapes], dim=3)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H, W) in enumerate(value_spatial_shapes):
# B, n_heads, head_dim, H, W
value_l_ = value_list[lid_].view(B * n_heads, head_dim, H, W)
# B, Len_q, n_heads, P, 2 -> B, n_heads, Len_q, P, 2 -> B*n_heads, Len_q, P, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# B*n_heads, head_dim, Len_q, P
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
mode='bilinear', padding_mode='zeros', align_corners=False)
sampling_value_list.append(sampling_value_l_)
# (B, Len_q, n_heads, L * P) -> (B, n_heads, Len_q, L, P) -> (B*n_heads, 1, Len_q, L*P)
attention_weights = attention_weights.transpose(1, 2).reshape(B * n_heads, 1, Len_q, L * P)
# B*n_heads, head_dim, Len_q, L*P
sampling_value_list = torch.stack(sampling_value_list, dim=-2).flatten(-2)
output = (sampling_value_list * attention_weights).sum(-1).view(B, n_heads * head_dim, Len_q)
return output.transpose(1, 2).contiguous()