Hanrui / sglang /sgl-kernel /csrc /cpu /extend.cpp
Lekr0's picture
Add files using upload-large-folder tool
d02d576 verified
#include "common.h"
#include "flash_attn.h"
#include "gemm.h"
namespace {
// [NOTE]: extend attention for CPU
// 1. BLOCK_M and BLOCK_N tuned for various seq lengths
// 2. can handle non-contiguous k_extend and v_extend
// 3. computes attention for prefix and extend separately
// 4. TODO: apply head dimension blocking to optimize GQA
//
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
scalar_t* __restrict__ o_extend,
const scalar_t* __restrict__ q_extend,
const scalar_t* __restrict__ k_extend,
const scalar_t* __restrict__ v_extend,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
const index_t* __restrict__ extend_seq_lens,
const index_t* __restrict__ extend_start_loc,
const void* __restrict__ buffer,
int batches,
int num_heads,
int num_heads_kv,
int head_size,
int head_size_v,
int q_strideM,
int q_strideH,
int ke_strideN,
int ke_strideH,
int ve_strideN,
int ve_strideH,
int k_strideN,
int k_strideH,
int v_strideN,
int v_strideH,
float sm_scale,
int max_num_reqs,
int max_context_len,
int max_total_num_tokens,
int max_len_extend,
int buffer_size_per_thread,
bool is_prefix_skipped) {
// strides
const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v;
// we use same buffer for packed key and value
const int ldb_tmp = std::max(head_size, head_size_v);
const int num_groups = num_heads / num_heads_kv;
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
// number of blocks along M
int MB = div_up(max_len_extend, BLOCK_M);
// parallel on [batches, num_heads, BM]
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
int bs{0}, head_id{0}, mb{0};
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
int tid = at::get_thread_num();
// s_i and s_delta: [BLOCK_M, BLOCK_N]
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
scalar_t* __restrict__ s_delta = reinterpret_cast<scalar_t*>(s_i);
// v_prime: [BLOCK_M, head_size_v]
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
scalar_t* __restrict__ Btmp = reinterpret_cast<scalar_t*>(v_prime + BLOCK_M * head_size_v);
// init Btmp just once for each thread to prevent NaN
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
alignas(64) float s_prime[BLOCK_M];
alignas(64) float m_prime[BLOCK_M];
for (int i = begin; i < end; ++i) {
// seq_len = prefix + extend
int head_kv_id = head_id / num_groups;
int seq_len = seq_lens[bs];
int seq_len_extend = extend_seq_lens[bs];
int seq_len_prefix = seq_len - seq_len_extend;
int seq_extend_start_loc = extend_start_loc[bs];
int req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
if (is_prefix_skipped) {
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
}
// offset and size in MB
int m = mb * BLOCK_M;
int m_size = std::min(BLOCK_M, seq_len_extend - m);
if (m_size <= 0) {
data_index_step(bs, batches, head_id, num_heads, mb, MB);
continue;
}
// get query
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
// init v', s' and m'
fill_stub(v_prime, 0.f, m_size * head_size_v);
fill_stub(s_prime, 0.f, m_size);
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
// stage 1: compute scores with prefix
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_buffer + head_kv_id * k_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ k_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
flash_attn_softmax<scalar_t, BLOCK_M, BLOCK_N>::apply(
s_i, s_delta, v_prime, s_prime, m_prime, m_size, n_size, padded_n_size, head_size_v, sm_scale);
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_buffer + head_kv_id * v_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ v_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_prefix
// stage 2: compute the triangle part
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
for (int n = 0; n < num_keys; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, num_keys - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t>(
/* dst */ Btmp,
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ ke_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
// apply causal mask
if (num_keys - n <= BLOCK_N) {
for (int row = 0; row < m_size; ++row) {
int last_col = m + row - n;
// fill [last_col + 1, n_size) to -inf
float* row_ptr = s_i + row * BLOCK_N;
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
}
}
flash_attn_softmax<scalar_t, BLOCK_M, BLOCK_N>::apply(
s_i, s_delta, v_prime, s_prime, m_prime, m_size, n_size, padded_n_size, head_size_v, sm_scale);
// get value and pack
pack_vnni2<scalar_t>(
/* dst */ Btmp,
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ ve_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_extend
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
for (int row = 0; row < m_size; ++row) {
float s = 1 / s_prime[row];
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
}
// move to the next index
data_index_step(bs, batches, head_id, num_heads, mb, MB);
}
at::native::cpublas::brgemm_release();
});
}
} // anonymous namespace
template <int BLOCK_M, int BLOCK_N>
inline int resize_buffer(at::Tensor& buffer, int num_threads, int head_size, int head_size_v) {
static_assert(BLOCK_M <= BLOCK_N, "Make sure BLOCK_M <= BLOCK_N to prevent buffer overflows during causal masking");
const int size_per_thread =
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
buffer.resize_({num_threads, size_per_thread});
return size_per_thread;
}
#define LAUNCH_EXTEND_ATTENTION_KERNEL(BLOCK_M, BLOCK_N) \
do { \
int sz = resize_buffer<BLOCK_M, BLOCK_N>(buffer, num_threads, head_size, head_size_v); \
\
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>( \
o_extend.data_ptr<scalar_t>(), \
q_extend.data_ptr<scalar_t>(), \
k_extend.data_ptr<scalar_t>(), \
v_extend.data_ptr<scalar_t>(), \
k_buffer.data_ptr<scalar_t>(), \
v_buffer.data_ptr<scalar_t>(), \
req_to_token.data_ptr<index_t>(), \
req_pool_indices.data_ptr<int64_t>(), \
seq_lens.data_ptr<int64_t>(), \
extend_seq_lens.data_ptr<index_t>(), \
extend_start_loc.data_ptr<index_t>(), \
buffer.data_ptr(), \
num_seqs, \
num_heads, \
num_heads_kv, \
head_size, \
head_size_v, \
q_strideM, \
q_strideH, \
ke_strideN, \
ke_strideH, \
ve_strideN, \
ve_strideH, \
k_strideN, \
k_strideH, \
v_strideN, \
v_strideH, \
sm_scale, \
max_num_reqs, \
max_context_len, \
max_total_num_tokens, \
max_len_extend, \
sz, \
is_prefix_skipped); \
} while (0)
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
//
// q_extend: [num_tokens, num_heads, head_size]
// k_extend: [num_extend_tokens, num_heads, head_size]
// v_extend: [num_extend_tokens, num_heads, head_size]
// o_extend: [num_tokens, num_heads, head_size]
// k_buffer: [max_total_num_tokens, num_heads, head_size]
// v_buffer: [max_total_num_tokens, num_heads, head_size]
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
// req_pool_indices: [num_seqs] int64
// seq_lens: [num_seqs] int64
// extend_seq_lens: [num_seqs]
// extend_start_loc: [num_seqs]
//
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap) {
RECORD_FUNCTION(
"sgl-kernel::extend_attention_cpu",
std::vector<c10::IValue>(
{q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_token,
req_pool_indices,
seq_lens,
extend_seq_lens,
extend_start_loc,
max_len_extend}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
int num_seqs = seq_lens.size(0);
int max_num_reqs = req_to_token.size(0);
int max_context_len = req_to_token.size(1);
int max_total_num_tokens = k_buffer.size(0);
int num_heads = q_extend.size(1);
int num_heads_kv = k_extend.size(1);
int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2);
// strides for q_extend, k_extend and v_extend
int q_strideM = q_extend.stride(0);
int q_strideH = q_extend.stride(1);
int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0);
int ve_strideH = v_extend.stride(1);
// strides for k_buffer and v_buffer
int k_strideN = k_buffer.stride(0);
int k_strideH = k_buffer.stride(1);
int v_strideN = v_buffer.stride(0);
int v_strideH = v_buffer.stride(1);
// check sizes
CHECK_EQ(req_pool_indices.size(0), num_seqs);
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
CHECK_EQ(extend_start_loc.size(0), num_seqs);
CHECK_EQ(v_extend.size(1), num_heads_kv);
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
// MLA will skip prefix part
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
// check index data types
const auto index_dtype = req_to_token.scalar_type();
TORCH_CHECK(
index_dtype == at::kInt || index_dtype == at::kLong,
"extend: expect req_to_token to be int32 or int64, got ",
index_dtype);
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
TORCH_CHECK(
req_pool_indices.scalar_type() == at::kLong,
"extend: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type());
TORCH_CHECK(
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
// D and DV need to be 32x as we transpose by 512-bit
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
int num_threads = at::get_num_threads();
auto buffer = at::empty({}, q_extend.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
if (max_len_extend <= 256) {
LAUNCH_EXTEND_ATTENTION_KERNEL(32, 64);
} else if (max_len_extend <= 1024) {
LAUNCH_EXTEND_ATTENTION_KERNEL(128, 256);
} else if (max_len_extend <= 4096) {
LAUNCH_EXTEND_ATTENTION_KERNEL(256, 768);
} else { // max_len_extend > 4096
LAUNCH_EXTEND_ATTENTION_KERNEL(512, 768);
}
});
});
}