diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..19f846cfea6e8d0869d82423a206db6e52aaa6ff 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1d33b2065c4198f2633d0a0edb3494baf998fc2a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+__pycache__
+checkpoints
\ No newline at end of file
diff --git a/README.md b/README.md
index 8e3db96511bfcc98683847966b838b0e35baddd3..7c8a65c66dce8a0352486211b9b14fc6caf4bea2 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,34 @@
---
-title: VITON HD
-emoji: 🌍
-colorFrom: green
-colorTo: indigo
+title: Virtual Try-On
+emoji: 👗
+colorFrom: pink
+colorTo: purple
sdk: gradio
sdk_version: 5.34.2
app_file: app.py
-pinned: false
-license: cc-by-nc-sa-4.0
-short_description: Virtual try-on
+pinned: true
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Virtual Try-On Demo
+This repository is the work demo implementation of [PromptDresser](https://arxiv.org/abs/2412.16978)
+
+> **PromptDresser: Improving the Quality and Controllability of Virtual Try-On via Generative Textual Prompt and Prompt-aware Mask**
+> [Jeongho Kim](https://scholar.google.co.kr/citations?user=4SCCBFwAAAAJ&hl=ko), [Hoiyeong Jin](https://scholar.google.com/citations?user=Jp-zhtUAAAAJ&hl=en), [Sunghyun Park](https://psh01087.github.io/), [Jaegul Choo](https://sites.google.com/site/jaegulchoo/)
+
+[[arXiv Paper](https://arxiv.org/abs/2412.16978)]
+
+## Citation
+```
+@misc{kim2024promptdresserimprovingqualitycontrollability,
+ title={PromptDresser: Improving the Quality and Controllability of Virtual Try-On via Generative Textual Prompt and Prompt-aware Mask},
+ author={Jeongho Kim and Hoiyeong Jin and Sunghyun Park and Jaegul Choo},
+ year={2024},
+ eprint={2412.16978},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2412.16978},
+}
+```
+
+## License
+Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..38106b6862149952c190f0c88ee516c60725e88d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,150 @@
+import os
+import torch
+import gradio as gr
+import tempfile
+from huggingface_hub import hf_hub_download
+from diffusers import AutoencoderKL, DDPMScheduler
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
+
+from promptdresser.models.unet import UNet2DConditionModel
+from promptdresser.models.cloth_encoder import ClothEncoder
+from promptdresser.pipelines.sdxl import PromptDresser
+from lib.caption import generate_caption
+from lib.mask import generate_clothing_mask
+from lib.pose import generate_openpose
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+weight_dtype = torch.float16 if device == "cuda" else torch.float32
+
+def load_models():
+ print("⚙️ Загрузка моделей...")
+
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ subfolder="scheduler"
+ )
+ tokenizer = CLIPTokenizer.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder")
+ tokenizer_2 = CLIPTokenizer.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer_2")
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder_2")
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
+ unet = UNet2DConditionModel.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet")
+ cloth_encoder = ClothEncoder.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
+
+ unet_checkpoint_path = hf_hub_download(
+ repo_id="Benrise/VITON-HD",
+ filename="VITONHD/model/pytorch_model.bin",
+ cache_dir="checkpoints"
+ )
+ unet.load_state_dict(torch.load(unet_checkpoint_path))
+
+ models = {
+ "unet": unet.to(device, dtype=weight_dtype),
+ "vae": vae.to(device, dtype=weight_dtype),
+ "text_encoder": text_encoder.to(device, dtype=weight_dtype),
+ "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
+ "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
+ "noise_scheduler": noise_scheduler,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2
+ }
+
+ pipeline = PromptDresser(
+ vae=models["vae"],
+ text_encoder=models["text_encoder"],
+ text_encoder_2=models["text_encoder_2"],
+ tokenizer=models["tokenizer"],
+ tokenizer_2=models["tokenizer_2"],
+ unet=models["unet"],
+ scheduler=models["noise_scheduler"],
+ ).to(device, dtype=weight_dtype)
+
+ return {**models, "pipeline": pipeline}
+
+models = load_models()
+pipeline = models["pipeline"]
+
+def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt=""):
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ person_path = os.path.join(tmp_dir, "person.png")
+ cloth_path = os.path.join(tmp_dir, "cloth.png")
+
+ person_image.save(person_path)
+ cloth_image.save(cloth_path)
+
+ mask_path = os.path.join(tmp_dir, "mask.png")
+ pose_path = os.path.join(tmp_dir, "pose.png")
+
+ mask_image = generate_clothing_mask(person_path, label=4, output_path=mask_path, show_result=False)
+ pose_image = generate_openpose(person_path, output_image_path=pose_path, show_result=False)
+
+ auto_outfit_prompt = generate_caption(person_path, device)
+ auto_clothing_prompt = generate_caption(cloth_path, device)
+
+ final_outfit_prompt = outfit_prompt or auto_outfit_prompt
+ final_clothing_prompt = clothing_prompt or auto_clothing_prompt
+
+ with torch.autocast(device):
+ result = pipeline(
+ image=person_image,
+ mask_image=mask_image,
+ pose_image=pose_image,
+ cloth_encoder=models["cloth_encoder"],
+ cloth_encoder_image=cloth_image,
+ prompt=final_outfit_prompt,
+ prompt_clothing=final_clothing_prompt,
+ height=1024,
+ width=768,
+ guidance_scale=2.0,
+ guidance_scale_img=4.5,
+ guidance_scale_text=7.5,
+ num_inference_steps=30,
+ strength=1,
+ interm_cloth_start_ratio=0.5,
+ generator=None,
+ ).images[0]
+
+ return result
+
+with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
+ gr.Markdown("# 🧥 Virtual Try-On")
+ gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
+
+ with gr.Row():
+ with gr.Column():
+ person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
+ cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
+ outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
+ clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
+ generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
+
+ gr.Examples(
+ examples=[
+ ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve"]
+ ],
+ inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
+ label="Примеры для быстрого тестирования"
+ )
+
+ with gr.Column():
+ output_image = gr.Image(label="Результат примерки", interactive=False)
+
+ generate_btn.click(
+ fn=generate_vton,
+ inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt],
+ outputs=output_image
+ )
+
+ gr.Markdown("### Инструкция:")
+ gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
+ "2. Загрузите фото одежды на белом фоне\n"
+ "3. При необходимости уточните описание образа или одежды\n"
+ "4. Нажмите кнопку 'Сгенерировать примерку'")
+
+if __name__ == "__main__":
+ demo.queue(max_size=3).launch(
+ server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
+ share=os.getenv("GRADIO_SHARE") == "True",
+ debug=True
+ )
\ No newline at end of file
diff --git a/configs/VITONHD.yaml b/configs/VITONHD.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..06cfd571a125c2144601a3c44c61af10793f12cb
--- /dev/null
+++ b/configs/VITONHD.yaml
@@ -0,0 +1,32 @@
+no_pose: True
+use_jointcond: True
+no_ipadapter: True
+
+use_interm_cloth_mask: True
+interm_cloth_start_ratio: 0.5
+
+dataset:
+ dataset_name: "VITONHDDataset"
+ data_root_dir: "./DATA/zalando-hd-resized"
+ img_spatial_transform_lst:
+ - "randomresizedcrop"
+ - "randomaffine"
+ cloth_spatial_transform_lst:
+ - "randomresizedcrop"
+ - "randomaffine"
+ img_cloth_spatial_transform_lst:
+ - "hflip"
+ color_transform_lst:
+ - "colorjitter"
+ i_drop_rate: 0.05
+ pose_type: "densepose"
+ train_folder_name: train_fine
+ test_folder_name: test_fine
+ prompt_version: v12
+ text_file_postfix: "gpt4o.json"
+ train_folder_name_for_interm_cloth_mask: train_coarse
+ test_folder_name_for_interm_cloth_mask: test_coarse
+ use_rand_dilate: True
+
+ rand_dilate_miniter: 0
+ rand_dilate_maxiter: 200
\ No newline at end of file
diff --git a/lib/caption.py b/lib/caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6fc8c7a342e1f48958e34675e6fde8e2275750b
--- /dev/null
+++ b/lib/caption.py
@@ -0,0 +1,19 @@
+from PIL import Image
+from transformers import AutoProcessor, AutoModelForCausalLM
+
+
+def generate_caption(image_path, device="cuda"):
+ print("Генерация подписи...")
+ processor = AutoProcessor.from_pretrained("microsoft/git-base", use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained("microsoft/git-base").to(device)
+ image = Image.open(image_path).convert("RGB")
+
+ inputs = processor(images=image, return_tensors="pt").to(device)
+ generated_ids = model.generate(
+ pixel_values=inputs.pixel_values,
+ max_length=50,
+ pad_token_id=processor.tokenizer.pad_token_id
+ )
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ print("Сгенерированная подпись:", caption)
+ return caption
\ No newline at end of file
diff --git a/lib/mask.py b/lib/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7ec0e6e26885f3d5b0277e3d5bf062011e8035
--- /dev/null
+++ b/lib/mask.py
@@ -0,0 +1,64 @@
+from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
+from PIL import Image
+import numpy as np
+import requests
+import torch.nn.functional as F
+import torch
+import os
+
+def generate_clothing_mask(
+ image_path: str,
+ label: int,
+ output_path: str = "./output_mask.png",
+ model_name: str = "mattmdjaga/segformer_b2_clothes",
+ show_result: bool = False
+) -> Image.Image:
+ """
+ Генерирует бинарную маску для указанного класса одежды и сохраняет её
+
+ Args:
+ image_path: Путь к изображению или URL
+ label: Класс для сегментации (0-17)
+ output_path: Путь для сохранения маски
+ model_name: Название модели HuggingFace
+ show_result: Показать результат matplotlib
+
+ Returns:
+ PIL.Image: Бинарная маска (белый - выбранный класс, черный - остальное)
+ """
+
+ processor = SegformerImageProcessor.from_pretrained(model_name)
+ model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
+
+ if image_path.startswith(('http://', 'https://')):
+ image = Image.open(requests.get(image_path, stream=True).raw)
+ else:
+ image = Image.open(image_path)
+
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+
+ image_np = np.array(image)
+ if len(image_np.shape) != 3 or image_np.shape[2] != 3:
+ raise ValueError("Изображение должно быть в формате RGB (H, W, 3)")
+
+ inputs = processor(images=image, return_tensors="pt")
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ logits = outputs.logits
+ upsampled_logits = F.interpolate(
+ logits,
+ size=image.size[::-1],
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
+ mask = (pred_seg == label).numpy().astype('uint8') * 255
+ mask_image = Image.fromarray(mask)
+
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ mask_image.save(output_path)
+
+ return mask_image
diff --git a/lib/pose.py b/lib/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..761e919116a0baceb6e0e95bb26598f1dcebf026
--- /dev/null
+++ b/lib/pose.py
@@ -0,0 +1,36 @@
+from controlnet_aux import OpenposeDetector
+from PIL import Image
+import torch
+
+
+def generate_openpose(
+ input_image_path: str,
+ output_image_path: str = None,
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
+ show_result: bool = False
+) -> Image.Image:
+ """
+ Генерирует OpenPose карту позы из входного изображения.
+
+ Параметры:
+ input_image_path (str): Путь к исходному изображению
+ output_image_path (str, optional): Путь для сохранения результата. Если None - не сохраняется.
+ device (str): Устройство для обработки ('cuda' или 'cpu')
+ show_result (bool): Показывать ли результат сразу
+
+ Возвращает:
+ Image.Image: Изображение с OpenPose картой позы
+ """
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet").to(device)
+
+ image = Image.open(input_image_path).convert("RGB")
+
+ openpose_map = openpose(image)
+
+ if output_image_path:
+ openpose_map.save(output_image_path)
+
+ if show_result:
+ openpose_map.show()
+
+ return image
diff --git a/preprocess/__init__.py b/preprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/humanparsing/__init__.py b/preprocess/humanparsing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/humanparsing/datasets/__init__.py b/preprocess/humanparsing/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/humanparsing/datasets/datasets.py b/preprocess/humanparsing/datasets/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..433f15af93029538b3b039f8f207764fcfe426d9
--- /dev/null
+++ b/preprocess/humanparsing/datasets/datasets.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : datasets.py
+@Time : 8/4/19 3:35 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import os
+import numpy as np
+import random
+import torch
+import cv2
+from torch.utils import data
+from utils.transforms import get_affine_transform
+
+
+class LIPDataSet(data.Dataset):
+ def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25,
+ rotation_factor=30, ignore_label=255, transform=None):
+ self.root = root
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
+ self.crop_size = np.asarray(crop_size)
+ self.ignore_label = ignore_label
+ self.scale_factor = scale_factor
+ self.rotation_factor = rotation_factor
+ self.flip_prob = 0.5
+ self.transform = transform
+ self.dataset = dataset
+
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
+ train_list = [i_id.strip() for i_id in open(list_path)]
+
+ self.train_list = train_list
+ self.number_samples = len(self.train_list)
+
+ def __len__(self):
+ return self.number_samples
+
+ def _box2cs(self, box):
+ x, y, w, h = box[:4]
+ return self._xywh2cs(x, y, w, h)
+
+ def _xywh2cs(self, x, y, w, h):
+ center = np.zeros((2), dtype=np.float32)
+ center[0] = x + w * 0.5
+ center[1] = y + h * 0.5
+ if w > self.aspect_ratio * h:
+ h = w * 1.0 / self.aspect_ratio
+ elif w < self.aspect_ratio * h:
+ w = h * self.aspect_ratio
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
+ return center, scale
+
+ def __getitem__(self, index):
+ train_item = self.train_list[index]
+
+ im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
+ parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')
+
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
+ h, w, _ = im.shape
+ parsing_anno = np.zeros((h, w), dtype=np.long)
+
+ # Get person center and scale
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
+ r = 0
+
+ if self.dataset != 'test':
+ # Get pose annotation
+ parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
+ if self.dataset == 'train' or self.dataset == 'trainval':
+ sf = self.scale_factor
+ rf = self.rotation_factor
+ s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
+ r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
+
+ if random.random() <= self.flip_prob:
+ im = im[:, ::-1, :]
+ parsing_anno = parsing_anno[:, ::-1]
+ person_center[0] = im.shape[1] - person_center[0] - 1
+ right_idx = [15, 17, 19]
+ left_idx = [14, 16, 18]
+ for i in range(0, 3):
+ right_pos = np.where(parsing_anno == right_idx[i])
+ left_pos = np.where(parsing_anno == left_idx[i])
+ parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
+ parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
+
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
+ input = cv2.warpAffine(
+ im,
+ trans,
+ (int(self.crop_size[1]), int(self.crop_size[0])),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0, 0, 0))
+
+ if self.transform:
+ input = self.transform(input)
+
+ meta = {
+ 'name': train_item,
+ 'center': person_center,
+ 'height': h,
+ 'width': w,
+ 'scale': s,
+ 'rotation': r
+ }
+
+ if self.dataset == 'val' or self.dataset == 'test':
+ return input, meta
+ else:
+ label_parsing = cv2.warpAffine(
+ parsing_anno,
+ trans,
+ (int(self.crop_size[1]), int(self.crop_size[0])),
+ flags=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(255))
+
+ label_parsing = torch.from_numpy(label_parsing)
+
+ return input, label_parsing, meta
+
+
+class LIPDataValSet(data.Dataset):
+ def __init__(self, root, dataset='val', crop_size=[473, 473], transform=None, flip=False):
+ self.root = root
+ self.crop_size = crop_size
+ self.transform = transform
+ self.flip = flip
+ self.dataset = dataset
+ self.root = root
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
+ self.crop_size = np.asarray(crop_size)
+
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
+ val_list = [i_id.strip() for i_id in open(list_path)]
+
+ self.val_list = val_list
+ self.number_samples = len(self.val_list)
+
+ def __len__(self):
+ return len(self.val_list)
+
+ def _box2cs(self, box):
+ x, y, w, h = box[:4]
+ return self._xywh2cs(x, y, w, h)
+
+ def _xywh2cs(self, x, y, w, h):
+ center = np.zeros((2), dtype=np.float32)
+ center[0] = x + w * 0.5
+ center[1] = y + h * 0.5
+ if w > self.aspect_ratio * h:
+ h = w * 1.0 / self.aspect_ratio
+ elif w < self.aspect_ratio * h:
+ w = h * self.aspect_ratio
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
+
+ return center, scale
+
+ def __getitem__(self, index):
+ val_item = self.val_list[index]
+ # Load training image
+ im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg')
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
+ h, w, _ = im.shape
+ # Get person center and scale
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
+ r = 0
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
+ input = cv2.warpAffine(
+ im,
+ trans,
+ (int(self.crop_size[1]), int(self.crop_size[0])),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0, 0, 0))
+ input = self.transform(input)
+ flip_input = input.flip(dims=[-1])
+ if self.flip:
+ batch_input_im = torch.stack([input, flip_input])
+ else:
+ batch_input_im = input
+
+ meta = {
+ 'name': val_item,
+ 'center': person_center,
+ 'height': h,
+ 'width': w,
+ 'scale': s,
+ 'rotation': r
+ }
+
+ return batch_input_im, meta
diff --git a/preprocess/humanparsing/datasets/simple_extractor_dataset.py b/preprocess/humanparsing/datasets/simple_extractor_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e85240701231f9789b822219c8b9eda47be4de
--- /dev/null
+++ b/preprocess/humanparsing/datasets/simple_extractor_dataset.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : dataset.py
+@Time : 8/30/19 9:12 PM
+@Desc : Dataset Definition
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import os
+import pdb
+
+import cv2
+import numpy as np
+from PIL import Image
+from torch.utils import data
+from utils.transforms import get_affine_transform
+
+
+class SimpleFolderDataset(data.Dataset):
+ def __init__(self, root, input_size=[512, 512], transform=None):
+ self.root = root
+ self.input_size = input_size
+ self.transform = transform
+ self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
+ self.input_size = np.asarray(input_size)
+ self.is_pil_image = False
+ if isinstance(root, Image.Image):
+ self.file_list = [root]
+ self.is_pil_image = True
+ elif os.path.isfile(root):
+ self.file_list = [os.path.basename(root)]
+ self.root = os.path.dirname(root)
+ else:
+ self.file_list = os.listdir(self.root)
+
+ def __len__(self):
+ return len(self.file_list)
+
+ def _box2cs(self, box):
+ x, y, w, h = box[:4]
+ return self._xywh2cs(x, y, w, h)
+
+ def _xywh2cs(self, x, y, w, h):
+ center = np.zeros((2), dtype=np.float32)
+ center[0] = x + w * 0.5
+ center[1] = y + h * 0.5
+ if w > self.aspect_ratio * h:
+ h = w * 1.0 / self.aspect_ratio
+ elif w < self.aspect_ratio * h:
+ w = h * self.aspect_ratio
+ scale = np.array([w, h], dtype=np.float32)
+ return center, scale
+
+ def __getitem__(self, index):
+ if self.is_pil_image:
+ img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
+ else:
+ img_name = self.file_list[index]
+ img_path = os.path.join(self.root, img_name)
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+ h, w, _ = img.shape
+
+ # Get person center and scale
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
+ r = 0
+ trans = get_affine_transform(person_center, s, r, self.input_size)
+ input = cv2.warpAffine(
+ img,
+ trans,
+ (int(self.input_size[1]), int(self.input_size[0])),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0, 0, 0))
+
+ input = self.transform(input)
+ meta = {
+ 'center': person_center,
+ 'height': h,
+ 'width': w,
+ 'scale': s,
+ 'rotation': r
+ }
+
+ return input, meta
diff --git a/preprocess/humanparsing/datasets/target_generation.py b/preprocess/humanparsing/datasets/target_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8524db4427755c12ce71a4292d87ebb3e91762c1
--- /dev/null
+++ b/preprocess/humanparsing/datasets/target_generation.py
@@ -0,0 +1,40 @@
+import torch
+from torch.nn import functional as F
+
+
+def generate_edge_tensor(label, edge_width=3):
+ label = label.type(torch.cuda.FloatTensor)
+ if len(label.shape) == 2:
+ label = label.unsqueeze(0)
+ n, h, w = label.shape
+ edge = torch.zeros(label.shape, dtype=torch.float).cuda()
+ # right
+ edge_right = edge[:, 1:h, :]
+ edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
+ & (label[:, :h - 1, :] != 255)] = 1
+
+ # up
+ edge_up = edge[:, :, :w - 1]
+ edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
+ & (label[:, :, :w - 1] != 255)
+ & (label[:, :, 1:w] != 255)] = 1
+
+ # upright
+ edge_upright = edge[:, :h - 1, :w - 1]
+ edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
+ & (label[:, :h - 1, :w - 1] != 255)
+ & (label[:, 1:h, 1:w] != 255)] = 1
+
+ # bottomright
+ edge_bottomright = edge[:, :h - 1, 1:w]
+ edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
+ & (label[:, :h - 1, 1:w] != 255)
+ & (label[:, 1:h, :w - 1] != 255)] = 1
+
+ kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
+ with torch.no_grad():
+ edge = edge.unsqueeze(1)
+ edge = F.conv2d(edge, kernel, stride=1, padding=1)
+ edge[edge!=0] = 1
+ edge = edge.squeeze()
+ return edge
diff --git a/preprocess/humanparsing/modules/__init__.py b/preprocess/humanparsing/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a098dee5911f3613d320d23db37bc401cf57fa4
--- /dev/null
+++ b/preprocess/humanparsing/modules/__init__.py
@@ -0,0 +1,5 @@
+from .bn import ABN, InPlaceABN, InPlaceABNSync
+from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
+from .misc import GlobalAvgPool2d, SingleGPU
+from .residual import IdentityResidualBlock
+from .dense import DenseModule
diff --git a/preprocess/humanparsing/modules/bn.py b/preprocess/humanparsing/modules/bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a794698867e89140a030d550d832e6fa12561c8b
--- /dev/null
+++ b/preprocess/humanparsing/modules/bn.py
@@ -0,0 +1,132 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as functional
+
+try:
+ from queue import Queue
+except ImportError:
+ from Queue import Queue
+
+from .functions import *
+
+
+class ABN(nn.Module):
+ """Activated Batch Normalization
+
+ This gathers a `BatchNorm2d` and an activation function in a single module
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
+ """Creates an Activated Batch Normalization module
+
+ Parameters
+ ----------
+ num_features : int
+ Number of feature channels in the input and output.
+ eps : float
+ Small constant to prevent numerical issues.
+ momentum : float
+ Momentum factor applied to compute running statistics as.
+ affine : bool
+ If `True` apply learned scale and shift transformation after normalization.
+ activation : str
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
+ slope : float
+ Negative slope for the `leaky_relu` activation.
+ """
+ super(ABN, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+ self.momentum = momentum
+ self.activation = activation
+ self.slope = slope
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.constant_(self.running_mean, 0)
+ nn.init.constant_(self.running_var, 1)
+ if self.affine:
+ nn.init.constant_(self.weight, 1)
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, x):
+ x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ if self.activation == ACT_RELU:
+ return functional.relu(x, inplace=True)
+ elif self.activation == ACT_LEAKY_RELU:
+ return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
+ elif self.activation == ACT_ELU:
+ return functional.elu(x, inplace=True)
+ else:
+ return x
+
+ def __repr__(self):
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
+ ' affine={affine}, activation={activation}'
+ if self.activation == "leaky_relu":
+ rep += ', slope={slope})'
+ else:
+ rep += ')'
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
+
+
+class InPlaceABN(ABN):
+ """InPlace Activated Batch Normalization"""
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
+ """Creates an InPlace Activated Batch Normalization module
+
+ Parameters
+ ----------
+ num_features : int
+ Number of feature channels in the input and output.
+ eps : float
+ Small constant to prevent numerical issues.
+ momentum : float
+ Momentum factor applied to compute running statistics as.
+ affine : bool
+ If `True` apply learned scale and shift transformation after normalization.
+ activation : str
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
+ slope : float
+ Negative slope for the `leaky_relu` activation.
+ """
+ super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
+
+ def forward(self, x):
+ x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
+ self.training, self.momentum, self.eps, self.activation, self.slope)
+ return x
+
+
+class InPlaceABNSync(ABN):
+ """InPlace Activated Batch Normalization with cross-GPU synchronization
+ This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
+ """
+
+ def forward(self, x):
+ x, _, _ = inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
+ self.training, self.momentum, self.eps, self.activation, self.slope)
+ return x
+
+ def __repr__(self):
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
+ ' affine={affine}, activation={activation}'
+ if self.activation == "leaky_relu":
+ rep += ', slope={slope})'
+ else:
+ rep += ')'
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
+
+
diff --git a/preprocess/humanparsing/modules/deeplab.py b/preprocess/humanparsing/modules/deeplab.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd25b78369b27ef02c183a0b17b9bf8354c5f7c3
--- /dev/null
+++ b/preprocess/humanparsing/modules/deeplab.py
@@ -0,0 +1,84 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as functional
+
+from models._util import try_index
+from .bn import ABN
+
+
+class DeeplabV3(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ hidden_channels=256,
+ dilations=(12, 24, 36),
+ norm_act=ABN,
+ pooling_size=None):
+ super(DeeplabV3, self).__init__()
+ self.pooling_size = pooling_size
+
+ self.map_convs = nn.ModuleList([
+ nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
+ ])
+ self.map_bn = norm_act(hidden_channels * 4)
+
+ self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
+ self.global_pooling_bn = norm_act(hidden_channels)
+
+ self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
+ self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
+ self.red_bn = norm_act(out_channels)
+
+ self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
+
+ def reset_parameters(self, activation, slope):
+ gain = nn.init.calculate_gain(activation, slope)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight.data, gain)
+ if hasattr(m, "bias") and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, ABN):
+ if hasattr(m, "weight") and m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if hasattr(m, "bias") and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ # Map convolutions
+ out = torch.cat([m(x) for m in self.map_convs], dim=1)
+ out = self.map_bn(out)
+ out = self.red_conv(out)
+
+ # Global pooling
+ pool = self._global_pooling(x)
+ pool = self.global_pooling_conv(pool)
+ pool = self.global_pooling_bn(pool)
+ pool = self.pool_red_conv(pool)
+ if self.training or self.pooling_size is None:
+ pool = pool.repeat(1, 1, x.size(2), x.size(3))
+
+ out += pool
+ out = self.red_bn(out)
+ return out
+
+ def _global_pooling(self, x):
+ if self.training or self.pooling_size is None:
+ pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
+ pool = pool.view(x.size(0), x.size(1), 1, 1)
+ else:
+ pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
+ min(try_index(self.pooling_size, 1), x.shape[3]))
+ padding = (
+ (pooling_size[1] - 1) // 2,
+ (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
+ (pooling_size[0] - 1) // 2,
+ (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
+ )
+
+ pool = functional.avg_pool2d(x, pooling_size, stride=1)
+ pool = functional.pad(pool, pad=padding, mode="replicate")
+ return pool
diff --git a/preprocess/humanparsing/modules/dense.py b/preprocess/humanparsing/modules/dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..9638d6e86d2ae838550fefa9002a984af52e6cc8
--- /dev/null
+++ b/preprocess/humanparsing/modules/dense.py
@@ -0,0 +1,42 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+from .bn import ABN
+
+
+class DenseModule(nn.Module):
+ def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
+ super(DenseModule, self).__init__()
+ self.in_channels = in_channels
+ self.growth = growth
+ self.layers = layers
+
+ self.convs1 = nn.ModuleList()
+ self.convs3 = nn.ModuleList()
+ for i in range(self.layers):
+ self.convs1.append(nn.Sequential(OrderedDict([
+ ("bn", norm_act(in_channels)),
+ ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
+ ])))
+ self.convs3.append(nn.Sequential(OrderedDict([
+ ("bn", norm_act(self.growth * bottleneck_factor)),
+ ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
+ dilation=dilation))
+ ])))
+ in_channels += self.growth
+
+ @property
+ def out_channels(self):
+ return self.in_channels + self.growth * self.layers
+
+ def forward(self, x):
+ inputs = [x]
+ for i in range(self.layers):
+ x = torch.cat(inputs, dim=1)
+ x = self.convs1[i](x)
+ x = self.convs3[i](x)
+ inputs += [x]
+
+ return torch.cat(inputs, dim=1)
diff --git a/preprocess/humanparsing/modules/functions.py b/preprocess/humanparsing/modules/functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2837260687dde56d4595b24aded5fddbc4bda8
--- /dev/null
+++ b/preprocess/humanparsing/modules/functions.py
@@ -0,0 +1,245 @@
+import pdb
+from os import path
+import torch
+import torch.distributed as dist
+import torch.autograd as autograd
+import torch.cuda.comm as comm
+from torch.autograd.function import once_differentiable
+from torch.utils.cpp_extension import load
+
+_src_path = path.join(path.dirname(path.abspath(__file__)), "src")
+_backend = load(name="inplace_abn",
+ extra_cflags=["-O3"],
+ sources=[path.join(_src_path, f) for f in [
+ "inplace_abn.cpp",
+ "inplace_abn_cpu.cpp",
+ "inplace_abn_cuda.cu",
+ "inplace_abn_cuda_half.cu"
+ ]],
+ extra_cuda_cflags=["--expt-extended-lambda"])
+
+# Activation names
+ACT_RELU = "relu"
+ACT_LEAKY_RELU = "leaky_relu"
+ACT_ELU = "elu"
+ACT_NONE = "none"
+
+
+def _check(fn, *args, **kwargs):
+ success = fn(*args, **kwargs)
+ if not success:
+ raise RuntimeError("CUDA Error encountered in {}".format(fn))
+
+
+def _broadcast_shape(x):
+ out_size = []
+ for i, s in enumerate(x.size()):
+ if i != 1:
+ out_size.append(1)
+ else:
+ out_size.append(s)
+ return out_size
+
+
+def _reduce(x):
+ if len(x.size()) == 2:
+ return x.sum(dim=0)
+ else:
+ n, c = x.size()[0:2]
+ return x.contiguous().view((n, c, -1)).sum(2).sum(0)
+
+
+def _count_samples(x):
+ count = 1
+ for i, s in enumerate(x.size()):
+ if i != 1:
+ count *= s
+ return count
+
+
+def _act_forward(ctx, x):
+ if ctx.activation == ACT_LEAKY_RELU:
+ _backend.leaky_relu_forward(x, ctx.slope)
+ elif ctx.activation == ACT_ELU:
+ _backend.elu_forward(x)
+ elif ctx.activation == ACT_NONE:
+ pass
+
+
+def _act_backward(ctx, x, dx):
+ if ctx.activation == ACT_LEAKY_RELU:
+ _backend.leaky_relu_backward(x, dx, ctx.slope)
+ elif ctx.activation == ACT_ELU:
+ _backend.elu_backward(x, dx)
+ elif ctx.activation == ACT_NONE:
+ pass
+
+
+class InPlaceABN(autograd.Function):
+ @staticmethod
+ def forward(ctx, x, weight, bias, running_mean, running_var,
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
+ # Save context
+ ctx.training = training
+ ctx.momentum = momentum
+ ctx.eps = eps
+ ctx.activation = activation
+ ctx.slope = slope
+ ctx.affine = weight is not None and bias is not None
+
+ # Prepare inputs
+ count = _count_samples(x)
+ x = x.contiguous()
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
+
+ if ctx.training:
+ mean, var = _backend.mean_var(x)
+
+ # Update running stats
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
+
+ # Mark in-place modified tensors
+ ctx.mark_dirty(x, running_mean, running_var)
+ else:
+ mean, var = running_mean.contiguous(), running_var.contiguous()
+ ctx.mark_dirty(x)
+
+ # BN forward + activation
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
+ _act_forward(ctx, x)
+
+ # Output
+ ctx.var = var
+ ctx.save_for_backward(x, var, weight, bias)
+ ctx.mark_non_differentiable(running_mean, running_var)
+ return x, running_mean, running_var
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, dz, _drunning_mean, _drunning_var):
+ z, var, weight, bias = ctx.saved_tensors
+ dz = dz.contiguous()
+
+ # Undo activation
+ _act_backward(ctx, z, dz)
+
+ if ctx.training:
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
+ else:
+ # TODO: implement simplified CUDA backward for inference mode
+ edz = dz.new_zeros(dz.size(1))
+ eydz = dz.new_zeros(dz.size(1))
+
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
+ # dweight = eydz * weight.sign() if ctx.affine else None
+ dweight = eydz if ctx.affine else None
+ if dweight is not None:
+ dweight[weight < 0] *= -1
+ dbias = edz if ctx.affine else None
+
+ return dx, dweight, dbias, None, None, None, None, None, None, None
+
+
+class InPlaceABNSync(autograd.Function):
+ @classmethod
+ def forward(cls, ctx, x, weight, bias, running_mean, running_var,
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
+ # Save context
+ ctx.training = training
+ ctx.momentum = momentum
+ ctx.eps = eps
+ ctx.activation = activation
+ ctx.slope = slope
+ ctx.affine = weight is not None and bias is not None
+
+ # Prepare inputs
+ ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
+
+ # count = _count_samples(x)
+ batch_size = x.new_tensor([x.shape[0]], dtype=torch.long)
+
+ x = x.contiguous()
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
+
+ if ctx.training:
+ mean, var = _backend.mean_var(x)
+ if ctx.world_size > 1:
+ # get global batch size
+ if equal_batches:
+ batch_size *= ctx.world_size
+ else:
+ dist.all_reduce(batch_size, dist.ReduceOp.SUM)
+
+ ctx.factor = x.shape[0] / float(batch_size.item())
+
+ mean_all = mean.clone() * ctx.factor
+ dist.all_reduce(mean_all, dist.ReduceOp.SUM)
+
+ var_all = (var + (mean - mean_all) ** 2) * ctx.factor
+ dist.all_reduce(var_all, dist.ReduceOp.SUM)
+
+ mean = mean_all
+ var = var_all
+
+ # Update running stats
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
+ count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1]
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
+
+ # Mark in-place modified tensors
+ ctx.mark_dirty(x, running_mean, running_var)
+ else:
+ mean, var = running_mean.contiguous(), running_var.contiguous()
+ ctx.mark_dirty(x)
+
+ # BN forward + activation
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
+ _act_forward(ctx, x)
+
+ # Output
+ ctx.var = var
+ ctx.save_for_backward(x, var, weight, bias)
+ ctx.mark_non_differentiable(running_mean, running_var)
+ return x, running_mean, running_var
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, dz, _drunning_mean, _drunning_var):
+ z, var, weight, bias = ctx.saved_tensors
+ dz = dz.contiguous()
+
+ # Undo activation
+ _act_backward(ctx, z, dz)
+
+ if ctx.training:
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
+ edz_local = edz.clone()
+ eydz_local = eydz.clone()
+
+ if ctx.world_size > 1:
+ edz *= ctx.factor
+ dist.all_reduce(edz, dist.ReduceOp.SUM)
+
+ eydz *= ctx.factor
+ dist.all_reduce(eydz, dist.ReduceOp.SUM)
+ else:
+ edz_local = edz = dz.new_zeros(dz.size(1))
+ eydz_local = eydz = dz.new_zeros(dz.size(1))
+
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
+ # dweight = eydz_local * weight.sign() if ctx.affine else None
+ dweight = eydz_local if ctx.affine else None
+ if dweight is not None:
+ dweight[weight < 0] *= -1
+ dbias = edz_local if ctx.affine else None
+
+ return dx, dweight, dbias, None, None, None, None, None, None, None
+
+
+inplace_abn = InPlaceABN.apply
+inplace_abn_sync = InPlaceABNSync.apply
+
+__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
diff --git a/preprocess/humanparsing/modules/misc.py b/preprocess/humanparsing/modules/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c50b69b38c950801baacba8b3684ffd23aef08b
--- /dev/null
+++ b/preprocess/humanparsing/modules/misc.py
@@ -0,0 +1,21 @@
+import torch.nn as nn
+import torch
+import torch.distributed as dist
+
+class GlobalAvgPool2d(nn.Module):
+ def __init__(self):
+ """Global average pooling over the input's spatial dimensions"""
+ super(GlobalAvgPool2d, self).__init__()
+
+ def forward(self, inputs):
+ in_size = inputs.size()
+ return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
+
+class SingleGPU(nn.Module):
+ def __init__(self, module):
+ super(SingleGPU, self).__init__()
+ self.module=module
+
+ def forward(self, input):
+ return self.module(input.cuda(non_blocking=True))
+
diff --git a/preprocess/humanparsing/modules/residual.py b/preprocess/humanparsing/modules/residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a5c90e0606a451ff690f67a2feac28476241d86
--- /dev/null
+++ b/preprocess/humanparsing/modules/residual.py
@@ -0,0 +1,182 @@
+from collections import OrderedDict
+
+import torch.nn as nn
+
+from .bn import ABN, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
+import torch.nn.functional as functional
+
+
+class ResidualBlock(nn.Module):
+ """Configurable residual block
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ channels : list of int
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
+ `3 x 3` then `1 x 1` convolutions.
+ stride : int
+ Stride of the first `3 x 3` convolution
+ dilation : int
+ Dilation to apply to the `3 x 3` convolutions.
+ groups : int
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
+ bottleneck blocks.
+ norm_act : callable
+ Function to create normalization / activation Module.
+ dropout: callable
+ Function to create Dropout Module.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ stride=1,
+ dilation=1,
+ groups=1,
+ norm_act=ABN,
+ dropout=None):
+ super(ResidualBlock, self).__init__()
+
+ # Check parameters for inconsistencies
+ if len(channels) != 2 and len(channels) != 3:
+ raise ValueError("channels must contain either two or three values")
+ if len(channels) == 2 and groups != 1:
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
+
+ is_bottleneck = len(channels) == 3
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
+
+ if not is_bottleneck:
+ bn2 = norm_act(channels[1])
+ bn2.activation = ACT_NONE
+ layers = [
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
+ dilation=dilation)),
+ ("bn1", norm_act(channels[0])),
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
+ dilation=dilation)),
+ ("bn2", bn2)
+ ]
+ if dropout is not None:
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
+ else:
+ bn3 = norm_act(channels[2])
+ bn3.activation = ACT_NONE
+ layers = [
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)),
+ ("bn1", norm_act(channels[0])),
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False,
+ groups=groups, dilation=dilation)),
+ ("bn2", norm_act(channels[1])),
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)),
+ ("bn3", bn3)
+ ]
+ if dropout is not None:
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
+ self.convs = nn.Sequential(OrderedDict(layers))
+
+ if need_proj_conv:
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
+ self.proj_bn = norm_act(channels[-1])
+ self.proj_bn.activation = ACT_NONE
+
+ def forward(self, x):
+ if hasattr(self, "proj_conv"):
+ residual = self.proj_conv(x)
+ residual = self.proj_bn(residual)
+ else:
+ residual = x
+ x = self.convs(x) + residual
+
+ if self.convs.bn1.activation == ACT_LEAKY_RELU:
+ return functional.leaky_relu(x, negative_slope=self.convs.bn1.slope, inplace=True)
+ elif self.convs.bn1.activation == ACT_ELU:
+ return functional.elu(x, inplace=True)
+ else:
+ return x
+
+
+class IdentityResidualBlock(nn.Module):
+ def __init__(self,
+ in_channels,
+ channels,
+ stride=1,
+ dilation=1,
+ groups=1,
+ norm_act=ABN,
+ dropout=None):
+ """Configurable identity-mapping residual block
+
+ Parameters
+ ----------
+ in_channels : int
+ Number of input channels.
+ channels : list of int
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
+ `3 x 3` then `1 x 1` convolutions.
+ stride : int
+ Stride of the first `3 x 3` convolution
+ dilation : int
+ Dilation to apply to the `3 x 3` convolutions.
+ groups : int
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
+ bottleneck blocks.
+ norm_act : callable
+ Function to create normalization / activation Module.
+ dropout: callable
+ Function to create Dropout Module.
+ """
+ super(IdentityResidualBlock, self).__init__()
+
+ # Check parameters for inconsistencies
+ if len(channels) != 2 and len(channels) != 3:
+ raise ValueError("channels must contain either two or three values")
+ if len(channels) == 2 and groups != 1:
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
+
+ is_bottleneck = len(channels) == 3
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
+
+ self.bn1 = norm_act(in_channels)
+ if not is_bottleneck:
+ layers = [
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
+ dilation=dilation)),
+ ("bn2", norm_act(channels[0])),
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
+ dilation=dilation))
+ ]
+ if dropout is not None:
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
+ else:
+ layers = [
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
+ ("bn2", norm_act(channels[0])),
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
+ groups=groups, dilation=dilation)),
+ ("bn3", norm_act(channels[1])),
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
+ ]
+ if dropout is not None:
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
+ self.convs = nn.Sequential(OrderedDict(layers))
+
+ if need_proj_conv:
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
+
+ def forward(self, x):
+ if hasattr(self, "proj_conv"):
+ bn1 = self.bn1(x)
+ shortcut = self.proj_conv(bn1)
+ else:
+ shortcut = x.clone()
+ bn1 = self.bn1(x)
+
+ out = self.convs(bn1)
+ out.add_(shortcut)
+
+ return out
diff --git a/preprocess/humanparsing/modules/src/checks.h b/preprocess/humanparsing/modules/src/checks.h
new file mode 100644
index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/checks.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include
+
+// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
+#ifndef AT_CHECK
+#define AT_CHECK AT_ASSERT
+#endif
+
+#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
+#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
+
+#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
\ No newline at end of file
diff --git a/preprocess/humanparsing/modules/src/inplace_abn.cpp b/preprocess/humanparsing/modules/src/inplace_abn.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0a6b1128cc20cbfc476134154e23e5869a92b856
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/inplace_abn.cpp
@@ -0,0 +1,95 @@
+#include
+
+#include
+
+#include "inplace_abn.h"
+
+std::vector mean_var(at::Tensor x) {
+ if (x.is_cuda()) {
+ if (x.type().scalarType() == at::ScalarType::Half) {
+ return mean_var_cuda_h(x);
+ } else {
+ return mean_var_cuda(x);
+ }
+ } else {
+ return mean_var_cpu(x);
+ }
+}
+
+at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ if (x.is_cuda()) {
+ if (x.type().scalarType() == at::ScalarType::Half) {
+ return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
+ } else {
+ return forward_cuda(x, mean, var, weight, bias, affine, eps);
+ }
+ } else {
+ return forward_cpu(x, mean, var, weight, bias, affine, eps);
+ }
+}
+
+std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ if (z.is_cuda()) {
+ if (z.type().scalarType() == at::ScalarType::Half) {
+ return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
+ } else {
+ return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
+ }
+ } else {
+ return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
+ }
+}
+
+at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
+ if (z.is_cuda()) {
+ if (z.type().scalarType() == at::ScalarType::Half) {
+ return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
+ } else {
+ return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
+ }
+ } else {
+ return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
+ }
+}
+
+void leaky_relu_forward(at::Tensor z, float slope) {
+ at::leaky_relu_(z, slope);
+}
+
+void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
+ if (z.is_cuda()) {
+ if (z.type().scalarType() == at::ScalarType::Half) {
+ return leaky_relu_backward_cuda_h(z, dz, slope);
+ } else {
+ return leaky_relu_backward_cuda(z, dz, slope);
+ }
+ } else {
+ return leaky_relu_backward_cpu(z, dz, slope);
+ }
+}
+
+void elu_forward(at::Tensor z) {
+ at::elu_(z);
+}
+
+void elu_backward(at::Tensor z, at::Tensor dz) {
+ if (z.is_cuda()) {
+ return elu_backward_cuda(z, dz);
+ } else {
+ return elu_backward_cpu(z, dz);
+ }
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("mean_var", &mean_var, "Mean and variance computation");
+ m.def("forward", &forward, "In-place forward computation");
+ m.def("edz_eydz", &edz_eydz, "First part of backward computation");
+ m.def("backward", &backward, "Second part of backward computation");
+ m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
+ m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
+ m.def("elu_forward", &elu_forward, "Elu forward computation");
+ m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
+}
diff --git a/preprocess/humanparsing/modules/src/inplace_abn.h b/preprocess/humanparsing/modules/src/inplace_abn.h
new file mode 100644
index 0000000000000000000000000000000000000000..17afd1196449ecb6376f28961e54b55e1537492f
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/inplace_abn.h
@@ -0,0 +1,88 @@
+#pragma once
+
+#include
+
+#include
+
+std::vector mean_var_cpu(at::Tensor x);
+std::vector mean_var_cuda(at::Tensor x);
+std::vector mean_var_cuda_h(at::Tensor x);
+
+at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+
+std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps);
+
+at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
+at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
+at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
+
+void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
+void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
+void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
+
+void elu_backward_cpu(at::Tensor z, at::Tensor dz);
+void elu_backward_cuda(at::Tensor z, at::Tensor dz);
+
+static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
+ num = x.size(0);
+ chn = x.size(1);
+ sp = 1;
+ for (int64_t i = 2; i < x.ndimension(); ++i)
+ sp *= x.size(i);
+}
+
+/*
+ * Specialized CUDA reduction functions for BN
+ */
+#ifdef __CUDACC__
+
+#include "utils/cuda.cuh"
+
+template
+__device__ T reduce(Op op, int plane, int N, int S) {
+ T sum = (T)0;
+ for (int batch = 0; batch < N; ++batch) {
+ for (int x = threadIdx.x; x < S; x += blockDim.x) {
+ sum += op(batch, plane, x);
+ }
+ }
+
+ // sum over NumThreads within a warp
+ sum = warpSum(sum);
+
+ // 'transpose', and reduce within warp again
+ __shared__ T shared[32];
+ __syncthreads();
+ if (threadIdx.x % WARP_SIZE == 0) {
+ shared[threadIdx.x / WARP_SIZE] = sum;
+ }
+ if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
+ // zero out the other entries in shared
+ shared[threadIdx.x] = (T)0;
+ }
+ __syncthreads();
+ if (threadIdx.x / WARP_SIZE == 0) {
+ sum = warpSum(shared[threadIdx.x]);
+ if (threadIdx.x == 0) {
+ shared[0] = sum;
+ }
+ }
+ __syncthreads();
+
+ // Everyone picks it up, should be broadcast into the whole gradInput
+ return shared[0];
+}
+#endif
diff --git a/preprocess/humanparsing/modules/src/inplace_abn_cpu.cpp b/preprocess/humanparsing/modules/src/inplace_abn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ffc6d38c52ea31661b8dd438dc3fe1958f50b61e
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/inplace_abn_cpu.cpp
@@ -0,0 +1,119 @@
+#include
+
+#include
+
+#include "utils/checks.h"
+#include "inplace_abn.h"
+
+at::Tensor reduce_sum(at::Tensor x) {
+ if (x.ndimension() == 2) {
+ return x.sum(0);
+ } else {
+ auto x_view = x.view({x.size(0), x.size(1), -1});
+ return x_view.sum(-1).sum(0);
+ }
+}
+
+at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
+ if (x.ndimension() == 2) {
+ return v;
+ } else {
+ std::vector broadcast_size = {1, -1};
+ for (int64_t i = 2; i < x.ndimension(); ++i)
+ broadcast_size.push_back(1);
+
+ return v.view(broadcast_size);
+ }
+}
+
+int64_t count(at::Tensor x) {
+ int64_t count = x.size(0);
+ for (int64_t i = 2; i < x.ndimension(); ++i)
+ count *= x.size(i);
+
+ return count;
+}
+
+at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
+ if (affine) {
+ return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
+ } else {
+ return z;
+ }
+}
+
+std::vector mean_var_cpu(at::Tensor x) {
+ auto num = count(x);
+ auto mean = reduce_sum(x) / num;
+ auto diff = x - broadcast_to(mean, x);
+ auto var = reduce_sum(diff.pow(2)) / num;
+
+ return {mean, var};
+}
+
+at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
+ auto mul = at::rsqrt(var + eps) * gamma;
+
+ x.sub_(broadcast_to(mean, x));
+ x.mul_(broadcast_to(mul, x));
+ if (affine) x.add_(broadcast_to(bias, x));
+
+ return x;
+}
+
+std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ auto edz = reduce_sum(dz);
+ auto y = invert_affine(z, weight, bias, affine, eps);
+ auto eydz = reduce_sum(y * dz);
+
+ return {edz, eydz};
+}
+
+at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
+ auto y = invert_affine(z, weight, bias, affine, eps);
+ auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
+
+ auto num = count(z);
+ auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
+ return dx;
+}
+
+void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
+ CHECK_CPU_INPUT(z);
+ CHECK_CPU_INPUT(dz);
+
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
+ int64_t count = z.numel();
+ auto *_z = z.data();
+ auto *_dz = dz.data();
+
+ for (int64_t i = 0; i < count; ++i) {
+ if (_z[i] < 0) {
+ _z[i] *= 1 / slope;
+ _dz[i] *= slope;
+ }
+ }
+ }));
+}
+
+void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
+ CHECK_CPU_INPUT(z);
+ CHECK_CPU_INPUT(dz);
+
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
+ int64_t count = z.numel();
+ auto *_z = z.data();
+ auto *_dz = dz.data();
+
+ for (int64_t i = 0; i < count; ++i) {
+ if (_z[i] < 0) {
+ _z[i] = log1p(_z[i]);
+ _dz[i] *= (_z[i] + 1.f);
+ }
+ }
+ }));
+}
diff --git a/preprocess/humanparsing/modules/src/inplace_abn_cuda.cu b/preprocess/humanparsing/modules/src/inplace_abn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..b157b06d47173d1645c6a40c89f564b737e84d43
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/inplace_abn_cuda.cu
@@ -0,0 +1,333 @@
+#include
+
+#include
+#include
+
+#include
+
+#include "utils/checks.h"
+#include "utils/cuda.cuh"
+#include "inplace_abn.h"
+
+#include
+
+// Operations for reduce
+template
+struct SumOp {
+ __device__ SumOp(const T *t, int c, int s)
+ : tensor(t), chn(c), sp(s) {}
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
+ return tensor[(batch * chn + plane) * sp + n];
+ }
+ const T *tensor;
+ const int chn;
+ const int sp;
+};
+
+template
+struct VarOp {
+ __device__ VarOp(T m, const T *t, int c, int s)
+ : mean(m), tensor(t), chn(c), sp(s) {}
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
+ T val = tensor[(batch * chn + plane) * sp + n];
+ return (val - mean) * (val - mean);
+ }
+ const T mean;
+ const T *tensor;
+ const int chn;
+ const int sp;
+};
+
+template
+struct GradOp {
+ __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
+ __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
+ T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
+ T _dz = dz[(batch * chn + plane) * sp + n];
+ return Pair(_dz, _y * _dz);
+ }
+ const T weight;
+ const T bias;
+ const T *z;
+ const T *dz;
+ const int chn;
+ const int sp;
+};
+
+/***********
+ * mean_var
+ ***********/
+
+template
+__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+ T norm = T(1) / T(num * sp);
+
+ T _mean = reduce>(SumOp(x, chn, sp), plane, num, sp) * norm;
+ __syncthreads();
+ T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, sp) * norm;
+
+ if (threadIdx.x == 0) {
+ mean[plane] = _mean;
+ var[plane] = _var;
+ }
+}
+
+std::vector mean_var_cuda(at::Tensor x) {
+ CHECK_CUDA_INPUT(x);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Prepare output tensors
+ auto mean = at::empty({chn}, x.options());
+ auto var = at::empty({chn}, x.options());
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
+ mean_var_kernel<<>>(
+ x.data(),
+ mean.data(),
+ var.data(),
+ num, chn, sp);
+ }));
+
+ return {mean, var};
+}
+
+/**********
+ * forward
+ **********/
+
+template
+__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
+ bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ T _mean = mean[plane];
+ T _var = var[plane];
+ T _weight = affine ? abs(weight[plane]) + eps : T(1);
+ T _bias = affine ? bias[plane] : T(0);
+
+ T mul = rsqrt(_var + eps) * _weight;
+
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ T _x = x[(batch * chn + plane) * sp + n];
+ T _y = (_x - _mean) * mul + _bias;
+
+ x[(batch * chn + plane) * sp + n] = _y;
+ }
+ }
+}
+
+at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ CHECK_CUDA_INPUT(x);
+ CHECK_CUDA_INPUT(mean);
+ CHECK_CUDA_INPUT(var);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
+ forward_kernel<<>>(
+ x.data(),
+ mean.data(),
+ var.data(),
+ weight.data(),
+ bias.data(),
+ affine, eps, num, chn, sp);
+ }));
+
+ return x;
+}
+
+/***********
+ * edz_eydz
+ ***********/
+
+template
+__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
+ T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
+ T _bias = affine ? bias[plane] : 0.f;
+
+ Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, sp);
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ edz[plane] = res.v1;
+ eydz[plane] = res.v2;
+ }
+}
+
+std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(z, num, chn, sp);
+
+ auto edz = at::empty({chn}, z.options());
+ auto eydz = at::empty({chn}, z.options());
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
+ edz_eydz_kernel<<>>(
+ z.data(),
+ dz.data(),
+ weight.data(),
+ bias.data(),
+ edz.data(),
+ eydz.data(),
+ affine, eps, num, chn, sp);
+ }));
+
+ return {edz, eydz};
+}
+
+/***********
+ * backward
+ ***********/
+
+template
+__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
+ const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
+ T _bias = affine ? bias[plane] : 0.f;
+ T _var = var[plane];
+ T _edz = edz[plane];
+ T _eydz = eydz[plane];
+
+ T _mul = _weight * rsqrt(_var + eps);
+ T count = T(num * sp);
+
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ T _dz = dz[(batch * chn + plane) * sp + n];
+ T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
+
+ dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
+ }
+ }
+}
+
+at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+ CHECK_CUDA_INPUT(var);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+ CHECK_CUDA_INPUT(edz);
+ CHECK_CUDA_INPUT(eydz);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(z, num, chn, sp);
+
+ auto dx = at::zeros_like(z);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
+ backward_kernel<<>>(
+ z.data(),
+ dz.data(),
+ var.data(),
+ weight.data(),
+ bias.data(),
+ edz.data(),
+ eydz.data(),
+ dx.data(),
+ affine, eps, num, chn, sp);
+ }));
+
+ return dx;
+}
+
+/**************
+ * activations
+ **************/
+
+template
+inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
+ // Create thrust pointers
+ thrust::device_ptr th_z = thrust::device_pointer_cast(z);
+ thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+ thrust::transform_if(thrust::cuda::par.on(stream),
+ th_dz, th_dz + count, th_z, th_dz,
+ [slope] __device__ (const T& dz) { return dz * slope; },
+ [] __device__ (const T& z) { return z < 0; });
+ thrust::transform_if(thrust::cuda::par.on(stream),
+ th_z, th_z + count, th_z,
+ [slope] __device__ (const T& z) { return z / slope; },
+ [] __device__ (const T& z) { return z < 0; });
+}
+
+void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+
+ int64_t count = z.numel();
+
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
+ leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
+ }));
+}
+
+template
+inline void elu_backward_impl(T *z, T *dz, int64_t count) {
+ // Create thrust pointers
+ thrust::device_ptr th_z = thrust::device_pointer_cast(z);
+ thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+ thrust::transform_if(thrust::cuda::par.on(stream),
+ th_dz, th_dz + count, th_z, th_z, th_dz,
+ [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
+ [] __device__ (const T& z) { return z < 0; });
+ thrust::transform_if(thrust::cuda::par.on(stream),
+ th_z, th_z + count, th_z,
+ [] __device__ (const T& z) { return log1p(z); },
+ [] __device__ (const T& z) { return z < 0; });
+}
+
+void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+
+ int64_t count = z.numel();
+
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
+ elu_backward_impl(z.data(), dz.data(), count);
+ }));
+}
diff --git a/preprocess/humanparsing/modules/src/inplace_abn_cuda_half.cu b/preprocess/humanparsing/modules/src/inplace_abn_cuda_half.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bb63e73f9d90179e5bd5dae5579c4844da9c25e2
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/inplace_abn_cuda_half.cu
@@ -0,0 +1,275 @@
+#include
+
+#include
+
+#include
+
+#include "utils/checks.h"
+#include "utils/cuda.cuh"
+#include "inplace_abn.h"
+
+#include
+
+// Operations for reduce
+struct SumOpH {
+ __device__ SumOpH(const half *t, int c, int s)
+ : tensor(t), chn(c), sp(s) {}
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
+ return __half2float(tensor[(batch * chn + plane) * sp + n]);
+ }
+ const half *tensor;
+ const int chn;
+ const int sp;
+};
+
+struct VarOpH {
+ __device__ VarOpH(float m, const half *t, int c, int s)
+ : mean(m), tensor(t), chn(c), sp(s) {}
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
+ const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
+ return (t - mean) * (t - mean);
+ }
+ const float mean;
+ const half *tensor;
+ const int chn;
+ const int sp;
+};
+
+struct GradOpH {
+ __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
+ __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
+ return Pair(_dz, _y * _dz);
+ }
+ const float weight;
+ const float bias;
+ const half *z;
+ const half *dz;
+ const int chn;
+ const int sp;
+};
+
+/***********
+ * mean_var
+ ***********/
+
+__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+ float norm = 1.f / static_cast(num * sp);
+
+ float _mean = reduce(SumOpH(x, chn, sp), plane, num, sp) * norm;
+ __syncthreads();
+ float _var = reduce(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
+
+ if (threadIdx.x == 0) {
+ mean[plane] = _mean;
+ var[plane] = _var;
+ }
+}
+
+std::vector mean_var_cuda_h(at::Tensor x) {
+ CHECK_CUDA_INPUT(x);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Prepare output tensors
+ auto mean = at::empty({chn},x.options().dtype(at::kFloat));
+ auto var = at::empty({chn},x.options().dtype(at::kFloat));
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ mean_var_kernel_h<<>>(
+ reinterpret_cast(x.data()),
+ mean.data(),
+ var.data(),
+ num, chn, sp);
+
+ return {mean, var};
+}
+
+/**********
+ * forward
+ **********/
+
+__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
+ bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ const float _mean = mean[plane];
+ const float _var = var[plane];
+ const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
+ const float _bias = affine ? bias[plane] : 0.f;
+
+ const float mul = rsqrt(_var + eps) * _weight;
+
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ half *x_ptr = x + (batch * chn + plane) * sp + n;
+ float _x = __half2float(*x_ptr);
+ float _y = (_x - _mean) * mul + _bias;
+
+ *x_ptr = __float2half(_y);
+ }
+ }
+}
+
+at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ CHECK_CUDA_INPUT(x);
+ CHECK_CUDA_INPUT(mean);
+ CHECK_CUDA_INPUT(var);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(x, num, chn, sp);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ forward_kernel_h<<>>(
+ reinterpret_cast(x.data()),
+ mean.data(),
+ var.data(),
+ weight.data(),
+ bias.data(),
+ affine, eps, num, chn, sp);
+
+ return x;
+}
+
+__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
+ float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
+ float _bias = affine ? bias[plane] : 0.f;
+
+ Pair res = reduce, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
+ __syncthreads();
+
+ if (threadIdx.x == 0) {
+ edz[plane] = res.v1;
+ eydz[plane] = res.v2;
+ }
+}
+
+std::vector edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
+ bool affine, float eps) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(z, num, chn, sp);
+
+ auto edz = at::empty({chn},z.options().dtype(at::kFloat));
+ auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ edz_eydz_kernel_h<<>>(
+ reinterpret_cast(z.data()),
+ reinterpret_cast(dz.data()),
+ weight.data(),
+ bias.data(),
+ edz.data(),
+ eydz.data(),
+ affine, eps, num, chn, sp);
+
+ return {edz, eydz};
+}
+
+__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
+ const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
+ int plane = blockIdx.x;
+
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
+ float _bias = affine ? bias[plane] : 0.f;
+ float _var = var[plane];
+ float _edz = edz[plane];
+ float _eydz = eydz[plane];
+
+ float _mul = _weight * rsqrt(_var + eps);
+ float count = float(num * sp);
+
+ for (int batch = 0; batch < num; ++batch) {
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
+
+ dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
+ }
+ }
+}
+
+at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+ CHECK_CUDA_INPUT(var);
+ CHECK_CUDA_INPUT(weight);
+ CHECK_CUDA_INPUT(bias);
+ CHECK_CUDA_INPUT(edz);
+ CHECK_CUDA_INPUT(eydz);
+
+ // Extract dimensions
+ int64_t num, chn, sp;
+ get_dims(z, num, chn, sp);
+
+ auto dx = at::zeros_like(z);
+
+ // Run kernel
+ dim3 blocks(chn);
+ dim3 threads(getNumThreads(sp));
+ auto stream = at::cuda::getCurrentCUDAStream();
+ backward_kernel_h<<>>(
+ reinterpret_cast(z.data()),
+ reinterpret_cast(dz.data()),
+ var.data(),
+ weight.data(),
+ bias.data(),
+ edz.data(),
+ eydz.data(),
+ reinterpret_cast(dx.data()),
+ affine, eps, num, chn, sp);
+
+ return dx;
+}
+
+__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
+ float _z = __half2float(z[i]);
+ if (_z < 0) {
+ dz[i] = __float2half(__half2float(dz[i]) * slope);
+ z[i] = __float2half(_z / slope);
+ }
+ }
+}
+
+void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
+ CHECK_CUDA_INPUT(z);
+ CHECK_CUDA_INPUT(dz);
+
+ int64_t count = z.numel();
+ dim3 threads(getNumThreads(count));
+ dim3 blocks = (count + threads.x - 1) / threads.x;
+ auto stream = at::cuda::getCurrentCUDAStream();
+ leaky_relu_backward_impl_h<<>>(
+ reinterpret_cast(z.data()),
+ reinterpret_cast(dz.data()),
+ slope, count);
+}
+
diff --git a/preprocess/humanparsing/modules/src/utils/checks.h b/preprocess/humanparsing/modules/src/utils/checks.h
new file mode 100644
index 0000000000000000000000000000000000000000..e761a6fe34d0789815b588eba7e3726026e0e868
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/utils/checks.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include
+
+// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
+#ifndef AT_CHECK
+#define AT_CHECK AT_ASSERT
+#endif
+
+#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
+#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
+
+#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
\ No newline at end of file
diff --git a/preprocess/humanparsing/modules/src/utils/common.h b/preprocess/humanparsing/modules/src/utils/common.h
new file mode 100644
index 0000000000000000000000000000000000000000..e8403eef8a233b75dd4bb353c16486fe1be2039a
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/utils/common.h
@@ -0,0 +1,49 @@
+#pragma once
+
+#include
+
+/*
+ * Functions to share code between CPU and GPU
+ */
+
+#ifdef __CUDACC__
+// CUDA versions
+
+#define HOST_DEVICE __host__ __device__
+#define INLINE_HOST_DEVICE __host__ __device__ inline
+#define FLOOR(x) floor(x)
+
+#if __CUDA_ARCH__ >= 600
+// Recent compute capabilities have block-level atomicAdd for all data types, so we use that
+#define ACCUM(x,y) atomicAdd_block(&(x),(y))
+#else
+// Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
+// and use the known atomicCAS-based implementation for double
+template
+__device__ inline data_t atomic_add(data_t *address, data_t val) {
+ return atomicAdd(address, val);
+}
+
+template<>
+__device__ inline double atomic_add(double *address, double val) {
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
+ unsigned long long int old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+#define ACCUM(x,y) atomic_add(&(x),(y))
+#endif // #if __CUDA_ARCH__ >= 600
+
+#else
+// CPU versions
+
+#define HOST_DEVICE
+#define INLINE_HOST_DEVICE inline
+#define FLOOR(x) std::floor(x)
+#define ACCUM(x,y) (x) += (y)
+
+#endif // #ifdef __CUDACC__
\ No newline at end of file
diff --git a/preprocess/humanparsing/modules/src/utils/cuda.cuh b/preprocess/humanparsing/modules/src/utils/cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..60c0023835e02c5f7c539c28ac07b75b72df394b
--- /dev/null
+++ b/preprocess/humanparsing/modules/src/utils/cuda.cuh
@@ -0,0 +1,71 @@
+#pragma once
+
+/*
+ * General settings and functions
+ */
+const int WARP_SIZE = 32;
+const int MAX_BLOCK_SIZE = 1024;
+
+static int getNumThreads(int nElem) {
+ int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
+ for (int i = 0; i < 6; ++i) {
+ if (nElem <= threadSizes[i]) {
+ return threadSizes[i];
+ }
+ }
+ return MAX_BLOCK_SIZE;
+}
+
+/*
+ * Reduction utilities
+ */
+template
+__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
+ unsigned int mask = 0xffffffff) {
+#if CUDART_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+ return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
+
+template
+struct Pair {
+ T v1, v2;
+ __device__ Pair() {}
+ __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
+ __device__ Pair(T v) : v1(v), v2(v) {}
+ __device__ Pair(int v) : v1(v), v2(v) {}
+ __device__ Pair &operator+=(const Pair &a) {
+ v1 += a.v1;
+ v2 += a.v2;
+ return *this;
+ }
+};
+
+template
+static __device__ __forceinline__ T warpSum(T val) {
+#if __CUDA_ARCH__ >= 300
+ for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
+ }
+#else
+ __shared__ T values[MAX_BLOCK_SIZE];
+ values[threadIdx.x] = val;
+ __threadfence_block();
+ const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
+ for (int i = 1; i < WARP_SIZE; i++) {
+ val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
+ }
+#endif
+ return val;
+}
+
+template
+static __device__ __forceinline__ Pair warpSum(Pair value) {
+ value.v1 = warpSum(value.v1);
+ value.v2 = warpSum(value.v2);
+ return value;
+}
\ No newline at end of file
diff --git a/preprocess/humanparsing/networks/AugmentCE2P.py b/preprocess/humanparsing/networks/AugmentCE2P.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce32f78dd0b92d943e5b1d573a33e2f69f247f23
--- /dev/null
+++ b/preprocess/humanparsing/networks/AugmentCE2P.py
@@ -0,0 +1,388 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : AugmentCE2P.py
+@Time : 8/4/19 3:35 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import functools
+import pdb
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+# Note here we adopt the InplaceABNSync implementation from https://github.com/mapillary/inplace_abn
+# By default, the InplaceABNSync module contains a BatchNorm Layer and a LeakyReLu layer
+from modules import InPlaceABNSync
+import numpy as np
+
+BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
+
+affine_par = True
+
+pretrained_settings = {
+ 'resnet101': {
+ 'imagenet': {
+ 'input_space': 'BGR',
+ 'input_size': [3, 224, 224],
+ 'input_range': [0, 1],
+ 'mean': [0.406, 0.456, 0.485],
+ 'std': [0.225, 0.224, 0.229],
+ 'num_classes': 1000
+ }
+ },
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=dilation * multi_grid, dilation=dilation * multi_grid, bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=False)
+ self.relu_inplace = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.dilation = dilation
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out = out + residual
+ out = self.relu_inplace(out)
+
+ return out
+
+
+class CostomAdaptiveAvgPool2D(nn.Module):
+
+ def __init__(self, output_size):
+
+ super(CostomAdaptiveAvgPool2D, self).__init__()
+
+ self.output_size = output_size
+
+ def forward(self, x):
+
+ H_in, W_in = x.shape[-2:]
+ H_out, W_out = self.output_size
+
+ out_i = []
+ for i in range(H_out):
+ out_j = []
+ for j in range(W_out):
+ hs = int(np.floor(i * H_in / H_out))
+ he = int(np.ceil((i + 1) * H_in / H_out))
+
+ ws = int(np.floor(j * W_in / W_out))
+ we = int(np.ceil((j + 1) * W_in / W_out))
+
+ # print(hs, he, ws, we)
+ kernel_size = [he - hs, we - ws]
+
+ out = F.avg_pool2d(x[:, :, hs:he, ws:we], kernel_size)
+ out_j.append(out)
+
+ out_j = torch.concat(out_j, -1)
+ out_i.append(out_j)
+
+ out_i = torch.concat(out_i, -2)
+ return out_i
+
+
+class PSPModule(nn.Module):
+ """
+ Reference:
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
+ """
+
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
+ super(PSPModule, self).__init__()
+
+ self.stages = []
+ tmp = []
+ for size in sizes:
+ if size == 3 or size == 6:
+ tmp.append(self._make_stage_custom(features, out_features, size))
+ else:
+ tmp.append(self._make_stage(features, out_features, size))
+ self.stages = nn.ModuleList(tmp)
+ # self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
+ bias=False),
+ InPlaceABNSync(out_features),
+ )
+
+ def _make_stage(self, features, out_features, size):
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
+ bn = InPlaceABNSync(out_features)
+ return nn.Sequential(prior, conv, bn)
+
+ def _make_stage_custom(self, features, out_features, size):
+ prior = CostomAdaptiveAvgPool2D(output_size=(size, size))
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
+ bn = InPlaceABNSync(out_features)
+ return nn.Sequential(prior, conv, bn)
+
+ def forward(self, feats):
+ h, w = feats.size(2), feats.size(3)
+ priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in
+ self.stages] + [feats]
+ bottle = self.bottleneck(torch.cat(priors, 1))
+ return bottle
+
+
+class ASPPModule(nn.Module):
+ """
+ Reference:
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
+ """
+
+ def __init__(self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)):
+ super(ASPPModule, self).__init__()
+
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
+ bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
+ InPlaceABNSync(inner_features))
+
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(out_features),
+ nn.Dropout2d(0.1)
+ )
+
+ def forward(self, x):
+ _, _, h, w = x.size()
+
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
+
+ feat2 = self.conv2(x)
+ feat3 = self.conv3(x)
+ feat4 = self.conv4(x)
+ feat5 = self.conv5(x)
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
+
+ bottle = self.bottleneck(out)
+ return bottle
+
+
+class Edge_Module(nn.Module):
+ """
+ Edge Learning Branch
+ """
+
+ def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
+ super(Edge_Module, self).__init__()
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(mid_fea)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(mid_fea)
+ )
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(mid_fea)
+ )
+ self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
+ self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
+
+ def forward(self, x1, x2, x3):
+ _, _, h, w = x1.size()
+
+ edge1_fea = self.conv1(x1)
+ edge1 = self.conv4(edge1_fea)
+ edge2_fea = self.conv2(x2)
+ edge2 = self.conv4(edge2_fea)
+ edge3_fea = self.conv3(x3)
+ edge3 = self.conv4(edge3_fea)
+
+ edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear', align_corners=True)
+ edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear', align_corners=True)
+ edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear', align_corners=True)
+ edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear', align_corners=True)
+
+ edge = torch.cat([edge1, edge2, edge3], dim=1)
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
+ edge = self.conv5(edge)
+
+ return edge, edge_fea
+
+
+class Decoder_Module(nn.Module):
+ """
+ Parsing Branch Decoder Module.
+ """
+
+ def __init__(self, num_classes):
+ super(Decoder_Module, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(256)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(48)
+ )
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(256),
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(256)
+ )
+
+ self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
+
+ def forward(self, xt, xl):
+ _, _, h, w = xl.size()
+ xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True)
+ xl = self.conv2(xl)
+ x = torch.cat([xt, xl], dim=1)
+ x = self.conv3(x)
+ seg = self.conv4(x)
+ return seg, x
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, layers, num_classes):
+ self.inplanes = 128
+ super(ResNet, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = BatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=False)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = BatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=False)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = BatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=False)
+
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1))
+
+ self.context_encoding = PSPModule(2048, 512)
+
+ self.edge = Edge_Module()
+ self.decoder = Decoder_Module(num_classes)
+
+ self.fushion = nn.Sequential(
+ nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(256),
+ nn.Dropout2d(0.1),
+ nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
+ )
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion, affine=affine_par))
+
+ layers = []
+ generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample,
+ multi_grid=generate_multi_grid(0, multi_grid)))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+ x2 = self.layer1(x)
+ x3 = self.layer2(x2)
+ x4 = self.layer3(x3)
+ x5 = self.layer4(x4)
+ x = self.context_encoding(x5)
+ parsing_result, parsing_fea = self.decoder(x, x2)
+ # Edge Branch
+ edge_result, edge_fea = self.edge(x2, x3, x4)
+ # Fusion Branch
+ x = torch.cat([parsing_fea, edge_fea], dim=1)
+ fusion_result = self.fushion(x)
+ return [[parsing_result, fusion_result], edge_result]
+
+
+def initialize_pretrained_model(model, settings, pretrained='./models/resnet101-imagenet.pth'):
+ model.input_space = settings['input_space']
+ model.input_size = settings['input_size']
+ model.input_range = settings['input_range']
+ model.mean = settings['mean']
+ model.std = settings['std']
+
+ if pretrained is not None:
+ saved_state_dict = torch.load(pretrained)
+ new_params = model.state_dict().copy()
+ for i in saved_state_dict:
+ i_parts = i.split('.')
+ if not i_parts[0] == 'fc':
+ new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
+ model.load_state_dict(new_params)
+
+
+def resnet101(num_classes=20, pretrained='./models/resnet101-imagenet.pth'):
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
+ settings = pretrained_settings['resnet101']['imagenet']
+ initialize_pretrained_model(model, settings, pretrained)
+ return model
diff --git a/preprocess/humanparsing/networks/__init__.py b/preprocess/humanparsing/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d5d384890e20652fa3ec282515ece6846ce447f
--- /dev/null
+++ b/preprocess/humanparsing/networks/__init__.py
@@ -0,0 +1,12 @@
+from __future__ import absolute_import
+from networks.AugmentCE2P import resnet101
+
+__factory = {
+ 'resnet101': resnet101,
+}
+
+
+def init_model(name, *args, **kwargs):
+ if name not in __factory.keys():
+ raise KeyError("Unknown model arch: {}".format(name))
+ return __factory[name](*args, **kwargs)
\ No newline at end of file
diff --git a/preprocess/humanparsing/networks/backbone/mobilenetv2.py b/preprocess/humanparsing/networks/backbone/mobilenetv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2fe342877cfbc5796efea85af9abccfb80a27e
--- /dev/null
+++ b/preprocess/humanparsing/networks/backbone/mobilenetv2.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : mobilenetv2.py
+@Time : 8/4/19 3:35 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import torch.nn as nn
+import math
+import functools
+
+from modules import InPlaceABN, InPlaceABNSync
+
+BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
+
+__all__ = ['mobilenetv2']
+
+
+def conv_bn(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+def conv_1x1_bn(inp, oup):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ nn.ReLU6(inplace=True)
+ )
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = round(inp * expand_ratio)
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ if expand_ratio == 1:
+ self.conv = nn.Sequential(
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ )
+ else:
+ self.conv = nn.Sequential(
+ # pw
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # dw
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
+ BatchNorm2d(hidden_dim),
+ nn.ReLU6(inplace=True),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ BatchNorm2d(oup),
+ )
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ interverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2], # layer 2
+ [6, 32, 3, 2], # layer 3
+ [6, 64, 4, 2],
+ [6, 96, 3, 1], # layer 4
+ [6, 160, 3, 2],
+ [6, 320, 1, 1], # layer 5
+ ]
+
+ # building first layer
+ assert input_size % 32 == 0
+ input_channel = int(input_channel * width_mult)
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
+ self.features = [conv_bn(3, input_channel, 2)]
+ # building inverted residual blocks
+ for t, c, n, s in interverted_residual_setting:
+ output_channel = int(c * width_mult)
+ for i in range(n):
+ if i == 0:
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
+ else:
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
+ input_channel = output_channel
+ # building last several layers
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
+ # make it nn.Sequential
+ self.features = nn.Sequential(*self.features)
+
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, n_class),
+ )
+
+ self._initialize_weights()
+
+ def forward(self, x):
+ x = self.features(x)
+ x = x.mean(3).mean(2)
+ x = self.classifier(x)
+ return x
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ n = m.weight.size(1)
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
+
+
+def mobilenetv2(pretrained=False, **kwargs):
+ """Constructs a MobileNet_V2 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = MobileNetV2(n_class=1000, **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
+ return model
diff --git a/preprocess/humanparsing/networks/backbone/resnet.py b/preprocess/humanparsing/networks/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..88d6f73bc4fc327e18123020e01ccf5c1b37f025
--- /dev/null
+++ b/preprocess/humanparsing/networks/backbone/resnet.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : resnet.py
+@Time : 8/4/19 3:35 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import functools
+import torch.nn as nn
+import math
+from torch.utils.model_zoo import load_url
+
+from modules import InPlaceABNSync
+
+BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
+
+__all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
+
+model_urls = {
+ 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
+ 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000):
+ self.inplanes = 128
+ super(ResNet, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = BatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = BatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = BatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnet18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet18']))
+ return model
+
+
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
+ return model
+
+
+def resnet101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
+ return model
diff --git a/preprocess/humanparsing/networks/backbone/resnext.py b/preprocess/humanparsing/networks/backbone/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..96adb54146addc523be71591eb93afcc2c25307f
--- /dev/null
+++ b/preprocess/humanparsing/networks/backbone/resnext.py
@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : resnext.py.py
+@Time : 8/11/19 8:58 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+import functools
+import torch.nn as nn
+import math
+from torch.utils.model_zoo import load_url
+
+from modules import InPlaceABNSync
+
+BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
+
+__all__ = ['ResNeXt', 'resnext101'] # support resnext 101
+
+model_urls = {
+ 'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
+ 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class GroupBottleneck(nn.Module):
+ expansion = 2
+
+ def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
+ super(GroupBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, groups=groups, bias=False)
+ self.bn2 = BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm2d(planes * 2)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNeXt(nn.Module):
+
+ def __init__(self, block, layers, groups=32, num_classes=1000):
+ self.inplanes = 128
+ super(ResNeXt, self).__init__()
+ self.conv1 = conv3x3(3, 64, stride=2)
+ self.bn1 = BatchNorm2d(64)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(64, 64)
+ self.bn2 = BatchNorm2d(64)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = conv3x3(64, 128)
+ self.bn3 = BatchNorm2d(128)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
+ self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
+ self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
+ self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
+ self.avgpool = nn.AvgPool2d(7, stride=1)
+ self.fc = nn.Linear(1024 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1, groups=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, groups, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=groups))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnext101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on Places
+ """
+ model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
+ if pretrained:
+ model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
+ return model
diff --git a/preprocess/humanparsing/networks/context_encoding/aspp.py b/preprocess/humanparsing/networks/context_encoding/aspp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ba531a8920665c982b1f3412bc030465d56d2a
--- /dev/null
+++ b/preprocess/humanparsing/networks/context_encoding/aspp.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : aspp.py
+@Time : 8/4/19 3:36 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from modules import InPlaceABNSync
+
+
+class ASPPModule(nn.Module):
+ """
+ Reference:
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
+ """
+ def __init__(self, features, out_features=512, inner_features=256, dilations=(12, 24, 36)):
+ super(ASPPModule, self).__init__()
+
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
+ bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
+ InPlaceABNSync(inner_features))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
+ InPlaceABNSync(inner_features))
+
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(out_features),
+ nn.Dropout2d(0.1)
+ )
+
+ def forward(self, x):
+ _, _, h, w = x.size()
+
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
+
+ feat2 = self.conv2(x)
+ feat3 = self.conv3(x)
+ feat4 = self.conv4(x)
+ feat5 = self.conv5(x)
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
+
+ bottle = self.bottleneck(out)
+ return bottle
\ No newline at end of file
diff --git a/preprocess/humanparsing/networks/context_encoding/ocnet.py b/preprocess/humanparsing/networks/context_encoding/ocnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac43ebf489ee478c48acf3f93b01b32bdb08cdf3
--- /dev/null
+++ b/preprocess/humanparsing/networks/context_encoding/ocnet.py
@@ -0,0 +1,226 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : ocnet.py
+@Time : 8/4/19 3:36 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import functools
+
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.nn import functional as F
+
+from modules import InPlaceABNSync
+BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
+
+
+class _SelfAttentionBlock(nn.Module):
+ '''
+ The basic implementation for self-attention block/non-local block
+ Input:
+ N X C X H X W
+ Parameters:
+ in_channels : the dimension of the input feature map
+ key_channels : the dimension after the key/query transform
+ value_channels : the dimension after the value transform
+ scale : choose the scale to downsample the input feature maps (save memory cost)
+ Return:
+ N X C X H X W
+ position-aware context features.(w/o concate or add with the input)
+ '''
+
+ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1):
+ super(_SelfAttentionBlock, self).__init__()
+ self.scale = scale
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.key_channels = key_channels
+ self.value_channels = value_channels
+ if out_channels == None:
+ self.out_channels = in_channels
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
+ self.f_key = nn.Sequential(
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
+ kernel_size=1, stride=1, padding=0),
+ InPlaceABNSync(self.key_channels),
+ )
+ self.f_query = self.f_key
+ self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
+ kernel_size=1, stride=1, padding=0)
+ self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels,
+ kernel_size=1, stride=1, padding=0)
+ nn.init.constant(self.W.weight, 0)
+ nn.init.constant(self.W.bias, 0)
+
+ def forward(self, x):
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
+ if self.scale > 1:
+ x = self.pool(x)
+
+ value = self.f_value(x).view(batch_size, self.value_channels, -1)
+ value = value.permute(0, 2, 1)
+ query = self.f_query(x).view(batch_size, self.key_channels, -1)
+ query = query.permute(0, 2, 1)
+ key = self.f_key(x).view(batch_size, self.key_channels, -1)
+
+ sim_map = torch.matmul(query, key)
+ sim_map = (self.key_channels ** -.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.view(batch_size, self.value_channels, *x.size()[2:])
+ context = self.W(context)
+ if self.scale > 1:
+ context = F.upsample(input=context, size=(h, w), mode='bilinear', align_corners=True)
+ return context
+
+
+class SelfAttentionBlock2D(_SelfAttentionBlock):
+ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1):
+ super(SelfAttentionBlock2D, self).__init__(in_channels,
+ key_channels,
+ value_channels,
+ out_channels,
+ scale)
+
+
+class BaseOC_Module(nn.Module):
+ """
+ Implementation of the BaseOC module
+ Parameters:
+ in_features / out_features: the channels of the input / output feature maps.
+ dropout: we choose 0.05 as the default value.
+ size: you can apply multiple sizes. Here we only use one size.
+ Return:
+ features fused with Object context information.
+ """
+
+ def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
+ super(BaseOC_Module, self).__init__()
+ self.stages = []
+ self.stages = nn.ModuleList(
+ [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
+ self.conv_bn_dropout = nn.Sequential(
+ nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0),
+ InPlaceABNSync(out_channels),
+ nn.Dropout2d(dropout)
+ )
+
+ def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
+ return SelfAttentionBlock2D(in_channels,
+ key_channels,
+ value_channels,
+ output_channels,
+ size)
+
+ def forward(self, feats):
+ priors = [stage(feats) for stage in self.stages]
+ context = priors[0]
+ for i in range(1, len(priors)):
+ context += priors[i]
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
+ return output
+
+
+class BaseOC_Context_Module(nn.Module):
+ """
+ Output only the context features.
+ Parameters:
+ in_features / out_features: the channels of the input / output feature maps.
+ dropout: specify the dropout ratio
+ fusion: We provide two different fusion method, "concat" or "add"
+ size: we find that directly learn the attention weights on even 1/8 feature maps is hard.
+ Return:
+ features after "concat" or "add"
+ """
+
+ def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1])):
+ super(BaseOC_Context_Module, self).__init__()
+ self.stages = []
+ self.stages = nn.ModuleList(
+ [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes])
+ self.conv_bn_dropout = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0),
+ InPlaceABNSync(out_channels),
+ )
+
+ def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size):
+ return SelfAttentionBlock2D(in_channels,
+ key_channels,
+ value_channels,
+ output_channels,
+ size)
+
+ def forward(self, feats):
+ priors = [stage(feats) for stage in self.stages]
+ context = priors[0]
+ for i in range(1, len(priors)):
+ context += priors[i]
+ output = self.conv_bn_dropout(context)
+ return output
+
+
+class ASP_OC_Module(nn.Module):
+ def __init__(self, features, out_features=256, dilations=(12, 24, 36)):
+ super(ASP_OC_Module, self).__init__()
+ self.context = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=True),
+ InPlaceABNSync(out_features),
+ BaseOC_Context_Module(in_channels=out_features, out_channels=out_features,
+ key_channels=out_features // 2, value_channels=out_features,
+ dropout=0, sizes=([2])))
+ self.conv2 = nn.Sequential(nn.Conv2d(features, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(out_features))
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
+ InPlaceABNSync(out_features))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
+ InPlaceABNSync(out_features))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
+ InPlaceABNSync(out_features))
+
+ self.conv_bn_dropout = nn.Sequential(
+ nn.Conv2d(out_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
+ InPlaceABNSync(out_features),
+ nn.Dropout2d(0.1)
+ )
+
+ def _cat_each(self, feat1, feat2, feat3, feat4, feat5):
+ assert (len(feat1) == len(feat2))
+ z = []
+ for i in range(len(feat1)):
+ z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), 1))
+ return z
+
+ def forward(self, x):
+ if isinstance(x, Variable):
+ _, _, h, w = x.size()
+ elif isinstance(x, tuple) or isinstance(x, list):
+ _, _, h, w = x[0].size()
+ else:
+ raise RuntimeError('unknown input type')
+
+ feat1 = self.context(x)
+ feat2 = self.conv2(x)
+ feat3 = self.conv3(x)
+ feat4 = self.conv4(x)
+ feat5 = self.conv5(x)
+
+ if isinstance(x, Variable):
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
+ elif isinstance(x, tuple) or isinstance(x, list):
+ out = self._cat_each(feat1, feat2, feat3, feat4, feat5)
+ else:
+ raise RuntimeError('unknown input type')
+ output = self.conv_bn_dropout(out)
+ return output
diff --git a/preprocess/humanparsing/networks/context_encoding/psp.py b/preprocess/humanparsing/networks/context_encoding/psp.py
new file mode 100644
index 0000000000000000000000000000000000000000..47181dc3f5fddb1c7fb80ad58a6694aae9ebd746
--- /dev/null
+++ b/preprocess/humanparsing/networks/context_encoding/psp.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : psp.py
+@Time : 8/4/19 3:36 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from modules import InPlaceABNSync
+
+
+class PSPModule(nn.Module):
+ """
+ Reference:
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
+ """
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
+ super(PSPModule, self).__init__()
+
+ self.stages = []
+ self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
+ self.bottleneck = nn.Sequential(
+ nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
+ bias=False),
+ InPlaceABNSync(out_features),
+ )
+
+ def _make_stage(self, features, out_features, size):
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
+ bn = InPlaceABNSync(out_features)
+ return nn.Sequential(prior, conv, bn)
+
+ def forward(self, feats):
+ h, w = feats.size(2), feats.size(3)
+ priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in
+ self.stages] + [feats]
+ bottle = self.bottleneck(torch.cat(priors, 1))
+ return bottle
\ No newline at end of file
diff --git a/preprocess/humanparsing/parsing_api.py b/preprocess/humanparsing/parsing_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e82570d62d10af224e894353900cf2d18063be9
--- /dev/null
+++ b/preprocess/humanparsing/parsing_api.py
@@ -0,0 +1,191 @@
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
+sys.path.insert(0, str(PROJECT_ROOT))
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from datasets.simple_extractor_dataset import SimpleFolderDataset
+from PIL import Image
+from utils.transforms import transform_logits
+
+
+def get_palette(num_cls):
+ """ Returns the color map for visualizing the segmentation mask.
+ Args:
+ num_cls: Number of classes
+ Returns:
+ The color map
+ """
+ n = num_cls
+ palette = [0] * (n * 3)
+ for j in range(0, n):
+ lab = j
+ palette[j * 3 + 0] = 0
+ palette[j * 3 + 1] = 0
+ palette[j * 3 + 2] = 0
+ i = 0
+ while lab:
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
+ i += 1
+ lab >>= 3
+ return palette
+
+
+def delete_irregular(logits_result):
+ parsing_result = np.argmax(logits_result, axis=2)
+ upper_cloth = np.where(parsing_result == 4, 255, 0)
+ contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8),
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
+ area = []
+ for i in range(len(contours)):
+ a = cv2.contourArea(contours[i], True)
+ area.append(abs(a))
+ if len(area) != 0:
+ top = area.index(max(area))
+ M = cv2.moments(contours[top])
+ cY = int(M["m01"] / M["m00"])
+
+ dresses = np.where(parsing_result == 7, 255, 0)
+ contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8),
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
+ area_dress = []
+ for j in range(len(contours_dress)):
+ a_d = cv2.contourArea(contours_dress[j], True)
+ area_dress.append(abs(a_d))
+ if len(area_dress) != 0:
+ top_dress = area_dress.index(max(area_dress))
+ M_dress = cv2.moments(contours_dress[top_dress])
+ cY_dress = int(M_dress["m01"] / M_dress["m00"])
+ wear_type = "dresses"
+ if len(area) != 0:
+ if len(area_dress) != 0 and cY_dress > cY:
+ irregular_list = np.array([4, 5, 6])
+ logits_result[:, :, irregular_list] = -1
+ else:
+ irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13])
+ logits_result[:cY, :, irregular_list] = -1
+ wear_type = "cloth_pant"
+ parsing_result = np.argmax(logits_result, axis=2)
+ # pad border
+ parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)
+ return parsing_result, wear_type
+
+
+def hole_fill(img):
+ img_copy = img.copy()
+ mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
+ cv2.floodFill(img, mask, (0, 0), 255)
+ img_inverse = cv2.bitwise_not(img)
+ dst = cv2.bitwise_or(img_copy, img_inverse)
+ return dst
+
+
+def refine_mask(mask):
+ contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
+ area = []
+ for j in range(len(contours)):
+ a_d = cv2.contourArea(contours[j], True)
+ area.append(abs(a_d))
+ refine_mask = np.zeros_like(mask).astype(np.uint8)
+ if len(area) != 0:
+ i = area.index(max(area))
+ cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
+ # keep large area in skin case
+ for j in range(len(area)):
+ if j != i and area[i] > 2000:
+ cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1)
+ return refine_mask
+
+
+def refine_hole(parsing_result_filled, parsing_result, arm_mask):
+ filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0),
+ np.where(parsing_result != 4, 255, 0)) - arm_mask * 255
+ contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
+ refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8)
+ for i in range(len(contours)):
+ a = cv2.contourArea(contours[i], True)
+ # keep hole > 2000 pixels
+ if abs(a) > 2000:
+ cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1)
+ return refine_hole_mask + arm_mask
+
+
+def onnx_inference(session, lip_session, input_dir):
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
+ ])
+ dataset = SimpleFolderDataset(root=input_dir, input_size=[512, 512], transform=transform)
+ # dataloader = DataLoader(dataset)
+ with torch.no_grad():
+ # for _, batch in enumerate(tqdm(dataloader, disable=True)):
+ image, meta = dataset[0]
+ image = image.unsqueeze(0)
+
+ # image, meta = batch
+ c = meta['center']
+ h = meta['height']
+ w = meta['width']
+ s = meta['scale']
+ output = session.run(None, {"input.1": image.numpy().astype(np.float32)})
+ upsample = torch.nn.Upsample(size=[512, 512], mode='bilinear', align_corners=True)
+ upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0))
+ upsample_output = upsample_output.squeeze()
+ upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
+ logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=[512, 512])
+ parsing_result = np.argmax(logits_result, axis=2)
+ parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)
+ # try holefilling the clothes part
+ arm_mask = (parsing_result == 14).astype(np.float32) \
+ + (parsing_result == 15).astype(np.float32)
+ upper_cloth_mask = (parsing_result == 4).astype(np.float32) + arm_mask
+ img = np.where(upper_cloth_mask, 255, 0)
+ dst = hole_fill(img.astype(np.uint8))
+ parsing_result_filled = dst / 255 * 4
+ parsing_result_woarm = np.where(parsing_result_filled == 4, parsing_result_filled, parsing_result)
+ # add back arm and refined hole between arm and cloth
+ refine_hole_mask = refine_hole(parsing_result_filled.astype(np.uint8), parsing_result.astype(np.uint8),
+ arm_mask.astype(np.uint8))
+ parsing_result = np.where(refine_hole_mask, parsing_result, parsing_result_woarm)
+ # remove padding
+ parsing_result = parsing_result[1:-1, 1:-1]
+
+ dataset_lip = SimpleFolderDataset(root=input_dir, input_size=[473, 473], transform=transform)
+ # dataloader_lip = DataLoader(dataset_lip)
+ with torch.no_grad():
+ # for _, batch in enumerate(tqdm(dataloader_lip, disable=True)):
+
+ image, meta = dataset_lip[0]
+ image = image.unsqueeze(0)
+
+ # image, meta = batch
+ c = meta['center']
+ s = meta['scale']
+ w = meta['width']
+ h = meta['height']
+
+ output_lip = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)})
+ upsample = torch.nn.Upsample(size=[473, 473], mode='bilinear', align_corners=True)
+ upsample_output_lip = upsample(torch.from_numpy(output_lip[1][0]).unsqueeze(0))
+ upsample_output_lip = upsample_output_lip.squeeze()
+ upsample_output_lip = upsample_output_lip.permute(1, 2, 0) # CHW -> HWC
+ logits_result_lip = transform_logits(upsample_output_lip.data.cpu().numpy(), c, s, w, h,
+ input_size=[473, 473])
+ parsing_result_lip = np.argmax(logits_result_lip, axis=2)
+ # add neck parsing result
+ neck_mask = np.logical_and(np.logical_not((parsing_result_lip == 13).astype(np.float32)),
+ (parsing_result == 11).astype(np.float32))
+ parsing_result = np.where(neck_mask, 18, parsing_result)
+ palette = get_palette(19)
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
+ output_img.putpalette(palette)
+ face_mask = torch.from_numpy((parsing_result == 11).astype(np.float32))
+
+ return output_img, face_mask
diff --git a/preprocess/humanparsing/run_parsing.py b/preprocess/humanparsing/run_parsing.py
new file mode 100644
index 0000000000000000000000000000000000000000..51057aaa338215332d1cc08b9070d5a0bbb2d097
--- /dev/null
+++ b/preprocess/humanparsing/run_parsing.py
@@ -0,0 +1,44 @@
+import os
+import pdb
+import sys
+from pathlib import Path
+
+import onnxruntime as ort
+
+PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
+sys.path.insert(0, str(PROJECT_ROOT))
+import torch
+from parsing_api import onnx_inference
+
+
+class Parsing:
+ def __init__(self, gpu_id: int):
+ self.gpu_id = gpu_id
+ # torch.cuda.set_device(gpu_id)
+ session_options = ort.SessionOptions()
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
+ session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
+ #### jho modified >>>>
+ providers = [
+ ('CUDAExecutionProvider', {
+ 'device_id': gpu_id,
+ }),
+ 'CPUExecutionProvider',
+ ]
+ self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_atr.onnx'),
+ sess_options=session_options, providers=providers)
+ self.lip_session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_lip.onnx'),
+ sess_options=session_options, providers=providers)
+ #### jho modified <<<<
+ # session_options.add_session_config_entry('gpu_id', str(gpu_id))
+ # self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_atr.onnx'),
+ # sess_options=session_options, providers=['CUDAExecutionProvider'])
+ # self.lip_session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_lip.onnx'),
+ # sess_options=session_options, providers=['CUDAExecutionProvider'])
+ print(f"parsing init done (gpu: {gpu_id})")
+
+ def __call__(self, input_image):
+ torch.cuda.set_device(self.gpu_id)
+ parsed_image, face_mask = onnx_inference(self.session, self.lip_session, input_image)
+ return parsed_image, face_mask
+
\ No newline at end of file
diff --git a/preprocess/humanparsing/utils/__init__.py b/preprocess/humanparsing/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/preprocess/humanparsing/utils/consistency_loss.py b/preprocess/humanparsing/utils/consistency_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b872fdcc10ecef02762399278191e48e79ea9a1f
--- /dev/null
+++ b/preprocess/humanparsing/utils/consistency_loss.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : kl_loss.py
+@Time : 7/23/19 4:02 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+from datasets.target_generation import generate_edge_tensor
+
+
+class ConsistencyLoss(nn.Module):
+ def __init__(self, ignore_index=255):
+ super(ConsistencyLoss, self).__init__()
+ self.ignore_index=ignore_index
+
+ def forward(self, parsing, edge, label):
+ parsing_pre = torch.argmax(parsing, dim=1)
+ parsing_pre[label==self.ignore_index]=self.ignore_index
+ generated_edge = generate_edge_tensor(parsing_pre)
+ edge_pre = torch.argmax(edge, dim=1)
+ v_generate_edge = generated_edge[label!=255]
+ v_edge_pre = edge_pre[label!=255]
+ v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor)
+ positive_union = (v_generate_edge==1)&(v_edge_pre==1) # only the positive values count
+ return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0))
diff --git a/preprocess/humanparsing/utils/criterion.py b/preprocess/humanparsing/utils/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..968894319042331482692e42804f103074e4b710
--- /dev/null
+++ b/preprocess/humanparsing/utils/criterion.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : criterion.py
+@Time : 8/30/19 8:59 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import torch.nn as nn
+import torch
+import numpy as np
+from torch.nn import functional as F
+from .lovasz_softmax import LovaszSoftmax
+from .kl_loss import KLDivergenceLoss
+from .consistency_loss import ConsistencyLoss
+
+NUM_CLASSES = 20
+
+
+class CriterionAll(nn.Module):
+ def __init__(self, use_class_weight=False, ignore_index=255, lambda_1=1, lambda_2=1, lambda_3=1,
+ num_classes=20):
+ super(CriterionAll, self).__init__()
+ self.ignore_index = ignore_index
+ self.use_class_weight = use_class_weight
+ self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
+ self.lovasz = LovaszSoftmax(ignore_index=ignore_index)
+ self.kldiv = KLDivergenceLoss(ignore_index=ignore_index)
+ self.reg = ConsistencyLoss(ignore_index=ignore_index)
+ self.lamda_1 = lambda_1
+ self.lamda_2 = lambda_2
+ self.lamda_3 = lambda_3
+ self.num_classes = num_classes
+
+ def parsing_loss(self, preds, target, cycle_n=None):
+ """
+ Loss function definition.
+
+ Args:
+ preds: [[parsing result1, parsing result2],[edge result]]
+ target: [parsing label, egde label]
+ soft_preds: [[parsing result1, parsing result2],[edge result]]
+ Returns:
+ Calculated Loss.
+ """
+ h, w = target[0].size(1), target[0].size(2)
+
+ pos_num = torch.sum(target[1] == 1, dtype=torch.float)
+ neg_num = torch.sum(target[1] == 0, dtype=torch.float)
+
+ weight_pos = neg_num / (pos_num + neg_num)
+ weight_neg = pos_num / (pos_num + neg_num)
+ weights = torch.tensor([weight_neg, weight_pos]) # edge loss weight
+
+ loss = 0
+
+ # loss for segmentation
+ preds_parsing = preds[0]
+ for pred_parsing in preds_parsing:
+ scale_pred = F.interpolate(input=pred_parsing, size=(h, w),
+ mode='bilinear', align_corners=True)
+
+ loss += 0.5 * self.lamda_1 * self.lovasz(scale_pred, target[0])
+ if target[2] is None:
+ loss += 0.5 * self.lamda_1 * self.criterion(scale_pred, target[0])
+ else:
+ soft_scale_pred = F.interpolate(input=target[2], size=(h, w),
+ mode='bilinear', align_corners=True)
+ soft_scale_pred = moving_average(soft_scale_pred, to_one_hot(target[0], num_cls=self.num_classes),
+ 1.0 / (cycle_n + 1.0))
+ loss += 0.5 * self.lamda_1 * self.kldiv(scale_pred, soft_scale_pred, target[0])
+
+ # loss for edge
+ preds_edge = preds[1]
+ for pred_edge in preds_edge:
+ scale_pred = F.interpolate(input=pred_edge, size=(h, w),
+ mode='bilinear', align_corners=True)
+ if target[3] is None:
+ loss += self.lamda_2 * F.cross_entropy(scale_pred, target[1],
+ weights.cuda(), ignore_index=self.ignore_index)
+ else:
+ soft_scale_edge = F.interpolate(input=target[3], size=(h, w),
+ mode='bilinear', align_corners=True)
+ soft_scale_edge = moving_average(soft_scale_edge, to_one_hot(target[1], num_cls=2),
+ 1.0 / (cycle_n + 1.0))
+ loss += self.lamda_2 * self.kldiv(scale_pred, soft_scale_edge, target[0])
+
+ # consistency regularization
+ preds_parsing = preds[0]
+ preds_edge = preds[1]
+ for pred_parsing in preds_parsing:
+ scale_pred = F.interpolate(input=pred_parsing, size=(h, w),
+ mode='bilinear', align_corners=True)
+ scale_edge = F.interpolate(input=preds_edge[0], size=(h, w),
+ mode='bilinear', align_corners=True)
+ loss += self.lamda_3 * self.reg(scale_pred, scale_edge, target[0])
+
+ return loss
+
+ def forward(self, preds, target, cycle_n=None):
+ loss = self.parsing_loss(preds, target, cycle_n)
+ return loss
+
+ def _generate_weights(self, masks, num_classes):
+ """
+ masks: torch.Tensor with shape [B, H, W]
+ """
+ masks_label = masks.data.cpu().numpy().astype(np.int64)
+ pixel_nums = []
+ tot_pixels = 0
+ for i in range(num_classes):
+ pixel_num_of_cls_i = np.sum(masks_label == i).astype(np.float)
+ pixel_nums.append(pixel_num_of_cls_i)
+ tot_pixels += pixel_num_of_cls_i
+ weights = []
+ for i in range(num_classes):
+ weights.append(
+ (tot_pixels - pixel_nums[i]) / tot_pixels / (num_classes - 1)
+ )
+ weights = np.array(weights, dtype=np.float)
+ # weights = torch.from_numpy(weights).float().to(masks.device)
+ return weights
+
+
+def moving_average(target1, target2, alpha=1.0):
+ target = 0
+ target += (1.0 - alpha) * target1
+ target += target2 * alpha
+ return target
+
+
+def to_one_hot(tensor, num_cls, dim=1, ignore_index=255):
+ b, h, w = tensor.shape
+ tensor[tensor == ignore_index] = 0
+ onehot_tensor = torch.zeros(b, num_cls, h, w).cuda()
+ onehot_tensor.scatter_(dim, tensor.unsqueeze(dim), 1)
+ return onehot_tensor
diff --git a/preprocess/humanparsing/utils/encoding.py b/preprocess/humanparsing/utils/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4cb131a03da41645b42f8461e8343c413917f45
--- /dev/null
+++ b/preprocess/humanparsing/utils/encoding.py
@@ -0,0 +1,187 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## ECE Department, Rutgers University
+## Email: zhang.hang@rutgers.edu
+## Copyright (c) 2017
+##
+## This source code is licensed under the MIT-style license found in the
+## LICENSE file in the root directory of this source tree
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
+"""Encoding Data Parallel"""
+import threading
+import torch
+from torch.autograd import Variable, Function
+import torch.cuda.comm as comm
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.parallel_apply import get_a_var
+from torch.nn.parallel._functions import Broadcast
+
+torch_ver = torch.__version__[:3]
+
+__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 'patch_replication_callback']
+
+def allreduce(*inputs):
+ """Cross GPU all reduce autograd operation for calculate mean and
+ variance in SyncBN.
+ """
+ return AllReduce.apply(*inputs)
+
+class AllReduce(Function):
+ @staticmethod
+ def forward(ctx, num_inputs, *inputs):
+ ctx.num_inputs = num_inputs
+ ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
+ inputs = [inputs[i:i + num_inputs]
+ for i in range(0, len(inputs), num_inputs)]
+ # sort before reduce sum
+ inputs = sorted(inputs, key=lambda i: i[0].get_device())
+ results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
+ outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
+ return tuple([t for tensors in outputs for t in tensors])
+
+ @staticmethod
+ def backward(ctx, *inputs):
+ inputs = [i.data for i in inputs]
+ inputs = [inputs[i:i + ctx.num_inputs]
+ for i in range(0, len(inputs), ctx.num_inputs)]
+ results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
+ outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
+ return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
+
+class Reduce(Function):
+ @staticmethod
+ def forward(ctx, *inputs):
+ ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
+ inputs = sorted(inputs, key=lambda i: i.get_device())
+ return comm.reduce_add(inputs)
+
+ @staticmethod
+ def backward(ctx, gradOutput):
+ return Broadcast.apply(ctx.target_gpus, gradOutput)
+
+
+class DataParallelModel(DataParallel):
+ """Implements data parallelism at the module level.
+
+ This container parallelizes the application of the given module by
+ splitting the input across the specified devices by chunking in the
+ batch dimension.
+ In the forward pass, the module is replicated on each device,
+ and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
+ Note that the outputs are not gathered, please use compatible
+ :class:`encoding.parallel.DataParallelCriterion`.
+
+ The batch size should be larger than the number of GPUs used. It should
+ also be an integer multiple of the number of GPUs so that each chunk is
+ the same size (so that each GPU processes the same number of samples).
+
+ Args:
+ module: module to be parallelized
+ device_ids: CUDA devices (default: all devices)
+
+ Reference:
+ Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
+ Amit Agrawal. “Context Encoding for Semantic Segmentation.
+ *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
+
+ Example::
+
+ >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
+ >>> y = net(x)
+ """
+ def gather(self, outputs, output_device):
+ return outputs
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelModel, self).replicate(module, device_ids)
+ return modules
+
+
+class DataParallelCriterion(DataParallel):
+ """
+ Calculate loss in multiple-GPUs, which balance the memory usage for
+ Semantic Segmentation.
+
+ The targets are splitted across the specified devices by chunking in
+ the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
+
+ Reference:
+ Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
+ Amit Agrawal. “Context Encoding for Semantic Segmentation.
+ *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
+
+ Example::
+
+ >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
+ >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
+ >>> y = net(x)
+ >>> loss = criterion(y, target)
+ """
+ def forward(self, inputs, *targets, **kwargs):
+ # input should be already scatterd
+ # scattering the targets instead
+ if not self.device_ids:
+ return self.module(inputs, *targets, **kwargs)
+ targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ return self.module(inputs, *targets[0], **kwargs[0])
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
+ outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
+ return Reduce.apply(*outputs) / len(outputs)
+
+
+def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
+ assert len(modules) == len(inputs)
+ assert len(targets) == len(inputs)
+ if kwargs_tup:
+ assert len(modules) == len(kwargs_tup)
+ else:
+ kwargs_tup = ({},) * len(modules)
+ if devices is not None:
+ assert len(modules) == len(devices)
+ else:
+ devices = [None] * len(modules)
+
+ lock = threading.Lock()
+ results = {}
+ if torch_ver != "0.3":
+ grad_enabled = torch.is_grad_enabled()
+
+ def _worker(i, module, input, target, kwargs, device=None):
+ if torch_ver != "0.3":
+ torch.set_grad_enabled(grad_enabled)
+ if device is None:
+ device = get_a_var(input).get_device()
+ try:
+ if not isinstance(input, tuple):
+ input = (input,)
+ with torch.cuda.device(device):
+ output = module(*(input + target), **kwargs)
+ with lock:
+ results[i] = output
+ except Exception as e:
+ with lock:
+ results[i] = e
+
+ if len(modules) > 1:
+ threads = [threading.Thread(target=_worker,
+ args=(i, module, input, target,
+ kwargs, device),)
+ for i, (module, input, target, kwargs, device) in
+ enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ else:
+ _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
+
+ outputs = []
+ for i in range(len(inputs)):
+ output = results[i]
+ if isinstance(output, Exception):
+ raise output
+ outputs.append(output)
+ return outputs
diff --git a/preprocess/humanparsing/utils/kl_loss.py b/preprocess/humanparsing/utils/kl_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb50203be566672d4cf42e7c60e92998a37d6fe
--- /dev/null
+++ b/preprocess/humanparsing/utils/kl_loss.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : kl_loss.py
+@Time : 7/23/19 4:02 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+import torch.nn.functional as F
+from torch import nn
+
+
+def flatten_probas(input, target, labels, ignore=255):
+ """
+ Flattens predictions in the batch.
+ """
+ B, C, H, W = input.size()
+ input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
+ target = target.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
+ labels = labels.view(-1)
+ if ignore is None:
+ return input, target
+ valid = (labels != ignore)
+ vinput = input[valid.nonzero().squeeze()]
+ vtarget = target[valid.nonzero().squeeze()]
+ return vinput, vtarget
+
+
+class KLDivergenceLoss(nn.Module):
+ def __init__(self, ignore_index=255, T=1):
+ super(KLDivergenceLoss, self).__init__()
+ self.ignore_index=ignore_index
+ self.T = T
+
+ def forward(self, input, target, label):
+ log_input_prob = F.log_softmax(input / self.T, dim=1)
+ target_porb = F.softmax(target / self.T, dim=1)
+ loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index))
+ return self.T*self.T*loss # balanced
diff --git a/preprocess/humanparsing/utils/lovasz_softmax.py b/preprocess/humanparsing/utils/lovasz_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e444f684c0d9bda9d7c2d54a4e79fac0ddf081
--- /dev/null
+++ b/preprocess/humanparsing/utils/lovasz_softmax.py
@@ -0,0 +1,279 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : lovasz_softmax.py
+@Time : 8/30/19 7:12 PM
+@Desc : Lovasz-Softmax and Jaccard hinge loss in PyTorch
+ Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+from __future__ import print_function, division
+
+import torch
+from torch.autograd import Variable
+import torch.nn.functional as F
+import numpy as np
+from torch import nn
+
+try:
+ from itertools import ifilterfalse
+except ImportError: # py3k
+ from itertools import filterfalse as ifilterfalse
+
+
+def lovasz_grad(gt_sorted):
+ """
+ Computes gradient of the Lovasz extension w.r.t sorted errors
+ See Alg. 1 in paper
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1. - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
+ """
+ IoU for foreground class
+ binary: 1 foreground, 0 background
+ """
+ if not per_image:
+ preds, labels = (preds,), (labels,)
+ ious = []
+ for pred, label in zip(preds, labels):
+ intersection = ((label == 1) & (pred == 1)).sum()
+ union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
+ if not union:
+ iou = EMPTY
+ else:
+ iou = float(intersection) / float(union)
+ ious.append(iou)
+ iou = mean(ious) # mean accross images if per_image
+ return 100 * iou
+
+
+def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
+ """
+ Array of IoU for each (non ignored) class
+ """
+ if not per_image:
+ preds, labels = (preds,), (labels,)
+ ious = []
+ for pred, label in zip(preds, labels):
+ iou = []
+ for i in range(C):
+ if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
+ intersection = ((label == i) & (pred == i)).sum()
+ union = ((label == i) | ((pred == i) & (label != ignore))).sum()
+ if not union:
+ iou.append(EMPTY)
+ else:
+ iou.append(float(intersection) / float(union))
+ ious.append(iou)
+ ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
+ return 100 * np.array(ious)
+
+
+# --------------------------- BINARY LOSSES ---------------------------
+
+
+def lovasz_hinge(logits, labels, per_image=True, ignore=None):
+ """
+ Binary Lovasz hinge loss
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
+ per_image: compute the loss per image instead of per batch
+ ignore: void class id
+ """
+ if per_image:
+ loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
+ for log, lab in zip(logits, labels))
+ else:
+ loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
+ return loss
+
+
+def lovasz_hinge_flat(logits, labels):
+ """
+ Binary Lovasz hinge loss
+ logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
+ labels: [P] Tensor, binary ground truth labels (0 or 1)
+ ignore: label to ignore
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels.float() - 1.
+ errors = (1. - logits * Variable(signs))
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), Variable(grad))
+ return loss
+
+
+def flatten_binary_scores(scores, labels, ignore=None):
+ """
+ Flattens predictions in the batch (binary case)
+ Remove labels equal to 'ignore'
+ """
+ scores = scores.view(-1)
+ labels = labels.view(-1)
+ if ignore is None:
+ return scores, labels
+ valid = (labels != ignore)
+ vscores = scores[valid]
+ vlabels = labels[valid]
+ return vscores, vlabels
+
+
+class StableBCELoss(torch.nn.modules.Module):
+ def __init__(self):
+ super(StableBCELoss, self).__init__()
+
+ def forward(self, input, target):
+ neg_abs = - input.abs()
+ loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
+ return loss.mean()
+
+
+def binary_xloss(logits, labels, ignore=None):
+ """
+ Binary Cross entropy loss
+ logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
+ labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
+ ignore: void class id
+ """
+ logits, labels = flatten_binary_scores(logits, labels, ignore)
+ loss = StableBCELoss()(logits, Variable(labels.float()))
+ return loss
+
+
+# --------------------------- MULTICLASS LOSSES ---------------------------
+
+
+def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255, weighted=None):
+ """
+ Multi-class Lovasz-Softmax loss
+ probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
+ Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
+ labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
+ classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ per_image: compute the loss per image instead of per batch
+ ignore: void class labels
+ """
+ if per_image:
+ loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes, weighted=weighted)
+ for prob, lab in zip(probas, labels))
+ else:
+ loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes, weighted=weighted )
+ return loss
+
+
+def lovasz_softmax_flat(probas, labels, classes='present', weighted=None):
+ """
+ Multi-class Lovasz-Softmax loss
+ probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
+ labels: [P] Tensor, ground truth labels (between 0 and C - 1)
+ classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ """
+ if probas.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probas * 0.
+ C = probas.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+ for c in class_to_sum:
+ fg = (labels == c).float() # foreground for class c
+ if (classes is 'present' and fg.sum() == 0):
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probas[:, 0]
+ else:
+ class_pred = probas[:, c]
+ errors = (Variable(fg) - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ if weighted is not None:
+ losses.append(weighted[c]*torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
+ else:
+ losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
+ return mean(losses)
+
+
+def flatten_probas(probas, labels, ignore=None):
+ """
+ Flattens predictions in the batch
+ """
+ if probas.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probas.size()
+ probas = probas.view(B, 1, H, W)
+ B, C, H, W = probas.size()
+ probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
+ labels = labels.view(-1)
+ if ignore is None:
+ return probas, labels
+ valid = (labels != ignore)
+ vprobas = probas[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobas, vlabels
+
+
+def xloss(logits, labels, ignore=None):
+ """
+ Cross entropy loss
+ """
+ return F.cross_entropy(logits, Variable(labels), ignore_index=255)
+
+
+# --------------------------- HELPER FUNCTIONS ---------------------------
+def isnan(x):
+ return x != x
+
+
+def mean(l, ignore_nan=False, empty=0):
+ """
+ nanmean compatible with generators.
+ """
+ l = iter(l)
+ if ignore_nan:
+ l = ifilterfalse(isnan, l)
+ try:
+ n = 1
+ acc = next(l)
+ except StopIteration:
+ if empty == 'raise':
+ raise ValueError('Empty mean')
+ return empty
+ for n, v in enumerate(l, 2):
+ acc += v
+ if n == 1:
+ return acc
+ return acc / n
+
+# --------------------------- Class ---------------------------
+class LovaszSoftmax(nn.Module):
+ def __init__(self, per_image=False, ignore_index=255, weighted=None):
+ super(LovaszSoftmax, self).__init__()
+ self.lovasz_softmax = lovasz_softmax
+ self.per_image = per_image
+ self.ignore_index=ignore_index
+ self.weighted = weighted
+
+ def forward(self, pred, label):
+ pred = F.softmax(pred, dim=1)
+ return self.lovasz_softmax(pred, label, per_image=self.per_image, ignore=self.ignore_index, weighted=self.weighted)
\ No newline at end of file
diff --git a/preprocess/humanparsing/utils/miou.py b/preprocess/humanparsing/utils/miou.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a2cc965a5c0cfd5497c9191906898da31485dd
--- /dev/null
+++ b/preprocess/humanparsing/utils/miou.py
@@ -0,0 +1,155 @@
+import cv2
+import os
+import numpy as np
+
+from collections import OrderedDict
+from PIL import Image as PILImage
+from utils.transforms import transform_parsing
+
+LABELS = ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', \
+ 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg',
+ 'Right-leg', 'Left-shoe', 'Right-shoe']
+
+
+# LABELS = ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs']
+
+def get_palette(num_cls):
+ """ Returns the color map for visualizing the segmentation mask.
+ Args:
+ num_cls: Number of classes
+ Returns:
+ The color map
+ """
+
+ n = num_cls
+ palette = [0] * (n * 3)
+ for j in range(0, n):
+ lab = j
+ palette[j * 3 + 0] = 0
+ palette[j * 3 + 1] = 0
+ palette[j * 3 + 2] = 0
+ i = 0
+ while lab:
+ palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
+ palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
+ palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
+ i += 1
+ lab >>= 3
+ return palette
+
+
+def get_confusion_matrix(gt_label, pred_label, num_classes):
+ """
+ Calcute the confusion matrix by given label and pred
+ :param gt_label: the ground truth label
+ :param pred_label: the pred label
+ :param num_classes: the nunber of class
+ :return: the confusion matrix
+ """
+ index = (gt_label * num_classes + pred_label).astype('int32')
+ label_count = np.bincount(index)
+ confusion_matrix = np.zeros((num_classes, num_classes))
+
+ for i_label in range(num_classes):
+ for i_pred_label in range(num_classes):
+ cur_index = i_label * num_classes + i_pred_label
+ if cur_index < len(label_count):
+ confusion_matrix[i_label, i_pred_label] = label_count[cur_index]
+
+ return confusion_matrix
+
+
+def compute_mean_ioU(preds, scales, centers, num_classes, datadir, input_size=[473, 473], dataset='val'):
+ val_file = os.path.join(datadir, dataset + '_id.txt')
+ val_id = [i_id.strip() for i_id in open(val_file)]
+
+ confusion_matrix = np.zeros((num_classes, num_classes))
+
+ for i, pred_out in enumerate(preds):
+ im_name = val_id[i]
+ gt_path = os.path.join(datadir, dataset + '_segmentations', im_name + '.png')
+ gt = np.array(PILImage.open(gt_path))
+ h, w = gt.shape
+ s = scales[i]
+ c = centers[i]
+ pred = transform_parsing(pred_out, c, s, w, h, input_size)
+
+ gt = np.asarray(gt, dtype=np.int32)
+ pred = np.asarray(pred, dtype=np.int32)
+
+ ignore_index = gt != 255
+
+ gt = gt[ignore_index]
+ pred = pred[ignore_index]
+
+ confusion_matrix += get_confusion_matrix(gt, pred, num_classes)
+
+ pos = confusion_matrix.sum(1)
+ res = confusion_matrix.sum(0)
+ tp = np.diag(confusion_matrix)
+
+ pixel_accuracy = (tp.sum() / pos.sum()) * 100
+ mean_accuracy = ((tp / np.maximum(1.0, pos)).mean()) * 100
+ IoU_array = (tp / np.maximum(1.0, pos + res - tp))
+ IoU_array = IoU_array * 100
+ mean_IoU = IoU_array.mean()
+ print('Pixel accuracy: %f \n' % pixel_accuracy)
+ print('Mean accuracy: %f \n' % mean_accuracy)
+ print('Mean IU: %f \n' % mean_IoU)
+ name_value = []
+
+ for i, (label, iou) in enumerate(zip(LABELS, IoU_array)):
+ name_value.append((label, iou))
+
+ name_value.append(('Pixel accuracy', pixel_accuracy))
+ name_value.append(('Mean accuracy', mean_accuracy))
+ name_value.append(('Mean IU', mean_IoU))
+ name_value = OrderedDict(name_value)
+ return name_value
+
+
+def compute_mean_ioU_file(preds_dir, num_classes, datadir, dataset='val'):
+ list_path = os.path.join(datadir, dataset + '_id.txt')
+ val_id = [i_id.strip() for i_id in open(list_path)]
+
+ confusion_matrix = np.zeros((num_classes, num_classes))
+
+ for i, im_name in enumerate(val_id):
+ gt_path = os.path.join(datadir, 'segmentations', im_name + '.png')
+ gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
+
+ pred_path = os.path.join(preds_dir, im_name + '.png')
+ pred = np.asarray(PILImage.open(pred_path))
+
+ gt = np.asarray(gt, dtype=np.int32)
+ pred = np.asarray(pred, dtype=np.int32)
+
+ ignore_index = gt != 255
+
+ gt = gt[ignore_index]
+ pred = pred[ignore_index]
+
+ confusion_matrix += get_confusion_matrix(gt, pred, num_classes)
+
+ pos = confusion_matrix.sum(1)
+ res = confusion_matrix.sum(0)
+ tp = np.diag(confusion_matrix)
+
+ pixel_accuracy = (tp.sum() / pos.sum()) * 100
+ mean_accuracy = ((tp / np.maximum(1.0, pos)).mean()) * 100
+ IoU_array = (tp / np.maximum(1.0, pos + res - tp))
+ IoU_array = IoU_array * 100
+ mean_IoU = IoU_array.mean()
+ print('Pixel accuracy: %f \n' % pixel_accuracy)
+ print('Mean accuracy: %f \n' % mean_accuracy)
+ print('Mean IU: %f \n' % mean_IoU)
+ name_value = []
+
+ for i, (label, iou) in enumerate(zip(LABELS, IoU_array)):
+ name_value.append((label, iou))
+
+ name_value.append(('Pixel accuracy', pixel_accuracy))
+ name_value.append(('Mean accuracy', mean_accuracy))
+ name_value.append(('Mean IU', mean_IoU))
+ name_value = OrderedDict(name_value)
+ return name_value
diff --git a/preprocess/humanparsing/utils/schp.py b/preprocess/humanparsing/utils/schp.py
new file mode 100644
index 0000000000000000000000000000000000000000..f57470452fac8183dc5c17156439416c15bd3265
--- /dev/null
+++ b/preprocess/humanparsing/utils/schp.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : schp.py
+@Time : 4/8/19 2:11 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import os
+import torch
+import modules
+
+def moving_average(net1, net2, alpha=1):
+ for param1, param2 in zip(net1.parameters(), net2.parameters()):
+ param1.data *= (1.0 - alpha)
+ param1.data += param2.data * alpha
+
+
+def _check_bn(module, flag):
+ if issubclass(module.__class__, modules.bn.InPlaceABNSync):
+ flag[0] = True
+
+
+def check_bn(model):
+ flag = [False]
+ model.apply(lambda module: _check_bn(module, flag))
+ return flag[0]
+
+
+def reset_bn(module):
+ if issubclass(module.__class__, modules.bn.InPlaceABNSync):
+ module.running_mean = torch.zeros_like(module.running_mean)
+ module.running_var = torch.ones_like(module.running_var)
+
+
+def _get_momenta(module, momenta):
+ if issubclass(module.__class__, modules.bn.InPlaceABNSync):
+ momenta[module] = module.momentum
+
+
+def _set_momenta(module, momenta):
+ if issubclass(module.__class__, modules.bn.InPlaceABNSync):
+ module.momentum = momenta[module]
+
+
+def bn_re_estimate(loader, model):
+ if not check_bn(model):
+ print('No batch norm layer detected')
+ return
+ model.train()
+ momenta = {}
+ model.apply(reset_bn)
+ model.apply(lambda module: _get_momenta(module, momenta))
+ n = 0
+ for i_iter, batch in enumerate(loader):
+ images, labels, _ = batch
+ b = images.data.size(0)
+ momentum = b / (n + b)
+ for module in momenta.keys():
+ module.momentum = momentum
+ model(images)
+ n += b
+ model.apply(lambda module: _set_momenta(module, momenta))
+
+
+def save_schp_checkpoint(states, is_best_parsing, output_dir, filename='schp_checkpoint.pth.tar'):
+ save_path = os.path.join(output_dir, filename)
+ if os.path.exists(save_path):
+ os.remove(save_path)
+ torch.save(states, save_path)
+ if is_best_parsing and 'state_dict' in states:
+ best_save_path = os.path.join(output_dir, 'model_parsing_best.pth.tar')
+ if os.path.exists(best_save_path):
+ os.remove(best_save_path)
+ torch.save(states, best_save_path)
diff --git a/preprocess/humanparsing/utils/soft_dice_loss.py b/preprocess/humanparsing/utils/soft_dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb5895fd37467d36f213f941d1b01d6d6f7f194c
--- /dev/null
+++ b/preprocess/humanparsing/utils/soft_dice_loss.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : soft_dice_loss.py
+@Time : 8/13/19 5:09 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+from __future__ import print_function, division
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+try:
+ from itertools import ifilterfalse
+except ImportError: # py3k
+ from itertools import filterfalse as ifilterfalse
+
+
+def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6):
+ '''
+ Tversky loss function.
+ probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
+ labels: [P] Tensor, ground truth labels (between 0 and C - 1)
+
+ Same as soft dice loss when alpha=beta=0.5.
+ Same as Jaccord loss when alpha=beta=1.0.
+ See `Tversky loss function for image segmentation using 3D fully convolutional deep networks`
+ https://arxiv.org/pdf/1706.05721.pdf
+ '''
+ C = probas.size(1)
+ losses = []
+ for c in list(range(C)):
+ fg = (labels == c).float()
+ if fg.sum() == 0:
+ continue
+ class_pred = probas[:, c]
+ p0 = class_pred
+ p1 = 1 - class_pred
+ g0 = fg
+ g1 = 1 - fg
+ numerator = torch.sum(p0 * g0)
+ denominator = numerator + alpha * torch.sum(p0 * g1) + beta * torch.sum(p1 * g0)
+ losses.append(1 - ((numerator) / (denominator + epsilon)))
+ return mean(losses)
+
+
+def flatten_probas(probas, labels, ignore=255):
+ """
+ Flattens predictions in the batch
+ """
+ B, C, H, W = probas.size()
+ probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
+ labels = labels.view(-1)
+ if ignore is None:
+ return probas, labels
+ valid = (labels != ignore)
+ vprobas = probas[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobas, vlabels
+
+
+def isnan(x):
+ return x != x
+
+
+def mean(l, ignore_nan=False, empty=0):
+ """
+ nanmean compatible with generators.
+ """
+ l = iter(l)
+ if ignore_nan:
+ l = ifilterfalse(isnan, l)
+ try:
+ n = 1
+ acc = next(l)
+ except StopIteration:
+ if empty == 'raise':
+ raise ValueError('Empty mean')
+ return empty
+ for n, v in enumerate(l, 2):
+ acc += v
+ if n == 1:
+ return acc
+ return acc / n
+
+
+class SoftDiceLoss(nn.Module):
+ def __init__(self, ignore_index=255):
+ super(SoftDiceLoss, self).__init__()
+ self.ignore_index = ignore_index
+
+ def forward(self, pred, label):
+ pred = F.softmax(pred, dim=1)
+ return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=0.5, beta=0.5)
+
+
+class SoftJaccordLoss(nn.Module):
+ def __init__(self, ignore_index=255):
+ super(SoftJaccordLoss, self).__init__()
+ self.ignore_index = ignore_index
+
+ def forward(self, pred, label):
+ pred = F.softmax(pred, dim=1)
+ return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=1.0, beta=1.0)
diff --git a/preprocess/humanparsing/utils/transforms.py b/preprocess/humanparsing/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..1442a728938ca19fcb4ac21ae6588266df45631c
--- /dev/null
+++ b/preprocess/humanparsing/utils/transforms.py
@@ -0,0 +1,167 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import cv2
+import torch
+
+class BRG2Tensor_transform(object):
+ def __call__(self, pic):
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
+ if isinstance(img, torch.ByteTensor):
+ return img.float()
+ else:
+ return img
+
+class BGR2RGB_transform(object):
+ def __call__(self, tensor):
+ return tensor[[2,1,0],:,:]
+
+def flip_back(output_flipped, matched_parts):
+ '''
+ ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
+ '''
+ assert output_flipped.ndim == 4,\
+ 'output_flipped should be [batch_size, num_joints, height, width]'
+
+ output_flipped = output_flipped[:, :, :, ::-1]
+
+ for pair in matched_parts:
+ tmp = output_flipped[:, pair[0], :, :].copy()
+ output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
+ output_flipped[:, pair[1], :, :] = tmp
+
+ return output_flipped
+
+
+def fliplr_joints(joints, joints_vis, width, matched_parts):
+ """
+ flip coords
+ """
+ # Flip horizontal
+ joints[:, 0] = width - joints[:, 0] - 1
+
+ # Change left-right parts
+ for pair in matched_parts:
+ joints[pair[0], :], joints[pair[1], :] = \
+ joints[pair[1], :], joints[pair[0], :].copy()
+ joints_vis[pair[0], :], joints_vis[pair[1], :] = \
+ joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
+
+ return joints*joints_vis, joints_vis
+
+
+def transform_preds(coords, center, scale, input_size):
+ target_coords = np.zeros(coords.shape)
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
+ for p in range(coords.shape[0]):
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+ return target_coords
+
+def transform_parsing(pred, center, scale, width, height, input_size):
+
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
+ target_pred = cv2.warpAffine(
+ pred,
+ trans,
+ (int(width), int(height)), #(int(width), int(height)),
+ flags=cv2.INTER_NEAREST,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0))
+
+ return target_pred
+
+def transform_logits(logits, center, scale, width, height, input_size):
+
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
+ channel = logits.shape[2]
+ target_logits = []
+ for i in range(channel):
+ target_logit = cv2.warpAffine(
+ logits[:,:,i],
+ trans,
+ (int(width), int(height)), #(int(width), int(height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0))
+ target_logits.append(target_logit)
+ target_logits = np.stack(target_logits,axis=2)
+
+ return target_logits
+
+
+def get_affine_transform(center,
+ scale,
+ rot,
+ output_size,
+ shift=np.array([0, 0], dtype=np.float32),
+ inv=0):
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+ print(scale)
+ scale = np.array([scale, scale])
+
+ scale_tmp = scale
+
+ src_w = scale_tmp[0]
+ dst_w = output_size[1]
+ dst_h = output_size[0]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
+ dst_dir = np.array([0, (dst_w-1) * -0.5], np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ dst = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ dst[0, :] = [(dst_w-1) * 0.5, (dst_h-1) * 0.5]
+ dst[1, :] = np.array([(dst_w-1) * 0.5, (dst_h-1) * 0.5]) + dst_dir
+
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def affine_transform(pt, t):
+ new_pt = np.array([pt[0], pt[1], 1.]).T
+ new_pt = np.dot(t, new_pt)
+ return new_pt[:2]
+
+
+def get_3rd_point(a, b):
+ direct = a - b
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
+
+
+def get_dir(src_point, rot_rad):
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+
+ src_result = [0, 0]
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
+
+ return src_result
+
+
+def crop(img, center, scale, output_size, rot=0):
+ trans = get_affine_transform(center, scale, rot, output_size)
+
+ dst_img = cv2.warpAffine(img,
+ trans,
+ (int(output_size[1]), int(output_size[0])),
+ flags=cv2.INTER_LINEAR)
+
+ return dst_img
diff --git a/preprocess/humanparsing/utils/warmup_scheduler.py b/preprocess/humanparsing/utils/warmup_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2528a9c598d5ee3477d60e2f8591ec37e8afb41d
--- /dev/null
+++ b/preprocess/humanparsing/utils/warmup_scheduler.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+"""
+@Author : Peike Li
+@Contact : peike.li@yahoo.com
+@File : warmup_scheduler.py
+@Time : 3/28/19 2:24 PM
+@Desc :
+@License : This source code is licensed under the license found in the
+ LICENSE file in the root directory of this source tree.
+"""
+
+import math
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class GradualWarmupScheduler(_LRScheduler):
+ """ Gradually warm-up learning rate with cosine annealing in optimizer.
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
+ """
+
+ def __init__(self, optimizer, total_epoch, eta_min=0, warmup_epoch=10, last_epoch=-1):
+ self.total_epoch = total_epoch
+ self.eta_min = eta_min
+ self.warmup_epoch = warmup_epoch
+ super(GradualWarmupScheduler, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch <= self.warmup_epoch:
+ return [self.eta_min + self.last_epoch*(base_lr - self.eta_min)/self.warmup_epoch for base_lr in self.base_lrs]
+ else:
+ return [self.eta_min + (base_lr-self.eta_min)*(1+math.cos(math.pi*(self.last_epoch-self.warmup_epoch)/(self.total_epoch-self.warmup_epoch))) / 2 for base_lr in self.base_lrs]
+
+
+class SGDRScheduler(_LRScheduler):
+ """ Consine annealing with warm up and restarts.
+ Proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`.
+ """
+ def __init__(self, optimizer, total_epoch=150, start_cyclical=100, cyclical_base_lr=7e-4, cyclical_epoch=10, eta_min=0, warmup_epoch=10, last_epoch=-1):
+ self.total_epoch = total_epoch
+ self.start_cyclical = start_cyclical
+ self.cyclical_epoch = cyclical_epoch
+ self.cyclical_base_lr = cyclical_base_lr
+ self.eta_min = eta_min
+ self.warmup_epoch = warmup_epoch
+ super(SGDRScheduler, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch < self.warmup_epoch:
+ return [self.eta_min + self.last_epoch*(base_lr - self.eta_min)/self.warmup_epoch for base_lr in self.base_lrs]
+ elif self.last_epoch < self.start_cyclical:
+ return [self.eta_min + (base_lr-self.eta_min)*(1+math.cos(math.pi*(self.last_epoch-self.warmup_epoch)/(self.start_cyclical-self.warmup_epoch))) / 2 for base_lr in self.base_lrs]
+ else:
+ return [self.eta_min + (self.cyclical_base_lr-self.eta_min)*(1+math.cos(math.pi* ((self.last_epoch-self.start_cyclical)% self.cyclical_epoch)/self.cyclical_epoch)) / 2 for base_lr in self.base_lrs]
+
+
+if __name__ == '__main__':
+ import matplotlib.pyplot as plt
+ import torch
+ model = torch.nn.Linear(10, 2)
+ optimizer = torch.optim.SGD(params=model.parameters(), lr=7e-3, momentum=0.9, weight_decay=5e-4)
+ scheduler_warmup = SGDRScheduler(optimizer, total_epoch=150, eta_min=7e-5, warmup_epoch=10, start_cyclical=100, cyclical_base_lr=3.5e-3, cyclical_epoch=10)
+ lr = []
+ for epoch in range(0,150):
+ scheduler_warmup.step(epoch)
+ lr.append(scheduler_warmup.get_lr())
+ plt.style.use('ggplot')
+ plt.plot(list(range(0,150)), lr)
+ plt.show()
+
diff --git a/promptdresser/data/data_utils.py b/promptdresser/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..79ca238f501ac6dfd592786b456fedb57b3ceb5b
--- /dev/null
+++ b/promptdresser/data/data_utils.py
@@ -0,0 +1,445 @@
+import os, json
+from os.path import join as opj
+
+import numpy as np
+from scipy.ndimage import binary_dilation
+import cv2
+from PIL import Image, ImageDraw
+import torchvision.transforms as T
+
+from ..utils import split_procidx
+
+def remove_postfix(fn):
+ splits = fn.split("_")
+ if len(splits) > 3:
+ return fn
+ else:
+ return splits[0]
+
+def get_validation_pairs(
+ data_root_dir,
+ is_paired,
+ data_type,
+ category=None,
+ proc_idx=0,
+ n_proc=1,
+ n_samples=None,
+ prompt_version=None,
+ text_file_postfix=None,
+ use_dc_cloth=False,
+ test_file_postfix=None,
+ **kwargs,
+ ):
+ img_names = []
+ c_names = []
+ full_texts = []
+ clothing_texts = []
+ is_vitonhd = category is None
+ assert data_type in ["train", "test"]
+ if data_type == "train":
+ pair_postfix = "pairs"
+ else:
+ if is_paired:
+ pair_postfix = "pairs"
+ else:
+ pair_postfix = "unpairs"
+ if test_file_postfix is not None:
+ print(f"pair postfix : {test_file_postfix}")
+ pair_postfix = test_file_postfix
+ if not is_vitonhd:
+ assert category in ["upper_body", "lower_body", "dresses"]
+ data_root_dir = opj(data_root_dir, category)
+
+ if not use_dc_cloth:
+ txt_path = opj(data_root_dir, f"{data_type}_{pair_postfix}.txt")
+ else:
+ txt_path = opj(data_root_dir, f"{data_type}_{pair_postfix}_dc.txt")
+
+ with open(txt_path, "r") as f:
+ for line in f.readlines():
+ img_name, c_name = line.strip().split()
+ img_names.append(img_name)
+ c_names.append(c_name)
+ img_names, c_names = map(list, zip(*sorted(zip(img_names, c_names))))
+ img_names = split_procidx(img_names, n_proc, proc_idx)
+ c_names = split_procidx(c_names, n_proc, proc_idx)
+ img_names = img_names[:n_samples]
+ c_names = c_names[:n_samples]
+
+ prompter = Prompter(category="upper_body" if is_vitonhd else category, version=prompt_version, data_type=data_type)
+ if text_file_postfix is not None:
+ if is_vitonhd: # vitonhd
+ textfile_bn = f"{data_type}_{text_file_postfix}"
+ else:
+ textfile_bn = text_file_postfix
+
+ with open(opj(data_root_dir, textfile_bn), "rb") as f:
+ text_dict = json.load(f)
+
+ if text_file_postfix is None:
+ full_texts = ["" for _ in range(len(img_names))]
+ clothing_texts = ["" for _ in range(len(img_names))]
+ else:
+ for img_name, c_name in zip(img_names, c_names):
+ if (category == "upper_body") and ("012143_0" in img_name): # error
+ continue
+
+ img_fn = os.path.splitext(img_name)[0]
+ c_fn = os.path.splitext(c_name)[0]
+ if not is_vitonhd:
+ img_fn = remove_postfix(img_fn)
+ c_fn = remove_postfix(c_fn)
+
+ person_dict = text_dict[img_fn]["person"]
+ clothing_dict = text_dict[c_fn]["clothing"]
+ if "person" in text_dict[c_fn].keys():
+ clothing_person_dict = text_dict[c_fn]["person"]
+ else:
+ clothing_person_dict = person_dict
+ full_txt, clothing_txt = prompter.generate(person_dict, clothing_dict, clothing_person_dict)
+
+ full_texts.append(full_txt)
+ clothing_texts.append(clothing_txt)
+ return img_names, c_names, full_texts, clothing_texts
+
+
+class Prompter:
+ def __init__(self, category, version, data_type="train"):
+ assert category in ["upper_body", "lower_body", "dresses"]
+ self.category = category
+ self.version = version
+ self.data_type = data_type
+ print(f"category : {self.category}, version : {self.version}")
+
+ @staticmethod
+ def create_prompt_v12(person_dict, clothing_dict, clothing_person_dict):
+ clothing_template = "a {category}, {material}, with {sleeve}, {neckline}"
+ full_template = "a {body_shape} {gender} wears {fit_of_clothing}, {category} ({material}), {neckline}, {sleeve_rolling_style}, {tucking_style}. With {hair_length} hair, {pose} with hands {hand_pose}"
+
+ # tucking style
+ if "crop" in clothing_dict["upper cloth length"]:
+ tucking_style = "untucked"
+ else:
+ tucking_style = person_dict["tucking style"]
+
+ if "short" in clothing_dict["sleeve"]:
+ sleeve_rolling_style = "short sleeve"
+ else:
+ sleeve_rolling_style = clothing_person_dict["sleeve rolling style"]
+
+ clothing_prompt = clothing_template.format(
+ category=clothing_dict['upper cloth category'],
+ material=clothing_dict["material"],
+ sleeve=clothing_dict["sleeve"],
+ neckline=clothing_dict["neckline"],
+ ).lower()
+
+ full_prompt = full_template.format(
+ gender=person_dict["gender"],
+ body_shape=person_dict['body shape'],
+ hair_length=person_dict['hair length'],
+ pose=person_dict["pose"],
+ hand_pose=person_dict['hand pose'],
+ fit_of_clothing=person_dict['fit of upper cloth'],
+ sleeve_rolling_style=sleeve_rolling_style,
+ tucking_style=tucking_style,
+ category=clothing_dict['upper cloth category'],
+ material=clothing_dict["material"],
+ neckline=clothing_dict["neckline"],
+ ).lower()
+
+ return full_prompt, clothing_prompt
+
+ def generate(self, person_dict, clothing_dict, clothing_person_dict):
+ full_prompt, clothing_prompt = self.create_prompt_v12(person_dict, clothing_dict, clothing_person_dict)
+ return full_prompt, clothing_prompt
+
+class IdentityTransform:
+ def __call__(self, x):
+ return x
+
+def get_transform(txt_lst, **kwargs):
+ trans_lst = []
+ for tr in txt_lst:
+ tr = tr.lower()
+ if tr == "hflip":
+ trans_lst.append(T.RandomHorizontalFlip())
+ elif tr == "randomresizedcrop":
+ trans_lst.append(T.RandomResizedCrop((kwargs["img_h"], kwargs["img_w"]), scale=(0.8, 1)))
+ elif tr == "resize":
+ trans_lst.append(T.Resize((kwargs["img_h"], kwargs["img_w"]), antialias=True))
+ elif tr == "randomresizedcrop_dynamic":
+ trans_lst.append(T.RandomResizedCrop((kwargs["img_h"], kwargs["img_w"]), scale=(0.5, 1), ratio=(0.3,2)))
+ elif tr == "randomaffine":
+ trans_lst.append(T.RandomAffine(degrees=0, translate=(0,0), scale=(0.8, 1.2)))
+ elif tr == "randomaffine_dynamic":
+ trans_lst.append(T.RandomAffine(degrees=(-30,30), translate=(0.1, 0.2), scale=(0.8, 1.2), fill=246))
+ elif tr == "rotate":
+ trans_lst.append(T.RandomAffine(degrees=(-30,30), fill=246))
+ elif tr == "colorjitter":
+ trans_lst.append(T.ColorJitter(
+ brightness=(0.8,1.2),
+ contrast=(0.8,1.2),
+ saturation=(0.8,1.2),
+ hue=(-0.1,0.1),
+ ))
+ elif tr == "colorjitter2":
+ trans_lst.append(T.ColorJitter(
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1,
+ ))
+ elif tr == "elastictransform":
+ trans_lst.append(T.ElasticTransform())
+ elif tr == "identity":
+ trans_lst.append(IdentityTransform())
+ else:
+ raise NotImplementedError(tr)
+ return trans_lst
+
+label_map = {
+ "background": 0,
+ "hat": 1,
+ "hair": 2,
+ "sunglasses": 3,
+ "upper_clothes": 4,
+ "skirt": 5,
+ "pants": 6,
+ "dress": 7,
+ "belt": 8,
+ "left_shoe": 9,
+ "right_shoe": 10,
+ "head": 11,
+ "left_leg": 12,
+ "right_leg": 13,
+ "left_arm": 14,
+ "right_arm": 15,
+ "bag": 16,
+ "scarf": 17,
+}
+
+def extend_arm_mask(wrist, elbow, scale):
+ wrist = elbow + scale * (wrist - elbow)
+ return wrist
+
+
+def hole_fill(img):
+ img = np.pad(img[1:-1, 1:-1], pad_width=1, mode='constant', constant_values=0)
+ img_copy = img.copy()
+ mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
+
+ cv2.floodFill(img, mask, (0, 0), 255)
+ img_inverse = cv2.bitwise_not(img)
+ dst = cv2.bitwise_or(img_copy, img_inverse)
+ return dst
+
+
+def refine_mask(mask):
+ contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
+ area = []
+ for j in range(len(contours)):
+ a_d = cv2.contourArea(contours[j], True)
+ area.append(abs(a_d))
+ refine_mask = np.zeros_like(mask).astype(np.uint8)
+ if len(area) != 0:
+ i = area.index(max(area))
+ cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
+
+ return refine_mask
+
+def get_mask_location(
+ model_type,
+ category,
+ model_parse: Image.Image,
+ keypoint: dict,
+ width=384, height=512,
+ radius=5,
+ version=None,
+ only_cloth=False,
+ only_cloth_arm=False,
+ only_cloth_armneck_with_dilate=False,
+ densepose=None,
+ use_pad=False,
+):
+ if category is None or category == "": category = "upper_body"
+ im_parse = model_parse.resize((width, height), Image.NEAREST)
+ parse_array = np.array(im_parse)
+
+ if model_type == 'hd':
+ arm_width = 60
+ elif model_type == 'dc':
+ arm_width = 45
+ else:
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
+
+ parse_head = (parse_array == 1).astype(np.float32) + \
+ (parse_array == 3).astype(np.float32) + \
+ (parse_array == 11).astype(np.float32)
+
+ parser_mask_fixed = (parse_array == label_map["left_shoe"]).astype(np.float32) + \
+ (parse_array == label_map["right_shoe"]).astype(np.float32) + \
+ (parse_array == label_map["hat"]).astype(np.float32) + \
+ (parse_array == label_map["sunglasses"]).astype(np.float32) + \
+ (parse_array == label_map["bag"]).astype(np.float32)
+
+
+ parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
+
+ arms_left = (parse_array == 14).astype(np.float32)
+ arms_right = (parse_array == 15).astype(np.float32)
+
+ if category == 'dresses':
+ parse_mask = (parse_array == 7).astype(np.float32) + \
+ (parse_array == 4).astype(np.float32) + \
+ (parse_array == 5).astype(np.float32) + \
+ (parse_array == 6).astype(np.float32)
+
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
+
+ elif category == 'upper_body':
+ parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
+ parser_mask_fixed_lower_cloth = (parse_array == label_map["skirt"]).astype(np.float32) + \
+ (parse_array == label_map["pants"]).astype(np.float32)
+
+ parser_mask_fixed_nolower = parser_mask_fixed.copy()
+
+ parser_mask_fixed += parser_mask_fixed_lower_cloth
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
+ elif category == 'lower_body':
+ parse_mask = (parse_array == 6).astype(np.float32) + \
+ (parse_array == 12).astype(np.float32) + \
+ (parse_array == 13).astype(np.float32) + \
+ (parse_array == 5).astype(np.float32)
+
+ parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
+ (parse_array == 14).astype(np.float32) + \
+ (parse_array == 15).astype(np.float32)
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
+ else:
+ raise NotImplementedError(f"category {category}")
+
+ if only_cloth:
+ mask = Image.fromarray(parse_mask.astype(np.uint8) * 255)
+ mask_gray = Image.fromarray(parse_mask.astype(np.uint8) * 127)
+ if use_pad:
+ dil_mask = binary_dilation(mask, iterations=30)
+ mask = Image.fromarray(dil_mask.astype(np.uint8)*255)
+ mask_gray = Image.fromarray(dil_mask.astype(np.uint8)*127)
+ return mask, mask_gray
+ elif only_cloth_arm:
+ parse_mask = parse_mask + (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32)
+ mask = Image.fromarray(parse_mask.astype(np.uint8) * 255)
+ mask_gray = Image.fromarray(parse_mask.astype(np.uint8) * 127)
+ return mask, mask_gray
+ elif only_cloth_armneck_with_dilate:
+ parse_mask = parse_mask + (parse_array == 14).astype(np.float32) + (parse_array == 15).astype(np.float32)
+ parse_mask = cv2.dilate(parse_mask.astype(np.uint8), np.ones((3,3), dtype=np.uint8), iterations=3)
+ parse_mask = np.logical_or(parse_mask, (parse_array == 18).astype(np.float32))
+ mask = Image.fromarray(parse_mask.astype(np.uint8) * 255)
+ mask_gray = Image.fromarray(parse_mask.astype(np.uint8) * 127)
+ return mask, mask_gray
+
+
+ # Load pose points
+ pose_data = keypoint["pose_keypoints_2d"]
+ pose_data = np.array(pose_data)
+ pose_data = pose_data.reshape((-1, 2))
+
+ im_arms_left = Image.new('L', (width, height))
+ im_arms_right = Image.new('L', (width, height))
+ arms_draw_left = ImageDraw.Draw(im_arms_left)
+ arms_draw_right = ImageDraw.Draw(im_arms_right)
+ if category == 'dresses' or category == 'upper_body':
+ shoulder_right = np.multiply(tuple(pose_data[2][:2]), height / 512.0)
+ shoulder_left = np.multiply(tuple(pose_data[5][:2]), height / 512.0)
+ elbow_right = np.multiply(tuple(pose_data[3][:2]), height / 512.0)
+ elbow_left = np.multiply(tuple(pose_data[6][:2]), height / 512.0)
+ wrist_right = np.multiply(tuple(pose_data[4][:2]), height / 512.0)
+ wrist_left = np.multiply(tuple(pose_data[7][:2]), height / 512.0)
+
+ hip_right = np.multiply(tuple(pose_data[8][:2]), height / 512.0)
+ hip_left = np.multiply(tuple(pose_data[11][:2]), height / 512.0)
+
+ ARM_LINE_WIDTH = int(arm_width / 512 * height)
+ size_left = [shoulder_left[0] - ARM_LINE_WIDTH // 2, shoulder_left[1] - ARM_LINE_WIDTH // 2, shoulder_left[0] + ARM_LINE_WIDTH // 2, shoulder_left[1] + ARM_LINE_WIDTH // 2]
+ size_right = [shoulder_right[0] - ARM_LINE_WIDTH // 2, shoulder_right[1] - ARM_LINE_WIDTH // 2, shoulder_right[0] + ARM_LINE_WIDTH // 2,
+ shoulder_right[1] + ARM_LINE_WIDTH // 2]
+
+ if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
+ im_arms_right = arms_right
+ else:
+ wrist_right = extend_arm_mask(wrist_right, elbow_right, 1.2)
+ arms_draw_right.line(np.concatenate((shoulder_right, elbow_right, wrist_right)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
+ arms_draw_right.arc(size_right, 0, 360, 'white', ARM_LINE_WIDTH // 2)
+
+ if wrist_left[0] <= 1. and wrist_left[1] <= 1.:
+ im_arms_left = arms_left
+ else:
+ wrist_left = extend_arm_mask(wrist_left, elbow_left, 1.2)
+ arms_draw_left.line(np.concatenate((wrist_left, elbow_left, shoulder_left)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
+ arms_draw_left.arc(size_left, 0, 360, 'white', ARM_LINE_WIDTH // 2)
+
+
+ if version == "v9":
+ pass
+ else:
+ hands_left = np.logical_and(np.logical_not(im_arms_left), arms_left)
+ hands_right = np.logical_and(np.logical_not(im_arms_right), arms_right)
+ hands = hands_left + hands_right
+ parser_mask_fixed += hands
+ if category == 'upper_body':
+ parser_mask_fixed_nolower += hands
+
+ parser_mask_fixed = np.logical_or(parser_mask_fixed, parse_head)
+ if category == 'upper_body':
+ parser_mask_fixed_nolower = np.logical_or(parser_mask_fixed_nolower, parse_head)
+ parse_mask = cv2.dilate(parse_mask, np.ones((radius, radius), np.uint16), iterations=5)
+ if category == 'dresses' or category == 'upper_body':
+ neck_mask = (parse_array == 18).astype(np.float32)
+ neck_mask = cv2.dilate(neck_mask, np.ones((radius, radius), np.uint16), iterations=1)
+ neck_mask = np.logical_and(neck_mask, np.logical_not(parse_head))
+ parse_mask = np.logical_or(parse_mask, neck_mask)
+
+ if version == 'v7':
+ arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=15)
+ else:
+ arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=4)
+ parse_mask += np.logical_or(parse_mask, arm_mask)
+
+ parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
+
+ parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
+ inpaint_mask = 1 - parse_mask_total
+ img = np.where(inpaint_mask, 255, 0)
+ dst = hole_fill(img.astype(np.uint8))
+ dst = refine_mask(dst)
+ if version not in ["v5", "v6"]:
+ inpaint_mask = dst / 255 * 1
+
+ if version is None or "official" in version:
+ pass
+ elif version == 'v9':
+ inpaint_mask = dst
+ hip_x1, hip_y1 = hip_left.astype(np.int64)
+ hip_x2, hip_y2 = hip_right.astype(np.int64)
+ inpaint_mask[hip_y1, hip_x1] = 255
+ inpaint_mask[hip_y2, hip_x2] = 255
+ coords = np.column_stack(np.where(inpaint_mask == 255))
+ y_min, x_min = coords.min(axis=0)
+ y_max, x_max = coords.max(axis=0)
+
+ inpaint_mask[y_min:y_max+1, x_min:x_max+1] = 255
+ inpaint_mask = np.logical_and(inpaint_mask, np.logical_not(parser_mask_fixed_nolower))
+ inpaint_mask = hole_fill(inpaint_mask.astype(np.uint8))
+ inpaint_mask = np.logical_and(inpaint_mask, np.logical_not(parse_head))
+ inpaint_mask = inpaint_mask.astype(np.uint8)
+ else:
+ raise NotImplementedError(f"upper body version {version}")
+
+
+ mask = Image.fromarray(inpaint_mask.astype(np.uint8) * 255)
+ mask_gray = Image.fromarray(inpaint_mask.astype(np.uint8) * 127)
+
+ return mask, mask_gray
\ No newline at end of file
diff --git a/promptdresser/models/attention_processor.py b/promptdresser/models/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b715741a6e503bb8c1769309ccfe30c0d492799f
--- /dev/null
+++ b/promptdresser/models/attention_processor.py
@@ -0,0 +1,476 @@
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.utils import USE_PEFT_BACKEND
+
+class AttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ *args,
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ *args,
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ self.attn_map = ip_attention_probs
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+
+class IPAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapater for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ *args,
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ with torch.no_grad():
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+ #print(self.attn_map.shape)
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+## for controlnet
+class CNAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(self, num_tokens=4):
+ self.num_tokens = num_tokens
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, *args, **kwargs,):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CNAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, num_tokens=4):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.num_tokens = num_tokens
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ *args,
+ **kwargs,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
\ No newline at end of file
diff --git a/promptdresser/models/cloth_encoder.py b/promptdresser/models/cloth_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7865237534188c567d7c54d77e4b108576286068
--- /dev/null
+++ b/promptdresser/models/cloth_encoder.py
@@ -0,0 +1,1285 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from os.path import join as opj
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+import json
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ GLIGENTextBoundingboxProjection,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import (
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+
+from diffusers.models.lora import LoRALinearLayer
+from promptdresser.utils import zero_rank_print_
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class Identity(torch.nn.Module):
+ r"""A placeholder identity operator that is argument-insensitive.
+
+ Args:
+ args: any argument (unused)
+ kwargs: any keyword argument (unused)
+
+ Shape:
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
+ - Output: :math:`(*)`, same shape as the input.
+
+ Examples::
+
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
+ >>> input = torch.randn(128, 20)
+ >>> output = m(input)
+ >>> print(output.size())
+ torch.Size([128, 20])
+
+ """
+ def __init__(self, scale=None, *args, **kwargs) -> None:
+ super(Identity, self).__init__()
+
+ def forward(self, input, *args, **kwargs):
+ return input
+class _LoRACompatibleLinear(nn.Module):
+ """
+ A Linear layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self):
+ pass
+
+ def _unfuse_lora(self):
+ pass
+
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
+ return hidden_states
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class ClothEncoder(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
+ *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ if len(up_block_types) == 3:
+ self.use_sd15 = False
+ else:
+ self.use_sd15 = True
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ if not self.use_sd15:
+ self.up_blocks[1].attentions[2].transformer_blocks[1].attn1.to_q = _LoRACompatibleLinear()
+ self.up_blocks[1].attentions[2].transformer_blocks[1].attn1.to_k = _LoRACompatibleLinear()
+ self.up_blocks[1].attentions[2].transformer_blocks[1].attn1.to_v = _LoRACompatibleLinear()
+ self.up_blocks[1].attentions[2].transformer_blocks[1].attn1.to_out = nn.ModuleList([Identity(), Identity()])
+ self.up_blocks[1].attentions[2].transformer_blocks[1].norm2 = Identity()
+ self.up_blocks[1].attentions[2].transformer_blocks[1].attn2 = None
+ self.up_blocks[1].attentions[2].transformer_blocks[1].norm3 = Identity()
+ self.up_blocks[1].attentions[2].transformer_blocks[1].ff = Identity()
+ self.up_blocks[1].attentions[2].proj_out = Identity()
+ self.up_blocks[1].upsamplers[0] = Identity()
+ self.up_blocks[2] = Identity()
+ else:
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
+ self.up_blocks[3].attentions[2].proj_out = Identity()
+
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ def add_c_text_proj_layer(self):
+ from .mutual_self_attention import torch_dfs
+ from diffusers.models.attention import BasicTransformerBlock
+ attn_modules = [module for module in (torch_dfs(self.mid_block)+torch_dfs(self.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module.c_text_proj_layer = True
+
+ self.config.add_c_text_proj_layer = True
+ zero_rank_print_("clothing encoder proj layer check")
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
+ example from ControlNet side model(s)
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ if self.use_sd15:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ else:
+ if not is_final_block:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ if self.use_sd15:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ scale=lora_scale,
+ )
+
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/promptdresser/models/combined_model.py b/promptdresser/models/combined_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af50d0bc38dab92691f6b12f541aad1310228e7
--- /dev/null
+++ b/promptdresser/models/combined_model.py
@@ -0,0 +1,11 @@
+import torch.nn as nn
+
+class CombinedModel(nn.Module):
+ def __init__(self, unet, cloth_encoder):
+ super().__init__()
+ self.unet = unet
+ self.cloth_encoder = cloth_encoder
+ def forward(self, x):
+ return x
+
+
\ No newline at end of file
diff --git a/promptdresser/models/mutual_self_attention.py b/promptdresser/models/mutual_self_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7bc18050b000f3efcdda09cc282f026ce9b1e30
--- /dev/null
+++ b/promptdresser/models/mutual_self_attention.py
@@ -0,0 +1,608 @@
+# Copyright 2023 ByteDance and/or its affiliates.
+#
+# Copyright (2023) MagicAnimate Authors
+#
+# ByteDance, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from ByteDance or
+# its affiliates is strictly prohibited.
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import random
+
+from typing import Any, Dict, Optional, Tuple
+
+from diffusers.models.attention import BasicTransformerBlock
+from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
+
+def torch_dfs(model: torch.nn.Module):
+ result = [model]
+ for child in model.children():
+ result += torch_dfs(child)
+ return result
+
+class ReferenceAttentionControl():
+ def __init__(self,
+ unet=None,
+ mode="write",
+ do_classifier_free_guidance=False,
+ attention_auto_machine_weight = float('inf'),
+ gn_auto_machine_weight = 1.0,
+ style_fidelity = 1.0,
+ reference_attn=True,
+ reference_adain=False,
+ fusion_blocks="full",
+ batch_size=1,
+ is_train=False,
+ is_second_stage=False,
+ use_jointcond=False,
+ ) -> None:
+ # 10. Modify self attention and group norm
+ self.unet = unet
+ assert mode in ["read", "write"] #, "write_control"]
+ assert fusion_blocks in ["midup", "full"] #, "downmid"]
+ self.reference_attn = reference_attn
+ self.reference_adain = reference_adain
+ self.fusion_blocks = fusion_blocks
+ self.batch_size = batch_size
+ self.is_train = is_train
+ self.is_second_stage=is_second_stage
+ self.add_clothing_text = getattr(unet, "add_clothing_text", False)
+ self.do_classifier_free_guidance = do_classifier_free_guidance
+ self.use_jointcond = use_jointcond
+
+ self.register_reference_hooks(
+ mode,
+ do_classifier_free_guidance,
+ attention_auto_machine_weight,
+ gn_auto_machine_weight,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ fusion_blocks,
+ batch_size=batch_size,
+ is_train=is_train,
+ is_second_stage=is_second_stage,
+ add_clothing_text=self.add_clothing_text,
+ use_jointcond=self.use_jointcond
+ )
+
+
+
+ def register_reference_hooks(
+ self,
+ mode,
+ do_classifier_free_guidance,
+ attention_auto_machine_weight,
+ gn_auto_machine_weight,
+ style_fidelity,
+ reference_attn,
+ reference_adain,
+ dtype=torch.float16,
+ batch_size=1,
+ num_images_per_prompt=1,
+ device=torch.device("cpu"),
+ fusion_blocks='full',
+ is_train=False,
+ is_second_stage=False,
+ add_clothing_text=False,
+ use_jointcond=False,
+ ):
+ MODE = mode
+ do_classifier_free_guidance = do_classifier_free_guidance
+ attention_auto_machine_weight = attention_auto_machine_weight
+ gn_auto_machine_weight = gn_auto_machine_weight
+ style_fidelity = style_fidelity
+ reference_attn = reference_attn
+ reference_adain = reference_adain
+ fusion_blocks = fusion_blocks
+ num_images_per_prompt = num_images_per_prompt
+ dtype=dtype
+ batch_size=batch_size
+ is_train=is_train
+ is_second_stage=is_second_stage
+ add_clothing_text=add_clothing_text
+ use_jointcond=use_jointcond
+
+ if do_classifier_free_guidance:
+ uc_mask = (
+ torch.Tensor([1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16)
+ .to(device)
+ .bool()
+ )
+ else:
+ uc_mask = (
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
+ .to(device)
+ .bool()
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.clone())
+
+
+ if getattr(self, "c_text_proj_layer", None):
+ self.bank.append(encoder_hidden_states.clone())
+
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if getattr(self, "c_text_proj_layer", None):
+ c_text = self.bank[-1]
+ if c_text.shape[-1] == 2048:
+ self.bank[-1] = self.c_text_proj_layer(c_text)
+
+ if getattr(self, "c_attn1", None):
+ hidden_states_uc_p = self.attn1(norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states], dim=1),
+ attention_mask=attention_mask)
+ hidden_states_uc_c = self.c_attn1(norm_hidden_states,
+ encoder_hidden_states=torch.cat(self.bank, dim=1),
+ attention_mask=attention_mask)
+
+ hidden_states_uc = hidden_states_uc_p + hidden_states_uc_c * self.gate_val.to(dtype=hidden_states_uc_c.dtype).tanh() + hidden_states
+
+ else:
+ hidden_states_uc = self.attn1(norm_hidden_states,
+ encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),
+ attention_mask=attention_mask,
+ ) + hidden_states
+
+ hidden_states_c = hidden_states_uc.clone()
+ if use_jointcond: # joint cond이면 input에서 zero tensor
+ hidden_states = hidden_states_uc
+ else:
+ if is_train and not is_second_stage:
+ _uc_mask = self.cfg_uc_mask.clone()
+ assert hidden_states.shape[0] == _uc_mask.shape[0], f"in training, cfg_uc_mask is used to drop the reference images so that batch_size must be equal : {hidden_states.shape[0]} vs {_uc_mask.shape[0]}"
+ else:
+ _uc_mask = uc_mask.clone()
+
+
+ if do_classifier_free_guidance and torch.any(_uc_mask) and not is_second_stage:
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
+ _uc_mask = (
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
+ .to(device)
+ .bool()
+ )
+
+ hidden_states_c[_uc_mask] = self.attn1(
+ norm_hidden_states[_uc_mask],
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
+ attention_mask=attention_mask
+ ) + hidden_states[_uc_mask]
+ hidden_states = hidden_states_c.clone()
+
+ # self.bank.clear()
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask].to(hidden_states_c.dtype)
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ attn_modules = [module for module in (torch_dfs(self.unet.down_blocks)+torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ if self.is_train:
+ cfg_uc_mask = torch.BoolTensor([
+ True if random.random() < 0.1 else False for _ in range(self.batch_size)
+ ])
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if self.is_train:
+ module.cfg_uc_mask = cfg_uc_mask.clone()
+
+
+ if self.reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ def update(self, writer, dtype=torch.float16):
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
+ writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock)]
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ if self.is_train:
+ cfg_uc_mask = torch.BoolTensor([
+ True if random.random() < 0.1 else False for _ in range(self.batch_size)
+ ])
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
+ r.bank = [v.clone().to(dtype) for v in w.bank]
+ # w.bank.clear()
+ if self.is_train:
+ r.cfg_uc_mask = cfg_uc_mask.clone()
+
+
+ if self.reference_adain:
+ reader_gn_modules = [self.unet.mid_block]
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ reader_gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ reader_gn_modules.append(module)
+
+ writer_gn_modules = [writer.unet.mid_block]
+
+ down_blocks = writer.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ writer_gn_modules.append(module)
+
+ up_blocks = writer.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ writer_gn_modules.append(module)
+
+ for r, w in zip(reader_gn_modules, writer_gn_modules):
+ if len(w.mean_bank) > 0 and isinstance(w.mean_bank[0], list):
+ r.mean_bank = [[v.clone().to(dtype) for v in vl] for vl in w.mean_bank]
+ r.var_bank = [[v.clone().to(dtype) for v in vl] for vl in w.var_bank]
+ else:
+ r.mean_bank = [v.clone().to(dtype) for v in w.mean_bank]
+ r.var_bank = [v.clone().to(dtype) for v in w.var_bank]
+
+ def clear(self):
+ if self.reference_attn:
+ if self.fusion_blocks == "midup":
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ elif self.fusion_blocks == "full":
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.down_blocks) + torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for r in reader_attn_modules:
+ r.bank.clear()
+ if self.reference_adain:
+ reader_gn_modules = [self.unet.mid_block]
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ reader_gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ reader_gn_modules.append(module)
+
+ for r in reader_gn_modules:
+ r.mean_bank.clear()
+ r.var_bank.clear()
+
\ No newline at end of file
diff --git a/promptdresser/models/pose_encoder.py b/promptdresser/models/pose_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0d64ab37ac0af34561f70cfb488730689964830
--- /dev/null
+++ b/promptdresser/models/pose_encoder.py
@@ -0,0 +1,52 @@
+from typing import Tuple
+import torch.nn as nn
+from torch.nn import functional as F
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
\ No newline at end of file
diff --git a/promptdresser/models/unet.py b/promptdresser/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..60aed31c95e785a4f046c356a9c4bac00b352ba0
--- /dev/null
+++ b/promptdresser/models/unet.py
@@ -0,0 +1,1298 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ GLIGENTextBoundingboxProjection,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import (
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ UNetMidBlock2DSimpleCrossAttn,
+ get_down_block,
+ get_up_block,
+)
+
+from promptdresser.utils import zero_rank_print_
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# for clothing text proj layer
+import torch.nn.functional as F
+class MLP(nn.Module):
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ # self.dummy = nn.Parameter(torch.Tensor([0.]))
+ # self.dummy.requires_grad = False # trick to check device and dtype
+
+ def forward(self, x):
+ # x = x.to(device=self.dummy.device, dtype=self.dummy.dtype)
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
+ *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: int = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads=64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim[-1],
+ attention_head_dim=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type is None:
+ self.mid_block = None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ def expand_first_conv(self, additional_in_channel=4):
+ old_conv = self.conv_in
+ old_in_channels = old_conv.in_channels
+ new_conv = nn.Conv2d(old_in_channels+additional_in_channel, old_conv.out_channels, kernel_size=old_conv.kernel_size, padding=old_conv.padding)
+ for key, param in new_conv.named_parameters():
+ nn.init.zeros_(param)
+ new_conv.weight.data[:, :old_in_channels] = old_conv.weight.data
+ new_conv.bias.data = old_conv.bias.data
+ self.conv_in = new_conv
+ self.config.in_channels = old_in_channels + additional_in_channel
+ zero_rank_print_("expand first conv of UNet")
+
+ def add_c_text_proj_layer(self):
+ from .mutual_self_attention import torch_dfs
+ from diffusers.models.attention import BasicTransformerBlock
+ attn_modules = [module for module in (torch_dfs(self.mid_block)+torch_dfs(self.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ # module.c_text_proj_layer = MLP(input_dim=2048, hidden_dim=768, output_dim=module.attn1.inner_dim, num_layers=2)
+ module.c_text_proj_layer = nn.Linear(2048, module.attn1.inner_dim, bias=False)
+
+ self.use_c_text_proj_layer = True
+ zero_rank_print_("unet clothing text projection layers are added")
+
+ def add_c_self_attention(self, fusion_blocks):
+ from copy import deepcopy
+ from .mutual_self_attention import torch_dfs
+ from diffusers.models.attention import BasicTransformerBlock
+
+ if fusion_blocks == "midup":
+ attn_modules = [module for module in (torch_dfs(self.mid_block)+torch_dfs(self.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ elif fusion_blocks == "full":
+ attn_modules = [module for module in torch_dfs(self) if isinstance(module, BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module.c_attn1 = deepcopy(module.attn1)
+ module.gate_val = nn.Parameter(torch.Tensor([0.]), requires_grad=True)
+
+ self.use_c_self_attention = True
+ zero_rank_print_("unet clothing self attention layers are added")
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor_timestep(self, timestep):
+ def fn_recursive_attn_processor(module: torch.nn.Module, timestep):
+ if hasattr(module, "set_timestep"):
+ module.set_timestep(timestep)
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(child, timestep)
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(module, timestep=timestep)
+
+ def set_attn_processor_posetemp_map(self, posetemp_map):
+ def fn_recursive_attn_processor(module: torch.nn.Module, posetemp_map):
+ if hasattr(module, "set_posetemp_map"):
+ module.set_posetemp_map(posetemp_map)
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(child, posetemp_map)
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(module, posetemp_map=posetemp_map)
+
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1, s2, b1, b2):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+
+ pose_encoder_input=None,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
+ example from ControlNet side model(s)
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb, hint = self.add_embedding(image_embs, hint)
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ if hasattr(self, "pose_encoder") and not hasattr(self, "controlnet"):
+ pose_feature = self.pose_encoder(pose_encoder_input)
+ sample = sample + pose_feature
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ self.set_attn_processor_timestep(timestep=timestep)
+
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ scale=lora_scale,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/promptdresser/pipelines/sdxl.py b/promptdresser/pipelines/sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9e052b24219b28839acbaacc2d2635729654d4f
--- /dev/null
+++ b/promptdresser/pipelines/sdxl.py
@@ -0,0 +1,2795 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import os
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+from promptdresser.models.mutual_self_attention import ReferenceAttentionControl
+from promptdresser.data.data_utils import get_mask_location
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ ... ).images[0]
+ ```
+"""
+
+def get_pred_x0(scheduler, noise_pred, t, latents, extra_step_kwargs, vae, image_processor, output_type):
+ pred_z0 = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)["pred_original_sample"]
+ if hasattr(scheduler, "_step_index"):
+ scheduler._step_index -= 1
+ pred_x0 = vae.decode(pred_z0 / vae.config.scaling_factor, return_dict=False)[0]
+ pred_x0 = image_processor.postprocess(pred_x0, output_type=output_type)[0]
+ return pred_x0
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+ deprecate(
+ "prepare_mask_and_masked_image",
+ "0.30.0",
+ deprecation_message,
+ )
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class PromptDresser(
+ DiffusionPipeline,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ FromSingleFileMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "add_neg_time_ids",
+ "mask",
+ "masked_image_latents",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.pose_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ self.use_posetemp = False
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+
+ save_eos=False,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ # if save_eos:
+ # self.eos_loc_index = untruncated_ids.shape[1] - 1
+ # print(f"save eos : {self.eos_loc_index}")
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, use_jointcond
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ if use_jointcond:
+ mask = torch.cat([mask] * 3)
+ else:
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ if use_jointcond:
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 3)
+ )
+ else:
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ timesteps = timesteps[-num_inference_steps:]
+ return timesteps, num_inference_steps
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def denoising_start(self):
+ return self._denoising_start
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+
+ cloth_encoder=None,
+ cloth_encoder_image=None,
+ prompt_clothing=None,
+ prompt_embeds_clothing: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds_clothing: Optional[torch.FloatTensor] = None,
+ pose_image=None,
+ use_jointcond=False,
+ guidance_scale_img=4.5,
+ guidance_scale_text=7.5,
+ interm_cloth_start_ratio=0.5,
+ detach_cloth_encoder=False,
+
+ category="upper_body",
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+ self._use_jointcond = use_jointcond
+ self._guidance_scale_img = guidance_scale_img
+ self._guidance_scale_text = guidance_scale_text
+
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds_clothing,
+ _,
+ pooled_prompt_embeds_clothing,
+ _,
+ ) = self.encode_prompt(
+ prompt=prompt_clothing,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds_clothing,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds_clothing,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid else None,
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ return_image_latents = True
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ cloth_encoder_init_image = self.image_processor.preprocess(cloth_encoder_image, height=height, width=width)
+ cloth_encoder_init_image = cloth_encoder_init_image.to(dtype=torch.float32)
+ cloth_encoder_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents=None,
+ image=cloth_encoder_init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+ _, _, cloth_encoder_latents = cloth_encoder_latents_outputs
+
+ if self._use_jointcond:
+ cloth_encoder_null_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents=None,
+ image=torch.zeros_like(cloth_encoder_init_image),
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+ _, _, cloth_encoder_null_latents = cloth_encoder_null_latents_outputs
+ cloth_encoder_latents = torch.cat([cloth_encoder_null_latents] + [cloth_encoder_latents] * 2)
+ else:
+ cloth_encoder_latents = torch.cat([cloth_encoder_latents] * 2) if self.do_classifier_free_guidance else cloth_encoder_latents
+
+ pose_init_image = self.image_processor.preprocess(pose_image, height=height, width=width)
+ pose_encoder_input = None
+ pose_init_image = pose_init_image.to(dtype=torch.float32)
+ pose_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=pose_init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+
+ _, _, pose_latents = pose_latents_outputs
+
+ if self._use_jointcond:
+ pose_latents = torch.cat([pose_latents] * 3)
+ else:
+ pose_latents = torch.cat([pose_latents] * 2) if self.do_classifier_free_guidance else pose_latents
+
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ self._use_jointcond,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet == 13:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ num_channels_pose_image = pose_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image + num_channels_pose_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ add_text_embeds_clothing = pooled_prompt_embeds_clothing
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self._use_jointcond:
+ prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds_clothing = torch.cat([negative_prompt_embeds, prompt_embeds_clothing, prompt_embeds_clothing], dim=0)
+ add_text_embeds_clothing = torch.cat([negative_pooled_prompt_embeds, add_text_embeds_clothing, add_text_embeds_clothing], dim=0)
+ else:
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds_clothing = torch.cat([negative_prompt_embeds, prompt_embeds_clothing], dim=0)
+ add_text_embeds_clothing = torch.cat([negative_pooled_prompt_embeds, add_text_embeds_clothing], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ prompt_embeds_clothing = prompt_embeds_clothing.to(device)
+ add_text_embeds_clothing = add_text_embeds_clothing.to(device)
+
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self._use_jointcond:
+ image_embeds = torch.cat([negative_image_embeds, negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+ else:
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+
+ if not detach_cloth_encoder:
+ reference_control_writer = ReferenceAttentionControl(cloth_encoder, do_classifier_free_guidance=True, mode="write", fusion_blocks="midup" if os.environ.get("MIDUP_FUSION_BLOCK") else "full", batch_size=batch_size, is_train=False, is_second_stage=False, use_jointcond=self._use_jointcond)
+ reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode="read", fusion_blocks="midup" if os.environ.get("MIDUP_FUSION_BLOCK") else "full", batch_size=batch_size, is_train=False, is_second_stage=False, use_jointcond=self._use_jointcond)
+
+ if not detach_cloth_encoder:
+ zero_timesteps = torch.randint(0, 1, (cloth_encoder_latents.shape[0],), device=cloth_encoder_latents.device)
+ zero_timesteps = zero_timesteps.long()
+ added_cond_kwargs_clothing = {"text_embeds": add_text_embeds_clothing, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs_clothing["image_embeds"] = image_embeds
+
+ cloth_encoder(
+ cloth_encoder_latents,
+ zero_timesteps,
+ encoder_hidden_states=prompt_embeds_clothing,
+ added_cond_kwargs=added_cond_kwargs_clothing,
+ return_dict=False,
+ )[0]
+ reference_control_reader.update(reference_control_writer)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ interm_cloth_start_timestep = int(len(timesteps) * interm_cloth_start_ratio)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # expand the latents if we are doing classifier free guidance
+ if self._use_jointcond:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+ if num_channels_unet == 13:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents, pose_latents], dim=1)
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ pose_encoder_input=pose_encoder_input,
+ )[0]
+
+
+ # perform guidance
+ if self._use_jointcond:
+ noise_pred_uncond, noise_pred_cond_img, noise_pred_cond_imgtext = noise_pred.chunk(3)
+ noise_pred = noise_pred_uncond + \
+ self._guidance_scale_img * (noise_pred_cond_img - noise_pred_uncond) + \
+ self._guidance_scale_text * (noise_pred_cond_imgtext - noise_pred_cond_img)
+ else:
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+ if not detach_cloth_encoder:
+ reference_control_reader.clear()
+ reference_control_writer.clear()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def get_interm_clothmask(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+
+ cloth_encoder=None,
+ cloth_encoder_image=None,
+ prompt_clothing=None,
+ prompt_embeds_clothing: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds_clothing: Optional[torch.FloatTensor] = None,
+ pose_image=None,
+ use_jointcond=False,
+ guidance_scale_img=4.5,
+ guidance_scale_text=7.5,
+ interm_cloth_start_ratio=0.5,
+ detach_cloth_encoder=False,
+
+ category="upper_body",
+ use_pad=False,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._denoising_start = denoising_start
+ self._interrupt = False
+ self._use_jointcond = use_jointcond
+ self._guidance_scale_img = guidance_scale_img
+ self._guidance_scale_text = guidance_scale_text
+
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds_clothing,
+ _,
+ pooled_prompt_embeds_clothing,
+ _,
+ ) = self.encode_prompt(
+ prompt=prompt_clothing,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds_clothing,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds_clothing,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
+
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps,
+ strength,
+ device,
+ denoising_start=self.denoising_start if denoising_value_valid else None,
+ )
+
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ return_image_latents = True
+
+ add_noise = True if self.denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ cloth_encoder_init_image = self.image_processor.preprocess(cloth_encoder_image, height=height, width=width)
+ cloth_encoder_init_image = cloth_encoder_init_image.to(dtype=torch.float32)
+ cloth_encoder_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents=None,
+ image=cloth_encoder_init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+ _, _, cloth_encoder_latents = cloth_encoder_latents_outputs
+
+ if self._use_jointcond:
+ cloth_encoder_null_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents=None,
+ image=torch.zeros_like(cloth_encoder_init_image),
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+ _, _, cloth_encoder_null_latents = cloth_encoder_null_latents_outputs
+ cloth_encoder_latents = torch.cat([cloth_encoder_null_latents] + [cloth_encoder_latents] * 2)
+ else:
+ cloth_encoder_latents = torch.cat([cloth_encoder_latents] * 2) if self.do_classifier_free_guidance else cloth_encoder_latents
+
+ pose_init_image = self.image_processor.preprocess(pose_image, height=height, width=width)
+ pose_encoder_input = None
+ pose_init_image = pose_init_image.to(dtype=torch.float32)
+ pose_latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=pose_init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=True,
+ )
+
+ _, _, pose_latents = pose_latents_outputs
+ if self._use_jointcond:
+ pose_latents = torch.cat([pose_latents] * 3)
+ else:
+ pose_latents = torch.cat([pose_latents] * 2) if self.do_classifier_free_guidance else pose_latents
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ self.do_classifier_free_guidance,
+ self._use_jointcond,
+ )
+
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet == 13:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ num_channels_pose_image = pose_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image + num_channels_pose_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds_clothing
+ add_text_embeds_clothing = pooled_prompt_embeds_clothing
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if self._use_jointcond:
+ prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds_clothing = torch.cat([negative_prompt_embeds, prompt_embeds_clothing, prompt_embeds_clothing], dim=0)
+ add_text_embeds_clothing = torch.cat([negative_pooled_prompt_embeds, add_text_embeds_clothing, add_text_embeds_clothing], dim=0)
+ else:
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds_clothing = torch.cat([negative_prompt_embeds, prompt_embeds_clothing], dim=0)
+ add_text_embeds_clothing = torch.cat([negative_pooled_prompt_embeds, add_text_embeds_clothing], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ prompt_embeds_clothing = prompt_embeds_clothing.to(device)
+ add_text_embeds_clothing = add_text_embeds_clothing.to(device)
+ if ip_adapter_image is not None:
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
+ image_embeds, negative_image_embeds = self.encode_image(
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
+ )
+ if self._use_jointcond:
+ image_embeds = torch.cat([negative_image_embeds, negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+ else:
+ if self.do_classifier_free_guidance:
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
+ image_embeds = image_embeds.to(device)
+
+ if not detach_cloth_encoder:
+ reference_control_writer = ReferenceAttentionControl(cloth_encoder, do_classifier_free_guidance=True, mode="write", fusion_blocks="midup" if os.environ.get("MIDUP_FUSION_BLOCK") else "full", batch_size=batch_size, is_train=False, is_second_stage=False, use_jointcond=self._use_jointcond)
+ reference_control_reader = ReferenceAttentionControl(self.unet, do_classifier_free_guidance=True, mode="read", fusion_blocks="midup" if os.environ.get("MIDUP_FUSION_BLOCK") else "full", batch_size=batch_size, is_train=False, is_second_stage=False, use_jointcond=self._use_jointcond)
+
+ if not detach_cloth_encoder:
+ zero_timesteps = torch.randint(0, 1, (cloth_encoder_latents.shape[0],), device=cloth_encoder_latents.device)
+ zero_timesteps = zero_timesteps.long()
+ added_cond_kwargs_clothing = {"text_embeds": add_text_embeds_clothing, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs_clothing["image_embeds"] = image_embeds
+
+ cloth_encoder(
+ cloth_encoder_latents,
+ zero_timesteps,
+ encoder_hidden_states=prompt_embeds_clothing,
+ added_cond_kwargs=added_cond_kwargs_clothing,
+ return_dict=False,
+ )[0]
+ reference_control_reader.update(reference_control_writer)
+
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ self.denoising_end is not None
+ and self.denoising_start is not None
+ and denoising_value_valid(self.denoising_end)
+ and denoising_value_valid(self.denoising_start)
+ and self.denoising_start >= self.denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {self.denoising_end} when using type float."
+ )
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ interm_cloth_start_timestep = int(len(timesteps) * interm_cloth_start_ratio)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+
+ # expand the latents if we are doing classifier free guidance
+ if self._use_jointcond:
+ latent_model_input = torch.cat([latents] * 3)
+ else:
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+ if num_channels_unet == 13:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents, pose_latents], dim=1)
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ pose_encoder_input=pose_encoder_input,
+ )[0]
+
+
+ # perform guidance
+ if self._use_jointcond:
+ noise_pred_uncond, noise_pred_cond_img, noise_pred_cond_imgtext = noise_pred.chunk(3)
+ noise_pred = noise_pred_uncond + \
+ self._guidance_scale_img * (noise_pred_cond_img - noise_pred_uncond) + \
+ self._guidance_scale_text * (noise_pred_cond_imgtext - noise_pred_cond_img)
+ else:
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+
+ if i == interm_cloth_start_timestep:
+ if not hasattr(self, "parsing_model_hd"):
+ from preprocess.humanparsing.run_parsing import Parsing
+ self.parsing_model_hd = Parsing(device.index)
+
+ pred_x0 = get_pred_x0(self.scheduler, noise_pred, t, latents, extra_step_kwargs, self.vae, self.image_processor, output_type)
+ model_parse, _ = self.parsing_model_hd(pred_x0.resize((384, 512)))
+ mask_image_clothing, _ = get_mask_location("hd", category, model_parse, keypoint=None, only_cloth=True, use_pad=use_pad)
+ interm_cloth_mask_image = PIL.Image.fromarray(np.array(mask_image_clothing))
+
+ reference_control_reader.clear()
+ reference_control_writer.clear()
+
+ return interm_cloth_mask_image
+
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if self.do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+ mask = callback_outputs.pop("mask", mask)
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+ if not detach_cloth_encoder:
+ reference_control_reader.clear()
+ reference_control_writer.clear()
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
+
+
diff --git a/promptdresser/utils.py b/promptdresser/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a8a047ca7cd8c6bea451e6fcb6230650541f1f
--- /dev/null
+++ b/promptdresser/utils.py
@@ -0,0 +1,296 @@
+import os
+from os.path import join as opj
+import json
+import math
+
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+import torch
+from torchvision.transforms import functional as TF
+from safetensors.torch import load_file as sf_load_file
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+def zero_rank_print_(s):
+ if "LOCAL_RANK" in os.environ.keys():
+ if int(os.environ["LOCAL_RANK"]) == 0:
+ print(s)
+ else:
+ print(s)
+
+def save_args(args, to_path):
+ with open(to_path, "w") as f:
+ json.dump(args.__dict__, f, indent=2)
+
+def load_args(from_path):
+ with open(from_path, "r") as f:
+ args_dict = json.load(f)
+ return args_dict
+
+def load_file(p):
+ if p.endswith(".safetensors"):
+ cp = sf_load_file(p)
+ else:
+ cp = torch.load(p, map_location="cpu")
+ return cp
+
+def tensor2pil(tensor, is_mask=False):
+ tensor = tensor.cpu()
+ if is_mask:
+ return Image.fromarray(np.uint8(tensor[0][0].numpy() * 255)).convert("RGB")
+ else:
+ tensor = (tensor[0].permute(1,2,0)+1) * 127.5
+ return Image.fromarray(np.uint8(tensor))
+
+def concat_pil_imgs(pil_img_lst):
+ max_img_h = -1
+ ratio_lst = []
+ for pil_img in pil_img_lst:
+ img_w, img_h = pil_img.size
+ max_img_h = max(max_img_h, img_h)
+ ratio_lst.append(img_w / img_h)
+
+
+ new_img_lst = []
+ for pil_img, ratio in zip(pil_img_lst, ratio_lst):
+ np_img = np.array(pil_img.resize((int(ratio * max_img_h), max_img_h)))
+ if np_img.ndim == 2:
+ np_img = np.stack([np_img] * 3, axis=-1)
+ if np_img.shape[-1] == 1:
+ np_img = np.concatenate([np_img]*3, axis=-1)
+ new_img_lst.append(np_img)
+
+ concat_img = np.concatenate(new_img_lst, axis=1)
+ return Image.fromarray(concat_img)
+
+@torch.no_grad()
+def get_attn_map(hidden_states, encoder_hidden_states, attn, norm_axis=-1):
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_wieght_logit = attn_weight
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = attn_weight.cpu().mean(dim=1)
+
+
+ min_ = attn_weight.min(dim=norm_axis, keepdims=True)[0]
+ max_ = attn_weight.max(dim=norm_axis, keepdims=True)[0]
+ norm_attn_weight = (attn_weight - min_) / (max_ - min_) * 255.0
+ norm_attn_weight = norm_attn_weight.numpy().astype(np.uint8)
+ return norm_attn_weight, attn_wieght_logit
+
+def pad_resize(img, trg_h, trg_w, pixel_value, pad_type=None):
+ if pad_type is None:
+ img = img.resize((trg_w, trg_h))
+ else:
+ cur_w, cur_h = img.size
+ pad_w = max(trg_w - cur_w, 0)
+ pad_h = max(trg_h - cur_h, 0)
+
+ pad_left = pad_w // 2
+ pad_right = pad_w - pad_left
+ pad_top = pad_h // 2
+ pad_bottom = pad_h - pad_top
+
+ padding = (pad_left, pad_top, pad_right, pad_bottom)
+
+ img = TF.pad(img, padding=padding, fill=pixel_value, padding_mode=pad_type)
+ return img
+
+def get_inputs(
+ root_dir, data_type, pose_type, img_bn, c_bn, img_h, img_w, train_folder_name, test_folder_name,
+ # use_repaint, train_folder_name_for_interm_cloth_mask=None, test_repaint_folder_name=None,
+ # return_inversion_latents=False,
+ category=None, pad_type=None, use_dc_cloth=False
+):
+ is_vitonhd = category is None or category == ""
+ img_fn = os.path.splitext(img_bn)[0]
+ if is_vitonhd:
+ if data_type == "train":
+ folder_name = train_folder_name if train_folder_name is not None else "train"
+ else:
+ folder_name = test_folder_name if test_folder_name is not None else "test"
+
+ person = Image.open(opj(root_dir, f"{folder_name}/image", img_bn)).convert("RGB").resize((img_w, img_h))
+ mask = Image.open(opj(root_dir, f"{folder_name}/agnostic-mask", f"{img_fn}_mask.png")).convert("RGB").resize((img_w, img_h))
+ cloth = Image.open(opj(root_dir, f"{folder_name}/cloth", c_bn)).convert("RGB").resize((img_w, img_h))
+
+ if pose_type == "openpose": pose = Image.open(opj(root_dir, f"{folder_name}/dwpose", f"{img_fn}.png")).convert("RGB").resize((img_w, img_h))
+ elif pose_type == "openpose_thick": pose = Image.open(opj(root_dir, f"{folder_name}/dwpose_thick", f"{img_fn}.png")).convert("RGB").resize((img_w, img_h))
+ elif pose_type == "densepose": pose = Image.open(opj(root_dir, f"{folder_name}/image-densepose", f"{img_fn}.jpg")).convert("RGB").resize((img_w, img_h))
+
+ person = Image.open(opj(root_dir, f"{folder_name}/image", img_bn)).convert("RGB")
+ mask = Image.open(opj(root_dir, f"{folder_name}/agnostic-mask", f"{img_fn}_mask.png")).convert("RGB")
+ if not use_dc_cloth:
+ cloth = Image.open(opj(root_dir, f"{folder_name}/cloth", c_bn)).convert("RGB")
+ else:
+ cloth = Image.open(opj(root_dir, f"{folder_name}/cloth_dc", c_bn)).convert("RGB")
+
+ if pose_type == "openpose": pose = Image.open(opj(root_dir, f"{folder_name}/dwpose", f"{img_fn}.png")).convert("RGB")
+ elif pose_type == "openpose_thick": pose = Image.open(opj(root_dir, f"{folder_name}/dwpose_thick", f"{img_fn}.png")).convert("RGB")
+ elif pose_type == "densepose": pose = Image.open(opj(root_dir, f"{folder_name}/image-densepose", f"{img_fn}.jpg")).convert("RGB")
+
+ person = pad_resize(person, img_h, img_w, (255,255,255), pad_type=pad_type)
+ if pad_type is None or pad_type == "resize":
+ other_pad_type = None
+ else:
+ other_pad_type = "constant"
+ mask = pad_resize(mask, img_h, img_w, (0,0,0), pad_type=other_pad_type)
+ cloth = pad_resize(cloth, img_h, img_w, (255,255,255), pad_type=other_pad_type)
+ pose = pad_resize(pose, img_h, img_w, (0,0,0), pad_type=other_pad_type)
+
+ return person, mask, pose, cloth
+
+def get_leanable_param_count(model_name, model):
+ named_param = model.named_parameters()
+ total_count = 0
+ lparam_count = 0
+ not_lparam_count = 0
+ for name, param in named_param:
+ if param.requires_grad:
+ lparam_count += 1
+ else:
+ not_lparam_count += 1
+ total_count += 1
+ return f" {model_name} | total : {total_count}, lparam : {lparam_count}, not_lparam : {not_lparam_count}"
+
+def split_procidx(ps, n_proc, proc_idx):
+ len_ps = len(ps)
+ if len_ps % n_proc == 0:
+ n_infer = len_ps // n_proc
+ else:
+ n_infer = len_ps // n_proc + 1
+
+ start_idx = int(proc_idx * n_infer)
+ end_idx = start_idx + n_infer
+ ps = ps[start_idx:end_idx]
+ return ps
+
+def get_tensor(img, h, w, is_mask=False):
+ img = np.array(img.resize((w, h))).astype(np.float32)
+ if not is_mask:
+ img = (img / 127.5) - 1.0
+ else:
+ img = (img < 128).astype(np.float32)[:,:,None]
+ return torch.from_numpy(img)[None].cuda()
+
+def get_batch(image, cloth, densepose, agn_img, agn_mask, img_h, img_w):
+ batch = dict()
+ batch["image"] = get_tensor(image, img_h, img_w)
+ batch["cloth"] = get_tensor(cloth, img_h, img_w)
+ batch["image_densepose"] = get_tensor(densepose, img_h, img_w)
+ batch["agn"] = get_tensor(agn_img, img_h, img_w)
+ batch["agn_mask"] = get_tensor(agn_mask, img_h, img_w, is_mask=True)
+ batch["txt"] = ""
+ return batch
+
+def tensor2img(x):
+ '''
+ x : [BS x c x H x W] or [c x H x W]
+ '''
+ if x.ndim == 3:
+ x = x.unsqueeze(0)
+ BS, C, H, W = x.shape
+ x = x.permute(0,2,3,1).reshape(-1, W, C).detach().cpu().numpy()
+ x = np.clip(x, -1, 1)
+ x = (x+1)/2
+ x = np.uint8(x*255.0)
+ if x.shape[-1] == 1:
+ x = np.concatenate([x,x,x], axis=-1)
+ return x
+
+def center_crop(image):
+ width, height = image.size
+ new_height = height
+ new_width = height*3/4
+ left = (width - new_width)/2
+ top = (height - new_height)/2
+ right = (width + new_width)/2
+ bottom = (height + new_height)/2
+
+ image = image.crop((left, top, right, bottom))
+ return image
+
+def get_lora_target_modules(named_modules, all_names, any_names, not_names):
+ output = []
+ lora_modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d]
+ for key, module in named_modules:
+ if all(all_name in key for all_name in all_names) and any(any_name in key for any_name in any_names) and not any(not_name in key for not_name in not_names):
+ for lora_module in lora_modules:
+ if isinstance(module, lora_module):
+ output.append(key)
+ return output
+
+
+def unfreeze_unet(unet, all_names, any_names, not_names):
+ for key, param in unet.named_parameters():
+ if all(all_name in key for all_name in all_names) and any(any_name in key for any_name in any_names) and not any(not_name in key for not_name in not_names):
+ param.requires_grad_(True)
+
+
+def get_txt(jf, person_id, clothing_id=None, prompt_version="v5", category="upper_body", verbose=True):
+ from .data.data_utils import Prompter
+ pt = Prompter(category=category, version=prompt_version)
+ if clothing_id is None:
+ clothing_id = person_id
+ person_dict = jf[person_id]["person"]
+ clothing_dict = jf[clothing_id]["clothing"]
+ clothing_person_dict = jf[clothing_id]["person"]
+ full_txt, clothing_txt = pt.generate(person_dict, clothing_dict, clothing_person_dict)
+ if verbose:
+ print(full_txt)
+ print("\n")
+ print(clothing_txt)
+ print("\n\n")
+
+def concat_save_images(ps_lst, save_dir, cut_right_two=False):
+ import cv2
+ from tqdm import tqdm
+ os.makedirs(save_dir, exist_ok=True)
+
+ min_value = min([len(ps) for ps in ps_lst])
+ for i in tqdm(range(min_value), total=min_value):
+ concat = []
+ for ps in ps_lst:
+ p = ps[i]
+ concat.append(cv2.imread(p))
+ concat = np.concatenate(concat, axis=1)
+ if cut_right_two:
+ concat = concat[:,:-2*768]
+
+ save_p = opj(save_dir, os.path.basename(p))
+ cv2.imwrite(save_p, concat)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..19413ef27ea621e3ce918612a5ab618a2e3c43c2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,18 @@
+--extra-index-url https://download.pytorch.org/whl/cu126
+torch==2.7.1
+torchvision==0.22.1
+torchaudio==2.7.1
+pillow==11.2.1
+transformers==4.52.4
+numpy==1.26.4
+diffusers==0.33.1
+tqdm==4.67.1
+opencv-python==4.11.0.86
+onnxruntime==1.22.0
+matplotlib==3.10.3
+scipy==1.11.0
+controlnet-aux==0.0.10
+accelerate==1.8.1
+mediapipe==0.10.21
+gradio==5.34.2
+huggingface-hub==0.33.0
\ No newline at end of file
diff --git a/test/00008_00.jpg b/test/00008_00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2fb91da5997ccdfad61d79c8ab70855a4b966f7f
--- /dev/null
+++ b/test/00008_00.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70516f73beb8c275c457a589bc7db74c4f503daa1c0037a7e02dce1b93dedebf
+size 118479
diff --git a/test/mask.png b/test/mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc8413900d468b60abee54cdd4b1f27af085eb23
--- /dev/null
+++ b/test/mask.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:70f11f96344c6d2e9366c27a3c5daeef62ef12c2f22e50054c8e366ee5967e44
+size 6256
diff --git a/test/mask2.png b/test/mask2.png
new file mode 100644
index 0000000000000000000000000000000000000000..b4980add65b3d3722a95d945fca00bfc7494c94a
--- /dev/null
+++ b/test/mask2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56a4a2991f73a18ee05f5e3b0144ec944aaeff4c6e6fc4e5f74236147c02c6ca
+size 2805
diff --git a/test/output_image.jpg b/test/output_image.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f9042464d5d2f288f4cb0d757c2ab08972896c84
--- /dev/null
+++ b/test/output_image.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:09ee76d86575e8b4de595282831477a3db645174cc96d8c9c7a9640a9b9ed678
+size 47734
diff --git a/test/person.jpg b/test/person.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..752db761f80aa1b4a61b78f6647a6636da523a00
--- /dev/null
+++ b/test/person.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c29753ccb0e245ea96920afce565548bddee173494b4e044510f95b5b2daa9ce
+size 87905
diff --git a/test/person2.png b/test/person2.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c8a06e6f3fb0ebb30de3e4b8109d81c87164402
--- /dev/null
+++ b/test/person2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc8b289bd90009320c37c80b3318773643c5cc692c4e996806bf6ef9a9725fe1
+size 198632
diff --git a/test/pose.png b/test/pose.png
new file mode 100644
index 0000000000000000000000000000000000000000..9def3be07f475866ad8fa3fd6b9627ce02fe0e26
--- /dev/null
+++ b/test/pose.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6358e6f810218d59b355c4a3dd77680bfa280a5d853e5bc196705858ee2c2910
+size 9810
diff --git a/test/pose2.png b/test/pose2.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a15e637a1309a37dc5a826c128a0293f5d03efe
--- /dev/null
+++ b/test/pose2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0c105fe0817d6db77d0b8fb42188f7bdfcba207685dd29a34d37d0be0eabde4a
+size 5637