| import unittest |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn.functional import softplus |
| from utils import precision |
|
|
| from sglang.test.test_utils import CustomTestCase |
|
|
| torch.manual_seed(1234) |
|
|
|
|
| def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): |
| """This function is intended to align with the l2norm implementation in the FLA library.""" |
| inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) |
| return x * inv_norm |
|
|
|
|
| def torch_chunk_gated_delta_rule( |
| query, |
| key, |
| value, |
| g, |
| beta, |
| chunk_size=64, |
| initial_state=None, |
| output_final_state=False, |
| use_qk_l2norm_in_kernel=False, |
| ): |
| initial_dtype = query.dtype |
| if use_qk_l2norm_in_kernel: |
| query = l2norm(query, dim=-1, eps=1e-6) |
| key = l2norm(key, dim=-1, eps=1e-6) |
| query, key, value, beta, g = [ |
| x.transpose(1, 2).contiguous().to(torch.float32) |
| for x in (query, key, value, beta, g) |
| ] |
|
|
| batch_size, sequence_length, num_heads, k_head_dim = key.shape |
| v_head_dim = value.shape[-1] |
| pad_size = (chunk_size - num_heads % chunk_size) % chunk_size |
| query = F.pad(query, (0, 0, 0, pad_size)) |
| key = F.pad(key, (0, 0, 0, pad_size)) |
| value = F.pad(value, (0, 0, 0, pad_size)) |
| beta = F.pad(beta, (0, pad_size)) |
| g = F.pad(g, (0, pad_size)) |
| tot_heads = num_heads + pad_size |
| scale = 1 / (query.shape[-1] ** 0.5) |
| query = query * scale |
|
|
| v_beta = value * beta.unsqueeze(-1) |
| k_beta = key * beta.unsqueeze(-1) |
| |
| query, key, value, k_beta, v_beta = [ |
| x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) |
| for x in (query, key, value, k_beta, v_beta) |
| ] |
| g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) |
| mask = torch.triu( |
| torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), |
| diagonal=0, |
| ) |
|
|
| |
| g = g.cumsum(dim=-1) |
| decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() |
| attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) |
| for i in range(1, chunk_size): |
| row = attn[..., i, :i].clone() |
| sub = attn[..., :i, :i].clone() |
| attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) |
| attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) |
| value = attn @ v_beta |
| k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) |
| last_recurrent_state = ( |
| torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) |
| if initial_state is None |
| else initial_state.to(value) |
| ) |
| core_attn_out = torch.zeros_like(value) |
| mask = torch.triu( |
| torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), |
| diagonal=1, |
| ) |
|
|
| |
| for i in range(0, tot_heads // chunk_size): |
| q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] |
| attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) |
| v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state |
| v_new = v_i - v_prime |
| attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state |
| core_attn_out[:, :, i] = attn_inter + attn @ v_new |
| last_recurrent_state = ( |
| last_recurrent_state * g[:, :, i, -1, None, None].exp() |
| + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( |
| -1, -2 |
| ) |
| @ v_new |
| ) |
|
|
| if not output_final_state: |
| last_recurrent_state = None |
| core_attn_out = core_attn_out.reshape( |
| core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] |
| ) |
| core_attn_out = core_attn_out[:, :, :num_heads] |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) |
| return core_attn_out, last_recurrent_state |
|
|
|
|
| def chunk_gated_delta_rule_update( |
| query, |
| key, |
| value, |
| g, |
| beta, |
| cu_seqlens, |
| initial_state, |
| use_qk_l2norm_in_kernel, |
| ): |
| num_heads = query.shape[2] |
| num_value_heads = value.shape[2] |
| batch_size = initial_state.shape[0] |
| if num_value_heads // num_heads > 1: |
| query = query.repeat_interleave(num_value_heads // num_heads, dim=2) |
| key = key.repeat_interleave(num_value_heads // num_heads, dim=2) |
| output = torch.empty_like(value) |
| final_state = torch.empty_like(initial_state) |
| start_q = 0 |
| for i in range(batch_size): |
| end_q = cu_seqlens[i + 1] |
| core_attn_outi, last_recurrent_state = torch_chunk_gated_delta_rule( |
| query=query[:, start_q:end_q, :, :], |
| key=key[:, start_q:end_q, :, :], |
| value=value[:, start_q:end_q, :, :], |
| g=g[:, start_q:end_q, :], |
| beta=beta[:, start_q:end_q, :], |
| initial_state=initial_state[i], |
| output_final_state=True, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| ) |
| output[:, start_q:end_q, :, :] = core_attn_outi |
| final_state[i] = last_recurrent_state |
| start_q = end_q |
| return output, final_state |
|
|
|
|
| def torch_recurrent_gated_delta_rule( |
| query, |
| key, |
| value, |
| g, |
| beta, |
| initial_state, |
| output_final_state, |
| use_qk_l2norm_in_kernel=False, |
| ): |
| initial_dtype = query.dtype |
| if use_qk_l2norm_in_kernel: |
| query = l2norm(query, dim=-1, eps=1e-6) |
| key = l2norm(key, dim=-1, eps=1e-6) |
| query, key, value, beta, g = [ |
| x.transpose(1, 2).contiguous().to(torch.float32) |
| for x in (query, key, value, beta, g) |
| ] |
|
|
| batch_size, num_heads, sequence_length, k_head_dim = key.shape |
| v_head_dim = value.shape[-1] |
| scale = 1 / (query.shape[-1] ** 0.5) |
| query = query * scale |
|
|
| core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to( |
| value |
| ) |
| last_recurrent_state = ( |
| torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) |
| if initial_state is None |
| else initial_state.to(value) |
| ) |
|
|
| for i in range(sequence_length): |
| q_t = query[:, :, i] |
| k_t = key[:, :, i] |
| v_t = value[:, :, i] |
| g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) |
| beta_t = beta[:, :, i].unsqueeze(-1) |
|
|
| last_recurrent_state = last_recurrent_state * g_t |
| kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) |
| delta = (v_t - kv_mem) * beta_t |
| last_recurrent_state = last_recurrent_state + k_t.unsqueeze( |
| -1 |
| ) * delta.unsqueeze(-2) |
| core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) |
|
|
| if not output_final_state: |
| last_recurrent_state = None |
| core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) |
| return core_attn_out, last_recurrent_state |
|
|
|
|
| def sigmoid_gating_delta_rule_update( |
| query, |
| key, |
| value, |
| A_log, |
| a, |
| dt_bias, |
| b, |
| initial_state, |
| output_final_state, |
| use_qk_l2norm_in_kernel=False, |
| ): |
| beta = b.sigmoid() |
| g = -A_log.float().exp() * softplus(a.float() + dt_bias) |
| return torch_recurrent_gated_delta_rule( |
| query, |
| key, |
| value, |
| g.unsqueeze(0), |
| beta.unsqueeze(0), |
| initial_state, |
| output_final_state, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| ) |
|
|
|
|
| def torch_gdn_gating(A_log, a, b, dt_bias): |
| return -A_log.float().exp() * softplus(a.float() + dt_bias).unsqueeze( |
| 0 |
| ), b.sigmoid().unsqueeze(0) |
|
|
|
|
| class TestMambaAttention(CustomTestCase): |
| def test_chunk_gated_delta_rule(self): |
| B, L, HK, HV, EK, EV, N = 1, 100, 3, 6, 64, 64, 4 |
| seqlens = torch.randint(1, L, (N + 1,)) |
| seqlens[0] = 0 |
| cu_seqlens_ = torch.cumsum(seqlens, dim=0).to(torch.int32) |
| T = cu_seqlens_[-1].item() |
| query_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 |
| key_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05 |
| value_ = torch.rand((B, T, HV, EV), dtype=torch.bfloat16) * 0.05 |
| g_ = torch.rand((B, T, HV), dtype=torch.float32) * 0.05 |
| beta_ = torch.rand((B, T, HV), dtype=torch.bfloat16) * 0.05 |
| initial_state_ = torch.rand((N, HV, EK, EV), dtype=torch.float32) * 0.05 |
|
|
| for use_qk_l2norm_in_kernel in [True, False]: |
| core_attn_out_ref, last_recurrent_state_ref = chunk_gated_delta_rule_update( |
| query=query_, |
| key=key_, |
| value=value_, |
| g=g_, |
| beta=beta_, |
| cu_seqlens=cu_seqlens_, |
| initial_state=initial_state_, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| ) |
|
|
| query = query_.clone() |
| key = key_.clone() |
| value = value_.clone() |
| g = g_.clone() |
| beta = beta_.clone() |
| cu_seqlens = cu_seqlens_.clone() |
| initial_state = initial_state_.clone() |
|
|
| core_attn_out, last_recurrent_state = ( |
| torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu( |
| query=query, |
| key=key, |
| value=value, |
| g=g, |
| beta=beta, |
| initial_state=initial_state, |
| output_final_state=True, |
| cu_seqlens=cu_seqlens, |
| head_first=False, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| ) |
| ) |
| atol = rtol = precision[core_attn_out.dtype] |
| torch.testing.assert_close( |
| core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol |
| ) |
| torch.testing.assert_close( |
| last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol |
| ) |
|
|
| def test_fused_gdn_gating(self): |
| dims = [6, 32] |
| for dim in dims: |
| A_log = torch.rand(dim) |
| a = torch.rand(1024, dim, dtype=torch.bfloat16) |
| b = torch.rand(1024, dim, dtype=torch.bfloat16) |
| dt_bias = torch.rand(dim, dtype=torch.bfloat16) |
|
|
| g, beta = torch_gdn_gating(A_log, a, b, dt_bias) |
| g_sgl, beta_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu( |
| A_log, a, b, dt_bias |
| ) |
| atol = rtol = precision[g.dtype] |
| atol2 = rtol2 = precision[beta.dtype] |
| torch.testing.assert_close(g, g_sgl, atol=atol, rtol=rtol) |
| torch.testing.assert_close(beta, beta_sgl, atol=atol2, rtol=rtol2) |
|
|
| def test_fused_sigmoid_gating_delta_rule_update(self): |
| batch_size = 1 |
| num_value_heads = 32 |
| head_k_dim = 128 |
| head_v_dim = 128 |
| num_heads = 16 |
| seq_len = 1 |
| attn_tp_size = 1 |
| key_dim = head_k_dim * num_heads |
| value_dim = head_v_dim * num_value_heads |
| mixed_qkv_dim = (key_dim * 2 + value_dim) // attn_tp_size |
| mixed_qkv = torch.rand( |
| seq_len * batch_size, mixed_qkv_dim, dtype=torch.bfloat16 |
| ) |
| query, key, value = torch.split( |
| mixed_qkv, |
| [ |
| key_dim // attn_tp_size, |
| key_dim // attn_tp_size, |
| value_dim // attn_tp_size, |
| ], |
| dim=-1, |
| ) |
| query = query.view(1, seq_len, num_heads, head_k_dim) |
| key = key.view(1, seq_len, num_heads, head_k_dim) |
| value = value.view(1, seq_len, num_value_heads, head_v_dim) |
| A_log = torch.rand(num_value_heads, dtype=torch.float32) |
| a = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) |
| b = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16) |
| dt_bias = torch.rand(num_value_heads, dtype=torch.bfloat16) |
| ssm_states = torch.rand( |
| 513, num_value_heads, head_k_dim, head_v_dim, dtype=torch.float32 |
| ) |
| cache_indices = torch.randint(0, 513, (batch_size,), dtype=torch.int32) |
| query_start_loc = torch.tensor([0, 1], dtype=torch.int32) |
| use_qk_l2norm_in_kernel = True |
| query_ref = query.clone() |
| key_ref = key.clone() |
| if num_value_heads // num_heads > 1: |
| query_ref = query_ref.repeat_interleave(num_value_heads // num_heads, dim=2) |
| key_ref = key_ref.repeat_interleave(num_value_heads // num_heads, dim=2) |
| core_attn_out_ref, last_recurrent_state_ref = sigmoid_gating_delta_rule_update( |
| query_ref.transpose(0, 1), |
| key_ref.transpose(0, 1), |
| value.transpose(0, 1), |
| A_log, |
| a, |
| dt_bias, |
| b, |
| initial_state=ssm_states[cache_indices], |
| output_final_state=True, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| ) |
| core_attn_out = torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu( |
| A_log=A_log, |
| dt_bias=dt_bias, |
| q=query, |
| k=key, |
| v=value, |
| a=a, |
| b=b, |
| initial_state_source=ssm_states, |
| initial_state_indices=cache_indices, |
| cu_seqlens=query_start_loc, |
| use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, |
| softplus_beta=1.0, |
| softplus_threshold=20.0, |
| ) |
| last_recurrent_state = ssm_states[cache_indices] |
| atol = rtol = precision[core_attn_out.dtype] |
| torch.testing.assert_close( |
| core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol |
| ) |
| torch.testing.assert_close( |
| last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|