#!/usr/bin/env python3 """Quick test to validate the training pipeline works.""" import sys import os import torch from torch.utils.data import DataLoader # Add supernova to path sys.path.append('.') from supernova.data import load_sources_from_yaml, TokenChunkDataset from supernova.tokenizer import load_gpt2_tokenizer from supernova.config import ModelConfig from supernova.model import SupernovaModel def test_training_pipeline(): print("Testing Supernova training pipeline...") try: # Load config and tokenizer cfg = ModelConfig.from_json_file('./configs/supernova_25m.json') tok = load_gpt2_tokenizer() print(f"Config loaded: {cfg.n_layers} layers, {cfg.d_model} d_model") # Load data sources sources = load_sources_from_yaml('./configs/data_sources.yaml') print(f"Data sources loaded: {len(sources)} sources") # Create dataset ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id) dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0) print("Dataset and DataLoader created") # Create model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SupernovaModel(cfg).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model created on {device}: {total_params:,} parameters") # Test one forward pass print("Testing forward pass...") model.train() batch = next(iter(dl)) x, y = batch x = x.to(device) y = y.to(device) print(f"Batch loaded: x.shape={x.shape}, y.shape={y.shape}") logits, loss = model(x, y) print(f"Forward pass successful: loss={loss.item():.4f}") # Test backward pass print("Testing backward pass...") loss.backward() grad_norm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None) print(f"Backward pass successful: grad_norm={grad_norm:.4f}") print("ALL TESTS PASSED! Training pipeline is ready!") return True except Exception as e: print(f"CRITICAL ERROR in training pipeline: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": success = test_training_pipeline() exit(0 if success else 1)