Kompella Sri Aasrith Souri commited on
Commit
c866f18
·
1 Parent(s): 76a1306

fixed gradient norm error

Browse files
Files changed (2) hide show
  1. supernova/train.py +23 -4
  2. test_gradients.py +128 -0
supernova/train.py CHANGED
@@ -19,12 +19,25 @@ from .data import load_sources_from_yaml, TokenChunkDataset, DataSource
19
  # ------------------------------
20
  # Utilities
21
  # ------------------------------
22
- def compute_grad_norm(model: nn.Module) -> float:
23
  total = 0.0
24
- for p in model.parameters():
 
 
 
 
25
  if p.grad is not None:
 
26
  param_norm = p.grad.data.float().norm(2).item()
27
  total += param_norm * param_norm
 
 
 
 
 
 
 
 
28
  return math.sqrt(total)
29
 
30
  def atomic_save(obj: Dict[str, Any], path: str):
@@ -237,6 +250,13 @@ def train(
237
  scaler.unscale_(optimizer)
238
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
239
 
 
 
 
 
 
 
 
240
  scaler.step(optimizer)
241
  scaler.update()
242
  optimizer.zero_grad(set_to_none=True)
@@ -247,8 +267,7 @@ def train(
247
  step += 1
248
 
249
  # logging
250
- if step % 50 == 0 and (not ddp or local_rank == 0):
251
- grad_norm = compute_grad_norm(model if not ddp else model.module)
252
  avg_loss = running_loss * grad_accum / 50.0
253
  running_loss = 0.0
254
  elapsed = time.time() - t0
 
19
  # ------------------------------
20
  # Utilities
21
  # ------------------------------
22
+ def compute_grad_norm(model: nn.Module, debug: bool = False) -> float:
23
  total = 0.0
24
+ grad_count = 0
25
+ param_count = 0
26
+
27
+ for name, p in model.named_parameters():
28
+ param_count += 1
29
  if p.grad is not None:
30
+ grad_count += 1
31
  param_norm = p.grad.data.float().norm(2).item()
32
  total += param_norm * param_norm
33
+ if debug and param_norm > 1e-8: # Only print non-zero gradients
34
+ print(f" {name}: grad_norm={param_norm:.6f}")
35
+ elif debug:
36
+ print(f" {name}: NO GRAD")
37
+
38
+ if debug:
39
+ print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={math.sqrt(total):.6f}")
40
+
41
  return math.sqrt(total)
42
 
43
  def atomic_save(obj: Dict[str, Any], path: str):
 
250
  scaler.unscale_(optimizer)
251
  torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
252
 
253
+ # Compute gradient norm BEFORE clearing gradients (only when needed for logging)
254
+ grad_norm = None
255
+ if (step + 1) % 50 == 0 and (not ddp or local_rank == 0):
256
+ # Enable debug mode for first few steps to diagnose gradient issues
257
+ debug_gradients = step < 5
258
+ grad_norm = compute_grad_norm(model if not ddp else model.module, debug=debug_gradients)
259
+
260
  scaler.step(optimizer)
261
  scaler.update()
262
  optimizer.zero_grad(set_to_none=True)
 
267
  step += 1
268
 
269
  # logging
270
+ if step % 50 == 0 and (not ddp or local_rank == 0) and grad_norm is not None:
 
271
  avg_loss = running_loss * grad_accum / 50.0
272
  running_loss = 0.0
273
  elapsed = time.time() - t0
test_gradients.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Diagnostic script to test gradient flow in SupernovaModel
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from supernova.config import ModelConfig
8
+ from supernova.model import SupernovaModel
9
+ from supernova.tokenizer import load_gpt2_tokenizer
10
+ import math
11
+
12
+ def compute_grad_norm(model, debug=True):
13
+ total = 0.0
14
+ grad_count = 0
15
+ param_count = 0
16
+
17
+ for name, p in model.named_parameters():
18
+ param_count += 1
19
+ if p.grad is not None:
20
+ grad_count += 1
21
+ param_norm = p.grad.data.float().norm(2).item()
22
+ total += param_norm * param_norm
23
+ if debug and param_norm > 1e-8:
24
+ print(f" {name}: grad_norm={param_norm:.6f}, shape={p.grad.shape}")
25
+ elif debug:
26
+ print(f" {name}: NO GRAD, requires_grad={p.requires_grad}")
27
+
28
+ total_norm = math.sqrt(total)
29
+ print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={total_norm:.6f}")
30
+ return total_norm
31
+
32
+ def test_gradient_flow():
33
+ print("Testing gradient flow in SupernovaModel...")
34
+
35
+ # Load config
36
+ try:
37
+ cfg = ModelConfig.from_json_file("supernova_25m_config.json")
38
+ print(f"Loaded config: {cfg.d_model}d, {cfg.n_layers}L, {cfg.n_heads}H")
39
+ except FileNotFoundError:
40
+ print("Config file not found, creating minimal config...")
41
+ cfg = ModelConfig(
42
+ vocab_size=50257,
43
+ d_model=512,
44
+ n_layers=8,
45
+ n_heads=8,
46
+ mlp_ratio=4,
47
+ dropout=0.1,
48
+ n_positions=1024,
49
+ use_positional_embedding=True,
50
+ final_layer_norm=True
51
+ )
52
+
53
+ # Create model
54
+ model = SupernovaModel(cfg)
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ model.to(device)
57
+ model.train()
58
+
59
+ print(f"Model parameters: {model.num_parameters():,}")
60
+ print(f"Using device: {device}")
61
+
62
+ # Create dummy data
63
+ batch_size = 2
64
+ seq_len = 64
65
+ input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
66
+ targets = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device)
67
+
68
+ print(f"Input shape: {input_ids.shape}, Target shape: {targets.shape}")
69
+
70
+ # Test 1: Basic forward pass
71
+ print("\n=== Test 1: Basic forward pass ===")
72
+ with torch.no_grad():
73
+ logits, loss = model(input_ids, targets)
74
+ print(f"Logits shape: {logits.shape}")
75
+ print(f"Loss: {loss.item():.6f}")
76
+
77
+ # Test 2: Forward pass with gradients
78
+ print("\n=== Test 2: Forward pass with gradients ===")
79
+ model.zero_grad()
80
+ logits, loss = model(input_ids, targets)
81
+ print(f"Loss before backward: {loss.item():.6f}")
82
+
83
+ loss.backward()
84
+ print("After backward pass:")
85
+ grad_norm = compute_grad_norm(model, debug=True)
86
+
87
+ # Test 3: With mixed precision
88
+ print("\n=== Test 3: With mixed precision ===")
89
+ model.zero_grad()
90
+ scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
91
+
92
+ device_type = 'cuda' if device.type == 'cuda' else 'cpu'
93
+ with torch.amp.autocast(device_type, enabled=(device.type == "cuda")):
94
+ logits, loss = model(input_ids, targets)
95
+ print(f"Loss with autocast: {loss.item():.6f}")
96
+ scaled_loss = scaler.scale(loss)
97
+ print(f"Scaled loss: {scaled_loss.item():.6f}")
98
+
99
+ scaled_loss.backward()
100
+ print("After scaled backward pass:")
101
+ grad_norm_before_unscale = compute_grad_norm(model, debug=False)
102
+ print(f"Grad norm before unscale: {grad_norm_before_unscale:.6f}")
103
+
104
+ scaler.unscale_(torch.optim.AdamW(model.parameters()))
105
+ print("After unscaling:")
106
+ grad_norm_after_unscale = compute_grad_norm(model, debug=True)
107
+
108
+ # Test 4: Parameter inspection
109
+ print("\n=== Test 4: Parameter inspection ===")
110
+ total_params = 0
111
+ trainable_params = 0
112
+ for name, param in model.named_parameters():
113
+ total_params += param.numel()
114
+ if param.requires_grad:
115
+ trainable_params += param.numel()
116
+
117
+ print(f"Total parameters: {total_params:,}")
118
+ print(f"Trainable parameters: {trainable_params:,}")
119
+
120
+ # Check specific layers
121
+ print("\nChecking specific layer parameters:")
122
+ for name, param in model.named_parameters():
123
+ if param.requires_grad:
124
+ print(f"{name}: shape={param.shape}, dtype={param.dtype}, device={param.device}")
125
+ break # Just show first few
126
+
127
+ if __name__ == "__main__":
128
+ test_gradient_flow()