|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import gc |
|
|
import math |
|
|
import time |
|
|
import json |
|
|
import argparse |
|
|
from datetime import datetime |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from model import LunaConfig, Luna, N_FEATURES |
|
|
|
|
|
|
|
|
gc.disable() |
|
|
|
|
|
class DataLoaderLite: |
|
|
|
|
|
|
|
|
def __init__(self, tokens_path: str, n_tokens: int, B: int, T: int, device: str = 'cuda'): |
|
|
self.B = B |
|
|
self.T = T |
|
|
self.device = device |
|
|
self.n_tokens = n_tokens |
|
|
|
|
|
|
|
|
print(f"Memory-mapping {tokens_path}...") |
|
|
self.tokens = np.memmap(tokens_path, dtype=np.int32, mode='r', shape=(n_tokens, N_FEATURES)) |
|
|
|
|
|
|
|
|
file_size_gb = (n_tokens * N_FEATURES * 4) / 1e9 |
|
|
print(f" {n_tokens:,} tokens ({file_size_gb:.2f} GB on disk, memory-mapped)") |
|
|
|
|
|
self.current_position = 0 |
|
|
self.n_batches = (n_tokens - T - 1) // (B * T) |
|
|
print(f" {self.n_batches:,} batches available") |
|
|
|
|
|
def reset(self): |
|
|
self.current_position = 0 |
|
|
|
|
|
def next_batch(self): |
|
|
B, T = self.B, self.T |
|
|
|
|
|
|
|
|
|
|
|
tokens_needed = B * T + 1 |
|
|
|
|
|
|
|
|
end_pos = self.current_position + tokens_needed |
|
|
buf = self.tokens[self.current_position : end_pos] |
|
|
|
|
|
|
|
|
buf = torch.from_numpy(buf.astype(np.int64)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = buf[:-1].view(B, T, N_FEATURES) |
|
|
y = buf[1:].view(B, T, N_FEATURES) |
|
|
|
|
|
|
|
|
self.current_position += B * T |
|
|
|
|
|
|
|
|
if self.current_position + tokens_needed > self.n_tokens: |
|
|
self.current_position = 0 |
|
|
|
|
|
|
|
|
return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(args): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
device_type = "cuda" if device == "cuda" else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print(f" GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
print(f" Compute: {torch.cuda.get_device_capability()}") |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
|
|
|
|
|
|
torch.manual_seed(1337) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(1337) |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
|
config_path = os.path.join(args.data_dir, "config.json") |
|
|
with open(config_path) as f: |
|
|
data_config = json.load(f) |
|
|
|
|
|
vocab_sizes = data_config['vocab_sizes'] |
|
|
train_tokens = data_config['train_tokens'] |
|
|
val_tokens = data_config['val_tokens'] |
|
|
|
|
|
|
|
|
tokens_per_step = args.batch_size * args.block_size * args.grad_accum_steps |
|
|
max_steps = int(train_tokens * args.epochs / tokens_per_step) |
|
|
warmup_steps = max(100, max_steps // 100) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("Luna Training") |
|
|
print(f"{'='*70}") |
|
|
print(f"Train tokens: {train_tokens:,}") |
|
|
print(f"Batch size: {args.batch_size}") |
|
|
print(f"Block size: {args.block_size}") |
|
|
print(f"Grad accum: {args.grad_accum_steps}") |
|
|
print(f"Effective batch: {tokens_per_step:,} tokens") |
|
|
print(f"Max steps: {max_steps:,}") |
|
|
print(f"Warmup steps: {warmup_steps}") |
|
|
|
|
|
|
|
|
train_path = os.path.join(args.data_dir, "train_tokens.dat") |
|
|
val_path = os.path.join(args.data_dir, "val_tokens.dat") |
|
|
|
|
|
train_loader = DataLoaderLite(train_path, train_tokens, args.batch_size, args.block_size, device) |
|
|
val_loader = DataLoaderLite(val_path, val_tokens, args.batch_size, args.block_size, device) |
|
|
|
|
|
|
|
|
model_config = LunaConfig( |
|
|
syllable_vocab=vocab_sizes['syllables'], |
|
|
onset_vocab=vocab_sizes['onsets'], |
|
|
nucleus_vocab=vocab_sizes['nuclei'], |
|
|
coda_vocab=vocab_sizes['codas'], |
|
|
n_layer=args.n_layer, |
|
|
n_head=args.n_head, |
|
|
n_embd=args.n_embd, |
|
|
max_seq_len=args.block_size, |
|
|
dropout=args.dropout if not args.compile else 0.0, |
|
|
fuse_output_heads=True, |
|
|
) |
|
|
|
|
|
model = Luna(model_config) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
start_step = 0 |
|
|
best_val_loss = float('inf') |
|
|
|
|
|
if args.resume: |
|
|
print(f"\nResuming from: {args.resume}") |
|
|
checkpoint = torch.load(args.resume, map_location=device, weights_only=False) |
|
|
|
|
|
state_dict = checkpoint['model'] |
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('_orig_mod.'): |
|
|
new_state_dict[k[10:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
start_step = checkpoint.get('step', 0) |
|
|
best_val_loss = checkpoint.get('val_loss', float('inf')) |
|
|
print(f" Resumed from step {start_step}, val_loss: {best_val_loss:.4f}") |
|
|
|
|
|
|
|
|
if args.compile: |
|
|
print("\nCompiling model with torch.compile()...") |
|
|
|
|
|
model = torch.compile(model) |
|
|
|
|
|
|
|
|
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} |
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
|
|
|
|
optim_groups = [ |
|
|
{'params': decay_params, 'weight_decay': 0.1}, |
|
|
{'params': nodecay_params, 'weight_decay': 0.0} |
|
|
] |
|
|
|
|
|
print(f"\nOptimizer:") |
|
|
print(f" Decayed: {sum(p.numel() for p in decay_params):,}") |
|
|
print(f" Non-decayed: {sum(p.numel() for p in nodecay_params):,}") |
|
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=args.lr, betas=(0.9, 0.95), eps=1e-8, fused=True) |
|
|
|
|
|
|
|
|
resume_optimizer_state = None |
|
|
if args.resume and 'optimizer' in checkpoint: |
|
|
resume_optimizer_state = checkpoint['optimizer'] |
|
|
print(f" Optimizer state will be restored after compile") |
|
|
|
|
|
|
|
|
max_lr = args.lr |
|
|
min_lr = max_lr * 0.1 |
|
|
|
|
|
def get_lr(it): |
|
|
if it < warmup_steps: |
|
|
return max_lr * (it + 1) / warmup_steps |
|
|
if it > max_steps: |
|
|
return min_lr |
|
|
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps) |
|
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
|
|
|
|
if args.resume: |
|
|
log_dir = os.path.dirname(args.resume) |
|
|
print(f" Continuing in log_dir: {log_dir}") |
|
|
else: |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
log_dir = os.path.join(args.log_dir, f"Luna_{timestamp}") |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if resume_optimizer_state is not None: |
|
|
try: |
|
|
optimizer.load_state_dict(resume_optimizer_state) |
|
|
print(f" Optimizer state restored!") |
|
|
except Exception as e: |
|
|
print(f" Warning: Could not restore optimizer state: {e}") |
|
|
|
|
|
|
|
|
if args.resume: |
|
|
train_loader.current_position = (start_step * args.batch_size * args.block_size) % train_loader.n_tokens |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("Starting Training") |
|
|
print(f"{'='*70}") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for step in range(start_step, max_steps): |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
if step % args.eval_interval == 0 or step == max_steps - 1: |
|
|
if device_type == "cuda": |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
model.eval() |
|
|
val_loader.reset() |
|
|
|
|
|
with torch.no_grad(): |
|
|
val_loss_accum = 0.0 |
|
|
val_steps = 20 |
|
|
for _ in range(val_steps): |
|
|
x, y = val_loader.next_batch() |
|
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16): |
|
|
logits, loss = model(x, y) |
|
|
val_loss_accum += loss.item() |
|
|
val_loss = val_loss_accum / val_steps |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
tokens_so_far = step * tokens_per_step |
|
|
tok_per_sec = tokens_so_far / elapsed if elapsed > 0 else 0 |
|
|
|
|
|
print(f"\n[Step {step:,}] val_loss: {val_loss:.4f} | {tok_per_sec:,.0f} tok/s") |
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
torch.save({ |
|
|
'model': model.state_dict(), |
|
|
'config': model_config, |
|
|
'step': step, |
|
|
'val_loss': val_loss, |
|
|
}, os.path.join(log_dir, "model_best.pt")) |
|
|
print(f" ✓ New best model saved!¯\_(ツ)_/¯") |
|
|
|
|
|
if device_type == "cuda": |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
loss_accum = 0.0 |
|
|
|
|
|
for micro_step in range(args.grad_accum_steps): |
|
|
x, y = train_loader.next_batch() |
|
|
|
|
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16): |
|
|
logits, loss = model(x, y) |
|
|
|
|
|
loss = loss / args.grad_accum_steps |
|
|
loss_accum += loss.detach() |
|
|
loss.backward() |
|
|
|
|
|
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
|
|
lr = get_lr(step) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
if device_type == "cuda": |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
t1 = time.time() |
|
|
dt = t1 - t0 |
|
|
tokens_this_step = tokens_per_step |
|
|
tok_per_sec = tokens_this_step / dt |
|
|
|
|
|
if step % 10 == 0: |
|
|
print(f"step {step:5d} | loss: {loss_accum.item():.4f} | lr {lr:.2e} | norm: {norm:.2f} | dt: {dt*1000:.0f}ms | tok/s: {tok_per_sec:,.0f}") |
|
|
|
|
|
|
|
|
if step > 0 and step % 5000 == 0: |
|
|
torch.save({ |
|
|
'model': model.state_dict(), |
|
|
'config': model_config, |
|
|
'step': step, |
|
|
'val_loss': best_val_loss, |
|
|
'optimizer': optimizer.state_dict(), |
|
|
}, os.path.join(log_dir, "checkpoint_latest.pt")) |
|
|
print(f" Checkpoint saved at step {step}") |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
'model': model.state_dict(), |
|
|
'config': model_config, |
|
|
'step': max_steps, |
|
|
'val_loss': val_loss, |
|
|
}, os.path.join(log_dir, "model_final.pt")) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print("Training Complete") |
|
|
print(f"{'='*70}") |
|
|
print(f" Best val loss: {best_val_loss:.4f}") |
|
|
print(f" Total time: {total_time/60:.1f} min") |
|
|
print(f" Avg throughput: {max_steps * tokens_per_step / total_time:,.0f} tok/s") |
|
|
print(f" Model saved: {log_dir}") |
|
|
|
|
|
gc.enable() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train Luna") |
|
|
|
|
|
parser.add_argument("--data_dir", type=str, required=True) |
|
|
parser.add_argument("--n_layer", type=int, default=12) |
|
|
parser.add_argument("--n_head", type=int, default=12) |
|
|
parser.add_argument("--n_embd", type=int, default=768) |
|
|
parser.add_argument("--dropout", type=float, default=0.1) |
|
|
parser.add_argument("--batch_size", type=int, default=8) |
|
|
parser.add_argument("--block_size", type=int, default=1024) |
|
|
parser.add_argument("--grad_accum_steps", type=int, default=2) |
|
|
parser.add_argument("--lr", type=float, default=6e-4) |
|
|
parser.add_argument("--epochs", type=float, default=1.0) |
|
|
parser.add_argument("--compile", action="store_true") |
|
|
parser.add_argument("--resume", type=str, default=None) |
|
|
parser.add_argument("--eval_interval", type=int, default=5000) |
|
|
parser.add_argument("--log_dir", type=str, default="./logs") |
|
|
|
|
|
args = parser.parse_args() |
|
|
train(args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |