File size: 4,668 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from dataloader import get_dataloaders
from loss import diffusion_loss
from sample import sample

def train():
    config = Config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Setup logging
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join(config.log_dir, timestamp)
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)
    
    # Initialize components
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    train_loader, val_loader = get_dataloaders(config)
    
    # Training loop
    for epoch in range(config.epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx, (x0, _) in enumerate(train_loader):
            x0 = x0.to(device)
            
            # Sample random timesteps
            t = torch.randint(0, config.T, (x0.size(0),), device=device)
            
            # Compute loss
            loss = diffusion_loss(model, x0, t, noise_scheduler, config)
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Add gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # Increased from 1.0
            
            optimizer.step()
            
            # Track epoch loss for scheduler
            epoch_loss += loss.item()
            num_batches += 1
            
            # Logging with more details
            if batch_idx % 100 == 0:
                # Check for NaN values
                if torch.isnan(loss):
                    print(f"WARNING: NaN loss detected at Epoch {epoch}, Batch {batch_idx}")
                
                # Check gradient norms
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                
                # Debug noise statistics less frequently (every 5 epochs)
                if batch_idx == 0 and epoch % 5 == 0:
                    print(f"Debug for Epoch {epoch}:")
                    noise_scheduler.debug_noise_stats(x0[:1], t[:1])
                
                # Re-enable batch logging since training is stable
                if batch_idx % 500 == 0:  # Less frequent logging
                    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Grad Norm: {total_norm:.4f}")
                writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
                writer.add_scalar('Grad_Norm/train', total_norm, epoch * len(train_loader) + batch_idx)
        
        # Update learning rate based on epoch loss
        avg_epoch_loss = epoch_loss / num_batches
        scheduler.step(avg_epoch_loss)
        
        # Log epoch statistics
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.2e}")
        writer.add_scalar('Loss/epoch', avg_epoch_loss, epoch)
        writer.add_scalar('Learning_Rate', current_lr, epoch)
        
        # Validation
        if epoch % config.sample_every == 0:
            sample(model, noise_scheduler, device, epoch, writer)
        
        # Save model checkpoints at epoch 30 and every 30 epochs
        if epoch == 30 or (epoch > 30 and epoch % 30 == 0):
            checkpoint_path = os.path.join(log_dir, f"model_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_epoch_loss,
                'config': config
            }, checkpoint_path)
            print(f"Model checkpoint saved at epoch {epoch}: {checkpoint_path}")
            
    torch.save(model.state_dict(), os.path.join(log_dir, "model_final.pth"))

if __name__ == "__main__":
    train()