""" MicroExperts — Self-organizing dynamic Mixture-of-Experts for continual learning. Target hardware: Apple M4 with 48 GB unified memory. """ import time import math import uuid import json import numpy as np import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.utils import tree_flatten from datasets import load_dataset from transformers import PreTrainedTokenizerFast import os import glob import re from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Any from collections import defaultdict def one_hot(indices: mx.array, num_classes: int) -> mx.array: # Build a range vector [0, 1, ..., num_classes-1] and compare with indices flat = indices.reshape(-1) # (K,) arange = mx.arange(num_classes) # (num_classes,) oh = (flat[:, None] == arange[None, :]).astype(mx.float32) # (K, num_classes) return oh.reshape(*indices.shape, num_classes) # ========================================== # 1. CONFIGURATION # ========================================== @dataclass class ModelArgs: dim: int = 768 n_layers: int = 12 n_heads: int = 12 n_kv_heads: int = 12 vocab_size: int = -1 norm_eps: float = 1e-8 max_seq_len: int = 2048 rope_theta: float = 10000.0 @dataclass class MicroExpertConfig: """All hyperparameters for the MicroExperts MoE system.""" #tier_hidden_dims: Tuple[int, ...] = (512, 1024, 2048, 4096) tier_hidden_dims: Tuple[int, ...] = (256, 512, 1024, 2048) monolith_split_enabled: bool = True monolith_variance_ema_alpha: float = 0.02 monolith_variance_z_threshold: float = 1.5 # Router router_embed_dim: int = 128 min_experts_per_token: int = 1 max_experts_per_token: int = 64 # Cannibalization / lifecycle ema_fast_alpha: float = 0.05 ema_slow_alpha: float = 0.005 split_threshold: float = 2.0 # Relaxed merge thresholds so merges actually fire merge_co_route_threshold: float = 0.5 merge_weakness_threshold: float = 0.05 death_threshold: float = 0.001 min_expert_age: int = 50 cooldown_steps: int = 100 # Base freeze duration — actual duration scaled by importance preserver_base_freeze_steps: int = 100 preserver_max_freeze_steps: int = 200 adapter_noise_scale: float = 0.02 max_experts_per_layer: int = 12 max_params_per_layer: int = 20_000_000 # 20 M # Initial state init_tier: int = 2 # Interference interference_subsample: int = 64 # Load balance loss load_balance_weight: float = 0.01 # Capacity-pressure merge: trigger when pool exceeds this fraction of budget merge_capacity_pressure_frac: float = 0.8 # Tier-gravity merge: same-tier co-activation threshold (lower than fragment) merge_tier_gravity_co_route: float = 0.4 merge_tier_gravity_min_co_activation: float = 0.3 # both activated > 30 % of tokens density_ema_alpha: float = 0.02 density_spike_z: float = 2.5 # z-score above mean to flag distribution shift @dataclass class TrainConfig: """Training hyperparameters.""" mode: str = "pretrain" batch_size: int = 8 learning_rate: float = 3e-4 max_steps: int = 30_000 tokenizer_file: str = "gutenberg_tokenizer.json" checkpoint_dir: str = "checkpoints_me" log_every: int = 10 summary_every: int = 500 checkpoint_every: int = 1000 lifecycle_every: int = 10 # Active learning al_data_dir: str = "./domains" al_steps_per_domain: int = 2000 al_learning_rate: float = 1e-4 al_lifecycle_every: int = 5 al_split_threshold: float = 1.5 al_min_expert_age: int = 100 # ========================================== # 2. EXPERT MODULE # ========================================== class Expert(nn.Module): """Single MicroExpert: SwiGLU FFN.""" def __init__(self, model_dim: int, hidden_dim: int): super().__init__() self.w1 = nn.Linear(model_dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, model_dim, bias=False) self.w3 = nn.Linear(model_dim, hidden_dim, bias=False) def __call__(self, x): return self.w2(nn.silu(self.w1(x)) * self.w3(x)) # ========================================== # 3. EXPERT METADATA # ========================================== @dataclass class ExpertMeta: """Non-parameter state for one expert.""" expert_id: str tier: int hidden_dim: int age: int = 0 cooldown: int = 0 frozen_steps: int = 0 ema_interference_fast: float = 0.0 ema_interference_slow: float = 0.0 ema_interference_var: float = 1.0 avg_routing_weight: float = 0.1 avg_activation_freq: float = 0.1 parent_id: Optional[str] = None generation: int = 0 def to_dict(self) -> dict: return { "expert_id": self.expert_id, "tier": self.tier, "hidden_dim": self.hidden_dim, "age": self.age, "cooldown": self.cooldown, "frozen_steps": self.frozen_steps, "ema_fast": self.ema_interference_fast, "ema_slow": self.ema_interference_slow, "ema_var": self.ema_interference_var, "avg_rw": self.avg_routing_weight, "avg_af": self.avg_activation_freq, "parent_id": self.parent_id, "generation": self.generation, } # ========================================== # 4. EXPERT EMBEDDING (trainable nn.Module) # ========================================== class ExpertEmbedding(nn.Module): def __init__(self, dim: int, init: Optional[mx.array] = None): super().__init__() if init is not None: self.embedding = init else: scale = 1.0 / math.sqrt(dim) self.embedding = mx.random.normal((dim,)) * scale # ========================================== # 5. ADAPTIVE ROUTER # ========================================== class AdaptiveRouter(nn.Module): def __init__(self, model_dim: int, config: MicroExpertConfig): super().__init__() self.config = config self.d = config.router_embed_dim self.proj = nn.Linear(model_dim, self.d, bias=False) self.threshold_head = nn.Linear(model_dim, 1, bias=True) # Trainable embeddings — list of nn.Module (MLX discovers these) self.embeddings: List[ExpertEmbedding] = [] # Parallel ID list (same order) self._emb_ids: List[str] = [] def _id_to_idx(self, eid: str) -> int: return self._emb_ids.index(eid) def add_expert(self, expert_id: str, init_embedding: Optional[mx.array] = None): emb = ExpertEmbedding(self.d, init=init_embedding) mx.eval(emb.parameters()) self.embeddings.append(emb) self._emb_ids.append(expert_id) def remove_expert(self, expert_id: str): if expert_id not in self._emb_ids: return idx = self._id_to_idx(expert_id) self.embeddings.pop(idx) self._emb_ids.pop(idx) def get_embedding(self, expert_id: str) -> mx.array: return self.embeddings[self._id_to_idx(expert_id)].embedding def set_embedding(self, expert_id: str, emb: mx.array): self.embeddings[self._id_to_idx(expert_id)].embedding = emb def __call__(self, x: mx.array, expert_ids: List[str]): """ Returns: routing_weights: (B, L, N) sparse softmax-normalized raw_scores: (B, L, N) cosine similarities density: (B, L) active expert count per token """ B, L, D = x.shape N = len(expert_ids) if N == 0: z = mx.zeros((B, L, 1)) return z[:, :, :0], z[:, :, :0], mx.zeros((B, L)) # Project input to routing space and normalize h = self.proj(x) # (B, L, d) h_norm = h / (mx.linalg.norm(h, axis=-1, keepdims=True) + 1e-8) # Stack expert embeddings into matrix E = mx.stack([self.embeddings[self._emb_ids.index(eid)].embedding for eid in expert_ids], axis=0) # (N, d) E_norm = E / (mx.linalg.norm(E, axis=-1, keepdims=True) + 1e-8) raw_scores = h_norm @ E_norm.T # (B, L, N) # Adaptive per-token threshold threshold = mx.sigmoid(self.threshold_head(x)) # (B, L, 1) gate_mask = (raw_scores > threshold).astype(mx.float32) # Guarantee top-1 always active best_idx = mx.argmax(raw_scores, axis=-1) # (B, L) best_oh = one_hot(best_idx, N) # (B, L, N) gate_mask = mx.maximum(gate_mask, best_oh) # Cap maximum active experts max_k = self.config.max_experts_per_token if max_k < N: sorted_idx = mx.argsort(-raw_scores, axis=-1) rank = mx.argsort(sorted_idx, axis=-1) gate_mask = gate_mask * (rank < max_k).astype(mx.float32) # Softmax over active experts masked = raw_scores * gate_mask + (1.0 - gate_mask) * (-1e9) routing_weights = mx.softmax(masked, axis=-1) * gate_mask density = gate_mask.sum(axis=-1) return routing_weights, raw_scores, density # ========================================== # 6. UTILITY: zero a nested grad tree # ========================================== def _zero_tree(tree): """Recursively zero all mx.arrays in a nested structure.""" if isinstance(tree, mx.array): return mx.zeros_like(tree) elif isinstance(tree, dict): return {k: _zero_tree(v) for k, v in tree.items()} elif isinstance(tree, list): return [_zero_tree(v) for v in tree] return tree # ========================================== # 7. MoE LAYER # ========================================== class MicroExpertsMoELayer(nn.Module): def __init__(self, model_dim: int, config: MicroExpertConfig, layer_idx: int): super().__init__() self.model_dim = model_dim self.config = config self.layer_idx = layer_idx self.router = AdaptiveRouter(model_dim, config) self._variance_ema: Dict[str, float] = {} self._variance_ema_sq: Dict[str, float] = {} # Expert modules — list for MLX parameter discovery self.expert_modules: List[Expert] = [] self._expert_id_list: List[str] = [] self._expert_meta: Dict[str, ExpertMeta] = {} self._lifecycle_log: List[str] = [] self.global_step: int = 0 # Cached from forward pass (detached) self._last_routing_weights: Optional[mx.array] = None self._last_density: Optional[mx.array] = None self._last_input: Optional[mx.array] = None # FIX: Cache expert outputs to avoid redundant forward in interference self._last_expert_outputs: Optional[List[mx.array]] = None # Frozen expert tracking self._frozen_eids: set = set() # FIX: Density drift tracking self._density_ema: float = 1.0 self._density_var: float = 1.0 self._drift_detected: bool = False # Create initial monolith self._create_expert(tier=config.init_tier) # --- Helpers --- @property def expert_ids(self) -> List[str]: return list(self._expert_id_list) def _eid_to_index(self, eid: str) -> int: return self._expert_id_list.index(eid) def _get_expert(self, eid: str) -> Expert: return self.expert_modules[self._eid_to_index(eid)] def _tier_to_hidden(self, tier: int) -> int: t = min(tier, len(self.config.tier_hidden_dims) - 1) return self.config.tier_hidden_dims[t] def _expert_param_count(self, tier: int) -> int: return 3 * self.model_dim * self._tier_to_hidden(tier) def _total_params(self) -> int: return sum(self._expert_param_count(m.tier) for m in self._expert_meta.values()) def _make_id(self) -> str: return uuid.uuid4().hex[:12] """ def _copy_optimizer_state(self, optimizer, parent_idx: int, child_eid: str): try: layers_state = optimizer.state.get("layers", []) if self.layer_idx >= len(layers_state): return moe_state = layers_state[self.layer_idx].get("moe", {}) expert_states = moe_state.get("expert_modules", []) if parent_idx >= len(expert_states): return parent_state = expert_states[parent_idx] child_idx = self._eid_to_index(child_eid) # Grow the list if needed while len(expert_states) <= child_idx: expert_states.append({}) # Deep copy the parent state import copy expert_states[child_idx] = copy.deepcopy(parent_state) except (KeyError, IndexError, TypeError): pass """ def _copy_optimizer_state(self, optimizer, parent_idx: int, children_eids: list): """Copy parent's optimizer state to children, then rebuild list.""" try: layers_state = optimizer.state.get("layers", []) if self.layer_idx >= len(layers_state): return moe_state = layers_state[self.layer_idx].get("moe", {}) expert_states = moe_state.get("expert_modules", []) if parent_idx >= len(expert_states): return import copy parent_state = copy.deepcopy(expert_states[parent_idx]) # Build new list matching current expert_modules order new_states = [] for i, eid in enumerate(self._expert_id_list): if eid in children_eids: new_states.append(copy.deepcopy(parent_state)) elif i < len(expert_states): new_states.append(expert_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states except (KeyError, IndexError, TypeError): pass # --- Expert creation / removal --- def _create_expert( self, tier: int, parent_id: Optional[str] = None, init_weights_from: Optional[Expert] = None, noise_scale: float = 0.0, frozen_steps: int = 0, init_embedding: Optional[mx.array] = None, ) -> str: eid = self._make_id() hidden = self._tier_to_hidden(tier) expert = Expert(self.model_dim, hidden) if init_weights_from is not None: src = dict(tree_flatten(init_weights_from.parameters())) dst = dict(tree_flatten(expert.parameters())) pairs = [] for k in dst: if k in src and src[k].shape == dst[k].shape: w = src[k] if noise_scale > 0: w = w + mx.random.normal(w.shape) * noise_scale * (mx.abs(w).mean() + 1e-8) pairs.append((k, w)) if pairs: expert.load_weights(pairs) mx.eval(expert.parameters()) self.expert_modules.append(expert) self._expert_id_list.append(eid) gen = 0 if parent_id and parent_id in self._expert_meta: gen = self._expert_meta[parent_id].generation + 1 self._expert_meta[eid] = ExpertMeta( expert_id=eid, tier=tier, hidden_dim=hidden, frozen_steps=frozen_steps, parent_id=parent_id, generation=gen, ) if frozen_steps > 0: self._frozen_eids.add(eid) self.router.add_expert(eid, init_embedding=init_embedding) return eid def _remove_expert(self, eid: str): if eid not in self._expert_id_list: return idx = self._eid_to_index(eid) self.expert_modules.pop(idx) self._expert_id_list.pop(idx) self._expert_meta.pop(eid, None) self._frozen_eids.discard(eid) self.router.remove_expert(eid) # --- Forward --- def __call__(self, x: mx.array) -> mx.array: B, L, D = x.shape N = len(self._expert_id_list) if N == 0: return mx.zeros_like(x) routing_weights, raw_scores, density = self.router(x, self._expert_id_list) # Compute and cache individual expert outputs expert_outputs = [self.expert_modules[i](x) for i in range(N)] output = mx.zeros_like(x) for i in range(N): w_i = routing_weights[:, :, i:i + 1] output = output + w_i * expert_outputs[i] # Cache detached copies for interference computation self._last_routing_weights = mx.stop_gradient(routing_weights) self._last_density = mx.stop_gradient(density) self._last_input = mx.stop_gradient(x) self._last_expert_outputs = [mx.stop_gradient(eo) for eo in expert_outputs] return output # --- Load balance loss --- def load_balance_loss(self) -> mx.array: """ Variance of per-expert activation frequency across the last batch. Penalizes uneven usage — prevents expert starvation without forcing uniform routing (which would defeat specialization). """ if self._last_routing_weights is None: return mx.array(0.0) N = self._last_routing_weights.shape[-1] if N <= 1: return mx.array(0.0) # Per-expert fraction of tokens where it's active (weight > 0.01) active = (self._last_routing_weights > 0.01).astype(mx.float32) freq = active.reshape(-1, N).mean(axis=0) return freq.var() # --- Frozen gradient zeroing --- def zero_frozen_grads(self, expert_grads: Any) -> Any: """Zero gradients for the expert_modules subtree of frozen experts.""" if not self._frozen_eids or not isinstance(expert_grads, list): return expert_grads result = [] for i, g in enumerate(expert_grads): eid = self._expert_id_list[i] if i < len(self._expert_id_list) else None if eid and eid in self._frozen_eids: result.append(_zero_tree(g)) else: result.append(g) return result def dr(self): """Update density EMA and detect distribution shift spikes.""" if self._last_density is None: return cfg = self.config current = self._last_density.mean().item() alpha = cfg.density_ema_alpha # Update EMA of density old_ema = self._density_ema self._density_ema = (1 - alpha) * self._density_ema + alpha * current diff = current - old_ema self._density_var = (1 - alpha) * self._density_var + alpha * diff * diff # Z-score spike detection std = math.sqrt(max(self._density_var, 1e-8)) z = (current - self._density_ema) / std self._drift_detected = z > cfg.density_spike_z if self._drift_detected: msg = (f"[step {self.global_step}][L{self.layer_idx}] " f"DRIFT density={current:.1f} ema={self._density_ema:.1f} z={z:.1f}") self._lifecycle_log.append(msg) print(msg) def compute_interference(self) -> Dict[str, float]: if (self._last_routing_weights is None or self._last_input is None or self._last_expert_outputs is None): return {} x = self._last_input rw = self._last_routing_weights B, L, D = x.shape N = len(self._expert_id_list) if N == 0: return {} T = min(self.config.interference_subsample, B * L) rw_flat = rw.reshape(-1, N)[:T] # Use cached expert outputs instead of re-running forward passes expert_outs_flat = [eo.reshape(-1, D)[:T] for eo in self._last_expert_outputs] # Combined mixture output on subsample combined = mx.zeros((T, D)) for i in range(N): combined = combined + rw_flat[:, i:i + 1] * expert_outs_flat[i] combined = mx.stop_gradient(combined) interference = {} for i in range(N): eid = self._expert_id_list[i] w_i = rw_flat[:, i] e_out = expert_outs_flat[i] active = (w_i > 0.01).astype(mx.float32) n_active = active.sum().item() if n_active < 1.0: interference[eid] = 0.0 continue diff_norm = mx.linalg.norm(combined - e_out, axis=-1) e_norm = mx.linalg.norm(e_out, axis=-1) + 1e-8 relative = diff_norm / e_norm score = (relative * w_i * active).sum() / (n_active + 1e-8) interference[eid] = score.item() mx.eval(list(interference.values())) return interference def _compute_monolith_split_scores(self) -> Dict[str, float]: scores = {} if self._last_expert_outputs is None or not self.config.monolith_split_enabled: return scores cfg = self.config for i, eid in enumerate(self._expert_id_list): if i >= len(self._last_expert_outputs): continue eo = self._last_expert_outputs[i] norms = mx.linalg.norm(eo.reshape(-1, eo.shape[-1]), axis=-1) var = norms.var().item() alpha = cfg.monolith_variance_ema_alpha prev_mean = self._variance_ema.get(eid, var) prev_sq = self._variance_ema_sq.get(eid, var * var) new_mean = (1 - alpha) * prev_mean + alpha * var new_sq = (1 - alpha) * prev_sq + alpha * var * var self._variance_ema[eid] = new_mean self._variance_ema_sq[eid] = new_sq running_std = math.sqrt(max(new_sq - new_mean * new_mean, 1e-8)) z = (var - new_mean) / running_std scores[eid] = z return scores # --- Lifecycle --- def lifecycle_step(self, optimizer=None): self.dr() interference = self.compute_interference() events = [] all_ids = list(self._expert_id_list) # snapshot before mutations monolith_scores = self._compute_monolith_split_scores() N = len(all_ids) for eid in all_ids: meta = self._expert_meta.get(eid) if meta is None: continue meta.age += 1 if meta.cooldown > 0: meta.cooldown -= 1 if meta.frozen_steps > 0: meta.frozen_steps -= 1 if meta.frozen_steps == 0: self._frozen_eids.discard(eid) # Routing stats from cached data if self._last_routing_weights is not None and eid in self._expert_id_list: idx = self._eid_to_index(eid) if idx < self._last_routing_weights.shape[-1]: w = self._last_routing_weights[:, :, idx] meta.avg_routing_weight = ( 0.95 * meta.avg_routing_weight + 0.05 * w.mean().item() ) meta.avg_activation_freq = ( 0.95 * meta.avg_activation_freq + 0.05 * (w > 0.01).astype(mx.float32).mean().item() ) # Interference EMAs intf = interference.get(eid, 0.0) af = self.config.ema_fast_alpha asl = self.config.ema_slow_alpha meta.ema_interference_fast = (1 - af) * meta.ema_interference_fast + af * intf meta.ema_interference_slow = (1 - asl) * meta.ema_interference_slow + asl * intf diff = intf - meta.ema_interference_slow meta.ema_interference_var = 0.99 * meta.ema_interference_var + 0.01 * diff * diff # Score by cannibalization z-score scored = [] for eid in all_ids: meta = self._expert_meta.get(eid) if meta is None or eid not in self._expert_id_list: continue std = math.sqrt(max(meta.ema_interference_var, 1e-8)) intf_z = (meta.ema_interference_fast - meta.ema_interference_slow) / std mono_z = monolith_scores.get(eid, 0.0) if N <= 2: z = mono_z else: z = max(intf_z, mono_z) scored.append((eid, z, meta)) scored.sort(key=lambda t: -t[1]) # FIX: Lower split threshold during detected drift — system should react faster effective_split_threshold = self.config.split_threshold if self._drift_detected: effective_split_threshold *= 0.7 # 30 % more sensitive during drift # Split / Death touched = set() for eid, z_score, meta in scored: if eid in touched or eid not in self._expert_id_list: continue if meta.age < self.config.min_expert_age or meta.cooldown > 0: continue budget_usage = self._total_params() / self.config.max_params_per_layer if budget_usage > 0.7: continue threshold = self.config.monolith_variance_z_threshold if N <= 2 else effective_split_threshold if (z_score > threshold and len(self._expert_id_list) < self.config.max_experts_per_layer and (self._total_params() + self._expert_param_count(meta.tier) < self.config.max_params_per_layer)): events.append(self._do_split(eid,optimizer=optimizer)) touched.add(eid) continue if (meta.avg_routing_weight < self.config.death_threshold and len(self._expert_id_list) > 1): events.append(self._do_death(eid, optimizer=optimizer)) touched.add(eid) continue events.extend(self._check_merges(touched, optimizer=optimizer)) for e in events: msg = f"[step {self.global_step}][L{self.layer_idx}] {e}" self._lifecycle_log.append(msg) print(msg) return events # --- Importance-proportional preserver freeze --- def _compute_freeze_steps(self, meta: ExpertMeta) -> int: cfg = self.config importance = max(0.0, min(1.0, meta.avg_routing_weight * 10.0)) freeze = int( cfg.preserver_base_freeze_steps + importance * (cfg.preserver_max_freeze_steps - cfg.preserver_base_freeze_steps) ) return freeze """ def _do_split(self, eid: str) -> str: meta = self._expert_meta[eid] parent = self._get_expert(eid) parent_emb = self.router.get_embedding(eid) freeze_steps = self._compute_freeze_steps(meta) preserver_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=0.0, frozen_steps=freeze_steps, init_embedding=parent_emb, ) adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 mx.eval(adapter_emb) adapter_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=self.config.adapter_noise_scale, frozen_steps=0, init_embedding=adapter_emb, ) self._remove_expert(eid) self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " f"+ adapter {adapter_id[:8]}") """ """ def _do_split(self, eid: str, optimizer=None) -> str: meta = self._expert_meta[eid] parent = self._get_expert(eid) parent_emb = self.router.get_embedding(eid) parent_idx = self._eid_to_index(eid) parent_opt_state = None parent_emb_opt_state = None if optimizer is not None: try: import copy layers_state = optimizer.state.get("layers", []) moe_state = layers_state[self.layer_idx].get("moe", {}) expert_states = moe_state.get("expert_modules", []) if parent_idx < len(expert_states): parent_opt_state = copy.deepcopy(expert_states[parent_idx]) # Save parent router embedding state router_state = moe_state.get("router", {}) emb_states = router_state.get("embeddings", []) if parent_idx < len(emb_states): parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx]) except (KeyError, IndexError, TypeError): pass freeze_steps = self._compute_freeze_steps(meta) preserver_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=0.0, frozen_steps=freeze_steps, init_embedding=parent_emb, ) adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 mx.eval(adapter_emb) adapter_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=self.config.adapter_noise_scale, frozen_steps=0, init_embedding=adapter_emb, ) # Copy optimizer state before removing parent if optimizer is not None: self._copy_optimizer_state(optimizer, parent_idx, preserver_id) self._copy_optimizer_state(optimizer, parent_idx, adapter_id) self._remove_expert(eid) if optimizer is not None and parent_opt_state is not None: try: import copy layers_state = optimizer.state["layers"] moe_state = layers_state[self.layer_idx]["moe"] old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if expert_eid == preserver_id or expert_eid == adapter_id: new_states.append(copy.deepcopy(parent_opt_state)) elif i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states except (KeyError, IndexError, TypeError): pass if optimizer is not None: try: layers_state = optimizer.state.get("layers", []) expert_states = layers_state[self.layer_idx]["moe"]["expert_modules"] if parent_idx < len(expert_states): expert_states.pop(parent_idx) except (KeyError, IndexError, TypeError): pass self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " f"+ adapter {adapter_id[:8]}") """ def _do_split(self, eid: str, optimizer=None) -> str: meta = self._expert_meta[eid] parent = self._get_expert(eid) parent_emb = self.router.get_embedding(eid) parent_idx = self._eid_to_index(eid) parent_opt_state = None parent_emb_opt_state = None if optimizer is not None: try: import copy layers_state = optimizer.state.get("layers", []) moe_state = layers_state[self.layer_idx].get("moe", {}) expert_states = moe_state.get("expert_modules", []) if parent_idx < len(expert_states): parent_opt_state = copy.deepcopy(expert_states[parent_idx]) router_state = moe_state.get("router", {}) emb_states = router_state.get("embeddings", []) if parent_idx < len(emb_states): parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx]) except (KeyError, IndexError, TypeError): pass freeze_steps = self._compute_freeze_steps(meta) preserver_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=0.0, frozen_steps=freeze_steps, init_embedding=parent_emb, ) adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 mx.eval(adapter_emb) adapter_id = self._create_expert( tier=meta.tier, parent_id=eid, init_weights_from=parent, noise_scale=self.config.adapter_noise_scale, frozen_steps=0, init_embedding=adapter_emb, ) self._remove_expert(eid) if optimizer is not None and parent_opt_state is not None: try: import copy layers_state = optimizer.state["layers"] moe_state = layers_state[self.layer_idx]["moe"] old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if expert_eid == preserver_id or expert_eid == adapter_id: new_states.append(copy.deepcopy(parent_opt_state)) elif i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states # Rebuild router embeddings state router_state = moe_state.get("router", {}) old_emb_states = router_state.get("embeddings", []) new_emb_states = [] for i, emb_eid in enumerate(self.router._emb_ids): if emb_eid == preserver_id or emb_eid == adapter_id: if parent_emb_opt_state is not None: new_emb_states.append(copy.deepcopy(parent_emb_opt_state)) else: new_emb_states.append({}) elif i < len(old_emb_states): new_emb_states.append(old_emb_states[i]) else: new_emb_states.append({}) router_state["embeddings"] = new_emb_states except (KeyError, IndexError, TypeError): pass self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " f"+ adapter {adapter_id[:8]}") def _do_death(self, eid: str, optimizer=None) -> str: meta = self._expert_meta[eid] info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})" self._remove_expert(eid) if optimizer is not None: try: layers_state = optimizer.state.get("layers", []) if self.layer_idx < len(layers_state): moe_state = layers_state[self.layer_idx].get("moe", {}) old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states # Rebuild router embeddings state router_state = moe_state.get("router", {}) old_emb_states = router_state.get("embeddings", []) new_emb_states = [] for i in range(len(self.router._emb_ids)): if i < len(old_emb_states): new_emb_states.append(old_emb_states[i]) else: new_emb_states.append({}) router_state["embeddings"] = new_emb_states except (KeyError, IndexError, TypeError): pass return info """ def _do_death(self, eid: str, optimizer=None) -> str: meta = self._expert_meta[eid] info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})" self._remove_expert(eid) if optimizer is not None: try: layers_state = optimizer.state.get("layers", []) if self.layer_idx < len(layers_state): moe_state = layers_state[self.layer_idx].get("moe", {}) old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states except (KeyError, IndexError, TypeError): pass return info """ def _average_expert_weights(self, expert_a: Expert, expert_b: Expert) -> List[Tuple[str, mx.array]]: """Average the weights of two same-shape experts.""" src_a = dict(tree_flatten(expert_a.parameters())) src_b = dict(tree_flatten(expert_b.parameters())) pairs = [] for k in src_a: if k in src_b and src_a[k].shape == src_b[k].shape: pairs.append((k, (src_a[k] + src_b[k]) / 2.0)) return pairs def _check_merges(self, touched: set, optimizer=None) -> List[str]: events = [] merged = set() ids = list(self._expert_id_list) cfg = self.config # Pre-compute co-activation matrix from cached routing weights co_activation = {} if self._last_routing_weights is not None: N = self._last_routing_weights.shape[-1] active = (self._last_routing_weights > 0.01).astype(mx.float32) # (B*L, N) binary activation matrix act_flat = active.reshape(-1, N) # Per-expert activation freq act_freq = act_flat.mean(axis=0) # (N,) mx.eval(act_freq) def _can_merge(eid): return (eid not in merged and eid not in touched and eid in self._expert_id_list and (meta := self._expert_meta.get(eid)) is not None and meta.age >= cfg.min_expert_age and meta.cooldown == 0) def _do_merge(eid_a, eid_b, meta_a, meta_b, reason: str, optimizer=None) -> Optional[str]: """Execute a merge and return event string, or None if budget exceeded.""" new_tier = min(meta_a.tier + 1, len(cfg.tier_hidden_dims) - 1) cost = self._expert_param_count(new_tier) freed = (self._expert_param_count(meta_a.tier) + self._expert_param_count(meta_b.tier)) if self._total_params() - freed + cost > cfg.max_params_per_layer: return None emb_a = self.router.get_embedding(eid_a) emb_b = self.router.get_embedding(eid_b) avg_emb = (emb_a + emb_b) / 2.0 mx.eval(avg_emb) if new_tier == meta_a.tier: merged_expert_id = self._create_expert( tier=new_tier, parent_id=eid_a, init_weights_from=self._get_expert(eid_a), init_embedding=avg_emb, ) # Overwrite with averaged weights avg_weights = self._average_expert_weights( self._get_expert(eid_a), self._get_expert(eid_b)) if avg_weights: self._get_expert(merged_expert_id).load_weights(avg_weights) mx.eval(self._get_expert(merged_expert_id).parameters()) else: # Tier-up merge: different hidden dim, can't average weights merged_expert_id = self._create_expert( tier=new_tier, parent_id=eid_a, init_embedding=avg_emb, ) self._expert_meta[merged_expert_id].cooldown = cfg.cooldown_steps self._remove_expert(eid_a) self._remove_expert(eid_b) merged.add(eid_a) merged.add(eid_b) """ if optimizer is not None: try: layers_state = optimizer.state.get("layers", []) if self.layer_idx < len(layers_state): moe_state = layers_state[self.layer_idx].get("moe", {}) old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if expert_eid == merged_expert_id: new_states.append({}) # fresh state, no momentum to copy elif i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states except (KeyError, IndexError, TypeError): pass """ if optimizer is not None: try: layers_state = optimizer.state.get("layers", []) if self.layer_idx < len(layers_state): moe_state = layers_state[self.layer_idx].get("moe", {}) # Rebuild expert_modules state old_states = moe_state.get("expert_modules", []) new_states = [] for i, expert_eid in enumerate(self._expert_id_list): if expert_eid == merged_expert_id: new_states.append({}) elif i < len(old_states): new_states.append(old_states[i]) else: new_states.append({}) moe_state["expert_modules"] = new_states # Rebuild router embeddings state router_state = moe_state.get("router", {}) old_emb_states = router_state.get("embeddings", []) new_emb_states = [] for i in range(len(self.router._emb_ids)): if i < len(old_emb_states): new_emb_states.append(old_emb_states[i]) else: new_emb_states.append({}) router_state["embeddings"] = new_emb_states except (KeyError, IndexError, TypeError): pass return (f"MERGE({reason}) {eid_a[:8]}+{eid_b[:8]} (T{meta_a.tier}) " f"-> {merged_expert_id[:8]} (T{new_tier})") # --- Force 1: Fragment merge (original: co-route + both weak) --- for i, eid_a in enumerate(ids): if not _can_merge(eid_a): continue meta_a = self._expert_meta[eid_a] for j in range(i + 1, len(ids)): eid_b = ids[j] if not _can_merge(eid_b): continue meta_b = self._expert_meta[eid_b] if meta_a.tier != meta_b.tier: continue emb_a = self.router.get_embedding(eid_a) emb_b = self.router.get_embedding(eid_b) cos = ((emb_a * emb_b).sum() / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) both_weak = (meta_a.avg_routing_weight < cfg.merge_weakness_threshold and meta_b.avg_routing_weight < cfg.merge_weakness_threshold) if cos.item() > cfg.merge_co_route_threshold and both_weak: result = _do_merge(eid_a, eid_b, meta_a, meta_b, "fragment", optimizer=optimizer) if result: events.append(result) break # --- Force 2: Capacity-pressure merge --- budget_frac = self._total_params() / cfg.max_params_per_layer if budget_frac > cfg.merge_capacity_pressure_frac: # Find weakest same-tier pair with highest cosine similarity candidates = [] for i, eid_a in enumerate(ids): if not _can_merge(eid_a): continue meta_a = self._expert_meta.get(eid_a) if meta_a is None: continue for j in range(i + 1, len(ids)): eid_b = ids[j] if not _can_merge(eid_b): continue meta_b = self._expert_meta.get(eid_b) if meta_b is None or meta_a.tier != meta_b.tier: continue emb_a = self.router.get_embedding(eid_a) emb_b = self.router.get_embedding(eid_b) cos = ((emb_a * emb_b).sum() / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) combined_w = meta_a.avg_routing_weight + meta_b.avg_routing_weight # Score: high cosine + low combined weight = best merge candidate score = cos.item() - combined_w candidates.append((score, eid_a, eid_b, meta_a, meta_b)) candidates.sort(key=lambda t: -t[0]) for score, eid_a, eid_b, meta_a, meta_b in candidates: if not _can_merge(eid_a) or not _can_merge(eid_b): continue result = _do_merge(eid_a, eid_b, meta_a, meta_b, "capacity",optimizer=optimizer) if result: events.append(result) # Only do one capacity merge per lifecycle step to avoid cascades break # --- Force 3: Tier-gravity merge (same-tier co-activate frequently) --- if self._last_routing_weights is not None: N = self._last_routing_weights.shape[-1] act_flat = (self._last_routing_weights > 0.01).astype(mx.float32).reshape(-1, N) total_tokens = act_flat.shape[0] for i, eid_a in enumerate(ids): if not _can_merge(eid_a): continue meta_a = self._expert_meta.get(eid_a) if meta_a is None: continue idx_a = self._eid_to_index(eid_a) if eid_a in self._expert_id_list else None if idx_a is None or idx_a >= N: continue for j in range(i + 1, len(ids)): eid_b = ids[j] if not _can_merge(eid_b): continue meta_b = self._expert_meta.get(eid_b) if meta_b is None or meta_a.tier != meta_b.tier: continue idx_b = self._eid_to_index(eid_b) if eid_b in self._expert_id_list else None if idx_b is None or idx_b >= N: continue # Co-activation: fraction of tokens where both are active both_active = (act_flat[:, idx_a] * act_flat[:, idx_b]).mean().item() emb_a = self.router.get_embedding(eid_a) emb_b = self.router.get_embedding(eid_b) cos = ((emb_a * emb_b).sum() / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) if (both_active > cfg.merge_tier_gravity_min_co_activation and cos.item() > cfg.merge_tier_gravity_co_route): result = _do_merge(eid_a, eid_b, meta_a, meta_b, "tier-gravity", optimizer=optimizer) if result: events.append(result) break return events # ========================================== # 8. MODEL COMPONENTS # ========================================== class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() self.weight = mx.ones((dims,)) self.eps = eps def __call__(self, x): return mx.fast.rms_norm(x, self.weight, self.eps) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.n_kv_heads = args.n_kv_heads self.head_dim = args.dim // args.n_heads self.scale = self.head_dim ** -0.5 self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta) def __call__(self, x, mask=None): B, L, D = x.shape queries, keys, values = self.wq(x), self.wk(x), self.wv(x) queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) queries = self.rope(queries) keys = self.rope(keys) output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask) return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1)) class MicroExpertsBlock(nn.Module): def __init__(self, args: ModelArgs, me_config: MicroExpertConfig, layer_idx: int): super().__init__() self.attention = Attention(args) self.moe = MicroExpertsMoELayer(args.dim, me_config, layer_idx) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def __call__(self, x, mask=None): h = x + self.attention(self.attention_norm(x), mask) return h + self.moe(self.ffn_norm(h)) class MicroExpertsModel(nn.Module): def __init__(self, args: ModelArgs, me_config: MicroExpertConfig): super().__init__() self.args = args self.me_config = me_config self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = [ MicroExpertsBlock(args, me_config, layer_idx=i) for i in range(args.n_layers) ] self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear(args.dim, args.vocab_size, bias=False) def __call__(self, x): L = x.shape[1] mask = nn.MultiHeadAttention.create_additive_causal_mask(L).astype(mx.float32) mask = mask[None, None, :, :] h = self.tok_embeddings(x) for layer in self.layers: h = layer(h, mask) return self.output(self.norm(h)) def set_global_step(self, step: int): for layer in self.layers: layer.moe.global_step = step def run_lifecycle(self, optimizer=None): all_events = [] for layer in self.layers: all_events.extend(layer.moe.lifecycle_step(optimizer=optimizer)) return all_events def total_load_balance_loss(self) -> mx.array: """Sum of per-layer activation frequency variance.""" lb = mx.array(0.0) for layer in self.layers: lb = lb + layer.moe.load_balance_loss() return lb def zero_frozen_grads(self, grads): """Walk gradient tree, zero frozen expert parameters.""" if not isinstance(grads, dict) or "layers" not in grads: return grads new_layers = [] for i, lg in enumerate(grads["layers"]): if (isinstance(lg, dict) and "moe" in lg and isinstance(lg["moe"], dict) and "expert_modules" in lg["moe"]): moe = self.layers[i].moe fixed = moe.zero_frozen_grads(lg["moe"]["expert_modules"]) new_moe = dict(lg["moe"]) new_moe["expert_modules"] = fixed new_lg = dict(lg) new_lg["moe"] = new_moe new_layers.append(new_lg) else: new_layers.append(lg) new_grads = dict(grads) new_grads["layers"] = new_layers return new_grads def expert_summary(self) -> str: lines = [] total_e, total_p = 0, 0 for i, layer in enumerate(self.layers): moe = layer.moe n = len(moe._expert_id_list) p = moe._total_params() total_e += n total_p += p tiers = defaultdict(int) for m in moe._expert_meta.values(): tiers[m.tier] += 1 ts = " ".join(f"T{t}:{c}" for t, c in sorted(tiers.items())) frozen = sum(1 for eid in moe._expert_id_list if eid in moe._frozen_eids) drift = " DRIFT" if moe._drift_detected else "" lines.append( f" L{i:2d}: {n:3d} experts ({ts}) | {p/1e6:.1f}M | " f"{frozen} frozen | d={moe._density_ema:.1f}{drift}") lines.append(f" TOTAL: {total_e} experts | {total_p/1e6:.1f}M MoE params") return "\n".join(lines) def save_meta(self, path: str): data = {} for i, layer in enumerate(self.layers): moe = layer.moe data[f"layer_{i}"] = { "expert_ids": list(moe._expert_id_list), "experts": {eid: m.to_dict() for eid, m in moe._expert_meta.items()}, "density_ema": moe._density_ema, } with open(path, "w") as f: json.dump(data, f, indent=2) # ========================================== # 9. DATA STREAMS # ========================================== def stream_gutenberg(tokenizer, batch_size: int, seq_len: int): print("Connecting to Gutenberg stream...") dataset = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True,) dataset_iter = iter(dataset) buffers = [[] for _ in range(batch_size)] while True: for i in range(batch_size): while len(buffers[i]) < seq_len + 1: try: row = next(dataset_iter) except StopIteration: dataset_iter = iter(dataset) row = next(dataset_iter) text = row.get("conversations", "") if isinstance(text, list): parts = [] for msg in text: role = msg.get("from", "") content = msg.get("value", []) if isinstance(content, str): parts.append(f"{role}\n{content}") text = "\n".join(parts) # if not text or len(text) < 10: continue buffers[i].extend(tokenizer.encode(text)) batch = [] for i in range(batch_size): batch.append(buffers[i][:seq_len + 1]) buffers[i] = buffers[i][seq_len:] yield mx.array(batch, dtype=mx.int32) def stream_domain_files(tokenizer, data_dir: str, batch_size: int, seq_len: int): files = sorted(glob.glob(os.path.join(data_dir, "*.txt"))) if not files: raise FileNotFoundError(f"No .txt files in {data_dir}") for fpath in files: domain = os.path.splitext(os.path.basename(fpath))[0] print(f"\n{'='*60}") print(f" ACTIVE LEARNING — Domain: {domain}") print(f"{'='*60}") with open(fpath, "r", encoding="utf-8", errors="replace") as f: text = f.read() tokens = tokenizer.encode(text) min_tokens = (seq_len + 1) * batch_size if len(tokens) < min_tokens: print(f" Skipping {domain}: {len(tokens)} tokens < {min_tokens} needed") continue def batch_gen(toks=tokens, bs=batch_size, sl=seq_len): while True: buf = list(toks) while len(buf) >= bs * (sl + 1): batch = [] for _ in range(bs): batch.append(buf[:sl + 1]) buf = buf[sl:] yield mx.array(batch, dtype=mx.int32) yield domain, batch_gen() # ========================================== # 10. LOSS + CHECKPOINT # ========================================== def loss_fn(model, x): """Cross-entropy + load balance auxiliary loss.""" logits = model(x) ce = nn.losses.cross_entropy(logits[:, :-1, :], x[:, 1:], reduction="mean") lb = model.total_load_balance_loss() return ce + model.me_config.load_balance_weight * lb def load_checkpoint(model, path: str): weights = dict(mx.load(path)) meta_path = path.replace(".npz", ".json") with open(meta_path, "r") as f: meta = json.load(f) for i, layer in enumerate(model.layers): moe = layer.moe layer_key = f"layer_{i}" if layer_key not in meta: continue layer_meta = meta[layer_key] for eid in list(moe._expert_id_list): moe._remove_expert(eid) for eid in layer_meta["expert_ids"]: em = layer_meta["experts"][eid] tier = em["tier"] hidden = moe._tier_to_hidden(tier) expert = Expert(moe.model_dim, hidden) mx.eval(expert.parameters()) moe.expert_modules.append(expert) moe._expert_id_list.append(eid) moe._expert_meta[eid] = ExpertMeta( expert_id=eid, tier=tier, hidden_dim=hidden, age=em.get("age", 0), cooldown=em.get("cooldown", 0), frozen_steps=em.get("frozen_steps", 0), ema_interference_fast=em.get("ema_fast", 0.0), ema_interference_slow=em.get("ema_slow", 0.0), ema_interference_var=em.get("ema_var", 1.0), avg_routing_weight=em.get("avg_rw", 0.1), avg_activation_freq=em.get("avg_af", 0.1), parent_id=em.get("parent_id"), generation=em.get("generation", 0), ) if em.get("frozen_steps", 0) > 0: moe._frozen_eids.add(eid) router_key = f"__router__.{i}.{eid}" init_emb = weights.pop(router_key, None) moe.router.add_expert(eid, init_embedding=init_emb) moe._density_ema = layer_meta.get("density_ema", 1.0) remaining = [(k, v) for k, v in weights.items() if not k.startswith("__router__")] model.load_weights(remaining, strict=False) mx.eval(model.parameters()) print(f" Loaded checkpoint from {path}") def get_latest_checkpoint(checkpoint_dir: str): if not os.path.exists(checkpoint_dir): return None, 0 ckpts = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_step_*.npz"))) if not ckpts: return None, 0 latest = ckpts[-1] m = re.search(r"step_(\d+)", latest) return latest, int(m.group(1)) def save_checkpoint(model, step: int, checkpoint_dir: str): path = os.path.join(checkpoint_dir, f"checkpoint_step_{step}.npz") save_dict = {} for k, v in tree_flatten(model.parameters()): save_dict[k] = v for i, layer in enumerate(model.layers): moe = layer.moe for j, eid in enumerate(moe.router._emb_ids): save_dict[f"__router__.{i}.{eid}"] = moe.router.embeddings[j].embedding mx.savez(path, **save_dict) model.save_meta(path.replace(".npz", ".json")) print(f" Saved checkpoint {path}") # ========================================== # 11. TRAINING LOOP # ========================================== def train_loop(model, optimizer, data_iter, tc: TrainConfig, start_step=0, max_steps=30000, lifecycle_every=10, label="train"): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) compiled_loss_and_grad = mx.compile(loss_and_grad_fn) step = start_step tic = time.time() topology_changed = False for batch in data_iter: if step >= max_steps: break model.set_global_step(step) # After a lifecycle event changes the expert topology (add/remove modules), if topology_changed: compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn)) topology_changed = False try: loss, grads = compiled_loss_and_grad(model, batch) except Exception: loss_and_grad_fn_eager = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn_eager(model, batch) compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn)) grads = model.zero_frozen_grads(grads) try: optimizer.update(model, grads) except (ValueError, KeyError, IndexError): # Topology change left stale optimizer state — wipe and retry optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))} optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state, loss) if step > 0 and step % lifecycle_every == 0: events = model.run_lifecycle(optimizer=optimizer) if events: topology_changed = True #optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))} """ optimizer.update(model, grads) mx.eval(model.parameters(), optimizer.state, loss) """ if step % tc.log_every == 0: toc = time.time() n_exp = sum(len(l.moe._expert_id_list) for l in model.layers) avg_d = sum( l.moe._last_density.mean().item() for l in model.layers if l.moe._last_density is not None ) / model.args.n_layers elapsed = toc - tic tok_per_sec = (tc.log_every * tc.batch_size * model.args.max_seq_len) / max(elapsed, 1e-6) print(f"[{label}] Step {step:6d} | Loss {loss.item():.4f} | " f"Experts {n_exp} | Density {avg_d:.1f} | " f"{tok_per_sec:.0f} tok/s | {elapsed:.2f}s") tic = time.time() if step > 0 and step % tc.summary_every == 0: print(f"\n--- Expert Summary @ step {step} ---") print(model.expert_summary()) print() if step > 0 and step % tc.checkpoint_every == 0: save_checkpoint(model, step, tc.checkpoint_dir) step += 1 return step # ========================================== # 12. INTERACTIVE SETUP + MAIN # ========================================== def prompt_config() -> TrainConfig: """Interactive configuration via input() prompts.""" tc = TrainConfig() print("\n" + "="*60) print(" MicroExperts — Training Configuration") print("="*60) # Mode print(" 1. pretrain — Gutenberg streaming pretraining") print(" 2. active_learning — Sequential domain continual learning(not implemented yet)") print(" 3. inference — Chat with the trained model") print(" 4. interactive_learning — Chat and learn from your inputs") print(" 5. train_and_chat — Train with periodic chat breaks") choice = input("Mode [1]: ").strip() if choice == "2": tc.mode = "active_learning" elif choice == "3": tc.mode = "inference" elif choice == "4": tc.mode = "interactive_learning" elif choice == "5": tc.mode = "train_and_chat" else: tc.mode = "pretrain" # Tokenizer tok = "gutenberg_tokenizer.json" if tok: tc.tokenizer_file = tok # Checkpoint dir cd = input(f"Checkpoint directory [{tc.checkpoint_dir}]: ").strip() if cd: tc.checkpoint_dir = cd # Batch size bs = input(f"Batch size [{tc.batch_size}]: ").strip() if bs: tc.batch_size = int(bs) # Learning rate if tc.mode == "pretrain": default_lr = tc.learning_rate else: default_lr = tc.al_learning_rate lr = input(f"Learning rate [{default_lr}]: ").strip() if lr: tc.learning_rate = float(lr) else: tc.learning_rate = default_lr # Max steps ms = input(f"Max steps [{tc.max_steps}]: ").strip() if ms: tc.max_steps = int(ms) # Resume resume = input("Resume from checkpoint? [Y/n]: ").strip().lower() tc._resume = resume != "n" # Mode-specific if tc.mode == "active_learning": dd = input(f"Domain data directory [{tc.al_data_dir}]: ").strip() if dd: tc.al_data_dir = dd spd = input(f"Steps per domain [{tc.al_steps_per_domain}]: ").strip() if spd: tc.al_steps_per_domain = int(spd) print("\n" + "-"*60) print(f" Mode: {tc.mode}") print(f" LR: {tc.learning_rate}") print(f" Batch: {tc.batch_size}") print(f" Max steps: {tc.max_steps}") print(f" Checkpoint: {tc.checkpoint_dir}") print(f" Resume: {tc._resume}") if tc.mode == "active_learning": print(f" Data dir: {tc.al_data_dir}") print(f" Steps/dom: {tc.al_steps_per_domain}") print(f" M4 budget: 150M params/layer, 128 experts/layer max") print("-"*60) confirm = input("Continue? [Y/n]: ").strip().lower() if confirm == "n": print("Aborted.") exit(0) return tc def generate(model, tokenizer, prompt: str, max_tokens: int = 256, temperature: float = 0.8): tokens = tokenizer.encode(prompt) tokens = mx.array([tokens], dtype=mx.int32) for _ in range(max_tokens): logits = model(tokens) next_logits = logits[:, -1, :] / temperature next_token = mx.random.categorical(next_logits) next_token = next_token.reshape(1, 1) tokens = mx.concatenate([tokens, next_token], axis=1) mx.eval(tokens) token_id = next_token.item() if token_id == tokenizer.eos_token_id: break # Print expert usage per layer print("\n Expert routing:") for i, layer in enumerate(model.layers): moe = layer.moe if moe._last_routing_weights is None: continue rw = moe._last_routing_weights N = rw.shape[-1] # Average routing weight per expert across all tokens avg_w = rw.reshape(-1, N).mean(axis=0) active = (avg_w > 0.01) parts = [] for j, eid in enumerate(moe._expert_id_list): if j < N and active[j].item(): meta = moe._expert_meta.get(eid) tier = meta.tier if meta else "?" parts.append(f"{eid[:6]}(T{tier} w={avg_w[j].item():.3f})") if parts: print(f" L{i:2d}: {' '.join(parts)}") return tokenizer.decode(tokens[0].tolist()) def main(): tc = prompt_config() os.makedirs(tc.checkpoint_dir, exist_ok=True) # Tokenizer print(f"\nLoading tokenizer: {tc.tokenizer_file}") tokenizer = PreTrainedTokenizerFast(tokenizer_file=tc.tokenizer_file) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Model args = ModelArgs() args.vocab_size = len(tokenizer) me_config = MicroExpertConfig() if tc.mode == "active_learning": me_config.split_threshold = tc.al_split_threshold me_config.min_expert_age = tc.al_min_expert_age print(f"Initializing MicroExperts model (vocab={args.vocab_size})...") model = MicroExpertsModel(args, me_config) # Resume current_step = 0 if tc._resume: ckpt, ckpt_step = get_latest_checkpoint(tc.checkpoint_dir) if ckpt: print(f"Resuming from {ckpt} @ step {ckpt_step}") load_checkpoint(model, ckpt) current_step = ckpt_step else: print("No checkpoint found — starting fresh.") mx.eval(model.parameters()) n_params = sum(v.size for _, v in tree_flatten(model.parameters())) print(f"Total params: {n_params / 1e6:.2f}M") print("Initial layout:") print(model.expert_summary()) optimizer = optim.AdamW(learning_rate=tc.learning_rate) # ---- PRETRAIN ---- if tc.mode == "pretrain": data = stream_gutenberg(tokenizer, tc.batch_size, args.max_seq_len) print(f"\nStarting pretraining for {tc.max_steps} steps...") final_step = train_loop( model, optimizer, data, tc, start_step=current_step, max_steps=tc.max_steps, lifecycle_every=tc.lifecycle_every, label="pretrain", ) elif tc.mode == "inference": print("\nChat ready. Type 'quit' to exit.\n") while True: user_input = input("You: ").strip() if user_input.lower() in ("quit", "exit"): break if not user_input: continue response = generate(model, tokenizer, user_input) print(f"Model: {response}\n") final_step = current_step # ---- ACTIVE LEARNING ---- elif tc.mode == "active_learning": lifecycle_every = tc.al_lifecycle_every print(f"\nActive learning from: {tc.al_data_dir}") print(f" Steps/domain: {tc.al_steps_per_domain} | Lifecycle every: {lifecycle_every}") domain_gen = stream_domain_files( tokenizer, tc.al_data_dir, tc.batch_size, args.max_seq_len) global_step = current_step for domain_name, batches in domain_gen: domain_max = global_step + tc.al_steps_per_domain n_before = sum(len(l.moe._expert_id_list) for l in model.layers) print(f"\n Training '{domain_name}': steps {global_step} -> {domain_max}") global_step = train_loop( model, optimizer, batches, tc, start_step=global_step, max_steps=domain_max, lifecycle_every=lifecycle_every, label=f"AL:{domain_name}", ) n_after = sum(len(l.moe._expert_id_list) for l in model.layers) print(f"\n '{domain_name}' done. Experts: {n_before} -> {n_after} ({n_after-n_before:+d})") print(model.expert_summary()) final_step = global_step elif tc.mode == "interactive_learning": if not tc._resume: print("WARNING: No checkpoint loaded, model is random.") il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate) il_step = current_step conversation_tokens = [] message_count = 0 print("\nInteractive learning ready. Type 'quit' to exit.") print("The model learns from the conversation.\n") while True: user_input = input("You: ").strip() if user_input.lower() in ("quit", "exit"): break if not user_input: continue response = generate(model, tokenizer, user_input) print(f"Model: {response}\n") conversation_tokens.extend(tokenizer.encode(user_input)) conversation_tokens.extend(tokenizer.encode(response)) message_count += 1 seq_len = model.args.max_seq_len trained = False # Train on full sequences when available while len(conversation_tokens) >= seq_len + 1: batch = mx.array([conversation_tokens[:seq_len + 1]], dtype=mx.int32) conversation_tokens = conversation_tokens[seq_len:] loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn(model, batch) grads = model.zero_frozen_grads(grads) il_optimizer.update(model, grads) mx.eval(model.parameters(), il_optimizer.state, loss) il_step += 1 model.set_global_step(il_step) trained = True print(f" [learned: loss={loss.item():.4f}, step={il_step}]") # Force train every 2 messages even with partial sequence if not trained and message_count % 2 == 0 and len(conversation_tokens) > 2: pad_len = seq_len + 1 tokens_to_use = conversation_tokens[-pad_len:] if len(conversation_tokens) >= pad_len else conversation_tokens # Pad if too short while len(tokens_to_use) < pad_len: tokens_to_use = tokens_to_use + tokens_to_use tokens_to_use = tokens_to_use[:pad_len] batch = mx.array([tokens_to_use], dtype=mx.int32) loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn(model, batch) grads = model.zero_frozen_grads(grads) il_optimizer.update(model, grads) mx.eval(model.parameters(), il_optimizer.state, loss) il_step += 1 model.set_global_step(il_step) print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]") # Lifecycle check if il_step > 0 and il_step % tc.al_lifecycle_every == 0: events = model.run_lifecycle() if events: il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} print(model.expert_summary()) save_checkpoint(model, il_step, tc.checkpoint_dir) print("Model saved.") final_step = il_step elif tc.mode == "train_and_chat": if not tc._resume: print("WARNING: No checkpoint loaded, model is random.") il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate) il_step = current_step conversation_tokens = [] message_count = 0 system_prompt = "You are a helpful assistant." chat_history = [] print("\nChat Learning ready. Type 'quit' to exit.") print("The model learns from the conversation with chat format.\n") while True: user_input = input("You: ").strip() if user_input.lower() in ("quit", "exit"): break if not user_input: continue response = generate(model, tokenizer, user_input) print(f"Model: {response}\n") # Build chat-formatted training text chat_history.append({"role": "user", "content": user_input}) chat_history.append({"role": "assistant", "content": response}) chat_text = f"system\n{system_prompt}\n" for msg in chat_history: role = "human" if msg["role"] == "user" else "gpt" chat_text += f"{role}\n{msg['content']}\n" conversation_tokens = tokenizer.encode(chat_text) message_count += 1 seq_len = model.args.max_seq_len trained = False # Train on full sequences from chat history train_tokens = list(conversation_tokens) while len(train_tokens) >= seq_len + 1: batch = mx.array([train_tokens[:seq_len + 1]], dtype=mx.int32) train_tokens = train_tokens[seq_len:] loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn(model, batch) grads = model.zero_frozen_grads(grads) try: il_optimizer.update(model, grads) except (ValueError, KeyError, IndexError): il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} il_optimizer.update(model, grads) mx.eval(model.parameters(), il_optimizer.state, loss) il_step += 1 model.set_global_step(il_step) trained = True print(f" [learned: loss={loss.item():.4f}, step={il_step}]") # Force train every 2 messages even with partial sequence if not trained and message_count % 2 == 0 and len(train_tokens) > 2: pad_len = seq_len + 1 tokens_to_use = train_tokens[-pad_len:] if len(train_tokens) >= pad_len else train_tokens while len(tokens_to_use) < pad_len: tokens_to_use = tokens_to_use + tokens_to_use tokens_to_use = tokens_to_use[:pad_len] batch = mx.array([tokens_to_use], dtype=mx.int32) loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss, grads = loss_and_grad_fn(model, batch) grads = model.zero_frozen_grads(grads) try: il_optimizer.update(model, grads) except (ValueError, KeyError, IndexError): il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} il_optimizer.update(model, grads) mx.eval(model.parameters(), il_optimizer.state, loss) il_step += 1 model.set_global_step(il_step) print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]") # Trim chat history if too long max_history = 20 if len(chat_history) > max_history: chat_history = chat_history[-max_history:] # Lifecycle check if il_step > 0 and il_step % tc.al_lifecycle_every == 0: events = model.run_lifecycle(optimizer=il_optimizer) if events: pass # optimizer state already rebuilt in lifecycle print(model.expert_summary()) save_checkpoint(model, il_step, tc.checkpoint_dir) print("Model saved.") final_step = il_step # Save final print("\nTraining complete.") save_checkpoint(model, final_step, tc.checkpoint_dir) print("Final layout:") print(model.expert_summary()) if __name__ == "__main__": main()