Spaces:
Sleeping
Sleeping
File size: 5,507 Bytes
8b83582 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import config
from data.dataset import get_dataloaders
from models.autoencoder import DenoisingAutoencoder
from training.train import load_checkpoint
def psnr(pred: torch.Tensor, target: torch.Tensor) -> float:
"""Peak Signal-to-Noise Ratio (higher = better)."""
mse = torch.mean((pred - target) ** 2).item()
if mse == 0:
return float("inf")
return 10 * np.log10(1.0 / mse)
def ssim(pred: torch.Tensor, target: torch.Tensor, window_size: int = 11) -> float:
"""
Structural Similarity Index (higher = better).
Computed per-channel and averaged. Range [-1, 1].
"""
C1, C2 = (0.01 ** 2), (0.03 ** 2)
results = []
for c in range(pred.shape[0]):
p = pred[c].unsqueeze(0).unsqueeze(0).float()
t = target[c].unsqueeze(0).unsqueeze(0).float()
mu_p = torch.nn.functional.avg_pool2d(p, window_size, stride=1,
padding=window_size // 2)
mu_t = torch.nn.functional.avg_pool2d(t, window_size, stride=1,
padding=window_size // 2)
mu_p_sq = mu_p ** 2
mu_t_sq = mu_t ** 2
mu_pt = mu_p * mu_t
sigma_p = torch.nn.functional.avg_pool2d(p * p, window_size, stride=1,
padding=window_size // 2) - mu_p_sq
sigma_t = torch.nn.functional.avg_pool2d(t * t, window_size, stride=1,
padding=window_size // 2) - mu_t_sq
sigma_pt = torch.nn.functional.avg_pool2d(p * t, window_size, stride=1,
padding=window_size // 2) - mu_pt
num = (2 * mu_pt + C1) * (2 * sigma_pt + C2)
den = (mu_p_sq + mu_t_sq + C1) * (sigma_p + sigma_t + C2)
results.append(torch.mean(num / den).item())
return float(np.mean(results))
def evaluate(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenoisingAutoencoder().to(device)
ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_{noise_type}.pth")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.")
epoch, _ = load_checkpoint(ckpt_path, model)
print(f"Loaded checkpoint from epoch {epoch}")
model.eval()
_, test_loader = get_dataloaders(noise_type)
all_psnr_noisy, all_psnr_denoised = [], []
all_ssim_noisy, all_ssim_denoised = [], []
sample_noisy, sample_clean, sample_denoised = [], [], []
with torch.no_grad():
for noisy, clean in test_loader:
noisy, clean = noisy.to(device), clean.to(device)
denoised = model(noisy)
for i in range(noisy.size(0)):
n, c, d = noisy[i].cpu(), clean[i].cpu(), denoised[i].cpu()
all_psnr_noisy.append(psnr(n, c))
all_psnr_denoised.append(psnr(d, c))
all_ssim_noisy.append(ssim(n, c))
all_ssim_denoised.append(ssim(d, c))
if len(sample_noisy) < n_samples:
for i in range(min(n_samples - len(sample_noisy), noisy.size(0))):
sample_noisy.append(noisy[i].cpu())
sample_clean.append(clean[i].cpu())
sample_denoised.append(denoised[i].cpu())
avg_psnr_noisy = np.mean(all_psnr_noisy)
avg_psnr_denoised = np.mean(all_psnr_denoised)
avg_ssim_noisy = np.mean(all_ssim_noisy)
avg_ssim_denoised = np.mean(all_ssim_denoised)
print(f"\n=== Evaluation ({noise_type} noise) ===")
print(f"PSNR — Noisy: {avg_psnr_noisy:.2f} dB | Denoised: {avg_psnr_denoised:.2f} dB")
print(f"SSIM — Noisy: {avg_ssim_noisy:.4f} | Denoised: {avg_ssim_denoised:.4f}")
# Save visual comparison
save_comparison(sample_noisy, sample_clean, sample_denoised, noise_type)
return avg_psnr_denoised, avg_ssim_denoised
def save_comparison(noisy_list, clean_list, denoised_list, noise_type):
n = len(noisy_list)
fig, axes = plt.subplots(3, n, figsize=(2.5 * n, 7))
titles = ["Noisy", "Denoised", "Clean"]
for col in range(n):
imgs = [noisy_list[col], denoised_list[col], clean_list[col]]
for row, img in enumerate(imgs):
ax = axes[row][col]
ax.imshow(img.permute(1, 2, 0).numpy())
ax.axis("off")
if col == 0:
ax.set_ylabel(titles[row], fontsize=12, rotation=90, labelpad=5)
plt.suptitle(f"Denoising Results — {noise_type} noise", fontsize=14)
plt.tight_layout()
out_path = os.path.join(config.RESULTS_DIR, f"comparison_{noise_type}.png")
os.makedirs(config.RESULTS_DIR, exist_ok=True)
plt.savefig(out_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Saved comparison image → {out_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL,
choices=config.NOISE_TYPES, help="Noise type")
parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize")
args = parser.parse_args()
evaluate(noise_type=args.noise, n_samples=args.samples)
|