|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import gc
|
| from time import time
|
| import math
|
| from tqdm import tqdm
|
|
|
| import torch |
| import torch.version |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| import SUPIR.utils.devices as devices |
|
|
| try: |
| import xformers |
| import xformers.ops |
| HAS_XFORMERS = True |
| except ImportError: |
| HAS_XFORMERS = False |
|
|
| sd_flag = True
|
|
|
| def get_recommend_encoder_tile_size():
|
| if torch.cuda.is_available():
|
| total_memory = torch.cuda.get_device_properties(
|
| devices.device).total_memory // 2**20
|
| if total_memory > 16*1000:
|
| ENCODER_TILE_SIZE = 3072
|
| elif total_memory > 12*1000:
|
| ENCODER_TILE_SIZE = 2048
|
| elif total_memory > 8*1000:
|
| ENCODER_TILE_SIZE = 1536
|
| else:
|
| ENCODER_TILE_SIZE = 960
|
| else:
|
| ENCODER_TILE_SIZE = 512
|
| return ENCODER_TILE_SIZE
|
|
|
|
|
| def get_recommend_decoder_tile_size():
|
| if torch.cuda.is_available():
|
| total_memory = torch.cuda.get_device_properties(
|
| devices.device).total_memory // 2**20
|
| if total_memory > 30*1000:
|
| DECODER_TILE_SIZE = 256
|
| elif total_memory > 16*1000:
|
| DECODER_TILE_SIZE = 192
|
| elif total_memory > 12*1000:
|
| DECODER_TILE_SIZE = 128
|
| elif total_memory > 8*1000:
|
| DECODER_TILE_SIZE = 96
|
| else:
|
| DECODER_TILE_SIZE = 64
|
| else:
|
| DECODER_TILE_SIZE = 64
|
| return DECODER_TILE_SIZE
|
|
|
|
|
| if 'global const':
|
| DEFAULT_ENABLED = False
|
| DEFAULT_MOVE_TO_GPU = False
|
| DEFAULT_FAST_ENCODER = True
|
| DEFAULT_FAST_DECODER = True
|
| DEFAULT_COLOR_FIX = 0
|
| DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()
|
| DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()
|
|
|
|
|
|
|
| def inplace_nonlinearity(x):
|
|
|
| return F.silu(x, inplace=True)
|
|
|
|
|
|
|
|
|
| def attn_forward_new(self, h_):
|
| batch_size, channel, height, width = h_.shape
|
| hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
| attention_mask = None
|
| encoder_hidden_states = None
|
| batch_size, sequence_length, _ = hidden_states.shape
|
| attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
| query = self.to_q(hidden_states)
|
|
|
| if encoder_hidden_states is None:
|
| encoder_hidden_states = hidden_states
|
| elif self.norm_cross:
|
| encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
| key = self.to_k(encoder_hidden_states)
|
| value = self.to_v(encoder_hidden_states)
|
|
|
| query = self.head_to_batch_dim(query)
|
| key = self.head_to_batch_dim(key)
|
| value = self.head_to_batch_dim(value)
|
|
|
| attention_probs = self.get_attention_scores(query, key, attention_mask)
|
| hidden_states = torch.bmm(attention_probs, value)
|
| hidden_states = self.batch_to_head_dim(hidden_states)
|
|
|
|
|
| hidden_states = self.to_out[0](hidden_states)
|
|
|
| hidden_states = self.to_out[1](hidden_states)
|
|
|
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
| return hidden_states
|
|
|
| def attn_forward_new_pt2_0(self, hidden_states,):
|
| scale = 1
|
| attention_mask = None
|
| encoder_hidden_states = None
|
|
|
| input_ndim = hidden_states.ndim
|
|
|
| if input_ndim == 4:
|
| batch_size, channel, height, width = hidden_states.shape
|
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
| batch_size, sequence_length, _ = (
|
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| )
|
|
|
| if attention_mask is not None:
|
| attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
|
| attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])
|
|
|
| if self.group_norm is not None:
|
| hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
| query = self.to_q(hidden_states, scale=scale)
|
|
|
| if encoder_hidden_states is None:
|
| encoder_hidden_states = hidden_states
|
| elif self.norm_cross:
|
| encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
| key = self.to_k(encoder_hidden_states, scale=scale)
|
| value = self.to_v(encoder_hidden_states, scale=scale)
|
|
|
| inner_dim = key.shape[-1]
|
| head_dim = inner_dim // self.heads
|
|
|
| query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
| key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
|
| hidden_states = F.scaled_dot_product_attention(
|
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| )
|
|
|
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
| hidden_states = hidden_states.to(query.dtype)
|
|
|
|
|
| hidden_states = self.to_out[0](hidden_states, scale=scale)
|
|
|
| hidden_states = self.to_out[1](hidden_states)
|
|
|
| if input_ndim == 4:
|
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
| return hidden_states
|
|
|
| def attn_forward_new_xformers(self, hidden_states):
|
| scale = 1
|
| attention_op = None
|
| attention_mask = None
|
| encoder_hidden_states = None
|
|
|
| input_ndim = hidden_states.ndim
|
|
|
| if input_ndim == 4:
|
| batch_size, channel, height, width = hidden_states.shape
|
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
| batch_size, key_tokens, _ = (
|
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| )
|
|
|
| attention_mask = self.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
| if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| _, query_tokens, _ = hidden_states.shape
|
| attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
|
|
| if self.group_norm is not None:
|
| hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
| query = self.to_q(hidden_states, scale=scale)
|
|
|
| if encoder_hidden_states is None:
|
| encoder_hidden_states = hidden_states
|
| elif self.norm_cross:
|
| encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
| key = self.to_k(encoder_hidden_states, scale=scale)
|
| value = self.to_v(encoder_hidden_states, scale=scale)
|
|
|
| query = self.head_to_batch_dim(query).contiguous()
|
| key = self.head_to_batch_dim(key).contiguous()
|
| value = self.head_to_batch_dim(value).contiguous()
|
|
|
| hidden_states = xformers.ops.memory_efficient_attention(
|
| query, key, value, attn_bias=attention_mask, op=attention_op
|
| )
|
| hidden_states = hidden_states.to(query.dtype)
|
| hidden_states = self.batch_to_head_dim(hidden_states)
|
|
|
|
|
| hidden_states = self.to_out[0](hidden_states, scale=scale)
|
|
|
| hidden_states = self.to_out[1](hidden_states)
|
|
|
| if input_ndim == 4:
|
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
| return hidden_states
|
|
|
| def attn_forward(self, h_):
|
| q = self.q(h_)
|
| k = self.k(h_)
|
| v = self.v(h_)
|
|
|
|
|
| b, c, h, w = q.shape
|
| q = q.reshape(b, c, h*w)
|
| q = q.permute(0, 2, 1)
|
| k = k.reshape(b, c, h*w)
|
| w_ = torch.bmm(q, k)
|
| w_ = w_ * (int(c)**(-0.5))
|
| w_ = torch.nn.functional.softmax(w_, dim=2)
|
|
|
|
|
| v = v.reshape(b, c, h*w)
|
| w_ = w_.permute(0, 2, 1)
|
|
|
| h_ = torch.bmm(v, w_)
|
| h_ = h_.reshape(b, c, h, w)
|
|
|
| h_ = self.proj_out(h_)
|
|
|
| return h_
|
|
|
|
|
| def xformer_attn_forward(self, h_):
|
| q = self.q(h_)
|
| k = self.k(h_)
|
| v = self.v(h_)
|
|
|
|
|
| B, C, H, W = q.shape
|
| q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
|
|
| q, k, v = map(
|
| lambda t: t.unsqueeze(3)
|
| .reshape(B, t.shape[1], 1, C)
|
| .permute(0, 2, 1, 3)
|
| .reshape(B * 1, t.shape[1], C)
|
| .contiguous(),
|
| (q, k, v),
|
| )
|
| out = xformers.ops.memory_efficient_attention(
|
| q, k, v, attn_bias=None, op=self.attention_op)
|
|
|
| out = (
|
| out.unsqueeze(0)
|
| .reshape(B, 1, out.shape[1], C)
|
| .permute(0, 2, 1, 3)
|
| .reshape(B, out.shape[1], C)
|
| )
|
| out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
| out = self.proj_out(out)
|
| return out
|
|
|
|
|
| def attn2task(task_queue, net):
|
| if False:
|
| task_queue.append(('store_res', lambda x: x))
|
| task_queue.append(('pre_norm', net.norm))
|
| task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))
|
| task_queue.append(['add_res', None])
|
| elif False:
|
| task_queue.append(('store_res', lambda x: x))
|
| task_queue.append(('pre_norm', net.norm))
|
| task_queue.append(
|
| ('attn', lambda x, net=net: xformer_attn_forward(net, x)))
|
| task_queue.append(['add_res', None])
|
| else:
|
| task_queue.append(('store_res', lambda x: x))
|
| task_queue.append(('pre_norm', net.norm))
|
| if HAS_XFORMERS: |
| |
| task_queue.append( |
| ('attn', lambda x, net=net: xformer_attn_forward(net, x))) |
| elif hasattr(F, "scaled_dot_product_attention"):
|
| task_queue.append(('attn', lambda x, net=net: attn_forward_new_pt2_0(net, x)))
|
| else:
|
| task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))
|
| task_queue.append(['add_res', None])
|
|
|
| def resblock2task(queue, block):
|
| """
|
| Turn a ResNetBlock into a sequence of tasks and append to the task queue
|
|
|
| @param queue: the target task queue
|
| @param block: ResNetBlock
|
|
|
| """
|
| if block.in_channels != block.out_channels:
|
| if sd_flag:
|
| if block.use_conv_shortcut:
|
| queue.append(('store_res', block.conv_shortcut))
|
| else:
|
| queue.append(('store_res', block.nin_shortcut))
|
| else:
|
| if block.use_in_shortcut:
|
| queue.append(('store_res', block.conv_shortcut))
|
| else:
|
| queue.append(('store_res', block.nin_shortcut))
|
|
|
| else:
|
| queue.append(('store_res', lambda x: x))
|
| queue.append(('pre_norm', block.norm1))
|
| queue.append(('silu', inplace_nonlinearity))
|
| queue.append(('conv1', block.conv1))
|
| queue.append(('pre_norm', block.norm2))
|
| queue.append(('silu', inplace_nonlinearity))
|
| queue.append(('conv2', block.conv2))
|
| queue.append(['add_res', None])
|
|
|
|
|
| def build_sampling(task_queue, net, is_decoder):
|
| """
|
| Build the sampling part of a task queue
|
| @param task_queue: the target task queue
|
| @param net: the network
|
| @param is_decoder: currently building decoder or encoder
|
| """
|
| if is_decoder:
|
| if sd_flag:
|
| resblock2task(task_queue, net.mid.block_1)
|
| attn2task(task_queue, net.mid.attn_1)
|
| print(task_queue)
|
| resblock2task(task_queue, net.mid.block_2)
|
| resolution_iter = reversed(range(net.num_resolutions))
|
| block_ids = net.num_res_blocks + 1
|
| condition = 0
|
| module = net.up
|
| func_name = 'upsample'
|
| else:
|
| resblock2task(task_queue, net.mid_block.resnets[0])
|
| attn2task(task_queue, net.mid_block.attentions[0])
|
| resblock2task(task_queue, net.mid_block.resnets[1])
|
| resolution_iter = (range(len(net.up_blocks)))
|
| block_ids = 2 + 1
|
| condition = len(net.up_blocks) - 1
|
| module = net.up_blocks
|
| func_name = 'upsamplers'
|
| else:
|
| if sd_flag:
|
| resolution_iter = range(net.num_resolutions)
|
| block_ids = net.num_res_blocks
|
| condition = net.num_resolutions - 1
|
| module = net.down
|
| func_name = 'downsample'
|
| else:
|
| resolution_iter = range(len(net.down_blocks))
|
| block_ids = 2
|
| condition = len(net.down_blocks) - 1
|
| module = net.down_blocks
|
| func_name = 'downsamplers'
|
|
|
| for i_level in resolution_iter:
|
| for i_block in range(block_ids):
|
| if sd_flag:
|
| resblock2task(task_queue, module[i_level].block[i_block])
|
| else:
|
| resblock2task(task_queue, module[i_level].resnets[i_block])
|
| if i_level != condition:
|
| if sd_flag:
|
| task_queue.append((func_name, getattr(module[i_level], func_name)))
|
| else:
|
| if is_decoder:
|
| task_queue.append((func_name, module[i_level].upsamplers[0]))
|
| else:
|
| task_queue.append((func_name, module[i_level].downsamplers[0]))
|
|
|
| if not is_decoder:
|
| if sd_flag:
|
| resblock2task(task_queue, net.mid.block_1)
|
| attn2task(task_queue, net.mid.attn_1)
|
| resblock2task(task_queue, net.mid.block_2)
|
| else:
|
| resblock2task(task_queue, net.mid_block.resnets[0])
|
| attn2task(task_queue, net.mid_block.attentions[0])
|
| resblock2task(task_queue, net.mid_block.resnets[1])
|
|
|
|
|
| def build_task_queue(net, is_decoder):
|
| """
|
| Build a single task queue for the encoder or decoder
|
| @param net: the VAE decoder or encoder network
|
| @param is_decoder: currently building decoder or encoder
|
| @return: the task queue
|
| """
|
| task_queue = []
|
| task_queue.append(('conv_in', net.conv_in))
|
|
|
|
|
|
|
| build_sampling(task_queue, net, is_decoder)
|
| if is_decoder and not sd_flag:
|
| net.give_pre_end = False
|
| net.tanh_out = False
|
|
|
| if not is_decoder or not net.give_pre_end:
|
| if sd_flag:
|
| task_queue.append(('pre_norm', net.norm_out))
|
| else:
|
| task_queue.append(('pre_norm', net.conv_norm_out))
|
| task_queue.append(('silu', inplace_nonlinearity))
|
| task_queue.append(('conv_out', net.conv_out))
|
| if is_decoder and net.tanh_out:
|
| task_queue.append(('tanh', torch.tanh))
|
|
|
| return task_queue
|
|
|
|
|
| def clone_task_queue(task_queue):
|
| """
|
| Clone a task queue
|
| @param task_queue: the task queue to be cloned
|
| @return: the cloned task queue
|
| """
|
| return [[item for item in task] for task in task_queue]
|
|
|
|
|
| def get_var_mean(input, num_groups, eps=1e-6):
|
| """
|
| Get mean and var for group norm
|
| """
|
| b, c = input.size(0), input.size(1)
|
| channel_in_group = int(c/num_groups)
|
| input_reshaped = input.contiguous().view(
|
| 1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
| var, mean = torch.var_mean(
|
| input_reshaped, dim=[0, 2, 3, 4], unbiased=False)
|
| return var, mean
|
|
|
|
|
| def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):
|
| """
|
| Custom group norm with fixed mean and var
|
|
|
| @param input: input tensor
|
| @param num_groups: number of groups. by default, num_groups = 32
|
| @param mean: mean, must be pre-calculated by get_var_mean
|
| @param var: var, must be pre-calculated by get_var_mean
|
| @param weight: weight, should be fetched from the original group norm
|
| @param bias: bias, should be fetched from the original group norm
|
| @param eps: epsilon, by default, eps = 1e-6 to match the original group norm
|
|
|
| @return: normalized tensor
|
| """
|
| b, c = input.size(0), input.size(1)
|
| channel_in_group = int(c/num_groups)
|
| input_reshaped = input.contiguous().view(
|
| 1, int(b * num_groups), channel_in_group, *input.size()[2:])
|
|
|
| out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,
|
| training=False, momentum=0, eps=eps)
|
|
|
| out = out.view(b, c, *input.size()[2:])
|
|
|
|
|
| if weight is not None:
|
| out *= weight.view(1, -1, 1, 1)
|
| if bias is not None:
|
| out += bias.view(1, -1, 1, 1)
|
| return out
|
|
|
|
|
| def crop_valid_region(x, input_bbox, target_bbox, is_decoder):
|
| """
|
| Crop the valid region from the tile
|
| @param x: input tile
|
| @param input_bbox: original input bounding box
|
| @param target_bbox: output bounding box
|
| @param scale: scale factor
|
| @return: cropped tile
|
| """
|
| padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]
|
| margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]
|
| return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]
|
|
|
|
|
|
|
|
|
| def perfcount(fn):
|
| def wrapper(*args, **kwargs):
|
| ts = time()
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.reset_peak_memory_stats(devices.device)
|
| devices.torch_gc()
|
| gc.collect()
|
|
|
| ret = fn(*args, **kwargs)
|
|
|
| devices.torch_gc()
|
| gc.collect()
|
| if torch.cuda.is_available():
|
| vram = torch.cuda.max_memory_allocated(devices.device) / 2**20
|
| torch.cuda.reset_peak_memory_stats(devices.device)
|
| print(
|
| f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')
|
| else:
|
| print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')
|
|
|
| return ret
|
| return wrapper
|
|
|
|
|
|
|
|
|
| class GroupNormParam:
|
| def __init__(self):
|
| self.var_list = []
|
| self.mean_list = []
|
| self.pixel_list = []
|
| self.weight = None
|
| self.bias = None
|
|
|
| def add_tile(self, tile, layer):
|
| var, mean = get_var_mean(tile, 32)
|
|
|
|
|
| if var.dtype == torch.float16 and var.isinf().any():
|
| fp32_tile = tile.float()
|
| var, mean = get_var_mean(fp32_tile, 32)
|
|
|
|
|
|
|
|
|
| self.var_list.append(var)
|
| self.mean_list.append(mean)
|
| self.pixel_list.append(
|
| tile.shape[2]*tile.shape[3])
|
| if hasattr(layer, 'weight'):
|
| self.weight = layer.weight
|
| self.bias = layer.bias
|
| else:
|
| self.weight = None
|
| self.bias = None
|
|
|
| def summary(self):
|
| """
|
| summarize the mean and var and return a function
|
| that apply group norm on each tile
|
| """
|
| if len(self.var_list) == 0:
|
| return None
|
| var = torch.vstack(self.var_list)
|
| mean = torch.vstack(self.mean_list)
|
| max_value = max(self.pixel_list)
|
| pixels = torch.tensor(
|
| self.pixel_list, dtype=torch.float32, device=devices.device) / max_value
|
| sum_pixels = torch.sum(pixels)
|
| pixels = pixels.unsqueeze(
|
| 1) / sum_pixels
|
| var = torch.sum(
|
| var * pixels, dim=0)
|
| mean = torch.sum(
|
| mean * pixels, dim=0)
|
| return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias)
|
|
|
| @staticmethod
|
| def from_tile(tile, norm):
|
| """
|
| create a function from a single tile without summary
|
| """
|
| var, mean = get_var_mean(tile, 32)
|
| if var.dtype == torch.float16 and var.isinf().any():
|
| fp32_tile = tile.float()
|
| var, mean = get_var_mean(fp32_tile, 32)
|
|
|
| if var.device.type == 'mps':
|
|
|
| var = torch.clamp(var, 0, 60000)
|
| var = var.half()
|
| mean = mean.half()
|
| if hasattr(norm, 'weight'):
|
| weight = norm.weight
|
| bias = norm.bias
|
| else:
|
| weight = None
|
| bias = None
|
|
|
| def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):
|
| return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)
|
| return group_norm_func
|
|
|
|
|
| class VAEHook:
|
| def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):
|
| self.net = net
|
| self.tile_size = tile_size
|
| self.is_decoder = is_decoder
|
| self.fast_mode = (fast_encoder and not is_decoder) or (
|
| fast_decoder and is_decoder)
|
| self.color_fix = color_fix and not is_decoder
|
| self.to_gpu = to_gpu
|
| self.pad = 11 if is_decoder else 32
|
|
|
| def __call__(self, x):
|
| B, C, H, W = x.shape
|
| original_device = next(self.net.parameters()).device
|
| try:
|
| if self.to_gpu:
|
| self.net.to(devices.get_optimal_device())
|
| if max(H, W) <= self.pad * 2 + self.tile_size:
|
| print("[Tiled VAE]: the input size is tiny and unnecessary to tile.")
|
| return self.net.original_forward(x)
|
| else:
|
| return self.vae_tile_forward(x)
|
| finally:
|
| self.net.to(original_device)
|
|
|
| def get_best_tile_size(self, lowerbound, upperbound):
|
| """
|
| Get the best tile size for GPU memory
|
| """
|
| divider = 32
|
| while divider >= 2:
|
| remainer = lowerbound % divider
|
| if remainer == 0:
|
| return lowerbound
|
| candidate = lowerbound - remainer + divider
|
| if candidate <= upperbound:
|
| return candidate
|
| divider //= 2
|
| return lowerbound
|
|
|
| def split_tiles(self, h, w):
|
| """
|
| Tool function to split the image into tiles
|
| @param h: height of the image
|
| @param w: width of the image
|
| @return: tile_input_bboxes, tile_output_bboxes
|
| """
|
| tile_input_bboxes, tile_output_bboxes = [], []
|
| tile_size = self.tile_size
|
| pad = self.pad
|
| num_height_tiles = math.ceil((h - 2 * pad) / tile_size)
|
| num_width_tiles = math.ceil((w - 2 * pad) / tile_size)
|
|
|
|
|
| num_height_tiles = max(num_height_tiles, 1)
|
| num_width_tiles = max(num_width_tiles, 1)
|
|
|
|
|
| real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)
|
| real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)
|
| real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)
|
| real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)
|
|
|
| print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +
|
| f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')
|
|
|
| for i in range(num_height_tiles):
|
| for j in range(num_width_tiles):
|
|
|
|
|
| input_bbox = [
|
| pad + j * real_tile_width,
|
| min(pad + (j + 1) * real_tile_width, w),
|
| pad + i * real_tile_height,
|
| min(pad + (i + 1) * real_tile_height, h),
|
| ]
|
|
|
|
|
| output_bbox = [
|
| input_bbox[0] if input_bbox[0] > pad else 0,
|
| input_bbox[1] if input_bbox[1] < w - pad else w,
|
| input_bbox[2] if input_bbox[2] > pad else 0,
|
| input_bbox[3] if input_bbox[3] < h - pad else h,
|
| ]
|
|
|
|
|
| output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]
|
| tile_output_bboxes.append(output_bbox)
|
|
|
|
|
| tile_input_bboxes.append([
|
| max(0, input_bbox[0] - pad),
|
| min(w, input_bbox[1] + pad),
|
| max(0, input_bbox[2] - pad),
|
| min(h, input_bbox[3] + pad),
|
| ])
|
|
|
| return tile_input_bboxes, tile_output_bboxes
|
|
|
| @torch.no_grad()
|
| def estimate_group_norm(self, z, task_queue, color_fix):
|
| device = z.device
|
| tile = z
|
| last_id = len(task_queue) - 1
|
| while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':
|
| last_id -= 1
|
| if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':
|
| raise ValueError('No group norm found in the task queue')
|
|
|
| for i in range(last_id + 1):
|
| task = task_queue[i]
|
| if task[0] == 'pre_norm':
|
| group_norm_func = GroupNormParam.from_tile(tile, task[1])
|
| task_queue[i] = ('apply_norm', group_norm_func)
|
| if i == last_id:
|
| return True
|
| tile = group_norm_func(tile)
|
| elif task[0] == 'store_res':
|
| task_id = i + 1
|
| while task_id < last_id and task_queue[task_id][0] != 'add_res':
|
| task_id += 1
|
| if task_id >= last_id:
|
| continue
|
| task_queue[task_id][1] = task[1](tile)
|
| elif task[0] == 'add_res':
|
| tile += task[1].to(device)
|
| task[1] = None
|
| elif color_fix and task[0] == 'downsample':
|
| for j in range(i, last_id + 1):
|
| if task_queue[j][0] == 'store_res':
|
| task_queue[j] = ('store_res_cpu', task_queue[j][1])
|
| return True
|
| else:
|
| tile = task[1](tile)
|
| try:
|
| devices.test_for_nans(tile, "vae")
|
| except:
|
| print(f'Nan detected in fast mode estimation. Fast mode disabled.')
|
| return False
|
|
|
| raise IndexError('Should not reach here')
|
|
|
| @perfcount
|
| @torch.no_grad()
|
| def vae_tile_forward(self, z):
|
| """
|
| Decode a latent vector z into an image in a tiled manner.
|
| @param z: latent vector
|
| @return: image
|
| """
|
| device = next(self.net.parameters()).device
|
| dtype = z.dtype
|
| net = self.net
|
| tile_size = self.tile_size
|
| is_decoder = self.is_decoder
|
|
|
| z = z.detach()
|
|
|
| N, height, width = z.shape[0], z.shape[2], z.shape[3]
|
| net.last_z_shape = z.shape
|
|
|
|
|
| print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')
|
|
|
| in_bboxes, out_bboxes = self.split_tiles(height, width)
|
|
|
|
|
| tiles = []
|
| for input_bbox in in_bboxes:
|
| tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()
|
| tiles.append(tile)
|
|
|
| num_tiles = len(tiles)
|
| num_completed = 0
|
|
|
|
|
| single_task_queue = build_task_queue(net, is_decoder)
|
|
|
| if self.fast_mode:
|
|
|
|
|
| scale_factor = tile_size / max(height, width)
|
| z = z.to(device)
|
| downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')
|
|
|
| print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')
|
|
|
|
|
|
|
| std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)
|
| std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)
|
| downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old
|
| del std_old, mean_old, std_new, mean_new
|
|
|
|
|
| downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())
|
| estimate_task_queue = clone_task_queue(single_task_queue)
|
| if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):
|
| single_task_queue = estimate_task_queue
|
| del downsampled_z
|
|
|
| task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]
|
|
|
|
|
| result = None
|
| result_approx = None
|
|
|
|
|
|
|
|
|
|
|
| del z
|
|
|
|
|
| pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ")
|
|
|
|
|
|
|
| forward = True
|
| interrupted = False
|
|
|
| while True:
|
|
|
|
|
| group_norm_param = GroupNormParam()
|
| for i in range(num_tiles) if forward else reversed(range(num_tiles)):
|
|
|
|
|
| tile = tiles[i].to(device)
|
| input_bbox = in_bboxes[i]
|
| task_queue = task_queues[i]
|
|
|
| interrupted = False
|
| while len(task_queue) > 0:
|
|
|
|
|
|
|
|
|
| task = task_queue.pop(0)
|
| if task[0] == 'pre_norm':
|
| group_norm_param.add_tile(tile, task[1])
|
| break
|
| elif task[0] == 'store_res' or task[0] == 'store_res_cpu':
|
| task_id = 0
|
| res = task[1](tile)
|
| if not self.fast_mode or task[0] == 'store_res_cpu':
|
| res = res.cpu()
|
| while task_queue[task_id][0] != 'add_res':
|
| task_id += 1
|
| task_queue[task_id][1] = res
|
| elif task[0] == 'add_res':
|
| tile += task[1].to(device)
|
| task[1] = None
|
| else:
|
| tile = task[1](tile)
|
|
|
| pbar.update(1)
|
|
|
| if interrupted: break
|
|
|
|
|
|
|
|
|
|
|
|
|
| if len(task_queue) == 0:
|
| tiles[i] = None
|
| num_completed += 1
|
| if result is None:
|
| result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)
|
| result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)
|
| del tile
|
| elif i == num_tiles - 1 and forward:
|
| forward = False
|
| tiles[i] = tile
|
| elif i == 0 and not forward:
|
| forward = True
|
| tiles[i] = tile
|
| else:
|
| tiles[i] = tile.cpu()
|
| del tile
|
|
|
| if interrupted: break
|
| if num_completed == num_tiles: break
|
|
|
|
|
| group_norm_func = group_norm_param.summary()
|
| if group_norm_func is not None:
|
| for i in range(num_tiles):
|
| task_queue = task_queues[i]
|
| task_queue.insert(0, ('apply_norm', group_norm_func))
|
|
|
|
|
| pbar.close()
|
| return result.to(dtype) if result is not None else result_approx.to(device) |
|
|