File size: 15,452 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
#pragma once

#include <string>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <ATen/cudnn/cudnn-wrapper.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <cuda.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif

#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8907
#define USE_CUDNN_RNN_V8_API
#endif

namespace at::native {

std::string cudnnTypeToString(cudnnDataType_t dtype);

// TODO: Add constructors for all of the descriptors

inline int dataSize(cudnnDataType_t dataType)
{
  switch (dataType) {
    case CUDNN_DATA_BFLOAT16:
    case CUDNN_DATA_HALF: return 2;
    case CUDNN_DATA_FLOAT: return 4;
    default: return 8;
  }
}

// The stride for a size-1 dimensions is not uniquely determined; in
// fact, it can be anything you want, because the fact that the
// tensor is size 1 at this dimension means that you will never actually
// try advancing your pointer by this stride.
//
// However, CuDNN has a much more stringent requirement on strides:
// if you are passing a contiguous input, it better be the case
// that the stride for dim i is the product of the sizes of dims
// i+1 to the end.  This stride is indeed uniquely determined.  This
// function modifies 'stride' in place so this invariant holds.
template <typename T>
static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
  int64_t z = 1;
  int index = 0;
  std::vector<int> permutation(dim);

  if (nhwc) {
    permutation[index++] = 1;
  }
  for (int d = dim-1; d > 1; d--) {
    permutation[index++] = d;
  }
  if (!nhwc) {
    permutation[index++] = 1;
  }
  permutation[index++] = 0;
  for (int d : permutation) {
    if (size[d] == 1) {
      stride[d] = z;
    } else {
      z *= size[d];
    }
  }
}

template <typename T, cudnnStatus_t (*dtor)(T*)>
struct DescriptorDeleter {
  void operator()(T* x) {
    if (x != nullptr) {
      AT_CUDNN_CHECK(dtor(x));
    }
  }
};

// A generic class for wrapping cuDNN descriptor types.  All you need
// is to give the underlying type the Descriptor_t points to (usually,
// if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
// the constructor and the destructor.  Subclasses are responsible
// for defining a set() function to actually set the descriptor.
//
// Descriptors default construct to a nullptr, and have a descriptor
// initialized the first time you call set() or any other initializing
// function.
template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_CUDA_CPP_API Descriptor {
 public:
  // TODO: Figure out why const-correctness doesn't work here

  // Use desc() to access the underlying descriptor pointer in
  // a read-only fashion.  Most client code should use this.
  // If the descriptor was never initialized, this will return
  // nullptr.
  T* desc() const { return desc_.get(); }
  T* desc() { return desc_.get(); }

  // Use mut_desc() to access the underlying descriptor pointer
  // if you intend to modify what it points to (e.g., using
  // cudnnSetFooDescriptor).  This will ensure that the descriptor
  // is initialized.  Code in this file will use this function.
  T* mut_desc() { init(); return desc_.get(); }
protected:
  void init() {
    if (desc_ == nullptr) {
      T* raw_desc = nullptr;
      AT_CUDNN_CHECK(ctor(&raw_desc));
      desc_.reset(raw_desc);
    }
  }
private:
  std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
};

class TORCH_CUDA_CPP_API RNNDataDescriptor : public Descriptor<
                                       cudnnRNNDataStruct,
                                       &cudnnCreateRNNDataDescriptor,
                                       &cudnnDestroyRNNDataDescriptor> {
public:
  void set(const at::Tensor &t, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray);
private:
  void set(cudnnDataType_t dataType, cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, int vectorSize, const int* seqLengthArray) {
    AT_CUDNN_CHECK(cudnnSetRNNDataDescriptor(mut_desc(), dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, nullptr));
  }
};

class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
                                               cudnnTensorStruct,
                                               &cudnnCreateTensorDescriptor,
                                               &cudnnDestroyTensorDescriptor> {
 public:
  TensorDescriptor() = default;
  explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
    set(t, pad);
  }

  // Note [CuDNN broadcast padding]
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  // pad specifies the minimum dimensionality of the tensor descriptor
  // we produce (it doesn't have anything to do with, e.g., convolution
  // padding).  If 't' is lower-dimensional than 'pad', the remaining
  // dimensions (on the right) are padded with ones.  This doesn't
  // affect the underlying data layout.  This is particularly useful for
  // dealing with a peculiarity of the CuDNN API, which is that broadcasting in CuDNN is
  // done in two steps: first, the client code is expected to pad out
  // (the dimensions) input tensors to be the same dimension as the
  // target broadcast, and then second, CuDNN takes of actually
  // broadcasting size 1 dimensions.

  void set(const at::Tensor &t, size_t pad = 0);
  void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);

  void print();

private:
  void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);

  void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
    std::vector<int> strides_copy(stride, stride + dim);
    fixSizeOneDimStride<int>(dim, size, strides_copy.data(), nhwc);
    AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, strides_copy.data()));
  }
};

std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);

class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
                                               cudnnFilterStruct,
                                               &cudnnCreateFilterDescriptor,
                                               &cudnnDestroyFilterDescriptor> {
 public:
  void set(const at::Tensor &t, int64_t pad = 0) {
    set(t, at::MemoryFormat::Contiguous, pad);
  }

  void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);

  void print();
private:
  void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
    AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
  }
};

std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);

struct TORCH_CUDA_CPP_API ConvolutionDescriptor

    : public Descriptor<
          cudnnConvolutionStruct,
          &cudnnCreateConvolutionDescriptor,
          &cudnnDestroyConvolutionDescriptor> {
  void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
    cudnnDataType_t mathType = dataType;
    if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
    AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
                                          CUDNN_CROSS_CORRELATION, mathType));
    AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
    // See Note [behavior of cudnnFind and cudnnGet]
    AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
    if(dataType == CUDNN_DATA_HALF) {
      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
    } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
      AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
    }
  }
};

struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor

    : public Descriptor<
          cudnnSpatialTransformerStruct,
          &cudnnCreateSpatialTransformerDescriptor,
          &cudnnDestroySpatialTransformerDescriptor> {
  void set(cudnnDataType_t dataType, int dim, int* size) {
    AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
  }
};

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_CUDA_CPP_API DropoutDescriptor

    : public Descriptor<
          cudnnDropoutStruct,
          &cudnnCreateDropoutDescriptor,
          &cudnnDestroyDropoutDescriptor> {
  at::Tensor state;

  // Initialize a dropout descriptor's RNG state.
  // WARNING: This function is very expensive, avoid calling this function!
  void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
    TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
    size_t state_size = 0;
    AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
    AT_ASSERT(options.device().type() == kCUDA);
    AT_ASSERT(options.dtype() == kByte);
    state = at::empty({static_cast<int64_t>(state_size)}, options);
    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
  }

  // Restore a dropout descriptor given a dropout probability and existing RNG state.
  void set(cudnnHandle_t handle, float dropout, const at::Tensor& state) {
    TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
    void *state_ptr = state.data_ptr();
    size_t state_size = state.size(0);
    // NB: The seed doesn't actually matter, so we give a dummy value
    AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
  }

  // Restore a dropout descriptor corresponding to no dropout
  void set_no_dropout(cudnnHandle_t handle) {
    // NB: seed doesn't matter when dropout = 0, because no random number
    // initialization actually takes place when there is no dropout.
    // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
    // dropout == 0
    AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
  }
};

struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
                                             cudnnRNNStruct,
                                             &cudnnCreateRNNDescriptor,
                                             &cudnnDestroyRNNDescriptor> {
  DropoutDescriptor dropout_desc_;
  void set(cudnnHandle_t handle,

#ifdef USE_CUDNN_RNN_V8_API

       int input_size,

       bool packed,

#endif

       int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,

           cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,

           cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
    dropout_desc_ = std::move(dropout_desc);
#ifndef USE_CUDNN_RNN_V8_API
    AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
          handle,
          mut_desc(),
          hidden_size,
          num_layers,
          dropout_desc_.desc(),
          input_mode,
          bidirectional,
          mode,
          algo,
          datatype));
    if (proj_size != 0) {
      AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
            handle,
            /*rnnDesc=*/mut_desc(),
            /*recProjSize=*/proj_size,
            /*outProjSize=*/0));
    }
    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
    if (prop->major >= 7) {
      if (input_type == CUDNN_DATA_HALF) {
        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
      }
      else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
      }
      else {
        // Technically, as the default it's not necessary to explicitly
        // set this.
        cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
      }
    }
#else
    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
    auto math_type = CUDNN_DEFAULT_MATH;
    if (prop->major >= 7) {
      if (input_type == CUDNN_DATA_HALF) {
        math_type = CUDNN_TENSOR_OP_MATH;
      } else if (!allow_tf32) {
        math_type = CUDNN_FMA_MATH;
      }
    }
    AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v8(
          mut_desc(),
          algo,
          mode,
          CUDNN_RNN_DOUBLE_BIAS,
          bidirectional,
          input_mode,
          input_type,
          datatype,
          math_type,
          input_size,
          hidden_size,
          proj_size ? proj_size : hidden_size,
          num_layers,
          dropout_desc_.desc(),
          packed ? CUDNN_RNN_PADDED_IO_DISABLED : CUDNN_RNN_PADDED_IO_ENABLED));
#endif
  }
};

struct TORCH_CUDA_CPP_API CTCLossDescriptor

    : public Descriptor<
          cudnnCTCLossStruct,
          &cudnnCreateCTCLossDescriptor,
          &cudnnDestroyCTCLossDescriptor> {
  void set(cudnnDataType_t datatype) {
    AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
  }
  void setEx(

      cudnnDataType_t datatype,

      cudnnLossNormalizationMode_t normMode,

      cudnnNanPropagation_t gradMode) {
    AT_CUDNN_CHECK(
        cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
  }
  void set_v8_v9(

      cudnnDataType_t datatype,

      cudnnLossNormalizationMode_t normMode,

      cudnnNanPropagation_t gradMode,

      int maxLabelLength) {
#if defined(CUDNN_VERSION) && CUDNN_VERSION >= 90000
    auto gradModev9 = CUDNN_CTC_ZERO_OOB_GRADIENTS;
    if (gradMode == cudnnNanPropagation_t::CUDNN_PROPAGATE_NAN) {
      gradModev9 = CUDNN_CTC_SKIP_OOB_GRADIENTS;
    }
    AT_CUDNN_CHECK(
        cudnnSetCTCLossDescriptor_v9(mut_desc(), datatype, normMode, gradModev9, maxLabelLength));
#else
    AT_CUDNN_CHECK(
        cudnnSetCTCLossDescriptor_v8(mut_desc(), datatype, normMode, gradMode, maxLabelLength));
#endif
  }

};

struct TORCH_CUDA_CPP_API ActivationDescriptor

    : public Descriptor<
          cudnnActivationStruct,
          &cudnnCreateActivationDescriptor,
          &cudnnDestroyActivationDescriptor> {
  void set(cudnnActivationMode_t mode) {
    AT_ASSERT(
        mode == CUDNN_ACTIVATION_RELU,
        "TODO: support more cuDNN activation modes");
    AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
        mut_desc(),
        mode,
        cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
        std::numeric_limits<double>::max()));
  }
};

union Constant

{
  float f;
  double d;
  Constant(cudnnDataType_t dataType, double value) {
    if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
      f = static_cast<float>(value);
    } else {
      d = value;
    }
  }
};

} // namespace