| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include <torch/all.h> |
|
|
| #include <ATen/cuda/CUDAContext.h> |
| #include <c10/cuda/CUDAGuard.h> |
| #include <cuda.h> |
| #include <cuda_fp16.h> |
| #include <cuda_runtime.h> |
|
|
| #include <iostream> |
|
|
| #include "core/exception.hpp" |
| #include "core/scalar_type.hpp" |
| #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" |
| #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" |
| #include "marlin_kernels/marlin_moe_kernel_ku4.h" |
|
|
| template <typename T> |
| inline std::string str(T x) { |
| return std::to_string(x); |
| } |
|
|
| namespace marlin_moe { |
|
|
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
|
|
| |
| |
| __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, |
| int const* __restrict__ perm_int_ptr, |
| int4* __restrict__ out_int4_ptr, int size_m, |
| int size_k, int block_rows) { |
| int start_row = block_rows * blockIdx.x; |
| int finish_row = start_row + block_rows; |
| if (finish_row > size_m) { |
| finish_row = size_m; |
| } |
| int cur_block_rows = finish_row - start_row; |
|
|
| int row_stride = size_k * sizeof(half) / 16; |
|
|
| auto permute_row = [&](int row) { |
| int iters = size_k / blockDim.x; |
| int rest = size_k % blockDim.x; |
|
|
| int offset = row * row_stride; |
|
|
| half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset); |
| half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset); |
|
|
| int base_k = 0; |
|
|
| for (int i = 0; i < iters; i++) { |
| int cur_k = base_k + threadIdx.x; |
| int src_pos = perm_int_ptr[cur_k]; |
|
|
| out_half[cur_k] = a_row_half[src_pos]; |
|
|
| base_k += blockDim.x; |
| } |
|
|
| if (rest) { |
| if (threadIdx.x < rest) { |
| int cur_k = base_k + threadIdx.x; |
| int src_pos = perm_int_ptr[cur_k]; |
|
|
| out_half[cur_k] = a_row_half[src_pos]; |
| } |
| } |
| }; |
|
|
| for (int i = 0; i < cur_block_rows; i++) { |
| int cur_row = start_row + i; |
| if (cur_row < size_m) { |
| permute_row(cur_row); |
| } |
| } |
| } |
|
|
| __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, |
| int* __restrict__ expert_offsets, |
| int topk_length, int block_size) { |
| int expert_id = threadIdx.x; |
| int num_experts = blockDim.x; |
|
|
| int occurrences = 0; |
| for (int i = 0; i < topk_length; ++i) { |
| occurrences += (topk_ids[i] == expert_id); |
| } |
| expert_offsets[expert_id + 1] = occurrences; |
| __syncthreads(); |
|
|
| if (threadIdx.x == 0) { |
| int tot_offset = 0; |
| expert_offsets[0] = 0; |
| for (int i = 0; i < num_experts; ++i) { |
| tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; |
| expert_offsets[i + 1] = tot_offset; |
| } |
| } |
| __syncthreads(); |
| } |
|
|
| #else |
|
|
| __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, |
| int const* __restrict__ perm_int_ptr, |
| int4* __restrict__ out_int4_ptr, int size_m, |
| int size_k, int block_rows) { |
| |
| assert(false); |
| return; |
| } |
|
|
| __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, |
| int* __restrict__ expert_offsets, |
| int topk_length, int block_size) { |
| |
| assert(false); |
| return; |
| } |
|
|
| #endif |
|
|
| typedef struct { |
| int thread_k; |
| int thread_n; |
| int num_threads; |
| } thread_config_t; |
|
|
| typedef struct { |
| int max_m_blocks; |
| thread_config_t tb_cfg; |
| } exec_config_t; |
|
|
| thread_config_t small_batch_thread_configs[] = { |
| |
|
|
| |
| {128, 128, 256}, |
| {128, 64, 128}, |
| {64, 256, 256}, |
| {64, 128, 128}, |
| {64, 64, 128}, |
| }; |
|
|
| thread_config_t large_batch_thread_configs[] = { |
| |
|
|
| |
| {64, 256, 256}, |
| {128, 128, 256}, |
| {64, 128, 128}, |
| {128, 64, 128}, |
| {64, 64, 128}, |
| }; |
|
|
| int get_scales_cache_size(thread_config_t const& th_config, int prob_m, |
| int prob_n, int prob_k, int num_bits, int group_size, |
| bool has_act_order, bool is_k_full) { |
| bool cache_scales_chunk = has_act_order && !is_k_full; |
|
|
| int tb_n = th_config.thread_n; |
| int tb_k = th_config.thread_k; |
|
|
| |
| int tb_groups; |
| if (group_size == -1) { |
| tb_groups = 1; |
| } else if (group_size == 0) { |
| tb_groups = ceildiv(tb_k, 32); |
| } else { |
| tb_groups = ceildiv(tb_k, group_size); |
| } |
|
|
| if (cache_scales_chunk) { |
| int load_groups = |
| tb_groups * STAGES * 2; |
| load_groups = max(load_groups, 32); |
| return load_groups * tb_n * 4; |
|
|
| } else { |
| int tb_scales = tb_groups * tb_n * 2; |
|
|
| return tb_scales * STAGES; |
| } |
| } |
|
|
| bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, |
| int prob_m, int prob_n, int prob_k, int num_bits, |
| int scales_cache_size, int max_shared_mem) { |
| int pack_factor = 32 / num_bits; |
|
|
| |
| int tb_k = th_config.thread_k; |
| int tb_n = th_config.thread_n; |
|
|
| int b_size = (tb_k * tb_n / pack_factor) * 4; |
|
|
| |
| int m_blocks = ceildiv(prob_m, 16); |
| int tb_max_m = 16; |
|
|
| while (true) { |
| if (m_blocks >= max_m_blocks) { |
| tb_max_m *= max_m_blocks; |
| break; |
| } |
|
|
| max_m_blocks--; |
| if (max_m_blocks == 0) { |
| TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); |
| } |
| } |
|
|
| int a_size = (tb_max_m * tb_k) * 2; |
|
|
| float pipe_size = (a_size + b_size) * STAGES; |
|
|
| TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); |
|
|
| return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); |
| } |
|
|
| bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, |
| int prob_m, int prob_n, int prob_k, int num_bits, |
| int group_size, bool has_act_order, bool is_k_full, |
| int max_shared_mem) { |
| |
| if (th_config.thread_k == -1 || th_config.thread_n == -1 || |
| th_config.num_threads == -1) { |
| return false; |
| } |
|
|
| |
| if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { |
| return false; |
| } |
|
|
| |
| |
| if (th_config.thread_k != 128 && th_config.thread_k != 64) { |
| return false; |
| } |
|
|
| |
| if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { |
| return false; |
| } |
|
|
| |
| if (th_config.num_threads < 128) { |
| return false; |
| } |
|
|
| |
| int scales_cache_size = |
| get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, |
| group_size, has_act_order, is_k_full); |
|
|
| |
| if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
| num_bits, scales_cache_size, max_shared_mem)) { |
| return false; |
| } |
|
|
| return true; |
| } |
|
|
| exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, |
| int num_bits, int group_size, |
| bool has_act_order, bool is_k_full, |
| int max_shared_mem) { |
| int max_m_blocks = 4; |
| while (max_m_blocks > 0) { |
| if (prob_m <= 16) { |
| for (auto th_config : small_batch_thread_configs) { |
| if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
| num_bits, group_size, has_act_order, is_k_full, |
| max_shared_mem)) { |
| return exec_config_t{max_m_blocks, th_config}; |
| } |
| } |
| } else { |
| for (auto th_config : large_batch_thread_configs) { |
| if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, |
| num_bits, group_size, has_act_order, is_k_full, |
| max_shared_mem)) { |
| return exec_config_t{max_m_blocks, th_config}; |
| } |
| } |
| } |
|
|
| max_m_blocks--; |
| |
| } |
|
|
| return exec_config_t{0, {-1, -1, -1}}; |
| } |
|
|
| #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ |
| else if (KERNEL_FUNCTION( \ |
| q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ |
| group_blocks, num_threads, blocks, max_shared_mem, stream, \ |
| A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ |
| zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ |
| num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ |
| replicate_input, apply_weights, m_block, max_par, \ |
| exec_cfg.max_m_blocks)) { \ |
| } |
|
|
| void marlin_mm_moe(const void* A, const void* B, void* C, |
| const void* sorted_ids, const void* topk_weights, |
| const void* topk_ids, const void* s, void* zp, |
| const void* g_idx, const void* perm, void* a_tmp, |
| void* expert_offsets, int prob_m, int prob_n, int prob_k, |
| void* workspace, vllm::ScalarType const& q_type, |
| bool has_act_order, bool is_k_full, bool has_zp, |
| int num_groups, int group_size, int num_experts, int topk, |
| int moe_block_size, int dev, cudaStream_t stream, |
| int thread_k, int thread_n, int sms, int max_par, |
| bool replicate_input, bool apply_weights) { |
| TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, |
| ", ", prob_n, ", ", prob_k, "]"); |
|
|
| if (sms == -1) { |
| cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); |
| } |
|
|
| int max_shared_mem = 0; |
| cudaDeviceGetAttribute(&max_shared_mem, |
| cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); |
| TORCH_CHECK(max_shared_mem > 0); |
|
|
| int num_bits = q_type.size_bits(); |
|
|
| |
| exec_config_t exec_cfg; |
| if (thread_k != -1 && thread_n != -1) { |
| |
| exec_cfg = |
| exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; |
| } else { |
| |
| exec_cfg = |
| determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, |
| has_act_order, is_k_full, max_shared_mem); |
| } |
|
|
| TORCH_CHECK(exec_cfg.max_m_blocks > 0 && |
| is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, |
| prob_m, prob_n, prob_k, num_bits, group_size, |
| has_act_order, is_k_full, max_shared_mem), |
| "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, |
| ", thread_k = ", exec_cfg.tb_cfg.thread_k, |
| ", thread_n = ", exec_cfg.tb_cfg.thread_n, |
| ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", |
| prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, |
| ", group_size = ", group_size, |
| ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, |
| ", max_shared_mem = ", max_shared_mem); |
|
|
| int num_threads = exec_cfg.tb_cfg.num_threads; |
| thread_k = exec_cfg.tb_cfg.thread_k; |
| thread_n = exec_cfg.tb_cfg.thread_n; |
|
|
| int thread_k_blocks = thread_k / 16; |
| int thread_n_blocks = thread_n / 16; |
|
|
| int blocks = sms; |
|
|
| TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, |
| " is not divisible by thread_n = ", thread_n); |
| TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, |
| " is not divisible by thread_k = ", thread_k); |
|
|
| int group_blocks = 0; |
| if (has_act_order) { |
| if (is_k_full) { |
| TORCH_CHECK(group_size != -1); |
| group_blocks = group_size / 16; |
| TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, |
| " is not divisible by group_blocks = ", group_blocks); |
| } else { |
| TORCH_CHECK(group_size == 0); |
| group_blocks = 0; |
| } |
|
|
| } else { |
| if (group_size == -1) { |
| group_blocks = -1; |
| } else { |
| group_blocks = group_size / 16; |
| TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, |
| " is not divisible by group_blocks = ", group_blocks); |
| } |
| } |
|
|
| int tot_m = prob_m; |
|
|
| const int* topk_ids_ptr = (const int*)topk_ids; |
| int* expert_offsets_ptr = (int*)expert_offsets; |
| compute_expert_offsets<<<1, num_experts, 0, stream>>>( |
| topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); |
|
|
| bool do_permute_a = has_act_order; |
|
|
| |
| |
| |
| if (is_k_full) { |
| has_act_order = false; |
| } |
|
|
| int pack_factor = 32 / q_type.size_bits(); |
|
|
| for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { |
| const int4* A_ptr = (const int4*)A; |
| int4* a_tmp_ptr = (int4*)a_tmp; |
| const int4* B_ptr = |
| (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; |
| int4* C_ptr = (int4*)C; |
| const float* topk_weights_ptr = (const float*)topk_weights; |
| const int* sorted_ids_ptr = (const int*)sorted_ids; |
| const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; |
| const int4* zp_ptr = |
| (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; |
| const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; |
| const int* perm_ptr = (const int*)perm + prob_k * expert_idx; |
| int* locks = (int*)workspace; |
|
|
| if (do_permute_a) { |
| |
| int topk_rows = replicate_input ? tot_m : tot_m * topk; |
| int block_rows = ceildiv(topk_rows, blocks); |
| permute_cols_kernel<<<blocks, num_threads, 0, stream>>>( |
| A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); |
| A_ptr = a_tmp_ptr; |
| } |
|
|
| int tot_m_blocks = ceildiv(tot_m, 16); |
| for (int m_block = 0; m_block < tot_m_blocks; |
| m_block += 4 * exec_cfg.max_m_blocks) { |
| if (false) { |
| } |
| CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) |
| CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) |
| CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) |
| else { |
| TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + |
| str(prob_n) + ", " + str(prob_k) + "]" + |
| ", has_act_order = " + str(has_act_order) + |
| ", num_groups = " + str(num_groups) + |
| ", group_size = " + str(group_size) + |
| ", thread_n_blocks = " + str(thread_n_blocks) + |
| ", thread_k_blocks = " + str(thread_k_blocks)); |
| } |
| } |
| } |
| } |
|
|
| } |
|
|
| torch::Tensor marlin_gemm_moe( |
| const torch::Tensor& a, const torch::Tensor& b_q_weights, |
| const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, |
| const torch::Tensor& topk_ids, const torch::Tensor& b_scales, |
| torch::Tensor& b_zeros, const torch::Tensor& g_idx, |
| const torch::Tensor& perm, torch::Tensor& workspace, |
| vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, |
| int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, |
| int64_t moe_block_size, bool replicate_input, bool apply_weights) { |
| vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); |
| bool has_zp = b_zeros.size(1) != 0; |
| if (has_zp) { |
| TORCH_CHECK( |
| b_q_type == vllm::kU4, |
| "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); |
| } else { |
| TORCH_CHECK( |
| b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, |
| "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); |
| } |
|
|
| int pack_factor = 32 / b_q_type.size_bits(); |
|
|
| int max_par = 4; |
|
|
| int dev = a.get_device(); |
|
|
| auto options_dtype = |
| torch::TensorOptions().dtype(a.dtype()).device(a.device()); |
| auto options_int = |
| torch::TensorOptions().dtype(torch::kInt).device(a.device()); |
| torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); |
| torch::Tensor a_tmp = |
| replicate_input ? torch::zeros({size_m, size_k}, options_dtype) |
| : torch::zeros({size_m, topk, size_k}, options_dtype); |
| torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); |
|
|
| |
| |
| int thread_k = -1; |
| |
| |
| int thread_n = -1; |
| |
| int sms = -1; |
|
|
| |
| int num_groups = -1; |
| int group_size = -1; |
| bool has_act_order = g_idx.size(1) != 0; |
|
|
| int b_rank = b_scales.sizes().size(); |
| TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); |
| TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), |
| " is not size_n = ", size_n); |
| num_groups = b_scales.size(1); |
|
|
| TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), |
| "if is_k_full is false, has_act_order must be true"); |
|
|
| if (has_act_order) { |
| if (is_k_full) { |
| TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); |
| TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, |
| ", is not divisible by num_groups = ", num_groups); |
| group_size = size_k / num_groups; |
| } else { |
| group_size = 0; |
| } |
|
|
| } else { |
| if (num_groups > 1) { |
| TORCH_CHECK( |
| size_k % num_groups == 0, "size_k = ", size_k, |
| ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); |
| group_size = size_k / num_groups; |
| } else { |
| group_size = -1; |
| } |
| } |
|
|
| |
| if (has_zp) { |
| int rank = b_zeros.sizes().size(); |
| TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); |
| TORCH_CHECK(b_zeros.size(1) == num_groups, |
| "b_zeros dim 1 = ", b_zeros.size(1), |
| " is not num_groups = ", num_groups); |
| TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, |
| "b_zeros dim 2 = ", b_zeros.size(2), |
| " is not size_n / pack_factor = ", size_n / pack_factor); |
| } |
|
|
| marlin_moe::marlin_mm_moe( |
| a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), |
| topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), |
| b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), |
| expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), |
| b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, |
| num_experts, topk, moe_block_size, dev, |
| at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, |
| replicate_input, apply_weights); |
| return c; |
| } |
|
|
|
|