b3h-young123 commited on
Commit
88c3efd
·
verified ·
1 Parent(s): 21e0fdd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. OOTDiffusion/checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors +3 -0
  2. OOTDiffusion/checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors +3 -0
  3. OOTDiffusion/checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors +3 -0
  4. OOTDiffusion/checkpoints/ootd/text_encoder/pytorch_model.bin +3 -0
  5. OOTDiffusion/ootd/inference_ootd.py +133 -0
  6. OOTDiffusion/ootd/inference_ootd_dc.py +132 -0
  7. OOTDiffusion/ootd/inference_ootd_hd.py +132 -0
  8. OOTDiffusion/ootd/pipelines_ootd/attention_garm.py +402 -0
  9. OOTDiffusion/ootd/pipelines_ootd/attention_vton.py +407 -0
  10. OOTDiffusion/ootd/pipelines_ootd/pipeline_ootd.py +846 -0
  11. OOTDiffusion/ootd/pipelines_ootd/transformer_garm_2d.py +449 -0
  12. OOTDiffusion/ootd/pipelines_ootd/transformer_vton_2d.py +452 -0
  13. OOTDiffusion/ootd/pipelines_ootd/unet_garm_2d_blocks.py +0 -0
  14. OOTDiffusion/ootd/pipelines_ootd/unet_garm_2d_condition.py +1183 -0
  15. OOTDiffusion/ootd/pipelines_ootd/unet_vton_2d_blocks.py +0 -0
  16. OOTDiffusion/ootd/pipelines_ootd/unet_vton_2d_condition.py +1183 -0
  17. OOTDiffusion/preprocess/humanparsing/datasets/__init__.py +0 -0
  18. OOTDiffusion/preprocess/humanparsing/datasets/datasets.py +201 -0
  19. OOTDiffusion/preprocess/humanparsing/datasets/simple_extractor_dataset.py +89 -0
  20. OOTDiffusion/preprocess/humanparsing/datasets/target_generation.py +40 -0
  21. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/CODE_OF_CONDUCT.md +5 -0
  22. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/CONTRIBUTING.md +49 -0
  23. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/Detectron2-Logo-Horz.svg +1 -0
  24. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE.md +5 -0
  25. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/bugs.md +36 -0
  26. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/config.yml +9 -0
  27. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/feature-request.md +31 -0
  28. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/questions-help-support.md +26 -0
  29. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md +45 -0
  30. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/pull_request_template.md +9 -0
  31. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/Dockerfile +49 -0
  32. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/Dockerfile-circleci +17 -0
  33. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/README.md +36 -0
  34. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/docker-compose.yml +18 -0
  35. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/datasets.md +221 -0
  36. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/evaluation.md +43 -0
  37. OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/index.rst +18 -0
  38. OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_datasets.py +200 -0
  39. OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_evaluate.py +210 -0
  40. OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_train.py +232 -0
  41. OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/make_id_list.py +13 -0
  42. OOTDiffusion/run/examples/garment/00055_00.jpg +3 -0
  43. OOTDiffusion/run/examples/garment/00126_00.jpg +3 -0
  44. OOTDiffusion/run/examples/garment/00151_00.jpg +3 -0
  45. OOTDiffusion/run/examples/garment/00470_00.jpg +3 -0
  46. OOTDiffusion/run/examples/garment/02015_00.jpg +3 -0
  47. OOTDiffusion/run/examples/garment/02305_00.jpg +3 -0
  48. OOTDiffusion/run/examples/garment/03032_00.jpg +3 -0
  49. OOTDiffusion/run/examples/garment/03244_00.jpg +3 -0
  50. OOTDiffusion/run/examples/garment/04825_00.jpg +3 -0
OOTDiffusion/checkpoints/ootd/ootd_dc/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b3cb1398172757fe1f49c130d104ec4da8d20d2132958dfff0748a2b6a7506b
3
+ size 3438213624
OOTDiffusion/checkpoints/ootd/ootd_hd/checkpoint-36000/unet_garm/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dea03c6b3339f13e1432711608d5c7ac83fcb9b14a430aee52b0015834ba41da
3
+ size 3438167536
OOTDiffusion/checkpoints/ootd/ootd_hd/checkpoint-36000/unet_vton/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3587b5025565060842eac78c74f87fc06d8b82c2b51d9938a492d42858679fe
3
+ size 3438213624
OOTDiffusion/checkpoints/ootd/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67
3
+ size 492305335
OOTDiffusion/ootd/inference_ootd.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import cv2
12
+
13
+ import random
14
+ import time
15
+ import pdb
16
+
17
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
18
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
19
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
20
+ from diffusers import UniPCMultistepScheduler
21
+ from diffusers import AutoencoderKL
22
+
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
+ from transformers import CLIPTextModel, CLIPTokenizer
27
+
28
+ VIT_PATH = "openai/clip-vit-large-patch14"
29
+ VAE_PATH = "./checkpoints/ootd"
30
+ UNET_PATH = "./checkpoints/ootd/ootd_hd/checkpoint-36000"
31
+ MODEL_PATH = "./checkpoints/ootd"
32
+
33
+ class OOTDiffusion:
34
+
35
+ def __init__(self, gpu_id):
36
+ # self.gpu_id = 'cuda:' + str(gpu_id)
37
+
38
+ vae = AutoencoderKL.from_pretrained(
39
+ VAE_PATH,
40
+ subfolder="vae",
41
+ torch_dtype=torch.float16,
42
+ )
43
+
44
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
45
+ UNET_PATH,
46
+ subfolder="unet_garm",
47
+ torch_dtype=torch.float16,
48
+ use_safetensors=True,
49
+ )
50
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
51
+ UNET_PATH,
52
+ subfolder="unet_vton",
53
+ torch_dtype=torch.float16,
54
+ use_safetensors=True,
55
+ )
56
+
57
+ self.pipe = OotdPipeline.from_pretrained(
58
+ MODEL_PATH,
59
+ unet_garm=unet_garm,
60
+ unet_vton=unet_vton,
61
+ vae=vae,
62
+ torch_dtype=torch.float16,
63
+ variant="fp16",
64
+ use_safetensors=True,
65
+ safety_checker=None,
66
+ requires_safety_checker=False,
67
+ )#.to(self.gpu_id)
68
+
69
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
70
+
71
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
72
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
73
+
74
+ self.tokenizer = CLIPTokenizer.from_pretrained(
75
+ MODEL_PATH,
76
+ subfolder="tokenizer",
77
+ )
78
+ self.text_encoder = CLIPTextModel.from_pretrained(
79
+ MODEL_PATH,
80
+ subfolder="text_encoder",
81
+ )#.to(self.gpu_id)
82
+
83
+
84
+ def tokenize_captions(self, captions, max_length):
85
+ inputs = self.tokenizer(
86
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
87
+ )
88
+ return inputs.input_ids
89
+
90
+
91
+ def __call__(self,
92
+ model_type='hd',
93
+ category='upperbody',
94
+ image_garm=None,
95
+ image_vton=None,
96
+ mask=None,
97
+ image_ori=None,
98
+ num_samples=1,
99
+ num_steps=20,
100
+ image_scale=1.0,
101
+ seed=-1,
102
+ ):
103
+ if seed == -1:
104
+ random.seed(time.time())
105
+ seed = random.randint(0, 2147483647)
106
+ print('Initial seed: ' + str(seed))
107
+ generator = torch.manual_seed(seed)
108
+
109
+ with torch.no_grad():
110
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
111
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
112
+ prompt_image = prompt_image.unsqueeze(1)
113
+ if model_type == 'hd':
114
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
115
+ prompt_embeds[:, 1:] = prompt_image[:]
116
+ elif model_type == 'dc':
117
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
118
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
119
+ else:
120
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
121
+
122
+ images = self.pipe(prompt_embeds=prompt_embeds,
123
+ image_garm=image_garm,
124
+ image_vton=image_vton,
125
+ mask=mask,
126
+ image_ori=image_ori,
127
+ num_inference_steps=num_steps,
128
+ image_guidance_scale=image_scale,
129
+ num_images_per_prompt=num_samples,
130
+ generator=generator,
131
+ ).images
132
+
133
+ return images
OOTDiffusion/ootd/inference_ootd_dc.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ import random
13
+ import time
14
+ import pdb
15
+
16
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
17
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
+ from diffusers import UniPCMultistepScheduler
20
+ from diffusers import AutoencoderKL
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ VIT_PATH = "openai/clip-vit-large-patch14"
28
+ VAE_PATH = "./checkpoints/ootd"
29
+ UNET_PATH = "./checkpoints/ootd/ootd_dc/checkpoint-36000"
30
+ MODEL_PATH = "./checkpoints/ootd"
31
+
32
+ class OOTDiffusionDC:
33
+
34
+ def __init__(self, gpu_id):
35
+ # self.gpu_id = 'cuda:' + str(gpu_id)
36
+
37
+ vae = AutoencoderKL.from_pretrained(
38
+ VAE_PATH,
39
+ subfolder="vae",
40
+ torch_dtype=torch.float16,
41
+ )
42
+
43
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
+ UNET_PATH,
45
+ subfolder="unet_garm",
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ )
49
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
50
+ UNET_PATH,
51
+ subfolder="unet_vton",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
+
56
+ self.pipe = OotdPipeline.from_pretrained(
57
+ MODEL_PATH,
58
+ unet_garm=unet_garm,
59
+ unet_vton=unet_vton,
60
+ vae=vae,
61
+ torch_dtype=torch.float16,
62
+ variant="fp16",
63
+ use_safetensors=True,
64
+ safety_checker=None,
65
+ requires_safety_checker=False,
66
+ )#.to(self.gpu_id)
67
+
68
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
+
70
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
72
+
73
+ self.tokenizer = CLIPTokenizer.from_pretrained(
74
+ MODEL_PATH,
75
+ subfolder="tokenizer",
76
+ )
77
+ self.text_encoder = CLIPTextModel.from_pretrained(
78
+ MODEL_PATH,
79
+ subfolder="text_encoder",
80
+ )#.to(self.gpu_id)
81
+
82
+
83
+ def tokenize_captions(self, captions, max_length):
84
+ inputs = self.tokenizer(
85
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
+ )
87
+ return inputs.input_ids
88
+
89
+
90
+ def __call__(self,
91
+ model_type='hd',
92
+ category='upperbody',
93
+ image_garm=None,
94
+ image_vton=None,
95
+ mask=None,
96
+ image_ori=None,
97
+ num_samples=1,
98
+ num_steps=20,
99
+ image_scale=1.0,
100
+ seed=-1,
101
+ ):
102
+ if seed == -1:
103
+ random.seed(time.time())
104
+ seed = random.randint(0, 2147483647)
105
+ print('Initial seed: ' + str(seed))
106
+ generator = torch.manual_seed(seed)
107
+
108
+ with torch.no_grad():
109
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
110
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
+ prompt_image = prompt_image.unsqueeze(1)
112
+ if model_type == 'hd':
113
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
114
+ prompt_embeds[:, 1:] = prompt_image[:]
115
+ elif model_type == 'dc':
116
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
117
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
+ else:
119
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
+
121
+ images = self.pipe(prompt_embeds=prompt_embeds,
122
+ image_garm=image_garm,
123
+ image_vton=image_vton,
124
+ mask=mask,
125
+ image_ori=image_ori,
126
+ num_inference_steps=num_steps,
127
+ image_guidance_scale=image_scale,
128
+ num_images_per_prompt=num_samples,
129
+ generator=generator,
130
+ ).images
131
+
132
+ return images
OOTDiffusion/ootd/inference_ootd_hd.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from pathlib import Path
3
+ import sys
4
+ PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
5
+ sys.path.insert(0, str(PROJECT_ROOT))
6
+ import os
7
+ import torch
8
+ import numpy as np
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ import random
13
+ import time
14
+ import pdb
15
+
16
+ from pipelines_ootd.pipeline_ootd import OotdPipeline
17
+ from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
18
+ from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
19
+ from diffusers import UniPCMultistepScheduler
20
+ from diffusers import AutoencoderKL
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ VIT_PATH = "openai/clip-vit-large-patch14"
28
+ VAE_PATH = "./checkpoints/ootd"
29
+ UNET_PATH = "./checkpoints/ootd/ootd_hd/checkpoint-36000"
30
+ MODEL_PATH = "./checkpoints/ootd"
31
+
32
+ class OOTDiffusionHD:
33
+
34
+ def __init__(self, gpu_id):
35
+ # self.gpu_id = 'cuda:' + str(gpu_id)
36
+
37
+ vae = AutoencoderKL.from_pretrained(
38
+ VAE_PATH,
39
+ subfolder="vae",
40
+ torch_dtype=torch.float16,
41
+ )
42
+
43
+ unet_garm = UNetGarm2DConditionModel.from_pretrained(
44
+ UNET_PATH,
45
+ subfolder="unet_garm",
46
+ torch_dtype=torch.float16,
47
+ use_safetensors=True,
48
+ )
49
+ unet_vton = UNetVton2DConditionModel.from_pretrained(
50
+ UNET_PATH,
51
+ subfolder="unet_vton",
52
+ torch_dtype=torch.float16,
53
+ use_safetensors=True,
54
+ )
55
+
56
+ self.pipe = OotdPipeline.from_pretrained(
57
+ MODEL_PATH,
58
+ unet_garm=unet_garm,
59
+ unet_vton=unet_vton,
60
+ vae=vae,
61
+ torch_dtype=torch.float16,
62
+ variant="fp16",
63
+ use_safetensors=True,
64
+ safety_checker=None,
65
+ requires_safety_checker=False,
66
+ )#.to(self.gpu_id)
67
+
68
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
69
+
70
+ self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
71
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
72
+
73
+ self.tokenizer = CLIPTokenizer.from_pretrained(
74
+ MODEL_PATH,
75
+ subfolder="tokenizer",
76
+ )
77
+ self.text_encoder = CLIPTextModel.from_pretrained(
78
+ MODEL_PATH,
79
+ subfolder="text_encoder",
80
+ )#.to(self.gpu_id)
81
+
82
+
83
+ def tokenize_captions(self, captions, max_length):
84
+ inputs = self.tokenizer(
85
+ captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
86
+ )
87
+ return inputs.input_ids
88
+
89
+
90
+ def __call__(self,
91
+ model_type='hd',
92
+ category='upperbody',
93
+ image_garm=None,
94
+ image_vton=None,
95
+ mask=None,
96
+ image_ori=None,
97
+ num_samples=1,
98
+ num_steps=20,
99
+ image_scale=1.0,
100
+ seed=-1,
101
+ ):
102
+ if seed == -1:
103
+ random.seed(time.time())
104
+ seed = random.randint(0, 2147483647)
105
+ print('Initial seed: ' + str(seed))
106
+ generator = torch.manual_seed(seed)
107
+
108
+ with torch.no_grad():
109
+ prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
110
+ prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
111
+ prompt_image = prompt_image.unsqueeze(1)
112
+ if model_type == 'hd':
113
+ prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
114
+ prompt_embeds[:, 1:] = prompt_image[:]
115
+ elif model_type == 'dc':
116
+ prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
117
+ prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
118
+ else:
119
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
120
+
121
+ images = self.pipe(prompt_embeds=prompt_embeds,
122
+ image_garm=image_garm,
123
+ image_vton=image_vton,
124
+ mask=mask,
125
+ image_ori=image_ori,
126
+ num_inference_steps=num_steps,
127
+ image_guidance_scale=image_scale,
128
+ num_images_per_prompt=num_samples,
129
+ generator=generator,
130
+ ).images
131
+
132
+ return images
OOTDiffusion/ootd/pipelines_ootd/attention_garm.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.utils import USE_PEFT_BACKEND
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
+
29
+
30
+ @maybe_allow_in_graph
31
+ class GatedSelfAttentionDense(nn.Module):
32
+ r"""
33
+ A gated self-attention dense layer that combines visual features and object features.
34
+
35
+ Parameters:
36
+ query_dim (`int`): The number of channels in the query.
37
+ context_dim (`int`): The number of channels in the context.
38
+ n_heads (`int`): The number of heads to use for attention.
39
+ d_head (`int`): The number of channels in each head.
40
+ """
41
+
42
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
+ super().__init__()
44
+
45
+ # we need a linear projection since we need cat visual feature and obj feature
46
+ self.linear = nn.Linear(context_dim, query_dim)
47
+
48
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
50
+
51
+ self.norm1 = nn.LayerNorm(query_dim)
52
+ self.norm2 = nn.LayerNorm(query_dim)
53
+
54
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
+
57
+ self.enabled = True
58
+
59
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
+ if not self.enabled:
61
+ return x
62
+
63
+ n_visual = x.shape[1]
64
+ objs = self.linear(objs)
65
+
66
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
+
69
+ return x
70
+
71
+
72
+ @maybe_allow_in_graph
73
+ class BasicTransformerBlock(nn.Module):
74
+ r"""
75
+ A basic Transformer block.
76
+
77
+ Parameters:
78
+ dim (`int`): The number of channels in the input and output.
79
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
+ attention_head_dim (`int`): The number of channels in each head.
81
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm (:
85
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
+ attention_bias (:
87
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
+ only_cross_attention (`bool`, *optional*):
89
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
+ double_self_attention (`bool`, *optional*):
91
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
92
+ upcast_attention (`bool`, *optional*):
93
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
+ Whether to use learnable elementwise affine parameters for normalization.
96
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
+ final_dropout (`bool` *optional*, defaults to False):
99
+ Whether to apply a final dropout after the last feed-forward layer.
100
+ attention_type (`str`, *optional*, defaults to `"default"`):
101
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
+ positional_embeddings (`str`, *optional*, defaults to `None`):
103
+ The type of positional embeddings to apply to.
104
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
+ The maximum number of positional embeddings to apply.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ num_attention_heads: int,
112
+ attention_head_dim: int,
113
+ dropout=0.0,
114
+ cross_attention_dim: Optional[int] = None,
115
+ activation_fn: str = "geglu",
116
+ num_embeds_ada_norm: Optional[int] = None,
117
+ attention_bias: bool = False,
118
+ only_cross_attention: bool = False,
119
+ double_self_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ norm_elementwise_affine: bool = True,
122
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
+ norm_eps: float = 1e-5,
124
+ final_dropout: bool = False,
125
+ attention_type: str = "default",
126
+ positional_embeddings: Optional[str] = None,
127
+ num_positional_embeddings: Optional[int] = None,
128
+ ):
129
+ super().__init__()
130
+ self.only_cross_attention = only_cross_attention
131
+
132
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
+ self.use_layer_norm = norm_type == "layer_norm"
136
+
137
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
+ raise ValueError(
139
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
+ )
142
+
143
+ if positional_embeddings and (num_positional_embeddings is None):
144
+ raise ValueError(
145
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
+ )
147
+
148
+ if positional_embeddings == "sinusoidal":
149
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
+ else:
151
+ self.pos_embed = None
152
+
153
+ # Define 3 blocks. Each block has its own normalization layer.
154
+ # 1. Self-Attn
155
+ if self.use_ada_layer_norm:
156
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
+ elif self.use_ada_layer_norm_zero:
158
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
+ else:
160
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
+
162
+ self.attn1 = Attention(
163
+ query_dim=dim,
164
+ heads=num_attention_heads,
165
+ dim_head=attention_head_dim,
166
+ dropout=dropout,
167
+ bias=attention_bias,
168
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
+ upcast_attention=upcast_attention,
170
+ )
171
+
172
+ # 2. Cross-Attn
173
+ if cross_attention_dim is not None or double_self_attention:
174
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
+ # the second cross attention block.
177
+ self.norm2 = (
178
+ AdaLayerNorm(dim, num_embeds_ada_norm)
179
+ if self.use_ada_layer_norm
180
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
+ )
182
+ self.attn2 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ ) # is self-attn if encoder_hidden_states is none
191
+ else:
192
+ self.norm2 = None
193
+ self.attn2 = None
194
+
195
+ # 3. Feed-forward
196
+ if not self.use_ada_layer_norm_single:
197
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
+
199
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
+
201
+ # 4. Fuser
202
+ if attention_type == "gated" or attention_type == "gated-text-image":
203
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
+
205
+ # 5. Scale-shift for PixArt-Alpha.
206
+ if self.use_ada_layer_norm_single:
207
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
+
209
+ # let chunk size default to None
210
+ self._chunk_size = None
211
+ self._chunk_dim = 0
212
+
213
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
+ # Sets chunk feed-forward
215
+ self._chunk_size = chunk_size
216
+ self._chunk_dim = dim
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.FloatTensor,
221
+ spatial_attn_inputs = [],
222
+ attention_mask: Optional[torch.FloatTensor] = None,
223
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
224
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
225
+ timestep: Optional[torch.LongTensor] = None,
226
+ cross_attention_kwargs: Dict[str, Any] = None,
227
+ class_labels: Optional[torch.LongTensor] = None,
228
+ ) -> torch.FloatTensor:
229
+ # Notice that normalization is always applied before the real computation in the following blocks.
230
+ # 0. Self-Attention
231
+ batch_size = hidden_states.shape[0]
232
+
233
+ spatial_attn_input = hidden_states
234
+ spatial_attn_inputs.append(spatial_attn_input)
235
+
236
+ if self.use_ada_layer_norm:
237
+ norm_hidden_states = self.norm1(hidden_states, timestep)
238
+ elif self.use_ada_layer_norm_zero:
239
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
240
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
241
+ )
242
+ elif self.use_layer_norm:
243
+ norm_hidden_states = self.norm1(hidden_states)
244
+ elif self.use_ada_layer_norm_single:
245
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
246
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
247
+ ).chunk(6, dim=1)
248
+ norm_hidden_states = self.norm1(hidden_states)
249
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
250
+ norm_hidden_states = norm_hidden_states.squeeze(1)
251
+ else:
252
+ raise ValueError("Incorrect norm used")
253
+
254
+ if self.pos_embed is not None:
255
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
256
+
257
+ # 1. Retrieve lora scale.
258
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
259
+
260
+ # 2. Prepare GLIGEN inputs
261
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
262
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
263
+
264
+ attn_output = self.attn1(
265
+ norm_hidden_states,
266
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
267
+ attention_mask=attention_mask,
268
+ **cross_attention_kwargs,
269
+ )
270
+ if self.use_ada_layer_norm_zero:
271
+ attn_output = gate_msa.unsqueeze(1) * attn_output
272
+ elif self.use_ada_layer_norm_single:
273
+ attn_output = gate_msa * attn_output
274
+
275
+ hidden_states = attn_output + hidden_states
276
+ if hidden_states.ndim == 4:
277
+ hidden_states = hidden_states.squeeze(1)
278
+
279
+ # 2.5 GLIGEN Control
280
+ if gligen_kwargs is not None:
281
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
282
+
283
+ # 3. Cross-Attention
284
+ if self.attn2 is not None:
285
+ if self.use_ada_layer_norm:
286
+ norm_hidden_states = self.norm2(hidden_states, timestep)
287
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
288
+ norm_hidden_states = self.norm2(hidden_states)
289
+ elif self.use_ada_layer_norm_single:
290
+ # For PixArt norm2 isn't applied here:
291
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
292
+ norm_hidden_states = hidden_states
293
+ else:
294
+ raise ValueError("Incorrect norm")
295
+
296
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
297
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
298
+
299
+ attn_output = self.attn2(
300
+ norm_hidden_states,
301
+ encoder_hidden_states=encoder_hidden_states,
302
+ attention_mask=encoder_attention_mask,
303
+ **cross_attention_kwargs,
304
+ )
305
+ hidden_states = attn_output + hidden_states
306
+
307
+ # 4. Feed-forward
308
+ if not self.use_ada_layer_norm_single:
309
+ norm_hidden_states = self.norm3(hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
313
+
314
+ if self.use_ada_layer_norm_single:
315
+ norm_hidden_states = self.norm2(hidden_states)
316
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
317
+
318
+ if self._chunk_size is not None:
319
+ # "feed_forward_chunk_size" can be used to save memory
320
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
321
+ raise ValueError(
322
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
323
+ )
324
+
325
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
326
+ ff_output = torch.cat(
327
+ [
328
+ self.ff(hid_slice, scale=lora_scale)
329
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
330
+ ],
331
+ dim=self._chunk_dim,
332
+ )
333
+ else:
334
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
335
+
336
+ if self.use_ada_layer_norm_zero:
337
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
338
+ elif self.use_ada_layer_norm_single:
339
+ ff_output = gate_mlp * ff_output
340
+
341
+ hidden_states = ff_output + hidden_states
342
+ if hidden_states.ndim == 4:
343
+ hidden_states = hidden_states.squeeze(1)
344
+
345
+ return hidden_states, spatial_attn_inputs
346
+
347
+
348
+ class FeedForward(nn.Module):
349
+ r"""
350
+ A feed-forward layer.
351
+
352
+ Parameters:
353
+ dim (`int`): The number of channels in the input.
354
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
355
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
356
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
357
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
358
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ dim: int,
364
+ dim_out: Optional[int] = None,
365
+ mult: int = 4,
366
+ dropout: float = 0.0,
367
+ activation_fn: str = "geglu",
368
+ final_dropout: bool = False,
369
+ ):
370
+ super().__init__()
371
+ inner_dim = int(dim * mult)
372
+ dim_out = dim_out if dim_out is not None else dim
373
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
374
+
375
+ if activation_fn == "gelu":
376
+ act_fn = GELU(dim, inner_dim)
377
+ if activation_fn == "gelu-approximate":
378
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
379
+ elif activation_fn == "geglu":
380
+ act_fn = GEGLU(dim, inner_dim)
381
+ elif activation_fn == "geglu-approximate":
382
+ act_fn = ApproximateGELU(dim, inner_dim)
383
+
384
+ self.net = nn.ModuleList([])
385
+ # project in
386
+ self.net.append(act_fn)
387
+ # project dropout
388
+ self.net.append(nn.Dropout(dropout))
389
+ # project out
390
+ self.net.append(linear_cls(inner_dim, dim_out))
391
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
392
+ if final_dropout:
393
+ self.net.append(nn.Dropout(dropout))
394
+
395
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
396
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
397
+ for module in self.net:
398
+ if isinstance(module, compatible_cls):
399
+ hidden_states = module(hidden_states, scale)
400
+ else:
401
+ hidden_states = module(hidden_states)
402
+ return hidden_states
OOTDiffusion/ootd/pipelines_ootd/attention_vton.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.utils import USE_PEFT_BACKEND
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
24
+ from diffusers.models.attention_processor import Attention
25
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
26
+ from diffusers.models.lora import LoRACompatibleLinear
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
28
+
29
+
30
+ @maybe_allow_in_graph
31
+ class GatedSelfAttentionDense(nn.Module):
32
+ r"""
33
+ A gated self-attention dense layer that combines visual features and object features.
34
+
35
+ Parameters:
36
+ query_dim (`int`): The number of channels in the query.
37
+ context_dim (`int`): The number of channels in the context.
38
+ n_heads (`int`): The number of heads to use for attention.
39
+ d_head (`int`): The number of channels in each head.
40
+ """
41
+
42
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
43
+ super().__init__()
44
+
45
+ # we need a linear projection since we need cat visual feature and obj feature
46
+ self.linear = nn.Linear(context_dim, query_dim)
47
+
48
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
49
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
50
+
51
+ self.norm1 = nn.LayerNorm(query_dim)
52
+ self.norm2 = nn.LayerNorm(query_dim)
53
+
54
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
55
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
56
+
57
+ self.enabled = True
58
+
59
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
60
+ if not self.enabled:
61
+ return x
62
+
63
+ n_visual = x.shape[1]
64
+ objs = self.linear(objs)
65
+
66
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
67
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
68
+
69
+ return x
70
+
71
+
72
+ @maybe_allow_in_graph
73
+ class BasicTransformerBlock(nn.Module):
74
+ r"""
75
+ A basic Transformer block.
76
+
77
+ Parameters:
78
+ dim (`int`): The number of channels in the input and output.
79
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
80
+ attention_head_dim (`int`): The number of channels in each head.
81
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
82
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
83
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
84
+ num_embeds_ada_norm (:
85
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
86
+ attention_bias (:
87
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
88
+ only_cross_attention (`bool`, *optional*):
89
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
90
+ double_self_attention (`bool`, *optional*):
91
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
92
+ upcast_attention (`bool`, *optional*):
93
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
94
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
95
+ Whether to use learnable elementwise affine parameters for normalization.
96
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
97
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
98
+ final_dropout (`bool` *optional*, defaults to False):
99
+ Whether to apply a final dropout after the last feed-forward layer.
100
+ attention_type (`str`, *optional*, defaults to `"default"`):
101
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
102
+ positional_embeddings (`str`, *optional*, defaults to `None`):
103
+ The type of positional embeddings to apply to.
104
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
105
+ The maximum number of positional embeddings to apply.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ num_attention_heads: int,
112
+ attention_head_dim: int,
113
+ dropout=0.0,
114
+ cross_attention_dim: Optional[int] = None,
115
+ activation_fn: str = "geglu",
116
+ num_embeds_ada_norm: Optional[int] = None,
117
+ attention_bias: bool = False,
118
+ only_cross_attention: bool = False,
119
+ double_self_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ norm_elementwise_affine: bool = True,
122
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
123
+ norm_eps: float = 1e-5,
124
+ final_dropout: bool = False,
125
+ attention_type: str = "default",
126
+ positional_embeddings: Optional[str] = None,
127
+ num_positional_embeddings: Optional[int] = None,
128
+ ):
129
+ super().__init__()
130
+ self.only_cross_attention = only_cross_attention
131
+
132
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
133
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
134
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
135
+ self.use_layer_norm = norm_type == "layer_norm"
136
+
137
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
138
+ raise ValueError(
139
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
140
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
141
+ )
142
+
143
+ if positional_embeddings and (num_positional_embeddings is None):
144
+ raise ValueError(
145
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
146
+ )
147
+
148
+ if positional_embeddings == "sinusoidal":
149
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
150
+ else:
151
+ self.pos_embed = None
152
+
153
+ # Define 3 blocks. Each block has its own normalization layer.
154
+ # 1. Self-Attn
155
+ if self.use_ada_layer_norm:
156
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
157
+ elif self.use_ada_layer_norm_zero:
158
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
159
+ else:
160
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
161
+
162
+ self.attn1 = Attention(
163
+ query_dim=dim,
164
+ heads=num_attention_heads,
165
+ dim_head=attention_head_dim,
166
+ dropout=dropout,
167
+ bias=attention_bias,
168
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
169
+ upcast_attention=upcast_attention,
170
+ )
171
+
172
+ # 2. Cross-Attn
173
+ if cross_attention_dim is not None or double_self_attention:
174
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
175
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
176
+ # the second cross attention block.
177
+ self.norm2 = (
178
+ AdaLayerNorm(dim, num_embeds_ada_norm)
179
+ if self.use_ada_layer_norm
180
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
181
+ )
182
+ self.attn2 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ ) # is self-attn if encoder_hidden_states is none
191
+ else:
192
+ self.norm2 = None
193
+ self.attn2 = None
194
+
195
+ # 3. Feed-forward
196
+ if not self.use_ada_layer_norm_single:
197
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
198
+
199
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
200
+
201
+ # 4. Fuser
202
+ if attention_type == "gated" or attention_type == "gated-text-image":
203
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
204
+
205
+ # 5. Scale-shift for PixArt-Alpha.
206
+ if self.use_ada_layer_norm_single:
207
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
208
+
209
+ # let chunk size default to None
210
+ self._chunk_size = None
211
+ self._chunk_dim = 0
212
+
213
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
214
+ # Sets chunk feed-forward
215
+ self._chunk_size = chunk_size
216
+ self._chunk_dim = dim
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.FloatTensor,
221
+ spatial_attn_inputs = [],
222
+ spatial_attn_idx = 0,
223
+ attention_mask: Optional[torch.FloatTensor] = None,
224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
226
+ timestep: Optional[torch.LongTensor] = None,
227
+ cross_attention_kwargs: Dict[str, Any] = None,
228
+ class_labels: Optional[torch.LongTensor] = None,
229
+ ) -> torch.FloatTensor:
230
+ # Notice that normalization is always applied before the real computation in the following blocks.
231
+ # 0. Self-Attention
232
+ batch_size = hidden_states.shape[0]
233
+
234
+ spatial_attn_input = spatial_attn_inputs[spatial_attn_idx]
235
+ spatial_attn_idx += 1
236
+ hidden_states = torch.cat((hidden_states, spatial_attn_input), dim=1)
237
+
238
+ if self.use_ada_layer_norm:
239
+ norm_hidden_states = self.norm1(hidden_states, timestep)
240
+ elif self.use_ada_layer_norm_zero:
241
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
242
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
243
+ )
244
+ elif self.use_layer_norm:
245
+ norm_hidden_states = self.norm1(hidden_states)
246
+ elif self.use_ada_layer_norm_single:
247
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
248
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
249
+ ).chunk(6, dim=1)
250
+ norm_hidden_states = self.norm1(hidden_states)
251
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
252
+ norm_hidden_states = norm_hidden_states.squeeze(1)
253
+ else:
254
+ raise ValueError("Incorrect norm used")
255
+
256
+ if self.pos_embed is not None:
257
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
258
+
259
+ # 1. Retrieve lora scale.
260
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
261
+
262
+ # 2. Prepare GLIGEN inputs
263
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
264
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ elif self.use_ada_layer_norm_single:
275
+ attn_output = gate_msa * attn_output
276
+
277
+
278
+ hidden_states = attn_output + hidden_states
279
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
280
+
281
+ if hidden_states.ndim == 4:
282
+ hidden_states = hidden_states.squeeze(1)
283
+
284
+ # 2.5 GLIGEN Control
285
+ if gligen_kwargs is not None:
286
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
287
+
288
+ # 3. Cross-Attention
289
+ if self.attn2 is not None:
290
+ if self.use_ada_layer_norm:
291
+ norm_hidden_states = self.norm2(hidden_states, timestep)
292
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
293
+ norm_hidden_states = self.norm2(hidden_states)
294
+ elif self.use_ada_layer_norm_single:
295
+ # For PixArt norm2 isn't applied here:
296
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
297
+ norm_hidden_states = hidden_states
298
+ else:
299
+ raise ValueError("Incorrect norm")
300
+
301
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
302
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
303
+
304
+ attn_output = self.attn2(
305
+ norm_hidden_states,
306
+ encoder_hidden_states=encoder_hidden_states,
307
+ attention_mask=encoder_attention_mask,
308
+ **cross_attention_kwargs,
309
+ )
310
+ hidden_states = attn_output + hidden_states
311
+
312
+ # 4. Feed-forward
313
+ if not self.use_ada_layer_norm_single:
314
+ norm_hidden_states = self.norm3(hidden_states)
315
+
316
+ if self.use_ada_layer_norm_zero:
317
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
318
+
319
+ if self.use_ada_layer_norm_single:
320
+ norm_hidden_states = self.norm2(hidden_states)
321
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
322
+
323
+ if self._chunk_size is not None:
324
+ # "feed_forward_chunk_size" can be used to save memory
325
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
326
+ raise ValueError(
327
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
328
+ )
329
+
330
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
331
+ ff_output = torch.cat(
332
+ [
333
+ self.ff(hid_slice, scale=lora_scale)
334
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
335
+ ],
336
+ dim=self._chunk_dim,
337
+ )
338
+ else:
339
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
340
+
341
+ if self.use_ada_layer_norm_zero:
342
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
343
+ elif self.use_ada_layer_norm_single:
344
+ ff_output = gate_mlp * ff_output
345
+
346
+ hidden_states = ff_output + hidden_states
347
+ if hidden_states.ndim == 4:
348
+ hidden_states = hidden_states.squeeze(1)
349
+
350
+ return hidden_states, spatial_attn_inputs, spatial_attn_idx
351
+
352
+
353
+ class FeedForward(nn.Module):
354
+ r"""
355
+ A feed-forward layer.
356
+
357
+ Parameters:
358
+ dim (`int`): The number of channels in the input.
359
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
360
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
361
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
362
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
363
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ dim: int,
369
+ dim_out: Optional[int] = None,
370
+ mult: int = 4,
371
+ dropout: float = 0.0,
372
+ activation_fn: str = "geglu",
373
+ final_dropout: bool = False,
374
+ ):
375
+ super().__init__()
376
+ inner_dim = int(dim * mult)
377
+ dim_out = dim_out if dim_out is not None else dim
378
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
379
+
380
+ if activation_fn == "gelu":
381
+ act_fn = GELU(dim, inner_dim)
382
+ if activation_fn == "gelu-approximate":
383
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
384
+ elif activation_fn == "geglu":
385
+ act_fn = GEGLU(dim, inner_dim)
386
+ elif activation_fn == "geglu-approximate":
387
+ act_fn = ApproximateGELU(dim, inner_dim)
388
+
389
+ self.net = nn.ModuleList([])
390
+ # project in
391
+ self.net.append(act_fn)
392
+ # project dropout
393
+ self.net.append(nn.Dropout(dropout))
394
+ # project out
395
+ self.net.append(linear_cls(inner_dim, dim_out))
396
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
397
+ if final_dropout:
398
+ self.net.append(nn.Dropout(dropout))
399
+
400
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
401
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
402
+ for module in self.net:
403
+ if isinstance(module, compatible_cls):
404
+ hidden_states = module(hidden_states, scale)
405
+ else:
406
+ hidden_states = module(hidden_states)
407
+ return hidden_states
OOTDiffusion/ootd/pipelines_ootd/pipeline_ootd.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ from packaging import version
23
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
24
+
25
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection
26
+
27
+ from .unet_vton_2d_condition import UNetVton2DConditionModel
28
+ from .unet_garm_2d_condition import UNetGarm2DConditionModel
29
+
30
+ from diffusers.configuration_utils import FrozenDict
31
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.schedulers import KarrasDiffusionSchedulers
36
+ from diffusers.utils import (
37
+ PIL_INTERPOLATION,
38
+ USE_PEFT_BACKEND,
39
+ deprecate,
40
+ logging,
41
+ replace_example_docstring,
42
+ scale_lora_layers,
43
+ unscale_lora_layers,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
48
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
49
+
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+
54
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
55
+ def preprocess(image):
56
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
57
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
58
+ if isinstance(image, torch.Tensor):
59
+ return image
60
+ elif isinstance(image, PIL.Image.Image):
61
+ image = [image]
62
+
63
+ if isinstance(image[0], PIL.Image.Image):
64
+ w, h = image[0].size
65
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
66
+
67
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
68
+ image = np.concatenate(image, axis=0)
69
+ image = np.array(image).astype(np.float32) / 255.0
70
+ image = image.transpose(0, 3, 1, 2)
71
+ image = 2.0 * image - 1.0
72
+ image = torch.from_numpy(image)
73
+ elif isinstance(image[0], torch.Tensor):
74
+ image = torch.cat(image, dim=0)
75
+ return image
76
+
77
+
78
+ class OotdPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
79
+ r"""
80
+ Args:
81
+ vae ([`AutoencoderKL`]):
82
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
83
+ text_encoder ([`~transformers.CLIPTextModel`]):
84
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
85
+ tokenizer ([`~transformers.CLIPTokenizer`]):
86
+ A `CLIPTokenizer` to tokenize text.
87
+ unet ([`UNet2DConditionModel`]):
88
+ A `UNet2DConditionModel` to denoise the encoded image latents.
89
+ scheduler ([`SchedulerMixin`]):
90
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
91
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
92
+ safety_checker ([`StableDiffusionSafetyChecker`]):
93
+ Classification module that estimates whether generated images could be considered offensive or harmful.
94
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
95
+ about a model's potential harms.
96
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
97
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
98
+ """
99
+ model_cpu_offload_seq = "text_encoder->unet->vae"
100
+ _optional_components = ["safety_checker", "feature_extractor"]
101
+ _exclude_from_cpu_offload = ["safety_checker"]
102
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "vton_latents"]
103
+
104
+ def __init__(
105
+ self,
106
+ vae: AutoencoderKL,
107
+ text_encoder: CLIPTextModel,
108
+ tokenizer: CLIPTokenizer,
109
+ unet_garm: UNetGarm2DConditionModel,
110
+ unet_vton: UNetVton2DConditionModel,
111
+ scheduler: KarrasDiffusionSchedulers,
112
+ safety_checker: StableDiffusionSafetyChecker,
113
+ feature_extractor: CLIPImageProcessor,
114
+ requires_safety_checker: bool = True,
115
+ ):
116
+ super().__init__()
117
+
118
+ if safety_checker is None and requires_safety_checker:
119
+ logger.warning(
120
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
121
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
122
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
123
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
124
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
125
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
126
+ )
127
+
128
+ if safety_checker is not None and feature_extractor is None:
129
+ raise ValueError(
130
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
131
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
132
+ )
133
+
134
+ self.register_modules(
135
+ vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ unet_garm=unet_garm,
139
+ unet_vton=unet_vton,
140
+ scheduler=scheduler,
141
+ safety_checker=safety_checker,
142
+ feature_extractor=feature_extractor,
143
+ )
144
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
145
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
146
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
147
+
148
+ @torch.no_grad()
149
+ def __call__(
150
+ self,
151
+ prompt: Union[str, List[str]] = None,
152
+ image_garm: PipelineImageInput = None,
153
+ image_vton: PipelineImageInput = None,
154
+ mask: PipelineImageInput = None,
155
+ image_ori: PipelineImageInput = None,
156
+ num_inference_steps: int = 100,
157
+ guidance_scale: float = 7.5,
158
+ image_guidance_scale: float = 1.5,
159
+ negative_prompt: Optional[Union[str, List[str]]] = None,
160
+ num_images_per_prompt: Optional[int] = 1,
161
+ eta: float = 0.0,
162
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
163
+ latents: Optional[torch.FloatTensor] = None,
164
+ prompt_embeds: Optional[torch.FloatTensor] = None,
165
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
166
+ output_type: Optional[str] = "pil",
167
+ return_dict: bool = True,
168
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
169
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
170
+ **kwargs,
171
+ ):
172
+ r"""
173
+ The call function to the pipeline for generation.
174
+
175
+ Args:
176
+ prompt (`str` or `List[str]`, *optional*):
177
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
178
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
179
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
180
+ image latents as `image`, but if passing latents directly it is not encoded again.
181
+ num_inference_steps (`int`, *optional*, defaults to 100):
182
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
183
+ expense of slower inference.
184
+ guidance_scale (`float`, *optional*, defaults to 7.5):
185
+ A higher guidance scale value encourages the model to generate images closely linked to the text
186
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
187
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
188
+ Push the generated image towards the initial `image`. Image guidance scale is enabled by setting
189
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
190
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
191
+ value of at least `1`.
192
+ negative_prompt (`str` or `List[str]`, *optional*):
193
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
194
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
195
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
196
+ The number of images to generate per prompt.
197
+ eta (`float`, *optional*, defaults to 0.0):
198
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
199
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
200
+ generator (`torch.Generator`, *optional*):
201
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
202
+ generation deterministic.
203
+ latents (`torch.FloatTensor`, *optional*):
204
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
205
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
206
+ tensor is generated by sampling using the supplied random `generator`.
207
+ prompt_embeds (`torch.FloatTensor`, *optional*):
208
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
209
+ provided, text embeddings are generated from the `prompt` input argument.
210
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
211
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
212
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
213
+ output_type (`str`, *optional*, defaults to `"pil"`):
214
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
215
+ return_dict (`bool`, *optional*, defaults to `True`):
216
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
217
+ plain tuple.
218
+ callback_on_step_end (`Callable`, *optional*):
219
+ A function that calls at the end of each denoising steps during the inference. The function is called
220
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
221
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
222
+ `callback_on_step_end_tensor_inputs`.
223
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
224
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
225
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
226
+ `._callback_tensor_inputs` attribute of your pipeline class.
227
+
228
+ Returns:
229
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
230
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
231
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
232
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
233
+ "not-safe-for-work" (nsfw) content.
234
+ """
235
+
236
+ callback = kwargs.pop("callback", None)
237
+ callback_steps = kwargs.pop("callback_steps", None)
238
+
239
+ if callback is not None:
240
+ deprecate(
241
+ "callback",
242
+ "1.0.0",
243
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
244
+ )
245
+ if callback_steps is not None:
246
+ deprecate(
247
+ "callback_steps",
248
+ "1.0.0",
249
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
250
+ )
251
+
252
+ # 0. Check inputs
253
+ self.check_inputs(
254
+ prompt,
255
+ callback_steps,
256
+ negative_prompt,
257
+ prompt_embeds,
258
+ negative_prompt_embeds,
259
+ callback_on_step_end_tensor_inputs,
260
+ )
261
+ self._guidance_scale = guidance_scale
262
+ self._image_guidance_scale = image_guidance_scale
263
+
264
+ if (image_vton is None) or (image_garm is None):
265
+ raise ValueError("`image` input cannot be undefined.")
266
+
267
+ # 1. Define call parameters
268
+ if prompt is not None and isinstance(prompt, str):
269
+ batch_size = 1
270
+ elif prompt is not None and isinstance(prompt, list):
271
+ batch_size = len(prompt)
272
+ else:
273
+ batch_size = prompt_embeds.shape[0]
274
+
275
+ device = self._execution_device
276
+ # check if scheduler is in sigmas space
277
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
278
+
279
+ # 2. Encode input prompt
280
+ prompt_embeds = self._encode_prompt(
281
+ prompt,
282
+ device,
283
+ num_images_per_prompt,
284
+ self.do_classifier_free_guidance,
285
+ negative_prompt,
286
+ prompt_embeds=prompt_embeds,
287
+ negative_prompt_embeds=negative_prompt_embeds,
288
+ )
289
+
290
+ # 3. Preprocess image
291
+ image_garm = self.image_processor.preprocess(image_garm)
292
+ image_vton = self.image_processor.preprocess(image_vton)
293
+ image_ori = self.image_processor.preprocess(image_ori)
294
+ mask = np.array(mask)
295
+ mask[mask < 127] = 0
296
+ mask[mask >= 127] = 255
297
+ mask = torch.tensor(mask)
298
+ mask = mask / 255
299
+ mask = mask.reshape(-1, 1, mask.size(-2), mask.size(-1))
300
+
301
+ # 4. set timesteps
302
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
303
+ timesteps = self.scheduler.timesteps
304
+
305
+ # 5. Prepare Image latents
306
+ garm_latents = self.prepare_garm_latents(
307
+ image_garm,
308
+ batch_size,
309
+ num_images_per_prompt,
310
+ prompt_embeds.dtype,
311
+ device,
312
+ self.do_classifier_free_guidance,
313
+ generator,
314
+ )
315
+
316
+ vton_latents, mask_latents, image_ori_latents = self.prepare_vton_latents(
317
+ image_vton,
318
+ mask,
319
+ image_ori,
320
+ batch_size,
321
+ num_images_per_prompt,
322
+ prompt_embeds.dtype,
323
+ device,
324
+ self.do_classifier_free_guidance,
325
+ generator,
326
+ )
327
+
328
+ height, width = vton_latents.shape[-2:]
329
+ height = height * self.vae_scale_factor
330
+ width = width * self.vae_scale_factor
331
+
332
+ # 6. Prepare latent variables
333
+ num_channels_latents = self.vae.config.latent_channels
334
+ latents = self.prepare_latents(
335
+ batch_size * num_images_per_prompt,
336
+ num_channels_latents,
337
+ height,
338
+ width,
339
+ prompt_embeds.dtype,
340
+ device,
341
+ generator,
342
+ latents,
343
+ )
344
+
345
+ noise = latents.clone()
346
+
347
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
348
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
349
+
350
+ # 9. Denoising loop
351
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
352
+ self._num_timesteps = len(timesteps)
353
+
354
+ _, spatial_attn_outputs = self.unet_garm(
355
+ garm_latents,
356
+ 0,
357
+ encoder_hidden_states=prompt_embeds,
358
+ return_dict=False,
359
+ )
360
+
361
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
362
+ for i, t in enumerate(timesteps):
363
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
364
+
365
+ # concat latents, image_latents in the channel dimension
366
+ scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
367
+ latent_vton_model_input = torch.cat([scaled_latent_model_input, vton_latents], dim=1)
368
+ # latent_vton_model_input = scaled_latent_model_input + vton_latents
369
+
370
+ spatial_attn_inputs = spatial_attn_outputs.copy()
371
+
372
+ # predict the noise residual
373
+ noise_pred = self.unet_vton(
374
+ latent_vton_model_input,
375
+ spatial_attn_inputs,
376
+ t,
377
+ encoder_hidden_states=prompt_embeds,
378
+ return_dict=False,
379
+ )[0]
380
+
381
+ # Hack:
382
+ # For karras style schedulers the model does classifer free guidance using the
383
+ # predicted_original_sample instead of the noise_pred. So we need to compute the
384
+ # predicted_original_sample here if we are using a karras style scheduler.
385
+ if scheduler_is_in_sigma_space:
386
+ step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
387
+ sigma = self.scheduler.sigmas[step_index]
388
+ noise_pred = latent_model_input - sigma * noise_pred
389
+
390
+ # perform guidance
391
+ if self.do_classifier_free_guidance:
392
+ noise_pred_text_image, noise_pred_text = noise_pred.chunk(2)
393
+ noise_pred = (
394
+ noise_pred_text
395
+ + self.image_guidance_scale * (noise_pred_text_image - noise_pred_text)
396
+ )
397
+
398
+ # Hack:
399
+ # For karras style schedulers the model does classifer free guidance using the
400
+ # predicted_original_sample instead of the noise_pred. But the scheduler.step function
401
+ # expects the noise_pred and computes the predicted_original_sample internally. So we
402
+ # need to overwrite the noise_pred here such that the value of the computed
403
+ # predicted_original_sample is correct.
404
+ if scheduler_is_in_sigma_space:
405
+ noise_pred = (noise_pred - latents) / (-sigma)
406
+
407
+ # compute the previous noisy sample x_t -> x_t-1
408
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
409
+
410
+ init_latents_proper = image_ori_latents * self.vae.config.scaling_factor
411
+
412
+ # repainting
413
+ if i < len(timesteps) - 1:
414
+ noise_timestep = timesteps[i + 1]
415
+ init_latents_proper = self.scheduler.add_noise(
416
+ init_latents_proper, noise, torch.tensor([noise_timestep])
417
+ )
418
+
419
+ latents = (1 - mask_latents) * init_latents_proper + mask_latents * latents
420
+
421
+ if callback_on_step_end is not None:
422
+ callback_kwargs = {}
423
+ for k in callback_on_step_end_tensor_inputs:
424
+ callback_kwargs[k] = locals()[k]
425
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
426
+
427
+ latents = callback_outputs.pop("latents", latents)
428
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
429
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
430
+ vton_latents = callback_outputs.pop("vton_latents", vton_latents)
431
+
432
+ # call the callback, if provided
433
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
434
+ progress_bar.update()
435
+ if callback is not None and i % callback_steps == 0:
436
+ step_idx = i // getattr(self.scheduler, "order", 1)
437
+ callback(step_idx, t, latents)
438
+
439
+ if not output_type == "latent":
440
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
441
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
442
+ else:
443
+ image = latents
444
+ has_nsfw_concept = None
445
+
446
+ if has_nsfw_concept is None:
447
+ do_denormalize = [True] * image.shape[0]
448
+ else:
449
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
450
+
451
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
452
+
453
+ # Offload all models
454
+ self.maybe_free_model_hooks()
455
+
456
+ if not return_dict:
457
+ return (image, has_nsfw_concept)
458
+
459
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
460
+
461
+ def _encode_prompt(
462
+ self,
463
+ prompt,
464
+ device,
465
+ num_images_per_prompt,
466
+ do_classifier_free_guidance,
467
+ negative_prompt=None,
468
+ prompt_embeds: Optional[torch.FloatTensor] = None,
469
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
470
+ ):
471
+ r"""
472
+ Encodes the prompt into text encoder hidden states.
473
+
474
+ Args:
475
+ prompt (`str` or `List[str]`, *optional*):
476
+ prompt to be encoded
477
+ device: (`torch.device`):
478
+ torch device
479
+ num_images_per_prompt (`int`):
480
+ number of images that should be generated per prompt
481
+ do_classifier_free_guidance (`bool`):
482
+ whether to use classifier free guidance or not
483
+ negative_ prompt (`str` or `List[str]`, *optional*):
484
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
485
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
486
+ less than `1`).
487
+ prompt_embeds (`torch.FloatTensor`, *optional*):
488
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
489
+ provided, text embeddings will be generated from `prompt` input argument.
490
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
491
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
492
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
493
+ argument.
494
+ """
495
+ if prompt is not None and isinstance(prompt, str):
496
+ batch_size = 1
497
+ elif prompt is not None and isinstance(prompt, list):
498
+ batch_size = len(prompt)
499
+ else:
500
+ batch_size = prompt_embeds.shape[0]
501
+
502
+ if prompt_embeds is None:
503
+ # textual inversion: procecss multi-vector tokens if necessary
504
+ if isinstance(self, TextualInversionLoaderMixin):
505
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
506
+
507
+ text_inputs = self.tokenizer(
508
+ prompt,
509
+ padding="max_length",
510
+ max_length=self.tokenizer.model_max_length,
511
+ truncation=True,
512
+ return_tensors="pt",
513
+ )
514
+ text_input_ids = text_inputs.input_ids
515
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
516
+
517
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
518
+ text_input_ids, untruncated_ids
519
+ ):
520
+ removed_text = self.tokenizer.batch_decode(
521
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
522
+ )
523
+ logger.warning(
524
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
525
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
526
+ )
527
+
528
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
529
+ attention_mask = text_inputs.attention_mask.to(device)
530
+ else:
531
+ attention_mask = None
532
+
533
+ prompt_embeds = self.text_encoder(
534
+ text_input_ids.to(device),
535
+ attention_mask=attention_mask,
536
+ )
537
+ prompt_embeds = prompt_embeds[0]
538
+
539
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
540
+
541
+ bs_embed, seq_len, _ = prompt_embeds.shape
542
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
543
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
544
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
545
+
546
+ # get unconditional embeddings for classifier free guidance
547
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
548
+ uncond_tokens: List[str]
549
+ if negative_prompt is None:
550
+ uncond_tokens = [""] * batch_size
551
+ elif type(prompt) is not type(negative_prompt):
552
+ raise TypeError(
553
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
554
+ f" {type(prompt)}."
555
+ )
556
+ elif isinstance(negative_prompt, str):
557
+ uncond_tokens = [negative_prompt]
558
+ elif batch_size != len(negative_prompt):
559
+ raise ValueError(
560
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
561
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
562
+ " the batch size of `prompt`."
563
+ )
564
+ else:
565
+ uncond_tokens = negative_prompt
566
+
567
+ # textual inversion: procecss multi-vector tokens if necessary
568
+ if isinstance(self, TextualInversionLoaderMixin):
569
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
570
+
571
+ max_length = prompt_embeds.shape[1]
572
+ uncond_input = self.tokenizer(
573
+ uncond_tokens,
574
+ padding="max_length",
575
+ max_length=max_length,
576
+ truncation=True,
577
+ return_tensors="pt",
578
+ )
579
+
580
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
581
+ attention_mask = uncond_input.attention_mask.to(device)
582
+ else:
583
+ attention_mask = None
584
+
585
+ if do_classifier_free_guidance:
586
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
587
+
588
+ return prompt_embeds
589
+
590
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
591
+ def run_safety_checker(self, image, device, dtype):
592
+ if self.safety_checker is None:
593
+ has_nsfw_concept = None
594
+ else:
595
+ if torch.is_tensor(image):
596
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
597
+ else:
598
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
599
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
600
+ image, has_nsfw_concept = self.safety_checker(
601
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
602
+ )
603
+ return image, has_nsfw_concept
604
+
605
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
606
+ def prepare_extra_step_kwargs(self, generator, eta):
607
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
608
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
609
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
610
+ # and should be between [0, 1]
611
+
612
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
613
+ extra_step_kwargs = {}
614
+ if accepts_eta:
615
+ extra_step_kwargs["eta"] = eta
616
+
617
+ # check if the scheduler accepts generator
618
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
619
+ if accepts_generator:
620
+ extra_step_kwargs["generator"] = generator
621
+ return extra_step_kwargs
622
+
623
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
624
+ def decode_latents(self, latents):
625
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
626
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
627
+
628
+ latents = 1 / self.vae.config.scaling_factor * latents
629
+ image = self.vae.decode(latents, return_dict=False)[0]
630
+ image = (image / 2 + 0.5).clamp(0, 1)
631
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
632
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
633
+ return image
634
+
635
+ def check_inputs(
636
+ self,
637
+ prompt,
638
+ callback_steps,
639
+ negative_prompt=None,
640
+ prompt_embeds=None,
641
+ negative_prompt_embeds=None,
642
+ callback_on_step_end_tensor_inputs=None,
643
+ ):
644
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
645
+ raise ValueError(
646
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
647
+ f" {type(callback_steps)}."
648
+ )
649
+
650
+ if callback_on_step_end_tensor_inputs is not None and not all(
651
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
652
+ ):
653
+ raise ValueError(
654
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
655
+ )
656
+
657
+ if prompt is not None and prompt_embeds is not None:
658
+ raise ValueError(
659
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
660
+ " only forward one of the two."
661
+ )
662
+ elif prompt is None and prompt_embeds is None:
663
+ raise ValueError(
664
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
665
+ )
666
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
667
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
668
+
669
+ if negative_prompt is not None and negative_prompt_embeds is not None:
670
+ raise ValueError(
671
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
672
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
673
+ )
674
+
675
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
676
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
677
+ raise ValueError(
678
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
679
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
680
+ f" {negative_prompt_embeds.shape}."
681
+ )
682
+
683
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
684
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
685
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
686
+ if isinstance(generator, list) and len(generator) != batch_size:
687
+ raise ValueError(
688
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
689
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
690
+ )
691
+
692
+ if latents is None:
693
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
694
+ else:
695
+ latents = latents.to(device)
696
+
697
+ # scale the initial noise by the standard deviation required by the scheduler
698
+ latents = latents * self.scheduler.init_noise_sigma
699
+ return latents
700
+
701
+ def prepare_garm_latents(
702
+ self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
703
+ ):
704
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
705
+ raise ValueError(
706
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
707
+ )
708
+
709
+ image = image.to(device=device, dtype=dtype)
710
+
711
+ batch_size = batch_size * num_images_per_prompt
712
+
713
+ if image.shape[1] == 4:
714
+ image_latents = image
715
+ else:
716
+ if isinstance(generator, list) and len(generator) != batch_size:
717
+ raise ValueError(
718
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
719
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
720
+ )
721
+
722
+ if isinstance(generator, list):
723
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
724
+ image_latents = torch.cat(image_latents, dim=0)
725
+ else:
726
+ image_latents = self.vae.encode(image).latent_dist.mode()
727
+
728
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
729
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
730
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
731
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
732
+ raise ValueError(
733
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
734
+ )
735
+ else:
736
+ image_latents = torch.cat([image_latents], dim=0)
737
+
738
+ if do_classifier_free_guidance:
739
+ uncond_image_latents = torch.zeros_like(image_latents)
740
+ image_latents = torch.cat([image_latents, uncond_image_latents], dim=0)
741
+
742
+ return image_latents
743
+
744
+ def prepare_vton_latents(
745
+ self, image, mask, image_ori, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
746
+ ):
747
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
748
+ raise ValueError(
749
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
750
+ )
751
+
752
+ image = image.to(device=device, dtype=dtype)
753
+ image_ori = image_ori.to(device=device, dtype=dtype)
754
+
755
+ batch_size = batch_size * num_images_per_prompt
756
+
757
+ if image.shape[1] == 4:
758
+ image_latents = image
759
+ image_ori_latents = image_ori
760
+ else:
761
+ if isinstance(generator, list) and len(generator) != batch_size:
762
+ raise ValueError(
763
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
764
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
765
+ )
766
+
767
+ if isinstance(generator, list):
768
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
769
+ image_latents = torch.cat(image_latents, dim=0)
770
+ image_ori_latents = [self.vae.encode(image_ori[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
771
+ image_ori_latents = torch.cat(image_ori_latents, dim=0)
772
+ else:
773
+ image_latents = self.vae.encode(image).latent_dist.mode()
774
+ image_ori_latents = self.vae.encode(image_ori).latent_dist.mode()
775
+
776
+ mask = torch.nn.functional.interpolate(
777
+ mask, size=(image_latents.size(-2), image_latents.size(-1))
778
+ )
779
+ mask = mask.to(device=device, dtype=dtype)
780
+
781
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
782
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
783
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
784
+ mask = torch.cat([mask] * additional_image_per_prompt, dim=0)
785
+ image_ori_latents = torch.cat([image_ori_latents] * additional_image_per_prompt, dim=0)
786
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
787
+ raise ValueError(
788
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
789
+ )
790
+ else:
791
+ image_latents = torch.cat([image_latents], dim=0)
792
+ mask = torch.cat([mask], dim=0)
793
+ image_ori_latents = torch.cat([image_ori_latents], dim=0)
794
+
795
+ if do_classifier_free_guidance:
796
+ # uncond_image_latents = torch.zeros_like(image_latents)
797
+ image_latents = torch.cat([image_latents] * 2, dim=0)
798
+
799
+ return image_latents, mask, image_ori_latents
800
+
801
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
802
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
803
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
804
+
805
+ The suffixes after the scaling factors represent the stages where they are being applied.
806
+
807
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
808
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
809
+
810
+ Args:
811
+ s1 (`float`):
812
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
813
+ mitigate "oversmoothing effect" in the enhanced denoising process.
814
+ s2 (`float`):
815
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
816
+ mitigate "oversmoothing effect" in the enhanced denoising process.
817
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
818
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
819
+ """
820
+ if not hasattr(self, "unet"):
821
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
822
+ self.unet_vton.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
823
+
824
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
825
+ def disable_freeu(self):
826
+ """Disables the FreeU mechanism if enabled."""
827
+ self.unet_vton.disable_freeu()
828
+
829
+ @property
830
+ def guidance_scale(self):
831
+ return self._guidance_scale
832
+
833
+ @property
834
+ def image_guidance_scale(self):
835
+ return self._image_guidance_scale
836
+
837
+ @property
838
+ def num_timesteps(self):
839
+ return self._num_timesteps
840
+
841
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
842
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
843
+ # corresponds to doing no classifier free guidance.
844
+ @property
845
+ def do_classifier_free_guidance(self):
846
+ return self.image_guidance_scale >= 1.0
OOTDiffusion/ootd/pipelines_ootd/transformer_garm_2d.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from .attention_garm import BasicTransformerBlock
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
27
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
+ # from diffusers.models.attention import BasicTransformerBlock
29
+ from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormSingle
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ The output of [`Transformer2DModel`].
39
+
40
+ Args:
41
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
+ distributions for the unnoised latent pixels.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ A 2D Transformer model for image-like data.
52
+
53
+ Parameters:
54
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
+ in_channels (`int`, *optional*):
57
+ The number of channels in the input and output (specify if the input is **continuous**).
58
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
+ This is fixed during training since it is used to learn a number of position embeddings.
63
+ num_vector_embeds (`int`, *optional*):
64
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
+ Includes the class for the masked latent pixel.
66
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
+ num_embeds_ada_norm ( `int`, *optional*):
68
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
+ added to the hidden states.
71
+
72
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
+ """
76
+
77
+ @register_to_config
78
+ def __init__(
79
+ self,
80
+ num_attention_heads: int = 16,
81
+ attention_head_dim: int = 88,
82
+ in_channels: Optional[int] = None,
83
+ out_channels: Optional[int] = None,
84
+ num_layers: int = 1,
85
+ dropout: float = 0.0,
86
+ norm_num_groups: int = 32,
87
+ cross_attention_dim: Optional[int] = None,
88
+ attention_bias: bool = False,
89
+ sample_size: Optional[int] = None,
90
+ num_vector_embeds: Optional[int] = None,
91
+ patch_size: Optional[int] = None,
92
+ activation_fn: str = "geglu",
93
+ num_embeds_ada_norm: Optional[int] = None,
94
+ use_linear_projection: bool = False,
95
+ only_cross_attention: bool = False,
96
+ double_self_attention: bool = False,
97
+ upcast_attention: bool = False,
98
+ norm_type: str = "layer_norm",
99
+ norm_elementwise_affine: bool = True,
100
+ norm_eps: float = 1e-5,
101
+ attention_type: str = "default",
102
+ caption_channels: int = None,
103
+ ):
104
+ super().__init__()
105
+ self.use_linear_projection = use_linear_projection
106
+ self.num_attention_heads = num_attention_heads
107
+ self.attention_head_dim = attention_head_dim
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
+
113
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
+ deprecation_message = (
121
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
+ )
127
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
+ norm_type = "ada_norm"
129
+
130
+ if self.is_input_continuous and self.is_input_vectorized:
131
+ raise ValueError(
132
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
+ " sure that either `in_channels` or `num_vector_embeds` is None."
134
+ )
135
+ elif self.is_input_vectorized and self.is_input_patches:
136
+ raise ValueError(
137
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
+ " sure that either `num_vector_embeds` or `num_patches` is None."
139
+ )
140
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
+ raise ValueError(
142
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
+ )
145
+
146
+ # 2. Define input layers
147
+ if self.is_input_continuous:
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = linear_cls(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
+ elif self.is_input_vectorized:
156
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+ self.num_vector_embeds = num_vector_embeds
162
+ self.num_latent_pixels = self.height * self.width
163
+
164
+ self.latent_image_embedding = ImagePositionalEmbeddings(
165
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
+ )
167
+ elif self.is_input_patches:
168
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
+
170
+ self.height = sample_size
171
+ self.width = sample_size
172
+
173
+ self.patch_size = patch_size
174
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
+ interpolation_scale = max(interpolation_scale, 1)
176
+ self.pos_embed = PatchEmbed(
177
+ height=sample_size,
178
+ width=sample_size,
179
+ patch_size=patch_size,
180
+ in_channels=in_channels,
181
+ embed_dim=inner_dim,
182
+ interpolation_scale=interpolation_scale,
183
+ )
184
+
185
+ # 3. Define transformers blocks
186
+ self.transformer_blocks = nn.ModuleList(
187
+ [
188
+ BasicTransformerBlock(
189
+ inner_dim,
190
+ num_attention_heads,
191
+ attention_head_dim,
192
+ dropout=dropout,
193
+ cross_attention_dim=cross_attention_dim,
194
+ activation_fn=activation_fn,
195
+ num_embeds_ada_norm=num_embeds_ada_norm,
196
+ attention_bias=attention_bias,
197
+ only_cross_attention=only_cross_attention,
198
+ double_self_attention=double_self_attention,
199
+ upcast_attention=upcast_attention,
200
+ norm_type=norm_type,
201
+ norm_elementwise_affine=norm_elementwise_affine,
202
+ norm_eps=norm_eps,
203
+ attention_type=attention_type,
204
+ )
205
+ for d in range(num_layers)
206
+ ]
207
+ )
208
+
209
+ # 4. Define output layers
210
+ self.out_channels = in_channels if out_channels is None else out_channels
211
+ if self.is_input_continuous:
212
+ # TODO: should use out_channels for continuous projections
213
+ if use_linear_projection:
214
+ self.proj_out = linear_cls(inner_dim, in_channels)
215
+ else:
216
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
+ elif self.is_input_vectorized:
218
+ self.norm_out = nn.LayerNorm(inner_dim)
219
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
+ elif self.is_input_patches and norm_type != "ada_norm_single":
221
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
+ elif self.is_input_patches and norm_type == "ada_norm_single":
225
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
+
229
+ # 5. PixArt-Alpha blocks.
230
+ self.adaln_single = None
231
+ self.use_additional_conditions = False
232
+ if norm_type == "ada_norm_single":
233
+ self.use_additional_conditions = self.config.sample_size == 128
234
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
+ # additional conditions until we find better name
236
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
+
238
+ self.caption_projection = None
239
+ if caption_channels is not None:
240
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
+
242
+ self.gradient_checkpointing = False
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ spatial_attn_inputs = [],
248
+ encoder_hidden_states: Optional[torch.Tensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ cross_attention_kwargs: Dict[str, Any] = None,
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ encoder_attention_mask: Optional[torch.Tensor] = None,
255
+ return_dict: bool = True,
256
+ ):
257
+ """
258
+ The [`Transformer2DModel`] forward method.
259
+
260
+ Args:
261
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
262
+ Input `hidden_states`.
263
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
264
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
265
+ self-attention.
266
+ timestep ( `torch.LongTensor`, *optional*):
267
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
268
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
269
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
270
+ `AdaLayerZeroNorm`.
271
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
272
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
273
+ `self.processor` in
274
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
275
+ attention_mask ( `torch.Tensor`, *optional*):
276
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
277
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
278
+ negative values to the attention scores corresponding to "discard" tokens.
279
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
280
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
281
+
282
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
283
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
284
+
285
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
286
+ above. This bias will be added to the cross-attention scores.
287
+ return_dict (`bool`, *optional*, defaults to `True`):
288
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
289
+ tuple.
290
+
291
+ Returns:
292
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
293
+ `tuple` where the first element is the sample tensor.
294
+ """
295
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
296
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
297
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
298
+ # expects mask of shape:
299
+ # [batch, key_tokens]
300
+ # adds singleton query_tokens dimension:
301
+ # [batch, 1, key_tokens]
302
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
303
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
304
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
305
+ if attention_mask is not None and attention_mask.ndim == 2:
306
+ # assume that mask is expressed as:
307
+ # (1 = keep, 0 = discard)
308
+ # convert mask into a bias that can be added to attention scores:
309
+ # (keep = +0, discard = -10000.0)
310
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
311
+ attention_mask = attention_mask.unsqueeze(1)
312
+
313
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
314
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
315
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
316
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
317
+
318
+ # Retrieve lora scale.
319
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
320
+
321
+ # 1. Input
322
+ if self.is_input_continuous:
323
+ batch, _, height, width = hidden_states.shape
324
+ residual = hidden_states
325
+
326
+ hidden_states = self.norm(hidden_states)
327
+ if not self.use_linear_projection:
328
+ hidden_states = (
329
+ self.proj_in(hidden_states, scale=lora_scale)
330
+ if not USE_PEFT_BACKEND
331
+ else self.proj_in(hidden_states)
332
+ )
333
+ inner_dim = hidden_states.shape[1]
334
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
335
+ else:
336
+ inner_dim = hidden_states.shape[1]
337
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
338
+ hidden_states = (
339
+ self.proj_in(hidden_states, scale=lora_scale)
340
+ if not USE_PEFT_BACKEND
341
+ else self.proj_in(hidden_states)
342
+ )
343
+
344
+ elif self.is_input_vectorized:
345
+ hidden_states = self.latent_image_embedding(hidden_states)
346
+ elif self.is_input_patches:
347
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
348
+ hidden_states = self.pos_embed(hidden_states)
349
+
350
+ if self.adaln_single is not None:
351
+ if self.use_additional_conditions and added_cond_kwargs is None:
352
+ raise ValueError(
353
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
354
+ )
355
+ batch_size = hidden_states.shape[0]
356
+ timestep, embedded_timestep = self.adaln_single(
357
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
358
+ )
359
+
360
+ # 2. Blocks
361
+ if self.caption_projection is not None:
362
+ batch_size = hidden_states.shape[0]
363
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
364
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
365
+
366
+ for block in self.transformer_blocks:
367
+ if self.training and self.gradient_checkpointing:
368
+ hidden_states, spatial_attn_inputs = torch.utils.checkpoint.checkpoint(
369
+ block,
370
+ hidden_states,
371
+ spatial_attn_inputs,
372
+ attention_mask,
373
+ encoder_hidden_states,
374
+ encoder_attention_mask,
375
+ timestep,
376
+ cross_attention_kwargs,
377
+ class_labels,
378
+ use_reentrant=False,
379
+ )
380
+ else:
381
+ hidden_states, spatial_attn_inputs = block(
382
+ hidden_states,
383
+ spatial_attn_inputs,
384
+ attention_mask=attention_mask,
385
+ encoder_hidden_states=encoder_hidden_states,
386
+ encoder_attention_mask=encoder_attention_mask,
387
+ timestep=timestep,
388
+ cross_attention_kwargs=cross_attention_kwargs,
389
+ class_labels=class_labels,
390
+ )
391
+
392
+ # 3. Output
393
+ if self.is_input_continuous:
394
+ if not self.use_linear_projection:
395
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
396
+ hidden_states = (
397
+ self.proj_out(hidden_states, scale=lora_scale)
398
+ if not USE_PEFT_BACKEND
399
+ else self.proj_out(hidden_states)
400
+ )
401
+ else:
402
+ hidden_states = (
403
+ self.proj_out(hidden_states, scale=lora_scale)
404
+ if not USE_PEFT_BACKEND
405
+ else self.proj_out(hidden_states)
406
+ )
407
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
408
+
409
+ output = hidden_states + residual
410
+ elif self.is_input_vectorized:
411
+ hidden_states = self.norm_out(hidden_states)
412
+ logits = self.out(hidden_states)
413
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
414
+ logits = logits.permute(0, 2, 1)
415
+
416
+ # log(p(x_0))
417
+ output = F.log_softmax(logits.double(), dim=1).float()
418
+
419
+ if self.is_input_patches:
420
+ if self.config.norm_type != "ada_norm_single":
421
+ conditioning = self.transformer_blocks[0].norm1.emb(
422
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
423
+ )
424
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
425
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
426
+ hidden_states = self.proj_out_2(hidden_states)
427
+ elif self.config.norm_type == "ada_norm_single":
428
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
429
+ hidden_states = self.norm_out(hidden_states)
430
+ # Modulation
431
+ hidden_states = hidden_states * (1 + scale) + shift
432
+ hidden_states = self.proj_out(hidden_states)
433
+ hidden_states = hidden_states.squeeze(1)
434
+
435
+ # unpatchify
436
+ if self.adaln_single is None:
437
+ height = width = int(hidden_states.shape[1] ** 0.5)
438
+ hidden_states = hidden_states.reshape(
439
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
440
+ )
441
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
442
+ output = hidden_states.reshape(
443
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
444
+ )
445
+
446
+ if not return_dict:
447
+ return (output,), spatial_attn_inputs
448
+
449
+ return Transformer2DModelOutput(sample=output), spatial_attn_inputs
OOTDiffusion/ootd/pipelines_ootd/transformer_vton_2d.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from .attention_vton import BasicTransformerBlock
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
27
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
28
+ # from diffusers.models.attention import BasicTransformerBlock
29
+ from diffusers.models.embeddings import CaptionProjection, PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import AdaLayerNormSingle
33
+
34
+
35
+ @dataclass
36
+ class Transformer2DModelOutput(BaseOutput):
37
+ """
38
+ The output of [`Transformer2DModel`].
39
+
40
+ Args:
41
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
42
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
43
+ distributions for the unnoised latent pixels.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ A 2D Transformer model for image-like data.
52
+
53
+ Parameters:
54
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
55
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
56
+ in_channels (`int`, *optional*):
57
+ The number of channels in the input and output (specify if the input is **continuous**).
58
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
61
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
62
+ This is fixed during training since it is used to learn a number of position embeddings.
63
+ num_vector_embeds (`int`, *optional*):
64
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
65
+ Includes the class for the masked latent pixel.
66
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
67
+ num_embeds_ada_norm ( `int`, *optional*):
68
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
69
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
70
+ added to the hidden states.
71
+
72
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
73
+ attention_bias (`bool`, *optional*):
74
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
75
+ """
76
+
77
+ @register_to_config
78
+ def __init__(
79
+ self,
80
+ num_attention_heads: int = 16,
81
+ attention_head_dim: int = 88,
82
+ in_channels: Optional[int] = None,
83
+ out_channels: Optional[int] = None,
84
+ num_layers: int = 1,
85
+ dropout: float = 0.0,
86
+ norm_num_groups: int = 32,
87
+ cross_attention_dim: Optional[int] = None,
88
+ attention_bias: bool = False,
89
+ sample_size: Optional[int] = None,
90
+ num_vector_embeds: Optional[int] = None,
91
+ patch_size: Optional[int] = None,
92
+ activation_fn: str = "geglu",
93
+ num_embeds_ada_norm: Optional[int] = None,
94
+ use_linear_projection: bool = False,
95
+ only_cross_attention: bool = False,
96
+ double_self_attention: bool = False,
97
+ upcast_attention: bool = False,
98
+ norm_type: str = "layer_norm",
99
+ norm_elementwise_affine: bool = True,
100
+ norm_eps: float = 1e-5,
101
+ attention_type: str = "default",
102
+ caption_channels: int = None,
103
+ ):
104
+ super().__init__()
105
+ self.use_linear_projection = use_linear_projection
106
+ self.num_attention_heads = num_attention_heads
107
+ self.attention_head_dim = attention_head_dim
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
111
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
112
+
113
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
+ deprecation_message = (
121
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
+ )
127
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
+ norm_type = "ada_norm"
129
+
130
+ if self.is_input_continuous and self.is_input_vectorized:
131
+ raise ValueError(
132
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
+ " sure that either `in_channels` or `num_vector_embeds` is None."
134
+ )
135
+ elif self.is_input_vectorized and self.is_input_patches:
136
+ raise ValueError(
137
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
+ " sure that either `num_vector_embeds` or `num_patches` is None."
139
+ )
140
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
+ raise ValueError(
142
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
+ )
145
+
146
+ # 2. Define input layers
147
+ if self.is_input_continuous:
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = linear_cls(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
+ elif self.is_input_vectorized:
156
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+ self.num_vector_embeds = num_vector_embeds
162
+ self.num_latent_pixels = self.height * self.width
163
+
164
+ self.latent_image_embedding = ImagePositionalEmbeddings(
165
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
+ )
167
+ elif self.is_input_patches:
168
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
+
170
+ self.height = sample_size
171
+ self.width = sample_size
172
+
173
+ self.patch_size = patch_size
174
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
175
+ interpolation_scale = max(interpolation_scale, 1)
176
+ self.pos_embed = PatchEmbed(
177
+ height=sample_size,
178
+ width=sample_size,
179
+ patch_size=patch_size,
180
+ in_channels=in_channels,
181
+ embed_dim=inner_dim,
182
+ interpolation_scale=interpolation_scale,
183
+ )
184
+
185
+ # 3. Define transformers blocks
186
+ self.transformer_blocks = nn.ModuleList(
187
+ [
188
+ BasicTransformerBlock(
189
+ inner_dim,
190
+ num_attention_heads,
191
+ attention_head_dim,
192
+ dropout=dropout,
193
+ cross_attention_dim=cross_attention_dim,
194
+ activation_fn=activation_fn,
195
+ num_embeds_ada_norm=num_embeds_ada_norm,
196
+ attention_bias=attention_bias,
197
+ only_cross_attention=only_cross_attention,
198
+ double_self_attention=double_self_attention,
199
+ upcast_attention=upcast_attention,
200
+ norm_type=norm_type,
201
+ norm_elementwise_affine=norm_elementwise_affine,
202
+ norm_eps=norm_eps,
203
+ attention_type=attention_type,
204
+ )
205
+ for d in range(num_layers)
206
+ ]
207
+ )
208
+
209
+ # 4. Define output layers
210
+ self.out_channels = in_channels if out_channels is None else out_channels
211
+ if self.is_input_continuous:
212
+ # TODO: should use out_channels for continuous projections
213
+ if use_linear_projection:
214
+ self.proj_out = linear_cls(inner_dim, in_channels)
215
+ else:
216
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
217
+ elif self.is_input_vectorized:
218
+ self.norm_out = nn.LayerNorm(inner_dim)
219
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
220
+ elif self.is_input_patches and norm_type != "ada_norm_single":
221
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
223
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
+ elif self.is_input_patches and norm_type == "ada_norm_single":
225
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
227
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
+
229
+ # 5. PixArt-Alpha blocks.
230
+ self.adaln_single = None
231
+ self.use_additional_conditions = False
232
+ if norm_type == "ada_norm_single":
233
+ self.use_additional_conditions = self.config.sample_size == 128
234
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
235
+ # additional conditions until we find better name
236
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
237
+
238
+ self.caption_projection = None
239
+ if caption_channels is not None:
240
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
241
+
242
+ self.gradient_checkpointing = False
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ spatial_attn_inputs = [],
248
+ spatial_attn_idx = 0,
249
+ encoder_hidden_states: Optional[torch.Tensor] = None,
250
+ timestep: Optional[torch.LongTensor] = None,
251
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
252
+ class_labels: Optional[torch.LongTensor] = None,
253
+ cross_attention_kwargs: Dict[str, Any] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ encoder_attention_mask: Optional[torch.Tensor] = None,
256
+ return_dict: bool = True,
257
+ ):
258
+ """
259
+ The [`Transformer2DModel`] forward method.
260
+
261
+ Args:
262
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
263
+ Input `hidden_states`.
264
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
265
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
266
+ self-attention.
267
+ timestep ( `torch.LongTensor`, *optional*):
268
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
269
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
270
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
271
+ `AdaLayerZeroNorm`.
272
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
273
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
274
+ `self.processor` in
275
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
276
+ attention_mask ( `torch.Tensor`, *optional*):
277
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
278
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
279
+ negative values to the attention scores corresponding to "discard" tokens.
280
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
281
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
282
+
283
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
284
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
285
+
286
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
287
+ above. This bias will be added to the cross-attention scores.
288
+ return_dict (`bool`, *optional*, defaults to `True`):
289
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
290
+ tuple.
291
+
292
+ Returns:
293
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
+ `tuple` where the first element is the sample tensor.
295
+ """
296
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
297
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
298
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
299
+ # expects mask of shape:
300
+ # [batch, key_tokens]
301
+ # adds singleton query_tokens dimension:
302
+ # [batch, 1, key_tokens]
303
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
304
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
305
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
306
+ if attention_mask is not None and attention_mask.ndim == 2:
307
+ # assume that mask is expressed as:
308
+ # (1 = keep, 0 = discard)
309
+ # convert mask into a bias that can be added to attention scores:
310
+ # (keep = +0, discard = -10000.0)
311
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
312
+ attention_mask = attention_mask.unsqueeze(1)
313
+
314
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
315
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
316
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
317
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
318
+
319
+ # Retrieve lora scale.
320
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
321
+
322
+ # 1. Input
323
+ if self.is_input_continuous:
324
+ batch, _, height, width = hidden_states.shape
325
+ residual = hidden_states
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ if not self.use_linear_projection:
329
+ hidden_states = (
330
+ self.proj_in(hidden_states, scale=lora_scale)
331
+ if not USE_PEFT_BACKEND
332
+ else self.proj_in(hidden_states)
333
+ )
334
+ inner_dim = hidden_states.shape[1]
335
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
336
+ else:
337
+ inner_dim = hidden_states.shape[1]
338
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
339
+ hidden_states = (
340
+ self.proj_in(hidden_states, scale=lora_scale)
341
+ if not USE_PEFT_BACKEND
342
+ else self.proj_in(hidden_states)
343
+ )
344
+
345
+ elif self.is_input_vectorized:
346
+ hidden_states = self.latent_image_embedding(hidden_states)
347
+ elif self.is_input_patches:
348
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
349
+ hidden_states = self.pos_embed(hidden_states)
350
+
351
+ if self.adaln_single is not None:
352
+ if self.use_additional_conditions and added_cond_kwargs is None:
353
+ raise ValueError(
354
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
355
+ )
356
+ batch_size = hidden_states.shape[0]
357
+ timestep, embedded_timestep = self.adaln_single(
358
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
359
+ )
360
+
361
+ # 2. Blocks
362
+ if self.caption_projection is not None:
363
+ batch_size = hidden_states.shape[0]
364
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
365
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
366
+
367
+ for block in self.transformer_blocks:
368
+ if self.training and self.gradient_checkpointing:
369
+ hidden_states, spatial_attn_inputs, spatial_attn_idx = torch.utils.checkpoint.checkpoint(
370
+ block,
371
+ hidden_states,
372
+ spatial_attn_inputs,
373
+ spatial_attn_idx,
374
+ attention_mask,
375
+ encoder_hidden_states,
376
+ encoder_attention_mask,
377
+ timestep,
378
+ cross_attention_kwargs,
379
+ class_labels,
380
+ use_reentrant=False,
381
+ )
382
+ else:
383
+ hidden_states, spatial_attn_inputs, spatial_attn_idx = block(
384
+ hidden_states,
385
+ spatial_attn_inputs,
386
+ spatial_attn_idx,
387
+ attention_mask=attention_mask,
388
+ encoder_hidden_states=encoder_hidden_states,
389
+ encoder_attention_mask=encoder_attention_mask,
390
+ timestep=timestep,
391
+ cross_attention_kwargs=cross_attention_kwargs,
392
+ class_labels=class_labels,
393
+ )
394
+
395
+ # 3. Output
396
+ if self.is_input_continuous:
397
+ if not self.use_linear_projection:
398
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
399
+ hidden_states = (
400
+ self.proj_out(hidden_states, scale=lora_scale)
401
+ if not USE_PEFT_BACKEND
402
+ else self.proj_out(hidden_states)
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.proj_out(hidden_states, scale=lora_scale)
407
+ if not USE_PEFT_BACKEND
408
+ else self.proj_out(hidden_states)
409
+ )
410
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
411
+
412
+ output = hidden_states + residual
413
+ elif self.is_input_vectorized:
414
+ hidden_states = self.norm_out(hidden_states)
415
+ logits = self.out(hidden_states)
416
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
417
+ logits = logits.permute(0, 2, 1)
418
+
419
+ # log(p(x_0))
420
+ output = F.log_softmax(logits.double(), dim=1).float()
421
+
422
+ if self.is_input_patches:
423
+ if self.config.norm_type != "ada_norm_single":
424
+ conditioning = self.transformer_blocks[0].norm1.emb(
425
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
426
+ )
427
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
428
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
429
+ hidden_states = self.proj_out_2(hidden_states)
430
+ elif self.config.norm_type == "ada_norm_single":
431
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
432
+ hidden_states = self.norm_out(hidden_states)
433
+ # Modulation
434
+ hidden_states = hidden_states * (1 + scale) + shift
435
+ hidden_states = self.proj_out(hidden_states)
436
+ hidden_states = hidden_states.squeeze(1)
437
+
438
+ # unpatchify
439
+ if self.adaln_single is None:
440
+ height = width = int(hidden_states.shape[1] ** 0.5)
441
+ hidden_states = hidden_states.reshape(
442
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
443
+ )
444
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
445
+ output = hidden_states.reshape(
446
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
447
+ )
448
+
449
+ if not return_dict:
450
+ return (output,), spatial_attn_inputs, spatial_attn_idx
451
+
452
+ return Transformer2DModelOutput(sample=output), spatial_attn_inputs, spatial_attn_idx
OOTDiffusion/ootd/pipelines_ootd/unet_garm_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
OOTDiffusion/ootd/pipelines_ootd/unet_garm_2d_condition.py ADDED
@@ -0,0 +1,1183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from .unet_garm_2d_blocks import (
24
+ UNetMidBlock2D,
25
+ UNetMidBlock2DCrossAttn,
26
+ UNetMidBlock2DSimpleCrossAttn,
27
+ get_down_block,
28
+ get_up_block,
29
+ )
30
+
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.loaders import UNet2DConditionLoadersMixin
33
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.models.activations import get_activation
35
+ from diffusers.models.attention_processor import (
36
+ ADDED_KV_ATTENTION_PROCESSORS,
37
+ CROSS_ATTENTION_PROCESSORS,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.embeddings import (
43
+ GaussianFourierProjection,
44
+ ImageHintTimeEmbedding,
45
+ ImageProjection,
46
+ ImageTimeEmbedding,
47
+ PositionNet,
48
+ TextImageProjection,
49
+ TextImageTimeEmbedding,
50
+ TextTimeEmbedding,
51
+ TimestepEmbedding,
52
+ Timesteps,
53
+ )
54
+ from diffusers.models.modeling_utils import ModelMixin
55
+ # from diffusers.models.unet_2d_blocks import (
56
+ # UNetMidBlock2D,
57
+ # UNetMidBlock2DCrossAttn,
58
+ # UNetMidBlock2DSimpleCrossAttn,
59
+ # get_down_block,
60
+ # get_up_block,
61
+ # )
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ @dataclass
68
+ class UNet2DConditionOutput(BaseOutput):
69
+ """
70
+ The output of [`UNet2DConditionModel`].
71
+
72
+ Args:
73
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
+ """
76
+
77
+ sample: torch.FloatTensor = None
78
+
79
+
80
+ class UNetGarm2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
+ r"""
82
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
+ shaped output.
84
+
85
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
+ for all models (such as downloading or saving).
87
+
88
+ Parameters:
89
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
+ Height and width of input/output sample.
91
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
+ Whether to flip the sin to cos in the time embedding.
96
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
+ The tuple of downsample blocks to use.
99
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
+ The tuple of upsample blocks to use.
104
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
+ Whether to include self-attention in the basic transformer blocks, see
106
+ [`~models.attention.BasicTransformerBlock`].
107
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
+ The tuple of output channels for each block.
109
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
+ If `None`, normalization and activation layers is skipped in post-processing.
116
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
+ The dimension of the cross attention features.
119
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
+ encoder_hid_dim (`int`, *optional*, defaults to None):
129
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
+ dimension to `cross_attention_dim`.
131
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
+ num_attention_heads (`int`, *optional*):
136
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
137
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
+ class_embed_type (`str`, *optional*, defaults to `None`):
140
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
+ addition_embed_type (`str`, *optional*, defaults to `None`):
143
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
+ "text". "text" will use the `TextTimeEmbedding` layer.
145
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
+ Dimension for the timestep embeddings.
147
+ num_class_embeds (`int`, *optional*, defaults to `None`):
148
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
+ class conditioning with `class_embed_type` equal to `None`.
150
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
151
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
153
+ An optional override for the dimension of the projected time embedding.
154
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
+ timestep_post_act (`str`, *optional*, defaults to `None`):
158
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
+ The dimension of `cond_proj` layer in the timestep embedding.
161
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
+ *optional*): The dimension of the `class_labels` input when
164
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
+ embeddings with the class embeddings.
167
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
+ otherwise.
172
+ """
173
+
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ sample_size: Optional[int] = None,
180
+ in_channels: int = 4,
181
+ out_channels: int = 4,
182
+ center_input_sample: bool = False,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
194
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
+ layers_per_block: Union[int, Tuple[int]] = 2,
196
+ downsample_padding: int = 1,
197
+ mid_block_scale_factor: float = 1,
198
+ dropout: float = 0.0,
199
+ act_fn: str = "silu",
200
+ norm_num_groups: Optional[int] = 32,
201
+ norm_eps: float = 1e-5,
202
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ attention_type: str = "default",
228
+ class_embeddings_concat: bool = False,
229
+ mid_block_only_cross_attention: Optional[bool] = None,
230
+ cross_attention_norm: Optional[str] = None,
231
+ addition_embed_type_num_heads=64,
232
+ ):
233
+ super().__init__()
234
+
235
+ self.sample_size = sample_size
236
+
237
+ if num_attention_heads is not None:
238
+ raise ValueError(
239
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
+ )
241
+
242
+ # If `num_attention_heads` is not defined (which is the case for most models)
243
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
+ # which is why we correct for the naming here.
248
+ num_attention_heads = num_attention_heads or attention_head_dim
249
+
250
+ # Check inputs
251
+ if len(down_block_types) != len(up_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
+ )
255
+
256
+ if len(block_out_channels) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
+ )
275
+
276
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
+ raise ValueError(
278
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
+ )
280
+
281
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
+ raise ValueError(
283
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
+ )
285
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
+ for layer_number_per_block in transformer_layers_per_block:
287
+ if isinstance(layer_number_per_block, list):
288
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
+
290
+ # input
291
+ conv_in_padding = (conv_in_kernel - 1) // 2
292
+ self.conv_in = nn.Conv2d(
293
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
+ )
295
+
296
+ # time
297
+ if time_embedding_type == "fourier":
298
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
+ if time_embed_dim % 2 != 0:
300
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
+ self.time_proj = GaussianFourierProjection(
302
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
+ )
304
+ timestep_input_dim = time_embed_dim
305
+ elif time_embedding_type == "positional":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
+
308
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
+ timestep_input_dim = block_out_channels[0]
310
+ else:
311
+ raise ValueError(
312
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
+ )
314
+
315
+ self.time_embedding = TimestepEmbedding(
316
+ timestep_input_dim,
317
+ time_embed_dim,
318
+ act_fn=act_fn,
319
+ post_act_fn=timestep_post_act,
320
+ cond_proj_dim=time_cond_proj_dim,
321
+ )
322
+
323
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
+ encoder_hid_dim_type = "text_proj"
325
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
+
328
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
+ raise ValueError(
330
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
+ )
332
+
333
+ if encoder_hid_dim_type == "text_proj":
334
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
+ elif encoder_hid_dim_type == "text_image_proj":
336
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
+ self.encoder_hid_proj = TextImageProjection(
340
+ text_embed_dim=encoder_hid_dim,
341
+ image_embed_dim=cross_attention_dim,
342
+ cross_attention_dim=cross_attention_dim,
343
+ )
344
+ elif encoder_hid_dim_type == "image_proj":
345
+ # Kandinsky 2.2
346
+ self.encoder_hid_proj = ImageProjection(
347
+ image_embed_dim=encoder_hid_dim,
348
+ cross_attention_dim=cross_attention_dim,
349
+ )
350
+ elif encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
+ )
354
+ else:
355
+ self.encoder_hid_proj = None
356
+
357
+ # class embedding
358
+ if class_embed_type is None and num_class_embeds is not None:
359
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
+ elif class_embed_type == "timestep":
361
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
+ elif class_embed_type == "identity":
363
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
+ elif class_embed_type == "projection":
365
+ if projection_class_embeddings_input_dim is None:
366
+ raise ValueError(
367
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
+ )
369
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
+ # 2. it projects from an arbitrary input dimension.
372
+ #
373
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
+ elif class_embed_type == "simple_projection":
378
+ if projection_class_embeddings_input_dim is None:
379
+ raise ValueError(
380
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
+ )
382
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
+ else:
384
+ self.class_embedding = None
385
+
386
+ if addition_embed_type == "text":
387
+ if encoder_hid_dim is not None:
388
+ text_time_embedding_from_dim = encoder_hid_dim
389
+ else:
390
+ text_time_embedding_from_dim = cross_attention_dim
391
+
392
+ self.add_embedding = TextTimeEmbedding(
393
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
+ )
395
+ elif addition_embed_type == "text_image":
396
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
+ self.add_embedding = TextImageTimeEmbedding(
400
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
+ )
402
+ elif addition_embed_type == "text_time":
403
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif addition_embed_type == "image":
406
+ # Kandinsky 2.2
407
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
+ elif addition_embed_type == "image_hint":
409
+ # Kandinsky 2.2 ControlNet
410
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type is not None:
412
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
+
414
+ if time_embedding_act_fn is None:
415
+ self.time_embed_act = None
416
+ else:
417
+ self.time_embed_act = get_activation(time_embedding_act_fn)
418
+
419
+ self.down_blocks = nn.ModuleList([])
420
+ self.up_blocks = nn.ModuleList([])
421
+
422
+ if isinstance(only_cross_attention, bool):
423
+ if mid_block_only_cross_attention is None:
424
+ mid_block_only_cross_attention = only_cross_attention
425
+
426
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
427
+
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = False
430
+
431
+ if isinstance(num_attention_heads, int):
432
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
+
434
+ if isinstance(attention_head_dim, int):
435
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
+
437
+ if isinstance(cross_attention_dim, int):
438
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
+
440
+ if isinstance(layers_per_block, int):
441
+ layers_per_block = [layers_per_block] * len(down_block_types)
442
+
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
+
446
+ if class_embeddings_concat:
447
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
448
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
+ # regular time embeddings
450
+ blocks_time_embed_dim = time_embed_dim * 2
451
+ else:
452
+ blocks_time_embed_dim = time_embed_dim
453
+
454
+ # down
455
+ output_channel = block_out_channels[0]
456
+ for i, down_block_type in enumerate(down_block_types):
457
+ input_channel = output_channel
458
+ output_channel = block_out_channels[i]
459
+ is_final_block = i == len(block_out_channels) - 1
460
+
461
+ down_block = get_down_block(
462
+ down_block_type,
463
+ num_layers=layers_per_block[i],
464
+ transformer_layers_per_block=transformer_layers_per_block[i],
465
+ in_channels=input_channel,
466
+ out_channels=output_channel,
467
+ temb_channels=blocks_time_embed_dim,
468
+ add_downsample=not is_final_block,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ resnet_groups=norm_num_groups,
472
+ cross_attention_dim=cross_attention_dim[i],
473
+ num_attention_heads=num_attention_heads[i],
474
+ downsample_padding=downsample_padding,
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention[i],
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ attention_type=attention_type,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ resnet_out_scale_factor=resnet_out_scale_factor,
483
+ cross_attention_norm=cross_attention_norm,
484
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
+ dropout=dropout,
486
+ )
487
+ self.down_blocks.append(down_block)
488
+
489
+ # mid
490
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
491
+ self.mid_block = UNetMidBlock2DCrossAttn(
492
+ transformer_layers_per_block=transformer_layers_per_block[-1],
493
+ in_channels=block_out_channels[-1],
494
+ temb_channels=blocks_time_embed_dim,
495
+ dropout=dropout,
496
+ resnet_eps=norm_eps,
497
+ resnet_act_fn=act_fn,
498
+ output_scale_factor=mid_block_scale_factor,
499
+ resnet_time_scale_shift=resnet_time_scale_shift,
500
+ cross_attention_dim=cross_attention_dim[-1],
501
+ num_attention_heads=num_attention_heads[-1],
502
+ resnet_groups=norm_num_groups,
503
+ dual_cross_attention=dual_cross_attention,
504
+ use_linear_projection=use_linear_projection,
505
+ upcast_attention=upcast_attention,
506
+ attention_type=attention_type,
507
+ )
508
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
+ in_channels=block_out_channels[-1],
511
+ temb_channels=blocks_time_embed_dim,
512
+ dropout=dropout,
513
+ resnet_eps=norm_eps,
514
+ resnet_act_fn=act_fn,
515
+ output_scale_factor=mid_block_scale_factor,
516
+ cross_attention_dim=cross_attention_dim[-1],
517
+ attention_head_dim=attention_head_dim[-1],
518
+ resnet_groups=norm_num_groups,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ skip_time_act=resnet_skip_time_act,
521
+ only_cross_attention=mid_block_only_cross_attention,
522
+ cross_attention_norm=cross_attention_norm,
523
+ )
524
+ elif mid_block_type == "UNetMidBlock2D":
525
+ self.mid_block = UNetMidBlock2D(
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ dropout=dropout,
529
+ num_layers=0,
530
+ resnet_eps=norm_eps,
531
+ resnet_act_fn=act_fn,
532
+ output_scale_factor=mid_block_scale_factor,
533
+ resnet_groups=norm_num_groups,
534
+ resnet_time_scale_shift=resnet_time_scale_shift,
535
+ add_attention=False,
536
+ )
537
+ elif mid_block_type is None:
538
+ self.mid_block = None
539
+ else:
540
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
+
542
+ # count how many layers upsample the images
543
+ self.num_upsamplers = 0
544
+
545
+ # up
546
+ reversed_block_out_channels = list(reversed(block_out_channels))
547
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
548
+ reversed_layers_per_block = list(reversed(layers_per_block))
549
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
+ reversed_transformer_layers_per_block = (
551
+ list(reversed(transformer_layers_per_block))
552
+ if reverse_transformer_layers_per_block is None
553
+ else reverse_transformer_layers_per_block
554
+ )
555
+ only_cross_attention = list(reversed(only_cross_attention))
556
+
557
+ output_channel = reversed_block_out_channels[0]
558
+ for i, up_block_type in enumerate(up_block_types):
559
+ is_final_block = i == len(block_out_channels) - 1
560
+
561
+ prev_output_channel = output_channel
562
+ output_channel = reversed_block_out_channels[i]
563
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
+
565
+ # add upsample block for all BUT final layer
566
+ if not is_final_block:
567
+ add_upsample = True
568
+ self.num_upsamplers += 1
569
+ else:
570
+ add_upsample = False
571
+
572
+ up_block = get_up_block(
573
+ up_block_type,
574
+ num_layers=reversed_layers_per_block[i] + 1,
575
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
+ in_channels=input_channel,
577
+ out_channels=output_channel,
578
+ prev_output_channel=prev_output_channel,
579
+ temb_channels=blocks_time_embed_dim,
580
+ add_upsample=add_upsample,
581
+ resnet_eps=norm_eps,
582
+ resnet_act_fn=act_fn,
583
+ resolution_idx=i,
584
+ resnet_groups=norm_num_groups,
585
+ cross_attention_dim=reversed_cross_attention_dim[i],
586
+ num_attention_heads=reversed_num_attention_heads[i],
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ only_cross_attention=only_cross_attention[i],
590
+ upcast_attention=upcast_attention,
591
+ resnet_time_scale_shift=resnet_time_scale_shift,
592
+ attention_type=attention_type,
593
+ resnet_skip_time_act=resnet_skip_time_act,
594
+ resnet_out_scale_factor=resnet_out_scale_factor,
595
+ cross_attention_norm=cross_attention_norm,
596
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
+ dropout=dropout,
598
+ )
599
+ self.up_blocks.append(up_block)
600
+ prev_output_channel = output_channel
601
+
602
+ # out
603
+ if norm_num_groups is not None:
604
+ self.conv_norm_out = nn.GroupNorm(
605
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
+ )
607
+
608
+ self.conv_act = get_activation(act_fn)
609
+
610
+ else:
611
+ self.conv_norm_out = None
612
+ self.conv_act = None
613
+
614
+ conv_out_padding = (conv_out_kernel - 1) // 2
615
+ self.conv_out = nn.Conv2d(
616
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
+ )
618
+
619
+ if attention_type in ["gated", "gated-text-image"]:
620
+ positive_len = 768
621
+ if isinstance(cross_attention_dim, int):
622
+ positive_len = cross_attention_dim
623
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
+ positive_len = cross_attention_dim[0]
625
+
626
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
627
+ self.position_net = PositionNet(
628
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
+ )
630
+
631
+ @property
632
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
+ r"""
634
+ Returns:
635
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
+ indexed by its weight name.
637
+ """
638
+ # set recursively
639
+ processors = {}
640
+
641
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
+ if hasattr(module, "get_processor"):
643
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
+
645
+ for sub_name, child in module.named_children():
646
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
+
648
+ return processors
649
+
650
+ for name, module in self.named_children():
651
+ fn_recursive_add_processors(name, module, processors)
652
+
653
+ return processors
654
+
655
+ def set_attn_processor(
656
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
+ ):
658
+ r"""
659
+ Sets the attention processor to use to compute attention.
660
+
661
+ Parameters:
662
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
+ for **all** `Attention` layers.
665
+
666
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
+ processor. This is strongly recommended when setting trainable attention processors.
668
+
669
+ """
670
+ count = len(self.attn_processors.keys())
671
+
672
+ if isinstance(processor, dict) and len(processor) != count:
673
+ raise ValueError(
674
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
+ )
677
+
678
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
+ if hasattr(module, "set_processor"):
680
+ if not isinstance(processor, dict):
681
+ module.set_processor(processor, _remove_lora=_remove_lora)
682
+ else:
683
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
+
685
+ for sub_name, child in module.named_children():
686
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
+
688
+ for name, module in self.named_children():
689
+ fn_recursive_attn_processor(name, module, processor)
690
+
691
+ def set_default_attn_processor(self):
692
+ """
693
+ Disables custom attention processors and sets the default attention implementation.
694
+ """
695
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
+ processor = AttnAddedKVProcessor()
697
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
+ processor = AttnProcessor()
699
+ else:
700
+ raise ValueError(
701
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
+ )
703
+
704
+ self.set_attn_processor(processor, _remove_lora=True)
705
+
706
+ def set_attention_slice(self, slice_size):
707
+ r"""
708
+ Enable sliced attention computation.
709
+
710
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
+
713
+ Args:
714
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
+ must be a multiple of `slice_size`.
719
+ """
720
+ sliceable_head_dims = []
721
+
722
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
+ if hasattr(module, "set_attention_slice"):
724
+ sliceable_head_dims.append(module.sliceable_head_dim)
725
+
726
+ for child in module.children():
727
+ fn_recursive_retrieve_sliceable_dims(child)
728
+
729
+ # retrieve number of attention layers
730
+ for module in self.children():
731
+ fn_recursive_retrieve_sliceable_dims(module)
732
+
733
+ num_sliceable_layers = len(sliceable_head_dims)
734
+
735
+ if slice_size == "auto":
736
+ # half the attention head size is usually a good trade-off between
737
+ # speed and memory
738
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
739
+ elif slice_size == "max":
740
+ # make smallest slice possible
741
+ slice_size = num_sliceable_layers * [1]
742
+
743
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
+
745
+ if len(slice_size) != len(sliceable_head_dims):
746
+ raise ValueError(
747
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
+ )
750
+
751
+ for i in range(len(slice_size)):
752
+ size = slice_size[i]
753
+ dim = sliceable_head_dims[i]
754
+ if size is not None and size > dim:
755
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
+
757
+ # Recursively walk through all the children.
758
+ # Any children which exposes the set_attention_slice method
759
+ # gets the message
760
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
+ if hasattr(module, "set_attention_slice"):
762
+ module.set_attention_slice(slice_size.pop())
763
+
764
+ for child in module.children():
765
+ fn_recursive_set_attention_slice(child, slice_size)
766
+
767
+ reversed_slice_size = list(reversed(slice_size))
768
+ for module in self.children():
769
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
770
+
771
+ def _set_gradient_checkpointing(self, module, value=False):
772
+ if hasattr(module, "gradient_checkpointing"):
773
+ module.gradient_checkpointing = value
774
+
775
+ def enable_freeu(self, s1, s2, b1, b2):
776
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
+
778
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
+
780
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
+
783
+ Args:
784
+ s1 (`float`):
785
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
787
+ s2 (`float`):
788
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
790
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
+ """
793
+ for i, upsample_block in enumerate(self.up_blocks):
794
+ setattr(upsample_block, "s1", s1)
795
+ setattr(upsample_block, "s2", s2)
796
+ setattr(upsample_block, "b1", b1)
797
+ setattr(upsample_block, "b2", b2)
798
+
799
+ def disable_freeu(self):
800
+ """Disables the FreeU mechanism."""
801
+ freeu_keys = {"s1", "s2", "b1", "b2"}
802
+ for i, upsample_block in enumerate(self.up_blocks):
803
+ for k in freeu_keys:
804
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
+ setattr(upsample_block, k, None)
806
+
807
+ def forward(
808
+ self,
809
+ sample: torch.FloatTensor,
810
+ timestep: Union[torch.Tensor, float, int],
811
+ encoder_hidden_states: torch.Tensor,
812
+ class_labels: Optional[torch.Tensor] = None,
813
+ timestep_cond: Optional[torch.Tensor] = None,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
816
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
817
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
818
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
819
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
820
+ encoder_attention_mask: Optional[torch.Tensor] = None,
821
+ return_dict: bool = True,
822
+ ) -> Union[UNet2DConditionOutput, Tuple]:
823
+ r"""
824
+ The [`UNet2DConditionModel`] forward method.
825
+
826
+ Args:
827
+ sample (`torch.FloatTensor`):
828
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
829
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
830
+ encoder_hidden_states (`torch.FloatTensor`):
831
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
832
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
833
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
834
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
835
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
836
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
837
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
838
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
839
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
840
+ negative values to the attention scores corresponding to "discard" tokens.
841
+ cross_attention_kwargs (`dict`, *optional*):
842
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
843
+ `self.processor` in
844
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
845
+ added_cond_kwargs: (`dict`, *optional*):
846
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
847
+ are passed along to the UNet blocks.
848
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
849
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
850
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
851
+ A tensor that if specified is added to the residual of the middle unet block.
852
+ encoder_attention_mask (`torch.Tensor`):
853
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
854
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
855
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
856
+ return_dict (`bool`, *optional*, defaults to `True`):
857
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
858
+ tuple.
859
+ cross_attention_kwargs (`dict`, *optional*):
860
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
861
+ added_cond_kwargs: (`dict`, *optional*):
862
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
863
+ are passed along to the UNet blocks.
864
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
865
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
866
+ example from ControlNet side model(s)
867
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
868
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
869
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
870
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
871
+
872
+ Returns:
873
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
874
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
875
+ a `tuple` is returned where the first element is the sample tensor.
876
+ """
877
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
878
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
879
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
880
+ # on the fly if necessary.
881
+ default_overall_up_factor = 2**self.num_upsamplers
882
+
883
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
884
+ forward_upsample_size = False
885
+ upsample_size = None
886
+
887
+ for dim in sample.shape[-2:]:
888
+ if dim % default_overall_up_factor != 0:
889
+ # Forward upsample size to force interpolation output size.
890
+ forward_upsample_size = True
891
+ break
892
+
893
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
894
+ # expects mask of shape:
895
+ # [batch, key_tokens]
896
+ # adds singleton query_tokens dimension:
897
+ # [batch, 1, key_tokens]
898
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
899
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
900
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
901
+ if attention_mask is not None:
902
+ # assume that mask is expressed as:
903
+ # (1 = keep, 0 = discard)
904
+ # convert mask into a bias that can be added to attention scores:
905
+ # (keep = +0, discard = -10000.0)
906
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
907
+ attention_mask = attention_mask.unsqueeze(1)
908
+
909
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
910
+ if encoder_attention_mask is not None:
911
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
912
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
913
+
914
+ # 0. center input if necessary
915
+ if self.config.center_input_sample:
916
+ sample = 2 * sample - 1.0
917
+
918
+ # 1. time
919
+ timesteps = timestep
920
+ if not torch.is_tensor(timesteps):
921
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
922
+ # This would be a good case for the `match` statement (Python 3.10+)
923
+ is_mps = sample.device.type == "mps"
924
+ if isinstance(timestep, float):
925
+ dtype = torch.float32 if is_mps else torch.float64
926
+ else:
927
+ dtype = torch.int32 if is_mps else torch.int64
928
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
929
+ elif len(timesteps.shape) == 0:
930
+ timesteps = timesteps[None].to(sample.device)
931
+
932
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
933
+ timesteps = timesteps.expand(sample.shape[0])
934
+
935
+ t_emb = self.time_proj(timesteps)
936
+
937
+ # `Timesteps` does not contain any weights and will always return f32 tensors
938
+ # but time_embedding might actually be running in fp16. so we need to cast here.
939
+ # there might be better ways to encapsulate this.
940
+ t_emb = t_emb.to(dtype=sample.dtype)
941
+
942
+ emb = self.time_embedding(t_emb, timestep_cond)
943
+ aug_emb = None
944
+
945
+ if self.class_embedding is not None:
946
+ if class_labels is None:
947
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
948
+
949
+ if self.config.class_embed_type == "timestep":
950
+ class_labels = self.time_proj(class_labels)
951
+
952
+ # `Timesteps` does not contain any weights and will always return f32 tensors
953
+ # there might be better ways to encapsulate this.
954
+ class_labels = class_labels.to(dtype=sample.dtype)
955
+
956
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
957
+
958
+ if self.config.class_embeddings_concat:
959
+ emb = torch.cat([emb, class_emb], dim=-1)
960
+ else:
961
+ emb = emb + class_emb
962
+
963
+ if self.config.addition_embed_type == "text":
964
+ aug_emb = self.add_embedding(encoder_hidden_states)
965
+ elif self.config.addition_embed_type == "text_image":
966
+ # Kandinsky 2.1 - style
967
+ if "image_embeds" not in added_cond_kwargs:
968
+ raise ValueError(
969
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
970
+ )
971
+
972
+ image_embs = added_cond_kwargs.get("image_embeds")
973
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
974
+ aug_emb = self.add_embedding(text_embs, image_embs)
975
+ elif self.config.addition_embed_type == "text_time":
976
+ # SDXL - style
977
+ if "text_embeds" not in added_cond_kwargs:
978
+ raise ValueError(
979
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
980
+ )
981
+ text_embeds = added_cond_kwargs.get("text_embeds")
982
+ if "time_ids" not in added_cond_kwargs:
983
+ raise ValueError(
984
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
985
+ )
986
+ time_ids = added_cond_kwargs.get("time_ids")
987
+ time_embeds = self.add_time_proj(time_ids.flatten())
988
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
989
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
990
+ add_embeds = add_embeds.to(emb.dtype)
991
+ aug_emb = self.add_embedding(add_embeds)
992
+ elif self.config.addition_embed_type == "image":
993
+ # Kandinsky 2.2 - style
994
+ if "image_embeds" not in added_cond_kwargs:
995
+ raise ValueError(
996
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
997
+ )
998
+ image_embs = added_cond_kwargs.get("image_embeds")
999
+ aug_emb = self.add_embedding(image_embs)
1000
+ elif self.config.addition_embed_type == "image_hint":
1001
+ # Kandinsky 2.2 - style
1002
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1003
+ raise ValueError(
1004
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1005
+ )
1006
+ image_embs = added_cond_kwargs.get("image_embeds")
1007
+ hint = added_cond_kwargs.get("hint")
1008
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1009
+ sample = torch.cat([sample, hint], dim=1)
1010
+
1011
+ emb = emb + aug_emb if aug_emb is not None else emb
1012
+
1013
+ if self.time_embed_act is not None:
1014
+ emb = self.time_embed_act(emb)
1015
+
1016
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1017
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1018
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1019
+ # Kadinsky 2.1 - style
1020
+ if "image_embeds" not in added_cond_kwargs:
1021
+ raise ValueError(
1022
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1023
+ )
1024
+
1025
+ image_embeds = added_cond_kwargs.get("image_embeds")
1026
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1027
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1028
+ # Kandinsky 2.2 - style
1029
+ if "image_embeds" not in added_cond_kwargs:
1030
+ raise ValueError(
1031
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1032
+ )
1033
+ image_embeds = added_cond_kwargs.get("image_embeds")
1034
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1035
+ # 2. pre-process
1036
+ sample = self.conv_in(sample)
1037
+
1038
+ # 2.5 GLIGEN position net
1039
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1040
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1041
+ gligen_args = cross_attention_kwargs.pop("gligen")
1042
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1043
+
1044
+ # For Vton
1045
+ spatial_attn_inputs = []
1046
+
1047
+ # 3. down
1048
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1049
+ if USE_PEFT_BACKEND:
1050
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1051
+ scale_lora_layers(self, lora_scale)
1052
+
1053
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1054
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1055
+ is_adapter = down_intrablock_additional_residuals is not None
1056
+ # maintain backward compatibility for legacy usage, where
1057
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1058
+ # but can only use one or the other
1059
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1060
+ deprecate(
1061
+ "T2I should not use down_block_additional_residuals",
1062
+ "1.3.0",
1063
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1064
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1065
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1066
+ standard_warn=False,
1067
+ )
1068
+ down_intrablock_additional_residuals = down_block_additional_residuals
1069
+ is_adapter = True
1070
+
1071
+ down_block_res_samples = (sample,)
1072
+ for downsample_block in self.down_blocks:
1073
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1074
+ # For t2i-adapter CrossAttnDownBlock2D
1075
+ additional_residuals = {}
1076
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1077
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1078
+
1079
+ sample, res_samples, spatial_attn_inputs = downsample_block(
1080
+ hidden_states=sample,
1081
+ spatial_attn_inputs=spatial_attn_inputs,
1082
+ temb=emb,
1083
+ encoder_hidden_states=encoder_hidden_states,
1084
+ attention_mask=attention_mask,
1085
+ cross_attention_kwargs=cross_attention_kwargs,
1086
+ encoder_attention_mask=encoder_attention_mask,
1087
+ **additional_residuals,
1088
+ )
1089
+ else:
1090
+ sample, res_samples = downsample_block(
1091
+ hidden_states=sample,
1092
+ temb=emb,
1093
+ scale=lora_scale,
1094
+ )
1095
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1096
+ sample += down_intrablock_additional_residuals.pop(0)
1097
+
1098
+ down_block_res_samples += res_samples
1099
+
1100
+ # if is_controlnet:
1101
+ # new_down_block_res_samples = ()
1102
+
1103
+ # for down_block_res_sample, down_block_additional_residual in zip(
1104
+ # down_block_res_samples, down_block_additional_residuals
1105
+ # ):
1106
+ # down_block_res_sample = down_block_res_sample + down_block_additional_residual
1107
+ # new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1108
+
1109
+ # down_block_res_samples = new_down_block_res_samples
1110
+
1111
+ # 4. mid
1112
+ if self.mid_block is not None:
1113
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1114
+ sample, spatial_attn_inputs = self.mid_block(
1115
+ sample,
1116
+ spatial_attn_inputs=spatial_attn_inputs,
1117
+ temb=emb,
1118
+ encoder_hidden_states=encoder_hidden_states,
1119
+ attention_mask=attention_mask,
1120
+ cross_attention_kwargs=cross_attention_kwargs,
1121
+ encoder_attention_mask=encoder_attention_mask,
1122
+ )
1123
+ else:
1124
+ sample = self.mid_block(sample, emb)
1125
+
1126
+ # To support T2I-Adapter-XL
1127
+ if (
1128
+ is_adapter
1129
+ and len(down_intrablock_additional_residuals) > 0
1130
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1131
+ ):
1132
+ sample += down_intrablock_additional_residuals.pop(0)
1133
+
1134
+ if is_controlnet:
1135
+ sample = sample + mid_block_additional_residual
1136
+
1137
+ # 5. up
1138
+ for i, upsample_block in enumerate(self.up_blocks):
1139
+ is_final_block = i == len(self.up_blocks) - 1
1140
+
1141
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1142
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1143
+
1144
+ # if we have not reached the final block and need to forward the
1145
+ # upsample size, we do it here
1146
+ if not is_final_block and forward_upsample_size:
1147
+ upsample_size = down_block_res_samples[-1].shape[2:]
1148
+
1149
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1150
+ sample, spatial_attn_inputs = upsample_block(
1151
+ hidden_states=sample,
1152
+ spatial_attn_inputs=spatial_attn_inputs,
1153
+ temb=emb,
1154
+ res_hidden_states_tuple=res_samples,
1155
+ encoder_hidden_states=encoder_hidden_states,
1156
+ cross_attention_kwargs=cross_attention_kwargs,
1157
+ upsample_size=upsample_size,
1158
+ attention_mask=attention_mask,
1159
+ encoder_attention_mask=encoder_attention_mask,
1160
+ )
1161
+ else:
1162
+ sample = upsample_block(
1163
+ hidden_states=sample,
1164
+ temb=emb,
1165
+ res_hidden_states_tuple=res_samples,
1166
+ upsample_size=upsample_size,
1167
+ scale=lora_scale,
1168
+ )
1169
+
1170
+ # 6. post-process
1171
+ if self.conv_norm_out:
1172
+ sample = self.conv_norm_out(sample)
1173
+ sample = self.conv_act(sample)
1174
+ sample = self.conv_out(sample)
1175
+
1176
+ if USE_PEFT_BACKEND:
1177
+ # remove `lora_scale` from each PEFT layer
1178
+ unscale_lora_layers(self, lora_scale)
1179
+
1180
+ if not return_dict:
1181
+ return (sample,), spatial_attn_inputs
1182
+
1183
+ return UNet2DConditionOutput(sample=sample), spatial_attn_inputs
OOTDiffusion/ootd/pipelines_ootd/unet_vton_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
OOTDiffusion/ootd/pipelines_ootd/unet_vton_2d_condition.py ADDED
@@ -0,0 +1,1183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified by Yuhao Xu for OOTDiffusion (https://github.com/levihsu/OOTDiffusion)
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from .unet_vton_2d_blocks import (
24
+ UNetMidBlock2D,
25
+ UNetMidBlock2DCrossAttn,
26
+ UNetMidBlock2DSimpleCrossAttn,
27
+ get_down_block,
28
+ get_up_block,
29
+ )
30
+
31
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
32
+ from diffusers.loaders import UNet2DConditionLoadersMixin
33
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
34
+ from diffusers.models.activations import get_activation
35
+ from diffusers.models.attention_processor import (
36
+ ADDED_KV_ATTENTION_PROCESSORS,
37
+ CROSS_ATTENTION_PROCESSORS,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.embeddings import (
43
+ GaussianFourierProjection,
44
+ ImageHintTimeEmbedding,
45
+ ImageProjection,
46
+ ImageTimeEmbedding,
47
+ PositionNet,
48
+ TextImageProjection,
49
+ TextImageTimeEmbedding,
50
+ TextTimeEmbedding,
51
+ TimestepEmbedding,
52
+ Timesteps,
53
+ )
54
+ from diffusers.models.modeling_utils import ModelMixin
55
+ # from ..diffusers.src.diffusers.models.unet_2d_blocks import (
56
+ # UNetMidBlock2D,
57
+ # UNetMidBlock2DCrossAttn,
58
+ # UNetMidBlock2DSimpleCrossAttn,
59
+ # get_down_block,
60
+ # get_up_block,
61
+ # )
62
+
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ @dataclass
68
+ class UNet2DConditionOutput(BaseOutput):
69
+ """
70
+ The output of [`UNet2DConditionModel`].
71
+
72
+ Args:
73
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
74
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
75
+ """
76
+
77
+ sample: torch.FloatTensor = None
78
+
79
+
80
+ class UNetVton2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
81
+ r"""
82
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
83
+ shaped output.
84
+
85
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
86
+ for all models (such as downloading or saving).
87
+
88
+ Parameters:
89
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
90
+ Height and width of input/output sample.
91
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
92
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
93
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
94
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
95
+ Whether to flip the sin to cos in the time embedding.
96
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
97
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
98
+ The tuple of downsample blocks to use.
99
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
100
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
101
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
102
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
103
+ The tuple of upsample blocks to use.
104
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
105
+ Whether to include self-attention in the basic transformer blocks, see
106
+ [`~models.attention.BasicTransformerBlock`].
107
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
108
+ The tuple of output channels for each block.
109
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
110
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
111
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
112
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
113
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
114
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
115
+ If `None`, normalization and activation layers is skipped in post-processing.
116
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
117
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
118
+ The dimension of the cross attention features.
119
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
120
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
121
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
122
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
124
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
125
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
126
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
127
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
128
+ encoder_hid_dim (`int`, *optional*, defaults to None):
129
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
130
+ dimension to `cross_attention_dim`.
131
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
132
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
133
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
134
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
135
+ num_attention_heads (`int`, *optional*):
136
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
137
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
138
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
139
+ class_embed_type (`str`, *optional*, defaults to `None`):
140
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
141
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
142
+ addition_embed_type (`str`, *optional*, defaults to `None`):
143
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
144
+ "text". "text" will use the `TextTimeEmbedding` layer.
145
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
146
+ Dimension for the timestep embeddings.
147
+ num_class_embeds (`int`, *optional*, defaults to `None`):
148
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
149
+ class conditioning with `class_embed_type` equal to `None`.
150
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
151
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
152
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
153
+ An optional override for the dimension of the projected time embedding.
154
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
155
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
156
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
157
+ timestep_post_act (`str`, *optional*, defaults to `None`):
158
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
159
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
160
+ The dimension of `cond_proj` layer in the timestep embedding.
161
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
162
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
163
+ *optional*): The dimension of the `class_labels` input when
164
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
165
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
166
+ embeddings with the class embeddings.
167
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
168
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
169
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
170
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
171
+ otherwise.
172
+ """
173
+
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ sample_size: Optional[int] = None,
180
+ in_channels: int = 4,
181
+ out_channels: int = 4,
182
+ center_input_sample: bool = False,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
192
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
193
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
194
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
195
+ layers_per_block: Union[int, Tuple[int]] = 2,
196
+ downsample_padding: int = 1,
197
+ mid_block_scale_factor: float = 1,
198
+ dropout: float = 0.0,
199
+ act_fn: str = "silu",
200
+ norm_num_groups: Optional[int] = 32,
201
+ norm_eps: float = 1e-5,
202
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
203
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
204
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ attention_type: str = "default",
228
+ class_embeddings_concat: bool = False,
229
+ mid_block_only_cross_attention: Optional[bool] = None,
230
+ cross_attention_norm: Optional[str] = None,
231
+ addition_embed_type_num_heads=64,
232
+ ):
233
+ super().__init__()
234
+
235
+ self.sample_size = sample_size
236
+
237
+ if num_attention_heads is not None:
238
+ raise ValueError(
239
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
240
+ )
241
+
242
+ # If `num_attention_heads` is not defined (which is the case for most models)
243
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
244
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
245
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
246
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
247
+ # which is why we correct for the naming here.
248
+ num_attention_heads = num_attention_heads or attention_head_dim
249
+
250
+ # Check inputs
251
+ if len(down_block_types) != len(up_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
254
+ )
255
+
256
+ if len(block_out_channels) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
274
+ )
275
+
276
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
277
+ raise ValueError(
278
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
279
+ )
280
+
281
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
282
+ raise ValueError(
283
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
284
+ )
285
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
286
+ for layer_number_per_block in transformer_layers_per_block:
287
+ if isinstance(layer_number_per_block, list):
288
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
289
+
290
+ # input
291
+ conv_in_padding = (conv_in_kernel - 1) // 2
292
+ self.conv_in = nn.Conv2d(
293
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
+ )
295
+
296
+ # time
297
+ if time_embedding_type == "fourier":
298
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
+ if time_embed_dim % 2 != 0:
300
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
+ self.time_proj = GaussianFourierProjection(
302
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
+ )
304
+ timestep_input_dim = time_embed_dim
305
+ elif time_embedding_type == "positional":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
+
308
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
+ timestep_input_dim = block_out_channels[0]
310
+ else:
311
+ raise ValueError(
312
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
+ )
314
+
315
+ self.time_embedding = TimestepEmbedding(
316
+ timestep_input_dim,
317
+ time_embed_dim,
318
+ act_fn=act_fn,
319
+ post_act_fn=timestep_post_act,
320
+ cond_proj_dim=time_cond_proj_dim,
321
+ )
322
+
323
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
+ encoder_hid_dim_type = "text_proj"
325
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
+
328
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
+ raise ValueError(
330
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
+ )
332
+
333
+ if encoder_hid_dim_type == "text_proj":
334
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
+ elif encoder_hid_dim_type == "text_image_proj":
336
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
+ self.encoder_hid_proj = TextImageProjection(
340
+ text_embed_dim=encoder_hid_dim,
341
+ image_embed_dim=cross_attention_dim,
342
+ cross_attention_dim=cross_attention_dim,
343
+ )
344
+ elif encoder_hid_dim_type == "image_proj":
345
+ # Kandinsky 2.2
346
+ self.encoder_hid_proj = ImageProjection(
347
+ image_embed_dim=encoder_hid_dim,
348
+ cross_attention_dim=cross_attention_dim,
349
+ )
350
+ elif encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
+ )
354
+ else:
355
+ self.encoder_hid_proj = None
356
+
357
+ # class embedding
358
+ if class_embed_type is None and num_class_embeds is not None:
359
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
+ elif class_embed_type == "timestep":
361
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
+ elif class_embed_type == "identity":
363
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
+ elif class_embed_type == "projection":
365
+ if projection_class_embeddings_input_dim is None:
366
+ raise ValueError(
367
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
+ )
369
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
+ # 2. it projects from an arbitrary input dimension.
372
+ #
373
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
+ elif class_embed_type == "simple_projection":
378
+ if projection_class_embeddings_input_dim is None:
379
+ raise ValueError(
380
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
+ )
382
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
+ else:
384
+ self.class_embedding = None
385
+
386
+ if addition_embed_type == "text":
387
+ if encoder_hid_dim is not None:
388
+ text_time_embedding_from_dim = encoder_hid_dim
389
+ else:
390
+ text_time_embedding_from_dim = cross_attention_dim
391
+
392
+ self.add_embedding = TextTimeEmbedding(
393
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
+ )
395
+ elif addition_embed_type == "text_image":
396
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
+ self.add_embedding = TextImageTimeEmbedding(
400
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
+ )
402
+ elif addition_embed_type == "text_time":
403
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif addition_embed_type == "image":
406
+ # Kandinsky 2.2
407
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
+ elif addition_embed_type == "image_hint":
409
+ # Kandinsky 2.2 ControlNet
410
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type is not None:
412
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
+
414
+ if time_embedding_act_fn is None:
415
+ self.time_embed_act = None
416
+ else:
417
+ self.time_embed_act = get_activation(time_embedding_act_fn)
418
+
419
+ self.down_blocks = nn.ModuleList([])
420
+ self.up_blocks = nn.ModuleList([])
421
+
422
+ if isinstance(only_cross_attention, bool):
423
+ if mid_block_only_cross_attention is None:
424
+ mid_block_only_cross_attention = only_cross_attention
425
+
426
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
427
+
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = False
430
+
431
+ if isinstance(num_attention_heads, int):
432
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
+
434
+ if isinstance(attention_head_dim, int):
435
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
+
437
+ if isinstance(cross_attention_dim, int):
438
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
+
440
+ if isinstance(layers_per_block, int):
441
+ layers_per_block = [layers_per_block] * len(down_block_types)
442
+
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
+
446
+ if class_embeddings_concat:
447
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
448
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
+ # regular time embeddings
450
+ blocks_time_embed_dim = time_embed_dim * 2
451
+ else:
452
+ blocks_time_embed_dim = time_embed_dim
453
+
454
+ # down
455
+ output_channel = block_out_channels[0]
456
+ for i, down_block_type in enumerate(down_block_types):
457
+ input_channel = output_channel
458
+ output_channel = block_out_channels[i]
459
+ is_final_block = i == len(block_out_channels) - 1
460
+
461
+ down_block = get_down_block(
462
+ down_block_type,
463
+ num_layers=layers_per_block[i],
464
+ transformer_layers_per_block=transformer_layers_per_block[i],
465
+ in_channels=input_channel,
466
+ out_channels=output_channel,
467
+ temb_channels=blocks_time_embed_dim,
468
+ add_downsample=not is_final_block,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ resnet_groups=norm_num_groups,
472
+ cross_attention_dim=cross_attention_dim[i],
473
+ num_attention_heads=num_attention_heads[i],
474
+ downsample_padding=downsample_padding,
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention[i],
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ attention_type=attention_type,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ resnet_out_scale_factor=resnet_out_scale_factor,
483
+ cross_attention_norm=cross_attention_norm,
484
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
+ dropout=dropout,
486
+ )
487
+ self.down_blocks.append(down_block)
488
+
489
+ # mid
490
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
491
+ self.mid_block = UNetMidBlock2DCrossAttn(
492
+ transformer_layers_per_block=transformer_layers_per_block[-1],
493
+ in_channels=block_out_channels[-1],
494
+ temb_channels=blocks_time_embed_dim,
495
+ dropout=dropout,
496
+ resnet_eps=norm_eps,
497
+ resnet_act_fn=act_fn,
498
+ output_scale_factor=mid_block_scale_factor,
499
+ resnet_time_scale_shift=resnet_time_scale_shift,
500
+ cross_attention_dim=cross_attention_dim[-1],
501
+ num_attention_heads=num_attention_heads[-1],
502
+ resnet_groups=norm_num_groups,
503
+ dual_cross_attention=dual_cross_attention,
504
+ use_linear_projection=use_linear_projection,
505
+ upcast_attention=upcast_attention,
506
+ attention_type=attention_type,
507
+ )
508
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
509
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
510
+ in_channels=block_out_channels[-1],
511
+ temb_channels=blocks_time_embed_dim,
512
+ dropout=dropout,
513
+ resnet_eps=norm_eps,
514
+ resnet_act_fn=act_fn,
515
+ output_scale_factor=mid_block_scale_factor,
516
+ cross_attention_dim=cross_attention_dim[-1],
517
+ attention_head_dim=attention_head_dim[-1],
518
+ resnet_groups=norm_num_groups,
519
+ resnet_time_scale_shift=resnet_time_scale_shift,
520
+ skip_time_act=resnet_skip_time_act,
521
+ only_cross_attention=mid_block_only_cross_attention,
522
+ cross_attention_norm=cross_attention_norm,
523
+ )
524
+ elif mid_block_type == "UNetMidBlock2D":
525
+ self.mid_block = UNetMidBlock2D(
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ dropout=dropout,
529
+ num_layers=0,
530
+ resnet_eps=norm_eps,
531
+ resnet_act_fn=act_fn,
532
+ output_scale_factor=mid_block_scale_factor,
533
+ resnet_groups=norm_num_groups,
534
+ resnet_time_scale_shift=resnet_time_scale_shift,
535
+ add_attention=False,
536
+ )
537
+ elif mid_block_type is None:
538
+ self.mid_block = None
539
+ else:
540
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
541
+
542
+ # count how many layers upsample the images
543
+ self.num_upsamplers = 0
544
+
545
+ # up
546
+ reversed_block_out_channels = list(reversed(block_out_channels))
547
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
548
+ reversed_layers_per_block = list(reversed(layers_per_block))
549
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
550
+ reversed_transformer_layers_per_block = (
551
+ list(reversed(transformer_layers_per_block))
552
+ if reverse_transformer_layers_per_block is None
553
+ else reverse_transformer_layers_per_block
554
+ )
555
+ only_cross_attention = list(reversed(only_cross_attention))
556
+
557
+ output_channel = reversed_block_out_channels[0]
558
+ for i, up_block_type in enumerate(up_block_types):
559
+ is_final_block = i == len(block_out_channels) - 1
560
+
561
+ prev_output_channel = output_channel
562
+ output_channel = reversed_block_out_channels[i]
563
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
564
+
565
+ # add upsample block for all BUT final layer
566
+ if not is_final_block:
567
+ add_upsample = True
568
+ self.num_upsamplers += 1
569
+ else:
570
+ add_upsample = False
571
+
572
+ up_block = get_up_block(
573
+ up_block_type,
574
+ num_layers=reversed_layers_per_block[i] + 1,
575
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
576
+ in_channels=input_channel,
577
+ out_channels=output_channel,
578
+ prev_output_channel=prev_output_channel,
579
+ temb_channels=blocks_time_embed_dim,
580
+ add_upsample=add_upsample,
581
+ resnet_eps=norm_eps,
582
+ resnet_act_fn=act_fn,
583
+ resolution_idx=i,
584
+ resnet_groups=norm_num_groups,
585
+ cross_attention_dim=reversed_cross_attention_dim[i],
586
+ num_attention_heads=reversed_num_attention_heads[i],
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ only_cross_attention=only_cross_attention[i],
590
+ upcast_attention=upcast_attention,
591
+ resnet_time_scale_shift=resnet_time_scale_shift,
592
+ attention_type=attention_type,
593
+ resnet_skip_time_act=resnet_skip_time_act,
594
+ resnet_out_scale_factor=resnet_out_scale_factor,
595
+ cross_attention_norm=cross_attention_norm,
596
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
597
+ dropout=dropout,
598
+ )
599
+ self.up_blocks.append(up_block)
600
+ prev_output_channel = output_channel
601
+
602
+ # out
603
+ if norm_num_groups is not None:
604
+ self.conv_norm_out = nn.GroupNorm(
605
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
606
+ )
607
+
608
+ self.conv_act = get_activation(act_fn)
609
+
610
+ else:
611
+ self.conv_norm_out = None
612
+ self.conv_act = None
613
+
614
+ conv_out_padding = (conv_out_kernel - 1) // 2
615
+ self.conv_out = nn.Conv2d(
616
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
617
+ )
618
+
619
+ if attention_type in ["gated", "gated-text-image"]:
620
+ positive_len = 768
621
+ if isinstance(cross_attention_dim, int):
622
+ positive_len = cross_attention_dim
623
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
624
+ positive_len = cross_attention_dim[0]
625
+
626
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
627
+ self.position_net = PositionNet(
628
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
629
+ )
630
+
631
+ @property
632
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
633
+ r"""
634
+ Returns:
635
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
636
+ indexed by its weight name.
637
+ """
638
+ # set recursively
639
+ processors = {}
640
+
641
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
642
+ if hasattr(module, "get_processor"):
643
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
644
+
645
+ for sub_name, child in module.named_children():
646
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
647
+
648
+ return processors
649
+
650
+ for name, module in self.named_children():
651
+ fn_recursive_add_processors(name, module, processors)
652
+
653
+ return processors
654
+
655
+ def set_attn_processor(
656
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
657
+ ):
658
+ r"""
659
+ Sets the attention processor to use to compute attention.
660
+
661
+ Parameters:
662
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
663
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
664
+ for **all** `Attention` layers.
665
+
666
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
667
+ processor. This is strongly recommended when setting trainable attention processors.
668
+
669
+ """
670
+ count = len(self.attn_processors.keys())
671
+
672
+ if isinstance(processor, dict) and len(processor) != count:
673
+ raise ValueError(
674
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
675
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
676
+ )
677
+
678
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
679
+ if hasattr(module, "set_processor"):
680
+ if not isinstance(processor, dict):
681
+ module.set_processor(processor, _remove_lora=_remove_lora)
682
+ else:
683
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
684
+
685
+ for sub_name, child in module.named_children():
686
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
687
+
688
+ for name, module in self.named_children():
689
+ fn_recursive_attn_processor(name, module, processor)
690
+
691
+ def set_default_attn_processor(self):
692
+ """
693
+ Disables custom attention processors and sets the default attention implementation.
694
+ """
695
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
696
+ processor = AttnAddedKVProcessor()
697
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
698
+ processor = AttnProcessor()
699
+ else:
700
+ raise ValueError(
701
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
702
+ )
703
+
704
+ self.set_attn_processor(processor, _remove_lora=True)
705
+
706
+ def set_attention_slice(self, slice_size):
707
+ r"""
708
+ Enable sliced attention computation.
709
+
710
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
711
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
712
+
713
+ Args:
714
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
715
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
716
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
717
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
718
+ must be a multiple of `slice_size`.
719
+ """
720
+ sliceable_head_dims = []
721
+
722
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
723
+ if hasattr(module, "set_attention_slice"):
724
+ sliceable_head_dims.append(module.sliceable_head_dim)
725
+
726
+ for child in module.children():
727
+ fn_recursive_retrieve_sliceable_dims(child)
728
+
729
+ # retrieve number of attention layers
730
+ for module in self.children():
731
+ fn_recursive_retrieve_sliceable_dims(module)
732
+
733
+ num_sliceable_layers = len(sliceable_head_dims)
734
+
735
+ if slice_size == "auto":
736
+ # half the attention head size is usually a good trade-off between
737
+ # speed and memory
738
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
739
+ elif slice_size == "max":
740
+ # make smallest slice possible
741
+ slice_size = num_sliceable_layers * [1]
742
+
743
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
744
+
745
+ if len(slice_size) != len(sliceable_head_dims):
746
+ raise ValueError(
747
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
748
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
749
+ )
750
+
751
+ for i in range(len(slice_size)):
752
+ size = slice_size[i]
753
+ dim = sliceable_head_dims[i]
754
+ if size is not None and size > dim:
755
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
756
+
757
+ # Recursively walk through all the children.
758
+ # Any children which exposes the set_attention_slice method
759
+ # gets the message
760
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
761
+ if hasattr(module, "set_attention_slice"):
762
+ module.set_attention_slice(slice_size.pop())
763
+
764
+ for child in module.children():
765
+ fn_recursive_set_attention_slice(child, slice_size)
766
+
767
+ reversed_slice_size = list(reversed(slice_size))
768
+ for module in self.children():
769
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
770
+
771
+ def _set_gradient_checkpointing(self, module, value=False):
772
+ if hasattr(module, "gradient_checkpointing"):
773
+ module.gradient_checkpointing = value
774
+
775
+ def enable_freeu(self, s1, s2, b1, b2):
776
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
777
+
778
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
779
+
780
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
781
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
782
+
783
+ Args:
784
+ s1 (`float`):
785
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
786
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
787
+ s2 (`float`):
788
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
789
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
790
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
791
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
792
+ """
793
+ for i, upsample_block in enumerate(self.up_blocks):
794
+ setattr(upsample_block, "s1", s1)
795
+ setattr(upsample_block, "s2", s2)
796
+ setattr(upsample_block, "b1", b1)
797
+ setattr(upsample_block, "b2", b2)
798
+
799
+ def disable_freeu(self):
800
+ """Disables the FreeU mechanism."""
801
+ freeu_keys = {"s1", "s2", "b1", "b2"}
802
+ for i, upsample_block in enumerate(self.up_blocks):
803
+ for k in freeu_keys:
804
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
805
+ setattr(upsample_block, k, None)
806
+
807
+ def forward(
808
+ self,
809
+ sample: torch.FloatTensor,
810
+ spatial_attn_inputs,
811
+ timestep: Union[torch.Tensor, float, int],
812
+ encoder_hidden_states: torch.Tensor,
813
+ class_labels: Optional[torch.Tensor] = None,
814
+ timestep_cond: Optional[torch.Tensor] = None,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
817
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
818
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
819
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
820
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
821
+ encoder_attention_mask: Optional[torch.Tensor] = None,
822
+ return_dict: bool = True,
823
+ ) -> Union[UNet2DConditionOutput, Tuple]:
824
+ r"""
825
+ The [`UNet2DConditionModel`] forward method.
826
+
827
+ Args:
828
+ sample (`torch.FloatTensor`):
829
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
830
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
831
+ encoder_hidden_states (`torch.FloatTensor`):
832
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
833
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
834
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
835
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
836
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
837
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
838
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
839
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
840
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
841
+ negative values to the attention scores corresponding to "discard" tokens.
842
+ cross_attention_kwargs (`dict`, *optional*):
843
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
844
+ `self.processor` in
845
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
846
+ added_cond_kwargs: (`dict`, *optional*):
847
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
848
+ are passed along to the UNet blocks.
849
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
850
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
851
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
852
+ A tensor that if specified is added to the residual of the middle unet block.
853
+ encoder_attention_mask (`torch.Tensor`):
854
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
855
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
856
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
857
+ return_dict (`bool`, *optional*, defaults to `True`):
858
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
859
+ tuple.
860
+ cross_attention_kwargs (`dict`, *optional*):
861
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
862
+ added_cond_kwargs: (`dict`, *optional*):
863
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
864
+ are passed along to the UNet blocks.
865
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
866
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
867
+ example from ControlNet side model(s)
868
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
869
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
870
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
871
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
872
+
873
+ Returns:
874
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
875
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
876
+ a `tuple` is returned where the first element is the sample tensor.
877
+ """
878
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
879
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
880
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
881
+ # on the fly if necessary.
882
+ default_overall_up_factor = 2**self.num_upsamplers
883
+
884
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
885
+ forward_upsample_size = False
886
+ upsample_size = None
887
+
888
+ for dim in sample.shape[-2:]:
889
+ if dim % default_overall_up_factor != 0:
890
+ # Forward upsample size to force interpolation output size.
891
+ forward_upsample_size = True
892
+ break
893
+
894
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
895
+ # expects mask of shape:
896
+ # [batch, key_tokens]
897
+ # adds singleton query_tokens dimension:
898
+ # [batch, 1, key_tokens]
899
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
900
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
901
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
902
+ if attention_mask is not None:
903
+ # assume that mask is expressed as:
904
+ # (1 = keep, 0 = discard)
905
+ # convert mask into a bias that can be added to attention scores:
906
+ # (keep = +0, discard = -10000.0)
907
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
908
+ attention_mask = attention_mask.unsqueeze(1)
909
+
910
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
911
+ if encoder_attention_mask is not None:
912
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
913
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
914
+
915
+ # 0. center input if necessary
916
+ if self.config.center_input_sample:
917
+ sample = 2 * sample - 1.0
918
+
919
+ # 1. time
920
+ timesteps = timestep
921
+ if not torch.is_tensor(timesteps):
922
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
923
+ # This would be a good case for the `match` statement (Python 3.10+)
924
+ is_mps = sample.device.type == "mps"
925
+ if isinstance(timestep, float):
926
+ dtype = torch.float32 if is_mps else torch.float64
927
+ else:
928
+ dtype = torch.int32 if is_mps else torch.int64
929
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
930
+ elif len(timesteps.shape) == 0:
931
+ timesteps = timesteps[None].to(sample.device)
932
+
933
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
934
+ timesteps = timesteps.expand(sample.shape[0])
935
+
936
+ t_emb = self.time_proj(timesteps)
937
+
938
+ # `Timesteps` does not contain any weights and will always return f32 tensors
939
+ # but time_embedding might actually be running in fp16. so we need to cast here.
940
+ # there might be better ways to encapsulate this.
941
+ t_emb = t_emb.to(dtype=sample.dtype)
942
+
943
+ emb = self.time_embedding(t_emb, timestep_cond)
944
+ aug_emb = None
945
+
946
+ if self.class_embedding is not None:
947
+ if class_labels is None:
948
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
949
+
950
+ if self.config.class_embed_type == "timestep":
951
+ class_labels = self.time_proj(class_labels)
952
+
953
+ # `Timesteps` does not contain any weights and will always return f32 tensors
954
+ # there might be better ways to encapsulate this.
955
+ class_labels = class_labels.to(dtype=sample.dtype)
956
+
957
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
958
+
959
+ if self.config.class_embeddings_concat:
960
+ emb = torch.cat([emb, class_emb], dim=-1)
961
+ else:
962
+ emb = emb + class_emb
963
+
964
+ if self.config.addition_embed_type == "text":
965
+ aug_emb = self.add_embedding(encoder_hidden_states)
966
+ elif self.config.addition_embed_type == "text_image":
967
+ # Kandinsky 2.1 - style
968
+ if "image_embeds" not in added_cond_kwargs:
969
+ raise ValueError(
970
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
971
+ )
972
+
973
+ image_embs = added_cond_kwargs.get("image_embeds")
974
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
975
+ aug_emb = self.add_embedding(text_embs, image_embs)
976
+ elif self.config.addition_embed_type == "text_time":
977
+ # SDXL - style
978
+ if "text_embeds" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
981
+ )
982
+ text_embeds = added_cond_kwargs.get("text_embeds")
983
+ if "time_ids" not in added_cond_kwargs:
984
+ raise ValueError(
985
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
986
+ )
987
+ time_ids = added_cond_kwargs.get("time_ids")
988
+ time_embeds = self.add_time_proj(time_ids.flatten())
989
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
990
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
991
+ add_embeds = add_embeds.to(emb.dtype)
992
+ aug_emb = self.add_embedding(add_embeds)
993
+ elif self.config.addition_embed_type == "image":
994
+ # Kandinsky 2.2 - style
995
+ if "image_embeds" not in added_cond_kwargs:
996
+ raise ValueError(
997
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
998
+ )
999
+ image_embs = added_cond_kwargs.get("image_embeds")
1000
+ aug_emb = self.add_embedding(image_embs)
1001
+ elif self.config.addition_embed_type == "image_hint":
1002
+ # Kandinsky 2.2 - style
1003
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1004
+ raise ValueError(
1005
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1006
+ )
1007
+ image_embs = added_cond_kwargs.get("image_embeds")
1008
+ hint = added_cond_kwargs.get("hint")
1009
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1010
+ sample = torch.cat([sample, hint], dim=1)
1011
+
1012
+ emb = emb + aug_emb if aug_emb is not None else emb
1013
+
1014
+ if self.time_embed_act is not None:
1015
+ emb = self.time_embed_act(emb)
1016
+
1017
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1018
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1019
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1020
+ # Kadinsky 2.1 - style
1021
+ if "image_embeds" not in added_cond_kwargs:
1022
+ raise ValueError(
1023
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1024
+ )
1025
+
1026
+ image_embeds = added_cond_kwargs.get("image_embeds")
1027
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1028
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1029
+ # Kandinsky 2.2 - style
1030
+ if "image_embeds" not in added_cond_kwargs:
1031
+ raise ValueError(
1032
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1036
+ # 2. pre-process
1037
+ sample = self.conv_in(sample)
1038
+
1039
+ # 2.5 GLIGEN position net
1040
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1041
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1042
+ gligen_args = cross_attention_kwargs.pop("gligen")
1043
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1044
+
1045
+ # for spatial attention
1046
+ spatial_attn_idx = 0
1047
+
1048
+ # 3. down
1049
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1050
+ if USE_PEFT_BACKEND:
1051
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1052
+ scale_lora_layers(self, lora_scale)
1053
+
1054
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1055
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1056
+ is_adapter = down_intrablock_additional_residuals is not None
1057
+ # maintain backward compatibility for legacy usage, where
1058
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1059
+ # but can only use one or the other
1060
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1061
+ deprecate(
1062
+ "T2I should not use down_block_additional_residuals",
1063
+ "1.3.0",
1064
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1065
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1066
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1067
+ standard_warn=False,
1068
+ )
1069
+ down_intrablock_additional_residuals = down_block_additional_residuals
1070
+ is_adapter = True
1071
+
1072
+ down_block_res_samples = (sample,)
1073
+ for downsample_block in self.down_blocks:
1074
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1075
+ # For t2i-adapter CrossAttnDownBlock2D
1076
+ additional_residuals = {}
1077
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1078
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1079
+
1080
+ sample, res_samples, spatial_attn_inputs, spatial_attn_idx = downsample_block(
1081
+ hidden_states=sample,
1082
+ spatial_attn_inputs=spatial_attn_inputs,
1083
+ spatial_attn_idx=spatial_attn_idx,
1084
+ temb=emb,
1085
+ encoder_hidden_states=encoder_hidden_states,
1086
+ attention_mask=attention_mask,
1087
+ cross_attention_kwargs=cross_attention_kwargs,
1088
+ encoder_attention_mask=encoder_attention_mask,
1089
+ **additional_residuals,
1090
+ )
1091
+ else:
1092
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1093
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1094
+ sample += down_intrablock_additional_residuals.pop(0)
1095
+
1096
+ down_block_res_samples += res_samples
1097
+
1098
+ if is_controlnet:
1099
+ new_down_block_res_samples = ()
1100
+
1101
+ for down_block_res_sample, down_block_additional_residual in zip(
1102
+ down_block_res_samples, down_block_additional_residuals
1103
+ ):
1104
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1105
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1106
+
1107
+ down_block_res_samples = new_down_block_res_samples
1108
+
1109
+ # 4. mid
1110
+ if self.mid_block is not None:
1111
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1112
+ sample, spatial_attn_inputs, spatial_attn_idx = self.mid_block(
1113
+ sample,
1114
+ spatial_attn_inputs=spatial_attn_inputs,
1115
+ spatial_attn_idx=spatial_attn_idx,
1116
+ temb=emb,
1117
+ encoder_hidden_states=encoder_hidden_states,
1118
+ attention_mask=attention_mask,
1119
+ cross_attention_kwargs=cross_attention_kwargs,
1120
+ encoder_attention_mask=encoder_attention_mask,
1121
+ )
1122
+ else:
1123
+ sample = self.mid_block(sample, emb)
1124
+
1125
+ # To support T2I-Adapter-XL
1126
+ if (
1127
+ is_adapter
1128
+ and len(down_intrablock_additional_residuals) > 0
1129
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1130
+ ):
1131
+ sample += down_intrablock_additional_residuals.pop(0)
1132
+
1133
+ if is_controlnet:
1134
+ sample = sample + mid_block_additional_residual
1135
+
1136
+ # 5. up
1137
+ for i, upsample_block in enumerate(self.up_blocks):
1138
+ is_final_block = i == len(self.up_blocks) - 1
1139
+
1140
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1141
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1142
+
1143
+ # if we have not reached the final block and need to forward the
1144
+ # upsample size, we do it here
1145
+ if not is_final_block and forward_upsample_size:
1146
+ upsample_size = down_block_res_samples[-1].shape[2:]
1147
+
1148
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1149
+ sample, spatial_attn_inputs, spatial_attn_idx = upsample_block(
1150
+ hidden_states=sample,
1151
+ spatial_attn_inputs=spatial_attn_inputs,
1152
+ spatial_attn_idx=spatial_attn_idx,
1153
+ temb=emb,
1154
+ res_hidden_states_tuple=res_samples,
1155
+ encoder_hidden_states=encoder_hidden_states,
1156
+ cross_attention_kwargs=cross_attention_kwargs,
1157
+ upsample_size=upsample_size,
1158
+ attention_mask=attention_mask,
1159
+ encoder_attention_mask=encoder_attention_mask,
1160
+ )
1161
+ else:
1162
+ sample = upsample_block(
1163
+ hidden_states=sample,
1164
+ temb=emb,
1165
+ res_hidden_states_tuple=res_samples,
1166
+ upsample_size=upsample_size,
1167
+ scale=lora_scale,
1168
+ )
1169
+
1170
+ # 6. post-process
1171
+ if self.conv_norm_out:
1172
+ sample = self.conv_norm_out(sample)
1173
+ sample = self.conv_act(sample)
1174
+ sample = self.conv_out(sample)
1175
+
1176
+ if USE_PEFT_BACKEND:
1177
+ # remove `lora_scale` from each PEFT layer
1178
+ unscale_lora_layers(self, lora_scale)
1179
+
1180
+ if not return_dict:
1181
+ return (sample,)
1182
+
1183
+ return UNet2DConditionOutput(sample=sample)
OOTDiffusion/preprocess/humanparsing/datasets/__init__.py ADDED
File without changes
OOTDiffusion/preprocess/humanparsing/datasets/datasets.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : datasets.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ import random
17
+ import torch
18
+ import cv2
19
+ from torch.utils import data
20
+ from utils.transforms import get_affine_transform
21
+
22
+
23
+ class LIPDataSet(data.Dataset):
24
+ def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25,
25
+ rotation_factor=30, ignore_label=255, transform=None):
26
+ self.root = root
27
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
28
+ self.crop_size = np.asarray(crop_size)
29
+ self.ignore_label = ignore_label
30
+ self.scale_factor = scale_factor
31
+ self.rotation_factor = rotation_factor
32
+ self.flip_prob = 0.5
33
+ self.transform = transform
34
+ self.dataset = dataset
35
+
36
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
37
+ train_list = [i_id.strip() for i_id in open(list_path)]
38
+
39
+ self.train_list = train_list
40
+ self.number_samples = len(self.train_list)
41
+
42
+ def __len__(self):
43
+ return self.number_samples
44
+
45
+ def _box2cs(self, box):
46
+ x, y, w, h = box[:4]
47
+ return self._xywh2cs(x, y, w, h)
48
+
49
+ def _xywh2cs(self, x, y, w, h):
50
+ center = np.zeros((2), dtype=np.float32)
51
+ center[0] = x + w * 0.5
52
+ center[1] = y + h * 0.5
53
+ if w > self.aspect_ratio * h:
54
+ h = w * 1.0 / self.aspect_ratio
55
+ elif w < self.aspect_ratio * h:
56
+ w = h * self.aspect_ratio
57
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
58
+ return center, scale
59
+
60
+ def __getitem__(self, index):
61
+ train_item = self.train_list[index]
62
+
63
+ im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
64
+ parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')
65
+
66
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
67
+ h, w, _ = im.shape
68
+ parsing_anno = np.zeros((h, w), dtype=np.long)
69
+
70
+ # Get person center and scale
71
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
72
+ r = 0
73
+
74
+ if self.dataset != 'test':
75
+ # Get pose annotation
76
+ parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
77
+ if self.dataset == 'train' or self.dataset == 'trainval':
78
+ sf = self.scale_factor
79
+ rf = self.rotation_factor
80
+ s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
81
+ r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
82
+
83
+ if random.random() <= self.flip_prob:
84
+ im = im[:, ::-1, :]
85
+ parsing_anno = parsing_anno[:, ::-1]
86
+ person_center[0] = im.shape[1] - person_center[0] - 1
87
+ right_idx = [15, 17, 19]
88
+ left_idx = [14, 16, 18]
89
+ for i in range(0, 3):
90
+ right_pos = np.where(parsing_anno == right_idx[i])
91
+ left_pos = np.where(parsing_anno == left_idx[i])
92
+ parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
93
+ parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
94
+
95
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
96
+ input = cv2.warpAffine(
97
+ im,
98
+ trans,
99
+ (int(self.crop_size[1]), int(self.crop_size[0])),
100
+ flags=cv2.INTER_LINEAR,
101
+ borderMode=cv2.BORDER_CONSTANT,
102
+ borderValue=(0, 0, 0))
103
+
104
+ if self.transform:
105
+ input = self.transform(input)
106
+
107
+ meta = {
108
+ 'name': train_item,
109
+ 'center': person_center,
110
+ 'height': h,
111
+ 'width': w,
112
+ 'scale': s,
113
+ 'rotation': r
114
+ }
115
+
116
+ if self.dataset == 'val' or self.dataset == 'test':
117
+ return input, meta
118
+ else:
119
+ label_parsing = cv2.warpAffine(
120
+ parsing_anno,
121
+ trans,
122
+ (int(self.crop_size[1]), int(self.crop_size[0])),
123
+ flags=cv2.INTER_NEAREST,
124
+ borderMode=cv2.BORDER_CONSTANT,
125
+ borderValue=(255))
126
+
127
+ label_parsing = torch.from_numpy(label_parsing)
128
+
129
+ return input, label_parsing, meta
130
+
131
+
132
+ class LIPDataValSet(data.Dataset):
133
+ def __init__(self, root, dataset='val', crop_size=[473, 473], transform=None, flip=False):
134
+ self.root = root
135
+ self.crop_size = crop_size
136
+ self.transform = transform
137
+ self.flip = flip
138
+ self.dataset = dataset
139
+ self.root = root
140
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
141
+ self.crop_size = np.asarray(crop_size)
142
+
143
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
144
+ val_list = [i_id.strip() for i_id in open(list_path)]
145
+
146
+ self.val_list = val_list
147
+ self.number_samples = len(self.val_list)
148
+
149
+ def __len__(self):
150
+ return len(self.val_list)
151
+
152
+ def _box2cs(self, box):
153
+ x, y, w, h = box[:4]
154
+ return self._xywh2cs(x, y, w, h)
155
+
156
+ def _xywh2cs(self, x, y, w, h):
157
+ center = np.zeros((2), dtype=np.float32)
158
+ center[0] = x + w * 0.5
159
+ center[1] = y + h * 0.5
160
+ if w > self.aspect_ratio * h:
161
+ h = w * 1.0 / self.aspect_ratio
162
+ elif w < self.aspect_ratio * h:
163
+ w = h * self.aspect_ratio
164
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
165
+
166
+ return center, scale
167
+
168
+ def __getitem__(self, index):
169
+ val_item = self.val_list[index]
170
+ # Load training image
171
+ im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg')
172
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
173
+ h, w, _ = im.shape
174
+ # Get person center and scale
175
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
176
+ r = 0
177
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
178
+ input = cv2.warpAffine(
179
+ im,
180
+ trans,
181
+ (int(self.crop_size[1]), int(self.crop_size[0])),
182
+ flags=cv2.INTER_LINEAR,
183
+ borderMode=cv2.BORDER_CONSTANT,
184
+ borderValue=(0, 0, 0))
185
+ input = self.transform(input)
186
+ flip_input = input.flip(dims=[-1])
187
+ if self.flip:
188
+ batch_input_im = torch.stack([input, flip_input])
189
+ else:
190
+ batch_input_im = input
191
+
192
+ meta = {
193
+ 'name': val_item,
194
+ 'center': person_center,
195
+ 'height': h,
196
+ 'width': w,
197
+ 'scale': s,
198
+ 'rotation': r
199
+ }
200
+
201
+ return batch_input_im, meta
OOTDiffusion/preprocess/humanparsing/datasets/simple_extractor_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : dataset.py
8
+ @Time : 8/30/19 9:12 PM
9
+ @Desc : Dataset Definition
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import pdb
16
+
17
+ import cv2
18
+ import numpy as np
19
+ from PIL import Image
20
+ from torch.utils import data
21
+ from utils.transforms import get_affine_transform
22
+
23
+
24
+ class SimpleFolderDataset(data.Dataset):
25
+ def __init__(self, root, input_size=[512, 512], transform=None):
26
+ self.root = root
27
+ self.input_size = input_size
28
+ self.transform = transform
29
+ self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
30
+ self.input_size = np.asarray(input_size)
31
+ self.is_pil_image = False
32
+ if isinstance(root, Image.Image):
33
+ self.file_list = [root]
34
+ self.is_pil_image = True
35
+ elif os.path.isfile(root):
36
+ self.file_list = [os.path.basename(root)]
37
+ self.root = os.path.dirname(root)
38
+ else:
39
+ self.file_list = os.listdir(self.root)
40
+
41
+ def __len__(self):
42
+ return len(self.file_list)
43
+
44
+ def _box2cs(self, box):
45
+ x, y, w, h = box[:4]
46
+ return self._xywh2cs(x, y, w, h)
47
+
48
+ def _xywh2cs(self, x, y, w, h):
49
+ center = np.zeros((2), dtype=np.float32)
50
+ center[0] = x + w * 0.5
51
+ center[1] = y + h * 0.5
52
+ if w > self.aspect_ratio * h:
53
+ h = w * 1.0 / self.aspect_ratio
54
+ elif w < self.aspect_ratio * h:
55
+ w = h * self.aspect_ratio
56
+ scale = np.array([w, h], dtype=np.float32)
57
+ return center, scale
58
+
59
+ def __getitem__(self, index):
60
+ if self.is_pil_image:
61
+ img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
62
+ else:
63
+ img_name = self.file_list[index]
64
+ img_path = os.path.join(self.root, img_name)
65
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
66
+ h, w, _ = img.shape
67
+
68
+ # Get person center and scale
69
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
70
+ r = 0
71
+ trans = get_affine_transform(person_center, s, r, self.input_size)
72
+ input = cv2.warpAffine(
73
+ img,
74
+ trans,
75
+ (int(self.input_size[1]), int(self.input_size[0])),
76
+ flags=cv2.INTER_LINEAR,
77
+ borderMode=cv2.BORDER_CONSTANT,
78
+ borderValue=(0, 0, 0))
79
+
80
+ input = self.transform(input)
81
+ meta = {
82
+ 'center': person_center,
83
+ 'height': h,
84
+ 'width': w,
85
+ 'scale': s,
86
+ 'rotation': r
87
+ }
88
+
89
+ return input, meta
OOTDiffusion/preprocess/humanparsing/datasets/target_generation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def generate_edge_tensor(label, edge_width=3):
6
+ label = label.type(torch.cuda.FloatTensor)
7
+ if len(label.shape) == 2:
8
+ label = label.unsqueeze(0)
9
+ n, h, w = label.shape
10
+ edge = torch.zeros(label.shape, dtype=torch.float).cuda()
11
+ # right
12
+ edge_right = edge[:, 1:h, :]
13
+ edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
14
+ & (label[:, :h - 1, :] != 255)] = 1
15
+
16
+ # up
17
+ edge_up = edge[:, :, :w - 1]
18
+ edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
19
+ & (label[:, :, :w - 1] != 255)
20
+ & (label[:, :, 1:w] != 255)] = 1
21
+
22
+ # upright
23
+ edge_upright = edge[:, :h - 1, :w - 1]
24
+ edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
25
+ & (label[:, :h - 1, :w - 1] != 255)
26
+ & (label[:, 1:h, 1:w] != 255)] = 1
27
+
28
+ # bottomright
29
+ edge_bottomright = edge[:, :h - 1, 1:w]
30
+ edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
31
+ & (label[:, :h - 1, 1:w] != 255)
32
+ & (label[:, 1:h, :w - 1] != 255)] = 1
33
+
34
+ kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
35
+ with torch.no_grad():
36
+ edge = edge.unsqueeze(1)
37
+ edge = F.conv2d(edge, kernel, stride=1, padding=1)
38
+ edge[edge!=0] = 1
39
+ edge = edge.squeeze()
40
+ return edge
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to detectron2
2
+
3
+ ## Issues
4
+ We use GitHub issues to track public bugs and questions.
5
+ Please make sure to follow one of the
6
+ [issue templates](https://github.com/facebookresearch/detectron2/issues/new/choose)
7
+ when reporting any issues.
8
+
9
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
10
+ disclosure of security bugs. In those cases, please go through the process
11
+ outlined on that page and do not file a public issue.
12
+
13
+ ## Pull Requests
14
+ We actively welcome your pull requests.
15
+
16
+ However, if you're adding any significant features (e.g. > 50 lines), please
17
+ make sure to have a corresponding issue to discuss your motivation and proposals,
18
+ before sending a PR. We do not always accept new features, and we take the following
19
+ factors into consideration:
20
+
21
+ 1. Whether the same feature can be achieved without modifying detectron2.
22
+ Detectron2 is designed so that you can implement many extensions from the outside, e.g.
23
+ those in [projects](https://github.com/facebookresearch/detectron2/tree/master/projects).
24
+ If some part is not as extensible, you can also bring up the issue to make it more extensible.
25
+ 2. Whether the feature is potentially useful to a large audience, or only to a small portion of users.
26
+ 3. Whether the proposed solution has a good design / interface.
27
+ 4. Whether the proposed solution adds extra mental/practical overhead to users who don't
28
+ need such feature.
29
+ 5. Whether the proposed solution breaks existing APIs.
30
+
31
+ When sending a PR, please do:
32
+
33
+ 1. If a PR contains multiple orthogonal changes, split it to several PRs.
34
+ 2. If you've added code that should be tested, add tests.
35
+ 3. For PRs that need experiments (e.g. adding a new model or new methods),
36
+ you don't need to update model zoo, but do provide experiment results in the description of the PR.
37
+ 4. If APIs are changed, update the documentation.
38
+ 5. Make sure your code lints with `./dev/linter.sh`.
39
+
40
+
41
+ ## Contributor License Agreement ("CLA")
42
+ In order to accept your pull request, we need you to submit a CLA. You only need
43
+ to do this once to work on any of Facebook's open source projects.
44
+
45
+ Complete your CLA here: <https://code.facebook.com/cla>
46
+
47
+ ## License
48
+ By contributing to detectron2, you agree that your contributions will be licensed
49
+ under the LICENSE file in the root directory of this source tree.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/Detectron2-Logo-Horz.svg ADDED
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ Please select an issue template from
3
+ https://github.com/facebookresearch/detectron2/issues/new/choose .
4
+
5
+ Otherwise your issue will be closed.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/bugs.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "🐛 Bugs"
3
+ about: Report bugs in detectron2
4
+ title: Please read & provide the following
5
+
6
+ ---
7
+
8
+ ## Instructions To Reproduce the 🐛 Bug:
9
+
10
+ 1. what changes you made (`git diff`) or what code you wrote
11
+ ```
12
+ <put diff or code here>
13
+ ```
14
+ 2. what exact command you run:
15
+ 3. what you observed (including __full logs__):
16
+ ```
17
+ <put logs here>
18
+ ```
19
+ 4. please simplify the steps as much as possible so they do not require additional resources to
20
+ run, such as a private dataset.
21
+
22
+ ## Expected behavior:
23
+
24
+ If there are no obvious error in "what you observed" provided above,
25
+ please tell us the expected behavior.
26
+
27
+ ## Environment:
28
+
29
+ Provide your environment information using the following command:
30
+ ```
31
+ wget -nc -q https://github.com/facebookresearch/detectron2/raw/master/detectron2/utils/collect_env.py && python collect_env.py
32
+ ```
33
+
34
+ If your issue looks like an installation issue / environment issue,
35
+ please first try to solve it yourself with the instructions in
36
+ https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # require an issue template to be chosen
2
+ blank_issues_enabled: false
3
+
4
+ # Unexpected behaviors & bugs are split to two templates.
5
+ # When they are one template, users think "it's not a bug" and don't choose the template.
6
+ #
7
+ # But the file name is still "unexpected-problems-bugs.md" so that old references
8
+ # to this issue template still works.
9
+ # It's ok since this template should be a superset of "bugs.md" (unexpected behaviors is a superset of bugs)
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/feature-request.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F680Feature Request"
3
+ about: Submit a proposal/request for a new detectron2 feature
4
+
5
+ ---
6
+
7
+ ## 🚀 Feature
8
+ A clear and concise description of the feature proposal.
9
+
10
+
11
+ ## Motivation & Examples
12
+
13
+ Tell us why the feature is useful.
14
+
15
+ Describe what the feature would look like, if it is implemented.
16
+ Best demonstrated using **code examples** in addition to words.
17
+
18
+ ## Note
19
+
20
+ We only consider adding new features if they are relevant to many users.
21
+
22
+ If you request implementation of research papers --
23
+ we only consider papers that have enough significance and prevalance in the object detection field.
24
+
25
+ We do not take requests for most projects in the `projects/` directory,
26
+ because they are research code release that is mainly for other researchers to reproduce results.
27
+
28
+ Instead of adding features inside detectron2,
29
+ you can implement many features by [extending detectron2](https://detectron2.readthedocs.io/tutorials/extend.html).
30
+ The [projects/](https://github.com/facebookresearch/detectron2/tree/master/projects/) directory contains many of such examples.
31
+
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/questions-help-support.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "❓How to do something?"
3
+ about: How to do something using detectron2? What does an API do?
4
+
5
+ ---
6
+
7
+ ## ❓ How to do something using detectron2
8
+
9
+ Describe what you want to do, including:
10
+ 1. what inputs you will provide, if any:
11
+ 2. what outputs you are expecting:
12
+
13
+ ## ❓ What does an API do and how to use it?
14
+ Please link to which API or documentation you're asking about from
15
+ https://detectron2.readthedocs.io/
16
+
17
+
18
+ NOTE:
19
+
20
+ 1. Only general answers are provided.
21
+ If you want to ask about "why X did not work", please use the
22
+ [Unexpected behaviors](https://github.com/facebookresearch/detectron2/issues/new/choose) issue template.
23
+
24
+ 2. About how to implement new models / new dataloader / new training logic, etc., check documentation first.
25
+
26
+ 3. We do not answer general machine learning / computer vision questions that are not specific to detectron2, such as how a model works, how to improve your training/make it converge, or what algorithm/methods can be used to achieve X.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "Unexpected behaviors"
3
+ about: Run into unexpected behaviors when using detectron2
4
+ title: Please read & provide the following
5
+
6
+ ---
7
+
8
+ If you do not know the root cause of the problem, and wish someone to help you, please
9
+ post according to this template:
10
+
11
+ ## Instructions To Reproduce the Issue:
12
+
13
+ 1. what changes you made (`git diff`) or what code you wrote
14
+ ```
15
+ <put diff or code here>
16
+ ```
17
+ 2. what exact command you run:
18
+ 3. what you observed (including __full logs__):
19
+ ```
20
+ <put logs here>
21
+ ```
22
+ 4. please simplify the steps as much as possible so they do not require additional resources to
23
+ run, such as a private dataset.
24
+
25
+ ## Expected behavior:
26
+
27
+ If there are no obvious error in "what you observed" provided above,
28
+ please tell us the expected behavior.
29
+
30
+ If you expect the model to converge / work better, note that we do not give suggestions
31
+ on how to train a new model.
32
+ Only in one of the two conditions we will help with it:
33
+ (1) You're unable to reproduce the results in detectron2 model zoo.
34
+ (2) It indicates a detectron2 bug.
35
+
36
+ ## Environment:
37
+
38
+ Provide your environment information using the following command:
39
+ ```
40
+ wget -nc -q https://github.com/facebookresearch/detectron2/raw/master/detectron2/utils/collect_env.py && python collect_env.py
41
+ ```
42
+
43
+ If your issue looks like an installation issue / environment issue,
44
+ please first try to solve it yourself with the instructions in
45
+ https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/.github/pull_request_template.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Thanks for your contribution!
2
+
3
+ If you're sending a large PR (e.g., >50 lines),
4
+ please open an issue first about the feature / bug, and indicate how you want to contribute.
5
+
6
+ Before submitting a PR, please run `dev/linter.sh` to lint the code.
7
+
8
+ See https://detectron2.readthedocs.io/notes/contributing.html#pull-requests
9
+ about how we handle PRs.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:10.1-cudnn7-devel
2
+
3
+ ENV DEBIAN_FRONTEND noninteractive
4
+ RUN apt-get update && apt-get install -y \
5
+ python3-opencv ca-certificates python3-dev git wget sudo \
6
+ cmake ninja-build protobuf-compiler libprotobuf-dev && \
7
+ rm -rf /var/lib/apt/lists/*
8
+ RUN ln -sv /usr/bin/python3 /usr/bin/python
9
+
10
+ # create a non-root user
11
+ ARG USER_ID=1000
12
+ RUN useradd -m --no-log-init --system --uid ${USER_ID} appuser -g sudo
13
+ RUN echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
14
+ USER appuser
15
+ WORKDIR /home/appuser
16
+
17
+ ENV PATH="/home/appuser/.local/bin:${PATH}"
18
+ RUN wget https://bootstrap.pypa.io/get-pip.py && \
19
+ python3 get-pip.py --user && \
20
+ rm get-pip.py
21
+
22
+ # install dependencies
23
+ # See https://pytorch.org/ for other options if you use a different version of CUDA
24
+ RUN pip install --user tensorboard cython
25
+ RUN pip install --user torch==1.5+cu101 torchvision==0.6+cu101 -f https://download.pytorch.org/whl/torch_stable.html
26
+ RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
27
+
28
+ RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
29
+ # install detectron2
30
+ RUN git clone https://github.com/facebookresearch/detectron2 detectron2_repo
31
+ # set FORCE_CUDA because during `docker build` cuda is not accessible
32
+ ENV FORCE_CUDA="1"
33
+ # This will by default build detectron2 for all common cuda architectures and take a lot more time,
34
+ # because inside `docker build`, there is no way to tell which architecture will be used.
35
+ ARG TORCH_CUDA_ARCH_LIST="Kepler;Kepler+Tesla;Maxwell;Maxwell+Tegra;Pascal;Volta;Turing"
36
+ ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}"
37
+
38
+ RUN pip install --user -e detectron2_repo
39
+
40
+ # Set a fixed model cache directory.
41
+ ENV FVCORE_CACHE="/tmp"
42
+ WORKDIR /home/appuser/detectron2_repo
43
+
44
+ # run detectron2 under user "appuser":
45
+ # wget http://images.cocodataset.org/val2017/000000439715.jpg -O input.jpg
46
+ # python3 demo/demo.py \
47
+ #--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
48
+ #--input input.jpg --output outputs/ \
49
+ #--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/Dockerfile-circleci ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:10.1-cudnn7-devel
2
+ # This dockerfile only aims to provide an environment for unittest on CircleCI
3
+
4
+ ENV DEBIAN_FRONTEND noninteractive
5
+ RUN apt-get update && apt-get install -y \
6
+ python3-opencv ca-certificates python3-dev git wget sudo ninja-build && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN wget -q https://bootstrap.pypa.io/get-pip.py && \
10
+ python3 get-pip.py && \
11
+ rm get-pip.py
12
+
13
+ # install dependencies
14
+ # See https://pytorch.org/ for other options if you use a different version of CUDA
15
+ RUN pip install tensorboard cython
16
+ RUN pip install torch==1.5+cu101 torchvision==0.6+cu101 -f https://download.pytorch.org/whl/torch_stable.html
17
+ RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Use the container (with docker ≥ 19.03)
3
+
4
+ ```
5
+ cd docker/
6
+ # Build:
7
+ docker build --build-arg USER_ID=$UID -t detectron2:v0 .
8
+ # Run:
9
+ docker run --gpus all -it \
10
+ --shm-size=8gb --env="DISPLAY" --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \
11
+ --name=detectron2 detectron2:v0
12
+
13
+ # Grant docker access to host X server to show images
14
+ xhost +local:`docker inspect --format='{{ .Config.Hostname }}' detectron2`
15
+ ```
16
+
17
+ ## Use the container (with docker < 19.03)
18
+
19
+ Install docker-compose and nvidia-docker2, then run:
20
+ ```
21
+ cd docker && USER_ID=$UID docker-compose run detectron2
22
+ ```
23
+
24
+ #### Using a persistent cache directory
25
+
26
+ You can prevent models from being re-downloaded on every run,
27
+ by storing them in a cache directory.
28
+
29
+ To do this, add `--volume=$HOME/.torch/fvcore_cache:/tmp:rw` in the run command.
30
+
31
+ ## Install new dependencies
32
+ Add the following to `Dockerfile` to make persistent changes.
33
+ ```
34
+ RUN sudo apt-get update && sudo apt-get install -y vim
35
+ ```
36
+ Or run them in the container to make temporary changes.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docker/docker-compose.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "2.3"
2
+ services:
3
+ detectron2:
4
+ build:
5
+ context: .
6
+ dockerfile: Dockerfile
7
+ args:
8
+ USER_ID: ${USER_ID:-1000}
9
+ runtime: nvidia # TODO: Exchange with "gpu: all" in the future (see https://github.com/facebookresearch/detectron2/pull/197/commits/00545e1f376918db4a8ce264d427a07c1e896c5a).
10
+ shm_size: "8gb"
11
+ ulimits:
12
+ memlock: -1
13
+ stack: 67108864
14
+ volumes:
15
+ - /tmp/.X11-unix:/tmp/.X11-unix:ro
16
+ environment:
17
+ - DISPLAY=$DISPLAY
18
+ - NVIDIA_VISIBLE_DEVICES=all
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/datasets.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Custom Datasets
2
+
3
+ Datasets that have builtin support in detectron2 are listed in [datasets](../../datasets).
4
+ If you want to use a custom dataset while also reusing detectron2's data loaders,
5
+ you will need to
6
+
7
+ 1. __Register__ your dataset (i.e., tell detectron2 how to obtain your dataset).
8
+ 2. Optionally, __register metadata__ for your dataset.
9
+
10
+ Next, we explain the above two concepts in detail.
11
+
12
+ The [Colab tutorial](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
13
+ has a live example of how to register and train on a dataset of custom formats.
14
+
15
+ ### Register a Dataset
16
+
17
+ To let detectron2 know how to obtain a dataset named "my_dataset", you will implement
18
+ a function that returns the items in your dataset and then tell detectron2 about this
19
+ function:
20
+ ```python
21
+ def my_dataset_function():
22
+ ...
23
+ return list[dict] in the following format
24
+
25
+ from detectron2.data import DatasetCatalog
26
+ DatasetCatalog.register("my_dataset", my_dataset_function)
27
+ ```
28
+
29
+ Here, the snippet associates a dataset "my_dataset" with a function that returns the data.
30
+ The registration stays effective until the process exists.
31
+
32
+ The function can processes data from its original format into either one of the following:
33
+ 1. Detectron2's standard dataset dict, described below. This will work with many other builtin
34
+ features in detectron2, so it's recommended to use it when it's sufficient for your task.
35
+ 2. Your custom dataset dict. You can also return arbitrary dicts in your own format,
36
+ such as adding extra keys for new tasks.
37
+ Then you will need to handle them properly downstream as well.
38
+ See below for more details.
39
+
40
+ #### Standard Dataset Dicts
41
+
42
+ For standard tasks
43
+ (instance detection, instance/semantic/panoptic segmentation, keypoint detection),
44
+ we load the original dataset into `list[dict]` with a specification similar to COCO's json annotations.
45
+ This is our standard representation for a dataset.
46
+
47
+ Each dict contains information about one image.
48
+ The dict may have the following fields,
49
+ and the required fields vary based on what the dataloader or the task needs (see more below).
50
+
51
+ + `file_name`: the full path to the image file. Will apply rotation and flipping if the image has such exif information.
52
+ + `height`, `width`: integer. The shape of image.
53
+ + `image_id` (str or int): a unique id that identifies this image. Used
54
+ during evaluation to identify the images, but a dataset may use it for different purposes.
55
+ + `annotations` (list[dict]): each dict corresponds to annotations of one instance
56
+ in this image. Required by instance detection/segmentation or keypoint detection tasks.
57
+
58
+ Images with empty `annotations` will by default be removed from training,
59
+ but can be included using `DATALOADER.FILTER_EMPTY_ANNOTATIONS`.
60
+
61
+ Each dict contains the following keys, of which `bbox`,`bbox_mode` and `category_id` are required:
62
+ + `bbox` (list[float]): list of 4 numbers representing the bounding box of the instance.
63
+ + `bbox_mode` (int): the format of bbox.
64
+ It must be a member of
65
+ [structures.BoxMode](../modules/structures.html#detectron2.structures.BoxMode).
66
+ Currently supports: `BoxMode.XYXY_ABS`, `BoxMode.XYWH_ABS`.
67
+ + `category_id` (int): an integer in the range [0, num_categories) representing the category label.
68
+ The value num_categories is reserved to represent the "background" category, if applicable.
69
+ + `segmentation` (list[list[float]] or dict): the segmentation mask of the instance.
70
+ + If `list[list[float]]`, it represents a list of polygons, one for each connected component
71
+ of the object. Each `list[float]` is one simple polygon in the format of `[x1, y1, ..., xn, yn]`.
72
+ The Xs and Ys are either relative coordinates in [0, 1], or absolute coordinates,
73
+ depend on whether "bbox_mode" is relative.
74
+ + If `dict`, it represents the per-pixel segmentation mask in COCO's RLE format. The dict should have
75
+ keys "size" and "counts". You can convert a uint8 segmentation mask of 0s and 1s into
76
+ RLE format by `pycocotools.mask.encode(np.asarray(mask, order="F"))`.
77
+ + `keypoints` (list[float]): in the format of [x1, y1, v1,..., xn, yn, vn].
78
+ v[i] means the [visibility](http://cocodataset.org/#format-data) of this keypoint.
79
+ `n` must be equal to the number of keypoint categories.
80
+ The Xs and Ys are either relative coordinates in [0, 1], or absolute coordinates,
81
+ depend on whether "bbox_mode" is relative.
82
+
83
+ Note that the coordinate annotations in COCO format are integers in range [0, H-1 or W-1].
84
+ By default, detectron2 adds 0.5 to absolute keypoint coordinates to convert them from discrete
85
+ pixel indices to floating point coordinates.
86
+ + `iscrowd`: 0 (default) or 1. Whether this instance is labeled as COCO's "crowd
87
+ region". Don't include this field if you don't know what it means.
88
+ + `sem_seg_file_name`: the full path to the ground truth semantic segmentation file.
89
+ Required by semantic segmentation task.
90
+ It should be an image whose pixel values are integer labels.
91
+
92
+
93
+ Fast R-CNN (with precomputed proposals) is rarely used today.
94
+ To train a Fast R-CNN, the following extra keys are needed:
95
+
96
+ + `proposal_boxes` (array): 2D numpy array with shape (K, 4) representing K precomputed proposal boxes for this image.
97
+ + `proposal_objectness_logits` (array): numpy array with shape (K, ), which corresponds to the objectness
98
+ logits of proposals in 'proposal_boxes'.
99
+ + `proposal_bbox_mode` (int): the format of the precomputed proposal bbox.
100
+ It must be a member of
101
+ [structures.BoxMode](../modules/structures.html#detectron2.structures.BoxMode).
102
+ Default is `BoxMode.XYXY_ABS`.
103
+
104
+ #### Custom Dataset Dicts for New Tasks
105
+
106
+ In the `list[dict]` that your dataset function returns, the dictionary can also have arbitrary custom data.
107
+ This will be useful for a new task that needs extra information not supported
108
+ by the standard dataset dicts. In this case, you need to make sure the downstream code can handle your data
109
+ correctly. Usually this requires writing a new `mapper` for the dataloader (see [Use Custom Dataloaders](./data_loading.md)).
110
+
111
+ When designing a custom format, note that all dicts are stored in memory
112
+ (sometimes serialized and with multiple copies).
113
+ To save memory, each dict is meant to contain small but sufficient information
114
+ about each sample, such as file names and annotations.
115
+ Loading full samples typically happens in the data loader.
116
+
117
+ For attributes shared among the entire dataset, use `Metadata` (see below).
118
+ To avoid extra memory, do not save such information repeatly for each sample.
119
+
120
+ ### "Metadata" for Datasets
121
+
122
+ Each dataset is associated with some metadata, accessible through
123
+ `MetadataCatalog.get(dataset_name).some_metadata`.
124
+ Metadata is a key-value mapping that contains information that's shared among
125
+ the entire dataset, and usually is used to interpret what's in the dataset, e.g.,
126
+ names of classes, colors of classes, root of files, etc.
127
+ This information will be useful for augmentation, evaluation, visualization, logging, etc.
128
+ The structure of metadata depends on the what is needed from the corresponding downstream code.
129
+
130
+ If you register a new dataset through `DatasetCatalog.register`,
131
+ you may also want to add its corresponding metadata through
132
+ `MetadataCatalog.get(dataset_name).some_key = some_value`, to enable any features that need the metadata.
133
+ You can do it like this (using the metadata key "thing_classes" as an example):
134
+
135
+ ```python
136
+ from detectron2.data import MetadataCatalog
137
+ MetadataCatalog.get("my_dataset").thing_classes = ["person", "dog"]
138
+ ```
139
+
140
+ Here is a list of metadata keys that are used by builtin features in detectron2.
141
+ If you add your own dataset without these metadata, some features may be
142
+ unavailable to you:
143
+
144
+ * `thing_classes` (list[str]): Used by all instance detection/segmentation tasks.
145
+ A list of names for each instance/thing category.
146
+ If you load a COCO format dataset, it will be automatically set by the function `load_coco_json`.
147
+
148
+ * `thing_colors` (list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each thing category.
149
+ Used for visualization. If not given, random colors are used.
150
+
151
+ * `stuff_classes` (list[str]): Used by semantic and panoptic segmentation tasks.
152
+ A list of names for each stuff category.
153
+
154
+ * `stuff_colors` (list[tuple(r, g, b)]): Pre-defined color (in [0, 255]) for each stuff category.
155
+ Used for visualization. If not given, random colors are used.
156
+
157
+ * `keypoint_names` (list[str]): Used by keypoint localization. A list of names for each keypoint.
158
+
159
+ * `keypoint_flip_map` (list[tuple[str]]): Used by the keypoint localization task. A list of pairs of names,
160
+ where each pair are the two keypoints that should be flipped if the image is
161
+ flipped horizontally during augmentation.
162
+ * `keypoint_connection_rules`: list[tuple(str, str, (r, g, b))]. Each tuple specifies a pair of keypoints
163
+ that are connected and the color to use for the line between them when visualized.
164
+
165
+ Some additional metadata that are specific to the evaluation of certain datasets (e.g. COCO):
166
+
167
+ * `thing_dataset_id_to_contiguous_id` (dict[int->int]): Used by all instance detection/segmentation tasks in the COCO format.
168
+ A mapping from instance class ids in the dataset to contiguous ids in range [0, #class).
169
+ Will be automatically set by the function `load_coco_json`.
170
+
171
+ * `stuff_dataset_id_to_contiguous_id` (dict[int->int]): Used when generating prediction json files for
172
+ semantic/panoptic segmentation.
173
+ A mapping from semantic segmentation class ids in the dataset
174
+ to contiguous ids in [0, num_categories). It is useful for evaluation only.
175
+
176
+ * `json_file`: The COCO annotation json file. Used by COCO evaluation for COCO-format datasets.
177
+ * `panoptic_root`, `panoptic_json`: Used by panoptic evaluation.
178
+ * `evaluator_type`: Used by the builtin main training script to select
179
+ evaluator. Don't use it in a new training script.
180
+ You can just provide the [DatasetEvaluator](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluator)
181
+ for your dataset directly in your main script.
182
+
183
+ NOTE: For background on the concept of "thing" and "stuff", see
184
+ [On Seeing Stuff: The Perception of Materials by Humans and Machines](http://persci.mit.edu/pub_pdfs/adelson_spie_01.pdf).
185
+ In detectron2, the term "thing" is used for instance-level tasks,
186
+ and "stuff" is used for semantic segmentation tasks.
187
+ Both are used in panoptic segmentation.
188
+
189
+ ### Register a COCO Format Dataset
190
+
191
+ If your dataset is already a json file in the COCO format,
192
+ the dataset and its associated metadata can be registered easily with:
193
+ ```python
194
+ from detectron2.data.datasets import register_coco_instances
195
+ register_coco_instances("my_dataset", {}, "json_annotation.json", "path/to/image/dir")
196
+ ```
197
+
198
+ If your dataset is in COCO format but with extra custom per-instance annotations,
199
+ the [load_coco_json](../modules/data.html#detectron2.data.datasets.load_coco_json)
200
+ function might be useful.
201
+
202
+ ### Update the Config for New Datasets
203
+
204
+ Once you've registered the dataset, you can use the name of the dataset (e.g., "my_dataset" in
205
+ example above) in `cfg.DATASETS.{TRAIN,TEST}`.
206
+ There are other configs you might want to change to train or evaluate on new datasets:
207
+
208
+ * `MODEL.ROI_HEADS.NUM_CLASSES` and `MODEL.RETINANET.NUM_CLASSES` are the number of thing classes
209
+ for R-CNN and RetinaNet models, respectively.
210
+ * `MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS` sets the number of keypoints for Keypoint R-CNN.
211
+ You'll also need to set [Keypoint OKS](http://cocodataset.org/#keypoints-eval)
212
+ with `TEST.KEYPOINT_OKS_SIGMAS` for evaluation.
213
+ * `MODEL.SEM_SEG_HEAD.NUM_CLASSES` sets the number of stuff classes for Semantic FPN & Panoptic FPN.
214
+ * If you're training Fast R-CNN (with precomputed proposals), `DATASETS.PROPOSAL_FILES_{TRAIN,TEST}`
215
+ need to match the datasets. The format of proposal files are documented
216
+ [here](../modules/data.html#detectron2.data.load_proposals_into_dataset).
217
+
218
+ New models
219
+ (e.g. [TensorMask](../../projects/TensorMask),
220
+ [PointRend](../../projects/PointRend))
221
+ often have similar configs of their own that need to be changed as well.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/evaluation.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Evaluation
3
+
4
+ Evaluation is a process that takes a number of inputs/outputs pairs and aggregate them.
5
+ You can always [use the model](./models.md) directly and just parse its inputs/outputs manually to perform
6
+ evaluation.
7
+ Alternatively, evaluation is implemented in detectron2 using the [DatasetEvaluator](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluator)
8
+ interface.
9
+
10
+ Detectron2 includes a few `DatasetEvaluator` that computes metrics using standard dataset-specific
11
+ APIs (e.g., COCO, LVIS).
12
+ You can also implement your own `DatasetEvaluator` that performs some other jobs
13
+ using the inputs/outputs pairs.
14
+ For example, to count how many instances are detected on the validation set:
15
+
16
+ ```
17
+ class Counter(DatasetEvaluator):
18
+ def reset(self):
19
+ self.count = 0
20
+ def process(self, inputs, outputs):
21
+ for output in outputs:
22
+ self.count += len(output["instances"])
23
+ def evaluate(self):
24
+ # save self.count somewhere, or print it, or return it.
25
+ return {"count": self.count}
26
+ ```
27
+
28
+ Once you have some `DatasetEvaluator`, you can run it with
29
+ [inference_on_dataset](../modules/evaluation.html#detectron2.evaluation.inference_on_dataset).
30
+ For example,
31
+
32
+ ```python
33
+ val_results = inference_on_dataset(
34
+ model,
35
+ val_data_loader,
36
+ DatasetEvaluators([COCOEvaluator(...), Counter()]))
37
+ ```
38
+ Compared to running the evaluation manually using the model, the benefit of this function is that
39
+ you can merge evaluators together using [DatasetEvaluators](../modules/evaluation.html#detectron2.evaluation.DatasetEvaluators).
40
+ In this way you can run all evaluations without having to go through the dataset multiple times.
41
+
42
+ The `inference_on_dataset` function also provides accurate speed benchmarks for the
43
+ given model and dataset.
OOTDiffusion/preprocess/humanparsing/mhp_extension/detectron2/docs/tutorials/index.rst ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorials
2
+ ======================================
3
+
4
+ .. toctree::
5
+ :maxdepth: 2
6
+
7
+ install
8
+ getting_started
9
+ builtin_datasets
10
+ extend
11
+ datasets
12
+ data_loading
13
+ models
14
+ write-models
15
+ training
16
+ evaluation
17
+ configs
18
+ deployment
OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_datasets.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : datasets.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ import random
17
+ import torch
18
+ import cv2
19
+ from torch.utils import data
20
+ from utils.transforms import get_affine_transform
21
+
22
+
23
+ class CropDataSet(data.Dataset):
24
+ def __init__(self, root, split_name, crop_size=[473, 473], scale_factor=0.25,
25
+ rotation_factor=30, ignore_label=255, transform=None):
26
+ self.root = root
27
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
28
+ self.crop_size = np.asarray(crop_size)
29
+ self.ignore_label = ignore_label
30
+ self.scale_factor = scale_factor
31
+ self.rotation_factor = rotation_factor
32
+ self.flip_prob = 0.5
33
+ self.transform = transform
34
+ self.split_name = split_name
35
+
36
+ list_path = os.path.join(self.root, self.split_name + '.txt')
37
+ train_list = [i_id.strip() for i_id in open(list_path)]
38
+
39
+ self.train_list = train_list
40
+ self.number_samples = len(self.train_list)
41
+
42
+ def __len__(self):
43
+ return self.number_samples
44
+
45
+ def _box2cs(self, box):
46
+ x, y, w, h = box[:4]
47
+ return self._xywh2cs(x, y, w, h)
48
+
49
+ def _xywh2cs(self, x, y, w, h):
50
+ center = np.zeros((2), dtype=np.float32)
51
+ center[0] = x + w * 0.5
52
+ center[1] = y + h * 0.5
53
+ if w > self.aspect_ratio * h:
54
+ h = w * 1.0 / self.aspect_ratio
55
+ elif w < self.aspect_ratio * h:
56
+ w = h * self.aspect_ratio
57
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
58
+ return center, scale
59
+
60
+ def __getitem__(self, index):
61
+ train_item = self.train_list[index]
62
+
63
+ im_path = os.path.join(self.root, self.split_name + '_images', train_item + '.jpg')
64
+ parsing_anno_path = os.path.join(self.root, self.split_name + '_segmentations', train_item + '.png')
65
+
66
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
67
+ h, w, _ = im.shape
68
+ parsing_anno = np.zeros((h, w), dtype=np.long)
69
+
70
+ # Get person center and scale
71
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
72
+ r = 0
73
+
74
+ if self.split_name != 'test':
75
+ # Get pose annotation
76
+ parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
77
+ sf = self.scale_factor
78
+ rf = self.rotation_factor
79
+ s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
80
+ r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
81
+
82
+ if random.random() <= self.flip_prob:
83
+ im = im[:, ::-1, :]
84
+ parsing_anno = parsing_anno[:, ::-1]
85
+ person_center[0] = im.shape[1] - person_center[0] - 1
86
+ right_idx = [15, 17, 19]
87
+ left_idx = [14, 16, 18]
88
+ for i in range(0, 3):
89
+ right_pos = np.where(parsing_anno == right_idx[i])
90
+ left_pos = np.where(parsing_anno == left_idx[i])
91
+ parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
92
+ parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
93
+
94
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
95
+ input = cv2.warpAffine(
96
+ im,
97
+ trans,
98
+ (int(self.crop_size[1]), int(self.crop_size[0])),
99
+ flags=cv2.INTER_LINEAR,
100
+ borderMode=cv2.BORDER_CONSTANT,
101
+ borderValue=(0, 0, 0))
102
+
103
+ if self.transform:
104
+ input = self.transform(input)
105
+
106
+ meta = {
107
+ 'name': train_item,
108
+ 'center': person_center,
109
+ 'height': h,
110
+ 'width': w,
111
+ 'scale': s,
112
+ 'rotation': r
113
+ }
114
+
115
+ if self.split_name == 'val' or self.split_name == 'test':
116
+ return input, meta
117
+ else:
118
+ label_parsing = cv2.warpAffine(
119
+ parsing_anno,
120
+ trans,
121
+ (int(self.crop_size[1]), int(self.crop_size[0])),
122
+ flags=cv2.INTER_NEAREST,
123
+ borderMode=cv2.BORDER_CONSTANT,
124
+ borderValue=(255))
125
+
126
+ label_parsing = torch.from_numpy(label_parsing)
127
+
128
+ return input, label_parsing, meta
129
+
130
+
131
+ class CropDataValSet(data.Dataset):
132
+ def __init__(self, root, split_name='crop_pic', crop_size=[473, 473], transform=None, flip=False):
133
+ self.root = root
134
+ self.crop_size = crop_size
135
+ self.transform = transform
136
+ self.flip = flip
137
+ self.split_name = split_name
138
+ self.root = root
139
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
140
+ self.crop_size = np.asarray(crop_size)
141
+
142
+ list_path = os.path.join(self.root, self.split_name + '.txt')
143
+ val_list = [i_id.strip() for i_id in open(list_path)]
144
+
145
+ self.val_list = val_list
146
+ self.number_samples = len(self.val_list)
147
+
148
+ def __len__(self):
149
+ return len(self.val_list)
150
+
151
+ def _box2cs(self, box):
152
+ x, y, w, h = box[:4]
153
+ return self._xywh2cs(x, y, w, h)
154
+
155
+ def _xywh2cs(self, x, y, w, h):
156
+ center = np.zeros((2), dtype=np.float32)
157
+ center[0] = x + w * 0.5
158
+ center[1] = y + h * 0.5
159
+ if w > self.aspect_ratio * h:
160
+ h = w * 1.0 / self.aspect_ratio
161
+ elif w < self.aspect_ratio * h:
162
+ w = h * self.aspect_ratio
163
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
164
+
165
+ return center, scale
166
+
167
+ def __getitem__(self, index):
168
+ val_item = self.val_list[index]
169
+ # Load training image
170
+ im_path = os.path.join(self.root, self.split_name, val_item + '.jpg')
171
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
172
+ h, w, _ = im.shape
173
+ # Get person center and scale
174
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
175
+ r = 0
176
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
177
+ input = cv2.warpAffine(
178
+ im,
179
+ trans,
180
+ (int(self.crop_size[1]), int(self.crop_size[0])),
181
+ flags=cv2.INTER_LINEAR,
182
+ borderMode=cv2.BORDER_CONSTANT,
183
+ borderValue=(0, 0, 0))
184
+ input = self.transform(input)
185
+ flip_input = input.flip(dims=[-1])
186
+ if self.flip:
187
+ batch_input_im = torch.stack([input, flip_input])
188
+ else:
189
+ batch_input_im = input
190
+
191
+ meta = {
192
+ 'name': val_item,
193
+ 'center': person_center,
194
+ 'height': h,
195
+ 'width': w,
196
+ 'scale': s,
197
+ 'rotation': r
198
+ }
199
+
200
+ return batch_input_im, meta
OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_evaluate.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : evaluate.py
8
+ @Time : 8/4/19 3:36 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import argparse
16
+ import numpy as np
17
+ import torch
18
+
19
+ from torch.utils import data
20
+ from tqdm import tqdm
21
+ from PIL import Image as PILImage
22
+ import torchvision.transforms as transforms
23
+ import torch.backends.cudnn as cudnn
24
+
25
+ import networks
26
+ from utils.miou import compute_mean_ioU
27
+ from utils.transforms import BGR2RGB_transform
28
+ from utils.transforms import transform_parsing, transform_logits
29
+ from mhp_extension.global_local_parsing.global_local_datasets import CropDataValSet
30
+
31
+
32
+ def get_arguments():
33
+ """Parse all the arguments provided from the CLI.
34
+
35
+ Returns:
36
+ A list of parsed arguments.
37
+ """
38
+ parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
39
+
40
+ # Network Structure
41
+ parser.add_argument("--arch", type=str, default='resnet101')
42
+ # Data Preference
43
+ parser.add_argument("--data-dir", type=str, default='./data/LIP')
44
+ parser.add_argument("--batch-size", type=int, default=1)
45
+ parser.add_argument("--split-name", type=str, default='crop_pic')
46
+ parser.add_argument("--input-size", type=str, default='473,473')
47
+ parser.add_argument("--num-classes", type=int, default=20)
48
+ parser.add_argument("--ignore-label", type=int, default=255)
49
+ parser.add_argument("--random-mirror", action="store_true")
50
+ parser.add_argument("--random-scale", action="store_true")
51
+ # Evaluation Preference
52
+ parser.add_argument("--log-dir", type=str, default='./log')
53
+ parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
54
+ parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
55
+ parser.add_argument("--save-results", action="store_true", help="whether to save the results.")
56
+ parser.add_argument("--flip", action="store_true", help="random flip during the test.")
57
+ parser.add_argument("--multi-scales", type=str, default='1', help="multiple scales during the test")
58
+ return parser.parse_args()
59
+
60
+
61
+ def get_palette(num_cls):
62
+ """ Returns the color map for visualizing the segmentation mask.
63
+ Args:
64
+ num_cls: Number of classes
65
+ Returns:
66
+ The color map
67
+ """
68
+ n = num_cls
69
+ palette = [0] * (n * 3)
70
+ for j in range(0, n):
71
+ lab = j
72
+ palette[j * 3 + 0] = 0
73
+ palette[j * 3 + 1] = 0
74
+ palette[j * 3 + 2] = 0
75
+ i = 0
76
+ while lab:
77
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
78
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
79
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
80
+ i += 1
81
+ lab >>= 3
82
+ return palette
83
+
84
+
85
+ def multi_scale_testing(model, batch_input_im, crop_size=[473, 473], flip=True, multi_scales=[1]):
86
+ flipped_idx = (15, 14, 17, 16, 19, 18)
87
+ if len(batch_input_im.shape) > 4:
88
+ batch_input_im = batch_input_im.squeeze()
89
+ if len(batch_input_im.shape) == 3:
90
+ batch_input_im = batch_input_im.unsqueeze(0)
91
+
92
+ interp = torch.nn.Upsample(size=crop_size, mode='bilinear', align_corners=True)
93
+ ms_outputs = []
94
+ for s in multi_scales:
95
+ interp_im = torch.nn.Upsample(scale_factor=s, mode='bilinear', align_corners=True)
96
+ scaled_im = interp_im(batch_input_im)
97
+ parsing_output = model(scaled_im)
98
+ parsing_output = parsing_output[0][-1]
99
+ output = parsing_output[0]
100
+ if flip:
101
+ flipped_output = parsing_output[1]
102
+ flipped_output[14:20, :, :] = flipped_output[flipped_idx, :, :]
103
+ output += flipped_output.flip(dims=[-1])
104
+ output *= 0.5
105
+ output = interp(output.unsqueeze(0))
106
+ ms_outputs.append(output[0])
107
+ ms_fused_parsing_output = torch.stack(ms_outputs)
108
+ ms_fused_parsing_output = ms_fused_parsing_output.mean(0)
109
+ ms_fused_parsing_output = ms_fused_parsing_output.permute(1, 2, 0) # HWC
110
+ parsing = torch.argmax(ms_fused_parsing_output, dim=2)
111
+ parsing = parsing.data.cpu().numpy()
112
+ ms_fused_parsing_output = ms_fused_parsing_output.data.cpu().numpy()
113
+ return parsing, ms_fused_parsing_output
114
+
115
+
116
+ def main():
117
+ """Create the model and start the evaluation process."""
118
+ args = get_arguments()
119
+ multi_scales = [float(i) for i in args.multi_scales.split(',')]
120
+ gpus = [int(i) for i in args.gpu.split(',')]
121
+ assert len(gpus) == 1
122
+ if not args.gpu == 'None':
123
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
124
+
125
+ cudnn.benchmark = True
126
+ cudnn.enabled = True
127
+
128
+ h, w = map(int, args.input_size.split(','))
129
+ input_size = [h, w]
130
+
131
+ model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=None)
132
+
133
+ IMAGE_MEAN = model.mean
134
+ IMAGE_STD = model.std
135
+ INPUT_SPACE = model.input_space
136
+ print('image mean: {}'.format(IMAGE_MEAN))
137
+ print('image std: {}'.format(IMAGE_STD))
138
+ print('input space:{}'.format(INPUT_SPACE))
139
+ if INPUT_SPACE == 'BGR':
140
+ print('BGR Transformation')
141
+ transform = transforms.Compose([
142
+ transforms.ToTensor(),
143
+ transforms.Normalize(mean=IMAGE_MEAN,
144
+ std=IMAGE_STD),
145
+
146
+ ])
147
+ if INPUT_SPACE == 'RGB':
148
+ print('RGB Transformation')
149
+ transform = transforms.Compose([
150
+ transforms.ToTensor(),
151
+ BGR2RGB_transform(),
152
+ transforms.Normalize(mean=IMAGE_MEAN,
153
+ std=IMAGE_STD),
154
+ ])
155
+
156
+ # Data loader
157
+ lip_test_dataset = CropDataValSet(args.data_dir, args.split_name, crop_size=input_size, transform=transform,
158
+ flip=args.flip)
159
+ num_samples = len(lip_test_dataset)
160
+ print('Totoal testing sample numbers: {}'.format(num_samples))
161
+ testloader = data.DataLoader(lip_test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
162
+
163
+ # Load model weight
164
+ state_dict = torch.load(args.model_restore)
165
+ from collections import OrderedDict
166
+ new_state_dict = OrderedDict()
167
+ for k, v in state_dict.items():
168
+ name = k[7:] # remove `module.`
169
+ new_state_dict[name] = v
170
+ model.load_state_dict(new_state_dict)
171
+ model.cuda()
172
+ model.eval()
173
+
174
+ sp_results_dir = os.path.join(args.log_dir, args.split_name + '_parsing')
175
+ if not os.path.exists(sp_results_dir):
176
+ os.makedirs(sp_results_dir)
177
+
178
+ palette = get_palette(20)
179
+ parsing_preds = []
180
+ scales = np.zeros((num_samples, 2), dtype=np.float32)
181
+ centers = np.zeros((num_samples, 2), dtype=np.int32)
182
+ with torch.no_grad():
183
+ for idx, batch in enumerate(tqdm(testloader)):
184
+ image, meta = batch
185
+ if (len(image.shape) > 4):
186
+ image = image.squeeze()
187
+ im_name = meta['name'][0]
188
+ c = meta['center'].numpy()[0]
189
+ s = meta['scale'].numpy()[0]
190
+ w = meta['width'].numpy()[0]
191
+ h = meta['height'].numpy()[0]
192
+ scales[idx, :] = s
193
+ centers[idx, :] = c
194
+ parsing, logits = multi_scale_testing(model, image.cuda(), crop_size=input_size, flip=args.flip,
195
+ multi_scales=multi_scales)
196
+ if args.save_results:
197
+ parsing_result = transform_parsing(parsing, c, s, w, h, input_size)
198
+ parsing_result_path = os.path.join(sp_results_dir, im_name + '.png')
199
+ output_im = PILImage.fromarray(np.asarray(parsing_result, dtype=np.uint8))
200
+ output_im.putpalette(palette)
201
+ output_im.save(parsing_result_path)
202
+ # save logits
203
+ logits_result = transform_logits(logits, c, s, w, h, input_size)
204
+ logits_result_path = os.path.join(sp_results_dir, im_name + '.npy')
205
+ np.save(logits_result_path, logits_result)
206
+ return
207
+
208
+
209
+ if __name__ == '__main__':
210
+ main()
OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/global_local_train.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : train.py
8
+ @Time : 8/4/19 3:36 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import json
16
+ import timeit
17
+ import argparse
18
+
19
+ import torch
20
+ import torch.optim as optim
21
+ import torchvision.transforms as transforms
22
+ import torch.backends.cudnn as cudnn
23
+ from torch.utils import data
24
+
25
+ import networks
26
+ import utils.schp as schp
27
+ from datasets.datasets import LIPDataSet
28
+ from datasets.target_generation import generate_edge_tensor
29
+ from utils.transforms import BGR2RGB_transform
30
+ from utils.criterion import CriterionAll
31
+ from utils.encoding import DataParallelModel, DataParallelCriterion
32
+ from utils.warmup_scheduler import SGDRScheduler
33
+
34
+
35
+ def get_arguments():
36
+ """Parse all the arguments provided from the CLI.
37
+ Returns:
38
+ A list of parsed arguments.
39
+ """
40
+ parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
41
+
42
+ # Network Structure
43
+ parser.add_argument("--arch", type=str, default='resnet101')
44
+ # Data Preference
45
+ parser.add_argument("--data-dir", type=str, default='./data/LIP')
46
+ parser.add_argument("--batch-size", type=int, default=16)
47
+ parser.add_argument("--input-size", type=str, default='473,473')
48
+ parser.add_argument("--split-name", type=str, default='crop_pic')
49
+ parser.add_argument("--num-classes", type=int, default=20)
50
+ parser.add_argument("--ignore-label", type=int, default=255)
51
+ parser.add_argument("--random-mirror", action="store_true")
52
+ parser.add_argument("--random-scale", action="store_true")
53
+ # Training Strategy
54
+ parser.add_argument("--learning-rate", type=float, default=7e-3)
55
+ parser.add_argument("--momentum", type=float, default=0.9)
56
+ parser.add_argument("--weight-decay", type=float, default=5e-4)
57
+ parser.add_argument("--gpu", type=str, default='0,1,2')
58
+ parser.add_argument("--start-epoch", type=int, default=0)
59
+ parser.add_argument("--epochs", type=int, default=150)
60
+ parser.add_argument("--eval-epochs", type=int, default=10)
61
+ parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth')
62
+ parser.add_argument("--log-dir", type=str, default='./log')
63
+ parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
64
+ parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch')
65
+ parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch')
66
+ parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar')
67
+ parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight')
68
+ parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight')
69
+ parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight')
70
+ return parser.parse_args()
71
+
72
+
73
+ def main():
74
+ args = get_arguments()
75
+ print(args)
76
+
77
+ start_epoch = 0
78
+ cycle_n = 0
79
+
80
+ if not os.path.exists(args.log_dir):
81
+ os.makedirs(args.log_dir)
82
+ with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
83
+ json.dump(vars(args), opt_file)
84
+
85
+ gpus = [int(i) for i in args.gpu.split(',')]
86
+ if not args.gpu == 'None':
87
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
88
+
89
+ input_size = list(map(int, args.input_size.split(',')))
90
+
91
+ cudnn.enabled = True
92
+ cudnn.benchmark = True
93
+
94
+ # Model Initialization
95
+ AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
96
+ model = DataParallelModel(AugmentCE2P)
97
+ model.cuda()
98
+
99
+ IMAGE_MEAN = AugmentCE2P.mean
100
+ IMAGE_STD = AugmentCE2P.std
101
+ INPUT_SPACE = AugmentCE2P.input_space
102
+ print('image mean: {}'.format(IMAGE_MEAN))
103
+ print('image std: {}'.format(IMAGE_STD))
104
+ print('input space:{}'.format(INPUT_SPACE))
105
+
106
+ restore_from = args.model_restore
107
+ if os.path.exists(restore_from):
108
+ print('Resume training from {}'.format(restore_from))
109
+ checkpoint = torch.load(restore_from)
110
+ model.load_state_dict(checkpoint['state_dict'])
111
+ start_epoch = checkpoint['epoch']
112
+
113
+ SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
114
+ schp_model = DataParallelModel(SCHP_AugmentCE2P)
115
+ schp_model.cuda()
116
+
117
+ if os.path.exists(args.schp_restore):
118
+ print('Resuming schp checkpoint from {}'.format(args.schp_restore))
119
+ schp_checkpoint = torch.load(args.schp_restore)
120
+ schp_model_state_dict = schp_checkpoint['state_dict']
121
+ cycle_n = schp_checkpoint['cycle_n']
122
+ schp_model.load_state_dict(schp_model_state_dict)
123
+
124
+ # Loss Function
125
+ criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c,
126
+ num_classes=args.num_classes)
127
+ criterion = DataParallelCriterion(criterion)
128
+ criterion.cuda()
129
+
130
+ # Data Loader
131
+ if INPUT_SPACE == 'BGR':
132
+ print('BGR Transformation')
133
+ transform = transforms.Compose([
134
+ transforms.ToTensor(),
135
+ transforms.Normalize(mean=IMAGE_MEAN,
136
+ std=IMAGE_STD),
137
+ ])
138
+
139
+ elif INPUT_SPACE == 'RGB':
140
+ print('RGB Transformation')
141
+ transform = transforms.Compose([
142
+ transforms.ToTensor(),
143
+ BGR2RGB_transform(),
144
+ transforms.Normalize(mean=IMAGE_MEAN,
145
+ std=IMAGE_STD),
146
+ ])
147
+
148
+ train_dataset = LIPDataSet(args.data_dir, args.split_name, crop_size=input_size, transform=transform)
149
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus),
150
+ num_workers=16, shuffle=True, pin_memory=True, drop_last=True)
151
+ print('Total training samples: {}'.format(len(train_dataset)))
152
+
153
+ # Optimizer Initialization
154
+ optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
155
+ weight_decay=args.weight_decay)
156
+
157
+ lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs,
158
+ eta_min=args.learning_rate / 100, warmup_epoch=10,
159
+ start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2,
160
+ cyclical_epoch=args.cycle_epochs)
161
+
162
+ total_iters = args.epochs * len(train_loader)
163
+ start = timeit.default_timer()
164
+ for epoch in range(start_epoch, args.epochs):
165
+ lr_scheduler.step(epoch=epoch)
166
+ lr = lr_scheduler.get_lr()[0]
167
+
168
+ model.train()
169
+ for i_iter, batch in enumerate(train_loader):
170
+ i_iter += len(train_loader) * epoch
171
+
172
+ images, labels, _ = batch
173
+ labels = labels.cuda(non_blocking=True)
174
+
175
+ edges = generate_edge_tensor(labels)
176
+ labels = labels.type(torch.cuda.LongTensor)
177
+ edges = edges.type(torch.cuda.LongTensor)
178
+
179
+ preds = model(images)
180
+
181
+ # Online Self Correction Cycle with Label Refinement
182
+ if cycle_n >= 1:
183
+ with torch.no_grad():
184
+ soft_preds = schp_model(images)
185
+ soft_parsing = []
186
+ soft_edge = []
187
+ for soft_pred in soft_preds:
188
+ soft_parsing.append(soft_pred[0][-1])
189
+ soft_edge.append(soft_pred[1][-1])
190
+ soft_preds = torch.cat(soft_parsing, dim=0)
191
+ soft_edges = torch.cat(soft_edge, dim=0)
192
+ else:
193
+ soft_preds = None
194
+ soft_edges = None
195
+
196
+ loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n)
197
+
198
+ optimizer.zero_grad()
199
+ loss.backward()
200
+ optimizer.step()
201
+
202
+ if i_iter % 100 == 0:
203
+ print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr,
204
+ loss.data.cpu().numpy()))
205
+ if (epoch + 1) % (args.eval_epochs) == 0:
206
+ schp.save_checkpoint({
207
+ 'epoch': epoch + 1,
208
+ 'state_dict': model.state_dict(),
209
+ }, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1))
210
+
211
+ # Self Correction Cycle with Model Aggregation
212
+ if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
213
+ print('Self-correction cycle number {}'.format(cycle_n))
214
+ schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
215
+ cycle_n += 1
216
+ schp.bn_re_estimate(train_loader, schp_model)
217
+ schp.save_schp_checkpoint({
218
+ 'state_dict': schp_model.state_dict(),
219
+ 'cycle_n': cycle_n,
220
+ }, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))
221
+
222
+ torch.cuda.empty_cache()
223
+ end = timeit.default_timer()
224
+ print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs,
225
+ (end - start) / (epoch - start_epoch + 1)))
226
+
227
+ end = timeit.default_timer()
228
+ print('Training Finished in {} seconds'.format(end - start))
229
+
230
+
231
+ if __name__ == '__main__':
232
+ main()
OOTDiffusion/preprocess/humanparsing/mhp_extension/global_local_parsing/make_id_list.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ DATASET = 'VIP' # DATASET: MHPv2 or CIHP or VIP
4
+ TYPE = 'crop_pic' # crop_pic or DemoDataset
5
+ IMG_DIR = '../demo/cropped_img/crop_pic'
6
+ SAVE_DIR = '../demo/cropped_img'
7
+
8
+ if not os.path.exists(SAVE_DIR):
9
+ os.makedirs(SAVE_DIR)
10
+
11
+ with open(os.path.join(SAVE_DIR, TYPE + '.txt'), "w") as f:
12
+ for img_name in os.listdir(IMG_DIR):
13
+ f.write(img_name[:-4] + '\n')
OOTDiffusion/run/examples/garment/00055_00.jpg ADDED

Git LFS Details

  • SHA256: c9c13719456663be040a63711d0ee92c6bac2e259017b072d396d874bdc367ad
  • Pointer size: 130 Bytes
  • Size of remote file: 94.8 kB
OOTDiffusion/run/examples/garment/00126_00.jpg ADDED

Git LFS Details

  • SHA256: 661ebf9c36b4ef974503023f8e59fe28937757646fb27b21c960869ab1cf3ac1
  • Pointer size: 130 Bytes
  • Size of remote file: 71.1 kB
OOTDiffusion/run/examples/garment/00151_00.jpg ADDED

Git LFS Details

  • SHA256: d8ad0a63cdacfa9e7876c25624195def2f4005cc9eb8e63b3f8f275b4fa0d7f7
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB
OOTDiffusion/run/examples/garment/00470_00.jpg ADDED

Git LFS Details

  • SHA256: b9d10eb367818c666c73b3ab42e1a62c51b3a0bea2b2284dd0a675980fb87a5b
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
OOTDiffusion/run/examples/garment/02015_00.jpg ADDED

Git LFS Details

  • SHA256: 54fdb08bde028706e345c35e2dbd8d3c2a4b57b2036703dc626b6a1597431f75
  • Pointer size: 130 Bytes
  • Size of remote file: 95.4 kB
OOTDiffusion/run/examples/garment/02305_00.jpg ADDED

Git LFS Details

  • SHA256: 6a836de0928148a0c8a993a36a49fa6559e2709c9913caff4c692496a6b86c68
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
OOTDiffusion/run/examples/garment/03032_00.jpg ADDED

Git LFS Details

  • SHA256: f129842d75d9cd347414710bdc148b3277abae85b17f44b71158fd5c8d641eed
  • Pointer size: 130 Bytes
  • Size of remote file: 84.1 kB
OOTDiffusion/run/examples/garment/03244_00.jpg ADDED

Git LFS Details

  • SHA256: 1216be01fdc43398f2adc4d747e12c6c8d4304465396525a73fd4b0260d5b074
  • Pointer size: 130 Bytes
  • Size of remote file: 88.7 kB
OOTDiffusion/run/examples/garment/04825_00.jpg ADDED

Git LFS Details

  • SHA256: 45ed29d00f63b8db2ee3952afbb85dbcb3b993eed2328667a3d72e58b31db2aa
  • Pointer size: 131 Bytes
  • Size of remote file: 168 kB