""" train_himoe.py — HiMoE (Hierarchical Mixture of Experts) Training Script ========================================================================= Architecture inspired by Matryoshka MoE: a nested, two-level routing system where a top-level router selects a MoE block, and each MoE block has its own router selecting among its local experts. Saved layout: model/ main_router.pt ← top-level (Level-1) gate weights moe_expert_001/ router.pt ← Level-2 gate for this MoE block model_001.pt … model_008.pt ← individual expert weights moe_expert_002/ … … backbone.pt ← embeddings, attention, LN, LM head config.json ← full config for re-loading Usage: python train_himoe.py # train from scratch python train_himoe.py --resume # continue from saved checkpoint """ import os import json import time import math import argparse import torch import torch.nn as nn from torch.nn import functional as F # ────────────────────────────────────────────────────────────────────────────── # Config # ────────────────────────────────────────────────────────────────────────────── class HiMoEConfig: # Transformer backbone block_size: int = 128 n_layer: int = 2 n_head: int = 4 n_embd: int = 256 dropout: float = 0.1 # HiMoE routing (Matryoshka-style nesting) num_moes: int = 6 # Level-1 choices num_experts: int = 8 # Level-2 choices per MoE # Training batch_size: int = 32 max_iters: int = 750 # for testing, increase to 3000 for actual training eval_interval:int = 50 eval_iters: int = 20 lr: float = 3e-4 # Paths data_file: str = "hamlet.txt" model_dir: str = "model" def to_dict(self): return {k: v for k, v in self.__class__.__dict__.items() if not k.startswith("_") and not callable(v)} # ────────────────────────────────────────────────────────────────────────────── # Model components # ────────────────────────────────────────────────────────────────────────────── class Expert(nn.Module): """A single feed-forward expert network.""" def __init__(self, n_embd: int, dropout: float = 0.0): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.GELU(), nn.Dropout(dropout), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class MoEBlock(nn.Module): """ Level-2 MoE: owns `num_experts` experts and its own gate (router). Top-1 routing — only one expert is activated per token. """ def __init__(self, n_embd: int, num_experts: int, dropout: float = 0.0): super().__init__() self.num_experts = num_experts self.experts = nn.ModuleList( [Expert(n_embd, dropout) for _ in range(num_experts)] ) # Level-2 router (saved separately as router.pt inside the MoE folder) self.router = nn.Linear(n_embd, num_experts, bias=False) def forward(self, x: torch.Tensor): """ x : (tokens, C) — already flattened to 2-D before entering here Returns: output (tokens, C), chosen expert indices (tokens,) """ logits = self.router(x) # (tokens, E) probs = F.softmax(logits, dim=-1) chosen = probs.argmax(dim=-1) # (tokens,) out = torch.zeros_like(x) for i, expert in enumerate(self.experts): mask = (chosen == i) if mask.any(): out[mask] = expert(x[mask]) return out, chosen class HiMoEFFN(nn.Module): """ Hierarchical MoE FFN (replaces the standard FFN in a Transformer block). Level-1 router selects one MoEBlock; that block's Level-2 router selects one expert — Matryoshka-style nesting. """ def __init__(self, cfg: HiMoEConfig): super().__init__() self.num_moes = cfg.num_moes self.num_experts = cfg.num_experts # Level-1 router (saved as main_router.pt at the top level) self.main_router = nn.Linear(cfg.n_embd, cfg.num_moes, bias=False) self.moe_blocks = nn.ModuleList( [MoEBlock(cfg.n_embd, cfg.num_experts, cfg.dropout) for _ in range(cfg.num_moes)] ) def forward(self, x: torch.Tensor): """ x : (B, T, C) Returns: output (B, T, C), moe_ids (B, T) — which MoE was chosen, exp_ids (B, T) — which expert inside that MoE was chosen """ B, T, C = x.shape flat = x.view(B * T, C) # (tokens, C) # Level-1 routing l1_logits = self.main_router(flat) # (tokens, num_moes) l1_probs = F.softmax(l1_logits, dim=-1) moe_ids = l1_probs.argmax(dim=-1) # (tokens,) out = torch.zeros_like(flat) exp_ids = torch.zeros_like(moe_ids) # (tokens,) for i, moe_block in enumerate(self.moe_blocks): mask = (moe_ids == i) if mask.any(): result, chosen_exp = moe_block(flat[mask]) out[mask] = result exp_ids[mask] = chosen_exp return (out.view(B, T, C), moe_ids.view(B, T), exp_ids.view(B, T)) class TransformerBlock(nn.Module): def __init__(self, cfg: HiMoEConfig): super().__init__() self.ln1 = nn.LayerNorm(cfg.n_embd) self.attn = nn.MultiheadAttention( cfg.n_embd, cfg.n_head, dropout=cfg.dropout, batch_first=True ) self.ln2 = nn.LayerNorm(cfg.n_embd) self.himoe = HiMoEFFN(cfg) def forward(self, x: torch.Tensor, attn_mask=None): # Self-attention with causal mask xn = self.ln1(x) attn_out, _ = self.attn(xn, xn, xn, attn_mask=attn_mask, need_weights=False, is_causal=True if attn_mask is None else False) x = x + attn_out # Hierarchical MoE FFN xn = self.ln2(x) ffn_out, moe_ids, exp_ids = self.himoe(xn) x = x + ffn_out return x, moe_ids, exp_ids class HiMoEModel(nn.Module): def __init__(self, cfg: HiMoEConfig, vocab_size: int): super().__init__() self.cfg = cfg self.vocab_size = vocab_size # Backbone (saved as backbone.pt) self.tok_emb = nn.Embedding(vocab_size, cfg.n_embd) self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList( [TransformerBlock(cfg) for _ in range(cfg.n_layer)] ) self.ln_f = nn.LayerNorm(cfg.n_embd) self.lm_head = nn.Linear(cfg.n_embd, vocab_size, bias=False) # Weight tying self.tok_emb.weight = self.lm_head.weight self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def forward(self, idx: torch.Tensor, targets=None): B, T = idx.shape assert T <= self.cfg.block_size, \ f"Sequence length {T} > block_size {self.cfg.block_size}" # Create causal mask for attention mask = torch.full((T, T), float('-inf'), device=idx.device) mask = torch.triu(mask, diagonal=1) tok = self.tok_emb(idx) pos = self.pos_emb(torch.arange(T, device=idx.device)) x = self.drop(tok + pos) all_moe_ids, all_exp_ids = [], [] for block in self.blocks: x, moe_ids, exp_ids = block(x, attn_mask=mask) all_moe_ids.append(moe_ids) all_exp_ids.append(exp_ids) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1) ) return logits, loss, all_moe_ids, all_exp_ids @torch.no_grad() def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 0.8, top_k: int = 40): routing_log = [] for _ in range(max_new_tokens): idx_cond = idx[:, -self.cfg.block_size:] logits, _, moe_ids, exp_ids = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) routing_log.append({ "moe": [m[:, -1].tolist() for m in moe_ids], "exp": [e[:, -1].tolist() for e in exp_ids], }) return idx, routing_log def num_params(self): return sum(p.numel() for p in self.parameters()) # ────────────────────────────────────────────────────────────────────────────── # Modular save / load # ────────────────────────────────────────────────────────────────────────────── def _moe_dir(base: str, moe_idx: int) -> str: return os.path.join(base, f"moe_expert_{moe_idx+1:03d}") def save_model(model: HiMoEModel, cfg: HiMoEConfig, vocab_size: int, stoi: dict, itos: dict, step: int): """ Save the full model in the modular layout described in the docstring. model/ config.json backbone.pt main_router.pt ← shared across all transformer layers (layer 0 shown; for n_layer > 1 we save per-layer sub-dirs) moe_expert_001/ router.pt model_001.pt … model_008.pt … """ base = cfg.model_dir os.makedirs(base, exist_ok=True) # ── config + vocab ──────────────────────────────────────────────────────── meta = { "config": cfg.to_dict(), "vocab_size": vocab_size, "step": step, "stoi": stoi, "itos": itos, } with open(os.path.join(base, "config.json"), "w") as f: json.dump(meta, f, indent=2) # ── backbone ───────────────────────────────────────────────────────────── backbone_sd = { "tok_emb": model.tok_emb.state_dict(), "pos_emb": model.pos_emb.state_dict(), "ln_f": model.ln_f.state_dict(), "lm_head": model.lm_head.state_dict(), # per-block attention + layer norms (not the MoE parts) "blocks_attn": [ { "ln1": blk.ln1.state_dict(), "attn": blk.attn.state_dict(), "ln2": blk.ln2.state_dict(), } for blk in model.blocks ], } torch.save(backbone_sd, os.path.join(base, "backbone.pt")) # ── per-layer routers & experts ─────────────────────────────────────────── # For multi-layer models we namespace by layer; single-layer stays flat. for layer_idx, blk in enumerate(model.blocks): himoe = blk.himoe # Determine directory prefix layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else "" # Level-1 (main) router torch.save( himoe.main_router.state_dict(), os.path.join(base, f"{layer_prefix}main_router.pt") ) # Per-MoE directories for moe_i, moe_block in enumerate(himoe.moe_blocks): moe_path = os.path.join( base, f"{layer_prefix}moe_expert_{moe_i+1:03d}" ) os.makedirs(moe_path, exist_ok=True) # Level-2 router torch.save( moe_block.router.state_dict(), os.path.join(moe_path, "router.pt") ) # Individual experts for exp_i, expert in enumerate(moe_block.experts): torch.save( expert.state_dict(), os.path.join(moe_path, f"model_{exp_i+1:03d}.pt") ) print(f"[save] Model saved to '{base}/' at step {step}.") def load_model(model_dir: str, device: str) -> tuple: """ Load the full model from the modular directory layout. Returns (model, cfg, stoi, itos, step). """ with open(os.path.join(model_dir, "config.json")) as f: meta = json.load(f) cfg = HiMoEConfig() for k, v in meta["config"].items(): setattr(cfg, k, v) cfg.model_dir = model_dir vocab_size = meta["vocab_size"] stoi = meta["stoi"] itos = {int(k): v for k, v in meta["itos"].items()} step = meta["step"] model = HiMoEModel(cfg, vocab_size).to(device) # backbone bb = torch.load(os.path.join(model_dir, "backbone.pt"), map_location=device) model.tok_emb.load_state_dict(bb["tok_emb"]) model.pos_emb.load_state_dict(bb["pos_emb"]) model.ln_f.load_state_dict(bb["ln_f"]) model.lm_head.load_state_dict(bb["lm_head"]) for i, blk in enumerate(model.blocks): blk.ln1.load_state_dict(bb["blocks_attn"][i]["ln1"]) blk.attn.load_state_dict(bb["blocks_attn"][i]["attn"]) blk.ln2.load_state_dict(bb["blocks_attn"][i]["ln2"]) # routers + experts for layer_idx, blk in enumerate(model.blocks): himoe = blk.himoe layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else "" himoe.main_router.load_state_dict( torch.load(os.path.join(model_dir, f"{layer_prefix}main_router.pt"), map_location=device) ) for moe_i, moe_block in enumerate(himoe.moe_blocks): moe_path = os.path.join( model_dir, f"{layer_prefix}moe_expert_{moe_i+1:03d}" ) moe_block.router.load_state_dict( torch.load(os.path.join(moe_path, "router.pt"), map_location=device) ) for exp_i, expert in enumerate(moe_block.experts): expert.load_state_dict( torch.load(os.path.join(moe_path, f"model_{exp_i+1:03d}.pt"), map_location=device) ) print(f"[load] Resumed from '{model_dir}/' at step {step}.") return model, cfg, stoi, itos, step # ────────────────────────────────────────────────────────────────────────────── # Data helpers # ────────────────────────────────────────────────────────────────────────────── def build_vocab(text: str): chars = sorted(set(text)) stoi = {c: i for i, c in enumerate(chars)} itos = {i: c for i, c in enumerate(chars)} return stoi, itos def encode(text: str, stoi: dict) -> list: return [stoi[c] for c in text] def get_batch(data: torch.Tensor, block_size: int, batch_size: int, device: str): ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([data[i:i+block_size] for i in ix]).to(device) y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device) return x, y @torch.no_grad() def estimate_loss(model, train_data, val_data, cfg, device): model.eval() result = {} for split, ds in [("train", train_data), ("val", val_data)]: losses = torch.zeros(cfg.eval_iters) for k in range(cfg.eval_iters): x, y = get_batch(ds, cfg.block_size, cfg.batch_size, device) _, loss, _, _ = model(x, y) losses[k] = loss.item() result[split] = losses.mean().item() model.train() return result # ────────────────────────────────────────────────────────────────────────────── # Training loop # ────────────────────────────────────────────────────────────────────────────── def train(cfg: HiMoEConfig, resume: bool = False): device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" print(f"[himoe] Device: {device}") # ── data ───────────────────────────────────────────────────────────────── with open(cfg.data_file, "r", encoding="utf-8") as f: text = f.read() print(f"[himoe] Dataset: {len(text):,} characters") stoi, itos = build_vocab(text) vocab_size = len(stoi) data = torch.tensor(encode(text, stoi), dtype=torch.long) n = int(0.9 * len(data)) train_data = data[:n] val_data = data[n:] # ── model ───────────────────────────────────────────────────────────────── start_step = 0 if resume and os.path.isfile(os.path.join(cfg.model_dir, "config.json")): model, cfg, stoi, itos, start_step = load_model(cfg.model_dir, device) else: model = HiMoEModel(cfg, vocab_size).to(device) total_params = model.num_params() active_params = ( # attention + norms + embeddings (always active) sum(p.numel() for blk in model.blocks for p in list(blk.attn.parameters()) + list(blk.ln1.parameters()) + list(blk.ln2.parameters())) + sum(p.numel() for p in model.tok_emb.parameters()) + sum(p.numel() for p in model.pos_emb.parameters()) + sum(p.numel() for p in model.ln_f.parameters()) + sum(p.numel() for p in model.lm_head.parameters()) # only 1 MoE block × 1 expert active per layer per token + cfg.n_layer * ( sum(p.numel() for p in model.blocks[0].himoe.main_router.parameters()) + sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].router.parameters()) + sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].experts[0].parameters()) ) ) print(f"[himoe] Total params : {total_params/1e6:.2f}M") print(f"[himoe] Active/token : ~{active_params/1e6:.2f}M " f"({100*active_params/total_params:.1f}% of total)") print(f"[himoe] Vocab size : {vocab_size}") print(f"[himoe] MoE structure : {cfg.num_moes} MoEs × {cfg.num_experts} experts " f"= {cfg.num_moes * cfg.num_experts} total experts") # ── optimiser ───────────────────────────────────────────────────────────── # Use weight decay on weight matrices, not biases/norms decay = {p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad} no_decay = {p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad} optimizer = torch.optim.AdamW([ {"params": list(decay), "weight_decay": 0.1}, {"params": list(no_decay), "weight_decay": 0.0}, ], lr=cfg.lr, betas=(0.9, 0.95)) # cosine LR decay def lr_schedule(step): warmup = 100 if step < warmup: return step / warmup progress = (step - warmup) / max(1, cfg.max_iters - warmup) return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) # ── loop ────────────────────────────────────────────────────────────────── print(f"\n[himoe] Training for {cfg.max_iters} steps …\n") t0 = time.time() for step in range(start_step, cfg.max_iters): # periodic evaluation + save if step % cfg.eval_interval == 0: losses = estimate_loss(model, train_data, val_data, cfg, device) elapsed = time.time() - t0 eta = (elapsed / max(step - start_step, 1)) * (cfg.max_iters - step) lr_now = optimizer.param_groups[0]["lr"] print(f"step {step:>5}/{cfg.max_iters} | " f"train {losses['train']:.4f} | " f"val {losses['val']:.4f} | " f"lr {lr_now:.2e} | " f"ETA {eta/60:.1f}m") save_model(model, cfg, vocab_size, stoi, itos, step) # Generate sample and save routing log periodically for visualization model.eval() with torch.no_grad(): # Workaround for MPS generation hangs: move to CPU for sampling original_device = next(model.parameters()).device model.to("cpu") context = torch.zeros((1, 1), dtype=torch.long, device="cpu") gen_ids, r_log = model.generate(context, max_new_tokens=400, temperature=0.8, top_k=40) smp = "".join(itos[i] for i in gen_ids[0].tolist()) with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f: f.write(smp) with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f: json.dump(r_log, f, indent=2) model.to(original_device) model.train() # forward + backward x, y = get_batch(train_data, cfg.block_size, cfg.batch_size, device) _, loss, _, _ = model(x, y) optimizer.zero_grad(set_to_none=True) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() # Constant updates if step % 5 == 0: print(f"\rstep {step:>5}/{cfg.max_iters} | loss {loss.item():.4f} | lr {optimizer.param_groups[0]['lr']:.2e}", end="", flush=True) if step % cfg.eval_interval == 0 and step > start_step: print() # new line after progress bar # final save save_model(model, cfg, vocab_size, stoi, itos, cfg.max_iters) print("\n[himoe] Training complete.") # ── sample generation ───────────────────────────────────────────────────── print("\n[himoe] Sample generation:\n" + "─" * 60) model.eval() context = torch.zeros((1, 1), dtype=torch.long, device=device) gen_ids, routing_log = model.generate(context, max_new_tokens=400, temperature=0.8, top_k=40) sample = "".join(itos[i] for i in gen_ids[0].tolist()) print(sample) print("─" * 60) with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f: f.write(sample) with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f: json.dump(routing_log, f, indent=2) # save full log for visualization print(f"\n[himoe] Sample + routing log saved to '{cfg.model_dir}/'") # ── routing statistics ──────────────────────────────────────────────────── print("\n[himoe] Expert utilisation (last generation, layer 0):") moe_counts = [0] * cfg.num_moes exp_counts = [[0] * cfg.num_experts for _ in range(cfg.num_moes)] for entry in routing_log: m = entry["moe"][0][0] e = entry["exp"][0][0] moe_counts[m] += 1 exp_counts[m][e] += 1 total = sum(moe_counts) for mi, mc in enumerate(moe_counts): bar = "█" * int(40 * mc / max(total, 1)) print(f" MoE {mi+1:02d} [{bar:<40}] {mc:4d} tokens " f"({100*mc/max(total,1):.1f}%)") for ei, ec in enumerate(exp_counts[mi]): if ec > 0: print(f" Expert {ei+1:02d}: {ec} tokens") # ────────────────────────────────────────────────────────────────────────────── # Entry point # ────────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train HiMoE on hamlet.txt") parser.add_argument("--resume", action="store_true", help="Resume from existing checkpoint in model/") parser.add_argument("--max_iters", type=int, default=None) parser.add_argument("--n_layer", type=int, default=None) parser.add_argument("--n_embd", type=int, default=None) parser.add_argument("--num_moes", type=int, default=None) parser.add_argument("--num_experts", type=int, default=None) parser.add_argument("--lr", type=float, default=None) parser.add_argument("--data_file", type=str, default=None) parser.add_argument("--model_dir", type=str, default=None) args = parser.parse_args() cfg = HiMoEConfig() for attr in ["max_iters", "n_layer", "n_embd", "num_moes", "num_experts", "lr", "data_file", "model_dir"]: val = getattr(args, attr) if val is not None: setattr(cfg, attr, val) train(cfg, resume=args.resume)