llir / test_synth_psnr_ssim.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
import os
import argparse
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from restormer import ChannelShuffleWithGBPDeep
import utils as utils
def load_model(weights_path: str, device: torch.device) -> torch.nn.Module:
"""Load ChannelShuffleWithGBPDeep model and weights."""
model = ChannelShuffleWithGBPDeep().to(device)
utils.load_checkpointG1(model, weights_path, strict=False)
model.eval()
return model
def list_image_files(folder: str):
exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif")
return sorted(
[f for f in os.listdir(folder) if f.lower().endswith(exts)]
)
def main():
parser = argparse.ArgumentParser(description="Evaluate synthetic test set (PSNR & SSIM).")
parser.add_argument(
"--weights",
type=str,
required=True,
help="Path to model checkpoint (e.g. ./checkpoint_new/Deraining/models/MPRNet/model_200.pth)",
)
parser.add_argument(
"--test_inp",
type=str,
default="./dataset/test/input",
help="Path to synthetic test input folder.",
)
parser.add_argument(
"--test_tar",
type=str,
default="./dataset/test/target",
help="Path to synthetic test target folder.",
)
parser.add_argument(
"--gpu",
type=str,
default="0",
help="CUDA_VISIBLE_DEVICES setting, e.g. '0' or '0,1'.",
)
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
assert os.path.exists(args.test_inp), f"test_inp not found: {args.test_inp}"
assert os.path.exists(args.test_tar), f"test_tar not found: {args.test_tar}"
model = load_model(args.weights, device)
to_tensor = ToTensor()
img_names = list_image_files(args.test_inp)
assert len(img_names) > 0, f"No images found in {args.test_inp}"
psnr_sum = 0.0
ssim_sum = 0.0
count = 0
for name in img_names:
inp_path = os.path.join(args.test_inp, name)
tar_path = os.path.join(args.test_tar, name)
if not os.path.exists(tar_path):
print(f"[Warning] Ground truth not found for {name}, skip.")
continue
inp_img = Image.open(inp_path).convert("RGB")
tar_img = Image.open(tar_path).convert("RGB")
inp = to_tensor(inp_img).unsqueeze(0).to(device)
tar = to_tensor(tar_img).unsqueeze(0).to(device)
with torch.no_grad():
out = model(inp)
out_np = out.squeeze(0).cpu().numpy()
tar_np = tar.squeeze(0).cpu().numpy()
# PSNR (channel-first, data_range=1)
psnr = peak_signal_noise_ratio(tar_np, out_np, data_range=1.0)
# SSIM 需要 channel-last
out_np_ch_last = out_np.transpose(1, 2, 0)
tar_np_ch_last = tar_np.transpose(1, 2, 0)
ssim = structural_similarity(
tar_np_ch_last,
out_np_ch_last,
data_range=1.0,
channel_axis=-1,
)
psnr_sum += psnr
ssim_sum += ssim
count += 1
print(f"{name}: PSNR={psnr:.4f}, SSIM={ssim:.4f}")
if count == 0:
print("No valid image pairs evaluated.")
return
print("-" * 60)
print(f"Average over {count} images: PSNR={psnr_sum / count:.4f}, SSIM={ssim_sum / count:.4f}")
if __name__ == "__main__":
main()
# python test_synth_psnr_ssim.py \
# --weights ./checkpoint_new/Deraining/models/MPRNet/model_46.pth \
# --test_inp ./dataset/test/input \
# --test_tar ./dataset/test/target \
# --gpu 0