File size: 2,962 Bytes
6848cb6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """Training utilities: optimizer setup, LR schedule, checkpointing."""
import json
import math
import os
from pathlib import Path
import torch
def cosine_with_warmup(step, warmup, total, max_lr, min_lr_ratio=0.1):
if step < warmup:
return max_lr * (step + 1) / warmup
progress = (step - warmup) / max(1, total - warmup)
progress = min(1.0, progress)
return min_lr_ratio * max_lr + 0.5 * (max_lr - min_lr_ratio * max_lr) * (1 + math.cos(math.pi * progress))
def make_optimizer(model, lr, weight_decay=0.1, betas=(0.9, 0.95), fused=True):
"""AdamW with weight decay only on 2D weights (no decay on biases / norms / embeddings).
Per Loshchilov & Hutter; same convention as nanoGPT.
"""
decay, no_decay = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
if p.dim() >= 2 and "tok_emb" not in n:
decay.append(p)
else:
no_decay.append(p)
groups = [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
extra = {}
if fused and torch.cuda.is_available():
try:
return torch.optim.AdamW(groups, lr=lr, betas=betas, fused=True)
except TypeError:
pass
return torch.optim.AdamW(groups, lr=lr, betas=betas, **extra)
def save_checkpoint(path, model, optimizer, scheduler_state, step, extra=None):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"scheduler": scheduler_state,
"step": step,
"config": {k: getattr(model.cfg, k) for k in model.cfg.__dataclass_fields__},
"extra": extra or {},
}
tmp = path.with_suffix(path.suffix + ".tmp")
torch.save(payload, tmp)
os.replace(tmp, path)
def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
payload = torch.load(path, map_location=map_location, weights_only=False)
# Si el checkpoint tiene tie_embeddings=True, usar strict=False
# (lm_head comparte pesos con tok_emb y no se guarda por separado)
strict = not payload.get("tie_embeddings", False)
missing, unexpected = model.load_state_dict(payload["model"], strict=strict)
if missing:
print(f"[load_checkpoint] missing keys (expected with tie_embeddings): {missing[:3]}")
if optimizer is not None and payload.get("optimizer"):
optimizer.load_state_dict(payload["optimizer"])
return payload.get("step", 0), payload.get("extra", {})
def count_tokens(loader_output_iter, n_steps, block_size, batch_size):
"""Approximate; effective tokens consumed per step."""
return n_steps * block_size * batch_size
def log_jsonl(path, record):
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|