| """ |
| Eve-2-MoE Training Script — Multi-GPU DDP |
| ========================================== |
| Usage: |
| Single GPU: python train.py |
| Multi-GPU: torchrun --nproc_per_node=2 train.py |
| 4x GPU: torchrun --nproc_per_node=4 train.py |
| |
| Override config: torchrun --nproc_per_node=2 train.py --max_steps 15000 --batch_size 48 |
| |
| Author: Anthony Maio / Making Minds AI Research |
| """ |
|
|
| import os |
| import sys |
| import math |
| import time |
| import json |
| import argparse |
| import logging |
| from pathlib import Path |
| from contextlib import nullcontext |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| import tiktoken |
| from datasets import load_dataset |
|
|
| from modeling_eve import ModelConfig, DeepSeekMoE |
|
|
| |
| |
| |
|
|
| def setup_distributed(): |
| """Initialize DDP if launched with torchrun, otherwise single-GPU.""" |
| if "RANK" in os.environ: |
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| torch.cuda.set_device(local_rank) |
| device = torch.device(f"cuda:{local_rank}") |
| else: |
| rank = 0 |
| world_size = 1 |
| local_rank = 0 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| is_master = rank == 0 |
| return rank, world_size, local_rank, device, is_master |
|
|
|
|
| def cleanup_distributed(): |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| |
| |
| |
|
|
| class StreamingDataLoader: |
| """Streams tokenized batches from FineWeb-Edu. |
| |
| Each DDP rank skips interleaved samples so no two GPUs see the same data. |
| """ |
|
|
| def __init__(self, batch_size: int, block_size: int, rank: int = 0, |
| world_size: int = 1, dataset_name: str = "sample-10BT"): |
| self.batch_size = batch_size |
| self.block_size = block_size |
| self.rank = rank |
| self.world_size = world_size |
| self.dataset_name = dataset_name |
| self.enc = tiktoken.get_encoding("gpt2") |
| self._init_stream() |
|
|
| def _init_stream(self): |
| ds = load_dataset("HuggingFaceFW/fineweb-edu", name=self.dataset_name, |
| split="train", streaming=True) |
| |
| if self.world_size > 1: |
| ds = ds.shard(num_shards=self.world_size, index=self.rank) |
| self.iter_dataset = iter(ds) |
|
|
| def get_batch(self) -> tuple[torch.Tensor, torch.Tensor]: |
| total_tokens = self.batch_size * self.block_size |
|
|
| batch_tokens = [] |
| while len(batch_tokens) < total_tokens + 1: |
| try: |
| text = next(self.iter_dataset)["text"] |
| tokens = self.enc.encode(text, allowed_special={"<|endoftext|>"}) |
| batch_tokens.extend(tokens) |
| except StopIteration: |
| print(f"[Rank {self.rank}] Dataset exhausted, restarting stream...") |
| self._init_stream() |
|
|
| data = torch.tensor(batch_tokens[:total_tokens + 1], dtype=torch.long) |
| x = data[:total_tokens].view(self.batch_size, self.block_size) |
| y = data[1:total_tokens + 1].view(self.batch_size, self.block_size) |
| return x, y |
|
|
|
|
| class ValidationLoader: |
| """WikiText-2 validation set.""" |
|
|
| def __init__(self, block_size: int, device: torch.device): |
| self.block_size = block_size |
| self.device = device |
| enc = tiktoken.get_encoding("gpt2") |
|
|
| ds = load_dataset("wikitext", "wikitext-2-v1", split="test") |
| text = "\n\n".join(ds["text"]) |
| tokens = enc.encode(text, allowed_special={"<|endoftext|>"}) |
| self.data = torch.tensor(tokens, dtype=torch.long, device=device) |
|
|
| @torch.no_grad() |
| def estimate_loss(self, model, eval_iters: int = 50, batch_size: int = 32) -> float: |
| model.eval() |
| losses = torch.zeros(eval_iters, device=self.device) |
|
|
| for k in range(eval_iters): |
| ix = torch.randint(len(self.data) - self.block_size, (batch_size,)) |
| x = torch.stack([self.data[i:i + self.block_size] for i in ix]) |
| y = torch.stack([self.data[i + 1:i + self.block_size + 1] for i in ix]) |
|
|
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
| _, loss = model(x, y) |
| losses[k] = loss.item() |
|
|
| model.train() |
| return losses.mean().item() |
|
|
|
|
| |
| |
| |
|
|
| def get_lr(step: int, max_steps: int, warmup_steps: int, peak_lr: float, min_lr_ratio: float = 0.1) -> float: |
| """Cosine decay with linear warmup.""" |
| min_lr = peak_lr * min_lr_ratio |
|
|
| |
| if step < warmup_steps: |
| return peak_lr * (step + 1) / (warmup_steps + 1) |
|
|
| |
| if step > max_steps: |
| return min_lr |
|
|
| |
| decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return min_lr + coeff * (peak_lr - min_lr) |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model, optimizer, step: int, loss: float, val_loss: float, |
| config: ModelConfig, checkpoint_dir: Path, is_ddp: bool): |
| """Save training checkpoint (model weights, optimizer state, metadata).""" |
| raw_model = model.module if is_ddp else model |
| checkpoint = { |
| "step": step, |
| "model_state_dict": raw_model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "train_loss": loss, |
| "val_loss": val_loss, |
| "config": { |
| "vocab_size": config.vocab_size, |
| "n_layer": config.n_layer, |
| "n_embd": config.n_embd, |
| "n_head": config.n_head, |
| "head_dim": config.head_dim, |
| "block_size": config.block_size, |
| "num_experts": config.num_experts, |
| "top_k": config.top_k, |
| "expert_intermediate_size": config.expert_intermediate_size, |
| "shared_expert_intermediate_size": config.shared_expert_intermediate_size, |
| "rope_theta": config.rope_theta, |
| }, |
| } |
| path = checkpoint_dir / f"step_{step}.pt" |
| torch.save(checkpoint, path) |
| print(f" Checkpoint saved: {path}") |
|
|
| |
| latest = checkpoint_dir / "latest.pt" |
| torch.save(checkpoint, latest) |
|
|
|
|
| def save_final_model(model, config: ModelConfig, output_dir: Path, is_ddp: bool): |
| """Save just the model weights + config for HuggingFace upload.""" |
| raw_model = model.module if is_ddp else model |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| torch.save(raw_model.state_dict(), output_dir / "pytorch_model.bin") |
|
|
| config_data = { |
| "architecture": "Eve-2-MoE", |
| "vocab_size": config.vocab_size, |
| "n_layer": config.n_layer, |
| "n_embd": config.n_embd, |
| "n_head": config.n_head, |
| "head_dim": config.head_dim, |
| "block_size": config.block_size, |
| "num_experts": config.num_experts, |
| "top_k": config.top_k, |
| "expert_intermediate_size": config.expert_intermediate_size, |
| "shared_expert_intermediate_size": config.shared_expert_intermediate_size, |
| "rope_theta": config.rope_theta, |
| } |
| with open(output_dir / "config.json", "w") as f: |
| json.dump(config_data, f, indent=2) |
|
|
| print(f" Final model saved to {output_dir}") |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Eve-2-MoE Training") |
|
|
| |
| p.add_argument("--n_layer", type=int, default=12) |
| p.add_argument("--n_embd", type=int, default=512) |
| p.add_argument("--n_head", type=int, default=8) |
| p.add_argument("--num_experts", type=int, default=8) |
| p.add_argument("--block_size", type=int, default=2048) |
|
|
| |
| p.add_argument("--max_steps", type=int, default=7500, |
| help="Total training steps. 7500 steps ≈ 500M tokens (1hr single B200)") |
| p.add_argument("--batch_size", type=int, default=32, |
| help="Per-GPU batch size") |
| p.add_argument("--learning_rate", type=float, default=5e-4) |
| p.add_argument("--warmup_steps", type=int, default=200) |
| p.add_argument("--weight_decay", type=float, default=0.1) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--min_lr_ratio", type=float, default=0.1, |
| help="Minimum LR as fraction of peak (cosine decay floor)") |
|
|
| |
| p.add_argument("--dataset", type=str, default="sample-10BT", |
| help="FineWeb-Edu subset name") |
|
|
| |
| p.add_argument("--save_every", type=int, default=500) |
| p.add_argument("--val_every", type=int, default=500) |
| p.add_argument("--checkpoint_dir", type=str, default="checkpoints") |
| p.add_argument("--output_dir", type=str, default="model_final") |
|
|
| |
| p.add_argument("--compile", action="store_true", default=True, |
| help="Use torch.compile (recommended for B200/H100)") |
| p.add_argument("--no_compile", action="store_true", |
| help="Disable torch.compile") |
| p.add_argument("--wandb_project", type=str, default="Eve-2-MoE", |
| help="WandB project name (empty to disable)") |
| p.add_argument("--wandb_run", type=str, default=None, |
| help="WandB run name") |
| p.add_argument("--resume", type=str, default=None, |
| help="Path to checkpoint to resume from") |
| p.add_argument("--use_checkpointing", action="store_true", |
| help="Enable gradient checkpointing (saves VRAM)") |
|
|
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| |
| rank, world_size, local_rank, device, is_master = setup_distributed() |
|
|
| if is_master: |
| print(f"{'=' * 60}") |
| print(f" Eve-2-MoE Training") |
| print(f" GPUs: {world_size} | Device: {torch.cuda.get_device_name(device)}") |
| print(f" Steps: {args.max_steps} | Batch/GPU: {args.batch_size}") |
| print(f" Global batch: {args.batch_size * world_size} × {args.block_size} = " |
| f"{args.batch_size * world_size * args.block_size:,} tokens/step") |
| print(f" Total tokens: ~{args.max_steps * args.batch_size * world_size * args.block_size / 1e9:.1f}B") |
| print(f"{'=' * 60}") |
|
|
| |
| config = ModelConfig( |
| n_layer=args.n_layer, |
| n_embd=args.n_embd, |
| n_head=args.n_head, |
| num_experts=args.num_experts, |
| block_size=args.block_size, |
| use_checkpointing=args.use_checkpointing, |
| ) |
|
|
| model = DeepSeekMoE(config).to(device) |
|
|
| if is_master: |
| param_count = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {param_count / 1e6:.2f}M") |
|
|
| |
| if args.compile and not args.no_compile: |
| if is_master: |
| print(" Compiling model with torch.compile...") |
| model = torch.compile(model) |
|
|
| |
| is_ddp = world_size > 1 |
| if is_ddp: |
| model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) |
|
|
| raw_model = model.module if is_ddp else model |
|
|
| |
| optimizer = torch.optim.AdamW( |
| raw_model.parameters(), |
| lr=args.learning_rate, |
| betas=(0.9, 0.95), |
| weight_decay=args.weight_decay, |
| ) |
|
|
| |
| start_step = 0 |
| if args.resume: |
| if is_master: |
| print(f" Resuming from {args.resume}...") |
| ckpt = torch.load(args.resume, map_location=device) |
| raw_model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| start_step = ckpt["step"] + 1 |
| if is_master: |
| print(f" Resumed at step {start_step}") |
|
|
| |
| train_loader = StreamingDataLoader( |
| batch_size=args.batch_size, |
| block_size=config.block_size, |
| rank=rank, |
| world_size=world_size, |
| dataset_name=args.dataset, |
| ) |
|
|
| val_loader = None |
| if is_master: |
| val_loader = ValidationLoader(config.block_size, device) |
|
|
| |
| checkpoint_dir = Path(args.checkpoint_dir) |
| if is_master: |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| wandb_enabled = False |
| if is_master and args.wandb_project: |
| try: |
| import wandb |
| wandb.init( |
| project=args.wandb_project, |
| name=args.wandb_run or f"eve2-{world_size}gpu-{args.max_steps}steps", |
| config=vars(args), |
| ) |
| wandb_enabled = True |
| except ImportError: |
| print(" WandB not installed, skipping.") |
|
|
| |
| model.train() |
| tokens_per_step = args.batch_size * world_size * config.block_size |
|
|
| if is_master: |
| print(f"\n Starting training from step {start_step}...\n") |
|
|
| for step in range(start_step, args.max_steps): |
| t0 = time.time() |
|
|
| |
| lr = get_lr(step, args.max_steps, args.warmup_steps, args.learning_rate, args.min_lr_ratio) |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = lr |
|
|
| |
| x, y = train_loader.get_batch() |
| x, y = x.to(device), y.to(device) |
|
|
| |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
| logits, loss = model(x, y) |
|
|
| |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
|
|
| |
| if args.grad_clip > 0: |
| grad_norm = torch.nn.utils.clip_grad_norm_(raw_model.parameters(), args.grad_clip) |
| else: |
| grad_norm = None |
|
|
| optimizer.step() |
|
|
| |
| torch.cuda.synchronize() |
| t1 = time.time() |
| dt_ms = (t1 - t0) * 1000 |
| tok_per_sec = tokens_per_step / (t1 - t0) |
|
|
| |
| if is_master and step % 10 == 0: |
| grad_str = f" | Grad: {grad_norm:.2f}" if grad_norm is not None else "" |
| print(f" Step {step:>6d}/{args.max_steps} | Loss: {loss.item():.4f} | " |
| f"LR: {lr:.2e} | {tok_per_sec:,.0f} tok/s | {dt_ms:.0f}ms{grad_str}") |
|
|
| if wandb_enabled: |
| import wandb |
| log = { |
| "train_loss": loss.item(), |
| "lr": lr, |
| "tokens_per_sec": tok_per_sec, |
| "step_time_ms": dt_ms, |
| } |
| if grad_norm is not None: |
| log["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm |
| wandb.log(log, step=step) |
|
|
| |
| if is_master and val_loader and step > 0 and step % args.val_every == 0: |
| val_loss = val_loader.estimate_loss(raw_model) |
| print(f" >>> Validation Loss: {val_loss:.4f}") |
| if wandb_enabled: |
| wandb.log({"val_loss": val_loss}, step=step) |
|
|
| |
| save_checkpoint(model, optimizer, step, loss.item(), val_loss, |
| config, checkpoint_dir, is_ddp) |
|
|
| |
| elif is_master and step > 0 and step % args.save_every == 0 and step % args.val_every != 0: |
| save_checkpoint(model, optimizer, step, loss.item(), -1.0, |
| config, checkpoint_dir, is_ddp) |
|
|
| |
| if is_master: |
| print(f"\n{'=' * 60}") |
| print(" Training complete!") |
|
|
| if val_loader: |
| final_val = val_loader.estimate_loss(raw_model) |
| print(f" Final Val Loss: {final_val:.4f}") |
|
|
| |
| output_dir = Path(args.output_dir) |
| save_final_model(model, config, output_dir, is_ddp) |
|
|
| |
| save_checkpoint(model, optimizer, args.max_steps, loss.item(), |
| final_val if val_loader else -1.0, |
| config, checkpoint_dir, is_ddp) |
|
|
| print(f"\n Upload to HuggingFace:") |
| print(f" huggingface-cli upload anthonym21/Eve-2-MoE-250M {output_dir}/") |
| print(f"{'=' * 60}") |
|
|
| if wandb_enabled: |
| import wandb |
| wandb.finish() |
|
|
| cleanup_distributed() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|