|
|
import torch |
|
|
from model import SmoothDiffusionUNet |
|
|
from noise_scheduler import FrequencyAwareNoise |
|
|
from config import Config |
|
|
from torchvision.utils import save_image, make_grid |
|
|
from dataloader import get_dataloaders |
|
|
import numpy as np |
|
|
|
|
|
def diagnose_and_fix(): |
|
|
"""Final diagnosis and alternative sampling approach""" |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
checkpoint = torch.load('model_final.pth', map_location=device) |
|
|
config = Config() |
|
|
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
|
noise_scheduler = FrequencyAwareNoise(config) |
|
|
model.load_state_dict(checkpoint) |
|
|
model.eval() |
|
|
|
|
|
print("=== FINAL DIAGNOSIS ===") |
|
|
|
|
|
|
|
|
train_loader, _ = get_dataloaders(config) |
|
|
real_batch, _ = next(iter(train_loader)) |
|
|
real_images = real_batch[:4].to(device) |
|
|
|
|
|
print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]") |
|
|
print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}") |
|
|
|
|
|
|
|
|
real_display = torch.clamp((real_images + 1) / 2, 0, 1) |
|
|
real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0) |
|
|
save_image(real_grid, "real_training_images.png") |
|
|
print("Real training images saved to real_training_images.png") |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
print("\n=== TESTING MODEL ON REAL DATA ===") |
|
|
|
|
|
for t_val in [50, 200, 400]: |
|
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor) |
|
|
|
|
|
|
|
|
noise_pred = model(x_noisy, t_tensor) |
|
|
|
|
|
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
|
x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) |
|
|
x_reconstructed = torch.clamp(x_reconstructed, -1, 1) |
|
|
|
|
|
print(f"\nTimestep {t_val}:") |
|
|
print(f" Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}") |
|
|
|
|
|
|
|
|
recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1) |
|
|
recon_grid = make_grid(recon_display, nrow=2, normalize=False) |
|
|
save_image(recon_grid, f"reconstruction_t{t_val}.png") |
|
|
print(f" Reconstruction saved to reconstruction_t{t_val}.png") |
|
|
|
|
|
print("\n=== TRYING INTERPOLATION SAMPLING ===") |
|
|
|
|
|
|
|
|
x1 = real_images[0:1] |
|
|
x2 = real_images[1:2] |
|
|
|
|
|
|
|
|
alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1) |
|
|
x_interp = torch.cat([ |
|
|
alpha * x1 + (1 - alpha) * x2 for alpha in alphas |
|
|
], dim=0) |
|
|
|
|
|
print(f"Starting from real image interpolation...") |
|
|
print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]") |
|
|
|
|
|
|
|
|
timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1] |
|
|
|
|
|
x = x_interp.clone() |
|
|
|
|
|
for t_val in timesteps: |
|
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
predicted_noise = model(x, t_tensor) |
|
|
|
|
|
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
|
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) |
|
|
x = torch.clamp(x, -1, 1) |
|
|
|
|
|
print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]") |
|
|
|
|
|
|
|
|
interp_display = torch.clamp((x + 1) / 2, 0, 1) |
|
|
interp_grid = make_grid(interp_display, nrow=2, normalize=False) |
|
|
save_image(interp_grid, "interpolation_sampling.png") |
|
|
print("Interpolation sampling saved to interpolation_sampling.png") |
|
|
|
|
|
print("\n=== TRYING MINIMAL NOISE SAMPLING ===") |
|
|
|
|
|
|
|
|
x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1 |
|
|
|
|
|
|
|
|
light_timesteps = [50, 30, 15, 5, 1] |
|
|
|
|
|
for t_val in light_timesteps: |
|
|
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
predicted_noise = model(x_minimal, t_tensor) |
|
|
|
|
|
|
|
|
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
|
|
x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) |
|
|
x_minimal = torch.clamp(x_minimal, -1, 1) |
|
|
|
|
|
print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]") |
|
|
print(f"Minimal noise result std: {x_minimal.std():.3f}") |
|
|
|
|
|
|
|
|
minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1) |
|
|
minimal_grid = make_grid(minimal_display, nrow=2, normalize=False) |
|
|
save_image(minimal_grid, "minimal_noise_sampling.png") |
|
|
print("Minimal noise sampling saved to minimal_noise_sampling.png") |
|
|
|
|
|
print("\n=== SUMMARY ===") |
|
|
print("Generated files:") |
|
|
print("- real_training_images.png (what we want to achieve)") |
|
|
print("- reconstruction_t*.png (model's denoising ability)") |
|
|
print("- interpolation_sampling.png (interpolation approach)") |
|
|
print("- minimal_noise_sampling.png (light noise approach)") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
diagnose_and_fix() |
|
|
|