Supernova25million / supernova /train_refactor.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
"""
Refactored training script for SupernovaModel
- AMP mixed precision training
- Resume from checkpoint (saves optimizer + scheduler state)
- TensorBoard logging
- Optional validation loop if --val-data-config provided
- DataLoader pin_memory and non_blocking transfers
- Save optimizer/scheduler/model/config/step
- CLI flags for common hyperparams
Usage:
python -m supernova.train_refactor --config path/to/config.json --data-config path/to/data.yaml
"""
import argparse
import math
import os
import time
from typing import Optional
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import get_cosine_schedule_with_warmup
from .config import ModelConfig
from .model import SupernovaModel
from .tokenizer import load_gpt2_tokenizer
from .data import load_sources_from_yaml, TokenChunkDataset
def compute_grad_norm(model: nn.Module) -> float:
total = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.float().norm(2).item()
total += param_norm * param_norm
return math.sqrt(total)
class Trainer:
def __init__(
self,
cfg: ModelConfig,
tok,
train_sources,
device: torch.device,
seq_len: int = 1024,
batch_size: int = 16,
grad_accum: int = 8,
lr: float = 3e-4,
warmup_steps: int = 2000,
max_steps: int = 100_000,
out_dir: str = "checkpoints",
weight_decay: float = 0.1,
betas: tuple = (0.9, 0.95),
num_workers: int = 4,
pin_memory: bool = True,
seed: int = 42,
validate_every: Optional[int] = None,
val_sources: Optional[list] = None,
clip_grad_norm: Optional[float] = None,
):
torch.manual_seed(seed)
self.device = device
self.cfg = cfg
self.tok = tok
self.seq_len = seq_len
self.batch_size = batch_size
self.grad_accum = grad_accum
self.lr = lr
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.out_dir = out_dir
self.weight_decay = weight_decay
self.betas = betas
self.num_workers = num_workers
self.pin_memory = pin_memory
self.validate_every = validate_every
self.val_sources = val_sources
self.clip_grad_norm = clip_grad_norm
os.makedirs(out_dir, exist_ok=True)
self.model = SupernovaModel(cfg).to(device)
# optimizer + scheduler
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay
)
self.scheduler = get_cosine_schedule_with_warmup(
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
)
self.train_ds = TokenChunkDataset(tok, train_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
self.train_dl = DataLoader(
self.train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
if val_sources is not None:
self.val_ds = TokenChunkDataset(tok, val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
self.val_dl = DataLoader(self.val_ds, batch_size=batch_size, shuffle=False, num_workers=max(0, num_workers//2), pin_memory=pin_memory)
else:
self.val_dl = None
# AMP scaler
self.scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None
# logging
self.writer = SummaryWriter(log_dir=os.path.join(out_dir, "logs"))
# training state
self.step = 0
self.micro = 0
self.running_loss = 0.0
# perf
torch.backends.cudnn.benchmark = True
def save_ckpt(self, path: str):
payload = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"config": self.cfg.__dict__,
"step": self.step,
}
torch.save(payload, path)
def load_ckpt(self, path: str):
ckpt = torch.load(path, map_location=self.device)
self.model.load_state_dict(ckpt["model_state_dict"])
if "optimizer_state_dict" in ckpt:
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if "scheduler_state_dict" in ckpt:
self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
self.step = ckpt.get("step", 0)
print(f"Resumed from {path}, step={self.step}")
@torch.no_grad()
def validate(self):
if self.val_dl is None:
return None
self.model.eval()
tot = 0.0
count = 0
for batch in self.val_dl:
x, y = batch
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
with torch.cuda.amp.autocast(enabled=(self.scaler is not None)):
_, loss = self.model(x, y)
tot += float(loss.detach().item())
count += 1
self.model.train()
return tot / max(1, count)
def train_loop(self, save_every: int = 10000, log_every: int = 50):
t0 = time.time()
for epoch in iter(int, 1): # infinite loop, break by max_steps
for batch in self.train_dl:
x, y = batch
x = x.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
# forward (AMP-capable)
if self.scaler is not None:
with torch.cuda.amp.autocast():
_, loss = self.model(x, y)
else:
_, loss = self.model(x, y)
loss = loss / self.grad_accum
if self.scaler is not None:
self.scaler.scale(loss).backward()
else:
loss.backward()
self.micro += 1
self.running_loss += float(loss.detach().item())
if self.micro % self.grad_accum == 0:
# optional clipping
if self.clip_grad_norm is not None:
if self.scaler is not None:
# unscale before clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
if self.scaler is not None:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
self.scheduler.step()
self.step += 1
if self.step % log_every == 0:
grad_norm = compute_grad_norm(self.model)
avg_loss = self.running_loss * self.grad_accum / log_every
elapsed = time.time() - t0
lr_now = self.scheduler.get_last_lr()[0]
tokens_per_sec = (self.batch_size * self.seq_len * log_every) / max(1e-9, elapsed)
print(f"step={self.step} loss={avg_loss:.4f} grad_norm={grad_norm:.2f} lr={lr_now:.6f} elapsed={elapsed:.1f}s tokens/s={tokens_per_sec:.1f}")
# tensorboard
self.writer.add_scalar("train/loss", avg_loss, self.step)
self.writer.add_scalar("train/grad_norm", grad_norm, self.step)
self.writer.add_scalar("train/lr", lr_now, self.step)
self.writer.add_scalar("train/tokens_per_sec", tokens_per_sec, self.step)
self.running_loss = 0.0
t0 = time.time()
if save_every and self.step % save_every == 0:
ckpt_path = os.path.join(self.out_dir, f"supernova_step{self.step}.pt")
self.save_ckpt(ckpt_path)
print(f"Saved checkpoint {ckpt_path}")
if self.validate_every and self.step % self.validate_every == 0:
val_loss = self.validate()
if val_loss is not None:
print(f"Validation loss at step {self.step}: {val_loss:.4f}")
self.writer.add_scalar("val/loss", val_loss, self.step)
if self.step >= self.max_steps:
print("Reached max_steps; finishing training")
final_ckpt = os.path.join(self.out_dir, "supernova_final.pt")
self.save_ckpt(final_ckpt)
return
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True)
ap.add_argument("--data-config", required=True)
ap.add_argument("--val-data-config", default=None)
ap.add_argument("--seq-len", type=int, default=1024)
ap.add_argument("--batch-size", type=int, default=16)
ap.add_argument("--grad-accum", type=int, default=8)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--warmup-steps", type=int, default=2000)
ap.add_argument("--max-steps", type=int, default=100000)
ap.add_argument("--save-every", type=int, default=10000)
ap.add_argument("--out-dir", type=str, default="checkpoints")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--weight-decay", type=float, default=0.1)
ap.add_argument("--betas", type=float, nargs=2, default=(0.9, 0.95))
ap.add_argument("--num-workers", type=int, default=4)
ap.add_argument("--resume", type=str, default=None, help="path to checkpoint to resume from")
ap.add_argument("--validate-every", type=int, default=None)
ap.add_argument("--clip-grad-norm", type=float, default=None)
return ap.parse_args()
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = ModelConfig.from_json_file(args.config)
cfg.assert_exact_params(expected=25_000_000)
tok = load_gpt2_tokenizer()
assert tok.vocab_size == cfg.vocab_size, (
f"Tokenizer vocab size ({tok.vocab_size}) != config ({cfg.vocab_size})"
)
train_sources = load_sources_from_yaml(args.data_config)
val_sources = load_sources_from_yaml(args.val_data_config) if args.val_data_config else None
trainer = Trainer(
cfg=cfg,
tok=tok,
train_sources=train_sources,
device=device,
seq_len=args.seq_len,
batch_size=args.batch_size,
grad_accum=args.grad_accum,
lr=args.lr,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
out_dir=args.out_dir,
weight_decay=args.weight_decay,
betas=tuple(args.betas),
num_workers=args.num_workers,
seed=args.seed,
validate_every=args.validate_every,
val_sources=val_sources,
clip_grad_norm=args.clip_grad_norm,
)
if args.resume:
trainer.load_ckpt(args.resume)
trainer.train_loop(save_every=args.save_every)
if __name__ == "__main__":
main()