File size: 4,861 Bytes
e9f9fd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
#include <ATen/ATen.h>
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
template <typename scalar_t>
__global__ void forget_mult_cuda_forward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, scalar_t* __restrict__ output,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
/*
Note: output is assumed to be one timestep longer than f or x where output[0] = h_{-1}
This means output array has a size of seq_length+1 on the word dimension
*/
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
if (hid < n_hidden && bid < batch_size){
for (int ts = 1; ts < seq_length + 1; ts++) {
int i = 0;
int dst_i = 0;
int dst_iminus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;
dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;
dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;
}
else {
i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
}
output[dst_i] = f[i] * x[i];
output[dst_i] += (1 - f[i]) * output[dst_iminus1];
}
}
}
template <typename scalar_t>
__global__ void forget_mult_cuda_backward_kernel(const scalar_t* __restrict__ x,
const scalar_t* __restrict__ f, const scalar_t* __restrict__ output,
const scalar_t* __restrict__ grad_output, scalar_t* __restrict__ grad_x,
scalar_t* __restrict__ grad_f, scalar_t* __restrict__ grad_h,
size_t batch_size, size_t seq_length, size_t n_hidden, bool batch_first) {
const int hid = blockIdx.x * blockDim.x + threadIdx.x;
const int bid = blockIdx.y * blockDim.y + threadIdx.y;
double running_f = 0;
if(hid < n_hidden && bid < batch_size){
for (int ts = seq_length; ts >= 0 + 1; ts--) {
int i = 0;
int dst_i = 0;
int dst_iminus1 = 0;
if (batch_first){
i = bid * n_hidden * seq_length + (ts-1) * n_hidden + hid;
dst_i = bid * n_hidden * (seq_length+1) + (ts-0) * n_hidden + hid;
dst_iminus1 = bid * n_hidden * (seq_length+1) + (ts-1) * n_hidden + hid;
}
else {
i = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
dst_i = (ts-0) * n_hidden * batch_size + bid * n_hidden + hid;
dst_iminus1 = (ts-1) * n_hidden * batch_size + bid * n_hidden + hid;
}
running_f += grad_output[i];
grad_x[i] = f[i] * running_f;
grad_f[i] = (x[i] - output[dst_iminus1]) * running_f;
// The line below is likely more numerically stable than (1 - f[i]) * running_f;
running_f = running_f - f[i] * running_f;
}
grad_h[bid * n_hidden + hid] = running_f;
}
}
at::Tensor forget_mult_cuda_forward(at::Tensor x, at::Tensor f, at::Tensor output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {
forget_mult_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return output;
}
std::vector<at::Tensor> forget_mult_cuda_backward(at::Tensor x, at::Tensor f,
at::Tensor output, at::Tensor grad_output, bool batch_first) {
const auto batch_size = (batch_first) ? x.size(0) : x.size(1);
const auto seq_length = (batch_first) ? x.size(1) : x.size(0);
const auto n_hidden = x.size(2);
auto grad_x = at::zeros_like(x);
auto grad_f = at::zeros_like(x);
auto grad_h = at::zeros({batch_size, n_hidden}, x.options());
const int threads = 1024;
const dim3 blocks((n_hidden + threads - 1) / threads, batch_size);
AT_DISPATCH_FLOATING_TYPES(x.type(), "forget_mult_cuda_forward", ([&] {
forget_mult_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
x.data<scalar_t>(), f.data<scalar_t>(), output.data<scalar_t>(), grad_output.data<scalar_t>(),
grad_x.data<scalar_t>(), grad_f.data<scalar_t>(), grad_h.data<scalar_t>(), batch_size,
seq_length, n_hidden, batch_first);
}));
THCudaCheck(cudaGetLastError());
return {grad_x, grad_f, grad_h};
}
|