image-denoiser / evaluation /evaluate_sr.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import config
from data.dataset import get_sr_dataloaders
from models.autoencoder import SuperResAutoencoder
from training.train import load_checkpoint
from evaluation.evaluate import psnr, ssim
def evaluate_sr(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SuperResAutoencoder().to(device)
ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_sr_{noise_type}.pth")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.")
epoch, _ = load_checkpoint(ckpt_path, model)
print(f"Loaded SR checkpoint from epoch {epoch}")
model.eval()
_, test_loader = get_sr_dataloaders(noise_type)
psnr_bicubic, psnr_sr = [], []
ssim_bicubic, ssim_sr = [], []
samples = [] # (noisy_lr, bicubic, sr_out, clean_hr)
with torch.no_grad():
for noisy_lr, clean_hr in test_loader:
noisy_lr, clean_hr = noisy_lr.to(device), clean_hr.to(device)
sr_out = model(noisy_lr)
# Bicubic baseline: upsample noisy LR to HR size
bicubic = F.interpolate(
noisy_lr, size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE),
mode="bicubic", align_corners=False
).clamp(0, 1)
for i in range(noisy_lr.size(0)):
b = bicubic[i].cpu()
s = sr_out[i].cpu()
c = clean_hr[i].cpu()
psnr_bicubic.append(psnr(b, c))
psnr_sr.append(psnr(s, c))
ssim_bicubic.append(ssim(b, c))
ssim_sr.append(ssim(s, c))
if len(samples) < n_samples:
for i in range(min(n_samples - len(samples), noisy_lr.size(0))):
samples.append((
noisy_lr[i].cpu(),
bicubic[i].cpu(),
sr_out[i].cpu(),
clean_hr[i].cpu(),
))
print(f"\n=== SR Evaluation ({noise_type} noise) ===")
print(f"PSNR — Bicubic: {np.mean(psnr_bicubic):.2f} dB | SR Model: {np.mean(psnr_sr):.2f} dB")
print(f"SSIM — Bicubic: {np.mean(ssim_bicubic):.4f} | SR Model: {np.mean(ssim_sr):.4f}")
save_comparison(samples, noise_type)
return np.mean(psnr_sr), np.mean(ssim_sr)
def save_comparison(samples, noise_type):
n = len(samples)
row_labels = ["Noisy LR", "Bicubic", "SR Model", "Ground Truth"]
fig, axes = plt.subplots(4, n, figsize=(2.5 * n, 10))
for col, (noisy_lr, bicubic, sr_out, clean_hr) in enumerate(samples):
# Upsample noisy LR for display (it's 48×48, others are 96×96)
noisy_display = F.interpolate(
noisy_lr.unsqueeze(0), size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE),
mode="nearest"
).squeeze(0)
imgs = [noisy_display, bicubic, sr_out, clean_hr]
for row, img in enumerate(imgs):
ax = axes[row][col]
ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())
ax.axis("off")
if col == 0:
ax.set_ylabel(row_labels[row], fontsize=11, rotation=90, labelpad=5)
plt.suptitle(f"Super-Resolution Results — {noise_type} noise", fontsize=14)
plt.tight_layout()
out_path = os.path.join(config.RESULTS_DIR, f"comparison_sr_{noise_type}.png")
os.makedirs(config.RESULTS_DIR, exist_ok=True)
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved comparison image -> {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL,
choices=config.NOISE_TYPES, help="Noise type")
parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize")
args = parser.parse_args()
evaluate_sr(noise_type=args.noise, n_samples=args.samples)