File size: 4,668 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from dataloader import get_dataloaders
from loss import diffusion_loss
from sample import sample
def train():
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup logging
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join(config.log_dir, timestamp)
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)
# Initialize components
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
train_loader, val_loader = get_dataloaders(config)
# Training loop
for epoch in range(config.epochs):
model.train()
epoch_loss = 0.0
num_batches = 0
for batch_idx, (x0, _) in enumerate(train_loader):
x0 = x0.to(device)
# Sample random timesteps
t = torch.randint(0, config.T, (x0.size(0),), device=device)
# Compute loss
loss = diffusion_loss(model, x0, t, noise_scheduler, config)
# Optimize
optimizer.zero_grad()
loss.backward()
# Add gradient clipping for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) # Increased from 1.0
optimizer.step()
# Track epoch loss for scheduler
epoch_loss += loss.item()
num_batches += 1
# Logging with more details
if batch_idx % 100 == 0:
# Check for NaN values
if torch.isnan(loss):
print(f"WARNING: NaN loss detected at Epoch {epoch}, Batch {batch_idx}")
# Check gradient norms
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
# Debug noise statistics less frequently (every 5 epochs)
if batch_idx == 0 and epoch % 5 == 0:
print(f"Debug for Epoch {epoch}:")
noise_scheduler.debug_noise_stats(x0[:1], t[:1])
# Re-enable batch logging since training is stable
if batch_idx % 500 == 0: # Less frequent logging
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Grad Norm: {total_norm:.4f}")
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
writer.add_scalar('Grad_Norm/train', total_norm, epoch * len(train_loader) + batch_idx)
# Update learning rate based on epoch loss
avg_epoch_loss = epoch_loss / num_batches
scheduler.step(avg_epoch_loss)
# Log epoch statistics
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.2e}")
writer.add_scalar('Loss/epoch', avg_epoch_loss, epoch)
writer.add_scalar('Learning_Rate', current_lr, epoch)
# Validation
if epoch % config.sample_every == 0:
sample(model, noise_scheduler, device, epoch, writer)
# Save model checkpoints at epoch 30 and every 30 epochs
if epoch == 30 or (epoch > 30 and epoch % 30 == 0):
checkpoint_path = os.path.join(log_dir, f"model_epoch_{epoch}.pth")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': avg_epoch_loss,
'config': config
}, checkpoint_path)
print(f"Model checkpoint saved at epoch {epoch}: {checkpoint_path}")
torch.save(model.state_dict(), os.path.join(log_dir, "model_final.pth"))
if __name__ == "__main__":
train() |