| """ |
| train_himoe.py β HiMoE (Hierarchical Mixture of Experts) Training Script |
| ========================================================================= |
| Architecture inspired by Matryoshka MoE: a nested, two-level routing system |
| where a top-level router selects a MoE block, and each MoE block has its own |
| router selecting among its local experts. |
| |
| Saved layout: |
| model/ |
| main_router.pt β top-level (Level-1) gate weights |
| moe_expert_001/ |
| router.pt β Level-2 gate for this MoE block |
| model_001.pt β¦ model_008.pt β individual expert weights |
| moe_expert_002/ β¦ |
| β¦ |
| backbone.pt β embeddings, attention, LN, LM head |
| config.json β full config for re-loading |
| |
| Usage: |
| python train_himoe.py # train from scratch |
| python train_himoe.py --resume # continue from saved checkpoint |
| """ |
|
|
| import os |
| import json |
| import time |
| import math |
| import argparse |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
| |
| |
| |
|
|
| class HiMoEConfig: |
| |
| block_size: int = 128 |
| n_layer: int = 2 |
| n_head: int = 4 |
| n_embd: int = 256 |
| dropout: float = 0.1 |
| |
| num_moes: int = 6 |
| num_experts: int = 8 |
| |
| batch_size: int = 32 |
| max_iters: int = 750 |
| eval_interval:int = 50 |
| eval_iters: int = 20 |
| lr: float = 3e-4 |
| |
| data_file: str = "hamlet.txt" |
| model_dir: str = "model" |
|
|
| def to_dict(self): |
| return {k: v for k, v in self.__class__.__dict__.items() |
| if not k.startswith("_") and not callable(v)} |
|
|
|
|
| |
| |
| |
|
|
| class Expert(nn.Module): |
| """A single feed-forward expert network.""" |
| def __init__(self, n_embd: int, dropout: float = 0.0): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(n_embd, 4 * n_embd), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(4 * n_embd, n_embd), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class MoEBlock(nn.Module): |
| """ |
| Level-2 MoE: owns `num_experts` experts and its own gate (router). |
| Top-1 routing β only one expert is activated per token. |
| """ |
| def __init__(self, n_embd: int, num_experts: int, dropout: float = 0.0): |
| super().__init__() |
| self.num_experts = num_experts |
| self.experts = nn.ModuleList( |
| [Expert(n_embd, dropout) for _ in range(num_experts)] |
| ) |
| |
| self.router = nn.Linear(n_embd, num_experts, bias=False) |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| x : (tokens, C) β already flattened to 2-D before entering here |
| Returns: output (tokens, C), chosen expert indices (tokens,) |
| """ |
| logits = self.router(x) |
| probs = F.softmax(logits, dim=-1) |
| chosen = probs.argmax(dim=-1) |
|
|
| out = torch.zeros_like(x) |
| for i, expert in enumerate(self.experts): |
| mask = (chosen == i) |
| if mask.any(): |
| out[mask] = expert(x[mask]) |
| return out, chosen |
|
|
|
|
| class HiMoEFFN(nn.Module): |
| """ |
| Hierarchical MoE FFN (replaces the standard FFN in a Transformer block). |
| |
| Level-1 router selects one MoEBlock; that block's Level-2 router selects |
| one expert β Matryoshka-style nesting. |
| """ |
| def __init__(self, cfg: HiMoEConfig): |
| super().__init__() |
| self.num_moes = cfg.num_moes |
| self.num_experts = cfg.num_experts |
| |
| self.main_router = nn.Linear(cfg.n_embd, cfg.num_moes, bias=False) |
| self.moe_blocks = nn.ModuleList( |
| [MoEBlock(cfg.n_embd, cfg.num_experts, cfg.dropout) |
| for _ in range(cfg.num_moes)] |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| x : (B, T, C) |
| Returns: output (B, T, C), |
| moe_ids (B, T) β which MoE was chosen, |
| exp_ids (B, T) β which expert inside that MoE was chosen |
| """ |
| B, T, C = x.shape |
| flat = x.view(B * T, C) |
|
|
| |
| l1_logits = self.main_router(flat) |
| l1_probs = F.softmax(l1_logits, dim=-1) |
| moe_ids = l1_probs.argmax(dim=-1) |
|
|
| out = torch.zeros_like(flat) |
| exp_ids = torch.zeros_like(moe_ids) |
|
|
| for i, moe_block in enumerate(self.moe_blocks): |
| mask = (moe_ids == i) |
| if mask.any(): |
| result, chosen_exp = moe_block(flat[mask]) |
| out[mask] = result |
| exp_ids[mask] = chosen_exp |
|
|
| return (out.view(B, T, C), |
| moe_ids.view(B, T), |
| exp_ids.view(B, T)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: HiMoEConfig): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(cfg.n_embd) |
| self.attn = nn.MultiheadAttention( |
| cfg.n_embd, cfg.n_head, |
| dropout=cfg.dropout, batch_first=True |
| ) |
| self.ln2 = nn.LayerNorm(cfg.n_embd) |
| self.himoe = HiMoEFFN(cfg) |
|
|
| def forward(self, x: torch.Tensor, attn_mask=None): |
| |
| xn = self.ln1(x) |
| attn_out, _ = self.attn(xn, xn, xn, |
| attn_mask=attn_mask, |
| need_weights=False, |
| is_causal=True if attn_mask is None else False) |
| x = x + attn_out |
|
|
| |
| xn = self.ln2(x) |
| ffn_out, moe_ids, exp_ids = self.himoe(xn) |
| x = x + ffn_out |
| return x, moe_ids, exp_ids |
|
|
|
|
| class HiMoEModel(nn.Module): |
| def __init__(self, cfg: HiMoEConfig, vocab_size: int): |
| super().__init__() |
| self.cfg = cfg |
| self.vocab_size = vocab_size |
|
|
| |
| self.tok_emb = nn.Embedding(vocab_size, cfg.n_embd) |
| self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd) |
| self.drop = nn.Dropout(cfg.dropout) |
| self.blocks = nn.ModuleList( |
| [TransformerBlock(cfg) for _ in range(cfg.n_layer)] |
| ) |
| self.ln_f = nn.LayerNorm(cfg.n_embd) |
| self.lm_head = nn.Linear(cfg.n_embd, vocab_size, bias=False) |
|
|
| |
| self.tok_emb.weight = self.lm_head.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
|
|
| def forward(self, idx: torch.Tensor, targets=None): |
| B, T = idx.shape |
| assert T <= self.cfg.block_size, \ |
| f"Sequence length {T} > block_size {self.cfg.block_size}" |
|
|
| |
| mask = torch.full((T, T), float('-inf'), device=idx.device) |
| mask = torch.triu(mask, diagonal=1) |
|
|
| tok = self.tok_emb(idx) |
| pos = self.pos_emb(torch.arange(T, device=idx.device)) |
| x = self.drop(tok + pos) |
|
|
| all_moe_ids, all_exp_ids = [], [] |
| for block in self.blocks: |
| x, moe_ids, exp_ids = block(x, attn_mask=mask) |
| all_moe_ids.append(moe_ids) |
| all_exp_ids.append(exp_ids) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1) |
| ) |
| return logits, loss, all_moe_ids, all_exp_ids |
|
|
| @torch.no_grad() |
| def generate(self, idx: torch.Tensor, max_new_tokens: int, |
| temperature: float = 0.8, top_k: int = 40): |
| routing_log = [] |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.cfg.block_size:] |
| logits, _, moe_ids, exp_ids = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat((idx, idx_next), dim=1) |
| routing_log.append({ |
| "moe": [m[:, -1].tolist() for m in moe_ids], |
| "exp": [e[:, -1].tolist() for e in exp_ids], |
| }) |
| return idx, routing_log |
|
|
| def num_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| |
| |
| |
|
|
| def _moe_dir(base: str, moe_idx: int) -> str: |
| return os.path.join(base, f"moe_expert_{moe_idx+1:03d}") |
|
|
|
|
| def save_model(model: HiMoEModel, cfg: HiMoEConfig, vocab_size: int, |
| stoi: dict, itos: dict, step: int): |
| """ |
| Save the full model in the modular layout described in the docstring. |
| |
| model/ |
| config.json |
| backbone.pt |
| main_router.pt β shared across all transformer layers (layer 0 shown; |
| for n_layer > 1 we save per-layer sub-dirs) |
| moe_expert_001/ |
| router.pt |
| model_001.pt β¦ model_008.pt |
| β¦ |
| """ |
| base = cfg.model_dir |
| os.makedirs(base, exist_ok=True) |
|
|
| |
| meta = { |
| "config": cfg.to_dict(), |
| "vocab_size": vocab_size, |
| "step": step, |
| "stoi": stoi, |
| "itos": itos, |
| } |
| with open(os.path.join(base, "config.json"), "w") as f: |
| json.dump(meta, f, indent=2) |
|
|
| |
| backbone_sd = { |
| "tok_emb": model.tok_emb.state_dict(), |
| "pos_emb": model.pos_emb.state_dict(), |
| "ln_f": model.ln_f.state_dict(), |
| "lm_head": model.lm_head.state_dict(), |
| |
| "blocks_attn": [ |
| { |
| "ln1": blk.ln1.state_dict(), |
| "attn": blk.attn.state_dict(), |
| "ln2": blk.ln2.state_dict(), |
| } |
| for blk in model.blocks |
| ], |
| } |
| torch.save(backbone_sd, os.path.join(base, "backbone.pt")) |
|
|
| |
| |
| for layer_idx, blk in enumerate(model.blocks): |
| himoe = blk.himoe |
|
|
| |
| layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else "" |
|
|
| |
| torch.save( |
| himoe.main_router.state_dict(), |
| os.path.join(base, f"{layer_prefix}main_router.pt") |
| ) |
|
|
| |
| for moe_i, moe_block in enumerate(himoe.moe_blocks): |
| moe_path = os.path.join( |
| base, |
| f"{layer_prefix}moe_expert_{moe_i+1:03d}" |
| ) |
| os.makedirs(moe_path, exist_ok=True) |
|
|
| |
| torch.save( |
| moe_block.router.state_dict(), |
| os.path.join(moe_path, "router.pt") |
| ) |
|
|
| |
| for exp_i, expert in enumerate(moe_block.experts): |
| torch.save( |
| expert.state_dict(), |
| os.path.join(moe_path, f"model_{exp_i+1:03d}.pt") |
| ) |
|
|
| print(f"[save] Model saved to '{base}/' at step {step}.") |
|
|
|
|
| def load_model(model_dir: str, device: str) -> tuple: |
| """ |
| Load the full model from the modular directory layout. |
| Returns (model, cfg, stoi, itos, step). |
| """ |
| with open(os.path.join(model_dir, "config.json")) as f: |
| meta = json.load(f) |
|
|
| cfg = HiMoEConfig() |
| for k, v in meta["config"].items(): |
| setattr(cfg, k, v) |
| cfg.model_dir = model_dir |
| vocab_size = meta["vocab_size"] |
| stoi = meta["stoi"] |
| itos = {int(k): v for k, v in meta["itos"].items()} |
| step = meta["step"] |
|
|
| model = HiMoEModel(cfg, vocab_size).to(device) |
|
|
| |
| bb = torch.load(os.path.join(model_dir, "backbone.pt"), map_location=device) |
| model.tok_emb.load_state_dict(bb["tok_emb"]) |
| model.pos_emb.load_state_dict(bb["pos_emb"]) |
| model.ln_f.load_state_dict(bb["ln_f"]) |
| model.lm_head.load_state_dict(bb["lm_head"]) |
| for i, blk in enumerate(model.blocks): |
| blk.ln1.load_state_dict(bb["blocks_attn"][i]["ln1"]) |
| blk.attn.load_state_dict(bb["blocks_attn"][i]["attn"]) |
| blk.ln2.load_state_dict(bb["blocks_attn"][i]["ln2"]) |
|
|
| |
| for layer_idx, blk in enumerate(model.blocks): |
| himoe = blk.himoe |
| layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else "" |
|
|
| himoe.main_router.load_state_dict( |
| torch.load(os.path.join(model_dir, f"{layer_prefix}main_router.pt"), |
| map_location=device) |
| ) |
| for moe_i, moe_block in enumerate(himoe.moe_blocks): |
| moe_path = os.path.join( |
| model_dir, f"{layer_prefix}moe_expert_{moe_i+1:03d}" |
| ) |
| moe_block.router.load_state_dict( |
| torch.load(os.path.join(moe_path, "router.pt"), |
| map_location=device) |
| ) |
| for exp_i, expert in enumerate(moe_block.experts): |
| expert.load_state_dict( |
| torch.load(os.path.join(moe_path, f"model_{exp_i+1:03d}.pt"), |
| map_location=device) |
| ) |
|
|
| print(f"[load] Resumed from '{model_dir}/' at step {step}.") |
| return model, cfg, stoi, itos, step |
|
|
|
|
| |
| |
| |
|
|
| def build_vocab(text: str): |
| chars = sorted(set(text)) |
| stoi = {c: i for i, c in enumerate(chars)} |
| itos = {i: c for i, c in enumerate(chars)} |
| return stoi, itos |
|
|
|
|
| def encode(text: str, stoi: dict) -> list: |
| return [stoi[c] for c in text] |
|
|
|
|
| def get_batch(data: torch.Tensor, block_size: int, |
| batch_size: int, device: str): |
| ix = torch.randint(len(data) - block_size, (batch_size,)) |
| x = torch.stack([data[i:i+block_size] for i in ix]).to(device) |
| y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device) |
| return x, y |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss(model, train_data, val_data, cfg, device): |
| model.eval() |
| result = {} |
| for split, ds in [("train", train_data), ("val", val_data)]: |
| losses = torch.zeros(cfg.eval_iters) |
| for k in range(cfg.eval_iters): |
| x, y = get_batch(ds, cfg.block_size, cfg.batch_size, device) |
| _, loss, _, _ = model(x, y) |
| losses[k] = loss.item() |
| result[split] = losses.mean().item() |
| model.train() |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def train(cfg: HiMoEConfig, resume: bool = False): |
| device = "cpu" |
| if torch.cuda.is_available(): |
| device = "cuda" |
| elif torch.backends.mps.is_available(): |
| device = "mps" |
| print(f"[himoe] Device: {device}") |
|
|
| |
| with open(cfg.data_file, "r", encoding="utf-8") as f: |
| text = f.read() |
| print(f"[himoe] Dataset: {len(text):,} characters") |
|
|
| stoi, itos = build_vocab(text) |
| vocab_size = len(stoi) |
| data = torch.tensor(encode(text, stoi), dtype=torch.long) |
| n = int(0.9 * len(data)) |
| train_data = data[:n] |
| val_data = data[n:] |
|
|
| |
| start_step = 0 |
| if resume and os.path.isfile(os.path.join(cfg.model_dir, "config.json")): |
| model, cfg, stoi, itos, start_step = load_model(cfg.model_dir, device) |
| else: |
| model = HiMoEModel(cfg, vocab_size).to(device) |
|
|
| total_params = model.num_params() |
| active_params = ( |
| |
| sum(p.numel() for blk in model.blocks |
| for p in list(blk.attn.parameters()) + |
| list(blk.ln1.parameters()) + |
| list(blk.ln2.parameters())) |
| + sum(p.numel() for p in model.tok_emb.parameters()) |
| + sum(p.numel() for p in model.pos_emb.parameters()) |
| + sum(p.numel() for p in model.ln_f.parameters()) |
| + sum(p.numel() for p in model.lm_head.parameters()) |
| |
| + cfg.n_layer * ( |
| sum(p.numel() for p in model.blocks[0].himoe.main_router.parameters()) |
| + sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].router.parameters()) |
| + sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].experts[0].parameters()) |
| ) |
| ) |
|
|
| print(f"[himoe] Total params : {total_params/1e6:.2f}M") |
| print(f"[himoe] Active/token : ~{active_params/1e6:.2f}M " |
| f"({100*active_params/total_params:.1f}% of total)") |
| print(f"[himoe] Vocab size : {vocab_size}") |
| print(f"[himoe] MoE structure : {cfg.num_moes} MoEs Γ {cfg.num_experts} experts " |
| f"= {cfg.num_moes * cfg.num_experts} total experts") |
|
|
| |
| |
| decay = {p for n, p in model.named_parameters() |
| if p.dim() >= 2 and p.requires_grad} |
| no_decay = {p for n, p in model.named_parameters() |
| if p.dim() < 2 and p.requires_grad} |
| optimizer = torch.optim.AdamW([ |
| {"params": list(decay), "weight_decay": 0.1}, |
| {"params": list(no_decay), "weight_decay": 0.0}, |
| ], lr=cfg.lr, betas=(0.9, 0.95)) |
|
|
| |
| def lr_schedule(step): |
| warmup = 100 |
| if step < warmup: |
| return step / warmup |
| progress = (step - warmup) / max(1, cfg.max_iters - warmup) |
| return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) |
|
|
| |
| print(f"\n[himoe] Training for {cfg.max_iters} steps β¦\n") |
| t0 = time.time() |
|
|
| for step in range(start_step, cfg.max_iters): |
| |
| if step % cfg.eval_interval == 0: |
| losses = estimate_loss(model, train_data, val_data, cfg, device) |
| elapsed = time.time() - t0 |
| eta = (elapsed / max(step - start_step, 1)) * (cfg.max_iters - step) |
| lr_now = optimizer.param_groups[0]["lr"] |
| print(f"step {step:>5}/{cfg.max_iters} | " |
| f"train {losses['train']:.4f} | " |
| f"val {losses['val']:.4f} | " |
| f"lr {lr_now:.2e} | " |
| f"ETA {eta/60:.1f}m") |
| save_model(model, cfg, vocab_size, stoi, itos, step) |
| |
| |
| model.eval() |
| with torch.no_grad(): |
| |
| original_device = next(model.parameters()).device |
| model.to("cpu") |
| context = torch.zeros((1, 1), dtype=torch.long, device="cpu") |
| gen_ids, r_log = model.generate(context, max_new_tokens=400, temperature=0.8, top_k=40) |
| smp = "".join(itos[i] for i in gen_ids[0].tolist()) |
| with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f: |
| f.write(smp) |
| with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f: |
| json.dump(r_log, f, indent=2) |
| model.to(original_device) |
| model.train() |
|
|
|
|
| |
| x, y = get_batch(train_data, cfg.block_size, |
| cfg.batch_size, device) |
| _, loss, _, _ = model(x, y) |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
|
|
| |
| if step % 5 == 0: |
| print(f"\rstep {step:>5}/{cfg.max_iters} | loss {loss.item():.4f} | lr {optimizer.param_groups[0]['lr']:.2e}", end="", flush=True) |
| if step % cfg.eval_interval == 0 and step > start_step: |
| print() |
|
|
| |
| save_model(model, cfg, vocab_size, stoi, itos, cfg.max_iters) |
| print("\n[himoe] Training complete.") |
|
|
| |
| print("\n[himoe] Sample generation:\n" + "β" * 60) |
| model.eval() |
| context = torch.zeros((1, 1), dtype=torch.long, device=device) |
| gen_ids, routing_log = model.generate(context, max_new_tokens=400, |
| temperature=0.8, top_k=40) |
| sample = "".join(itos[i] for i in gen_ids[0].tolist()) |
| print(sample) |
| print("β" * 60) |
|
|
| with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f: |
| f.write(sample) |
| with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f: |
| json.dump(routing_log, f, indent=2) |
|
|
| print(f"\n[himoe] Sample + routing log saved to '{cfg.model_dir}/'") |
|
|
| |
| print("\n[himoe] Expert utilisation (last generation, layer 0):") |
| moe_counts = [0] * cfg.num_moes |
| exp_counts = [[0] * cfg.num_experts for _ in range(cfg.num_moes)] |
| for entry in routing_log: |
| m = entry["moe"][0][0] |
| e = entry["exp"][0][0] |
| moe_counts[m] += 1 |
| exp_counts[m][e] += 1 |
| total = sum(moe_counts) |
| for mi, mc in enumerate(moe_counts): |
| bar = "β" * int(40 * mc / max(total, 1)) |
| print(f" MoE {mi+1:02d} [{bar:<40}] {mc:4d} tokens " |
| f"({100*mc/max(total,1):.1f}%)") |
| for ei, ec in enumerate(exp_counts[mi]): |
| if ec > 0: |
| print(f" Expert {ei+1:02d}: {ec} tokens") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train HiMoE on hamlet.txt") |
| parser.add_argument("--resume", action="store_true", |
| help="Resume from existing checkpoint in model/") |
| parser.add_argument("--max_iters", type=int, default=None) |
| parser.add_argument("--n_layer", type=int, default=None) |
| parser.add_argument("--n_embd", type=int, default=None) |
| parser.add_argument("--num_moes", type=int, default=None) |
| parser.add_argument("--num_experts", type=int, default=None) |
| parser.add_argument("--lr", type=float, default=None) |
| parser.add_argument("--data_file", type=str, default=None) |
| parser.add_argument("--model_dir", type=str, default=None) |
| args = parser.parse_args() |
|
|
| cfg = HiMoEConfig() |
| for attr in ["max_iters", "n_layer", "n_embd", "num_moes", |
| "num_experts", "lr", "data_file", "model_dir"]: |
| val = getattr(args, attr) |
| if val is not None: |
| setattr(cfg, attr, val) |
|
|
| train(cfg, resume=args.resume) |