|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from model.config import PyCraftConfig
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """
|
| SwiGLU feed-forward block.
|
|
|
| Architecture:
|
| x β [gate_proj(x) * silu(up_proj(x))] β down_proj β output
|
|
|
| Where silu(x) = x * sigmoid(x) (also called Swish)
|
| """
|
|
|
| def __init__(self, config: PyCraftConfig):
|
| super().__init__()
|
|
|
| self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| gate = self.gate_proj(x)
|
|
|
| up = self.up_proj(x)
|
|
|
|
|
| hidden = gate * F.silu(up)
|
|
|
| return self.down_proj(hidden)
|
|
|
| @property
|
| def param_count(self) -> int:
|
| return sum(p.numel() for p in self.parameters())
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| from model.config import get_config_tiny
|
|
|
| torch.manual_seed(42)
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
| cfg = get_config_tiny()
|
|
|
| print(f"Testing SwiGLU FFN on {device}...")
|
| print(f" d_model={cfg.d_model}, d_ff={cfg.d_ff}")
|
|
|
| ffn = SwiGLU(cfg).to(device)
|
| print(f" FFN params: {ffn.param_count:,}")
|
|
|
| x = torch.randn(2, 64, cfg.d_model, device=device)
|
| with torch.no_grad():
|
| out = ffn(x)
|
|
|
| print(f" Input shape: {tuple(x.shape)}")
|
| print(f" Output shape: {tuple(out.shape)}")
|
| assert out.shape == x.shape, "Output shape mismatch!"
|
|
|
|
|
| x2 = torch.randn(2, 64, cfg.d_model, device=device, requires_grad=True)
|
| out2 = ffn(x2)
|
| out2.sum().backward()
|
| assert x2.grad is not None
|
| print(f" Gradient norm: {x2.grad.norm().item():.4f}")
|
| print(" All FFN tests PASSED.")
|
|
|