| #include "ATen/ATen.h" |
| #include "ATen/cuda/CUDAContext.h" |
| #include "ATen/cuda/detail/IndexUtils.cuh" |
| #include <cuda.h> |
| #include <cuda_runtime.h> |
| #include <stdio.h> |
| #include <cmath> |
| #include "ATen/TensorUtils.h" |
| #include "ATen/Type.h" |
| #include "ATen/AccumulateType.h" |
| #include <THC/THCGeneral.h> |
|
|
| #include "type_shim.h" |
|
|
| typedef enum{ |
| ADAM_MODE_0 =0, |
| ADAM_MODE_1 =1 |
| } adamMode_t; |
|
|
| template <typename T, typename GRAD_T> |
| __global__ void adam_cuda_kernel( |
| GRAD_T* __restrict__ p, |
| T* __restrict__ p_copy, |
| T* __restrict__ m, |
| T* __restrict__ v, |
| const GRAD_T * __restrict__ g, |
| const float b1, |
| const float b2, |
| const float eps, |
| const float grad_scale, |
| const float step_size, |
| const size_t tsize, |
| adamMode_t mode, |
| const float decay) |
| { |
| |
| const int blockId = gridDim.x * blockIdx.y + blockIdx.x; |
| const int threadsPerBlock = blockDim.x * blockDim.y; |
| const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; |
| const int i = (blockId * threadsPerBlock + threadIdInBlock); |
| const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; |
|
|
| for (int j = i; j < tsize; j+=totThreads) { |
| T scaled_grad = g[j]/grad_scale; |
| m[j] = b1*m[j] + (1-b1)*scaled_grad; |
| v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad; |
| float denom; |
| if (mode == ADAM_MODE_0) |
| denom = sqrtf(v[j] + eps); |
| else |
| denom = sqrtf(v[j]) + eps; |
| float update = (m[j]/denom) + (decay*p[j]); |
| p[j] = (GRAD_T) (p[j] - (step_size*update)); |
| if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; |
| } |
| } |
|
|
| void fused_adam_cuda( |
| at::Tensor & p, |
| at::Tensor & p_copy, |
| at::Tensor & m, |
| at::Tensor & v, |
| at::Tensor & g, |
| float lr, |
| float beta1, |
| float beta2, |
| float eps, |
| float grad_scale, |
| int step, |
| int mode, |
| int bias_correction, |
| float decay) |
| { |
| |
|
|
| |
| int tsize = p.numel(); |
| |
| const int threadsPerBlock = 512; |
| const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); |
| AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); |
| |
| float step_size = 0; |
| if (bias_correction == 1) { |
| const float bias_correction1 = 1 - std::pow(beta1, step); |
| const float bias_correction2 = 1 - std::pow(beta2, step); |
| step_size = lr * std::sqrt(bias_correction2)/bias_correction1; |
| } |
| else { |
| step_size = lr; |
| } |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
|
| if (g.scalar_type() == at::ScalarType::Half) { |
| |
| |
| |
| using namespace at; |
| DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", |
| using accscalar_t = at::acc_type<scalar_t_0, true>; |
| adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( |
| p.data<scalar_t_0>(), |
| NULL, |
| m.data<accscalar_t>(), |
| v.data<accscalar_t>(), |
| g.data<scalar_t_0>(), |
| beta1, |
| beta2, |
| eps, |
| grad_scale, |
| step_size, |
| tsize, |
| (adamMode_t) mode, |
| decay); |
| ) |
| } else { |
| using namespace at; |
| DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", |
| adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( |
| p.data<scalar_t_0>(), |
| NULL, |
| m.data<scalar_t_0>(), |
| v.data<scalar_t_0>(), |
| g.data<scalar_t_0>(), |
| beta1, |
| beta2, |
| eps, |
| grad_scale, |
| step_size, |
| tsize, |
| (adamMode_t) mode, |
| decay); |
| ); |
| } |
| THCudaCheck(cudaGetLastError()); |
|
|
| } |
|
|