File size: 4,521 Bytes
c866f18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
"""
Diagnostic script to test gradient flow in SupernovaModel
"""
import torch
import torch.nn.functional as F
from supernova.config import ModelConfig
from supernova.model import SupernovaModel
from supernova.tokenizer import load_gpt2_tokenizer
import math

def compute_grad_norm(model, debug=True):
    total = 0.0
    grad_count = 0
    param_count = 0
    
    for name, p in model.named_parameters():
        param_count += 1
        if p.grad is not None:
            grad_count += 1
            param_norm = p.grad.data.float().norm(2).item()
            total += param_norm * param_norm
            if debug and param_norm > 1e-8:
                print(f"  {name}: grad_norm={param_norm:.6f}, shape={p.grad.shape}")
        elif debug:
            print(f"  {name}: NO GRAD, requires_grad={p.requires_grad}")
    
    total_norm = math.sqrt(total)
    print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={total_norm:.6f}")
    return total_norm

def test_gradient_flow():
    print("Testing gradient flow in SupernovaModel...")
    
    # Load config
    try:
        cfg = ModelConfig.from_json_file("supernova_25m_config.json")
        print(f"Loaded config: {cfg.d_model}d, {cfg.n_layers}L, {cfg.n_heads}H")
    except FileNotFoundError:
        print("Config file not found, creating minimal config...")
        cfg = ModelConfig(
            vocab_size=50257,
            d_model=512,
            n_layers=8,
            n_heads=8,
            mlp_ratio=4,
            dropout=0.1,
            n_positions=1024,
            use_positional_embedding=True,
            final_layer_norm=True
        )
    
    # Create model
    model = SupernovaModel(cfg)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    
    print(f"Model parameters: {model.num_parameters():,}")
    print(f"Using device: {device}")
    
    # Create dummy data
    batch_size = 2
    seq_len = 64
    input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
    targets = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
    
    print(f"Input shape: {input_ids.shape}, Target shape: {targets.shape}")
    
    # Test 1: Basic forward pass
    print("\n=== Test 1: Basic forward pass ===")
    with torch.no_grad():
        logits, loss = model(input_ids, targets)
        print(f"Logits shape: {logits.shape}")
        print(f"Loss: {loss.item():.6f}")
    
    # Test 2: Forward pass with gradients
    print("\n=== Test 2: Forward pass with gradients ===")
    model.zero_grad()
    logits, loss = model(input_ids, targets)
    print(f"Loss before backward: {loss.item():.6f}")
    
    loss.backward()
    print("After backward pass:")
    grad_norm = compute_grad_norm(model, debug=True)
    
    # Test 3: With mixed precision
    print("\n=== Test 3: With mixed precision ===")
    model.zero_grad()
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
    
    device_type = 'cuda' if device.type == 'cuda' else 'cpu'
    with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
        logits, loss = model(input_ids, targets)
        print(f"Loss with autocast: {loss.item():.6f}")
        scaled_loss = scaler.scale(loss)
        print(f"Scaled loss: {scaled_loss.item():.6f}")
    
    scaled_loss.backward()
    print("After scaled backward pass:")
    grad_norm_before_unscale = compute_grad_norm(model, debug=False)
    print(f"Grad norm before unscale: {grad_norm_before_unscale:.6f}")
    
    scaler.unscale_(torch.optim.AdamW(model.parameters()))
    print("After unscaling:")
    grad_norm_after_unscale = compute_grad_norm(model, debug=True)
    
    # Test 4: Parameter inspection
    print("\n=== Test 4: Parameter inspection ===")
    total_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Check specific layers
    print("\nChecking specific layer parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
            break  # Just show first few

if __name__ == "__main__":
    test_gradient_flow()