weiyuyeh commited on
Commit
a03472d
·
1 Parent(s): e41b2aa
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +57 -0
  2. README.md +2 -2
  3. app.py +152 -0
  4. requirements.txt +51 -0
  5. src/multiview_consist_edit/MVHumanNet_multi.py +403 -0
  6. src/multiview_consist_edit/Thuman2_multi.py +366 -0
  7. src/multiview_consist_edit/config/infer_tryon_multi.yaml +44 -0
  8. src/multiview_consist_edit/config/train_tryon_multi.yaml +137 -0
  9. src/multiview_consist_edit/data/MVHumanNet_multi.py +406 -0
  10. src/multiview_consist_edit/data/Thuman2_multi.py +367 -0
  11. src/multiview_consist_edit/data/camera_utils.py +479 -0
  12. src/multiview_consist_edit/infer_tryon_multi.py +185 -0
  13. src/multiview_consist_edit/models/ReferenceEncoder.py +67 -0
  14. src/multiview_consist_edit/models/ReferenceNet.py +1146 -0
  15. src/multiview_consist_edit/models/ReferenceNet_attention_multi_fp16.py +297 -0
  16. src/multiview_consist_edit/models/attention.py +320 -0
  17. src/multiview_consist_edit/models/condition_encoder.py +395 -0
  18. src/multiview_consist_edit/models/embeddings.py +385 -0
  19. src/multiview_consist_edit/models/hack_poseguider.py +97 -0
  20. src/multiview_consist_edit/models/hack_unet2d.py +329 -0
  21. src/multiview_consist_edit/models/mv_attn_processor.py +132 -0
  22. src/multiview_consist_edit/models/resnet.py +212 -0
  23. src/multiview_consist_edit/models/unet.py +523 -0
  24. src/multiview_consist_edit/parse_tool/postprocess_parse.py +42 -0
  25. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/__init__.py +0 -0
  26. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/datasets.py +201 -0
  27. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/simple_extractor_dataset.py +89 -0
  28. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/target_generation.py +40 -0
  29. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/__init__.py +5 -0
  30. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/bn.py +132 -0
  31. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/deeplab.py +84 -0
  32. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/dense.py +42 -0
  33. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/functions.py +245 -0
  34. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/misc.py +21 -0
  35. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/residual.py +182 -0
  36. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/checks.h +15 -0
  37. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn.cpp +95 -0
  38. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn.h +88 -0
  39. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cpu.cpp +119 -0
  40. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cuda.cu +333 -0
  41. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cuda_half.cu +275 -0
  42. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/checks.h +15 -0
  43. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/common.h +49 -0
  44. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/cuda.cuh +71 -0
  45. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/AugmentCE2P.py +388 -0
  46. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/__init__.py +12 -0
  47. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/mobilenetv2.py +156 -0
  48. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/resnet.py +205 -0
  49. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/resnext.py +149 -0
  50. src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/context_encoding/aspp.py +64 -0
.gitignore ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 编译文件和缓存
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+ *.pyd
6
+ *.so
7
+
8
+ # Python 虚拟环境
9
+ venv/
10
+ .env/
11
+ .venv/
12
+ env/
13
+ virtualenvs/
14
+ .Python/
15
+
16
+ # Python 打包和分发
17
+ build/
18
+ dist/
19
+ *.egg-info/
20
+ *.egg
21
+ *.whl
22
+ *.tar.gz
23
+
24
+ # 测试相关
25
+ .coverage
26
+ htmlcov/
27
+ .pytest_cache/
28
+ .mypy_cache/
29
+
30
+ # IDE 和编辑器
31
+ .idea/
32
+ .vscode/
33
+ *.suo
34
+ *.sublime-workspace
35
+ *.sublime-project
36
+
37
+ # 环境变量文件
38
+ .env
39
+ .env.local
40
+ .env.*
41
+
42
+ # 日志文件
43
+ *.log
44
+ *.log.*
45
+
46
+ # 系统文件
47
+ .DS_Store
48
+ Thumbs.db
49
+
50
+ src/render_from_thuman/ckpt/
51
+
52
+ # data
53
+ demo_data/
54
+
55
+ # models
56
+ src/multiview_consist_edit/checkpoints/
57
+ src/multiview_consist_edit/parse_tool/ckpt/
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: VTON360
3
- emoji:
4
- colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.38.2
 
1
  ---
2
  title: VTON360
3
+ emoji: 🐢
4
+ colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.38.2
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import shutil
5
+ import sys
6
+
7
+ target_paths = {
8
+ "data": "/home/user/app/upload/data.zip",
9
+ "data_dir": "/home/user/app/upload/data",
10
+ "config": "/home/user/app/src/multiview_consist_edit/config/infer_tryon_multi.yaml",
11
+ "output_data": "/home/user/app/image_output_tryon_mvhumannet",
12
+ "output_zip": "/home/user/app/outputs/result.zip",
13
+ }
14
+
15
+ def unzip_data():
16
+ if os.path.exists(target_paths["data"]):
17
+ if os.path.exists(target_paths["data_dir"]):
18
+ shutil.rmtree(target_paths["data_dir"])
19
+ os.makedirs(target_paths["data_dir"], exist_ok=True)
20
+ shutil.unpack_archive(target_paths["data"], target_paths["data_dir"])
21
+ return target_paths["data_dir"]
22
+ else:
23
+ raise FileNotFoundError("Data file not found at " + target_paths["data"])
24
+
25
+
26
+ def zip_outputs():
27
+ if os.path.exists(target_paths["output_zip"]):
28
+ os.remove(target_paths["output_zip"])
29
+ shutil.make_archive(target_paths["output_zip"].replace(".zip", ""), 'zip', root_dir=target_paths["output_data"])
30
+ return target_paths["output_zip"]
31
+
32
+
33
+ def start_inference_stream():
34
+ process = subprocess.Popen(
35
+ ["python", "src/multiview_consist_edit/infer_tryon_multi.py"],
36
+ stdout=subprocess.PIPE,
37
+ stderr=subprocess.STDOUT,
38
+ text=True,
39
+ bufsize=1,
40
+ universal_newlines=True
41
+ )
42
+
43
+ output = []
44
+ for line in process.stdout:
45
+ output.append(line)
46
+ yield "".join(output)
47
+
48
+ def install_package(package_name):
49
+ try:
50
+ result = subprocess.run(
51
+ [sys.executable, "-m", "pip", "install", package_name],
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.PIPE,
54
+ text=True,
55
+ )
56
+ output = result.stdout + "\n" + result.stderr
57
+ return output
58
+ except Exception as e:
59
+ return f"Error: {str(e)}"
60
+
61
+
62
+ def show_package(pkg_name):
63
+ try:
64
+ result = subprocess.run(
65
+ [sys.executable, "-m", "pip", "show", pkg_name],
66
+ stdout=subprocess.PIPE,
67
+ stderr=subprocess.PIPE,
68
+ text=True,
69
+ )
70
+ return result.stdout if result.stdout else result.stderr
71
+ except Exception as e:
72
+ return str(e)
73
+
74
+
75
+ def uninstall_package(package_name):
76
+ try:
77
+ result = subprocess.run(
78
+ [sys.executable, "-m", "pip", "uninstall", package_name, "-y"],
79
+ stdout=subprocess.PIPE,
80
+ stderr=subprocess.PIPE,
81
+ text=True,
82
+ )
83
+ output = result.stdout + "\n" + result.stderr
84
+ return output
85
+ except Exception as e:
86
+ return f"Error: {str(e)}"
87
+
88
+ # print(uninstall_package("datasets"))
89
+ # print(install_package("uvicorn==0.30.6"))
90
+ # print(install_package("huggingface_hub==0.25.1"))
91
+ # print(install_package("diffusers==0.25.1"))
92
+ # print(install_package("gradio==5.0.0"))
93
+ # print("package version set complete")
94
+
95
+ def save_files(data_file, config_file):
96
+ os.makedirs(os.path.dirname(target_paths["data"]), exist_ok=True)
97
+ os.makedirs(os.path.dirname(target_paths["config"]), exist_ok=True)
98
+
99
+ shutil.copy(data_file.name, target_paths["data"])
100
+ shutil.copy(config_file.name, target_paths["config"])
101
+ unzip_data()
102
+ return "檔案已成功上傳、儲存並解壓縮了!"
103
+
104
+
105
+ with gr.Blocks(theme=gr.themes.Origin()) as demo:
106
+ gr.Markdown("## 請先上傳檔案")
107
+ with gr.Row():
108
+ data_input = gr.File(label="上傳資料壓縮檔", file_types=[".zip"])
109
+ config_input = gr.File(label="Config 檔", file_types=[".yaml", ".yml"])
110
+
111
+ upload_button = gr.Button("上傳並儲存")
112
+ output = gr.Textbox(label="狀態")
113
+
114
+
115
+ gr.Markdown("## Inference")
116
+ with gr.Column():
117
+ log_output = gr.Textbox(label="Inference Log", lines=20)
118
+ infer_btn = gr.Button("Start Inference")
119
+
120
+ gr.Markdown("## Pip Installer")
121
+ with gr.Column():
122
+ with gr.Row():
123
+ pkg_input = gr.Textbox(lines=1, placeholder="輸入想安裝的套件名稱,例如 diffusers 或 numpy==1.2.0")
124
+ install_output = gr.Textbox(label="Install Output", lines=10)
125
+ install_btn = gr.Button("Install Package")
126
+
127
+ gr.Markdown("## Pip Uninstaller")
128
+ with gr.Column():
129
+ with gr.Row():
130
+ pkg_input2 = gr.Textbox(lines=1, placeholder="輸入想解除安裝的套件名稱,例如 diffusers 或 numpy")
131
+ uninstall_output = gr.Textbox(label="Uninstall Output", lines=10)
132
+ uninstall_btn = gr.Button("Uninstall Package")
133
+
134
+ gr.Markdown("## Pip show")
135
+ with gr.Column():
136
+ with gr.Row():
137
+ show_input = gr.Textbox(label="輸入套件名稱(如 diffusers)")
138
+ show_output = gr.Textbox(label="套件資訊", lines=10)
139
+ show_btn = gr.Button("pip show")
140
+
141
+ gr.Markdown("## Download results")
142
+ with gr.Column():
143
+ file_output = gr.File(label="點擊下載", interactive=True)
144
+ download_btn = gr.Button("下載結果")
145
+
146
+ show_btn.click(fn=show_package, inputs=show_input, outputs=show_output)
147
+ download_btn.click(fn=zip_outputs, outputs=file_output)
148
+ install_btn.click(fn=install_package, inputs=pkg_input, outputs=install_output)
149
+ infer_btn.click(fn=start_inference_stream, outputs=log_output)
150
+ uninstall_btn.click(fn=uninstall_package, inputs=pkg_input2, outputs=uninstall_output)
151
+ upload_button.click(fn=save_files,inputs=[data_input, config_input],outputs=output)
152
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ accelerate==0.25.0
3
+ av==12.3.0
4
+ basicsr==1.4.2
5
+ black==25.1.0
6
+ cityscapesscripts==2.2.4
7
+ cloudpickle==3.1.1
8
+ diffusers==0.25.1
9
+ einops==0.8.1
10
+ fairscale==0.4.13
11
+ fvcore==0.1.5.post20221221
12
+ gsplat==0.1.2.1
13
+ hydra-core==1.3.2
14
+ iopath==0.1.10
15
+ kornia==0.7.3
16
+ matplotlib==3.10.3
17
+ mmcv==2.2.0
18
+ mmdet==3.3.0
19
+ nerfstudio==1.0.0
20
+ numpy==1.24.4
21
+ omegaconf==2.3.0
22
+ onnx==1.17.0
23
+ onnxruntime==1.16.2
24
+ open_clip_torch==2.22.0
25
+ opencv_python==4.8.0.76
26
+ packaging==25.0
27
+ Pillow==11.2.1
28
+ pycocotools==2.0.8
29
+ Pygments==2.19.1
30
+ pytorch_msssim==1.0.0
31
+ PyYAML==6.0.1
32
+ Requests==2.32.3
33
+ safetensors==0.5.3
34
+ scikit_learn==1.6.1
35
+ scipy==1.15.3
36
+ setuptools==69.5.1
37
+ Shapely==2.1.0
38
+ scikit-image
39
+ tabulate==0.9.0
40
+ taichi==1.7.3
41
+ taichi_glsl==0.0.12
42
+ termcolor==3.1.0
43
+ timm
44
+ torch==2.1.2+cu118
45
+ torchvision==0.16.2+cu118
46
+ torchmetrics==1.7.1
47
+ tqdm==4.66.4
48
+ transformers==4.42.3
49
+ typing_extensions==4.13.2
50
+
51
+ xformers==0.0.23.post1
src/multiview_consist_edit/MVHumanNet_multi.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from PIL import Image,ImageDraw
4
+ import json
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from transformers import CLIPProcessor
10
+ import random
11
+ from torchvision.transforms import functional as F
12
+ import torch.distributed as dist
13
+ import copy
14
+ import cv2
15
+ import pickle
16
+ from .camera_utils import read_camera_mvhumannet
17
+
18
+ def crop_and_resize(img, bbox, size):
19
+
20
+ # 计算中心点和新的宽高
21
+ center_x = (bbox[0] + bbox[2]) / 2
22
+ center_y = (bbox[1] + bbox[3]) / 2
23
+ new_height = bbox[3] - bbox[1]
24
+ new_width = int(new_height * (2 / 3))
25
+
26
+ # 计算新的边界框
27
+ new_bbox = [
28
+ int(center_x - new_width / 2),
29
+ int(center_y - new_height / 2),
30
+ int(center_x + new_width / 2),
31
+ int(center_y + new_height / 2)
32
+ ]
33
+
34
+ # 裁剪图像
35
+ cropped_img = img.crop(new_bbox)
36
+
37
+ # 调整大小
38
+ resized_img = cropped_img.resize(size)
39
+
40
+ return resized_img
41
+
42
+
43
+ class MVHumanNet_Dataset(Dataset):
44
+ def __init__(
45
+ self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8,
46
+ ):
47
+ im_names = []
48
+ self.dataroot = os.path.join(dataroot, 'processed_mvhumannet')
49
+ self.cloth_root = os.path.join(dataroot, 'cloth')
50
+ self.data_ids = []
51
+ self.data_frame_ids = []
52
+ self.cloth_ids = []
53
+ self.cloth_frame_ids = []
54
+ if is_train:
55
+ f = open(os.path.join(dataroot,'train_frame_ids.txt'))
56
+ for line in f.readlines():
57
+ line_info = line.strip().split()
58
+ self.data_ids.append(line_info[0])
59
+ self.data_frame_ids.append(line_info[1])
60
+ f.close()
61
+ else:
62
+ f = open(os.path.join(dataroot, 'test_ids.txt'))
63
+ for line in f.readlines():
64
+ line_info = line.strip().split()
65
+ self.data_ids.append(line_info[0])
66
+ self.data_frame_ids.append(line_info[1])
67
+ f.close()
68
+ f2 = open(os.path.join(dataroot, 'test_cloth_ids.txt'))
69
+ # f2 = open(os.path.join(dataroot, 'test_mvg_cloth_ids.txt'))
70
+ for line in f2.readlines():
71
+ line_info = line.strip().split()
72
+ self.cloth_ids.append(line_info[0])
73
+ self.cloth_frame_ids.append(line_info[1])
74
+ f2.close()
75
+
76
+ self.is_train = is_train
77
+ self.sample_size = sample_size
78
+ self.multi_length = multi_length
79
+ self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=True)
80
+
81
+ self.pixel_transforms = transforms.Compose([
82
+ #transforms.Resize((1024,768), interpolation=0),
83
+ #transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
84
+ transforms.Resize(self.sample_size, interpolation=0),
85
+ # transforms.CenterCrop(self.sample_size),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
88
+ ])
89
+
90
+ self.pixel_transforms_0 = transforms.Compose([
91
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
92
+ ])
93
+ self.pixel_transforms_1 = transforms.Compose([
94
+ # transforms.Resize((1024,768), interpolation=0),
95
+ # transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
96
+ transforms.Resize(self.sample_size, interpolation=0),
97
+ ])
98
+
99
+ self.ref_transforms_train = transforms.Compose([
100
+ transforms.Resize(self.sample_size),
101
+ # RandomScaleResize([1.0,1.1]),
102
+ # transforms.CenterCrop(self.sample_size),
103
+ transforms.RandomAffine(degrees=0, translate=(0.08,0.08),scale=(0.9,1.1)),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
106
+ ])
107
+ self.ref_transforms_test = transforms.Compose([
108
+ transforms.Resize(self.sample_size),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
111
+ ])
112
+
113
+ def __len__(self):
114
+ if len(self.cloth_ids) >= 1:
115
+ return len(self.data_ids)*len(self.cloth_ids)
116
+ else:
117
+ return len(self.data_ids)
118
+
119
+ def __getitem__(self, idx):
120
+
121
+ if len(self.cloth_ids) >=1:
122
+ data_idx = idx // len(self.cloth_ids)
123
+ cloth_idx = idx % len(self.cloth_ids)
124
+
125
+ data_id = self.data_ids[data_idx]
126
+ frame_id = self.data_frame_ids[data_idx]
127
+ cloth_id = self.cloth_ids[cloth_idx]
128
+ cloth_frame_id = self.cloth_frame_ids[cloth_idx]
129
+ cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % (cloth_id, cloth_frame_id)) # 实际是反的
130
+ cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % (cloth_id, cloth_frame_id))
131
+ else:
132
+ data_id = self.data_ids[idx]
133
+ frame_id = self.data_frame_ids[idx]
134
+ cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % (data_id, frame_id)) # 实际是反的
135
+ cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % (data_id, frame_id))
136
+
137
+ # cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % ('100030', '0540'))
138
+ # cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % ('100030', '0540'))
139
+
140
+ images_root = os.path.join(self.dataroot, data_id, 'agnostic', frame_id)
141
+ images = sorted(os.listdir(images_root))
142
+
143
+ if self.is_train:
144
+ check_images = []
145
+ for image in images:
146
+ if 'CC32871A015' not in image:
147
+ check_images.append(image)
148
+ select_images = random.sample(check_images, self.multi_length)
149
+
150
+ else:
151
+ # front
152
+ front_cameras = [
153
+ 'CC32871A005','CC32871A016','CC32871A017','CC32871A023','CC32871A027',
154
+ 'CC32871A030','CC32871A032','CC32871A033','CC32871A034','CC32871A035',
155
+ 'CC32871A038','CC32871A050','CC32871A051','CC32871A052','CC32871A059', 'CC32871A060'
156
+ ]
157
+ back_cameras = [
158
+ 'CC32871A004','CC32871A010', 'CC32871A013', 'CC32871A022', 'CC32871A029',
159
+ 'CC32871A031','CC32871A037', 'CC32871A039', 'CC32871A040', 'CC32871A044',
160
+ 'CC32871A046','CC32871A048', 'CC32871A055', 'CC32871A057', 'CC32871A058', 'CC32871A041'
161
+ ]
162
+ select_images = []
163
+ for image in images:
164
+ camera_id = image.split('_')[0]
165
+ if camera_id in front_cameras:
166
+ select_images.append(image)
167
+ select_images = sorted(select_images)
168
+ # print(select_images)
169
+ for i in range(len(select_images)):
170
+ select_images[i] = os.path.join(data_id,'resized_img', frame_id, select_images[i])
171
+ sample = self.load_images(select_images, data_id, cloth_name_front, cloth_name_back)
172
+ return sample
173
+
174
+ def load_images(self, select_images, data_id, cloth_name_front, cloth_name_back):
175
+
176
+ pixel_values_list = []
177
+ pixel_values_pose_list = []
178
+ camera_parm_list = []
179
+ pixel_values_agnostic_list = []
180
+ image_name_list = []
181
+
182
+ # load camera info
183
+ intri_name = os.path.join(self.dataroot, data_id, 'camera_intrinsics.json')
184
+ extri_name = os.path.join(self.dataroot, data_id, 'camera_extrinsics.json')
185
+ camera_scale_fn = os.path.join(self.dataroot, data_id, 'camera_scale.pkl')
186
+ camera_scale = pickle.load(open(camera_scale_fn, "rb"))
187
+ cameras_gt = read_camera_mvhumannet(intri_name, extri_name, camera_scale)
188
+
189
+ # load person data
190
+ for img_name in select_images:
191
+ camera_id = img_name.split('/')[-1].split('_')[0]
192
+
193
+ # load data
194
+ image_name_list.append(img_name)
195
+ pixel_values = Image.open(os.path.join(self.dataroot, img_name))
196
+ pixel_values_pose = Image.open(os.path.join(self.dataroot, img_name).replace('resized_img', 'normals').replace('.jpg','_normal.jpg'))
197
+ pixel_values_agnostic = Image.open(os.path.join(self.dataroot, img_name).replace('resized_img', 'agnostic'))
198
+ parm_matrix = cameras_gt[camera_id]['RT'] # extrinsic
199
+
200
+ # crop pose
201
+ annot_path = os.path.join(self.dataroot, img_name.replace('resized_img', 'annots').replace('.jpg','.json'))
202
+ annot_info = json.load(open(annot_path))
203
+ bbox = annot_info['annots'][0]['bbox']
204
+ width = annot_info['width']
205
+ if width == 4096 or width == 2448:
206
+ for i in range(4):
207
+ bbox[i] = bbox[i] // 2
208
+ elif width == 2048:
209
+ pass
210
+ else:
211
+ print('wrong annot size',img_path)
212
+ pixel_values_pose = crop_and_resize(pixel_values_pose, bbox, size=self.sample_size)
213
+
214
+ # camera parameter
215
+ parm_matrix = torch.tensor(parm_matrix)
216
+ camera_parm = parm_matrix[:3,:3].reshape(-1) # todo
217
+
218
+ # transform
219
+ pixel_values = self.pixel_transforms(pixel_values)
220
+ pixel_values_pose = self.pixel_transforms(pixel_values_pose)
221
+ pixel_values_agnostic = self.pixel_transforms(pixel_values_agnostic)
222
+
223
+ pixel_values_list.append(pixel_values)
224
+ pixel_values_pose_list.append(pixel_values_pose)
225
+ camera_parm_list.append(camera_parm)
226
+ pixel_values_agnostic_list.append(pixel_values_agnostic)
227
+
228
+ pixel_values = torch.stack(pixel_values_list)
229
+ pixel_values_pose = torch.stack(pixel_values_pose_list)
230
+ camera_parm = torch.stack(camera_parm_list)
231
+ pixel_values_agnostic = torch.stack(pixel_values_agnostic_list)
232
+
233
+ pixel_values_cloth_front = Image.open(os.path.join(self.cloth_root, cloth_name_front))
234
+ pixel_values_cloth_back = Image.open(os.path.join(self.cloth_root, cloth_name_back))
235
+
236
+ # clip
237
+ clip_ref_front = self.clip_image_processor(images=pixel_values_cloth_front, return_tensors="pt").pixel_values
238
+ clip_ref_back = self.clip_image_processor(images=pixel_values_cloth_back, return_tensors="pt").pixel_values
239
+
240
+ if self.is_train:
241
+ pixel_values_cloth_front = self.ref_transforms_train(pixel_values_cloth_front)
242
+ pixel_values_cloth_back = self.ref_transforms_train(pixel_values_cloth_back)
243
+ else:
244
+ pixel_values_cloth_front = self.ref_transforms_test(pixel_values_cloth_front)
245
+ pixel_values_cloth_back = self.ref_transforms_test(pixel_values_cloth_back)
246
+
247
+ drop_image_embeds = []
248
+ for k in range(len(select_images)):
249
+ if random.random() < 0.1:
250
+ drop_image_embeds.append(torch.tensor(1))
251
+ else:
252
+ drop_image_embeds.append(torch.tensor(0))
253
+ drop_image_embeds = torch.stack(drop_image_embeds)
254
+ sample = dict(
255
+ pixel_values=pixel_values,
256
+ pixel_values_pose=pixel_values_pose,
257
+ pixel_values_agnostic=pixel_values_agnostic,
258
+ clip_ref_front=clip_ref_front,
259
+ clip_ref_back=clip_ref_back,
260
+ pixel_values_cloth_front=pixel_values_cloth_front,
261
+ pixel_values_cloth_back=pixel_values_cloth_back,
262
+ camera_parm=camera_parm,
263
+ drop_image_embeds=drop_image_embeds,
264
+ img_name=image_name_list,
265
+ cloth_name=cloth_name_front,
266
+ )
267
+
268
+ return sample
269
+
270
+ def collate_fn(data):
271
+
272
+ pixel_values = torch.stack([example["pixel_values"] for example in data])
273
+ pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
274
+ pixel_values_agnostic = torch.stack([example["pixel_values_agnostic"] for example in data])
275
+ clip_ref_front = torch.cat([example["clip_ref_front"] for example in data])
276
+ clip_ref_back = torch.cat([example["clip_ref_back"] for example in data])
277
+ pixel_values_cloth_front = torch.stack([example["pixel_values_cloth_front"] for example in data])
278
+ pixel_values_cloth_back = torch.stack([example["pixel_values_cloth_back"] for example in data])
279
+ camera_parm = torch.stack([example["camera_parm"] for example in data])
280
+ drop_image_embeds = [example["drop_image_embeds"] for example in data]
281
+ drop_image_embeds = torch.stack(drop_image_embeds)
282
+ img_name = []
283
+ cloth_name = []
284
+ for example in data:
285
+ img_name.extend(example['img_name'])
286
+ cloth_name.append(example['cloth_name'])
287
+
288
+ return {
289
+ "pixel_values": pixel_values,
290
+ "pixel_values_pose": pixel_values_pose,
291
+ "pixel_values_agnostic": pixel_values_agnostic,
292
+ "clip_ref_front": clip_ref_front,
293
+ "clip_ref_back": clip_ref_back,
294
+ "pixel_values_ref_front": pixel_values_cloth_front,
295
+ "pixel_values_ref_back": pixel_values_cloth_back,
296
+ "camera_parm": camera_parm,
297
+ "drop_image_embeds": drop_image_embeds,
298
+ "img_name": img_name,
299
+ "cloth_name": cloth_name,
300
+ }
301
+
302
+
303
+ if __name__ == '__main__':
304
+ seed = 20
305
+ random.seed(seed)
306
+ torch.manual_seed(seed)
307
+ torch.cuda.manual_seed(seed)
308
+ dataset = MVHumanNet_Dataset(dataroot="/GPUFS/sysu_gbli2_1/hzj/mvhumannet/",
309
+ sample_size=(768,576),is_train=True,mode='pair',
310
+ clip_model_path = "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32")
311
+
312
+ # print(len(dataset))
313
+
314
+ # for _ in range(500):
315
+
316
+ # p = random.randint(0,len(dataset)-1)
317
+ # p = dataset[p]
318
+
319
+ test_dataloader = torch.utils.data.DataLoader(
320
+ dataset,
321
+ shuffle=False,
322
+ collate_fn=collate_fn,
323
+ batch_size=1,
324
+ num_workers=2,
325
+ )
326
+
327
+ for _, batch in enumerate(test_dataloader):
328
+ # print(batch['cloth_name'], batch['img_name'])
329
+ p = {}
330
+ print('111', batch['camera_parm'].shape)
331
+ print('111', batch['drop_image_embeds'].shape)
332
+ for key in batch.keys():
333
+ p[key] = batch[key][0]
334
+ # p = dataset[12]
335
+
336
+ print(p['camera_parm'].shape)
337
+
338
+ pixel_values = p['pixel_values'][0].permute(1,2,0).numpy()
339
+ print(p['pixel_values'].shape)
340
+ pixel_values = pixel_values / 2 + 0.5
341
+ pixel_values *=255
342
+ pixel_values = pixel_values.astype(np.uint8)
343
+ pixel_values= Image.fromarray(pixel_values)
344
+ pixel_values.save('pixel_values0.jpg')
345
+
346
+ pixel_values_pose = p['pixel_values_pose'][0].permute(1,2,0).numpy()
347
+ print(p['pixel_values_pose'].shape)
348
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
349
+ pixel_values_pose *=255
350
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
351
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
352
+ pixel_values_pose.save('pixel_values_pose.jpg')
353
+
354
+ pixel_values_agnostic = p['pixel_values_agnostic'][0].permute(1,2,0).numpy()
355
+ print(p['pixel_values_agnostic'].shape)
356
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
357
+ pixel_values_agnostic *=255
358
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
359
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
360
+ pixel_values_agnostic.save('pixel_values_agnostic.jpg')
361
+
362
+ pixel_values = p['pixel_values'][2].permute(1,2,0).numpy()
363
+ print(p['pixel_values'].shape)
364
+ pixel_values = pixel_values / 2 + 0.5
365
+ pixel_values *=255
366
+ pixel_values = pixel_values.astype(np.uint8)
367
+ pixel_values= Image.fromarray(pixel_values)
368
+ pixel_values.save('pixel_values2.jpg')
369
+
370
+ pixel_values_pose = p['pixel_values_pose'][2].permute(1,2,0).numpy()
371
+ print(p['pixel_values_pose'].shape)
372
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
373
+ pixel_values_pose *=255
374
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
375
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
376
+ pixel_values_pose.save('pixel_values_pose2.jpg')
377
+
378
+ pixel_values_agnostic = p['pixel_values_agnostic'][2].permute(1,2,0).numpy()
379
+ print(p['pixel_values_agnostic'].shape)
380
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
381
+ pixel_values_agnostic *=255
382
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
383
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
384
+ pixel_values_agnostic.save('pixel_values_agnostic2.jpg')
385
+
386
+ pixel_values_cloth_img = p['pixel_values_ref_front'].permute(1,2,0).numpy()
387
+ print(p['pixel_values_ref_front'].shape)
388
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
389
+ pixel_values_cloth_img *=255
390
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
391
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
392
+ pixel_values_cloth_img.save('pixel_values_cloth_front.jpg')
393
+
394
+ pixel_values_cloth_img = p['pixel_values_ref_back'].permute(1,2,0).numpy()
395
+ print(p['pixel_values_ref_back'].shape)
396
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
397
+ pixel_values_cloth_img *=255
398
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
399
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
400
+ pixel_values_cloth_img.save('pixel_values_cloth_back.jpg')
401
+ exit()
402
+
403
+
src/multiview_consist_edit/Thuman2_multi.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from PIL import Image,ImageDraw
4
+ import json
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from transformers import CLIPProcessor
10
+ import random
11
+ from torchvision.transforms import functional as F
12
+ import torch.distributed as dist
13
+ import copy
14
+ import cv2
15
+
16
+ def crop_image(human_img_orig):
17
+ human_img_orig = human_img_orig.resize((1024,1024))
18
+ original_width, original_height = human_img_orig.size
19
+ target_width = 768
20
+ crop_amount = (original_width - target_width) // 2
21
+ left = crop_amount
22
+ upper = 0
23
+ right = original_width - crop_amount
24
+ lower = original_height
25
+ cropped_image = human_img_orig.crop((left, upper, right, lower))
26
+ return cropped_image
27
+
28
+ class Thuman2_Dataset(Dataset):
29
+ def __init__(
30
+ self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8,
31
+ ):
32
+ c_names_front = []
33
+ c_names_back = []
34
+
35
+ self.data_ids = []
36
+ self.dataroot = os.path.join(dataroot, 'all')
37
+ self.cloth_root = os.path.join(dataroot, 'cloth')
38
+ # self.cloth_root = os.path.join(dataroot, 'MVG_clothes')
39
+
40
+ self.cloth_ids = []
41
+ if is_train:
42
+ f = open(os.path.join(dataroot,'train_ids.txt'))
43
+ for line in f.readlines():
44
+ self.data_ids.append(line.strip())
45
+ f.close()
46
+ else:
47
+ # f = open(os.path.join(dataroot, 'val_ids.txt'))
48
+ f = open(os.path.join(dataroot, 'test_ids.txt'))
49
+ # f = open(os.path.join(dataroot, 'test_mvg_ids.txt'))
50
+ for line in f.readlines():
51
+ self.data_ids.append(line.strip())
52
+ f.close()
53
+ f2 = open(os.path.join(dataroot, 'test_cloth_ids.txt'))
54
+ # f2 = open(os.path.join(dataroot, 'test_mvg_cloth_ids.txt'))
55
+ for line in f2.readlines():
56
+ self.cloth_ids.append(line.strip())
57
+ f2.close()
58
+
59
+ self.mode = mode
60
+ self.is_train = is_train
61
+ self.sample_size = sample_size
62
+ self.multi_length = multi_length
63
+ self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=True)
64
+
65
+ self.pixel_transforms = transforms.Compose([
66
+ transforms.Resize((1024,768), interpolation=0),
67
+ transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
68
+ transforms.Resize(self.sample_size, interpolation=0),
69
+ # transforms.CenterCrop(self.sample_size),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
72
+ ])
73
+
74
+ self.pixel_transforms_0 = transforms.Compose([
75
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
76
+ ])
77
+ self.pixel_transforms_1 = transforms.Compose([
78
+ transforms.Resize((1024,768), interpolation=0),
79
+ transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
80
+ transforms.Resize(self.sample_size, interpolation=0),
81
+ ])
82
+
83
+ self.ref_transforms_train = transforms.Compose([
84
+ transforms.Resize(self.sample_size),
85
+ # RandomScaleResize([1.0,1.1]),
86
+ transforms.CenterCrop(self.sample_size),
87
+ transforms.RandomAffine(degrees=0, translate=(0.08,0.08),scale=(0.9,1.1)),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
90
+ ])
91
+ self.ref_transforms_test = transforms.Compose([
92
+ transforms.Resize(self.sample_size),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
95
+ ])
96
+ self.color_transform = transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.0)
97
+
98
+ def __len__(self):
99
+ if len(self.cloth_ids) >= 1:
100
+ return len(self.data_ids)*len(self.cloth_ids)
101
+ else:
102
+ return len(self.data_ids)
103
+
104
+ def __getitem__(self, idx):
105
+
106
+ if len(self.cloth_ids) >=1:
107
+ data_idx = idx // len(self.cloth_ids)
108
+ cloth_idx = idx % len(self.cloth_ids)
109
+
110
+ data_id = self.data_ids[data_idx]
111
+ cloth_id = self.cloth_ids[cloth_idx]
112
+ cloth_name_back = os.path.join(self.cloth_root, '%s_front.jpg' % cloth_id)
113
+ cloth_name_front = os.path.join(self.cloth_root, '%s_back.jpg' % cloth_id)
114
+ else:
115
+ data_id = self.data_ids[idx]
116
+ cloth_name_back = os.path.join(self.cloth_root, '%s_front.jpg' % data_id)
117
+ cloth_name_front = os.path.join(self.cloth_root, '%s_back.jpg' % data_id)
118
+
119
+ images_root = os.path.join(self.dataroot, data_id, 'agnostic') # need only val
120
+ images = sorted(os.listdir(images_root))
121
+
122
+ # cloth_name_back = '0001_front.jpg'
123
+ # cloth_name_front = '0001_back.jpg'
124
+
125
+ if self.is_train:
126
+ select_images = random.sample(images, self.multi_length)
127
+
128
+ else:
129
+ # select_idxs = [0,3,6,9,12, 15,18,21,24,27, 79,76,73,70,67,64]
130
+ L = len(images)
131
+ select_idxs = []
132
+ begin = 0
133
+ sl = 16.0
134
+ if True:
135
+ while begin < L//2:
136
+ select_idxs.append(int(begin/2))
137
+ select_idxs.append(int(L-1-begin/2))
138
+ begin += L/sl
139
+ else:
140
+ begin = L//4
141
+ while begin < L*3//4:
142
+ select_idxs.append(int(begin))
143
+ begin += L/2/sl
144
+ # print(sorted(select_idxs))
145
+ # select_idxs = [0,3,6,9,12, 15,18,21,24,27, L-1,L-4,L-7,L-10,L-13,L-16]
146
+ select_images = []
147
+ for select_idx in select_idxs:
148
+ select_images.append(images[select_idx])
149
+ select_images = sorted(select_images)
150
+ # print(select_images)
151
+ for i in range(len(select_images)):
152
+ select_images[i] = os.path.join(data_id,'images',select_images[i])
153
+ sample = self.load_images(select_images, cloth_name_front, cloth_name_back)
154
+ return sample
155
+
156
+ def color_progress(images):
157
+ fn_idx, b, c, s, h = self.color_transform.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue)
158
+ for image in images:
159
+ image = F.adjust_contrast(image, c)
160
+ image = F.adjust_brightness(image, b)
161
+ image = F.adjust_saturation(image, s)
162
+ return images
163
+
164
+ def load_images(self, select_images, cloth_name_front, cloth_name_back):
165
+
166
+ pixel_values_list = []
167
+ pixel_values_pose_list = []
168
+ camera_parm_list = []
169
+ pixel_values_agnostic_list = []
170
+ image_name_list = []
171
+
172
+ # load person data
173
+ for img_name in select_images:
174
+ image_name_list.append(img_name)
175
+ pixel_values = Image.open(os.path.join(self.dataroot, img_name))
176
+ pixel_values_pose = Image.open(os.path.join(self.dataroot, img_name).replace('images', 'normals'))
177
+ # parse_lip = Image.open(os.path.join(parse_lip_dir, img_name))
178
+ pixel_values_agnostic = Image.open(os.path.join(self.dataroot, img_name).replace('images', 'agnostic'))
179
+ parm_matrix = np.load(os.path.join(self.dataroot, img_name[:4],'parm', img_name[-7:-4]+'_extrinsic.npy'))
180
+ pixel_values = crop_image(pixel_values)
181
+ pixel_values_pose = crop_image(pixel_values_pose)
182
+ # camera parameter
183
+ parm_matrix = torch.tensor(parm_matrix)
184
+ camera_parm = parm_matrix[:3,:3].reshape(-1) # todo
185
+ # transform
186
+ pixel_values = self.pixel_transforms(pixel_values)
187
+ pixel_values_pose = self.pixel_transforms(pixel_values_pose)
188
+ pixel_values_agnostic = self.pixel_transforms(pixel_values_agnostic)
189
+
190
+ pixel_values_list.append(pixel_values)
191
+ pixel_values_pose_list.append(pixel_values_pose)
192
+ camera_parm_list.append(camera_parm)
193
+ pixel_values_agnostic_list.append(pixel_values_agnostic)
194
+
195
+ pixel_values = torch.stack(pixel_values_list)
196
+ pixel_values_pose = torch.stack(pixel_values_pose_list)
197
+ camera_parm = torch.stack(camera_parm_list)
198
+ pixel_values_agnostic = torch.stack(pixel_values_agnostic_list)
199
+
200
+ pixel_values_cloth_front = Image.open(os.path.join(self.cloth_root, cloth_name_front))
201
+ pixel_values_cloth_back = Image.open(os.path.join(self.cloth_root, cloth_name_back))
202
+
203
+ # clip
204
+ clip_ref_front = self.clip_image_processor(images=pixel_values_cloth_front, return_tensors="pt").pixel_values
205
+ clip_ref_back = self.clip_image_processor(images=pixel_values_cloth_back, return_tensors="pt").pixel_values
206
+
207
+ if self.is_train:
208
+ pixel_values_cloth_front = self.ref_transforms_train(pixel_values_cloth_front)
209
+ pixel_values_cloth_back = self.ref_transforms_train(pixel_values_cloth_back)
210
+ else:
211
+ pixel_values_cloth_front = self.ref_transforms_test(pixel_values_cloth_front)
212
+ pixel_values_cloth_back = self.ref_transforms_test(pixel_values_cloth_back)
213
+
214
+ drop_image_embeds = []
215
+ for k in range(len(select_images)):
216
+ if random.random() < 0.1:
217
+ drop_image_embeds.append(torch.tensor(1))
218
+ else:
219
+ drop_image_embeds.append(torch.tensor(0))
220
+ drop_image_embeds = torch.stack(drop_image_embeds)
221
+ sample = dict(
222
+ pixel_values=pixel_values,
223
+ pixel_values_pose=pixel_values_pose,
224
+ pixel_values_agnostic=pixel_values_agnostic,
225
+ clip_ref_front=clip_ref_front,
226
+ clip_ref_back=clip_ref_back,
227
+ pixel_values_cloth_front=pixel_values_cloth_front,
228
+ pixel_values_cloth_back=pixel_values_cloth_back,
229
+ camera_parm=camera_parm,
230
+ drop_image_embeds=drop_image_embeds,
231
+ img_name=image_name_list,
232
+ cloth_name=cloth_name_front,
233
+ )
234
+
235
+ return sample
236
+
237
+ def collate_fn(data):
238
+
239
+ pixel_values = torch.stack([example["pixel_values"] for example in data])
240
+ pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
241
+ pixel_values_agnostic = torch.stack([example["pixel_values_agnostic"] for example in data])
242
+ clip_ref_front = torch.cat([example["clip_ref_front"] for example in data])
243
+ clip_ref_back = torch.cat([example["clip_ref_back"] for example in data])
244
+ pixel_values_cloth_front = torch.stack([example["pixel_values_cloth_front"] for example in data])
245
+ pixel_values_cloth_back = torch.stack([example["pixel_values_cloth_back"] for example in data])
246
+ camera_parm = torch.stack([example["camera_parm"] for example in data])
247
+ drop_image_embeds = [example["drop_image_embeds"] for example in data]
248
+ drop_image_embeds = torch.stack(drop_image_embeds)
249
+ img_name = []
250
+ cloth_name = []
251
+ for example in data:
252
+ img_name.extend(example['img_name'])
253
+ cloth_name.append(example['cloth_name'])
254
+
255
+ return {
256
+ "pixel_values": pixel_values,
257
+ "pixel_values_pose": pixel_values_pose,
258
+ "pixel_values_agnostic": pixel_values_agnostic,
259
+ "clip_ref_front": clip_ref_front,
260
+ "clip_ref_back": clip_ref_back,
261
+ "pixel_values_ref_front": pixel_values_cloth_front,
262
+ "pixel_values_ref_back": pixel_values_cloth_back,
263
+ "camera_parm": camera_parm,
264
+ "drop_image_embeds": drop_image_embeds,
265
+ "img_name": img_name,
266
+ "cloth_name": cloth_name,
267
+ }
268
+
269
+
270
+ if __name__ == '__main__':
271
+ seed = 20
272
+ random.seed(seed)
273
+ torch.manual_seed(seed)
274
+ torch.cuda.manual_seed(seed)
275
+ dataset = Thuman2_Dataset(dataroot="/GPUFS/sysu_gbli2_1/hzj/save_render_data_yw/",
276
+ sample_size=(768,576),is_train=False,mode='pair',
277
+ clip_model_path = "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32")
278
+
279
+ # for _ in range(500):
280
+
281
+ # p = random.randint(0,len(dataset)-1)
282
+ # p = dataset[p]
283
+
284
+ test_dataloader = torch.utils.data.DataLoader(
285
+ dataset,
286
+ shuffle=False,
287
+ collate_fn=collate_fn,
288
+ batch_size=2,
289
+ num_workers=1,
290
+ )
291
+
292
+ for _, batch in enumerate(test_dataloader):
293
+ p = {}
294
+ print('111', batch['camera_parm'].shape)
295
+ print('111', batch['drop_image_embeds'].shape)
296
+ for key in batch.keys():
297
+ p[key] = batch[key][0]
298
+ # p = dataset[12]
299
+
300
+ print(p['camera_parm'].shape)
301
+
302
+ pixel_values = p['pixel_values'][0].permute(1,2,0).numpy()
303
+ print(p['pixel_values'].shape)
304
+ pixel_values = pixel_values / 2 + 0.5
305
+ pixel_values *=255
306
+ pixel_values = pixel_values.astype(np.uint8)
307
+ pixel_values= Image.fromarray(pixel_values)
308
+ pixel_values.save('pixel_values0.jpg')
309
+
310
+ pixel_values_pose = p['pixel_values_pose'][0].permute(1,2,0).numpy()
311
+ print(p['pixel_values_pose'].shape)
312
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
313
+ pixel_values_pose *=255
314
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
315
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
316
+ pixel_values_pose.save('pixel_values_pose.jpg')
317
+
318
+ pixel_values_agnostic = p['pixel_values_agnostic'][0].permute(1,2,0).numpy()
319
+ print(p['pixel_values_agnostic'].shape)
320
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
321
+ pixel_values_agnostic *=255
322
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
323
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
324
+ pixel_values_agnostic.save('pixel_values_agnostic.jpg')
325
+
326
+ pixel_values = p['pixel_values'][2].permute(1,2,0).numpy()
327
+ print(p['pixel_values'].shape)
328
+ pixel_values = pixel_values / 2 + 0.5
329
+ pixel_values *=255
330
+ pixel_values = pixel_values.astype(np.uint8)
331
+ pixel_values= Image.fromarray(pixel_values)
332
+ pixel_values.save('pixel_values2.jpg')
333
+
334
+ pixel_values_pose = p['pixel_values_pose'][2].permute(1,2,0).numpy()
335
+ print(p['pixel_values_pose'].shape)
336
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
337
+ pixel_values_pose *=255
338
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
339
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
340
+ pixel_values_pose.save('pixel_values_pose2.jpg')
341
+
342
+ pixel_values_agnostic = p['pixel_values_agnostic'][2].permute(1,2,0).numpy()
343
+ print(p['pixel_values_agnostic'].shape)
344
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
345
+ pixel_values_agnostic *=255
346
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
347
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
348
+ pixel_values_agnostic.save('pixel_values_agnostic2.jpg')
349
+
350
+ pixel_values_cloth_img = p['pixel_values_ref_front'].permute(1,2,0).numpy()
351
+ print(p['pixel_values_ref_front'].shape)
352
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
353
+ pixel_values_cloth_img *=255
354
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
355
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
356
+ pixel_values_cloth_img.save('pixel_values_cloth_front.jpg')
357
+
358
+ pixel_values_cloth_img = p['pixel_values_ref_back'].permute(1,2,0).numpy()
359
+ print(p['pixel_values_ref_back'].shape)
360
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
361
+ pixel_values_cloth_img *=255
362
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
363
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
364
+ pixel_values_cloth_img.save('pixel_values_cloth_back.jpg')
365
+
366
+ exit()
src/multiview_consist_edit/config/infer_tryon_multi.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 42
2
+
3
+ model_path: "stable-diffusion-v1-5/stable-diffusion-v1-5"
4
+ vae_path: "stabilityai/sd-vae-ft-mse"
5
+ clip_model_path: 'openai/clip-vit-base-patch32'
6
+
7
+ # unet_path: "/GPUFS/sysu_gbli2_1/hzj/animate/checkpoints/thuman_tryon_mvattn_multi_1205/checkpoint-30000"
8
+ # pretrained_poseguider_path: "/GPUFS/sysu_gbli2_1/hzj/animate/checkpoints/thuman_tryon_mvattn_multi_1205/checkpoint-30000/pose.ckpt"
9
+ # pretrained_referencenet_path: '/GPUFS/sysu_gbli2_1/hzj/animate/checkpoints/thuman_tryon_mvattn_multi_1205/checkpoint-30000'
10
+
11
+ unet_path: "./checkpoints/mvhumannet_tryon_mvattn_multi/checkpoint-40000"
12
+ pretrained_poseguider_path: "./checkpoints/mvhumannet_tryon_mvattn_multi/checkpoint-40000/pose.ckpt"
13
+ pretrained_referencenet_path: './checkpoints/mvhumannet_tryon_mvattn_multi/checkpoint-40000'
14
+
15
+ out_dir: 'image_output_tryon_mvhumannet'
16
+
17
+ batch_size: 2
18
+ dataloader_num_workers: 4
19
+ guidance_scale: 2 # thuman:3 mvhumannet:2
20
+
21
+
22
+ # infer_data:
23
+ # # dataroot: "/GPUFS/sysu_gbli2_1/hzj/render_data"
24
+ # dataroot: "/GPUFS/sysu_gbli2_1/hzj/save_render_data_yw/"
25
+ # # sample_size: [512,384] # for 40G 256
26
+ # sample_size: [768,576]
27
+ # clip_model_path: '/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32'
28
+ # is_train: false
29
+ # mode: 'pair'
30
+ # output_front: true
31
+
32
+ infer_data:
33
+ # dataroot: "/GPUFS/sysu_gbli2_1/hzj/render_data"
34
+ dataroot: "../../demo_data/mvhumannet_2D_edit/"
35
+ # sample_size: [512,384] # for 40G 256
36
+ sample_size: [768,576]
37
+ clip_model_path: 'openai/clip-vit-base-patch32'
38
+ is_train: false
39
+ mode: 'pair'
40
+ output_front: true
41
+
42
+ fusion_blocks: "full"
43
+ image_finetune: true
44
+ num_inference_steps: 30
src/multiview_consist_edit/config/train_tryon_multi.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_finetune: true
2
+ from_scratch: false
3
+
4
+ output_dir: "mvhumannet_tryon_mvattn_multi_1205"
5
+ # output_dir: "mvhumannet_tryon_exp_multi_1028"
6
+ logging_dir: "log"
7
+ # pretrained_model_path: "/data1/hezijian/pretrained_models/stable-diffusion-v1-5"
8
+ # pretrained_vae_path: "/data1/hezijian/pretrained_models/sd-vae-ft-mse"
9
+ # pretrained_clip_path: '/data1/hezijian/pretrained_models/clip-vit-base-patch32'
10
+ # clip_model_path: '/data1/hezijian/pretrained_models/clip-vit-base-patch32'
11
+ pretrained_model_path: "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/stable-diffusion-v1-5"
12
+ pretrained_vae_path: "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/sd-vae-ft-mse"
13
+ pretrained_clip_path: '/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32'
14
+ clip_model_path: '/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32'
15
+ controlnet_model_name_or_path: null
16
+
17
+ # trained stage1 model
18
+ trained_unet_path: "checkpoints/thuman_tryon_exp_1015_two/checkpoint-120000"
19
+ trained_referencenet_path: "checkpoints/thuman_tryon_exp_1015_two/checkpoint-120000"
20
+ trained_pose_guider_path: 'checkpoints/thuman_tryon_exp_1015_two/checkpoint-120000/pose.ckpt'
21
+ # trained_unet_path: "thuman_tryon_exp_1015_two/checkpoint-60000"
22
+ # trained_referencenet_path: "thuman_tryon_exp_1015_two/checkpoint-60000"
23
+ # trained_pose_guider_path: 'thuman_tryon_exp_1015_two/checkpoint-60000/pose.ckpt'
24
+
25
+ unet_additional_kwargs:
26
+ use_motion_module : false
27
+ motion_module_resolutions : [ 1,2,4,8 ]
28
+ unet_use_cross_frame_attention : false
29
+ unet_use_temporal_attention : false
30
+
31
+ motion_module_type: Vanilla
32
+ motion_module_kwargs:
33
+ num_attention_heads : 8
34
+ num_transformer_block : 1
35
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
36
+ temporal_position_encoding : true
37
+ temporal_position_encoding_max_len : 24
38
+ temporal_attention_dim_div : 1
39
+ zero_initialize : true
40
+ encoder_hid_dim: 1280
41
+ encoder_hid_dim_type: 'text_proj'
42
+
43
+ noise_scheduler_kwargs:
44
+ num_train_timesteps: 1000
45
+ beta_start: 0.00085
46
+ beta_end: 0.012
47
+ beta_schedule: "linear"
48
+ steps_offset: 1
49
+ clip_sample: false
50
+
51
+ train_data:
52
+ # dataroot: "/GPUFS/sysu_gbli2_1/hzj/render_data"
53
+ dataroot: "/GPUFS/sysu_gbli2_1/hzj/mvhumannet/"
54
+ # sample_size: [512,384] # for 40G 256
55
+ sample_size: [768,576]
56
+ clip_model_path: '/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32'
57
+ is_train: true
58
+ mode: 'pair'
59
+
60
+ # train_data:
61
+ # # dataroot: "/GPUFS/sysu_gbli2_1/hzj/render_data"
62
+ # dataroot: "/GPUFS/sysu_gbli2_1/hzj/save_render_data_yw/"
63
+ # # sample_size: [512,384] # for 40G 256
64
+ # sample_size: [768,576] # for 40G 256
65
+ # clip_model_path: '/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32'
66
+ # is_train: true
67
+ # mode: 'pair'
68
+
69
+ # train_data:
70
+ # # csv_path: "./data/UBC_train_info_test.csv"
71
+ # csv_path: "./data/TikTok_info.csv"
72
+ # video_folder: "../TikTok_dataset2/TikTok_dataset"
73
+ # sample_size: 512 # for 40G 256
74
+ # sample_stride: 4
75
+ # sample_n_frames: 8
76
+ # clip_model_path: 'pretrained_models/clip-vit-base-patch32'
77
+
78
+ # train_data:
79
+ # # csv_path: "./data/UBC_train_info_test.csv"
80
+ # csv_path: "./data/UBC_train_info.csv"
81
+ # video_folder: "../UBC_dataset"
82
+ # sample_size: 512 # for 40G 256
83
+ # sample_stride: 4
84
+ # sample_n_frames: 8
85
+ # clip_model_path: 'pretrained_models/clip-vit-base-patch32'
86
+
87
+ validation_data:
88
+ prompts:
89
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
90
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
91
+ - "Robot dancing in times square."
92
+ - "Pacific coast, carmel by the sea ocean and waves."
93
+ num_inference_steps: 25
94
+ guidance_scale: 8.
95
+
96
+ trainable_modules:
97
+ # - "motion_modules."
98
+ - "."
99
+ # - "conv_in"
100
+
101
+ fusion_blocks: "full"
102
+
103
+ unet_checkpoint_path: ""
104
+
105
+ scale_lr: false
106
+ adam_beta1: 0.9
107
+ adam_beta2: 0.999
108
+ adam_weight_decay: 1.e-2
109
+ adam_epsilon: 1.e-08
110
+ learning_rate: 2.e-5
111
+ train_batch_size: 1
112
+ gradient_accumulation_steps: 2
113
+ max_grad_norm: 1.0
114
+
115
+ lr_scheduler: 'constant'
116
+ lr_warmup_steps: 0
117
+
118
+ num_train_epochs: 10000
119
+ max_train_steps: null
120
+ checkpointing_steps: 2000
121
+
122
+ validation_steps: 5000
123
+ validation_steps_tuple: [2, 50]
124
+
125
+ seed: 42
126
+ mixed_precision_training: true
127
+ enable_xformers_memory_efficient_attention: True
128
+
129
+ is_debug: False
130
+
131
+ checkpoints_total_limit: 10
132
+ mixed_precision: "fp16"
133
+ report_to: "tensorboard"
134
+ allow_tf32: true
135
+ resume_from_checkpoint: 'latest'
136
+ # resume_from_checkpoint: null
137
+ dataloader_num_workers: 8
src/multiview_consist_edit/data/MVHumanNet_multi.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from PIL import Image,ImageDraw
4
+ import json
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from transformers import CLIPProcessor
10
+ import random
11
+ from torchvision.transforms import functional as F
12
+ import torch.distributed as dist
13
+ import copy
14
+ import cv2
15
+ import pickle
16
+ from .camera_utils import read_camera_mvhumannet
17
+
18
+ def crop_and_resize(img, bbox, size):
19
+
20
+ # 计算中心点和新的宽高
21
+ center_x = (bbox[0] + bbox[2]) / 2
22
+ center_y = (bbox[1] + bbox[3]) / 2
23
+ new_height = bbox[3] - bbox[1]
24
+ new_width = int(new_height * (2 / 3))
25
+
26
+ # 计算新的边界框
27
+ new_bbox = [
28
+ int(center_x - new_width / 2),
29
+ int(center_y - new_height / 2),
30
+ int(center_x + new_width / 2),
31
+ int(center_y + new_height / 2)
32
+ ]
33
+
34
+ # 裁剪图像
35
+ cropped_img = img.crop(new_bbox)
36
+
37
+ # 调整大小
38
+ resized_img = cropped_img.resize(size)
39
+
40
+ return resized_img
41
+
42
+
43
+ class MVHumanNet_Dataset(Dataset):
44
+ def __init__(
45
+ self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8, output_front=True,
46
+ ):
47
+ im_names = []
48
+ self.dataroot = os.path.join(dataroot, 'processed_mvhumannet')
49
+ self.cloth_root = os.path.join(dataroot, 'cloth')
50
+ self.data_ids = []
51
+ self.data_frame_ids = []
52
+ self.cloth_ids = []
53
+ self.cloth_frame_ids = []
54
+ if is_train:
55
+ f = open(os.path.join(dataroot,'train_frame_ids.txt'))
56
+ for line in f.readlines():
57
+ line_info = line.strip().split()
58
+ self.data_ids.append(line_info[0])
59
+ self.data_frame_ids.append(line_info[1])
60
+ f.close()
61
+ else:
62
+ f = open(os.path.join(dataroot, 'test_ids.txt'))
63
+ for line in f.readlines():
64
+ line_info = line.strip().split()
65
+ self.data_ids.append(line_info[0])
66
+ self.data_frame_ids.append(line_info[1])
67
+ f.close()
68
+ f2 = open(os.path.join(dataroot, 'test_cloth_ids.txt'))
69
+ # f2 = open(os.path.join(dataroot, 'test_mvg_cloth_ids.txt'))
70
+ for line in f2.readlines():
71
+ line_info = line.strip().split()
72
+ self.cloth_ids.append(line_info[0])
73
+ self.cloth_frame_ids.append(line_info[1])
74
+ f2.close()
75
+
76
+ self.is_train = is_train
77
+ self.sample_size = sample_size
78
+ self.multi_length = multi_length
79
+ self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=False)
80
+
81
+ self.pixel_transforms = transforms.Compose([
82
+ #transforms.Resize((1024,768), interpolation=0),
83
+ #transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
84
+ transforms.Resize(self.sample_size, interpolation=0),
85
+ # transforms.CenterCrop(self.sample_size),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
88
+ ])
89
+
90
+ self.pixel_transforms_0 = transforms.Compose([
91
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
92
+ ])
93
+ self.pixel_transforms_1 = transforms.Compose([
94
+ # transforms.Resize((1024,768), interpolation=0),
95
+ # transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
96
+ transforms.Resize(self.sample_size, interpolation=0),
97
+ ])
98
+
99
+ self.ref_transforms_train = transforms.Compose([
100
+ transforms.Resize(self.sample_size),
101
+ # RandomScaleResize([1.0,1.1]),
102
+ # transforms.CenterCrop(self.sample_size),
103
+ transforms.RandomAffine(degrees=0, translate=(0.08,0.08),scale=(0.9,1.1)),
104
+ transforms.ToTensor(),
105
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
106
+ ])
107
+ self.ref_transforms_test = transforms.Compose([
108
+ transforms.Resize(self.sample_size),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
111
+ ])
112
+ self.output_front = True
113
+
114
+ def __len__(self):
115
+ if len(self.cloth_ids) >= 1:
116
+ return len(self.data_ids)*len(self.cloth_ids)
117
+ else:
118
+ return len(self.data_ids)
119
+
120
+ def __getitem__(self, idx):
121
+
122
+ if len(self.cloth_ids) >=1:
123
+ data_idx = idx // len(self.cloth_ids)
124
+ cloth_idx = idx % len(self.cloth_ids)
125
+
126
+ data_id = self.data_ids[data_idx]
127
+ frame_id = self.data_frame_ids[data_idx]
128
+ cloth_id = self.cloth_ids[cloth_idx]
129
+ cloth_frame_id = self.cloth_frame_ids[cloth_idx]
130
+ cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % (cloth_id, cloth_frame_id)) # 实际是反的
131
+ cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % (cloth_id, cloth_frame_id))
132
+ else:
133
+ data_id = self.data_ids[idx]
134
+ frame_id = self.data_frame_ids[idx]
135
+ cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % (data_id, frame_id)) # 实际是反的
136
+ cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % (data_id, frame_id))
137
+
138
+ # cloth_name_front = os.path.join(self.cloth_root, '%s_%s_front.jpg' % ('100030', '0540'))
139
+ # cloth_name_back = os.path.join(self.cloth_root, '%s_%s_back.jpg' % ('100030', '0540'))
140
+
141
+ images_root = os.path.join(self.dataroot, data_id, 'agnostic', frame_id)
142
+ images = sorted(os.listdir(images_root))
143
+
144
+ if self.is_train:
145
+ check_images = []
146
+ for image in images:
147
+ if 'CC32871A015' not in image:
148
+ check_images.append(image)
149
+ select_images = random.sample(check_images, self.multi_length)
150
+
151
+ else:
152
+ # front
153
+ front_cameras = [
154
+ 'CC32871A005','CC32871A016','CC32871A017','CC32871A023','CC32871A027',
155
+ 'CC32871A030','CC32871A032','CC32871A033','CC32871A034','CC32871A035',
156
+ 'CC32871A038','CC32871A050','CC32871A051','CC32871A052','CC32871A059', 'CC32871A060'
157
+ ]
158
+ back_cameras = [
159
+ 'CC32871A004','CC32871A010', 'CC32871A013', 'CC32871A022', 'CC32871A029',
160
+ 'CC32871A031','CC32871A037', 'CC32871A039', 'CC32871A040', 'CC32871A044',
161
+ 'CC32871A046','CC32871A048', 'CC32871A055', 'CC32871A057', 'CC32871A058', 'CC32871A041'
162
+ ]
163
+ select_images = []
164
+ for image in images:
165
+ camera_id = image.split('_')[0]
166
+ if camera_id in front_cameras and self.output_front:
167
+ select_images.append(image)
168
+ if camera_id in back_cameras and not self.output_front:
169
+ select_images.append(image)
170
+ select_images = sorted(select_images)
171
+ # print(select_images)
172
+ for i in range(len(select_images)):
173
+ select_images[i] = os.path.join(data_id,'resized_img', frame_id, select_images[i])
174
+ sample = self.load_images(select_images, data_id, cloth_name_front, cloth_name_back)
175
+ return sample
176
+
177
+ def load_images(self, select_images, data_id, cloth_name_front, cloth_name_back):
178
+
179
+ pixel_values_list = []
180
+ pixel_values_pose_list = []
181
+ camera_parm_list = []
182
+ pixel_values_agnostic_list = []
183
+ image_name_list = []
184
+
185
+ # load camera info
186
+ intri_name = os.path.join(self.dataroot, data_id, 'camera_intrinsics.json')
187
+ extri_name = os.path.join(self.dataroot, data_id, 'camera_extrinsics.json')
188
+ camera_scale_fn = os.path.join(self.dataroot, data_id, 'camera_scale.pkl')
189
+ camera_scale = pickle.load(open(camera_scale_fn, "rb"))
190
+ cameras_gt = read_camera_mvhumannet(intri_name, extri_name, camera_scale)
191
+
192
+ # load person data
193
+ for img_name in select_images:
194
+ camera_id = img_name.split('\\')[-1].split('_')[0]
195
+
196
+ # load data
197
+ image_name_list.append(img_name)
198
+ pixel_values = Image.open(os.path.join(self.dataroot, img_name))
199
+ pixel_values_pose = Image.open(os.path.join(self.dataroot, img_name).replace('resized_img', 'normals').replace('.jpg','_normal.jpg'))
200
+ pixel_values_agnostic = Image.open(os.path.join(self.dataroot, img_name).replace('resized_img', 'agnostic'))
201
+ parm_matrix = cameras_gt[camera_id]['RT'] # extrinsic
202
+
203
+ # crop pose
204
+ annot_path = os.path.join(self.dataroot, img_name.replace('resized_img', 'annots').replace('.jpg','.json'))
205
+ annot_info = json.load(open(annot_path))
206
+ bbox = annot_info['annots'][0]['bbox']
207
+ width = annot_info['width']
208
+ if width == 4096 or width == 2448:
209
+ for i in range(4):
210
+ bbox[i] = bbox[i] // 2
211
+ elif width == 2048:
212
+ pass
213
+ else:
214
+ print('wrong annot size',img_path)
215
+ pixel_values_pose = crop_and_resize(pixel_values_pose, bbox, size=self.sample_size)
216
+
217
+ # camera parameter
218
+ parm_matrix = torch.tensor(parm_matrix)
219
+ camera_parm = parm_matrix[:3,:3].reshape(-1) # todo
220
+
221
+ # transform
222
+ pixel_values = self.pixel_transforms(pixel_values)
223
+ pixel_values_pose = self.pixel_transforms(pixel_values_pose)
224
+ pixel_values_agnostic = self.pixel_transforms(pixel_values_agnostic)
225
+
226
+ pixel_values_list.append(pixel_values)
227
+ pixel_values_pose_list.append(pixel_values_pose)
228
+ camera_parm_list.append(camera_parm)
229
+ pixel_values_agnostic_list.append(pixel_values_agnostic)
230
+
231
+ pixel_values = torch.stack(pixel_values_list)
232
+ pixel_values_pose = torch.stack(pixel_values_pose_list)
233
+ camera_parm = torch.stack(camera_parm_list)
234
+ pixel_values_agnostic = torch.stack(pixel_values_agnostic_list)
235
+
236
+ pixel_values_cloth_front = Image.open(cloth_name_front)
237
+ pixel_values_cloth_back = Image.open(cloth_name_back)
238
+
239
+ # clip
240
+ clip_ref_front = self.clip_image_processor(images=pixel_values_cloth_front, return_tensors="pt").pixel_values
241
+ clip_ref_back = self.clip_image_processor(images=pixel_values_cloth_back, return_tensors="pt").pixel_values
242
+
243
+ if self.is_train:
244
+ pixel_values_cloth_front = self.ref_transforms_train(pixel_values_cloth_front)
245
+ pixel_values_cloth_back = self.ref_transforms_train(pixel_values_cloth_back)
246
+ else:
247
+ pixel_values_cloth_front = self.ref_transforms_test(pixel_values_cloth_front)
248
+ pixel_values_cloth_back = self.ref_transforms_test(pixel_values_cloth_back)
249
+
250
+ drop_image_embeds = []
251
+ for k in range(len(select_images)):
252
+ if random.random() < 0.1:
253
+ drop_image_embeds.append(torch.tensor(1))
254
+ else:
255
+ drop_image_embeds.append(torch.tensor(0))
256
+ drop_image_embeds = torch.stack(drop_image_embeds)
257
+ sample = dict(
258
+ pixel_values=pixel_values,
259
+ pixel_values_pose=pixel_values_pose,
260
+ pixel_values_agnostic=pixel_values_agnostic,
261
+ clip_ref_front=clip_ref_front,
262
+ clip_ref_back=clip_ref_back,
263
+ pixel_values_cloth_front=pixel_values_cloth_front,
264
+ pixel_values_cloth_back=pixel_values_cloth_back,
265
+ camera_parm=camera_parm,
266
+ drop_image_embeds=drop_image_embeds,
267
+ img_name=image_name_list,
268
+ cloth_name=cloth_name_front,
269
+ )
270
+
271
+ return sample
272
+
273
+ def collate_fn(data):
274
+
275
+ pixel_values = torch.stack([example["pixel_values"] for example in data])
276
+ pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
277
+ pixel_values_agnostic = torch.stack([example["pixel_values_agnostic"] for example in data])
278
+ clip_ref_front = torch.cat([example["clip_ref_front"] for example in data])
279
+ clip_ref_back = torch.cat([example["clip_ref_back"] for example in data])
280
+ pixel_values_cloth_front = torch.stack([example["pixel_values_cloth_front"] for example in data])
281
+ pixel_values_cloth_back = torch.stack([example["pixel_values_cloth_back"] for example in data])
282
+ camera_parm = torch.stack([example["camera_parm"] for example in data])
283
+ drop_image_embeds = [example["drop_image_embeds"] for example in data]
284
+ drop_image_embeds = torch.stack(drop_image_embeds)
285
+ img_name = []
286
+ cloth_name = []
287
+ for example in data:
288
+ img_name.extend(example['img_name'])
289
+ cloth_name.append(example['cloth_name'])
290
+
291
+ return {
292
+ "pixel_values": pixel_values,
293
+ "pixel_values_pose": pixel_values_pose,
294
+ "pixel_values_agnostic": pixel_values_agnostic,
295
+ "clip_ref_front": clip_ref_front,
296
+ "clip_ref_back": clip_ref_back,
297
+ "pixel_values_ref_front": pixel_values_cloth_front,
298
+ "pixel_values_ref_back": pixel_values_cloth_back,
299
+ "camera_parm": camera_parm,
300
+ "drop_image_embeds": drop_image_embeds,
301
+ "img_name": img_name,
302
+ "cloth_name": cloth_name,
303
+ }
304
+
305
+
306
+ if __name__ == '__main__':
307
+ seed = 20
308
+ random.seed(seed)
309
+ torch.manual_seed(seed)
310
+ torch.cuda.manual_seed(seed)
311
+ dataset = MVHumanNet_Dataset(dataroot="/GPUFS/sysu_gbli2_1/hzj/mvhumannet/",
312
+ sample_size=(768,576),is_train=True,mode='pair',
313
+ clip_model_path = "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32")
314
+
315
+ # print(len(dataset))
316
+
317
+ # for _ in range(500):
318
+
319
+ # p = random.randint(0,len(dataset)-1)
320
+ # p = dataset[p]
321
+
322
+ test_dataloader = torch.utils.data.DataLoader(
323
+ dataset,
324
+ shuffle=False,
325
+ collate_fn=collate_fn,
326
+ batch_size=1,
327
+ num_workers=2,
328
+ )
329
+
330
+ for _, batch in enumerate(test_dataloader):
331
+ # print(batch['cloth_name'], batch['img_name'])
332
+ p = {}
333
+ print('111', batch['camera_parm'].shape)
334
+ print('111', batch['drop_image_embeds'].shape)
335
+ for key in batch.keys():
336
+ p[key] = batch[key][0]
337
+ # p = dataset[12]
338
+
339
+ print(p['camera_parm'].shape)
340
+
341
+ pixel_values = p['pixel_values'][0].permute(1,2,0).numpy()
342
+ print(p['pixel_values'].shape)
343
+ pixel_values = pixel_values / 2 + 0.5
344
+ pixel_values *=255
345
+ pixel_values = pixel_values.astype(np.uint8)
346
+ pixel_values= Image.fromarray(pixel_values)
347
+ pixel_values.save('pixel_values0.jpg')
348
+
349
+ pixel_values_pose = p['pixel_values_pose'][0].permute(1,2,0).numpy()
350
+ print(p['pixel_values_pose'].shape)
351
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
352
+ pixel_values_pose *=255
353
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
354
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
355
+ pixel_values_pose.save('pixel_values_pose.jpg')
356
+
357
+ pixel_values_agnostic = p['pixel_values_agnostic'][0].permute(1,2,0).numpy()
358
+ print(p['pixel_values_agnostic'].shape)
359
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
360
+ pixel_values_agnostic *=255
361
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
362
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
363
+ pixel_values_agnostic.save('pixel_values_agnostic.jpg')
364
+
365
+ pixel_values = p['pixel_values'][2].permute(1,2,0).numpy()
366
+ print(p['pixel_values'].shape)
367
+ pixel_values = pixel_values / 2 + 0.5
368
+ pixel_values *=255
369
+ pixel_values = pixel_values.astype(np.uint8)
370
+ pixel_values= Image.fromarray(pixel_values)
371
+ pixel_values.save('pixel_values2.jpg')
372
+
373
+ pixel_values_pose = p['pixel_values_pose'][2].permute(1,2,0).numpy()
374
+ print(p['pixel_values_pose'].shape)
375
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
376
+ pixel_values_pose *=255
377
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
378
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
379
+ pixel_values_pose.save('pixel_values_pose2.jpg')
380
+
381
+ pixel_values_agnostic = p['pixel_values_agnostic'][2].permute(1,2,0).numpy()
382
+ print(p['pixel_values_agnostic'].shape)
383
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
384
+ pixel_values_agnostic *=255
385
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
386
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
387
+ pixel_values_agnostic.save('pixel_values_agnostic2.jpg')
388
+
389
+ pixel_values_cloth_img = p['pixel_values_ref_front'].permute(1,2,0).numpy()
390
+ print(p['pixel_values_ref_front'].shape)
391
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
392
+ pixel_values_cloth_img *=255
393
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
394
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
395
+ pixel_values_cloth_img.save('pixel_values_cloth_front.jpg')
396
+
397
+ pixel_values_cloth_img = p['pixel_values_ref_back'].permute(1,2,0).numpy()
398
+ print(p['pixel_values_ref_back'].shape)
399
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
400
+ pixel_values_cloth_img *=255
401
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
402
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
403
+ pixel_values_cloth_img.save('pixel_values_cloth_back.jpg')
404
+ exit()
405
+
406
+
src/multiview_consist_edit/data/Thuman2_multi.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from PIL import Image,ImageDraw
4
+ import json
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from transformers import CLIPProcessor
10
+ import random
11
+ from torchvision.transforms import functional as F
12
+ import torch.distributed as dist
13
+ import copy
14
+ import cv2
15
+
16
+ def crop_image(human_img_orig):
17
+ human_img_orig = human_img_orig.resize((1024,1024))
18
+ original_width, original_height = human_img_orig.size
19
+ target_width = 768
20
+ crop_amount = (original_width - target_width) // 2
21
+ left = crop_amount
22
+ upper = 0
23
+ right = original_width - crop_amount
24
+ lower = original_height
25
+ cropped_image = human_img_orig.crop((left, upper, right, lower))
26
+ return cropped_image
27
+
28
+ class Thuman2_Dataset(Dataset):
29
+ def __init__(
30
+ self, dataroot, sample_size=(512,384), is_train=True, mode='pair', clip_model_path='', multi_length=8, output_front=True,
31
+ ):
32
+ c_names_front = []
33
+ c_names_back = []
34
+
35
+ self.data_ids = []
36
+ self.dataroot = os.path.join(dataroot, 'all')
37
+ self.cloth_root = os.path.join(dataroot, 'cloth')
38
+ # self.cloth_root = os.path.join(dataroot, 'MVG_clothes')
39
+
40
+ self.cloth_ids = []
41
+ if is_train:
42
+ f = open(os.path.join(dataroot,'train_ids.txt'))
43
+ for line in f.readlines():
44
+ self.data_ids.append(line.strip())
45
+ f.close()
46
+ else:
47
+ # f = open(os.path.join(dataroot, 'val_ids.txt'))
48
+ f = open(os.path.join(dataroot, 'test_ids.txt'))
49
+ # f = open(os.path.join(dataroot, 'test_mvg_ids.txt'))
50
+ for line in f.readlines():
51
+ self.data_ids.append(line.strip())
52
+ f.close()
53
+ f2 = open(os.path.join(dataroot, 'test_cloth_ids.txt'))
54
+ # f2 = open(os.path.join(dataroot, 'test_mvg_cloth_ids.txt'))
55
+ for line in f2.readlines():
56
+ self.cloth_ids.append(line.strip())
57
+ f2.close()
58
+
59
+ self.mode = mode
60
+ self.is_train = is_train
61
+ self.sample_size = sample_size
62
+ self.multi_length = multi_length
63
+ self.clip_image_processor = CLIPProcessor.from_pretrained(clip_model_path,local_files_only=True)
64
+
65
+ self.pixel_transforms = transforms.Compose([
66
+ transforms.Resize((1024,768), interpolation=0),
67
+ transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
68
+ transforms.Resize(self.sample_size, interpolation=0),
69
+ # transforms.CenterCrop(self.sample_size),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
72
+ ])
73
+
74
+ self.pixel_transforms_0 = transforms.Compose([
75
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
76
+ ])
77
+ self.pixel_transforms_1 = transforms.Compose([
78
+ transforms.Resize((1024,768), interpolation=0),
79
+ transforms.CenterCrop((int(1024 * 6/8), int(768 * 6/8))),
80
+ transforms.Resize(self.sample_size, interpolation=0),
81
+ ])
82
+
83
+ self.ref_transforms_train = transforms.Compose([
84
+ transforms.Resize(self.sample_size),
85
+ # RandomScaleResize([1.0,1.1]),
86
+ transforms.CenterCrop(self.sample_size),
87
+ transforms.RandomAffine(degrees=0, translate=(0.08,0.08),scale=(0.9,1.1)),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
90
+ ])
91
+ self.ref_transforms_test = transforms.Compose([
92
+ transforms.Resize(self.sample_size),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
95
+ ])
96
+ self.color_transform = transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.0)
97
+ self.output_front = True
98
+
99
+ def __len__(self):
100
+ if len(self.cloth_ids) >= 1:
101
+ return len(self.data_ids)*len(self.cloth_ids)
102
+ else:
103
+ return len(self.data_ids)
104
+
105
+ def __getitem__(self, idx):
106
+
107
+ if len(self.cloth_ids) >=1:
108
+ data_idx = idx // len(self.cloth_ids)
109
+ cloth_idx = idx % len(self.cloth_ids)
110
+
111
+ data_id = self.data_ids[data_idx]
112
+ cloth_id = self.cloth_ids[cloth_idx]
113
+ cloth_name_back = os.path.join(self.cloth_root, '%s_front.jpg' % cloth_id)
114
+ cloth_name_front = os.path.join(self.cloth_root, '%s_back.jpg' % cloth_id)
115
+ else:
116
+ data_id = self.data_ids[idx]
117
+ cloth_name_back = os.path.join(self.cloth_root, '%s_front.jpg' % data_id)
118
+ cloth_name_front = os.path.join(self.cloth_root, '%s_back.jpg' % data_id)
119
+
120
+ images_root = os.path.join(self.dataroot, data_id, 'agnostic') # need only val
121
+ images = sorted(os.listdir(images_root))
122
+
123
+ # cloth_name_back = '0001_front.jpg'
124
+ # cloth_name_front = '0001_back.jpg'
125
+
126
+ if self.is_train:
127
+ select_images = random.sample(images, self.multi_length)
128
+
129
+ else:
130
+ # select_idxs = [0,3,6,9,12, 15,18,21,24,27, 79,76,73,70,67,64]
131
+ L = len(images)
132
+ select_idxs = []
133
+ begin = 0
134
+ sl = 16.0
135
+ if self.output_front:
136
+ while begin < L//2:
137
+ select_idxs.append(int(begin/2))
138
+ select_idxs.append(int(L-1-begin/2))
139
+ begin += L/sl
140
+ else:
141
+ begin = L//4
142
+ while begin < L*3//4:
143
+ select_idxs.append(int(begin))
144
+ begin += L/2/sl
145
+ # print(sorted(select_idxs))
146
+ # select_idxs = [0,3,6,9,12, 15,18,21,24,27, L-1,L-4,L-7,L-10,L-13,L-16]
147
+ select_images = []
148
+ for select_idx in select_idxs:
149
+ select_images.append(images[select_idx])
150
+ select_images = sorted(select_images)
151
+ # print(select_images)
152
+ for i in range(len(select_images)):
153
+ select_images[i] = os.path.join(data_id,'images',select_images[i])
154
+ sample = self.load_images(select_images, cloth_name_front, cloth_name_back)
155
+ return sample
156
+
157
+ def color_progress(images):
158
+ fn_idx, b, c, s, h = self.color_transform.get_params(color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,color_jitter.hue)
159
+ for image in images:
160
+ image = F.adjust_contrast(image, c)
161
+ image = F.adjust_brightness(image, b)
162
+ image = F.adjust_saturation(image, s)
163
+ return images
164
+
165
+ def load_images(self, select_images, cloth_name_front, cloth_name_back):
166
+
167
+ pixel_values_list = []
168
+ pixel_values_pose_list = []
169
+ camera_parm_list = []
170
+ pixel_values_agnostic_list = []
171
+ image_name_list = []
172
+
173
+ # load person data
174
+ for img_name in select_images:
175
+ image_name_list.append(img_name)
176
+ pixel_values = Image.open(os.path.join(self.dataroot, img_name))
177
+ pixel_values_pose = Image.open(os.path.join(self.dataroot, img_name).replace('images', 'normals'))
178
+ # parse_lip = Image.open(os.path.join(parse_lip_dir, img_name))
179
+ pixel_values_agnostic = Image.open(os.path.join(self.dataroot, img_name).replace('images', 'agnostic'))
180
+ parm_matrix = np.load(os.path.join(self.dataroot, img_name[:4],'parm', img_name[-7:-4]+'_extrinsic.npy'))
181
+ pixel_values = crop_image(pixel_values)
182
+ pixel_values_pose = crop_image(pixel_values_pose)
183
+ # camera parameter
184
+ parm_matrix = torch.tensor(parm_matrix)
185
+ camera_parm = parm_matrix[:3,:3].reshape(-1) # todo
186
+ # transform
187
+ pixel_values = self.pixel_transforms(pixel_values)
188
+ pixel_values_pose = self.pixel_transforms(pixel_values_pose)
189
+ pixel_values_agnostic = self.pixel_transforms(pixel_values_agnostic)
190
+
191
+ pixel_values_list.append(pixel_values)
192
+ pixel_values_pose_list.append(pixel_values_pose)
193
+ camera_parm_list.append(camera_parm)
194
+ pixel_values_agnostic_list.append(pixel_values_agnostic)
195
+
196
+ pixel_values = torch.stack(pixel_values_list)
197
+ pixel_values_pose = torch.stack(pixel_values_pose_list)
198
+ camera_parm = torch.stack(camera_parm_list)
199
+ pixel_values_agnostic = torch.stack(pixel_values_agnostic_list)
200
+
201
+ pixel_values_cloth_front = Image.open(os.path.join(self.cloth_root, cloth_name_front))
202
+ pixel_values_cloth_back = Image.open(os.path.join(self.cloth_root, cloth_name_back))
203
+
204
+ # clip
205
+ clip_ref_front = self.clip_image_processor(images=pixel_values_cloth_front, return_tensors="pt").pixel_values
206
+ clip_ref_back = self.clip_image_processor(images=pixel_values_cloth_back, return_tensors="pt").pixel_values
207
+
208
+ if self.is_train:
209
+ pixel_values_cloth_front = self.ref_transforms_train(pixel_values_cloth_front)
210
+ pixel_values_cloth_back = self.ref_transforms_train(pixel_values_cloth_back)
211
+ else:
212
+ pixel_values_cloth_front = self.ref_transforms_test(pixel_values_cloth_front)
213
+ pixel_values_cloth_back = self.ref_transforms_test(pixel_values_cloth_back)
214
+
215
+ drop_image_embeds = []
216
+ for k in range(len(select_images)):
217
+ if random.random() < 0.1:
218
+ drop_image_embeds.append(torch.tensor(1))
219
+ else:
220
+ drop_image_embeds.append(torch.tensor(0))
221
+ drop_image_embeds = torch.stack(drop_image_embeds)
222
+ sample = dict(
223
+ pixel_values=pixel_values,
224
+ pixel_values_pose=pixel_values_pose,
225
+ pixel_values_agnostic=pixel_values_agnostic,
226
+ clip_ref_front=clip_ref_front,
227
+ clip_ref_back=clip_ref_back,
228
+ pixel_values_cloth_front=pixel_values_cloth_front,
229
+ pixel_values_cloth_back=pixel_values_cloth_back,
230
+ camera_parm=camera_parm,
231
+ drop_image_embeds=drop_image_embeds,
232
+ img_name=image_name_list,
233
+ cloth_name=cloth_name_front,
234
+ )
235
+
236
+ return sample
237
+
238
+ def collate_fn(data):
239
+
240
+ pixel_values = torch.stack([example["pixel_values"] for example in data])
241
+ pixel_values_pose = torch.stack([example["pixel_values_pose"] for example in data])
242
+ pixel_values_agnostic = torch.stack([example["pixel_values_agnostic"] for example in data])
243
+ clip_ref_front = torch.cat([example["clip_ref_front"] for example in data])
244
+ clip_ref_back = torch.cat([example["clip_ref_back"] for example in data])
245
+ pixel_values_cloth_front = torch.stack([example["pixel_values_cloth_front"] for example in data])
246
+ pixel_values_cloth_back = torch.stack([example["pixel_values_cloth_back"] for example in data])
247
+ camera_parm = torch.stack([example["camera_parm"] for example in data])
248
+ drop_image_embeds = [example["drop_image_embeds"] for example in data]
249
+ drop_image_embeds = torch.stack(drop_image_embeds)
250
+ img_name = []
251
+ cloth_name = []
252
+ for example in data:
253
+ img_name.extend(example['img_name'])
254
+ cloth_name.append(example['cloth_name'])
255
+
256
+ return {
257
+ "pixel_values": pixel_values,
258
+ "pixel_values_pose": pixel_values_pose,
259
+ "pixel_values_agnostic": pixel_values_agnostic,
260
+ "clip_ref_front": clip_ref_front,
261
+ "clip_ref_back": clip_ref_back,
262
+ "pixel_values_ref_front": pixel_values_cloth_front,
263
+ "pixel_values_ref_back": pixel_values_cloth_back,
264
+ "camera_parm": camera_parm,
265
+ "drop_image_embeds": drop_image_embeds,
266
+ "img_name": img_name,
267
+ "cloth_name": cloth_name,
268
+ }
269
+
270
+
271
+ if __name__ == '__main__':
272
+ seed = 20
273
+ random.seed(seed)
274
+ torch.manual_seed(seed)
275
+ torch.cuda.manual_seed(seed)
276
+ dataset = Thuman2_Dataset(dataroot="/GPUFS/sysu_gbli2_1/hzj/save_render_data_yw/",
277
+ sample_size=(768,576),is_train=False,mode='pair',
278
+ clip_model_path = "/GPUFS/sysu_gbli2_1/hzj/pretrained_models/clip-vit-base-patch32")
279
+
280
+ # for _ in range(500):
281
+
282
+ # p = random.randint(0,len(dataset)-1)
283
+ # p = dataset[p]
284
+
285
+ test_dataloader = torch.utils.data.DataLoader(
286
+ dataset,
287
+ shuffle=False,
288
+ collate_fn=collate_fn,
289
+ batch_size=2,
290
+ num_workers=1,
291
+ )
292
+
293
+ for _, batch in enumerate(test_dataloader):
294
+ p = {}
295
+ print('111', batch['camera_parm'].shape)
296
+ print('111', batch['drop_image_embeds'].shape)
297
+ for key in batch.keys():
298
+ p[key] = batch[key][0]
299
+ # p = dataset[12]
300
+
301
+ print(p['camera_parm'].shape)
302
+
303
+ pixel_values = p['pixel_values'][0].permute(1,2,0).numpy()
304
+ print(p['pixel_values'].shape)
305
+ pixel_values = pixel_values / 2 + 0.5
306
+ pixel_values *=255
307
+ pixel_values = pixel_values.astype(np.uint8)
308
+ pixel_values= Image.fromarray(pixel_values)
309
+ pixel_values.save('pixel_values0.jpg')
310
+
311
+ pixel_values_pose = p['pixel_values_pose'][0].permute(1,2,0).numpy()
312
+ print(p['pixel_values_pose'].shape)
313
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
314
+ pixel_values_pose *=255
315
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
316
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
317
+ pixel_values_pose.save('pixel_values_pose.jpg')
318
+
319
+ pixel_values_agnostic = p['pixel_values_agnostic'][0].permute(1,2,0).numpy()
320
+ print(p['pixel_values_agnostic'].shape)
321
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
322
+ pixel_values_agnostic *=255
323
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
324
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
325
+ pixel_values_agnostic.save('pixel_values_agnostic.jpg')
326
+
327
+ pixel_values = p['pixel_values'][2].permute(1,2,0).numpy()
328
+ print(p['pixel_values'].shape)
329
+ pixel_values = pixel_values / 2 + 0.5
330
+ pixel_values *=255
331
+ pixel_values = pixel_values.astype(np.uint8)
332
+ pixel_values= Image.fromarray(pixel_values)
333
+ pixel_values.save('pixel_values2.jpg')
334
+
335
+ pixel_values_pose = p['pixel_values_pose'][2].permute(1,2,0).numpy()
336
+ print(p['pixel_values_pose'].shape)
337
+ pixel_values_pose = pixel_values_pose / 2 + 0.5
338
+ pixel_values_pose *=255
339
+ pixel_values_pose = pixel_values_pose.astype(np.uint8)
340
+ pixel_values_pose= Image.fromarray(pixel_values_pose)
341
+ pixel_values_pose.save('pixel_values_pose2.jpg')
342
+
343
+ pixel_values_agnostic = p['pixel_values_agnostic'][2].permute(1,2,0).numpy()
344
+ print(p['pixel_values_agnostic'].shape)
345
+ pixel_values_agnostic = pixel_values_agnostic / 2 + 0.5
346
+ pixel_values_agnostic *=255
347
+ pixel_values_agnostic = pixel_values_agnostic.astype(np.uint8)
348
+ pixel_values_agnostic= Image.fromarray(pixel_values_agnostic)
349
+ pixel_values_agnostic.save('pixel_values_agnostic2.jpg')
350
+
351
+ pixel_values_cloth_img = p['pixel_values_ref_front'].permute(1,2,0).numpy()
352
+ print(p['pixel_values_ref_front'].shape)
353
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
354
+ pixel_values_cloth_img *=255
355
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
356
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
357
+ pixel_values_cloth_img.save('pixel_values_cloth_front.jpg')
358
+
359
+ pixel_values_cloth_img = p['pixel_values_ref_back'].permute(1,2,0).numpy()
360
+ print(p['pixel_values_ref_back'].shape)
361
+ pixel_values_cloth_img = pixel_values_cloth_img / 2 + 0.5
362
+ pixel_values_cloth_img *=255
363
+ pixel_values_cloth_img = pixel_values_cloth_img.astype(np.uint8)
364
+ pixel_values_cloth_img= Image.fromarray(pixel_values_cloth_img)
365
+ pixel_values_cloth_img.save('pixel_values_cloth_back.jpg')
366
+
367
+ exit()
src/multiview_consist_edit/data/camera_utils.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ from os.path import join
5
+ class FileStorage(object):
6
+ def __init__(self, filename, isWrite=False):
7
+ version = cv2.__version__
8
+ self.major_version = int(version.split('.')[0])
9
+ self.second_version = int(version.split('.')[1])
10
+
11
+ if isWrite:
12
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
13
+ self.fs = open(filename, 'w')
14
+ self.fs.write('%YAML:1.0\r\n')
15
+ self.fs.write('---\r\n')
16
+ else:
17
+ assert os.path.exists(filename), filename
18
+ self.fs = cv2.FileStorage(filename, cv2.FILE_STORAGE_READ)
19
+ self.isWrite = isWrite
20
+
21
+ def __del__(self):
22
+ if self.isWrite:
23
+ self.fs.close()
24
+ else:
25
+ cv2.FileStorage.release(self.fs)
26
+
27
+ def _write(self, out):
28
+ self.fs.write(out+'\r\n')
29
+
30
+ def write(self, key, value, dt='mat'):
31
+ if dt == 'mat':
32
+ self._write('{}: !!opencv-matrix'.format(key))
33
+ self._write(' rows: {}'.format(value.shape[0]))
34
+ self._write(' cols: {}'.format(value.shape[1]))
35
+ self._write(' dt: d')
36
+ self._write(' data: [{}]'.format(', '.join(['{:.6f}'.format(i) for i in value.reshape(-1)])))
37
+ elif dt == 'list':
38
+ self._write('{}:'.format(key))
39
+ for elem in value:
40
+ self._write(' - "{}"'.format(elem))
41
+ elif dt == 'int':
42
+ self._write('{}: {}'.format(key, value))
43
+
44
+ def read(self, key, dt='mat'):
45
+ if dt == 'mat':
46
+ output = self.fs.getNode(key).mat()
47
+ elif dt == 'list':
48
+ results = []
49
+ n = self.fs.getNode(key)
50
+ for i in range(n.size()):
51
+ val = n.at(i).string()
52
+ if val == '':
53
+ val = str(int(n.at(i).real()))
54
+ if val != 'none':
55
+ results.append(val)
56
+ output = results
57
+ elif dt == 'int':
58
+ output = int(self.fs.getNode(key).real())
59
+ else:
60
+ raise NotImplementedError
61
+ return output
62
+
63
+ def close(self):
64
+ self.__del__(self)
65
+
66
+ def read_intri(intri_name):
67
+ assert os.path.exists(intri_name), intri_name
68
+ intri = FileStorage(intri_name)
69
+ camnames = intri.read('names', dt='list')
70
+ cameras = {}
71
+ for key in camnames:
72
+ cam = {}
73
+ cam['K'] = intri.read('K_{}'.format(key))
74
+ cam['invK'] = np.linalg.inv(cam['K'])
75
+ cam['dist'] = intri.read('dist_{}'.format(key))
76
+ cameras[key] = cam
77
+ return cameras
78
+
79
+ def write_intri(intri_name, cameras):
80
+ if not os.path.exists(os.path.dirname(intri_name)):
81
+ os.makedirs(os.path.dirname(intri_name))
82
+ intri = FileStorage(intri_name, True)
83
+ results = {}
84
+ camnames = list(cameras.keys())
85
+ intri.write('names', camnames, 'list')
86
+ for key_, val in cameras.items():
87
+ key = key_.split('.')[0]
88
+ K, dist = val['K'], val['dist']
89
+ assert K.shape == (3, 3), K.shape
90
+ assert dist.shape == (1, 5) or dist.shape == (5, 1) or dist.shape == (1, 4) or dist.shape == (4, 1), dist.shape
91
+ intri.write('K_{}'.format(key), K)
92
+ intri.write('dist_{}'.format(key), dist.flatten()[None])
93
+
94
+ def write_extri(extri_name, cameras):
95
+ if not os.path.exists(os.path.dirname(extri_name)):
96
+ os.makedirs(os.path.dirname(extri_name))
97
+ extri = FileStorage(extri_name, True)
98
+ results = {}
99
+ camnames = list(cameras.keys())
100
+ extri.write('names', camnames, 'list')
101
+ for key_, val in cameras.items():
102
+ key = key_.split('.')[0]
103
+ extri.write('R_{}'.format(key), val['Rvec'])
104
+ extri.write('Rot_{}'.format(key), val['R'])
105
+ extri.write('T_{}'.format(key), val['T'])
106
+ return 0
107
+
108
+ def read_camera(intri_name, extri_name, cam_names=[]):
109
+ assert os.path.exists(intri_name), intri_name
110
+ assert os.path.exists(extri_name), extri_name
111
+
112
+ intri = FileStorage(intri_name)
113
+ extri = FileStorage(extri_name)
114
+ cams, P = {}, {}
115
+ cam_names = intri.read('names', dt='list')
116
+ for cam in cam_names:
117
+ # 内参只读子码流的
118
+ cams[cam] = {}
119
+ cams[cam]['K'] = intri.read('K_{}'.format( cam))
120
+ cams[cam]['invK'] = np.linalg.inv(cams[cam]['K'])
121
+ H = intri.read('H_{}'.format(cam), dt='int')
122
+ W = intri.read('W_{}'.format(cam), dt='int')
123
+ if H is None or W is None:
124
+ print('[camera] no H or W for {}'.format(cam))
125
+ H, W = -1, -1
126
+ cams[cam]['H'] = H
127
+ cams[cam]['W'] = W
128
+ Rvec = extri.read('R_{}'.format(cam))
129
+ Tvec = extri.read('T_{}'.format(cam))
130
+ assert Rvec is not None, cam
131
+ R = cv2.Rodrigues(Rvec)[0]
132
+ RT = np.hstack((R, Tvec))
133
+
134
+ cams[cam]['RT'] = RT
135
+ cams[cam]['R'] = R
136
+ cams[cam]['Rvec'] = Rvec
137
+ cams[cam]['T'] = Tvec
138
+ cams[cam]['center'] = - Rvec.T @ Tvec
139
+ P[cam] = cams[cam]['K'] @ cams[cam]['RT']
140
+ cams[cam]['P'] = P[cam]
141
+
142
+ cams[cam]['dist'] = intri.read('dist_{}'.format(cam))
143
+ if cams[cam]['dist'] is None:
144
+ cams[cam]['dist'] = intri.read('D_{}'.format(cam))
145
+ if cams[cam]['dist'] is None:
146
+ print('[camera] no dist for {}'.format(cam))
147
+ cams['basenames'] = cam_names
148
+ return cams
149
+
150
+
151
+
152
+
153
+ def read_camera_mvhumannet(intri_name, extri_name, camera_scale ,cam_names=[]):
154
+ assert os.path.exists(intri_name), intri_name
155
+ assert os.path.exists(extri_name), extri_name
156
+
157
+ import json
158
+
159
+ with open(intri_name, 'r') as f:
160
+ camera_intrinsics = json.load(f)
161
+
162
+ with open(extri_name, 'r') as f:
163
+ camera_extrinsics = json.load(f)
164
+
165
+ # print("intri: ", camera_intrinsics)
166
+
167
+ item = os.path.dirname(intri_name).split("/")[-1]
168
+ # print("item: ", item)
169
+ # intri = FileStorage(intri_name)
170
+ # extri = FileStorage(extri_name)
171
+ cams, P = {}, {}
172
+
173
+ # cam_names = intri.read('names', dt='list')
174
+
175
+ cam_names = camera_extrinsics.keys()
176
+
177
+ for cam in cam_names:
178
+ # 内参只读子码流的
179
+
180
+ updated_cam = cam.split('.')[0].split('_')
181
+ # print("updated_cam_before: ", updated_cam)
182
+ # updated_cam[1] = 'cache' # for test
183
+ updated_cam = updated_cam[-1]
184
+ # print("updated_cam_after: ", updated_cam)
185
+
186
+ cams[updated_cam] = {}
187
+ # cams[updated_cam]['K'] = intri.read('K_{}'.format( cam))
188
+ cams[updated_cam]['K'] = np.array(camera_intrinsics['intrinsics'])
189
+ cams[updated_cam]['invK'] = np.linalg.inv(cams[updated_cam]['K'])
190
+
191
+ # import IPython; IPython.embed(); exit()
192
+
193
+ # Rvec = extri.read('R_{}'.format(cam))
194
+ # Tvec = extri.read('T_{}'.format(cam))
195
+ # assert Rvec is not None, cam
196
+ # R = cv2.Rodrigues(Rvec)[0]
197
+
198
+ R = np.array(camera_extrinsics[cam]['rotation'])
199
+ # longgang
200
+ # Tvec = np.array(camera_extrinsics[cam]['translation'])[:, None] / 1000 * 100 / 65
201
+ # futian
202
+ Tvec = np.array(camera_extrinsics[cam]['translation'])[:, None] / 1000 * camera_scale
203
+
204
+ RT = np.hstack((R, Tvec))
205
+
206
+ cams[updated_cam]['RT'] = RT
207
+ cams[updated_cam]['R'] = R
208
+ # cams[updated_cam]['Rvec'] = Rvec
209
+ cams[updated_cam]['T'] = Tvec
210
+ # cams[updated_cam]['center'] = - Rvec.T @ Tvec
211
+ P[updated_cam] = cams[updated_cam]['K'] @ cams[updated_cam]['RT']
212
+ cams[updated_cam]['P'] = P[updated_cam]
213
+
214
+ # cams[updated_cam]['dist'] = np.array(camera_intrinsics['dist'])
215
+ cams[updated_cam]['dist'] = None # dist for cv2.undistortPoint
216
+ # cams['basenames'] = cam_names
217
+ return cams
218
+
219
+
220
+ def read_camera_ours(intri_name, extri_name, cam_names=[]):
221
+ assert os.path.exists(intri_name), intri_name
222
+ assert os.path.exists(extri_name), extri_name
223
+
224
+ import json
225
+
226
+ with open(intri_name, 'r') as f:
227
+ camera_intrinsics = json.load(f)
228
+
229
+ with open(extri_name, 'r') as f:
230
+ camera_extrinsics = json.load(f)
231
+
232
+ # print("intri: ", camera_intrinsics)
233
+
234
+ item = os.path.dirname(intri_name).split("/")[-1]
235
+ # print("item: ", item)
236
+ # intri = FileStorage(intri_name)
237
+ # extri = FileStorage(extri_name)
238
+ cams, P = {}, {}
239
+
240
+ # cam_names = intri.read('names', dt='list')
241
+
242
+ cam_names = camera_extrinsics.keys()
243
+
244
+ for cam in cam_names:
245
+ # 内参只读子码流的
246
+
247
+ updated_cam = cam.split('.')[0].split('_')
248
+ # print("updated_cam_before: ", updated_cam)
249
+ # updated_cam[1] = 'cache' # for test
250
+ updated_cam = updated_cam[-1]
251
+ # print("updated_cam_after: ", updated_cam)
252
+
253
+ cams[updated_cam] = {}
254
+ # cams[updated_cam]['K'] = intri.read('K_{}'.format( cam))
255
+ cams[updated_cam]['K'] = np.array(camera_intrinsics['intrinsics'])
256
+ cams[updated_cam]['invK'] = np.linalg.inv(cams[updated_cam]['K'])
257
+
258
+ # import IPython; IPython.embed(); exit()
259
+
260
+ # Rvec = extri.read('R_{}'.format(cam))
261
+ # Tvec = extri.read('T_{}'.format(cam))
262
+ # assert Rvec is not None, cam
263
+ # R = cv2.Rodrigues(Rvec)[0]
264
+
265
+ R = np.array(camera_extrinsics[cam]['rotation'])
266
+ # longgang
267
+ # Tvec = np.array(camera_extrinsics[cam]['translation'])[:, None] / 1000 * 100 / 65
268
+ # futian
269
+ Tvec = np.array(camera_extrinsics[cam]['translation'])[:, None] / 1000 * 120 / 65
270
+
271
+ RT = np.hstack((R, Tvec))
272
+
273
+ cams[updated_cam]['RT'] = RT
274
+ cams[updated_cam]['R'] = R
275
+ # cams[updated_cam]['Rvec'] = Rvec
276
+ cams[updated_cam]['T'] = Tvec
277
+ # cams[updated_cam]['center'] = - Rvec.T @ Tvec
278
+ P[updated_cam] = cams[updated_cam]['K'] @ cams[updated_cam]['RT']
279
+ cams[updated_cam]['P'] = P[updated_cam]
280
+
281
+ # cams[updated_cam]['dist'] = np.array(camera_intrinsics['dist'])
282
+ cams[updated_cam]['dist'] = None # dist for cv2.undistortPoint
283
+ # cams['basenames'] = cam_names
284
+ return cams
285
+
286
+
287
+
288
+ def read_cameras(path, intri='intri.yml', extri='extri.yml', subs=[]):
289
+ cameras = read_camera(join(path, intri), join(path, extri))
290
+ cameras.pop('basenames')
291
+ if len(subs) > 0:
292
+ cameras = {key:cameras[key].astype(np.float32) for key in subs}
293
+ return cameras
294
+
295
+ def write_camera(camera, path):
296
+ from os.path import join
297
+ intri_name = join(path, 'intri.yml')
298
+ extri_name = join(path, 'extri.yml')
299
+ intri = FileStorage(intri_name, True)
300
+ extri = FileStorage(extri_name, True)
301
+ results = {}
302
+ camnames = [key_.split('.')[0] for key_ in camera.keys()]
303
+ intri.write('names', camnames, 'list')
304
+ extri.write('names', camnames, 'list')
305
+ for key_, val in camera.items():
306
+ if key_ == 'basenames':
307
+ continue
308
+ key = key_.split('.')[0]
309
+ intri.write('K_{}'.format(key), val['K'])
310
+ intri.write('dist_{}'.format(key), val['dist'])
311
+ if 'H' in val.keys() and 'W' in val.keys():
312
+ intri.write('H_{}'.format(key), val['H'], dt='int')
313
+ intri.write('W_{}'.format(key), val['W'], dt='int')
314
+ if 'Rvec' not in val.keys():
315
+ val['Rvec'] = cv2.Rodrigues(val['R'])[0]
316
+ extri.write('R_{}'.format(key), val['Rvec'])
317
+ extri.write('Rot_{}'.format(key), val['R'])
318
+ extri.write('T_{}'.format(key), val['T'])
319
+
320
+ def camera_from_img(img):
321
+ height, width = img.shape[0], img.shape[1]
322
+ # focal = 1.2*max(height, width) # as colmap
323
+ focal = 1.2*min(height, width) # as colmap
324
+ K = np.array([focal, 0., width/2, 0., focal, height/2, 0. ,0., 1.]).reshape(3, 3)
325
+ camera = {'K':K ,'R': np.eye(3), 'T': np.zeros((3, 1)), 'dist': np.zeros((1, 5))}
326
+ camera['invK'] = np.linalg.inv(camera['K'])
327
+ camera['P'] = camera['K'] @ np.hstack((camera['R'], camera['T']))
328
+ return camera
329
+
330
+ class Undistort:
331
+ distortMap = {}
332
+ @classmethod
333
+ def image(cls, frame, K, dist, sub=None, interp=cv2.INTER_NEAREST):
334
+ if sub is None:
335
+ return cv2.undistort(frame, K, dist, None)
336
+ else:
337
+ if sub not in cls.distortMap.keys():
338
+ h, w = frame.shape[:2]
339
+ mapx, mapy = cv2.initUndistortRectifyMap(K, dist, None, K, (w,h), 5)
340
+ cls.distortMap[sub] = (mapx, mapy)
341
+ mapx, mapy = cls.distortMap[sub]
342
+ img = cv2.remap(frame, mapx, mapy, interp)
343
+ return img
344
+
345
+ @staticmethod
346
+ def points(keypoints, K, dist):
347
+ # keypoints: (N, 3)
348
+ assert len(keypoints.shape) == 2, keypoints.shape
349
+ kpts = keypoints[:, None, :2]
350
+ kpts = np.ascontiguousarray(kpts)
351
+ kpts = cv2.undistortPoints(kpts, K, dist, P=K)
352
+ keypoints = np.hstack([kpts[:, 0], keypoints[:, 2:]])
353
+ return keypoints
354
+
355
+ @staticmethod
356
+ def bbox(bbox, K, dist):
357
+ keypoints = np.array([[bbox[0], bbox[1], 1], [bbox[2], bbox[3], 1]])
358
+ kpts = Undistort.points(keypoints, K, dist)
359
+ bbox = np.array([kpts[0, 0], kpts[0, 1], kpts[1, 0], kpts[1, 1], bbox[4]])
360
+ return bbox
361
+
362
+ class Distort:
363
+ @staticmethod
364
+ def points(keypoints, K, dist):
365
+ pass
366
+
367
+ @staticmethod
368
+ def bbox(bbox, K, dist):
369
+ keypoints = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[3]]], dtype=np.float32)
370
+ k3d = cv2.convertPointsToHomogeneous(keypoints)
371
+ k3d = (np.linalg.inv(K) @ k3d[:, 0].T).T[:, None]
372
+ k2d, _ = cv2.projectPoints(k3d, np.zeros((3,)), np.zeros((3,)), K, dist)
373
+ k2d = k2d[:, 0]
374
+ bbox = np.array([k2d[0,0], k2d[0,1], k2d[1, 0], k2d[1, 1], bbox[-1]])
375
+ return bbox
376
+
377
+ def unproj(kpts, invK):
378
+ homo = np.hstack([kpts[:, :2], np.ones_like(kpts[:, :1])])
379
+ homo = homo @ invK.T
380
+ return np.hstack([homo[:, :2], kpts[:, 2:]])
381
+ class UndistortFisheye:
382
+ @staticmethod
383
+ def image(frame, K, dist):
384
+ Knew = K.copy()
385
+ frame = cv2.fisheye.undistortImage(frame, K, dist, Knew=Knew)
386
+ return frame, Knew
387
+
388
+ @staticmethod
389
+ def points(keypoints, K, dist, Knew):
390
+ # keypoints: (N, 3)
391
+ assert len(keypoints.shape) == 2, keypoints.shape
392
+ kpts = keypoints[:, None, :2]
393
+ kpts = np.ascontiguousarray(kpts)
394
+ kpts = cv2.fisheye.undistortPoints(kpts, K, dist, P=Knew)
395
+ keypoints = np.hstack([kpts[:, 0], keypoints[:, 2:]])
396
+ return keypoints
397
+
398
+ @staticmethod
399
+ def bbox(bbox, K, dist, Knew):
400
+ keypoints = np.array([[bbox[0], bbox[1], 1], [bbox[2], bbox[3], 1]])
401
+ kpts = UndistortFisheye.points(keypoints, K, dist, Knew)
402
+ bbox = np.array([kpts[0, 0], kpts[0, 1], kpts[1, 0], kpts[1, 1], bbox[4]])
403
+ return bbox
404
+
405
+
406
+ def get_Pall(cameras, camnames):
407
+ Pall = np.stack([cameras[cam]['K'] @ np.hstack((cameras[cam]['R'], cameras[cam]['T'])) for cam in camnames])
408
+ return Pall
409
+
410
+ def get_fundamental_matrix(cameras, basenames):
411
+ skew_op = lambda x: np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
412
+ fundamental_op = lambda K_0, R_0, T_0, K_1, R_1, T_1: np.linalg.inv(K_0).T @ (
413
+ R_0 @ R_1.T) @ K_1.T @ skew_op(K_1 @ R_1 @ R_0.T @ (T_0 - R_0 @ R_1.T @ T_1))
414
+ fundamental_RT_op = lambda K_0, RT_0, K_1, RT_1: fundamental_op (K_0, RT_0[:, :3], RT_0[:, 3], K_1,
415
+ RT_1[:, :3], RT_1[:, 3] )
416
+ F = np.zeros((len(basenames), len(basenames), 3, 3)) # N x N x 3 x 3 matrix
417
+ F = {(icam, jcam): np.zeros((3, 3)) for jcam in basenames for icam in basenames}
418
+ for icam in basenames:
419
+ for jcam in basenames:
420
+ F[(icam, jcam)] += fundamental_RT_op(cameras[icam]['K'], cameras[icam]['RT'], cameras[jcam]['K'], cameras[jcam]['RT'])
421
+ if F[(icam, jcam)].sum() == 0:
422
+ F[(icam, jcam)] += 1e-12 # to avoid nan
423
+ return F
424
+
425
+ def interp_cameras(cameras, keys, step=20, loop=True, allstep=-1, **kwargs):
426
+ from scipy.spatial.transform import Rotation as R
427
+ from scipy.spatial.transform import Slerp
428
+ if allstep != -1:
429
+ tall = np.linspace(0., 1., allstep+1)[:-1].reshape(-1, 1, 1)
430
+ elif allstep == -1 and loop:
431
+ tall = np.linspace(0., 1., 1+step*len(keys))[:-1].reshape(-1, 1, 1)
432
+ elif allstep == -1 and not loop:
433
+ tall = np.linspace(0., 1., 1+step*(len(keys)-1))[:-1].reshape(-1, 1, 1)
434
+ cameras_new = {}
435
+ for ik in range(len(keys)):
436
+ if ik == len(keys) -1 and not loop:
437
+ break
438
+ if loop:
439
+ start, end = (ik * tall.shape[0])//len(keys), int((ik+1)*tall.shape[0])//len(keys)
440
+ print(ik, start, end, tall.shape)
441
+ else:
442
+ start, end = (ik * tall.shape[0])//(len(keys)-1), int((ik+1)*tall.shape[0])//(len(keys)-1)
443
+ t = tall[start:end].copy()
444
+ t = (t-t.min())/(t.max()-t.min())
445
+ left, right = keys[ik], keys[0 if ik == len(keys)-1 else ik + 1]
446
+ camera_left = cameras[left]
447
+ camera_right = cameras[right]
448
+ # 插值相机中心: center = - R.T @ T
449
+ center_l = - camera_left['R'].T @ camera_left['T']
450
+ center_r = - camera_right['R'].T @ camera_right['T']
451
+ center_l, center_r = center_l[None], center_r[None]
452
+ if False:
453
+ centers = center_l * (1-t) + center_r * t
454
+ else:
455
+ # 球面插值
456
+ norm_l, norm_r = np.linalg.norm(center_l), np.linalg.norm(center_r)
457
+ center_l, center_r = center_l/norm_l, center_r/norm_r
458
+ costheta = (center_l*center_r).sum()
459
+ sintheta = np.sqrt(1. - costheta**2)
460
+ theta = np.arctan2(sintheta, costheta)
461
+ centers = (np.sin(theta*(1-t)) * center_l + np.sin(theta * t) * center_r)/sintheta
462
+ norm = norm_l * (1-t) + norm_r * t
463
+ centers = centers * norm
464
+ key_rots = R.from_matrix(np.stack([camera_left['R'], camera_right['R']]))
465
+ key_times = [0, 1]
466
+ slerp = Slerp(key_times, key_rots)
467
+ interp_rots = slerp(t.squeeze()).as_matrix()
468
+ # 计算相机T RX + T = 0 => T = - R @ X
469
+ T = - np.einsum('bmn,bno->bmo', interp_rots, centers)
470
+ K = camera_left['K'] * (1-t) + camera_right['K'] * t
471
+ for i in range(T.shape[0]):
472
+ cameras_new['{}-{}-{}'.format(left, right, i)] = \
473
+ {
474
+ 'K': K[i],
475
+ 'dist': np.zeros((1, 5)),
476
+ 'R': interp_rots[i],
477
+ 'T': T[i]
478
+ }
479
+ return cameras_new
src/multiview_consist_edit/infer_tryon_multi.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import os
9
+ import random
10
+ import copy
11
+ import time
12
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
13
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, DDIMScheduler
14
+ from torchvision.utils import make_grid as make_image_grid
15
+ from torchvision.utils import save_image
16
+ from models.condition_encoder import FrozenOpenCLIPImageEmbedderV2
17
+ from omegaconf import OmegaConf
18
+ from pipelines.pipeline_tryon_multi import TryOnPipeline
19
+ from models.hack_poseguider import Hack_PoseGuider as PoseGuider
20
+
21
+ from models.ReferenceNet import ReferenceNet
22
+ from models.ReferenceEncoder import ReferenceEncoder
23
+
24
+ from data.Thuman2_multi import Thuman2_Dataset, collate_fn
25
+ # from data.Thuman2_multi_ps2 import Thuman2_Dataset, collate_fn
26
+ from data.MVHumanNet_multi import MVHumanNet_Dataset
27
+ from models.hack_unet2d import Hack_UNet2DConditionModel as UNet2DConditionModel
28
+
29
+ config = OmegaConf.load('config/infer_tryon_multi.yaml')
30
+
31
+ def main():
32
+ # seed
33
+ seed = config.seed
34
+ random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed(seed)
37
+
38
+ # dataset
39
+ infer_data_config = config.infer_data
40
+ if 'mvhumannet' in infer_data_config['dataroot']:
41
+ infer_dataset = MVHumanNet_Dataset(**infer_data_config)
42
+ print('using mvhumannet')
43
+ else:
44
+ infer_dataset = Thuman2_Dataset(**infer_data_config)
45
+ print('using Thuman2_Dataset')
46
+
47
+ batch_size = config.batch_size
48
+ # multi_length = 16
49
+
50
+ test_dataloader = torch.utils.data.DataLoader(
51
+ infer_dataset,
52
+ shuffle=False,
53
+ collate_fn=collate_fn,
54
+ batch_size=config.batch_size,
55
+ num_workers=config.dataloader_num_workers,
56
+ )
57
+
58
+ unet = UNet2DConditionModel.from_pretrained(
59
+ config.unet_path, subfolder="unet",torch_dtype=torch.float16
60
+ ).to("cuda")
61
+ # unet = UNet2DConditionModel.from_pretrained(
62
+ # config.unet_path, subfolder=None,torch_dtype=torch.float16
63
+ # ).to("cuda")
64
+
65
+ vae= AutoencoderKL.from_pretrained(
66
+ config.vae_path,torch_dtype=torch.float16
67
+ ).to("cuda")
68
+
69
+ referencenet = ReferenceNet.from_pretrained(
70
+ config.pretrained_referencenet_path, subfolder="referencenet",torch_dtype=torch.float16
71
+ ).to("cuda")
72
+ # referencenet = ReferenceNet.load_referencenet(pretrained_model_path=config.pretrained_referencenet_path).to("cuda", dtype=torch.float16)
73
+
74
+ pose_guider = PoseGuider.from_pretrained(pretrained_model_path=config.pretrained_poseguider_path).to("cuda", dtype=torch.float16)
75
+ pose_guider.eval()
76
+ scheduler = DDIMScheduler.from_pretrained(config.model_path, subfolder='scheduler')
77
+
78
+ pipe = TryOnPipeline(pose_guider=pose_guider, referencenet=referencenet, vae=vae, unet=unet, scheduler=scheduler)
79
+ pipe.enable_xformers_memory_efficient_attention()
80
+ # pipe._execution_device = torch.device("cuda")
81
+ # pipe.to("cuda")
82
+
83
+ clip_image_encoder = ReferenceEncoder(model_path=config.clip_model_path).to(device='cuda',dtype=torch.float16)
84
+
85
+ pipe.scheduler = DDIMScheduler(
86
+ beta_start=0.00085,
87
+ beta_end=0.012,
88
+ beta_schedule="scaled_linear",
89
+ clip_sample=False,
90
+ set_alpha_to_one=False,
91
+ )
92
+ generator = torch.Generator("cuda").manual_seed(seed)
93
+
94
+ # infer
95
+ out_dir = config.out_dir
96
+ if not os.path.exists(out_dir):
97
+ os.makedirs(out_dir)
98
+
99
+ num_inference_steps = config.num_inference_steps
100
+ guidance_scale = config.guidance_scale
101
+ weight_dtype = torch.float16
102
+
103
+ # # check vae reconstruction
104
+ # image_idx = 0
105
+ # for i, batch in enumerate(test_dataloader):
106
+ # video = batch['pixel_values'].to(device='cuda', dtype=torch.float16)
107
+ # out = video[0].cpu() /2 +0.5
108
+ # out = out.detach().permute(1,2,0).numpy()
109
+ # out = (out * 255).astype(np.uint8)
110
+ # out = Image.fromarray(out)
111
+ # out.save('%d_test_ori.png' % i)
112
+
113
+ # latents = vae.encode(video)
114
+ # latents = latents.latent_dist.sample()
115
+
116
+ # reconstruct_video = vae.decode(latents).sample
117
+
118
+ # reconstruct_video = reconstruct_video.clamp(-1, 1)
119
+ # out = reconstruct_video[0].cpu() /2 +0.5
120
+ # out = out.detach().permute(1,2,0).numpy()
121
+ # out = (out * 255).astype(np.uint8)
122
+ # out = Image.fromarray(out)
123
+ # out.save('%d_test2.png' % i)
124
+
125
+
126
+ image_idx = 0
127
+ for i, batch in enumerate(test_dataloader):
128
+
129
+ pixel_values = batch["pixel_values"]
130
+ pixel_values_pose = batch["pixel_values_pose"].to(device='cuda')
131
+ pixel_values_agnostic = batch["pixel_values_agnostic"].to(device='cuda')
132
+ clip_ref_front = batch["clip_ref_front"].to(device='cuda')
133
+ clip_ref_back = batch["clip_ref_back"].to(device='cuda')
134
+ pixel_values_ref_front = batch["pixel_values_ref_front"].to(device='cuda')
135
+ pixel_values_ref_back = batch["pixel_values_ref_back"].to(device='cuda')
136
+ camera_pose = batch["camera_parm"]
137
+ front_dino_fea = clip_image_encoder(clip_ref_front.to(weight_dtype))
138
+ back_dino_fea = clip_image_encoder(clip_ref_back.to(weight_dtype))
139
+ img_name = batch["img_name"]
140
+ cloth_name = batch["cloth_name"]
141
+ multi_length = pixel_values.shape[1]
142
+ # dino_fea = dino_fea.unsqueeze(1)
143
+ # print(dino_fea.shape) # [bs,1,768]
144
+ print(img_name)
145
+ edited_images = pipe(
146
+ num_inference_steps=num_inference_steps,
147
+ guidance_scale=guidance_scale,
148
+ front_image=pixel_values_ref_front.to(weight_dtype),
149
+ back_image=pixel_values_ref_back.to(weight_dtype),
150
+ pose_image=pixel_values_pose.to(weight_dtype),
151
+ # camera_pose=camera_pose.to(weight_dtype),
152
+ camera_pose=camera_pose,
153
+ agnostic_image=pixel_values_agnostic.to(weight_dtype),
154
+ generator=generator,
155
+ front_dino_fea = front_dino_fea,
156
+ back_dino_fea = back_dino_fea,
157
+ ).images
158
+
159
+ # print('check3', pixel_values.shape, pixel_values_pose.shape, pixel_values_agnostic.shape, pixel_values_ref_front.shape, pixel_values_ref_back.shape)
160
+
161
+ for batch_idx in range(config.batch_size):
162
+
163
+ for image_idx in range(multi_length):
164
+ total_idx = batch_idx*multi_length + image_idx
165
+ edited_image = edited_images[total_idx]
166
+ edited_image = torch.tensor(np.array(edited_image)).permute(2,0,1) / 255.0
167
+ grid = make_image_grid([(pixel_values[batch_idx][image_idx].cpu() / 2 + 0.5),edited_image.cpu(), (pixel_values_pose[batch_idx][image_idx].cpu() / 2 + 0.5),
168
+ (pixel_values_agnostic[batch_idx][image_idx].cpu() / 2 + 0.5), (pixel_values_ref_front[batch_idx].cpu() / 2 + 0.5),(pixel_values_ref_back[batch_idx].cpu() / 2 + 0.5)], nrow=2)
169
+ # save_image(grid, os.path.join(out_dir, ('%d.jpg'%image_idx).zfill(6)))
170
+ # os.makedirs(os.path.join(out_dir, sample_name[idx].split("_")[0]), exist_ok=True)
171
+ # save_image(edited_image, os.path.join(out_dir, img_name[idx][:-4]+'_'+cloth_name[idx]))
172
+ img_name[total_idx] = img_name[total_idx].replace('/','_')
173
+ cloth_name[batch_idx] = cloth_name[batch_idx].split('/')[-1].split('_')[0]
174
+ print(img_name[total_idx], cloth_name[batch_idx])
175
+ sub_cloth_root = os.path.join(out_dir, cloth_name[batch_idx])
176
+ if not os.path.exists(sub_cloth_root):
177
+ os.makedirs(sub_cloth_root)
178
+ save_image(edited_image, os.path.join(out_dir, cloth_name[batch_idx], img_name[total_idx]))
179
+ save_image(grid, os.path.join(out_dir, cloth_name[batch_idx], 'cond_'+img_name[total_idx]))
180
+ print(out_dir, cloth_name[batch_idx], img_name[total_idx])
181
+ image_idx +=1
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
src/multiview_consist_edit/models/ReferenceEncoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ from transformers import CLIPProcessor, CLIPVisionModel, CLIPImageProcessor
5
+ from transformers import logging
6
+ logging.set_verbosity_warning()
7
+ logging.set_verbosity_error()
8
+
9
+ # https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train_plus.py#L49
10
+
11
+ class ReferenceEncoder(nn.Module):
12
+ def __init__(self, model_path="openai/clip-vit-base-patch32"):
13
+ super(ReferenceEncoder, self).__init__()
14
+ self.model = CLIPVisionModel.from_pretrained(model_path,local_files_only=False)
15
+ self.freeze()
16
+
17
+ def freeze(self):
18
+ self.model = self.model.eval()
19
+ for param in self.model.parameters():
20
+ param.requires_grad = False
21
+
22
+ def forward(self, pixel_values):
23
+ outputs = self.model(pixel_values)
24
+
25
+ last_hidden_state = outputs.last_hidden_state
26
+ return last_hidden_state
27
+
28
+ # pooled_output = outputs.pooler_output
29
+ # return pooled_output
30
+
31
+
32
+
33
+
34
+ class ReferenceEncoder2(nn.Module):
35
+ def __init__(self, model_path="openai/clip-vit-base-patch32"):
36
+ super(ReferenceEncoder2, self).__init__()
37
+ self.model = CLIPVisionModel.from_pretrained(model_path,local_files_only=True)
38
+ self.processor = CLIPProcessor.from_pretrained(model_path,local_files_only=True)
39
+ self.freeze()
40
+
41
+ def freeze(self):
42
+ self.model = self.model.eval()
43
+ for param in self.model.parameters():
44
+ param.requires_grad = False
45
+
46
+ def forward(self, image):
47
+ inputs = self.processor(images=image, return_tensors="pt")
48
+
49
+ print(inputs['pixel_values'].size())
50
+
51
+ outputs = self.model(**inputs)
52
+ print(outputs['last_hidden_state'].shape)
53
+ print(outputs.keys())
54
+ pooled_output = outputs.pooler_output
55
+
56
+ return pooled_output
57
+
58
+ # # example
59
+ # model = ReferenceEncoder2(model_path='/root/autodl-tmp/Open-AnimateAnyone/pretrained_models/clip-vit-base-patch32')
60
+ # image_path = "../test.png"
61
+ # # image_path = "/mnt/f/research/HumanVideo/AnimateAnyone-unofficial/DWPose/0001.png"
62
+ # image = Image.open(image_path).convert('RGB')
63
+ # image = [image,image]
64
+
65
+ # pooled_output = model(image)
66
+
67
+ # print(f"Pooled Output Size: {pooled_output.size()}") # Pooled Output Size: torch.Size([bs, 768])
src/multiview_consist_edit/models/ReferenceNet.py ADDED
@@ -0,0 +1,1146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.lora import LoRALinearLayer
33
+ from diffusers.models.embeddings import (
34
+ GaussianFourierProjection,
35
+ ImageHintTimeEmbedding,
36
+ ImageProjection,
37
+ ImageTimeEmbedding,
38
+ PositionNet,
39
+ TextImageProjection,
40
+ TextImageTimeEmbedding,
41
+ TextTimeEmbedding,
42
+ TimestepEmbedding,
43
+ Timesteps,
44
+ )
45
+ from diffusers.models.modeling_utils import ModelMixin
46
+ from diffusers.models.unet_2d_blocks import (
47
+ UNetMidBlock2DCrossAttn,
48
+ UNetMidBlock2DSimpleCrossAttn,
49
+ get_down_block,
50
+ get_up_block,
51
+ )
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ class Identity(torch.nn.Module):
58
+ r"""A placeholder identity operator that is argument-insensitive.
59
+
60
+ Args:
61
+ args: any argument (unused)
62
+ kwargs: any keyword argument (unused)
63
+
64
+ Shape:
65
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
66
+ - Output: :math:`(*)`, same shape as the input.
67
+
68
+ Examples::
69
+
70
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
71
+ >>> input = torch.randn(128, 20)
72
+ >>> output = m(input)
73
+ >>> print(output.size())
74
+ torch.Size([128, 20])
75
+
76
+ """
77
+ def __init__(self, scale=None, *args, **kwargs) -> None:
78
+ super(Identity, self).__init__()
79
+
80
+ def forward(self, input, *args, **kwargs):
81
+ return input
82
+
83
+
84
+
85
+ class _LoRACompatibleLinear(nn.Module):
86
+ """
87
+ A Linear layer that can be used with LoRA.
88
+ """
89
+
90
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
91
+ super().__init__(*args, **kwargs)
92
+ self.lora_layer = lora_layer
93
+
94
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
95
+ self.lora_layer = lora_layer
96
+
97
+ def _fuse_lora(self):
98
+ pass
99
+
100
+ def _unfuse_lora(self):
101
+ pass
102
+
103
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
104
+ return hidden_states
105
+
106
+
107
+ @dataclass
108
+ class UNet2DConditionOutput(BaseOutput):
109
+ """
110
+ The output of [`UNet2DConditionModel`].
111
+
112
+ Args:
113
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
114
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
115
+ """
116
+
117
+ sample: torch.FloatTensor = None
118
+
119
+
120
+ class ReferenceNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
121
+ r"""
122
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
123
+ shaped output.
124
+
125
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
126
+ for all models (such as downloading or saving).
127
+
128
+ Parameters:
129
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
130
+ Height and width of input/output sample.
131
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
132
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
133
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
134
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
135
+ Whether to flip the sin to cos in the time embedding.
136
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
137
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
138
+ The tuple of downsample blocks to use.
139
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
140
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
141
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
142
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
143
+ The tuple of upsample blocks to use.
144
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
145
+ Whether to include self-attention in the basic transformer blocks, see
146
+ [`~models.attention.BasicTransformerBlock`].
147
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
148
+ The tuple of output channels for each block.
149
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
150
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
151
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
152
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
153
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
154
+ If `None`, normalization and activation layers is skipped in post-processing.
155
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
156
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
157
+ The dimension of the cross attention features.
158
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
159
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
160
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
161
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
162
+ encoder_hid_dim (`int`, *optional*, defaults to None):
163
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
164
+ dimension to `cross_attention_dim`.
165
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
166
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
167
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
168
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
169
+ num_attention_heads (`int`, *optional*):
170
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
171
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
172
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
173
+ class_embed_type (`str`, *optional*, defaults to `None`):
174
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
175
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
176
+ addition_embed_type (`str`, *optional*, defaults to `None`):
177
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
178
+ "text". "text" will use the `TextTimeEmbedding` layer.
179
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
180
+ Dimension for the timestep embeddings.
181
+ num_class_embeds (`int`, *optional*, defaults to `None`):
182
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
183
+ class conditioning with `class_embed_type` equal to `None`.
184
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
185
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
186
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
187
+ An optional override for the dimension of the projected time embedding.
188
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
189
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
190
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
191
+ timestep_post_act (`str`, *optional*, defaults to `None`):
192
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
193
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
194
+ The dimension of `cond_proj` layer in the timestep embedding.
195
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
196
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
197
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
198
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
199
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
200
+ embeddings with the class embeddings.
201
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
202
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
203
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
204
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
205
+ otherwise.
206
+ """
207
+
208
+ _supports_gradient_checkpointing = True
209
+
210
+ @register_to_config
211
+ def __init__(
212
+ self,
213
+ sample_size: Optional[int] = None,
214
+ in_channels: int = 4,
215
+ out_channels: int = 4,
216
+ center_input_sample: bool = False,
217
+ flip_sin_to_cos: bool = True,
218
+ freq_shift: int = 0,
219
+ down_block_types: Tuple[str] = (
220
+ "CrossAttnDownBlock2D",
221
+ "CrossAttnDownBlock2D",
222
+ "CrossAttnDownBlock2D",
223
+ "DownBlock2D",
224
+ ),
225
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
226
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
227
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
228
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
229
+ layers_per_block: Union[int, Tuple[int]] = 2,
230
+ downsample_padding: int = 1,
231
+ mid_block_scale_factor: float = 1,
232
+ act_fn: str = "silu",
233
+ norm_num_groups: Optional[int] = 32,
234
+ norm_eps: float = 1e-5,
235
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
236
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
237
+ encoder_hid_dim: Optional[int] = None,
238
+ encoder_hid_dim_type: Optional[str] = None,
239
+ attention_head_dim: Union[int, Tuple[int]] = 8,
240
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
241
+ dual_cross_attention: bool = False,
242
+ use_linear_projection: bool = False,
243
+ class_embed_type: Optional[str] = None,
244
+ addition_embed_type: Optional[str] = None,
245
+ addition_time_embed_dim: Optional[int] = None,
246
+ num_class_embeds: Optional[int] = None,
247
+ upcast_attention: bool = False,
248
+ resnet_time_scale_shift: str = "default",
249
+ resnet_skip_time_act: bool = False,
250
+ resnet_out_scale_factor: int = 1.0,
251
+ time_embedding_type: str = "positional",
252
+ time_embedding_dim: Optional[int] = None,
253
+ time_embedding_act_fn: Optional[str] = None,
254
+ timestep_post_act: Optional[str] = None,
255
+ time_cond_proj_dim: Optional[int] = None,
256
+ conv_in_kernel: int = 3,
257
+ conv_out_kernel: int = 3,
258
+ projection_class_embeddings_input_dim: Optional[int] = None,
259
+ attention_type: str = "default",
260
+ class_embeddings_concat: bool = False,
261
+ mid_block_only_cross_attention: Optional[bool] = None,
262
+ cross_attention_norm: Optional[str] = None,
263
+ addition_embed_type_num_heads=64,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.sample_size = sample_size
268
+
269
+ if num_attention_heads is not None:
270
+ raise ValueError(
271
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
272
+ )
273
+
274
+ # If `num_attention_heads` is not defined (which is the case for most models)
275
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
276
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
277
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
278
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
279
+ # which is why we correct for the naming here.
280
+ num_attention_heads = num_attention_heads or attention_head_dim
281
+
282
+ # Check inputs
283
+ if len(down_block_types) != len(up_block_types):
284
+ raise ValueError(
285
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
286
+ )
287
+
288
+ if len(block_out_channels) != len(down_block_types):
289
+ raise ValueError(
290
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
291
+ )
292
+
293
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
294
+ raise ValueError(
295
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
296
+ )
297
+
298
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
299
+ raise ValueError(
300
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
301
+ )
302
+
303
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
304
+ raise ValueError(
305
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
306
+ )
307
+
308
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
309
+ raise ValueError(
310
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
311
+ )
312
+
313
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ # input
319
+ conv_in_padding = (conv_in_kernel - 1) // 2
320
+ self.conv_in = nn.Conv2d(
321
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
322
+ )
323
+
324
+ # time
325
+ if time_embedding_type == "fourier":
326
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
327
+ if time_embed_dim % 2 != 0:
328
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
329
+ self.time_proj = GaussianFourierProjection(
330
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
331
+ )
332
+ timestep_input_dim = time_embed_dim
333
+ elif time_embedding_type == "positional":
334
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
335
+
336
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
337
+ timestep_input_dim = block_out_channels[0]
338
+ else:
339
+ raise ValueError(
340
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
341
+ )
342
+
343
+ self.time_embedding = TimestepEmbedding(
344
+ timestep_input_dim,
345
+ time_embed_dim,
346
+ act_fn=act_fn,
347
+ post_act_fn=timestep_post_act,
348
+ cond_proj_dim=time_cond_proj_dim,
349
+ )
350
+
351
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
352
+ encoder_hid_dim_type = "text_proj"
353
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
354
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
355
+
356
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
357
+ raise ValueError(
358
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
359
+ )
360
+
361
+ if encoder_hid_dim_type == "text_proj":
362
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
363
+ elif encoder_hid_dim_type == "text_image_proj":
364
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
365
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
366
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
367
+ self.encoder_hid_proj = TextImageProjection(
368
+ text_embed_dim=encoder_hid_dim,
369
+ image_embed_dim=cross_attention_dim,
370
+ cross_attention_dim=cross_attention_dim,
371
+ )
372
+ elif encoder_hid_dim_type == "image_proj":
373
+ # Kandinsky 2.2
374
+ self.encoder_hid_proj = ImageProjection(
375
+ image_embed_dim=encoder_hid_dim,
376
+ cross_attention_dim=cross_attention_dim,
377
+ )
378
+ elif encoder_hid_dim_type is not None:
379
+ raise ValueError(
380
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
381
+ )
382
+ else:
383
+ self.encoder_hid_proj = None
384
+
385
+ # class embedding
386
+ if class_embed_type is None and num_class_embeds is not None:
387
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
388
+ elif class_embed_type == "timestep":
389
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
390
+ elif class_embed_type == "identity":
391
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
392
+ elif class_embed_type == "projection":
393
+ if projection_class_embeddings_input_dim is None:
394
+ raise ValueError(
395
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
396
+ )
397
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
398
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
399
+ # 2. it projects from an arbitrary input dimension.
400
+ #
401
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
402
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
403
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
404
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif class_embed_type == "simple_projection":
406
+ if projection_class_embeddings_input_dim is None:
407
+ raise ValueError(
408
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
409
+ )
410
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
411
+ else:
412
+ self.class_embedding = None
413
+
414
+ if addition_embed_type == "text":
415
+ if encoder_hid_dim is not None:
416
+ text_time_embedding_from_dim = encoder_hid_dim
417
+ else:
418
+ text_time_embedding_from_dim = cross_attention_dim
419
+
420
+ self.add_embedding = TextTimeEmbedding(
421
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
422
+ )
423
+ elif addition_embed_type == "text_image":
424
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
425
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
426
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
427
+ self.add_embedding = TextImageTimeEmbedding(
428
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
429
+ )
430
+ elif addition_embed_type == "text_time":
431
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
432
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
433
+ elif addition_embed_type == "image":
434
+ # Kandinsky 2.2
435
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
436
+ elif addition_embed_type == "image_hint":
437
+ # Kandinsky 2.2 ControlNet
438
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
439
+ elif addition_embed_type is not None:
440
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
441
+
442
+ if time_embedding_act_fn is None:
443
+ self.time_embed_act = None
444
+ else:
445
+ self.time_embed_act = get_activation(time_embedding_act_fn)
446
+
447
+ self.down_blocks = nn.ModuleList([])
448
+ self.up_blocks = nn.ModuleList([])
449
+
450
+ if isinstance(only_cross_attention, bool):
451
+ if mid_block_only_cross_attention is None:
452
+ mid_block_only_cross_attention = only_cross_attention
453
+
454
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
455
+
456
+ if mid_block_only_cross_attention is None:
457
+ mid_block_only_cross_attention = False
458
+
459
+ if isinstance(num_attention_heads, int):
460
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
461
+
462
+ if isinstance(attention_head_dim, int):
463
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
464
+
465
+ if isinstance(cross_attention_dim, int):
466
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
467
+
468
+ if isinstance(layers_per_block, int):
469
+ layers_per_block = [layers_per_block] * len(down_block_types)
470
+
471
+ if isinstance(transformer_layers_per_block, int):
472
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
473
+
474
+ if class_embeddings_concat:
475
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
476
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
477
+ # regular time embeddings
478
+ blocks_time_embed_dim = time_embed_dim * 2
479
+ else:
480
+ blocks_time_embed_dim = time_embed_dim
481
+
482
+ # down
483
+ output_channel = block_out_channels[0]
484
+ for i, down_block_type in enumerate(down_block_types):
485
+ input_channel = output_channel
486
+ output_channel = block_out_channels[i]
487
+ is_final_block = i == len(block_out_channels) - 1
488
+
489
+ down_block = get_down_block(
490
+ down_block_type,
491
+ num_layers=layers_per_block[i],
492
+ transformer_layers_per_block=transformer_layers_per_block[i],
493
+ in_channels=input_channel,
494
+ out_channels=output_channel,
495
+ temb_channels=blocks_time_embed_dim,
496
+ add_downsample=not is_final_block,
497
+ resnet_eps=norm_eps,
498
+ resnet_act_fn=act_fn,
499
+ resnet_groups=norm_num_groups,
500
+ cross_attention_dim=cross_attention_dim[i],
501
+ num_attention_heads=num_attention_heads[i],
502
+ downsample_padding=downsample_padding,
503
+ dual_cross_attention=dual_cross_attention,
504
+ use_linear_projection=use_linear_projection,
505
+ only_cross_attention=only_cross_attention[i],
506
+ upcast_attention=upcast_attention,
507
+ resnet_time_scale_shift=resnet_time_scale_shift,
508
+ attention_type=attention_type,
509
+ resnet_skip_time_act=resnet_skip_time_act,
510
+ resnet_out_scale_factor=resnet_out_scale_factor,
511
+ cross_attention_norm=cross_attention_norm,
512
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
513
+ )
514
+ self.down_blocks.append(down_block)
515
+
516
+ # mid
517
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
518
+ self.mid_block = UNetMidBlock2DCrossAttn(
519
+ transformer_layers_per_block=transformer_layers_per_block[-1],
520
+ in_channels=block_out_channels[-1],
521
+ temb_channels=blocks_time_embed_dim,
522
+ resnet_eps=norm_eps,
523
+ resnet_act_fn=act_fn,
524
+ output_scale_factor=mid_block_scale_factor,
525
+ resnet_time_scale_shift=resnet_time_scale_shift,
526
+ cross_attention_dim=cross_attention_dim[-1],
527
+ num_attention_heads=num_attention_heads[-1],
528
+ resnet_groups=norm_num_groups,
529
+ dual_cross_attention=dual_cross_attention,
530
+ use_linear_projection=use_linear_projection,
531
+ upcast_attention=upcast_attention,
532
+ attention_type=attention_type,
533
+ )
534
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
535
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
536
+ in_channels=block_out_channels[-1],
537
+ temb_channels=blocks_time_embed_dim,
538
+ resnet_eps=norm_eps,
539
+ resnet_act_fn=act_fn,
540
+ output_scale_factor=mid_block_scale_factor,
541
+ cross_attention_dim=cross_attention_dim[-1],
542
+ attention_head_dim=attention_head_dim[-1],
543
+ resnet_groups=norm_num_groups,
544
+ resnet_time_scale_shift=resnet_time_scale_shift,
545
+ skip_time_act=resnet_skip_time_act,
546
+ only_cross_attention=mid_block_only_cross_attention,
547
+ cross_attention_norm=cross_attention_norm,
548
+ )
549
+ elif mid_block_type is None:
550
+ self.mid_block = None
551
+ else:
552
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
553
+
554
+ # count how many layers upsample the images
555
+ self.num_upsamplers = 0
556
+
557
+ # up
558
+ reversed_block_out_channels = list(reversed(block_out_channels))
559
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
560
+ reversed_layers_per_block = list(reversed(layers_per_block))
561
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
562
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
563
+ only_cross_attention = list(reversed(only_cross_attention))
564
+
565
+ output_channel = reversed_block_out_channels[0]
566
+ for i, up_block_type in enumerate(up_block_types):
567
+ is_final_block = i == len(block_out_channels) - 1
568
+
569
+ prev_output_channel = output_channel
570
+ output_channel = reversed_block_out_channels[i]
571
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
572
+
573
+ # add upsample block for all BUT final layer
574
+ if not is_final_block:
575
+ add_upsample = True
576
+ self.num_upsamplers += 1
577
+ else:
578
+ add_upsample = False
579
+
580
+ up_block = get_up_block(
581
+ up_block_type,
582
+ num_layers=reversed_layers_per_block[i] + 1,
583
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
584
+ in_channels=input_channel,
585
+ out_channels=output_channel,
586
+ prev_output_channel=prev_output_channel,
587
+ temb_channels=blocks_time_embed_dim,
588
+ add_upsample=add_upsample,
589
+ resnet_eps=norm_eps,
590
+ resnet_act_fn=act_fn,
591
+ resnet_groups=norm_num_groups,
592
+ cross_attention_dim=reversed_cross_attention_dim[i],
593
+ num_attention_heads=reversed_num_attention_heads[i],
594
+ dual_cross_attention=dual_cross_attention,
595
+ use_linear_projection=use_linear_projection,
596
+ only_cross_attention=only_cross_attention[i],
597
+ upcast_attention=upcast_attention,
598
+ resnet_time_scale_shift=resnet_time_scale_shift,
599
+ attention_type=attention_type,
600
+ resnet_skip_time_act=resnet_skip_time_act,
601
+ resnet_out_scale_factor=resnet_out_scale_factor,
602
+ cross_attention_norm=cross_attention_norm,
603
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
604
+ )
605
+ self.up_blocks.append(up_block)
606
+ prev_output_channel = output_channel
607
+
608
+
609
+ # # out
610
+ # if norm_num_groups is not None:
611
+ # self.conv_norm_out = nn.GroupNorm(
612
+ # num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
613
+ # )
614
+
615
+ # self.conv_act = get_activation(act_fn)
616
+
617
+ # else:
618
+ # self.conv_norm_out = None
619
+ # self.conv_act = None
620
+
621
+ # conv_out_padding = (conv_out_kernel - 1) // 2
622
+ # self.conv_out = nn.Conv2d(
623
+ # block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
624
+ # )
625
+
626
+ # Diff vs diffusers-0.21.4/src/diffusers/models/unet_2d_condition.py
627
+ # skip last cross attention for slight acceleration and for DDP training
628
+ # The following parameters (cross-attention for the last layer)
629
+ # and conv_out are not involved in the gradient calculation of the model
630
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
631
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
632
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
633
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
634
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
635
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
636
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
637
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
638
+ self.up_blocks[3].attentions[2].proj_out = Identity()
639
+
640
+ if attention_type in ["gated", "gated-text-image"]:
641
+ positive_len = 768
642
+ if isinstance(cross_attention_dim, int):
643
+ positive_len = cross_attention_dim
644
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
645
+ positive_len = cross_attention_dim[0]
646
+
647
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
648
+ self.position_net = PositionNet(
649
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
650
+ )
651
+
652
+ @property
653
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
654
+ r"""
655
+ Returns:
656
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
657
+ indexed by its weight name.
658
+ """
659
+ # set recursively
660
+ processors = {}
661
+
662
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
663
+ if hasattr(module, "get_processor"):
664
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
665
+
666
+ for sub_name, child in module.named_children():
667
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
668
+
669
+ return processors
670
+
671
+ for name, module in self.named_children():
672
+ fn_recursive_add_processors(name, module, processors)
673
+
674
+ return processors
675
+
676
+ def set_attn_processor(
677
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
678
+ ):
679
+ r"""
680
+ Sets the attention processor to use to compute attention.
681
+
682
+ Parameters:
683
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
684
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
685
+ for **all** `Attention` layers.
686
+
687
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
688
+ processor. This is strongly recommended when setting trainable attention processors.
689
+
690
+ """
691
+ count = len(self.attn_processors.keys())
692
+
693
+ if isinstance(processor, dict) and len(processor) != count:
694
+ raise ValueError(
695
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
696
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
697
+ )
698
+
699
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
700
+ if hasattr(module, "set_processor"):
701
+ if not isinstance(processor, dict):
702
+ module.set_processor(processor)
703
+ else:
704
+ module.set_processor(processor.pop(f"{name}.processor"))
705
+
706
+ for sub_name, child in module.named_children():
707
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
708
+
709
+ for name, module in self.named_children():
710
+ fn_recursive_attn_processor(name, module, processor)
711
+
712
+ def set_default_attn_processor(self):
713
+ """
714
+ Disables custom attention processors and sets the default attention implementation.
715
+ """
716
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
717
+ processor = AttnAddedKVProcessor()
718
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
719
+ processor = AttnProcessor()
720
+ else:
721
+ raise ValueError(
722
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
723
+ )
724
+
725
+ self.set_attn_processor(processor)
726
+
727
+ def set_attention_slice(self, slice_size):
728
+ r"""
729
+ Enable sliced attention computation.
730
+
731
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
732
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
733
+
734
+ Args:
735
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
736
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
737
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
738
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
739
+ must be a multiple of `slice_size`.
740
+ """
741
+ sliceable_head_dims = []
742
+
743
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
744
+ if hasattr(module, "set_attention_slice"):
745
+ sliceable_head_dims.append(module.sliceable_head_dim)
746
+
747
+ for child in module.children():
748
+ fn_recursive_retrieve_sliceable_dims(child)
749
+
750
+ # retrieve number of attention layers
751
+ for module in self.children():
752
+ fn_recursive_retrieve_sliceable_dims(module)
753
+
754
+ num_sliceable_layers = len(sliceable_head_dims)
755
+
756
+ if slice_size == "auto":
757
+ # half the attention head size is usually a good trade-off between
758
+ # speed and memory
759
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
760
+ elif slice_size == "max":
761
+ # make smallest slice possible
762
+ slice_size = num_sliceable_layers * [1]
763
+
764
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
765
+
766
+ if len(slice_size) != len(sliceable_head_dims):
767
+ raise ValueError(
768
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
769
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
770
+ )
771
+
772
+ for i in range(len(slice_size)):
773
+ size = slice_size[i]
774
+ dim = sliceable_head_dims[i]
775
+ if size is not None and size > dim:
776
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
777
+
778
+ # Recursively walk through all the children.
779
+ # Any children which exposes the set_attention_slice method
780
+ # gets the message
781
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
782
+ if hasattr(module, "set_attention_slice"):
783
+ module.set_attention_slice(slice_size.pop())
784
+
785
+ for child in module.children():
786
+ fn_recursive_set_attention_slice(child, slice_size)
787
+
788
+ reversed_slice_size = list(reversed(slice_size))
789
+ for module in self.children():
790
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
791
+
792
+ def _set_gradient_checkpointing(self, module, value=False):
793
+ if hasattr(module, "gradient_checkpointing"):
794
+ module.gradient_checkpointing = value
795
+
796
+ def forward(
797
+ self,
798
+ sample: torch.FloatTensor,
799
+ timestep: Union[torch.Tensor, float, int],
800
+ encoder_hidden_states: torch.Tensor,
801
+ class_labels: Optional[torch.Tensor] = None,
802
+ timestep_cond: Optional[torch.Tensor] = None,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
805
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
806
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
807
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
808
+ encoder_attention_mask: Optional[torch.Tensor] = None,
809
+ return_dict: bool = True,
810
+ ) -> Union[UNet2DConditionOutput, Tuple]:
811
+ r"""
812
+ The [`UNet2DConditionModel`] forward method.
813
+
814
+ Args:
815
+ sample (`torch.FloatTensor`):
816
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
817
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
818
+ encoder_hidden_states (`torch.FloatTensor`):
819
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
820
+ encoder_attention_mask (`torch.Tensor`):
821
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
822
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
823
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
824
+ return_dict (`bool`, *optional*, defaults to `True`):
825
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
826
+ tuple.
827
+ cross_attention_kwargs (`dict`, *optional*):
828
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
829
+ added_cond_kwargs: (`dict`, *optional*):
830
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
831
+ are passed along to the UNet blocks.
832
+
833
+ Returns:
834
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
835
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
836
+ a `tuple` is returned where the first element is the sample tensor.
837
+ """
838
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
839
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
840
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
841
+ # on the fly if necessary.
842
+ default_overall_up_factor = 2**self.num_upsamplers
843
+
844
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
845
+ forward_upsample_size = False
846
+ upsample_size = None
847
+
848
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
849
+ logger.info("Forward upsample size to force interpolation output size.")
850
+ forward_upsample_size = True
851
+
852
+ if attention_mask is not None:
853
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
854
+ attention_mask = attention_mask.unsqueeze(1)
855
+
856
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
857
+ if encoder_attention_mask is not None:
858
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
859
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
860
+
861
+ # 0. center input if necessary
862
+ if self.config.center_input_sample:
863
+ sample = 2 * sample - 1.0
864
+
865
+ # 1. time
866
+ timesteps = timestep
867
+ if not torch.is_tensor(timesteps):
868
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
869
+ # This would be a good case for the `match` statement (Python 3.10+)
870
+ is_mps = sample.device.type == "mps"
871
+ if isinstance(timestep, float):
872
+ dtype = torch.float32 if is_mps else torch.float64
873
+ else:
874
+ dtype = torch.int32 if is_mps else torch.int64
875
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
876
+ elif len(timesteps.shape) == 0:
877
+ timesteps = timesteps[None].to(sample.device)
878
+
879
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
880
+ timesteps = timesteps.expand(sample.shape[0])
881
+
882
+ t_emb = self.time_proj(timesteps)
883
+
884
+ # `Timesteps` does not contain any weights and will always return f32 tensors
885
+ # but time_embedding might actually be running in fp16. so we need to cast here.
886
+ # there might be better ways to encapsulate this.
887
+ t_emb = t_emb.to(dtype=sample.dtype)
888
+
889
+ emb = self.time_embedding(t_emb, timestep_cond)
890
+ aug_emb = None
891
+
892
+ if self.class_embedding is not None:
893
+ if class_labels is None:
894
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
895
+
896
+ if self.config.class_embed_type == "timestep":
897
+ class_labels = self.time_proj(class_labels)
898
+
899
+ # `Timesteps` does not contain any weights and will always return f32 tensors
900
+ # there might be better ways to encapsulate this.
901
+ class_labels = class_labels.to(dtype=sample.dtype)
902
+
903
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
904
+
905
+ if self.config.class_embeddings_concat:
906
+ emb = torch.cat([emb, class_emb], dim=-1)
907
+ else:
908
+ emb = emb + class_emb
909
+
910
+ if self.config.addition_embed_type == "text":
911
+ aug_emb = self.add_embedding(encoder_hidden_states)
912
+ elif self.config.addition_embed_type == "text_image":
913
+ # Kandinsky 2.1 - style
914
+ if "image_embeds" not in added_cond_kwargs:
915
+ raise ValueError(
916
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
917
+ )
918
+
919
+ image_embs = added_cond_kwargs.get("image_embeds")
920
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
921
+ aug_emb = self.add_embedding(text_embs, image_embs)
922
+ elif self.config.addition_embed_type == "text_time":
923
+ # SDXL - style
924
+ if "text_embeds" not in added_cond_kwargs:
925
+ raise ValueError(
926
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
927
+ )
928
+ text_embeds = added_cond_kwargs.get("text_embeds")
929
+ if "time_ids" not in added_cond_kwargs:
930
+ raise ValueError(
931
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
932
+ )
933
+ time_ids = added_cond_kwargs.get("time_ids")
934
+ time_embeds = self.add_time_proj(time_ids.flatten())
935
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
936
+
937
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
938
+ add_embeds = add_embeds.to(emb.dtype)
939
+ aug_emb = self.add_embedding(add_embeds)
940
+ elif self.config.addition_embed_type == "image":
941
+ # Kandinsky 2.2 - style
942
+ if "image_embeds" not in added_cond_kwargs:
943
+ raise ValueError(
944
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
945
+ )
946
+ image_embs = added_cond_kwargs.get("image_embeds")
947
+ aug_emb = self.add_embedding(image_embs)
948
+ elif self.config.addition_embed_type == "image_hint":
949
+ # Kandinsky 2.2 - style
950
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
951
+ raise ValueError(
952
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
953
+ )
954
+ image_embs = added_cond_kwargs.get("image_embeds")
955
+ hint = added_cond_kwargs.get("hint")
956
+ aug_emb, hint = self.add_embedding(image_embs, hint)
957
+ sample = torch.cat([sample, hint], dim=1)
958
+
959
+ emb = emb + aug_emb if aug_emb is not None else emb
960
+
961
+ if self.time_embed_act is not None:
962
+ emb = self.time_embed_act(emb)
963
+
964
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
965
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
966
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
967
+ # Kadinsky 2.1 - style
968
+ if "image_embeds" not in added_cond_kwargs:
969
+ raise ValueError(
970
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
971
+ )
972
+
973
+ image_embeds = added_cond_kwargs.get("image_embeds")
974
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
975
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
976
+ # Kandinsky 2.2 - style
977
+ if "image_embeds" not in added_cond_kwargs:
978
+ raise ValueError(
979
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
980
+ )
981
+ image_embeds = added_cond_kwargs.get("image_embeds")
982
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
983
+ # 2. pre-process
984
+ sample = self.conv_in(sample)
985
+
986
+ # 2.5 GLIGEN position net
987
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
988
+ cross_attention_kwargs = cross_attention_kwargs.copy()
989
+ gligen_args = cross_attention_kwargs.pop("gligen")
990
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
991
+
992
+ # 3. down
993
+
994
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
995
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
996
+
997
+ down_block_res_samples = (sample,)
998
+ for downsample_block in self.down_blocks:
999
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1000
+ # For t2i-adapter CrossAttnDownBlock2D
1001
+ additional_residuals = {}
1002
+ if is_adapter and len(down_block_additional_residuals) > 0:
1003
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1004
+
1005
+ sample, res_samples = downsample_block(
1006
+ hidden_states=sample,
1007
+ temb=emb,
1008
+ encoder_hidden_states=encoder_hidden_states,
1009
+ attention_mask=attention_mask,
1010
+ cross_attention_kwargs=cross_attention_kwargs,
1011
+ encoder_attention_mask=encoder_attention_mask,
1012
+ **additional_residuals,
1013
+ )
1014
+ else:
1015
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1016
+
1017
+ if is_adapter and len(down_block_additional_residuals) > 0:
1018
+ sample += down_block_additional_residuals.pop(0)
1019
+
1020
+ down_block_res_samples += res_samples
1021
+
1022
+ if is_controlnet:
1023
+ new_down_block_res_samples = ()
1024
+
1025
+ for down_block_res_sample, down_block_additional_residual in zip(
1026
+ down_block_res_samples, down_block_additional_residuals
1027
+ ):
1028
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1029
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1030
+
1031
+ down_block_res_samples = new_down_block_res_samples
1032
+
1033
+ # 4. mid
1034
+ if self.mid_block is not None:
1035
+ sample = self.mid_block(
1036
+ sample,
1037
+ emb,
1038
+ encoder_hidden_states=encoder_hidden_states,
1039
+ attention_mask=attention_mask,
1040
+ cross_attention_kwargs=cross_attention_kwargs,
1041
+ encoder_attention_mask=encoder_attention_mask,
1042
+ )
1043
+ # To support T2I-Adapter-XL
1044
+ if (
1045
+ is_adapter
1046
+ and len(down_block_additional_residuals) > 0
1047
+ and sample.shape == down_block_additional_residuals[0].shape
1048
+ ):
1049
+ sample += down_block_additional_residuals.pop(0)
1050
+
1051
+ if is_controlnet:
1052
+ sample = sample + mid_block_additional_residual
1053
+
1054
+ # 5. up
1055
+ for i, upsample_block in enumerate(self.up_blocks):
1056
+ is_final_block = i == len(self.up_blocks) - 1
1057
+
1058
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1059
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1060
+
1061
+ # if we have not reached the final block and need to forward the
1062
+ # upsample size, we do it here
1063
+ if not is_final_block and forward_upsample_size:
1064
+ upsample_size = down_block_res_samples[-1].shape[2:]
1065
+
1066
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1067
+ sample = upsample_block(
1068
+ hidden_states=sample,
1069
+ temb=emb,
1070
+ res_hidden_states_tuple=res_samples,
1071
+ encoder_hidden_states=encoder_hidden_states,
1072
+ cross_attention_kwargs=cross_attention_kwargs,
1073
+ upsample_size=upsample_size,
1074
+ attention_mask=attention_mask,
1075
+ encoder_attention_mask=encoder_attention_mask,
1076
+ )
1077
+ else:
1078
+ sample = upsample_block(
1079
+ hidden_states=sample,
1080
+ temb=emb,
1081
+ res_hidden_states_tuple=res_samples,
1082
+ upsample_size=upsample_size
1083
+ )
1084
+
1085
+ if not return_dict:
1086
+ return (sample,)
1087
+
1088
+ return UNet2DConditionOutput(sample=sample)
1089
+
1090
+ @classmethod
1091
+ def load_referencenet(cls, pretrained_model_path):
1092
+ print(f"loaded ReferenceNet's pretrained weights from {pretrained_model_path} ...")
1093
+
1094
+ config = {
1095
+ "_class_name": "UNet2DConditionModel",
1096
+ "_diffusers_version": "0.6.0",
1097
+ "act_fn": "silu",
1098
+ "attention_head_dim": 8,
1099
+ "block_out_channels": [320, 640, 1280, 1280],
1100
+ "center_input_sample": False,
1101
+ "cross_attention_dim": 768,
1102
+ "down_block_types": [
1103
+ "CrossAttnDownBlock2D",
1104
+ "CrossAttnDownBlock2D",
1105
+ "CrossAttnDownBlock2D",
1106
+ "DownBlock2D"
1107
+ ],
1108
+ "downsample_padding": 1,
1109
+ "flip_sin_to_cos": True,
1110
+ "freq_shift": 0,
1111
+ "in_channels": 4,
1112
+ "layers_per_block": 2,
1113
+ "mid_block_scale_factor": 1,
1114
+ "norm_eps": 1e-05,
1115
+ "norm_num_groups": 32,
1116
+ "out_channels": 4,
1117
+ "sample_size": 64,
1118
+ "up_block_types": [
1119
+ "UpBlock2D",
1120
+ "CrossAttnUpBlock2D",
1121
+ "CrossAttnUpBlock2D",
1122
+ "CrossAttnUpBlock2D"
1123
+ ]
1124
+ }
1125
+
1126
+ # from diffusers.utils import WEIGHTS_NAME
1127
+ model = cls.from_config(config)
1128
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1129
+ model_file = pretrained_model_path
1130
+
1131
+ if not os.path.isfile(model_file):
1132
+ raise RuntimeError(f"{model_file} does not exist")
1133
+ state_dict = torch.load(model_file, map_location="cpu")
1134
+
1135
+ m, u = model.load_state_dict(state_dict, strict=False)
1136
+ if m or u:
1137
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1138
+ print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
1139
+
1140
+ # params = [p.numel() for n, p in model.named_parameters() if "2D" in n]
1141
+ # print(f"### 2D Module Parameters: {sum(params) / 1e6} M")
1142
+
1143
+ params = [p.numel() for n, p in model.named_parameters()]
1144
+ print(f"### Module Parameters: {sum(params) / 1e6} M")
1145
+
1146
+ return model
src/multiview_consist_edit/models/ReferenceNet_attention_multi_fp16.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from einops import rearrange
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
+
9
+ from diffusers.models.attention import BasicTransformerBlock
10
+ from .attention import BasicTransformerBlock as _BasicTransformerBlock
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceNetAttention():
20
+
21
+ def __init__(self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight = float('inf'),
26
+ gn_auto_machine_weight = 1.0,
27
+ style_fidelity = 1.0,
28
+ reference_attn=True,
29
+ fusion_blocks="full",
30
+ batch_size=1,
31
+ is_image=False,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.fusion_blocks = fusion_blocks
39
+ self.register_reference_hooks(
40
+ mode,
41
+ do_classifier_free_guidance,
42
+ attention_auto_machine_weight,
43
+ gn_auto_machine_weight,
44
+ style_fidelity,
45
+ reference_attn,
46
+ fusion_blocks,
47
+ batch_size=batch_size,
48
+ is_image=is_image,
49
+ )
50
+
51
+ def register_reference_hooks(
52
+ self,
53
+ mode,
54
+ do_classifier_free_guidance,
55
+ attention_auto_machine_weight,
56
+ gn_auto_machine_weight,
57
+ style_fidelity,
58
+ reference_attn,
59
+ # dtype=torch.float16,
60
+ dtype=torch.float16,
61
+ batch_size=1,
62
+ num_images_per_prompt=1,
63
+ device=torch.device("cpu"),
64
+ fusion_blocks='midup',
65
+ is_image=False,
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ fusion_blocks = fusion_blocks
74
+ num_images_per_prompt = num_images_per_prompt
75
+ dtype=dtype
76
+
77
+ def fully_self_attn(self, hidden_states, norm_hidden_states, attention_mask, garment_fea_attn=True):
78
+ b = self.bank[0].shape[0] # 因为衣服没有经过rearrage,不需要将b和f合成bf
79
+ p,l,c = norm_hidden_states.shape
80
+ f = p//b
81
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) l c -> b (f l) c",b=b)
82
+ # add front view and back view feature
83
+ if garment_fea_attn:
84
+ # self.bank[0] = self.bank[0][0].unsqueeze(0)
85
+ # self.bank[1] = self.bank[1][0].unsqueeze(0)
86
+ # print('check2', norm_hidden_states.shape, self.bank[0].shape)
87
+ modify_norm_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)
88
+ else:
89
+ modify_norm_hidden_states = norm_hidden_states
90
+
91
+ hidden_states_uc = self.attn1(modify_norm_hidden_states,
92
+ encoder_hidden_states=modify_norm_hidden_states,
93
+ attention_mask=attention_mask,garment_fea_attn=garment_fea_attn)
94
+ hidden_states_uc = hidden_states_uc[:, :(f*l), :]
95
+ hidden_states_uc = rearrange(hidden_states_uc, "b (f l) c -> (b f) l c", b=b, f=f)
96
+ hidden_states_uc = hidden_states_uc + hidden_states
97
+ return hidden_states_uc
98
+
99
+ def hacked_basic_transformer_inner_forward(
100
+ self,
101
+ hidden_states: torch.FloatTensor,
102
+ attention_mask: Optional[torch.FloatTensor] = None,
103
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
104
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
105
+ timestep: Optional[torch.LongTensor] = None,
106
+ cross_attention_kwargs: Dict[str, Any] = None,
107
+ class_labels: Optional[torch.LongTensor] = None,
108
+ video_length=None,
109
+ ):
110
+ if self.use_ada_layer_norm:
111
+ norm_hidden_states = self.norm1(hidden_states, timestep)
112
+ elif self.use_ada_layer_norm_zero:
113
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
114
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
115
+ )
116
+ else:
117
+ norm_hidden_states = self.norm1(hidden_states)
118
+
119
+ # 1. Self-Attention
120
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
121
+ if self.only_cross_attention:
122
+ attn_output = self.attn1(
123
+ norm_hidden_states,
124
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
125
+ attention_mask=attention_mask,
126
+ **cross_attention_kwargs,
127
+ )
128
+ else:
129
+ if MODE == "write":
130
+ self.bank.append(norm_hidden_states.clone())
131
+ attn_output = self.attn1(
132
+ norm_hidden_states,
133
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
134
+ attention_mask=attention_mask,
135
+ **cross_attention_kwargs,
136
+ )
137
+ if MODE == "read":
138
+ if not is_image:
139
+ self.bank = [rearrange(d.unsqueeze(1).repeat(1, video_length, 1, 1), "b t l c -> (b t) l c")[:hidden_states.shape[0]] for d in self.bank]
140
+
141
+ # revise here
142
+ if True: # 这里一定是True, 如果是false用图像级别的代码就好
143
+ if do_classifier_free_guidance:
144
+ _uc_mask_top = (
145
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
146
+ .to(device)
147
+ .bool()
148
+ )
149
+ _uc_mask_bottom = (
150
+ torch.Tensor([0] * (hidden_states.shape[0]//2) + [1] * (hidden_states.shape[0]//2))
151
+ .to(device)
152
+ .bool()
153
+ )
154
+ # 前面一半是uncond, 后面一半是cond
155
+ hidden_states_uc = norm_hidden_states.clone()
156
+ hidden_states_uc[_uc_mask_top] = fully_self_attn(self, hidden_states[_uc_mask_top], norm_hidden_states[_uc_mask_top], attention_mask, garment_fea_attn=False)
157
+ hidden_states_uc[_uc_mask_bottom] = fully_self_attn(self, hidden_states[_uc_mask_bottom], norm_hidden_states[_uc_mask_bottom], attention_mask, garment_fea_attn=True)
158
+ hidden_states = hidden_states_uc.clone()
159
+ else:
160
+ hidden_states_uc = fully_self_attn(self, hidden_states, norm_hidden_states, attention_mask, garment_fea_attn=True)
161
+ hidden_states = hidden_states_uc.clone()
162
+
163
+ else:
164
+ # modify Reference Sec 3.2.2
165
+
166
+ modify_norm_hidden_states = torch.cat([norm_hidden_states] + self.bank, dim=1)
167
+
168
+ hidden_states_uc = self.attn1(modify_norm_hidden_states,
169
+ encoder_hidden_states=modify_norm_hidden_states,
170
+ attention_mask=attention_mask)[:,:hidden_states.shape[-2],:] + hidden_states
171
+
172
+ hidden_states_c = hidden_states_uc.clone()
173
+ _uc_mask = uc_mask.clone()
174
+ if do_classifier_free_guidance:
175
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
176
+ _uc_mask = (
177
+ torch.Tensor([1] * (hidden_states.shape[0]//2) + [0] * (hidden_states.shape[0]//2))
178
+ .to(device)
179
+ .bool()
180
+ )
181
+ # print('111111', _uc_mask.shape, norm_hidden_states.shape)
182
+ hidden_states_c[_uc_mask] = self.attn1(
183
+ norm_hidden_states[_uc_mask],
184
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
185
+ attention_mask=attention_mask,
186
+ ) + hidden_states[_uc_mask]
187
+ hidden_states = hidden_states_c.clone()
188
+
189
+ # self.bank.clear()
190
+
191
+
192
+ if self.attn2 is not None:
193
+ # Cross-Attention
194
+ norm_hidden_states = (
195
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
196
+ )
197
+ hidden_states = (
198
+ self.attn2(
199
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
200
+ )
201
+ + hidden_states
202
+ )
203
+
204
+ # Feed-forward
205
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
206
+
207
+ # Temporal-Attention
208
+ if not is_image:
209
+ if self.unet_use_temporal_attention:
210
+ d = hidden_states.shape[1]
211
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
212
+ norm_hidden_states = (
213
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
214
+ )
215
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
216
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
217
+
218
+ return hidden_states
219
+
220
+ if self.use_ada_layer_norm_zero:
221
+ attn_output = gate_msa.unsqueeze(1) * attn_output
222
+ hidden_states = attn_output + hidden_states
223
+
224
+ if self.attn2 is not None:
225
+ norm_hidden_states = (
226
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
227
+ )
228
+
229
+ # 2. Cross-Attention
230
+ attn_output = self.attn2(
231
+ norm_hidden_states,
232
+ encoder_hidden_states=encoder_hidden_states,
233
+ attention_mask=encoder_attention_mask,
234
+ **cross_attention_kwargs,
235
+ )
236
+ hidden_states = attn_output + hidden_states
237
+
238
+ # 3. Feed-forward
239
+ norm_hidden_states = self.norm3(hidden_states)
240
+
241
+ if self.use_ada_layer_norm_zero:
242
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
243
+
244
+ ff_output = self.ff(norm_hidden_states)
245
+
246
+ if self.use_ada_layer_norm_zero:
247
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
248
+
249
+ hidden_states = ff_output + hidden_states
250
+
251
+ return hidden_states
252
+
253
+ if self.reference_attn:
254
+ if self.fusion_blocks == "midup":
255
+ attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
256
+ elif self.fusion_blocks == "full":
257
+ attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
258
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
259
+
260
+ for i, module in enumerate(attn_modules):
261
+ module._original_inner_forward = module.forward
262
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
263
+ module.bank = []
264
+ module.attn_weight = float(i) / float(len(attn_modules))
265
+
266
+ # def update(self, writer, dtype=torch.float16):
267
+ def update(self, writer, dtype=torch.float16):
268
+ if self.reference_attn:
269
+ if self.fusion_blocks == "midup":
270
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, _BasicTransformerBlock)]
271
+ writer_attn_modules = [module for module in (torch_dfs(writer.unet.mid_block)+torch_dfs(writer.unet.up_blocks)) if isinstance(module, BasicTransformerBlock)]
272
+ elif self.fusion_blocks == "full":
273
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
274
+ writer_attn_modules = [module for module in torch_dfs(writer.unet) if isinstance(module, _BasicTransformerBlock) or isinstance(module, BasicTransformerBlock)]
275
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
276
+ writer_attn_modules = sorted(writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
277
+
278
+ if len(reader_attn_modules) == 0:
279
+ print('reader_attn_modules is null')
280
+ assert False
281
+ if len(writer_attn_modules) == 0:
282
+ print('writer_attn_modules is null')
283
+ assert False
284
+
285
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
286
+ r.bank = [v.clone().to(dtype) for v in w.bank]
287
+ # w.bank.clear()
288
+
289
+ def clear(self):
290
+ if self.reference_attn:
291
+ if self.fusion_blocks == "midup":
292
+ reader_attn_modules = [module for module in (torch_dfs(self.unet.mid_block)+torch_dfs(self.unet.up_blocks)) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
293
+ elif self.fusion_blocks == "full":
294
+ reader_attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) or isinstance(module, _BasicTransformerBlock)]
295
+ reader_attn_modules = sorted(reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
296
+ for r in reader_attn_modules:
297
+ r.bank.clear()
src/multiview_consist_edit/models/attention.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from dataclasses import dataclass
21
+ from typing import Optional
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.utils import BaseOutput
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
32
+ from diffusers.models.attention import Attention as CrossAttention
33
+
34
+ from einops import rearrange, repeat
35
+
36
+ @dataclass
37
+ class Transformer3DModelOutput(BaseOutput):
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer3DModel(ModelMixin, ConfigMixin):
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ num_attention_heads: int = 16,
53
+ attention_head_dim: int = 88,
54
+ in_channels: Optional[int] = None,
55
+ num_layers: int = 1,
56
+ dropout: float = 0.0,
57
+ norm_num_groups: int = 32,
58
+ cross_attention_dim: Optional[int] = None,
59
+ attention_bias: bool = False,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ use_linear_projection: bool = False,
63
+ only_cross_attention: bool = False,
64
+ upcast_attention: bool = False,
65
+
66
+ unet_use_cross_frame_attention=None,
67
+ unet_use_temporal_attention=None,
68
+ ):
69
+ super().__init__()
70
+ self.use_linear_projection = use_linear_projection
71
+ self.num_attention_heads = num_attention_heads
72
+ self.attention_head_dim = attention_head_dim
73
+ inner_dim = num_attention_heads * attention_head_dim
74
+
75
+ # Define input layers
76
+ self.in_channels = in_channels
77
+
78
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
79
+ if use_linear_projection:
80
+ self.proj_in = nn.Linear(in_channels, inner_dim)
81
+ else:
82
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
83
+
84
+ # Define transformers blocks
85
+ self.transformer_blocks = nn.ModuleList(
86
+ [
87
+ BasicTransformerBlock(
88
+ inner_dim,
89
+ num_attention_heads,
90
+ attention_head_dim,
91
+ dropout=dropout,
92
+ cross_attention_dim=cross_attention_dim,
93
+ activation_fn=activation_fn,
94
+ num_embeds_ada_norm=num_embeds_ada_norm,
95
+ attention_bias=attention_bias,
96
+ only_cross_attention=only_cross_attention,
97
+ upcast_attention=upcast_attention,
98
+
99
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
+ unet_use_temporal_attention=unet_use_temporal_attention,
101
+ )
102
+ for d in range(num_layers)
103
+ ]
104
+ )
105
+
106
+ # 4. Define output layers
107
+ if use_linear_projection:
108
+ self.proj_out = nn.Linear(in_channels, inner_dim)
109
+ else:
110
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
111
+
112
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
113
+ # Input
114
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
115
+ video_length = hidden_states.shape[2]
116
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
117
+ # JH: need not repeat when a list of prompts are given
118
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
119
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
129
+ else:
130
+ inner_dim = hidden_states.shape[1]
131
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
132
+ hidden_states = self.proj_in(hidden_states)
133
+
134
+ # Blocks
135
+ for block in self.transformer_blocks:
136
+ hidden_states = block(
137
+ hidden_states,
138
+ encoder_hidden_states=encoder_hidden_states,
139
+ timestep=timestep,
140
+ video_length=video_length
141
+ )
142
+
143
+ # Output
144
+ if not self.use_linear_projection:
145
+ hidden_states = (
146
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
147
+ )
148
+ hidden_states = self.proj_out(hidden_states)
149
+ else:
150
+ hidden_states = self.proj_out(hidden_states)
151
+ hidden_states = (
152
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
153
+ )
154
+
155
+ output = hidden_states + residual
156
+
157
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
+ if not return_dict:
159
+ return (output,)
160
+
161
+ return Transformer3DModelOutput(sample=output)
162
+
163
+
164
+ class BasicTransformerBlock(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim: int,
168
+ num_attention_heads: int,
169
+ attention_head_dim: int,
170
+ dropout=0.0,
171
+ cross_attention_dim: Optional[int] = None,
172
+ activation_fn: str = "geglu",
173
+ num_embeds_ada_norm: Optional[int] = None,
174
+ attention_bias: bool = False,
175
+ only_cross_attention: bool = False,
176
+ upcast_attention: bool = False,
177
+
178
+ unet_use_cross_frame_attention = None,
179
+ unet_use_temporal_attention = None,
180
+ ):
181
+ super().__init__()
182
+ self.only_cross_attention = only_cross_attention
183
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
184
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
185
+ self.unet_use_temporal_attention = unet_use_temporal_attention
186
+
187
+ # SC-Attn
188
+ assert unet_use_cross_frame_attention is not None
189
+ if unet_use_cross_frame_attention:
190
+ self.attn1 = SparseCausalAttention2D(
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
197
+ upcast_attention=upcast_attention,
198
+ )
199
+ else:
200
+ self.attn1 = CrossAttention(
201
+ query_dim=dim,
202
+ heads=num_attention_heads,
203
+ dim_head=attention_head_dim,
204
+ dropout=dropout,
205
+ bias=attention_bias,
206
+ upcast_attention=upcast_attention,
207
+ )
208
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
+
210
+ # Cross-Attn
211
+ if cross_attention_dim is not None:
212
+ self.attn2 = CrossAttention(
213
+ query_dim=dim,
214
+ cross_attention_dim=cross_attention_dim,
215
+ heads=num_attention_heads,
216
+ dim_head=attention_head_dim,
217
+ dropout=dropout,
218
+ bias=attention_bias,
219
+ upcast_attention=upcast_attention,
220
+ )
221
+ else:
222
+ self.attn2 = None
223
+
224
+ if cross_attention_dim is not None:
225
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
226
+ else:
227
+ self.norm2 = None
228
+
229
+ # Feed-forward
230
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
231
+ self.norm3 = nn.LayerNorm(dim)
232
+ self.use_ada_layer_norm_zero = False
233
+
234
+ # Temp-Attn
235
+ assert unet_use_temporal_attention is not None
236
+ if unet_use_temporal_attention:
237
+ self.attn_temp = CrossAttention(
238
+ query_dim=dim,
239
+ heads=num_attention_heads,
240
+ dim_head=attention_head_dim,
241
+ dropout=dropout,
242
+ bias=attention_bias,
243
+ upcast_attention=upcast_attention,
244
+ )
245
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
246
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
247
+
248
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
249
+ if not is_xformers_available():
250
+ print("Here is how to install it")
251
+ raise ModuleNotFoundError(
252
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
253
+ " xformers",
254
+ name="xformers",
255
+ )
256
+ elif not torch.cuda.is_available():
257
+ raise ValueError(
258
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
259
+ " available for GPU "
260
+ )
261
+ else:
262
+ try:
263
+ # Make sure we can run the memory efficient attention
264
+ _ = xformers.ops.memory_efficient_attention(
265
+ torch.randn((1, 2, 40), device="cuda"),
266
+ torch.randn((1, 2, 40), device="cuda"),
267
+ torch.randn((1, 2, 40), device="cuda"),
268
+ )
269
+ except Exception as e:
270
+ raise e
271
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
272
+ if self.attn2 is not None:
273
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
274
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
275
+
276
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
277
+ # SparseCausal-Attention
278
+ norm_hidden_states = (
279
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
280
+ )
281
+
282
+ # if self.only_cross_attention:
283
+ # hidden_states = (
284
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
285
+ # )
286
+ # else:
287
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
288
+
289
+ # pdb.set_trace()
290
+ if self.unet_use_cross_frame_attention:
291
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
292
+ else:
293
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
294
+
295
+ if self.attn2 is not None:
296
+ # Cross-Attention
297
+ norm_hidden_states = (
298
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
299
+ )
300
+ hidden_states = (
301
+ self.attn2(
302
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
303
+ )
304
+ + hidden_states
305
+ )
306
+
307
+ # Feed-forward
308
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
309
+
310
+ # Temporal-Attention
311
+ if self.unet_use_temporal_attention:
312
+ d = hidden_states.shape[1]
313
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
314
+ norm_hidden_states = (
315
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
316
+ )
317
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
318
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
319
+
320
+ return hidden_states
src/multiview_consist_edit/models/condition_encoder.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import kornia
4
+ import open_clip
5
+ from torch.utils.checkpoint import checkpoint
6
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7
+ # from lvdm.common import autocast
8
+ # from utils.utils import count_params
9
+
10
+ # from https://github.com/Doubiiu/DynamiCrafter/blob/main/lvdm/modules/encoders/condition.py
11
+
12
+ class AbstractEncoder(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def encode(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+
20
+ class IdentityEncoder(AbstractEncoder):
21
+ def encode(self, x):
22
+ return x
23
+
24
+
25
+ class ClassEmbedder(nn.Module):
26
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27
+ super().__init__()
28
+ self.key = key
29
+ self.embedding = nn.Embedding(n_classes, embed_dim)
30
+ self.n_classes = n_classes
31
+ self.ucg_rate = ucg_rate
32
+
33
+ def forward(self, batch, key=None, disable_dropout=False):
34
+ if key is None:
35
+ key = self.key
36
+ # this is for use in crossattn
37
+ c = batch[key][:, None]
38
+ if self.ucg_rate > 0. and not disable_dropout:
39
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
41
+ c = c.long()
42
+ c = self.embedding(c)
43
+ return c
44
+
45
+ def get_unconditional_conditioning(self, bs, device="cuda"):
46
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47
+ uc = torch.ones((bs,), device=device) * uc_class
48
+ uc = {self.key: uc}
49
+ return uc
50
+
51
+
52
+ def disabled_train(self, mode=True):
53
+ """Overwrite model.train with this function to make sure train/eval mode
54
+ does not change anymore."""
55
+ return self
56
+
57
+
58
+ class FrozenT5Embedder(AbstractEncoder):
59
+ """Uses the T5 transformer encoder for text"""
60
+
61
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
62
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
63
+ super().__init__()
64
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
65
+ self.transformer = T5EncoderModel.from_pretrained(version)
66
+ self.device = device
67
+ self.max_length = max_length # TODO: typical value?
68
+ if freeze:
69
+ self.freeze()
70
+
71
+ def freeze(self):
72
+ self.transformer = self.transformer.eval()
73
+ # self.train = disabled_train
74
+ for param in self.parameters():
75
+ param.requires_grad = False
76
+
77
+ def forward(self, text):
78
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
79
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
80
+ tokens = batch_encoding["input_ids"].to(self.device)
81
+ outputs = self.transformer(input_ids=tokens)
82
+
83
+ z = outputs.last_hidden_state
84
+ return z
85
+
86
+ def encode(self, text):
87
+ return self(text)
88
+
89
+
90
+ class FrozenCLIPEmbedder(AbstractEncoder):
91
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
92
+ LAYERS = [
93
+ "last",
94
+ "pooled",
95
+ "hidden"
96
+ ]
97
+
98
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
99
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
100
+ super().__init__()
101
+ assert layer in self.LAYERS
102
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
103
+ self.transformer = CLIPTextModel.from_pretrained(version)
104
+ self.device = device
105
+ self.max_length = max_length
106
+ if freeze:
107
+ self.freeze()
108
+ self.layer = layer
109
+ self.layer_idx = layer_idx
110
+ if layer == "hidden":
111
+ assert layer_idx is not None
112
+ assert 0 <= abs(layer_idx) <= 12
113
+
114
+ def freeze(self):
115
+ self.transformer = self.transformer.eval()
116
+ # self.train = disabled_train
117
+ for param in self.parameters():
118
+ param.requires_grad = False
119
+
120
+ def forward(self, text):
121
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
122
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
123
+ tokens = batch_encoding["input_ids"].to(self.device)
124
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
125
+ if self.layer == "last":
126
+ z = outputs.last_hidden_state
127
+ elif self.layer == "pooled":
128
+ z = outputs.pooler_output[:, None, :]
129
+ else:
130
+ z = outputs.hidden_states[self.layer_idx]
131
+ return z
132
+
133
+ def encode(self, text):
134
+ return self(text)
135
+
136
+
137
+ class ClipImageEmbedder(nn.Module):
138
+ def __init__(
139
+ self,
140
+ model,
141
+ jit=False,
142
+ device='cuda' if torch.cuda.is_available() else 'cpu',
143
+ antialias=True,
144
+ ucg_rate=0.
145
+ ):
146
+ super().__init__()
147
+ from clip import load as load_clip
148
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
149
+
150
+ self.antialias = antialias
151
+
152
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
153
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
154
+ self.ucg_rate = ucg_rate
155
+
156
+ def preprocess(self, x):
157
+ # normalize to [0,1]
158
+ x = kornia.geometry.resize(x, (224, 224),
159
+ interpolation='bicubic', align_corners=True,
160
+ antialias=self.antialias)
161
+ x = (x + 1.) / 2.
162
+ # re-normalize according to clip
163
+ x = kornia.enhance.normalize(x, self.mean, self.std)
164
+ return x
165
+
166
+ def forward(self, x, no_dropout=False):
167
+ # x is assumed to be in range [-1,1]
168
+ out = self.model.encode_image(self.preprocess(x))
169
+ out = out.to(x.dtype)
170
+ if self.ucg_rate > 0. and not no_dropout:
171
+ out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
172
+ return out
173
+
174
+
175
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
176
+ """
177
+ Uses the OpenCLIP transformer encoder for text
178
+ """
179
+ LAYERS = [
180
+ # "pooled",
181
+ "last",
182
+ "penultimate"
183
+ ]
184
+
185
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
186
+ freeze=True, layer="last"):
187
+ super().__init__()
188
+ assert layer in self.LAYERS
189
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
190
+ del model.visual
191
+ self.model = model
192
+
193
+ self.device = device
194
+ self.max_length = max_length
195
+ if freeze:
196
+ self.freeze()
197
+ self.layer = layer
198
+ if self.layer == "last":
199
+ self.layer_idx = 0
200
+ elif self.layer == "penultimate":
201
+ self.layer_idx = 1
202
+ else:
203
+ raise NotImplementedError()
204
+
205
+ def freeze(self):
206
+ self.model = self.model.eval()
207
+ for param in self.parameters():
208
+ param.requires_grad = False
209
+
210
+ def forward(self, text):
211
+ tokens = open_clip.tokenize(text) ## all clip models use 77 as context length
212
+ z = self.encode_with_transformer(tokens.to(self.device))
213
+ return z
214
+
215
+ def encode_with_transformer(self, text):
216
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
217
+ x = x + self.model.positional_embedding
218
+ x = x.permute(1, 0, 2) # NLD -> LND
219
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
220
+ x = x.permute(1, 0, 2) # LND -> NLD
221
+ x = self.model.ln_final(x)
222
+ return x
223
+
224
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
225
+ for i, r in enumerate(self.model.transformer.resblocks):
226
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
227
+ break
228
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
229
+ x = checkpoint(r, x, attn_mask)
230
+ else:
231
+ x = r(x, attn_mask=attn_mask)
232
+ return x
233
+
234
+ def encode(self, text):
235
+ return self(text)
236
+
237
+
238
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
239
+ """
240
+ Uses the OpenCLIP vision transformer encoder for images
241
+ """
242
+
243
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
244
+ freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
245
+ super().__init__()
246
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
247
+ pretrained=version, )
248
+ del model.transformer
249
+ self.model = model
250
+ # self.mapper = torch.nn.Linear(1280, 1024)
251
+ self.device = device
252
+ self.max_length = max_length
253
+ if freeze:
254
+ self.freeze()
255
+ self.layer = layer
256
+ if self.layer == "penultimate":
257
+ raise NotImplementedError()
258
+ self.layer_idx = 1
259
+
260
+ self.antialias = antialias
261
+
262
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
263
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
264
+ self.ucg_rate = ucg_rate
265
+
266
+ def preprocess(self, x):
267
+ # normalize to [0,1]
268
+ x = kornia.geometry.resize(x, (224, 224),
269
+ interpolation='bicubic', align_corners=True,
270
+ antialias=self.antialias)
271
+ x = (x + 1.) / 2.
272
+ # renormalize according to clip
273
+ x = kornia.enhance.normalize(x, self.mean, self.std)
274
+ return x
275
+
276
+ def freeze(self):
277
+ self.model = self.model.eval()
278
+ for param in self.model.parameters():
279
+ param.requires_grad = False
280
+
281
+ # @autocast
282
+ def forward(self, image, no_dropout=False):
283
+ z = self.encode_with_vision_transformer(image)
284
+ if self.ucg_rate > 0. and not no_dropout:
285
+ z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
286
+ return z
287
+
288
+ def encode_with_vision_transformer(self, img):
289
+ img = self.preprocess(img)
290
+ x = self.model.visual(img)
291
+ return x
292
+
293
+ def encode(self, text):
294
+ return self(text)
295
+
296
+ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
297
+ """
298
+ Uses the OpenCLIP vision transformer encoder for images
299
+ """
300
+
301
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
302
+ freeze=True, layer="pooled", antialias=True, model_path=None):
303
+ super().__init__()
304
+
305
+ if model_path is None:
306
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
307
+ pretrained=version, )
308
+ else:
309
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
310
+ pretrained=model_path)
311
+ del model.transformer
312
+ self.model = model
313
+ self.device = device
314
+
315
+ if freeze:
316
+ self.freeze()
317
+ self.layer = layer
318
+ if self.layer == "penultimate":
319
+ raise NotImplementedError()
320
+ self.layer_idx = 1
321
+
322
+ self.antialias = antialias
323
+
324
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
325
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
326
+
327
+
328
+ def preprocess(self, x):
329
+ # normalize to [0,1]
330
+ x = kornia.geometry.resize(x, (224, 224),
331
+ interpolation='bicubic', align_corners=True,
332
+ antialias=self.antialias)
333
+ x = (x + 1.) / 2.
334
+ # renormalize according to clip
335
+ x = kornia.enhance.normalize(x, self.mean, self.std)
336
+ return x
337
+
338
+ def freeze(self):
339
+ self.model = self.model.eval()
340
+ for param in self.model.parameters():
341
+ param.requires_grad = False
342
+
343
+ def forward(self, image, no_dropout=False):
344
+ ## image: b c h w
345
+ z = self.encode_with_vision_transformer(image)
346
+ return z
347
+
348
+ def encode_with_vision_transformer(self, x):
349
+ x = self.preprocess(x)
350
+
351
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
352
+ if self.model.visual.input_patchnorm:
353
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
354
+ x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
355
+ x = x.permute(0, 2, 4, 1, 3, 5)
356
+ x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
357
+ x = self.model.visual.patchnorm_pre_ln(x)
358
+ x = self.model.visual.conv1(x)
359
+ else:
360
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
361
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
362
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
363
+
364
+ # class embeddings and positional embeddings
365
+ x = torch.cat(
366
+ [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
367
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
368
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
369
+
370
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
371
+ x = self.model.visual.patch_dropout(x)
372
+ x = self.model.visual.ln_pre(x)
373
+
374
+ x = x.permute(1, 0, 2) # NLD -> LND
375
+ x = self.model.visual.transformer(x)
376
+ x = x.permute(1, 0, 2) # LND -> NLD
377
+
378
+ return x
379
+
380
+ class FrozenCLIPT5Encoder(AbstractEncoder):
381
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
382
+ clip_max_length=77, t5_max_length=77):
383
+ super().__init__()
384
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
385
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
386
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
387
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
388
+
389
+ def encode(self, text):
390
+ return self(text)
391
+
392
+ def forward(self, text):
393
+ clip_z = self.clip_encoder.encode(text)
394
+ t5_z = self.t5_encoder.encode(text)
395
+ return [clip_z, t5_z]
src/multiview_consist_edit/models/embeddings.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ def get_timestep_embedding(
29
+ timesteps: torch.Tensor,
30
+ embedding_dim: int,
31
+ flip_sin_to_cos: bool = False,
32
+ downscale_freq_shift: float = 1,
33
+ scale: float = 1,
34
+ max_period: int = 10000,
35
+ ):
36
+ """
37
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
38
+
39
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
40
+ These may be fractional.
41
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
42
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
43
+ """
44
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
45
+
46
+ half_dim = embedding_dim // 2
47
+ exponent = -math.log(max_period) * torch.arange(
48
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
49
+ )
50
+ exponent = exponent / (half_dim - downscale_freq_shift)
51
+
52
+ emb = torch.exp(exponent)
53
+ emb = timesteps[:, None].float() * emb[None, :]
54
+
55
+ # scale embeddings
56
+ emb = scale * emb
57
+
58
+ # concat sine and cosine embeddings
59
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
60
+
61
+ # flip sine and cosine embeddings
62
+ if flip_sin_to_cos:
63
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
64
+
65
+ # zero pad
66
+ if embedding_dim % 2 == 1:
67
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
68
+ return emb
69
+
70
+
71
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
72
+ """
73
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
+ """
76
+ grid_h = np.arange(grid_size, dtype=np.float32)
77
+ grid_w = np.arange(grid_size, dtype=np.float32)
78
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
79
+ grid = np.stack(grid, axis=0)
80
+
81
+ grid = grid.reshape([2, 1, grid_size, grid_size])
82
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
83
+ if cls_token and extra_tokens > 0:
84
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
85
+ return pos_embed
86
+
87
+
88
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
89
+ if embed_dim % 2 != 0:
90
+ raise ValueError("embed_dim must be divisible by 2")
91
+
92
+ # use half of dimensions to encode grid_h
93
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
94
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
95
+
96
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
97
+ return emb
98
+
99
+
100
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
101
+ """
102
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
103
+ """
104
+ if embed_dim % 2 != 0:
105
+ raise ValueError("embed_dim must be divisible by 2")
106
+
107
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = np.sin(out) # (M, D/2)
115
+ emb_cos = np.cos(out) # (M, D/2)
116
+
117
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
118
+ return emb
119
+
120
+
121
+ class PatchEmbed(nn.Module):
122
+ """2D Image to Patch Embedding"""
123
+
124
+ def __init__(
125
+ self,
126
+ height=224,
127
+ width=224,
128
+ patch_size=16,
129
+ in_channels=3,
130
+ embed_dim=768,
131
+ layer_norm=False,
132
+ flatten=True,
133
+ bias=True,
134
+ ):
135
+ super().__init__()
136
+
137
+ num_patches = (height // patch_size) * (width // patch_size)
138
+ self.flatten = flatten
139
+ self.layer_norm = layer_norm
140
+
141
+ self.proj = nn.Conv2d(
142
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
143
+ )
144
+ if layer_norm:
145
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
146
+ else:
147
+ self.norm = None
148
+
149
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
150
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
151
+
152
+ def forward(self, latent):
153
+ latent = self.proj(latent)
154
+ if self.flatten:
155
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
156
+ if self.layer_norm:
157
+ latent = self.norm(latent)
158
+ return latent + self.pos_embed
159
+
160
+
161
+ class TimestepEmbedding(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channels: int,
165
+ time_embed_dim: int,
166
+ act_fn: str = "silu",
167
+ out_dim: int = None,
168
+ post_act_fn: Optional[str] = None,
169
+ cond_proj_dim=None,
170
+ ):
171
+ super().__init__()
172
+
173
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
174
+
175
+ if cond_proj_dim is not None:
176
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
177
+ else:
178
+ self.cond_proj = None
179
+
180
+ if act_fn == "silu":
181
+ self.act = nn.SiLU()
182
+ elif act_fn == "mish":
183
+ self.act = nn.Mish()
184
+ elif act_fn == "gelu":
185
+ self.act = nn.GELU()
186
+ else:
187
+ raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
188
+
189
+ if out_dim is not None:
190
+ time_embed_dim_out = out_dim
191
+ else:
192
+ time_embed_dim_out = time_embed_dim
193
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
194
+
195
+ if post_act_fn is None:
196
+ self.post_act = None
197
+ elif post_act_fn == "silu":
198
+ self.post_act = nn.SiLU()
199
+ elif post_act_fn == "mish":
200
+ self.post_act = nn.Mish()
201
+ elif post_act_fn == "gelu":
202
+ self.post_act = nn.GELU()
203
+ else:
204
+ raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
205
+
206
+ def forward(self, sample, condition=None):
207
+ if condition is not None:
208
+ sample = sample + self.cond_proj(condition)
209
+ sample = self.linear_1(sample)
210
+
211
+ if self.act is not None:
212
+ sample = self.act(sample)
213
+
214
+ sample = self.linear_2(sample)
215
+
216
+ if self.post_act is not None:
217
+ sample = self.post_act(sample)
218
+ return sample
219
+
220
+
221
+ class Timesteps(nn.Module):
222
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
223
+ super().__init__()
224
+ self.num_channels = num_channels
225
+ self.flip_sin_to_cos = flip_sin_to_cos
226
+ self.downscale_freq_shift = downscale_freq_shift
227
+
228
+ def forward(self, timesteps):
229
+ t_emb = get_timestep_embedding(
230
+ timesteps,
231
+ self.num_channels,
232
+ flip_sin_to_cos=self.flip_sin_to_cos,
233
+ downscale_freq_shift=self.downscale_freq_shift,
234
+ )
235
+ return t_emb
236
+
237
+
238
+ class GaussianFourierProjection(nn.Module):
239
+ """Gaussian Fourier embeddings for noise levels."""
240
+
241
+ def __init__(
242
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
243
+ ):
244
+ super().__init__()
245
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
246
+ self.log = log
247
+ self.flip_sin_to_cos = flip_sin_to_cos
248
+
249
+ if set_W_to_weight:
250
+ # to delete later
251
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
252
+
253
+ self.weight = self.W
254
+
255
+ def forward(self, x):
256
+ if self.log:
257
+ x = torch.log(x)
258
+
259
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
260
+
261
+ if self.flip_sin_to_cos:
262
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
263
+ else:
264
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
265
+ return out
266
+
267
+
268
+ class ImagePositionalEmbeddings(nn.Module):
269
+ """
270
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
271
+ height and width of the latent space.
272
+
273
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
274
+
275
+ For VQ-diffusion:
276
+
277
+ Output vector embeddings are used as input for the transformer.
278
+
279
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
280
+
281
+ Args:
282
+ num_embed (`int`):
283
+ Number of embeddings for the latent pixels embeddings.
284
+ height (`int`):
285
+ Height of the latent image i.e. the number of height embeddings.
286
+ width (`int`):
287
+ Width of the latent image i.e. the number of width embeddings.
288
+ embed_dim (`int`):
289
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ num_embed: int,
295
+ height: int,
296
+ width: int,
297
+ embed_dim: int,
298
+ ):
299
+ super().__init__()
300
+
301
+ self.height = height
302
+ self.width = width
303
+ self.num_embed = num_embed
304
+ self.embed_dim = embed_dim
305
+
306
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
307
+ self.height_emb = nn.Embedding(self.height, embed_dim)
308
+ self.width_emb = nn.Embedding(self.width, embed_dim)
309
+
310
+ def forward(self, index):
311
+ emb = self.emb(index)
312
+
313
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
314
+
315
+ # 1 x H x D -> 1 x H x 1 x D
316
+ height_emb = height_emb.unsqueeze(2)
317
+
318
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
319
+
320
+ # 1 x W x D -> 1 x 1 x W x D
321
+ width_emb = width_emb.unsqueeze(1)
322
+
323
+ pos_emb = height_emb + width_emb
324
+
325
+ # 1 x H x W x D -> 1 x L xD
326
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
327
+
328
+ emb = emb + pos_emb[:, : emb.shape[1], :]
329
+
330
+ return emb
331
+
332
+
333
+ class LabelEmbedding(nn.Module):
334
+ """
335
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
336
+
337
+ Args:
338
+ num_classes (`int`): The number of classes.
339
+ hidden_size (`int`): The size of the vector embeddings.
340
+ dropout_prob (`float`): The probability of dropping a label.
341
+ """
342
+
343
+ def __init__(self, num_classes, hidden_size, dropout_prob):
344
+ super().__init__()
345
+ use_cfg_embedding = dropout_prob > 0
346
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
347
+ self.num_classes = num_classes
348
+ self.dropout_prob = dropout_prob
349
+
350
+ def token_drop(self, labels, force_drop_ids=None):
351
+ """
352
+ Drops labels to enable classifier-free guidance.
353
+ """
354
+ if force_drop_ids is None:
355
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
356
+ else:
357
+ drop_ids = torch.tensor(force_drop_ids == 1)
358
+ labels = torch.where(drop_ids, self.num_classes, labels)
359
+ return labels
360
+
361
+ def forward(self, labels, force_drop_ids=None):
362
+ use_dropout = self.dropout_prob > 0
363
+ if (self.training and use_dropout) or (force_drop_ids is not None):
364
+ labels = self.token_drop(labels, force_drop_ids)
365
+ embeddings = self.embedding_table(labels)
366
+ return embeddings
367
+
368
+
369
+ class CombinedTimestepLabelEmbeddings(nn.Module):
370
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
371
+ super().__init__()
372
+
373
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
374
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
375
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
376
+
377
+ def forward(self, timestep, class_labels, hidden_dtype=None):
378
+ timesteps_proj = self.time_proj(timestep)
379
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
380
+
381
+ class_labels = self.class_embedder(class_labels) # (N, D)
382
+
383
+ conditioning = timesteps_emb + class_labels # (N, D)
384
+
385
+ return conditioning
src/multiview_consist_edit/models/hack_poseguider.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ from einops import rearrange
6
+ import numpy as np
7
+
8
+ class Hack_PoseGuider(nn.Module):
9
+ def __init__(self, noise_latent_channels=320):
10
+ super(Hack_PoseGuider, self).__init__()
11
+
12
+ self.conv_layers = nn.Sequential(
13
+ nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(3),
15
+ nn.ReLU(),
16
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
17
+ nn.BatchNorm2d(16),
18
+ nn.ReLU(),
19
+
20
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
21
+ nn.BatchNorm2d(16),
22
+ nn.ReLU(),
23
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
24
+ nn.BatchNorm2d(32),
25
+ nn.ReLU(),
26
+
27
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
28
+ nn.BatchNorm2d(32),
29
+ nn.ReLU(),
30
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
31
+ nn.BatchNorm2d(64),
32
+ nn.ReLU(),
33
+
34
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
35
+ nn.BatchNorm2d(64),
36
+ nn.ReLU(),
37
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
38
+ nn.BatchNorm2d(128),
39
+ nn.ReLU()
40
+ )
41
+
42
+ # Final projection layer
43
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
44
+
45
+ # Initialize layers
46
+ self._initialize_weights()
47
+
48
+ self.scale = nn.Parameter(torch.ones(1) * 2)
49
+
50
+ # def _initialize_weights(self):
51
+ # # Initialize weights with Gaussian distribution and zero out the final layer
52
+ # for m in self.conv_layers:
53
+ # if isinstance(m, nn.Conv2d):
54
+ # init.normal_(m.weight, mean=0.0, std=0.02)
55
+ # if m.bias is not None:
56
+ # init.zeros_(m.bias)
57
+
58
+ # init.zeros_(self.final_proj.weight)
59
+ # if self.final_proj.bias is not None:
60
+ # init.zeros_(self.final_proj.bias)
61
+
62
+ def _initialize_weights(self):
63
+ # Initialize weights with He initialization and zero out the biases
64
+ for m in self.conv_layers:
65
+ if isinstance(m, nn.Conv2d):
66
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
67
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
68
+ if m.bias is not None:
69
+ init.zeros_(m.bias)
70
+
71
+ # For the final projection layer, initialize weights to zero (or you may choose to use He initialization here as well)
72
+ init.zeros_(self.final_proj.weight)
73
+ if self.final_proj.bias is not None:
74
+ init.zeros_(self.final_proj.bias)
75
+
76
+
77
+ def forward(self, x):
78
+ x = self.conv_layers(x)
79
+ x = self.final_proj(x)
80
+
81
+ return x * self.scale
82
+
83
+ @classmethod
84
+ def from_pretrained(cls,pretrained_model_path):
85
+ if not os.path.exists(pretrained_model_path):
86
+ print(f"There is no model file in {pretrained_model_path}")
87
+ print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
88
+
89
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
90
+ model = Hack_PoseGuider(noise_latent_channels=320)
91
+
92
+ m, u = model.load_state_dict(state_dict, strict=False)
93
+ # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
94
+ params = [p.numel() for n, p in model.named_parameters()]
95
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
96
+
97
+ return model
src/multiview_consist_edit/models/hack_unet2d.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.utils.checkpoint
7
+ # from diffusers import UNet2DConditionModel
8
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel,UNet2DConditionOutput,logger
9
+
10
+
11
+ class Hack_UNet2DConditionModel(UNet2DConditionModel):
12
+ def forward(
13
+ self,
14
+ sample: torch.FloatTensor,
15
+ timestep: Union[torch.Tensor, float, int],
16
+ encoder_hidden_states: torch.Tensor,
17
+ latent_pose: torch.Tensor, # new add
18
+
19
+ class_labels: Optional[torch.Tensor] = None,
20
+ timestep_cond: Optional[torch.Tensor] = None,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
23
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
24
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
25
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
26
+ encoder_attention_mask: Optional[torch.Tensor] = None,
27
+ return_dict: bool = True,
28
+ ) -> Union[UNet2DConditionOutput, Tuple]:
29
+ r"""
30
+ The [`UNet2DConditionModel`] forward method.
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor`):
34
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
35
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
36
+ encoder_hidden_states (`torch.FloatTensor`):
37
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
38
+ encoder_attention_mask (`torch.Tensor`):
39
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
40
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
41
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
42
+ return_dict (`bool`, *optional*, defaults to `True`):
43
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
44
+ tuple.
45
+ cross_attention_kwargs (`dict`, *optional*):
46
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
47
+ added_cond_kwargs: (`dict`, *optional*):
48
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
49
+ are passed along to the UNet blocks.
50
+
51
+ Returns:
52
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
53
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
54
+ a `tuple` is returned where the first element is the sample tensor.
55
+ """
56
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
57
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
58
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
59
+ # on the fly if necessary.
60
+ default_overall_up_factor = 2**self.num_upsamplers
61
+
62
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
63
+ forward_upsample_size = False
64
+ upsample_size = None
65
+
66
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
67
+ logger.info("Forward upsample size to force interpolation output size.")
68
+ forward_upsample_size = True
69
+
70
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
71
+ # expects mask of shape:
72
+ # [batch, key_tokens]
73
+ # adds singleton query_tokens dimension:
74
+ # [batch, 1, key_tokens]
75
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
76
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
77
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
78
+ if attention_mask is not None:
79
+ # assume that mask is expressed as:
80
+ # (1 = keep, 0 = discard)
81
+ # convert mask into a bias that can be added to attention scores:
82
+ # (keep = +0, discard = -10000.0)
83
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
84
+ attention_mask = attention_mask.unsqueeze(1)
85
+
86
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
87
+ if encoder_attention_mask is not None:
88
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
89
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
90
+
91
+ # 0. center input if necessary
92
+ if self.config.center_input_sample:
93
+ sample = 2 * sample - 1.0
94
+
95
+ # 1. time
96
+ timesteps = timestep
97
+ if not torch.is_tensor(timesteps):
98
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
99
+ # This would be a good case for the `match` statement (Python 3.10+)
100
+ is_mps = sample.device.type == "mps"
101
+ if isinstance(timestep, float):
102
+ dtype = torch.float32 if is_mps else torch.float64
103
+ else:
104
+ dtype = torch.int32 if is_mps else torch.int64
105
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
106
+ elif len(timesteps.shape) == 0:
107
+ timesteps = timesteps[None].to(sample.device)
108
+
109
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
110
+ timesteps = timesteps.expand(sample.shape[0])
111
+
112
+ t_emb = self.time_proj(timesteps)
113
+
114
+ # `Timesteps` does not contain any weights and will always return f32 tensors
115
+ # but time_embedding might actually be running in fp16. so we need to cast here.
116
+ # there might be better ways to encapsulate this.
117
+ t_emb = t_emb.to(dtype=sample.dtype)
118
+
119
+ emb = self.time_embedding(t_emb, timestep_cond)
120
+ aug_emb = None
121
+
122
+ if self.class_embedding is not None:
123
+ if class_labels is None:
124
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
125
+
126
+ if self.config.class_embed_type == "timestep":
127
+ class_labels = self.time_proj(class_labels)
128
+
129
+ # `Timesteps` does not contain any weights and will always return f32 tensors
130
+ # there might be better ways to encapsulate this.
131
+ class_labels = class_labels.to(dtype=sample.dtype)
132
+
133
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
134
+
135
+ if self.config.class_embeddings_concat:
136
+ emb = torch.cat([emb, class_emb], dim=-1)
137
+ else:
138
+ emb = emb + class_emb
139
+
140
+ if self.config.addition_embed_type == "text":
141
+ aug_emb = self.add_embedding(encoder_hidden_states)
142
+ elif self.config.addition_embed_type == "text_image":
143
+ # Kandinsky 2.1 - style
144
+ if "image_embeds" not in added_cond_kwargs:
145
+ raise ValueError(
146
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
147
+ )
148
+
149
+ image_embs = added_cond_kwargs.get("image_embeds")
150
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
151
+ aug_emb = self.add_embedding(text_embs, image_embs)
152
+ elif self.config.addition_embed_type == "text_time":
153
+ # SDXL - style
154
+ if "text_embeds" not in added_cond_kwargs:
155
+ raise ValueError(
156
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
157
+ )
158
+ text_embeds = added_cond_kwargs.get("text_embeds")
159
+ if "time_ids" not in added_cond_kwargs:
160
+ raise ValueError(
161
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
162
+ )
163
+ time_ids = added_cond_kwargs.get("time_ids")
164
+ time_embeds = self.add_time_proj(time_ids.flatten())
165
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
166
+
167
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
168
+ add_embeds = add_embeds.to(emb.dtype)
169
+ aug_emb = self.add_embedding(add_embeds)
170
+ elif self.config.addition_embed_type == "image":
171
+ # Kandinsky 2.2 - style
172
+ if "image_embeds" not in added_cond_kwargs:
173
+ raise ValueError(
174
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
175
+ )
176
+ image_embs = added_cond_kwargs.get("image_embeds")
177
+ aug_emb = self.add_embedding(image_embs)
178
+ elif self.config.addition_embed_type == "image_hint":
179
+ # Kandinsky 2.2 - style
180
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
181
+ raise ValueError(
182
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
183
+ )
184
+ image_embs = added_cond_kwargs.get("image_embeds")
185
+ hint = added_cond_kwargs.get("hint")
186
+ aug_emb, hint = self.add_embedding(image_embs, hint)
187
+ sample = torch.cat([sample, hint], dim=1)
188
+
189
+ emb = emb + aug_emb if aug_emb is not None else emb
190
+
191
+ if self.time_embed_act is not None:
192
+ emb = self.time_embed_act(emb)
193
+
194
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
195
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
196
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
197
+ # Kadinsky 2.1 - style
198
+ if "image_embeds" not in added_cond_kwargs:
199
+ raise ValueError(
200
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
201
+ )
202
+
203
+ image_embeds = added_cond_kwargs.get("image_embeds")
204
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
205
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
206
+ # Kandinsky 2.2 - style
207
+ if "image_embeds" not in added_cond_kwargs:
208
+ raise ValueError(
209
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
210
+ )
211
+ image_embeds = added_cond_kwargs.get("image_embeds")
212
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
213
+ # 2. pre-process
214
+ sample = self.conv_in(sample)
215
+
216
+ # add latent_pose
217
+ sample = sample + latent_pose
218
+
219
+ # 2.5 GLIGEN position net
220
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
221
+ cross_attention_kwargs = cross_attention_kwargs.copy()
222
+ gligen_args = cross_attention_kwargs.pop("gligen")
223
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
224
+
225
+ # 3. down
226
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
227
+
228
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
229
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
230
+
231
+ down_block_res_samples = (sample,)
232
+ for downsample_block in self.down_blocks:
233
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
234
+ # For t2i-adapter CrossAttnDownBlock2D
235
+ additional_residuals = {}
236
+ if is_adapter and len(down_block_additional_residuals) > 0:
237
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
238
+
239
+ sample, res_samples = downsample_block(
240
+ hidden_states=sample,
241
+ temb=emb,
242
+ encoder_hidden_states=encoder_hidden_states,
243
+ attention_mask=attention_mask,
244
+ cross_attention_kwargs=cross_attention_kwargs,
245
+ encoder_attention_mask=encoder_attention_mask,
246
+ **additional_residuals,
247
+ )
248
+ else:
249
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
250
+
251
+ if is_adapter and len(down_block_additional_residuals) > 0:
252
+ sample += down_block_additional_residuals.pop(0)
253
+
254
+ down_block_res_samples += res_samples
255
+
256
+ if is_controlnet:
257
+ new_down_block_res_samples = ()
258
+
259
+ for down_block_res_sample, down_block_additional_residual in zip(
260
+ down_block_res_samples, down_block_additional_residuals
261
+ ):
262
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
263
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
264
+
265
+ down_block_res_samples = new_down_block_res_samples
266
+
267
+ # 4. mid
268
+ if self.mid_block is not None:
269
+ sample = self.mid_block(
270
+ sample,
271
+ emb,
272
+ encoder_hidden_states=encoder_hidden_states,
273
+ attention_mask=attention_mask,
274
+ cross_attention_kwargs=cross_attention_kwargs,
275
+ encoder_attention_mask=encoder_attention_mask,
276
+ )
277
+ # To support T2I-Adapter-XL
278
+ if (
279
+ is_adapter
280
+ and len(down_block_additional_residuals) > 0
281
+ and sample.shape == down_block_additional_residuals[0].shape
282
+ ):
283
+ sample += down_block_additional_residuals.pop(0)
284
+
285
+ if is_controlnet:
286
+ sample = sample + mid_block_additional_residual
287
+
288
+ # 5. up
289
+ for i, upsample_block in enumerate(self.up_blocks):
290
+ is_final_block = i == len(self.up_blocks) - 1
291
+
292
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
293
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
294
+
295
+ # if we have not reached the final block and need to forward the
296
+ # upsample size, we do it here
297
+ if not is_final_block and forward_upsample_size:
298
+ upsample_size = down_block_res_samples[-1].shape[2:]
299
+
300
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
301
+ sample = upsample_block(
302
+ hidden_states=sample,
303
+ temb=emb,
304
+ res_hidden_states_tuple=res_samples,
305
+ encoder_hidden_states=encoder_hidden_states,
306
+ cross_attention_kwargs=cross_attention_kwargs,
307
+ upsample_size=upsample_size,
308
+ attention_mask=attention_mask,
309
+ encoder_attention_mask=encoder_attention_mask,
310
+ )
311
+ else:
312
+ sample = upsample_block(
313
+ hidden_states=sample,
314
+ temb=emb,
315
+ res_hidden_states_tuple=res_samples,
316
+ upsample_size=upsample_size,
317
+ scale=lora_scale,
318
+ )
319
+
320
+ # 6. post-process
321
+ if self.conv_norm_out:
322
+ sample = self.conv_norm_out(sample)
323
+ sample = self.conv_act(sample)
324
+ sample = self.conv_out(sample)
325
+
326
+ if not return_dict:
327
+ return (sample,)
328
+
329
+ return UNet2DConditionOutput(sample=sample)
src/multiview_consist_edit/models/mv_attn_processor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import USE_PEFT_BACKEND
2
+ from typing import Callable, Optional
3
+ import torch
4
+ from diffusers.models.attention_processor import Attention
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+ if is_xformers_available():
7
+ import xformers
8
+ import xformers.ops
9
+ else:
10
+ xformers = None
11
+
12
+ class MVXFormersAttnProcessor:
13
+ r"""
14
+ Processor for implementing memory efficient attention using xFormers.
15
+
16
+ Args:
17
+ attention_op (`Callable`, *optional*, defaults to `None`):
18
+ The base
19
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
20
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
21
+ operator.
22
+ """
23
+
24
+ def __init__(self, weight_matrix=None, attention_op: Optional[Callable] = None):
25
+ if weight_matrix:
26
+ self.bs = weight_matrix.shape[0]
27
+ self.frame_length = weight_matrix.shape[1]
28
+ self.weight_matrix = weight_matrix
29
+ self.attention_op = attention_op
30
+
31
+ def update_weight_matrix(self, weight_matrix):
32
+ self.bs = weight_matrix.shape[0]
33
+ self.frame_length = weight_matrix.shape[1]
34
+ self.weight_matrix = weight_matrix
35
+
36
+ def __call__(
37
+ self,
38
+ attn: Attention,
39
+ hidden_states: torch.Tensor,
40
+ encoder_hidden_states: Optional[torch.Tensor] = None,
41
+ attention_mask: Optional[torch.Tensor] = None,
42
+ temb: Optional[torch.Tensor] = None,
43
+ garment_fea_attn = True,
44
+ *args,
45
+ **kwargs,
46
+ ) -> torch.Tensor:
47
+
48
+ residual = hidden_states
49
+
50
+ if attn.spatial_norm is not None:
51
+ hidden_states = attn.spatial_norm(hidden_states, temb)
52
+
53
+ input_ndim = hidden_states.ndim
54
+
55
+ if input_ndim == 4:
56
+ batch_size, channel, height, width = hidden_states.shape
57
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
58
+
59
+ batch_size, key_tokens, _ = (
60
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
61
+ )
62
+
63
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
64
+ if attention_mask is not None:
65
+ # expand our mask's singleton query_tokens dimension:
66
+ # [batch*heads, 1, key_tokens] ->
67
+ # [batch*heads, query_tokens, key_tokens]
68
+ # so that it can be added as a bias onto the attention scores that xformers computes:
69
+ # [batch*heads, query_tokens, key_tokens]
70
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
71
+ _, query_tokens, _ = hidden_states.shape
72
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
73
+
74
+ if attn.group_norm is not None:
75
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
76
+
77
+ query = attn.to_q(hidden_states)
78
+
79
+ if encoder_hidden_states is None:
80
+ encoder_hidden_states = hidden_states
81
+ elif attn.norm_cross:
82
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
83
+
84
+ key = attn.to_k(encoder_hidden_states)
85
+ value = attn.to_v(encoder_hidden_states)
86
+
87
+ query = attn.head_to_batch_dim(query).contiguous()
88
+ key = attn.head_to_batch_dim(key).contiguous()
89
+ value = attn.head_to_batch_dim(value).contiguous()
90
+
91
+ attn_out = torch.empty_like(query)
92
+
93
+ if garment_fea_attn:
94
+ frame_length = self.frame_length + 2 # 2 for two garments
95
+ else:
96
+ frame_length = self.frame_length
97
+ token_num_per_frame = query.shape[1] // frame_length
98
+ # print('000000',query.shape,frame_length)
99
+ heads_num = attn.heads
100
+ for b in range(self.bs):
101
+ for i in range(self.frame_length):
102
+ curr_q = query[heads_num*b:heads_num*(b+1),token_num_per_frame*i:token_num_per_frame*(i+1),:]
103
+ weight = self.weight_matrix[b,i,:]
104
+ if garment_fea_attn:
105
+ weight = torch.cat([weight,torch.tensor([1,1],dtype=weight.dtype,device=weight.device)],dim=0) # garment's attn weight set 1
106
+ weight = weight.repeat_interleave(token_num_per_frame)
107
+ curr_k = key[heads_num*b:heads_num*(b+1)]
108
+ curr_v = value[heads_num*b:heads_num*(b+1)]
109
+ weight = weight.unsqueeze(0).unsqueeze(-1)
110
+ curr_k = weight * curr_k
111
+ hidden_states = xformers.ops.memory_efficient_attention(
112
+ curr_q, curr_k, curr_v, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
113
+ )
114
+ attn_out[heads_num*b:heads_num*(b+1),token_num_per_frame*i:token_num_per_frame*(i+1),:] = hidden_states
115
+ hidden_states = attn_out
116
+ hidden_states = hidden_states.to(query.dtype)
117
+ hidden_states = attn.batch_to_head_dim(hidden_states)
118
+
119
+ # linear proj
120
+ hidden_states = attn.to_out[0](hidden_states)
121
+ # dropout
122
+ hidden_states = attn.to_out[1](hidden_states)
123
+
124
+ if input_ndim == 4:
125
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
126
+
127
+ if attn.residual_connection:
128
+ hidden_states = hidden_states + residual
129
+
130
+ hidden_states = hidden_states / attn.rescale_output_factor
131
+
132
+ return hidden_states
src/multiview_consist_edit/models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
27
+ from einops import rearrange
28
+
29
+
30
+ class InflatedConv3d(nn.Conv2d):
31
+ def forward(self, x):
32
+ video_length = x.shape[2]
33
+
34
+ x = rearrange(x, "b c f h w -> (b f) c h w")
35
+ x = super().forward(x)
36
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
37
+
38
+ return x
39
+
40
+
41
+ class Upsample3D(nn.Module):
42
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
43
+ super().__init__()
44
+ self.channels = channels
45
+ self.out_channels = out_channels or channels
46
+ self.use_conv = use_conv
47
+ self.use_conv_transpose = use_conv_transpose
48
+ self.name = name
49
+
50
+ conv = None
51
+ if use_conv_transpose:
52
+ raise NotImplementedError
53
+ elif use_conv:
54
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
55
+
56
+ def forward(self, hidden_states, output_size=None):
57
+ assert hidden_states.shape[1] == self.channels
58
+
59
+ if self.use_conv_transpose:
60
+ raise NotImplementedError
61
+
62
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
63
+ dtype = hidden_states.dtype
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(torch.float32)
66
+
67
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
68
+ if hidden_states.shape[0] >= 64:
69
+ hidden_states = hidden_states.contiguous()
70
+
71
+ # if `output_size` is passed we force the interpolation output
72
+ # size and do not make use of `scale_factor=2`
73
+ if output_size is None:
74
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
75
+ else:
76
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
77
+
78
+ # If the input is bfloat16, we cast back to bfloat16
79
+ if dtype == torch.bfloat16:
80
+ hidden_states = hidden_states.to(dtype)
81
+
82
+ hidden_states = self.conv(hidden_states)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class Downsample3D(nn.Module):
88
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.out_channels = out_channels or channels
92
+ self.use_conv = use_conv
93
+ self.padding = padding
94
+ stride = 2
95
+ self.name = name
96
+
97
+ if use_conv:
98
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
99
+ else:
100
+ raise NotImplementedError
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
src/multiview_consist_edit/models/unet.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo-
3
+ # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B-
4
+ # ytedance Inc..
5
+ # *************************************************************************
6
+
7
+ # Adapted from https://github.com/guoyww/AnimateDiff
8
+
9
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ from dataclasses import dataclass
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ import os
26
+ import json
27
+ import pdb
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.utils.checkpoint
32
+
33
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
34
+ from diffusers.models.modeling_utils import ModelMixin
35
+ from diffusers.utils import BaseOutput, logging
36
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
37
+ from .unet_3d_blocks import (
38
+ CrossAttnDownBlock3D,
39
+ CrossAttnUpBlock3D,
40
+ DownBlock3D,
41
+ UNetMidBlock3DCrossAttn,
42
+ UpBlock3D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from .resnet import InflatedConv3d
47
+
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+
52
+ @dataclass
53
+ class UNet3DConditionOutput(BaseOutput):
54
+ sample: torch.FloatTensor
55
+
56
+
57
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
58
+ _supports_gradient_checkpointing = True
59
+
60
+ @register_to_config
61
+ def __init__(
62
+ self,
63
+ sample_size: Optional[int] = None,
64
+ in_channels: int = 4,
65
+ out_channels: int = 4,
66
+ center_input_sample: bool = False,
67
+ flip_sin_to_cos: bool = True,
68
+ freq_shift: int = 0,
69
+ down_block_types: Tuple[str] = (
70
+ "CrossAttnDownBlock3D",
71
+ "CrossAttnDownBlock3D",
72
+ "CrossAttnDownBlock3D",
73
+ "DownBlock3D",
74
+ ),
75
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
76
+ up_block_types: Tuple[str] = (
77
+ "UpBlock3D",
78
+ "CrossAttnUpBlock3D",
79
+ "CrossAttnUpBlock3D",
80
+ "CrossAttnUpBlock3D"
81
+ ),
82
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
83
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
84
+ layers_per_block: int = 2,
85
+ downsample_padding: int = 1,
86
+ mid_block_scale_factor: float = 1,
87
+ act_fn: str = "silu",
88
+ norm_num_groups: int = 32,
89
+ norm_eps: float = 1e-5,
90
+ cross_attention_dim: int = 1280,
91
+ attention_head_dim: Union[int, Tuple[int]] = 8,
92
+ dual_cross_attention: bool = False,
93
+ use_linear_projection: bool = False,
94
+ class_embed_type: Optional[str] = None,
95
+ num_class_embeds: Optional[int] = None,
96
+ upcast_attention: bool = False,
97
+ resnet_time_scale_shift: str = "default",
98
+
99
+ # Additional
100
+ use_motion_module = False,
101
+ motion_module_resolutions = ( 1,2,4,8 ),
102
+ motion_module_mid_block = False,
103
+ motion_module_decoder_only = False,
104
+ motion_module_type = None,
105
+ motion_module_kwargs = {},
106
+ unet_use_cross_frame_attention = None,
107
+ unet_use_temporal_attention = None,
108
+ encoder_hid_dim: Optional[int] = None,
109
+ ):
110
+ super().__init__()
111
+
112
+ self.sample_size = sample_size
113
+ time_embed_dim = block_out_channels[0] * 4
114
+
115
+ # input
116
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
117
+
118
+ # time
119
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
120
+ timestep_input_dim = block_out_channels[0]
121
+
122
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
123
+
124
+ if encoder_hid_dim is not None:
125
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
126
+ else:
127
+ self.encoder_hid_proj = None
128
+
129
+ # class embedding
130
+ if class_embed_type is None and num_class_embeds is not None:
131
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
132
+ elif class_embed_type == "timestep":
133
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
134
+ elif class_embed_type == "identity":
135
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
136
+ else:
137
+ self.class_embedding = None
138
+
139
+ self.down_blocks = nn.ModuleList([])
140
+ self.mid_block = None
141
+ self.up_blocks = nn.ModuleList([])
142
+
143
+ if isinstance(only_cross_attention, bool):
144
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
145
+
146
+ if isinstance(attention_head_dim, int):
147
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
148
+
149
+ # down
150
+ output_channel = block_out_channels[0]
151
+ for i, down_block_type in enumerate(down_block_types):
152
+ res = 2 ** i
153
+ input_channel = output_channel
154
+ output_channel = block_out_channels[i]
155
+ is_final_block = i == len(block_out_channels) - 1
156
+
157
+ down_block = get_down_block(
158
+ down_block_type,
159
+ num_layers=layers_per_block,
160
+ in_channels=input_channel,
161
+ out_channels=output_channel,
162
+ temb_channels=time_embed_dim,
163
+ add_downsample=not is_final_block,
164
+ resnet_eps=norm_eps,
165
+ resnet_act_fn=act_fn,
166
+ resnet_groups=norm_num_groups,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attention_head_dim[i],
169
+ downsample_padding=downsample_padding,
170
+ dual_cross_attention=dual_cross_attention,
171
+ use_linear_projection=use_linear_projection,
172
+ only_cross_attention=only_cross_attention[i],
173
+ upcast_attention=upcast_attention,
174
+ resnet_time_scale_shift=resnet_time_scale_shift,
175
+
176
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
177
+ unet_use_temporal_attention=unet_use_temporal_attention,
178
+
179
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
180
+ motion_module_type=motion_module_type,
181
+ motion_module_kwargs=motion_module_kwargs,
182
+ )
183
+ self.down_blocks.append(down_block)
184
+
185
+ # mid
186
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
187
+ self.mid_block = UNetMidBlock3DCrossAttn(
188
+ in_channels=block_out_channels[-1],
189
+ temb_channels=time_embed_dim,
190
+ resnet_eps=norm_eps,
191
+ resnet_act_fn=act_fn,
192
+ output_scale_factor=mid_block_scale_factor,
193
+ resnet_time_scale_shift=resnet_time_scale_shift,
194
+ cross_attention_dim=cross_attention_dim,
195
+ attn_num_head_channels=attention_head_dim[-1],
196
+ resnet_groups=norm_num_groups,
197
+ dual_cross_attention=dual_cross_attention,
198
+ use_linear_projection=use_linear_projection,
199
+ upcast_attention=upcast_attention,
200
+
201
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
202
+ unet_use_temporal_attention=unet_use_temporal_attention,
203
+
204
+ use_motion_module=use_motion_module and motion_module_mid_block,
205
+ motion_module_type=motion_module_type,
206
+ motion_module_kwargs=motion_module_kwargs,
207
+ )
208
+ else:
209
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
210
+
211
+ # count how many layers upsample the videos
212
+ self.num_upsamplers = 0
213
+
214
+ # up
215
+ reversed_block_out_channels = list(reversed(block_out_channels))
216
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
217
+ only_cross_attention = list(reversed(only_cross_attention))
218
+ output_channel = reversed_block_out_channels[0]
219
+ for i, up_block_type in enumerate(up_block_types):
220
+ res = 2 ** (3 - i)
221
+ is_final_block = i == len(block_out_channels) - 1
222
+
223
+ prev_output_channel = output_channel
224
+ output_channel = reversed_block_out_channels[i]
225
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
226
+
227
+ # add upsample block for all BUT final layer
228
+ if not is_final_block:
229
+ add_upsample = True
230
+ self.num_upsamplers += 1
231
+ else:
232
+ add_upsample = False
233
+
234
+ up_block = get_up_block(
235
+ up_block_type,
236
+ num_layers=layers_per_block + 1,
237
+ in_channels=input_channel,
238
+ out_channels=output_channel,
239
+ prev_output_channel=prev_output_channel,
240
+ temb_channels=time_embed_dim,
241
+ add_upsample=add_upsample,
242
+ resnet_eps=norm_eps,
243
+ resnet_act_fn=act_fn,
244
+ resnet_groups=norm_num_groups,
245
+ cross_attention_dim=cross_attention_dim,
246
+ attn_num_head_channels=reversed_attention_head_dim[i],
247
+ dual_cross_attention=dual_cross_attention,
248
+ use_linear_projection=use_linear_projection,
249
+ only_cross_attention=only_cross_attention[i],
250
+ upcast_attention=upcast_attention,
251
+ resnet_time_scale_shift=resnet_time_scale_shift,
252
+
253
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
254
+ unet_use_temporal_attention=unet_use_temporal_attention,
255
+
256
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
257
+ motion_module_type=motion_module_type,
258
+ motion_module_kwargs=motion_module_kwargs,
259
+ )
260
+ self.up_blocks.append(up_block)
261
+ prev_output_channel = output_channel
262
+
263
+ # out
264
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
265
+ self.conv_act = nn.SiLU()
266
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
267
+
268
+ def set_attention_slice(self, slice_size):
269
+ r"""
270
+ Enable sliced attention computation.
271
+
272
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
273
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
274
+
275
+ Args:
276
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
277
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
278
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
279
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
280
+ must be a multiple of `slice_size`.
281
+ """
282
+ sliceable_head_dims = []
283
+
284
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
285
+ if hasattr(module, "set_attention_slice"):
286
+ sliceable_head_dims.append(module.sliceable_head_dim)
287
+
288
+ for child in module.children():
289
+ fn_recursive_retrieve_slicable_dims(child)
290
+
291
+ # retrieve number of attention layers
292
+ for module in self.children():
293
+ fn_recursive_retrieve_slicable_dims(module)
294
+
295
+ num_slicable_layers = len(sliceable_head_dims)
296
+
297
+ if slice_size == "auto":
298
+ # half the attention head size is usually a good trade-off between
299
+ # speed and memory
300
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
301
+ elif slice_size == "max":
302
+ # make smallest slice possible
303
+ slice_size = num_slicable_layers * [1]
304
+
305
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
306
+
307
+ if len(slice_size) != len(sliceable_head_dims):
308
+ raise ValueError(
309
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
310
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
311
+ )
312
+
313
+ for i in range(len(slice_size)):
314
+ size = slice_size[i]
315
+ dim = sliceable_head_dims[i]
316
+ if size is not None and size > dim:
317
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
318
+
319
+ # Recursively walk through all the children.
320
+ # Any children which exposes the set_attention_slice method
321
+ # gets the message
322
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
323
+ if hasattr(module, "set_attention_slice"):
324
+ module.set_attention_slice(slice_size.pop())
325
+
326
+ for child in module.children():
327
+ fn_recursive_set_attention_slice(child, slice_size)
328
+
329
+ reversed_slice_size = list(reversed(slice_size))
330
+ for module in self.children():
331
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
332
+
333
+ def _set_gradient_checkpointing(self, module, value=False):
334
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
335
+ module.gradient_checkpointing = value
336
+
337
+ def forward(
338
+ self,
339
+ sample: torch.FloatTensor,
340
+ timestep: Union[torch.Tensor, float, int],
341
+ encoder_hidden_states: torch.Tensor,
342
+ class_labels: Optional[torch.Tensor] = None,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ return_dict: bool = True,
345
+ ) -> Union[UNet3DConditionOutput, Tuple]:
346
+ r"""
347
+ Args:
348
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
349
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
350
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
351
+ return_dict (`bool`, *optional*, defaults to `True`):
352
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
353
+
354
+ Returns:
355
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
356
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
357
+ returning a tuple, the first element is the sample tensor.
358
+ """
359
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
360
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
361
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
362
+ # on the fly if necessary.
363
+ default_overall_up_factor = 2**self.num_upsamplers
364
+
365
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
366
+ forward_upsample_size = False
367
+ upsample_size = None
368
+
369
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
370
+ logger.info("Forward upsample size to force interpolation output size.")
371
+ forward_upsample_size = True
372
+
373
+ # prepare attention_mask
374
+ if attention_mask is not None:
375
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
376
+ attention_mask = attention_mask.unsqueeze(1)
377
+
378
+ # center input if necessary
379
+ if self.config.center_input_sample:
380
+ sample = 2 * sample - 1.0
381
+
382
+ # time
383
+ timesteps = timestep
384
+ if not torch.is_tensor(timesteps):
385
+ # This would be a good case for the `match` statement (Python 3.10+)
386
+ is_mps = sample.device.type == "mps"
387
+ if isinstance(timestep, float):
388
+ dtype = torch.float32 if is_mps else torch.float64
389
+ else:
390
+ dtype = torch.int32 if is_mps else torch.int64
391
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
392
+ elif len(timesteps.shape) == 0:
393
+ timesteps = timesteps[None].to(sample.device)
394
+
395
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
396
+ timesteps = timesteps.expand(sample.shape[0])
397
+
398
+ t_emb = self.time_proj(timesteps)
399
+
400
+ # timesteps does not contain any weights and will always return f32 tensors
401
+ # but time_embedding might actually be running in fp16. so we need to cast here.
402
+ # there might be better ways to encapsulate this.
403
+ t_emb = t_emb.to(dtype=self.dtype)
404
+ emb = self.time_embedding(t_emb)
405
+
406
+ if self.class_embedding is not None:
407
+ if class_labels is None:
408
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
409
+
410
+ if self.config.class_embed_type == "timestep":
411
+ class_labels = self.time_proj(class_labels)
412
+
413
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
414
+ emb = emb + class_emb
415
+
416
+ if self.encoder_hid_proj is not None:
417
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
418
+
419
+ # pre-process
420
+ sample = self.conv_in(sample)
421
+
422
+ # down
423
+ down_block_res_samples = (sample,)
424
+ for downsample_block in self.down_blocks:
425
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
426
+ sample, res_samples = downsample_block(
427
+ hidden_states=sample,
428
+ temb=emb,
429
+ encoder_hidden_states=encoder_hidden_states,
430
+ attention_mask=attention_mask,
431
+ )
432
+ else:
433
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
434
+
435
+ down_block_res_samples += res_samples
436
+
437
+ # mid
438
+ sample = self.mid_block(
439
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
440
+ )
441
+
442
+ # up
443
+ for i, upsample_block in enumerate(self.up_blocks):
444
+ is_final_block = i == len(self.up_blocks) - 1
445
+
446
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
447
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
448
+
449
+ # if we have not reached the final block and need to forward the
450
+ # upsample size, we do it here
451
+ if not is_final_block and forward_upsample_size:
452
+ upsample_size = down_block_res_samples[-1].shape[2:]
453
+
454
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
455
+ sample = upsample_block(
456
+ hidden_states=sample,
457
+ temb=emb,
458
+ res_hidden_states_tuple=res_samples,
459
+ encoder_hidden_states=encoder_hidden_states,
460
+ upsample_size=upsample_size,
461
+ attention_mask=attention_mask,
462
+ )
463
+ else:
464
+ sample = upsample_block(
465
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
466
+ )
467
+
468
+ # post-process
469
+ sample = self.conv_norm_out(sample)
470
+ sample = self.conv_act(sample)
471
+ sample = self.conv_out(sample)
472
+
473
+ if not return_dict:
474
+ return (sample,)
475
+
476
+ return UNet3DConditionOutput(sample=sample)
477
+
478
+ @classmethod
479
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
480
+ if subfolder is not None:
481
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
482
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
483
+
484
+ config_file = os.path.join(pretrained_model_path, 'config.json')
485
+ if not os.path.isfile(config_file):
486
+ raise RuntimeError(f"{config_file} does not exist")
487
+ with open(config_file, "r") as f:
488
+ config = json.load(f)
489
+ config["_class_name"] = cls.__name__
490
+ config["down_block_types"] = [
491
+ "CrossAttnDownBlock3D",
492
+ "CrossAttnDownBlock3D",
493
+ "CrossAttnDownBlock3D",
494
+ "DownBlock3D"
495
+ ]
496
+ config["up_block_types"] = [
497
+ "UpBlock3D",
498
+ "CrossAttnUpBlock3D",
499
+ "CrossAttnUpBlock3D",
500
+ "CrossAttnUpBlock3D"
501
+ ]
502
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
503
+ from diffusers.utils import WEIGHTS_NAME
504
+ # 用于加载accelerator存的模型
505
+ import safetensors
506
+ WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
507
+ model = cls.from_config(config, **unet_additional_kwargs)
508
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
509
+ if not os.path.isfile(model_file):
510
+ raise RuntimeError(f"{model_file} does not exist")
511
+ # state_dict = torch.load(model_file, map_location="cpu")
512
+ state_dict = safetensors.torch.load_file(
513
+ model_file, device="cpu"
514
+ )
515
+
516
+ m, u = model.load_state_dict(state_dict, strict=False)
517
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
518
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
519
+
520
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
521
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
522
+
523
+ return model
src/multiview_consist_edit/parse_tool/postprocess_parse.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ # sys.path.append('./')
3
+ from PIL import Image
4
+ from preprocess.humanparsing.run_parsing import Parsing
5
+ from preprocess.openpose.run_openpose import OpenPose
6
+ import os
7
+ import torch
8
+ from torchvision import transforms
9
+ from torchvision.transforms.functional import to_pil_image
10
+ import argparse
11
+
12
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
13
+
14
+ if __name__ == '__main__':
15
+
16
+ parser = argparse.ArgumentParser(description='script')
17
+
18
+ # 添加参数
19
+ parser.add_argument('root', type=str)
20
+
21
+ # 解析参数
22
+ args = parser.parse_args()
23
+
24
+ # root = '/GPUFS/sysu_gbli2_1/hzj/animate/output/image_output_tryon_1025_22000_test_multi_3_all2_mvg_back/'
25
+ root = args.root
26
+ parsing_model = Parsing(0)
27
+ cloth_ids = os.listdir(root)
28
+
29
+ for cloth_subroot in cloth_ids[:]:
30
+ print(cloth_subroot)
31
+ images = os.listdir(os.path.join(root, cloth_subroot))
32
+
33
+ for image in images:
34
+ if 'cond' in image or 'parse' in image:
35
+ continue
36
+ human_img_path = os.path.join(root, cloth_subroot, image)
37
+ human_img = Image.open(human_img_path)
38
+ model_parse, _ = parsing_model(human_img.resize((384,512)))
39
+ model_parse = model_parse.resize((576,768))
40
+ model_parse_path = os.path.join(root, cloth_subroot, 'parse_'+image.replace('jpg','png'))
41
+ # print(model_parse_path)
42
+ model_parse.save(model_parse_path)
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/__init__.py ADDED
File without changes
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/datasets.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : datasets.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ import random
17
+ import torch
18
+ import cv2
19
+ from torch.utils import data
20
+ from utils.transforms import get_affine_transform
21
+
22
+
23
+ class LIPDataSet(data.Dataset):
24
+ def __init__(self, root, dataset, crop_size=[473, 473], scale_factor=0.25,
25
+ rotation_factor=30, ignore_label=255, transform=None):
26
+ self.root = root
27
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
28
+ self.crop_size = np.asarray(crop_size)
29
+ self.ignore_label = ignore_label
30
+ self.scale_factor = scale_factor
31
+ self.rotation_factor = rotation_factor
32
+ self.flip_prob = 0.5
33
+ self.transform = transform
34
+ self.dataset = dataset
35
+
36
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
37
+ train_list = [i_id.strip() for i_id in open(list_path)]
38
+
39
+ self.train_list = train_list
40
+ self.number_samples = len(self.train_list)
41
+
42
+ def __len__(self):
43
+ return self.number_samples
44
+
45
+ def _box2cs(self, box):
46
+ x, y, w, h = box[:4]
47
+ return self._xywh2cs(x, y, w, h)
48
+
49
+ def _xywh2cs(self, x, y, w, h):
50
+ center = np.zeros((2), dtype=np.float32)
51
+ center[0] = x + w * 0.5
52
+ center[1] = y + h * 0.5
53
+ if w > self.aspect_ratio * h:
54
+ h = w * 1.0 / self.aspect_ratio
55
+ elif w < self.aspect_ratio * h:
56
+ w = h * self.aspect_ratio
57
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
58
+ return center, scale
59
+
60
+ def __getitem__(self, index):
61
+ train_item = self.train_list[index]
62
+
63
+ im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
64
+ parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')
65
+
66
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
67
+ h, w, _ = im.shape
68
+ parsing_anno = np.zeros((h, w), dtype=np.long)
69
+
70
+ # Get person center and scale
71
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
72
+ r = 0
73
+
74
+ if self.dataset != 'test':
75
+ # Get pose annotation
76
+ parsing_anno = cv2.imread(parsing_anno_path, cv2.IMREAD_GRAYSCALE)
77
+ if self.dataset == 'train' or self.dataset == 'trainval':
78
+ sf = self.scale_factor
79
+ rf = self.rotation_factor
80
+ s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
81
+ r = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) if random.random() <= 0.6 else 0
82
+
83
+ if random.random() <= self.flip_prob:
84
+ im = im[:, ::-1, :]
85
+ parsing_anno = parsing_anno[:, ::-1]
86
+ person_center[0] = im.shape[1] - person_center[0] - 1
87
+ right_idx = [15, 17, 19]
88
+ left_idx = [14, 16, 18]
89
+ for i in range(0, 3):
90
+ right_pos = np.where(parsing_anno == right_idx[i])
91
+ left_pos = np.where(parsing_anno == left_idx[i])
92
+ parsing_anno[right_pos[0], right_pos[1]] = left_idx[i]
93
+ parsing_anno[left_pos[0], left_pos[1]] = right_idx[i]
94
+
95
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
96
+ input = cv2.warpAffine(
97
+ im,
98
+ trans,
99
+ (int(self.crop_size[1]), int(self.crop_size[0])),
100
+ flags=cv2.INTER_LINEAR,
101
+ borderMode=cv2.BORDER_CONSTANT,
102
+ borderValue=(0, 0, 0))
103
+
104
+ if self.transform:
105
+ input = self.transform(input)
106
+
107
+ meta = {
108
+ 'name': train_item,
109
+ 'center': person_center,
110
+ 'height': h,
111
+ 'width': w,
112
+ 'scale': s,
113
+ 'rotation': r
114
+ }
115
+
116
+ if self.dataset == 'val' or self.dataset == 'test':
117
+ return input, meta
118
+ else:
119
+ label_parsing = cv2.warpAffine(
120
+ parsing_anno,
121
+ trans,
122
+ (int(self.crop_size[1]), int(self.crop_size[0])),
123
+ flags=cv2.INTER_NEAREST,
124
+ borderMode=cv2.BORDER_CONSTANT,
125
+ borderValue=(255))
126
+
127
+ label_parsing = torch.from_numpy(label_parsing)
128
+
129
+ return input, label_parsing, meta
130
+
131
+
132
+ class LIPDataValSet(data.Dataset):
133
+ def __init__(self, root, dataset='val', crop_size=[473, 473], transform=None, flip=False):
134
+ self.root = root
135
+ self.crop_size = crop_size
136
+ self.transform = transform
137
+ self.flip = flip
138
+ self.dataset = dataset
139
+ self.root = root
140
+ self.aspect_ratio = crop_size[1] * 1.0 / crop_size[0]
141
+ self.crop_size = np.asarray(crop_size)
142
+
143
+ list_path = os.path.join(self.root, self.dataset + '_id.txt')
144
+ val_list = [i_id.strip() for i_id in open(list_path)]
145
+
146
+ self.val_list = val_list
147
+ self.number_samples = len(self.val_list)
148
+
149
+ def __len__(self):
150
+ return len(self.val_list)
151
+
152
+ def _box2cs(self, box):
153
+ x, y, w, h = box[:4]
154
+ return self._xywh2cs(x, y, w, h)
155
+
156
+ def _xywh2cs(self, x, y, w, h):
157
+ center = np.zeros((2), dtype=np.float32)
158
+ center[0] = x + w * 0.5
159
+ center[1] = y + h * 0.5
160
+ if w > self.aspect_ratio * h:
161
+ h = w * 1.0 / self.aspect_ratio
162
+ elif w < self.aspect_ratio * h:
163
+ w = h * self.aspect_ratio
164
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
165
+
166
+ return center, scale
167
+
168
+ def __getitem__(self, index):
169
+ val_item = self.val_list[index]
170
+ # Load training image
171
+ im_path = os.path.join(self.root, self.dataset + '_images', val_item + '.jpg')
172
+ im = cv2.imread(im_path, cv2.IMREAD_COLOR)
173
+ h, w, _ = im.shape
174
+ # Get person center and scale
175
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
176
+ r = 0
177
+ trans = get_affine_transform(person_center, s, r, self.crop_size)
178
+ input = cv2.warpAffine(
179
+ im,
180
+ trans,
181
+ (int(self.crop_size[1]), int(self.crop_size[0])),
182
+ flags=cv2.INTER_LINEAR,
183
+ borderMode=cv2.BORDER_CONSTANT,
184
+ borderValue=(0, 0, 0))
185
+ input = self.transform(input)
186
+ flip_input = input.flip(dims=[-1])
187
+ if self.flip:
188
+ batch_input_im = torch.stack([input, flip_input])
189
+ else:
190
+ batch_input_im = input
191
+
192
+ meta = {
193
+ 'name': val_item,
194
+ 'center': person_center,
195
+ 'height': h,
196
+ 'width': w,
197
+ 'scale': s,
198
+ 'rotation': r
199
+ }
200
+
201
+ return batch_input_im, meta
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/simple_extractor_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : dataset.py
8
+ @Time : 8/30/19 9:12 PM
9
+ @Desc : Dataset Definition
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import os
15
+ import pdb
16
+
17
+ import cv2
18
+ import numpy as np
19
+ from PIL import Image
20
+ from torch.utils import data
21
+ from utils.transforms import get_affine_transform
22
+
23
+
24
+ class SimpleFolderDataset(data.Dataset):
25
+ def __init__(self, root, input_size=[512, 512], transform=None):
26
+ self.root = root
27
+ self.input_size = input_size
28
+ self.transform = transform
29
+ self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
30
+ self.input_size = np.asarray(input_size)
31
+ self.is_pil_image = False
32
+ if isinstance(root, Image.Image):
33
+ self.file_list = [root]
34
+ self.is_pil_image = True
35
+ elif os.path.isfile(root):
36
+ self.file_list = [os.path.basename(root)]
37
+ self.root = os.path.dirname(root)
38
+ else:
39
+ self.file_list = os.listdir(self.root)
40
+
41
+ def __len__(self):
42
+ return len(self.file_list)
43
+
44
+ def _box2cs(self, box):
45
+ x, y, w, h = box[:4]
46
+ return self._xywh2cs(x, y, w, h)
47
+
48
+ def _xywh2cs(self, x, y, w, h):
49
+ center = np.zeros((2), dtype=np.float32)
50
+ center[0] = x + w * 0.5
51
+ center[1] = y + h * 0.5
52
+ if w > self.aspect_ratio * h:
53
+ h = w * 1.0 / self.aspect_ratio
54
+ elif w < self.aspect_ratio * h:
55
+ w = h * self.aspect_ratio
56
+ scale = np.array([w, h], dtype=np.float32)
57
+ return center, scale
58
+
59
+ def __getitem__(self, index):
60
+ if self.is_pil_image:
61
+ img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
62
+ else:
63
+ img_name = self.file_list[index]
64
+ img_path = os.path.join(self.root, img_name)
65
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
66
+ h, w, _ = img.shape
67
+
68
+ # Get person center and scale
69
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
70
+ r = 0
71
+ trans = get_affine_transform(person_center, s, r, self.input_size)
72
+ input = cv2.warpAffine(
73
+ img,
74
+ trans,
75
+ (int(self.input_size[1]), int(self.input_size[0])),
76
+ flags=cv2.INTER_LINEAR,
77
+ borderMode=cv2.BORDER_CONSTANT,
78
+ borderValue=(0, 0, 0))
79
+
80
+ input = self.transform(input)
81
+ meta = {
82
+ 'center': person_center,
83
+ 'height': h,
84
+ 'width': w,
85
+ 'scale': s,
86
+ 'rotation': r
87
+ }
88
+
89
+ return input, meta
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/datasets/target_generation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def generate_edge_tensor(label, edge_width=3):
6
+ label = label.type(torch.cuda.FloatTensor)
7
+ if len(label.shape) == 2:
8
+ label = label.unsqueeze(0)
9
+ n, h, w = label.shape
10
+ edge = torch.zeros(label.shape, dtype=torch.float).cuda()
11
+ # right
12
+ edge_right = edge[:, 1:h, :]
13
+ edge_right[(label[:, 1:h, :] != label[:, :h - 1, :]) & (label[:, 1:h, :] != 255)
14
+ & (label[:, :h - 1, :] != 255)] = 1
15
+
16
+ # up
17
+ edge_up = edge[:, :, :w - 1]
18
+ edge_up[(label[:, :, :w - 1] != label[:, :, 1:w])
19
+ & (label[:, :, :w - 1] != 255)
20
+ & (label[:, :, 1:w] != 255)] = 1
21
+
22
+ # upright
23
+ edge_upright = edge[:, :h - 1, :w - 1]
24
+ edge_upright[(label[:, :h - 1, :w - 1] != label[:, 1:h, 1:w])
25
+ & (label[:, :h - 1, :w - 1] != 255)
26
+ & (label[:, 1:h, 1:w] != 255)] = 1
27
+
28
+ # bottomright
29
+ edge_bottomright = edge[:, :h - 1, 1:w]
30
+ edge_bottomright[(label[:, :h - 1, 1:w] != label[:, 1:h, :w - 1])
31
+ & (label[:, :h - 1, 1:w] != 255)
32
+ & (label[:, 1:h, :w - 1] != 255)] = 1
33
+
34
+ kernel = torch.ones((1, 1, edge_width, edge_width), dtype=torch.float).cuda()
35
+ with torch.no_grad():
36
+ edge = edge.unsqueeze(1)
37
+ edge = F.conv2d(edge, kernel, stride=1, padding=1)
38
+ edge[edge!=0] = 1
39
+ edge = edge.squeeze()
40
+ return edge
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .bn import ABN, InPlaceABN, InPlaceABNSync
2
+ from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3
+ from .misc import GlobalAvgPool2d, SingleGPU
4
+ from .residual import IdentityResidualBlock
5
+ from .dense import DenseModule
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/bn.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ try:
6
+ from queue import Queue
7
+ except ImportError:
8
+ from Queue import Queue
9
+
10
+ from .functions import *
11
+
12
+
13
+ class ABN(nn.Module):
14
+ """Activated Batch Normalization
15
+
16
+ This gathers a `BatchNorm2d` and an activation function in a single module
17
+ """
18
+
19
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
20
+ """Creates an Activated Batch Normalization module
21
+
22
+ Parameters
23
+ ----------
24
+ num_features : int
25
+ Number of feature channels in the input and output.
26
+ eps : float
27
+ Small constant to prevent numerical issues.
28
+ momentum : float
29
+ Momentum factor applied to compute running statistics as.
30
+ affine : bool
31
+ If `True` apply learned scale and shift transformation after normalization.
32
+ activation : str
33
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
34
+ slope : float
35
+ Negative slope for the `leaky_relu` activation.
36
+ """
37
+ super(ABN, self).__init__()
38
+ self.num_features = num_features
39
+ self.affine = affine
40
+ self.eps = eps
41
+ self.momentum = momentum
42
+ self.activation = activation
43
+ self.slope = slope
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(num_features))
46
+ self.bias = nn.Parameter(torch.zeros(num_features))
47
+ else:
48
+ self.register_parameter('weight', None)
49
+ self.register_parameter('bias', None)
50
+ self.register_buffer('running_mean', torch.zeros(num_features))
51
+ self.register_buffer('running_var', torch.ones(num_features))
52
+ self.reset_parameters()
53
+
54
+ def reset_parameters(self):
55
+ nn.init.constant_(self.running_mean, 0)
56
+ nn.init.constant_(self.running_var, 1)
57
+ if self.affine:
58
+ nn.init.constant_(self.weight, 1)
59
+ nn.init.constant_(self.bias, 0)
60
+
61
+ def forward(self, x):
62
+ x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
63
+ self.training, self.momentum, self.eps)
64
+
65
+ if self.activation == ACT_RELU:
66
+ return functional.relu(x, inplace=True)
67
+ elif self.activation == ACT_LEAKY_RELU:
68
+ return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69
+ elif self.activation == ACT_ELU:
70
+ return functional.elu(x, inplace=True)
71
+ else:
72
+ return x
73
+
74
+ def __repr__(self):
75
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
76
+ ' affine={affine}, activation={activation}'
77
+ if self.activation == "leaky_relu":
78
+ rep += ', slope={slope})'
79
+ else:
80
+ rep += ')'
81
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
82
+
83
+
84
+ class InPlaceABN(ABN):
85
+ """InPlace Activated Batch Normalization"""
86
+
87
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
88
+ """Creates an InPlace Activated Batch Normalization module
89
+
90
+ Parameters
91
+ ----------
92
+ num_features : int
93
+ Number of feature channels in the input and output.
94
+ eps : float
95
+ Small constant to prevent numerical issues.
96
+ momentum : float
97
+ Momentum factor applied to compute running statistics as.
98
+ affine : bool
99
+ If `True` apply learned scale and shift transformation after normalization.
100
+ activation : str
101
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
102
+ slope : float
103
+ Negative slope for the `leaky_relu` activation.
104
+ """
105
+ super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
106
+
107
+ def forward(self, x):
108
+ x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
109
+ self.training, self.momentum, self.eps, self.activation, self.slope)
110
+ return x
111
+
112
+
113
+ class InPlaceABNSync(ABN):
114
+ """InPlace Activated Batch Normalization with cross-GPU synchronization
115
+ This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
116
+ """
117
+
118
+ def forward(self, x):
119
+ x, _, _ = inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
120
+ self.training, self.momentum, self.eps, self.activation, self.slope)
121
+ return x
122
+
123
+ def __repr__(self):
124
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
125
+ ' affine={affine}, activation={activation}'
126
+ if self.activation == "leaky_relu":
127
+ rep += ', slope={slope})'
128
+ else:
129
+ rep += ')'
130
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
131
+
132
+
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/deeplab.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ from models._util import try_index
6
+ from .bn import ABN
7
+
8
+
9
+ class DeeplabV3(nn.Module):
10
+ def __init__(self,
11
+ in_channels,
12
+ out_channels,
13
+ hidden_channels=256,
14
+ dilations=(12, 24, 36),
15
+ norm_act=ABN,
16
+ pooling_size=None):
17
+ super(DeeplabV3, self).__init__()
18
+ self.pooling_size = pooling_size
19
+
20
+ self.map_convs = nn.ModuleList([
21
+ nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
22
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
23
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
24
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
25
+ ])
26
+ self.map_bn = norm_act(hidden_channels * 4)
27
+
28
+ self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
29
+ self.global_pooling_bn = norm_act(hidden_channels)
30
+
31
+ self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
32
+ self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
33
+ self.red_bn = norm_act(out_channels)
34
+
35
+ self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
36
+
37
+ def reset_parameters(self, activation, slope):
38
+ gain = nn.init.calculate_gain(activation, slope)
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ nn.init.xavier_normal_(m.weight.data, gain)
42
+ if hasattr(m, "bias") and m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+ elif isinstance(m, ABN):
45
+ if hasattr(m, "weight") and m.weight is not None:
46
+ nn.init.constant_(m.weight, 1)
47
+ if hasattr(m, "bias") and m.bias is not None:
48
+ nn.init.constant_(m.bias, 0)
49
+
50
+ def forward(self, x):
51
+ # Map convolutions
52
+ out = torch.cat([m(x) for m in self.map_convs], dim=1)
53
+ out = self.map_bn(out)
54
+ out = self.red_conv(out)
55
+
56
+ # Global pooling
57
+ pool = self._global_pooling(x)
58
+ pool = self.global_pooling_conv(pool)
59
+ pool = self.global_pooling_bn(pool)
60
+ pool = self.pool_red_conv(pool)
61
+ if self.training or self.pooling_size is None:
62
+ pool = pool.repeat(1, 1, x.size(2), x.size(3))
63
+
64
+ out += pool
65
+ out = self.red_bn(out)
66
+ return out
67
+
68
+ def _global_pooling(self, x):
69
+ if self.training or self.pooling_size is None:
70
+ pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
71
+ pool = pool.view(x.size(0), x.size(1), 1, 1)
72
+ else:
73
+ pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
74
+ min(try_index(self.pooling_size, 1), x.shape[3]))
75
+ padding = (
76
+ (pooling_size[1] - 1) // 2,
77
+ (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
78
+ (pooling_size[0] - 1) // 2,
79
+ (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
80
+ )
81
+
82
+ pool = functional.avg_pool2d(x, pooling_size, stride=1)
83
+ pool = functional.pad(pool, pad=padding, mode="replicate")
84
+ return pool
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/dense.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .bn import ABN
7
+
8
+
9
+ class DenseModule(nn.Module):
10
+ def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
11
+ super(DenseModule, self).__init__()
12
+ self.in_channels = in_channels
13
+ self.growth = growth
14
+ self.layers = layers
15
+
16
+ self.convs1 = nn.ModuleList()
17
+ self.convs3 = nn.ModuleList()
18
+ for i in range(self.layers):
19
+ self.convs1.append(nn.Sequential(OrderedDict([
20
+ ("bn", norm_act(in_channels)),
21
+ ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
22
+ ])))
23
+ self.convs3.append(nn.Sequential(OrderedDict([
24
+ ("bn", norm_act(self.growth * bottleneck_factor)),
25
+ ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
26
+ dilation=dilation))
27
+ ])))
28
+ in_channels += self.growth
29
+
30
+ @property
31
+ def out_channels(self):
32
+ return self.in_channels + self.growth * self.layers
33
+
34
+ def forward(self, x):
35
+ inputs = [x]
36
+ for i in range(self.layers):
37
+ x = torch.cat(inputs, dim=1)
38
+ x = self.convs1[i](x)
39
+ x = self.convs3[i](x)
40
+ inputs += [x]
41
+
42
+ return torch.cat(inputs, dim=1)
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/functions.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ from os import path
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.autograd as autograd
6
+ import torch.cuda.comm as comm
7
+ from torch.autograd.function import once_differentiable
8
+ from torch.utils.cpp_extension import load
9
+
10
+ _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
11
+ _backend = load(name="inplace_abn",
12
+ extra_cflags=["-O3"],
13
+ sources=[path.join(_src_path, f) for f in [
14
+ "inplace_abn.cpp",
15
+ "inplace_abn_cpu.cpp",
16
+ "inplace_abn_cuda.cu",
17
+ "inplace_abn_cuda_half.cu"
18
+ ]],
19
+ extra_cuda_cflags=["--expt-extended-lambda"])
20
+
21
+ # Activation names
22
+ ACT_RELU = "relu"
23
+ ACT_LEAKY_RELU = "leaky_relu"
24
+ ACT_ELU = "elu"
25
+ ACT_NONE = "none"
26
+
27
+
28
+ def _check(fn, *args, **kwargs):
29
+ success = fn(*args, **kwargs)
30
+ if not success:
31
+ raise RuntimeError("CUDA Error encountered in {}".format(fn))
32
+
33
+
34
+ def _broadcast_shape(x):
35
+ out_size = []
36
+ for i, s in enumerate(x.size()):
37
+ if i != 1:
38
+ out_size.append(1)
39
+ else:
40
+ out_size.append(s)
41
+ return out_size
42
+
43
+
44
+ def _reduce(x):
45
+ if len(x.size()) == 2:
46
+ return x.sum(dim=0)
47
+ else:
48
+ n, c = x.size()[0:2]
49
+ return x.contiguous().view((n, c, -1)).sum(2).sum(0)
50
+
51
+
52
+ def _count_samples(x):
53
+ count = 1
54
+ for i, s in enumerate(x.size()):
55
+ if i != 1:
56
+ count *= s
57
+ return count
58
+
59
+
60
+ def _act_forward(ctx, x):
61
+ if ctx.activation == ACT_LEAKY_RELU:
62
+ _backend.leaky_relu_forward(x, ctx.slope)
63
+ elif ctx.activation == ACT_ELU:
64
+ _backend.elu_forward(x)
65
+ elif ctx.activation == ACT_NONE:
66
+ pass
67
+
68
+
69
+ def _act_backward(ctx, x, dx):
70
+ if ctx.activation == ACT_LEAKY_RELU:
71
+ _backend.leaky_relu_backward(x, dx, ctx.slope)
72
+ elif ctx.activation == ACT_ELU:
73
+ _backend.elu_backward(x, dx)
74
+ elif ctx.activation == ACT_NONE:
75
+ pass
76
+
77
+
78
+ class InPlaceABN(autograd.Function):
79
+ @staticmethod
80
+ def forward(ctx, x, weight, bias, running_mean, running_var,
81
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
82
+ # Save context
83
+ ctx.training = training
84
+ ctx.momentum = momentum
85
+ ctx.eps = eps
86
+ ctx.activation = activation
87
+ ctx.slope = slope
88
+ ctx.affine = weight is not None and bias is not None
89
+
90
+ # Prepare inputs
91
+ count = _count_samples(x)
92
+ x = x.contiguous()
93
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
94
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
95
+
96
+ if ctx.training:
97
+ mean, var = _backend.mean_var(x)
98
+
99
+ # Update running stats
100
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
101
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
102
+
103
+ # Mark in-place modified tensors
104
+ ctx.mark_dirty(x, running_mean, running_var)
105
+ else:
106
+ mean, var = running_mean.contiguous(), running_var.contiguous()
107
+ ctx.mark_dirty(x)
108
+
109
+ # BN forward + activation
110
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
111
+ _act_forward(ctx, x)
112
+
113
+ # Output
114
+ ctx.var = var
115
+ ctx.save_for_backward(x, var, weight, bias)
116
+ ctx.mark_non_differentiable(running_mean, running_var)
117
+ return x, running_mean, running_var
118
+
119
+ @staticmethod
120
+ @once_differentiable
121
+ def backward(ctx, dz, _drunning_mean, _drunning_var):
122
+ z, var, weight, bias = ctx.saved_tensors
123
+ dz = dz.contiguous()
124
+
125
+ # Undo activation
126
+ _act_backward(ctx, z, dz)
127
+
128
+ if ctx.training:
129
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
130
+ else:
131
+ # TODO: implement simplified CUDA backward for inference mode
132
+ edz = dz.new_zeros(dz.size(1))
133
+ eydz = dz.new_zeros(dz.size(1))
134
+
135
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
136
+ # dweight = eydz * weight.sign() if ctx.affine else None
137
+ dweight = eydz if ctx.affine else None
138
+ if dweight is not None:
139
+ dweight[weight < 0] *= -1
140
+ dbias = edz if ctx.affine else None
141
+
142
+ return dx, dweight, dbias, None, None, None, None, None, None, None
143
+
144
+
145
+ class InPlaceABNSync(autograd.Function):
146
+ @classmethod
147
+ def forward(cls, ctx, x, weight, bias, running_mean, running_var,
148
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
149
+ # Save context
150
+ ctx.training = training
151
+ ctx.momentum = momentum
152
+ ctx.eps = eps
153
+ ctx.activation = activation
154
+ ctx.slope = slope
155
+ ctx.affine = weight is not None and bias is not None
156
+
157
+ # Prepare inputs
158
+ ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
159
+
160
+ # count = _count_samples(x)
161
+ batch_size = x.new_tensor([x.shape[0]], dtype=torch.long)
162
+
163
+ x = x.contiguous()
164
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
165
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
166
+
167
+ if ctx.training:
168
+ mean, var = _backend.mean_var(x)
169
+ if ctx.world_size > 1:
170
+ # get global batch size
171
+ if equal_batches:
172
+ batch_size *= ctx.world_size
173
+ else:
174
+ dist.all_reduce(batch_size, dist.ReduceOp.SUM)
175
+
176
+ ctx.factor = x.shape[0] / float(batch_size.item())
177
+
178
+ mean_all = mean.clone() * ctx.factor
179
+ dist.all_reduce(mean_all, dist.ReduceOp.SUM)
180
+
181
+ var_all = (var + (mean - mean_all) ** 2) * ctx.factor
182
+ dist.all_reduce(var_all, dist.ReduceOp.SUM)
183
+
184
+ mean = mean_all
185
+ var = var_all
186
+
187
+ # Update running stats
188
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
189
+ count = batch_size.item() * x.view(x.shape[0], x.shape[1], -1).shape[-1]
190
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
191
+
192
+ # Mark in-place modified tensors
193
+ ctx.mark_dirty(x, running_mean, running_var)
194
+ else:
195
+ mean, var = running_mean.contiguous(), running_var.contiguous()
196
+ ctx.mark_dirty(x)
197
+
198
+ # BN forward + activation
199
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
200
+ _act_forward(ctx, x)
201
+
202
+ # Output
203
+ ctx.var = var
204
+ ctx.save_for_backward(x, var, weight, bias)
205
+ ctx.mark_non_differentiable(running_mean, running_var)
206
+ return x, running_mean, running_var
207
+
208
+ @staticmethod
209
+ @once_differentiable
210
+ def backward(ctx, dz, _drunning_mean, _drunning_var):
211
+ z, var, weight, bias = ctx.saved_tensors
212
+ dz = dz.contiguous()
213
+
214
+ # Undo activation
215
+ _act_backward(ctx, z, dz)
216
+
217
+ if ctx.training:
218
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
219
+ edz_local = edz.clone()
220
+ eydz_local = eydz.clone()
221
+
222
+ if ctx.world_size > 1:
223
+ edz *= ctx.factor
224
+ dist.all_reduce(edz, dist.ReduceOp.SUM)
225
+
226
+ eydz *= ctx.factor
227
+ dist.all_reduce(eydz, dist.ReduceOp.SUM)
228
+ else:
229
+ edz_local = edz = dz.new_zeros(dz.size(1))
230
+ eydz_local = eydz = dz.new_zeros(dz.size(1))
231
+
232
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
233
+ # dweight = eydz_local * weight.sign() if ctx.affine else None
234
+ dweight = eydz_local if ctx.affine else None
235
+ if dweight is not None:
236
+ dweight[weight < 0] *= -1
237
+ dbias = edz_local if ctx.affine else None
238
+
239
+ return dx, dweight, dbias, None, None, None, None, None, None, None
240
+
241
+
242
+ inplace_abn = InPlaceABN.apply
243
+ inplace_abn_sync = InPlaceABNSync.apply
244
+
245
+ __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/misc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ class GlobalAvgPool2d(nn.Module):
6
+ def __init__(self):
7
+ """Global average pooling over the input's spatial dimensions"""
8
+ super(GlobalAvgPool2d, self).__init__()
9
+
10
+ def forward(self, inputs):
11
+ in_size = inputs.size()
12
+ return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
13
+
14
+ class SingleGPU(nn.Module):
15
+ def __init__(self, module):
16
+ super(SingleGPU, self).__init__()
17
+ self.module=module
18
+
19
+ def forward(self, input):
20
+ return self.module(input.cuda(non_blocking=True))
21
+
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/residual.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch.nn as nn
4
+
5
+ from .bn import ABN, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
6
+ import torch.nn.functional as functional
7
+
8
+
9
+ class ResidualBlock(nn.Module):
10
+ """Configurable residual block
11
+
12
+ Parameters
13
+ ----------
14
+ in_channels : int
15
+ Number of input channels.
16
+ channels : list of int
17
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
18
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
19
+ `3 x 3` then `1 x 1` convolutions.
20
+ stride : int
21
+ Stride of the first `3 x 3` convolution
22
+ dilation : int
23
+ Dilation to apply to the `3 x 3` convolutions.
24
+ groups : int
25
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
26
+ bottleneck blocks.
27
+ norm_act : callable
28
+ Function to create normalization / activation Module.
29
+ dropout: callable
30
+ Function to create Dropout Module.
31
+ """
32
+
33
+ def __init__(self,
34
+ in_channels,
35
+ channels,
36
+ stride=1,
37
+ dilation=1,
38
+ groups=1,
39
+ norm_act=ABN,
40
+ dropout=None):
41
+ super(ResidualBlock, self).__init__()
42
+
43
+ # Check parameters for inconsistencies
44
+ if len(channels) != 2 and len(channels) != 3:
45
+ raise ValueError("channels must contain either two or three values")
46
+ if len(channels) == 2 and groups != 1:
47
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
48
+
49
+ is_bottleneck = len(channels) == 3
50
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
51
+
52
+ if not is_bottleneck:
53
+ bn2 = norm_act(channels[1])
54
+ bn2.activation = ACT_NONE
55
+ layers = [
56
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
57
+ dilation=dilation)),
58
+ ("bn1", norm_act(channels[0])),
59
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
60
+ dilation=dilation)),
61
+ ("bn2", bn2)
62
+ ]
63
+ if dropout is not None:
64
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
65
+ else:
66
+ bn3 = norm_act(channels[2])
67
+ bn3.activation = ACT_NONE
68
+ layers = [
69
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)),
70
+ ("bn1", norm_act(channels[0])),
71
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False,
72
+ groups=groups, dilation=dilation)),
73
+ ("bn2", norm_act(channels[1])),
74
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)),
75
+ ("bn3", bn3)
76
+ ]
77
+ if dropout is not None:
78
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
79
+ self.convs = nn.Sequential(OrderedDict(layers))
80
+
81
+ if need_proj_conv:
82
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
83
+ self.proj_bn = norm_act(channels[-1])
84
+ self.proj_bn.activation = ACT_NONE
85
+
86
+ def forward(self, x):
87
+ if hasattr(self, "proj_conv"):
88
+ residual = self.proj_conv(x)
89
+ residual = self.proj_bn(residual)
90
+ else:
91
+ residual = x
92
+ x = self.convs(x) + residual
93
+
94
+ if self.convs.bn1.activation == ACT_LEAKY_RELU:
95
+ return functional.leaky_relu(x, negative_slope=self.convs.bn1.slope, inplace=True)
96
+ elif self.convs.bn1.activation == ACT_ELU:
97
+ return functional.elu(x, inplace=True)
98
+ else:
99
+ return x
100
+
101
+
102
+ class IdentityResidualBlock(nn.Module):
103
+ def __init__(self,
104
+ in_channels,
105
+ channels,
106
+ stride=1,
107
+ dilation=1,
108
+ groups=1,
109
+ norm_act=ABN,
110
+ dropout=None):
111
+ """Configurable identity-mapping residual block
112
+
113
+ Parameters
114
+ ----------
115
+ in_channels : int
116
+ Number of input channels.
117
+ channels : list of int
118
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
119
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
120
+ `3 x 3` then `1 x 1` convolutions.
121
+ stride : int
122
+ Stride of the first `3 x 3` convolution
123
+ dilation : int
124
+ Dilation to apply to the `3 x 3` convolutions.
125
+ groups : int
126
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
127
+ bottleneck blocks.
128
+ norm_act : callable
129
+ Function to create normalization / activation Module.
130
+ dropout: callable
131
+ Function to create Dropout Module.
132
+ """
133
+ super(IdentityResidualBlock, self).__init__()
134
+
135
+ # Check parameters for inconsistencies
136
+ if len(channels) != 2 and len(channels) != 3:
137
+ raise ValueError("channels must contain either two or three values")
138
+ if len(channels) == 2 and groups != 1:
139
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
140
+
141
+ is_bottleneck = len(channels) == 3
142
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
143
+
144
+ self.bn1 = norm_act(in_channels)
145
+ if not is_bottleneck:
146
+ layers = [
147
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
148
+ dilation=dilation)),
149
+ ("bn2", norm_act(channels[0])),
150
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
151
+ dilation=dilation))
152
+ ]
153
+ if dropout is not None:
154
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
155
+ else:
156
+ layers = [
157
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
158
+ ("bn2", norm_act(channels[0])),
159
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
160
+ groups=groups, dilation=dilation)),
161
+ ("bn3", norm_act(channels[1])),
162
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
163
+ ]
164
+ if dropout is not None:
165
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
166
+ self.convs = nn.Sequential(OrderedDict(layers))
167
+
168
+ if need_proj_conv:
169
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
170
+
171
+ def forward(self, x):
172
+ if hasattr(self, "proj_conv"):
173
+ bn1 = self.bn1(x)
174
+ shortcut = self.proj_conv(bn1)
175
+ else:
176
+ shortcut = x.clone()
177
+ bn1 = self.bn1(x)
178
+
179
+ out = self.convs(bn1)
180
+ out.add_(shortcut)
181
+
182
+ return out
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn.cpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "inplace_abn.h"
6
+
7
+ std::vector<at::Tensor> mean_var(at::Tensor x) {
8
+ if (x.is_cuda()) {
9
+ if (x.type().scalarType() == at::ScalarType::Half) {
10
+ return mean_var_cuda_h(x);
11
+ } else {
12
+ return mean_var_cuda(x);
13
+ }
14
+ } else {
15
+ return mean_var_cpu(x);
16
+ }
17
+ }
18
+
19
+ at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
20
+ bool affine, float eps) {
21
+ if (x.is_cuda()) {
22
+ if (x.type().scalarType() == at::ScalarType::Half) {
23
+ return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
24
+ } else {
25
+ return forward_cuda(x, mean, var, weight, bias, affine, eps);
26
+ }
27
+ } else {
28
+ return forward_cpu(x, mean, var, weight, bias, affine, eps);
29
+ }
30
+ }
31
+
32
+ std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
33
+ bool affine, float eps) {
34
+ if (z.is_cuda()) {
35
+ if (z.type().scalarType() == at::ScalarType::Half) {
36
+ return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
37
+ } else {
38
+ return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
39
+ }
40
+ } else {
41
+ return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
42
+ }
43
+ }
44
+
45
+ at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
46
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
47
+ if (z.is_cuda()) {
48
+ if (z.type().scalarType() == at::ScalarType::Half) {
49
+ return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
50
+ } else {
51
+ return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
52
+ }
53
+ } else {
54
+ return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
55
+ }
56
+ }
57
+
58
+ void leaky_relu_forward(at::Tensor z, float slope) {
59
+ at::leaky_relu_(z, slope);
60
+ }
61
+
62
+ void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
63
+ if (z.is_cuda()) {
64
+ if (z.type().scalarType() == at::ScalarType::Half) {
65
+ return leaky_relu_backward_cuda_h(z, dz, slope);
66
+ } else {
67
+ return leaky_relu_backward_cuda(z, dz, slope);
68
+ }
69
+ } else {
70
+ return leaky_relu_backward_cpu(z, dz, slope);
71
+ }
72
+ }
73
+
74
+ void elu_forward(at::Tensor z) {
75
+ at::elu_(z);
76
+ }
77
+
78
+ void elu_backward(at::Tensor z, at::Tensor dz) {
79
+ if (z.is_cuda()) {
80
+ return elu_backward_cuda(z, dz);
81
+ } else {
82
+ return elu_backward_cpu(z, dz);
83
+ }
84
+ }
85
+
86
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87
+ m.def("mean_var", &mean_var, "Mean and variance computation");
88
+ m.def("forward", &forward, "In-place forward computation");
89
+ m.def("edz_eydz", &edz_eydz, "First part of backward computation");
90
+ m.def("backward", &backward, "Second part of backward computation");
91
+ m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
92
+ m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
93
+ m.def("elu_forward", &elu_forward, "Elu forward computation");
94
+ m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
95
+ }
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <vector>
6
+
7
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
8
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
9
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);
10
+
11
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
12
+ bool affine, float eps);
13
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
14
+ bool affine, float eps);
15
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16
+ bool affine, float eps);
17
+
18
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
19
+ bool affine, float eps);
20
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
21
+ bool affine, float eps);
22
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
23
+ bool affine, float eps);
24
+
25
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
26
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
27
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
28
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
29
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
30
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
31
+
32
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
33
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
34
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
35
+
36
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz);
37
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz);
38
+
39
+ static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
40
+ num = x.size(0);
41
+ chn = x.size(1);
42
+ sp = 1;
43
+ for (int64_t i = 2; i < x.ndimension(); ++i)
44
+ sp *= x.size(i);
45
+ }
46
+
47
+ /*
48
+ * Specialized CUDA reduction functions for BN
49
+ */
50
+ #ifdef __CUDACC__
51
+
52
+ #include "utils/cuda.cuh"
53
+
54
+ template <typename T, typename Op>
55
+ __device__ T reduce(Op op, int plane, int N, int S) {
56
+ T sum = (T)0;
57
+ for (int batch = 0; batch < N; ++batch) {
58
+ for (int x = threadIdx.x; x < S; x += blockDim.x) {
59
+ sum += op(batch, plane, x);
60
+ }
61
+ }
62
+
63
+ // sum over NumThreads within a warp
64
+ sum = warpSum(sum);
65
+
66
+ // 'transpose', and reduce within warp again
67
+ __shared__ T shared[32];
68
+ __syncthreads();
69
+ if (threadIdx.x % WARP_SIZE == 0) {
70
+ shared[threadIdx.x / WARP_SIZE] = sum;
71
+ }
72
+ if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
73
+ // zero out the other entries in shared
74
+ shared[threadIdx.x] = (T)0;
75
+ }
76
+ __syncthreads();
77
+ if (threadIdx.x / WARP_SIZE == 0) {
78
+ sum = warpSum(shared[threadIdx.x]);
79
+ if (threadIdx.x == 0) {
80
+ shared[0] = sum;
81
+ }
82
+ }
83
+ __syncthreads();
84
+
85
+ // Everyone picks it up, should be broadcast into the whole gradInput
86
+ return shared[0];
87
+ }
88
+ #endif
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cpu.cpp ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "utils/checks.h"
6
+ #include "inplace_abn.h"
7
+
8
+ at::Tensor reduce_sum(at::Tensor x) {
9
+ if (x.ndimension() == 2) {
10
+ return x.sum(0);
11
+ } else {
12
+ auto x_view = x.view({x.size(0), x.size(1), -1});
13
+ return x_view.sum(-1).sum(0);
14
+ }
15
+ }
16
+
17
+ at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
18
+ if (x.ndimension() == 2) {
19
+ return v;
20
+ } else {
21
+ std::vector<int64_t> broadcast_size = {1, -1};
22
+ for (int64_t i = 2; i < x.ndimension(); ++i)
23
+ broadcast_size.push_back(1);
24
+
25
+ return v.view(broadcast_size);
26
+ }
27
+ }
28
+
29
+ int64_t count(at::Tensor x) {
30
+ int64_t count = x.size(0);
31
+ for (int64_t i = 2; i < x.ndimension(); ++i)
32
+ count *= x.size(i);
33
+
34
+ return count;
35
+ }
36
+
37
+ at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
38
+ if (affine) {
39
+ return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
40
+ } else {
41
+ return z;
42
+ }
43
+ }
44
+
45
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x) {
46
+ auto num = count(x);
47
+ auto mean = reduce_sum(x) / num;
48
+ auto diff = x - broadcast_to(mean, x);
49
+ auto var = reduce_sum(diff.pow(2)) / num;
50
+
51
+ return {mean, var};
52
+ }
53
+
54
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
55
+ bool affine, float eps) {
56
+ auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
57
+ auto mul = at::rsqrt(var + eps) * gamma;
58
+
59
+ x.sub_(broadcast_to(mean, x));
60
+ x.mul_(broadcast_to(mul, x));
61
+ if (affine) x.add_(broadcast_to(bias, x));
62
+
63
+ return x;
64
+ }
65
+
66
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
67
+ bool affine, float eps) {
68
+ auto edz = reduce_sum(dz);
69
+ auto y = invert_affine(z, weight, bias, affine, eps);
70
+ auto eydz = reduce_sum(y * dz);
71
+
72
+ return {edz, eydz};
73
+ }
74
+
75
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
76
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
77
+ auto y = invert_affine(z, weight, bias, affine, eps);
78
+ auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
79
+
80
+ auto num = count(z);
81
+ auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
82
+ return dx;
83
+ }
84
+
85
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
86
+ CHECK_CPU_INPUT(z);
87
+ CHECK_CPU_INPUT(dz);
88
+
89
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
90
+ int64_t count = z.numel();
91
+ auto *_z = z.data<scalar_t>();
92
+ auto *_dz = dz.data<scalar_t>();
93
+
94
+ for (int64_t i = 0; i < count; ++i) {
95
+ if (_z[i] < 0) {
96
+ _z[i] *= 1 / slope;
97
+ _dz[i] *= slope;
98
+ }
99
+ }
100
+ }));
101
+ }
102
+
103
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
104
+ CHECK_CPU_INPUT(z);
105
+ CHECK_CPU_INPUT(dz);
106
+
107
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
108
+ int64_t count = z.numel();
109
+ auto *_z = z.data<scalar_t>();
110
+ auto *_dz = dz.data<scalar_t>();
111
+
112
+ for (int64_t i = 0; i < count; ++i) {
113
+ if (_z[i] < 0) {
114
+ _z[i] = log1p(_z[i]);
115
+ _dz[i] *= (_z[i] + 1.f);
116
+ }
117
+ }
118
+ }));
119
+ }
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cuda.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <thrust/device_ptr.h>
4
+ #include <thrust/transform.h>
5
+
6
+ #include <vector>
7
+
8
+ #include "utils/checks.h"
9
+ #include "utils/cuda.cuh"
10
+ #include "inplace_abn.h"
11
+
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ // Operations for reduce
15
+ template<typename T>
16
+ struct SumOp {
17
+ __device__ SumOp(const T *t, int c, int s)
18
+ : tensor(t), chn(c), sp(s) {}
19
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
20
+ return tensor[(batch * chn + plane) * sp + n];
21
+ }
22
+ const T *tensor;
23
+ const int chn;
24
+ const int sp;
25
+ };
26
+
27
+ template<typename T>
28
+ struct VarOp {
29
+ __device__ VarOp(T m, const T *t, int c, int s)
30
+ : mean(m), tensor(t), chn(c), sp(s) {}
31
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
32
+ T val = tensor[(batch * chn + plane) * sp + n];
33
+ return (val - mean) * (val - mean);
34
+ }
35
+ const T mean;
36
+ const T *tensor;
37
+ const int chn;
38
+ const int sp;
39
+ };
40
+
41
+ template<typename T>
42
+ struct GradOp {
43
+ __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
44
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
45
+ __device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
46
+ T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
47
+ T _dz = dz[(batch * chn + plane) * sp + n];
48
+ return Pair<T>(_dz, _y * _dz);
49
+ }
50
+ const T weight;
51
+ const T bias;
52
+ const T *z;
53
+ const T *dz;
54
+ const int chn;
55
+ const int sp;
56
+ };
57
+
58
+ /***********
59
+ * mean_var
60
+ ***********/
61
+
62
+ template<typename T>
63
+ __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
64
+ int plane = blockIdx.x;
65
+ T norm = T(1) / T(num * sp);
66
+
67
+ T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
68
+ __syncthreads();
69
+ T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
70
+
71
+ if (threadIdx.x == 0) {
72
+ mean[plane] = _mean;
73
+ var[plane] = _var;
74
+ }
75
+ }
76
+
77
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
78
+ CHECK_CUDA_INPUT(x);
79
+
80
+ // Extract dimensions
81
+ int64_t num, chn, sp;
82
+ get_dims(x, num, chn, sp);
83
+
84
+ // Prepare output tensors
85
+ auto mean = at::empty({chn}, x.options());
86
+ auto var = at::empty({chn}, x.options());
87
+
88
+ // Run kernel
89
+ dim3 blocks(chn);
90
+ dim3 threads(getNumThreads(sp));
91
+ auto stream = at::cuda::getCurrentCUDAStream();
92
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
93
+ mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
94
+ x.data<scalar_t>(),
95
+ mean.data<scalar_t>(),
96
+ var.data<scalar_t>(),
97
+ num, chn, sp);
98
+ }));
99
+
100
+ return {mean, var};
101
+ }
102
+
103
+ /**********
104
+ * forward
105
+ **********/
106
+
107
+ template<typename T>
108
+ __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
109
+ bool affine, float eps, int num, int chn, int sp) {
110
+ int plane = blockIdx.x;
111
+
112
+ T _mean = mean[plane];
113
+ T _var = var[plane];
114
+ T _weight = affine ? abs(weight[plane]) + eps : T(1);
115
+ T _bias = affine ? bias[plane] : T(0);
116
+
117
+ T mul = rsqrt(_var + eps) * _weight;
118
+
119
+ for (int batch = 0; batch < num; ++batch) {
120
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
121
+ T _x = x[(batch * chn + plane) * sp + n];
122
+ T _y = (_x - _mean) * mul + _bias;
123
+
124
+ x[(batch * chn + plane) * sp + n] = _y;
125
+ }
126
+ }
127
+ }
128
+
129
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
130
+ bool affine, float eps) {
131
+ CHECK_CUDA_INPUT(x);
132
+ CHECK_CUDA_INPUT(mean);
133
+ CHECK_CUDA_INPUT(var);
134
+ CHECK_CUDA_INPUT(weight);
135
+ CHECK_CUDA_INPUT(bias);
136
+
137
+ // Extract dimensions
138
+ int64_t num, chn, sp;
139
+ get_dims(x, num, chn, sp);
140
+
141
+ // Run kernel
142
+ dim3 blocks(chn);
143
+ dim3 threads(getNumThreads(sp));
144
+ auto stream = at::cuda::getCurrentCUDAStream();
145
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
146
+ forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
147
+ x.data<scalar_t>(),
148
+ mean.data<scalar_t>(),
149
+ var.data<scalar_t>(),
150
+ weight.data<scalar_t>(),
151
+ bias.data<scalar_t>(),
152
+ affine, eps, num, chn, sp);
153
+ }));
154
+
155
+ return x;
156
+ }
157
+
158
+ /***********
159
+ * edz_eydz
160
+ ***********/
161
+
162
+ template<typename T>
163
+ __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
164
+ T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
165
+ int plane = blockIdx.x;
166
+
167
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
168
+ T _bias = affine ? bias[plane] : 0.f;
169
+
170
+ Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
171
+ __syncthreads();
172
+
173
+ if (threadIdx.x == 0) {
174
+ edz[plane] = res.v1;
175
+ eydz[plane] = res.v2;
176
+ }
177
+ }
178
+
179
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
180
+ bool affine, float eps) {
181
+ CHECK_CUDA_INPUT(z);
182
+ CHECK_CUDA_INPUT(dz);
183
+ CHECK_CUDA_INPUT(weight);
184
+ CHECK_CUDA_INPUT(bias);
185
+
186
+ // Extract dimensions
187
+ int64_t num, chn, sp;
188
+ get_dims(z, num, chn, sp);
189
+
190
+ auto edz = at::empty({chn}, z.options());
191
+ auto eydz = at::empty({chn}, z.options());
192
+
193
+ // Run kernel
194
+ dim3 blocks(chn);
195
+ dim3 threads(getNumThreads(sp));
196
+ auto stream = at::cuda::getCurrentCUDAStream();
197
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
198
+ edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
199
+ z.data<scalar_t>(),
200
+ dz.data<scalar_t>(),
201
+ weight.data<scalar_t>(),
202
+ bias.data<scalar_t>(),
203
+ edz.data<scalar_t>(),
204
+ eydz.data<scalar_t>(),
205
+ affine, eps, num, chn, sp);
206
+ }));
207
+
208
+ return {edz, eydz};
209
+ }
210
+
211
+ /***********
212
+ * backward
213
+ ***********/
214
+
215
+ template<typename T>
216
+ __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
217
+ const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
218
+ int plane = blockIdx.x;
219
+
220
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
221
+ T _bias = affine ? bias[plane] : 0.f;
222
+ T _var = var[plane];
223
+ T _edz = edz[plane];
224
+ T _eydz = eydz[plane];
225
+
226
+ T _mul = _weight * rsqrt(_var + eps);
227
+ T count = T(num * sp);
228
+
229
+ for (int batch = 0; batch < num; ++batch) {
230
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
231
+ T _dz = dz[(batch * chn + plane) * sp + n];
232
+ T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
233
+
234
+ dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
235
+ }
236
+ }
237
+ }
238
+
239
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
240
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
241
+ CHECK_CUDA_INPUT(z);
242
+ CHECK_CUDA_INPUT(dz);
243
+ CHECK_CUDA_INPUT(var);
244
+ CHECK_CUDA_INPUT(weight);
245
+ CHECK_CUDA_INPUT(bias);
246
+ CHECK_CUDA_INPUT(edz);
247
+ CHECK_CUDA_INPUT(eydz);
248
+
249
+ // Extract dimensions
250
+ int64_t num, chn, sp;
251
+ get_dims(z, num, chn, sp);
252
+
253
+ auto dx = at::zeros_like(z);
254
+
255
+ // Run kernel
256
+ dim3 blocks(chn);
257
+ dim3 threads(getNumThreads(sp));
258
+ auto stream = at::cuda::getCurrentCUDAStream();
259
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
260
+ backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
261
+ z.data<scalar_t>(),
262
+ dz.data<scalar_t>(),
263
+ var.data<scalar_t>(),
264
+ weight.data<scalar_t>(),
265
+ bias.data<scalar_t>(),
266
+ edz.data<scalar_t>(),
267
+ eydz.data<scalar_t>(),
268
+ dx.data<scalar_t>(),
269
+ affine, eps, num, chn, sp);
270
+ }));
271
+
272
+ return dx;
273
+ }
274
+
275
+ /**************
276
+ * activations
277
+ **************/
278
+
279
+ template<typename T>
280
+ inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
281
+ // Create thrust pointers
282
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
283
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
284
+
285
+ auto stream = at::cuda::getCurrentCUDAStream();
286
+ thrust::transform_if(thrust::cuda::par.on(stream),
287
+ th_dz, th_dz + count, th_z, th_dz,
288
+ [slope] __device__ (const T& dz) { return dz * slope; },
289
+ [] __device__ (const T& z) { return z < 0; });
290
+ thrust::transform_if(thrust::cuda::par.on(stream),
291
+ th_z, th_z + count, th_z,
292
+ [slope] __device__ (const T& z) { return z / slope; },
293
+ [] __device__ (const T& z) { return z < 0; });
294
+ }
295
+
296
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
297
+ CHECK_CUDA_INPUT(z);
298
+ CHECK_CUDA_INPUT(dz);
299
+
300
+ int64_t count = z.numel();
301
+
302
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
303
+ leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
304
+ }));
305
+ }
306
+
307
+ template<typename T>
308
+ inline void elu_backward_impl(T *z, T *dz, int64_t count) {
309
+ // Create thrust pointers
310
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
311
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
312
+
313
+ auto stream = at::cuda::getCurrentCUDAStream();
314
+ thrust::transform_if(thrust::cuda::par.on(stream),
315
+ th_dz, th_dz + count, th_z, th_z, th_dz,
316
+ [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
317
+ [] __device__ (const T& z) { return z < 0; });
318
+ thrust::transform_if(thrust::cuda::par.on(stream),
319
+ th_z, th_z + count, th_z,
320
+ [] __device__ (const T& z) { return log1p(z); },
321
+ [] __device__ (const T& z) { return z < 0; });
322
+ }
323
+
324
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
325
+ CHECK_CUDA_INPUT(z);
326
+ CHECK_CUDA_INPUT(dz);
327
+
328
+ int64_t count = z.numel();
329
+
330
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
331
+ elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
332
+ }));
333
+ }
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/inplace_abn_cuda_half.cu ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <cuda_fp16.h>
4
+
5
+ #include <vector>
6
+
7
+ #include "utils/checks.h"
8
+ #include "utils/cuda.cuh"
9
+ #include "inplace_abn.h"
10
+
11
+ #include <ATen/cuda/CUDAContext.h>
12
+
13
+ // Operations for reduce
14
+ struct SumOpH {
15
+ __device__ SumOpH(const half *t, int c, int s)
16
+ : tensor(t), chn(c), sp(s) {}
17
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
18
+ return __half2float(tensor[(batch * chn + plane) * sp + n]);
19
+ }
20
+ const half *tensor;
21
+ const int chn;
22
+ const int sp;
23
+ };
24
+
25
+ struct VarOpH {
26
+ __device__ VarOpH(float m, const half *t, int c, int s)
27
+ : mean(m), tensor(t), chn(c), sp(s) {}
28
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
29
+ const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
30
+ return (t - mean) * (t - mean);
31
+ }
32
+ const float mean;
33
+ const half *tensor;
34
+ const int chn;
35
+ const int sp;
36
+ };
37
+
38
+ struct GradOpH {
39
+ __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
40
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
41
+ __device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) {
42
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
43
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
44
+ return Pair<float>(_dz, _y * _dz);
45
+ }
46
+ const float weight;
47
+ const float bias;
48
+ const half *z;
49
+ const half *dz;
50
+ const int chn;
51
+ const int sp;
52
+ };
53
+
54
+ /***********
55
+ * mean_var
56
+ ***********/
57
+
58
+ __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
59
+ int plane = blockIdx.x;
60
+ float norm = 1.f / static_cast<float>(num * sp);
61
+
62
+ float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm;
63
+ __syncthreads();
64
+ float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
65
+
66
+ if (threadIdx.x == 0) {
67
+ mean[plane] = _mean;
68
+ var[plane] = _var;
69
+ }
70
+ }
71
+
72
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) {
73
+ CHECK_CUDA_INPUT(x);
74
+
75
+ // Extract dimensions
76
+ int64_t num, chn, sp;
77
+ get_dims(x, num, chn, sp);
78
+
79
+ // Prepare output tensors
80
+ auto mean = at::empty({chn},x.options().dtype(at::kFloat));
81
+ auto var = at::empty({chn},x.options().dtype(at::kFloat));
82
+
83
+ // Run kernel
84
+ dim3 blocks(chn);
85
+ dim3 threads(getNumThreads(sp));
86
+ auto stream = at::cuda::getCurrentCUDAStream();
87
+ mean_var_kernel_h<<<blocks, threads, 0, stream>>>(
88
+ reinterpret_cast<half*>(x.data<at::Half>()),
89
+ mean.data<float>(),
90
+ var.data<float>(),
91
+ num, chn, sp);
92
+
93
+ return {mean, var};
94
+ }
95
+
96
+ /**********
97
+ * forward
98
+ **********/
99
+
100
+ __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
101
+ bool affine, float eps, int num, int chn, int sp) {
102
+ int plane = blockIdx.x;
103
+
104
+ const float _mean = mean[plane];
105
+ const float _var = var[plane];
106
+ const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
107
+ const float _bias = affine ? bias[plane] : 0.f;
108
+
109
+ const float mul = rsqrt(_var + eps) * _weight;
110
+
111
+ for (int batch = 0; batch < num; ++batch) {
112
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
113
+ half *x_ptr = x + (batch * chn + plane) * sp + n;
114
+ float _x = __half2float(*x_ptr);
115
+ float _y = (_x - _mean) * mul + _bias;
116
+
117
+ *x_ptr = __float2half(_y);
118
+ }
119
+ }
120
+ }
121
+
122
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
123
+ bool affine, float eps) {
124
+ CHECK_CUDA_INPUT(x);
125
+ CHECK_CUDA_INPUT(mean);
126
+ CHECK_CUDA_INPUT(var);
127
+ CHECK_CUDA_INPUT(weight);
128
+ CHECK_CUDA_INPUT(bias);
129
+
130
+ // Extract dimensions
131
+ int64_t num, chn, sp;
132
+ get_dims(x, num, chn, sp);
133
+
134
+ // Run kernel
135
+ dim3 blocks(chn);
136
+ dim3 threads(getNumThreads(sp));
137
+ auto stream = at::cuda::getCurrentCUDAStream();
138
+ forward_kernel_h<<<blocks, threads, 0, stream>>>(
139
+ reinterpret_cast<half*>(x.data<at::Half>()),
140
+ mean.data<float>(),
141
+ var.data<float>(),
142
+ weight.data<float>(),
143
+ bias.data<float>(),
144
+ affine, eps, num, chn, sp);
145
+
146
+ return x;
147
+ }
148
+
149
+ __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
150
+ float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
151
+ int plane = blockIdx.x;
152
+
153
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
154
+ float _bias = affine ? bias[plane] : 0.f;
155
+
156
+ Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
157
+ __syncthreads();
158
+
159
+ if (threadIdx.x == 0) {
160
+ edz[plane] = res.v1;
161
+ eydz[plane] = res.v2;
162
+ }
163
+ }
164
+
165
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
166
+ bool affine, float eps) {
167
+ CHECK_CUDA_INPUT(z);
168
+ CHECK_CUDA_INPUT(dz);
169
+ CHECK_CUDA_INPUT(weight);
170
+ CHECK_CUDA_INPUT(bias);
171
+
172
+ // Extract dimensions
173
+ int64_t num, chn, sp;
174
+ get_dims(z, num, chn, sp);
175
+
176
+ auto edz = at::empty({chn},z.options().dtype(at::kFloat));
177
+ auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
178
+
179
+ // Run kernel
180
+ dim3 blocks(chn);
181
+ dim3 threads(getNumThreads(sp));
182
+ auto stream = at::cuda::getCurrentCUDAStream();
183
+ edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>(
184
+ reinterpret_cast<half*>(z.data<at::Half>()),
185
+ reinterpret_cast<half*>(dz.data<at::Half>()),
186
+ weight.data<float>(),
187
+ bias.data<float>(),
188
+ edz.data<float>(),
189
+ eydz.data<float>(),
190
+ affine, eps, num, chn, sp);
191
+
192
+ return {edz, eydz};
193
+ }
194
+
195
+ __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
196
+ const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
197
+ int plane = blockIdx.x;
198
+
199
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
200
+ float _bias = affine ? bias[plane] : 0.f;
201
+ float _var = var[plane];
202
+ float _edz = edz[plane];
203
+ float _eydz = eydz[plane];
204
+
205
+ float _mul = _weight * rsqrt(_var + eps);
206
+ float count = float(num * sp);
207
+
208
+ for (int batch = 0; batch < num; ++batch) {
209
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
210
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
211
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
212
+
213
+ dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
214
+ }
215
+ }
216
+ }
217
+
218
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
219
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
220
+ CHECK_CUDA_INPUT(z);
221
+ CHECK_CUDA_INPUT(dz);
222
+ CHECK_CUDA_INPUT(var);
223
+ CHECK_CUDA_INPUT(weight);
224
+ CHECK_CUDA_INPUT(bias);
225
+ CHECK_CUDA_INPUT(edz);
226
+ CHECK_CUDA_INPUT(eydz);
227
+
228
+ // Extract dimensions
229
+ int64_t num, chn, sp;
230
+ get_dims(z, num, chn, sp);
231
+
232
+ auto dx = at::zeros_like(z);
233
+
234
+ // Run kernel
235
+ dim3 blocks(chn);
236
+ dim3 threads(getNumThreads(sp));
237
+ auto stream = at::cuda::getCurrentCUDAStream();
238
+ backward_kernel_h<<<blocks, threads, 0, stream>>>(
239
+ reinterpret_cast<half*>(z.data<at::Half>()),
240
+ reinterpret_cast<half*>(dz.data<at::Half>()),
241
+ var.data<float>(),
242
+ weight.data<float>(),
243
+ bias.data<float>(),
244
+ edz.data<float>(),
245
+ eydz.data<float>(),
246
+ reinterpret_cast<half*>(dx.data<at::Half>()),
247
+ affine, eps, num, chn, sp);
248
+
249
+ return dx;
250
+ }
251
+
252
+ __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
253
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
254
+ float _z = __half2float(z[i]);
255
+ if (_z < 0) {
256
+ dz[i] = __float2half(__half2float(dz[i]) * slope);
257
+ z[i] = __float2half(_z / slope);
258
+ }
259
+ }
260
+ }
261
+
262
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
263
+ CHECK_CUDA_INPUT(z);
264
+ CHECK_CUDA_INPUT(dz);
265
+
266
+ int64_t count = z.numel();
267
+ dim3 threads(getNumThreads(count));
268
+ dim3 blocks = (count + threads.x - 1) / threads.x;
269
+ auto stream = at::cuda::getCurrentCUDAStream();
270
+ leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>(
271
+ reinterpret_cast<half*>(z.data<at::Half>()),
272
+ reinterpret_cast<half*>(dz.data<at::Half>()),
273
+ slope, count);
274
+ }
275
+
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/common.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ /*
6
+ * Functions to share code between CPU and GPU
7
+ */
8
+
9
+ #ifdef __CUDACC__
10
+ // CUDA versions
11
+
12
+ #define HOST_DEVICE __host__ __device__
13
+ #define INLINE_HOST_DEVICE __host__ __device__ inline
14
+ #define FLOOR(x) floor(x)
15
+
16
+ #if __CUDA_ARCH__ >= 600
17
+ // Recent compute capabilities have block-level atomicAdd for all data types, so we use that
18
+ #define ACCUM(x,y) atomicAdd_block(&(x),(y))
19
+ #else
20
+ // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
21
+ // and use the known atomicCAS-based implementation for double
22
+ template<typename data_t>
23
+ __device__ inline data_t atomic_add(data_t *address, data_t val) {
24
+ return atomicAdd(address, val);
25
+ }
26
+
27
+ template<>
28
+ __device__ inline double atomic_add(double *address, double val) {
29
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
30
+ unsigned long long int old = *address_as_ull, assumed;
31
+ do {
32
+ assumed = old;
33
+ old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
34
+ } while (assumed != old);
35
+ return __longlong_as_double(old);
36
+ }
37
+
38
+ #define ACCUM(x,y) atomic_add(&(x),(y))
39
+ #endif // #if __CUDA_ARCH__ >= 600
40
+
41
+ #else
42
+ // CPU versions
43
+
44
+ #define HOST_DEVICE
45
+ #define INLINE_HOST_DEVICE inline
46
+ #define FLOOR(x) std::floor(x)
47
+ #define ACCUM(x,y) (x) += (y)
48
+
49
+ #endif // #ifdef __CUDACC__
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/modules/src/utils/cuda.cuh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /*
4
+ * General settings and functions
5
+ */
6
+ const int WARP_SIZE = 32;
7
+ const int MAX_BLOCK_SIZE = 1024;
8
+
9
+ static int getNumThreads(int nElem) {
10
+ int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
11
+ for (int i = 0; i < 6; ++i) {
12
+ if (nElem <= threadSizes[i]) {
13
+ return threadSizes[i];
14
+ }
15
+ }
16
+ return MAX_BLOCK_SIZE;
17
+ }
18
+
19
+ /*
20
+ * Reduction utilities
21
+ */
22
+ template <typename T>
23
+ __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24
+ unsigned int mask = 0xffffffff) {
25
+ #if CUDART_VERSION >= 9000
26
+ return __shfl_xor_sync(mask, value, laneMask, width);
27
+ #else
28
+ return __shfl_xor(value, laneMask, width);
29
+ #endif
30
+ }
31
+
32
+ __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33
+
34
+ template<typename T>
35
+ struct Pair {
36
+ T v1, v2;
37
+ __device__ Pair() {}
38
+ __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
39
+ __device__ Pair(T v) : v1(v), v2(v) {}
40
+ __device__ Pair(int v) : v1(v), v2(v) {}
41
+ __device__ Pair &operator+=(const Pair<T> &a) {
42
+ v1 += a.v1;
43
+ v2 += a.v2;
44
+ return *this;
45
+ }
46
+ };
47
+
48
+ template<typename T>
49
+ static __device__ __forceinline__ T warpSum(T val) {
50
+ #if __CUDA_ARCH__ >= 300
51
+ for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
52
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
53
+ }
54
+ #else
55
+ __shared__ T values[MAX_BLOCK_SIZE];
56
+ values[threadIdx.x] = val;
57
+ __threadfence_block();
58
+ const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
59
+ for (int i = 1; i < WARP_SIZE; i++) {
60
+ val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
61
+ }
62
+ #endif
63
+ return val;
64
+ }
65
+
66
+ template<typename T>
67
+ static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
68
+ value.v1 = warpSum(value.v1);
69
+ value.v2 = warpSum(value.v2);
70
+ return value;
71
+ }
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/AugmentCE2P.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : AugmentCE2P.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import functools
15
+ import pdb
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ # Note here we adopt the InplaceABNSync implementation from https://github.com/mapillary/inplace_abn
21
+ # By default, the InplaceABNSync module contains a BatchNorm Layer and a LeakyReLu layer
22
+ from modules import InPlaceABNSync
23
+ import numpy as np
24
+
25
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
26
+
27
+ affine_par = True
28
+
29
+ pretrained_settings = {
30
+ 'resnet101': {
31
+ 'imagenet': {
32
+ 'input_space': 'BGR',
33
+ 'input_size': [3, 224, 224],
34
+ 'input_range': [0, 1],
35
+ 'mean': [0.406, 0.456, 0.485],
36
+ 'std': [0.225, 0.224, 0.229],
37
+ 'num_classes': 1000
38
+ }
39
+ },
40
+ }
41
+
42
+
43
+ def conv3x3(in_planes, out_planes, stride=1):
44
+ "3x3 convolution with padding"
45
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
46
+ padding=1, bias=False)
47
+
48
+
49
+ class Bottleneck(nn.Module):
50
+ expansion = 4
51
+
52
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
53
+ super(Bottleneck, self).__init__()
54
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
55
+ self.bn1 = BatchNorm2d(planes)
56
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
57
+ padding=dilation * multi_grid, dilation=dilation * multi_grid, bias=False)
58
+ self.bn2 = BatchNorm2d(planes)
59
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
60
+ self.bn3 = BatchNorm2d(planes * 4)
61
+ self.relu = nn.ReLU(inplace=False)
62
+ self.relu_inplace = nn.ReLU(inplace=True)
63
+ self.downsample = downsample
64
+ self.dilation = dilation
65
+ self.stride = stride
66
+
67
+ def forward(self, x):
68
+ residual = x
69
+
70
+ out = self.conv1(x)
71
+ out = self.bn1(out)
72
+ out = self.relu(out)
73
+
74
+ out = self.conv2(out)
75
+ out = self.bn2(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv3(out)
79
+ out = self.bn3(out)
80
+
81
+ if self.downsample is not None:
82
+ residual = self.downsample(x)
83
+
84
+ out = out + residual
85
+ out = self.relu_inplace(out)
86
+
87
+ return out
88
+
89
+
90
+ class CostomAdaptiveAvgPool2D(nn.Module):
91
+
92
+ def __init__(self, output_size):
93
+
94
+ super(CostomAdaptiveAvgPool2D, self).__init__()
95
+
96
+ self.output_size = output_size
97
+
98
+ def forward(self, x):
99
+
100
+ H_in, W_in = x.shape[-2:]
101
+ H_out, W_out = self.output_size
102
+
103
+ out_i = []
104
+ for i in range(H_out):
105
+ out_j = []
106
+ for j in range(W_out):
107
+ hs = int(np.floor(i * H_in / H_out))
108
+ he = int(np.ceil((i + 1) * H_in / H_out))
109
+
110
+ ws = int(np.floor(j * W_in / W_out))
111
+ we = int(np.ceil((j + 1) * W_in / W_out))
112
+
113
+ # print(hs, he, ws, we)
114
+ kernel_size = [he - hs, we - ws]
115
+
116
+ out = F.avg_pool2d(x[:, :, hs:he, ws:we], kernel_size)
117
+ out_j.append(out)
118
+
119
+ out_j = torch.concat(out_j, -1)
120
+ out_i.append(out_j)
121
+
122
+ out_i = torch.concat(out_i, -2)
123
+ return out_i
124
+
125
+
126
+ class PSPModule(nn.Module):
127
+ """
128
+ Reference:
129
+ Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
130
+ """
131
+
132
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
133
+ super(PSPModule, self).__init__()
134
+
135
+ self.stages = []
136
+ tmp = []
137
+ for size in sizes:
138
+ if size == 3 or size == 6:
139
+ tmp.append(self._make_stage_custom(features, out_features, size))
140
+ else:
141
+ tmp.append(self._make_stage(features, out_features, size))
142
+ self.stages = nn.ModuleList(tmp)
143
+ # self.stages = nn.ModuleList([self._make_stage(features, out_features, size) for size in sizes])
144
+ self.bottleneck = nn.Sequential(
145
+ nn.Conv2d(features + len(sizes) * out_features, out_features, kernel_size=3, padding=1, dilation=1,
146
+ bias=False),
147
+ InPlaceABNSync(out_features),
148
+ )
149
+
150
+ def _make_stage(self, features, out_features, size):
151
+ prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
152
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
153
+ bn = InPlaceABNSync(out_features)
154
+ return nn.Sequential(prior, conv, bn)
155
+
156
+ def _make_stage_custom(self, features, out_features, size):
157
+ prior = CostomAdaptiveAvgPool2D(output_size=(size, size))
158
+ conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
159
+ bn = InPlaceABNSync(out_features)
160
+ return nn.Sequential(prior, conv, bn)
161
+
162
+ def forward(self, feats):
163
+ h, w = feats.size(2), feats.size(3)
164
+ priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in
165
+ self.stages] + [feats]
166
+ bottle = self.bottleneck(torch.cat(priors, 1))
167
+ return bottle
168
+
169
+
170
+ class ASPPModule(nn.Module):
171
+ """
172
+ Reference:
173
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
174
+ """
175
+
176
+ def __init__(self, features, inner_features=256, out_features=512, dilations=(12, 24, 36)):
177
+ super(ASPPModule, self).__init__()
178
+
179
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
180
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
181
+ bias=False),
182
+ InPlaceABNSync(inner_features))
183
+ self.conv2 = nn.Sequential(
184
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
185
+ InPlaceABNSync(inner_features))
186
+ self.conv3 = nn.Sequential(
187
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
188
+ InPlaceABNSync(inner_features))
189
+ self.conv4 = nn.Sequential(
190
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
191
+ InPlaceABNSync(inner_features))
192
+ self.conv5 = nn.Sequential(
193
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
194
+ InPlaceABNSync(inner_features))
195
+
196
+ self.bottleneck = nn.Sequential(
197
+ nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
198
+ InPlaceABNSync(out_features),
199
+ nn.Dropout2d(0.1)
200
+ )
201
+
202
+ def forward(self, x):
203
+ _, _, h, w = x.size()
204
+
205
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
206
+
207
+ feat2 = self.conv2(x)
208
+ feat3 = self.conv3(x)
209
+ feat4 = self.conv4(x)
210
+ feat5 = self.conv5(x)
211
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
212
+
213
+ bottle = self.bottleneck(out)
214
+ return bottle
215
+
216
+
217
+ class Edge_Module(nn.Module):
218
+ """
219
+ Edge Learning Branch
220
+ """
221
+
222
+ def __init__(self, in_fea=[256, 512, 1024], mid_fea=256, out_fea=2):
223
+ super(Edge_Module, self).__init__()
224
+
225
+ self.conv1 = nn.Sequential(
226
+ nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
227
+ InPlaceABNSync(mid_fea)
228
+ )
229
+ self.conv2 = nn.Sequential(
230
+ nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
231
+ InPlaceABNSync(mid_fea)
232
+ )
233
+ self.conv3 = nn.Sequential(
234
+ nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
235
+ InPlaceABNSync(mid_fea)
236
+ )
237
+ self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
238
+ self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
239
+
240
+ def forward(self, x1, x2, x3):
241
+ _, _, h, w = x1.size()
242
+
243
+ edge1_fea = self.conv1(x1)
244
+ edge1 = self.conv4(edge1_fea)
245
+ edge2_fea = self.conv2(x2)
246
+ edge2 = self.conv4(edge2_fea)
247
+ edge3_fea = self.conv3(x3)
248
+ edge3 = self.conv4(edge3_fea)
249
+
250
+ edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear', align_corners=True)
251
+ edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear', align_corners=True)
252
+ edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear', align_corners=True)
253
+ edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear', align_corners=True)
254
+
255
+ edge = torch.cat([edge1, edge2, edge3], dim=1)
256
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
257
+ edge = self.conv5(edge)
258
+
259
+ return edge, edge_fea
260
+
261
+
262
+ class Decoder_Module(nn.Module):
263
+ """
264
+ Parsing Branch Decoder Module.
265
+ """
266
+
267
+ def __init__(self, num_classes):
268
+ super(Decoder_Module, self).__init__()
269
+ self.conv1 = nn.Sequential(
270
+ nn.Conv2d(512, 256, kernel_size=1, padding=0, dilation=1, bias=False),
271
+ InPlaceABNSync(256)
272
+ )
273
+ self.conv2 = nn.Sequential(
274
+ nn.Conv2d(256, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
275
+ InPlaceABNSync(48)
276
+ )
277
+ self.conv3 = nn.Sequential(
278
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
279
+ InPlaceABNSync(256),
280
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
281
+ InPlaceABNSync(256)
282
+ )
283
+
284
+ self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
285
+
286
+ def forward(self, xt, xl):
287
+ _, _, h, w = xl.size()
288
+ xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True)
289
+ xl = self.conv2(xl)
290
+ x = torch.cat([xt, xl], dim=1)
291
+ x = self.conv3(x)
292
+ seg = self.conv4(x)
293
+ return seg, x
294
+
295
+
296
+ class ResNet(nn.Module):
297
+ def __init__(self, block, layers, num_classes):
298
+ self.inplanes = 128
299
+ super(ResNet, self).__init__()
300
+ self.conv1 = conv3x3(3, 64, stride=2)
301
+ self.bn1 = BatchNorm2d(64)
302
+ self.relu1 = nn.ReLU(inplace=False)
303
+ self.conv2 = conv3x3(64, 64)
304
+ self.bn2 = BatchNorm2d(64)
305
+ self.relu2 = nn.ReLU(inplace=False)
306
+ self.conv3 = conv3x3(64, 128)
307
+ self.bn3 = BatchNorm2d(128)
308
+ self.relu3 = nn.ReLU(inplace=False)
309
+
310
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
311
+
312
+ self.layer1 = self._make_layer(block, 64, layers[0])
313
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
314
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
315
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 1, 1))
316
+
317
+ self.context_encoding = PSPModule(2048, 512)
318
+
319
+ self.edge = Edge_Module()
320
+ self.decoder = Decoder_Module(num_classes)
321
+
322
+ self.fushion = nn.Sequential(
323
+ nn.Conv2d(1024, 256, kernel_size=1, padding=0, dilation=1, bias=False),
324
+ InPlaceABNSync(256),
325
+ nn.Dropout2d(0.1),
326
+ nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
327
+ )
328
+
329
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
330
+ downsample = None
331
+ if stride != 1 or self.inplanes != planes * block.expansion:
332
+ downsample = nn.Sequential(
333
+ nn.Conv2d(self.inplanes, planes * block.expansion,
334
+ kernel_size=1, stride=stride, bias=False),
335
+ BatchNorm2d(planes * block.expansion, affine=affine_par))
336
+
337
+ layers = []
338
+ generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1
339
+ layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample,
340
+ multi_grid=generate_multi_grid(0, multi_grid)))
341
+ self.inplanes = planes * block.expansion
342
+ for i in range(1, blocks):
343
+ layers.append(
344
+ block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
345
+
346
+ return nn.Sequential(*layers)
347
+
348
+ def forward(self, x):
349
+ x = self.relu1(self.bn1(self.conv1(x)))
350
+ x = self.relu2(self.bn2(self.conv2(x)))
351
+ x = self.relu3(self.bn3(self.conv3(x)))
352
+ x = self.maxpool(x)
353
+ x2 = self.layer1(x)
354
+ x3 = self.layer2(x2)
355
+ x4 = self.layer3(x3)
356
+ x5 = self.layer4(x4)
357
+ x = self.context_encoding(x5)
358
+ parsing_result, parsing_fea = self.decoder(x, x2)
359
+ # Edge Branch
360
+ edge_result, edge_fea = self.edge(x2, x3, x4)
361
+ # Fusion Branch
362
+ x = torch.cat([parsing_fea, edge_fea], dim=1)
363
+ fusion_result = self.fushion(x)
364
+ return [[parsing_result, fusion_result], edge_result]
365
+
366
+
367
+ def initialize_pretrained_model(model, settings, pretrained='./models/resnet101-imagenet.pth'):
368
+ model.input_space = settings['input_space']
369
+ model.input_size = settings['input_size']
370
+ model.input_range = settings['input_range']
371
+ model.mean = settings['mean']
372
+ model.std = settings['std']
373
+
374
+ if pretrained is not None:
375
+ saved_state_dict = torch.load(pretrained)
376
+ new_params = model.state_dict().copy()
377
+ for i in saved_state_dict:
378
+ i_parts = i.split('.')
379
+ if not i_parts[0] == 'fc':
380
+ new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
381
+ model.load_state_dict(new_params)
382
+
383
+
384
+ def resnet101(num_classes=20, pretrained='./models/resnet101-imagenet.pth'):
385
+ model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
386
+ settings = pretrained_settings['resnet101']['imagenet']
387
+ initialize_pretrained_model(model, settings, pretrained)
388
+ return model
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from networks.AugmentCE2P import resnet101
3
+
4
+ __factory = {
5
+ 'resnet101': resnet101,
6
+ }
7
+
8
+
9
+ def init_model(name, *args, **kwargs):
10
+ if name not in __factory.keys():
11
+ raise KeyError("Unknown model arch: {}".format(name))
12
+ return __factory[name](*args, **kwargs)
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/mobilenetv2.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : mobilenetv2.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import torch.nn as nn
15
+ import math
16
+ import functools
17
+
18
+ from modules import InPlaceABN, InPlaceABNSync
19
+
20
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
21
+
22
+ __all__ = ['mobilenetv2']
23
+
24
+
25
+ def conv_bn(inp, oup, stride):
26
+ return nn.Sequential(
27
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
28
+ BatchNorm2d(oup),
29
+ nn.ReLU6(inplace=True)
30
+ )
31
+
32
+
33
+ def conv_1x1_bn(inp, oup):
34
+ return nn.Sequential(
35
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36
+ BatchNorm2d(oup),
37
+ nn.ReLU6(inplace=True)
38
+ )
39
+
40
+
41
+ class InvertedResidual(nn.Module):
42
+ def __init__(self, inp, oup, stride, expand_ratio):
43
+ super(InvertedResidual, self).__init__()
44
+ self.stride = stride
45
+ assert stride in [1, 2]
46
+
47
+ hidden_dim = round(inp * expand_ratio)
48
+ self.use_res_connect = self.stride == 1 and inp == oup
49
+
50
+ if expand_ratio == 1:
51
+ self.conv = nn.Sequential(
52
+ # dw
53
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
54
+ BatchNorm2d(hidden_dim),
55
+ nn.ReLU6(inplace=True),
56
+ # pw-linear
57
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
58
+ BatchNorm2d(oup),
59
+ )
60
+ else:
61
+ self.conv = nn.Sequential(
62
+ # pw
63
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
64
+ BatchNorm2d(hidden_dim),
65
+ nn.ReLU6(inplace=True),
66
+ # dw
67
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
68
+ BatchNorm2d(hidden_dim),
69
+ nn.ReLU6(inplace=True),
70
+ # pw-linear
71
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
72
+ BatchNorm2d(oup),
73
+ )
74
+
75
+ def forward(self, x):
76
+ if self.use_res_connect:
77
+ return x + self.conv(x)
78
+ else:
79
+ return self.conv(x)
80
+
81
+
82
+ class MobileNetV2(nn.Module):
83
+ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
84
+ super(MobileNetV2, self).__init__()
85
+ block = InvertedResidual
86
+ input_channel = 32
87
+ last_channel = 1280
88
+ interverted_residual_setting = [
89
+ # t, c, n, s
90
+ [1, 16, 1, 1],
91
+ [6, 24, 2, 2], # layer 2
92
+ [6, 32, 3, 2], # layer 3
93
+ [6, 64, 4, 2],
94
+ [6, 96, 3, 1], # layer 4
95
+ [6, 160, 3, 2],
96
+ [6, 320, 1, 1], # layer 5
97
+ ]
98
+
99
+ # building first layer
100
+ assert input_size % 32 == 0
101
+ input_channel = int(input_channel * width_mult)
102
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
103
+ self.features = [conv_bn(3, input_channel, 2)]
104
+ # building inverted residual blocks
105
+ for t, c, n, s in interverted_residual_setting:
106
+ output_channel = int(c * width_mult)
107
+ for i in range(n):
108
+ if i == 0:
109
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
110
+ else:
111
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
112
+ input_channel = output_channel
113
+ # building last several layers
114
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
115
+ # make it nn.Sequential
116
+ self.features = nn.Sequential(*self.features)
117
+
118
+ # building classifier
119
+ self.classifier = nn.Sequential(
120
+ nn.Dropout(0.2),
121
+ nn.Linear(self.last_channel, n_class),
122
+ )
123
+
124
+ self._initialize_weights()
125
+
126
+ def forward(self, x):
127
+ x = self.features(x)
128
+ x = x.mean(3).mean(2)
129
+ x = self.classifier(x)
130
+ return x
131
+
132
+ def _initialize_weights(self):
133
+ for m in self.modules():
134
+ if isinstance(m, nn.Conv2d):
135
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
136
+ m.weight.data.normal_(0, math.sqrt(2. / n))
137
+ if m.bias is not None:
138
+ m.bias.data.zero_()
139
+ elif isinstance(m, BatchNorm2d):
140
+ m.weight.data.fill_(1)
141
+ m.bias.data.zero_()
142
+ elif isinstance(m, nn.Linear):
143
+ n = m.weight.size(1)
144
+ m.weight.data.normal_(0, 0.01)
145
+ m.bias.data.zero_()
146
+
147
+
148
+ def mobilenetv2(pretrained=False, **kwargs):
149
+ """Constructs a MobileNet_V2 model.
150
+ Args:
151
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
152
+ """
153
+ model = MobileNetV2(n_class=1000, **kwargs)
154
+ if pretrained:
155
+ model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
156
+ return model
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/resnet.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : resnet.py
8
+ @Time : 8/4/19 3:35 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import functools
15
+ import torch.nn as nn
16
+ import math
17
+ from torch.utils.model_zoo import load_url
18
+
19
+ from modules import InPlaceABNSync
20
+
21
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
22
+
23
+ __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
24
+
25
+ model_urls = {
26
+ 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
27
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
28
+ 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
29
+ }
30
+
31
+
32
+ def conv3x3(in_planes, out_planes, stride=1):
33
+ "3x3 convolution with padding"
34
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35
+ padding=1, bias=False)
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
42
+ super(BasicBlock, self).__init__()
43
+ self.conv1 = conv3x3(inplanes, planes, stride)
44
+ self.bn1 = BatchNorm2d(planes)
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.conv2 = conv3x3(planes, planes)
47
+ self.bn2 = BatchNorm2d(planes)
48
+ self.downsample = downsample
49
+ self.stride = stride
50
+
51
+ def forward(self, x):
52
+ residual = x
53
+
54
+ out = self.conv1(x)
55
+ out = self.bn1(out)
56
+ out = self.relu(out)
57
+
58
+ out = self.conv2(out)
59
+ out = self.bn2(out)
60
+
61
+ if self.downsample is not None:
62
+ residual = self.downsample(x)
63
+
64
+ out += residual
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class Bottleneck(nn.Module):
71
+ expansion = 4
72
+
73
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
74
+ super(Bottleneck, self).__init__()
75
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
76
+ self.bn1 = BatchNorm2d(planes)
77
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
78
+ padding=1, bias=False)
79
+ self.bn2 = BatchNorm2d(planes)
80
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
81
+ self.bn3 = BatchNorm2d(planes * 4)
82
+ self.relu = nn.ReLU(inplace=True)
83
+ self.downsample = downsample
84
+ self.stride = stride
85
+
86
+ def forward(self, x):
87
+ residual = x
88
+
89
+ out = self.conv1(x)
90
+ out = self.bn1(out)
91
+ out = self.relu(out)
92
+
93
+ out = self.conv2(out)
94
+ out = self.bn2(out)
95
+ out = self.relu(out)
96
+
97
+ out = self.conv3(out)
98
+ out = self.bn3(out)
99
+
100
+ if self.downsample is not None:
101
+ residual = self.downsample(x)
102
+
103
+ out += residual
104
+ out = self.relu(out)
105
+
106
+ return out
107
+
108
+
109
+ class ResNet(nn.Module):
110
+
111
+ def __init__(self, block, layers, num_classes=1000):
112
+ self.inplanes = 128
113
+ super(ResNet, self).__init__()
114
+ self.conv1 = conv3x3(3, 64, stride=2)
115
+ self.bn1 = BatchNorm2d(64)
116
+ self.relu1 = nn.ReLU(inplace=True)
117
+ self.conv2 = conv3x3(64, 64)
118
+ self.bn2 = BatchNorm2d(64)
119
+ self.relu2 = nn.ReLU(inplace=True)
120
+ self.conv3 = conv3x3(64, 128)
121
+ self.bn3 = BatchNorm2d(128)
122
+ self.relu3 = nn.ReLU(inplace=True)
123
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
124
+
125
+ self.layer1 = self._make_layer(block, 64, layers[0])
126
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
127
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
128
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
129
+ self.avgpool = nn.AvgPool2d(7, stride=1)
130
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
131
+
132
+ for m in self.modules():
133
+ if isinstance(m, nn.Conv2d):
134
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
135
+ m.weight.data.normal_(0, math.sqrt(2. / n))
136
+ elif isinstance(m, BatchNorm2d):
137
+ m.weight.data.fill_(1)
138
+ m.bias.data.zero_()
139
+
140
+ def _make_layer(self, block, planes, blocks, stride=1):
141
+ downsample = None
142
+ if stride != 1 or self.inplanes != planes * block.expansion:
143
+ downsample = nn.Sequential(
144
+ nn.Conv2d(self.inplanes, planes * block.expansion,
145
+ kernel_size=1, stride=stride, bias=False),
146
+ BatchNorm2d(planes * block.expansion),
147
+ )
148
+
149
+ layers = []
150
+ layers.append(block(self.inplanes, planes, stride, downsample))
151
+ self.inplanes = planes * block.expansion
152
+ for i in range(1, blocks):
153
+ layers.append(block(self.inplanes, planes))
154
+
155
+ return nn.Sequential(*layers)
156
+
157
+ def forward(self, x):
158
+ x = self.relu1(self.bn1(self.conv1(x)))
159
+ x = self.relu2(self.bn2(self.conv2(x)))
160
+ x = self.relu3(self.bn3(self.conv3(x)))
161
+ x = self.maxpool(x)
162
+
163
+ x = self.layer1(x)
164
+ x = self.layer2(x)
165
+ x = self.layer3(x)
166
+ x = self.layer4(x)
167
+
168
+ x = self.avgpool(x)
169
+ x = x.view(x.size(0), -1)
170
+ x = self.fc(x)
171
+
172
+ return x
173
+
174
+
175
+ def resnet18(pretrained=False, **kwargs):
176
+ """Constructs a ResNet-18 model.
177
+ Args:
178
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
179
+ """
180
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
181
+ if pretrained:
182
+ model.load_state_dict(load_url(model_urls['resnet18']))
183
+ return model
184
+
185
+
186
+ def resnet50(pretrained=False, **kwargs):
187
+ """Constructs a ResNet-50 model.
188
+ Args:
189
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
190
+ """
191
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
192
+ if pretrained:
193
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
194
+ return model
195
+
196
+
197
+ def resnet101(pretrained=False, **kwargs):
198
+ """Constructs a ResNet-101 model.
199
+ Args:
200
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
201
+ """
202
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
203
+ if pretrained:
204
+ model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
205
+ return model
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/backbone/resnext.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : resnext.py.py
8
+ @Time : 8/11/19 8:58 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+ import functools
14
+ import torch.nn as nn
15
+ import math
16
+ from torch.utils.model_zoo import load_url
17
+
18
+ from modules import InPlaceABNSync
19
+
20
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
21
+
22
+ __all__ = ['ResNeXt', 'resnext101'] # support resnext 101
23
+
24
+ model_urls = {
25
+ 'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
26
+ 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
27
+ }
28
+
29
+
30
+ def conv3x3(in_planes, out_planes, stride=1):
31
+ "3x3 convolution with padding"
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33
+ padding=1, bias=False)
34
+
35
+
36
+ class GroupBottleneck(nn.Module):
37
+ expansion = 2
38
+
39
+ def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
40
+ super(GroupBottleneck, self).__init__()
41
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
42
+ self.bn1 = BatchNorm2d(planes)
43
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
44
+ padding=1, groups=groups, bias=False)
45
+ self.bn2 = BatchNorm2d(planes)
46
+ self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
47
+ self.bn3 = BatchNorm2d(planes * 2)
48
+ self.relu = nn.ReLU(inplace=True)
49
+ self.downsample = downsample
50
+ self.stride = stride
51
+
52
+ def forward(self, x):
53
+ residual = x
54
+
55
+ out = self.conv1(x)
56
+ out = self.bn1(out)
57
+ out = self.relu(out)
58
+
59
+ out = self.conv2(out)
60
+ out = self.bn2(out)
61
+ out = self.relu(out)
62
+
63
+ out = self.conv3(out)
64
+ out = self.bn3(out)
65
+
66
+ if self.downsample is not None:
67
+ residual = self.downsample(x)
68
+
69
+ out += residual
70
+ out = self.relu(out)
71
+
72
+ return out
73
+
74
+
75
+ class ResNeXt(nn.Module):
76
+
77
+ def __init__(self, block, layers, groups=32, num_classes=1000):
78
+ self.inplanes = 128
79
+ super(ResNeXt, self).__init__()
80
+ self.conv1 = conv3x3(3, 64, stride=2)
81
+ self.bn1 = BatchNorm2d(64)
82
+ self.relu1 = nn.ReLU(inplace=True)
83
+ self.conv2 = conv3x3(64, 64)
84
+ self.bn2 = BatchNorm2d(64)
85
+ self.relu2 = nn.ReLU(inplace=True)
86
+ self.conv3 = conv3x3(64, 128)
87
+ self.bn3 = BatchNorm2d(128)
88
+ self.relu3 = nn.ReLU(inplace=True)
89
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
90
+
91
+ self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
92
+ self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
93
+ self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
94
+ self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
95
+ self.avgpool = nn.AvgPool2d(7, stride=1)
96
+ self.fc = nn.Linear(1024 * block.expansion, num_classes)
97
+
98
+ for m in self.modules():
99
+ if isinstance(m, nn.Conv2d):
100
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
101
+ m.weight.data.normal_(0, math.sqrt(2. / n))
102
+ elif isinstance(m, BatchNorm2d):
103
+ m.weight.data.fill_(1)
104
+ m.bias.data.zero_()
105
+
106
+ def _make_layer(self, block, planes, blocks, stride=1, groups=1):
107
+ downsample = None
108
+ if stride != 1 or self.inplanes != planes * block.expansion:
109
+ downsample = nn.Sequential(
110
+ nn.Conv2d(self.inplanes, planes * block.expansion,
111
+ kernel_size=1, stride=stride, bias=False),
112
+ BatchNorm2d(planes * block.expansion),
113
+ )
114
+
115
+ layers = []
116
+ layers.append(block(self.inplanes, planes, stride, groups, downsample))
117
+ self.inplanes = planes * block.expansion
118
+ for i in range(1, blocks):
119
+ layers.append(block(self.inplanes, planes, groups=groups))
120
+
121
+ return nn.Sequential(*layers)
122
+
123
+ def forward(self, x):
124
+ x = self.relu1(self.bn1(self.conv1(x)))
125
+ x = self.relu2(self.bn2(self.conv2(x)))
126
+ x = self.relu3(self.bn3(self.conv3(x)))
127
+ x = self.maxpool(x)
128
+
129
+ x = self.layer1(x)
130
+ x = self.layer2(x)
131
+ x = self.layer3(x)
132
+ x = self.layer4(x)
133
+
134
+ x = self.avgpool(x)
135
+ x = x.view(x.size(0), -1)
136
+ x = self.fc(x)
137
+
138
+ return x
139
+
140
+
141
+ def resnext101(pretrained=False, **kwargs):
142
+ """Constructs a ResNet-101 model.
143
+ Args:
144
+ pretrained (bool): If True, returns a model pre-trained on Places
145
+ """
146
+ model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
147
+ if pretrained:
148
+ model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
149
+ return model
src/multiview_consist_edit/parse_tool/preprocess/humanparsing/networks/context_encoding/aspp.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ @Author : Peike Li
6
+ @Contact : peike.li@yahoo.com
7
+ @File : aspp.py
8
+ @Time : 8/4/19 3:36 PM
9
+ @Desc :
10
+ @License : This source code is licensed under the license found in the
11
+ LICENSE file in the root directory of this source tree.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ from modules import InPlaceABNSync
19
+
20
+
21
+ class ASPPModule(nn.Module):
22
+ """
23
+ Reference:
24
+ Chen, Liang-Chieh, et al. *"Rethinking Atrous Convolution for Semantic Image Segmentation."*
25
+ """
26
+ def __init__(self, features, out_features=512, inner_features=256, dilations=(12, 24, 36)):
27
+ super(ASPPModule, self).__init__()
28
+
29
+ self.conv1 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
30
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1,
31
+ bias=False),
32
+ InPlaceABNSync(inner_features))
33
+ self.conv2 = nn.Sequential(
34
+ nn.Conv2d(features, inner_features, kernel_size=1, padding=0, dilation=1, bias=False),
35
+ InPlaceABNSync(inner_features))
36
+ self.conv3 = nn.Sequential(
37
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False),
38
+ InPlaceABNSync(inner_features))
39
+ self.conv4 = nn.Sequential(
40
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False),
41
+ InPlaceABNSync(inner_features))
42
+ self.conv5 = nn.Sequential(
43
+ nn.Conv2d(features, inner_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False),
44
+ InPlaceABNSync(inner_features))
45
+
46
+ self.bottleneck = nn.Sequential(
47
+ nn.Conv2d(inner_features * 5, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
48
+ InPlaceABNSync(out_features),
49
+ nn.Dropout2d(0.1)
50
+ )
51
+
52
+ def forward(self, x):
53
+ _, _, h, w = x.size()
54
+
55
+ feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
56
+
57
+ feat2 = self.conv2(x)
58
+ feat3 = self.conv3(x)
59
+ feat4 = self.conv4(x)
60
+ feat5 = self.conv5(x)
61
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1)
62
+
63
+ bottle = self.bottleneck(out)
64
+ return bottle