|
|
import argparse |
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import time |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from transformers import get_cosine_schedule_with_warmup |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
from .config import ModelConfig |
|
|
from .model import SupernovaModel |
|
|
from .tokenizer import load_gpt2_tokenizer |
|
|
from .data import load_sources_from_yaml, TokenChunkDataset, DataSource |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_grad_norm(model: nn.Module, debug: bool = False) -> float: |
|
|
total = 0.0 |
|
|
grad_count = 0 |
|
|
param_count = 0 |
|
|
|
|
|
for name, p in model.named_parameters(): |
|
|
param_count += 1 |
|
|
if p.grad is not None: |
|
|
grad_count += 1 |
|
|
param_norm = p.grad.data.float().norm(2).item() |
|
|
total += param_norm * param_norm |
|
|
if debug and param_norm > 1e-8: |
|
|
print(f" {name}: grad_norm={param_norm:.6f}") |
|
|
elif debug: |
|
|
print(f" {name}: NO GRAD") |
|
|
|
|
|
if debug: |
|
|
print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={math.sqrt(total):.6f}") |
|
|
|
|
|
return math.sqrt(total) |
|
|
|
|
|
def atomic_save(obj: Dict[str, Any], path: str): |
|
|
tmp = path + ".tmp" |
|
|
torch.save(obj, tmp) |
|
|
os.replace(tmp, path) |
|
|
|
|
|
def save_safetensors_checkpoint(model_state_dict: Dict[str, torch.Tensor], path: str): |
|
|
"""Save model weights in safetensors format.""" |
|
|
try: |
|
|
tmp = path + ".tmp" |
|
|
save_file(model_state_dict, tmp) |
|
|
os.replace(tmp, path) |
|
|
print(f"✓ Saved safetensors to {path}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to save safetensors: {e}") |
|
|
|
|
|
class EMA: |
|
|
"""Simple exponential moving average of model params (maintains shadow copy).""" |
|
|
def __init__(self, model: nn.Module, decay: float = 0.9999): |
|
|
self.decay = decay |
|
|
self.shadow = {} |
|
|
for name, p in model.named_parameters(): |
|
|
if p.requires_grad: |
|
|
self.shadow[name] = p.data.clone() |
|
|
|
|
|
def update(self, model: nn.Module): |
|
|
for name, p in model.named_parameters(): |
|
|
if p.requires_grad: |
|
|
self.shadow[name].mul_(self.decay).add_(p.data, alpha=1.0 - self.decay) |
|
|
|
|
|
def store(self, model: nn.Module): |
|
|
self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} |
|
|
|
|
|
def copy_to(self, model: nn.Module): |
|
|
for name, p in model.named_parameters(): |
|
|
if p.requires_grad: |
|
|
p.data.copy_(self.shadow[name]) |
|
|
|
|
|
def restore(self, model: nn.Module): |
|
|
for name, p in model.named_parameters(): |
|
|
if p.requires_grad: |
|
|
p.data.copy_(self.backup[name]) |
|
|
del self.backup |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train( |
|
|
config_path: str, |
|
|
data_config_path: str, |
|
|
seq_len: int = 1024, |
|
|
batch_size: int = 16, |
|
|
grad_accum: int = 8, |
|
|
lr: float = 3e-4, |
|
|
warmup_steps: int = 2000, |
|
|
max_steps: int = 100_000, |
|
|
save_every: int = 10_000, |
|
|
out_dir: str = "checkpoints", |
|
|
seed: int = 42, |
|
|
validate_every: int = 1000, |
|
|
val_steps: int = 100, |
|
|
clip_grad_norm: Optional[float] = 1.0, |
|
|
use_ema: bool = True, |
|
|
ema_decay: float = 0.9999, |
|
|
resume_from: Optional[str] = None, |
|
|
use_tensorboard: bool = True, |
|
|
ddp: bool = False, |
|
|
local_rank: int = 0, |
|
|
num_workers: int = 4, |
|
|
pin_memory: bool = True, |
|
|
compile_model: bool = False, |
|
|
export_safetensors: bool = True, |
|
|
): |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
import random |
|
|
random.seed(seed) |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
if ddp: |
|
|
torch.distributed.init_process_group(backend="nccl") |
|
|
device = torch.device(f"cuda:{local_rank}") |
|
|
torch.cuda.set_device(device) |
|
|
else: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
cfg = ModelConfig.from_json_file(config_path) |
|
|
cfg.assert_exact_params(expected=25_000_000) |
|
|
tok = load_gpt2_tokenizer() |
|
|
assert tok.vocab_size == cfg.vocab_size, "Tokenizer vocab size mismatch." |
|
|
|
|
|
model = SupernovaModel(cfg) |
|
|
if hasattr(model, "gradient_checkpointing_enable"): |
|
|
try: |
|
|
model.gradient_checkpointing_enable() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
assert total_params == 25_000_000, f"Model has {total_params} params, expected 25,000,000" |
|
|
|
|
|
if compile_model: |
|
|
try: |
|
|
model = torch.compile(model) |
|
|
except Exception as e: |
|
|
print("torch.compile not available/failed:", e) |
|
|
|
|
|
if ddp: |
|
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=False) |
|
|
|
|
|
sources = load_sources_from_yaml(data_config_path) |
|
|
ds = TokenChunkDataset( |
|
|
tokenizer=tok, |
|
|
sources=sources, |
|
|
seq_len=seq_len, |
|
|
eos_token_id=tok.eos_token_id |
|
|
) |
|
|
sampler = DistributedSampler(ds) if ddp else None |
|
|
|
|
|
dl = DataLoader( |
|
|
ds, |
|
|
batch_size=batch_size, |
|
|
sampler=sampler, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
prefetch_factor=2, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
def param_groups(model): |
|
|
decay, no_decay = [], [] |
|
|
for n, p in model.named_parameters(): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
if any(nd in n for nd in ["bias", "ln", "layernorm", "LayerNorm", "norm"]): |
|
|
no_decay.append(p) |
|
|
else: |
|
|
decay.append(p) |
|
|
return [ |
|
|
{"params": decay, "weight_decay": 0.1}, |
|
|
{"params": no_decay, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
optimizer = torch.optim.AdamW(param_groups(model), lr=lr, betas=(0.9, 0.95), eps=1e-8) |
|
|
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps) |
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda")) |
|
|
|
|
|
ema = EMA(model if not ddp else model.module, decay=ema_decay) if use_ema else None |
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir=os.path.join(out_dir, "runs")) if use_tensorboard and (not ddp or local_rank == 0) else None |
|
|
|
|
|
val_ds = None |
|
|
val_dl = None |
|
|
|
|
|
start_step = 0 |
|
|
best_val_loss = float("inf") |
|
|
if resume_from and os.path.exists(resume_from): |
|
|
ckpt = torch.load(resume_from, map_location=device) |
|
|
model_state = ckpt["model_state_dict"] |
|
|
target = model.module if ddp else model |
|
|
target.load_state_dict(model_state) |
|
|
optimizer.load_state_dict(ckpt.get("optimizer_state_dict", {})) |
|
|
scheduler_state = ckpt.get("scheduler_state_dict", None) |
|
|
if scheduler_state: |
|
|
scheduler.load_state_dict(scheduler_state) |
|
|
if "scaler_state_dict" in ckpt and scaler is not None: |
|
|
scaler.load_state_dict(ckpt["scaler_state_dict"]) |
|
|
start_step = ckpt.get("step", 0) |
|
|
best_val_loss = ckpt.get("best_val_loss", best_val_loss) |
|
|
print(f"Resumed from {resume_from} at step {start_step}") |
|
|
|
|
|
model.train() |
|
|
step = start_step |
|
|
micro = 0 |
|
|
running_loss = 0.0 |
|
|
t0 = time.time() |
|
|
no_improve_steps = 0 |
|
|
early_stop_patience = 10_000 |
|
|
|
|
|
while step < max_steps: |
|
|
if sampler is not None: |
|
|
sampler.set_epoch(step) |
|
|
|
|
|
for batch in dl: |
|
|
x, y = batch |
|
|
x = x.to(device, non_blocking=True) |
|
|
y = y.to(device, non_blocking=True) |
|
|
|
|
|
device_type = 'cuda' if device.type == 'cuda' else 'cpu' |
|
|
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")): |
|
|
logits, loss = model(x, y) |
|
|
loss = loss / grad_accum |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
micro += 1 |
|
|
running_loss += loss.item() |
|
|
|
|
|
if micro % grad_accum == 0: |
|
|
if clip_grad_norm is not None: |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) |
|
|
|
|
|
grad_norm = None |
|
|
if (step + 1) % 50 == 0 and (not ddp or local_rank == 0): |
|
|
debug_gradients = step < 5 |
|
|
grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients) |
|
|
|
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
scheduler.step() |
|
|
|
|
|
if ema: |
|
|
ema.update(model if not ddp else model.module) |
|
|
step += 1 |
|
|
|
|
|
if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None: |
|
|
avg_loss = running_loss * grad_accum / 50.0 |
|
|
running_loss = 0.0 |
|
|
elapsed = time.time() - t0 |
|
|
lr_now = scheduler.get_last_lr()[0] |
|
|
print(f"step={step} loss={avg_loss:.6f} grad_norm={grad_norm:.3f} lr={lr_now:.6f} elapsed={elapsed:.1f}s") |
|
|
if writer: |
|
|
writer.add_scalar("train/loss", avg_loss, step) |
|
|
writer.add_scalar("train/grad_norm", grad_norm, step) |
|
|
writer.add_scalar("train/lr", lr_now, step) |
|
|
t0 = time.time() |
|
|
|
|
|
if validate_every and step % validate_every == 0: |
|
|
if val_dl is None: |
|
|
val_sources = [] |
|
|
for source in sources[:min(3, len(sources))]: |
|
|
val_source = DataSource( |
|
|
name=f"{source.name}_val", |
|
|
hf_path="wikitext", |
|
|
hf_name="wikitext-2-v1", |
|
|
split="validation", |
|
|
text_field="text", |
|
|
weight=1, |
|
|
streaming=False |
|
|
) |
|
|
val_sources.append(val_source) |
|
|
val_ds = TokenChunkDataset( |
|
|
tokenizer=tok, |
|
|
sources=val_sources, |
|
|
seq_len=seq_len, |
|
|
eos_token_id=tok.eos_token_id |
|
|
) |
|
|
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) |
|
|
|
|
|
model.eval() |
|
|
if ema: |
|
|
ema.store(model if not ddp else model.module) |
|
|
ema.copy_to(model if not ddp else model.module) |
|
|
|
|
|
val_losses = [] |
|
|
with torch.no_grad(): |
|
|
for i, (vx, vy) in enumerate(val_dl): |
|
|
if i >= val_steps: |
|
|
break |
|
|
vx = vx.to(device) |
|
|
vy = vy.to(device) |
|
|
device_type = 'cuda' if device.type == 'cuda' else 'cpu' |
|
|
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")): |
|
|
_, vloss = model(vx, vy) |
|
|
val_losses.append(float(vloss.detach().cpu().item())) |
|
|
mean_val = float(sum(val_losses) / max(1, len(val_losses))) |
|
|
if writer and (not ddp or local_rank == 0): |
|
|
writer.add_scalar("val/loss", mean_val, step) |
|
|
print(f"[eval] step={step} val_loss={mean_val:.6f}") |
|
|
|
|
|
if ema: |
|
|
ema.restore(model if not ddp else model.module) |
|
|
model.train() |
|
|
|
|
|
if mean_val < best_val_loss: |
|
|
best_val_loss = mean_val |
|
|
no_improve_steps = 0 |
|
|
best_path_pt = os.path.join(out_dir, f"supernova_best_step{step}.pt") |
|
|
model_state = model.module.state_dict() if ddp else model.state_dict() |
|
|
ckpt = { |
|
|
"model_state_dict": model_state, |
|
|
"optimizer_state_dict": optimizer.state_dict(), |
|
|
"scheduler_state_dict": scheduler.state_dict(), |
|
|
"scaler_state_dict": (scaler.state_dict() if scaler else None), |
|
|
"step": step, |
|
|
"best_val_loss": best_val_loss, |
|
|
"config": cfg.__dict__, |
|
|
} |
|
|
if not ddp or local_rank == 0: |
|
|
atomic_save(ckpt, best_path_pt) |
|
|
print(f"Saved best checkpoint to {best_path_pt}") |
|
|
|
|
|
|
|
|
if export_safetensors: |
|
|
best_path_st = os.path.join(out_dir, f"supernova_best_step{step}.safetensors") |
|
|
save_safetensors_checkpoint( |
|
|
|