|
|
#pragma once
|
|
|
|
|
|
#include <string>
|
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
#include <ATen/cuda/Exceptions.h>
|
|
|
|
|
|
#include <ATen/cudnn/cudnn-wrapper.h>
|
|
|
#include <ATen/cudnn/Utils.h>
|
|
|
#include <ATen/core/Tensor.h>
|
|
|
#include <ATen/TensorUtils.h>
|
|
|
#include <ATen/cuda/ATenCUDAGeneral.h>
|
|
|
#include <cuda.h>
|
|
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
|
#include <ATen/Functions.h>
|
|
|
#else
|
|
|
#include <ATen/ops/empty.h>
|
|
|
#endif
|
|
|
|
|
|
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907
|
|
|
#define USE_CUDNN_RNN_V8_API
|
|
|
#endif
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
std::string cudnnTypeToString(cudnnDataType_t dtype);
|
|
|
|
|
|
|
|
|
|
|
|
inline int dataSize(cudnnDataType_t dataType)
|
|
|
{
|
|
|
switch (dataType) {
|
|
|
case CUDNN_DATA_BFLOAT16:
|
|
|
case CUDNN_DATA_HALF: return 2;
|
|
|
case CUDNN_DATA_FLOAT: return 4;
|
|
|
default: return 8;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
|
|
|
int64_t z = 1;
|
|
|
int index = 0;
|
|
|
std::vector<int> permutation(dim);
|
|
|
|
|
|
if (nhwc) {
|
|
|
permutation[index++] = 1;
|
|
|
}
|
|
|
for (int d = dim-1; d > 1; d--) {
|
|
|
permutation[index++] = d;
|
|
|
}
|
|
|
if (!nhwc) {
|
|
|
permutation[index++] = 1;
|
|
|
}
|
|
|
permutation[index++] = 0;
|
|
|
for (int d : permutation) {
|
|
|
if (size[d] == 1) {
|
|
|
stride[d] = z;
|
|
|
} else {
|
|
|
z *= size[d];
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
template <typename T, cudnnStatus_t (*dtor)(T*)>
|
|
|
struct DescriptorDeleter {
|
|
|
void operator()(T* x) {
|
|
|
if (x != nullptr) {
|
|
|
AT_CUDNN_CHECK(dtor(x));
|
|
|
}
|
|
|
}
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
|
|
|
|
|
|
class TORCH_CUDA_CPP_API Descriptor {
|
|
|
public:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T* desc() const { return desc_.get(); }
|
|
|
T* desc() { return desc_.get(); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T* mut_desc() { init(); return desc_.get(); }
|
|
|
protected:
|
|
|
void init() {
|
|
|
if (desc_ == nullptr) {
|
|
|
T* raw_desc = nullptr;
|
|
|
AT_CUDNN_CHECK(ctor(&raw_desc));
|
|
|
desc_.reset(raw_desc);
|
|
|
}
|
|
|
}
|
|
|
private:
|
|
|
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
|
|
|
};
|
|
|
|
|
|
class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor<
|
|
|
cudnnRNNDataStruct,
|
|
|
&cudnnCreateRNNDataDescriptor,
|
|
|
&cudnnDestroyRNNDataDescriptor> {
|
|
|
public:
|
|
|
void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray);
|
|
|
private:
|
|
|
void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) {
|
|
|
AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, nullptr));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
|
|
|
cudnnTensorStruct,
|
|
|
&cudnnCreateTensorDescriptor,
|
|
|
&cudnnDestroyTensorDescriptor> {
|
|
|
public:
|
|
|
TensorDescriptor() = default;
|
|
|
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
|
|
|
set(t, pad);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void set(const at::Tensor &t, size_t pad = 0);
|
|
|
void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
|
|
|
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
|
|
|
|
|
|
void print();
|
|
|
|
|
|
private:
|
|
|
void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
|
|
|
|
|
|
void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
|
|
|
std::vector<int> strides_copy(stride, stride + dim);
|
|
|
fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
|
|
|
AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
|
|
|
|
|
|
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
|
|
|
cudnnFilterStruct,
|
|
|
&cudnnCreateFilterDescriptor,
|
|
|
&cudnnDestroyFilterDescriptor> {
|
|
|
public:
|
|
|
void set(const at::Tensor &t, int64_t pad = 0) {
|
|
|
set(t, at::MemoryFormat::Contiguous, pad);
|
|
|
}
|
|
|
|
|
|
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
|
|
|
|
|
|
void print();
|
|
|
private:
|
|
|
void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
|
|
|
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
|
|
|
: public Descriptor<
|
|
|
cudnnConvolutionStruct,
|
|
|
&cudnnCreateConvolutionDescriptor,
|
|
|
&cudnnDestroyConvolutionDescriptor> {
|
|
|
void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale , int groups, bool allow_tf32) {
|
|
|
cudnnDataType_t mathType = dataType;
|
|
|
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
|
|
|
CUDNN_CROSS_CORRELATION, mathType));
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
|
|
|
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
|
|
|
if(dataType == CUDNN_DATA_HALF) {
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
} else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
|
|
|
AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
|
|
|
}
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
|
|
|
: public Descriptor<
|
|
|
cudnnSpatialTransformerStruct,
|
|
|
&cudnnCreateSpatialTransformerDescriptor,
|
|
|
&cudnnDestroySpatialTransformerDescriptor> {
|
|
|
void set(cudnnDataType_t dataType, int dim, int* size) {
|
|
|
AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API DropoutDescriptor
|
|
|
: public Descriptor<
|
|
|
cudnnDropoutStruct,
|
|
|
&cudnnCreateDropoutDescriptor,
|
|
|
&cudnnDestroyDropoutDescriptor> {
|
|
|
at::Tensor state;
|
|
|
|
|
|
|
|
|
|
|
|
void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
|
|
|
TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
|
|
|
size_t state_size = 0;
|
|
|
AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
|
|
|
AT_ASSERT(options.device().type() == kCUDA);
|
|
|
AT_ASSERT(options.dtype() == kByte);
|
|
|
state = at::empty({static_cast<int64_t>(state_size)}, options);
|
|
|
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
|
|
|
}
|
|
|
|
|
|
|
|
|
void set(cudnnHandle_t handle, float dropout, const at::Tensor& state) {
|
|
|
TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
|
|
|
void *state_ptr = state.data_ptr();
|
|
|
size_t state_size = state.size(0);
|
|
|
|
|
|
AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 ));
|
|
|
}
|
|
|
|
|
|
|
|
|
void set_no_dropout(cudnnHandle_t handle) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 , nullptr, 0 , 0 ));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
|
|
|
cudnnRNNStruct,
|
|
|
&cudnnCreateRNNDescriptor,
|
|
|
&cudnnDestroyRNNDescriptor> {
|
|
|
DropoutDescriptor dropout_desc_;
|
|
|
void set(cudnnHandle_t handle,
|
|
|
#ifdef USE_CUDNN_RNN_V8_API
|
|
|
int input_size,
|
|
|
bool packed,
|
|
|
#endif
|
|
|
int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
|
|
|
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
|
|
|
cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
|
|
|
dropout_desc_ = std::move(dropout_desc);
|
|
|
#ifndef USE_CUDNN_RNN_V8_API
|
|
|
AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
|
|
|
handle,
|
|
|
mut_desc(),
|
|
|
hidden_size,
|
|
|
num_layers,
|
|
|
dropout_desc_.desc(),
|
|
|
input_mode,
|
|
|
bidirectional,
|
|
|
mode,
|
|
|
algo,
|
|
|
datatype));
|
|
|
if (proj_size != 0) {
|
|
|
AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
|
|
|
handle,
|
|
|
mut_desc(),
|
|
|
proj_size,
|
|
|
0));
|
|
|
}
|
|
|
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
|
|
if (prop->major >= 7) {
|
|
|
if (input_type == CUDNN_DATA_HALF) {
|
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
|
|
|
}
|
|
|
else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
|
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
|
|
|
}
|
|
|
else {
|
|
|
|
|
|
|
|
|
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
|
|
|
}
|
|
|
}
|
|
|
#else
|
|
|
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
|
|
auto math_type = CUDNN_DEFAULT_MATH;
|
|
|
if (prop->major >= 7) {
|
|
|
if (input_type == CUDNN_DATA_HALF) {
|
|
|
math_type = CUDNN_TENSOR_OP_MATH;
|
|
|
} else if (!allow_tf32) {
|
|
|
math_type = CUDNN_FMA_MATH;
|
|
|
}
|
|
|
}
|
|
|
AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8(
|
|
|
mut_desc(),
|
|
|
algo,
|
|
|
mode,
|
|
|
CUDNN_RNN_DOUBLE_BIAS,
|
|
|
bidirectional,
|
|
|
input_mode,
|
|
|
input_type,
|
|
|
datatype,
|
|
|
math_type,
|
|
|
input_size,
|
|
|
hidden_size,
|
|
|
proj_size ? proj_size : hidden_size,
|
|
|
num_layers,
|
|
|
dropout_desc_.desc(),
|
|
|
packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED));
|
|
|
#endif
|
|
|
}
|
|
|
};
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API CTCLossDescriptor
|
|
|
: public Descriptor<
|
|
|
cudnnCTCLossStruct,
|
|
|
&cudnnCreateCTCLossDescriptor,
|
|
|
&cudnnDestroyCTCLossDescriptor> {
|
|
|
void set(cudnnDataType_t datatype) {
|
|
|
AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
|
|
|
}
|
|
|
void setEx(
|
|
|
cudnnDataType_t datatype,
|
|
|
cudnnLossNormalizationMode_t normMode,
|
|
|
cudnnNanPropagation_t gradMode) {
|
|
|
AT_CUDNN_CHECK(
|
|
|
cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
|
|
|
}
|
|
|
void set_v8_v9(
|
|
|
cudnnDataType_t datatype,
|
|
|
cudnnLossNormalizationMode_t normMode,
|
|
|
cudnnNanPropagation_t gradMode,
|
|
|
int maxLabelLength) {
|
|
|
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
|
|
|
auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS;
|
|
|
if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) {
|
|
|
gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS;
|
|
|
}
|
|
|
AT_CUDNN_CHECK(
|
|
|
cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength));
|
|
|
#else
|
|
|
AT_CUDNN_CHECK(
|
|
|
cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength));
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
};
|
|
|
|
|
|
struct TORCH_CUDA_CPP_API ActivationDescriptor
|
|
|
: public Descriptor<
|
|
|
cudnnActivationStruct,
|
|
|
&cudnnCreateActivationDescriptor,
|
|
|
&cudnnDestroyActivationDescriptor> {
|
|
|
void set(cudnnActivationMode_t mode) {
|
|
|
AT_ASSERT(
|
|
|
mode == CUDNN_ACTIVATION_RELU,
|
|
|
"TODO: support more cuDNN activation modes");
|
|
|
AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
|
|
|
mut_desc(),
|
|
|
mode,
|
|
|
cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
|
|
|
std::numeric_limits<double>::max()));
|
|
|
}
|
|
|
};
|
|
|
|
|
|
union Constant
|
|
|
{
|
|
|
float f;
|
|
|
double d;
|
|
|
Constant(cudnnDataType_t dataType, double value) {
|
|
|
if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
|
|
|
f = static_cast<float>(value);
|
|
|
} else {
|
|
|
d = value;
|
|
|
}
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|