""" model/mlp.py SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc. Standard FFN (GPT-2): out = dropout(W2 * GELU(W1 * x)) SwiGLU FFN (LLaMA): gate = W_gate * x # linear gate up = W_up * x # linear up-proj hidden = SiLU(gate) * up # element-wise gating (learned) out = W_down * hidden # down-proj back to d_model SiLU (Sigmoid Linear Unit): SiLU(x) = x * sigmoid(x) Why SwiGLU is better: - The gating mechanism (SiLU(gate) * up) gives the model a learned way to activate or suppress each hidden dimension independently. - Empirically outperforms GELU/ReLU FFNs at the same parameter count. - d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256. This compensates for having 3 matrices instead of 2, keeping total parameter count comparable to a standard 4x FFN. """ import torch import torch.nn as nn import torch.nn.functional as F from model.config import ModelConfig class SwiGLU(nn.Module): def __init__(self, config: ModelConfig): super().__init__() d_model = config.d_model d_ff = config.d_ff # Three weight matrices — no bias self.gate = nn.Linear(d_model, d_ff, bias=config.bias) # gate projection self.up = nn.Linear(d_model, d_ff, bias=config.bias) # up projection self.down = nn.Linear(d_ff, d_model, bias=config.bias) # down projection def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x : (B, T, d_model) Returns: out : (B, T, d_model) """ # SiLU = x * sigmoid(x) (also called swish) # Element-wise gating: SiLU(gate) acts as a learned activation mask on up return self.down(F.silu(self.gate(x)) * self.up(x)) # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": from model.config import SLLM_100M cfg = SLLM_100M mlp = SwiGLU(cfg) n_params = sum(p.numel() for p in mlp.parameters()) print(f"SwiGLU d_model={cfg.d_model} d_ff={cfg.d_ff}") print(f" gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}") print(f" up : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}") print(f" down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}") print(f" total MLP params : {n_params/1e6:.3f}M") B, T = 2, 64 x = torch.randn(B, T, cfg.d_model) out = mlp(x) print(f"Input shape : {x.shape}") print(f"Output shape : {out.shape}") assert out.shape == x.shape, "Shape mismatch!" print("PASS")