File size: 6,455 Bytes
a6dd040 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | /*
* Copyright (c) 2020 NVIDIA Corporation.
* Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*
* Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
* Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
* of the code.
*/
#ifndef MATH_FUNCTIONS_CUH
#define MATH_FUNCTIONS_CUH
#include "mkl_alternate.hpp"
#include "gpu.cuh"
namespace minkowski {
template <typename Dtype>
void gpu_gemm(cublasHandle_t handle, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const Dtype alpha, const Dtype *A, const Dtype *B,
const Dtype beta, Dtype *C);
template <typename Dtype>
void gpu_addition(const int N, const Dtype *a, const Dtype *b, Dtype *y,
cudaStream_t stream);
template <typename Dtype>
void gpu_multiplication(const int N, const Dtype *a, const Dtype *b, Dtype *y,
cudaStream_t stream);
template <typename Dtype>
void col2row_major(const int nrows, const int ncols, const Dtype *colA,
Dtype *rowA, cudaStream_t stream);
template <typename Dtype>
void row2col_major(const int nrows, const int ncols, const Dtype *colA,
Dtype *rowA, cudaStream_t stream);
template <typename allocator_type>
void sort_coo_gpu(cusparseHandle_t handle, const int m, const int n,
const int nnz, int *d_coo_row, int *d_coo_col,
allocator_type &allocator);
namespace detail {
// copy_kernel_map for block thread > length
template <typename Dtype, typename Itype>
__global__ void __shared_copy_kernel_map(Dtype *__restrict__ dst,
const Dtype *__restrict__ const src,
const Itype *__restrict__ const map,
const Itype nthreads,
const Itype length) {
// cchoy: cache map and benchmark.
extern __shared__ unsigned int smap[];
const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const Itype src_index = i / length;
const Itype length_index = i % length;
const Itype block_rem = (blockIdx.x * blockDim.x) % length;
const Itype smap_index = (threadIdx.x + block_rem) / length;
if ((threadIdx.x == 0 || (threadIdx.x + block_rem) % length == 0) &&
i < nthreads)
smap[smap_index] = map[src_index];
__syncthreads();
if (i < nthreads) {
dst[i] = src[smap[smap_index] * length + length_index];
}
}
template <typename Dtype, typename Itype>
__global__ void
__shared_accumulate_kernel_map(Dtype *__restrict__ dst,
const Dtype *__restrict__ const src,
const Itype *__restrict__ const map,
const Itype nthreads, const Itype length) {
// cchoy: cache map and benchmark.
extern __shared__ unsigned int smap[];
const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
const Itype src_index = i / length;
const Itype length_index = i % length;
const Itype block_rem = (blockIdx.x * blockDim.x) % length;
const Itype smap_index = (threadIdx.x + block_rem) / length;
if ((threadIdx.x == 0 || (threadIdx.x + block_rem) % length == 0) &&
i < nthreads)
smap[smap_index] = map[src_index];
__syncthreads();
if (i < nthreads)
atomicAdd(&dst[smap[smap_index] * length + length_index], src[i]);
}
template <typename Dtype, typename Itype>
void shared_copy_kernel_map(Dtype *dst, const Dtype *const src,
const Itype *const map, const Itype nthreads,
const Itype length) {
constexpr Itype MAX_THREADS = 512;
if (MAX_THREADS >= length) {
LOG_DEBUG("Blocks:", GET_BLOCKS(nthreads, MAX_THREADS),
"Threads:", MAX_THREADS,
"Shared:", GET_BLOCKS(MAX_THREADS, length));
__shared_copy_kernel_map<Dtype, Itype>
<<<GET_BLOCKS(nthreads, MAX_THREADS), MAX_THREADS,
GET_BLOCKS(MAX_THREADS, length) * sizeof(unsigned int)>>>(
dst, src, map, nthreads, length);
} else {
LOG_DEBUG("Blocks:", GET_BLOCKS(nthreads, MAX_THREADS),
"Threads:", MAX_THREADS,
"Shared:", GET_BLOCKS(length, MAX_THREADS));
__shared_copy_kernel_map<Dtype, Itype>
<<<GET_BLOCKS(nthreads, MAX_THREADS), MAX_THREADS,
GET_BLOCKS(length, MAX_THREADS) * sizeof(unsigned int)>>>(
dst, src, map, nthreads, length);
}
}
template <typename Dtype, typename Itype>
void shared_accumulate_kernel_map(Dtype *dst, const Dtype *const src,
const Itype *const map, const Itype nthreads,
const Itype length) {
constexpr Itype MAX_THREADS = 512;
if (MAX_THREADS >= length)
__shared_accumulate_kernel_map<Dtype, Itype>
<<<GET_BLOCKS(nthreads, MAX_THREADS), MAX_THREADS,
GET_BLOCKS(MAX_THREADS, length) * sizeof(unsigned int)>>>(
dst, src, map, nthreads, length);
else
__shared_accumulate_kernel_map<Dtype, Itype>
<<<GET_BLOCKS(nthreads, MAX_THREADS), MAX_THREADS,
GET_BLOCKS(length, MAX_THREADS) * sizeof(unsigned int)>>>(
dst, src, map, nthreads, length);
}
} // end namespace detail
} // end namespace minkowski
#endif // MATH_FUNCTIONS
|