| """ |
| model/mlp.py |
| |
| SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc. |
| |
| Standard FFN (GPT-2): |
| out = dropout(W2 * GELU(W1 * x)) |
| |
| SwiGLU FFN (LLaMA): |
| gate = W_gate * x # linear gate |
| up = W_up * x # linear up-proj |
| hidden = SiLU(gate) * up # element-wise gating (learned) |
| out = W_down * hidden # down-proj back to d_model |
| |
| SiLU (Sigmoid Linear Unit): |
| SiLU(x) = x * sigmoid(x) |
| |
| Why SwiGLU is better: |
| - The gating mechanism (SiLU(gate) * up) gives the model a learned |
| way to activate or suppress each hidden dimension independently. |
| - Empirically outperforms GELU/ReLU FFNs at the same parameter count. |
| - d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256. |
| This compensates for having 3 matrices instead of 2, keeping |
| total parameter count comparable to a standard 4x FFN. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from model.config import ModelConfig |
|
|
|
|
| class SwiGLU(nn.Module): |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
|
|
| d_model = config.d_model |
| d_ff = config.d_ff |
|
|
| |
| self.gate = nn.Linear(d_model, d_ff, bias=config.bias) |
| self.up = nn.Linear(d_model, d_ff, bias=config.bias) |
| self.down = nn.Linear(d_ff, d_model, bias=config.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x : (B, T, d_model) |
| Returns: |
| out : (B, T, d_model) |
| """ |
| |
| |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| from model.config import SLLM_100M |
|
|
| cfg = SLLM_100M |
| mlp = SwiGLU(cfg) |
|
|
| n_params = sum(p.numel() for p in mlp.parameters()) |
| print(f"SwiGLU d_model={cfg.d_model} d_ff={cfg.d_ff}") |
| print(f" gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}") |
| print(f" up : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}") |
| print(f" down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}") |
| print(f" total MLP params : {n_params/1e6:.3f}M") |
|
|
| B, T = 2, 64 |
| x = torch.randn(B, T, cfg.d_model) |
| out = mlp(x) |
|
|
| print(f"Input shape : {x.shape}") |
| print(f"Output shape : {out.shape}") |
| assert out.shape == x.shape, "Shape mismatch!" |
| print("PASS") |
|
|