Fabrice-TIERCELIN commited on
Commit
e6d98f9
·
verified ·
1 Parent(s): 8c19f42

Delete SUPIR

Browse files
SUPIR/__init__.py DELETED
File without changes
SUPIR/models/SUPIR_model.py DELETED
@@ -1,195 +0,0 @@
1
- import torch
2
- from sgm.models.diffusion import DiffusionEngine
3
- from sgm.util import instantiate_from_config
4
- import copy
5
- from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
6
- import random
7
- from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
8
- from pytorch_lightning import seed_everything
9
- from torch.nn.functional import interpolate
10
- from SUPIR.utils.tilevae import VAEHook
11
-
12
- class SUPIRModel(DiffusionEngine):
13
- def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
14
- super().__init__(*args, **kwargs)
15
- control_model = instantiate_from_config(control_stage_config)
16
- self.model.load_control_model(control_model)
17
- self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
18
- self.sampler_config = kwargs['sampler_config']
19
-
20
- assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
21
- if ae_dtype == 'fp32':
22
- ae_dtype = torch.float32
23
- elif ae_dtype == 'fp16':
24
- raise RuntimeError('fp16 cause NaN in AE')
25
- elif ae_dtype == 'bf16':
26
- ae_dtype = torch.bfloat16
27
-
28
- if diffusion_dtype == 'fp32':
29
- diffusion_dtype = torch.float32
30
- elif diffusion_dtype == 'fp16':
31
- diffusion_dtype = torch.float16
32
- elif diffusion_dtype == 'bf16':
33
- diffusion_dtype = torch.bfloat16
34
-
35
- self.ae_dtype = ae_dtype
36
- self.model.dtype = diffusion_dtype
37
-
38
- self.p_p = p_p
39
- self.n_p = n_p
40
-
41
- @torch.no_grad()
42
- def encode_first_stage(self, x):
43
- with torch.autocast("cuda", dtype=self.ae_dtype):
44
- z = self.first_stage_model.encode(x)
45
- z = self.scale_factor * z
46
- return z
47
-
48
- @torch.no_grad()
49
- def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
50
- with torch.autocast("cuda", dtype=self.ae_dtype):
51
- if is_stage1:
52
- h = self.first_stage_model.denoise_encoder_s1(x)
53
- else:
54
- h = self.first_stage_model.denoise_encoder(x)
55
- moments = self.first_stage_model.quant_conv(h)
56
- posterior = DiagonalGaussianDistribution(moments)
57
- if use_sample:
58
- z = posterior.sample()
59
- else:
60
- z = posterior.mode()
61
- z = self.scale_factor * z
62
- return z
63
-
64
- @torch.no_grad()
65
- def decode_first_stage(self, z):
66
- z = 1.0 / self.scale_factor * z
67
- with torch.autocast("cuda", dtype=self.ae_dtype):
68
- out = self.first_stage_model.decode(z)
69
- return out.float()
70
-
71
- @torch.no_grad()
72
- def batchify_denoise(self, x, is_stage1=False):
73
- '''
74
- [N, C, H, W], [-1, 1], RGB
75
- '''
76
- x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
77
- return self.decode_first_stage(x)
78
-
79
- @torch.no_grad()
80
- def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
81
- num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
82
- cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
83
- '''
84
- [N, C], [-1, 1], RGB
85
- '''
86
- assert len(x) == len(p)
87
- assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
88
-
89
- N = len(x)
90
- if num_samples > 1:
91
- assert N == 1
92
- N = num_samples
93
- x = x.repeat(N, 1, 1, 1)
94
- p = p * N
95
-
96
- if p_p == 'default':
97
- p_p = self.p_p
98
- if n_p == 'default':
99
- n_p = self.n_p
100
-
101
- self.sampler_config.params.num_steps = num_steps
102
- if use_linear_CFG:
103
- self.sampler_config.params.guider_config.params.scale_min = cfg_scale
104
- self.sampler_config.params.guider_config.params.scale = cfg_scale_start
105
- else:
106
- self.sampler_config.params.guider_config.params.scale_min = cfg_scale
107
- self.sampler_config.params.guider_config.params.scale = cfg_scale
108
- self.sampler_config.params.restore_cfg = restoration_scale
109
- self.sampler_config.params.s_churn = s_churn
110
- self.sampler_config.params.s_noise = s_noise
111
- self.sampler = instantiate_from_config(self.sampler_config)
112
-
113
- if seed == -1:
114
- seed = random.randint(0, 65535)
115
- seed_everything(seed)
116
-
117
- _z = self.encode_first_stage_with_denoise(x, use_sample=False)
118
- x_stage1 = self.decode_first_stage(_z)
119
- z_stage1 = self.encode_first_stage(x_stage1)
120
-
121
- c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
122
-
123
- denoiser = lambda input, sigma, c, control_scale: self.denoiser(
124
- self.model, input, sigma, c, control_scale, **kwargs
125
- )
126
-
127
- noised_z = torch.randn_like(_z).to(_z.device)
128
-
129
- _samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
130
- use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
131
- samples = self.decode_first_stage(_samples)
132
- if color_fix_type == 'Wavelet':
133
- samples = wavelet_reconstruction(samples, x_stage1)
134
- elif color_fix_type == 'AdaIn':
135
- samples = adaptive_instance_normalization(samples, x_stage1)
136
- return samples
137
-
138
- def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
139
- self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
140
- self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
141
- self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
142
- self.first_stage_model.denoise_encoder.forward = VAEHook(
143
- self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
144
- fast_encoder=False, color_fix=False, to_gpu=True)
145
- self.first_stage_model.encoder.forward = VAEHook(
146
- self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
147
- fast_encoder=False, color_fix=False, to_gpu=True)
148
- self.first_stage_model.decoder.forward = VAEHook(
149
- self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
150
- fast_encoder=False, color_fix=False, to_gpu=True)
151
-
152
- def prepare_condition(self, _z, p, p_p, n_p, N):
153
- batch = {}
154
- batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
155
- batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
156
- batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
157
- batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
158
- batch['control'] = _z
159
-
160
- batch_uc = copy.deepcopy(batch)
161
- batch_uc['txt'] = [n_p for _ in p]
162
-
163
- if not isinstance(p[0], list):
164
- batch['txt'] = [''.join([_p, p_p]) for _p in p]
165
- with torch.cuda.amp.autocast(dtype=self.ae_dtype):
166
- c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
167
- else:
168
- assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
169
- p_tiles = p[0]
170
- c = []
171
- for i, p_tile in enumerate(p_tiles):
172
- batch['txt'] = [''.join([p_tile, p_p])]
173
- with torch.cuda.amp.autocast(dtype=self.ae_dtype):
174
- if i == 0:
175
- _c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
176
- else:
177
- _c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
178
- c.append(_c)
179
- return c, uc
180
-
181
-
182
- if __name__ == '__main__':
183
- from SUPIR.util import create_model, load_state_dict
184
-
185
- model = create_model('../../options/dev/SUPIR_paper_version.yaml')
186
-
187
- SDXL_CKPT = '/opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors'
188
- SUPIR_CKPT = '/opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-paper.ckpt'
189
- model.load_state_dict(load_state_dict(SDXL_CKPT), strict=False)
190
- model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
191
- model = model.cuda()
192
-
193
- x = torch.randn(1, 3, 512, 512).cuda()
194
- p = ['a professional, detailed, high-quality photo']
195
- samples = model.batchify_sample(x, p, num_steps=50, restoration_scale=4.0, s_churn=0, cfg_scale=4.0, seed=-1, num_samples=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/models/__init__.py DELETED
File without changes
SUPIR/modules/SUPIR_v0.py DELETED
@@ -1,718 +0,0 @@
1
- # from einops._torch_specific import allow_ops_in_compiled_graph
2
- # allow_ops_in_compiled_graph()
3
- import einops
4
- import torch
5
- import torch as th
6
- import torch.nn as nn
7
- from einops import rearrange, repeat
8
-
9
- from sgm.modules.diffusionmodules.util import (
10
- avg_pool_nd,
11
- checkpoint,
12
- conv_nd,
13
- linear,
14
- normalization,
15
- timestep_embedding,
16
- zero_module,
17
- )
18
-
19
- from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
20
- TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
21
- from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
22
- from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
23
- import re
24
- import torch
25
- from functools import partial
26
-
27
-
28
- try:
29
- import xformers
30
- import xformers.ops
31
- XFORMERS_IS_AVAILBLE = True
32
- except:
33
- XFORMERS_IS_AVAILBLE = False
34
-
35
-
36
- # dummy replace
37
- def convert_module_to_f16(x):
38
- pass
39
-
40
-
41
- def convert_module_to_f32(x):
42
- pass
43
-
44
-
45
- class ZeroConv(nn.Module):
46
- def __init__(self, label_nc, norm_nc, mask=False):
47
- super().__init__()
48
- self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
49
- self.mask = mask
50
-
51
- def forward(self, c, h, h_ori=None):
52
- # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
53
- if not self.mask:
54
- h = h + self.zero_conv(c)
55
- else:
56
- h = h + self.zero_conv(c) * torch.zeros_like(h)
57
- if h_ori is not None:
58
- h = th.cat([h_ori, h], dim=1)
59
- return h
60
-
61
-
62
- class ZeroSFT(nn.Module):
63
- def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
64
- super().__init__()
65
-
66
- # param_free_norm_type = str(parsed.group(1))
67
- ks = 3
68
- pw = ks // 2
69
-
70
- self.norm = norm
71
- if self.norm:
72
- self.param_free_norm = normalization(norm_nc + concat_channels)
73
- else:
74
- self.param_free_norm = nn.Identity()
75
-
76
- nhidden = 128
77
-
78
- self.mlp_shared = nn.Sequential(
79
- nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
80
- nn.SiLU()
81
- )
82
- self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
83
- self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
84
- # self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
85
- # self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
86
-
87
- self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
88
- self.pre_concat = bool(concat_channels != 0)
89
- self.mask = mask
90
-
91
- def forward(self, c, h, h_ori=None, control_scale=1):
92
- assert self.mask is False
93
- if h_ori is not None and self.pre_concat:
94
- h_raw = th.cat([h_ori, h], dim=1)
95
- else:
96
- h_raw = h
97
-
98
- if self.mask:
99
- h = h + self.zero_conv(c) * torch.zeros_like(h)
100
- else:
101
- h = h + self.zero_conv(c)
102
- if h_ori is not None and self.pre_concat:
103
- h = th.cat([h_ori, h], dim=1)
104
- actv = self.mlp_shared(c)
105
- gamma = self.zero_mul(actv)
106
- beta = self.zero_add(actv)
107
- if self.mask:
108
- gamma = gamma * torch.zeros_like(gamma)
109
- beta = beta * torch.zeros_like(beta)
110
- h = self.param_free_norm(h) * (gamma + 1) + beta
111
- if h_ori is not None and not self.pre_concat:
112
- h = th.cat([h_ori, h], dim=1)
113
- return h * control_scale + h_raw * (1 - control_scale)
114
-
115
-
116
- class ZeroCrossAttn(nn.Module):
117
- ATTENTION_MODES = {
118
- "softmax": CrossAttention, # vanilla attention
119
- "softmax-xformers": MemoryEfficientCrossAttention
120
- }
121
-
122
- def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
123
- super().__init__()
124
- attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
125
- assert attn_mode in self.ATTENTION_MODES
126
- attn_cls = self.ATTENTION_MODES[attn_mode]
127
- self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
128
- self.norm1 = normalization(query_dim)
129
- self.norm2 = normalization(context_dim)
130
-
131
- self.mask = mask
132
-
133
- # if zero_out:
134
- # # for p in self.attn.to_out.parameters():
135
- # # p.detach().zero_()
136
- # self.attn.to_out = zero_module(self.attn.to_out)
137
-
138
- def forward(self, context, x, control_scale=1):
139
- assert self.mask is False
140
- x_in = x
141
- x = self.norm1(x)
142
- context = self.norm2(context)
143
- b, c, h, w = x.shape
144
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
145
- context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
146
- x = self.attn(x, context)
147
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
148
- if self.mask:
149
- x = x * torch.zeros_like(x)
150
- x = x_in + x * control_scale
151
-
152
- return x
153
-
154
-
155
- class GLVControl(nn.Module):
156
- def __init__(
157
- self,
158
- in_channels,
159
- model_channels,
160
- out_channels,
161
- num_res_blocks,
162
- attention_resolutions,
163
- dropout=0,
164
- channel_mult=(1, 2, 4, 8),
165
- conv_resample=True,
166
- dims=2,
167
- num_classes=None,
168
- use_checkpoint=False,
169
- use_fp16=False,
170
- num_heads=-1,
171
- num_head_channels=-1,
172
- num_heads_upsample=-1,
173
- use_scale_shift_norm=False,
174
- resblock_updown=False,
175
- use_new_attention_order=False,
176
- use_spatial_transformer=False, # custom transformer support
177
- transformer_depth=1, # custom transformer support
178
- context_dim=None, # custom transformer support
179
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
180
- legacy=True,
181
- disable_self_attentions=None,
182
- num_attention_blocks=None,
183
- disable_middle_self_attn=False,
184
- use_linear_in_transformer=False,
185
- spatial_transformer_attn_type="softmax",
186
- adm_in_channels=None,
187
- use_fairscale_checkpoint=False,
188
- offload_to_cpu=False,
189
- transformer_depth_middle=None,
190
- input_upscale=1,
191
- ):
192
- super().__init__()
193
- from omegaconf.listconfig import ListConfig
194
-
195
- if use_spatial_transformer:
196
- assert (
197
- context_dim is not None
198
- ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
199
-
200
- if context_dim is not None:
201
- assert (
202
- use_spatial_transformer
203
- ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
204
- if type(context_dim) == ListConfig:
205
- context_dim = list(context_dim)
206
-
207
- if num_heads_upsample == -1:
208
- num_heads_upsample = num_heads
209
-
210
- if num_heads == -1:
211
- assert (
212
- num_head_channels != -1
213
- ), "Either num_heads or num_head_channels has to be set"
214
-
215
- if num_head_channels == -1:
216
- assert (
217
- num_heads != -1
218
- ), "Either num_heads or num_head_channels has to be set"
219
-
220
- self.in_channels = in_channels
221
- self.model_channels = model_channels
222
- self.out_channels = out_channels
223
- if isinstance(transformer_depth, int):
224
- transformer_depth = len(channel_mult) * [transformer_depth]
225
- elif isinstance(transformer_depth, ListConfig):
226
- transformer_depth = list(transformer_depth)
227
- transformer_depth_middle = default(
228
- transformer_depth_middle, transformer_depth[-1]
229
- )
230
-
231
- if isinstance(num_res_blocks, int):
232
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
233
- else:
234
- if len(num_res_blocks) != len(channel_mult):
235
- raise ValueError(
236
- "provide num_res_blocks either as an int (globally constant) or "
237
- "as a list/tuple (per-level) with the same length as channel_mult"
238
- )
239
- self.num_res_blocks = num_res_blocks
240
- # self.num_res_blocks = num_res_blocks
241
- if disable_self_attentions is not None:
242
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
243
- assert len(disable_self_attentions) == len(channel_mult)
244
- if num_attention_blocks is not None:
245
- assert len(num_attention_blocks) == len(self.num_res_blocks)
246
- assert all(
247
- map(
248
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
249
- range(len(num_attention_blocks)),
250
- )
251
- )
252
- print(
253
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
254
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
255
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
256
- f"attention will still not be set."
257
- ) # todo: convert to warning
258
-
259
- self.attention_resolutions = attention_resolutions
260
- self.dropout = dropout
261
- self.channel_mult = channel_mult
262
- self.conv_resample = conv_resample
263
- self.num_classes = num_classes
264
- self.use_checkpoint = use_checkpoint
265
- if use_fp16:
266
- print("WARNING: use_fp16 was dropped and has no effect anymore.")
267
- # self.dtype = th.float16 if use_fp16 else th.float32
268
- self.num_heads = num_heads
269
- self.num_head_channels = num_head_channels
270
- self.num_heads_upsample = num_heads_upsample
271
- self.predict_codebook_ids = n_embed is not None
272
-
273
- assert use_fairscale_checkpoint != use_checkpoint or not (
274
- use_checkpoint or use_fairscale_checkpoint
275
- )
276
-
277
- self.use_fairscale_checkpoint = False
278
- checkpoint_wrapper_fn = (
279
- partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
280
- if self.use_fairscale_checkpoint
281
- else lambda x: x
282
- )
283
-
284
- time_embed_dim = model_channels * 4
285
- self.time_embed = checkpoint_wrapper_fn(
286
- nn.Sequential(
287
- linear(model_channels, time_embed_dim),
288
- nn.SiLU(),
289
- linear(time_embed_dim, time_embed_dim),
290
- )
291
- )
292
-
293
- if self.num_classes is not None:
294
- if isinstance(self.num_classes, int):
295
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
296
- elif self.num_classes == "continuous":
297
- print("setting up linear c_adm embedding layer")
298
- self.label_emb = nn.Linear(1, time_embed_dim)
299
- elif self.num_classes == "timestep":
300
- self.label_emb = checkpoint_wrapper_fn(
301
- nn.Sequential(
302
- Timestep(model_channels),
303
- nn.Sequential(
304
- linear(model_channels, time_embed_dim),
305
- nn.SiLU(),
306
- linear(time_embed_dim, time_embed_dim),
307
- ),
308
- )
309
- )
310
- elif self.num_classes == "sequential":
311
- assert adm_in_channels is not None
312
- self.label_emb = nn.Sequential(
313
- nn.Sequential(
314
- linear(adm_in_channels, time_embed_dim),
315
- nn.SiLU(),
316
- linear(time_embed_dim, time_embed_dim),
317
- )
318
- )
319
- else:
320
- raise ValueError()
321
-
322
- self.input_blocks = nn.ModuleList(
323
- [
324
- TimestepEmbedSequential(
325
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
326
- )
327
- ]
328
- )
329
- self._feature_size = model_channels
330
- input_block_chans = [model_channels]
331
- ch = model_channels
332
- ds = 1
333
- for level, mult in enumerate(channel_mult):
334
- for nr in range(self.num_res_blocks[level]):
335
- layers = [
336
- checkpoint_wrapper_fn(
337
- ResBlock(
338
- ch,
339
- time_embed_dim,
340
- dropout,
341
- out_channels=mult * model_channels,
342
- dims=dims,
343
- use_checkpoint=use_checkpoint,
344
- use_scale_shift_norm=use_scale_shift_norm,
345
- )
346
- )
347
- ]
348
- ch = mult * model_channels
349
- if ds in attention_resolutions:
350
- if num_head_channels == -1:
351
- dim_head = ch // num_heads
352
- else:
353
- num_heads = ch // num_head_channels
354
- dim_head = num_head_channels
355
- if legacy:
356
- # num_heads = 1
357
- dim_head = (
358
- ch // num_heads
359
- if use_spatial_transformer
360
- else num_head_channels
361
- )
362
- if exists(disable_self_attentions):
363
- disabled_sa = disable_self_attentions[level]
364
- else:
365
- disabled_sa = False
366
-
367
- if (
368
- not exists(num_attention_blocks)
369
- or nr < num_attention_blocks[level]
370
- ):
371
- layers.append(
372
- checkpoint_wrapper_fn(
373
- AttentionBlock(
374
- ch,
375
- use_checkpoint=use_checkpoint,
376
- num_heads=num_heads,
377
- num_head_channels=dim_head,
378
- use_new_attention_order=use_new_attention_order,
379
- )
380
- )
381
- if not use_spatial_transformer
382
- else checkpoint_wrapper_fn(
383
- SpatialTransformer(
384
- ch,
385
- num_heads,
386
- dim_head,
387
- depth=transformer_depth[level],
388
- context_dim=context_dim,
389
- disable_self_attn=disabled_sa,
390
- use_linear=use_linear_in_transformer,
391
- attn_type=spatial_transformer_attn_type,
392
- use_checkpoint=use_checkpoint,
393
- )
394
- )
395
- )
396
- self.input_blocks.append(TimestepEmbedSequential(*layers))
397
- self._feature_size += ch
398
- input_block_chans.append(ch)
399
- if level != len(channel_mult) - 1:
400
- out_ch = ch
401
- self.input_blocks.append(
402
- TimestepEmbedSequential(
403
- checkpoint_wrapper_fn(
404
- ResBlock(
405
- ch,
406
- time_embed_dim,
407
- dropout,
408
- out_channels=out_ch,
409
- dims=dims,
410
- use_checkpoint=use_checkpoint,
411
- use_scale_shift_norm=use_scale_shift_norm,
412
- down=True,
413
- )
414
- )
415
- if resblock_updown
416
- else Downsample(
417
- ch, conv_resample, dims=dims, out_channels=out_ch
418
- )
419
- )
420
- )
421
- ch = out_ch
422
- input_block_chans.append(ch)
423
- ds *= 2
424
- self._feature_size += ch
425
-
426
- if num_head_channels == -1:
427
- dim_head = ch // num_heads
428
- else:
429
- num_heads = ch // num_head_channels
430
- dim_head = num_head_channels
431
- if legacy:
432
- # num_heads = 1
433
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
434
- self.middle_block = TimestepEmbedSequential(
435
- checkpoint_wrapper_fn(
436
- ResBlock(
437
- ch,
438
- time_embed_dim,
439
- dropout,
440
- dims=dims,
441
- use_checkpoint=use_checkpoint,
442
- use_scale_shift_norm=use_scale_shift_norm,
443
- )
444
- ),
445
- checkpoint_wrapper_fn(
446
- AttentionBlock(
447
- ch,
448
- use_checkpoint=use_checkpoint,
449
- num_heads=num_heads,
450
- num_head_channels=dim_head,
451
- use_new_attention_order=use_new_attention_order,
452
- )
453
- )
454
- if not use_spatial_transformer
455
- else checkpoint_wrapper_fn(
456
- SpatialTransformer( # always uses a self-attn
457
- ch,
458
- num_heads,
459
- dim_head,
460
- depth=transformer_depth_middle,
461
- context_dim=context_dim,
462
- disable_self_attn=disable_middle_self_attn,
463
- use_linear=use_linear_in_transformer,
464
- attn_type=spatial_transformer_attn_type,
465
- use_checkpoint=use_checkpoint,
466
- )
467
- ),
468
- checkpoint_wrapper_fn(
469
- ResBlock(
470
- ch,
471
- time_embed_dim,
472
- dropout,
473
- dims=dims,
474
- use_checkpoint=use_checkpoint,
475
- use_scale_shift_norm=use_scale_shift_norm,
476
- )
477
- ),
478
- )
479
-
480
- self.input_upscale = input_upscale
481
- self.input_hint_block = TimestepEmbedSequential(
482
- zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
483
- )
484
-
485
- def convert_to_fp16(self):
486
- """
487
- Convert the torso of the model to float16.
488
- """
489
- self.input_blocks.apply(convert_module_to_f16)
490
- self.middle_block.apply(convert_module_to_f16)
491
-
492
- def convert_to_fp32(self):
493
- """
494
- Convert the torso of the model to float32.
495
- """
496
- self.input_blocks.apply(convert_module_to_f32)
497
- self.middle_block.apply(convert_module_to_f32)
498
-
499
- def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
500
- # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
501
- # x = x.to(torch.float32)
502
- # timesteps = timesteps.to(torch.float32)
503
- # xt = xt.to(torch.float32)
504
- # context = context.to(torch.float32)
505
- # y = y.to(torch.float32)
506
- # print(x.dtype)
507
- xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
508
-
509
- if self.input_upscale != 1:
510
- x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
511
- assert (y is not None) == (
512
- self.num_classes is not None
513
- ), "must specify y if and only if the model is class-conditional"
514
- hs = []
515
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
516
- # import pdb
517
- # pdb.set_trace()
518
- emb = self.time_embed(t_emb)
519
-
520
- if self.num_classes is not None:
521
- assert y.shape[0] == xt.shape[0]
522
- emb = emb + self.label_emb(y)
523
-
524
- guided_hint = self.input_hint_block(x, emb, context)
525
-
526
- # h = x.type(self.dtype)
527
- h = xt
528
- for module in self.input_blocks:
529
- if guided_hint is not None:
530
- h = module(h, emb, context)
531
- h += guided_hint
532
- guided_hint = None
533
- else:
534
- h = module(h, emb, context)
535
- hs.append(h)
536
- # print(module)
537
- # print(h.shape)
538
- h = self.middle_block(h, emb, context)
539
- hs.append(h)
540
- return hs
541
-
542
-
543
- class LightGLVUNet(UNetModel):
544
- def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
545
- *args, **kwargs):
546
- super().__init__(*args, **kwargs)
547
- if mode == 'XL-base':
548
- cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
549
- project_channels = [160] * 4 + [320] * 3 + [640] * 3
550
- concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
551
- cross_attn_insert_idx = [6, 3]
552
- self.progressive_mask_nums = [0, 3, 7, 11]
553
- elif mode == 'XL-refine':
554
- cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
555
- project_channels = [192] * 4 + [384] * 3 + [768] * 6
556
- concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
557
- cross_attn_insert_idx = [9, 6, 3]
558
- self.progressive_mask_nums = [0, 3, 6, 10, 14]
559
- else:
560
- raise NotImplementedError
561
-
562
- project_channels = [int(c * project_channel_scale) for c in project_channels]
563
-
564
- self.project_modules = nn.ModuleList()
565
- for i in range(len(cond_output_channels)):
566
- # if i == len(cond_output_channels) - 1:
567
- # _project_type = 'ZeroCrossAttn'
568
- # else:
569
- # _project_type = project_type
570
- _project_type = project_type
571
- if _project_type == 'ZeroSFT':
572
- self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
573
- concat_channels=concat_channels[i]))
574
- elif _project_type == 'ZeroCrossAttn':
575
- self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
576
- else:
577
- raise NotImplementedError
578
-
579
- for i in cross_attn_insert_idx:
580
- self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
581
- # print(self.project_modules[i])
582
-
583
- def step_progressive_mask(self):
584
- if len(self.progressive_mask_nums) > 0:
585
- mask_num = self.progressive_mask_nums.pop()
586
- for i in range(len(self.project_modules)):
587
- if i < mask_num:
588
- self.project_modules[i].mask = True
589
- else:
590
- self.project_modules[i].mask = False
591
- return
592
- # print(f'step_progressive_mask, current masked layers: {mask_num}')
593
- else:
594
- return
595
- # print('step_progressive_mask, no more masked layers')
596
- # for i in range(len(self.project_modules)):
597
- # print(self.project_modules[i].mask)
598
-
599
-
600
- def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
601
- """
602
- Apply the model to an input batch.
603
- :param x: an [N x C x ...] Tensor of inputs.
604
- :param timesteps: a 1-D batch of timesteps.
605
- :param context: conditioning plugged in via crossattn
606
- :param y: an [N] Tensor of labels, if class-conditional.
607
- :return: an [N x C x ...] Tensor of outputs.
608
- """
609
- assert (y is not None) == (
610
- self.num_classes is not None
611
- ), "must specify y if and only if the model is class-conditional"
612
- hs = []
613
-
614
- _dtype = control[0].dtype
615
- x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
616
-
617
- with torch.no_grad():
618
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
619
- emb = self.time_embed(t_emb)
620
-
621
- if self.num_classes is not None:
622
- assert y.shape[0] == x.shape[0]
623
- emb = emb + self.label_emb(y)
624
-
625
- # h = x.type(self.dtype)
626
- h = x
627
- for module in self.input_blocks:
628
- h = module(h, emb, context)
629
- hs.append(h)
630
-
631
- adapter_idx = len(self.project_modules) - 1
632
- control_idx = len(control) - 1
633
- h = self.middle_block(h, emb, context)
634
- h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
635
- adapter_idx -= 1
636
- control_idx -= 1
637
-
638
- for i, module in enumerate(self.output_blocks):
639
- _h = hs.pop()
640
- h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
641
- adapter_idx -= 1
642
- # h = th.cat([h, _h], dim=1)
643
- if len(module) == 3:
644
- assert isinstance(module[2], Upsample)
645
- for layer in module[:2]:
646
- if isinstance(layer, TimestepBlock):
647
- h = layer(h, emb)
648
- elif isinstance(layer, SpatialTransformer):
649
- h = layer(h, context)
650
- else:
651
- h = layer(h)
652
- # print('cross_attn_here')
653
- h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
654
- adapter_idx -= 1
655
- h = module[2](h)
656
- else:
657
- h = module(h, emb, context)
658
- control_idx -= 1
659
- # print(module)
660
- # print(h.shape)
661
-
662
- h = h.type(x.dtype)
663
- if self.predict_codebook_ids:
664
- assert False, "not supported anymore. what the f*** are you doing?"
665
- else:
666
- return self.out(h)
667
-
668
- if __name__ == '__main__':
669
- from omegaconf import OmegaConf
670
-
671
- # refiner
672
- # opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
673
- #
674
- # model = instantiate_from_config(opt.model.params.control_stage_config)
675
- # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
676
- # hint = [h.cuda() for h in hint]
677
- # print(sum(map(lambda hint: hint.numel(), model.parameters())))
678
- #
679
- # unet = instantiate_from_config(opt.model.params.network_config)
680
- # unet = unet.cuda()
681
- #
682
- # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
683
- # torch.randn([1, 2560]).cuda(), hint)
684
- # print(sum(map(lambda _output: _output.numel(), unet.parameters())))
685
-
686
- # base
687
- with torch.no_grad():
688
- opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
689
-
690
- model = instantiate_from_config(opt.model.params.control_stage_config)
691
- model = model.cuda()
692
-
693
- hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
694
- torch.randn([1, 2816]).cuda())
695
-
696
- for h in hint:
697
- print(h.shape)
698
- #
699
- unet = instantiate_from_config(opt.model.params.network_config)
700
- unet = unet.cuda()
701
- _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
702
- torch.randn([1, 2816]).cuda(), hint)
703
-
704
-
705
- # model = instantiate_from_config(opt.model.params.control_stage_config)
706
- # model = model.cuda()
707
- # # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
708
- # hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
709
- # torch.randn([1, 2560]).cuda())
710
- # # hint = [h.cuda() for h in hint]
711
- #
712
- # for h in hint:
713
- # print(h.shape)
714
- #
715
- # unet = instantiate_from_config(opt.model.params.network_config)
716
- # unet = unet.cuda()
717
- # _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
718
- # torch.randn([1, 2560]).cuda(), hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/modules/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- SDXL_BASE_CHANNEL_DICT = {
2
- 'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
3
- 'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
4
- 'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
5
- }
6
-
7
- SDXL_REFINE_CHANNEL_DICT = {
8
- 'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
9
- 'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
10
- 'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/util.py DELETED
@@ -1,179 +0,0 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- from torch.nn.functional import interpolate
7
- from omegaconf import OmegaConf
8
- from sgm.util import instantiate_from_config
9
-
10
-
11
- def get_state_dict(d):
12
- return d.get('state_dict', d)
13
-
14
-
15
- def load_state_dict(ckpt_path, location='cpu'):
16
- _, extension = os.path.splitext(ckpt_path)
17
- if extension.lower() == ".safetensors":
18
- import safetensors.torch
19
- state_dict = safetensors.torch.load_file(ckpt_path, device=location)
20
- else:
21
- state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
22
- state_dict = get_state_dict(state_dict)
23
- print(f'Loaded state_dict from [{ckpt_path}]')
24
- return state_dict
25
-
26
-
27
- def create_model(config_path):
28
- config = OmegaConf.load(config_path)
29
- model = instantiate_from_config(config.model).cpu()
30
- print(f'Loaded model config from [{config_path}]')
31
- return model
32
-
33
-
34
- def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
35
- config = OmegaConf.load(config_path)
36
- model = instantiate_from_config(config.model).cpu()
37
- print(f'Loaded model config from [{config_path}]')
38
- if config.SDXL_CKPT is not None:
39
- model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
40
- if config.SUPIR_CKPT is not None:
41
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
42
- if SUPIR_sign is not None:
43
- assert SUPIR_sign in ['F', 'Q']
44
- if SUPIR_sign == 'F':
45
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
46
- elif SUPIR_sign == 'Q':
47
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
48
- if load_default_setting:
49
- default_setting = config.default_setting
50
- return model, default_setting
51
- return model
52
-
53
- def load_QF_ckpt(config_path):
54
- config = OmegaConf.load(config_path)
55
- ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu')
56
- ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu')
57
- return ckpt_Q, ckpt_F
58
-
59
-
60
- def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
61
- '''
62
- PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
63
- '''
64
- # size
65
- w, h = img.size
66
- w *= upsacle
67
- h *= upsacle
68
- w0, h0 = round(w), round(h)
69
- if min(w, h) < min_size:
70
- _upsacle = min_size / min(w, h)
71
- w *= _upsacle
72
- h *= _upsacle
73
- if fix_resize is not None:
74
- _upsacle = fix_resize / min(w, h)
75
- w *= _upsacle
76
- h *= _upsacle
77
- w0, h0 = round(w), round(h)
78
- w = int(np.round(w / 64.0)) * 64
79
- h = int(np.round(h / 64.0)) * 64
80
- x = img.resize((w, h), Image.BICUBIC)
81
- x = np.array(x).round().clip(0, 255).astype(np.uint8)
82
- x = x / 255 * 2 - 1
83
- x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
84
- return x, h0, w0
85
-
86
-
87
- def Tensor2PIL(x, h0, w0):
88
- '''
89
- Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
90
- '''
91
- x = x.unsqueeze(0)
92
- x = interpolate(x, size=(h0, w0), mode='bicubic')
93
- x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
94
- return Image.fromarray(x)
95
-
96
-
97
- def HWC3(x):
98
- assert x.dtype == np.uint8
99
- if x.ndim == 2:
100
- x = x[:, :, None]
101
- assert x.ndim == 3
102
- H, W, C = x.shape
103
- assert C == 1 or C == 3 or C == 4
104
- if C == 3:
105
- return x
106
- if C == 1:
107
- return np.concatenate([x, x, x], axis=2)
108
- if C == 4:
109
- color = x[:, :, 0:3].astype(np.float32)
110
- alpha = x[:, :, 3:4].astype(np.float32) / 255.0
111
- y = color * alpha + 255.0 * (1.0 - alpha)
112
- y = y.clip(0, 255).astype(np.uint8)
113
- return y
114
-
115
-
116
- def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
117
- H, W, C = input_image.shape
118
- H = float(H)
119
- W = float(W)
120
- H *= upscale
121
- W *= upscale
122
- if min_size is not None:
123
- if min(H, W) < min_size:
124
- _upsacle = min_size / min(W, H)
125
- W *= _upsacle
126
- H *= _upsacle
127
- H = int(np.round(H / unit_resolution)) * unit_resolution
128
- W = int(np.round(W / unit_resolution)) * unit_resolution
129
- img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
130
- img = img.round().clip(0, 255).astype(np.uint8)
131
- return img
132
-
133
-
134
- def fix_resize(input_image, size=512, unit_resolution=64):
135
- H, W, C = input_image.shape
136
- H = float(H)
137
- W = float(W)
138
- upscale = size / min(H, W)
139
- H *= upscale
140
- W *= upscale
141
- H = int(np.round(H / unit_resolution)) * unit_resolution
142
- W = int(np.round(W / unit_resolution)) * unit_resolution
143
- img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
144
- img = img.round().clip(0, 255).astype(np.uint8)
145
- return img
146
-
147
-
148
-
149
- def Numpy2Tensor(img):
150
- '''
151
- np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
152
- '''
153
- # size
154
- img = np.array(img) / 255 * 2 - 1
155
- img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
156
- return img
157
-
158
-
159
- def Tensor2Numpy(x, h0=None, w0=None):
160
- '''
161
- Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
162
- '''
163
- if h0 is not None and w0 is not None:
164
- x = x.unsqueeze(0)
165
- x = interpolate(x, size=(h0, w0), mode='bicubic')
166
- x = x.squeeze(0)
167
- x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
168
- return x
169
-
170
-
171
- def convert_dtype(dtype_str):
172
- if dtype_str == 'fp32':
173
- return torch.float32
174
- elif dtype_str == 'fp16':
175
- return torch.float16
176
- elif dtype_str == 'bf16':
177
- return torch.bfloat16
178
- else:
179
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/utils/__init__.py DELETED
File without changes
SUPIR/utils/colorfix.py DELETED
@@ -1,120 +0,0 @@
1
- '''
2
- # --------------------------------------------------------------------------------
3
- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4
- # --------------------------------------------------------------------------------
5
- '''
6
-
7
- import torch
8
- from PIL import Image
9
- from torch import Tensor
10
- from torch.nn import functional as F
11
-
12
- from torchvision.transforms import ToTensor, ToPILImage
13
-
14
- def adain_color_fix(target: Image, source: Image):
15
- # Convert images to tensors
16
- to_tensor = ToTensor()
17
- target_tensor = to_tensor(target).unsqueeze(0)
18
- source_tensor = to_tensor(source).unsqueeze(0)
19
-
20
- # Apply adaptive instance normalization
21
- result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22
-
23
- # Convert tensor back to image
24
- to_image = ToPILImage()
25
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26
-
27
- return result_image
28
-
29
- def wavelet_color_fix(target: Image, source: Image):
30
- # Convert images to tensors
31
- to_tensor = ToTensor()
32
- target_tensor = to_tensor(target).unsqueeze(0)
33
- source_tensor = to_tensor(source).unsqueeze(0)
34
-
35
- # Apply wavelet reconstruction
36
- result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37
-
38
- # Convert tensor back to image
39
- to_image = ToPILImage()
40
- result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41
-
42
- return result_image
43
-
44
- def calc_mean_std(feat: Tensor, eps=1e-5):
45
- """Calculate mean and std for adaptive_instance_normalization.
46
- Args:
47
- feat (Tensor): 4D tensor.
48
- eps (float): A small value added to the variance to avoid
49
- divide-by-zero. Default: 1e-5.
50
- """
51
- size = feat.size()
52
- assert len(size) == 4, 'The input feature should be 4D tensor.'
53
- b, c = size[:2]
54
- feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55
- feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56
- feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57
- return feat_mean, feat_std
58
-
59
- def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60
- """Adaptive instance normalization.
61
- Adjust the reference features to have the similar color and illuminations
62
- as those in the degradate features.
63
- Args:
64
- content_feat (Tensor): The reference feature.
65
- style_feat (Tensor): The degradate features.
66
- """
67
- size = content_feat.size()
68
- style_mean, style_std = calc_mean_std(style_feat)
69
- content_mean, content_std = calc_mean_std(content_feat)
70
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72
-
73
- def wavelet_blur(image: Tensor, radius: int):
74
- """
75
- Apply wavelet blur to the input tensor.
76
- """
77
- # input shape: (1, 3, H, W)
78
- # convolution kernel
79
- kernel_vals = [
80
- [0.0625, 0.125, 0.0625],
81
- [0.125, 0.25, 0.125],
82
- [0.0625, 0.125, 0.0625],
83
- ]
84
- kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85
- # add channel dimensions to the kernel to make it a 4D tensor
86
- kernel = kernel[None, None]
87
- # repeat the kernel across all input channels
88
- kernel = kernel.repeat(3, 1, 1, 1)
89
- image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90
- # apply convolution
91
- output = F.conv2d(image, kernel, groups=3, dilation=radius)
92
- return output
93
-
94
- def wavelet_decomposition(image: Tensor, levels=5):
95
- """
96
- Apply wavelet decomposition to the input tensor.
97
- This function only returns the low frequency & the high frequency.
98
- """
99
- high_freq = torch.zeros_like(image)
100
- for i in range(levels):
101
- radius = 2 ** i
102
- low_freq = wavelet_blur(image, radius)
103
- high_freq += (image - low_freq)
104
- image = low_freq
105
-
106
- return high_freq, low_freq
107
-
108
- def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109
- """
110
- Apply wavelet decomposition, so that the content will have the same color as the style.
111
- """
112
- # calculate the wavelet decomposition of the content feature
113
- content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114
- del content_low_freq
115
- # calculate the wavelet decomposition of the style feature
116
- style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117
- del style_high_freq
118
- # reconstruct the content feature with the style's high frequency
119
- return content_high_freq + style_low_freq
120
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/utils/devices.py DELETED
@@ -1,138 +0,0 @@
1
- import sys
2
- import contextlib
3
- from functools import lru_cache
4
-
5
- import torch
6
- #from modules import errors
7
-
8
- if sys.platform == "darwin":
9
- from modules import mac_specific
10
-
11
-
12
- def has_mps() -> bool:
13
- if sys.platform != "darwin":
14
- return False
15
- else:
16
- return mac_specific.has_mps
17
-
18
-
19
- def get_cuda_device_string():
20
- return "cuda"
21
-
22
-
23
- def get_optimal_device_name():
24
- if torch.cuda.is_available():
25
- return get_cuda_device_string()
26
-
27
- if has_mps():
28
- return "mps"
29
-
30
- return "cpu"
31
-
32
-
33
- def get_optimal_device():
34
- return torch.device(get_optimal_device_name())
35
-
36
-
37
- def get_device_for(task):
38
- return get_optimal_device()
39
-
40
-
41
- def torch_gc():
42
-
43
- if torch.cuda.is_available():
44
- with torch.cuda.device(get_cuda_device_string()):
45
- torch.cuda.empty_cache()
46
- torch.cuda.ipc_collect()
47
-
48
- if has_mps():
49
- mac_specific.torch_mps_gc()
50
-
51
-
52
- def enable_tf32():
53
- if torch.cuda.is_available():
54
-
55
- # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56
- # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57
- if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58
- torch.backends.cudnn.benchmark = True
59
-
60
- torch.backends.cuda.matmul.allow_tf32 = True
61
- torch.backends.cudnn.allow_tf32 = True
62
-
63
-
64
- enable_tf32()
65
- #errors.run(enable_tf32, "Enabling TF32")
66
-
67
- cpu = torch.device("cpu")
68
- device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69
- dtype = torch.float16
70
- dtype_vae = torch.float16
71
- dtype_unet = torch.float16
72
- unet_needs_upcast = False
73
-
74
-
75
- def cond_cast_unet(input):
76
- return input.to(dtype_unet) if unet_needs_upcast else input
77
-
78
-
79
- def cond_cast_float(input):
80
- return input.float() if unet_needs_upcast else input
81
-
82
-
83
- def randn(seed, shape):
84
- torch.manual_seed(seed)
85
- return torch.randn(shape, device=device)
86
-
87
-
88
- def randn_without_seed(shape):
89
- return torch.randn(shape, device=device)
90
-
91
-
92
- def autocast(disable=False):
93
- if disable:
94
- return contextlib.nullcontext()
95
-
96
- return torch.autocast("cuda")
97
-
98
-
99
- def without_autocast(disable=False):
100
- return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101
-
102
-
103
- class NansException(Exception):
104
- pass
105
-
106
-
107
- def test_for_nans(x, where):
108
- if not torch.all(torch.isnan(x)).item():
109
- return
110
-
111
- if where == "unet":
112
- message = "A tensor with all NaNs was produced in Unet."
113
-
114
- elif where == "vae":
115
- message = "A tensor with all NaNs was produced in VAE."
116
-
117
- else:
118
- message = "A tensor with all NaNs was produced."
119
-
120
- message += " Use --disable-nan-check commandline argument to disable this check."
121
-
122
- raise NansException(message)
123
-
124
-
125
- @lru_cache
126
- def first_time_calculation():
127
- """
128
- just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129
- spends about 2.7 seconds doing that, at least wih NVidia.
130
- """
131
-
132
- x = torch.zeros((1, 1)).to(device, dtype)
133
- linear = torch.nn.Linear(1, 1).to(device, dtype)
134
- linear(x)
135
-
136
- x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137
- conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138
- conv2d(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/utils/face_restoration_helper.py DELETED
@@ -1,514 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import os
4
- import torch
5
- from torchvision.transforms.functional import normalize
6
-
7
- from facexlib.detection import init_detection_model
8
- from facexlib.parsing import init_parsing_model
9
- from facexlib.utils.misc import img2tensor, imwrite
10
-
11
- from .file import load_file_from_url
12
-
13
-
14
- def get_largest_face(det_faces, h, w):
15
- def get_location(val, length):
16
- if val < 0:
17
- return 0
18
- elif val > length:
19
- return length
20
- else:
21
- return val
22
-
23
- face_areas = []
24
- for det_face in det_faces:
25
- left = get_location(det_face[0], w)
26
- right = get_location(det_face[2], w)
27
- top = get_location(det_face[1], h)
28
- bottom = get_location(det_face[3], h)
29
- face_area = (right - left) * (bottom - top)
30
- face_areas.append(face_area)
31
- largest_idx = face_areas.index(max(face_areas))
32
- return det_faces[largest_idx], largest_idx
33
-
34
-
35
- def get_center_face(det_faces, h=0, w=0, center=None):
36
- if center is not None:
37
- center = np.array(center)
38
- else:
39
- center = np.array([w / 2, h / 2])
40
- center_dist = []
41
- for det_face in det_faces:
42
- face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
43
- dist = np.linalg.norm(face_center - center)
44
- center_dist.append(dist)
45
- center_idx = center_dist.index(min(center_dist))
46
- return det_faces[center_idx], center_idx
47
-
48
-
49
- class FaceRestoreHelper(object):
50
- """Helper for the face restoration pipeline (base class)."""
51
-
52
- def __init__(self,
53
- upscale_factor,
54
- face_size=512,
55
- crop_ratio=(1, 1),
56
- det_model='retinaface_resnet50',
57
- save_ext='png',
58
- template_3points=False,
59
- pad_blur=False,
60
- use_parse=False,
61
- device=None):
62
- self.template_3points = template_3points # improve robustness
63
- self.upscale_factor = int(upscale_factor)
64
- # the cropped face ratio based on the square face
65
- self.crop_ratio = crop_ratio # (h, w)
66
- assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
67
- self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
68
- self.det_model = det_model
69
-
70
- if self.det_model == 'dlib':
71
- # standard 5 landmarks for FFHQ faces with 1024 x 1024
72
- self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
73
- [337.91089109, 488.38613861], [437.95049505, 493.51485149],
74
- [513.58415842, 678.5049505]])
75
- self.face_template = self.face_template / (1024 // face_size)
76
- elif self.template_3points:
77
- self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
78
- else:
79
- # standard 5 landmarks for FFHQ faces with 512 x 512
80
- # facexlib
81
- self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
82
- [201.26117, 371.41043], [313.08905, 371.15118]])
83
-
84
- # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
85
- # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
86
- # [198.22603, 372.82502], [313.91018, 372.75659]])
87
-
88
- self.face_template = self.face_template * (face_size / 512.0)
89
- if self.crop_ratio[0] > 1:
90
- self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
91
- if self.crop_ratio[1] > 1:
92
- self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
93
- self.save_ext = save_ext
94
- self.pad_blur = pad_blur
95
- if self.pad_blur is True:
96
- self.template_3points = False
97
-
98
- self.all_landmarks_5 = []
99
- self.det_faces = []
100
- self.affine_matrices = []
101
- self.inverse_affine_matrices = []
102
- self.cropped_faces = []
103
- self.restored_faces = []
104
- self.pad_input_imgs = []
105
-
106
- if device is None:
107
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
108
- # self.device = get_device()
109
- else:
110
- self.device = device
111
-
112
- # init face detection model
113
- self.face_detector = init_detection_model(det_model, half=False, device=self.device)
114
-
115
- # init face parsing model
116
- self.use_parse = use_parse
117
- self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
118
-
119
- def set_upscale_factor(self, upscale_factor):
120
- self.upscale_factor = upscale_factor
121
-
122
- def read_image(self, img):
123
- """img can be image path or cv2 loaded image."""
124
- # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
125
- if isinstance(img, str):
126
- img = cv2.imread(img)
127
-
128
- if np.max(img) > 256: # 16-bit image
129
- img = img / 65535 * 255
130
- if len(img.shape) == 2: # gray image
131
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
132
- elif img.shape[2] == 4: # BGRA image with alpha channel
133
- img = img[:, :, 0:3]
134
-
135
- self.input_img = img
136
- # self.is_gray = is_gray(img, threshold=10)
137
- # if self.is_gray:
138
- # print('Grayscale input: True')
139
-
140
- if min(self.input_img.shape[:2]) < 512:
141
- f = 512.0 / min(self.input_img.shape[:2])
142
- self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
143
-
144
- def init_dlib(self, detection_path, landmark5_path):
145
- """Initialize the dlib detectors and predictors."""
146
- try:
147
- import dlib
148
- except ImportError:
149
- print('Please install dlib by running:' 'conda install -c conda-forge dlib')
150
- detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
151
- landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
152
- face_detector = dlib.cnn_face_detection_model_v1(detection_path)
153
- shape_predictor_5 = dlib.shape_predictor(landmark5_path)
154
- return face_detector, shape_predictor_5
155
-
156
- def get_face_landmarks_5_dlib(self,
157
- only_keep_largest=False,
158
- scale=1):
159
- det_faces = self.face_detector(self.input_img, scale)
160
-
161
- if len(det_faces) == 0:
162
- print('No face detected. Try to increase upsample_num_times.')
163
- return 0
164
- else:
165
- if only_keep_largest:
166
- print('Detect several faces and only keep the largest.')
167
- face_areas = []
168
- for i in range(len(det_faces)):
169
- face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
170
- det_faces[i].rect.bottom() - det_faces[i].rect.top())
171
- face_areas.append(face_area)
172
- largest_idx = face_areas.index(max(face_areas))
173
- self.det_faces = [det_faces[largest_idx]]
174
- else:
175
- self.det_faces = det_faces
176
-
177
- if len(self.det_faces) == 0:
178
- return 0
179
-
180
- for face in self.det_faces:
181
- shape = self.shape_predictor_5(self.input_img, face.rect)
182
- landmark = np.array([[part.x, part.y] for part in shape.parts()])
183
- self.all_landmarks_5.append(landmark)
184
-
185
- return len(self.all_landmarks_5)
186
-
187
- def get_face_landmarks_5(self,
188
- only_keep_largest=False,
189
- only_center_face=False,
190
- resize=None,
191
- blur_ratio=0.01,
192
- eye_dist_threshold=None):
193
- if self.det_model == 'dlib':
194
- return self.get_face_landmarks_5_dlib(only_keep_largest)
195
-
196
- if resize is None:
197
- scale = 1
198
- input_img = self.input_img
199
- else:
200
- h, w = self.input_img.shape[0:2]
201
- scale = resize / min(h, w)
202
- scale = max(1, scale) # always scale up
203
- h, w = int(h * scale), int(w * scale)
204
- interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
205
- input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
206
-
207
- with torch.no_grad():
208
- bboxes = self.face_detector.detect_faces(input_img)
209
-
210
- if bboxes is None or bboxes.shape[0] == 0:
211
- return 0
212
- else:
213
- bboxes = bboxes / scale
214
-
215
- for bbox in bboxes:
216
- # remove faces with too small eye distance: side faces or too small faces
217
- eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
218
- if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
219
- continue
220
-
221
- if self.template_3points:
222
- landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
223
- else:
224
- landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
225
- self.all_landmarks_5.append(landmark)
226
- self.det_faces.append(bbox[0:5])
227
-
228
- if len(self.det_faces) == 0:
229
- return 0
230
- if only_keep_largest:
231
- h, w, _ = self.input_img.shape
232
- self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
233
- self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
234
- elif only_center_face:
235
- h, w, _ = self.input_img.shape
236
- self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
237
- self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
238
-
239
- # pad blurry images
240
- if self.pad_blur:
241
- self.pad_input_imgs = []
242
- for landmarks in self.all_landmarks_5:
243
- # get landmarks
244
- eye_left = landmarks[0, :]
245
- eye_right = landmarks[1, :]
246
- eye_avg = (eye_left + eye_right) * 0.5
247
- mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
248
- eye_to_eye = eye_right - eye_left
249
- eye_to_mouth = mouth_avg - eye_avg
250
-
251
- # Get the oriented crop rectangle
252
- # x: half width of the oriented crop rectangle
253
- x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
254
- # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
255
- # norm with the hypotenuse: get the direction
256
- x /= np.hypot(*x) # get the hypotenuse of a right triangle
257
- rect_scale = 1.5
258
- x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
259
- # y: half height of the oriented crop rectangle
260
- y = np.flipud(x) * [-1, 1]
261
-
262
- # c: center
263
- c = eye_avg + eye_to_mouth * 0.1
264
- # quad: (left_top, left_bottom, right_bottom, right_top)
265
- quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
266
- # qsize: side length of the square
267
- qsize = np.hypot(*x) * 2
268
- border = max(int(np.rint(qsize * 0.1)), 3)
269
-
270
- # get pad
271
- # pad: (width_left, height_top, width_right, height_bottom)
272
- pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
273
- int(np.ceil(max(quad[:, 1]))))
274
- pad = [
275
- max(-pad[0] + border, 1),
276
- max(-pad[1] + border, 1),
277
- max(pad[2] - self.input_img.shape[0] + border, 1),
278
- max(pad[3] - self.input_img.shape[1] + border, 1)
279
- ]
280
-
281
- if max(pad) > 1:
282
- # pad image
283
- pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
284
- # modify landmark coords
285
- landmarks[:, 0] += pad[0]
286
- landmarks[:, 1] += pad[1]
287
- # blur pad images
288
- h, w, _ = pad_img.shape
289
- y, x, _ = np.ogrid[:h, :w, :1]
290
- mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
291
- np.float32(w - 1 - x) / pad[2]),
292
- 1.0 - np.minimum(np.float32(y) / pad[1],
293
- np.float32(h - 1 - y) / pad[3]))
294
- blur = int(qsize * blur_ratio)
295
- if blur % 2 == 0:
296
- blur += 1
297
- blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
298
- # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
299
-
300
- pad_img = pad_img.astype('float32')
301
- pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
302
- pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
303
- pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
304
- self.pad_input_imgs.append(pad_img)
305
- else:
306
- self.pad_input_imgs.append(np.copy(self.input_img))
307
-
308
- return len(self.all_landmarks_5)
309
-
310
- def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
311
- """Align and warp faces with face template.
312
- """
313
- if self.pad_blur:
314
- assert len(self.pad_input_imgs) == len(
315
- self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
316
- for idx, landmark in enumerate(self.all_landmarks_5):
317
- # use 5 landmarks to get affine matrix
318
- # use cv2.LMEDS method for the equivalence to skimage transform
319
- # ref: https://blog.csdn.net/yichxi/article/details/115827338
320
- affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
321
- self.affine_matrices.append(affine_matrix)
322
- # warp and crop faces
323
- if border_mode == 'constant':
324
- border_mode = cv2.BORDER_CONSTANT
325
- elif border_mode == 'reflect101':
326
- border_mode = cv2.BORDER_REFLECT101
327
- elif border_mode == 'reflect':
328
- border_mode = cv2.BORDER_REFLECT
329
- if self.pad_blur:
330
- input_img = self.pad_input_imgs[idx]
331
- else:
332
- input_img = self.input_img
333
- cropped_face = cv2.warpAffine(
334
- input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
335
- self.cropped_faces.append(cropped_face)
336
- # save the cropped face
337
- if save_cropped_path is not None:
338
- path = os.path.splitext(save_cropped_path)[0]
339
- save_path = f'{path}_{idx:02d}.{self.save_ext}'
340
- imwrite(cropped_face, save_path)
341
-
342
- def get_inverse_affine(self, save_inverse_affine_path=None):
343
- """Get inverse affine matrix."""
344
- for idx, affine_matrix in enumerate(self.affine_matrices):
345
- inverse_affine = cv2.invertAffineTransform(affine_matrix)
346
- inverse_affine *= self.upscale_factor
347
- self.inverse_affine_matrices.append(inverse_affine)
348
- # save inverse affine matrices
349
- if save_inverse_affine_path is not None:
350
- path, _ = os.path.splitext(save_inverse_affine_path)
351
- save_path = f'{path}_{idx:02d}.pth'
352
- torch.save(inverse_affine, save_path)
353
-
354
- def add_restored_face(self, restored_face, input_face=None):
355
- # if self.is_gray:
356
- # restored_face = bgr2gray(restored_face) # convert img into grayscale
357
- # if input_face is not None:
358
- # restored_face = adain_npy(restored_face, input_face) # transfer the color
359
- self.restored_faces.append(restored_face)
360
-
361
- def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
362
- h, w, _ = self.input_img.shape
363
- h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
364
-
365
- if upsample_img is None:
366
- # simply resize the background
367
- # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
368
- upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
369
- else:
370
- upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
371
-
372
- assert len(self.restored_faces) == len(
373
- self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
374
-
375
- inv_mask_borders = []
376
- for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
377
- if face_upsampler is not None:
378
- restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
379
- inverse_affine /= self.upscale_factor
380
- inverse_affine[:, 2] *= self.upscale_factor
381
- face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
382
- else:
383
- # Add an offset to inverse affine matrix, for more precise back alignment
384
- if self.upscale_factor > 1:
385
- extra_offset = 0.5 * self.upscale_factor
386
- else:
387
- extra_offset = 0
388
- inverse_affine[:, 2] += extra_offset
389
- face_size = self.face_size
390
- inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
391
-
392
- # if draw_box or not self.use_parse: # use square parse maps
393
- # mask = np.ones(face_size, dtype=np.float32)
394
- # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
395
- # # remove the black borders
396
- # inv_mask_erosion = cv2.erode(
397
- # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
398
- # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
399
- # total_face_area = np.sum(inv_mask_erosion) # // 3
400
- # # add border
401
- # if draw_box:
402
- # h, w = face_size
403
- # mask_border = np.ones((h, w, 3), dtype=np.float32)
404
- # border = int(1400/np.sqrt(total_face_area))
405
- # mask_border[border:h-border, border:w-border,:] = 0
406
- # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
407
- # inv_mask_borders.append(inv_mask_border)
408
- # if not self.use_parse:
409
- # # compute the fusion edge based on the area of face
410
- # w_edge = int(total_face_area**0.5) // 20
411
- # erosion_radius = w_edge * 2
412
- # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
413
- # blur_size = w_edge * 2
414
- # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
415
- # if len(upsample_img.shape) == 2: # upsample_img is gray image
416
- # upsample_img = upsample_img[:, :, None]
417
- # inv_soft_mask = inv_soft_mask[:, :, None]
418
-
419
- # always use square mask
420
- mask = np.ones(face_size, dtype=np.float32)
421
- inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
422
- # remove the black borders
423
- inv_mask_erosion = cv2.erode(
424
- inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
425
- pasted_face = inv_mask_erosion[:, :, None] * inv_restored
426
- total_face_area = np.sum(inv_mask_erosion) # // 3
427
- # add border
428
- if draw_box:
429
- h, w = face_size
430
- mask_border = np.ones((h, w, 3), dtype=np.float32)
431
- border = int(1400 / np.sqrt(total_face_area))
432
- mask_border[border:h - border, border:w - border, :] = 0
433
- inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
434
- inv_mask_borders.append(inv_mask_border)
435
- # compute the fusion edge based on the area of face
436
- w_edge = int(total_face_area ** 0.5) // 20
437
- erosion_radius = w_edge * 2
438
- inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
439
- blur_size = w_edge * 2
440
- inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
441
- if len(upsample_img.shape) == 2: # upsample_img is gray image
442
- upsample_img = upsample_img[:, :, None]
443
- inv_soft_mask = inv_soft_mask[:, :, None]
444
-
445
- # parse mask
446
- if self.use_parse:
447
- # inference
448
- face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
449
- face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
450
- normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
451
- face_input = torch.unsqueeze(face_input, 0).to(self.device)
452
- with torch.no_grad():
453
- out = self.face_parse(face_input)[0]
454
- out = out.argmax(dim=1).squeeze().cpu().numpy()
455
-
456
- parse_mask = np.zeros(out.shape)
457
- MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
458
- for idx, color in enumerate(MASK_COLORMAP):
459
- parse_mask[out == idx] = color
460
- # blur the mask
461
- parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
462
- parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
463
- # remove the black borders
464
- thres = 10
465
- parse_mask[:thres, :] = 0
466
- parse_mask[-thres:, :] = 0
467
- parse_mask[:, :thres] = 0
468
- parse_mask[:, -thres:] = 0
469
- parse_mask = parse_mask / 255.
470
-
471
- parse_mask = cv2.resize(parse_mask, face_size)
472
- parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
473
- inv_soft_parse_mask = parse_mask[:, :, None]
474
- # pasted_face = inv_restored
475
- fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
476
- inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
477
-
478
- if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
479
- alpha = upsample_img[:, :, 3:]
480
- upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
481
- upsample_img = np.concatenate((upsample_img, alpha), axis=2)
482
- else:
483
- upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
484
-
485
- if np.max(upsample_img) > 256: # 16-bit image
486
- upsample_img = upsample_img.astype(np.uint16)
487
- else:
488
- upsample_img = upsample_img.astype(np.uint8)
489
-
490
- # draw bounding box
491
- if draw_box:
492
- # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
493
- img_color = np.ones([*upsample_img.shape], dtype=np.float32)
494
- img_color[:, :, 0] = 0
495
- img_color[:, :, 1] = 255
496
- img_color[:, :, 2] = 0
497
- for inv_mask_border in inv_mask_borders:
498
- upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
499
- # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
500
-
501
- if save_path is not None:
502
- path = os.path.splitext(save_path)[0]
503
- save_path = f'{path}.{self.save_ext}'
504
- imwrite(upsample_img, save_path)
505
- return upsample_img
506
-
507
- def clean_all(self):
508
- self.all_landmarks_5 = []
509
- self.restored_faces = []
510
- self.affine_matrices = []
511
- self.cropped_faces = []
512
- self.inverse_affine_matrices = []
513
- self.det_faces = []
514
- self.pad_input_imgs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/utils/file.py DELETED
@@ -1,79 +0,0 @@
1
- import os
2
- from typing import List, Tuple
3
-
4
- from urllib.parse import urlparse
5
- from torch.hub import download_url_to_file, get_dir
6
-
7
-
8
- def load_file_list(file_list_path: str) -> List[str]:
9
- files = []
10
- # each line in file list contains a path of an image
11
- with open(file_list_path, "r") as fin:
12
- for line in fin:
13
- path = line.strip()
14
- if path:
15
- files.append(path)
16
- return files
17
-
18
-
19
- def list_image_files(
20
- img_dir: str,
21
- exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
22
- follow_links: bool=False,
23
- log_progress: bool=False,
24
- log_every_n_files: int=10000,
25
- max_size: int=-1
26
- ) -> List[str]:
27
- files = []
28
- for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
29
- early_stop = False
30
- for file_name in file_names:
31
- if os.path.splitext(file_name)[1].lower() in exts:
32
- if max_size >= 0 and len(files) >= max_size:
33
- early_stop = True
34
- break
35
- files.append(os.path.join(dir_path, file_name))
36
- if log_progress and len(files) % log_every_n_files == 0:
37
- print(f"find {len(files)} images in {img_dir}")
38
- if early_stop:
39
- break
40
- return files
41
-
42
-
43
- def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
44
- parent_path, file_name = os.path.split(file_path)
45
- stem, ext = os.path.splitext(file_name)
46
- return parent_path, stem, ext
47
-
48
-
49
- # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
50
- def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
51
- """Load file form http url, will download models if necessary.
52
-
53
- Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
54
-
55
- Args:
56
- url (str): URL to be downloaded.
57
- model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
58
- Default: None.
59
- progress (bool): Whether to show the download progress. Default: True.
60
- file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
61
-
62
- Returns:
63
- str: The path to the downloaded file.
64
- """
65
- if model_dir is None: # use the pytorch hub_dir
66
- hub_dir = get_dir()
67
- model_dir = os.path.join(hub_dir, 'checkpoints')
68
-
69
- os.makedirs(model_dir, exist_ok=True)
70
-
71
- parts = urlparse(url)
72
- filename = os.path.basename(parts.path)
73
- if file_name is not None:
74
- filename = file_name
75
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
76
- if not os.path.exists(cached_file):
77
- print(f'Downloading: "{url}" to {cached_file}\n')
78
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
79
- return cached_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SUPIR/utils/tilevae.py DELETED
@@ -1,971 +0,0 @@
1
- # ------------------------------------------------------------------------
2
- #
3
- # Ultimate VAE Tile Optimization
4
- #
5
- # Introducing a revolutionary new optimization designed to make
6
- # the VAE work with giant images on limited VRAM!
7
- # Say goodbye to the frustration of OOM and hello to seamless output!
8
- #
9
- # ------------------------------------------------------------------------
10
- #
11
- # This script is a wild hack that splits the image into tiles,
12
- # encodes each tile separately, and merges the result back together.
13
- #
14
- # Advantages:
15
- # - The VAE can now work with giant images on limited VRAM
16
- # (~10 GB for 8K images!)
17
- # - The merged output is completely seamless without any post-processing.
18
- #
19
- # Drawbacks:
20
- # - Giant RAM needed. To store the intermediate results for a 4096x4096
21
- # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
22
- # you need 128 GB RAM machine (it consumes ~100 GB)
23
- # - NaNs always appear in for 8k images when you use fp16 (half) VAE
24
- # You must use --no-half-vae to disable half VAE for that giant image.
25
- # - Slow speed. With default tile size, it takes around 50/200 seconds
26
- # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
27
- # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
28
- # - The gradient calculation is not compatible with this hack. It
29
- # will break any backward() or torch.autograd.grad() that passes VAE.
30
- # (But you can still use the VAE to generate training data.)
31
- #
32
- # How it works:
33
- # 1) The image is split into tiles.
34
- # - To ensure perfect results, each tile is padded with 32 pixels
35
- # on each side.
36
- # - Then the conv2d/silu/upsample/downsample can produce identical
37
- # results to the original image without splitting.
38
- # 2) The original forward is decomposed into a task queue and a task worker.
39
- # - The task queue is a list of functions that will be executed in order.
40
- # - The task worker is a loop that executes the tasks in the queue.
41
- # 3) The task queue is executed for each tile.
42
- # - Current tile is sent to GPU.
43
- # - local operations are directly executed.
44
- # - Group norm calculation is temporarily suspended until the mean
45
- # and var of all tiles are calculated.
46
- # - The residual is pre-calculated and stored and addded back later.
47
- # - When need to go to the next tile, the current tile is send to cpu.
48
- # 4) After all tiles are processed, tiles are merged on cpu and return.
49
- #
50
- # Enjoy!
51
- #
52
- # @author: LI YI @ Nanyang Technological University - Singapore
53
- # @date: 2023-03-02
54
- # @license: MIT License
55
- #
56
- # Please give me a star if you like this project!
57
- #
58
- # -------------------------------------------------------------------------
59
-
60
- import gc
61
- from time import time
62
- import math
63
- from tqdm import tqdm
64
-
65
- import torch
66
- import torch.version
67
- import torch.nn.functional as F
68
- from einops import rearrange
69
- from diffusers.utils.import_utils import is_xformers_available
70
-
71
- import SUPIR.utils.devices as devices
72
-
73
- try:
74
- import xformers
75
- import xformers.ops
76
- except ImportError:
77
- pass
78
-
79
- sd_flag = True
80
-
81
- def get_recommend_encoder_tile_size():
82
- if torch.cuda.is_available():
83
- total_memory = torch.cuda.get_device_properties(
84
- devices.device).total_memory // 2**20
85
- if total_memory > 16*1000:
86
- ENCODER_TILE_SIZE = 3072
87
- elif total_memory > 12*1000:
88
- ENCODER_TILE_SIZE = 2048
89
- elif total_memory > 8*1000:
90
- ENCODER_TILE_SIZE = 1536
91
- else:
92
- ENCODER_TILE_SIZE = 960
93
- else:
94
- ENCODER_TILE_SIZE = 512
95
- return ENCODER_TILE_SIZE
96
-
97
-
98
- def get_recommend_decoder_tile_size():
99
- if torch.cuda.is_available():
100
- total_memory = torch.cuda.get_device_properties(
101
- devices.device).total_memory // 2**20
102
- if total_memory > 30*1000:
103
- DECODER_TILE_SIZE = 256
104
- elif total_memory > 16*1000:
105
- DECODER_TILE_SIZE = 192
106
- elif total_memory > 12*1000:
107
- DECODER_TILE_SIZE = 128
108
- elif total_memory > 8*1000:
109
- DECODER_TILE_SIZE = 96
110
- else:
111
- DECODER_TILE_SIZE = 64
112
- else:
113
- DECODER_TILE_SIZE = 64
114
- return DECODER_TILE_SIZE
115
-
116
-
117
- if 'global const':
118
- DEFAULT_ENABLED = False
119
- DEFAULT_MOVE_TO_GPU = False
120
- DEFAULT_FAST_ENCODER = True
121
- DEFAULT_FAST_DECODER = True
122
- DEFAULT_COLOR_FIX = 0
123
- DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
124
- DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
125
-
126
-
127
- # inplace version of silu
128
- def inplace_nonlinearity(x):
129
- # Test: fix for Nans
130
- return F.silu(x, inplace=True)
131
-
132
- # extracted from ldm.modules.diffusionmodules.model
133
-
134
- # from diffusers lib
135
- def attn_forward_new(self, h_):
136
- batch_size, channel, height, width = h_.shape
137
- hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
138
-
139
- attention_mask = None
140
- encoder_hidden_states = None
141
- batch_size, sequence_length, _ = hidden_states.shape
142
- attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
143
-
144
- query = self.to_q(hidden_states)
145
-
146
- if encoder_hidden_states is None:
147
- encoder_hidden_states = hidden_states
148
- elif self.norm_cross:
149
- encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
150
-
151
- key = self.to_k(encoder_hidden_states)
152
- value = self.to_v(encoder_hidden_states)
153
-
154
- query = self.head_to_batch_dim(query)
155
- key = self.head_to_batch_dim(key)
156
- value = self.head_to_batch_dim(value)
157
-
158
- attention_probs = self.get_attention_scores(query, key, attention_mask)
159
- hidden_states = torch.bmm(attention_probs, value)
160
- hidden_states = self.batch_to_head_dim(hidden_states)
161
-
162
- # linear proj
163
- hidden_states = self.to_out[0](hidden_states)
164
- # dropout
165
- hidden_states = self.to_out[1](hidden_states)
166
-
167
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
168
-
169
- return hidden_states
170
-
171
- def attn_forward_new_pt2_0(self, hidden_states,):
172
- scale = 1
173
- attention_mask = None
174
- encoder_hidden_states = None
175
-
176
- input_ndim = hidden_states.ndim
177
-
178
- if input_ndim == 4:
179
- batch_size, channel, height, width = hidden_states.shape
180
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
181
-
182
- batch_size, sequence_length, _ = (
183
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
184
- )
185
-
186
- if attention_mask is not None:
187
- attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
188
- # scaled_dot_product_attention expects attention_mask shape to be
189
- # (batch, heads, source_length, target_length)
190
- attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
191
-
192
- if self.group_norm is not None:
193
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
194
-
195
- query = self.to_q(hidden_states, scale=scale)
196
-
197
- if encoder_hidden_states is None:
198
- encoder_hidden_states = hidden_states
199
- elif self.norm_cross:
200
- encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
201
-
202
- key = self.to_k(encoder_hidden_states, scale=scale)
203
- value = self.to_v(encoder_hidden_states, scale=scale)
204
-
205
- inner_dim = key.shape[-1]
206
- head_dim = inner_dim // self.heads
207
-
208
- query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
209
-
210
- key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
211
- value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
212
-
213
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
214
- # TODO: add support for attn.scale when we move to Torch 2.1
215
- hidden_states = F.scaled_dot_product_attention(
216
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
217
- )
218
-
219
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
220
- hidden_states = hidden_states.to(query.dtype)
221
-
222
- # linear proj
223
- hidden_states = self.to_out[0](hidden_states, scale=scale)
224
- # dropout
225
- hidden_states = self.to_out[1](hidden_states)
226
-
227
- if input_ndim == 4:
228
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
229
-
230
- return hidden_states
231
-
232
- def attn_forward_new_xformers(self, hidden_states):
233
- scale = 1
234
- attention_op = None
235
- attention_mask = None
236
- encoder_hidden_states = None
237
-
238
- input_ndim = hidden_states.ndim
239
-
240
- if input_ndim == 4:
241
- batch_size, channel, height, width = hidden_states.shape
242
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
243
-
244
- batch_size, key_tokens, _ = (
245
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
246
- )
247
-
248
- attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
249
- if attention_mask is not None:
250
- # expand our mask's singleton query_tokens dimension:
251
- # [batch*heads, 1, key_tokens] ->
252
- # [batch*heads, query_tokens, key_tokens]
253
- # so that it can be added as a bias onto the attention scores that xformers computes:
254
- # [batch*heads, query_tokens, key_tokens]
255
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
256
- _, query_tokens, _ = hidden_states.shape
257
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
258
-
259
- if self.group_norm is not None:
260
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
261
-
262
- query = self.to_q(hidden_states, scale=scale)
263
-
264
- if encoder_hidden_states is None:
265
- encoder_hidden_states = hidden_states
266
- elif self.norm_cross:
267
- encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
268
-
269
- key = self.to_k(encoder_hidden_states, scale=scale)
270
- value = self.to_v(encoder_hidden_states, scale=scale)
271
-
272
- query = self.head_to_batch_dim(query).contiguous()
273
- key = self.head_to_batch_dim(key).contiguous()
274
- value = self.head_to_batch_dim(value).contiguous()
275
-
276
- hidden_states = xformers.ops.memory_efficient_attention(
277
- query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
278
- )
279
- hidden_states = hidden_states.to(query.dtype)
280
- hidden_states = self.batch_to_head_dim(hidden_states)
281
-
282
- # linear proj
283
- hidden_states = self.to_out[0](hidden_states, scale=scale)
284
- # dropout
285
- hidden_states = self.to_out[1](hidden_states)
286
-
287
- if input_ndim == 4:
288
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
289
-
290
- return hidden_states
291
-
292
- def attn_forward(self, h_):
293
- q = self.q(h_)
294
- k = self.k(h_)
295
- v = self.v(h_)
296
-
297
- # compute attention
298
- b, c, h, w = q.shape
299
- q = q.reshape(b, c, h*w)
300
- q = q.permute(0, 2, 1) # b,hw,c
301
- k = k.reshape(b, c, h*w) # b,c,hw
302
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
303
- w_ = w_ * (int(c)**(-0.5))
304
- w_ = torch.nn.functional.softmax(w_, dim=2)
305
-
306
- # attend to values
307
- v = v.reshape(b, c, h*w)
308
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
309
- # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
310
- h_ = torch.bmm(v, w_)
311
- h_ = h_.reshape(b, c, h, w)
312
-
313
- h_ = self.proj_out(h_)
314
-
315
- return h_
316
-
317
-
318
- def xformer_attn_forward(self, h_):
319
- q = self.q(h_)
320
- k = self.k(h_)
321
- v = self.v(h_)
322
-
323
- # compute attention
324
- B, C, H, W = q.shape
325
- q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
326
-
327
- q, k, v = map(
328
- lambda t: t.unsqueeze(3)
329
- .reshape(B, t.shape[1], 1, C)
330
- .permute(0, 2, 1, 3)
331
- .reshape(B * 1, t.shape[1], C)
332
- .contiguous(),
333
- (q, k, v),
334
- )
335
- out = xformers.ops.memory_efficient_attention(
336
- q, k, v, attn_bias=None, op=self.attention_op)
337
-
338
- out = (
339
- out.unsqueeze(0)
340
- .reshape(B, 1, out.shape[1], C)
341
- .permute(0, 2, 1, 3)
342
- .reshape(B, out.shape[1], C)
343
- )
344
- out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
345
- out = self.proj_out(out)
346
- return out
347
-
348
-
349
- def attn2task(task_queue, net):
350
- if False: #isinstance(net, AttnBlock):
351
- task_queue.append(('store_res', lambda x: x))
352
- task_queue.append(('pre_norm', net.norm))
353
- task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
354
- task_queue.append(['add_res', None])
355
- elif False: #isinstance(net, MemoryEfficientAttnBlock):
356
- task_queue.append(('store_res', lambda x: x))
357
- task_queue.append(('pre_norm', net.norm))
358
- task_queue.append(
359
- ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
360
- task_queue.append(['add_res', None])
361
- else:
362
- task_queue.append(('store_res', lambda x: x))
363
- task_queue.append(('pre_norm', net.norm))
364
- if is_xformers_available:
365
- # task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
366
- task_queue.append(
367
- ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
368
- elif hasattr(F, "scaled_dot_product_attention"):
369
- task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
370
- else:
371
- task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
372
- task_queue.append(['add_res', None])
373
-
374
- def resblock2task(queue, block):
375
- """
376
- Turn a ResNetBlock into a sequence of tasks and append to the task queue
377
-
378
- @param queue: the target task queue
379
- @param block: ResNetBlock
380
-
381
- """
382
- if block.in_channels != block.out_channels:
383
- if sd_flag:
384
- if block.use_conv_shortcut:
385
- queue.append(('store_res', block.conv_shortcut))
386
- else:
387
- queue.append(('store_res', block.nin_shortcut))
388
- else:
389
- if block.use_in_shortcut:
390
- queue.append(('store_res', block.conv_shortcut))
391
- else:
392
- queue.append(('store_res', block.nin_shortcut))
393
-
394
- else:
395
- queue.append(('store_res', lambda x: x))
396
- queue.append(('pre_norm', block.norm1))
397
- queue.append(('silu', inplace_nonlinearity))
398
- queue.append(('conv1', block.conv1))
399
- queue.append(('pre_norm', block.norm2))
400
- queue.append(('silu', inplace_nonlinearity))
401
- queue.append(('conv2', block.conv2))
402
- queue.append(['add_res', None])
403
-
404
-
405
- def build_sampling(task_queue, net, is_decoder):
406
- """
407
- Build the sampling part of a task queue
408
- @param task_queue: the target task queue
409
- @param net: the network
410
- @param is_decoder: currently building decoder or encoder
411
- """
412
- if is_decoder:
413
- if sd_flag:
414
- resblock2task(task_queue, net.mid.block_1)
415
- attn2task(task_queue, net.mid.attn_1)
416
- print(task_queue)
417
- resblock2task(task_queue, net.mid.block_2)
418
- resolution_iter = reversed(range(net.num_resolutions))
419
- block_ids = net.num_res_blocks + 1
420
- condition = 0
421
- module = net.up
422
- func_name = 'upsample'
423
- else:
424
- resblock2task(task_queue, net.mid_block.resnets[0])
425
- attn2task(task_queue, net.mid_block.attentions[0])
426
- resblock2task(task_queue, net.mid_block.resnets[1])
427
- resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
428
- block_ids = 2 + 1
429
- condition = len(net.up_blocks) - 1
430
- module = net.up_blocks
431
- func_name = 'upsamplers'
432
- else:
433
- if sd_flag:
434
- resolution_iter = range(net.num_resolutions)
435
- block_ids = net.num_res_blocks
436
- condition = net.num_resolutions - 1
437
- module = net.down
438
- func_name = 'downsample'
439
- else:
440
- resolution_iter = range(len(net.down_blocks))
441
- block_ids = 2
442
- condition = len(net.down_blocks) - 1
443
- module = net.down_blocks
444
- func_name = 'downsamplers'
445
-
446
- for i_level in resolution_iter:
447
- for i_block in range(block_ids):
448
- if sd_flag:
449
- resblock2task(task_queue, module[i_level].block[i_block])
450
- else:
451
- resblock2task(task_queue, module[i_level].resnets[i_block])
452
- if i_level != condition:
453
- if sd_flag:
454
- task_queue.append((func_name, getattr(module[i_level], func_name)))
455
- else:
456
- if is_decoder:
457
- task_queue.append((func_name, module[i_level].upsamplers[0]))
458
- else:
459
- task_queue.append((func_name, module[i_level].downsamplers[0]))
460
-
461
- if not is_decoder:
462
- if sd_flag:
463
- resblock2task(task_queue, net.mid.block_1)
464
- attn2task(task_queue, net.mid.attn_1)
465
- resblock2task(task_queue, net.mid.block_2)
466
- else:
467
- resblock2task(task_queue, net.mid_block.resnets[0])
468
- attn2task(task_queue, net.mid_block.attentions[0])
469
- resblock2task(task_queue, net.mid_block.resnets[1])
470
-
471
-
472
- def build_task_queue(net, is_decoder):
473
- """
474
- Build a single task queue for the encoder or decoder
475
- @param net: the VAE decoder or encoder network
476
- @param is_decoder: currently building decoder or encoder
477
- @return: the task queue
478
- """
479
- task_queue = []
480
- task_queue.append(('conv_in', net.conv_in))
481
-
482
- # construct the sampling part of the task queue
483
- # because encoder and decoder share the same architecture, we extract the sampling part
484
- build_sampling(task_queue, net, is_decoder)
485
- if is_decoder and not sd_flag:
486
- net.give_pre_end = False
487
- net.tanh_out = False
488
-
489
- if not is_decoder or not net.give_pre_end:
490
- if sd_flag:
491
- task_queue.append(('pre_norm', net.norm_out))
492
- else:
493
- task_queue.append(('pre_norm', net.conv_norm_out))
494
- task_queue.append(('silu', inplace_nonlinearity))
495
- task_queue.append(('conv_out', net.conv_out))
496
- if is_decoder and net.tanh_out:
497
- task_queue.append(('tanh', torch.tanh))
498
-
499
- return task_queue
500
-
501
-
502
- def clone_task_queue(task_queue):
503
- """
504
- Clone a task queue
505
- @param task_queue: the task queue to be cloned
506
- @return: the cloned task queue
507
- """
508
- return [[item for item in task] for task in task_queue]
509
-
510
-
511
- def get_var_mean(input, num_groups, eps=1e-6):
512
- """
513
- Get mean and var for group norm
514
- """
515
- b, c = input.size(0), input.size(1)
516
- channel_in_group = int(c/num_groups)
517
- input_reshaped = input.contiguous().view(
518
- 1, int(b * num_groups), channel_in_group, *input.size()[2:])
519
- var, mean = torch.var_mean(
520
- input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
521
- return var, mean
522
-
523
-
524
- def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
525
- """
526
- Custom group norm with fixed mean and var
527
-
528
- @param input: input tensor
529
- @param num_groups: number of groups. by default, num_groups = 32
530
- @param mean: mean, must be pre-calculated by get_var_mean
531
- @param var: var, must be pre-calculated by get_var_mean
532
- @param weight: weight, should be fetched from the original group norm
533
- @param bias: bias, should be fetched from the original group norm
534
- @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
535
-
536
- @return: normalized tensor
537
- """
538
- b, c = input.size(0), input.size(1)
539
- channel_in_group = int(c/num_groups)
540
- input_reshaped = input.contiguous().view(
541
- 1, int(b * num_groups), channel_in_group, *input.size()[2:])
542
-
543
- out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
544
- training=False, momentum=0, eps=eps)
545
-
546
- out = out.view(b, c, *input.size()[2:])
547
-
548
- # post affine transform
549
- if weight is not None:
550
- out *= weight.view(1, -1, 1, 1)
551
- if bias is not None:
552
- out += bias.view(1, -1, 1, 1)
553
- return out
554
-
555
-
556
- def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
557
- """
558
- Crop the valid region from the tile
559
- @param x: input tile
560
- @param input_bbox: original input bounding box
561
- @param target_bbox: output bounding box
562
- @param scale: scale factor
563
- @return: cropped tile
564
- """
565
- padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
566
- margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
567
- return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
568
-
569
- # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
570
-
571
-
572
- def perfcount(fn):
573
- def wrapper(*args, **kwargs):
574
- ts = time()
575
-
576
- if torch.cuda.is_available():
577
- torch.cuda.reset_peak_memory_stats(devices.device)
578
- devices.torch_gc()
579
- gc.collect()
580
-
581
- ret = fn(*args, **kwargs)
582
-
583
- devices.torch_gc()
584
- gc.collect()
585
- if torch.cuda.is_available():
586
- vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
587
- torch.cuda.reset_peak_memory_stats(devices.device)
588
- print(
589
- f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
590
- else:
591
- print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
592
-
593
- return ret
594
- return wrapper
595
-
596
- # copy end :)
597
-
598
-
599
- class GroupNormParam:
600
- def __init__(self):
601
- self.var_list = []
602
- self.mean_list = []
603
- self.pixel_list = []
604
- self.weight = None
605
- self.bias = None
606
-
607
- def add_tile(self, tile, layer):
608
- var, mean = get_var_mean(tile, 32)
609
- # For giant images, the variance can be larger than max float16
610
- # In this case we create a copy to float32
611
- if var.dtype == torch.float16 and var.isinf().any():
612
- fp32_tile = tile.float()
613
- var, mean = get_var_mean(fp32_tile, 32)
614
- # ============= DEBUG: test for infinite =============
615
- # if torch.isinf(var).any():
616
- # print('var: ', var)
617
- # ====================================================
618
- self.var_list.append(var)
619
- self.mean_list.append(mean)
620
- self.pixel_list.append(
621
- tile.shape[2]*tile.shape[3])
622
- if hasattr(layer, 'weight'):
623
- self.weight = layer.weight
624
- self.bias = layer.bias
625
- else:
626
- self.weight = None
627
- self.bias = None
628
-
629
- def summary(self):
630
- """
631
- summarize the mean and var and return a function
632
- that apply group norm on each tile
633
- """
634
- if len(self.var_list) == 0:
635
- return None
636
- var = torch.vstack(self.var_list)
637
- mean = torch.vstack(self.mean_list)
638
- max_value = max(self.pixel_list)
639
- pixels = torch.tensor(
640
- self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
641
- sum_pixels = torch.sum(pixels)
642
- pixels = pixels.unsqueeze(
643
- 1) / sum_pixels
644
- var = torch.sum(
645
- var * pixels, dim=0)
646
- mean = torch.sum(
647
- mean * pixels, dim=0)
648
- return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
649
-
650
- @staticmethod
651
- def from_tile(tile, norm):
652
- """
653
- create a function from a single tile without summary
654
- """
655
- var, mean = get_var_mean(tile, 32)
656
- if var.dtype == torch.float16 and var.isinf().any():
657
- fp32_tile = tile.float()
658
- var, mean = get_var_mean(fp32_tile, 32)
659
- # if it is a macbook, we need to convert back to float16
660
- if var.device.type == 'mps':
661
- # clamp to avoid overflow
662
- var = torch.clamp(var, 0, 60000)
663
- var = var.half()
664
- mean = mean.half()
665
- if hasattr(norm, 'weight'):
666
- weight = norm.weight
667
- bias = norm.bias
668
- else:
669
- weight = None
670
- bias = None
671
-
672
- def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
673
- return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
674
- return group_norm_func
675
-
676
-
677
- class VAEHook:
678
- def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
679
- self.net = net # encoder | decoder
680
- self.tile_size = tile_size
681
- self.is_decoder = is_decoder
682
- self.fast_mode = (fast_encoder and not is_decoder) or (
683
- fast_decoder and is_decoder)
684
- self.color_fix = color_fix and not is_decoder
685
- self.to_gpu = to_gpu
686
- self.pad = 11 if is_decoder else 32
687
-
688
- def __call__(self, x):
689
- B, C, H, W = x.shape
690
- original_device = next(self.net.parameters()).device
691
- try:
692
- if self.to_gpu:
693
- self.net.to(devices.get_optimal_device())
694
- if max(H, W) <= self.pad * 2 + self.tile_size:
695
- print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
696
- return self.net.original_forward(x)
697
- else:
698
- return self.vae_tile_forward(x)
699
- finally:
700
- self.net.to(original_device)
701
-
702
- def get_best_tile_size(self, lowerbound, upperbound):
703
- """
704
- Get the best tile size for GPU memory
705
- """
706
- divider = 32
707
- while divider >= 2:
708
- remainer = lowerbound % divider
709
- if remainer == 0:
710
- return lowerbound
711
- candidate = lowerbound - remainer + divider
712
- if candidate <= upperbound:
713
- return candidate
714
- divider //= 2
715
- return lowerbound
716
-
717
- def split_tiles(self, h, w):
718
- """
719
- Tool function to split the image into tiles
720
- @param h: height of the image
721
- @param w: width of the image
722
- @return: tile_input_bboxes, tile_output_bboxes
723
- """
724
- tile_input_bboxes, tile_output_bboxes = [], []
725
- tile_size = self.tile_size
726
- pad = self.pad
727
- num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
728
- num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
729
- # If any of the numbers are 0, we let it be 1
730
- # This is to deal with long and thin images
731
- num_height_tiles = max(num_height_tiles, 1)
732
- num_width_tiles = max(num_width_tiles, 1)
733
-
734
- # Suggestions from https://github.com/Kahsolt: auto shrink the tile size
735
- real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
736
- real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
737
- real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
738
- real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
739
-
740
- print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
741
- f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
742
-
743
- for i in range(num_height_tiles):
744
- for j in range(num_width_tiles):
745
- # bbox: [x1, x2, y1, y2]
746
- # the padding is is unnessary for image borders. So we directly start from (32, 32)
747
- input_bbox = [
748
- pad + j * real_tile_width,
749
- min(pad + (j + 1) * real_tile_width, w),
750
- pad + i * real_tile_height,
751
- min(pad + (i + 1) * real_tile_height, h),
752
- ]
753
-
754
- # if the output bbox is close to the image boundary, we extend it to the image boundary
755
- output_bbox = [
756
- input_bbox[0] if input_bbox[0] > pad else 0,
757
- input_bbox[1] if input_bbox[1] < w - pad else w,
758
- input_bbox[2] if input_bbox[2] > pad else 0,
759
- input_bbox[3] if input_bbox[3] < h - pad else h,
760
- ]
761
-
762
- # scale to get the final output bbox
763
- output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
764
- tile_output_bboxes.append(output_bbox)
765
-
766
- # indistinguishable expand the input bbox by pad pixels
767
- tile_input_bboxes.append([
768
- max(0, input_bbox[0] - pad),
769
- min(w, input_bbox[1] + pad),
770
- max(0, input_bbox[2] - pad),
771
- min(h, input_bbox[3] + pad),
772
- ])
773
-
774
- return tile_input_bboxes, tile_output_bboxes
775
-
776
- @torch.no_grad()
777
- def estimate_group_norm(self, z, task_queue, color_fix):
778
- device = z.device
779
- tile = z
780
- last_id = len(task_queue) - 1
781
- while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
782
- last_id -= 1
783
- if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
784
- raise ValueError('No group norm found in the task queue')
785
- # estimate until the last group norm
786
- for i in range(last_id + 1):
787
- task = task_queue[i]
788
- if task[0] == 'pre_norm':
789
- group_norm_func = GroupNormParam.from_tile(tile, task[1])
790
- task_queue[i] = ('apply_norm', group_norm_func)
791
- if i == last_id:
792
- return True
793
- tile = group_norm_func(tile)
794
- elif task[0] == 'store_res':
795
- task_id = i + 1
796
- while task_id < last_id and task_queue[task_id][0] != 'add_res':
797
- task_id += 1
798
- if task_id >= last_id:
799
- continue
800
- task_queue[task_id][1] = task[1](tile)
801
- elif task[0] == 'add_res':
802
- tile += task[1].to(device)
803
- task[1] = None
804
- elif color_fix and task[0] == 'downsample':
805
- for j in range(i, last_id + 1):
806
- if task_queue[j][0] == 'store_res':
807
- task_queue[j] = ('store_res_cpu', task_queue[j][1])
808
- return True
809
- else:
810
- tile = task[1](tile)
811
- try:
812
- devices.test_for_nans(tile, "vae")
813
- except:
814
- print(f'Nan detected in fast mode estimation. Fast mode disabled.')
815
- return False
816
-
817
- raise IndexError('Should not reach here')
818
-
819
- @perfcount
820
- @torch.no_grad()
821
- def vae_tile_forward(self, z):
822
- """
823
- Decode a latent vector z into an image in a tiled manner.
824
- @param z: latent vector
825
- @return: image
826
- """
827
- device = next(self.net.parameters()).device
828
- dtype = z.dtype
829
- net = self.net
830
- tile_size = self.tile_size
831
- is_decoder = self.is_decoder
832
-
833
- z = z.detach() # detach the input to avoid backprop
834
-
835
- N, height, width = z.shape[0], z.shape[2], z.shape[3]
836
- net.last_z_shape = z.shape
837
-
838
- # Split the input into tiles and build a task queue for each tile
839
- print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
840
-
841
- in_bboxes, out_bboxes = self.split_tiles(height, width)
842
-
843
- # Prepare tiles by split the input latents
844
- tiles = []
845
- for input_bbox in in_bboxes:
846
- tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
847
- tiles.append(tile)
848
-
849
- num_tiles = len(tiles)
850
- num_completed = 0
851
-
852
- # Build task queues
853
- single_task_queue = build_task_queue(net, is_decoder)
854
- #print(single_task_queue)
855
- if self.fast_mode:
856
- # Fast mode: downsample the input image to the tile size,
857
- # then estimate the group norm parameters on the downsampled image
858
- scale_factor = tile_size / max(height, width)
859
- z = z.to(device)
860
- downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
861
- # use nearest-exact to keep statictics as close as possible
862
- print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
863
-
864
- # ======= Special thanks to @Kahsolt for distribution shift issue ======= #
865
- # The downsampling will heavily distort its mean and std, so we need to recover it.
866
- std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
867
- std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
868
- downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
869
- del std_old, mean_old, std_new, mean_new
870
- # occasionally the std_new is too small or too large, which exceeds the range of float16
871
- # so we need to clamp it to max z's range.
872
- downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
873
- estimate_task_queue = clone_task_queue(single_task_queue)
874
- if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
875
- single_task_queue = estimate_task_queue
876
- del downsampled_z
877
-
878
- task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
879
-
880
- # Dummy result
881
- result = None
882
- result_approx = None
883
- #try:
884
- # with devices.autocast():
885
- # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
886
- #except: pass
887
- # Free memory of input latent tensor
888
- del z
889
-
890
- # Task queue execution
891
- pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
892
-
893
- # execute the task back and forth when switch tiles so that we always
894
- # keep one tile on the GPU to reduce unnecessary data transfer
895
- forward = True
896
- interrupted = False
897
- #state.interrupted = interrupted
898
- while True:
899
- #if state.interrupted: interrupted = True ; break
900
-
901
- group_norm_param = GroupNormParam()
902
- for i in range(num_tiles) if forward else reversed(range(num_tiles)):
903
- #if state.interrupted: interrupted = True ; break
904
-
905
- tile = tiles[i].to(device)
906
- input_bbox = in_bboxes[i]
907
- task_queue = task_queues[i]
908
-
909
- interrupted = False
910
- while len(task_queue) > 0:
911
- #if state.interrupted: interrupted = True ; break
912
-
913
- # DEBUG: current task
914
- # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
915
- task = task_queue.pop(0)
916
- if task[0] == 'pre_norm':
917
- group_norm_param.add_tile(tile, task[1])
918
- break
919
- elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
920
- task_id = 0
921
- res = task[1](tile)
922
- if not self.fast_mode or task[0] == 'store_res_cpu':
923
- res = res.cpu()
924
- while task_queue[task_id][0] != 'add_res':
925
- task_id += 1
926
- task_queue[task_id][1] = res
927
- elif task[0] == 'add_res':
928
- tile += task[1].to(device)
929
- task[1] = None
930
- else:
931
- tile = task[1](tile)
932
- #print(tiles[i].shape, tile.shape, task)
933
- pbar.update(1)
934
-
935
- if interrupted: break
936
-
937
- # check for NaNs in the tile.
938
- # If there are NaNs, we abort the process to save user's time
939
- #devices.test_for_nans(tile, "vae")
940
-
941
- #print(tiles[i].shape, tile.shape, i, num_tiles)
942
- if len(task_queue) == 0:
943
- tiles[i] = None
944
- num_completed += 1
945
- if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
946
- result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
947
- result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
948
- del tile
949
- elif i == num_tiles - 1 and forward:
950
- forward = False
951
- tiles[i] = tile
952
- elif i == 0 and not forward:
953
- forward = True
954
- tiles[i] = tile
955
- else:
956
- tiles[i] = tile.cpu()
957
- del tile
958
-
959
- if interrupted: break
960
- if num_completed == num_tiles: break
961
-
962
- # insert the group norm task to the head of each task queue
963
- group_norm_func = group_norm_param.summary()
964
- if group_norm_func is not None:
965
- for i in range(num_tiles):
966
- task_queue = task_queues[i]
967
- task_queue.insert(0, ('apply_norm', group_norm_func))
968
-
969
- # Done!
970
- pbar.close()
971
- return result.to(dtype) if result is not None else result_approx.to(device)