image-denoiser / evaluation /evaluate.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import config
from data.dataset import get_dataloaders
from models.autoencoder import DenoisingAutoencoder
from training.train import load_checkpoint
def psnr(pred: torch.Tensor, target: torch.Tensor) -> float:
"""Peak Signal-to-Noise Ratio (higher = better)."""
mse = torch.mean((pred - target) ** 2).item()
if mse == 0:
return float("inf")
return 10 * np.log10(1.0 / mse)
def ssim(pred: torch.Tensor, target: torch.Tensor, window_size: int = 11) -> float:
"""
Structural Similarity Index (higher = better).
Computed per-channel and averaged. Range [-1, 1].
"""
C1, C2 = (0.01 ** 2), (0.03 ** 2)
results = []
for c in range(pred.shape[0]):
p = pred[c].unsqueeze(0).unsqueeze(0).float()
t = target[c].unsqueeze(0).unsqueeze(0).float()
mu_p = torch.nn.functional.avg_pool2d(p, window_size, stride=1,
padding=window_size // 2)
mu_t = torch.nn.functional.avg_pool2d(t, window_size, stride=1,
padding=window_size // 2)
mu_p_sq = mu_p ** 2
mu_t_sq = mu_t ** 2
mu_pt = mu_p * mu_t
sigma_p = torch.nn.functional.avg_pool2d(p * p, window_size, stride=1,
padding=window_size // 2) - mu_p_sq
sigma_t = torch.nn.functional.avg_pool2d(t * t, window_size, stride=1,
padding=window_size // 2) - mu_t_sq
sigma_pt = torch.nn.functional.avg_pool2d(p * t, window_size, stride=1,
padding=window_size // 2) - mu_pt
num = (2 * mu_pt + C1) * (2 * sigma_pt + C2)
den = (mu_p_sq + mu_t_sq + C1) * (sigma_p + sigma_t + C2)
results.append(torch.mean(num / den).item())
return float(np.mean(results))
def evaluate(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenoisingAutoencoder().to(device)
ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_{noise_type}.pth")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.")
epoch, _ = load_checkpoint(ckpt_path, model)
print(f"Loaded checkpoint from epoch {epoch}")
model.eval()
_, test_loader = get_dataloaders(noise_type)
all_psnr_noisy, all_psnr_denoised = [], []
all_ssim_noisy, all_ssim_denoised = [], []
sample_noisy, sample_clean, sample_denoised = [], [], []
with torch.no_grad():
for noisy, clean in test_loader:
noisy, clean = noisy.to(device), clean.to(device)
denoised = model(noisy)
for i in range(noisy.size(0)):
n, c, d = noisy[i].cpu(), clean[i].cpu(), denoised[i].cpu()
all_psnr_noisy.append(psnr(n, c))
all_psnr_denoised.append(psnr(d, c))
all_ssim_noisy.append(ssim(n, c))
all_ssim_denoised.append(ssim(d, c))
if len(sample_noisy) < n_samples:
for i in range(min(n_samples - len(sample_noisy), noisy.size(0))):
sample_noisy.append(noisy[i].cpu())
sample_clean.append(clean[i].cpu())
sample_denoised.append(denoised[i].cpu())
avg_psnr_noisy = np.mean(all_psnr_noisy)
avg_psnr_denoised = np.mean(all_psnr_denoised)
avg_ssim_noisy = np.mean(all_ssim_noisy)
avg_ssim_denoised = np.mean(all_ssim_denoised)
print(f"\n=== Evaluation ({noise_type} noise) ===")
print(f"PSNR — Noisy: {avg_psnr_noisy:.2f} dB | Denoised: {avg_psnr_denoised:.2f} dB")
print(f"SSIM — Noisy: {avg_ssim_noisy:.4f} | Denoised: {avg_ssim_denoised:.4f}")
# Save visual comparison
save_comparison(sample_noisy, sample_clean, sample_denoised, noise_type)
return avg_psnr_denoised, avg_ssim_denoised
def save_comparison(noisy_list, clean_list, denoised_list, noise_type):
n = len(noisy_list)
fig, axes = plt.subplots(3, n, figsize=(2.5 * n, 7))
titles = ["Noisy", "Denoised", "Clean"]
for col in range(n):
imgs = [noisy_list[col], denoised_list[col], clean_list[col]]
for row, img in enumerate(imgs):
ax = axes[row][col]
ax.imshow(img.permute(1, 2, 0).numpy())
ax.axis("off")
if col == 0:
ax.set_ylabel(titles[row], fontsize=12, rotation=90, labelpad=5)
plt.suptitle(f"Denoising Results — {noise_type} noise", fontsize=14)
plt.tight_layout()
out_path = os.path.join(config.RESULTS_DIR, f"comparison_{noise_type}.png")
os.makedirs(config.RESULTS_DIR, exist_ok=True)
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved comparison image → {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL,
choices=config.NOISE_TYPES, help="Noise type")
parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize")
args = parser.parse_args()
evaluate(noise_type=args.noise, n_samples=args.samples)