File size: 2,729 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
model/mlp.py

SwiGLU Feed-Forward Network — used in LLaMA, PaLM, Mistral, etc.

Standard FFN (GPT-2):
    out = dropout(W2 * GELU(W1 * x))

SwiGLU FFN (LLaMA):
    gate   = W_gate * x           # linear gate
    up     = W_up   * x           # linear up-proj
    hidden = SiLU(gate) * up      # element-wise gating (learned)
    out    = W_down * hidden      # down-proj back to d_model

SiLU (Sigmoid Linear Unit):
    SiLU(x) = x * sigmoid(x)

Why SwiGLU is better:
    - The gating mechanism (SiLU(gate) * up) gives the model a learned
      way to activate or suppress each hidden dimension independently.
    - Empirically outperforms GELU/ReLU FFNs at the same parameter count.
    - d_ff is set to int(2/3 * 4 * d_model) rounded to nearest 256.
      This compensates for having 3 matrices instead of 2, keeping
      total parameter count comparable to a standard 4x FFN.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from model.config import ModelConfig


class SwiGLU(nn.Module):

    def __init__(self, config: ModelConfig):
        super().__init__()

        d_model = config.d_model
        d_ff    = config.d_ff

        # Three weight matrices — no bias
        self.gate = nn.Linear(d_model, d_ff, bias=config.bias)   # gate projection
        self.up   = nn.Linear(d_model, d_ff, bias=config.bias)   # up projection
        self.down = nn.Linear(d_ff, d_model, bias=config.bias)   # down projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x : (B, T, d_model)
        Returns:
            out : (B, T, d_model)
        """
        # SiLU = x * sigmoid(x)  (also called swish)
        # Element-wise gating: SiLU(gate) acts as a learned activation mask on up
        return self.down(F.silu(self.gate(x)) * self.up(x))


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    from model.config import SLLM_100M

    cfg = SLLM_100M
    mlp = SwiGLU(cfg)

    n_params = sum(p.numel() for p in mlp.parameters())
    print(f"SwiGLU d_model={cfg.d_model}  d_ff={cfg.d_ff}")
    print(f"  gate : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
    print(f"  up   : {cfg.d_model} x {cfg.d_ff} = {cfg.d_model * cfg.d_ff:,}")
    print(f"  down : {cfg.d_ff} x {cfg.d_model} = {cfg.d_ff * cfg.d_model:,}")
    print(f"  total MLP params : {n_params/1e6:.3f}M")

    B, T = 2, 64
    x   = torch.randn(B, T, cfg.d_model)
    out = mlp(x)

    print(f"Input  shape : {x.shape}")
    print(f"Output shape : {out.shape}")
    assert out.shape == x.shape, "Shape mismatch!"
    print("PASS")