Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +9 -0
- Dataset_custom.py +317 -0
- LICENSE +23 -0
- LatentDiffusion.yaml +83 -0
- Mediapipe_Result_Cache.py +36 -0
- MoE.py +141 -0
- Other_dependencies/arcface/add.txt +1 -0
- Other_dependencies/arcface/model_ir_se50.pth +3 -0
- Other_dependencies/face_parsing/79999_iter.pth +3 -0
- Other_dependencies/face_parsing/add.txt +1 -0
- Other_dependencies/mp_models/blaze_face_short_range.tflite +3 -0
- Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task +3 -0
- app.py +239 -0
- checkpoints/pretrained.json +1072 -0
- download_checkpoints.py +29 -0
- eval_tool/lpips/__init__.py +0 -0
- eval_tool/lpips/lpips.py +35 -0
- eval_tool/lpips/networks.py +96 -0
- eval_tool/lpips/utils.py +30 -0
- examples/face/ref-semantic_mask.png +0 -0
- examples/face/ref.png +3 -0
- examples/face/tgt-semantic_mask.png +0 -0
- examples/face/tgt.png +3 -0
- examples/hair/ref-semantic_mask.png +0 -0
- examples/hair/ref.png +3 -0
- examples/hair/tgt-semantic_mask.png +0 -0
- examples/hair/tgt.png +3 -0
- examples/head/ref-semantic_mask.png +0 -0
- examples/head/ref.png +3 -0
- examples/head/tgt-semantic_mask.png +0 -0
- examples/head/tgt.png +3 -0
- examples/inputs.txt +5 -0
- examples/motion/ref-semantic_mask.png +0 -0
- examples/motion/ref.png +3 -0
- examples/motion/tgt-semantic_mask.png +0 -0
- examples/motion/tgt.png +3 -0
- gen_lmk_and_mask.py +41 -0
- gen_semantic_mask.py +90 -0
- get_mask.py +68 -0
- global_.py +9 -0
- hf_model.py +247 -0
- imports.py +8 -0
- infer.py +366 -0
- infer_hf.py +279 -0
- init_model.py +178 -0
- ldm/lr_scheduler.py +99 -0
- ldm/models/autoencoder.py +443 -0
- ldm/models/diffusion/__init__.py +0 -0
- ldm/models/diffusion/bank.py +76 -0
- ldm/models/diffusion/classifier.py +267 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/face/ref.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/face/tgt.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/hair/ref.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/hair/tgt.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/head/ref.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/head/tgt.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/motion/ref.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/motion/tgt.png filter=lfs diff=lfs merge=lfs -text
|
Dataset_custom.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from imports import *
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torch
|
| 7 |
+
import torch.utils.data as data
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import albumentations
|
| 11 |
+
|
| 12 |
+
from util_face import *
|
| 13 |
+
from util_4dataset import *
|
| 14 |
+
from util_cv2 import cv2_resize_auto_interpolation
|
| 15 |
+
from Mediapipe_Result_Cache import Mediapipe_Result_Cache
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def resize_A(img, dataset_name, size=(512, 512), interpolation=None):
|
| 19 |
+
is_pil = isinstance(img, Image.Image)
|
| 20 |
+
if is_pil:
|
| 21 |
+
img = np.array(img)
|
| 22 |
+
if img.shape[:2] != (512, 512):
|
| 23 |
+
img = cv2_resize_auto_interpolation(img, size, interpolation=interpolation)
|
| 24 |
+
if is_pil:
|
| 25 |
+
img = Image.fromarray(img)
|
| 26 |
+
return img
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def un_norm_clip(x1):
|
| 30 |
+
x = x1 * 1.0
|
| 31 |
+
reduce = False
|
| 32 |
+
if len(x.shape) == 3:
|
| 33 |
+
x = x.unsqueeze(0)
|
| 34 |
+
reduce = True
|
| 35 |
+
x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
|
| 36 |
+
x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
|
| 37 |
+
x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
|
| 38 |
+
if reduce:
|
| 39 |
+
x = x.squeeze(0)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def un_norm(x):
|
| 44 |
+
return (x + 1.0) / 2.0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _dilate(_mask, kernel_size, iterations):
|
| 48 |
+
_mask = _mask.astype(np.uint8)
|
| 49 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 50 |
+
_mask = cv2.dilate(_mask, kernel, iterations=iterations)
|
| 51 |
+
_mask = _mask.astype(bool)
|
| 52 |
+
return _mask
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def dilate_4_task0(sm_mask):
|
| 56 |
+
sm_mask = np.array(sm_mask)
|
| 57 |
+
preserve1 = [2, 3, 10, 5]
|
| 58 |
+
mask1 = np.isin(sm_mask, preserve1)
|
| 59 |
+
mask1 = _dilate(mask1, 7, 1)
|
| 60 |
+
preserve2 = [3, 10]
|
| 61 |
+
mask2 = np.isin(sm_mask, preserve2)
|
| 62 |
+
mask2 = _dilate(mask2, 10, 3)
|
| 63 |
+
preserve3 = [1]
|
| 64 |
+
mask3 = np.isin(sm_mask, preserve3)
|
| 65 |
+
mask3 = _dilate(mask3, 7, 2)
|
| 66 |
+
mask = mask1 | mask2 | mask3
|
| 67 |
+
return mask
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Dataset_custom(data.Dataset):
|
| 71 |
+
mean = (0.5, 0.5, 0.5)
|
| 72 |
+
std = (0.5, 0.5, 0.5)
|
| 73 |
+
|
| 74 |
+
def get_img4clip(
|
| 75 |
+
self,
|
| 76 |
+
img,
|
| 77 |
+
sm_mask,
|
| 78 |
+
preserve,
|
| 79 |
+
for_clip=True,
|
| 80 |
+
add_semantic_head=False,
|
| 81 |
+
mask_after_npisin=None,
|
| 82 |
+
for_inpaint512=False,
|
| 83 |
+
):
|
| 84 |
+
sm_mask = np.array(sm_mask)
|
| 85 |
+
if mask_after_npisin is None:
|
| 86 |
+
if self.task == 0 and 0:
|
| 87 |
+
mask = dilate_4_task0(sm_mask)
|
| 88 |
+
else:
|
| 89 |
+
mask = np.isin(sm_mask, preserve)
|
| 90 |
+
if self.task == 0 and 1 and for_inpaint512:
|
| 91 |
+
forehead_mask = get_forehead_mask(sm_mask)
|
| 92 |
+
mask = mask & ~forehead_mask
|
| 93 |
+
else:
|
| 94 |
+
mask = mask_after_npisin
|
| 95 |
+
|
| 96 |
+
if isinstance(img, np.ndarray):
|
| 97 |
+
img = Image.fromarray(img)
|
| 98 |
+
if add_semantic_head:
|
| 99 |
+
mask_before_colorSM = mask
|
| 100 |
+
img, mask = add_colorSM(img, sm_mask, preserve, None)
|
| 101 |
+
mask = mask_after_npisin__2__tensor(mask)
|
| 102 |
+
|
| 103 |
+
if for_clip:
|
| 104 |
+
image_tensor = get_tensor_clip()(img)
|
| 105 |
+
else:
|
| 106 |
+
image_tensor = get_tensor(mean=self.mean, std=self.std)(img)
|
| 107 |
+
image_tensor = T.Resize([512, 512])(image_tensor)
|
| 108 |
+
image_tensor = image_tensor * mask
|
| 109 |
+
if for_clip:
|
| 110 |
+
image_tensor = 255.0 * rearrange(un_norm_clip(image_tensor), "c h w -> h w c").cpu().numpy()
|
| 111 |
+
_size = 224
|
| 112 |
+
else:
|
| 113 |
+
image_tensor = 255.0 * rearrange(un_norm(image_tensor), "c h w -> h w c").cpu().numpy()
|
| 114 |
+
_size = 512
|
| 115 |
+
|
| 116 |
+
image_tensor = albumentations.Resize(height=_size, width=_size)(image=image_tensor)
|
| 117 |
+
image_tensor = Image.fromarray(image_tensor["image"].astype(np.uint8))
|
| 118 |
+
if for_clip:
|
| 119 |
+
image_tensor = get_tensor_clip()(image_tensor)
|
| 120 |
+
else:
|
| 121 |
+
image_tensor = get_tensor(mean=self.mean, std=self.std)(image_tensor)
|
| 122 |
+
image_tensor = image_tensor * mask
|
| 123 |
+
if add_semantic_head:
|
| 124 |
+
mask = mask_after_npisin__2__tensor(mask_before_colorSM)
|
| 125 |
+
return image_tensor, mask
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
state,
|
| 130 |
+
task,
|
| 131 |
+
paths_tgt,
|
| 132 |
+
paths_ref,
|
| 133 |
+
name="custom",
|
| 134 |
+
):
|
| 135 |
+
if task == 0:
|
| 136 |
+
USE_filter_mediapipe_fail_swap = 1
|
| 137 |
+
USE_pts = 1
|
| 138 |
+
READ_mediapipe_result_from_cache = 1
|
| 139 |
+
elif task == 1:
|
| 140 |
+
USE_filter_mediapipe_fail_swap = 0
|
| 141 |
+
USE_pts = 0
|
| 142 |
+
READ_mediapipe_result_from_cache = 1
|
| 143 |
+
elif task == 2:
|
| 144 |
+
USE_filter_mediapipe_fail_swap = 1
|
| 145 |
+
USE_pts = 1
|
| 146 |
+
READ_mediapipe_result_from_cache = 1
|
| 147 |
+
elif task == 3:
|
| 148 |
+
USE_filter_mediapipe_fail_swap = 0
|
| 149 |
+
USE_pts = 1
|
| 150 |
+
READ_mediapipe_result_from_cache = 1
|
| 151 |
+
self.READ_mediapipe_result_from_cache = READ_mediapipe_result_from_cache
|
| 152 |
+
|
| 153 |
+
assert state == "test"
|
| 154 |
+
self.state = state
|
| 155 |
+
self.image_size = 512
|
| 156 |
+
self.kernel = np.ones((1, 1), np.uint8)
|
| 157 |
+
self.name = name
|
| 158 |
+
|
| 159 |
+
assert paths_tgt is not None and paths_ref is not None, "paths_tgt and paths_ref are required"
|
| 160 |
+
assert len(paths_tgt) == len(paths_ref), "paths_tgt and paths_ref must be the same length"
|
| 161 |
+
self.paths_tgt = list(paths_tgt)
|
| 162 |
+
self.paths_ref = list(paths_ref)
|
| 163 |
+
|
| 164 |
+
if READ_mediapipe_result_from_cache:
|
| 165 |
+
self.mediapipe_Result_Cache = Mediapipe_Result_Cache()
|
| 166 |
+
self.task = task
|
| 167 |
+
|
| 168 |
+
def __getitem__(self, index):
|
| 169 |
+
task = self.task
|
| 170 |
+
path_tgt = self.paths_tgt[index]
|
| 171 |
+
path_ref = self.paths_ref[index]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
img_tgt = Image.open(path_tgt).convert("RGB")
|
| 175 |
+
img_tgt = resize_A(img_tgt, self.name)
|
| 176 |
+
|
| 177 |
+
mask_path = path_img_2_path_mask(path_tgt)
|
| 178 |
+
if self.task == 0:
|
| 179 |
+
preserve = [1, 2, 3, 10, 5, 6, 7, 9]
|
| 180 |
+
if 0:
|
| 181 |
+
preserve = [1, 2, 3, 10, 5]
|
| 182 |
+
sm_mask_tgt = Image.open(mask_path).convert("L")
|
| 183 |
+
sm_mask_tgt = np.array(sm_mask_tgt)
|
| 184 |
+
if 0:
|
| 185 |
+
mask_tgt = dilate_4_task0(sm_mask_tgt)
|
| 186 |
+
else:
|
| 187 |
+
mask_tgt = np.isin(sm_mask_tgt, preserve)
|
| 188 |
+
if self.task == 0 and 1:
|
| 189 |
+
forehead_mask = get_forehead_mask(sm_mask_tgt)
|
| 190 |
+
mask_tgt = mask_tgt & ~forehead_mask
|
| 191 |
+
elif self.task == 1:
|
| 192 |
+
preserve = [4]
|
| 193 |
+
mask_tgt = path_img_2_mask(path_tgt, preserve)
|
| 194 |
+
elif self.task == 3:
|
| 195 |
+
preserve = [1, 2, 3, 10, 4, 5, 6, 7, 9]
|
| 196 |
+
mask_tgt = path_img_2_mask(path_tgt, preserve)
|
| 197 |
+
elif self.task == 2:
|
| 198 |
+
preserve = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21]
|
| 199 |
+
sm_mask_tgt = Image.open(mask_path).convert("L")
|
| 200 |
+
sm_mask_tgt = np.array(sm_mask_tgt)
|
| 201 |
+
mask_tgt = np.isin(sm_mask_tgt, preserve)
|
| 202 |
+
|
| 203 |
+
converted_mask = np.zeros_like(mask_tgt)
|
| 204 |
+
converted_mask[mask_tgt] = 255
|
| 205 |
+
mask_tgt = Image.fromarray(converted_mask).convert("L")
|
| 206 |
+
mask_tensor = 1 - get_tensor(normalize=False, toTensor=True)(mask_tgt)
|
| 207 |
+
|
| 208 |
+
image_tensor = get_tensor(mean=self.mean, std=self.std)(img_tgt)
|
| 209 |
+
image_tensor_resize = T.Resize([self.image_size, self.image_size])(image_tensor)
|
| 210 |
+
mask_tensor_resize = T.Resize([self.image_size, self.image_size])(mask_tensor)
|
| 211 |
+
|
| 212 |
+
if task == 2:
|
| 213 |
+
inpaint_tensor_resize = image_tensor_resize
|
| 214 |
+
else:
|
| 215 |
+
inpaint_tensor_resize = image_tensor_resize * mask_tensor_resize
|
| 216 |
+
if 1:
|
| 217 |
+
mask_tensor_resize = 1 - mask_tensor_resize
|
| 218 |
+
|
| 219 |
+
if 1:
|
| 220 |
+
mask_path_ref = path_img_2_path_mask(path_ref)
|
| 221 |
+
sm_mask_ref = Image.open(mask_path_ref).convert("L")
|
| 222 |
+
sm_mask_ref = np.array(sm_mask_ref)
|
| 223 |
+
img_ref = cv2.imread(str(path_ref))
|
| 224 |
+
img_ref = cv2.cvtColor(img_ref, cv2.COLOR_BGR2RGB)
|
| 225 |
+
img_ref = resize_A(img_ref, self.name)
|
| 226 |
+
|
| 227 |
+
if task != 2:
|
| 228 |
+
ref_image_tensor, ref_mask_tensor = self.get_img4clip(
|
| 229 |
+
img_ref, sm_mask_ref, preserve, for_clip=True, add_semantic_head=0
|
| 230 |
+
)
|
| 231 |
+
if task == 3:
|
| 232 |
+
ref_image_faceOnly_tensor, _ = self.get_img4clip(
|
| 233 |
+
img_ref,
|
| 234 |
+
sm_mask_ref,
|
| 235 |
+
[1, 2, 3, 10, 5, 6, 7, 9],
|
| 236 |
+
for_clip=False,
|
| 237 |
+
add_semantic_head=0,
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
ref_image_tensor = inpaint_tensor_resize
|
| 241 |
+
|
| 242 |
+
ret = {
|
| 243 |
+
"inpaint_image": inpaint_tensor_resize,
|
| 244 |
+
"inpaint_mask": mask_tensor_resize,
|
| 245 |
+
"ref_imgs": ref_image_tensor,
|
| 246 |
+
"task": self.task,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if self.task == 0:
|
| 250 |
+
ret["enInputs"] = {
|
| 251 |
+
"face_ID-in": ref_image_tensor,
|
| 252 |
+
"face-clip-in": ref_image_tensor,
|
| 253 |
+
}
|
| 254 |
+
elif self.task == 1:
|
| 255 |
+
ret["enInputs"] = {
|
| 256 |
+
"hair-clip-in": ref_image_tensor,
|
| 257 |
+
}
|
| 258 |
+
elif self.task == 2:
|
| 259 |
+
tgt_nonBg_tensor, _ = self.get_img4clip(img_tgt, sm_mask_tgt, preserve)
|
| 260 |
+
ret["enInputs"] = {
|
| 261 |
+
"face_ID-in": tgt_nonBg_tensor,
|
| 262 |
+
"head-clip-in": tgt_nonBg_tensor,
|
| 263 |
+
}
|
| 264 |
+
elif self.task == 3:
|
| 265 |
+
ret["enInputs"] = {
|
| 266 |
+
"face_ID-in": ref_image_faceOnly_tensor,
|
| 267 |
+
"head-clip-in": ref_image_tensor,
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
if (REFNET.ENABLE and REFNET.task2layerNum[task] > 0) or CH14:
|
| 271 |
+
if task != 2:
|
| 272 |
+
ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
|
| 273 |
+
img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
ref_imgs_4unet, ref_mask_4unet = self.get_img4clip(
|
| 277 |
+
img_tgt,
|
| 278 |
+
sm_mask_tgt,
|
| 279 |
+
"any",
|
| 280 |
+
for_clip=False,
|
| 281 |
+
add_semantic_head=0,
|
| 282 |
+
mask_after_npisin=np.ones_like(sm_mask_tgt).astype(bool),
|
| 283 |
+
)
|
| 284 |
+
ref_imgs_4unet = T.Resize([self.image_size, self.image_size])(ref_imgs_4unet)
|
| 285 |
+
ref_mask_512 = T.Resize([self.image_size, self.image_size])(ref_mask_4unet)
|
| 286 |
+
ret["ref_imgs_4unet"] = ref_imgs_4unet
|
| 287 |
+
ret["ref_mask_512"] = ref_mask_512
|
| 288 |
+
|
| 289 |
+
if self.READ_mediapipe_result_from_cache:
|
| 290 |
+
if self.state == "test":
|
| 291 |
+
if task == 2:
|
| 292 |
+
_p_lmk = path_ref
|
| 293 |
+
else:
|
| 294 |
+
_p_lmk = path_tgt
|
| 295 |
+
else:
|
| 296 |
+
_p_lmk = path_tgt
|
| 297 |
+
ret["mediapipe_lmkAll"] = self.mediapipe_Result_Cache.get(_p_lmk)
|
| 298 |
+
if ret["mediapipe_lmkAll"] is None:
|
| 299 |
+
raise RuntimeError(
|
| 300 |
+
f"Missing Mediapipe cache for input image: {_p_lmk}. "
|
| 301 |
+
"Precompute landmarks and ensure cache exists before inference."
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if self.state == "test":
|
| 305 |
+
prior_image_tensor = "None"
|
| 306 |
+
out_stem = f"{Path(path_tgt).stem}-{Path(path_ref).stem}"
|
| 307 |
+
if task == 2:
|
| 308 |
+
ref512, _ = self.get_img4clip(
|
| 309 |
+
img_ref, sm_mask_ref, preserve, for_clip=False, add_semantic_head=0
|
| 310 |
+
)
|
| 311 |
+
ref512 = T.Resize([self.image_size, self.image_size])(ref512)
|
| 312 |
+
ret["ref512"] = ref512
|
| 313 |
+
ret = (image_tensor_resize, prior_image_tensor, ret, out_stem)
|
| 314 |
+
return ret
|
| 315 |
+
|
| 316 |
+
def __len__(self):
|
| 317 |
+
return len(self.paths_tgt)
|
LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Sanoojan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
| 23 |
+
|
LatentDiffusion.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 4.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "inpaint"
|
| 11 |
+
cond_stage_key: "image"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: true # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
u_cond_percent: 0.2
|
| 18 |
+
scale_factor: 0.18215
|
| 19 |
+
use_ema: False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
scheduler_config: # 10000 warmup steps
|
| 23 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 24 |
+
params:
|
| 25 |
+
warm_up_steps: [ 10000 ]
|
| 26 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 27 |
+
f_start: [ 1.e-1 ]
|
| 28 |
+
f_max: [ 1. ]
|
| 29 |
+
f_min: [ 1. ]
|
| 30 |
+
|
| 31 |
+
unet_config:
|
| 32 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 33 |
+
params:
|
| 34 |
+
image_size: 32 # unused
|
| 35 |
+
out_channels: 4
|
| 36 |
+
model_channels: 320
|
| 37 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 38 |
+
num_res_blocks: 2
|
| 39 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 40 |
+
num_heads: 8
|
| 41 |
+
use_spatial_transformer: True
|
| 42 |
+
transformer_depth: 1
|
| 43 |
+
context_dim: 768
|
| 44 |
+
use_checkpoint: True
|
| 45 |
+
legacy: False
|
| 46 |
+
add_conv_in_front_of_unet: False
|
| 47 |
+
|
| 48 |
+
first_stage_config:
|
| 49 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 50 |
+
params:
|
| 51 |
+
embed_dim: 4
|
| 52 |
+
monitor: val/rec_loss
|
| 53 |
+
ddconfig:
|
| 54 |
+
double_z: true
|
| 55 |
+
z_channels: 4
|
| 56 |
+
resolution: 256
|
| 57 |
+
in_channels: 3
|
| 58 |
+
out_ch: 3
|
| 59 |
+
ch: 128
|
| 60 |
+
ch_mult:
|
| 61 |
+
- 1
|
| 62 |
+
- 2
|
| 63 |
+
- 4
|
| 64 |
+
- 4
|
| 65 |
+
num_res_blocks: 2
|
| 66 |
+
attn_resolutions: []
|
| 67 |
+
dropout: 0.0
|
| 68 |
+
lossconfig:
|
| 69 |
+
target: torch.nn.Identity
|
| 70 |
+
|
| 71 |
+
cond_stage_config:
|
| 72 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
| 73 |
+
other_params:
|
| 74 |
+
clip_weight: 1.0
|
| 75 |
+
arcface_path: "Other_dependencies/arcface/model_ir_se50.pth"
|
| 76 |
+
multi_scale_ID: False # True was used for the previous training there is an issue
|
| 77 |
+
Additional_config:
|
| 78 |
+
Reconstruct_initial: False # scy:
|
| 79 |
+
Target_CLIP_feat: True
|
| 80 |
+
Source_CLIP_feat: True
|
| 81 |
+
Reconstruct_DDIM_steps: 4
|
| 82 |
+
|
| 83 |
+
|
Mediapipe_Result_Cache.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from imports import *
|
| 2 |
+
import json,random,os
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Mediapipe_Result_Cache:
|
| 8 |
+
"""
|
| 9 |
+
Convention: when a cache entry exists, it must not be None.
|
| 10 |
+
In other words, None results should not be cached; get/set guard against historical None values.
|
| 11 |
+
"""
|
| 12 |
+
# DIR = Path('/inspurfs/group/mayuexin/suncy/mediapipe_result/A')
|
| 13 |
+
DIR = Path('data/mediapipe_result')
|
| 14 |
+
def __init__(self):
|
| 15 |
+
pass
|
| 16 |
+
def get_path(self, img_path):
|
| 17 |
+
img_path = Path(img_path)
|
| 18 |
+
str_img_folder = str(img_path.parent)
|
| 19 |
+
assert '|' not in str_img_folder
|
| 20 |
+
str_img_folder = str_img_folder.replace('/', '|')
|
| 21 |
+
lmk_folder = self.DIR / str_img_folder
|
| 22 |
+
lmk_folder.mkdir(parents=1, exist_ok=True)
|
| 23 |
+
ret= lmk_folder / (img_path.name+'.npy')
|
| 24 |
+
return ret
|
| 25 |
+
def get(self, img_path):
|
| 26 |
+
path = self.get_path(img_path)
|
| 27 |
+
# print(f"[get] {path=}")
|
| 28 |
+
if path.exists():
|
| 29 |
+
ret = np.load(path)
|
| 30 |
+
assert ret is not None
|
| 31 |
+
return ret
|
| 32 |
+
def set(self, img_path, lmks):
|
| 33 |
+
assert lmks is not None
|
| 34 |
+
path = self.get_path(img_path)
|
| 35 |
+
np.save(path, lmks)
|
| 36 |
+
# print(f"{path=}")
|
MoE.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from imports import *
|
| 2 |
+
import global_
|
| 3 |
+
import torch,copy
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from ldm.modules.attention import FeedForward,CrossAttention
|
| 6 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel,ResBlock,TimestepEmbedSequential
|
| 7 |
+
# import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
# ---------------- Configs ----------------
|
| 10 |
+
CONV2D_PARAM_STATS = []
|
| 11 |
+
|
| 12 |
+
def average_module_weight(src_modules: list):
|
| 13 |
+
"""Average the weights of multiple modules (similar to init_model.py)."""
|
| 14 |
+
if not src_modules:
|
| 15 |
+
return None
|
| 16 |
+
avg_state_dict = {}
|
| 17 |
+
first_state_dict = src_modules[0].state_dict()
|
| 18 |
+
for key in first_state_dict:
|
| 19 |
+
avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
|
| 20 |
+
for module in src_modules:
|
| 21 |
+
module_state_dict = module.state_dict()
|
| 22 |
+
for key in avg_state_dict:
|
| 23 |
+
avg_state_dict[key] += module_state_dict[key]
|
| 24 |
+
for key in avg_state_dict:
|
| 25 |
+
avg_state_dict[key] /= len(src_modules)
|
| 26 |
+
return avg_state_dict
|
| 27 |
+
|
| 28 |
+
class ModuleDict_W(nn.Module): # Wrapper of ModuleDict
|
| 29 |
+
def __init__(self, modules: list, keys: list):
|
| 30 |
+
super().__init__()
|
| 31 |
+
assert len(keys) == len(modules), f"{len(keys)=} {len(modules)=}"
|
| 32 |
+
self._keys = [int(k) for k in keys]
|
| 33 |
+
self._moduleDict = nn.ModuleDict({str(int(k)): m for k, m in zip(self._keys, modules)})
|
| 34 |
+
def __getitem__(self, k: int):
|
| 35 |
+
_k = str(int(k))
|
| 36 |
+
return self._moduleDict[_k]
|
| 37 |
+
def keys(self):
|
| 38 |
+
return list(self._keys)
|
| 39 |
+
def forward(self, *args, **kwargs):
|
| 40 |
+
cur_task = global_.task
|
| 41 |
+
assert cur_task in self._keys, f"Current task {cur_task} not in available tasks {self._keys}"
|
| 42 |
+
return self._moduleDict[str(int(cur_task))](*args, **kwargs)
|
| 43 |
+
def offload_unused_tasks(self, unused_tasks, method: str):
|
| 44 |
+
for i in unused_tasks:
|
| 45 |
+
_k = str(int(i))
|
| 46 |
+
if _k in self._moduleDict:
|
| 47 |
+
if method == 'del':
|
| 48 |
+
# self._moduleDict[_k] = None # should behave the same either way
|
| 49 |
+
del self._moduleDict[_k]
|
| 50 |
+
elif method == 'cpu':
|
| 51 |
+
self._moduleDict[_k].to('cpu')
|
| 52 |
+
else:
|
| 53 |
+
raise
|
| 54 |
+
|
| 55 |
+
class TaskSpecific_MoE(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
module:nn.Module,# or list of Module
|
| 59 |
+
tasks:tuple,
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.cur_task = None
|
| 63 |
+
self.tasks = tasks
|
| 64 |
+
if isinstance(module, nn.Module):
|
| 65 |
+
modules = [copy.deepcopy(module) for _ in self.tasks]
|
| 66 |
+
elif isinstance(module, list):
|
| 67 |
+
assert len(module) == len(self.tasks), f"got {len(module)} and {len(self.tasks)}"
|
| 68 |
+
modules = module
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f"got {type(module)}")
|
| 71 |
+
self.tasks_2_module = ModuleDict_W(modules, self.tasks)
|
| 72 |
+
|
| 73 |
+
def forward(self, *args, **kwargs) -> torch.Tensor:
|
| 74 |
+
# cur_task = self.cur_task
|
| 75 |
+
cur_task = global_.task
|
| 76 |
+
assert cur_task in self.tasks, f"Current task {cur_task} not in available tasks {self.tasks}"
|
| 77 |
+
return self.tasks_2_module[cur_task](*args, **kwargs)
|
| 78 |
+
|
| 79 |
+
def set_task(self, task):
|
| 80 |
+
assert 0, 'set_task is disabled for now; update to gg.task instead'
|
| 81 |
+
# assert task in self.tasks, f"Task {task} not in available tasks {self.tasks}"
|
| 82 |
+
self.cur_task = task
|
| 83 |
+
|
| 84 |
+
def is_task_specific_(name:str):
|
| 85 |
+
is_task_specific = (
|
| 86 |
+
('._moduleDict.' in name) or
|
| 87 |
+
('tasks_2_module' in name) or
|
| 88 |
+
('task_ffn' in name) or
|
| 89 |
+
('task_proj' in name) or
|
| 90 |
+
('task_conv' in name) or
|
| 91 |
+
('task_gate_mlps' in name) or
|
| 92 |
+
('task_lora' in name) or
|
| 93 |
+
|
| 94 |
+
('encoder_clip_' in name) or
|
| 95 |
+
('proj_out_source__' in name) or
|
| 96 |
+
('ID_proj_out' in name) or
|
| 97 |
+
('landmark_proj_out' in name) or
|
| 98 |
+
('learnable_vector' in name)
|
| 99 |
+
)
|
| 100 |
+
return is_task_specific
|
| 101 |
+
def tp_param_need_sync(name: str, p: torch.nn.Parameter):
|
| 102 |
+
if is_task_specific_(name):
|
| 103 |
+
return False, True
|
| 104 |
+
if 'first_stage_model' in name or 'face_ID_model' in name or 'encoder_clip_face.tokenizer' in name or 'encoder_clip_face.model' in name:
|
| 105 |
+
return False, False
|
| 106 |
+
if not p.requires_grad:
|
| 107 |
+
return False, False
|
| 108 |
+
return True, False
|
| 109 |
+
def offload_unused_tasks(parent: nn.Module, active_task: int, method: str, ):
|
| 110 |
+
unused_tasks = [_t for _t in TASKS if _t != active_task] # inactive tasks
|
| 111 |
+
for name, child in parent.named_children():
|
| 112 |
+
if hasattr(child, '__class__') and child.__class__.__name__ in [
|
| 113 |
+
'TaskSpecific_MoE',
|
| 114 |
+
'FFN_TaskSpecific_Plus_Shared',
|
| 115 |
+
'Linear_TaskSpecific_Plus_Shared',
|
| 116 |
+
'Conv_TaskSpecific_Plus_Shared',
|
| 117 |
+
'FFN_Shared_Plus_TaskLoRA',
|
| 118 |
+
'Linear_Shared_Plus_TaskLoRA',
|
| 119 |
+
'Conv_Shared_Plus_TaskLoRA',
|
| 120 |
+
]:
|
| 121 |
+
for attr_name in [ # normalize attribute handling to avoid repetition
|
| 122 |
+
'tasks_2_module',
|
| 123 |
+
'task_ffn', 'task_proj', 'task_conv',
|
| 124 |
+
'task_lora_in', 'task_lora_out', 'task_lora',
|
| 125 |
+
]:
|
| 126 |
+
if hasattr(child, attr_name):
|
| 127 |
+
ml = getattr(child, attr_name)
|
| 128 |
+
if isinstance(ml, nn.ModuleList):
|
| 129 |
+
for i in unused_tasks: # move or delete parameters for inactive tasks
|
| 130 |
+
if method == 'del':
|
| 131 |
+
ml[i] = None
|
| 132 |
+
elif method == 'cpu':
|
| 133 |
+
ml[i].to('cpu')
|
| 134 |
+
else: raise Exception
|
| 135 |
+
elif isinstance(ml, ModuleDict_W):
|
| 136 |
+
ml.offload_unused_tasks(unused_tasks,method)
|
| 137 |
+
# recurse(child)
|
| 138 |
+
else: offload_unused_tasks(child,active_task,method)
|
| 139 |
+
def offload_unused_tasks__LD(modelMOE, task_keep: int, method: str, ):
|
| 140 |
+
# Remove or offload inactive task-related parameters to save CUDA memory (method: del|cpu)
|
| 141 |
+
offload_unused_tasks(modelMOE, task_keep, method)
|
Other_dependencies/arcface/add.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Add arcface model
|
Other_dependencies/arcface/model_ir_se50.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a035c768259b98ab1ce0e646312f48b9e1e218197a0f80ac6765e88f8b6ddf28
|
| 3 |
+
size 175367323
|
Other_dependencies/face_parsing/79999_iter.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
| 3 |
+
size 53289463
|
Other_dependencies/face_parsing/add.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Add face parsing model
|
Other_dependencies/mp_models/blaze_face_short_range.tflite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
|
| 3 |
+
size 229746
|
Other_dependencies/mp_models/face_landmarker_v2_with_blendshapes.task
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
|
| 3 |
+
size 3758596
|
app.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Space demo for UniBioTransfer.
|
| 3 |
+
Gradio interface for face/hair/motion/head transfer.
|
| 4 |
+
|
| 5 |
+
ZeroGPU Compatible:
|
| 6 |
+
- Model initialized on CPU (no GPU memory during startup)
|
| 7 |
+
- Inference wrapped with @spaces.GPU decorator
|
| 8 |
+
- Thread-safe global variable access with Lock
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import threading
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
# ==========================================
|
| 17 |
+
# 兼容层:处理本地测试 vs HF ZeroGPU 环境
|
| 18 |
+
# ==========================================
|
| 19 |
+
try:
|
| 20 |
+
import spaces
|
| 21 |
+
print("Detected spaces library (Hugging Face environment).")
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("Local environment detected. Mocking spaces.GPU...")
|
| 24 |
+
class spaces:
|
| 25 |
+
@staticmethod
|
| 26 |
+
def GPU(func):
|
| 27 |
+
return func # 本地测试时,装饰器变为空壳,直接执行原函数
|
| 28 |
+
|
| 29 |
+
from infer_hf import UniBioTransferPipeline
|
| 30 |
+
|
| 31 |
+
# 锁和全局单例 Pipeline
|
| 32 |
+
inference_lock = threading.Lock()
|
| 33 |
+
global_pipeline :UniBioTransferPipeline = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_pipeline(task):
|
| 37 |
+
"""
|
| 38 |
+
单例模式:全局只初始化一次模型(放在 CPU),后续只切换任务。
|
| 39 |
+
强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡。
|
| 40 |
+
"""
|
| 41 |
+
global global_pipeline
|
| 42 |
+
if global_pipeline is None:
|
| 43 |
+
print("Initializing pipeline once on CPU...")
|
| 44 |
+
# 强制写死 CPU,保证 ZeroGPU 全局初始化时不碰显卡
|
| 45 |
+
global_pipeline = UniBioTransferPipeline.from_pretrained(
|
| 46 |
+
repo_id="scy639/UniBioTransfer",
|
| 47 |
+
task=task,
|
| 48 |
+
device="cpu",
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
# 如果模型已经在内存中,只需切换 task ID 即可
|
| 52 |
+
print(f"Switching existing pipeline to task: {task}")
|
| 53 |
+
global_pipeline.set_task(task)
|
| 54 |
+
return global_pipeline
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# 核心:将所有会用到 GPU 的前向推理逻辑包裹在这里
|
| 58 |
+
@spaces.GPU
|
| 59 |
+
def run_gpu_inference(pipeline:UniBioTransferPipeline, tgt_pil, ref_pil, ddim_steps, scale, seed, num_images):
|
| 60 |
+
"""
|
| 61 |
+
这里是 ZeroGPU 分配算力的地方。进入此函数时可以安全地 to("cuda")。
|
| 62 |
+
如果是在本地服务器,这个装饰器没用,但内部的 .to("cuda") 同样生效。
|
| 63 |
+
"""
|
| 64 |
+
return pipeline(
|
| 65 |
+
tgt_pil,
|
| 66 |
+
ref_pil,
|
| 67 |
+
ddim_steps=ddim_steps,
|
| 68 |
+
scale=scale,
|
| 69 |
+
seed=seed,
|
| 70 |
+
num_images=num_images,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def inference(task, tgt_img, ref_img, ddim_steps, seed, num_images):
|
| 75 |
+
"""
|
| 76 |
+
Run inference for the demo.
|
| 77 |
+
"""
|
| 78 |
+
if tgt_img is None or ref_img is None:
|
| 79 |
+
return None, "Please upload both target and reference images."
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
# 1. 拿模型 (此时模型在 CPU)
|
| 83 |
+
pipeline = get_pipeline(task)
|
| 84 |
+
|
| 85 |
+
tgt_pil = Image.fromarray(tgt_img).convert("RGB")
|
| 86 |
+
ref_pil = Image.fromarray(ref_img).convert("RGB")
|
| 87 |
+
|
| 88 |
+
# 2. 加锁,防止并发污染 global_.task,进入 GPU 推理
|
| 89 |
+
with inference_lock:
|
| 90 |
+
results = run_gpu_inference(
|
| 91 |
+
pipeline,
|
| 92 |
+
tgt_pil,
|
| 93 |
+
ref_pil,
|
| 94 |
+
int(ddim_steps),
|
| 95 |
+
float(3),
|
| 96 |
+
int(seed),
|
| 97 |
+
int(num_images)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return results, f"Success! Task: {task} transfer completed."
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
import traceback
|
| 104 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 105 |
+
print(f"{error_msg}")
|
| 106 |
+
return None, error_msg
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def create_demo():
|
| 110 |
+
"""Create Gradio demo interface."""
|
| 111 |
+
import gradio as gr
|
| 112 |
+
|
| 113 |
+
with gr.Blocks(title="UniBioTransfer") as demo:
|
| 114 |
+
gr.Markdown(
|
| 115 |
+
"""
|
| 116 |
+
# UniBioTransfer
|
| 117 |
+
|
| 118 |
+
Perform face transfer, hair transfer, motion transfer (face reenactment), and head transfer.
|
| 119 |
+
|
| 120 |
+
- **Face Transfer**: Transfer face identity from reference to target
|
| 121 |
+
- **Hair Transfer**: Transfer hairstyle from reference to target
|
| 122 |
+
- **Motion Transfer**: Transfer motion(expression+head pose) from reference to target
|
| 123 |
+
- **Head Transfer**: Transfer entire head from reference to target
|
| 124 |
+
|
| 125 |
+
[Code](https://github.com/scy639/UniBioTransfer)
|
| 126 |
+
[Project Page](https://scy639.github.io/UniBioTransfer.github.io/)
|
| 127 |
+
[Paper](https://arxiv.org/abs/2603.19637)
|
| 128 |
+
"""
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
with gr.Row():
|
| 132 |
+
with gr.Column():
|
| 133 |
+
task_dropdown = gr.Dropdown(
|
| 134 |
+
choices=["face", "hair", "motion", "head"],
|
| 135 |
+
value="face",
|
| 136 |
+
label="Task",
|
| 137 |
+
info="Select the transfer type",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
with gr.Row():
|
| 141 |
+
tgt_image = gr.Image(
|
| 142 |
+
label="Target Image",
|
| 143 |
+
type="numpy",
|
| 144 |
+
height=300,
|
| 145 |
+
)
|
| 146 |
+
ref_image = gr.Image(
|
| 147 |
+
label="Reference Image",
|
| 148 |
+
type="numpy",
|
| 149 |
+
height=300,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
with gr.Row():
|
| 153 |
+
ddim_steps = gr.Slider(
|
| 154 |
+
minimum=4,
|
| 155 |
+
maximum=50,
|
| 156 |
+
value=50,
|
| 157 |
+
step=1,
|
| 158 |
+
label="DDIM Steps",
|
| 159 |
+
info="More steps = better quality but slower",
|
| 160 |
+
)
|
| 161 |
+
# scale = gr.Slider(
|
| 162 |
+
# minimum=1.0,
|
| 163 |
+
# maximum=10.0,
|
| 164 |
+
# value=3.0,
|
| 165 |
+
# step=0.5,
|
| 166 |
+
# label="CFG Scale",
|
| 167 |
+
# info="Guidance scale for conditioning",
|
| 168 |
+
# )
|
| 169 |
+
|
| 170 |
+
seed = gr.Number(
|
| 171 |
+
value=42,
|
| 172 |
+
label="Random Seed",
|
| 173 |
+
info="For reproducibility",
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
num_images = gr.Slider(
|
| 177 |
+
minimum=1,
|
| 178 |
+
maximum=32,
|
| 179 |
+
value=4,
|
| 180 |
+
step=1,
|
| 181 |
+
label="Number of output images",
|
| 182 |
+
info="Multi-output with different initial noise",
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
run_btn = gr.Button("Run Inference", variant="primary")
|
| 186 |
+
|
| 187 |
+
with gr.Column():
|
| 188 |
+
output_gallery = gr.Gallery(
|
| 189 |
+
label="Results",
|
| 190 |
+
height=800,
|
| 191 |
+
columns=2,
|
| 192 |
+
)
|
| 193 |
+
status_text = gr.Textbox(
|
| 194 |
+
label="Status",
|
| 195 |
+
lines=3,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
gr.Markdown(
|
| 199 |
+
"""
|
| 200 |
+
### Usage
|
| 201 |
+
1. Upload a **target image** (the person whose face/hair/motion/head will be modified)
|
| 202 |
+
2. Upload a **reference image** (the source of the attribute to transfer)
|
| 203 |
+
3. Select the **task** type
|
| 204 |
+
4. Click "Run Inference"
|
| 205 |
+
|
| 206 |
+
### Requirements
|
| 207 |
+
- Works best when the heads in the two input images have similar sizes.
|
| 208 |
+
"""
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
run_btn.click(
|
| 212 |
+
fn=inference,
|
| 213 |
+
inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
|
| 214 |
+
outputs=[output_gallery, status_text],
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
task_dropdown.change(
|
| 218 |
+
fn=lambda t: f"Task switched to: {t} transfer",
|
| 219 |
+
inputs=[task_dropdown],
|
| 220 |
+
outputs=[status_text],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
gr.Examples(
|
| 224 |
+
examples=[
|
| 225 |
+
["face", "examples/face/tgt.png", "examples/face/ref.png", 20, 42, 4],
|
| 226 |
+
["hair", "examples/hair/tgt.png", "examples/hair/ref.png", 20, 42, 4],
|
| 227 |
+
["motion", "examples/motion/tgt.png", "examples/motion/ref.png", 20, 42, 4],
|
| 228 |
+
["head", "examples/head/tgt.png", "examples/head/ref.png", 20, 42, 4],
|
| 229 |
+
],
|
| 230 |
+
inputs=[task_dropdown, tgt_image, ref_image, ddim_steps, seed, num_images],
|
| 231 |
+
label="Examples",
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return demo
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
demo = create_demo()
|
| 239 |
+
demo.launch()
|
checkpoints/pretrained.json
ADDED
|
@@ -0,0 +1,1072 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
".model.diffusion_model.input_blocks.0.0": [
|
| 3 |
+
4,
|
| 4 |
+
4,
|
| 5 |
+
4,
|
| 6 |
+
4
|
| 7 |
+
],
|
| 8 |
+
".model.diffusion_model.input_blocks.1.0.in_layers.2": [
|
| 9 |
+
5,
|
| 10 |
+
4,
|
| 11 |
+
8,
|
| 12 |
+
4
|
| 13 |
+
],
|
| 14 |
+
".model.diffusion_model.input_blocks.1.0.out_layers.3": [
|
| 15 |
+
7,
|
| 16 |
+
4,
|
| 17 |
+
12,
|
| 18 |
+
4
|
| 19 |
+
],
|
| 20 |
+
".model.diffusion_model.input_blocks.1.1.proj_in": [
|
| 21 |
+
4,
|
| 22 |
+
4,
|
| 23 |
+
6,
|
| 24 |
+
4
|
| 25 |
+
],
|
| 26 |
+
".model.diffusion_model.input_blocks.1.1.transformer_blocks.0.ff": [
|
| 27 |
+
[
|
| 28 |
+
5,
|
| 29 |
+
4,
|
| 30 |
+
8,
|
| 31 |
+
4
|
| 32 |
+
],
|
| 33 |
+
[
|
| 34 |
+
7,
|
| 35 |
+
4,
|
| 36 |
+
12,
|
| 37 |
+
4
|
| 38 |
+
]
|
| 39 |
+
],
|
| 40 |
+
".model.diffusion_model.input_blocks.1.1.proj_out": [
|
| 41 |
+
4,
|
| 42 |
+
4,
|
| 43 |
+
8,
|
| 44 |
+
4
|
| 45 |
+
],
|
| 46 |
+
".model.diffusion_model.input_blocks.2.0.in_layers.2": [
|
| 47 |
+
14,
|
| 48 |
+
5,
|
| 49 |
+
19,
|
| 50 |
+
4
|
| 51 |
+
],
|
| 52 |
+
".model.diffusion_model.input_blocks.2.0.out_layers.3": [
|
| 53 |
+
16,
|
| 54 |
+
4,
|
| 55 |
+
15,
|
| 56 |
+
4
|
| 57 |
+
],
|
| 58 |
+
".model.diffusion_model.input_blocks.2.1.proj_in": [
|
| 59 |
+
9,
|
| 60 |
+
4,
|
| 61 |
+
11,
|
| 62 |
+
4
|
| 63 |
+
],
|
| 64 |
+
".model.diffusion_model.input_blocks.2.1.transformer_blocks.0.ff": [
|
| 65 |
+
[
|
| 66 |
+
16,
|
| 67 |
+
4,
|
| 68 |
+
14,
|
| 69 |
+
4
|
| 70 |
+
],
|
| 71 |
+
[
|
| 72 |
+
17,
|
| 73 |
+
4,
|
| 74 |
+
14,
|
| 75 |
+
4
|
| 76 |
+
]
|
| 77 |
+
],
|
| 78 |
+
".model.diffusion_model.input_blocks.2.1.proj_out": [
|
| 79 |
+
13,
|
| 80 |
+
4,
|
| 81 |
+
11,
|
| 82 |
+
4
|
| 83 |
+
],
|
| 84 |
+
".model.diffusion_model.input_blocks.3.0.op": [
|
| 85 |
+
26,
|
| 86 |
+
7,
|
| 87 |
+
31,
|
| 88 |
+
8
|
| 89 |
+
],
|
| 90 |
+
".model.diffusion_model.input_blocks.4.0.in_layers.2": [
|
| 91 |
+
23,
|
| 92 |
+
6,
|
| 93 |
+
31,
|
| 94 |
+
8
|
| 95 |
+
],
|
| 96 |
+
".model.diffusion_model.input_blocks.4.0.out_layers.3": [
|
| 97 |
+
27,
|
| 98 |
+
6,
|
| 99 |
+
37,
|
| 100 |
+
8
|
| 101 |
+
],
|
| 102 |
+
".model.diffusion_model.input_blocks.4.0.skip_connection": [
|
| 103 |
+
20,
|
| 104 |
+
6,
|
| 105 |
+
22,
|
| 106 |
+
6
|
| 107 |
+
],
|
| 108 |
+
".model.diffusion_model.input_blocks.4.1.proj_in": [
|
| 109 |
+
20,
|
| 110 |
+
6,
|
| 111 |
+
28,
|
| 112 |
+
7
|
| 113 |
+
],
|
| 114 |
+
".model.diffusion_model.input_blocks.4.1.transformer_blocks.0.ff": [
|
| 115 |
+
[
|
| 116 |
+
22,
|
| 117 |
+
6,
|
| 118 |
+
37,
|
| 119 |
+
8
|
| 120 |
+
],
|
| 121 |
+
[
|
| 122 |
+
31,
|
| 123 |
+
8,
|
| 124 |
+
39,
|
| 125 |
+
10
|
| 126 |
+
]
|
| 127 |
+
],
|
| 128 |
+
".model.diffusion_model.input_blocks.4.1.proj_out": [
|
| 129 |
+
26,
|
| 130 |
+
8,
|
| 131 |
+
37,
|
| 132 |
+
10
|
| 133 |
+
],
|
| 134 |
+
".model.diffusion_model.input_blocks.5.0.in_layers.2": [
|
| 135 |
+
27,
|
| 136 |
+
10,
|
| 137 |
+
46,
|
| 138 |
+
11
|
| 139 |
+
],
|
| 140 |
+
".model.diffusion_model.input_blocks.5.0.out_layers.3": [
|
| 141 |
+
18,
|
| 142 |
+
6,
|
| 143 |
+
36,
|
| 144 |
+
7
|
| 145 |
+
],
|
| 146 |
+
".model.diffusion_model.input_blocks.5.1.proj_in": [
|
| 147 |
+
20,
|
| 148 |
+
7,
|
| 149 |
+
29,
|
| 150 |
+
7
|
| 151 |
+
],
|
| 152 |
+
".model.diffusion_model.input_blocks.5.1.transformer_blocks.0.ff": [
|
| 153 |
+
[
|
| 154 |
+
22,
|
| 155 |
+
7,
|
| 156 |
+
41,
|
| 157 |
+
9
|
| 158 |
+
],
|
| 159 |
+
[
|
| 160 |
+
26,
|
| 161 |
+
10,
|
| 162 |
+
33,
|
| 163 |
+
12
|
| 164 |
+
]
|
| 165 |
+
],
|
| 166 |
+
".model.diffusion_model.input_blocks.5.1.proj_out": [
|
| 167 |
+
24,
|
| 168 |
+
9,
|
| 169 |
+
33,
|
| 170 |
+
10
|
| 171 |
+
],
|
| 172 |
+
".model.diffusion_model.input_blocks.6.0.op": [
|
| 173 |
+
52,
|
| 174 |
+
17,
|
| 175 |
+
76,
|
| 176 |
+
20
|
| 177 |
+
],
|
| 178 |
+
".model.diffusion_model.input_blocks.7.0.in_layers.2": [
|
| 179 |
+
50,
|
| 180 |
+
14,
|
| 181 |
+
80,
|
| 182 |
+
19
|
| 183 |
+
],
|
| 184 |
+
".model.diffusion_model.input_blocks.7.0.out_layers.3": [
|
| 185 |
+
56,
|
| 186 |
+
15,
|
| 187 |
+
90,
|
| 188 |
+
22
|
| 189 |
+
],
|
| 190 |
+
".model.diffusion_model.input_blocks.7.0.skip_connection": [
|
| 191 |
+
40,
|
| 192 |
+
13,
|
| 193 |
+
59,
|
| 194 |
+
16
|
| 195 |
+
],
|
| 196 |
+
".model.diffusion_model.input_blocks.7.1.proj_in": [
|
| 197 |
+
33,
|
| 198 |
+
12,
|
| 199 |
+
55,
|
| 200 |
+
14
|
| 201 |
+
],
|
| 202 |
+
".model.diffusion_model.input_blocks.7.1.transformer_blocks.0.ff": [
|
| 203 |
+
[
|
| 204 |
+
39,
|
| 205 |
+
11,
|
| 206 |
+
62,
|
| 207 |
+
13
|
| 208 |
+
],
|
| 209 |
+
[
|
| 210 |
+
59,
|
| 211 |
+
17,
|
| 212 |
+
82,
|
| 213 |
+
21
|
| 214 |
+
]
|
| 215 |
+
],
|
| 216 |
+
".model.diffusion_model.input_blocks.7.1.proj_out": [
|
| 217 |
+
55,
|
| 218 |
+
17,
|
| 219 |
+
80,
|
| 220 |
+
22
|
| 221 |
+
],
|
| 222 |
+
".model.diffusion_model.input_blocks.8.0.in_layers.2": [
|
| 223 |
+
73,
|
| 224 |
+
20,
|
| 225 |
+
108,
|
| 226 |
+
27
|
| 227 |
+
],
|
| 228 |
+
".model.diffusion_model.input_blocks.8.0.out_layers.3": [
|
| 229 |
+
65,
|
| 230 |
+
15,
|
| 231 |
+
95,
|
| 232 |
+
21
|
| 233 |
+
],
|
| 234 |
+
".model.diffusion_model.input_blocks.8.1.proj_in": [
|
| 235 |
+
43,
|
| 236 |
+
13,
|
| 237 |
+
69,
|
| 238 |
+
18
|
| 239 |
+
],
|
| 240 |
+
".model.diffusion_model.input_blocks.8.1.transformer_blocks.0.ff": [
|
| 241 |
+
[
|
| 242 |
+
41,
|
| 243 |
+
10,
|
| 244 |
+
68,
|
| 245 |
+
13
|
| 246 |
+
],
|
| 247 |
+
[
|
| 248 |
+
56,
|
| 249 |
+
17,
|
| 250 |
+
85,
|
| 251 |
+
21
|
| 252 |
+
]
|
| 253 |
+
],
|
| 254 |
+
".model.diffusion_model.input_blocks.8.1.proj_out": [
|
| 255 |
+
52,
|
| 256 |
+
16,
|
| 257 |
+
78,
|
| 258 |
+
20
|
| 259 |
+
],
|
| 260 |
+
".model.diffusion_model.input_blocks.9.0.op": [
|
| 261 |
+
90,
|
| 262 |
+
30,
|
| 263 |
+
157,
|
| 264 |
+
39
|
| 265 |
+
],
|
| 266 |
+
".model.diffusion_model.input_blocks.10.0.in_layers.2": [
|
| 267 |
+
81,
|
| 268 |
+
21,
|
| 269 |
+
113,
|
| 270 |
+
26
|
| 271 |
+
],
|
| 272 |
+
".model.diffusion_model.input_blocks.10.0.out_layers.3": [
|
| 273 |
+
80,
|
| 274 |
+
21,
|
| 275 |
+
123,
|
| 276 |
+
28
|
| 277 |
+
],
|
| 278 |
+
".model.diffusion_model.input_blocks.11.0.in_layers.2": [
|
| 279 |
+
87,
|
| 280 |
+
23,
|
| 281 |
+
118,
|
| 282 |
+
28
|
| 283 |
+
],
|
| 284 |
+
".model.diffusion_model.input_blocks.11.0.out_layers.3": [
|
| 285 |
+
77,
|
| 286 |
+
20,
|
| 287 |
+
113,
|
| 288 |
+
26
|
| 289 |
+
],
|
| 290 |
+
".model.diffusion_model.middle_block.0.in_layers.2": [
|
| 291 |
+
84,
|
| 292 |
+
22,
|
| 293 |
+
113,
|
| 294 |
+
26
|
| 295 |
+
],
|
| 296 |
+
".model.diffusion_model.middle_block.0.out_layers.3": [
|
| 297 |
+
68,
|
| 298 |
+
16,
|
| 299 |
+
99,
|
| 300 |
+
21
|
| 301 |
+
],
|
| 302 |
+
".model.diffusion_model.middle_block.1.proj_in": [
|
| 303 |
+
36,
|
| 304 |
+
10,
|
| 305 |
+
59,
|
| 306 |
+
13
|
| 307 |
+
],
|
| 308 |
+
".model.diffusion_model.middle_block.1.transformer_blocks.0.ff": [
|
| 309 |
+
[
|
| 310 |
+
31,
|
| 311 |
+
5,
|
| 312 |
+
45,
|
| 313 |
+
6
|
| 314 |
+
],
|
| 315 |
+
[
|
| 316 |
+
55,
|
| 317 |
+
15,
|
| 318 |
+
69,
|
| 319 |
+
17
|
| 320 |
+
]
|
| 321 |
+
],
|
| 322 |
+
".model.diffusion_model.middle_block.1.proj_out": [
|
| 323 |
+
39,
|
| 324 |
+
10,
|
| 325 |
+
61,
|
| 326 |
+
14
|
| 327 |
+
],
|
| 328 |
+
".model.diffusion_model.middle_block.2.in_layers.2": [
|
| 329 |
+
73,
|
| 330 |
+
17,
|
| 331 |
+
104,
|
| 332 |
+
23
|
| 333 |
+
],
|
| 334 |
+
".model.diffusion_model.middle_block.2.out_layers.3": [
|
| 335 |
+
62,
|
| 336 |
+
15,
|
| 337 |
+
88,
|
| 338 |
+
20
|
| 339 |
+
],
|
| 340 |
+
".model.diffusion_model.output_blocks.0.0.in_layers.2": [
|
| 341 |
+
96,
|
| 342 |
+
25,
|
| 343 |
+
135,
|
| 344 |
+
32
|
| 345 |
+
],
|
| 346 |
+
".model.diffusion_model.output_blocks.0.0.out_layers.3": [
|
| 347 |
+
86,
|
| 348 |
+
21,
|
| 349 |
+
120,
|
| 350 |
+
28
|
| 351 |
+
],
|
| 352 |
+
".model.diffusion_model.output_blocks.0.0.skip_connection": [
|
| 353 |
+
64,
|
| 354 |
+
21,
|
| 355 |
+
106,
|
| 356 |
+
27
|
| 357 |
+
],
|
| 358 |
+
".model.diffusion_model.output_blocks.1.0.in_layers.2": [
|
| 359 |
+
94,
|
| 360 |
+
27,
|
| 361 |
+
155,
|
| 362 |
+
36
|
| 363 |
+
],
|
| 364 |
+
".model.diffusion_model.output_blocks.1.0.out_layers.3": [
|
| 365 |
+
86,
|
| 366 |
+
24,
|
| 367 |
+
136,
|
| 368 |
+
31
|
| 369 |
+
],
|
| 370 |
+
".model.diffusion_model.output_blocks.1.0.skip_connection": [
|
| 371 |
+
72,
|
| 372 |
+
23,
|
| 373 |
+
115,
|
| 374 |
+
29
|
| 375 |
+
],
|
| 376 |
+
".model.diffusion_model.output_blocks.2.0.in_layers.2": [
|
| 377 |
+
84,
|
| 378 |
+
31,
|
| 379 |
+
164,
|
| 380 |
+
39
|
| 381 |
+
],
|
| 382 |
+
".model.diffusion_model.output_blocks.2.0.out_layers.3": [
|
| 383 |
+
42,
|
| 384 |
+
19,
|
| 385 |
+
123,
|
| 386 |
+
29
|
| 387 |
+
],
|
| 388 |
+
".model.diffusion_model.output_blocks.2.0.skip_connection": [
|
| 389 |
+
72,
|
| 390 |
+
24,
|
| 391 |
+
110,
|
| 392 |
+
28
|
| 393 |
+
],
|
| 394 |
+
".model.diffusion_model.output_blocks.2.1.conv": [
|
| 395 |
+
72,
|
| 396 |
+
25,
|
| 397 |
+
121,
|
| 398 |
+
29
|
| 399 |
+
],
|
| 400 |
+
".model.diffusion_model.output_blocks.3.0.in_layers.2": [
|
| 401 |
+
85,
|
| 402 |
+
31,
|
| 403 |
+
158,
|
| 404 |
+
38
|
| 405 |
+
],
|
| 406 |
+
".model.diffusion_model.output_blocks.3.0.out_layers.3": [
|
| 407 |
+
42,
|
| 408 |
+
21,
|
| 409 |
+
117,
|
| 410 |
+
25
|
| 411 |
+
],
|
| 412 |
+
".model.diffusion_model.output_blocks.3.0.skip_connection": [
|
| 413 |
+
71,
|
| 414 |
+
23,
|
| 415 |
+
111,
|
| 416 |
+
28
|
| 417 |
+
],
|
| 418 |
+
".model.diffusion_model.output_blocks.3.1.proj_in": [
|
| 419 |
+
42,
|
| 420 |
+
14,
|
| 421 |
+
73,
|
| 422 |
+
18
|
| 423 |
+
],
|
| 424 |
+
".model.diffusion_model.output_blocks.3.1.transformer_blocks.0.ff": [
|
| 425 |
+
[
|
| 426 |
+
37,
|
| 427 |
+
10,
|
| 428 |
+
68,
|
| 429 |
+
13
|
| 430 |
+
],
|
| 431 |
+
[
|
| 432 |
+
60,
|
| 433 |
+
18,
|
| 434 |
+
83,
|
| 435 |
+
20
|
| 436 |
+
]
|
| 437 |
+
],
|
| 438 |
+
".model.diffusion_model.output_blocks.3.1.proj_out": [
|
| 439 |
+
51,
|
| 440 |
+
18,
|
| 441 |
+
79,
|
| 442 |
+
21
|
| 443 |
+
],
|
| 444 |
+
".model.diffusion_model.output_blocks.4.0.in_layers.2": [
|
| 445 |
+
104,
|
| 446 |
+
32,
|
| 447 |
+
159,
|
| 448 |
+
40
|
| 449 |
+
],
|
| 450 |
+
".model.diffusion_model.output_blocks.4.0.out_layers.3": [
|
| 451 |
+
83,
|
| 452 |
+
24,
|
| 453 |
+
125,
|
| 454 |
+
29
|
| 455 |
+
],
|
| 456 |
+
".model.diffusion_model.output_blocks.4.0.skip_connection": [
|
| 457 |
+
73,
|
| 458 |
+
22,
|
| 459 |
+
101,
|
| 460 |
+
28
|
| 461 |
+
],
|
| 462 |
+
".model.diffusion_model.output_blocks.4.1.proj_in": [
|
| 463 |
+
49,
|
| 464 |
+
15,
|
| 465 |
+
77,
|
| 466 |
+
20
|
| 467 |
+
],
|
| 468 |
+
".model.diffusion_model.output_blocks.4.1.transformer_blocks.0.ff": [
|
| 469 |
+
[
|
| 470 |
+
38,
|
| 471 |
+
11,
|
| 472 |
+
70,
|
| 473 |
+
14
|
| 474 |
+
],
|
| 475 |
+
[
|
| 476 |
+
63,
|
| 477 |
+
16,
|
| 478 |
+
85,
|
| 479 |
+
20
|
| 480 |
+
]
|
| 481 |
+
],
|
| 482 |
+
".model.diffusion_model.output_blocks.4.1.proj_out": [
|
| 483 |
+
51,
|
| 484 |
+
18,
|
| 485 |
+
81,
|
| 486 |
+
21
|
| 487 |
+
],
|
| 488 |
+
".model.diffusion_model.output_blocks.5.0.in_layers.2": [
|
| 489 |
+
91,
|
| 490 |
+
33,
|
| 491 |
+
161,
|
| 492 |
+
40
|
| 493 |
+
],
|
| 494 |
+
".model.diffusion_model.output_blocks.5.0.out_layers.3": [
|
| 495 |
+
83,
|
| 496 |
+
26,
|
| 497 |
+
140,
|
| 498 |
+
32
|
| 499 |
+
],
|
| 500 |
+
".model.diffusion_model.output_blocks.5.0.skip_connection": [
|
| 501 |
+
81,
|
| 502 |
+
24,
|
| 503 |
+
116,
|
| 504 |
+
30
|
| 505 |
+
],
|
| 506 |
+
".model.diffusion_model.output_blocks.5.1.proj_in": [
|
| 507 |
+
48,
|
| 508 |
+
16,
|
| 509 |
+
82,
|
| 510 |
+
21
|
| 511 |
+
],
|
| 512 |
+
".model.diffusion_model.output_blocks.5.1.transformer_blocks.0.ff": [
|
| 513 |
+
[
|
| 514 |
+
34,
|
| 515 |
+
12,
|
| 516 |
+
76,
|
| 517 |
+
15
|
| 518 |
+
],
|
| 519 |
+
[
|
| 520 |
+
55,
|
| 521 |
+
16,
|
| 522 |
+
81,
|
| 523 |
+
18
|
| 524 |
+
]
|
| 525 |
+
],
|
| 526 |
+
".model.diffusion_model.output_blocks.5.1.proj_out": [
|
| 527 |
+
57,
|
| 528 |
+
19,
|
| 529 |
+
85,
|
| 530 |
+
22
|
| 531 |
+
],
|
| 532 |
+
".model.diffusion_model.output_blocks.5.2.conv": [
|
| 533 |
+
108,
|
| 534 |
+
34,
|
| 535 |
+
159,
|
| 536 |
+
41
|
| 537 |
+
],
|
| 538 |
+
".model.diffusion_model.output_blocks.6.0.in_layers.2": [
|
| 539 |
+
55,
|
| 540 |
+
18,
|
| 541 |
+
87,
|
| 542 |
+
22
|
| 543 |
+
],
|
| 544 |
+
".model.diffusion_model.output_blocks.6.0.out_layers.3": [
|
| 545 |
+
32,
|
| 546 |
+
13,
|
| 547 |
+
54,
|
| 548 |
+
15
|
| 549 |
+
],
|
| 550 |
+
".model.diffusion_model.output_blocks.6.0.skip_connection": [
|
| 551 |
+
25,
|
| 552 |
+
9,
|
| 553 |
+
30,
|
| 554 |
+
14
|
| 555 |
+
],
|
| 556 |
+
".model.diffusion_model.output_blocks.6.1.proj_in": [
|
| 557 |
+
26,
|
| 558 |
+
9,
|
| 559 |
+
40,
|
| 560 |
+
11
|
| 561 |
+
],
|
| 562 |
+
".model.diffusion_model.output_blocks.6.1.transformer_blocks.0.ff": [
|
| 563 |
+
[
|
| 564 |
+
25,
|
| 565 |
+
8,
|
| 566 |
+
47,
|
| 567 |
+
12
|
| 568 |
+
],
|
| 569 |
+
[
|
| 570 |
+
36,
|
| 571 |
+
11,
|
| 572 |
+
47,
|
| 573 |
+
13
|
| 574 |
+
]
|
| 575 |
+
],
|
| 576 |
+
".model.diffusion_model.output_blocks.6.1.proj_out": [
|
| 577 |
+
23,
|
| 578 |
+
10,
|
| 579 |
+
38,
|
| 580 |
+
12
|
| 581 |
+
],
|
| 582 |
+
".model.diffusion_model.output_blocks.7.0.in_layers.2": [
|
| 583 |
+
55,
|
| 584 |
+
18,
|
| 585 |
+
82,
|
| 586 |
+
20
|
| 587 |
+
],
|
| 588 |
+
".model.diffusion_model.output_blocks.7.0.out_layers.3": [
|
| 589 |
+
47,
|
| 590 |
+
14,
|
| 591 |
+
65,
|
| 592 |
+
17
|
| 593 |
+
],
|
| 594 |
+
".model.diffusion_model.output_blocks.7.0.skip_connection": [
|
| 595 |
+
40,
|
| 596 |
+
11,
|
| 597 |
+
40,
|
| 598 |
+
12
|
| 599 |
+
],
|
| 600 |
+
".model.diffusion_model.output_blocks.7.1.proj_in": [
|
| 601 |
+
27,
|
| 602 |
+
9,
|
| 603 |
+
41,
|
| 604 |
+
11
|
| 605 |
+
],
|
| 606 |
+
".model.diffusion_model.output_blocks.7.1.transformer_blocks.0.ff": [
|
| 607 |
+
[
|
| 608 |
+
27,
|
| 609 |
+
8,
|
| 610 |
+
47,
|
| 611 |
+
11
|
| 612 |
+
],
|
| 613 |
+
[
|
| 614 |
+
34,
|
| 615 |
+
11,
|
| 616 |
+
47,
|
| 617 |
+
12
|
| 618 |
+
]
|
| 619 |
+
],
|
| 620 |
+
".model.diffusion_model.output_blocks.7.1.proj_out": [
|
| 621 |
+
33,
|
| 622 |
+
9,
|
| 623 |
+
39,
|
| 624 |
+
12
|
| 625 |
+
],
|
| 626 |
+
".model.diffusion_model.output_blocks.8.0.in_layers.2": [
|
| 627 |
+
58,
|
| 628 |
+
17,
|
| 629 |
+
82,
|
| 630 |
+
20
|
| 631 |
+
],
|
| 632 |
+
".model.diffusion_model.output_blocks.8.0.out_layers.3": [
|
| 633 |
+
56,
|
| 634 |
+
15,
|
| 635 |
+
75,
|
| 636 |
+
18
|
| 637 |
+
],
|
| 638 |
+
".model.diffusion_model.output_blocks.8.0.skip_connection": [
|
| 639 |
+
44,
|
| 640 |
+
10,
|
| 641 |
+
47,
|
| 642 |
+
11
|
| 643 |
+
],
|
| 644 |
+
".model.diffusion_model.output_blocks.8.1.proj_in": [
|
| 645 |
+
32,
|
| 646 |
+
9,
|
| 647 |
+
43,
|
| 648 |
+
10
|
| 649 |
+
],
|
| 650 |
+
".model.diffusion_model.output_blocks.8.1.transformer_blocks.0.ff": [
|
| 651 |
+
[
|
| 652 |
+
28,
|
| 653 |
+
7,
|
| 654 |
+
47,
|
| 655 |
+
8
|
| 656 |
+
],
|
| 657 |
+
[
|
| 658 |
+
35,
|
| 659 |
+
8,
|
| 660 |
+
45,
|
| 661 |
+
8
|
| 662 |
+
]
|
| 663 |
+
],
|
| 664 |
+
".model.diffusion_model.output_blocks.8.1.proj_out": [
|
| 665 |
+
35,
|
| 666 |
+
10,
|
| 667 |
+
44,
|
| 668 |
+
10
|
| 669 |
+
],
|
| 670 |
+
".model.diffusion_model.output_blocks.8.2.conv": [
|
| 671 |
+
65,
|
| 672 |
+
19,
|
| 673 |
+
85,
|
| 674 |
+
22
|
| 675 |
+
],
|
| 676 |
+
".model.diffusion_model.output_blocks.9.0.in_layers.2": [
|
| 677 |
+
37,
|
| 678 |
+
10,
|
| 679 |
+
35,
|
| 680 |
+
10
|
| 681 |
+
],
|
| 682 |
+
".model.diffusion_model.output_blocks.9.0.out_layers.3": [
|
| 683 |
+
28,
|
| 684 |
+
6,
|
| 685 |
+
23,
|
| 686 |
+
5
|
| 687 |
+
],
|
| 688 |
+
".model.diffusion_model.output_blocks.9.0.skip_connection": [
|
| 689 |
+
15,
|
| 690 |
+
4,
|
| 691 |
+
4,
|
| 692 |
+
4
|
| 693 |
+
],
|
| 694 |
+
".model.diffusion_model.output_blocks.9.1.proj_in": [
|
| 695 |
+
16,
|
| 696 |
+
4,
|
| 697 |
+
6,
|
| 698 |
+
4
|
| 699 |
+
],
|
| 700 |
+
".model.diffusion_model.output_blocks.9.1.transformer_blocks.0.ff": [
|
| 701 |
+
[
|
| 702 |
+
24,
|
| 703 |
+
5,
|
| 704 |
+
23,
|
| 705 |
+
5
|
| 706 |
+
],
|
| 707 |
+
[
|
| 708 |
+
23,
|
| 709 |
+
5,
|
| 710 |
+
24,
|
| 711 |
+
6
|
| 712 |
+
]
|
| 713 |
+
],
|
| 714 |
+
".model.diffusion_model.output_blocks.9.1.proj_out": [
|
| 715 |
+
16,
|
| 716 |
+
4,
|
| 717 |
+
14,
|
| 718 |
+
4
|
| 719 |
+
],
|
| 720 |
+
".model.diffusion_model.output_blocks.10.0.in_layers.2": [
|
| 721 |
+
31,
|
| 722 |
+
9,
|
| 723 |
+
38,
|
| 724 |
+
10
|
| 725 |
+
],
|
| 726 |
+
".model.diffusion_model.output_blocks.10.0.out_layers.3": [
|
| 727 |
+
20,
|
| 728 |
+
4,
|
| 729 |
+
24,
|
| 730 |
+
4
|
| 731 |
+
],
|
| 732 |
+
".model.diffusion_model.output_blocks.10.0.skip_connection": [
|
| 733 |
+
4,
|
| 734 |
+
4,
|
| 735 |
+
7,
|
| 736 |
+
4
|
| 737 |
+
],
|
| 738 |
+
".model.diffusion_model.output_blocks.10.1.proj_in": [
|
| 739 |
+
6,
|
| 740 |
+
4,
|
| 741 |
+
11,
|
| 742 |
+
4
|
| 743 |
+
],
|
| 744 |
+
".model.diffusion_model.output_blocks.10.1.transformer_blocks.0.ff": [
|
| 745 |
+
[
|
| 746 |
+
17,
|
| 747 |
+
4,
|
| 748 |
+
21,
|
| 749 |
+
4
|
| 750 |
+
],
|
| 751 |
+
[
|
| 752 |
+
17,
|
| 753 |
+
5,
|
| 754 |
+
21,
|
| 755 |
+
5
|
| 756 |
+
]
|
| 757 |
+
],
|
| 758 |
+
".model.diffusion_model.output_blocks.10.1.proj_out": [
|
| 759 |
+
9,
|
| 760 |
+
4,
|
| 761 |
+
12,
|
| 762 |
+
4
|
| 763 |
+
],
|
| 764 |
+
".model.diffusion_model.output_blocks.11.0.in_layers.2": [
|
| 765 |
+
7,
|
| 766 |
+
4,
|
| 767 |
+
18,
|
| 768 |
+
4
|
| 769 |
+
],
|
| 770 |
+
".model.diffusion_model.output_blocks.11.0.out_layers.3": [
|
| 771 |
+
16,
|
| 772 |
+
6,
|
| 773 |
+
22,
|
| 774 |
+
5
|
| 775 |
+
],
|
| 776 |
+
".model.diffusion_model.output_blocks.11.0.skip_connection": [
|
| 777 |
+
4,
|
| 778 |
+
4,
|
| 779 |
+
4,
|
| 780 |
+
4
|
| 781 |
+
],
|
| 782 |
+
".model.diffusion_model.output_blocks.11.1.proj_in": [
|
| 783 |
+
9,
|
| 784 |
+
4,
|
| 785 |
+
13,
|
| 786 |
+
4
|
| 787 |
+
],
|
| 788 |
+
".model.diffusion_model.output_blocks.11.1.transformer_blocks.0.ff": [
|
| 789 |
+
[
|
| 790 |
+
19,
|
| 791 |
+
4,
|
| 792 |
+
24,
|
| 793 |
+
4
|
| 794 |
+
],
|
| 795 |
+
[
|
| 796 |
+
12,
|
| 797 |
+
4,
|
| 798 |
+
14,
|
| 799 |
+
4
|
| 800 |
+
]
|
| 801 |
+
],
|
| 802 |
+
".model.diffusion_model.output_blocks.11.1.proj_out": [
|
| 803 |
+
7,
|
| 804 |
+
4,
|
| 805 |
+
10,
|
| 806 |
+
4
|
| 807 |
+
],
|
| 808 |
+
".model.diffusion_model.out.2": [
|
| 809 |
+
4,
|
| 810 |
+
4,
|
| 811 |
+
4,
|
| 812 |
+
4
|
| 813 |
+
],
|
| 814 |
+
".model.diffusion_model_refNet.input_blocks.0.0": [
|
| 815 |
+
4,
|
| 816 |
+
4,
|
| 817 |
+
4,
|
| 818 |
+
4
|
| 819 |
+
],
|
| 820 |
+
".model.diffusion_model_refNet.input_blocks.1.0.in_layers.2": [
|
| 821 |
+
17,
|
| 822 |
+
8,
|
| 823 |
+
26,
|
| 824 |
+
8
|
| 825 |
+
],
|
| 826 |
+
".model.diffusion_model_refNet.input_blocks.1.0.out_layers.3": [
|
| 827 |
+
21,
|
| 828 |
+
14,
|
| 829 |
+
37,
|
| 830 |
+
12
|
| 831 |
+
],
|
| 832 |
+
".model.diffusion_model_refNet.input_blocks.1.1.proj_in": [
|
| 833 |
+
11,
|
| 834 |
+
8,
|
| 835 |
+
19,
|
| 836 |
+
6
|
| 837 |
+
],
|
| 838 |
+
".model.diffusion_model_refNet.input_blocks.1.1.transformer_blocks.0.ff": [
|
| 839 |
+
[
|
| 840 |
+
14,
|
| 841 |
+
12,
|
| 842 |
+
24,
|
| 843 |
+
7
|
| 844 |
+
],
|
| 845 |
+
[
|
| 846 |
+
17,
|
| 847 |
+
12,
|
| 848 |
+
26,
|
| 849 |
+
7
|
| 850 |
+
]
|
| 851 |
+
],
|
| 852 |
+
".model.diffusion_model_refNet.input_blocks.1.1.proj_out": [
|
| 853 |
+
11,
|
| 854 |
+
7,
|
| 855 |
+
20,
|
| 856 |
+
5
|
| 857 |
+
],
|
| 858 |
+
".model.diffusion_model_refNet.input_blocks.2.0.in_layers.2": [
|
| 859 |
+
27,
|
| 860 |
+
15,
|
| 861 |
+
40,
|
| 862 |
+
13
|
| 863 |
+
],
|
| 864 |
+
".model.diffusion_model_refNet.input_blocks.2.0.out_layers.3": [
|
| 865 |
+
26,
|
| 866 |
+
15,
|
| 867 |
+
38,
|
| 868 |
+
12
|
| 869 |
+
],
|
| 870 |
+
".model.diffusion_model_refNet.input_blocks.2.1.proj_in": [
|
| 871 |
+
15,
|
| 872 |
+
7,
|
| 873 |
+
21,
|
| 874 |
+
6
|
| 875 |
+
],
|
| 876 |
+
".model.diffusion_model_refNet.input_blocks.2.1.transformer_blocks.0.ff": [
|
| 877 |
+
[
|
| 878 |
+
17,
|
| 879 |
+
13,
|
| 880 |
+
30,
|
| 881 |
+
9
|
| 882 |
+
],
|
| 883 |
+
[
|
| 884 |
+
16,
|
| 885 |
+
12,
|
| 886 |
+
27,
|
| 887 |
+
8
|
| 888 |
+
]
|
| 889 |
+
],
|
| 890 |
+
".model.diffusion_model_refNet.input_blocks.2.1.proj_out": [
|
| 891 |
+
12,
|
| 892 |
+
7,
|
| 893 |
+
18,
|
| 894 |
+
6
|
| 895 |
+
],
|
| 896 |
+
".model.diffusion_model_refNet.input_blocks.3.0.op": [
|
| 897 |
+
27,
|
| 898 |
+
13,
|
| 899 |
+
43,
|
| 900 |
+
12
|
| 901 |
+
],
|
| 902 |
+
".model.diffusion_model_refNet.input_blocks.4.0.in_layers.2": [
|
| 903 |
+
30,
|
| 904 |
+
19,
|
| 905 |
+
49,
|
| 906 |
+
14
|
| 907 |
+
],
|
| 908 |
+
".model.diffusion_model_refNet.input_blocks.4.0.out_layers.3": [
|
| 909 |
+
32,
|
| 910 |
+
26,
|
| 911 |
+
55,
|
| 912 |
+
15
|
| 913 |
+
],
|
| 914 |
+
".model.diffusion_model_refNet.input_blocks.4.0.skip_connection": [
|
| 915 |
+
22,
|
| 916 |
+
10,
|
| 917 |
+
30,
|
| 918 |
+
9
|
| 919 |
+
],
|
| 920 |
+
".model.diffusion_model_refNet.input_blocks.4.1.proj_in": [
|
| 921 |
+
22,
|
| 922 |
+
14,
|
| 923 |
+
35,
|
| 924 |
+
10
|
| 925 |
+
],
|
| 926 |
+
".model.diffusion_model_refNet.input_blocks.4.1.transformer_blocks.0.ff": [
|
| 927 |
+
[
|
| 928 |
+
26,
|
| 929 |
+
25,
|
| 930 |
+
52,
|
| 931 |
+
14
|
| 932 |
+
],
|
| 933 |
+
[
|
| 934 |
+
28,
|
| 935 |
+
22,
|
| 936 |
+
51,
|
| 937 |
+
14
|
| 938 |
+
]
|
| 939 |
+
],
|
| 940 |
+
".model.diffusion_model_refNet.input_blocks.4.1.proj_out": [
|
| 941 |
+
24,
|
| 942 |
+
15,
|
| 943 |
+
40,
|
| 944 |
+
11
|
| 945 |
+
],
|
| 946 |
+
".model.diffusion_model_refNet.input_blocks.5.0.in_layers.2": [
|
| 947 |
+
44,
|
| 948 |
+
30,
|
| 949 |
+
78,
|
| 950 |
+
22
|
| 951 |
+
],
|
| 952 |
+
".model.diffusion_model_refNet.input_blocks.5.0.out_layers.3": [
|
| 953 |
+
28,
|
| 954 |
+
29,
|
| 955 |
+
56,
|
| 956 |
+
15
|
| 957 |
+
],
|
| 958 |
+
".model.diffusion_model_refNet.input_blocks.5.1.proj_in": [
|
| 959 |
+
20,
|
| 960 |
+
13,
|
| 961 |
+
34,
|
| 962 |
+
9
|
| 963 |
+
],
|
| 964 |
+
".model.diffusion_model_refNet.input_blocks.5.1.transformer_blocks.0.ff": [
|
| 965 |
+
[
|
| 966 |
+
26,
|
| 967 |
+
27,
|
| 968 |
+
52,
|
| 969 |
+
14
|
| 970 |
+
],
|
| 971 |
+
[
|
| 972 |
+
23,
|
| 973 |
+
23,
|
| 974 |
+
53,
|
| 975 |
+
14
|
| 976 |
+
]
|
| 977 |
+
],
|
| 978 |
+
".model.diffusion_model_refNet.input_blocks.5.1.proj_out": [
|
| 979 |
+
17,
|
| 980 |
+
14,
|
| 981 |
+
36,
|
| 982 |
+
10
|
| 983 |
+
],
|
| 984 |
+
".model.diffusion_model_refNet.input_blocks.6.0.op": [
|
| 985 |
+
46,
|
| 986 |
+
31,
|
| 987 |
+
82,
|
| 988 |
+
21
|
| 989 |
+
],
|
| 990 |
+
".model.diffusion_model_refNet.input_blocks.7.0.in_layers.2": [
|
| 991 |
+
75,
|
| 992 |
+
41,
|
| 993 |
+
116,
|
| 994 |
+
32
|
| 995 |
+
],
|
| 996 |
+
".model.diffusion_model_refNet.input_blocks.7.0.out_layers.3": [
|
| 997 |
+
67,
|
| 998 |
+
50,
|
| 999 |
+
108,
|
| 1000 |
+
29
|
| 1001 |
+
],
|
| 1002 |
+
".model.diffusion_model_refNet.input_blocks.7.0.skip_connection": [
|
| 1003 |
+
31,
|
| 1004 |
+
19,
|
| 1005 |
+
59,
|
| 1006 |
+
15
|
| 1007 |
+
],
|
| 1008 |
+
".model.diffusion_model_refNet.input_blocks.7.1.proj_in": [
|
| 1009 |
+
36,
|
| 1010 |
+
29,
|
| 1011 |
+
73,
|
| 1012 |
+
19
|
| 1013 |
+
],
|
| 1014 |
+
".model.diffusion_model_refNet.input_blocks.7.1.transformer_blocks.0.ff": [
|
| 1015 |
+
[
|
| 1016 |
+
74,
|
| 1017 |
+
61,
|
| 1018 |
+
106,
|
| 1019 |
+
26
|
| 1020 |
+
],
|
| 1021 |
+
[
|
| 1022 |
+
63,
|
| 1023 |
+
49,
|
| 1024 |
+
90,
|
| 1025 |
+
24
|
| 1026 |
+
]
|
| 1027 |
+
],
|
| 1028 |
+
".model.diffusion_model_refNet.input_blocks.7.1.proj_out": [
|
| 1029 |
+
34,
|
| 1030 |
+
29,
|
| 1031 |
+
68,
|
| 1032 |
+
18
|
| 1033 |
+
],
|
| 1034 |
+
".model.diffusion_model_refNet.input_blocks.8.0.in_layers.2": [
|
| 1035 |
+
92,
|
| 1036 |
+
56,
|
| 1037 |
+
128,
|
| 1038 |
+
36
|
| 1039 |
+
],
|
| 1040 |
+
".model.diffusion_model_refNet.input_blocks.8.0.out_layers.3": [
|
| 1041 |
+
43,
|
| 1042 |
+
51,
|
| 1043 |
+
66,
|
| 1044 |
+
16
|
| 1045 |
+
],
|
| 1046 |
+
".model.diffusion_model_refNet.input_blocks.8.1.proj_in": [
|
| 1047 |
+
26,
|
| 1048 |
+
28,
|
| 1049 |
+
59,
|
| 1050 |
+
15
|
| 1051 |
+
],
|
| 1052 |
+
".model.diffusion_model_refNet.input_blocks.8.1.transformer_blocks.0.ff": [
|
| 1053 |
+
[
|
| 1054 |
+
188,
|
| 1055 |
+
69,
|
| 1056 |
+
232,
|
| 1057 |
+
69
|
| 1058 |
+
],
|
| 1059 |
+
[
|
| 1060 |
+
140,
|
| 1061 |
+
51,
|
| 1062 |
+
173,
|
| 1063 |
+
51
|
| 1064 |
+
]
|
| 1065 |
+
],
|
| 1066 |
+
".model.diffusion_model_refNet.input_blocks.8.1.proj_out": [
|
| 1067 |
+
91,
|
| 1068 |
+
33,
|
| 1069 |
+
113,
|
| 1070 |
+
33
|
| 1071 |
+
]
|
| 1072 |
+
}
|
download_checkpoints.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
from imports import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _download(repo_id, filename, local_path: Path) -> Path:
|
| 8 |
+
local_path = Path(local_path)
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 12 |
+
print(f"downloading to {local_path}")
|
| 13 |
+
downloaded = hf_hub_download(
|
| 14 |
+
repo_id=repo_id,
|
| 15 |
+
filename=filename,
|
| 16 |
+
local_dir=str(local_path.parent),
|
| 17 |
+
local_dir_use_symlinks=False,
|
| 18 |
+
token=token,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_download("CompVis/stable-diffusion-v-1-4-original",SD14_filename, SD14_localpath)
|
| 24 |
+
|
| 25 |
+
_download("scy639/UniBioTransfer",PRETRAIN_CKPT_PATH, ".")
|
| 26 |
+
_download("scy639/UniBioTransfer",PRETRAIN_JSON_PATH, ".")
|
| 27 |
+
|
| 28 |
+
_download("scy639/UniBioTransfer","Other_dependencies/arcface/model_ir_se50.pth", ".")
|
| 29 |
+
_download("scy639/UniBioTransfer","Other_dependencies/face_parsing/79999_iter.pth", ".")
|
eval_tool/lpips/__init__.py
ADDED
|
File without changes
|
eval_tool/lpips/lpips.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from eval_tool.lpips.networks import get_network, LinLayers
|
| 5 |
+
from eval_tool.lpips.utils import get_state_dict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LPIPS(nn.Module):
|
| 9 |
+
r"""Creates a criterion that measures
|
| 10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
| 11 |
+
Arguments:
|
| 12 |
+
net_type (str): the network type to compare the features:
|
| 13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
| 14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
| 17 |
+
|
| 18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
| 19 |
+
|
| 20 |
+
super(LPIPS, self).__init__()
|
| 21 |
+
|
| 22 |
+
# pretrained network
|
| 23 |
+
self.net = get_network(net_type)
|
| 24 |
+
|
| 25 |
+
# linear layers
|
| 26 |
+
self.lin = LinLayers(self.net.n_channels_list)
|
| 27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
| 31 |
+
|
| 32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
| 33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
| 34 |
+
|
| 35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
eval_tool/lpips/networks.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Sequence
|
| 2 |
+
|
| 3 |
+
from itertools import chain
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torchvision import models
|
| 8 |
+
|
| 9 |
+
from eval_tool.lpips.utils import normalize_activation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_network(net_type: str):
|
| 13 |
+
if net_type == 'alex':
|
| 14 |
+
return AlexNet()
|
| 15 |
+
elif net_type == 'squeeze':
|
| 16 |
+
return SqueezeNet()
|
| 17 |
+
elif net_type == 'vgg':
|
| 18 |
+
return VGG16()
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LinLayers(nn.ModuleList):
|
| 24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
| 25 |
+
super(LinLayers, self).__init__([
|
| 26 |
+
nn.Sequential(
|
| 27 |
+
nn.Identity(),
|
| 28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
| 29 |
+
) for nc in n_channels_list
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
for param in self.parameters():
|
| 33 |
+
param.requires_grad = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BaseNet(nn.Module):
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super(BaseNet, self).__init__()
|
| 39 |
+
|
| 40 |
+
# register buffer
|
| 41 |
+
self.register_buffer(
|
| 42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
| 43 |
+
self.register_buffer(
|
| 44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
| 45 |
+
|
| 46 |
+
def set_requires_grad(self, state: bool):
|
| 47 |
+
for param in chain(self.parameters(), self.buffers()):
|
| 48 |
+
param.requires_grad = state
|
| 49 |
+
|
| 50 |
+
def z_score(self, x: torch.Tensor):
|
| 51 |
+
return (x - self.mean) / self.std
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor):
|
| 54 |
+
x = self.z_score(x)
|
| 55 |
+
|
| 56 |
+
output = []
|
| 57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
| 58 |
+
x = layer(x)
|
| 59 |
+
if i in self.target_layers:
|
| 60 |
+
output.append(normalize_activation(x))
|
| 61 |
+
if len(output) == len(self.target_layers):
|
| 62 |
+
break
|
| 63 |
+
return output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class SqueezeNet(BaseNet):
|
| 67 |
+
def __init__(self):
|
| 68 |
+
super(SqueezeNet, self).__init__()
|
| 69 |
+
|
| 70 |
+
self.layers = models.squeezenet1_1(True).features
|
| 71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
| 72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
| 73 |
+
|
| 74 |
+
self.set_requires_grad(False)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AlexNet(BaseNet):
|
| 78 |
+
def __init__(self):
|
| 79 |
+
super(AlexNet, self).__init__()
|
| 80 |
+
|
| 81 |
+
self.layers = models.alexnet(True).features
|
| 82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
| 83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
| 84 |
+
|
| 85 |
+
self.set_requires_grad(False)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class VGG16(BaseNet):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super(VGG16, self).__init__()
|
| 91 |
+
|
| 92 |
+
self.layers = models.vgg16(True).features
|
| 93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
| 94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
| 95 |
+
|
| 96 |
+
self.set_requires_grad(False)
|
eval_tool/lpips/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def normalize_activation(x, eps=1e-10):
|
| 7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)+1e-16) #
|
| 8 |
+
return x / (norm_factor + eps)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
| 12 |
+
# build url
|
| 13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
| 14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
| 15 |
+
|
| 16 |
+
# download
|
| 17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
| 18 |
+
url, progress=True,
|
| 19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# rename keys
|
| 23 |
+
new_state_dict = OrderedDict()
|
| 24 |
+
for key, val in old_state_dict.items():
|
| 25 |
+
new_key = key
|
| 26 |
+
new_key = new_key.replace('lin', '')
|
| 27 |
+
new_key = new_key.replace('model.', '')
|
| 28 |
+
new_state_dict[new_key] = val
|
| 29 |
+
|
| 30 |
+
return new_state_dict
|
examples/face/ref-semantic_mask.png
ADDED
|
examples/face/ref.png
ADDED
|
Git LFS Details
|
examples/face/tgt-semantic_mask.png
ADDED
|
examples/face/tgt.png
ADDED
|
Git LFS Details
|
examples/hair/ref-semantic_mask.png
ADDED
|
examples/hair/ref.png
ADDED
|
Git LFS Details
|
examples/hair/tgt-semantic_mask.png
ADDED
|
examples/hair/tgt.png
ADDED
|
Git LFS Details
|
examples/head/ref-semantic_mask.png
ADDED
|
examples/head/ref.png
ADDED
|
Git LFS Details
|
examples/head/tgt-semantic_mask.png
ADDED
|
examples/head/tgt.png
ADDED
|
Git LFS Details
|
examples/inputs.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
target_path_1 reference_path_1
|
| 2 |
+
target_path_2 reference_path_2
|
| 3 |
+
target_path_3 reference_path_3
|
| 4 |
+
target_path_4 reference_path_4
|
| 5 |
+
target_path_5 reference_path_5
|
examples/motion/ref-semantic_mask.png
ADDED
|
examples/motion/ref.png
ADDED
|
Git LFS Details
|
examples/motion/tgt-semantic_mask.png
ADDED
|
examples/motion/tgt.png
ADDED
|
Git LFS Details
|
gen_lmk_and_mask.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENABLE_lmk_cache = False
|
| 2 |
+
ENABLE_mask_cache = False
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
from imports import *
|
| 7 |
+
from util_cv2 import cv2_resize_auto_interpolation
|
| 8 |
+
from Mediapipe_Result_Cache import Mediapipe_Result_Cache
|
| 9 |
+
from lmk_util.lmk_extractor import LandmarkExtractor
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def gen_lmk_and_mask(img_paths, size=512, write_cache=True):
|
| 13 |
+
extractor = LandmarkExtractor()
|
| 14 |
+
cache = Mediapipe_Result_Cache()
|
| 15 |
+
seen = set()
|
| 16 |
+
for p in img_paths:
|
| 17 |
+
if not p:
|
| 18 |
+
continue
|
| 19 |
+
p = str(p)
|
| 20 |
+
if p in seen:
|
| 21 |
+
continue
|
| 22 |
+
seen.add(p)
|
| 23 |
+
|
| 24 |
+
cache_path = cache.get_path(p)
|
| 25 |
+
if not ( cache_path.exists() and ENABLE_lmk_cache ):
|
| 26 |
+
img = cv2.imread(p)
|
| 27 |
+
if img is None:
|
| 28 |
+
print(f"cv2.imread failed: {p}")
|
| 29 |
+
raise
|
| 30 |
+
continue
|
| 31 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 32 |
+
img = cv2_resize_auto_interpolation(img, (size, size))
|
| 33 |
+
lmks = extractor.extract_single(img)
|
| 34 |
+
if lmks is None:
|
| 35 |
+
print(f"no lmks: {p}")
|
| 36 |
+
raise
|
| 37 |
+
continue
|
| 38 |
+
if write_cache:
|
| 39 |
+
cache.set(p, lmks)
|
| 40 |
+
|
| 41 |
+
path_img_2_path_mask(p, reuse_if_exists=ENABLE_mask_cache, label_mode="RF12_")
|
gen_semantic_mask.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
def:
|
| 3 |
+
tgt: Target image to be edited (face swapped)
|
| 4 |
+
ref: Face ID source image (also called src in REFace)
|
| 5 |
+
swap: Swapped output image, using face ID from ref to replace face in tgt
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from my_py_lib.image_util import print_image_statistics
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import numpy as np
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from torchvision.transforms import Resize
|
| 17 |
+
from torchvision.utils import make_grid
|
| 18 |
+
from contextlib import nullcontext
|
| 19 |
+
from torch.cuda.amp import autocast
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
# Sampling configs
|
| 26 |
+
DDIM_STEPS = 50
|
| 27 |
+
GUIDANCE_SCALE = 3.0
|
| 28 |
+
IMG_SIZE = 512
|
| 29 |
+
LATENT_CHANNELS = 4
|
| 30 |
+
DOWNSAMPLE_FACTOR = 8
|
| 31 |
+
START_NOISE_T = 1000
|
| 32 |
+
DDIM_ETA = 0.0
|
| 33 |
+
PRECISION = "full" # or "autocast"
|
| 34 |
+
FIXED_CODE = False # whether to use fixed starting code
|
| 35 |
+
SAVE_INTERMEDIATES = False # whether to save intermediate results
|
| 36 |
+
LOG_EVERY_T = 100 # log frequency during sampling
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MaskModel_LazyLoader:
|
| 40 |
+
model = None
|
| 41 |
+
@classmethod
|
| 42 |
+
def get(cls):
|
| 43 |
+
faceParsing_ckpt = "Other_dependencies/face_parsing/79999_iter.pth"
|
| 44 |
+
if cls.model is None:
|
| 45 |
+
from pretrained.face_parsing.face_parsing_demo import init_faceParsing_pretrained_model
|
| 46 |
+
cls.model = init_faceParsing_pretrained_model(
|
| 47 |
+
'default',
|
| 48 |
+
faceParsing_ckpt,
|
| 49 |
+
''
|
| 50 |
+
)
|
| 51 |
+
print(f"Initialized face parsing model from {faceParsing_ckpt}")
|
| 52 |
+
return cls.model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def gen_semantic_mask(path_img: Path, path_mask_to_save: Path, label_mode:str, path_vis: Path = None):
|
| 56 |
+
"""Generate semantic mask for an image using face parsing model"""
|
| 57 |
+
pil_im = Image.open(path_img).convert("RGB")
|
| 58 |
+
w, h = pil_im.size
|
| 59 |
+
# print(f"{pil_im.size=}") # 512,512
|
| 60 |
+
TMP_size = 1024
|
| 61 |
+
if w != TMP_size or h != TMP_size:
|
| 62 |
+
pil_im = pil_im.resize((TMP_size, TMP_size), Image.BILINEAR)
|
| 63 |
+
|
| 64 |
+
model = MaskModel_LazyLoader.get()
|
| 65 |
+
from pretrained.face_parsing.face_parsing_demo import faceParsing_demo, vis_parsing_maps
|
| 66 |
+
|
| 67 |
+
# print(f"{pil_im.size=}") # 1024,1024
|
| 68 |
+
# Generate mask with conversion to seg12 format
|
| 69 |
+
mask = faceParsing_demo(
|
| 70 |
+
model,
|
| 71 |
+
pil_im,
|
| 72 |
+
label_mode,
|
| 73 |
+
model_name='default'
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
Image.fromarray(mask).save(path_mask_to_save)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"{e=}")
|
| 80 |
+
print(f"{path_mask_to_save=}")
|
| 81 |
+
if path_mask_to_save.exists():
|
| 82 |
+
path_mask_to_save.unlink()
|
| 83 |
+
print(f'path_mask_to_save.unlink()')
|
| 84 |
+
# print(f"Saved mask: {path_mask_to_save}")
|
| 85 |
+
# print(f"{mask.shape=}") # 512,512
|
| 86 |
+
|
| 87 |
+
if path_vis:
|
| 88 |
+
mask_vis = vis_parsing_maps(pil_im, mask)
|
| 89 |
+
Image.fromarray(mask_vis).save(path_vis)
|
| 90 |
+
print(f"Saved mask vis: {path_vis}")
|
get_mask.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from util_and_constant import *
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def path_img_2_mask(
|
| 8 |
+
path_img,
|
| 9 |
+
preserve=(1, 2, 3, 5, 6, 7, 9, 10, 11, ), # int | list-liek. Default val represents face
|
| 10 |
+
):
|
| 11 |
+
"""
|
| 12 |
+
0 bg, 1 mouth, 2 eyebrow, 3 eyes, 4 hair, 5 nose, 6 face (excluding facial parts), 7: ear, 8: neck, 9: tooth
|
| 13 |
+
10: eye_glass, 11: ear_rings
|
| 14 |
+
"""
|
| 15 |
+
if isinstance(preserve,int):
|
| 16 |
+
preserve = (preserve,)
|
| 17 |
+
if 1:
|
| 18 |
+
assert isinstance(preserve,tuple) or isinstance(preserve,list)
|
| 19 |
+
assert all(isinstance(p, int) and 0 <= p <= 11 for p in preserve)
|
| 20 |
+
import numpy as np
|
| 21 |
+
from PIL import Image
|
| 22 |
+
mask_path = path_img_2_path_mask(path_img)
|
| 23 |
+
mask = Image.open(mask_path).convert('L')
|
| 24 |
+
mask = np.array(mask)
|
| 25 |
+
mask = np.isin(mask, preserve)
|
| 26 |
+
return mask
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_forehead_mask(sm_mask):
|
| 31 |
+
# return mask (np bool) where the forehead (face above eyebrows) is True
|
| 32 |
+
sm_mask = np.array(sm_mask)
|
| 33 |
+
# 6 is face (excluding facial parts); keep only the forehead part
|
| 34 |
+
# First get all face pixels
|
| 35 |
+
face_mask = (sm_mask == 6)
|
| 36 |
+
# Get eyebrow pixels to determine forehead boundary
|
| 37 |
+
# if 2 in sm, ; elif 3(eyes) in ; elif 10(eye_glass) in ; else
|
| 38 |
+
if 2 in sm_mask:
|
| 39 |
+
eyebrow_mask = (sm_mask == 2)
|
| 40 |
+
eyebrow_coords = np.where(eyebrow_mask)
|
| 41 |
+
eyebrow_top = np.min(eyebrow_coords[0])
|
| 42 |
+
# Forehead is face region above eyebrows
|
| 43 |
+
forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < eyebrow_top)
|
| 44 |
+
elif 3 in sm_mask:
|
| 45 |
+
eye_mask = (sm_mask == 3)
|
| 46 |
+
eye_coords = np.where(eye_mask)
|
| 47 |
+
eye_top = np.min(eye_coords[0])
|
| 48 |
+
# Estimate forehead as region above eyes with some margin
|
| 49 |
+
forehead_threshold = eye_top - 20 # 20 pixels above eyes as forehead
|
| 50 |
+
forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold)
|
| 51 |
+
elif 10 in sm_mask:
|
| 52 |
+
glass_mask = (sm_mask == 10)
|
| 53 |
+
glass_coords = np.where(glass_mask)
|
| 54 |
+
glass_top = np.min(glass_coords[0])
|
| 55 |
+
# Forehead is face region above glasses
|
| 56 |
+
forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < glass_top)
|
| 57 |
+
else:
|
| 58 |
+
# If no eyebrows detected, keep upper portion of face
|
| 59 |
+
face_coords = np.where(face_mask)
|
| 60 |
+
if len(face_coords[0]) > 0:
|
| 61 |
+
face_top = np.min(face_coords[0])
|
| 62 |
+
face_height = np.max(face_coords[0]) - face_top
|
| 63 |
+
forehead_threshold = face_top + face_height * 0.15 # top 15% as forehead
|
| 64 |
+
forehead_mask = face_mask & (np.arange(sm_mask.shape[0])[:, None] < forehead_threshold)
|
| 65 |
+
else:
|
| 66 |
+
forehead_mask = np.zeros_like(face_mask, dtype=bool)
|
| 67 |
+
forehead_mask = forehead_mask & face_mask
|
| 68 |
+
return forehead_mask
|
global_.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
some global variables
|
| 3 |
+
"""
|
| 4 |
+
task :int = None # current batch task id
|
| 5 |
+
|
| 6 |
+
TP_enable:bool = None # None means not set yet. should be set in imports.py
|
| 7 |
+
rank_:int = None
|
| 8 |
+
moduleName_2_adaRank:dict = {} # adaptive rank for each shared+LoRA module
|
| 9 |
+
|
hf_model.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Hub compatible model wrapper for UniBioTransfer.
|
| 3 |
+
Provides from_pretrained() and push_to_hub() functionality via PyTorchModelHubMixin.
|
| 4 |
+
"""
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
import copy
|
| 9 |
+
import os
|
| 10 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 11 |
+
|
| 12 |
+
import global_
|
| 13 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion, LandmarkExtractor
|
| 14 |
+
from ldm.util import instantiate_from_config
|
| 15 |
+
from omegaconf import OmegaConf
|
| 16 |
+
from pytorch_lightning import seed_everything
|
| 17 |
+
from MoE import offload_unused_tasks__LD
|
| 18 |
+
from multiTask_model import TaskSpecific_MoE, replace_modules_lossless
|
| 19 |
+
from my_py_lib.torch_util import cleanup_gpu_memory
|
| 20 |
+
|
| 21 |
+
TASKS = (0, 1, 2, 3)
|
| 22 |
+
TASK_NAME2ID = {"face": 0, "hair": 1, "motion": 2, "head": 3}
|
| 23 |
+
TASK_ID2NAME = {v: k for k, v in TASK_NAME2ID.items()}
|
| 24 |
+
|
| 25 |
+
SD14_FILENAME = "sd-v1-4.ckpt"
|
| 26 |
+
SD14_REPO = "CompVis/stable-diffusion-v-1-4-original"
|
| 27 |
+
PRETRAIN_REPO = "scy639/UniBioTransfer"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_first_stage_from_sd14(model, sd14_path):
|
| 31 |
+
"""Load first_stage_model (VAE) from SD v1.4 checkpoint."""
|
| 32 |
+
print(f"Loading first_stage_model from {sd14_path}")
|
| 33 |
+
sd14 = torch.load(str(sd14_path), map_location="cpu")
|
| 34 |
+
if isinstance(sd14, dict) and "state_dict" in sd14:
|
| 35 |
+
sd14_sd = sd14["state_dict"]
|
| 36 |
+
else:
|
| 37 |
+
sd14_sd = sd14
|
| 38 |
+
|
| 39 |
+
prefixes = ["first_stage_model.", "model.first_stage_model."]
|
| 40 |
+
fs_sd = {}
|
| 41 |
+
for prefix in prefixes:
|
| 42 |
+
for k, v in sd14_sd.items():
|
| 43 |
+
if k.startswith(prefix):
|
| 44 |
+
fs_sd[k[len(prefix):]] = v
|
| 45 |
+
if fs_sd:
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
if not fs_sd:
|
| 49 |
+
raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.")
|
| 50 |
+
|
| 51 |
+
model.first_stage_model.load_state_dict(fs_sd, strict=True)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class UniBioTransferModel(LatentDiffusion, PyTorchModelHubMixin):
|
| 55 |
+
"""
|
| 56 |
+
Hugging Face Hub compatible wrapper for UniBioTransfer.
|
| 57 |
+
|
| 58 |
+
Inherits from LatentDiffusion and adds HF Hub integration via PyTorchModelHubMixin.
|
| 59 |
+
|
| 60 |
+
Usage:
|
| 61 |
+
# Load model from HF Hub
|
| 62 |
+
model = UniBioTransferModel.from_pretrained("scy639/UniBioTransfer", task="face")
|
| 63 |
+
|
| 64 |
+
# Push to HF Hub
|
| 65 |
+
model.push_to_hub("your-repo/UniBioTransfer")
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
config: Model config dict (handled by PyTorchModelHubMixin)
|
| 69 |
+
task: Task name or ID (face/hair/motion/head)
|
| 70 |
+
**kwargs: Additional arguments passed to LatentDiffusion
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, config=None, task="face", **kwargs):
|
| 74 |
+
self._task_name = task if isinstance(task, str) else TASK_ID2NAME.get(task, "face")
|
| 75 |
+
self._task_id = TASK_NAME2ID.get(self._task_name, 0) if isinstance(task, str) else task
|
| 76 |
+
|
| 77 |
+
global_.task = self._task_id
|
| 78 |
+
|
| 79 |
+
if config is None:
|
| 80 |
+
config = {}
|
| 81 |
+
|
| 82 |
+
super().__init__(**config)
|
| 83 |
+
|
| 84 |
+
self._hf_config = {
|
| 85 |
+
"task": self._task_name,
|
| 86 |
+
"task_id": self._task_id,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def from_pretrained(
|
| 91 |
+
cls,
|
| 92 |
+
pretrained_model_name_or_path=None,
|
| 93 |
+
task="face",
|
| 94 |
+
device="cuda",
|
| 95 |
+
download_sd14=True,
|
| 96 |
+
download_deps=True,
|
| 97 |
+
cache_dir=None,
|
| 98 |
+
**kwargs,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Load model from Hugging Face Hub.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
pretrained_model_name_or_path: HF repo ID or local path.
|
| 105 |
+
Default: "scy639/UniBioTransfer"
|
| 106 |
+
task: Task name (face/hair/motion/head) or task ID (0/1/2/3)
|
| 107 |
+
device: Device to load model to ("cuda" or "cpu")
|
| 108 |
+
download_sd14: Whether to download SD v1.4 VAE weights
|
| 109 |
+
download_deps: Whether to download other dependencies (ArcFace, DLIB, face_parsing)
|
| 110 |
+
cache_dir: Cache directory for downloads
|
| 111 |
+
**kwargs: Additional arguments
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
UniBioTransferModel: Loaded model
|
| 115 |
+
"""
|
| 116 |
+
task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
|
| 117 |
+
task_name = TASK_ID2NAME.get(task_id, "face")
|
| 118 |
+
|
| 119 |
+
global_.task = task_id
|
| 120 |
+
|
| 121 |
+
if pretrained_model_name_or_path is None:
|
| 122 |
+
pretrained_model_name_or_path = PRETRAIN_REPO
|
| 123 |
+
|
| 124 |
+
repo_id = pretrained_model_name_or_path
|
| 125 |
+
|
| 126 |
+
cache_dir = Path(cache_dir) if cache_dir else Path(".")
|
| 127 |
+
|
| 128 |
+
ckpt_path = cache_dir / "checkpoints" / "pretrained.ckpt"
|
| 129 |
+
json_path = cache_dir / "checkpoints" / "pretrained.json"
|
| 130 |
+
sd14_path = cache_dir / "checkpoints" / SD14_FILENAME
|
| 131 |
+
arcface_path = cache_dir / "Other_dependencies" / "arcface" / "model_ir_se50.pth"
|
| 132 |
+
face_parsing_path = cache_dir / "Other_dependencies" / "face_parsing" / "79999_iter.pth"
|
| 133 |
+
|
| 134 |
+
def _download_file(repo, filename, local_path):
|
| 135 |
+
local_path = Path(local_path)
|
| 136 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 137 |
+
print(f"Downloading {filename} from {repo}...")
|
| 138 |
+
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 139 |
+
hf_hub_download(
|
| 140 |
+
repo_id=repo,
|
| 141 |
+
filename=filename,
|
| 142 |
+
local_dir=str(local_path.parent),
|
| 143 |
+
local_dir_use_symlinks=False,
|
| 144 |
+
token=token,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if not ckpt_path.exists():
|
| 148 |
+
_download_file(repo_id, "checkpoints/pretrained.ckpt", ckpt_path)
|
| 149 |
+
if not json_path.exists():
|
| 150 |
+
_download_file(repo_id, "checkpoints/pretrained.json", json_path)
|
| 151 |
+
|
| 152 |
+
if download_sd14 and not sd14_path.exists():
|
| 153 |
+
_download_file(SD14_REPO, SD14_FILENAME, sd14_path)
|
| 154 |
+
|
| 155 |
+
if download_deps:
|
| 156 |
+
if not arcface_path.exists():
|
| 157 |
+
_download_file(repo_id, "Other_dependencies/arcface/model_ir_se50.pth", arcface_path)
|
| 158 |
+
if not face_parsing_path.exists():
|
| 159 |
+
_download_file(repo_id, "Other_dependencies/face_parsing/79999_iter.pth", face_parsing_path)
|
| 160 |
+
|
| 161 |
+
seed_everything(42)
|
| 162 |
+
|
| 163 |
+
cur_dir = Path(__file__).parent
|
| 164 |
+
yaml_path = cur_dir / "LatentDiffusion.yaml"
|
| 165 |
+
if not yaml_path.exists():
|
| 166 |
+
yaml_path = Path("LatentDiffusion.yaml")
|
| 167 |
+
|
| 168 |
+
model_config = OmegaConf.load(yaml_path).model
|
| 169 |
+
model = instantiate_from_config(model_config)
|
| 170 |
+
|
| 171 |
+
with open(json_path, 'r') as f:
|
| 172 |
+
global_.moduleName_2_adaRank = json.load(f)
|
| 173 |
+
print(f"Loaded adaptive rank config from {json_path}")
|
| 174 |
+
|
| 175 |
+
_src0 = copy.deepcopy(model.model.diffusion_model)
|
| 176 |
+
_src1 = copy.deepcopy(model.model.diffusion_model)
|
| 177 |
+
_src2 = copy.deepcopy(model.model.diffusion_model)
|
| 178 |
+
_src3 = copy.deepcopy(model.model.diffusion_model)
|
| 179 |
+
replace_modules_lossless(
|
| 180 |
+
model.model.diffusion_model,
|
| 181 |
+
[_src0, _src1, _src2, _src3],
|
| 182 |
+
[0, 1, 2, 3],
|
| 183 |
+
parent_name=".model.diffusion_model",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
model.ID_proj_out = TaskSpecific_MoE([
|
| 187 |
+
copy.deepcopy(model.ID_proj_out),
|
| 188 |
+
copy.deepcopy(model.ID_proj_out),
|
| 189 |
+
copy.deepcopy(model.ID_proj_out),
|
| 190 |
+
], [0, 2, 3])
|
| 191 |
+
model.landmark_proj_out = TaskSpecific_MoE([
|
| 192 |
+
copy.deepcopy(model.landmark_proj_out),
|
| 193 |
+
copy.deepcopy(model.landmark_proj_out),
|
| 194 |
+
copy.deepcopy(model.landmark_proj_out),
|
| 195 |
+
], [0, 2, 3])
|
| 196 |
+
model.proj_out_source__head = TaskSpecific_MoE([
|
| 197 |
+
copy.deepcopy(model.proj_out_source__head),
|
| 198 |
+
copy.deepcopy(model.proj_out_source__head),
|
| 199 |
+
], [2, 3])
|
| 200 |
+
|
| 201 |
+
from util_and_constant import REFNET
|
| 202 |
+
if REFNET.ENABLE:
|
| 203 |
+
shared_ref = model.model.diffusion_model_refNet
|
| 204 |
+
src0 = shared_ref
|
| 205 |
+
src1 = copy.deepcopy(shared_ref)
|
| 206 |
+
src2 = copy.deepcopy(shared_ref)
|
| 207 |
+
src3 = copy.deepcopy(shared_ref)
|
| 208 |
+
replace_modules_lossless(shared_ref, [src0, src1, src2, src3], [0, 1, 2, 3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
|
| 209 |
+
from ldm.models.diffusion.bank import Bank
|
| 210 |
+
model.model.bank = Bank(
|
| 211 |
+
reader=model.model.diffusion_model,
|
| 212 |
+
writer=model.model.diffusion_model_refNet
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
print(f"Loading model weights from {ckpt_path}")
|
| 216 |
+
pl_sd = torch.load(str(ckpt_path), map_location="cpu")
|
| 217 |
+
if isinstance(pl_sd, dict) and "state_dict" in pl_sd:
|
| 218 |
+
sd = pl_sd["state_dict"]
|
| 219 |
+
else:
|
| 220 |
+
sd = pl_sd
|
| 221 |
+
|
| 222 |
+
m, u = model.load_state_dict(sd, strict=False)
|
| 223 |
+
if len(m) > 0:
|
| 224 |
+
print(f"Missing keys: {len(m)}")
|
| 225 |
+
if len(u) > 0:
|
| 226 |
+
print(f"Unexpected keys: {len(u)}")
|
| 227 |
+
|
| 228 |
+
_load_first_stage_from_sd14(model, sd14_path)
|
| 229 |
+
|
| 230 |
+
# offload_unused_tasks__LD(model, task_id, method="cpu")
|
| 231 |
+
|
| 232 |
+
model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False)
|
| 233 |
+
cleanup_gpu_memory()
|
| 234 |
+
|
| 235 |
+
# ZeroGPU 兼容:只在 device 不是 "cpu" 且 CUDA 可用时才移动到 GPU
|
| 236 |
+
# 如果传入 device="cpu",保持模型在 CPU 上(ZeroGPU 初始化时不碰显卡)
|
| 237 |
+
if device != "cpu" and torch.cuda.is_available():
|
| 238 |
+
model = model.to(torch.device(device))
|
| 239 |
+
else:
|
| 240 |
+
model = model.to(torch.device("cpu"))
|
| 241 |
+
model.eval()
|
| 242 |
+
|
| 243 |
+
model._task_id = task_id
|
| 244 |
+
model._task_name = task_name
|
| 245 |
+
model._hf_config = {"task": task_name, "task_id": task_id}
|
| 246 |
+
|
| 247 |
+
return model
|
imports.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
#---------------------------------------------------------------------------------------------------------------------
|
| 5 |
+
from util_and_constant import *
|
| 6 |
+
from get_mask import *
|
| 7 |
+
from util_cv2 import *
|
| 8 |
+
|
infer.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------- Config -------------------------------------------------
|
| 2 |
+
num_workers :int = 1
|
| 3 |
+
DDIM_STEPS = 50
|
| 4 |
+
BATCH_SIZE = 1
|
| 5 |
+
FIXED_CODE = False
|
| 6 |
+
# for vis
|
| 7 |
+
SAVE_INTERMEDIATES = True
|
| 8 |
+
NUM_grid_in_a_column = 5
|
| 9 |
+
# ------------------------------------------------------------------------------------------------------------------------
|
| 10 |
+
import argparse
|
| 11 |
+
parser = argparse.ArgumentParser(description="Custom inference for tgt/ref image pairs.")
|
| 12 |
+
parser.add_argument("--task-name", type=str,
|
| 13 |
+
default='face',
|
| 14 |
+
help="face|hair|motion|head")
|
| 15 |
+
parser.add_argument("--out-dir", type=str, default='examples/outputs', help="Output directory")
|
| 16 |
+
# option 1: pass 2 paths
|
| 17 |
+
parser.add_argument("--tgt", type=str, default=None, help="Path to target image. if None, will use paths read from --pair-list")
|
| 18 |
+
parser.add_argument("--ref", type=str, default=None, help="Path to reference image")
|
| 19 |
+
# option 2: pass a txt containing paths
|
| 20 |
+
parser.add_argument("--pair-list", type=str, default='examples/inputs.txt', help="white-space-separated list file: tgt_path ref_path")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
#-----------------------------------------set TASK--------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
task_name :str = args.task_name
|
| 26 |
+
TASK :int = {
|
| 27 |
+
'face': 0,
|
| 28 |
+
'hair': 1,
|
| 29 |
+
'motion': 2,
|
| 30 |
+
'head': 3,
|
| 31 |
+
}[task_name]
|
| 32 |
+
print(f'task: {task_name} transfer (ID: {TASK})')
|
| 33 |
+
# ------------------------------------------------------------------------------------------------------------------------
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
import sys
|
| 37 |
+
import os
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
|
| 40 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 41 |
+
|
| 42 |
+
from imports import *
|
| 43 |
+
import torch
|
| 44 |
+
import numpy as np
|
| 45 |
+
from omegaconf import OmegaConf
|
| 46 |
+
from PIL import Image
|
| 47 |
+
from tqdm import tqdm
|
| 48 |
+
from einops import rearrange
|
| 49 |
+
from torchvision.utils import make_grid
|
| 50 |
+
from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
|
| 51 |
+
from pytorch_lightning import seed_everything
|
| 52 |
+
from torch import autocast
|
| 53 |
+
from contextlib import nullcontext
|
| 54 |
+
import torchvision
|
| 55 |
+
|
| 56 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 57 |
+
from ldm.util import instantiate_from_config
|
| 58 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 59 |
+
from Dataset_custom import Dataset_custom
|
| 60 |
+
from MoE import offload_unused_tasks__LD
|
| 61 |
+
from ldm.models.diffusion.ddpm import LandmarkExtractor
|
| 62 |
+
from my_py_lib.torch_util import cleanup_gpu_memory
|
| 63 |
+
from gen_lmk_and_mask import gen_lmk_and_mask
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ------------------------------------------------------------------------------------------------------------------------
|
| 74 |
+
DDIM_ETA = 0.0
|
| 75 |
+
SCALE = 3.0
|
| 76 |
+
PRECISION = "full" # "full" or "autocast"
|
| 77 |
+
H = 512
|
| 78 |
+
W = 512
|
| 79 |
+
C = 4
|
| 80 |
+
F = 8
|
| 81 |
+
# ------------------------------------------------------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None:
|
| 85 |
+
print(f"Loading first_stage_model from {sd14_path}")
|
| 86 |
+
sd14 = torch.load(str(sd14_path), map_location="cpu")
|
| 87 |
+
if isinstance(sd14, dict) and "state_dict" in sd14:
|
| 88 |
+
sd14_sd = sd14["state_dict"]
|
| 89 |
+
else:
|
| 90 |
+
sd14_sd = sd14
|
| 91 |
+
|
| 92 |
+
prefixes = ["first_stage_model.", "model.first_stage_model."]
|
| 93 |
+
fs_sd = {}
|
| 94 |
+
for prefix in prefixes:
|
| 95 |
+
for k, v in sd14_sd.items():
|
| 96 |
+
if k.startswith(prefix):
|
| 97 |
+
fs_sd[k[len(prefix):]] = v
|
| 98 |
+
if fs_sd:
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
if not fs_sd:
|
| 102 |
+
raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.")
|
| 103 |
+
|
| 104 |
+
model.first_stage_model.load_state_dict(fs_sd, strict=True)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num):
|
| 108 |
+
x = model.decode_first_stage(x)
|
| 109 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
| 110 |
+
x = x.cpu().permute(0, 2, 3, 1).numpy()
|
| 111 |
+
for i in range(len(x)):
|
| 112 |
+
img = Image.fromarray((x[i] * 255).astype(np.uint8))
|
| 113 |
+
save_path = Path(base_path) / segment_id
|
| 114 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
img.save(save_path / f"{intermediate_num}.png")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_tensor_clip(normalize=True, toTensor=True):
|
| 119 |
+
transform_list = []
|
| 120 |
+
if toTensor:
|
| 121 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
| 122 |
+
if normalize:
|
| 123 |
+
transform_list += [
|
| 124 |
+
torchvision.transforms.Normalize(
|
| 125 |
+
(0.48145466, 0.4578275, 0.40821073),
|
| 126 |
+
(0.26862954, 0.26130258, 0.27577711),
|
| 127 |
+
)
|
| 128 |
+
]
|
| 129 |
+
return torchvision.transforms.Compose(transform_list)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_model_from_config(ckpt, verbose=1):
|
| 133 |
+
if 1:
|
| 134 |
+
ckpt = Path(ckpt)
|
| 135 |
+
print(f"Loading model from {ckpt}")
|
| 136 |
+
pl_sd = torch.load(str(ckpt), map_location="cpu")
|
| 137 |
+
if isinstance(pl_sd, dict) and "state_dict" in pl_sd:
|
| 138 |
+
sd = pl_sd["state_dict"]
|
| 139 |
+
else:
|
| 140 |
+
sd = pl_sd
|
| 141 |
+
else:
|
| 142 |
+
print("DEBUG_skip_load_ckpt")
|
| 143 |
+
if 1:
|
| 144 |
+
from init_model import get_moe
|
| 145 |
+
model: LatentDiffusion = get_moe()
|
| 146 |
+
model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False)
|
| 147 |
+
cleanup_gpu_memory()
|
| 148 |
+
if 1:
|
| 149 |
+
m, u = model.load_state_dict(sd, strict=False)
|
| 150 |
+
if len(m) > 0 and verbose:
|
| 151 |
+
print("missing keys:")
|
| 152 |
+
pretty_print_torch_module_keys(m)
|
| 153 |
+
if len(u) > 0 and verbose:
|
| 154 |
+
print("unexpected keys:")
|
| 155 |
+
pretty_print_torch_module_keys(u)
|
| 156 |
+
load_first_stage_from_sd14(model, SD14_localpath)
|
| 157 |
+
|
| 158 |
+
offload_unused_tasks__LD(model, TASK, method="del") # for save cuda mem
|
| 159 |
+
model.cuda()
|
| 160 |
+
model.eval()
|
| 161 |
+
return model
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def load_pairs(pair_list, tgt, ref):
|
| 167 |
+
if tgt and ref:
|
| 168 |
+
pairs = [(tgt, ref), ]
|
| 169 |
+
elif pair_list:
|
| 170 |
+
pairs = []
|
| 171 |
+
with open(pair_list, "r") as f:
|
| 172 |
+
for line_num, line in enumerate(f, start=1):
|
| 173 |
+
line = line.strip()
|
| 174 |
+
if not line or line.startswith("#"):
|
| 175 |
+
continue
|
| 176 |
+
parts = line.split(" ")
|
| 177 |
+
if len(parts) != 2:
|
| 178 |
+
raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}")
|
| 179 |
+
pairs.append((parts[0], parts[1]))
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.")
|
| 182 |
+
print(f"{pairs=}")
|
| 183 |
+
return pairs
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def un_norm(x):
|
| 187 |
+
return (x + 1.0) / 2.0
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def un_norm_clip(x1):
|
| 191 |
+
x = x1 * 1.0
|
| 192 |
+
reduce = False
|
| 193 |
+
if len(x.shape) == 3:
|
| 194 |
+
x = x.unsqueeze(0)
|
| 195 |
+
reduce = True
|
| 196 |
+
x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466
|
| 197 |
+
x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275
|
| 198 |
+
x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073
|
| 199 |
+
if reduce:
|
| 200 |
+
x = x.squeeze(0)
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
pairs = load_pairs(args.pair_list, args.tgt, args.ref)
|
| 206 |
+
|
| 207 |
+
out_dir = Path(args.out_dir)
|
| 208 |
+
result_path = out_dir / "results"
|
| 209 |
+
grid_path = out_dir / "grid"
|
| 210 |
+
inter_path = out_dir / "intermediates"
|
| 211 |
+
inter_pred_path = inter_path / "pred_x0"
|
| 212 |
+
inter_noised_path = inter_path / "noised"
|
| 213 |
+
out_dir.mkdir(parents=False, exist_ok=True)
|
| 214 |
+
result_path.mkdir(parents=False, exist_ok=True)
|
| 215 |
+
grid_path.mkdir(parents=False, exist_ok=True)
|
| 216 |
+
inter_path.mkdir(parents=False, exist_ok=True)
|
| 217 |
+
if SAVE_INTERMEDIATES:
|
| 218 |
+
inter_pred_path.mkdir(parents=False, exist_ok=True)
|
| 219 |
+
inter_noised_path.mkdir(parents=False, exist_ok=True)
|
| 220 |
+
paths_tgt = [p[0] for p in pairs]
|
| 221 |
+
paths_ref = [p[1] for p in pairs]
|
| 222 |
+
gen_lmk_and_mask(paths_tgt + paths_ref)
|
| 223 |
+
|
| 224 |
+
seed_everything(42)
|
| 225 |
+
|
| 226 |
+
model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, )
|
| 227 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 228 |
+
model = model.to(device)
|
| 229 |
+
sampler = DDIMSampler(model)
|
| 230 |
+
|
| 231 |
+
dataset = Dataset_custom(
|
| 232 |
+
"test",
|
| 233 |
+
task=TASK,
|
| 234 |
+
paths_tgt=paths_tgt,
|
| 235 |
+
paths_ref=paths_ref,
|
| 236 |
+
)
|
| 237 |
+
dataloader = torch.utils.data.DataLoader(
|
| 238 |
+
dataset,
|
| 239 |
+
batch_size=BATCH_SIZE,
|
| 240 |
+
num_workers=num_workers,
|
| 241 |
+
pin_memory=True,
|
| 242 |
+
shuffle=False,
|
| 243 |
+
drop_last=False,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
start_code = None
|
| 247 |
+
if FIXED_CODE:
|
| 248 |
+
start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device)
|
| 249 |
+
|
| 250 |
+
precision_scope = autocast if PRECISION == "autocast" else nullcontext
|
| 251 |
+
grids = []
|
| 252 |
+
grid_stems = []
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
with precision_scope("cuda"):
|
| 256 |
+
with model.ema_scope():
|
| 257 |
+
for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader):
|
| 258 |
+
model.set_task(test_model_kwargs)
|
| 259 |
+
bs = test_batch.shape[0]
|
| 260 |
+
|
| 261 |
+
batch_ = {
|
| 262 |
+
**test_model_kwargs,
|
| 263 |
+
"GT": torch.zeros_like(test_model_kwargs["inpaint_image"]),
|
| 264 |
+
}
|
| 265 |
+
batch_, c = model.get_input_and_conditioning(batch_, device=device)
|
| 266 |
+
z_inpaint = batch_["z4_inpaint"]
|
| 267 |
+
z_inpaint_mask = batch_["tgt_mask_64"]
|
| 268 |
+
z_ref = batch_["z_ref"]
|
| 269 |
+
z9 = batch_["z9"]
|
| 270 |
+
|
| 271 |
+
uc = None
|
| 272 |
+
if SCALE != 1.0:
|
| 273 |
+
uc = model.learnable_vector[TASK].repeat(bs, 1, 1)
|
| 274 |
+
|
| 275 |
+
shape = [C, H // F, W // F]
|
| 276 |
+
local_start_code = start_code
|
| 277 |
+
if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs):
|
| 278 |
+
local_start_code = torch.randn([bs, C, H // F, W // F], device=device)
|
| 279 |
+
samples_ddim, intermediates = sampler.sample(
|
| 280 |
+
S=DDIM_STEPS,
|
| 281 |
+
conditioning=c,
|
| 282 |
+
batch_size=bs,
|
| 283 |
+
shape=shape,
|
| 284 |
+
verbose=False,
|
| 285 |
+
unconditional_guidance_scale=SCALE,
|
| 286 |
+
unconditional_conditioning=uc,
|
| 287 |
+
eta=DDIM_ETA,
|
| 288 |
+
x_T=local_start_code,
|
| 289 |
+
log_every_t=100,
|
| 290 |
+
z_inpaint=z_inpaint,
|
| 291 |
+
z_inpaint_mask=z_inpaint_mask,
|
| 292 |
+
z_ref=z_ref,
|
| 293 |
+
z9=z9,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if SAVE_INTERMEDIATES:
|
| 297 |
+
intermediate_pred_x0 = intermediates["pred_x0"]
|
| 298 |
+
intermediate_noised = intermediates["x_inter"]
|
| 299 |
+
for i in range(len(intermediate_pred_x0)):
|
| 300 |
+
for j in range(bs):
|
| 301 |
+
stem = f"{out_stem_batch[j]}"
|
| 302 |
+
save_sample_by_decode(
|
| 303 |
+
intermediate_pred_x0[i][j : j + 1],
|
| 304 |
+
model,
|
| 305 |
+
inter_pred_path,
|
| 306 |
+
stem,
|
| 307 |
+
i,
|
| 308 |
+
)
|
| 309 |
+
save_sample_by_decode(
|
| 310 |
+
intermediate_noised[i][j : j + 1],
|
| 311 |
+
model,
|
| 312 |
+
inter_noised_path,
|
| 313 |
+
stem,
|
| 314 |
+
i,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
| 318 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 319 |
+
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
| 320 |
+
|
| 321 |
+
x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2)
|
| 322 |
+
for i, x_sample in enumerate(x_checked_image_torch):
|
| 323 |
+
stem = f"{out_stem_batch[i]}"
|
| 324 |
+
out_path = result_path / f"{stem}.png"
|
| 325 |
+
img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8))
|
| 326 |
+
img.save(out_path)
|
| 327 |
+
print(f"{out_path=}")
|
| 328 |
+
|
| 329 |
+
for i, x_sample in enumerate(x_checked_image_torch):
|
| 330 |
+
all_img = []
|
| 331 |
+
all_img.append(un_norm(test_batch[i]).cpu())
|
| 332 |
+
if TASK != 2:
|
| 333 |
+
ref_img = test_model_kwargs["ref_imgs"].squeeze(1)
|
| 334 |
+
ref_img = torchvision.transforms.Resize([512, 512])(ref_img)
|
| 335 |
+
ref_img = un_norm_clip(ref_img[i]).cpu()
|
| 336 |
+
else:
|
| 337 |
+
ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu()
|
| 338 |
+
all_img.append(ref_img)
|
| 339 |
+
all_img.append(x_sample)
|
| 340 |
+
|
| 341 |
+
grid = torch.stack(all_img, 0)
|
| 342 |
+
grid = make_grid(grid)
|
| 343 |
+
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
|
| 344 |
+
img = Image.fromarray(grid.astype(np.uint8))
|
| 345 |
+
stem = f"{out_stem_batch[i]}"
|
| 346 |
+
path_save_img = grid_path / f"grid-{stem}.jpg"
|
| 347 |
+
img.save(path_save_img)
|
| 348 |
+
print(f"{path_save_img=}")
|
| 349 |
+
grids.append(img)
|
| 350 |
+
grid_stems.append(stem)
|
| 351 |
+
if len(grids) >= NUM_grid_in_a_column:
|
| 352 |
+
stem_start = grid_stems[0]
|
| 353 |
+
stem_end = grid_stems[-1]
|
| 354 |
+
grid_column = imgs_2_grid_A(
|
| 355 |
+
grids,
|
| 356 |
+
grid_layout='column',
|
| 357 |
+
grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"),
|
| 358 |
+
)
|
| 359 |
+
grids = []
|
| 360 |
+
grid_stems = []
|
| 361 |
+
|
| 362 |
+
model.unset_task()
|
| 363 |
+
|
| 364 |
+
print(f"Your samples are ready and waiting for you here: {out_dir}")
|
| 365 |
+
|
| 366 |
+
|
infer_hf.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
High-level inference pipeline for UniBioTransfer.
|
| 3 |
+
Designed for easy use in Hugging Face Spaces and other applications.
|
| 4 |
+
|
| 5 |
+
ZeroGPU Compatible:
|
| 6 |
+
- Supports CPU initialization (device="cpu")
|
| 7 |
+
- Dynamically switches to CUDA during inference when called from @spaces.GPU
|
| 8 |
+
"""
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
import global_
|
| 16 |
+
from hf_model import UniBioTransferModel, TASK_NAME2ID, TASK_ID2NAME
|
| 17 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 18 |
+
from pytorch_lightning import seed_everything
|
| 19 |
+
|
| 20 |
+
DDIM_STEPS_DEFAULT = 50
|
| 21 |
+
SCALE_DEFAULT = 3.0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
H, W, C, F = 512, 512, 4, 8
|
| 25 |
+
class UniBioTransferPipeline:
|
| 26 |
+
"""
|
| 27 |
+
High-level pipeline for UniBioTransfer inference.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, model, task="face", device="cpu"):
|
| 31 |
+
"""
|
| 32 |
+
Initialize pipeline with a loaded model.
|
| 33 |
+
"""
|
| 34 |
+
self.model = model
|
| 35 |
+
self.task = task
|
| 36 |
+
self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
|
| 37 |
+
self._init_device = device
|
| 38 |
+
|
| 39 |
+
global_.task = self.task_id
|
| 40 |
+
self.model.task = self.task_id
|
| 41 |
+
|
| 42 |
+
self.sampler = DDIMSampler(model)
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_pretrained(
|
| 46 |
+
cls,
|
| 47 |
+
repo_id="scy639/UniBioTransfer",
|
| 48 |
+
task="face",
|
| 49 |
+
device="cpu",
|
| 50 |
+
cache_dir=None,
|
| 51 |
+
**kwargs,
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Load pipeline from Hugging Face Hub.
|
| 55 |
+
"""
|
| 56 |
+
model = UniBioTransferModel.from_pretrained(
|
| 57 |
+
pretrained_model_name_or_path=repo_id,
|
| 58 |
+
task=task,
|
| 59 |
+
device=device,
|
| 60 |
+
cache_dir=cache_dir,
|
| 61 |
+
**kwargs,
|
| 62 |
+
)
|
| 63 |
+
return cls(model, task=task, device=device)
|
| 64 |
+
|
| 65 |
+
def set_task(self, task):
|
| 66 |
+
"""Switch to a different task."""
|
| 67 |
+
self.task = task
|
| 68 |
+
self.task_id = TASK_NAME2ID.get(task, task) if isinstance(task, str) else task
|
| 69 |
+
global_.task = self.task_id
|
| 70 |
+
self.model.task = self.task_id
|
| 71 |
+
|
| 72 |
+
def __call__(
|
| 73 |
+
self,
|
| 74 |
+
tgt_image,
|
| 75 |
+
ref_image,
|
| 76 |
+
ddim_steps=DDIM_STEPS_DEFAULT,
|
| 77 |
+
scale=SCALE_DEFAULT,
|
| 78 |
+
seed=42,
|
| 79 |
+
num_images=1,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Run inference on a pair of images.
|
| 83 |
+
"""
|
| 84 |
+
seed_everything(seed)
|
| 85 |
+
|
| 86 |
+
tgt_img = self._load_image(tgt_image)
|
| 87 |
+
ref_img = self._load_image(ref_image)
|
| 88 |
+
|
| 89 |
+
tgt_img = self._resize_image(tgt_img, (H, W))
|
| 90 |
+
ref_img = self._resize_image(ref_img, (H, W))
|
| 91 |
+
|
| 92 |
+
result_tensors = self._run_inference(tgt_img, ref_img, ddim_steps, scale, num_images)
|
| 93 |
+
|
| 94 |
+
result_imgs = [self._postprocess(result_tensors[i]) for i in range(result_tensors.shape[0])]
|
| 95 |
+
return result_imgs
|
| 96 |
+
|
| 97 |
+
def _load_image(self, img):
|
| 98 |
+
"""Load image from various formats."""
|
| 99 |
+
if isinstance(img, Image.Image):
|
| 100 |
+
return img.convert("RGB")
|
| 101 |
+
elif isinstance(img, np.ndarray):
|
| 102 |
+
return Image.fromarray(img).convert("RGB")
|
| 103 |
+
elif isinstance(img, (str, Path)):
|
| 104 |
+
return Image.open(img).convert("RGB")
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unsupported image type: {type(img)}")
|
| 107 |
+
|
| 108 |
+
def _resize_image(self, img, size):
|
| 109 |
+
"""Resize image to target size."""
|
| 110 |
+
if img.size != size:
|
| 111 |
+
img = img.resize(size, Image.LANCZOS)
|
| 112 |
+
return img
|
| 113 |
+
|
| 114 |
+
def _run_inference(self, tgt_img, ref_img, ddim_steps, scale, num_images):
|
| 115 |
+
"""
|
| 116 |
+
Run diffusion sampling.
|
| 117 |
+
完全复用 infer.py 的逻辑,使用 dataloader。
|
| 118 |
+
"""
|
| 119 |
+
from Dataset_custom import Dataset_custom
|
| 120 |
+
from gen_lmk_and_mask import gen_lmk_and_mask
|
| 121 |
+
import tempfile
|
| 122 |
+
|
| 123 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 124 |
+
tgt_path = Path(tmpdir) / "tgt.png"
|
| 125 |
+
ref_path = Path(tmpdir) / "ref.png"
|
| 126 |
+
tgt_img.save(tgt_path)
|
| 127 |
+
ref_img.save(ref_path)
|
| 128 |
+
|
| 129 |
+
gen_lmk_and_mask([str(tgt_path), str(ref_path)], write_cache=True)
|
| 130 |
+
|
| 131 |
+
dataset = Dataset_custom(
|
| 132 |
+
"test",
|
| 133 |
+
task=self.task_id,
|
| 134 |
+
paths_tgt=[str(tgt_path)],
|
| 135 |
+
paths_ref=[str(ref_path)],
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
dataloader = torch.utils.data.DataLoader(
|
| 139 |
+
dataset,
|
| 140 |
+
batch_size=1,
|
| 141 |
+
num_workers=1,
|
| 142 |
+
pin_memory=True,
|
| 143 |
+
shuffle=False,
|
| 144 |
+
drop_last=False,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
run_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 148 |
+
self.model = self.model.to(run_device)
|
| 149 |
+
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
for test_batch, prior, test_model_kwargs, out_stem_batch in dataloader:
|
| 152 |
+
test_batch = test_batch.to(run_device)
|
| 153 |
+
if test_batch.shape[0] == 1:
|
| 154 |
+
test_batch = test_batch.repeat(num_images, 1, 1, 1)
|
| 155 |
+
if isinstance(prior, torch.Tensor):
|
| 156 |
+
prior = prior.to(run_device)
|
| 157 |
+
if prior.shape[0] == 1:
|
| 158 |
+
prior = prior.repeat(num_images, 1, 1, 1)
|
| 159 |
+
for k, v in test_model_kwargs.items():
|
| 160 |
+
if isinstance(v, torch.Tensor):
|
| 161 |
+
v = v.to(run_device)
|
| 162 |
+
if v.shape[0] == 1:
|
| 163 |
+
repeats = [num_images] + [1] * (v.ndim - 1)
|
| 164 |
+
v = v.repeat(*repeats)
|
| 165 |
+
test_model_kwargs[k] = v
|
| 166 |
+
elif isinstance(v, dict):
|
| 167 |
+
new_v = {}
|
| 168 |
+
for kk, vv in v.items():
|
| 169 |
+
if isinstance(vv, torch.Tensor):
|
| 170 |
+
vv = vv.to(run_device)
|
| 171 |
+
if vv.shape[0] == 1:
|
| 172 |
+
repeats = [num_images] + [1] * (vv.ndim - 1)
|
| 173 |
+
vv = vv.repeat(*repeats)
|
| 174 |
+
new_v[kk] = vv
|
| 175 |
+
else:
|
| 176 |
+
new_v[kk] = vv
|
| 177 |
+
test_model_kwargs[k] = new_v
|
| 178 |
+
elif isinstance(v, list):
|
| 179 |
+
test_model_kwargs[k] = v * num_images
|
| 180 |
+
|
| 181 |
+
self.model.set_task(test_model_kwargs)
|
| 182 |
+
bs = num_images
|
| 183 |
+
|
| 184 |
+
batch_ = {
|
| 185 |
+
**test_model_kwargs,
|
| 186 |
+
"GT": torch.zeros(num_images, *test_model_kwargs["inpaint_image"].shape[1:], device=run_device),
|
| 187 |
+
}
|
| 188 |
+
batch_, c = self.model.get_input_and_conditioning(batch_, device=run_device)
|
| 189 |
+
|
| 190 |
+
z_inpaint = batch_["z4_inpaint"]
|
| 191 |
+
z_inpaint_mask = batch_["tgt_mask_64"]
|
| 192 |
+
z_ref = batch_["z_ref"]
|
| 193 |
+
z9 = batch_["z9"]
|
| 194 |
+
|
| 195 |
+
uc = None
|
| 196 |
+
if scale != 1.0:
|
| 197 |
+
uc = self.model.learnable_vector[self.task_id].repeat(bs, 1, 1)
|
| 198 |
+
|
| 199 |
+
shape = [C, H // F, W // F]
|
| 200 |
+
start_code = None
|
| 201 |
+
|
| 202 |
+
samples_ddim, _ = self.sampler.sample(
|
| 203 |
+
S=ddim_steps,
|
| 204 |
+
conditioning=c,
|
| 205 |
+
batch_size=bs,
|
| 206 |
+
shape=shape,
|
| 207 |
+
verbose=False,
|
| 208 |
+
unconditional_guidance_scale=scale,
|
| 209 |
+
unconditional_conditioning=uc,
|
| 210 |
+
eta=0.0,
|
| 211 |
+
x_T=start_code,
|
| 212 |
+
log_every_t=100,
|
| 213 |
+
z_inpaint=z_inpaint,
|
| 214 |
+
z_inpaint_mask=z_inpaint_mask,
|
| 215 |
+
z_ref=z_ref,
|
| 216 |
+
z9=z9,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
| 220 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
| 221 |
+
|
| 222 |
+
self.model.unset_task()
|
| 223 |
+
|
| 224 |
+
return x_samples_ddim
|
| 225 |
+
|
| 226 |
+
def _postprocess(self, tensor):
|
| 227 |
+
"""Convert model output tensor to PIL Image."""
|
| 228 |
+
img_array = tensor.cpu().permute(1, 2, 0).numpy()
|
| 229 |
+
img_array = (img_array * 255).astype(np.uint8)
|
| 230 |
+
return Image.fromarray(img_array)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def infer_single(
|
| 234 |
+
tgt_path,
|
| 235 |
+
ref_path,
|
| 236 |
+
task="face",
|
| 237 |
+
output_path=None,
|
| 238 |
+
ddim_steps=DDIM_STEPS_DEFAULT,
|
| 239 |
+
scale=SCALE_DEFAULT,
|
| 240 |
+
device="cuda",
|
| 241 |
+
):
|
| 242 |
+
"""
|
| 243 |
+
Convenience function for single inference.
|
| 244 |
+
"""
|
| 245 |
+
pipeline = UniBioTransferPipeline.from_pretrained(task=task, device=device)
|
| 246 |
+
result = pipeline(tgt_path, ref_path, ddim_steps=ddim_steps, scale=scale)
|
| 247 |
+
|
| 248 |
+
if output_path is not None:
|
| 249 |
+
result.save(output_path)
|
| 250 |
+
print(f"Saved result to {output_path}")
|
| 251 |
+
|
| 252 |
+
return result
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
import argparse
|
| 257 |
+
|
| 258 |
+
parser = argparse.ArgumentParser(description="UniBioTransfer inference")
|
| 259 |
+
parser.add_argument("--task", type=str, default="face", choices=["face", "hair", "motion", "head"])
|
| 260 |
+
parser.add_argument("--tgt", type=str, required=True, help="Path to target image")
|
| 261 |
+
parser.add_argument("--ref", type=str, required=True, help="Path to reference image")
|
| 262 |
+
parser.add_argument("--out", type=str, default="result.png", help="Output path")
|
| 263 |
+
parser.add_argument("--ddim-steps", type=int, default=50)
|
| 264 |
+
parser.add_argument("--scale", type=float, default=3.0)
|
| 265 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 266 |
+
|
| 267 |
+
args = parser.parse_args()
|
| 268 |
+
|
| 269 |
+
result = infer_single(
|
| 270 |
+
args.tgt,
|
| 271 |
+
args.ref,
|
| 272 |
+
task=args.task,
|
| 273 |
+
output_path=args.out,
|
| 274 |
+
ddim_steps=args.ddim_steps,
|
| 275 |
+
scale=args.scale,
|
| 276 |
+
device=args.device,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
print(f"Inference complete. Result shape: {result.size}")
|
init_model.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys,os
|
| 2 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 3 |
+
if __name__=='__main__': sys.path.append(os.path.abspath(os.path.join(cur_dir, '..')))
|
| 4 |
+
|
| 5 |
+
from imports import *
|
| 6 |
+
import json
|
| 7 |
+
import argparse, os, sys, glob
|
| 8 |
+
import cv2
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from MoE import *
|
| 12 |
+
from multiTask_model import *
|
| 13 |
+
from lora_layers import *
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from tqdm import tqdm, trange
|
| 17 |
+
from itertools import islice
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from torchvision.utils import make_grid
|
| 20 |
+
from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A
|
| 21 |
+
import time
|
| 22 |
+
import copy
|
| 23 |
+
from pytorch_lightning import seed_everything
|
| 24 |
+
from torch import autocast
|
| 25 |
+
from contextlib import contextmanager, nullcontext
|
| 26 |
+
import torchvision
|
| 27 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 28 |
+
from ldm.models.diffusion.bank import Bank
|
| 29 |
+
from ldm.util import instantiate_from_config
|
| 30 |
+
|
| 31 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 32 |
+
|
| 33 |
+
from transformers import AutoFeatureExtractor
|
| 34 |
+
|
| 35 |
+
# import clip
|
| 36 |
+
from torchvision.transforms import Resize
|
| 37 |
+
from fnmatch import fnmatch
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
from PIL import Image
|
| 41 |
+
from torchvision.transforms import PILToTensor
|
| 42 |
+
#----------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_moe():
|
| 46 |
+
if 1:
|
| 47 |
+
seed_everything(42)
|
| 48 |
+
# torch.cuda.set_device(opt.device_ID)
|
| 49 |
+
model :LatentDiffusion = instantiate_from_config(OmegaConf.load(f"LatentDiffusion.yaml").model,)
|
| 50 |
+
if REFNET.ENABLE:
|
| 51 |
+
assert model.model.diffusion_model_refNet.is_refNet
|
| 52 |
+
|
| 53 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 54 |
+
device = torch.device("cpu")
|
| 55 |
+
model = model.to(device)
|
| 56 |
+
if FOR_upcycle_ckpt_GEN_or_USE:
|
| 57 |
+
del model.ptsM_Generator
|
| 58 |
+
|
| 59 |
+
def average_module_weight(
|
| 60 |
+
src_modules: list,
|
| 61 |
+
):
|
| 62 |
+
"""Average the weights of multiple modules"""
|
| 63 |
+
if not src_modules:
|
| 64 |
+
return None
|
| 65 |
+
# Get the state dict of the first module as template
|
| 66 |
+
avg_state_dict = {}
|
| 67 |
+
first_state_dict = src_modules[0].state_dict()
|
| 68 |
+
# Initialize with zeros
|
| 69 |
+
for key in first_state_dict:
|
| 70 |
+
avg_state_dict[key] = torch.zeros_like(first_state_dict[key])
|
| 71 |
+
# Sum
|
| 72 |
+
for module in src_modules:
|
| 73 |
+
module_state_dict = module.state_dict()
|
| 74 |
+
for key in avg_state_dict:
|
| 75 |
+
avg_state_dict[key] += module_state_dict[key]
|
| 76 |
+
# Average
|
| 77 |
+
for key in avg_state_dict:
|
| 78 |
+
avg_state_dict[key] /= len(src_modules)
|
| 79 |
+
return avg_state_dict
|
| 80 |
+
def recursive_average_module_weight(
|
| 81 |
+
tgt_module: nn.Module,
|
| 82 |
+
src_modules: list,
|
| 83 |
+
cb,
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
Recursively find modules and replace with averaged weights based on callback
|
| 87 |
+
"""
|
| 88 |
+
for name, child in tgt_module.named_children():
|
| 89 |
+
if 1: # Get corresponding modules from source models
|
| 90 |
+
src_child_modules = []
|
| 91 |
+
for src_module in src_modules:
|
| 92 |
+
src_child = getattr(src_module, name)
|
| 93 |
+
assert src_child is not None,name
|
| 94 |
+
src_child_modules.append(src_child)
|
| 95 |
+
# assert not isinstance(child, TaskSpecific_MoE)
|
| 96 |
+
if cb(child, name, tgt_module):
|
| 97 |
+
print(f"[recursive_average_module_weight] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
|
| 98 |
+
# Average & load
|
| 99 |
+
avg_weights = average_module_weight(src_child_modules)
|
| 100 |
+
child.load_state_dict(avg_weights)
|
| 101 |
+
else:
|
| 102 |
+
recursive_average_module_weight(child, src_child_modules, cb)
|
| 103 |
+
return tgt_module
|
| 104 |
+
|
| 105 |
+
def replace_module_with_TaskSpecific(
|
| 106 |
+
tgt_module: nn.Module,# tgt module
|
| 107 |
+
src_modules: list,
|
| 108 |
+
cb,
|
| 109 |
+
parent_name: str = "",
|
| 110 |
+
depth :int = 0,
|
| 111 |
+
):
|
| 112 |
+
for name, child in tgt_module.named_children():
|
| 113 |
+
if 1: # Get corresponding modules from source models
|
| 114 |
+
src_child_modules = []
|
| 115 |
+
for src_module in src_modules:
|
| 116 |
+
src_child = getattr(src_module, name)
|
| 117 |
+
assert src_child is not None,name
|
| 118 |
+
src_child_modules.append(src_child)
|
| 119 |
+
assert not isinstance(child, TaskSpecific_MoE)
|
| 120 |
+
full_name = f"{parent_name}.{name}"
|
| 121 |
+
if cb(child, name, full_name, tgt_module):
|
| 122 |
+
print(f"[replace_module_with_TaskSpecific] {name=} child: {repr(child)[:50]} tgt_module: {repr(tgt_module)[:50]}")
|
| 123 |
+
setattr(tgt_module, name, TaskSpecific_MoE(src_child_modules,TASKS))
|
| 124 |
+
else:
|
| 125 |
+
if depth<=0:
|
| 126 |
+
replace_module_with_TaskSpecific(child, src_child_modules,cb,parent_name=full_name,depth=depth+1)
|
| 127 |
+
return tgt_module
|
| 128 |
+
|
| 129 |
+
if not FOR_upcycle_ckpt_GEN_or_USE:
|
| 130 |
+
modelMOE :LatentDiffusion = model
|
| 131 |
+
del model
|
| 132 |
+
if 1: # ensure distinct module instances per task (avoid shared identities)
|
| 133 |
+
with open(PRETRAIN_JSON_PATH, 'r') as f: global_.moduleName_2_adaRank = json.load(f)
|
| 134 |
+
print(f"loaded from {PRETRAIN_JSON_PATH=}")
|
| 135 |
+
_src0 = copy.deepcopy(modelMOE.model.diffusion_model)
|
| 136 |
+
_src1 = copy.deepcopy(modelMOE.model.diffusion_model)
|
| 137 |
+
_src2 = copy.deepcopy(modelMOE.model.diffusion_model)
|
| 138 |
+
_src3 = copy.deepcopy(modelMOE.model.diffusion_model)
|
| 139 |
+
replace_modules_lossless(
|
| 140 |
+
modelMOE.model.diffusion_model,
|
| 141 |
+
[ _src0, _src1, _src2, _src3 ],
|
| 142 |
+
[0,1,2,3],
|
| 143 |
+
parent_name=".model.diffusion_model",
|
| 144 |
+
)
|
| 145 |
+
# Build-time dummy wrapping for task-specific heads so that ckpt keys match
|
| 146 |
+
modelMOE.ID_proj_out = TaskSpecific_MoE([
|
| 147 |
+
copy.deepcopy(modelMOE.ID_proj_out),
|
| 148 |
+
copy.deepcopy(modelMOE.ID_proj_out),
|
| 149 |
+
copy.deepcopy(modelMOE.ID_proj_out),
|
| 150 |
+
], [0,2,3])
|
| 151 |
+
modelMOE.landmark_proj_out = TaskSpecific_MoE([
|
| 152 |
+
copy.deepcopy(modelMOE.landmark_proj_out),
|
| 153 |
+
copy.deepcopy(modelMOE.landmark_proj_out),
|
| 154 |
+
copy.deepcopy(modelMOE.landmark_proj_out),
|
| 155 |
+
], [0,2,3])
|
| 156 |
+
modelMOE.proj_out_source__head = TaskSpecific_MoE([
|
| 157 |
+
copy.deepcopy(modelMOE.proj_out_source__head),
|
| 158 |
+
copy.deepcopy(modelMOE.proj_out_source__head),
|
| 159 |
+
], [2,3])
|
| 160 |
+
# Upcycle single refNet using three source refNets, and keep only one
|
| 161 |
+
if REFNET.ENABLE:
|
| 162 |
+
shared_ref = modelMOE.model.diffusion_model_refNet
|
| 163 |
+
src0 = shared_ref
|
| 164 |
+
src1 = copy.deepcopy(shared_ref)
|
| 165 |
+
src2 = copy.deepcopy(shared_ref)
|
| 166 |
+
src3 = copy.deepcopy(shared_ref)
|
| 167 |
+
replace_modules_lossless(shared_ref, [src0, src1, src2, src3],[0,1,2,3], parent_name=".model.diffusion_model_refNet", for_refnet=True)
|
| 168 |
+
# load from ./modelMOE.ckpt
|
| 169 |
+
time.sleep(20*rank_)
|
| 170 |
+
print(f"ckpt load over. m,u:")
|
| 171 |
+
# Initialize bank here (after model structure is finalized)
|
| 172 |
+
if REFNET.ENABLE :
|
| 173 |
+
modelMOE.model.bank = Bank(reader=modelMOE.model.diffusion_model,writer=modelMOE.model.diffusion_model_refNet)
|
| 174 |
+
if __name__=='__main__':
|
| 175 |
+
for key in sorted( get_representative_moduleNames(modelMOE.state_dict().keys()) ):
|
| 176 |
+
print(f" - {key}")
|
| 177 |
+
return modelMOE
|
| 178 |
+
|
ldm/lr_scheduler.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LambdaWarmUpCosineScheduler:
|
| 5 |
+
"""
|
| 6 |
+
note: use with a base_lr of 1.0
|
| 7 |
+
"""
|
| 8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
| 9 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 10 |
+
self.lr_start = lr_start
|
| 11 |
+
self.lr_min = lr_min
|
| 12 |
+
self.lr_max = lr_max
|
| 13 |
+
self.lr_max_decay_steps = max_decay_steps
|
| 14 |
+
self.last_lr = 0.
|
| 15 |
+
self.verbosity_interval = verbosity_interval
|
| 16 |
+
|
| 17 |
+
def schedule(self, n, **kwargs):
|
| 18 |
+
if self.verbosity_interval > 0:
|
| 19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
| 20 |
+
if n < self.lr_warm_up_steps:
|
| 21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
| 22 |
+
self.last_lr = lr
|
| 23 |
+
return lr
|
| 24 |
+
else:
|
| 25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
| 26 |
+
t = min(t, 1.0)
|
| 27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
| 28 |
+
1 + np.cos(t * np.pi))
|
| 29 |
+
self.last_lr = lr
|
| 30 |
+
return lr
|
| 31 |
+
|
| 32 |
+
def __call__(self, n, **kwargs):
|
| 33 |
+
return self.schedule(n,**kwargs)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LambdaWarmUpCosineScheduler2:
|
| 37 |
+
"""
|
| 38 |
+
supports repeated iterations, configurable via lists
|
| 39 |
+
note: use with a base_lr of 1.0.
|
| 40 |
+
"""
|
| 41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
| 42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
| 43 |
+
self.lr_warm_up_steps = warm_up_steps
|
| 44 |
+
self.f_start = f_start
|
| 45 |
+
self.f_min = f_min
|
| 46 |
+
self.f_max = f_max
|
| 47 |
+
self.cycle_lengths = cycle_lengths
|
| 48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
| 49 |
+
self.last_f = 0.
|
| 50 |
+
self.verbosity_interval = verbosity_interval
|
| 51 |
+
|
| 52 |
+
def find_in_interval(self, n):
|
| 53 |
+
interval = 0
|
| 54 |
+
for cl in self.cum_cycles[1:]:
|
| 55 |
+
if n <= cl:
|
| 56 |
+
return interval
|
| 57 |
+
interval += 1
|
| 58 |
+
|
| 59 |
+
def schedule(self, n, **kwargs):
|
| 60 |
+
cycle = self.find_in_interval(n)
|
| 61 |
+
n = n - self.cum_cycles[cycle]
|
| 62 |
+
if self.verbosity_interval > 0:
|
| 63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 64 |
+
f"current cycle {cycle}")
|
| 65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 67 |
+
self.last_f = f
|
| 68 |
+
return f
|
| 69 |
+
else:
|
| 70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
| 71 |
+
t = min(t, 1.0)
|
| 72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
| 73 |
+
1 + np.cos(t * np.pi))
|
| 74 |
+
self.last_f = f
|
| 75 |
+
return f
|
| 76 |
+
|
| 77 |
+
def __call__(self, n, **kwargs):
|
| 78 |
+
return self.schedule(n, **kwargs)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
| 82 |
+
|
| 83 |
+
def schedule(self, n, **kwargs):# n is the step index
|
| 84 |
+
cycle = self.find_in_interval(n)
|
| 85 |
+
n = n - self.cum_cycles[cycle]
|
| 86 |
+
if self.verbosity_interval > 0:
|
| 87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
| 88 |
+
f"current cycle {cycle}")
|
| 89 |
+
|
| 90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
| 91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
| 92 |
+
self.last_f = f
|
| 93 |
+
# print(f"0 {n=} {f=}")
|
| 94 |
+
return f
|
| 95 |
+
else:
|
| 96 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
| 97 |
+
self.last_f = f
|
| 98 |
+
# print(f"1 {n=} {f=}")
|
| 99 |
+
return f
|
ldm/models/autoencoder.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pytorch_lightning as pl
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from contextlib import contextmanager
|
| 5 |
+
|
| 6 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
| 7 |
+
|
| 8 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
| 9 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 10 |
+
|
| 11 |
+
from ldm.util import instantiate_from_config
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class VQModel(pl.LightningModule):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
ddconfig,
|
| 17 |
+
lossconfig,
|
| 18 |
+
n_embed,
|
| 19 |
+
embed_dim,
|
| 20 |
+
ckpt_path=None,
|
| 21 |
+
ignore_keys=[],
|
| 22 |
+
image_key="image",
|
| 23 |
+
colorize_nlabels=None,
|
| 24 |
+
monitor=None,
|
| 25 |
+
batch_resize_range=None,
|
| 26 |
+
scheduler_config=None,
|
| 27 |
+
lr_g_factor=1.0,
|
| 28 |
+
remap=None,
|
| 29 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
| 30 |
+
use_ema=False
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.embed_dim = embed_dim
|
| 34 |
+
self.n_embed = n_embed
|
| 35 |
+
self.image_key = image_key
|
| 36 |
+
self.encoder = Encoder(**ddconfig)
|
| 37 |
+
self.decoder = Decoder(**ddconfig)
|
| 38 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 39 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
| 40 |
+
remap=remap,
|
| 41 |
+
sane_index_shape=sane_index_shape)
|
| 42 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
| 43 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 44 |
+
if colorize_nlabels is not None:
|
| 45 |
+
assert type(colorize_nlabels)==int
|
| 46 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 47 |
+
if monitor is not None:
|
| 48 |
+
self.monitor = monitor
|
| 49 |
+
self.batch_resize_range = batch_resize_range
|
| 50 |
+
if self.batch_resize_range is not None:
|
| 51 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
| 52 |
+
|
| 53 |
+
self.use_ema = use_ema
|
| 54 |
+
if self.use_ema:
|
| 55 |
+
self.model_ema = LitEma(self)
|
| 56 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
| 57 |
+
|
| 58 |
+
if ckpt_path is not None:
|
| 59 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 60 |
+
self.scheduler_config = scheduler_config
|
| 61 |
+
self.lr_g_factor = lr_g_factor
|
| 62 |
+
|
| 63 |
+
@contextmanager
|
| 64 |
+
def ema_scope(self, context=None):
|
| 65 |
+
if self.use_ema:
|
| 66 |
+
self.model_ema.store(self.parameters())
|
| 67 |
+
self.model_ema.copy_to(self)
|
| 68 |
+
if context is not None:
|
| 69 |
+
print(f"{context}: Switched to EMA weights")
|
| 70 |
+
try:
|
| 71 |
+
yield None
|
| 72 |
+
finally:
|
| 73 |
+
if self.use_ema:
|
| 74 |
+
self.model_ema.restore(self.parameters())
|
| 75 |
+
if context is not None:
|
| 76 |
+
print(f"{context}: Restored training weights")
|
| 77 |
+
|
| 78 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 79 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 80 |
+
keys = list(sd.keys())
|
| 81 |
+
for k in keys:
|
| 82 |
+
for ik in ignore_keys:
|
| 83 |
+
if k.startswith(ik):
|
| 84 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 85 |
+
del sd[k]
|
| 86 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
| 87 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 88 |
+
if len(missing) > 0:
|
| 89 |
+
print(f"Missing Keys: {missing}")
|
| 90 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 91 |
+
|
| 92 |
+
def on_train_batch_end(self, *args, **kwargs):
|
| 93 |
+
if self.use_ema:
|
| 94 |
+
self.model_ema(self)
|
| 95 |
+
|
| 96 |
+
def encode(self, x):
|
| 97 |
+
h = self.encoder(x)
|
| 98 |
+
h = self.quant_conv(h)
|
| 99 |
+
quant, emb_loss, info = self.quantize(h)
|
| 100 |
+
return quant, emb_loss, info
|
| 101 |
+
|
| 102 |
+
def encode_to_prequant(self, x):
|
| 103 |
+
h = self.encoder(x)
|
| 104 |
+
h = self.quant_conv(h)
|
| 105 |
+
return h
|
| 106 |
+
|
| 107 |
+
def decode(self, quant):
|
| 108 |
+
quant = self.post_quant_conv(quant)
|
| 109 |
+
dec = self.decoder(quant)
|
| 110 |
+
return dec
|
| 111 |
+
|
| 112 |
+
def decode_code(self, code_b):
|
| 113 |
+
quant_b = self.quantize.embed_code(code_b)
|
| 114 |
+
dec = self.decode(quant_b)
|
| 115 |
+
return dec
|
| 116 |
+
|
| 117 |
+
def forward(self, input, return_pred_indices=False):
|
| 118 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
| 119 |
+
dec = self.decode(quant)
|
| 120 |
+
if return_pred_indices:
|
| 121 |
+
return dec, diff, ind
|
| 122 |
+
return dec, diff
|
| 123 |
+
|
| 124 |
+
def get_input(self, batch, k):
|
| 125 |
+
x = batch[k]
|
| 126 |
+
if len(x.shape) == 3:
|
| 127 |
+
x = x[..., None]
|
| 128 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 129 |
+
if self.batch_resize_range is not None:
|
| 130 |
+
lower_size = self.batch_resize_range[0]
|
| 131 |
+
upper_size = self.batch_resize_range[1]
|
| 132 |
+
if self.global_step <= 4:
|
| 133 |
+
# do the first few batches with max size to avoid later oom
|
| 134 |
+
new_resize = upper_size
|
| 135 |
+
else:
|
| 136 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
| 137 |
+
if new_resize != x.shape[2]:
|
| 138 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
| 139 |
+
x = x.detach()
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 143 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
| 144 |
+
# try not to fool the heuristics
|
| 145 |
+
x = self.get_input(batch, self.image_key)
|
| 146 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 147 |
+
|
| 148 |
+
if optimizer_idx == 0:
|
| 149 |
+
# autoencode
|
| 150 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 151 |
+
last_layer=self.get_last_layer(), split="train",
|
| 152 |
+
predicted_indices=ind)
|
| 153 |
+
|
| 154 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 155 |
+
return aeloss
|
| 156 |
+
|
| 157 |
+
if optimizer_idx == 1:
|
| 158 |
+
# discriminator
|
| 159 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
| 160 |
+
last_layer=self.get_last_layer(), split="train")
|
| 161 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
| 162 |
+
return discloss
|
| 163 |
+
|
| 164 |
+
def validation_step(self, batch, batch_idx):
|
| 165 |
+
log_dict = self._validation_step(batch, batch_idx)
|
| 166 |
+
with self.ema_scope():
|
| 167 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
| 168 |
+
return log_dict
|
| 169 |
+
|
| 170 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
| 171 |
+
x = self.get_input(batch, self.image_key)
|
| 172 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
| 173 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
| 174 |
+
self.global_step,
|
| 175 |
+
last_layer=self.get_last_layer(),
|
| 176 |
+
split="val"+suffix,
|
| 177 |
+
predicted_indices=ind
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
| 181 |
+
self.global_step,
|
| 182 |
+
last_layer=self.get_last_layer(),
|
| 183 |
+
split="val"+suffix,
|
| 184 |
+
predicted_indices=ind
|
| 185 |
+
)
|
| 186 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
| 187 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
| 188 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 189 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
| 190 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
| 191 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
| 192 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
| 193 |
+
self.log_dict(log_dict_ae)
|
| 194 |
+
self.log_dict(log_dict_disc)
|
| 195 |
+
return self.log_dict
|
| 196 |
+
|
| 197 |
+
def configure_optimizers(self):
|
| 198 |
+
lr_d = self.learning_rate
|
| 199 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
| 200 |
+
print("lr_d", lr_d)
|
| 201 |
+
print("lr_g", lr_g)
|
| 202 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 203 |
+
list(self.decoder.parameters())+
|
| 204 |
+
list(self.quantize.parameters())+
|
| 205 |
+
list(self.quant_conv.parameters())+
|
| 206 |
+
list(self.post_quant_conv.parameters()),
|
| 207 |
+
lr=lr_g, betas=(0.5, 0.9))
|
| 208 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 209 |
+
lr=lr_d, betas=(0.5, 0.9))
|
| 210 |
+
|
| 211 |
+
if self.scheduler_config is not None:
|
| 212 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 213 |
+
|
| 214 |
+
print("Setting up LambdaLR scheduler...")
|
| 215 |
+
scheduler = [
|
| 216 |
+
{
|
| 217 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
| 218 |
+
'interval': 'step',
|
| 219 |
+
'frequency': 1
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
| 223 |
+
'interval': 'step',
|
| 224 |
+
'frequency': 1
|
| 225 |
+
},
|
| 226 |
+
]
|
| 227 |
+
return [opt_ae, opt_disc], scheduler
|
| 228 |
+
return [opt_ae, opt_disc], []
|
| 229 |
+
|
| 230 |
+
def get_last_layer(self):
|
| 231 |
+
return self.decoder.conv_out.weight
|
| 232 |
+
|
| 233 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
| 234 |
+
log = dict()
|
| 235 |
+
x = self.get_input(batch, self.image_key)
|
| 236 |
+
x = x.to(self.device)
|
| 237 |
+
if only_inputs:
|
| 238 |
+
log["inputs"] = x
|
| 239 |
+
return log
|
| 240 |
+
xrec, _ = self(x)
|
| 241 |
+
if x.shape[1] > 3:
|
| 242 |
+
# colorize with random projection
|
| 243 |
+
assert xrec.shape[1] > 3
|
| 244 |
+
x = self.to_rgb(x)
|
| 245 |
+
xrec = self.to_rgb(xrec)
|
| 246 |
+
log["inputs"] = x
|
| 247 |
+
log["reconstructions"] = xrec
|
| 248 |
+
if plot_ema:
|
| 249 |
+
with self.ema_scope():
|
| 250 |
+
xrec_ema, _ = self(x)
|
| 251 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
| 252 |
+
log["reconstructions_ema"] = xrec_ema
|
| 253 |
+
return log
|
| 254 |
+
|
| 255 |
+
def to_rgb(self, x):
|
| 256 |
+
assert self.image_key == "segmentation"
|
| 257 |
+
if not hasattr(self, "colorize"):
|
| 258 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 259 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 260 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 261 |
+
return x
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class VQModelInterface(VQModel):
|
| 265 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
| 266 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
| 267 |
+
self.embed_dim = embed_dim
|
| 268 |
+
|
| 269 |
+
def encode(self, x):
|
| 270 |
+
h = self.encoder(x)
|
| 271 |
+
h = self.quant_conv(h)
|
| 272 |
+
return h
|
| 273 |
+
|
| 274 |
+
def decode(self, h, force_not_quantize=False):
|
| 275 |
+
# also go through quantization layer
|
| 276 |
+
if not force_not_quantize:
|
| 277 |
+
quant, emb_loss, info = self.quantize(h)
|
| 278 |
+
else:
|
| 279 |
+
quant = h
|
| 280 |
+
quant = self.post_quant_conv(quant)
|
| 281 |
+
dec = self.decoder(quant)
|
| 282 |
+
return dec
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class AutoencoderKL(pl.LightningModule):
|
| 286 |
+
def __init__(self,
|
| 287 |
+
ddconfig,
|
| 288 |
+
lossconfig,
|
| 289 |
+
embed_dim,
|
| 290 |
+
ckpt_path=None,
|
| 291 |
+
ignore_keys=[],
|
| 292 |
+
image_key="image",
|
| 293 |
+
colorize_nlabels=None,
|
| 294 |
+
monitor=None,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.image_key = image_key
|
| 298 |
+
self.encoder = Encoder(**ddconfig)
|
| 299 |
+
self.decoder = Decoder(**ddconfig)
|
| 300 |
+
self.loss = instantiate_from_config(lossconfig)
|
| 301 |
+
assert ddconfig["double_z"]
|
| 302 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
| 303 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 304 |
+
self.embed_dim = embed_dim
|
| 305 |
+
if colorize_nlabels is not None:
|
| 306 |
+
assert type(colorize_nlabels)==int
|
| 307 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
| 308 |
+
if monitor is not None:
|
| 309 |
+
self.monitor = monitor
|
| 310 |
+
if ckpt_path is not None:
|
| 311 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
| 312 |
+
|
| 313 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
| 314 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
| 315 |
+
keys = list(sd.keys())
|
| 316 |
+
for k in keys:
|
| 317 |
+
for ik in ignore_keys:
|
| 318 |
+
if k.startswith(ik):
|
| 319 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 320 |
+
del sd[k]
|
| 321 |
+
self.load_state_dict(sd, strict=False)
|
| 322 |
+
print(f"Restored from {path}")
|
| 323 |
+
|
| 324 |
+
def encode(self, x):
|
| 325 |
+
h = self.encoder(x)
|
| 326 |
+
moments = self.quant_conv(h)
|
| 327 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 328 |
+
return posterior
|
| 329 |
+
|
| 330 |
+
def decode(self, z):
|
| 331 |
+
z = self.post_quant_conv(z)
|
| 332 |
+
dec = self.decoder(z)
|
| 333 |
+
return dec
|
| 334 |
+
|
| 335 |
+
def forward(self, input, sample_posterior=True):
|
| 336 |
+
posterior = self.encode(input)
|
| 337 |
+
if sample_posterior:
|
| 338 |
+
z = posterior.sample()
|
| 339 |
+
else:
|
| 340 |
+
z = posterior.mode()
|
| 341 |
+
dec = self.decode(z)
|
| 342 |
+
return dec, posterior
|
| 343 |
+
|
| 344 |
+
def get_input(self, batch, k):
|
| 345 |
+
x = batch[k]
|
| 346 |
+
if len(x.shape) == 3:
|
| 347 |
+
x = x[..., None]
|
| 348 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 349 |
+
return x
|
| 350 |
+
|
| 351 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
| 352 |
+
inputs = self.get_input(batch, self.image_key)
|
| 353 |
+
reconstructions, posterior = self(inputs)
|
| 354 |
+
|
| 355 |
+
if optimizer_idx == 0:
|
| 356 |
+
# train encoder+decoder+logvar
|
| 357 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 358 |
+
last_layer=self.get_last_layer(), split="train")
|
| 359 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 360 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 361 |
+
return aeloss
|
| 362 |
+
|
| 363 |
+
if optimizer_idx == 1:
|
| 364 |
+
# train the discriminator
|
| 365 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
| 366 |
+
last_layer=self.get_last_layer(), split="train")
|
| 367 |
+
|
| 368 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
| 369 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
| 370 |
+
return discloss
|
| 371 |
+
|
| 372 |
+
def validation_step(self, batch, batch_idx):
|
| 373 |
+
inputs = self.get_input(batch, self.image_key)
|
| 374 |
+
reconstructions, posterior = self(inputs)
|
| 375 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
| 376 |
+
last_layer=self.get_last_layer(), split="val")
|
| 377 |
+
|
| 378 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
| 379 |
+
last_layer=self.get_last_layer(), split="val")
|
| 380 |
+
|
| 381 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
| 382 |
+
self.log_dict(log_dict_ae)
|
| 383 |
+
self.log_dict(log_dict_disc)
|
| 384 |
+
return self.log_dict
|
| 385 |
+
|
| 386 |
+
def configure_optimizers(self):
|
| 387 |
+
lr = self.learning_rate
|
| 388 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
| 389 |
+
list(self.decoder.parameters())+
|
| 390 |
+
list(self.quant_conv.parameters())+
|
| 391 |
+
list(self.post_quant_conv.parameters()),
|
| 392 |
+
lr=lr, betas=(0.5, 0.9))
|
| 393 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
| 394 |
+
lr=lr, betas=(0.5, 0.9))
|
| 395 |
+
return [opt_ae, opt_disc], []
|
| 396 |
+
|
| 397 |
+
def get_last_layer(self):
|
| 398 |
+
return self.decoder.conv_out.weight
|
| 399 |
+
|
| 400 |
+
@torch.no_grad()
|
| 401 |
+
def log_images(self, batch, only_inputs=False, **kwargs):
|
| 402 |
+
log = dict()
|
| 403 |
+
x = self.get_input(batch, self.image_key)
|
| 404 |
+
x = x.to(self.device)
|
| 405 |
+
if not only_inputs:
|
| 406 |
+
xrec, posterior = self(x)
|
| 407 |
+
if x.shape[1] > 3:
|
| 408 |
+
# colorize with random projection
|
| 409 |
+
assert xrec.shape[1] > 3
|
| 410 |
+
x = self.to_rgb(x)
|
| 411 |
+
xrec = self.to_rgb(xrec)
|
| 412 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
| 413 |
+
log["reconstructions"] = xrec
|
| 414 |
+
log["inputs"] = x
|
| 415 |
+
return log
|
| 416 |
+
|
| 417 |
+
def to_rgb(self, x):
|
| 418 |
+
assert self.image_key == "segmentation"
|
| 419 |
+
if not hasattr(self, "colorize"):
|
| 420 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
| 421 |
+
x = F.conv2d(x, weight=self.colorize)
|
| 422 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
| 423 |
+
return x
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class IdentityFirstStage(torch.nn.Module):
|
| 427 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
| 428 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
| 429 |
+
super().__init__()
|
| 430 |
+
|
| 431 |
+
def encode(self, x, *args, **kwargs):
|
| 432 |
+
return x
|
| 433 |
+
|
| 434 |
+
def decode(self, x, *args, **kwargs):
|
| 435 |
+
return x
|
| 436 |
+
|
| 437 |
+
def quantize(self, x, *args, **kwargs):
|
| 438 |
+
if self.vq_interface:
|
| 439 |
+
return x, None, [None, None, None]
|
| 440 |
+
return x
|
| 441 |
+
|
| 442 |
+
def forward(self, x, *args, **kwargs):
|
| 443 |
+
return x
|
ldm/models/diffusion/__init__.py
ADDED
|
File without changes
|
ldm/models/diffusion/bank.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .misc_4ddpm import *
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from ldm.modules.attention import BasicTransformerBlock
|
| 5 |
+
class Bank:
|
| 6 |
+
def __init__(self,reader:nn.Module, writer:nn.Module) -> None:
|
| 7 |
+
"""
|
| 8 |
+
For the DFS model, mark every BasicTransformerBlock with name_4bank and isReader_4bank flags.
|
| 9 |
+
Similar logic applies for the writer while checking for BasicTransformerBlock instances.
|
| 10 |
+
"""
|
| 11 |
+
self.name2data = {}
|
| 12 |
+
self.name2count = {} # track how many times each name has been retrieved
|
| 13 |
+
self.WHEN_clear_a_field = 2 # clear the entry after this many gets
|
| 14 |
+
skip_names = [
|
| 15 |
+
'input_blocks.1.1.transformer_blocks.0',
|
| 16 |
+
'input_blocks.2.1.transformer_blocks.0',
|
| 17 |
+
# 'input_blocks.4.1.transformer_blocks.0',
|
| 18 |
+
# 'input_blocks.5.1.transformer_blocks.0',
|
| 19 |
+
# 'input_blocks.7.1.transformer_blocks.0',
|
| 20 |
+
# 'input_blocks.8.1.transformer_blocks.0',
|
| 21 |
+
##-----------all middle and output_blocks (everything outside input_blocks)----
|
| 22 |
+
'middle_block.1.transformer_blocks.0',
|
| 23 |
+
'output_blocks.3.1.transformer_blocks.0',
|
| 24 |
+
'output_blocks.4.1.transformer_blocks.0',
|
| 25 |
+
'output_blocks.5.1.transformer_blocks.0',
|
| 26 |
+
'output_blocks.6.1.transformer_blocks.0',
|
| 27 |
+
'output_blocks.7.1.transformer_blocks.0',
|
| 28 |
+
'output_blocks.8.1.transformer_blocks.0',
|
| 29 |
+
'output_blocks.9.1.transformer_blocks.0',
|
| 30 |
+
'output_blocks.10.1.transformer_blocks.0',
|
| 31 |
+
'output_blocks.11.1.transformer_blocks.0',
|
| 32 |
+
]
|
| 33 |
+
# print(f"{skip_names=}")
|
| 34 |
+
|
| 35 |
+
l_name = []
|
| 36 |
+
for name, _module in writer.named_modules():
|
| 37 |
+
if isinstance(_module, BasicTransformerBlock):
|
| 38 |
+
if DEBUG:
|
| 39 |
+
print(f"{name=}")
|
| 40 |
+
if name in skip_names:
|
| 41 |
+
# print(f"skip {name=}")
|
| 42 |
+
continue
|
| 43 |
+
_module.bank = self
|
| 44 |
+
_module.name4bank = name
|
| 45 |
+
_module.isReader_4bank = False
|
| 46 |
+
l_name.append(name)
|
| 47 |
+
# print(f"{l_name=}")
|
| 48 |
+
|
| 49 |
+
for name, _module in reader.named_modules():
|
| 50 |
+
if isinstance(_module, BasicTransformerBlock):
|
| 51 |
+
if name not in l_name:
|
| 52 |
+
continue
|
| 53 |
+
_module.bank = self
|
| 54 |
+
_module.name4bank = name
|
| 55 |
+
_module.isReader_4bank = True
|
| 56 |
+
def set(self,name,data):
|
| 57 |
+
self.name2data[name] = data
|
| 58 |
+
# self.name2count[name] = 0
|
| 59 |
+
def get(self,name):
|
| 60 |
+
printC('bank get', name)
|
| 61 |
+
if name in self.name2data:
|
| 62 |
+
if name not in self.name2count:
|
| 63 |
+
self.name2count[name] = 0
|
| 64 |
+
self.name2count[name] += 1
|
| 65 |
+
data = self.name2data[name]
|
| 66 |
+
if self.name2count[name] >= self.WHEN_clear_a_field: # once the max get count is reached, remove the entry
|
| 67 |
+
del self.name2data[name]
|
| 68 |
+
del self.name2count[name]
|
| 69 |
+
return data
|
| 70 |
+
raise Exception(f"{name}\n{list(self.name2data.keys())}")
|
| 71 |
+
return None
|
| 72 |
+
def clear(self,):
|
| 73 |
+
printC('clear')
|
| 74 |
+
printC('mean ct:', sum( self.name2count.values() ) / len( self.name2count.values() ) if len( self.name2count.values() )>0 else 'null' )
|
| 75 |
+
self.name2data.clear()
|
| 76 |
+
self.name2count.clear()
|
ldm/models/diffusion/classifier.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from glob import glob
|
| 11 |
+
from natsort import natsorted
|
| 12 |
+
|
| 13 |
+
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
|
| 14 |
+
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
|
| 15 |
+
|
| 16 |
+
__models__ = {
|
| 17 |
+
'class_label': EncoderUNetModel,
|
| 18 |
+
'segmentation': UNetModel
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def disabled_train(self, mode=True):
|
| 23 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
| 24 |
+
does not change anymore."""
|
| 25 |
+
return self
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class NoisyLatentImageClassifier(pl.LightningModule):
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
diffusion_path,
|
| 32 |
+
num_classes,
|
| 33 |
+
ckpt_path=None,
|
| 34 |
+
pool='attention',
|
| 35 |
+
label_key=None,
|
| 36 |
+
diffusion_ckpt_path=None,
|
| 37 |
+
scheduler_config=None,
|
| 38 |
+
weight_decay=1.e-2,
|
| 39 |
+
log_steps=10,
|
| 40 |
+
monitor='val/loss',
|
| 41 |
+
*args,
|
| 42 |
+
**kwargs):
|
| 43 |
+
super().__init__(*args, **kwargs)
|
| 44 |
+
self.num_classes = num_classes
|
| 45 |
+
# get latest config of diffusion model
|
| 46 |
+
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
|
| 47 |
+
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
| 48 |
+
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
| 49 |
+
self.load_diffusion()
|
| 50 |
+
|
| 51 |
+
self.monitor = monitor
|
| 52 |
+
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
| 53 |
+
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
| 54 |
+
self.log_steps = log_steps
|
| 55 |
+
|
| 56 |
+
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
|
| 57 |
+
else self.diffusion_model.cond_stage_key
|
| 58 |
+
|
| 59 |
+
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
|
| 60 |
+
|
| 61 |
+
if self.label_key not in __models__:
|
| 62 |
+
raise NotImplementedError()
|
| 63 |
+
|
| 64 |
+
self.load_classifier(ckpt_path, pool)
|
| 65 |
+
|
| 66 |
+
self.scheduler_config = scheduler_config
|
| 67 |
+
self.use_scheduler = self.scheduler_config is not None
|
| 68 |
+
self.weight_decay = weight_decay
|
| 69 |
+
|
| 70 |
+
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
| 71 |
+
sd = torch.load(path, map_location="cpu")
|
| 72 |
+
if "state_dict" in list(sd.keys()):
|
| 73 |
+
sd = sd["state_dict"]
|
| 74 |
+
keys = list(sd.keys())
|
| 75 |
+
for k in keys:
|
| 76 |
+
for ik in ignore_keys:
|
| 77 |
+
if k.startswith(ik):
|
| 78 |
+
print("Deleting key {} from state_dict.".format(k))
|
| 79 |
+
del sd[k]
|
| 80 |
+
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
| 81 |
+
sd, strict=False)
|
| 82 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
| 83 |
+
if len(missing) > 0:
|
| 84 |
+
print(f"Missing Keys: {missing}")
|
| 85 |
+
if len(unexpected) > 0:
|
| 86 |
+
print(f"Unexpected Keys: {unexpected}")
|
| 87 |
+
|
| 88 |
+
def load_diffusion(self):
|
| 89 |
+
model = instantiate_from_config(self.diffusion_config)
|
| 90 |
+
self.diffusion_model = model.eval()
|
| 91 |
+
self.diffusion_model.train = disabled_train
|
| 92 |
+
for param in self.diffusion_model.parameters():
|
| 93 |
+
param.requires_grad = False
|
| 94 |
+
|
| 95 |
+
def load_classifier(self, ckpt_path, pool):
|
| 96 |
+
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
| 97 |
+
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
|
| 98 |
+
model_config.out_channels = self.num_classes
|
| 99 |
+
if self.label_key == 'class_label':
|
| 100 |
+
model_config.pool = pool
|
| 101 |
+
|
| 102 |
+
self.model = __models__[self.label_key](**model_config)
|
| 103 |
+
if ckpt_path is not None:
|
| 104 |
+
print('#####################################################################')
|
| 105 |
+
print(f'load from ckpt "{ckpt_path}"')
|
| 106 |
+
print('#####################################################################')
|
| 107 |
+
self.init_from_ckpt(ckpt_path)
|
| 108 |
+
|
| 109 |
+
@torch.no_grad()
|
| 110 |
+
def get_x_noisy(self, x, t, noise=None):
|
| 111 |
+
noise = default(noise, lambda: torch.randn_like(x))
|
| 112 |
+
continuous_sqrt_alpha_cumprod = None
|
| 113 |
+
if self.diffusion_model.use_continuous_noise:
|
| 114 |
+
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
| 115 |
+
# todo: make sure t+1 is correct here
|
| 116 |
+
|
| 117 |
+
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
|
| 118 |
+
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
|
| 119 |
+
|
| 120 |
+
def forward(self, x_noisy, t, *args, **kwargs):
|
| 121 |
+
return self.model(x_noisy, t)
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def get_input(self, batch, k):
|
| 125 |
+
x = batch[k]
|
| 126 |
+
if len(x.shape) == 3:
|
| 127 |
+
x = x[..., None]
|
| 128 |
+
x = rearrange(x, 'b h w c -> b c h w')
|
| 129 |
+
x = x.to(memory_format=torch.contiguous_format).float()
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
@torch.no_grad()
|
| 133 |
+
def get_conditioning(self, batch, k=None):
|
| 134 |
+
if k is None:
|
| 135 |
+
k = self.label_key
|
| 136 |
+
assert k is not None, 'Needs to provide label key'
|
| 137 |
+
|
| 138 |
+
targets = batch[k].to(self.device)
|
| 139 |
+
|
| 140 |
+
if self.label_key == 'segmentation':
|
| 141 |
+
targets = rearrange(targets, 'b h w c -> b c h w')
|
| 142 |
+
for down in range(self.numd):
|
| 143 |
+
h, w = targets.shape[-2:]
|
| 144 |
+
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
|
| 145 |
+
|
| 146 |
+
# targets = rearrange(targets,'b c h w -> b h w c')
|
| 147 |
+
|
| 148 |
+
return targets
|
| 149 |
+
|
| 150 |
+
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
| 151 |
+
_, top_ks = torch.topk(logits, k, dim=1)
|
| 152 |
+
if reduction == "mean":
|
| 153 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
| 154 |
+
elif reduction == "none":
|
| 155 |
+
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
| 156 |
+
|
| 157 |
+
def on_train_epoch_start(self):
|
| 158 |
+
# save some memory
|
| 159 |
+
self.diffusion_model.model.to('cpu')
|
| 160 |
+
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def write_logs(self, loss, logits, targets):
|
| 163 |
+
log_prefix = 'train' if self.training else 'val'
|
| 164 |
+
log = {}
|
| 165 |
+
log[f"{log_prefix}/loss"] = loss.mean()
|
| 166 |
+
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
| 167 |
+
logits, targets, k=1, reduction="mean"
|
| 168 |
+
)
|
| 169 |
+
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
| 170 |
+
logits, targets, k=5, reduction="mean"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
|
| 174 |
+
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
| 175 |
+
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
|
| 176 |
+
lr = self.optimizers().param_groups[0]['lr']
|
| 177 |
+
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
| 178 |
+
|
| 179 |
+
def shared_step(self, batch, t=None):
|
| 180 |
+
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
|
| 181 |
+
targets = self.get_conditioning(batch)
|
| 182 |
+
if targets.dim() == 4:
|
| 183 |
+
targets = targets.argmax(dim=1)
|
| 184 |
+
if t is None:
|
| 185 |
+
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
|
| 186 |
+
else:
|
| 187 |
+
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
| 188 |
+
x_noisy = self.get_x_noisy(x, t)
|
| 189 |
+
logits = self(x_noisy, t)
|
| 190 |
+
|
| 191 |
+
loss = F.cross_entropy(logits, targets, reduction='none')
|
| 192 |
+
|
| 193 |
+
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
| 194 |
+
|
| 195 |
+
loss = loss.mean()
|
| 196 |
+
return loss, logits, x_noisy, targets
|
| 197 |
+
|
| 198 |
+
def training_step(self, batch, batch_idx):
|
| 199 |
+
loss, *_ = self.shared_step(batch)
|
| 200 |
+
return loss
|
| 201 |
+
|
| 202 |
+
def reset_noise_accs(self):
|
| 203 |
+
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
|
| 204 |
+
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
|
| 205 |
+
|
| 206 |
+
def on_validation_start(self):
|
| 207 |
+
self.reset_noise_accs()
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def validation_step(self, batch, batch_idx):
|
| 211 |
+
loss, *_ = self.shared_step(batch)
|
| 212 |
+
|
| 213 |
+
for t in self.noisy_acc:
|
| 214 |
+
_, logits, _, targets = self.shared_step(batch, t)
|
| 215 |
+
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
|
| 216 |
+
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
|
| 217 |
+
|
| 218 |
+
return loss
|
| 219 |
+
|
| 220 |
+
def configure_optimizers(self):
|
| 221 |
+
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 222 |
+
|
| 223 |
+
if self.use_scheduler:
|
| 224 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
| 225 |
+
|
| 226 |
+
print("Setting up LambdaLR scheduler...")
|
| 227 |
+
scheduler = [
|
| 228 |
+
{
|
| 229 |
+
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
| 230 |
+
'interval': 'step',
|
| 231 |
+
'frequency': 1
|
| 232 |
+
}]
|
| 233 |
+
return [optimizer], scheduler
|
| 234 |
+
|
| 235 |
+
return optimizer
|
| 236 |
+
|
| 237 |
+
@torch.no_grad()
|
| 238 |
+
def log_images(self, batch, N=8, *args, **kwargs):
|
| 239 |
+
log = dict()
|
| 240 |
+
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
| 241 |
+
log['inputs'] = x
|
| 242 |
+
|
| 243 |
+
y = self.get_conditioning(batch)
|
| 244 |
+
|
| 245 |
+
if self.label_key == 'class_label':
|
| 246 |
+
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
| 247 |
+
log['labels'] = y
|
| 248 |
+
|
| 249 |
+
if ismap(y):
|
| 250 |
+
log['labels'] = self.diffusion_model.to_rgb(y)
|
| 251 |
+
|
| 252 |
+
for step in range(self.log_steps):
|
| 253 |
+
current_time = step * self.log_time_interval
|
| 254 |
+
|
| 255 |
+
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
| 256 |
+
|
| 257 |
+
log[f'inputs@t{current_time}'] = x_noisy
|
| 258 |
+
|
| 259 |
+
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
| 260 |
+
pred = rearrange(pred, 'b h w c -> b c h w')
|
| 261 |
+
|
| 262 |
+
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
|
| 263 |
+
|
| 264 |
+
for key in log:
|
| 265 |
+
log[key] = log[key][:N]
|
| 266 |
+
|
| 267 |
+
return log
|