ml-ryanlee commited on
Commit
f2ab902
·
verified ·
1 Parent(s): 900eeb9

Upload folder using huggingface_hub

Browse files
checkpoint_step_41340.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6340a2c575617d09a3f41218eba40046542166eaa9112bf417555209a4e839a
3
+ size 2952571035
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LoopLMForCausalLM"
4
+ ],
5
+ "context_length": 1024,
6
+ "d_ff": 2432,
7
+ "d_model": 896,
8
+ "dtype": "float32",
9
+ "lb_loss_factor": 0.01,
10
+ "lz_loss_factor": 0.001,
11
+ "max_length": 1024,
12
+ "model_type": "loop-lm",
13
+ "model_variant": "base",
14
+ "num_active": 2,
15
+ "num_experts": 8,
16
+ "num_heads": 14,
17
+ "num_layers": 16,
18
+ "num_layers_in_stack": 8,
19
+ "num_stacks": 2,
20
+ "rope_theta": 10000.0,
21
+ "transformers_version": "5.3.0",
22
+ "vocab_size": 50257,
23
+ "weight_tying": false,
24
+ "width_ratio": 7.0,
25
+ "auto_map": {
26
+ "AutoConfig": "modeling_loop_lm.LoopLMConfig",
27
+ "AutoModelForCausalLM": "modeling_loop_lm.LoopLMForCausalLM"
28
+ }
29
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "max_length": 1024,
4
+ "output_attentions": false,
5
+ "output_hidden_states": false,
6
+ "transformers_version": "5.3.0"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7fa4c3ee5a6a22401801919cbe7b13f8f727b8771d4db9fead77ff5d9a81180
3
+ size 992547408
modeling_loop_lm.py ADDED
@@ -0,0 +1,984 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained modeling file for trust_remote_code use.
2
+
3
+ This file merges mup_models.py and hf_wrapper.py into a single module with no
4
+ imports from looped_scaling.*. It is intended to be placed alongside a
5
+ config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that
6
+ HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it
7
+ without requiring the looped_scaling package to be installed.
8
+
9
+ Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer),
10
+ "moe" (MoETransformer), "looped-moe" (LoopedMoETransformer).
11
+ """
12
+
13
+ import torch
14
+ import math
15
+ import sys
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from collections.abc import Callable, Iterable
19
+ from einops import rearrange, einsum, reduce, repeat
20
+ from typing import IO, Any, BinaryIO, Optional
21
+ from torch import Tensor
22
+ from collections import Counter, defaultdict
23
+ from torch.nn.functional import scaled_dot_product_attention as sdpa # for flash attention
24
+ from torch.nn.functional import grouped_mm, silu
25
+ from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+
29
+ BASE_D_MODEL = 128
30
+ BASE_D_FF = 384
31
+
32
+ """ Standard Transformer and Components implemented with muP """
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Numerically stable softmax (inlined from looped_scaling/utils.py)
37
+ # ---------------------------------------------------------------------------
38
+
39
+ def softmax(logits: Tensor, dim: int) -> Tensor:
40
+ logits = logits.float()
41
+ # get max values over specified dimension
42
+ max_values = torch.max(logits, dim=dim, keepdim=True).values
43
+
44
+ # subtract max_values from x so max element is 0
45
+ shifted = logits - max_values # broadcast should work
46
+
47
+ # get exp of shifted terms
48
+ shifted_exps = torch.exp(shifted)
49
+
50
+ # get sum of shifted terms
51
+ shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True)
52
+
53
+ # calculate product
54
+ product = shifted_exps / shifted_exp_sums
55
+
56
+ return product
57
+
58
+
59
+ # y = Wx (no bias terms!)
60
+ class Linear(nn.Module):
61
+ def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
62
+ super().__init__()
63
+
64
+ # Register parameter first so shape is always stored (required for HF meta-device loading)
65
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))
66
+
67
+ # for muP, derive initial std deviation from given base model's std_deviation and width ratio
68
+ std_scaled = std_base / math.sqrt(width_ratio)
69
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
70
+
71
+ def forward(self, x: Tensor) -> Tensor:
72
+ # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
73
+ # on output side of einsum expression, so "... d_out" follows convention
74
+ # to put the output dim last
75
+ return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")
76
+
77
+ class Embedding(nn.Module):
78
+ def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
79
+ super().__init__()
80
+
81
+ # Register parameter first so shape is always stored (required for HF meta-device loading)
82
+ self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))
83
+
84
+ # normalize the embeddings to spec
85
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)
86
+
87
+ def forward(self, token_ids: Tensor) -> Tensor:
88
+ # for every id, we need to pull the row vector associated
89
+ return self.weight[token_ids]
90
+
91
+ class RMSNorm(nn.Module):
92
+ def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
93
+ super().__init__()
94
+
95
+ # for muP no gain parameter on the rms
96
+ self.d_model = d_model
97
+ self.eps = eps
98
+
99
+ def forward(self, x: Tensor) -> Tensor:
100
+ # upcast input to torch.float32
101
+ in_dtype = x.dtype
102
+ x = x.to(torch.float32)
103
+
104
+ # calculate the RMS scalar
105
+ # scalar for every ex. in batch, for every emb in sequence
106
+ mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq")
107
+ rms = torch.sqrt(mean_squared_sum + self.eps)
108
+
109
+ # for muP, no gain on rms norm as is normally applied.
110
+ rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d")
111
+
112
+ # return result to original dtype
113
+ return rms_norm.to(in_dtype)
114
+
115
+ class PositionwiseFeedforward(nn.Module):
116
+ # SwiGLU(x) = W2(SiLU(W1x)⊙W3x)
117
+ def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None):
118
+ super().__init__()
119
+
120
+ # for muP, calculate the base model's standard deviation
121
+ w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) # same for all W because d_model+d_ff = d_ff+d_model
122
+
123
+ # initialize parameters of SWiGLU FFN
124
+ self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
125
+ self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype)
126
+ self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
127
+
128
+ def forward(self, x: Tensor) -> Tensor:
129
+ # FFN = W2*(SiLU(W1*X) dot W3X)
130
+ silu_in = self.w1(x)
131
+ silu_out = silu(silu_in) # silu_in * torch.sigmoid(silu_in)
132
+ gate = self.w3(x)
133
+ gated_prod = silu_out * gate
134
+ final_prod = self.w2(gated_prod)
135
+ return final_prod
136
+
137
+ class RotaryPositionalEmbedding(nn.Module):
138
+ def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None):
139
+ """
140
+ theta: float Θ value for the RoPE
141
+ d_k: int dimension of query and key vectors
142
+ max_seq_len: int Maximum sequence length that will be inputted
143
+ device: torch.device | None = None Device to store the buffer on
144
+ """
145
+ super().__init__()
146
+ rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype)
147
+
148
+ # initialize rotation matrix
149
+ for i in range(max_seq_len):
150
+ for k in range(d_k//2):
151
+ angle = i/(theta**(2*k/d_k))
152
+ rot = Tensor([[math.cos(angle), -math.sin(angle)],
153
+ [math.sin(angle), math.cos(angle)]])
154
+ rotations[i, k, :] = rot
155
+
156
+ self.register_buffer("rotations", rotations, persistent=True)
157
+
158
+
159
+ def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
160
+ """
161
+ self.rotations shape: (seq_dim, feature_dim, 2, 2)
162
+ x: tensor of shape (..., seq_dim, feature_dim)
163
+ token_positions: tensor of shape (..., seq_dim)
164
+ """
165
+ # get the correct rotation matrices
166
+ # by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim
167
+ rot = self.rotations[token_positions].to(dtype=x.dtype) # match activation dtype (buffer is float32, activations may be bfloat16)
168
+
169
+ # rearrange by every two elements along feature dim of input x
170
+ x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2)
171
+
172
+ # apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,)
173
+ y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i")
174
+
175
+ # reshape y_pairs back to original shape
176
+ y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)")
177
+
178
+ return y
179
+
180
+ def scaled_dot_product_attention(
181
+ Q: Tensor,
182
+ K: Tensor,
183
+ V: Tensor,
184
+ mask: Optional[Tensor] = None,
185
+ ) -> Tensor:
186
+ """
187
+ Given key (K), query (Q), and value (V) tensors, return
188
+ the output of your scaled dot product attention implementation.
189
+
190
+ Args:
191
+ let m be seq length of inputs, n be seq length of outputs
192
+ d_k is look-up dim, d_v is value dim
193
+ Q (Float[Tensor, "batch ... n d_k"]): Query tensor
194
+ K (Float[Tensor, "batch ... m d_k"]): Key tensor
195
+ V (Float[Tensor, "batch ... m d_v"]): Values tensor
196
+ mask (Float[Tensor, " ... n m"] | None): Mask tensor
197
+ Returns:
198
+ Float[Tensor, " ... n d_v"]: Output of SDPA
199
+ """
200
+
201
+ # get the key feature dim (should be last dim of Q and K)
202
+ d_k = Q.shape[-1]
203
+ assert d_k == K.shape[-1]
204
+
205
+ # calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k)
206
+ scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k
207
+
208
+ # apply the mask if there is one
209
+ if mask is not None:
210
+ bool_mask = mask.bool() # compatible if somehow, input is mask bool or if float
211
+ attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype)
212
+ scores = scores + attn_mask
213
+
214
+ # calculate the weighted
215
+ weights = softmax(scores, dim=-1) # the softmax should be taken over the m inputs at an i'th output pos.
216
+
217
+ # return weights@V
218
+ return einsum(weights, V, "... n m, ... m d_v -> ... n d_v")
219
+
220
+ class MultiheadSelfAttention(nn.Module):
221
+ """
222
+ Args:
223
+ d_model (int): Dimensionality of the feedforward input and output.
224
+ num_heads (int): Number of heads to use in multi-headed attention.
225
+ max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that.
226
+ q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection
227
+ k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection
228
+ v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection
229
+ o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection
230
+ in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on.
231
+
232
+ Returns:
233
+ Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention
234
+ implementation with the given QKV projection weights and input features.
235
+ """
236
+ def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None):
237
+ super().__init__()
238
+
239
+ # initialize the multi-head self attention weights as 1 large matrix (which will be sliced)
240
+ assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
241
+
242
+ self.d_model = d_model
243
+ self.num_heads = num_heads
244
+
245
+ # for muP, calculate standard deviation of base model
246
+ attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL))
247
+
248
+ # for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev
249
+ self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
250
+ self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
251
+ self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
252
+ self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
253
+
254
+ # # Removed for torch sdpa, uncomment if using normal code
255
+ # if max_seq_len:
256
+ # causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device))
257
+ # self.register_buffer("causal_mask", causal_mask, persistent=False)
258
+ # else:
259
+ # self.register_buffer("causal_mask", None, persistent=False)
260
+
261
+ assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE."
262
+
263
+ if theta:
264
+ d_k = d_model//num_heads
265
+ self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype)
266
+ else:
267
+ self.rope = None
268
+
269
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
270
+ # get Q, K, V matrices
271
+ Q = self.q_proj(x) # output shape is [batch seq d_model]
272
+ K = self.k_proj(x)
273
+ V = self.v_proj(x)
274
+
275
+ # #create causal mask intepreting the second to last dim as seq dim
276
+ # if self.causal_mask is None:
277
+ # seq_dim = x.shape[-2]
278
+ # cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device))
279
+ # else:
280
+ # # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len)
281
+ # seq_dim = x.shape[-2]
282
+ # cmask = self.causal_mask[:seq_dim, :seq_dim]
283
+
284
+ # get slice size for multi-head self attention
285
+ d_k = self.d_model // self.num_heads
286
+ d_v = self.d_model // self.num_heads
287
+
288
+ q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
289
+ k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
290
+
291
+ # apply RoPE to q_heads and k_heads
292
+ if self.rope:
293
+ seq_dim = x.shape[-2] # x is (b,s,d)
294
+ if token_positions is None:
295
+ token_positions = torch.arange(seq_dim, device=x.device)
296
+ token_positions = rearrange(token_positions, "seq -> 1 seq") # 1 seq allows broadcast across batch dim
297
+
298
+ q_heads = self.rope(q_heads, token_positions)
299
+ k_heads = self.rope(k_heads, token_positions)
300
+
301
+ v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v)
302
+
303
+ #mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask)
304
+ mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k)
305
+ mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)")
306
+
307
+ # apply o_proj_weight to the concatenated multi-head attention product
308
+ out = self.output_proj(mha)
309
+
310
+ return out
311
+
312
+ class PrenormBlock(nn.Module):
313
+ def __init__(self,
314
+ d_model: int,
315
+ num_heads: int,
316
+ d_ff: int,
317
+ max_seq_len: int,
318
+ theta: float,
319
+ width_ratio: float,
320
+ device=None,
321
+ dtype=None):
322
+ super().__init__()
323
+ # norm layer
324
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
325
+ # mhsa with rope
326
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
327
+ # add step
328
+ # norm layer
329
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
330
+ # positionwise feed forward
331
+ self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype)
332
+ # add to output
333
+
334
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
335
+
336
+ # first Tx operation, Norm + MHSA w/ RoPE
337
+ norm1_out = self.ln1(x)
338
+ # we may have to define token_positions if it is not given
339
+ attn_out = self.attn(norm1_out, token_positions)
340
+
341
+ # ensure no broadcasting, elementwise addition on [batch seq d_model]
342
+ assert(x.shape == attn_out.shape)
343
+ resid1_out = attn_out + x
344
+
345
+ # second Tx operation, Norm + SwiGLU
346
+ norm2_out = self.ln2(resid1_out)
347
+ ffn_out = self.ffn(norm2_out)
348
+
349
+ # ensure no broadcasting, elementwise addition
350
+ assert(ffn_out.shape == resid1_out.shape)
351
+ final_out = resid1_out + ffn_out
352
+ return final_out
353
+
354
+ class MuTransformer(nn.Module):
355
+ def __init__(
356
+ self, vocab_size: int,
357
+ context_length: int,
358
+ d_model: int,
359
+ num_layers: int,
360
+ num_heads: int,
361
+ d_ff: int,
362
+ rope_theta: float,
363
+ width_ratio: float = 1.0,
364
+ weight_tying: bool = False,
365
+ device=None, dtype=None):
366
+ super().__init__()
367
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
368
+ self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
369
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
370
+ self.weight_tying = weight_tying
371
+ if weight_tying:
372
+ self.lm_head = self.token_embeddings.weight
373
+ else:
374
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
375
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
376
+ self.width_ratio = width_ratio
377
+
378
+ def forward(self, x: Tensor) -> Tensor:
379
+ # 1. token embed step, no muP alpha_in
380
+ x = self.token_embeddings(x)
381
+
382
+ # 2. prenorm blocks step
383
+ for layer in self.layers:
384
+ x = layer(x)
385
+
386
+ # 3. Final norm
387
+ x = self.ln_final(x)
388
+
389
+ # 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling
390
+ if self.weight_tying:
391
+ x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
392
+ else:
393
+ x = self.lm_head(x)
394
+
395
+ # 5. return output, no muP alpha_out
396
+ return x
397
+
398
+ """ Looped Language Models implemented with MuP """
399
+
400
+ class LoopedStack(nn.Module):
401
+ def __init__(
402
+ self,
403
+ context_length: int,
404
+ d_model: int,
405
+ num_layers_in_stack: int,
406
+ num_heads: int,
407
+ d_ff: int,
408
+ rope_theta: float,
409
+ width_ratio: float = 1.0,
410
+ mixture_of_experts: bool = False,
411
+ num_experts: Optional[int] = None,
412
+ num_active: Optional[int] = None,
413
+ device=None, dtype=None):
414
+ super().__init__()
415
+ if mixture_of_experts:
416
+ # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
417
+ # context_length,rope_theta,width_ratio,device,dtype)
418
+ # for _ in range(num_layers_in_stack)])
419
+ self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
420
+ context_length, rope_theta, width_ratio, device, dtype)
421
+ for _ in range(num_layers_in_stack)])
422
+ else:
423
+ self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta,
424
+ width_ratio, device, dtype) for _ in range(num_layers_in_stack)])
425
+ self.mixture_of_experts = mixture_of_experts
426
+
427
+ def forward(self, x: Tensor) -> Tensor:
428
+ # prenorm blocks step
429
+ if self.mixture_of_experts:
430
+ lb_total = 0
431
+ lz_total = 0
432
+ # sum up load balancing and z-losses across each layer
433
+ for layer in self.layers:
434
+ x, lb, lz = layer(x)
435
+ lb_total += lb
436
+ lz_total += lz
437
+ return x, lb_total, lz_total
438
+ else:
439
+ for layer in self.layers:
440
+ x = layer(x)
441
+ return x
442
+
443
+ class LoopedTransformer(nn.Module):
444
+ def __init__(
445
+ self,
446
+ vocab_size: int,
447
+ context_length: int,
448
+ d_model: int,
449
+ num_layers_in_stack: int,
450
+ num_stacks: int,
451
+ num_heads: int,
452
+ d_ff: int,
453
+ rope_theta: float,
454
+ width_ratio: float = 1.0,
455
+ weight_tying: bool = False,
456
+ device=None, dtype=None):
457
+ super().__init__()
458
+ self.num_stacks = num_stacks
459
+
460
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
461
+ self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype)
462
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
463
+ self.weight_tying = weight_tying
464
+ self.width_ratio = width_ratio
465
+
466
+ if weight_tying:
467
+ self.lm_head = self.token_embeddings.weight
468
+ else:
469
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
470
+ self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype)
471
+
472
+ def forward(self, x: Tensor) -> Tensor:
473
+ # token embed step
474
+ x = self.token_embeddings(x)
475
+
476
+ # repeated calls to stack
477
+ for i in range(self.num_stacks):
478
+ x = self.stack(x)
479
+
480
+ # final norm
481
+ x = self.ln_final(x)
482
+
483
+ # Vocab projection or lm_head
484
+ if self.weight_tying:
485
+ x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
486
+ else:
487
+ x = self.lm_head(x)
488
+
489
+ return x
490
+
491
+ """ Mixture-of-Experts Implementation in muP """
492
+
493
+ # Router Class
494
+ class Router(nn.Module):
495
+ def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None):
496
+ super().__init__()
497
+ # router is simply a linear layer. we initialize (d_in, d_out) according to my code
498
+ std_base = math.sqrt(2/(BASE_D_MODEL+num_experts))
499
+ self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) # adjusted for muP
500
+ self.num_active = num_active
501
+
502
+ def forward(self, x: Tensor):
503
+ # returns scores, top_k_scores, top_k_indices
504
+ logits = self.gate(x) # should be shape (batch, seq, n_routers)
505
+
506
+ # probs
507
+ probs = softmax(logits, dim=-1)
508
+
509
+ # get top_k
510
+ top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1)
511
+
512
+ # renormalize the top scores so weighted sum of expert products can be taken
513
+ score_sums = torch.sum(top_scores, dim=-1, keepdim=True) # (batch, seq)
514
+ top_scores = top_scores/score_sums
515
+
516
+ return logits, probs, top_scores, top_experts
517
+
518
+ class MoEPrenormBlock(nn.Module):
519
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
520
+ max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
521
+ super().__init__()
522
+ # norm layer before mHSA+RoPE
523
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
524
+
525
+ # mhsa with rope
526
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
527
+
528
+ # norm layer before position-wise feedfoward
529
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
530
+
531
+ # router
532
+ self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
533
+
534
+ # save MoE hyperparams
535
+ self.num_experts = num_experts
536
+ self.num_active = num_active
537
+
538
+ # initialize MoE FFNs as a module list
539
+ d_ff_expert = d_ff // num_active
540
+ self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) # adjusted for muP
541
+
542
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
543
+ # input dims
544
+ batch, seq, dim = x.shape
545
+
546
+ # first Tx operation, Norm + MHSA w/ RoPE
547
+ norm1_out = self.ln1(x)
548
+ # we may have to define token_positions if it is not given
549
+ attn_out = self.attn(norm1_out, token_positions)
550
+
551
+ # ensure no broadcasting, elementwise addition on [batch seq d_model]
552
+ assert(x.shape == attn_out.shape)
553
+ resid1_out = attn_out + x
554
+
555
+ # prenorm before position-wise feedforward
556
+ norm2_out = self.ln2(resid1_out)
557
+
558
+ # get scores from Router. returns shape (batch,seq,k)
559
+ logits, probs, top_scores, top_experts = self.router(norm2_out) # logits and probs are (batch, seq, n_routers)
560
+ expert_mean_probs = torch.mean(probs, dim=(0, 1)) # take mean across batch and seq dims
561
+
562
+ # apply mixture of experts
563
+ experts_out = torch.zeros_like(norm2_out) # copies shape, device and dtype
564
+ total_tokens_assigned = batch*seq*self.num_active
565
+ lb_sum = 0
566
+
567
+ for expert_idx in range(self.num_experts):
568
+ # get masks for expert selection
569
+ expert_mask = (top_experts == expert_idx)
570
+ embed_mask = expert_mask.any(dim=-1) # if any of the k is expert, we want to transform embed
571
+ if not embed_mask.any(): continue
572
+ pi = expert_mean_probs[expert_idx].item()
573
+ fi = (expert_mask.sum().item())/total_tokens_assigned # num embeds assigned to expert in batch
574
+ lb_sum += fi*pi
575
+
576
+ # extract embeds and weights for activated experts
577
+ weights = top_scores[expert_mask] # (num_embeds)
578
+ expert_embeds = norm2_out[embed_mask] # (num_embeds, hidden_dim)
579
+
580
+ # forward for the correct experts
581
+ expert_out = self.experts[expert_idx](expert_embeds) # Vanilla Implementation
582
+
583
+ # map back to experts output
584
+ experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out # broadcast elementwise multiply by hidden dim
585
+
586
+ # calculate batch's load balancing loss
587
+ lb = self.num_experts*lb_sum
588
+
589
+ # calculate batch's router z loss
590
+ logsumexp = torch.logsumexp(logits.float(), dim=-1)
591
+ lz = torch.mean(logsumexp ** 2)
592
+
593
+ # ensure no broadcasting, elementwise addition
594
+ assert(experts_out.shape == resid1_out.shape)
595
+ final_out = resid1_out + experts_out
596
+ return final_out, lb, lz
597
+
598
+
599
+ class GroupedMoEPrenormBlock(nn.Module):
600
+ @staticmethod
601
+ def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter:
602
+ w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) # (batch, in, out)
603
+ std_scaled = std_base / math.sqrt(width_ratio)
604
+ nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
605
+ return nn.Parameter(w)
606
+
607
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
608
+ max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
609
+ super().__init__()
610
+ # norm layer before mHSA+RoPE
611
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
612
+
613
+ # mhsa with rope
614
+ self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
615
+
616
+ # norm layer before position-wise feedfoward
617
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
618
+
619
+ # router
620
+ self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
621
+
622
+ # save MoE hyperparams
623
+ self.num_experts = num_experts
624
+ self.num_active = num_active
625
+
626
+ # initialize MoE FFNs as a module list
627
+ d_ff_expert = d_ff // num_active
628
+
629
+ # expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio
630
+ w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF))
631
+ self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
632
+ self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype)
633
+ self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
634
+
635
+ def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
636
+ batch, seq, dim = x.shape
637
+ total_tokens = batch * seq
638
+
639
+ # first Tx operation, Norm + MHSA w/ RoPE
640
+ norm1_out = self.ln1(x)
641
+ attn_out = self.attn(norm1_out, token_positions)
642
+
643
+ assert(x.shape == attn_out.shape)
644
+ resid1_out = attn_out + x
645
+
646
+ # prenorm before position-wise feedforward
647
+ norm2_out = self.ln2(resid1_out)
648
+
649
+ # get scores from Router. returns shape (batch, seq, k)
650
+ logits, probs, top_scores, top_experts = self.router(norm2_out)
651
+
652
+ # flatten to 2D for grouped_mm
653
+ x_flat = rearrange(norm2_out, 'b s d -> (b s) d') # (total_tokens, d_model)
654
+ flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') # (total_tokens * k,)
655
+ flat_scores = rearrange(top_scores, 'b s k -> (b s k)') # (total_tokens * k,)
656
+ flat_positions = torch.arange(total_tokens, device=x.device) # (total_tokens)
657
+ flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) # (total_tokens * k)
658
+
659
+ # sort by expert
660
+ sort_indices = flat_expert_ids.argsort(stable=True)
661
+ sorted_expert_ids = flat_expert_ids[sort_indices]
662
+ sorted_token_ids = flat_token_ids[sort_indices]
663
+ sorted_scores = flat_scores[sort_indices]
664
+ sorted_x = x_flat[sorted_token_ids] # (total_tokens * k, d_model)
665
+
666
+ # build offs (cumulative token counts per expert)
667
+ counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
668
+ offs = counts.cumsum(0).to(torch.int32) # (num_experts,)
669
+
670
+ # grouped SwiGLU: W2(SiLU(W1 x) dot W3 x)
671
+ h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs)
672
+ h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs)
673
+ gated = silu(h1) * h3
674
+ expert_out = grouped_mm(gated, self.experts_w2, offs=offs) # (total_tokens * k, d_model)
675
+
676
+ # weight by router scores and scatter-add back
677
+ expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d')
678
+ output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype)
679
+ output_flat.index_add_(0, sorted_token_ids, expert_out)
680
+
681
+ # reshape back to (batch, seq, d_model)
682
+ experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq)
683
+
684
+ # aux losses
685
+ fi = counts.float() / (total_tokens * self.num_active)
686
+ pi = reduce(probs, 'b s e -> e', 'mean')
687
+ lb = self.num_experts * einsum(fi, pi, 'e, e ->')
688
+
689
+ logsumexp = torch.logsumexp(logits.float(), dim=-1)
690
+ lz = reduce(logsumexp ** 2, '... -> ', 'mean')
691
+
692
+ # residual connection
693
+ assert(experts_out.shape == resid1_out.shape)
694
+ final_out = resid1_out + experts_out
695
+ return final_out, lb, lz
696
+
697
+
698
+ # MoE Implementation
699
+ class MoETransformer(nn.Module):
700
+ def __init__(
701
+ self, vocab_size: int,
702
+ context_length: int,
703
+ d_model: int,
704
+ num_layers: int,
705
+ num_heads: int,
706
+ d_ff: int,
707
+ num_experts: int,
708
+ num_active: int,
709
+ rope_theta: float,
710
+ width_ratio: float = 1.0,
711
+ device=None, dtype=None):
712
+ super().__init__()
713
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
714
+ self.num_layers = num_layers
715
+ # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
716
+ # context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)])
717
+ self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
718
+ context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
719
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
720
+
721
+ # only non-tied embeddings now
722
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
723
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
724
+
725
+ def forward(self, x: Tensor) -> Tensor:
726
+ # collect aux losses
727
+ lb_total = 0
728
+ lz_total = 0
729
+
730
+ # 1. token embed step
731
+ x = self.token_embeddings(x)
732
+
733
+ # 2. prenorm blocks step
734
+ for layer in self.layers:
735
+ x, lb, lz = layer(x)
736
+ lb_total += lb
737
+ lz_total += lz
738
+
739
+ # 3. Final norm
740
+ x = self.ln_final(x)
741
+
742
+ # 4. Vocab projection or lm_head
743
+ x = self.lm_head(x)
744
+
745
+ # calculate average layer aux loss
746
+ lb_avg = lb_total / self.num_layers
747
+ lz_avg = lz_total / self.num_layers
748
+
749
+ return x, lb_avg, lz_avg
750
+
751
+ class LoopedMoETransformer(nn.Module):
752
+ def __init__(
753
+ self, vocab_size: int,
754
+ context_length: int,
755
+ d_model: int,
756
+ num_layers_in_stack: int,
757
+ num_stacks: int,
758
+ num_heads: int,
759
+ d_ff: int,
760
+ num_experts: int,
761
+ num_active: int,
762
+ rope_theta: float,
763
+ width_ratio: float,
764
+ device=None, dtype=None):
765
+ super().__init__()
766
+ self.stack_depth = num_stacks
767
+ self.total_layers = num_stacks*num_layers_in_stack
768
+ self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
769
+ self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads,
770
+ d_ff, rope_theta, width_ratio, mixture_of_experts=True,
771
+ num_experts=num_experts, num_active=num_active,
772
+ device=device, dtype=dtype) # parameters for loop with MoE
773
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
774
+
775
+ # scale lm head
776
+ std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
777
+ self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
778
+
779
+
780
+ def forward(self, x: Tensor) -> Tensor:
781
+ # collect aux losses
782
+ lb_total = 0
783
+ lz_total = 0
784
+
785
+ # token embed step
786
+ x = self.token_embeddings(x)
787
+
788
+ # repeated calls to stack
789
+ for i in range(self.stack_depth):
790
+ x, lb, lz = self.stack(x)
791
+ lb_total += lb
792
+ lz_total += lz
793
+
794
+ # final norm
795
+ x = self.ln_final(x)
796
+
797
+ # Vocab projection or lm_head
798
+ x = self.lm_head(x)
799
+
800
+ # calculate aux loss averages
801
+ lb_avg = lb_total / self.total_layers
802
+ lz_avg = lz_total / self.total_layers
803
+
804
+ return x, lb_avg, lz_avg
805
+
806
+
807
+ # ---------------------------------------------------------------------------
808
+ # HuggingFace wrapper (from hf_wrapper.py)
809
+ # ---------------------------------------------------------------------------
810
+
811
+ class LoopLMConfig(PretrainedConfig):
812
+ """Config for all four loop-lm model variants."""
813
+
814
+ model_type = "loop-lm"
815
+
816
+ def __init__(
817
+ self,
818
+ # which of the four architectures to use
819
+ model_variant: str = "base", # "base" | "looped" | "moe" | "looped-moe"
820
+ # shared
821
+ vocab_size: int = 50257,
822
+ context_length: int = 1024,
823
+ d_model: int = 1024,
824
+ num_heads: int = 16,
825
+ d_ff: int = 2752,
826
+ rope_theta: float = 10000.0,
827
+ width_ratio: float = 8.0, # d_model / base_d_model (128); set at training time
828
+ # base + moe only
829
+ num_layers: int = 16,
830
+ # base + looped only
831
+ weight_tying: bool = False,
832
+ # looped + looped-moe only
833
+ num_layers_in_stack: int = 8,
834
+ num_stacks: int = 2,
835
+ # moe + looped-moe only
836
+ num_experts: int = 8,
837
+ num_active: int = 2,
838
+ # aux loss weights — used when forward() is called with labels
839
+ lb_loss_factor: float = 0.01,
840
+ lz_loss_factor: float = 0.001,
841
+ **kwargs,
842
+ ):
843
+ super().__init__(**kwargs)
844
+ self.model_variant = model_variant
845
+ self.vocab_size = vocab_size
846
+ self.context_length = context_length
847
+ self.d_model = d_model
848
+ self.num_heads = num_heads
849
+ self.d_ff = d_ff
850
+ self.rope_theta = rope_theta
851
+ self.width_ratio = width_ratio
852
+ self.num_layers = num_layers
853
+ self.weight_tying = weight_tying
854
+ self.num_layers_in_stack = num_layers_in_stack
855
+ self.num_stacks = num_stacks
856
+ self.num_experts = num_experts
857
+ self.num_active = num_active
858
+ self.lb_loss_factor = lb_loss_factor
859
+ self.lz_loss_factor = lz_loss_factor
860
+ # lm-evaluation-harness looks for this attribute to cap sequence length
861
+ self.max_length = context_length
862
+
863
+
864
+ class LoopLMForCausalLM(PreTrainedModel, GenerationMixin):
865
+ """Causal LM wrapper over all four looped-scaling variants.
866
+
867
+ Implements the HuggingFace PreTrainedModel interface so you can:
868
+ - Upload/download via push_to_hub / from_pretrained
869
+ - Run lm-evaluation-harness evals
870
+ - Fine-tune with TRL's SFTTrainer / DPOTrainer
871
+ """
872
+
873
+ config_class = LoopLMConfig
874
+ # tell HF which parameter holds the output logits for generation
875
+ _keys_to_ignore_on_load_missing = []
876
+
877
+ def __init__(self, config: LoopLMConfig):
878
+ super().__init__(config)
879
+ self.model = self._build_inner_model(config)
880
+ self.post_init()
881
+
882
+ # ------------------------------------------------------------------
883
+ # Model construction
884
+ # ------------------------------------------------------------------
885
+
886
+ def _build_inner_model(self, config: LoopLMConfig):
887
+ kw = dict(
888
+ vocab_size=config.vocab_size,
889
+ context_length=config.context_length,
890
+ d_model=config.d_model,
891
+ num_heads=config.num_heads,
892
+ d_ff=config.d_ff,
893
+ rope_theta=config.rope_theta,
894
+ width_ratio=config.width_ratio,
895
+ # device=None so weights are placed on CPU; caller uses .to(device)
896
+ )
897
+ v = config.model_variant
898
+ if v == "base":
899
+ return MuTransformer(
900
+ **kw,
901
+ num_layers=config.num_layers,
902
+ weight_tying=config.weight_tying,
903
+ )
904
+ elif v == "looped":
905
+ return LoopedTransformer(
906
+ **kw,
907
+ num_layers_in_stack=config.num_layers_in_stack,
908
+ num_stacks=config.num_stacks,
909
+ weight_tying=config.weight_tying,
910
+ )
911
+ elif v == "moe":
912
+ return MoETransformer(
913
+ **kw,
914
+ num_layers=config.num_layers,
915
+ num_experts=config.num_experts,
916
+ num_active=config.num_active,
917
+ )
918
+ elif v == "looped-moe":
919
+ return LoopedMoETransformer(
920
+ **kw,
921
+ num_layers_in_stack=config.num_layers_in_stack,
922
+ num_stacks=config.num_stacks,
923
+ num_experts=config.num_experts,
924
+ num_active=config.num_active,
925
+ )
926
+ else:
927
+ raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe")
928
+
929
+ # ------------------------------------------------------------------
930
+ # Embedding access (required by some HF utilities)
931
+ # ------------------------------------------------------------------
932
+
933
+ def get_input_embeddings(self):
934
+ return self.model.token_embeddings
935
+
936
+ def set_input_embeddings(self, value):
937
+ self.model.token_embeddings = value
938
+
939
+ # ------------------------------------------------------------------
940
+ # Forward
941
+ # ------------------------------------------------------------------
942
+
943
+ def forward(
944
+ self,
945
+ input_ids: torch.LongTensor,
946
+ attention_mask: Optional[torch.Tensor] = None, # causal mask is handled internally
947
+ labels: Optional[torch.LongTensor] = None,
948
+ **kwargs,
949
+ ) -> CausalLMOutputWithPast:
950
+ """
951
+ Args:
952
+ input_ids: (batch, seq)
953
+ attention_mask: ignored — models use a built-in causal mask
954
+ labels: (batch, seq) token ids; if provided, returns cross-entropy loss.
955
+ For MoE variants, aux losses (lb + lz) are added to the CE loss.
956
+ """
957
+ is_moe = self.config.model_variant in ("moe", "looped-moe")
958
+
959
+ if is_moe:
960
+ logits, lb, lz = self.model(input_ids)
961
+ else:
962
+ logits = self.model(input_ids)
963
+ lb = lz = 0.0
964
+
965
+ loss = None
966
+ if labels is not None:
967
+ ce_loss = F.cross_entropy(
968
+ logits.view(-1, logits.size(-1)),
969
+ labels.view(-1),
970
+ )
971
+ aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz
972
+ loss = ce_loss + aux if self.training else ce_loss
973
+
974
+ return CausalLMOutputWithPast(
975
+ loss=loss,
976
+ logits=logits,
977
+ )
978
+
979
+ # ------------------------------------------------------------------
980
+ # Generation support (no KV cache — generation is correct but slow)
981
+ # ------------------------------------------------------------------
982
+
983
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
984
+ return {"input_ids": input_ids}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1024,
9
+ "pad_token": null,
10
+ "tokenizer_class": "GPT2Tokenizer",
11
+ "unk_token": "<|endoftext|>"
12
+ }