|
|
#pragma once |
|
|
|
|
|
#include <string> |
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
|
#include <ATen/cuda/Exceptions.h> |
|
|
|
|
|
#include <ATen/cudnn/cudnn-wrapper.h> |
|
|
#include <ATen/cudnn/Utils.h> |
|
|
#include <ATen/ATen.h> |
|
|
#include <ATen/TensorUtils.h> |
|
|
#include <ATen/cuda/ATenCUDAGeneral.h> |
|
|
#include <cuda.h> |
|
|
|
|
|
namespace at { namespace native { |
|
|
|
|
|
std::string cudnnTypeToString(cudnnDataType_t dtype); |
|
|
|
|
|
|
|
|
|
|
|
inline int dataSize(cudnnDataType_t dataType) |
|
|
{ |
|
|
switch (dataType) { |
|
|
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200 |
|
|
case CUDNN_DATA_BFLOAT16: |
|
|
#endif |
|
|
case CUDNN_DATA_HALF: return 2; |
|
|
case CUDNN_DATA_FLOAT: return 4; |
|
|
default: return 8; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static inline void fixSizeOneDimStride(int dim, const int *size, int *stride, bool nhwc) { |
|
|
int64_t z = 1; |
|
|
int index = 0; |
|
|
std::vector<int> permutation(dim); |
|
|
|
|
|
if (nhwc) { |
|
|
permutation[index++] = 1; |
|
|
} |
|
|
for (int d = dim-1; d > 1; d--) { |
|
|
permutation[index++] = d; |
|
|
} |
|
|
if (!nhwc) { |
|
|
permutation[index++] = 1; |
|
|
} |
|
|
permutation[index++] = 0; |
|
|
for (int d : permutation) { |
|
|
if (size[d] == 1) { |
|
|
stride[d] = z; |
|
|
} else { |
|
|
z *= size[d]; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename T, cudnnStatus_t (*dtor)(T*)> |
|
|
struct DescriptorDeleter { |
|
|
void operator()(T* x) { |
|
|
if (x != nullptr) { |
|
|
AT_CUDNN_CHECK(dtor(x)); |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)> |
|
|
class TORCH_CUDA_CPP_API Descriptor { |
|
|
public: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T* desc() const { return desc_.get(); } |
|
|
T* desc() { return desc_.get(); } |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T* mut_desc() { init(); return desc_.get(); } |
|
|
protected: |
|
|
void init() { |
|
|
if (desc_ == nullptr) { |
|
|
T* raw_desc; |
|
|
AT_CUDNN_CHECK(ctor(&raw_desc)); |
|
|
desc_.reset(raw_desc); |
|
|
} |
|
|
} |
|
|
private: |
|
|
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_; |
|
|
}; |
|
|
|
|
|
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< |
|
|
cudnnTensorStruct, |
|
|
&cudnnCreateTensorDescriptor, |
|
|
&cudnnDestroyTensorDescriptor> { |
|
|
public: |
|
|
TensorDescriptor() {} |
|
|
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { |
|
|
set(t, pad); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void set(const at::Tensor &t, size_t pad = 0); |
|
|
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0); |
|
|
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0); |
|
|
|
|
|
void print(); |
|
|
|
|
|
private: |
|
|
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc); |
|
|
|
|
|
void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) { |
|
|
fixSizeOneDimStride(dim, size, stride, nhwc); |
|
|
AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride)); |
|
|
} |
|
|
}; |
|
|
|
|
|
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); |
|
|
|
|
|
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< |
|
|
cudnnFilterStruct, |
|
|
&cudnnCreateFilterDescriptor, |
|
|
&cudnnDestroyFilterDescriptor> { |
|
|
public: |
|
|
void set(const at::Tensor &t, int64_t pad = 0) { |
|
|
set(t, at::MemoryFormat::Contiguous, pad); |
|
|
} |
|
|
|
|
|
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0); |
|
|
|
|
|
void print(); |
|
|
private: |
|
|
void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) { |
|
|
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size)); |
|
|
} |
|
|
}; |
|
|
|
|
|
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d); |
|
|
|
|
|
struct TORCH_CUDA_CPP_API ConvolutionDescriptor |
|
|
: public Descriptor< |
|
|
cudnnConvolutionStruct, |
|
|
&cudnnCreateConvolutionDescriptor, |
|
|
&cudnnDestroyConvolutionDescriptor> { |
|
|
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale , int groups, bool allow_tf32) { |
|
|
cudnnDataType_t mathType = dataType; |
|
|
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT; |
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, |
|
|
CUDNN_CROSS_CORRELATION, mathType)); |
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups)); |
|
|
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH)); |
|
|
if(dataType == CUDNN_DATA_HALF) { |
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH)); |
|
|
} else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) { |
|
|
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 |
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH)); |
|
|
#endif |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor |
|
|
: public Descriptor< |
|
|
cudnnSpatialTransformerStruct, |
|
|
&cudnnCreateSpatialTransformerDescriptor, |
|
|
&cudnnDestroySpatialTransformerDescriptor> { |
|
|
void set(cudnnDataType_t dataType, int dim, int* size) { |
|
|
AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size)); |
|
|
} |
|
|
}; |
|
|
|
|
|
struct TORCH_CUDA_CPP_API DropoutDescriptor |
|
|
: public Descriptor< |
|
|
cudnnDropoutStruct, |
|
|
&cudnnCreateDropoutDescriptor, |
|
|
&cudnnDestroyDropoutDescriptor> { |
|
|
at::Tensor state; |
|
|
|
|
|
|
|
|
|
|
|
void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) { |
|
|
TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); |
|
|
size_t state_size; |
|
|
AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size)); |
|
|
AT_ASSERT(options.device().type() == kCUDA); |
|
|
AT_ASSERT(options.dtype() == kByte); |
|
|
state = at::empty({static_cast<int64_t>(state_size)}, options); |
|
|
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed)); |
|
|
} |
|
|
|
|
|
|
|
|
void set(cudnnHandle_t handle, float dropout, at::Tensor state_) { |
|
|
TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); |
|
|
state = state_; |
|
|
void *state_ptr = state.data_ptr(); |
|
|
size_t state_size = state.size(0); |
|
|
|
|
|
AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 )); |
|
|
} |
|
|
|
|
|
|
|
|
void set_no_dropout(cudnnHandle_t handle) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 , nullptr, 0 , 0 )); |
|
|
} |
|
|
}; |
|
|
|
|
|
struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor< |
|
|
cudnnRNNStruct, |
|
|
&cudnnCreateRNNDescriptor, |
|
|
&cudnnDestroyRNNDescriptor> { |
|
|
DropoutDescriptor dropout_desc_; |
|
|
void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc, |
|
|
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional, |
|
|
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) { |
|
|
dropout_desc_ = std::move(dropout_desc); |
|
|
|
|
|
AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6( |
|
|
handle, |
|
|
mut_desc(), |
|
|
hidden_size, |
|
|
num_layers, |
|
|
dropout_desc_.desc(), |
|
|
input_mode, |
|
|
bidirectional, |
|
|
mode, |
|
|
algo, |
|
|
datatype)); |
|
|
if (proj_size != 0) { |
|
|
AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers( |
|
|
handle, |
|
|
mut_desc(), |
|
|
proj_size, |
|
|
0)); |
|
|
} |
|
|
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); |
|
|
if (prop->major >= 7) { |
|
|
if (input_type == CUDNN_DATA_HALF) { |
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH); |
|
|
} |
|
|
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 |
|
|
else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) { |
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH); |
|
|
} |
|
|
#endif |
|
|
else { |
|
|
|
|
|
|
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH); |
|
|
} |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
struct TORCH_CUDA_CPP_API CTCLossDescriptor |
|
|
: public Descriptor< |
|
|
cudnnCTCLossStruct, |
|
|
&cudnnCreateCTCLossDescriptor, |
|
|
&cudnnDestroyCTCLossDescriptor> { |
|
|
void set(cudnnDataType_t datatype) { |
|
|
AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype)); |
|
|
} |
|
|
#if CUDNN_VERSION >= 7600 |
|
|
void setEx( |
|
|
cudnnDataType_t datatype, |
|
|
cudnnLossNormalizationMode_t normMode, |
|
|
cudnnNanPropagation_t gradMode) { |
|
|
AT_CUDNN_CHECK( |
|
|
cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode)); |
|
|
} |
|
|
#endif |
|
|
}; |
|
|
|
|
|
struct TORCH_CUDA_CPP_API ActivationDescriptor |
|
|
: public Descriptor< |
|
|
cudnnActivationStruct, |
|
|
&cudnnCreateActivationDescriptor, |
|
|
&cudnnDestroyActivationDescriptor> { |
|
|
void set(cudnnActivationMode_t mode) { |
|
|
AT_ASSERT( |
|
|
mode == CUDNN_ACTIVATION_RELU, |
|
|
"TODO: support more cuDNN activation modes"); |
|
|
AT_CUDNN_CHECK(cudnnSetActivationDescriptor( |
|
|
mut_desc(), |
|
|
mode, |
|
|
cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, |
|
|
std::numeric_limits<double>::max())); |
|
|
} |
|
|
}; |
|
|
|
|
|
union Constant |
|
|
{ |
|
|
float f; |
|
|
double d; |
|
|
Constant(cudnnDataType_t dataType, double value) { |
|
|
if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) { |
|
|
f = static_cast<float>(value); |
|
|
} else { |
|
|
d = value; |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
}} |
|
|
|