"""SwiGLU feed-forward module.""" from __future__ import annotations import torch import torch.nn.functional as F from torch import nn from model.config import ModelConfig class SwiGLUMLP(nn.Module): """Bias-free SwiGLU feed-forward network.""" def __init__(self, config: ModelConfig): super().__init__() self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False) self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False) self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply SwiGLU and project back to the model width.""" return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))