Supernova25million / train_production.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
#!/usr/bin/env python3
"""
Production-ready Supernova training script.
Optimized for stability, monitoring, and memory efficiency.
"""
import argparse
import json
import math
import os
import sys
import time
import logging
from pathlib import Path
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
# Add supernova to path
sys.path.append('.')
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.data import load_sources_from_yaml, TokenChunkDataset
def setup_logging(output_dir: str) -> logging.Logger:
"""Setup comprehensive logging."""
os.makedirs(output_dir, exist_ok=True)
logger = logging.getLogger('supernova_training')
logger.setLevel(logging.INFO)
# File handler
file_handler = logging.FileHandler(os.path.join(output_dir, 'training.log'))
file_handler.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Formatter
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
def compute_grad_norm(model: nn.Module) -> float:
"""Compute gradient norm."""
total = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.float().norm(2).item()
total += param_norm * param_norm
return math.sqrt(total)
def format_time(seconds: float) -> str:
"""Format seconds into readable time."""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds//60:.0f}m{seconds%60:.0f}s"
else:
return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m"
def get_memory_usage() -> Dict[str, float]:
"""Get current memory usage."""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # GB
cached = torch.cuda.memory_reserved() / 1024**3 # GB
return {'allocated': allocated, 'cached': cached}
return {'allocated': 0, 'cached': 0}
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Any,
step: int,
loss: float,
best_loss: float,
config: Dict[str, Any],
path: str,
logger: logging.Logger
) -> None:
"""Save training checkpoint."""
try:
checkpoint = {
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"config": config,
"step": step,
"loss": loss,
"best_loss": best_loss,
"timestamp": time.time(),
}
# Create directory if needed
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(checkpoint, path)
logger.info(f"💾 Checkpoint saved: {path} (loss: {loss:.4f})")
except Exception as e:
logger.error(f"❌ Failed to save checkpoint {path}: {e}")
raise
def validate_training_setup(
config_path: str,
data_config_path: str,
logger: logging.Logger
) -> None:
"""Validate training setup before starting."""
logger.info("🔍 Validating training setup...")
# Check config files exist
if not os.path.exists(config_path):
raise FileNotFoundError(f"Model config not found: {config_path}")
if not os.path.exists(data_config_path):
raise FileNotFoundError(f"Data config not found: {data_config_path}")
# Test model creation
cfg = ModelConfig.from_json_file(config_path)
cfg.assert_exact_params(expected=25_000_000)
model = SupernovaModel(cfg)
total_params = sum(p.numel() for p in model.parameters())
assert total_params == 25_000_000
# Test data loading
sources = load_sources_from_yaml(data_config_path)
if not sources:
raise ValueError("No data sources configured")
# Test tokenizer
tok = load_gpt2_tokenizer()
assert tok.vocab_size == cfg.vocab_size
logger.info("✅ Training setup validation complete")
def train_production(
config_path: str,
data_config_path: str,
seq_len: int = 1024,
batch_size: int = 16,
grad_accum: int = 8,
lr: float = 3e-4,
warmup_steps: int = 2000,
max_steps: int = 100_000,
save_every: int = 10_000,
log_every: int = 50,
out_dir: str = "checkpoints",
seed: int = 42,
max_grad_norm: float = 1.0,
enable_mixed_precision: bool = True,
) -> None:
"""Production training with full monitoring and optimization."""
# Setup logging
logger = setup_logging(out_dir)
logger.info("🚀 SUPERNOVA PRODUCTION TRAINING STARTED")
logger.info("=" * 60)
# Validate setup
validate_training_setup(config_path, data_config_path, logger)
# Setup device and seed
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"📱 Device: {device}")
logger.info(f"🌱 Seed: {seed}")
# Load configuration
cfg = ModelConfig.from_json_file(config_path)
cfg.assert_exact_params(expected=25_000_000)
logger.info(f"⚙️ Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
# Load tokenizer
tok = load_gpt2_tokenizer()
logger.info(f"🔤 Tokenizer: {tok.vocab_size:,} vocab size")
# Create model
model = SupernovaModel(cfg).to(device)
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"🧠 Model: {total_params:,} parameters")
# Setup mixed precision if enabled
scaler = torch.cuda.amp.GradScaler() if enable_mixed_precision and torch.cuda.is_available() else None
if scaler:
logger.info("⚡ Mixed precision training enabled")
# Load data
logger.info("📚 Loading datasets...")
sources = load_sources_from_yaml(data_config_path)
logger.info(f"📊 Data sources: {len(sources)} sources loaded")
for i, source in enumerate(sources):
logger.info(f" {i+1}. {source.name} (weight: {source.weight})")
ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
logger.info(f"🔄 DataLoader: batch_size={batch_size}, seq_len={seq_len}")
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
)
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
)
logger.info(f"🎯 Training configuration:")
logger.info(f" Learning rate: {lr}")
logger.info(f" Warmup steps: {warmup_steps:,}")
logger.info(f" Max steps: {max_steps:,}")
logger.info(f" Gradient accumulation: {grad_accum}")
logger.info(f" Max gradient norm: {max_grad_norm}")
logger.info(f" Save every: {save_every:,} steps")
logger.info(f" Log every: {log_every} steps")
# Training variables
model.train()
step = 0
micro = 0
running_loss = 0.0
best_loss = float('inf')
start_time = time.time()
logger.info("🏃 Starting training loop...")
logger.info("=" * 60)
try:
while step < max_steps:
for batch in dl:
x, y = batch
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
# Forward pass with optional mixed precision
if scaler:
with torch.cuda.amp.autocast():
logits, loss = model(x, y)
loss = loss / grad_accum
else:
logits, loss = model(x, y)
loss = loss / grad_accum
# Backward pass
if scaler:
scaler.scale(loss).backward()
else:
loss.backward()
micro += 1
running_loss += loss.item()
# Optimizer step
if micro % grad_accum == 0:
if scaler:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
step += 1
# Logging
if step % log_every == 0:
grad_norm = compute_grad_norm(model)
avg_loss = running_loss * grad_accum / log_every
running_loss = 0.0
lr_now = scheduler.get_last_lr()[0]
elapsed = time.time() - start_time
# Memory usage
memory = get_memory_usage()
# Calculate throughput
tokens_per_sec = (step * batch_size * seq_len * grad_accum) / elapsed
log_msg = (
f"Step {step:6d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | "
f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s"
)
if memory['allocated'] > 0:
log_msg += f" | Mem: {memory['allocated']:.1f}GB"
logger.info(log_msg)
# Track best loss
if avg_loss < best_loss:
best_loss = avg_loss
logger.info(f"💫 New best loss: {best_loss:.4f}")
# Save checkpoints
if save_every and step % save_every == 0:
ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
save_checkpoint(
model, optimizer, scheduler, step, avg_loss if 'avg_loss' in locals() else 0.0,
best_loss, cfg.__dict__, ckpt_path, logger
)
if step >= max_steps:
break
# Clear cache periodically to prevent OOM
if torch.cuda.is_available() and micro % 100 == 0:
torch.cuda.empty_cache()
except KeyboardInterrupt:
logger.info("\n⏹️ Training interrupted by user")
except Exception as e:
logger.error(f"\n❌ Training failed: {e}")
raise
# Final checkpoint
final_path = os.path.join(out_dir, "supernova_final.pt")
final_loss = running_loss * grad_accum / max(1, micro % grad_accum) if running_loss > 0 else best_loss
save_checkpoint(model, optimizer, scheduler, step, final_loss, best_loss, cfg.__dict__, final_path, logger)
# Training summary
total_time = time.time() - start_time
total_tokens = step * batch_size * seq_len * grad_accum
logger.info("\n" + "=" * 60)
logger.info("🎉 TRAINING COMPLETE!")
logger.info(f"📈 Final step: {step:,}")
logger.info(f"🏆 Best loss: {best_loss:.4f}")
logger.info(f"⏱️ Total time: {format_time(total_time)}")
logger.info(f"🔢 Total tokens: {total_tokens:,}")
logger.info(f"⚡ Average throughput: {total_tokens/total_time:.0f} tokens/sec")
logger.info(f"💾 Final checkpoint: {final_path}")
logger.info("=" * 60)
def main():
parser = argparse.ArgumentParser(description="Production Supernova Training")
parser.add_argument("--config", required=True, help="Path to model config")
parser.add_argument("--data-config", required=True, help="Path to data config")
parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps")
parser.add_argument("--save-every", type=int, default=10000, help="Save frequency")
parser.add_argument("--log-every", type=int, default=50, help="Log frequency")
parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient clipping")
parser.add_argument("--no-mixed-precision", action="store_true", help="Disable mixed precision")
args = parser.parse_args()
train_production(
config_path=args.config,
data_config_path=args.data_config,
seq_len=args.seq_len,
batch_size=args.batch_size,
grad_accum=args.grad_accum,
lr=args.lr,
warmup_steps=args.warmup_steps,
max_steps=args.max_steps,
save_every=args.save_every,
log_every=args.log_every,
out_dir=args.out_dir,
seed=args.seed,
max_grad_norm=args.max_grad_norm,
enable_mixed_precision=not args.no_mixed_precision,
)
if __name__ == "__main__":
main()