Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import torch | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import config | |
| from data.dataset import get_dataloaders | |
| from models.autoencoder import DenoisingAutoencoder | |
| from training.train import load_checkpoint | |
| def psnr(pred: torch.Tensor, target: torch.Tensor) -> float: | |
| """Peak Signal-to-Noise Ratio (higher = better).""" | |
| mse = torch.mean((pred - target) ** 2).item() | |
| if mse == 0: | |
| return float("inf") | |
| return 10 * np.log10(1.0 / mse) | |
| def ssim(pred: torch.Tensor, target: torch.Tensor, window_size: int = 11) -> float: | |
| """ | |
| Structural Similarity Index (higher = better). | |
| Computed per-channel and averaged. Range [-1, 1]. | |
| """ | |
| C1, C2 = (0.01 ** 2), (0.03 ** 2) | |
| results = [] | |
| for c in range(pred.shape[0]): | |
| p = pred[c].unsqueeze(0).unsqueeze(0).float() | |
| t = target[c].unsqueeze(0).unsqueeze(0).float() | |
| mu_p = torch.nn.functional.avg_pool2d(p, window_size, stride=1, | |
| padding=window_size // 2) | |
| mu_t = torch.nn.functional.avg_pool2d(t, window_size, stride=1, | |
| padding=window_size // 2) | |
| mu_p_sq = mu_p ** 2 | |
| mu_t_sq = mu_t ** 2 | |
| mu_pt = mu_p * mu_t | |
| sigma_p = torch.nn.functional.avg_pool2d(p * p, window_size, stride=1, | |
| padding=window_size // 2) - mu_p_sq | |
| sigma_t = torch.nn.functional.avg_pool2d(t * t, window_size, stride=1, | |
| padding=window_size // 2) - mu_t_sq | |
| sigma_pt = torch.nn.functional.avg_pool2d(p * t, window_size, stride=1, | |
| padding=window_size // 2) - mu_pt | |
| num = (2 * mu_pt + C1) * (2 * sigma_pt + C2) | |
| den = (mu_p_sq + mu_t_sq + C1) * (sigma_p + sigma_t + C2) | |
| results.append(torch.mean(num / den).item()) | |
| return float(np.mean(results)) | |
| def evaluate(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = DenoisingAutoencoder().to(device) | |
| ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_{noise_type}.pth") | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.") | |
| epoch, _ = load_checkpoint(ckpt_path, model) | |
| print(f"Loaded checkpoint from epoch {epoch}") | |
| model.eval() | |
| _, test_loader = get_dataloaders(noise_type) | |
| all_psnr_noisy, all_psnr_denoised = [], [] | |
| all_ssim_noisy, all_ssim_denoised = [], [] | |
| sample_noisy, sample_clean, sample_denoised = [], [], [] | |
| with torch.no_grad(): | |
| for noisy, clean in test_loader: | |
| noisy, clean = noisy.to(device), clean.to(device) | |
| denoised = model(noisy) | |
| for i in range(noisy.size(0)): | |
| n, c, d = noisy[i].cpu(), clean[i].cpu(), denoised[i].cpu() | |
| all_psnr_noisy.append(psnr(n, c)) | |
| all_psnr_denoised.append(psnr(d, c)) | |
| all_ssim_noisy.append(ssim(n, c)) | |
| all_ssim_denoised.append(ssim(d, c)) | |
| if len(sample_noisy) < n_samples: | |
| for i in range(min(n_samples - len(sample_noisy), noisy.size(0))): | |
| sample_noisy.append(noisy[i].cpu()) | |
| sample_clean.append(clean[i].cpu()) | |
| sample_denoised.append(denoised[i].cpu()) | |
| avg_psnr_noisy = np.mean(all_psnr_noisy) | |
| avg_psnr_denoised = np.mean(all_psnr_denoised) | |
| avg_ssim_noisy = np.mean(all_ssim_noisy) | |
| avg_ssim_denoised = np.mean(all_ssim_denoised) | |
| print(f"\n=== Evaluation ({noise_type} noise) ===") | |
| print(f"PSNR — Noisy: {avg_psnr_noisy:.2f} dB | Denoised: {avg_psnr_denoised:.2f} dB") | |
| print(f"SSIM — Noisy: {avg_ssim_noisy:.4f} | Denoised: {avg_ssim_denoised:.4f}") | |
| # Save visual comparison | |
| save_comparison(sample_noisy, sample_clean, sample_denoised, noise_type) | |
| return avg_psnr_denoised, avg_ssim_denoised | |
| def save_comparison(noisy_list, clean_list, denoised_list, noise_type): | |
| n = len(noisy_list) | |
| fig, axes = plt.subplots(3, n, figsize=(2.5 * n, 7)) | |
| titles = ["Noisy", "Denoised", "Clean"] | |
| for col in range(n): | |
| imgs = [noisy_list[col], denoised_list[col], clean_list[col]] | |
| for row, img in enumerate(imgs): | |
| ax = axes[row][col] | |
| ax.imshow(img.permute(1, 2, 0).numpy()) | |
| ax.axis("off") | |
| if col == 0: | |
| ax.set_ylabel(titles[row], fontsize=12, rotation=90, labelpad=5) | |
| plt.suptitle(f"Denoising Results — {noise_type} noise", fontsize=14) | |
| plt.tight_layout() | |
| out_path = os.path.join(config.RESULTS_DIR, f"comparison_{noise_type}.png") | |
| os.makedirs(config.RESULTS_DIR, exist_ok=True) | |
| plt.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved comparison image → {out_path}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL, | |
| choices=config.NOISE_TYPES, help="Noise type") | |
| parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize") | |
| args = parser.parse_args() | |
| evaluate(noise_type=args.noise, n_samples=args.samples) | |