Kompella Sri Aasrith Souri
commited on
Commit
·
c866f18
1
Parent(s):
76a1306
fixed gradient norm error
Browse files- supernova/train.py +23 -4
- test_gradients.py +128 -0
supernova/train.py
CHANGED
|
@@ -19,12 +19,25 @@ from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
|
|
| 19 |
# ------------------------------
|
| 20 |
# Utilities
|
| 21 |
# ------------------------------
|
| 22 |
-
def compute_grad_norm(model: nn.Module) -> float:
|
| 23 |
total = 0.0
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if p.grad is not None:
|
|
|
|
| 26 |
param_norm = p.grad.data.float().norm(2).item()
|
| 27 |
total += param_norm * param_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return math.sqrt(total)
|
| 29 |
|
| 30 |
def atomic_save(obj: Dict[str, Any], path: str):
|
|
@@ -237,6 +250,13 @@ def train(
|
|
| 237 |
scaler.unscale_(optimizer)
|
| 238 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
scaler.step(optimizer)
|
| 241 |
scaler.update()
|
| 242 |
optimizer.zero_grad(set_to_none=True)
|
|
@@ -247,8 +267,7 @@ def train(
|
|
| 247 |
step += 1
|
| 248 |
|
| 249 |
# logging
|
| 250 |
-
if step % 50 == 0 and (not ddp or local_rank == 0):
|
| 251 |
-
grad_norm = compute_grad_norm(model if not ddp else model.module)
|
| 252 |
avg_loss = running_loss * grad_accum / 50.0
|
| 253 |
running_loss = 0.0
|
| 254 |
elapsed = time.time() - t0
|
|
|
|
| 19 |
# ------------------------------
|
| 20 |
# Utilities
|
| 21 |
# ------------------------------
|
| 22 |
+
def compute_grad_norm(model: nn.Module, debug: bool = False) -> float:
|
| 23 |
total = 0.0
|
| 24 |
+
grad_count = 0
|
| 25 |
+
param_count = 0
|
| 26 |
+
|
| 27 |
+
for name, p in model.named_parameters():
|
| 28 |
+
param_count += 1
|
| 29 |
if p.grad is not None:
|
| 30 |
+
grad_count += 1
|
| 31 |
param_norm = p.grad.data.float().norm(2).item()
|
| 32 |
total += param_norm * param_norm
|
| 33 |
+
if debug and param_norm > 1e-8: # Only print non-zero gradients
|
| 34 |
+
print(f" {name}: grad_norm={param_norm:.6f}")
|
| 35 |
+
elif debug:
|
| 36 |
+
print(f" {name}: NO GRAD")
|
| 37 |
+
|
| 38 |
+
if debug:
|
| 39 |
+
print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={math.sqrt(total):.6f}")
|
| 40 |
+
|
| 41 |
return math.sqrt(total)
|
| 42 |
|
| 43 |
def atomic_save(obj: Dict[str, Any], path: str):
|
|
|
|
| 250 |
scaler.unscale_(optimizer)
|
| 251 |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
| 252 |
|
| 253 |
+
# Compute gradient norm BEFORE clearing gradients (only when needed for logging)
|
| 254 |
+
grad_norm = None
|
| 255 |
+
if (step + 1) % 50 == 0 and (not ddp or local_rank == 0):
|
| 256 |
+
# Enable debug mode for first few steps to diagnose gradient issues
|
| 257 |
+
debug_gradients = step < 5
|
| 258 |
+
grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients)
|
| 259 |
+
|
| 260 |
scaler.step(optimizer)
|
| 261 |
scaler.update()
|
| 262 |
optimizer.zero_grad(set_to_none=True)
|
|
|
|
| 267 |
step += 1
|
| 268 |
|
| 269 |
# logging
|
| 270 |
+
if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None:
|
|
|
|
| 271 |
avg_loss = running_loss * grad_accum / 50.0
|
| 272 |
running_loss = 0.0
|
| 273 |
elapsed = time.time() - t0
|
test_gradients.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Diagnostic script to test gradient flow in SupernovaModel
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from supernova.config import ModelConfig
|
| 8 |
+
from supernova.model import SupernovaModel
|
| 9 |
+
from supernova.tokenizer import load_gpt2_tokenizer
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
def compute_grad_norm(model, debug=True):
|
| 13 |
+
total = 0.0
|
| 14 |
+
grad_count = 0
|
| 15 |
+
param_count = 0
|
| 16 |
+
|
| 17 |
+
for name, p in model.named_parameters():
|
| 18 |
+
param_count += 1
|
| 19 |
+
if p.grad is not None:
|
| 20 |
+
grad_count += 1
|
| 21 |
+
param_norm = p.grad.data.float().norm(2).item()
|
| 22 |
+
total += param_norm * param_norm
|
| 23 |
+
if debug and param_norm > 1e-8:
|
| 24 |
+
print(f" {name}: grad_norm={param_norm:.6f}, shape={p.grad.shape}")
|
| 25 |
+
elif debug:
|
| 26 |
+
print(f" {name}: NO GRAD, requires_grad={p.requires_grad}")
|
| 27 |
+
|
| 28 |
+
total_norm = math.sqrt(total)
|
| 29 |
+
print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={total_norm:.6f}")
|
| 30 |
+
return total_norm
|
| 31 |
+
|
| 32 |
+
def test_gradient_flow():
|
| 33 |
+
print("Testing gradient flow in SupernovaModel...")
|
| 34 |
+
|
| 35 |
+
# Load config
|
| 36 |
+
try:
|
| 37 |
+
cfg = ModelConfig.from_json_file("supernova_25m_config.json")
|
| 38 |
+
print(f"Loaded config: {cfg.d_model}d, {cfg.n_layers}L, {cfg.n_heads}H")
|
| 39 |
+
except FileNotFoundError:
|
| 40 |
+
print("Config file not found, creating minimal config...")
|
| 41 |
+
cfg = ModelConfig(
|
| 42 |
+
vocab_size=50257,
|
| 43 |
+
d_model=512,
|
| 44 |
+
n_layers=8,
|
| 45 |
+
n_heads=8,
|
| 46 |
+
mlp_ratio=4,
|
| 47 |
+
dropout=0.1,
|
| 48 |
+
n_positions=1024,
|
| 49 |
+
use_positional_embedding=True,
|
| 50 |
+
final_layer_norm=True
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Create model
|
| 54 |
+
model = SupernovaModel(cfg)
|
| 55 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
+
model.to(device)
|
| 57 |
+
model.train()
|
| 58 |
+
|
| 59 |
+
print(f"Model parameters: {model.num_parameters():,}")
|
| 60 |
+
print(f"Using device: {device}")
|
| 61 |
+
|
| 62 |
+
# Create dummy data
|
| 63 |
+
batch_size = 2
|
| 64 |
+
seq_len = 64
|
| 65 |
+
input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
|
| 66 |
+
targets = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
|
| 67 |
+
|
| 68 |
+
print(f"Input shape: {input_ids.shape}, Target shape: {targets.shape}")
|
| 69 |
+
|
| 70 |
+
# Test 1: Basic forward pass
|
| 71 |
+
print("\n=== Test 1: Basic forward pass ===")
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
logits, loss = model(input_ids, targets)
|
| 74 |
+
print(f"Logits shape: {logits.shape}")
|
| 75 |
+
print(f"Loss: {loss.item():.6f}")
|
| 76 |
+
|
| 77 |
+
# Test 2: Forward pass with gradients
|
| 78 |
+
print("\n=== Test 2: Forward pass with gradients ===")
|
| 79 |
+
model.zero_grad()
|
| 80 |
+
logits, loss = model(input_ids, targets)
|
| 81 |
+
print(f"Loss before backward: {loss.item():.6f}")
|
| 82 |
+
|
| 83 |
+
loss.backward()
|
| 84 |
+
print("After backward pass:")
|
| 85 |
+
grad_norm = compute_grad_norm(model, debug=True)
|
| 86 |
+
|
| 87 |
+
# Test 3: With mixed precision
|
| 88 |
+
print("\n=== Test 3: With mixed precision ===")
|
| 89 |
+
model.zero_grad()
|
| 90 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
|
| 91 |
+
|
| 92 |
+
device_type = 'cuda' if device.type == 'cuda' else 'cpu'
|
| 93 |
+
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
|
| 94 |
+
logits, loss = model(input_ids, targets)
|
| 95 |
+
print(f"Loss with autocast: {loss.item():.6f}")
|
| 96 |
+
scaled_loss = scaler.scale(loss)
|
| 97 |
+
print(f"Scaled loss: {scaled_loss.item():.6f}")
|
| 98 |
+
|
| 99 |
+
scaled_loss.backward()
|
| 100 |
+
print("After scaled backward pass:")
|
| 101 |
+
grad_norm_before_unscale = compute_grad_norm(model, debug=False)
|
| 102 |
+
print(f"Grad norm before unscale: {grad_norm_before_unscale:.6f}")
|
| 103 |
+
|
| 104 |
+
scaler.unscale_(torch.optim.AdamW(model.parameters()))
|
| 105 |
+
print("After unscaling:")
|
| 106 |
+
grad_norm_after_unscale = compute_grad_norm(model, debug=True)
|
| 107 |
+
|
| 108 |
+
# Test 4: Parameter inspection
|
| 109 |
+
print("\n=== Test 4: Parameter inspection ===")
|
| 110 |
+
total_params = 0
|
| 111 |
+
trainable_params = 0
|
| 112 |
+
for name, param in model.named_parameters():
|
| 113 |
+
total_params += param.numel()
|
| 114 |
+
if param.requires_grad:
|
| 115 |
+
trainable_params += param.numel()
|
| 116 |
+
|
| 117 |
+
print(f"Total parameters: {total_params:,}")
|
| 118 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 119 |
+
|
| 120 |
+
# Check specific layers
|
| 121 |
+
print("\nChecking specific layer parameters:")
|
| 122 |
+
for name, param in model.named_parameters():
|
| 123 |
+
if param.requires_grad:
|
| 124 |
+
print(f"{name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
|
| 125 |
+
break # Just show first few
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
test_gradient_flow()
|