| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | #include <torch/types.h> |
| |
|
| | #include "deform_conv.h" |
| |
|
| | #include <cmath> |
| | #include <vector> |
| |
|
| | namespace detectron2 { |
| |
|
| | void deformable_im2col( |
| | const at::Tensor data_im, |
| | const at::Tensor data_offset, |
| | const int channels, |
| | const int height, |
| | const int width, |
| | const int ksize_h, |
| | const int ksize_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int parallel_imgs, |
| | const int deformable_group, |
| | at::Tensor data_col); |
| |
|
| | void deformable_col2im( |
| | const at::Tensor data_col, |
| | const at::Tensor data_offset, |
| | const int channels, |
| | const int height, |
| | const int width, |
| | const int ksize_h, |
| | const int ksize_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int parallel_imgs, |
| | const int deformable_group, |
| | at::Tensor grad_im); |
| |
|
| | void deformable_col2im_coord( |
| | const at::Tensor data_col, |
| | const at::Tensor data_im, |
| | const at::Tensor data_offset, |
| | const int channels, |
| | const int height, |
| | const int width, |
| | const int ksize_h, |
| | const int ksize_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int parallel_imgs, |
| | const int deformable_group, |
| | at::Tensor grad_offset); |
| |
|
| | void modulated_deformable_im2col_cuda( |
| | const at::Tensor data_im, |
| | const at::Tensor data_offset, |
| | const at::Tensor data_mask, |
| | const int batch_size, |
| | const int channels, |
| | const int height_im, |
| | const int width_im, |
| | const int height_col, |
| | const int width_col, |
| | const int kernel_h, |
| | const int kenerl_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int deformable_group, |
| | at::Tensor data_col); |
| |
|
| | void modulated_deformable_col2im_cuda( |
| | const at::Tensor data_col, |
| | const at::Tensor data_offset, |
| | const at::Tensor data_mask, |
| | const int batch_size, |
| | const int channels, |
| | const int height_im, |
| | const int width_im, |
| | const int height_col, |
| | const int width_col, |
| | const int kernel_h, |
| | const int kenerl_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int deformable_group, |
| | at::Tensor grad_im); |
| |
|
| | void modulated_deformable_col2im_coord_cuda( |
| | const at::Tensor data_col, |
| | const at::Tensor data_im, |
| | const at::Tensor data_offset, |
| | const at::Tensor data_mask, |
| | const int batch_size, |
| | const int channels, |
| | const int height_im, |
| | const int width_im, |
| | const int height_col, |
| | const int width_col, |
| | const int kernel_h, |
| | const int kenerl_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int deformable_group, |
| | at::Tensor grad_offset, |
| | at::Tensor grad_mask); |
| |
|
| | void shape_check( |
| | at::Tensor input, |
| | at::Tensor offset, |
| | at::Tensor* gradOutput, |
| | at::Tensor weight, |
| | int kH, |
| | int kW, |
| | int dH, |
| | int dW, |
| | int padH, |
| | int padW, |
| | int dilationH, |
| | int dilationW, |
| | int group, |
| | int deformable_group) { |
| | TORCH_CHECK( |
| | weight.ndimension() == 4, |
| | "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " |
| | "but got: %s", |
| | weight.ndimension()); |
| |
|
| | TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); |
| |
|
| | TORCH_CHECK( |
| | kW > 0 && kH > 0, |
| | "kernel size should be greater than zero, but got kH: %d kW: %d", |
| | kH, |
| | kW); |
| |
|
| | TORCH_CHECK( |
| | (weight.size(2) == kH && weight.size(3) == kW), |
| | "kernel size should be consistent with weight, ", |
| | "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", |
| | kH, |
| | kW, |
| | weight.size(2), |
| | weight.size(3)); |
| |
|
| | TORCH_CHECK( |
| | dW > 0 && dH > 0, |
| | "stride should be greater than zero, but got dH: %d dW: %d", |
| | dH, |
| | dW); |
| |
|
| | TORCH_CHECK( |
| | dilationW > 0 && dilationH > 0, |
| | "dilation should be greater than 0, but got dilationH: %d dilationW: %d", |
| | dilationH, |
| | dilationW); |
| |
|
| | int ndim = input.ndimension(); |
| | int dimf = 0; |
| | int dimh = 1; |
| | int dimw = 2; |
| |
|
| | if (ndim == 4) { |
| | dimf++; |
| | dimh++; |
| | dimw++; |
| | } |
| |
|
| | TORCH_CHECK( |
| | ndim == 3 || ndim == 4, |
| | "3D or 4D input tensor expected but got: %s", |
| | ndim); |
| |
|
| | long nInputPlane = weight.size(1) * group; |
| | long inputHeight = input.size(dimh); |
| | long inputWidth = input.size(dimw); |
| | long nOutputPlane = weight.size(0); |
| | long outputHeight = |
| | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; |
| | long outputWidth = |
| | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; |
| |
|
| | TORCH_CHECK( |
| | nInputPlane % deformable_group == 0, |
| | "input channels must divide deformable group size"); |
| |
|
| | if (outputWidth < 1 || outputHeight < 1) |
| | AT_ERROR( |
| | "Given input size: (%ld x %ld x %ld). " |
| | "Calculated output size: (%ld x %ld x %ld). Output size is too small", |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth, |
| | nOutputPlane, |
| | outputHeight, |
| | outputWidth); |
| |
|
| | TORCH_CHECK( |
| | input.size(1) == nInputPlane, |
| | "invalid number of input planes, expected: %d, but got: %d", |
| | nInputPlane, |
| | input.size(1)); |
| |
|
| | TORCH_CHECK( |
| | (inputHeight + 2 * padH >= kH && inputWidth + 2 * padW >= kW), |
| | "input image is smaller than kernel"); |
| |
|
| | TORCH_CHECK( |
| | (offset.size(2) == outputHeight && offset.size(3) == outputWidth), |
| | "invalid spatial size of offset, expected height: %d width: %d, but " |
| | "got height: %d width: %d", |
| | outputHeight, |
| | outputWidth, |
| | offset.size(2), |
| | offset.size(3)); |
| |
|
| | TORCH_CHECK( |
| | (offset.size(1) == deformable_group * 2 * kH * kW), |
| | "invalid number of channels of offset"); |
| |
|
| | if (gradOutput != NULL) { |
| | TORCH_CHECK( |
| | gradOutput->size(dimf) == nOutputPlane, |
| | "invalid number of gradOutput planes, expected: %d, but got: %d", |
| | nOutputPlane, |
| | gradOutput->size(dimf)); |
| |
|
| | TORCH_CHECK( |
| | (gradOutput->size(dimh) == outputHeight && |
| | gradOutput->size(dimw) == outputWidth), |
| | "invalid size of gradOutput, expected height: %d width: %d , but " |
| | "got height: %d width: %d", |
| | outputHeight, |
| | outputWidth, |
| | gradOutput->size(dimh), |
| | gradOutput->size(dimw)); |
| | } |
| | } |
| |
|
| | int deform_conv_forward_cuda( |
| | at::Tensor input, |
| | at::Tensor weight, |
| | at::Tensor offset, |
| | at::Tensor output, |
| | at::Tensor columns, |
| | at::Tensor ones, |
| | int kW, |
| | int kH, |
| | int dW, |
| | int dH, |
| | int padW, |
| | int padH, |
| | int dilationW, |
| | int dilationH, |
| | int group, |
| | int deformable_group, |
| | int im2col_step) { |
| | |
| | |
| | |
| | |
| | |
| |
|
| | shape_check( |
| | input, |
| | offset, |
| | NULL, |
| | weight, |
| | kH, |
| | kW, |
| | dH, |
| | dW, |
| | padH, |
| | padW, |
| | dilationH, |
| | dilationW, |
| | group, |
| | deformable_group); |
| |
|
| | input = input.contiguous(); |
| | offset = offset.contiguous(); |
| | weight = weight.contiguous(); |
| |
|
| | int batch = 1; |
| | if (input.ndimension() == 3) { |
| | |
| | batch = 0; |
| | input.unsqueeze_(0); |
| | offset.unsqueeze_(0); |
| | } |
| |
|
| | |
| |
|
| | long batchSize = input.size(0); |
| | long nInputPlane = input.size(1); |
| | long inputHeight = input.size(2); |
| | long inputWidth = input.size(3); |
| |
|
| | long nOutputPlane = weight.size(0); |
| |
|
| | long outputWidth = |
| | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; |
| | long outputHeight = |
| | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; |
| |
|
| | TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); |
| |
|
| | output = output.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nOutputPlane, |
| | outputHeight, |
| | outputWidth}); |
| | columns = at::zeros( |
| | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, |
| | input.options()); |
| |
|
| | if (ones.ndimension() != 2 || |
| | ones.size(0) * ones.size(1) < outputHeight * outputWidth) { |
| | ones = at::ones({outputHeight, outputWidth}, input.options()); |
| | } |
| |
|
| | input = input.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth}); |
| | offset = offset.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | deformable_group * 2 * kH * kW, |
| | outputHeight, |
| | outputWidth}); |
| |
|
| | at::Tensor output_buffer = at::zeros( |
| | {batchSize / im2col_step, |
| | nOutputPlane, |
| | im2col_step * outputHeight, |
| | outputWidth}, |
| | output.options()); |
| |
|
| | output_buffer = output_buffer.view( |
| | {output_buffer.size(0), |
| | group, |
| | output_buffer.size(1) / group, |
| | output_buffer.size(2), |
| | output_buffer.size(3)}); |
| |
|
| | for (int elt = 0; elt < batchSize / im2col_step; elt++) { |
| | deformable_im2col( |
| | input[elt], |
| | offset[elt], |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth, |
| | kH, |
| | kW, |
| | padH, |
| | padW, |
| | dH, |
| | dW, |
| | dilationH, |
| | dilationW, |
| | im2col_step, |
| | deformable_group, |
| | columns); |
| |
|
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| | weight = weight.view( |
| | {group, |
| | weight.size(0) / group, |
| | weight.size(1), |
| | weight.size(2), |
| | weight.size(3)}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | output_buffer[elt][g] = output_buffer[elt][g] |
| | .flatten(1) |
| | .addmm_(weight[g].flatten(1), columns[g]) |
| | .view_as(output_buffer[elt][g]); |
| | } |
| | } |
| |
|
| | output_buffer = output_buffer.view( |
| | {output_buffer.size(0), |
| | output_buffer.size(1) * output_buffer.size(2), |
| | output_buffer.size(3), |
| | output_buffer.size(4)}); |
| |
|
| | output_buffer = output_buffer.view( |
| | {batchSize / im2col_step, |
| | nOutputPlane, |
| | im2col_step, |
| | outputHeight, |
| | outputWidth}); |
| | output_buffer.transpose_(1, 2); |
| | output.copy_(output_buffer); |
| | output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); |
| |
|
| | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); |
| | offset = offset.view( |
| | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); |
| |
|
| | if (batch == 0) { |
| | output = output.view({nOutputPlane, outputHeight, outputWidth}); |
| | input = input.view({nInputPlane, inputHeight, inputWidth}); |
| | offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); |
| | } |
| |
|
| | return 1; |
| | } |
| |
|
| | int deform_conv_backward_input_cuda( |
| | at::Tensor input, |
| | at::Tensor offset, |
| | at::Tensor gradOutput, |
| | at::Tensor gradInput, |
| | at::Tensor gradOffset, |
| | at::Tensor weight, |
| | at::Tensor columns, |
| | int kW, |
| | int kH, |
| | int dW, |
| | int dH, |
| | int padW, |
| | int padH, |
| | int dilationW, |
| | int dilationH, |
| | int group, |
| | int deformable_group, |
| | int im2col_step) { |
| | shape_check( |
| | input, |
| | offset, |
| | &gradOutput, |
| | weight, |
| | kH, |
| | kW, |
| | dH, |
| | dW, |
| | padH, |
| | padW, |
| | dilationH, |
| | dilationW, |
| | group, |
| | deformable_group); |
| |
|
| | input = input.contiguous(); |
| | offset = offset.contiguous(); |
| | gradOutput = gradOutput.contiguous(); |
| | weight = weight.contiguous(); |
| |
|
| | int batch = 1; |
| |
|
| | if (input.ndimension() == 3) { |
| | |
| | batch = 0; |
| | input = input.view({1, input.size(0), input.size(1), input.size(2)}); |
| | offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); |
| | gradOutput = gradOutput.view( |
| | {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); |
| | } |
| |
|
| | long batchSize = input.size(0); |
| | long nInputPlane = input.size(1); |
| | long inputHeight = input.size(2); |
| | long inputWidth = input.size(3); |
| |
|
| | long nOutputPlane = weight.size(0); |
| |
|
| | long outputWidth = |
| | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; |
| | long outputHeight = |
| | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; |
| |
|
| | TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); |
| | gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); |
| | columns = at::zeros( |
| | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, |
| | input.options()); |
| |
|
| | |
| | gradOutput = gradOutput.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nOutputPlane, |
| | outputHeight, |
| | outputWidth}); |
| | gradOutput.transpose_(1, 2); |
| |
|
| | gradInput = gradInput.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth}); |
| | input = input.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth}); |
| | gradOffset = gradOffset.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | deformable_group * 2 * kH * kW, |
| | outputHeight, |
| | outputWidth}); |
| | offset = offset.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | deformable_group * 2 * kH * kW, |
| | outputHeight, |
| | outputWidth}); |
| |
|
| | for (int elt = 0; elt < batchSize / im2col_step; elt++) { |
| | |
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| | weight = weight.view( |
| | {group, |
| | weight.size(0) / group, |
| | weight.size(1), |
| | weight.size(2), |
| | weight.size(3)}); |
| | gradOutput = gradOutput.view( |
| | {gradOutput.size(0), |
| | group, |
| | gradOutput.size(1) / group, |
| | gradOutput.size(2), |
| | gradOutput.size(3), |
| | gradOutput.size(4)}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | columns[g] = columns[g].addmm_( |
| | weight[g].flatten(1).transpose(0, 1), |
| | gradOutput[elt][g].flatten(1), |
| | 0.0f, |
| | 1.0f); |
| | } |
| |
|
| | columns = |
| | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); |
| | gradOutput = gradOutput.view( |
| | {gradOutput.size(0), |
| | gradOutput.size(1) * gradOutput.size(2), |
| | gradOutput.size(3), |
| | gradOutput.size(4), |
| | gradOutput.size(5)}); |
| |
|
| | deformable_col2im_coord( |
| | columns, |
| | input[elt], |
| | offset[elt], |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth, |
| | kH, |
| | kW, |
| | padH, |
| | padW, |
| | dH, |
| | dW, |
| | dilationH, |
| | dilationW, |
| | im2col_step, |
| | deformable_group, |
| | gradOffset[elt]); |
| |
|
| | deformable_col2im( |
| | columns, |
| | offset[elt], |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth, |
| | kH, |
| | kW, |
| | padH, |
| | padW, |
| | dH, |
| | dW, |
| | dilationH, |
| | dilationW, |
| | im2col_step, |
| | deformable_group, |
| | gradInput[elt]); |
| | } |
| |
|
| | gradOutput.transpose_(1, 2); |
| | gradOutput = |
| | gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); |
| |
|
| | gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); |
| | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); |
| | gradOffset = gradOffset.view( |
| | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); |
| | offset = offset.view( |
| | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); |
| |
|
| | if (batch == 0) { |
| | gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); |
| | input = input.view({nInputPlane, inputHeight, inputWidth}); |
| | gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); |
| | offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); |
| | gradOffset = |
| | gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); |
| | } |
| |
|
| | return 1; |
| | } |
| |
|
| | int deform_conv_backward_parameters_cuda( |
| | at::Tensor input, |
| | at::Tensor offset, |
| | at::Tensor gradOutput, |
| | at::Tensor gradWeight, |
| | at::Tensor columns, |
| | at::Tensor ones, |
| | int kW, |
| | int kH, |
| | int dW, |
| | int dH, |
| | int padW, |
| | int padH, |
| | int dilationW, |
| | int dilationH, |
| | int group, |
| | int deformable_group, |
| | float scale, |
| | int im2col_step) { |
| | |
| | |
| | |
| |
|
| | shape_check( |
| | input, |
| | offset, |
| | &gradOutput, |
| | gradWeight, |
| | kH, |
| | kW, |
| | dH, |
| | dW, |
| | padH, |
| | padW, |
| | dilationH, |
| | dilationW, |
| | group, |
| | deformable_group); |
| |
|
| | input = input.contiguous(); |
| | offset = offset.contiguous(); |
| | gradOutput = gradOutput.contiguous(); |
| |
|
| | int batch = 1; |
| |
|
| | if (input.ndimension() == 3) { |
| | |
| | batch = 0; |
| | input = input.view( |
| | at::IntList({1, input.size(0), input.size(1), input.size(2)})); |
| | gradOutput = gradOutput.view( |
| | {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); |
| | } |
| |
|
| | long batchSize = input.size(0); |
| | long nInputPlane = input.size(1); |
| | long inputHeight = input.size(2); |
| | long inputWidth = input.size(3); |
| |
|
| | long nOutputPlane = gradWeight.size(0); |
| |
|
| | long outputWidth = |
| | (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; |
| | long outputHeight = |
| | (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; |
| |
|
| | TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); |
| |
|
| | columns = at::zeros( |
| | {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, |
| | input.options()); |
| |
|
| | gradOutput = gradOutput.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nOutputPlane, |
| | outputHeight, |
| | outputWidth}); |
| | gradOutput.transpose_(1, 2); |
| |
|
| | at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); |
| | gradOutputBuffer = gradOutputBuffer.view( |
| | {batchSize / im2col_step, |
| | nOutputPlane, |
| | im2col_step, |
| | outputHeight, |
| | outputWidth}); |
| | gradOutputBuffer.copy_(gradOutput); |
| | |
| | gradOutputBuffer = gradOutputBuffer.reshape( |
| | {batchSize / im2col_step, |
| | nOutputPlane, |
| | im2col_step * outputHeight, |
| | outputWidth}); |
| |
|
| | gradOutput.transpose_(1, 2); |
| | gradOutput = |
| | gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); |
| |
|
| | input = input.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth}); |
| | offset = offset.view( |
| | {batchSize / im2col_step, |
| | im2col_step, |
| | deformable_group * 2 * kH * kW, |
| | outputHeight, |
| | outputWidth}); |
| |
|
| | for (int elt = 0; elt < batchSize / im2col_step; elt++) { |
| | deformable_im2col( |
| | input[elt], |
| | offset[elt], |
| | nInputPlane, |
| | inputHeight, |
| | inputWidth, |
| | kH, |
| | kW, |
| | padH, |
| | padW, |
| | dH, |
| | dW, |
| | dilationH, |
| | dilationW, |
| | im2col_step, |
| | deformable_group, |
| | columns); |
| |
|
| | |
| | gradOutputBuffer = gradOutputBuffer.view( |
| | {gradOutputBuffer.size(0), |
| | group, |
| | gradOutputBuffer.size(1) / group, |
| | gradOutputBuffer.size(2), |
| | gradOutputBuffer.size(3)}); |
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| | gradWeight = gradWeight.view( |
| | {group, |
| | gradWeight.size(0) / group, |
| | gradWeight.size(1), |
| | gradWeight.size(2), |
| | gradWeight.size(3)}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | gradWeight[g] = gradWeight[g] |
| | .flatten(1) |
| | .addmm_( |
| | gradOutputBuffer[elt][g].flatten(1), |
| | columns[g].transpose(1, 0), |
| | 1.0, |
| | scale) |
| | .view_as(gradWeight[g]); |
| | } |
| | gradOutputBuffer = gradOutputBuffer.view( |
| | {gradOutputBuffer.size(0), |
| | gradOutputBuffer.size(1) * gradOutputBuffer.size(2), |
| | gradOutputBuffer.size(3), |
| | gradOutputBuffer.size(4)}); |
| | columns = |
| | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); |
| | gradWeight = gradWeight.view( |
| | {gradWeight.size(0) * gradWeight.size(1), |
| | gradWeight.size(2), |
| | gradWeight.size(3), |
| | gradWeight.size(4)}); |
| | } |
| |
|
| | input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); |
| | offset = offset.view( |
| | {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); |
| |
|
| | if (batch == 0) { |
| | gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); |
| | input = input.view({nInputPlane, inputHeight, inputWidth}); |
| | } |
| |
|
| | return 1; |
| | } |
| |
|
| | void modulated_deform_conv_cuda_forward( |
| | at::Tensor input, |
| | at::Tensor weight, |
| | at::Tensor bias, |
| | at::Tensor ones, |
| | at::Tensor offset, |
| | at::Tensor mask, |
| | at::Tensor output, |
| | at::Tensor columns, |
| | int kernel_h, |
| | int kernel_w, |
| | const int stride_h, |
| | const int stride_w, |
| | const int pad_h, |
| | const int pad_w, |
| | const int dilation_h, |
| | const int dilation_w, |
| | const int group, |
| | const int deformable_group, |
| | const bool with_bias) { |
| | shape_check( |
| | input, |
| | offset, |
| | NULL, |
| | weight, |
| | kernel_h, |
| | kernel_w, |
| | stride_h, |
| | stride_w, |
| | pad_h, |
| | pad_w, |
| | dilation_h, |
| | dilation_w, |
| | group, |
| | deformable_group); |
| |
|
| | TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); |
| | TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); |
| |
|
| | const int batch = input.size(0); |
| | const int channels = input.size(1); |
| | const int height = input.size(2); |
| | const int width = input.size(3); |
| |
|
| | const int channels_out = weight.size(0); |
| | const int channels_kernel = weight.size(1); |
| | const int kernel_h_ = weight.size(2); |
| | const int kernel_w_ = weight.size(3); |
| |
|
| | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) |
| | AT_ERROR( |
| | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", |
| | kernel_h_, |
| | kernel_w, |
| | kernel_h_, |
| | kernel_w_); |
| | if (channels != channels_kernel * group) |
| | AT_ERROR( |
| | "Input shape and kernel channels wont match: (%d vs %d).", |
| | channels, |
| | channels_kernel * group); |
| |
|
| | const int height_out = |
| | (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; |
| | const int width_out = |
| | (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; |
| |
|
| | |
| | TORCH_CHECK( |
| | (mask.size(2) == height_out && mask.size(3) == width_out), |
| | "invalid spatial size of mask, expected height: %d width: %d, but " |
| | "got height: %d width: %d", |
| | height_out, |
| | width_out, |
| | mask.size(2), |
| | mask.size(3)); |
| |
|
| | TORCH_CHECK( |
| | (mask.size(1) == deformable_group * kernel_h * kernel_w), |
| | "invalid number of channels of mask"); |
| |
|
| | if (ones.ndimension() != 2 || |
| | ones.size(0) * ones.size(1) < height_out * width_out) { |
| | |
| | ones = at::ones({height_out, width_out}, input.options()); |
| | } |
| |
|
| | |
| | output = output.view({batch, channels_out, height_out, width_out}).zero_(); |
| | |
| | columns = at::zeros( |
| | {channels * kernel_h * kernel_w, 1 * height_out * width_out}, |
| | input.options()); |
| |
|
| | output = output.view( |
| | {output.size(0), |
| | group, |
| | output.size(1) / group, |
| | output.size(2), |
| | output.size(3)}); |
| |
|
| | for (int b = 0; b < batch; b++) { |
| | modulated_deformable_im2col_cuda( |
| | input[b], |
| | offset[b], |
| | mask[b], |
| | 1, |
| | channels, |
| | height, |
| | width, |
| | height_out, |
| | width_out, |
| | kernel_h, |
| | kernel_w, |
| | pad_h, |
| | pad_w, |
| | stride_h, |
| | stride_w, |
| | dilation_h, |
| | dilation_w, |
| | deformable_group, |
| | columns); |
| |
|
| | |
| | weight = weight.view( |
| | {group, |
| | weight.size(0) / group, |
| | weight.size(1), |
| | weight.size(2), |
| | weight.size(3)}); |
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | output[b][g] = output[b][g] |
| | .flatten(1) |
| | .addmm_(weight[g].flatten(1), columns[g]) |
| | .view_as(output[b][g]); |
| | } |
| |
|
| | weight = weight.view( |
| | {weight.size(0) * weight.size(1), |
| | weight.size(2), |
| | weight.size(3), |
| | weight.size(4)}); |
| | columns = |
| | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); |
| | } |
| |
|
| | output = output.view( |
| | {output.size(0), |
| | output.size(1) * output.size(2), |
| | output.size(3), |
| | output.size(4)}); |
| |
|
| | if (with_bias) { |
| | output += bias.view({1, bias.size(0), 1, 1}); |
| | } |
| | } |
| |
|
| | void modulated_deform_conv_cuda_backward( |
| | at::Tensor input, |
| | at::Tensor weight, |
| | at::Tensor bias, |
| | at::Tensor ones, |
| | at::Tensor offset, |
| | at::Tensor mask, |
| | at::Tensor columns, |
| | at::Tensor grad_input, |
| | at::Tensor grad_weight, |
| | at::Tensor grad_bias, |
| | at::Tensor grad_offset, |
| | at::Tensor grad_mask, |
| | at::Tensor grad_output, |
| | int kernel_h, |
| | int kernel_w, |
| | int stride_h, |
| | int stride_w, |
| | int pad_h, |
| | int pad_w, |
| | int dilation_h, |
| | int dilation_w, |
| | int group, |
| | int deformable_group, |
| | const bool with_bias) { |
| | shape_check( |
| | input, |
| | offset, |
| | &grad_output, |
| | weight, |
| | kernel_h, |
| | kernel_w, |
| | stride_h, |
| | stride_w, |
| | pad_h, |
| | pad_w, |
| | dilation_h, |
| | dilation_w, |
| | group, |
| | deformable_group); |
| |
|
| | TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); |
| | TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); |
| |
|
| | const int batch = input.size(0); |
| | const int channels = input.size(1); |
| | const int height = input.size(2); |
| | const int width = input.size(3); |
| |
|
| | const int channels_kernel = weight.size(1); |
| | const int kernel_h_ = weight.size(2); |
| | const int kernel_w_ = weight.size(3); |
| | if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) |
| | AT_ERROR( |
| | "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", |
| | kernel_h_, |
| | kernel_w, |
| | kernel_h_, |
| | kernel_w_); |
| | if (channels != channels_kernel * group) |
| | AT_ERROR( |
| | "Input shape and kernel channels wont match: (%d vs %d).", |
| | channels, |
| | channels_kernel * group); |
| |
|
| | const int height_out = |
| | (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; |
| | const int width_out = |
| | (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; |
| |
|
| | |
| | TORCH_CHECK( |
| | (mask.size(2) == height_out && mask.size(3) == width_out), |
| | "invalid spatial size of mask, expected height: %d width: %d, but " |
| | "got height: %d width: %d", |
| | height_out, |
| | width_out, |
| | mask.size(2), |
| | mask.size(3)); |
| |
|
| | TORCH_CHECK( |
| | (mask.size(1) == deformable_group * kernel_h * kernel_w), |
| | "invalid number of channels of mask"); |
| |
|
| | if (ones.ndimension() != 2 || |
| | ones.size(0) * ones.size(1) < height_out * width_out) { |
| | |
| | ones = at::ones({height_out, width_out}, input.options()); |
| | } |
| |
|
| | grad_input = grad_input.view({batch, channels, height, width}); |
| | columns = at::zeros( |
| | {channels * kernel_h * kernel_w, height_out * width_out}, |
| | input.options()); |
| |
|
| | grad_output = grad_output.view( |
| | {grad_output.size(0), |
| | group, |
| | grad_output.size(1) / group, |
| | grad_output.size(2), |
| | grad_output.size(3)}); |
| |
|
| | for (int b = 0; b < batch; b++) { |
| | |
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| | weight = weight.view( |
| | {group, |
| | weight.size(0) / group, |
| | weight.size(1), |
| | weight.size(2), |
| | weight.size(3)}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | columns[g].addmm_( |
| | weight[g].flatten(1).transpose(0, 1), |
| | grad_output[b][g].flatten(1), |
| | 0.0f, |
| | 1.0f); |
| | } |
| |
|
| | columns = |
| | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); |
| | weight = weight.view( |
| | {weight.size(0) * weight.size(1), |
| | weight.size(2), |
| | weight.size(3), |
| | weight.size(4)}); |
| |
|
| | |
| | modulated_deformable_col2im_coord_cuda( |
| | columns, |
| | input[b], |
| | offset[b], |
| | mask[b], |
| | 1, |
| | channels, |
| | height, |
| | width, |
| | height_out, |
| | width_out, |
| | kernel_h, |
| | kernel_w, |
| | pad_h, |
| | pad_w, |
| | stride_h, |
| | stride_w, |
| | dilation_h, |
| | dilation_w, |
| | deformable_group, |
| | grad_offset[b], |
| | grad_mask[b]); |
| | |
| | modulated_deformable_col2im_cuda( |
| | columns, |
| | offset[b], |
| | mask[b], |
| | 1, |
| | channels, |
| | height, |
| | width, |
| | height_out, |
| | width_out, |
| | kernel_h, |
| | kernel_w, |
| | pad_h, |
| | pad_w, |
| | stride_h, |
| | stride_w, |
| | dilation_h, |
| | dilation_w, |
| | deformable_group, |
| | grad_input[b]); |
| |
|
| | |
| | |
| | modulated_deformable_im2col_cuda( |
| | input[b], |
| | offset[b], |
| | mask[b], |
| | 1, |
| | channels, |
| | height, |
| | width, |
| | height_out, |
| | width_out, |
| | kernel_h, |
| | kernel_w, |
| | pad_h, |
| | pad_w, |
| | stride_h, |
| | stride_w, |
| | dilation_h, |
| | dilation_w, |
| | deformable_group, |
| | columns); |
| |
|
| | columns = columns.view({group, columns.size(0) / group, columns.size(1)}); |
| | grad_weight = grad_weight.view( |
| | {group, |
| | grad_weight.size(0) / group, |
| | grad_weight.size(1), |
| | grad_weight.size(2), |
| | grad_weight.size(3)}); |
| | if (with_bias) |
| | grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); |
| |
|
| | for (int g = 0; g < group; g++) { |
| | grad_weight[g] = |
| | grad_weight[g] |
| | .flatten(1) |
| | .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) |
| | .view_as(grad_weight[g]); |
| | if (with_bias) { |
| | grad_bias[g] = |
| | grad_bias[g] |
| | .view({-1, 1}) |
| | .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) |
| | .view(-1); |
| | } |
| | } |
| |
|
| | columns = |
| | columns.view({columns.size(0) * columns.size(1), columns.size(2)}); |
| | grad_weight = grad_weight.view( |
| | {grad_weight.size(0) * grad_weight.size(1), |
| | grad_weight.size(2), |
| | grad_weight.size(3), |
| | grad_weight.size(4)}); |
| | if (with_bias) |
| | grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); |
| | } |
| | grad_output = grad_output.view( |
| | {grad_output.size(0) * grad_output.size(1), |
| | grad_output.size(2), |
| | grad_output.size(3), |
| | grad_output.size(4)}); |
| | } |
| |
|
| | } |
| |
|