| """SwiGLU MLP. | |
| Extracted from nanochat-v3/nanochat/gpt.py — silu(gate(x)) * up(x), then project down. | |
| """ | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SwiGLU(nn.Module): | |
| """SwiGLU MLP: 3 matrices, uniform expansion ratio. | |
| If `hidden` is provided explicitly, it overrides `ffn_mult * n_embd`. | |
| Pass `hidden` to get clean power-of-2 / divisible-by-16 dims for FP8 kernels. | |
| """ | |
| def __init__(self, n_embd: int, ffn_mult: float = 4, hidden: int | None = None): | |
| super().__init__() | |
| if hidden is None: | |
| hidden = int(ffn_mult * n_embd) | |
| self.c_gate = nn.Linear(n_embd, hidden, bias=False) | |
| self.c_up = nn.Linear(n_embd, hidden, bias=False) | |
| self.c_proj = nn.Linear(hidden, n_embd, bias=False) | |
| def forward(self, x): | |
| return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x)) | |