sage / model /mlp.py
sage002's picture
feat: rewrite SAGE 1B architecture and replace legacy repo contents
ef18673 verified
raw
history blame contribute delete
779 Bytes
"""SwiGLU feed-forward module."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from model.config import ModelConfig
class SwiGLUMLP(nn.Module):
"""Bias-free SwiGLU feed-forward network."""
def __init__(self, config: ModelConfig):
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU and project back to the model width."""
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))