| |
| |
|
|
| import os |
| from collections import OrderedDict |
|
|
| import data |
| from options.test_options import TestOptions |
| from models.pix2pix_model import Pix2PixModel |
| from util.visualizer import Visualizer |
| import torchvision.utils as vutils |
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| opt = TestOptions().parse() |
|
|
| dataloader = data.create_dataloader(opt) |
|
|
| model = Pix2PixModel(opt) |
| model.eval() |
|
|
| visualizer = Visualizer(opt) |
|
|
|
|
| single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img") |
|
|
|
|
| if not os.path.exists(single_save_url): |
| os.makedirs(single_save_url) |
|
|
|
|
| for i, data_i in enumerate(dataloader): |
| if i * opt.batchSize >= opt.how_many: |
| break |
|
|
| generated = model(data_i, mode="inference") |
|
|
| img_path = data_i["path"] |
|
|
| for b in range(generated.shape[0]): |
| img_name = os.path.split(img_path[b])[-1] |
| save_img_url = os.path.join(single_save_url, img_name) |
|
|
| vutils.save_image((generated[b] + 1) / 2, save_img_url) |
|
|
|
|