Grad-CDM / final_diagnosis.py
nazgut's picture
Upload 24 files
8abfb97 verified
import torch
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from config import Config
from torchvision.utils import save_image, make_grid
from dataloader import get_dataloaders
import numpy as np
def diagnose_and_fix():
"""Final diagnosis and alternative sampling approach"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
checkpoint = torch.load('model_final.pth', map_location=device)
config = Config()
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
model.load_state_dict(checkpoint)
model.eval()
print("=== FINAL DIAGNOSIS ===")
# Load some real training data to compare
train_loader, _ = get_dataloaders(config)
real_batch, _ = next(iter(train_loader))
real_images = real_batch[:4].to(device)
print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]")
print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}")
# Save real images for comparison
real_display = torch.clamp((real_images + 1) / 2, 0, 1)
real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0)
save_image(real_grid, "real_training_images.png")
print("Real training images saved to real_training_images.png")
with torch.no_grad():
# Test model on real data at different noise levels
print("\n=== TESTING MODEL ON REAL DATA ===")
for t_val in [50, 200, 400]:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Add noise to real image
x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor)
# Get model prediction
noise_pred = model(x_noisy, t_tensor)
# Try to reconstruct
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t)
x_reconstructed = torch.clamp(x_reconstructed, -1, 1)
print(f"\nTimestep {t_val}:")
print(f" Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}")
# Save reconstruction
recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1)
recon_grid = make_grid(recon_display, nrow=2, normalize=False)
save_image(recon_grid, f"reconstruction_t{t_val}.png")
print(f" Reconstruction saved to reconstruction_t{t_val}.png")
print("\n=== TRYING INTERPOLATION SAMPLING ===")
# Instead of starting from pure noise, interpolate between real images
x1 = real_images[0:1] # First real image
x2 = real_images[1:2] # Second real image
# Create interpolations
alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1)
x_interp = torch.cat([
alpha * x1 + (1 - alpha) * x2 for alpha in alphas
], dim=0)
print(f"Starting from real image interpolation...")
print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]")
# Apply light denoising starting from these interpolated real images
timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1]
x = x_interp.clone()
for t_val in timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x, t_tensor)
# Apply denoising step
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) # Gentle denoising
x = torch.clamp(x, -1, 1)
print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]")
# Save interpolation result
interp_display = torch.clamp((x + 1) / 2, 0, 1)
interp_grid = make_grid(interp_display, nrow=2, normalize=False)
save_image(interp_grid, "interpolation_sampling.png")
print("Interpolation sampling saved to interpolation_sampling.png")
print("\n=== TRYING MINIMAL NOISE SAMPLING ===")
# Start from very light noise around zero
x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1 # Very light noise
# Apply just a few denoising steps
light_timesteps = [50, 30, 15, 5, 1]
for t_val in light_timesteps:
t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long)
# Get model prediction
predicted_noise = model(x_minimal, t_tensor)
# Light denoising
alpha_bar_t = noise_scheduler.alpha_bars[t_val].item()
x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t)
x_minimal = torch.clamp(x_minimal, -1, 1)
print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]")
print(f"Minimal noise result std: {x_minimal.std():.3f}")
# Save minimal noise result
minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1)
minimal_grid = make_grid(minimal_display, nrow=2, normalize=False)
save_image(minimal_grid, "minimal_noise_sampling.png")
print("Minimal noise sampling saved to minimal_noise_sampling.png")
print("\n=== SUMMARY ===")
print("Generated files:")
print("- real_training_images.png (what we want to achieve)")
print("- reconstruction_t*.png (model's denoising ability)")
print("- interpolation_sampling.png (interpolation approach)")
print("- minimal_noise_sampling.png (light noise approach)")
if __name__ == "__main__":
diagnose_and_fix()