File size: 4,725 Bytes
c0f89d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Reusable building-block layers: RMSNorm, RotaryEmbedding, SwiGLU.
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------------
# Optional TransformerEngine import (FP8 support)
# ---------------------------------------------------------------------------
try:
    import transformer_engine.pytorch as te  # type: ignore[import]
    HAS_TE = True
except ImportError:
    te = None  # type: ignore[assignment]
    HAS_TE = False


# ---------------------------------------------------------------------------
# RMS Layer Normalisation
# ---------------------------------------------------------------------------

class RMSNorm(nn.Module):
    """Root-Mean-Square Layer Normalisation (Zhang & Sennrich, 2019).

    Computation is promoted to float32 for numerical stability and cast back
    to the input dtype before returning.
    """

    def __init__(self, d_model: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # x: (..., D) — compute in fp32
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Upcast to float32, normalise, scale, then restore original dtype.
        out = self._norm(x.float()).to(x.dtype)
        return out * self.weight


# ---------------------------------------------------------------------------
# Rotary Positional Embedding
# ---------------------------------------------------------------------------

class RotaryEmbedding(nn.Module):
    """Precomputed rotary positional embeddings (Su et al., RoFormer 2021).

    Cos/sin tables are stored as buffers (shape: max_seq_len × D//2) so they
    move with the module to the correct device automatically.
    """

    def __init__(self, dim: int, max_seq_len: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.theta = theta

        # Precompute and register
        cos, sin = self._build_tables(dim, max_seq_len, theta)
        self.register_buffer("_cos_cached", cos, persistent=False)
        self.register_buffer("_sin_cached", sin, persistent=False)

    @staticmethod
    def _build_tables(
        dim: int, max_seq_len: int, theta: float
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute cos/sin tables with shape (max_seq_len, dim // 2)."""
        half_dim = dim // 2
        # Inverse frequencies: shape (half_dim,)
        freqs = 1.0 / (
            theta ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim)
        )
        # Positions: shape (max_seq_len,)
        t = torch.arange(max_seq_len, dtype=torch.float32)
        # Outer product → (max_seq_len, half_dim)
        emb = torch.outer(t, freqs)
        cos = emb.cos()  # (T, D//2)
        sin = emb.sin()  # (T, D//2)
        return cos, sin

    def forward(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
        """Return (cos, sin) slices of shape (seq_len, D//2) on *device*.

        If *seq_len* exceeds the precomputed length the tables are recomputed
        on-the-fly (rare, but graceful fallback).
        """
        if seq_len > self.max_seq_len:
            cos, sin = self._build_tables(self.dim, seq_len, self.theta)
            cos = cos.to(device)
            sin = sin.to(device)
        else:
            cos = self._cos_cached[:seq_len].to(device)
            sin = self._sin_cached[:seq_len].to(device)
        return cos, sin


# ---------------------------------------------------------------------------
# SwiGLU Feed-Forward Network
# ---------------------------------------------------------------------------

class SwiGLU(nn.Module):
    """SwiGLU feed-forward block (Shazeer, 2020).

    Architecture:
        out = down_proj( SiLU(gate_proj(x)) * up_proj(x) )

    The gate and up projections are separate linear layers so that the gating
    mechanism can learn an independent representation.
    """

    def __init__(self, d_model: int, d_ffn: int, bias: bool = False) -> None:
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ffn, bias=bias)
        self.up_proj   = nn.Linear(d_model, d_ffn, bias=bias)
        self.down_proj = nn.Linear(d_ffn, d_model, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Gated activation: element-wise product of SiLU(gate) and up projection
        return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))