|
|
"""
|
|
|
Refactored training script for SupernovaModel
|
|
|
- AMP mixed precision training
|
|
|
- Resume from checkpoint (saves optimizer + scheduler state)
|
|
|
- TensorBoard logging
|
|
|
- Optional validation loop if --val-data-config provided
|
|
|
- DataLoader pin_memory and non_blocking transfers
|
|
|
- Save optimizer/scheduler/model/config/step
|
|
|
- CLI flags for common hyperparams
|
|
|
|
|
|
Usage:
|
|
|
python -m supernova.train_refactor --config path/to/config.json --data-config path/to/data.yaml
|
|
|
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import math
|
|
|
import os
|
|
|
import time
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
from transformers import get_cosine_schedule_with_warmup
|
|
|
|
|
|
from .config import ModelConfig
|
|
|
from .model import SupernovaModel
|
|
|
from .tokenizer import load_gpt2_tokenizer
|
|
|
from .data import load_sources_from_yaml, TokenChunkDataset
|
|
|
|
|
|
|
|
|
def compute_grad_norm(model: nn.Module) -> float:
|
|
|
total = 0.0
|
|
|
for p in model.parameters():
|
|
|
if p.grad is not None:
|
|
|
param_norm = p.grad.data.float().norm(2).item()
|
|
|
total += param_norm * param_norm
|
|
|
return math.sqrt(total)
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
cfg: ModelConfig,
|
|
|
tok,
|
|
|
train_sources,
|
|
|
device: torch.device,
|
|
|
seq_len: int = 1024,
|
|
|
batch_size: int = 16,
|
|
|
grad_accum: int = 8,
|
|
|
lr: float = 3e-4,
|
|
|
warmup_steps: int = 2000,
|
|
|
max_steps: int = 100_000,
|
|
|
out_dir: str = "checkpoints",
|
|
|
weight_decay: float = 0.1,
|
|
|
betas: tuple = (0.9, 0.95),
|
|
|
num_workers: int = 4,
|
|
|
pin_memory: bool = True,
|
|
|
seed: int = 42,
|
|
|
validate_every: Optional[int] = None,
|
|
|
val_sources: Optional[list] = None,
|
|
|
clip_grad_norm: Optional[float] = None,
|
|
|
):
|
|
|
torch.manual_seed(seed)
|
|
|
self.device = device
|
|
|
self.cfg = cfg
|
|
|
self.tok = tok
|
|
|
self.seq_len = seq_len
|
|
|
self.batch_size = batch_size
|
|
|
self.grad_accum = grad_accum
|
|
|
self.lr = lr
|
|
|
self.warmup_steps = warmup_steps
|
|
|
self.max_steps = max_steps
|
|
|
self.out_dir = out_dir
|
|
|
self.weight_decay = weight_decay
|
|
|
self.betas = betas
|
|
|
self.num_workers = num_workers
|
|
|
self.pin_memory = pin_memory
|
|
|
self.validate_every = validate_every
|
|
|
self.val_sources = val_sources
|
|
|
self.clip_grad_norm = clip_grad_norm
|
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
|
|
self.model = SupernovaModel(cfg).to(device)
|
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW(
|
|
|
self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay
|
|
|
)
|
|
|
self.scheduler = get_cosine_schedule_with_warmup(
|
|
|
self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
|
|
|
)
|
|
|
|
|
|
self.train_ds = TokenChunkDataset(tok, train_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
|
|
self.train_dl = DataLoader(
|
|
|
self.train_ds,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
num_workers=num_workers,
|
|
|
pin_memory=pin_memory,
|
|
|
drop_last=True,
|
|
|
)
|
|
|
|
|
|
if val_sources is not None:
|
|
|
self.val_ds = TokenChunkDataset(tok, val_sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
|
|
self.val_dl = DataLoader(self.val_ds, batch_size=batch_size, shuffle=False, num_workers=max(0, num_workers//2), pin_memory=pin_memory)
|
|
|
else:
|
|
|
self.val_dl = None
|
|
|
|
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None
|
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=os.path.join(out_dir, "logs"))
|
|
|
|
|
|
|
|
|
self.step = 0
|
|
|
self.micro = 0
|
|
|
self.running_loss = 0.0
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def save_ckpt(self, path: str):
|
|
|
payload = {
|
|
|
"model_state_dict": self.model.state_dict(),
|
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
|
"scheduler_state_dict": self.scheduler.state_dict(),
|
|
|
"config": self.cfg.__dict__,
|
|
|
"step": self.step,
|
|
|
}
|
|
|
torch.save(payload, path)
|
|
|
|
|
|
def load_ckpt(self, path: str):
|
|
|
ckpt = torch.load(path, map_location=self.device)
|
|
|
self.model.load_state_dict(ckpt["model_state_dict"])
|
|
|
if "optimizer_state_dict" in ckpt:
|
|
|
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
|
if "scheduler_state_dict" in ckpt:
|
|
|
self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
|
|
self.step = ckpt.get("step", 0)
|
|
|
print(f"Resumed from {path}, step={self.step}")
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def validate(self):
|
|
|
if self.val_dl is None:
|
|
|
return None
|
|
|
self.model.eval()
|
|
|
tot = 0.0
|
|
|
count = 0
|
|
|
for batch in self.val_dl:
|
|
|
x, y = batch
|
|
|
x = x.to(self.device, non_blocking=True)
|
|
|
y = y.to(self.device, non_blocking=True)
|
|
|
with torch.cuda.amp.autocast(enabled=(self.scaler is not None)):
|
|
|
_, loss = self.model(x, y)
|
|
|
tot += float(loss.detach().item())
|
|
|
count += 1
|
|
|
self.model.train()
|
|
|
return tot / max(1, count)
|
|
|
|
|
|
def train_loop(self, save_every: int = 10000, log_every: int = 50):
|
|
|
t0 = time.time()
|
|
|
for epoch in iter(int, 1):
|
|
|
for batch in self.train_dl:
|
|
|
x, y = batch
|
|
|
x = x.to(self.device, non_blocking=True)
|
|
|
y = y.to(self.device, non_blocking=True)
|
|
|
|
|
|
|
|
|
if self.scaler is not None:
|
|
|
with torch.cuda.amp.autocast():
|
|
|
_, loss = self.model(x, y)
|
|
|
else:
|
|
|
_, loss = self.model(x, y)
|
|
|
|
|
|
loss = loss / self.grad_accum
|
|
|
|
|
|
if self.scaler is not None:
|
|
|
self.scaler.scale(loss).backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
|
|
|
self.micro += 1
|
|
|
self.running_loss += float(loss.detach().item())
|
|
|
|
|
|
if self.micro % self.grad_accum == 0:
|
|
|
|
|
|
if self.clip_grad_norm is not None:
|
|
|
if self.scaler is not None:
|
|
|
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm)
|
|
|
|
|
|
if self.scaler is not None:
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
else:
|
|
|
self.optimizer.step()
|
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True)
|
|
|
self.scheduler.step()
|
|
|
|
|
|
self.step += 1
|
|
|
|
|
|
if self.step % log_every == 0:
|
|
|
grad_norm = compute_grad_norm(self.model)
|
|
|
avg_loss = self.running_loss * self.grad_accum / log_every
|
|
|
elapsed = time.time() - t0
|
|
|
lr_now = self.scheduler.get_last_lr()[0]
|
|
|
tokens_per_sec = (self.batch_size * self.seq_len * log_every) / max(1e-9, elapsed)
|
|
|
|
|
|
print(f"step={self.step} loss={avg_loss:.4f} grad_norm={grad_norm:.2f} lr={lr_now:.6f} elapsed={elapsed:.1f}s tokens/s={tokens_per_sec:.1f}")
|
|
|
|
|
|
|
|
|
self.writer.add_scalar("train/loss", avg_loss, self.step)
|
|
|
self.writer.add_scalar("train/grad_norm", grad_norm, self.step)
|
|
|
self.writer.add_scalar("train/lr", lr_now, self.step)
|
|
|
self.writer.add_scalar("train/tokens_per_sec", tokens_per_sec, self.step)
|
|
|
|
|
|
self.running_loss = 0.0
|
|
|
t0 = time.time()
|
|
|
|
|
|
if save_every and self.step % save_every == 0:
|
|
|
ckpt_path = os.path.join(self.out_dir, f"supernova_step{self.step}.pt")
|
|
|
self.save_ckpt(ckpt_path)
|
|
|
print(f"Saved checkpoint {ckpt_path}")
|
|
|
|
|
|
if self.validate_every and self.step % self.validate_every == 0:
|
|
|
val_loss = self.validate()
|
|
|
if val_loss is not None:
|
|
|
print(f"Validation loss at step {self.step}: {val_loss:.4f}")
|
|
|
self.writer.add_scalar("val/loss", val_loss, self.step)
|
|
|
|
|
|
if self.step >= self.max_steps:
|
|
|
print("Reached max_steps; finishing training")
|
|
|
final_ckpt = os.path.join(self.out_dir, "supernova_final.pt")
|
|
|
self.save_ckpt(final_ckpt)
|
|
|
return
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
ap = argparse.ArgumentParser()
|
|
|
ap.add_argument("--config", required=True)
|
|
|
ap.add_argument("--data-config", required=True)
|
|
|
ap.add_argument("--val-data-config", default=None)
|
|
|
ap.add_argument("--seq-len", type=int, default=1024)
|
|
|
ap.add_argument("--batch-size", type=int, default=16)
|
|
|
ap.add_argument("--grad-accum", type=int, default=8)
|
|
|
ap.add_argument("--lr", type=float, default=3e-4)
|
|
|
ap.add_argument("--warmup-steps", type=int, default=2000)
|
|
|
ap.add_argument("--max-steps", type=int, default=100000)
|
|
|
ap.add_argument("--save-every", type=int, default=10000)
|
|
|
ap.add_argument("--out-dir", type=str, default="checkpoints")
|
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
|
ap.add_argument("--weight-decay", type=float, default=0.1)
|
|
|
ap.add_argument("--betas", type=float, nargs=2, default=(0.9, 0.95))
|
|
|
ap.add_argument("--num-workers", type=int, default=4)
|
|
|
ap.add_argument("--resume", type=str, default=None, help="path to checkpoint to resume from")
|
|
|
ap.add_argument("--validate-every", type=int, default=None)
|
|
|
ap.add_argument("--clip-grad-norm", type=float, default=None)
|
|
|
return ap.parse_args()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
args = parse_args()
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
cfg = ModelConfig.from_json_file(args.config)
|
|
|
cfg.assert_exact_params(expected=25_000_000)
|
|
|
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
assert tok.vocab_size == cfg.vocab_size, (
|
|
|
f"Tokenizer vocab size ({tok.vocab_size}) != config ({cfg.vocab_size})"
|
|
|
)
|
|
|
|
|
|
train_sources = load_sources_from_yaml(args.data_config)
|
|
|
val_sources = load_sources_from_yaml(args.val_data_config) if args.val_data_config else None
|
|
|
|
|
|
trainer = Trainer(
|
|
|
cfg=cfg,
|
|
|
tok=tok,
|
|
|
train_sources=train_sources,
|
|
|
device=device,
|
|
|
seq_len=args.seq_len,
|
|
|
batch_size=args.batch_size,
|
|
|
grad_accum=args.grad_accum,
|
|
|
lr=args.lr,
|
|
|
warmup_steps=args.warmup_steps,
|
|
|
max_steps=args.max_steps,
|
|
|
out_dir=args.out_dir,
|
|
|
weight_decay=args.weight_decay,
|
|
|
betas=tuple(args.betas),
|
|
|
num_workers=args.num_workers,
|
|
|
seed=args.seed,
|
|
|
validate_every=args.validate_every,
|
|
|
val_sources=val_sources,
|
|
|
clip_grad_norm=args.clip_grad_norm,
|
|
|
)
|
|
|
|
|
|
if args.resume:
|
|
|
trainer.load_ckpt(args.resume)
|
|
|
|
|
|
trainer.train_loop(save_every=args.save_every)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |