svjack commited on
Commit
f774f0f
·
1 Parent(s): 9fffe95

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +415 -0
  2. cp_dataset_test.py +264 -0
  3. network_generator.py +433 -0
  4. networks.py +453 -0
  5. requirements.txt +12 -0
  6. test_generator.py +278 -0
  7. utils.py +119 -0
app.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import set_start_method
2
+ #set_start_method("fork")
3
+
4
+ import sys
5
+ #sys.path.insert(0, "../HR-VITON-main")
6
+ from test_generator import *
7
+
8
+ import re
9
+ import inspect
10
+ from dataclasses import dataclass, field
11
+ from tqdm import tqdm
12
+ import pandas as pd
13
+ import os
14
+ import torch
15
+ import pandas as pd
16
+ import gradio as gr
17
+ import streamlit as st
18
+ from io import BytesIO
19
+
20
+ #### pip install streamlit-image-select
21
+ from streamlit_image_select import image_select
22
+
23
+ demo_image_dir = "demo_images_dir"
24
+ assert os.path.exists(demo_image_dir)
25
+ demo_images = list(map(lambda y: os.path.join(demo_image_dir, y) ,filter(lambda x: x.endswith(".png") or x.endswith(".jpeg") or x.endswith(".jpg")
26
+ ,os.listdir(demo_image_dir))))
27
+ assert demo_images
28
+
29
+ #https://github.com/jrieke/streamlit-image-select/issues/10
30
+
31
+ #.image-box {
32
+ # border: 1px solid rgba(49, 51, 63, 0.2);
33
+ # border-radius: 0.25rem;
34
+ # padding: calc(0.25rem + 1px);
35
+ # height: 10rem;
36
+ # min-width: 10rem;
37
+ #}
38
+
39
+ demo_images = list(map(lambda x: x.resize((256, 256)), map(Image.open, demo_images)))
40
+
41
+ @dataclass
42
+ class OPT:
43
+ #### ConditionGenerator
44
+ out_layer = None
45
+ warp_feature = None
46
+ #### SPADEGenerator
47
+ semantic_nc = None
48
+ fine_height = None
49
+ fine_width = None
50
+ ngf = None
51
+ num_upsampling_layers = None
52
+ norm_G = None
53
+ gen_semantic_nc = None
54
+ #### weight load
55
+ tocg_checkpoint = None
56
+ gen_checkpoint = None
57
+ cuda = False
58
+
59
+ data_list = None
60
+ datamode = None
61
+ dataroot = None
62
+
63
+ batch_size = None
64
+ shuffle = False
65
+ workers = None
66
+
67
+ clothmask_composition = None
68
+ occlusion = False
69
+ datasetting = None
70
+
71
+ opt = OPT()
72
+ opt.out_layer = "relu"
73
+ opt.warp_feature = "T1"
74
+
75
+ input1_nc = 4 # cloth + cloth-mask
76
+ nc = 13
77
+ input2_nc = nc + 3 # parse_agnostic + densepose
78
+ output_nc = nc
79
+ tocg = ConditionGenerator(opt,
80
+ input1_nc=input1_nc,
81
+ input2_nc=input2_nc, output_nc=output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
82
+
83
+ #### SPADEResBlock
84
+ from network_generator import SPADEResBlock
85
+
86
+ opt.semantic_nc = 7
87
+ opt.fine_height = 1024
88
+ opt.fine_width = 768
89
+ opt.ngf = 64
90
+ opt.num_upsampling_layers = "most"
91
+ opt.norm_G = "spectralaliasinstance"
92
+ opt.gen_semantic_nc = 7
93
+
94
+ generator = SPADEGenerator(opt, 3+3+3)
95
+ generator.print_network()
96
+
97
+ #### https://drive.google.com/open?id=1XJTCdRBOPVgVTmqzhVGFAgMm2NLkw5uQ&authuser=0
98
+ opt.tocg_checkpoint = "mtviton.pth"
99
+ #### https://drive.google.com/open?id=1T5_YDUhYSSKPC_nZMk2NeC-XXUFoYeNy&authuser=0
100
+ opt.gen_checkpoint = "gen.pth"
101
+
102
+ opt.cuda = False
103
+
104
+ load_checkpoint(tocg, opt.tocg_checkpoint,opt)
105
+ load_checkpoint_G(generator, opt.gen_checkpoint,opt)
106
+
107
+ #### def test scope
108
+ tocg.eval()
109
+ generator.eval()
110
+
111
+ opt.data_list = "test_pairs.txt"
112
+ opt.datamode = "test"
113
+ opt.dataroot = "zalando-hd-resized"
114
+
115
+ opt.batch_size = 1
116
+ opt.shuffle = False
117
+ opt.workers = 1
118
+ opt.semantic_nc = 13
119
+
120
+ test_dataset = CPDatasetTest(opt)
121
+ test_loader = CPDataLoader(opt, test_dataset)
122
+
123
+ def construct_images(img_tensors, img_names = [None]):
124
+ #for img_tensor, img_name in zip(img_tensors, img_names):
125
+ for img_tensor, img_name in zip(img_tensors, img_names):
126
+ tensor = (img_tensor.clone() + 1) * 0.5 * 255
127
+ tensor = tensor.cpu().clamp(0, 255)
128
+ try:
129
+ array = tensor.numpy().astype('uint8')
130
+ except:
131
+ array = tensor.detach().numpy().astype('uint8')
132
+
133
+ if array.shape[0] == 1:
134
+ array = array.squeeze(0)
135
+ elif array.shape[0] == 3:
136
+ array = array.swapaxes(0, 1).swapaxes(1, 2)
137
+
138
+ im = Image.fromarray(array)
139
+ return im
140
+
141
+ def single_pred_slim_func(opt, inputs, tocg = tocg, generator = generator):
142
+ gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
143
+ if opt.cuda:
144
+ gauss = gauss.cuda()
145
+
146
+ # Model
147
+ if opt.cuda:
148
+ tocg.cuda()
149
+ tocg.eval()
150
+ generator.eval()
151
+
152
+ num = 0
153
+ iter_start_time = time.time()
154
+ with torch.no_grad():
155
+ for inputs in [inputs]:
156
+ if opt.cuda :
157
+ #pose_map = inputs['pose'].cuda()
158
+ pre_clothes_mask = inputs['cloth_mask'][opt.datasetting].cuda()
159
+ #label = inputs['parse']
160
+ parse_agnostic = inputs['parse_agnostic']
161
+ agnostic = inputs['agnostic'].cuda()
162
+ clothes = inputs['cloth'][opt.datasetting].cuda() # target cloth
163
+ densepose = inputs['densepose'].cuda()
164
+ #im = inputs['image']
165
+ #input_label, input_parse_agnostic = label.cuda(), parse_agnostic.cuda()
166
+ input_parse_agnostic = parse_agnostic.cuda()
167
+ pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
168
+ else :
169
+ #pose_map = inputs['pose']
170
+ pre_clothes_mask = inputs['cloth_mask'][opt.datasetting]
171
+ #label = inputs['parse']
172
+ parse_agnostic = inputs['parse_agnostic']
173
+ agnostic = inputs['agnostic']
174
+ clothes = inputs['cloth'][opt.datasetting] # target cloth
175
+ densepose = inputs['densepose']
176
+ #im = inputs['image']
177
+ #input_label, input_parse_agnostic = label, parse_agnostic
178
+ input_parse_agnostic = parse_agnostic
179
+ pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float))
180
+
181
+
182
+
183
+ # down
184
+ #pose_map_down = F.interpolate(pose_map, size=(256, 192), mode='bilinear')
185
+ pre_clothes_mask_down = F.interpolate(pre_clothes_mask, size=(256, 192), mode='nearest')
186
+ #input_label_down = F.interpolate(input_label, size=(256, 192), mode='bilinear')
187
+ input_parse_agnostic_down = F.interpolate(input_parse_agnostic, size=(256, 192), mode='nearest')
188
+ #agnostic_down = F.interpolate(agnostic, size=(256, 192), mode='nearest')
189
+ clothes_down = F.interpolate(clothes, size=(256, 192), mode='bilinear')
190
+ densepose_down = F.interpolate(densepose, size=(256, 192), mode='bilinear')
191
+
192
+ shape = pre_clothes_mask.shape
193
+
194
+ # multi-task inputs
195
+ input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
196
+ input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
197
+
198
+ # forward
199
+ flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(opt,input1, input2)
200
+
201
+ # warped cloth mask one hot
202
+ if opt.cuda :
203
+ warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
204
+ else :
205
+ warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float))
206
+
207
+ if opt.clothmask_composition != 'no_composition':
208
+ if opt.clothmask_composition == 'detach':
209
+ cloth_mask = torch.ones_like(fake_segmap)
210
+ cloth_mask[:,3:4, :, :] = warped_cm_onehot
211
+ fake_segmap = fake_segmap * cloth_mask
212
+
213
+ if opt.clothmask_composition == 'warp_grad':
214
+ cloth_mask = torch.ones_like(fake_segmap)
215
+ cloth_mask[:,3:4, :, :] = warped_clothmask_paired
216
+ fake_segmap = fake_segmap * cloth_mask
217
+
218
+ # make generator input parse map
219
+ fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(opt.fine_height, opt.fine_width), mode='bilinear'))
220
+ fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
221
+
222
+ if opt.cuda :
223
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
224
+ else:
225
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_()
226
+ old_parse.scatter_(1, fake_parse, 1.0)
227
+
228
+ labels = {
229
+ 0: ['background', [0]],
230
+ 1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
231
+ 2: ['upper', [3]],
232
+ 3: ['hair', [1]],
233
+ 4: ['left_arm', [5]],
234
+ 5: ['right_arm', [6]],
235
+ 6: ['noise', [12]]
236
+ }
237
+ if opt.cuda :
238
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
239
+ else:
240
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_()
241
+ for i in range(len(labels)):
242
+ for label in labels[i][1]:
243
+ parse[:, i] += old_parse[:, label]
244
+
245
+ # warped cloth
246
+ N, _, iH, iW = clothes.shape
247
+ flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
248
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((96 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((128 - 1.0) / 2.0)], 3)
249
+
250
+ grid = make_grid(N, iH, iW,opt)
251
+ warped_grid = grid + flow_norm
252
+ warped_cloth = F.grid_sample(clothes, warped_grid, padding_mode='border')
253
+ warped_clothmask = F.grid_sample(pre_clothes_mask, warped_grid, padding_mode='border')
254
+ if opt.occlusion:
255
+ warped_clothmask = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask)
256
+ warped_cloth = warped_cloth * warped_clothmask + torch.ones_like(warped_cloth) * (1-warped_clothmask)
257
+
258
+
259
+ output = generator(torch.cat((agnostic, densepose, warped_cloth), dim=1), parse)
260
+ # save output
261
+ return output
262
+ #save_images(output, unpaired_names, output_dir)
263
+ #num += shape[0]
264
+ #print(num)
265
+
266
+ opt.clothmask_composition = "warp_grad"
267
+ opt.occlusion = False
268
+ opt.datasetting = "unpaired"
269
+
270
+ def read_img_and_trans(dataset ,opt ,img_path):
271
+ if type(img_path) in [type("")]:
272
+ im = Image.open(img_path)
273
+ else:
274
+ im = img_path
275
+ im = transforms.Resize(opt.fine_width, interpolation=2)(im)
276
+ im = dataset.transform(im)
277
+ return im
278
+
279
+ import sys
280
+ sys.path.insert(0, "fashion-eye-try-on")
281
+
282
+ import os
283
+ from PIL import Image
284
+ import gradio as gr
285
+ from cloth_segmentation import generate_cloth_mask
286
+
287
+ def generate_cloth_mask_and_display(cloth_img):
288
+ path = 'fashion-eye-try-on/cloth/cloth.jpg'
289
+ if os.path.exists(path):
290
+ os.remove(path)
291
+ cloth_img.save(path)
292
+ try:
293
+ # os.system('.\cloth_segmentation\generate_cloth_mask.py')
294
+ generate_cloth_mask()
295
+ except Exception as e:
296
+ print(e)
297
+ return
298
+ cloth_mask_img = Image.open("fashion-eye-try-on/cloth_mask/cloth.jpg")
299
+ return cloth_mask_img
300
+
301
+ def take_human_feature_from_dataset(dataset, idx):
302
+ inputs_upper = list(torch.utils.data.DataLoader(
303
+ [dataset[idx]], batch_size=1))[0]
304
+ return {
305
+ "parse_agnostic": inputs_upper["parse_agnostic"],
306
+ "agnostic": inputs_upper["agnostic"],
307
+ "densepose": inputs_upper["densepose"],
308
+ }
309
+
310
+ def take_all_feature_with_dataset(cloth_img_path, idx, opt = opt, dataset = test_dataset, only_show_human = False):
311
+ if type(cloth_img_path) != type(""):
312
+ assert hasattr(cloth_img_path, "save")
313
+ cloth_img_path.save("tmp_cloth.jpg")
314
+ cloth_img_path = "tmp_cloth.jpg"
315
+ assert type(cloth_img_path) == type("")
316
+ inputs_upper_dict = take_human_feature_from_dataset(dataset, idx)
317
+ if only_show_human:
318
+ return Image.fromarray((inputs_upper_dict["densepose"][0].numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
319
+ cloth_readed = read_img_and_trans(dataset, opt,
320
+ cloth_img_path
321
+ )
322
+ #assert ((cloth_readed - inputs_upper["cloth"][opt.datasetting][0]) ** 2).sum().numpy() < 1e-15
323
+ cloth_input = {
324
+ opt.datasetting: cloth_readed[None,:]
325
+ }
326
+ mask_img = generate_cloth_mask_and_display(
327
+ Image.open(
328
+ cloth_img_path
329
+ )
330
+ )
331
+ cloth_mask_input = {
332
+ opt.datasetting:
333
+ torch.Tensor((np.asarray(mask_img) / 255))[None, None, :]
334
+ }
335
+ inputs_upper_dict["cloth"] = cloth_input
336
+ inputs_upper_dict["cloth_mask"] = cloth_mask_input
337
+ return inputs_upper_dict
338
+
339
+ def pred_func(cloth_img, pidx
340
+ ):
341
+ idx = int(pidx)
342
+ im = cloth_img
343
+
344
+ #### truly input
345
+ inputs_upper_dict = take_all_feature_with_dataset(
346
+ im, idx, only_show_human = False)
347
+
348
+ output_slim = single_pred_slim_func(opt, inputs_upper_dict)
349
+ output_img = construct_images(output_slim)
350
+ return output_img
351
+
352
+ option = st.selectbox(
353
+ "Choose cloth image or Upload cloth image",
354
+ ("Choose", "Upload", )
355
+ )
356
+ if type(option) != type(""):
357
+ option = "Choose"
358
+
359
+ img = None
360
+ uploaded_file = None
361
+ if option == "Upload":
362
+ # To read file as bytes:
363
+ uploaded_file = st.file_uploader("Upload img")
364
+ if uploaded_file is not None:
365
+ bytes_data = uploaded_file.getvalue()
366
+ img = Image.open(BytesIO(bytes_data))
367
+ cloth_img = img.convert("RGB").resize((256 + 128, 512))
368
+ st.image(cloth_img)
369
+ uploaded_file = st.selectbox(
370
+ "Have Choose the image",
371
+ ("Wait", "Have Done")
372
+ )
373
+ else:
374
+ img = image_select("Choose img", demo_images)
375
+ #img = Image.open(img)
376
+ cloth_img = img.convert("RGB").resize((256 + 128, 512))
377
+ st.image(cloth_img)
378
+ uploaded_file = st.selectbox(
379
+ "Have Choose the image",
380
+ ("Wait", "Have Done")
381
+ )
382
+
383
+ if img is not None and (uploaded_file is not "Wait" and uploaded_file is not None):
384
+ cloth_img = img.convert("RGB").resize((768, 1024))
385
+ #pidx = 44
386
+ pidx_index_list = [44, 84, 67]
387
+ poeses = []
388
+ for idx in range(len(pidx_index_list)):
389
+ poeses.append(
390
+ take_all_feature_with_dataset(
391
+ cloth_img, pidx_index_list[idx], only_show_human = True)
392
+ )
393
+
394
+ col1, col2, col3 = st.columns(3)
395
+
396
+ with col1:
397
+ st.header("Pose 0")
398
+ pose_img = poeses[0]
399
+ st.image(pose_img)
400
+ b = pred_func(cloth_img, pidx_index_list[0])
401
+ st.image(b)
402
+
403
+ with col2:
404
+ st.header("Pose 1")
405
+ pose_img = poeses[1]
406
+ st.image(pose_img)
407
+ b = pred_func(cloth_img, pidx_index_list[1])
408
+ st.image(b)
409
+
410
+ with col3:
411
+ st.header("Pose 2")
412
+ pose_img = poeses[2]
413
+ st.image(pose_img)
414
+ b = pred_func(cloth_img, pidx_index_list[2])
415
+ st.image(b)
cp_dataset_test.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+ import torchvision.transforms as transforms
4
+
5
+ from PIL import Image, ImageDraw
6
+
7
+ import os.path as osp
8
+ import numpy as np
9
+ import json
10
+
11
+
12
+ class CPDatasetTest(data.Dataset):
13
+ """
14
+ Test Dataset for CP-VTON.
15
+ """
16
+ def __init__(self, opt):
17
+ super(CPDatasetTest, self).__init__()
18
+ # base setting
19
+ self.opt = opt
20
+ self.root = opt.dataroot
21
+ self.datamode = opt.datamode # train or test or self-defined
22
+ self.data_list = opt.data_list
23
+ self.fine_height = opt.fine_height
24
+ self.fine_width = opt.fine_width
25
+ self.semantic_nc = opt.semantic_nc
26
+ self.data_path = osp.join(opt.dataroot, opt.datamode)
27
+ self.transform = transforms.Compose([ \
28
+ transforms.ToTensor(), \
29
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
30
+
31
+ # load data list
32
+ im_names = []
33
+ c_names = []
34
+ with open(osp.join(opt.dataroot, opt.data_list), 'r') as f:
35
+ for line in f.readlines():
36
+ im_name, c_name = line.strip().split()
37
+ im_names.append(im_name)
38
+ c_names.append(c_name)
39
+
40
+ self.im_names = im_names
41
+ self.c_names = dict()
42
+ self.c_names['paired'] = im_names
43
+ self.c_names['unpaired'] = c_names
44
+
45
+ def name(self):
46
+ return "CPDataset"
47
+ def get_agnostic(self, im, im_parse, pose_data):
48
+ parse_array = np.array(im_parse)
49
+ parse_head = ((parse_array == 4).astype(np.float32) +
50
+ (parse_array == 13).astype(np.float32))
51
+ parse_lower = ((parse_array == 9).astype(np.float32) +
52
+ (parse_array == 12).astype(np.float32) +
53
+ (parse_array == 16).astype(np.float32) +
54
+ (parse_array == 17).astype(np.float32) +
55
+ (parse_array == 18).astype(np.float32) +
56
+ (parse_array == 19).astype(np.float32))
57
+
58
+ agnostic = im.copy()
59
+ agnostic_draw = ImageDraw.Draw(agnostic)
60
+
61
+ length_a = np.linalg.norm(pose_data[5] - pose_data[2])
62
+ length_b = np.linalg.norm(pose_data[12] - pose_data[9])
63
+ point = (pose_data[9] + pose_data[12]) / 2
64
+ pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
65
+ pose_data[12] = point + (pose_data[12] - point) / length_b * length_a
66
+
67
+ r = int(length_a / 16) + 1
68
+
69
+ # mask torso
70
+ for i in [9, 12]:
71
+ pointx, pointy = pose_data[i]
72
+ agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray')
73
+ agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6)
74
+ agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6)
75
+ agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12)
76
+ agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')
77
+
78
+ # mask neck
79
+ pointx, pointy = pose_data[1]
80
+ agnostic_draw.rectangle((pointx-r*5, pointy-r*9, pointx+r*5, pointy), 'gray', 'gray')
81
+
82
+ # mask arms
83
+ agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*12)
84
+ for i in [2, 5]:
85
+ pointx, pointy = pose_data[i]
86
+ agnostic_draw.ellipse((pointx-r*5, pointy-r*6, pointx+r*5, pointy+r*6), 'gray', 'gray')
87
+ for i in [3, 4, 6, 7]:
88
+ if (pose_data[i-1, 0] == 0.0 and pose_data[i-1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
89
+ continue
90
+ agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10)
91
+ pointx, pointy = pose_data[i]
92
+ agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
93
+
94
+ for parse_id, pose_ids in [(14, [5, 6, 7]), (15, [2, 3, 4])]:
95
+ mask_arm = Image.new('L', (768, 1024), 'white')
96
+ mask_arm_draw = ImageDraw.Draw(mask_arm)
97
+ pointx, pointy = pose_data[pose_ids[0]]
98
+ mask_arm_draw.ellipse((pointx-r*5, pointy-r*6, pointx+r*5, pointy+r*6), 'black', 'black')
99
+ for i in pose_ids[1:]:
100
+ if (pose_data[i-1, 0] == 0.0 and pose_data[i-1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
101
+ continue
102
+ mask_arm_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'black', width=r*10)
103
+ pointx, pointy = pose_data[i]
104
+ if i != pose_ids[-1]:
105
+ mask_arm_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'black', 'black')
106
+ mask_arm_draw.ellipse((pointx-r*4, pointy-r*4, pointx+r*4, pointy+r*4), 'black', 'black')
107
+
108
+ parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
109
+ agnostic.paste(im, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))
110
+
111
+ agnostic.paste(im, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
112
+ agnostic.paste(im, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
113
+ return agnostic
114
+ def __getitem__(self, index):
115
+ im_name = self.im_names[index]
116
+ c_name = {}
117
+ c = {}
118
+ cm = {}
119
+ for key in self.c_names:
120
+ c_name[key] = self.c_names[key][index]
121
+ c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB')
122
+ c[key] = transforms.Resize(self.fine_width, interpolation=2)(c[key])
123
+ cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key]))
124
+ cm[key] = transforms.Resize(self.fine_width, interpolation=0)(cm[key])
125
+
126
+ c[key] = self.transform(c[key]) # [-1,1]
127
+ cm_array = np.array(cm[key])
128
+ cm_array = (cm_array >= 128).astype(np.float32)
129
+ cm[key] = torch.from_numpy(cm_array) # [0,1]
130
+ cm[key].unsqueeze_(0)
131
+
132
+ # person image
133
+ im_pil_big = Image.open(osp.join(self.data_path, 'image', im_name))
134
+ im_pil = transforms.Resize(self.fine_width, interpolation=2)(im_pil_big)
135
+
136
+ im = self.transform(im_pil)
137
+
138
+ # load parsing image
139
+ parse_name = im_name.replace('.jpg', '.png')
140
+ im_parse_pil_big = Image.open(osp.join(self.data_path, 'image-parse-v3', parse_name))
141
+ im_parse_pil = transforms.Resize(self.fine_width, interpolation=0)(im_parse_pil_big)
142
+ parse = torch.from_numpy(np.array(im_parse_pil)[None]).long()
143
+ im_parse = self.transform(im_parse_pil.convert('RGB'))
144
+
145
+ labels = {
146
+ 0: ['background', [0, 10]],
147
+ 1: ['hair', [1, 2]],
148
+ 2: ['face', [4, 13]],
149
+ 3: ['upper', [5, 6, 7]],
150
+ 4: ['bottom', [9, 12]],
151
+ 5: ['left_arm', [14]],
152
+ 6: ['right_arm', [15]],
153
+ 7: ['left_leg', [16]],
154
+ 8: ['right_leg', [17]],
155
+ 9: ['left_shoe', [18]],
156
+ 10: ['right_shoe', [19]],
157
+ 11: ['socks', [8]],
158
+ 12: ['noise', [3, 11]]
159
+ }
160
+
161
+ parse_map = torch.FloatTensor(20, self.fine_height, self.fine_width).zero_()
162
+ parse_map = parse_map.scatter_(0, parse, 1.0)
163
+ new_parse_map = torch.FloatTensor(self.semantic_nc, self.fine_height, self.fine_width).zero_()
164
+
165
+ for i in range(len(labels)):
166
+ for label in labels[i][1]:
167
+ new_parse_map[i] += parse_map[label]
168
+
169
+ parse_onehot = torch.FloatTensor(1, self.fine_height, self.fine_width).zero_()
170
+ for i in range(len(labels)):
171
+ for label in labels[i][1]:
172
+ parse_onehot[0] += parse_map[label] * i
173
+
174
+ # load image-parse-agnostic
175
+ image_parse_agnostic = Image.open(osp.join(self.data_path, 'image-parse-agnostic-v3.2', parse_name))
176
+ image_parse_agnostic = transforms.Resize(self.fine_width, interpolation=0)(image_parse_agnostic)
177
+ parse_agnostic = torch.from_numpy(np.array(image_parse_agnostic)[None]).long()
178
+ image_parse_agnostic = self.transform(image_parse_agnostic.convert('RGB'))
179
+
180
+ parse_agnostic_map = torch.FloatTensor(20, self.fine_height, self.fine_width).zero_()
181
+ parse_agnostic_map = parse_agnostic_map.scatter_(0, parse_agnostic, 1.0)
182
+ new_parse_agnostic_map = torch.FloatTensor(self.semantic_nc, self.fine_height, self.fine_width).zero_()
183
+ for i in range(len(labels)):
184
+ for label in labels[i][1]:
185
+ new_parse_agnostic_map[i] += parse_agnostic_map[label]
186
+
187
+
188
+ # parse cloth & parse cloth mask
189
+ pcm = new_parse_map[3:4]
190
+ im_c = im * pcm + (1 - pcm)
191
+
192
+ # load pose points
193
+ pose_name = im_name.replace('.jpg', '_rendered.png')
194
+ pose_map = Image.open(osp.join(self.data_path, 'openpose_img', pose_name))
195
+ pose_map = transforms.Resize(self.fine_width, interpolation=2)(pose_map)
196
+ pose_map = self.transform(pose_map) # [-1,1]
197
+
198
+ pose_name = im_name.replace('.jpg', '_keypoints.json')
199
+ with open(osp.join(self.data_path, 'openpose_json', pose_name), 'r') as f:
200
+ pose_label = json.load(f)
201
+ pose_data = pose_label['people'][0]['pose_keypoints_2d']
202
+ pose_data = np.array(pose_data)
203
+ pose_data = pose_data.reshape((-1, 3))[:, :2]
204
+
205
+
206
+ # load densepose
207
+ densepose_name = im_name.replace('image', 'image-densepose')
208
+ densepose_map = Image.open(osp.join(self.data_path, 'image-densepose', densepose_name))
209
+ densepose_map = transforms.Resize(self.fine_width, interpolation=2)(densepose_map)
210
+ densepose_map = self.transform(densepose_map) # [-1,1]
211
+ agnostic = self.get_agnostic(im_pil_big, im_parse_pil_big, pose_data)
212
+ agnostic = transforms.Resize(self.fine_width, interpolation=2)(agnostic)
213
+ agnostic = self.transform(agnostic)
214
+
215
+
216
+
217
+ result = {
218
+ 'c_name': c_name, # for visualization
219
+ 'im_name': im_name, # for visualization or ground truth
220
+ # intput 1 (clothfloww)
221
+ 'cloth': c, # for input
222
+ 'cloth_mask': cm, # for input
223
+ # intput 2 (segnet)
224
+ 'parse_agnostic': new_parse_agnostic_map,
225
+ 'densepose': densepose_map,
226
+ 'pose': pose_map, # for conditioning
227
+ # GT
228
+ 'parse_onehot' : parse_onehot, # Cross Entropy
229
+ 'parse': new_parse_map, # GAN Loss real
230
+ 'pcm': pcm, # L1 Loss & vis
231
+ 'parse_cloth': im_c, # VGG Loss & vis
232
+ # visualization
233
+ 'image': im, # for visualization
234
+ 'agnostic' : agnostic
235
+ }
236
+
237
+ return result
238
+
239
+ def __len__(self):
240
+ return len(self.im_names)
241
+
242
+
243
+ class CPDataLoader(object):
244
+ def __init__(self, opt, dataset):
245
+ super(CPDataLoader, self).__init__()
246
+ if opt.shuffle :
247
+ train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
248
+ else:
249
+ train_sampler = None
250
+
251
+ self.data_loader = torch.utils.data.DataLoader(
252
+ dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
253
+ num_workers=opt.workers, pin_memory=True, drop_last=True, sampler=train_sampler)
254
+ self.dataset = dataset
255
+ self.data_iter = self.data_loader.__iter__()
256
+
257
+ def next_batch(self):
258
+ try:
259
+ batch = self.data_iter.__next__()
260
+ except StopIteration:
261
+ self.data_iter = self.data_loader.__iter__()
262
+ batch = self.data_iter.__next__()
263
+
264
+ return batch
network_generator.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from torch.nn.utils import spectral_norm
6
+ import numpy as np
7
+
8
+
9
+ class BaseNetwork(nn.Module):
10
+ def __init__(self):
11
+ super(BaseNetwork, self).__init__()
12
+
13
+ def print_network(self):
14
+ num_params = 0
15
+ for param in self.parameters():
16
+ num_params += param.numel()
17
+ print("Network [{}] was created. Total number of parameters: {:.1f} million. "
18
+ "To see the architecture, do print(network).".format(self.__class__.__name__, num_params / 1000000))
19
+
20
+ def init_weights(self, init_type='normal', gain=0.02):
21
+ def init_func(m):
22
+ classname = m.__class__.__name__
23
+ if 'BatchNorm2d' in classname:
24
+ if hasattr(m, 'weight') and m.weight is not None:
25
+ init.normal_(m.weight.data, 1.0, gain)
26
+ if hasattr(m, 'bias') and m.bias is not None:
27
+ init.constant_(m.bias.data, 0.0)
28
+ elif ('Conv' in classname or 'Linear' in classname) and hasattr(m, 'weight'):
29
+ if init_type == 'normal':
30
+ init.normal_(m.weight.data, 0.0, gain)
31
+ elif init_type == 'xavier':
32
+ init.xavier_normal_(m.weight.data, gain=gain)
33
+ elif init_type == 'xavier_uniform':
34
+ init.xavier_uniform_(m.weight.data, gain=1.0)
35
+ elif init_type == 'kaiming':
36
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37
+ elif init_type == 'orthogonal':
38
+ init.orthogonal_(m.weight.data, gain=gain)
39
+ elif init_type == 'none': # uses pytorch's default init method
40
+ m.reset_parameters()
41
+ else:
42
+ raise NotImplementedError("initialization method '{}' is not implemented".format(init_type))
43
+ if hasattr(m, 'bias') and m.bias is not None:
44
+ init.constant_(m.bias.data, 0.0)
45
+
46
+ self.apply(init_func)
47
+
48
+ def forward(self, *inputs):
49
+ pass
50
+
51
+
52
+ class MaskNorm(nn.Module):
53
+ def __init__(self, norm_nc):
54
+ super(MaskNorm, self).__init__()
55
+
56
+ self.norm_layer = nn.InstanceNorm2d(norm_nc, affine=False)
57
+
58
+ def normalize_region(self, region, mask):
59
+ b, c, h, w = region.size()
60
+
61
+ num_pixels = mask.sum((2, 3), keepdim=True) # size: (b, 1, 1, 1)
62
+ num_pixels[num_pixels == 0] = 1
63
+ mu = region.sum((2, 3), keepdim=True) / num_pixels # size: (b, c, 1, 1)
64
+
65
+ normalized_region = self.norm_layer(region + (1 - mask) * mu)
66
+ return normalized_region * torch.sqrt(num_pixels / (h * w))
67
+
68
+ def forward(self, x, mask):
69
+ mask = mask.detach()
70
+ normalized_foreground = self.normalize_region(x * mask, mask)
71
+ normalized_background = self.normalize_region(x * (1 - mask), 1 - mask)
72
+ return normalized_foreground + normalized_background
73
+
74
+
75
+ class SPADENorm(nn.Module):
76
+ def __init__(self,opt, norm_type, norm_nc, label_nc):
77
+ super(SPADENorm, self).__init__()
78
+ self.param_opt=opt
79
+ self.noise_scale = nn.Parameter(torch.zeros(norm_nc))
80
+
81
+ assert norm_type.startswith('alias')
82
+ param_free_norm_type = norm_type[len('alias'):]
83
+ if param_free_norm_type == 'batch':
84
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
85
+ elif param_free_norm_type == 'instance':
86
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
87
+ elif param_free_norm_type == 'mask':
88
+ self.param_free_norm = MaskNorm(norm_nc)
89
+ else:
90
+ raise ValueError(
91
+ "'{}' is not a recognized parameter-free normalization type in SPADENorm".format(param_free_norm_type)
92
+ )
93
+
94
+ nhidden = 128
95
+ ks = 3
96
+ pw = ks // 2
97
+ self.conv_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
98
+ self.conv_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
99
+ self.conv_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
100
+
101
+ def forward(self, x, seg, misalign_mask=None):
102
+ # Part 1. Generate parameter-free normalized activations.
103
+ b, c, h, w = x.size()
104
+ if self.param_opt.cuda :
105
+ noise = (torch.randn(b, w, h, 1).cuda() * self.noise_scale).transpose(1, 3)
106
+ else:
107
+ noise = (torch.randn(b, w, h, 1)* self.noise_scale).transpose(1, 3)
108
+
109
+
110
+ if misalign_mask is None:
111
+ normalized = self.param_free_norm(x + noise)
112
+ else:
113
+ normalized = self.param_free_norm(x + noise, misalign_mask)
114
+
115
+ # Part 2. Produce affine parameters conditioned on the segmentation map.
116
+ actv = self.conv_shared(seg)
117
+ gamma = self.conv_gamma(actv)
118
+ beta = self.conv_beta(actv)
119
+
120
+ # Apply the affine parameters.
121
+ output = normalized * (1 + gamma) + beta
122
+ return output
123
+
124
+
125
+ class SPADEResBlock(nn.Module):
126
+ def __init__(self, opt, input_nc, output_nc, use_mask_norm=True):
127
+ super(SPADEResBlock, self).__init__()
128
+ self.param_opt=opt
129
+ self.learned_shortcut = (input_nc != output_nc)
130
+ middle_nc = min(input_nc, output_nc)
131
+
132
+ self.conv_0 = nn.Conv2d(input_nc, middle_nc, kernel_size=3, padding=1)
133
+ self.conv_1 = nn.Conv2d(middle_nc, output_nc, kernel_size=3, padding=1)
134
+ if self.learned_shortcut:
135
+ self.conv_s = nn.Conv2d(input_nc, output_nc, kernel_size=1, bias=False)
136
+
137
+ subnorm_type = opt.norm_G
138
+ if subnorm_type.startswith('spectral'):
139
+ subnorm_type = subnorm_type[len('spectral'):]
140
+ self.conv_0 = spectral_norm(self.conv_0)
141
+ self.conv_1 = spectral_norm(self.conv_1)
142
+ if self.learned_shortcut:
143
+ self.conv_s = spectral_norm(self.conv_s)
144
+
145
+ gen_semantic_nc = opt.gen_semantic_nc
146
+ if use_mask_norm:
147
+ subnorm_type = 'aliasmask'
148
+ gen_semantic_nc = gen_semantic_nc + 1
149
+
150
+ self.norm_0 = SPADENorm(opt,subnorm_type, input_nc, gen_semantic_nc)
151
+ self.norm_1 = SPADENorm(opt,subnorm_type, middle_nc, gen_semantic_nc)
152
+ if self.learned_shortcut:
153
+ self.norm_s = SPADENorm(opt,subnorm_type, input_nc, gen_semantic_nc)
154
+
155
+ self.relu = nn.LeakyReLU(0.2)
156
+
157
+ def shortcut(self, x, seg, misalign_mask):
158
+ if self.learned_shortcut:
159
+ return self.conv_s(self.norm_s(x, seg, misalign_mask))
160
+ else:
161
+ return x
162
+
163
+ def forward(self, x, seg, misalign_mask=None):
164
+ seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
165
+ if misalign_mask is not None:
166
+ misalign_mask = F.interpolate(misalign_mask, size=x.size()[2:], mode='nearest')
167
+
168
+ x_s = self.shortcut(x, seg, misalign_mask)
169
+
170
+ dx = self.conv_0(self.relu(self.norm_0(x, seg, misalign_mask)))
171
+ dx = self.conv_1(self.relu(self.norm_1(dx, seg, misalign_mask)))
172
+ output = x_s + dx
173
+ return output
174
+
175
+
176
+ class SPADEGenerator(BaseNetwork):
177
+ def __init__(self, opt, input_nc):
178
+ super(SPADEGenerator, self).__init__()
179
+ self.num_upsampling_layers = opt.num_upsampling_layers
180
+ self.param_opt=opt
181
+ self.sh, self.sw = self.compute_latent_vector_size(opt)
182
+
183
+ nf = opt.ngf
184
+ self.conv_0 = nn.Conv2d(input_nc, nf * 16, kernel_size=3, padding=1)
185
+ for i in range(1, 8):
186
+ self.add_module('conv_{}'.format(i), nn.Conv2d(input_nc, 16, kernel_size=3, padding=1))
187
+
188
+ self.head_0 = SPADEResBlock(opt, nf * 16, nf * 16, use_mask_norm=False)
189
+
190
+ self.G_middle_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
191
+ self.G_middle_1 = SPADEResBlock(opt, nf * 16 + 16, nf * 16, use_mask_norm=False)
192
+
193
+ self.up_0 = SPADEResBlock(opt, nf * 16 + 16, nf * 8, use_mask_norm=False)
194
+ self.up_1 = SPADEResBlock(opt, nf * 8 + 16, nf * 4, use_mask_norm=False)
195
+ self.up_2 = SPADEResBlock(opt, nf * 4 + 16, nf * 2, use_mask_norm=False)
196
+ self.up_3 = SPADEResBlock(opt, nf * 2 + 16, nf * 1, use_mask_norm=False)
197
+ if self.num_upsampling_layers == 'most':
198
+ self.up_4 = SPADEResBlock(opt, nf * 1 + 16, nf // 2, use_mask_norm=False)
199
+ nf = nf // 2
200
+
201
+ self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
202
+
203
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
204
+ self.relu = nn.LeakyReLU(0.2)
205
+ self.tanh = nn.Tanh()
206
+
207
+ def compute_latent_vector_size(self, opt):
208
+ if self.num_upsampling_layers == 'normal':
209
+ num_up_layers = 5
210
+ elif self.num_upsampling_layers == 'more':
211
+ num_up_layers = 6
212
+ elif self.num_upsampling_layers == 'most':
213
+ num_up_layers = 7
214
+ else:
215
+ raise ValueError("opt.num_upsampling_layers '{}' is not recognized".format(self.num_upsampling_layers))
216
+
217
+ sh = opt.fine_height // 2**num_up_layers
218
+ sw = opt.fine_width // 2**num_up_layers
219
+ return sh, sw
220
+
221
+ def forward(self, x, seg):
222
+ samples = [F.interpolate(x, size=(self.sh * 2**i, self.sw * 2**i), mode='nearest') for i in range(8)]
223
+ features = [self._modules['conv_{}'.format(i)](samples[i]) for i in range(8)]
224
+
225
+ x = self.head_0(features[0], seg)
226
+ x = self.up(x)
227
+ x = self.G_middle_0(torch.cat((x, features[1]), 1), seg)
228
+ if self.num_upsampling_layers in ['more', 'most']:
229
+ x = self.up(x)
230
+ x = self.G_middle_1(torch.cat((x, features[2]), 1), seg)
231
+
232
+ x = self.up(x)
233
+ x = self.up_0(torch.cat((x, features[3]), 1), seg)
234
+ x = self.up(x)
235
+ x = self.up_1(torch.cat((x, features[4]), 1), seg)
236
+ x = self.up(x)
237
+ x = self.up_2(torch.cat((x, features[5]), 1), seg)
238
+ x = self.up(x)
239
+ x = self.up_3(torch.cat((x, features[6]), 1), seg)
240
+ if self.num_upsampling_layers == 'most':
241
+ x = self.up(x)
242
+ x = self.up_4(torch.cat((x, features[7]), 1), seg)
243
+
244
+ x = self.conv_img(self.relu(x))
245
+ return self.tanh(x)
246
+ ########################################################################
247
+
248
+ ########################################################################
249
+
250
+ class NLayerDiscriminator(BaseNetwork):
251
+
252
+ def __init__(self, opt):
253
+ super().__init__()
254
+ self.no_ganFeat_loss = opt.no_ganFeat_loss
255
+ nf = opt.ndf
256
+
257
+ kw = 4
258
+ pw = int(np.ceil((kw - 1.0) / 2))
259
+ norm_layer = get_nonspade_norm_layer(opt.norm_D)
260
+
261
+ input_nc = opt.gen_semantic_nc + 3
262
+ # input_nc = opt.gen_semantic_nc + 13
263
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=pw),
264
+ nn.LeakyReLU(0.2, False)]]
265
+
266
+ for n in range(1, opt.n_layers_D):
267
+ nf_prev = nf
268
+ nf = min(nf * 2, 512)
269
+ sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=pw)),
270
+ nn.LeakyReLU(0.2, False)]]
271
+
272
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=pw)]]
273
+
274
+ # We divide the layers into groups to extract intermediate layer outputs
275
+ for n in range(len(sequence)):
276
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
277
+
278
+ def forward(self, input):
279
+ results = [input]
280
+ for submodel in self.children():
281
+ intermediate_output = submodel(results[-1])
282
+ results.append(intermediate_output)
283
+
284
+ get_intermediate_features = not self.no_ganFeat_loss
285
+ if get_intermediate_features:
286
+ return results[1:]
287
+ else:
288
+ return results[-1]
289
+
290
+
291
+ class MultiscaleDiscriminator(BaseNetwork):
292
+
293
+ def __init__(self, opt):
294
+ super().__init__()
295
+ self.no_ganFeat_loss = opt.no_ganFeat_loss
296
+
297
+ for i in range(opt.num_D):
298
+ subnetD = NLayerDiscriminator(opt)
299
+ self.add_module('discriminator_%d' % i, subnetD)
300
+
301
+ def downsample(self, input):
302
+ return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
303
+
304
+ # Returns list of lists of discriminator outputs.
305
+ # The final result is of size opt.num_D x opt.n_layers_D
306
+ def forward(self, input):
307
+ result = []
308
+ get_intermediate_features = not self.no_ganFeat_loss
309
+ for name, D in self.named_children():
310
+ out = D(input)
311
+ if not get_intermediate_features:
312
+ out = [out]
313
+ result.append(out)
314
+ input = self.downsample(input)
315
+
316
+ return result
317
+
318
+ class GANLoss(nn.Module):
319
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor):
320
+ super(GANLoss, self).__init__()
321
+ self.real_label = target_real_label
322
+ self.fake_label = target_fake_label
323
+ self.real_label_tensor = None
324
+ self.fake_label_tensor = None
325
+ self.zero_tensor = None
326
+ self.Tensor = tensor
327
+ self.gan_mode = gan_mode
328
+ if gan_mode == 'ls':
329
+ pass
330
+ elif gan_mode == 'original':
331
+ pass
332
+ elif gan_mode == 'w':
333
+ pass
334
+ elif gan_mode == 'hinge':
335
+ pass
336
+ else:
337
+ raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
338
+
339
+ def get_target_tensor(self, input, target_is_real):
340
+ if target_is_real:
341
+ if self.real_label_tensor is None:
342
+ self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
343
+ self.real_label_tensor.requires_grad_(False)
344
+ return self.real_label_tensor.expand_as(input)
345
+ else:
346
+ if self.fake_label_tensor is None:
347
+ self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
348
+ self.fake_label_tensor.requires_grad_(False)
349
+ return self.fake_label_tensor.expand_as(input)
350
+
351
+ def get_zero_tensor(self, input):
352
+ if self.zero_tensor is None:
353
+ self.zero_tensor = self.Tensor(1).fill_(0)
354
+ self.zero_tensor.requires_grad_(False)
355
+ return self.zero_tensor.expand_as(input)
356
+
357
+ def loss(self, input, target_is_real, for_discriminator=True):
358
+ if self.gan_mode == 'original': # cross entropy loss
359
+ target_tensor = self.get_target_tensor(input, target_is_real)
360
+ loss = F.binary_cross_entropy_with_logits(input, target_tensor)
361
+ return loss
362
+ elif self.gan_mode == 'ls':
363
+ target_tensor = self.get_target_tensor(input, target_is_real)
364
+ return F.mse_loss(input, target_tensor)
365
+ elif self.gan_mode == 'hinge':
366
+ if for_discriminator:
367
+ if target_is_real:
368
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
369
+ loss = -torch.mean(minval)
370
+ else:
371
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
372
+ loss = -torch.mean(minval)
373
+ else:
374
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
375
+ loss = -torch.mean(input)
376
+ return loss
377
+ else:
378
+ # wgan
379
+ if target_is_real:
380
+ return -input.mean()
381
+ else:
382
+ return input.mean()
383
+
384
+ def __call__(self, input, target_is_real, for_discriminator=True):
385
+ # computing loss is a bit complicated because |input| may not be
386
+ # a tensor, but list of tensors in case of multiscale discriminator
387
+ if isinstance(input, list):
388
+ loss = 0
389
+ for pred_i in input:
390
+ if isinstance(pred_i, list):
391
+ pred_i = pred_i[-1]
392
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
393
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
394
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
395
+ loss += new_loss
396
+ return loss / len(input)
397
+ else:
398
+ return self.loss(input, target_is_real, for_discriminator)
399
+
400
+
401
+ def get_nonspade_norm_layer(norm_type='instance'):
402
+ def get_out_channel(layer):
403
+ if hasattr(layer, 'out_channels'):
404
+ return getattr(layer, 'out_channels')
405
+ return layer.weight.size(0)
406
+
407
+ def add_norm_layer(layer):
408
+ nonlocal norm_type
409
+ if norm_type.startswith('spectral'):
410
+ layer = spectral_norm(layer)
411
+ subnorm_type = norm_type[len('spectral'):]
412
+
413
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
414
+ return layer
415
+
416
+ # remove bias in the previous layer, which is meaningless
417
+ # since it has no effect after normalization
418
+ if getattr(layer, 'bias', None) is not None:
419
+ delattr(layer, 'bias')
420
+ layer.register_parameter('bias', None)
421
+
422
+ if subnorm_type == 'batch':
423
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
424
+ # elif subnorm_type == 'sync_batch':
425
+ # norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
426
+ elif subnorm_type == 'instance':
427
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
428
+ else:
429
+ raise ValueError('normalization layer %s is not recognized' % subnorm_type)
430
+
431
+ return nn.Sequential(layer, norm_layer)
432
+
433
+ return add_norm_layer
networks.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from torchvision import models
6
+ import os
7
+ from torch.nn.utils import spectral_norm
8
+ import numpy as np
9
+
10
+ import functools
11
+
12
+
13
+ class ConditionGenerator(nn.Module):
14
+ def __init__(self, opt, input1_nc, input2_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d):
15
+ super(ConditionGenerator, self).__init__()
16
+ self.warp_feature = opt.warp_feature
17
+ self.out_layer_opt = opt.out_layer
18
+
19
+ self.ClothEncoder = nn.Sequential(
20
+ ResBlock(input1_nc, ngf, norm_layer=norm_layer, scale='down'), # 128
21
+ ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'), # 64
22
+ ResBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, scale='down'), # 32
23
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'), # 16
24
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down') # 8
25
+ )
26
+
27
+ self.PoseEncoder = nn.Sequential(
28
+ ResBlock(input2_nc, ngf, norm_layer=norm_layer, scale='down'),
29
+ ResBlock(ngf, ngf * 2, norm_layer=norm_layer, scale='down'),
30
+ ResBlock(ngf * 2, ngf * 4, norm_layer=norm_layer, scale='down'),
31
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down'),
32
+ ResBlock(ngf * 4, ngf * 4, norm_layer=norm_layer, scale='down')
33
+ )
34
+
35
+ self.conv = ResBlock(ngf * 4, ngf * 8, norm_layer=norm_layer, scale='same')
36
+
37
+ if opt.warp_feature == 'T1':
38
+ # in_nc -> skip connection + T1, T2 channel
39
+ self.SegDecoder = nn.Sequential(
40
+ ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
41
+ ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 4, norm_layer=norm_layer, scale='up'), # 32
42
+ ResBlock(ngf * 4 * 2 + ngf * 4 , ngf * 2, norm_layer=norm_layer, scale='up'), # 64
43
+ ResBlock(ngf * 2 * 2 + ngf * 4 , ngf, norm_layer=norm_layer, scale='up'), # 128
44
+ ResBlock(ngf * 1 * 2 + ngf * 4, ngf, norm_layer=norm_layer, scale='up') # 256
45
+ )
46
+ if opt.warp_feature == 'encoder':
47
+ # in_nc -> [x, skip_connection, warped_cloth_encoder_feature(E1)]
48
+ self.SegDecoder = nn.Sequential(
49
+ ResBlock(ngf * 8, ngf * 4, norm_layer=norm_layer, scale='up'), # 16
50
+ ResBlock(ngf * 4 * 3, ngf * 4, norm_layer=norm_layer, scale='up'), # 32
51
+ ResBlock(ngf * 4 * 3, ngf * 2, norm_layer=norm_layer, scale='up'), # 64
52
+ ResBlock(ngf * 2 * 3, ngf, norm_layer=norm_layer, scale='up'), # 128
53
+ ResBlock(ngf * 1 * 3, ngf, norm_layer=norm_layer, scale='up') # 256
54
+ )
55
+ if opt.out_layer == 'relu':
56
+ self.out_layer = ResBlock(ngf + input1_nc + input2_nc, output_nc, norm_layer=norm_layer, scale='same')
57
+ if opt.out_layer == 'conv':
58
+ self.out_layer = nn.Sequential(
59
+ ResBlock(ngf + input1_nc + input2_nc, ngf, norm_layer=norm_layer, scale='same'),
60
+ nn.Conv2d(ngf, output_nc, kernel_size=1, bias=True)
61
+ )
62
+
63
+ # Cloth Conv 1x1
64
+ self.conv1 = nn.Sequential(
65
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
66
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
67
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
68
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
69
+ )
70
+
71
+ # Person Conv 1x1
72
+ self.conv2 = nn.Sequential(
73
+ nn.Conv2d(ngf, ngf * 4, kernel_size=1, bias=True),
74
+ nn.Conv2d(ngf * 2, ngf * 4, kernel_size=1, bias=True),
75
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
76
+ nn.Conv2d(ngf * 4, ngf * 4, kernel_size=1, bias=True),
77
+ )
78
+
79
+ self.flow_conv = nn.ModuleList([
80
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
81
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
82
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
83
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
84
+ nn.Conv2d(ngf * 8, 2, kernel_size=3, stride=1, padding=1, bias=True),
85
+ ]
86
+ )
87
+
88
+ self.bottleneck = nn.Sequential(
89
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
90
+ nn.Sequential(nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
91
+ nn.Sequential(nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True) , nn.ReLU()),
92
+ nn.Sequential(nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1, bias=True), nn.ReLU()),
93
+ )
94
+
95
+ def normalize(self, x):
96
+ return x
97
+
98
+ def forward(self,opt,input1, input2, upsample='bilinear'):
99
+ E1_list = []
100
+ E2_list = []
101
+ flow_list = []
102
+ # warped_grid_list = []
103
+
104
+ # Feature Pyramid Network
105
+ for i in range(5):
106
+ if i == 0:
107
+ E1_list.append(self.ClothEncoder[i](input1))
108
+ E2_list.append(self.PoseEncoder[i](input2))
109
+ else:
110
+ E1_list.append(self.ClothEncoder[i](E1_list[i - 1]))
111
+ E2_list.append(self.PoseEncoder[i](E2_list[i - 1]))
112
+
113
+ # Compute Clothflow
114
+ for i in range(5):
115
+ N, _, iH, iW = E1_list[4 - i].size()
116
+ grid = make_grid(N, iH, iW,opt)
117
+
118
+ if i == 0:
119
+ T1 = E1_list[4 - i] # (ngf * 4) x 8 x 6
120
+ T2 = E2_list[4 - i]
121
+ E4 = torch.cat([T1, T2], 1)
122
+
123
+ flow = self.flow_conv[i](self.normalize(E4)).permute(0, 2, 3, 1)
124
+ flow_list.append(flow)
125
+
126
+ x = self.conv(T2)
127
+ x = self.SegDecoder[i](x)
128
+
129
+ else:
130
+ T1 = F.interpolate(T1, scale_factor=2, mode=upsample) + self.conv1[4 - i](E1_list[4 - i])
131
+ T2 = F.interpolate(T2, scale_factor=2, mode=upsample) + self.conv2[4 - i](E2_list[4 - i])
132
+
133
+ flow = F.interpolate(flow_list[i - 1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1) # upsample n-1 flow
134
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
135
+ warped_T1 = F.grid_sample(T1, flow_norm + grid, padding_mode='border')
136
+
137
+ flow = flow + self.flow_conv[i](self.normalize(torch.cat([warped_T1, self.bottleneck[i-1](x)], 1))).permute(0, 2, 3, 1) # F(n)
138
+ flow_list.append(flow)
139
+
140
+ if self.warp_feature == 'T1':
141
+ x = self.SegDecoder[i](torch.cat([x, E2_list[4-i], warped_T1], 1))
142
+ if self.warp_feature == 'encoder':
143
+ warped_E1 = F.grid_sample(E1_list[4-i], flow_norm + grid, padding_mode='border')
144
+ x = self.SegDecoder[i](torch.cat([x, E2_list[4-i], warped_E1], 1))
145
+
146
+
147
+ N, _, iH, iW = input1.size()
148
+ grid = make_grid(N, iH, iW,opt)
149
+
150
+ flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), scale_factor=2, mode=upsample).permute(0, 2, 3, 1)
151
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((iW/2 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((iH/2 - 1.0) / 2.0)], 3)
152
+ warped_input1 = F.grid_sample(input1, flow_norm + grid, padding_mode='border')
153
+
154
+ x = self.out_layer(torch.cat([x, input2, warped_input1], 1))
155
+
156
+ warped_c = warped_input1[:, :-1, :, :]
157
+ warped_cm = warped_input1[:, -1:, :, :]
158
+
159
+ return flow_list, x, warped_c, warped_cm
160
+
161
+ def make_grid(N, iH, iW,opt):
162
+ grid_x = torch.linspace(-1.0, 1.0, iW).view(1, 1, iW, 1).expand(N, iH, -1, -1)
163
+ grid_y = torch.linspace(-1.0, 1.0, iH).view(1, iH, 1, 1).expand(N, -1, iW, -1)
164
+ if opt.cuda :
165
+ grid = torch.cat([grid_x, grid_y], 3).cuda()
166
+ else:
167
+ grid = torch.cat([grid_x, grid_y], 3)
168
+ return grid
169
+
170
+
171
+ class ResBlock(nn.Module):
172
+ def __init__(self, in_nc, out_nc, scale='down', norm_layer=nn.BatchNorm2d):
173
+ super(ResBlock, self).__init__()
174
+ use_bias = norm_layer == nn.InstanceNorm2d
175
+ assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
176
+
177
+ if scale == 'same':
178
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
179
+ if scale == 'up':
180
+ self.scale = nn.Sequential(
181
+ nn.Upsample(scale_factor=2, mode='bilinear'),
182
+ nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
183
+ )
184
+ if scale == 'down':
185
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
186
+
187
+ self.block = nn.Sequential(
188
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
189
+ norm_layer(out_nc),
190
+ nn.ReLU(inplace=True),
191
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
192
+ norm_layer(out_nc)
193
+ )
194
+ self.relu = nn.ReLU(inplace=True)
195
+
196
+ def forward(self, x):
197
+ residual = self.scale(x)
198
+ return self.relu(residual + self.block(residual))
199
+
200
+
201
+ class Vgg19(nn.Module):
202
+ def __init__(self, requires_grad=False):
203
+ super(Vgg19, self).__init__()
204
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
205
+ self.slice1 = torch.nn.Sequential()
206
+ self.slice2 = torch.nn.Sequential()
207
+ self.slice3 = torch.nn.Sequential()
208
+ self.slice4 = torch.nn.Sequential()
209
+ self.slice5 = torch.nn.Sequential()
210
+ for x in range(2):
211
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
212
+ for x in range(2, 7):
213
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
214
+ for x in range(7, 12):
215
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
216
+ for x in range(12, 21):
217
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
218
+ for x in range(21, 30):
219
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
220
+ if not requires_grad:
221
+ for param in self.parameters():
222
+ param.requires_grad = False
223
+
224
+ def forward(self, X):
225
+ h_relu1 = self.slice1(X)
226
+ h_relu2 = self.slice2(h_relu1)
227
+ h_relu3 = self.slice3(h_relu2)
228
+ h_relu4 = self.slice4(h_relu3)
229
+ h_relu5 = self.slice5(h_relu4)
230
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
231
+ return out
232
+
233
+
234
+ class VGGLoss(nn.Module):
235
+ def __init__(self, opt,layids = None):
236
+ super(VGGLoss, self).__init__()
237
+ self.vgg = Vgg19()
238
+ if opt.cuda:
239
+ self.vgg.cuda()
240
+ self.criterion = nn.L1Loss()
241
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
242
+ self.layids = layids
243
+
244
+ def forward(self, x, y):
245
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
246
+ loss = 0
247
+ if self.layids is None:
248
+ self.layids = list(range(len(x_vgg)))
249
+ for i in self.layids:
250
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
251
+ return loss
252
+
253
+ # Defines the GAN loss which uses either LSGAN or the regular GAN.
254
+ # When LSGAN is used, it is basically same as MSELoss,
255
+ # but it abstracts away the need to create the target label tensor
256
+ # that has the same size as the input
257
+
258
+ class GANLoss(nn.Module):
259
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
260
+ tensor=torch.FloatTensor):
261
+ super(GANLoss, self).__init__()
262
+ self.real_label = target_real_label
263
+ self.fake_label = target_fake_label
264
+ self.real_label_var = None
265
+ self.fake_label_var = None
266
+ self.Tensor = tensor
267
+ if use_lsgan:
268
+ self.loss = nn.MSELoss()
269
+ else:
270
+ self.loss = nn.BCELoss()
271
+
272
+ def get_target_tensor(self, input, target_is_real):
273
+ if target_is_real:
274
+ create_label = ((self.real_label_var is None) or
275
+ (self.real_label_var.numel() != input.numel()))
276
+ if create_label:
277
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
278
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
279
+ target_tensor = self.real_label_var
280
+ else:
281
+ create_label = ((self.fake_label_var is None) or
282
+ (self.fake_label_var.numel() != input.numel()))
283
+ if create_label:
284
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
285
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
286
+ target_tensor = self.fake_label_var
287
+ return target_tensor
288
+
289
+ def __call__(self, input, target_is_real):
290
+ if isinstance(input[0], list):
291
+ loss = 0
292
+ for input_i in input:
293
+ pred = input_i[-1]
294
+ target_tensor = self.get_target_tensor(pred, target_is_real)
295
+ loss += self.loss(pred, target_tensor)
296
+ return loss
297
+ else:
298
+ target_tensor = self.get_target_tensor(input[-1], target_is_real)
299
+ return self.loss(input[-1], target_tensor)
300
+
301
+
302
+ class MultiscaleDiscriminator(nn.Module):
303
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
304
+ use_sigmoid=False, num_D=3, getIntermFeat=False, Ddownx2=False, Ddropout=False, spectral=False):
305
+ super(MultiscaleDiscriminator, self).__init__()
306
+ self.num_D = num_D
307
+ self.n_layers = n_layers
308
+ self.getIntermFeat = getIntermFeat
309
+ self.Ddownx2 = Ddownx2
310
+
311
+
312
+ for i in range(num_D):
313
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat, Ddropout, spectral=spectral)
314
+ if getIntermFeat:
315
+ for j in range(n_layers + 2):
316
+ setattr(self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)))
317
+ else:
318
+ setattr(self, 'layer' + str(i), netD.model)
319
+
320
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
321
+
322
+ def singleD_forward(self, model, input):
323
+ if self.getIntermFeat:
324
+ result = [input]
325
+ for i in range(len(model)):
326
+ result.append(model[i](result[-1]))
327
+ return result[1:]
328
+ else:
329
+ return [model(input)]
330
+
331
+ def forward(self, input):
332
+ num_D = self.num_D
333
+
334
+ result = []
335
+ if self.Ddownx2:
336
+ input_downsampled = self.downsample(input)
337
+ else:
338
+ input_downsampled = input
339
+ for i in range(num_D):
340
+
341
+ if self.getIntermFeat:
342
+ model = [getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) for j in
343
+ range(self.n_layers + 2)]
344
+ else:
345
+ model = getattr(self, 'layer' + str(num_D - 1 - i))
346
+ result.append(self.singleD_forward(model, input_downsampled))
347
+ if i != (num_D - 1):
348
+ input_downsampled = self.downsample(input_downsampled)
349
+ return result
350
+
351
+ class NLayerDiscriminator(nn.Module):
352
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False, Ddropout=False, spectral=False):
353
+ super(NLayerDiscriminator, self).__init__()
354
+ self.getIntermFeat = getIntermFeat
355
+ self.n_layers = n_layers
356
+ self.spectral_norm = spectral_norm if spectral else lambda x: x
357
+
358
+ kw = 4
359
+ padw = int(np.ceil((kw - 1.0) / 2))
360
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
361
+
362
+ nf = ndf
363
+ for n in range(1, n_layers):
364
+ nf_prev = nf
365
+ nf = min(nf * 2, 512)
366
+ if Ddropout:
367
+ sequence += [[
368
+ self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
369
+ norm_layer(nf), nn.LeakyReLU(0.2, True), nn.Dropout(0.5)
370
+ ]]
371
+ else:
372
+
373
+ sequence += [[
374
+ self.spectral_norm(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw)),
375
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
376
+ ]]
377
+
378
+ nf_prev = nf
379
+ nf = min(nf * 2, 512)
380
+ sequence += [[
381
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
382
+ norm_layer(nf),
383
+ nn.LeakyReLU(0.2, True)
384
+ ]]
385
+
386
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
387
+
388
+ if use_sigmoid:
389
+ sequence += [[nn.Sigmoid()]]
390
+
391
+ if getIntermFeat:
392
+ for n in range(len(sequence)):
393
+ setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))
394
+ else:
395
+ sequence_stream = []
396
+ for n in range(len(sequence)):
397
+ sequence_stream += sequence[n]
398
+ self.model = nn.Sequential(*sequence_stream)
399
+
400
+ def forward(self, input):
401
+ if self.getIntermFeat:
402
+ res = [input]
403
+ for n in range(self.n_layers + 2):
404
+ model = getattr(self, 'model' + str(n))
405
+ res.append(model(res[-1]))
406
+ return res[1:]
407
+ else:
408
+ return self.model(input)
409
+
410
+
411
+ def save_checkpoint(model, save_path,opt):
412
+ if not os.path.exists(os.path.dirname(save_path)):
413
+ os.makedirs(os.path.dirname(save_path))
414
+
415
+ torch.save(model.cpu().state_dict(), save_path)
416
+ if opt.cuda :
417
+ model.cuda()
418
+
419
+ def load_checkpoint(model, checkpoint_path,opt):
420
+ if not os.path.exists(checkpoint_path):
421
+ print('no checkpoint')
422
+ raise
423
+ log = model.load_state_dict(torch.load(checkpoint_path), strict=False)
424
+ if opt.cuda :
425
+ model.cuda()
426
+
427
+
428
+ def weights_init(m):
429
+ classname = m.__class__.__name__
430
+ if classname.find('Conv2d') != -1:
431
+ m.weight.data.normal_(0.0, 0.02)
432
+ elif classname.find('BatchNorm2d') != -1:
433
+ m.weight.data.normal_(1.0, 0.02)
434
+ m.bias.data.fill_(0)
435
+
436
+ def get_norm_layer(norm_type='instance'):
437
+ if norm_type == 'batch':
438
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
439
+ elif norm_type == 'instance':
440
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
441
+ else:
442
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
443
+ return norm_layer
444
+
445
+ def define_D(input_nc, ndf=64, n_layers_D=3, norm='instance', use_sigmoid=False, num_D=2, getIntermFeat=False, gpu_ids=[], Ddownx2=False, Ddropout=False, spectral=False):
446
+ norm_layer = get_norm_layer(norm_type=norm)
447
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat, Ddownx2, Ddropout, spectral=spectral)
448
+ print(netD)
449
+ if len(gpu_ids) > 0:
450
+ assert (torch.cuda.is_available())
451
+ netD.cuda()
452
+ netD.apply(weights_init)
453
+ return netD
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ opencv-python
5
+ torchgeometry
6
+ Pillow
7
+ tqdm
8
+ tensorboardX
9
+ scikit-image
10
+ scipy
11
+ streamlit-image-select
12
+ pandas
test_generator.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from torchvision.utils import make_grid as make_image_grid
5
+ from torchvision.utils import save_image
6
+ import argparse
7
+ import os
8
+ import time
9
+ from cp_dataset_test import CPDatasetTest, CPDataLoader
10
+
11
+ from networks import ConditionGenerator, load_checkpoint, make_grid
12
+ from network_generator import SPADEGenerator
13
+ from tensorboardX import SummaryWriter
14
+ from utils import *
15
+
16
+ import torchgeometry as tgm
17
+ from collections import OrderedDict
18
+
19
+ def remove_overlap(seg_out, warped_cm):
20
+
21
+ assert len(warped_cm.shape) == 4
22
+
23
+ warped_cm = warped_cm - (torch.cat([seg_out[:, 1:3, :, :], seg_out[:, 5:, :, :]], dim=1)).sum(dim=1, keepdim=True) * warped_cm
24
+ return warped_cm
25
+ def get_opt():
26
+ parser = argparse.ArgumentParser()
27
+
28
+ parser.add_argument("--gpu_ids", default="")
29
+ parser.add_argument('-j', '--workers', type=int, default=4)
30
+ parser.add_argument('-b', '--batch-size', type=int, default=1)
31
+ parser.add_argument('--fp16', action='store_true', help='use amp')
32
+ # Cuda availability
33
+ parser.add_argument('--cuda',default=False, help='cuda or cpu')
34
+
35
+ parser.add_argument('--test_name', type=str, default='test', help='test name')
36
+ parser.add_argument("--dataroot", default="./data/zalando-hd-resize")
37
+ parser.add_argument("--datamode", default="test")
38
+ parser.add_argument("--data_list", default="test_pairs.txt")
39
+ parser.add_argument("--output_dir", type=str, default="./Output")
40
+ parser.add_argument("--datasetting", default="unpaired")
41
+ parser.add_argument("--fine_width", type=int, default=768)
42
+ parser.add_argument("--fine_height", type=int, default=1024)
43
+
44
+ parser.add_argument('--tensorboard_dir', type=str, default='./data/zalando-hd-resize/tensorboard', help='save tensorboard infos')
45
+ parser.add_argument('--checkpoint_dir', type=str, default='checkpoints', help='save checkpoint infos')
46
+ parser.add_argument('--tocg_checkpoint', type=str, default='./eval_models/weights/v0.1/mtviton.pth', help='tocg checkpoint')
47
+ parser.add_argument('--gen_checkpoint', type=str, default='./eval_models/weights/v0.1/gen.pth', help='G checkpoint')
48
+
49
+ parser.add_argument("--tensorboard_count", type=int, default=100)
50
+ parser.add_argument("--shuffle", action='store_true', help='shuffle input data')
51
+ parser.add_argument("--semantic_nc", type=int, default=13)
52
+ parser.add_argument("--output_nc", type=int, default=13)
53
+ parser.add_argument('--gen_semantic_nc', type=int, default=7, help='# of input label classes without unknown class')
54
+
55
+ # network
56
+ parser.add_argument("--warp_feature", choices=['encoder', 'T1'], default="T1")
57
+ parser.add_argument("--out_layer", choices=['relu', 'conv'], default="relu")
58
+
59
+ # training
60
+ parser.add_argument("--clothmask_composition", type=str, choices=['no_composition', 'detach', 'warp_grad'], default='warp_grad')
61
+
62
+ # Hyper-parameters
63
+ parser.add_argument('--upsample', type=str, default='bilinear', choices=['nearest', 'bilinear'])
64
+ parser.add_argument('--occlusion', action='store_true', help="Occlusion handling")
65
+
66
+ # generator
67
+ parser.add_argument('--norm_G', type=str, default='spectralaliasinstance', help='instance normalization or batch normalization')
68
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
69
+ parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
70
+ parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
71
+ parser.add_argument('--num_upsampling_layers', choices=('normal', 'more', 'most'), default='most', # normal: 256, more: 512
72
+ help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator")
73
+
74
+ opt = parser.parse_args()
75
+ return opt
76
+
77
+ def load_checkpoint_G(model, checkpoint_path,opt):
78
+ if not os.path.exists(checkpoint_path):
79
+ print("Invalid path!")
80
+ return
81
+ state_dict = torch.load(checkpoint_path)
82
+ new_state_dict = OrderedDict([(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict.items()])
83
+ new_state_dict._metadata = OrderedDict([(k.replace('ace', 'alias').replace('.Spade', ''), v) for (k, v) in state_dict._metadata.items()])
84
+ model.load_state_dict(new_state_dict, strict=True)
85
+ if opt.cuda :
86
+ model.cuda()
87
+
88
+
89
+
90
+ def test(opt, test_loader, tocg, generator):
91
+ gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
92
+ if opt.cuda:
93
+ gauss = gauss.cuda()
94
+
95
+ # Model
96
+ if opt.cuda :
97
+ tocg.cuda()
98
+ tocg.eval()
99
+ generator.eval()
100
+
101
+ if opt.output_dir is not None:
102
+ output_dir = opt.output_dir
103
+ else:
104
+ output_dir = os.path.join('./output', opt.test_name,
105
+ opt.datamode, opt.datasetting, 'generator', 'output')
106
+ grid_dir = os.path.join('./output', opt.test_name,
107
+ opt.datamode, opt.datasetting, 'generator', 'grid')
108
+
109
+ os.makedirs(grid_dir, exist_ok=True)
110
+
111
+ os.makedirs(output_dir, exist_ok=True)
112
+
113
+ num = 0
114
+ iter_start_time = time.time()
115
+ with torch.no_grad():
116
+ for inputs in test_loader.data_loader:
117
+
118
+ if opt.cuda :
119
+ pose_map = inputs['pose'].cuda()
120
+ pre_clothes_mask = inputs['cloth_mask'][opt.datasetting].cuda()
121
+ label = inputs['parse']
122
+ parse_agnostic = inputs['parse_agnostic']
123
+ agnostic = inputs['agnostic'].cuda()
124
+ clothes = inputs['cloth'][opt.datasetting].cuda() # target cloth
125
+ densepose = inputs['densepose'].cuda()
126
+ im = inputs['image']
127
+ input_label, input_parse_agnostic = label.cuda(), parse_agnostic.cuda()
128
+ pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
129
+ else :
130
+ pose_map = inputs['pose']
131
+ pre_clothes_mask = inputs['cloth_mask'][opt.datasetting]
132
+ label = inputs['parse']
133
+ parse_agnostic = inputs['parse_agnostic']
134
+ agnostic = inputs['agnostic']
135
+ clothes = inputs['cloth'][opt.datasetting] # target cloth
136
+ densepose = inputs['densepose']
137
+ im = inputs['image']
138
+ input_label, input_parse_agnostic = label, parse_agnostic
139
+ pre_clothes_mask = torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float))
140
+
141
+
142
+
143
+ # down
144
+ pose_map_down = F.interpolate(pose_map, size=(256, 192), mode='bilinear')
145
+ pre_clothes_mask_down = F.interpolate(pre_clothes_mask, size=(256, 192), mode='nearest')
146
+ input_label_down = F.interpolate(input_label, size=(256, 192), mode='bilinear')
147
+ input_parse_agnostic_down = F.interpolate(input_parse_agnostic, size=(256, 192), mode='nearest')
148
+ agnostic_down = F.interpolate(agnostic, size=(256, 192), mode='nearest')
149
+ clothes_down = F.interpolate(clothes, size=(256, 192), mode='bilinear')
150
+ densepose_down = F.interpolate(densepose, size=(256, 192), mode='bilinear')
151
+
152
+ shape = pre_clothes_mask.shape
153
+
154
+ # multi-task inputs
155
+ input1 = torch.cat([clothes_down, pre_clothes_mask_down], 1)
156
+ input2 = torch.cat([input_parse_agnostic_down, densepose_down], 1)
157
+
158
+ # forward
159
+ flow_list, fake_segmap, warped_cloth_paired, warped_clothmask_paired = tocg(opt,input1, input2)
160
+
161
+ # warped cloth mask one hot
162
+ if opt.cuda :
163
+ warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
164
+ else :
165
+ warped_cm_onehot = torch.FloatTensor((warped_clothmask_paired.detach().cpu().numpy() > 0.5).astype(np.float))
166
+
167
+ if opt.clothmask_composition != 'no_composition':
168
+ if opt.clothmask_composition == 'detach':
169
+ cloth_mask = torch.ones_like(fake_segmap)
170
+ cloth_mask[:,3:4, :, :] = warped_cm_onehot
171
+ fake_segmap = fake_segmap * cloth_mask
172
+
173
+ if opt.clothmask_composition == 'warp_grad':
174
+ cloth_mask = torch.ones_like(fake_segmap)
175
+ cloth_mask[:,3:4, :, :] = warped_clothmask_paired
176
+ fake_segmap = fake_segmap * cloth_mask
177
+
178
+ # make generator input parse map
179
+ fake_parse_gauss = gauss(F.interpolate(fake_segmap, size=(opt.fine_height, opt.fine_width), mode='bilinear'))
180
+ fake_parse = fake_parse_gauss.argmax(dim=1)[:, None]
181
+
182
+ if opt.cuda :
183
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_().cuda()
184
+ else:
185
+ old_parse = torch.FloatTensor(fake_parse.size(0), 13, opt.fine_height, opt.fine_width).zero_()
186
+ old_parse.scatter_(1, fake_parse, 1.0)
187
+
188
+ labels = {
189
+ 0: ['background', [0]],
190
+ 1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
191
+ 2: ['upper', [3]],
192
+ 3: ['hair', [1]],
193
+ 4: ['left_arm', [5]],
194
+ 5: ['right_arm', [6]],
195
+ 6: ['noise', [12]]
196
+ }
197
+ if opt.cuda :
198
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_().cuda()
199
+ else:
200
+ parse = torch.FloatTensor(fake_parse.size(0), 7, opt.fine_height, opt.fine_width).zero_()
201
+ for i in range(len(labels)):
202
+ for label in labels[i][1]:
203
+ parse[:, i] += old_parse[:, label]
204
+
205
+ # warped cloth
206
+ N, _, iH, iW = clothes.shape
207
+ flow = F.interpolate(flow_list[-1].permute(0, 3, 1, 2), size=(iH, iW), mode='bilinear').permute(0, 2, 3, 1)
208
+ flow_norm = torch.cat([flow[:, :, :, 0:1] / ((96 - 1.0) / 2.0), flow[:, :, :, 1:2] / ((128 - 1.0) / 2.0)], 3)
209
+
210
+ grid = make_grid(N, iH, iW,opt)
211
+ warped_grid = grid + flow_norm
212
+ warped_cloth = F.grid_sample(clothes, warped_grid, padding_mode='border')
213
+ warped_clothmask = F.grid_sample(pre_clothes_mask, warped_grid, padding_mode='border')
214
+ if opt.occlusion:
215
+ warped_clothmask = remove_overlap(F.softmax(fake_parse_gauss, dim=1), warped_clothmask)
216
+ warped_cloth = warped_cloth * warped_clothmask + torch.ones_like(warped_cloth) * (1-warped_clothmask)
217
+
218
+
219
+ output = generator(torch.cat((agnostic, densepose, warped_cloth), dim=1), parse)
220
+ # visualize
221
+ unpaired_names = []
222
+ for i in range(shape[0]):
223
+ grid = make_image_grid([(clothes[i].cpu() / 2 + 0.5), (pre_clothes_mask[i].cpu()).expand(3, -1, -1), visualize_segmap(parse_agnostic.cpu(), batch=i), ((densepose.cpu()[i]+1)/2),
224
+ (warped_cloth[i].cpu().detach() / 2 + 0.5), (warped_clothmask[i].cpu().detach()).expand(3, -1, -1), visualize_segmap(fake_parse_gauss.cpu(), batch=i),
225
+ (pose_map[i].cpu()/2 +0.5), (warped_cloth[i].cpu()/2 + 0.5), (agnostic[i].cpu()/2 + 0.5),
226
+ (im[i]/2 +0.5), (output[i].cpu()/2 +0.5)],
227
+ nrow=4)
228
+ unpaired_name = (inputs['c_name']['paired'][i].split('.')[0] + '_' + inputs['c_name'][opt.datasetting][i].split('.')[0] + '.png')
229
+ save_image(grid, os.path.join(grid_dir, unpaired_name))
230
+ unpaired_names.append(unpaired_name)
231
+
232
+ # save output
233
+ save_images(output, unpaired_names, output_dir)
234
+
235
+ num += shape[0]
236
+ print(num)
237
+
238
+ print(f"Test time {time.time() - iter_start_time}")
239
+
240
+
241
+ def main():
242
+ opt = get_opt()
243
+ print(opt)
244
+ print("Start to test %s!")
245
+ os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_ids
246
+
247
+ # create test dataset & loader
248
+ test_dataset = CPDatasetTest(opt)
249
+ test_loader = CPDataLoader(opt, test_dataset)
250
+
251
+ # visualization
252
+ # if not os.path.exists(opt.tensorboard_dir):
253
+ # os.makedirs(opt.tensorboard_dir)
254
+ # board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.test_name, opt.datamode, opt.datasetting))
255
+
256
+ ## Model
257
+ # tocg
258
+ input1_nc = 4 # cloth + cloth-mask
259
+ input2_nc = opt.semantic_nc + 3 # parse_agnostic + densepose
260
+ tocg = ConditionGenerator(opt, input1_nc=input1_nc, input2_nc=input2_nc, output_nc=opt.output_nc, ngf=96, norm_layer=nn.BatchNorm2d)
261
+
262
+ # generator
263
+ opt.semantic_nc = 7
264
+ generator = SPADEGenerator(opt, 3+3+3)
265
+ generator.print_network()
266
+
267
+ # Load Checkpoint
268
+ load_checkpoint(tocg, opt.tocg_checkpoint,opt)
269
+ load_checkpoint_G(generator, opt.gen_checkpoint,opt)
270
+
271
+ # Train
272
+ test(opt, test_loader, tocg, generator)
273
+
274
+ print("Finished testing!")
275
+
276
+
277
+ if __name__ == "__main__":
278
+ main()
utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import cv2
7
+ import os
8
+
9
+ def get_clothes_mask(old_label) :
10
+ clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int))
11
+ return clothes
12
+
13
+ def changearm(old_label):
14
+ label=old_label
15
+ arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int))
16
+ arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int))
17
+ label=label*(1-arm1)+arm1*3
18
+ label=label*(1-arm2)+arm2*3
19
+ return label
20
+
21
+ def gen_noise(shape):
22
+ noise = np.zeros(shape, dtype=np.uint8)
23
+ ### noise
24
+ noise = cv2.randn(noise, 0, 255)
25
+ noise = np.asarray(noise / 255, dtype=np.uint8)
26
+ noise = torch.tensor(noise, dtype=torch.float32)
27
+ return noise
28
+
29
+ def cross_entropy2d(input, target, weight=None, size_average=True):
30
+ n, c, h, w = input.size()
31
+ nt, ht, wt = target.size()
32
+
33
+ # Handle inconsistent size between input and target
34
+ if h != ht or w != wt:
35
+ input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
36
+
37
+ input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
38
+ target = target.view(-1)
39
+ loss = F.cross_entropy(
40
+ input, target, weight=weight, size_average=size_average, ignore_index=250
41
+ )
42
+ return loss
43
+
44
+ def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0):
45
+ image_numpy = image_tensor[batch].cpu().float().numpy()
46
+ result = np.argmax(image_numpy, axis=0)
47
+ return result.astype(imtype)
48
+
49
+ def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) :
50
+ palette = [
51
+ 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51,
52
+ 254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85,
53
+ 85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220,
54
+ 0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0
55
+ ]
56
+ input = input.detach()
57
+ if multi_channel :
58
+ input = ndim_tensor2im(input,batch=batch)
59
+ else :
60
+ input = input[batch][0].cpu()
61
+ input = np.asarray(input)
62
+ input = input.astype(np.uint8)
63
+ input = Image.fromarray(input, 'P')
64
+ input.putpalette(palette)
65
+
66
+ if tensor_out :
67
+ trans = transforms.ToTensor()
68
+ return trans(input.convert('RGB'))
69
+
70
+ return input
71
+
72
+ def pred_to_onehot(prediction) :
73
+ size = prediction.shape
74
+ prediction_max = torch.argmax(prediction, dim=1)
75
+ oneHot_size = (size[0], 13, size[2], size[3])
76
+ pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
77
+ pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0)
78
+ return pred_onehot
79
+
80
+ def cal_miou(prediction, target) :
81
+ size = prediction.shape
82
+ target = target.cpu()
83
+ prediction = pred_to_onehot(prediction.detach().cpu())
84
+ list = [1,2,3,4,5,6,7,8]
85
+ union = 0
86
+ intersection = 0
87
+ for b in range(size[0]) :
88
+ for c in list :
89
+ intersection += torch.logical_and(target[b,c], prediction[b,c]).sum()
90
+ union += torch.logical_or(target[b,c], prediction[b,c]).sum()
91
+ return intersection.item()/union.item()
92
+
93
+ def save_images(img_tensors, img_names, save_dir):
94
+ for img_tensor, img_name in zip(img_tensors, img_names):
95
+ tensor = (img_tensor.clone() + 1) * 0.5 * 255
96
+ tensor = tensor.cpu().clamp(0, 255)
97
+
98
+ try:
99
+ array = tensor.numpy().astype('uint8')
100
+ except:
101
+ array = tensor.detach().numpy().astype('uint8')
102
+
103
+ if array.shape[0] == 1:
104
+ array = array.squeeze(0)
105
+ elif array.shape[0] == 3:
106
+ array = array.swapaxes(0, 1).swapaxes(1, 2)
107
+
108
+ im = Image.fromarray(array)
109
+ im.save(os.path.join(save_dir, img_name), format='JPEG')
110
+
111
+
112
+ def create_network(cls, opt):
113
+ net = cls(opt)
114
+ net.print_network()
115
+ if len(opt.gpu_ids) > 0:
116
+ assert(torch.cuda.is_available())
117
+ net.cuda()
118
+ net.init_weights(opt.init_type, opt.init_variance)
119
+ return net