|
|
|
|
|
"""
|
|
|
Enhanced training script with comprehensive logging and validation.
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import math
|
|
|
import os
|
|
|
import sys
|
|
|
import time
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import DataLoader
|
|
|
from transformers import get_cosine_schedule_with_warmup
|
|
|
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
|
|
|
|
|
|
def compute_grad_norm(model: nn.Module) -> float:
|
|
|
total = 0.0
|
|
|
for p in model.parameters():
|
|
|
if p.grad is not None:
|
|
|
param_norm = p.grad.data.float().norm(2).item()
|
|
|
total += param_norm * param_norm
|
|
|
return math.sqrt(total)
|
|
|
|
|
|
|
|
|
def format_time(seconds):
|
|
|
"""Format seconds into readable time."""
|
|
|
if seconds < 60:
|
|
|
return f"{seconds:.1f}s"
|
|
|
elif seconds < 3600:
|
|
|
return f"{seconds//60:.0f}m{seconds%60:.0f}s"
|
|
|
else:
|
|
|
return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m"
|
|
|
|
|
|
|
|
|
def train_enhanced(
|
|
|
config_path: str,
|
|
|
data_config_path: str,
|
|
|
seq_len: int = 1024,
|
|
|
batch_size: int = 16,
|
|
|
grad_accum: int = 8,
|
|
|
lr: float = 3e-4,
|
|
|
warmup_steps: int = 2000,
|
|
|
max_steps: int = 100_000,
|
|
|
save_every: int = 10_000,
|
|
|
out_dir: str = "checkpoints",
|
|
|
seed: int = 42,
|
|
|
):
|
|
|
print("π SUPERNOVA ENHANCED TRAINING")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"π± Device: {device}")
|
|
|
print(f"π± Seed: {seed}")
|
|
|
|
|
|
|
|
|
cfg = ModelConfig.from_json_file(config_path)
|
|
|
cfg.assert_exact_params(expected=25_000_000)
|
|
|
print(f"βοΈ Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
|
|
|
|
|
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
assert tok.vocab_size == cfg.vocab_size
|
|
|
print(f"π€ Tokenizer: {tok.vocab_size:,} vocab size")
|
|
|
|
|
|
|
|
|
model = SupernovaModel(cfg).to(device)
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
assert total_params == 25_000_000
|
|
|
print(f"π§ Model: {total_params:,} parameters (EXACT)")
|
|
|
|
|
|
|
|
|
print("π Loading datasets...")
|
|
|
sources = load_sources_from_yaml(data_config_path)
|
|
|
print(f"π Data sources: {len(sources)} sources loaded")
|
|
|
for i, source in enumerate(sources):
|
|
|
print(f" {i+1}. {source.name} (weight: {source.weight})")
|
|
|
|
|
|
ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
|
|
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
|
|
print(f"π DataLoader: batch_size={batch_size}, seq_len={seq_len}")
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
|
|
|
)
|
|
|
scheduler = get_cosine_schedule_with_warmup(
|
|
|
optimizer,
|
|
|
num_warmup_steps=warmup_steps,
|
|
|
num_training_steps=max_steps,
|
|
|
)
|
|
|
|
|
|
print(f"π― Training setup:")
|
|
|
print(f" Learning rate: {lr}")
|
|
|
print(f" Warmup steps: {warmup_steps:,}")
|
|
|
print(f" Max steps: {max_steps:,}")
|
|
|
print(f" Grad accumulation: {grad_accum}")
|
|
|
print(f" Save every: {save_every:,} steps")
|
|
|
|
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
print(f"πΎ Output dir: {out_dir}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
step = 0
|
|
|
micro = 0
|
|
|
running_loss = 0.0
|
|
|
best_loss = float('inf')
|
|
|
start_time = time.time()
|
|
|
last_log_time = start_time
|
|
|
|
|
|
print("π Starting training...")
|
|
|
print("=" * 60)
|
|
|
|
|
|
try:
|
|
|
while step < max_steps:
|
|
|
for batch in dl:
|
|
|
x, y = batch
|
|
|
x = x.to(device)
|
|
|
y = y.to(device)
|
|
|
|
|
|
logits, loss = model(x, y)
|
|
|
loss = loss / grad_accum
|
|
|
loss.backward()
|
|
|
|
|
|
micro += 1
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
if micro % grad_accum == 0:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
scheduler.step()
|
|
|
|
|
|
step += 1
|
|
|
|
|
|
|
|
|
if step % 10 == 0:
|
|
|
grad_norm = compute_grad_norm(model)
|
|
|
avg_loss = running_loss * grad_accum / 10.0
|
|
|
running_loss = 0.0
|
|
|
elapsed = time.time() - last_log_time
|
|
|
total_elapsed = time.time() - start_time
|
|
|
lr_now = scheduler.get_last_lr()[0]
|
|
|
|
|
|
|
|
|
tokens_per_batch = batch_size * seq_len
|
|
|
tokens_per_step = tokens_per_batch * grad_accum
|
|
|
tokens_processed = step * tokens_per_step
|
|
|
tokens_per_sec = tokens_processed / total_elapsed
|
|
|
|
|
|
print(f"Step {step:5d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | "
|
|
|
f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s | {format_time(total_elapsed)}")
|
|
|
|
|
|
|
|
|
if avg_loss < best_loss:
|
|
|
best_loss = avg_loss
|
|
|
print(f"π« New best loss: {best_loss:.4f}")
|
|
|
|
|
|
last_log_time = time.time()
|
|
|
|
|
|
|
|
|
if save_every and step % save_every == 0:
|
|
|
ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
|
|
|
torch.save({
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
|
"config": cfg.__dict__,
|
|
|
"step": step,
|
|
|
"loss": avg_loss,
|
|
|
"best_loss": best_loss,
|
|
|
}, ckpt_path)
|
|
|
print(f"πΎ Saved checkpoint: {ckpt_path}")
|
|
|
|
|
|
if step >= max_steps:
|
|
|
break
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nβΉοΈ Training interrupted by user")
|
|
|
except Exception as e:
|
|
|
print(f"\nβ Training failed with error: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
final_path = os.path.join(out_dir, "supernova_final.pt")
|
|
|
torch.save({
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
|
"config": cfg.__dict__,
|
|
|
"step": step,
|
|
|
"loss": running_loss * grad_accum / max(1, micro % grad_accum),
|
|
|
"best_loss": best_loss,
|
|
|
}, final_path)
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
print("\n" + "=" * 60)
|
|
|
print("π TRAINING COMPLETE!")
|
|
|
print(f"π Final step: {step:,}")
|
|
|
print(f"π Best loss: {best_loss:.4f}")
|
|
|
print(f"β±οΈ Total time: {format_time(total_time)}")
|
|
|
print(f"πΎ Final checkpoint: {final_path}")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Enhanced Supernova Training")
|
|
|
parser.add_argument("--config", required=True, help="Path to model config")
|
|
|
parser.add_argument("--data-config", required=True, help="Path to data config")
|
|
|
parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
|
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
|
|
parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
|
|
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
|
|
parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
|
|
|
parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps")
|
|
|
parser.add_argument("--save-every", type=int, default=10000, help="Save frequency")
|
|
|
parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
train_enhanced(
|
|
|
config_path=args.config,
|
|
|
data_config_path=args.data_config,
|
|
|
seq_len=args.seq_len,
|
|
|
batch_size=args.batch_size,
|
|
|
grad_accum=args.grad_accum,
|
|
|
lr=args.lr,
|
|
|
warmup_steps=args.warmup_steps,
|
|
|
max_steps=args.max_steps,
|
|
|
save_every=args.save_every,
|
|
|
out_dir=args.out_dir,
|
|
|
seed=args.seed,
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |