Text Generation
gustavlangstroem commited on
Commit
9987dd2
·
verified ·
1 Parent(s): 82c5207

Upload 4 files

Browse files
Files changed (3) hide show
  1. gutenberg_tokenizer.json +0 -0
  2. microexpert.py +2024 -0
  3. tokenizer.py +57 -0
gutenberg_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
microexpert.py ADDED
@@ -0,0 +1,2024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MicroExperts — Self-organizing dynamic Mixture-of-Experts for continual learning.
3
+
4
+
5
+ Target hardware: Apple M4 with 48 GB unified memory.
6
+ """
7
+
8
+ import time
9
+ import math
10
+ import uuid
11
+ import json
12
+ import numpy as np
13
+ import mlx.core as mx
14
+ import mlx.nn as nn
15
+ import mlx.optimizers as optim
16
+ from mlx.utils import tree_flatten
17
+ from datasets import load_dataset
18
+ from transformers import PreTrainedTokenizerFast
19
+ import os
20
+ import glob
21
+ import re
22
+ from dataclasses import dataclass, field
23
+ from typing import Dict, List, Optional, Tuple, Any
24
+ from collections import defaultdict
25
+
26
+
27
+
28
+ def one_hot(indices: mx.array, num_classes: int) -> mx.array:
29
+
30
+ # Build a range vector [0, 1, ..., num_classes-1] and compare with indices
31
+ flat = indices.reshape(-1) # (K,)
32
+ arange = mx.arange(num_classes) # (num_classes,)
33
+ oh = (flat[:, None] == arange[None, :]).astype(mx.float32) # (K, num_classes)
34
+ return oh.reshape(*indices.shape, num_classes)
35
+
36
+ # ==========================================
37
+ # 1. CONFIGURATION
38
+ # ==========================================
39
+ @dataclass
40
+ class ModelArgs:
41
+ dim: int = 768
42
+ n_layers: int = 12
43
+ n_heads: int = 12
44
+ n_kv_heads: int = 12
45
+ vocab_size: int = -1
46
+ norm_eps: float = 1e-8
47
+ max_seq_len: int = 2048
48
+ rope_theta: float = 10000.0
49
+
50
+
51
+ @dataclass
52
+ class MicroExpertConfig:
53
+ """All hyperparameters for the MicroExperts MoE system."""
54
+ #tier_hidden_dims: Tuple[int, ...] = (512, 1024, 2048, 4096)
55
+ tier_hidden_dims: Tuple[int, ...] = (256, 512, 1024, 2048)
56
+
57
+ monolith_split_enabled: bool = True
58
+ monolith_variance_ema_alpha: float = 0.02
59
+ monolith_variance_z_threshold: float = 1.5
60
+
61
+ # Router
62
+ router_embed_dim: int = 128
63
+ min_experts_per_token: int = 1
64
+ max_experts_per_token: int = 64
65
+
66
+ # Cannibalization / lifecycle
67
+ ema_fast_alpha: float = 0.05
68
+ ema_slow_alpha: float = 0.005
69
+ split_threshold: float = 2.0
70
+ # Relaxed merge thresholds so merges actually fire
71
+ merge_co_route_threshold: float = 0.5
72
+ merge_weakness_threshold: float = 0.05
73
+ death_threshold: float = 0.001
74
+ min_expert_age: int = 50
75
+ cooldown_steps: int = 100
76
+ # Base freeze duration — actual duration scaled by importance
77
+ preserver_base_freeze_steps: int = 100
78
+ preserver_max_freeze_steps: int = 200
79
+ adapter_noise_scale: float = 0.02
80
+
81
+
82
+ max_experts_per_layer: int = 12
83
+ max_params_per_layer: int = 20_000_000 # 20 M
84
+
85
+ # Initial state
86
+ init_tier: int = 2
87
+
88
+ # Interference
89
+ interference_subsample: int = 64
90
+
91
+ # Load balance loss
92
+ load_balance_weight: float = 0.01
93
+
94
+ # Capacity-pressure merge: trigger when pool exceeds this fraction of budget
95
+ merge_capacity_pressure_frac: float = 0.8
96
+ # Tier-gravity merge: same-tier co-activation threshold (lower than fragment)
97
+ merge_tier_gravity_co_route: float = 0.4
98
+ merge_tier_gravity_min_co_activation: float = 0.3 # both activated > 30 % of tokens
99
+
100
+
101
+ density_ema_alpha: float = 0.02
102
+ density_spike_z: float = 2.5 # z-score above mean to flag distribution shift
103
+
104
+
105
+ @dataclass
106
+ class TrainConfig:
107
+ """Training hyperparameters."""
108
+ mode: str = "pretrain"
109
+ batch_size: int = 8
110
+ learning_rate: float = 3e-4
111
+ max_steps: int = 30_000
112
+ tokenizer_file: str = "gutenberg_tokenizer.json"
113
+ checkpoint_dir: str = "checkpoints_me"
114
+ log_every: int = 10
115
+ summary_every: int = 500
116
+ checkpoint_every: int = 1000
117
+ lifecycle_every: int = 10
118
+
119
+ # Active learning
120
+ al_data_dir: str = "./domains"
121
+ al_steps_per_domain: int = 2000
122
+ al_learning_rate: float = 1e-4
123
+ al_lifecycle_every: int = 5
124
+ al_split_threshold: float = 1.5
125
+ al_min_expert_age: int = 100
126
+
127
+
128
+ # ==========================================
129
+ # 2. EXPERT MODULE
130
+ # ==========================================
131
+ class Expert(nn.Module):
132
+ """Single MicroExpert: SwiGLU FFN."""
133
+
134
+ def __init__(self, model_dim: int, hidden_dim: int):
135
+ super().__init__()
136
+ self.w1 = nn.Linear(model_dim, hidden_dim, bias=False)
137
+ self.w2 = nn.Linear(hidden_dim, model_dim, bias=False)
138
+ self.w3 = nn.Linear(model_dim, hidden_dim, bias=False)
139
+
140
+ def __call__(self, x):
141
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
142
+
143
+
144
+ # ==========================================
145
+ # 3. EXPERT METADATA
146
+ # ==========================================
147
+ @dataclass
148
+ class ExpertMeta:
149
+ """Non-parameter state for one expert."""
150
+ expert_id: str
151
+ tier: int
152
+ hidden_dim: int
153
+ age: int = 0
154
+ cooldown: int = 0
155
+ frozen_steps: int = 0
156
+ ema_interference_fast: float = 0.0
157
+ ema_interference_slow: float = 0.0
158
+ ema_interference_var: float = 1.0
159
+ avg_routing_weight: float = 0.1
160
+ avg_activation_freq: float = 0.1
161
+ parent_id: Optional[str] = None
162
+ generation: int = 0
163
+
164
+ def to_dict(self) -> dict:
165
+ return {
166
+ "expert_id": self.expert_id, "tier": self.tier,
167
+ "hidden_dim": self.hidden_dim, "age": self.age,
168
+ "cooldown": self.cooldown, "frozen_steps": self.frozen_steps,
169
+ "ema_fast": self.ema_interference_fast,
170
+ "ema_slow": self.ema_interference_slow,
171
+ "ema_var": self.ema_interference_var,
172
+ "avg_rw": self.avg_routing_weight,
173
+ "avg_af": self.avg_activation_freq,
174
+ "parent_id": self.parent_id, "generation": self.generation,
175
+ }
176
+
177
+
178
+ # ==========================================
179
+ # 4. EXPERT EMBEDDING (trainable nn.Module)
180
+ # ==========================================
181
+ class ExpertEmbedding(nn.Module):
182
+
183
+ def __init__(self, dim: int, init: Optional[mx.array] = None):
184
+ super().__init__()
185
+ if init is not None:
186
+ self.embedding = init
187
+ else:
188
+ scale = 1.0 / math.sqrt(dim)
189
+ self.embedding = mx.random.normal((dim,)) * scale
190
+
191
+
192
+ # ==========================================
193
+ # 5. ADAPTIVE ROUTER
194
+ # ==========================================
195
+ class AdaptiveRouter(nn.Module):
196
+
197
+ def __init__(self, model_dim: int, config: MicroExpertConfig):
198
+ super().__init__()
199
+ self.config = config
200
+ self.d = config.router_embed_dim
201
+ self.proj = nn.Linear(model_dim, self.d, bias=False)
202
+ self.threshold_head = nn.Linear(model_dim, 1, bias=True)
203
+
204
+ # Trainable embeddings — list of nn.Module (MLX discovers these)
205
+ self.embeddings: List[ExpertEmbedding] = []
206
+ # Parallel ID list (same order)
207
+ self._emb_ids: List[str] = []
208
+
209
+ def _id_to_idx(self, eid: str) -> int:
210
+ return self._emb_ids.index(eid)
211
+
212
+ def add_expert(self, expert_id: str, init_embedding: Optional[mx.array] = None):
213
+ emb = ExpertEmbedding(self.d, init=init_embedding)
214
+ mx.eval(emb.parameters())
215
+ self.embeddings.append(emb)
216
+ self._emb_ids.append(expert_id)
217
+
218
+ def remove_expert(self, expert_id: str):
219
+ if expert_id not in self._emb_ids:
220
+ return
221
+ idx = self._id_to_idx(expert_id)
222
+ self.embeddings.pop(idx)
223
+ self._emb_ids.pop(idx)
224
+
225
+ def get_embedding(self, expert_id: str) -> mx.array:
226
+ return self.embeddings[self._id_to_idx(expert_id)].embedding
227
+
228
+ def set_embedding(self, expert_id: str, emb: mx.array):
229
+ self.embeddings[self._id_to_idx(expert_id)].embedding = emb
230
+
231
+ def __call__(self, x: mx.array, expert_ids: List[str]):
232
+ """
233
+ Returns:
234
+ routing_weights: (B, L, N) sparse softmax-normalized
235
+ raw_scores: (B, L, N) cosine similarities
236
+ density: (B, L) active expert count per token
237
+ """
238
+ B, L, D = x.shape
239
+ N = len(expert_ids)
240
+
241
+ if N == 0:
242
+ z = mx.zeros((B, L, 1))
243
+ return z[:, :, :0], z[:, :, :0], mx.zeros((B, L))
244
+
245
+ # Project input to routing space and normalize
246
+ h = self.proj(x) # (B, L, d)
247
+ h_norm = h / (mx.linalg.norm(h, axis=-1, keepdims=True) + 1e-8)
248
+
249
+ # Stack expert embeddings into matrix
250
+ E = mx.stack([self.embeddings[self._emb_ids.index(eid)].embedding
251
+ for eid in expert_ids], axis=0) # (N, d)
252
+ E_norm = E / (mx.linalg.norm(E, axis=-1, keepdims=True) + 1e-8)
253
+
254
+ raw_scores = h_norm @ E_norm.T # (B, L, N)
255
+
256
+ # Adaptive per-token threshold
257
+ threshold = mx.sigmoid(self.threshold_head(x)) # (B, L, 1)
258
+ gate_mask = (raw_scores > threshold).astype(mx.float32)
259
+
260
+ # Guarantee top-1 always active
261
+ best_idx = mx.argmax(raw_scores, axis=-1) # (B, L)
262
+ best_oh = one_hot(best_idx, N) # (B, L, N)
263
+ gate_mask = mx.maximum(gate_mask, best_oh)
264
+
265
+ # Cap maximum active experts
266
+ max_k = self.config.max_experts_per_token
267
+ if max_k < N:
268
+ sorted_idx = mx.argsort(-raw_scores, axis=-1)
269
+ rank = mx.argsort(sorted_idx, axis=-1)
270
+ gate_mask = gate_mask * (rank < max_k).astype(mx.float32)
271
+
272
+ # Softmax over active experts
273
+ masked = raw_scores * gate_mask + (1.0 - gate_mask) * (-1e9)
274
+ routing_weights = mx.softmax(masked, axis=-1) * gate_mask
275
+
276
+ density = gate_mask.sum(axis=-1)
277
+ return routing_weights, raw_scores, density
278
+
279
+
280
+ # ==========================================
281
+ # 6. UTILITY: zero a nested grad tree
282
+ # ==========================================
283
+ def _zero_tree(tree):
284
+ """Recursively zero all mx.arrays in a nested structure."""
285
+ if isinstance(tree, mx.array):
286
+ return mx.zeros_like(tree)
287
+ elif isinstance(tree, dict):
288
+ return {k: _zero_tree(v) for k, v in tree.items()}
289
+ elif isinstance(tree, list):
290
+ return [_zero_tree(v) for v in tree]
291
+ return tree
292
+
293
+
294
+ # ==========================================
295
+ # 7. MoE LAYER
296
+ # ==========================================
297
+ class MicroExpertsMoELayer(nn.Module):
298
+
299
+ def __init__(self, model_dim: int, config: MicroExpertConfig, layer_idx: int):
300
+ super().__init__()
301
+ self.model_dim = model_dim
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ self.router = AdaptiveRouter(model_dim, config)
305
+ self._variance_ema: Dict[str, float] = {}
306
+ self._variance_ema_sq: Dict[str, float] = {}
307
+
308
+ # Expert modules — list for MLX parameter discovery
309
+ self.expert_modules: List[Expert] = []
310
+ self._expert_id_list: List[str] = []
311
+ self._expert_meta: Dict[str, ExpertMeta] = {}
312
+ self._lifecycle_log: List[str] = []
313
+ self.global_step: int = 0
314
+
315
+ # Cached from forward pass (detached)
316
+ self._last_routing_weights: Optional[mx.array] = None
317
+ self._last_density: Optional[mx.array] = None
318
+ self._last_input: Optional[mx.array] = None
319
+ # FIX: Cache expert outputs to avoid redundant forward in interference
320
+ self._last_expert_outputs: Optional[List[mx.array]] = None
321
+
322
+ # Frozen expert tracking
323
+ self._frozen_eids: set = set()
324
+
325
+ # FIX: Density drift tracking
326
+ self._density_ema: float = 1.0
327
+ self._density_var: float = 1.0
328
+ self._drift_detected: bool = False
329
+
330
+ # Create initial monolith
331
+ self._create_expert(tier=config.init_tier)
332
+
333
+ # --- Helpers ---
334
+ @property
335
+ def expert_ids(self) -> List[str]:
336
+ return list(self._expert_id_list)
337
+
338
+ def _eid_to_index(self, eid: str) -> int:
339
+ return self._expert_id_list.index(eid)
340
+
341
+ def _get_expert(self, eid: str) -> Expert:
342
+ return self.expert_modules[self._eid_to_index(eid)]
343
+
344
+ def _tier_to_hidden(self, tier: int) -> int:
345
+ t = min(tier, len(self.config.tier_hidden_dims) - 1)
346
+ return self.config.tier_hidden_dims[t]
347
+
348
+ def _expert_param_count(self, tier: int) -> int:
349
+ return 3 * self.model_dim * self._tier_to_hidden(tier)
350
+
351
+ def _total_params(self) -> int:
352
+ return sum(self._expert_param_count(m.tier) for m in self._expert_meta.values())
353
+
354
+ def _make_id(self) -> str:
355
+ return uuid.uuid4().hex[:12]
356
+
357
+ """
358
+ def _copy_optimizer_state(self, optimizer, parent_idx: int, child_eid: str):
359
+ try:
360
+ layers_state = optimizer.state.get("layers", [])
361
+ if self.layer_idx >= len(layers_state):
362
+ return
363
+ moe_state = layers_state[self.layer_idx].get("moe", {})
364
+ expert_states = moe_state.get("expert_modules", [])
365
+ if parent_idx >= len(expert_states):
366
+ return
367
+
368
+ parent_state = expert_states[parent_idx]
369
+ child_idx = self._eid_to_index(child_eid)
370
+
371
+ # Grow the list if needed
372
+ while len(expert_states) <= child_idx:
373
+ expert_states.append({})
374
+
375
+ # Deep copy the parent state
376
+ import copy
377
+ expert_states[child_idx] = copy.deepcopy(parent_state)
378
+ except (KeyError, IndexError, TypeError):
379
+ pass
380
+ """
381
+ def _copy_optimizer_state(self, optimizer, parent_idx: int, children_eids: list):
382
+ """Copy parent's optimizer state to children, then rebuild list."""
383
+ try:
384
+ layers_state = optimizer.state.get("layers", [])
385
+ if self.layer_idx >= len(layers_state):
386
+ return
387
+ moe_state = layers_state[self.layer_idx].get("moe", {})
388
+ expert_states = moe_state.get("expert_modules", [])
389
+ if parent_idx >= len(expert_states):
390
+ return
391
+
392
+ import copy
393
+ parent_state = copy.deepcopy(expert_states[parent_idx])
394
+
395
+ # Build new list matching current expert_modules order
396
+ new_states = []
397
+ for i, eid in enumerate(self._expert_id_list):
398
+ if eid in children_eids:
399
+ new_states.append(copy.deepcopy(parent_state))
400
+ elif i < len(expert_states):
401
+ new_states.append(expert_states[i])
402
+ else:
403
+ new_states.append({})
404
+
405
+ moe_state["expert_modules"] = new_states
406
+ except (KeyError, IndexError, TypeError):
407
+ pass
408
+
409
+ # --- Expert creation / removal ---
410
+ def _create_expert(
411
+ self, tier: int,
412
+ parent_id: Optional[str] = None,
413
+ init_weights_from: Optional[Expert] = None,
414
+ noise_scale: float = 0.0,
415
+ frozen_steps: int = 0,
416
+ init_embedding: Optional[mx.array] = None,
417
+ ) -> str:
418
+ eid = self._make_id()
419
+ hidden = self._tier_to_hidden(tier)
420
+ expert = Expert(self.model_dim, hidden)
421
+
422
+ if init_weights_from is not None:
423
+ src = dict(tree_flatten(init_weights_from.parameters()))
424
+ dst = dict(tree_flatten(expert.parameters()))
425
+ pairs = []
426
+ for k in dst:
427
+ if k in src and src[k].shape == dst[k].shape:
428
+ w = src[k]
429
+ if noise_scale > 0:
430
+ w = w + mx.random.normal(w.shape) * noise_scale * (mx.abs(w).mean() + 1e-8)
431
+ pairs.append((k, w))
432
+ if pairs:
433
+ expert.load_weights(pairs)
434
+
435
+ mx.eval(expert.parameters())
436
+
437
+ self.expert_modules.append(expert)
438
+ self._expert_id_list.append(eid)
439
+
440
+ gen = 0
441
+ if parent_id and parent_id in self._expert_meta:
442
+ gen = self._expert_meta[parent_id].generation + 1
443
+
444
+ self._expert_meta[eid] = ExpertMeta(
445
+ expert_id=eid, tier=tier, hidden_dim=hidden,
446
+ frozen_steps=frozen_steps, parent_id=parent_id, generation=gen,
447
+ )
448
+ if frozen_steps > 0:
449
+ self._frozen_eids.add(eid)
450
+
451
+ self.router.add_expert(eid, init_embedding=init_embedding)
452
+ return eid
453
+
454
+ def _remove_expert(self, eid: str):
455
+ if eid not in self._expert_id_list:
456
+ return
457
+ idx = self._eid_to_index(eid)
458
+ self.expert_modules.pop(idx)
459
+ self._expert_id_list.pop(idx)
460
+ self._expert_meta.pop(eid, None)
461
+ self._frozen_eids.discard(eid)
462
+ self.router.remove_expert(eid)
463
+
464
+ # --- Forward ---
465
+ def __call__(self, x: mx.array) -> mx.array:
466
+ B, L, D = x.shape
467
+ N = len(self._expert_id_list)
468
+ if N == 0:
469
+ return mx.zeros_like(x)
470
+
471
+ routing_weights, raw_scores, density = self.router(x, self._expert_id_list)
472
+
473
+ # Compute and cache individual expert outputs
474
+ expert_outputs = [self.expert_modules[i](x) for i in range(N)]
475
+
476
+ output = mx.zeros_like(x)
477
+ for i in range(N):
478
+ w_i = routing_weights[:, :, i:i + 1]
479
+ output = output + w_i * expert_outputs[i]
480
+
481
+ # Cache detached copies for interference computation
482
+ self._last_routing_weights = mx.stop_gradient(routing_weights)
483
+ self._last_density = mx.stop_gradient(density)
484
+ self._last_input = mx.stop_gradient(x)
485
+ self._last_expert_outputs = [mx.stop_gradient(eo) for eo in expert_outputs]
486
+
487
+ return output
488
+
489
+ # --- Load balance loss ---
490
+ def load_balance_loss(self) -> mx.array:
491
+ """
492
+ Variance of per-expert activation frequency across the last batch.
493
+ Penalizes uneven usage — prevents expert starvation without forcing
494
+ uniform routing (which would defeat specialization).
495
+ """
496
+ if self._last_routing_weights is None:
497
+ return mx.array(0.0)
498
+
499
+ N = self._last_routing_weights.shape[-1]
500
+ if N <= 1:
501
+ return mx.array(0.0)
502
+
503
+ # Per-expert fraction of tokens where it's active (weight > 0.01)
504
+ active = (self._last_routing_weights > 0.01).astype(mx.float32)
505
+ freq = active.reshape(-1, N).mean(axis=0)
506
+
507
+ return freq.var()
508
+
509
+ # --- Frozen gradient zeroing ---
510
+ def zero_frozen_grads(self, expert_grads: Any) -> Any:
511
+ """Zero gradients for the expert_modules subtree of frozen experts."""
512
+ if not self._frozen_eids or not isinstance(expert_grads, list):
513
+ return expert_grads
514
+ result = []
515
+ for i, g in enumerate(expert_grads):
516
+ eid = self._expert_id_list[i] if i < len(self._expert_id_list) else None
517
+ if eid and eid in self._frozen_eids:
518
+ result.append(_zero_tree(g))
519
+ else:
520
+ result.append(g)
521
+ return result
522
+
523
+ def dr(self):
524
+ """Update density EMA and detect distribution shift spikes."""
525
+ if self._last_density is None:
526
+ return
527
+ cfg = self.config
528
+ current = self._last_density.mean().item()
529
+ alpha = cfg.density_ema_alpha
530
+
531
+ # Update EMA of density
532
+ old_ema = self._density_ema
533
+ self._density_ema = (1 - alpha) * self._density_ema + alpha * current
534
+ diff = current - old_ema
535
+ self._density_var = (1 - alpha) * self._density_var + alpha * diff * diff
536
+
537
+ # Z-score spike detection
538
+ std = math.sqrt(max(self._density_var, 1e-8))
539
+ z = (current - self._density_ema) / std
540
+ self._drift_detected = z > cfg.density_spike_z
541
+
542
+ if self._drift_detected:
543
+ msg = (f"[step {self.global_step}][L{self.layer_idx}] "
544
+ f"DRIFT density={current:.1f} ema={self._density_ema:.1f} z={z:.1f}")
545
+ self._lifecycle_log.append(msg)
546
+ print(msg)
547
+
548
+ def compute_interference(self) -> Dict[str, float]:
549
+ if (self._last_routing_weights is None or self._last_input is None
550
+ or self._last_expert_outputs is None):
551
+ return {}
552
+
553
+ x = self._last_input
554
+ rw = self._last_routing_weights
555
+ B, L, D = x.shape
556
+ N = len(self._expert_id_list)
557
+ if N == 0:
558
+ return {}
559
+
560
+ T = min(self.config.interference_subsample, B * L)
561
+ rw_flat = rw.reshape(-1, N)[:T]
562
+
563
+ # Use cached expert outputs instead of re-running forward passes
564
+ expert_outs_flat = [eo.reshape(-1, D)[:T] for eo in self._last_expert_outputs]
565
+
566
+ # Combined mixture output on subsample
567
+ combined = mx.zeros((T, D))
568
+ for i in range(N):
569
+ combined = combined + rw_flat[:, i:i + 1] * expert_outs_flat[i]
570
+ combined = mx.stop_gradient(combined)
571
+
572
+ interference = {}
573
+ for i in range(N):
574
+ eid = self._expert_id_list[i]
575
+ w_i = rw_flat[:, i]
576
+ e_out = expert_outs_flat[i]
577
+ active = (w_i > 0.01).astype(mx.float32)
578
+ n_active = active.sum().item()
579
+ if n_active < 1.0:
580
+ interference[eid] = 0.0
581
+ continue
582
+ diff_norm = mx.linalg.norm(combined - e_out, axis=-1)
583
+ e_norm = mx.linalg.norm(e_out, axis=-1) + 1e-8
584
+ relative = diff_norm / e_norm
585
+ score = (relative * w_i * active).sum() / (n_active + 1e-8)
586
+ interference[eid] = score.item()
587
+
588
+ mx.eval(list(interference.values()))
589
+ return interference
590
+
591
+ def _compute_monolith_split_scores(self) -> Dict[str, float]:
592
+ scores = {}
593
+ if self._last_expert_outputs is None or not self.config.monolith_split_enabled:
594
+ return scores
595
+ cfg = self.config
596
+ for i, eid in enumerate(self._expert_id_list):
597
+ if i >= len(self._last_expert_outputs):
598
+ continue
599
+ eo = self._last_expert_outputs[i]
600
+ norms = mx.linalg.norm(eo.reshape(-1, eo.shape[-1]), axis=-1)
601
+ var = norms.var().item()
602
+ alpha = cfg.monolith_variance_ema_alpha
603
+ prev_mean = self._variance_ema.get(eid, var)
604
+ prev_sq = self._variance_ema_sq.get(eid, var * var)
605
+ new_mean = (1 - alpha) * prev_mean + alpha * var
606
+ new_sq = (1 - alpha) * prev_sq + alpha * var * var
607
+ self._variance_ema[eid] = new_mean
608
+ self._variance_ema_sq[eid] = new_sq
609
+ running_std = math.sqrt(max(new_sq - new_mean * new_mean, 1e-8))
610
+ z = (var - new_mean) / running_std
611
+ scores[eid] = z
612
+ return scores
613
+
614
+ # --- Lifecycle ---
615
+ def lifecycle_step(self, optimizer=None):
616
+
617
+ self.dr()
618
+
619
+ interference = self.compute_interference()
620
+ events = []
621
+ all_ids = list(self._expert_id_list) # snapshot before mutations
622
+
623
+
624
+ monolith_scores = self._compute_monolith_split_scores()
625
+ N = len(all_ids)
626
+
627
+ for eid in all_ids:
628
+ meta = self._expert_meta.get(eid)
629
+ if meta is None:
630
+ continue
631
+ meta.age += 1
632
+ if meta.cooldown > 0:
633
+ meta.cooldown -= 1
634
+ if meta.frozen_steps > 0:
635
+ meta.frozen_steps -= 1
636
+ if meta.frozen_steps == 0:
637
+ self._frozen_eids.discard(eid)
638
+
639
+ # Routing stats from cached data
640
+ if self._last_routing_weights is not None and eid in self._expert_id_list:
641
+ idx = self._eid_to_index(eid)
642
+ if idx < self._last_routing_weights.shape[-1]:
643
+ w = self._last_routing_weights[:, :, idx]
644
+ meta.avg_routing_weight = (
645
+ 0.95 * meta.avg_routing_weight + 0.05 * w.mean().item()
646
+ )
647
+ meta.avg_activation_freq = (
648
+ 0.95 * meta.avg_activation_freq
649
+ + 0.05 * (w > 0.01).astype(mx.float32).mean().item()
650
+ )
651
+
652
+ # Interference EMAs
653
+ intf = interference.get(eid, 0.0)
654
+ af = self.config.ema_fast_alpha
655
+ asl = self.config.ema_slow_alpha
656
+ meta.ema_interference_fast = (1 - af) * meta.ema_interference_fast + af * intf
657
+ meta.ema_interference_slow = (1 - asl) * meta.ema_interference_slow + asl * intf
658
+ diff = intf - meta.ema_interference_slow
659
+ meta.ema_interference_var = 0.99 * meta.ema_interference_var + 0.01 * diff * diff
660
+
661
+ # Score by cannibalization z-score
662
+ scored = []
663
+ for eid in all_ids:
664
+ meta = self._expert_meta.get(eid)
665
+ if meta is None or eid not in self._expert_id_list:
666
+ continue
667
+ std = math.sqrt(max(meta.ema_interference_var, 1e-8))
668
+ intf_z = (meta.ema_interference_fast - meta.ema_interference_slow) / std
669
+ mono_z = monolith_scores.get(eid, 0.0)
670
+ if N <= 2:
671
+ z = mono_z
672
+ else:
673
+ z = max(intf_z, mono_z)
674
+ scored.append((eid, z, meta))
675
+ scored.sort(key=lambda t: -t[1])
676
+
677
+ # FIX: Lower split threshold during detected drift — system should react faster
678
+ effective_split_threshold = self.config.split_threshold
679
+ if self._drift_detected:
680
+ effective_split_threshold *= 0.7 # 30 % more sensitive during drift
681
+
682
+ # Split / Death
683
+ touched = set()
684
+ for eid, z_score, meta in scored:
685
+ if eid in touched or eid not in self._expert_id_list:
686
+ continue
687
+ if meta.age < self.config.min_expert_age or meta.cooldown > 0:
688
+ continue
689
+ budget_usage = self._total_params() / self.config.max_params_per_layer
690
+ if budget_usage > 0.7:
691
+ continue
692
+
693
+ threshold = self.config.monolith_variance_z_threshold if N <= 2 else effective_split_threshold
694
+ if (z_score > threshold
695
+ and len(self._expert_id_list) < self.config.max_experts_per_layer
696
+ and (self._total_params() + self._expert_param_count(meta.tier)
697
+ < self.config.max_params_per_layer)):
698
+ events.append(self._do_split(eid,optimizer=optimizer))
699
+ touched.add(eid)
700
+ continue
701
+
702
+ if (meta.avg_routing_weight < self.config.death_threshold
703
+ and len(self._expert_id_list) > 1):
704
+ events.append(self._do_death(eid, optimizer=optimizer))
705
+ touched.add(eid)
706
+ continue
707
+
708
+ events.extend(self._check_merges(touched, optimizer=optimizer))
709
+
710
+ for e in events:
711
+ msg = f"[step {self.global_step}][L{self.layer_idx}] {e}"
712
+ self._lifecycle_log.append(msg)
713
+ print(msg)
714
+ return events
715
+
716
+ # --- Importance-proportional preserver freeze ---
717
+ def _compute_freeze_steps(self, meta: ExpertMeta) -> int:
718
+ cfg = self.config
719
+ importance = max(0.0, min(1.0, meta.avg_routing_weight * 10.0))
720
+ freeze = int(
721
+ cfg.preserver_base_freeze_steps
722
+ + importance * (cfg.preserver_max_freeze_steps - cfg.preserver_base_freeze_steps)
723
+ )
724
+ return freeze
725
+
726
+
727
+ """
728
+ def _do_split(self, eid: str) -> str:
729
+ meta = self._expert_meta[eid]
730
+ parent = self._get_expert(eid)
731
+ parent_emb = self.router.get_embedding(eid)
732
+
733
+ freeze_steps = self._compute_freeze_steps(meta)
734
+
735
+ preserver_id = self._create_expert(
736
+ tier=meta.tier, parent_id=eid,
737
+ init_weights_from=parent, noise_scale=0.0,
738
+ frozen_steps=freeze_steps,
739
+ init_embedding=parent_emb,
740
+ )
741
+
742
+ adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
743
+ mx.eval(adapter_emb)
744
+ adapter_id = self._create_expert(
745
+ tier=meta.tier, parent_id=eid,
746
+ init_weights_from=parent,
747
+ noise_scale=self.config.adapter_noise_scale,
748
+ frozen_steps=0, init_embedding=adapter_emb,
749
+ )
750
+
751
+ self._remove_expert(eid)
752
+ self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
753
+ self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
754
+
755
+ return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
756
+ f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
757
+ f"+ adapter {adapter_id[:8]}")
758
+ """
759
+ """
760
+ def _do_split(self, eid: str, optimizer=None) -> str:
761
+ meta = self._expert_meta[eid]
762
+ parent = self._get_expert(eid)
763
+ parent_emb = self.router.get_embedding(eid)
764
+ parent_idx = self._eid_to_index(eid)
765
+
766
+
767
+ parent_opt_state = None
768
+ parent_emb_opt_state = None
769
+ if optimizer is not None:
770
+ try:
771
+ import copy
772
+ layers_state = optimizer.state.get("layers", [])
773
+ moe_state = layers_state[self.layer_idx].get("moe", {})
774
+ expert_states = moe_state.get("expert_modules", [])
775
+ if parent_idx < len(expert_states):
776
+ parent_opt_state = copy.deepcopy(expert_states[parent_idx])
777
+ # Save parent router embedding state
778
+ router_state = moe_state.get("router", {})
779
+ emb_states = router_state.get("embeddings", [])
780
+ if parent_idx < len(emb_states):
781
+ parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx])
782
+ except (KeyError, IndexError, TypeError):
783
+ pass
784
+
785
+
786
+ freeze_steps = self._compute_freeze_steps(meta)
787
+
788
+ preserver_id = self._create_expert(
789
+ tier=meta.tier, parent_id=eid,
790
+ init_weights_from=parent, noise_scale=0.0,
791
+ frozen_steps=freeze_steps,
792
+ init_embedding=parent_emb,
793
+ )
794
+
795
+ adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
796
+ mx.eval(adapter_emb)
797
+ adapter_id = self._create_expert(
798
+ tier=meta.tier, parent_id=eid,
799
+ init_weights_from=parent,
800
+ noise_scale=self.config.adapter_noise_scale,
801
+ frozen_steps=0, init_embedding=adapter_emb,
802
+ )
803
+
804
+ # Copy optimizer state before removing parent
805
+
806
+ if optimizer is not None:
807
+ self._copy_optimizer_state(optimizer, parent_idx, preserver_id)
808
+ self._copy_optimizer_state(optimizer, parent_idx, adapter_id)
809
+
810
+ self._remove_expert(eid)
811
+
812
+ if optimizer is not None and parent_opt_state is not None:
813
+ try:
814
+ import copy
815
+ layers_state = optimizer.state["layers"]
816
+ moe_state = layers_state[self.layer_idx]["moe"]
817
+ old_states = moe_state.get("expert_modules", [])
818
+
819
+ new_states = []
820
+ for i, expert_eid in enumerate(self._expert_id_list):
821
+ if expert_eid == preserver_id or expert_eid == adapter_id:
822
+ new_states.append(copy.deepcopy(parent_opt_state))
823
+ elif i < len(old_states):
824
+ new_states.append(old_states[i])
825
+ else:
826
+ new_states.append({})
827
+
828
+ moe_state["expert_modules"] = new_states
829
+ except (KeyError, IndexError, TypeError):
830
+ pass
831
+
832
+
833
+
834
+ if optimizer is not None:
835
+ try:
836
+ layers_state = optimizer.state.get("layers", [])
837
+ expert_states = layers_state[self.layer_idx]["moe"]["expert_modules"]
838
+ if parent_idx < len(expert_states):
839
+ expert_states.pop(parent_idx)
840
+ except (KeyError, IndexError, TypeError):
841
+ pass
842
+
843
+ self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
844
+ self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
845
+
846
+ return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
847
+ f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
848
+ f"+ adapter {adapter_id[:8]}")
849
+
850
+ """
851
+ def _do_split(self, eid: str, optimizer=None) -> str:
852
+ meta = self._expert_meta[eid]
853
+ parent = self._get_expert(eid)
854
+ parent_emb = self.router.get_embedding(eid)
855
+ parent_idx = self._eid_to_index(eid)
856
+
857
+ parent_opt_state = None
858
+ parent_emb_opt_state = None
859
+ if optimizer is not None:
860
+ try:
861
+ import copy
862
+ layers_state = optimizer.state.get("layers", [])
863
+ moe_state = layers_state[self.layer_idx].get("moe", {})
864
+ expert_states = moe_state.get("expert_modules", [])
865
+ if parent_idx < len(expert_states):
866
+ parent_opt_state = copy.deepcopy(expert_states[parent_idx])
867
+ router_state = moe_state.get("router", {})
868
+ emb_states = router_state.get("embeddings", [])
869
+ if parent_idx < len(emb_states):
870
+ parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx])
871
+ except (KeyError, IndexError, TypeError):
872
+ pass
873
+
874
+ freeze_steps = self._compute_freeze_steps(meta)
875
+
876
+ preserver_id = self._create_expert(
877
+ tier=meta.tier, parent_id=eid,
878
+ init_weights_from=parent, noise_scale=0.0,
879
+ frozen_steps=freeze_steps,
880
+ init_embedding=parent_emb,
881
+ )
882
+
883
+ adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
884
+ mx.eval(adapter_emb)
885
+ adapter_id = self._create_expert(
886
+ tier=meta.tier, parent_id=eid,
887
+ init_weights_from=parent,
888
+ noise_scale=self.config.adapter_noise_scale,
889
+ frozen_steps=0, init_embedding=adapter_emb,
890
+ )
891
+
892
+ self._remove_expert(eid)
893
+
894
+ if optimizer is not None and parent_opt_state is not None:
895
+ try:
896
+ import copy
897
+ layers_state = optimizer.state["layers"]
898
+ moe_state = layers_state[self.layer_idx]["moe"]
899
+ old_states = moe_state.get("expert_modules", [])
900
+
901
+ new_states = []
902
+ for i, expert_eid in enumerate(self._expert_id_list):
903
+ if expert_eid == preserver_id or expert_eid == adapter_id:
904
+ new_states.append(copy.deepcopy(parent_opt_state))
905
+ elif i < len(old_states):
906
+ new_states.append(old_states[i])
907
+ else:
908
+ new_states.append({})
909
+ moe_state["expert_modules"] = new_states
910
+
911
+ # Rebuild router embeddings state
912
+ router_state = moe_state.get("router", {})
913
+ old_emb_states = router_state.get("embeddings", [])
914
+ new_emb_states = []
915
+ for i, emb_eid in enumerate(self.router._emb_ids):
916
+ if emb_eid == preserver_id or emb_eid == adapter_id:
917
+ if parent_emb_opt_state is not None:
918
+ new_emb_states.append(copy.deepcopy(parent_emb_opt_state))
919
+ else:
920
+ new_emb_states.append({})
921
+ elif i < len(old_emb_states):
922
+ new_emb_states.append(old_emb_states[i])
923
+ else:
924
+ new_emb_states.append({})
925
+ router_state["embeddings"] = new_emb_states
926
+ except (KeyError, IndexError, TypeError):
927
+ pass
928
+
929
+ self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
930
+ self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
931
+
932
+ return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
933
+ f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
934
+ f"+ adapter {adapter_id[:8]}")
935
+
936
+ def _do_death(self, eid: str, optimizer=None) -> str:
937
+ meta = self._expert_meta[eid]
938
+ info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})"
939
+ self._remove_expert(eid)
940
+
941
+ if optimizer is not None:
942
+ try:
943
+ layers_state = optimizer.state.get("layers", [])
944
+ if self.layer_idx < len(layers_state):
945
+ moe_state = layers_state[self.layer_idx].get("moe", {})
946
+ old_states = moe_state.get("expert_modules", [])
947
+ new_states = []
948
+ for i, expert_eid in enumerate(self._expert_id_list):
949
+ if i < len(old_states):
950
+ new_states.append(old_states[i])
951
+ else:
952
+ new_states.append({})
953
+ moe_state["expert_modules"] = new_states
954
+
955
+ # Rebuild router embeddings state
956
+ router_state = moe_state.get("router", {})
957
+ old_emb_states = router_state.get("embeddings", [])
958
+ new_emb_states = []
959
+ for i in range(len(self.router._emb_ids)):
960
+ if i < len(old_emb_states):
961
+ new_emb_states.append(old_emb_states[i])
962
+ else:
963
+ new_emb_states.append({})
964
+ router_state["embeddings"] = new_emb_states
965
+ except (KeyError, IndexError, TypeError):
966
+ pass
967
+
968
+ return info
969
+
970
+ """
971
+ def _do_death(self, eid: str, optimizer=None) -> str:
972
+ meta = self._expert_meta[eid]
973
+ info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})"
974
+ self._remove_expert(eid)
975
+
976
+ if optimizer is not None:
977
+ try:
978
+ layers_state = optimizer.state.get("layers", [])
979
+ if self.layer_idx < len(layers_state):
980
+ moe_state = layers_state[self.layer_idx].get("moe", {})
981
+ old_states = moe_state.get("expert_modules", [])
982
+ new_states = []
983
+ for i, expert_eid in enumerate(self._expert_id_list):
984
+ if i < len(old_states):
985
+ new_states.append(old_states[i])
986
+ else:
987
+ new_states.append({})
988
+ moe_state["expert_modules"] = new_states
989
+ except (KeyError, IndexError, TypeError):
990
+ pass
991
+
992
+ return info
993
+
994
+ """
995
+
996
+ def _average_expert_weights(self, expert_a: Expert, expert_b: Expert) -> List[Tuple[str, mx.array]]:
997
+ """Average the weights of two same-shape experts."""
998
+ src_a = dict(tree_flatten(expert_a.parameters()))
999
+ src_b = dict(tree_flatten(expert_b.parameters()))
1000
+ pairs = []
1001
+ for k in src_a:
1002
+ if k in src_b and src_a[k].shape == src_b[k].shape:
1003
+ pairs.append((k, (src_a[k] + src_b[k]) / 2.0))
1004
+ return pairs
1005
+
1006
+ def _check_merges(self, touched: set, optimizer=None) -> List[str]:
1007
+ events = []
1008
+ merged = set()
1009
+ ids = list(self._expert_id_list)
1010
+ cfg = self.config
1011
+
1012
+ # Pre-compute co-activation matrix from cached routing weights
1013
+ co_activation = {}
1014
+ if self._last_routing_weights is not None:
1015
+ N = self._last_routing_weights.shape[-1]
1016
+ active = (self._last_routing_weights > 0.01).astype(mx.float32)
1017
+ # (B*L, N) binary activation matrix
1018
+ act_flat = active.reshape(-1, N)
1019
+ # Per-expert activation freq
1020
+ act_freq = act_flat.mean(axis=0) # (N,)
1021
+ mx.eval(act_freq)
1022
+
1023
+ def _can_merge(eid):
1024
+ return (eid not in merged and eid not in touched
1025
+ and eid in self._expert_id_list
1026
+ and (meta := self._expert_meta.get(eid)) is not None
1027
+ and meta.age >= cfg.min_expert_age
1028
+ and meta.cooldown == 0)
1029
+
1030
+ def _do_merge(eid_a, eid_b, meta_a, meta_b, reason: str, optimizer=None) -> Optional[str]:
1031
+ """Execute a merge and return event string, or None if budget exceeded."""
1032
+ new_tier = min(meta_a.tier + 1, len(cfg.tier_hidden_dims) - 1)
1033
+ cost = self._expert_param_count(new_tier)
1034
+ freed = (self._expert_param_count(meta_a.tier)
1035
+ + self._expert_param_count(meta_b.tier))
1036
+ if self._total_params() - freed + cost > cfg.max_params_per_layer:
1037
+ return None
1038
+
1039
+ emb_a = self.router.get_embedding(eid_a)
1040
+ emb_b = self.router.get_embedding(eid_b)
1041
+ avg_emb = (emb_a + emb_b) / 2.0
1042
+ mx.eval(avg_emb)
1043
+
1044
+ if new_tier == meta_a.tier:
1045
+
1046
+ merged_expert_id = self._create_expert(
1047
+ tier=new_tier, parent_id=eid_a,
1048
+ init_weights_from=self._get_expert(eid_a),
1049
+ init_embedding=avg_emb,
1050
+ )
1051
+ # Overwrite with averaged weights
1052
+ avg_weights = self._average_expert_weights(
1053
+ self._get_expert(eid_a), self._get_expert(eid_b))
1054
+ if avg_weights:
1055
+ self._get_expert(merged_expert_id).load_weights(avg_weights)
1056
+ mx.eval(self._get_expert(merged_expert_id).parameters())
1057
+ else:
1058
+ # Tier-up merge: different hidden dim, can't average weights
1059
+ merged_expert_id = self._create_expert(
1060
+ tier=new_tier, parent_id=eid_a,
1061
+ init_embedding=avg_emb,
1062
+ )
1063
+
1064
+ self._expert_meta[merged_expert_id].cooldown = cfg.cooldown_steps
1065
+ self._remove_expert(eid_a)
1066
+ self._remove_expert(eid_b)
1067
+ merged.add(eid_a)
1068
+ merged.add(eid_b)
1069
+ """
1070
+ if optimizer is not None:
1071
+ try:
1072
+ layers_state = optimizer.state.get("layers", [])
1073
+ if self.layer_idx < len(layers_state):
1074
+ moe_state = layers_state[self.layer_idx].get("moe", {})
1075
+ old_states = moe_state.get("expert_modules", [])
1076
+ new_states = []
1077
+ for i, expert_eid in enumerate(self._expert_id_list):
1078
+ if expert_eid == merged_expert_id:
1079
+ new_states.append({}) # fresh state, no momentum to copy
1080
+ elif i < len(old_states):
1081
+ new_states.append(old_states[i])
1082
+ else:
1083
+ new_states.append({})
1084
+ moe_state["expert_modules"] = new_states
1085
+ except (KeyError, IndexError, TypeError):
1086
+ pass
1087
+ """
1088
+
1089
+ if optimizer is not None:
1090
+ try:
1091
+ layers_state = optimizer.state.get("layers", [])
1092
+ if self.layer_idx < len(layers_state):
1093
+ moe_state = layers_state[self.layer_idx].get("moe", {})
1094
+
1095
+ # Rebuild expert_modules state
1096
+ old_states = moe_state.get("expert_modules", [])
1097
+ new_states = []
1098
+ for i, expert_eid in enumerate(self._expert_id_list):
1099
+ if expert_eid == merged_expert_id:
1100
+ new_states.append({})
1101
+ elif i < len(old_states):
1102
+ new_states.append(old_states[i])
1103
+ else:
1104
+ new_states.append({})
1105
+ moe_state["expert_modules"] = new_states
1106
+
1107
+ # Rebuild router embeddings state
1108
+ router_state = moe_state.get("router", {})
1109
+ old_emb_states = router_state.get("embeddings", [])
1110
+ new_emb_states = []
1111
+ for i in range(len(self.router._emb_ids)):
1112
+ if i < len(old_emb_states):
1113
+ new_emb_states.append(old_emb_states[i])
1114
+ else:
1115
+ new_emb_states.append({})
1116
+ router_state["embeddings"] = new_emb_states
1117
+ except (KeyError, IndexError, TypeError):
1118
+ pass
1119
+
1120
+ return (f"MERGE({reason}) {eid_a[:8]}+{eid_b[:8]} (T{meta_a.tier}) "
1121
+ f"-> {merged_expert_id[:8]} (T{new_tier})")
1122
+
1123
+ # --- Force 1: Fragment merge (original: co-route + both weak) ---
1124
+ for i, eid_a in enumerate(ids):
1125
+ if not _can_merge(eid_a):
1126
+ continue
1127
+ meta_a = self._expert_meta[eid_a]
1128
+
1129
+ for j in range(i + 1, len(ids)):
1130
+ eid_b = ids[j]
1131
+ if not _can_merge(eid_b):
1132
+ continue
1133
+ meta_b = self._expert_meta[eid_b]
1134
+ if meta_a.tier != meta_b.tier:
1135
+ continue
1136
+
1137
+ emb_a = self.router.get_embedding(eid_a)
1138
+ emb_b = self.router.get_embedding(eid_b)
1139
+ cos = ((emb_a * emb_b).sum()
1140
+ / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
1141
+
1142
+ both_weak = (meta_a.avg_routing_weight < cfg.merge_weakness_threshold
1143
+ and meta_b.avg_routing_weight < cfg.merge_weakness_threshold)
1144
+
1145
+ if cos.item() > cfg.merge_co_route_threshold and both_weak:
1146
+ result = _do_merge(eid_a, eid_b, meta_a, meta_b, "fragment", optimizer=optimizer)
1147
+ if result:
1148
+ events.append(result)
1149
+ break
1150
+
1151
+ # --- Force 2: Capacity-pressure merge ---
1152
+ budget_frac = self._total_params() / cfg.max_params_per_layer
1153
+ if budget_frac > cfg.merge_capacity_pressure_frac:
1154
+ # Find weakest same-tier pair with highest cosine similarity
1155
+ candidates = []
1156
+ for i, eid_a in enumerate(ids):
1157
+ if not _can_merge(eid_a):
1158
+ continue
1159
+ meta_a = self._expert_meta.get(eid_a)
1160
+ if meta_a is None:
1161
+ continue
1162
+ for j in range(i + 1, len(ids)):
1163
+ eid_b = ids[j]
1164
+ if not _can_merge(eid_b):
1165
+ continue
1166
+ meta_b = self._expert_meta.get(eid_b)
1167
+ if meta_b is None or meta_a.tier != meta_b.tier:
1168
+ continue
1169
+ emb_a = self.router.get_embedding(eid_a)
1170
+ emb_b = self.router.get_embedding(eid_b)
1171
+ cos = ((emb_a * emb_b).sum()
1172
+ / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
1173
+ combined_w = meta_a.avg_routing_weight + meta_b.avg_routing_weight
1174
+ # Score: high cosine + low combined weight = best merge candidate
1175
+ score = cos.item() - combined_w
1176
+ candidates.append((score, eid_a, eid_b, meta_a, meta_b))
1177
+
1178
+ candidates.sort(key=lambda t: -t[0])
1179
+ for score, eid_a, eid_b, meta_a, meta_b in candidates:
1180
+ if not _can_merge(eid_a) or not _can_merge(eid_b):
1181
+ continue
1182
+ result = _do_merge(eid_a, eid_b, meta_a, meta_b, "capacity",optimizer=optimizer)
1183
+ if result:
1184
+ events.append(result)
1185
+ # Only do one capacity merge per lifecycle step to avoid cascades
1186
+ break
1187
+
1188
+ # --- Force 3: Tier-gravity merge (same-tier co-activate frequently) ---
1189
+ if self._last_routing_weights is not None:
1190
+ N = self._last_routing_weights.shape[-1]
1191
+ act_flat = (self._last_routing_weights > 0.01).astype(mx.float32).reshape(-1, N)
1192
+ total_tokens = act_flat.shape[0]
1193
+
1194
+ for i, eid_a in enumerate(ids):
1195
+ if not _can_merge(eid_a):
1196
+ continue
1197
+ meta_a = self._expert_meta.get(eid_a)
1198
+ if meta_a is None:
1199
+ continue
1200
+ idx_a = self._eid_to_index(eid_a) if eid_a in self._expert_id_list else None
1201
+ if idx_a is None or idx_a >= N:
1202
+ continue
1203
+
1204
+ for j in range(i + 1, len(ids)):
1205
+ eid_b = ids[j]
1206
+ if not _can_merge(eid_b):
1207
+ continue
1208
+ meta_b = self._expert_meta.get(eid_b)
1209
+ if meta_b is None or meta_a.tier != meta_b.tier:
1210
+ continue
1211
+ idx_b = self._eid_to_index(eid_b) if eid_b in self._expert_id_list else None
1212
+ if idx_b is None or idx_b >= N:
1213
+ continue
1214
+
1215
+ # Co-activation: fraction of tokens where both are active
1216
+ both_active = (act_flat[:, idx_a] * act_flat[:, idx_b]).mean().item()
1217
+
1218
+ emb_a = self.router.get_embedding(eid_a)
1219
+ emb_b = self.router.get_embedding(eid_b)
1220
+ cos = ((emb_a * emb_b).sum()
1221
+ / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
1222
+
1223
+ if (both_active > cfg.merge_tier_gravity_min_co_activation
1224
+ and cos.item() > cfg.merge_tier_gravity_co_route):
1225
+ result = _do_merge(eid_a, eid_b, meta_a, meta_b, "tier-gravity", optimizer=optimizer)
1226
+ if result:
1227
+ events.append(result)
1228
+ break
1229
+
1230
+ return events
1231
+
1232
+
1233
+ # ==========================================
1234
+ # 8. MODEL COMPONENTS
1235
+ # ==========================================
1236
+ class RMSNorm(nn.Module):
1237
+ def __init__(self, dims: int, eps: float = 1e-5):
1238
+ super().__init__()
1239
+ self.weight = mx.ones((dims,))
1240
+ self.eps = eps
1241
+
1242
+ def __call__(self, x):
1243
+ return mx.fast.rms_norm(x, self.weight, self.eps)
1244
+
1245
+
1246
+ class Attention(nn.Module):
1247
+ def __init__(self, args: ModelArgs):
1248
+ super().__init__()
1249
+ self.n_heads = args.n_heads
1250
+ self.n_kv_heads = args.n_kv_heads
1251
+ self.head_dim = args.dim // args.n_heads
1252
+ self.scale = self.head_dim ** -0.5
1253
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
1254
+ self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
1255
+ self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
1256
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
1257
+ self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta)
1258
+
1259
+ def __call__(self, x, mask=None):
1260
+ B, L, D = x.shape
1261
+ queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
1262
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
1263
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
1264
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
1265
+ queries = self.rope(queries)
1266
+ keys = self.rope(keys)
1267
+ output = mx.fast.scaled_dot_product_attention(
1268
+ queries, keys, values, scale=self.scale, mask=mask)
1269
+ return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1))
1270
+
1271
+
1272
+ class MicroExpertsBlock(nn.Module):
1273
+ def __init__(self, args: ModelArgs, me_config: MicroExpertConfig, layer_idx: int):
1274
+ super().__init__()
1275
+ self.attention = Attention(args)
1276
+ self.moe = MicroExpertsMoELayer(args.dim, me_config, layer_idx)
1277
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
1278
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
1279
+
1280
+ def __call__(self, x, mask=None):
1281
+ h = x + self.attention(self.attention_norm(x), mask)
1282
+ return h + self.moe(self.ffn_norm(h))
1283
+
1284
+
1285
+ class MicroExpertsModel(nn.Module):
1286
+ def __init__(self, args: ModelArgs, me_config: MicroExpertConfig):
1287
+ super().__init__()
1288
+ self.args = args
1289
+ self.me_config = me_config
1290
+ self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
1291
+ self.layers = [
1292
+ MicroExpertsBlock(args, me_config, layer_idx=i)
1293
+ for i in range(args.n_layers)
1294
+ ]
1295
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
1296
+ self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
1297
+
1298
+ def __call__(self, x):
1299
+ L = x.shape[1]
1300
+ mask = nn.MultiHeadAttention.create_additive_causal_mask(L).astype(mx.float32)
1301
+ mask = mask[None, None, :, :]
1302
+ h = self.tok_embeddings(x)
1303
+ for layer in self.layers:
1304
+ h = layer(h, mask)
1305
+ return self.output(self.norm(h))
1306
+
1307
+ def set_global_step(self, step: int):
1308
+ for layer in self.layers:
1309
+ layer.moe.global_step = step
1310
+
1311
+ def run_lifecycle(self, optimizer=None):
1312
+ all_events = []
1313
+ for layer in self.layers:
1314
+ all_events.extend(layer.moe.lifecycle_step(optimizer=optimizer))
1315
+ return all_events
1316
+
1317
+ def total_load_balance_loss(self) -> mx.array:
1318
+ """Sum of per-layer activation frequency variance."""
1319
+ lb = mx.array(0.0)
1320
+ for layer in self.layers:
1321
+ lb = lb + layer.moe.load_balance_loss()
1322
+ return lb
1323
+
1324
+ def zero_frozen_grads(self, grads):
1325
+ """Walk gradient tree, zero frozen expert parameters."""
1326
+ if not isinstance(grads, dict) or "layers" not in grads:
1327
+ return grads
1328
+ new_layers = []
1329
+ for i, lg in enumerate(grads["layers"]):
1330
+ if (isinstance(lg, dict) and "moe" in lg
1331
+ and isinstance(lg["moe"], dict)
1332
+ and "expert_modules" in lg["moe"]):
1333
+ moe = self.layers[i].moe
1334
+ fixed = moe.zero_frozen_grads(lg["moe"]["expert_modules"])
1335
+ new_moe = dict(lg["moe"])
1336
+ new_moe["expert_modules"] = fixed
1337
+ new_lg = dict(lg)
1338
+ new_lg["moe"] = new_moe
1339
+ new_layers.append(new_lg)
1340
+ else:
1341
+ new_layers.append(lg)
1342
+ new_grads = dict(grads)
1343
+ new_grads["layers"] = new_layers
1344
+ return new_grads
1345
+
1346
+ def expert_summary(self) -> str:
1347
+ lines = []
1348
+ total_e, total_p = 0, 0
1349
+ for i, layer in enumerate(self.layers):
1350
+ moe = layer.moe
1351
+ n = len(moe._expert_id_list)
1352
+ p = moe._total_params()
1353
+ total_e += n
1354
+ total_p += p
1355
+ tiers = defaultdict(int)
1356
+ for m in moe._expert_meta.values():
1357
+ tiers[m.tier] += 1
1358
+ ts = " ".join(f"T{t}:{c}" for t, c in sorted(tiers.items()))
1359
+ frozen = sum(1 for eid in moe._expert_id_list if eid in moe._frozen_eids)
1360
+ drift = " DRIFT" if moe._drift_detected else ""
1361
+ lines.append(
1362
+ f" L{i:2d}: {n:3d} experts ({ts}) | {p/1e6:.1f}M | "
1363
+ f"{frozen} frozen | d={moe._density_ema:.1f}{drift}")
1364
+ lines.append(f" TOTAL: {total_e} experts | {total_p/1e6:.1f}M MoE params")
1365
+ return "\n".join(lines)
1366
+
1367
+ def save_meta(self, path: str):
1368
+ data = {}
1369
+ for i, layer in enumerate(self.layers):
1370
+ moe = layer.moe
1371
+ data[f"layer_{i}"] = {
1372
+ "expert_ids": list(moe._expert_id_list),
1373
+ "experts": {eid: m.to_dict() for eid, m in moe._expert_meta.items()},
1374
+ "density_ema": moe._density_ema,
1375
+ }
1376
+ with open(path, "w") as f:
1377
+ json.dump(data, f, indent=2)
1378
+
1379
+
1380
+ # ==========================================
1381
+ # 9. DATA STREAMS
1382
+ # ==========================================
1383
+ def stream_gutenberg(tokenizer, batch_size: int, seq_len: int):
1384
+ print("Connecting to Gutenberg stream...")
1385
+ dataset = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True,)
1386
+ dataset_iter = iter(dataset)
1387
+ buffers = [[] for _ in range(batch_size)]
1388
+ while True:
1389
+ for i in range(batch_size):
1390
+ while len(buffers[i]) < seq_len + 1:
1391
+ try:
1392
+ row = next(dataset_iter)
1393
+ except StopIteration:
1394
+ dataset_iter = iter(dataset)
1395
+ row = next(dataset_iter)
1396
+ text = row.get("conversations", "")
1397
+ if isinstance(text, list):
1398
+ parts = []
1399
+ for msg in text:
1400
+ role = msg.get("from", "")
1401
+ content = msg.get("value", [])
1402
+ if isinstance(content, str):
1403
+ parts.append(f"{role}\n{content}")
1404
+ text = "\n".join(parts)
1405
+ #
1406
+ if not text or len(text) < 10:
1407
+ continue
1408
+ buffers[i].extend(tokenizer.encode(text))
1409
+ batch = []
1410
+ for i in range(batch_size):
1411
+ batch.append(buffers[i][:seq_len + 1])
1412
+ buffers[i] = buffers[i][seq_len:]
1413
+ yield mx.array(batch, dtype=mx.int32)
1414
+
1415
+
1416
+ def stream_domain_files(tokenizer, data_dir: str, batch_size: int, seq_len: int):
1417
+ files = sorted(glob.glob(os.path.join(data_dir, "*.txt")))
1418
+ if not files:
1419
+ raise FileNotFoundError(f"No .txt files in {data_dir}")
1420
+ for fpath in files:
1421
+ domain = os.path.splitext(os.path.basename(fpath))[0]
1422
+ print(f"\n{'='*60}")
1423
+ print(f" ACTIVE LEARNING — Domain: {domain}")
1424
+ print(f"{'='*60}")
1425
+ with open(fpath, "r", encoding="utf-8", errors="replace") as f:
1426
+ text = f.read()
1427
+ tokens = tokenizer.encode(text)
1428
+ min_tokens = (seq_len + 1) * batch_size
1429
+ if len(tokens) < min_tokens:
1430
+ print(f" Skipping {domain}: {len(tokens)} tokens < {min_tokens} needed")
1431
+ continue
1432
+
1433
+ def batch_gen(toks=tokens, bs=batch_size, sl=seq_len):
1434
+ while True:
1435
+ buf = list(toks)
1436
+ while len(buf) >= bs * (sl + 1):
1437
+ batch = []
1438
+ for _ in range(bs):
1439
+ batch.append(buf[:sl + 1])
1440
+ buf = buf[sl:]
1441
+ yield mx.array(batch, dtype=mx.int32)
1442
+
1443
+ yield domain, batch_gen()
1444
+
1445
+
1446
+ # ==========================================
1447
+ # 10. LOSS + CHECKPOINT
1448
+ # ==========================================
1449
+ def loss_fn(model, x):
1450
+ """Cross-entropy + load balance auxiliary loss."""
1451
+ logits = model(x)
1452
+ ce = nn.losses.cross_entropy(logits[:, :-1, :], x[:, 1:], reduction="mean")
1453
+ lb = model.total_load_balance_loss()
1454
+ return ce + model.me_config.load_balance_weight * lb
1455
+
1456
+ def load_checkpoint(model, path: str):
1457
+ weights = dict(mx.load(path))
1458
+ meta_path = path.replace(".npz", ".json")
1459
+ with open(meta_path, "r") as f:
1460
+ meta = json.load(f)
1461
+
1462
+ for i, layer in enumerate(model.layers):
1463
+ moe = layer.moe
1464
+ layer_key = f"layer_{i}"
1465
+ if layer_key not in meta:
1466
+ continue
1467
+ layer_meta = meta[layer_key]
1468
+
1469
+ for eid in list(moe._expert_id_list):
1470
+ moe._remove_expert(eid)
1471
+
1472
+ for eid in layer_meta["expert_ids"]:
1473
+ em = layer_meta["experts"][eid]
1474
+ tier = em["tier"]
1475
+ hidden = moe._tier_to_hidden(tier)
1476
+ expert = Expert(moe.model_dim, hidden)
1477
+ mx.eval(expert.parameters())
1478
+ moe.expert_modules.append(expert)
1479
+ moe._expert_id_list.append(eid)
1480
+ moe._expert_meta[eid] = ExpertMeta(
1481
+ expert_id=eid, tier=tier, hidden_dim=hidden,
1482
+ age=em.get("age", 0),
1483
+ cooldown=em.get("cooldown", 0),
1484
+ frozen_steps=em.get("frozen_steps", 0),
1485
+ ema_interference_fast=em.get("ema_fast", 0.0),
1486
+ ema_interference_slow=em.get("ema_slow", 0.0),
1487
+ ema_interference_var=em.get("ema_var", 1.0),
1488
+ avg_routing_weight=em.get("avg_rw", 0.1),
1489
+ avg_activation_freq=em.get("avg_af", 0.1),
1490
+ parent_id=em.get("parent_id"),
1491
+ generation=em.get("generation", 0),
1492
+ )
1493
+ if em.get("frozen_steps", 0) > 0:
1494
+ moe._frozen_eids.add(eid)
1495
+ router_key = f"__router__.{i}.{eid}"
1496
+ init_emb = weights.pop(router_key, None)
1497
+ moe.router.add_expert(eid, init_embedding=init_emb)
1498
+
1499
+ moe._density_ema = layer_meta.get("density_ema", 1.0)
1500
+
1501
+ remaining = [(k, v) for k, v in weights.items() if not k.startswith("__router__")]
1502
+ model.load_weights(remaining, strict=False)
1503
+ mx.eval(model.parameters())
1504
+ print(f" Loaded checkpoint from {path}")
1505
+
1506
+
1507
+ def get_latest_checkpoint(checkpoint_dir: str):
1508
+ if not os.path.exists(checkpoint_dir):
1509
+ return None, 0
1510
+ ckpts = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_step_*.npz")))
1511
+ if not ckpts:
1512
+ return None, 0
1513
+ latest = ckpts[-1]
1514
+ m = re.search(r"step_(\d+)", latest)
1515
+ return latest, int(m.group(1))
1516
+
1517
+
1518
+ def save_checkpoint(model, step: int, checkpoint_dir: str):
1519
+ path = os.path.join(checkpoint_dir, f"checkpoint_step_{step}.npz")
1520
+
1521
+ save_dict = {}
1522
+
1523
+ for k, v in tree_flatten(model.parameters()):
1524
+ save_dict[k] = v
1525
+
1526
+ for i, layer in enumerate(model.layers):
1527
+ moe = layer.moe
1528
+ for j, eid in enumerate(moe.router._emb_ids):
1529
+ save_dict[f"__router__.{i}.{eid}"] = moe.router.embeddings[j].embedding
1530
+
1531
+ mx.savez(path, **save_dict)
1532
+ model.save_meta(path.replace(".npz", ".json"))
1533
+ print(f" Saved checkpoint {path}")
1534
+
1535
+
1536
+ # ==========================================
1537
+ # 11. TRAINING LOOP
1538
+ # ==========================================
1539
+ def train_loop(model, optimizer, data_iter, tc: TrainConfig,
1540
+ start_step=0, max_steps=30000, lifecycle_every=10, label="train"):
1541
+
1542
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
1543
+ compiled_loss_and_grad = mx.compile(loss_and_grad_fn)
1544
+
1545
+ step = start_step
1546
+ tic = time.time()
1547
+
1548
+ topology_changed = False
1549
+
1550
+ for batch in data_iter:
1551
+ if step >= max_steps:
1552
+ break
1553
+ model.set_global_step(step)
1554
+
1555
+ # After a lifecycle event changes the expert topology (add/remove modules),
1556
+ if topology_changed:
1557
+ compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn))
1558
+ topology_changed = False
1559
+
1560
+ try:
1561
+ loss, grads = compiled_loss_and_grad(model, batch)
1562
+ except Exception:
1563
+ loss_and_grad_fn_eager = nn.value_and_grad(model, loss_fn)
1564
+ loss, grads = loss_and_grad_fn_eager(model, batch)
1565
+ compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn))
1566
+
1567
+ grads = model.zero_frozen_grads(grads)
1568
+ try:
1569
+ optimizer.update(model, grads)
1570
+ except (ValueError, KeyError, IndexError):
1571
+ # Topology change left stale optimizer state — wipe and retry
1572
+ optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))}
1573
+ optimizer.update(model, grads)
1574
+ mx.eval(model.parameters(), optimizer.state, loss)
1575
+
1576
+ if step > 0 and step % lifecycle_every == 0:
1577
+ events = model.run_lifecycle(optimizer=optimizer)
1578
+ if events:
1579
+ topology_changed = True
1580
+ #optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))}
1581
+
1582
+ """
1583
+ optimizer.update(model, grads)
1584
+ mx.eval(model.parameters(), optimizer.state, loss)
1585
+ """
1586
+
1587
+ if step % tc.log_every == 0:
1588
+ toc = time.time()
1589
+ n_exp = sum(len(l.moe._expert_id_list) for l in model.layers)
1590
+ avg_d = sum(
1591
+ l.moe._last_density.mean().item()
1592
+ for l in model.layers if l.moe._last_density is not None
1593
+ ) / model.args.n_layers
1594
+ elapsed = toc - tic
1595
+ tok_per_sec = (tc.log_every * tc.batch_size * model.args.max_seq_len) / max(elapsed, 1e-6)
1596
+ print(f"[{label}] Step {step:6d} | Loss {loss.item():.4f} | "
1597
+ f"Experts {n_exp} | Density {avg_d:.1f} | "
1598
+ f"{tok_per_sec:.0f} tok/s | {elapsed:.2f}s")
1599
+ tic = time.time()
1600
+
1601
+ if step > 0 and step % tc.summary_every == 0:
1602
+ print(f"\n--- Expert Summary @ step {step} ---")
1603
+ print(model.expert_summary())
1604
+ print()
1605
+
1606
+ if step > 0 and step % tc.checkpoint_every == 0:
1607
+ save_checkpoint(model, step, tc.checkpoint_dir)
1608
+
1609
+ step += 1
1610
+ return step
1611
+
1612
+
1613
+ # ==========================================
1614
+ # 12. INTERACTIVE SETUP + MAIN
1615
+ # ==========================================
1616
+ def prompt_config() -> TrainConfig:
1617
+ """Interactive configuration via input() prompts."""
1618
+ tc = TrainConfig()
1619
+
1620
+ print("\n" + "="*60)
1621
+ print(" MicroExperts — Training Configuration")
1622
+ print("="*60)
1623
+
1624
+ # Mode
1625
+ print(" 1. pretrain — Gutenberg streaming pretraining")
1626
+ print(" 2. active_learning — Sequential domain continual learning(not implemented yet)")
1627
+ print(" 3. inference — Chat with the trained model")
1628
+ print(" 4. interactive_learning — Chat and learn from your inputs")
1629
+ print(" 5. train_and_chat — Train with periodic chat breaks")
1630
+ choice = input("Mode [1]: ").strip()
1631
+ if choice == "2":
1632
+ tc.mode = "active_learning"
1633
+ elif choice == "3":
1634
+ tc.mode = "inference"
1635
+ elif choice == "4":
1636
+ tc.mode = "interactive_learning"
1637
+ elif choice == "5":
1638
+ tc.mode = "train_and_chat"
1639
+ else:
1640
+ tc.mode = "pretrain"
1641
+
1642
+ # Tokenizer
1643
+ tok = "gutenberg_tokenizer.json"
1644
+ if tok:
1645
+ tc.tokenizer_file = tok
1646
+
1647
+ # Checkpoint dir
1648
+ cd = input(f"Checkpoint directory [{tc.checkpoint_dir}]: ").strip()
1649
+ if cd:
1650
+ tc.checkpoint_dir = cd
1651
+
1652
+ # Batch size
1653
+ bs = input(f"Batch size [{tc.batch_size}]: ").strip()
1654
+ if bs:
1655
+ tc.batch_size = int(bs)
1656
+
1657
+ # Learning rate
1658
+ if tc.mode == "pretrain":
1659
+ default_lr = tc.learning_rate
1660
+ else:
1661
+ default_lr = tc.al_learning_rate
1662
+ lr = input(f"Learning rate [{default_lr}]: ").strip()
1663
+ if lr:
1664
+ tc.learning_rate = float(lr)
1665
+ else:
1666
+ tc.learning_rate = default_lr
1667
+
1668
+ # Max steps
1669
+ ms = input(f"Max steps [{tc.max_steps}]: ").strip()
1670
+ if ms:
1671
+ tc.max_steps = int(ms)
1672
+
1673
+ # Resume
1674
+ resume = input("Resume from checkpoint? [Y/n]: ").strip().lower()
1675
+ tc._resume = resume != "n"
1676
+
1677
+ # Mode-specific
1678
+ if tc.mode == "active_learning":
1679
+ dd = input(f"Domain data directory [{tc.al_data_dir}]: ").strip()
1680
+ if dd:
1681
+ tc.al_data_dir = dd
1682
+ spd = input(f"Steps per domain [{tc.al_steps_per_domain}]: ").strip()
1683
+ if spd:
1684
+ tc.al_steps_per_domain = int(spd)
1685
+
1686
+ print("\n" + "-"*60)
1687
+ print(f" Mode: {tc.mode}")
1688
+ print(f" LR: {tc.learning_rate}")
1689
+ print(f" Batch: {tc.batch_size}")
1690
+ print(f" Max steps: {tc.max_steps}")
1691
+ print(f" Checkpoint: {tc.checkpoint_dir}")
1692
+ print(f" Resume: {tc._resume}")
1693
+ if tc.mode == "active_learning":
1694
+ print(f" Data dir: {tc.al_data_dir}")
1695
+ print(f" Steps/dom: {tc.al_steps_per_domain}")
1696
+ print(f" M4 budget: 150M params/layer, 128 experts/layer max")
1697
+ print("-"*60)
1698
+
1699
+ confirm = input("Continue? [Y/n]: ").strip().lower()
1700
+ if confirm == "n":
1701
+ print("Aborted.")
1702
+ exit(0)
1703
+
1704
+ return tc
1705
+
1706
+ def generate(model, tokenizer, prompt: str, max_tokens: int = 256, temperature: float = 0.8):
1707
+ tokens = tokenizer.encode(prompt)
1708
+ tokens = mx.array([tokens], dtype=mx.int32)
1709
+
1710
+ for _ in range(max_tokens):
1711
+ logits = model(tokens)
1712
+ next_logits = logits[:, -1, :] / temperature
1713
+ next_token = mx.random.categorical(next_logits)
1714
+ next_token = next_token.reshape(1, 1)
1715
+ tokens = mx.concatenate([tokens, next_token], axis=1)
1716
+ mx.eval(tokens)
1717
+
1718
+ token_id = next_token.item()
1719
+ if token_id == tokenizer.eos_token_id:
1720
+ break
1721
+
1722
+ # Print expert usage per layer
1723
+ print("\n Expert routing:")
1724
+ for i, layer in enumerate(model.layers):
1725
+ moe = layer.moe
1726
+ if moe._last_routing_weights is None:
1727
+ continue
1728
+ rw = moe._last_routing_weights
1729
+ N = rw.shape[-1]
1730
+ # Average routing weight per expert across all tokens
1731
+ avg_w = rw.reshape(-1, N).mean(axis=0)
1732
+ active = (avg_w > 0.01)
1733
+ parts = []
1734
+ for j, eid in enumerate(moe._expert_id_list):
1735
+ if j < N and active[j].item():
1736
+ meta = moe._expert_meta.get(eid)
1737
+ tier = meta.tier if meta else "?"
1738
+ parts.append(f"{eid[:6]}(T{tier} w={avg_w[j].item():.3f})")
1739
+ if parts:
1740
+ print(f" L{i:2d}: {' '.join(parts)}")
1741
+
1742
+ return tokenizer.decode(tokens[0].tolist())
1743
+
1744
+ def main():
1745
+ tc = prompt_config()
1746
+ os.makedirs(tc.checkpoint_dir, exist_ok=True)
1747
+
1748
+ # Tokenizer
1749
+ print(f"\nLoading tokenizer: {tc.tokenizer_file}")
1750
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file=tc.tokenizer_file)
1751
+ if tokenizer.pad_token is None:
1752
+ tokenizer.pad_token = tokenizer.eos_token
1753
+
1754
+ # Model
1755
+ args = ModelArgs()
1756
+ args.vocab_size = len(tokenizer)
1757
+ me_config = MicroExpertConfig()
1758
+
1759
+ if tc.mode == "active_learning":
1760
+ me_config.split_threshold = tc.al_split_threshold
1761
+ me_config.min_expert_age = tc.al_min_expert_age
1762
+
1763
+ print(f"Initializing MicroExperts model (vocab={args.vocab_size})...")
1764
+ model = MicroExpertsModel(args, me_config)
1765
+
1766
+ # Resume
1767
+ current_step = 0
1768
+ if tc._resume:
1769
+ ckpt, ckpt_step = get_latest_checkpoint(tc.checkpoint_dir)
1770
+ if ckpt:
1771
+ print(f"Resuming from {ckpt} @ step {ckpt_step}")
1772
+ load_checkpoint(model, ckpt)
1773
+ current_step = ckpt_step
1774
+ else:
1775
+ print("No checkpoint found — starting fresh.")
1776
+
1777
+ mx.eval(model.parameters())
1778
+ n_params = sum(v.size for _, v in tree_flatten(model.parameters()))
1779
+ print(f"Total params: {n_params / 1e6:.2f}M")
1780
+ print("Initial layout:")
1781
+ print(model.expert_summary())
1782
+
1783
+ optimizer = optim.AdamW(learning_rate=tc.learning_rate)
1784
+
1785
+ # ---- PRETRAIN ----
1786
+ if tc.mode == "pretrain":
1787
+ data = stream_gutenberg(tokenizer, tc.batch_size, args.max_seq_len)
1788
+ print(f"\nStarting pretraining for {tc.max_steps} steps...")
1789
+ final_step = train_loop(
1790
+ model, optimizer, data, tc,
1791
+ start_step=current_step, max_steps=tc.max_steps,
1792
+ lifecycle_every=tc.lifecycle_every, label="pretrain",
1793
+ )
1794
+
1795
+ elif tc.mode == "inference":
1796
+
1797
+ print("\nChat ready. Type 'quit' to exit.\n")
1798
+ while True:
1799
+ user_input = input("You: ").strip()
1800
+ if user_input.lower() in ("quit", "exit"):
1801
+ break
1802
+ if not user_input:
1803
+ continue
1804
+ response = generate(model, tokenizer, user_input)
1805
+ print(f"Model: {response}\n")
1806
+
1807
+ final_step = current_step
1808
+
1809
+ # ---- ACTIVE LEARNING ----
1810
+ elif tc.mode == "active_learning":
1811
+ lifecycle_every = tc.al_lifecycle_every
1812
+ print(f"\nActive learning from: {tc.al_data_dir}")
1813
+ print(f" Steps/domain: {tc.al_steps_per_domain} | Lifecycle every: {lifecycle_every}")
1814
+
1815
+ domain_gen = stream_domain_files(
1816
+ tokenizer, tc.al_data_dir, tc.batch_size, args.max_seq_len)
1817
+
1818
+ global_step = current_step
1819
+ for domain_name, batches in domain_gen:
1820
+ domain_max = global_step + tc.al_steps_per_domain
1821
+ n_before = sum(len(l.moe._expert_id_list) for l in model.layers)
1822
+
1823
+ print(f"\n Training '{domain_name}': steps {global_step} -> {domain_max}")
1824
+ global_step = train_loop(
1825
+ model, optimizer, batches, tc,
1826
+ start_step=global_step, max_steps=domain_max,
1827
+ lifecycle_every=lifecycle_every, label=f"AL:{domain_name}",
1828
+ )
1829
+
1830
+ n_after = sum(len(l.moe._expert_id_list) for l in model.layers)
1831
+ print(f"\n '{domain_name}' done. Experts: {n_before} -> {n_after} ({n_after-n_before:+d})")
1832
+ print(model.expert_summary())
1833
+
1834
+ final_step = global_step
1835
+
1836
+ elif tc.mode == "interactive_learning":
1837
+ if not tc._resume:
1838
+ print("WARNING: No checkpoint loaded, model is random.")
1839
+
1840
+ il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate)
1841
+ il_step = current_step
1842
+ conversation_tokens = []
1843
+ message_count = 0
1844
+
1845
+ print("\nInteractive learning ready. Type 'quit' to exit.")
1846
+ print("The model learns from the conversation.\n")
1847
+
1848
+ while True:
1849
+ user_input = input("You: ").strip()
1850
+ if user_input.lower() in ("quit", "exit"):
1851
+ break
1852
+ if not user_input:
1853
+ continue
1854
+
1855
+ response = generate(model, tokenizer, user_input)
1856
+ print(f"Model: {response}\n")
1857
+
1858
+ conversation_tokens.extend(tokenizer.encode(user_input))
1859
+ conversation_tokens.extend(tokenizer.encode(response))
1860
+ message_count += 1
1861
+
1862
+ seq_len = model.args.max_seq_len
1863
+ trained = False
1864
+
1865
+ # Train on full sequences when available
1866
+ while len(conversation_tokens) >= seq_len + 1:
1867
+ batch = mx.array([conversation_tokens[:seq_len + 1]], dtype=mx.int32)
1868
+ conversation_tokens = conversation_tokens[seq_len:]
1869
+
1870
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
1871
+ loss, grads = loss_and_grad_fn(model, batch)
1872
+ grads = model.zero_frozen_grads(grads)
1873
+ il_optimizer.update(model, grads)
1874
+ mx.eval(model.parameters(), il_optimizer.state, loss)
1875
+
1876
+ il_step += 1
1877
+ model.set_global_step(il_step)
1878
+ trained = True
1879
+ print(f" [learned: loss={loss.item():.4f}, step={il_step}]")
1880
+
1881
+ # Force train every 2 messages even with partial sequence
1882
+ if not trained and message_count % 2 == 0 and len(conversation_tokens) > 2:
1883
+ pad_len = seq_len + 1
1884
+ tokens_to_use = conversation_tokens[-pad_len:] if len(conversation_tokens) >= pad_len else conversation_tokens
1885
+ # Pad if too short
1886
+ while len(tokens_to_use) < pad_len:
1887
+ tokens_to_use = tokens_to_use + tokens_to_use
1888
+ tokens_to_use = tokens_to_use[:pad_len]
1889
+
1890
+ batch = mx.array([tokens_to_use], dtype=mx.int32)
1891
+
1892
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
1893
+ loss, grads = loss_and_grad_fn(model, batch)
1894
+ grads = model.zero_frozen_grads(grads)
1895
+ il_optimizer.update(model, grads)
1896
+ mx.eval(model.parameters(), il_optimizer.state, loss)
1897
+
1898
+ il_step += 1
1899
+ model.set_global_step(il_step)
1900
+ print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]")
1901
+
1902
+ # Lifecycle check
1903
+ if il_step > 0 and il_step % tc.al_lifecycle_every == 0:
1904
+ events = model.run_lifecycle()
1905
+ if events:
1906
+ il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
1907
+
1908
+ print(model.expert_summary())
1909
+
1910
+ save_checkpoint(model, il_step, tc.checkpoint_dir)
1911
+ print("Model saved.")
1912
+ final_step = il_step
1913
+
1914
+ elif tc.mode == "train_and_chat":
1915
+ if not tc._resume:
1916
+ print("WARNING: No checkpoint loaded, model is random.")
1917
+
1918
+ il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate)
1919
+ il_step = current_step
1920
+ conversation_tokens = []
1921
+ message_count = 0
1922
+
1923
+ system_prompt = "You are a helpful assistant."
1924
+ chat_history = []
1925
+
1926
+ print("\nChat Learning ready. Type 'quit' to exit.")
1927
+ print("The model learns from the conversation with chat format.\n")
1928
+
1929
+ while True:
1930
+ user_input = input("You: ").strip()
1931
+ if user_input.lower() in ("quit", "exit"):
1932
+ break
1933
+ if not user_input:
1934
+ continue
1935
+
1936
+ response = generate(model, tokenizer, user_input)
1937
+ print(f"Model: {response}\n")
1938
+
1939
+ # Build chat-formatted training text
1940
+ chat_history.append({"role": "user", "content": user_input})
1941
+ chat_history.append({"role": "assistant", "content": response})
1942
+
1943
+ chat_text = f"system\n{system_prompt}\n"
1944
+ for msg in chat_history:
1945
+ role = "human" if msg["role"] == "user" else "gpt"
1946
+ chat_text += f"{role}\n{msg['content']}\n"
1947
+
1948
+ conversation_tokens = tokenizer.encode(chat_text)
1949
+ message_count += 1
1950
+
1951
+ seq_len = model.args.max_seq_len
1952
+ trained = False
1953
+
1954
+ # Train on full sequences from chat history
1955
+ train_tokens = list(conversation_tokens)
1956
+ while len(train_tokens) >= seq_len + 1:
1957
+ batch = mx.array([train_tokens[:seq_len + 1]], dtype=mx.int32)
1958
+ train_tokens = train_tokens[seq_len:]
1959
+
1960
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
1961
+ loss, grads = loss_and_grad_fn(model, batch)
1962
+ grads = model.zero_frozen_grads(grads)
1963
+ try:
1964
+ il_optimizer.update(model, grads)
1965
+ except (ValueError, KeyError, IndexError):
1966
+ il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
1967
+ il_optimizer.update(model, grads)
1968
+ mx.eval(model.parameters(), il_optimizer.state, loss)
1969
+
1970
+ il_step += 1
1971
+ model.set_global_step(il_step)
1972
+ trained = True
1973
+ print(f" [learned: loss={loss.item():.4f}, step={il_step}]")
1974
+
1975
+ # Force train every 2 messages even with partial sequence
1976
+ if not trained and message_count % 2 == 0 and len(train_tokens) > 2:
1977
+ pad_len = seq_len + 1
1978
+ tokens_to_use = train_tokens[-pad_len:] if len(train_tokens) >= pad_len else train_tokens
1979
+ while len(tokens_to_use) < pad_len:
1980
+ tokens_to_use = tokens_to_use + tokens_to_use
1981
+ tokens_to_use = tokens_to_use[:pad_len]
1982
+
1983
+ batch = mx.array([tokens_to_use], dtype=mx.int32)
1984
+
1985
+ loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
1986
+ loss, grads = loss_and_grad_fn(model, batch)
1987
+ grads = model.zero_frozen_grads(grads)
1988
+ try:
1989
+ il_optimizer.update(model, grads)
1990
+ except (ValueError, KeyError, IndexError):
1991
+ il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
1992
+ il_optimizer.update(model, grads)
1993
+ mx.eval(model.parameters(), il_optimizer.state, loss)
1994
+
1995
+ il_step += 1
1996
+ model.set_global_step(il_step)
1997
+ print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]")
1998
+
1999
+ # Trim chat history if too long
2000
+ max_history = 20
2001
+ if len(chat_history) > max_history:
2002
+ chat_history = chat_history[-max_history:]
2003
+
2004
+ # Lifecycle check
2005
+ if il_step > 0 and il_step % tc.al_lifecycle_every == 0:
2006
+ events = model.run_lifecycle(optimizer=il_optimizer)
2007
+ if events:
2008
+ pass # optimizer state already rebuilt in lifecycle
2009
+
2010
+ print(model.expert_summary())
2011
+
2012
+ save_checkpoint(model, il_step, tc.checkpoint_dir)
2013
+ print("Model saved.")
2014
+ final_step = il_step
2015
+
2016
+ # Save final
2017
+ print("\nTraining complete.")
2018
+ save_checkpoint(model, final_step, tc.checkpoint_dir)
2019
+ print("Final layout:")
2020
+ print(model.expert_summary())
2021
+
2022
+
2023
+ if __name__ == "__main__":
2024
+ main()
tokenizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors, Regex
3
+
4
+ # --- CONFIGURATION ---
5
+ DATASET_NAME = "sedthh/gutenberg_english"
6
+ VOCAB_SIZE = 32000
7
+ SAMPLE_SIZE = 3000
8
+ BATCH_SIZE = 100
9
+
10
+ # 1. Connect
11
+ print(f"1. Connecting to {DATASET_NAME}...")
12
+ dataset = load_dataset(DATASET_NAME, split="train", streaming=True)
13
+
14
+ # 2. The Generator
15
+ def batch_iterator():
16
+ batch = []
17
+ print("2. Collecting data...")
18
+ for i, item in enumerate(dataset):
19
+ if i >= SAMPLE_SIZE: break
20
+
21
+ batch.append(item['TEXT'])
22
+
23
+ if len(batch) == BATCH_SIZE:
24
+ print(f" > Processing batch {(i+1)//BATCH_SIZE}...", end='\r')
25
+ yield batch
26
+ batch = []
27
+ if batch: yield batch
28
+
29
+ # 3. TOKENIZER
30
+ print("\n3. Initializing Tokenizer...")
31
+ tokenizer = Tokenizer(models.BPE())
32
+
33
+
34
+ qwen_pattern = Regex(r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""")
35
+
36
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
37
+ pre_tokenizers.Split(pattern=qwen_pattern, behavior="isolated"),
38
+ pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
39
+ ])
40
+
41
+ tokenizer.decoder = decoders.ByteLevel()
42
+
43
+ trainer = trainers.BpeTrainer(
44
+ vocab_size=VOCAB_SIZE,
45
+ special_tokens=["<|endoftext|>", "<|padding|>"],
46
+ show_progress=True,
47
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
48
+ )
49
+
50
+ # 4. Train
51
+ print("4. Training Qwen-style tokenizer...")
52
+ tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
53
+
54
+ # 5. Save
55
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
56
+ tokenizer.save("qwen_style_tokenizer.json")
57
+ print(f"\nSUCCESS! Saved 'qwen_style_tokenizer.json'")