mamba-webgpu / golden_dump.py
LJTSG's picture
Upload golden_dump.py with huggingface_hub
075f0ef verified
"""
Dump golden intermediate values from Falcon-Mamba 7B forward pass.
One token (BOS=1), layer 0 only. Dumps after each op so we can compare
against WebGPU shader outputs.
"""
import torch, json, sys
from transformers import AutoModelForCausalLM, AutoTokenizer
print("Loading model...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-mamba-7b-instruct",
dtype=torch.float32,
)
model.eval()
# Single token: BOS = 1
token_id = 1
input_ids = torch.tensor([[token_id]], dtype=torch.long)
# Get embedding
emb_weight = model.backbone.embeddings.weight
hidden = emb_weight[token_id].unsqueeze(0) # [1, 4096]
print(f"embedding[0:10]: {hidden[0, :10].tolist()}")
# Layer 0
layer = model.backbone.layers[0]
mixer = layer.mixer
# RMSNorm
norm_weight = layer.norm.weight
variance = hidden.pow(2).mean(-1, keepdim=True)
norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight
print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}")
# in_proj
projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight)
print(f"projected[0:10]: {projected[0, :10].tolist()}")
print(f"projected shape: {projected.shape}")
# Split into hidden and gate
I = projected.shape[-1] // 2 # 8192
x = projected[..., :I]
gate = projected[..., I:]
print(f"x[0:10]: {x[0, :10].tolist()}")
print(f"gate[0:10]: {gate[0, :10].tolist()}")
# Conv1d (first step, state is zeros)
conv_weight = mixer.conv1d.weight # [8192, 1, 4]
conv_bias = mixer.conv1d.bias # [8192]
# With zero state, conv1d output = x * w[:, 0, 3] + bias (only the last kernel position)
# Actually for depthwise conv with zero-padded input, it's just the last weight * x + bias
conv_state = torch.zeros(1, I, 3) # [batch, channels, kernel-1]
window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1) # [8192, 4]
conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias # [8192]
print(f"conv_out[0:10]: {conv_out[:10].tolist()}")
# SiLU on conv output
hidden_silu = conv_out * torch.sigmoid(conv_out)
print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}")
# x_proj: hidden_silu -> dt_pre, B, C
x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight)
dt_rank = mixer.dt_proj.weight.shape[1] # 256
ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2 # 16
dt_pre = x_proj_out[..., :dt_rank]
B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size]
C = x_proj_out[..., dt_rank + ssm_state_size:]
print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}")
print(f"B[0:10]: {B[0, :10].tolist()}")
print(f"C[0:10]: {C[0, :10].tolist()}")
# dt_proj: dt_pre -> dt
dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight)
print(f"dt[0:10]: {dt[0, :10].tolist()}")
# SSU
dt_bias = mixer.dt_proj.bias
A_log = mixer.A_log # [8192, 16]
D_param = mixer.D # [8192]
# delta = softplus(dt + dt_bias)
delta = torch.nn.functional.softplus(dt + dt_bias)
print(f"delta[0:10]: {delta[0, :10].tolist()}")
# A = -exp(A_log)
A = -torch.exp(A_log.float())
print(f"A[0:10,0]: {A[:10, 0].tolist()}")
# delta_A = exp(delta * A) -- for each (h, s)
delta_expanded = delta.squeeze(0).unsqueeze(-1) # [8192, 1]
delta_A = torch.exp(delta_expanded * A) # [8192, 16]
print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}")
# state update: state = state * delta_A + delta_B * x
# state starts at zeros, so state = delta_B * x
delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A))
# Actually B is [1, 16], need to broadcast
B_sq = B.squeeze(0) # [16]
delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1) # [8192, 16]
state = delta_A * 0 + delta_B_x # state was zero
print(f"state[0:5,0:4]: {state[:5, :4].tolist()}")
# y = sum_s(state * C) + D * x
C_sq = C.squeeze(0) # [16]
y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu
print(f"y[0:10]: {y[:10].tolist()}")
# gate silu
gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0))
print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}")
# y * gate
gated = y * gate_silu
print(f"gated[0:10]: {gated[:10].tolist()}")
# out_proj
out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight)
print(f"out_proj[0:10]: {out[0, :10].tolist()}")
# residual
result = hidden + out
print(f"after_residual[0:10]: {result[0, :10].tolist()}")
print("\nDone. Compare these values against WebGPU readbacks.")