| """SwiGLU Feed-Forward Network.""" |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from llm_lab.config import ModelConfig |
|
|
|
|
| class SwiGLUFeedForward(nn.Module): |
| """SwiGLU: Gated Linear Unit with Swish activation function. |
| |
| Standard FFN: |
| FFN(x) = ReLU(x·W1 + b1)·W2 + b2 |
| → simple nonlinear transformation |
| |
| SwiGLU FFN: |
| SwiGLU(x) = (Swish(x·W_gate) ⊙ (x·W_up)) · W_down |
| → controls information flow via a gating mechanism |
| |
| Why is SwiGLU better? |
| - Swish(x) = x · sigmoid(x): smooth activation, allows some negative values |
| - The gate vector learns "which information to let through" |
| - Consistently reported to outperform ReLU FFN in PaLM, LLaMA, etc. |
| |
| Note: Having two up-projections (W_gate and W_up) means |
| 1.5x the parameters of a standard FFN, but intermediate_dim is |
| adjusted to match the total parameter count. |
| """ |
|
|
| def __init__(self, config: ModelConfig): |
| super().__init__() |
| |
| self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False) |
| |
| self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False) |
| |
| self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| gate_val = self.gate_proj(x) |
| gate = gate_val * torch.sigmoid(gate_val) |
| |
| up = self.up_proj(x) |
| |
| return self.down_proj(gate * up) |
|
|