"""SwiGLU MLP. Extracted from nanochat-v3/nanochat/gpt.py — silu(gate(x)) * up(x), then project down. """ import torch.nn as nn import torch.nn.functional as F class SwiGLU(nn.Module): """SwiGLU MLP: 3 matrices, uniform expansion ratio. If `hidden` is provided explicitly, it overrides `ffn_mult * n_embd`. Pass `hidden` to get clean power-of-2 / divisible-by-16 dims for FP8 kernels. """ def __init__(self, n_embd: int, ffn_mult: float = 4, hidden: int | None = None): super().__init__() if hidden is None: hidden = int(ffn_mult * n_embd) self.c_gate = nn.Linear(n_embd, hidden, bias=False) self.c_up = nn.Linear(n_embd, hidden, bias=False) self.c_proj = nn.Linear(hidden, n_embd, bias=False) def forward(self, x): return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x))