Spaces:
Build error
Build error
| //#include <ATen/cuda/CUDAContext.h> | |
| //#include <THC/THCAtomics.cuh> | |
| //#include <THC/THCDeviceUtils.cuh> | |
| //extern THCState *state; | |
| // author: Charles Shang | |
| // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu | |
| // modified from the CUDA version for CPU use by Daniel K. Suhendro | |
| // edit by: James Bockman and Matthew Howe | |
| // modified for torch implementation to remove use of deprecated torch access to Blas | |
| at::Tensor | |
| dcn_v2_cpu_forward(const at::Tensor &input, | |
| const at::Tensor &weight, | |
| const at::Tensor &bias, | |
| const at::Tensor &offset, | |
| const at::Tensor &mask, | |
| const int kernel_h, | |
| const int kernel_w, | |
| const int stride_h, | |
| const int stride_w, | |
| const int pad_h, | |
| const int pad_w, | |
| const int dilation_h, | |
| const int dilation_w, | |
| const int deformable_group) | |
| { | |
| // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); | |
| /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); | |
| AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); | |
| AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); | |
| AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); | |
| AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ | |
| const int batch = input.size(0); | |
| const int channels = input.size(1); | |
| const int height = input.size(2); | |
| const int width = input.size(3); | |
| const int channels_out = weight.size(0); | |
| const int channels_kernel = weight.size(1); | |
| const int kernel_h_ = weight.size(2); | |
| const int kernel_w_ = weight.size(3); | |
| // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); | |
| // printf("Channels: %d %d\n", channels, channels_kernel); | |
| // printf("Channels: %d %d\n", channels_out, channels_kernel); | |
| AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, | |
| "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); | |
| AT_ASSERTM(channels == channels_kernel, | |
| "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); | |
| const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; | |
| const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; | |
| // auto ones = at::ones({height_out, width_out}, input.options()); | |
| auto ones = at::ones({bias.sizes()[0], height_out, width_out}, input.options()); | |
| auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); | |
| auto output = at::zeros({batch, channels_out, height_out, width_out}, input.options()); | |
| using scalar_t = float; | |
| for (int b = 0; b < batch; b++) | |
| { | |
| auto input_n = input.select(0, b); | |
| auto offset_n = offset.select(0, b); | |
| auto mask_n = mask.select(0, b); | |
| auto output_n = output.select(0, b); | |
| // std::cout << "output_n: " << output_n << "output.select(0,b): " << output.select(0,b) << "\n"; | |
| // Do Bias first: | |
| // M,N,K are dims of matrix A and B | |
| // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) | |
| // (N x 1) (1 x M) | |
| // torch implementation | |
| auto ones_T = at::transpose(ones.contiguous(), 2, 0); | |
| ones_T = at::mul(ones_T, bias.contiguous()); | |
| ones_T = at::transpose(ones_T, 2, 0); | |
| output_n = at::add(output_n, ones_T); | |
| modulated_deformable_im2col_cpu(input_n.data_ptr<scalar_t>(), | |
| offset_n.data_ptr<scalar_t>(), | |
| mask_n.data_ptr<scalar_t>(), | |
| 1, channels, height, width, | |
| height_out, width_out, kernel_h, kernel_w, | |
| pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, | |
| deformable_group, | |
| columns.data_ptr<scalar_t>()); | |
| //(k * m) x (m * n) | |
| // Y = WC | |
| // torch implementation | |
| auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w}); | |
| auto product = at::matmul(weight_flat, columns); | |
| output.select(0, b) = at::add(output_n, product.view({channels_out, height_out, width_out})); | |
| } | |
| return output; | |
| } | |
| std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input, | |
| const at::Tensor &weight, | |
| const at::Tensor &bias, | |
| const at::Tensor &offset, | |
| const at::Tensor &mask, | |
| const at::Tensor &grad_output, | |
| int kernel_h, int kernel_w, | |
| int stride_h, int stride_w, | |
| int pad_h, int pad_w, | |
| int dilation_h, int dilation_w, | |
| int deformable_group) | |
| { | |
| THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); | |
| THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); | |
| /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); | |
| AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); | |
| AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); | |
| AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); | |
| AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ | |
| const int batch = input.size(0); | |
| const int channels = input.size(1); | |
| const int height = input.size(2); | |
| const int width = input.size(3); | |
| const int channels_out = weight.size(0); | |
| const int channels_kernel = weight.size(1); | |
| const int kernel_h_ = weight.size(2); | |
| const int kernel_w_ = weight.size(3); | |
| AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, | |
| "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); | |
| AT_ASSERTM(channels == channels_kernel, | |
| "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); | |
| const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; | |
| const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; | |
| auto ones = at::ones({height_out, width_out}, input.options()); | |
| auto columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); | |
| auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); | |
| auto grad_input = at::zeros_like(input); | |
| auto grad_weight = at::zeros_like(weight); | |
| auto grad_bias = at::zeros_like(bias); | |
| auto grad_offset = at::zeros_like(offset); | |
| auto grad_mask = at::zeros_like(mask); | |
| using scalar_t = float; | |
| for (int b = 0; b < batch; b++) | |
| { | |
| auto input_n = input.select(0, b); | |
| auto offset_n = offset.select(0, b); | |
| auto mask_n = mask.select(0, b); | |
| auto grad_output_n = grad_output.select(0, b); | |
| auto grad_input_n = grad_input.select(0, b); | |
| auto grad_offset_n = grad_offset.select(0, b); | |
| auto grad_mask_n = grad_mask.select(0, b); | |
| // Torch implementation | |
| auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w}); | |
| weight_flat = at::transpose(weight_flat, 1, 0); | |
| auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out}); | |
| columns = at::matmul(weight_flat, grad_output_n_flat); | |
| // gradient w.r.t. input coordinate data | |
| modulated_deformable_col2im_coord_cpu(columns.data_ptr<scalar_t>(), | |
| input_n.data_ptr<scalar_t>(), | |
| offset_n.data_ptr<scalar_t>(), | |
| mask_n.data_ptr<scalar_t>(), | |
| 1, channels, height, width, | |
| height_out, width_out, kernel_h, kernel_w, | |
| pad_h, pad_w, stride_h, stride_w, | |
| dilation_h, dilation_w, deformable_group, | |
| grad_offset_n.data_ptr<scalar_t>(), | |
| grad_mask_n.data_ptr<scalar_t>()); | |
| // gradient w.r.t. input data | |
| modulated_deformable_col2im_cpu(columns.data_ptr<scalar_t>(), | |
| offset_n.data_ptr<scalar_t>(), | |
| mask_n.data_ptr<scalar_t>(), | |
| 1, channels, height, width, | |
| height_out, width_out, kernel_h, kernel_w, | |
| pad_h, pad_w, stride_h, stride_w, | |
| dilation_h, dilation_w, deformable_group, | |
| grad_input_n.data_ptr<scalar_t>()); | |
| // gradient w.r.t. weight, dWeight should accumulate across the batch and group | |
| modulated_deformable_im2col_cpu(input_n.data_ptr<scalar_t>(), | |
| offset_n.data_ptr<scalar_t>(), | |
| mask_n.data_ptr<scalar_t>(), | |
| 1, channels, height, width, | |
| height_out, width_out, kernel_h, kernel_w, | |
| pad_h, pad_w, stride_h, stride_w, | |
| dilation_h, dilation_w, deformable_group, | |
| columns.data_ptr<scalar_t>()); | |
| // Torch implementation | |
| auto product = at::matmul(grad_output_n_flat, at::transpose(columns, 1, 0)); | |
| grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w})); | |
| // Torch implementation | |
| auto ones_flat = ones.view({height_out*width_out}); | |
| product = at::matmul(grad_output_n_flat, ones_flat); | |
| grad_bias = at::add(grad_bias, product); | |
| } | |
| return { | |
| grad_input, grad_offset, grad_mask, grad_weight, grad_bias | |
| }; | |
| } |