pycraft-1 / model /feedforward.py
imshadow0's picture
Upload model/feedforward.py with huggingface_hub
a25ba8e verified
Raw
History Blame Contribute Delete
2.72 kB
# model/feedforward.py
# SwiGLU Feed-Forward Network.
#
# Standard FFN uses: Linear β†’ ReLU β†’ Linear (2 matrices)
# SwiGLU uses: (Linear_gate βŠ™ Swish(Linear_up)) β†’ Linear_down (3 matrices)
#
# The gating mechanism gives better accuracy per FLOP.
# Used in: Llama 3, Qwen 3, Gemma 2, Mistral, PyCraft-1.
# No bias terms throughout.
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.config import PyCraftConfig
class SwiGLU(nn.Module):
"""
SwiGLU feed-forward block.
Architecture:
x β†’ [gate_proj(x) * silu(up_proj(x))] β†’ down_proj β†’ output
Where silu(x) = x * sigmoid(x) (also called Swish)
"""
def __init__(self, config: PyCraftConfig):
super().__init__()
# Three projections, all without bias
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Gate branch: controls information flow
gate = self.gate_proj(x)
# Up branch: projects to higher dimension
up = self.up_proj(x)
# SwiGLU: element-wise gate * SiLU(up)
# SiLU(x) = x * sigmoid(x) β€” smooth, non-monotonic activation
hidden = gate * F.silu(up)
# Project back to d_model
return self.down_proj(hidden)
@property
def param_count(self) -> int:
return sum(p.numel() for p in self.parameters())
# ------------------------------------------------------------------ #
# Quick self-test
# ------------------------------------------------------------------ #
if __name__ == "__main__":
from model.config import get_config_tiny
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = get_config_tiny()
print(f"Testing SwiGLU FFN on {device}...")
print(f" d_model={cfg.d_model}, d_ff={cfg.d_ff}")
ffn = SwiGLU(cfg).to(device)
print(f" FFN params: {ffn.param_count:,}")
x = torch.randn(2, 64, cfg.d_model, device=device)
with torch.no_grad():
out = ffn(x)
print(f" Input shape: {tuple(x.shape)}")
print(f" Output shape: {tuple(out.shape)}")
assert out.shape == x.shape, "Output shape mismatch!"
# Gradient test
x2 = torch.randn(2, 64, cfg.d_model, device=device, requires_grad=True)
out2 = ffn(x2)
out2.sum().backward()
assert x2.grad is not None
print(f" Gradient norm: {x2.grad.norm().item():.4f}")
print(" All FFN tests PASSED.")