Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The IDEA Authors. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------------------------------ | |
| # Deformable DETR | |
| # Copyright (c) 2020 SenseTime. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------------------------------ | |
| # Modified from: | |
| # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py | |
| # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py | |
| # ------------------------------------------------------------------------------------------------ | |
| from __future__ import absolute_import, division, print_function | |
| import unittest | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.autograd import gradcheck | |
| from detrex.layers.multi_scale_deform_attn import MultiScaleDeformableAttnFunction | |
| N, M, D = 1, 2, 2 | |
| Lq, L, P = 2, 2, 2 | |
| shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() | |
| level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) | |
| S = sum([(H * W).item() for H, W in shapes]) | |
| class TestMsDeformAttn(unittest.TestCase): | |
| def ms_deform_attn_core_pytorch( | |
| self, value, value_spatial_shapes, sampling_locations, attention_weights | |
| ): | |
| # for debug and test only, | |
| # need to use cuda version instead | |
| N_, S_, M_, D_ = value.shape | |
| _, Lq_, M_, L_, P_, _ = sampling_locations.shape | |
| value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) | |
| sampling_grids = 2 * sampling_locations - 1 | |
| sampling_value_list = [] | |
| for lid_, (H_, W_) in enumerate(value_spatial_shapes): | |
| # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ | |
| value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) | |
| # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 | |
| sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) | |
| # N_*M_, D_, Lq_, P_ | |
| sampling_value_l_ = F.grid_sample( | |
| value_l_, | |
| sampling_grid_l_, | |
| mode="bilinear", | |
| padding_mode="zeros", | |
| align_corners=False, | |
| ) | |
| sampling_value_list.append(sampling_value_l_) | |
| # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) | |
| attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) | |
| output = ( | |
| (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) | |
| .sum(-1) | |
| .view(N_, M_ * D_, Lq_) | |
| ) | |
| return output.transpose(1, 2).contiguous() | |
| def check_gradient_numerical( | |
| self, channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True | |
| ): | |
| value = torch.rand(N, S, M, channels).cuda() * 0.01 | |
| sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() | |
| attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 | |
| attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) | |
| im2col_step = 2 | |
| func = MultiScaleDeformableAttnFunction.apply | |
| value.requires_grad = grad_value | |
| sampling_locations.requires_grad = grad_sampling_loc | |
| attention_weights.requires_grad = grad_attn_weight | |
| gradok = gradcheck( | |
| func, | |
| ( | |
| value.double(), | |
| shapes, | |
| level_start_index, | |
| sampling_locations.double(), | |
| attention_weights.double(), | |
| im2col_step, | |
| ), | |
| ) | |
| return gradok | |
| def test_forward_equal_with_pytorch_double(self): | |
| value = torch.rand(N, S, M, D).cuda() * 0.01 | |
| sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() | |
| attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 | |
| attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) | |
| im2col_step = 2 | |
| output_pytorch = ( | |
| self.ms_deform_attn_core_pytorch( | |
| value.double(), shapes, sampling_locations.double(), attention_weights.double() | |
| ) | |
| .detach() | |
| .cpu() | |
| ) | |
| output_cuda = ( | |
| MultiScaleDeformableAttnFunction.apply( | |
| value.double(), | |
| shapes, | |
| level_start_index, | |
| sampling_locations.double(), | |
| attention_weights.double(), | |
| im2col_step, | |
| ) | |
| .detach() | |
| .cpu() | |
| ) | |
| self.assertTrue(torch.allclose(output_cuda, output_pytorch)) | |
| def test_gradient_numerical(self): | |
| for channels in [30, 32, 64, 71, 1025]: | |
| self.assertTrue(self.check_gradient_numerical(channels, True, True, True)) | |