|
|
|
|
|
""" |
|
|
Diagnostic script to test gradient flow in SupernovaModel |
|
|
""" |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from supernova.config import ModelConfig |
|
|
from supernova.model import SupernovaModel |
|
|
from supernova.tokenizer import load_gpt2_tokenizer |
|
|
import math |
|
|
|
|
|
def compute_grad_norm(model, debug=True): |
|
|
total = 0.0 |
|
|
grad_count = 0 |
|
|
param_count = 0 |
|
|
|
|
|
for name, p in model.named_parameters(): |
|
|
param_count += 1 |
|
|
if p.grad is not None: |
|
|
grad_count += 1 |
|
|
param_norm = p.grad.data.float().norm(2).item() |
|
|
total += param_norm * param_norm |
|
|
if debug and param_norm > 1e-8: |
|
|
print(f" {name}: grad_norm={param_norm:.6f}, shape={p.grad.shape}") |
|
|
elif debug: |
|
|
print(f" {name}: NO GRAD, requires_grad={p.requires_grad}") |
|
|
|
|
|
total_norm = math.sqrt(total) |
|
|
print(f"Gradient stats: {grad_count}/{param_count} parameters have gradients, total_norm={total_norm:.6f}") |
|
|
return total_norm |
|
|
|
|
|
def test_gradient_flow(): |
|
|
print("Testing gradient flow in SupernovaModel...") |
|
|
|
|
|
|
|
|
try: |
|
|
cfg = ModelConfig.from_json_file("supernova_25m_config.json") |
|
|
print(f"Loaded config: {cfg.d_model}d, {cfg.n_layers}L, {cfg.n_heads}H") |
|
|
except FileNotFoundError: |
|
|
print("Config file not found, creating minimal config...") |
|
|
cfg = ModelConfig( |
|
|
vocab_size=50257, |
|
|
d_model=512, |
|
|
n_layers=8, |
|
|
n_heads=8, |
|
|
mlp_ratio=4, |
|
|
dropout=0.1, |
|
|
n_positions=1024, |
|
|
use_positional_embedding=True, |
|
|
final_layer_norm=True |
|
|
) |
|
|
|
|
|
|
|
|
model = SupernovaModel(cfg) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
model.train() |
|
|
|
|
|
print(f"Model parameters: {model.num_parameters():,}") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
batch_size = 2 |
|
|
seq_len = 64 |
|
|
input_ids = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device) |
|
|
targets = torch.randint(0, cfg.vocab_size, (batch_size, seq_len), device=device) |
|
|
|
|
|
print(f"Input shape: {input_ids.shape}, Target shape: {targets.shape}") |
|
|
|
|
|
|
|
|
print("\n=== Test 1: Basic forward pass ===") |
|
|
with torch.no_grad(): |
|
|
logits, loss = model(input_ids, targets) |
|
|
print(f"Logits shape: {logits.shape}") |
|
|
print(f"Loss: {loss.item():.6f}") |
|
|
|
|
|
|
|
|
print("\n=== Test 2: Forward pass with gradients ===") |
|
|
model.zero_grad() |
|
|
logits, loss = model(input_ids, targets) |
|
|
print(f"Loss before backward: {loss.item():.6f}") |
|
|
|
|
|
loss.backward() |
|
|
print("After backward pass:") |
|
|
grad_norm = compute_grad_norm(model, debug=True) |
|
|
|
|
|
|
|
|
print("\n=== Test 3: With mixed precision ===") |
|
|
model.zero_grad() |
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda")) |
|
|
|
|
|
device_type = 'cuda' if device.type == 'cuda' else 'cpu' |
|
|
with torch.amp.autocast(device_type, enabled=(device.type == "cuda")): |
|
|
logits, loss = model(input_ids, targets) |
|
|
print(f"Loss with autocast: {loss.item():.6f}") |
|
|
scaled_loss = scaler.scale(loss) |
|
|
print(f"Scaled loss: {scaled_loss.item():.6f}") |
|
|
|
|
|
scaled_loss.backward() |
|
|
print("After scaled backward pass:") |
|
|
grad_norm_before_unscale = compute_grad_norm(model, debug=False) |
|
|
print(f"Grad norm before unscale: {grad_norm_before_unscale:.6f}") |
|
|
|
|
|
scaler.unscale_(torch.optim.AdamW(model.parameters())) |
|
|
print("After unscaling:") |
|
|
grad_norm_after_unscale = compute_grad_norm(model, debug=True) |
|
|
|
|
|
|
|
|
print("\n=== Test 4: Parameter inspection ===") |
|
|
total_params = 0 |
|
|
trainable_params = 0 |
|
|
for name, param in model.named_parameters(): |
|
|
total_params += param.numel() |
|
|
if param.requires_grad: |
|
|
trainable_params += param.numel() |
|
|
|
|
|
print(f"Total parameters: {total_params:,}") |
|
|
print(f"Trainable parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
print("\nChecking specific layer parameters:") |
|
|
for name, param in model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
print(f"{name}: shape={param.shape}, dtype={param.dtype}, device={param.device}") |
|
|
break |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_gradient_flow() |