import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import torch import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import config from data.dataset import get_dataloaders from models.autoencoder import DenoisingAutoencoder from training.train import load_checkpoint def psnr(pred: torch.Tensor, target: torch.Tensor) -> float: """Peak Signal-to-Noise Ratio (higher = better).""" mse = torch.mean((pred - target) ** 2).item() if mse == 0: return float("inf") return 10 * np.log10(1.0 / mse) def ssim(pred: torch.Tensor, target: torch.Tensor, window_size: int = 11) -> float: """ Structural Similarity Index (higher = better). Computed per-channel and averaged. Range [-1, 1]. """ C1, C2 = (0.01 ** 2), (0.03 ** 2) results = [] for c in range(pred.shape[0]): p = pred[c].unsqueeze(0).unsqueeze(0).float() t = target[c].unsqueeze(0).unsqueeze(0).float() mu_p = torch.nn.functional.avg_pool2d(p, window_size, stride=1, padding=window_size // 2) mu_t = torch.nn.functional.avg_pool2d(t, window_size, stride=1, padding=window_size // 2) mu_p_sq = mu_p ** 2 mu_t_sq = mu_t ** 2 mu_pt = mu_p * mu_t sigma_p = torch.nn.functional.avg_pool2d(p * p, window_size, stride=1, padding=window_size // 2) - mu_p_sq sigma_t = torch.nn.functional.avg_pool2d(t * t, window_size, stride=1, padding=window_size // 2) - mu_t_sq sigma_pt = torch.nn.functional.avg_pool2d(p * t, window_size, stride=1, padding=window_size // 2) - mu_pt num = (2 * mu_pt + C1) * (2 * sigma_pt + C2) den = (mu_p_sq + mu_t_sq + C1) * (sigma_p + sigma_t + C2) results.append(torch.mean(num / den).item()) return float(np.mean(results)) def evaluate(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DenoisingAutoencoder().to(device) ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_{noise_type}.pth") if not os.path.exists(ckpt_path): raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.") epoch, _ = load_checkpoint(ckpt_path, model) print(f"Loaded checkpoint from epoch {epoch}") model.eval() _, test_loader = get_dataloaders(noise_type) all_psnr_noisy, all_psnr_denoised = [], [] all_ssim_noisy, all_ssim_denoised = [], [] sample_noisy, sample_clean, sample_denoised = [], [], [] with torch.no_grad(): for noisy, clean in test_loader: noisy, clean = noisy.to(device), clean.to(device) denoised = model(noisy) for i in range(noisy.size(0)): n, c, d = noisy[i].cpu(), clean[i].cpu(), denoised[i].cpu() all_psnr_noisy.append(psnr(n, c)) all_psnr_denoised.append(psnr(d, c)) all_ssim_noisy.append(ssim(n, c)) all_ssim_denoised.append(ssim(d, c)) if len(sample_noisy) < n_samples: for i in range(min(n_samples - len(sample_noisy), noisy.size(0))): sample_noisy.append(noisy[i].cpu()) sample_clean.append(clean[i].cpu()) sample_denoised.append(denoised[i].cpu()) avg_psnr_noisy = np.mean(all_psnr_noisy) avg_psnr_denoised = np.mean(all_psnr_denoised) avg_ssim_noisy = np.mean(all_ssim_noisy) avg_ssim_denoised = np.mean(all_ssim_denoised) print(f"\n=== Evaluation ({noise_type} noise) ===") print(f"PSNR — Noisy: {avg_psnr_noisy:.2f} dB | Denoised: {avg_psnr_denoised:.2f} dB") print(f"SSIM — Noisy: {avg_ssim_noisy:.4f} | Denoised: {avg_ssim_denoised:.4f}") # Save visual comparison save_comparison(sample_noisy, sample_clean, sample_denoised, noise_type) return avg_psnr_denoised, avg_ssim_denoised def save_comparison(noisy_list, clean_list, denoised_list, noise_type): n = len(noisy_list) fig, axes = plt.subplots(3, n, figsize=(2.5 * n, 7)) titles = ["Noisy", "Denoised", "Clean"] for col in range(n): imgs = [noisy_list[col], denoised_list[col], clean_list[col]] for row, img in enumerate(imgs): ax = axes[row][col] ax.imshow(img.permute(1, 2, 0).numpy()) ax.axis("off") if col == 0: ax.set_ylabel(titles[row], fontsize=12, rotation=90, labelpad=5) plt.suptitle(f"Denoising Results — {noise_type} noise", fontsize=14) plt.tight_layout() out_path = os.path.join(config.RESULTS_DIR, f"comparison_{noise_type}.png") os.makedirs(config.RESULTS_DIR, exist_ok=True) plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved comparison image → {out_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL, choices=config.NOISE_TYPES, help="Noise type") parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize") args = parser.parse_args() evaluate(noise_type=args.noise, n_samples=args.samples)