| |
| |
| |
| |
| |
| |
| |
| |
| #include <ATen/core/TensorBase.h> |
| #include <ATen/core/TensorBody.h> |
| #include <c10/cuda/CUDAStream.h> |
| #include <c10/macros/Macros.h> |
| #include <c10/util/Exception.h> |
| #include <cuda.h> |
| #include <cuda_fp16.h> |
|
|
| #include <cstddef> |
| #include <cstdint> |
| #include <optional> |
|
|
| namespace { |
|
|
| constexpr int TopK = 2048; |
| constexpr int kThreadsPerBlock = 1024; |
|
|
| #ifdef USE_ROCM |
| |
| |
| #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES |
| constexpr size_t kSmem = static_cast<size_t>(SGL_TOPK_DYNAMIC_SMEM_BYTES); |
| #else |
| constexpr size_t kSmem = 48 * 1024; |
| #endif |
| #else |
| |
| |
| |
| constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); |
| #endif |
|
|
| struct FastTopKParams { |
| const float* __restrict__ input; |
| const int32_t* __restrict__ row_starts; |
| int32_t* __restrict__ indices; |
| int32_t* __restrict__ lengths; |
| int64_t input_stride; |
| }; |
|
|
| |
| __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { |
| const auto tid = threadIdx.x; |
| for (int i = tid; i < TopK; i += kThreadsPerBlock) { |
| indice[i] = (i < length) ? i : -1; |
| } |
| } |
|
|
| |
| __device__ void naive_topk_transform( |
| const float* __restrict__ score, |
| int32_t length, |
| int32_t* __restrict__ dst_page_table, |
| const int32_t* __restrict__ src_page_table) { |
| const auto tid = threadIdx.x; |
| for (auto i = tid; i < TopK; i += kThreadsPerBlock) { |
| dst_page_table[i] = (i < length) ? src_page_table[i] : -1; |
| } |
| } |
|
|
| |
| __device__ void naive_topk_transform_ragged( |
| const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { |
| const auto tid = threadIdx.x; |
| for (auto i = tid; i < TopK; i += kThreadsPerBlock) { |
| topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1; |
| } |
| } |
|
|
| __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { |
| __half h = __float2half_rn(x); |
| uint16_t bits = __half_as_ushort(h); |
| uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits) : static_cast<uint16_t>(bits | 0x8000); |
| return static_cast<uint8_t>(key >> 8); |
| } |
|
|
| __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { |
| uint32_t bits = __float_as_uint(x); |
| return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); |
| } |
|
|
| __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { |
| |
| |
| int topk = TopK; |
| constexpr auto BLOCK_SIZE = 1024; |
| constexpr auto RADIX = 256; |
| constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); |
|
|
| alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; |
| alignas(128) __shared__ int s_counter; |
| alignas(128) __shared__ int s_threshold_bin_id; |
| alignas(128) __shared__ int s_num_input[2]; |
|
|
| auto& s_histogram = s_histogram_buf[0]; |
| |
| extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; |
|
|
| const int tx = threadIdx.x; |
|
|
| |
| if (tx < RADIX + 1) s_histogram[tx] = 0; |
| __syncthreads(); |
|
|
| for (int idx = tx; idx < length; idx += BLOCK_SIZE) { |
| const auto bin = convert_to_uint8(input[idx + row_start]); |
| ::atomicAdd(&s_histogram[bin], 1); |
| } |
| __syncthreads(); |
|
|
| const auto run_cumsum = [&] { |
| #pragma unroll 8 |
| for (int i = 0; i < 8; ++i) { |
| static_assert(1 << 8 == RADIX); |
| if (C10_LIKELY(tx < RADIX)) { |
| const auto j = 1 << i; |
| const auto k = i & 1; |
| auto value = s_histogram_buf[k][tx]; |
| if (tx < RADIX - j) { |
| value += s_histogram_buf[k][tx + j]; |
| } |
| s_histogram_buf[k ^ 1][tx] = value; |
| } |
| __syncthreads(); |
| } |
| }; |
|
|
| run_cumsum(); |
| if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { |
| s_threshold_bin_id = tx; |
| s_num_input[0] = 0; |
| s_counter = 0; |
| } |
| __syncthreads(); |
|
|
| const auto threshold_bin = s_threshold_bin_id; |
| topk -= s_histogram[threshold_bin + 1]; |
|
|
| if (topk == 0) { |
| for (int idx = tx; idx < length; idx += BLOCK_SIZE) { |
| const auto bin = static_cast<int>(convert_to_uint8(input[idx + row_start])); |
| if (bin > threshold_bin) { |
| const auto pos = ::atomicAdd(&s_counter, 1); |
| index[pos] = idx; |
| } |
| } |
| __syncthreads(); |
| return; |
| } else { |
| __syncthreads(); |
| if (tx < RADIX + 1) { |
| s_histogram[tx] = 0; |
| } |
| __syncthreads(); |
|
|
| for (int idx = tx; idx < length; idx += BLOCK_SIZE) { |
| const auto raw_input = input[idx + row_start]; |
| const auto bin = static_cast<int>(convert_to_uint8(raw_input)); |
| if (bin > threshold_bin) { |
| const auto pos = ::atomicAdd(&s_counter, 1); |
| index[pos] = idx; |
| } else if (bin == threshold_bin) { |
| const auto pos = ::atomicAdd(&s_num_input[0], 1); |
| |
| if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { |
| s_input_idx[0][pos] = idx; |
| const auto bin = convert_to_uint32(raw_input); |
| const auto sub_bin = (bin >> 24) & 0xFF; |
| ::atomicAdd(&s_histogram[sub_bin], 1); |
| } |
| } |
| } |
| __syncthreads(); |
| } |
|
|
| |
| #pragma unroll 4 |
| for (int round = 0; round < 4; ++round) { |
| __shared__ int s_last_remain; |
| const auto r_idx = round % 2; |
|
|
| |
| const auto _raw_num_input = s_num_input[r_idx]; |
| const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); |
|
|
| run_cumsum(); |
| if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { |
| s_threshold_bin_id = tx; |
| s_num_input[r_idx ^ 1] = 0; |
| s_last_remain = topk - s_histogram[tx + 1]; |
| } |
| __syncthreads(); |
|
|
| const auto threshold_bin = s_threshold_bin_id; |
| topk -= s_histogram[threshold_bin + 1]; |
|
|
| if (topk == 0) { |
| for (int i = tx; i < num_input; i += BLOCK_SIZE) { |
| const auto idx = s_input_idx[r_idx][i]; |
| const auto offset = 24 - round * 8; |
| const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; |
| if (bin > threshold_bin) { |
| const auto pos = ::atomicAdd(&s_counter, 1); |
| index[pos] = idx; |
| } |
| } |
| __syncthreads(); |
| break; |
| } else { |
| __syncthreads(); |
| if (tx < RADIX + 1) { |
| s_histogram[tx] = 0; |
| } |
| __syncthreads(); |
| for (int i = tx; i < num_input; i += BLOCK_SIZE) { |
| const auto idx = s_input_idx[r_idx][i]; |
| const auto raw_input = input[idx + row_start]; |
| const auto offset = 24 - round * 8; |
| const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; |
| if (bin > threshold_bin) { |
| const auto pos = ::atomicAdd(&s_counter, 1); |
| index[pos] = idx; |
| } else if (bin == threshold_bin) { |
| if (round == 3) { |
| const auto pos = ::atomicAdd(&s_last_remain, -1); |
| if (pos > 0) { |
| index[TopK - pos] = idx; |
| } |
| } else { |
| const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); |
| if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { |
| |
| s_input_idx[r_idx ^ 1][pos] = idx; |
| const auto bin = convert_to_uint32(raw_input); |
| const auto sub_bin = (bin >> (offset - 8)) & 0xFF; |
| ::atomicAdd(&s_histogram[sub_bin], 1); |
| } |
| } |
| } |
| } |
| __syncthreads(); |
| } |
| } |
| } |
|
|
| __global__ __launch_bounds__(kThreadsPerBlock) |
| void topk_kernel(const FastTopKParams params) { |
| const auto& [input, row_starts, indices, lengths, input_stride] = params; |
| const auto bid = static_cast<uint64_t>(blockIdx.x); |
| const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; |
| const auto length = lengths[bid]; |
| const auto indice = indices + bid * TopK; |
| const auto score = input + bid * input_stride; |
| if (length <= TopK) { |
| return naive_topk_cuda(score, indice, length); |
| } else { |
| return fast_topk_cuda_tl(score, indice, row_start, length); |
| } |
| } |
|
|
| __global__ __launch_bounds__(kThreadsPerBlock) |
| void topk_transform_decode_kernel( |
| const FastTopKParams params, |
| int32_t* __restrict__ dst_page_table, |
| const int32_t* __restrict__ src_page_table, |
| const int64_t src_stride) { |
| const auto& [input, _1, _2, lengths, input_stride] = params; |
| const auto bid = static_cast<uint64_t>(blockIdx.x); |
| const auto tid = threadIdx.x; |
| const auto row_start = 0; |
| const auto length = lengths[bid]; |
| const auto src_page_entry = src_page_table + bid * src_stride; |
| const auto dst_page_entry = dst_page_table + bid * TopK; |
| const auto score = input + bid * input_stride; |
| if (length <= TopK) { |
| return naive_topk_transform(score, length, dst_page_entry, src_page_entry); |
| } else { |
| __shared__ int s_indices[TopK]; |
| fast_topk_cuda_tl(score, s_indices, row_start, length); |
| |
| static_assert(TopK % kThreadsPerBlock == 0); |
| static_assert(TopK / kThreadsPerBlock == 2); |
| const auto idx_0 = tid; |
| const auto pos_0 = s_indices[idx_0]; |
| dst_page_entry[idx_0] = src_page_entry[pos_0]; |
| const auto idx_1 = tid + kThreadsPerBlock; |
| const auto pos_1 = s_indices[idx_1]; |
| dst_page_entry[idx_1] = src_page_entry[pos_1]; |
| } |
| } |
|
|
| __global__ __launch_bounds__(kThreadsPerBlock) |
| void topk_transform_prefill_kernel( |
| const FastTopKParams params, |
| int32_t* __restrict__ dst_page_table, |
| const int32_t* __restrict__ src_page_table, |
| const int64_t src_stride, |
| const int32_t* __restrict__ cu_seqlens_q, |
| const int64_t prefill_bs) { |
| const auto& [input, row_starts, _, lengths, input_stride] = params; |
| const auto bid = static_cast<uint64_t>(blockIdx.x); |
| const auto tid = threadIdx.x; |
| const auto length = lengths[bid]; |
| const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; |
| const auto dst_page_entry = dst_page_table + bid * TopK; |
| const auto score = input + bid * input_stride; |
|
|
| |
| |
| __shared__ const int32_t* s_src_page_entry; |
| if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { |
| if (tid < prefill_bs) { |
| if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { |
| s_src_page_entry = src_page_table + tid * src_stride; |
| } |
| } |
| } else { |
| for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { |
| if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { |
| s_src_page_entry = src_page_table + i * src_stride; |
| } |
| } |
| } |
| __syncthreads(); |
| const auto src_page_entry = s_src_page_entry; |
|
|
| if (length <= TopK) { |
| return naive_topk_transform(score, length, dst_page_entry, src_page_entry); |
| } else { |
| __shared__ int s_indices[TopK]; |
| fast_topk_cuda_tl(score, s_indices, row_start, length); |
| |
| static_assert(TopK % kThreadsPerBlock == 0); |
| static_assert(TopK / kThreadsPerBlock == 2); |
| const auto idx_0 = tid; |
| const auto pos_0 = s_indices[idx_0]; |
| dst_page_entry[idx_0] = src_page_entry[pos_0]; |
| const auto idx_1 = tid + kThreadsPerBlock; |
| const auto pos_1 = s_indices[idx_1]; |
| dst_page_entry[idx_1] = src_page_entry[pos_1]; |
| } |
| } |
|
|
| __global__ __launch_bounds__(kThreadsPerBlock) |
| void topk_transform_prefill_ragged_kernel( |
| const FastTopKParams params, |
| int32_t* __restrict__ topk_indices_ragged, |
| const int32_t* __restrict__ topk_indices_offset) { |
| const auto& [input, row_starts, _, lengths, input_stride] = params; |
| const auto bid = static_cast<uint64_t>(blockIdx.x); |
| const auto tid = threadIdx.x; |
| const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; |
| const auto length = lengths[bid]; |
| const auto dst_indices_entry = topk_indices_ragged + bid * TopK; |
| const auto score = input + bid * input_stride; |
| const auto offset = topk_indices_offset[bid]; |
|
|
| if (length <= TopK) { |
| return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); |
| } else { |
| __shared__ int s_indices[TopK]; |
| fast_topk_cuda_tl(score, s_indices, row_start, length); |
| |
| static_assert(TopK % kThreadsPerBlock == 0); |
| static_assert(TopK / kThreadsPerBlock == 2); |
| const auto idx_0 = tid; |
| const auto pos_0 = s_indices[idx_0]; |
| dst_indices_entry[idx_0] = pos_0 + offset; |
| const auto idx_1 = tid + kThreadsPerBlock; |
| const auto pos_1 = s_indices[idx_1]; |
| dst_indices_entry[idx_1] = pos_1 + offset; |
| } |
| } |
|
|
| auto get_params( |
| const at::Tensor& score, |
| const at::Tensor& lengths, |
| std::optional<at::Tensor> row_starts_opt = std::nullopt, |
| std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams { |
| const auto B = score.size(0); |
| TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); |
| if (row_starts_opt.has_value()) { |
| const auto& row_starts = row_starts_opt.value(); |
| TORCH_CHECK(row_starts.dim() == 1); |
| TORCH_CHECK(row_starts.size(0) == B); |
| } |
| TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); |
| TORCH_CHECK(lengths.size(0) == B); |
| int32_t* indices_data_ptr = nullptr; |
| if (indices_opt.has_value()) { |
| const auto& indices = indices_opt.value(); |
| TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); |
| TORCH_CHECK(indices.size(0) == B); |
| TORCH_CHECK(indices.size(1) == TopK); |
| indices_data_ptr = indices.data_ptr<int32_t>(); |
| } |
|
|
| return FastTopKParams{ |
| .input = score.data_ptr<float>(), |
| .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr<int32_t>() : nullptr, |
| .indices = indices_data_ptr, |
| .lengths = lengths.data_ptr<int32_t>(), |
| .input_stride = score.stride(0), |
| }; |
| } |
|
|
| template <auto* f, size_t max_dynamic_smem> |
| void setup_kernel_smem_once() { |
| [[maybe_unused]] |
| static const auto result = [] { |
| #ifdef USE_ROCM |
| |
| |
| |
| return ::cudaFuncSetAttribute( |
| reinterpret_cast<const void*>(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); |
| #else |
| |
| return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); |
| #endif |
| }(); |
| TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); |
| } |
|
|
| } |
|
|
| #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") |
|
|
| void fast_topk_interface( |
| const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional<at::Tensor> row_starts_opt) { |
| CHECK_CUDA(score); |
| CHECK_CUDA(indices); |
| if (row_starts_opt.has_value()) { |
| CHECK_CUDA(row_starts_opt.value()); |
| } |
| CHECK_CUDA(lengths); |
| const auto params = get_params(score, lengths, row_starts_opt, indices); |
| const auto B = score.size(0); |
| const auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| const auto grid = dim3{static_cast<uint32_t>(B)}; |
| const auto block = dim3{kThreadsPerBlock}; |
| setup_kernel_smem_once<topk_kernel, kSmem>(); |
| topk_kernel<<<grid, block, kSmem, stream>>>(params); |
| const auto result = cudaGetLastError(); |
| TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); |
| } |
|
|
| void fast_topk_transform_interface( |
| const at::Tensor& score, |
| const at::Tensor& lengths, |
| at::Tensor& dst_page_table, |
| const at::Tensor& src_page_table, |
| const at::Tensor& cu_seqlens_q, |
| std::optional<at::Tensor> row_starts_opt) { |
| CHECK_CUDA(score); |
| CHECK_CUDA(lengths); |
| CHECK_CUDA(dst_page_table); |
| CHECK_CUDA(src_page_table); |
| CHECK_CUDA(cu_seqlens_q); |
| if (row_starts_opt.has_value()) { |
| CHECK_CUDA(row_starts_opt.value()); |
| } |
| const auto params = get_params(score, lengths, row_starts_opt); |
| const auto B = score.size(0); |
| TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); |
| TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); |
| TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); |
| const auto prefill_bs = cu_seqlens_q.size(0) - 1; |
| TORCH_CHECK(dst_page_table.size(0) == B); |
| TORCH_CHECK(dst_page_table.size(1) == TopK); |
| TORCH_CHECK(src_page_table.size(0) == prefill_bs); |
| TORCH_CHECK(prefill_bs <= B); |
|
|
| |
| const auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| const auto grid = dim3{static_cast<uint32_t>(B)}; |
| const auto block = dim3{kThreadsPerBlock}; |
| const auto src_stride = src_page_table.stride(0); |
|
|
| |
| |
| |
| |
| const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; |
| if (is_decode) { |
| setup_kernel_smem_once<topk_transform_decode_kernel, kSmem>(); |
| topk_transform_decode_kernel<<<grid, block, kSmem, stream>>>( |
| params, dst_page_table.data_ptr<int32_t>(), src_page_table.data_ptr<int32_t>(), src_stride); |
| } else { |
| setup_kernel_smem_once<topk_transform_prefill_kernel, kSmem>(); |
| topk_transform_prefill_kernel<<<grid, block, kSmem, stream>>>( |
| params, |
| dst_page_table.data_ptr<int32_t>(), |
| src_page_table.data_ptr<int32_t>(), |
| src_stride, |
| cu_seqlens_q.data_ptr<int32_t>(), |
| prefill_bs); |
| } |
|
|
| const auto result = cudaGetLastError(); |
| TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); |
| } |
|
|
| void fast_topk_transform_ragged_interface( |
| const at::Tensor& score, |
| const at::Tensor& lengths, |
| at::Tensor& topk_indices_ragged, |
| const at::Tensor& topk_indices_offset, |
| std::optional<at::Tensor> row_starts_opt) { |
| CHECK_CUDA(score); |
| CHECK_CUDA(lengths); |
| CHECK_CUDA(topk_indices_ragged); |
| CHECK_CUDA(topk_indices_offset); |
| if (row_starts_opt.has_value()) { |
| CHECK_CUDA(row_starts_opt.value()); |
| } |
|
|
| const auto params = get_params(score, lengths, row_starts_opt); |
| const auto B = score.size(0); |
| TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); |
| TORCH_CHECK(topk_indices_offset.dim() == 1); |
|
|
| TORCH_CHECK(topk_indices_ragged.size(0) == B); |
| TORCH_CHECK(topk_indices_ragged.size(1) == TopK); |
| TORCH_CHECK(topk_indices_offset.size(0) == B); |
|
|
| |
| const auto stream = at::cuda::getCurrentCUDAStream().stream(); |
| const auto grid = dim3{static_cast<uint32_t>(B)}; |
| const auto block = dim3{kThreadsPerBlock}; |
|
|
| setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>(); |
| topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>( |
| params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>()); |
|
|
| const auto result = cudaGetLastError(); |
| TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); |
| } |
|
|