File size: 14,520 Bytes
1262b25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 | """
SID-GPT v2 training script.
nanoGPT-style training loop with frame-aligned batch sampling,
cosine LR schedule, gradient accumulation, and AMP support.
Supports single-GPU and multi-GPU (DDP via torchrun).
"""
import argparse
import math
import os
import struct
import time
from contextlib import nullcontext
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from model import ModelConfig, Transformer
TOKEN_SEP = 256
TOKEN_FRAME = 257
TOKENS_PER_FRAME = 26
BYTES_PER_FRAME = 25
def setup_ddp():
"""
Auto-detect DDP: torchrun sets RANK/LOCAL_RANK env vars.
Returns (rank, local_rank, world_size, is_ddp).
Without torchrun, returns (0, 0, 1, False).
"""
if "RANK" not in os.environ:
return 0, 0, 1, False
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl")
return rank, local_rank, world_size, True
def get_device(requested: str) -> str:
if requested != "auto":
return requested
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_dtype(requested: str, device: str) -> torch.dtype:
if requested == "bfloat16":
if device == "cuda" and torch.cuda.is_bf16_supported():
return torch.bfloat16
print("[WARN] bfloat16 not supported, falling back to float16")
return torch.float16
if requested == "float16":
return torch.float16
return torch.float32
def load_data(path: str, device: str) -> torch.Tensor:
raw = np.fromfile(path, dtype=np.uint16)
print(f"[DATA] Loaded {len(raw)} tokens from {path}")
return torch.from_numpy(raw.astype(np.int64)).to(device)
def generate_synth_data(device: str) -> torch.Tensor:
"""
Generate synthetic training data: ~20 short songs with
deterministic patterns (ascending frequencies, simple ADSR)
for end-to-end pipeline testing without HVSC data.
"""
tokens = []
rng = np.random.RandomState(42)
for song_idx in range(20):
# SEP frame
tokens.extend([TOKEN_SEP] * TOKENS_PER_FRAME)
num_frames = 80 + song_idx * 5
base_freq = 1000 + song_idx * 200
for f in range(num_frames):
tokens.append(TOKEN_FRAME)
regs = [0] * BYTES_PER_FRAME
# Voice 1: ascending frequency
freq = (base_freq + f * 50) & 0xFFFF
regs[0] = freq & 0xFF
regs[1] = (freq >> 8) & 0xFF
# Pulse width
regs[2] = 0x00
regs[3] = 0x08
# Control: gate on, triangle
regs[4] = 0x11 if f < num_frames - 5 else 0x10
# ADSR
regs[5] = 0x09
regs[6] = 0x00
# Voice 2: harmony (offset frequency)
freq2 = (base_freq + f * 37 + 500) & 0xFFFF
regs[7] = freq2 & 0xFF
regs[8] = (freq2 >> 8) & 0xFF
regs[9] = 0x00
regs[10] = 0x08
regs[11] = 0x21 if f % 16 < 12 else 0x20
regs[12] = 0x0A
regs[13] = 0x00
# Voice 3: bass (slow frequency)
freq3 = (base_freq // 2 + f * 10) & 0xFFFF
regs[14] = freq3 & 0xFF
regs[15] = (freq3 >> 8) & 0xFF
regs[16] = 0x00
regs[17] = 0x04
regs[18] = 0x41 if f % 32 < 24 else 0x40
regs[19] = 0x0C
regs[20] = 0x00
# Filter + volume
regs[21] = 0x00
regs[22] = rng.randint(0, 8)
regs[23] = 0x00
regs[24] = 0x0F
tokens.extend(regs)
data = np.array(tokens, dtype=np.uint16)
print(f"[SYNTH] Generated {len(data)} tokens ({20} songs)")
return torch.from_numpy(data.astype(np.int64)).to(device)
def split_data(data, block_size):
"""Split at frame-aligned boundary (multiple of 26)."""
n = len(data)
split_tok = int(n * 0.95)
# Align to frame boundary
split_tok = (split_tok // TOKENS_PER_FRAME) * TOKENS_PER_FRAME
return data[:split_tok], data[split_tok:]
def get_batch(data, block_size, batch_size, device):
"""
Frame-aligned batch sampling. Offsets are multiples of 26
so sequences always start on frame boundaries.
"""
max_start = (len(data) - block_size - 1) // TOKENS_PER_FRAME
if max_start < 1:
max_start = 1
offsets = torch.randint(max_start, (batch_size,)) * TOKENS_PER_FRAME
x = torch.stack([data[o : o + block_size] for o in offsets])
y = torch.stack(
[data[o + 1 : o + 1 + block_size] for o in offsets]
)
return x.to(device), y.to(device)
@torch.no_grad()
def estimate_loss(
model, train_data, val_data, config, args, device,
):
model.eval()
out = {}
for name, data in [("train", train_data), ("val", val_data)]:
losses = []
for _ in range(args.eval_iters):
x, y = get_batch(
data, config.block_size,
args.batch_size, device,
)
with torch.amp.autocast(
device_type=device.split(":")[0],
dtype=args.amp_dtype,
):
_, loss = model(x, y)
losses.append(loss.item())
out[name] = sum(losses) / len(losses)
model.train()
return out
def get_lr(step, args):
"""
Cosine LR schedule with linear warmup.
Decays from lr to min_lr over max_steps.
"""
if step < args.warmup:
return args.lr * (step + 1) / args.warmup
if step >= args.max_steps:
return args.min_lr
progress = (step - args.warmup) / (args.max_steps - args.warmup)
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
return args.min_lr + coeff * (args.lr - args.min_lr)
def configure_optimizer(model, args, device):
# Separate params: decay 2D+ params, no decay for 1D (norms, biases)
decay_params = []
no_decay_params = []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if p.dim() >= 2:
decay_params.append(p)
else:
no_decay_params.append(p)
groups = [
{"params": decay_params, "weight_decay": args.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
use_fused = device.startswith("cuda")
optimizer = torch.optim.AdamW(
groups,
lr=args.lr,
betas=(args.beta1, args.beta2),
fused=use_fused,
)
return optimizer
def save_checkpoint(model, optimizer, config, step, path):
torch.save(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"config": config,
"step": step,
},
path,
)
print(f"[CKPT] Saved {path}")
def main():
parser = argparse.ArgumentParser(
description="SID-GPT v2 training"
)
parser.add_argument("--data", type=str, default=None)
parser.add_argument(
"--config", type=str, default="small",
choices=["small", "large"],
)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--grad-accum", type=int, default=4)
parser.add_argument("--max-steps", type=int, default=5000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--min-lr", type=float, default=3e-5)
parser.add_argument("--warmup", type=int, default=200)
parser.add_argument("--weight-decay", type=float, default=0.1)
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.95)
parser.add_argument("--eval-interval", type=int, default=250)
parser.add_argument("--eval-iters", type=int, default=50)
parser.add_argument("--log-interval", type=int, default=10)
parser.add_argument(
"--out-dir", type=str, default="training/checkpoints"
)
parser.add_argument("--device", type=str, default="auto")
parser.add_argument(
"--dtype", type=str, default="bfloat16",
choices=["bfloat16", "float16", "float32"],
)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--seed", type=int, default=1337)
parser.add_argument("--synth", action="store_true")
parser.add_argument("--resume", type=str, default=None)
args = parser.parse_args()
if not args.synth and args.data is None and args.resume is None:
parser.error("--data or --synth or --resume required")
# Enable experimental Flash Attention on ROCm
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"
# DDP setup (auto-detect torchrun)
rank, local_rank, world_size, is_ddp = setup_ddp()
is_master = rank == 0
if is_ddp:
device = f"cuda:{local_rank}"
device_type = "cuda"
else:
device = get_device(args.device)
device_type = device.split(":")[0]
torch.manual_seed(args.seed + rank)
args.amp_dtype = get_dtype(args.dtype, device)
if is_master:
if is_ddp:
print(
f"[INIT] DDP: {world_size} GPUs, "
f"dtype: {args.amp_dtype}"
)
else:
print(
f"[INIT] Device: {device}, "
f"dtype: {args.amp_dtype}"
)
# Model config
if args.config == "large":
config = ModelConfig.large()
else:
config = ModelConfig.small()
start_step = 0
if args.resume:
if is_master:
print(f"[RESUME] Loading checkpoint {args.resume}")
ckpt = torch.load(
args.resume, map_location=device,
weights_only=False,
)
config = ckpt["config"]
model = Transformer(config).to(device)
model.load_state_dict(ckpt["model"])
start_step = ckpt["step"]
if is_master:
print(f"[RESUME] Resuming from step {start_step}")
else:
model = Transformer(config).to(device)
if is_master:
print(
f"[MODEL] {args.config}: "
f"{model.count_params():,} params, "
f"{config.n_layer}L/{config.n_head}H/"
f"{config.n_embd}D"
)
if args.compile and device_type == "cuda":
if is_master:
print("[COMPILE] torch.compile enabled")
model = torch.compile(model)
# Wrap in DDP after compile
if is_ddp:
model = DDP(model, device_ids=[local_rank])
raw_model = model.module if is_ddp else model
# Data
if args.synth:
data = generate_synth_data(device)
else:
data = load_data(args.data, device)
train_data, val_data = split_data(data, config.block_size)
if is_master:
print(
f"[DATA] Train: {len(train_data):,} tokens, "
f"Val: {len(val_data):,} tokens"
)
# Optimizer (on raw model params)
optimizer = configure_optimizer(raw_model, args, device)
if args.resume and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
# GradScaler only for float16
use_scaler = args.amp_dtype == torch.float16
scaler = torch.amp.GradScaler(enabled=use_scaler)
if is_master:
os.makedirs(args.out_dir, exist_ok=True)
# Training loop
model.train()
t0 = time.time()
for step in range(start_step, args.max_steps):
lr = get_lr(step, args)
for pg in optimizer.param_groups:
pg["lr"] = lr
# Eval (rank 0 only)
if (
step % args.eval_interval == 0
and step > 0
and is_master
):
losses = estimate_loss(
model, train_data, val_data,
config, args, device,
)
print(
f"[EVAL] step {step}: "
f"train={losses['train']:.4f}, "
f"val={losses['val']:.4f}"
)
save_checkpoint(
raw_model, optimizer, config, step,
os.path.join(
args.out_dir, f"ckpt_{step}.pt"
),
)
# Gradient accumulation
optimizer.zero_grad(set_to_none=True)
accum_loss = 0.0
for micro in range(args.grad_accum):
x, y = get_batch(
train_data, config.block_size,
args.batch_size, device,
)
with torch.amp.autocast(
device_type=device_type, dtype=args.amp_dtype
):
_, loss = model(x, y)
loss = loss / args.grad_accum
accum_loss += loss.item()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), 1.0
)
scaler.step(optimizer)
scaler.update()
# Logging (rank 0 only)
if step % args.log_interval == 0 and is_master:
dt = time.time() - t0
t0 = time.time()
if dt > 0 and step > start_step:
ms_per_step = (
dt / args.log_interval * 1000
)
tps = (
args.batch_size * args.grad_accum
* config.block_size
* args.log_interval
* world_size / dt
)
else:
ms_per_step = 0
tps = 0
print(
f"[TRAIN] step {step:5d} | "
f"loss {accum_loss:.4f} | "
f"lr {lr:.2e} | "
f"{ms_per_step:.0f}ms/step | "
f"{dt:.2f}s/{args.log_interval}steps | "
f"{tps/1e6:.2f}M tok/s"
)
# Final save (rank 0 only)
if is_master:
save_checkpoint(
raw_model, optimizer, config, args.max_steps,
os.path.join(
args.out_dir, f"ckpt_{args.max_steps}.pt"
),
)
losses = estimate_loss(
model, train_data, val_data,
config, args, device,
)
print(
f"[DONE] Final: train={losses['train']:.4f}, "
f"val={losses['val']:.4f}"
)
if is_ddp:
dist.destroy_process_group()
if __name__ == "__main__":
main()
|