|
|
#pragma once |
|
|
|
|
|
#include <assert.h> |
|
|
#include <cfloat> |
|
|
#include <limits> |
|
|
#include <stdint.h> |
|
|
#include <cuda_fp16.h> |
|
|
#include <c10/macros/Macros.h> |
|
|
|
|
|
#include <ATen/cuda/DeviceUtils.cuh> |
|
|
|
|
|
namespace { |
|
|
|
|
|
int log2_ceil(int value) { |
|
|
int log2_value = 0; |
|
|
while ((1 << log2_value) < value) ++log2_value; |
|
|
return log2_value; |
|
|
} |
|
|
|
|
|
template<typename T> |
|
|
struct Add { |
|
|
__device__ __forceinline__ T operator()(T a, T b) const { |
|
|
return a + b; |
|
|
} |
|
|
}; |
|
|
|
|
|
template<typename T> |
|
|
struct Max { |
|
|
__device__ __forceinline__ T operator()(T a, T b) const { |
|
|
return a < b ? b : a; |
|
|
} |
|
|
}; |
|
|
|
|
|
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp> |
|
|
__device__ __forceinline__ void warp_reduce(acc_t* sum) { |
|
|
ReduceOp<acc_t> r; |
|
|
#pragma unroll |
|
|
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { |
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
acc_t b = WARP_SHFL_XOR(sum[i], offset, WARP_SIZE); |
|
|
sum[i] = r(sum[i], b); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked> |
|
|
__global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false) |
|
|
{ |
|
|
|
|
|
constexpr int next_power_of_two = 1 << log2_elements; |
|
|
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; |
|
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; |
|
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; |
|
|
|
|
|
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; |
|
|
|
|
|
|
|
|
|
|
|
int local_batches = batch_size - first_batch; |
|
|
if (local_batches > WARP_BATCH) |
|
|
local_batches = WARP_BATCH; |
|
|
|
|
|
|
|
|
int local_idx = threadIdx.x; |
|
|
int idx_offset = first_batch * stride + local_idx; |
|
|
|
|
|
src += idx_offset; |
|
|
dst += idx_offset; |
|
|
|
|
|
if (is_transformer_mask) { |
|
|
mask += ((first_batch * stride) / head_chunk_size) * stride + local_idx; |
|
|
} else { |
|
|
mask += idx_offset; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
int batch_element_count = (i >= local_batches) ? 0 : element_count; |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
int element_index = local_idx + it * WARP_SIZE; |
|
|
if (element_index < batch_element_count) { |
|
|
elements[i][it] = src[i*element_count+it*WARP_SIZE]; |
|
|
} else { |
|
|
elements[i][it] = -std::numeric_limits<acc_t>::infinity(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
acc_t max_value[WARP_BATCH]; |
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
int batch_element_count = (i >= local_batches) ? 0 : element_count; |
|
|
bool is_meaningful_max = false; |
|
|
max_value[i] = elements[i][0]; |
|
|
#pragma unroll |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
if (is_masked) { |
|
|
int idx = it*WARP_SIZE; |
|
|
if ((idx + local_idx) < batch_element_count) { |
|
|
if (!is_transformer_mask) { |
|
|
idx += i*element_count; |
|
|
} |
|
|
if (!mask[idx]) { |
|
|
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; |
|
|
is_meaningful_max = true; |
|
|
} |
|
|
} |
|
|
} else { |
|
|
max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it]; |
|
|
} |
|
|
} |
|
|
if (is_masked) { |
|
|
if (!is_meaningful_max) { |
|
|
max_value[i] = -std::numeric_limits<acc_t>::infinity(); |
|
|
} |
|
|
} |
|
|
} |
|
|
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value); |
|
|
|
|
|
acc_t sum[WARP_BATCH] { 0.0f }; |
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
int batch_element_count = (i >= local_batches) ? 0 : element_count; |
|
|
#pragma unroll |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
if (!is_masked) { |
|
|
if (is_log_softmax) { |
|
|
sum[i] += std::exp(elements[i][it] - max_value[i]); |
|
|
} else { |
|
|
elements[i][it] = std::exp(elements[i][it] - max_value[i]); |
|
|
sum[i] += elements[i][it]; |
|
|
} |
|
|
} else { |
|
|
int idx = it*WARP_SIZE; |
|
|
bool valid = (idx + local_idx) < batch_element_count; |
|
|
if (!is_transformer_mask) { |
|
|
idx += i*element_count; |
|
|
} |
|
|
if (valid) { |
|
|
if (!mask[idx]) { |
|
|
if (is_log_softmax) { |
|
|
sum[i] += std::exp(elements[i][it] - max_value[i]); |
|
|
} else { |
|
|
elements[i][it] = std::exp(elements[i][it] - max_value[i]); |
|
|
sum[i] += elements[i][it]; |
|
|
} |
|
|
} else { |
|
|
if (!is_log_softmax) { |
|
|
|
|
|
elements[i][it] = 0; |
|
|
} |
|
|
} |
|
|
} else { |
|
|
if (!is_log_softmax) { |
|
|
elements[i][it] = 0.; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
if (i >= local_batches) |
|
|
break; |
|
|
if (is_log_softmax) sum[i] = std::log(sum[i]); |
|
|
#pragma unroll |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
int element_index = local_idx + it * WARP_SIZE; |
|
|
if (element_index < element_count) { |
|
|
if (is_log_softmax) { |
|
|
dst[i*element_count+it*WARP_SIZE] = elements[i][it] - max_value[i] - sum[i]; |
|
|
} else if (sum[i] == 0) { |
|
|
dst[i*element_count+it*WARP_SIZE] = std::numeric_limits<acc_t>::quiet_NaN(); |
|
|
} else { |
|
|
dst[i*element_count+it*WARP_SIZE] = elements[i][it] / sum[i]; |
|
|
} |
|
|
} else { |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
template <typename input_t, typename output_t, typename acc_t, int log2_elements, bool is_log_softmax, bool is_masked> |
|
|
__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr) |
|
|
{ |
|
|
|
|
|
constexpr int next_power_of_two = 1 << log2_elements; |
|
|
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; |
|
|
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; |
|
|
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; |
|
|
|
|
|
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; |
|
|
|
|
|
|
|
|
|
|
|
int local_batches = batch_size - first_batch; |
|
|
if (local_batches > WARP_BATCH) |
|
|
local_batches = WARP_BATCH; |
|
|
|
|
|
|
|
|
int local_idx = threadIdx.x % WARP_SIZE; |
|
|
|
|
|
|
|
|
int thread_offset = first_batch * stride + local_idx; |
|
|
grad += thread_offset; |
|
|
output += thread_offset; |
|
|
gradInput += thread_offset; |
|
|
if (is_masked) { |
|
|
mask += thread_offset; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; |
|
|
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
int batch_element_count = (i >= local_batches) ? 0 : element_count; |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
int element_index = local_idx + it * WARP_SIZE; |
|
|
if (element_index < batch_element_count) { |
|
|
grad_reg[i][it] = grad[i*element_count+it*WARP_SIZE]; |
|
|
output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; |
|
|
} else { |
|
|
grad_reg[i][it] = acc_t(0); |
|
|
output_reg[i][it] = acc_t(0); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
acc_t sum[WARP_BATCH] { 0.0f }; |
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
#pragma unroll |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
if (!is_masked || !mask[i*element_count+it*WARP_SIZE]) { |
|
|
sum[i] += grad_reg[i][it]; |
|
|
} |
|
|
} |
|
|
} |
|
|
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
|
if (i >= local_batches) |
|
|
break; |
|
|
#pragma unroll |
|
|
for (int it = 0; it < WARP_ITERATIONS; ++it) { |
|
|
int element_index = local_idx + it * WARP_SIZE; |
|
|
if (element_index < element_count) { |
|
|
if (is_masked && mask[i*element_count+it*WARP_SIZE]) { |
|
|
gradInput[i*element_count+it*WARP_SIZE] = 0; |
|
|
} |
|
|
|
|
|
else if (is_log_softmax) { |
|
|
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); |
|
|
} else { |
|
|
gradInput[i*element_count+it*WARP_SIZE] = (grad_reg[i][it] - output_reg[i][it] * sum[i]); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked> |
|
|
void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr, int chunk_size = -1, bool is_transformer_mask = false) |
|
|
{ |
|
|
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); |
|
|
if (softmax_elements == 0) { |
|
|
return; |
|
|
} else { |
|
|
int log2_elements = log2_ceil(softmax_elements); |
|
|
const int next_power_of_two = 1 << log2_elements; |
|
|
|
|
|
|
|
|
int warp_size = at::cuda::warp_size(); |
|
|
warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; |
|
|
|
|
|
|
|
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; |
|
|
|
|
|
|
|
|
constexpr int threads_per_block = 128; |
|
|
|
|
|
int warps_per_block = (threads_per_block / warp_size); |
|
|
int batches_per_block = warps_per_block * batches_per_warp; |
|
|
int blocks = (batch_count + batches_per_block - 1) / batches_per_block; |
|
|
dim3 threads(warp_size, warps_per_block, 1); |
|
|
|
|
|
switch (log2_elements) { |
|
|
#define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ |
|
|
softmax_warp_forward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \ |
|
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, \ |
|
|
src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); \ |
|
|
break; |
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(0); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(1); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(2); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(3); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(4); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(5); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(6); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(7); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(8); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(9); |
|
|
LAUNCH_SOFTMAX_WARP_FORWARD(10); ; |
|
|
default: |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
template<typename input_t, typename output_t, typename acc_t, bool is_log_softmax, bool is_masked> |
|
|
void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output, int softmax_elements, int softmax_elements_stride, int batch_count, const bool *mask = nullptr) |
|
|
{ |
|
|
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 1024 ); |
|
|
if (softmax_elements == 0) { |
|
|
return; |
|
|
} else { |
|
|
int log2_elements = log2_ceil(softmax_elements); |
|
|
const int next_power_of_two = 1 << log2_elements; |
|
|
|
|
|
|
|
|
int warp_size = at::cuda::warp_size(); |
|
|
warp_size = (next_power_of_two < warp_size) ? next_power_of_two : warp_size; |
|
|
|
|
|
|
|
|
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; |
|
|
|
|
|
|
|
|
constexpr int threads_per_block = 128; |
|
|
|
|
|
int warps_per_block = (threads_per_block / warp_size); |
|
|
int batches_per_block = warps_per_block * batches_per_warp; |
|
|
int blocks = (batch_count + batches_per_block - 1) / batches_per_block; |
|
|
dim3 threads(warp_size, warps_per_block, 1); |
|
|
|
|
|
switch (log2_elements) { |
|
|
#define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ |
|
|
softmax_warp_backward<input_t, output_t, acc_t, L2E, is_log_softmax, is_masked> \ |
|
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>> \ |
|
|
(grad_input, grad, output, batch_count, softmax_elements_stride, \ |
|
|
softmax_elements, mask); \ |
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); \ |
|
|
break; |
|
|
|
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(0); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(1); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(2); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(3); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(4); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(5); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(6); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(7); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(8); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(9); |
|
|
LAUNCH_SOFTMAX_WARP_BACKWARD(10); |
|
|
default: |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|