| """ |
| MicroExperts — Self-organizing dynamic Mixture-of-Experts for continual learning. |
| |
| |
| Target hardware: Apple M4 with 48 GB unified memory. |
| """ |
|
|
| import time |
| import math |
| import uuid |
| import json |
| import numpy as np |
| import mlx.core as mx |
| import mlx.nn as nn |
| import mlx.optimizers as optim |
| from mlx.utils import tree_flatten |
| from datasets import load_dataset |
| from transformers import PreTrainedTokenizerFast |
| import os |
| import glob |
| import re |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Tuple, Any |
| from collections import defaultdict |
|
|
|
|
|
|
| def one_hot(indices: mx.array, num_classes: int) -> mx.array: |
| |
| |
| flat = indices.reshape(-1) |
| arange = mx.arange(num_classes) |
| oh = (flat[:, None] == arange[None, :]).astype(mx.float32) |
| return oh.reshape(*indices.shape, num_classes) |
|
|
| |
| |
| |
| @dataclass |
| class ModelArgs: |
| dim: int = 768 |
| n_layers: int = 12 |
| n_heads: int = 12 |
| n_kv_heads: int = 12 |
| vocab_size: int = -1 |
| norm_eps: float = 1e-8 |
| max_seq_len: int = 2048 |
| rope_theta: float = 10000.0 |
|
|
|
|
| @dataclass |
| class MicroExpertConfig: |
| """All hyperparameters for the MicroExperts MoE system.""" |
| |
| tier_hidden_dims: Tuple[int, ...] = (256, 512, 1024, 2048) |
|
|
| monolith_split_enabled: bool = True |
| monolith_variance_ema_alpha: float = 0.02 |
| monolith_variance_z_threshold: float = 1.5 |
|
|
| |
| router_embed_dim: int = 128 |
| min_experts_per_token: int = 1 |
| max_experts_per_token: int = 64 |
|
|
| |
| ema_fast_alpha: float = 0.05 |
| ema_slow_alpha: float = 0.005 |
| split_threshold: float = 2.0 |
| |
| merge_co_route_threshold: float = 0.5 |
| merge_weakness_threshold: float = 0.05 |
| death_threshold: float = 0.001 |
| min_expert_age: int = 50 |
| cooldown_steps: int = 100 |
| |
| preserver_base_freeze_steps: int = 100 |
| preserver_max_freeze_steps: int = 200 |
| adapter_noise_scale: float = 0.02 |
|
|
|
|
| max_experts_per_layer: int = 12 |
| max_params_per_layer: int = 20_000_000 |
|
|
| |
| init_tier: int = 2 |
|
|
| |
| interference_subsample: int = 64 |
|
|
| |
| load_balance_weight: float = 0.01 |
|
|
| |
| merge_capacity_pressure_frac: float = 0.8 |
| |
| merge_tier_gravity_co_route: float = 0.4 |
| merge_tier_gravity_min_co_activation: float = 0.3 |
|
|
|
|
| density_ema_alpha: float = 0.02 |
| density_spike_z: float = 2.5 |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| """Training hyperparameters.""" |
| mode: str = "pretrain" |
| batch_size: int = 8 |
| learning_rate: float = 3e-4 |
| max_steps: int = 30_000 |
| tokenizer_file: str = "gutenberg_tokenizer.json" |
| checkpoint_dir: str = "checkpoints_me" |
| log_every: int = 10 |
| summary_every: int = 500 |
| checkpoint_every: int = 1000 |
| lifecycle_every: int = 10 |
|
|
| |
| al_data_dir: str = "./domains" |
| al_steps_per_domain: int = 2000 |
| al_learning_rate: float = 1e-4 |
| al_lifecycle_every: int = 5 |
| al_split_threshold: float = 1.5 |
| al_min_expert_age: int = 100 |
|
|
|
|
| |
| |
| |
| class Expert(nn.Module): |
| """Single MicroExpert: SwiGLU FFN.""" |
|
|
| def __init__(self, model_dim: int, hidden_dim: int): |
| super().__init__() |
| self.w1 = nn.Linear(model_dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, model_dim, bias=False) |
| self.w3 = nn.Linear(model_dim, hidden_dim, bias=False) |
|
|
| def __call__(self, x): |
| return self.w2(nn.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| |
| |
| |
| @dataclass |
| class ExpertMeta: |
| """Non-parameter state for one expert.""" |
| expert_id: str |
| tier: int |
| hidden_dim: int |
| age: int = 0 |
| cooldown: int = 0 |
| frozen_steps: int = 0 |
| ema_interference_fast: float = 0.0 |
| ema_interference_slow: float = 0.0 |
| ema_interference_var: float = 1.0 |
| avg_routing_weight: float = 0.1 |
| avg_activation_freq: float = 0.1 |
| parent_id: Optional[str] = None |
| generation: int = 0 |
|
|
| def to_dict(self) -> dict: |
| return { |
| "expert_id": self.expert_id, "tier": self.tier, |
| "hidden_dim": self.hidden_dim, "age": self.age, |
| "cooldown": self.cooldown, "frozen_steps": self.frozen_steps, |
| "ema_fast": self.ema_interference_fast, |
| "ema_slow": self.ema_interference_slow, |
| "ema_var": self.ema_interference_var, |
| "avg_rw": self.avg_routing_weight, |
| "avg_af": self.avg_activation_freq, |
| "parent_id": self.parent_id, "generation": self.generation, |
| } |
|
|
|
|
| |
| |
| |
| class ExpertEmbedding(nn.Module): |
|
|
| def __init__(self, dim: int, init: Optional[mx.array] = None): |
| super().__init__() |
| if init is not None: |
| self.embedding = init |
| else: |
| scale = 1.0 / math.sqrt(dim) |
| self.embedding = mx.random.normal((dim,)) * scale |
|
|
|
|
| |
| |
| |
| class AdaptiveRouter(nn.Module): |
|
|
| def __init__(self, model_dim: int, config: MicroExpertConfig): |
| super().__init__() |
| self.config = config |
| self.d = config.router_embed_dim |
| self.proj = nn.Linear(model_dim, self.d, bias=False) |
| self.threshold_head = nn.Linear(model_dim, 1, bias=True) |
|
|
| |
| self.embeddings: List[ExpertEmbedding] = [] |
| |
| self._emb_ids: List[str] = [] |
|
|
| def _id_to_idx(self, eid: str) -> int: |
| return self._emb_ids.index(eid) |
|
|
| def add_expert(self, expert_id: str, init_embedding: Optional[mx.array] = None): |
| emb = ExpertEmbedding(self.d, init=init_embedding) |
| mx.eval(emb.parameters()) |
| self.embeddings.append(emb) |
| self._emb_ids.append(expert_id) |
|
|
| def remove_expert(self, expert_id: str): |
| if expert_id not in self._emb_ids: |
| return |
| idx = self._id_to_idx(expert_id) |
| self.embeddings.pop(idx) |
| self._emb_ids.pop(idx) |
|
|
| def get_embedding(self, expert_id: str) -> mx.array: |
| return self.embeddings[self._id_to_idx(expert_id)].embedding |
|
|
| def set_embedding(self, expert_id: str, emb: mx.array): |
| self.embeddings[self._id_to_idx(expert_id)].embedding = emb |
|
|
| def __call__(self, x: mx.array, expert_ids: List[str]): |
| """ |
| Returns: |
| routing_weights: (B, L, N) sparse softmax-normalized |
| raw_scores: (B, L, N) cosine similarities |
| density: (B, L) active expert count per token |
| """ |
| B, L, D = x.shape |
| N = len(expert_ids) |
|
|
| if N == 0: |
| z = mx.zeros((B, L, 1)) |
| return z[:, :, :0], z[:, :, :0], mx.zeros((B, L)) |
|
|
| |
| h = self.proj(x) |
| h_norm = h / (mx.linalg.norm(h, axis=-1, keepdims=True) + 1e-8) |
|
|
| |
| E = mx.stack([self.embeddings[self._emb_ids.index(eid)].embedding |
| for eid in expert_ids], axis=0) |
| E_norm = E / (mx.linalg.norm(E, axis=-1, keepdims=True) + 1e-8) |
|
|
| raw_scores = h_norm @ E_norm.T |
|
|
| |
| threshold = mx.sigmoid(self.threshold_head(x)) |
| gate_mask = (raw_scores > threshold).astype(mx.float32) |
|
|
| |
| best_idx = mx.argmax(raw_scores, axis=-1) |
| best_oh = one_hot(best_idx, N) |
| gate_mask = mx.maximum(gate_mask, best_oh) |
|
|
| |
| max_k = self.config.max_experts_per_token |
| if max_k < N: |
| sorted_idx = mx.argsort(-raw_scores, axis=-1) |
| rank = mx.argsort(sorted_idx, axis=-1) |
| gate_mask = gate_mask * (rank < max_k).astype(mx.float32) |
|
|
| |
| masked = raw_scores * gate_mask + (1.0 - gate_mask) * (-1e9) |
| routing_weights = mx.softmax(masked, axis=-1) * gate_mask |
|
|
| density = gate_mask.sum(axis=-1) |
| return routing_weights, raw_scores, density |
|
|
|
|
| |
| |
| |
| def _zero_tree(tree): |
| """Recursively zero all mx.arrays in a nested structure.""" |
| if isinstance(tree, mx.array): |
| return mx.zeros_like(tree) |
| elif isinstance(tree, dict): |
| return {k: _zero_tree(v) for k, v in tree.items()} |
| elif isinstance(tree, list): |
| return [_zero_tree(v) for v in tree] |
| return tree |
|
|
|
|
| |
| |
| |
| class MicroExpertsMoELayer(nn.Module): |
|
|
| def __init__(self, model_dim: int, config: MicroExpertConfig, layer_idx: int): |
| super().__init__() |
| self.model_dim = model_dim |
| self.config = config |
| self.layer_idx = layer_idx |
| self.router = AdaptiveRouter(model_dim, config) |
| self._variance_ema: Dict[str, float] = {} |
| self._variance_ema_sq: Dict[str, float] = {} |
|
|
| |
| self.expert_modules: List[Expert] = [] |
| self._expert_id_list: List[str] = [] |
| self._expert_meta: Dict[str, ExpertMeta] = {} |
| self._lifecycle_log: List[str] = [] |
| self.global_step: int = 0 |
|
|
| |
| self._last_routing_weights: Optional[mx.array] = None |
| self._last_density: Optional[mx.array] = None |
| self._last_input: Optional[mx.array] = None |
| |
| self._last_expert_outputs: Optional[List[mx.array]] = None |
|
|
| |
| self._frozen_eids: set = set() |
|
|
| |
| self._density_ema: float = 1.0 |
| self._density_var: float = 1.0 |
| self._drift_detected: bool = False |
|
|
| |
| self._create_expert(tier=config.init_tier) |
|
|
| |
| @property |
| def expert_ids(self) -> List[str]: |
| return list(self._expert_id_list) |
|
|
| def _eid_to_index(self, eid: str) -> int: |
| return self._expert_id_list.index(eid) |
|
|
| def _get_expert(self, eid: str) -> Expert: |
| return self.expert_modules[self._eid_to_index(eid)] |
|
|
| def _tier_to_hidden(self, tier: int) -> int: |
| t = min(tier, len(self.config.tier_hidden_dims) - 1) |
| return self.config.tier_hidden_dims[t] |
|
|
| def _expert_param_count(self, tier: int) -> int: |
| return 3 * self.model_dim * self._tier_to_hidden(tier) |
|
|
| def _total_params(self) -> int: |
| return sum(self._expert_param_count(m.tier) for m in self._expert_meta.values()) |
|
|
| def _make_id(self) -> str: |
| return uuid.uuid4().hex[:12] |
| |
| """ |
| def _copy_optimizer_state(self, optimizer, parent_idx: int, child_eid: str): |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx >= len(layers_state): |
| return |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| expert_states = moe_state.get("expert_modules", []) |
| if parent_idx >= len(expert_states): |
| return |
| |
| parent_state = expert_states[parent_idx] |
| child_idx = self._eid_to_index(child_eid) |
| |
| # Grow the list if needed |
| while len(expert_states) <= child_idx: |
| expert_states.append({}) |
| |
| # Deep copy the parent state |
| import copy |
| expert_states[child_idx] = copy.deepcopy(parent_state) |
| except (KeyError, IndexError, TypeError): |
| pass |
| """ |
| def _copy_optimizer_state(self, optimizer, parent_idx: int, children_eids: list): |
| """Copy parent's optimizer state to children, then rebuild list.""" |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx >= len(layers_state): |
| return |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| expert_states = moe_state.get("expert_modules", []) |
| if parent_idx >= len(expert_states): |
| return |
| |
| import copy |
| parent_state = copy.deepcopy(expert_states[parent_idx]) |
| |
| |
| new_states = [] |
| for i, eid in enumerate(self._expert_id_list): |
| if eid in children_eids: |
| new_states.append(copy.deepcopy(parent_state)) |
| elif i < len(expert_states): |
| new_states.append(expert_states[i]) |
| else: |
| new_states.append({}) |
| |
| moe_state["expert_modules"] = new_states |
| except (KeyError, IndexError, TypeError): |
| pass |
|
|
| |
| def _create_expert( |
| self, tier: int, |
| parent_id: Optional[str] = None, |
| init_weights_from: Optional[Expert] = None, |
| noise_scale: float = 0.0, |
| frozen_steps: int = 0, |
| init_embedding: Optional[mx.array] = None, |
| ) -> str: |
| eid = self._make_id() |
| hidden = self._tier_to_hidden(tier) |
| expert = Expert(self.model_dim, hidden) |
|
|
| if init_weights_from is not None: |
| src = dict(tree_flatten(init_weights_from.parameters())) |
| dst = dict(tree_flatten(expert.parameters())) |
| pairs = [] |
| for k in dst: |
| if k in src and src[k].shape == dst[k].shape: |
| w = src[k] |
| if noise_scale > 0: |
| w = w + mx.random.normal(w.shape) * noise_scale * (mx.abs(w).mean() + 1e-8) |
| pairs.append((k, w)) |
| if pairs: |
| expert.load_weights(pairs) |
|
|
| mx.eval(expert.parameters()) |
|
|
| self.expert_modules.append(expert) |
| self._expert_id_list.append(eid) |
|
|
| gen = 0 |
| if parent_id and parent_id in self._expert_meta: |
| gen = self._expert_meta[parent_id].generation + 1 |
|
|
| self._expert_meta[eid] = ExpertMeta( |
| expert_id=eid, tier=tier, hidden_dim=hidden, |
| frozen_steps=frozen_steps, parent_id=parent_id, generation=gen, |
| ) |
| if frozen_steps > 0: |
| self._frozen_eids.add(eid) |
|
|
| self.router.add_expert(eid, init_embedding=init_embedding) |
| return eid |
|
|
| def _remove_expert(self, eid: str): |
| if eid not in self._expert_id_list: |
| return |
| idx = self._eid_to_index(eid) |
| self.expert_modules.pop(idx) |
| self._expert_id_list.pop(idx) |
| self._expert_meta.pop(eid, None) |
| self._frozen_eids.discard(eid) |
| self.router.remove_expert(eid) |
|
|
| |
| def __call__(self, x: mx.array) -> mx.array: |
| B, L, D = x.shape |
| N = len(self._expert_id_list) |
| if N == 0: |
| return mx.zeros_like(x) |
|
|
| routing_weights, raw_scores, density = self.router(x, self._expert_id_list) |
|
|
| |
| expert_outputs = [self.expert_modules[i](x) for i in range(N)] |
|
|
| output = mx.zeros_like(x) |
| for i in range(N): |
| w_i = routing_weights[:, :, i:i + 1] |
| output = output + w_i * expert_outputs[i] |
|
|
| |
| self._last_routing_weights = mx.stop_gradient(routing_weights) |
| self._last_density = mx.stop_gradient(density) |
| self._last_input = mx.stop_gradient(x) |
| self._last_expert_outputs = [mx.stop_gradient(eo) for eo in expert_outputs] |
|
|
| return output |
|
|
| |
| def load_balance_loss(self) -> mx.array: |
| """ |
| Variance of per-expert activation frequency across the last batch. |
| Penalizes uneven usage — prevents expert starvation without forcing |
| uniform routing (which would defeat specialization). |
| """ |
| if self._last_routing_weights is None: |
| return mx.array(0.0) |
|
|
| N = self._last_routing_weights.shape[-1] |
| if N <= 1: |
| return mx.array(0.0) |
|
|
| |
| active = (self._last_routing_weights > 0.01).astype(mx.float32) |
| freq = active.reshape(-1, N).mean(axis=0) |
|
|
| return freq.var() |
|
|
| |
| def zero_frozen_grads(self, expert_grads: Any) -> Any: |
| """Zero gradients for the expert_modules subtree of frozen experts.""" |
| if not self._frozen_eids or not isinstance(expert_grads, list): |
| return expert_grads |
| result = [] |
| for i, g in enumerate(expert_grads): |
| eid = self._expert_id_list[i] if i < len(self._expert_id_list) else None |
| if eid and eid in self._frozen_eids: |
| result.append(_zero_tree(g)) |
| else: |
| result.append(g) |
| return result |
|
|
| def dr(self): |
| """Update density EMA and detect distribution shift spikes.""" |
| if self._last_density is None: |
| return |
| cfg = self.config |
| current = self._last_density.mean().item() |
| alpha = cfg.density_ema_alpha |
|
|
| |
| old_ema = self._density_ema |
| self._density_ema = (1 - alpha) * self._density_ema + alpha * current |
| diff = current - old_ema |
| self._density_var = (1 - alpha) * self._density_var + alpha * diff * diff |
|
|
| |
| std = math.sqrt(max(self._density_var, 1e-8)) |
| z = (current - self._density_ema) / std |
| self._drift_detected = z > cfg.density_spike_z |
|
|
| if self._drift_detected: |
| msg = (f"[step {self.global_step}][L{self.layer_idx}] " |
| f"DRIFT density={current:.1f} ema={self._density_ema:.1f} z={z:.1f}") |
| self._lifecycle_log.append(msg) |
| print(msg) |
|
|
| def compute_interference(self) -> Dict[str, float]: |
| if (self._last_routing_weights is None or self._last_input is None |
| or self._last_expert_outputs is None): |
| return {} |
|
|
| x = self._last_input |
| rw = self._last_routing_weights |
| B, L, D = x.shape |
| N = len(self._expert_id_list) |
| if N == 0: |
| return {} |
|
|
| T = min(self.config.interference_subsample, B * L) |
| rw_flat = rw.reshape(-1, N)[:T] |
|
|
| |
| expert_outs_flat = [eo.reshape(-1, D)[:T] for eo in self._last_expert_outputs] |
|
|
| |
| combined = mx.zeros((T, D)) |
| for i in range(N): |
| combined = combined + rw_flat[:, i:i + 1] * expert_outs_flat[i] |
| combined = mx.stop_gradient(combined) |
|
|
| interference = {} |
| for i in range(N): |
| eid = self._expert_id_list[i] |
| w_i = rw_flat[:, i] |
| e_out = expert_outs_flat[i] |
| active = (w_i > 0.01).astype(mx.float32) |
| n_active = active.sum().item() |
| if n_active < 1.0: |
| interference[eid] = 0.0 |
| continue |
| diff_norm = mx.linalg.norm(combined - e_out, axis=-1) |
| e_norm = mx.linalg.norm(e_out, axis=-1) + 1e-8 |
| relative = diff_norm / e_norm |
| score = (relative * w_i * active).sum() / (n_active + 1e-8) |
| interference[eid] = score.item() |
|
|
| mx.eval(list(interference.values())) |
| return interference |
| |
| def _compute_monolith_split_scores(self) -> Dict[str, float]: |
| scores = {} |
| if self._last_expert_outputs is None or not self.config.monolith_split_enabled: |
| return scores |
| cfg = self.config |
| for i, eid in enumerate(self._expert_id_list): |
| if i >= len(self._last_expert_outputs): |
| continue |
| eo = self._last_expert_outputs[i] |
| norms = mx.linalg.norm(eo.reshape(-1, eo.shape[-1]), axis=-1) |
| var = norms.var().item() |
| alpha = cfg.monolith_variance_ema_alpha |
| prev_mean = self._variance_ema.get(eid, var) |
| prev_sq = self._variance_ema_sq.get(eid, var * var) |
| new_mean = (1 - alpha) * prev_mean + alpha * var |
| new_sq = (1 - alpha) * prev_sq + alpha * var * var |
| self._variance_ema[eid] = new_mean |
| self._variance_ema_sq[eid] = new_sq |
| running_std = math.sqrt(max(new_sq - new_mean * new_mean, 1e-8)) |
| z = (var - new_mean) / running_std |
| scores[eid] = z |
| return scores |
|
|
| |
| def lifecycle_step(self, optimizer=None): |
|
|
| self.dr() |
|
|
| interference = self.compute_interference() |
| events = [] |
| all_ids = list(self._expert_id_list) |
|
|
|
|
| monolith_scores = self._compute_monolith_split_scores() |
| N = len(all_ids) |
|
|
| for eid in all_ids: |
| meta = self._expert_meta.get(eid) |
| if meta is None: |
| continue |
| meta.age += 1 |
| if meta.cooldown > 0: |
| meta.cooldown -= 1 |
| if meta.frozen_steps > 0: |
| meta.frozen_steps -= 1 |
| if meta.frozen_steps == 0: |
| self._frozen_eids.discard(eid) |
|
|
| |
| if self._last_routing_weights is not None and eid in self._expert_id_list: |
| idx = self._eid_to_index(eid) |
| if idx < self._last_routing_weights.shape[-1]: |
| w = self._last_routing_weights[:, :, idx] |
| meta.avg_routing_weight = ( |
| 0.95 * meta.avg_routing_weight + 0.05 * w.mean().item() |
| ) |
| meta.avg_activation_freq = ( |
| 0.95 * meta.avg_activation_freq |
| + 0.05 * (w > 0.01).astype(mx.float32).mean().item() |
| ) |
|
|
| |
| intf = interference.get(eid, 0.0) |
| af = self.config.ema_fast_alpha |
| asl = self.config.ema_slow_alpha |
| meta.ema_interference_fast = (1 - af) * meta.ema_interference_fast + af * intf |
| meta.ema_interference_slow = (1 - asl) * meta.ema_interference_slow + asl * intf |
| diff = intf - meta.ema_interference_slow |
| meta.ema_interference_var = 0.99 * meta.ema_interference_var + 0.01 * diff * diff |
|
|
| |
| scored = [] |
| for eid in all_ids: |
| meta = self._expert_meta.get(eid) |
| if meta is None or eid not in self._expert_id_list: |
| continue |
| std = math.sqrt(max(meta.ema_interference_var, 1e-8)) |
| intf_z = (meta.ema_interference_fast - meta.ema_interference_slow) / std |
| mono_z = monolith_scores.get(eid, 0.0) |
| if N <= 2: |
| z = mono_z |
| else: |
| z = max(intf_z, mono_z) |
| scored.append((eid, z, meta)) |
| scored.sort(key=lambda t: -t[1]) |
|
|
| |
| effective_split_threshold = self.config.split_threshold |
| if self._drift_detected: |
| effective_split_threshold *= 0.7 |
|
|
| |
| touched = set() |
| for eid, z_score, meta in scored: |
| if eid in touched or eid not in self._expert_id_list: |
| continue |
| if meta.age < self.config.min_expert_age or meta.cooldown > 0: |
| continue |
| budget_usage = self._total_params() / self.config.max_params_per_layer |
| if budget_usage > 0.7: |
| continue |
|
|
| threshold = self.config.monolith_variance_z_threshold if N <= 2 else effective_split_threshold |
| if (z_score > threshold |
| and len(self._expert_id_list) < self.config.max_experts_per_layer |
| and (self._total_params() + self._expert_param_count(meta.tier) |
| < self.config.max_params_per_layer)): |
| events.append(self._do_split(eid,optimizer=optimizer)) |
| touched.add(eid) |
| continue |
|
|
| if (meta.avg_routing_weight < self.config.death_threshold |
| and len(self._expert_id_list) > 1): |
| events.append(self._do_death(eid, optimizer=optimizer)) |
| touched.add(eid) |
| continue |
|
|
| events.extend(self._check_merges(touched, optimizer=optimizer)) |
|
|
| for e in events: |
| msg = f"[step {self.global_step}][L{self.layer_idx}] {e}" |
| self._lifecycle_log.append(msg) |
| print(msg) |
| return events |
|
|
| |
| def _compute_freeze_steps(self, meta: ExpertMeta) -> int: |
| cfg = self.config |
| importance = max(0.0, min(1.0, meta.avg_routing_weight * 10.0)) |
| freeze = int( |
| cfg.preserver_base_freeze_steps |
| + importance * (cfg.preserver_max_freeze_steps - cfg.preserver_base_freeze_steps) |
| ) |
| return freeze |
|
|
|
|
| """ |
| def _do_split(self, eid: str) -> str: |
| meta = self._expert_meta[eid] |
| parent = self._get_expert(eid) |
| parent_emb = self.router.get_embedding(eid) |
| |
| freeze_steps = self._compute_freeze_steps(meta) |
| |
| preserver_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, noise_scale=0.0, |
| frozen_steps=freeze_steps, |
| init_embedding=parent_emb, |
| ) |
| |
| adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 |
| mx.eval(adapter_emb) |
| adapter_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, |
| noise_scale=self.config.adapter_noise_scale, |
| frozen_steps=0, init_embedding=adapter_emb, |
| ) |
| |
| self._remove_expert(eid) |
| self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps |
| self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps |
| |
| return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " |
| f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " |
| f"+ adapter {adapter_id[:8]}") |
| """ |
| """ |
| def _do_split(self, eid: str, optimizer=None) -> str: |
| meta = self._expert_meta[eid] |
| parent = self._get_expert(eid) |
| parent_emb = self.router.get_embedding(eid) |
| parent_idx = self._eid_to_index(eid) |
| |
| |
| parent_opt_state = None |
| parent_emb_opt_state = None |
| if optimizer is not None: |
| try: |
| import copy |
| layers_state = optimizer.state.get("layers", []) |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| expert_states = moe_state.get("expert_modules", []) |
| if parent_idx < len(expert_states): |
| parent_opt_state = copy.deepcopy(expert_states[parent_idx]) |
| # Save parent router embedding state |
| router_state = moe_state.get("router", {}) |
| emb_states = router_state.get("embeddings", []) |
| if parent_idx < len(emb_states): |
| parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx]) |
| except (KeyError, IndexError, TypeError): |
| pass |
| |
| |
| freeze_steps = self._compute_freeze_steps(meta) |
| |
| preserver_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, noise_scale=0.0, |
| frozen_steps=freeze_steps, |
| init_embedding=parent_emb, |
| ) |
| |
| adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 |
| mx.eval(adapter_emb) |
| adapter_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, |
| noise_scale=self.config.adapter_noise_scale, |
| frozen_steps=0, init_embedding=adapter_emb, |
| ) |
| |
| # Copy optimizer state before removing parent |
| |
| if optimizer is not None: |
| self._copy_optimizer_state(optimizer, parent_idx, preserver_id) |
| self._copy_optimizer_state(optimizer, parent_idx, adapter_id) |
| |
| self._remove_expert(eid) |
| |
| if optimizer is not None and parent_opt_state is not None: |
| try: |
| import copy |
| layers_state = optimizer.state["layers"] |
| moe_state = layers_state[self.layer_idx]["moe"] |
| old_states = moe_state.get("expert_modules", []) |
| |
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if expert_eid == preserver_id or expert_eid == adapter_id: |
| new_states.append(copy.deepcopy(parent_opt_state)) |
| elif i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| |
| moe_state["expert_modules"] = new_states |
| except (KeyError, IndexError, TypeError): |
| pass |
| |
| |
| |
| if optimizer is not None: |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| expert_states = layers_state[self.layer_idx]["moe"]["expert_modules"] |
| if parent_idx < len(expert_states): |
| expert_states.pop(parent_idx) |
| except (KeyError, IndexError, TypeError): |
| pass |
| |
| self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps |
| self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps |
| |
| return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " |
| f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " |
| f"+ adapter {adapter_id[:8]}") |
| |
| """ |
| def _do_split(self, eid: str, optimizer=None) -> str: |
| meta = self._expert_meta[eid] |
| parent = self._get_expert(eid) |
| parent_emb = self.router.get_embedding(eid) |
| parent_idx = self._eid_to_index(eid) |
|
|
| parent_opt_state = None |
| parent_emb_opt_state = None |
| if optimizer is not None: |
| try: |
| import copy |
| layers_state = optimizer.state.get("layers", []) |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| expert_states = moe_state.get("expert_modules", []) |
| if parent_idx < len(expert_states): |
| parent_opt_state = copy.deepcopy(expert_states[parent_idx]) |
| router_state = moe_state.get("router", {}) |
| emb_states = router_state.get("embeddings", []) |
| if parent_idx < len(emb_states): |
| parent_emb_opt_state = copy.deepcopy(emb_states[parent_idx]) |
| except (KeyError, IndexError, TypeError): |
| pass |
|
|
| freeze_steps = self._compute_freeze_steps(meta) |
|
|
| preserver_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, noise_scale=0.0, |
| frozen_steps=freeze_steps, |
| init_embedding=parent_emb, |
| ) |
|
|
| adapter_emb = parent_emb + mx.random.normal(parent_emb.shape) * 0.1 |
| mx.eval(adapter_emb) |
| adapter_id = self._create_expert( |
| tier=meta.tier, parent_id=eid, |
| init_weights_from=parent, |
| noise_scale=self.config.adapter_noise_scale, |
| frozen_steps=0, init_embedding=adapter_emb, |
| ) |
|
|
| self._remove_expert(eid) |
|
|
| if optimizer is not None and parent_opt_state is not None: |
| try: |
| import copy |
| layers_state = optimizer.state["layers"] |
| moe_state = layers_state[self.layer_idx]["moe"] |
| old_states = moe_state.get("expert_modules", []) |
|
|
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if expert_eid == preserver_id or expert_eid == adapter_id: |
| new_states.append(copy.deepcopy(parent_opt_state)) |
| elif i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| moe_state["expert_modules"] = new_states |
|
|
| |
| router_state = moe_state.get("router", {}) |
| old_emb_states = router_state.get("embeddings", []) |
| new_emb_states = [] |
| for i, emb_eid in enumerate(self.router._emb_ids): |
| if emb_eid == preserver_id or emb_eid == adapter_id: |
| if parent_emb_opt_state is not None: |
| new_emb_states.append(copy.deepcopy(parent_emb_opt_state)) |
| else: |
| new_emb_states.append({}) |
| elif i < len(old_emb_states): |
| new_emb_states.append(old_emb_states[i]) |
| else: |
| new_emb_states.append({}) |
| router_state["embeddings"] = new_emb_states |
| except (KeyError, IndexError, TypeError): |
| pass |
|
|
| self._expert_meta[preserver_id].cooldown = self.config.cooldown_steps |
| self._expert_meta[adapter_id].cooldown = self.config.cooldown_steps |
|
|
| return (f"SPLIT {eid[:8]} (T{meta.tier}, w={meta.avg_routing_weight:.4f}) -> " |
| f"preserver {preserver_id[:8]} (frozen={freeze_steps}) " |
| f"+ adapter {adapter_id[:8]}") |
| |
| def _do_death(self, eid: str, optimizer=None) -> str: |
| meta = self._expert_meta[eid] |
| info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})" |
| self._remove_expert(eid) |
|
|
| if optimizer is not None: |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx < len(layers_state): |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| old_states = moe_state.get("expert_modules", []) |
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| moe_state["expert_modules"] = new_states |
|
|
| |
| router_state = moe_state.get("router", {}) |
| old_emb_states = router_state.get("embeddings", []) |
| new_emb_states = [] |
| for i in range(len(self.router._emb_ids)): |
| if i < len(old_emb_states): |
| new_emb_states.append(old_emb_states[i]) |
| else: |
| new_emb_states.append({}) |
| router_state["embeddings"] = new_emb_states |
| except (KeyError, IndexError, TypeError): |
| pass |
|
|
| return info |
|
|
| """ |
| def _do_death(self, eid: str, optimizer=None) -> str: |
| meta = self._expert_meta[eid] |
| info = f"DEATH {eid[:8]} (T{meta.tier}, age={meta.age}, w={meta.avg_routing_weight:.4f})" |
| self._remove_expert(eid) |
| |
| if optimizer is not None: |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx < len(layers_state): |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| old_states = moe_state.get("expert_modules", []) |
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| moe_state["expert_modules"] = new_states |
| except (KeyError, IndexError, TypeError): |
| pass |
| |
| return info |
| |
| """ |
|
|
| def _average_expert_weights(self, expert_a: Expert, expert_b: Expert) -> List[Tuple[str, mx.array]]: |
| """Average the weights of two same-shape experts.""" |
| src_a = dict(tree_flatten(expert_a.parameters())) |
| src_b = dict(tree_flatten(expert_b.parameters())) |
| pairs = [] |
| for k in src_a: |
| if k in src_b and src_a[k].shape == src_b[k].shape: |
| pairs.append((k, (src_a[k] + src_b[k]) / 2.0)) |
| return pairs |
|
|
| def _check_merges(self, touched: set, optimizer=None) -> List[str]: |
| events = [] |
| merged = set() |
| ids = list(self._expert_id_list) |
| cfg = self.config |
|
|
| |
| co_activation = {} |
| if self._last_routing_weights is not None: |
| N = self._last_routing_weights.shape[-1] |
| active = (self._last_routing_weights > 0.01).astype(mx.float32) |
| |
| act_flat = active.reshape(-1, N) |
| |
| act_freq = act_flat.mean(axis=0) |
| mx.eval(act_freq) |
|
|
| def _can_merge(eid): |
| return (eid not in merged and eid not in touched |
| and eid in self._expert_id_list |
| and (meta := self._expert_meta.get(eid)) is not None |
| and meta.age >= cfg.min_expert_age |
| and meta.cooldown == 0) |
|
|
| def _do_merge(eid_a, eid_b, meta_a, meta_b, reason: str, optimizer=None) -> Optional[str]: |
| """Execute a merge and return event string, or None if budget exceeded.""" |
| new_tier = min(meta_a.tier + 1, len(cfg.tier_hidden_dims) - 1) |
| cost = self._expert_param_count(new_tier) |
| freed = (self._expert_param_count(meta_a.tier) |
| + self._expert_param_count(meta_b.tier)) |
| if self._total_params() - freed + cost > cfg.max_params_per_layer: |
| return None |
|
|
| emb_a = self.router.get_embedding(eid_a) |
| emb_b = self.router.get_embedding(eid_b) |
| avg_emb = (emb_a + emb_b) / 2.0 |
| mx.eval(avg_emb) |
|
|
| if new_tier == meta_a.tier: |
| |
| merged_expert_id = self._create_expert( |
| tier=new_tier, parent_id=eid_a, |
| init_weights_from=self._get_expert(eid_a), |
| init_embedding=avg_emb, |
| ) |
| |
| avg_weights = self._average_expert_weights( |
| self._get_expert(eid_a), self._get_expert(eid_b)) |
| if avg_weights: |
| self._get_expert(merged_expert_id).load_weights(avg_weights) |
| mx.eval(self._get_expert(merged_expert_id).parameters()) |
| else: |
| |
| merged_expert_id = self._create_expert( |
| tier=new_tier, parent_id=eid_a, |
| init_embedding=avg_emb, |
| ) |
|
|
| self._expert_meta[merged_expert_id].cooldown = cfg.cooldown_steps |
| self._remove_expert(eid_a) |
| self._remove_expert(eid_b) |
| merged.add(eid_a) |
| merged.add(eid_b) |
| """ |
| if optimizer is not None: |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx < len(layers_state): |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| old_states = moe_state.get("expert_modules", []) |
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if expert_eid == merged_expert_id: |
| new_states.append({}) # fresh state, no momentum to copy |
| elif i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| moe_state["expert_modules"] = new_states |
| except (KeyError, IndexError, TypeError): |
| pass |
| """ |
| |
| if optimizer is not None: |
| try: |
| layers_state = optimizer.state.get("layers", []) |
| if self.layer_idx < len(layers_state): |
| moe_state = layers_state[self.layer_idx].get("moe", {}) |
| |
| |
| old_states = moe_state.get("expert_modules", []) |
| new_states = [] |
| for i, expert_eid in enumerate(self._expert_id_list): |
| if expert_eid == merged_expert_id: |
| new_states.append({}) |
| elif i < len(old_states): |
| new_states.append(old_states[i]) |
| else: |
| new_states.append({}) |
| moe_state["expert_modules"] = new_states |
| |
| |
| router_state = moe_state.get("router", {}) |
| old_emb_states = router_state.get("embeddings", []) |
| new_emb_states = [] |
| for i in range(len(self.router._emb_ids)): |
| if i < len(old_emb_states): |
| new_emb_states.append(old_emb_states[i]) |
| else: |
| new_emb_states.append({}) |
| router_state["embeddings"] = new_emb_states |
| except (KeyError, IndexError, TypeError): |
| pass |
|
|
| return (f"MERGE({reason}) {eid_a[:8]}+{eid_b[:8]} (T{meta_a.tier}) " |
| f"-> {merged_expert_id[:8]} (T{new_tier})") |
|
|
| |
| for i, eid_a in enumerate(ids): |
| if not _can_merge(eid_a): |
| continue |
| meta_a = self._expert_meta[eid_a] |
|
|
| for j in range(i + 1, len(ids)): |
| eid_b = ids[j] |
| if not _can_merge(eid_b): |
| continue |
| meta_b = self._expert_meta[eid_b] |
| if meta_a.tier != meta_b.tier: |
| continue |
|
|
| emb_a = self.router.get_embedding(eid_a) |
| emb_b = self.router.get_embedding(eid_b) |
| cos = ((emb_a * emb_b).sum() |
| / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) |
|
|
| both_weak = (meta_a.avg_routing_weight < cfg.merge_weakness_threshold |
| and meta_b.avg_routing_weight < cfg.merge_weakness_threshold) |
|
|
| if cos.item() > cfg.merge_co_route_threshold and both_weak: |
| result = _do_merge(eid_a, eid_b, meta_a, meta_b, "fragment", optimizer=optimizer) |
| if result: |
| events.append(result) |
| break |
|
|
| |
| budget_frac = self._total_params() / cfg.max_params_per_layer |
| if budget_frac > cfg.merge_capacity_pressure_frac: |
| |
| candidates = [] |
| for i, eid_a in enumerate(ids): |
| if not _can_merge(eid_a): |
| continue |
| meta_a = self._expert_meta.get(eid_a) |
| if meta_a is None: |
| continue |
| for j in range(i + 1, len(ids)): |
| eid_b = ids[j] |
| if not _can_merge(eid_b): |
| continue |
| meta_b = self._expert_meta.get(eid_b) |
| if meta_b is None or meta_a.tier != meta_b.tier: |
| continue |
| emb_a = self.router.get_embedding(eid_a) |
| emb_b = self.router.get_embedding(eid_b) |
| cos = ((emb_a * emb_b).sum() |
| / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) |
| combined_w = meta_a.avg_routing_weight + meta_b.avg_routing_weight |
| |
| score = cos.item() - combined_w |
| candidates.append((score, eid_a, eid_b, meta_a, meta_b)) |
|
|
| candidates.sort(key=lambda t: -t[0]) |
| for score, eid_a, eid_b, meta_a, meta_b in candidates: |
| if not _can_merge(eid_a) or not _can_merge(eid_b): |
| continue |
| result = _do_merge(eid_a, eid_b, meta_a, meta_b, "capacity",optimizer=optimizer) |
| if result: |
| events.append(result) |
| |
| break |
|
|
| |
| if self._last_routing_weights is not None: |
| N = self._last_routing_weights.shape[-1] |
| act_flat = (self._last_routing_weights > 0.01).astype(mx.float32).reshape(-1, N) |
| total_tokens = act_flat.shape[0] |
|
|
| for i, eid_a in enumerate(ids): |
| if not _can_merge(eid_a): |
| continue |
| meta_a = self._expert_meta.get(eid_a) |
| if meta_a is None: |
| continue |
| idx_a = self._eid_to_index(eid_a) if eid_a in self._expert_id_list else None |
| if idx_a is None or idx_a >= N: |
| continue |
|
|
| for j in range(i + 1, len(ids)): |
| eid_b = ids[j] |
| if not _can_merge(eid_b): |
| continue |
| meta_b = self._expert_meta.get(eid_b) |
| if meta_b is None or meta_a.tier != meta_b.tier: |
| continue |
| idx_b = self._eid_to_index(eid_b) if eid_b in self._expert_id_list else None |
| if idx_b is None or idx_b >= N: |
| continue |
|
|
| |
| both_active = (act_flat[:, idx_a] * act_flat[:, idx_b]).mean().item() |
|
|
| emb_a = self.router.get_embedding(eid_a) |
| emb_b = self.router.get_embedding(eid_b) |
| cos = ((emb_a * emb_b).sum() |
| / (mx.linalg.norm(emb_a) * mx.linalg.norm(emb_b) + 1e-8)) |
|
|
| if (both_active > cfg.merge_tier_gravity_min_co_activation |
| and cos.item() > cfg.merge_tier_gravity_co_route): |
| result = _do_merge(eid_a, eid_b, meta_a, meta_b, "tier-gravity", optimizer=optimizer) |
| if result: |
| events.append(result) |
| break |
|
|
| return events |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dims: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = mx.ones((dims,)) |
| self.eps = eps |
|
|
| def __call__(self, x): |
| return mx.fast.rms_norm(x, self.weight, self.eps) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, args: ModelArgs): |
| super().__init__() |
| self.n_heads = args.n_heads |
| self.n_kv_heads = args.n_kv_heads |
| self.head_dim = args.dim // args.n_heads |
| self.scale = self.head_dim ** -0.5 |
| self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
| self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) |
| self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
| self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta) |
|
|
| def __call__(self, x, mask=None): |
| B, L, D = x.shape |
| queries, keys, values = self.wq(x), self.wk(x), self.wv(x) |
| queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) |
| keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) |
| queries = self.rope(queries) |
| keys = self.rope(keys) |
| output = mx.fast.scaled_dot_product_attention( |
| queries, keys, values, scale=self.scale, mask=mask) |
| return self.wo(output.transpose(0, 2, 1, 3).reshape(B, L, -1)) |
|
|
|
|
| class MicroExpertsBlock(nn.Module): |
| def __init__(self, args: ModelArgs, me_config: MicroExpertConfig, layer_idx: int): |
| super().__init__() |
| self.attention = Attention(args) |
| self.moe = MicroExpertsMoELayer(args.dim, me_config, layer_idx) |
| self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) |
| self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) |
|
|
| def __call__(self, x, mask=None): |
| h = x + self.attention(self.attention_norm(x), mask) |
| return h + self.moe(self.ffn_norm(h)) |
|
|
|
|
| class MicroExpertsModel(nn.Module): |
| def __init__(self, args: ModelArgs, me_config: MicroExpertConfig): |
| super().__init__() |
| self.args = args |
| self.me_config = me_config |
| self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) |
| self.layers = [ |
| MicroExpertsBlock(args, me_config, layer_idx=i) |
| for i in range(args.n_layers) |
| ] |
| self.norm = RMSNorm(args.dim, eps=args.norm_eps) |
| self.output = nn.Linear(args.dim, args.vocab_size, bias=False) |
|
|
| def __call__(self, x): |
| L = x.shape[1] |
| mask = nn.MultiHeadAttention.create_additive_causal_mask(L).astype(mx.float32) |
| mask = mask[None, None, :, :] |
| h = self.tok_embeddings(x) |
| for layer in self.layers: |
| h = layer(h, mask) |
| return self.output(self.norm(h)) |
|
|
| def set_global_step(self, step: int): |
| for layer in self.layers: |
| layer.moe.global_step = step |
|
|
| def run_lifecycle(self, optimizer=None): |
| all_events = [] |
| for layer in self.layers: |
| all_events.extend(layer.moe.lifecycle_step(optimizer=optimizer)) |
| return all_events |
|
|
| def total_load_balance_loss(self) -> mx.array: |
| """Sum of per-layer activation frequency variance.""" |
| lb = mx.array(0.0) |
| for layer in self.layers: |
| lb = lb + layer.moe.load_balance_loss() |
| return lb |
|
|
| def zero_frozen_grads(self, grads): |
| """Walk gradient tree, zero frozen expert parameters.""" |
| if not isinstance(grads, dict) or "layers" not in grads: |
| return grads |
| new_layers = [] |
| for i, lg in enumerate(grads["layers"]): |
| if (isinstance(lg, dict) and "moe" in lg |
| and isinstance(lg["moe"], dict) |
| and "expert_modules" in lg["moe"]): |
| moe = self.layers[i].moe |
| fixed = moe.zero_frozen_grads(lg["moe"]["expert_modules"]) |
| new_moe = dict(lg["moe"]) |
| new_moe["expert_modules"] = fixed |
| new_lg = dict(lg) |
| new_lg["moe"] = new_moe |
| new_layers.append(new_lg) |
| else: |
| new_layers.append(lg) |
| new_grads = dict(grads) |
| new_grads["layers"] = new_layers |
| return new_grads |
|
|
| def expert_summary(self) -> str: |
| lines = [] |
| total_e, total_p = 0, 0 |
| for i, layer in enumerate(self.layers): |
| moe = layer.moe |
| n = len(moe._expert_id_list) |
| p = moe._total_params() |
| total_e += n |
| total_p += p |
| tiers = defaultdict(int) |
| for m in moe._expert_meta.values(): |
| tiers[m.tier] += 1 |
| ts = " ".join(f"T{t}:{c}" for t, c in sorted(tiers.items())) |
| frozen = sum(1 for eid in moe._expert_id_list if eid in moe._frozen_eids) |
| drift = " DRIFT" if moe._drift_detected else "" |
| lines.append( |
| f" L{i:2d}: {n:3d} experts ({ts}) | {p/1e6:.1f}M | " |
| f"{frozen} frozen | d={moe._density_ema:.1f}{drift}") |
| lines.append(f" TOTAL: {total_e} experts | {total_p/1e6:.1f}M MoE params") |
| return "\n".join(lines) |
|
|
| def save_meta(self, path: str): |
| data = {} |
| for i, layer in enumerate(self.layers): |
| moe = layer.moe |
| data[f"layer_{i}"] = { |
| "expert_ids": list(moe._expert_id_list), |
| "experts": {eid: m.to_dict() for eid, m in moe._expert_meta.items()}, |
| "density_ema": moe._density_ema, |
| } |
| with open(path, "w") as f: |
| json.dump(data, f, indent=2) |
|
|
|
|
| |
| |
| |
| def stream_gutenberg(tokenizer, batch_size: int, seq_len: int): |
| print("Connecting to Gutenberg stream...") |
| dataset = load_dataset("teknium/OpenHermes-2.5", split="train", streaming=True,) |
| dataset_iter = iter(dataset) |
| buffers = [[] for _ in range(batch_size)] |
| while True: |
| for i in range(batch_size): |
| while len(buffers[i]) < seq_len + 1: |
| try: |
| row = next(dataset_iter) |
| except StopIteration: |
| dataset_iter = iter(dataset) |
| row = next(dataset_iter) |
| text = row.get("conversations", "") |
| if isinstance(text, list): |
| parts = [] |
| for msg in text: |
| role = msg.get("from", "") |
| content = msg.get("value", []) |
| if isinstance(content, str): |
| parts.append(f"{role}\n{content}") |
| text = "\n".join(parts) |
| |
| if not text or len(text) < 10: |
| continue |
| buffers[i].extend(tokenizer.encode(text)) |
| batch = [] |
| for i in range(batch_size): |
| batch.append(buffers[i][:seq_len + 1]) |
| buffers[i] = buffers[i][seq_len:] |
| yield mx.array(batch, dtype=mx.int32) |
|
|
|
|
| def stream_domain_files(tokenizer, data_dir: str, batch_size: int, seq_len: int): |
| files = sorted(glob.glob(os.path.join(data_dir, "*.txt"))) |
| if not files: |
| raise FileNotFoundError(f"No .txt files in {data_dir}") |
| for fpath in files: |
| domain = os.path.splitext(os.path.basename(fpath))[0] |
| print(f"\n{'='*60}") |
| print(f" ACTIVE LEARNING — Domain: {domain}") |
| print(f"{'='*60}") |
| with open(fpath, "r", encoding="utf-8", errors="replace") as f: |
| text = f.read() |
| tokens = tokenizer.encode(text) |
| min_tokens = (seq_len + 1) * batch_size |
| if len(tokens) < min_tokens: |
| print(f" Skipping {domain}: {len(tokens)} tokens < {min_tokens} needed") |
| continue |
|
|
| def batch_gen(toks=tokens, bs=batch_size, sl=seq_len): |
| while True: |
| buf = list(toks) |
| while len(buf) >= bs * (sl + 1): |
| batch = [] |
| for _ in range(bs): |
| batch.append(buf[:sl + 1]) |
| buf = buf[sl:] |
| yield mx.array(batch, dtype=mx.int32) |
|
|
| yield domain, batch_gen() |
|
|
|
|
| |
| |
| |
| def loss_fn(model, x): |
| """Cross-entropy + load balance auxiliary loss.""" |
| logits = model(x) |
| ce = nn.losses.cross_entropy(logits[:, :-1, :], x[:, 1:], reduction="mean") |
| lb = model.total_load_balance_loss() |
| return ce + model.me_config.load_balance_weight * lb |
|
|
| def load_checkpoint(model, path: str): |
| weights = dict(mx.load(path)) |
| meta_path = path.replace(".npz", ".json") |
| with open(meta_path, "r") as f: |
| meta = json.load(f) |
| |
| for i, layer in enumerate(model.layers): |
| moe = layer.moe |
| layer_key = f"layer_{i}" |
| if layer_key not in meta: |
| continue |
| layer_meta = meta[layer_key] |
| |
| for eid in list(moe._expert_id_list): |
| moe._remove_expert(eid) |
| |
| for eid in layer_meta["expert_ids"]: |
| em = layer_meta["experts"][eid] |
| tier = em["tier"] |
| hidden = moe._tier_to_hidden(tier) |
| expert = Expert(moe.model_dim, hidden) |
| mx.eval(expert.parameters()) |
| moe.expert_modules.append(expert) |
| moe._expert_id_list.append(eid) |
| moe._expert_meta[eid] = ExpertMeta( |
| expert_id=eid, tier=tier, hidden_dim=hidden, |
| age=em.get("age", 0), |
| cooldown=em.get("cooldown", 0), |
| frozen_steps=em.get("frozen_steps", 0), |
| ema_interference_fast=em.get("ema_fast", 0.0), |
| ema_interference_slow=em.get("ema_slow", 0.0), |
| ema_interference_var=em.get("ema_var", 1.0), |
| avg_routing_weight=em.get("avg_rw", 0.1), |
| avg_activation_freq=em.get("avg_af", 0.1), |
| parent_id=em.get("parent_id"), |
| generation=em.get("generation", 0), |
| ) |
| if em.get("frozen_steps", 0) > 0: |
| moe._frozen_eids.add(eid) |
| router_key = f"__router__.{i}.{eid}" |
| init_emb = weights.pop(router_key, None) |
| moe.router.add_expert(eid, init_embedding=init_emb) |
| |
| moe._density_ema = layer_meta.get("density_ema", 1.0) |
| |
| remaining = [(k, v) for k, v in weights.items() if not k.startswith("__router__")] |
| model.load_weights(remaining, strict=False) |
| mx.eval(model.parameters()) |
| print(f" Loaded checkpoint from {path}") |
|
|
|
|
| def get_latest_checkpoint(checkpoint_dir: str): |
| if not os.path.exists(checkpoint_dir): |
| return None, 0 |
| ckpts = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_step_*.npz"))) |
| if not ckpts: |
| return None, 0 |
| latest = ckpts[-1] |
| m = re.search(r"step_(\d+)", latest) |
| return latest, int(m.group(1)) |
|
|
|
|
| def save_checkpoint(model, step: int, checkpoint_dir: str): |
| path = os.path.join(checkpoint_dir, f"checkpoint_step_{step}.npz") |
| |
| save_dict = {} |
| |
| for k, v in tree_flatten(model.parameters()): |
| save_dict[k] = v |
| |
| for i, layer in enumerate(model.layers): |
| moe = layer.moe |
| for j, eid in enumerate(moe.router._emb_ids): |
| save_dict[f"__router__.{i}.{eid}"] = moe.router.embeddings[j].embedding |
| |
| mx.savez(path, **save_dict) |
| model.save_meta(path.replace(".npz", ".json")) |
| print(f" Saved checkpoint {path}") |
|
|
|
|
| |
| |
| |
| def train_loop(model, optimizer, data_iter, tc: TrainConfig, |
| start_step=0, max_steps=30000, lifecycle_every=10, label="train"): |
|
|
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
| compiled_loss_and_grad = mx.compile(loss_and_grad_fn) |
|
|
| step = start_step |
| tic = time.time() |
|
|
| topology_changed = False |
|
|
| for batch in data_iter: |
| if step >= max_steps: |
| break |
| model.set_global_step(step) |
|
|
| |
| if topology_changed: |
| compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn)) |
| topology_changed = False |
|
|
| try: |
| loss, grads = compiled_loss_and_grad(model, batch) |
| except Exception: |
| loss_and_grad_fn_eager = nn.value_and_grad(model, loss_fn) |
| loss, grads = loss_and_grad_fn_eager(model, batch) |
| compiled_loss_and_grad = mx.compile(nn.value_and_grad(model, loss_fn)) |
|
|
| grads = model.zero_frozen_grads(grads) |
| try: |
| optimizer.update(model, grads) |
| except (ValueError, KeyError, IndexError): |
| |
| optimizer.state = {k: v for k, v in optimizer.state.items() if not isinstance(v, (dict, list))} |
| optimizer.update(model, grads) |
| mx.eval(model.parameters(), optimizer.state, loss) |
|
|
| if step > 0 and step % lifecycle_every == 0: |
| events = model.run_lifecycle(optimizer=optimizer) |
| if events: |
| topology_changed = True |
| |
|
|
| """ |
| optimizer.update(model, grads) |
| mx.eval(model.parameters(), optimizer.state, loss) |
| """ |
|
|
| if step % tc.log_every == 0: |
| toc = time.time() |
| n_exp = sum(len(l.moe._expert_id_list) for l in model.layers) |
| avg_d = sum( |
| l.moe._last_density.mean().item() |
| for l in model.layers if l.moe._last_density is not None |
| ) / model.args.n_layers |
| elapsed = toc - tic |
| tok_per_sec = (tc.log_every * tc.batch_size * model.args.max_seq_len) / max(elapsed, 1e-6) |
| print(f"[{label}] Step {step:6d} | Loss {loss.item():.4f} | " |
| f"Experts {n_exp} | Density {avg_d:.1f} | " |
| f"{tok_per_sec:.0f} tok/s | {elapsed:.2f}s") |
| tic = time.time() |
|
|
| if step > 0 and step % tc.summary_every == 0: |
| print(f"\n--- Expert Summary @ step {step} ---") |
| print(model.expert_summary()) |
| print() |
|
|
| if step > 0 and step % tc.checkpoint_every == 0: |
| save_checkpoint(model, step, tc.checkpoint_dir) |
|
|
| step += 1 |
| return step |
|
|
|
|
| |
| |
| |
| def prompt_config() -> TrainConfig: |
| """Interactive configuration via input() prompts.""" |
| tc = TrainConfig() |
|
|
| print("\n" + "="*60) |
| print(" MicroExperts — Training Configuration") |
| print("="*60) |
|
|
| |
| print(" 1. pretrain — Gutenberg streaming pretraining") |
| print(" 2. active_learning — Sequential domain continual learning(not implemented yet)") |
| print(" 3. inference — Chat with the trained model") |
| print(" 4. interactive_learning — Chat and learn from your inputs") |
| print(" 5. train_and_chat — Train with periodic chat breaks") |
| choice = input("Mode [1]: ").strip() |
| if choice == "2": |
| tc.mode = "active_learning" |
| elif choice == "3": |
| tc.mode = "inference" |
| elif choice == "4": |
| tc.mode = "interactive_learning" |
| elif choice == "5": |
| tc.mode = "train_and_chat" |
| else: |
| tc.mode = "pretrain" |
| |
| |
| tok = "gutenberg_tokenizer.json" |
| if tok: |
| tc.tokenizer_file = tok |
|
|
| |
| cd = input(f"Checkpoint directory [{tc.checkpoint_dir}]: ").strip() |
| if cd: |
| tc.checkpoint_dir = cd |
|
|
| |
| bs = input(f"Batch size [{tc.batch_size}]: ").strip() |
| if bs: |
| tc.batch_size = int(bs) |
|
|
| |
| if tc.mode == "pretrain": |
| default_lr = tc.learning_rate |
| else: |
| default_lr = tc.al_learning_rate |
| lr = input(f"Learning rate [{default_lr}]: ").strip() |
| if lr: |
| tc.learning_rate = float(lr) |
| else: |
| tc.learning_rate = default_lr |
|
|
| |
| ms = input(f"Max steps [{tc.max_steps}]: ").strip() |
| if ms: |
| tc.max_steps = int(ms) |
|
|
| |
| resume = input("Resume from checkpoint? [Y/n]: ").strip().lower() |
| tc._resume = resume != "n" |
|
|
| |
| if tc.mode == "active_learning": |
| dd = input(f"Domain data directory [{tc.al_data_dir}]: ").strip() |
| if dd: |
| tc.al_data_dir = dd |
| spd = input(f"Steps per domain [{tc.al_steps_per_domain}]: ").strip() |
| if spd: |
| tc.al_steps_per_domain = int(spd) |
|
|
| print("\n" + "-"*60) |
| print(f" Mode: {tc.mode}") |
| print(f" LR: {tc.learning_rate}") |
| print(f" Batch: {tc.batch_size}") |
| print(f" Max steps: {tc.max_steps}") |
| print(f" Checkpoint: {tc.checkpoint_dir}") |
| print(f" Resume: {tc._resume}") |
| if tc.mode == "active_learning": |
| print(f" Data dir: {tc.al_data_dir}") |
| print(f" Steps/dom: {tc.al_steps_per_domain}") |
| print(f" M4 budget: 150M params/layer, 128 experts/layer max") |
| print("-"*60) |
|
|
| confirm = input("Continue? [Y/n]: ").strip().lower() |
| if confirm == "n": |
| print("Aborted.") |
| exit(0) |
|
|
| return tc |
|
|
| def generate(model, tokenizer, prompt: str, max_tokens: int = 256, temperature: float = 0.8): |
| tokens = tokenizer.encode(prompt) |
| tokens = mx.array([tokens], dtype=mx.int32) |
| |
| for _ in range(max_tokens): |
| logits = model(tokens) |
| next_logits = logits[:, -1, :] / temperature |
| next_token = mx.random.categorical(next_logits) |
| next_token = next_token.reshape(1, 1) |
| tokens = mx.concatenate([tokens, next_token], axis=1) |
| mx.eval(tokens) |
| |
| token_id = next_token.item() |
| if token_id == tokenizer.eos_token_id: |
| break |
| |
| |
| print("\n Expert routing:") |
| for i, layer in enumerate(model.layers): |
| moe = layer.moe |
| if moe._last_routing_weights is None: |
| continue |
| rw = moe._last_routing_weights |
| N = rw.shape[-1] |
| |
| avg_w = rw.reshape(-1, N).mean(axis=0) |
| active = (avg_w > 0.01) |
| parts = [] |
| for j, eid in enumerate(moe._expert_id_list): |
| if j < N and active[j].item(): |
| meta = moe._expert_meta.get(eid) |
| tier = meta.tier if meta else "?" |
| parts.append(f"{eid[:6]}(T{tier} w={avg_w[j].item():.3f})") |
| if parts: |
| print(f" L{i:2d}: {' '.join(parts)}") |
| |
| return tokenizer.decode(tokens[0].tolist()) |
|
|
| def main(): |
| tc = prompt_config() |
| os.makedirs(tc.checkpoint_dir, exist_ok=True) |
|
|
| |
| print(f"\nLoading tokenizer: {tc.tokenizer_file}") |
| tokenizer = PreTrainedTokenizerFast(tokenizer_file=tc.tokenizer_file) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| args = ModelArgs() |
| args.vocab_size = len(tokenizer) |
| me_config = MicroExpertConfig() |
|
|
| if tc.mode == "active_learning": |
| me_config.split_threshold = tc.al_split_threshold |
| me_config.min_expert_age = tc.al_min_expert_age |
|
|
| print(f"Initializing MicroExperts model (vocab={args.vocab_size})...") |
| model = MicroExpertsModel(args, me_config) |
|
|
| |
| current_step = 0 |
| if tc._resume: |
| ckpt, ckpt_step = get_latest_checkpoint(tc.checkpoint_dir) |
| if ckpt: |
| print(f"Resuming from {ckpt} @ step {ckpt_step}") |
| load_checkpoint(model, ckpt) |
| current_step = ckpt_step |
| else: |
| print("No checkpoint found — starting fresh.") |
|
|
| mx.eval(model.parameters()) |
| n_params = sum(v.size for _, v in tree_flatten(model.parameters())) |
| print(f"Total params: {n_params / 1e6:.2f}M") |
| print("Initial layout:") |
| print(model.expert_summary()) |
|
|
| optimizer = optim.AdamW(learning_rate=tc.learning_rate) |
|
|
| |
| if tc.mode == "pretrain": |
| data = stream_gutenberg(tokenizer, tc.batch_size, args.max_seq_len) |
| print(f"\nStarting pretraining for {tc.max_steps} steps...") |
| final_step = train_loop( |
| model, optimizer, data, tc, |
| start_step=current_step, max_steps=tc.max_steps, |
| lifecycle_every=tc.lifecycle_every, label="pretrain", |
| ) |
| |
| elif tc.mode == "inference": |
| |
| print("\nChat ready. Type 'quit' to exit.\n") |
| while True: |
| user_input = input("You: ").strip() |
| if user_input.lower() in ("quit", "exit"): |
| break |
| if not user_input: |
| continue |
| response = generate(model, tokenizer, user_input) |
| print(f"Model: {response}\n") |
| |
| final_step = current_step |
|
|
| |
| elif tc.mode == "active_learning": |
| lifecycle_every = tc.al_lifecycle_every |
| print(f"\nActive learning from: {tc.al_data_dir}") |
| print(f" Steps/domain: {tc.al_steps_per_domain} | Lifecycle every: {lifecycle_every}") |
|
|
| domain_gen = stream_domain_files( |
| tokenizer, tc.al_data_dir, tc.batch_size, args.max_seq_len) |
|
|
| global_step = current_step |
| for domain_name, batches in domain_gen: |
| domain_max = global_step + tc.al_steps_per_domain |
| n_before = sum(len(l.moe._expert_id_list) for l in model.layers) |
|
|
| print(f"\n Training '{domain_name}': steps {global_step} -> {domain_max}") |
| global_step = train_loop( |
| model, optimizer, batches, tc, |
| start_step=global_step, max_steps=domain_max, |
| lifecycle_every=lifecycle_every, label=f"AL:{domain_name}", |
| ) |
|
|
| n_after = sum(len(l.moe._expert_id_list) for l in model.layers) |
| print(f"\n '{domain_name}' done. Experts: {n_before} -> {n_after} ({n_after-n_before:+d})") |
| print(model.expert_summary()) |
|
|
| final_step = global_step |
| |
| elif tc.mode == "interactive_learning": |
| if not tc._resume: |
| print("WARNING: No checkpoint loaded, model is random.") |
| |
| il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate) |
| il_step = current_step |
| conversation_tokens = [] |
| message_count = 0 |
| |
| print("\nInteractive learning ready. Type 'quit' to exit.") |
| print("The model learns from the conversation.\n") |
| |
| while True: |
| user_input = input("You: ").strip() |
| if user_input.lower() in ("quit", "exit"): |
| break |
| if not user_input: |
| continue |
| |
| response = generate(model, tokenizer, user_input) |
| print(f"Model: {response}\n") |
| |
| conversation_tokens.extend(tokenizer.encode(user_input)) |
| conversation_tokens.extend(tokenizer.encode(response)) |
| message_count += 1 |
| |
| seq_len = model.args.max_seq_len |
| trained = False |
| |
| |
| while len(conversation_tokens) >= seq_len + 1: |
| batch = mx.array([conversation_tokens[:seq_len + 1]], dtype=mx.int32) |
| conversation_tokens = conversation_tokens[seq_len:] |
| |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
| loss, grads = loss_and_grad_fn(model, batch) |
| grads = model.zero_frozen_grads(grads) |
| il_optimizer.update(model, grads) |
| mx.eval(model.parameters(), il_optimizer.state, loss) |
| |
| il_step += 1 |
| model.set_global_step(il_step) |
| trained = True |
| print(f" [learned: loss={loss.item():.4f}, step={il_step}]") |
| |
| |
| if not trained and message_count % 2 == 0 and len(conversation_tokens) > 2: |
| pad_len = seq_len + 1 |
| tokens_to_use = conversation_tokens[-pad_len:] if len(conversation_tokens) >= pad_len else conversation_tokens |
| |
| while len(tokens_to_use) < pad_len: |
| tokens_to_use = tokens_to_use + tokens_to_use |
| tokens_to_use = tokens_to_use[:pad_len] |
| |
| batch = mx.array([tokens_to_use], dtype=mx.int32) |
| |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
| loss, grads = loss_and_grad_fn(model, batch) |
| grads = model.zero_frozen_grads(grads) |
| il_optimizer.update(model, grads) |
| mx.eval(model.parameters(), il_optimizer.state, loss) |
| |
| il_step += 1 |
| model.set_global_step(il_step) |
| print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]") |
| |
| |
| if il_step > 0 and il_step % tc.al_lifecycle_every == 0: |
| events = model.run_lifecycle() |
| if events: |
| il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} |
| |
| print(model.expert_summary()) |
| |
| save_checkpoint(model, il_step, tc.checkpoint_dir) |
| print("Model saved.") |
| final_step = il_step |
|
|
| elif tc.mode == "train_and_chat": |
| if not tc._resume: |
| print("WARNING: No checkpoint loaded, model is random.") |
| |
| il_optimizer = optim.AdamW(learning_rate=tc.al_learning_rate) |
| il_step = current_step |
| conversation_tokens = [] |
| message_count = 0 |
| |
| system_prompt = "You are a helpful assistant." |
| chat_history = [] |
| |
| print("\nChat Learning ready. Type 'quit' to exit.") |
| print("The model learns from the conversation with chat format.\n") |
| |
| while True: |
| user_input = input("You: ").strip() |
| if user_input.lower() in ("quit", "exit"): |
| break |
| if not user_input: |
| continue |
| |
| response = generate(model, tokenizer, user_input) |
| print(f"Model: {response}\n") |
| |
| |
| chat_history.append({"role": "user", "content": user_input}) |
| chat_history.append({"role": "assistant", "content": response}) |
| |
| chat_text = f"system\n{system_prompt}\n" |
| for msg in chat_history: |
| role = "human" if msg["role"] == "user" else "gpt" |
| chat_text += f"{role}\n{msg['content']}\n" |
| |
| conversation_tokens = tokenizer.encode(chat_text) |
| message_count += 1 |
| |
| seq_len = model.args.max_seq_len |
| trained = False |
| |
| |
| train_tokens = list(conversation_tokens) |
| while len(train_tokens) >= seq_len + 1: |
| batch = mx.array([train_tokens[:seq_len + 1]], dtype=mx.int32) |
| train_tokens = train_tokens[seq_len:] |
| |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
| loss, grads = loss_and_grad_fn(model, batch) |
| grads = model.zero_frozen_grads(grads) |
| try: |
| il_optimizer.update(model, grads) |
| except (ValueError, KeyError, IndexError): |
| il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} |
| il_optimizer.update(model, grads) |
| mx.eval(model.parameters(), il_optimizer.state, loss) |
| |
| il_step += 1 |
| model.set_global_step(il_step) |
| trained = True |
| print(f" [learned: loss={loss.item():.4f}, step={il_step}]") |
| |
| |
| if not trained and message_count % 2 == 0 and len(train_tokens) > 2: |
| pad_len = seq_len + 1 |
| tokens_to_use = train_tokens[-pad_len:] if len(train_tokens) >= pad_len else train_tokens |
| while len(tokens_to_use) < pad_len: |
| tokens_to_use = tokens_to_use + tokens_to_use |
| tokens_to_use = tokens_to_use[:pad_len] |
| |
| batch = mx.array([tokens_to_use], dtype=mx.int32) |
| |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) |
| loss, grads = loss_and_grad_fn(model, batch) |
| grads = model.zero_frozen_grads(grads) |
| try: |
| il_optimizer.update(model, grads) |
| except (ValueError, KeyError, IndexError): |
| il_optimizer.state = {k: v for k, v in il_optimizer.state.items() if not isinstance(v, (dict, list))} |
| il_optimizer.update(model, grads) |
| mx.eval(model.parameters(), il_optimizer.state, loss) |
| |
| il_step += 1 |
| model.set_global_step(il_step) |
| print(f" [forced learn @ msg {message_count}: loss={loss.item():.4f}, step={il_step}]") |
| |
| |
| max_history = 20 |
| if len(chat_history) > max_history: |
| chat_history = chat_history[-max_history:] |
| |
| |
| if il_step > 0 and il_step % tc.al_lifecycle_every == 0: |
| events = model.run_lifecycle(optimizer=il_optimizer) |
| if events: |
| pass |
| |
| print(model.expert_summary()) |
| |
| save_checkpoint(model, il_step, tc.checkpoint_dir) |
| print("Model saved.") |
| final_step = il_step |
|
|
| |
| print("\nTraining complete.") |
| save_checkpoint(model, final_step, tc.checkpoint_dir) |
| print("Final layout:") |
| print(model.expert_summary()) |
|
|
|
|
| if __name__ == "__main__": |
| main() |