File size: 3,321 Bytes
dcd2bd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import glob
import torch
import torch.nn.functional as F
import math
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_SIGMAS = (15.0 / 255.0, 50.0 / 255.0, 75.0 / 255.0)


def _testset_root(name):
    return os.path.join(
        _SCRIPT_DIR,
        "datasets",
        "Test_Datasets",
        "FFDNet-master",
        "testsets",
        name,
    )


class TestDataset(Dataset):
    def __init__(self, root_dir, sigma):
        self.sigma = sigma
        self.image_paths = glob.glob(os.path.join(root_dir, "*.png")) + glob.glob(
            os.path.join(root_dir, "*.jpg")
        )
        self.transform = transforms.Compose(
            [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])
        clean = self.transform(img)
        noisy = clean + torch.randn_like(clean) * self.sigma
        return clean, torch.clamp(noisy, 0.0, 1.0)


def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float("inf")
    return 20 * math.log10(1.0 / math.sqrt(mse))


def classical_telegraph_step(u_n, u_n_minus_1, tau=0.2, gamma=1.0):
    kx = torch.tensor([[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]], device=DEVICE).view(
        1, 1, 3, 3
    )
    ky = torch.tensor([[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]], device=DEVICE).view(
        1, 1, 3, 3
    )

    grad_x = F.conv2d(u_n, kx, padding=1)
    grad_y = F.conv2d(u_n, ky, padding=1)
    grad_mag = torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)

    c = 1.0 / (1.0 + (grad_mag / 0.1) ** 2)

    divergence = F.conv2d(c * grad_x, kx, padding=1) + F.conv2d(
        c * grad_y, ky, padding=1
    )

    alpha = (2 + gamma * tau) / (1 + gamma * tau)
    beta = -1 / (1 + gamma * tau)
    lam = (tau**2) / (1 + gamma * tau)

    return alpha * u_n + beta * u_n_minus_1 + lam * divergence


def run_eval(dataset_name, sigma):
    root = _testset_root(dataset_name)
    dataset = TestDataset(root, sigma)
    if len(dataset) == 0:
        print(f"[!] Skip {dataset_name}: no images in {os.path.abspath(root)}")
        return None

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    total_psnr = 0.0
    for clean, noisy in dataloader:
        clean, noisy = clean.to(DEVICE), noisy.to(DEVICE)
        u_n_minus_1, u_n = noisy.clone(), noisy.clone()

        for _ in range(20):
            u_next = classical_telegraph_step(u_n, u_n_minus_1)
            u_n_minus_1, u_n = u_n, u_next

        total_psnr += calculate_psnr(clean, u_n)

    avg_psnr = total_psnr / len(dataset)
    sigma_int = int(round(sigma * 255.0))
    print(
        f"[+] {dataset_name}  sigma={sigma_int}/255  PSNR: {avg_psnr:.2f} dB  "
        f"({len(dataset)} images)"
    )
    return avg_psnr


def main():
    print("[*] Classical Majee 2020 baseline — Set12 & BSD68")
    for dataset_name in ("Set12", "BSD68"):
        for sigma in _SIGMAS:
            run_eval(dataset_name, sigma)
        print()


if __name__ == "__main__":
    main()