| """ |
| Dump golden intermediate values from Falcon-Mamba 7B forward pass. |
| One token (BOS=1), layer 0 only. Dumps after each op so we can compare |
| against WebGPU shader outputs. |
| """ |
| import torch, json, sys |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| print("Loading model...", flush=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| "tiiuae/falcon-mamba-7b-instruct", |
| dtype=torch.float32, |
| ) |
| model.eval() |
|
|
| |
| token_id = 1 |
| input_ids = torch.tensor([[token_id]], dtype=torch.long) |
|
|
| |
| emb_weight = model.backbone.embeddings.weight |
| hidden = emb_weight[token_id].unsqueeze(0) |
| print(f"embedding[0:10]: {hidden[0, :10].tolist()}") |
|
|
| |
| layer = model.backbone.layers[0] |
| mixer = layer.mixer |
|
|
| |
| norm_weight = layer.norm.weight |
| variance = hidden.pow(2).mean(-1, keepdim=True) |
| norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight |
| print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}") |
|
|
| |
| projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight) |
| print(f"projected[0:10]: {projected[0, :10].tolist()}") |
| print(f"projected shape: {projected.shape}") |
|
|
| |
| I = projected.shape[-1] // 2 |
| x = projected[..., :I] |
| gate = projected[..., I:] |
| print(f"x[0:10]: {x[0, :10].tolist()}") |
| print(f"gate[0:10]: {gate[0, :10].tolist()}") |
|
|
| |
| conv_weight = mixer.conv1d.weight |
| conv_bias = mixer.conv1d.bias |
| |
| |
| conv_state = torch.zeros(1, I, 3) |
| window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1) |
| conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias |
| print(f"conv_out[0:10]: {conv_out[:10].tolist()}") |
|
|
| |
| hidden_silu = conv_out * torch.sigmoid(conv_out) |
| print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}") |
|
|
| |
| x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight) |
| dt_rank = mixer.dt_proj.weight.shape[1] |
| ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2 |
| dt_pre = x_proj_out[..., :dt_rank] |
| B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size] |
| C = x_proj_out[..., dt_rank + ssm_state_size:] |
| print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}") |
| print(f"B[0:10]: {B[0, :10].tolist()}") |
| print(f"C[0:10]: {C[0, :10].tolist()}") |
|
|
| |
| dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight) |
| print(f"dt[0:10]: {dt[0, :10].tolist()}") |
|
|
| |
| dt_bias = mixer.dt_proj.bias |
| A_log = mixer.A_log |
| D_param = mixer.D |
|
|
| |
| delta = torch.nn.functional.softplus(dt + dt_bias) |
| print(f"delta[0:10]: {delta[0, :10].tolist()}") |
|
|
| |
| A = -torch.exp(A_log.float()) |
| print(f"A[0:10,0]: {A[:10, 0].tolist()}") |
|
|
| |
| delta_expanded = delta.squeeze(0).unsqueeze(-1) |
| delta_A = torch.exp(delta_expanded * A) |
| print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}") |
|
|
| |
| |
| delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A)) |
| |
| B_sq = B.squeeze(0) |
| delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1) |
| state = delta_A * 0 + delta_B_x |
| print(f"state[0:5,0:4]: {state[:5, :4].tolist()}") |
|
|
| |
| C_sq = C.squeeze(0) |
| y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu |
| print(f"y[0:10]: {y[:10].tolist()}") |
|
|
| |
| gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0)) |
| print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}") |
|
|
| |
| gated = y * gate_silu |
| print(f"gated[0:10]: {gated[:10].tolist()}") |
|
|
| |
| out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight) |
| print(f"out_proj[0:10]: {out[0, :10].tolist()}") |
|
|
| |
| result = hidden + out |
| print(f"after_residual[0:10]: {result[0, :10].tolist()}") |
|
|
| print("\nDone. Compare these values against WebGPU readbacks.") |
|
|