LLM-1B-Lab / llm_lab /model /feedforward.py
Vjeong's picture
Replace F.silu with explicit SiLU implementation in SwiGLUFeedForward
baf4768
"""SwiGLU Feed-Forward Network."""
import torch
import torch.nn as nn
from llm_lab.config import ModelConfig
class SwiGLUFeedForward(nn.Module):
"""SwiGLU: Gated Linear Unit with Swish activation function.
Standard FFN:
FFN(x) = ReLU(x·W1 + b1)·W2 + b2
→ simple nonlinear transformation
SwiGLU FFN:
SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down
→ controls information flow via a gating mechanism
Why is SwiGLU better?
- Swish(x) = x · sigmoid(x): smooth activation, allows some negative values
- The gate vector learns "which information to let through"
- Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc.
Note: Having two up-projections (W_gate and W_up) means
1.5x the parameters of a standard FFN, but intermediate_dim is
adjusted to match the total parameter count.
"""
def __init__(self, config: ModelConfig):
super().__init__()
# Gate projection: hidden_dim → intermediate_dim
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
# Up projection: hidden_dim → intermediate_dim
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
# Down projection: intermediate_dim → hidden_dim
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU(x) = (Swish(gate(x)) ⊙ up(x)) · down
#
# 1) gate: decides which information to pass through (Swish activation)
gate_val = self.gate_proj(x)
gate = gate_val * torch.sigmoid(gate_val) # SiLU(x) = x * sigmoid(x)
# 2) up: projects information to a higher dimension
up = self.up_proj(x)
# 3) element-wise multiplication (gating) → project back to original dimension
return self.down_proj(gate * up)