|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef GPU_CONVOLUTION |
|
|
#define GPU_CONVOLUTION |
|
|
|
|
|
#include <iostream> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "allocators.cuh" |
|
|
#include "convolution_kernel.cuh" |
|
|
#include "math_functions.cuh" |
|
|
|
|
|
#include <ATen/cuda/CUDAUtils.h> |
|
|
#include <torch/extension.h> |
|
|
|
|
|
namespace minkowski { |
|
|
|
|
|
namespace detail { |
|
|
|
|
|
bool check_direct_gemm_forward(MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const &convolution_mode, |
|
|
long const sA, long const sB, long const N) { |
|
|
if ((convolution_mode == ConvolutionMode::DIRECT_GEMM) || |
|
|
(algo_index == MinkowskiAlgorithm::MEMORY_EFFICIENT)) |
|
|
return true; |
|
|
if (convolution_mode == ConvolutionMode::COPY_GEMM) |
|
|
return false; |
|
|
if (sA == 32 && sB == 64 and N <= 490537) return true; |
|
|
if (sB <= 40) { |
|
|
if (sB <= 20) { |
|
|
return true; |
|
|
} else { |
|
|
if (N <= 295625) { |
|
|
return true; |
|
|
} else { |
|
|
return (sA <= 12); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
if (sA <= 20) |
|
|
return true; |
|
|
else { |
|
|
if (N <= 74556) { |
|
|
return (sB <= 112); |
|
|
} else { |
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
bool check_direct_gemm_backward(MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const &convolution_mode, |
|
|
long const sA, long const sB, long const N) { |
|
|
if ((convolution_mode == ConvolutionMode::DIRECT_GEMM) || |
|
|
(algo_index == MinkowskiAlgorithm::MEMORY_EFFICIENT)) |
|
|
return true; |
|
|
if (convolution_mode == ConvolutionMode::COPY_GEMM) |
|
|
return false; |
|
|
if (sA == 32 && sB == 64 and N <= 490537) return true; |
|
|
if (sB <= 40) { |
|
|
if (sA <= 20) |
|
|
return true; |
|
|
else { |
|
|
if (N <= 490540) { |
|
|
return true; |
|
|
} else { |
|
|
return (sA <= 12); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
if (sA <= 20) { |
|
|
return true; |
|
|
} else { |
|
|
if (N <= 30612) { |
|
|
return (sB <= 160); |
|
|
} else { |
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Dtype, typename Itype, int BLOCK_SIZE> |
|
|
__global__ void |
|
|
matmul(const Dtype *__restrict__ A, const int wA, const int hA, |
|
|
const Dtype *__restrict__ B, const int wB, const int hB, |
|
|
Dtype *__restrict__ C, |
|
|
const Itype *__restrict__ in_map, const Itype *__restrict__ out_map) { |
|
|
|
|
|
|
|
|
|
|
|
const int bx = blockIdx.x; |
|
|
const int by = blockIdx.y; |
|
|
|
|
|
|
|
|
const int tx = threadIdx.x; |
|
|
const int ty = threadIdx.y; |
|
|
|
|
|
|
|
|
const int x = BLOCK_SIZE * bx + tx; |
|
|
const int y = BLOCK_SIZE * by + ty; |
|
|
|
|
|
|
|
|
|
|
|
Dtype Csub = 0; |
|
|
|
|
|
const Itype in_row = y < hA ? in_map[y] : 0; |
|
|
const Itype out_row = y < hA ? out_map[y] : 0; |
|
|
|
|
|
|
|
|
|
|
|
for (int s = 0; s < wA; s += BLOCK_SIZE) { |
|
|
|
|
|
|
|
|
__shared__ Dtype As[BLOCK_SIZE][BLOCK_SIZE]; |
|
|
|
|
|
|
|
|
|
|
|
__shared__ Dtype Bs[BLOCK_SIZE][BLOCK_SIZE]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
As[ty][tx] = ((s + tx) < wA && y < hA) ? A[wA * in_row + s + tx] : 0; |
|
|
Bs[ty][tx] = ((s + ty) < hB && x < wB) ? B[wB * (s + ty) + x] : 0; |
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int k = 0; k < BLOCK_SIZE; ++k) { |
|
|
Csub += As[ty][k] * Bs[k][tx]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (y < hA && x < wB) |
|
|
atomicAdd(&C[wB * out_row + x], Csub); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Dtype, typename Itype, int BLOCK_SIZE> |
|
|
__global__ void |
|
|
matmul2(const Dtype *__restrict__ A, const int wA, const int hA, |
|
|
const Dtype *__restrict__ B, const int wB, const int hB, |
|
|
const Dtype *__restrict__ D, const int wD, const int hD, |
|
|
Dtype *__restrict__ C, Dtype *__restrict__ E, |
|
|
const Itype *__restrict__ in_map, const Itype *__restrict__ out_map) { |
|
|
|
|
|
|
|
|
|
|
|
const int bx = blockIdx.x; |
|
|
const int by = blockIdx.y; |
|
|
|
|
|
|
|
|
const int tx = threadIdx.x; |
|
|
const int ty = threadIdx.y; |
|
|
|
|
|
|
|
|
const int x = BLOCK_SIZE * bx + tx; |
|
|
const int y = BLOCK_SIZE * by + ty; |
|
|
|
|
|
const Itype in_row = y < hA ? in_map[y] : 0; |
|
|
const Itype out_row = y < hA ? out_map[y] : 0; |
|
|
|
|
|
|
|
|
|
|
|
Dtype Csub = 0; |
|
|
Dtype Esub = 0; |
|
|
|
|
|
|
|
|
|
|
|
__shared__ Dtype As[BLOCK_SIZE][BLOCK_SIZE]; |
|
|
|
|
|
|
|
|
|
|
|
__shared__ Dtype BTs[BLOCK_SIZE][BLOCK_SIZE]; |
|
|
|
|
|
|
|
|
|
|
|
__shared__ Dtype DTs[BLOCK_SIZE][BLOCK_SIZE]; |
|
|
|
|
|
|
|
|
DTs[ty][tx] = (x < wD && y < hD) ? D[wD * in_row + x] : 0; |
|
|
|
|
|
|
|
|
|
|
|
for (int s = 0; s < wA; s += BLOCK_SIZE) { |
|
|
|
|
|
|
|
|
|
|
|
As[ty][tx] = ((s + tx) < wA && y < hA) ? A[wA * out_row + s + tx] : 0; |
|
|
|
|
|
|
|
|
BTs[ty][tx] = ((s + ty) < wB && x < hB) ? B[wB * x + s + ty] : 0; |
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int k = 0; k < BLOCK_SIZE; ++k) { |
|
|
Csub += As[ty][k] * BTs[k][tx]; |
|
|
} |
|
|
|
|
|
|
|
|
Esub = 0; |
|
|
#pragma unroll |
|
|
for (int k = 0; k < BLOCK_SIZE; ++k) { |
|
|
Esub += DTs[k][ty] * As[k][tx]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ((bx * BLOCK_SIZE + ty) < wD && (s + tx) < wA) |
|
|
atomicAdd(&E[wA * (bx * BLOCK_SIZE + ty) + (s + tx)], Esub); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (y < hA && x < hB) |
|
|
atomicAdd(&C[hB * in_row + x], Csub); |
|
|
} |
|
|
|
|
|
template <typename Dtype, typename Itype> |
|
|
__global__ void |
|
|
add_mapped_output_tr(const size_t n, const Dtype *__restrict__ in_feat, |
|
|
const size_t in_nchannel, Dtype *__restrict__ out_feat, |
|
|
const size_t out_nchannel, const Itype *__restrict__ map) { |
|
|
extern __shared__ Itype map_index[]; |
|
|
|
|
|
const int bx = blockIdx.x; |
|
|
const int by = blockIdx.y; |
|
|
|
|
|
|
|
|
const int tx = threadIdx.x; |
|
|
const int ty = threadIdx.y; |
|
|
|
|
|
|
|
|
const int x = blockDim.x * bx + tx; |
|
|
const int y = blockDim.y * by + ty; |
|
|
|
|
|
if (x < n && ty == 0) |
|
|
map_index[tx] = map[x]; |
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
if (x < n && y < out_nchannel) { |
|
|
atomicAdd(&out_feat[map_index[tx] * out_nchannel + y], |
|
|
in_feat[y * in_nchannel + x]); |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
template <typename Dtype, typename Itype, typename ByteAllocator> |
|
|
void ConvolutionForwardKernelGPU( |
|
|
Dtype const *d_in_feat, |
|
|
default_types::size_type const in_nchannel, |
|
|
Dtype *d_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
Dtype *d_kernel, gpu_kernel_map<Itype, ByteAllocator> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
ByteAllocator &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, |
|
|
cublasHandle_t cuhandle, cudaStream_t stream) { |
|
|
|
|
|
size_t n_active_in_volume, shared_mem_size = -1; |
|
|
|
|
|
if (detail::check_direct_gemm_forward(algo_index, convolution_mode, |
|
|
in_nchannel, out_nchannel, in_nrows)) { |
|
|
|
|
|
if ((in_nchannel > 16 && out_nchannel > 16 && |
|
|
in_nchannel * out_nchannel >= 512) || |
|
|
(in_nchannel > 24 && out_nchannel > 24)) |
|
|
shared_mem_size = 32; |
|
|
else if (in_nchannel % 24 == 0 && out_nchannel % 24 == 0) |
|
|
shared_mem_size = 24; |
|
|
else if ((in_nchannel > 8 && out_nchannel > 8) || |
|
|
(in_nchannel % 16 == 0 && out_nchannel % 16 == 0)) |
|
|
shared_mem_size = 16; |
|
|
else |
|
|
shared_mem_size = 8; |
|
|
|
|
|
dim3 threads(shared_mem_size, shared_mem_size); |
|
|
|
|
|
|
|
|
|
|
|
for (auto it = kernel_map.key_cbegin(); it != kernel_map.key_cend(); ++it) { |
|
|
auto const k = it->first; |
|
|
n_active_in_volume = kernel_map.size(k); |
|
|
if (n_active_in_volume == 0) |
|
|
continue; |
|
|
|
|
|
size_t const num_grid = |
|
|
(n_active_in_volume + shared_mem_size - 1) / shared_mem_size; |
|
|
size_t const num_div = (num_grid + MAX_GRID - 1) / MAX_GRID; |
|
|
size_t const step = (n_active_in_volume + num_div - 1) / num_div; |
|
|
|
|
|
for (size_t s = 0; s < num_div; s++) { |
|
|
size_t const offset = step * s; |
|
|
size_t const remainder = n_active_in_volume - offset; |
|
|
size_t const curr_num_active = remainder < step ? remainder : step; |
|
|
dim3 const grid((out_nchannel + threads.x - 1) / threads.x, |
|
|
(curr_num_active + threads.y - 1) / threads.y); |
|
|
|
|
|
|
|
|
#ifdef DEBUG |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
switch (shared_mem_size) { |
|
|
case 32: |
|
|
detail::matmul<Dtype, Itype, 32><<<grid, threads, 0, stream>>>( |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, d_out_feat, kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 24: |
|
|
detail::matmul<Dtype, Itype, 24><<<grid, threads, 0, stream>>>( |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, d_out_feat, kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 16: |
|
|
detail::matmul<Dtype, Itype, 16><<<grid, threads, 0, stream>>>( |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, d_out_feat, kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 8: |
|
|
detail::matmul<Dtype, Itype, 8><<<grid, threads, 0, stream>>>( |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, d_out_feat, kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
} |
|
|
} |
|
|
#ifdef DEBUG |
|
|
LOG_DEBUG("k:", k, "num:", n_active_in_volume); |
|
|
CUDA_CHECK(cudaDeviceSynchronize()); |
|
|
#endif |
|
|
CUDA_CHECK(cudaGetLastError()); |
|
|
} |
|
|
} else { |
|
|
Itype const max_numel = kernel_map.max_size(); |
|
|
LOG_DEBUG("max_numel:", max_numel); |
|
|
Dtype *mapped_in_feat = reinterpret_cast<Dtype *>( |
|
|
allocator.allocate(max_numel * in_nchannel * sizeof(Dtype))); |
|
|
Dtype *mapped_out_feat = reinterpret_cast<Dtype *>( |
|
|
allocator.allocate(max_numel * out_nchannel * sizeof(Dtype))); |
|
|
|
|
|
for (auto it = kernel_map.key_cbegin(); it != kernel_map.key_cend(); ++it) { |
|
|
auto const k = it->first; |
|
|
n_active_in_volume = kernel_map.size(k); |
|
|
if (n_active_in_volume == 0) |
|
|
continue; |
|
|
|
|
|
LOG_DEBUG(n_active_in_volume * in_nchannel, in_nchannel); |
|
|
detail::shared_copy_kernel_map<Dtype, Itype>( |
|
|
|
|
|
mapped_in_feat, d_in_feat, kernel_map.in_maps.begin(k), |
|
|
n_active_in_volume * in_nchannel, in_nchannel); |
|
|
|
|
|
#ifdef DEBUG |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
gpu_gemm<Dtype>(cuhandle, CblasNoTrans, CblasNoTrans, |
|
|
n_active_in_volume, |
|
|
out_nchannel, |
|
|
in_nchannel, |
|
|
1, |
|
|
mapped_in_feat, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], |
|
|
0, |
|
|
mapped_out_feat |
|
|
); |
|
|
|
|
|
detail::shared_accumulate_kernel_map<Dtype, Itype>( |
|
|
d_out_feat, mapped_out_feat, kernel_map.out_maps.begin(k), |
|
|
n_active_in_volume * out_nchannel, out_nchannel); |
|
|
} |
|
|
|
|
|
allocator.deallocate((char *)mapped_in_feat, |
|
|
max_numel * in_nchannel * sizeof(Dtype)); |
|
|
allocator.deallocate((char *)mapped_out_feat, |
|
|
max_numel * out_nchannel * sizeof(Dtype)); |
|
|
} |
|
|
CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
|
} |
|
|
|
|
|
|
|
|
template void |
|
|
ConvolutionForwardKernelGPU<float, uint32_t, detail::default_allocator<char>>( |
|
|
float const *d_in_feat, default_types::size_type const in_nchannel, |
|
|
float *d_out_feat, default_types::size_type const out_nchannel, |
|
|
float *d_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::default_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::default_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
template void |
|
|
ConvolutionForwardKernelGPU<double, uint32_t, detail::default_allocator<char>>( |
|
|
double const *d_in_feat, default_types::size_type const in_nchannel, |
|
|
double *d_out_feat, default_types::size_type const out_nchannel, |
|
|
double *d_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::default_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::default_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
|
|
|
template void |
|
|
ConvolutionForwardKernelGPU<float, uint32_t, detail::c10_allocator<char>>( |
|
|
float const *d_in_feat, default_types::size_type const in_nchannel, |
|
|
float *d_out_feat, default_types::size_type const out_nchannel, |
|
|
float *d_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::c10_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::c10_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
template void |
|
|
ConvolutionForwardKernelGPU<double, uint32_t, detail::c10_allocator<char>>( |
|
|
double const *d_in_feat, default_types::size_type const in_nchannel, |
|
|
double *d_out_feat, default_types::size_type const out_nchannel, |
|
|
double *d_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::c10_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::c10_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
|
|
|
template <typename Dtype, typename Itype, typename ByteAllocator> |
|
|
void ConvolutionBackwardKernelGPU( |
|
|
Dtype const *d_in_feat, |
|
|
Dtype *d_grad_in_feat, default_types::size_type const in_nchannel, |
|
|
Dtype const *d_grad_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
Dtype const *d_kernel, Dtype *d_grad_kernel, |
|
|
gpu_kernel_map<Itype, ByteAllocator> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
ByteAllocator &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream) { |
|
|
|
|
|
#ifdef DEBUG |
|
|
CUDA_CHECK_ARGS(cudaDeviceSynchronize(), |
|
|
"Error triggered from a previous kernel call."); |
|
|
#endif |
|
|
|
|
|
size_t n_active_in_volume, shared_mem_size = -1; |
|
|
|
|
|
if ((in_nchannel > 16 && out_nchannel > 16 && |
|
|
in_nchannel * out_nchannel >= 512) || |
|
|
(in_nchannel % 32 == 0 && out_nchannel % 32 == 0)) |
|
|
shared_mem_size = 32; |
|
|
else if (in_nchannel % 24 == 0 && out_nchannel % 24 == 0) |
|
|
shared_mem_size = 24; |
|
|
else if ((in_nchannel > 8 && out_nchannel > 8) || |
|
|
(in_nchannel % 16 == 0 && out_nchannel % 16 == 0)) |
|
|
shared_mem_size = 16; |
|
|
else |
|
|
shared_mem_size = 8; |
|
|
|
|
|
if (!detail::check_direct_gemm_backward( |
|
|
algo_index, convolution_mode, in_nchannel, out_nchannel, in_nrows)) { |
|
|
|
|
|
size_t max_active = kernel_map.max_size(); |
|
|
size_t in_buffer_size = max_active * in_nchannel * sizeof(Dtype); |
|
|
size_t out_buffer_size = max_active * out_nchannel * sizeof(Dtype); |
|
|
Dtype *d_input_buffer = (Dtype *)allocator.allocate(in_buffer_size); |
|
|
Dtype *d_output_buffer = (Dtype *)allocator.allocate(out_buffer_size); |
|
|
|
|
|
dim3 threads(32, shared_mem_size); |
|
|
#ifdef DEBUG |
|
|
timer t; |
|
|
#endif |
|
|
for (auto it = kernel_map.key_cbegin(); it != kernel_map.key_cend(); ++it) { |
|
|
auto const k = it->first; |
|
|
n_active_in_volume = kernel_map.size(k); |
|
|
if (n_active_in_volume == 0) |
|
|
continue; |
|
|
|
|
|
|
|
|
Itype const *d_in_map = kernel_map.in_maps.begin(k); |
|
|
Itype const *d_out_map = kernel_map.out_maps.begin(k); |
|
|
|
|
|
#ifdef DEBUG |
|
|
t.tic(); |
|
|
#endif |
|
|
detail::shared_copy_kernel_map<Dtype, Itype>( |
|
|
d_output_buffer, d_grad_out_feat, d_out_map, |
|
|
n_active_in_volume * out_nchannel, out_nchannel); |
|
|
#ifdef DEBUG |
|
|
CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
|
LOG_DEBUG("copy input", t.toc()); |
|
|
t.tic(); |
|
|
#endif |
|
|
gpu_gemm<Dtype>(cuhandle, CblasNoTrans, CblasTrans, |
|
|
in_nchannel, |
|
|
n_active_in_volume, |
|
|
out_nchannel, |
|
|
1, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], |
|
|
d_output_buffer, |
|
|
0, |
|
|
d_input_buffer |
|
|
); |
|
|
#ifdef DEBUG |
|
|
CUDA_CHECK(cudaStreamSynchronize(0)); |
|
|
LOG_DEBUG("input grad gemm", t.toc()); |
|
|
t.tic(); |
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
dim3 const grid_tr(GET_BLOCKS(n_active_in_volume, threads.x), |
|
|
GET_BLOCKS(in_nchannel, threads.y)); |
|
|
detail::add_mapped_output_tr<Dtype, Itype> |
|
|
<<<grid_tr, threads, threads.x * sizeof(Itype), stream>>>( |
|
|
n_active_in_volume, |
|
|
d_input_buffer, |
|
|
n_active_in_volume, |
|
|
d_grad_in_feat, in_nchannel, |
|
|
d_in_map); |
|
|
#ifdef DEBUG |
|
|
CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
|
LOG_DEBUG("accumulate in grad", t.toc()); |
|
|
t.tic(); |
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
dim3 const grid_in(GET_BLOCKS(n_active_in_volume, threads.x), |
|
|
GET_BLOCKS(in_nchannel, threads.y)); |
|
|
detail::shared_copy_kernel_map<Dtype, Itype>( |
|
|
d_input_buffer, d_in_feat, d_in_map, n_active_in_volume * in_nchannel, |
|
|
in_nchannel); |
|
|
#ifdef DEBUG |
|
|
LOG_DEBUG("copy in feat to buffer", t.toc()); |
|
|
t.tic(); |
|
|
#endif |
|
|
|
|
|
CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
|
gpu_gemm<Dtype>(cuhandle, CblasTrans, CblasNoTrans, |
|
|
in_nchannel, |
|
|
out_nchannel, |
|
|
n_active_in_volume, |
|
|
1, |
|
|
d_input_buffer, |
|
|
d_output_buffer, |
|
|
1, |
|
|
&d_grad_kernel[k * in_nchannel * out_nchannel] |
|
|
); |
|
|
CUDA_CHECK(cudaStreamSynchronize(0)); |
|
|
#ifdef DEBUG |
|
|
LOG_DEBUG("grad kernel gemm", t.toc()); |
|
|
t.tic(); |
|
|
#endif |
|
|
} |
|
|
allocator.deallocate((char *)d_input_buffer, in_buffer_size); |
|
|
allocator.deallocate((char *)d_output_buffer, out_buffer_size); |
|
|
} else { |
|
|
dim3 threads(shared_mem_size, shared_mem_size); |
|
|
|
|
|
for (auto it = kernel_map.key_cbegin(); it != kernel_map.key_cend(); ++it) { |
|
|
auto const k = it->first; |
|
|
n_active_in_volume = kernel_map.size(k); |
|
|
if (n_active_in_volume == 0) |
|
|
continue; |
|
|
|
|
|
size_t const num_grid = |
|
|
(n_active_in_volume + shared_mem_size - 1) / shared_mem_size; |
|
|
size_t const num_div = (num_grid + MAX_GRID - 1) / MAX_GRID; |
|
|
size_t const step = (n_active_in_volume + num_div - 1) / num_div; |
|
|
|
|
|
for (int s = 0; s < num_div; s++) { |
|
|
size_t const offset = step * s; |
|
|
size_t const remainder = n_active_in_volume - offset; |
|
|
size_t const curr_num_active = remainder < step ? remainder : step; |
|
|
dim3 const grid((in_nchannel + threads.x - 1) / threads.x, |
|
|
(curr_num_active + threads.y - 1) / threads.y); |
|
|
|
|
|
switch (shared_mem_size) { |
|
|
case 32: |
|
|
detail::matmul2<Dtype, Itype, 32><<<grid, threads, 0, stream>>>( |
|
|
d_grad_out_feat, out_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
d_grad_in_feat, |
|
|
&d_grad_kernel[k * in_nchannel * out_nchannel], |
|
|
kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 24: |
|
|
detail::matmul2<Dtype, Itype, 24><<<grid, threads, 0, stream>>>( |
|
|
d_grad_out_feat, out_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
d_grad_in_feat, |
|
|
&d_grad_kernel[k * in_nchannel * out_nchannel], |
|
|
kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 16: |
|
|
detail::matmul2<Dtype, Itype, 16><<<grid, threads, 0, stream>>>( |
|
|
d_grad_out_feat, out_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
d_grad_in_feat, |
|
|
&d_grad_kernel[k * in_nchannel * out_nchannel], |
|
|
kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
case 8: |
|
|
detail::matmul2<Dtype, Itype, 8><<<grid, threads, 0, stream>>>( |
|
|
d_grad_out_feat, out_nchannel, curr_num_active, |
|
|
&d_kernel[k * in_nchannel * out_nchannel], out_nchannel, |
|
|
in_nchannel, |
|
|
d_in_feat, in_nchannel, curr_num_active, |
|
|
d_grad_in_feat, |
|
|
&d_grad_kernel[k * in_nchannel * out_nchannel], |
|
|
kernel_map.in_maps.begin(k) + offset, |
|
|
kernel_map.out_maps.begin(k) + offset); |
|
|
break; |
|
|
} |
|
|
} |
|
|
CUDA_CHECK(cudaGetLastError()); |
|
|
} |
|
|
CUDA_CHECK(cudaStreamSynchronize(stream)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
template void |
|
|
ConvolutionBackwardKernelGPU<float, uint32_t, detail::default_allocator<char>>( |
|
|
float const *d_in_feat, float *d_grad_in_feat, |
|
|
default_types::size_type const in_nchannel, |
|
|
float const *d_grad_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
float const *d_kernel, float *p_grad_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::default_allocator<char>> const |
|
|
&kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::default_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
template void |
|
|
ConvolutionBackwardKernelGPU<double, uint32_t, detail::default_allocator<char>>( |
|
|
double const *d_in_feat, double *d_grad_in_feat, |
|
|
default_types::size_type const in_nchannel, |
|
|
double const *d_grad_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
double const *d_kernel, double *p_grad_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::default_allocator<char>> const |
|
|
&kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::default_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
|
|
|
template void |
|
|
ConvolutionBackwardKernelGPU<float, uint32_t, detail::c10_allocator<char>>( |
|
|
float const *d_in_feat, float *d_grad_in_feat, |
|
|
default_types::size_type const in_nchannel, |
|
|
float const *d_grad_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
float const *d_kernel, float *p_grad_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::c10_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::c10_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
template void |
|
|
ConvolutionBackwardKernelGPU<double, uint32_t, detail::c10_allocator<char>>( |
|
|
double const *d_in_feat, double *d_grad_in_feat, |
|
|
default_types::size_type const in_nchannel, |
|
|
double const *d_grad_out_feat, |
|
|
default_types::size_type const out_nchannel, |
|
|
double const *d_kernel, double *p_grad_kernel, |
|
|
gpu_kernel_map<uint32_t, detail::c10_allocator<char>> const &kernel_map, |
|
|
default_types::size_type const in_nrows, |
|
|
default_types::size_type const out_nrows, |
|
|
detail::c10_allocator<char> &allocator, |
|
|
MinkowskiAlgorithm::Mode const algo_index, |
|
|
ConvolutionMode::Type const convolution_mode, cublasHandle_t cuhandle, |
|
|
cudaStream_t stream); |
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|