File size: 2,521 Bytes
8174855 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#!/usr/bin/env python3
"""Quick test to validate the training pipeline works."""
import sys
import os
import torch
from torch.utils.data import DataLoader
# Add supernova to path
sys.path.append('.')
from supernova.data import load_sources_from_yaml, TokenChunkDataset
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
def test_training_pipeline():
print("Testing Supernova training pipeline...")
try:
# Load config and tokenizer
cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
tok = load_gpt2_tokenizer()
print(f"Config loaded: {cfg.n_layers} layers, {cfg.d_model} d_model")
# Load data sources
sources = load_sources_from_yaml('./configs/data_sources.yaml')
print(f"Data sources loaded: {len(sources)} sources")
# Create dataset
ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)
print("Dataset and DataLoader created")
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SupernovaModel(cfg).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created on {device}: {total_params:,} parameters")
# Test one forward pass
print("Testing forward pass...")
model.train()
batch = next(iter(dl))
x, y = batch
x = x.to(device)
y = y.to(device)
print(f"Batch loaded: x.shape={x.shape}, y.shape={y.shape}")
logits, loss = model(x, y)
print(f"Forward pass successful: loss={loss.item():.4f}")
# Test backward pass
print("Testing backward pass...")
loss.backward()
grad_norm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
print(f"Backward pass successful: grad_norm={grad_norm:.4f}")
print("ALL TESTS PASSED! Training pipeline is ready!")
return True
except Exception as e:
print(f"CRITICAL ERROR in training pipeline: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_training_pipeline()
exit(0 if success else 1) |