Ahmed
Upload code/mlp.py with huggingface_hub
aeb21e1 verified
"""SwiGLU MLP.
Extracted from nanochat-v3/nanochat/gpt.py — silu(gate(x)) * up(x), then project down.
"""
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
"""SwiGLU MLP: 3 matrices, uniform expansion ratio.
If `hidden` is provided explicitly, it overrides `ffn_mult * n_embd`.
Pass `hidden` to get clean power-of-2 / divisible-by-16 dims for FP8 kernels.
"""
def __init__(self, n_embd: int, ffn_mult: float = 4, hidden: int | None = None):
super().__init__()
if hidden is None:
hidden = int(ffn_mult * n_embd)
self.c_gate = nn.Linear(n_embd, hidden, bias=False)
self.c_up = nn.Linear(n_embd, hidden, bias=False)
self.c_proj = nn.Linear(hidden, n_embd, bias=False)
def forward(self, x):
return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x))