| import os |
| import argparse |
|
|
| import torch |
| from PIL import Image |
| from torchvision.transforms import ToTensor |
| from skimage.metrics import peak_signal_noise_ratio, structural_similarity |
|
|
| from restormer import ChannelShuffleWithGBPDeep |
| import utils as utils |
|
|
|
|
| def load_model(weights_path: str, device: torch.device) -> torch.nn.Module: |
| """Load ChannelShuffleWithGBPDeep model and weights.""" |
| model = ChannelShuffleWithGBPDeep().to(device) |
| utils.load_checkpointG1(model, weights_path, strict=False) |
| model.eval() |
| return model |
|
|
|
|
| def list_image_files(folder: str): |
| exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif") |
| return sorted( |
| [f for f in os.listdir(folder) if f.lower().endswith(exts)] |
| ) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate synthetic test set (PSNR & SSIM).") |
| parser.add_argument( |
| "--weights", |
| type=str, |
| required=True, |
| help="Path to model checkpoint (e.g. ./checkpoint_new/Deraining/models/MPRNet/model_200.pth)", |
| ) |
| parser.add_argument( |
| "--test_inp", |
| type=str, |
| default="./dataset/test/input", |
| help="Path to synthetic test input folder.", |
| ) |
| parser.add_argument( |
| "--test_tar", |
| type=str, |
| default="./dataset/test/target", |
| help="Path to synthetic test target folder.", |
| ) |
| parser.add_argument( |
| "--gpu", |
| type=str, |
| default="0", |
| help="CUDA_VISIBLE_DEVICES setting, e.g. '0' or '0,1'.", |
| ) |
| args = parser.parse_args() |
|
|
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| assert os.path.exists(args.test_inp), f"test_inp not found: {args.test_inp}" |
| assert os.path.exists(args.test_tar), f"test_tar not found: {args.test_tar}" |
|
|
| model = load_model(args.weights, device) |
| to_tensor = ToTensor() |
|
|
| img_names = list_image_files(args.test_inp) |
| assert len(img_names) > 0, f"No images found in {args.test_inp}" |
|
|
| psnr_sum = 0.0 |
| ssim_sum = 0.0 |
| count = 0 |
|
|
| for name in img_names: |
| inp_path = os.path.join(args.test_inp, name) |
| tar_path = os.path.join(args.test_tar, name) |
| if not os.path.exists(tar_path): |
| print(f"[Warning] Ground truth not found for {name}, skip.") |
| continue |
|
|
| inp_img = Image.open(inp_path).convert("RGB") |
| tar_img = Image.open(tar_path).convert("RGB") |
|
|
| inp = to_tensor(inp_img).unsqueeze(0).to(device) |
| tar = to_tensor(tar_img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| out = model(inp) |
|
|
| out_np = out.squeeze(0).cpu().numpy() |
| tar_np = tar.squeeze(0).cpu().numpy() |
|
|
| |
| psnr = peak_signal_noise_ratio(tar_np, out_np, data_range=1.0) |
|
|
| |
| out_np_ch_last = out_np.transpose(1, 2, 0) |
| tar_np_ch_last = tar_np.transpose(1, 2, 0) |
| ssim = structural_similarity( |
| tar_np_ch_last, |
| out_np_ch_last, |
| data_range=1.0, |
| channel_axis=-1, |
| ) |
|
|
| psnr_sum += psnr |
| ssim_sum += ssim |
| count += 1 |
|
|
| print(f"{name}: PSNR={psnr:.4f}, SSIM={ssim:.4f}") |
|
|
| if count == 0: |
| print("No valid image pairs evaluated.") |
| return |
|
|
| print("-" * 60) |
| print(f"Average over {count} images: PSNR={psnr_sum / count:.4f}, SSIM={ssim_sum / count:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|