File size: 4,339 Bytes
075f0ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Dump golden intermediate values from Falcon-Mamba 7B forward pass.
One token (BOS=1), layer 0 only. Dumps after each op so we can compare
against WebGPU shader outputs.
"""
import torch, json, sys
from transformers import AutoModelForCausalLM, AutoTokenizer

print("Loading model...", flush=True)
model = AutoModelForCausalLM.from_pretrained(
    "tiiuae/falcon-mamba-7b-instruct",
    dtype=torch.float32,
)
model.eval()

# Single token: BOS = 1
token_id = 1
input_ids = torch.tensor([[token_id]], dtype=torch.long)

# Get embedding
emb_weight = model.backbone.embeddings.weight
hidden = emb_weight[token_id].unsqueeze(0)  # [1, 4096]
print(f"embedding[0:10]: {hidden[0, :10].tolist()}")

# Layer 0
layer = model.backbone.layers[0]
mixer = layer.mixer

# RMSNorm
norm_weight = layer.norm.weight
variance = hidden.pow(2).mean(-1, keepdim=True)
norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight
print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}")

# in_proj
projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight)
print(f"projected[0:10]: {projected[0, :10].tolist()}")
print(f"projected shape: {projected.shape}")

# Split into hidden and gate
I = projected.shape[-1] // 2  # 8192
x = projected[..., :I]
gate = projected[..., I:]
print(f"x[0:10]: {x[0, :10].tolist()}")
print(f"gate[0:10]: {gate[0, :10].tolist()}")

# Conv1d (first step, state is zeros)
conv_weight = mixer.conv1d.weight  # [8192, 1, 4]
conv_bias = mixer.conv1d.bias      # [8192]
# With zero state, conv1d output = x * w[:, 0, 3] + bias (only the last kernel position)
# Actually for depthwise conv with zero-padded input, it's just the last weight * x + bias
conv_state = torch.zeros(1, I, 3)  # [batch, channels, kernel-1]
window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1)  # [8192, 4]
conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias  # [8192]
print(f"conv_out[0:10]: {conv_out[:10].tolist()}")

# SiLU on conv output
hidden_silu = conv_out * torch.sigmoid(conv_out)
print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}")

# x_proj: hidden_silu -> dt_pre, B, C
x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight)
dt_rank = mixer.dt_proj.weight.shape[1]  # 256
ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2  # 16
dt_pre = x_proj_out[..., :dt_rank]
B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size]
C = x_proj_out[..., dt_rank + ssm_state_size:]
print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}")
print(f"B[0:10]: {B[0, :10].tolist()}")
print(f"C[0:10]: {C[0, :10].tolist()}")

# dt_proj: dt_pre -> dt
dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight)
print(f"dt[0:10]: {dt[0, :10].tolist()}")

# SSU
dt_bias = mixer.dt_proj.bias
A_log = mixer.A_log  # [8192, 16]
D_param = mixer.D    # [8192]

# delta = softplus(dt + dt_bias)
delta = torch.nn.functional.softplus(dt + dt_bias)
print(f"delta[0:10]: {delta[0, :10].tolist()}")

# A = -exp(A_log)
A = -torch.exp(A_log.float())
print(f"A[0:10,0]: {A[:10, 0].tolist()}")

# delta_A = exp(delta * A)  -- for each (h, s)
delta_expanded = delta.squeeze(0).unsqueeze(-1)  # [8192, 1]
delta_A = torch.exp(delta_expanded * A)  # [8192, 16]
print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}")

# state update: state = state * delta_A + delta_B * x
# state starts at zeros, so state = delta_B * x
delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A))
# Actually B is [1, 16], need to broadcast
B_sq = B.squeeze(0)  # [16]
delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1)  # [8192, 16]
state = delta_A * 0 + delta_B_x  # state was zero
print(f"state[0:5,0:4]: {state[:5, :4].tolist()}")

# y = sum_s(state * C) + D * x
C_sq = C.squeeze(0)  # [16]
y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu
print(f"y[0:10]: {y[:10].tolist()}")

# gate silu
gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0))
print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}")

# y * gate
gated = y * gate_silu
print(f"gated[0:10]: {gated[:10].tolist()}")

# out_proj
out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight)
print(f"out_proj[0:10]: {out[0, :10].tolist()}")

# residual
result = hidden + out
print(f"after_residual[0:10]: {result[0, :10].tolist()}")

print("\nDone. Compare these values against WebGPU readbacks.")