File size: 2,717 Bytes
a25ba8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# model/feedforward.py
# SwiGLU Feed-Forward Network.
#
# Standard FFN uses: Linear → ReLU → Linear  (2 matrices)
# SwiGLU uses:       (Linear_gate ⊙ Swish(Linear_up)) → Linear_down  (3 matrices)
#
# The gating mechanism gives better accuracy per FLOP.
# Used in: Llama 3, Qwen 3, Gemma 2, Mistral, PyCraft-1.
# No bias terms throughout.

import torch
import torch.nn as nn
import torch.nn.functional as F

from model.config import PyCraftConfig


class SwiGLU(nn.Module):
    """

    SwiGLU feed-forward block.



    Architecture:

        x → [gate_proj(x) * silu(up_proj(x))] → down_proj → output



    Where silu(x) = x * sigmoid(x)  (also called Swish)

    """

    def __init__(self, config: PyCraftConfig):
        super().__init__()
        # Three projections, all without bias
        self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gate branch: controls information flow
        gate = self.gate_proj(x)
        # Up branch: projects to higher dimension
        up = self.up_proj(x)
        # SwiGLU: element-wise gate * SiLU(up)
        # SiLU(x) = x * sigmoid(x) — smooth, non-monotonic activation
        hidden = gate * F.silu(up)
        # Project back to d_model
        return self.down_proj(hidden)

    @property
    def param_count(self) -> int:
        return sum(p.numel() for p in self.parameters())


# ------------------------------------------------------------------ #
# Quick self-test
# ------------------------------------------------------------------ #
if __name__ == "__main__":
    from model.config import get_config_tiny

    torch.manual_seed(42)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cfg = get_config_tiny()

    print(f"Testing SwiGLU FFN on {device}...")
    print(f"  d_model={cfg.d_model}, d_ff={cfg.d_ff}")

    ffn = SwiGLU(cfg).to(device)
    print(f"  FFN params: {ffn.param_count:,}")

    x = torch.randn(2, 64, cfg.d_model, device=device)
    with torch.no_grad():
        out = ffn(x)

    print(f"  Input  shape: {tuple(x.shape)}")
    print(f"  Output shape: {tuple(out.shape)}")
    assert out.shape == x.shape, "Output shape mismatch!"

    # Gradient test
    x2 = torch.randn(2, 64, cfg.d_model, device=device, requires_grad=True)
    out2 = ffn(x2)
    out2.sum().backward()
    assert x2.grad is not None
    print(f"  Gradient norm: {x2.grad.norm().item():.4f}")
    print("  All FFN tests PASSED.")