| | import math |
| | from typing import Union, Callable, List |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms.functional |
| | from torch import nn, Tensor |
| | from einops import rearrange |
| | from PIL import Image |
| | from modules.processing import StableDiffusionProcessing, slerp as Slerp |
| |
|
| | from scripts.sdhook import ( |
| | SDHook, |
| | each_unet_attn_layers, |
| | each_unet_transformers, |
| | each_unet_resblock |
| | ) |
| |
|
| |
|
| | class Upscaler: |
| | |
| | def __init__(self, mode: str, aa: bool): |
| | mode = { |
| | 'nearest': 'nearest-exact', |
| | 'bilinear': 'bilinear', |
| | 'bicubic': 'bicubic', |
| | }.get(mode.lower(), mode) |
| | self.mode = mode |
| | self.aa = bool(aa) |
| | |
| | @property |
| | def name(self): |
| | s = self.mode |
| | if self.aa: s += '-aa' |
| | return s |
| | |
| | def __call__(self, x: Tensor, scale: float = 2.0): |
| | return F.interpolate(x, scale_factor=scale, mode=self.mode, antialias=self.aa) |
| |
|
| |
|
| | class Downscaler: |
| | |
| | def __init__(self, mode: str, aa: bool): |
| | self._name = mode.lower() |
| | intp, mode = { |
| | 'nearest': (F.interpolate, 'nearest-exact'), |
| | 'bilinear': (F.interpolate, 'bilinear'), |
| | 'bicubic': (F.interpolate, 'bicubic'), |
| | 'area': (F.interpolate, 'area'), |
| | 'pooling max': (F.max_pool2d, ''), |
| | 'pooling avg': (F.avg_pool2d, ''), |
| | }[mode.lower()] |
| | self.intp = intp |
| | self.mode = mode |
| | self.aa = bool(aa) |
| | |
| | @property |
| | def name(self): |
| | s = self._name |
| | if self.aa: s += '-aa' |
| | return s |
| | |
| | def __call__(self, x: Tensor, scale: float = 2.0): |
| | if scale <= 1: |
| | scale = float(scale) |
| | scale_inv = 1 / scale |
| | else: |
| | scale_inv = float(scale) |
| | scale = 1 / scale_inv |
| | assert scale <= 1 |
| | assert 1 <= scale_inv |
| | |
| | kwargs = {} |
| | if len(self.mode) != 0: |
| | kwargs['scale_factor'] = scale |
| | kwargs['mode'] = self.mode |
| | kwargs['antialias'] = self.aa |
| | else: |
| | kwargs['kernel_size'] = int(scale_inv) |
| | return self.intp(x, **kwargs) |
| |
|
| |
|
| | def lerp(v0, v1, t): |
| | return torch.lerp(v0, v1, t) |
| |
|
| | def slerp(v0, v1, t): |
| | v = Slerp(t, v0, v1) |
| | if torch.any(torch.isnan(v)).item(): |
| | v = lerp(v0, v1, t) |
| | return v |
| |
|
| | class Hooker(SDHook): |
| | |
| | def __init__( |
| | self, |
| | enabled: bool, |
| | multiply: int, |
| | weight: float, |
| | noise_strength: float, |
| | layers: Union[list,None], |
| | apply_to: List[str], |
| | start_steps: int, |
| | max_steps: int, |
| | up_fn: Callable[[Tensor,float], Tensor], |
| | down_fn: Callable[[Tensor,float], Tensor], |
| | intp: str, |
| | x: float, |
| | y: float, |
| | force_float: bool, |
| | mask_image: Union[Image.Image,None], |
| | ): |
| | super().__init__(enabled) |
| | self.multiply = int(multiply) |
| | self.weight = float(weight) |
| | self.noise_strength = float(noise_strength) |
| | self.layers = layers |
| | self.apply_to = apply_to |
| | self.start_steps = int(start_steps) |
| | self.max_steps = int(max_steps) |
| | self.up = up_fn |
| | self.down = down_fn |
| | self.x0 = x |
| | self.y0 = y |
| | self.force_float = force_float |
| | self.mask_image = mask_image |
| | |
| | if intp == 'lerp': |
| | self.intp = lerp |
| | elif intp == 'slerp': |
| | self.intp = slerp |
| | else: |
| | raise ValueError(f'invalid interpolation method: {intp}') |
| | |
| | if not (1 <= self.multiply and (self.multiply & (self.multiply - 1) == 0)): |
| | raise ValueError(f'multiplier must be power of 2, but not: {self.multiply}') |
| | |
| | if mask_image is not None: |
| | if mask_image.mode != 'L': |
| | raise ValueError(f'the mode of mask image is: {mask_image.mode}') |
| | |
| | def hook_unet(self, p: StableDiffusionProcessing, unet: nn.Module): |
| | step = 0 |
| | |
| | def hook_step_pre(*args, **kwargs): |
| | nonlocal step |
| | step += 1 |
| | |
| | self.hook_layer_pre(unet, hook_step_pre) |
| | |
| | start_step = self.start_steps |
| | max_steps = self.max_steps |
| | M = self.multiply |
| | |
| | def create_pre_hook(name: str, ctx: dict): |
| | def pre_hook(module: nn.Module, inputs: list): |
| | ctx['skipped'] = True |
| | |
| | if step < start_step or max_steps < step: |
| | return |
| | |
| | x, *rest = inputs |
| | dim = x.dim() |
| | if dim == 3: |
| | |
| | bi, ni, chi = x.shape |
| | wi, hi, Ni = self.get_size(p, ni) |
| | x = rearrange(x, 'b (h w) c -> b c h w', w=wi) |
| | if len(rest) != 0: |
| | |
| | rest[0] = torch.concat((rest[0], rest[0]), dim=0) |
| | elif dim == 4: |
| | |
| | bi, chi, hi, wi = x.shape |
| | if 0 < len(rest): |
| | t_emb = rest[0] |
| | rest[0] = torch.concat((t_emb, t_emb), dim=0) |
| | else: |
| | |
| | pass |
| | else: |
| | return |
| | |
| | |
| | w, h = wi // M, hi // M |
| | if w == 0 or h == 0: |
| | |
| | return |
| | |
| | s0, t0 = int(wi * self.x0), int(hi * self.y0) |
| | s1, t1 = s0 + w, t0 + h |
| | if wi < s1: |
| | s1 = wi |
| | s0 = s1 - w |
| | if hi < t1: |
| | t1 = hi |
| | t0 = t1 - h |
| | |
| | if s0 < 0 or t0 < 0: |
| | raise ValueError(f'LLuL failed to process: s=({s0},{s1}), t=({t0},{t1})') |
| | |
| | x1 = x[:, :, t0:t1, s0:s1] |
| | |
| | |
| | x1 = self.up(x1, M) |
| | |
| | |
| | if self.noise_strength > 0: |
| | |
| | noise = torch.randn_like(x1) |
| | |
| | |
| | |
| | |
| | x1 = (1.0 - self.noise_strength) * x1 + self.noise_strength * noise |
| | |
| | if x1.shape[-1] < x.shape[-1] or x1.shape[-2] < x.shape[-2]: |
| | dx = x.shape[-1] - x1.shape[-1] |
| | dx1 = dx // 2 |
| | dx2 = dx - dx1 |
| | dy = x.shape[-2] - x1.shape[-2] |
| | dy1 = dy // 2 |
| | dy2 = dy - dy1 |
| | x1 = F.pad(x1, (dx1, dx2, dy1, dy2), 'replicate') |
| | |
| | x = torch.concat((x, x1), dim=0) |
| | if dim == 3: |
| | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
| | |
| | |
| | ctx['skipped'] = False |
| | return x, *rest |
| | return pre_hook |
| | |
| | def create_post_hook(name: str, ctx: dict): |
| | def post_hook(module: nn.Module, inputs: list, output: Tensor): |
| | if step < start_step or max_steps < step: |
| | return |
| | |
| | if ctx['skipped']: |
| | return |
| | |
| | x = output |
| | dim = x.dim() |
| | if dim == 3: |
| | bo, no, cho = x.shape |
| | wo, ho, No = self.get_size(p, no) |
| | x = rearrange(x, 'b (h w) c -> b c h w', w=wo) |
| | elif dim == 4: |
| | bo, cho, ho, wo = x.shape |
| | else: |
| | return |
| | |
| | assert bo % 2 == 0 |
| | x, x1 = x[:bo//2], x[bo//2:] |
| | |
| | |
| | x1 = self.down(x1, M) |
| | |
| | |
| | w, h = x1.shape[-1], x1.shape[-2] |
| | s0, t0 = int(wo * self.x0), int(ho * self.y0) |
| | s1, t1 = s0 + w, t0 + h |
| | if wo < s1: |
| | s1 = wo |
| | s0 = s1 - w |
| | if ho < t1: |
| | t1 = ho |
| | t0 = t1 - h |
| | |
| | if s0 < 0 or t0 < 0: |
| | raise ValueError(f'LLuL failed to process: s=({s0},{s1}), t=({t0},{t1})') |
| | |
| | x[:, :, t0:t1, s0:s1] = self.interpolate(x[:, :, t0:t1, s0:s1], x1, self.weight) |
| | |
| | if dim == 3: |
| | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
| | |
| | |
| | return x |
| | return post_hook |
| | |
| | def create_hook(name: str, **kwargs): |
| | ctx = dict() |
| | ctx.update(kwargs) |
| | return ( |
| | create_pre_hook(name, ctx), |
| | create_post_hook(name, ctx) |
| | ) |
| | |
| | def wrap_for_xattn(pre, post): |
| | def f(module: nn.Module, o: Callable, *args, **kwargs): |
| | inputs = list(args) + list(kwargs.values()) |
| | inputs_ = pre(module, inputs) |
| | if inputs_ is not None: |
| | inputs = inputs_ |
| | output = o(*inputs) |
| | output_ = post(module, inputs, output) |
| | if output_ is not None: |
| | output = output_ |
| | return output |
| | return f |
| | |
| | |
| | |
| | |
| | for name, attn in each_unet_attn_layers(unet): |
| | if self.layers is not None: |
| | if not any(layer in name for layer in self.layers): |
| | continue |
| | |
| | q_in = attn.to_q.in_features |
| | k_in = attn.to_k.in_features |
| | if q_in == k_in: |
| | |
| | if 's. attn.' in self.apply_to: |
| | pre, post = create_hook(name) |
| | self.hook_layer_pre(attn, pre) |
| | self.hook_layer(attn, post) |
| | else: |
| | |
| | if 'x. attn.' in self.apply_to: |
| | pre, post = create_hook(name) |
| | self.hook_forward(attn, wrap_for_xattn(pre, post)) |
| | |
| | |
| | |
| | |
| | for name, res in each_unet_resblock(unet): |
| | if 'resblock' not in self.apply_to: |
| | continue |
| | |
| | if self.layers is not None: |
| | if not any(layer in name for layer in self.layers): |
| | continue |
| | |
| | pre, post = create_hook(name) |
| | self.hook_layer_pre(res, pre) |
| | self.hook_layer(res, post) |
| | |
| | |
| | |
| | |
| | for name, res in each_unet_transformers(unet): |
| | if 'transformer' not in self.apply_to: |
| | continue |
| | |
| | if self.layers is not None: |
| | if not any(layer in name for layer in self.layers): |
| | continue |
| | |
| | pre, post = create_hook(name) |
| | self.hook_layer_pre(res, pre) |
| | self.hook_layer(res, post) |
| | |
| | |
| | |
| | |
| | if 'out' in self.apply_to: |
| | out = unet.out |
| | pre, post = create_hook('out') |
| | self.hook_layer_pre(out, pre) |
| | self.hook_layer(out, post) |
| | |
| | def get_size(self, p: StableDiffusionProcessing, n: int): |
| | |
| | wh = p.width * p.height |
| | N2 = wh // n |
| | N = int(math.sqrt(N2)) |
| | assert N*N == N2, f'N={N}, N2={N2}' |
| | assert p.width % N == 0, f'width={p.width}, N={N}' |
| | assert p.height % N == 0, f'height={p.height}, N={N}' |
| | w, h = p.width // N, p.height // N |
| | assert w * h == n, f'w={w}, h={h}, N={N}, n={n}' |
| | return w, h, N |
| | |
| | def interpolate(self, v1: Tensor, v2: Tensor, t: float): |
| | dtype = v1.dtype |
| | |
| | |
| | |
| | |
| | if v2.dtype != v1.dtype and not self.force_float: |
| | v2 = v2.to(v1.dtype) |
| | |
| |
|
| | if self.force_float: |
| | v1 = v1.float() |
| | v2 = v2.float() |
| | |
| | if self.mask_image is None: |
| | v = self.intp(v1, v2, t) |
| | else: |
| | to_w, to_h = v1.shape[-1], v1.shape[-2] |
| | resized_image = self.mask_image.resize((to_w, to_h), Image.BILINEAR) |
| | mask = torchvision.transforms.functional.to_tensor(resized_image).to(device=v1.device, dtype=v1.dtype) |
| | mask.unsqueeze_(0) |
| | mask.mul_(t) |
| | v = self.intp(v1, v2, mask) |
| | |
| | if self.force_float: |
| | v = v.to(dtype) |
| | |
| | return v |
| |
|