| import unittest |
|
|
| import torch |
| from torch.nn.functional import scaled_dot_product_attention |
|
|
| from sglang.test.test_utils import CustomTestCase |
|
|
| torch.manual_seed(1234) |
|
|
|
|
| class TestExtendAttention(CustomTestCase): |
|
|
| def _run_sdpa_forward_extend( |
| self, |
| query: torch.Tensor, |
| output: torch.Tensor, |
| k_cache: torch.Tensor, |
| v_cache: torch.Tensor, |
| req_to_token: torch.Tensor, |
| req_pool_indices: torch.Tensor, |
| seq_lens: torch.Tensor, |
| extend_prefix_lens: torch.Tensor, |
| extend_seq_lens: torch.Tensor, |
| scaling=None, |
| enable_gqa=False, |
| causal=False, |
| ): |
|
|
| assert seq_lens.shape[0] == extend_prefix_lens.shape[0] |
| assert seq_lens.shape[0] == extend_seq_lens.shape[0] |
|
|
| |
| query = query.movedim(0, query.dim() - 2) |
|
|
| start_q, start_kv = 0, 0 |
| for seq_idx in range(seq_lens.shape[0]): |
|
|
| extend_seq_len_q = extend_seq_lens[seq_idx] |
| prefill_seq_len_q = extend_prefix_lens[seq_idx] |
|
|
| seq_len_kv = seq_lens[seq_idx] |
| end_q = start_q + extend_seq_len_q |
| end_kv = start_kv + seq_len_kv |
|
|
| per_req_query = query[:, start_q:end_q, :] |
| per_req_query_redudant = torch.empty( |
| (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), |
| dtype=per_req_query.dtype, |
| device=per_req_query.device, |
| ) |
|
|
| per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query |
|
|
| |
| |
| req_pool_idx = req_pool_indices[seq_idx] |
| per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] |
| per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) |
| per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) |
|
|
| per_req_out_redudant = ( |
| scaled_dot_product_attention( |
| per_req_query_redudant.unsqueeze(0), |
| per_req_key.unsqueeze(0), |
| per_req_value.unsqueeze(0), |
| enable_gqa=enable_gqa, |
| scale=scaling, |
| is_causal=causal, |
| ) |
| .squeeze(0) |
| .movedim(query.dim() - 2, 0) |
| ) |
| output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :] |
| start_q, start_kv = end_q, end_kv |
| return output |
|
|
| def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False): |
| dtype = torch.bfloat16 |
|
|
| b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) |
| if mla: |
| b_seq_len_prefix.zero_() |
| b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) |
| b_seq_len = b_seq_len_prefix + b_seq_len_extend |
| max_len_in_batch = torch.max(b_seq_len, 0)[0].item() |
|
|
| b_req_idx = torch.arange(B, dtype=torch.int32) |
| req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32) |
| b_start_loc = torch.zeros((B,), dtype=torch.int32) |
| b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) |
| b_start_loc_extend = torch.zeros((B,), dtype=torch.int32) |
| b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) |
|
|
| for i in range(B): |
| req_to_tokens[i, : b_seq_len[i]] = torch.arange( |
| b_start_loc[i], b_start_loc[i] + b_seq_len[i] |
| ) |
|
|
| total_token_num = torch.sum(b_seq_len).item() |
| extend_token_num = torch.sum(b_seq_len_extend).item() |
|
|
| H_BUF = 1 if mla else H_KV |
| k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype) |
| v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype) |
|
|
| k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype) |
| v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype) |
| q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype) |
|
|
| for i in range(B): |
| extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] |
| extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] |
| extend_start = b_start_loc_extend[i] |
| extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] |
| k_extend[extend_start:extend_end] = k_buffer[ |
| extend_start_in_buffer:extend_end_in_buffer |
| ] |
| v_extend[extend_start:extend_end] = v_buffer[ |
| extend_start_in_buffer:extend_end_in_buffer |
| ] |
| q_extend[extend_start:extend_end] = torch.randn( |
| (b_seq_len_extend[i], H_Q, D), dtype=dtype |
| ) |
|
|
| |
| q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1) |
| k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) |
| v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) |
| k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) |
| v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) |
|
|
| b_seq_len_extend = b_seq_len - b_seq_len_prefix |
| b_start_loc_extend = torch.zeros_like(b_seq_len) |
| b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) |
| max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() |
|
|
| sm_scale = 1.0 / (D**0.5) |
| logit_cap = 0.0 |
|
|
| |
| b_req_idx = b_req_idx.to(torch.int64) |
| b_seq_len = b_seq_len.to(torch.int64) |
|
|
| enable_gqa = H_Q != H_KV |
| o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) |
| self._run_sdpa_forward_extend( |
| q_extend, |
| o_ref, |
| k_buffer, |
| v_buffer, |
| req_to_tokens, |
| b_req_idx, |
| b_seq_len, |
| b_seq_len_prefix, |
| b_seq_len_extend, |
| scaling=sm_scale, |
| enable_gqa=enable_gqa, |
| causal=True, |
| ) |
|
|
| o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) |
| torch.ops.sgl_kernel.extend_attention_cpu( |
| q_extend, |
| k_extend, |
| v_extend, |
| o_extend, |
| k_buffer, |
| v_buffer, |
| req_to_tokens, |
| b_req_idx, |
| b_seq_len, |
| b_seq_len_extend, |
| b_start_loc_extend, |
| max_len_extend, |
| sm_scale, |
| logit_cap, |
| ) |
|
|
| torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2) |
|
|
| def test_extend_attention(self): |
| for is_mla in [True, False]: |
| self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla) |
| self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla) |
| self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla) |
| self._test_extend_attention_once(1, 9000, 16, 1, 32, 32, is_mla) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|