File size: 1,947 Bytes
8a58ffe
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
858e8b2
8a58ffe
858e8b2
 
 
 
8a58ffe
858e8b2
 
 
8a58ffe
 
 
 
858e8b2
8a58ffe
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
 
858e8b2
baf4768
 
858e8b2
8a58ffe
858e8b2
8a58ffe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""SwiGLU Feed-Forward Network."""

import torch
import torch.nn as nn

from llm_lab.config import ModelConfig


class SwiGLUFeedForward(nn.Module):
    """SwiGLU: Gated Linear Unit with Swish activation function.

    Standard FFN:
      FFN(x) = ReLU(x·W1 + b1)·W2 + b2
      → simple nonlinear transformation

    SwiGLU FFN:
      SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
      → controls information flow via a gating mechanism

    Why is SwiGLU better?
      - Swish(x) = x · sigmoid(x): smooth activation, allows some negative values
      - The gate vector learns "which information to let through"
      - Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc.

    Note: Having two up-projections (W_gate and W_up) means
    1.5x the parameters of a standard FFN, but intermediate_dim is
    adjusted to match the total parameter count.
    """

    def __init__(self, config: ModelConfig):
        super().__init__()
        # Gate projection: hidden_dim → intermediate_dim
        self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
        # Up projection: hidden_dim → intermediate_dim
        self.up_proj   = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
        # Down projection: intermediate_dim → hidden_dim
        self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
        #
        # 1) gate: decides which information to pass through (Swish activation)
        gate_val = self.gate_proj(x)
        gate = gate_val * torch.sigmoid(gate_val)  # SiLU(x) = x * sigmoid(x)
        # 2) up: projects information to a higher dimension
        up = self.up_proj(x)
        # 3) element-wise multiplication (gating) → project back to original dimension
        return self.down_proj(gate * up)