| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | #include "ATen/ATen.h" |
| | #include "ATen/AccumulateType.h" |
| | #include "ATen/cuda/CUDAContext.h" |
| | #include "ATen/cuda/DeviceUtils.cuh" |
| |
|
| | #include <cuda.h> |
| | #include <cuda_runtime.h> |
| |
|
| | #include "type_shim.h" |
| |
|
| | template<typename U> __device__ |
| | void cuWelfordOnlineSum( |
| | const U curr, |
| | U& mu, |
| | U& sigma2, |
| | U& count) |
| | { |
| | count = count + U(1); |
| | U delta = curr - mu; |
| | U lmean = mu + delta / count; |
| | mu = lmean; |
| | U delta2 = curr - lmean; |
| | sigma2 = sigma2 + delta * delta2; |
| | } |
| |
|
| | template<typename U> __device__ |
| | void cuChanOnlineSum( |
| | const U muB, |
| | const U sigma2B, |
| | const U countB, |
| | U& mu, |
| | U& sigma2, |
| | U& count) |
| | { |
| | U delta = muB - mu; |
| | U nA = count; |
| | U nB = countB; |
| | count = count + countB; |
| | U nX = count; |
| | if (nX > U(0)) { |
| | nA = nA / nX; |
| | nB = nB / nX; |
| | mu = nA*mu + nB*muB; |
| | sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; |
| | } else { |
| | mu = U(0); |
| | sigma2 = U(0); |
| | } |
| | } |
| |
|
| | template<typename T, typename U> __device__ |
| | void cuWelfordMuSigma2( |
| | const T* __restrict__ vals, |
| | const int n1, |
| | const int n2, |
| | const int i1, |
| | U& mu, |
| | U& sigma2, |
| | U* buf) |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | U count = U(0); |
| | mu= U(0); |
| | sigma2 = U(0); |
| | if (i1 < n1) { |
| | |
| | |
| | |
| | const int numx = blockDim.x * blockDim.y; |
| | const int thrx = threadIdx.x + threadIdx.y * blockDim.x; |
| | const T* lvals = vals + i1*n2; |
| | int l = 4*thrx; |
| | for (; l+3 < n2; l+=4*numx) { |
| | for (int k = 0; k < 4; ++k) { |
| | U curr = static_cast<U>(lvals[l+k]); |
| | cuWelfordOnlineSum<U>(curr,mu,sigma2,count); |
| | } |
| | } |
| | for (; l < n2; ++l) { |
| | U curr = static_cast<U>(lvals[l]); |
| | cuWelfordOnlineSum<U>(curr,mu,sigma2,count); |
| | } |
| | |
| | for (int l = 0; l <= 4; ++l) { |
| | int srcLaneB = (threadIdx.x+(1<<l))&31; |
| | U muB = WARP_SHFL(mu, srcLaneB); |
| | U countB = WARP_SHFL(count, srcLaneB); |
| | U sigma2B = WARP_SHFL(sigma2, srcLaneB); |
| | cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); |
| | } |
| | |
| | |
| | if (blockDim.y > 1) { |
| | U* ubuf = (U*)buf; |
| | U* ibuf = (U*)(ubuf + blockDim.y); |
| | for (int offset = blockDim.y/2; offset > 0; offset /= 2) { |
| | |
| | if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { |
| | const int wrt_y = threadIdx.y - offset; |
| | ubuf[2*wrt_y] = mu; |
| | ubuf[2*wrt_y+1] = sigma2; |
| | ibuf[wrt_y] = count; |
| | } |
| | __syncthreads(); |
| | |
| | if (threadIdx.x == 0 && threadIdx.y < offset) { |
| | U muB = ubuf[2*threadIdx.y]; |
| | U sigma2B = ubuf[2*threadIdx.y+1]; |
| | U countB = ibuf[threadIdx.y]; |
| | cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count); |
| | } |
| | __syncthreads(); |
| | } |
| | |
| | if (threadIdx.x == 0 && threadIdx.y == 0) { |
| | ubuf[0] = mu; |
| | ubuf[1] = sigma2; |
| | } |
| | __syncthreads(); |
| | mu = ubuf[0]; |
| | sigma2 = ubuf[1]/U(n2); |
| | |
| | } else { |
| | mu = WARP_SHFL(mu, 0); |
| | sigma2 = WARP_SHFL(sigma2/U(n2), 0); |
| | } |
| | } |
| | } |
| |
|
| | template<> __device__ |
| | void cuWelfordMuSigma2( |
| | const at::Half* __restrict__ vals, |
| | const int n1, |
| | const int n2, |
| | const int i1, |
| | float& mu, |
| | float& sigma2, |
| | float* buf) |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | float count = 0.0f; |
| | mu= float(0); |
| | sigma2 = float(0); |
| | if (i1 < n1) { |
| | |
| | |
| | |
| | const int numx = blockDim.x * blockDim.y; |
| | const int thrx = threadIdx.x + threadIdx.y * blockDim.x; |
| | const at::Half* lvals = vals + i1*n2; |
| | int l = 8*thrx; |
| | if ((((size_t)lvals)&3) != 0) { |
| | |
| | |
| | if (thrx == 0) { |
| | float curr = static_cast<float>(lvals[0]); |
| | cuWelfordOnlineSum(curr,mu,sigma2,count); |
| | } |
| | ++l; |
| | } |
| | |
| | for (; l+7 < n2; l+=8*numx) { |
| | for (int k = 0; k < 8; k+=2) { |
| | float2 curr = __half22float2(*((__half2*)(lvals+l+k))); |
| | cuWelfordOnlineSum(curr.x,mu,sigma2,count); |
| | cuWelfordOnlineSum(curr.y,mu,sigma2,count); |
| | } |
| | } |
| | for (; l < n2; ++l) { |
| | float curr = static_cast<float>(lvals[l]); |
| | cuWelfordOnlineSum(curr,mu,sigma2,count); |
| | } |
| | |
| | for (int l = 0; l <= 4; ++l) { |
| | int srcLaneB = (threadIdx.x+(1<<l))&31; |
| | float muB = WARP_SHFL(mu, srcLaneB); |
| | float countB = WARP_SHFL(count, srcLaneB); |
| | float sigma2B = WARP_SHFL(sigma2, srcLaneB); |
| | cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); |
| | } |
| | |
| | |
| | if (blockDim.y > 1) { |
| | float* ubuf = (float*)buf; |
| | float* ibuf = (float*)(ubuf + blockDim.y); |
| | for (int offset = blockDim.y/2; offset > 0; offset /= 2) { |
| | |
| | if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { |
| | const int wrt_y = threadIdx.y - offset; |
| | ubuf[2*wrt_y] = mu; |
| | ubuf[2*wrt_y+1] = sigma2; |
| | ibuf[wrt_y] = count; |
| | } |
| | __syncthreads(); |
| | |
| | if (threadIdx.x == 0 && threadIdx.y < offset) { |
| | float muB = ubuf[2*threadIdx.y]; |
| | float sigma2B = ubuf[2*threadIdx.y+1]; |
| | float countB = ibuf[threadIdx.y]; |
| | cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); |
| | } |
| | __syncthreads(); |
| | } |
| | |
| | if (threadIdx.x == 0 && threadIdx.y == 0) { |
| | ubuf[0] = mu; |
| | ubuf[1] = sigma2; |
| | } |
| | __syncthreads(); |
| | mu = ubuf[0]; |
| | sigma2 = ubuf[1]/float(n2); |
| | |
| | } else { |
| | mu = WARP_SHFL(mu, 0); |
| | sigma2 = WARP_SHFL(sigma2/float(n2), 0); |
| | } |
| | } |
| | } |
| |
|
| | template<typename U> U rsqrt(U v) { |
| | return U(1) / sqrt(v); |
| | } |
| | template<> float rsqrt(float v) { |
| | return rsqrtf(v); |
| | } |
| | template<> double rsqrt(double v) { |
| | return rsqrt(v); |
| | } |
| |
|
| | namespace { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | template <typename T> |
| | struct SharedMemory; |
| |
|
| | template <> |
| | struct SharedMemory <float> |
| | { |
| | __device__ float *getPointer() |
| | { |
| | extern __shared__ float s_float[]; |
| | return s_float; |
| | } |
| | }; |
| |
|
| | } |
| |
|
| | template<typename T, typename U, typename V> __global__ |
| | void cuApplyLayerNorm( |
| | V* __restrict__ output_vals, |
| | U* __restrict__ mean, |
| | U* __restrict__ invvar, |
| | const T* __restrict__ vals, |
| | const int n1, |
| | const int n2, |
| | const U epsilon, |
| | const V* __restrict__ gamma, |
| | const V* __restrict__ beta |
| | ) |
| | { |
| | |
| | |
| | |
| | |
| | for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { |
| | SharedMemory<U> shared; |
| | U* buf = shared.getPointer(); |
| | U mu,sigma2; |
| | cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); |
| | const T* lvals = vals + i1*n2; |
| | V* ovals = output_vals + i1*n2; |
| | U c_invvar = rsqrt(sigma2 + epsilon); |
| | const int numx = blockDim.x * blockDim.y; |
| | const int thrx = threadIdx.x + threadIdx.y * blockDim.x; |
| | if (gamma != NULL && beta != NULL) { |
| | for (int i = thrx; i < n2; i+=numx) { |
| | U curr = static_cast<U>(lvals[i]); |
| | ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i]; |
| | } |
| | } else { |
| | for (int i = thrx; i < n2; i+=numx) { |
| | U curr = static_cast<U>(lvals[i]); |
| | ovals[i] = static_cast<V>(c_invvar * (curr - mu)); |
| | } |
| | } |
| | if (threadIdx.x == 0 && threadIdx.y == 0) { |
| | mean[i1] = mu; |
| | invvar[i1] = c_invvar; |
| | } |
| | __syncthreads(); |
| | } |
| | } |
| |
|
| | template<typename T, typename U, typename V> __device__ |
| | void cuLoadWriteStridedInputs( |
| | const int i1_block, |
| | const int thr_load_row_off, |
| | const int thr_load_col_off, |
| | const int i2_off, |
| | const int row_stride, |
| | U* warp_buf1, |
| | U* warp_buf2, |
| | const T* input, |
| | const V* dout, |
| | const int i1_end, |
| | const int n2, |
| | const U* __restrict__ mean, |
| | const U* __restrict__ invvar |
| | ) |
| | { |
| | int i1 = i1_block+thr_load_row_off; |
| | if (i1 < i1_end) { |
| | U curr_mean = mean[i1]; |
| | U curr_invvar = invvar[i1]; |
| | for (int k = 0; k < blockDim.y; ++k) { |
| | int i2 = i2_off + k; |
| | int load_idx = i1*n2+i2; |
| | int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; |
| | if (i2<n2) { |
| | U curr_input = static_cast<U>(input[load_idx]); |
| | U curr_dout = static_cast<U>(dout[load_idx]); |
| | warp_buf1[write_idx] = curr_dout; |
| | warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; |
| | } else { |
| | warp_buf1[write_idx] = U(0); |
| | warp_buf2[write_idx] = U(0); |
| | } |
| | } |
| | } else { |
| | for (int k = 0; k < blockDim.y; ++k) { |
| | int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; |
| | warp_buf1[write_idx] = U(0); |
| | warp_buf2[write_idx] = U(0); |
| | } |
| | } |
| | } |
| |
|
| | template<typename T, typename U, typename V> __device__ |
| | void cuLoadAddStridedInputs( |
| | const int i1_block, |
| | const int thr_load_row_off, |
| | const int thr_load_col_off, |
| | const int i2_off, |
| | const int row_stride, |
| | U* warp_buf1, |
| | U* warp_buf2, |
| | const T* input, |
| | const V* dout, |
| | const int i1_end, |
| | const int n2, |
| | const U* __restrict__ mean, |
| | const U* __restrict__ invvar |
| | ) |
| | { |
| | int i1 = i1_block+thr_load_row_off; |
| | if (i1 < i1_end) { |
| | U curr_mean = mean[i1]; |
| | U curr_invvar = invvar[i1]; |
| | for (int k = 0; k < blockDim.y; ++k) { |
| | int i2 = i2_off + k; |
| | int load_idx = i1*n2+i2; |
| | int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; |
| | if (i2<n2) { |
| | U curr_input = static_cast<U>(input[load_idx]); |
| | U curr_dout = static_cast<U>(dout[load_idx]); |
| | warp_buf1[write_idx] += curr_dout; |
| | warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | template<typename T, typename U, typename V> __global__ |
| | void cuComputePartGradGammaBeta( |
| | const V* __restrict__ dout, |
| | const T* __restrict__ input, |
| | const int n1, |
| | const int n2, |
| | const U* __restrict__ mean, |
| | const U* __restrict__ invvar, |
| | U epsilon, |
| | U* part_grad_gamma, |
| | U* part_grad_beta) |
| | { |
| | const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); |
| | const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; |
| | const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; |
| | const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; |
| | const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; |
| | const int row_stride = blockDim.x+1; |
| | const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); |
| | const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; |
| | const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; |
| | SharedMemory<U> shared; |
| | U* buf = shared.getPointer(); |
| | U* warp_buf1 = (U*)buf; |
| | U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; |
| | |
| | |
| | cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); |
| | for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { |
| | cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); |
| | } |
| | __syncthreads(); |
| | |
| | |
| | U acc1 = U(0); |
| | U acc2 = U(0); |
| | for (int k = 0; k < blockDim.y; ++k) { |
| | int row1 = threadIdx.y + k*blockDim.y; |
| | int idx1 = row1*row_stride + threadIdx.x; |
| | acc1 += warp_buf1[idx1]; |
| | acc2 += warp_buf2[idx1]; |
| | } |
| | warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; |
| | warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; |
| | __syncthreads(); |
| | |
| | for (int offset = blockDim.y/2; offset > 1; offset /= 2) { |
| | if (threadIdx.y < offset) { |
| | int row1 = threadIdx.y; |
| | int row2 = threadIdx.y + offset; |
| | int idx1 = row1*row_stride + threadIdx.x; |
| | int idx2 = row2*row_stride + threadIdx.x; |
| | warp_buf1[idx1] += warp_buf1[idx2]; |
| | warp_buf2[idx1] += warp_buf2[idx2]; |
| | } |
| | __syncthreads(); |
| | } |
| | int i2 = blockIdx.x * blockDim.x + threadIdx.x; |
| | if (threadIdx.y == 0 && i2 < n2) { |
| | int row1 = threadIdx.y; |
| | int row2 = threadIdx.y + 1; |
| | int idx1 = row1*row_stride + threadIdx.x; |
| | int idx2 = row2*row_stride + threadIdx.x; |
| | part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; |
| | part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; |
| | } |
| | } |
| |
|
| | template<typename U, typename V> __global__ |
| | void cuComputeGradGammaBeta( |
| | const U* part_grad_gamma, |
| | const U* part_grad_beta, |
| | const int part_size, |
| | const int n1, |
| | const int n2, |
| | V* grad_gamma, |
| | V* grad_beta) |
| | { |
| | |
| | SharedMemory<U> shared; |
| | U* buf = shared.getPointer(); |
| | int i2 = blockIdx.x * blockDim.x + threadIdx.x; |
| | if (i2 < n2) { |
| | |
| | int num_warp_reductions = part_size / blockDim.y; |
| | U sum_gamma = U(0); |
| | U sum_beta = U(0); |
| | const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; |
| | const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; |
| | for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { |
| | sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; |
| | sum_beta += part_grad_beta_ptr[warp_offset*n2]; |
| | } |
| | |
| | const int nbsize3 = blockDim.x * blockDim.y / 2; |
| | for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { |
| | |
| | if (threadIdx.y >= offset && threadIdx.y < 2*offset) { |
| | const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; |
| | buf[write_idx] = sum_gamma; |
| | buf[write_idx+nbsize3] = sum_beta; |
| | } |
| | __syncthreads(); |
| | |
| | if (threadIdx.y < offset) { |
| | const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; |
| | sum_gamma += buf[read_idx]; |
| | sum_beta += buf[read_idx+nbsize3]; |
| | } |
| | __syncthreads(); |
| | } |
| | |
| | if (threadIdx.y == 0) { |
| | grad_gamma[i2] = sum_gamma; |
| | grad_beta[i2] = sum_beta; |
| | } |
| | } |
| | } |
| |
|
| | template<typename T, typename U, typename V> __global__ |
| | void cuComputeGradInput( |
| | const V* __restrict__ dout, |
| | const T* __restrict__ input, |
| | const int n1, |
| | const int n2, |
| | const U* __restrict__ mean, |
| | const U* __restrict__ invvar, |
| | U epsilon, |
| | const V* gamma, |
| | T* grad_input) |
| | { |
| | for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { |
| | U sum_loss1 = U(0); |
| | U sum_loss2 = U(0); |
| | const U c_mean = mean[i1]; |
| | const U c_invvar = invvar[i1]; |
| | const T* k_input = input + i1*n2; |
| | const V* k_dout = dout + i1*n2; |
| | const int numx = blockDim.x * blockDim.y; |
| | const int thrx = threadIdx.x + threadIdx.y * blockDim.x; |
| | if (gamma != NULL) { |
| | int l = 4*thrx; |
| | for (; l+3 < n2; l+=4*numx) { |
| | for (int k = 0; k < 4; ++k) { |
| | const U c_h = static_cast<U>(k_input[l+k]); |
| | const U c_loss = static_cast<U>(k_dout[l+k]); |
| | sum_loss1 += c_loss * gamma[l+k]; |
| | sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; |
| | } |
| | } |
| | for (; l < n2; ++l) { |
| | const U c_h = static_cast<U>(k_input[l]); |
| | const U c_loss = static_cast<U>(k_dout[l]); |
| | sum_loss1 += c_loss * gamma[l]; |
| | sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; |
| | } |
| | } else { |
| | int l = 4*thrx; |
| | for (; l+3 < n2; l+=4*numx) { |
| | for (int k = 0; k < 4; ++k) { |
| | const U c_h = static_cast<U>(k_input[l+k]); |
| | const U c_loss = static_cast<U>(k_dout[l+k]); |
| | sum_loss1 += c_loss; |
| | sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; |
| | } |
| | } |
| | for (; l < n2; ++l) { |
| | const U c_h = static_cast<U>(k_input[l]); |
| | const U c_loss = static_cast<U>(k_dout[l]); |
| | sum_loss1 += c_loss; |
| | sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; |
| | } |
| | } |
| | |
| | for (int mask = blockDim.x/2; mask > 0; mask /= 2) { |
| | sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); |
| | sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); |
| | } |
| | |
| | if (blockDim.y > 1) { |
| | SharedMemory<U> shared; |
| | U* buf = shared.getPointer(); |
| | for (int offset = blockDim.y/2; offset > 0; offset /= 2) { |
| | |
| | if (threadIdx.y >= offset && threadIdx.y < 2*offset) { |
| | const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; |
| | buf[2*wrt_i] = sum_loss1; |
| | buf[2*wrt_i+1] = sum_loss2; |
| | } |
| | __syncthreads(); |
| | |
| | if (threadIdx.y < offset) { |
| | const int read_i = threadIdx.y * blockDim.x + threadIdx.x; |
| | sum_loss1 += buf[2*read_i]; |
| | sum_loss2 += buf[2*read_i+1]; |
| | } |
| | __syncthreads(); |
| | } |
| | if (threadIdx.y == 0) { |
| | buf[2*threadIdx.x] = sum_loss1; |
| | buf[2*threadIdx.x+1] = sum_loss2; |
| | } |
| | __syncthreads(); |
| | if (threadIdx.y !=0) { |
| | sum_loss1 = buf[2*threadIdx.x]; |
| | sum_loss2 = buf[2*threadIdx.x+1]; |
| | } |
| | } |
| | |
| | U fH = (U)n2; |
| | U term1 = (U(1) / fH) * c_invvar; |
| | T* k_grad_input = grad_input + i1*n2; |
| | if (gamma != NULL) { |
| | for (int l = thrx; l < n2; l+=numx) { |
| | const U c_h = static_cast<U>(k_input[l]); |
| | const U c_loss = static_cast<U>(k_dout[l]); |
| | U f_grad_input = fH * c_loss * gamma[l]; |
| | f_grad_input -= sum_loss1; |
| | f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; |
| | f_grad_input *= term1; |
| | k_grad_input[l] = static_cast<T>(f_grad_input); |
| | } |
| | } else { |
| | for (int l = thrx; l < n2; l+=numx) { |
| | const U c_h = static_cast<U>(k_input[l]); |
| | const U c_loss = static_cast<U>(k_dout[l]); |
| | U f_grad_input = fH * c_loss; |
| | f_grad_input -= sum_loss1; |
| | f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; |
| | f_grad_input *= term1; |
| | k_grad_input[l] = static_cast<T>(f_grad_input); |
| | } |
| | } |
| | |
| | __syncthreads(); |
| | } |
| | } |
| |
|
| |
|
| |
|
| |
|
| | template<typename T, typename U, typename V> |
| | void HostApplyLayerNorm( |
| | V* output, |
| | U* mean, |
| | U* invvar, |
| | const T* input, |
| | int n1, |
| | int n2, |
| | double epsilon, |
| | const V* gamma, |
| | const V* beta |
| | ) |
| | { |
| | auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| | const dim3 threads(32,4,1); |
| | const uint64_t maxGridY = |
| | at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; |
| | const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); |
| | int nshared = |
| | threads.y > 1 ? |
| | threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : |
| | 0; |
| | cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>( |
| | output, |
| | mean, |
| | invvar, |
| | input, |
| | n1,n2, |
| | U(epsilon), |
| | gamma,beta); |
| | } |
| |
|
| |
|
| | void cuda_layer_norm( |
| | at::Tensor* output, |
| | at::Tensor* mean, |
| | at::Tensor* invvar, |
| | at::Tensor* input, |
| | int n1, |
| | int n2, |
| | #ifdef VERSION_GE_1_1 |
| | at::IntArrayRef normalized_shape, |
| | #else |
| | at::IntList normalized_shape, |
| | #endif |
| | at::Tensor* gamma, |
| | at::Tensor* beta, |
| | double epsilon) |
| | { |
| | using namespace at; |
| | DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( |
| | input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", |
| | HostApplyLayerNorm( |
| | output->DATA_PTR<scalar_t_out>(), |
| | mean->DATA_PTR<float>(), |
| | invvar->DATA_PTR<float>(), |
| | input->DATA_PTR<scalar_t_in>(), |
| | n1,n2, |
| | epsilon, |
| | gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, |
| | beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL); |
| | ) |
| | } |
| |
|
| |
|
| | template<typename T, typename U, typename V> |
| | void HostLayerNormGradient( |
| | const V* dout, |
| | const U* mean, |
| | const U* invvar, |
| | at::Tensor* input, |
| | int n1, |
| | int n2, |
| | const V* gamma, |
| | const V* beta, |
| | double epsilon, |
| | T* grad_input, |
| | V* grad_gamma, |
| | V* grad_beta |
| | ) |
| | { |
| | auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| |
|
| | if (gamma != NULL && beta != NULL) { |
| | |
| | const int part_size = 16; |
| | const dim3 threads2(32,4,1); |
| | const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); |
| | const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * |
| | (threads2.x + 1); |
| | const int nshared2_b = threads2.x * threads2.y * sizeof(U); |
| | const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; |
| | at::Tensor part_grad_gamma = at::empty( |
| | {part_size,n2}, input->options().dtype(at::ScalarType::Float)); |
| | at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); |
| | cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( |
| | dout, |
| | input->DATA_PTR<T>(), |
| | n1,n2, |
| | mean, |
| | invvar, |
| | U(epsilon), |
| | part_grad_gamma.DATA_PTR<U>(), |
| | part_grad_beta.DATA_PTR<U>()); |
| |
|
| | const dim3 threads3(32,8,1); |
| | const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); |
| | const int nshared3 = threads3.x * threads3.y * sizeof(U); |
| | cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>( |
| | part_grad_gamma.DATA_PTR<U>(), |
| | part_grad_beta.DATA_PTR<U>(), |
| | part_size, |
| | n1,n2, |
| | grad_gamma, |
| | grad_beta); |
| | } |
| |
|
| | |
| | const uint64_t maxGridY = |
| | at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; |
| | const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); |
| | const dim3 threads1(32,4,1); |
| | int nshared = |
| | threads1.y > 1 ? |
| | threads1.y*threads1.x*sizeof(U) : |
| | 0; |
| | cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>( |
| | dout, |
| | input->DATA_PTR<T>(), |
| | n1,n2, |
| | mean, |
| | invvar, |
| | U(epsilon), |
| | gamma, |
| | grad_input); |
| | } |
| |
|
| |
|
| | void cuda_layer_norm_gradient( |
| | at::Tensor* dout, |
| | at::Tensor* mean, |
| | at::Tensor* invvar, |
| | at::Tensor* input, |
| | int n1, |
| | int n2, |
| | #ifdef VERSION_GE_1_1 |
| | at::IntArrayRef normalized_shape, |
| | #else |
| | at::IntList normalized_shape, |
| | #endif |
| | at::Tensor* gamma, |
| | at::Tensor* beta, |
| | double epsilon, |
| | at::Tensor* grad_input, |
| | at::Tensor* grad_gamma, |
| | at::Tensor* grad_beta) |
| | { |
| | using namespace at; |
| | DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( |
| | input->scalar_type(), gamma->scalar_type(), |
| | "cuda_layer_norm_gradient_kernel", |
| | HostLayerNormGradient( |
| | dout->DATA_PTR<scalar_t_out>(), |
| | mean->DATA_PTR<float>(), |
| | invvar->DATA_PTR<float>(), |
| | input, |
| | n1,n2, |
| | |
| | |
| | gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL, |
| | gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, |
| | epsilon, |
| | grad_input->DATA_PTR<scalar_t_in>(), |
| | gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL, |
| | gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL); |
| | ) |
| | } |
| |
|