""" Dump golden intermediate values from Falcon-Mamba 7B forward pass. One token (BOS=1), layer 0 only. Dumps after each op so we can compare against WebGPU shader outputs. """ import torch, json, sys from transformers import AutoModelForCausalLM, AutoTokenizer print("Loading model...", flush=True) model = AutoModelForCausalLM.from_pretrained( "tiiuae/falcon-mamba-7b-instruct", dtype=torch.float32, ) model.eval() # Single token: BOS = 1 token_id = 1 input_ids = torch.tensor([[token_id]], dtype=torch.long) # Get embedding emb_weight = model.backbone.embeddings.weight hidden = emb_weight[token_id].unsqueeze(0) # [1, 4096] print(f"embedding[0:10]: {hidden[0, :10].tolist()}") # Layer 0 layer = model.backbone.layers[0] mixer = layer.mixer # RMSNorm norm_weight = layer.norm.weight variance = hidden.pow(2).mean(-1, keepdim=True) norm_out = hidden * torch.rsqrt(variance + 1e-5) * norm_weight print(f"norm_out[0:10]: {norm_out[0, :10].tolist()}") # in_proj projected = torch.nn.functional.linear(norm_out, mixer.in_proj.weight) print(f"projected[0:10]: {projected[0, :10].tolist()}") print(f"projected shape: {projected.shape}") # Split into hidden and gate I = projected.shape[-1] // 2 # 8192 x = projected[..., :I] gate = projected[..., I:] print(f"x[0:10]: {x[0, :10].tolist()}") print(f"gate[0:10]: {gate[0, :10].tolist()}") # Conv1d (first step, state is zeros) conv_weight = mixer.conv1d.weight # [8192, 1, 4] conv_bias = mixer.conv1d.bias # [8192] # With zero state, conv1d output = x * w[:, 0, 3] + bias (only the last kernel position) # Actually for depthwise conv with zero-padded input, it's just the last weight * x + bias conv_state = torch.zeros(1, I, 3) # [batch, channels, kernel-1] window = torch.cat([conv_state.squeeze(0), x.squeeze(0).unsqueeze(-1)], dim=-1) # [8192, 4] conv_out = (window * conv_weight.squeeze(1)).sum(-1) + conv_bias # [8192] print(f"conv_out[0:10]: {conv_out[:10].tolist()}") # SiLU on conv output hidden_silu = conv_out * torch.sigmoid(conv_out) print(f"after_silu[0:10]: {hidden_silu[:10].tolist()}") # x_proj: hidden_silu -> dt_pre, B, C x_proj_out = torch.nn.functional.linear(hidden_silu.unsqueeze(0), mixer.x_proj.weight) dt_rank = mixer.dt_proj.weight.shape[1] # 256 ssm_state_size = (x_proj_out.shape[-1] - dt_rank) // 2 # 16 dt_pre = x_proj_out[..., :dt_rank] B = x_proj_out[..., dt_rank:dt_rank + ssm_state_size] C = x_proj_out[..., dt_rank + ssm_state_size:] print(f"dt_pre[0:10]: {dt_pre[0, :10].tolist()}") print(f"B[0:10]: {B[0, :10].tolist()}") print(f"C[0:10]: {C[0, :10].tolist()}") # dt_proj: dt_pre -> dt dt = torch.nn.functional.linear(dt_pre, mixer.dt_proj.weight) print(f"dt[0:10]: {dt[0, :10].tolist()}") # SSU dt_bias = mixer.dt_proj.bias A_log = mixer.A_log # [8192, 16] D_param = mixer.D # [8192] # delta = softplus(dt + dt_bias) delta = torch.nn.functional.softplus(dt + dt_bias) print(f"delta[0:10]: {delta[0, :10].tolist()}") # A = -exp(A_log) A = -torch.exp(A_log.float()) print(f"A[0:10,0]: {A[:10, 0].tolist()}") # delta_A = exp(delta * A) -- for each (h, s) delta_expanded = delta.squeeze(0).unsqueeze(-1) # [8192, 1] delta_A = torch.exp(delta_expanded * A) # [8192, 16] print(f"delta_A[0:5,0:4]: {delta_A[:5, :4].tolist()}") # state update: state = state * delta_A + delta_B * x # state starts at zeros, so state = delta_B * x delta_B = (delta_expanded * B.squeeze(0).unsqueeze(0).expand_as(A)) # Actually B is [1, 16], need to broadcast B_sq = B.squeeze(0) # [16] delta_B_x = delta_expanded * B_sq.unsqueeze(0) * hidden_silu.unsqueeze(-1) # [8192, 16] state = delta_A * 0 + delta_B_x # state was zero print(f"state[0:5,0:4]: {state[:5, :4].tolist()}") # y = sum_s(state * C) + D * x C_sq = C.squeeze(0) # [16] y = (state * C_sq.unsqueeze(0)).sum(-1) + D_param * hidden_silu print(f"y[0:10]: {y[:10].tolist()}") # gate silu gate_silu = gate.squeeze(0) * torch.sigmoid(gate.squeeze(0)) print(f"gate_silu[0:10]: {gate_silu[:10].tolist()}") # y * gate gated = y * gate_silu print(f"gated[0:10]: {gated[:10].tolist()}") # out_proj out = torch.nn.functional.linear(gated.unsqueeze(0), mixer.out_proj.weight) print(f"out_proj[0:10]: {out[0, :10].tolist()}") # residual result = hidden + out print(f"after_residual[0:10]: {result[0, :10].tolist()}") print("\nDone. Compare these values against WebGPU readbacks.")