b3h-young123 commited on
Commit
31c8e71
·
verified ·
1 Parent(s): 9391a86

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CatVTON/.gitattributes +36 -0
  2. CatVTON/.gitignore +2 -0
  3. CatVTON/README.md +13 -0
  4. CatVTON/__pycache__/utils.cpython-39.pyc +0 -0
  5. CatVTON/app.py +778 -0
  6. CatVTON/densepose/__init__.py +22 -0
  7. CatVTON/densepose/__pycache__/__init__.cpython-39.pyc +0 -0
  8. CatVTON/densepose/__pycache__/config.cpython-39.pyc +0 -0
  9. CatVTON/densepose/config.py +277 -0
  10. CatVTON/densepose/converters/__init__.py +17 -0
  11. CatVTON/densepose/converters/__pycache__/__init__.cpython-39.pyc +0 -0
  12. CatVTON/densepose/converters/__pycache__/base.cpython-39.pyc +0 -0
  13. CatVTON/densepose/converters/__pycache__/builtin.cpython-39.pyc +0 -0
  14. CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc +0 -0
  15. CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc +0 -0
  16. CatVTON/densepose/converters/__pycache__/hflip.cpython-39.pyc +0 -0
  17. CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc +0 -0
  18. CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-39.pyc +0 -0
  19. CatVTON/densepose/converters/__pycache__/to_mask.cpython-39.pyc +0 -0
  20. CatVTON/densepose/converters/base.py +95 -0
  21. CatVTON/densepose/converters/builtin.py +33 -0
  22. CatVTON/densepose/converters/chart_output_hflip.py +73 -0
  23. CatVTON/densepose/converters/chart_output_to_chart_result.py +190 -0
  24. CatVTON/densepose/converters/hflip.py +36 -0
  25. CatVTON/densepose/converters/segm_to_mask.py +152 -0
  26. CatVTON/densepose/converters/to_chart_result.py +72 -0
  27. CatVTON/densepose/converters/to_mask.py +51 -0
  28. CatVTON/densepose/engine/__init__.py +5 -0
  29. CatVTON/densepose/engine/trainer.py +260 -0
  30. CatVTON/densepose/modeling/__init__.py +15 -0
  31. CatVTON/densepose/modeling/build.py +89 -0
  32. CatVTON/densepose/modeling/confidence.py +75 -0
  33. CatVTON/densepose/modeling/densepose_checkpoint.py +37 -0
  34. CatVTON/densepose/modeling/filter.py +96 -0
  35. CatVTON/densepose/modeling/hrfpn.py +184 -0
  36. CatVTON/densepose/modeling/hrnet.py +476 -0
  37. CatVTON/densepose/modeling/inference.py +46 -0
  38. CatVTON/densepose/modeling/test_time_augmentation.py +209 -0
  39. CatVTON/densepose/modeling/utils.py +13 -0
  40. CatVTON/densepose/utils/__init__.py +0 -0
  41. CatVTON/densepose/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  42. CatVTON/densepose/utils/__pycache__/transform.cpython-39.pyc +0 -0
  43. CatVTON/densepose/utils/dbhelper.py +149 -0
  44. CatVTON/densepose/utils/logger.py +15 -0
  45. CatVTON/densepose/utils/transform.py +17 -0
  46. CatVTON/model/DensePose/__init__.py +158 -0
  47. CatVTON/model/DensePose/__pycache__/__init__.cpython-310.pyc +0 -0
  48. CatVTON/model/DensePose/__pycache__/__init__.cpython-312.pyc +0 -0
  49. CatVTON/model/DensePose/__pycache__/__init__.cpython-39.pyc +0 -0
  50. CatVTON/model/SCHP/__init__.py +179 -0
CatVTON/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
CatVTON/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ playground.py
2
+ __pycache__
CatVTON/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CatVTON
3
+ emoji: 🐈
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-sa-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
CatVTON/__pycache__/utils.cpython-39.pyc ADDED
Binary file (20.3 kB). View file
 
CatVTON/app.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
+ parser.add_argument(
24
+ "--base_model_path",
25
+ type=str,
26
+ default="booksforcharlie/stable-diffusion-inpainting",
27
+ help=(
28
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
+ ),
30
+ )
31
+ parser.add_argument(
32
+ "--p2p_base_model_path",
33
+ type=str,
34
+ default="timbrooks/instruct-pix2pix",
35
+ help=(
36
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
+ ),
38
+ )
39
+ parser.add_argument(
40
+ "--resume_path",
41
+ type=str,
42
+ default="zhengchong/CatVTON",
43
+ help=(
44
+ "The Path to the checkpoint of trained tryon model."
45
+ ),
46
+ )
47
+ parser.add_argument(
48
+ "--output_dir",
49
+ type=str,
50
+ default="resource/demo/output",
51
+ help="The output directory where the model predictions will be written.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--width",
56
+ type=int,
57
+ default=768,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--height",
65
+ type=int,
66
+ default=1024,
67
+ help=(
68
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
+ " resolution"
70
+ ),
71
+ )
72
+ parser.add_argument(
73
+ "--repaint",
74
+ action="store_true",
75
+ help="Whether to repaint the result image with the original background."
76
+ )
77
+ parser.add_argument(
78
+ "--allow_tf32",
79
+ action="store_true",
80
+ default=True,
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--mixed_precision",
88
+ type=str,
89
+ default="bf16",
90
+ choices=["no", "fp16", "bf16"],
91
+ help=(
92
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
+ ),
96
+ )
97
+
98
+ args = parser.parse_args()
99
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
101
+ args.local_rank = env_local_rank
102
+
103
+ return args
104
+
105
+ def image_grid(imgs, rows, cols):
106
+ assert len(imgs) == rows * cols
107
+
108
+ w, h = imgs[0].size
109
+ grid = Image.new("RGB", size=(cols * w, rows * h))
110
+
111
+ for i, img in enumerate(imgs):
112
+ grid.paste(img, box=(i % cols * w, i // cols * h))
113
+ return grid
114
+
115
+
116
+ args = parse_args()
117
+
118
+ # Mask-based CatVTON
119
+ catvton_repo = "zhengchong/CatVTON"
120
+ repo_path = snapshot_download(repo_id=catvton_repo)
121
+ # Pipeline
122
+ pipeline = CatVTONPipeline(
123
+ base_ckpt=args.base_model_path,
124
+ attn_ckpt=repo_path,
125
+ attn_ckpt_version="mix",
126
+ weight_dtype=init_weight_dtype(args.mixed_precision),
127
+ use_tf32=args.allow_tf32,
128
+ device='cuda'
129
+ )
130
+ # AutoMasker
131
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
+ automasker = AutoMasker(
133
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
135
+ device='cuda',
136
+ )
137
+
138
+
139
+ # Flux-based CatVTON
140
+ access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
+ flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
+ pipeline_flux.load_lora_weights(
144
+ os.path.join(repo_path, "flux-lora"),
145
+ weight_name='pytorch_lora_weights.safetensors'
146
+ )
147
+ pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
+
149
+
150
+ # Mask-free CatVTON
151
+ catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
152
+ repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
153
+ pipeline_p2p = CatVTONPix2PixPipeline(
154
+ base_ckpt=args.p2p_base_model_path,
155
+ attn_ckpt=repo_path_mf,
156
+ attn_ckpt_version="mix-48k-1024",
157
+ weight_dtype=init_weight_dtype(args.mixed_precision),
158
+ use_tf32=args.allow_tf32,
159
+ device='cuda'
160
+ )
161
+
162
+
163
+ @spaces.GPU(duration=120)
164
+ def submit_function(
165
+ person_image,
166
+ cloth_image,
167
+ cloth_type,
168
+ num_inference_steps,
169
+ guidance_scale,
170
+ seed,
171
+ show_type
172
+ ):
173
+ person_image, mask = person_image["background"], person_image["layers"][0]
174
+ mask = Image.open(mask).convert("L")
175
+ if len(np.unique(np.array(mask))) == 1:
176
+ mask = None
177
+ else:
178
+ mask = np.array(mask)
179
+ mask[mask > 0] = 255
180
+ mask = Image.fromarray(mask)
181
+
182
+ tmp_folder = args.output_dir
183
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
184
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
185
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
186
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
187
+
188
+ generator = None
189
+ if seed != -1:
190
+ generator = torch.Generator(device='cuda').manual_seed(seed)
191
+
192
+ person_image = Image.open(person_image).convert("RGB")
193
+ cloth_image = Image.open(cloth_image).convert("RGB")
194
+ person_image = resize_and_crop(person_image, (args.width, args.height))
195
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
196
+
197
+ # Process mask
198
+ if mask is not None:
199
+ mask = resize_and_crop(mask, (args.width, args.height))
200
+ else:
201
+ mask = automasker(
202
+ person_image,
203
+ cloth_type
204
+ )['mask']
205
+ mask = mask_processor.blur(mask, blur_factor=9)
206
+
207
+ # Inference
208
+ # try:
209
+ result_image = pipeline(
210
+ image=person_image,
211
+ condition_image=cloth_image,
212
+ mask=mask,
213
+ num_inference_steps=num_inference_steps,
214
+ guidance_scale=guidance_scale,
215
+ generator=generator
216
+ )[0]
217
+ # except Exception as e:
218
+ # raise gr.Error(
219
+ # "An error occurred. Please try again later: {}".format(e)
220
+ # )
221
+
222
+ # Post-process
223
+ masked_person = vis_mask(person_image, mask)
224
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
225
+ save_result_image.save(result_save_path)
226
+ if show_type == "result only":
227
+ return result_image
228
+ else:
229
+ width, height = person_image.size
230
+ if show_type == "input & result":
231
+ condition_width = width // 2
232
+ conditions = image_grid([person_image, cloth_image], 2, 1)
233
+ else:
234
+ condition_width = width // 3
235
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
236
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
238
+ new_result_image.paste(conditions, (0, 0))
239
+ new_result_image.paste(result_image, (condition_width + 5, 0))
240
+ return new_result_image
241
+
242
+ @spaces.GPU(duration=120)
243
+ def submit_function_p2p(
244
+ person_image,
245
+ cloth_image,
246
+ num_inference_steps,
247
+ guidance_scale,
248
+ seed):
249
+ person_image= person_image["background"]
250
+
251
+ tmp_folder = args.output_dir
252
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
253
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
254
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
255
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
256
+
257
+ generator = None
258
+ if seed != -1:
259
+ generator = torch.Generator(device='cuda').manual_seed(seed)
260
+
261
+ person_image = Image.open(person_image).convert("RGB")
262
+ cloth_image = Image.open(cloth_image).convert("RGB")
263
+ person_image = resize_and_crop(person_image, (args.width, args.height))
264
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
+
266
+ # Inference
267
+ try:
268
+ result_image = pipeline_p2p(
269
+ image=person_image,
270
+ condition_image=cloth_image,
271
+ num_inference_steps=num_inference_steps,
272
+ guidance_scale=guidance_scale,
273
+ generator=generator
274
+ )[0]
275
+ except Exception as e:
276
+ raise gr.Error(
277
+ "An error occurred. Please try again later: {}".format(e)
278
+ )
279
+
280
+ # Post-process
281
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
282
+ save_result_image.save(result_save_path)
283
+ return result_image
284
+
285
+ @spaces.GPU(duration=120)
286
+ def submit_function_flux(
287
+ person_image,
288
+ cloth_image,
289
+ cloth_type,
290
+ num_inference_steps,
291
+ guidance_scale,
292
+ seed,
293
+ show_type
294
+ ):
295
+
296
+ # Process image editor input
297
+ person_image, mask = person_image["background"], person_image["layers"][0]
298
+ mask = Image.open(mask).convert("L")
299
+ if len(np.unique(np.array(mask))) == 1:
300
+ mask = None
301
+ else:
302
+ mask = np.array(mask)
303
+ mask[mask > 0] = 255
304
+ mask = Image.fromarray(mask)
305
+
306
+ # Set random seed
307
+ generator = None
308
+ if seed != -1:
309
+ generator = torch.Generator(device='cuda').manual_seed(seed)
310
+
311
+ # Process input images
312
+ person_image = Image.open(person_image).convert("RGB")
313
+ cloth_image = Image.open(cloth_image).convert("RGB")
314
+
315
+ # Adjust image sizes
316
+ person_image = resize_and_crop(person_image, (args.width, args.height))
317
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
318
+
319
+ # Process mask
320
+ if mask is not None:
321
+ mask = resize_and_crop(mask, (args.width, args.height))
322
+ else:
323
+ mask = automasker(
324
+ person_image,
325
+ cloth_type
326
+ )['mask']
327
+ mask = mask_processor.blur(mask, blur_factor=9)
328
+
329
+ # Inference
330
+ result_image = pipeline_flux(
331
+ image=person_image,
332
+ condition_image=cloth_image,
333
+ mask_image=mask,
334
+ width=args.width,
335
+ height=args.height,
336
+ num_inference_steps=num_inference_steps,
337
+ guidance_scale=guidance_scale,
338
+ generator=generator
339
+ ).images[0]
340
+
341
+ # Post-processing
342
+ masked_person = vis_mask(person_image, mask)
343
+
344
+ # Return result based on show type
345
+ if show_type == "result only":
346
+ return result_image
347
+ else:
348
+ width, height = person_image.size
349
+ if show_type == "input & result":
350
+ condition_width = width // 2
351
+ conditions = image_grid([person_image, cloth_image], 2, 1)
352
+ else:
353
+ condition_width = width // 3
354
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
355
+
356
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
357
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
358
+ new_result_image.paste(conditions, (0, 0))
359
+ new_result_image.paste(result_image, (condition_width + 5, 0))
360
+ return new_result_image
361
+
362
+
363
+ def person_example_fn(image_path):
364
+ return image_path
365
+
366
+
367
+ HEADER = """
368
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
369
+ <div style="display: flex; justify-content: center; align-items: center;">
370
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
371
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
372
+ </a>
373
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
374
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
375
+ </a>
376
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
377
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
378
+ </a>
379
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
380
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
381
+ </a>
382
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
383
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
384
+ </a>
385
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
386
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
387
+ </a>
388
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
389
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
390
+ </a>
391
+ </div>
392
+ <br>
393
+ · This demo and our weights are only for Non-commercial Use. <br>
394
+ · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
395
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
396
+ """
397
+
398
+ def app_gradio():
399
+ with gr.Blocks(title="CatVTON") as demo:
400
+ gr.Markdown(HEADER)
401
+ with gr.Tab("Mask-based & SD1.5"):
402
+ with gr.Row():
403
+ with gr.Column(scale=1, min_width=350):
404
+ with gr.Row():
405
+ image_path = gr.Image(
406
+ type="filepath",
407
+ interactive=True,
408
+ visible=False,
409
+ )
410
+ person_image = gr.ImageEditor(
411
+ interactive=True, label="Person Image", type="filepath"
412
+ )
413
+
414
+ with gr.Row():
415
+ with gr.Column(scale=1, min_width=230):
416
+ cloth_image = gr.Image(
417
+ interactive=True, label="Condition Image", type="filepath"
418
+ )
419
+ with gr.Column(scale=1, min_width=120):
420
+ gr.Markdown(
421
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
422
+ )
423
+ cloth_type = gr.Radio(
424
+ label="Try-On Cloth Type",
425
+ choices=["upper", "lower", "overall"],
426
+ value="upper",
427
+ )
428
+
429
+
430
+ submit = gr.Button("Submit")
431
+ gr.Markdown(
432
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
433
+ )
434
+
435
+ gr.Markdown(
436
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
437
+ )
438
+ with gr.Accordion("Advanced Options", open=False):
439
+ num_inference_steps = gr.Slider(
440
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
441
+ )
442
+ # Guidence Scale
443
+ guidance_scale = gr.Slider(
444
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
445
+ )
446
+ # Random Seed
447
+ seed = gr.Slider(
448
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
449
+ )
450
+ show_type = gr.Radio(
451
+ label="Show Type",
452
+ choices=["result only", "input & result", "input & mask & result"],
453
+ value="input & mask & result",
454
+ )
455
+
456
+ with gr.Column(scale=2, min_width=500):
457
+ result_image = gr.Image(interactive=False, label="Result")
458
+ with gr.Row():
459
+ # Photo Examples
460
+ root_path = "resource/demo/example"
461
+ with gr.Column():
462
+ men_exm = gr.Examples(
463
+ examples=[
464
+ os.path.join(root_path, "person", "men", _)
465
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
466
+ ],
467
+ examples_per_page=4,
468
+ inputs=image_path,
469
+ label="Person Examples ①",
470
+ )
471
+ women_exm = gr.Examples(
472
+ examples=[
473
+ os.path.join(root_path, "person", "women", _)
474
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
475
+ ],
476
+ examples_per_page=4,
477
+ inputs=image_path,
478
+ label="Person Examples ②",
479
+ )
480
+ gr.Markdown(
481
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
482
+ )
483
+ with gr.Column():
484
+ condition_upper_exm = gr.Examples(
485
+ examples=[
486
+ os.path.join(root_path, "condition", "upper", _)
487
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
488
+ ],
489
+ examples_per_page=4,
490
+ inputs=cloth_image,
491
+ label="Condition Upper Examples",
492
+ )
493
+ condition_overall_exm = gr.Examples(
494
+ examples=[
495
+ os.path.join(root_path, "condition", "overall", _)
496
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
497
+ ],
498
+ examples_per_page=4,
499
+ inputs=cloth_image,
500
+ label="Condition Overall Examples",
501
+ )
502
+ condition_person_exm = gr.Examples(
503
+ examples=[
504
+ os.path.join(root_path, "condition", "person", _)
505
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
506
+ ],
507
+ examples_per_page=4,
508
+ inputs=cloth_image,
509
+ label="Condition Reference Person Examples",
510
+ )
511
+ gr.Markdown(
512
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
513
+ )
514
+
515
+ image_path.change(
516
+ person_example_fn, inputs=image_path, outputs=person_image
517
+ )
518
+
519
+ submit.click(
520
+ submit_function,
521
+ [
522
+ person_image,
523
+ cloth_image,
524
+ cloth_type,
525
+ num_inference_steps,
526
+ guidance_scale,
527
+ seed,
528
+ show_type,
529
+ ],
530
+ result_image,
531
+ )
532
+
533
+ with gr.Tab("Mask-based & Flux.1 Fill Dev"):
534
+ with gr.Row():
535
+ with gr.Column(scale=1, min_width=350):
536
+ with gr.Row():
537
+ image_path_flux = gr.Image(
538
+ type="filepath",
539
+ interactive=True,
540
+ visible=False,
541
+ )
542
+ person_image_flux = gr.ImageEditor(
543
+ interactive=True, label="Person Image", type="filepath"
544
+ )
545
+
546
+ with gr.Row():
547
+ with gr.Column(scale=1, min_width=230):
548
+ cloth_image_flux = gr.Image(
549
+ interactive=True, label="Condition Image", type="filepath"
550
+ )
551
+ with gr.Column(scale=1, min_width=120):
552
+ gr.Markdown(
553
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
554
+ )
555
+ cloth_type = gr.Radio(
556
+ label="Try-On Cloth Type",
557
+ choices=["upper", "lower", "overall"],
558
+ value="upper",
559
+ )
560
+
561
+ submit_flux = gr.Button("Submit")
562
+ gr.Markdown(
563
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
564
+ )
565
+
566
+ with gr.Accordion("Advanced Options", open=False):
567
+ num_inference_steps_flux = gr.Slider(
568
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
569
+ )
570
+ # Guidence Scale
571
+ guidance_scale_flux = gr.Slider(
572
+ label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
573
+ )
574
+ # Random Seed
575
+ seed_flux = gr.Slider(
576
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
577
+ )
578
+ show_type = gr.Radio(
579
+ label="Show Type",
580
+ choices=["result only", "input & result", "input & mask & result"],
581
+ value="input & mask & result",
582
+ )
583
+
584
+ with gr.Column(scale=2, min_width=500):
585
+ result_image_flux = gr.Image(interactive=False, label="Result")
586
+ with gr.Row():
587
+ # Photo Examples
588
+ root_path = "resource/demo/example"
589
+ with gr.Column():
590
+ gr.Examples(
591
+ examples=[
592
+ os.path.join(root_path, "person", "men", _)
593
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
594
+ ],
595
+ examples_per_page=4,
596
+ inputs=image_path_flux,
597
+ label="Person Examples ①",
598
+ )
599
+ gr.Examples(
600
+ examples=[
601
+ os.path.join(root_path, "person", "women", _)
602
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
603
+ ],
604
+ examples_per_page=4,
605
+ inputs=image_path_flux,
606
+ label="Person Examples ②",
607
+ )
608
+ gr.Markdown(
609
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
610
+ )
611
+ with gr.Column():
612
+ gr.Examples(
613
+ examples=[
614
+ os.path.join(root_path, "condition", "upper", _)
615
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
616
+ ],
617
+ examples_per_page=4,
618
+ inputs=cloth_image_flux,
619
+ label="Condition Upper Examples",
620
+ )
621
+ gr.Examples(
622
+ examples=[
623
+ os.path.join(root_path, "condition", "overall", _)
624
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
625
+ ],
626
+ examples_per_page=4,
627
+ inputs=cloth_image_flux,
628
+ label="Condition Overall Examples",
629
+ )
630
+ condition_person_exm = gr.Examples(
631
+ examples=[
632
+ os.path.join(root_path, "condition", "person", _)
633
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
634
+ ],
635
+ examples_per_page=4,
636
+ inputs=cloth_image_flux,
637
+ label="Condition Reference Person Examples",
638
+ )
639
+ gr.Markdown(
640
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
641
+ )
642
+
643
+
644
+ image_path_flux.change(
645
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
646
+ )
647
+
648
+ submit_flux.click(
649
+ submit_function_flux,
650
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
651
+ result_image_flux,
652
+ )
653
+
654
+
655
+ with gr.Tab("Mask-free & SD1.5"):
656
+ with gr.Row():
657
+ with gr.Column(scale=1, min_width=350):
658
+ with gr.Row():
659
+ image_path_p2p = gr.Image(
660
+ type="filepath",
661
+ interactive=True,
662
+ visible=False,
663
+ )
664
+ person_image_p2p = gr.ImageEditor(
665
+ interactive=True, label="Person Image", type="filepath"
666
+ )
667
+
668
+ with gr.Row():
669
+ with gr.Column(scale=1, min_width=230):
670
+ cloth_image_p2p = gr.Image(
671
+ interactive=True, label="Condition Image", type="filepath"
672
+ )
673
+
674
+ submit_p2p = gr.Button("Submit")
675
+ gr.Markdown(
676
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
677
+ )
678
+
679
+ gr.Markdown(
680
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
681
+ )
682
+ with gr.Accordion("Advanced Options", open=False):
683
+ num_inference_steps_p2p = gr.Slider(
684
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
685
+ )
686
+ # Guidence Scale
687
+ guidance_scale_p2p = gr.Slider(
688
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
689
+ )
690
+ # Random Seed
691
+ seed_p2p = gr.Slider(
692
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
693
+ )
694
+ # show_type = gr.Radio(
695
+ # label="Show Type",
696
+ # choices=["result only", "input & result", "input & mask & result"],
697
+ # value="input & mask & result",
698
+ # )
699
+
700
+ with gr.Column(scale=2, min_width=500):
701
+ result_image_p2p = gr.Image(interactive=False, label="Result")
702
+ with gr.Row():
703
+ # Photo Examples
704
+ root_path = "resource/demo/example"
705
+ with gr.Column():
706
+ gr.Examples(
707
+ examples=[
708
+ os.path.join(root_path, "person", "men", _)
709
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
710
+ ],
711
+ examples_per_page=4,
712
+ inputs=image_path_p2p,
713
+ label="Person Examples ①",
714
+ )
715
+ gr.Examples(
716
+ examples=[
717
+ os.path.join(root_path, "person", "women", _)
718
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
719
+ ],
720
+ examples_per_page=4,
721
+ inputs=image_path_p2p,
722
+ label="Person Examples ②",
723
+ )
724
+ gr.Markdown(
725
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
726
+ )
727
+ with gr.Column():
728
+ gr.Examples(
729
+ examples=[
730
+ os.path.join(root_path, "condition", "upper", _)
731
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
732
+ ],
733
+ examples_per_page=4,
734
+ inputs=cloth_image_p2p,
735
+ label="Condition Upper Examples",
736
+ )
737
+ gr.Examples(
738
+ examples=[
739
+ os.path.join(root_path, "condition", "overall", _)
740
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
741
+ ],
742
+ examples_per_page=4,
743
+ inputs=cloth_image_p2p,
744
+ label="Condition Overall Examples",
745
+ )
746
+ condition_person_exm = gr.Examples(
747
+ examples=[
748
+ os.path.join(root_path, "condition", "person", _)
749
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
750
+ ],
751
+ examples_per_page=4,
752
+ inputs=cloth_image_p2p,
753
+ label="Condition Reference Person Examples",
754
+ )
755
+ gr.Markdown(
756
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
757
+ )
758
+
759
+ image_path_p2p.change(
760
+ person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
761
+ )
762
+
763
+ submit_p2p.click(
764
+ submit_function_p2p,
765
+ [
766
+ person_image_p2p,
767
+ cloth_image_p2p,
768
+ num_inference_steps_p2p,
769
+ guidance_scale_p2p,
770
+ seed_p2p],
771
+ result_image_p2p,
772
+ )
773
+
774
+ demo.queue().launch(share=True, show_error=True)
775
+
776
+
777
+ if __name__ == "__main__":
778
+ app_gradio()
CatVTON/densepose/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .data.datasets import builtin # just to register data
5
+ from .converters import builtin as builtin_converters # register converters
6
+ from .config import (
7
+ add_densepose_config,
8
+ add_densepose_head_config,
9
+ add_hrnet_config,
10
+ add_dataset_category_config,
11
+ add_bootstrap_config,
12
+ load_bootstrap_config,
13
+ )
14
+ from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
15
+ from .evaluation import DensePoseCOCOEvaluator
16
+ from .modeling.roi_heads import DensePoseROIHeads
17
+ from .modeling.test_time_augmentation import (
18
+ DensePoseGeneralizedRCNNWithTTA,
19
+ DensePoseDatasetMapperTTA,
20
+ )
21
+ from .utils.transform import load_from_cfg
22
+ from .modeling.hrfpn import build_hrfpn_backbone
CatVTON/densepose/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (925 Bytes). View file
 
CatVTON/densepose/__pycache__/config.cpython-39.pyc ADDED
Binary file (5.82 kB). View file
 
CatVTON/densepose/config.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding = utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # pyre-ignore-all-errors
4
+
5
+ from detectron2.config import CfgNode as CN
6
+
7
+
8
+ def add_dataset_category_config(cfg: CN) -> None:
9
+ """
10
+ Add config for additional category-related dataset options
11
+ - category whitelisting
12
+ - category mapping
13
+ """
14
+ _C = cfg
15
+ _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
16
+ _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
17
+ # class to mesh mapping
18
+ _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
19
+
20
+
21
+ def add_evaluation_config(cfg: CN) -> None:
22
+ _C = cfg
23
+ _C.DENSEPOSE_EVALUATION = CN()
24
+ # evaluator type, possible values:
25
+ # - "iou": evaluator for models that produce iou data
26
+ # - "cse": evaluator for models that produce cse data
27
+ _C.DENSEPOSE_EVALUATION.TYPE = "iou"
28
+ # storage for DensePose results, possible values:
29
+ # - "none": no explicit storage, all the results are stored in the
30
+ # dictionary with predictions, memory intensive;
31
+ # historically the default storage type
32
+ # - "ram": RAM storage, uses per-process RAM storage, which is
33
+ # reduced to a single process storage on later stages,
34
+ # less memory intensive
35
+ # - "file": file storage, uses per-process file-based storage,
36
+ # the least memory intensive, but may create bottlenecks
37
+ # on file system accesses
38
+ _C.DENSEPOSE_EVALUATION.STORAGE = "none"
39
+ # minimum threshold for IOU values: the lower its values is,
40
+ # the more matches are produced (and the higher the AP score)
41
+ _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
42
+ # Non-distributed inference is slower (at inference time) but can avoid RAM OOM
43
+ _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
44
+ # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
45
+ _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
46
+ # meshes to compute mesh alignment for
47
+ _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
48
+
49
+
50
+ def add_bootstrap_config(cfg: CN) -> None:
51
+ """ """
52
+ _C = cfg
53
+ _C.BOOTSTRAP_DATASETS = []
54
+ _C.BOOTSTRAP_MODEL = CN()
55
+ _C.BOOTSTRAP_MODEL.WEIGHTS = ""
56
+ _C.BOOTSTRAP_MODEL.DEVICE = "cuda"
57
+
58
+
59
+ def get_bootstrap_dataset_config() -> CN:
60
+ _C = CN()
61
+ _C.DATASET = ""
62
+ # ratio used to mix data loaders
63
+ _C.RATIO = 0.1
64
+ # image loader
65
+ _C.IMAGE_LOADER = CN(new_allowed=True)
66
+ _C.IMAGE_LOADER.TYPE = ""
67
+ _C.IMAGE_LOADER.BATCH_SIZE = 4
68
+ _C.IMAGE_LOADER.NUM_WORKERS = 4
69
+ _C.IMAGE_LOADER.CATEGORIES = []
70
+ _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
71
+ _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
72
+ # inference
73
+ _C.INFERENCE = CN()
74
+ # batch size for model inputs
75
+ _C.INFERENCE.INPUT_BATCH_SIZE = 4
76
+ # batch size to group model outputs
77
+ _C.INFERENCE.OUTPUT_BATCH_SIZE = 2
78
+ # sampled data
79
+ _C.DATA_SAMPLER = CN(new_allowed=True)
80
+ _C.DATA_SAMPLER.TYPE = ""
81
+ _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
82
+ # filter
83
+ _C.FILTER = CN(new_allowed=True)
84
+ _C.FILTER.TYPE = ""
85
+ return _C
86
+
87
+
88
+ def load_bootstrap_config(cfg: CN) -> None:
89
+ """
90
+ Bootstrap datasets are given as a list of `dict` that are not automatically
91
+ converted into CfgNode. This method processes all bootstrap dataset entries
92
+ and ensures that they are in CfgNode format and comply with the specification
93
+ """
94
+ if not cfg.BOOTSTRAP_DATASETS:
95
+ return
96
+
97
+ bootstrap_datasets_cfgnodes = []
98
+ for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
99
+ _C = get_bootstrap_dataset_config().clone()
100
+ _C.merge_from_other_cfg(CN(dataset_cfg))
101
+ bootstrap_datasets_cfgnodes.append(_C)
102
+ cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
103
+
104
+
105
+ def add_densepose_head_cse_config(cfg: CN) -> None:
106
+ """
107
+ Add configuration options for Continuous Surface Embeddings (CSE)
108
+ """
109
+ _C = cfg
110
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
111
+ # Dimensionality D of the embedding space
112
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
113
+ # Embedder specifications for various mesh IDs
114
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
115
+ # normalization coefficient for embedding distances
116
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
117
+ # normalization coefficient for geodesic distances
118
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
119
+ # embedding loss weight
120
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
121
+ # embedding loss name, currently the following options are supported:
122
+ # - EmbeddingLoss: cross-entropy on vertex labels
123
+ # - SoftEmbeddingLoss: cross-entropy on vertex label combined with
124
+ # Gaussian penalty on distance between vertices
125
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
126
+ # optimizer hyperparameters
127
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
128
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
129
+ # Shape to shape cycle consistency loss parameters:
130
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
131
+ # shape to shape cycle consistency loss weight
132
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
133
+ # norm type used for loss computation
134
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
135
+ # normalization term for embedding similarity matrices
136
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
137
+ # maximum number of vertices to include into shape to shape cycle loss
138
+ # if negative or zero, all vertices are considered
139
+ # if positive, random subset of vertices of given size is considered
140
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
141
+ # Pixel to shape cycle consistency loss parameters:
142
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
143
+ # pixel to shape cycle consistency loss weight
144
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
145
+ # norm type used for loss computation
146
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
147
+ # map images to all meshes and back (if false, use only gt meshes from the batch)
148
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
149
+ # Randomly select at most this number of pixels from every instance
150
+ # if negative or zero, all vertices are considered
151
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
152
+ # normalization factor for pixel to pixel distances (higher value = smoother distribution)
153
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
154
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
155
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
156
+
157
+
158
+ def add_densepose_head_config(cfg: CN) -> None:
159
+ """
160
+ Add config for densepose head.
161
+ """
162
+ _C = cfg
163
+
164
+ _C.MODEL.DENSEPOSE_ON = True
165
+
166
+ _C.MODEL.ROI_DENSEPOSE_HEAD = CN()
167
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
168
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
169
+ # Number of parts used for point labels
170
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
171
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
172
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
173
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
174
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
175
+ _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
176
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
177
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
178
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
179
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
180
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
181
+ _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
182
+ # Loss weights for annotation masks.(14 Parts)
183
+ _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
184
+ # Loss weights for surface parts. (24 Parts)
185
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
186
+ # Loss weights for UV regression.
187
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
188
+ # Coarse segmentation is trained using instance segmentation task data
189
+ _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
190
+ # For Decoder
191
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
192
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
193
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
194
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
195
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
196
+ # For DeepLab head
197
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
198
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
199
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
200
+ # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
201
+ # Some registered predictors:
202
+ # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
203
+ # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
204
+ # and associated confidences for predefined charts (default)
205
+ # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
206
+ # and associated confidences for CSE
207
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
208
+ # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
209
+ # Some registered losses:
210
+ # "DensePoseChartLoss": loss for chart-based models that estimate
211
+ # segmentation and UV coordinates
212
+ # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
213
+ # segmentation, UV coordinates and the corresponding confidences (default)
214
+ _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
215
+ # Confidences
216
+ # Enable learning UV confidences (variances) along with the actual values
217
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
218
+ # UV confidence lower bound
219
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
220
+ # Enable learning segmentation confidences (variances) along with the actual values
221
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
222
+ # Segmentation confidence lower bound
223
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
224
+ # Statistical model type for confidence learning, possible values:
225
+ # - "iid_iso": statistically independent identically distributed residuals
226
+ # with isotropic covariance
227
+ # - "indep_aniso": statistically independent residuals with anisotropic
228
+ # covariances
229
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
230
+ # List of angles for rotation in data augmentation during training
231
+ _C.INPUT.ROTATION_ANGLES = [0]
232
+ _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
233
+
234
+ add_densepose_head_cse_config(cfg)
235
+
236
+
237
+ def add_hrnet_config(cfg: CN) -> None:
238
+ """
239
+ Add config for HRNet backbone.
240
+ """
241
+ _C = cfg
242
+
243
+ # For HigherHRNet w32
244
+ _C.MODEL.HRNET = CN()
245
+ _C.MODEL.HRNET.STEM_INPLANES = 64
246
+ _C.MODEL.HRNET.STAGE2 = CN()
247
+ _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
248
+ _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
249
+ _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
250
+ _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
251
+ _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
252
+ _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
253
+ _C.MODEL.HRNET.STAGE3 = CN()
254
+ _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
255
+ _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
256
+ _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
257
+ _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
258
+ _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
259
+ _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
260
+ _C.MODEL.HRNET.STAGE4 = CN()
261
+ _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
262
+ _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
263
+ _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
264
+ _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
265
+ _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
266
+ _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
267
+
268
+ _C.MODEL.HRNET.HRFPN = CN()
269
+ _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
270
+
271
+
272
+ def add_densepose_config(cfg: CN) -> None:
273
+ add_densepose_head_config(cfg)
274
+ add_hrnet_config(cfg)
275
+ add_bootstrap_config(cfg)
276
+ add_dataset_category_config(cfg)
277
+ add_evaluation_config(cfg)
CatVTON/densepose/converters/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .hflip import HFlipConverter
6
+ from .to_mask import ToMaskConverter
7
+ from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
8
+ from .segm_to_mask import (
9
+ predictor_output_with_fine_and_coarse_segm_to_mask,
10
+ predictor_output_with_coarse_segm_to_mask,
11
+ resample_fine_and_coarse_segm_to_bbox,
12
+ )
13
+ from .chart_output_to_chart_result import (
14
+ densepose_chart_predictor_output_to_result,
15
+ densepose_chart_predictor_output_to_result_with_confidences,
16
+ )
17
+ from .chart_output_hflip import densepose_chart_predictor_output_hflip
CatVTON/densepose/converters/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (799 Bytes). View file
 
CatVTON/densepose/converters/__pycache__/base.cpython-39.pyc ADDED
Binary file (3.68 kB). View file
 
CatVTON/densepose/converters/__pycache__/builtin.cpython-39.pyc ADDED
Binary file (804 Bytes). View file
 
CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-39.pyc ADDED
Binary file (1.95 kB). View file
 
CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-39.pyc ADDED
Binary file (6.03 kB). View file
 
CatVTON/densepose/converters/__pycache__/hflip.cpython-39.pyc ADDED
Binary file (1.35 kB). View file
 
CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-39.pyc ADDED
Binary file (5.75 kB). View file
 
CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-39.pyc ADDED
Binary file (2.74 kB). View file
 
CatVTON/densepose/converters/__pycache__/to_mask.cpython-39.pyc ADDED
Binary file (1.76 kB). View file
 
CatVTON/densepose/converters/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple, Type
6
+ import torch
7
+
8
+
9
+ class BaseConverter:
10
+ """
11
+ Converter base class to be reused by various converters.
12
+ Converter allows one to convert data from various source types to a particular
13
+ destination type. Each source type needs to register its converter. The
14
+ registration for each source type is valid for all descendants of that type.
15
+ """
16
+
17
+ @classmethod
18
+ def register(cls, from_type: Type, converter: Any = None):
19
+ """
20
+ Registers a converter for the specified type.
21
+ Can be used as a decorator (if converter is None), or called as a method.
22
+
23
+ Args:
24
+ from_type (type): type to register the converter for;
25
+ all instances of this type will use the same converter
26
+ converter (callable): converter to be registered for the given
27
+ type; if None, this method is assumed to be a decorator for the converter
28
+ """
29
+
30
+ if converter is not None:
31
+ cls._do_register(from_type, converter)
32
+
33
+ def wrapper(converter: Any) -> Any:
34
+ cls._do_register(from_type, converter)
35
+ return converter
36
+
37
+ return wrapper
38
+
39
+ @classmethod
40
+ def _do_register(cls, from_type: Type, converter: Any):
41
+ cls.registry[from_type] = converter # pyre-ignore[16]
42
+
43
+ @classmethod
44
+ def _lookup_converter(cls, from_type: Type) -> Any:
45
+ """
46
+ Perform recursive lookup for the given type
47
+ to find registered converter. If a converter was found for some base
48
+ class, it gets registered for this class to save on further lookups.
49
+
50
+ Args:
51
+ from_type: type for which to find a converter
52
+ Return:
53
+ callable or None - registered converter or None
54
+ if no suitable entry was found in the registry
55
+ """
56
+ if from_type in cls.registry: # pyre-ignore[16]
57
+ return cls.registry[from_type]
58
+ for base in from_type.__bases__:
59
+ converter = cls._lookup_converter(base)
60
+ if converter is not None:
61
+ cls._do_register(from_type, converter)
62
+ return converter
63
+ return None
64
+
65
+ @classmethod
66
+ def convert(cls, instance: Any, *args, **kwargs):
67
+ """
68
+ Convert an instance to the destination type using some registered
69
+ converter. Does recursive lookup for base classes, so there's no need
70
+ for explicit registration for derived classes.
71
+
72
+ Args:
73
+ instance: source instance to convert to the destination type
74
+ Return:
75
+ An instance of the destination type obtained from the source instance
76
+ Raises KeyError, if no suitable converter found
77
+ """
78
+ instance_type = type(instance)
79
+ converter = cls._lookup_converter(instance_type)
80
+ if converter is None:
81
+ if cls.dst_type is None: # pyre-ignore[16]
82
+ output_type_str = "itself"
83
+ else:
84
+ output_type_str = cls.dst_type
85
+ raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
86
+ return converter(instance, *args, **kwargs)
87
+
88
+
89
+ IntTupleBox = Tuple[int, int, int, int]
90
+
91
+
92
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
93
+ int_box = [0, 0, 0, 0]
94
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
95
+ return int_box[0], int_box[1], int_box[2], int_box[3]
CatVTON/densepose/converters/builtin.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
6
+ from . import (
7
+ HFlipConverter,
8
+ ToChartResultConverter,
9
+ ToChartResultConverterWithConfidences,
10
+ ToMaskConverter,
11
+ densepose_chart_predictor_output_hflip,
12
+ densepose_chart_predictor_output_to_result,
13
+ densepose_chart_predictor_output_to_result_with_confidences,
14
+ predictor_output_with_coarse_segm_to_mask,
15
+ predictor_output_with_fine_and_coarse_segm_to_mask,
16
+ )
17
+
18
+ ToMaskConverter.register(
19
+ DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
20
+ )
21
+ ToMaskConverter.register(
22
+ DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
23
+ )
24
+
25
+ ToChartResultConverter.register(
26
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
27
+ )
28
+
29
+ ToChartResultConverterWithConfidences.register(
30
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
31
+ )
32
+
33
+ HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
CatVTON/densepose/converters/chart_output_hflip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ import torch
6
+
7
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
8
+
9
+
10
+ def densepose_chart_predictor_output_hflip(
11
+ densepose_predictor_output: DensePoseChartPredictorOutput,
12
+ transform_data: DensePoseTransformData,
13
+ ) -> DensePoseChartPredictorOutput:
14
+ """
15
+ Change to take into account a Horizontal flip.
16
+ """
17
+ if len(densepose_predictor_output) > 0:
18
+
19
+ PredictorOutput = type(densepose_predictor_output)
20
+ output_dict = {}
21
+
22
+ for field in fields(densepose_predictor_output):
23
+ field_value = getattr(densepose_predictor_output, field.name)
24
+ # flip tensors
25
+ if isinstance(field_value, torch.Tensor):
26
+ setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
27
+
28
+ densepose_predictor_output = _flip_iuv_semantics_tensor(
29
+ densepose_predictor_output, transform_data
30
+ )
31
+ densepose_predictor_output = _flip_segm_semantics_tensor(
32
+ densepose_predictor_output, transform_data
33
+ )
34
+
35
+ for field in fields(densepose_predictor_output):
36
+ output_dict[field.name] = getattr(densepose_predictor_output, field.name)
37
+
38
+ return PredictorOutput(**output_dict)
39
+ else:
40
+ return densepose_predictor_output
41
+
42
+
43
+ def _flip_iuv_semantics_tensor(
44
+ densepose_predictor_output: DensePoseChartPredictorOutput,
45
+ dp_transform_data: DensePoseTransformData,
46
+ ) -> DensePoseChartPredictorOutput:
47
+ point_label_symmetries = dp_transform_data.point_label_symmetries
48
+ uv_symmetries = dp_transform_data.uv_symmetries
49
+
50
+ N, C, H, W = densepose_predictor_output.u.shape
51
+ u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
52
+ v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
53
+ Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
54
+ None, :, None, None
55
+ ].expand(N, C - 1, H, W)
56
+ densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
57
+ densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
58
+
59
+ for el in ["fine_segm", "u", "v"]:
60
+ densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
61
+ :, point_label_symmetries, :, :
62
+ ]
63
+ return densepose_predictor_output
64
+
65
+
66
+ def _flip_segm_semantics_tensor(
67
+ densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
68
+ ):
69
+ if densepose_predictor_output.coarse_segm.shape[1] > 2:
70
+ densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
71
+ :, dp_transform_data.mask_label_symmetries, :, :
72
+ ]
73
+ return densepose_predictor_output
CatVTON/densepose/converters/chart_output_to_chart_result.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Dict
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures.boxes import Boxes, BoxMode
10
+
11
+ from ..structures import (
12
+ DensePoseChartPredictorOutput,
13
+ DensePoseChartResult,
14
+ DensePoseChartResultWithConfidences,
15
+ )
16
+ from . import resample_fine_and_coarse_segm_to_bbox
17
+ from .base import IntTupleBox, make_int_box
18
+
19
+
20
+ def resample_uv_tensors_to_bbox(
21
+ u: torch.Tensor,
22
+ v: torch.Tensor,
23
+ labels: torch.Tensor,
24
+ box_xywh_abs: IntTupleBox,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Resamples U and V coordinate estimates for the given bounding box
28
+
29
+ Args:
30
+ u (tensor [1, C, H, W] of float): U coordinates
31
+ v (tensor [1, C, H, W] of float): V coordinates
32
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
33
+ outputs for the given bounding box
34
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
35
+ Return:
36
+ Resampled U and V coordinates - a tensor [2, H, W] of float
37
+ """
38
+ x, y, w, h = box_xywh_abs
39
+ w = max(int(w), 1)
40
+ h = max(int(h), 1)
41
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
42
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
43
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
44
+ for part_id in range(1, u_bbox.size(1)):
45
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
46
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
47
+ return uv
48
+
49
+
50
+ def resample_uv_to_bbox(
51
+ predictor_output: DensePoseChartPredictorOutput,
52
+ labels: torch.Tensor,
53
+ box_xywh_abs: IntTupleBox,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resamples U and V coordinate estimates for the given bounding box
57
+
58
+ Args:
59
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
60
+ output to be resampled
61
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
62
+ outputs for the given bounding box
63
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
64
+ Return:
65
+ Resampled U and V coordinates - a tensor [2, H, W] of float
66
+ """
67
+ return resample_uv_tensors_to_bbox(
68
+ predictor_output.u,
69
+ predictor_output.v,
70
+ labels,
71
+ box_xywh_abs,
72
+ )
73
+
74
+
75
+ def densepose_chart_predictor_output_to_result(
76
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
77
+ ) -> DensePoseChartResult:
78
+ """
79
+ Convert densepose chart predictor outputs to results
80
+
81
+ Args:
82
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
83
+ output to be converted to results, must contain only 1 output
84
+ boxes (Boxes): bounding box that corresponds to the predictor output,
85
+ must contain only 1 bounding box
86
+ Return:
87
+ DensePose chart-based result (DensePoseChartResult)
88
+ """
89
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
90
+ f"Predictor output to result conversion can operate only single outputs"
91
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
92
+ )
93
+
94
+ boxes_xyxy_abs = boxes.tensor.clone()
95
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
96
+ box_xywh = make_int_box(boxes_xywh_abs[0])
97
+
98
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
99
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
100
+ return DensePoseChartResult(labels=labels, uv=uv)
101
+
102
+
103
+ def resample_confidences_to_bbox(
104
+ predictor_output: DensePoseChartPredictorOutput,
105
+ labels: torch.Tensor,
106
+ box_xywh_abs: IntTupleBox,
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Resamples confidences for the given bounding box
110
+
111
+ Args:
112
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
113
+ output to be resampled
114
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
115
+ outputs for the given bounding box
116
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
117
+ Return:
118
+ Resampled confidences - a dict of [H, W] tensors of float
119
+ """
120
+
121
+ x, y, w, h = box_xywh_abs
122
+ w = max(int(w), 1)
123
+ h = max(int(h), 1)
124
+
125
+ confidence_names = [
126
+ "sigma_1",
127
+ "sigma_2",
128
+ "kappa_u",
129
+ "kappa_v",
130
+ "fine_segm_confidence",
131
+ "coarse_segm_confidence",
132
+ ]
133
+ confidence_results = {key: None for key in confidence_names}
134
+ confidence_names = [
135
+ key for key in confidence_names if getattr(predictor_output, key) is not None
136
+ ]
137
+ confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
138
+
139
+ # assign data from channels that correspond to the labels
140
+ for key in confidence_names:
141
+ resampled_confidence = F.interpolate(
142
+ getattr(predictor_output, key),
143
+ (h, w),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ result = confidence_base.clone()
148
+ for part_id in range(1, predictor_output.u.size(1)):
149
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
150
+ # confidence is not part-based, don't try to fill it part by part
151
+ continue
152
+ result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
153
+
154
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
155
+ # confidence is not part-based, fill the data with the first channel
156
+ # (targeted for segmentation confidences that have only 1 channel)
157
+ result = resampled_confidence[0, 0]
158
+
159
+ confidence_results[key] = result
160
+
161
+ return confidence_results # pyre-ignore[7]
162
+
163
+
164
+ def densepose_chart_predictor_output_to_result_with_confidences(
165
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
166
+ ) -> DensePoseChartResultWithConfidences:
167
+ """
168
+ Convert densepose chart predictor outputs to results
169
+
170
+ Args:
171
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
172
+ output with confidences to be converted to results, must contain only 1 output
173
+ boxes (Boxes): bounding box that corresponds to the predictor output,
174
+ must contain only 1 bounding box
175
+ Return:
176
+ DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
177
+ """
178
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
179
+ f"Predictor output to result conversion can operate only single outputs"
180
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
181
+ )
182
+
183
+ boxes_xyxy_abs = boxes.tensor.clone()
184
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
185
+ box_xywh = make_int_box(boxes_xywh_abs[0])
186
+
187
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
188
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
189
+ confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
190
+ return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
CatVTON/densepose/converters/hflip.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from .base import BaseConverter
8
+
9
+
10
+ class HFlipConverter(BaseConverter):
11
+ """
12
+ Converts various DensePose predictor outputs to DensePose results.
13
+ Each DensePose predictor output type has to register its convertion strategy.
14
+ """
15
+
16
+ registry = {}
17
+ dst_type = None
18
+
19
+ @classmethod
20
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
21
+ # inconsistently.
22
+ def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
23
+ """
24
+ Performs an horizontal flip on DensePose predictor outputs.
25
+ Does recursive lookup for base classes, so there's no need
26
+ for explicit registration for derived classes.
27
+
28
+ Args:
29
+ predictor_outputs: DensePose predictor output to be converted to BitMasks
30
+ transform_data: Anything useful for the flip
31
+ Return:
32
+ An instance of the same type as predictor_outputs
33
+ """
34
+ return super(HFlipConverter, cls).convert(
35
+ predictor_outputs, transform_data, *args, **kwargs
36
+ )
CatVTON/densepose/converters/segm_to_mask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BitMasks, Boxes, BoxMode
10
+
11
+ from .base import IntTupleBox, make_int_box
12
+ from .to_mask import ImageSizeType
13
+
14
+
15
+ def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
16
+ """
17
+ Resample coarse segmentation tensor to the given
18
+ bounding box and derive labels for each pixel of the bounding box
19
+
20
+ Args:
21
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
22
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
23
+ corner coordinates, width (W) and height (H)
24
+ Return:
25
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
26
+ """
27
+ x, y, w, h = box_xywh_abs
28
+ w = max(int(w), 1)
29
+ h = max(int(h), 1)
30
+ labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
31
+ return labels
32
+
33
+
34
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
35
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
36
+ ):
37
+ """
38
+ Resample fine and coarse segmentation tensors to the given
39
+ bounding box and derive labels for each pixel of the bounding box
40
+
41
+ Args:
42
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
43
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
44
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
45
+ corner coordinates, width (W) and height (H)
46
+ Return:
47
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
48
+ """
49
+ x, y, w, h = box_xywh_abs
50
+ w = max(int(w), 1)
51
+ h = max(int(h), 1)
52
+ # coarse segmentation
53
+ coarse_segm_bbox = F.interpolate(
54
+ coarse_segm,
55
+ (h, w),
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ ).argmax(dim=1)
59
+ # combined coarse and fine segmentation
60
+ labels = (
61
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
62
+ * (coarse_segm_bbox > 0).long()
63
+ )
64
+ return labels
65
+
66
+
67
+ def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
68
+ """
69
+ Resample fine and coarse segmentation outputs from a predictor to the given
70
+ bounding box and derive labels for each pixel of the bounding box
71
+
72
+ Args:
73
+ predictor_output: DensePose predictor output that contains segmentation
74
+ results to be resampled
75
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
76
+ corner coordinates, width (W) and height (H)
77
+ Return:
78
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
79
+ """
80
+ return resample_fine_and_coarse_segm_tensors_to_bbox(
81
+ predictor_output.fine_segm,
82
+ predictor_output.coarse_segm,
83
+ box_xywh_abs,
84
+ )
85
+
86
+
87
+ def predictor_output_with_coarse_segm_to_mask(
88
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
89
+ ) -> BitMasks:
90
+ """
91
+ Convert predictor output with coarse and fine segmentation to a mask.
92
+ Assumes that predictor output has the following attributes:
93
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
94
+ unnormalized scores for N instances; D is the number of coarse
95
+ segmentation labels, H and W is the resolution of the estimate
96
+
97
+ Args:
98
+ predictor_output: DensePose predictor output to be converted to mask
99
+ boxes (Boxes): bounding boxes that correspond to the DensePose
100
+ predictor outputs
101
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
102
+ Return:
103
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
104
+ a mask of the size of the image for each instance
105
+ """
106
+ H, W = image_size_hw
107
+ boxes_xyxy_abs = boxes.tensor.clone()
108
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
109
+ N = len(boxes_xywh_abs)
110
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
111
+ for i in range(len(boxes_xywh_abs)):
112
+ box_xywh = make_int_box(boxes_xywh_abs[i])
113
+ box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
114
+ x, y, w, h = box_xywh
115
+ masks[i, y : y + h, x : x + w] = box_mask
116
+
117
+ return BitMasks(masks)
118
+
119
+
120
+ def predictor_output_with_fine_and_coarse_segm_to_mask(
121
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
122
+ ) -> BitMasks:
123
+ """
124
+ Convert predictor output with coarse and fine segmentation to a mask.
125
+ Assumes that predictor output has the following attributes:
126
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
127
+ unnormalized scores for N instances; D is the number of coarse
128
+ segmentation labels, H and W is the resolution of the estimate
129
+ - fine_segm (tensor of size [N, C, H, W]): fine segmentation
130
+ unnormalized scores for N instances; C is the number of fine
131
+ segmentation labels, H and W is the resolution of the estimate
132
+
133
+ Args:
134
+ predictor_output: DensePose predictor output to be converted to mask
135
+ boxes (Boxes): bounding boxes that correspond to the DensePose
136
+ predictor outputs
137
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
138
+ Return:
139
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
140
+ a mask of the size of the image for each instance
141
+ """
142
+ H, W = image_size_hw
143
+ boxes_xyxy_abs = boxes.tensor.clone()
144
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
145
+ N = len(boxes_xywh_abs)
146
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
147
+ for i in range(len(boxes_xywh_abs)):
148
+ box_xywh = make_int_box(boxes_xywh_abs[i])
149
+ labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
150
+ x, y, w, h = box_xywh
151
+ masks[i, y : y + h, x : x + w] = labels_i > 0
152
+ return BitMasks(masks)
CatVTON/densepose/converters/to_chart_result.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from detectron2.structures import Boxes
8
+
9
+ from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
10
+ from .base import BaseConverter
11
+
12
+
13
+ class ToChartResultConverter(BaseConverter):
14
+ """
15
+ Converts various DensePose predictor outputs to DensePose results.
16
+ Each DensePose predictor output type has to register its convertion strategy.
17
+ """
18
+
19
+ registry = {}
20
+ dst_type = DensePoseChartResult
21
+
22
+ @classmethod
23
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
24
+ # inconsistently.
25
+ def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
26
+ """
27
+ Convert DensePose predictor outputs to DensePoseResult using some registered
28
+ converter. Does recursive lookup for base classes, so there's no need
29
+ for explicit registration for derived classes.
30
+
31
+ Args:
32
+ densepose_predictor_outputs: DensePose predictor output to be
33
+ converted to BitMasks
34
+ boxes (Boxes): bounding boxes that correspond to the DensePose
35
+ predictor outputs
36
+ Return:
37
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
38
+ """
39
+ return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
40
+
41
+
42
+ class ToChartResultConverterWithConfidences(BaseConverter):
43
+ """
44
+ Converts various DensePose predictor outputs to DensePose results.
45
+ Each DensePose predictor output type has to register its convertion strategy.
46
+ """
47
+
48
+ registry = {}
49
+ dst_type = DensePoseChartResultWithConfidences
50
+
51
+ @classmethod
52
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
53
+ # inconsistently.
54
+ def convert(
55
+ cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
56
+ ) -> DensePoseChartResultWithConfidences:
57
+ """
58
+ Convert DensePose predictor outputs to DensePoseResult with confidences
59
+ using some registered converter. Does recursive lookup for base classes,
60
+ so there's no need for explicit registration for derived classes.
61
+
62
+ Args:
63
+ densepose_predictor_outputs: DensePose predictor output with confidences
64
+ to be converted to BitMasks
65
+ boxes (Boxes): bounding boxes that correspond to the DensePose
66
+ predictor outputs
67
+ Return:
68
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
69
+ """
70
+ return super(ToChartResultConverterWithConfidences, cls).convert(
71
+ predictor_outputs, boxes, *args, **kwargs
72
+ )
CatVTON/densepose/converters/to_mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple
6
+
7
+ from detectron2.structures import BitMasks, Boxes
8
+
9
+ from .base import BaseConverter
10
+
11
+ ImageSizeType = Tuple[int, int]
12
+
13
+
14
+ class ToMaskConverter(BaseConverter):
15
+ """
16
+ Converts various DensePose predictor outputs to masks
17
+ in bit mask format (see `BitMasks`). Each DensePose predictor output type
18
+ has to register its convertion strategy.
19
+ """
20
+
21
+ registry = {}
22
+ dst_type = BitMasks
23
+
24
+ @classmethod
25
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
26
+ # inconsistently.
27
+ def convert(
28
+ cls,
29
+ densepose_predictor_outputs: Any,
30
+ boxes: Boxes,
31
+ image_size_hw: ImageSizeType,
32
+ *args,
33
+ **kwargs
34
+ ) -> BitMasks:
35
+ """
36
+ Convert DensePose predictor outputs to BitMasks using some registered
37
+ converter. Does recursive lookup for base classes, so there's no need
38
+ for explicit registration for derived classes.
39
+
40
+ Args:
41
+ densepose_predictor_outputs: DensePose predictor output to be
42
+ converted to BitMasks
43
+ boxes (Boxes): bounding boxes that correspond to the DensePose
44
+ predictor outputs
45
+ image_size_hw (tuple [int, int]): image height and width
46
+ Return:
47
+ An instance of `BitMasks`. If no suitable converter was found, raises KeyError
48
+ """
49
+ return super(ToMaskConverter, cls).convert(
50
+ densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
51
+ )
CatVTON/densepose/engine/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .trainer import Trainer
CatVTON/densepose/engine/trainer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ import os
7
+ from collections import OrderedDict
8
+ from typing import List, Optional, Union
9
+ import torch
10
+ from torch import nn
11
+
12
+ from detectron2.checkpoint import DetectionCheckpointer
13
+ from detectron2.config import CfgNode
14
+ from detectron2.engine import DefaultTrainer
15
+ from detectron2.evaluation import (
16
+ DatasetEvaluator,
17
+ DatasetEvaluators,
18
+ inference_on_dataset,
19
+ print_csv_format,
20
+ )
21
+ from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
22
+ from detectron2.utils import comm
23
+ from detectron2.utils.events import EventWriter, get_event_storage
24
+
25
+ from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
26
+ from densepose.data import (
27
+ DatasetMapper,
28
+ build_combined_loader,
29
+ build_detection_test_loader,
30
+ build_detection_train_loader,
31
+ build_inference_based_loaders,
32
+ has_inference_based_loaders,
33
+ )
34
+ from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
35
+ from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
36
+ from densepose.modeling.cse import Embedder
37
+
38
+
39
+ class SampleCountingLoader:
40
+ def __init__(self, loader):
41
+ self.loader = loader
42
+
43
+ def __iter__(self):
44
+ it = iter(self.loader)
45
+ storage = get_event_storage()
46
+ while True:
47
+ try:
48
+ batch = next(it)
49
+ num_inst_per_dataset = {}
50
+ for data in batch:
51
+ dataset_name = data["dataset"]
52
+ if dataset_name not in num_inst_per_dataset:
53
+ num_inst_per_dataset[dataset_name] = 0
54
+ num_inst = len(data["instances"])
55
+ num_inst_per_dataset[dataset_name] += num_inst
56
+ for dataset_name in num_inst_per_dataset:
57
+ storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
58
+ yield batch
59
+ except StopIteration:
60
+ break
61
+
62
+
63
+ class SampleCountMetricPrinter(EventWriter):
64
+ def __init__(self):
65
+ self.logger = logging.getLogger(__name__)
66
+
67
+ def write(self):
68
+ storage = get_event_storage()
69
+ batch_stats_strs = []
70
+ for key, buf in storage.histories().items():
71
+ if key.startswith("batch/"):
72
+ batch_stats_strs.append(f"{key} {buf.avg(20)}")
73
+ self.logger.info(", ".join(batch_stats_strs))
74
+
75
+
76
+ class Trainer(DefaultTrainer):
77
+ @classmethod
78
+ def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
79
+ if isinstance(model, nn.parallel.DistributedDataParallel):
80
+ model = model.module
81
+ if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
82
+ return model.roi_heads.embedder
83
+ return None
84
+
85
+ # TODO: the only reason to copy the base class code here is to pass the embedder from
86
+ # the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
87
+ @classmethod
88
+ def test(
89
+ cls,
90
+ cfg: CfgNode,
91
+ model: nn.Module,
92
+ evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
93
+ ):
94
+ """
95
+ Args:
96
+ cfg (CfgNode):
97
+ model (nn.Module):
98
+ evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
99
+ :meth:`build_evaluator`. Otherwise, must have the same length as
100
+ ``cfg.DATASETS.TEST``.
101
+
102
+ Returns:
103
+ dict: a dict of result metrics
104
+ """
105
+ logger = logging.getLogger(__name__)
106
+ if isinstance(evaluators, DatasetEvaluator):
107
+ evaluators = [evaluators]
108
+ if evaluators is not None:
109
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
110
+ len(cfg.DATASETS.TEST), len(evaluators)
111
+ )
112
+
113
+ results = OrderedDict()
114
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
115
+ data_loader = cls.build_test_loader(cfg, dataset_name)
116
+ # When evaluators are passed in as arguments,
117
+ # implicitly assume that evaluators can be created before data_loader.
118
+ if evaluators is not None:
119
+ evaluator = evaluators[idx]
120
+ else:
121
+ try:
122
+ embedder = cls.extract_embedder_from_model(model)
123
+ evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
124
+ except NotImplementedError:
125
+ logger.warn(
126
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
127
+ "or implement its `build_evaluator` method."
128
+ )
129
+ results[dataset_name] = {}
130
+ continue
131
+ if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
132
+ results_i = inference_on_dataset(model, data_loader, evaluator)
133
+ else:
134
+ results_i = {}
135
+ results[dataset_name] = results_i
136
+ if comm.is_main_process():
137
+ assert isinstance(
138
+ results_i, dict
139
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
140
+ results_i
141
+ )
142
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
143
+ print_csv_format(results_i)
144
+
145
+ if len(results) == 1:
146
+ results = list(results.values())[0]
147
+ return results
148
+
149
+ @classmethod
150
+ def build_evaluator(
151
+ cls,
152
+ cfg: CfgNode,
153
+ dataset_name: str,
154
+ output_folder: Optional[str] = None,
155
+ embedder: Optional[Embedder] = None,
156
+ ) -> DatasetEvaluators:
157
+ if output_folder is None:
158
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
159
+ evaluators = []
160
+ distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
161
+ # Note: we currently use COCO evaluator for both COCO and LVIS datasets
162
+ # to have compatible metrics. LVIS bbox evaluator could also be used
163
+ # with an adapter to properly handle filtered / mapped categories
164
+ # evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
165
+ # if evaluator_type == "coco":
166
+ # evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
167
+ # elif evaluator_type == "lvis":
168
+ # evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
169
+ evaluators.append(
170
+ Detectron2COCOEvaluatorAdapter(
171
+ dataset_name, output_dir=output_folder, distributed=distributed
172
+ )
173
+ )
174
+ if cfg.MODEL.DENSEPOSE_ON:
175
+ storage = build_densepose_evaluator_storage(cfg, output_folder)
176
+ evaluators.append(
177
+ DensePoseCOCOEvaluator(
178
+ dataset_name,
179
+ distributed,
180
+ output_folder,
181
+ evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
182
+ min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
183
+ storage=storage,
184
+ embedder=embedder,
185
+ should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
186
+ mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
187
+ )
188
+ )
189
+ return DatasetEvaluators(evaluators)
190
+
191
+ @classmethod
192
+ def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
193
+ params = get_default_optimizer_params(
194
+ model,
195
+ base_lr=cfg.SOLVER.BASE_LR,
196
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
197
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
198
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
199
+ overrides={
200
+ "features": {
201
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
202
+ },
203
+ "embeddings": {
204
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
205
+ },
206
+ },
207
+ )
208
+ optimizer = torch.optim.SGD(
209
+ params,
210
+ cfg.SOLVER.BASE_LR,
211
+ momentum=cfg.SOLVER.MOMENTUM,
212
+ nesterov=cfg.SOLVER.NESTEROV,
213
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
214
+ )
215
+ # pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
216
+ return maybe_add_gradient_clipping(cfg, optimizer)
217
+
218
+ @classmethod
219
+ def build_test_loader(cls, cfg: CfgNode, dataset_name):
220
+ return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
221
+
222
+ @classmethod
223
+ def build_train_loader(cls, cfg: CfgNode):
224
+ data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
225
+ if not has_inference_based_loaders(cfg):
226
+ return data_loader
227
+ model = cls.build_model(cfg)
228
+ model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
229
+ DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
230
+ inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
231
+ loaders = [data_loader] + inference_based_loaders
232
+ ratios = [1.0] + ratios
233
+ combined_data_loader = build_combined_loader(cfg, loaders, ratios)
234
+ sample_counting_loader = SampleCountingLoader(combined_data_loader)
235
+ return sample_counting_loader
236
+
237
+ def build_writers(self):
238
+ writers = super().build_writers()
239
+ writers.append(SampleCountMetricPrinter())
240
+ return writers
241
+
242
+ @classmethod
243
+ def test_with_TTA(cls, cfg: CfgNode, model):
244
+ logger = logging.getLogger("detectron2.trainer")
245
+ # In the end of training, run an evaluation with TTA
246
+ # Only support some R-CNN models.
247
+ logger.info("Running inference with test-time augmentation ...")
248
+ transform_data = load_from_cfg(cfg)
249
+ model = DensePoseGeneralizedRCNNWithTTA(
250
+ cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
251
+ )
252
+ evaluators = [
253
+ cls.build_evaluator(
254
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
255
+ )
256
+ for name in cfg.DATASETS.TEST
257
+ ]
258
+ res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
259
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
260
+ return res
CatVTON/densepose/modeling/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .confidence import DensePoseConfidenceModelConfig, DensePoseUVConfidenceType
6
+ from .filter import DensePoseDataFilter
7
+ from .inference import densepose_inference
8
+ from .utils import initialize_module_params
9
+ from .build import (
10
+ build_densepose_data_filter,
11
+ build_densepose_embedder,
12
+ build_densepose_head,
13
+ build_densepose_losses,
14
+ build_densepose_predictor,
15
+ )
CatVTON/densepose/modeling/build.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Optional
6
+ from torch import nn
7
+
8
+ from detectron2.config import CfgNode
9
+
10
+ from .cse.embedder import Embedder
11
+ from .filter import DensePoseDataFilter
12
+
13
+
14
+ def build_densepose_predictor(cfg: CfgNode, input_channels: int):
15
+ """
16
+ Create an instance of DensePose predictor based on configuration options.
17
+
18
+ Args:
19
+ cfg (CfgNode): configuration options
20
+ input_channels (int): input tensor size along the channel dimension
21
+ Return:
22
+ An instance of DensePose predictor
23
+ """
24
+ from .predictors import DENSEPOSE_PREDICTOR_REGISTRY
25
+
26
+ predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME
27
+ return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels)
28
+
29
+
30
+ def build_densepose_data_filter(cfg: CfgNode):
31
+ """
32
+ Build DensePose data filter which selects data for training
33
+
34
+ Args:
35
+ cfg (CfgNode): configuration options
36
+
37
+ Return:
38
+ Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances)
39
+ An instance of DensePose filter, which takes feature tensors and proposals
40
+ as an input and returns filtered features and proposals
41
+ """
42
+ dp_filter = DensePoseDataFilter(cfg)
43
+ return dp_filter
44
+
45
+
46
+ def build_densepose_head(cfg: CfgNode, input_channels: int):
47
+ """
48
+ Build DensePose head based on configurations options
49
+
50
+ Args:
51
+ cfg (CfgNode): configuration options
52
+ input_channels (int): input tensor size along the channel dimension
53
+ Return:
54
+ An instance of DensePose head
55
+ """
56
+ from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY
57
+
58
+ head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME
59
+ return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels)
60
+
61
+
62
+ def build_densepose_losses(cfg: CfgNode):
63
+ """
64
+ Build DensePose loss based on configurations options
65
+
66
+ Args:
67
+ cfg (CfgNode): configuration options
68
+ Return:
69
+ An instance of DensePose loss
70
+ """
71
+ from .losses import DENSEPOSE_LOSS_REGISTRY
72
+
73
+ loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME
74
+ return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg)
75
+
76
+
77
+ def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]:
78
+ """
79
+ Build embedder used to embed mesh vertices into an embedding space.
80
+ Embedder contains sub-embedders, one for each mesh ID.
81
+
82
+ Args:
83
+ cfg (cfgNode): configuration options
84
+ Return:
85
+ Embedding module
86
+ """
87
+ if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS:
88
+ return Embedder(cfg)
89
+ return None
CatVTON/densepose/modeling/confidence.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+
8
+ from detectron2.config import CfgNode
9
+
10
+
11
+ class DensePoseUVConfidenceType(Enum):
12
+ """
13
+ Statistical model type for confidence learning, possible values:
14
+ - "iid_iso": statistically independent identically distributed residuals
15
+ with anisotropic covariance
16
+ - "indep_aniso": statistically independent residuals with anisotropic
17
+ covariances
18
+ For details, see:
19
+ N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning
20
+ Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019
21
+ """
22
+
23
+ # fmt: off
24
+ IID_ISO = "iid_iso"
25
+ INDEP_ANISO = "indep_aniso"
26
+ # fmt: on
27
+
28
+
29
+ @dataclass
30
+ class DensePoseUVConfidenceConfig:
31
+ """
32
+ Configuration options for confidence on UV data
33
+ """
34
+
35
+ enabled: bool = False
36
+ # lower bound on UV confidences
37
+ epsilon: float = 0.01
38
+ type: DensePoseUVConfidenceType = DensePoseUVConfidenceType.IID_ISO
39
+
40
+
41
+ @dataclass
42
+ class DensePoseSegmConfidenceConfig:
43
+ """
44
+ Configuration options for confidence on segmentation
45
+ """
46
+
47
+ enabled: bool = False
48
+ # lower bound on confidence values
49
+ epsilon: float = 0.01
50
+
51
+
52
+ @dataclass
53
+ class DensePoseConfidenceModelConfig:
54
+ """
55
+ Configuration options for confidence models
56
+ """
57
+
58
+ # confidence for U and V values
59
+ uv_confidence: DensePoseUVConfidenceConfig
60
+ # segmentation confidence
61
+ segm_confidence: DensePoseSegmConfidenceConfig
62
+
63
+ @staticmethod
64
+ def from_cfg(cfg: CfgNode) -> "DensePoseConfidenceModelConfig":
65
+ return DensePoseConfidenceModelConfig(
66
+ uv_confidence=DensePoseUVConfidenceConfig(
67
+ enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.ENABLED,
68
+ epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON,
69
+ type=DensePoseUVConfidenceType(cfg.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE),
70
+ ),
71
+ segm_confidence=DensePoseSegmConfidenceConfig(
72
+ enabled=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.ENABLED,
73
+ epsilon=cfg.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON,
74
+ ),
75
+ )
CatVTON/densepose/modeling/densepose_checkpoint.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from collections import OrderedDict
5
+
6
+ from detectron2.checkpoint import DetectionCheckpointer
7
+
8
+
9
+ def _rename_HRNet_weights(weights):
10
+ # We detect and rename HRNet weights for DensePose. 1956 and 1716 are values that are
11
+ # common to all HRNet pretrained weights, and should be enough to accurately identify them
12
+ if (
13
+ len(weights["model"].keys()) == 1956
14
+ and len([k for k in weights["model"].keys() if k.startswith("stage")]) == 1716
15
+ ):
16
+ hrnet_weights = OrderedDict()
17
+ for k in weights["model"].keys():
18
+ hrnet_weights["backbone.bottom_up." + str(k)] = weights["model"][k]
19
+ return {"model": hrnet_weights}
20
+ else:
21
+ return weights
22
+
23
+
24
+ class DensePoseCheckpointer(DetectionCheckpointer):
25
+ """
26
+ Same as :class:`DetectionCheckpointer`, but is able to handle HRNet weights
27
+ """
28
+
29
+ def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
30
+ super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
31
+
32
+ def _load_file(self, filename: str) -> object:
33
+ """
34
+ Adding hrnet support
35
+ """
36
+ weights = super()._load_file(filename)
37
+ return _rename_HRNet_weights(weights)
CatVTON/densepose/modeling/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import List
6
+ import torch
7
+
8
+ from detectron2.config import CfgNode
9
+ from detectron2.structures import Instances
10
+ from detectron2.structures.boxes import matched_pairwise_iou
11
+
12
+
13
+ class DensePoseDataFilter:
14
+ def __init__(self, cfg: CfgNode):
15
+ self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD
16
+ self.keep_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
17
+
18
+ @torch.no_grad()
19
+ def __call__(self, features: List[torch.Tensor], proposals_with_targets: List[Instances]):
20
+ """
21
+ Filters proposals with targets to keep only the ones relevant for
22
+ DensePose training
23
+
24
+ Args:
25
+ features (list[Tensor]): input data as a list of features,
26
+ each feature is a tensor. Axis 0 represents the number of
27
+ images `N` in the input data; axes 1-3 are channels,
28
+ height, and width, which may vary between features
29
+ (e.g., if a feature pyramid is used).
30
+ proposals_with_targets (list[Instances]): length `N` list of
31
+ `Instances`. The i-th `Instances` contains instances
32
+ (proposals, GT) for the i-th input image,
33
+ Returns:
34
+ list[Tensor]: filtered features
35
+ list[Instances]: filtered proposals
36
+ """
37
+ proposals_filtered = []
38
+ # TODO: the commented out code was supposed to correctly deal with situations
39
+ # where no valid DensePose GT is available for certain images. The corresponding
40
+ # image features were sliced and proposals were filtered. This led to performance
41
+ # deterioration, both in terms of runtime and in terms of evaluation results.
42
+ #
43
+ # feature_mask = torch.ones(
44
+ # len(proposals_with_targets),
45
+ # dtype=torch.bool,
46
+ # device=features[0].device if len(features) > 0 else torch.device("cpu"),
47
+ # )
48
+ for i, proposals_per_image in enumerate(proposals_with_targets):
49
+ if not proposals_per_image.has("gt_densepose") and (
50
+ not proposals_per_image.has("gt_masks") or not self.keep_masks
51
+ ):
52
+ # feature_mask[i] = 0
53
+ continue
54
+ gt_boxes = proposals_per_image.gt_boxes
55
+ est_boxes = proposals_per_image.proposal_boxes
56
+ # apply match threshold for densepose head
57
+ iou = matched_pairwise_iou(gt_boxes, est_boxes)
58
+ iou_select = iou > self.iou_threshold
59
+ proposals_per_image = proposals_per_image[iou_select] # pyre-ignore[6]
60
+
61
+ N_gt_boxes = len(proposals_per_image.gt_boxes)
62
+ assert N_gt_boxes == len(proposals_per_image.proposal_boxes), (
63
+ f"The number of GT boxes {N_gt_boxes} is different from the "
64
+ f"number of proposal boxes {len(proposals_per_image.proposal_boxes)}"
65
+ )
66
+ # filter out any target without suitable annotation
67
+ if self.keep_masks:
68
+ gt_masks = (
69
+ proposals_per_image.gt_masks
70
+ if hasattr(proposals_per_image, "gt_masks")
71
+ else [None] * N_gt_boxes
72
+ )
73
+ else:
74
+ gt_masks = [None] * N_gt_boxes
75
+ gt_densepose = (
76
+ proposals_per_image.gt_densepose
77
+ if hasattr(proposals_per_image, "gt_densepose")
78
+ else [None] * N_gt_boxes
79
+ )
80
+ assert len(gt_masks) == N_gt_boxes
81
+ assert len(gt_densepose) == N_gt_boxes
82
+ selected_indices = [
83
+ i
84
+ for i, (dp_target, mask_target) in enumerate(zip(gt_densepose, gt_masks))
85
+ if (dp_target is not None) or (mask_target is not None)
86
+ ]
87
+ # if not len(selected_indices):
88
+ # feature_mask[i] = 0
89
+ # continue
90
+ if len(selected_indices) != N_gt_boxes:
91
+ proposals_per_image = proposals_per_image[selected_indices] # pyre-ignore[6]
92
+ assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes)
93
+ proposals_filtered.append(proposals_per_image)
94
+ # features_filtered = [feature[feature_mask] for feature in features]
95
+ # return features_filtered, proposals_filtered
96
+ return features, proposals_filtered
CatVTON/densepose/modeling/hrfpn.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ """
5
+ MIT License
6
+ Copyright (c) 2019 Microsoft
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+ """
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from detectron2.layers import ShapeSpec
29
+ from detectron2.modeling.backbone import BACKBONE_REGISTRY
30
+ from detectron2.modeling.backbone.backbone import Backbone
31
+
32
+ from .hrnet import build_pose_hrnet_backbone
33
+
34
+
35
+ class HRFPN(Backbone):
36
+ """HRFPN (High Resolution Feature Pyramids)
37
+ Transforms outputs of HRNet backbone so they are suitable for the ROI_heads
38
+ arXiv: https://arxiv.org/abs/1904.04514
39
+ Adapted from https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/hrfpn.py
40
+ Args:
41
+ bottom_up: (list) output of HRNet
42
+ in_features (list): names of the input features (output of HRNet)
43
+ in_channels (list): number of channels for each branch
44
+ out_channels (int): output channels of feature pyramids
45
+ n_out_features (int): number of output stages
46
+ pooling (str): pooling for generating feature pyramids (from {MAX, AVG})
47
+ share_conv (bool): Have one conv per output, or share one with all the outputs
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ bottom_up,
53
+ in_features,
54
+ n_out_features,
55
+ in_channels,
56
+ out_channels,
57
+ pooling="AVG",
58
+ share_conv=False,
59
+ ):
60
+ super(HRFPN, self).__init__()
61
+ assert isinstance(in_channels, list)
62
+ self.bottom_up = bottom_up
63
+ self.in_features = in_features
64
+ self.n_out_features = n_out_features
65
+ self.in_channels = in_channels
66
+ self.out_channels = out_channels
67
+ self.num_ins = len(in_channels)
68
+ self.share_conv = share_conv
69
+
70
+ if self.share_conv:
71
+ self.fpn_conv = nn.Conv2d(
72
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1
73
+ )
74
+ else:
75
+ self.fpn_conv = nn.ModuleList()
76
+ for _ in range(self.n_out_features):
77
+ self.fpn_conv.append(
78
+ nn.Conv2d(
79
+ in_channels=out_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=3,
82
+ padding=1,
83
+ )
84
+ )
85
+
86
+ # Custom change: Replaces a simple bilinear interpolation
87
+ self.interp_conv = nn.ModuleList()
88
+ for i in range(len(self.in_features)):
89
+ self.interp_conv.append(
90
+ nn.Sequential(
91
+ nn.ConvTranspose2d(
92
+ in_channels=in_channels[i],
93
+ out_channels=in_channels[i],
94
+ kernel_size=4,
95
+ stride=2**i,
96
+ padding=0,
97
+ output_padding=0,
98
+ bias=False,
99
+ ),
100
+ nn.BatchNorm2d(in_channels[i], momentum=0.1),
101
+ nn.ReLU(inplace=True),
102
+ )
103
+ )
104
+
105
+ # Custom change: Replaces a couple (reduction conv + pooling) by one conv
106
+ self.reduction_pooling_conv = nn.ModuleList()
107
+ for i in range(self.n_out_features):
108
+ self.reduction_pooling_conv.append(
109
+ nn.Sequential(
110
+ nn.Conv2d(sum(in_channels), out_channels, kernel_size=2**i, stride=2**i),
111
+ nn.BatchNorm2d(out_channels, momentum=0.1),
112
+ nn.ReLU(inplace=True),
113
+ )
114
+ )
115
+
116
+ if pooling == "MAX":
117
+ self.pooling = F.max_pool2d
118
+ else:
119
+ self.pooling = F.avg_pool2d
120
+
121
+ self._out_features = []
122
+ self._out_feature_channels = {}
123
+ self._out_feature_strides = {}
124
+
125
+ for i in range(self.n_out_features):
126
+ self._out_features.append("p%d" % (i + 1))
127
+ self._out_feature_channels.update({self._out_features[-1]: self.out_channels})
128
+ self._out_feature_strides.update({self._out_features[-1]: 2 ** (i + 2)})
129
+
130
+ # default init_weights for conv(msra) and norm in ConvModule
131
+ def init_weights(self):
132
+ for m in self.modules():
133
+ if isinstance(m, nn.Conv2d):
134
+ nn.init.kaiming_normal_(m.weight, a=1)
135
+ nn.init.constant_(m.bias, 0)
136
+
137
+ def forward(self, inputs):
138
+ bottom_up_features = self.bottom_up(inputs)
139
+ assert len(bottom_up_features) == len(self.in_features)
140
+ inputs = [bottom_up_features[f] for f in self.in_features]
141
+
142
+ outs = []
143
+ for i in range(len(inputs)):
144
+ outs.append(self.interp_conv[i](inputs[i]))
145
+ shape_2 = min(o.shape[2] for o in outs)
146
+ shape_3 = min(o.shape[3] for o in outs)
147
+ out = torch.cat([o[:, :, :shape_2, :shape_3] for o in outs], dim=1)
148
+ outs = []
149
+ for i in range(self.n_out_features):
150
+ outs.append(self.reduction_pooling_conv[i](out))
151
+ for i in range(len(outs)): # Make shapes consistent
152
+ outs[-1 - i] = outs[-1 - i][
153
+ :, :, : outs[-1].shape[2] * 2**i, : outs[-1].shape[3] * 2**i
154
+ ]
155
+ outputs = []
156
+ for i in range(len(outs)):
157
+ if self.share_conv:
158
+ outputs.append(self.fpn_conv(outs[i]))
159
+ else:
160
+ outputs.append(self.fpn_conv[i](outs[i]))
161
+
162
+ assert len(self._out_features) == len(outputs)
163
+ return dict(zip(self._out_features, outputs))
164
+
165
+
166
+ @BACKBONE_REGISTRY.register()
167
+ def build_hrfpn_backbone(cfg, input_shape: ShapeSpec) -> HRFPN:
168
+
169
+ in_channels = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
170
+ in_features = ["p%d" % (i + 1) for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES)]
171
+ n_out_features = len(cfg.MODEL.ROI_HEADS.IN_FEATURES)
172
+ out_channels = cfg.MODEL.HRNET.HRFPN.OUT_CHANNELS
173
+ hrnet = build_pose_hrnet_backbone(cfg, input_shape)
174
+ hrfpn = HRFPN(
175
+ hrnet,
176
+ in_features,
177
+ n_out_features,
178
+ in_channels,
179
+ out_channels,
180
+ pooling="AVG",
181
+ share_conv=False,
182
+ )
183
+
184
+ return hrfpn
CatVTON/densepose/modeling/hrnet.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # ------------------------------------------------------------------------------
3
+ # Copyright (c) Microsoft
4
+ # Licensed under the MIT License.
5
+ # Written by Bin Xiao (leoxiaobin@gmail.com)
6
+ # Modified by Bowen Cheng (bcheng9@illinois.edu)
7
+ # Adapted from https://github.com/HRNet/Higher-HRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py # noqa
8
+ # ------------------------------------------------------------------------------
9
+
10
+ # pyre-unsafe
11
+
12
+ from __future__ import absolute_import, division, print_function
13
+ import logging
14
+ import torch.nn as nn
15
+
16
+ from detectron2.layers import ShapeSpec
17
+ from detectron2.modeling.backbone import BACKBONE_REGISTRY
18
+ from detectron2.modeling.backbone.backbone import Backbone
19
+
20
+ BN_MOMENTUM = 0.1
21
+ logger = logging.getLogger(__name__)
22
+
23
+ __all__ = ["build_pose_hrnet_backbone", "PoseHigherResolutionNet"]
24
+
25
+
26
+ def conv3x3(in_planes, out_planes, stride=1):
27
+ """3x3 convolution with padding"""
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.conv1 = conv3x3(inplanes, planes, stride)
37
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.conv2 = conv3x3(planes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
41
+ self.downsample = downsample
42
+ self.stride = stride
43
+
44
+ def forward(self, x):
45
+ residual = x
46
+
47
+ out = self.conv1(x)
48
+ out = self.bn1(out)
49
+ out = self.relu(out)
50
+
51
+ out = self.conv2(out)
52
+ out = self.bn2(out)
53
+
54
+ if self.downsample is not None:
55
+ residual = self.downsample(x)
56
+
57
+ out += residual
58
+ out = self.relu(out)
59
+
60
+ return out
61
+
62
+
63
+ class Bottleneck(nn.Module):
64
+ expansion = 4
65
+
66
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
67
+ super(Bottleneck, self).__init__()
68
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
69
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
70
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
71
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
72
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
73
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ residual = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class HighResolutionModule(nn.Module):
102
+ """HighResolutionModule
103
+ Building block of the PoseHigherResolutionNet (see lower)
104
+ arXiv: https://arxiv.org/abs/1908.10357
105
+ Args:
106
+ num_branches (int): number of branches of the modyle
107
+ blocks (str): type of block of the module
108
+ num_blocks (int): number of blocks of the module
109
+ num_inchannels (int): number of input channels of the module
110
+ num_channels (list): number of channels of each branch
111
+ multi_scale_output (bool): only used by the last module of PoseHigherResolutionNet
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ num_branches,
117
+ blocks,
118
+ num_blocks,
119
+ num_inchannels,
120
+ num_channels,
121
+ multi_scale_output=True,
122
+ ):
123
+ super(HighResolutionModule, self).__init__()
124
+ self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)
125
+
126
+ self.num_inchannels = num_inchannels
127
+ self.num_branches = num_branches
128
+
129
+ self.multi_scale_output = multi_scale_output
130
+
131
+ self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
132
+ self.fuse_layers = self._make_fuse_layers()
133
+ self.relu = nn.ReLU(True)
134
+
135
+ def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
136
+ if num_branches != len(num_blocks):
137
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
138
+ logger.error(error_msg)
139
+ raise ValueError(error_msg)
140
+
141
+ if num_branches != len(num_channels):
142
+ error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
143
+ num_branches, len(num_channels)
144
+ )
145
+ logger.error(error_msg)
146
+ raise ValueError(error_msg)
147
+
148
+ if num_branches != len(num_inchannels):
149
+ error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
150
+ num_branches, len(num_inchannels)
151
+ )
152
+ logger.error(error_msg)
153
+ raise ValueError(error_msg)
154
+
155
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
156
+ downsample = None
157
+ if (
158
+ stride != 1
159
+ or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion
160
+ ):
161
+ downsample = nn.Sequential(
162
+ nn.Conv2d(
163
+ self.num_inchannels[branch_index],
164
+ num_channels[branch_index] * block.expansion,
165
+ kernel_size=1,
166
+ stride=stride,
167
+ bias=False,
168
+ ),
169
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
170
+ )
171
+
172
+ layers = []
173
+ layers.append(
174
+ block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)
175
+ )
176
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
177
+ for _ in range(1, num_blocks[branch_index]):
178
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
179
+
180
+ return nn.Sequential(*layers)
181
+
182
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
183
+ branches = []
184
+
185
+ for i in range(num_branches):
186
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
187
+
188
+ return nn.ModuleList(branches)
189
+
190
+ def _make_fuse_layers(self):
191
+ if self.num_branches == 1:
192
+ return None
193
+
194
+ num_branches = self.num_branches
195
+ num_inchannels = self.num_inchannels
196
+ fuse_layers = []
197
+ for i in range(num_branches if self.multi_scale_output else 1):
198
+ fuse_layer = []
199
+ for j in range(num_branches):
200
+ if j > i:
201
+ fuse_layer.append(
202
+ nn.Sequential(
203
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
204
+ nn.BatchNorm2d(num_inchannels[i]),
205
+ nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
206
+ )
207
+ )
208
+ elif j == i:
209
+ fuse_layer.append(None)
210
+ else:
211
+ conv3x3s = []
212
+ for k in range(i - j):
213
+ if k == i - j - 1:
214
+ num_outchannels_conv3x3 = num_inchannels[i]
215
+ conv3x3s.append(
216
+ nn.Sequential(
217
+ nn.Conv2d(
218
+ num_inchannels[j],
219
+ num_outchannels_conv3x3,
220
+ 3,
221
+ 2,
222
+ 1,
223
+ bias=False,
224
+ ),
225
+ nn.BatchNorm2d(num_outchannels_conv3x3),
226
+ )
227
+ )
228
+ else:
229
+ num_outchannels_conv3x3 = num_inchannels[j]
230
+ conv3x3s.append(
231
+ nn.Sequential(
232
+ nn.Conv2d(
233
+ num_inchannels[j],
234
+ num_outchannels_conv3x3,
235
+ 3,
236
+ 2,
237
+ 1,
238
+ bias=False,
239
+ ),
240
+ nn.BatchNorm2d(num_outchannels_conv3x3),
241
+ nn.ReLU(True),
242
+ )
243
+ )
244
+ fuse_layer.append(nn.Sequential(*conv3x3s))
245
+ fuse_layers.append(nn.ModuleList(fuse_layer))
246
+
247
+ return nn.ModuleList(fuse_layers)
248
+
249
+ def get_num_inchannels(self):
250
+ return self.num_inchannels
251
+
252
+ def forward(self, x):
253
+ if self.num_branches == 1:
254
+ return [self.branches[0](x[0])]
255
+
256
+ for i in range(self.num_branches):
257
+ x[i] = self.branches[i](x[i])
258
+
259
+ x_fuse = []
260
+
261
+ for i in range(len(self.fuse_layers)):
262
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
263
+ for j in range(1, self.num_branches):
264
+ if i == j:
265
+ y = y + x[j]
266
+ else:
267
+ z = self.fuse_layers[i][j](x[j])[:, :, : y.shape[2], : y.shape[3]]
268
+ y = y + z
269
+ x_fuse.append(self.relu(y))
270
+
271
+ return x_fuse
272
+
273
+
274
+ blocks_dict = {"BASIC": BasicBlock, "BOTTLENECK": Bottleneck}
275
+
276
+
277
+ class PoseHigherResolutionNet(Backbone):
278
+ """PoseHigherResolutionNet
279
+ Composed of several HighResolutionModule tied together with ConvNets
280
+ Adapted from the GitHub version to fit with HRFPN and the Detectron2 infrastructure
281
+ arXiv: https://arxiv.org/abs/1908.10357
282
+ """
283
+
284
+ def __init__(self, cfg, **kwargs):
285
+ self.inplanes = cfg.MODEL.HRNET.STEM_INPLANES
286
+ super(PoseHigherResolutionNet, self).__init__()
287
+
288
+ # stem net
289
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
290
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
291
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
292
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
293
+ self.relu = nn.ReLU(inplace=True)
294
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
295
+
296
+ self.stage2_cfg = cfg.MODEL.HRNET.STAGE2
297
+ num_channels = self.stage2_cfg.NUM_CHANNELS
298
+ block = blocks_dict[self.stage2_cfg.BLOCK]
299
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
300
+ self.transition1 = self._make_transition_layer([256], num_channels)
301
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
302
+
303
+ self.stage3_cfg = cfg.MODEL.HRNET.STAGE3
304
+ num_channels = self.stage3_cfg.NUM_CHANNELS
305
+ block = blocks_dict[self.stage3_cfg.BLOCK]
306
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
307
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
308
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
309
+
310
+ self.stage4_cfg = cfg.MODEL.HRNET.STAGE4
311
+ num_channels = self.stage4_cfg.NUM_CHANNELS
312
+ block = blocks_dict[self.stage4_cfg.BLOCK]
313
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
314
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
315
+ self.stage4, pre_stage_channels = self._make_stage(
316
+ self.stage4_cfg, num_channels, multi_scale_output=True
317
+ )
318
+
319
+ self._out_features = []
320
+ self._out_feature_channels = {}
321
+ self._out_feature_strides = {}
322
+
323
+ for i in range(cfg.MODEL.HRNET.STAGE4.NUM_BRANCHES):
324
+ self._out_features.append("p%d" % (i + 1))
325
+ self._out_feature_channels.update(
326
+ {self._out_features[-1]: cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS[i]}
327
+ )
328
+ self._out_feature_strides.update({self._out_features[-1]: 1})
329
+
330
+ def _get_deconv_cfg(self, deconv_kernel):
331
+ if deconv_kernel == 4:
332
+ padding = 1
333
+ output_padding = 0
334
+ elif deconv_kernel == 3:
335
+ padding = 1
336
+ output_padding = 1
337
+ elif deconv_kernel == 2:
338
+ padding = 0
339
+ output_padding = 0
340
+
341
+ return deconv_kernel, padding, output_padding
342
+
343
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
344
+ num_branches_cur = len(num_channels_cur_layer)
345
+ num_branches_pre = len(num_channels_pre_layer)
346
+
347
+ transition_layers = []
348
+ for i in range(num_branches_cur):
349
+ if i < num_branches_pre:
350
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
351
+ transition_layers.append(
352
+ nn.Sequential(
353
+ nn.Conv2d(
354
+ num_channels_pre_layer[i],
355
+ num_channels_cur_layer[i],
356
+ 3,
357
+ 1,
358
+ 1,
359
+ bias=False,
360
+ ),
361
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
362
+ nn.ReLU(inplace=True),
363
+ )
364
+ )
365
+ else:
366
+ transition_layers.append(None)
367
+ else:
368
+ conv3x3s = []
369
+ for j in range(i + 1 - num_branches_pre):
370
+ inchannels = num_channels_pre_layer[-1]
371
+ outchannels = (
372
+ num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
373
+ )
374
+ conv3x3s.append(
375
+ nn.Sequential(
376
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
377
+ nn.BatchNorm2d(outchannels),
378
+ nn.ReLU(inplace=True),
379
+ )
380
+ )
381
+ transition_layers.append(nn.Sequential(*conv3x3s))
382
+
383
+ return nn.ModuleList(transition_layers)
384
+
385
+ def _make_layer(self, block, planes, blocks, stride=1):
386
+ downsample = None
387
+ if stride != 1 or self.inplanes != planes * block.expansion:
388
+ downsample = nn.Sequential(
389
+ nn.Conv2d(
390
+ self.inplanes,
391
+ planes * block.expansion,
392
+ kernel_size=1,
393
+ stride=stride,
394
+ bias=False,
395
+ ),
396
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
397
+ )
398
+
399
+ layers = []
400
+ layers.append(block(self.inplanes, planes, stride, downsample))
401
+ self.inplanes = planes * block.expansion
402
+ for _ in range(1, blocks):
403
+ layers.append(block(self.inplanes, planes))
404
+
405
+ return nn.Sequential(*layers)
406
+
407
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
408
+ num_modules = layer_config["NUM_MODULES"]
409
+ num_branches = layer_config["NUM_BRANCHES"]
410
+ num_blocks = layer_config["NUM_BLOCKS"]
411
+ num_channels = layer_config["NUM_CHANNELS"]
412
+ block = blocks_dict[layer_config["BLOCK"]]
413
+
414
+ modules = []
415
+ for i in range(num_modules):
416
+ # multi_scale_output is only used last module
417
+ if not multi_scale_output and i == num_modules - 1:
418
+ reset_multi_scale_output = False
419
+ else:
420
+ reset_multi_scale_output = True
421
+
422
+ modules.append(
423
+ HighResolutionModule(
424
+ num_branches,
425
+ block,
426
+ num_blocks,
427
+ num_inchannels,
428
+ num_channels,
429
+ reset_multi_scale_output,
430
+ )
431
+ )
432
+ num_inchannels = modules[-1].get_num_inchannels()
433
+
434
+ return nn.Sequential(*modules), num_inchannels
435
+
436
+ def forward(self, x):
437
+ x = self.conv1(x)
438
+ x = self.bn1(x)
439
+ x = self.relu(x)
440
+ x = self.conv2(x)
441
+ x = self.bn2(x)
442
+ x = self.relu(x)
443
+ x = self.layer1(x)
444
+
445
+ x_list = []
446
+ for i in range(self.stage2_cfg.NUM_BRANCHES):
447
+ if self.transition1[i] is not None:
448
+ x_list.append(self.transition1[i](x))
449
+ else:
450
+ x_list.append(x)
451
+ y_list = self.stage2(x_list)
452
+
453
+ x_list = []
454
+ for i in range(self.stage3_cfg.NUM_BRANCHES):
455
+ if self.transition2[i] is not None:
456
+ x_list.append(self.transition2[i](y_list[-1]))
457
+ else:
458
+ x_list.append(y_list[i])
459
+ y_list = self.stage3(x_list)
460
+
461
+ x_list = []
462
+ for i in range(self.stage4_cfg.NUM_BRANCHES):
463
+ if self.transition3[i] is not None:
464
+ x_list.append(self.transition3[i](y_list[-1]))
465
+ else:
466
+ x_list.append(y_list[i])
467
+ y_list = self.stage4(x_list)
468
+
469
+ assert len(self._out_features) == len(y_list)
470
+ return dict(zip(self._out_features, y_list)) # final_outputs
471
+
472
+
473
+ @BACKBONE_REGISTRY.register()
474
+ def build_pose_hrnet_backbone(cfg, input_shape: ShapeSpec):
475
+ model = PoseHigherResolutionNet(cfg)
476
+ return model
CatVTON/densepose/modeling/inference.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ from typing import Any, List
6
+ import torch
7
+
8
+ from detectron2.structures import Instances
9
+
10
+
11
+ def densepose_inference(densepose_predictor_output: Any, detections: List[Instances]) -> None:
12
+ """
13
+ Splits DensePose predictor outputs into chunks, each chunk corresponds to
14
+ detections on one image. Predictor output chunks are stored in `pred_densepose`
15
+ attribute of the corresponding `Instances` object.
16
+
17
+ Args:
18
+ densepose_predictor_output: a dataclass instance (can be of different types,
19
+ depending on predictor used for inference). Each field can be `None`
20
+ (if the corresponding output was not inferred) or a tensor of size
21
+ [N, ...], where N = N_1 + N_2 + .. + N_k is a total number of
22
+ detections on all images, N_1 is the number of detections on image 1,
23
+ N_2 is the number of detections on image 2, etc.
24
+ detections: a list of objects of type `Instance`, k-th object corresponds
25
+ to detections on k-th image.
26
+ """
27
+ k = 0
28
+ for detection_i in detections:
29
+ if densepose_predictor_output is None:
30
+ # don't add `pred_densepose` attribute
31
+ continue
32
+ n_i = detection_i.__len__()
33
+
34
+ PredictorOutput = type(densepose_predictor_output)
35
+ output_i_dict = {}
36
+ # we assume here that `densepose_predictor_output` is a dataclass object
37
+ for field in fields(densepose_predictor_output):
38
+ field_value = getattr(densepose_predictor_output, field.name)
39
+ # slice tensors
40
+ if isinstance(field_value, torch.Tensor):
41
+ output_i_dict[field.name] = field_value[k : k + n_i]
42
+ # leave others as is
43
+ else:
44
+ output_i_dict[field.name] = field_value
45
+ detection_i.pred_densepose = PredictorOutput(**output_i_dict)
46
+ k += n_i
CatVTON/densepose/modeling/test_time_augmentation.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import copy
5
+ import numpy as np
6
+ import torch
7
+ from fvcore.transforms import HFlipTransform, TransformList
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.data.transforms import RandomRotation, RotationTransform, apply_transform_gens
11
+ from detectron2.modeling.postprocessing import detector_postprocess
12
+ from detectron2.modeling.test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA
13
+
14
+ from ..converters import HFlipConverter
15
+
16
+
17
+ class DensePoseDatasetMapperTTA(DatasetMapperTTA):
18
+ def __init__(self, cfg):
19
+ super().__init__(cfg=cfg)
20
+ self.angles = cfg.TEST.AUG.ROTATION_ANGLES
21
+
22
+ def __call__(self, dataset_dict):
23
+ ret = super().__call__(dataset_dict=dataset_dict)
24
+ numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy()
25
+ for angle in self.angles:
26
+ rotate = RandomRotation(angle=angle, expand=True)
27
+ new_numpy_image, tfms = apply_transform_gens([rotate], np.copy(numpy_image))
28
+ torch_image = torch.from_numpy(np.ascontiguousarray(new_numpy_image.transpose(2, 0, 1)))
29
+ dic = copy.deepcopy(dataset_dict)
30
+ # In DatasetMapperTTA, there is a pre_tfm transform (resize or no-op) that is
31
+ # added at the beginning of each TransformList. That's '.transforms[0]'.
32
+ dic["transforms"] = TransformList(
33
+ [ret[-1]["transforms"].transforms[0]] + tfms.transforms
34
+ )
35
+ dic["image"] = torch_image
36
+ ret.append(dic)
37
+ return ret
38
+
39
+
40
+ class DensePoseGeneralizedRCNNWithTTA(GeneralizedRCNNWithTTA):
41
+ def __init__(self, cfg, model, transform_data, tta_mapper=None, batch_size=1):
42
+ """
43
+ Args:
44
+ cfg (CfgNode):
45
+ model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on.
46
+ transform_data (DensePoseTransformData): contains symmetry label
47
+ transforms used for horizontal flip
48
+ tta_mapper (callable): takes a dataset dict and returns a list of
49
+ augmented versions of the dataset dict. Defaults to
50
+ `DatasetMapperTTA(cfg)`.
51
+ batch_size (int): batch the augmented images into this batch size for inference.
52
+ """
53
+ self._transform_data = transform_data.to(model.device)
54
+ super().__init__(cfg=cfg, model=model, tta_mapper=tta_mapper, batch_size=batch_size)
55
+
56
+ # the implementation follows closely the one from detectron2/modeling
57
+ def _inference_one_image(self, input):
58
+ """
59
+ Args:
60
+ input (dict): one dataset dict with "image" field being a CHW tensor
61
+
62
+ Returns:
63
+ dict: one output dict
64
+ """
65
+ orig_shape = (input["height"], input["width"])
66
+ # For some reason, resize with uint8 slightly increases box AP but decreases densepose AP
67
+ input["image"] = input["image"].to(torch.uint8)
68
+ augmented_inputs, tfms = self._get_augmented_inputs(input)
69
+ # Detect boxes from all augmented versions
70
+ with self._turn_off_roi_heads(["mask_on", "keypoint_on", "densepose_on"]):
71
+ # temporarily disable roi heads
72
+ all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
73
+ merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)
74
+
75
+ if self.cfg.MODEL.MASK_ON or self.cfg.MODEL.DENSEPOSE_ON:
76
+ # Use the detected boxes to obtain new fields
77
+ augmented_instances = self._rescale_detected_boxes(
78
+ augmented_inputs, merged_instances, tfms
79
+ )
80
+ # run forward on the detected boxes
81
+ outputs = self._batch_inference(augmented_inputs, augmented_instances)
82
+ # Delete now useless variables to avoid being out of memory
83
+ del augmented_inputs, augmented_instances
84
+ # average the predictions
85
+ if self.cfg.MODEL.MASK_ON:
86
+ merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
87
+ if self.cfg.MODEL.DENSEPOSE_ON:
88
+ merged_instances.pred_densepose = self._reduce_pred_densepose(outputs, tfms)
89
+ # postprocess
90
+ merged_instances = detector_postprocess(merged_instances, *orig_shape)
91
+ return {"instances": merged_instances}
92
+ else:
93
+ return {"instances": merged_instances}
94
+
95
+ def _get_augmented_boxes(self, augmented_inputs, tfms):
96
+ # Heavily based on detectron2/modeling/test_time_augmentation.py
97
+ # Only difference is that RotationTransform is excluded from bbox computation
98
+ # 1: forward with all augmented images
99
+ outputs = self._batch_inference(augmented_inputs)
100
+ # 2: union the results
101
+ all_boxes = []
102
+ all_scores = []
103
+ all_classes = []
104
+ for output, tfm in zip(outputs, tfms):
105
+ # Need to inverse the transforms on boxes, to obtain results on original image
106
+ if not any(isinstance(t, RotationTransform) for t in tfm.transforms):
107
+ # Some transforms can't compute bbox correctly
108
+ pred_boxes = output.pred_boxes.tensor
109
+ original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
110
+ all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device))
111
+ all_scores.extend(output.scores)
112
+ all_classes.extend(output.pred_classes)
113
+ all_boxes = torch.cat(all_boxes, dim=0)
114
+ return all_boxes, all_scores, all_classes
115
+
116
+ def _reduce_pred_densepose(self, outputs, tfms):
117
+ # Should apply inverse transforms on densepose preds.
118
+ # We assume only rotation, resize & flip are used. pred_masks is a scale-invariant
119
+ # representation, so we handle the other ones specially
120
+ for idx, (output, tfm) in enumerate(zip(outputs, tfms)):
121
+ for t in tfm.transforms:
122
+ for attr in ["coarse_segm", "fine_segm", "u", "v"]:
123
+ setattr(
124
+ output.pred_densepose,
125
+ attr,
126
+ _inverse_rotation(
127
+ getattr(output.pred_densepose, attr), output.pred_boxes.tensor, t
128
+ ),
129
+ )
130
+ if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
131
+ output.pred_densepose = HFlipConverter.convert(
132
+ output.pred_densepose, self._transform_data
133
+ )
134
+ self._incremental_avg_dp(outputs[0].pred_densepose, output.pred_densepose, idx)
135
+ return outputs[0].pred_densepose
136
+
137
+ # incrementally computed average: u_(n + 1) = u_n + (x_(n+1) - u_n) / (n + 1).
138
+ def _incremental_avg_dp(self, avg, new_el, idx):
139
+ for attr in ["coarse_segm", "fine_segm", "u", "v"]:
140
+ setattr(avg, attr, (getattr(avg, attr) * idx + getattr(new_el, attr)) / (idx + 1))
141
+ if idx:
142
+ # Deletion of the > 0 index intermediary values to prevent GPU OOM
143
+ setattr(new_el, attr, None)
144
+ return avg
145
+
146
+
147
+ def _inverse_rotation(densepose_attrs, boxes, transform):
148
+ # resample outputs to image size and rotate back the densepose preds
149
+ # on the rotated images to the space of the original image
150
+ if len(boxes) == 0 or not isinstance(transform, RotationTransform):
151
+ return densepose_attrs
152
+ boxes = boxes.int().cpu().numpy()
153
+ wh_boxes = boxes[:, 2:] - boxes[:, :2] # bboxes in the rotated space
154
+ inv_boxes = rotate_box_inverse(transform, boxes).astype(int) # bboxes in original image
155
+ wh_diff = (inv_boxes[:, 2:] - inv_boxes[:, :2] - wh_boxes) // 2 # diff between new/old bboxes
156
+ rotation_matrix = torch.tensor([transform.rm_image]).to(device=densepose_attrs.device).float()
157
+ rotation_matrix[:, :, -1] = 0
158
+ # To apply grid_sample for rotation, we need to have enough space to fit the original and
159
+ # rotated bboxes. l_bds and r_bds are the left/right bounds that will be used to
160
+ # crop the difference once the rotation is done
161
+ l_bds = np.maximum(0, -wh_diff)
162
+ for i in range(len(densepose_attrs)):
163
+ if min(wh_boxes[i]) <= 0:
164
+ continue
165
+ densepose_attr = densepose_attrs[[i]].clone()
166
+ # 1. Interpolate densepose attribute to size of the rotated bbox
167
+ densepose_attr = F.interpolate(densepose_attr, wh_boxes[i].tolist()[::-1], mode="bilinear")
168
+ # 2. Pad the interpolated attribute so it has room for the original + rotated bbox
169
+ densepose_attr = F.pad(densepose_attr, tuple(np.repeat(np.maximum(0, wh_diff[i]), 2)))
170
+ # 3. Compute rotation grid and transform
171
+ grid = F.affine_grid(rotation_matrix, size=densepose_attr.shape)
172
+ densepose_attr = F.grid_sample(densepose_attr, grid)
173
+ # 4. Compute right bounds and crop the densepose_attr to the size of the original bbox
174
+ r_bds = densepose_attr.shape[2:][::-1] - l_bds[i]
175
+ densepose_attr = densepose_attr[:, :, l_bds[i][1] : r_bds[1], l_bds[i][0] : r_bds[0]]
176
+ if min(densepose_attr.shape) > 0:
177
+ # Interpolate back to the original size of the densepose attribute
178
+ densepose_attr = F.interpolate(
179
+ densepose_attr, densepose_attrs.shape[-2:], mode="bilinear"
180
+ )
181
+ # Adding a very small probability to the background class to fill padded zones
182
+ densepose_attr[:, 0] += 1e-10
183
+ densepose_attrs[i] = densepose_attr
184
+ return densepose_attrs
185
+
186
+
187
+ def rotate_box_inverse(rot_tfm, rotated_box):
188
+ """
189
+ rotated_box is a N * 4 array of [x0, y0, x1, y1] boxes
190
+ When a bbox is rotated, it gets bigger, because we need to surround the tilted bbox
191
+ So when a bbox is rotated then inverse-rotated, it is much bigger than the original
192
+ This function aims to invert the rotation on the box, but also resize it to its original size
193
+ """
194
+ # 1. Compute the inverse rotation of the rotated bboxes (bigger than it )
195
+ invrot_box = rot_tfm.inverse().apply_box(rotated_box)
196
+ h, w = rotated_box[:, 3] - rotated_box[:, 1], rotated_box[:, 2] - rotated_box[:, 0]
197
+ ih, iw = invrot_box[:, 3] - invrot_box[:, 1], invrot_box[:, 2] - invrot_box[:, 0]
198
+ assert 2 * rot_tfm.abs_sin**2 != 1, "45 degrees angle can't be inverted"
199
+ # 2. Inverse the corresponding computation in the rotation transform
200
+ # to get the original height/width of the rotated boxes
201
+ orig_h = (h * rot_tfm.abs_cos - w * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
202
+ orig_w = (w * rot_tfm.abs_cos - h * rot_tfm.abs_sin) / (1 - 2 * rot_tfm.abs_sin**2)
203
+ # 3. Resize the inverse-rotated bboxes to their original size
204
+ invrot_box[:, 0] += (iw - orig_w) / 2
205
+ invrot_box[:, 1] += (ih - orig_h) / 2
206
+ invrot_box[:, 2] -= (iw - orig_w) / 2
207
+ invrot_box[:, 3] -= (ih - orig_h) / 2
208
+
209
+ return invrot_box
CatVTON/densepose/modeling/utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from torch import nn
6
+
7
+
8
+ def initialize_module_params(module: nn.Module) -> None:
9
+ for name, param in module.named_parameters():
10
+ if "bias" in name:
11
+ nn.init.constant_(param, 0)
12
+ elif "weight" in name:
13
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
CatVTON/densepose/utils/__init__.py ADDED
File without changes
CatVTON/densepose/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (160 Bytes). View file
 
CatVTON/densepose/utils/__pycache__/transform.cpython-39.pyc ADDED
Binary file (733 Bytes). View file
 
CatVTON/densepose/utils/dbhelper.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from typing import Any, Dict, Optional, Tuple
5
+
6
+
7
+ class EntrySelector:
8
+ """
9
+ Base class for entry selectors
10
+ """
11
+
12
+ @staticmethod
13
+ def from_string(spec: str) -> "EntrySelector":
14
+ if spec == "*":
15
+ return AllEntrySelector()
16
+ return FieldEntrySelector(spec)
17
+
18
+
19
+ class AllEntrySelector(EntrySelector):
20
+ """
21
+ Selector that accepts all entries
22
+ """
23
+
24
+ SPECIFIER = "*"
25
+
26
+ def __call__(self, entry):
27
+ return True
28
+
29
+
30
+ class FieldEntrySelector(EntrySelector):
31
+ """
32
+ Selector that accepts only entries that match provided field
33
+ specifier(s). Only a limited set of specifiers is supported for now:
34
+ <specifiers>::=<specifier>[<comma><specifiers>]
35
+ <specifier>::=<field_name>[<type_delim><type>]<equal><value_or_range>
36
+ <field_name> is a valid identifier
37
+ <type> ::= "int" | "str"
38
+ <equal> ::= "="
39
+ <comma> ::= ","
40
+ <type_delim> ::= ":"
41
+ <value_or_range> ::= <value> | <range>
42
+ <range> ::= <value><range_delim><value>
43
+ <range_delim> ::= "-"
44
+ <value> is a string without spaces and special symbols
45
+ (e.g. <comma>, <equal>, <type_delim>, <range_delim>)
46
+ """
47
+
48
+ _SPEC_DELIM = ","
49
+ _TYPE_DELIM = ":"
50
+ _RANGE_DELIM = "-"
51
+ _EQUAL = "="
52
+ _ERROR_PREFIX = "Invalid field selector specifier"
53
+
54
+ class _FieldEntryValuePredicate:
55
+ """
56
+ Predicate that checks strict equality for the specified entry field
57
+ """
58
+
59
+ def __init__(self, name: str, typespec: Optional[str], value: str):
60
+ import builtins
61
+
62
+ self.name = name
63
+ self.type = getattr(builtins, typespec) if typespec is not None else str
64
+ self.value = value
65
+
66
+ def __call__(self, entry):
67
+ return entry[self.name] == self.type(self.value)
68
+
69
+ class _FieldEntryRangePredicate:
70
+ """
71
+ Predicate that checks whether an entry field falls into the specified range
72
+ """
73
+
74
+ def __init__(self, name: str, typespec: Optional[str], vmin: str, vmax: str):
75
+ import builtins
76
+
77
+ self.name = name
78
+ self.type = getattr(builtins, typespec) if typespec is not None else str
79
+ self.vmin = vmin
80
+ self.vmax = vmax
81
+
82
+ def __call__(self, entry):
83
+ return (entry[self.name] >= self.type(self.vmin)) and (
84
+ entry[self.name] <= self.type(self.vmax)
85
+ )
86
+
87
+ def __init__(self, spec: str):
88
+ self._predicates = self._parse_specifier_into_predicates(spec)
89
+
90
+ def __call__(self, entry: Dict[str, Any]):
91
+ for predicate in self._predicates:
92
+ if not predicate(entry):
93
+ return False
94
+ return True
95
+
96
+ def _parse_specifier_into_predicates(self, spec: str):
97
+ predicates = []
98
+ specs = spec.split(self._SPEC_DELIM)
99
+ for subspec in specs:
100
+ eq_idx = subspec.find(self._EQUAL)
101
+ if eq_idx > 0:
102
+ field_name_with_type = subspec[:eq_idx]
103
+ field_name, field_type = self._parse_field_name_type(field_name_with_type)
104
+ field_value_or_range = subspec[eq_idx + 1 :]
105
+ if self._is_range_spec(field_value_or_range):
106
+ vmin, vmax = self._get_range_spec(field_value_or_range)
107
+ predicate = FieldEntrySelector._FieldEntryRangePredicate(
108
+ field_name, field_type, vmin, vmax
109
+ )
110
+ else:
111
+ predicate = FieldEntrySelector._FieldEntryValuePredicate(
112
+ field_name, field_type, field_value_or_range
113
+ )
114
+ predicates.append(predicate)
115
+ elif eq_idx == 0:
116
+ self._parse_error(f'"{subspec}", field name is empty!')
117
+ else:
118
+ self._parse_error(f'"{subspec}", should have format ' "<field>=<value_or_range>!")
119
+ return predicates
120
+
121
+ def _parse_field_name_type(self, field_name_with_type: str) -> Tuple[str, Optional[str]]:
122
+ type_delim_idx = field_name_with_type.find(self._TYPE_DELIM)
123
+ if type_delim_idx > 0:
124
+ field_name = field_name_with_type[:type_delim_idx]
125
+ field_type = field_name_with_type[type_delim_idx + 1 :]
126
+ elif type_delim_idx == 0:
127
+ self._parse_error(f'"{field_name_with_type}", field name is empty!')
128
+ else:
129
+ field_name = field_name_with_type
130
+ field_type = None
131
+ # pyre-fixme[61]: `field_name` may not be initialized here.
132
+ # pyre-fixme[61]: `field_type` may not be initialized here.
133
+ return field_name, field_type
134
+
135
+ def _is_range_spec(self, field_value_or_range):
136
+ delim_idx = field_value_or_range.find(self._RANGE_DELIM)
137
+ return delim_idx > 0
138
+
139
+ def _get_range_spec(self, field_value_or_range):
140
+ if self._is_range_spec(field_value_or_range):
141
+ delim_idx = field_value_or_range.find(self._RANGE_DELIM)
142
+ vmin = field_value_or_range[:delim_idx]
143
+ vmax = field_value_or_range[delim_idx + 1 :]
144
+ return vmin, vmax
145
+ else:
146
+ self._parse_error('"field_value_or_range", range of values expected!')
147
+
148
+ def _parse_error(self, msg):
149
+ raise ValueError(f"{self._ERROR_PREFIX}: {msg}")
CatVTON/densepose/utils/logger.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ import logging
5
+
6
+
7
+ def verbosity_to_level(verbosity) -> int:
8
+ if verbosity is not None:
9
+ if verbosity == 0:
10
+ return logging.WARNING
11
+ elif verbosity == 1:
12
+ return logging.INFO
13
+ elif verbosity >= 2:
14
+ return logging.DEBUG
15
+ return logging.WARNING
CatVTON/densepose/utils/transform.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from detectron2.data import MetadataCatalog
5
+ from detectron2.utils.file_io import PathManager
6
+
7
+ from densepose import DensePoseTransformData
8
+
9
+
10
+ def load_for_dataset(dataset_name):
11
+ path = MetadataCatalog.get(dataset_name).densepose_transform_src
12
+ densepose_transform_data_fpath = PathManager.get_local_path(path)
13
+ return DensePoseTransformData.load(densepose_transform_data_fpath)
14
+
15
+
16
+ def load_from_cfg(cfg):
17
+ return load_for_dataset(cfg.DATASETS.TEST[0])
CatVTON/model/DensePose/__init__.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import glob
3
+ import os
4
+ from random import randint
5
+ import shutil
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from densepose import add_densepose_config
13
+ from densepose.vis.base import CompoundVisualizer
14
+ from densepose.vis.densepose_results import DensePoseResultsFineSegmentationVisualizer
15
+ from densepose.vis.extractor import create_extractor, CompoundExtractor
16
+ from detectron2.config import get_cfg
17
+ from detectron2.data.detection_utils import read_image
18
+ from detectron2.engine.defaults import DefaultPredictor
19
+
20
+
21
+ class DensePose:
22
+ """
23
+ DensePose used in this project is from Detectron2 (https://github.com/facebookresearch/detectron2).
24
+ These codes are modified from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose.
25
+ The checkpoint is downloaded from https://github.com/facebookresearch/detectron2/blob/main/projects/DensePose/doc/DENSEPOSE_IUV.md#ModelZoo.
26
+
27
+ We use the model R_50_FPN_s1x with id 165712039, but other models should also work.
28
+ The config file is downloaded from https://github.com/facebookresearch/detectron2/tree/main/projects/DensePose/configs.
29
+ Noted that the config file should match the model checkpoint and Base-DensePose-RCNN-FPN.yaml is also needed.
30
+ """
31
+
32
+ def __init__(self, model_path="./checkpoints/densepose_", device="cuda"):
33
+ self.device = device
34
+ self.config_path = os.path.join(model_path, 'densepose_rcnn_R_50_FPN_s1x.yaml')
35
+ self.model_path = os.path.join(model_path, 'model_final_162be9.pkl')
36
+ self.visualizations = ["dp_segm"]
37
+ self.VISUALIZERS = {"dp_segm": DensePoseResultsFineSegmentationVisualizer}
38
+ self.min_score = 0.8
39
+
40
+ self.cfg = self.setup_config()
41
+ self.predictor = DefaultPredictor(self.cfg)
42
+ self.predictor.model.to(self.device)
43
+
44
+ def setup_config(self):
45
+ opts = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(self.min_score)]
46
+ cfg = get_cfg()
47
+ add_densepose_config(cfg)
48
+ cfg.merge_from_file(self.config_path)
49
+ cfg.merge_from_list(opts)
50
+ cfg.MODEL.WEIGHTS = self.model_path
51
+ cfg.freeze()
52
+ return cfg
53
+
54
+ @staticmethod
55
+ def _get_input_file_list(input_spec: str):
56
+ if os.path.isdir(input_spec):
57
+ file_list = [os.path.join(input_spec, fname) for fname in os.listdir(input_spec)
58
+ if os.path.isfile(os.path.join(input_spec, fname))]
59
+ elif os.path.isfile(input_spec):
60
+ file_list = [input_spec]
61
+ else:
62
+ file_list = glob.glob(input_spec)
63
+ return file_list
64
+
65
+ def create_context(self, cfg, output_path):
66
+ vis_specs = self.visualizations
67
+ visualizers = []
68
+ extractors = []
69
+ for vis_spec in vis_specs:
70
+ texture_atlas = texture_atlases_dict = None
71
+ vis = self.VISUALIZERS[vis_spec](
72
+ cfg=cfg,
73
+ texture_atlas=texture_atlas,
74
+ texture_atlases_dict=texture_atlases_dict,
75
+ alpha=1.0
76
+ )
77
+ visualizers.append(vis)
78
+ extractor = create_extractor(vis)
79
+ extractors.append(extractor)
80
+ visualizer = CompoundVisualizer(visualizers)
81
+ extractor = CompoundExtractor(extractors)
82
+ context = {
83
+ "extractor": extractor,
84
+ "visualizer": visualizer,
85
+ "out_fname": output_path,
86
+ "entry_idx": 0,
87
+ }
88
+ return context
89
+
90
+ def execute_on_outputs(self, context, entry, outputs):
91
+ extractor = context["extractor"]
92
+
93
+ data = extractor(outputs)
94
+
95
+ H, W, _ = entry["image"].shape
96
+ result = np.zeros((H, W), dtype=np.uint8)
97
+
98
+ data, box = data[0]
99
+ x, y, w, h = [int(_) for _ in box[0].cpu().numpy()]
100
+ i_array = data[0].labels[None].cpu().numpy()[0]
101
+ result[y:y + h, x:x + w] = i_array
102
+ result = Image.fromarray(result)
103
+ result.save(context["out_fname"])
104
+
105
+ def __call__(self, image_or_path, resize=512) -> Image.Image:
106
+ """
107
+ :param image_or_path: Path of the input image.
108
+ :param resize: Resize the input image if its max size is larger than this value.
109
+ :return: Dense pose image.
110
+ """
111
+ # random tmp path with timestamp
112
+ tmp_path = f"./densepose_/tmp/"
113
+ if not os.path.exists(tmp_path):
114
+ os.makedirs(tmp_path)
115
+
116
+ image_path = os.path.join(tmp_path, f"{int(time.time())}-{self.device}-{randint(0, 100000)}.png")
117
+ if isinstance(image_or_path, str):
118
+ assert image_or_path.split(".")[-1] in ["jpg", "png"], "Only support jpg and png images."
119
+ shutil.copy(image_or_path, image_path)
120
+ elif isinstance(image_or_path, Image.Image):
121
+ image_or_path.save(image_path)
122
+ else:
123
+ shutil.rmtree(tmp_path)
124
+ raise TypeError("image_path must be str or PIL.Image.Image")
125
+
126
+ output_path = image_path.replace(".png", "_dense.png").replace(".jpg", "_dense.png")
127
+ w, h = Image.open(image_path).size
128
+
129
+ file_list = self._get_input_file_list(image_path)
130
+ assert len(file_list), "No input images found!"
131
+ context = self.create_context(self.cfg, output_path)
132
+ for file_name in file_list:
133
+ img = read_image(file_name, format="BGR") # predictor expects BGR image.
134
+ # resize
135
+ if (_ := max(img.shape)) > resize:
136
+ scale = resize / _
137
+ img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
138
+
139
+ with torch.no_grad():
140
+ outputs = self.predictor(img)["instances"]
141
+ try:
142
+ self.execute_on_outputs(context, {"file_name": file_name, "image": img}, outputs)
143
+ except Exception as e:
144
+ null_gray = Image.new('L', (1, 1))
145
+ null_gray.save(output_path)
146
+
147
+ dense_gray = Image.open(output_path).convert("L")
148
+ dense_gray = dense_gray.resize((w, h), Image.NEAREST)
149
+ # remove image_path and output_path
150
+ os.remove(image_path)
151
+ os.remove(output_path)
152
+
153
+
154
+ return dense_gray
155
+
156
+
157
+ if __name__ == '__main__':
158
+ pass
CatVTON/model/DensePose/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
CatVTON/model/DensePose/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (8.91 kB). View file
 
CatVTON/model/DensePose/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (5.83 kB). View file
 
CatVTON/model/SCHP/__init__.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.SCHP import networks
2
+ from model.SCHP.utils.transforms import get_affine_transform, transform_logits
3
+
4
+ from collections import OrderedDict
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ def get_palette(num_cls):
12
+ """ Returns the color map for visualizing the segmentation mask.
13
+ Args:
14
+ num_cls: Number of classes
15
+ Returns:
16
+ The color map
17
+ """
18
+ n = num_cls
19
+ palette = [0] * (n * 3)
20
+ for j in range(0, n):
21
+ lab = j
22
+ palette[j * 3 + 0] = 0
23
+ palette[j * 3 + 1] = 0
24
+ palette[j * 3 + 2] = 0
25
+ i = 0
26
+ while lab:
27
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
28
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
29
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
30
+ i += 1
31
+ lab >>= 3
32
+ return palette
33
+
34
+ dataset_settings = {
35
+ 'lip': {
36
+ 'input_size': [473, 473],
37
+ 'num_classes': 20,
38
+ 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
39
+ 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
40
+ 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
41
+ },
42
+ 'atr': {
43
+ 'input_size': [512, 512],
44
+ 'num_classes': 18,
45
+ 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
46
+ 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
47
+ },
48
+ 'pascal': {
49
+ 'input_size': [512, 512],
50
+ 'num_classes': 7,
51
+ 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
52
+ }
53
+ }
54
+
55
+ class SCHP:
56
+ def __init__(self, ckpt_path, device):
57
+ dataset_type = None
58
+ if 'lip' in ckpt_path:
59
+ dataset_type = 'lip'
60
+ elif 'atr' in ckpt_path:
61
+ dataset_type = 'atr'
62
+ elif 'pascal' in ckpt_path:
63
+ dataset_type = 'pascal'
64
+ assert dataset_type is not None, 'Dataset type not found in checkpoint path'
65
+ self.device = device
66
+ self.num_classes = dataset_settings[dataset_type]['num_classes']
67
+ self.input_size = dataset_settings[dataset_type]['input_size']
68
+ self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
69
+ self.palette = get_palette(self.num_classes)
70
+
71
+ self.label = dataset_settings[dataset_type]['label']
72
+ self.model = networks.init_model('resnet101', num_classes=self.num_classes, pretrained=None).to(device)
73
+ self.load_ckpt(ckpt_path)
74
+ self.model.eval()
75
+
76
+ self.transform = transforms.Compose([
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
79
+ ])
80
+ self.upsample = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=True)
81
+
82
+
83
+ def load_ckpt(self, ckpt_path):
84
+ rename_map = {
85
+ "decoder.conv3.2.weight": "decoder.conv3.3.weight",
86
+ "decoder.conv3.3.weight": "decoder.conv3.4.weight",
87
+ "decoder.conv3.3.bias": "decoder.conv3.4.bias",
88
+ "decoder.conv3.3.running_mean": "decoder.conv3.4.running_mean",
89
+ "decoder.conv3.3.running_var": "decoder.conv3.4.running_var",
90
+ "fushion.3.weight": "fushion.4.weight",
91
+ "fushion.3.bias": "fushion.4.bias",
92
+ }
93
+ state_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
94
+ new_state_dict = OrderedDict()
95
+ for k, v in state_dict.items():
96
+ name = k[7:] # remove `module.`
97
+ new_state_dict[name] = v
98
+ new_state_dict_ = OrderedDict()
99
+ for k, v in list(new_state_dict.items()):
100
+ if k in rename_map:
101
+ new_state_dict_[rename_map[k]] = v
102
+ else:
103
+ new_state_dict_[k] = v
104
+ self.model.load_state_dict(new_state_dict_, strict=False)
105
+
106
+ def _box2cs(self, box):
107
+ x, y, w, h = box[:4]
108
+ return self._xywh2cs(x, y, w, h)
109
+
110
+ def _xywh2cs(self, x, y, w, h):
111
+ center = np.zeros((2), dtype=np.float32)
112
+ center[0] = x + w * 0.5
113
+ center[1] = y + h * 0.5
114
+ if w > self.aspect_ratio * h:
115
+ h = w * 1.0 / self.aspect_ratio
116
+ elif w < self.aspect_ratio * h:
117
+ w = h * self.aspect_ratio
118
+ scale = np.array([w, h], dtype=np.float32)
119
+ return center, scale
120
+
121
+ def preprocess(self, image):
122
+ if isinstance(image, str):
123
+ img = cv2.imread(image, cv2.IMREAD_COLOR)
124
+ elif isinstance(image, Image.Image):
125
+ # to cv2 format
126
+ img = np.array(image)
127
+
128
+ h, w, _ = img.shape
129
+ # Get person center and scale
130
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
131
+ r = 0
132
+ trans = get_affine_transform(person_center, s, r, self.input_size)
133
+ input = cv2.warpAffine(
134
+ img,
135
+ trans,
136
+ (int(self.input_size[1]), int(self.input_size[0])),
137
+ flags=cv2.INTER_LINEAR,
138
+ borderMode=cv2.BORDER_CONSTANT,
139
+ borderValue=(0, 0, 0))
140
+
141
+ input = self.transform(input).to(self.device).unsqueeze(0)
142
+ meta = {
143
+ 'center': person_center,
144
+ 'height': h,
145
+ 'width': w,
146
+ 'scale': s,
147
+ 'rotation': r
148
+ }
149
+ return input, meta
150
+
151
+
152
+ def __call__(self, image_or_path):
153
+ if isinstance(image_or_path, list):
154
+ image_list = []
155
+ meta_list = []
156
+ for image in image_or_path:
157
+ image, meta = self.preprocess(image)
158
+ image_list.append(image)
159
+ meta_list.append(meta)
160
+ image = torch.cat(image_list, dim=0)
161
+ else:
162
+ image, meta = self.preprocess(image_or_path)
163
+ meta_list = [meta]
164
+
165
+ output = self.model(image)
166
+ # upsample_outputs = self.upsample(output[0][-1])
167
+ upsample_outputs = self.upsample(output)
168
+ upsample_outputs = upsample_outputs.permute(0, 2, 3, 1) # BCHW -> BHWC
169
+
170
+ output_img_list = []
171
+ for upsample_output, meta in zip(upsample_outputs, meta_list):
172
+ c, s, w, h = meta['center'], meta['scale'], meta['width'], meta['height']
173
+ logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=self.input_size)
174
+ parsing_result = np.argmax(logits_result, axis=2)
175
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
176
+ output_img.putpalette(self.palette)
177
+ output_img_list.append(output_img)
178
+
179
+ return output_img_list[0] if len(output_img_list) == 1 else output_img_list