Delete SUPIR
Browse files- SUPIR/__init__.py +0 -0
- SUPIR/models/SUPIR_model.py +0 -195
- SUPIR/models/__init__.py +0 -0
- SUPIR/modules/SUPIR_v0.py +0 -718
- SUPIR/modules/__init__.py +0 -11
- SUPIR/util.py +0 -179
- SUPIR/utils/__init__.py +0 -0
- SUPIR/utils/colorfix.py +0 -120
- SUPIR/utils/devices.py +0 -138
- SUPIR/utils/face_restoration_helper.py +0 -514
- SUPIR/utils/file.py +0 -79
- SUPIR/utils/tilevae.py +0 -971
SUPIR/__init__.py
DELETED
|
File without changes
|
SUPIR/models/SUPIR_model.py
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from sgm.models.diffusion import DiffusionEngine
|
| 3 |
-
from sgm.util import instantiate_from_config
|
| 4 |
-
import copy
|
| 5 |
-
from sgm.modules.distributions.distributions import DiagonalGaussianDistribution
|
| 6 |
-
import random
|
| 7 |
-
from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
|
| 8 |
-
from pytorch_lightning import seed_everything
|
| 9 |
-
from torch.nn.functional import interpolate
|
| 10 |
-
from SUPIR.utils.tilevae import VAEHook
|
| 11 |
-
|
| 12 |
-
class SUPIRModel(DiffusionEngine):
|
| 13 |
-
def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
|
| 14 |
-
super().__init__(*args, **kwargs)
|
| 15 |
-
control_model = instantiate_from_config(control_stage_config)
|
| 16 |
-
self.model.load_control_model(control_model)
|
| 17 |
-
self.first_stage_model.denoise_encoder = copy.deepcopy(self.first_stage_model.encoder)
|
| 18 |
-
self.sampler_config = kwargs['sampler_config']
|
| 19 |
-
|
| 20 |
-
assert (ae_dtype in ['fp32', 'fp16', 'bf16']) and (diffusion_dtype in ['fp32', 'fp16', 'bf16'])
|
| 21 |
-
if ae_dtype == 'fp32':
|
| 22 |
-
ae_dtype = torch.float32
|
| 23 |
-
elif ae_dtype == 'fp16':
|
| 24 |
-
raise RuntimeError('fp16 cause NaN in AE')
|
| 25 |
-
elif ae_dtype == 'bf16':
|
| 26 |
-
ae_dtype = torch.bfloat16
|
| 27 |
-
|
| 28 |
-
if diffusion_dtype == 'fp32':
|
| 29 |
-
diffusion_dtype = torch.float32
|
| 30 |
-
elif diffusion_dtype == 'fp16':
|
| 31 |
-
diffusion_dtype = torch.float16
|
| 32 |
-
elif diffusion_dtype == 'bf16':
|
| 33 |
-
diffusion_dtype = torch.bfloat16
|
| 34 |
-
|
| 35 |
-
self.ae_dtype = ae_dtype
|
| 36 |
-
self.model.dtype = diffusion_dtype
|
| 37 |
-
|
| 38 |
-
self.p_p = p_p
|
| 39 |
-
self.n_p = n_p
|
| 40 |
-
|
| 41 |
-
@torch.no_grad()
|
| 42 |
-
def encode_first_stage(self, x):
|
| 43 |
-
with torch.autocast("cuda", dtype=self.ae_dtype):
|
| 44 |
-
z = self.first_stage_model.encode(x)
|
| 45 |
-
z = self.scale_factor * z
|
| 46 |
-
return z
|
| 47 |
-
|
| 48 |
-
@torch.no_grad()
|
| 49 |
-
def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
|
| 50 |
-
with torch.autocast("cuda", dtype=self.ae_dtype):
|
| 51 |
-
if is_stage1:
|
| 52 |
-
h = self.first_stage_model.denoise_encoder_s1(x)
|
| 53 |
-
else:
|
| 54 |
-
h = self.first_stage_model.denoise_encoder(x)
|
| 55 |
-
moments = self.first_stage_model.quant_conv(h)
|
| 56 |
-
posterior = DiagonalGaussianDistribution(moments)
|
| 57 |
-
if use_sample:
|
| 58 |
-
z = posterior.sample()
|
| 59 |
-
else:
|
| 60 |
-
z = posterior.mode()
|
| 61 |
-
z = self.scale_factor * z
|
| 62 |
-
return z
|
| 63 |
-
|
| 64 |
-
@torch.no_grad()
|
| 65 |
-
def decode_first_stage(self, z):
|
| 66 |
-
z = 1.0 / self.scale_factor * z
|
| 67 |
-
with torch.autocast("cuda", dtype=self.ae_dtype):
|
| 68 |
-
out = self.first_stage_model.decode(z)
|
| 69 |
-
return out.float()
|
| 70 |
-
|
| 71 |
-
@torch.no_grad()
|
| 72 |
-
def batchify_denoise(self, x, is_stage1=False):
|
| 73 |
-
'''
|
| 74 |
-
[N, C, H, W], [-1, 1], RGB
|
| 75 |
-
'''
|
| 76 |
-
x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
|
| 77 |
-
return self.decode_first_stage(x)
|
| 78 |
-
|
| 79 |
-
@torch.no_grad()
|
| 80 |
-
def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, restoration_scale=4.0, s_churn=0, s_noise=1.003, cfg_scale=4.0, seed=-1,
|
| 81 |
-
num_samples=1, control_scale=1, color_fix_type='None', use_linear_CFG=False, use_linear_control_scale=False,
|
| 82 |
-
cfg_scale_start=1.0, control_scale_start=0.0, **kwargs):
|
| 83 |
-
'''
|
| 84 |
-
[N, C], [-1, 1], RGB
|
| 85 |
-
'''
|
| 86 |
-
assert len(x) == len(p)
|
| 87 |
-
assert color_fix_type in ['Wavelet', 'AdaIn', 'None']
|
| 88 |
-
|
| 89 |
-
N = len(x)
|
| 90 |
-
if num_samples > 1:
|
| 91 |
-
assert N == 1
|
| 92 |
-
N = num_samples
|
| 93 |
-
x = x.repeat(N, 1, 1, 1)
|
| 94 |
-
p = p * N
|
| 95 |
-
|
| 96 |
-
if p_p == 'default':
|
| 97 |
-
p_p = self.p_p
|
| 98 |
-
if n_p == 'default':
|
| 99 |
-
n_p = self.n_p
|
| 100 |
-
|
| 101 |
-
self.sampler_config.params.num_steps = num_steps
|
| 102 |
-
if use_linear_CFG:
|
| 103 |
-
self.sampler_config.params.guider_config.params.scale_min = cfg_scale
|
| 104 |
-
self.sampler_config.params.guider_config.params.scale = cfg_scale_start
|
| 105 |
-
else:
|
| 106 |
-
self.sampler_config.params.guider_config.params.scale_min = cfg_scale
|
| 107 |
-
self.sampler_config.params.guider_config.params.scale = cfg_scale
|
| 108 |
-
self.sampler_config.params.restore_cfg = restoration_scale
|
| 109 |
-
self.sampler_config.params.s_churn = s_churn
|
| 110 |
-
self.sampler_config.params.s_noise = s_noise
|
| 111 |
-
self.sampler = instantiate_from_config(self.sampler_config)
|
| 112 |
-
|
| 113 |
-
if seed == -1:
|
| 114 |
-
seed = random.randint(0, 65535)
|
| 115 |
-
seed_everything(seed)
|
| 116 |
-
|
| 117 |
-
_z = self.encode_first_stage_with_denoise(x, use_sample=False)
|
| 118 |
-
x_stage1 = self.decode_first_stage(_z)
|
| 119 |
-
z_stage1 = self.encode_first_stage(x_stage1)
|
| 120 |
-
|
| 121 |
-
c, uc = self.prepare_condition(_z, p, p_p, n_p, N)
|
| 122 |
-
|
| 123 |
-
denoiser = lambda input, sigma, c, control_scale: self.denoiser(
|
| 124 |
-
self.model, input, sigma, c, control_scale, **kwargs
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
noised_z = torch.randn_like(_z).to(_z.device)
|
| 128 |
-
|
| 129 |
-
_samples = self.sampler(denoiser, noised_z, cond=c, uc=uc, x_center=z_stage1, control_scale=control_scale,
|
| 130 |
-
use_linear_control_scale=use_linear_control_scale, control_scale_start=control_scale_start)
|
| 131 |
-
samples = self.decode_first_stage(_samples)
|
| 132 |
-
if color_fix_type == 'Wavelet':
|
| 133 |
-
samples = wavelet_reconstruction(samples, x_stage1)
|
| 134 |
-
elif color_fix_type == 'AdaIn':
|
| 135 |
-
samples = adaptive_instance_normalization(samples, x_stage1)
|
| 136 |
-
return samples
|
| 137 |
-
|
| 138 |
-
def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
|
| 139 |
-
self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
|
| 140 |
-
self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
|
| 141 |
-
self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
|
| 142 |
-
self.first_stage_model.denoise_encoder.forward = VAEHook(
|
| 143 |
-
self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
|
| 144 |
-
fast_encoder=False, color_fix=False, to_gpu=True)
|
| 145 |
-
self.first_stage_model.encoder.forward = VAEHook(
|
| 146 |
-
self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
|
| 147 |
-
fast_encoder=False, color_fix=False, to_gpu=True)
|
| 148 |
-
self.first_stage_model.decoder.forward = VAEHook(
|
| 149 |
-
self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
|
| 150 |
-
fast_encoder=False, color_fix=False, to_gpu=True)
|
| 151 |
-
|
| 152 |
-
def prepare_condition(self, _z, p, p_p, n_p, N):
|
| 153 |
-
batch = {}
|
| 154 |
-
batch['original_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
|
| 155 |
-
batch['crop_coords_top_left'] = torch.tensor([0, 0]).repeat(N, 1).to(_z.device)
|
| 156 |
-
batch['target_size_as_tuple'] = torch.tensor([1024, 1024]).repeat(N, 1).to(_z.device)
|
| 157 |
-
batch['aesthetic_score'] = torch.tensor([9.0]).repeat(N, 1).to(_z.device)
|
| 158 |
-
batch['control'] = _z
|
| 159 |
-
|
| 160 |
-
batch_uc = copy.deepcopy(batch)
|
| 161 |
-
batch_uc['txt'] = [n_p for _ in p]
|
| 162 |
-
|
| 163 |
-
if not isinstance(p[0], list):
|
| 164 |
-
batch['txt'] = [''.join([_p, p_p]) for _p in p]
|
| 165 |
-
with torch.cuda.amp.autocast(dtype=self.ae_dtype):
|
| 166 |
-
c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
|
| 167 |
-
else:
|
| 168 |
-
assert len(p) == 1, 'Support bs=1 only for local prompt conditioning.'
|
| 169 |
-
p_tiles = p[0]
|
| 170 |
-
c = []
|
| 171 |
-
for i, p_tile in enumerate(p_tiles):
|
| 172 |
-
batch['txt'] = [''.join([p_tile, p_p])]
|
| 173 |
-
with torch.cuda.amp.autocast(dtype=self.ae_dtype):
|
| 174 |
-
if i == 0:
|
| 175 |
-
_c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
|
| 176 |
-
else:
|
| 177 |
-
_c, _ = self.conditioner.get_unconditional_conditioning(batch, None)
|
| 178 |
-
c.append(_c)
|
| 179 |
-
return c, uc
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
if __name__ == '__main__':
|
| 183 |
-
from SUPIR.util import create_model, load_state_dict
|
| 184 |
-
|
| 185 |
-
model = create_model('../../options/dev/SUPIR_paper_version.yaml')
|
| 186 |
-
|
| 187 |
-
SDXL_CKPT = '/opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors'
|
| 188 |
-
SUPIR_CKPT = '/opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-paper.ckpt'
|
| 189 |
-
model.load_state_dict(load_state_dict(SDXL_CKPT), strict=False)
|
| 190 |
-
model.load_state_dict(load_state_dict(SUPIR_CKPT), strict=False)
|
| 191 |
-
model = model.cuda()
|
| 192 |
-
|
| 193 |
-
x = torch.randn(1, 3, 512, 512).cuda()
|
| 194 |
-
p = ['a professional, detailed, high-quality photo']
|
| 195 |
-
samples = model.batchify_sample(x, p, num_steps=50, restoration_scale=4.0, s_churn=0, cfg_scale=4.0, seed=-1, num_samples=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/models/__init__.py
DELETED
|
File without changes
|
SUPIR/modules/SUPIR_v0.py
DELETED
|
@@ -1,718 +0,0 @@
|
|
| 1 |
-
# from einops._torch_specific import allow_ops_in_compiled_graph
|
| 2 |
-
# allow_ops_in_compiled_graph()
|
| 3 |
-
import einops
|
| 4 |
-
import torch
|
| 5 |
-
import torch as th
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from einops import rearrange, repeat
|
| 8 |
-
|
| 9 |
-
from sgm.modules.diffusionmodules.util import (
|
| 10 |
-
avg_pool_nd,
|
| 11 |
-
checkpoint,
|
| 12 |
-
conv_nd,
|
| 13 |
-
linear,
|
| 14 |
-
normalization,
|
| 15 |
-
timestep_embedding,
|
| 16 |
-
zero_module,
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
from sgm.modules.diffusionmodules.openaimodel import Downsample, Upsample, UNetModel, Timestep, \
|
| 20 |
-
TimestepEmbedSequential, ResBlock, AttentionBlock, TimestepBlock
|
| 21 |
-
from sgm.modules.attention import SpatialTransformer, MemoryEfficientCrossAttention, CrossAttention
|
| 22 |
-
from sgm.util import default, log_txt_as_img, exists, instantiate_from_config
|
| 23 |
-
import re
|
| 24 |
-
import torch
|
| 25 |
-
from functools import partial
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
import xformers
|
| 30 |
-
import xformers.ops
|
| 31 |
-
XFORMERS_IS_AVAILBLE = True
|
| 32 |
-
except:
|
| 33 |
-
XFORMERS_IS_AVAILBLE = False
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# dummy replace
|
| 37 |
-
def convert_module_to_f16(x):
|
| 38 |
-
pass
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def convert_module_to_f32(x):
|
| 42 |
-
pass
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class ZeroConv(nn.Module):
|
| 46 |
-
def __init__(self, label_nc, norm_nc, mask=False):
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
|
| 49 |
-
self.mask = mask
|
| 50 |
-
|
| 51 |
-
def forward(self, c, h, h_ori=None):
|
| 52 |
-
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
| 53 |
-
if not self.mask:
|
| 54 |
-
h = h + self.zero_conv(c)
|
| 55 |
-
else:
|
| 56 |
-
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
| 57 |
-
if h_ori is not None:
|
| 58 |
-
h = th.cat([h_ori, h], dim=1)
|
| 59 |
-
return h
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class ZeroSFT(nn.Module):
|
| 63 |
-
def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
|
| 64 |
-
super().__init__()
|
| 65 |
-
|
| 66 |
-
# param_free_norm_type = str(parsed.group(1))
|
| 67 |
-
ks = 3
|
| 68 |
-
pw = ks // 2
|
| 69 |
-
|
| 70 |
-
self.norm = norm
|
| 71 |
-
if self.norm:
|
| 72 |
-
self.param_free_norm = normalization(norm_nc + concat_channels)
|
| 73 |
-
else:
|
| 74 |
-
self.param_free_norm = nn.Identity()
|
| 75 |
-
|
| 76 |
-
nhidden = 128
|
| 77 |
-
|
| 78 |
-
self.mlp_shared = nn.Sequential(
|
| 79 |
-
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
|
| 80 |
-
nn.SiLU()
|
| 81 |
-
)
|
| 82 |
-
self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
| 83 |
-
self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
|
| 84 |
-
# self.zero_mul = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
|
| 85 |
-
# self.zero_add = nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)
|
| 86 |
-
|
| 87 |
-
self.zero_conv = zero_module(conv_nd(2, label_nc, norm_nc, 1, 1, 0))
|
| 88 |
-
self.pre_concat = bool(concat_channels != 0)
|
| 89 |
-
self.mask = mask
|
| 90 |
-
|
| 91 |
-
def forward(self, c, h, h_ori=None, control_scale=1):
|
| 92 |
-
assert self.mask is False
|
| 93 |
-
if h_ori is not None and self.pre_concat:
|
| 94 |
-
h_raw = th.cat([h_ori, h], dim=1)
|
| 95 |
-
else:
|
| 96 |
-
h_raw = h
|
| 97 |
-
|
| 98 |
-
if self.mask:
|
| 99 |
-
h = h + self.zero_conv(c) * torch.zeros_like(h)
|
| 100 |
-
else:
|
| 101 |
-
h = h + self.zero_conv(c)
|
| 102 |
-
if h_ori is not None and self.pre_concat:
|
| 103 |
-
h = th.cat([h_ori, h], dim=1)
|
| 104 |
-
actv = self.mlp_shared(c)
|
| 105 |
-
gamma = self.zero_mul(actv)
|
| 106 |
-
beta = self.zero_add(actv)
|
| 107 |
-
if self.mask:
|
| 108 |
-
gamma = gamma * torch.zeros_like(gamma)
|
| 109 |
-
beta = beta * torch.zeros_like(beta)
|
| 110 |
-
h = self.param_free_norm(h) * (gamma + 1) + beta
|
| 111 |
-
if h_ori is not None and not self.pre_concat:
|
| 112 |
-
h = th.cat([h_ori, h], dim=1)
|
| 113 |
-
return h * control_scale + h_raw * (1 - control_scale)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class ZeroCrossAttn(nn.Module):
|
| 117 |
-
ATTENTION_MODES = {
|
| 118 |
-
"softmax": CrossAttention, # vanilla attention
|
| 119 |
-
"softmax-xformers": MemoryEfficientCrossAttention
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
def __init__(self, context_dim, query_dim, zero_out=True, mask=False):
|
| 123 |
-
super().__init__()
|
| 124 |
-
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
| 125 |
-
assert attn_mode in self.ATTENTION_MODES
|
| 126 |
-
attn_cls = self.ATTENTION_MODES[attn_mode]
|
| 127 |
-
self.attn = attn_cls(query_dim=query_dim, context_dim=context_dim, heads=query_dim//64, dim_head=64)
|
| 128 |
-
self.norm1 = normalization(query_dim)
|
| 129 |
-
self.norm2 = normalization(context_dim)
|
| 130 |
-
|
| 131 |
-
self.mask = mask
|
| 132 |
-
|
| 133 |
-
# if zero_out:
|
| 134 |
-
# # for p in self.attn.to_out.parameters():
|
| 135 |
-
# # p.detach().zero_()
|
| 136 |
-
# self.attn.to_out = zero_module(self.attn.to_out)
|
| 137 |
-
|
| 138 |
-
def forward(self, context, x, control_scale=1):
|
| 139 |
-
assert self.mask is False
|
| 140 |
-
x_in = x
|
| 141 |
-
x = self.norm1(x)
|
| 142 |
-
context = self.norm2(context)
|
| 143 |
-
b, c, h, w = x.shape
|
| 144 |
-
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
| 145 |
-
context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
|
| 146 |
-
x = self.attn(x, context)
|
| 147 |
-
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
| 148 |
-
if self.mask:
|
| 149 |
-
x = x * torch.zeros_like(x)
|
| 150 |
-
x = x_in + x * control_scale
|
| 151 |
-
|
| 152 |
-
return x
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
class GLVControl(nn.Module):
|
| 156 |
-
def __init__(
|
| 157 |
-
self,
|
| 158 |
-
in_channels,
|
| 159 |
-
model_channels,
|
| 160 |
-
out_channels,
|
| 161 |
-
num_res_blocks,
|
| 162 |
-
attention_resolutions,
|
| 163 |
-
dropout=0,
|
| 164 |
-
channel_mult=(1, 2, 4, 8),
|
| 165 |
-
conv_resample=True,
|
| 166 |
-
dims=2,
|
| 167 |
-
num_classes=None,
|
| 168 |
-
use_checkpoint=False,
|
| 169 |
-
use_fp16=False,
|
| 170 |
-
num_heads=-1,
|
| 171 |
-
num_head_channels=-1,
|
| 172 |
-
num_heads_upsample=-1,
|
| 173 |
-
use_scale_shift_norm=False,
|
| 174 |
-
resblock_updown=False,
|
| 175 |
-
use_new_attention_order=False,
|
| 176 |
-
use_spatial_transformer=False, # custom transformer support
|
| 177 |
-
transformer_depth=1, # custom transformer support
|
| 178 |
-
context_dim=None, # custom transformer support
|
| 179 |
-
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 180 |
-
legacy=True,
|
| 181 |
-
disable_self_attentions=None,
|
| 182 |
-
num_attention_blocks=None,
|
| 183 |
-
disable_middle_self_attn=False,
|
| 184 |
-
use_linear_in_transformer=False,
|
| 185 |
-
spatial_transformer_attn_type="softmax",
|
| 186 |
-
adm_in_channels=None,
|
| 187 |
-
use_fairscale_checkpoint=False,
|
| 188 |
-
offload_to_cpu=False,
|
| 189 |
-
transformer_depth_middle=None,
|
| 190 |
-
input_upscale=1,
|
| 191 |
-
):
|
| 192 |
-
super().__init__()
|
| 193 |
-
from omegaconf.listconfig import ListConfig
|
| 194 |
-
|
| 195 |
-
if use_spatial_transformer:
|
| 196 |
-
assert (
|
| 197 |
-
context_dim is not None
|
| 198 |
-
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
| 199 |
-
|
| 200 |
-
if context_dim is not None:
|
| 201 |
-
assert (
|
| 202 |
-
use_spatial_transformer
|
| 203 |
-
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
| 204 |
-
if type(context_dim) == ListConfig:
|
| 205 |
-
context_dim = list(context_dim)
|
| 206 |
-
|
| 207 |
-
if num_heads_upsample == -1:
|
| 208 |
-
num_heads_upsample = num_heads
|
| 209 |
-
|
| 210 |
-
if num_heads == -1:
|
| 211 |
-
assert (
|
| 212 |
-
num_head_channels != -1
|
| 213 |
-
), "Either num_heads or num_head_channels has to be set"
|
| 214 |
-
|
| 215 |
-
if num_head_channels == -1:
|
| 216 |
-
assert (
|
| 217 |
-
num_heads != -1
|
| 218 |
-
), "Either num_heads or num_head_channels has to be set"
|
| 219 |
-
|
| 220 |
-
self.in_channels = in_channels
|
| 221 |
-
self.model_channels = model_channels
|
| 222 |
-
self.out_channels = out_channels
|
| 223 |
-
if isinstance(transformer_depth, int):
|
| 224 |
-
transformer_depth = len(channel_mult) * [transformer_depth]
|
| 225 |
-
elif isinstance(transformer_depth, ListConfig):
|
| 226 |
-
transformer_depth = list(transformer_depth)
|
| 227 |
-
transformer_depth_middle = default(
|
| 228 |
-
transformer_depth_middle, transformer_depth[-1]
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
if isinstance(num_res_blocks, int):
|
| 232 |
-
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 233 |
-
else:
|
| 234 |
-
if len(num_res_blocks) != len(channel_mult):
|
| 235 |
-
raise ValueError(
|
| 236 |
-
"provide num_res_blocks either as an int (globally constant) or "
|
| 237 |
-
"as a list/tuple (per-level) with the same length as channel_mult"
|
| 238 |
-
)
|
| 239 |
-
self.num_res_blocks = num_res_blocks
|
| 240 |
-
# self.num_res_blocks = num_res_blocks
|
| 241 |
-
if disable_self_attentions is not None:
|
| 242 |
-
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 243 |
-
assert len(disable_self_attentions) == len(channel_mult)
|
| 244 |
-
if num_attention_blocks is not None:
|
| 245 |
-
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 246 |
-
assert all(
|
| 247 |
-
map(
|
| 248 |
-
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
| 249 |
-
range(len(num_attention_blocks)),
|
| 250 |
-
)
|
| 251 |
-
)
|
| 252 |
-
print(
|
| 253 |
-
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 254 |
-
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 255 |
-
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 256 |
-
f"attention will still not be set."
|
| 257 |
-
) # todo: convert to warning
|
| 258 |
-
|
| 259 |
-
self.attention_resolutions = attention_resolutions
|
| 260 |
-
self.dropout = dropout
|
| 261 |
-
self.channel_mult = channel_mult
|
| 262 |
-
self.conv_resample = conv_resample
|
| 263 |
-
self.num_classes = num_classes
|
| 264 |
-
self.use_checkpoint = use_checkpoint
|
| 265 |
-
if use_fp16:
|
| 266 |
-
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
| 267 |
-
# self.dtype = th.float16 if use_fp16 else th.float32
|
| 268 |
-
self.num_heads = num_heads
|
| 269 |
-
self.num_head_channels = num_head_channels
|
| 270 |
-
self.num_heads_upsample = num_heads_upsample
|
| 271 |
-
self.predict_codebook_ids = n_embed is not None
|
| 272 |
-
|
| 273 |
-
assert use_fairscale_checkpoint != use_checkpoint or not (
|
| 274 |
-
use_checkpoint or use_fairscale_checkpoint
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
self.use_fairscale_checkpoint = False
|
| 278 |
-
checkpoint_wrapper_fn = (
|
| 279 |
-
partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
|
| 280 |
-
if self.use_fairscale_checkpoint
|
| 281 |
-
else lambda x: x
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
time_embed_dim = model_channels * 4
|
| 285 |
-
self.time_embed = checkpoint_wrapper_fn(
|
| 286 |
-
nn.Sequential(
|
| 287 |
-
linear(model_channels, time_embed_dim),
|
| 288 |
-
nn.SiLU(),
|
| 289 |
-
linear(time_embed_dim, time_embed_dim),
|
| 290 |
-
)
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
if self.num_classes is not None:
|
| 294 |
-
if isinstance(self.num_classes, int):
|
| 295 |
-
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 296 |
-
elif self.num_classes == "continuous":
|
| 297 |
-
print("setting up linear c_adm embedding layer")
|
| 298 |
-
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 299 |
-
elif self.num_classes == "timestep":
|
| 300 |
-
self.label_emb = checkpoint_wrapper_fn(
|
| 301 |
-
nn.Sequential(
|
| 302 |
-
Timestep(model_channels),
|
| 303 |
-
nn.Sequential(
|
| 304 |
-
linear(model_channels, time_embed_dim),
|
| 305 |
-
nn.SiLU(),
|
| 306 |
-
linear(time_embed_dim, time_embed_dim),
|
| 307 |
-
),
|
| 308 |
-
)
|
| 309 |
-
)
|
| 310 |
-
elif self.num_classes == "sequential":
|
| 311 |
-
assert adm_in_channels is not None
|
| 312 |
-
self.label_emb = nn.Sequential(
|
| 313 |
-
nn.Sequential(
|
| 314 |
-
linear(adm_in_channels, time_embed_dim),
|
| 315 |
-
nn.SiLU(),
|
| 316 |
-
linear(time_embed_dim, time_embed_dim),
|
| 317 |
-
)
|
| 318 |
-
)
|
| 319 |
-
else:
|
| 320 |
-
raise ValueError()
|
| 321 |
-
|
| 322 |
-
self.input_blocks = nn.ModuleList(
|
| 323 |
-
[
|
| 324 |
-
TimestepEmbedSequential(
|
| 325 |
-
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 326 |
-
)
|
| 327 |
-
]
|
| 328 |
-
)
|
| 329 |
-
self._feature_size = model_channels
|
| 330 |
-
input_block_chans = [model_channels]
|
| 331 |
-
ch = model_channels
|
| 332 |
-
ds = 1
|
| 333 |
-
for level, mult in enumerate(channel_mult):
|
| 334 |
-
for nr in range(self.num_res_blocks[level]):
|
| 335 |
-
layers = [
|
| 336 |
-
checkpoint_wrapper_fn(
|
| 337 |
-
ResBlock(
|
| 338 |
-
ch,
|
| 339 |
-
time_embed_dim,
|
| 340 |
-
dropout,
|
| 341 |
-
out_channels=mult * model_channels,
|
| 342 |
-
dims=dims,
|
| 343 |
-
use_checkpoint=use_checkpoint,
|
| 344 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
| 345 |
-
)
|
| 346 |
-
)
|
| 347 |
-
]
|
| 348 |
-
ch = mult * model_channels
|
| 349 |
-
if ds in attention_resolutions:
|
| 350 |
-
if num_head_channels == -1:
|
| 351 |
-
dim_head = ch // num_heads
|
| 352 |
-
else:
|
| 353 |
-
num_heads = ch // num_head_channels
|
| 354 |
-
dim_head = num_head_channels
|
| 355 |
-
if legacy:
|
| 356 |
-
# num_heads = 1
|
| 357 |
-
dim_head = (
|
| 358 |
-
ch // num_heads
|
| 359 |
-
if use_spatial_transformer
|
| 360 |
-
else num_head_channels
|
| 361 |
-
)
|
| 362 |
-
if exists(disable_self_attentions):
|
| 363 |
-
disabled_sa = disable_self_attentions[level]
|
| 364 |
-
else:
|
| 365 |
-
disabled_sa = False
|
| 366 |
-
|
| 367 |
-
if (
|
| 368 |
-
not exists(num_attention_blocks)
|
| 369 |
-
or nr < num_attention_blocks[level]
|
| 370 |
-
):
|
| 371 |
-
layers.append(
|
| 372 |
-
checkpoint_wrapper_fn(
|
| 373 |
-
AttentionBlock(
|
| 374 |
-
ch,
|
| 375 |
-
use_checkpoint=use_checkpoint,
|
| 376 |
-
num_heads=num_heads,
|
| 377 |
-
num_head_channels=dim_head,
|
| 378 |
-
use_new_attention_order=use_new_attention_order,
|
| 379 |
-
)
|
| 380 |
-
)
|
| 381 |
-
if not use_spatial_transformer
|
| 382 |
-
else checkpoint_wrapper_fn(
|
| 383 |
-
SpatialTransformer(
|
| 384 |
-
ch,
|
| 385 |
-
num_heads,
|
| 386 |
-
dim_head,
|
| 387 |
-
depth=transformer_depth[level],
|
| 388 |
-
context_dim=context_dim,
|
| 389 |
-
disable_self_attn=disabled_sa,
|
| 390 |
-
use_linear=use_linear_in_transformer,
|
| 391 |
-
attn_type=spatial_transformer_attn_type,
|
| 392 |
-
use_checkpoint=use_checkpoint,
|
| 393 |
-
)
|
| 394 |
-
)
|
| 395 |
-
)
|
| 396 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 397 |
-
self._feature_size += ch
|
| 398 |
-
input_block_chans.append(ch)
|
| 399 |
-
if level != len(channel_mult) - 1:
|
| 400 |
-
out_ch = ch
|
| 401 |
-
self.input_blocks.append(
|
| 402 |
-
TimestepEmbedSequential(
|
| 403 |
-
checkpoint_wrapper_fn(
|
| 404 |
-
ResBlock(
|
| 405 |
-
ch,
|
| 406 |
-
time_embed_dim,
|
| 407 |
-
dropout,
|
| 408 |
-
out_channels=out_ch,
|
| 409 |
-
dims=dims,
|
| 410 |
-
use_checkpoint=use_checkpoint,
|
| 411 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
| 412 |
-
down=True,
|
| 413 |
-
)
|
| 414 |
-
)
|
| 415 |
-
if resblock_updown
|
| 416 |
-
else Downsample(
|
| 417 |
-
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 418 |
-
)
|
| 419 |
-
)
|
| 420 |
-
)
|
| 421 |
-
ch = out_ch
|
| 422 |
-
input_block_chans.append(ch)
|
| 423 |
-
ds *= 2
|
| 424 |
-
self._feature_size += ch
|
| 425 |
-
|
| 426 |
-
if num_head_channels == -1:
|
| 427 |
-
dim_head = ch // num_heads
|
| 428 |
-
else:
|
| 429 |
-
num_heads = ch // num_head_channels
|
| 430 |
-
dim_head = num_head_channels
|
| 431 |
-
if legacy:
|
| 432 |
-
# num_heads = 1
|
| 433 |
-
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 434 |
-
self.middle_block = TimestepEmbedSequential(
|
| 435 |
-
checkpoint_wrapper_fn(
|
| 436 |
-
ResBlock(
|
| 437 |
-
ch,
|
| 438 |
-
time_embed_dim,
|
| 439 |
-
dropout,
|
| 440 |
-
dims=dims,
|
| 441 |
-
use_checkpoint=use_checkpoint,
|
| 442 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
| 443 |
-
)
|
| 444 |
-
),
|
| 445 |
-
checkpoint_wrapper_fn(
|
| 446 |
-
AttentionBlock(
|
| 447 |
-
ch,
|
| 448 |
-
use_checkpoint=use_checkpoint,
|
| 449 |
-
num_heads=num_heads,
|
| 450 |
-
num_head_channels=dim_head,
|
| 451 |
-
use_new_attention_order=use_new_attention_order,
|
| 452 |
-
)
|
| 453 |
-
)
|
| 454 |
-
if not use_spatial_transformer
|
| 455 |
-
else checkpoint_wrapper_fn(
|
| 456 |
-
SpatialTransformer( # always uses a self-attn
|
| 457 |
-
ch,
|
| 458 |
-
num_heads,
|
| 459 |
-
dim_head,
|
| 460 |
-
depth=transformer_depth_middle,
|
| 461 |
-
context_dim=context_dim,
|
| 462 |
-
disable_self_attn=disable_middle_self_attn,
|
| 463 |
-
use_linear=use_linear_in_transformer,
|
| 464 |
-
attn_type=spatial_transformer_attn_type,
|
| 465 |
-
use_checkpoint=use_checkpoint,
|
| 466 |
-
)
|
| 467 |
-
),
|
| 468 |
-
checkpoint_wrapper_fn(
|
| 469 |
-
ResBlock(
|
| 470 |
-
ch,
|
| 471 |
-
time_embed_dim,
|
| 472 |
-
dropout,
|
| 473 |
-
dims=dims,
|
| 474 |
-
use_checkpoint=use_checkpoint,
|
| 475 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
| 476 |
-
)
|
| 477 |
-
),
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
self.input_upscale = input_upscale
|
| 481 |
-
self.input_hint_block = TimestepEmbedSequential(
|
| 482 |
-
zero_module(conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
def convert_to_fp16(self):
|
| 486 |
-
"""
|
| 487 |
-
Convert the torso of the model to float16.
|
| 488 |
-
"""
|
| 489 |
-
self.input_blocks.apply(convert_module_to_f16)
|
| 490 |
-
self.middle_block.apply(convert_module_to_f16)
|
| 491 |
-
|
| 492 |
-
def convert_to_fp32(self):
|
| 493 |
-
"""
|
| 494 |
-
Convert the torso of the model to float32.
|
| 495 |
-
"""
|
| 496 |
-
self.input_blocks.apply(convert_module_to_f32)
|
| 497 |
-
self.middle_block.apply(convert_module_to_f32)
|
| 498 |
-
|
| 499 |
-
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
|
| 500 |
-
# with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
| 501 |
-
# x = x.to(torch.float32)
|
| 502 |
-
# timesteps = timesteps.to(torch.float32)
|
| 503 |
-
# xt = xt.to(torch.float32)
|
| 504 |
-
# context = context.to(torch.float32)
|
| 505 |
-
# y = y.to(torch.float32)
|
| 506 |
-
# print(x.dtype)
|
| 507 |
-
xt, context, y = xt.to(x.dtype), context.to(x.dtype), y.to(x.dtype)
|
| 508 |
-
|
| 509 |
-
if self.input_upscale != 1:
|
| 510 |
-
x = nn.functional.interpolate(x, scale_factor=self.input_upscale, mode='bilinear', antialias=True)
|
| 511 |
-
assert (y is not None) == (
|
| 512 |
-
self.num_classes is not None
|
| 513 |
-
), "must specify y if and only if the model is class-conditional"
|
| 514 |
-
hs = []
|
| 515 |
-
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| 516 |
-
# import pdb
|
| 517 |
-
# pdb.set_trace()
|
| 518 |
-
emb = self.time_embed(t_emb)
|
| 519 |
-
|
| 520 |
-
if self.num_classes is not None:
|
| 521 |
-
assert y.shape[0] == xt.shape[0]
|
| 522 |
-
emb = emb + self.label_emb(y)
|
| 523 |
-
|
| 524 |
-
guided_hint = self.input_hint_block(x, emb, context)
|
| 525 |
-
|
| 526 |
-
# h = x.type(self.dtype)
|
| 527 |
-
h = xt
|
| 528 |
-
for module in self.input_blocks:
|
| 529 |
-
if guided_hint is not None:
|
| 530 |
-
h = module(h, emb, context)
|
| 531 |
-
h += guided_hint
|
| 532 |
-
guided_hint = None
|
| 533 |
-
else:
|
| 534 |
-
h = module(h, emb, context)
|
| 535 |
-
hs.append(h)
|
| 536 |
-
# print(module)
|
| 537 |
-
# print(h.shape)
|
| 538 |
-
h = self.middle_block(h, emb, context)
|
| 539 |
-
hs.append(h)
|
| 540 |
-
return hs
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
class LightGLVUNet(UNetModel):
|
| 544 |
-
def __init__(self, mode='', project_type='ZeroSFT', project_channel_scale=1,
|
| 545 |
-
*args, **kwargs):
|
| 546 |
-
super().__init__(*args, **kwargs)
|
| 547 |
-
if mode == 'XL-base':
|
| 548 |
-
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
|
| 549 |
-
project_channels = [160] * 4 + [320] * 3 + [640] * 3
|
| 550 |
-
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
| 551 |
-
cross_attn_insert_idx = [6, 3]
|
| 552 |
-
self.progressive_mask_nums = [0, 3, 7, 11]
|
| 553 |
-
elif mode == 'XL-refine':
|
| 554 |
-
cond_output_channels = [384] * 4 + [768] * 3 + [1536] * 6
|
| 555 |
-
project_channels = [192] * 4 + [384] * 3 + [768] * 6
|
| 556 |
-
concat_channels = [384] * 2 + [768] * 3 + [1536] * 7 + [0]
|
| 557 |
-
cross_attn_insert_idx = [9, 6, 3]
|
| 558 |
-
self.progressive_mask_nums = [0, 3, 6, 10, 14]
|
| 559 |
-
else:
|
| 560 |
-
raise NotImplementedError
|
| 561 |
-
|
| 562 |
-
project_channels = [int(c * project_channel_scale) for c in project_channels]
|
| 563 |
-
|
| 564 |
-
self.project_modules = nn.ModuleList()
|
| 565 |
-
for i in range(len(cond_output_channels)):
|
| 566 |
-
# if i == len(cond_output_channels) - 1:
|
| 567 |
-
# _project_type = 'ZeroCrossAttn'
|
| 568 |
-
# else:
|
| 569 |
-
# _project_type = project_type
|
| 570 |
-
_project_type = project_type
|
| 571 |
-
if _project_type == 'ZeroSFT':
|
| 572 |
-
self.project_modules.append(ZeroSFT(project_channels[i], cond_output_channels[i],
|
| 573 |
-
concat_channels=concat_channels[i]))
|
| 574 |
-
elif _project_type == 'ZeroCrossAttn':
|
| 575 |
-
self.project_modules.append(ZeroCrossAttn(cond_output_channels[i], project_channels[i]))
|
| 576 |
-
else:
|
| 577 |
-
raise NotImplementedError
|
| 578 |
-
|
| 579 |
-
for i in cross_attn_insert_idx:
|
| 580 |
-
self.project_modules.insert(i, ZeroCrossAttn(cond_output_channels[i], concat_channels[i]))
|
| 581 |
-
# print(self.project_modules[i])
|
| 582 |
-
|
| 583 |
-
def step_progressive_mask(self):
|
| 584 |
-
if len(self.progressive_mask_nums) > 0:
|
| 585 |
-
mask_num = self.progressive_mask_nums.pop()
|
| 586 |
-
for i in range(len(self.project_modules)):
|
| 587 |
-
if i < mask_num:
|
| 588 |
-
self.project_modules[i].mask = True
|
| 589 |
-
else:
|
| 590 |
-
self.project_modules[i].mask = False
|
| 591 |
-
return
|
| 592 |
-
# print(f'step_progressive_mask, current masked layers: {mask_num}')
|
| 593 |
-
else:
|
| 594 |
-
return
|
| 595 |
-
# print('step_progressive_mask, no more masked layers')
|
| 596 |
-
# for i in range(len(self.project_modules)):
|
| 597 |
-
# print(self.project_modules[i].mask)
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
def forward(self, x, timesteps=None, context=None, y=None, control=None, control_scale=1, **kwargs):
|
| 601 |
-
"""
|
| 602 |
-
Apply the model to an input batch.
|
| 603 |
-
:param x: an [N x C x ...] Tensor of inputs.
|
| 604 |
-
:param timesteps: a 1-D batch of timesteps.
|
| 605 |
-
:param context: conditioning plugged in via crossattn
|
| 606 |
-
:param y: an [N] Tensor of labels, if class-conditional.
|
| 607 |
-
:return: an [N x C x ...] Tensor of outputs.
|
| 608 |
-
"""
|
| 609 |
-
assert (y is not None) == (
|
| 610 |
-
self.num_classes is not None
|
| 611 |
-
), "must specify y if and only if the model is class-conditional"
|
| 612 |
-
hs = []
|
| 613 |
-
|
| 614 |
-
_dtype = control[0].dtype
|
| 615 |
-
x, context, y = x.to(_dtype), context.to(_dtype), y.to(_dtype)
|
| 616 |
-
|
| 617 |
-
with torch.no_grad():
|
| 618 |
-
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| 619 |
-
emb = self.time_embed(t_emb)
|
| 620 |
-
|
| 621 |
-
if self.num_classes is not None:
|
| 622 |
-
assert y.shape[0] == x.shape[0]
|
| 623 |
-
emb = emb + self.label_emb(y)
|
| 624 |
-
|
| 625 |
-
# h = x.type(self.dtype)
|
| 626 |
-
h = x
|
| 627 |
-
for module in self.input_blocks:
|
| 628 |
-
h = module(h, emb, context)
|
| 629 |
-
hs.append(h)
|
| 630 |
-
|
| 631 |
-
adapter_idx = len(self.project_modules) - 1
|
| 632 |
-
control_idx = len(control) - 1
|
| 633 |
-
h = self.middle_block(h, emb, context)
|
| 634 |
-
h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
|
| 635 |
-
adapter_idx -= 1
|
| 636 |
-
control_idx -= 1
|
| 637 |
-
|
| 638 |
-
for i, module in enumerate(self.output_blocks):
|
| 639 |
-
_h = hs.pop()
|
| 640 |
-
h = self.project_modules[adapter_idx](control[control_idx], _h, h, control_scale=control_scale)
|
| 641 |
-
adapter_idx -= 1
|
| 642 |
-
# h = th.cat([h, _h], dim=1)
|
| 643 |
-
if len(module) == 3:
|
| 644 |
-
assert isinstance(module[2], Upsample)
|
| 645 |
-
for layer in module[:2]:
|
| 646 |
-
if isinstance(layer, TimestepBlock):
|
| 647 |
-
h = layer(h, emb)
|
| 648 |
-
elif isinstance(layer, SpatialTransformer):
|
| 649 |
-
h = layer(h, context)
|
| 650 |
-
else:
|
| 651 |
-
h = layer(h)
|
| 652 |
-
# print('cross_attn_here')
|
| 653 |
-
h = self.project_modules[adapter_idx](control[control_idx], h, control_scale=control_scale)
|
| 654 |
-
adapter_idx -= 1
|
| 655 |
-
h = module[2](h)
|
| 656 |
-
else:
|
| 657 |
-
h = module(h, emb, context)
|
| 658 |
-
control_idx -= 1
|
| 659 |
-
# print(module)
|
| 660 |
-
# print(h.shape)
|
| 661 |
-
|
| 662 |
-
h = h.type(x.dtype)
|
| 663 |
-
if self.predict_codebook_ids:
|
| 664 |
-
assert False, "not supported anymore. what the f*** are you doing?"
|
| 665 |
-
else:
|
| 666 |
-
return self.out(h)
|
| 667 |
-
|
| 668 |
-
if __name__ == '__main__':
|
| 669 |
-
from omegaconf import OmegaConf
|
| 670 |
-
|
| 671 |
-
# refiner
|
| 672 |
-
# opt = OmegaConf.load('../../options/train/debug_p2_xl.yaml')
|
| 673 |
-
#
|
| 674 |
-
# model = instantiate_from_config(opt.model.params.control_stage_config)
|
| 675 |
-
# hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
|
| 676 |
-
# hint = [h.cuda() for h in hint]
|
| 677 |
-
# print(sum(map(lambda hint: hint.numel(), model.parameters())))
|
| 678 |
-
#
|
| 679 |
-
# unet = instantiate_from_config(opt.model.params.network_config)
|
| 680 |
-
# unet = unet.cuda()
|
| 681 |
-
#
|
| 682 |
-
# _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
| 683 |
-
# torch.randn([1, 2560]).cuda(), hint)
|
| 684 |
-
# print(sum(map(lambda _output: _output.numel(), unet.parameters())))
|
| 685 |
-
|
| 686 |
-
# base
|
| 687 |
-
with torch.no_grad():
|
| 688 |
-
opt = OmegaConf.load('../../options/dev/SUPIR_tmp.yaml')
|
| 689 |
-
|
| 690 |
-
model = instantiate_from_config(opt.model.params.control_stage_config)
|
| 691 |
-
model = model.cuda()
|
| 692 |
-
|
| 693 |
-
hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 2048]).cuda(),
|
| 694 |
-
torch.randn([1, 2816]).cuda())
|
| 695 |
-
|
| 696 |
-
for h in hint:
|
| 697 |
-
print(h.shape)
|
| 698 |
-
#
|
| 699 |
-
unet = instantiate_from_config(opt.model.params.network_config)
|
| 700 |
-
unet = unet.cuda()
|
| 701 |
-
_output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 2048]).cuda(),
|
| 702 |
-
torch.randn([1, 2816]).cuda(), hint)
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
# model = instantiate_from_config(opt.model.params.control_stage_config)
|
| 706 |
-
# model = model.cuda()
|
| 707 |
-
# # hint = model(torch.randn([1, 4, 64, 64]), torch.randn([1]), torch.randn([1, 4, 64, 64]))
|
| 708 |
-
# hint = model(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
| 709 |
-
# torch.randn([1, 2560]).cuda())
|
| 710 |
-
# # hint = [h.cuda() for h in hint]
|
| 711 |
-
#
|
| 712 |
-
# for h in hint:
|
| 713 |
-
# print(h.shape)
|
| 714 |
-
#
|
| 715 |
-
# unet = instantiate_from_config(opt.model.params.network_config)
|
| 716 |
-
# unet = unet.cuda()
|
| 717 |
-
# _output = unet(torch.randn([1, 4, 64, 64]).cuda(), torch.randn([1]).cuda(), torch.randn([1, 77, 1280]).cuda(),
|
| 718 |
-
# torch.randn([1, 2560]).cuda(), hint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/modules/__init__.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
SDXL_BASE_CHANNEL_DICT = {
|
| 2 |
-
'cond_output_channels': [320] * 4 + [640] * 3 + [1280] * 3,
|
| 3 |
-
'project_channels': [160] * 4 + [320] * 3 + [640] * 3,
|
| 4 |
-
'concat_channels': [320] * 2 + [640] * 3 + [1280] * 4 + [0]
|
| 5 |
-
}
|
| 6 |
-
|
| 7 |
-
SDXL_REFINE_CHANNEL_DICT = {
|
| 8 |
-
'cond_output_channels': [384] * 4 + [768] * 3 + [1536] * 6,
|
| 9 |
-
'project_channels': [192] * 4 + [384] * 3 + [768] * 6,
|
| 10 |
-
'concat_channels': [384] * 2 + [768] * 3 + [1536] * 7 + [0]
|
| 11 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/util.py
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import numpy as np
|
| 4 |
-
import cv2
|
| 5 |
-
from PIL import Image
|
| 6 |
-
from torch.nn.functional import interpolate
|
| 7 |
-
from omegaconf import OmegaConf
|
| 8 |
-
from sgm.util import instantiate_from_config
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def get_state_dict(d):
|
| 12 |
-
return d.get('state_dict', d)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def load_state_dict(ckpt_path, location='cpu'):
|
| 16 |
-
_, extension = os.path.splitext(ckpt_path)
|
| 17 |
-
if extension.lower() == ".safetensors":
|
| 18 |
-
import safetensors.torch
|
| 19 |
-
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
| 20 |
-
else:
|
| 21 |
-
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
| 22 |
-
state_dict = get_state_dict(state_dict)
|
| 23 |
-
print(f'Loaded state_dict from [{ckpt_path}]')
|
| 24 |
-
return state_dict
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def create_model(config_path):
|
| 28 |
-
config = OmegaConf.load(config_path)
|
| 29 |
-
model = instantiate_from_config(config.model).cpu()
|
| 30 |
-
print(f'Loaded model config from [{config_path}]')
|
| 31 |
-
return model
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
|
| 35 |
-
config = OmegaConf.load(config_path)
|
| 36 |
-
model = instantiate_from_config(config.model).cpu()
|
| 37 |
-
print(f'Loaded model config from [{config_path}]')
|
| 38 |
-
if config.SDXL_CKPT is not None:
|
| 39 |
-
model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
|
| 40 |
-
if config.SUPIR_CKPT is not None:
|
| 41 |
-
model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
|
| 42 |
-
if SUPIR_sign is not None:
|
| 43 |
-
assert SUPIR_sign in ['F', 'Q']
|
| 44 |
-
if SUPIR_sign == 'F':
|
| 45 |
-
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
|
| 46 |
-
elif SUPIR_sign == 'Q':
|
| 47 |
-
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
|
| 48 |
-
if load_default_setting:
|
| 49 |
-
default_setting = config.default_setting
|
| 50 |
-
return model, default_setting
|
| 51 |
-
return model
|
| 52 |
-
|
| 53 |
-
def load_QF_ckpt(config_path):
|
| 54 |
-
config = OmegaConf.load(config_path)
|
| 55 |
-
ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu')
|
| 56 |
-
ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu')
|
| 57 |
-
return ckpt_Q, ckpt_F
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
|
| 61 |
-
'''
|
| 62 |
-
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
|
| 63 |
-
'''
|
| 64 |
-
# size
|
| 65 |
-
w, h = img.size
|
| 66 |
-
w *= upsacle
|
| 67 |
-
h *= upsacle
|
| 68 |
-
w0, h0 = round(w), round(h)
|
| 69 |
-
if min(w, h) < min_size:
|
| 70 |
-
_upsacle = min_size / min(w, h)
|
| 71 |
-
w *= _upsacle
|
| 72 |
-
h *= _upsacle
|
| 73 |
-
if fix_resize is not None:
|
| 74 |
-
_upsacle = fix_resize / min(w, h)
|
| 75 |
-
w *= _upsacle
|
| 76 |
-
h *= _upsacle
|
| 77 |
-
w0, h0 = round(w), round(h)
|
| 78 |
-
w = int(np.round(w / 64.0)) * 64
|
| 79 |
-
h = int(np.round(h / 64.0)) * 64
|
| 80 |
-
x = img.resize((w, h), Image.BICUBIC)
|
| 81 |
-
x = np.array(x).round().clip(0, 255).astype(np.uint8)
|
| 82 |
-
x = x / 255 * 2 - 1
|
| 83 |
-
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1)
|
| 84 |
-
return x, h0, w0
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def Tensor2PIL(x, h0, w0):
|
| 88 |
-
'''
|
| 89 |
-
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
|
| 90 |
-
'''
|
| 91 |
-
x = x.unsqueeze(0)
|
| 92 |
-
x = interpolate(x, size=(h0, w0), mode='bicubic')
|
| 93 |
-
x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 94 |
-
return Image.fromarray(x)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def HWC3(x):
|
| 98 |
-
assert x.dtype == np.uint8
|
| 99 |
-
if x.ndim == 2:
|
| 100 |
-
x = x[:, :, None]
|
| 101 |
-
assert x.ndim == 3
|
| 102 |
-
H, W, C = x.shape
|
| 103 |
-
assert C == 1 or C == 3 or C == 4
|
| 104 |
-
if C == 3:
|
| 105 |
-
return x
|
| 106 |
-
if C == 1:
|
| 107 |
-
return np.concatenate([x, x, x], axis=2)
|
| 108 |
-
if C == 4:
|
| 109 |
-
color = x[:, :, 0:3].astype(np.float32)
|
| 110 |
-
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
| 111 |
-
y = color * alpha + 255.0 * (1.0 - alpha)
|
| 112 |
-
y = y.clip(0, 255).astype(np.uint8)
|
| 113 |
-
return y
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64):
|
| 117 |
-
H, W, C = input_image.shape
|
| 118 |
-
H = float(H)
|
| 119 |
-
W = float(W)
|
| 120 |
-
H *= upscale
|
| 121 |
-
W *= upscale
|
| 122 |
-
if min_size is not None:
|
| 123 |
-
if min(H, W) < min_size:
|
| 124 |
-
_upsacle = min_size / min(W, H)
|
| 125 |
-
W *= _upsacle
|
| 126 |
-
H *= _upsacle
|
| 127 |
-
H = int(np.round(H / unit_resolution)) * unit_resolution
|
| 128 |
-
W = int(np.round(W / unit_resolution)) * unit_resolution
|
| 129 |
-
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
|
| 130 |
-
img = img.round().clip(0, 255).astype(np.uint8)
|
| 131 |
-
return img
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def fix_resize(input_image, size=512, unit_resolution=64):
|
| 135 |
-
H, W, C = input_image.shape
|
| 136 |
-
H = float(H)
|
| 137 |
-
W = float(W)
|
| 138 |
-
upscale = size / min(H, W)
|
| 139 |
-
H *= upscale
|
| 140 |
-
W *= upscale
|
| 141 |
-
H = int(np.round(H / unit_resolution)) * unit_resolution
|
| 142 |
-
W = int(np.round(W / unit_resolution)) * unit_resolution
|
| 143 |
-
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA)
|
| 144 |
-
img = img.round().clip(0, 255).astype(np.uint8)
|
| 145 |
-
return img
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def Numpy2Tensor(img):
|
| 150 |
-
'''
|
| 151 |
-
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1]
|
| 152 |
-
'''
|
| 153 |
-
# size
|
| 154 |
-
img = np.array(img) / 255 * 2 - 1
|
| 155 |
-
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1)
|
| 156 |
-
return img
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def Tensor2Numpy(x, h0=None, w0=None):
|
| 160 |
-
'''
|
| 161 |
-
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image
|
| 162 |
-
'''
|
| 163 |
-
if h0 is not None and w0 is not None:
|
| 164 |
-
x = x.unsqueeze(0)
|
| 165 |
-
x = interpolate(x, size=(h0, w0), mode='bicubic')
|
| 166 |
-
x = x.squeeze(0)
|
| 167 |
-
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 168 |
-
return x
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def convert_dtype(dtype_str):
|
| 172 |
-
if dtype_str == 'fp32':
|
| 173 |
-
return torch.float32
|
| 174 |
-
elif dtype_str == 'fp16':
|
| 175 |
-
return torch.float16
|
| 176 |
-
elif dtype_str == 'bf16':
|
| 177 |
-
return torch.bfloat16
|
| 178 |
-
else:
|
| 179 |
-
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/utils/__init__.py
DELETED
|
File without changes
|
SUPIR/utils/colorfix.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
'''
|
| 2 |
-
# --------------------------------------------------------------------------------
|
| 3 |
-
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
|
| 4 |
-
# --------------------------------------------------------------------------------
|
| 5 |
-
'''
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
from PIL import Image
|
| 9 |
-
from torch import Tensor
|
| 10 |
-
from torch.nn import functional as F
|
| 11 |
-
|
| 12 |
-
from torchvision.transforms import ToTensor, ToPILImage
|
| 13 |
-
|
| 14 |
-
def adain_color_fix(target: Image, source: Image):
|
| 15 |
-
# Convert images to tensors
|
| 16 |
-
to_tensor = ToTensor()
|
| 17 |
-
target_tensor = to_tensor(target).unsqueeze(0)
|
| 18 |
-
source_tensor = to_tensor(source).unsqueeze(0)
|
| 19 |
-
|
| 20 |
-
# Apply adaptive instance normalization
|
| 21 |
-
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
| 22 |
-
|
| 23 |
-
# Convert tensor back to image
|
| 24 |
-
to_image = ToPILImage()
|
| 25 |
-
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
| 26 |
-
|
| 27 |
-
return result_image
|
| 28 |
-
|
| 29 |
-
def wavelet_color_fix(target: Image, source: Image):
|
| 30 |
-
# Convert images to tensors
|
| 31 |
-
to_tensor = ToTensor()
|
| 32 |
-
target_tensor = to_tensor(target).unsqueeze(0)
|
| 33 |
-
source_tensor = to_tensor(source).unsqueeze(0)
|
| 34 |
-
|
| 35 |
-
# Apply wavelet reconstruction
|
| 36 |
-
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
| 37 |
-
|
| 38 |
-
# Convert tensor back to image
|
| 39 |
-
to_image = ToPILImage()
|
| 40 |
-
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
| 41 |
-
|
| 42 |
-
return result_image
|
| 43 |
-
|
| 44 |
-
def calc_mean_std(feat: Tensor, eps=1e-5):
|
| 45 |
-
"""Calculate mean and std for adaptive_instance_normalization.
|
| 46 |
-
Args:
|
| 47 |
-
feat (Tensor): 4D tensor.
|
| 48 |
-
eps (float): A small value added to the variance to avoid
|
| 49 |
-
divide-by-zero. Default: 1e-5.
|
| 50 |
-
"""
|
| 51 |
-
size = feat.size()
|
| 52 |
-
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 53 |
-
b, c = size[:2]
|
| 54 |
-
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
|
| 55 |
-
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
|
| 56 |
-
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
|
| 57 |
-
return feat_mean, feat_std
|
| 58 |
-
|
| 59 |
-
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
| 60 |
-
"""Adaptive instance normalization.
|
| 61 |
-
Adjust the reference features to have the similar color and illuminations
|
| 62 |
-
as those in the degradate features.
|
| 63 |
-
Args:
|
| 64 |
-
content_feat (Tensor): The reference feature.
|
| 65 |
-
style_feat (Tensor): The degradate features.
|
| 66 |
-
"""
|
| 67 |
-
size = content_feat.size()
|
| 68 |
-
style_mean, style_std = calc_mean_std(style_feat)
|
| 69 |
-
content_mean, content_std = calc_mean_std(content_feat)
|
| 70 |
-
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 71 |
-
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 72 |
-
|
| 73 |
-
def wavelet_blur(image: Tensor, radius: int):
|
| 74 |
-
"""
|
| 75 |
-
Apply wavelet blur to the input tensor.
|
| 76 |
-
"""
|
| 77 |
-
# input shape: (1, 3, H, W)
|
| 78 |
-
# convolution kernel
|
| 79 |
-
kernel_vals = [
|
| 80 |
-
[0.0625, 0.125, 0.0625],
|
| 81 |
-
[0.125, 0.25, 0.125],
|
| 82 |
-
[0.0625, 0.125, 0.0625],
|
| 83 |
-
]
|
| 84 |
-
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 85 |
-
# add channel dimensions to the kernel to make it a 4D tensor
|
| 86 |
-
kernel = kernel[None, None]
|
| 87 |
-
# repeat the kernel across all input channels
|
| 88 |
-
kernel = kernel.repeat(3, 1, 1, 1)
|
| 89 |
-
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 90 |
-
# apply convolution
|
| 91 |
-
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
| 92 |
-
return output
|
| 93 |
-
|
| 94 |
-
def wavelet_decomposition(image: Tensor, levels=5):
|
| 95 |
-
"""
|
| 96 |
-
Apply wavelet decomposition to the input tensor.
|
| 97 |
-
This function only returns the low frequency & the high frequency.
|
| 98 |
-
"""
|
| 99 |
-
high_freq = torch.zeros_like(image)
|
| 100 |
-
for i in range(levels):
|
| 101 |
-
radius = 2 ** i
|
| 102 |
-
low_freq = wavelet_blur(image, radius)
|
| 103 |
-
high_freq += (image - low_freq)
|
| 104 |
-
image = low_freq
|
| 105 |
-
|
| 106 |
-
return high_freq, low_freq
|
| 107 |
-
|
| 108 |
-
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
| 109 |
-
"""
|
| 110 |
-
Apply wavelet decomposition, so that the content will have the same color as the style.
|
| 111 |
-
"""
|
| 112 |
-
# calculate the wavelet decomposition of the content feature
|
| 113 |
-
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
| 114 |
-
del content_low_freq
|
| 115 |
-
# calculate the wavelet decomposition of the style feature
|
| 116 |
-
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
| 117 |
-
del style_high_freq
|
| 118 |
-
# reconstruct the content feature with the style's high frequency
|
| 119 |
-
return content_high_freq + style_low_freq
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/utils/devices.py
DELETED
|
@@ -1,138 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import contextlib
|
| 3 |
-
from functools import lru_cache
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
#from modules import errors
|
| 7 |
-
|
| 8 |
-
if sys.platform == "darwin":
|
| 9 |
-
from modules import mac_specific
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def has_mps() -> bool:
|
| 13 |
-
if sys.platform != "darwin":
|
| 14 |
-
return False
|
| 15 |
-
else:
|
| 16 |
-
return mac_specific.has_mps
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def get_cuda_device_string():
|
| 20 |
-
return "cuda"
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def get_optimal_device_name():
|
| 24 |
-
if torch.cuda.is_available():
|
| 25 |
-
return get_cuda_device_string()
|
| 26 |
-
|
| 27 |
-
if has_mps():
|
| 28 |
-
return "mps"
|
| 29 |
-
|
| 30 |
-
return "cpu"
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def get_optimal_device():
|
| 34 |
-
return torch.device(get_optimal_device_name())
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_device_for(task):
|
| 38 |
-
return get_optimal_device()
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def torch_gc():
|
| 42 |
-
|
| 43 |
-
if torch.cuda.is_available():
|
| 44 |
-
with torch.cuda.device(get_cuda_device_string()):
|
| 45 |
-
torch.cuda.empty_cache()
|
| 46 |
-
torch.cuda.ipc_collect()
|
| 47 |
-
|
| 48 |
-
if has_mps():
|
| 49 |
-
mac_specific.torch_mps_gc()
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def enable_tf32():
|
| 53 |
-
if torch.cuda.is_available():
|
| 54 |
-
|
| 55 |
-
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
| 56 |
-
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
| 57 |
-
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
| 58 |
-
torch.backends.cudnn.benchmark = True
|
| 59 |
-
|
| 60 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 61 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
enable_tf32()
|
| 65 |
-
#errors.run(enable_tf32, "Enabling TF32")
|
| 66 |
-
|
| 67 |
-
cpu = torch.device("cpu")
|
| 68 |
-
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
|
| 69 |
-
dtype = torch.float16
|
| 70 |
-
dtype_vae = torch.float16
|
| 71 |
-
dtype_unet = torch.float16
|
| 72 |
-
unet_needs_upcast = False
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def cond_cast_unet(input):
|
| 76 |
-
return input.to(dtype_unet) if unet_needs_upcast else input
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def cond_cast_float(input):
|
| 80 |
-
return input.float() if unet_needs_upcast else input
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def randn(seed, shape):
|
| 84 |
-
torch.manual_seed(seed)
|
| 85 |
-
return torch.randn(shape, device=device)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def randn_without_seed(shape):
|
| 89 |
-
return torch.randn(shape, device=device)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def autocast(disable=False):
|
| 93 |
-
if disable:
|
| 94 |
-
return contextlib.nullcontext()
|
| 95 |
-
|
| 96 |
-
return torch.autocast("cuda")
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def without_autocast(disable=False):
|
| 100 |
-
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class NansException(Exception):
|
| 104 |
-
pass
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def test_for_nans(x, where):
|
| 108 |
-
if not torch.all(torch.isnan(x)).item():
|
| 109 |
-
return
|
| 110 |
-
|
| 111 |
-
if where == "unet":
|
| 112 |
-
message = "A tensor with all NaNs was produced in Unet."
|
| 113 |
-
|
| 114 |
-
elif where == "vae":
|
| 115 |
-
message = "A tensor with all NaNs was produced in VAE."
|
| 116 |
-
|
| 117 |
-
else:
|
| 118 |
-
message = "A tensor with all NaNs was produced."
|
| 119 |
-
|
| 120 |
-
message += " Use --disable-nan-check commandline argument to disable this check."
|
| 121 |
-
|
| 122 |
-
raise NansException(message)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@lru_cache
|
| 126 |
-
def first_time_calculation():
|
| 127 |
-
"""
|
| 128 |
-
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
| 129 |
-
spends about 2.7 seconds doing that, at least wih NVidia.
|
| 130 |
-
"""
|
| 131 |
-
|
| 132 |
-
x = torch.zeros((1, 1)).to(device, dtype)
|
| 133 |
-
linear = torch.nn.Linear(1, 1).to(device, dtype)
|
| 134 |
-
linear(x)
|
| 135 |
-
|
| 136 |
-
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
| 137 |
-
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
| 138 |
-
conv2d(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/utils/face_restoration_helper.py
DELETED
|
@@ -1,514 +0,0 @@
|
|
| 1 |
-
import cv2
|
| 2 |
-
import numpy as np
|
| 3 |
-
import os
|
| 4 |
-
import torch
|
| 5 |
-
from torchvision.transforms.functional import normalize
|
| 6 |
-
|
| 7 |
-
from facexlib.detection import init_detection_model
|
| 8 |
-
from facexlib.parsing import init_parsing_model
|
| 9 |
-
from facexlib.utils.misc import img2tensor, imwrite
|
| 10 |
-
|
| 11 |
-
from .file import load_file_from_url
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_largest_face(det_faces, h, w):
|
| 15 |
-
def get_location(val, length):
|
| 16 |
-
if val < 0:
|
| 17 |
-
return 0
|
| 18 |
-
elif val > length:
|
| 19 |
-
return length
|
| 20 |
-
else:
|
| 21 |
-
return val
|
| 22 |
-
|
| 23 |
-
face_areas = []
|
| 24 |
-
for det_face in det_faces:
|
| 25 |
-
left = get_location(det_face[0], w)
|
| 26 |
-
right = get_location(det_face[2], w)
|
| 27 |
-
top = get_location(det_face[1], h)
|
| 28 |
-
bottom = get_location(det_face[3], h)
|
| 29 |
-
face_area = (right - left) * (bottom - top)
|
| 30 |
-
face_areas.append(face_area)
|
| 31 |
-
largest_idx = face_areas.index(max(face_areas))
|
| 32 |
-
return det_faces[largest_idx], largest_idx
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def get_center_face(det_faces, h=0, w=0, center=None):
|
| 36 |
-
if center is not None:
|
| 37 |
-
center = np.array(center)
|
| 38 |
-
else:
|
| 39 |
-
center = np.array([w / 2, h / 2])
|
| 40 |
-
center_dist = []
|
| 41 |
-
for det_face in det_faces:
|
| 42 |
-
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
| 43 |
-
dist = np.linalg.norm(face_center - center)
|
| 44 |
-
center_dist.append(dist)
|
| 45 |
-
center_idx = center_dist.index(min(center_dist))
|
| 46 |
-
return det_faces[center_idx], center_idx
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class FaceRestoreHelper(object):
|
| 50 |
-
"""Helper for the face restoration pipeline (base class)."""
|
| 51 |
-
|
| 52 |
-
def __init__(self,
|
| 53 |
-
upscale_factor,
|
| 54 |
-
face_size=512,
|
| 55 |
-
crop_ratio=(1, 1),
|
| 56 |
-
det_model='retinaface_resnet50',
|
| 57 |
-
save_ext='png',
|
| 58 |
-
template_3points=False,
|
| 59 |
-
pad_blur=False,
|
| 60 |
-
use_parse=False,
|
| 61 |
-
device=None):
|
| 62 |
-
self.template_3points = template_3points # improve robustness
|
| 63 |
-
self.upscale_factor = int(upscale_factor)
|
| 64 |
-
# the cropped face ratio based on the square face
|
| 65 |
-
self.crop_ratio = crop_ratio # (h, w)
|
| 66 |
-
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 67 |
-
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 68 |
-
self.det_model = det_model
|
| 69 |
-
|
| 70 |
-
if self.det_model == 'dlib':
|
| 71 |
-
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
| 72 |
-
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
| 73 |
-
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
| 74 |
-
[513.58415842, 678.5049505]])
|
| 75 |
-
self.face_template = self.face_template / (1024 // face_size)
|
| 76 |
-
elif self.template_3points:
|
| 77 |
-
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 78 |
-
else:
|
| 79 |
-
# standard 5 landmarks for FFHQ faces with 512 x 512
|
| 80 |
-
# facexlib
|
| 81 |
-
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
| 82 |
-
[201.26117, 371.41043], [313.08905, 371.15118]])
|
| 83 |
-
|
| 84 |
-
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
| 85 |
-
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 86 |
-
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 87 |
-
|
| 88 |
-
self.face_template = self.face_template * (face_size / 512.0)
|
| 89 |
-
if self.crop_ratio[0] > 1:
|
| 90 |
-
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
| 91 |
-
if self.crop_ratio[1] > 1:
|
| 92 |
-
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
| 93 |
-
self.save_ext = save_ext
|
| 94 |
-
self.pad_blur = pad_blur
|
| 95 |
-
if self.pad_blur is True:
|
| 96 |
-
self.template_3points = False
|
| 97 |
-
|
| 98 |
-
self.all_landmarks_5 = []
|
| 99 |
-
self.det_faces = []
|
| 100 |
-
self.affine_matrices = []
|
| 101 |
-
self.inverse_affine_matrices = []
|
| 102 |
-
self.cropped_faces = []
|
| 103 |
-
self.restored_faces = []
|
| 104 |
-
self.pad_input_imgs = []
|
| 105 |
-
|
| 106 |
-
if device is None:
|
| 107 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 108 |
-
# self.device = get_device()
|
| 109 |
-
else:
|
| 110 |
-
self.device = device
|
| 111 |
-
|
| 112 |
-
# init face detection model
|
| 113 |
-
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
| 114 |
-
|
| 115 |
-
# init face parsing model
|
| 116 |
-
self.use_parse = use_parse
|
| 117 |
-
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
| 118 |
-
|
| 119 |
-
def set_upscale_factor(self, upscale_factor):
|
| 120 |
-
self.upscale_factor = upscale_factor
|
| 121 |
-
|
| 122 |
-
def read_image(self, img):
|
| 123 |
-
"""img can be image path or cv2 loaded image."""
|
| 124 |
-
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
| 125 |
-
if isinstance(img, str):
|
| 126 |
-
img = cv2.imread(img)
|
| 127 |
-
|
| 128 |
-
if np.max(img) > 256: # 16-bit image
|
| 129 |
-
img = img / 65535 * 255
|
| 130 |
-
if len(img.shape) == 2: # gray image
|
| 131 |
-
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 132 |
-
elif img.shape[2] == 4: # BGRA image with alpha channel
|
| 133 |
-
img = img[:, :, 0:3]
|
| 134 |
-
|
| 135 |
-
self.input_img = img
|
| 136 |
-
# self.is_gray = is_gray(img, threshold=10)
|
| 137 |
-
# if self.is_gray:
|
| 138 |
-
# print('Grayscale input: True')
|
| 139 |
-
|
| 140 |
-
if min(self.input_img.shape[:2]) < 512:
|
| 141 |
-
f = 512.0 / min(self.input_img.shape[:2])
|
| 142 |
-
self.input_img = cv2.resize(self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
| 143 |
-
|
| 144 |
-
def init_dlib(self, detection_path, landmark5_path):
|
| 145 |
-
"""Initialize the dlib detectors and predictors."""
|
| 146 |
-
try:
|
| 147 |
-
import dlib
|
| 148 |
-
except ImportError:
|
| 149 |
-
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
| 150 |
-
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 151 |
-
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 152 |
-
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
| 153 |
-
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
| 154 |
-
return face_detector, shape_predictor_5
|
| 155 |
-
|
| 156 |
-
def get_face_landmarks_5_dlib(self,
|
| 157 |
-
only_keep_largest=False,
|
| 158 |
-
scale=1):
|
| 159 |
-
det_faces = self.face_detector(self.input_img, scale)
|
| 160 |
-
|
| 161 |
-
if len(det_faces) == 0:
|
| 162 |
-
print('No face detected. Try to increase upsample_num_times.')
|
| 163 |
-
return 0
|
| 164 |
-
else:
|
| 165 |
-
if only_keep_largest:
|
| 166 |
-
print('Detect several faces and only keep the largest.')
|
| 167 |
-
face_areas = []
|
| 168 |
-
for i in range(len(det_faces)):
|
| 169 |
-
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
| 170 |
-
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
| 171 |
-
face_areas.append(face_area)
|
| 172 |
-
largest_idx = face_areas.index(max(face_areas))
|
| 173 |
-
self.det_faces = [det_faces[largest_idx]]
|
| 174 |
-
else:
|
| 175 |
-
self.det_faces = det_faces
|
| 176 |
-
|
| 177 |
-
if len(self.det_faces) == 0:
|
| 178 |
-
return 0
|
| 179 |
-
|
| 180 |
-
for face in self.det_faces:
|
| 181 |
-
shape = self.shape_predictor_5(self.input_img, face.rect)
|
| 182 |
-
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
| 183 |
-
self.all_landmarks_5.append(landmark)
|
| 184 |
-
|
| 185 |
-
return len(self.all_landmarks_5)
|
| 186 |
-
|
| 187 |
-
def get_face_landmarks_5(self,
|
| 188 |
-
only_keep_largest=False,
|
| 189 |
-
only_center_face=False,
|
| 190 |
-
resize=None,
|
| 191 |
-
blur_ratio=0.01,
|
| 192 |
-
eye_dist_threshold=None):
|
| 193 |
-
if self.det_model == 'dlib':
|
| 194 |
-
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
| 195 |
-
|
| 196 |
-
if resize is None:
|
| 197 |
-
scale = 1
|
| 198 |
-
input_img = self.input_img
|
| 199 |
-
else:
|
| 200 |
-
h, w = self.input_img.shape[0:2]
|
| 201 |
-
scale = resize / min(h, w)
|
| 202 |
-
scale = max(1, scale) # always scale up
|
| 203 |
-
h, w = int(h * scale), int(w * scale)
|
| 204 |
-
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
| 205 |
-
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
| 206 |
-
|
| 207 |
-
with torch.no_grad():
|
| 208 |
-
bboxes = self.face_detector.detect_faces(input_img)
|
| 209 |
-
|
| 210 |
-
if bboxes is None or bboxes.shape[0] == 0:
|
| 211 |
-
return 0
|
| 212 |
-
else:
|
| 213 |
-
bboxes = bboxes / scale
|
| 214 |
-
|
| 215 |
-
for bbox in bboxes:
|
| 216 |
-
# remove faces with too small eye distance: side faces or too small faces
|
| 217 |
-
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
| 218 |
-
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
| 219 |
-
continue
|
| 220 |
-
|
| 221 |
-
if self.template_3points:
|
| 222 |
-
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
| 223 |
-
else:
|
| 224 |
-
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
| 225 |
-
self.all_landmarks_5.append(landmark)
|
| 226 |
-
self.det_faces.append(bbox[0:5])
|
| 227 |
-
|
| 228 |
-
if len(self.det_faces) == 0:
|
| 229 |
-
return 0
|
| 230 |
-
if only_keep_largest:
|
| 231 |
-
h, w, _ = self.input_img.shape
|
| 232 |
-
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
| 233 |
-
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
| 234 |
-
elif only_center_face:
|
| 235 |
-
h, w, _ = self.input_img.shape
|
| 236 |
-
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
| 237 |
-
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
| 238 |
-
|
| 239 |
-
# pad blurry images
|
| 240 |
-
if self.pad_blur:
|
| 241 |
-
self.pad_input_imgs = []
|
| 242 |
-
for landmarks in self.all_landmarks_5:
|
| 243 |
-
# get landmarks
|
| 244 |
-
eye_left = landmarks[0, :]
|
| 245 |
-
eye_right = landmarks[1, :]
|
| 246 |
-
eye_avg = (eye_left + eye_right) * 0.5
|
| 247 |
-
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
| 248 |
-
eye_to_eye = eye_right - eye_left
|
| 249 |
-
eye_to_mouth = mouth_avg - eye_avg
|
| 250 |
-
|
| 251 |
-
# Get the oriented crop rectangle
|
| 252 |
-
# x: half width of the oriented crop rectangle
|
| 253 |
-
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
| 254 |
-
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
| 255 |
-
# norm with the hypotenuse: get the direction
|
| 256 |
-
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
| 257 |
-
rect_scale = 1.5
|
| 258 |
-
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
| 259 |
-
# y: half height of the oriented crop rectangle
|
| 260 |
-
y = np.flipud(x) * [-1, 1]
|
| 261 |
-
|
| 262 |
-
# c: center
|
| 263 |
-
c = eye_avg + eye_to_mouth * 0.1
|
| 264 |
-
# quad: (left_top, left_bottom, right_bottom, right_top)
|
| 265 |
-
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
| 266 |
-
# qsize: side length of the square
|
| 267 |
-
qsize = np.hypot(*x) * 2
|
| 268 |
-
border = max(int(np.rint(qsize * 0.1)), 3)
|
| 269 |
-
|
| 270 |
-
# get pad
|
| 271 |
-
# pad: (width_left, height_top, width_right, height_bottom)
|
| 272 |
-
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
| 273 |
-
int(np.ceil(max(quad[:, 1]))))
|
| 274 |
-
pad = [
|
| 275 |
-
max(-pad[0] + border, 1),
|
| 276 |
-
max(-pad[1] + border, 1),
|
| 277 |
-
max(pad[2] - self.input_img.shape[0] + border, 1),
|
| 278 |
-
max(pad[3] - self.input_img.shape[1] + border, 1)
|
| 279 |
-
]
|
| 280 |
-
|
| 281 |
-
if max(pad) > 1:
|
| 282 |
-
# pad image
|
| 283 |
-
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
| 284 |
-
# modify landmark coords
|
| 285 |
-
landmarks[:, 0] += pad[0]
|
| 286 |
-
landmarks[:, 1] += pad[1]
|
| 287 |
-
# blur pad images
|
| 288 |
-
h, w, _ = pad_img.shape
|
| 289 |
-
y, x, _ = np.ogrid[:h, :w, :1]
|
| 290 |
-
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
| 291 |
-
np.float32(w - 1 - x) / pad[2]),
|
| 292 |
-
1.0 - np.minimum(np.float32(y) / pad[1],
|
| 293 |
-
np.float32(h - 1 - y) / pad[3]))
|
| 294 |
-
blur = int(qsize * blur_ratio)
|
| 295 |
-
if blur % 2 == 0:
|
| 296 |
-
blur += 1
|
| 297 |
-
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
| 298 |
-
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
| 299 |
-
|
| 300 |
-
pad_img = pad_img.astype('float32')
|
| 301 |
-
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
| 302 |
-
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
| 303 |
-
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
| 304 |
-
self.pad_input_imgs.append(pad_img)
|
| 305 |
-
else:
|
| 306 |
-
self.pad_input_imgs.append(np.copy(self.input_img))
|
| 307 |
-
|
| 308 |
-
return len(self.all_landmarks_5)
|
| 309 |
-
|
| 310 |
-
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
| 311 |
-
"""Align and warp faces with face template.
|
| 312 |
-
"""
|
| 313 |
-
if self.pad_blur:
|
| 314 |
-
assert len(self.pad_input_imgs) == len(
|
| 315 |
-
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
| 316 |
-
for idx, landmark in enumerate(self.all_landmarks_5):
|
| 317 |
-
# use 5 landmarks to get affine matrix
|
| 318 |
-
# use cv2.LMEDS method for the equivalence to skimage transform
|
| 319 |
-
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
| 320 |
-
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
| 321 |
-
self.affine_matrices.append(affine_matrix)
|
| 322 |
-
# warp and crop faces
|
| 323 |
-
if border_mode == 'constant':
|
| 324 |
-
border_mode = cv2.BORDER_CONSTANT
|
| 325 |
-
elif border_mode == 'reflect101':
|
| 326 |
-
border_mode = cv2.BORDER_REFLECT101
|
| 327 |
-
elif border_mode == 'reflect':
|
| 328 |
-
border_mode = cv2.BORDER_REFLECT
|
| 329 |
-
if self.pad_blur:
|
| 330 |
-
input_img = self.pad_input_imgs[idx]
|
| 331 |
-
else:
|
| 332 |
-
input_img = self.input_img
|
| 333 |
-
cropped_face = cv2.warpAffine(
|
| 334 |
-
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
| 335 |
-
self.cropped_faces.append(cropped_face)
|
| 336 |
-
# save the cropped face
|
| 337 |
-
if save_cropped_path is not None:
|
| 338 |
-
path = os.path.splitext(save_cropped_path)[0]
|
| 339 |
-
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
| 340 |
-
imwrite(cropped_face, save_path)
|
| 341 |
-
|
| 342 |
-
def get_inverse_affine(self, save_inverse_affine_path=None):
|
| 343 |
-
"""Get inverse affine matrix."""
|
| 344 |
-
for idx, affine_matrix in enumerate(self.affine_matrices):
|
| 345 |
-
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
| 346 |
-
inverse_affine *= self.upscale_factor
|
| 347 |
-
self.inverse_affine_matrices.append(inverse_affine)
|
| 348 |
-
# save inverse affine matrices
|
| 349 |
-
if save_inverse_affine_path is not None:
|
| 350 |
-
path, _ = os.path.splitext(save_inverse_affine_path)
|
| 351 |
-
save_path = f'{path}_{idx:02d}.pth'
|
| 352 |
-
torch.save(inverse_affine, save_path)
|
| 353 |
-
|
| 354 |
-
def add_restored_face(self, restored_face, input_face=None):
|
| 355 |
-
# if self.is_gray:
|
| 356 |
-
# restored_face = bgr2gray(restored_face) # convert img into grayscale
|
| 357 |
-
# if input_face is not None:
|
| 358 |
-
# restored_face = adain_npy(restored_face, input_face) # transfer the color
|
| 359 |
-
self.restored_faces.append(restored_face)
|
| 360 |
-
|
| 361 |
-
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
| 362 |
-
h, w, _ = self.input_img.shape
|
| 363 |
-
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
| 364 |
-
|
| 365 |
-
if upsample_img is None:
|
| 366 |
-
# simply resize the background
|
| 367 |
-
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 368 |
-
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
| 369 |
-
else:
|
| 370 |
-
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 371 |
-
|
| 372 |
-
assert len(self.restored_faces) == len(
|
| 373 |
-
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
| 374 |
-
|
| 375 |
-
inv_mask_borders = []
|
| 376 |
-
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
| 377 |
-
if face_upsampler is not None:
|
| 378 |
-
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
| 379 |
-
inverse_affine /= self.upscale_factor
|
| 380 |
-
inverse_affine[:, 2] *= self.upscale_factor
|
| 381 |
-
face_size = (self.face_size[0] * self.upscale_factor, self.face_size[1] * self.upscale_factor)
|
| 382 |
-
else:
|
| 383 |
-
# Add an offset to inverse affine matrix, for more precise back alignment
|
| 384 |
-
if self.upscale_factor > 1:
|
| 385 |
-
extra_offset = 0.5 * self.upscale_factor
|
| 386 |
-
else:
|
| 387 |
-
extra_offset = 0
|
| 388 |
-
inverse_affine[:, 2] += extra_offset
|
| 389 |
-
face_size = self.face_size
|
| 390 |
-
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
| 391 |
-
|
| 392 |
-
# if draw_box or not self.use_parse: # use square parse maps
|
| 393 |
-
# mask = np.ones(face_size, dtype=np.float32)
|
| 394 |
-
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 395 |
-
# # remove the black borders
|
| 396 |
-
# inv_mask_erosion = cv2.erode(
|
| 397 |
-
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 398 |
-
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 399 |
-
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 400 |
-
# # add border
|
| 401 |
-
# if draw_box:
|
| 402 |
-
# h, w = face_size
|
| 403 |
-
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
| 404 |
-
# border = int(1400/np.sqrt(total_face_area))
|
| 405 |
-
# mask_border[border:h-border, border:w-border,:] = 0
|
| 406 |
-
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
| 407 |
-
# inv_mask_borders.append(inv_mask_border)
|
| 408 |
-
# if not self.use_parse:
|
| 409 |
-
# # compute the fusion edge based on the area of face
|
| 410 |
-
# w_edge = int(total_face_area**0.5) // 20
|
| 411 |
-
# erosion_radius = w_edge * 2
|
| 412 |
-
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 413 |
-
# blur_size = w_edge * 2
|
| 414 |
-
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 415 |
-
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 416 |
-
# upsample_img = upsample_img[:, :, None]
|
| 417 |
-
# inv_soft_mask = inv_soft_mask[:, :, None]
|
| 418 |
-
|
| 419 |
-
# always use square mask
|
| 420 |
-
mask = np.ones(face_size, dtype=np.float32)
|
| 421 |
-
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 422 |
-
# remove the black borders
|
| 423 |
-
inv_mask_erosion = cv2.erode(
|
| 424 |
-
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 425 |
-
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 426 |
-
total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 427 |
-
# add border
|
| 428 |
-
if draw_box:
|
| 429 |
-
h, w = face_size
|
| 430 |
-
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
| 431 |
-
border = int(1400 / np.sqrt(total_face_area))
|
| 432 |
-
mask_border[border:h - border, border:w - border, :] = 0
|
| 433 |
-
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
| 434 |
-
inv_mask_borders.append(inv_mask_border)
|
| 435 |
-
# compute the fusion edge based on the area of face
|
| 436 |
-
w_edge = int(total_face_area ** 0.5) // 20
|
| 437 |
-
erosion_radius = w_edge * 2
|
| 438 |
-
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 439 |
-
blur_size = w_edge * 2
|
| 440 |
-
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 441 |
-
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 442 |
-
upsample_img = upsample_img[:, :, None]
|
| 443 |
-
inv_soft_mask = inv_soft_mask[:, :, None]
|
| 444 |
-
|
| 445 |
-
# parse mask
|
| 446 |
-
if self.use_parse:
|
| 447 |
-
# inference
|
| 448 |
-
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 449 |
-
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
| 450 |
-
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 451 |
-
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
| 452 |
-
with torch.no_grad():
|
| 453 |
-
out = self.face_parse(face_input)[0]
|
| 454 |
-
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
| 455 |
-
|
| 456 |
-
parse_mask = np.zeros(out.shape)
|
| 457 |
-
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
| 458 |
-
for idx, color in enumerate(MASK_COLORMAP):
|
| 459 |
-
parse_mask[out == idx] = color
|
| 460 |
-
# blur the mask
|
| 461 |
-
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 462 |
-
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 463 |
-
# remove the black borders
|
| 464 |
-
thres = 10
|
| 465 |
-
parse_mask[:thres, :] = 0
|
| 466 |
-
parse_mask[-thres:, :] = 0
|
| 467 |
-
parse_mask[:, :thres] = 0
|
| 468 |
-
parse_mask[:, -thres:] = 0
|
| 469 |
-
parse_mask = parse_mask / 255.
|
| 470 |
-
|
| 471 |
-
parse_mask = cv2.resize(parse_mask, face_size)
|
| 472 |
-
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
| 473 |
-
inv_soft_parse_mask = parse_mask[:, :, None]
|
| 474 |
-
# pasted_face = inv_restored
|
| 475 |
-
fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
|
| 476 |
-
inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
|
| 477 |
-
|
| 478 |
-
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 479 |
-
alpha = upsample_img[:, :, 3:]
|
| 480 |
-
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
| 481 |
-
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
| 482 |
-
else:
|
| 483 |
-
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
| 484 |
-
|
| 485 |
-
if np.max(upsample_img) > 256: # 16-bit image
|
| 486 |
-
upsample_img = upsample_img.astype(np.uint16)
|
| 487 |
-
else:
|
| 488 |
-
upsample_img = upsample_img.astype(np.uint8)
|
| 489 |
-
|
| 490 |
-
# draw bounding box
|
| 491 |
-
if draw_box:
|
| 492 |
-
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
| 493 |
-
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
| 494 |
-
img_color[:, :, 0] = 0
|
| 495 |
-
img_color[:, :, 1] = 255
|
| 496 |
-
img_color[:, :, 2] = 0
|
| 497 |
-
for inv_mask_border in inv_mask_borders:
|
| 498 |
-
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
| 499 |
-
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
| 500 |
-
|
| 501 |
-
if save_path is not None:
|
| 502 |
-
path = os.path.splitext(save_path)[0]
|
| 503 |
-
save_path = f'{path}.{self.save_ext}'
|
| 504 |
-
imwrite(upsample_img, save_path)
|
| 505 |
-
return upsample_img
|
| 506 |
-
|
| 507 |
-
def clean_all(self):
|
| 508 |
-
self.all_landmarks_5 = []
|
| 509 |
-
self.restored_faces = []
|
| 510 |
-
self.affine_matrices = []
|
| 511 |
-
self.cropped_faces = []
|
| 512 |
-
self.inverse_affine_matrices = []
|
| 513 |
-
self.det_faces = []
|
| 514 |
-
self.pad_input_imgs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/utils/file.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from typing import List, Tuple
|
| 3 |
-
|
| 4 |
-
from urllib.parse import urlparse
|
| 5 |
-
from torch.hub import download_url_to_file, get_dir
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def load_file_list(file_list_path: str) -> List[str]:
|
| 9 |
-
files = []
|
| 10 |
-
# each line in file list contains a path of an image
|
| 11 |
-
with open(file_list_path, "r") as fin:
|
| 12 |
-
for line in fin:
|
| 13 |
-
path = line.strip()
|
| 14 |
-
if path:
|
| 15 |
-
files.append(path)
|
| 16 |
-
return files
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def list_image_files(
|
| 20 |
-
img_dir: str,
|
| 21 |
-
exts: Tuple[str]=(".jpg", ".png", ".jpeg"),
|
| 22 |
-
follow_links: bool=False,
|
| 23 |
-
log_progress: bool=False,
|
| 24 |
-
log_every_n_files: int=10000,
|
| 25 |
-
max_size: int=-1
|
| 26 |
-
) -> List[str]:
|
| 27 |
-
files = []
|
| 28 |
-
for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links):
|
| 29 |
-
early_stop = False
|
| 30 |
-
for file_name in file_names:
|
| 31 |
-
if os.path.splitext(file_name)[1].lower() in exts:
|
| 32 |
-
if max_size >= 0 and len(files) >= max_size:
|
| 33 |
-
early_stop = True
|
| 34 |
-
break
|
| 35 |
-
files.append(os.path.join(dir_path, file_name))
|
| 36 |
-
if log_progress and len(files) % log_every_n_files == 0:
|
| 37 |
-
print(f"find {len(files)} images in {img_dir}")
|
| 38 |
-
if early_stop:
|
| 39 |
-
break
|
| 40 |
-
return files
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def get_file_name_parts(file_path: str) -> Tuple[str, str, str]:
|
| 44 |
-
parent_path, file_name = os.path.split(file_path)
|
| 45 |
-
stem, ext = os.path.splitext(file_name)
|
| 46 |
-
return parent_path, stem, ext
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
|
| 50 |
-
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 51 |
-
"""Load file form http url, will download models if necessary.
|
| 52 |
-
|
| 53 |
-
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
url (str): URL to be downloaded.
|
| 57 |
-
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
| 58 |
-
Default: None.
|
| 59 |
-
progress (bool): Whether to show the download progress. Default: True.
|
| 60 |
-
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
str: The path to the downloaded file.
|
| 64 |
-
"""
|
| 65 |
-
if model_dir is None: # use the pytorch hub_dir
|
| 66 |
-
hub_dir = get_dir()
|
| 67 |
-
model_dir = os.path.join(hub_dir, 'checkpoints')
|
| 68 |
-
|
| 69 |
-
os.makedirs(model_dir, exist_ok=True)
|
| 70 |
-
|
| 71 |
-
parts = urlparse(url)
|
| 72 |
-
filename = os.path.basename(parts.path)
|
| 73 |
-
if file_name is not None:
|
| 74 |
-
filename = file_name
|
| 75 |
-
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
| 76 |
-
if not os.path.exists(cached_file):
|
| 77 |
-
print(f'Downloading: "{url}" to {cached_file}\n')
|
| 78 |
-
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
| 79 |
-
return cached_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPIR/utils/tilevae.py
DELETED
|
@@ -1,971 +0,0 @@
|
|
| 1 |
-
# ------------------------------------------------------------------------
|
| 2 |
-
#
|
| 3 |
-
# Ultimate VAE Tile Optimization
|
| 4 |
-
#
|
| 5 |
-
# Introducing a revolutionary new optimization designed to make
|
| 6 |
-
# the VAE work with giant images on limited VRAM!
|
| 7 |
-
# Say goodbye to the frustration of OOM and hello to seamless output!
|
| 8 |
-
#
|
| 9 |
-
# ------------------------------------------------------------------------
|
| 10 |
-
#
|
| 11 |
-
# This script is a wild hack that splits the image into tiles,
|
| 12 |
-
# encodes each tile separately, and merges the result back together.
|
| 13 |
-
#
|
| 14 |
-
# Advantages:
|
| 15 |
-
# - The VAE can now work with giant images on limited VRAM
|
| 16 |
-
# (~10 GB for 8K images!)
|
| 17 |
-
# - The merged output is completely seamless without any post-processing.
|
| 18 |
-
#
|
| 19 |
-
# Drawbacks:
|
| 20 |
-
# - Giant RAM needed. To store the intermediate results for a 4096x4096
|
| 21 |
-
# images, you need 32 GB RAM it consumes ~20GB); for 8192x8192
|
| 22 |
-
# you need 128 GB RAM machine (it consumes ~100 GB)
|
| 23 |
-
# - NaNs always appear in for 8k images when you use fp16 (half) VAE
|
| 24 |
-
# You must use --no-half-vae to disable half VAE for that giant image.
|
| 25 |
-
# - Slow speed. With default tile size, it takes around 50/200 seconds
|
| 26 |
-
# to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode
|
| 27 |
-
# a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)
|
| 28 |
-
# - The gradient calculation is not compatible with this hack. It
|
| 29 |
-
# will break any backward() or torch.autograd.grad() that passes VAE.
|
| 30 |
-
# (But you can still use the VAE to generate training data.)
|
| 31 |
-
#
|
| 32 |
-
# How it works:
|
| 33 |
-
# 1) The image is split into tiles.
|
| 34 |
-
# - To ensure perfect results, each tile is padded with 32 pixels
|
| 35 |
-
# on each side.
|
| 36 |
-
# - Then the conv2d/silu/upsample/downsample can produce identical
|
| 37 |
-
# results to the original image without splitting.
|
| 38 |
-
# 2) The original forward is decomposed into a task queue and a task worker.
|
| 39 |
-
# - The task queue is a list of functions that will be executed in order.
|
| 40 |
-
# - The task worker is a loop that executes the tasks in the queue.
|
| 41 |
-
# 3) The task queue is executed for each tile.
|
| 42 |
-
# - Current tile is sent to GPU.
|
| 43 |
-
# - local operations are directly executed.
|
| 44 |
-
# - Group norm calculation is temporarily suspended until the mean
|
| 45 |
-
# and var of all tiles are calculated.
|
| 46 |
-
# - The residual is pre-calculated and stored and addded back later.
|
| 47 |
-
# - When need to go to the next tile, the current tile is send to cpu.
|
| 48 |
-
# 4) After all tiles are processed, tiles are merged on cpu and return.
|
| 49 |
-
#
|
| 50 |
-
# Enjoy!
|
| 51 |
-
#
|
| 52 |
-
# @author: LI YI @ Nanyang Technological University - Singapore
|
| 53 |
-
# @date: 2023-03-02
|
| 54 |
-
# @license: MIT License
|
| 55 |
-
#
|
| 56 |
-
# Please give me a star if you like this project!
|
| 57 |
-
#
|
| 58 |
-
# -------------------------------------------------------------------------
|
| 59 |
-
|
| 60 |
-
import gc
|
| 61 |
-
from time import time
|
| 62 |
-
import math
|
| 63 |
-
from tqdm import tqdm
|
| 64 |
-
|
| 65 |
-
import torch
|
| 66 |
-
import torch.version
|
| 67 |
-
import torch.nn.functional as F
|
| 68 |
-
from einops import rearrange
|
| 69 |
-
from diffusers.utils.import_utils import is_xformers_available
|
| 70 |
-
|
| 71 |
-
import SUPIR.utils.devices as devices
|
| 72 |
-
|
| 73 |
-
try:
|
| 74 |
-
import xformers
|
| 75 |
-
import xformers.ops
|
| 76 |
-
except ImportError:
|
| 77 |
-
pass
|
| 78 |
-
|
| 79 |
-
sd_flag = True
|
| 80 |
-
|
| 81 |
-
def get_recommend_encoder_tile_size():
|
| 82 |
-
if torch.cuda.is_available():
|
| 83 |
-
total_memory = torch.cuda.get_device_properties(
|
| 84 |
-
devices.device).total_memory // 2**20
|
| 85 |
-
if total_memory > 16*1000:
|
| 86 |
-
ENCODER_TILE_SIZE = 3072
|
| 87 |
-
elif total_memory > 12*1000:
|
| 88 |
-
ENCODER_TILE_SIZE = 2048
|
| 89 |
-
elif total_memory > 8*1000:
|
| 90 |
-
ENCODER_TILE_SIZE = 1536
|
| 91 |
-
else:
|
| 92 |
-
ENCODER_TILE_SIZE = 960
|
| 93 |
-
else:
|
| 94 |
-
ENCODER_TILE_SIZE = 512
|
| 95 |
-
return ENCODER_TILE_SIZE
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def get_recommend_decoder_tile_size():
|
| 99 |
-
if torch.cuda.is_available():
|
| 100 |
-
total_memory = torch.cuda.get_device_properties(
|
| 101 |
-
devices.device).total_memory // 2**20
|
| 102 |
-
if total_memory > 30*1000:
|
| 103 |
-
DECODER_TILE_SIZE = 256
|
| 104 |
-
elif total_memory > 16*1000:
|
| 105 |
-
DECODER_TILE_SIZE = 192
|
| 106 |
-
elif total_memory > 12*1000:
|
| 107 |
-
DECODER_TILE_SIZE = 128
|
| 108 |
-
elif total_memory > 8*1000:
|
| 109 |
-
DECODER_TILE_SIZE = 96
|
| 110 |
-
else:
|
| 111 |
-
DECODER_TILE_SIZE = 64
|
| 112 |
-
else:
|
| 113 |
-
DECODER_TILE_SIZE = 64
|
| 114 |
-
return DECODER_TILE_SIZE
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
if 'global const':
|
| 118 |
-
DEFAULT_ENABLED = False
|
| 119 |
-
DEFAULT_MOVE_TO_GPU = False
|
| 120 |
-
DEFAULT_FAST_ENCODER = True
|
| 121 |
-
DEFAULT_FAST_DECODER = True
|
| 122 |
-
DEFAULT_COLOR_FIX = 0
|
| 123 |
-
DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
|
| 124 |
-
DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# inplace version of silu
|
| 128 |
-
def inplace_nonlinearity(x):
|
| 129 |
-
# Test: fix for Nans
|
| 130 |
-
return F.silu(x, inplace=True)
|
| 131 |
-
|
| 132 |
-
# extracted from ldm.modules.diffusionmodules.model
|
| 133 |
-
|
| 134 |
-
# from diffusers lib
|
| 135 |
-
def attn_forward_new(self, h_):
|
| 136 |
-
batch_size, channel, height, width = h_.shape
|
| 137 |
-
hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
|
| 138 |
-
|
| 139 |
-
attention_mask = None
|
| 140 |
-
encoder_hidden_states = None
|
| 141 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
| 142 |
-
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 143 |
-
|
| 144 |
-
query = self.to_q(hidden_states)
|
| 145 |
-
|
| 146 |
-
if encoder_hidden_states is None:
|
| 147 |
-
encoder_hidden_states = hidden_states
|
| 148 |
-
elif self.norm_cross:
|
| 149 |
-
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
| 150 |
-
|
| 151 |
-
key = self.to_k(encoder_hidden_states)
|
| 152 |
-
value = self.to_v(encoder_hidden_states)
|
| 153 |
-
|
| 154 |
-
query = self.head_to_batch_dim(query)
|
| 155 |
-
key = self.head_to_batch_dim(key)
|
| 156 |
-
value = self.head_to_batch_dim(value)
|
| 157 |
-
|
| 158 |
-
attention_probs = self.get_attention_scores(query, key, attention_mask)
|
| 159 |
-
hidden_states = torch.bmm(attention_probs, value)
|
| 160 |
-
hidden_states = self.batch_to_head_dim(hidden_states)
|
| 161 |
-
|
| 162 |
-
# linear proj
|
| 163 |
-
hidden_states = self.to_out[0](hidden_states)
|
| 164 |
-
# dropout
|
| 165 |
-
hidden_states = self.to_out[1](hidden_states)
|
| 166 |
-
|
| 167 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 168 |
-
|
| 169 |
-
return hidden_states
|
| 170 |
-
|
| 171 |
-
def attn_forward_new_pt2_0(self, hidden_states,):
|
| 172 |
-
scale = 1
|
| 173 |
-
attention_mask = None
|
| 174 |
-
encoder_hidden_states = None
|
| 175 |
-
|
| 176 |
-
input_ndim = hidden_states.ndim
|
| 177 |
-
|
| 178 |
-
if input_ndim == 4:
|
| 179 |
-
batch_size, channel, height, width = hidden_states.shape
|
| 180 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 181 |
-
|
| 182 |
-
batch_size, sequence_length, _ = (
|
| 183 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
if attention_mask is not None:
|
| 187 |
-
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 188 |
-
# scaled_dot_product_attention expects attention_mask shape to be
|
| 189 |
-
# (batch, heads, source_length, target_length)
|
| 190 |
-
attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
| 191 |
-
|
| 192 |
-
if self.group_norm is not None:
|
| 193 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 194 |
-
|
| 195 |
-
query = self.to_q(hidden_states, scale=scale)
|
| 196 |
-
|
| 197 |
-
if encoder_hidden_states is None:
|
| 198 |
-
encoder_hidden_states = hidden_states
|
| 199 |
-
elif self.norm_cross:
|
| 200 |
-
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
| 201 |
-
|
| 202 |
-
key = self.to_k(encoder_hidden_states, scale=scale)
|
| 203 |
-
value = self.to_v(encoder_hidden_states, scale=scale)
|
| 204 |
-
|
| 205 |
-
inner_dim = key.shape[-1]
|
| 206 |
-
head_dim = inner_dim // self.heads
|
| 207 |
-
|
| 208 |
-
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 209 |
-
|
| 210 |
-
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 211 |
-
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 212 |
-
|
| 213 |
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 214 |
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 215 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 216 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
| 220 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 221 |
-
|
| 222 |
-
# linear proj
|
| 223 |
-
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
| 224 |
-
# dropout
|
| 225 |
-
hidden_states = self.to_out[1](hidden_states)
|
| 226 |
-
|
| 227 |
-
if input_ndim == 4:
|
| 228 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 229 |
-
|
| 230 |
-
return hidden_states
|
| 231 |
-
|
| 232 |
-
def attn_forward_new_xformers(self, hidden_states):
|
| 233 |
-
scale = 1
|
| 234 |
-
attention_op = None
|
| 235 |
-
attention_mask = None
|
| 236 |
-
encoder_hidden_states = None
|
| 237 |
-
|
| 238 |
-
input_ndim = hidden_states.ndim
|
| 239 |
-
|
| 240 |
-
if input_ndim == 4:
|
| 241 |
-
batch_size, channel, height, width = hidden_states.shape
|
| 242 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 243 |
-
|
| 244 |
-
batch_size, key_tokens, _ = (
|
| 245 |
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
| 249 |
-
if attention_mask is not None:
|
| 250 |
-
# expand our mask's singleton query_tokens dimension:
|
| 251 |
-
# [batch*heads, 1, key_tokens] ->
|
| 252 |
-
# [batch*heads, query_tokens, key_tokens]
|
| 253 |
-
# so that it can be added as a bias onto the attention scores that xformers computes:
|
| 254 |
-
# [batch*heads, query_tokens, key_tokens]
|
| 255 |
-
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
| 256 |
-
_, query_tokens, _ = hidden_states.shape
|
| 257 |
-
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
| 258 |
-
|
| 259 |
-
if self.group_norm is not None:
|
| 260 |
-
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 261 |
-
|
| 262 |
-
query = self.to_q(hidden_states, scale=scale)
|
| 263 |
-
|
| 264 |
-
if encoder_hidden_states is None:
|
| 265 |
-
encoder_hidden_states = hidden_states
|
| 266 |
-
elif self.norm_cross:
|
| 267 |
-
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
| 268 |
-
|
| 269 |
-
key = self.to_k(encoder_hidden_states, scale=scale)
|
| 270 |
-
value = self.to_v(encoder_hidden_states, scale=scale)
|
| 271 |
-
|
| 272 |
-
query = self.head_to_batch_dim(query).contiguous()
|
| 273 |
-
key = self.head_to_batch_dim(key).contiguous()
|
| 274 |
-
value = self.head_to_batch_dim(value).contiguous()
|
| 275 |
-
|
| 276 |
-
hidden_states = xformers.ops.memory_efficient_attention(
|
| 277 |
-
query, key, value, attn_bias=attention_mask, op=attention_op#, scale=scale
|
| 278 |
-
)
|
| 279 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 280 |
-
hidden_states = self.batch_to_head_dim(hidden_states)
|
| 281 |
-
|
| 282 |
-
# linear proj
|
| 283 |
-
hidden_states = self.to_out[0](hidden_states, scale=scale)
|
| 284 |
-
# dropout
|
| 285 |
-
hidden_states = self.to_out[1](hidden_states)
|
| 286 |
-
|
| 287 |
-
if input_ndim == 4:
|
| 288 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 289 |
-
|
| 290 |
-
return hidden_states
|
| 291 |
-
|
| 292 |
-
def attn_forward(self, h_):
|
| 293 |
-
q = self.q(h_)
|
| 294 |
-
k = self.k(h_)
|
| 295 |
-
v = self.v(h_)
|
| 296 |
-
|
| 297 |
-
# compute attention
|
| 298 |
-
b, c, h, w = q.shape
|
| 299 |
-
q = q.reshape(b, c, h*w)
|
| 300 |
-
q = q.permute(0, 2, 1) # b,hw,c
|
| 301 |
-
k = k.reshape(b, c, h*w) # b,c,hw
|
| 302 |
-
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 303 |
-
w_ = w_ * (int(c)**(-0.5))
|
| 304 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 305 |
-
|
| 306 |
-
# attend to values
|
| 307 |
-
v = v.reshape(b, c, h*w)
|
| 308 |
-
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 309 |
-
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 310 |
-
h_ = torch.bmm(v, w_)
|
| 311 |
-
h_ = h_.reshape(b, c, h, w)
|
| 312 |
-
|
| 313 |
-
h_ = self.proj_out(h_)
|
| 314 |
-
|
| 315 |
-
return h_
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
def xformer_attn_forward(self, h_):
|
| 319 |
-
q = self.q(h_)
|
| 320 |
-
k = self.k(h_)
|
| 321 |
-
v = self.v(h_)
|
| 322 |
-
|
| 323 |
-
# compute attention
|
| 324 |
-
B, C, H, W = q.shape
|
| 325 |
-
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
| 326 |
-
|
| 327 |
-
q, k, v = map(
|
| 328 |
-
lambda t: t.unsqueeze(3)
|
| 329 |
-
.reshape(B, t.shape[1], 1, C)
|
| 330 |
-
.permute(0, 2, 1, 3)
|
| 331 |
-
.reshape(B * 1, t.shape[1], C)
|
| 332 |
-
.contiguous(),
|
| 333 |
-
(q, k, v),
|
| 334 |
-
)
|
| 335 |
-
out = xformers.ops.memory_efficient_attention(
|
| 336 |
-
q, k, v, attn_bias=None, op=self.attention_op)
|
| 337 |
-
|
| 338 |
-
out = (
|
| 339 |
-
out.unsqueeze(0)
|
| 340 |
-
.reshape(B, 1, out.shape[1], C)
|
| 341 |
-
.permute(0, 2, 1, 3)
|
| 342 |
-
.reshape(B, out.shape[1], C)
|
| 343 |
-
)
|
| 344 |
-
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
| 345 |
-
out = self.proj_out(out)
|
| 346 |
-
return out
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
def attn2task(task_queue, net):
|
| 350 |
-
if False: #isinstance(net, AttnBlock):
|
| 351 |
-
task_queue.append(('store_res', lambda x: x))
|
| 352 |
-
task_queue.append(('pre_norm', net.norm))
|
| 353 |
-
task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
| 354 |
-
task_queue.append(['add_res', None])
|
| 355 |
-
elif False: #isinstance(net, MemoryEfficientAttnBlock):
|
| 356 |
-
task_queue.append(('store_res', lambda x: x))
|
| 357 |
-
task_queue.append(('pre_norm', net.norm))
|
| 358 |
-
task_queue.append(
|
| 359 |
-
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
| 360 |
-
task_queue.append(['add_res', None])
|
| 361 |
-
else:
|
| 362 |
-
task_queue.append(('store_res', lambda x: x))
|
| 363 |
-
task_queue.append(('pre_norm', net.norm))
|
| 364 |
-
if is_xformers_available:
|
| 365 |
-
# task_queue.append(('attn', lambda x, net=net: attn_forward_new_xformers(net, x)))
|
| 366 |
-
task_queue.append(
|
| 367 |
-
('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
| 368 |
-
elif hasattr(F, "scaled_dot_product_attention"):
|
| 369 |
-
task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
|
| 370 |
-
else:
|
| 371 |
-
task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
|
| 372 |
-
task_queue.append(['add_res', None])
|
| 373 |
-
|
| 374 |
-
def resblock2task(queue, block):
|
| 375 |
-
"""
|
| 376 |
-
Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
| 377 |
-
|
| 378 |
-
@param queue: the target task queue
|
| 379 |
-
@param block: ResNetBlock
|
| 380 |
-
|
| 381 |
-
"""
|
| 382 |
-
if block.in_channels != block.out_channels:
|
| 383 |
-
if sd_flag:
|
| 384 |
-
if block.use_conv_shortcut:
|
| 385 |
-
queue.append(('store_res', block.conv_shortcut))
|
| 386 |
-
else:
|
| 387 |
-
queue.append(('store_res', block.nin_shortcut))
|
| 388 |
-
else:
|
| 389 |
-
if block.use_in_shortcut:
|
| 390 |
-
queue.append(('store_res', block.conv_shortcut))
|
| 391 |
-
else:
|
| 392 |
-
queue.append(('store_res', block.nin_shortcut))
|
| 393 |
-
|
| 394 |
-
else:
|
| 395 |
-
queue.append(('store_res', lambda x: x))
|
| 396 |
-
queue.append(('pre_norm', block.norm1))
|
| 397 |
-
queue.append(('silu', inplace_nonlinearity))
|
| 398 |
-
queue.append(('conv1', block.conv1))
|
| 399 |
-
queue.append(('pre_norm', block.norm2))
|
| 400 |
-
queue.append(('silu', inplace_nonlinearity))
|
| 401 |
-
queue.append(('conv2', block.conv2))
|
| 402 |
-
queue.append(['add_res', None])
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
def build_sampling(task_queue, net, is_decoder):
|
| 406 |
-
"""
|
| 407 |
-
Build the sampling part of a task queue
|
| 408 |
-
@param task_queue: the target task queue
|
| 409 |
-
@param net: the network
|
| 410 |
-
@param is_decoder: currently building decoder or encoder
|
| 411 |
-
"""
|
| 412 |
-
if is_decoder:
|
| 413 |
-
if sd_flag:
|
| 414 |
-
resblock2task(task_queue, net.mid.block_1)
|
| 415 |
-
attn2task(task_queue, net.mid.attn_1)
|
| 416 |
-
print(task_queue)
|
| 417 |
-
resblock2task(task_queue, net.mid.block_2)
|
| 418 |
-
resolution_iter = reversed(range(net.num_resolutions))
|
| 419 |
-
block_ids = net.num_res_blocks + 1
|
| 420 |
-
condition = 0
|
| 421 |
-
module = net.up
|
| 422 |
-
func_name = 'upsample'
|
| 423 |
-
else:
|
| 424 |
-
resblock2task(task_queue, net.mid_block.resnets[0])
|
| 425 |
-
attn2task(task_queue, net.mid_block.attentions[0])
|
| 426 |
-
resblock2task(task_queue, net.mid_block.resnets[1])
|
| 427 |
-
resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3
|
| 428 |
-
block_ids = 2 + 1
|
| 429 |
-
condition = len(net.up_blocks) - 1
|
| 430 |
-
module = net.up_blocks
|
| 431 |
-
func_name = 'upsamplers'
|
| 432 |
-
else:
|
| 433 |
-
if sd_flag:
|
| 434 |
-
resolution_iter = range(net.num_resolutions)
|
| 435 |
-
block_ids = net.num_res_blocks
|
| 436 |
-
condition = net.num_resolutions - 1
|
| 437 |
-
module = net.down
|
| 438 |
-
func_name = 'downsample'
|
| 439 |
-
else:
|
| 440 |
-
resolution_iter = range(len(net.down_blocks))
|
| 441 |
-
block_ids = 2
|
| 442 |
-
condition = len(net.down_blocks) - 1
|
| 443 |
-
module = net.down_blocks
|
| 444 |
-
func_name = 'downsamplers'
|
| 445 |
-
|
| 446 |
-
for i_level in resolution_iter:
|
| 447 |
-
for i_block in range(block_ids):
|
| 448 |
-
if sd_flag:
|
| 449 |
-
resblock2task(task_queue, module[i_level].block[i_block])
|
| 450 |
-
else:
|
| 451 |
-
resblock2task(task_queue, module[i_level].resnets[i_block])
|
| 452 |
-
if i_level != condition:
|
| 453 |
-
if sd_flag:
|
| 454 |
-
task_queue.append((func_name, getattr(module[i_level], func_name)))
|
| 455 |
-
else:
|
| 456 |
-
if is_decoder:
|
| 457 |
-
task_queue.append((func_name, module[i_level].upsamplers[0]))
|
| 458 |
-
else:
|
| 459 |
-
task_queue.append((func_name, module[i_level].downsamplers[0]))
|
| 460 |
-
|
| 461 |
-
if not is_decoder:
|
| 462 |
-
if sd_flag:
|
| 463 |
-
resblock2task(task_queue, net.mid.block_1)
|
| 464 |
-
attn2task(task_queue, net.mid.attn_1)
|
| 465 |
-
resblock2task(task_queue, net.mid.block_2)
|
| 466 |
-
else:
|
| 467 |
-
resblock2task(task_queue, net.mid_block.resnets[0])
|
| 468 |
-
attn2task(task_queue, net.mid_block.attentions[0])
|
| 469 |
-
resblock2task(task_queue, net.mid_block.resnets[1])
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
def build_task_queue(net, is_decoder):
|
| 473 |
-
"""
|
| 474 |
-
Build a single task queue for the encoder or decoder
|
| 475 |
-
@param net: the VAE decoder or encoder network
|
| 476 |
-
@param is_decoder: currently building decoder or encoder
|
| 477 |
-
@return: the task queue
|
| 478 |
-
"""
|
| 479 |
-
task_queue = []
|
| 480 |
-
task_queue.append(('conv_in', net.conv_in))
|
| 481 |
-
|
| 482 |
-
# construct the sampling part of the task queue
|
| 483 |
-
# because encoder and decoder share the same architecture, we extract the sampling part
|
| 484 |
-
build_sampling(task_queue, net, is_decoder)
|
| 485 |
-
if is_decoder and not sd_flag:
|
| 486 |
-
net.give_pre_end = False
|
| 487 |
-
net.tanh_out = False
|
| 488 |
-
|
| 489 |
-
if not is_decoder or not net.give_pre_end:
|
| 490 |
-
if sd_flag:
|
| 491 |
-
task_queue.append(('pre_norm', net.norm_out))
|
| 492 |
-
else:
|
| 493 |
-
task_queue.append(('pre_norm', net.conv_norm_out))
|
| 494 |
-
task_queue.append(('silu', inplace_nonlinearity))
|
| 495 |
-
task_queue.append(('conv_out', net.conv_out))
|
| 496 |
-
if is_decoder and net.tanh_out:
|
| 497 |
-
task_queue.append(('tanh', torch.tanh))
|
| 498 |
-
|
| 499 |
-
return task_queue
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
def clone_task_queue(task_queue):
|
| 503 |
-
"""
|
| 504 |
-
Clone a task queue
|
| 505 |
-
@param task_queue: the task queue to be cloned
|
| 506 |
-
@return: the cloned task queue
|
| 507 |
-
"""
|
| 508 |
-
return [[item for item in task] for task in task_queue]
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
def get_var_mean(input, num_groups, eps=1e-6):
|
| 512 |
-
"""
|
| 513 |
-
Get mean and var for group norm
|
| 514 |
-
"""
|
| 515 |
-
b, c = input.size(0), input.size(1)
|
| 516 |
-
channel_in_group = int(c/num_groups)
|
| 517 |
-
input_reshaped = input.contiguous().view(
|
| 518 |
-
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
| 519 |
-
var, mean = torch.var_mean(
|
| 520 |
-
input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
| 521 |
-
return var, mean
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
| 525 |
-
"""
|
| 526 |
-
Custom group norm with fixed mean and var
|
| 527 |
-
|
| 528 |
-
@param input: input tensor
|
| 529 |
-
@param num_groups: number of groups. by default, num_groups = 32
|
| 530 |
-
@param mean: mean, must be pre-calculated by get_var_mean
|
| 531 |
-
@param var: var, must be pre-calculated by get_var_mean
|
| 532 |
-
@param weight: weight, should be fetched from the original group norm
|
| 533 |
-
@param bias: bias, should be fetched from the original group norm
|
| 534 |
-
@param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
| 535 |
-
|
| 536 |
-
@return: normalized tensor
|
| 537 |
-
"""
|
| 538 |
-
b, c = input.size(0), input.size(1)
|
| 539 |
-
channel_in_group = int(c/num_groups)
|
| 540 |
-
input_reshaped = input.contiguous().view(
|
| 541 |
-
1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
| 542 |
-
|
| 543 |
-
out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
|
| 544 |
-
training=False, momentum=0, eps=eps)
|
| 545 |
-
|
| 546 |
-
out = out.view(b, c, *input.size()[2:])
|
| 547 |
-
|
| 548 |
-
# post affine transform
|
| 549 |
-
if weight is not None:
|
| 550 |
-
out *= weight.view(1, -1, 1, 1)
|
| 551 |
-
if bias is not None:
|
| 552 |
-
out += bias.view(1, -1, 1, 1)
|
| 553 |
-
return out
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
| 557 |
-
"""
|
| 558 |
-
Crop the valid region from the tile
|
| 559 |
-
@param x: input tile
|
| 560 |
-
@param input_bbox: original input bounding box
|
| 561 |
-
@param target_bbox: output bounding box
|
| 562 |
-
@param scale: scale factor
|
| 563 |
-
@return: cropped tile
|
| 564 |
-
"""
|
| 565 |
-
padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
| 566 |
-
margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
| 567 |
-
return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
| 568 |
-
|
| 569 |
-
# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
def perfcount(fn):
|
| 573 |
-
def wrapper(*args, **kwargs):
|
| 574 |
-
ts = time()
|
| 575 |
-
|
| 576 |
-
if torch.cuda.is_available():
|
| 577 |
-
torch.cuda.reset_peak_memory_stats(devices.device)
|
| 578 |
-
devices.torch_gc()
|
| 579 |
-
gc.collect()
|
| 580 |
-
|
| 581 |
-
ret = fn(*args, **kwargs)
|
| 582 |
-
|
| 583 |
-
devices.torch_gc()
|
| 584 |
-
gc.collect()
|
| 585 |
-
if torch.cuda.is_available():
|
| 586 |
-
vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
| 587 |
-
torch.cuda.reset_peak_memory_stats(devices.device)
|
| 588 |
-
print(
|
| 589 |
-
f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
| 590 |
-
else:
|
| 591 |
-
print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
| 592 |
-
|
| 593 |
-
return ret
|
| 594 |
-
return wrapper
|
| 595 |
-
|
| 596 |
-
# copy end :)
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
class GroupNormParam:
|
| 600 |
-
def __init__(self):
|
| 601 |
-
self.var_list = []
|
| 602 |
-
self.mean_list = []
|
| 603 |
-
self.pixel_list = []
|
| 604 |
-
self.weight = None
|
| 605 |
-
self.bias = None
|
| 606 |
-
|
| 607 |
-
def add_tile(self, tile, layer):
|
| 608 |
-
var, mean = get_var_mean(tile, 32)
|
| 609 |
-
# For giant images, the variance can be larger than max float16
|
| 610 |
-
# In this case we create a copy to float32
|
| 611 |
-
if var.dtype == torch.float16 and var.isinf().any():
|
| 612 |
-
fp32_tile = tile.float()
|
| 613 |
-
var, mean = get_var_mean(fp32_tile, 32)
|
| 614 |
-
# ============= DEBUG: test for infinite =============
|
| 615 |
-
# if torch.isinf(var).any():
|
| 616 |
-
# print('var: ', var)
|
| 617 |
-
# ====================================================
|
| 618 |
-
self.var_list.append(var)
|
| 619 |
-
self.mean_list.append(mean)
|
| 620 |
-
self.pixel_list.append(
|
| 621 |
-
tile.shape[2]*tile.shape[3])
|
| 622 |
-
if hasattr(layer, 'weight'):
|
| 623 |
-
self.weight = layer.weight
|
| 624 |
-
self.bias = layer.bias
|
| 625 |
-
else:
|
| 626 |
-
self.weight = None
|
| 627 |
-
self.bias = None
|
| 628 |
-
|
| 629 |
-
def summary(self):
|
| 630 |
-
"""
|
| 631 |
-
summarize the mean and var and return a function
|
| 632 |
-
that apply group norm on each tile
|
| 633 |
-
"""
|
| 634 |
-
if len(self.var_list) == 0:
|
| 635 |
-
return None
|
| 636 |
-
var = torch.vstack(self.var_list)
|
| 637 |
-
mean = torch.vstack(self.mean_list)
|
| 638 |
-
max_value = max(self.pixel_list)
|
| 639 |
-
pixels = torch.tensor(
|
| 640 |
-
self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
| 641 |
-
sum_pixels = torch.sum(pixels)
|
| 642 |
-
pixels = pixels.unsqueeze(
|
| 643 |
-
1) / sum_pixels
|
| 644 |
-
var = torch.sum(
|
| 645 |
-
var * pixels, dim=0)
|
| 646 |
-
mean = torch.sum(
|
| 647 |
-
mean * pixels, dim=0)
|
| 648 |
-
return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
| 649 |
-
|
| 650 |
-
@staticmethod
|
| 651 |
-
def from_tile(tile, norm):
|
| 652 |
-
"""
|
| 653 |
-
create a function from a single tile without summary
|
| 654 |
-
"""
|
| 655 |
-
var, mean = get_var_mean(tile, 32)
|
| 656 |
-
if var.dtype == torch.float16 and var.isinf().any():
|
| 657 |
-
fp32_tile = tile.float()
|
| 658 |
-
var, mean = get_var_mean(fp32_tile, 32)
|
| 659 |
-
# if it is a macbook, we need to convert back to float16
|
| 660 |
-
if var.device.type == 'mps':
|
| 661 |
-
# clamp to avoid overflow
|
| 662 |
-
var = torch.clamp(var, 0, 60000)
|
| 663 |
-
var = var.half()
|
| 664 |
-
mean = mean.half()
|
| 665 |
-
if hasattr(norm, 'weight'):
|
| 666 |
-
weight = norm.weight
|
| 667 |
-
bias = norm.bias
|
| 668 |
-
else:
|
| 669 |
-
weight = None
|
| 670 |
-
bias = None
|
| 671 |
-
|
| 672 |
-
def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
| 673 |
-
return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
| 674 |
-
return group_norm_func
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
class VAEHook:
|
| 678 |
-
def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
|
| 679 |
-
self.net = net # encoder | decoder
|
| 680 |
-
self.tile_size = tile_size
|
| 681 |
-
self.is_decoder = is_decoder
|
| 682 |
-
self.fast_mode = (fast_encoder and not is_decoder) or (
|
| 683 |
-
fast_decoder and is_decoder)
|
| 684 |
-
self.color_fix = color_fix and not is_decoder
|
| 685 |
-
self.to_gpu = to_gpu
|
| 686 |
-
self.pad = 11 if is_decoder else 32
|
| 687 |
-
|
| 688 |
-
def __call__(self, x):
|
| 689 |
-
B, C, H, W = x.shape
|
| 690 |
-
original_device = next(self.net.parameters()).device
|
| 691 |
-
try:
|
| 692 |
-
if self.to_gpu:
|
| 693 |
-
self.net.to(devices.get_optimal_device())
|
| 694 |
-
if max(H, W) <= self.pad * 2 + self.tile_size:
|
| 695 |
-
print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
| 696 |
-
return self.net.original_forward(x)
|
| 697 |
-
else:
|
| 698 |
-
return self.vae_tile_forward(x)
|
| 699 |
-
finally:
|
| 700 |
-
self.net.to(original_device)
|
| 701 |
-
|
| 702 |
-
def get_best_tile_size(self, lowerbound, upperbound):
|
| 703 |
-
"""
|
| 704 |
-
Get the best tile size for GPU memory
|
| 705 |
-
"""
|
| 706 |
-
divider = 32
|
| 707 |
-
while divider >= 2:
|
| 708 |
-
remainer = lowerbound % divider
|
| 709 |
-
if remainer == 0:
|
| 710 |
-
return lowerbound
|
| 711 |
-
candidate = lowerbound - remainer + divider
|
| 712 |
-
if candidate <= upperbound:
|
| 713 |
-
return candidate
|
| 714 |
-
divider //= 2
|
| 715 |
-
return lowerbound
|
| 716 |
-
|
| 717 |
-
def split_tiles(self, h, w):
|
| 718 |
-
"""
|
| 719 |
-
Tool function to split the image into tiles
|
| 720 |
-
@param h: height of the image
|
| 721 |
-
@param w: width of the image
|
| 722 |
-
@return: tile_input_bboxes, tile_output_bboxes
|
| 723 |
-
"""
|
| 724 |
-
tile_input_bboxes, tile_output_bboxes = [], []
|
| 725 |
-
tile_size = self.tile_size
|
| 726 |
-
pad = self.pad
|
| 727 |
-
num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
| 728 |
-
num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
| 729 |
-
# If any of the numbers are 0, we let it be 1
|
| 730 |
-
# This is to deal with long and thin images
|
| 731 |
-
num_height_tiles = max(num_height_tiles, 1)
|
| 732 |
-
num_width_tiles = max(num_width_tiles, 1)
|
| 733 |
-
|
| 734 |
-
# Suggestions from https://github.com/Kahsolt: auto shrink the tile size
|
| 735 |
-
real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
| 736 |
-
real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
| 737 |
-
real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
| 738 |
-
real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
| 739 |
-
|
| 740 |
-
print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
| 741 |
-
f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
| 742 |
-
|
| 743 |
-
for i in range(num_height_tiles):
|
| 744 |
-
for j in range(num_width_tiles):
|
| 745 |
-
# bbox: [x1, x2, y1, y2]
|
| 746 |
-
# the padding is is unnessary for image borders. So we directly start from (32, 32)
|
| 747 |
-
input_bbox = [
|
| 748 |
-
pad + j * real_tile_width,
|
| 749 |
-
min(pad + (j + 1) * real_tile_width, w),
|
| 750 |
-
pad + i * real_tile_height,
|
| 751 |
-
min(pad + (i + 1) * real_tile_height, h),
|
| 752 |
-
]
|
| 753 |
-
|
| 754 |
-
# if the output bbox is close to the image boundary, we extend it to the image boundary
|
| 755 |
-
output_bbox = [
|
| 756 |
-
input_bbox[0] if input_bbox[0] > pad else 0,
|
| 757 |
-
input_bbox[1] if input_bbox[1] < w - pad else w,
|
| 758 |
-
input_bbox[2] if input_bbox[2] > pad else 0,
|
| 759 |
-
input_bbox[3] if input_bbox[3] < h - pad else h,
|
| 760 |
-
]
|
| 761 |
-
|
| 762 |
-
# scale to get the final output bbox
|
| 763 |
-
output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
| 764 |
-
tile_output_bboxes.append(output_bbox)
|
| 765 |
-
|
| 766 |
-
# indistinguishable expand the input bbox by pad pixels
|
| 767 |
-
tile_input_bboxes.append([
|
| 768 |
-
max(0, input_bbox[0] - pad),
|
| 769 |
-
min(w, input_bbox[1] + pad),
|
| 770 |
-
max(0, input_bbox[2] - pad),
|
| 771 |
-
min(h, input_bbox[3] + pad),
|
| 772 |
-
])
|
| 773 |
-
|
| 774 |
-
return tile_input_bboxes, tile_output_bboxes
|
| 775 |
-
|
| 776 |
-
@torch.no_grad()
|
| 777 |
-
def estimate_group_norm(self, z, task_queue, color_fix):
|
| 778 |
-
device = z.device
|
| 779 |
-
tile = z
|
| 780 |
-
last_id = len(task_queue) - 1
|
| 781 |
-
while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
| 782 |
-
last_id -= 1
|
| 783 |
-
if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
| 784 |
-
raise ValueError('No group norm found in the task queue')
|
| 785 |
-
# estimate until the last group norm
|
| 786 |
-
for i in range(last_id + 1):
|
| 787 |
-
task = task_queue[i]
|
| 788 |
-
if task[0] == 'pre_norm':
|
| 789 |
-
group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
| 790 |
-
task_queue[i] = ('apply_norm', group_norm_func)
|
| 791 |
-
if i == last_id:
|
| 792 |
-
return True
|
| 793 |
-
tile = group_norm_func(tile)
|
| 794 |
-
elif task[0] == 'store_res':
|
| 795 |
-
task_id = i + 1
|
| 796 |
-
while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
| 797 |
-
task_id += 1
|
| 798 |
-
if task_id >= last_id:
|
| 799 |
-
continue
|
| 800 |
-
task_queue[task_id][1] = task[1](tile)
|
| 801 |
-
elif task[0] == 'add_res':
|
| 802 |
-
tile += task[1].to(device)
|
| 803 |
-
task[1] = None
|
| 804 |
-
elif color_fix and task[0] == 'downsample':
|
| 805 |
-
for j in range(i, last_id + 1):
|
| 806 |
-
if task_queue[j][0] == 'store_res':
|
| 807 |
-
task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
| 808 |
-
return True
|
| 809 |
-
else:
|
| 810 |
-
tile = task[1](tile)
|
| 811 |
-
try:
|
| 812 |
-
devices.test_for_nans(tile, "vae")
|
| 813 |
-
except:
|
| 814 |
-
print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
| 815 |
-
return False
|
| 816 |
-
|
| 817 |
-
raise IndexError('Should not reach here')
|
| 818 |
-
|
| 819 |
-
@perfcount
|
| 820 |
-
@torch.no_grad()
|
| 821 |
-
def vae_tile_forward(self, z):
|
| 822 |
-
"""
|
| 823 |
-
Decode a latent vector z into an image in a tiled manner.
|
| 824 |
-
@param z: latent vector
|
| 825 |
-
@return: image
|
| 826 |
-
"""
|
| 827 |
-
device = next(self.net.parameters()).device
|
| 828 |
-
dtype = z.dtype
|
| 829 |
-
net = self.net
|
| 830 |
-
tile_size = self.tile_size
|
| 831 |
-
is_decoder = self.is_decoder
|
| 832 |
-
|
| 833 |
-
z = z.detach() # detach the input to avoid backprop
|
| 834 |
-
|
| 835 |
-
N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
| 836 |
-
net.last_z_shape = z.shape
|
| 837 |
-
|
| 838 |
-
# Split the input into tiles and build a task queue for each tile
|
| 839 |
-
print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
| 840 |
-
|
| 841 |
-
in_bboxes, out_bboxes = self.split_tiles(height, width)
|
| 842 |
-
|
| 843 |
-
# Prepare tiles by split the input latents
|
| 844 |
-
tiles = []
|
| 845 |
-
for input_bbox in in_bboxes:
|
| 846 |
-
tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
| 847 |
-
tiles.append(tile)
|
| 848 |
-
|
| 849 |
-
num_tiles = len(tiles)
|
| 850 |
-
num_completed = 0
|
| 851 |
-
|
| 852 |
-
# Build task queues
|
| 853 |
-
single_task_queue = build_task_queue(net, is_decoder)
|
| 854 |
-
#print(single_task_queue)
|
| 855 |
-
if self.fast_mode:
|
| 856 |
-
# Fast mode: downsample the input image to the tile size,
|
| 857 |
-
# then estimate the group norm parameters on the downsampled image
|
| 858 |
-
scale_factor = tile_size / max(height, width)
|
| 859 |
-
z = z.to(device)
|
| 860 |
-
downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
| 861 |
-
# use nearest-exact to keep statictics as close as possible
|
| 862 |
-
print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
| 863 |
-
|
| 864 |
-
# ======= Special thanks to @Kahsolt for distribution shift issue ======= #
|
| 865 |
-
# The downsampling will heavily distort its mean and std, so we need to recover it.
|
| 866 |
-
std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
| 867 |
-
std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
| 868 |
-
downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
| 869 |
-
del std_old, mean_old, std_new, mean_new
|
| 870 |
-
# occasionally the std_new is too small or too large, which exceeds the range of float16
|
| 871 |
-
# so we need to clamp it to max z's range.
|
| 872 |
-
downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
| 873 |
-
estimate_task_queue = clone_task_queue(single_task_queue)
|
| 874 |
-
if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
| 875 |
-
single_task_queue = estimate_task_queue
|
| 876 |
-
del downsampled_z
|
| 877 |
-
|
| 878 |
-
task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
| 879 |
-
|
| 880 |
-
# Dummy result
|
| 881 |
-
result = None
|
| 882 |
-
result_approx = None
|
| 883 |
-
#try:
|
| 884 |
-
# with devices.autocast():
|
| 885 |
-
# result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()
|
| 886 |
-
#except: pass
|
| 887 |
-
# Free memory of input latent tensor
|
| 888 |
-
del z
|
| 889 |
-
|
| 890 |
-
# Task queue execution
|
| 891 |
-
pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
| 892 |
-
|
| 893 |
-
# execute the task back and forth when switch tiles so that we always
|
| 894 |
-
# keep one tile on the GPU to reduce unnecessary data transfer
|
| 895 |
-
forward = True
|
| 896 |
-
interrupted = False
|
| 897 |
-
#state.interrupted = interrupted
|
| 898 |
-
while True:
|
| 899 |
-
#if state.interrupted: interrupted = True ; break
|
| 900 |
-
|
| 901 |
-
group_norm_param = GroupNormParam()
|
| 902 |
-
for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
| 903 |
-
#if state.interrupted: interrupted = True ; break
|
| 904 |
-
|
| 905 |
-
tile = tiles[i].to(device)
|
| 906 |
-
input_bbox = in_bboxes[i]
|
| 907 |
-
task_queue = task_queues[i]
|
| 908 |
-
|
| 909 |
-
interrupted = False
|
| 910 |
-
while len(task_queue) > 0:
|
| 911 |
-
#if state.interrupted: interrupted = True ; break
|
| 912 |
-
|
| 913 |
-
# DEBUG: current task
|
| 914 |
-
# print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)
|
| 915 |
-
task = task_queue.pop(0)
|
| 916 |
-
if task[0] == 'pre_norm':
|
| 917 |
-
group_norm_param.add_tile(tile, task[1])
|
| 918 |
-
break
|
| 919 |
-
elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
| 920 |
-
task_id = 0
|
| 921 |
-
res = task[1](tile)
|
| 922 |
-
if not self.fast_mode or task[0] == 'store_res_cpu':
|
| 923 |
-
res = res.cpu()
|
| 924 |
-
while task_queue[task_id][0] != 'add_res':
|
| 925 |
-
task_id += 1
|
| 926 |
-
task_queue[task_id][1] = res
|
| 927 |
-
elif task[0] == 'add_res':
|
| 928 |
-
tile += task[1].to(device)
|
| 929 |
-
task[1] = None
|
| 930 |
-
else:
|
| 931 |
-
tile = task[1](tile)
|
| 932 |
-
#print(tiles[i].shape, tile.shape, task)
|
| 933 |
-
pbar.update(1)
|
| 934 |
-
|
| 935 |
-
if interrupted: break
|
| 936 |
-
|
| 937 |
-
# check for NaNs in the tile.
|
| 938 |
-
# If there are NaNs, we abort the process to save user's time
|
| 939 |
-
#devices.test_for_nans(tile, "vae")
|
| 940 |
-
|
| 941 |
-
#print(tiles[i].shape, tile.shape, i, num_tiles)
|
| 942 |
-
if len(task_queue) == 0:
|
| 943 |
-
tiles[i] = None
|
| 944 |
-
num_completed += 1
|
| 945 |
-
if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically
|
| 946 |
-
result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
| 947 |
-
result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
| 948 |
-
del tile
|
| 949 |
-
elif i == num_tiles - 1 and forward:
|
| 950 |
-
forward = False
|
| 951 |
-
tiles[i] = tile
|
| 952 |
-
elif i == 0 and not forward:
|
| 953 |
-
forward = True
|
| 954 |
-
tiles[i] = tile
|
| 955 |
-
else:
|
| 956 |
-
tiles[i] = tile.cpu()
|
| 957 |
-
del tile
|
| 958 |
-
|
| 959 |
-
if interrupted: break
|
| 960 |
-
if num_completed == num_tiles: break
|
| 961 |
-
|
| 962 |
-
# insert the group norm task to the head of each task queue
|
| 963 |
-
group_norm_func = group_norm_param.summary()
|
| 964 |
-
if group_norm_func is not None:
|
| 965 |
-
for i in range(num_tiles):
|
| 966 |
-
task_queue = task_queues[i]
|
| 967 |
-
task_queue.insert(0, ('apply_norm', group_norm_func))
|
| 968 |
-
|
| 969 |
-
# Done!
|
| 970 |
-
pbar.close()
|
| 971 |
-
return result.to(dtype) if result is not None else result_approx.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|