karthick
Upload TinyStories 24.5M model - article generation success
fb67af8
"""SwiGLU (Swish-Gated Linear Unit) activation function implementation.
Critical implementation details:
1. Requires THREE weight matrices (gate, value, down-projection)
2. Hidden dimension should be adjusted to ~8/3 * d_model for parameter parity
3. No bias terms in modern implementations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class SwiGLU(nn.Module):
"""Swish-Gated Linear Unit activation function.
SwiGLU combines the Swish activation (SiLU) with a gating mechanism
for improved gradient flow in deep networks.
Based on the paper: 'GLU Variants Improve Transformer'
https://arxiv.org/abs/2002.05202
"""
def __init__(
self,
input_dim: int,
hidden_dim: Optional[int] = None,
output_dim: Optional[int] = None,
multiple_of: int = 256,
bias: bool = False,
):
"""
Args:
input_dim: Input dimension (d_model)
hidden_dim: Hidden dimension for FFN. If None, uses 8/3 * input_dim
output_dim: Output dimension. If None, uses input_dim
multiple_of: Round hidden_dim to nearest multiple for hardware efficiency
bias: Whether to use bias terms (modern LLMs use False)
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim or input_dim
# CRITICAL: Adjust hidden dimension for parameter parity
# Standard FFN with ReLU/GELU uses 4 * d_model
# SwiGLU needs 3 matrices, so use (8/3) * d_model for same param count
if hidden_dim is None:
hidden_dim = int(8 * input_dim / 3)
# Round to nearest multiple for better hardware utilization
if multiple_of > 1:
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_dim
# Three linear projections required for SwiGLU
self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias) # Gate projection
self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias) # Value/up projection
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias) # Down projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU activation.
Formula: SwiGLU(x) = (Swish(xW_gate) ⊗ xW_up) W_down
where Swish(x) = x * sigmoid(x) = SiLU(x)
Args:
x: Input tensor of shape [..., input_dim]
Returns:
Output tensor of shape [..., output_dim]
"""
# Gate path with Swish/SiLU activation
gate = F.silu(self.w_gate(x))
# Value path (no activation)
value = self.w_up(x)
# Element-wise multiplication (gating)
hidden = gate * value
# Down projection to output dimension
output = self.w_down(hidden)
return output
def extra_repr(self) -> str:
return (
f'input_dim={self.input_dim}, '
f'hidden_dim={self.hidden_dim}, '
f'output_dim={self.output_dim}'
)
class SwiGLUParallel(nn.Module):
"""Parallel version of SwiGLU that combines gate and up projections.
This is more efficient as it reduces the number of separate matmuls.
Used in models like LLaMA and Mistral.
"""
def __init__(
self,
input_dim: int,
hidden_dim: Optional[int] = None,
output_dim: Optional[int] = None,
multiple_of: int = 256,
bias: bool = False,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim or input_dim
if hidden_dim is None:
hidden_dim = int(8 * input_dim / 3)
if multiple_of > 1:
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.hidden_dim = hidden_dim
# Combined gate and up projection for efficiency
# Output shape: [batch, seq, 2 * hidden_dim]
self.w_gate_up = nn.Linear(input_dim, 2 * hidden_dim, bias=bias)
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU with parallel projections."""
# Single matmul for both gate and up projections
gate_up = self.w_gate_up(x)
# Split into gate and up components
gate, up = gate_up.chunk(2, dim=-1)
# Apply SwiGLU
hidden = F.silu(gate) * up
output = self.w_down(hidden)
return output
class GeGLU(nn.Module):
"""GELU-Gated Linear Unit - alternative to SwiGLU.
Some models use GeGLU instead of SwiGLU. The difference is using
GELU instead of SiLU for the gating activation.
"""
def __init__(
self,
input_dim: int,
hidden_dim: Optional[int] = None,
output_dim: Optional[int] = None,
bias: bool = False,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim or input_dim
if hidden_dim is None:
hidden_dim = int(8 * input_dim / 3)
self.hidden_dim = hidden_dim
self.w_gate = nn.Linear(input_dim, hidden_dim, bias=bias)
self.w_up = nn.Linear(input_dim, hidden_dim, bias=bias)
self.w_down = nn.Linear(hidden_dim, self.output_dim, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply GeGLU activation."""
gate = F.gelu(self.w_gate(x))
value = self.w_up(x)
hidden = gate * value
output = self.w_down(hidden)
return output
def calculate_ffn_params(d_model: int, activation: str = "swiglu") -> dict:
"""Calculate FFN parameters for different activation functions.
This helper ensures parameter parity across different activation types.
"""
if activation == "relu" or activation == "gelu":
# Standard FFN: 2 matrices
hidden_dim = 4 * d_model
num_params = 2 * d_model * hidden_dim
elif activation in ["swiglu", "geglu"]:
# Gated FFN: 3 matrices, adjust hidden dimension
hidden_dim = int(8 * d_model / 3)
# Round to multiple of 256 for hardware efficiency
hidden_dim = 256 * ((hidden_dim + 255) // 256)
num_params = d_model * hidden_dim * 2 + hidden_dim * d_model
else:
raise ValueError(f"Unknown activation: {activation}")
return {
"activation": activation,
"d_model": d_model,
"hidden_dim": hidden_dim,
"num_params": num_params,
"params_millions": num_params / 1e6,
}
# Example usage and parameter comparison
if __name__ == "__main__":
d_model = 768
# Compare parameter counts
print("FFN Parameter Comparison:")
for act in ["relu", "gelu", "swiglu"]:
params = calculate_ffn_params(d_model, act)
print(f"{act.upper()}:")
print(f" Hidden dim: {params['hidden_dim']}")
print(f" Parameters: {params['params_millions']:.2f}M")
# Test SwiGLU
batch_size, seq_len = 2, 512
x = torch.randn(batch_size, seq_len, d_model)
swiglu = SwiGLU(d_model)
output = swiglu(x)
print(f"\nSwiGLU Test:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"SwiGLU parameters: {sum(p.numel() for p in swiglu.parameters()) / 1e6:.2f}M")