|
|
|
|
|
"""Quick test to validate the training pipeline works."""
|
|
|
|
|
|
import sys
|
|
|
import os
|
|
|
import torch
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
sys.path.append('.')
|
|
|
|
|
|
from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
from supernova.config import ModelConfig
|
|
|
from supernova.model import SupernovaModel
|
|
|
|
|
|
def test_training_pipeline():
|
|
|
print("Testing Supernova training pipeline...")
|
|
|
|
|
|
try:
|
|
|
|
|
|
cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
|
|
tok = load_gpt2_tokenizer()
|
|
|
print(f"Config loaded: {cfg.n_layers} layers, {cfg.d_model} d_model")
|
|
|
|
|
|
|
|
|
sources = load_sources_from_yaml('./configs/data_sources.yaml')
|
|
|
print(f"Data sources loaded: {len(sources)} sources")
|
|
|
|
|
|
|
|
|
ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
|
|
|
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)
|
|
|
print("Dataset and DataLoader created")
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
model = SupernovaModel(cfg).to(device)
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
print(f"Model created on {device}: {total_params:,} parameters")
|
|
|
|
|
|
|
|
|
print("Testing forward pass...")
|
|
|
model.train()
|
|
|
batch = next(iter(dl))
|
|
|
x, y = batch
|
|
|
x = x.to(device)
|
|
|
y = y.to(device)
|
|
|
print(f"Batch loaded: x.shape={x.shape}, y.shape={y.shape}")
|
|
|
|
|
|
logits, loss = model(x, y)
|
|
|
print(f"Forward pass successful: loss={loss.item():.4f}")
|
|
|
|
|
|
|
|
|
print("Testing backward pass...")
|
|
|
loss.backward()
|
|
|
grad_norm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
|
|
|
print(f"Backward pass successful: grad_norm={grad_norm:.4f}")
|
|
|
|
|
|
print("ALL TESTS PASSED! Training pipeline is ready!")
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"CRITICAL ERROR in training pipeline: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
success = test_training_pipeline()
|
|
|
exit(0 if success else 1) |