File size: 25,352 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 | #include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE
//
// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm,
// otherwise we need to split IC for the 2nd gemm.
// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop.
// 3. k_input and v_input share the same storage, the torch API did
// this in `set_kv_buffer`. No additional memory movement.
//
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
// convert_weight_packed make sure N0 and N1 are 32x
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const int8_t* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* As */ As + mb_start,
/* Bs */ Bs + local_nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B0,
const at::Float8_e4m3fn* __restrict__ B1,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
scalar_t* __restrict__ Btmp,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K,
int64_t block_size_N,
int64_t block_size_K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Btmp*/ Btmp + tid * BLOCK_N * K,
/* Ctmp*/ Ctmp,
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
template <typename scalar_t>
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec sum_fvec = fVec(float(0));
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
return vec_reduce_sum(sum_fvec);
}
// map2 from aten functional doesn't have fast bf16->fp32 conversion
template <typename scalar_t>
inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec scale_fvec = fVec(scale);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(w + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(y + d);
}
}
template <typename scalar_t>
void rms_norm_kernel_impl(
scalar_t* __restrict__ input0,
scalar_t* __restrict__ input1,
const scalar_t* __restrict__ weight0,
const scalar_t* __restrict__ weight1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t stride1,
float eps = 1e-5) {
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
scalar_t* x0 = input0 + m * N0;
scalar_t* x1 = input1 + m * stride1;
float scale0 = reduce(x0, N0);
float scale1 = reduce(x1, N1);
scale0 = float(1) / std::sqrt(scale0 / N0 + eps);
scale1 = float(1) / std::sqrt(scale1 / N1 + eps);
map2(x0, x0, weight0, scale0, N0);
map2(x1, x1, weight1, scale1, N1);
}
});
}
template <typename scalar_t>
inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) {
TORCH_CHECK(false, "rotary scalar path not implemented.");
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void rotary<at::BFloat16>(
const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) {
// permute indices
const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
// rotary dim is 64, just 2 iters
#pragma GCC unroll 2
for (int64_t d = 0; d < size; d += 32) {
int64_t d2 = d >> 1;
// load coefs
__m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2)));
__m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2)));
// load input
__m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d));
__m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
__m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
// from [16, 2] to [2, 16]
__m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
__m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
// out1 = in1 * cos - in2 * sin;
// out2 = in2 * cos + in1 * sin
__m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin));
__m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin));
// from [2, 16] to [16, 2]
a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2);
b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2);
_mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a)));
}
}
#endif
template <typename scalar_t>
void rotary_emb_kernel_impl(
scalar_t* q_pe_out,
scalar_t* k_pe_out,
const scalar_t* q_pe,
const scalar_t* k_pe,
const int64_t* pos,
const scalar_t* cos_sin,
int64_t num_seqs,
int64_t num_heads,
int64_t rotary_dim,
int64_t q_strideB,
int64_t q_strideH,
int64_t k_strideB,
int64_t oq_strideB,
int64_t oq_strideH,
int64_t ok_strideB) {
TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x.");
const int64_t rotary_offset = rotary_dim / 2;
// parallel on [num_seqs, num_heads + 1]
// top [num_heads] handle q_pe and bottom [1] handle k_pe
at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, num_seqs, head_id, num_heads + 1);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
// get cos and sin cache ptr
int64_t index = pos[seq];
const scalar_t* cos = cos_sin + index * rotary_dim;
const scalar_t* sin = cos + rotary_offset;
const scalar_t* input =
(head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB;
scalar_t* out =
(head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB;
rotary<scalar_t>(input, out, cos, sin, rotary_dim);
// move to the next index
data_index_step(seq, num_seqs, head_id, num_heads + 1);
}
});
}
} // anonymous namespace
extern at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
extern at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
extern void
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
extern at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// NB: shapes in DeepDeek R1
//
// hidden_states : [num_seqs, hidden_size] [1, 7168]
// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168]
// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536]
// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168]
// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128]
// q_a_layernorm_weight : [q_lora_rank] [1536]
// kv_a_layernorm_weight : [kv_lora_rank] [512]
//
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope",
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
const auto st = hidden_states.scalar_type();
CHECK_INPUT(hidden_states);
CHECK_INPUT(positions);
CHECK_INPUT(cos_sin_cache);
CHECK_EQ(q_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(positions.scalar_type(), at::kLong);
CHECK_EQ(cos_sin_cache.scalar_type(), st);
CHECK_DIM(2, hidden_states);
CHECK_DIM(3, w_kc);
CHECK_DIM(1, q_a_layernorm_weight);
CHECK_DIM(1, kv_a_layernorm_weight);
CHECK_DIM(1, positions);
CHECK_DIM(2, cos_sin_cache);
// skip contiguous checks for weights, expect prepacked
TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!");
int64_t num_seqs = hidden_states.size(0);
int64_t hidden_size = hidden_states.size(1);
int64_t q_lora_rank = q_a_proj_weight.size(0);
int64_t num_heads = w_kc.size(0);
int64_t kv_lora_rank = w_kc.size(1);
int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads;
int64_t qk_nope_head_dim = w_kc.size(2);
int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank;
int64_t rotary_dim = cos_sin_cache.size(1);
CHECK_EQ(positions.numel(), num_seqs);
CHECK_EQ(rotary_dim, qk_rope_head_dim);
CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank);
CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank);
// check the packed dimension
CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8));
CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
if (use_int8_w8a8) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
}
// outputs and temp buffer
const auto options = hidden_states.options();
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options);
auto v_input = k_input.narrow(-1, 0, kv_lora_rank);
// outputs of q_a_proj and q_b_proj
auto qa = at::empty({num_seqs, q_lora_rank}, options);
// stage 1: q_a_proj and kv_a_proj
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] {
if (use_int8_w8a8) {
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank);
TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim);
auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte));
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size));
const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>();
at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size);
}
});
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
Aq_data,
q_a_proj_weight.data_ptr<int8_t>(),
kv_a_proj_weight.data_ptr<int8_t>(),
As_data,
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
} else if (use_fp8_w8a16) {
int64_t block_size_N = block_size.value()[0];
int64_t block_size_K = block_size.value()[1];
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
const int BLOCK_N = block_size_n();
const int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
buffer.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size,
block_size_N,
block_size_K);
} else {
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<scalar_t>(),
kv_a_proj_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
}
});
// stage 2: apply rmsnorm inplace
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] {
rms_norm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
v_input.data_ptr<scalar_t>(),
q_a_layernorm_weight.data_ptr<scalar_t>(),
kv_a_layernorm_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank,
kv_lora_rank + qk_rope_head_dim,
eps);
});
// stage 3: q_b_proj
at::Tensor qb;
std::optional<at::Tensor> bias;
if (use_int8_w8a8) {
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
} else if (use_fp8_w8a16) {
qb = fp8_scaled_mm_cpu(
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
} else {
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
}
qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1});
// stage 4: bmm
std::optional<at::Tensor> scale;
auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1);
auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1);
bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale);
// stage 5: rope
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] {
rotary_emb_kernel_impl<scalar_t>(
q_input.data_ptr<scalar_t>() + kv_lora_rank,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
qb.data_ptr<scalar_t>() + qk_nope_head_dim,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
positions.data_ptr<int64_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_seqs,
num_heads,
rotary_dim,
num_heads * qk_head_dim,
qk_head_dim,
kv_lora_rank + qk_rope_head_dim,
num_heads * (kv_lora_rank + qk_rope_head_dim),
kv_lora_rank + qk_rope_head_dim,
kv_lora_rank + qk_rope_head_dim);
});
return std::make_tuple(q_input, k_input, v_input);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope_fused_weight",
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
int64_t hidden_size = hidden_states.size(1);
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
std::vector<at::Tensor> weight_chunks =
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
at::Tensor q_a_proj_weight = weight_chunks[0];
at::Tensor kv_a_proj_weight = weight_chunks[1];
at::Tensor q_a_proj_s;
at::Tensor kv_a_proj_s;
if (use_int8_w8a8) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
std::vector<at::Tensor> scale_chunks =
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
if (use_fp8_w8a16) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
int64_t block_size_N = block_size.value()[0];
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
return qkv_proj_with_rope(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
kv_a_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
q_a_proj_s,
q_b_proj_scale,
kv_a_proj_s,
is_vnni,
block_size);
}
|