Spaces:
Sleeping
Sleeping
| """Train GPT on MIDI token chunks: checkpoints, CSV log, val tracking.""" | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import importlib | |
| import math | |
| import sys | |
| import time | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from torch.utils.data import DataLoader | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from dataset import build_dataloaders # noqa: E402 | |
| from model import GPT, default_gpt_config # noqa: E402 | |
| def _lr_lambda_factory(warmup_steps: int, total_steps: int): | |
| """Warmup then cosine: LR multiplier 1.0 → 0.1 over non-warmup steps.""" | |
| def lr_lambda(current_step: int) -> float: | |
| if current_step < warmup_steps: | |
| return float(current_step + 1) / float(max(1, warmup_steps)) | |
| if total_steps <= warmup_steps: | |
| return 1.0 | |
| t = (current_step - warmup_steps) / float(total_steps - warmup_steps) | |
| t = min(1.0, max(0.0, t)) | |
| min_f = 0.1 | |
| return min_f + (1.0 - min_f) * 0.5 * (1.0 + math.cos(math.pi * t)) | |
| return lr_lambda | |
| def evaluate( | |
| model: GPT, val_loader: DataLoader, device: torch.device | |
| ) -> float: | |
| model.eval() | |
| total = 0.0 | |
| n_tokens = 0 | |
| for x, y in val_loader: | |
| x = x.to(device) | |
| y = y.to(device) | |
| logits = model(x) | |
| loss = F.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)), | |
| y.reshape(-1), | |
| ) | |
| total += loss.item() * y.numel() | |
| n_tokens += y.numel() | |
| model.train() | |
| return total / max(1, n_tokens) | |
| def save_checkpoint( | |
| path: Path, | |
| model: GPT, | |
| optimizer: AdamW, | |
| scheduler: LambdaLR, | |
| global_step: int, | |
| epoch: int, | |
| config_dict: Dict[str, Any], | |
| ) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| "global_step": global_step, | |
| "epoch": epoch, | |
| "config": config_dict, | |
| }, | |
| path, | |
| ) | |
| def save_best( | |
| path: Path, | |
| model: GPT, | |
| val_loss: float, | |
| global_step: int, | |
| config_dict: Dict[str, Any], | |
| ) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "model": model.state_dict(), | |
| "val_loss": val_loss, | |
| "global_step": global_step, | |
| "config": config_dict, | |
| }, | |
| path, | |
| ) | |
| def append_csv_row( | |
| csv_path: Path, | |
| fieldnames: list[str], | |
| row: Dict[str, Any], | |
| write_header: bool, | |
| ) -> None: | |
| csv_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(csv_path, "a", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=fieldnames) | |
| if write_header: | |
| w.writeheader() | |
| w.writerow(row) | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def train(args: argparse.Namespace) -> None: | |
| device = _pick_device() | |
| print(f"[train] device={device}") | |
| torch.manual_seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(args.seed) | |
| train_loader, val_loader, stats = build_dataloaders( | |
| sample_dir=Path(args.sample_dir) if args.sample_dir else None, | |
| block_size=args.block_size, | |
| batch_size=args.batch_size, | |
| split_ratio=args.split_ratio, | |
| seed=args.seed, | |
| ) | |
| print( | |
| f"[train] data: train_chunks={stats.n_train_chunks} " | |
| f"val_chunks={stats.n_val_chunks} tokens={stats.n_tokens_total}" | |
| ) | |
| cfg = default_gpt_config() | |
| cfg.block_size = args.block_size | |
| cfg.dropout = args.dropout | |
| cfg.vocab_size = stats.vocab_size | |
| model = GPT(cfg).to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| print(f"[train] parameters={n_params:,} (~{n_params / 1e6:.2f}M)") | |
| base_lr = 3e-4 | |
| optimizer = AdamW( | |
| model.parameters(), | |
| lr=base_lr, | |
| betas=(0.9, 0.95), | |
| weight_decay=0.1, | |
| ) | |
| steps_per_epoch = len(train_loader) | |
| total_steps = max(1, args.max_epochs * steps_per_epoch) | |
| if total_steps < args.warmup_steps: | |
| print( | |
| f"[train] warning: total_steps={total_steps} < " | |
| f"warmup={args.warmup_steps}; LR schedule may be odd." | |
| ) | |
| scheduler = LambdaLR( | |
| optimizer, | |
| _lr_lambda_factory(args.warmup_steps, total_steps), | |
| last_epoch=-1, | |
| ) | |
| config_dict: Dict[str, Any] = asdict(cfg) | |
| config_dict.update( | |
| { | |
| "vocab_size": stats.vocab_size, | |
| "n_bpe_merges": stats.n_bpe_merges, | |
| "max_epochs": args.max_epochs, | |
| "batch_size": args.batch_size, | |
| "seed": args.seed, | |
| } | |
| ) | |
| results_dir = Path(args.results_dir) | |
| log_csv = results_dir / "training_log.csv" | |
| ckpt_dir = results_dir / "checkpoints" | |
| best_path = ckpt_dir / "best_model.pt" | |
| fieldnames = [ | |
| "step", | |
| "epoch", | |
| "lr", | |
| "train_loss", | |
| "val_loss", | |
| "train_ppl", | |
| "val_ppl", | |
| ] | |
| if not log_csv.exists(): | |
| log_csv.parent.mkdir(parents=True, exist_ok=True) | |
| with open(log_csv, "w", newline="") as f: | |
| csv.DictWriter(f, fieldnames=fieldnames).writeheader() | |
| random_ce = math.log(stats.vocab_size) | |
| print( | |
| f"[train] random baseline CE≈{random_ce:.3f} (nats), " | |
| f"ppl≈{math.exp(random_ce):.1f} (≈vocab {stats.vocab_size})" | |
| ) | |
| best_val = float("inf") | |
| global_step = 0 | |
| train_loss_accum = 0.0 | |
| train_loss_count = 0 | |
| last_val_loss: Optional[float] = None | |
| use_wandb = False | |
| wandb = None | |
| try: | |
| _wandb = importlib.import_module("wandb") | |
| _wandb.init( | |
| project="bach-gpt", | |
| name="v2-25M-5k-files", | |
| config={ | |
| "d_model": cfg.d_model, | |
| "n_layers": cfg.n_layers, | |
| "n_heads": cfg.n_heads, | |
| "d_ff": cfg.d_ff, | |
| "block_size": cfg.block_size, | |
| "batch_size": args.batch_size, | |
| "max_epochs": args.max_epochs, | |
| "warmup_steps": args.warmup_steps, | |
| "sample_dir": args.sample_dir or "sample_5k", | |
| }, | |
| ) | |
| wandb = _wandb | |
| use_wandb = True | |
| except Exception: | |
| print("[train] wandb not available, logging to CSV only") | |
| model.train() | |
| t0 = time.perf_counter() | |
| try: | |
| for epoch in range(args.max_epochs): | |
| for x, y in train_loader: | |
| x = x.to(device) | |
| y = y.to(device) | |
| logits = model(x) | |
| loss = F.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)), | |
| y.reshape(-1), | |
| ) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| train_loss_accum += loss.item() | |
| train_loss_count += 1 | |
| lr = optimizer.param_groups[0]["lr"] | |
| if global_step % args.train_log_every == 0: | |
| avg_train = train_loss_accum / max(1, train_loss_count) | |
| try: | |
| train_ppl = math.exp(avg_train) | |
| except OverflowError: | |
| train_ppl = float("inf") | |
| print( | |
| f"[train] step={global_step} epoch={epoch} " | |
| f"train_loss={avg_train:.4f} " | |
| f"train_ppl={train_ppl:.2f} " | |
| f"lr={lr:.2e}" | |
| ) | |
| if use_wandb and wandb is not None: | |
| wandb.log( | |
| { | |
| "train/loss": avg_train, | |
| "train/ppl": train_ppl, | |
| "lr": lr, | |
| }, | |
| step=global_step, | |
| ) | |
| append_csv_row( | |
| log_csv, | |
| fieldnames, | |
| { | |
| "step": global_step, | |
| "epoch": epoch, | |
| "lr": lr, | |
| "train_loss": f"{avg_train:.6f}", | |
| "val_loss": ( | |
| "" | |
| if last_val_loss is None | |
| else f"{last_val_loss:.6f}" | |
| ), | |
| "train_ppl": f"{train_ppl:.4f}", | |
| "val_ppl": ( | |
| "" | |
| if last_val_loss is None | |
| else f"{math.exp(last_val_loss):.4f}" | |
| ), | |
| }, | |
| write_header=False, | |
| ) | |
| train_loss_accum = 0.0 | |
| train_loss_count = 0 | |
| if global_step % args.val_every == 0: | |
| val_loss = evaluate(model, val_loader, device) | |
| last_val_loss = val_loss | |
| val_ppl = math.exp(val_loss) | |
| print( | |
| f"[val] step={global_step} val_loss={val_loss:.4f} " | |
| f"val_ppl={val_ppl:.2f}" | |
| ) | |
| if use_wandb and wandb is not None: | |
| wandb.log( | |
| { | |
| "val/loss": val_loss, | |
| "val/ppl": val_ppl, | |
| }, | |
| step=global_step, | |
| ) | |
| append_csv_row( | |
| log_csv, | |
| fieldnames, | |
| { | |
| "step": global_step, | |
| "epoch": epoch, | |
| "lr": lr, | |
| "train_loss": "", | |
| "val_loss": f"{val_loss:.6f}", | |
| "train_ppl": "", | |
| "val_ppl": f"{val_ppl:.4f}", | |
| }, | |
| write_header=False, | |
| ) | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| save_best( | |
| best_path, | |
| model, | |
| val_loss, | |
| global_step, | |
| config_dict, | |
| ) | |
| print( | |
| f"[train] new best val_loss={val_loss:.4f} " | |
| f"→ {best_path}" | |
| ) | |
| if global_step % args.checkpoint_every == 0: | |
| ckpt_path = ckpt_dir / f"checkpoint_step_{global_step}.pt" | |
| save_checkpoint( | |
| ckpt_path, | |
| model, | |
| optimizer, | |
| scheduler, | |
| global_step, | |
| epoch, | |
| config_dict, | |
| ) | |
| print(f"[train] saved {ckpt_path}") | |
| finally: | |
| if use_wandb and wandb is not None: | |
| wandb.finish() | |
| elapsed = time.perf_counter() - t0 | |
| print( | |
| f"[train] finished in {elapsed / 60:.1f} min, " | |
| f"best_val={best_val:.4f}" | |
| ) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Train bach-gpt on MIDI tokens") | |
| p.add_argument("--max-epochs", type=int, default=10) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--block-size", type=int, default=512) | |
| p.add_argument("--split-ratio", type=float, default=0.9) | |
| p.add_argument("--dropout", type=float, default=0.1) | |
| p.add_argument("--seed", type=int, default=17) | |
| p.add_argument("--warmup-steps", type=int, default=100) | |
| p.add_argument("--train-log-every", type=int, default=50) | |
| p.add_argument("--val-every", type=int, default=500) | |
| p.add_argument("--checkpoint-every", type=int, default=500) | |
| p.add_argument( | |
| "--sample-dir", | |
| type=str, | |
| default="", | |
| help=( | |
| "Override GigaMIDI sample directory " | |
| "(default: data/gigamidi/sample)" | |
| ), | |
| ) | |
| p.add_argument( | |
| "--results-dir", | |
| type=str, | |
| default=str(_ROOT / "results"), | |
| help="Directory for training_log.csv and checkpoints/", | |
| ) | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| train(parse_args()) | |