| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <ATen/ATen.h> |
| | #include <c10/cuda/CUDAStream.h> |
| |
|
| | #include <cuda.h> |
| | #include <cuda_runtime.h> |
| |
|
| | #include <algorithm> |
| | #include <functional> |
| | #include <iostream> |
| | #include <stdexcept> |
| | #include <utility> |
| | #include <vector> |
| |
|
| | #include <stdlib.h> |
| | #include <assert.h> |
| |
|
| | #define SHFL_MASK 0xffffffff |
| |
|
| | template<int FS, int SB, int padding_l, typename scalar_t> |
| | __global__ |
| | void lightconv_forward_kernel(const scalar_t* input, |
| | const scalar_t* filters, |
| | int minibatch, int sequenceLength, |
| | int numFeatures, int numFiltersInBlock, |
| | scalar_t* output); |
| |
|
| | template<int FS, int SB, int padding_l, typename scalar_t> |
| | __global__ |
| | void lightconv_grad_wrt_input_kernel( |
| | const scalar_t* input, |
| | const scalar_t* filters, |
| | int minibatch, |
| | int sequenceLength, |
| | int numFeatures, |
| | int numFiltersInBlock, |
| | scalar_t* output); |
| |
|
| | template<int FS, int SB, int padding_l, typename scalar_t> |
| | __global__ |
| | void lightconv_grad_wrt_weights_firstpass_short_kernel( |
| | const scalar_t* input, |
| | const scalar_t* gradInput, |
| | int minibatch, |
| | int sequenceLength, |
| | int numFeatures, |
| | int numFiltersInBlock, |
| | int numHeads, |
| | float* output); |
| |
|
| | template<int FS, int SB, typename scalar_t> |
| | __global__ |
| | void lightconv_grad_wrt_weights_secondpass_short_kernel( |
| | const float* input, |
| | const int minibatch, |
| | const int numFiltersInBlock, |
| | scalar_t* output); |
| |
|
| | template<int FS, int SB, int padding_l, typename scalar_t> |
| | __global__ |
| | void lightconv_grad_wrt_weights_firstpass_kernel( |
| | const scalar_t* input, |
| | const scalar_t* gradInput, |
| | int minibatch, |
| | int sequenceLength, |
| | int numFeatures, |
| | int numFiltersInBlock, |
| | float* output); |
| |
|
| | template<int FS, int SB, typename scalar_t> |
| | __global__ |
| | void lightconv_grad_wrt_weights_secondpass_kernel( |
| | const float* input, |
| | const int minibatch, |
| | const int numFiltersInBlock, |
| | scalar_t* output); |
| |
|
| |
|