File size: 6,091 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np

def diagnose_and_fix():
    """Final diagnosis and alternative sampling approach"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    checkpoint = torch.load('model_final.pth', map_location=device)
    config = Config()
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    model.load_state_dict(checkpoint)
    model.eval()
    
    print("=== FINAL DIAGNOSIS ===")
    
    # Load some real training data to compare
    train_loader, _ = get_dataloaders(config)
    real_batch, _ = next(iter(train_loader))
    real_images = real_batch[:4].to(device)
    
    print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]")
    print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}")
    
    # Save real images for comparison
    real_display = torch.clamp((real_images + 1) / 2, 0, 1)
    real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0)
    save_image(real_grid, "real_training_images.png")
    print("Real training images saved to real_training_images.png")
    
    with torch.no_grad():
        # Test model on real data at different noise levels
        print("\n=== TESTING MODEL ON REAL DATA ===")
        
        for t_val in [50, 200, 400]:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            
            # Add noise to real image
            x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor)
            
            # Get model prediction
            noise_pred = model(x_noisy, t_tensor)
            
            # Try to reconstruct
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
            x_reconstructed = torch.clamp(x_reconstructed, -1, 1)
            
            print(f"\nTimestep {t_val}:")
            print(f"  Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}")
            
            # Save reconstruction
            recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1)
            recon_grid = make_grid(recon_display, nrow=2, normalize=False)
            save_image(recon_grid, f"reconstruction_t{t_val}.png")
            print(f"  Reconstruction saved to reconstruction_t{t_val}.png")
        
        print("\n=== TRYING INTERPOLATION SAMPLING ===")
        
        # Instead of starting from pure noise, interpolate between real images
        x1 = real_images[0:1]  # First real image
        x2 = real_images[1:2]  # Second real image
        
        # Create interpolations
        alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1)
        x_interp = torch.cat([
            alpha * x1 + (1 - alpha) * x2 for alpha in alphas
        ], dim=0)
        
        print(f"Starting from real image interpolation...")
        print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]")
        
        # Apply light denoising starting from these interpolated real images
        timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1]
        
        x = x_interp.clone()
        
        for t_val in timesteps:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            
            # Get model prediction
            predicted_noise = model(x, t_tensor)
            
            # Apply denoising step
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t)  # Gentle denoising
            x = torch.clamp(x, -1, 1)
        
        print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]")
        
        # Save interpolation result
        interp_display = torch.clamp((x + 1) / 2, 0, 1)
        interp_grid = make_grid(interp_display, nrow=2, normalize=False)
        save_image(interp_grid, "interpolation_sampling.png")
        print("Interpolation sampling saved to interpolation_sampling.png")
        
        print("\n=== TRYING MINIMAL NOISE SAMPLING ===")
        
        # Start from very light noise around zero
        x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1  # Very light noise
        
        # Apply just a few denoising steps
        light_timesteps = [50, 30, 15, 5, 1]
        
        for t_val in light_timesteps:
            t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
            
            # Get model prediction
            predicted_noise = model(x_minimal, t_tensor)
            
            # Light denoising
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
            x_minimal = torch.clamp(x_minimal, -1, 1)
        
        print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]")
        print(f"Minimal noise result std: {x_minimal.std():.3f}")
        
        # Save minimal noise result
        minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1)
        minimal_grid = make_grid(minimal_display, nrow=2, normalize=False)
        save_image(minimal_grid, "minimal_noise_sampling.png")
        print("Minimal noise sampling saved to minimal_noise_sampling.png")
        
        print("\n=== SUMMARY ===")
        print("Generated files:")
        print("- real_training_images.png (what we want to achieve)")
        print("- reconstruction_t*.png (model's denoising ability)")
        print("- interpolation_sampling.png (interpolation approach)")
        print("- minimal_noise_sampling.png (light noise approach)")

if __name__ == "__main__":
    diagnose_and_fix()