Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- app.py +415 -0
- cp_dataset_test.py +264 -0
- network_generator.py +433 -0
- networks.py +453 -0
- requirements.txt +12 -0
- test_generator.py +278 -0
- utils.py +119 -0
app.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import set_start_method
|
| 2 |
+
#set_start_method("fork")
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
#sys.path.insert(0, "../HR-VITON-main")
|
| 6 |
+
from test_generator import *
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import inspect
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import os
|
| 14 |
+
import torch
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import streamlit as st
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
|
| 20 |
+
#### pip install streamlit-image-select
|
| 21 |
+
from streamlit_image_select import image_select
|
| 22 |
+
|
| 23 |
+
demo_image_dir = "demo_images_dir"
|
| 24 |
+
assert os.path.exists(demo_image_dir)
|
| 25 |
+
demo_images = list(map(lambda y: os.path.join(demo_image_dir, y) ,filter(lambda x: x.endswith(".png") or x.endswith(".jpeg") or x.endswith(".jpg")
|
| 26 |
+
,os.listdir(demo_image_dir))))
|
| 27 |
+
assert demo_images
|
| 28 |
+
|
| 29 |
+
#https://github.com/jrieke/streamlit-image-select/issues/10
|
| 30 |
+
|
| 31 |
+
#.image-box {
|
| 32 |
+
# border: 1px solid rgba(49, 51, 63, 0.2);
|
| 33 |
+
# border-radius: 0.25rem;
|
| 34 |
+
# padding: calc(0.25rem + 1px);
|
| 35 |
+
# height: 10rem;
|
| 36 |
+
# min-width: 10rem;
|
| 37 |
+
#}
|
| 38 |
+
|
| 39 |
+
demo_images = list(map(lambda x: x.resize((256, 256)), map(Image.open, demo_images)))
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class OPT:
|
| 43 |
+
#### ConditionGenerator
|
| 44 |
+
out_layer = None
|
| 45 |
+
warp_feature = None
|
| 46 |
+
#### SPADEGenerator
|
| 47 |
+
semantic_nc = None
|
| 48 |
+
fine_height = None
|
| 49 |
+
fine_width = None
|
| 50 |
+
ngf = None
|
| 51 |
+
num_upsampling_layers = None
|
| 52 |
+
norm_G = None
|
| 53 |
+
gen_semantic_nc = None
|
| 54 |
+
#### weight load
|
| 55 |
+
tocg_checkpoint = None
|
| 56 |
+
gen_checkpoint = None
|
| 57 |
+
cuda = False
|
| 58 |
+
|
| 59 |
+
data_list = None
|
| 60 |
+
datamode = None
|
| 61 |
+
dataroot = None
|
| 62 |
+
|
| 63 |
+
batch_size = None
|
| 64 |
+
shuffle = False
|
| 65 |
+
workers = None
|
| 66 |
+
|
| 67 |
+
clothmask_composition = None
|
| 68 |
+
occlusion = False
|
| 69 |
+
datasetting = None
|
| 70 |
+
|
| 71 |
+
opt = OPT()
|
| 72 |
+
opt.out_layer = "relu"
|
| 73 |
+
opt.warp_feature = "T1"
|
| 74 |
+
|
| 75 |
+
input1_nc = 4 # cloth + cloth-mask
|
| 76 |
+
nc = 13
|
| 77 |
+
input2_nc = nc + 3 # parse_agnostic + densepose
|
| 78 |
+
output_nc = nc
|
| 79 |
+
tocg = ConditionGenerator(opt,
|
| 80 |
+
input1_nc=input1_nc,
|
| 81 |
+
input2_nc=input2_nc, output_nc=output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
|
| 82 |
+
|
| 83 |
+
#### SPADEResBlock
|
| 84 |
+
from network_generator import SPADEResBlock
|
| 85 |
+
|
| 86 |
+
opt.semantic_nc = 7
|
| 87 |
+
opt.fine_height = 1024
|
| 88 |
+
opt.fine_width = 768
|
| 89 |
+
opt.ngf = 64
|
| 90 |
+
opt.num_upsampling_layers = "most"
|
| 91 |
+
opt.norm_G = "spectralaliasinstance"
|
| 92 |
+
opt.gen_semantic_nc = 7
|
| 93 |
+
|
| 94 |
+
generator = SPADEGenerator(opt, 3+3+3)
|
| 95 |
+
generator.print_network()
|
| 96 |
+
|
| 97 |
+
#### https://drive.google.com/open?id=1XJTCdRBOPVgVTmqzhVGFAgMm2NLkw5uQ&authuser=0
|
| 98 |
+
opt.tocg_checkpoint = "mtviton.pth"
|
| 99 |
+
#### https://drive.google.com/open?id=1T5_YDUhYSSKPC_nZMk2NeC-XXUFoYeNy&authuser=0
|
| 100 |
+
opt.gen_checkpoint = "gen.pth"
|
| 101 |
+
|
| 102 |
+
opt.cuda = False
|
| 103 |
+
|
| 104 |
+
load_checkpoint(tocg, opt.tocg_checkpoint,opt)
|
| 105 |
+
load_checkpoint_G(generator, opt.gen_checkpoint,opt)
|
| 106 |
+
|
| 107 |
+
#### def test scope
|
| 108 |
+
tocg.eval()
|
| 109 |
+
generator.eval()
|
| 110 |
+
|
| 111 |
+
opt.data_list = "test_pairs.txt"
|
| 112 |
+
opt.datamode = "test"
|
| 113 |
+
opt.dataroot = "zalando-hd-resized"
|
| 114 |
+
|
| 115 |
+
opt.batch_size = 1
|
| 116 |
+
opt.shuffle = False
|
| 117 |
+
opt.workers = 1
|
| 118 |
+
opt.semantic_nc = 13
|
| 119 |
+
|
| 120 |
+
test_dataset = CPDatasetTest(opt)
|
| 121 |
+
test_loader = CPDataLoader(opt, test_dataset)
|
| 122 |
+
|
| 123 |
+
def construct_images(img_tensors, img_names = [None]):
|
| 124 |
+
#for img_tensor, img_name in zip(img_tensors, img_names):
|
| 125 |
+
for img_tensor, img_name in zip(img_tensors, img_names):
|
| 126 |
+
tensor = (img_tensor.clone() + 1) * 0.5 * 255
|
| 127 |
+
tensor = tensor.cpu().clamp(0, 255)
|
| 128 |
+
try:
|
| 129 |
+
array = tensor.numpy().astype('uint8')
|
| 130 |
+
except:
|
| 131 |
+
array = tensor.detach().numpy().astype('uint8')
|
| 132 |
+
|
| 133 |
+
if array.shape[0] == 1:
|
| 134 |
+
array = array.squeeze(0)
|
| 135 |
+
elif array.shape[0] == 3:
|
| 136 |
+
array = array.swapaxes(0, 1).swapaxes(1, 2)
|
| 137 |
+
|
| 138 |
+
im = Image.fromarray(array)
|
| 139 |
+
return im
|
| 140 |
+
|
| 141 |
+
def single_pred_slim_func(opt, inputs, tocg = tocg, generator = generator):
|
| 142 |
+
gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
|
| 143 |
+
if opt.cuda:
|
| 144 |
+
gauss = gauss.cuda()
|
| 145 |
+
|
| 146 |
+
# Model
|
| 147 |
+
if opt.cuda:
|
| 148 |
+
tocg.cuda()
|
| 149 |
+
tocg.eval()
|
| 150 |
+
generator.eval()
|
| 151 |
+
|
| 152 |
+
num = 0
|
| 153 |
+
iter_start_time = time.time()
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
for inputs in [inputs]:
|
| 156 |
+
if opt.cuda :
|
| 157 |
+
#pose_map = inputs['pose'].cuda()
|
| 158 |
+
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting].cuda()
|
| 159 |
+
#label = inputs['parse']
|
| 160 |
+
parse_agnostic = inputs['parse_agnostic']
|
| 161 |
+
agnostic = inputs['agnostic'].cuda()
|
| 162 |
+
clothes = inputs['cloth'][opt.datasetting].cuda() # target cloth
|
| 163 |
+
densepose = inputs['densepose'].cuda()
|
| 164 |
+
#im = inputs['image']
|
| 165 |
+
#input_label, input_parse_agnostic = label.cuda(), parse_agnostic.cuda()
|
| 166 |
+
input_parse_agnostic = parse_agnostic.cuda()
|
| 167 |
+
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
|
| 168 |
+
else :
|
| 169 |
+
#pose_map = inputs['pose']
|
| 170 |
+
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting]
|
| 171 |
+
#label = inputs['parse']
|
| 172 |
+
parse_agnostic = inputs['parse_agnostic']
|
| 173 |
+
agnostic = inputs['agnostic']
|
| 174 |
+
clothes = inputs['cloth'][opt.datasetting] # target cloth
|
| 175 |
+
densepose = inputs['densepose']
|
| 176 |
+
#im = inputs['image']
|
| 177 |
+
#input_label, input_parse_agnostic = label, parse_agnostic
|
| 178 |
+
input_parse_agnostic = parse_agnostic
|
| 179 |
+
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# down
|
| 184 |
+
#pose_map_down = F.interpolate(pose_map, size=(256, 192), mode='bilinear')
|
| 185 |
+
pre_clothes_mask_down = F.interpolate(pre_clothes_mask, size=(256, 192), mode='nearest')
|
| 186 |
+
#input_label_down = F.interpolate(input_label, size=(256, 192), mode='bilinear')
|
| 187 |
+
input_parse_agnostic_down = F.interpolate(input_parse_agnostic, size=(256, 192), mode='nearest')
|
| 188 |
+
#agnostic_down = F.interpolate(agnostic, size=(256, 192), mode='nearest')
|
| 189 |
+
clothes_down = F.interpolate(clothes, size=(256, 192), mode='bilinear')
|
| 190 |
+
densepose_down = F.interpolate(densepose, size=(256, 192), mode='bilinear')
|
| 191 |
+
|
| 192 |
+
shape = pre_clothes_mask.shape
|
| 193 |
+
|
| 194 |
+
# multi-task inputs
|
| 195 |
+
input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
|
| 196 |
+
input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
|
| 197 |
+
|
| 198 |
+
# forward
|
| 199 |
+
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(opt,input1, input2)
|
| 200 |
+
|
| 201 |
+
# warped cloth mask one hot
|
| 202 |
+
if opt.cuda :
|
| 203 |
+
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
|
| 204 |
+
else :
|
| 205 |
+
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float))
|
| 206 |
+
|
| 207 |
+
if opt.clothmask_composition != 'no_composition':
|
| 208 |
+
if opt.clothmask_composition == 'detach':
|
| 209 |
+
cloth_mask = torch.ones_like(fake_segmap)
|
| 210 |
+
cloth_mask[:,3:4, :, :] = warped_cm_onehot
|
| 211 |
+
fake_segmap = fake_segmap * cloth_mask
|
| 212 |
+
|
| 213 |
+
if opt.clothmask_composition == 'warp_grad':
|
| 214 |
+
cloth_mask = torch.ones_like(fake_segmap)
|
| 215 |
+
cloth_mask[:,3:4, :, :] = warped_clothmask_paired
|
| 216 |
+
fake_segmap = fake_segmap * cloth_mask
|
| 217 |
+
|
| 218 |
+
# make generator input parse map
|
| 219 |
+
fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(opt.fine_height, opt.fine_width), mode='bilinear'))
|
| 220 |
+
fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
|
| 221 |
+
|
| 222 |
+
if opt.cuda :
|
| 223 |
+
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
|
| 224 |
+
else:
|
| 225 |
+
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_()
|
| 226 |
+
old_parse.scatter_(1, fake_parse, 1.0)
|
| 227 |
+
|
| 228 |
+
labels = {
|
| 229 |
+
0: ['background', [0]],
|
| 230 |
+
1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
|
| 231 |
+
2: ['upper', [3]],
|
| 232 |
+
3: ['hair', [1]],
|
| 233 |
+
4: ['left_arm', [5]],
|
| 234 |
+
5: ['right_arm', [6]],
|
| 235 |
+
6: ['noise', [12]]
|
| 236 |
+
}
|
| 237 |
+
if opt.cuda :
|
| 238 |
+
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
|
| 239 |
+
else:
|
| 240 |
+
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_()
|
| 241 |
+
for i in range(len(labels)):
|
| 242 |
+
for label in labels[i][1]:
|
| 243 |
+
parse[:, i] += old_parse[:, label]
|
| 244 |
+
|
| 245 |
+
# warped cloth
|
| 246 |
+
N, _, iH, iW = clothes.shape
|
| 247 |
+
flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
|
| 248 |
+
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((96 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((128 - 1.0) / 2.0)], 3)
|
| 249 |
+
|
| 250 |
+
grid = make_grid(N, iH, iW,opt)
|
| 251 |
+
warped_grid = grid + flow_norm
|
| 252 |
+
warped_cloth = F.grid_sample(clothes, warped_grid, padding_mode='border')
|
| 253 |
+
warped_clothmask = F.grid_sample(pre_clothes_mask, warped_grid, padding_mode='border')
|
| 254 |
+
if opt.occlusion:
|
| 255 |
+
warped_clothmask = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask)
|
| 256 |
+
warped_cloth = warped_cloth * warped_clothmask + torch.ones_like(warped_cloth) * (1-warped_clothmask)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
output = generator(torch.cat((agnostic, densepose, warped_cloth), dim=1), parse)
|
| 260 |
+
# save output
|
| 261 |
+
return output
|
| 262 |
+
#save_images(output, unpaired_names, output_dir)
|
| 263 |
+
#num += shape[0]
|
| 264 |
+
#print(num)
|
| 265 |
+
|
| 266 |
+
opt.clothmask_composition = "warp_grad"
|
| 267 |
+
opt.occlusion = False
|
| 268 |
+
opt.datasetting = "unpaired"
|
| 269 |
+
|
| 270 |
+
def read_img_and_trans(dataset ,opt ,img_path):
|
| 271 |
+
if type(img_path) in [type("")]:
|
| 272 |
+
im = Image.open(img_path)
|
| 273 |
+
else:
|
| 274 |
+
im = img_path
|
| 275 |
+
im = transforms.Resize(opt.fine_width, interpolation=2)(im)
|
| 276 |
+
im = dataset.transform(im)
|
| 277 |
+
return im
|
| 278 |
+
|
| 279 |
+
import sys
|
| 280 |
+
sys.path.insert(0, "fashion-eye-try-on")
|
| 281 |
+
|
| 282 |
+
import os
|
| 283 |
+
from PIL import Image
|
| 284 |
+
import gradio as gr
|
| 285 |
+
from cloth_segmentation import generate_cloth_mask
|
| 286 |
+
|
| 287 |
+
def generate_cloth_mask_and_display(cloth_img):
|
| 288 |
+
path = 'fashion-eye-try-on/cloth/cloth.jpg'
|
| 289 |
+
if os.path.exists(path):
|
| 290 |
+
os.remove(path)
|
| 291 |
+
cloth_img.save(path)
|
| 292 |
+
try:
|
| 293 |
+
# os.system('.\cloth_segmentation\generate_cloth_mask.py')
|
| 294 |
+
generate_cloth_mask()
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print(e)
|
| 297 |
+
return
|
| 298 |
+
cloth_mask_img = Image.open("fashion-eye-try-on/cloth_mask/cloth.jpg")
|
| 299 |
+
return cloth_mask_img
|
| 300 |
+
|
| 301 |
+
def take_human_feature_from_dataset(dataset, idx):
|
| 302 |
+
inputs_upper = list(torch.utils.data.DataLoader(
|
| 303 |
+
[dataset[idx]], batch_size=1))[0]
|
| 304 |
+
return {
|
| 305 |
+
"parse_agnostic": inputs_upper["parse_agnostic"],
|
| 306 |
+
"agnostic": inputs_upper["agnostic"],
|
| 307 |
+
"densepose": inputs_upper["densepose"],
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def take_all_feature_with_dataset(cloth_img_path, idx, opt = opt, dataset = test_dataset, only_show_human = False):
|
| 311 |
+
if type(cloth_img_path) != type(""):
|
| 312 |
+
assert hasattr(cloth_img_path, "save")
|
| 313 |
+
cloth_img_path.save("tmp_cloth.jpg")
|
| 314 |
+
cloth_img_path = "tmp_cloth.jpg"
|
| 315 |
+
assert type(cloth_img_path) == type("")
|
| 316 |
+
inputs_upper_dict = take_human_feature_from_dataset(dataset, idx)
|
| 317 |
+
if only_show_human:
|
| 318 |
+
return Image.fromarray((inputs_upper_dict["densepose"][0].numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
|
| 319 |
+
cloth_readed = read_img_and_trans(dataset, opt,
|
| 320 |
+
cloth_img_path
|
| 321 |
+
)
|
| 322 |
+
#assert ((cloth_readed - inputs_upper["cloth"][opt.datasetting][0]) ** 2).sum().numpy() < 1e-15
|
| 323 |
+
cloth_input = {
|
| 324 |
+
opt.datasetting: cloth_readed[None,:]
|
| 325 |
+
}
|
| 326 |
+
mask_img = generate_cloth_mask_and_display(
|
| 327 |
+
Image.open(
|
| 328 |
+
cloth_img_path
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
cloth_mask_input = {
|
| 332 |
+
opt.datasetting:
|
| 333 |
+
torch.Tensor((np.asarray(mask_img) / 255))[None, None, :]
|
| 334 |
+
}
|
| 335 |
+
inputs_upper_dict["cloth"] = cloth_input
|
| 336 |
+
inputs_upper_dict["cloth_mask"] = cloth_mask_input
|
| 337 |
+
return inputs_upper_dict
|
| 338 |
+
|
| 339 |
+
def pred_func(cloth_img, pidx
|
| 340 |
+
):
|
| 341 |
+
idx = int(pidx)
|
| 342 |
+
im = cloth_img
|
| 343 |
+
|
| 344 |
+
#### truly input
|
| 345 |
+
inputs_upper_dict = take_all_feature_with_dataset(
|
| 346 |
+
im, idx, only_show_human = False)
|
| 347 |
+
|
| 348 |
+
output_slim = single_pred_slim_func(opt, inputs_upper_dict)
|
| 349 |
+
output_img = construct_images(output_slim)
|
| 350 |
+
return output_img
|
| 351 |
+
|
| 352 |
+
option = st.selectbox(
|
| 353 |
+
"Choose cloth image or Upload cloth image",
|
| 354 |
+
("Choose", "Upload", )
|
| 355 |
+
)
|
| 356 |
+
if type(option) != type(""):
|
| 357 |
+
option = "Choose"
|
| 358 |
+
|
| 359 |
+
img = None
|
| 360 |
+
uploaded_file = None
|
| 361 |
+
if option == "Upload":
|
| 362 |
+
# To read file as bytes:
|
| 363 |
+
uploaded_file = st.file_uploader("Upload img")
|
| 364 |
+
if uploaded_file is not None:
|
| 365 |
+
bytes_data = uploaded_file.getvalue()
|
| 366 |
+
img = Image.open(BytesIO(bytes_data))
|
| 367 |
+
cloth_img = img.convert("RGB").resize((256 + 128, 512))
|
| 368 |
+
st.image(cloth_img)
|
| 369 |
+
uploaded_file = st.selectbox(
|
| 370 |
+
"Have Choose the image",
|
| 371 |
+
("Wait", "Have Done")
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
img = image_select("Choose img", demo_images)
|
| 375 |
+
#img = Image.open(img)
|
| 376 |
+
cloth_img = img.convert("RGB").resize((256 + 128, 512))
|
| 377 |
+
st.image(cloth_img)
|
| 378 |
+
uploaded_file = st.selectbox(
|
| 379 |
+
"Have Choose the image",
|
| 380 |
+
("Wait", "Have Done")
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if img is not None and (uploaded_file is not "Wait" and uploaded_file is not None):
|
| 384 |
+
cloth_img = img.convert("RGB").resize((768, 1024))
|
| 385 |
+
#pidx = 44
|
| 386 |
+
pidx_index_list = [44, 84, 67]
|
| 387 |
+
poeses = []
|
| 388 |
+
for idx in range(len(pidx_index_list)):
|
| 389 |
+
poeses.append(
|
| 390 |
+
take_all_feature_with_dataset(
|
| 391 |
+
cloth_img, pidx_index_list[idx], only_show_human = True)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
col1, col2, col3 = st.columns(3)
|
| 395 |
+
|
| 396 |
+
with col1:
|
| 397 |
+
st.header("Pose 0")
|
| 398 |
+
pose_img = poeses[0]
|
| 399 |
+
st.image(pose_img)
|
| 400 |
+
b = pred_func(cloth_img, pidx_index_list[0])
|
| 401 |
+
st.image(b)
|
| 402 |
+
|
| 403 |
+
with col2:
|
| 404 |
+
st.header("Pose 1")
|
| 405 |
+
pose_img = poeses[1]
|
| 406 |
+
st.image(pose_img)
|
| 407 |
+
b = pred_func(cloth_img, pidx_index_list[1])
|
| 408 |
+
st.image(b)
|
| 409 |
+
|
| 410 |
+
with col3:
|
| 411 |
+
st.header("Pose 2")
|
| 412 |
+
pose_img = poeses[2]
|
| 413 |
+
st.image(pose_img)
|
| 414 |
+
b = pred_func(cloth_img, pidx_index_list[2])
|
| 415 |
+
st.image(b)
|
cp_dataset_test.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.utils.data as data
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CPDatasetTest(data.Dataset):
|
| 13 |
+
"""
|
| 14 |
+
Test Dataset for CP-VTON.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, opt):
|
| 17 |
+
super(CPDatasetTest, self).__init__()
|
| 18 |
+
# base setting
|
| 19 |
+
self.opt = opt
|
| 20 |
+
self.root = opt.dataroot
|
| 21 |
+
self.datamode = opt.datamode # train or test or self-defined
|
| 22 |
+
self.data_list = opt.data_list
|
| 23 |
+
self.fine_height = opt.fine_height
|
| 24 |
+
self.fine_width = opt.fine_width
|
| 25 |
+
self.semantic_nc = opt.semantic_nc
|
| 26 |
+
self.data_path = osp.join(opt.dataroot, opt.datamode)
|
| 27 |
+
self.transform = transforms.Compose([ \
|
| 28 |
+
transforms.ToTensor(), \
|
| 29 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 30 |
+
|
| 31 |
+
# load data list
|
| 32 |
+
im_names = []
|
| 33 |
+
c_names = []
|
| 34 |
+
with open(osp.join(opt.dataroot, opt.data_list), 'r') as f:
|
| 35 |
+
for line in f.readlines():
|
| 36 |
+
im_name, c_name = line.strip().split()
|
| 37 |
+
im_names.append(im_name)
|
| 38 |
+
c_names.append(c_name)
|
| 39 |
+
|
| 40 |
+
self.im_names = im_names
|
| 41 |
+
self.c_names = dict()
|
| 42 |
+
self.c_names['paired'] = im_names
|
| 43 |
+
self.c_names['unpaired'] = c_names
|
| 44 |
+
|
| 45 |
+
def name(self):
|
| 46 |
+
return "CPDataset"
|
| 47 |
+
def get_agnostic(self, im, im_parse, pose_data):
|
| 48 |
+
parse_array = np.array(im_parse)
|
| 49 |
+
parse_head = ((parse_array == 4).astype(np.float32) +
|
| 50 |
+
(parse_array == 13).astype(np.float32))
|
| 51 |
+
parse_lower = ((parse_array == 9).astype(np.float32) +
|
| 52 |
+
(parse_array == 12).astype(np.float32) +
|
| 53 |
+
(parse_array == 16).astype(np.float32) +
|
| 54 |
+
(parse_array == 17).astype(np.float32) +
|
| 55 |
+
(parse_array == 18).astype(np.float32) +
|
| 56 |
+
(parse_array == 19).astype(np.float32))
|
| 57 |
+
|
| 58 |
+
agnostic = im.copy()
|
| 59 |
+
agnostic_draw = ImageDraw.Draw(agnostic)
|
| 60 |
+
|
| 61 |
+
length_a = np.linalg.norm(pose_data[5] - pose_data[2])
|
| 62 |
+
length_b = np.linalg.norm(pose_data[12] - pose_data[9])
|
| 63 |
+
point = (pose_data[9] + pose_data[12]) / 2
|
| 64 |
+
pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
|
| 65 |
+
pose_data[12] = point + (pose_data[12] - point) / length_b * length_a
|
| 66 |
+
|
| 67 |
+
r = int(length_a / 16) + 1
|
| 68 |
+
|
| 69 |
+
# mask torso
|
| 70 |
+
for i in [9, 12]:
|
| 71 |
+
pointx, pointy = pose_data[i]
|
| 72 |
+
agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray')
|
| 73 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6)
|
| 74 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6)
|
| 75 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12)
|
| 76 |
+
agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')
|
| 77 |
+
|
| 78 |
+
# mask neck
|
| 79 |
+
pointx, pointy = pose_data[1]
|
| 80 |
+
agnostic_draw.rectangle((pointx-r*5, pointy-r*9, pointx+r*5, pointy), 'gray', 'gray')
|
| 81 |
+
|
| 82 |
+
# mask arms
|
| 83 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*12)
|
| 84 |
+
for i in [2, 5]:
|
| 85 |
+
pointx, pointy = pose_data[i]
|
| 86 |
+
agnostic_draw.ellipse((pointx-r*5, pointy-r*6, pointx+r*5, pointy+r*6), 'gray', 'gray')
|
| 87 |
+
for i in [3, 4, 6, 7]:
|
| 88 |
+
if (pose_data[i-1, 0] == 0.0 and pose_data[i-1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
|
| 89 |
+
continue
|
| 90 |
+
agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10)
|
| 91 |
+
pointx, pointy = pose_data[i]
|
| 92 |
+
agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
|
| 93 |
+
|
| 94 |
+
for parse_id, pose_ids in [(14, [5, 6, 7]), (15, [2, 3, 4])]:
|
| 95 |
+
mask_arm = Image.new('L', (768, 1024), 'white')
|
| 96 |
+
mask_arm_draw = ImageDraw.Draw(mask_arm)
|
| 97 |
+
pointx, pointy = pose_data[pose_ids[0]]
|
| 98 |
+
mask_arm_draw.ellipse((pointx-r*5, pointy-r*6, pointx+r*5, pointy+r*6), 'black', 'black')
|
| 99 |
+
for i in pose_ids[1:]:
|
| 100 |
+
if (pose_data[i-1, 0] == 0.0 and pose_data[i-1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
|
| 101 |
+
continue
|
| 102 |
+
mask_arm_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'black', width=r*10)
|
| 103 |
+
pointx, pointy = pose_data[i]
|
| 104 |
+
if i != pose_ids[-1]:
|
| 105 |
+
mask_arm_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'black', 'black')
|
| 106 |
+
mask_arm_draw.ellipse((pointx-r*4, pointy-r*4, pointx+r*4, pointy+r*4), 'black', 'black')
|
| 107 |
+
|
| 108 |
+
parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
|
| 109 |
+
agnostic.paste(im, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))
|
| 110 |
+
|
| 111 |
+
agnostic.paste(im, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
|
| 112 |
+
agnostic.paste(im, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
|
| 113 |
+
return agnostic
|
| 114 |
+
def __getitem__(self, index):
|
| 115 |
+
im_name = self.im_names[index]
|
| 116 |
+
c_name = {}
|
| 117 |
+
c = {}
|
| 118 |
+
cm = {}
|
| 119 |
+
for key in self.c_names:
|
| 120 |
+
c_name[key] = self.c_names[key][index]
|
| 121 |
+
c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB')
|
| 122 |
+
c[key] = transforms.Resize(self.fine_width, interpolation=2)(c[key])
|
| 123 |
+
cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key]))
|
| 124 |
+
cm[key] = transforms.Resize(self.fine_width, interpolation=0)(cm[key])
|
| 125 |
+
|
| 126 |
+
c[key] = self.transform(c[key]) # [-1,1]
|
| 127 |
+
cm_array = np.array(cm[key])
|
| 128 |
+
cm_array = (cm_array >= 128).astype(np.float32)
|
| 129 |
+
cm[key] = torch.from_numpy(cm_array) # [0,1]
|
| 130 |
+
cm[key].unsqueeze_(0)
|
| 131 |
+
|
| 132 |
+
# person image
|
| 133 |
+
im_pil_big = Image.open(osp.join(self.data_path, 'image', im_name))
|
| 134 |
+
im_pil = transforms.Resize(self.fine_width, interpolation=2)(im_pil_big)
|
| 135 |
+
|
| 136 |
+
im = self.transform(im_pil)
|
| 137 |
+
|
| 138 |
+
# load parsing image
|
| 139 |
+
parse_name = im_name.replace('.jpg', '.png')
|
| 140 |
+
im_parse_pil_big = Image.open(osp.join(self.data_path, 'image-parse-v3', parse_name))
|
| 141 |
+
im_parse_pil = transforms.Resize(self.fine_width, interpolation=0)(im_parse_pil_big)
|
| 142 |
+
parse = torch.from_numpy(np.array(im_parse_pil)[None]).long()
|
| 143 |
+
im_parse = self.transform(im_parse_pil.convert('RGB'))
|
| 144 |
+
|
| 145 |
+
labels = {
|
| 146 |
+
0: ['background', [0, 10]],
|
| 147 |
+
1: ['hair', [1, 2]],
|
| 148 |
+
2: ['face', [4, 13]],
|
| 149 |
+
3: ['upper', [5, 6, 7]],
|
| 150 |
+
4: ['bottom', [9, 12]],
|
| 151 |
+
5: ['left_arm', [14]],
|
| 152 |
+
6: ['right_arm', [15]],
|
| 153 |
+
7: ['left_leg', [16]],
|
| 154 |
+
8: ['right_leg', [17]],
|
| 155 |
+
9: ['left_shoe', [18]],
|
| 156 |
+
10: ['right_shoe', [19]],
|
| 157 |
+
11: ['socks', [8]],
|
| 158 |
+
12: ['noise', [3, 11]]
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
parse_map = torch.FloatTensor(20, self.fine_height, self.fine_width).zero_()
|
| 162 |
+
parse_map = parse_map.scatter_(0, parse, 1.0)
|
| 163 |
+
new_parse_map = torch.FloatTensor(self.semantic_nc, self.fine_height, self.fine_width).zero_()
|
| 164 |
+
|
| 165 |
+
for i in range(len(labels)):
|
| 166 |
+
for label in labels[i][1]:
|
| 167 |
+
new_parse_map[i] += parse_map[label]
|
| 168 |
+
|
| 169 |
+
parse_onehot = torch.FloatTensor(1, self.fine_height, self.fine_width).zero_()
|
| 170 |
+
for i in range(len(labels)):
|
| 171 |
+
for label in labels[i][1]:
|
| 172 |
+
parse_onehot[0] += parse_map[label] * i
|
| 173 |
+
|
| 174 |
+
# load image-parse-agnostic
|
| 175 |
+
image_parse_agnostic = Image.open(osp.join(self.data_path, 'image-parse-agnostic-v3.2', parse_name))
|
| 176 |
+
image_parse_agnostic = transforms.Resize(self.fine_width, interpolation=0)(image_parse_agnostic)
|
| 177 |
+
parse_agnostic = torch.from_numpy(np.array(image_parse_agnostic)[None]).long()
|
| 178 |
+
image_parse_agnostic = self.transform(image_parse_agnostic.convert('RGB'))
|
| 179 |
+
|
| 180 |
+
parse_agnostic_map = torch.FloatTensor(20, self.fine_height, self.fine_width).zero_()
|
| 181 |
+
parse_agnostic_map = parse_agnostic_map.scatter_(0, parse_agnostic, 1.0)
|
| 182 |
+
new_parse_agnostic_map = torch.FloatTensor(self.semantic_nc, self.fine_height, self.fine_width).zero_()
|
| 183 |
+
for i in range(len(labels)):
|
| 184 |
+
for label in labels[i][1]:
|
| 185 |
+
new_parse_agnostic_map[i] += parse_agnostic_map[label]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# parse cloth & parse cloth mask
|
| 189 |
+
pcm = new_parse_map[3:4]
|
| 190 |
+
im_c = im * pcm + (1 - pcm)
|
| 191 |
+
|
| 192 |
+
# load pose points
|
| 193 |
+
pose_name = im_name.replace('.jpg', '_rendered.png')
|
| 194 |
+
pose_map = Image.open(osp.join(self.data_path, 'openpose_img', pose_name))
|
| 195 |
+
pose_map = transforms.Resize(self.fine_width, interpolation=2)(pose_map)
|
| 196 |
+
pose_map = self.transform(pose_map) # [-1,1]
|
| 197 |
+
|
| 198 |
+
pose_name = im_name.replace('.jpg', '_keypoints.json')
|
| 199 |
+
with open(osp.join(self.data_path, 'openpose_json', pose_name), 'r') as f:
|
| 200 |
+
pose_label = json.load(f)
|
| 201 |
+
pose_data = pose_label['people'][0]['pose_keypoints_2d']
|
| 202 |
+
pose_data = np.array(pose_data)
|
| 203 |
+
pose_data = pose_data.reshape((-1, 3))[:, :2]
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# load densepose
|
| 207 |
+
densepose_name = im_name.replace('image', 'image-densepose')
|
| 208 |
+
densepose_map = Image.open(osp.join(self.data_path, 'image-densepose', densepose_name))
|
| 209 |
+
densepose_map = transforms.Resize(self.fine_width, interpolation=2)(densepose_map)
|
| 210 |
+
densepose_map = self.transform(densepose_map) # [-1,1]
|
| 211 |
+
agnostic = self.get_agnostic(im_pil_big, im_parse_pil_big, pose_data)
|
| 212 |
+
agnostic = transforms.Resize(self.fine_width, interpolation=2)(agnostic)
|
| 213 |
+
agnostic = self.transform(agnostic)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
result = {
|
| 218 |
+
'c_name': c_name, # for visualization
|
| 219 |
+
'im_name': im_name, # for visualization or ground truth
|
| 220 |
+
# intput 1 (clothfloww)
|
| 221 |
+
'cloth': c, # for input
|
| 222 |
+
'cloth_mask': cm, # for input
|
| 223 |
+
# intput 2 (segnet)
|
| 224 |
+
'parse_agnostic': new_parse_agnostic_map,
|
| 225 |
+
'densepose': densepose_map,
|
| 226 |
+
'pose': pose_map, # for conditioning
|
| 227 |
+
# GT
|
| 228 |
+
'parse_onehot' : parse_onehot, # Cross Entropy
|
| 229 |
+
'parse': new_parse_map, # GAN Loss real
|
| 230 |
+
'pcm': pcm, # L1 Loss & vis
|
| 231 |
+
'parse_cloth': im_c, # VGG Loss & vis
|
| 232 |
+
# visualization
|
| 233 |
+
'image': im, # for visualization
|
| 234 |
+
'agnostic' : agnostic
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
return result
|
| 238 |
+
|
| 239 |
+
def __len__(self):
|
| 240 |
+
return len(self.im_names)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class CPDataLoader(object):
|
| 244 |
+
def __init__(self, opt, dataset):
|
| 245 |
+
super(CPDataLoader, self).__init__()
|
| 246 |
+
if opt.shuffle :
|
| 247 |
+
train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
|
| 248 |
+
else:
|
| 249 |
+
train_sampler = None
|
| 250 |
+
|
| 251 |
+
self.data_loader = torch.utils.data.DataLoader(
|
| 252 |
+
dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
|
| 253 |
+
num_workers=opt.workers, pin_memory=True, drop_last=True, sampler=train_sampler)
|
| 254 |
+
self.dataset = dataset
|
| 255 |
+
self.data_iter = self.data_loader.__iter__()
|
| 256 |
+
|
| 257 |
+
def next_batch(self):
|
| 258 |
+
try:
|
| 259 |
+
batch = self.data_iter.__next__()
|
| 260 |
+
except StopIteration:
|
| 261 |
+
self.data_iter = self.data_loader.__iter__()
|
| 262 |
+
batch = self.data_iter.__next__()
|
| 263 |
+
|
| 264 |
+
return batch
|
network_generator.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import init
|
| 5 |
+
from torch.nn.utils import spectral_norm
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseNetwork(nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(BaseNetwork, self).__init__()
|
| 12 |
+
|
| 13 |
+
def print_network(self):
|
| 14 |
+
num_params = 0
|
| 15 |
+
for param in self.parameters():
|
| 16 |
+
num_params += param.numel()
|
| 17 |
+
print("Network [{}] was created. Total number of parameters: {:.1f} million. "
|
| 18 |
+
"To see the architecture, do print(network).".format(self.__class__.__name__, num_params / 1000000))
|
| 19 |
+
|
| 20 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
| 21 |
+
def init_func(m):
|
| 22 |
+
classname = m.__class__.__name__
|
| 23 |
+
if 'BatchNorm2d' in classname:
|
| 24 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
| 25 |
+
init.normal_(m.weight.data, 1.0, gain)
|
| 26 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 27 |
+
init.constant_(m.bias.data, 0.0)
|
| 28 |
+
elif ('Conv' in classname or 'Linear' in classname) and hasattr(m, 'weight'):
|
| 29 |
+
if init_type == 'normal':
|
| 30 |
+
init.normal_(m.weight.data, 0.0, gain)
|
| 31 |
+
elif init_type == 'xavier':
|
| 32 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
| 33 |
+
elif init_type == 'xavier_uniform':
|
| 34 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
| 35 |
+
elif init_type == 'kaiming':
|
| 36 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 37 |
+
elif init_type == 'orthogonal':
|
| 38 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
| 39 |
+
elif init_type == 'none': # uses pytorch's default init method
|
| 40 |
+
m.reset_parameters()
|
| 41 |
+
else:
|
| 42 |
+
raise NotImplementedError("initialization method '{}' is not implemented".format(init_type))
|
| 43 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 44 |
+
init.constant_(m.bias.data, 0.0)
|
| 45 |
+
|
| 46 |
+
self.apply(init_func)
|
| 47 |
+
|
| 48 |
+
def forward(self, *inputs):
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MaskNorm(nn.Module):
|
| 53 |
+
def __init__(self, norm_nc):
|
| 54 |
+
super(MaskNorm, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.norm_layer = nn.InstanceNorm2d(norm_nc, affine=False)
|
| 57 |
+
|
| 58 |
+
def normalize_region(self, region, mask):
|
| 59 |
+
b, c, h, w = region.size()
|
| 60 |
+
|
| 61 |
+
num_pixels = mask.sum((2, 3), keepdim=True) # size: (b, 1, 1, 1)
|
| 62 |
+
num_pixels[num_pixels == 0] = 1
|
| 63 |
+
mu = region.sum((2, 3), keepdim=True) / num_pixels # size: (b, c, 1, 1)
|
| 64 |
+
|
| 65 |
+
normalized_region = self.norm_layer(region + (1 - mask) * mu)
|
| 66 |
+
return normalized_region * torch.sqrt(num_pixels / (h * w))
|
| 67 |
+
|
| 68 |
+
def forward(self, x, mask):
|
| 69 |
+
mask = mask.detach()
|
| 70 |
+
normalized_foreground = self.normalize_region(x * mask, mask)
|
| 71 |
+
normalized_background = self.normalize_region(x * (1 - mask), 1 - mask)
|
| 72 |
+
return normalized_foreground + normalized_background
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class SPADENorm(nn.Module):
|
| 76 |
+
def __init__(self,opt, norm_type, norm_nc, label_nc):
|
| 77 |
+
super(SPADENorm, self).__init__()
|
| 78 |
+
self.param_opt=opt
|
| 79 |
+
self.noise_scale = nn.Parameter(torch.zeros(norm_nc))
|
| 80 |
+
|
| 81 |
+
assert norm_type.startswith('alias')
|
| 82 |
+
param_free_norm_type = norm_type[len('alias'):]
|
| 83 |
+
if param_free_norm_type == 'batch':
|
| 84 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
| 85 |
+
elif param_free_norm_type == 'instance':
|
| 86 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
| 87 |
+
elif param_free_norm_type == 'mask':
|
| 88 |
+
self.param_free_norm = MaskNorm(norm_nc)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"'{}' is not a recognized parameter-free normalization type in SPADENorm".format(param_free_norm_type)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
nhidden = 128
|
| 95 |
+
ks = 3
|
| 96 |
+
pw = ks // 2
|
| 97 |
+
self.conv_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
| 98 |
+
self.conv_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
| 99 |
+
self.conv_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, seg, misalign_mask=None):
|
| 102 |
+
# Part 1. Generate parameter-free normalized activations.
|
| 103 |
+
b, c, h, w = x.size()
|
| 104 |
+
if self.param_opt.cuda :
|
| 105 |
+
noise = (torch.randn(b, w, h, 1).cuda() * self.noise_scale).transpose(1, 3)
|
| 106 |
+
else:
|
| 107 |
+
noise = (torch.randn(b, w, h, 1)* self.noise_scale).transpose(1, 3)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if misalign_mask is None:
|
| 111 |
+
normalized = self.param_free_norm(x + noise)
|
| 112 |
+
else:
|
| 113 |
+
normalized = self.param_free_norm(x + noise, misalign_mask)
|
| 114 |
+
|
| 115 |
+
# Part 2. Produce affine parameters conditioned on the segmentation map.
|
| 116 |
+
actv = self.conv_shared(seg)
|
| 117 |
+
gamma = self.conv_gamma(actv)
|
| 118 |
+
beta = self.conv_beta(actv)
|
| 119 |
+
|
| 120 |
+
# Apply the affine parameters.
|
| 121 |
+
output = normalized * (1 + gamma) + beta
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SPADEResBlock(nn.Module):
|
| 126 |
+
def __init__(self, opt, input_nc, output_nc, use_mask_norm=True):
|
| 127 |
+
super(SPADEResBlock, self).__init__()
|
| 128 |
+
self.param_opt=opt
|
| 129 |
+
self.learned_shortcut = (input_nc != output_nc)
|
| 130 |
+
middle_nc = min(input_nc, output_nc)
|
| 131 |
+
|
| 132 |
+
self.conv_0 = nn.Conv2d(input_nc, middle_nc, kernel_size=3, padding=1)
|
| 133 |
+
self.conv_1 = nn.Conv2d(middle_nc, output_nc, kernel_size=3, padding=1)
|
| 134 |
+
if self.learned_shortcut:
|
| 135 |
+
self.conv_s = nn.Conv2d(input_nc, output_nc, kernel_size=1, bias=False)
|
| 136 |
+
|
| 137 |
+
subnorm_type = opt.norm_G
|
| 138 |
+
if subnorm_type.startswith('spectral'):
|
| 139 |
+
subnorm_type = subnorm_type[len('spectral'):]
|
| 140 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
| 141 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
| 142 |
+
if self.learned_shortcut:
|
| 143 |
+
self.conv_s = spectral_norm(self.conv_s)
|
| 144 |
+
|
| 145 |
+
gen_semantic_nc = opt.gen_semantic_nc
|
| 146 |
+
if use_mask_norm:
|
| 147 |
+
subnorm_type = 'aliasmask'
|
| 148 |
+
gen_semantic_nc = gen_semantic_nc + 1
|
| 149 |
+
|
| 150 |
+
self.norm_0 = SPADENorm(opt,subnorm_type, input_nc, gen_semantic_nc)
|
| 151 |
+
self.norm_1 = SPADENorm(opt,subnorm_type, middle_nc, gen_semantic_nc)
|
| 152 |
+
if self.learned_shortcut:
|
| 153 |
+
self.norm_s = SPADENorm(opt,subnorm_type, input_nc, gen_semantic_nc)
|
| 154 |
+
|
| 155 |
+
self.relu = nn.LeakyReLU(0.2)
|
| 156 |
+
|
| 157 |
+
def shortcut(self, x, seg, misalign_mask):
|
| 158 |
+
if self.learned_shortcut:
|
| 159 |
+
return self.conv_s(self.norm_s(x, seg, misalign_mask))
|
| 160 |
+
else:
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
def forward(self, x, seg, misalign_mask=None):
|
| 164 |
+
seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
| 165 |
+
if misalign_mask is not None:
|
| 166 |
+
misalign_mask = F.interpolate(misalign_mask, size=x.size()[2:], mode='nearest')
|
| 167 |
+
|
| 168 |
+
x_s = self.shortcut(x, seg, misalign_mask)
|
| 169 |
+
|
| 170 |
+
dx = self.conv_0(self.relu(self.norm_0(x, seg, misalign_mask)))
|
| 171 |
+
dx = self.conv_1(self.relu(self.norm_1(dx, seg, misalign_mask)))
|
| 172 |
+
output = x_s + dx
|
| 173 |
+
return output
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class SPADEGenerator(BaseNetwork):
|
| 177 |
+
def __init__(self, opt, input_nc):
|
| 178 |
+
super(SPADEGenerator, self).__init__()
|
| 179 |
+
self.num_upsampling_layers = opt.num_upsampling_layers
|
| 180 |
+
self.param_opt=opt
|
| 181 |
+
self.sh, self.sw = self.compute_latent_vector_size(opt)
|
| 182 |
+
|
| 183 |
+
nf = opt.ngf
|
| 184 |
+
self.conv_0 = nn.Conv2d(input_nc, nf * 16, kernel_size=3, padding=1)
|
| 185 |
+
for i in range(1, 8):
|
| 186 |
+
self.add_module('conv_{}'.format(i), nn.Conv2d(input_nc, 16, kernel_size=3, padding=1))
|
| 187 |
+
|
| 188 |
+
self.head_0 = SPADEResBlock(opt, nf * 16, nf * 16, use_mask_norm=False)
|
| 189 |
+
|
| 190 |
+
self.G_middle_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
|
| 191 |
+
self.G_middle_1 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
|
| 192 |
+
|
| 193 |
+
self.up_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 8, use_mask_norm=False)
|
| 194 |
+
self.up_1 = SPADEResBlock(opt, nf * 8 + 16, nf * 4, use_mask_norm=False)
|
| 195 |
+
self.up_2 = SPADEResBlock(opt, nf * 4 + 16, nf * 2, use_mask_norm=False)
|
| 196 |
+
self.up_3 = SPADEResBlock(opt, nf * 2 + 16, nf * 1, use_mask_norm=False)
|
| 197 |
+
if self.num_upsampling_layers == 'most':
|
| 198 |
+
self.up_4 = SPADEResBlock(opt, nf * 1 + 16, nf // 2, use_mask_norm=False)
|
| 199 |
+
nf = nf // 2
|
| 200 |
+
|
| 201 |
+
self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
|
| 202 |
+
|
| 203 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
| 204 |
+
self.relu = nn.LeakyReLU(0.2)
|
| 205 |
+
self.tanh = nn.Tanh()
|
| 206 |
+
|
| 207 |
+
def compute_latent_vector_size(self, opt):
|
| 208 |
+
if self.num_upsampling_layers == 'normal':
|
| 209 |
+
num_up_layers = 5
|
| 210 |
+
elif self.num_upsampling_layers == 'more':
|
| 211 |
+
num_up_layers = 6
|
| 212 |
+
elif self.num_upsampling_layers == 'most':
|
| 213 |
+
num_up_layers = 7
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError("opt.num_upsampling_layers '{}' is not recognized".format(self.num_upsampling_layers))
|
| 216 |
+
|
| 217 |
+
sh = opt.fine_height // 2**num_up_layers
|
| 218 |
+
sw = opt.fine_width // 2**num_up_layers
|
| 219 |
+
return sh, sw
|
| 220 |
+
|
| 221 |
+
def forward(self, x, seg):
|
| 222 |
+
samples = [F.interpolate(x, size=(self.sh * 2**i, self.sw * 2**i), mode='nearest') for i in range(8)]
|
| 223 |
+
features = [self._modules['conv_{}'.format(i)](samples[i]) for i in range(8)]
|
| 224 |
+
|
| 225 |
+
x = self.head_0(features[0], seg)
|
| 226 |
+
x = self.up(x)
|
| 227 |
+
x = self.G_middle_0(torch.cat((x, features[1]), 1), seg)
|
| 228 |
+
if self.num_upsampling_layers in ['more', 'most']:
|
| 229 |
+
x = self.up(x)
|
| 230 |
+
x = self.G_middle_1(torch.cat((x, features[2]), 1), seg)
|
| 231 |
+
|
| 232 |
+
x = self.up(x)
|
| 233 |
+
x = self.up_0(torch.cat((x, features[3]), 1), seg)
|
| 234 |
+
x = self.up(x)
|
| 235 |
+
x = self.up_1(torch.cat((x, features[4]), 1), seg)
|
| 236 |
+
x = self.up(x)
|
| 237 |
+
x = self.up_2(torch.cat((x, features[5]), 1), seg)
|
| 238 |
+
x = self.up(x)
|
| 239 |
+
x = self.up_3(torch.cat((x, features[6]), 1), seg)
|
| 240 |
+
if self.num_upsampling_layers == 'most':
|
| 241 |
+
x = self.up(x)
|
| 242 |
+
x = self.up_4(torch.cat((x, features[7]), 1), seg)
|
| 243 |
+
|
| 244 |
+
x = self.conv_img(self.relu(x))
|
| 245 |
+
return self.tanh(x)
|
| 246 |
+
########################################################################
|
| 247 |
+
|
| 248 |
+
########################################################################
|
| 249 |
+
|
| 250 |
+
class NLayerDiscriminator(BaseNetwork):
|
| 251 |
+
|
| 252 |
+
def __init__(self, opt):
|
| 253 |
+
super().__init__()
|
| 254 |
+
self.no_ganFeat_loss = opt.no_ganFeat_loss
|
| 255 |
+
nf = opt.ndf
|
| 256 |
+
|
| 257 |
+
kw = 4
|
| 258 |
+
pw = int(np.ceil((kw - 1.0) / 2))
|
| 259 |
+
norm_layer = get_nonspade_norm_layer(opt.norm_D)
|
| 260 |
+
|
| 261 |
+
input_nc = opt.gen_semantic_nc + 3
|
| 262 |
+
# input_nc = opt.gen_semantic_nc + 13
|
| 263 |
+
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=pw),
|
| 264 |
+
nn.LeakyReLU(0.2, False)]]
|
| 265 |
+
|
| 266 |
+
for n in range(1, opt.n_layers_D):
|
| 267 |
+
nf_prev = nf
|
| 268 |
+
nf = min(nf * 2, 512)
|
| 269 |
+
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=pw)),
|
| 270 |
+
nn.LeakyReLU(0.2, False)]]
|
| 271 |
+
|
| 272 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=pw)]]
|
| 273 |
+
|
| 274 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
| 275 |
+
for n in range(len(sequence)):
|
| 276 |
+
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
| 277 |
+
|
| 278 |
+
def forward(self, input):
|
| 279 |
+
results = [input]
|
| 280 |
+
for submodel in self.children():
|
| 281 |
+
intermediate_output = submodel(results[-1])
|
| 282 |
+
results.append(intermediate_output)
|
| 283 |
+
|
| 284 |
+
get_intermediate_features = not self.no_ganFeat_loss
|
| 285 |
+
if get_intermediate_features:
|
| 286 |
+
return results[1:]
|
| 287 |
+
else:
|
| 288 |
+
return results[-1]
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class MultiscaleDiscriminator(BaseNetwork):
|
| 292 |
+
|
| 293 |
+
def __init__(self, opt):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.no_ganFeat_loss = opt.no_ganFeat_loss
|
| 296 |
+
|
| 297 |
+
for i in range(opt.num_D):
|
| 298 |
+
subnetD = NLayerDiscriminator(opt)
|
| 299 |
+
self.add_module('discriminator_%d' % i, subnetD)
|
| 300 |
+
|
| 301 |
+
def downsample(self, input):
|
| 302 |
+
return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
| 303 |
+
|
| 304 |
+
# Returns list of lists of discriminator outputs.
|
| 305 |
+
# The final result is of size opt.num_D x opt.n_layers_D
|
| 306 |
+
def forward(self, input):
|
| 307 |
+
result = []
|
| 308 |
+
get_intermediate_features = not self.no_ganFeat_loss
|
| 309 |
+
for name, D in self.named_children():
|
| 310 |
+
out = D(input)
|
| 311 |
+
if not get_intermediate_features:
|
| 312 |
+
out = [out]
|
| 313 |
+
result.append(out)
|
| 314 |
+
input = self.downsample(input)
|
| 315 |
+
|
| 316 |
+
return result
|
| 317 |
+
|
| 318 |
+
class GANLoss(nn.Module):
|
| 319 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
|
| 320 |
+
super(GANLoss, self).__init__()
|
| 321 |
+
self.real_label = target_real_label
|
| 322 |
+
self.fake_label = target_fake_label
|
| 323 |
+
self.real_label_tensor = None
|
| 324 |
+
self.fake_label_tensor = None
|
| 325 |
+
self.zero_tensor = None
|
| 326 |
+
self.Tensor = tensor
|
| 327 |
+
self.gan_mode = gan_mode
|
| 328 |
+
if gan_mode == 'ls':
|
| 329 |
+
pass
|
| 330 |
+
elif gan_mode == 'original':
|
| 331 |
+
pass
|
| 332 |
+
elif gan_mode == 'w':
|
| 333 |
+
pass
|
| 334 |
+
elif gan_mode == 'hinge':
|
| 335 |
+
pass
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
|
| 338 |
+
|
| 339 |
+
def get_target_tensor(self, input, target_is_real):
|
| 340 |
+
if target_is_real:
|
| 341 |
+
if self.real_label_tensor is None:
|
| 342 |
+
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
|
| 343 |
+
self.real_label_tensor.requires_grad_(False)
|
| 344 |
+
return self.real_label_tensor.expand_as(input)
|
| 345 |
+
else:
|
| 346 |
+
if self.fake_label_tensor is None:
|
| 347 |
+
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
|
| 348 |
+
self.fake_label_tensor.requires_grad_(False)
|
| 349 |
+
return self.fake_label_tensor.expand_as(input)
|
| 350 |
+
|
| 351 |
+
def get_zero_tensor(self, input):
|
| 352 |
+
if self.zero_tensor is None:
|
| 353 |
+
self.zero_tensor = self.Tensor(1).fill_(0)
|
| 354 |
+
self.zero_tensor.requires_grad_(False)
|
| 355 |
+
return self.zero_tensor.expand_as(input)
|
| 356 |
+
|
| 357 |
+
def loss(self, input, target_is_real, for_discriminator=True):
|
| 358 |
+
if self.gan_mode == 'original': # cross entropy loss
|
| 359 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
| 360 |
+
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
|
| 361 |
+
return loss
|
| 362 |
+
elif self.gan_mode == 'ls':
|
| 363 |
+
target_tensor = self.get_target_tensor(input, target_is_real)
|
| 364 |
+
return F.mse_loss(input, target_tensor)
|
| 365 |
+
elif self.gan_mode == 'hinge':
|
| 366 |
+
if for_discriminator:
|
| 367 |
+
if target_is_real:
|
| 368 |
+
minval = torch.min(input - 1, self.get_zero_tensor(input))
|
| 369 |
+
loss = -torch.mean(minval)
|
| 370 |
+
else:
|
| 371 |
+
minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
| 372 |
+
loss = -torch.mean(minval)
|
| 373 |
+
else:
|
| 374 |
+
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
| 375 |
+
loss = -torch.mean(input)
|
| 376 |
+
return loss
|
| 377 |
+
else:
|
| 378 |
+
# wgan
|
| 379 |
+
if target_is_real:
|
| 380 |
+
return -input.mean()
|
| 381 |
+
else:
|
| 382 |
+
return input.mean()
|
| 383 |
+
|
| 384 |
+
def __call__(self, input, target_is_real, for_discriminator=True):
|
| 385 |
+
# computing loss is a bit complicated because |input| may not be
|
| 386 |
+
# a tensor, but list of tensors in case of multiscale discriminator
|
| 387 |
+
if isinstance(input, list):
|
| 388 |
+
loss = 0
|
| 389 |
+
for pred_i in input:
|
| 390 |
+
if isinstance(pred_i, list):
|
| 391 |
+
pred_i = pred_i[-1]
|
| 392 |
+
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
|
| 393 |
+
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
| 394 |
+
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
| 395 |
+
loss += new_loss
|
| 396 |
+
return loss / len(input)
|
| 397 |
+
else:
|
| 398 |
+
return self.loss(input, target_is_real, for_discriminator)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def get_nonspade_norm_layer(norm_type='instance'):
|
| 402 |
+
def get_out_channel(layer):
|
| 403 |
+
if hasattr(layer, 'out_channels'):
|
| 404 |
+
return getattr(layer, 'out_channels')
|
| 405 |
+
return layer.weight.size(0)
|
| 406 |
+
|
| 407 |
+
def add_norm_layer(layer):
|
| 408 |
+
nonlocal norm_type
|
| 409 |
+
if norm_type.startswith('spectral'):
|
| 410 |
+
layer = spectral_norm(layer)
|
| 411 |
+
subnorm_type = norm_type[len('spectral'):]
|
| 412 |
+
|
| 413 |
+
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
| 414 |
+
return layer
|
| 415 |
+
|
| 416 |
+
# remove bias in the previous layer, which is meaningless
|
| 417 |
+
# since it has no effect after normalization
|
| 418 |
+
if getattr(layer, 'bias', None) is not None:
|
| 419 |
+
delattr(layer, 'bias')
|
| 420 |
+
layer.register_parameter('bias', None)
|
| 421 |
+
|
| 422 |
+
if subnorm_type == 'batch':
|
| 423 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
| 424 |
+
# elif subnorm_type == 'sync_batch':
|
| 425 |
+
# norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
|
| 426 |
+
elif subnorm_type == 'instance':
|
| 427 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
| 428 |
+
else:
|
| 429 |
+
raise ValueError('normalization layer %s is not recognized' % subnorm_type)
|
| 430 |
+
|
| 431 |
+
return nn.Sequential(layer, norm_layer)
|
| 432 |
+
|
| 433 |
+
return add_norm_layer
|
networks.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from torchvision import models
|
| 6 |
+
import os
|
| 7 |
+
from torch.nn.utils import spectral_norm
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import functools
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConditionGenerator(nn.Module):
|
| 14 |
+
def __init__(self, opt, input1_nc, input2_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d):
|
| 15 |
+
super(ConditionGenerator, self).__init__()
|
| 16 |
+
self.warp_feature = opt.warp_feature
|
| 17 |
+
self.out_layer_opt = opt.out_layer
|
| 18 |
+
|
| 19 |
+
self.ClothEncoder = nn.Sequential(
|
| 20 |
+
ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 128
|
| 21 |
+
ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'), # 64
|
| 22 |
+
ResBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
|
| 23 |
+
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
|
| 24 |
+
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down') # 8
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.PoseEncoder = nn.Sequential(
|
| 28 |
+
ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
|
| 29 |
+
ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
|
| 30 |
+
ResBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, scale='down'),
|
| 31 |
+
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
|
| 32 |
+
ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down')
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.conv = ResBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, scale='same')
|
| 36 |
+
|
| 37 |
+
if opt.warp_feature == 'T1':
|
| 38 |
+
# in_nc -> skip connection + T1, T2 channel
|
| 39 |
+
self.SegDecoder = nn.Sequential(
|
| 40 |
+
ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
|
| 41 |
+
ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
|
| 42 |
+
ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 2, norm_layer=norm_layer, scale='up'), # 64
|
| 43 |
+
ResBlock(ngf * 2 * 2 + ngf * 4 , ngf, norm_layer=norm_layer, scale='up'), # 128
|
| 44 |
+
ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up') # 256
|
| 45 |
+
)
|
| 46 |
+
if opt.warp_feature == 'encoder':
|
| 47 |
+
# in_nc -> [x, skip_connection, warped_cloth_encoder_feature(E1)]
|
| 48 |
+
self.SegDecoder = nn.Sequential(
|
| 49 |
+
ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
|
| 50 |
+
ResBlock(ngf * 4 * 3, ngf * 4, norm_layer=norm_layer, scale='up'), # 32
|
| 51 |
+
ResBlock(ngf * 4 * 3, ngf * 2, norm_layer=norm_layer, scale='up'), # 64
|
| 52 |
+
ResBlock(ngf * 2 * 3, ngf, norm_layer=norm_layer, scale='up'), # 128
|
| 53 |
+
ResBlock(ngf * 1 * 3, ngf, norm_layer=norm_layer, scale='up') # 256
|
| 54 |
+
)
|
| 55 |
+
if opt.out_layer == 'relu':
|
| 56 |
+
self.out_layer = ResBlock(ngf + input1_nc + input2_nc, output_nc, norm_layer=norm_layer, scale='same')
|
| 57 |
+
if opt.out_layer == 'conv':
|
| 58 |
+
self.out_layer = nn.Sequential(
|
| 59 |
+
ResBlock(ngf + input1_nc + input2_nc, ngf, norm_layer=norm_layer, scale='same'),
|
| 60 |
+
nn.Conv2d(ngf, output_nc, kernel_size=1, bias=True)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Cloth Conv 1x1
|
| 64 |
+
self.conv1 = nn.Sequential(
|
| 65 |
+
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
|
| 66 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
|
| 67 |
+
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
|
| 68 |
+
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Person Conv 1x1
|
| 72 |
+
self.conv2 = nn.Sequential(
|
| 73 |
+
nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
|
| 74 |
+
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
|
| 75 |
+
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
|
| 76 |
+
nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.flow_conv = nn.ModuleList([
|
| 80 |
+
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
| 81 |
+
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
| 82 |
+
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
| 83 |
+
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
| 84 |
+
nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.bottleneck = nn.Sequential(
|
| 89 |
+
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
|
| 90 |
+
nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
|
| 91 |
+
nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
|
| 92 |
+
nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def normalize(self, x):
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
def forward(self,opt,input1, input2, upsample='bilinear'):
|
| 99 |
+
E1_list = []
|
| 100 |
+
E2_list = []
|
| 101 |
+
flow_list = []
|
| 102 |
+
# warped_grid_list = []
|
| 103 |
+
|
| 104 |
+
# Feature Pyramid Network
|
| 105 |
+
for i in range(5):
|
| 106 |
+
if i == 0:
|
| 107 |
+
E1_list.append(self.ClothEncoder[i](input1))
|
| 108 |
+
E2_list.append(self.PoseEncoder[i](input2))
|
| 109 |
+
else:
|
| 110 |
+
E1_list.append(self.ClothEncoder[i](E1_list[i - 1]))
|
| 111 |
+
E2_list.append(self.PoseEncoder[i](E2_list[i - 1]))
|
| 112 |
+
|
| 113 |
+
# Compute Clothflow
|
| 114 |
+
for i in range(5):
|
| 115 |
+
N, _, iH, iW = E1_list[4 - i].size()
|
| 116 |
+
grid = make_grid(N, iH, iW,opt)
|
| 117 |
+
|
| 118 |
+
if i == 0:
|
| 119 |
+
T1 = E1_list[4 - i] # (ngf * 4) x 8 x 6
|
| 120 |
+
T2 = E2_list[4 - i]
|
| 121 |
+
E4 = torch.cat([T1, T2], 1)
|
| 122 |
+
|
| 123 |
+
flow = self.flow_conv[i](self.normalize(E4)).permute(0, 2, 3, 1)
|
| 124 |
+
flow_list.append(flow)
|
| 125 |
+
|
| 126 |
+
x = self.conv(T2)
|
| 127 |
+
x = self.SegDecoder[i](x)
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
T1 = F.interpolate(T1, scale_factor=2, mode=upsample) + self.conv1[4 - i](E1_list[4 - i])
|
| 131 |
+
T2 = F.interpolate(T2, scale_factor=2, mode=upsample) + self.conv2[4 - i](E2_list[4 - i])
|
| 132 |
+
|
| 133 |
+
flow = F.interpolate(flow_list[i - 1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1) # upsample n-1 flow
|
| 134 |
+
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
|
| 135 |
+
warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
|
| 136 |
+
|
| 137 |
+
flow = flow + self.flow_conv[i](self.normalize(torch.cat([warped_T1, self.bottleneck[i-1](x)], 1))).permute(0, 2, 3, 1) # F(n)
|
| 138 |
+
flow_list.append(flow)
|
| 139 |
+
|
| 140 |
+
if self.warp_feature == 'T1':
|
| 141 |
+
x = self.SegDecoder[i](torch.cat([x, E2_list[4-i], warped_T1], 1))
|
| 142 |
+
if self.warp_feature == 'encoder':
|
| 143 |
+
warped_E1 = F.grid_sample(E1_list[4-i], flow_norm + grid, padding_mode='border')
|
| 144 |
+
x = self.SegDecoder[i](torch.cat([x, E2_list[4-i], warped_E1], 1))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
N, _, iH, iW = input1.size()
|
| 148 |
+
grid = make_grid(N, iH, iW,opt)
|
| 149 |
+
|
| 150 |
+
flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1)
|
| 151 |
+
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
|
| 152 |
+
warped_input1 = F.grid_sample(input1, flow_norm + grid, padding_mode='border')
|
| 153 |
+
|
| 154 |
+
x = self.out_layer(torch.cat([x, input2, warped_input1], 1))
|
| 155 |
+
|
| 156 |
+
warped_c = warped_input1[:, :-1, :, :]
|
| 157 |
+
warped_cm = warped_input1[:, -1:, :, :]
|
| 158 |
+
|
| 159 |
+
return flow_list, x, warped_c, warped_cm
|
| 160 |
+
|
| 161 |
+
def make_grid(N, iH, iW,opt):
|
| 162 |
+
grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1)
|
| 163 |
+
grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1)
|
| 164 |
+
if opt.cuda :
|
| 165 |
+
grid = torch.cat([grid_x, grid_y], 3).cuda()
|
| 166 |
+
else:
|
| 167 |
+
grid = torch.cat([grid_x, grid_y], 3)
|
| 168 |
+
return grid
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ResBlock(nn.Module):
|
| 172 |
+
def __init__(self, in_nc, out_nc, scale='down', norm_layer=nn.BatchNorm2d):
|
| 173 |
+
super(ResBlock, self).__init__()
|
| 174 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 175 |
+
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
|
| 176 |
+
|
| 177 |
+
if scale == 'same':
|
| 178 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
|
| 179 |
+
if scale == 'up':
|
| 180 |
+
self.scale = nn.Sequential(
|
| 181 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
| 182 |
+
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
|
| 183 |
+
)
|
| 184 |
+
if scale == 'down':
|
| 185 |
+
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
|
| 186 |
+
|
| 187 |
+
self.block = nn.Sequential(
|
| 188 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
| 189 |
+
norm_layer(out_nc),
|
| 190 |
+
nn.ReLU(inplace=True),
|
| 191 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
| 192 |
+
norm_layer(out_nc)
|
| 193 |
+
)
|
| 194 |
+
self.relu = nn.ReLU(inplace=True)
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
residual = self.scale(x)
|
| 198 |
+
return self.relu(residual + self.block(residual))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Vgg19(nn.Module):
|
| 202 |
+
def __init__(self, requires_grad=False):
|
| 203 |
+
super(Vgg19, self).__init__()
|
| 204 |
+
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
| 205 |
+
self.slice1 = torch.nn.Sequential()
|
| 206 |
+
self.slice2 = torch.nn.Sequential()
|
| 207 |
+
self.slice3 = torch.nn.Sequential()
|
| 208 |
+
self.slice4 = torch.nn.Sequential()
|
| 209 |
+
self.slice5 = torch.nn.Sequential()
|
| 210 |
+
for x in range(2):
|
| 211 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
| 212 |
+
for x in range(2, 7):
|
| 213 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
| 214 |
+
for x in range(7, 12):
|
| 215 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
| 216 |
+
for x in range(12, 21):
|
| 217 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
| 218 |
+
for x in range(21, 30):
|
| 219 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
| 220 |
+
if not requires_grad:
|
| 221 |
+
for param in self.parameters():
|
| 222 |
+
param.requires_grad = False
|
| 223 |
+
|
| 224 |
+
def forward(self, X):
|
| 225 |
+
h_relu1 = self.slice1(X)
|
| 226 |
+
h_relu2 = self.slice2(h_relu1)
|
| 227 |
+
h_relu3 = self.slice3(h_relu2)
|
| 228 |
+
h_relu4 = self.slice4(h_relu3)
|
| 229 |
+
h_relu5 = self.slice5(h_relu4)
|
| 230 |
+
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
| 231 |
+
return out
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class VGGLoss(nn.Module):
|
| 235 |
+
def __init__(self, opt,layids = None):
|
| 236 |
+
super(VGGLoss, self).__init__()
|
| 237 |
+
self.vgg = Vgg19()
|
| 238 |
+
if opt.cuda:
|
| 239 |
+
self.vgg.cuda()
|
| 240 |
+
self.criterion = nn.L1Loss()
|
| 241 |
+
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
| 242 |
+
self.layids = layids
|
| 243 |
+
|
| 244 |
+
def forward(self, x, y):
|
| 245 |
+
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
| 246 |
+
loss = 0
|
| 247 |
+
if self.layids is None:
|
| 248 |
+
self.layids = list(range(len(x_vgg)))
|
| 249 |
+
for i in self.layids:
|
| 250 |
+
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
| 251 |
+
return loss
|
| 252 |
+
|
| 253 |
+
# Defines the GAN loss which uses either LSGAN or the regular GAN.
|
| 254 |
+
# When LSGAN is used, it is basically same as MSELoss,
|
| 255 |
+
# but it abstracts away the need to create the target label tensor
|
| 256 |
+
# that has the same size as the input
|
| 257 |
+
|
| 258 |
+
class GANLoss(nn.Module):
|
| 259 |
+
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
|
| 260 |
+
tensor=torch.FloatTensor):
|
| 261 |
+
super(GANLoss, self).__init__()
|
| 262 |
+
self.real_label = target_real_label
|
| 263 |
+
self.fake_label = target_fake_label
|
| 264 |
+
self.real_label_var = None
|
| 265 |
+
self.fake_label_var = None
|
| 266 |
+
self.Tensor = tensor
|
| 267 |
+
if use_lsgan:
|
| 268 |
+
self.loss = nn.MSELoss()
|
| 269 |
+
else:
|
| 270 |
+
self.loss = nn.BCELoss()
|
| 271 |
+
|
| 272 |
+
def get_target_tensor(self, input, target_is_real):
|
| 273 |
+
if target_is_real:
|
| 274 |
+
create_label = ((self.real_label_var is None) or
|
| 275 |
+
(self.real_label_var.numel() != input.numel()))
|
| 276 |
+
if create_label:
|
| 277 |
+
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
| 278 |
+
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
| 279 |
+
target_tensor = self.real_label_var
|
| 280 |
+
else:
|
| 281 |
+
create_label = ((self.fake_label_var is None) or
|
| 282 |
+
(self.fake_label_var.numel() != input.numel()))
|
| 283 |
+
if create_label:
|
| 284 |
+
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
| 285 |
+
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
| 286 |
+
target_tensor = self.fake_label_var
|
| 287 |
+
return target_tensor
|
| 288 |
+
|
| 289 |
+
def __call__(self, input, target_is_real):
|
| 290 |
+
if isinstance(input[0], list):
|
| 291 |
+
loss = 0
|
| 292 |
+
for input_i in input:
|
| 293 |
+
pred = input_i[-1]
|
| 294 |
+
target_tensor = self.get_target_tensor(pred, target_is_real)
|
| 295 |
+
loss += self.loss(pred, target_tensor)
|
| 296 |
+
return loss
|
| 297 |
+
else:
|
| 298 |
+
target_tensor = self.get_target_tensor(input[-1], target_is_real)
|
| 299 |
+
return self.loss(input[-1], target_tensor)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class MultiscaleDiscriminator(nn.Module):
|
| 303 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
| 304 |
+
use_sigmoid=False, num_D=3, getIntermFeat=False, Ddownx2=False, Ddropout=False, spectral=False):
|
| 305 |
+
super(MultiscaleDiscriminator, self).__init__()
|
| 306 |
+
self.num_D = num_D
|
| 307 |
+
self.n_layers = n_layers
|
| 308 |
+
self.getIntermFeat = getIntermFeat
|
| 309 |
+
self.Ddownx2 = Ddownx2
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
for i in range(num_D):
|
| 313 |
+
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat, Ddropout, spectral=spectral)
|
| 314 |
+
if getIntermFeat:
|
| 315 |
+
for j in range(n_layers + 2):
|
| 316 |
+
setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
|
| 317 |
+
else:
|
| 318 |
+
setattr(self, 'layer' + str(i), netD.model)
|
| 319 |
+
|
| 320 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
| 321 |
+
|
| 322 |
+
def singleD_forward(self, model, input):
|
| 323 |
+
if self.getIntermFeat:
|
| 324 |
+
result = [input]
|
| 325 |
+
for i in range(len(model)):
|
| 326 |
+
result.append(model[i](result[-1]))
|
| 327 |
+
return result[1:]
|
| 328 |
+
else:
|
| 329 |
+
return [model(input)]
|
| 330 |
+
|
| 331 |
+
def forward(self, input):
|
| 332 |
+
num_D = self.num_D
|
| 333 |
+
|
| 334 |
+
result = []
|
| 335 |
+
if self.Ddownx2:
|
| 336 |
+
input_downsampled = self.downsample(input)
|
| 337 |
+
else:
|
| 338 |
+
input_downsampled = input
|
| 339 |
+
for i in range(num_D):
|
| 340 |
+
|
| 341 |
+
if self.getIntermFeat:
|
| 342 |
+
model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
|
| 343 |
+
range(self.n_layers + 2)]
|
| 344 |
+
else:
|
| 345 |
+
model = getattr(self, 'layer' + str(num_D - 1 - i))
|
| 346 |
+
result.append(self.singleD_forward(model, input_downsampled))
|
| 347 |
+
if i != (num_D - 1):
|
| 348 |
+
input_downsampled = self.downsample(input_downsampled)
|
| 349 |
+
return result
|
| 350 |
+
|
| 351 |
+
class NLayerDiscriminator(nn.Module):
|
| 352 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False, Ddropout=False, spectral=False):
|
| 353 |
+
super(NLayerDiscriminator, self).__init__()
|
| 354 |
+
self.getIntermFeat = getIntermFeat
|
| 355 |
+
self.n_layers = n_layers
|
| 356 |
+
self.spectral_norm = spectral_norm if spectral else lambda x: x
|
| 357 |
+
|
| 358 |
+
kw = 4
|
| 359 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
| 360 |
+
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
| 361 |
+
|
| 362 |
+
nf = ndf
|
| 363 |
+
for n in range(1, n_layers):
|
| 364 |
+
nf_prev = nf
|
| 365 |
+
nf = min(nf * 2, 512)
|
| 366 |
+
if Ddropout:
|
| 367 |
+
sequence += [[
|
| 368 |
+
self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
|
| 369 |
+
norm_layer(nf), nn.LeakyReLU(0.2, True), nn.Dropout(0.5)
|
| 370 |
+
]]
|
| 371 |
+
else:
|
| 372 |
+
|
| 373 |
+
sequence += [[
|
| 374 |
+
self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
|
| 375 |
+
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
| 376 |
+
]]
|
| 377 |
+
|
| 378 |
+
nf_prev = nf
|
| 379 |
+
nf = min(nf * 2, 512)
|
| 380 |
+
sequence += [[
|
| 381 |
+
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
| 382 |
+
norm_layer(nf),
|
| 383 |
+
nn.LeakyReLU(0.2, True)
|
| 384 |
+
]]
|
| 385 |
+
|
| 386 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
| 387 |
+
|
| 388 |
+
if use_sigmoid:
|
| 389 |
+
sequence += [[nn.Sigmoid()]]
|
| 390 |
+
|
| 391 |
+
if getIntermFeat:
|
| 392 |
+
for n in range(len(sequence)):
|
| 393 |
+
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
|
| 394 |
+
else:
|
| 395 |
+
sequence_stream = []
|
| 396 |
+
for n in range(len(sequence)):
|
| 397 |
+
sequence_stream += sequence[n]
|
| 398 |
+
self.model = nn.Sequential(*sequence_stream)
|
| 399 |
+
|
| 400 |
+
def forward(self, input):
|
| 401 |
+
if self.getIntermFeat:
|
| 402 |
+
res = [input]
|
| 403 |
+
for n in range(self.n_layers + 2):
|
| 404 |
+
model = getattr(self, 'model' + str(n))
|
| 405 |
+
res.append(model(res[-1]))
|
| 406 |
+
return res[1:]
|
| 407 |
+
else:
|
| 408 |
+
return self.model(input)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def save_checkpoint(model, save_path,opt):
|
| 412 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
| 413 |
+
os.makedirs(os.path.dirname(save_path))
|
| 414 |
+
|
| 415 |
+
torch.save(model.cpu().state_dict(), save_path)
|
| 416 |
+
if opt.cuda :
|
| 417 |
+
model.cuda()
|
| 418 |
+
|
| 419 |
+
def load_checkpoint(model, checkpoint_path,opt):
|
| 420 |
+
if not os.path.exists(checkpoint_path):
|
| 421 |
+
print('no checkpoint')
|
| 422 |
+
raise
|
| 423 |
+
log = model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
| 424 |
+
if opt.cuda :
|
| 425 |
+
model.cuda()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def weights_init(m):
|
| 429 |
+
classname = m.__class__.__name__
|
| 430 |
+
if classname.find('Conv2d') != -1:
|
| 431 |
+
m.weight.data.normal_(0.0, 0.02)
|
| 432 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 433 |
+
m.weight.data.normal_(1.0, 0.02)
|
| 434 |
+
m.bias.data.fill_(0)
|
| 435 |
+
|
| 436 |
+
def get_norm_layer(norm_type='instance'):
|
| 437 |
+
if norm_type == 'batch':
|
| 438 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
| 439 |
+
elif norm_type == 'instance':
|
| 440 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
| 441 |
+
else:
|
| 442 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
| 443 |
+
return norm_layer
|
| 444 |
+
|
| 445 |
+
def define_D(input_nc, ndf=64, n_layers_D=3, norm='instance', use_sigmoid=False, num_D=2, getIntermFeat=False, gpu_ids=[], Ddownx2=False, Ddropout=False, spectral=False):
|
| 446 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
| 447 |
+
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat, Ddownx2, Ddropout, spectral=spectral)
|
| 448 |
+
print(netD)
|
| 449 |
+
if len(gpu_ids) > 0:
|
| 450 |
+
assert (torch.cuda.is_available())
|
| 451 |
+
netD.cuda()
|
| 452 |
+
netD.apply(weights_init)
|
| 453 |
+
return netD
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
torchaudio
|
| 4 |
+
opencv-python
|
| 5 |
+
torchgeometry
|
| 6 |
+
Pillow
|
| 7 |
+
tqdm
|
| 8 |
+
tensorboardX
|
| 9 |
+
scikit-image
|
| 10 |
+
scipy
|
| 11 |
+
streamlit-image-select
|
| 12 |
+
pandas
|
test_generator.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from torchvision.utils import make_grid as make_image_grid
|
| 5 |
+
from torchvision.utils import save_image
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
from cp_dataset_test import CPDatasetTest, CPDataLoader
|
| 10 |
+
|
| 11 |
+
from networks import ConditionGenerator, load_checkpoint, make_grid
|
| 12 |
+
from network_generator import SPADEGenerator
|
| 13 |
+
from tensorboardX import SummaryWriter
|
| 14 |
+
from utils import *
|
| 15 |
+
|
| 16 |
+
import torchgeometry as tgm
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
|
| 19 |
+
def remove_overlap(seg_out, warped_cm):
|
| 20 |
+
|
| 21 |
+
assert len(warped_cm.shape) == 4
|
| 22 |
+
|
| 23 |
+
warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm
|
| 24 |
+
return warped_cm
|
| 25 |
+
def get_opt():
|
| 26 |
+
parser = argparse.ArgumentParser()
|
| 27 |
+
|
| 28 |
+
parser.add_argument("--gpu_ids", default="")
|
| 29 |
+
parser.add_argument('-j', '--workers', type=int, default=4)
|
| 30 |
+
parser.add_argument('-b', '--batch-size', type=int, default=1)
|
| 31 |
+
parser.add_argument('--fp16', action='store_true', help='use amp')
|
| 32 |
+
# Cuda availability
|
| 33 |
+
parser.add_argument('--cuda',default=False, help='cuda or cpu')
|
| 34 |
+
|
| 35 |
+
parser.add_argument('--test_name', type=str, default='test', help='test name')
|
| 36 |
+
parser.add_argument("--dataroot", default="./data/zalando-hd-resize")
|
| 37 |
+
parser.add_argument("--datamode", default="test")
|
| 38 |
+
parser.add_argument("--data_list", default="test_pairs.txt")
|
| 39 |
+
parser.add_argument("--output_dir", type=str, default="./Output")
|
| 40 |
+
parser.add_argument("--datasetting", default="unpaired")
|
| 41 |
+
parser.add_argument("--fine_width", type=int, default=768)
|
| 42 |
+
parser.add_argument("--fine_height", type=int, default=1024)
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--tensorboard_dir', type=str, default='./data/zalando-hd-resize/tensorboard', help='save tensorboard infos')
|
| 45 |
+
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
|
| 46 |
+
parser.add_argument('--tocg_checkpoint', type=str, default='./eval_models/weights/v0.1/mtviton.pth', help='tocg checkpoint')
|
| 47 |
+
parser.add_argument('--gen_checkpoint', type=str, default='./eval_models/weights/v0.1/gen.pth', help='G checkpoint')
|
| 48 |
+
|
| 49 |
+
parser.add_argument("--tensorboard_count", type=int, default=100)
|
| 50 |
+
parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
|
| 51 |
+
parser.add_argument("--semantic_nc", type=int, default=13)
|
| 52 |
+
parser.add_argument("--output_nc", type=int, default=13)
|
| 53 |
+
parser.add_argument('--gen_semantic_nc', type=int, default=7, help='# of input label classes without unknown class')
|
| 54 |
+
|
| 55 |
+
# network
|
| 56 |
+
parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
|
| 57 |
+
parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
|
| 58 |
+
|
| 59 |
+
# training
|
| 60 |
+
parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad')
|
| 61 |
+
|
| 62 |
+
# Hyper-parameters
|
| 63 |
+
parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
|
| 64 |
+
parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")
|
| 65 |
+
|
| 66 |
+
# generator
|
| 67 |
+
parser.add_argument('--norm_G', type=str, default='spectralaliasinstance', help='instance normalization or batch normalization')
|
| 68 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
|
| 69 |
+
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
|
| 70 |
+
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
|
| 71 |
+
parser.add_argument('--num_upsampling_layers', choices=('normal', 'more', 'most'), default='most', # normal: 256, more: 512
|
| 72 |
+
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
|
| 73 |
+
|
| 74 |
+
opt = parser.parse_args()
|
| 75 |
+
return opt
|
| 76 |
+
|
| 77 |
+
def load_checkpoint_G(model, checkpoint_path,opt):
|
| 78 |
+
if not os.path.exists(checkpoint_path):
|
| 79 |
+
print("Invalid path!")
|
| 80 |
+
return
|
| 81 |
+
state_dict = torch.load(checkpoint_path)
|
| 82 |
+
new_state_dict = OrderedDict([(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict.items()])
|
| 83 |
+
new_state_dict._metadata = OrderedDict([(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict._metadata.items()])
|
| 84 |
+
model.load_state_dict(new_state_dict, strict=True)
|
| 85 |
+
if opt.cuda :
|
| 86 |
+
model.cuda()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test(opt, test_loader, tocg, generator):
|
| 91 |
+
gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
|
| 92 |
+
if opt.cuda:
|
| 93 |
+
gauss = gauss.cuda()
|
| 94 |
+
|
| 95 |
+
# Model
|
| 96 |
+
if opt.cuda :
|
| 97 |
+
tocg.cuda()
|
| 98 |
+
tocg.eval()
|
| 99 |
+
generator.eval()
|
| 100 |
+
|
| 101 |
+
if opt.output_dir is not None:
|
| 102 |
+
output_dir = opt.output_dir
|
| 103 |
+
else:
|
| 104 |
+
output_dir = os.path.join('./output', opt.test_name,
|
| 105 |
+
opt.datamode, opt.datasetting, 'generator', 'output')
|
| 106 |
+
grid_dir = os.path.join('./output', opt.test_name,
|
| 107 |
+
opt.datamode, opt.datasetting, 'generator', 'grid')
|
| 108 |
+
|
| 109 |
+
os.makedirs(grid_dir, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
num = 0
|
| 114 |
+
iter_start_time = time.time()
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
for inputs in test_loader.data_loader:
|
| 117 |
+
|
| 118 |
+
if opt.cuda :
|
| 119 |
+
pose_map = inputs['pose'].cuda()
|
| 120 |
+
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting].cuda()
|
| 121 |
+
label = inputs['parse']
|
| 122 |
+
parse_agnostic = inputs['parse_agnostic']
|
| 123 |
+
agnostic = inputs['agnostic'].cuda()
|
| 124 |
+
clothes = inputs['cloth'][opt.datasetting].cuda() # target cloth
|
| 125 |
+
densepose = inputs['densepose'].cuda()
|
| 126 |
+
im = inputs['image']
|
| 127 |
+
input_label, input_parse_agnostic = label.cuda(), parse_agnostic.cuda()
|
| 128 |
+
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
|
| 129 |
+
else :
|
| 130 |
+
pose_map = inputs['pose']
|
| 131 |
+
pre_clothes_mask = inputs['cloth_mask'][opt.datasetting]
|
| 132 |
+
label = inputs['parse']
|
| 133 |
+
parse_agnostic = inputs['parse_agnostic']
|
| 134 |
+
agnostic = inputs['agnostic']
|
| 135 |
+
clothes = inputs['cloth'][opt.datasetting] # target cloth
|
| 136 |
+
densepose = inputs['densepose']
|
| 137 |
+
im = inputs['image']
|
| 138 |
+
input_label, input_parse_agnostic = label, parse_agnostic
|
| 139 |
+
pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# down
|
| 144 |
+
pose_map_down = F.interpolate(pose_map, size=(256, 192), mode='bilinear')
|
| 145 |
+
pre_clothes_mask_down = F.interpolate(pre_clothes_mask, size=(256, 192), mode='nearest')
|
| 146 |
+
input_label_down = F.interpolate(input_label, size=(256, 192), mode='bilinear')
|
| 147 |
+
input_parse_agnostic_down = F.interpolate(input_parse_agnostic, size=(256, 192), mode='nearest')
|
| 148 |
+
agnostic_down = F.interpolate(agnostic, size=(256, 192), mode='nearest')
|
| 149 |
+
clothes_down = F.interpolate(clothes, size=(256, 192), mode='bilinear')
|
| 150 |
+
densepose_down = F.interpolate(densepose, size=(256, 192), mode='bilinear')
|
| 151 |
+
|
| 152 |
+
shape = pre_clothes_mask.shape
|
| 153 |
+
|
| 154 |
+
# multi-task inputs
|
| 155 |
+
input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
|
| 156 |
+
input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
|
| 157 |
+
|
| 158 |
+
# forward
|
| 159 |
+
flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(opt,input1, input2)
|
| 160 |
+
|
| 161 |
+
# warped cloth mask one hot
|
| 162 |
+
if opt.cuda :
|
| 163 |
+
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
|
| 164 |
+
else :
|
| 165 |
+
warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float))
|
| 166 |
+
|
| 167 |
+
if opt.clothmask_composition != 'no_composition':
|
| 168 |
+
if opt.clothmask_composition == 'detach':
|
| 169 |
+
cloth_mask = torch.ones_like(fake_segmap)
|
| 170 |
+
cloth_mask[:,3:4, :, :] = warped_cm_onehot
|
| 171 |
+
fake_segmap = fake_segmap * cloth_mask
|
| 172 |
+
|
| 173 |
+
if opt.clothmask_composition == 'warp_grad':
|
| 174 |
+
cloth_mask = torch.ones_like(fake_segmap)
|
| 175 |
+
cloth_mask[:,3:4, :, :] = warped_clothmask_paired
|
| 176 |
+
fake_segmap = fake_segmap * cloth_mask
|
| 177 |
+
|
| 178 |
+
# make generator input parse map
|
| 179 |
+
fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(opt.fine_height, opt.fine_width), mode='bilinear'))
|
| 180 |
+
fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
|
| 181 |
+
|
| 182 |
+
if opt.cuda :
|
| 183 |
+
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
|
| 184 |
+
else:
|
| 185 |
+
old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_()
|
| 186 |
+
old_parse.scatter_(1, fake_parse, 1.0)
|
| 187 |
+
|
| 188 |
+
labels = {
|
| 189 |
+
0: ['background', [0]],
|
| 190 |
+
1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
|
| 191 |
+
2: ['upper', [3]],
|
| 192 |
+
3: ['hair', [1]],
|
| 193 |
+
4: ['left_arm', [5]],
|
| 194 |
+
5: ['right_arm', [6]],
|
| 195 |
+
6: ['noise', [12]]
|
| 196 |
+
}
|
| 197 |
+
if opt.cuda :
|
| 198 |
+
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
|
| 199 |
+
else:
|
| 200 |
+
parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_()
|
| 201 |
+
for i in range(len(labels)):
|
| 202 |
+
for label in labels[i][1]:
|
| 203 |
+
parse[:, i] += old_parse[:, label]
|
| 204 |
+
|
| 205 |
+
# warped cloth
|
| 206 |
+
N, _, iH, iW = clothes.shape
|
| 207 |
+
flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
|
| 208 |
+
flow_norm = torch.cat([flow[:, :, :, 0:1] / ((96 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((128 - 1.0) / 2.0)], 3)
|
| 209 |
+
|
| 210 |
+
grid = make_grid(N, iH, iW,opt)
|
| 211 |
+
warped_grid = grid + flow_norm
|
| 212 |
+
warped_cloth = F.grid_sample(clothes, warped_grid, padding_mode='border')
|
| 213 |
+
warped_clothmask = F.grid_sample(pre_clothes_mask, warped_grid, padding_mode='border')
|
| 214 |
+
if opt.occlusion:
|
| 215 |
+
warped_clothmask = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask)
|
| 216 |
+
warped_cloth = warped_cloth * warped_clothmask + torch.ones_like(warped_cloth) * (1-warped_clothmask)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
output = generator(torch.cat((agnostic, densepose, warped_cloth), dim=1), parse)
|
| 220 |
+
# visualize
|
| 221 |
+
unpaired_names = []
|
| 222 |
+
for i in range(shape[0]):
|
| 223 |
+
grid = make_image_grid([(clothes[i].cpu() / 2 + 0.5), (pre_clothes_mask[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2),
|
| 224 |
+
(warped_cloth[i].cpu().detach() / 2 + 0.5), (warped_clothmask[i].cpu().detach()).expand(3, -1, -1), visualize_segmap(fake_parse_gauss.cpu(), batch=i),
|
| 225 |
+
(pose_map[i].cpu()/2 +0.5), (warped_cloth[i].cpu()/2 + 0.5), (agnostic[i].cpu()/2 + 0.5),
|
| 226 |
+
(im[i]/2 +0.5), (output[i].cpu()/2 +0.5)],
|
| 227 |
+
nrow=4)
|
| 228 |
+
unpaired_name = (inputs['c_name']['paired'][i].split('.')[0] + '_' + inputs['c_name'][opt.datasetting][i].split('.')[0] + '.png')
|
| 229 |
+
save_image(grid, os.path.join(grid_dir, unpaired_name))
|
| 230 |
+
unpaired_names.append(unpaired_name)
|
| 231 |
+
|
| 232 |
+
# save output
|
| 233 |
+
save_images(output, unpaired_names, output_dir)
|
| 234 |
+
|
| 235 |
+
num += shape[0]
|
| 236 |
+
print(num)
|
| 237 |
+
|
| 238 |
+
print(f"Test time {time.time() - iter_start_time}")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def main():
|
| 242 |
+
opt = get_opt()
|
| 243 |
+
print(opt)
|
| 244 |
+
print("Start to test %s!")
|
| 245 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
|
| 246 |
+
|
| 247 |
+
# create test dataset & loader
|
| 248 |
+
test_dataset = CPDatasetTest(opt)
|
| 249 |
+
test_loader = CPDataLoader(opt, test_dataset)
|
| 250 |
+
|
| 251 |
+
# visualization
|
| 252 |
+
# if not os.path.exists(opt.tensorboard_dir):
|
| 253 |
+
# os.makedirs(opt.tensorboard_dir)
|
| 254 |
+
# board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.test_name, opt.datamode, opt.datasetting))
|
| 255 |
+
|
| 256 |
+
## Model
|
| 257 |
+
# tocg
|
| 258 |
+
input1_nc = 4 # cloth + cloth-mask
|
| 259 |
+
input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
|
| 260 |
+
tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
|
| 261 |
+
|
| 262 |
+
# generator
|
| 263 |
+
opt.semantic_nc = 7
|
| 264 |
+
generator = SPADEGenerator(opt, 3+3+3)
|
| 265 |
+
generator.print_network()
|
| 266 |
+
|
| 267 |
+
# Load Checkpoint
|
| 268 |
+
load_checkpoint(tocg, opt.tocg_checkpoint,opt)
|
| 269 |
+
load_checkpoint_G(generator, opt.gen_checkpoint,opt)
|
| 270 |
+
|
| 271 |
+
# Train
|
| 272 |
+
test(opt, test_loader, tocg, generator)
|
| 273 |
+
|
| 274 |
+
print("Finished testing!")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def get_clothes_mask(old_label) :
|
| 10 |
+
clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int))
|
| 11 |
+
return clothes
|
| 12 |
+
|
| 13 |
+
def changearm(old_label):
|
| 14 |
+
label=old_label
|
| 15 |
+
arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int))
|
| 16 |
+
arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int))
|
| 17 |
+
label=label*(1-arm1)+arm1*3
|
| 18 |
+
label=label*(1-arm2)+arm2*3
|
| 19 |
+
return label
|
| 20 |
+
|
| 21 |
+
def gen_noise(shape):
|
| 22 |
+
noise = np.zeros(shape, dtype=np.uint8)
|
| 23 |
+
### noise
|
| 24 |
+
noise = cv2.randn(noise, 0, 255)
|
| 25 |
+
noise = np.asarray(noise / 255, dtype=np.uint8)
|
| 26 |
+
noise = torch.tensor(noise, dtype=torch.float32)
|
| 27 |
+
return noise
|
| 28 |
+
|
| 29 |
+
def cross_entropy2d(input, target, weight=None, size_average=True):
|
| 30 |
+
n, c, h, w = input.size()
|
| 31 |
+
nt, ht, wt = target.size()
|
| 32 |
+
|
| 33 |
+
# Handle inconsistent size between input and target
|
| 34 |
+
if h != ht or w != wt:
|
| 35 |
+
input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
|
| 36 |
+
|
| 37 |
+
input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
|
| 38 |
+
target = target.view(-1)
|
| 39 |
+
loss = F.cross_entropy(
|
| 40 |
+
input, target, weight=weight, size_average=size_average, ignore_index=250
|
| 41 |
+
)
|
| 42 |
+
return loss
|
| 43 |
+
|
| 44 |
+
def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0):
|
| 45 |
+
image_numpy = image_tensor[batch].cpu().float().numpy()
|
| 46 |
+
result = np.argmax(image_numpy, axis=0)
|
| 47 |
+
return result.astype(imtype)
|
| 48 |
+
|
| 49 |
+
def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) :
|
| 50 |
+
palette = [
|
| 51 |
+
0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51,
|
| 52 |
+
254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85,
|
| 53 |
+
85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220,
|
| 54 |
+
0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0
|
| 55 |
+
]
|
| 56 |
+
input = input.detach()
|
| 57 |
+
if multi_channel :
|
| 58 |
+
input = ndim_tensor2im(input,batch=batch)
|
| 59 |
+
else :
|
| 60 |
+
input = input[batch][0].cpu()
|
| 61 |
+
input = np.asarray(input)
|
| 62 |
+
input = input.astype(np.uint8)
|
| 63 |
+
input = Image.fromarray(input, 'P')
|
| 64 |
+
input.putpalette(palette)
|
| 65 |
+
|
| 66 |
+
if tensor_out :
|
| 67 |
+
trans = transforms.ToTensor()
|
| 68 |
+
return trans(input.convert('RGB'))
|
| 69 |
+
|
| 70 |
+
return input
|
| 71 |
+
|
| 72 |
+
def pred_to_onehot(prediction) :
|
| 73 |
+
size = prediction.shape
|
| 74 |
+
prediction_max = torch.argmax(prediction, dim=1)
|
| 75 |
+
oneHot_size = (size[0], 13, size[2], size[3])
|
| 76 |
+
pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
| 77 |
+
pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0)
|
| 78 |
+
return pred_onehot
|
| 79 |
+
|
| 80 |
+
def cal_miou(prediction, target) :
|
| 81 |
+
size = prediction.shape
|
| 82 |
+
target = target.cpu()
|
| 83 |
+
prediction = pred_to_onehot(prediction.detach().cpu())
|
| 84 |
+
list = [1,2,3,4,5,6,7,8]
|
| 85 |
+
union = 0
|
| 86 |
+
intersection = 0
|
| 87 |
+
for b in range(size[0]) :
|
| 88 |
+
for c in list :
|
| 89 |
+
intersection += torch.logical_and(target[b,c], prediction[b,c]).sum()
|
| 90 |
+
union += torch.logical_or(target[b,c], prediction[b,c]).sum()
|
| 91 |
+
return intersection.item()/union.item()
|
| 92 |
+
|
| 93 |
+
def save_images(img_tensors, img_names, save_dir):
|
| 94 |
+
for img_tensor, img_name in zip(img_tensors, img_names):
|
| 95 |
+
tensor = (img_tensor.clone() + 1) * 0.5 * 255
|
| 96 |
+
tensor = tensor.cpu().clamp(0, 255)
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
array = tensor.numpy().astype('uint8')
|
| 100 |
+
except:
|
| 101 |
+
array = tensor.detach().numpy().astype('uint8')
|
| 102 |
+
|
| 103 |
+
if array.shape[0] == 1:
|
| 104 |
+
array = array.squeeze(0)
|
| 105 |
+
elif array.shape[0] == 3:
|
| 106 |
+
array = array.swapaxes(0, 1).swapaxes(1, 2)
|
| 107 |
+
|
| 108 |
+
im = Image.fromarray(array)
|
| 109 |
+
im.save(os.path.join(save_dir, img_name), format='JPEG')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_network(cls, opt):
|
| 113 |
+
net = cls(opt)
|
| 114 |
+
net.print_network()
|
| 115 |
+
if len(opt.gpu_ids) > 0:
|
| 116 |
+
assert(torch.cuda.is_available())
|
| 117 |
+
net.cuda()
|
| 118 |
+
net.init_weights(opt.init_type, opt.init_variance)
|
| 119 |
+
return net
|