File size: 2,521 Bytes
8174855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
"""Quick test to validate the training pipeline works."""

import sys
import os
import torch
from torch.utils.data import DataLoader

# Add supernova to path
sys.path.append('.')

from supernova.data import load_sources_from_yaml, TokenChunkDataset
from supernova.tokenizer import load_gpt2_tokenizer
from supernova.config import ModelConfig
from supernova.model import SupernovaModel

def test_training_pipeline():
    print("Testing Supernova training pipeline...")
    
    try:
        # Load config and tokenizer
        cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
        tok = load_gpt2_tokenizer()
        print(f"Config loaded: {cfg.n_layers} layers, {cfg.d_model} d_model")
        
        # Load data sources
        sources = load_sources_from_yaml('./configs/data_sources.yaml')
        print(f"Data sources loaded: {len(sources)} sources")
        
        # Create dataset
        ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
        dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)
        print("Dataset and DataLoader created")
        
        # Create model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = SupernovaModel(cfg).to(device)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Model created on {device}: {total_params:,} parameters")
        
        # Test one forward pass
        print("Testing forward pass...")
        model.train()
        batch = next(iter(dl))
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        print(f"Batch loaded: x.shape={x.shape}, y.shape={y.shape}")
        
        logits, loss = model(x, y)
        print(f"Forward pass successful: loss={loss.item():.4f}")
        
        # Test backward pass
        print("Testing backward pass...")
        loss.backward()
        grad_norm = sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None)
        print(f"Backward pass successful: grad_norm={grad_norm:.4f}")
        
        print("ALL TESTS PASSED! Training pipeline is ready!")
        return True
        
    except Exception as e:
        print(f"CRITICAL ERROR in training pipeline: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    success = test_training_pipeline()
    exit(0 if success else 1)