import glob import math import os import torch from PIL import Image from torchvision import transforms from train_network import UnrolledNetwork from train_tnrd_baseline import TNRDBaselineNetwork SIGMA = 25.0 / 255.0 TEST_DIR = "./datasets/Test_Datasets/FFDNet-master/testsets/Set12" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def calculate_psnr(img1, img2): mse = torch.mean((img1 - img2) ** 2) if mse == 0: return float("inf") return 20 * math.log10(1.0 / math.sqrt(mse)) def _make_model(spec): if spec["kind"] == "telegraph": return UnrolledNetwork(num_stages=spec["stages"], use_wave=spec["wave"]).to(DEVICE) if spec["kind"] == "tnrd": return TNRDBaselineNetwork(num_stages=spec["stages"]).to(DEVICE) raise ValueError(f"Unknown model kind: {spec['kind']}") def _autocast_context(): return torch.amp.autocast("cuda") if DEVICE.type == "cuda" else torch.autocast("cpu", enabled=False) def evaluate_model(spec): model = _make_model(spec) model.load_state_dict(torch.load(spec["file"], map_location=DEVICE)) model.eval() test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) test_paths = sorted(glob.glob(os.path.join(TEST_DIR, "*.png"))) if not test_paths: return "Error: No images found in Set12 directory." torch.manual_seed(42) total_psnr = 0.0 with torch.no_grad(): for path in test_paths: clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE) noisy = torch.clamp(clean + torch.randn_like(clean) * SIGMA, 0.0, 1.0) with _autocast_context(): output = model(noisy) total_psnr += calculate_psnr(clean, output) return f"{total_psnr / len(test_paths):.2f} dB" def main(): print("[*] Evaluating Checkpoints on Set12 (Sigma = 25)...") print("-" * 60) models_to_test = [ { "name": "3-Stage TDE (Proposed)", "kind": "telegraph", "stages": 3, "wave": True, "file": "model_3stages_waveTrue.pth", }, { "name": "5-Stage TDE (Proposed)", "kind": "telegraph", "stages": 5, "wave": True, "file": "model_5stages_waveTrue.pth", }, { "name": "5-Stage Telegraph w/o wave", "kind": "telegraph", "stages": 5, "wave": False, "file": "model_5stages_waveFalse.pth", }, { "name": "5-Stage TNRD baseline", "kind": "tnrd", "stages": 5, "file": "tnrd_baseline_5stages.pth", }, ] for spec in models_to_test: if os.path.exists(spec["file"]): score = evaluate_model(spec) print(f"{spec['name']:<30} | PSNR: {score}") else: print(f"{spec['name']:<30} | PSNR: [File not found]") print("-" * 60) print("PDE baseline: run classical_baseline.py for the fixed telegraph/PDE comparison.") if __name__ == "__main__": main()