| import argparse |
| import os |
|
|
| import torch |
| from PIL import Image |
| from torchvision import transforms |
|
|
| from train_network import ( |
| DEFAULT_GAMMAS, |
| DEFAULT_SIGMA_LEVELS, |
| DEFAULT_TAU_INITS, |
| _collect_image_paths, |
| _SCRIPT_DIR, |
| calculate_psnr, |
| gamma_tag, |
| sigma_int_to_float, |
| sigma_tag, |
| tau_tag, |
| ) |
| from train_network import UnrolledNetwork as MLPUnrolledNetwork |
| from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork |
| from train_tnrd_baseline import TNRDBaselineNetwork |
|
|
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| TESTSETS = ("Set12", "BSD68") |
|
|
|
|
| def _testset_root(name): |
| return os.path.join( |
| _SCRIPT_DIR, |
| "datasets", |
| "Test_Datasets", |
| "FFDNet-master", |
| "testsets", |
| name, |
| ) |
|
|
|
|
| def _autocast_context(): |
| return ( |
| torch.amp.autocast("cuda") |
| if DEVICE.type == "cuda" |
| else torch.autocast("cpu", enabled=False) |
| ) |
|
|
|
|
| def _build_model(model_type, stages, use_wave, damping_gamma, tau_init): |
| if model_type == "mlp": |
| return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE) |
| if model_type == "rbf": |
| return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE) |
| if model_type == "tnrd": |
| return TNRDBaselineNetwork(stages, tau_init=tau_init).to(DEVICE) |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| def _checkpoint_specs(stages, sigmas, gammas, tau_inits, include_finetuned): |
| specs = [] |
|
|
| for sigma in sigmas: |
| sigma_name = sigma_tag(sigma) |
| for tau_init in tau_inits: |
| tau_name = tau_tag(tau_init) |
| specs.append( |
| { |
| "label": f"TNRD baseline sigma={sigma} tau={tau_init}", |
| "model_type": "tnrd", |
| "use_wave": False, |
| "damping_gamma": 1.0, |
| "tau_init": tau_init, |
| "path": f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth", |
| } |
| ) |
| if include_finetuned: |
| specs.append( |
| { |
| "label": f"Finetuned TNRD baseline sigma={sigma} tau={tau_init}", |
| "model_type": "tnrd", |
| "use_wave": False, |
| "damping_gamma": 1.0, |
| "tau_init": tau_init, |
| "path": f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth", |
| } |
| ) |
|
|
| for damping_gamma in gammas: |
| gamma_name = gamma_tag(damping_gamma) |
| for tau_init in tau_inits: |
| tau_name = tau_tag(tau_init) |
| specs.extend( |
| [ |
| { |
| "label": f"MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "mlp", |
| "use_wave": True, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "mlp", |
| "use_wave": False, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "rbf", |
| "use_wave": True, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "rbf", |
| "use_wave": False, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| ] |
| ) |
|
|
| if include_finetuned: |
| specs.extend( |
| [ |
| { |
| "label": f"Finetuned MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "mlp", |
| "use_wave": True, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"finetuned_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"Finetuned MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "mlp", |
| "use_wave": False, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"finetuned_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"Finetuned RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "rbf", |
| "use_wave": True, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"finetuned_rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| { |
| "label": f"Finetuned RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", |
| "model_type": "rbf", |
| "use_wave": False, |
| "damping_gamma": damping_gamma, |
| "tau_init": tau_init, |
| "path": f"finetuned_rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", |
| }, |
| ] |
| ) |
|
|
| return specs |
|
|
|
|
| def evaluate_checkpoint(spec, dataset_name, sigma, stages): |
| model = _build_model( |
| spec["model_type"], |
| stages, |
| spec["use_wave"], |
| spec["damping_gamma"], |
| spec["tau_init"], |
| ) |
| state = torch.load(spec["path"], map_location=DEVICE) |
| model.load_state_dict(state) |
| model.eval() |
|
|
| test_root = _testset_root(dataset_name) |
| test_paths = _collect_image_paths(test_root) |
| if not test_paths: |
| raise FileNotFoundError( |
| f"No test images found in {os.path.abspath(test_root)} for {dataset_name}." |
| ) |
|
|
| sigma_float = sigma_int_to_float(sigma) |
| test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) |
| torch.manual_seed(42) |
| total_psnr = 0.0 |
|
|
| with torch.no_grad(): |
| for path in test_paths: |
| clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE) |
| noisy = torch.clamp(clean + torch.randn_like(clean) * sigma_float, 0.0, 1.0) |
| with _autocast_context(): |
| output = model(noisy) |
| total_psnr += calculate_psnr(clean, output) |
|
|
| return total_psnr / len(test_paths) |
|
|
|
|
| def main(args): |
| specs = _checkpoint_specs( |
| args.stages, |
| [int(s) for s in args.sigmas], |
| [float(g) for g in args.gammas], |
| [float(t) for t in args.tau_inits], |
| args.include_finetuned, |
| ) |
| print(f"[*] Evaluating checkpoints on {', '.join(TESTSETS)}") |
| print(f"[*] Device: {DEVICE}") |
| print("-" * 90) |
| print(f"{'Model':<38} {'Dataset':<8} {'PSNR':>8} Checkpoint") |
| print("-" * 90) |
|
|
| for spec in specs: |
| if not os.path.exists(spec["path"]): |
| print(f"{spec['label']:<38} {'-':<8} {'[missing]':>8} {spec['path']}") |
| continue |
|
|
| sigma_value = int(spec["label"].split("sigma=")[-1]) |
| for dataset_name in TESTSETS: |
| psnr = evaluate_checkpoint(spec, dataset_name, sigma_value, args.stages) |
| print(f"{spec['label']:<38} {dataset_name:<8} {psnr:>8.2f} {spec['path']}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--stages", type=int, default=5) |
| parser.add_argument( |
| "--sigmas", |
| type=int, |
| nargs="+", |
| default=list(DEFAULT_SIGMA_LEVELS), |
| help="Noise levels to evaluate, specified in 0-255 units.", |
| ) |
| parser.add_argument( |
| "--gammas", |
| type=float, |
| nargs="+", |
| default=list(DEFAULT_GAMMAS), |
| help="Fixed damping gamma values to evaluate for MLP/RBF models.", |
| ) |
| parser.add_argument( |
| "--tau_inits", |
| type=float, |
| nargs="+", |
| default=list(DEFAULT_TAU_INITS), |
| help="Initial tau values to evaluate.", |
| ) |
| parser.add_argument( |
| "--include_finetuned", |
| action="store_true", |
| help="Also evaluate finetuned checkpoints.", |
| ) |
| args = parser.parse_args() |
| main(args) |
|
|