Upload golden_dump.py with huggingface_hub
Browse files- golden_dump.py +124 -0
golden_dump.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dump golden intermediate values from Falcon-Mamba 7B forward pass.
|
| 3 |
+
One token (BOS=1), layer 0 only. Dumps after each op so we can compare
|
| 4 |
+
against WebGPU shader outputs.
|
| 5 |
+
"""
|
| 6 |
+
import torch, json, sys
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
print("Loading model...", flush=True)
|
| 10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 11 |
+
"tiiuae/falcon-mamba-7b-instruct",
|
| 12 |
+
dtype=torch.float32,
|
| 13 |
+
)
|
| 14 |
+
model.eval()
|
| 15 |
+
|
| 16 |
+
# Single token: BOS = 1
|
| 17 |
+
token_id = 1
|
| 18 |
+
input_ids = torch.tensor([[token_id]], dtype=torch.long)
|
| 19 |
+
|
| 20 |
+
# Get embedding
|
| 21 |
+
emb_weight = model.backbone.embeddings.weight
|
| 22 |
+
hidden = emb_weight[token_id].unsqueeze(0) # [1, 4096]
|
| 23 |
+
print(f"embedding[0:10]: {hidden[0, :10].tolist()}")
|
| 24 |
+
|
| 25 |
+
# Layer 0
|
| 26 |
+
layer = model.backbone.layers[0]
|
| 27 |
+
mixer = layer.mixer
|
| 28 |
+
|
| 29 |
+
# RMSNorm
|
| 30 |
+
norm_weight = layer.norm.weight
|
| 31 |
+
variance = hidden.pow(2).mean(-1, keepdim=True)
|
| 32 |
+
norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight
|
| 33 |
+
print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}")
|
| 34 |
+
|
| 35 |
+
# in_proj
|
| 36 |
+
projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight)
|
| 37 |
+
print(f"projected[0:10]: {projected[0, :10].tolist()}")
|
| 38 |
+
print(f"projected shape: {projected.shape}")
|
| 39 |
+
|
| 40 |
+
# Split into hidden and gate
|
| 41 |
+
I = projected.shape[-1] // 2 # 8192
|
| 42 |
+
x = projected[..., :I]
|
| 43 |
+
gate = projected[..., I:]
|
| 44 |
+
print(f"x[0:10]: {x[0, :10].tolist()}")
|
| 45 |
+
print(f"gate[0:10]: {gate[0, :10].tolist()}")
|
| 46 |
+
|
| 47 |
+
# Conv1d (first step, state is zeros)
|
| 48 |
+
conv_weight = mixer.conv1d.weight # [8192, 1, 4]
|
| 49 |
+
conv_bias = mixer.conv1d.bias # [8192]
|
| 50 |
+
# With zero state, conv1d output = x * w[:, 0, 3] + bias (only the last kernel position)
|
| 51 |
+
# Actually for depthwise conv with zero-padded input, it's just the last weight * x + bias
|
| 52 |
+
conv_state = torch.zeros(1, I, 3) # [batch, channels, kernel-1]
|
| 53 |
+
window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1) # [8192, 4]
|
| 54 |
+
conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias # [8192]
|
| 55 |
+
print(f"conv_out[0:10]: {conv_out[:10].tolist()}")
|
| 56 |
+
|
| 57 |
+
# SiLU on conv output
|
| 58 |
+
hidden_silu = conv_out * torch.sigmoid(conv_out)
|
| 59 |
+
print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}")
|
| 60 |
+
|
| 61 |
+
# x_proj: hidden_silu -> dt_pre, B, C
|
| 62 |
+
x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight)
|
| 63 |
+
dt_rank = mixer.dt_proj.weight.shape[1] # 256
|
| 64 |
+
ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2 # 16
|
| 65 |
+
dt_pre = x_proj_out[..., :dt_rank]
|
| 66 |
+
B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size]
|
| 67 |
+
C = x_proj_out[..., dt_rank + ssm_state_size:]
|
| 68 |
+
print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}")
|
| 69 |
+
print(f"B[0:10]: {B[0, :10].tolist()}")
|
| 70 |
+
print(f"C[0:10]: {C[0, :10].tolist()}")
|
| 71 |
+
|
| 72 |
+
# dt_proj: dt_pre -> dt
|
| 73 |
+
dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight)
|
| 74 |
+
print(f"dt[0:10]: {dt[0, :10].tolist()}")
|
| 75 |
+
|
| 76 |
+
# SSU
|
| 77 |
+
dt_bias = mixer.dt_proj.bias
|
| 78 |
+
A_log = mixer.A_log # [8192, 16]
|
| 79 |
+
D_param = mixer.D # [8192]
|
| 80 |
+
|
| 81 |
+
# delta = softplus(dt + dt_bias)
|
| 82 |
+
delta = torch.nn.functional.softplus(dt + dt_bias)
|
| 83 |
+
print(f"delta[0:10]: {delta[0, :10].tolist()}")
|
| 84 |
+
|
| 85 |
+
# A = -exp(A_log)
|
| 86 |
+
A = -torch.exp(A_log.float())
|
| 87 |
+
print(f"A[0:10,0]: {A[:10, 0].tolist()}")
|
| 88 |
+
|
| 89 |
+
# delta_A = exp(delta * A) -- for each (h, s)
|
| 90 |
+
delta_expanded = delta.squeeze(0).unsqueeze(-1) # [8192, 1]
|
| 91 |
+
delta_A = torch.exp(delta_expanded * A) # [8192, 16]
|
| 92 |
+
print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}")
|
| 93 |
+
|
| 94 |
+
# state update: state = state * delta_A + delta_B * x
|
| 95 |
+
# state starts at zeros, so state = delta_B * x
|
| 96 |
+
delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A))
|
| 97 |
+
# Actually B is [1, 16], need to broadcast
|
| 98 |
+
B_sq = B.squeeze(0) # [16]
|
| 99 |
+
delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1) # [8192, 16]
|
| 100 |
+
state = delta_A * 0 + delta_B_x # state was zero
|
| 101 |
+
print(f"state[0:5,0:4]: {state[:5, :4].tolist()}")
|
| 102 |
+
|
| 103 |
+
# y = sum_s(state * C) + D * x
|
| 104 |
+
C_sq = C.squeeze(0) # [16]
|
| 105 |
+
y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu
|
| 106 |
+
print(f"y[0:10]: {y[:10].tolist()}")
|
| 107 |
+
|
| 108 |
+
# gate silu
|
| 109 |
+
gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0))
|
| 110 |
+
print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}")
|
| 111 |
+
|
| 112 |
+
# y * gate
|
| 113 |
+
gated = y * gate_silu
|
| 114 |
+
print(f"gated[0:10]: {gated[:10].tolist()}")
|
| 115 |
+
|
| 116 |
+
# out_proj
|
| 117 |
+
out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight)
|
| 118 |
+
print(f"out_proj[0:10]: {out[0, :10].tolist()}")
|
| 119 |
+
|
| 120 |
+
# residual
|
| 121 |
+
result = hidden + out
|
| 122 |
+
print(f"after_residual[0:10]: {result[0, :10].tolist()}")
|
| 123 |
+
|
| 124 |
+
print("\nDone. Compare these values against WebGPU readbacks.")
|