Hanrui / sglang /test /srt /cpu /test_mamba.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import unittest
import torch
import torch.nn.functional as F
from torch.nn.functional import softplus
from utils import precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=0,
)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device),
diagonal=1,
)
# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2
)
@ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]
)
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def chunk_gated_delta_rule_update(
query, # [B, T, HK, K]
key, # [B, T, HK, K]
value, # [B, T, HV, V]
g, # [B, T, HV]
beta, # [B, T, HV]
cu_seqlens, # [N+1]
initial_state, # [N, HV, K, V]
use_qk_l2norm_in_kernel, # True
):
num_heads = query.shape[2]
num_value_heads = value.shape[2]
batch_size = initial_state.shape[0]
if num_value_heads // num_heads > 1:
query = query.repeat_interleave(num_value_heads // num_heads, dim=2)
key = key.repeat_interleave(num_value_heads // num_heads, dim=2)
output = torch.empty_like(value)
final_state = torch.empty_like(initial_state)
start_q = 0
for i in range(batch_size):
end_q = cu_seqlens[i + 1]
core_attn_outi, last_recurrent_state = torch_chunk_gated_delta_rule(
query=query[:, start_q:end_q, :, :],
key=key[:, start_q:end_q, :, :],
value=value[:, start_q:end_q, :, :],
g=g[:, start_q:end_q, :],
beta=beta[:, start_q:end_q, :],
initial_state=initial_state[i],
output_final_state=True,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
output[:, start_q:end_q, :, :] = core_attn_outi
final_state[i] = last_recurrent_state
start_q = end_q
return output, final_state
def torch_recurrent_gated_delta_rule(
query,
key,
value,
g,
beta,
initial_state,
output_final_state,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(
value
)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(
-1
) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def sigmoid_gating_delta_rule_update(
query,
key,
value,
A_log,
a,
dt_bias,
b,
initial_state,
output_final_state,
use_qk_l2norm_in_kernel=False,
):
beta = b.sigmoid()
g = -A_log.float().exp() * softplus(a.float() + dt_bias)
return torch_recurrent_gated_delta_rule(
query,
key,
value,
g.unsqueeze(0),
beta.unsqueeze(0),
initial_state,
output_final_state,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
def torch_gdn_gating(A_log, a, b, dt_bias):
return -A_log.float().exp() * softplus(a.float() + dt_bias).unsqueeze(
0
), b.sigmoid().unsqueeze(0)
class TestMambaAttention(CustomTestCase):
def test_chunk_gated_delta_rule(self):
B, L, HK, HV, EK, EV, N = 1, 100, 3, 6, 64, 64, 4
seqlens = torch.randint(1, L, (N + 1,))
seqlens[0] = 0
cu_seqlens_ = torch.cumsum(seqlens, dim=0).to(torch.int32)
T = cu_seqlens_[-1].item()
query_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05
key_ = torch.rand((B, T, HK, EK), dtype=torch.bfloat16) * 0.05
value_ = torch.rand((B, T, HV, EV), dtype=torch.bfloat16) * 0.05
g_ = torch.rand((B, T, HV), dtype=torch.float32) * 0.05
beta_ = torch.rand((B, T, HV), dtype=torch.bfloat16) * 0.05
initial_state_ = torch.rand((N, HV, EK, EV), dtype=torch.float32) * 0.05
for use_qk_l2norm_in_kernel in [True, False]:
core_attn_out_ref, last_recurrent_state_ref = chunk_gated_delta_rule_update(
query=query_,
key=key_,
value=value_,
g=g_,
beta=beta_,
cu_seqlens=cu_seqlens_,
initial_state=initial_state_,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
query = query_.clone()
key = key_.clone()
value = value_.clone()
g = g_.clone()
beta = beta_.clone()
cu_seqlens = cu_seqlens_.clone()
initial_state = initial_state_.clone()
core_attn_out, last_recurrent_state = (
torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu(
query=query,
key=key,
value=value,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=cu_seqlens,
head_first=False,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
)
atol = rtol = precision[core_attn_out.dtype]
torch.testing.assert_close(
core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol
)
torch.testing.assert_close(
last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol
)
def test_fused_gdn_gating(self):
dims = [6, 32]
for dim in dims:
A_log = torch.rand(dim)
a = torch.rand(1024, dim, dtype=torch.bfloat16)
b = torch.rand(1024, dim, dtype=torch.bfloat16)
dt_bias = torch.rand(dim, dtype=torch.bfloat16)
g, beta = torch_gdn_gating(A_log, a, b, dt_bias)
g_sgl, beta_sgl = torch.ops.sgl_kernel.fused_gdn_gating_cpu(
A_log, a, b, dt_bias
)
atol = rtol = precision[g.dtype]
atol2 = rtol2 = precision[beta.dtype]
torch.testing.assert_close(g, g_sgl, atol=atol, rtol=rtol)
torch.testing.assert_close(beta, beta_sgl, atol=atol2, rtol=rtol2)
def test_fused_sigmoid_gating_delta_rule_update(self):
batch_size = 1
num_value_heads = 32
head_k_dim = 128
head_v_dim = 128
num_heads = 16
seq_len = 1
attn_tp_size = 1
key_dim = head_k_dim * num_heads
value_dim = head_v_dim * num_value_heads
mixed_qkv_dim = (key_dim * 2 + value_dim) // attn_tp_size
mixed_qkv = torch.rand(
seq_len * batch_size, mixed_qkv_dim, dtype=torch.bfloat16
)
query, key, value = torch.split(
mixed_qkv,
[
key_dim // attn_tp_size,
key_dim // attn_tp_size,
value_dim // attn_tp_size,
],
dim=-1,
)
query = query.view(1, seq_len, num_heads, head_k_dim)
key = key.view(1, seq_len, num_heads, head_k_dim)
value = value.view(1, seq_len, num_value_heads, head_v_dim)
A_log = torch.rand(num_value_heads, dtype=torch.float32)
a = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16)
b = torch.rand(batch_size, num_value_heads, dtype=torch.bfloat16)
dt_bias = torch.rand(num_value_heads, dtype=torch.bfloat16)
ssm_states = torch.rand(
513, num_value_heads, head_k_dim, head_v_dim, dtype=torch.float32
)
cache_indices = torch.randint(0, 513, (batch_size,), dtype=torch.int32)
query_start_loc = torch.tensor([0, 1], dtype=torch.int32)
use_qk_l2norm_in_kernel = True
query_ref = query.clone()
key_ref = key.clone()
if num_value_heads // num_heads > 1:
query_ref = query_ref.repeat_interleave(num_value_heads // num_heads, dim=2)
key_ref = key_ref.repeat_interleave(num_value_heads // num_heads, dim=2)
core_attn_out_ref, last_recurrent_state_ref = sigmoid_gating_delta_rule_update(
query_ref.transpose(0, 1),
key_ref.transpose(0, 1),
value.transpose(0, 1),
A_log,
a,
dt_bias,
b,
initial_state=ssm_states[cache_indices],
output_final_state=True,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
core_attn_out = torch.ops.sgl_kernel.fused_sigmoid_gating_delta_rule_update_cpu(
A_log=A_log,
dt_bias=dt_bias,
q=query,
k=key,
v=value,
a=a,
b=b,
initial_state_source=ssm_states,
initial_state_indices=cache_indices,
cu_seqlens=query_start_loc,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
softplus_beta=1.0,
softplus_threshold=20.0,
)
last_recurrent_state = ssm_states[cache_indices]
atol = rtol = precision[core_attn_out.dtype]
torch.testing.assert_close(
core_attn_out, core_attn_out_ref, atol=atol, rtol=rtol
)
torch.testing.assert_close(
last_recurrent_state, last_recurrent_state_ref, atol=atol, rtol=rtol
)
if __name__ == "__main__":
unittest.main()