Text Generation
Microexpert_NG / microexpert.py
gustavlangstroem's picture
Upload 4 files
9987dd2 verified
"""
MicroExperts — Self-organizing dynamic Mixture-of-Experts for continual learning.
Target hardware: Apple M4 with 48 GB unified memory.
"""
import time
import math
import uuid
import json
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
from datasets import load_dataset
from transformers import PreTrainedTokenizerFast
import os
import glob
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from collections import defaultdict
def one_hot(indices: mx.array, num_classes: int) -> mx.array:
# Build a range vector [0, 1, ..., num_classes-1] and compare with indices
flat = indices.reshape(-1) # (K,)
arange = mx.arange(num_classes) # (num_classes,)
oh = (flat[:, None] == arange[None, :]).astype(mx.float32) # (K, num_classes)
return oh.reshape(*indices.shape, num_classes)
# ==========================================
# 1. CONFIGURATION
# ==========================================
@dataclass
class ModelArgs:
dim: int = 768
n_layers: int = 12
n_heads: int = 12
n_kv_heads: int = 12
vocab_size: int = -1
norm_eps: float = 1e-8
max_seq_len: int = 2048
rope_theta: float = 10000.0
@dataclass
class MicroExpertConfig:
"""All hyperparameters for the MicroExperts MoE system."""
#tier_hidden_dims: Tuple[int, ...] = (512, 1024, 2048, 4096)
tier_hidden_dims: Tuple[int, ...] = (256, 512, 1024, 2048)
monolith_split_enabled: bool = True
monolith_variance_ema_alpha: float = 0.02
monolith_variance_z_threshold: float = 1.5
# Router
router_embed_dim: int = 128
min_experts_per_token: int = 1
max_experts_per_token: int = 64
# Cannibalization / lifecycle
ema_fast_alpha: float = 0.05
ema_slow_alpha: float = 0.005
split_threshold: float = 2.0
# Relaxed merge thresholds so merges actually fire
merge_co_route_threshold: float = 0.5
merge_weakness_threshold: float = 0.05
death_threshold: float = 0.001
min_expert_age: int = 50
cooldown_steps: int = 100
# Base freeze duration — actual duration scaled by importance
preserver_base_freeze_steps: int = 100
preserver_max_freeze_steps: int = 200
adapter_noise_scale: float = 0.02
max_experts_per_layer: int = 12
max_params_per_layer: int = 20_000_000 # 20 M
# Initial state
init_tier: int = 2
# Interference
interference_subsample: int = 64
# Load balance loss
load_balance_weight: float = 0.01
# Capacity-pressure merge: trigger when pool exceeds this fraction of budget
merge_capacity_pressure_frac: float = 0.8
# Tier-gravity merge: same-tier co-activation threshold (lower than fragment)
merge_tier_gravity_co_route: float = 0.4
merge_tier_gravity_min_co_activation: float = 0.3 # both activated > 30 % of tokens
density_ema_alpha: float = 0.02
density_spike_z: float = 2.5 # z-score above mean to flag distribution shift
@dataclass
class TrainConfig:
"""Training hyperparameters."""
mode: str = "pretrain"
batch_size: int = 8
learning_rate: float = 3e-4
max_steps: int = 30_000
tokenizer_file: str = "gutenberg_tokenizer.json"
checkpoint_dir: str = "checkpoints_me"
log_every: int = 10
summary_every: int = 500
checkpoint_every: int = 1000
lifecycle_every: int = 10
# Active learning
al_data_dir: str = "./domains"
al_steps_per_domain: int = 2000
al_learning_rate: float = 1e-4
al_lifecycle_every: int = 5
al_split_threshold: float = 1.5
al_min_expert_age: int = 100
# ==========================================
# 2. EXPERT MODULE
# ==========================================
class Expert(nn.Module):
"""Single MicroExpert: SwiGLU FFN."""
def __init__(self, model_dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(model_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, model_dim, bias=False)
self.w3 = nn.Linear(model_dim, hidden_dim, bias=False)
def __call__(self, x):
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
# ==========================================
# 3. EXPERT METADATA
# ==========================================
@dataclass
class ExpertMeta:
"""Non-parameter state for one expert."""
expert_id: str
tier: int
hidden_dim: int
age: int = 0
cooldown: int = 0
frozen_steps: int = 0
ema_interference_fast: float = 0.0
ema_interference_slow: float = 0.0
ema_interference_var: float = 1.0
avg_routing_weight: float = 0.1
avg_activation_freq: float = 0.1
parent_id: Optional[str] = None
generation: int = 0
def to_dict(self) -> dict:
return {
"expert_id": self.expert_id, "tier": self.tier,
"hidden_dim": self.hidden_dim, "age": self.age,
"cooldown": self.cooldown, "frozen_steps": self.frozen_steps,
"ema_fast": self.ema_interference_fast,
"ema_slow": self.ema_interference_slow,
"ema_var": self.ema_interference_var,
"avg_rw": self.avg_routing_weight,
"avg_af": self.avg_activation_freq,
"parent_id": self.parent_id, "generation": self.generation,
}
# ==========================================
# 4. EXPERT EMBEDDING (trainable nn.Module)
# ==========================================
class ExpertEmbedding(nn.Module):
def __init__(self, dim: int, init: Optional[mx.array] = None):
super().__init__()
if init is not None:
self.embedding = init
else:
scale = 1.0 / math.sqrt(dim)
self.embedding = mx.random.normal((dim,)) * scale
# ==========================================
# 5. ADAPTIVE ROUTER
# ==========================================
class AdaptiveRouter(nn.Module):
def __init__(self, model_dim: int, config: MicroExpertConfig):
super().__init__()
self.config = config
self.d = config.router_embed_dim
self.proj = nn.Linear(model_dim, self.d, bias=False)
self.threshold_head = nn.Linear(model_dim, 1, bias=True)
# Trainable embeddings — list of nn.Module (MLX discovers these)
self.embeddings: List[ExpertEmbedding] = []
# Parallel ID list (same order)
self._emb_ids: List[str] = []
def _id_to_idx(self, eid: str) -> int:
return self._emb_ids.index(eid)
def add_expert(self, expert_id: str, init_embedding: Optional[mx.array] = None):
emb = ExpertEmbedding(self.d, init=init_embedding)
mx.eval(emb.parameters())
self.embeddings.append(emb)
self._emb_ids.append(expert_id)
def remove_expert(self, expert_id: str):
if expert_id not in self._emb_ids:
return
idx = self._id_to_idx(expert_id)
self.embeddings.pop(idx)
self._emb_ids.pop(idx)
def get_embedding(self, expert_id: str) -> mx.array:
return self.embeddings[self._id_to_idx(expert_id)].embedding
def set_embedding(self, expert_id: str, emb: mx.array):
self.embeddings[self._id_to_idx(expert_id)].embedding = emb
def __call__(self, x: mx.array, expert_ids: List[str]):
"""
Returns:
routing_weights: (B, L, N) sparse softmax-normalized
raw_scores: (B, L, N) cosine similarities
density: (B, L) active expert count per token
"""
B, L, D = x.shape
N = len(expert_ids)
if N == 0:
z = mx.zeros((B, L, 1))
return z[:, :, :0], z[:, :, :0], mx.zeros((B, L))
# Project input to routing space and normalize
h = self.proj(x) # (B, L, d)
h_norm = h / (mx.linalg.norm(h, axis=-1, keepdims=True) + 1e-8)
# Stack expert embeddings into matrix
E = mx.stack([self.embeddings[self._emb_ids.index(eid)].embedding
for eid in expert_ids], axis=0) # (N, d)
E_norm = E / (mx.linalg.norm(E, axis=-1, keepdims=True) + 1e-8)
raw_scores = h_norm @ E_norm.T # (B, L, N)
# Adaptive per-token threshold
threshold = mx.sigmoid(self.threshold_head(x)) # (B, L, 1)
gate_mask = (raw_scores > threshold).astype(mx.float32)
# Guarantee top-1 always active
best_idx = mx.argmax(raw_scores, axis=-1) # (B, L)
best_oh = one_hot(best_idx, N) # (B, L, N)
gate_mask = mx.maximum(gate_mask, best_oh)
# Cap maximum active experts
max_k = self.config.max_experts_per_token
if max_k < N:
sorted_idx = mx.argsort(-raw_scores, axis=-1)
rank = mx.argsort(sorted_idx, axis=-1)
gate_mask = gate_mask * (rank < max_k).astype(mx.float32)
# Softmax over active experts
masked = raw_scores * gate_mask + (1.0 - gate_mask) * (-1e9)
routing_weights = mx.softmax(masked, axis=-1) * gate_mask
density = gate_mask.sum(axis=-1)
return routing_weights, raw_scores, density
# ==========================================
# 6. UTILITY: zero a nested grad tree
# ==========================================
def _zero_tree(tree):
"""Recursively zero all mx.arrays in a nested structure."""
if isinstance(tree, mx.array):
return mx.zeros_like(tree)
elif isinstance(tree, dict):
return {k: _zero_tree(v) for k, v in tree.items()}
elif isinstance(tree, list):
return [_zero_tree(v) for v in tree]
return tree
# ==========================================
# 7. MoE LAYER
# ==========================================
class MicroExpertsMoELayer(nn.Module):
def __init__(self, model_dim: int, config: MicroExpertConfig, layer_idx: int):
super().__init__()
self.model_dim = model_dim
self.config = config
self.layer_idx = layer_idx
self.router = AdaptiveRouter(model_dim, config)
self._variance_ema: Dict[str, float] = {}
self._variance_ema_sq: Dict[str, float] = {}
# Expert modules — list for MLX parameter discovery
self.expert_modules: List[Expert] = []
self._expert_id_list: List[str] = []
self._expert_meta: Dict[str, ExpertMeta] = {}
self._lifecycle_log: List[str] = []
self.global_step: int = 0
# Cached from forward pass (detached)
self._last_routing_weights: Optional[mx.array] = None
self._last_density: Optional[mx.array] = None
self._last_input: Optional[mx.array] = None
# FIX: Cache expert outputs to avoid redundant forward in interference
self._last_expert_outputs: Optional[List[mx.array]] = None
# Frozen expert tracking
self._frozen_eids: set = set()
# FIX: Density drift tracking
self._density_ema: float = 1.0
self._density_var: float = 1.0
self._drift_detected: bool = False
# Create initial monolith
self._create_expert(tier=config.init_tier)
# --- Helpers ---
@property
def expert_ids(self) -> List[str]:
return list(self._expert_id_list)
def _eid_to_index(self, eid: str) -> int:
return self._expert_id_list.index(eid)
def _get_expert(self, eid: str) -> Expert:
return self.expert_modules[self._eid_to_index(eid)]
def _tier_to_hidden(self, tier: int) -> int:
t = min(tier, len(self.config.tier_hidden_dims) - 1)
return self.config.tier_hidden_dims[t]
def _expert_param_count(self, tier: int) -> int:
return 3 * self.model_dim * self._tier_to_hidden(tier)
def _total_params(self) -> int:
return sum(self._expert_param_count(m.tier) for m in self._expert_meta.values())
def _make_id(self) -> str:
return uuid.uuid4().hex[:12]
"""
def _copy_optimizer_state(self, optimizer, parent_idx: int, child_eid: str):
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx >= len(layers_state):
return
moe_state = layers_state[self.layer_idx].get("moe", {})
expert_states = moe_state.get("expert_modules", [])
if parent_idx >= len(expert_states):
return
parent_state = expert_states[parent_idx]
child_idx = self._eid_to_index(child_eid)
# Grow the list if needed
while len(expert_states) <= child_idx:
expert_states.append({})
# Deep copy the parent state
import copy
expert_states[child_idx] = copy.deepcopy(parent_state)
except (KeyError, IndexError, TypeError):
pass
"""
def _copy_optimizer_state(self, optimizer, parent_idx: int, children_eids: list):
"""Copy parent's optimizer state to children, then rebuild list."""
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx >= len(layers_state):
return
moe_state = layers_state[self.layer_idx].get("moe", {})
expert_states = moe_state.get("expert_modules", [])
if parent_idx >= len(expert_states):
return
import copy
parent_state = copy.deepcopy(expert_states[parent_idx])
# Build new list matching current expert_modules order
new_states = []
for i, eid in enumerate(self._expert_id_list):
if eid in children_eids:
new_states.append(copy.deepcopy(parent_state))
elif i < len(expert_states):
new_states.append(expert_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
except (KeyError, IndexError, TypeError):
pass
# --- Expert creation / removal ---
def _create_expert(
self, tier: int,
parent_id: Optional[str] = None,
init_weights_from: Optional[Expert] = None,
noise_scale: float = 0.0,
frozen_steps: int = 0,
init_embedding: Optional[mx.array] = None,
) -> str:
eid = self._make_id()
hidden = self._tier_to_hidden(tier)
expert = Expert(self.model_dim, hidden)
if init_weights_from is not None:
src = dict(tree_flatten(init_weights_from.parameters()))
dst = dict(tree_flatten(expert.parameters()))
pairs = []
for k in dst:
if k in src and src[k].shape == dst[k].shape:
w = src[k]
if noise_scale > 0:
w = w + mx.random.normal(w.shape) * noise_scale * (mx.abs(w).mean() + 1e-8)
pairs.append((k, w))
if pairs:
expert.load_weights(pairs)
mx.eval(expert.parameters())
self.expert_modules.append(expert)
self._expert_id_list.append(eid)
gen = 0
if parent_id and parent_id in self._expert_meta:
gen = self._expert_meta[parent_id].generation + 1
self._expert_meta[eid] = ExpertMeta(
expert_id=eid, tier=tier, hidden_dim=hidden,
frozen_steps=frozen_steps, parent_id=parent_id, generation=gen,
)
if frozen_steps > 0:
self._frozen_eids.add(eid)
self.router.add_expert(eid, init_embedding=init_embedding)
return eid
def _remove_expert(self, eid: str):
if eid not in self._expert_id_list:
return
idx = self._eid_to_index(eid)
self.expert_modules.pop(idx)
self._expert_id_list.pop(idx)
self._expert_meta.pop(eid, None)
self._frozen_eids.discard(eid)
self.router.remove_expert(eid)
# --- Forward ---
def __call__(self, x: mx.array) -> mx.array:
B, L, D = x.shape
N = len(self._expert_id_list)
if N == 0:
return mx.zeros_like(x)
routing_weights, raw_scores, density = self.router(x, self._expert_id_list)
# Compute and cache individual expert outputs
expert_outputs = [self.expert_modules[i](x) for i in range(N)]
output = mx.zeros_like(x)
for i in range(N):
w_i = routing_weights[:, :, i:i + 1]
output = output + w_i * expert_outputs[i]
# Cache detached copies for interference computation
self._last_routing_weights = mx.stop_gradient(routing_weights)
self._last_density = mx.stop_gradient(density)
self._last_input = mx.stop_gradient(x)
self._last_expert_outputs = [mx.stop_gradient(eo) for eo in expert_outputs]
return output
# --- Load balance loss ---
def load_balance_loss(self) -> mx.array:
"""
Variance of per-expert activation frequency across the last batch.
Penalizes uneven usage — prevents expert starvation without forcing
uniform routing (which would defeat specialization).
"""
if self._last_routing_weights is None:
return mx.array(0.0)
N = self._last_routing_weights.shape[-1]
if N <= 1:
return mx.array(0.0)
# Per-expert fraction of tokens where it's active (weight > 0.01)
active = (self._last_routing_weights > 0.01).astype(mx.float32)
freq = active.reshape(-1, N).mean(axis=0)
return freq.var()
# --- Frozen gradient zeroing ---
def zero_frozen_grads(self, expert_grads: Any) -> Any:
"""Zero gradients for the expert_modules subtree of frozen experts."""
if not self._frozen_eids or not isinstance(expert_grads, list):
return expert_grads
result = []
for i, g in enumerate(expert_grads):
eid = self._expert_id_list[i] if i < len(self._expert_id_list) else None
if eid and eid in self._frozen_eids:
result.append(_zero_tree(g))
else:
result.append(g)
return result
def dr(self):
"""Update density EMA and detect distribution shift spikes."""
if self._last_density is None:
return
cfg = self.config
current = self._last_density.mean().item()
alpha = cfg.density_ema_alpha
# Update EMA of density
old_ema = self._density_ema
self._density_ema = (1 - alpha) * self._density_ema + alpha * current
diff = current - old_ema
self._density_var = (1 - alpha) * self._density_var + alpha * diff * diff
# Z-score spike detection
std = math.sqrt(max(self._density_var, 1e-8))
z = (current - self._density_ema) / std
self._drift_detected = z > cfg.density_spike_z
if self._drift_detected:
msg = (f"[step {self.global_step}][L{self.layer_idx}] "
f"DRIFT density={current:.1f} ema={self._density_ema:.1f} z={z:.1f}")
self._lifecycle_log.append(msg)
print(msg)
def compute_interference(self) -> Dict[str, float]:
if (self._last_routing_weights is None or self._last_input is None
or self._last_expert_outputs is None):
return {}
x = self._last_input
rw = self._last_routing_weights
B, L, D = x.shape
N = len(self._expert_id_list)
if N == 0:
return {}
T = min(self.config.interference_subsample, B * L)
rw_flat = rw.reshape(-1, N)[:T]
# Use cached expert outputs instead of re-running forward passes
expert_outs_flat = [eo.reshape(-1, D)[:T] for eo in self._last_expert_outputs]
# Combined mixture output on subsample
combined = mx.zeros((T, D))
for i in range(N):
combined = combined + rw_flat[:, i:i + 1] * expert_outs_flat[i]
combined = mx.stop_gradient(combined)
interference = {}
for i in range(N):
eid = self._expert_id_list[i]
w_i = rw_flat[:, i]
e_out = expert_outs_flat[i]
active = (w_i > 0.01).astype(mx.float32)
n_active = active.sum().item()
if n_active < 1.0:
interference[eid] = 0.0
continue
diff_norm = mx.linalg.norm(combined - e_out, axis=-1)
e_norm = mx.linalg.norm(e_out, axis=-1) + 1e-8
relative = diff_norm / e_norm
score = (relative * w_i * active).sum() / (n_active + 1e-8)
interference[eid] = score.item()
mx.eval(list(interference.values()))
return interference
def _compute_monolith_split_scores(self) -> Dict[str, float]:
scores = {}
if self._last_expert_outputs is None or not self.config.monolith_split_enabled:
return scores
cfg = self.config
for i, eid in enumerate(self._expert_id_list):
if i >= len(self._last_expert_outputs):
continue
eo = self._last_expert_outputs[i]
norms = mx.linalg.norm(eo.reshape(-1, eo.shape[-1]), axis=-1)
var = norms.var().item()
alpha = cfg.monolith_variance_ema_alpha
prev_mean = self._variance_ema.get(eid, var)
prev_sq = self._variance_ema_sq.get(eid, var * var)
new_mean = (1 - alpha) * prev_mean + alpha * var
new_sq = (1 - alpha) * prev_sq + alpha * var * var
self._variance_ema[eid] = new_mean
self._variance_ema_sq[eid] = new_sq
running_std = math.sqrt(max(new_sq - new_mean * new_mean, 1e-8))
z = (var - new_mean) / running_std
scores[eid] = z
return scores
# --- Lifecycle ---
def lifecycle_step(self, optimizer=None):
self.dr()
interference = self.compute_interference()
events = []
all_ids = list(self._expert_id_list) # snapshot before mutations
monolith_scores = self._compute_monolith_split_scores()
N = len(all_ids)
for eid in all_ids:
meta = self._expert_meta.get(eid)
if meta is None:
continue
meta.age += 1
if meta.cooldown > 0:
meta.cooldown -= 1
if meta.frozen_steps > 0:
meta.frozen_steps -= 1
if meta.frozen_steps == 0:
self._frozen_eids.discard(eid)
# Routing stats from cached data
if self._last_routing_weights is not None and eid in self._expert_id_list:
idx = self._eid_to_index(eid)
if idx < self._last_routing_weights.shape[-1]:
w = self._last_routing_weights[:, :, idx]
meta.avg_routing_weight = (
0.95 * meta.avg_routing_weight + 0.05 * w.mean().item()
)
meta.avg_activation_freq = (
0.95 * meta.avg_activation_freq
+ 0.05 * (w > 0.01).astype(mx.float32).mean().item()
)
# Interference EMAs
intf = interference.get(eid, 0.0)
af = self.config.ema_fast_alpha
asl = self.config.ema_slow_alpha
meta.ema_interference_fast = (1 - af) * meta.ema_interference_fast + af * intf
meta.ema_interference_slow = (1 - asl) * meta.ema_interference_slow + asl * intf
diff = intf - meta.ema_interference_slow
meta.ema_interference_var = 0.99 * meta.ema_interference_var + 0.01 * diff * diff
# Score by cannibalization z-score
scored = []
for eid in all_ids:
meta = self._expert_meta.get(eid)
if meta is None or eid not in self._expert_id_list:
continue
std = math.sqrt(max(meta.ema_interference_var, 1e-8))
intf_z = (meta.ema_interference_fast - meta.ema_interference_slow) / std
mono_z = monolith_scores.get(eid, 0.0)
if N <= 2:
z = mono_z
else:
z = max(intf_z, mono_z)
scored.append((eid, z, meta))
scored.sort(key=lambda t: -t[1])
# FIX: Lower split threshold during detected drift — system should react faster
effective_split_threshold = self.config.split_threshold
if self._drift_detected:
effective_split_threshold *= 0.7 # 30 % more sensitive during drift
# Split / Death
touched = set()
for eid, z_score, meta in scored:
if eid in touched or eid not in self._expert_id_list:
continue
if meta.age < self.config.min_expert_age or meta.cooldown > 0:
continue
budget_usage = self._total_params() / self.config.max_params_per_layer
if budget_usage > 0.7:
continue
threshold = self.config.monolith_variance_z_threshold if N <= 2 else effective_split_threshold
if (z_score > threshold
and len(self._expert_id_list) < self.config.max_experts_per_layer
and (self._total_params() + self._expert_param_count(meta.tier)
< self.config.max_params_per_layer)):
events.append(self._do_split(eid,optimizer=optimizer))
touched.add(eid)
continue
if (meta.avg_routing_weight < self.config.death_threshold
and len(self._expert_id_list) > 1):
events.append(self._do_death(eid, optimizer=optimizer))
touched.add(eid)
continue
events.extend(self._check_merges(touched, optimizer=optimizer))
for e in events:
msg = f"[step {self.global_step}][L{self.layer_idx}] {e}"
self._lifecycle_log.append(msg)
print(msg)
return events
# --- Importance-proportional preserver freeze ---
def _compute_freeze_steps(self, meta: ExpertMeta) -> int:
cfg = self.config
importance = max(0.0, min(1.0, meta.avg_routing_weight * 10.0))
freeze = int(
cfg.preserver_base_freeze_steps
+ importance * (cfg.preserver_max_freeze_steps - cfg.preserver_base_freeze_steps)
)
return freeze
"""
def _do_split(self, eid: str) -> str:
meta = self._expert_meta[eid]
parent = self._get_expert(eid)
parent_emb = self.router.get_embedding(eid)
freeze_steps = self._compute_freeze_steps(meta)
preserver_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent, noise_scale=0.0,
frozen_steps=freeze_steps,
init_embedding=parent_emb,
)
adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
mx.eval(adapter_emb)
adapter_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent,
noise_scale=self.config.adapter_noise_scale,
frozen_steps=0, init_embedding=adapter_emb,
)
self._remove_expert(eid)
self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
f"+ adapter {adapter_id[:8]}")
"""
"""
def _do_split(self, eid: str, optimizer=None) -> str:
meta = self._expert_meta[eid]
parent = self._get_expert(eid)
parent_emb = self.router.get_embedding(eid)
parent_idx = self._eid_to_index(eid)
parent_opt_state = None
parent_emb_opt_state = None
if optimizer is not None:
try:
import copy
layers_state = optimizer.state.get("layers", [])
moe_state = layers_state[self.layer_idx].get("moe", {})
expert_states = moe_state.get("expert_modules", [])
if parent_idx < len(expert_states):
parent_opt_state = copy.deepcopy(expert_states[parent_idx])
# Save parent router embedding state
router_state = moe_state.get("router", {})
emb_states = router_state.get("embeddings", [])
if parent_idx < len(emb_states):
parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx])
except (KeyError, IndexError, TypeError):
pass
freeze_steps = self._compute_freeze_steps(meta)
preserver_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent, noise_scale=0.0,
frozen_steps=freeze_steps,
init_embedding=parent_emb,
)
adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
mx.eval(adapter_emb)
adapter_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent,
noise_scale=self.config.adapter_noise_scale,
frozen_steps=0, init_embedding=adapter_emb,
)
# Copy optimizer state before removing parent
if optimizer is not None:
self._copy_optimizer_state(optimizer, parent_idx, preserver_id)
self._copy_optimizer_state(optimizer, parent_idx, adapter_id)
self._remove_expert(eid)
if optimizer is not None and parent_opt_state is not None:
try:
import copy
layers_state = optimizer.state["layers"]
moe_state = layers_state[self.layer_idx]["moe"]
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if expert_eid == preserver_id or expert_eid == adapter_id:
new_states.append(copy.deepcopy(parent_opt_state))
elif i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
except (KeyError, IndexError, TypeError):
pass
if optimizer is not None:
try:
layers_state = optimizer.state.get("layers", [])
expert_states = layers_state[self.layer_idx]["moe"]["expert_modules"]
if parent_idx < len(expert_states):
expert_states.pop(parent_idx)
except (KeyError, IndexError, TypeError):
pass
self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
f"+ adapter {adapter_id[:8]}")
"""
def _do_split(self, eid: str, optimizer=None) -> str:
meta = self._expert_meta[eid]
parent = self._get_expert(eid)
parent_emb = self.router.get_embedding(eid)
parent_idx = self._eid_to_index(eid)
parent_opt_state = None
parent_emb_opt_state = None
if optimizer is not None:
try:
import copy
layers_state = optimizer.state.get("layers", [])
moe_state = layers_state[self.layer_idx].get("moe", {})
expert_states = moe_state.get("expert_modules", [])
if parent_idx < len(expert_states):
parent_opt_state = copy.deepcopy(expert_states[parent_idx])
router_state = moe_state.get("router", {})
emb_states = router_state.get("embeddings", [])
if parent_idx < len(emb_states):
parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx])
except (KeyError, IndexError, TypeError):
pass
freeze_steps = self._compute_freeze_steps(meta)
preserver_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent, noise_scale=0.0,
frozen_steps=freeze_steps,
init_embedding=parent_emb,
)
adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1
mx.eval(adapter_emb)
adapter_id = self._create_expert(
tier=meta.tier, parent_id=eid,
init_weights_from=parent,
noise_scale=self.config.adapter_noise_scale,
frozen_steps=0, init_embedding=adapter_emb,
)
self._remove_expert(eid)
if optimizer is not None and parent_opt_state is not None:
try:
import copy
layers_state = optimizer.state["layers"]
moe_state = layers_state[self.layer_idx]["moe"]
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if expert_eid == preserver_id or expert_eid == adapter_id:
new_states.append(copy.deepcopy(parent_opt_state))
elif i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
# Rebuild router embeddings state
router_state = moe_state.get("router", {})
old_emb_states = router_state.get("embeddings", [])
new_emb_states = []
for i, emb_eid in enumerate(self.router._emb_ids):
if emb_eid == preserver_id or emb_eid == adapter_id:
if parent_emb_opt_state is not None:
new_emb_states.append(copy.deepcopy(parent_emb_opt_state))
else:
new_emb_states.append({})
elif i < len(old_emb_states):
new_emb_states.append(old_emb_states[i])
else:
new_emb_states.append({})
router_state["embeddings"] = new_emb_states
except (KeyError, IndexError, TypeError):
pass
self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps
self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps
return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> "
f"preserver {preserver_id[:8]} (frozen={freeze_steps}) "
f"+ adapter {adapter_id[:8]}")
def _do_death(self, eid: str, optimizer=None) -> str:
meta = self._expert_meta[eid]
info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})"
self._remove_expert(eid)
if optimizer is not None:
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx < len(layers_state):
moe_state = layers_state[self.layer_idx].get("moe", {})
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
# Rebuild router embeddings state
router_state = moe_state.get("router", {})
old_emb_states = router_state.get("embeddings", [])
new_emb_states = []
for i in range(len(self.router._emb_ids)):
if i < len(old_emb_states):
new_emb_states.append(old_emb_states[i])
else:
new_emb_states.append({})
router_state["embeddings"] = new_emb_states
except (KeyError, IndexError, TypeError):
pass
return info
"""
def _do_death(self, eid: str, optimizer=None) -> str:
meta = self._expert_meta[eid]
info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})"
self._remove_expert(eid)
if optimizer is not None:
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx < len(layers_state):
moe_state = layers_state[self.layer_idx].get("moe", {})
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
except (KeyError, IndexError, TypeError):
pass
return info
"""
def _average_expert_weights(self, expert_a: Expert, expert_b: Expert) -> List[Tuple[str, mx.array]]:
"""Average the weights of two same-shape experts."""
src_a = dict(tree_flatten(expert_a.parameters()))
src_b = dict(tree_flatten(expert_b.parameters()))
pairs = []
for k in src_a:
if k in src_b and src_a[k].shape == src_b[k].shape:
pairs.append((k, (src_a[k] + src_b[k]) / 2.0))
return pairs
def _check_merges(self, touched: set, optimizer=None) -> List[str]:
events = []
merged = set()
ids = list(self._expert_id_list)
cfg = self.config
# Pre-compute co-activation matrix from cached routing weights
co_activation = {}
if self._last_routing_weights is not None:
N = self._last_routing_weights.shape[-1]
active = (self._last_routing_weights > 0.01).astype(mx.float32)
# (B*L, N) binary activation matrix
act_flat = active.reshape(-1, N)
# Per-expert activation freq
act_freq = act_flat.mean(axis=0) # (N,)
mx.eval(act_freq)
def _can_merge(eid):
return (eid not in merged and eid not in touched
and eid in self._expert_id_list
and (meta := self._expert_meta.get(eid)) is not None
and meta.age >= cfg.min_expert_age
and meta.cooldown == 0)
def _do_merge(eid_a, eid_b, meta_a, meta_b, reason: str, optimizer=None) -> Optional[str]:
"""Execute a merge and return event string, or None if budget exceeded."""
new_tier = min(meta_a.tier + 1, len(cfg.tier_hidden_dims) - 1)
cost = self._expert_param_count(new_tier)
freed = (self._expert_param_count(meta_a.tier)
+ self._expert_param_count(meta_b.tier))
if self._total_params() - freed + cost > cfg.max_params_per_layer:
return None
emb_a = self.router.get_embedding(eid_a)
emb_b = self.router.get_embedding(eid_b)
avg_emb = (emb_a + emb_b) / 2.0
mx.eval(avg_emb)
if new_tier == meta_a.tier:
merged_expert_id = self._create_expert(
tier=new_tier, parent_id=eid_a,
init_weights_from=self._get_expert(eid_a),
init_embedding=avg_emb,
)
# Overwrite with averaged weights
avg_weights = self._average_expert_weights(
self._get_expert(eid_a), self._get_expert(eid_b))
if avg_weights:
self._get_expert(merged_expert_id).load_weights(avg_weights)
mx.eval(self._get_expert(merged_expert_id).parameters())
else:
# Tier-up merge: different hidden dim, can't average weights
merged_expert_id = self._create_expert(
tier=new_tier, parent_id=eid_a,
init_embedding=avg_emb,
)
self._expert_meta[merged_expert_id].cooldown = cfg.cooldown_steps
self._remove_expert(eid_a)
self._remove_expert(eid_b)
merged.add(eid_a)
merged.add(eid_b)
"""
if optimizer is not None:
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx < len(layers_state):
moe_state = layers_state[self.layer_idx].get("moe", {})
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if expert_eid == merged_expert_id:
new_states.append({}) # fresh state, no momentum to copy
elif i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
except (KeyError, IndexError, TypeError):
pass
"""
if optimizer is not None:
try:
layers_state = optimizer.state.get("layers", [])
if self.layer_idx < len(layers_state):
moe_state = layers_state[self.layer_idx].get("moe", {})
# Rebuild expert_modules state
old_states = moe_state.get("expert_modules", [])
new_states = []
for i, expert_eid in enumerate(self._expert_id_list):
if expert_eid == merged_expert_id:
new_states.append({})
elif i < len(old_states):
new_states.append(old_states[i])
else:
new_states.append({})
moe_state["expert_modules"] = new_states
# Rebuild router embeddings state
router_state = moe_state.get("router", {})
old_emb_states = router_state.get("embeddings", [])
new_emb_states = []
for i in range(len(self.router._emb_ids)):
if i < len(old_emb_states):
new_emb_states.append(old_emb_states[i])
else:
new_emb_states.append({})
router_state["embeddings"] = new_emb_states
except (KeyError, IndexError, TypeError):
pass
return (f"MERGE({reason}) {eid_a[:8]}+{eid_b[:8]} (T{meta_a.tier}) "
f"-> {merged_expert_id[:8]} (T{new_tier})")
# --- Force 1: Fragment merge (original: co-route + both weak) ---
for i, eid_a in enumerate(ids):
if not _can_merge(eid_a):
continue
meta_a = self._expert_meta[eid_a]
for j in range(i + 1, len(ids)):
eid_b = ids[j]
if not _can_merge(eid_b):
continue
meta_b = self._expert_meta[eid_b]
if meta_a.tier != meta_b.tier:
continue
emb_a = self.router.get_embedding(eid_a)
emb_b = self.router.get_embedding(eid_b)
cos = ((emb_a * emb_b).sum()
/ (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
both_weak = (meta_a.avg_routing_weight < cfg.merge_weakness_threshold
and meta_b.avg_routing_weight < cfg.merge_weakness_threshold)
if cos.item() > cfg.merge_co_route_threshold and both_weak:
result = _do_merge(eid_a, eid_b, meta_a, meta_b, "fragment", optimizer=optimizer)
if result:
events.append(result)
break
# --- Force 2: Capacity-pressure merge ---
budget_frac = self._total_params() / cfg.max_params_per_layer
if budget_frac > cfg.merge_capacity_pressure_frac:
# Find weakest same-tier pair with highest cosine similarity
candidates = []
for i, eid_a in enumerate(ids):
if not _can_merge(eid_a):
continue
meta_a = self._expert_meta.get(eid_a)
if meta_a is None:
continue
for j in range(i + 1, len(ids)):
eid_b = ids[j]
if not _can_merge(eid_b):
continue
meta_b = self._expert_meta.get(eid_b)
if meta_b is None or meta_a.tier != meta_b.tier:
continue
emb_a = self.router.get_embedding(eid_a)
emb_b = self.router.get_embedding(eid_b)
cos = ((emb_a * emb_b).sum()
/ (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
combined_w = meta_a.avg_routing_weight + meta_b.avg_routing_weight
# Score: high cosine + low combined weight = best merge candidate
score = cos.item() - combined_w
candidates.append((score, eid_a, eid_b, meta_a, meta_b))
candidates.sort(key=lambda t: -t[0])
for score, eid_a, eid_b, meta_a, meta_b in candidates:
if not _can_merge(eid_a) or not _can_merge(eid_b):
continue
result = _do_merge(eid_a, eid_b, meta_a, meta_b, "capacity",optimizer=optimizer)
if result:
events.append(result)
# Only do one capacity merge per lifecycle step to avoid cascades
break
# --- Force 3: Tier-gravity merge (same-tier co-activate frequently) ---
if self._last_routing_weights is not None:
N = self._last_routing_weights.shape[-1]
act_flat = (self._last_routing_weights > 0.01).astype(mx.float32).reshape(-1, N)
total_tokens = act_flat.shape[0]
for i, eid_a in enumerate(ids):
if not _can_merge(eid_a):
continue
meta_a = self._expert_meta.get(eid_a)
if meta_a is None:
continue
idx_a = self._eid_to_index(eid_a) if eid_a in self._expert_id_list else None
if idx_a is None or idx_a >= N:
continue
for j in range(i + 1, len(ids)):
eid_b = ids[j]
if not _can_merge(eid_b):
continue
meta_b = self._expert_meta.get(eid_b)
if meta_b is None or meta_a.tier != meta_b.tier:
continue
idx_b = self._eid_to_index(eid_b) if eid_b in self._expert_id_list else None
if idx_b is None or idx_b >= N:
continue
# Co-activation: fraction of tokens where both are active
both_active = (act_flat[:, idx_a] * act_flat[:, idx_b]).mean().item()
emb_a = self.router.get_embedding(eid_a)
emb_b = self.router.get_embedding(eid_b)
cos = ((emb_a * emb_b).sum()
/ (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8))
if (both_active > cfg.merge_tier_gravity_min_co_activation
and cos.item() > cfg.merge_tier_gravity_co_route):
result = _do_merge(eid_a, eid_b, meta_a, meta_b, "tier-gravity", optimizer=optimizer)
if result:
events.append(result)
break
return events
# ==========================================
# 8. MODEL COMPONENTS
# ==========================================
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.n_kv_heads = args.n_kv_heads
self.head_dim = args.dim // args.n_heads
self.scale = self.head_dim ** -0.5
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta)
def __call__(self, x, mask=None):
B, L, D = x.shape
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask)
return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1))
class MicroExpertsBlock(nn.Module):
def __init__(self, args: ModelArgs, me_config: MicroExpertConfig, layer_idx: int):
super().__init__()
self.attention = Attention(args)
self.moe = MicroExpertsMoELayer(args.dim, me_config, layer_idx)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def __call__(self, x, mask=None):
h = x + self.attention(self.attention_norm(x), mask)
return h + self.moe(self.ffn_norm(h))
class MicroExpertsModel(nn.Module):
def __init__(self, args: ModelArgs, me_config: MicroExpertConfig):
super().__init__()
self.args = args
self.me_config = me_config
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = [
MicroExpertsBlock(args, me_config, layer_idx=i)
for i in range(args.n_layers)
]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
def __call__(self, x):
L = x.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(L).astype(mx.float32)
mask = mask[None, None, :, :]
h = self.tok_embeddings(x)
for layer in self.layers:
h = layer(h, mask)
return self.output(self.norm(h))
def set_global_step(self, step: int):
for layer in self.layers:
layer.moe.global_step = step
def run_lifecycle(self, optimizer=None):
all_events = []
for layer in self.layers:
all_events.extend(layer.moe.lifecycle_step(optimizer=optimizer))
return all_events
def total_load_balance_loss(self) -> mx.array:
"""Sum of per-layer activation frequency variance."""
lb = mx.array(0.0)
for layer in self.layers:
lb = lb + layer.moe.load_balance_loss()
return lb
def zero_frozen_grads(self, grads):
"""Walk gradient tree, zero frozen expert parameters."""
if not isinstance(grads, dict) or "layers" not in grads:
return grads
new_layers = []
for i, lg in enumerate(grads["layers"]):
if (isinstance(lg, dict) and "moe" in lg
and isinstance(lg["moe"], dict)
and "expert_modules" in lg["moe"]):
moe = self.layers[i].moe
fixed = moe.zero_frozen_grads(lg["moe"]["expert_modules"])
new_moe = dict(lg["moe"])
new_moe["expert_modules"] = fixed
new_lg = dict(lg)
new_lg["moe"] = new_moe
new_layers.append(new_lg)
else:
new_layers.append(lg)
new_grads = dict(grads)
new_grads["layers"] = new_layers
return new_grads
def expert_summary(self) -> str:
lines = []
total_e, total_p = 0, 0
for i, layer in enumerate(self.layers):
moe = layer.moe
n = len(moe._expert_id_list)
p = moe._total_params()
total_e += n
total_p += p
tiers = defaultdict(int)
for m in moe._expert_meta.values():
tiers[m.tier] += 1
ts = " ".join(f"T{t}:{c}" for t, c in sorted(tiers.items()))
frozen = sum(1 for eid in moe._expert_id_list if eid in moe._frozen_eids)
drift = " DRIFT" if moe._drift_detected else ""
lines.append(
f" L{i:2d}: {n:3d} experts ({ts}) | {p/1e6:.1f}M | "
f"{frozen} frozen | d={moe._density_ema:.1f}{drift}")
lines.append(f" TOTAL: {total_e} experts | {total_p/1e6:.1f}M MoE params")
return "\n".join(lines)
def save_meta(self, path: str):
data = {}
for i, layer in enumerate(self.layers):
moe = layer.moe
data[f"layer_{i}"] = {
"expert_ids": list(moe._expert_id_list),
"experts": {eid: m.to_dict() for eid, m in moe._expert_meta.items()},
"density_ema": moe._density_ema,
}
with open(path, "w") as f:
json.dump(data, f, indent=2)
# ==========================================
# 9. DATA STREAMS
# ==========================================
def stream_gutenberg(tokenizer, batch_size: int, seq_len: int):
print("Connecting to Gutenberg stream...")
dataset = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True,)
dataset_iter = iter(dataset)
buffers = [[] for _ in range(batch_size)]
while True:
for i in range(batch_size):
while len(buffers[i]) < seq_len + 1:
try:
row = next(dataset_iter)
except StopIteration:
dataset_iter = iter(dataset)
row = next(dataset_iter)
text = row.get("conversations", "")
if isinstance(text, list):
parts = []
for msg in text:
role = msg.get("from", "")
content = msg.get("value", [])
if isinstance(content, str):
parts.append(f"{role}\n{content}")
text = "\n".join(parts)
#
if not text or len(text) < 10:
continue
buffers[i].extend(tokenizer.encode(text))
batch = []
for i in range(batch_size):
batch.append(buffers[i][:seq_len + 1])
buffers[i] = buffers[i][seq_len:]
yield mx.array(batch, dtype=mx.int32)
def stream_domain_files(tokenizer, data_dir: str, batch_size: int, seq_len: int):
files = sorted(glob.glob(os.path.join(data_dir, "*.txt")))
if not files:
raise FileNotFoundError(f"No .txt files in {data_dir}")
for fpath in files:
domain = os.path.splitext(os.path.basename(fpath))[0]
print(f"\n{'='*60}")
print(f" ACTIVE LEARNING — Domain: {domain}")
print(f"{'='*60}")
with open(fpath, "r", encoding="utf-8", errors="replace") as f:
text = f.read()
tokens = tokenizer.encode(text)
min_tokens = (seq_len + 1) * batch_size
if len(tokens) < min_tokens:
print(f" Skipping {domain}: {len(tokens)} tokens < {min_tokens} needed")
continue
def batch_gen(toks=tokens, bs=batch_size, sl=seq_len):
while True:
buf = list(toks)
while len(buf) >= bs * (sl + 1):
batch = []
for _ in range(bs):
batch.append(buf[:sl + 1])
buf = buf[sl:]
yield mx.array(batch, dtype=mx.int32)
yield domain, batch_gen()
# ==========================================
# 10. LOSS + CHECKPOINT
# ==========================================
def loss_fn(model, x):
"""Cross-entropy + load balance auxiliary loss."""
logits = model(x)
ce = nn.losses.cross_entropy(logits[:, :-1, :], x[:, 1:], reduction="mean")
lb = model.total_load_balance_loss()
return ce + model.me_config.load_balance_weight * lb
def load_checkpoint(model, path: str):
weights = dict(mx.load(path))
meta_path = path.replace(".npz", ".json")
with open(meta_path, "r") as f:
meta = json.load(f)
for i, layer in enumerate(model.layers):
moe = layer.moe
layer_key = f"layer_{i}"
if layer_key not in meta:
continue
layer_meta = meta[layer_key]
for eid in list(moe._expert_id_list):
moe._remove_expert(eid)
for eid in layer_meta["expert_ids"]:
em = layer_meta["experts"][eid]
tier = em["tier"]
hidden = moe._tier_to_hidden(tier)
expert = Expert(moe.model_dim, hidden)
mx.eval(expert.parameters())
moe.expert_modules.append(expert)
moe._expert_id_list.append(eid)
moe._expert_meta[eid] = ExpertMeta(
expert_id=eid, tier=tier, hidden_dim=hidden,
age=em.get("age", 0),
cooldown=em.get("cooldown", 0),
frozen_steps=em.get("frozen_steps", 0),
ema_interference_fast=em.get("ema_fast", 0.0),
ema_interference_slow=em.get("ema_slow", 0.0),
ema_interference_var=em.get("ema_var", 1.0),
avg_routing_weight=em.get("avg_rw", 0.1),
avg_activation_freq=em.get("avg_af", 0.1),
parent_id=em.get("parent_id"),
generation=em.get("generation", 0),
)
if em.get("frozen_steps", 0) > 0:
moe._frozen_eids.add(eid)
router_key = f"__router__.{i}.{eid}"
init_emb = weights.pop(router_key, None)
moe.router.add_expert(eid, init_embedding=init_emb)
moe._density_ema = layer_meta.get("density_ema", 1.0)
remaining = [(k, v) for k, v in weights.items() if not k.startswith("__router__")]
model.load_weights(remaining, strict=False)
mx.eval(model.parameters())
print(f" Loaded checkpoint from {path}")
def get_latest_checkpoint(checkpoint_dir: str):
if not os.path.exists(checkpoint_dir):
return None, 0
ckpts = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_step_*.npz")))
if not ckpts:
return None, 0
latest = ckpts[-1]
m = re.search(r"step_(\d+)", latest)
return latest, int(m.group(1))
def save_checkpoint(model, step: int, checkpoint_dir: str):
path = os.path.join(checkpoint_dir, f"checkpoint_step_{step}.npz")
save_dict = {}
for k, v in tree_flatten(model.parameters()):
save_dict[k] = v
for i, layer in enumerate(model.layers):
moe = layer.moe
for j, eid in enumerate(moe.router._emb_ids):
save_dict[f"__router__.{i}.{eid}"] = moe.router.embeddings[j].embedding
mx.savez(path, **save_dict)
model.save_meta(path.replace(".npz", ".json"))
print(f" Saved checkpoint {path}")
# ==========================================
# 11. TRAINING LOOP
# ==========================================
def train_loop(model, optimizer, data_iter, tc: TrainConfig,
start_step=0, max_steps=30000, lifecycle_every=10, label="train"):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
compiled_loss_and_grad = mx.compile(loss_and_grad_fn)
step = start_step
tic = time.time()
topology_changed = False
for batch in data_iter:
if step >= max_steps:
break
model.set_global_step(step)
# After a lifecycle event changes the expert topology (add/remove modules),
if topology_changed:
compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn))
topology_changed = False
try:
loss, grads = compiled_loss_and_grad(model, batch)
except Exception:
loss_and_grad_fn_eager = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn_eager(model, batch)
compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn))
grads = model.zero_frozen_grads(grads)
try:
optimizer.update(model, grads)
except (ValueError, KeyError, IndexError):
# Topology change left stale optimizer state — wipe and retry
optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))}
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)
if step > 0 and step % lifecycle_every == 0:
events = model.run_lifecycle(optimizer=optimizer)
if events:
topology_changed = True
#optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))}
"""
optimizer.update(model, grads)
mx.eval(model.parameters(), optimizer.state, loss)
"""
if step % tc.log_every == 0:
toc = time.time()
n_exp = sum(len(l.moe._expert_id_list) for l in model.layers)
avg_d = sum(
l.moe._last_density.mean().item()
for l in model.layers if l.moe._last_density is not None
) / model.args.n_layers
elapsed = toc - tic
tok_per_sec = (tc.log_every * tc.batch_size * model.args.max_seq_len) / max(elapsed, 1e-6)
print(f"[{label}] Step {step:6d} | Loss {loss.item():.4f} | "
f"Experts {n_exp} | Density {avg_d:.1f} | "
f"{tok_per_sec:.0f} tok/s | {elapsed:.2f}s")
tic = time.time()
if step > 0 and step % tc.summary_every == 0:
print(f"\n--- Expert Summary @ step {step} ---")
print(model.expert_summary())
print()
if step > 0 and step % tc.checkpoint_every == 0:
save_checkpoint(model, step, tc.checkpoint_dir)
step += 1
return step
# ==========================================
# 12. INTERACTIVE SETUP + MAIN
# ==========================================
def prompt_config() -> TrainConfig:
"""Interactive configuration via input() prompts."""
tc = TrainConfig()
print("\n" + "="*60)
print(" MicroExperts — Training Configuration")
print("="*60)
# Mode
print(" 1. pretrain — Gutenberg streaming pretraining")
print(" 2. active_learning — Sequential domain continual learning(not implemented yet)")
print(" 3. inference — Chat with the trained model")
print(" 4. interactive_learning — Chat and learn from your inputs")
print(" 5. train_and_chat — Train with periodic chat breaks")
choice = input("Mode [1]: ").strip()
if choice == "2":
tc.mode = "active_learning"
elif choice == "3":
tc.mode = "inference"
elif choice == "4":
tc.mode = "interactive_learning"
elif choice == "5":
tc.mode = "train_and_chat"
else:
tc.mode = "pretrain"
# Tokenizer
tok = "gutenberg_tokenizer.json"
if tok:
tc.tokenizer_file = tok
# Checkpoint dir
cd = input(f"Checkpoint directory [{tc.checkpoint_dir}]: ").strip()
if cd:
tc.checkpoint_dir = cd
# Batch size
bs = input(f"Batch size [{tc.batch_size}]: ").strip()
if bs:
tc.batch_size = int(bs)
# Learning rate
if tc.mode == "pretrain":
default_lr = tc.learning_rate
else:
default_lr = tc.al_learning_rate
lr = input(f"Learning rate [{default_lr}]: ").strip()
if lr:
tc.learning_rate = float(lr)
else:
tc.learning_rate = default_lr
# Max steps
ms = input(f"Max steps [{tc.max_steps}]: ").strip()
if ms:
tc.max_steps = int(ms)
# Resume
resume = input("Resume from checkpoint? [Y/n]: ").strip().lower()
tc._resume = resume != "n"
# Mode-specific
if tc.mode == "active_learning":
dd = input(f"Domain data directory [{tc.al_data_dir}]: ").strip()
if dd:
tc.al_data_dir = dd
spd = input(f"Steps per domain [{tc.al_steps_per_domain}]: ").strip()
if spd:
tc.al_steps_per_domain = int(spd)
print("\n" + "-"*60)
print(f" Mode: {tc.mode}")
print(f" LR: {tc.learning_rate}")
print(f" Batch: {tc.batch_size}")
print(f" Max steps: {tc.max_steps}")
print(f" Checkpoint: {tc.checkpoint_dir}")
print(f" Resume: {tc._resume}")
if tc.mode == "active_learning":
print(f" Data dir: {tc.al_data_dir}")
print(f" Steps/dom: {tc.al_steps_per_domain}")
print(f" M4 budget: 150M params/layer, 128 experts/layer max")
print("-"*60)
confirm = input("Continue? [Y/n]: ").strip().lower()
if confirm == "n":
print("Aborted.")
exit(0)
return tc
def generate(model, tokenizer, prompt: str, max_tokens: int = 256, temperature: float = 0.8):
tokens = tokenizer.encode(prompt)
tokens = mx.array([tokens], dtype=mx.int32)
for _ in range(max_tokens):
logits = model(tokens)
next_logits = logits[:, -1, :] / temperature
next_token = mx.random.categorical(next_logits)
next_token = next_token.reshape(1, 1)
tokens = mx.concatenate([tokens, next_token], axis=1)
mx.eval(tokens)
token_id = next_token.item()
if token_id == tokenizer.eos_token_id:
break
# Print expert usage per layer
print("\n Expert routing:")
for i, layer in enumerate(model.layers):
moe = layer.moe
if moe._last_routing_weights is None:
continue
rw = moe._last_routing_weights
N = rw.shape[-1]
# Average routing weight per expert across all tokens
avg_w = rw.reshape(-1, N).mean(axis=0)
active = (avg_w > 0.01)
parts = []
for j, eid in enumerate(moe._expert_id_list):
if j < N and active[j].item():
meta = moe._expert_meta.get(eid)
tier = meta.tier if meta else "?"
parts.append(f"{eid[:6]}(T{tier} w={avg_w[j].item():.3f})")
if parts:
print(f" L{i:2d}: {' '.join(parts)}")
return tokenizer.decode(tokens[0].tolist())
def main():
tc = prompt_config()
os.makedirs(tc.checkpoint_dir, exist_ok=True)
# Tokenizer
print(f"\nLoading tokenizer: {tc.tokenizer_file}")
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tc.tokenizer_file)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Model
args = ModelArgs()
args.vocab_size = len(tokenizer)
me_config = MicroExpertConfig()
if tc.mode == "active_learning":
me_config.split_threshold = tc.al_split_threshold
me_config.min_expert_age = tc.al_min_expert_age
print(f"Initializing MicroExperts model (vocab={args.vocab_size})...")
model = MicroExpertsModel(args, me_config)
# Resume
current_step = 0
if tc._resume:
ckpt, ckpt_step = get_latest_checkpoint(tc.checkpoint_dir)
if ckpt:
print(f"Resuming from {ckpt} @ step {ckpt_step}")
load_checkpoint(model, ckpt)
current_step = ckpt_step
else:
print("No checkpoint found — starting fresh.")
mx.eval(model.parameters())
n_params = sum(v.size for _, v in tree_flatten(model.parameters()))
print(f"Total params: {n_params / 1e6:.2f}M")
print("Initial layout:")
print(model.expert_summary())
optimizer = optim.AdamW(learning_rate=tc.learning_rate)
# ---- PRETRAIN ----
if tc.mode == "pretrain":
data = stream_gutenberg(tokenizer, tc.batch_size, args.max_seq_len)
print(f"\nStarting pretraining for {tc.max_steps} steps...")
final_step = train_loop(
model, optimizer, data, tc,
start_step=current_step, max_steps=tc.max_steps,
lifecycle_every=tc.lifecycle_every, label="pretrain",
)
elif tc.mode == "inference":
print("\nChat ready. Type 'quit' to exit.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ("quit", "exit"):
break
if not user_input:
continue
response = generate(model, tokenizer, user_input)
print(f"Model: {response}\n")
final_step = current_step
# ---- ACTIVE LEARNING ----
elif tc.mode == "active_learning":
lifecycle_every = tc.al_lifecycle_every
print(f"\nActive learning from: {tc.al_data_dir}")
print(f" Steps/domain: {tc.al_steps_per_domain} | Lifecycle every: {lifecycle_every}")
domain_gen = stream_domain_files(
tokenizer, tc.al_data_dir, tc.batch_size, args.max_seq_len)
global_step = current_step
for domain_name, batches in domain_gen:
domain_max = global_step + tc.al_steps_per_domain
n_before = sum(len(l.moe._expert_id_list) for l in model.layers)
print(f"\n Training '{domain_name}': steps {global_step} -> {domain_max}")
global_step = train_loop(
model, optimizer, batches, tc,
start_step=global_step, max_steps=domain_max,
lifecycle_every=lifecycle_every, label=f"AL:{domain_name}",
)
n_after = sum(len(l.moe._expert_id_list) for l in model.layers)
print(f"\n '{domain_name}' done. Experts: {n_before} -> {n_after} ({n_after-n_before:+d})")
print(model.expert_summary())
final_step = global_step
elif tc.mode == "interactive_learning":
if not tc._resume:
print("WARNING: No checkpoint loaded, model is random.")
il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate)
il_step = current_step
conversation_tokens = []
message_count = 0
print("\nInteractive learning ready. Type 'quit' to exit.")
print("The model learns from the conversation.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ("quit", "exit"):
break
if not user_input:
continue
response = generate(model, tokenizer, user_input)
print(f"Model: {response}\n")
conversation_tokens.extend(tokenizer.encode(user_input))
conversation_tokens.extend(tokenizer.encode(response))
message_count += 1
seq_len = model.args.max_seq_len
trained = False
# Train on full sequences when available
while len(conversation_tokens) >= seq_len + 1:
batch = mx.array([conversation_tokens[:seq_len + 1]], dtype=mx.int32)
conversation_tokens = conversation_tokens[seq_len:]
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, batch)
grads = model.zero_frozen_grads(grads)
il_optimizer.update(model, grads)
mx.eval(model.parameters(), il_optimizer.state, loss)
il_step += 1
model.set_global_step(il_step)
trained = True
print(f" [learned: loss={loss.item():.4f}, step={il_step}]")
# Force train every 2 messages even with partial sequence
if not trained and message_count % 2 == 0 and len(conversation_tokens) > 2:
pad_len = seq_len + 1
tokens_to_use = conversation_tokens[-pad_len:] if len(conversation_tokens) >= pad_len else conversation_tokens
# Pad if too short
while len(tokens_to_use) < pad_len:
tokens_to_use = tokens_to_use + tokens_to_use
tokens_to_use = tokens_to_use[:pad_len]
batch = mx.array([tokens_to_use], dtype=mx.int32)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, batch)
grads = model.zero_frozen_grads(grads)
il_optimizer.update(model, grads)
mx.eval(model.parameters(), il_optimizer.state, loss)
il_step += 1
model.set_global_step(il_step)
print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]")
# Lifecycle check
if il_step > 0 and il_step % tc.al_lifecycle_every == 0:
events = model.run_lifecycle()
if events:
il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
print(model.expert_summary())
save_checkpoint(model, il_step, tc.checkpoint_dir)
print("Model saved.")
final_step = il_step
elif tc.mode == "train_and_chat":
if not tc._resume:
print("WARNING: No checkpoint loaded, model is random.")
il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate)
il_step = current_step
conversation_tokens = []
message_count = 0
system_prompt = "You are a helpful assistant."
chat_history = []
print("\nChat Learning ready. Type 'quit' to exit.")
print("The model learns from the conversation with chat format.\n")
while True:
user_input = input("You: ").strip()
if user_input.lower() in ("quit", "exit"):
break
if not user_input:
continue
response = generate(model, tokenizer, user_input)
print(f"Model: {response}\n")
# Build chat-formatted training text
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": response})
chat_text = f"system\n{system_prompt}\n"
for msg in chat_history:
role = "human" if msg["role"] == "user" else "gpt"
chat_text += f"{role}\n{msg['content']}\n"
conversation_tokens = tokenizer.encode(chat_text)
message_count += 1
seq_len = model.args.max_seq_len
trained = False
# Train on full sequences from chat history
train_tokens = list(conversation_tokens)
while len(train_tokens) >= seq_len + 1:
batch = mx.array([train_tokens[:seq_len + 1]], dtype=mx.int32)
train_tokens = train_tokens[seq_len:]
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, batch)
grads = model.zero_frozen_grads(grads)
try:
il_optimizer.update(model, grads)
except (ValueError, KeyError, IndexError):
il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
il_optimizer.update(model, grads)
mx.eval(model.parameters(), il_optimizer.state, loss)
il_step += 1
model.set_global_step(il_step)
trained = True
print(f" [learned: loss={loss.item():.4f}, step={il_step}]")
# Force train every 2 messages even with partial sequence
if not trained and message_count % 2 == 0 and len(train_tokens) > 2:
pad_len = seq_len + 1
tokens_to_use = train_tokens[-pad_len:] if len(train_tokens) >= pad_len else train_tokens
while len(tokens_to_use) < pad_len:
tokens_to_use = tokens_to_use + tokens_to_use
tokens_to_use = tokens_to_use[:pad_len]
batch = mx.array([tokens_to_use], dtype=mx.int32)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, batch)
grads = model.zero_frozen_grads(grads)
try:
il_optimizer.update(model, grads)
except (ValueError, KeyError, IndexError):
il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))}
il_optimizer.update(model, grads)
mx.eval(model.parameters(), il_optimizer.state, loss)
il_step += 1
model.set_global_step(il_step)
print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]")
# Trim chat history if too long
max_history = 20
if len(chat_history) > max_history:
chat_history = chat_history[-max_history:]
# Lifecycle check
if il_step > 0 and il_step % tc.al_lifecycle_every == 0:
events = model.run_lifecycle(optimizer=il_optimizer)
if events:
pass # optimizer state already rebuilt in lifecycle
print(model.expert_summary())
save_checkpoint(model, il_step, tc.checkpoint_dir)
print("Model saved.")
final_step = il_step
# Save final
print("\nTraining complete.")
save_checkpoint(model, final_step, tc.checkpoint_dir)
print("Final layout:")
print(model.expert_summary())
if __name__ == "__main__":
main()