File size: 13,834 Bytes
9c737ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | # Copyright 2026 Jakub Sykała
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gc
import math
import time
import json
import argparse
from datetime import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import LunaConfig, Luna, N_FEATURES
# Disable GC during training for consistent performance
gc.disable()
class DataLoaderLite:
def __init__(self, tokens_path: str, n_tokens: int, B: int, T: int, device: str = 'cuda'):
self.B = B
self.T = T
self.device = device
self.n_tokens = n_tokens
# Memory-map the file
print(f"Memory-mapping {tokens_path}...")
self.tokens = np.memmap(tokens_path, dtype=np.int32, mode='r', shape=(n_tokens, N_FEATURES))
# Calculate size
file_size_gb = (n_tokens * N_FEATURES * 4) / 1e9 # 4 bytes per int32
print(f" {n_tokens:,} tokens ({file_size_gb:.2f} GB on disk, memory-mapped)")
self.current_position = 0
self.n_batches = (n_tokens - T - 1) // (B * T)
print(f" {self.n_batches:,} batches available")
def reset(self):
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
# Calculate how many tokens we need: B sequences of T+1 each
# But they can overlap, so we need B*T + 1 tokens total
tokens_needed = B * T + 1
# Get the slice from memmap (this is fast - OS caches it)
end_pos = self.current_position + tokens_needed
buf = self.tokens[self.current_position : end_pos]
# Convert to torch tensor (only this small batch goes to RAM)
buf = torch.from_numpy(buf.astype(np.int64)) # [B*T+1, 9]
# Reshape: create B sequences of length T+1
# x[i] = buf[i*T : i*T + T]
# y[i] = buf[i*T + 1 : i*T + T + 1]
# Efficient reshape using view
x = buf[:-1].view(B, T, N_FEATURES) # [B, T, 9]
y = buf[1:].view(B, T, N_FEATURES) # [B, T, 9] - shifted by 1
# Advance position
self.current_position += B * T
# Wrap around if we'd go past the end
if self.current_position + tokens_needed > self.n_tokens:
self.current_position = 0
# Non-blocking transfer to GPU
return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True)
# ==============================================================================
# TRAINING
# ==============================================================================
def train(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
device_type = "cuda" if device == "cuda" else "cpu"
print(f"Using device: {device}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f" Compute: {torch.cuda.get_device_capability()}")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Seeds
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
torch.set_float32_matmul_precision('high')
# Load config
config_path = os.path.join(args.data_dir, "config.json")
with open(config_path) as f:
data_config = json.load(f)
vocab_sizes = data_config['vocab_sizes']
train_tokens = data_config['train_tokens']
val_tokens = data_config['val_tokens']
# Calculate steps
tokens_per_step = args.batch_size * args.block_size * args.grad_accum_steps
max_steps = int(train_tokens * args.epochs / tokens_per_step)
warmup_steps = max(100, max_steps // 100)
print(f"\n{'='*70}")
print("Luna Training")
print(f"{'='*70}")
print(f"Train tokens: {train_tokens:,}")
print(f"Batch size: {args.batch_size}")
print(f"Block size: {args.block_size}")
print(f"Grad accum: {args.grad_accum_steps}")
print(f"Effective batch: {tokens_per_step:,} tokens")
print(f"Max steps: {max_steps:,}")
print(f"Warmup steps: {warmup_steps}")
# Data loaders
train_path = os.path.join(args.data_dir, "train_tokens.dat")
val_path = os.path.join(args.data_dir, "val_tokens.dat")
train_loader = DataLoaderLite(train_path, train_tokens, args.batch_size, args.block_size, device)
val_loader = DataLoaderLite(val_path, val_tokens, args.batch_size, args.block_size, device)
# Create model
model_config = LunaConfig(
syllable_vocab=vocab_sizes['syllables'],
onset_vocab=vocab_sizes['onsets'],
nucleus_vocab=vocab_sizes['nuclei'],
coda_vocab=vocab_sizes['codas'],
n_layer=args.n_layer,
n_head=args.n_head,
n_embd=args.n_embd,
max_seq_len=args.block_size,
dropout=args.dropout if not args.compile else 0.0,
fuse_output_heads=True,
)
model = Luna(model_config)
model.to(device)
# Resume checkpoint BEFORE compile
start_step = 0
best_val_loss = float('inf')
if args.resume:
print(f"\nResuming from: {args.resume}")
checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
state_dict = checkpoint['model']
# Handle compiled model prefix
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('_orig_mod.'):
new_state_dict[k[10:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict, strict=False)
start_step = checkpoint.get('step', 0)
best_val_loss = checkpoint.get('val_loss', float('inf'))
print(f" Resumed from step {start_step}, val_loss: {best_val_loss:.4f}")
# torch.compile AFTER loading checkpoint
if args.compile:
print("\nCompiling model with torch.compile()...")
# Use default mode - more stable than reduce-overhead
model = torch.compile(model)
# Optimizer with proper weight decay
param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': 0.1},
{'params': nodecay_params, 'weight_decay': 0.0}
]
print(f"\nOptimizer:")
print(f" Decayed: {sum(p.numel() for p in decay_params):,}")
print(f" Non-decayed: {sum(p.numel() for p in nodecay_params):,}")
optimizer = torch.optim.AdamW(optim_groups, lr=args.lr, betas=(0.9, 0.95), eps=1e-8, fused=True)
# Load optimizer state if resuming
resume_optimizer_state = None
if args.resume and 'optimizer' in checkpoint:
resume_optimizer_state = checkpoint['optimizer']
print(f" Optimizer state will be restored after compile")
# LR schedule
max_lr = args.lr
min_lr = max_lr * 0.1
def get_lr(it):
if it < warmup_steps:
return max_lr * (it + 1) / warmup_steps
if it > max_steps:
return min_lr
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
# Logging - use existing dir if resuming, else create new
if args.resume:
log_dir = os.path.dirname(args.resume)
print(f" Continuing in log_dir: {log_dir}")
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join(args.log_dir, f"Luna_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
# Restore optimizer state after everything is set up
if resume_optimizer_state is not None:
try:
optimizer.load_state_dict(resume_optimizer_state)
print(f" Optimizer state restored!")
except Exception as e:
print(f" Warning: Could not restore optimizer state: {e}")
# Set data position if resuming
if args.resume:
train_loader.current_position = (start_step * args.batch_size * args.block_size) % train_loader.n_tokens
print(f"\n{'='*70}")
print("Starting Training")
print(f"{'='*70}")
start_time = time.time()
for step in range(start_step, max_steps):
t0 = time.time()
# Evaluation
if step % args.eval_interval == 0 or step == max_steps - 1:
if device_type == "cuda":
torch.cuda.synchronize()
model.eval()
val_loader.reset()
with torch.no_grad():
val_loss_accum = 0.0
val_steps = 20
for _ in range(val_steps):
x, y = val_loader.next_batch()
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
val_loss_accum += loss.item()
val_loss = val_loss_accum / val_steps
elapsed = time.time() - start_time
tokens_so_far = step * tokens_per_step
tok_per_sec = tokens_so_far / elapsed if elapsed > 0 else 0
print(f"\n[Step {step:,}] val_loss: {val_loss:.4f} | {tok_per_sec:,.0f} tok/s")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'model': model.state_dict(),
'config': model_config,
'step': step,
'val_loss': val_loss,
}, os.path.join(log_dir, "model_best.pt"))
print(f" ✓ New best model saved!¯\_(ツ)_/¯")
if device_type == "cuda":
torch.cuda.synchronize()
model.train()
# Training step
optimizer.zero_grad(set_to_none=True)
loss_accum = 0.0
for micro_step in range(args.grad_accum_steps):
x, y = train_loader.next_batch()
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / args.grad_accum_steps
loss_accum += loss.detach()
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
if device_type == "cuda":
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
tokens_this_step = tokens_per_step
tok_per_sec = tokens_this_step / dt
if step % 10 == 0:
print(f"step {step:5d} | loss: {loss_accum.item():.4f} | lr {lr:.2e} | norm: {norm:.2f} | dt: {dt*1000:.0f}ms | tok/s: {tok_per_sec:,.0f}")
# Save checkpoint every 5000 steps for safe resume
if step > 0 and step % 5000 == 0:
torch.save({
'model': model.state_dict(),
'config': model_config,
'step': step,
'val_loss': best_val_loss,
'optimizer': optimizer.state_dict(),
}, os.path.join(log_dir, "checkpoint_latest.pt"))
print(f" Checkpoint saved at step {step}")
# Final save
torch.save({
'model': model.state_dict(),
'config': model_config,
'step': max_steps,
'val_loss': val_loss,
}, os.path.join(log_dir, "model_final.pt"))
total_time = time.time() - start_time
print(f"\n{'='*70}")
print("Training Complete")
print(f"{'='*70}")
print(f" Best val loss: {best_val_loss:.4f}")
print(f" Total time: {total_time/60:.1f} min")
print(f" Avg throughput: {max_steps * tokens_per_step / total_time:,.0f} tok/s")
print(f" Model saved: {log_dir}")
gc.enable()
gc.collect()
def main():
parser = argparse.ArgumentParser(description="Train Luna")
parser.add_argument("--data_dir", type=str, required=True)
parser.add_argument("--n_layer", type=int, default=12)
parser.add_argument("--n_head", type=int, default=12)
parser.add_argument("--n_embd", type=int, default=768)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--block_size", type=int, default=1024)
parser.add_argument("--grad_accum_steps", type=int, default=2)
parser.add_argument("--lr", type=float, default=6e-4)
parser.add_argument("--epochs", type=float, default=1.0)
parser.add_argument("--compile", action="store_true")
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--eval_interval", type=int, default=5000)
parser.add_argument("--log_dir", type=str, default="./logs")
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main() |