viton-hd / test.py
known57's picture
Add files using upload-large-folder tool
24870a9 verified
import argparse
import os
import torch
from torch import nn
from torch.nn import functional as F
import torchgeometry as tgm
from datasets import VITONDataset, VITONDataLoader
from networks import SegGenerator, GMM, ALIASGenerator
from utils import gen_noise, load_checkpoint, save_images
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, required=True)
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-j', '--workers', type=int, default=1)
parser.add_argument('--load_height', type=int, default=1024)
parser.add_argument('--load_width', type=int, default=768)
parser.add_argument('--shuffle', action='store_true')
parser.add_argument('--dataset_dir', type=str, default='./datasets/')
parser.add_argument('--dataset_mode', type=str, default='test')
parser.add_argument('--dataset_list', type=str, default='test_pairs.txt')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
parser.add_argument('--save_dir', type=str, default='./results/')
parser.add_argument('--display_freq', type=int, default=1)
parser.add_argument('--seg_checkpoint', type=str, default='seg_final.pth')
parser.add_argument('--gmm_checkpoint', type=str, default='gmm_final.pth')
parser.add_argument('--alias_checkpoint', type=str, default='alias_final.pth')
# common
parser.add_argument('--semantic_nc', type=int, default=13, help='# of human-parsing map classes')
parser.add_argument('--init_type', choices=['normal', 'xavier', 'xavier_uniform', 'kaiming', 'orthogonal', 'none'], default='xavier')
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
# for GMM
parser.add_argument('--grid_size', type=int, default=5)
# for ALIASGenerator
parser.add_argument('--norm_G', type=str, default='spectralaliasinstance')
parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer')
parser.add_argument('--num_upsampling_layers', choices=['normal', 'more', 'most'], default='most',
help='If \'more\', add upsampling layer between the two middle resnet blocks. '
'If \'most\', also add one more (upsampling + resnet) layer at the end of the generator.')
opt = parser.parse_args()
return opt
def test(opt, seg, gmm, alias):
up = nn.Upsample(size=(opt.load_height, opt.load_width), mode='bilinear')
gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
gauss.cuda()
test_dataset = VITONDataset(opt)
test_loader = VITONDataLoader(opt, test_dataset)
with torch.no_grad():
for i, inputs in enumerate(test_loader.data_loader):
img_names = inputs['img_name']
c_names = inputs['c_name']['unpaired']
img_agnostic = inputs['img_agnostic'].cuda()
parse_agnostic = inputs['parse_agnostic'].cuda()
pose = inputs['pose'].cuda()
c = inputs['cloth']['unpaired'].cuda()
cm = inputs['cloth_mask']['unpaired'].cuda()
# Part 1. Segmentation generation
parse_agnostic_down = F.interpolate(parse_agnostic, size=(256, 192), mode='bilinear')
pose_down = F.interpolate(pose, size=(256, 192), mode='bilinear')
c_masked_down = F.interpolate(c * cm, size=(256, 192), mode='bilinear')
cm_down = F.interpolate(cm, size=(256, 192), mode='bilinear')
seg_input = torch.cat((cm_down, c_masked_down, parse_agnostic_down, pose_down, gen_noise(cm_down.size()).cuda()), dim=1)
parse_pred_down = seg(seg_input)
parse_pred = gauss(up(parse_pred_down))
parse_pred = parse_pred.argmax(dim=1)[:, None]
parse_old = torch.zeros(parse_pred.size(0), 13, opt.load_height, opt.load_width, dtype=torch.float).cuda()
parse_old.scatter_(1, parse_pred, 1.0)
labels = {
0: ['background', [0]],
1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
2: ['upper', [3]],
3: ['hair', [1]],
4: ['left_arm', [5]],
5: ['right_arm', [6]],
6: ['noise', [12]]
}
parse = torch.zeros(parse_pred.size(0), 7, opt.load_height, opt.load_width, dtype=torch.float).cuda()
for j in range(len(labels)):
for label in labels[j][1]:
parse[:, j] += parse_old[:, label]
# Part 2. Clothes Deformation
agnostic_gmm = F.interpolate(img_agnostic, size=(256, 192), mode='nearest')
parse_cloth_gmm = F.interpolate(parse[:, 2:3], size=(256, 192), mode='nearest')
pose_gmm = F.interpolate(pose, size=(256, 192), mode='nearest')
c_gmm = F.interpolate(c, size=(256, 192), mode='nearest')
gmm_input = torch.cat((parse_cloth_gmm, pose_gmm, agnostic_gmm), dim=1)
_, warped_grid = gmm(gmm_input, c_gmm)
warped_c = F.grid_sample(c, warped_grid, padding_mode='border')
warped_cm = F.grid_sample(cm, warped_grid, padding_mode='border')
# Part 3. Try-on synthesis
misalign_mask = parse[:, 2:3] - warped_cm
misalign_mask[misalign_mask < 0.0] = 0.0
parse_div = torch.cat((parse, misalign_mask), dim=1)
parse_div[:, 2:3] -= misalign_mask
output = alias(torch.cat((img_agnostic, pose, warped_c), dim=1), parse, parse_div, misalign_mask)
unpaired_names = []
for img_name, c_name in zip(img_names, c_names):
unpaired_names.append('{}_{}'.format(img_name.split('_')[0], c_name))
save_images(output, unpaired_names, os.path.join(opt.save_dir, opt.name))
if (i + 1) % opt.display_freq == 0:
print("step: {}".format(i + 1))
def main():
opt = get_opt()
print(opt)
if not os.path.exists(os.path.join(opt.save_dir, opt.name)):
os.makedirs(os.path.join(opt.save_dir, opt.name))
seg = SegGenerator(opt, input_nc=opt.semantic_nc + 8, output_nc=opt.semantic_nc)
gmm = GMM(opt, inputA_nc=7, inputB_nc=3)
opt.semantic_nc = 7
alias = ALIASGenerator(opt, input_nc=9)
opt.semantic_nc = 13
load_checkpoint(seg, os.path.join(opt.checkpoint_dir, opt.seg_checkpoint))
load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, opt.gmm_checkpoint))
load_checkpoint(alias, os.path.join(opt.checkpoint_dir, opt.alias_checkpoint))
seg.cuda().eval()
gmm.cuda().eval()
alias.cuda().eval()
test(opt, seg, gmm, alias)
if __name__ == '__main__':
main()