import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import torch import torch.nn.functional as F import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import config from data.dataset import get_sr_dataloaders from models.autoencoder import SuperResAutoencoder from training.train import load_checkpoint from evaluation.evaluate import psnr, ssim def evaluate_sr(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SuperResAutoencoder().to(device) ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_sr_{noise_type}.pth") if not os.path.exists(ckpt_path): raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.") epoch, _ = load_checkpoint(ckpt_path, model) print(f"Loaded SR checkpoint from epoch {epoch}") model.eval() _, test_loader = get_sr_dataloaders(noise_type) psnr_bicubic, psnr_sr = [], [] ssim_bicubic, ssim_sr = [], [] samples = [] # (noisy_lr, bicubic, sr_out, clean_hr) with torch.no_grad(): for noisy_lr, clean_hr in test_loader: noisy_lr, clean_hr = noisy_lr.to(device), clean_hr.to(device) sr_out = model(noisy_lr) # Bicubic baseline: upsample noisy LR to HR size bicubic = F.interpolate( noisy_lr, size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE), mode="bicubic", align_corners=False ).clamp(0, 1) for i in range(noisy_lr.size(0)): b = bicubic[i].cpu() s = sr_out[i].cpu() c = clean_hr[i].cpu() psnr_bicubic.append(psnr(b, c)) psnr_sr.append(psnr(s, c)) ssim_bicubic.append(ssim(b, c)) ssim_sr.append(ssim(s, c)) if len(samples) < n_samples: for i in range(min(n_samples - len(samples), noisy_lr.size(0))): samples.append(( noisy_lr[i].cpu(), bicubic[i].cpu(), sr_out[i].cpu(), clean_hr[i].cpu(), )) print(f"\n=== SR Evaluation ({noise_type} noise) ===") print(f"PSNR — Bicubic: {np.mean(psnr_bicubic):.2f} dB | SR Model: {np.mean(psnr_sr):.2f} dB") print(f"SSIM — Bicubic: {np.mean(ssim_bicubic):.4f} | SR Model: {np.mean(ssim_sr):.4f}") save_comparison(samples, noise_type) return np.mean(psnr_sr), np.mean(ssim_sr) def save_comparison(samples, noise_type): n = len(samples) row_labels = ["Noisy LR", "Bicubic", "SR Model", "Ground Truth"] fig, axes = plt.subplots(4, n, figsize=(2.5 * n, 10)) for col, (noisy_lr, bicubic, sr_out, clean_hr) in enumerate(samples): # Upsample noisy LR for display (it's 48×48, others are 96×96) noisy_display = F.interpolate( noisy_lr.unsqueeze(0), size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE), mode="nearest" ).squeeze(0) imgs = [noisy_display, bicubic, sr_out, clean_hr] for row, img in enumerate(imgs): ax = axes[row][col] ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy()) ax.axis("off") if col == 0: ax.set_ylabel(row_labels[row], fontsize=11, rotation=90, labelpad=5) plt.suptitle(f"Super-Resolution Results — {noise_type} noise", fontsize=14) plt.tight_layout() out_path = os.path.join(config.RESULTS_DIR, f"comparison_sr_{noise_type}.png") os.makedirs(config.RESULTS_DIR, exist_ok=True) plt.savefig(out_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved comparison image -> {out_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL, choices=config.NOISE_TYPES, help="Noise type") parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize") args = parser.parse_args() evaluate_sr(noise_type=args.noise, n_samples=args.samples)