import os import argparse import torch from PIL import Image from torchvision.transforms import ToTensor from skimage.metrics import peak_signal_noise_ratio, structural_similarity from restormer import ChannelShuffleWithGBPDeep import utils as utils def load_model(weights_path: str, device: torch.device) -> torch.nn.Module: """Load ChannelShuffleWithGBPDeep model and weights.""" model = ChannelShuffleWithGBPDeep().to(device) utils.load_checkpointG1(model, weights_path, strict=False) model.eval() return model def list_image_files(folder: str): exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif") return sorted( [f for f in os.listdir(folder) if f.lower().endswith(exts)] ) def main(): parser = argparse.ArgumentParser(description="Evaluate synthetic test set (PSNR & SSIM).") parser.add_argument( "--weights", type=str, required=True, help="Path to model checkpoint (e.g. ./checkpoint_new/Deraining/models/MPRNet/model_200.pth)", ) parser.add_argument( "--test_inp", type=str, default="./dataset/test/input", help="Path to synthetic test input folder.", ) parser.add_argument( "--test_tar", type=str, default="./dataset/test/target", help="Path to synthetic test target folder.", ) parser.add_argument( "--gpu", type=str, default="0", help="CUDA_VISIBLE_DEVICES setting, e.g. '0' or '0,1'.", ) args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") assert os.path.exists(args.test_inp), f"test_inp not found: {args.test_inp}" assert os.path.exists(args.test_tar), f"test_tar not found: {args.test_tar}" model = load_model(args.weights, device) to_tensor = ToTensor() img_names = list_image_files(args.test_inp) assert len(img_names) > 0, f"No images found in {args.test_inp}" psnr_sum = 0.0 ssim_sum = 0.0 count = 0 for name in img_names: inp_path = os.path.join(args.test_inp, name) tar_path = os.path.join(args.test_tar, name) if not os.path.exists(tar_path): print(f"[Warning] Ground truth not found for {name}, skip.") continue inp_img = Image.open(inp_path).convert("RGB") tar_img = Image.open(tar_path).convert("RGB") inp = to_tensor(inp_img).unsqueeze(0).to(device) tar = to_tensor(tar_img).unsqueeze(0).to(device) with torch.no_grad(): out = model(inp) out_np = out.squeeze(0).cpu().numpy() tar_np = tar.squeeze(0).cpu().numpy() # PSNR (channel-first, data_range=1) psnr = peak_signal_noise_ratio(tar_np, out_np, data_range=1.0) # SSIM 需要 channel-last out_np_ch_last = out_np.transpose(1, 2, 0) tar_np_ch_last = tar_np.transpose(1, 2, 0) ssim = structural_similarity( tar_np_ch_last, out_np_ch_last, data_range=1.0, channel_axis=-1, ) psnr_sum += psnr ssim_sum += ssim count += 1 print(f"{name}: PSNR={psnr:.4f}, SSIM={ssim:.4f}") if count == 0: print("No valid image pairs evaluated.") return print("-" * 60) print(f"Average over {count} images: PSNR={psnr_sum / count:.4f}, SSIM={ssim_sum / count:.4f}") if __name__ == "__main__": main() # python test_synth_psnr_ssim.py \ # --weights ./checkpoint_new/Deraining/models/MPRNet/model_46.pth \ # --test_inp ./dataset/test/input \ # --test_tar ./dataset/test/target \ # --gpu 0