NGain commited on
Commit
a0967dd
·
verified ·
1 Parent(s): 199b0ea

Delete utils

Browse files
utils/__pycache__/devices.cpython-310.pyc DELETED
Binary file (4.14 kB)
 
utils/__pycache__/devices.cpython-37.pyc DELETED
Binary file (4.09 kB)
 
utils/__pycache__/devices.cpython-38.pyc DELETED
Binary file (4.09 kB)
 
utils/__pycache__/img_util.cpython-310.pyc DELETED
Binary file (1.26 kB)
 
utils/__pycache__/img_util.cpython-37.pyc DELETED
Binary file (1.25 kB)
 
utils/__pycache__/misc.cpython-310.pyc DELETED
Binary file (2.01 kB)
 
utils/__pycache__/misc.cpython-37.pyc DELETED
Binary file (1.95 kB)
 
utils/__pycache__/misc.cpython-38.pyc DELETED
Binary file (1.94 kB)
 
utils/__pycache__/vaehook.cpython-310.pyc DELETED
Binary file (19.5 kB)
 
utils/__pycache__/vaehook.cpython-37.pyc DELETED
Binary file (18.9 kB)
 
utils/__pycache__/vaehook.cpython-38.pyc DELETED
Binary file (18.8 kB)
 
utils/__pycache__/wavelet_color_fix.cpython-310.pyc DELETED
Binary file (3.72 kB)
 
utils/__pycache__/wavelet_color_fix.cpython-38.pyc DELETED
Binary file (3.79 kB)
 
utils/devices.py DELETED
@@ -1,138 +0,0 @@
1
- import sys
2
- import contextlib
3
- from functools import lru_cache
4
-
5
- import torch
6
- #from modules import errors
7
-
8
- if sys.platform == "darwin":
9
- from modules import mac_specific
10
-
11
-
12
- def has_mps() -> bool:
13
- if sys.platform != "darwin":
14
- return False
15
- else:
16
- return mac_specific.has_mps
17
-
18
-
19
- def get_cuda_device_string():
20
- return "cuda"
21
-
22
-
23
- def get_optimal_device_name():
24
- if torch.cuda.is_available():
25
- return get_cuda_device_string()
26
-
27
- if has_mps():
28
- return "mps"
29
-
30
- return "cpu"
31
-
32
-
33
- def get_optimal_device():
34
- return torch.device(get_optimal_device_name())
35
-
36
-
37
- def get_device_for(task):
38
- return get_optimal_device()
39
-
40
-
41
- def torch_gc():
42
-
43
- if torch.cuda.is_available():
44
- with torch.cuda.device(get_cuda_device_string()):
45
- torch.cuda.empty_cache()
46
- torch.cuda.ipc_collect()
47
-
48
- if has_mps():
49
- mac_specific.torch_mps_gc()
50
-
51
-
52
- def enable_tf32():
53
- if torch.cuda.is_available():
54
-
55
- # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56
- # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57
- if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58
- torch.backends.cudnn.benchmark = True
59
-
60
- torch.backends.cuda.matmul.allow_tf32 = True
61
- torch.backends.cudnn.allow_tf32 = True
62
-
63
-
64
- enable_tf32()
65
- #errors.run(enable_tf32, "Enabling TF32")
66
-
67
- cpu = torch.device("cpu")
68
- device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69
- dtype = torch.float16
70
- dtype_vae = torch.float16
71
- dtype_unet = torch.float16
72
- unet_needs_upcast = False
73
-
74
-
75
- def cond_cast_unet(input):
76
- return input.to(dtype_unet) if unet_needs_upcast else input
77
-
78
-
79
- def cond_cast_float(input):
80
- return input.float() if unet_needs_upcast else input
81
-
82
-
83
- def randn(seed, shape):
84
- torch.manual_seed(seed)
85
- return torch.randn(shape, device=device)
86
-
87
-
88
- def randn_without_seed(shape):
89
- return torch.randn(shape, device=device)
90
-
91
-
92
- def autocast(disable=False):
93
- if disable:
94
- return contextlib.nullcontext()
95
-
96
- return torch.autocast("cuda")
97
-
98
-
99
- def without_autocast(disable=False):
100
- return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101
-
102
-
103
- class NansException(Exception):
104
- pass
105
-
106
-
107
- def test_for_nans(x, where):
108
- if not torch.all(torch.isnan(x)).item():
109
- return
110
-
111
- if where == "unet":
112
- message = "A tensor with all NaNs was produced in Unet."
113
-
114
- elif where == "vae":
115
- message = "A tensor with all NaNs was produced in VAE."
116
-
117
- else:
118
- message = "A tensor with all NaNs was produced."
119
-
120
- message += " Use --disable-nan-check commandline argument to disable this check."
121
-
122
- raise NansException(message)
123
-
124
-
125
- @lru_cache
126
- def first_time_calculation():
127
- """
128
- just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129
- spends about 2.7 seconds doing that, at least wih NVidia.
130
- """
131
-
132
- x = torch.zeros((1, 1)).to(device, dtype)
133
- linear = torch.nn.Linear(1, 1).to(device, dtype)
134
- linear(x)
135
-
136
- x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137
- conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138
- conv2d(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/img_util.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- import PIL
3
- import cv2
4
- import math
5
- import numpy as np
6
- import torch
7
- import torchvision
8
- import imageio
9
-
10
- from einops import rearrange
11
-
12
- def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0):
13
- videos = rearrange(videos, "b c t h w -> t b c h w").cpu()
14
- outputs = []
15
- for x in videos:
16
- x = torchvision.utils.make_grid(x, nrow=n_rows)
17
- x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
18
- if rescale:
19
- x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1
20
- x = (x * 255).numpy().astype(np.uint8)
21
- #x = adjust_gamma(x, 0.5)
22
- outputs.append(x)
23
-
24
- outputs = outputs[discardN:]
25
-
26
- if path is not None:
27
- #os.makedirs(os.path.dirname(path), exist_ok=True)
28
- imageio.mimsave(path, outputs, duration=1000/fps, loop=0)
29
-
30
- return outputs
31
-
32
- def convert_image_to_fn(img_type, minsize, image, eps=0.02):
33
- width, height = image.size
34
- if min(width, height) < minsize:
35
- scale = minsize/min(width, height) + eps
36
- image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))
37
-
38
- if image.mode != img_type:
39
- return image.convert(img_type)
40
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/metrics.py DELETED
@@ -1,65 +0,0 @@
1
- import os
2
- import pyiqa
3
- import argparse
4
- from tqdm import tqdm
5
-
6
- def test_image_quality(image_dir, metrics, weight_paths):
7
- """
8
- 测试指定文件夹中所有 PNG 图像的质量指标。
9
-
10
- Args:
11
- image_dir (str): 包含 PNG 图像的文件夹路径。
12
- metrics (list): 需要测试的指标列表,例如 ['musiq', 'maniqa', 'clipiqa'].
13
- weight_paths (dict): 每个指标的本地权重文件路径。
14
- """
15
- # 初始化指标模型
16
- metric_models = {}
17
- for metric in metrics:
18
- if metric in weight_paths:
19
- # 如果提供了本地权重路径,则加载本地权重
20
- model = pyiqa.create_metric(metric, pretrained_model_path=weight_paths[metric])
21
- else:
22
- # 否则使用默认权重(需要网络下载)
23
- model = pyiqa.create_metric(metric)
24
- metric_models[metric] = model
25
-
26
- # 获取所有 PNG 图像路径
27
- image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]
28
- if not image_paths:
29
- print(f"未找到 PNG 图像:{image_dir}")
30
- return
31
- # image_paths = sorted(image_paths)[:28]
32
- print(image_paths)
33
-
34
- # 遍历图像并计算指标
35
- results = {metric: [] for metric in metrics}
36
- for image_path in tqdm(image_paths, desc="Processing images"):
37
- for metric, model in metric_models.items():
38
- score = model(image_path) # 计算指标分数
39
- results[metric].append(score.item()) # 将分数添加到结果中
40
-
41
- # 打印结果
42
- print("\n测试结果:")
43
- for metric, scores in results.items():
44
- avg_score = sum(scores) / len(scores)
45
- # print(f"{metric.upper()} - 平均分数: {avg_score:.4f}")
46
- print(avg_score)
47
- # print(f"{metric.upper()} - 单张图像分数: {scores}")
48
-
49
- if __name__ == "__main__":
50
- # 解析命令行参数
51
- parser = argparse.ArgumentParser(description="测试图像质量指标")
52
- parser.add_argument("--image_dir", type=str, required=True, help="包含 PNG 图像的文件夹路径")
53
- args = parser.parse_args()
54
-
55
- # 需要测试的指标
56
- metrics_to_test = ['musiq', 'maniqa', 'clipiqa']
57
-
58
- # 每个指标的本地权重文件路径
59
- weight_paths = {
60
- 'musiq': '/media/ssd8T/wyw/Pretrained/musiq/musiq_koniq_ckpt-e95806b9.pth',
61
- 'maniqa': '/media/ssd8T/wyw/Pretrained/clipiqa/ckpt_koniq10k.pt',
62
- }
63
-
64
- # 运行测试
65
- test_image_quality(args.image_dir, metrics_to_test, weight_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/metrics_off.py DELETED
@@ -1,313 +0,0 @@
1
- import torch
2
- import os
3
- import pyiqa
4
- import cv2
5
- import numpy as np
6
- from PIL import Image
7
-
8
-
9
-
10
- def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
11
- """Calculate PSNR (Peak Signal-to-Noise Ratio).
12
-
13
- Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
14
-
15
- Args:
16
- img1 (ndarray): Images with range [0, 255].
17
- img2 (ndarray): Images with range [0, 255].
18
- crop_border (int): Cropped pixels in each edge of an image. These
19
- pixels are not involved in the PSNR calculation.
20
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
21
- Default: 'HWC'.
22
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
23
-
24
- Returns:
25
- float: psnr result.
26
- """
27
-
28
- assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
29
- if input_order not in ['HWC', 'CHW']:
30
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
31
- img1 = reorder_image(img1, input_order=input_order)
32
- img2 = reorder_image(img2, input_order=input_order)
33
- img1 = img1.astype(np.float64)
34
- img2 = img2.astype(np.float64)
35
-
36
- if crop_border != 0:
37
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
38
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
39
-
40
- if test_y_channel:
41
- img1 = to_y_channel(img1)
42
- img2 = to_y_channel(img2)
43
-
44
- mse = np.mean((img1 - img2) ** 2)
45
- if mse == 0:
46
- return float('inf')
47
- return 20. * np.log10(255. / np.sqrt(mse))
48
-
49
-
50
- def _ssim(img1, img2):
51
- """Calculate SSIM (structural similarity) for one channel images.
52
-
53
- It is called by func:`calculate_ssim`.
54
-
55
- Args:
56
- img1 (ndarray): Images with range [0, 255] with order 'HWC'.
57
- img2 (ndarray): Images with range [0, 255] with order 'HWC'.
58
-
59
- Returns:
60
- float: ssim result.
61
- """
62
-
63
- C1 = (0.01 * 255) ** 2
64
- C2 = (0.03 * 255) ** 2
65
-
66
- img1 = img1.astype(np.float64)
67
- img2 = img2.astype(np.float64)
68
- kernel = cv2.getGaussianKernel(11, 1.5)
69
- window = np.outer(kernel, kernel.transpose())
70
-
71
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
72
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
73
- mu1_sq = mu1 ** 2
74
- mu2_sq = mu2 ** 2
75
- mu1_mu2 = mu1 * mu2
76
- sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
77
- sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
78
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
79
-
80
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
81
- return ssim_map.mean()
82
-
83
-
84
- def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
85
- """Calculate SSIM (structural similarity).
86
-
87
- Ref:
88
- Image quality assessment: From error visibility to structural similarity
89
-
90
- The results are the same as that of the official released MATLAB code in
91
- https://ece.uwaterloo.ca/~z70wang/research/ssim/.
92
-
93
- For three-channel images, SSIM is calculated for each channel and then
94
- averaged.
95
-
96
- Args:
97
- img1 (ndarray): Images with range [0, 255].
98
- img2 (ndarray): Images with range [0, 255].
99
- crop_border (int): Cropped pixels in each edge of an image. These
100
- pixels are not involved in the SSIM calculation.
101
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
102
- Default: 'HWC'.
103
- test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
104
-
105
- Returns:
106
- float: ssim result.
107
- """
108
-
109
- assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
110
- if input_order not in ['HWC', 'CHW']:
111
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
112
- img1 = reorder_image(img1, input_order=input_order)
113
- img2 = reorder_image(img2, input_order=input_order)
114
- img1 = img1.astype(np.float64)
115
- img2 = img2.astype(np.float64)
116
-
117
- if crop_border != 0:
118
- img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
119
- img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
120
-
121
- if test_y_channel:
122
- img1 = to_y_channel(img1)
123
- img2 = to_y_channel(img2)
124
-
125
- ssims = []
126
- for i in range(img1.shape[2]):
127
- ssims.append(_ssim(img1[..., i], img2[..., i]))
128
- return np.array(ssims).mean()
129
-
130
-
131
- def reorder_image(img, input_order='HWC'):
132
- """Reorder images to 'HWC' order.
133
-
134
- If the input_order is (h, w), return (h, w, 1);
135
- If the input_order is (c, h, w), return (h, w, c);
136
- If the input_order is (h, w, c), return as it is.
137
-
138
- Args:
139
- img (ndarray): Input image.
140
- input_order (str): Whether the input order is 'HWC' or 'CHW'.
141
- If the input image shape is (h, w), input_order will not have
142
- effects. Default: 'HWC'.
143
-
144
- Returns:
145
- ndarray: reordered image.
146
- """
147
-
148
- if input_order not in ['HWC', 'CHW']:
149
- raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
150
- if len(img.shape) == 2:
151
- img = img[..., None]
152
- if input_order == 'CHW':
153
- img = img.transpose(1, 2, 0)
154
- return img
155
-
156
-
157
- def to_y_channel(img):
158
- """Change to Y channel of YCbCr.
159
-
160
- Args:
161
- img (ndarray): Images with range [0, 255].
162
-
163
- Returns:
164
- (ndarray): Images with range [0, 255] (float type) without round.
165
- """
166
- img = img.astype(np.float32) / 255.
167
- if img.ndim == 3 and img.shape[2] == 3:
168
- img = bgr2ycbcr(img, y_only=True)
169
- img = img[..., None]
170
- return img * 255.
171
-
172
-
173
- def _convert_input_type_range(img):
174
- """Convert the type and range of the input image.
175
-
176
- It converts the input image to np.float32 type and range of [0, 1].
177
- It is mainly used for pre-processing the input image in colorspace
178
- convertion functions such as rgb2ycbcr and ycbcr2rgb.
179
-
180
- Args:
181
- img (ndarray): The input image. It accepts:
182
- 1. np.uint8 type with range [0, 255];
183
- 2. np.float32 type with range [0, 1].
184
-
185
- Returns:
186
- (ndarray): The converted image with type of np.float32 and range of
187
- [0, 1].
188
- """
189
- img_type = img.dtype
190
- img = img.astype(np.float32)
191
- if img_type == np.float32:
192
- pass
193
- elif img_type == np.uint8:
194
- img /= 255.
195
- else:
196
- raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
197
- return img
198
-
199
-
200
- def _convert_output_type_range(img, dst_type):
201
- """Convert the type and range of the image according to dst_type.
202
-
203
- It converts the image to desired type and range. If `dst_type` is np.uint8,
204
- images will be converted to np.uint8 type with range [0, 255]. If
205
- `dst_type` is np.float32, it converts the image to np.float32 type with
206
- range [0, 1].
207
- It is mainly used for post-processing images in colorspace convertion
208
- functions such as rgb2ycbcr and ycbcr2rgb.
209
-
210
- Args:
211
- img (ndarray): The image to be converted with np.float32 type and
212
- range [0, 255].
213
- dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
214
- converts the image to np.uint8 type with range [0, 255]. If
215
- dst_type is np.float32, it converts the image to np.float32 type
216
- with range [0, 1].
217
-
218
- Returns:
219
- (ndarray): The converted image with desired type and range.
220
- """
221
- if dst_type not in (np.uint8, np.float32):
222
- raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
223
- if dst_type == np.uint8:
224
- img = img.round()
225
- else:
226
- img /= 255.
227
- return img.astype(dst_type)
228
-
229
-
230
- def bgr2ycbcr(img, y_only=False):
231
- """Convert a BGR image to YCbCr image.
232
-
233
- The bgr version of rgb2ycbcr.
234
- It implements the ITU-R BT.601 conversion for standard-definition
235
- television. See more details in
236
- https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
237
-
238
- It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
239
- In OpenCV, it implements a JPEG conversion. See more details in
240
- https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
241
-
242
- Args:
243
- img (ndarray): The input image. It accepts:
244
- 1. np.uint8 type with range [0, 255];
245
- 2. np.float32 type with range [0, 1].
246
- y_only (bool): Whether to only return Y channel. Default: False.
247
-
248
- Returns:
249
- ndarray: The converted YCbCr image. The output image has the same type
250
- and range as input image.
251
- """
252
- img_type = img.dtype
253
- img = _convert_input_type_range(img)
254
- if y_only:
255
- out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
256
- else:
257
- out_img = np.matmul(
258
- img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
259
- out_img = _convert_output_type_range(out_img, img_type)
260
- return out_img
261
-
262
-
263
-
264
- def metric(input_file_list_w, metric_types):
265
- # Ensure the file numbers are the same for reference-based metrics
266
-
267
- # Initiate score pool
268
- psnrs = 0
269
- ssims = 0
270
-
271
- pyiqa_types = metric_types
272
- pyiqa_metrics = {}
273
- pyiqa_results = {}
274
- for m in pyiqa_types:
275
- pyiqa_metrics[m] = pyiqa.create_metric(m, device='cpu')
276
- pyiqa_results[m] = 0
277
-
278
-
279
- file_num_w = len(input_file_list_w)
280
- print("the number of submitted wild", file_num_w)
281
- for idx in range(file_num_w):
282
- for m in pyiqa_types:
283
- if 'lpips' not in m:
284
- pyiqa_results[m] += pyiqa_metrics[m](input_file_list_w[idx]).detach().cpu().squeeze().item()
285
-
286
- for m in pyiqa_types:
287
- pyiqa_results[m] /= file_num_w
288
-
289
- return pyiqa_results
290
-
291
- import sys
292
- import glob
293
-
294
- submit_dir = '/media/ssd8T/wyw/Data/NTIRE2025/SeeSR_test/sam_10000/wild_noise/sample00'
295
-
296
- img_ext = ['png', 'jpg']
297
-
298
- input_list_w = []
299
-
300
-
301
- for ext in img_ext:
302
- input_list_w.extend(glob.glob(os.path.join(submit_dir, f'*.{ext}')))
303
-
304
-
305
- input_list_w.sort()
306
-
307
- # metrics used in pyiqa
308
- pyiqa_metrics = ['musiq', 'maniqa', 'clipiqa']
309
-
310
- pyiqa_all = metric(input_list_w, pyiqa_metrics)
311
-
312
- score = 10*pyiqa_all['maniqa']+10*pyiqa_all['clipiqa']+0.1*pyiqa_all['musiq']
313
- print('FinalScore:{} MUSIQ:{} ManIQA:{} CLIPIQA:{}'.format(score, str(pyiqa_all['musiq']), str(pyiqa_all['maniqa']), str(pyiqa_all['clipiqa'])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/misc.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import binascii
3
- from safetensors import safe_open
4
-
5
- import torch
6
-
7
- from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
8
-
9
- def rand_name(length=8, suffix=''):
10
- name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
11
- if suffix:
12
- if not suffix.startswith('.'):
13
- suffix = '.' + suffix
14
- name += suffix
15
- return name
16
-
17
- def cycle(dl):
18
- while True:
19
- for data in dl:
20
- yield data
21
-
22
- def exists(x):
23
- return x is not None
24
-
25
- def identity(x):
26
- return x
27
-
28
- def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""):
29
- if model_path is None: return unet
30
-
31
- if model_path.endswith(".ckpt"):
32
- base_state_dict = torch.load(model_path)['state_dict']
33
- elif model_path.endswith(".safetensors"):
34
- state_dict = {}
35
- with safe_open(model_path, framework="pt", device="cpu") as f:
36
- for key in f.keys():
37
- state_dict[key] = f.get_tensor(key)
38
-
39
- is_lora = all("lora" in k for k in state_dict.keys())
40
- if not is_lora:
41
- base_state_dict = state_dict
42
- else:
43
- base_state_dict = {}
44
- with safe_open(model_base, framework="pt", device="cpu") as f:
45
- for key in f.keys():
46
- base_state_dict[key] = f.get_tensor(key)
47
-
48
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config)
49
- unet_state_dict = unet.state_dict()
50
- for key in converted_unet_checkpoint:
51
- converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key]
52
- unet.load_state_dict(converted_unet_checkpoint, strict=False)
53
-
54
- if vae is not None:
55
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config)
56
- vae.load_state_dict(converted_vae_checkpoint)
57
-
58
- return unet, vae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/vaehook.py DELETED
@@ -1,828 +0,0 @@
1
- # ------------------------------------------------------------------------
2
- #
3
- # Ultimate VAE Tile Optimization
4
- #
5
- # Introducing a revolutionary new optimization designed to make
6
- # the VAE work with giant images on limited VRAM!
7
- # Say goodbye to the frustration of OOM and hello to seamless output!
8
- #
9
- # ------------------------------------------------------------------------
10
- #
11
- # This script is a wild hack that splits the image into tiles,
12
- # encodes each tile separately, and merges the result back together.
13
- #
14
- # Advantages:
15
- # - The VAE can now work with giant images on limited VRAM
16
- # (~10 GB for 8K images!)
17
- # - The merged output is completely seamless without any post-processing.
18
- #
19
- # Drawbacks:
20
- # - Giant RAM needed. To store the intermediate results for a 4096x4096
21
- # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
22
- # you need 128 GB RAM machine (it consumes ~100 GB)
23
- # - NaNs always appear in for 8k images when you use fp16 (half) VAE
24
- # You must use --no-half-vae to disable half VAE for that giant image.
25
- # - Slow speed. With default tile size, it takes around 50/200 seconds
26
- # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
27
- # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
28
- # - The gradient calculation is not compatible with this hack. It
29
- # will break any backward() or torch.autograd.grad() that passes VAE.
30
- # (But you can still use the VAE to generate training data.)
31
- #
32
- # How it works:
33
- # 1) The image is split into tiles.
34
- # - To ensure perfect results, each tile is padded with 32 pixels
35
- # on each side.
36
- # - Then the conv2d/silu/upsample/downsample can produce identical
37
- # results to the original image without splitting.
38
- # 2) The original forward is decomposed into a task queue and a task worker.
39
- # - The task queue is a list of functions that will be executed in order.
40
- # - The task worker is a loop that executes the tasks in the queue.
41
- # 3) The task queue is executed for each tile.
42
- # - Current tile is sent to GPU.
43
- # - local operations are directly executed.
44
- # - Group norm calculation is temporarily suspended until the mean
45
- # and var of all tiles are calculated.
46
- # - The residual is pre-calculated and stored and addded back later.
47
- # - When need to go to the next tile, the current tile is send to cpu.
48
- # 4) After all tiles are processed, tiles are merged on cpu and return.
49
- #
50
- # Enjoy!
51
- #
52
- # @author: LI YI @ Nanyang Technological University - Singapore
53
- # @date: 2023-03-02
54
- # @license: MIT License
55
- #
56
- # Please give me a star if you like this project!
57
- #
58
- # -------------------------------------------------------------------------
59
-
60
- import gc
61
- from time import time
62
- import math
63
- from tqdm import tqdm
64
-
65
- import torch
66
- import torch.version
67
- import torch.nn.functional as F
68
- from einops import rearrange
69
- import os
70
- import sys
71
- sys.path.append(os.getcwd())
72
- import utils.devices as devices
73
-
74
- try:
75
- import xformers
76
- import xformers.ops
77
- except ImportError:
78
- pass
79
-
80
- sd_flag = False
81
-
82
- def get_recommend_encoder_tile_size():
83
- if torch.cuda.is_available():
84
- total_memory = torch.cuda.get_device_properties(
85
- devices.device).total_memory // 2**20
86
- if total_memory > 16*1000:
87
- ENCODER_TILE_SIZE = 3072
88
- elif total_memory > 12*1000:
89
- ENCODER_TILE_SIZE = 2048
90
- elif total_memory > 8*1000:
91
- ENCODER_TILE_SIZE = 1536
92
- else:
93
- ENCODER_TILE_SIZE = 960
94
- else:
95
- ENCODER_TILE_SIZE = 512
96
- return ENCODER_TILE_SIZE
97
-
98
-
99
- def get_recommend_decoder_tile_size():
100
- if torch.cuda.is_available():
101
- total_memory = torch.cuda.get_device_properties(
102
- devices.device).total_memory // 2**20
103
- if total_memory > 30*1000:
104
- DECODER_TILE_SIZE = 256
105
- elif total_memory > 16*1000:
106
- DECODER_TILE_SIZE = 192
107
- elif total_memory > 12*1000:
108
- DECODER_TILE_SIZE = 128
109
- elif total_memory > 8*1000:
110
- DECODER_TILE_SIZE = 96
111
- else:
112
- DECODER_TILE_SIZE = 64
113
- else:
114
- DECODER_TILE_SIZE = 64
115
- return DECODER_TILE_SIZE
116
-
117
-
118
- if 'global const':
119
- DEFAULT_ENABLED = False
120
- DEFAULT_MOVE_TO_GPU = False
121
- DEFAULT_FAST_ENCODER = True
122
- DEFAULT_FAST_DECODER = True
123
- DEFAULT_COLOR_FIX = 0
124
- DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
125
- DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
126
-
127
-
128
- # inplace version of silu
129
- def inplace_nonlinearity(x):
130
- # Test: fix for Nans
131
- return F.silu(x, inplace=True)
132
-
133
- # extracted from ldm.modules.diffusionmodules.model
134
-
135
- # from diffusers lib
136
- def attn_forward_new(self, h_):
137
- batch_size, channel, height, width = h_.shape
138
- hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
139
-
140
- attention_mask = None
141
- encoder_hidden_states = None
142
- batch_size, sequence_length, _ = hidden_states.shape
143
- attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
144
-
145
- query = self.to_q(hidden_states)
146
-
147
- if encoder_hidden_states is None:
148
- encoder_hidden_states = hidden_states
149
- elif self.norm_cross:
150
- encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
151
-
152
- key = self.to_k(encoder_hidden_states)
153
- value = self.to_v(encoder_hidden_states)
154
-
155
- query = self.head_to_batch_dim(query)
156
- key = self.head_to_batch_dim(key)
157
- value = self.head_to_batch_dim(value)
158
-
159
- attention_probs = self.get_attention_scores(query, key, attention_mask)
160
- hidden_states = torch.bmm(attention_probs, value)
161
- hidden_states = self.batch_to_head_dim(hidden_states)
162
-
163
- # linear proj
164
- hidden_states = self.to_out[0](hidden_states)
165
- # dropout
166
- hidden_states = self.to_out[1](hidden_states)
167
-
168
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
169
-
170
- return hidden_states
171
-
172
- def attn_forward(self, h_):
173
- q = self.q(h_)
174
- k = self.k(h_)
175
- v = self.v(h_)
176
-
177
- # compute attention
178
- b, c, h, w = q.shape
179
- q = q.reshape(b, c, h*w)
180
- q = q.permute(0, 2, 1) # b,hw,c
181
- k = k.reshape(b, c, h*w) # b,c,hw
182
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
183
- w_ = w_ * (int(c)**(-0.5))
184
- w_ = torch.nn.functional.softmax(w_, dim=2)
185
-
186
- # attend to values
187
- v = v.reshape(b, c, h*w)
188
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
189
- # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
190
- h_ = torch.bmm(v, w_)
191
- h_ = h_.reshape(b, c, h, w)
192
-
193
- h_ = self.proj_out(h_)
194
-
195
- return h_
196
-
197
-
198
- def xformer_attn_forward(self, h_):
199
- q = self.q(h_)
200
- k = self.k(h_)
201
- v = self.v(h_)
202
-
203
- # compute attention
204
- B, C, H, W = q.shape
205
- q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
206
-
207
- q, k, v = map(
208
- lambda t: t.unsqueeze(3)
209
- .reshape(B, t.shape[1], 1, C)
210
- .permute(0, 2, 1, 3)
211
- .reshape(B * 1, t.shape[1], C)
212
- .contiguous(),
213
- (q, k, v),
214
- )
215
- out = xformers.ops.memory_efficient_attention(
216
- q, k, v, attn_bias=None, op=self.attention_op)
217
-
218
- out = (
219
- out.unsqueeze(0)
220
- .reshape(B, 1, out.shape[1], C)
221
- .permute(0, 2, 1, 3)
222
- .reshape(B, out.shape[1], C)
223
- )
224
- out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
225
- out = self.proj_out(out)
226
- return out
227
-
228
-
229
- def attn2task(task_queue, net):
230
- if False: #isinstance(net, AttnBlock):
231
- task_queue.append(('store_res', lambda x: x))
232
- task_queue.append(('pre_norm', net.norm))
233
- task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
234
- task_queue.append(['add_res', None])
235
- elif False: #isinstance(net, MemoryEfficientAttnBlock):
236
- task_queue.append(('store_res', lambda x: x))
237
- task_queue.append(('pre_norm', net.norm))
238
- task_queue.append(
239
- ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
240
- task_queue.append(['add_res', None])
241
- else:
242
- task_queue.append(('store_res', lambda x: x))
243
- task_queue.append(('pre_norm', net.group_norm))
244
- task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
245
- task_queue.append(['add_res', None])
246
-
247
- def resblock2task(queue, block):
248
- """
249
- Turn a ResNetBlock into a sequence of tasks and append to the task queue
250
-
251
- @param queue: the target task queue
252
- @param block: ResNetBlock
253
-
254
- """
255
- if block.in_channels != block.out_channels:
256
- if sd_flag:
257
- if block.use_conv_shortcut:
258
- queue.append(('store_res', block.conv_shortcut))
259
- else:
260
- queue.append(('store_res', block.nin_shortcut))
261
- else:
262
- if block.use_in_shortcut:
263
- queue.append(('store_res', block.conv_shortcut))
264
- else:
265
- queue.append(('store_res', block.nin_shortcut))
266
-
267
- else:
268
- queue.append(('store_res', lambda x: x))
269
- queue.append(('pre_norm', block.norm1))
270
- queue.append(('silu', inplace_nonlinearity))
271
- queue.append(('conv1', block.conv1))
272
- queue.append(('pre_norm', block.norm2))
273
- queue.append(('silu', inplace_nonlinearity))
274
- queue.append(('conv2', block.conv2))
275
- queue.append(['add_res', None])
276
-
277
-
278
-
279
- def build_sampling(task_queue, net, is_decoder):
280
- """
281
- Build the sampling part of a task queue
282
- @param task_queue: the target task queue
283
- @param net: the network
284
- @param is_decoder: currently building decoder or encoder
285
- """
286
- if is_decoder:
287
- # resblock2task(task_queue, net.mid.block_1)
288
- # attn2task(task_queue, net.mid.attn_1)
289
- # resblock2task(task_queue, net.mid.block_2)
290
- # resolution_iter = reversed(range(net.num_resolutions))
291
- # block_ids = net.num_res_blocks + 1
292
- # condition = 0
293
- # module = net.up
294
- # func_name = 'upsample'
295
- resblock2task(task_queue, net.mid_block.resnets[0])
296
- attn2task(task_queue, net.mid_block.attentions[0])
297
- resblock2task(task_queue, net.mid_block.resnets[1])
298
- resolution_iter = (range(len(net.up_blocks))) # range(0,4)
299
- block_ids = 2 + 1
300
- condition = len(net.up_blocks) - 1
301
- module = net.up_blocks
302
- func_name = 'upsamplers'
303
- else:
304
- # resolution_iter = range(net.num_resolutions)
305
- # block_ids = net.num_res_blocks
306
- # condition = net.num_resolutions - 1
307
- # module = net.down
308
- # func_name = 'downsample'
309
- resolution_iter = (range(len(net.down_blocks))) # range(0,4)
310
- block_ids = 2
311
- condition = len(net.down_blocks) - 1
312
- module = net.down_blocks
313
- func_name = 'downsamplers'
314
-
315
-
316
- for i_level in resolution_iter:
317
- for i_block in range(block_ids):
318
- resblock2task(task_queue, module[i_level].resnets[i_block])
319
- if i_level != condition:
320
- if is_decoder:
321
- task_queue.append((func_name, module[i_level].upsamplers[0]))
322
- else:
323
- task_queue.append((func_name, module[i_level].downsamplers[0]))
324
-
325
- if not is_decoder:
326
- resblock2task(task_queue, net.mid_block.resnets[0])
327
- attn2task(task_queue, net.mid_block.attentions[0])
328
- resblock2task(task_queue, net.mid_block.resnets[1])
329
-
330
-
331
- def build_task_queue(net, is_decoder):
332
- """
333
- Build a single task queue for the encoder or decoder
334
- @param net: the VAE decoder or encoder network
335
- @param is_decoder: currently building decoder or encoder
336
- @return: the task queue
337
- """
338
- task_queue = []
339
- task_queue.append(('conv_in', net.conv_in))
340
-
341
- # construct the sampling part of the task queue
342
- # because encoder and decoder share the same architecture, we extract the sampling part
343
- build_sampling(task_queue, net, is_decoder)
344
- if is_decoder and not sd_flag:
345
- net.give_pre_end = False
346
- net.tanh_out = False
347
-
348
- if not is_decoder or not net.give_pre_end:
349
- if sd_flag:
350
- task_queue.append(('pre_norm', net.norm_out))
351
- else:
352
- task_queue.append(('pre_norm', net.conv_norm_out))
353
- task_queue.append(('silu', inplace_nonlinearity))
354
- task_queue.append(('conv_out', net.conv_out))
355
- if is_decoder and net.tanh_out:
356
- task_queue.append(('tanh', torch.tanh))
357
-
358
- return task_queue
359
-
360
-
361
- def clone_task_queue(task_queue):
362
- """
363
- Clone a task queue
364
- @param task_queue: the task queue to be cloned
365
- @return: the cloned task queue
366
- """
367
- return [[item for item in task] for task in task_queue]
368
-
369
-
370
- def get_var_mean(input, num_groups, eps=1e-6):
371
- """
372
- Get mean and var for group norm
373
- """
374
- b, c = input.size(0), input.size(1)
375
- channel_in_group = int(c/num_groups)
376
- input_reshaped = input.contiguous().view(
377
- 1, int(b * num_groups), channel_in_group, *input.size()[2:])
378
- var, mean = torch.var_mean(
379
- input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
380
- return var, mean
381
-
382
-
383
- def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
384
- """
385
- Custom group norm with fixed mean and var
386
-
387
- @param input: input tensor
388
- @param num_groups: number of groups. by default, num_groups = 32
389
- @param mean: mean, must be pre-calculated by get_var_mean
390
- @param var: var, must be pre-calculated by get_var_mean
391
- @param weight: weight, should be fetched from the original group norm
392
- @param bias: bias, should be fetched from the original group norm
393
- @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
394
-
395
- @return: normalized tensor
396
- """
397
- b, c = input.size(0), input.size(1)
398
- channel_in_group = int(c/num_groups)
399
- input_reshaped = input.contiguous().view(
400
- 1, int(b * num_groups), channel_in_group, *input.size()[2:])
401
-
402
- out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
403
- training=False, momentum=0, eps=eps)
404
-
405
- out = out.view(b, c, *input.size()[2:])
406
-
407
- # post affine transform
408
- if weight is not None:
409
- out *= weight.view(1, -1, 1, 1)
410
- if bias is not None:
411
- out += bias.view(1, -1, 1, 1)
412
- return out
413
-
414
-
415
- def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
416
- """
417
- Crop the valid region from the tile
418
- @param x: input tile
419
- @param input_bbox: original input bounding box
420
- @param target_bbox: output bounding box
421
- @param scale: scale factor
422
- @return: cropped tile
423
- """
424
- padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
425
- margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
426
- return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
427
-
428
- # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
429
-
430
-
431
- def perfcount(fn):
432
- def wrapper(*args, **kwargs):
433
- ts = time()
434
-
435
- if torch.cuda.is_available():
436
- torch.cuda.reset_peak_memory_stats(devices.device)
437
- devices.torch_gc()
438
- gc.collect()
439
-
440
- ret = fn(*args, **kwargs)
441
-
442
- devices.torch_gc()
443
- gc.collect()
444
- if torch.cuda.is_available():
445
- vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
446
- torch.cuda.reset_peak_memory_stats(devices.device)
447
- print(
448
- f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
449
- else:
450
- print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
451
-
452
- return ret
453
- return wrapper
454
-
455
- # copy end :)
456
-
457
-
458
- class GroupNormParam:
459
- def __init__(self):
460
- self.var_list = []
461
- self.mean_list = []
462
- self.pixel_list = []
463
- self.weight = None
464
- self.bias = None
465
-
466
- def add_tile(self, tile, layer):
467
- var, mean = get_var_mean(tile, 32)
468
- # For giant images, the variance can be larger than max float16
469
- # In this case we create a copy to float32
470
- if var.dtype == torch.float16 and var.isinf().any():
471
- fp32_tile = tile.float()
472
- var, mean = get_var_mean(fp32_tile, 32)
473
- # ============= DEBUG: test for infinite =============
474
- # if torch.isinf(var).any():
475
- # print('var: ', var)
476
- # ====================================================
477
- self.var_list.append(var)
478
- self.mean_list.append(mean)
479
- self.pixel_list.append(
480
- tile.shape[2]*tile.shape[3])
481
- if hasattr(layer, 'weight'):
482
- self.weight = layer.weight
483
- self.bias = layer.bias
484
- else:
485
- self.weight = None
486
- self.bias = None
487
-
488
- def summary(self):
489
- """
490
- summarize the mean and var and return a function
491
- that apply group norm on each tile
492
- """
493
- if len(self.var_list) == 0:
494
- return None
495
- var = torch.vstack(self.var_list)
496
- mean = torch.vstack(self.mean_list)
497
- max_value = max(self.pixel_list)
498
- pixels = torch.tensor(
499
- self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
500
- sum_pixels = torch.sum(pixels)
501
- pixels = pixels.unsqueeze(
502
- 1) / sum_pixels
503
- var = torch.sum(
504
- var * pixels, dim=0)
505
- mean = torch.sum(
506
- mean * pixels, dim=0)
507
- return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
508
-
509
- @staticmethod
510
- def from_tile(tile, norm):
511
- """
512
- create a function from a single tile without summary
513
- """
514
- var, mean = get_var_mean(tile, 32)
515
- if var.dtype == torch.float16 and var.isinf().any():
516
- fp32_tile = tile.float()
517
- var, mean = get_var_mean(fp32_tile, 32)
518
- # if it is a macbook, we need to convert back to float16
519
- if var.device.type == 'mps':
520
- # clamp to avoid overflow
521
- var = torch.clamp(var, 0, 60000)
522
- var = var.half()
523
- mean = mean.half()
524
- if hasattr(norm, 'weight'):
525
- weight = norm.weight
526
- bias = norm.bias
527
- else:
528
- weight = None
529
- bias = None
530
-
531
- def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
532
- return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
533
- return group_norm_func
534
-
535
-
536
- class VAEHook:
537
- def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
538
- self.net = net # encoder | decoder
539
- self.tile_size = tile_size
540
- self.is_decoder = is_decoder
541
- self.fast_mode = (fast_encoder and not is_decoder) or (
542
- fast_decoder and is_decoder)
543
- self.color_fix = color_fix and not is_decoder
544
- self.to_gpu = to_gpu
545
- self.pad = 11 if is_decoder else 32
546
-
547
- def __call__(self, x):
548
- B, C, H, W = x.shape
549
- original_device = next(self.net.parameters()).device
550
- try:
551
- if self.to_gpu:
552
- self.net.to(devices.get_optimal_device())
553
- if max(H, W) <= self.pad * 2 + self.tile_size:
554
- print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
555
- return self.net.original_forward(x)
556
- else:
557
- return self.vae_tile_forward(x)
558
- finally:
559
- self.net.to(original_device)
560
-
561
- def get_best_tile_size(self, lowerbound, upperbound):
562
- """
563
- Get the best tile size for GPU memory
564
- """
565
- divider = 32
566
- while divider >= 2:
567
- remainer = lowerbound % divider
568
- if remainer == 0:
569
- return lowerbound
570
- candidate = lowerbound - remainer + divider
571
- if candidate <= upperbound:
572
- return candidate
573
- divider //= 2
574
- return lowerbound
575
-
576
- def split_tiles(self, h, w):
577
- """
578
- Tool function to split the image into tiles
579
- @param h: height of the image
580
- @param w: width of the image
581
- @return: tile_input_bboxes, tile_output_bboxes
582
- """
583
- tile_input_bboxes, tile_output_bboxes = [], []
584
- tile_size = self.tile_size
585
- pad = self.pad
586
- num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
587
- num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
588
- # If any of the numbers are 0, we let it be 1
589
- # This is to deal with long and thin images
590
- num_height_tiles = max(num_height_tiles, 1)
591
- num_width_tiles = max(num_width_tiles, 1)
592
-
593
- # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
594
- real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
595
- real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
596
- real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
597
- real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
598
-
599
- print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
600
- f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
601
-
602
- for i in range(num_height_tiles):
603
- for j in range(num_width_tiles):
604
- # bbox: [x1, x2, y1, y2]
605
- # the padding is is unnessary for image borders. So we directly start from (32, 32)
606
- input_bbox = [
607
- pad + j * real_tile_width,
608
- min(pad + (j + 1) * real_tile_width, w),
609
- pad + i * real_tile_height,
610
- min(pad + (i + 1) * real_tile_height, h),
611
- ]
612
-
613
- # if the output bbox is close to the image boundary, we extend it to the image boundary
614
- output_bbox = [
615
- input_bbox[0] if input_bbox[0] > pad else 0,
616
- input_bbox[1] if input_bbox[1] < w - pad else w,
617
- input_bbox[2] if input_bbox[2] > pad else 0,
618
- input_bbox[3] if input_bbox[3] < h - pad else h,
619
- ]
620
-
621
- # scale to get the final output bbox
622
- output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
623
- tile_output_bboxes.append(output_bbox)
624
-
625
- # indistinguishable expand the input bbox by pad pixels
626
- tile_input_bboxes.append([
627
- max(0, input_bbox[0] - pad),
628
- min(w, input_bbox[1] + pad),
629
- max(0, input_bbox[2] - pad),
630
- min(h, input_bbox[3] + pad),
631
- ])
632
-
633
- return tile_input_bboxes, tile_output_bboxes
634
-
635
- @torch.no_grad()
636
- def estimate_group_norm(self, z, task_queue, color_fix):
637
- device = z.device
638
- tile = z
639
- last_id = len(task_queue) - 1
640
- while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
641
- last_id -= 1
642
- if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
643
- raise ValueError('No group norm found in the task queue')
644
- # estimate until the last group norm
645
- for i in range(last_id + 1):
646
- task = task_queue[i]
647
- if task[0] == 'pre_norm':
648
- group_norm_func = GroupNormParam.from_tile(tile, task[1])
649
- task_queue[i] = ('apply_norm', group_norm_func)
650
- if i == last_id:
651
- return True
652
- tile = group_norm_func(tile)
653
- elif task[0] == 'store_res':
654
- task_id = i + 1
655
- while task_id < last_id and task_queue[task_id][0] != 'add_res':
656
- task_id += 1
657
- if task_id >= last_id:
658
- continue
659
- task_queue[task_id][1] = task[1](tile)
660
- elif task[0] == 'add_res':
661
- tile += task[1].to(device)
662
- task[1] = None
663
- elif color_fix and task[0] == 'downsample':
664
- for j in range(i, last_id + 1):
665
- if task_queue[j][0] == 'store_res':
666
- task_queue[j] = ('store_res_cpu', task_queue[j][1])
667
- return True
668
- else:
669
- tile = task[1](tile)
670
- try:
671
- devices.test_for_nans(tile, "vae")
672
- except:
673
- print(f'Nan detected in fast mode estimation. Fast mode disabled.')
674
- return False
675
-
676
- raise IndexError('Should not reach here')
677
-
678
- @perfcount
679
- @torch.no_grad()
680
- def vae_tile_forward(self, z):
681
- """
682
- Decode a latent vector z into an image in a tiled manner.
683
- @param z: latent vector
684
- @return: image
685
- """
686
- device = next(self.net.parameters()).device
687
- net = self.net
688
- tile_size = self.tile_size
689
- is_decoder = self.is_decoder
690
-
691
- z = z.detach() # detach the input to avoid backprop
692
-
693
- N, height, width = z.shape[0], z.shape[2], z.shape[3]
694
- net.last_z_shape = z.shape
695
-
696
- # Split the input into tiles and build a task queue for each tile
697
- print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
698
-
699
- in_bboxes, out_bboxes = self.split_tiles(height, width)
700
-
701
- # Prepare tiles by split the input latents
702
- tiles = []
703
- for input_bbox in in_bboxes:
704
- tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
705
- tiles.append(tile)
706
-
707
- num_tiles = len(tiles)
708
- num_completed = 0
709
-
710
- # Build task queues
711
- single_task_queue = build_task_queue(net, is_decoder)
712
- # print(single_task_queue)
713
- if self.fast_mode:
714
- # Fast mode: downsample the input image to the tile size,
715
- # then estimate the group norm parameters on the downsampled image
716
- scale_factor = tile_size / max(height, width)
717
- z = z.to(device)
718
- downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
719
- # use nearest-exact to keep statictics as close as possible
720
- print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
721
-
722
- # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
723
- # The downsampling will heavily distort its mean and std, so we need to recover it.
724
- std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
725
- std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
726
- downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
727
- del std_old, mean_old, std_new, mean_new
728
- # occasionally the std_new is too small or too large, which exceeds the range of float16
729
- # so we need to clamp it to max z's range.
730
- downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
731
- estimate_task_queue = clone_task_queue(single_task_queue)
732
- if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
733
- single_task_queue = estimate_task_queue
734
- del downsampled_z
735
-
736
- task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
737
-
738
- # Dummy result
739
- result = None
740
- result_approx = None
741
- #try:
742
- # with devices.autocast():
743
- # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
744
- #except: pass
745
- # Free memory of input latent tensor
746
- del z
747
-
748
- # Task queue execution
749
- pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
750
-
751
- # execute the task back and forth when switch tiles so that we always
752
- # keep one tile on the GPU to reduce unnecessary data transfer
753
- forward = True
754
- interrupted = False
755
- #state.interrupted = interrupted
756
- while True:
757
- #if state.interrupted: interrupted = True ; break
758
-
759
- group_norm_param = GroupNormParam()
760
- for i in range(num_tiles) if forward else reversed(range(num_tiles)):
761
- #if state.interrupted: interrupted = True ; break
762
-
763
- tile = tiles[i].to(device)
764
- input_bbox = in_bboxes[i]
765
- task_queue = task_queues[i]
766
-
767
- interrupted = False
768
- while len(task_queue) > 0:
769
- #if state.interrupted: interrupted = True ; break
770
-
771
- # DEBUG: current task
772
- # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
773
- task = task_queue.pop(0)
774
- if task[0] == 'pre_norm':
775
- group_norm_param.add_tile(tile, task[1])
776
- break
777
- elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
778
- task_id = 0
779
- res = task[1](tile)
780
- if not self.fast_mode or task[0] == 'store_res_cpu':
781
- res = res.cpu()
782
- while task_queue[task_id][0] != 'add_res':
783
- task_id += 1
784
- task_queue[task_id][1] = res
785
- elif task[0] == 'add_res':
786
- tile += task[1].to(device)
787
- task[1] = None
788
- else:
789
- tile = task[1](tile)
790
- pbar.update(1)
791
-
792
- if interrupted: break
793
-
794
- # check for NaNs in the tile.
795
- # If there are NaNs, we abort the process to save user's time
796
- #devices.test_for_nans(tile, "vae")
797
-
798
- #print(tiles[i].shape, tile.shape, i, num_tiles)
799
- if len(task_queue) == 0:
800
- tiles[i] = None
801
- num_completed += 1
802
- if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
803
- result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
804
- result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
805
- del tile
806
- elif i == num_tiles - 1 and forward:
807
- forward = False
808
- tiles[i] = tile
809
- elif i == 0 and not forward:
810
- forward = True
811
- tiles[i] = tile
812
- else:
813
- tiles[i] = tile.cpu()
814
- del tile
815
-
816
- if interrupted: break
817
- if num_completed == num_tiles: break
818
-
819
- # insert the group norm task to the head of each task queue
820
- group_norm_func = group_norm_param.summary()
821
- if group_norm_func is not None:
822
- for i in range(num_tiles):
823
- task_queue = task_queues[i]
824
- task_queue.insert(0, ('apply_norm', group_norm_func))
825
-
826
- # Done!
827
- pbar.close()
828
- return result if result is not None else result_approx.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/wavelet_color_fix.py DELETED
@@ -1,119 +0,0 @@
1
- '''
2
- # --------------------------------------------------------------------------------
3
- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
- # --------------------------------------------------------------------------------
5
- '''
6
-
7
- import torch
8
- from PIL import Image
9
- from torch import Tensor
10
- from torch.nn import functional as F
11
-
12
- from torchvision.transforms import ToTensor, ToPILImage
13
-
14
- def adain_color_fix(target: Image, source: Image):
15
- # Convert images to tensors
16
- to_tensor = ToTensor()
17
- target_tensor = to_tensor(target).unsqueeze(0)
18
- source_tensor = to_tensor(source).unsqueeze(0)
19
-
20
- # Apply adaptive instance normalization
21
- result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22
-
23
- # Convert tensor back to image
24
- to_image = ToPILImage()
25
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26
-
27
- return result_image
28
-
29
- def wavelet_color_fix(target: Image, source: Image):
30
- # Convert images to tensors
31
- to_tensor = ToTensor()
32
- target_tensor = to_tensor(target).unsqueeze(0)
33
- source_tensor = to_tensor(source).unsqueeze(0)
34
-
35
- # Apply wavelet reconstruction
36
- result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37
-
38
- # Convert tensor back to image
39
- to_image = ToPILImage()
40
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41
-
42
- return result_image
43
-
44
- def calc_mean_std(feat: Tensor, eps=1e-5):
45
- """Calculate mean and std for adaptive_instance_normalization.
46
- Args:
47
- feat (Tensor): 4D tensor.
48
- eps (float): A small value added to the variance to avoid
49
- divide-by-zero. Default: 1e-5.
50
- """
51
- size = feat.size()
52
- assert len(size) == 4, 'The input feature should be 4D tensor.'
53
- b, c = size[:2]
54
- feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55
- feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56
- feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57
- return feat_mean, feat_std
58
-
59
- def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60
- """Adaptive instance normalization.
61
- Adjust the reference features to have the similar color and illuminations
62
- as those in the degradate features.
63
- Args:
64
- content_feat (Tensor): The reference feature.
65
- style_feat (Tensor): The degradate features.
66
- """
67
- size = content_feat.size()
68
- style_mean, style_std = calc_mean_std(style_feat)
69
- content_mean, content_std = calc_mean_std(content_feat)
70
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72
-
73
- def wavelet_blur(image: Tensor, radius: int):
74
- """
75
- Apply wavelet blur to the input tensor.
76
- """
77
- # input shape: (1, 3, H, W)
78
- # convolution kernel
79
- kernel_vals = [
80
- [0.0625, 0.125, 0.0625],
81
- [0.125, 0.25, 0.125],
82
- [0.0625, 0.125, 0.0625],
83
- ]
84
- kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85
- # add channel dimensions to the kernel to make it a 4D tensor
86
- kernel = kernel[None, None]
87
- # repeat the kernel across all input channels
88
- kernel = kernel.repeat(3, 1, 1, 1)
89
- image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90
- # apply convolution
91
- output = F.conv2d(image, kernel, groups=3, dilation=radius)
92
- return output
93
-
94
- def wavelet_decomposition(image: Tensor, levels=5):
95
- """
96
- Apply wavelet decomposition to the input tensor.
97
- This function only returns the low frequency & the high frequency.
98
- """
99
- high_freq = torch.zeros_like(image)
100
- for i in range(levels):
101
- radius = 2 ** i
102
- low_freq = wavelet_blur(image, radius)
103
- high_freq += (image - low_freq)
104
- image = low_freq
105
-
106
- return high_freq, low_freq
107
-
108
- def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109
- """
110
- Apply wavelet decomposition, so that the content will have the same color as the style.
111
- """
112
- # calculate the wavelet decomposition of the content feature
113
- content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114
- del content_low_freq
115
- # calculate the wavelet decomposition of the style feature
116
- style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117
- del style_high_freq
118
- # reconstruct the content feature with the style's high frequency
119
- return content_high_freq + style_low_freq