Upload folder using huggingface_hub
#40
by somebody-to-love - opened
- source/model/__init__.py +18 -0
- source/model/attention.py +263 -0
- source/model/config.py +186 -0
- source/model/layers.py +127 -0
- source/model/mamba_block.py +280 -0
- source/model/transformer.py +370 -0
source/model/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model — LLM architecture package.
|
| 3 |
+
|
| 4 |
+
Public API:
|
| 5 |
+
LLM : top-level decoder-only transformer/hybrid language model
|
| 6 |
+
LMConfig : configuration dataclass
|
| 7 |
+
Mamba2Block: Mamba-2 SSD block (used internally by LLM in hybrid mode)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .config import LMConfig
|
| 11 |
+
from .mamba_block import Mamba2Block
|
| 12 |
+
from .transformer import LLM
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"LLM",
|
| 16 |
+
"LMConfig",
|
| 17 |
+
"Mamba2Block",
|
| 18 |
+
]
|
source/model/attention.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-Head (and Grouped-Query) Attention with optional FlashAttention-2 backend.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from .config import LMConfig
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# Optional FlashAttention import
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
try:
|
| 19 |
+
from flash_attn import flash_attn_func # type: ignore[import]
|
| 20 |
+
HAS_FLASH_ATTN = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
HAS_FLASH_ATTN = False
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Optional TransformerEngine import (FP8 support)
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
try:
|
| 28 |
+
import transformer_engine.pytorch as te # type: ignore[import]
|
| 29 |
+
HAS_TE = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
te = None # type: ignore[assignment]
|
| 32 |
+
HAS_TE = False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# Rotary embedding helper
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
def apply_rotary_emb(
|
| 40 |
+
x: torch.Tensor,
|
| 41 |
+
cos: torch.Tensor,
|
| 42 |
+
sin: torch.Tensor,
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""Apply rotary positional embeddings to query or key tensor.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x: (B, T, H, D_head)
|
| 48 |
+
cos: (T, D_head // 2) — from RotaryEmbedding.forward
|
| 49 |
+
sin: (T, D_head // 2) — from RotaryEmbedding.forward
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Tensor with the same shape as *x*, rotated.
|
| 53 |
+
"""
|
| 54 |
+
d = x.shape[-1]
|
| 55 |
+
half_d = d // 2
|
| 56 |
+
|
| 57 |
+
x1 = x[..., :half_d] # (B, T, H, D//2)
|
| 58 |
+
x2 = x[..., half_d:] # (B, T, H, D//2)
|
| 59 |
+
|
| 60 |
+
# Broadcast cos/sin from (T, D//2) → (1, T, 1, D//2)
|
| 61 |
+
cos = cos.unsqueeze(0).unsqueeze(2) # (1, T, 1, D//2)
|
| 62 |
+
sin = sin.unsqueeze(0).unsqueeze(2) # (1, T, 1, D//2)
|
| 63 |
+
|
| 64 |
+
rotated = torch.cat(
|
| 65 |
+
[x1 * cos - x2 * sin, x1 * sin + x2 * cos],
|
| 66 |
+
dim=-1,
|
| 67 |
+
)
|
| 68 |
+
return rotated.to(x.dtype)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Multi-Head Attention
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
class MultiHeadAttention(nn.Module):
|
| 77 |
+
"""Multi-head (or grouped-query) causal self-attention.
|
| 78 |
+
|
| 79 |
+
Supports:
|
| 80 |
+
- Standard MHA: n_kv_heads == n_heads
|
| 81 |
+
- GQA / MQA: n_kv_heads < n_heads (must evenly divide n_heads)
|
| 82 |
+
|
| 83 |
+
Attention backend:
|
| 84 |
+
- FlashAttention-2 when available and config.use_flash_attn is True
|
| 85 |
+
- Vanilla scaled dot-product otherwise (causal mask via upper-triangular)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, config: LMConfig) -> None:
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
self.n_heads = config.n_heads
|
| 92 |
+
self.n_kv_heads = config.n_kv_heads # resolved in __post_init__
|
| 93 |
+
self.head_dim = config.d_model // config.n_heads
|
| 94 |
+
self.d_model = config.d_model
|
| 95 |
+
self.dropout = config.dropout
|
| 96 |
+
self.use_flash = config.use_flash_attn
|
| 97 |
+
|
| 98 |
+
# Number of query-head groups per KV head
|
| 99 |
+
self.n_rep = self.n_heads // self.n_kv_heads
|
| 100 |
+
|
| 101 |
+
# Projections ----------------------------------------------------
|
| 102 |
+
# Select Linear implementation: te.Linear (FP8) or nn.Linear (BF16)
|
| 103 |
+
_Linear = te.Linear if (config.use_fp8 and HAS_TE) else nn.Linear
|
| 104 |
+
|
| 105 |
+
# Fused QKV projection: single GEMM (d_model → q_dim + k_dim + v_dim)
|
| 106 |
+
# For GQA 24:8 with head_dim=128: 3072 + 1024 + 1024 = 5120
|
| 107 |
+
self._q_dim = self.n_heads * self.head_dim # e.g. 24 * 128 = 3072
|
| 108 |
+
self._kv_dim = self.n_kv_heads * self.head_dim # e.g. 8 * 128 = 1024
|
| 109 |
+
self.qkv_proj = _Linear(
|
| 110 |
+
config.d_model,
|
| 111 |
+
self._q_dim + 2 * self._kv_dim, # 3072 + 2*1024 = 5120
|
| 112 |
+
bias=config.bias,
|
| 113 |
+
)
|
| 114 |
+
self.out_proj = _Linear(
|
| 115 |
+
config.d_model,
|
| 116 |
+
config.d_model,
|
| 117 |
+
bias=config.bias,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ------------------------------------------------------------------
|
| 121 |
+
# KV-head expansion for GQA
|
| 122 |
+
# ------------------------------------------------------------------
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 126 |
+
"""Expand KV heads to match the number of query heads.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
x: (B, T, n_kv_heads, head_dim)
|
| 130 |
+
n_rep: repetition factor
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
(B, T, n_kv_heads * n_rep, head_dim)
|
| 134 |
+
"""
|
| 135 |
+
if n_rep == 1:
|
| 136 |
+
return x
|
| 137 |
+
B, T, n_kv, D = x.shape
|
| 138 |
+
return x.repeat_interleave(n_rep, dim=2)
|
| 139 |
+
|
| 140 |
+
# ------------------------------------------------------------------
|
| 141 |
+
# Forward
|
| 142 |
+
# ------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
def forward(
|
| 145 |
+
self,
|
| 146 |
+
x: torch.Tensor,
|
| 147 |
+
cos: torch.Tensor,
|
| 148 |
+
sin: torch.Tensor,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Args:
|
| 152 |
+
x: (B, T, C)
|
| 153 |
+
cos: (T, head_dim // 2) — from RotaryEmbedding
|
| 154 |
+
sin: (T, head_dim // 2) — from RotaryEmbedding
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
(B, T, C)
|
| 158 |
+
"""
|
| 159 |
+
B, T, C = x.shape
|
| 160 |
+
|
| 161 |
+
# --- Fused QKV projection (single GEMM) --------------------------------
|
| 162 |
+
qkv = self.qkv_proj(x) # (B, T, q_dim + 2*kv_dim)
|
| 163 |
+
q, k, v = qkv.split([self._q_dim, self._kv_dim, self._kv_dim], dim=-1)
|
| 164 |
+
q = q.view(B, T, self.n_heads, self.head_dim)
|
| 165 |
+
k = k.view(B, T, self.n_kv_heads, self.head_dim)
|
| 166 |
+
v = v.view(B, T, self.n_kv_heads, self.head_dim)
|
| 167 |
+
|
| 168 |
+
# FlashAttention-2 and rotary embedding require bf16/fp16.
|
| 169 |
+
# te.Linear with MXFP8 may emit FP8-format output tensors; cast if needed.
|
| 170 |
+
if q.dtype not in (torch.float16, torch.bfloat16):
|
| 171 |
+
q = q.to(torch.bfloat16)
|
| 172 |
+
k = k.to(torch.bfloat16)
|
| 173 |
+
v = v.to(torch.bfloat16)
|
| 174 |
+
|
| 175 |
+
# --- Rotary embeddings -----------------------------------------------
|
| 176 |
+
q = apply_rotary_emb(q, cos, sin)
|
| 177 |
+
k = apply_rotary_emb(k, cos, sin)
|
| 178 |
+
|
| 179 |
+
# --- Attention -------------------------------------------------------
|
| 180 |
+
if self.use_flash and HAS_FLASH_ATTN and x.is_cuda:
|
| 181 |
+
attn_out = self._flash_attention(q, k, v, B, T)
|
| 182 |
+
else:
|
| 183 |
+
attn_out = self._standard_attention(q, k, v, B, T)
|
| 184 |
+
|
| 185 |
+
# --- Output projection -----------------------------------------------
|
| 186 |
+
# attn_out: (B, T, C)
|
| 187 |
+
return self.out_proj(attn_out)
|
| 188 |
+
|
| 189 |
+
# ------------------------------------------------------------------
|
| 190 |
+
# FlashAttention-2 path
|
| 191 |
+
# ------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
def _flash_attention(
|
| 194 |
+
self,
|
| 195 |
+
q: torch.Tensor,
|
| 196 |
+
k: torch.Tensor,
|
| 197 |
+
v: torch.Tensor,
|
| 198 |
+
B: int,
|
| 199 |
+
T: int,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
"""Run FlashAttention-2.
|
| 202 |
+
|
| 203 |
+
flash_attn_func expects inputs in (B, T, H, D) layout and returns
|
| 204 |
+
(B, T, H, D). FlashAttention-2 natively supports GQA via head count
|
| 205 |
+
mismatch (q has n_heads, k/v have n_kv_heads) — no KV expansion needed.
|
| 206 |
+
"""
|
| 207 |
+
dropout_p = self.dropout if self.training else 0.0
|
| 208 |
+
|
| 209 |
+
# flash_attn_func: (B, T, H, D) → (B, T, H, D)
|
| 210 |
+
# GQA is handled natively: q=(B,T,n_heads,D), k/v=(B,T,n_kv_heads,D)
|
| 211 |
+
out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
|
| 212 |
+
|
| 213 |
+
# Reshape (B, T, n_heads, head_dim) → (B, T, C)
|
| 214 |
+
return out.reshape(B, T, self.n_heads * self.head_dim)
|
| 215 |
+
|
| 216 |
+
# ------------------------------------------------------------------
|
| 217 |
+
# Standard (fallback) attention path
|
| 218 |
+
# ------------------------------------------------------------------
|
| 219 |
+
|
| 220 |
+
def _standard_attention(
|
| 221 |
+
self,
|
| 222 |
+
q: torch.Tensor,
|
| 223 |
+
k: torch.Tensor,
|
| 224 |
+
v: torch.Tensor,
|
| 225 |
+
B: int,
|
| 226 |
+
T: int,
|
| 227 |
+
) -> torch.Tensor:
|
| 228 |
+
"""Vanilla scaled dot-product causal attention.
|
| 229 |
+
|
| 230 |
+
Softmax is computed in float32 for numerical stability.
|
| 231 |
+
"""
|
| 232 |
+
# Expand KV heads for GQA
|
| 233 |
+
k = self._repeat_kv(k, self.n_rep) # (B, T, n_heads, head_dim)
|
| 234 |
+
v = self._repeat_kv(v, self.n_rep) # (B, T, n_heads, head_dim)
|
| 235 |
+
|
| 236 |
+
# (B, T, H, D) → (B, H, T, D)
|
| 237 |
+
q = q.transpose(1, 2)
|
| 238 |
+
k = k.transpose(1, 2)
|
| 239 |
+
v = v.transpose(1, 2)
|
| 240 |
+
|
| 241 |
+
scale = math.sqrt(self.head_dim)
|
| 242 |
+
|
| 243 |
+
# Scaled dot-product: (B, H, T, T)
|
| 244 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / scale
|
| 245 |
+
|
| 246 |
+
# Causal mask: fill upper triangle (excluding diagonal) with -inf
|
| 247 |
+
causal_mask = torch.triu(
|
| 248 |
+
torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1
|
| 249 |
+
)
|
| 250 |
+
scores = scores.masked_fill(causal_mask, float("-inf"))
|
| 251 |
+
|
| 252 |
+
# Softmax in fp32, then cast back
|
| 253 |
+
attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
|
| 254 |
+
|
| 255 |
+
if self.training and self.dropout > 0.0:
|
| 256 |
+
attn_weights = F.dropout(attn_weights, p=self.dropout)
|
| 257 |
+
|
| 258 |
+
# Weighted sum: (B, H, T, D)
|
| 259 |
+
out = torch.matmul(attn_weights, v)
|
| 260 |
+
|
| 261 |
+
# (B, H, T, D) → (B, T, H, D) → (B, T, C)
|
| 262 |
+
out = out.transpose(1, 2).contiguous().reshape(B, T, self.d_model)
|
| 263 |
+
return out
|
source/model/config.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LMConfig: configuration dataclass for the LLM model architecture.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _round_to_multiple(n: int, multiple: int) -> int:
|
| 18 |
+
"""Round n up to the nearest multiple of `multiple`."""
|
| 19 |
+
return math.ceil(n / multiple) * multiple
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class LMConfig:
|
| 24 |
+
# Vocabulary
|
| 25 |
+
vocab_size: int = 32000
|
| 26 |
+
|
| 27 |
+
# Model dimensions
|
| 28 |
+
d_model: int = 768
|
| 29 |
+
n_layers: int = 12
|
| 30 |
+
n_heads: int = 12
|
| 31 |
+
|
| 32 |
+
# Grouped-query attention: None → standard MHA (n_kv_heads == n_heads)
|
| 33 |
+
n_kv_heads: Optional[int] = None
|
| 34 |
+
|
| 35 |
+
# Feed-forward hidden dimension: None → auto-computed
|
| 36 |
+
d_ffn: Optional[int] = None
|
| 37 |
+
|
| 38 |
+
# Sequence length
|
| 39 |
+
max_seq_len: int = 2048
|
| 40 |
+
|
| 41 |
+
# RoPE base frequency
|
| 42 |
+
rope_theta: float = 10000.0
|
| 43 |
+
|
| 44 |
+
# Regularisation
|
| 45 |
+
dropout: float = 0.0
|
| 46 |
+
bias: bool = False
|
| 47 |
+
|
| 48 |
+
# Attention backend
|
| 49 |
+
use_flash_attn: bool = True
|
| 50 |
+
|
| 51 |
+
# FP8 quantization
|
| 52 |
+
use_fp8: bool = False
|
| 53 |
+
|
| 54 |
+
# Hybrid Mamba-Transformer settings
|
| 55 |
+
use_hybrid: bool = False
|
| 56 |
+
hybrid_pattern: str = "" # e.g. "M M A M M M M A M M M M M M M M M M A M" for 40-layer Nemotron-H style
|
| 57 |
+
# Mamba-2 SSM parameters
|
| 58 |
+
mamba_d_state: int = 128
|
| 59 |
+
mamba_head_dim: int = 64
|
| 60 |
+
mamba_expand: int = 2
|
| 61 |
+
mamba_conv_kernel: int = 4
|
| 62 |
+
mamba_n_groups: int = 1
|
| 63 |
+
mamba_chunk_size: int = 256
|
| 64 |
+
|
| 65 |
+
def __post_init__(self) -> None:
|
| 66 |
+
# Resolve n_kv_heads: None → full MHA
|
| 67 |
+
if self.n_kv_heads is None:
|
| 68 |
+
self.n_kv_heads = self.n_heads
|
| 69 |
+
|
| 70 |
+
# Validate GQA divisibility
|
| 71 |
+
if self.n_heads % self.n_kv_heads != 0:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"n_heads ({self.n_heads}) must be divisible by "
|
| 74 |
+
f"n_kv_heads ({self.n_kv_heads})"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Compute d_ffn using the LLaMA-style formula: round(8/3 * d_model)
|
| 78 |
+
# rounded up to the nearest multiple of 256.
|
| 79 |
+
if self.d_ffn is None:
|
| 80 |
+
raw = int(8 / 3 * self.d_model)
|
| 81 |
+
self.d_ffn = _round_to_multiple(raw, 256)
|
| 82 |
+
|
| 83 |
+
# Hybrid Mamba-Transformer validation
|
| 84 |
+
if self.use_hybrid and not self.hybrid_pattern.strip():
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"use_hybrid=True requires a non-empty hybrid_pattern "
|
| 87 |
+
"(space-separated 'M'/'A' per layer)"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# FP8 alignment: TE requires dimensions divisible by 16
|
| 91 |
+
if self.use_fp8:
|
| 92 |
+
if self.d_model % 16 != 0:
|
| 93 |
+
raise ValueError(f"FP8: d_model ({self.d_model}) must be divisible by 16")
|
| 94 |
+
if self.d_ffn % 16 != 0:
|
| 95 |
+
raise ValueError(f"FP8: d_ffn ({self.d_ffn}) must be divisible by 16")
|
| 96 |
+
|
| 97 |
+
# ------------------------------------------------------------------
|
| 98 |
+
# Properties
|
| 99 |
+
# ------------------------------------------------------------------
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def num_params(self) -> int:
|
| 103 |
+
"""Approximate parameter count using the 12 * L * d^2 rule."""
|
| 104 |
+
return 12 * self.n_layers * self.d_model ** 2
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def head_dim(self) -> int:
|
| 108 |
+
"""Dimensionality of each attention head."""
|
| 109 |
+
return self.d_model // self.n_heads
|
| 110 |
+
|
| 111 |
+
# ------------------------------------------------------------------
|
| 112 |
+
# Serialisation helpers
|
| 113 |
+
# ------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
def to_dict(self) -> dict:
|
| 116 |
+
"""Return a plain-Python-dict representation of the config."""
|
| 117 |
+
return {
|
| 118 |
+
"vocab_size": self.vocab_size,
|
| 119 |
+
"d_model": self.d_model,
|
| 120 |
+
"n_layers": self.n_layers,
|
| 121 |
+
"n_heads": self.n_heads,
|
| 122 |
+
"n_kv_heads": self.n_kv_heads,
|
| 123 |
+
"d_ffn": self.d_ffn,
|
| 124 |
+
"max_seq_len": self.max_seq_len,
|
| 125 |
+
"rope_theta": self.rope_theta,
|
| 126 |
+
"dropout": self.dropout,
|
| 127 |
+
"bias": self.bias,
|
| 128 |
+
"use_flash_attn": self.use_flash_attn,
|
| 129 |
+
"use_fp8": self.use_fp8,
|
| 130 |
+
"use_hybrid": self.use_hybrid,
|
| 131 |
+
"hybrid_pattern": self.hybrid_pattern,
|
| 132 |
+
"mamba_d_state": self.mamba_d_state,
|
| 133 |
+
"mamba_head_dim": self.mamba_head_dim,
|
| 134 |
+
"mamba_expand": self.mamba_expand,
|
| 135 |
+
"mamba_conv_kernel": self.mamba_conv_kernel,
|
| 136 |
+
"mamba_n_groups": self.mamba_n_groups,
|
| 137 |
+
"mamba_chunk_size": self.mamba_chunk_size,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def to_yaml(self, path: str | Path) -> None:
|
| 141 |
+
"""Serialise config to a YAML file."""
|
| 142 |
+
path = Path(path)
|
| 143 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 145 |
+
yaml.safe_dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def from_dict(cls, d: dict) -> "LMConfig":
|
| 149 |
+
"""Construct a LMConfig from a plain dict (e.g. loaded from YAML)."""
|
| 150 |
+
return cls(**d)
|
| 151 |
+
|
| 152 |
+
@classmethod
|
| 153 |
+
def from_yaml(cls, path: str | Path) -> "LMConfig":
|
| 154 |
+
"""Load config from a YAML file."""
|
| 155 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 156 |
+
data = yaml.safe_load(f)
|
| 157 |
+
# Support nested YAML with 'model' section (e.g., shared multi-section configs)
|
| 158 |
+
if "model" in data and isinstance(data["model"], dict):
|
| 159 |
+
data = data["model"]
|
| 160 |
+
return cls.from_dict(data)
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def from_hf_config(cls, path: str | Path) -> "LMConfig":
|
| 164 |
+
"""Load config from a HuggingFace-format config.json (LlamaForCausalLM)."""
|
| 165 |
+
path = Path(path)
|
| 166 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 167 |
+
hf = json.load(f)
|
| 168 |
+
|
| 169 |
+
rope_theta = 10000.0
|
| 170 |
+
if "rope_parameters" in hf and isinstance(hf["rope_parameters"], dict):
|
| 171 |
+
rope_theta = float(hf["rope_parameters"].get("rope_theta", rope_theta))
|
| 172 |
+
elif "rope_theta" in hf:
|
| 173 |
+
rope_theta = float(hf["rope_theta"])
|
| 174 |
+
|
| 175 |
+
return cls(
|
| 176 |
+
vocab_size=hf["vocab_size"],
|
| 177 |
+
d_model=hf["hidden_size"],
|
| 178 |
+
n_layers=hf["num_hidden_layers"],
|
| 179 |
+
n_heads=hf["num_attention_heads"],
|
| 180 |
+
n_kv_heads=hf.get("num_key_value_heads", hf["num_attention_heads"]),
|
| 181 |
+
d_ffn=hf["intermediate_size"],
|
| 182 |
+
max_seq_len=hf.get("max_position_embeddings", 4096),
|
| 183 |
+
rope_theta=rope_theta,
|
| 184 |
+
dropout=hf.get("attention_dropout", 0.0),
|
| 185 |
+
bias=hf.get("attention_bias", False),
|
| 186 |
+
)
|
source/model/layers.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reusable building-block layers: RMSNorm, RotaryEmbedding, SwiGLU.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Optional TransformerEngine import (FP8 support)
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
try:
|
| 16 |
+
import transformer_engine.pytorch as te # type: ignore[import]
|
| 17 |
+
HAS_TE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
te = None # type: ignore[assignment]
|
| 20 |
+
HAS_TE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# RMS Layer Normalisation
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
class RMSNorm(nn.Module):
|
| 28 |
+
"""Root-Mean-Square Layer Normalisation (Zhang & Sennrich, 2019).
|
| 29 |
+
|
| 30 |
+
Computation is promoted to float32 for numerical stability and cast back
|
| 31 |
+
to the input dtype before returning.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.eps = eps
|
| 37 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 38 |
+
|
| 39 |
+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
# x: (..., D) — compute in fp32
|
| 41 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
# Upcast to float32, normalise, scale, then restore original dtype.
|
| 45 |
+
out = self._norm(x.float()).to(x.dtype)
|
| 46 |
+
return out * self.weight
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# Rotary Positional Embedding
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class RotaryEmbedding(nn.Module):
|
| 54 |
+
"""Precomputed rotary positional embeddings (Su et al., RoFormer 2021).
|
| 55 |
+
|
| 56 |
+
Cos/sin tables are stored as buffers (shape: max_seq_len × D//2) so they
|
| 57 |
+
move with the module to the correct device automatically.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, dim: int, max_seq_len: int, theta: float = 10000.0) -> None:
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.dim = dim
|
| 63 |
+
self.max_seq_len = max_seq_len
|
| 64 |
+
self.theta = theta
|
| 65 |
+
|
| 66 |
+
# Precompute and register
|
| 67 |
+
cos, sin = self._build_tables(dim, max_seq_len, theta)
|
| 68 |
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
| 69 |
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def _build_tables(
|
| 73 |
+
dim: int, max_seq_len: int, theta: float
|
| 74 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 75 |
+
"""Compute cos/sin tables with shape (max_seq_len, dim // 2)."""
|
| 76 |
+
half_dim = dim // 2
|
| 77 |
+
# Inverse frequencies: shape (half_dim,)
|
| 78 |
+
freqs = 1.0 / (
|
| 79 |
+
theta ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim)
|
| 80 |
+
)
|
| 81 |
+
# Positions: shape (max_seq_len,)
|
| 82 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 83 |
+
# Outer product → (max_seq_len, half_dim)
|
| 84 |
+
emb = torch.outer(t, freqs)
|
| 85 |
+
cos = emb.cos() # (T, D//2)
|
| 86 |
+
sin = emb.sin() # (T, D//2)
|
| 87 |
+
return cos, sin
|
| 88 |
+
|
| 89 |
+
def forward(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""Return (cos, sin) slices of shape (seq_len, D//2) on *device*.
|
| 91 |
+
|
| 92 |
+
If *seq_len* exceeds the precomputed length the tables are recomputed
|
| 93 |
+
on-the-fly (rare, but graceful fallback).
|
| 94 |
+
"""
|
| 95 |
+
if seq_len > self.max_seq_len:
|
| 96 |
+
cos, sin = self._build_tables(self.dim, seq_len, self.theta)
|
| 97 |
+
cos = cos.to(device)
|
| 98 |
+
sin = sin.to(device)
|
| 99 |
+
else:
|
| 100 |
+
cos = self._cos_cached[:seq_len].to(device)
|
| 101 |
+
sin = self._sin_cached[:seq_len].to(device)
|
| 102 |
+
return cos, sin
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# SwiGLU Feed-Forward Network
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
|
| 109 |
+
class SwiGLU(nn.Module):
|
| 110 |
+
"""SwiGLU feed-forward block (Shazeer, 2020).
|
| 111 |
+
|
| 112 |
+
Architecture:
|
| 113 |
+
out = down_proj( SiLU(gate_proj(x)) * up_proj(x) )
|
| 114 |
+
|
| 115 |
+
The gate and up projections are separate linear layers so that the gating
|
| 116 |
+
mechanism can learn an independent representation.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, d_model: int, d_ffn: int, bias: bool = False) -> None:
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.gate_proj = nn.Linear(d_model, d_ffn, bias=bias)
|
| 122 |
+
self.up_proj = nn.Linear(d_model, d_ffn, bias=bias)
|
| 123 |
+
self.down_proj = nn.Linear(d_ffn, d_model, bias=bias)
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
# Gated activation: element-wise product of SiLU(gate) and up projection
|
| 127 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
source/model/mamba_block.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mamba-2 block based on the Structured State Space Duality (SSD) formulation.
|
| 3 |
+
|
| 4 |
+
Reference: "Transformers are SSMs: Generalized Models and Efficient Algorithms
|
| 5 |
+
Through Structured State Space Duality" (Dao & Gu, 2024).
|
| 6 |
+
|
| 7 |
+
This implements a pure-PyTorch sequential scan for correctness and generality.
|
| 8 |
+
A chunked SSD kernel can be swapped in later for speed.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
|
| 19 |
+
from .layers import RMSNorm
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Selective Scan (sequential, numerically stable in float32)
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
def selective_scan(
|
| 27 |
+
x: torch.Tensor,
|
| 28 |
+
dt: torch.Tensor,
|
| 29 |
+
A_log: torch.Tensor,
|
| 30 |
+
B: torch.Tensor,
|
| 31 |
+
C: torch.Tensor,
|
| 32 |
+
D: torch.Tensor,
|
| 33 |
+
n_groups: int,
|
| 34 |
+
) -> torch.Tensor:
|
| 35 |
+
"""Run the SSM recurrence sequentially over the time axis.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
x: (B, L, n_heads, head_dim) — input after conv + activation.
|
| 39 |
+
dt: (B, L, n_heads) — discretisation time-steps (after softplus).
|
| 40 |
+
A_log: (n_heads,) — log(-A), learnable diagonal decay.
|
| 41 |
+
B: (B, L, n_groups, d_state) — input-to-state projection per step.
|
| 42 |
+
C: (B, L, n_groups, d_state) — state-to-output projection per step.
|
| 43 |
+
D: (n_heads,) — skip/residual connection per head.
|
| 44 |
+
n_groups: int — number of B/C groups (heads per group share B/C).
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
y: (B, L, n_heads, head_dim) — SSM output.
|
| 48 |
+
"""
|
| 49 |
+
batch, seq_len, n_heads, head_dim = x.shape
|
| 50 |
+
d_state = B.shape[-1]
|
| 51 |
+
heads_per_group = n_heads // n_groups
|
| 52 |
+
|
| 53 |
+
# Compute decay: dA = exp(-exp(A_log) * dt) — shape (B, L, n_heads)
|
| 54 |
+
neg_A = A_log.exp() # (n_heads,)
|
| 55 |
+
dA = torch.exp(-neg_A.unsqueeze(0).unsqueeze(0) * dt) # (B, L, n_heads)
|
| 56 |
+
|
| 57 |
+
# Scale input by dt: dBx will be accumulated into state
|
| 58 |
+
# dt: (B, L, n_heads) -> (B, L, n_heads, 1)
|
| 59 |
+
dt_x = dt.unsqueeze(-1) * x # (B, L, n_heads, head_dim)
|
| 60 |
+
|
| 61 |
+
# Allocate output
|
| 62 |
+
y = torch.zeros_like(x)
|
| 63 |
+
|
| 64 |
+
# State: (B, n_heads, head_dim, d_state) — accumulated in float32
|
| 65 |
+
h = torch.zeros(
|
| 66 |
+
batch, n_heads, head_dim, d_state,
|
| 67 |
+
dtype=torch.float32, device=x.device,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Expand B/C from groups to heads: (B, L, n_groups, d_state) -> indexing
|
| 71 |
+
# For efficiency we index into the group dimension during the loop.
|
| 72 |
+
# group_idx[head] -> which group this head belongs to
|
| 73 |
+
group_idx = torch.arange(n_heads, device=x.device) // heads_per_group # (n_heads,)
|
| 74 |
+
|
| 75 |
+
for t in range(seq_len):
|
| 76 |
+
# --- Decay state ---
|
| 77 |
+
# dA_t: (B, n_heads) -> (B, n_heads, 1, 1)
|
| 78 |
+
dA_t = dA[:, t, :].float().unsqueeze(-1).unsqueeze(-1)
|
| 79 |
+
h = h * dA_t # (B, n_heads, head_dim, d_state)
|
| 80 |
+
|
| 81 |
+
# --- Input contribution ---
|
| 82 |
+
# B_t: (B, n_groups, d_state) -> (B, n_heads, d_state) via group expansion
|
| 83 |
+
B_t = B[:, t, :, :][:, group_idx, :] # (B, n_heads, d_state)
|
| 84 |
+
# dt_x_t: (B, n_heads, head_dim)
|
| 85 |
+
dt_x_t = dt_x[:, t, :, :].float() # (B, n_heads, head_dim)
|
| 86 |
+
# Outer product: (B, n_heads, head_dim, 1) * (B, n_heads, 1, d_state)
|
| 87 |
+
h = h + dt_x_t.unsqueeze(-1) * B_t.float().unsqueeze(-2)
|
| 88 |
+
|
| 89 |
+
# --- Output ---
|
| 90 |
+
# C_t: (B, n_groups, d_state) -> (B, n_heads, d_state)
|
| 91 |
+
C_t = C[:, t, :, :][:, group_idx, :] # (B, n_heads, d_state)
|
| 92 |
+
# y_t = sum_over_d_state( h * C_t ) -> (B, n_heads, head_dim)
|
| 93 |
+
y_t = torch.einsum("bnhd,bnd->bnh", h, C_t.float())
|
| 94 |
+
y[:, t, :, :] = y_t.to(x.dtype)
|
| 95 |
+
|
| 96 |
+
# Skip connection: D * x
|
| 97 |
+
y = y + D.view(1, 1, n_heads, 1) * x
|
| 98 |
+
|
| 99 |
+
return y
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Mamba-2 Block
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
class Mamba2Block(nn.Module):
|
| 107 |
+
"""Mamba-2 block with pre-norm residual connection.
|
| 108 |
+
|
| 109 |
+
Implements:
|
| 110 |
+
1. RMSNorm (pre-norm)
|
| 111 |
+
2. Input projection -> (z, x, B, C, dt)
|
| 112 |
+
3. Causal depth-wise Conv1d on x
|
| 113 |
+
4. SiLU activation on x
|
| 114 |
+
5. Selective scan (SSM recurrence)
|
| 115 |
+
6. Gated output: y * SiLU(z)
|
| 116 |
+
7. Output projection + residual
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
d_model: Model hidden dimension.
|
| 120 |
+
d_state: SSM state dimension N (default 128).
|
| 121 |
+
head_dim: Per-head dimension for SSD (default 64).
|
| 122 |
+
expand: Expansion factor for inner dimension (default 2).
|
| 123 |
+
conv_kernel: Causal 1D convolution kernel size (default 4).
|
| 124 |
+
n_groups: Number of groups for B/C projections (default 1).
|
| 125 |
+
chunk_size: Chunk size for SSD algorithm — reserved for future use (default 256).
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
d_model: int,
|
| 131 |
+
d_state: int = 128,
|
| 132 |
+
head_dim: int = 64,
|
| 133 |
+
expand: int = 2,
|
| 134 |
+
conv_kernel: int = 4,
|
| 135 |
+
n_groups: int = 1,
|
| 136 |
+
chunk_size: int = 256,
|
| 137 |
+
) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
|
| 140 |
+
self.d_model = d_model
|
| 141 |
+
self.d_state = d_state
|
| 142 |
+
self.head_dim = head_dim
|
| 143 |
+
self.expand = expand
|
| 144 |
+
self.n_groups = n_groups
|
| 145 |
+
self.chunk_size = chunk_size
|
| 146 |
+
|
| 147 |
+
# Derived dimensions
|
| 148 |
+
self.d_inner = expand * d_model
|
| 149 |
+
self.n_heads = self.d_inner // head_dim
|
| 150 |
+
assert self.d_inner % head_dim == 0, (
|
| 151 |
+
f"d_inner ({self.d_inner}) must be divisible by head_dim ({head_dim})"
|
| 152 |
+
)
|
| 153 |
+
assert self.n_heads % n_groups == 0, (
|
| 154 |
+
f"n_heads ({self.n_heads}) must be divisible by n_groups ({n_groups})"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Pre-norm
|
| 158 |
+
self.norm = RMSNorm(d_model)
|
| 159 |
+
|
| 160 |
+
# Input projection: d_model -> z + x + B + C + dt
|
| 161 |
+
self.d_proj = (
|
| 162 |
+
self.d_inner # z (gate)
|
| 163 |
+
+ self.d_inner # x (input to conv + SSM)
|
| 164 |
+
+ n_groups * d_state # B
|
| 165 |
+
+ n_groups * d_state # C
|
| 166 |
+
+ self.n_heads # dt (one per head)
|
| 167 |
+
)
|
| 168 |
+
self.in_proj = nn.Linear(d_model, self.d_proj, bias=False)
|
| 169 |
+
|
| 170 |
+
# Causal depth-wise conv1d over x
|
| 171 |
+
self.conv1d = nn.Conv1d(
|
| 172 |
+
in_channels=self.d_inner,
|
| 173 |
+
out_channels=self.d_inner,
|
| 174 |
+
kernel_size=conv_kernel,
|
| 175 |
+
groups=self.d_inner,
|
| 176 |
+
padding=conv_kernel - 1, # causal: trim trailing values
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# SSM parameters
|
| 180 |
+
# A_log: log(-A) where A is the diagonal decay — init from log(uniform(1, 16))
|
| 181 |
+
A_init = torch.log(torch.rand(self.n_heads) * 15.0 + 1.0) # log(U(1,16))
|
| 182 |
+
self.A_log = nn.Parameter(A_init)
|
| 183 |
+
|
| 184 |
+
# D: skip connection per head — init to ones
|
| 185 |
+
self.D = nn.Parameter(torch.ones(self.n_heads))
|
| 186 |
+
|
| 187 |
+
# dt_bias: added before softplus — init from log(uniform(0.001, 0.1))
|
| 188 |
+
dt_bias_init = torch.log(torch.rand(self.n_heads) * 0.099 + 0.001)
|
| 189 |
+
self.dt_bias = nn.Parameter(dt_bias_init)
|
| 190 |
+
|
| 191 |
+
# Output projection
|
| 192 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| 193 |
+
|
| 194 |
+
# ------------------------------------------------------------------
|
| 195 |
+
# Helpers
|
| 196 |
+
# ------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
def _split_projection(
|
| 199 |
+
self, proj: torch.Tensor
|
| 200 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 201 |
+
"""Split the fused input projection into (z, x, B, C, dt).
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
proj: (B, L, d_proj)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
z: (B, L, d_inner)
|
| 208 |
+
x: (B, L, d_inner)
|
| 209 |
+
B: (B, L, n_groups, d_state)
|
| 210 |
+
C: (B, L, n_groups, d_state)
|
| 211 |
+
dt: (B, L, n_heads)
|
| 212 |
+
"""
|
| 213 |
+
batch, seq_len, _ = proj.shape
|
| 214 |
+
i = 0
|
| 215 |
+
|
| 216 |
+
z = proj[:, :, i : i + self.d_inner]
|
| 217 |
+
i += self.d_inner
|
| 218 |
+
|
| 219 |
+
x = proj[:, :, i : i + self.d_inner]
|
| 220 |
+
i += self.d_inner
|
| 221 |
+
|
| 222 |
+
bc_dim = self.n_groups * self.d_state
|
| 223 |
+
B = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
|
| 224 |
+
i += bc_dim
|
| 225 |
+
|
| 226 |
+
C = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
|
| 227 |
+
i += bc_dim
|
| 228 |
+
|
| 229 |
+
dt = proj[:, :, i : i + self.n_heads]
|
| 230 |
+
return z, x, B, C, dt
|
| 231 |
+
|
| 232 |
+
# ------------------------------------------------------------------
|
| 233 |
+
# Forward
|
| 234 |
+
# ------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 237 |
+
"""
|
| 238 |
+
Args:
|
| 239 |
+
x: (B, L, d_model) — input hidden states.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
(B, L, d_model) — output with residual connection applied.
|
| 243 |
+
"""
|
| 244 |
+
residual = x
|
| 245 |
+
x = self.norm(x)
|
| 246 |
+
|
| 247 |
+
# --- Input projection ---
|
| 248 |
+
proj = self.in_proj(x) # (B, L, d_proj)
|
| 249 |
+
z, x_ssm, B, C, dt_raw = self._split_projection(proj)
|
| 250 |
+
|
| 251 |
+
# --- Causal conv1d on x ---
|
| 252 |
+
# Conv1d expects (B, C, L)
|
| 253 |
+
x_conv = x_ssm.transpose(1, 2) # (B, d_inner, L)
|
| 254 |
+
x_conv = self.conv1d(x_conv)
|
| 255 |
+
# Trim to causal: remove the (kernel-1) trailing padding
|
| 256 |
+
x_conv = x_conv[:, :, :x_ssm.shape[1]] # (B, d_inner, L)
|
| 257 |
+
x_conv = x_conv.transpose(1, 2) # (B, L, d_inner)
|
| 258 |
+
x_conv = F.silu(x_conv)
|
| 259 |
+
|
| 260 |
+
# --- Discretise dt ---
|
| 261 |
+
dt = F.softplus(dt_raw + self.dt_bias) # (B, L, n_heads)
|
| 262 |
+
|
| 263 |
+
# --- Reshape x for multi-head scan ---
|
| 264 |
+
batch, seq_len, _ = x_conv.shape
|
| 265 |
+
x_heads = x_conv.reshape(batch, seq_len, self.n_heads, self.head_dim)
|
| 266 |
+
|
| 267 |
+
# --- Selective scan (SSM recurrence) ---
|
| 268 |
+
y = selective_scan(
|
| 269 |
+
x_heads, dt, self.A_log, B, C, self.D,
|
| 270 |
+
n_groups=self.n_groups,
|
| 271 |
+
) # (B, L, n_heads, head_dim)
|
| 272 |
+
|
| 273 |
+
# --- Flatten heads back ---
|
| 274 |
+
y = y.reshape(batch, seq_len, self.d_inner) # (B, L, d_inner)
|
| 275 |
+
|
| 276 |
+
# --- Gated output ---
|
| 277 |
+
y = y * F.silu(z)
|
| 278 |
+
|
| 279 |
+
# --- Output projection + residual ---
|
| 280 |
+
return residual + self.out_proj(y)
|
source/model/transformer.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full transformer: TransformerBlock and top-level LLM model.
|
| 3 |
+
Supports pure Transformer and hybrid Mamba-2 + Transformer architectures.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from .config import LMConfig
|
| 16 |
+
from .layers import RMSNorm, RotaryEmbedding, SwiGLU
|
| 17 |
+
from .attention import MultiHeadAttention
|
| 18 |
+
from .mamba_block import Mamba2Block
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Optional TransformerEngine import (FP8 support)
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
try:
|
| 24 |
+
import transformer_engine.pytorch as te # type: ignore[import]
|
| 25 |
+
HAS_TE = True
|
| 26 |
+
except ImportError:
|
| 27 |
+
te = None # type: ignore[assignment]
|
| 28 |
+
HAS_TE = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# HuggingFace ↔ Custom weight conversion helpers
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
def _load_hf_state_dict(path: Path) -> dict[str, torch.Tensor]:
|
| 36 |
+
"""Load weights from HF safetensors (or pytorch_model.bin fallback)."""
|
| 37 |
+
safetensors_path = path / "model.safetensors"
|
| 38 |
+
if safetensors_path.exists():
|
| 39 |
+
from safetensors.torch import load_file
|
| 40 |
+
return load_file(str(safetensors_path), device="cpu")
|
| 41 |
+
bin_path = path / "pytorch_model.bin"
|
| 42 |
+
if bin_path.exists():
|
| 43 |
+
return torch.load(bin_path, map_location="cpu", weights_only=True)
|
| 44 |
+
raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {path}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _convert_hf_to_custom(hf_sd: dict[str, torch.Tensor], config: LMConfig) -> dict[str, torch.Tensor]:
|
| 48 |
+
"""Convert HuggingFace LlamaForCausalLM state dict to our custom format.
|
| 49 |
+
|
| 50 |
+
Key mapping:
|
| 51 |
+
HF: model.embed_tokens.weight → embedding.weight
|
| 52 |
+
HF: model.layers.{i}.self_attn.q/k/v_proj.weight → layers.{i}.attn.qkv_proj.weight (fused)
|
| 53 |
+
HF: model.layers.{i}.self_attn.o_proj.weight → layers.{i}.attn.out_proj.weight
|
| 54 |
+
HF: model.layers.{i}.input_layernorm.weight → layers.{i}.attn_norm.weight
|
| 55 |
+
HF: model.layers.{i}.mlp.gate_proj.weight → layers.{i}.ffn.gate_proj.weight
|
| 56 |
+
HF: model.layers.{i}.mlp.up_proj.weight → layers.{i}.ffn.up_proj.weight
|
| 57 |
+
HF: model.layers.{i}.mlp.down_proj.weight → layers.{i}.ffn.down_proj.weight
|
| 58 |
+
HF: model.layers.{i}.post_attention_layernorm.weight → layers.{i}.ffn_norm.weight
|
| 59 |
+
HF: model.norm.weight → norm.weight
|
| 60 |
+
HF: lm_head.weight → lm_head.weight
|
| 61 |
+
"""
|
| 62 |
+
sd: dict[str, torch.Tensor] = {}
|
| 63 |
+
|
| 64 |
+
sd["embedding.weight"] = hf_sd["model.embed_tokens.weight"]
|
| 65 |
+
sd["norm.weight"] = hf_sd["model.norm.weight"]
|
| 66 |
+
sd["lm_head.weight"] = hf_sd["lm_head.weight"]
|
| 67 |
+
|
| 68 |
+
for i in range(config.n_layers):
|
| 69 |
+
pfx = f"model.layers.{i}"
|
| 70 |
+
out = f"layers.{i}"
|
| 71 |
+
|
| 72 |
+
# Fuse Q, K, V into single qkv_proj
|
| 73 |
+
q = hf_sd[f"{pfx}.self_attn.q_proj.weight"]
|
| 74 |
+
k = hf_sd[f"{pfx}.self_attn.k_proj.weight"]
|
| 75 |
+
v = hf_sd[f"{pfx}.self_attn.v_proj.weight"]
|
| 76 |
+
sd[f"{out}.attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
|
| 77 |
+
|
| 78 |
+
sd[f"{out}.attn.out_proj.weight"] = hf_sd[f"{pfx}.self_attn.o_proj.weight"]
|
| 79 |
+
sd[f"{out}.attn_norm.weight"] = hf_sd[f"{pfx}.input_layernorm.weight"]
|
| 80 |
+
|
| 81 |
+
sd[f"{out}.ffn.gate_proj.weight"] = hf_sd[f"{pfx}.mlp.gate_proj.weight"]
|
| 82 |
+
sd[f"{out}.ffn.up_proj.weight"] = hf_sd[f"{pfx}.mlp.up_proj.weight"]
|
| 83 |
+
sd[f"{out}.ffn.down_proj.weight"] = hf_sd[f"{pfx}.mlp.down_proj.weight"]
|
| 84 |
+
sd[f"{out}.ffn_norm.weight"] = hf_sd[f"{pfx}.post_attention_layernorm.weight"]
|
| 85 |
+
|
| 86 |
+
return sd
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
# Transformer Block
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
class TransformerBlock(nn.Module):
|
| 94 |
+
"""Single pre-norm transformer decoder block.
|
| 95 |
+
|
| 96 |
+
Layout:
|
| 97 |
+
x = x + Attention( RMSNorm(x) )
|
| 98 |
+
x = x + FFN( RMSNorm(x) )
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(self, config: LMConfig) -> None:
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.attn_norm = RMSNorm(config.d_model)
|
| 104 |
+
self.attn = MultiHeadAttention(config)
|
| 105 |
+
self._use_fp8 = config.use_fp8 and HAS_TE
|
| 106 |
+
|
| 107 |
+
if self._use_fp8:
|
| 108 |
+
# te.LayerNormMLP fuses RMSNorm + gate/up/down projections into one kernel.
|
| 109 |
+
# It applies normalisation internally, so ffn_norm is not needed.
|
| 110 |
+
self.ffn_norm = None
|
| 111 |
+
self.ffn = te.LayerNormMLP(
|
| 112 |
+
hidden_size=config.d_model,
|
| 113 |
+
ffn_hidden_size=config.d_ffn,
|
| 114 |
+
bias=config.bias,
|
| 115 |
+
activation="swiglu",
|
| 116 |
+
normalization="RMSNorm",
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
self.ffn_norm = RMSNorm(config.d_model)
|
| 120 |
+
self.ffn = SwiGLU(config.d_model, config.d_ffn, bias=config.bias)
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
cos: torch.Tensor,
|
| 126 |
+
sin: torch.Tensor,
|
| 127 |
+
) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Args:
|
| 130 |
+
x: (B, T, C)
|
| 131 |
+
cos: (T, head_dim // 2)
|
| 132 |
+
sin: (T, head_dim // 2)
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
(B, T, C)
|
| 136 |
+
"""
|
| 137 |
+
# Pre-norm attention with residual
|
| 138 |
+
x = x + self.attn(self.attn_norm(x), cos, sin)
|
| 139 |
+
# FFN with residual — te.LayerNormMLP applies norm internally
|
| 140 |
+
if self._use_fp8:
|
| 141 |
+
x = x + self.ffn(x)
|
| 142 |
+
else:
|
| 143 |
+
x = x + self.ffn(self.ffn_norm(x))
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ---------------------------------------------------------------------------
|
| 148 |
+
# Full Language Model
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
|
| 151 |
+
class LLM(nn.Module):
|
| 152 |
+
"""Decoder-only transformer language model.
|
| 153 |
+
|
| 154 |
+
Features:
|
| 155 |
+
- Learned token embeddings with weight tying to the LM head
|
| 156 |
+
- Rotary positional embeddings (no learned position embeddings)
|
| 157 |
+
- Stack of pre-norm TransformerBlocks
|
| 158 |
+
- Final RMSNorm before the LM head
|
| 159 |
+
- Optional cross-entropy loss computation (for training)
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, config: LMConfig) -> None:
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.config = config
|
| 165 |
+
|
| 166 |
+
# --- Embedding -------------------------------------------------------
|
| 167 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 168 |
+
|
| 169 |
+
# --- Layers (pure Transformer or hybrid Mamba-Transformer) -----------
|
| 170 |
+
if config.use_hybrid and config.hybrid_pattern:
|
| 171 |
+
pattern = config.hybrid_pattern.strip().split()
|
| 172 |
+
if len(pattern) != config.n_layers:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"hybrid_pattern has {len(pattern)} entries but "
|
| 175 |
+
f"n_layers={config.n_layers}"
|
| 176 |
+
)
|
| 177 |
+
layers: list[nn.Module] = []
|
| 178 |
+
# Track which layers are Mamba vs Attention for forward dispatch
|
| 179 |
+
self._layer_types: list[str] = pattern
|
| 180 |
+
for layer_type in pattern:
|
| 181 |
+
if layer_type == "M":
|
| 182 |
+
layers.append(Mamba2Block(
|
| 183 |
+
d_model=config.d_model,
|
| 184 |
+
d_state=config.mamba_d_state,
|
| 185 |
+
head_dim=config.mamba_head_dim,
|
| 186 |
+
expand=config.mamba_expand,
|
| 187 |
+
conv_kernel=config.mamba_conv_kernel,
|
| 188 |
+
n_groups=config.mamba_n_groups,
|
| 189 |
+
chunk_size=config.mamba_chunk_size,
|
| 190 |
+
))
|
| 191 |
+
elif layer_type == "A":
|
| 192 |
+
layers.append(TransformerBlock(config))
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"Unknown layer type '{layer_type}' in hybrid_pattern. "
|
| 196 |
+
f"Use 'M' (Mamba) or 'A' (Attention)."
|
| 197 |
+
)
|
| 198 |
+
self.layers = nn.ModuleList(layers)
|
| 199 |
+
else:
|
| 200 |
+
self._layer_types = ["A"] * config.n_layers
|
| 201 |
+
self.layers = nn.ModuleList(
|
| 202 |
+
[TransformerBlock(config) for _ in range(config.n_layers)]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# --- Final normalisation and LM head ---------------------------------
|
| 206 |
+
self.norm = RMSNorm(config.d_model)
|
| 207 |
+
# NOTE: lm_head는 nn.Linear 유지 — embedding weight tying + TE FP8 호환성
|
| 208 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 209 |
+
|
| 210 |
+
# Weight tying: share embedding and LM-head weight matrices
|
| 211 |
+
self.lm_head.weight = self.embedding.weight
|
| 212 |
+
|
| 213 |
+
# --- Rotary embeddings -----------------------------------------------
|
| 214 |
+
self.rope = RotaryEmbedding(
|
| 215 |
+
dim=config.head_dim,
|
| 216 |
+
max_seq_len=config.max_seq_len,
|
| 217 |
+
theta=config.rope_theta,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# --- Initialise weights ----------------------------------------------
|
| 221 |
+
self.apply(self._init_weights)
|
| 222 |
+
|
| 223 |
+
# ------------------------------------------------------------------
|
| 224 |
+
# Weight initialisation
|
| 225 |
+
# ------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def _init_weights(module: nn.Module) -> None:
|
| 229 |
+
"""Apply standard initialisation:
|
| 230 |
+
- Linear / Embedding weights: N(0, 0.02)
|
| 231 |
+
- Bias parameters: zeros
|
| 232 |
+
- te.Linear / te.LayerNormMLP: skipped (TE manages its own init)
|
| 233 |
+
- Mamba2Block: skipped (manages its own init)
|
| 234 |
+
"""
|
| 235 |
+
# TE modules handle their own weight initialisation.
|
| 236 |
+
if HAS_TE and isinstance(module, (te.Linear, te.LayerNormMLP)):
|
| 237 |
+
return
|
| 238 |
+
# Mamba2Block handles its own parameter init (A_log, D, dt_bias, etc.)
|
| 239 |
+
if isinstance(module, Mamba2Block):
|
| 240 |
+
return
|
| 241 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 242 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 243 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 244 |
+
nn.init.zeros_(module.bias)
|
| 245 |
+
|
| 246 |
+
# ------------------------------------------------------------------
|
| 247 |
+
# Forward pass
|
| 248 |
+
# ------------------------------------------------------------------
|
| 249 |
+
|
| 250 |
+
def forward(
|
| 251 |
+
self,
|
| 252 |
+
input_ids: torch.Tensor,
|
| 253 |
+
targets: Optional[torch.Tensor] = None,
|
| 254 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 255 |
+
"""
|
| 256 |
+
Args:
|
| 257 |
+
input_ids: (B, T) long tensor of token indices
|
| 258 |
+
targets: (B, T) long tensor of target token indices, or None.
|
| 259 |
+
Use -1 (ignore_index) to mask positions.
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
logits: (B, T, vocab_size)
|
| 263 |
+
loss: scalar cross-entropy loss, or None if targets is None
|
| 264 |
+
"""
|
| 265 |
+
B, T = input_ids.shape
|
| 266 |
+
device = input_ids.device
|
| 267 |
+
|
| 268 |
+
# Token embeddings: (B, T, C)
|
| 269 |
+
x = self.embedding(input_ids)
|
| 270 |
+
|
| 271 |
+
# Rotary cos/sin for this sequence length: (T, head_dim // 2)
|
| 272 |
+
# Only needed for Attention layers, but precomputed once for all.
|
| 273 |
+
cos, sin = self.rope(T, device)
|
| 274 |
+
|
| 275 |
+
# Run through blocks — Mamba blocks ignore cos/sin
|
| 276 |
+
for layer, ltype in zip(self.layers, self._layer_types):
|
| 277 |
+
if ltype == "M":
|
| 278 |
+
x = layer(x)
|
| 279 |
+
else:
|
| 280 |
+
x = layer(x, cos, sin)
|
| 281 |
+
|
| 282 |
+
# Final normalisation
|
| 283 |
+
x = self.norm(x)
|
| 284 |
+
|
| 285 |
+
# LM head: (B, T, vocab_size)
|
| 286 |
+
logits = self.lm_head(x)
|
| 287 |
+
|
| 288 |
+
# Compute loss if targets are provided
|
| 289 |
+
loss: Optional[torch.Tensor] = None
|
| 290 |
+
if targets is not None:
|
| 291 |
+
loss = F.cross_entropy(
|
| 292 |
+
logits.view(-1, logits.size(-1)),
|
| 293 |
+
targets.view(-1),
|
| 294 |
+
ignore_index=-1,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return logits, loss
|
| 298 |
+
|
| 299 |
+
# ------------------------------------------------------------------
|
| 300 |
+
# Properties
|
| 301 |
+
# ------------------------------------------------------------------
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def num_params(self) -> int:
|
| 305 |
+
"""Number of trainable parameters."""
|
| 306 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 307 |
+
|
| 308 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 309 |
+
"""HuggingFace-compatible accessor for the token embedding layer."""
|
| 310 |
+
return self.embedding
|
| 311 |
+
|
| 312 |
+
# ------------------------------------------------------------------
|
| 313 |
+
# Constructors
|
| 314 |
+
# ------------------------------------------------------------------
|
| 315 |
+
|
| 316 |
+
@classmethod
|
| 317 |
+
def from_config(cls, config: LMConfig) -> "LLM":
|
| 318 |
+
"""Construct an LLM from an LMConfig instance."""
|
| 319 |
+
return cls(config)
|
| 320 |
+
|
| 321 |
+
@classmethod
|
| 322 |
+
def from_pretrained(cls, path: str | Path) -> "LLM":
|
| 323 |
+
"""Load model from a checkpoint directory.
|
| 324 |
+
|
| 325 |
+
Supports two formats (auto-detected):
|
| 326 |
+
1. Custom: config.yaml + model.pt
|
| 327 |
+
2. HuggingFace: config.json + model.safetensors (LlamaForCausalLM)
|
| 328 |
+
"""
|
| 329 |
+
path = Path(path)
|
| 330 |
+
|
| 331 |
+
# --- Custom format ---
|
| 332 |
+
if (path / "config.yaml").exists():
|
| 333 |
+
config = LMConfig.from_yaml(path / "config.yaml")
|
| 334 |
+
model = cls(config)
|
| 335 |
+
state_dict = torch.load(
|
| 336 |
+
path / "model.pt",
|
| 337 |
+
map_location="cpu",
|
| 338 |
+
weights_only=True,
|
| 339 |
+
)
|
| 340 |
+
model.load_state_dict(state_dict)
|
| 341 |
+
return model
|
| 342 |
+
|
| 343 |
+
# --- HuggingFace format ---
|
| 344 |
+
if (path / "config.json").exists():
|
| 345 |
+
config = LMConfig.from_hf_config(path / "config.json")
|
| 346 |
+
model = cls(config)
|
| 347 |
+
hf_sd = _load_hf_state_dict(path)
|
| 348 |
+
our_sd = _convert_hf_to_custom(hf_sd, config)
|
| 349 |
+
model.load_state_dict(our_sd)
|
| 350 |
+
return model
|
| 351 |
+
|
| 352 |
+
raise FileNotFoundError(
|
| 353 |
+
f"No config.yaml or config.json found in {path}"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# ------------------------------------------------------------------
|
| 357 |
+
# Persistence
|
| 358 |
+
# ------------------------------------------------------------------
|
| 359 |
+
|
| 360 |
+
def save_pretrained(self, path: str | Path) -> None:
|
| 361 |
+
"""Save config and model weights to a directory.
|
| 362 |
+
|
| 363 |
+
Creates:
|
| 364 |
+
<path>/config.yaml
|
| 365 |
+
<path>/model.pt
|
| 366 |
+
"""
|
| 367 |
+
path = Path(path)
|
| 368 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 369 |
+
self.config.to_yaml(path / "config.yaml")
|
| 370 |
+
torch.save(self.state_dict(), path / "model.pt")
|