Spaces:
Sleeping
Sleeping
| /** | |
| * Copyright (c) Facebook, Inc. and its affiliates. | |
| * | |
| * This source code is licensed under the MIT license found in the | |
| * LICENSE file in the root directory of this source tree. | |
| */ | |
| // FS is filter size and kernels are specialized for filter sizes | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void dynamicconv_forward_kernel( | |
| const scalar_t* input, | |
| const scalar_t* weight, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| int numHeads, | |
| scalar_t* output) { | |
| assert(blockDim.x == SB); | |
| const int tid = threadIdx.x; | |
| const int batchIdx = blockIdx.x; | |
| const int featureIdx = blockIdx.y; | |
| const int head = featureIdx / numFiltersInBlock; | |
| const int IOOffset = | |
| batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; | |
| const scalar_t* inputFeature = &input[IOOffset]; | |
| scalar_t* outputFeature = &output[IOOffset]; | |
| scalar_t filter[FS]; | |
| __shared__ scalar_t tempInput[SB + FS]; | |
| zeroSharedMem<FS, SB, padding_l>(tempInput); | |
| const int numIterations = divUp<int, int>(sequenceLength, SB); | |
| for (int i = 0; i < numIterations; ++i) { | |
| __syncthreads(); | |
| const int inputOffset = i * SB; | |
| load_input_to_shared<FS, SB, padding_l>( | |
| inputFeature, | |
| inputOffset, | |
| sequenceLength, | |
| i, | |
| numIterations, | |
| false, | |
| tempInput); | |
| __syncthreads(); | |
| if (inputOffset + tid < sequenceLength) { | |
| for (int k = 0; k < FS; ++k) { | |
| const int filterOffset = batchIdx * numHeads * FS * sequenceLength + | |
| head * FS * sequenceLength + k * sequenceLength + i * SB + tid; | |
| filter[k] = weight[filterOffset]; | |
| } | |
| scalar_t out = scalar_t(0.0); | |
| for (int k = 0; k < FS; ++k) { | |
| out += filter[k] * tempInput[tid + k]; | |
| } | |
| outputFeature[inputOffset + tid] = out; | |
| } | |
| } | |
| } | |
| template <int FS, int SB, int padding_l, typename scalar_t> | |
| __global__ void dynamicconv_backward_kernel( | |
| const scalar_t* gradOutput, // B * C * T | |
| const scalar_t* input, // B * C * T | |
| const scalar_t* weight, | |
| int minibatch, | |
| int sequenceLength, | |
| int numFeatures, | |
| int numFiltersInBlock, | |
| int numHeads, | |
| scalar_t* gradWeight, | |
| scalar_t* gradInput) { // B * H * k * T | |
| assert(blockDim.x == SB); | |
| // each block operates on a single batch and filter head | |
| const int tid = threadIdx.x; | |
| const int batchIdx = blockIdx.x; | |
| const int headIdx = blockIdx.y; | |
| const int chunkIdx = blockIdx.z; | |
| const int numChunks = divUp<int, int>(sequenceLength, SB); | |
| const int inputOffset = chunkIdx * SB; | |
| // initialize shared memory for output gradient and input | |
| __shared__ scalar_t tempGradOutput[SB + FS]; | |
| __shared__ scalar_t tempInput[SB + FS]; | |
| const int padding = FS - padding_l - 1; | |
| zeroSharedMem<FS, SB, padding>(tempGradOutput); | |
| zeroSharedMem<FS, SB, padding_l>(tempInput); | |
| // initialize local filter and weight gradient sum arrays | |
| scalar_t tempGradSum[FS]; | |
| scalar_t bfilter[FS]; | |
| for (int k = 0; k < FS; ++k) { | |
| tempGradSum[k] = scalar_t(0.0); | |
| int idxOffset = inputOffset + tid + k - padding; | |
| if (idxOffset >= 0 && idxOffset < sequenceLength) { | |
| int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + | |
| headIdx * FS * sequenceLength + (FS - k - 1) * sequenceLength + | |
| idxOffset; | |
| bfilter[k] = weight[bfilterOffset]; | |
| } else { | |
| bfilter[k] = scalar_t(0.0); | |
| } | |
| } | |
| // iterate over filter block | |
| for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { | |
| __syncthreads(); | |
| // load input and output gradient for this channel and chunk | |
| const int IOOffset = batchIdx * numFeatures * sequenceLength + | |
| (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; | |
| const scalar_t* inputFeature = &input[IOOffset]; | |
| const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; | |
| scalar_t* gradInputFeature = &gradInput[IOOffset]; | |
| load_input_to_shared<FS, SB, padding>( | |
| gradOutputFeature, | |
| inputOffset, | |
| sequenceLength, | |
| chunkIdx, | |
| numChunks, | |
| true, | |
| tempGradOutput); | |
| load_input_to_shared<FS, SB, padding_l>( | |
| inputFeature, | |
| inputOffset, | |
| sequenceLength, | |
| chunkIdx, | |
| numChunks, | |
| true, | |
| tempInput); | |
| __syncthreads(); | |
| // sum input and weight gradients | |
| scalar_t out = scalar_t(0.0); | |
| for (int k = 0; k < FS; ++k) { | |
| tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; | |
| out += bfilter[k] * tempGradOutput[tid + k]; | |
| } | |
| if (inputOffset + tid < sequenceLength) { | |
| gradInputFeature[inputOffset + tid] = out; | |
| } | |
| } | |
| const int gradOffset = | |
| batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength; | |
| scalar_t* gradWeightFeature = &gradWeight[gradOffset]; | |
| // write weight gradient | |
| if (inputOffset + tid < sequenceLength) { | |
| for (int k = 0; k < FS; ++k) { | |
| const int outputOffset = k * sequenceLength + inputOffset + tid; | |
| gradWeightFeature[outputOffset] = tempGradSum[k]; | |
| } | |
| } | |
| } | |