File size: 779 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""SwiGLU feed-forward module."""

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import nn

from model.config import ModelConfig


class SwiGLUMLP(nn.Module):
    """Bias-free SwiGLU feed-forward network."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
        self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
        self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply SwiGLU and project back to the model width."""
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))