File size: 779 Bytes
ef18673 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | """SwiGLU feed-forward module."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from model.config import ModelConfig
class SwiGLUMLP(nn.Module):
"""Bias-free SwiGLU feed-forward network."""
def __init__(self, config: ModelConfig):
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU and project back to the model width."""
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|