| """ |
| 对 real 测试集跑模型,并把输出图保存到指定文件夹。 |
| 输入:./dataset/test/real_input(或 --real_inp) |
| 输出:保存到 --output_dir,文件名与输入一致。 |
| """ |
| import os |
| import argparse |
|
|
| import torch |
| from PIL import Image |
| from torchvision.transforms import ToTensor |
|
|
| from restormer import ChannelShuffleWithGBPDeep |
| import utils as utils |
|
|
|
|
| def load_model(weights_path: str, device: torch.device) -> torch.nn.Module: |
| model = ChannelShuffleWithGBPDeep().to(device) |
| utils.load_checkpointG1(model, weights_path, strict=False) |
| model.eval() |
| return model |
|
|
|
|
| def list_image_files(folder: str): |
| exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif") |
| return sorted([f for f in os.listdir(folder) if f.lower().endswith(exts)]) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run model on real test images and save outputs.") |
| parser.add_argument("--weights", type=str, required=True, help="Path to model checkpoint.") |
| parser.add_argument( |
| "--real_inp", |
| type=str, |
| default="./dataset/test/real_input", |
| help="Path to real-world test input folder.", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="./dataset/test/real_output", |
| help="Folder to save output images.", |
| ) |
| parser.add_argument("--gpu", type=str, default="0") |
| args = parser.parse_args() |
|
|
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| assert os.path.exists(args.real_inp), f"real_inp not found: {args.real_inp}" |
| os.makedirs(args.output_dir, exist_ok=True) |
| print(f"Output dir: {args.output_dir}") |
|
|
| model = load_model(args.weights, device) |
| to_tensor = ToTensor() |
| img_names = list_image_files(args.real_inp) |
| assert len(img_names) > 0, f"No images found in {args.real_inp}" |
|
|
| for name in img_names: |
| inp_path = os.path.join(args.real_inp, name) |
| inp_img = Image.open(inp_path).convert("RGB") |
| inp = to_tensor(inp_img).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| out = model(inp) |
|
|
| |
| out_np = out.squeeze(0).cpu().clamp(0, 1).numpy() |
| out_np = (out_np.transpose(1, 2, 0) * 255.0).clip(0, 255).astype("uint8") |
| out_pil = Image.fromarray(out_np) |
|
|
| |
| base, _ = os.path.splitext(name) |
| out_name = base + ".png" |
| out_path = os.path.join(args.output_dir, out_name) |
| out_pil.save(out_path) |
| print(f"Saved: {out_path}") |
|
|
| print(f"Done. {len(img_names)} images saved to {args.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|