"""SwiGLU Feed-Forward Network.""" import torch import torch.nn as nn from llm_lab.config import ModelConfig class SwiGLUFeedForward(nn.Module): """SwiGLU: Gated Linear Unit with Swish activation function. Standard FFN: FFN(x) = ReLU(x·W1 + b1)·W2 + b2 → simple nonlinear transformation SwiGLU FFN: SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down → controls information flow via a gating mechanism Why is SwiGLU better? - Swish(x) = x · sigmoid(x): smooth activation, allows some negative values - The gate vector learns "which information to let through" - Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc. Note: Having two up-projections (W_gate and W_up) means 1.5x the parameters of a standard FFN, but intermediate_dim is adjusted to match the total parameter count. """ def __init__(self, config: ModelConfig): super().__init__() # Gate projection: hidden_dim → intermediate_dim self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False) # Up projection: hidden_dim → intermediate_dim self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False) # Down projection: intermediate_dim → hidden_dim self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down # # 1) gate: decides which information to pass through (Swish activation) gate_val = self.gate_proj(x) gate = gate_val * torch.sigmoid(gate_val) # SiLU(x) = x * sigmoid(x) # 2) up: projects information to a higher dimension up = self.up_proj(x) # 3) element-wise multiplication (gating) → project back to original dimension return self.down_proj(gate * up)