scy639 commited on
Commit
8c90776
·
verified ·
1 Parent(s): dcab600

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. Dataset_custom.py +317 -0
  3. LICENSE +23 -0
  4. LatentDiffusion.yaml +83 -0
  5. Mediapipe_Result_Cache.py +36 -0
  6. MoE.py +141 -0
  7. Other_dependencies/arcface/add.txt +1 -0
  8. Other_dependencies/arcface/model_ir_se50.pth +3 -0
  9. Other_dependencies/face_parsing/79999_iter.pth +3 -0
  10. Other_dependencies/face_parsing/add.txt +1 -0
  11. Other_dependencies/mp_models/blaze_face_short_range.tflite +3 -0
  12. Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task +3 -0
  13. app.py +239 -0
  14. checkpoints/pretrained.json +1072 -0
  15. download_checkpoints.py +29 -0
  16. eval_tool/lpips/__init__.py +0 -0
  17. eval_tool/lpips/lpips.py +35 -0
  18. eval_tool/lpips/networks.py +96 -0
  19. eval_tool/lpips/utils.py +30 -0
  20. examples/face/ref-semantic_mask.png +0 -0
  21. examples/face/ref.png +3 -0
  22. examples/face/tgt-semantic_mask.png +0 -0
  23. examples/face/tgt.png +3 -0
  24. examples/hair/ref-semantic_mask.png +0 -0
  25. examples/hair/ref.png +3 -0
  26. examples/hair/tgt-semantic_mask.png +0 -0
  27. examples/hair/tgt.png +3 -0
  28. examples/head/ref-semantic_mask.png +0 -0
  29. examples/head/ref.png +3 -0
  30. examples/head/tgt-semantic_mask.png +0 -0
  31. examples/head/tgt.png +3 -0
  32. examples/inputs.txt +5 -0
  33. examples/motion/ref-semantic_mask.png +0 -0
  34. examples/motion/ref.png +3 -0
  35. examples/motion/tgt-semantic_mask.png +0 -0
  36. examples/motion/tgt.png +3 -0
  37. gen_lmk_and_mask.py +41 -0
  38. gen_semantic_mask.py +90 -0
  39. get_mask.py +68 -0
  40. global_.py +9 -0
  41. hf_model.py +247 -0
  42. imports.py +8 -0
  43. infer.py +366 -0
  44. infer_hf.py +279 -0
  45. init_model.py +178 -0
  46. ldm/lr_scheduler.py +99 -0
  47. ldm/models/autoencoder.py +443 -0
  48. ldm/models/diffusion/__init__.py +0 -0
  49. ldm/models/diffusion/bank.py +76 -0
  50. ldm/models/diffusion/classifier.py +267 -0
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
37
+ examples/face/ref.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/face/tgt.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/hair/ref.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/hair/tgt.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/head/ref.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/head/tgt.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/motion/ref.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/motion/tgt.png filter=lfs diff=lfs merge=lfs -text
Dataset_custom.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imports import *
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import torch
7
+ import torch.utils.data as data
8
+ import torchvision.transforms as T
9
+ from einops import rearrange
10
+ import albumentations
11
+
12
+ from util_face import *
13
+ from util_4dataset import *
14
+ from util_cv2 import cv2_resize_auto_interpolation
15
+ from Mediapipe_Result_Cache import Mediapipe_Result_Cache
16
+
17
+
18
+ def resize_A(img, dataset_name, size=(512, 512), interpolation=None):
19
+ is_pil = isinstance(img, Image.Image)
20
+ if is_pil:
21
+ img = np.array(img)
22
+ if img.shape[:2] != (512, 512):
23
+ img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation)
24
+ if is_pil:
25
+ img = Image.fromarray(img)
26
+ return img
27
+
28
+
29
+ def un_norm_clip(x1):
30
+ x = x1 * 1.0
31
+ reduce = False
32
+ if len(x.shape) == 3:
33
+ x = x.unsqueeze(0)
34
+ reduce = True
35
+ x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
36
+ x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
37
+ x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
38
+ if reduce:
39
+ x = x.squeeze(0)
40
+ return x
41
+
42
+
43
+ def un_norm(x):
44
+ return (x + 1.0) / 2.0
45
+
46
+
47
+ def _dilate(_mask, kernel_size, iterations):
48
+ _mask = _mask.astype(np.uint8)
49
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
50
+ _mask = cv2.dilate(_mask, kernel, iterations=iterations)
51
+ _mask = _mask.astype(bool)
52
+ return _mask
53
+
54
+
55
+ def dilate_4_task0(sm_mask):
56
+ sm_mask = np.array(sm_mask)
57
+ preserve1 = [2, 3, 10, 5]
58
+ mask1 = np.isin(sm_mask, preserve1)
59
+ mask1 = _dilate(mask1, 7, 1)
60
+ preserve2 = [3, 10]
61
+ mask2 = np.isin(sm_mask, preserve2)
62
+ mask2 = _dilate(mask2, 10, 3)
63
+ preserve3 = [1]
64
+ mask3 = np.isin(sm_mask, preserve3)
65
+ mask3 = _dilate(mask3, 7, 2)
66
+ mask = mask1 | mask2 | mask3
67
+ return mask
68
+
69
+
70
+ class Dataset_custom(data.Dataset):
71
+ mean = (0.5, 0.5, 0.5)
72
+ std = (0.5, 0.5, 0.5)
73
+
74
+ def get_img4clip(
75
+ self,
76
+ img,
77
+ sm_mask,
78
+ preserve,
79
+ for_clip=True,
80
+ add_semantic_head=False,
81
+ mask_after_npisin=None,
82
+ for_inpaint512=False,
83
+ ):
84
+ sm_mask = np.array(sm_mask)
85
+ if mask_after_npisin is None:
86
+ if self.task == 0 and 0:
87
+ mask = dilate_4_task0(sm_mask)
88
+ else:
89
+ mask = np.isin(sm_mask, preserve)
90
+ if self.task == 0 and 1 and for_inpaint512:
91
+ forehead_mask = get_forehead_mask(sm_mask)
92
+ mask = mask & ~forehead_mask
93
+ else:
94
+ mask = mask_after_npisin
95
+
96
+ if isinstance(img, np.ndarray):
97
+ img = Image.fromarray(img)
98
+ if add_semantic_head:
99
+ mask_before_colorSM = mask
100
+ img, mask = add_colorSM(img, sm_mask, preserve, None)
101
+ mask = mask_after_npisin__2__tensor(mask)
102
+
103
+ if for_clip:
104
+ image_tensor = get_tensor_clip()(img)
105
+ else:
106
+ image_tensor = get_tensor(mean=self.mean, std=self.std)(img)
107
+ image_tensor = T.Resize([512, 512])(image_tensor)
108
+ image_tensor = image_tensor * mask
109
+ if for_clip:
110
+ image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy()
111
+ _size = 224
112
+ else:
113
+ image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy()
114
+ _size = 512
115
+
116
+ image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor)
117
+ image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8))
118
+ if for_clip:
119
+ image_tensor = get_tensor_clip()(image_tensor)
120
+ else:
121
+ image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor)
122
+ image_tensor = image_tensor * mask
123
+ if add_semantic_head:
124
+ mask = mask_after_npisin__2__tensor(mask_before_colorSM)
125
+ return image_tensor, mask
126
+
127
+ def __init__(
128
+ self,
129
+ state,
130
+ task,
131
+ paths_tgt,
132
+ paths_ref,
133
+ name="custom",
134
+ ):
135
+ if task == 0:
136
+ USE_filter_mediapipe_fail_swap = 1
137
+ USE_pts = 1
138
+ READ_mediapipe_result_from_cache = 1
139
+ elif task == 1:
140
+ USE_filter_mediapipe_fail_swap = 0
141
+ USE_pts = 0
142
+ READ_mediapipe_result_from_cache = 1
143
+ elif task == 2:
144
+ USE_filter_mediapipe_fail_swap = 1
145
+ USE_pts = 1
146
+ READ_mediapipe_result_from_cache = 1
147
+ elif task == 3:
148
+ USE_filter_mediapipe_fail_swap = 0
149
+ USE_pts = 1
150
+ READ_mediapipe_result_from_cache = 1
151
+ self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache
152
+
153
+ assert state == "test"
154
+ self.state = state
155
+ self.image_size = 512
156
+ self.kernel = np.ones((1, 1), np.uint8)
157
+ self.name = name
158
+
159
+ assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required"
160
+ assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length"
161
+ self.paths_tgt = list(paths_tgt)
162
+ self.paths_ref = list(paths_ref)
163
+
164
+ if READ_mediapipe_result_from_cache:
165
+ self.mediapipe_Result_Cache = Mediapipe_Result_Cache()
166
+ self.task = task
167
+
168
+ def __getitem__(self, index):
169
+ task = self.task
170
+ path_tgt = self.paths_tgt[index]
171
+ path_ref = self.paths_ref[index]
172
+
173
+
174
+ img_tgt = Image.open(path_tgt).convert("RGB")
175
+ img_tgt = resize_A(img_tgt, self.name)
176
+
177
+ mask_path = path_img_2_path_mask(path_tgt)
178
+ if self.task == 0:
179
+ preserve = [1, 2, 3, 10, 5, 6, 7, 9]
180
+ if 0:
181
+ preserve = [1, 2, 3, 10, 5]
182
+ sm_mask_tgt = Image.open(mask_path).convert("L")
183
+ sm_mask_tgt = np.array(sm_mask_tgt)
184
+ if 0:
185
+ mask_tgt = dilate_4_task0(sm_mask_tgt)
186
+ else:
187
+ mask_tgt = np.isin(sm_mask_tgt, preserve)
188
+ if self.task == 0 and 1:
189
+ forehead_mask = get_forehead_mask(sm_mask_tgt)
190
+ mask_tgt = mask_tgt & ~forehead_mask
191
+ elif self.task == 1:
192
+ preserve = [4]
193
+ mask_tgt = path_img_2_mask(path_tgt, preserve)
194
+ elif self.task == 3:
195
+ preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9]
196
+ mask_tgt = path_img_2_mask(path_tgt, preserve)
197
+ elif self.task == 2:
198
+ preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21]
199
+ sm_mask_tgt = Image.open(mask_path).convert("L")
200
+ sm_mask_tgt = np.array(sm_mask_tgt)
201
+ mask_tgt = np.isin(sm_mask_tgt, preserve)
202
+
203
+ converted_mask = np.zeros_like(mask_tgt)
204
+ converted_mask[mask_tgt] = 255
205
+ mask_tgt = Image.fromarray(converted_mask).convert("L")
206
+ mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt)
207
+
208
+ image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt)
209
+ image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor)
210
+ mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor)
211
+
212
+ if task == 2:
213
+ inpaint_tensor_resize = image_tensor_resize
214
+ else:
215
+ inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize
216
+ if 1:
217
+ mask_tensor_resize = 1 - mask_tensor_resize
218
+
219
+ if 1:
220
+ mask_path_ref = path_img_2_path_mask(path_ref)
221
+ sm_mask_ref = Image.open(mask_path_ref).convert("L")
222
+ sm_mask_ref = np.array(sm_mask_ref)
223
+ img_ref = cv2.imread(str(path_ref))
224
+ img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB)
225
+ img_ref = resize_A(img_ref, self.name)
226
+
227
+ if task != 2:
228
+ ref_image_tensor, ref_mask_tensor = self.get_img4clip(
229
+ img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0
230
+ )
231
+ if task == 3:
232
+ ref_image_faceOnly_tensor, _ = self.get_img4clip(
233
+ img_ref,
234
+ sm_mask_ref,
235
+ [1, 2, 3, 10, 5, 6, 7, 9],
236
+ for_clip=False,
237
+ add_semantic_head=0,
238
+ )
239
+ else:
240
+ ref_image_tensor = inpaint_tensor_resize
241
+
242
+ ret = {
243
+ "inpaint_image": inpaint_tensor_resize,
244
+ "inpaint_mask": mask_tensor_resize,
245
+ "ref_imgs": ref_image_tensor,
246
+ "task": self.task,
247
+ }
248
+
249
+ if self.task == 0:
250
+ ret["enInputs"] = {
251
+ "face_ID-in": ref_image_tensor,
252
+ "face-clip-in": ref_image_tensor,
253
+ }
254
+ elif self.task == 1:
255
+ ret["enInputs"] = {
256
+ "hair-clip-in": ref_image_tensor,
257
+ }
258
+ elif self.task == 2:
259
+ tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve)
260
+ ret["enInputs"] = {
261
+ "face_ID-in": tgt_nonBg_tensor,
262
+ "head-clip-in": tgt_nonBg_tensor,
263
+ }
264
+ elif self.task == 3:
265
+ ret["enInputs"] = {
266
+ "face_ID-in": ref_image_faceOnly_tensor,
267
+ "head-clip-in": ref_image_tensor,
268
+ }
269
+
270
+ if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14:
271
+ if task != 2:
272
+ ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
273
+ img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
274
+ )
275
+ else:
276
+ ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
277
+ img_tgt,
278
+ sm_mask_tgt,
279
+ "any",
280
+ for_clip=False,
281
+ add_semantic_head=0,
282
+ mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool),
283
+ )
284
+ ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet)
285
+ ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet)
286
+ ret["ref_imgs_4unet"] = ref_imgs_4unet
287
+ ret["ref_mask_512"] = ref_mask_512
288
+
289
+ if self.READ_mediapipe_result_from_cache:
290
+ if self.state == "test":
291
+ if task == 2:
292
+ _p_lmk = path_ref
293
+ else:
294
+ _p_lmk = path_tgt
295
+ else:
296
+ _p_lmk = path_tgt
297
+ ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk)
298
+ if ret["mediapipe_lmkAll"] is None:
299
+ raise RuntimeError(
300
+ f"Missing Mediapipe cache for input image: {_p_lmk}. "
301
+ "Precompute landmarks and ensure cache exists before inference."
302
+ )
303
+
304
+ if self.state == "test":
305
+ prior_image_tensor = "None"
306
+ out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}"
307
+ if task == 2:
308
+ ref512, _ = self.get_img4clip(
309
+ img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
310
+ )
311
+ ref512 = T.Resize([self.image_size, self.image_size])(ref512)
312
+ ret["ref512"] = ref512
313
+ ret = (image_tensor_resize, prior_image_tensor, ret, out_stem)
314
+ return ret
315
+
316
+ def __len__(self):
317
+ return len(self.paths_tgt)
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sanoojan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+
LatentDiffusion.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "inpaint"
11
+ cond_stage_key: "image"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: true # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ u_cond_percent: 0.2
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+
22
+ scheduler_config: # 10000 warmup steps
23
+ target: ldm.lr_scheduler.LambdaLinearScheduler
24
+ params:
25
+ warm_up_steps: [ 10000 ]
26
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
27
+ f_start: [ 1.e-1 ]
28
+ f_max: [ 1. ]
29
+ f_min: [ 1. ]
30
+
31
+ unet_config:
32
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
33
+ params:
34
+ image_size: 32 # unused
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+ add_conv_in_front_of_unet: False
47
+
48
+ first_stage_config:
49
+ target: ldm.models.autoencoder.AutoencoderKL
50
+ params:
51
+ embed_dim: 4
52
+ monitor: val/rec_loss
53
+ ddconfig:
54
+ double_z: true
55
+ z_channels: 4
56
+ resolution: 256
57
+ in_channels: 3
58
+ out_ch: 3
59
+ ch: 128
60
+ ch_mult:
61
+ - 1
62
+ - 2
63
+ - 4
64
+ - 4
65
+ num_res_blocks: 2
66
+ attn_resolutions: []
67
+ dropout: 0.0
68
+ lossconfig:
69
+ target: torch.nn.Identity
70
+
71
+ cond_stage_config:
72
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
73
+ other_params:
74
+ clip_weight: 1.0
75
+ arcface_path: "Other_dependencies/arcface/model_ir_se50.pth"
76
+ multi_scale_ID: False # True was used for the previous training there is an issue
77
+ Additional_config:
78
+ Reconstruct_initial: False # scy:
79
+ Target_CLIP_feat: True
80
+ Source_CLIP_feat: True
81
+ Reconstruct_DDIM_steps: 4
82
+
83
+
Mediapipe_Result_Cache.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imports import *
2
+ import json,random,os
3
+ import numpy as np
4
+
5
+
6
+
7
+ class Mediapipe_Result_Cache:
8
+ """
9
+ Convention: when a cache entry exists, it must not be None.
10
+ In other words, None results should not be cached; get/set guard against historical None values.
11
+ """
12
+ # DIR = Path('/inspurfs/group/mayuexin/suncy/mediapipe_result/A')
13
+ DIR = Path('data/mediapipe_result')
14
+ def __init__(self):
15
+ pass
16
+ def get_path(self, img_path):
17
+ img_path = Path(img_path)
18
+ str_img_folder = str(img_path.parent)
19
+ assert '|' not in str_img_folder
20
+ str_img_folder = str_img_folder.replace('/', '|')
21
+ lmk_folder = self.DIR / str_img_folder
22
+ lmk_folder.mkdir(parents=1, exist_ok=True)
23
+ ret= lmk_folder / (img_path.name+'.npy')
24
+ return ret
25
+ def get(self, img_path):
26
+ path = self.get_path(img_path)
27
+ # print(f"[get] {path=}")
28
+ if path.exists():
29
+ ret = np.load(path)
30
+ assert ret is not None
31
+ return ret
32
+ def set(self, img_path, lmks):
33
+ assert lmks is not None
34
+ path = self.get_path(img_path)
35
+ np.save(path, lmks)
36
+ # print(f"{path=}")
MoE.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imports import *
2
+ import global_
3
+ import torch,copy
4
+ import torch.nn as nn
5
+ from ldm.modules.attention import FeedForward,CrossAttention
6
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential
7
+ # import torch.nn.functional as F
8
+
9
+ # ---------------- Configs ----------------
10
+ CONV2D_PARAM_STATS = []
11
+
12
+ def average_module_weight(src_modules: list):
13
+ """Average the weights of multiple modules (similar to init_model.py)."""
14
+ if not src_modules:
15
+ return None
16
+ avg_state_dict = {}
17
+ first_state_dict = src_modules[0].state_dict()
18
+ for key in first_state_dict:
19
+ avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
20
+ for module in src_modules:
21
+ module_state_dict = module.state_dict()
22
+ for key in avg_state_dict:
23
+ avg_state_dict[key] += module_state_dict[key]
24
+ for key in avg_state_dict:
25
+ avg_state_dict[key] /= len(src_modules)
26
+ return avg_state_dict
27
+
28
+ class ModuleDict_W(nn.Module): # Wrapper of ModuleDict
29
+ def __init__(self, modules: list, keys: list):
30
+ super().__init__()
31
+ assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}"
32
+ self._keys = [int(k) for k in keys]
33
+ self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)})
34
+ def __getitem__(self, k: int):
35
+ _k = str(int(k))
36
+ return self._moduleDict[_k]
37
+ def keys(self):
38
+ return list(self._keys)
39
+ def forward(self, *args, **kwargs):
40
+ cur_task = global_.task
41
+ assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}"
42
+ return self._moduleDict[str(int(cur_task))](*args, **kwargs)
43
+ def offload_unused_tasks(self, unused_tasks, method: str):
44
+ for i in unused_tasks:
45
+ _k = str(int(i))
46
+ if _k in self._moduleDict:
47
+ if method == 'del':
48
+ # self._moduleDict[_k] = None # should behave the same either way
49
+ del self._moduleDict[_k]
50
+ elif method == 'cpu':
51
+ self._moduleDict[_k].to('cpu')
52
+ else:
53
+ raise
54
+
55
+ class TaskSpecific_MoE(nn.Module):
56
+ def __init__(
57
+ self,
58
+ module:nn.Module,# or list of Module
59
+ tasks:tuple,
60
+ ):
61
+ super().__init__()
62
+ self.cur_task = None
63
+ self.tasks = tasks
64
+ if isinstance(module, nn.Module):
65
+ modules = [copy.deepcopy(module) for _ in self.tasks]
66
+ elif isinstance(module, list):
67
+ assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}"
68
+ modules = module
69
+ else:
70
+ raise ValueError(f"got {type(module)}")
71
+ self.tasks_2_module = ModuleDict_W(modules, self.tasks)
72
+
73
+ def forward(self, *args, **kwargs) -> torch.Tensor:
74
+ # cur_task = self.cur_task
75
+ cur_task = global_.task
76
+ assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}"
77
+ return self.tasks_2_module[cur_task](*args, **kwargs)
78
+
79
+ def set_task(self, task):
80
+ assert 0, 'set_task is disabled for now; update to gg.task instead'
81
+ # assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}"
82
+ self.cur_task = task
83
+
84
+ def is_task_specific_(name:str):
85
+ is_task_specific = (
86
+ ('._moduleDict.' in name) or
87
+ ('tasks_2_module' in name) or
88
+ ('task_ffn' in name) or
89
+ ('task_proj' in name) or
90
+ ('task_conv' in name) or
91
+ ('task_gate_mlps' in name) or
92
+ ('task_lora' in name) or
93
+
94
+ ('encoder_clip_' in name) or
95
+ ('proj_out_source__' in name) or
96
+ ('ID_proj_out' in name) or
97
+ ('landmark_proj_out' in name) or
98
+ ('learnable_vector' in name)
99
+ )
100
+ return is_task_specific
101
+ def tp_param_need_sync(name: str, p: torch.nn.Parameter):
102
+ if is_task_specific_(name):
103
+ return False, True
104
+ if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name:
105
+ return False, False
106
+ if not p.requires_grad:
107
+ return False, False
108
+ return True, False
109
+ def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ):
110
+ unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks
111
+ for name, child in parent.named_children():
112
+ if hasattr(child, '__class__') and child.__class__.__name__ in [
113
+ 'TaskSpecific_MoE',
114
+ 'FFN_TaskSpecific_Plus_Shared',
115
+ 'Linear_TaskSpecific_Plus_Shared',
116
+ 'Conv_TaskSpecific_Plus_Shared',
117
+ 'FFN_Shared_Plus_TaskLoRA',
118
+ 'Linear_Shared_Plus_TaskLoRA',
119
+ 'Conv_Shared_Plus_TaskLoRA',
120
+ ]:
121
+ for attr_name in [ # normalize attribute handling to avoid repetition
122
+ 'tasks_2_module',
123
+ 'task_ffn', 'task_proj', 'task_conv',
124
+ 'task_lora_in', 'task_lora_out', 'task_lora',
125
+ ]:
126
+ if hasattr(child, attr_name):
127
+ ml = getattr(child, attr_name)
128
+ if isinstance(ml, nn.ModuleList):
129
+ for i in unused_tasks: # move or delete parameters for inactive tasks
130
+ if method == 'del':
131
+ ml[i] = None
132
+ elif method == 'cpu':
133
+ ml[i].to('cpu')
134
+ else: raise Exception
135
+ elif isinstance(ml, ModuleDict_W):
136
+ ml.offload_unused_tasks(unused_tasks,method)
137
+ # recurse(child)
138
+ else: offload_unused_tasks(child,active_task,method)
139
+ def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ):
140
+ # Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu)
141
+ offload_unused_tasks(modelMOE, task_keep, method)
Other_dependencies/arcface/add.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Add arcface model
Other_dependencies/arcface/model_ir_se50.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a035c768259b98ab1ce0e646312f48b9e1e218197a0f80ac6765e88f8b6ddf28
3
+ size 175367323
Other_dependencies/face_parsing/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
Other_dependencies/face_parsing/add.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Add face parsing model
Other_dependencies/mp_models/blaze_face_short_range.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
3
+ size 229746
Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
3
+ size 3758596
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Space demo for UniBioTransfer.
3
+ Gradio interface for face/hair/motion/head transfer.
4
+
5
+ ZeroGPU Compatible:
6
+ - Model initialized on CPU (no GPU memory during startup)
7
+ - Inference wrapped with @spaces.GPU decorator
8
+ - Thread-safe global variable access with Lock
9
+ """
10
+
11
+ import threading
12
+ import torch
13
+ from PIL import Image
14
+ import numpy as np
15
+
16
+ # ==========================================
17
+ # 兼容层:处理本地测试 vs HF ZeroGPU 环境
18
+ # ==========================================
19
+ try:
20
+ import spaces
21
+ print("Detected spaces library (Hugging Face environment).")
22
+ except ImportError:
23
+ print("Local environment detected. Mocking spaces.GPU...")
24
+ class spaces:
25
+ @staticmethod
26
+ def GPU(func):
27
+ return func # 本地测试时,装饰器变为空壳,直接执行原函数
28
+
29
+ from infer_hf import UniBioTransferPipeline
30
+
31
+ # 锁和全局单例 Pipeline
32
+ inference_lock = threading.Lock()
33
+ global_pipeline :UniBioTransferPipeline = None
34
+
35
+
36
+ def get_pipeline(task):
37
+ """
38
+ 单例模式:全局只初始化一次模型(放在 CPU),后续只切换任务。
39
+ 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡。
40
+ """
41
+ global global_pipeline
42
+ if global_pipeline is None:
43
+ print("Initializing pipeline once on CPU...")
44
+ # 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡
45
+ global_pipeline = UniBioTransferPipeline.from_pretrained(
46
+ repo_id="scy639/UniBioTransfer",
47
+ task=task,
48
+ device="cpu",
49
+ )
50
+ else:
51
+ # 如果模型已经在内存中,只需切换 task ID 即可
52
+ print(f"Switching existing pipeline to task: {task}")
53
+ global_pipeline.set_task(task)
54
+ return global_pipeline
55
+
56
+
57
+ # 核心:将所有会用到 GPU 的前向推理逻辑包裹在这里
58
+ @spaces.GPU
59
+ def run_gpu_inference(pipeline:UniBioTransferPipeline, tgt_pil, ref_pil, ddim_steps, scale, seed, num_images):
60
+ """
61
+ 这里是 ZeroGPU 分配算力的地方。进入此函数时可以安全地 to("cuda")。
62
+ 如果是在本地服务器,这个装饰器没用,但内部的 .to("cuda") 同样生效。
63
+ """
64
+ return pipeline(
65
+ tgt_pil,
66
+ ref_pil,
67
+ ddim_steps=ddim_steps,
68
+ scale=scale,
69
+ seed=seed,
70
+ num_images=num_images,
71
+ )
72
+
73
+
74
+ def inference(task, tgt_img, ref_img, ddim_steps, seed, num_images):
75
+ """
76
+ Run inference for the demo.
77
+ """
78
+ if tgt_img is None or ref_img is None:
79
+ return None, "Please upload both target and reference images."
80
+
81
+ try:
82
+ # 1. 拿模型 (此时模型在 CPU)
83
+ pipeline = get_pipeline(task)
84
+
85
+ tgt_pil = Image.fromarray(tgt_img).convert("RGB")
86
+ ref_pil = Image.fromarray(ref_img).convert("RGB")
87
+
88
+ # 2. 加锁,防止并发污染 global_.task,进入 GPU 推理
89
+ with inference_lock:
90
+ results = run_gpu_inference(
91
+ pipeline,
92
+ tgt_pil,
93
+ ref_pil,
94
+ int(ddim_steps),
95
+ float(3),
96
+ int(seed),
97
+ int(num_images)
98
+ )
99
+
100
+ return results, f"Success! Task: {task} transfer completed."
101
+
102
+ except Exception as e:
103
+ import traceback
104
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
105
+ print(f"{error_msg}")
106
+ return None, error_msg
107
+
108
+
109
+ def create_demo():
110
+ """Create Gradio demo interface."""
111
+ import gradio as gr
112
+
113
+ with gr.Blocks(title="UniBioTransfer") as demo:
114
+ gr.Markdown(
115
+ """
116
+ # UniBioTransfer
117
+
118
+ Perform face transfer, hair transfer, motion transfer (face reenactment), and head transfer.
119
+
120
+ - **Face Transfer**: Transfer face identity from reference to target
121
+ - **Hair Transfer**: Transfer hairstyle from reference to target
122
+ - **Motion Transfer**: Transfer motion(expression+head pose) from reference to target
123
+ - **Head Transfer**: Transfer entire head from reference to target
124
+
125
+ [Code](https://github.com/scy639/UniBioTransfer)
126
+ [Project Page](https://scy639.github.io/UniBioTransfer.github.io/)
127
+ [Paper](https://arxiv.org/abs/2603.19637)
128
+ """
129
+ )
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ task_dropdown = gr.Dropdown(
134
+ choices=["face", "hair", "motion", "head"],
135
+ value="face",
136
+ label="Task",
137
+ info="Select the transfer type",
138
+ )
139
+
140
+ with gr.Row():
141
+ tgt_image = gr.Image(
142
+ label="Target Image",
143
+ type="numpy",
144
+ height=300,
145
+ )
146
+ ref_image = gr.Image(
147
+ label="Reference Image",
148
+ type="numpy",
149
+ height=300,
150
+ )
151
+
152
+ with gr.Row():
153
+ ddim_steps = gr.Slider(
154
+ minimum=4,
155
+ maximum=50,
156
+ value=50,
157
+ step=1,
158
+ label="DDIM Steps",
159
+ info="More steps = better quality but slower",
160
+ )
161
+ # scale = gr.Slider(
162
+ # minimum=1.0,
163
+ # maximum=10.0,
164
+ # value=3.0,
165
+ # step=0.5,
166
+ # label="CFG Scale",
167
+ # info="Guidance scale for conditioning",
168
+ # )
169
+
170
+ seed = gr.Number(
171
+ value=42,
172
+ label="Random Seed",
173
+ info="For reproducibility",
174
+ )
175
+
176
+ num_images = gr.Slider(
177
+ minimum=1,
178
+ maximum=32,
179
+ value=4,
180
+ step=1,
181
+ label="Number of output images",
182
+ info="Multi-output with different initial noise",
183
+ )
184
+
185
+ run_btn = gr.Button("Run Inference", variant="primary")
186
+
187
+ with gr.Column():
188
+ output_gallery = gr.Gallery(
189
+ label="Results",
190
+ height=800,
191
+ columns=2,
192
+ )
193
+ status_text = gr.Textbox(
194
+ label="Status",
195
+ lines=3,
196
+ )
197
+
198
+ gr.Markdown(
199
+ """
200
+ ### Usage
201
+ 1. Upload a **target image** (the person whose face/hair/motion/head will be modified)
202
+ 2. Upload a **reference image** (the source of the attribute to transfer)
203
+ 3. Select the **task** type
204
+ 4. Click "Run Inference"
205
+
206
+ ### Requirements
207
+ - Works best when the heads in the two input images have similar sizes.
208
+ """
209
+ )
210
+
211
+ run_btn.click(
212
+ fn=inference,
213
+ inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
214
+ outputs=[output_gallery, status_text],
215
+ )
216
+
217
+ task_dropdown.change(
218
+ fn=lambda t: f"Task switched to: {t} transfer",
219
+ inputs=[task_dropdown],
220
+ outputs=[status_text],
221
+ )
222
+
223
+ gr.Examples(
224
+ examples=[
225
+ ["face", "examples/face/tgt.png", "examples/face/ref.png", 20, 42, 4],
226
+ ["hair", "examples/hair/tgt.png", "examples/hair/ref.png", 20, 42, 4],
227
+ ["motion", "examples/motion/tgt.png", "examples/motion/ref.png", 20, 42, 4],
228
+ ["head", "examples/head/tgt.png", "examples/head/ref.png", 20, 42, 4],
229
+ ],
230
+ inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
231
+ label="Examples",
232
+ )
233
+
234
+ return demo
235
+
236
+
237
+ if __name__ == "__main__":
238
+ demo = create_demo()
239
+ demo.launch()
checkpoints/pretrained.json ADDED
@@ -0,0 +1,1072 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ ".model.diffusion_model.input_blocks.0.0": [
3
+ 4,
4
+ 4,
5
+ 4,
6
+ 4
7
+ ],
8
+ ".model.diffusion_model.input_blocks.1.0.in_layers.2": [
9
+ 5,
10
+ 4,
11
+ 8,
12
+ 4
13
+ ],
14
+ ".model.diffusion_model.input_blocks.1.0.out_layers.3": [
15
+ 7,
16
+ 4,
17
+ 12,
18
+ 4
19
+ ],
20
+ ".model.diffusion_model.input_blocks.1.1.proj_in": [
21
+ 4,
22
+ 4,
23
+ 6,
24
+ 4
25
+ ],
26
+ ".model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff": [
27
+ [
28
+ 5,
29
+ 4,
30
+ 8,
31
+ 4
32
+ ],
33
+ [
34
+ 7,
35
+ 4,
36
+ 12,
37
+ 4
38
+ ]
39
+ ],
40
+ ".model.diffusion_model.input_blocks.1.1.proj_out": [
41
+ 4,
42
+ 4,
43
+ 8,
44
+ 4
45
+ ],
46
+ ".model.diffusion_model.input_blocks.2.0.in_layers.2": [
47
+ 14,
48
+ 5,
49
+ 19,
50
+ 4
51
+ ],
52
+ ".model.diffusion_model.input_blocks.2.0.out_layers.3": [
53
+ 16,
54
+ 4,
55
+ 15,
56
+ 4
57
+ ],
58
+ ".model.diffusion_model.input_blocks.2.1.proj_in": [
59
+ 9,
60
+ 4,
61
+ 11,
62
+ 4
63
+ ],
64
+ ".model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff": [
65
+ [
66
+ 16,
67
+ 4,
68
+ 14,
69
+ 4
70
+ ],
71
+ [
72
+ 17,
73
+ 4,
74
+ 14,
75
+ 4
76
+ ]
77
+ ],
78
+ ".model.diffusion_model.input_blocks.2.1.proj_out": [
79
+ 13,
80
+ 4,
81
+ 11,
82
+ 4
83
+ ],
84
+ ".model.diffusion_model.input_blocks.3.0.op": [
85
+ 26,
86
+ 7,
87
+ 31,
88
+ 8
89
+ ],
90
+ ".model.diffusion_model.input_blocks.4.0.in_layers.2": [
91
+ 23,
92
+ 6,
93
+ 31,
94
+ 8
95
+ ],
96
+ ".model.diffusion_model.input_blocks.4.0.out_layers.3": [
97
+ 27,
98
+ 6,
99
+ 37,
100
+ 8
101
+ ],
102
+ ".model.diffusion_model.input_blocks.4.0.skip_connection": [
103
+ 20,
104
+ 6,
105
+ 22,
106
+ 6
107
+ ],
108
+ ".model.diffusion_model.input_blocks.4.1.proj_in": [
109
+ 20,
110
+ 6,
111
+ 28,
112
+ 7
113
+ ],
114
+ ".model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff": [
115
+ [
116
+ 22,
117
+ 6,
118
+ 37,
119
+ 8
120
+ ],
121
+ [
122
+ 31,
123
+ 8,
124
+ 39,
125
+ 10
126
+ ]
127
+ ],
128
+ ".model.diffusion_model.input_blocks.4.1.proj_out": [
129
+ 26,
130
+ 8,
131
+ 37,
132
+ 10
133
+ ],
134
+ ".model.diffusion_model.input_blocks.5.0.in_layers.2": [
135
+ 27,
136
+ 10,
137
+ 46,
138
+ 11
139
+ ],
140
+ ".model.diffusion_model.input_blocks.5.0.out_layers.3": [
141
+ 18,
142
+ 6,
143
+ 36,
144
+ 7
145
+ ],
146
+ ".model.diffusion_model.input_blocks.5.1.proj_in": [
147
+ 20,
148
+ 7,
149
+ 29,
150
+ 7
151
+ ],
152
+ ".model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff": [
153
+ [
154
+ 22,
155
+ 7,
156
+ 41,
157
+ 9
158
+ ],
159
+ [
160
+ 26,
161
+ 10,
162
+ 33,
163
+ 12
164
+ ]
165
+ ],
166
+ ".model.diffusion_model.input_blocks.5.1.proj_out": [
167
+ 24,
168
+ 9,
169
+ 33,
170
+ 10
171
+ ],
172
+ ".model.diffusion_model.input_blocks.6.0.op": [
173
+ 52,
174
+ 17,
175
+ 76,
176
+ 20
177
+ ],
178
+ ".model.diffusion_model.input_blocks.7.0.in_layers.2": [
179
+ 50,
180
+ 14,
181
+ 80,
182
+ 19
183
+ ],
184
+ ".model.diffusion_model.input_blocks.7.0.out_layers.3": [
185
+ 56,
186
+ 15,
187
+ 90,
188
+ 22
189
+ ],
190
+ ".model.diffusion_model.input_blocks.7.0.skip_connection": [
191
+ 40,
192
+ 13,
193
+ 59,
194
+ 16
195
+ ],
196
+ ".model.diffusion_model.input_blocks.7.1.proj_in": [
197
+ 33,
198
+ 12,
199
+ 55,
200
+ 14
201
+ ],
202
+ ".model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff": [
203
+ [
204
+ 39,
205
+ 11,
206
+ 62,
207
+ 13
208
+ ],
209
+ [
210
+ 59,
211
+ 17,
212
+ 82,
213
+ 21
214
+ ]
215
+ ],
216
+ ".model.diffusion_model.input_blocks.7.1.proj_out": [
217
+ 55,
218
+ 17,
219
+ 80,
220
+ 22
221
+ ],
222
+ ".model.diffusion_model.input_blocks.8.0.in_layers.2": [
223
+ 73,
224
+ 20,
225
+ 108,
226
+ 27
227
+ ],
228
+ ".model.diffusion_model.input_blocks.8.0.out_layers.3": [
229
+ 65,
230
+ 15,
231
+ 95,
232
+ 21
233
+ ],
234
+ ".model.diffusion_model.input_blocks.8.1.proj_in": [
235
+ 43,
236
+ 13,
237
+ 69,
238
+ 18
239
+ ],
240
+ ".model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff": [
241
+ [
242
+ 41,
243
+ 10,
244
+ 68,
245
+ 13
246
+ ],
247
+ [
248
+ 56,
249
+ 17,
250
+ 85,
251
+ 21
252
+ ]
253
+ ],
254
+ ".model.diffusion_model.input_blocks.8.1.proj_out": [
255
+ 52,
256
+ 16,
257
+ 78,
258
+ 20
259
+ ],
260
+ ".model.diffusion_model.input_blocks.9.0.op": [
261
+ 90,
262
+ 30,
263
+ 157,
264
+ 39
265
+ ],
266
+ ".model.diffusion_model.input_blocks.10.0.in_layers.2": [
267
+ 81,
268
+ 21,
269
+ 113,
270
+ 26
271
+ ],
272
+ ".model.diffusion_model.input_blocks.10.0.out_layers.3": [
273
+ 80,
274
+ 21,
275
+ 123,
276
+ 28
277
+ ],
278
+ ".model.diffusion_model.input_blocks.11.0.in_layers.2": [
279
+ 87,
280
+ 23,
281
+ 118,
282
+ 28
283
+ ],
284
+ ".model.diffusion_model.input_blocks.11.0.out_layers.3": [
285
+ 77,
286
+ 20,
287
+ 113,
288
+ 26
289
+ ],
290
+ ".model.diffusion_model.middle_block.0.in_layers.2": [
291
+ 84,
292
+ 22,
293
+ 113,
294
+ 26
295
+ ],
296
+ ".model.diffusion_model.middle_block.0.out_layers.3": [
297
+ 68,
298
+ 16,
299
+ 99,
300
+ 21
301
+ ],
302
+ ".model.diffusion_model.middle_block.1.proj_in": [
303
+ 36,
304
+ 10,
305
+ 59,
306
+ 13
307
+ ],
308
+ ".model.diffusion_model.middle_block.1.transformer_blocks.0.ff": [
309
+ [
310
+ 31,
311
+ 5,
312
+ 45,
313
+ 6
314
+ ],
315
+ [
316
+ 55,
317
+ 15,
318
+ 69,
319
+ 17
320
+ ]
321
+ ],
322
+ ".model.diffusion_model.middle_block.1.proj_out": [
323
+ 39,
324
+ 10,
325
+ 61,
326
+ 14
327
+ ],
328
+ ".model.diffusion_model.middle_block.2.in_layers.2": [
329
+ 73,
330
+ 17,
331
+ 104,
332
+ 23
333
+ ],
334
+ ".model.diffusion_model.middle_block.2.out_layers.3": [
335
+ 62,
336
+ 15,
337
+ 88,
338
+ 20
339
+ ],
340
+ ".model.diffusion_model.output_blocks.0.0.in_layers.2": [
341
+ 96,
342
+ 25,
343
+ 135,
344
+ 32
345
+ ],
346
+ ".model.diffusion_model.output_blocks.0.0.out_layers.3": [
347
+ 86,
348
+ 21,
349
+ 120,
350
+ 28
351
+ ],
352
+ ".model.diffusion_model.output_blocks.0.0.skip_connection": [
353
+ 64,
354
+ 21,
355
+ 106,
356
+ 27
357
+ ],
358
+ ".model.diffusion_model.output_blocks.1.0.in_layers.2": [
359
+ 94,
360
+ 27,
361
+ 155,
362
+ 36
363
+ ],
364
+ ".model.diffusion_model.output_blocks.1.0.out_layers.3": [
365
+ 86,
366
+ 24,
367
+ 136,
368
+ 31
369
+ ],
370
+ ".model.diffusion_model.output_blocks.1.0.skip_connection": [
371
+ 72,
372
+ 23,
373
+ 115,
374
+ 29
375
+ ],
376
+ ".model.diffusion_model.output_blocks.2.0.in_layers.2": [
377
+ 84,
378
+ 31,
379
+ 164,
380
+ 39
381
+ ],
382
+ ".model.diffusion_model.output_blocks.2.0.out_layers.3": [
383
+ 42,
384
+ 19,
385
+ 123,
386
+ 29
387
+ ],
388
+ ".model.diffusion_model.output_blocks.2.0.skip_connection": [
389
+ 72,
390
+ 24,
391
+ 110,
392
+ 28
393
+ ],
394
+ ".model.diffusion_model.output_blocks.2.1.conv": [
395
+ 72,
396
+ 25,
397
+ 121,
398
+ 29
399
+ ],
400
+ ".model.diffusion_model.output_blocks.3.0.in_layers.2": [
401
+ 85,
402
+ 31,
403
+ 158,
404
+ 38
405
+ ],
406
+ ".model.diffusion_model.output_blocks.3.0.out_layers.3": [
407
+ 42,
408
+ 21,
409
+ 117,
410
+ 25
411
+ ],
412
+ ".model.diffusion_model.output_blocks.3.0.skip_connection": [
413
+ 71,
414
+ 23,
415
+ 111,
416
+ 28
417
+ ],
418
+ ".model.diffusion_model.output_blocks.3.1.proj_in": [
419
+ 42,
420
+ 14,
421
+ 73,
422
+ 18
423
+ ],
424
+ ".model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff": [
425
+ [
426
+ 37,
427
+ 10,
428
+ 68,
429
+ 13
430
+ ],
431
+ [
432
+ 60,
433
+ 18,
434
+ 83,
435
+ 20
436
+ ]
437
+ ],
438
+ ".model.diffusion_model.output_blocks.3.1.proj_out": [
439
+ 51,
440
+ 18,
441
+ 79,
442
+ 21
443
+ ],
444
+ ".model.diffusion_model.output_blocks.4.0.in_layers.2": [
445
+ 104,
446
+ 32,
447
+ 159,
448
+ 40
449
+ ],
450
+ ".model.diffusion_model.output_blocks.4.0.out_layers.3": [
451
+ 83,
452
+ 24,
453
+ 125,
454
+ 29
455
+ ],
456
+ ".model.diffusion_model.output_blocks.4.0.skip_connection": [
457
+ 73,
458
+ 22,
459
+ 101,
460
+ 28
461
+ ],
462
+ ".model.diffusion_model.output_blocks.4.1.proj_in": [
463
+ 49,
464
+ 15,
465
+ 77,
466
+ 20
467
+ ],
468
+ ".model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff": [
469
+ [
470
+ 38,
471
+ 11,
472
+ 70,
473
+ 14
474
+ ],
475
+ [
476
+ 63,
477
+ 16,
478
+ 85,
479
+ 20
480
+ ]
481
+ ],
482
+ ".model.diffusion_model.output_blocks.4.1.proj_out": [
483
+ 51,
484
+ 18,
485
+ 81,
486
+ 21
487
+ ],
488
+ ".model.diffusion_model.output_blocks.5.0.in_layers.2": [
489
+ 91,
490
+ 33,
491
+ 161,
492
+ 40
493
+ ],
494
+ ".model.diffusion_model.output_blocks.5.0.out_layers.3": [
495
+ 83,
496
+ 26,
497
+ 140,
498
+ 32
499
+ ],
500
+ ".model.diffusion_model.output_blocks.5.0.skip_connection": [
501
+ 81,
502
+ 24,
503
+ 116,
504
+ 30
505
+ ],
506
+ ".model.diffusion_model.output_blocks.5.1.proj_in": [
507
+ 48,
508
+ 16,
509
+ 82,
510
+ 21
511
+ ],
512
+ ".model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff": [
513
+ [
514
+ 34,
515
+ 12,
516
+ 76,
517
+ 15
518
+ ],
519
+ [
520
+ 55,
521
+ 16,
522
+ 81,
523
+ 18
524
+ ]
525
+ ],
526
+ ".model.diffusion_model.output_blocks.5.1.proj_out": [
527
+ 57,
528
+ 19,
529
+ 85,
530
+ 22
531
+ ],
532
+ ".model.diffusion_model.output_blocks.5.2.conv": [
533
+ 108,
534
+ 34,
535
+ 159,
536
+ 41
537
+ ],
538
+ ".model.diffusion_model.output_blocks.6.0.in_layers.2": [
539
+ 55,
540
+ 18,
541
+ 87,
542
+ 22
543
+ ],
544
+ ".model.diffusion_model.output_blocks.6.0.out_layers.3": [
545
+ 32,
546
+ 13,
547
+ 54,
548
+ 15
549
+ ],
550
+ ".model.diffusion_model.output_blocks.6.0.skip_connection": [
551
+ 25,
552
+ 9,
553
+ 30,
554
+ 14
555
+ ],
556
+ ".model.diffusion_model.output_blocks.6.1.proj_in": [
557
+ 26,
558
+ 9,
559
+ 40,
560
+ 11
561
+ ],
562
+ ".model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff": [
563
+ [
564
+ 25,
565
+ 8,
566
+ 47,
567
+ 12
568
+ ],
569
+ [
570
+ 36,
571
+ 11,
572
+ 47,
573
+ 13
574
+ ]
575
+ ],
576
+ ".model.diffusion_model.output_blocks.6.1.proj_out": [
577
+ 23,
578
+ 10,
579
+ 38,
580
+ 12
581
+ ],
582
+ ".model.diffusion_model.output_blocks.7.0.in_layers.2": [
583
+ 55,
584
+ 18,
585
+ 82,
586
+ 20
587
+ ],
588
+ ".model.diffusion_model.output_blocks.7.0.out_layers.3": [
589
+ 47,
590
+ 14,
591
+ 65,
592
+ 17
593
+ ],
594
+ ".model.diffusion_model.output_blocks.7.0.skip_connection": [
595
+ 40,
596
+ 11,
597
+ 40,
598
+ 12
599
+ ],
600
+ ".model.diffusion_model.output_blocks.7.1.proj_in": [
601
+ 27,
602
+ 9,
603
+ 41,
604
+ 11
605
+ ],
606
+ ".model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff": [
607
+ [
608
+ 27,
609
+ 8,
610
+ 47,
611
+ 11
612
+ ],
613
+ [
614
+ 34,
615
+ 11,
616
+ 47,
617
+ 12
618
+ ]
619
+ ],
620
+ ".model.diffusion_model.output_blocks.7.1.proj_out": [
621
+ 33,
622
+ 9,
623
+ 39,
624
+ 12
625
+ ],
626
+ ".model.diffusion_model.output_blocks.8.0.in_layers.2": [
627
+ 58,
628
+ 17,
629
+ 82,
630
+ 20
631
+ ],
632
+ ".model.diffusion_model.output_blocks.8.0.out_layers.3": [
633
+ 56,
634
+ 15,
635
+ 75,
636
+ 18
637
+ ],
638
+ ".model.diffusion_model.output_blocks.8.0.skip_connection": [
639
+ 44,
640
+ 10,
641
+ 47,
642
+ 11
643
+ ],
644
+ ".model.diffusion_model.output_blocks.8.1.proj_in": [
645
+ 32,
646
+ 9,
647
+ 43,
648
+ 10
649
+ ],
650
+ ".model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff": [
651
+ [
652
+ 28,
653
+ 7,
654
+ 47,
655
+ 8
656
+ ],
657
+ [
658
+ 35,
659
+ 8,
660
+ 45,
661
+ 8
662
+ ]
663
+ ],
664
+ ".model.diffusion_model.output_blocks.8.1.proj_out": [
665
+ 35,
666
+ 10,
667
+ 44,
668
+ 10
669
+ ],
670
+ ".model.diffusion_model.output_blocks.8.2.conv": [
671
+ 65,
672
+ 19,
673
+ 85,
674
+ 22
675
+ ],
676
+ ".model.diffusion_model.output_blocks.9.0.in_layers.2": [
677
+ 37,
678
+ 10,
679
+ 35,
680
+ 10
681
+ ],
682
+ ".model.diffusion_model.output_blocks.9.0.out_layers.3": [
683
+ 28,
684
+ 6,
685
+ 23,
686
+ 5
687
+ ],
688
+ ".model.diffusion_model.output_blocks.9.0.skip_connection": [
689
+ 15,
690
+ 4,
691
+ 4,
692
+ 4
693
+ ],
694
+ ".model.diffusion_model.output_blocks.9.1.proj_in": [
695
+ 16,
696
+ 4,
697
+ 6,
698
+ 4
699
+ ],
700
+ ".model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff": [
701
+ [
702
+ 24,
703
+ 5,
704
+ 23,
705
+ 5
706
+ ],
707
+ [
708
+ 23,
709
+ 5,
710
+ 24,
711
+ 6
712
+ ]
713
+ ],
714
+ ".model.diffusion_model.output_blocks.9.1.proj_out": [
715
+ 16,
716
+ 4,
717
+ 14,
718
+ 4
719
+ ],
720
+ ".model.diffusion_model.output_blocks.10.0.in_layers.2": [
721
+ 31,
722
+ 9,
723
+ 38,
724
+ 10
725
+ ],
726
+ ".model.diffusion_model.output_blocks.10.0.out_layers.3": [
727
+ 20,
728
+ 4,
729
+ 24,
730
+ 4
731
+ ],
732
+ ".model.diffusion_model.output_blocks.10.0.skip_connection": [
733
+ 4,
734
+ 4,
735
+ 7,
736
+ 4
737
+ ],
738
+ ".model.diffusion_model.output_blocks.10.1.proj_in": [
739
+ 6,
740
+ 4,
741
+ 11,
742
+ 4
743
+ ],
744
+ ".model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff": [
745
+ [
746
+ 17,
747
+ 4,
748
+ 21,
749
+ 4
750
+ ],
751
+ [
752
+ 17,
753
+ 5,
754
+ 21,
755
+ 5
756
+ ]
757
+ ],
758
+ ".model.diffusion_model.output_blocks.10.1.proj_out": [
759
+ 9,
760
+ 4,
761
+ 12,
762
+ 4
763
+ ],
764
+ ".model.diffusion_model.output_blocks.11.0.in_layers.2": [
765
+ 7,
766
+ 4,
767
+ 18,
768
+ 4
769
+ ],
770
+ ".model.diffusion_model.output_blocks.11.0.out_layers.3": [
771
+ 16,
772
+ 6,
773
+ 22,
774
+ 5
775
+ ],
776
+ ".model.diffusion_model.output_blocks.11.0.skip_connection": [
777
+ 4,
778
+ 4,
779
+ 4,
780
+ 4
781
+ ],
782
+ ".model.diffusion_model.output_blocks.11.1.proj_in": [
783
+ 9,
784
+ 4,
785
+ 13,
786
+ 4
787
+ ],
788
+ ".model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff": [
789
+ [
790
+ 19,
791
+ 4,
792
+ 24,
793
+ 4
794
+ ],
795
+ [
796
+ 12,
797
+ 4,
798
+ 14,
799
+ 4
800
+ ]
801
+ ],
802
+ ".model.diffusion_model.output_blocks.11.1.proj_out": [
803
+ 7,
804
+ 4,
805
+ 10,
806
+ 4
807
+ ],
808
+ ".model.diffusion_model.out.2": [
809
+ 4,
810
+ 4,
811
+ 4,
812
+ 4
813
+ ],
814
+ ".model.diffusion_model_refNet.input_blocks.0.0": [
815
+ 4,
816
+ 4,
817
+ 4,
818
+ 4
819
+ ],
820
+ ".model.diffusion_model_refNet.input_blocks.1.0.in_layers.2": [
821
+ 17,
822
+ 8,
823
+ 26,
824
+ 8
825
+ ],
826
+ ".model.diffusion_model_refNet.input_blocks.1.0.out_layers.3": [
827
+ 21,
828
+ 14,
829
+ 37,
830
+ 12
831
+ ],
832
+ ".model.diffusion_model_refNet.input_blocks.1.1.proj_in": [
833
+ 11,
834
+ 8,
835
+ 19,
836
+ 6
837
+ ],
838
+ ".model.diffusion_model_refNet.input_blocks.1.1.transformer_blocks.0.ff": [
839
+ [
840
+ 14,
841
+ 12,
842
+ 24,
843
+ 7
844
+ ],
845
+ [
846
+ 17,
847
+ 12,
848
+ 26,
849
+ 7
850
+ ]
851
+ ],
852
+ ".model.diffusion_model_refNet.input_blocks.1.1.proj_out": [
853
+ 11,
854
+ 7,
855
+ 20,
856
+ 5
857
+ ],
858
+ ".model.diffusion_model_refNet.input_blocks.2.0.in_layers.2": [
859
+ 27,
860
+ 15,
861
+ 40,
862
+ 13
863
+ ],
864
+ ".model.diffusion_model_refNet.input_blocks.2.0.out_layers.3": [
865
+ 26,
866
+ 15,
867
+ 38,
868
+ 12
869
+ ],
870
+ ".model.diffusion_model_refNet.input_blocks.2.1.proj_in": [
871
+ 15,
872
+ 7,
873
+ 21,
874
+ 6
875
+ ],
876
+ ".model.diffusion_model_refNet.input_blocks.2.1.transformer_blocks.0.ff": [
877
+ [
878
+ 17,
879
+ 13,
880
+ 30,
881
+ 9
882
+ ],
883
+ [
884
+ 16,
885
+ 12,
886
+ 27,
887
+ 8
888
+ ]
889
+ ],
890
+ ".model.diffusion_model_refNet.input_blocks.2.1.proj_out": [
891
+ 12,
892
+ 7,
893
+ 18,
894
+ 6
895
+ ],
896
+ ".model.diffusion_model_refNet.input_blocks.3.0.op": [
897
+ 27,
898
+ 13,
899
+ 43,
900
+ 12
901
+ ],
902
+ ".model.diffusion_model_refNet.input_blocks.4.0.in_layers.2": [
903
+ 30,
904
+ 19,
905
+ 49,
906
+ 14
907
+ ],
908
+ ".model.diffusion_model_refNet.input_blocks.4.0.out_layers.3": [
909
+ 32,
910
+ 26,
911
+ 55,
912
+ 15
913
+ ],
914
+ ".model.diffusion_model_refNet.input_blocks.4.0.skip_connection": [
915
+ 22,
916
+ 10,
917
+ 30,
918
+ 9
919
+ ],
920
+ ".model.diffusion_model_refNet.input_blocks.4.1.proj_in": [
921
+ 22,
922
+ 14,
923
+ 35,
924
+ 10
925
+ ],
926
+ ".model.diffusion_model_refNet.input_blocks.4.1.transformer_blocks.0.ff": [
927
+ [
928
+ 26,
929
+ 25,
930
+ 52,
931
+ 14
932
+ ],
933
+ [
934
+ 28,
935
+ 22,
936
+ 51,
937
+ 14
938
+ ]
939
+ ],
940
+ ".model.diffusion_model_refNet.input_blocks.4.1.proj_out": [
941
+ 24,
942
+ 15,
943
+ 40,
944
+ 11
945
+ ],
946
+ ".model.diffusion_model_refNet.input_blocks.5.0.in_layers.2": [
947
+ 44,
948
+ 30,
949
+ 78,
950
+ 22
951
+ ],
952
+ ".model.diffusion_model_refNet.input_blocks.5.0.out_layers.3": [
953
+ 28,
954
+ 29,
955
+ 56,
956
+ 15
957
+ ],
958
+ ".model.diffusion_model_refNet.input_blocks.5.1.proj_in": [
959
+ 20,
960
+ 13,
961
+ 34,
962
+ 9
963
+ ],
964
+ ".model.diffusion_model_refNet.input_blocks.5.1.transformer_blocks.0.ff": [
965
+ [
966
+ 26,
967
+ 27,
968
+ 52,
969
+ 14
970
+ ],
971
+ [
972
+ 23,
973
+ 23,
974
+ 53,
975
+ 14
976
+ ]
977
+ ],
978
+ ".model.diffusion_model_refNet.input_blocks.5.1.proj_out": [
979
+ 17,
980
+ 14,
981
+ 36,
982
+ 10
983
+ ],
984
+ ".model.diffusion_model_refNet.input_blocks.6.0.op": [
985
+ 46,
986
+ 31,
987
+ 82,
988
+ 21
989
+ ],
990
+ ".model.diffusion_model_refNet.input_blocks.7.0.in_layers.2": [
991
+ 75,
992
+ 41,
993
+ 116,
994
+ 32
995
+ ],
996
+ ".model.diffusion_model_refNet.input_blocks.7.0.out_layers.3": [
997
+ 67,
998
+ 50,
999
+ 108,
1000
+ 29
1001
+ ],
1002
+ ".model.diffusion_model_refNet.input_blocks.7.0.skip_connection": [
1003
+ 31,
1004
+ 19,
1005
+ 59,
1006
+ 15
1007
+ ],
1008
+ ".model.diffusion_model_refNet.input_blocks.7.1.proj_in": [
1009
+ 36,
1010
+ 29,
1011
+ 73,
1012
+ 19
1013
+ ],
1014
+ ".model.diffusion_model_refNet.input_blocks.7.1.transformer_blocks.0.ff": [
1015
+ [
1016
+ 74,
1017
+ 61,
1018
+ 106,
1019
+ 26
1020
+ ],
1021
+ [
1022
+ 63,
1023
+ 49,
1024
+ 90,
1025
+ 24
1026
+ ]
1027
+ ],
1028
+ ".model.diffusion_model_refNet.input_blocks.7.1.proj_out": [
1029
+ 34,
1030
+ 29,
1031
+ 68,
1032
+ 18
1033
+ ],
1034
+ ".model.diffusion_model_refNet.input_blocks.8.0.in_layers.2": [
1035
+ 92,
1036
+ 56,
1037
+ 128,
1038
+ 36
1039
+ ],
1040
+ ".model.diffusion_model_refNet.input_blocks.8.0.out_layers.3": [
1041
+ 43,
1042
+ 51,
1043
+ 66,
1044
+ 16
1045
+ ],
1046
+ ".model.diffusion_model_refNet.input_blocks.8.1.proj_in": [
1047
+ 26,
1048
+ 28,
1049
+ 59,
1050
+ 15
1051
+ ],
1052
+ ".model.diffusion_model_refNet.input_blocks.8.1.transformer_blocks.0.ff": [
1053
+ [
1054
+ 188,
1055
+ 69,
1056
+ 232,
1057
+ 69
1058
+ ],
1059
+ [
1060
+ 140,
1061
+ 51,
1062
+ 173,
1063
+ 51
1064
+ ]
1065
+ ],
1066
+ ".model.diffusion_model_refNet.input_blocks.8.1.proj_out": [
1067
+ 91,
1068
+ 33,
1069
+ 113,
1070
+ 33
1071
+ ]
1072
+ }
download_checkpoints.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from imports import *
4
+
5
+
6
+
7
+ def _download(repo_id, filename, local_path: Path) -> Path:
8
+ local_path = Path(local_path)
9
+ from huggingface_hub import hf_hub_download
10
+ local_path.parent.mkdir(parents=True, exist_ok=True)
11
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
12
+ print(f"downloading to {local_path}")
13
+ downloaded = hf_hub_download(
14
+ repo_id=repo_id,
15
+ filename=filename,
16
+ local_dir=str(local_path.parent),
17
+ local_dir_use_symlinks=False,
18
+ token=token,
19
+ )
20
+
21
+
22
+
23
+ _download("CompVis/stable-diffusion-v-1-4-original",SD14_filename, SD14_localpath)
24
+
25
+ _download("scy639/UniBioTransfer",PRETRAIN_CKPT_PATH, ".")
26
+ _download("scy639/UniBioTransfer",PRETRAIN_JSON_PATH, ".")
27
+
28
+ _download("scy639/UniBioTransfer","Other_dependencies/arcface/model_ir_se50.pth", ".")
29
+ _download("scy639/UniBioTransfer","Other_dependencies/face_parsing/79999_iter.pth", ".")
eval_tool/lpips/__init__.py ADDED
File without changes
eval_tool/lpips/lpips.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from eval_tool.lpips.networks import get_network, LinLayers
5
+ from eval_tool.lpips.utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+ Arguments:
12
+ net_type (str): the network type to compare the features:
13
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14
+ version (str): the version of LPIPS. Default: 0.1.
15
+ """
16
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17
+
18
+ assert version in ['0.1'], 'v0.1 is only supported now'
19
+
20
+ super(LPIPS, self).__init__()
21
+
22
+ # pretrained network
23
+ self.net = get_network(net_type)
24
+
25
+ # linear layers
26
+ self.lin = LinLayers(self.net.n_channels_list)
27
+ self.lin.load_state_dict(get_state_dict(net_type, version))
28
+
29
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
30
+ feat_x, feat_y = self.net(x), self.net(y)
31
+
32
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34
+
35
+ return torch.sum(torch.cat(res, 0)) / x.shape[0]
eval_tool/lpips/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from eval_tool.lpips.utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(True).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
eval_tool/lpips/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)+1e-16) #
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
examples/face/ref-semantic_mask.png ADDED
examples/face/ref.png ADDED

Git LFS Details

  • SHA256: a477d2f5928b4ab40046fdcd7a0b9d4f35d619822eccd4137396fc06dbb82b48
  • Pointer size: 131 Bytes
  • Size of remote file: 399 kB
examples/face/tgt-semantic_mask.png ADDED
examples/face/tgt.png ADDED

Git LFS Details

  • SHA256: dea3592ab41c766b8d1ba041eda3b545871f1684528bff5c40321a9fbd7c8546
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
examples/hair/ref-semantic_mask.png ADDED
examples/hair/ref.png ADDED

Git LFS Details

  • SHA256: 946981b5a077df22a393d6e1ebb1bdef73c020f25e99339f732345777ae6565c
  • Pointer size: 131 Bytes
  • Size of remote file: 435 kB
examples/hair/tgt-semantic_mask.png ADDED
examples/hair/tgt.png ADDED

Git LFS Details

  • SHA256: daa1c69651861183fe113995abb20192fafe829a7b1a349c2ccc2713d7b057b4
  • Pointer size: 131 Bytes
  • Size of remote file: 399 kB
examples/head/ref-semantic_mask.png ADDED
examples/head/ref.png ADDED

Git LFS Details

  • SHA256: ff89b38ec94ee110a8760c6bb6b316c8ad2f4502a14aec1d217305e0ca2dfa47
  • Pointer size: 131 Bytes
  • Size of remote file: 440 kB
examples/head/tgt-semantic_mask.png ADDED
examples/head/tgt.png ADDED

Git LFS Details

  • SHA256: 9467c48978020761d76df2e133808f490f2eacb359f5fda61d08017a77b20151
  • Pointer size: 131 Bytes
  • Size of remote file: 336 kB
examples/inputs.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ target_path_1 reference_path_1
2
+ target_path_2 reference_path_2
3
+ target_path_3 reference_path_3
4
+ target_path_4 reference_path_4
5
+ target_path_5 reference_path_5
examples/motion/ref-semantic_mask.png ADDED
examples/motion/ref.png ADDED

Git LFS Details

  • SHA256: 8b58a80c13e5741072b6c603f7edd61ba3e3c9456536064b0d4746f4bab9c786
  • Pointer size: 131 Bytes
  • Size of remote file: 424 kB
examples/motion/tgt-semantic_mask.png ADDED
examples/motion/tgt.png ADDED

Git LFS Details

  • SHA256: 6e527760e591e97ab36892ee683f91673a678a8b23b7603779d430cfbcc0e5f3
  • Pointer size: 131 Bytes
  • Size of remote file: 427 kB
gen_lmk_and_mask.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ENABLE_lmk_cache = False
2
+ ENABLE_mask_cache = False
3
+
4
+
5
+ import cv2
6
+ from imports import *
7
+ from util_cv2 import cv2_resize_auto_interpolation
8
+ from Mediapipe_Result_Cache import Mediapipe_Result_Cache
9
+ from lmk_util.lmk_extractor import LandmarkExtractor
10
+
11
+
12
+ def gen_lmk_and_mask(img_paths, size=512, write_cache=True):
13
+ extractor = LandmarkExtractor()
14
+ cache = Mediapipe_Result_Cache()
15
+ seen = set()
16
+ for p in img_paths:
17
+ if not p:
18
+ continue
19
+ p = str(p)
20
+ if p in seen:
21
+ continue
22
+ seen.add(p)
23
+
24
+ cache_path = cache.get_path(p)
25
+ if not ( cache_path.exists() and ENABLE_lmk_cache ):
26
+ img = cv2.imread(p)
27
+ if img is None:
28
+ print(f"cv2.imread failed: {p}")
29
+ raise
30
+ continue
31
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
32
+ img = cv2_resize_auto_interpolation(img, (size, size))
33
+ lmks = extractor.extract_single(img)
34
+ if lmks is None:
35
+ print(f"no lmks: {p}")
36
+ raise
37
+ continue
38
+ if write_cache:
39
+ cache.set(p, lmks)
40
+
41
+ path_img_2_path_mask(p, reuse_if_exists=ENABLE_mask_cache, label_mode="RF12_")
gen_semantic_mask.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ def:
3
+ tgt: Target image to be edited (face swapped)
4
+ ref: Face ID source image (also called src in REFace)
5
+ swap: Swapped output image, using face ID from ref to replace face in tgt
6
+ """
7
+ import os
8
+ from pathlib import Path
9
+ from tqdm import tqdm
10
+ from my_py_lib.image_util import print_image_statistics
11
+ import torch
12
+ import torchvision
13
+ from PIL import Image
14
+ import numpy as np
15
+ from einops import rearrange
16
+ from torchvision.transforms import Resize
17
+ from torchvision.utils import make_grid
18
+ from contextlib import nullcontext
19
+ from torch.cuda.amp import autocast
20
+ from omegaconf import OmegaConf
21
+ import cv2
22
+
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Sampling configs
26
+ DDIM_STEPS = 50
27
+ GUIDANCE_SCALE = 3.0
28
+ IMG_SIZE = 512
29
+ LATENT_CHANNELS = 4
30
+ DOWNSAMPLE_FACTOR = 8
31
+ START_NOISE_T = 1000
32
+ DDIM_ETA = 0.0
33
+ PRECISION = "full" # or "autocast"
34
+ FIXED_CODE = False # whether to use fixed starting code
35
+ SAVE_INTERMEDIATES = False # whether to save intermediate results
36
+ LOG_EVERY_T = 100 # log frequency during sampling
37
+
38
+
39
+ class MaskModel_LazyLoader:
40
+ model = None
41
+ @classmethod
42
+ def get(cls):
43
+ faceParsing_ckpt = "Other_dependencies/face_parsing/79999_iter.pth"
44
+ if cls.model is None:
45
+ from pretrained.face_parsing.face_parsing_demo import init_faceParsing_pretrained_model
46
+ cls.model = init_faceParsing_pretrained_model(
47
+ 'default',
48
+ faceParsing_ckpt,
49
+ ''
50
+ )
51
+ print(f"Initialized face parsing model from {faceParsing_ckpt}")
52
+ return cls.model
53
+
54
+
55
+ def gen_semantic_mask(path_img: Path, path_mask_to_save: Path, label_mode:str, path_vis: Path = None):
56
+ """Generate semantic mask for an image using face parsing model"""
57
+ pil_im = Image.open(path_img).convert("RGB")
58
+ w, h = pil_im.size
59
+ # print(f"{pil_im.size=}") # 512,512
60
+ TMP_size = 1024
61
+ if w != TMP_size or h != TMP_size:
62
+ pil_im = pil_im.resize((TMP_size, TMP_size), Image.BILINEAR)
63
+
64
+ model = MaskModel_LazyLoader.get()
65
+ from pretrained.face_parsing.face_parsing_demo import faceParsing_demo, vis_parsing_maps
66
+
67
+ # print(f"{pil_im.size=}") # 1024,1024
68
+ # Generate mask with conversion to seg12 format
69
+ mask = faceParsing_demo(
70
+ model,
71
+ pil_im,
72
+ label_mode,
73
+ model_name='default'
74
+ )
75
+
76
+ try:
77
+ Image.fromarray(mask).save(path_mask_to_save)
78
+ except Exception as e:
79
+ print(f"{e=}")
80
+ print(f"{path_mask_to_save=}")
81
+ if path_mask_to_save.exists():
82
+ path_mask_to_save.unlink()
83
+ print(f'path_mask_to_save.unlink()')
84
+ # print(f"Saved mask: {path_mask_to_save}")
85
+ # print(f"{mask.shape=}") # 512,512
86
+
87
+ if path_vis:
88
+ mask_vis = vis_parsing_maps(pil_im, mask)
89
+ Image.fromarray(mask_vis).save(path_vis)
90
+ print(f"Saved mask vis: {path_vis}")
get_mask.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from util_and_constant import *
2
+ from pathlib import Path
3
+ from PIL import Image
4
+ import cv2
5
+ import numpy as np
6
+
7
+ def path_img_2_mask(
8
+ path_img,
9
+ preserve=(1, 2, 3, 5, 6, 7, 9, 10, 11, ), # int | list-liek. Default val represents face
10
+ ):
11
+ """
12
+ 0 bg, 1 mouth, 2 eyebrow, 3 eyes, 4 hair, 5 nose, 6 face (excluding facial parts), 7: ear, 8: neck, 9: tooth
13
+ 10: eye_glass, 11: ear_rings
14
+ """
15
+ if isinstance(preserve,int):
16
+ preserve = (preserve,)
17
+ if 1:
18
+ assert isinstance(preserve,tuple) or isinstance(preserve,list)
19
+ assert all(isinstance(p, int) and 0 <= p <= 11 for p in preserve)
20
+ import numpy as np
21
+ from PIL import Image
22
+ mask_path = path_img_2_path_mask(path_img)
23
+ mask = Image.open(mask_path).convert('L')
24
+ mask = np.array(mask)
25
+ mask = np.isin(mask, preserve)
26
+ return mask
27
+
28
+
29
+
30
+ def get_forehead_mask(sm_mask):
31
+ # return mask (np bool) where the forehead (face above eyebrows) is True
32
+ sm_mask = np.array(sm_mask)
33
+ # 6 is face (excluding facial parts); keep only the forehead part
34
+ # First get all face pixels
35
+ face_mask = (sm_mask == 6)
36
+ # Get eyebrow pixels to determine forehead boundary
37
+ # if 2 in sm, ; elif 3(eyes) in ; elif 10(eye_glass) in ; else
38
+ if 2 in sm_mask:
39
+ eyebrow_mask = (sm_mask == 2)
40
+ eyebrow_coords = np.where(eyebrow_mask)
41
+ eyebrow_top = np.min(eyebrow_coords[0])
42
+ # Forehead is face region above eyebrows
43
+ forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < eyebrow_top)
44
+ elif 3 in sm_mask:
45
+ eye_mask = (sm_mask == 3)
46
+ eye_coords = np.where(eye_mask)
47
+ eye_top = np.min(eye_coords[0])
48
+ # Estimate forehead as region above eyes with some margin
49
+ forehead_threshold = eye_top - 20 # 20 pixels above eyes as forehead
50
+ forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold)
51
+ elif 10 in sm_mask:
52
+ glass_mask = (sm_mask == 10)
53
+ glass_coords = np.where(glass_mask)
54
+ glass_top = np.min(glass_coords[0])
55
+ # Forehead is face region above glasses
56
+ forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < glass_top)
57
+ else:
58
+ # If no eyebrows detected, keep upper portion of face
59
+ face_coords = np.where(face_mask)
60
+ if len(face_coords[0]) > 0:
61
+ face_top = np.min(face_coords[0])
62
+ face_height = np.max(face_coords[0]) - face_top
63
+ forehead_threshold = face_top + face_height * 0.15 # top 15% as forehead
64
+ forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold)
65
+ else:
66
+ forehead_mask = np.zeros_like(face_mask, dtype=bool)
67
+ forehead_mask = forehead_mask & face_mask
68
+ return forehead_mask
global_.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ some global variables
3
+ """
4
+ task :int = None # current batch task id
5
+
6
+ TP_enable:bool = None # None means not set yet. should be set in imports.py
7
+ rank_:int = None
8
+ moduleName_2_adaRank:dict = {} # adaptive rank for each shared+LoRA module
9
+
hf_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Hub compatible model wrapper for UniBioTransfer.
3
+ Provides from_pretrained() and push_to_hub() functionality via PyTorchModelHubMixin.
4
+ """
5
+ from pathlib import Path
6
+ import torch
7
+ import json
8
+ import copy
9
+ import os
10
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
11
+
12
+ import global_
13
+ from ldm.models.diffusion.ddpm import LatentDiffusion, LandmarkExtractor
14
+ from ldm.util import instantiate_from_config
15
+ from omegaconf import OmegaConf
16
+ from pytorch_lightning import seed_everything
17
+ from MoE import offload_unused_tasks__LD
18
+ from multiTask_model import TaskSpecific_MoE, replace_modules_lossless
19
+ from my_py_lib.torch_util import cleanup_gpu_memory
20
+
21
+ TASKS = (0, 1, 2, 3)
22
+ TASK_NAME2ID = {"face": 0, "hair": 1, "motion": 2, "head": 3}
23
+ TASK_ID2NAME = {v: k for k, v in TASK_NAME2ID.items()}
24
+
25
+ SD14_FILENAME = "sd-v1-4.ckpt"
26
+ SD14_REPO = "CompVis/stable-diffusion-v-1-4-original"
27
+ PRETRAIN_REPO = "scy639/UniBioTransfer"
28
+
29
+
30
+ def _load_first_stage_from_sd14(model, sd14_path):
31
+ """Load first_stage_model (VAE) from SD v1.4 checkpoint."""
32
+ print(f"Loading first_stage_model from {sd14_path}")
33
+ sd14 = torch.load(str(sd14_path), map_location="cpu")
34
+ if isinstance(sd14, dict) and "state_dict" in sd14:
35
+ sd14_sd = sd14["state_dict"]
36
+ else:
37
+ sd14_sd = sd14
38
+
39
+ prefixes = ["first_stage_model.", "model.first_stage_model."]
40
+ fs_sd = {}
41
+ for prefix in prefixes:
42
+ for k, v in sd14_sd.items():
43
+ if k.startswith(prefix):
44
+ fs_sd[k[len(prefix):]] = v
45
+ if fs_sd:
46
+ break
47
+
48
+ if not fs_sd:
49
+ raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.")
50
+
51
+ model.first_stage_model.load_state_dict(fs_sd, strict=True)
52
+
53
+
54
+ class UniBioTransferModel(LatentDiffusion, PyTorchModelHubMixin):
55
+ """
56
+ Hugging Face Hub compatible wrapper for UniBioTransfer.
57
+
58
+ Inherits from LatentDiffusion and adds HF Hub integration via PyTorchModelHubMixin.
59
+
60
+ Usage:
61
+ # Load model from HF Hub
62
+ model = UniBioTransferModel.from_pretrained("scy639/UniBioTransfer", task="face")
63
+
64
+ # Push to HF Hub
65
+ model.push_to_hub("your-repo/UniBioTransfer")
66
+
67
+ Args:
68
+ config: Model config dict (handled by PyTorchModelHubMixin)
69
+ task: Task name or ID (face/hair/motion/head)
70
+ **kwargs: Additional arguments passed to LatentDiffusion
71
+ """
72
+
73
+ def __init__(self, config=None, task="face", **kwargs):
74
+ self._task_name = task if isinstance(task, str) else TASK_ID2NAME.get(task, "face")
75
+ self._task_id = TASK_NAME2ID.get(self._task_name, 0) if isinstance(task, str) else task
76
+
77
+ global_.task = self._task_id
78
+
79
+ if config is None:
80
+ config = {}
81
+
82
+ super().__init__(**config)
83
+
84
+ self._hf_config = {
85
+ "task": self._task_name,
86
+ "task_id": self._task_id,
87
+ }
88
+
89
+ @classmethod
90
+ def from_pretrained(
91
+ cls,
92
+ pretrained_model_name_or_path=None,
93
+ task="face",
94
+ device="cuda",
95
+ download_sd14=True,
96
+ download_deps=True,
97
+ cache_dir=None,
98
+ **kwargs,
99
+ ):
100
+ """
101
+ Load model from Hugging Face Hub.
102
+
103
+ Args:
104
+ pretrained_model_name_or_path: HF repo ID or local path.
105
+ Default: "scy639/UniBioTransfer"
106
+ task: Task name (face/hair/motion/head) or task ID (0/1/2/3)
107
+ device: Device to load model to ("cuda" or "cpu")
108
+ download_sd14: Whether to download SD v1.4 VAE weights
109
+ download_deps: Whether to download other dependencies (ArcFace, DLIB, face_parsing)
110
+ cache_dir: Cache directory for downloads
111
+ **kwargs: Additional arguments
112
+
113
+ Returns:
114
+ UniBioTransferModel: Loaded model
115
+ """
116
+ task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
117
+ task_name = TASK_ID2NAME.get(task_id, "face")
118
+
119
+ global_.task = task_id
120
+
121
+ if pretrained_model_name_or_path is None:
122
+ pretrained_model_name_or_path = PRETRAIN_REPO
123
+
124
+ repo_id = pretrained_model_name_or_path
125
+
126
+ cache_dir = Path(cache_dir) if cache_dir else Path(".")
127
+
128
+ ckpt_path = cache_dir / "checkpoints" / "pretrained.ckpt"
129
+ json_path = cache_dir / "checkpoints" / "pretrained.json"
130
+ sd14_path = cache_dir / "checkpoints" / SD14_FILENAME
131
+ arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
132
+ face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
133
+
134
+ def _download_file(repo, filename, local_path):
135
+ local_path = Path(local_path)
136
+ local_path.parent.mkdir(parents=True, exist_ok=True)
137
+ print(f"Downloading {filename} from {repo}...")
138
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
139
+ hf_hub_download(
140
+ repo_id=repo,
141
+ filename=filename,
142
+ local_dir=str(local_path.parent),
143
+ local_dir_use_symlinks=False,
144
+ token=token,
145
+ )
146
+
147
+ if not ckpt_path.exists():
148
+ _download_file(repo_id, "checkpoints/pretrained.ckpt", ckpt_path)
149
+ if not json_path.exists():
150
+ _download_file(repo_id, "checkpoints/pretrained.json", json_path)
151
+
152
+ if download_sd14 and not sd14_path.exists():
153
+ _download_file(SD14_REPO, SD14_FILENAME, sd14_path)
154
+
155
+ if download_deps:
156
+ if not arcface_path.exists():
157
+ _download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", arcface_path)
158
+ if not face_parsing_path.exists():
159
+ _download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", face_parsing_path)
160
+
161
+ seed_everything(42)
162
+
163
+ cur_dir = Path(__file__).parent
164
+ yaml_path = cur_dir / "LatentDiffusion.yaml"
165
+ if not yaml_path.exists():
166
+ yaml_path = Path("LatentDiffusion.yaml")
167
+
168
+ model_config = OmegaConf.load(yaml_path).model
169
+ model = instantiate_from_config(model_config)
170
+
171
+ with open(json_path, 'r') as f:
172
+ global_.moduleName_2_adaRank = json.load(f)
173
+ print(f"Loaded adaptive rank config from {json_path}")
174
+
175
+ _src0 = copy.deepcopy(model.model.diffusion_model)
176
+ _src1 = copy.deepcopy(model.model.diffusion_model)
177
+ _src2 = copy.deepcopy(model.model.diffusion_model)
178
+ _src3 = copy.deepcopy(model.model.diffusion_model)
179
+ replace_modules_lossless(
180
+ model.model.diffusion_model,
181
+ [_src0, _src1, _src2, _src3],
182
+ [0, 1, 2, 3],
183
+ parent_name=".model.diffusion_model",
184
+ )
185
+
186
+ model.ID_proj_out = TaskSpecific_MoE([
187
+ copy.deepcopy(model.ID_proj_out),
188
+ copy.deepcopy(model.ID_proj_out),
189
+ copy.deepcopy(model.ID_proj_out),
190
+ ], [0, 2, 3])
191
+ model.landmark_proj_out = TaskSpecific_MoE([
192
+ copy.deepcopy(model.landmark_proj_out),
193
+ copy.deepcopy(model.landmark_proj_out),
194
+ copy.deepcopy(model.landmark_proj_out),
195
+ ], [0, 2, 3])
196
+ model.proj_out_source__head = TaskSpecific_MoE([
197
+ copy.deepcopy(model.proj_out_source__head),
198
+ copy.deepcopy(model.proj_out_source__head),
199
+ ], [2, 3])
200
+
201
+ from util_and_constant import REFNET
202
+ if REFNET.ENABLE:
203
+ shared_ref = model.model.diffusion_model_refNet
204
+ src0 = shared_ref
205
+ src1 = copy.deepcopy(shared_ref)
206
+ src2 = copy.deepcopy(shared_ref)
207
+ src3 = copy.deepcopy(shared_ref)
208
+ replace_modules_lossless(shared_ref, [src0, src1, src2, src3], [0, 1, 2, 3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
209
+ from ldm.models.diffusion.bank import Bank
210
+ model.model.bank = Bank(
211
+ reader=model.model.diffusion_model,
212
+ writer=model.model.diffusion_model_refNet
213
+ )
214
+
215
+ print(f"Loading model weights from {ckpt_path}")
216
+ pl_sd = torch.load(str(ckpt_path), map_location="cpu")
217
+ if isinstance(pl_sd, dict) and "state_dict" in pl_sd:
218
+ sd = pl_sd["state_dict"]
219
+ else:
220
+ sd = pl_sd
221
+
222
+ m, u = model.load_state_dict(sd, strict=False)
223
+ if len(m) > 0:
224
+ print(f"Missing keys: {len(m)}")
225
+ if len(u) > 0:
226
+ print(f"Unexpected keys: {len(u)}")
227
+
228
+ _load_first_stage_from_sd14(model, sd14_path)
229
+
230
+ # offload_unused_tasks__LD(model, task_id, method="cpu")
231
+
232
+ model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False)
233
+ cleanup_gpu_memory()
234
+
235
+ # ZeroGPU 兼容:只在 device 不是 "cpu" 且 CUDA 可用时才移动到 GPU
236
+ # 如果传入 device="cpu",保持模型在 CPU 上(ZeroGPU 初始化时不碰显卡)
237
+ if device != "cpu" and torch.cuda.is_available():
238
+ model = model.to(torch.device(device))
239
+ else:
240
+ model = model.to(torch.device("cpu"))
241
+ model.eval()
242
+
243
+ model._task_id = task_id
244
+ model._task_name = task_name
245
+ model._hf_config = {"task": task_name, "task_id": task_id}
246
+
247
+ return model
imports.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ #---------------------------------------------------------------------------------------------------------------------
5
+ from util_and_constant import *
6
+ from get_mask import *
7
+ from util_cv2 import *
8
+
infer.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------- Config -------------------------------------------------
2
+ num_workers :int = 1
3
+ DDIM_STEPS = 50
4
+ BATCH_SIZE = 1
5
+ FIXED_CODE = False
6
+ # for vis
7
+ SAVE_INTERMEDIATES = True
8
+ NUM_grid_in_a_column = 5
9
+ # ------------------------------------------------------------------------------------------------------------------------
10
+ import argparse
11
+ parser = argparse.ArgumentParser(description="Custom inference for tgt/ref image pairs.")
12
+ parser.add_argument("--task-name", type=str,
13
+ default='face',
14
+ help="face|hair|motion|head")
15
+ parser.add_argument("--out-dir", type=str, default='examples/outputs', help="Output directory")
16
+ # option 1: pass 2 paths
17
+ parser.add_argument("--tgt", type=str, default=None, help="Path to target image. if None, will use paths read from --pair-list")
18
+ parser.add_argument("--ref", type=str, default=None, help="Path to reference image")
19
+ # option 2: pass a txt containing paths
20
+ parser.add_argument("--pair-list", type=str, default='examples/inputs.txt', help="white-space-separated list file: tgt_path ref_path")
21
+ args = parser.parse_args()
22
+
23
+ #-----------------------------------------set TASK--------------------------------------------------------------------------
24
+
25
+ task_name :str = args.task_name
26
+ TASK :int = {
27
+ 'face': 0,
28
+ 'hair': 1,
29
+ 'motion': 2,
30
+ 'head': 3,
31
+ }[task_name]
32
+ print(f'task: {task_name} transfer (ID: {TASK})')
33
+ # ------------------------------------------------------------------------------------------------------------------------
34
+
35
+
36
+ import sys
37
+ import os
38
+ from pathlib import Path
39
+
40
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
41
+
42
+ from imports import *
43
+ import torch
44
+ import numpy as np
45
+ from omegaconf import OmegaConf
46
+ from PIL import Image
47
+ from tqdm import tqdm
48
+ from einops import rearrange
49
+ from torchvision.utils import make_grid
50
+ from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
51
+ from pytorch_lightning import seed_everything
52
+ from torch import autocast
53
+ from contextlib import nullcontext
54
+ import torchvision
55
+
56
+ from ldm.models.diffusion.ddpm import LatentDiffusion
57
+ from ldm.util import instantiate_from_config
58
+ from ldm.models.diffusion.ddim import DDIMSampler
59
+ from Dataset_custom import Dataset_custom
60
+ from MoE import offload_unused_tasks__LD
61
+ from ldm.models.diffusion.ddpm import LandmarkExtractor
62
+ from my_py_lib.torch_util import cleanup_gpu_memory
63
+ from gen_lmk_and_mask import gen_lmk_and_mask
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+ # ------------------------------------------------------------------------------------------------------------------------
74
+ DDIM_ETA = 0.0
75
+ SCALE = 3.0
76
+ PRECISION = "full" # "full" or "autocast"
77
+ H = 512
78
+ W = 512
79
+ C = 4
80
+ F = 8
81
+ # ------------------------------------------------------------------------------------------------------------------------
82
+
83
+
84
+ def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None:
85
+ print(f"Loading first_stage_model from {sd14_path}")
86
+ sd14 = torch.load(str(sd14_path), map_location="cpu")
87
+ if isinstance(sd14, dict) and "state_dict" in sd14:
88
+ sd14_sd = sd14["state_dict"]
89
+ else:
90
+ sd14_sd = sd14
91
+
92
+ prefixes = ["first_stage_model.", "model.first_stage_model."]
93
+ fs_sd = {}
94
+ for prefix in prefixes:
95
+ for k, v in sd14_sd.items():
96
+ if k.startswith(prefix):
97
+ fs_sd[k[len(prefix):]] = v
98
+ if fs_sd:
99
+ break
100
+
101
+ if not fs_sd:
102
+ raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.")
103
+
104
+ model.first_stage_model.load_state_dict(fs_sd, strict=True)
105
+
106
+
107
+ def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num):
108
+ x = model.decode_first_stage(x)
109
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
110
+ x = x.cpu().permute(0, 2, 3, 1).numpy()
111
+ for i in range(len(x)):
112
+ img = Image.fromarray((x[i] * 255).astype(np.uint8))
113
+ save_path = Path(base_path) / segment_id
114
+ save_path.mkdir(parents=True, exist_ok=True)
115
+ img.save(save_path / f"{intermediate_num}.png")
116
+
117
+
118
+ def get_tensor_clip(normalize=True, toTensor=True):
119
+ transform_list = []
120
+ if toTensor:
121
+ transform_list += [torchvision.transforms.ToTensor()]
122
+ if normalize:
123
+ transform_list += [
124
+ torchvision.transforms.Normalize(
125
+ (0.48145466, 0.4578275, 0.40821073),
126
+ (0.26862954, 0.26130258, 0.27577711),
127
+ )
128
+ ]
129
+ return torchvision.transforms.Compose(transform_list)
130
+
131
+
132
+ def load_model_from_config(ckpt, verbose=1):
133
+ if 1:
134
+ ckpt = Path(ckpt)
135
+ print(f"Loading model from {ckpt}")
136
+ pl_sd = torch.load(str(ckpt), map_location="cpu")
137
+ if isinstance(pl_sd, dict) and "state_dict" in pl_sd:
138
+ sd = pl_sd["state_dict"]
139
+ else:
140
+ sd = pl_sd
141
+ else:
142
+ print("DEBUG_skip_load_ckpt")
143
+ if 1:
144
+ from init_model import get_moe
145
+ model: LatentDiffusion = get_moe()
146
+ model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False)
147
+ cleanup_gpu_memory()
148
+ if 1:
149
+ m, u = model.load_state_dict(sd, strict=False)
150
+ if len(m) > 0 and verbose:
151
+ print("missing keys:")
152
+ pretty_print_torch_module_keys(m)
153
+ if len(u) > 0 and verbose:
154
+ print("unexpected keys:")
155
+ pretty_print_torch_module_keys(u)
156
+ load_first_stage_from_sd14(model, SD14_localpath)
157
+
158
+ offload_unused_tasks__LD(model, TASK, method="del") # for save cuda mem
159
+ model.cuda()
160
+ model.eval()
161
+ return model
162
+
163
+
164
+
165
+
166
+ def load_pairs(pair_list, tgt, ref):
167
+ if tgt and ref:
168
+ pairs = [(tgt, ref), ]
169
+ elif pair_list:
170
+ pairs = []
171
+ with open(pair_list, "r") as f:
172
+ for line_num, line in enumerate(f, start=1):
173
+ line = line.strip()
174
+ if not line or line.startswith("#"):
175
+ continue
176
+ parts = line.split(" ")
177
+ if len(parts) != 2:
178
+ raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}")
179
+ pairs.append((parts[0], parts[1]))
180
+ else:
181
+ raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.")
182
+ print(f"{pairs=}")
183
+ return pairs
184
+
185
+
186
+ def un_norm(x):
187
+ return (x + 1.0) / 2.0
188
+
189
+
190
+ def un_norm_clip(x1):
191
+ x = x1 * 1.0
192
+ reduce = False
193
+ if len(x.shape) == 3:
194
+ x = x.unsqueeze(0)
195
+ reduce = True
196
+ x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
197
+ x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
198
+ x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
199
+ if reduce:
200
+ x = x.squeeze(0)
201
+ return x
202
+
203
+
204
+ if __name__ == "__main__":
205
+ pairs = load_pairs(args.pair_list, args.tgt, args.ref)
206
+
207
+ out_dir = Path(args.out_dir)
208
+ result_path = out_dir / "results"
209
+ grid_path = out_dir / "grid"
210
+ inter_path = out_dir / "intermediates"
211
+ inter_pred_path = inter_path / "pred_x0"
212
+ inter_noised_path = inter_path / "noised"
213
+ out_dir.mkdir(parents=False, exist_ok=True)
214
+ result_path.mkdir(parents=False, exist_ok=True)
215
+ grid_path.mkdir(parents=False, exist_ok=True)
216
+ inter_path.mkdir(parents=False, exist_ok=True)
217
+ if SAVE_INTERMEDIATES:
218
+ inter_pred_path.mkdir(parents=False, exist_ok=True)
219
+ inter_noised_path.mkdir(parents=False, exist_ok=True)
220
+ paths_tgt = [p[0] for p in pairs]
221
+ paths_ref = [p[1] for p in pairs]
222
+ gen_lmk_and_mask(paths_tgt + paths_ref)
223
+
224
+ seed_everything(42)
225
+
226
+ model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, )
227
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
228
+ model = model.to(device)
229
+ sampler = DDIMSampler(model)
230
+
231
+ dataset = Dataset_custom(
232
+ "test",
233
+ task=TASK,
234
+ paths_tgt=paths_tgt,
235
+ paths_ref=paths_ref,
236
+ )
237
+ dataloader = torch.utils.data.DataLoader(
238
+ dataset,
239
+ batch_size=BATCH_SIZE,
240
+ num_workers=num_workers,
241
+ pin_memory=True,
242
+ shuffle=False,
243
+ drop_last=False,
244
+ )
245
+
246
+ start_code = None
247
+ if FIXED_CODE:
248
+ start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device)
249
+
250
+ precision_scope = autocast if PRECISION == "autocast" else nullcontext
251
+ grids = []
252
+ grid_stems = []
253
+
254
+ with torch.no_grad():
255
+ with precision_scope("cuda"):
256
+ with model.ema_scope():
257
+ for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader):
258
+ model.set_task(test_model_kwargs)
259
+ bs = test_batch.shape[0]
260
+
261
+ batch_ = {
262
+ **test_model_kwargs,
263
+ "GT": torch.zeros_like(test_model_kwargs["inpaint_image"]),
264
+ }
265
+ batch_, c = model.get_input_and_conditioning(batch_, device=device)
266
+ z_inpaint = batch_["z4_inpaint"]
267
+ z_inpaint_mask = batch_["tgt_mask_64"]
268
+ z_ref = batch_["z_ref"]
269
+ z9 = batch_["z9"]
270
+
271
+ uc = None
272
+ if SCALE != 1.0:
273
+ uc = model.learnable_vector[TASK].repeat(bs, 1, 1)
274
+
275
+ shape = [C, H // F, W // F]
276
+ local_start_code = start_code
277
+ if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs):
278
+ local_start_code = torch.randn([bs, C, H // F, W // F], device=device)
279
+ samples_ddim, intermediates = sampler.sample(
280
+ S=DDIM_STEPS,
281
+ conditioning=c,
282
+ batch_size=bs,
283
+ shape=shape,
284
+ verbose=False,
285
+ unconditional_guidance_scale=SCALE,
286
+ unconditional_conditioning=uc,
287
+ eta=DDIM_ETA,
288
+ x_T=local_start_code,
289
+ log_every_t=100,
290
+ z_inpaint=z_inpaint,
291
+ z_inpaint_mask=z_inpaint_mask,
292
+ z_ref=z_ref,
293
+ z9=z9,
294
+ )
295
+
296
+ if SAVE_INTERMEDIATES:
297
+ intermediate_pred_x0 = intermediates["pred_x0"]
298
+ intermediate_noised = intermediates["x_inter"]
299
+ for i in range(len(intermediate_pred_x0)):
300
+ for j in range(bs):
301
+ stem = f"{out_stem_batch[j]}"
302
+ save_sample_by_decode(
303
+ intermediate_pred_x0[i][j : j + 1],
304
+ model,
305
+ inter_pred_path,
306
+ stem,
307
+ i,
308
+ )
309
+ save_sample_by_decode(
310
+ intermediate_noised[i][j : j + 1],
311
+ model,
312
+ inter_noised_path,
313
+ stem,
314
+ i,
315
+ )
316
+
317
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
318
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
319
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
320
+
321
+ x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)
322
+ for i, x_sample in enumerate(x_checked_image_torch):
323
+ stem = f"{out_stem_batch[i]}"
324
+ out_path = result_path / f"{stem}.png"
325
+ img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
326
+ img.save(out_path)
327
+ print(f"{out_path=}")
328
+
329
+ for i, x_sample in enumerate(x_checked_image_torch):
330
+ all_img = []
331
+ all_img.append(un_norm(test_batch[i]).cpu())
332
+ if TASK != 2:
333
+ ref_img = test_model_kwargs["ref_imgs"].squeeze(1)
334
+ ref_img = torchvision.transforms.Resize([512, 512])(ref_img)
335
+ ref_img = un_norm_clip(ref_img[i]).cpu()
336
+ else:
337
+ ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu()
338
+ all_img.append(ref_img)
339
+ all_img.append(x_sample)
340
+
341
+ grid = torch.stack(all_img, 0)
342
+ grid = make_grid(grid)
343
+ grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
344
+ img = Image.fromarray(grid.astype(np.uint8))
345
+ stem = f"{out_stem_batch[i]}"
346
+ path_save_img = grid_path / f"grid-{stem}.jpg"
347
+ img.save(path_save_img)
348
+ print(f"{path_save_img=}")
349
+ grids.append(img)
350
+ grid_stems.append(stem)
351
+ if len(grids) >= NUM_grid_in_a_column:
352
+ stem_start = grid_stems[0]
353
+ stem_end = grid_stems[-1]
354
+ grid_column = imgs_2_grid_A(
355
+ grids,
356
+ grid_layout='column',
357
+ grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"),
358
+ )
359
+ grids = []
360
+ grid_stems = []
361
+
362
+ model.unset_task()
363
+
364
+ print(f"Your samples are ready and waiting for you here: {out_dir}")
365
+
366
+
infer_hf.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ High-level inference pipeline for UniBioTransfer.
3
+ Designed for easy use in Hugging Face Spaces and other applications.
4
+
5
+ ZeroGPU Compatible:
6
+ - Supports CPU initialization (device="cpu")
7
+ - Dynamically switches to CUDA during inference when called from @spaces.GPU
8
+ """
9
+ from pathlib import Path
10
+ import torch
11
+ import numpy as np
12
+ from PIL import Image
13
+ import cv2
14
+
15
+ import global_
16
+ from hf_model import UniBioTransferModel, TASK_NAME2ID, TASK_ID2NAME
17
+ from ldm.models.diffusion.ddim import DDIMSampler
18
+ from pytorch_lightning import seed_everything
19
+
20
+ DDIM_STEPS_DEFAULT = 50
21
+ SCALE_DEFAULT = 3.0
22
+
23
+
24
+ H, W, C, F = 512, 512, 4, 8
25
+ class UniBioTransferPipeline:
26
+ """
27
+ High-level pipeline for UniBioTransfer inference.
28
+ """
29
+
30
+ def __init__(self, model, task="face", device="cpu"):
31
+ """
32
+ Initialize pipeline with a loaded model.
33
+ """
34
+ self.model = model
35
+ self.task = task
36
+ self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
37
+ self._init_device = device
38
+
39
+ global_.task = self.task_id
40
+ self.model.task = self.task_id
41
+
42
+ self.sampler = DDIMSampler(model)
43
+
44
+ @classmethod
45
+ def from_pretrained(
46
+ cls,
47
+ repo_id="scy639/UniBioTransfer",
48
+ task="face",
49
+ device="cpu",
50
+ cache_dir=None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Load pipeline from Hugging Face Hub.
55
+ """
56
+ model = UniBioTransferModel.from_pretrained(
57
+ pretrained_model_name_or_path=repo_id,
58
+ task=task,
59
+ device=device,
60
+ cache_dir=cache_dir,
61
+ **kwargs,
62
+ )
63
+ return cls(model, task=task, device=device)
64
+
65
+ def set_task(self, task):
66
+ """Switch to a different task."""
67
+ self.task = task
68
+ self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
69
+ global_.task = self.task_id
70
+ self.model.task = self.task_id
71
+
72
+ def __call__(
73
+ self,
74
+ tgt_image,
75
+ ref_image,
76
+ ddim_steps=DDIM_STEPS_DEFAULT,
77
+ scale=SCALE_DEFAULT,
78
+ seed=42,
79
+ num_images=1,
80
+ ):
81
+ """
82
+ Run inference on a pair of images.
83
+ """
84
+ seed_everything(seed)
85
+
86
+ tgt_img = self._load_image(tgt_image)
87
+ ref_img = self._load_image(ref_image)
88
+
89
+ tgt_img = self._resize_image(tgt_img, (H, W))
90
+ ref_img = self._resize_image(ref_img, (H, W))
91
+
92
+ result_tensors = self._run_inference(tgt_img, ref_img, ddim_steps, scale, num_images)
93
+
94
+ result_imgs = [self._postprocess(result_tensors[i]) for i in range(result_tensors.shape[0])]
95
+ return result_imgs
96
+
97
+ def _load_image(self, img):
98
+ """Load image from various formats."""
99
+ if isinstance(img, Image.Image):
100
+ return img.convert("RGB")
101
+ elif isinstance(img, np.ndarray):
102
+ return Image.fromarray(img).convert("RGB")
103
+ elif isinstance(img, (str, Path)):
104
+ return Image.open(img).convert("RGB")
105
+ else:
106
+ raise ValueError(f"Unsupported image type: {type(img)}")
107
+
108
+ def _resize_image(self, img, size):
109
+ """Resize image to target size."""
110
+ if img.size != size:
111
+ img = img.resize(size, Image.LANCZOS)
112
+ return img
113
+
114
+ def _run_inference(self, tgt_img, ref_img, ddim_steps, scale, num_images):
115
+ """
116
+ Run diffusion sampling.
117
+ 完全复用 infer.py 的逻辑,使用 dataloader。
118
+ """
119
+ from Dataset_custom import Dataset_custom
120
+ from gen_lmk_and_mask import gen_lmk_and_mask
121
+ import tempfile
122
+
123
+ with tempfile.TemporaryDirectory() as tmpdir:
124
+ tgt_path = Path(tmpdir) / "tgt.png"
125
+ ref_path = Path(tmpdir) / "ref.png"
126
+ tgt_img.save(tgt_path)
127
+ ref_img.save(ref_path)
128
+
129
+ gen_lmk_and_mask([str(tgt_path), str(ref_path)], write_cache=True)
130
+
131
+ dataset = Dataset_custom(
132
+ "test",
133
+ task=self.task_id,
134
+ paths_tgt=[str(tgt_path)],
135
+ paths_ref=[str(ref_path)],
136
+ )
137
+
138
+ dataloader = torch.utils.data.DataLoader(
139
+ dataset,
140
+ batch_size=1,
141
+ num_workers=1,
142
+ pin_memory=True,
143
+ shuffle=False,
144
+ drop_last=False,
145
+ )
146
+
147
+ run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ self.model = self.model.to(run_device)
149
+
150
+ with torch.no_grad():
151
+ for test_batch, prior, test_model_kwargs, out_stem_batch in dataloader:
152
+ test_batch = test_batch.to(run_device)
153
+ if test_batch.shape[0] == 1:
154
+ test_batch = test_batch.repeat(num_images, 1, 1, 1)
155
+ if isinstance(prior, torch.Tensor):
156
+ prior = prior.to(run_device)
157
+ if prior.shape[0] == 1:
158
+ prior = prior.repeat(num_images, 1, 1, 1)
159
+ for k, v in test_model_kwargs.items():
160
+ if isinstance(v, torch.Tensor):
161
+ v = v.to(run_device)
162
+ if v.shape[0] == 1:
163
+ repeats = [num_images] + [1] * (v.ndim - 1)
164
+ v = v.repeat(*repeats)
165
+ test_model_kwargs[k] = v
166
+ elif isinstance(v, dict):
167
+ new_v = {}
168
+ for kk, vv in v.items():
169
+ if isinstance(vv, torch.Tensor):
170
+ vv = vv.to(run_device)
171
+ if vv.shape[0] == 1:
172
+ repeats = [num_images] + [1] * (vv.ndim - 1)
173
+ vv = vv.repeat(*repeats)
174
+ new_v[kk] = vv
175
+ else:
176
+ new_v[kk] = vv
177
+ test_model_kwargs[k] = new_v
178
+ elif isinstance(v, list):
179
+ test_model_kwargs[k] = v * num_images
180
+
181
+ self.model.set_task(test_model_kwargs)
182
+ bs = num_images
183
+
184
+ batch_ = {
185
+ **test_model_kwargs,
186
+ "GT": torch.zeros(num_images, *test_model_kwargs["inpaint_image"].shape[1:], device=run_device),
187
+ }
188
+ batch_, c = self.model.get_input_and_conditioning(batch_, device=run_device)
189
+
190
+ z_inpaint = batch_["z4_inpaint"]
191
+ z_inpaint_mask = batch_["tgt_mask_64"]
192
+ z_ref = batch_["z_ref"]
193
+ z9 = batch_["z9"]
194
+
195
+ uc = None
196
+ if scale != 1.0:
197
+ uc = self.model.learnable_vector[self.task_id].repeat(bs, 1, 1)
198
+
199
+ shape = [C, H // F, W // F]
200
+ start_code = None
201
+
202
+ samples_ddim, _ = self.sampler.sample(
203
+ S=ddim_steps,
204
+ conditioning=c,
205
+ batch_size=bs,
206
+ shape=shape,
207
+ verbose=False,
208
+ unconditional_guidance_scale=scale,
209
+ unconditional_conditioning=uc,
210
+ eta=0.0,
211
+ x_T=start_code,
212
+ log_every_t=100,
213
+ z_inpaint=z_inpaint,
214
+ z_inpaint_mask=z_inpaint_mask,
215
+ z_ref=z_ref,
216
+ z9=z9,
217
+ )
218
+
219
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
220
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
221
+
222
+ self.model.unset_task()
223
+
224
+ return x_samples_ddim
225
+
226
+ def _postprocess(self, tensor):
227
+ """Convert model output tensor to PIL Image."""
228
+ img_array = tensor.cpu().permute(1, 2, 0).numpy()
229
+ img_array = (img_array * 255).astype(np.uint8)
230
+ return Image.fromarray(img_array)
231
+
232
+
233
+ def infer_single(
234
+ tgt_path,
235
+ ref_path,
236
+ task="face",
237
+ output_path=None,
238
+ ddim_steps=DDIM_STEPS_DEFAULT,
239
+ scale=SCALE_DEFAULT,
240
+ device="cuda",
241
+ ):
242
+ """
243
+ Convenience function for single inference.
244
+ """
245
+ pipeline = UniBioTransferPipeline.from_pretrained(task=task, device=device)
246
+ result = pipeline(tgt_path, ref_path, ddim_steps=ddim_steps, scale=scale)
247
+
248
+ if output_path is not None:
249
+ result.save(output_path)
250
+ print(f"Saved result to {output_path}")
251
+
252
+ return result
253
+
254
+
255
+ if __name__ == "__main__":
256
+ import argparse
257
+
258
+ parser = argparse.ArgumentParser(description="UniBioTransfer inference")
259
+ parser.add_argument("--task", type=str, default="face", choices=["face", "hair", "motion", "head"])
260
+ parser.add_argument("--tgt", type=str, required=True, help="Path to target image")
261
+ parser.add_argument("--ref", type=str, required=True, help="Path to reference image")
262
+ parser.add_argument("--out", type=str, default="result.png", help="Output path")
263
+ parser.add_argument("--ddim-steps", type=int, default=50)
264
+ parser.add_argument("--scale", type=float, default=3.0)
265
+ parser.add_argument("--device", type=str, default="cuda")
266
+
267
+ args = parser.parse_args()
268
+
269
+ result = infer_single(
270
+ args.tgt,
271
+ args.ref,
272
+ task=args.task,
273
+ output_path=args.out,
274
+ ddim_steps=args.ddim_steps,
275
+ scale=args.scale,
276
+ device=args.device,
277
+ )
278
+
279
+ print(f"Inference complete. Result shape: {result.size}")
init_model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys,os
2
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
3
+ if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..')))
4
+
5
+ from imports import *
6
+ import json
7
+ import argparse, os, sys, glob
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
+ from MoE import *
12
+ from multiTask_model import *
13
+ from lora_layers import *
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from tqdm import tqdm, trange
17
+ from itertools import islice
18
+ from einops import rearrange
19
+ from torchvision.utils import make_grid
20
+ from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
21
+ import time
22
+ import copy
23
+ from pytorch_lightning import seed_everything
24
+ from torch import autocast
25
+ from contextlib import contextmanager, nullcontext
26
+ import torchvision
27
+ from ldm.models.diffusion.ddpm import LatentDiffusion
28
+ from ldm.models.diffusion.bank import Bank
29
+ from ldm.util import instantiate_from_config
30
+
31
+ from ldm.models.diffusion.ddim import DDIMSampler
32
+
33
+ from transformers import AutoFeatureExtractor
34
+
35
+ # import clip
36
+ from torchvision.transforms import Resize
37
+ from fnmatch import fnmatch
38
+
39
+
40
+ from PIL import Image
41
+ from torchvision.transforms import PILToTensor
42
+ #----------------------------------------------------------------------------
43
+
44
+
45
+ def get_moe():
46
+ if 1:
47
+ seed_everything(42)
48
+ # torch.cuda.set_device(opt.device_ID)
49
+ model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,)
50
+ if REFNET.ENABLE:
51
+ assert model.model.diffusion_model_refNet.is_refNet
52
+
53
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
54
+ device = torch.device("cpu")
55
+ model = model.to(device)
56
+ if FOR_upcycle_ckpt_GEN_or_USE:
57
+ del model.ptsM_Generator
58
+
59
+ def average_module_weight(
60
+ src_modules: list,
61
+ ):
62
+ """Average the weights of multiple modules"""
63
+ if not src_modules:
64
+ return None
65
+ # Get the state dict of the first module as template
66
+ avg_state_dict = {}
67
+ first_state_dict = src_modules[0].state_dict()
68
+ # Initialize with zeros
69
+ for key in first_state_dict:
70
+ avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
71
+ # Sum
72
+ for module in src_modules:
73
+ module_state_dict = module.state_dict()
74
+ for key in avg_state_dict:
75
+ avg_state_dict[key] += module_state_dict[key]
76
+ # Average
77
+ for key in avg_state_dict:
78
+ avg_state_dict[key] /= len(src_modules)
79
+ return avg_state_dict
80
+ def recursive_average_module_weight(
81
+ tgt_module: nn.Module,
82
+ src_modules: list,
83
+ cb,
84
+ ):
85
+ """
86
+ Recursively find modules and replace with averaged weights based on callback
87
+ """
88
+ for name, child in tgt_module.named_children():
89
+ if 1: # Get corresponding modules from source models
90
+ src_child_modules = []
91
+ for src_module in src_modules:
92
+ src_child = getattr(src_module, name)
93
+ assert src_child is not None,name
94
+ src_child_modules.append(src_child)
95
+ # assert not isinstance(child, TaskSpecific_MoE)
96
+ if cb(child, name, tgt_module):
97
+ print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
98
+ # Average & load
99
+ avg_weights = average_module_weight(src_child_modules)
100
+ child.load_state_dict(avg_weights)
101
+ else:
102
+ recursive_average_module_weight(child, src_child_modules, cb)
103
+ return tgt_module
104
+
105
+ def replace_module_with_TaskSpecific(
106
+ tgt_module: nn.Module,# tgt module
107
+ src_modules: list,
108
+ cb,
109
+ parent_name: str = "",
110
+ depth :int = 0,
111
+ ):
112
+ for name, child in tgt_module.named_children():
113
+ if 1: # Get corresponding modules from source models
114
+ src_child_modules = []
115
+ for src_module in src_modules:
116
+ src_child = getattr(src_module, name)
117
+ assert src_child is not None,name
118
+ src_child_modules.append(src_child)
119
+ assert not isinstance(child, TaskSpecific_MoE)
120
+ full_name = f"{parent_name}.{name}"
121
+ if cb(child, name, full_name, tgt_module):
122
+ print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
123
+ setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS))
124
+ else:
125
+ if depth<=0:
126
+ replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1)
127
+ return tgt_module
128
+
129
+ if not FOR_upcycle_ckpt_GEN_or_USE:
130
+ modelMOE :LatentDiffusion = model
131
+ del model
132
+ if 1: # ensure distinct module instances per task (avoid shared identities)
133
+ with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f)
134
+ print(f"loaded from {PRETRAIN_JSON_PATH=}")
135
+ _src0 = copy.deepcopy(modelMOE.model.diffusion_model)
136
+ _src1 = copy.deepcopy(modelMOE.model.diffusion_model)
137
+ _src2 = copy.deepcopy(modelMOE.model.diffusion_model)
138
+ _src3 = copy.deepcopy(modelMOE.model.diffusion_model)
139
+ replace_modules_lossless(
140
+ modelMOE.model.diffusion_model,
141
+ [ _src0, _src1, _src2, _src3 ],
142
+ [0,1,2,3],
143
+ parent_name=".model.diffusion_model",
144
+ )
145
+ # Build-time dummy wrapping for task-specific heads so that ckpt keys match
146
+ modelMOE.ID_proj_out = TaskSpecific_MoE([
147
+ copy.deepcopy(modelMOE.ID_proj_out),
148
+ copy.deepcopy(modelMOE.ID_proj_out),
149
+ copy.deepcopy(modelMOE.ID_proj_out),
150
+ ], [0,2,3])
151
+ modelMOE.landmark_proj_out = TaskSpecific_MoE([
152
+ copy.deepcopy(modelMOE.landmark_proj_out),
153
+ copy.deepcopy(modelMOE.landmark_proj_out),
154
+ copy.deepcopy(modelMOE.landmark_proj_out),
155
+ ], [0,2,3])
156
+ modelMOE.proj_out_source__head = TaskSpecific_MoE([
157
+ copy.deepcopy(modelMOE.proj_out_source__head),
158
+ copy.deepcopy(modelMOE.proj_out_source__head),
159
+ ], [2,3])
160
+ # Upcycle single refNet using three source refNets, and keep only one
161
+ if REFNET.ENABLE:
162
+ shared_ref = modelMOE.model.diffusion_model_refNet
163
+ src0 = shared_ref
164
+ src1 = copy.deepcopy(shared_ref)
165
+ src2 = copy.deepcopy(shared_ref)
166
+ src3 = copy.deepcopy(shared_ref)
167
+ replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
168
+ # load from ./modelMOE.ckpt
169
+ time.sleep(20*rank_)
170
+ print(f"ckpt load over. m,u:")
171
+ # Initialize bank here (after model structure is finalized)
172
+ if REFNET.ENABLE :
173
+ modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet)
174
+ if __name__=='__main__':
175
+ for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ):
176
+ print(f" - {key}")
177
+ return modelMOE
178
+
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):# n is the step index
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ # print(f"0 {n=} {f=}")
94
+ return f
95
+ else:
96
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
97
+ self.last_f = f
98
+ # print(f"1 {n=} {f=}")
99
+ return f
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ return posterior
329
+
330
+ def decode(self, z):
331
+ z = self.post_quant_conv(z)
332
+ dec = self.decoder(z)
333
+ return dec
334
+
335
+ def forward(self, input, sample_posterior=True):
336
+ posterior = self.encode(input)
337
+ if sample_posterior:
338
+ z = posterior.sample()
339
+ else:
340
+ z = posterior.mode()
341
+ dec = self.decode(z)
342
+ return dec, posterior
343
+
344
+ def get_input(self, batch, k):
345
+ x = batch[k]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
+ return x
350
+
351
+ def training_step(self, batch, batch_idx, optimizer_idx):
352
+ inputs = self.get_input(batch, self.image_key)
353
+ reconstructions, posterior = self(inputs)
354
+
355
+ if optimizer_idx == 0:
356
+ # train encoder+decoder+logvar
357
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
+ last_layer=self.get_last_layer(), split="train")
359
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
+ return aeloss
362
+
363
+ if optimizer_idx == 1:
364
+ # train the discriminator
365
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
+ last_layer=self.get_last_layer(), split="train")
367
+
368
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
+ return discloss
371
+
372
+ def validation_step(self, batch, batch_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
+ last_layer=self.get_last_layer(), split="val")
377
+
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
+ self.log_dict(log_dict_ae)
383
+ self.log_dict(log_dict_disc)
384
+ return self.log_dict
385
+
386
+ def configure_optimizers(self):
387
+ lr = self.learning_rate
388
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
+ list(self.decoder.parameters())+
390
+ list(self.quant_conv.parameters())+
391
+ list(self.post_quant_conv.parameters()),
392
+ lr=lr, betas=(0.5, 0.9))
393
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
+ lr=lr, betas=(0.5, 0.9))
395
+ return [opt_ae, opt_disc], []
396
+
397
+ def get_last_layer(self):
398
+ return self.decoder.conv_out.weight
399
+
400
+ @torch.no_grad()
401
+ def log_images(self, batch, only_inputs=False, **kwargs):
402
+ log = dict()
403
+ x = self.get_input(batch, self.image_key)
404
+ x = x.to(self.device)
405
+ if not only_inputs:
406
+ xrec, posterior = self(x)
407
+ if x.shape[1] > 3:
408
+ # colorize with random projection
409
+ assert xrec.shape[1] > 3
410
+ x = self.to_rgb(x)
411
+ xrec = self.to_rgb(xrec)
412
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
+ log["reconstructions"] = xrec
414
+ log["inputs"] = x
415
+ return log
416
+
417
+ def to_rgb(self, x):
418
+ assert self.image_key == "segmentation"
419
+ if not hasattr(self, "colorize"):
420
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
+ x = F.conv2d(x, weight=self.colorize)
422
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
+ return x
424
+
425
+
426
+ class IdentityFirstStage(torch.nn.Module):
427
+ def __init__(self, *args, vq_interface=False, **kwargs):
428
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
+ super().__init__()
430
+
431
+ def encode(self, x, *args, **kwargs):
432
+ return x
433
+
434
+ def decode(self, x, *args, **kwargs):
435
+ return x
436
+
437
+ def quantize(self, x, *args, **kwargs):
438
+ if self.vq_interface:
439
+ return x, None, [None, None, None]
440
+ return x
441
+
442
+ def forward(self, x, *args, **kwargs):
443
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/bank.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .misc_4ddpm import *
2
+
3
+
4
+ from ldm.modules.attention import BasicTransformerBlock
5
+ class Bank:
6
+ def __init__(self,reader:nn.Module, writer:nn.Module) -> None:
7
+ """
8
+ For the DFS model, mark every BasicTransformerBlock with name_4bank and isReader_4bank flags.
9
+ Similar logic applies for the writer while checking for BasicTransformerBlock instances.
10
+ """
11
+ self.name2data = {}
12
+ self.name2count = {} # track how many times each name has been retrieved
13
+ self.WHEN_clear_a_field = 2 # clear the entry after this many gets
14
+ skip_names = [
15
+ 'input_blocks.1.1.transformer_blocks.0',
16
+ 'input_blocks.2.1.transformer_blocks.0',
17
+ # 'input_blocks.4.1.transformer_blocks.0',
18
+ # 'input_blocks.5.1.transformer_blocks.0',
19
+ # 'input_blocks.7.1.transformer_blocks.0',
20
+ # 'input_blocks.8.1.transformer_blocks.0',
21
+ ##-----------all middle and output_blocks (everything outside input_blocks)----
22
+ 'middle_block.1.transformer_blocks.0',
23
+ 'output_blocks.3.1.transformer_blocks.0',
24
+ 'output_blocks.4.1.transformer_blocks.0',
25
+ 'output_blocks.5.1.transformer_blocks.0',
26
+ 'output_blocks.6.1.transformer_blocks.0',
27
+ 'output_blocks.7.1.transformer_blocks.0',
28
+ 'output_blocks.8.1.transformer_blocks.0',
29
+ 'output_blocks.9.1.transformer_blocks.0',
30
+ 'output_blocks.10.1.transformer_blocks.0',
31
+ 'output_blocks.11.1.transformer_blocks.0',
32
+ ]
33
+ # print(f"{skip_names=}")
34
+
35
+ l_name = []
36
+ for name, _module in writer.named_modules():
37
+ if isinstance(_module, BasicTransformerBlock):
38
+ if DEBUG:
39
+ print(f"{name=}")
40
+ if name in skip_names:
41
+ # print(f"skip {name=}")
42
+ continue
43
+ _module.bank = self
44
+ _module.name4bank = name
45
+ _module.isReader_4bank = False
46
+ l_name.append(name)
47
+ # print(f"{l_name=}")
48
+
49
+ for name, _module in reader.named_modules():
50
+ if isinstance(_module, BasicTransformerBlock):
51
+ if name not in l_name:
52
+ continue
53
+ _module.bank = self
54
+ _module.name4bank = name
55
+ _module.isReader_4bank = True
56
+ def set(self,name,data):
57
+ self.name2data[name] = data
58
+ # self.name2count[name] = 0
59
+ def get(self,name):
60
+ printC('bank get', name)
61
+ if name in self.name2data:
62
+ if name not in self.name2count:
63
+ self.name2count[name] = 0
64
+ self.name2count[name] += 1
65
+ data = self.name2data[name]
66
+ if self.name2count[name] >= self.WHEN_clear_a_field: # once the max get count is reached, remove the entry
67
+ del self.name2data[name]
68
+ del self.name2count[name]
69
+ return data
70
+ raise Exception(f"{name}\n{list(self.name2data.keys())}")
71
+ return None
72
+ def clear(self,):
73
+ printC('clear')
74
+ printC('mean ct:', sum( self.name2count.values() ) / len( self.name2count.values() ) if len( self.name2count.values() )>0 else 'null' )
75
+ self.name2data.clear()
76
+ self.name2count.clear()
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log