personal_math / evaluate_checkpoints.py
psidharth567's picture
Sync full project: code, checkpoints, datasets, logs
dcd2bd2 verified
import argparse
import os
import torch
from PIL import Image
from torchvision import transforms
from train_network import (
DEFAULT_GAMMAS,
DEFAULT_SIGMA_LEVELS,
DEFAULT_TAU_INITS,
_collect_image_paths,
_SCRIPT_DIR,
calculate_psnr,
gamma_tag,
sigma_int_to_float,
sigma_tag,
tau_tag,
)
from train_network import UnrolledNetwork as MLPUnrolledNetwork
from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork
from train_tnrd_baseline import TNRDBaselineNetwork
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TESTSETS = ("Set12", "BSD68")
def _testset_root(name):
return os.path.join(
_SCRIPT_DIR,
"datasets",
"Test_Datasets",
"FFDNet-master",
"testsets",
name,
)
def _autocast_context():
return (
torch.amp.autocast("cuda")
if DEVICE.type == "cuda"
else torch.autocast("cpu", enabled=False)
)
def _build_model(model_type, stages, use_wave, damping_gamma, tau_init):
if model_type == "mlp":
return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE)
if model_type == "rbf":
return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE)
if model_type == "tnrd":
return TNRDBaselineNetwork(stages, tau_init=tau_init).to(DEVICE)
raise ValueError(f"Unknown model type: {model_type}")
def _checkpoint_specs(stages, sigmas, gammas, tau_inits, include_finetuned):
specs = []
for sigma in sigmas:
sigma_name = sigma_tag(sigma)
for tau_init in tau_inits:
tau_name = tau_tag(tau_init)
specs.append(
{
"label": f"TNRD baseline sigma={sigma} tau={tau_init}",
"model_type": "tnrd",
"use_wave": False,
"damping_gamma": 1.0,
"tau_init": tau_init,
"path": f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth",
}
)
if include_finetuned:
specs.append(
{
"label": f"Finetuned TNRD baseline sigma={sigma} tau={tau_init}",
"model_type": "tnrd",
"use_wave": False,
"damping_gamma": 1.0,
"tau_init": tau_init,
"path": f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth",
}
)
for damping_gamma in gammas:
gamma_name = gamma_tag(damping_gamma)
for tau_init in tau_inits:
tau_name = tau_tag(tau_init)
specs.extend(
[
{
"label": f"MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "mlp",
"use_wave": True,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "mlp",
"use_wave": False,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "rbf",
"use_wave": True,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "rbf",
"use_wave": False,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
]
)
if include_finetuned:
specs.extend(
[
{
"label": f"Finetuned MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "mlp",
"use_wave": True,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"finetuned_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"Finetuned MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "mlp",
"use_wave": False,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"finetuned_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"Finetuned RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "rbf",
"use_wave": True,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"finetuned_rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
{
"label": f"Finetuned RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}",
"model_type": "rbf",
"use_wave": False,
"damping_gamma": damping_gamma,
"tau_init": tau_init,
"path": f"finetuned_rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth",
},
]
)
return specs
def evaluate_checkpoint(spec, dataset_name, sigma, stages):
model = _build_model(
spec["model_type"],
stages,
spec["use_wave"],
spec["damping_gamma"],
spec["tau_init"],
)
state = torch.load(spec["path"], map_location=DEVICE)
model.load_state_dict(state)
model.eval()
test_root = _testset_root(dataset_name)
test_paths = _collect_image_paths(test_root)
if not test_paths:
raise FileNotFoundError(
f"No test images found in {os.path.abspath(test_root)} for {dataset_name}."
)
sigma_float = sigma_int_to_float(sigma)
test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
torch.manual_seed(42)
total_psnr = 0.0
with torch.no_grad():
for path in test_paths:
clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE)
noisy = torch.clamp(clean + torch.randn_like(clean) * sigma_float, 0.0, 1.0)
with _autocast_context():
output = model(noisy)
total_psnr += calculate_psnr(clean, output)
return total_psnr / len(test_paths)
def main(args):
specs = _checkpoint_specs(
args.stages,
[int(s) for s in args.sigmas],
[float(g) for g in args.gammas],
[float(t) for t in args.tau_inits],
args.include_finetuned,
)
print(f"[*] Evaluating checkpoints on {', '.join(TESTSETS)}")
print(f"[*] Device: {DEVICE}")
print("-" * 90)
print(f"{'Model':<38} {'Dataset':<8} {'PSNR':>8} Checkpoint")
print("-" * 90)
for spec in specs:
if not os.path.exists(spec["path"]):
print(f"{spec['label']:<38} {'-':<8} {'[missing]':>8} {spec['path']}")
continue
sigma_value = int(spec["label"].split("sigma=")[-1])
for dataset_name in TESTSETS:
psnr = evaluate_checkpoint(spec, dataset_name, sigma_value, args.stages)
print(f"{spec['label']:<38} {dataset_name:<8} {psnr:>8.2f} {spec['path']}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--stages", type=int, default=5)
parser.add_argument(
"--sigmas",
type=int,
nargs="+",
default=list(DEFAULT_SIGMA_LEVELS),
help="Noise levels to evaluate, specified in 0-255 units.",
)
parser.add_argument(
"--gammas",
type=float,
nargs="+",
default=list(DEFAULT_GAMMAS),
help="Fixed damping gamma values to evaluate for MLP/RBF models.",
)
parser.add_argument(
"--tau_inits",
type=float,
nargs="+",
default=list(DEFAULT_TAU_INITS),
help="Initial tau values to evaluate.",
)
parser.add_argument(
"--include_finetuned",
action="store_true",
help="Also evaluate finetuned checkpoints.",
)
args = parser.parse_args()
main(args)