| #include "common.h" |
| #include "gemm.h" |
| #include "vec.h" |
|
|
| namespace { |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| template <typename scalar_t> |
| void segment_gemm_kernel_impl( |
| scalar_t* __restrict__ C0, |
| scalar_t* __restrict__ C1, |
| const scalar_t* __restrict__ A, |
| const scalar_t* __restrict__ B0, |
| const scalar_t* __restrict__ B1, |
| int64_t M, |
| int64_t N0, |
| int64_t N1, |
| int64_t K) { |
| |
| constexpr int64_t BLOCK_M = block_size_m(); |
| constexpr int64_t BLOCK_N = block_size_n(); |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB0 = div_up(N0, BLOCK_N); |
| const int64_t NB1 = div_up(N1, BLOCK_N); |
| const int64_t NB = NB0 + NB1; |
|
|
| const bool use_brgemm = can_use_brgemm<scalar_t>(M); |
|
|
| |
| at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { |
| int64_t mb{0}, nb{0}; |
| data_index_init(begin, mb, MB, nb, NB); |
|
|
| |
| alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; |
|
|
| for (int64_t i = begin; i < end; ++i) { |
| UNUSED(i); |
| int mb_start = mb * BLOCK_M; |
| int mb_size = std::min(M - mb_start, BLOCK_M); |
| int nb_start = nb * BLOCK_N; |
| int nb_size = BLOCK_N; |
|
|
| const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1; |
| scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; |
| int64_t ldc = nb < NB0 ? N0 : N1; |
| int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; |
|
|
| tinygemm_kernel<scalar_t>( |
| A + mb_start * K, |
| B + local_nb_start * K , |
| C + mb_start * ldc + local_nb_start, |
| Ctmp, |
| mb_size, |
| nb_size, |
| K, |
| K, |
| nb_size, |
| ldc, |
| use_brgemm); |
|
|
| |
| data_index_step(mb, MB, nb, NB); |
| } |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
| } |
|
|
| |
| template <typename scalar_t> |
| void segment_gemm_kernel_impl( |
| scalar_t* __restrict__ C0, |
| scalar_t* __restrict__ C1, |
| const uint8_t* __restrict__ A, |
| const int8_t* __restrict__ B0, |
| const int8_t* __restrict__ B1, |
| const float* __restrict__ As, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| int64_t M, |
| int64_t N0, |
| int64_t N1, |
| int64_t K) { |
| constexpr int64_t BLOCK_M = block_size_m(); |
| constexpr int64_t BLOCK_N = block_size_n(); |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB0 = div_up(N0, BLOCK_N); |
| const int64_t NB1 = div_up(N1, BLOCK_N); |
| const int64_t NB = NB0 + NB1; |
|
|
| const bool use_brgemm = can_use_brgemm<int8_t>(M); |
|
|
| |
| const int64_t packed_row_size = get_row_size<int8_t>(K); |
|
|
| |
| at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { |
| int64_t mb{0}, nb{0}; |
| data_index_init(begin, mb, MB, nb, NB); |
|
|
| |
| alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; |
|
|
| for (int64_t i = begin; i < end; ++i) { |
| UNUSED(i); |
| int mb_start = mb * BLOCK_M; |
| int mb_size = std::min(M - mb_start, BLOCK_M); |
| int nb_start = nb * BLOCK_N; |
| int nb_size = BLOCK_N; |
|
|
| const int8_t* __restrict__ B = nb < NB0 ? B0 : B1; |
| const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1; |
| scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; |
| int64_t ldc = nb < NB0 ? N0 : N1; |
| int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; |
|
|
| tinygemm_kernel<scalar_t>( |
| A + mb_start * K, |
| B + local_nb_start * packed_row_size , |
| C + mb_start * ldc + local_nb_start, |
| Ctmp, |
| As + mb_start, |
| Bs + local_nb_start, |
| mb_size, |
| nb_size, |
| K, |
| K, |
| nb_size, |
| ldc, |
| use_brgemm); |
|
|
| |
| data_index_step(mb, MB, nb, NB); |
| } |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
| } |
|
|
| |
| template <typename scalar_t> |
| void segment_gemm_kernel_impl( |
| scalar_t* __restrict__ C0, |
| scalar_t* __restrict__ C1, |
| const scalar_t* __restrict__ A, |
| const at::Float8_e4m3fn* __restrict__ B0, |
| const at::Float8_e4m3fn* __restrict__ B1, |
| const float* __restrict__ Bs0, |
| const float* __restrict__ Bs1, |
| scalar_t* __restrict__ Btmp, |
| int64_t M, |
| int64_t N0, |
| int64_t N1, |
| int64_t K, |
| int64_t block_size_N, |
| int64_t block_size_K) { |
| constexpr int64_t BLOCK_M = block_size_m(); |
| constexpr int64_t BLOCK_N = block_size_n(); |
| const int64_t MB = div_up(M, BLOCK_M); |
| const int64_t NB0 = div_up(N0, BLOCK_N); |
| const int64_t NB1 = div_up(N1, BLOCK_N); |
| const int64_t NB = NB0 + NB1; |
|
|
| const int64_t scale_size_K = div_up(K, block_size_K); |
| const int64_t blocks_n_per_group = block_size_N / BLOCK_N; |
|
|
| const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); |
|
|
| |
| at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { |
| int64_t mb{0}, nb{0}; |
| data_index_init(begin, mb, MB, nb, NB); |
|
|
| int tid = at::get_thread_num(); |
| |
| alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; |
|
|
| for (int64_t i = begin; i < end; ++i) { |
| UNUSED(i); |
|
|
| int mb_start = mb * BLOCK_M; |
| int mb_size = std::min(M - mb_start, BLOCK_M); |
| int nb_start = nb * BLOCK_N; |
| int nb_size = BLOCK_N; |
|
|
| const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1; |
| const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1; |
| scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; |
| int64_t ldc = nb < NB0 ? N0 : N1; |
| int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; |
| int64_t new_nb = nb < NB0 ? nb : nb - NB0; |
|
|
| tinygemm_kernel<scalar_t>( |
| A + mb_start * K, |
| B + local_nb_start * K , |
| C + mb_start * ldc + local_nb_start, |
| Btmp + tid * BLOCK_N * K, |
| Ctmp, |
| Bs + (new_nb / blocks_n_per_group) * scale_size_K, |
| mb_size, |
| nb_size, |
| K, |
| K, |
| nb_size, |
| ldc, |
| use_brgemm, |
| block_size_K); |
|
|
| |
| data_index_step(mb, MB, nb, NB); |
| } |
|
|
| if (use_brgemm) { |
| at::native::cpublas::brgemm_release(); |
| } |
| }); |
| } |
|
|
| template <typename scalar_t> |
| inline float reduce(const scalar_t* __restrict__ x, int64_t size) { |
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| fVec sum_fvec = fVec(float(0)); |
|
|
| |
| #pragma GCC unroll 4 |
| for (int64_t d = 0; d < size; d += bVec::size()) { |
| bVec x_bvec = bVec::loadu(x + d); |
| fVec x_fvec0, x_fvec1; |
| std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); |
| sum_fvec += x_fvec0 * x_fvec0; |
| sum_fvec += x_fvec1 * x_fvec1; |
| } |
| return vec_reduce_sum(sum_fvec); |
| } |
|
|
| |
| template <typename scalar_t> |
| inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) { |
| using bVec = at::vec::Vectorized<scalar_t>; |
| using fVec = at::vec::Vectorized<float>; |
| fVec scale_fvec = fVec(scale); |
|
|
| |
| #pragma GCC unroll 4 |
| for (int64_t d = 0; d < size; d += bVec::size()) { |
| bVec x_bvec = bVec::loadu(x + d); |
| fVec x_fvec0, x_fvec1; |
| std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); |
| bVec w_bvec = bVec::loadu(w + d); |
| fVec w_fvec0, w_fvec1; |
| std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec); |
| x_fvec0 = x_fvec0 * scale_fvec * w_fvec0; |
| x_fvec1 = x_fvec1 * scale_fvec * w_fvec1; |
| bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1); |
| out_bvec.store(y + d); |
| } |
| } |
|
|
| template <typename scalar_t> |
| void rms_norm_kernel_impl( |
| scalar_t* __restrict__ input0, |
| scalar_t* __restrict__ input1, |
| const scalar_t* __restrict__ weight0, |
| const scalar_t* __restrict__ weight1, |
| int64_t M, |
| int64_t N0, |
| int64_t N1, |
| int64_t stride1, |
| float eps = 1e-5) { |
| at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| scalar_t* x0 = input0 + m * N0; |
| scalar_t* x1 = input1 + m * stride1; |
| float scale0 = reduce(x0, N0); |
| float scale1 = reduce(x1, N1); |
| scale0 = float(1) / std::sqrt(scale0 / N0 + eps); |
| scale1 = float(1) / std::sqrt(scale1 / N1 + eps); |
| map2(x0, x0, weight0, scale0, N0); |
| map2(x1, x1, weight1, scale1, N1); |
| } |
| }); |
| } |
|
|
| template <typename scalar_t> |
| inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) { |
| TORCH_CHECK(false, "rotary scalar path not implemented."); |
| } |
|
|
| #if defined(CPU_CAPABILITY_AVX512) |
| template <> |
| inline void rotary<at::BFloat16>( |
| const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) { |
| |
| const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); |
| const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); |
| const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); |
| const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); |
|
|
| |
| #pragma GCC unroll 2 |
| for (int64_t d = 0; d < size; d += 32) { |
| int64_t d2 = d >> 1; |
| |
| __m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2))); |
| __m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2))); |
| |
| __m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d)); |
| __m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0)); |
| __m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1)); |
| |
| __m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); |
| __m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); |
| |
| |
| __m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin)); |
| __m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin)); |
| |
| a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2); |
| b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2); |
|
|
| _mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a))); |
| } |
| } |
| #endif |
|
|
| template <typename scalar_t> |
| void rotary_emb_kernel_impl( |
| scalar_t* q_pe_out, |
| scalar_t* k_pe_out, |
| const scalar_t* q_pe, |
| const scalar_t* k_pe, |
| const int64_t* pos, |
| const scalar_t* cos_sin, |
| int64_t num_seqs, |
| int64_t num_heads, |
| int64_t rotary_dim, |
| int64_t q_strideB, |
| int64_t q_strideH, |
| int64_t k_strideB, |
| int64_t oq_strideB, |
| int64_t oq_strideH, |
| int64_t ok_strideB) { |
| TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x."); |
| const int64_t rotary_offset = rotary_dim / 2; |
|
|
| |
| |
| at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { |
| int64_t seq{0}, head_id{0}; |
| data_index_init(begin, seq, num_seqs, head_id, num_heads + 1); |
|
|
| for (int64_t i = begin; i < end; ++i) { |
| UNUSED(i); |
| |
| int64_t index = pos[seq]; |
| const scalar_t* cos = cos_sin + index * rotary_dim; |
| const scalar_t* sin = cos + rotary_offset; |
|
|
| const scalar_t* input = |
| (head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB; |
| scalar_t* out = |
| (head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB; |
| rotary<scalar_t>(input, out, cos, sin, rotary_dim); |
|
|
| |
| data_index_step(seq, num_seqs, head_id, num_heads + 1); |
| } |
| }); |
| } |
|
|
| } |
|
|
| extern at::Tensor |
| weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni); |
|
|
| extern at::Tensor int8_scaled_mm_with_quant( |
| at::Tensor& mat1, |
| at::Tensor& mat2, |
| at::Tensor& scales2, |
| const std::optional<at::Tensor>& bias, |
| at::ScalarType out_dtype, |
| bool is_vnni); |
|
|
| extern void |
| bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale); |
|
|
| extern at::Tensor fp8_scaled_mm_cpu( |
| at::Tensor& mat1, |
| at::Tensor& mat2, |
| at::Tensor& scales2, |
| std::vector<int64_t> block_size, |
| const std::optional<at::Tensor>& bias, |
| at::ScalarType out_dtype, |
| bool is_vnni); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( |
| at::Tensor& hidden_states, |
| at::Tensor& q_a_proj_weight, |
| at::Tensor& q_b_proj_weight, |
| at::Tensor& kv_a_proj_weight, |
| at::Tensor& w_kc, |
| at::Tensor& q_a_layernorm_weight, |
| at::Tensor& kv_a_layernorm_weight, |
| at::Tensor& positions, |
| at::Tensor& cos_sin_cache, |
| double eps, |
| bool use_int8_w8a8, |
| bool use_fp8_w8a16, |
| std::optional<at::Tensor> q_a_proj_scale, |
| std::optional<at::Tensor> q_b_proj_scale, |
| std::optional<at::Tensor> kv_a_proj_scale, |
| bool is_vnni, |
| std::optional<std::vector<int64_t>> block_size) { |
| RECORD_FUNCTION( |
| "sgl-kernel::qkv_proj_with_rope", |
| std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc})); |
|
|
| const auto st = hidden_states.scalar_type(); |
| CHECK_INPUT(hidden_states); |
| CHECK_INPUT(positions); |
| CHECK_INPUT(cos_sin_cache); |
| CHECK_EQ(q_a_layernorm_weight.scalar_type(), st); |
| CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st); |
| CHECK_EQ(positions.scalar_type(), at::kLong); |
| CHECK_EQ(cos_sin_cache.scalar_type(), st); |
| CHECK_DIM(2, hidden_states); |
| CHECK_DIM(3, w_kc); |
| CHECK_DIM(1, q_a_layernorm_weight); |
| CHECK_DIM(1, kv_a_layernorm_weight); |
| CHECK_DIM(1, positions); |
| CHECK_DIM(2, cos_sin_cache); |
|
|
| |
| TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!"); |
|
|
| int64_t num_seqs = hidden_states.size(0); |
| int64_t hidden_size = hidden_states.size(1); |
| int64_t q_lora_rank = q_a_proj_weight.size(0); |
| int64_t num_heads = w_kc.size(0); |
| int64_t kv_lora_rank = w_kc.size(1); |
| int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads; |
| int64_t qk_nope_head_dim = w_kc.size(2); |
| int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank; |
| int64_t rotary_dim = cos_sin_cache.size(1); |
|
|
| CHECK_EQ(positions.numel(), num_seqs); |
| CHECK_EQ(rotary_dim, qk_rope_head_dim); |
| CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank); |
| CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank); |
|
|
| |
| CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); |
| CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8)); |
| CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); |
|
|
| if (use_int8_w8a8) { |
| TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8."); |
| TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8."); |
| TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8."); |
| } |
| if (use_fp8_w8a16) { |
| TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16."); |
| TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16."); |
| TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16."); |
| TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); |
| TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16."); |
| } |
| |
| const auto options = hidden_states.options(); |
| auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options); |
| auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options); |
| auto v_input = k_input.narrow(-1, 0, kv_lora_rank); |
|
|
| |
| auto qa = at::empty({num_seqs, q_lora_rank}, options); |
|
|
| |
| AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] { |
| if (use_int8_w8a8) { |
| auto q_a_proj_s = q_a_proj_scale.value(); |
| auto kv_a_proj_s = kv_a_proj_scale.value(); |
| TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank); |
| TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim); |
|
|
| auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte)); |
| uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>(); |
| float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size)); |
| const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>(); |
|
|
| at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) { |
| for (int64_t m = begin; m < end; ++m) { |
| quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size); |
| } |
| }); |
|
|
| segment_gemm_kernel_impl<scalar_t>( |
| qa.data_ptr<scalar_t>(), |
| k_input.data_ptr<scalar_t>(), |
| Aq_data, |
| q_a_proj_weight.data_ptr<int8_t>(), |
| kv_a_proj_weight.data_ptr<int8_t>(), |
| As_data, |
| q_a_proj_s.data_ptr<float>(), |
| kv_a_proj_s.data_ptr<float>(), |
| num_seqs, |
| q_lora_rank, |
| kv_lora_rank + qk_rope_head_dim, |
| hidden_size); |
| } else if (use_fp8_w8a16) { |
| int64_t block_size_N = block_size.value()[0]; |
| int64_t block_size_K = block_size.value()[1]; |
| auto q_a_proj_s = q_a_proj_scale.value(); |
| auto kv_a_proj_s = kv_a_proj_scale.value(); |
| CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N)); |
| CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K)); |
| CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N)); |
| CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K)); |
|
|
| const int BLOCK_N = block_size_n(); |
| const int num_threads = at::get_num_threads(); |
| auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options); |
| segment_gemm_kernel_impl<scalar_t>( |
| qa.data_ptr<scalar_t>(), |
| k_input.data_ptr<scalar_t>(), |
| hidden_states.data_ptr<scalar_t>(), |
| q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(), |
| kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(), |
| q_a_proj_s.data_ptr<float>(), |
| kv_a_proj_s.data_ptr<float>(), |
| buffer.data_ptr<scalar_t>(), |
| num_seqs, |
| q_lora_rank, |
| kv_lora_rank + qk_rope_head_dim, |
| hidden_size, |
| block_size_N, |
| block_size_K); |
| } else { |
| segment_gemm_kernel_impl<scalar_t>( |
| qa.data_ptr<scalar_t>(), |
| k_input.data_ptr<scalar_t>(), |
| hidden_states.data_ptr<scalar_t>(), |
| q_a_proj_weight.data_ptr<scalar_t>(), |
| kv_a_proj_weight.data_ptr<scalar_t>(), |
| num_seqs, |
| q_lora_rank, |
| kv_lora_rank + qk_rope_head_dim, |
| hidden_size); |
| } |
| }); |
|
|
| |
| AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] { |
| rms_norm_kernel_impl<scalar_t>( |
| qa.data_ptr<scalar_t>(), |
| v_input.data_ptr<scalar_t>(), |
| q_a_layernorm_weight.data_ptr<scalar_t>(), |
| kv_a_layernorm_weight.data_ptr<scalar_t>(), |
| num_seqs, |
| q_lora_rank, |
| kv_lora_rank, |
| kv_lora_rank + qk_rope_head_dim, |
| eps); |
| }); |
|
|
| |
| at::Tensor qb; |
| std::optional<at::Tensor> bias; |
| if (use_int8_w8a8) { |
| qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni); |
| } else if (use_fp8_w8a16) { |
| qb = fp8_scaled_mm_cpu( |
| qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni); |
| } else { |
| qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni); |
| } |
| qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1}); |
|
|
| |
| std::optional<at::Tensor> scale; |
| auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1); |
| auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1); |
| bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale); |
|
|
| |
| AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] { |
| rotary_emb_kernel_impl<scalar_t>( |
| q_input.data_ptr<scalar_t>() + kv_lora_rank, |
| k_input.data_ptr<scalar_t>() + kv_lora_rank, |
| qb.data_ptr<scalar_t>() + qk_nope_head_dim, |
| k_input.data_ptr<scalar_t>() + kv_lora_rank, |
| positions.data_ptr<int64_t>(), |
| cos_sin_cache.data_ptr<scalar_t>(), |
| num_seqs, |
| num_heads, |
| rotary_dim, |
| num_heads * qk_head_dim, |
| qk_head_dim, |
| kv_lora_rank + qk_rope_head_dim, |
| num_heads * (kv_lora_rank + qk_rope_head_dim), |
| kv_lora_rank + qk_rope_head_dim, |
| kv_lora_rank + qk_rope_head_dim); |
| }); |
|
|
| return std::make_tuple(q_input, k_input, v_input); |
| } |
|
|
| std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight( |
| at::Tensor& hidden_states, |
| at::Tensor& qkv_a_proj_weight, |
| at::Tensor& q_b_proj_weight, |
| at::Tensor& w_kc, |
| at::Tensor& q_a_layernorm_weight, |
| at::Tensor& kv_a_layernorm_weight, |
| at::Tensor& positions, |
| at::Tensor& cos_sin_cache, |
| double eps, |
| bool use_int8_w8a8, |
| bool use_fp8_w8a16, |
| std::optional<at::Tensor> qkv_a_proj_scale, |
| std::optional<at::Tensor> q_b_proj_scale, |
| bool is_vnni, |
| std::optional<std::vector<int64_t>> block_size, |
| int64_t q_lora_rank, |
| int64_t kv_lora_rank, |
| int64_t qk_rope_head_dim) { |
| RECORD_FUNCTION( |
| "sgl-kernel::qkv_proj_with_rope_fused_weight", |
| std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc})); |
|
|
| int64_t hidden_size = hidden_states.size(1); |
| CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim); |
| CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); |
|
|
| std::vector<at::Tensor> weight_chunks = |
| at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0); |
| at::Tensor q_a_proj_weight = weight_chunks[0]; |
| at::Tensor kv_a_proj_weight = weight_chunks[1]; |
| at::Tensor q_a_proj_s; |
| at::Tensor kv_a_proj_s; |
|
|
| if (use_int8_w8a8) { |
| TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8."); |
| std::vector<at::Tensor> scale_chunks = |
| at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0); |
| q_a_proj_s = scale_chunks[0]; |
| kv_a_proj_s = scale_chunks[1]; |
| } |
| if (use_fp8_w8a16) { |
| TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16."); |
| int64_t block_size_N = block_size.value()[0]; |
| int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N); |
| int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N); |
| std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0); |
| q_a_proj_s = scale_chunks[0]; |
| kv_a_proj_s = scale_chunks[1]; |
| } |
|
|
| return qkv_proj_with_rope( |
| hidden_states, |
| q_a_proj_weight, |
| q_b_proj_weight, |
| kv_a_proj_weight, |
| w_kc, |
| q_a_layernorm_weight, |
| kv_a_layernorm_weight, |
| positions, |
| cos_sin_cache, |
| eps, |
| use_int8_w8a8, |
| use_fp8_w8a16, |
| q_a_proj_s, |
| q_b_proj_scale, |
| kv_a_proj_s, |
| is_vnni, |
| block_size); |
| } |
|
|