File size: 867 Bytes
aeb21e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""SwiGLU MLP.

Extracted from nanochat-v3/nanochat/gpt.py — silu(gate(x)) * up(x), then project down.
"""

import torch.nn as nn
import torch.nn.functional as F


class SwiGLU(nn.Module):
    """SwiGLU MLP: 3 matrices, uniform expansion ratio.

    If `hidden` is provided explicitly, it overrides `ffn_mult * n_embd`.
    Pass `hidden` to get clean power-of-2 / divisible-by-16 dims for FP8 kernels.
    """

    def __init__(self, n_embd: int, ffn_mult: float = 4, hidden: int | None = None):
        super().__init__()
        if hidden is None:
            hidden = int(ffn_mult * n_embd)
        self.c_gate = nn.Linear(n_embd, hidden, bias=False)
        self.c_up = nn.Linear(n_embd, hidden, bias=False)
        self.c_proj = nn.Linear(hidden, n_embd, bias=False)

    def forward(self, x):
        return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x))