| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| #if !defined(CUDNN_ADV_H_) |
| #define CUDNN_ADV_H_ |
|
|
| #include <stdint.h> |
|
|
| #include "cudnn_version.h" |
| #include "cudnn_ops.h" |
|
|
| |
| #define CUDNN_ADV_MAJOR 9 |
| #define CUDNN_ADV_MINOR 10 |
| #define CUDNN_ADV_PATCH 2 |
|
|
| #if (CUDNN_ADV_MAJOR != CUDNN_MAJOR) || (CUDNN_ADV_MINOR != CUDNN_MINOR) || (CUDNN_ADV_PATCH != CUDNN_PATCHLEVEL) |
| #error Version mismatch in cuDNN ADV INFER!!! |
| #endif |
|
|
| #if defined(__cplusplus) |
| extern "C" { |
| #endif |
|
|
| |
|
|
| typedef enum { |
| CUDNN_RNN_ALGO_STANDARD = 0, |
| CUDNN_RNN_ALGO_PERSIST_STATIC = 1, |
| CUDNN_RNN_ALGO_PERSIST_DYNAMIC = 2, |
| CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H = 3, |
| CUDNN_RNN_ALGO_COUNT = 4, |
| } cudnnRNNAlgo_t; |
|
|
| typedef enum { |
| CUDNN_FWD_MODE_INFERENCE = 0, |
| CUDNN_FWD_MODE_TRAINING = 1, |
| } cudnnForwardMode_t; |
|
|
| typedef enum { |
| CUDNN_RNN_RELU = 0, |
| CUDNN_RNN_TANH = 1, |
| CUDNN_LSTM = 2, |
| CUDNN_GRU = 3, |
| } cudnnRNNMode_t; |
|
|
| typedef enum { |
| CUDNN_RNN_NO_BIAS = 0, |
| CUDNN_RNN_SINGLE_INP_BIAS = 1, |
| CUDNN_RNN_DOUBLE_BIAS = 2, |
| CUDNN_RNN_SINGLE_REC_BIAS = 3 |
| } cudnnRNNBiasMode_t; |
|
|
| typedef enum { |
| CUDNN_UNIDIRECTIONAL = 0, |
| CUDNN_BIDIRECTIONAL = 1, |
| } cudnnDirectionMode_t; |
|
|
| typedef enum { |
| CUDNN_LINEAR_INPUT = 0, |
| CUDNN_SKIP_INPUT = 1, |
| } cudnnRNNInputMode_t; |
|
|
| typedef enum { |
| CUDNN_RNN_CLIP_NONE = 0, |
| CUDNN_RNN_CLIP_MINMAX = 1, |
| } cudnnRNNClipMode_t; |
|
|
| typedef enum { |
| CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED = 0, |
| CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED = 1, |
| CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED = 2, |
| } cudnnRNNDataLayout_t; |
|
|
| |
| #define CUDNN_RNN_PADDED_IO_DISABLED 0 |
| #define CUDNN_RNN_PADDED_IO_ENABLED (1U << 0) |
|
|
| struct cudnnRNNStruct; |
| typedef struct cudnnRNNStruct *cudnnRNNDescriptor_t; |
|
|
| struct cudnnRNNDataStruct; |
| typedef struct cudnnRNNDataStruct *cudnnRNNDataDescriptor_t; |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc); |
|
|
| |
| |
| |
| |
| |
| |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnSetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc, |
| cudnnRNNAlgo_t algo, |
| cudnnRNNMode_t cellMode, |
| cudnnRNNBiasMode_t biasMode, |
| cudnnDirectionMode_t dirMode, |
| cudnnRNNInputMode_t inputMode, |
| cudnnDataType_t dataType, |
| cudnnDataType_t mathPrec, |
| cudnnMathType_t mathType, |
| int32_t inputSize, |
| int32_t hiddenSize, |
| int32_t projSize, |
| int32_t numLayers, |
| cudnnDropoutDescriptor_t dropoutDesc, |
| uint32_t auxFlags); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetRNNDescriptor_v8(cudnnRNNDescriptor_t rnnDesc, |
| cudnnRNNAlgo_t *algo, |
| cudnnRNNMode_t *cellMode, |
| cudnnRNNBiasMode_t *biasMode, |
| cudnnDirectionMode_t *dirMode, |
| cudnnRNNInputMode_t *inputMode, |
| cudnnDataType_t *dataType, |
| cudnnDataType_t *mathPrec, |
| cudnnMathType_t *mathType, |
| int32_t *inputSize, |
| int32_t *hiddenSize, |
| int32_t *projSize, |
| int32_t *numLayers, |
| cudnnDropoutDescriptor_t *dropoutDesc, |
| uint32_t *auxFlags); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnRNNSetClip_v8(cudnnRNNDescriptor_t rnnDesc, |
| cudnnRNNClipMode_t clipMode, |
| cudnnNanPropagation_t clipNanOpt, |
| double lclip, |
| double rclip); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnRNNSetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t clipMode, double lclip, double rclip); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnRNNGetClip_v8(cudnnRNNDescriptor_t rnnDesc, |
| cudnnRNNClipMode_t *clipMode, |
| cudnnNanPropagation_t *clipNanOpt, |
| double *lclip, |
| double *rclip); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnRNNGetClip_v9(cudnnRNNDescriptor_t rnnDesc, cudnnRNNClipMode_t *clipMode, double *lclip, double *rclip); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnBuildRNNDynamic(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int miniBatch); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetRNNTempSpaceSizes(cudnnHandle_t handle, |
| cudnnRNNDescriptor_t rnnDesc, |
| cudnnForwardMode_t fwdMode, |
| cudnnRNNDataDescriptor_t xDesc, |
| size_t *workSpaceSize, |
| size_t *reserveSpaceSize); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetRNNWeightSpaceSize(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, size_t *weightSpaceSize); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetRNNWeightParams(cudnnHandle_t handle, |
| cudnnRNNDescriptor_t rnnDesc, |
| int32_t pseudoLayer, |
| size_t weightSpaceSize, |
| const void *weightSpace, |
| int32_t linLayerID, |
| cudnnTensorDescriptor_t mDesc, |
| void **mAddr, |
| cudnnTensorDescriptor_t bDesc, |
| void **bAddr); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc, |
| cudnnDataType_t dataType, |
| cudnnRNNDataLayout_t layout, |
| int maxSeqLength, |
| int batchSize, |
| int vectorSize, |
| const int seqLengthArray[], |
| void *paddingFill); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc, |
| cudnnDataType_t *dataType, |
| cudnnRNNDataLayout_t *layout, |
| int *maxSeqLength, |
| int *batchSize, |
| int *vectorSize, |
| int arrayLengthRequested, |
| int seqLengthArray[], |
| void *paddingFill); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnRNNForward(cudnnHandle_t handle, |
| cudnnRNNDescriptor_t rnnDesc, |
| cudnnForwardMode_t fwdMode, |
| const int32_t devSeqLengths[], |
| cudnnRNNDataDescriptor_t xDesc, |
| const void *x, |
| cudnnRNNDataDescriptor_t yDesc, |
| void *y, |
| cudnnTensorDescriptor_t hDesc, |
| const void *hx, |
| void *hy, |
| cudnnTensorDescriptor_t cDesc, |
| const void *cx, |
| void *cy, |
| size_t weightSpaceSize, |
| const void *weightSpace, |
| size_t workSpaceSize, |
| void *workSpace, |
| size_t reserveSpaceSize, |
| void *reserveSpace); |
|
|
| |
|
|
| typedef enum { |
| CUDNN_SEQDATA_TIME_DIM = 0, |
| CUDNN_SEQDATA_BATCH_DIM = 1, |
| CUDNN_SEQDATA_BEAM_DIM = 2, |
| CUDNN_SEQDATA_VECT_DIM = 3 |
| } cudnnSeqDataAxis_t; |
|
|
| struct cudnnSeqDataStruct; |
| typedef struct cudnnSeqDataStruct *cudnnSeqDataDescriptor_t CUDNN_DEPRECATED; |
|
|
| #define CUDNN_SEQDATA_DIM_COUNT 4 |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc, |
| cudnnDataType_t dataType, |
| int nbDims, |
| const int dimA[], |
| const cudnnSeqDataAxis_t axes[], |
| size_t seqLengthArraySize, |
| const int seqLengthArray[], |
| void *paddingFill); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc, |
| cudnnDataType_t *dataType, |
| int *nbDims, |
| int nbDimsRequested, |
| int dimA[], |
| cudnnSeqDataAxis_t axes[], |
| size_t *seqLengthArraySize, |
| size_t seqLengthSizeRequested, |
| int seqLengthArray[], |
| void *paddingFill); |
|
|
| |
|
|
| |
| |
| |
| |
| |
| #define CUDNN_ATTN_QUERYMAP_ALL_TO_ONE 0 |
| #define CUDNN_ATTN_QUERYMAP_ONE_TO_ONE (1U << 0) /* multiple Q-s map to multiple (K,V) sets when beam size > 1 */ |
| #define CUDNN_ATTN_DISABLE_PROJ_BIASES 0 |
| #define CUDNN_ATTN_ENABLE_PROJ_BIASES (1U << 1) |
|
|
| struct cudnnAttnStruct; |
| typedef struct cudnnAttnStruct *cudnnAttnDescriptor_t CUDNN_DEPRECATED; |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc, |
| unsigned attnMode, |
| int nHeads, |
| double smScaler, |
| cudnnDataType_t dataType, |
| cudnnDataType_t computePrec, |
| cudnnMathType_t mathType, |
| cudnnDropoutDescriptor_t attnDropoutDesc, |
| cudnnDropoutDescriptor_t postDropoutDesc, |
| int qSize, |
| int kSize, |
| int vSize, |
| int qProjSize, |
| int kProjSize, |
| int vProjSize, |
| int oProjSize, |
| int qoMaxSeqLength, |
| int kvMaxSeqLength, |
| int maxBatchSize, |
| int maxBeamSize); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc, |
| unsigned *attnMode, |
| int *nHeads, |
| double *smScaler, |
| cudnnDataType_t *dataType, |
| cudnnDataType_t *computePrec, |
| cudnnMathType_t *mathType, |
| cudnnDropoutDescriptor_t *attnDropoutDesc, |
| cudnnDropoutDescriptor_t *postDropoutDesc, |
| int *qSize, |
| int *kSize, |
| int *vSize, |
| int *qProjSize, |
| int *kProjSize, |
| int *vProjSize, |
| int *oProjSize, |
| int *qoMaxSeqLength, |
| int *kvMaxSeqLength, |
| int *maxBatchSize, |
| int *maxBeamSize); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle, |
| const cudnnAttnDescriptor_t attnDesc, |
| size_t *weightSizeInBytes, |
| size_t *workSpaceSizeInBytes, |
| size_t *reserveSpaceSizeInBytes); |
|
|
| typedef enum { |
| CUDNN_MH_ATTN_Q_WEIGHTS = 0, |
| CUDNN_MH_ATTN_K_WEIGHTS = 1, |
| CUDNN_MH_ATTN_V_WEIGHTS = 2, |
| CUDNN_MH_ATTN_O_WEIGHTS = 3, |
| CUDNN_MH_ATTN_Q_BIASES = 4, |
| CUDNN_MH_ATTN_K_BIASES = 5, |
| CUDNN_MH_ATTN_V_BIASES = 6, |
| CUDNN_MH_ATTN_O_BIASES = 7, |
| } cudnnMultiHeadAttnWeightKind_t; |
|
|
| #define CUDNN_ATTN_WKIND_COUNT 8 |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle, |
| const cudnnAttnDescriptor_t attnDesc, |
| cudnnMultiHeadAttnWeightKind_t wKind, |
| size_t weightSizeInBytes, |
| const void *weights, |
| cudnnTensorDescriptor_t wDesc, |
| void **wAddr); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnMultiHeadAttnForward(cudnnHandle_t handle, |
| const cudnnAttnDescriptor_t attnDesc, |
| int currIdx, |
| const int loWinIdx[], |
| const int hiWinIdx[], |
| const int devSeqLengthsQO[], |
| const int devSeqLengthsKV[], |
| const cudnnSeqDataDescriptor_t qDesc, |
| const void *queries, |
| const void *residuals, |
| const cudnnSeqDataDescriptor_t kDesc, |
| const void *keys, |
| const cudnnSeqDataDescriptor_t vDesc, |
| const void *values, |
| const cudnnSeqDataDescriptor_t oDesc, |
| void *out, |
| size_t weightSizeInBytes, |
| const void *weights, |
| size_t workSpaceSizeInBytes, |
| void *workSpace, |
| size_t reserveSpaceSizeInBytes, |
| void *reserveSpace); |
|
|
| |
| |
| |
| |
| |
| |
| |
| cudnnStatus_t CUDNNWINAPI |
| cudnnAdvVersionCheck(void); |
|
|
| typedef enum { |
| CUDNN_WGRAD_MODE_ADD = 0, |
| CUDNN_WGRAD_MODE_SET = 1, |
| } cudnnWgradMode_t; |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnRNNBackwardData_v8(cudnnHandle_t handle, |
| cudnnRNNDescriptor_t rnnDesc, |
| const int32_t devSeqLengths[], |
| cudnnRNNDataDescriptor_t yDesc, |
| const void *y, |
| const void *dy, |
| cudnnRNNDataDescriptor_t xDesc, |
| void *dx, |
| cudnnTensorDescriptor_t hDesc, |
| const void *hx, |
| const void *dhy, |
| void *dhx, |
| cudnnTensorDescriptor_t cDesc, |
| const void *cx, |
| const void *dcy, |
| void *dcx, |
| size_t weightSpaceSize, |
| const void *weightSpace, |
| size_t workSpaceSize, |
| void *workSpace, |
| size_t reserveSpaceSize, |
| void *reserveSpace); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnRNNBackwardWeights_v8(cudnnHandle_t handle, |
| cudnnRNNDescriptor_t rnnDesc, |
| cudnnWgradMode_t addGrad, |
| const int32_t devSeqLengths[], |
| cudnnRNNDataDescriptor_t xDesc, |
| const void *x, |
| cudnnTensorDescriptor_t hDesc, |
| const void *hx, |
| cudnnRNNDataDescriptor_t yDesc, |
| const void *y, |
| size_t weightSpaceSize, |
| void *dweightSpace, |
| size_t workSpaceSize, |
| void *workSpace, |
| size_t reserveSpaceSize, |
| void *reserveSpace); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle, |
| const cudnnAttnDescriptor_t attnDesc, |
| const int loWinIdx[], |
| const int hiWinIdx[], |
| const int devSeqLengthsDQDO[], |
| const int devSeqLengthsDKDV[], |
| const cudnnSeqDataDescriptor_t doDesc, |
| const void *dout, |
| const cudnnSeqDataDescriptor_t dqDesc, |
| void *dqueries, |
| const void *queries, |
| const cudnnSeqDataDescriptor_t dkDesc, |
| void *dkeys, |
| const void *keys, |
| const cudnnSeqDataDescriptor_t dvDesc, |
| void *dvalues, |
| const void *values, |
| size_t weightSizeInBytes, |
| const void *weights, |
| size_t workSpaceSizeInBytes, |
| void *workSpace, |
| size_t reserveSpaceSizeInBytes, |
| void *reserveSpace); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle, |
| const cudnnAttnDescriptor_t attnDesc, |
| cudnnWgradMode_t addGrad, |
| const cudnnSeqDataDescriptor_t qDesc, |
| const void *queries, |
| const cudnnSeqDataDescriptor_t kDesc, |
| const void *keys, |
| const cudnnSeqDataDescriptor_t vDesc, |
| const void *values, |
| const cudnnSeqDataDescriptor_t doDesc, |
| const void *dout, |
| size_t weightSizeInBytes, |
| const void *weights, |
| void *dweights, |
| size_t workSpaceSizeInBytes, |
| void *workSpace, |
| size_t reserveSpaceSizeInBytes, |
| void *reserveSpace); |
|
|
| |
| |
| |
| |
| typedef enum { |
| CUDNN_LOSS_NORMALIZATION_NONE = 0, |
| CUDNN_LOSS_NORMALIZATION_SOFTMAX = 1, |
| } cudnnLossNormalizationMode_t; |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t compType, |
| cudnnLossNormalizationMode_t normMode, |
| cudnnNanPropagation_t gradMode); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnSetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t compType, |
| cudnnLossNormalizationMode_t normMode, |
| cudnnNanPropagation_t gradMode, |
| int maxLabelLength); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnSetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t compType, |
| cudnnLossNormalizationMode_t normMode, |
| cudnnCTCGradMode_t ctcGradMode, |
| int maxLabelLength); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t *compType, |
| cudnnLossNormalizationMode_t *normMode, |
| cudnnNanPropagation_t *gradMode); |
|
|
| CUDNN_DEPRECATED cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossDescriptor_v8(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t *compType, |
| cudnnLossNormalizationMode_t *normMode, |
| cudnnNanPropagation_t *gradMode, |
| int *maxLabelLength); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossDescriptor_v9(cudnnCTCLossDescriptor_t ctcLossDesc, |
| cudnnDataType_t *compType, |
| cudnnLossNormalizationMode_t *normMode, |
| cudnnCTCGradMode_t *ctcGradMode, |
| int *maxLabelLength); |
|
|
| cudnnStatus_t CUDNNWINAPI |
| cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc); |
|
|
| |
| cudnnStatus_t CUDNNWINAPI |
| cudnnCTCLoss( |
| cudnnHandle_t handle, |
| const cudnnTensorDescriptor_t probsDesc, |
| |
| const void *probs, |
| const int hostLabels[], |
| const int hostLabelLengths[], |
| const int hostInputLengths[], |
| void *costs, |
| const cudnnTensorDescriptor_t gradientsDesc, |
| void *gradients, |
| cudnnCTCLossAlgo_t algo, |
| cudnnCTCLossDescriptor_t ctcLossDesc, |
| void *workspace, |
| size_t workSpaceSizeInBytes); |
|
|
| |
| cudnnStatus_t CUDNNWINAPI |
| cudnnCTCLoss_v8( |
| cudnnHandle_t handle, |
| cudnnCTCLossAlgo_t algo, |
| cudnnCTCLossDescriptor_t ctcLossDesc, |
| const cudnnTensorDescriptor_t probsDesc, |
| |
| const void *probs, |
| const int labels[], |
| const int labelLengths[], |
| const int inputLengths[], |
| void *costs, |
| const cudnnTensorDescriptor_t gradientsDesc, |
| void *gradients, |
| size_t workSpaceSizeInBytes, |
| void *workspace); |
|
|
| |
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossWorkspaceSize( |
| cudnnHandle_t handle, |
| const cudnnTensorDescriptor_t probsDesc, |
| |
| const cudnnTensorDescriptor_t gradientsDesc, |
| |
| |
| const int *labels, |
| const int *labelLengths, |
| const int *inputLengths, |
| cudnnCTCLossAlgo_t algo, |
| cudnnCTCLossDescriptor_t ctcLossDesc, |
| size_t *sizeInBytes); |
|
|
| |
| cudnnStatus_t CUDNNWINAPI |
| cudnnGetCTCLossWorkspaceSize_v8( |
| cudnnHandle_t handle, |
| cudnnCTCLossAlgo_t algo, |
| cudnnCTCLossDescriptor_t ctcLossDesc, |
| const cudnnTensorDescriptor_t probsDesc, |
| |
| const cudnnTensorDescriptor_t gradientsDesc, |
| |
| |
| size_t *sizeInBytes); |
|
|
| #if defined(__cplusplus) |
| } |
| #endif |
|
|
| #endif |
|
|