Supernova25million / test_training.py
algorythmtechnologies's picture
Upload folder using huggingface_hub
8174855 verified
#!/usr/bin/env python3
"""Quick test to validate the training pipeline works."""
import sys
import os
import torch
from torch.utils.data import DataLoader
# Add supernova to path
sys.path.append('.')
from supernova.data import load_sources_from_yaml, TokenChunkDataset
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
def test_training_pipeline():
print("Testing Supernova training pipeline...")
try:
# Load config and tokenizer
cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
tok = load_gpt2_tokenizer()
print(f"Config loaded: {cfg.n_layers} layers, {cfg.d_model} d_model")
# Load data sources
sources = load_sources_from_yaml('./configs/data_sources.yaml')
print(f"Data sources loaded: {len(sources)} sources")
# Create dataset
ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)
print("Dataset and DataLoader created")
# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SupernovaModel(cfg).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created on {device}: {total_params:,} parameters")
# Test one forward pass
print("Testing forward pass...")
model.train()
batch = next(iter(dl))
x, y = batch
x = x.to(device)
y = y.to(device)
print(f"Batch loaded: x.shape={x.shape}, y.shape={y.shape}")
logits, loss = model(x, y)
print(f"Forward pass successful: loss={loss.item():.4f}")
# Test backward pass
print("Testing backward pass...")
loss.backward()
grad_norm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
print(f"Backward pass successful: grad_norm={grad_norm:.4f}")
print("ALL TESTS PASSED! Training pipeline is ready!")
return True
except Exception as e:
print(f"CRITICAL ERROR in training pipeline: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_training_pipeline()
exit(0 if success else 1)