File size: 2,729 Bytes
7f974df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | """
model/mlp.py
SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc.
Standard FFN (GPT-2):
out = dropout(W2 * GELU(W1 * x))
SwiGLU FFN (LLaMA):
gate = W_gate * x # linear gate
up = W_up * x # linear up-proj
hidden = SiLU(gate) * up # element-wise gating (learned)
out = W_down * hidden # down-proj back to d_model
SiLU (Sigmoid Linear Unit):
SiLU(x) = x * sigmoid(x)
Why SwiGLU is better:
- The gating mechanism (SiLU(gate) * up) gives the model a learned
way to activate or suppress each hidden dimension independently.
- Empirically outperforms GELU/ReLU FFNs at the same parameter count.
- d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256.
This compensates for having 3 matrices instead of 2, keeping
total parameter count comparable to a standard 4x FFN.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.config import ModelConfig
class SwiGLU(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
d_model = config.d_model
d_ff = config.d_ff
# Three weight matrices — no bias
self.gate = nn.Linear(d_model, d_ff, bias=config.bias) # gate projection
self.up = nn.Linear(d_model, d_ff, bias=config.bias) # up projection
self.down = nn.Linear(d_ff, d_model, bias=config.bias) # down projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x : (B, T, d_model)
Returns:
out : (B, T, d_model)
"""
# SiLU = x * sigmoid(x) (also called swish)
# Element-wise gating: SiLU(gate) acts as a learned activation mask on up
return self.down(F.silu(self.gate(x)) * self.up(x))
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
from model.config import SLLM_100M
cfg = SLLM_100M
mlp = SwiGLU(cfg)
n_params = sum(p.numel() for p in mlp.parameters())
print(f"SwiGLU d_model={cfg.d_model} d_ff={cfg.d_ff}")
print(f" gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
print(f" up : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
print(f" down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}")
print(f" total MLP params : {n_params/1e6:.3f}M")
B, T = 2, 64
x = torch.randn(B, T, cfg.d_model)
out = mlp(x)
print(f"Input shape : {x.shape}")
print(f"Output shape : {out.shape}")
assert out.shape == x.shape, "Shape mismatch!"
print("PASS")
|