File size: 4,239 Bytes
8b83582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import config
from data.dataset import get_sr_dataloaders
from models.autoencoder import SuperResAutoencoder
from training.train import load_checkpoint
from evaluation.evaluate import psnr, ssim


def evaluate_sr(noise_type: str = config.NOISE_TYPE_EVAL, n_samples: int = 8):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SuperResAutoencoder().to(device)
    ckpt_path = os.path.join(config.CHECKPOINT_DIR, f"best_sr_{noise_type}.pth")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"No checkpoint found at {ckpt_path}. Train first.")
    epoch, _ = load_checkpoint(ckpt_path, model)
    print(f"Loaded SR checkpoint from epoch {epoch}")
    model.eval()

    _, test_loader = get_sr_dataloaders(noise_type)

    psnr_bicubic, psnr_sr = [], []
    ssim_bicubic, ssim_sr = [], []
    samples = []  # (noisy_lr, bicubic, sr_out, clean_hr)

    with torch.no_grad():
        for noisy_lr, clean_hr in test_loader:
            noisy_lr, clean_hr = noisy_lr.to(device), clean_hr.to(device)
            sr_out = model(noisy_lr)

            # Bicubic baseline: upsample noisy LR to HR size
            bicubic = F.interpolate(
                noisy_lr, size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE),
                mode="bicubic", align_corners=False
            ).clamp(0, 1)

            for i in range(noisy_lr.size(0)):
                b = bicubic[i].cpu()
                s = sr_out[i].cpu()
                c = clean_hr[i].cpu()

                psnr_bicubic.append(psnr(b, c))
                psnr_sr.append(psnr(s, c))
                ssim_bicubic.append(ssim(b, c))
                ssim_sr.append(ssim(s, c))

            if len(samples) < n_samples:
                for i in range(min(n_samples - len(samples), noisy_lr.size(0))):
                    samples.append((
                        noisy_lr[i].cpu(),
                        bicubic[i].cpu(),
                        sr_out[i].cpu(),
                        clean_hr[i].cpu(),
                    ))

    print(f"\n=== SR Evaluation ({noise_type} noise) ===")
    print(f"PSNR  — Bicubic: {np.mean(psnr_bicubic):.2f} dB | SR Model: {np.mean(psnr_sr):.2f} dB")
    print(f"SSIM  — Bicubic: {np.mean(ssim_bicubic):.4f}    | SR Model: {np.mean(ssim_sr):.4f}")

    save_comparison(samples, noise_type)
    return np.mean(psnr_sr), np.mean(ssim_sr)


def save_comparison(samples, noise_type):
    n = len(samples)
    row_labels = ["Noisy LR", "Bicubic", "SR Model", "Ground Truth"]
    fig, axes = plt.subplots(4, n, figsize=(2.5 * n, 10))

    for col, (noisy_lr, bicubic, sr_out, clean_hr) in enumerate(samples):
        # Upsample noisy LR for display (it's 48×48, others are 96×96)
        noisy_display = F.interpolate(
            noisy_lr.unsqueeze(0), size=(config.SR_OUTPUT_SIZE, config.SR_OUTPUT_SIZE),
            mode="nearest"
        ).squeeze(0)

        imgs = [noisy_display, bicubic, sr_out, clean_hr]
        for row, img in enumerate(imgs):
            ax = axes[row][col]
            ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())
            ax.axis("off")
            if col == 0:
                ax.set_ylabel(row_labels[row], fontsize=11, rotation=90, labelpad=5)

    plt.suptitle(f"Super-Resolution Results — {noise_type} noise", fontsize=14)
    plt.tight_layout()

    out_path = os.path.join(config.RESULTS_DIR, f"comparison_sr_{noise_type}.png")
    os.makedirs(config.RESULTS_DIR, exist_ok=True)
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved comparison image -> {out_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--noise", default=config.NOISE_TYPE_EVAL,
                        choices=config.NOISE_TYPES, help="Noise type")
    parser.add_argument("--samples", type=int, default=8, help="Number of samples to visualize")
    args = parser.parse_args()
    evaluate_sr(noise_type=args.noise, n_samples=args.samples)