File size: 3,131 Bytes
dcd2bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import glob
import math
import os

import torch
from PIL import Image
from torchvision import transforms

from train_network import UnrolledNetwork
from train_tnrd_baseline import TNRDBaselineNetwork


SIGMA = 25.0 / 255.0
TEST_DIR = "./datasets/Test_Datasets/FFDNet-master/testsets/Set12"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float("inf")
    return 20 * math.log10(1.0 / math.sqrt(mse))


def _make_model(spec):
    if spec["kind"] == "telegraph":
        return UnrolledNetwork(num_stages=spec["stages"], use_wave=spec["wave"]).to(DEVICE)
    if spec["kind"] == "tnrd":
        return TNRDBaselineNetwork(num_stages=spec["stages"]).to(DEVICE)
    raise ValueError(f"Unknown model kind: {spec['kind']}")


def _autocast_context():
    return torch.amp.autocast("cuda") if DEVICE.type == "cuda" else torch.autocast("cpu", enabled=False)


def evaluate_model(spec):
    model = _make_model(spec)
    model.load_state_dict(torch.load(spec["file"], map_location=DEVICE))
    model.eval()

    test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
    test_paths = sorted(glob.glob(os.path.join(TEST_DIR, "*.png")))
    if not test_paths:
        return "Error: No images found in Set12 directory."

    torch.manual_seed(42)
    total_psnr = 0.0
    with torch.no_grad():
        for path in test_paths:
            clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE)
            noisy = torch.clamp(clean + torch.randn_like(clean) * SIGMA, 0.0, 1.0)

            with _autocast_context():
                output = model(noisy)

            total_psnr += calculate_psnr(clean, output)

    return f"{total_psnr / len(test_paths):.2f} dB"


def main():
    print("[*] Evaluating Checkpoints on Set12 (Sigma = 25)...")
    print("-" * 60)

    models_to_test = [
        {
            "name": "3-Stage TDE (Proposed)",
            "kind": "telegraph",
            "stages": 3,
            "wave": True,
            "file": "model_3stages_waveTrue.pth",
        },
        {
            "name": "5-Stage TDE (Proposed)",
            "kind": "telegraph",
            "stages": 5,
            "wave": True,
            "file": "model_5stages_waveTrue.pth",
        },
        {
            "name": "5-Stage Telegraph w/o wave",
            "kind": "telegraph",
            "stages": 5,
            "wave": False,
            "file": "model_5stages_waveFalse.pth",
        },
        {
            "name": "5-Stage TNRD baseline",
            "kind": "tnrd",
            "stages": 5,
            "file": "tnrd_baseline_5stages.pth",
        },
    ]

    for spec in models_to_test:
        if os.path.exists(spec["file"]):
            score = evaluate_model(spec)
            print(f"{spec['name']:<30} | PSNR: {score}")
        else:
            print(f"{spec['name']:<30} | PSNR: [File not found]")

    print("-" * 60)
    print("PDE baseline: run classical_baseline.py for the fixed telegraph/PDE comparison.")


if __name__ == "__main__":
    main()