Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import config | |
| from data.dataset import get_sr_dataloaders | |
| from models.autoencoder import SuperResAutoencoder | |
| from training.train import load_checkpoint | |
| from evaluation.evaluate import psnr, ssim | |
| def evaluate_sr(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = SuperResAutoencoder().to(device) | |
| ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_sr_{noise_type}.pth") | |
| if not os.path.exists(ckpt_path): | |
| raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.") | |
| epoch, _ = load_checkpoint(ckpt_path, model) | |
| print(f"Loaded SR checkpoint from epoch {epoch}") | |
| model.eval() | |
| _, test_loader = get_sr_dataloaders(noise_type) | |
| psnr_bicubic, psnr_sr = [], [] | |
| ssim_bicubic, ssim_sr = [], [] | |
| samples = [] # (noisy_lr, bicubic, sr_out, clean_hr) | |
| with torch.no_grad(): | |
| for noisy_lr, clean_hr in test_loader: | |
| noisy_lr, clean_hr = noisy_lr.to(device), clean_hr.to(device) | |
| sr_out = model(noisy_lr) | |
| # Bicubic baseline: upsample noisy LR to HR size | |
| bicubic = F.interpolate( | |
| noisy_lr, size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE), | |
| mode="bicubic", align_corners=False | |
| ).clamp(0, 1) | |
| for i in range(noisy_lr.size(0)): | |
| b = bicubic[i].cpu() | |
| s = sr_out[i].cpu() | |
| c = clean_hr[i].cpu() | |
| psnr_bicubic.append(psnr(b, c)) | |
| psnr_sr.append(psnr(s, c)) | |
| ssim_bicubic.append(ssim(b, c)) | |
| ssim_sr.append(ssim(s, c)) | |
| if len(samples) < n_samples: | |
| for i in range(min(n_samples - len(samples), noisy_lr.size(0))): | |
| samples.append(( | |
| noisy_lr[i].cpu(), | |
| bicubic[i].cpu(), | |
| sr_out[i].cpu(), | |
| clean_hr[i].cpu(), | |
| )) | |
| print(f"\n=== SR Evaluation ({noise_type} noise) ===") | |
| print(f"PSNR — Bicubic: {np.mean(psnr_bicubic):.2f} dB | SR Model: {np.mean(psnr_sr):.2f} dB") | |
| print(f"SSIM — Bicubic: {np.mean(ssim_bicubic):.4f} | SR Model: {np.mean(ssim_sr):.4f}") | |
| save_comparison(samples, noise_type) | |
| return np.mean(psnr_sr), np.mean(ssim_sr) | |
| def save_comparison(samples, noise_type): | |
| n = len(samples) | |
| row_labels = ["Noisy LR", "Bicubic", "SR Model", "Ground Truth"] | |
| fig, axes = plt.subplots(4, n, figsize=(2.5 * n, 10)) | |
| for col, (noisy_lr, bicubic, sr_out, clean_hr) in enumerate(samples): | |
| # Upsample noisy LR for display (it's 48×48, others are 96×96) | |
| noisy_display = F.interpolate( | |
| noisy_lr.unsqueeze(0), size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE), | |
| mode="nearest" | |
| ).squeeze(0) | |
| imgs = [noisy_display, bicubic, sr_out, clean_hr] | |
| for row, img in enumerate(imgs): | |
| ax = axes[row][col] | |
| ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy()) | |
| ax.axis("off") | |
| if col == 0: | |
| ax.set_ylabel(row_labels[row], fontsize=11, rotation=90, labelpad=5) | |
| plt.suptitle(f"Super-Resolution Results — {noise_type} noise", fontsize=14) | |
| plt.tight_layout() | |
| out_path = os.path.join(config.RESULTS_DIR, f"comparison_sr_{noise_type}.png") | |
| os.makedirs(config.RESULTS_DIR, exist_ok=True) | |
| plt.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved comparison image -> {out_path}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL, | |
| choices=config.NOISE_TYPES, help="Noise type") | |
| parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize") | |
| args = parser.parse_args() | |
| evaluate_sr(noise_type=args.noise, n_samples=args.samples) | |