llir / test_real_save_output.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
"""
对 real 测试集跑模型,并把输出图保存到指定文件夹。
输入:./dataset/test/real_input(或 --real_inp)
输出:保存到 --output_dir,文件名与输入一致。
"""
import os
import argparse
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from restormer import ChannelShuffleWithGBPDeep
import utils as utils
def load_model(weights_path: str, device: torch.device) -> torch.nn.Module:
model = ChannelShuffleWithGBPDeep().to(device)
utils.load_checkpointG1(model, weights_path, strict=False)
model.eval()
return model
def list_image_files(folder: str):
exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".gif")
return sorted([f for f in os.listdir(folder) if f.lower().endswith(exts)])
def main():
parser = argparse.ArgumentParser(description="Run model on real test images and save outputs.")
parser.add_argument("--weights", type=str, required=True, help="Path to model checkpoint.")
parser.add_argument(
"--real_inp",
type=str,
default="./dataset/test/real_input",
help="Path to real-world test input folder.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./dataset/test/real_output",
help="Folder to save output images.",
)
parser.add_argument("--gpu", type=str, default="0")
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
assert os.path.exists(args.real_inp), f"real_inp not found: {args.real_inp}"
os.makedirs(args.output_dir, exist_ok=True)
print(f"Output dir: {args.output_dir}")
model = load_model(args.weights, device)
to_tensor = ToTensor()
img_names = list_image_files(args.real_inp)
assert len(img_names) > 0, f"No images found in {args.real_inp}"
for name in img_names:
inp_path = os.path.join(args.real_inp, name)
inp_img = Image.open(inp_path).convert("RGB")
inp = to_tensor(inp_img).unsqueeze(0).to(device)
with torch.no_grad():
out = model(inp)
# (1, 3, H, W) RGB 0~1 -> 保存为 PNG
out_np = out.squeeze(0).cpu().clamp(0, 1).numpy()
out_np = (out_np.transpose(1, 2, 0) * 255.0).clip(0, 255).astype("uint8")
out_pil = Image.fromarray(out_np)
# 输出文件名与输入一致,扩展名统一为 .png
base, _ = os.path.splitext(name)
out_name = base + ".png"
out_path = os.path.join(args.output_dir, out_name)
out_pil.save(out_path)
print(f"Saved: {out_path}")
print(f"Done. {len(img_names)} images saved to {args.output_dir}")
if __name__ == "__main__":
main()