HiMoE / train.py
AGofficial's picture
Upload 4 files
5404f1c verified
"""
train_himoe.py β€” HiMoE (Hierarchical Mixture of Experts) Training Script
=========================================================================
Architecture inspired by Matryoshka MoE: a nested, two-level routing system
where a top-level router selects a MoE block, and each MoE block has its own
router selecting among its local experts.
Saved layout:
model/
main_router.pt ← top-level (Level-1) gate weights
moe_expert_001/
router.pt ← Level-2 gate for this MoE block
model_001.pt … model_008.pt ← individual expert weights
moe_expert_002/ …
…
backbone.pt ← embeddings, attention, LN, LM head
config.json ← full config for re-loading
Usage:
python train_himoe.py # train from scratch
python train_himoe.py --resume # continue from saved checkpoint
"""
import os
import json
import time
import math
import argparse
import torch
import torch.nn as nn
from torch.nn import functional as F
# ──────────────────────────────────────────────────────────────────────────────
# Config
# ──────────────────────────────────────────────────────────────────────────────
class HiMoEConfig:
# Transformer backbone
block_size: int = 128
n_layer: int = 2
n_head: int = 4
n_embd: int = 256
dropout: float = 0.1
# HiMoE routing (Matryoshka-style nesting)
num_moes: int = 6 # Level-1 choices
num_experts: int = 8 # Level-2 choices per MoE
# Training
batch_size: int = 32
max_iters: int = 750 # for testing, increase to 3000 for actual training
eval_interval:int = 50
eval_iters: int = 20
lr: float = 3e-4
# Paths
data_file: str = "hamlet.txt"
model_dir: str = "model"
def to_dict(self):
return {k: v for k, v in self.__class__.__dict__.items()
if not k.startswith("_") and not callable(v)}
# ──────────────────────────────────────────────────────────────────────────────
# Model components
# ──────────────────────────────────────────────────────────────────────────────
class Expert(nn.Module):
"""A single feed-forward expert network."""
def __init__(self, n_embd: int, dropout: float = 0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class MoEBlock(nn.Module):
"""
Level-2 MoE: owns `num_experts` experts and its own gate (router).
Top-1 routing β€” only one expert is activated per token.
"""
def __init__(self, n_embd: int, num_experts: int, dropout: float = 0.0):
super().__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList(
[Expert(n_embd, dropout) for _ in range(num_experts)]
)
# Level-2 router (saved separately as router.pt inside the MoE folder)
self.router = nn.Linear(n_embd, num_experts, bias=False)
def forward(self, x: torch.Tensor):
"""
x : (tokens, C) β€” already flattened to 2-D before entering here
Returns: output (tokens, C), chosen expert indices (tokens,)
"""
logits = self.router(x) # (tokens, E)
probs = F.softmax(logits, dim=-1)
chosen = probs.argmax(dim=-1) # (tokens,)
out = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
mask = (chosen == i)
if mask.any():
out[mask] = expert(x[mask])
return out, chosen
class HiMoEFFN(nn.Module):
"""
Hierarchical MoE FFN (replaces the standard FFN in a Transformer block).
Level-1 router selects one MoEBlock; that block's Level-2 router selects
one expert β€” Matryoshka-style nesting.
"""
def __init__(self, cfg: HiMoEConfig):
super().__init__()
self.num_moes = cfg.num_moes
self.num_experts = cfg.num_experts
# Level-1 router (saved as main_router.pt at the top level)
self.main_router = nn.Linear(cfg.n_embd, cfg.num_moes, bias=False)
self.moe_blocks = nn.ModuleList(
[MoEBlock(cfg.n_embd, cfg.num_experts, cfg.dropout)
for _ in range(cfg.num_moes)]
)
def forward(self, x: torch.Tensor):
"""
x : (B, T, C)
Returns: output (B, T, C),
moe_ids (B, T) β€” which MoE was chosen,
exp_ids (B, T) β€” which expert inside that MoE was chosen
"""
B, T, C = x.shape
flat = x.view(B * T, C) # (tokens, C)
# Level-1 routing
l1_logits = self.main_router(flat) # (tokens, num_moes)
l1_probs = F.softmax(l1_logits, dim=-1)
moe_ids = l1_probs.argmax(dim=-1) # (tokens,)
out = torch.zeros_like(flat)
exp_ids = torch.zeros_like(moe_ids) # (tokens,)
for i, moe_block in enumerate(self.moe_blocks):
mask = (moe_ids == i)
if mask.any():
result, chosen_exp = moe_block(flat[mask])
out[mask] = result
exp_ids[mask] = chosen_exp
return (out.view(B, T, C),
moe_ids.view(B, T),
exp_ids.view(B, T))
class TransformerBlock(nn.Module):
def __init__(self, cfg: HiMoEConfig):
super().__init__()
self.ln1 = nn.LayerNorm(cfg.n_embd)
self.attn = nn.MultiheadAttention(
cfg.n_embd, cfg.n_head,
dropout=cfg.dropout, batch_first=True
)
self.ln2 = nn.LayerNorm(cfg.n_embd)
self.himoe = HiMoEFFN(cfg)
def forward(self, x: torch.Tensor, attn_mask=None):
# Self-attention with causal mask
xn = self.ln1(x)
attn_out, _ = self.attn(xn, xn, xn,
attn_mask=attn_mask,
need_weights=False,
is_causal=True if attn_mask is None else False)
x = x + attn_out
# Hierarchical MoE FFN
xn = self.ln2(x)
ffn_out, moe_ids, exp_ids = self.himoe(xn)
x = x + ffn_out
return x, moe_ids, exp_ids
class HiMoEModel(nn.Module):
def __init__(self, cfg: HiMoEConfig, vocab_size: int):
super().__init__()
self.cfg = cfg
self.vocab_size = vocab_size
# Backbone (saved as backbone.pt)
self.tok_emb = nn.Embedding(vocab_size, cfg.n_embd)
self.pos_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg.n_layer)]
)
self.ln_f = nn.LayerNorm(cfg.n_embd)
self.lm_head = nn.Linear(cfg.n_embd, vocab_size, bias=False)
# Weight tying
self.tok_emb.weight = self.lm_head.weight
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def forward(self, idx: torch.Tensor, targets=None):
B, T = idx.shape
assert T <= self.cfg.block_size, \
f"Sequence length {T} > block_size {self.cfg.block_size}"
# Create causal mask for attention
mask = torch.full((T, T), float('-inf'), device=idx.device)
mask = torch.triu(mask, diagonal=1)
tok = self.tok_emb(idx)
pos = self.pos_emb(torch.arange(T, device=idx.device))
x = self.drop(tok + pos)
all_moe_ids, all_exp_ids = [], []
for block in self.blocks:
x, moe_ids, exp_ids = block(x, attn_mask=mask)
all_moe_ids.append(moe_ids)
all_exp_ids.append(exp_ids)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return logits, loss, all_moe_ids, all_exp_ids
@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int,
temperature: float = 0.8, top_k: int = 40):
routing_log = []
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.cfg.block_size:]
logits, _, moe_ids, exp_ids = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
routing_log.append({
"moe": [m[:, -1].tolist() for m in moe_ids],
"exp": [e[:, -1].tolist() for e in exp_ids],
})
return idx, routing_log
def num_params(self):
return sum(p.numel() for p in self.parameters())
# ──────────────────────────────────────────────────────────────────────────────
# Modular save / load
# ──────────────────────────────────────────────────────────────────────────────
def _moe_dir(base: str, moe_idx: int) -> str:
return os.path.join(base, f"moe_expert_{moe_idx+1:03d}")
def save_model(model: HiMoEModel, cfg: HiMoEConfig, vocab_size: int,
stoi: dict, itos: dict, step: int):
"""
Save the full model in the modular layout described in the docstring.
model/
config.json
backbone.pt
main_router.pt ← shared across all transformer layers (layer 0 shown;
for n_layer > 1 we save per-layer sub-dirs)
moe_expert_001/
router.pt
model_001.pt … model_008.pt
…
"""
base = cfg.model_dir
os.makedirs(base, exist_ok=True)
# ── config + vocab ────────────────────────────────────────────────────────
meta = {
"config": cfg.to_dict(),
"vocab_size": vocab_size,
"step": step,
"stoi": stoi,
"itos": itos,
}
with open(os.path.join(base, "config.json"), "w") as f:
json.dump(meta, f, indent=2)
# ── backbone ─────────────────────────────────────────────────────────────
backbone_sd = {
"tok_emb": model.tok_emb.state_dict(),
"pos_emb": model.pos_emb.state_dict(),
"ln_f": model.ln_f.state_dict(),
"lm_head": model.lm_head.state_dict(),
# per-block attention + layer norms (not the MoE parts)
"blocks_attn": [
{
"ln1": blk.ln1.state_dict(),
"attn": blk.attn.state_dict(),
"ln2": blk.ln2.state_dict(),
}
for blk in model.blocks
],
}
torch.save(backbone_sd, os.path.join(base, "backbone.pt"))
# ── per-layer routers & experts ───────────────────────────────────────────
# For multi-layer models we namespace by layer; single-layer stays flat.
for layer_idx, blk in enumerate(model.blocks):
himoe = blk.himoe
# Determine directory prefix
layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else ""
# Level-1 (main) router
torch.save(
himoe.main_router.state_dict(),
os.path.join(base, f"{layer_prefix}main_router.pt")
)
# Per-MoE directories
for moe_i, moe_block in enumerate(himoe.moe_blocks):
moe_path = os.path.join(
base,
f"{layer_prefix}moe_expert_{moe_i+1:03d}"
)
os.makedirs(moe_path, exist_ok=True)
# Level-2 router
torch.save(
moe_block.router.state_dict(),
os.path.join(moe_path, "router.pt")
)
# Individual experts
for exp_i, expert in enumerate(moe_block.experts):
torch.save(
expert.state_dict(),
os.path.join(moe_path, f"model_{exp_i+1:03d}.pt")
)
print(f"[save] Model saved to '{base}/' at step {step}.")
def load_model(model_dir: str, device: str) -> tuple:
"""
Load the full model from the modular directory layout.
Returns (model, cfg, stoi, itos, step).
"""
with open(os.path.join(model_dir, "config.json")) as f:
meta = json.load(f)
cfg = HiMoEConfig()
for k, v in meta["config"].items():
setattr(cfg, k, v)
cfg.model_dir = model_dir
vocab_size = meta["vocab_size"]
stoi = meta["stoi"]
itos = {int(k): v for k, v in meta["itos"].items()}
step = meta["step"]
model = HiMoEModel(cfg, vocab_size).to(device)
# backbone
bb = torch.load(os.path.join(model_dir, "backbone.pt"), map_location=device)
model.tok_emb.load_state_dict(bb["tok_emb"])
model.pos_emb.load_state_dict(bb["pos_emb"])
model.ln_f.load_state_dict(bb["ln_f"])
model.lm_head.load_state_dict(bb["lm_head"])
for i, blk in enumerate(model.blocks):
blk.ln1.load_state_dict(bb["blocks_attn"][i]["ln1"])
blk.attn.load_state_dict(bb["blocks_attn"][i]["attn"])
blk.ln2.load_state_dict(bb["blocks_attn"][i]["ln2"])
# routers + experts
for layer_idx, blk in enumerate(model.blocks):
himoe = blk.himoe
layer_prefix = f"layer_{layer_idx+1:02d}_" if cfg.n_layer > 1 else ""
himoe.main_router.load_state_dict(
torch.load(os.path.join(model_dir, f"{layer_prefix}main_router.pt"),
map_location=device)
)
for moe_i, moe_block in enumerate(himoe.moe_blocks):
moe_path = os.path.join(
model_dir, f"{layer_prefix}moe_expert_{moe_i+1:03d}"
)
moe_block.router.load_state_dict(
torch.load(os.path.join(moe_path, "router.pt"),
map_location=device)
)
for exp_i, expert in enumerate(moe_block.experts):
expert.load_state_dict(
torch.load(os.path.join(moe_path, f"model_{exp_i+1:03d}.pt"),
map_location=device)
)
print(f"[load] Resumed from '{model_dir}/' at step {step}.")
return model, cfg, stoi, itos, step
# ──────────────────────────────────────────────────────────────────────────────
# Data helpers
# ──────────────────────────────────────────────────────────────────────────────
def build_vocab(text: str):
chars = sorted(set(text))
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
return stoi, itos
def encode(text: str, stoi: dict) -> list:
return [stoi[c] for c in text]
def get_batch(data: torch.Tensor, block_size: int,
batch_size: int, device: str):
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
return x, y
@torch.no_grad()
def estimate_loss(model, train_data, val_data, cfg, device):
model.eval()
result = {}
for split, ds in [("train", train_data), ("val", val_data)]:
losses = torch.zeros(cfg.eval_iters)
for k in range(cfg.eval_iters):
x, y = get_batch(ds, cfg.block_size, cfg.batch_size, device)
_, loss, _, _ = model(x, y)
losses[k] = loss.item()
result[split] = losses.mean().item()
model.train()
return result
# ──────────────────────────────────────────────────────────────────────────────
# Training loop
# ──────────────────────────────────────────────────────────────────────────────
def train(cfg: HiMoEConfig, resume: bool = False):
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
print(f"[himoe] Device: {device}")
# ── data ─────────────────────────────────────────────────────────────────
with open(cfg.data_file, "r", encoding="utf-8") as f:
text = f.read()
print(f"[himoe] Dataset: {len(text):,} characters")
stoi, itos = build_vocab(text)
vocab_size = len(stoi)
data = torch.tensor(encode(text, stoi), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
# ── model ─────────────────────────────────────────────────────────────────
start_step = 0
if resume and os.path.isfile(os.path.join(cfg.model_dir, "config.json")):
model, cfg, stoi, itos, start_step = load_model(cfg.model_dir, device)
else:
model = HiMoEModel(cfg, vocab_size).to(device)
total_params = model.num_params()
active_params = (
# attention + norms + embeddings (always active)
sum(p.numel() for blk in model.blocks
for p in list(blk.attn.parameters()) +
list(blk.ln1.parameters()) +
list(blk.ln2.parameters()))
+ sum(p.numel() for p in model.tok_emb.parameters())
+ sum(p.numel() for p in model.pos_emb.parameters())
+ sum(p.numel() for p in model.ln_f.parameters())
+ sum(p.numel() for p in model.lm_head.parameters())
# only 1 MoE block Γ— 1 expert active per layer per token
+ cfg.n_layer * (
sum(p.numel() for p in model.blocks[0].himoe.main_router.parameters())
+ sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].router.parameters())
+ sum(p.numel() for p in model.blocks[0].himoe.moe_blocks[0].experts[0].parameters())
)
)
print(f"[himoe] Total params : {total_params/1e6:.2f}M")
print(f"[himoe] Active/token : ~{active_params/1e6:.2f}M "
f"({100*active_params/total_params:.1f}% of total)")
print(f"[himoe] Vocab size : {vocab_size}")
print(f"[himoe] MoE structure : {cfg.num_moes} MoEs Γ— {cfg.num_experts} experts "
f"= {cfg.num_moes * cfg.num_experts} total experts")
# ── optimiser ─────────────────────────────────────────────────────────────
# Use weight decay on weight matrices, not biases/norms
decay = {p for n, p in model.named_parameters()
if p.dim() >= 2 and p.requires_grad}
no_decay = {p for n, p in model.named_parameters()
if p.dim() < 2 and p.requires_grad}
optimizer = torch.optim.AdamW([
{"params": list(decay), "weight_decay": 0.1},
{"params": list(no_decay), "weight_decay": 0.0},
], lr=cfg.lr, betas=(0.9, 0.95))
# cosine LR decay
def lr_schedule(step):
warmup = 100
if step < warmup:
return step / warmup
progress = (step - warmup) / max(1, cfg.max_iters - warmup)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
# ── loop ──────────────────────────────────────────────────────────────────
print(f"\n[himoe] Training for {cfg.max_iters} steps …\n")
t0 = time.time()
for step in range(start_step, cfg.max_iters):
# periodic evaluation + save
if step % cfg.eval_interval == 0:
losses = estimate_loss(model, train_data, val_data, cfg, device)
elapsed = time.time() - t0
eta = (elapsed / max(step - start_step, 1)) * (cfg.max_iters - step)
lr_now = optimizer.param_groups[0]["lr"]
print(f"step {step:>5}/{cfg.max_iters} | "
f"train {losses['train']:.4f} | "
f"val {losses['val']:.4f} | "
f"lr {lr_now:.2e} | "
f"ETA {eta/60:.1f}m")
save_model(model, cfg, vocab_size, stoi, itos, step)
# Generate sample and save routing log periodically for visualization
model.eval()
with torch.no_grad():
# Workaround for MPS generation hangs: move to CPU for sampling
original_device = next(model.parameters()).device
model.to("cpu")
context = torch.zeros((1, 1), dtype=torch.long, device="cpu")
gen_ids, r_log = model.generate(context, max_new_tokens=400, temperature=0.8, top_k=40)
smp = "".join(itos[i] for i in gen_ids[0].tolist())
with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f:
f.write(smp)
with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f:
json.dump(r_log, f, indent=2)
model.to(original_device)
model.train()
# forward + backward
x, y = get_batch(train_data, cfg.block_size,
cfg.batch_size, device)
_, loss, _, _ = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Constant updates
if step % 5 == 0:
print(f"\rstep {step:>5}/{cfg.max_iters} | loss {loss.item():.4f} | lr {optimizer.param_groups[0]['lr']:.2e}", end="", flush=True)
if step % cfg.eval_interval == 0 and step > start_step:
print() # new line after progress bar
# final save
save_model(model, cfg, vocab_size, stoi, itos, cfg.max_iters)
print("\n[himoe] Training complete.")
# ── sample generation ─────────────────────────────────────────────────────
print("\n[himoe] Sample generation:\n" + "─" * 60)
model.eval()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
gen_ids, routing_log = model.generate(context, max_new_tokens=400,
temperature=0.8, top_k=40)
sample = "".join(itos[i] for i in gen_ids[0].tolist())
print(sample)
print("─" * 60)
with open(os.path.join(cfg.model_dir, "sample.txt"), "w") as f:
f.write(sample)
with open(os.path.join(cfg.model_dir, "routing_log.json"), "w") as f:
json.dump(routing_log, f, indent=2) # save full log for visualization
print(f"\n[himoe] Sample + routing log saved to '{cfg.model_dir}/'")
# ── routing statistics ────────────────────────────────────────────────────
print("\n[himoe] Expert utilisation (last generation, layer 0):")
moe_counts = [0] * cfg.num_moes
exp_counts = [[0] * cfg.num_experts for _ in range(cfg.num_moes)]
for entry in routing_log:
m = entry["moe"][0][0]
e = entry["exp"][0][0]
moe_counts[m] += 1
exp_counts[m][e] += 1
total = sum(moe_counts)
for mi, mc in enumerate(moe_counts):
bar = "β–ˆ" * int(40 * mc / max(total, 1))
print(f" MoE {mi+1:02d} [{bar:<40}] {mc:4d} tokens "
f"({100*mc/max(total,1):.1f}%)")
for ei, ec in enumerate(exp_counts[mi]):
if ec > 0:
print(f" Expert {ei+1:02d}: {ec} tokens")
# ──────────────────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train HiMoE on hamlet.txt")
parser.add_argument("--resume", action="store_true",
help="Resume from existing checkpoint in model/")
parser.add_argument("--max_iters", type=int, default=None)
parser.add_argument("--n_layer", type=int, default=None)
parser.add_argument("--n_embd", type=int, default=None)
parser.add_argument("--num_moes", type=int, default=None)
parser.add_argument("--num_experts", type=int, default=None)
parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--data_file", type=str, default=None)
parser.add_argument("--model_dir", type=str, default=None)
args = parser.parse_args()
cfg = HiMoEConfig()
for attr in ["max_iters", "n_layer", "n_embd", "num_moes",
"num_experts", "lr", "data_file", "model_dir"]:
val = getattr(args, attr)
if val is not None:
setattr(cfg, attr, val)
train(cfg, resume=args.resume)