File size: 6,091 Bytes
8abfb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np
def diagnose_and_fix():
"""Final diagnosis and alternative sampling approach"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
model.eval()
print("=== FINAL DIAGNOSIS ===")
# Load some real training data to compare
train_loader, _ = get_dataloaders(config)
real_batch, _ = next(iter(train_loader))
real_images = real_batch[:4].to(device)
print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]")
print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}")
# Save real images for comparison
real_display = torch.clamp((real_images + 1) / 2, 0, 1)
real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0)
save_image(real_grid, "real_training_images.png")
print("Real training images saved to real_training_images.png")
with torch.no_grad():
# Test model on real data at different noise levels
print("\n=== TESTING MODEL ON REAL DATA ===")
for t_val in [50, 200, 400]:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Add noise to real image
x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor)
# Get model prediction
noise_pred = model(x_noisy, t_tensor)
# Try to reconstruct
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
x_reconstructed = torch.clamp(x_reconstructed, -1, 1)
print(f"\nTimestep {t_val}:")
print(f" Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}")
# Save reconstruction
recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1)
recon_grid = make_grid(recon_display, nrow=2, normalize=False)
save_image(recon_grid, f"reconstruction_t{t_val}.png")
print(f" Reconstruction saved to reconstruction_t{t_val}.png")
print("\n=== TRYING INTERPOLATION SAMPLING ===")
# Instead of starting from pure noise, interpolate between real images
x1 = real_images[0:1] # First real image
x2 = real_images[1:2] # Second real image
# Create interpolations
alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1)
x_interp = torch.cat([
alpha * x1 + (1 - alpha) * x2 for alpha in alphas
], dim=0)
print(f"Starting from real image interpolation...")
print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]")
# Apply light denoising starting from these interpolated real images
timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1]
x = x_interp.clone()
for t_val in timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Apply denoising step
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) # Gentle denoising
x = torch.clamp(x, -1, 1)
print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]")
# Save interpolation result
interp_display = torch.clamp((x + 1) / 2, 0, 1)
interp_grid = make_grid(interp_display, nrow=2, normalize=False)
save_image(interp_grid, "interpolation_sampling.png")
print("Interpolation sampling saved to interpolation_sampling.png")
print("\n=== TRYING MINIMAL NOISE SAMPLING ===")
# Start from very light noise around zero
x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1 # Very light noise
# Apply just a few denoising steps
light_timesteps = [50, 30, 15, 5, 1]
for t_val in light_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x_minimal, t_tensor)
# Light denoising
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
x_minimal = torch.clamp(x_minimal, -1, 1)
print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]")
print(f"Minimal noise result std: {x_minimal.std():.3f}")
# Save minimal noise result
minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1)
minimal_grid = make_grid(minimal_display, nrow=2, normalize=False)
save_image(minimal_grid, "minimal_noise_sampling.png")
print("Minimal noise sampling saved to minimal_noise_sampling.png")
print("\n=== SUMMARY ===")
print("Generated files:")
print("- real_training_images.png (what we want to achieve)")
print("- reconstruction_t*.png (model's denoising ability)")
print("- interpolation_sampling.png (interpolation approach)")
print("- minimal_noise_sampling.png (light noise approach)")
if __name__ == "__main__":
diagnose_and_fix()
|