algorythmtechnologies's picture
Update supernova/train.py
ca8d994 verified
import argparse
import json
import math
import os
import time
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from transformers import get_cosine_schedule_with_warmup
from safetensors.torch import save_file
from .config import ModelConfig
from .model import SupernovaModel
from .tokenizer import load_gpt2_tokenizer
from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
# ------------------------------
# Utilities
# ------------------------------
def compute_grad_norm(model: nn.Module, debug: bool = False) -> float:
total = 0.0
grad_count = 0
param_count = 0
for name, p in model.named_parameters():
param_count += 1
if p.grad is not None:
grad_count += 1
param_norm = p.grad.data.float().norm(2).item()
total += param_norm * param_norm
if debug and param_norm > 1e-8:
print(f" {name}: grad_norm={param_norm:.6f}")
elif debug:
print(f" {name}: NO GRAD")
if debug:
print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={math.sqrt(total):.6f}")
return math.sqrt(total)
def atomic_save(obj: Dict[str, Any], path: str):
tmp = path + ".tmp"
torch.save(obj, tmp)
os.replace(tmp, path)
def save_safetensors_checkpoint(model_state_dict: Dict[str, torch.Tensor], path: str):
"""Save model weights in safetensors format."""
try:
tmp = path + ".tmp"
save_file(model_state_dict, tmp)
os.replace(tmp, path)
print(f"✓ Saved safetensors to {path}")
except Exception as e:
print(f"Warning: Failed to save safetensors: {e}")
class EMA:
"""Simple exponential moving average of model params (maintains shadow copy)."""
def __init__(self, model: nn.Module, decay: float = 0.9999):
self.decay = decay
self.shadow = {}
for name, p in model.named_parameters():
if p.requires_grad:
self.shadow[name] = p.data.clone()
def update(self, model: nn.Module):
for name, p in model.named_parameters():
if p.requires_grad:
self.shadow[name].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay)
def store(self, model: nn.Module):
self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}
def copy_to(self, model: nn.Module):
for name, p in model.named_parameters():
if p.requires_grad:
p.data.copy_(self.shadow[name])
def restore(self, model: nn.Module):
for name, p in model.named_parameters():
if p.requires_grad:
p.data.copy_(self.backup[name])
del self.backup
# ------------------------------
# Training loop
# ------------------------------
def train(
config_path: str,
data_config_path: str,
seq_len: int = 1024,
batch_size: int = 16,
grad_accum: int = 8,
lr: float = 3e-4,
warmup_steps: int = 2000,
max_steps: int = 100_000,
save_every: int = 10_000,
out_dir: str = "checkpoints",
seed: int = 42,
validate_every: int = 1000,
val_steps: int = 100,
clip_grad_norm: Optional[float] = 1.0,
use_ema: bool = True,
ema_decay: float = 0.9999,
resume_from: Optional[str] = None,
use_tensorboard: bool = True,
ddp: bool = False,
local_rank: int = 0,
num_workers: int = 4,
pin_memory: bool = True,
compile_model: bool = False,
export_safetensors: bool = True,
):
# reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
import random
random.seed(seed)
torch.backends.cudnn.benchmark = True
# device / distributed
if ddp:
torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# config & tokenizer
cfg = ModelConfig.from_json_file(config_path)
cfg.assert_exact_params(expected=25_000_000)
tok = load_gpt2_tokenizer()
assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch."
model = SupernovaModel(cfg)
if hasattr(model, "gradient_checkpointing_enable"):
try:
model.gradient_checkpointing_enable()
except Exception:
pass
model.to(device)
total_params = sum(p.numel() for p in model.parameters())
assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000"
if compile_model:
try:
model = torch.compile(model)
except Exception as e:
print("torch.compile not available/failed:", e)
if ddp:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False)
sources = load_sources_from_yaml(data_config_path)
ds = TokenChunkDataset(
tokenizer=tok,
sources=sources,
seq_len=seq_len,
eos_token_id=tok.eos_token_id
)
sampler = DistributedSampler(ds) if ddp else None
dl = DataLoader(
ds,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=2,
drop_last=True,
)
def param_groups(model):
decay, no_decay = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if any(nd in n for nd in ["bias", "ln", "layernorm", "LayerNorm", "norm"]):
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay, "weight_decay": 0.1},
{"params": no_decay, "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None
os.makedirs(out_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None
val_ds = None
val_dl = None
start_step = 0
best_val_loss = float("inf")
if resume_from and os.path.exists(resume_from):
ckpt = torch.load(resume_from, map_location=device)
model_state = ckpt["model_state_dict"]
target = model.module if ddp else model
target.load_state_dict(model_state)
optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {}))
scheduler_state = ckpt.get("scheduler_state_dict", None)
if scheduler_state:
scheduler.load_state_dict(scheduler_state)
if "scaler_state_dict" in ckpt and scaler is not None:
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_step = ckpt.get("step", 0)
best_val_loss = ckpt.get("best_val_loss", best_val_loss)
print(f"Resumed from {resume_from} at step {start_step}")
model.train()
step = start_step
micro = 0
running_loss = 0.0
t0 = time.time()
no_improve_steps = 0
early_stop_patience = 10_000
while step < max_steps:
if sampler is not None:
sampler.set_epoch(step)
for batch in dl:
x, y = batch
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
device_type = 'cuda' if device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
logits, loss = model(x, y)
loss = loss / grad_accum
scaler.scale(loss).backward()
micro += 1
running_loss += loss.item()
if micro % grad_accum == 0:
if clip_grad_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
grad_norm = None
if (step + 1) % 50 == 0 and (not ddp or local_rank == 0):
debug_gradients = step < 5
grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
if ema:
ema.update(model if not ddp else model.module)
step += 1
if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None:
avg_loss = running_loss * grad_accum / 50.0
running_loss = 0.0
elapsed = time.time() - t0
lr_now = scheduler.get_last_lr()[0]
print(f"step={step} loss={avg_loss:.6f} grad_norm={grad_norm:.3f} lr={lr_now:.6f} elapsed={elapsed:.1f}s")
if writer:
writer.add_scalar("train/loss", avg_loss, step)
writer.add_scalar("train/grad_norm", grad_norm, step)
writer.add_scalar("train/lr", lr_now, step)
t0 = time.time()
if validate_every and step % validate_every == 0:
if val_dl is None:
val_sources = []
for source in sources[:min(3, len(sources))]:
val_source = DataSource(
name=f"{source.name}_val",
hf_path="wikitext",
hf_name="wikitext-2-v1",
split="validation",
text_field="text",
weight=1,
streaming=False
)
val_sources.append(val_source)
val_ds = TokenChunkDataset(
tokenizer=tok,
sources=val_sources,
seq_len=seq_len,
eos_token_id=tok.eos_token_id
)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False)
model.eval()
if ema:
ema.store(model if not ddp else model.module)
ema.copy_to(model if not ddp else model.module)
val_losses = []
with torch.no_grad():
for i, (vx, vy) in enumerate(val_dl):
if i >= val_steps:
break
vx = vx.to(device)
vy = vy.to(device)
device_type = 'cuda' if device.type == 'cuda' else 'cpu'
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
_, vloss = model(vx, vy)
val_losses.append(float(vloss.detach().cpu().item()))
mean_val = float(sum(val_losses) / max(1, len(val_losses)))
if writer and (not ddp or local_rank == 0):
writer.add_scalar("val/loss", mean_val, step)
print(f"[eval] step={step} val_loss={mean_val:.6f}")
if ema:
ema.restore(model if not ddp else model.module)
model.train()
if mean_val < best_val_loss:
best_val_loss = mean_val
no_improve_steps = 0
best_path_pt = os.path.join(out_dir, f"supernova_best_step{step}.pt")
model_state = model.module.state_dict() if ddp else model.state_dict()
ckpt = {
"model_state_dict": model_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": (scaler.state_dict() if scaler else None),
"step": step,
"best_val_loss": best_val_loss,
"config": cfg.__dict__,
}
if not ddp or local_rank == 0:
atomic_save(ckpt, best_path_pt)
print(f"Saved best checkpoint to {best_path_pt}")
# Save safetensors
if export_safetensors:
best_path_st = os.path.join(out_dir, f"supernova_best_step{step}.safetensors")
save_safetensors_checkpoint(