feylur commited on
Commit
3aa47f5
·
verified ·
1 Parent(s): e12eb26

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +325 -0
inference.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import argparse
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from diffusers.image_processor import VaeImageProcessor
7
+ from tqdm import tqdm
8
+ from PIL import Image, ImageFilter
9
+
10
+ from model.pipeline import CatVTONPipeline
11
+
12
+ class InferenceDataset(Dataset):
13
+ def __init__(self, args):
14
+ self.args = args
15
+
16
+ self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
17
+ self.mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
18
+ self.data = self.load_data()
19
+
20
+ def load_data(self):
21
+ return []
22
+
23
+ def __len__(self):
24
+ return len(self.data)
25
+
26
+ def __getitem__(self, idx):
27
+ data = self.data[idx]
28
+ person, cloth, mask = [Image.open(data[key]) for key in ['person', 'cloth', 'mask']]
29
+ return {
30
+ 'index': idx,
31
+ 'person_name': data['person_name'],
32
+ 'person': self.vae_processor.preprocess(person, self.args.height, self.args.width)[0],
33
+ 'cloth': self.vae_processor.preprocess(cloth, self.args.height, self.args.width)[0],
34
+ 'mask': self.mask_processor.preprocess(mask, self.args.height, self.args.width)[0]
35
+ }
36
+
37
+ class VITONHDTestDataset(InferenceDataset):
38
+ def load_data(self):
39
+ assert os.path.exists(pair_txt:=os.path.join(self.args.data_root_path, 'test_pairs_unpaired.txt')), f"File {pair_txt} does not exist."
40
+ with open(pair_txt, 'r') as f:
41
+ lines = f.readlines()
42
+ self.args.data_root_path = os.path.join(self.args.data_root_path, "test")
43
+ output_dir = os.path.join(self.args.output_dir, "vitonhd", 'unpaired' if not self.args.eval_pair else 'paired')
44
+ data = []
45
+ for line in lines:
46
+ person_img, cloth_img = line.strip().split(" ")
47
+ if os.path.exists(os.path.join(output_dir, person_img)):
48
+ continue
49
+ if self.args.eval_pair:
50
+ cloth_img = person_img
51
+ data.append({
52
+ 'person_name': person_img,
53
+ 'person': os.path.join(self.args.data_root_path, 'image', person_img),
54
+ 'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img),
55
+ 'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')),
56
+ })
57
+ return data
58
+
59
+ class DressCodeTestDataset(InferenceDataset):
60
+ def load_data(self):
61
+ data = []
62
+ for sub_folder in ['upper_body', 'lower_body', 'dresses']:
63
+ assert os.path.exists(os.path.join(self.args.data_root_path, sub_folder)), f"Folder {sub_folder} does not exist."
64
+ pair_txt = os.path.join(self.args.data_root_path, sub_folder, 'test_pairs_paired.txt' if self.args.eval_pair else 'test_pairs_unpaired.txt')
65
+ assert os.path.exists(pair_txt), f"File {pair_txt} does not exist."
66
+ with open(pair_txt, 'r') as f:
67
+ lines = f.readlines()
68
+
69
+ output_dir = os.path.join(self.args.output_dir, f"dresscode-{self.args.height}",
70
+ 'unpaired' if not self.args.eval_pair else 'paired', sub_folder)
71
+ for line in lines:
72
+ person_img, cloth_img = line.strip().split(" ")
73
+ if os.path.exists(os.path.join(output_dir, person_img)):
74
+ continue
75
+ data.append({
76
+ 'person_name': os.path.join(sub_folder, person_img),
77
+ 'person': os.path.join(self.args.data_root_path, sub_folder, 'images', person_img),
78
+ 'cloth': os.path.join(self.args.data_root_path, sub_folder, 'images', cloth_img),
79
+ 'mask': os.path.join(self.args.data_root_path, sub_folder, 'agnostic_masks', person_img.replace('.jpg', '.png'))
80
+ })
81
+ return data
82
+
83
+
84
+ def parse_args():
85
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
86
+ parser.add_argument(
87
+ "--base_model_path",
88
+ type=str,
89
+ default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
90
+ help=(
91
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
92
+ ),
93
+ )
94
+ parser.add_argument(
95
+ "--resume_path",
96
+ type=str,
97
+ default="zhengchong/CatVTON",
98
+ help=(
99
+ "The Path to the checkpoint of trained tryon model."
100
+ ),
101
+ )
102
+ parser.add_argument(
103
+ "--dataset_name",
104
+ type=str,
105
+ required=True,
106
+ help="The datasets to use for evaluation.",
107
+ )
108
+ parser.add_argument(
109
+ "--data_root_path",
110
+ type=str,
111
+ required=True,
112
+ help="Path to the dataset to evaluate."
113
+ )
114
+ parser.add_argument(
115
+ "--output_dir",
116
+ type=str,
117
+ default="output",
118
+ help="The output directory where the model predictions will be written.",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--seed", type=int, default=555, help="A seed for reproducible evaluation."
123
+ )
124
+ parser.add_argument(
125
+ "--batch_size", type=int, default=8, help="The batch size for evaluation."
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--num_inference_steps",
130
+ type=int,
131
+ default=50,
132
+ help="Number of inference steps to perform.",
133
+ )
134
+ parser.add_argument(
135
+ "--guidance_scale",
136
+ type=float,
137
+ default=2.5,
138
+ help="The scale of classifier-free guidance for inference.",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--width",
143
+ type=int,
144
+ default=384,
145
+ help=(
146
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
147
+ " resolution"
148
+ ),
149
+ )
150
+ parser.add_argument(
151
+ "--height",
152
+ type=int,
153
+ default=512,
154
+ help=(
155
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
156
+ " resolution"
157
+ ),
158
+ )
159
+ parser.add_argument(
160
+ "--repaint",
161
+ action="store_true",
162
+ help="Whether to repaint the result image with the original background."
163
+ )
164
+ parser.add_argument(
165
+ "--eval_pair",
166
+ action="store_true",
167
+ help="Whether or not to evaluate the pair.",
168
+ )
169
+ parser.add_argument(
170
+ "--concat_eval_results",
171
+ action="store_true",
172
+ help="Whether or not to concatenate the all conditions into one image.",
173
+ )
174
+ parser.add_argument(
175
+ "--allow_tf32",
176
+ action="store_true",
177
+ default=True,
178
+ help=(
179
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
180
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
181
+ ),
182
+ )
183
+ parser.add_argument(
184
+ "--dataloader_num_workers",
185
+ type=int,
186
+ default=8,
187
+ help=(
188
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
189
+ ),
190
+ )
191
+ parser.add_argument(
192
+ "--mixed_precision",
193
+ type=str,
194
+ default="bf16",
195
+ choices=["no", "fp16", "bf16"],
196
+ help=(
197
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
198
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
199
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
200
+ ),
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--concat_axis",
205
+ type=str,
206
+ choices=["x", "y", 'random'],
207
+ default="y",
208
+ help="The axis to concat the cloth feature, select from ['x', 'y', 'random'].",
209
+ )
210
+ parser.add_argument(
211
+ "--enable_condition_noise",
212
+ action="store_true",
213
+ default=True,
214
+ help="Whether or not to enable condition noise.",
215
+ )
216
+
217
+ args = parser.parse_args()
218
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
219
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
220
+ args.local_rank = env_local_rank
221
+
222
+ return args
223
+
224
+
225
+ def repaint(person, mask, result):
226
+ _, h = result.size
227
+ kernal_size = h // 50
228
+ if kernal_size % 2 == 0:
229
+ kernal_size += 1
230
+ mask = mask.filter(ImageFilter.GaussianBlur(kernal_size))
231
+ person_np = np.array(person)
232
+ result_np = np.array(result)
233
+ mask_np = np.array(mask) / 255
234
+ repaint_result = person_np * (1 - mask_np) + result_np * mask_np
235
+ repaint_result = Image.fromarray(repaint_result.astype(np.uint8))
236
+ return repaint_result
237
+
238
+ def to_pil_image(images):
239
+ images = (images / 2 + 0.5).clamp(0, 1)
240
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
241
+ if images.ndim == 3:
242
+ images = images[None, ...]
243
+ images = (images * 255).round().astype("uint8")
244
+ if images.shape[-1] == 1:
245
+ # special case for grayscale (single channel) images
246
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
247
+ else:
248
+ pil_images = [Image.fromarray(image) for image in images]
249
+ return pil_images
250
+
251
+ @torch.no_grad()
252
+ def main():
253
+ args = parse_args()
254
+ # Pipeline
255
+ pipeline = CatVTONPipeline(
256
+ attn_ckpt_version=args.dataset_name,
257
+ attn_ckpt=args.resume_path,
258
+ base_ckpt=args.base_model_path,
259
+ weight_dtype={
260
+ "no": torch.float32,
261
+ "fp16": torch.float16,
262
+ "bf16": torch.bfloat16,
263
+ }[args.mixed_precision],
264
+ device="cuda",
265
+ skip_safety_check=True
266
+ )
267
+ # Dataset
268
+ if args.dataset_name == "vitonhd":
269
+ dataset = VITONHDTestDataset(args)
270
+ elif args.dataset_name == "dresscode":
271
+ dataset = DressCodeTestDataset(args)
272
+ else:
273
+ raise ValueError(f"Invalid dataset name {args.dataset}.")
274
+ print(f"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.")
275
+ dataloader = DataLoader(
276
+ dataset,
277
+ batch_size=args.batch_size,
278
+ shuffle=False,
279
+ num_workers=args.dataloader_num_workers
280
+ )
281
+ # Inference
282
+ generator = torch.Generator(device='cuda').manual_seed(args.seed)
283
+ args.output_dir = os.path.join(args.output_dir, f"{args.dataset_name}-{args.height}", "paired" if args.eval_pair else "unpaired")
284
+ if not os.path.exists(args.output_dir):
285
+ os.makedirs(args.output_dir)
286
+ for batch in tqdm(dataloader):
287
+ person_images = batch['person']
288
+ cloth_images = batch['cloth']
289
+ masks = batch['mask']
290
+ results = pipeline(
291
+ person_images,
292
+ cloth_images,
293
+ masks,
294
+ num_inference_steps=args.num_inference_steps,
295
+ guidance_scale=args.guidance_scale,
296
+ height=args.height,
297
+ width=args.width,
298
+ generator=generator,
299
+ )
300
+
301
+ if args.concat_eval_results or args.repaint:
302
+ person_images = to_pil_image(person_images)
303
+ cloth_images = to_pil_image(cloth_images)
304
+ masks = to_pil_image(masks)
305
+ for i, result in enumerate(results):
306
+ person_name = batch['person_name'][i]
307
+ output_path = os.path.join(args.output_dir, person_name)
308
+ if not os.path.exists(os.path.dirname(output_path)):
309
+ os.makedirs(os.path.dirname(output_path))
310
+ if args.repaint:
311
+ person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']
312
+ person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)
313
+ mask = Image.open(mask_path).resize(result.size, Image.NEAREST)
314
+ result = repaint(person_image, mask, result)
315
+ if args.concat_eval_results:
316
+ w, h = result.size
317
+ concated_result = Image.new('RGB', (w*3, h))
318
+ concated_result.paste(person_images[i], (0, 0))
319
+ concated_result.paste(cloth_images[i], (w, 0))
320
+ concated_result.paste(result, (w*2, 0))
321
+ result = concated_result
322
+ result.save(output_path)
323
+
324
+ if __name__ == "__main__":
325
+ main()