|
|
""" |
|
|
Eve-2-MoE Training Script — Multi-GPU DDP |
|
|
========================================== |
|
|
Usage: |
|
|
Single GPU: python train.py |
|
|
Multi-GPU: torchrun --nproc_per_node=2 train.py |
|
|
4x GPU: torchrun --nproc_per_node=4 train.py |
|
|
|
|
|
Override config: torchrun --nproc_per_node=2 train.py --max_steps 15000 --batch_size 48 |
|
|
|
|
|
Author: Anthony Maio / Making Minds AI Research |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import math |
|
|
import time |
|
|
import json |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
|
import tiktoken |
|
|
from datasets import load_dataset |
|
|
|
|
|
from modeling_eve import ModelConfig, DeepSeekMoE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_distributed(): |
|
|
"""Initialize DDP if launched with torchrun, otherwise single-GPU.""" |
|
|
if "RANK" in os.environ: |
|
|
dist.init_process_group(backend="nccl") |
|
|
rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
torch.cuda.set_device(local_rank) |
|
|
device = torch.device(f"cuda:{local_rank}") |
|
|
else: |
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
local_rank = 0 |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
is_master = rank == 0 |
|
|
return rank, world_size, local_rank, device, is_master |
|
|
|
|
|
|
|
|
def cleanup_distributed(): |
|
|
if dist.is_initialized(): |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamingDataLoader: |
|
|
"""Streams tokenized batches from FineWeb-Edu. |
|
|
|
|
|
Each DDP rank skips interleaved samples so no two GPUs see the same data. |
|
|
""" |
|
|
|
|
|
def __init__(self, batch_size: int, block_size: int, rank: int = 0, |
|
|
world_size: int = 1, dataset_name: str = "sample-10BT"): |
|
|
self.batch_size = batch_size |
|
|
self.block_size = block_size |
|
|
self.rank = rank |
|
|
self.world_size = world_size |
|
|
self.dataset_name = dataset_name |
|
|
self.enc = tiktoken.get_encoding("gpt2") |
|
|
self._init_stream() |
|
|
|
|
|
def _init_stream(self): |
|
|
ds = load_dataset("HuggingFaceFW/fineweb-edu", name=self.dataset_name, |
|
|
split="train", streaming=True) |
|
|
|
|
|
if self.world_size > 1: |
|
|
ds = ds.shard(num_shards=self.world_size, index=self.rank) |
|
|
self.iter_dataset = iter(ds) |
|
|
|
|
|
def get_batch(self) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
total_tokens = self.batch_size * self.block_size |
|
|
|
|
|
batch_tokens = [] |
|
|
while len(batch_tokens) < total_tokens + 1: |
|
|
try: |
|
|
text = next(self.iter_dataset)["text"] |
|
|
tokens = self.enc.encode(text, allowed_special={"<|endoftext|>"}) |
|
|
batch_tokens.extend(tokens) |
|
|
except StopIteration: |
|
|
print(f"[Rank {self.rank}] Dataset exhausted, restarting stream...") |
|
|
self._init_stream() |
|
|
|
|
|
data = torch.tensor(batch_tokens[:total_tokens + 1], dtype=torch.long) |
|
|
x = data[:total_tokens].view(self.batch_size, self.block_size) |
|
|
y = data[1:total_tokens + 1].view(self.batch_size, self.block_size) |
|
|
return x, y |
|
|
|
|
|
|
|
|
class ValidationLoader: |
|
|
"""WikiText-2 validation set.""" |
|
|
|
|
|
def __init__(self, block_size: int, device: torch.device): |
|
|
self.block_size = block_size |
|
|
self.device = device |
|
|
enc = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
ds = load_dataset("wikitext", "wikitext-2-v1", split="test") |
|
|
text = "\n\n".join(ds["text"]) |
|
|
tokens = enc.encode(text, allowed_special={"<|endoftext|>"}) |
|
|
self.data = torch.tensor(tokens, dtype=torch.long, device=device) |
|
|
|
|
|
@torch.no_grad() |
|
|
def estimate_loss(self, model, eval_iters: int = 50, batch_size: int = 32) -> float: |
|
|
model.eval() |
|
|
losses = torch.zeros(eval_iters, device=self.device) |
|
|
|
|
|
for k in range(eval_iters): |
|
|
ix = torch.randint(len(self.data) - self.block_size, (batch_size,)) |
|
|
x = torch.stack([self.data[i:i + self.block_size] for i in ix]) |
|
|
y = torch.stack([self.data[i + 1:i + self.block_size + 1] for i in ix]) |
|
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
|
_, loss = model(x, y) |
|
|
losses[k] = loss.item() |
|
|
|
|
|
model.train() |
|
|
return losses.mean().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lr(step: int, max_steps: int, warmup_steps: int, peak_lr: float, min_lr_ratio: float = 0.1) -> float: |
|
|
"""Cosine decay with linear warmup.""" |
|
|
min_lr = peak_lr * min_lr_ratio |
|
|
|
|
|
|
|
|
if step < warmup_steps: |
|
|
return peak_lr * (step + 1) / (warmup_steps + 1) |
|
|
|
|
|
|
|
|
if step > max_steps: |
|
|
return min_lr |
|
|
|
|
|
|
|
|
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
|
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
return min_lr + coeff * (peak_lr - min_lr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, step: int, loss: float, val_loss: float, |
|
|
config: ModelConfig, checkpoint_dir: Path, is_ddp: bool): |
|
|
"""Save training checkpoint (model weights, optimizer state, metadata).""" |
|
|
raw_model = model.module if is_ddp else model |
|
|
checkpoint = { |
|
|
"step": step, |
|
|
"model_state_dict": raw_model.state_dict(), |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"train_loss": loss, |
|
|
"val_loss": val_loss, |
|
|
"config": { |
|
|
"vocab_size": config.vocab_size, |
|
|
"n_layer": config.n_layer, |
|
|
"n_embd": config.n_embd, |
|
|
"n_head": config.n_head, |
|
|
"head_dim": config.head_dim, |
|
|
"block_size": config.block_size, |
|
|
"num_experts": config.num_experts, |
|
|
"top_k": config.top_k, |
|
|
"expert_intermediate_size": config.expert_intermediate_size, |
|
|
"shared_expert_intermediate_size": config.shared_expert_intermediate_size, |
|
|
"rope_theta": config.rope_theta, |
|
|
}, |
|
|
} |
|
|
path = checkpoint_dir / f"step_{step}.pt" |
|
|
torch.save(checkpoint, path) |
|
|
print(f" Checkpoint saved: {path}") |
|
|
|
|
|
|
|
|
latest = checkpoint_dir / "latest.pt" |
|
|
torch.save(checkpoint, latest) |
|
|
|
|
|
|
|
|
def save_final_model(model, config: ModelConfig, output_dir: Path, is_ddp: bool): |
|
|
"""Save just the model weights + config for HuggingFace upload.""" |
|
|
raw_model = model.module if is_ddp else model |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
torch.save(raw_model.state_dict(), output_dir / "pytorch_model.bin") |
|
|
|
|
|
config_data = { |
|
|
"architecture": "Eve-2-MoE", |
|
|
"vocab_size": config.vocab_size, |
|
|
"n_layer": config.n_layer, |
|
|
"n_embd": config.n_embd, |
|
|
"n_head": config.n_head, |
|
|
"head_dim": config.head_dim, |
|
|
"block_size": config.block_size, |
|
|
"num_experts": config.num_experts, |
|
|
"top_k": config.top_k, |
|
|
"expert_intermediate_size": config.expert_intermediate_size, |
|
|
"shared_expert_intermediate_size": config.shared_expert_intermediate_size, |
|
|
"rope_theta": config.rope_theta, |
|
|
} |
|
|
with open(output_dir / "config.json", "w") as f: |
|
|
json.dump(config_data, f, indent=2) |
|
|
|
|
|
print(f" Final model saved to {output_dir}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser(description="Eve-2-MoE Training") |
|
|
|
|
|
|
|
|
p.add_argument("--n_layer", type=int, default=12) |
|
|
p.add_argument("--n_embd", type=int, default=512) |
|
|
p.add_argument("--n_head", type=int, default=8) |
|
|
p.add_argument("--num_experts", type=int, default=8) |
|
|
p.add_argument("--block_size", type=int, default=2048) |
|
|
|
|
|
|
|
|
p.add_argument("--max_steps", type=int, default=7500, |
|
|
help="Total training steps. 7500 steps ≈ 500M tokens (1hr single B200)") |
|
|
p.add_argument("--batch_size", type=int, default=32, |
|
|
help="Per-GPU batch size") |
|
|
p.add_argument("--learning_rate", type=float, default=5e-4) |
|
|
p.add_argument("--warmup_steps", type=int, default=200) |
|
|
p.add_argument("--weight_decay", type=float, default=0.1) |
|
|
p.add_argument("--grad_clip", type=float, default=1.0) |
|
|
p.add_argument("--min_lr_ratio", type=float, default=0.1, |
|
|
help="Minimum LR as fraction of peak (cosine decay floor)") |
|
|
|
|
|
|
|
|
p.add_argument("--dataset", type=str, default="sample-10BT", |
|
|
help="FineWeb-Edu subset name") |
|
|
|
|
|
|
|
|
p.add_argument("--save_every", type=int, default=500) |
|
|
p.add_argument("--val_every", type=int, default=500) |
|
|
p.add_argument("--checkpoint_dir", type=str, default="checkpoints") |
|
|
p.add_argument("--output_dir", type=str, default="model_final") |
|
|
|
|
|
|
|
|
p.add_argument("--compile", action="store_true", default=True, |
|
|
help="Use torch.compile (recommended for B200/H100)") |
|
|
p.add_argument("--no_compile", action="store_true", |
|
|
help="Disable torch.compile") |
|
|
p.add_argument("--wandb_project", type=str, default="Eve-2-MoE", |
|
|
help="WandB project name (empty to disable)") |
|
|
p.add_argument("--wandb_run", type=str, default=None, |
|
|
help="WandB run name") |
|
|
p.add_argument("--resume", type=str, default=None, |
|
|
help="Path to checkpoint to resume from") |
|
|
p.add_argument("--use_checkpointing", action="store_true", |
|
|
help="Enable gradient checkpointing (saves VRAM)") |
|
|
|
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
rank, world_size, local_rank, device, is_master = setup_distributed() |
|
|
|
|
|
if is_master: |
|
|
print(f"{'=' * 60}") |
|
|
print(f" Eve-2-MoE Training") |
|
|
print(f" GPUs: {world_size} | Device: {torch.cuda.get_device_name(device)}") |
|
|
print(f" Steps: {args.max_steps} | Batch/GPU: {args.batch_size}") |
|
|
print(f" Global batch: {args.batch_size * world_size} × {args.block_size} = " |
|
|
f"{args.batch_size * world_size * args.block_size:,} tokens/step") |
|
|
print(f" Total tokens: ~{args.max_steps * args.batch_size * world_size * args.block_size / 1e9:.1f}B") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
|
|
|
config = ModelConfig( |
|
|
n_layer=args.n_layer, |
|
|
n_embd=args.n_embd, |
|
|
n_head=args.n_head, |
|
|
num_experts=args.num_experts, |
|
|
block_size=args.block_size, |
|
|
use_checkpointing=args.use_checkpointing, |
|
|
) |
|
|
|
|
|
model = DeepSeekMoE(config).to(device) |
|
|
|
|
|
if is_master: |
|
|
param_count = sum(p.numel() for p in model.parameters()) |
|
|
print(f" Parameters: {param_count / 1e6:.2f}M") |
|
|
|
|
|
|
|
|
if args.compile and not args.no_compile: |
|
|
if is_master: |
|
|
print(" Compiling model with torch.compile...") |
|
|
model = torch.compile(model) |
|
|
|
|
|
|
|
|
is_ddp = world_size > 1 |
|
|
if is_ddp: |
|
|
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) |
|
|
|
|
|
raw_model = model.module if is_ddp else model |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
raw_model.parameters(), |
|
|
lr=args.learning_rate, |
|
|
betas=(0.9, 0.95), |
|
|
weight_decay=args.weight_decay, |
|
|
) |
|
|
|
|
|
|
|
|
start_step = 0 |
|
|
if args.resume: |
|
|
if is_master: |
|
|
print(f" Resuming from {args.resume}...") |
|
|
ckpt = torch.load(args.resume, map_location=device) |
|
|
raw_model.load_state_dict(ckpt["model_state_dict"]) |
|
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
|
|
start_step = ckpt["step"] + 1 |
|
|
if is_master: |
|
|
print(f" Resumed at step {start_step}") |
|
|
|
|
|
|
|
|
train_loader = StreamingDataLoader( |
|
|
batch_size=args.batch_size, |
|
|
block_size=config.block_size, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
dataset_name=args.dataset, |
|
|
) |
|
|
|
|
|
val_loader = None |
|
|
if is_master: |
|
|
val_loader = ValidationLoader(config.block_size, device) |
|
|
|
|
|
|
|
|
checkpoint_dir = Path(args.checkpoint_dir) |
|
|
if is_master: |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
wandb_enabled = False |
|
|
if is_master and args.wandb_project: |
|
|
try: |
|
|
import wandb |
|
|
wandb.init( |
|
|
project=args.wandb_project, |
|
|
name=args.wandb_run or f"eve2-{world_size}gpu-{args.max_steps}steps", |
|
|
config=vars(args), |
|
|
) |
|
|
wandb_enabled = True |
|
|
except ImportError: |
|
|
print(" WandB not installed, skipping.") |
|
|
|
|
|
|
|
|
model.train() |
|
|
tokens_per_step = args.batch_size * world_size * config.block_size |
|
|
|
|
|
if is_master: |
|
|
print(f"\n Starting training from step {start_step}...\n") |
|
|
|
|
|
for step in range(start_step, args.max_steps): |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
lr = get_lr(step, args.max_steps, args.warmup_steps, args.learning_rate, args.min_lr_ratio) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group["lr"] = lr |
|
|
|
|
|
|
|
|
x, y = train_loader.get_batch() |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
|
|
|
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
|
logits, loss = model(x, y) |
|
|
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if args.grad_clip > 0: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(raw_model.parameters(), args.grad_clip) |
|
|
else: |
|
|
grad_norm = None |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
t1 = time.time() |
|
|
dt_ms = (t1 - t0) * 1000 |
|
|
tok_per_sec = tokens_per_step / (t1 - t0) |
|
|
|
|
|
|
|
|
if is_master and step % 10 == 0: |
|
|
grad_str = f" | Grad: {grad_norm:.2f}" if grad_norm is not None else "" |
|
|
print(f" Step {step:>6d}/{args.max_steps} | Loss: {loss.item():.4f} | " |
|
|
f"LR: {lr:.2e} | {tok_per_sec:,.0f} tok/s | {dt_ms:.0f}ms{grad_str}") |
|
|
|
|
|
if wandb_enabled: |
|
|
import wandb |
|
|
log = { |
|
|
"train_loss": loss.item(), |
|
|
"lr": lr, |
|
|
"tokens_per_sec": tok_per_sec, |
|
|
"step_time_ms": dt_ms, |
|
|
} |
|
|
if grad_norm is not None: |
|
|
log["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm |
|
|
wandb.log(log, step=step) |
|
|
|
|
|
|
|
|
if is_master and val_loader and step > 0 and step % args.val_every == 0: |
|
|
val_loss = val_loader.estimate_loss(raw_model) |
|
|
print(f" >>> Validation Loss: {val_loss:.4f}") |
|
|
if wandb_enabled: |
|
|
wandb.log({"val_loss": val_loss}, step=step) |
|
|
|
|
|
|
|
|
save_checkpoint(model, optimizer, step, loss.item(), val_loss, |
|
|
config, checkpoint_dir, is_ddp) |
|
|
|
|
|
|
|
|
elif is_master and step > 0 and step % args.save_every == 0 and step % args.val_every != 0: |
|
|
save_checkpoint(model, optimizer, step, loss.item(), -1.0, |
|
|
config, checkpoint_dir, is_ddp) |
|
|
|
|
|
|
|
|
if is_master: |
|
|
print(f"\n{'=' * 60}") |
|
|
print(" Training complete!") |
|
|
|
|
|
if val_loader: |
|
|
final_val = val_loader.estimate_loss(raw_model) |
|
|
print(f" Final Val Loss: {final_val:.4f}") |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
save_final_model(model, config, output_dir, is_ddp) |
|
|
|
|
|
|
|
|
save_checkpoint(model, optimizer, args.max_steps, loss.item(), |
|
|
final_val if val_loader else -1.0, |
|
|
config, checkpoint_dir, is_ddp) |
|
|
|
|
|
print(f"\n Upload to HuggingFace:") |
|
|
print(f" huggingface-cli upload anthonym21/Eve-2-MoE-250M {output_dir}/") |
|
|
print(f"{'=' * 60}") |
|
|
|
|
|
if wandb_enabled: |
|
|
import wandb |
|
|
wandb.finish() |
|
|
|
|
|
cleanup_distributed() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|