""" 对 real 测试集跑模型,并把输出图保存到指定文件夹。 输入:./dataset/test/real_input(或 --real_inp) 输出:保存到 --output_dir,文件名与输入一致。 """ import os import argparse import torch from PIL import Image from torchvision.transforms import ToTensor from restormer import ChannelShuffleWithGBPDeep import utils as utils def load_model(weights_path: str, device: torch.device) -> torch.nn.Module: model = ChannelShuffleWithGBPDeep().to(device) utils.load_checkpointG1(model, weights_path, strict=False) model.eval() return model def list_image_files(folder: str): exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif") return sorted([f for f in os.listdir(folder) if f.lower().endswith(exts)]) def main(): parser = argparse.ArgumentParser(description="Run model on real test images and save outputs.") parser.add_argument("--weights", type=str, required=True, help="Path to model checkpoint.") parser.add_argument( "--real_inp", type=str, default="./dataset/test/real_input", help="Path to real-world test input folder.", ) parser.add_argument( "--output_dir", type=str, default="./dataset/test/real_output", help="Folder to save output images.", ) parser.add_argument("--gpu", type=str, default="0") args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") assert os.path.exists(args.real_inp), f"real_inp not found: {args.real_inp}" os.makedirs(args.output_dir, exist_ok=True) print(f"Output dir: {args.output_dir}") model = load_model(args.weights, device) to_tensor = ToTensor() img_names = list_image_files(args.real_inp) assert len(img_names) > 0, f"No images found in {args.real_inp}" for name in img_names: inp_path = os.path.join(args.real_inp, name) inp_img = Image.open(inp_path).convert("RGB") inp = to_tensor(inp_img).unsqueeze(0).to(device) with torch.no_grad(): out = model(inp) # (1, 3, H, W) RGB 0~1 -> 保存为 PNG out_np = out.squeeze(0).cpu().clamp(0, 1).numpy() out_np = (out_np.transpose(1, 2, 0) * 255.0).clip(0, 255).astype("uint8") out_pil = Image.fromarray(out_np) # 输出文件名与输入一致,扩展名统一为 .png base, _ = os.path.splitext(name) out_name = base + ".png" out_path = os.path.join(args.output_dir, out_name) out_pil.save(out_path) print(f"Saved: {out_path}") print(f"Done. {len(img_names)} images saved to {args.output_dir}") if __name__ == "__main__": main()