LJTSG commited on
Commit
075f0ef
·
verified ·
1 Parent(s): ed8ac31

Upload golden_dump.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. golden_dump.py +124 -0
golden_dump.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dump golden intermediate values from Falcon-Mamba 7B forward pass.
3
+ One token (BOS=1), layer 0 only. Dumps after each op so we can compare
4
+ against WebGPU shader outputs.
5
+ """
6
+ import torch, json, sys
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ print("Loading model...", flush=True)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ "tiiuae/falcon-mamba-7b-instruct",
12
+ dtype=torch.float32,
13
+ )
14
+ model.eval()
15
+
16
+ # Single token: BOS = 1
17
+ token_id = 1
18
+ input_ids = torch.tensor([[token_id]], dtype=torch.long)
19
+
20
+ # Get embedding
21
+ emb_weight = model.backbone.embeddings.weight
22
+ hidden = emb_weight[token_id].unsqueeze(0) # [1, 4096]
23
+ print(f"embedding[0:10]: {hidden[0, :10].tolist()}")
24
+
25
+ # Layer 0
26
+ layer = model.backbone.layers[0]
27
+ mixer = layer.mixer
28
+
29
+ # RMSNorm
30
+ norm_weight = layer.norm.weight
31
+ variance = hidden.pow(2).mean(-1, keepdim=True)
32
+ norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight
33
+ print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}")
34
+
35
+ # in_proj
36
+ projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight)
37
+ print(f"projected[0:10]: {projected[0, :10].tolist()}")
38
+ print(f"projected shape: {projected.shape}")
39
+
40
+ # Split into hidden and gate
41
+ I = projected.shape[-1] // 2 # 8192
42
+ x = projected[..., :I]
43
+ gate = projected[..., I:]
44
+ print(f"x[0:10]: {x[0, :10].tolist()}")
45
+ print(f"gate[0:10]: {gate[0, :10].tolist()}")
46
+
47
+ # Conv1d (first step, state is zeros)
48
+ conv_weight = mixer.conv1d.weight # [8192, 1, 4]
49
+ conv_bias = mixer.conv1d.bias # [8192]
50
+ # With zero state, conv1d output = x * w[:, 0, 3] + bias (only the last kernel position)
51
+ # Actually for depthwise conv with zero-padded input, it's just the last weight * x + bias
52
+ conv_state = torch.zeros(1, I, 3) # [batch, channels, kernel-1]
53
+ window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1) # [8192, 4]
54
+ conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias # [8192]
55
+ print(f"conv_out[0:10]: {conv_out[:10].tolist()}")
56
+
57
+ # SiLU on conv output
58
+ hidden_silu = conv_out * torch.sigmoid(conv_out)
59
+ print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}")
60
+
61
+ # x_proj: hidden_silu -> dt_pre, B, C
62
+ x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight)
63
+ dt_rank = mixer.dt_proj.weight.shape[1] # 256
64
+ ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2 # 16
65
+ dt_pre = x_proj_out[..., :dt_rank]
66
+ B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size]
67
+ C = x_proj_out[..., dt_rank + ssm_state_size:]
68
+ print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}")
69
+ print(f"B[0:10]: {B[0, :10].tolist()}")
70
+ print(f"C[0:10]: {C[0, :10].tolist()}")
71
+
72
+ # dt_proj: dt_pre -> dt
73
+ dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight)
74
+ print(f"dt[0:10]: {dt[0, :10].tolist()}")
75
+
76
+ # SSU
77
+ dt_bias = mixer.dt_proj.bias
78
+ A_log = mixer.A_log # [8192, 16]
79
+ D_param = mixer.D # [8192]
80
+
81
+ # delta = softplus(dt + dt_bias)
82
+ delta = torch.nn.functional.softplus(dt + dt_bias)
83
+ print(f"delta[0:10]: {delta[0, :10].tolist()}")
84
+
85
+ # A = -exp(A_log)
86
+ A = -torch.exp(A_log.float())
87
+ print(f"A[0:10,0]: {A[:10, 0].tolist()}")
88
+
89
+ # delta_A = exp(delta * A) -- for each (h, s)
90
+ delta_expanded = delta.squeeze(0).unsqueeze(-1) # [8192, 1]
91
+ delta_A = torch.exp(delta_expanded * A) # [8192, 16]
92
+ print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}")
93
+
94
+ # state update: state = state * delta_A + delta_B * x
95
+ # state starts at zeros, so state = delta_B * x
96
+ delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A))
97
+ # Actually B is [1, 16], need to broadcast
98
+ B_sq = B.squeeze(0) # [16]
99
+ delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1) # [8192, 16]
100
+ state = delta_A * 0 + delta_B_x # state was zero
101
+ print(f"state[0:5,0:4]: {state[:5, :4].tolist()}")
102
+
103
+ # y = sum_s(state * C) + D * x
104
+ C_sq = C.squeeze(0) # [16]
105
+ y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu
106
+ print(f"y[0:10]: {y[:10].tolist()}")
107
+
108
+ # gate silu
109
+ gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0))
110
+ print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}")
111
+
112
+ # y * gate
113
+ gated = y * gate_silu
114
+ print(f"gated[0:10]: {gated[:10].tolist()}")
115
+
116
+ # out_proj
117
+ out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight)
118
+ print(f"out_proj[0:10]: {out[0, :10].tolist()}")
119
+
120
+ # residual
121
+ result = hidden + out
122
+ print(f"after_residual[0:10]: {result[0, :10].tolist()}")
123
+
124
+ print("\nDone. Compare these values against WebGPU readbacks.")