File size: 4,633 Bytes
8abfb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image
import numpy as np

def test_model_quality():
    """Test if the model can actually denoise"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    checkpoint = torch.load('model_final.pth', map_location=device)
    config = Config()
    
    model = SmoothDiffusionUNet(config).to(device)
    noise_scheduler = FrequencyAwareNoise(config)
    model.load_state_dict(checkpoint)
    model.eval()
    
    print("=== TESTING MODEL DENOISING ABILITY ===")
    
    with torch.no_grad():
        # Create a simple test pattern
        x_clean = torch.zeros(1, 3, 64, 64, device=device)
        
        # Create clear patterns that should be easy to denoise
        x_clean[0, 0, 20:44, 20:44] = 1.0   # Red square
        x_clean[0, 1, 10:30, 40:60] = -1.0  # Green rectangle  
        x_clean[0, 2, 35:50, 10:25] = 0.5   # Blue rectangle
        
        print(f"Created test pattern with range [{x_clean.min():.3f}, {x_clean.max():.3f}]")
        
        # Test at different noise levels
        test_timesteps = [50, 100, 200, 400]
        
        for t_val in test_timesteps:
            print(f"\n--- Testing at timestep {t_val} ---")
            
            t_tensor = torch.full((1,), t_val, device=device, dtype=torch.long)
            
            # Add noise like in training
            x_noisy, noise_target = noise_scheduler.apply_noise(x_clean, t_tensor)
            
            # Get model prediction
            noise_pred = model(x_noisy, t_tensor)
            
            # Calculate accuracy
            mse = torch.mean((noise_pred - noise_target) ** 2)
            mae = torch.mean(torch.abs(noise_pred - noise_target))
            
            print(f"  Noisy image range: [{x_noisy.min():.3f}, {x_noisy.max():.3f}]")
            print(f"  Target noise range: [{noise_target.min():.3f}, {noise_target.max():.3f}]")
            print(f"  Predicted noise range: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]")
            print(f"  MSE: {mse.item():.6f}")
            print(f"  MAE: {mae.item():.6f}")
            
            # Try to reconstruct clean image
            alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
            x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
            x_reconstructed = torch.clamp(x_reconstructed, -1, 1)
            
            reconstruction_error = torch.mean((x_reconstructed - x_clean) ** 2)
            print(f"  Reconstruction MSE: {reconstruction_error.item():.6f}")
            
            if mse.item() > 1.0:
                print(f"  ❌ High prediction error - model didn't learn well")
            elif reconstruction_error.item() > 0.5:
                print(f"  ⚠️  Poor reconstruction - model learned noise but not images")
            else:
                print(f"  ✅ Good denoising performance")
        
        # Save test images
        print(f"\n=== SAVING TEST IMAGES ===")
        
        # Save original test pattern
        x_clean_display = (x_clean + 1) / 2
        save_image(x_clean_display, "test_pattern_clean.png")
        print(f"Clean test pattern saved to test_pattern_clean.png")
        
        # Save heavily noised version
        t_heavy = torch.full((1,), 400, device=device, dtype=torch.long)
        x_heavy_noisy, _ = noise_scheduler.apply_noise(x_clean, t_heavy)
        x_heavy_display = torch.clamp((x_heavy_noisy + 1) / 2, 0, 1)
        save_image(x_heavy_display, "test_pattern_noisy.png")
        print(f"Noisy test pattern saved to test_pattern_noisy.png")
        
        # Try to denoise it
        noise_pred = model(x_heavy_noisy, t_heavy)
        alpha_bar_t = noise_scheduler.alpha_bars[400].item()
        x_denoised = (x_heavy_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
        x_denoised = torch.clamp(x_denoised, -1, 1)
        x_denoised_display = (x_denoised + 1) / 2
        save_image(x_denoised_display, "test_pattern_denoised.png")
        print(f"Denoised test pattern saved to test_pattern_denoised.png")
        
        final_error = torch.mean((x_denoised - x_clean) ** 2)
        print(f"Final reconstruction error: {final_error.item():.6f}")
        
        if final_error.item() < 0.1:
            print("✅ Model can denoise simple patterns!")
        else:
            print("❌ Model cannot denoise - training was unsuccessful")

if __name__ == "__main__":
    test_model_quality()