|
|
|
|
|
"""
|
|
|
Production-ready Supernova training script.
|
|
|
Optimized for stability, monitoring, and memory efficiency.
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import math
|
|
|
import os
|
|
|
import sys
|
|
|
import time
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
from typing import Optional, Dict, Any
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import DataLoader
|
|
|
from transformers import get_cosine_schedule_with_warmup
|
|
|
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
|
|
|
|
|
|
def setup_logging(output_dir: str) -> logging.Logger:
|
|
|
"""Setup comprehensive logging."""
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
logger = logging.getLogger('supernova_training')
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
file_handler = logging.FileHandler(os.path.join(output_dir, 'training.log'))
|
|
|
file_handler.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
console_handler = logging.StreamHandler()
|
|
|
console_handler.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
|
file_handler.setFormatter(formatter)
|
|
|
console_handler.setFormatter(formatter)
|
|
|
|
|
|
logger.addHandler(file_handler)
|
|
|
logger.addHandler(console_handler)
|
|
|
|
|
|
return logger
|
|
|
|
|
|
|
|
|
def compute_grad_norm(model: nn.Module) -> float:
|
|
|
"""Compute gradient norm."""
|
|
|
total = 0.0
|
|
|
for p in model.parameters():
|
|
|
if p.grad is not None:
|
|
|
param_norm = p.grad.data.float().norm(2).item()
|
|
|
total += param_norm * param_norm
|
|
|
return math.sqrt(total)
|
|
|
|
|
|
|
|
|
def format_time(seconds: float) -> str:
|
|
|
"""Format seconds into readable time."""
|
|
|
if seconds < 60:
|
|
|
return f"{seconds:.1f}s"
|
|
|
elif seconds < 3600:
|
|
|
return f"{seconds//60:.0f}m{seconds%60:.0f}s"
|
|
|
else:
|
|
|
return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m"
|
|
|
|
|
|
|
|
|
def get_memory_usage() -> Dict[str, float]:
|
|
|
"""Get current memory usage."""
|
|
|
if torch.cuda.is_available():
|
|
|
allocated = torch.cuda.memory_allocated() / 1024**3
|
|
|
cached = torch.cuda.memory_reserved() / 1024**3
|
|
|
return {'allocated': allocated, 'cached': cached}
|
|
|
return {'allocated': 0, 'cached': 0}
|
|
|
|
|
|
|
|
|
def save_checkpoint(
|
|
|
model: nn.Module,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
scheduler: Any,
|
|
|
step: int,
|
|
|
loss: float,
|
|
|
best_loss: float,
|
|
|
config: Dict[str, Any],
|
|
|
path: str,
|
|
|
logger: logging.Logger
|
|
|
) -> None:
|
|
|
"""Save training checkpoint."""
|
|
|
try:
|
|
|
checkpoint = {
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
|
"config": config,
|
|
|
"step": step,
|
|
|
"loss": loss,
|
|
|
"best_loss": best_loss,
|
|
|
"timestamp": time.time(),
|
|
|
}
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
|
|
|
torch.save(checkpoint, path)
|
|
|
logger.info(f"💾 Checkpoint saved: {path} (loss: {loss:.4f})")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"❌ Failed to save checkpoint {path}: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
def validate_training_setup(
|
|
|
config_path: str,
|
|
|
data_config_path: str,
|
|
|
logger: logging.Logger
|
|
|
) -> None:
|
|
|
"""Validate training setup before starting."""
|
|
|
logger.info("🔍 Validating training setup...")
|
|
|
|
|
|
|
|
|
if not os.path.exists(config_path):
|
|
|
raise FileNotFoundError(f"Model config not found: {config_path}")
|
|
|
if not os.path.exists(data_config_path):
|
|
|
raise FileNotFoundError(f"Data config not found: {data_config_path}")
|
|
|
|
|
|
|
|
|
cfg = ModelConfig.from_json_file(config_path)
|
|
|
cfg.assert_exact_params(expected=25_000_000)
|
|
|
model = SupernovaModel(cfg)
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
assert total_params == 25_000_000
|
|
|
|
|
|
|
|
|
sources = load_sources_from_yaml(data_config_path)
|
|
|
if not sources:
|
|
|
raise ValueError("No data sources configured")
|
|
|
|
|
|
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
assert tok.vocab_size == cfg.vocab_size
|
|
|
|
|
|
logger.info("✅ Training setup validation complete")
|
|
|
|
|
|
|
|
|
def train_production(
|
|
|
config_path: str,
|
|
|
data_config_path: str,
|
|
|
seq_len: int = 1024,
|
|
|
batch_size: int = 16,
|
|
|
grad_accum: int = 8,
|
|
|
lr: float = 3e-4,
|
|
|
warmup_steps: int = 2000,
|
|
|
max_steps: int = 100_000,
|
|
|
save_every: int = 10_000,
|
|
|
log_every: int = 50,
|
|
|
out_dir: str = "checkpoints",
|
|
|
seed: int = 42,
|
|
|
max_grad_norm: float = 1.0,
|
|
|
enable_mixed_precision: bool = True,
|
|
|
) -> None:
|
|
|
"""Production training with full monitoring and optimization."""
|
|
|
|
|
|
|
|
|
logger = setup_logging(out_dir)
|
|
|
logger.info("🚀 SUPERNOVA PRODUCTION TRAINING STARTED")
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
|
|
|
validate_training_setup(config_path, data_config_path, logger)
|
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
logger.info(f"📱 Device: {device}")
|
|
|
logger.info(f"🌱 Seed: {seed}")
|
|
|
|
|
|
|
|
|
cfg = ModelConfig.from_json_file(config_path)
|
|
|
cfg.assert_exact_params(expected=25_000_000)
|
|
|
logger.info(f"⚙️ Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
|
|
|
|
|
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
logger.info(f"🔤 Tokenizer: {tok.vocab_size:,} vocab size")
|
|
|
|
|
|
|
|
|
model = SupernovaModel(cfg).to(device)
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
logger.info(f"🧠 Model: {total_params:,} parameters")
|
|
|
|
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if enable_mixed_precision and torch.cuda.is_available() else None
|
|
|
if scaler:
|
|
|
logger.info("⚡ Mixed precision training enabled")
|
|
|
|
|
|
|
|
|
logger.info("📚 Loading datasets...")
|
|
|
sources = load_sources_from_yaml(data_config_path)
|
|
|
logger.info(f"📊 Data sources: {len(sources)} sources loaded")
|
|
|
for i, source in enumerate(sources):
|
|
|
logger.info(f" {i+1}. {source.name} (weight: {source.weight})")
|
|
|
|
|
|
ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
|
|
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
|
|
logger.info(f"🔄 DataLoader: batch_size={batch_size}, seq_len={seq_len}")
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
|
|
|
)
|
|
|
scheduler = get_cosine_schedule_with_warmup(
|
|
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
|
|
|
)
|
|
|
|
|
|
logger.info(f"🎯 Training configuration:")
|
|
|
logger.info(f" Learning rate: {lr}")
|
|
|
logger.info(f" Warmup steps: {warmup_steps:,}")
|
|
|
logger.info(f" Max steps: {max_steps:,}")
|
|
|
logger.info(f" Gradient accumulation: {grad_accum}")
|
|
|
logger.info(f" Max gradient norm: {max_grad_norm}")
|
|
|
logger.info(f" Save every: {save_every:,} steps")
|
|
|
logger.info(f" Log every: {log_every} steps")
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
step = 0
|
|
|
micro = 0
|
|
|
running_loss = 0.0
|
|
|
best_loss = float('inf')
|
|
|
start_time = time.time()
|
|
|
|
|
|
logger.info("🏃 Starting training loop...")
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
try:
|
|
|
while step < max_steps:
|
|
|
for batch in dl:
|
|
|
x, y = batch
|
|
|
x = x.to(device, non_blocking=True)
|
|
|
y = y.to(device, non_blocking=True)
|
|
|
|
|
|
|
|
|
if scaler:
|
|
|
with torch.cuda.amp.autocast():
|
|
|
logits, loss = model(x, y)
|
|
|
loss = loss / grad_accum
|
|
|
else:
|
|
|
logits, loss = model(x, y)
|
|
|
loss = loss / grad_accum
|
|
|
|
|
|
|
|
|
if scaler:
|
|
|
scaler.scale(loss).backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
|
|
|
micro += 1
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
|
|
|
if micro % grad_accum == 0:
|
|
|
if scaler:
|
|
|
scaler.unscale_(optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
|
|
optimizer.step()
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
scheduler.step()
|
|
|
step += 1
|
|
|
|
|
|
|
|
|
if step % log_every == 0:
|
|
|
grad_norm = compute_grad_norm(model)
|
|
|
avg_loss = running_loss * grad_accum / log_every
|
|
|
running_loss = 0.0
|
|
|
lr_now = scheduler.get_last_lr()[0]
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
|
|
|
|
memory = get_memory_usage()
|
|
|
|
|
|
|
|
|
tokens_per_sec = (step * batch_size * seq_len * grad_accum) / elapsed
|
|
|
|
|
|
log_msg = (
|
|
|
f"Step {step:6d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | "
|
|
|
f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s"
|
|
|
)
|
|
|
|
|
|
if memory['allocated'] > 0:
|
|
|
log_msg += f" | Mem: {memory['allocated']:.1f}GB"
|
|
|
|
|
|
logger.info(log_msg)
|
|
|
|
|
|
|
|
|
if avg_loss < best_loss:
|
|
|
best_loss = avg_loss
|
|
|
logger.info(f"💫 New best loss: {best_loss:.4f}")
|
|
|
|
|
|
|
|
|
if save_every and step % save_every == 0:
|
|
|
ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
|
|
|
save_checkpoint(
|
|
|
model, optimizer, scheduler, step, avg_loss if 'avg_loss' in locals() else 0.0,
|
|
|
best_loss, cfg.__dict__, ckpt_path, logger
|
|
|
)
|
|
|
|
|
|
if step >= max_steps:
|
|
|
break
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and micro % 100 == 0:
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
logger.info("\n⏹️ Training interrupted by user")
|
|
|
except Exception as e:
|
|
|
logger.error(f"\n❌ Training failed: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
final_path = os.path.join(out_dir, "supernova_final.pt")
|
|
|
final_loss = running_loss * grad_accum / max(1, micro % grad_accum) if running_loss > 0 else best_loss
|
|
|
save_checkpoint(model, optimizer, scheduler, step, final_loss, best_loss, cfg.__dict__, final_path, logger)
|
|
|
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
total_tokens = step * batch_size * seq_len * grad_accum
|
|
|
|
|
|
logger.info("\n" + "=" * 60)
|
|
|
logger.info("🎉 TRAINING COMPLETE!")
|
|
|
logger.info(f"📈 Final step: {step:,}")
|
|
|
logger.info(f"🏆 Best loss: {best_loss:.4f}")
|
|
|
logger.info(f"⏱️ Total time: {format_time(total_time)}")
|
|
|
logger.info(f"🔢 Total tokens: {total_tokens:,}")
|
|
|
logger.info(f"⚡ Average throughput: {total_tokens/total_time:.0f} tokens/sec")
|
|
|
logger.info(f"💾 Final checkpoint: {final_path}")
|
|
|
logger.info("=" * 60)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Production Supernova Training")
|
|
|
parser.add_argument("--config", required=True, help="Path to model config")
|
|
|
parser.add_argument("--data-config", required=True, help="Path to data config")
|
|
|
parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
|
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
|
|
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
|
|
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
|
parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
|
|
|
parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps")
|
|
|
parser.add_argument("--save-every", type=int, default=10000, help="Save frequency")
|
|
|
parser.add_argument("--log-every", type=int, default=50, help="Log frequency")
|
|
|
parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient clipping")
|
|
|
parser.add_argument("--no-mixed-precision", action="store_true", help="Disable mixed precision")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
train_production(
|
|
|
config_path=args.config,
|
|
|
data_config_path=args.data_config,
|
|
|
seq_len=args.seq_len,
|
|
|
batch_size=args.batch_size,
|
|
|
grad_accum=args.grad_accum,
|
|
|
lr=args.lr,
|
|
|
warmup_steps=args.warmup_steps,
|
|
|
max_steps=args.max_steps,
|
|
|
save_every=args.save_every,
|
|
|
log_every=args.log_every,
|
|
|
out_dir=args.out_dir,
|
|
|
seed=args.seed,
|
|
|
max_grad_norm=args.max_grad_norm,
|
|
|
enable_mixed_precision=not args.no_mixed_precision,
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|