|
|
|
|
| import math
|
| from einops import rearrange
|
|
|
| from torch import randint
|
|
|
| def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
| min_value = min(min_value, value)
|
|
|
|
|
| divisors = [i for i in range(min_value, value + 1) if value % i == 0]
|
|
|
| ns = [value // i for i in divisors[:max_options]]
|
|
|
| if len(ns) - 1 > 0:
|
| idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
| else:
|
| idx = 0
|
|
|
| return ns[idx]
|
|
|
| class HyperTile:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| return {"required": { "model": ("MODEL",),
|
| "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
|
| "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
|
| "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
|
| "scale_depth": ("BOOLEAN", {"default": False}),
|
| }}
|
| RETURN_TYPES = ("MODEL",)
|
| FUNCTION = "patch"
|
|
|
| CATEGORY = "model_patches/unet"
|
|
|
| def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
| model_channels = model.model.model_config.unet_config["model_channels"]
|
|
|
| latent_tile_size = max(32, tile_size) // 8
|
| self.temp = None
|
|
|
| def hypertile_in(q, k, v, extra_options):
|
| model_chans = q.shape[-2]
|
| orig_shape = extra_options['original_shape']
|
| apply_to = []
|
| for i in range(max_depth + 1):
|
| apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
|
|
|
| if model_chans in apply_to:
|
| shape = extra_options["original_shape"]
|
| aspect_ratio = shape[-1] / shape[-2]
|
|
|
| hw = q.size(1)
|
| h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
|
|
| factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
|
| nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
| nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
|
|
| if nh * nw > 1:
|
| q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
| self.temp = (nh, nw, h, w)
|
| return q, k, v
|
|
|
| return q, k, v
|
| def hypertile_out(out, extra_options):
|
| if self.temp is not None:
|
| nh, nw, h, w = self.temp
|
| self.temp = None
|
| out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
| out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
| return out
|
|
|
|
|
| m = model.clone()
|
| m.set_model_attn1_patch(hypertile_in)
|
| m.set_model_attn1_output_patch(hypertile_out)
|
| return (m, )
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "HyperTile": HyperTile,
|
| }
|
|
|