Upload folder using huggingface_hub

#40
source/model/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model — LLM architecture package.
3
+
4
+ Public API:
5
+ LLM : top-level decoder-only transformer/hybrid language model
6
+ LMConfig : configuration dataclass
7
+ Mamba2Block: Mamba-2 SSD block (used internally by LLM in hybrid mode)
8
+ """
9
+
10
+ from .config import LMConfig
11
+ from .mamba_block import Mamba2Block
12
+ from .transformer import LLM
13
+
14
+ __all__ = [
15
+ "LLM",
16
+ "LMConfig",
17
+ "Mamba2Block",
18
+ ]
source/model/attention.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Head (and Grouped-Query) Attention with optional FlashAttention-2 backend.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .config import LMConfig
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # Optional FlashAttention import
17
+ # ---------------------------------------------------------------------------
18
+ try:
19
+ from flash_attn import flash_attn_func # type: ignore[import]
20
+ HAS_FLASH_ATTN = True
21
+ except ImportError:
22
+ HAS_FLASH_ATTN = False
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Optional TransformerEngine import (FP8 support)
26
+ # ---------------------------------------------------------------------------
27
+ try:
28
+ import transformer_engine.pytorch as te # type: ignore[import]
29
+ HAS_TE = True
30
+ except ImportError:
31
+ te = None # type: ignore[assignment]
32
+ HAS_TE = False
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Rotary embedding helper
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def apply_rotary_emb(
40
+ x: torch.Tensor,
41
+ cos: torch.Tensor,
42
+ sin: torch.Tensor,
43
+ ) -> torch.Tensor:
44
+ """Apply rotary positional embeddings to query or key tensor.
45
+
46
+ Args:
47
+ x: (B, T, H, D_head)
48
+ cos: (T, D_head // 2) — from RotaryEmbedding.forward
49
+ sin: (T, D_head // 2) — from RotaryEmbedding.forward
50
+
51
+ Returns:
52
+ Tensor with the same shape as *x*, rotated.
53
+ """
54
+ d = x.shape[-1]
55
+ half_d = d // 2
56
+
57
+ x1 = x[..., :half_d] # (B, T, H, D//2)
58
+ x2 = x[..., half_d:] # (B, T, H, D//2)
59
+
60
+ # Broadcast cos/sin from (T, D//2) → (1, T, 1, D//2)
61
+ cos = cos.unsqueeze(0).unsqueeze(2) # (1, T, 1, D//2)
62
+ sin = sin.unsqueeze(0).unsqueeze(2) # (1, T, 1, D//2)
63
+
64
+ rotated = torch.cat(
65
+ [x1 * cos - x2 * sin, x1 * sin + x2 * cos],
66
+ dim=-1,
67
+ )
68
+ return rotated.to(x.dtype)
69
+
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Multi-Head Attention
74
+ # ---------------------------------------------------------------------------
75
+
76
+ class MultiHeadAttention(nn.Module):
77
+ """Multi-head (or grouped-query) causal self-attention.
78
+
79
+ Supports:
80
+ - Standard MHA: n_kv_heads == n_heads
81
+ - GQA / MQA: n_kv_heads < n_heads (must evenly divide n_heads)
82
+
83
+ Attention backend:
84
+ - FlashAttention-2 when available and config.use_flash_attn is True
85
+ - Vanilla scaled dot-product otherwise (causal mask via upper-triangular)
86
+ """
87
+
88
+ def __init__(self, config: LMConfig) -> None:
89
+ super().__init__()
90
+
91
+ self.n_heads = config.n_heads
92
+ self.n_kv_heads = config.n_kv_heads # resolved in __post_init__
93
+ self.head_dim = config.d_model // config.n_heads
94
+ self.d_model = config.d_model
95
+ self.dropout = config.dropout
96
+ self.use_flash = config.use_flash_attn
97
+
98
+ # Number of query-head groups per KV head
99
+ self.n_rep = self.n_heads // self.n_kv_heads
100
+
101
+ # Projections ----------------------------------------------------
102
+ # Select Linear implementation: te.Linear (FP8) or nn.Linear (BF16)
103
+ _Linear = te.Linear if (config.use_fp8 and HAS_TE) else nn.Linear
104
+
105
+ # Fused QKV projection: single GEMM (d_model → q_dim + k_dim + v_dim)
106
+ # For GQA 24:8 with head_dim=128: 3072 + 1024 + 1024 = 5120
107
+ self._q_dim = self.n_heads * self.head_dim # e.g. 24 * 128 = 3072
108
+ self._kv_dim = self.n_kv_heads * self.head_dim # e.g. 8 * 128 = 1024
109
+ self.qkv_proj = _Linear(
110
+ config.d_model,
111
+ self._q_dim + 2 * self._kv_dim, # 3072 + 2*1024 = 5120
112
+ bias=config.bias,
113
+ )
114
+ self.out_proj = _Linear(
115
+ config.d_model,
116
+ config.d_model,
117
+ bias=config.bias,
118
+ )
119
+
120
+ # ------------------------------------------------------------------
121
+ # KV-head expansion for GQA
122
+ # ------------------------------------------------------------------
123
+
124
+ @staticmethod
125
+ def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
126
+ """Expand KV heads to match the number of query heads.
127
+
128
+ Args:
129
+ x: (B, T, n_kv_heads, head_dim)
130
+ n_rep: repetition factor
131
+
132
+ Returns:
133
+ (B, T, n_kv_heads * n_rep, head_dim)
134
+ """
135
+ if n_rep == 1:
136
+ return x
137
+ B, T, n_kv, D = x.shape
138
+ return x.repeat_interleave(n_rep, dim=2)
139
+
140
+ # ------------------------------------------------------------------
141
+ # Forward
142
+ # ------------------------------------------------------------------
143
+
144
+ def forward(
145
+ self,
146
+ x: torch.Tensor,
147
+ cos: torch.Tensor,
148
+ sin: torch.Tensor,
149
+ ) -> torch.Tensor:
150
+ """
151
+ Args:
152
+ x: (B, T, C)
153
+ cos: (T, head_dim // 2) — from RotaryEmbedding
154
+ sin: (T, head_dim // 2) — from RotaryEmbedding
155
+
156
+ Returns:
157
+ (B, T, C)
158
+ """
159
+ B, T, C = x.shape
160
+
161
+ # --- Fused QKV projection (single GEMM) --------------------------------
162
+ qkv = self.qkv_proj(x) # (B, T, q_dim + 2*kv_dim)
163
+ q, k, v = qkv.split([self._q_dim, self._kv_dim, self._kv_dim], dim=-1)
164
+ q = q.view(B, T, self.n_heads, self.head_dim)
165
+ k = k.view(B, T, self.n_kv_heads, self.head_dim)
166
+ v = v.view(B, T, self.n_kv_heads, self.head_dim)
167
+
168
+ # FlashAttention-2 and rotary embedding require bf16/fp16.
169
+ # te.Linear with MXFP8 may emit FP8-format output tensors; cast if needed.
170
+ if q.dtype not in (torch.float16, torch.bfloat16):
171
+ q = q.to(torch.bfloat16)
172
+ k = k.to(torch.bfloat16)
173
+ v = v.to(torch.bfloat16)
174
+
175
+ # --- Rotary embeddings -----------------------------------------------
176
+ q = apply_rotary_emb(q, cos, sin)
177
+ k = apply_rotary_emb(k, cos, sin)
178
+
179
+ # --- Attention -------------------------------------------------------
180
+ if self.use_flash and HAS_FLASH_ATTN and x.is_cuda:
181
+ attn_out = self._flash_attention(q, k, v, B, T)
182
+ else:
183
+ attn_out = self._standard_attention(q, k, v, B, T)
184
+
185
+ # --- Output projection -----------------------------------------------
186
+ # attn_out: (B, T, C)
187
+ return self.out_proj(attn_out)
188
+
189
+ # ------------------------------------------------------------------
190
+ # FlashAttention-2 path
191
+ # ------------------------------------------------------------------
192
+
193
+ def _flash_attention(
194
+ self,
195
+ q: torch.Tensor,
196
+ k: torch.Tensor,
197
+ v: torch.Tensor,
198
+ B: int,
199
+ T: int,
200
+ ) -> torch.Tensor:
201
+ """Run FlashAttention-2.
202
+
203
+ flash_attn_func expects inputs in (B, T, H, D) layout and returns
204
+ (B, T, H, D). FlashAttention-2 natively supports GQA via head count
205
+ mismatch (q has n_heads, k/v have n_kv_heads) — no KV expansion needed.
206
+ """
207
+ dropout_p = self.dropout if self.training else 0.0
208
+
209
+ # flash_attn_func: (B, T, H, D) → (B, T, H, D)
210
+ # GQA is handled natively: q=(B,T,n_heads,D), k/v=(B,T,n_kv_heads,D)
211
+ out = flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
212
+
213
+ # Reshape (B, T, n_heads, head_dim) → (B, T, C)
214
+ return out.reshape(B, T, self.n_heads * self.head_dim)
215
+
216
+ # ------------------------------------------------------------------
217
+ # Standard (fallback) attention path
218
+ # ------------------------------------------------------------------
219
+
220
+ def _standard_attention(
221
+ self,
222
+ q: torch.Tensor,
223
+ k: torch.Tensor,
224
+ v: torch.Tensor,
225
+ B: int,
226
+ T: int,
227
+ ) -> torch.Tensor:
228
+ """Vanilla scaled dot-product causal attention.
229
+
230
+ Softmax is computed in float32 for numerical stability.
231
+ """
232
+ # Expand KV heads for GQA
233
+ k = self._repeat_kv(k, self.n_rep) # (B, T, n_heads, head_dim)
234
+ v = self._repeat_kv(v, self.n_rep) # (B, T, n_heads, head_dim)
235
+
236
+ # (B, T, H, D) → (B, H, T, D)
237
+ q = q.transpose(1, 2)
238
+ k = k.transpose(1, 2)
239
+ v = v.transpose(1, 2)
240
+
241
+ scale = math.sqrt(self.head_dim)
242
+
243
+ # Scaled dot-product: (B, H, T, T)
244
+ scores = torch.matmul(q, k.transpose(-2, -1)) / scale
245
+
246
+ # Causal mask: fill upper triangle (excluding diagonal) with -inf
247
+ causal_mask = torch.triu(
248
+ torch.ones(T, T, device=q.device, dtype=torch.bool), diagonal=1
249
+ )
250
+ scores = scores.masked_fill(causal_mask, float("-inf"))
251
+
252
+ # Softmax in fp32, then cast back
253
+ attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
254
+
255
+ if self.training and self.dropout > 0.0:
256
+ attn_weights = F.dropout(attn_weights, p=self.dropout)
257
+
258
+ # Weighted sum: (B, H, T, D)
259
+ out = torch.matmul(attn_weights, v)
260
+
261
+ # (B, H, T, D) → (B, T, H, D) → (B, T, C)
262
+ out = out.transpose(1, 2).contiguous().reshape(B, T, self.d_model)
263
+ return out
source/model/config.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LMConfig: configuration dataclass for the LLM model architecture.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import json
13
+
14
+ import yaml
15
+
16
+
17
+ def _round_to_multiple(n: int, multiple: int) -> int:
18
+ """Round n up to the nearest multiple of `multiple`."""
19
+ return math.ceil(n / multiple) * multiple
20
+
21
+
22
+ @dataclass
23
+ class LMConfig:
24
+ # Vocabulary
25
+ vocab_size: int = 32000
26
+
27
+ # Model dimensions
28
+ d_model: int = 768
29
+ n_layers: int = 12
30
+ n_heads: int = 12
31
+
32
+ # Grouped-query attention: None → standard MHA (n_kv_heads == n_heads)
33
+ n_kv_heads: Optional[int] = None
34
+
35
+ # Feed-forward hidden dimension: None → auto-computed
36
+ d_ffn: Optional[int] = None
37
+
38
+ # Sequence length
39
+ max_seq_len: int = 2048
40
+
41
+ # RoPE base frequency
42
+ rope_theta: float = 10000.0
43
+
44
+ # Regularisation
45
+ dropout: float = 0.0
46
+ bias: bool = False
47
+
48
+ # Attention backend
49
+ use_flash_attn: bool = True
50
+
51
+ # FP8 quantization
52
+ use_fp8: bool = False
53
+
54
+ # Hybrid Mamba-Transformer settings
55
+ use_hybrid: bool = False
56
+ hybrid_pattern: str = "" # e.g. "M M A M M M M A M M M M M M M M M M A M" for 40-layer Nemotron-H style
57
+ # Mamba-2 SSM parameters
58
+ mamba_d_state: int = 128
59
+ mamba_head_dim: int = 64
60
+ mamba_expand: int = 2
61
+ mamba_conv_kernel: int = 4
62
+ mamba_n_groups: int = 1
63
+ mamba_chunk_size: int = 256
64
+
65
+ def __post_init__(self) -> None:
66
+ # Resolve n_kv_heads: None → full MHA
67
+ if self.n_kv_heads is None:
68
+ self.n_kv_heads = self.n_heads
69
+
70
+ # Validate GQA divisibility
71
+ if self.n_heads % self.n_kv_heads != 0:
72
+ raise ValueError(
73
+ f"n_heads ({self.n_heads}) must be divisible by "
74
+ f"n_kv_heads ({self.n_kv_heads})"
75
+ )
76
+
77
+ # Compute d_ffn using the LLaMA-style formula: round(8/3 * d_model)
78
+ # rounded up to the nearest multiple of 256.
79
+ if self.d_ffn is None:
80
+ raw = int(8 / 3 * self.d_model)
81
+ self.d_ffn = _round_to_multiple(raw, 256)
82
+
83
+ # Hybrid Mamba-Transformer validation
84
+ if self.use_hybrid and not self.hybrid_pattern.strip():
85
+ raise ValueError(
86
+ "use_hybrid=True requires a non-empty hybrid_pattern "
87
+ "(space-separated 'M'/'A' per layer)"
88
+ )
89
+
90
+ # FP8 alignment: TE requires dimensions divisible by 16
91
+ if self.use_fp8:
92
+ if self.d_model % 16 != 0:
93
+ raise ValueError(f"FP8: d_model ({self.d_model}) must be divisible by 16")
94
+ if self.d_ffn % 16 != 0:
95
+ raise ValueError(f"FP8: d_ffn ({self.d_ffn}) must be divisible by 16")
96
+
97
+ # ------------------------------------------------------------------
98
+ # Properties
99
+ # ------------------------------------------------------------------
100
+
101
+ @property
102
+ def num_params(self) -> int:
103
+ """Approximate parameter count using the 12 * L * d^2 rule."""
104
+ return 12 * self.n_layers * self.d_model ** 2
105
+
106
+ @property
107
+ def head_dim(self) -> int:
108
+ """Dimensionality of each attention head."""
109
+ return self.d_model // self.n_heads
110
+
111
+ # ------------------------------------------------------------------
112
+ # Serialisation helpers
113
+ # ------------------------------------------------------------------
114
+
115
+ def to_dict(self) -> dict:
116
+ """Return a plain-Python-dict representation of the config."""
117
+ return {
118
+ "vocab_size": self.vocab_size,
119
+ "d_model": self.d_model,
120
+ "n_layers": self.n_layers,
121
+ "n_heads": self.n_heads,
122
+ "n_kv_heads": self.n_kv_heads,
123
+ "d_ffn": self.d_ffn,
124
+ "max_seq_len": self.max_seq_len,
125
+ "rope_theta": self.rope_theta,
126
+ "dropout": self.dropout,
127
+ "bias": self.bias,
128
+ "use_flash_attn": self.use_flash_attn,
129
+ "use_fp8": self.use_fp8,
130
+ "use_hybrid": self.use_hybrid,
131
+ "hybrid_pattern": self.hybrid_pattern,
132
+ "mamba_d_state": self.mamba_d_state,
133
+ "mamba_head_dim": self.mamba_head_dim,
134
+ "mamba_expand": self.mamba_expand,
135
+ "mamba_conv_kernel": self.mamba_conv_kernel,
136
+ "mamba_n_groups": self.mamba_n_groups,
137
+ "mamba_chunk_size": self.mamba_chunk_size,
138
+ }
139
+
140
+ def to_yaml(self, path: str | Path) -> None:
141
+ """Serialise config to a YAML file."""
142
+ path = Path(path)
143
+ path.parent.mkdir(parents=True, exist_ok=True)
144
+ with open(path, "w", encoding="utf-8") as f:
145
+ yaml.safe_dump(self.to_dict(), f, default_flow_style=False, sort_keys=False)
146
+
147
+ @classmethod
148
+ def from_dict(cls, d: dict) -> "LMConfig":
149
+ """Construct a LMConfig from a plain dict (e.g. loaded from YAML)."""
150
+ return cls(**d)
151
+
152
+ @classmethod
153
+ def from_yaml(cls, path: str | Path) -> "LMConfig":
154
+ """Load config from a YAML file."""
155
+ with open(path, "r", encoding="utf-8") as f:
156
+ data = yaml.safe_load(f)
157
+ # Support nested YAML with 'model' section (e.g., shared multi-section configs)
158
+ if "model" in data and isinstance(data["model"], dict):
159
+ data = data["model"]
160
+ return cls.from_dict(data)
161
+
162
+ @classmethod
163
+ def from_hf_config(cls, path: str | Path) -> "LMConfig":
164
+ """Load config from a HuggingFace-format config.json (LlamaForCausalLM)."""
165
+ path = Path(path)
166
+ with open(path, "r", encoding="utf-8") as f:
167
+ hf = json.load(f)
168
+
169
+ rope_theta = 10000.0
170
+ if "rope_parameters" in hf and isinstance(hf["rope_parameters"], dict):
171
+ rope_theta = float(hf["rope_parameters"].get("rope_theta", rope_theta))
172
+ elif "rope_theta" in hf:
173
+ rope_theta = float(hf["rope_theta"])
174
+
175
+ return cls(
176
+ vocab_size=hf["vocab_size"],
177
+ d_model=hf["hidden_size"],
178
+ n_layers=hf["num_hidden_layers"],
179
+ n_heads=hf["num_attention_heads"],
180
+ n_kv_heads=hf.get("num_key_value_heads", hf["num_attention_heads"]),
181
+ d_ffn=hf["intermediate_size"],
182
+ max_seq_len=hf.get("max_position_embeddings", 4096),
183
+ rope_theta=rope_theta,
184
+ dropout=hf.get("attention_dropout", 0.0),
185
+ bias=hf.get("attention_bias", False),
186
+ )
source/model/layers.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reusable building-block layers: RMSNorm, RotaryEmbedding, SwiGLU.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Optional TransformerEngine import (FP8 support)
14
+ # ---------------------------------------------------------------------------
15
+ try:
16
+ import transformer_engine.pytorch as te # type: ignore[import]
17
+ HAS_TE = True
18
+ except ImportError:
19
+ te = None # type: ignore[assignment]
20
+ HAS_TE = False
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # RMS Layer Normalisation
25
+ # ---------------------------------------------------------------------------
26
+
27
+ class RMSNorm(nn.Module):
28
+ """Root-Mean-Square Layer Normalisation (Zhang & Sennrich, 2019).
29
+
30
+ Computation is promoted to float32 for numerical stability and cast back
31
+ to the input dtype before returning.
32
+ """
33
+
34
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
35
+ super().__init__()
36
+ self.eps = eps
37
+ self.weight = nn.Parameter(torch.ones(d_model))
38
+
39
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
40
+ # x: (..., D) — compute in fp32
41
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ # Upcast to float32, normalise, scale, then restore original dtype.
45
+ out = self._norm(x.float()).to(x.dtype)
46
+ return out * self.weight
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Rotary Positional Embedding
51
+ # ---------------------------------------------------------------------------
52
+
53
+ class RotaryEmbedding(nn.Module):
54
+ """Precomputed rotary positional embeddings (Su et al., RoFormer 2021).
55
+
56
+ Cos/sin tables are stored as buffers (shape: max_seq_len × D//2) so they
57
+ move with the module to the correct device automatically.
58
+ """
59
+
60
+ def __init__(self, dim: int, max_seq_len: int, theta: float = 10000.0) -> None:
61
+ super().__init__()
62
+ self.dim = dim
63
+ self.max_seq_len = max_seq_len
64
+ self.theta = theta
65
+
66
+ # Precompute and register
67
+ cos, sin = self._build_tables(dim, max_seq_len, theta)
68
+ self.register_buffer("_cos_cached", cos, persistent=False)
69
+ self.register_buffer("_sin_cached", sin, persistent=False)
70
+
71
+ @staticmethod
72
+ def _build_tables(
73
+ dim: int, max_seq_len: int, theta: float
74
+ ) -> tuple[torch.Tensor, torch.Tensor]:
75
+ """Compute cos/sin tables with shape (max_seq_len, dim // 2)."""
76
+ half_dim = dim // 2
77
+ # Inverse frequencies: shape (half_dim,)
78
+ freqs = 1.0 / (
79
+ theta ** (torch.arange(0, half_dim, dtype=torch.float32) / half_dim)
80
+ )
81
+ # Positions: shape (max_seq_len,)
82
+ t = torch.arange(max_seq_len, dtype=torch.float32)
83
+ # Outer product → (max_seq_len, half_dim)
84
+ emb = torch.outer(t, freqs)
85
+ cos = emb.cos() # (T, D//2)
86
+ sin = emb.sin() # (T, D//2)
87
+ return cos, sin
88
+
89
+ def forward(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
90
+ """Return (cos, sin) slices of shape (seq_len, D//2) on *device*.
91
+
92
+ If *seq_len* exceeds the precomputed length the tables are recomputed
93
+ on-the-fly (rare, but graceful fallback).
94
+ """
95
+ if seq_len > self.max_seq_len:
96
+ cos, sin = self._build_tables(self.dim, seq_len, self.theta)
97
+ cos = cos.to(device)
98
+ sin = sin.to(device)
99
+ else:
100
+ cos = self._cos_cached[:seq_len].to(device)
101
+ sin = self._sin_cached[:seq_len].to(device)
102
+ return cos, sin
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # SwiGLU Feed-Forward Network
107
+ # ---------------------------------------------------------------------------
108
+
109
+ class SwiGLU(nn.Module):
110
+ """SwiGLU feed-forward block (Shazeer, 2020).
111
+
112
+ Architecture:
113
+ out = down_proj( SiLU(gate_proj(x)) * up_proj(x) )
114
+
115
+ The gate and up projections are separate linear layers so that the gating
116
+ mechanism can learn an independent representation.
117
+ """
118
+
119
+ def __init__(self, d_model: int, d_ffn: int, bias: bool = False) -> None:
120
+ super().__init__()
121
+ self.gate_proj = nn.Linear(d_model, d_ffn, bias=bias)
122
+ self.up_proj = nn.Linear(d_model, d_ffn, bias=bias)
123
+ self.down_proj = nn.Linear(d_ffn, d_model, bias=bias)
124
+
125
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
126
+ # Gated activation: element-wise product of SiLU(gate) and up projection
127
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
source/model/mamba_block.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mamba-2 block based on the Structured State Space Duality (SSD) formulation.
3
+
4
+ Reference: "Transformers are SSMs: Generalized Models and Efficient Algorithms
5
+ Through Structured State Space Duality" (Dao & Gu, 2024).
6
+
7
+ This implements a pure-PyTorch sequential scan for correctness and generality.
8
+ A chunked SSD kernel can be swapped in later for speed.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from .layers import RMSNorm
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Selective Scan (sequential, numerically stable in float32)
24
+ # ---------------------------------------------------------------------------
25
+
26
+ def selective_scan(
27
+ x: torch.Tensor,
28
+ dt: torch.Tensor,
29
+ A_log: torch.Tensor,
30
+ B: torch.Tensor,
31
+ C: torch.Tensor,
32
+ D: torch.Tensor,
33
+ n_groups: int,
34
+ ) -> torch.Tensor:
35
+ """Run the SSM recurrence sequentially over the time axis.
36
+
37
+ Args:
38
+ x: (B, L, n_heads, head_dim) — input after conv + activation.
39
+ dt: (B, L, n_heads) — discretisation time-steps (after softplus).
40
+ A_log: (n_heads,) — log(-A), learnable diagonal decay.
41
+ B: (B, L, n_groups, d_state) — input-to-state projection per step.
42
+ C: (B, L, n_groups, d_state) — state-to-output projection per step.
43
+ D: (n_heads,) — skip/residual connection per head.
44
+ n_groups: int — number of B/C groups (heads per group share B/C).
45
+
46
+ Returns:
47
+ y: (B, L, n_heads, head_dim) — SSM output.
48
+ """
49
+ batch, seq_len, n_heads, head_dim = x.shape
50
+ d_state = B.shape[-1]
51
+ heads_per_group = n_heads // n_groups
52
+
53
+ # Compute decay: dA = exp(-exp(A_log) * dt) — shape (B, L, n_heads)
54
+ neg_A = A_log.exp() # (n_heads,)
55
+ dA = torch.exp(-neg_A.unsqueeze(0).unsqueeze(0) * dt) # (B, L, n_heads)
56
+
57
+ # Scale input by dt: dBx will be accumulated into state
58
+ # dt: (B, L, n_heads) -> (B, L, n_heads, 1)
59
+ dt_x = dt.unsqueeze(-1) * x # (B, L, n_heads, head_dim)
60
+
61
+ # Allocate output
62
+ y = torch.zeros_like(x)
63
+
64
+ # State: (B, n_heads, head_dim, d_state) — accumulated in float32
65
+ h = torch.zeros(
66
+ batch, n_heads, head_dim, d_state,
67
+ dtype=torch.float32, device=x.device,
68
+ )
69
+
70
+ # Expand B/C from groups to heads: (B, L, n_groups, d_state) -> indexing
71
+ # For efficiency we index into the group dimension during the loop.
72
+ # group_idx[head] -> which group this head belongs to
73
+ group_idx = torch.arange(n_heads, device=x.device) // heads_per_group # (n_heads,)
74
+
75
+ for t in range(seq_len):
76
+ # --- Decay state ---
77
+ # dA_t: (B, n_heads) -> (B, n_heads, 1, 1)
78
+ dA_t = dA[:, t, :].float().unsqueeze(-1).unsqueeze(-1)
79
+ h = h * dA_t # (B, n_heads, head_dim, d_state)
80
+
81
+ # --- Input contribution ---
82
+ # B_t: (B, n_groups, d_state) -> (B, n_heads, d_state) via group expansion
83
+ B_t = B[:, t, :, :][:, group_idx, :] # (B, n_heads, d_state)
84
+ # dt_x_t: (B, n_heads, head_dim)
85
+ dt_x_t = dt_x[:, t, :, :].float() # (B, n_heads, head_dim)
86
+ # Outer product: (B, n_heads, head_dim, 1) * (B, n_heads, 1, d_state)
87
+ h = h + dt_x_t.unsqueeze(-1) * B_t.float().unsqueeze(-2)
88
+
89
+ # --- Output ---
90
+ # C_t: (B, n_groups, d_state) -> (B, n_heads, d_state)
91
+ C_t = C[:, t, :, :][:, group_idx, :] # (B, n_heads, d_state)
92
+ # y_t = sum_over_d_state( h * C_t ) -> (B, n_heads, head_dim)
93
+ y_t = torch.einsum("bnhd,bnd->bnh", h, C_t.float())
94
+ y[:, t, :, :] = y_t.to(x.dtype)
95
+
96
+ # Skip connection: D * x
97
+ y = y + D.view(1, 1, n_heads, 1) * x
98
+
99
+ return y
100
+
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Mamba-2 Block
104
+ # ---------------------------------------------------------------------------
105
+
106
+ class Mamba2Block(nn.Module):
107
+ """Mamba-2 block with pre-norm residual connection.
108
+
109
+ Implements:
110
+ 1. RMSNorm (pre-norm)
111
+ 2. Input projection -> (z, x, B, C, dt)
112
+ 3. Causal depth-wise Conv1d on x
113
+ 4. SiLU activation on x
114
+ 5. Selective scan (SSM recurrence)
115
+ 6. Gated output: y * SiLU(z)
116
+ 7. Output projection + residual
117
+
118
+ Args:
119
+ d_model: Model hidden dimension.
120
+ d_state: SSM state dimension N (default 128).
121
+ head_dim: Per-head dimension for SSD (default 64).
122
+ expand: Expansion factor for inner dimension (default 2).
123
+ conv_kernel: Causal 1D convolution kernel size (default 4).
124
+ n_groups: Number of groups for B/C projections (default 1).
125
+ chunk_size: Chunk size for SSD algorithm — reserved for future use (default 256).
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ d_model: int,
131
+ d_state: int = 128,
132
+ head_dim: int = 64,
133
+ expand: int = 2,
134
+ conv_kernel: int = 4,
135
+ n_groups: int = 1,
136
+ chunk_size: int = 256,
137
+ ) -> None:
138
+ super().__init__()
139
+
140
+ self.d_model = d_model
141
+ self.d_state = d_state
142
+ self.head_dim = head_dim
143
+ self.expand = expand
144
+ self.n_groups = n_groups
145
+ self.chunk_size = chunk_size
146
+
147
+ # Derived dimensions
148
+ self.d_inner = expand * d_model
149
+ self.n_heads = self.d_inner // head_dim
150
+ assert self.d_inner % head_dim == 0, (
151
+ f"d_inner ({self.d_inner}) must be divisible by head_dim ({head_dim})"
152
+ )
153
+ assert self.n_heads % n_groups == 0, (
154
+ f"n_heads ({self.n_heads}) must be divisible by n_groups ({n_groups})"
155
+ )
156
+
157
+ # Pre-norm
158
+ self.norm = RMSNorm(d_model)
159
+
160
+ # Input projection: d_model -> z + x + B + C + dt
161
+ self.d_proj = (
162
+ self.d_inner # z (gate)
163
+ + self.d_inner # x (input to conv + SSM)
164
+ + n_groups * d_state # B
165
+ + n_groups * d_state # C
166
+ + self.n_heads # dt (one per head)
167
+ )
168
+ self.in_proj = nn.Linear(d_model, self.d_proj, bias=False)
169
+
170
+ # Causal depth-wise conv1d over x
171
+ self.conv1d = nn.Conv1d(
172
+ in_channels=self.d_inner,
173
+ out_channels=self.d_inner,
174
+ kernel_size=conv_kernel,
175
+ groups=self.d_inner,
176
+ padding=conv_kernel - 1, # causal: trim trailing values
177
+ )
178
+
179
+ # SSM parameters
180
+ # A_log: log(-A) where A is the diagonal decay — init from log(uniform(1, 16))
181
+ A_init = torch.log(torch.rand(self.n_heads) * 15.0 + 1.0) # log(U(1,16))
182
+ self.A_log = nn.Parameter(A_init)
183
+
184
+ # D: skip connection per head — init to ones
185
+ self.D = nn.Parameter(torch.ones(self.n_heads))
186
+
187
+ # dt_bias: added before softplus — init from log(uniform(0.001, 0.1))
188
+ dt_bias_init = torch.log(torch.rand(self.n_heads) * 0.099 + 0.001)
189
+ self.dt_bias = nn.Parameter(dt_bias_init)
190
+
191
+ # Output projection
192
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
193
+
194
+ # ------------------------------------------------------------------
195
+ # Helpers
196
+ # ------------------------------------------------------------------
197
+
198
+ def _split_projection(
199
+ self, proj: torch.Tensor
200
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
201
+ """Split the fused input projection into (z, x, B, C, dt).
202
+
203
+ Args:
204
+ proj: (B, L, d_proj)
205
+
206
+ Returns:
207
+ z: (B, L, d_inner)
208
+ x: (B, L, d_inner)
209
+ B: (B, L, n_groups, d_state)
210
+ C: (B, L, n_groups, d_state)
211
+ dt: (B, L, n_heads)
212
+ """
213
+ batch, seq_len, _ = proj.shape
214
+ i = 0
215
+
216
+ z = proj[:, :, i : i + self.d_inner]
217
+ i += self.d_inner
218
+
219
+ x = proj[:, :, i : i + self.d_inner]
220
+ i += self.d_inner
221
+
222
+ bc_dim = self.n_groups * self.d_state
223
+ B = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
224
+ i += bc_dim
225
+
226
+ C = proj[:, :, i : i + bc_dim].reshape(batch, seq_len, self.n_groups, self.d_state)
227
+ i += bc_dim
228
+
229
+ dt = proj[:, :, i : i + self.n_heads]
230
+ return z, x, B, C, dt
231
+
232
+ # ------------------------------------------------------------------
233
+ # Forward
234
+ # ------------------------------------------------------------------
235
+
236
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
237
+ """
238
+ Args:
239
+ x: (B, L, d_model) — input hidden states.
240
+
241
+ Returns:
242
+ (B, L, d_model) — output with residual connection applied.
243
+ """
244
+ residual = x
245
+ x = self.norm(x)
246
+
247
+ # --- Input projection ---
248
+ proj = self.in_proj(x) # (B, L, d_proj)
249
+ z, x_ssm, B, C, dt_raw = self._split_projection(proj)
250
+
251
+ # --- Causal conv1d on x ---
252
+ # Conv1d expects (B, C, L)
253
+ x_conv = x_ssm.transpose(1, 2) # (B, d_inner, L)
254
+ x_conv = self.conv1d(x_conv)
255
+ # Trim to causal: remove the (kernel-1) trailing padding
256
+ x_conv = x_conv[:, :, :x_ssm.shape[1]] # (B, d_inner, L)
257
+ x_conv = x_conv.transpose(1, 2) # (B, L, d_inner)
258
+ x_conv = F.silu(x_conv)
259
+
260
+ # --- Discretise dt ---
261
+ dt = F.softplus(dt_raw + self.dt_bias) # (B, L, n_heads)
262
+
263
+ # --- Reshape x for multi-head scan ---
264
+ batch, seq_len, _ = x_conv.shape
265
+ x_heads = x_conv.reshape(batch, seq_len, self.n_heads, self.head_dim)
266
+
267
+ # --- Selective scan (SSM recurrence) ---
268
+ y = selective_scan(
269
+ x_heads, dt, self.A_log, B, C, self.D,
270
+ n_groups=self.n_groups,
271
+ ) # (B, L, n_heads, head_dim)
272
+
273
+ # --- Flatten heads back ---
274
+ y = y.reshape(batch, seq_len, self.d_inner) # (B, L, d_inner)
275
+
276
+ # --- Gated output ---
277
+ y = y * F.silu(z)
278
+
279
+ # --- Output projection + residual ---
280
+ return residual + self.out_proj(y)
source/model/transformer.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full transformer: TransformerBlock and top-level LLM model.
3
+ Supports pure Transformer and hybrid Mamba-2 + Transformer architectures.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .config import LMConfig
16
+ from .layers import RMSNorm, RotaryEmbedding, SwiGLU
17
+ from .attention import MultiHeadAttention
18
+ from .mamba_block import Mamba2Block
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Optional TransformerEngine import (FP8 support)
22
+ # ---------------------------------------------------------------------------
23
+ try:
24
+ import transformer_engine.pytorch as te # type: ignore[import]
25
+ HAS_TE = True
26
+ except ImportError:
27
+ te = None # type: ignore[assignment]
28
+ HAS_TE = False
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # HuggingFace ↔ Custom weight conversion helpers
33
+ # ---------------------------------------------------------------------------
34
+
35
+ def _load_hf_state_dict(path: Path) -> dict[str, torch.Tensor]:
36
+ """Load weights from HF safetensors (or pytorch_model.bin fallback)."""
37
+ safetensors_path = path / "model.safetensors"
38
+ if safetensors_path.exists():
39
+ from safetensors.torch import load_file
40
+ return load_file(str(safetensors_path), device="cpu")
41
+ bin_path = path / "pytorch_model.bin"
42
+ if bin_path.exists():
43
+ return torch.load(bin_path, map_location="cpu", weights_only=True)
44
+ raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {path}")
45
+
46
+
47
+ def _convert_hf_to_custom(hf_sd: dict[str, torch.Tensor], config: LMConfig) -> dict[str, torch.Tensor]:
48
+ """Convert HuggingFace LlamaForCausalLM state dict to our custom format.
49
+
50
+ Key mapping:
51
+ HF: model.embed_tokens.weight → embedding.weight
52
+ HF: model.layers.{i}.self_attn.q/k/v_proj.weight → layers.{i}.attn.qkv_proj.weight (fused)
53
+ HF: model.layers.{i}.self_attn.o_proj.weight → layers.{i}.attn.out_proj.weight
54
+ HF: model.layers.{i}.input_layernorm.weight → layers.{i}.attn_norm.weight
55
+ HF: model.layers.{i}.mlp.gate_proj.weight → layers.{i}.ffn.gate_proj.weight
56
+ HF: model.layers.{i}.mlp.up_proj.weight → layers.{i}.ffn.up_proj.weight
57
+ HF: model.layers.{i}.mlp.down_proj.weight → layers.{i}.ffn.down_proj.weight
58
+ HF: model.layers.{i}.post_attention_layernorm.weight → layers.{i}.ffn_norm.weight
59
+ HF: model.norm.weight → norm.weight
60
+ HF: lm_head.weight → lm_head.weight
61
+ """
62
+ sd: dict[str, torch.Tensor] = {}
63
+
64
+ sd["embedding.weight"] = hf_sd["model.embed_tokens.weight"]
65
+ sd["norm.weight"] = hf_sd["model.norm.weight"]
66
+ sd["lm_head.weight"] = hf_sd["lm_head.weight"]
67
+
68
+ for i in range(config.n_layers):
69
+ pfx = f"model.layers.{i}"
70
+ out = f"layers.{i}"
71
+
72
+ # Fuse Q, K, V into single qkv_proj
73
+ q = hf_sd[f"{pfx}.self_attn.q_proj.weight"]
74
+ k = hf_sd[f"{pfx}.self_attn.k_proj.weight"]
75
+ v = hf_sd[f"{pfx}.self_attn.v_proj.weight"]
76
+ sd[f"{out}.attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
77
+
78
+ sd[f"{out}.attn.out_proj.weight"] = hf_sd[f"{pfx}.self_attn.o_proj.weight"]
79
+ sd[f"{out}.attn_norm.weight"] = hf_sd[f"{pfx}.input_layernorm.weight"]
80
+
81
+ sd[f"{out}.ffn.gate_proj.weight"] = hf_sd[f"{pfx}.mlp.gate_proj.weight"]
82
+ sd[f"{out}.ffn.up_proj.weight"] = hf_sd[f"{pfx}.mlp.up_proj.weight"]
83
+ sd[f"{out}.ffn.down_proj.weight"] = hf_sd[f"{pfx}.mlp.down_proj.weight"]
84
+ sd[f"{out}.ffn_norm.weight"] = hf_sd[f"{pfx}.post_attention_layernorm.weight"]
85
+
86
+ return sd
87
+
88
+
89
+ # ---------------------------------------------------------------------------
90
+ # Transformer Block
91
+ # ---------------------------------------------------------------------------
92
+
93
+ class TransformerBlock(nn.Module):
94
+ """Single pre-norm transformer decoder block.
95
+
96
+ Layout:
97
+ x = x + Attention( RMSNorm(x) )
98
+ x = x + FFN( RMSNorm(x) )
99
+ """
100
+
101
+ def __init__(self, config: LMConfig) -> None:
102
+ super().__init__()
103
+ self.attn_norm = RMSNorm(config.d_model)
104
+ self.attn = MultiHeadAttention(config)
105
+ self._use_fp8 = config.use_fp8 and HAS_TE
106
+
107
+ if self._use_fp8:
108
+ # te.LayerNormMLP fuses RMSNorm + gate/up/down projections into one kernel.
109
+ # It applies normalisation internally, so ffn_norm is not needed.
110
+ self.ffn_norm = None
111
+ self.ffn = te.LayerNormMLP(
112
+ hidden_size=config.d_model,
113
+ ffn_hidden_size=config.d_ffn,
114
+ bias=config.bias,
115
+ activation="swiglu",
116
+ normalization="RMSNorm",
117
+ )
118
+ else:
119
+ self.ffn_norm = RMSNorm(config.d_model)
120
+ self.ffn = SwiGLU(config.d_model, config.d_ffn, bias=config.bias)
121
+
122
+ def forward(
123
+ self,
124
+ x: torch.Tensor,
125
+ cos: torch.Tensor,
126
+ sin: torch.Tensor,
127
+ ) -> torch.Tensor:
128
+ """
129
+ Args:
130
+ x: (B, T, C)
131
+ cos: (T, head_dim // 2)
132
+ sin: (T, head_dim // 2)
133
+
134
+ Returns:
135
+ (B, T, C)
136
+ """
137
+ # Pre-norm attention with residual
138
+ x = x + self.attn(self.attn_norm(x), cos, sin)
139
+ # FFN with residual — te.LayerNormMLP applies norm internally
140
+ if self._use_fp8:
141
+ x = x + self.ffn(x)
142
+ else:
143
+ x = x + self.ffn(self.ffn_norm(x))
144
+ return x
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Full Language Model
149
+ # ---------------------------------------------------------------------------
150
+
151
+ class LLM(nn.Module):
152
+ """Decoder-only transformer language model.
153
+
154
+ Features:
155
+ - Learned token embeddings with weight tying to the LM head
156
+ - Rotary positional embeddings (no learned position embeddings)
157
+ - Stack of pre-norm TransformerBlocks
158
+ - Final RMSNorm before the LM head
159
+ - Optional cross-entropy loss computation (for training)
160
+ """
161
+
162
+ def __init__(self, config: LMConfig) -> None:
163
+ super().__init__()
164
+ self.config = config
165
+
166
+ # --- Embedding -------------------------------------------------------
167
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
168
+
169
+ # --- Layers (pure Transformer or hybrid Mamba-Transformer) -----------
170
+ if config.use_hybrid and config.hybrid_pattern:
171
+ pattern = config.hybrid_pattern.strip().split()
172
+ if len(pattern) != config.n_layers:
173
+ raise ValueError(
174
+ f"hybrid_pattern has {len(pattern)} entries but "
175
+ f"n_layers={config.n_layers}"
176
+ )
177
+ layers: list[nn.Module] = []
178
+ # Track which layers are Mamba vs Attention for forward dispatch
179
+ self._layer_types: list[str] = pattern
180
+ for layer_type in pattern:
181
+ if layer_type == "M":
182
+ layers.append(Mamba2Block(
183
+ d_model=config.d_model,
184
+ d_state=config.mamba_d_state,
185
+ head_dim=config.mamba_head_dim,
186
+ expand=config.mamba_expand,
187
+ conv_kernel=config.mamba_conv_kernel,
188
+ n_groups=config.mamba_n_groups,
189
+ chunk_size=config.mamba_chunk_size,
190
+ ))
191
+ elif layer_type == "A":
192
+ layers.append(TransformerBlock(config))
193
+ else:
194
+ raise ValueError(
195
+ f"Unknown layer type '{layer_type}' in hybrid_pattern. "
196
+ f"Use 'M' (Mamba) or 'A' (Attention)."
197
+ )
198
+ self.layers = nn.ModuleList(layers)
199
+ else:
200
+ self._layer_types = ["A"] * config.n_layers
201
+ self.layers = nn.ModuleList(
202
+ [TransformerBlock(config) for _ in range(config.n_layers)]
203
+ )
204
+
205
+ # --- Final normalisation and LM head ---------------------------------
206
+ self.norm = RMSNorm(config.d_model)
207
+ # NOTE: lm_head는 nn.Linear 유지 — embedding weight tying + TE FP8 호환성
208
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
209
+
210
+ # Weight tying: share embedding and LM-head weight matrices
211
+ self.lm_head.weight = self.embedding.weight
212
+
213
+ # --- Rotary embeddings -----------------------------------------------
214
+ self.rope = RotaryEmbedding(
215
+ dim=config.head_dim,
216
+ max_seq_len=config.max_seq_len,
217
+ theta=config.rope_theta,
218
+ )
219
+
220
+ # --- Initialise weights ----------------------------------------------
221
+ self.apply(self._init_weights)
222
+
223
+ # ------------------------------------------------------------------
224
+ # Weight initialisation
225
+ # ------------------------------------------------------------------
226
+
227
+ @staticmethod
228
+ def _init_weights(module: nn.Module) -> None:
229
+ """Apply standard initialisation:
230
+ - Linear / Embedding weights: N(0, 0.02)
231
+ - Bias parameters: zeros
232
+ - te.Linear / te.LayerNormMLP: skipped (TE manages its own init)
233
+ - Mamba2Block: skipped (manages its own init)
234
+ """
235
+ # TE modules handle their own weight initialisation.
236
+ if HAS_TE and isinstance(module, (te.Linear, te.LayerNormMLP)):
237
+ return
238
+ # Mamba2Block handles its own parameter init (A_log, D, dt_bias, etc.)
239
+ if isinstance(module, Mamba2Block):
240
+ return
241
+ if isinstance(module, (nn.Linear, nn.Embedding)):
242
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
243
+ if isinstance(module, nn.Linear) and module.bias is not None:
244
+ nn.init.zeros_(module.bias)
245
+
246
+ # ------------------------------------------------------------------
247
+ # Forward pass
248
+ # ------------------------------------------------------------------
249
+
250
+ def forward(
251
+ self,
252
+ input_ids: torch.Tensor,
253
+ targets: Optional[torch.Tensor] = None,
254
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
255
+ """
256
+ Args:
257
+ input_ids: (B, T) long tensor of token indices
258
+ targets: (B, T) long tensor of target token indices, or None.
259
+ Use -1 (ignore_index) to mask positions.
260
+
261
+ Returns:
262
+ logits: (B, T, vocab_size)
263
+ loss: scalar cross-entropy loss, or None if targets is None
264
+ """
265
+ B, T = input_ids.shape
266
+ device = input_ids.device
267
+
268
+ # Token embeddings: (B, T, C)
269
+ x = self.embedding(input_ids)
270
+
271
+ # Rotary cos/sin for this sequence length: (T, head_dim // 2)
272
+ # Only needed for Attention layers, but precomputed once for all.
273
+ cos, sin = self.rope(T, device)
274
+
275
+ # Run through blocks — Mamba blocks ignore cos/sin
276
+ for layer, ltype in zip(self.layers, self._layer_types):
277
+ if ltype == "M":
278
+ x = layer(x)
279
+ else:
280
+ x = layer(x, cos, sin)
281
+
282
+ # Final normalisation
283
+ x = self.norm(x)
284
+
285
+ # LM head: (B, T, vocab_size)
286
+ logits = self.lm_head(x)
287
+
288
+ # Compute loss if targets are provided
289
+ loss: Optional[torch.Tensor] = None
290
+ if targets is not None:
291
+ loss = F.cross_entropy(
292
+ logits.view(-1, logits.size(-1)),
293
+ targets.view(-1),
294
+ ignore_index=-1,
295
+ )
296
+
297
+ return logits, loss
298
+
299
+ # ------------------------------------------------------------------
300
+ # Properties
301
+ # ------------------------------------------------------------------
302
+
303
+ @property
304
+ def num_params(self) -> int:
305
+ """Number of trainable parameters."""
306
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
307
+
308
+ def get_input_embeddings(self) -> nn.Embedding:
309
+ """HuggingFace-compatible accessor for the token embedding layer."""
310
+ return self.embedding
311
+
312
+ # ------------------------------------------------------------------
313
+ # Constructors
314
+ # ------------------------------------------------------------------
315
+
316
+ @classmethod
317
+ def from_config(cls, config: LMConfig) -> "LLM":
318
+ """Construct an LLM from an LMConfig instance."""
319
+ return cls(config)
320
+
321
+ @classmethod
322
+ def from_pretrained(cls, path: str | Path) -> "LLM":
323
+ """Load model from a checkpoint directory.
324
+
325
+ Supports two formats (auto-detected):
326
+ 1. Custom: config.yaml + model.pt
327
+ 2. HuggingFace: config.json + model.safetensors (LlamaForCausalLM)
328
+ """
329
+ path = Path(path)
330
+
331
+ # --- Custom format ---
332
+ if (path / "config.yaml").exists():
333
+ config = LMConfig.from_yaml(path / "config.yaml")
334
+ model = cls(config)
335
+ state_dict = torch.load(
336
+ path / "model.pt",
337
+ map_location="cpu",
338
+ weights_only=True,
339
+ )
340
+ model.load_state_dict(state_dict)
341
+ return model
342
+
343
+ # --- HuggingFace format ---
344
+ if (path / "config.json").exists():
345
+ config = LMConfig.from_hf_config(path / "config.json")
346
+ model = cls(config)
347
+ hf_sd = _load_hf_state_dict(path)
348
+ our_sd = _convert_hf_to_custom(hf_sd, config)
349
+ model.load_state_dict(our_sd)
350
+ return model
351
+
352
+ raise FileNotFoundError(
353
+ f"No config.yaml or config.json found in {path}"
354
+ )
355
+
356
+ # ------------------------------------------------------------------
357
+ # Persistence
358
+ # ------------------------------------------------------------------
359
+
360
+ def save_pretrained(self, path: str | Path) -> None:
361
+ """Save config and model weights to a directory.
362
+
363
+ Creates:
364
+ <path>/config.yaml
365
+ <path>/model.pt
366
+ """
367
+ path = Path(path)
368
+ path.mkdir(parents=True, exist_ok=True)
369
+ self.config.to_yaml(path / "config.yaml")
370
+ torch.save(self.state_dict(), path / "model.pt")