Coda / src /train.py
Prajanya Gupta
initial deploy
6b7b403
"""Train GPT on MIDI token chunks: checkpoints, CSV log, val tracking."""
from __future__ import annotations
import argparse
import csv
import importlib
import math
import sys
import time
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT = _SCRIPT_DIR.parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from dataset import build_dataloaders # noqa: E402
from model import GPT, default_gpt_config # noqa: E402
def _lr_lambda_factory(warmup_steps: int, total_steps: int):
"""Warmup then cosine: LR multiplier 1.0 → 0.1 over non-warmup steps."""
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
return float(current_step + 1) / float(max(1, warmup_steps))
if total_steps <= warmup_steps:
return 1.0
t = (current_step - warmup_steps) / float(total_steps - warmup_steps)
t = min(1.0, max(0.0, t))
min_f = 0.1
return min_f + (1.0 - min_f) * 0.5 * (1.0 + math.cos(math.pi * t))
return lr_lambda
@torch.no_grad()
def evaluate(
model: GPT, val_loader: DataLoader, device: torch.device
) -> float:
model.eval()
total = 0.0
n_tokens = 0
for x, y in val_loader:
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1),
)
total += loss.item() * y.numel()
n_tokens += y.numel()
model.train()
return total / max(1, n_tokens)
def save_checkpoint(
path: Path,
model: GPT,
optimizer: AdamW,
scheduler: LambdaLR,
global_step: int,
epoch: int,
config_dict: Dict[str, Any],
) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"global_step": global_step,
"epoch": epoch,
"config": config_dict,
},
path,
)
def save_best(
path: Path,
model: GPT,
val_loss: float,
global_step: int,
config_dict: Dict[str, Any],
) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"model": model.state_dict(),
"val_loss": val_loss,
"global_step": global_step,
"config": config_dict,
},
path,
)
def append_csv_row(
csv_path: Path,
fieldnames: list[str],
row: Dict[str, Any],
write_header: bool,
) -> None:
csv_path.parent.mkdir(parents=True, exist_ok=True)
with open(csv_path, "a", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
if write_header:
w.writeheader()
w.writerow(row)
def _pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def train(args: argparse.Namespace) -> None:
device = _pick_device()
print(f"[train] device={device}")
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
train_loader, val_loader, stats = build_dataloaders(
sample_dir=Path(args.sample_dir) if args.sample_dir else None,
block_size=args.block_size,
batch_size=args.batch_size,
split_ratio=args.split_ratio,
seed=args.seed,
)
print(
f"[train] data: train_chunks={stats.n_train_chunks} "
f"val_chunks={stats.n_val_chunks} tokens={stats.n_tokens_total}"
)
cfg = default_gpt_config()
cfg.block_size = args.block_size
cfg.dropout = args.dropout
cfg.vocab_size = stats.vocab_size
model = GPT(cfg).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[train] parameters={n_params:,} (~{n_params / 1e6:.2f}M)")
base_lr = 3e-4
optimizer = AdamW(
model.parameters(),
lr=base_lr,
betas=(0.9, 0.95),
weight_decay=0.1,
)
steps_per_epoch = len(train_loader)
total_steps = max(1, args.max_epochs * steps_per_epoch)
if total_steps < args.warmup_steps:
print(
f"[train] warning: total_steps={total_steps} < "
f"warmup={args.warmup_steps}; LR schedule may be odd."
)
scheduler = LambdaLR(
optimizer,
_lr_lambda_factory(args.warmup_steps, total_steps),
last_epoch=-1,
)
config_dict: Dict[str, Any] = asdict(cfg)
config_dict.update(
{
"vocab_size": stats.vocab_size,
"n_bpe_merges": stats.n_bpe_merges,
"max_epochs": args.max_epochs,
"batch_size": args.batch_size,
"seed": args.seed,
}
)
results_dir = Path(args.results_dir)
log_csv = results_dir / "training_log.csv"
ckpt_dir = results_dir / "checkpoints"
best_path = ckpt_dir / "best_model.pt"
fieldnames = [
"step",
"epoch",
"lr",
"train_loss",
"val_loss",
"train_ppl",
"val_ppl",
]
if not log_csv.exists():
log_csv.parent.mkdir(parents=True, exist_ok=True)
with open(log_csv, "w", newline="") as f:
csv.DictWriter(f, fieldnames=fieldnames).writeheader()
random_ce = math.log(stats.vocab_size)
print(
f"[train] random baseline CE≈{random_ce:.3f} (nats), "
f"ppl≈{math.exp(random_ce):.1f} (≈vocab {stats.vocab_size})"
)
best_val = float("inf")
global_step = 0
train_loss_accum = 0.0
train_loss_count = 0
last_val_loss: Optional[float] = None
use_wandb = False
wandb = None
try:
_wandb = importlib.import_module("wandb")
_wandb.init(
project="bach-gpt",
name="v2-25M-5k-files",
config={
"d_model": cfg.d_model,
"n_layers": cfg.n_layers,
"n_heads": cfg.n_heads,
"d_ff": cfg.d_ff,
"block_size": cfg.block_size,
"batch_size": args.batch_size,
"max_epochs": args.max_epochs,
"warmup_steps": args.warmup_steps,
"sample_dir": args.sample_dir or "sample_5k",
},
)
wandb = _wandb
use_wandb = True
except Exception:
print("[train] wandb not available, logging to CSV only")
model.train()
t0 = time.perf_counter()
try:
for epoch in range(args.max_epochs):
for x, y in train_loader:
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y.reshape(-1),
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
train_loss_accum += loss.item()
train_loss_count += 1
lr = optimizer.param_groups[0]["lr"]
if global_step % args.train_log_every == 0:
avg_train = train_loss_accum / max(1, train_loss_count)
try:
train_ppl = math.exp(avg_train)
except OverflowError:
train_ppl = float("inf")
print(
f"[train] step={global_step} epoch={epoch} "
f"train_loss={avg_train:.4f} "
f"train_ppl={train_ppl:.2f} "
f"lr={lr:.2e}"
)
if use_wandb and wandb is not None:
wandb.log(
{
"train/loss": avg_train,
"train/ppl": train_ppl,
"lr": lr,
},
step=global_step,
)
append_csv_row(
log_csv,
fieldnames,
{
"step": global_step,
"epoch": epoch,
"lr": lr,
"train_loss": f"{avg_train:.6f}",
"val_loss": (
""
if last_val_loss is None
else f"{last_val_loss:.6f}"
),
"train_ppl": f"{train_ppl:.4f}",
"val_ppl": (
""
if last_val_loss is None
else f"{math.exp(last_val_loss):.4f}"
),
},
write_header=False,
)
train_loss_accum = 0.0
train_loss_count = 0
if global_step % args.val_every == 0:
val_loss = evaluate(model, val_loader, device)
last_val_loss = val_loss
val_ppl = math.exp(val_loss)
print(
f"[val] step={global_step} val_loss={val_loss:.4f} "
f"val_ppl={val_ppl:.2f}"
)
if use_wandb and wandb is not None:
wandb.log(
{
"val/loss": val_loss,
"val/ppl": val_ppl,
},
step=global_step,
)
append_csv_row(
log_csv,
fieldnames,
{
"step": global_step,
"epoch": epoch,
"lr": lr,
"train_loss": "",
"val_loss": f"{val_loss:.6f}",
"train_ppl": "",
"val_ppl": f"{val_ppl:.4f}",
},
write_header=False,
)
if val_loss < best_val:
best_val = val_loss
save_best(
best_path,
model,
val_loss,
global_step,
config_dict,
)
print(
f"[train] new best val_loss={val_loss:.4f} "
f"→ {best_path}"
)
if global_step % args.checkpoint_every == 0:
ckpt_path = ckpt_dir / f"checkpoint_step_{global_step}.pt"
save_checkpoint(
ckpt_path,
model,
optimizer,
scheduler,
global_step,
epoch,
config_dict,
)
print(f"[train] saved {ckpt_path}")
finally:
if use_wandb and wandb is not None:
wandb.finish()
elapsed = time.perf_counter() - t0
print(
f"[train] finished in {elapsed / 60:.1f} min, "
f"best_val={best_val:.4f}"
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Train bach-gpt on MIDI tokens")
p.add_argument("--max-epochs", type=int, default=10)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--block-size", type=int, default=512)
p.add_argument("--split-ratio", type=float, default=0.9)
p.add_argument("--dropout", type=float, default=0.1)
p.add_argument("--seed", type=int, default=17)
p.add_argument("--warmup-steps", type=int, default=100)
p.add_argument("--train-log-every", type=int, default=50)
p.add_argument("--val-every", type=int, default=500)
p.add_argument("--checkpoint-every", type=int, default=500)
p.add_argument(
"--sample-dir",
type=str,
default="",
help=(
"Override GigaMIDI sample directory "
"(default: data/gigamidi/sample)"
),
)
p.add_argument(
"--results-dir",
type=str,
default=str(_ROOT / "results"),
help="Directory for training_log.csv and checkpoints/",
)
return p.parse_args()
if __name__ == "__main__":
train(parse_args())