harness / diffs /41470.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp b/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
deleted file mode 100644
index 388a73d22d4c..000000000000
--- a/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#include <vector>
-
-#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
-
-
-at::Tensor
-ms_deform_attn_cpu_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step)
-{
- AT_ERROR("Not implement on cpu");
-}
-
-std::vector<at::Tensor>
-ms_deform_attn_cpu_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step)
-{
- AT_ERROR("Not implement on cpu");
-}
diff --git a/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h b/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
deleted file mode 100644
index 7eac8c8bcd1b..000000000000
--- a/src/transformers/kernels/deta/cpu/ms_deform_attn_cpu.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#pragma once
-#include <torch/extension.h>
-
-at::Tensor
-ms_deform_attn_cpu_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step);
-
-std::vector<at::Tensor>
-ms_deform_attn_cpu_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step);
-
diff --git a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu b/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
deleted file mode 100644
index 8ea1d7fabe26..000000000000
--- a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cu
+++ /dev/null
@@ -1,156 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#include <vector>
-#include "cuda/ms_deform_im2col_cuda.cuh"
-
-#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <cuda.h>
-#include <cuda_runtime.h>
-
-#pragma once
-#include <torch/extension.h>
-
-
-at::Tensor ms_deform_attn_cuda_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step)
-{
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
-
- AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
- AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
- AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
- AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
- AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
-
- const int batch = value.size(0);
- const int spatial_size = value.size(1);
- const int num_heads = value.size(2);
- const int channels = value.size(3);
-
- const int num_levels = spatial_shapes.size(0);
-
- const int num_query = sampling_loc.size(1);
- const int num_point = sampling_loc.size(4);
-
- const int im2col_step_ = std::min(batch, im2col_step);
-
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
-
- auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
-
- const int batch_n = im2col_step_;
- auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
- auto per_value_size = spatial_size * num_heads * channels;
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
- for (int n = 0; n < batch/im2col_step_; ++n)
- {
- auto columns = output_n.select(0, n);
- AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
- ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
- value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- spatial_shapes.data<int64_t>(),
- level_start_index.data<int64_t>(),
- sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
- columns.data<scalar_t>());
-
- }));
- }
-
- output = output.view({batch, num_query, num_heads*channels});
-
- return output;
-}
-
-
-std::vector<at::Tensor> ms_deform_attn_cuda_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step)
-{
-
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
- AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
-
- AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
- AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
- AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
- AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
- AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
- AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
-
- const int batch = value.size(0);
- const int spatial_size = value.size(1);
- const int num_heads = value.size(2);
- const int channels = value.size(3);
-
- const int num_levels = spatial_shapes.size(0);
-
- const int num_query = sampling_loc.size(1);
- const int num_point = sampling_loc.size(4);
-
- const int im2col_step_ = std::min(batch, im2col_step);
-
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
-
- auto grad_value = at::zeros_like(value);
- auto grad_sampling_loc = at::zeros_like(sampling_loc);
- auto grad_attn_weight = at::zeros_like(attn_weight);
-
- const int batch_n = im2col_step_;
- auto per_value_size = spatial_size * num_heads * channels;
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
- auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
-
- for (int n = 0; n < batch/im2col_step_; ++n)
- {
- auto grad_output_g = grad_output_n.select(0, n);
- AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
- ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
- grad_output_g.data<scalar_t>(),
- value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- spatial_shapes.data<int64_t>(),
- level_start_index.data<int64_t>(),
- sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
- grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
-
- }));
- }
-
- return {
- grad_value, grad_sampling_loc, grad_attn_weight
- };
-}
diff --git a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh b/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
deleted file mode 100644
index 34f8ae9cb77b..000000000000
--- a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.cuh
+++ /dev/null
@@ -1,1467 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#include <vector>
-
-#include <cuda.h>
-#include <cuda_runtime.h>
-
-#include <cstdio>
-#include <algorithm>
-#include <cstring>
-
-#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
-
-#include <THC/THCAtomics.cuh>
-
-#define CUDA_KERNEL_LOOP(i, n) \
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
- i < (n); \
- i += blockDim.x * gridDim.x)
-
-
-at::Tensor ms_deform_attn_cuda_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step)
-{
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
-
- AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
- AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
- AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
- AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
- AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
-
- const int batch = value.size(0);
- const int spatial_size = value.size(1);
- const int num_heads = value.size(2);
- const int channels = value.size(3);
-
- const int num_levels = spatial_shapes.size(0);
-
- const int num_query = sampling_loc.size(1);
- const int num_point = sampling_loc.size(4);
-
- const int im2col_step_ = std::min(batch, im2col_step);
-
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
-
- auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
-
- const int batch_n = im2col_step_;
- auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
- auto per_value_size = spatial_size * num_heads * channels;
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
- for (int n = 0; n < batch/im2col_step_; ++n)
- {
- auto columns = output_n.select(0, n);
- AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
- ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
- value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- spatial_shapes.data<int64_t>(),
- level_start_index.data<int64_t>(),
- sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
- columns.data<scalar_t>());
-
- }));
- }
-
- output = output.view({batch, num_query, num_heads*channels});
-
- return output;
-}
-
-
-std::vector<at::Tensor> ms_deform_attn_cuda_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step)
-{
-
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
- AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
-
- AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
- AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
- AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
- AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
- AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
- AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
-
- const int batch = value.size(0);
- const int spatial_size = value.size(1);
- const int num_heads = value.size(2);
- const int channels = value.size(3);
-
- const int num_levels = spatial_shapes.size(0);
-
- const int num_query = sampling_loc.size(1);
- const int num_point = sampling_loc.size(4);
-
- const int im2col_step_ = std::min(batch, im2col_step);
-
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
-
- auto grad_value = at::zeros_like(value);
- auto grad_sampling_loc = at::zeros_like(sampling_loc);
- auto grad_attn_weight = at::zeros_like(attn_weight);
-
- const int batch_n = im2col_step_;
- auto per_value_size = spatial_size * num_heads * channels;
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
- auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
-
- for (int n = 0; n < batch/im2col_step_; ++n)
- {
- auto grad_output_g = grad_output_n.select(0, n);
- AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
- ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
- grad_output_g.data<scalar_t>(),
- value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- spatial_shapes.data<int64_t>(),
- level_start_index.data<int64_t>(),
- sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
- grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
- grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
- grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
-
- }));
- }
-
- return {
- grad_value, grad_sampling_loc, grad_attn_weight
- };
-}
-
-const int CUDA_NUM_THREADS = 1024;
-inline int GET_BLOCKS(const int N, const int num_threads)
-{
- return (N + num_threads - 1) / num_threads;
-}
-
-
-template <typename scalar_t>
-__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- }
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- return val;
-}
-
-
-template <typename scalar_t>
-__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
- const scalar_t &top_grad,
- const scalar_t &attn_weight,
- scalar_t* &grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
- const scalar_t top_grad_value = top_grad * attn_weight;
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- grad_h_weight -= hw * v1;
- grad_w_weight -= hh * v1;
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- grad_h_weight -= lw * v2;
- grad_w_weight += hh * v2;
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- grad_h_weight += hw * v3;
- grad_w_weight -= lh * v3;
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- grad_h_weight += lw * v4;
- grad_w_weight += lh * v4;
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
- }
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- *grad_attn_weight = top_grad * val;
- *grad_sampling_loc = width * grad_w_weight * top_grad_value;
- *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
-}
-
-
-template <typename scalar_t>
-__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
- const scalar_t &top_grad,
- const scalar_t &attn_weight,
- scalar_t* &grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
- const scalar_t top_grad_value = top_grad * attn_weight;
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- grad_h_weight -= hw * v1;
- grad_w_weight -= hh * v1;
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- grad_h_weight -= lw * v2;
- grad_w_weight += hh * v2;
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- grad_h_weight += hw * v3;
- grad_w_weight -= lh * v3;
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- grad_h_weight += lw * v4;
- grad_w_weight += lh * v4;
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
- }
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- atomicAdd(grad_attn_weight, top_grad * val);
- atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
- atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_im2col_gpu_kernel(const int n,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *data_col)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- scalar_t *data_col_ptr = data_col + index;
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
- scalar_t col = 0;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
-
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
- }
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- }
- }
- *data_col_ptr = col;
- }
-}
-
-template <typename scalar_t, unsigned int blockSize>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
- if (tid == 0)
- {
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
- int sid=2;
- for (unsigned int tid = 1; tid < blockSize; ++tid)
- {
- _grad_w += cache_grad_sampling_loc[sid];
- _grad_h += cache_grad_sampling_loc[sid + 1];
- _grad_a += cache_grad_attn_weight[tid];
- sid += 2;
- }
-
-
- *grad_sampling_loc = _grad_w;
- *(grad_sampling_loc + 1) = _grad_h;
- *grad_attn_weight = _grad_a;
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t, unsigned int blockSize>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockSize/2; s>0; s>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- *grad_sampling_loc = cache_grad_sampling_loc[0];
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
- *grad_attn_weight = cache_grad_attn_weight[0];
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
- if (tid == 0)
- {
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
- int sid=2;
- for (unsigned int tid = 1; tid < blockDim.x; ++tid)
- {
- _grad_w += cache_grad_sampling_loc[sid];
- _grad_h += cache_grad_sampling_loc[sid + 1];
- _grad_a += cache_grad_attn_weight[tid];
- sid += 2;
- }
-
-
- *grad_sampling_loc = _grad_w;
- *(grad_sampling_loc + 1) = _grad_h;
- *grad_attn_weight = _grad_a;
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- if (tid + (s << 1) < spre)
- {
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
- }
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- *grad_sampling_loc = cache_grad_sampling_loc[0];
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
- *grad_attn_weight = cache_grad_attn_weight[0];
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- if (tid + (s << 1) < spre)
- {
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
- }
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
- atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
- atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear_gm(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- grad_sampling_loc, grad_attn_weight);
- }
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-void ms_deformable_im2col_cuda(cudaStream_t stream,
- const scalar_t* data_value,
- const int64_t* data_spatial_shapes,
- const int64_t* data_level_start_index,
- const scalar_t* data_sampling_loc,
- const scalar_t* data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t* data_col)
-{
- const int num_kernels = batch_size * num_query * num_heads * channels;
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
- const int num_threads = CUDA_NUM_THREADS;
- ms_deformable_im2col_gpu_kernel<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
- batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
-
- cudaError_t err = cudaGetLastError();
- if (err != cudaSuccess)
- {
- printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
- }
-
-}
-
-template <typename scalar_t>
-void ms_deformable_col2im_cuda(cudaStream_t stream,
- const scalar_t* grad_col,
- const scalar_t* data_value,
- const int64_t * data_spatial_shapes,
- const int64_t * data_level_start_index,
- const scalar_t * data_sampling_loc,
- const scalar_t * data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t* grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
- const int num_kernels = batch_size * num_query * num_heads * channels;
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
- if (channels > 1024)
- {
- if ((channels & 1023) == 0)
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- else
- {
- ms_deformable_col2im_gpu_kernel_gm<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- }
- else{
- switch(channels)
- {
- case 1:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 2:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 4:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 8:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 16:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 32:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 64:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 128:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 256:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 512:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 1024:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- default:
- if (channels < 64)
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- else
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- }
- }
- cudaError_t err = cudaGetLastError();
- if (err != cudaSuccess)
- {
- printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
- }
-
-}
diff --git a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h b/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
deleted file mode 100644
index fbcf4543e66b..000000000000
--- a/src/transformers/kernels/deta/cuda/ms_deform_attn_cuda.h
+++ /dev/null
@@ -1,29 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#pragma once
-#include <torch/extension.h>
-
-at::Tensor ms_deform_attn_cuda_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step);
-
-std::vector<at::Tensor> ms_deform_attn_cuda_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step);
diff --git a/src/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh b/src/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
deleted file mode 100644
index c0db0c88c9db..000000000000
--- a/src/transformers/kernels/deta/cuda/ms_deform_im2col_cuda.cuh
+++ /dev/null
@@ -1,1327 +0,0 @@
-/*!
-**************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************
-* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
-* Copyright (c) 2018 Microsoft
-**************************************************************************
-*/
-
-#include <cstdio>
-#include <algorithm>
-#include <cstring>
-
-#include <ATen/ATen.h>
-#include <ATen/cuda/CUDAContext.h>
-
-#include <THC/THCAtomics.cuh>
-
-#define CUDA_KERNEL_LOOP(i, n) \
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
- i < (n); \
- i += blockDim.x * gridDim.x)
-
-const int CUDA_NUM_THREADS = 1024;
-inline int GET_BLOCKS(const int N, const int num_threads)
-{
- return (N + num_threads - 1) / num_threads;
-}
-
-
-template <typename scalar_t>
-__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- }
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- return val;
-}
-
-
-template <typename scalar_t>
-__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
- const scalar_t &top_grad,
- const scalar_t &attn_weight,
- scalar_t* &grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
- const scalar_t top_grad_value = top_grad * attn_weight;
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- grad_h_weight -= hw * v1;
- grad_w_weight -= hh * v1;
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- grad_h_weight -= lw * v2;
- grad_w_weight += hh * v2;
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- grad_h_weight += hw * v3;
- grad_w_weight -= lh * v3;
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- grad_h_weight += lw * v4;
- grad_w_weight += lh * v4;
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
- }
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- *grad_attn_weight = top_grad * val;
- *grad_sampling_loc = width * grad_w_weight * top_grad_value;
- *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
-}
-
-
-template <typename scalar_t>
-__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
- const int &height, const int &width, const int &nheads, const int &channels,
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
- const scalar_t &top_grad,
- const scalar_t &attn_weight,
- scalar_t* &grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int h_low = floor(h);
- const int w_low = floor(w);
- const int h_high = h_low + 1;
- const int w_high = w_low + 1;
-
- const scalar_t lh = h - h_low;
- const scalar_t lw = w - w_low;
- const scalar_t hh = 1 - lh, hw = 1 - lw;
-
- const int w_stride = nheads * channels;
- const int h_stride = width * w_stride;
- const int h_low_ptr_offset = h_low * h_stride;
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
- const int w_low_ptr_offset = w_low * w_stride;
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
- const int base_ptr = m * channels + c;
-
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
- const scalar_t top_grad_value = top_grad * attn_weight;
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
-
- scalar_t v1 = 0;
- if (h_low >= 0 && w_low >= 0)
- {
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
- v1 = bottom_data[ptr1];
- grad_h_weight -= hw * v1;
- grad_w_weight -= hh * v1;
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
- }
- scalar_t v2 = 0;
- if (h_low >= 0 && w_high <= width - 1)
- {
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
- v2 = bottom_data[ptr2];
- grad_h_weight -= lw * v2;
- grad_w_weight += hh * v2;
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
- }
- scalar_t v3 = 0;
- if (h_high <= height - 1 && w_low >= 0)
- {
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
- v3 = bottom_data[ptr3];
- grad_h_weight += hw * v3;
- grad_w_weight -= lh * v3;
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
- }
- scalar_t v4 = 0;
- if (h_high <= height - 1 && w_high <= width - 1)
- {
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
- v4 = bottom_data[ptr4];
- grad_h_weight += lw * v4;
- grad_w_weight += lh * v4;
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
- }
-
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- atomicAdd(grad_attn_weight, top_grad * val);
- atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
- atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_im2col_gpu_kernel(const int n,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *data_col)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- scalar_t *data_col_ptr = data_col + index;
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
- scalar_t col = 0;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
-
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
- }
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- }
- }
- *data_col_ptr = col;
- }
-}
-
-template <typename scalar_t, unsigned int blockSize>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
- if (tid == 0)
- {
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
- int sid=2;
- for (unsigned int tid = 1; tid < blockSize; ++tid)
- {
- _grad_w += cache_grad_sampling_loc[sid];
- _grad_h += cache_grad_sampling_loc[sid + 1];
- _grad_a += cache_grad_attn_weight[tid];
- sid += 2;
- }
-
-
- *grad_sampling_loc = _grad_w;
- *(grad_sampling_loc + 1) = _grad_h;
- *grad_attn_weight = _grad_a;
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t, unsigned int blockSize>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockSize/2; s>0; s>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- *grad_sampling_loc = cache_grad_sampling_loc[0];
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
- *grad_attn_weight = cache_grad_attn_weight[0];
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
- if (tid == 0)
- {
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
- int sid=2;
- for (unsigned int tid = 1; tid < blockDim.x; ++tid)
- {
- _grad_w += cache_grad_sampling_loc[sid];
- _grad_h += cache_grad_sampling_loc[sid + 1];
- _grad_a += cache_grad_attn_weight[tid];
- sid += 2;
- }
-
-
- *grad_sampling_loc = _grad_w;
- *(grad_sampling_loc + 1) = _grad_h;
- *grad_attn_weight = _grad_a;
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- if (tid + (s << 1) < spre)
- {
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
- }
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- *grad_sampling_loc = cache_grad_sampling_loc[0];
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
- *grad_attn_weight = cache_grad_attn_weight[0];
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- extern __shared__ int _s[];
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
- unsigned int tid = threadIdx.x;
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
- *(cache_grad_attn_weight+threadIdx.x)=0;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
- }
-
- __syncthreads();
-
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
- {
- if (tid < s) {
- const unsigned int xid1 = tid << 1;
- const unsigned int xid2 = (tid + s) << 1;
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
- if (tid + (s << 1) < spre)
- {
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
- }
- }
- __syncthreads();
- }
-
- if (tid == 0)
- {
- atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
- atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
- atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
- }
- __syncthreads();
-
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
- const scalar_t *grad_col,
- const scalar_t *data_value,
- const int64_t *data_spatial_shapes,
- const int64_t *data_level_start_index,
- const scalar_t *data_sampling_loc,
- const scalar_t *data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t *grad_value,
- scalar_t *grad_sampling_loc,
- scalar_t *grad_attn_weight)
-{
- CUDA_KERNEL_LOOP(index, n)
- {
- int _temp = index;
- const int c_col = _temp % channels;
- _temp /= channels;
- const int sampling_index = _temp;
- const int m_col = _temp % num_heads;
- _temp /= num_heads;
- const int q_col = _temp % num_query;
- _temp /= num_query;
- const int b_col = _temp;
-
- const scalar_t top_grad = grad_col[index];
-
- int data_weight_ptr = sampling_index * num_levels * num_point;
- int data_loc_w_ptr = data_weight_ptr << 1;
- const int grad_sampling_ptr = data_weight_ptr;
- grad_sampling_loc += grad_sampling_ptr << 1;
- grad_attn_weight += grad_sampling_ptr;
- const int grad_weight_stride = 1;
- const int grad_loc_stride = 2;
- const int qid_stride = num_heads * channels;
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
-
- for (int l_col=0; l_col < num_levels; ++l_col)
- {
- const int level_start_id = data_level_start_index[l_col];
- const int spatial_h_ptr = l_col << 1;
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
-
- for (int p_col=0; p_col < num_point; ++p_col)
- {
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
- const scalar_t weight = data_attn_weight[data_weight_ptr];
-
- const scalar_t h_im = loc_h * spatial_h - 0.5;
- const scalar_t w_im = loc_w * spatial_w - 0.5;
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
- {
- ms_deform_attn_col2im_bilinear_gm(
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
- top_grad, weight, grad_value_ptr,
- grad_sampling_loc, grad_attn_weight);
- }
- data_weight_ptr += 1;
- data_loc_w_ptr += 2;
- grad_attn_weight += grad_weight_stride;
- grad_sampling_loc += grad_loc_stride;
- }
- }
- }
-}
-
-
-template <typename scalar_t>
-void ms_deformable_im2col_cuda(cudaStream_t stream,
- const scalar_t* data_value,
- const int64_t* data_spatial_shapes,
- const int64_t* data_level_start_index,
- const scalar_t* data_sampling_loc,
- const scalar_t* data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t* data_col)
-{
- const int num_kernels = batch_size * num_query * num_heads * channels;
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
- const int num_threads = CUDA_NUM_THREADS;
- ms_deformable_im2col_gpu_kernel<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
- batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
-
- cudaError_t err = cudaGetLastError();
- if (err != cudaSuccess)
- {
- printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
- }
-
-}
-
-template <typename scalar_t>
-void ms_deformable_col2im_cuda(cudaStream_t stream,
- const scalar_t* grad_col,
- const scalar_t* data_value,
- const int64_t * data_spatial_shapes,
- const int64_t * data_level_start_index,
- const scalar_t * data_sampling_loc,
- const scalar_t * data_attn_weight,
- const int batch_size,
- const int spatial_size,
- const int num_heads,
- const int channels,
- const int num_levels,
- const int num_query,
- const int num_point,
- scalar_t* grad_value,
- scalar_t* grad_sampling_loc,
- scalar_t* grad_attn_weight)
-{
- const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
- const int num_kernels = batch_size * num_query * num_heads * channels;
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
- if (channels > 1024)
- {
- if ((channels & 1023) == 0)
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- else
- {
- ms_deformable_col2im_gpu_kernel_gm<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- }
- else{
- switch(channels)
- {
- case 1:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 2:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 4:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 8:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 16:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 32:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 64:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 128:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 256:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 512:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- case 1024:
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- 0, stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- break;
- default:
- if (channels < 64)
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- else
- {
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
- num_threads*3*sizeof(scalar_t), stream>>>(
- num_kernels,
- grad_col,
- data_value,
- data_spatial_shapes,
- data_level_start_index,
- data_sampling_loc,
- data_attn_weight,
- batch_size,
- spatial_size,
- num_heads,
- channels,
- num_levels,
- num_query,
- num_point,
- grad_value,
- grad_sampling_loc,
- grad_attn_weight);
- }
- }
- }
- cudaError_t err = cudaGetLastError();
- if (err != cudaSuccess)
- {
- printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
- }
-
-}
diff --git a/src/transformers/kernels/deta/ms_deform_attn.h b/src/transformers/kernels/deta/ms_deform_attn.h
deleted file mode 100644
index 119b1fa317d1..000000000000
--- a/src/transformers/kernels/deta/ms_deform_attn.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#pragma once
-
-#include "cpu/ms_deform_attn_cpu.h"
-
-#ifdef WITH_CUDA
-#include "cuda/ms_deform_attn_cuda.h"
-#endif
-
-
-at::Tensor
-ms_deform_attn_forward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const int im2col_step)
-{
- if (value.type().is_cuda())
- {
-#ifdef WITH_CUDA
- return ms_deform_attn_cuda_forward(
- value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- }
- AT_ERROR("Not implemented on the CPU");
-}
-
-std::vector<at::Tensor>
-ms_deform_attn_backward(
- const at::Tensor &value,
- const at::Tensor &spatial_shapes,
- const at::Tensor &level_start_index,
- const at::Tensor &sampling_loc,
- const at::Tensor &attn_weight,
- const at::Tensor &grad_output,
- const int im2col_step)
-{
- if (value.type().is_cuda())
- {
-#ifdef WITH_CUDA
- return ms_deform_attn_cuda_backward(
- value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
-#else
- AT_ERROR("Not compiled with GPU support");
-#endif
- }
- AT_ERROR("Not implemented on the CPU");
-}
diff --git a/src/transformers/kernels/deta/vision.cpp b/src/transformers/kernels/deta/vision.cpp
deleted file mode 100644
index 6ce3875568b9..000000000000
--- a/src/transformers/kernels/deta/vision.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-/*!
-**************************************************************************************************
-* Deformable DETR
-* Copyright (c) 2020 SenseTime. All Rights Reserved.
-* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
-**************************************************************************************************
-* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
-**************************************************************************************************
-*/
-
-#include "ms_deform_attn.h"
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
- m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
-}
\ No newline at end of file
diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py
index 1e233608e3e4..c2a9d0816e07 100644
--- a/src/transformers/models/deprecated/deta/modeling_deta.py
+++ b/src/transformers/models/deprecated/deta/modeling_deta.py
@@ -16,17 +16,13 @@
import copy
import math
-import os
import warnings
from dataclasses import dataclass
-from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
-from torch.autograd import Function
-from torch.autograd.function import once_differentiable
from ....activations import ACT2FN
from ....file_utils import (
@@ -34,101 +30,75 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
- is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
)
+from ....integrations.hub_kernels import use_kernel_forward_from_hub
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput
from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import meshgrid
-from ....utils import is_accelerate_available, is_ninja_available, is_torchvision_available, logging, requires_backends
+from ....utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ....utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig
logger = logging.get_logger(__name__)
-MultiScaleDeformableAttention = None
-
-def load_cuda_kernels():
- from torch.utils.cpp_extension import load
-
- global MultiScaleDeformableAttention
-
- root = Path(__file__).resolve().parent.parent.parent.parent / "kernels" / "deta"
- src_files = [
- root / filename
- for filename in [
- "vision.cpp",
- os.path.join("cpu", "ms_deform_attn_cpu.cpp"),
- os.path.join("cuda", "ms_deform_attn_cuda.cu"),
- ]
- ]
-
- MultiScaleDeformableAttention = load(
- "MultiScaleDeformableAttention",
- src_files,
- with_cuda=True,
- extra_include_paths=[str(root)],
- extra_cflags=["-DWITH_CUDA=1"],
- extra_cuda_cflags=[
- "-DCUDA_HAS_FP16=1",
- "-D__CUDA_NO_HALF_OPERATORS__",
- "-D__CUDA_NO_HALF_CONVERSIONS__",
- "-D__CUDA_NO_HALF2_OPERATORS__",
- ],
- )
-
-
-class MultiScaleDeformableAttentionFunction(Function):
- @staticmethod
+@use_kernel_forward_from_hub("MultiScaleDeformableAttention")
+class MultiScaleDeformableAttention(nn.Module):
def forward(
- context,
- value,
- value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- im2col_step,
+ self,
+ value: Tensor,
+ value_spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ sampling_locations: Tensor,
+ attention_weights: Tensor,
+ im2col_step: int,
):
- context.im2col_step = im2col_step
- output = MultiScaleDeformableAttention.ms_deform_attn_forward(
- value,
- value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- context.im2col_step,
- )
- context.save_for_backward(
- value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights
+ batch_size, _, num_heads, hidden_dim = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+ value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level_id, (height, width) in enumerate(value_spatial_shapes):
+ # batch_size, height*width, num_heads, hidden_dim
+ # -> batch_size, height*width, num_heads*hidden_dim
+ # -> batch_size, num_heads*hidden_dim, height*width
+ # -> batch_size*num_heads, hidden_dim, height, width
+ value_l_ = (
+ value_list[level_id]
+ .flatten(2)
+ .transpose(1, 2)
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
+ )
+ # batch_size, num_queries, num_heads, num_points, 2
+ # -> batch_size, num_heads, num_queries, num_points, 2
+ # -> batch_size*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
+ sampling_value_l_ = nn.functional.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
)
- return output
-
- @staticmethod
- @once_differentiable
- def backward(context, grad_output):
- (
- value,
- value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- ) = context.saved_tensors
- grad_value, grad_sampling_loc, grad_attn_weight = MultiScaleDeformableAttention.ms_deform_attn_backward(
- value,
- value_spatial_shapes,
- value_level_start_index,
- sampling_locations,
- attention_weights,
- grad_output,
- context.im2col_step,
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(batch_size, num_heads * hidden_dim, num_queries)
)
-
- return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+ return output.transpose(1, 2).contiguous()
if is_accelerate_available():
@@ -571,12 +541,7 @@ class DetaMultiscaleDeformableAttention(nn.Module):
def __init__(self, config: DetaConfig, num_heads: int, n_points: int):
super().__init__()
- kernel_loaded = MultiScaleDeformableAttention is not None
- if is_torch_cuda_available() and is_ninja_available() and not kernel_loaded:
- try:
- load_cuda_kernels()
- except Exception as e:
- logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
+ self.attn = MultiScaleDeformableAttention()
if config.d_model % num_heads != 0:
raise ValueError(
@@ -684,23 +649,14 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
- if self.disable_custom_kernels:
- # PyTorch implementation
- output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
- else:
- try:
- # custom kernel
- output = MultiScaleDeformableAttentionFunction.apply(
- value,
- spatial_shapes,
- level_start_index,
- sampling_locations,
- attention_weights,
- self.im2col_step,
- )
- except Exception:
- # PyTorch implementation
- output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
+ output = self.attn(
+ value,
+ spatial_shapes,
+ level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
output = self.output_proj(output)
return output, attention_weights