| """ |
| Reusable building-block layers: RMSNorm, RotaryEmbedding, SwiGLU. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| try: |
| import transformer_engine.pytorch as te |
| HAS_TE = True |
| except ImportError: |
| te = None |
| HAS_TE = False |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """Root-Mean-Square Layer Normalisation (Zhang & Sennrich, 2019). |
| |
| Computation is promoted to float32 for numerical stability and cast back |
| to the input dtype before returning. |
| """ |
|
|
| def __init__(self, d_model: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| out = self._norm(x.float()).to(x.dtype) |
| return out * self.weight |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| """Precomputed rotary positional embeddings (Su et al., RoFormer 2021). |
| |
| Cos/sin tables are stored as buffers (shape: max_seq_len × D//2) so they |
| move with the module to the correct device automatically. |
| """ |
|
|
| def __init__(self, dim: int, max_seq_len: int, theta: float = 10000.0) -> None: |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.theta = theta |
|
|
| |
| cos, sin = self._build_tables(dim, max_seq_len, theta) |
| self.register_buffer("_cos_cached", cos, persistent=False) |
| self.register_buffer("_sin_cached", sin, persistent=False) |
|
|
| @staticmethod |
| def _build_tables( |
| dim: int, max_seq_len: int, theta: float |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute cos/sin tables with shape (max_seq_len, dim // 2).""" |
| half_dim = dim // 2 |
| |
| freqs = 1.0 / ( |
| theta ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim) |
| ) |
| |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| |
| emb = torch.outer(t, freqs) |
| cos = emb.cos() |
| sin = emb.sin() |
| return cos, sin |
|
|
| def forward(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return (cos, sin) slices of shape (seq_len, D//2) on *device*. |
| |
| If *seq_len* exceeds the precomputed length the tables are recomputed |
| on-the-fly (rare, but graceful fallback). |
| """ |
| if seq_len > self.max_seq_len: |
| cos, sin = self._build_tables(self.dim, seq_len, self.theta) |
| cos = cos.to(device) |
| sin = sin.to(device) |
| else: |
| cos = self._cos_cached[:seq_len].to(device) |
| sin = self._sin_cached[:seq_len].to(device) |
| return cos, sin |
|
|
|
|
| |
| |
| |
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU feed-forward block (Shazeer, 2020). |
| |
| Architecture: |
| out = down_proj( SiLU(gate_proj(x)) * up_proj(x) ) |
| |
| The gate and up projections are separate linear layers so that the gating |
| mechanism can learn an independent representation. |
| """ |
|
|
| def __init__(self, d_model: int, d_ffn: int, bias: bool = False) -> None: |
| super().__init__() |
| self.gate_proj = nn.Linear(d_model, d_ffn, bias=bias) |
| self.up_proj = nn.Linear(d_model, d_ffn, bias=bias) |
| self.down_proj = nn.Linear(d_ffn, d_model, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|