| | |
| |
|
| | import math |
| | from einops import rearrange |
| | |
| | from torch import randint |
| |
|
| | def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: |
| | min_value = min(min_value, value) |
| |
|
| | |
| | divisors = [i for i in range(min_value, value + 1) if value % i == 0] |
| |
|
| | ns = [value // i for i in divisors[:max_options]] |
| |
|
| | if len(ns) - 1 > 0: |
| | idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() |
| | else: |
| | idx = 0 |
| |
|
| | return ns[idx] |
| |
|
| | class HyperTile: |
| | @classmethod |
| | def INPUT_TYPES(s): |
| | return {"required": { "model": ("MODEL",), |
| | "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), |
| | "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), |
| | "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), |
| | "scale_depth": ("BOOLEAN", {"default": False}), |
| | }} |
| | RETURN_TYPES = ("MODEL",) |
| | FUNCTION = "patch" |
| |
|
| | CATEGORY = "model_patches/unet" |
| |
|
| | def patch(self, model, tile_size, swap_size, max_depth, scale_depth): |
| | model_channels = model.model.model_config.unet_config["model_channels"] |
| |
|
| | latent_tile_size = max(32, tile_size) // 8 |
| | self.temp = None |
| |
|
| | def hypertile_in(q, k, v, extra_options): |
| | model_chans = q.shape[-2] |
| | orig_shape = extra_options['original_shape'] |
| | apply_to = [] |
| | for i in range(max_depth + 1): |
| | apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) |
| |
|
| | if model_chans in apply_to: |
| | shape = extra_options["original_shape"] |
| | aspect_ratio = shape[-1] / shape[-2] |
| |
|
| | hw = q.size(1) |
| | h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) |
| |
|
| | factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 |
| | nh = random_divisor(h, latent_tile_size * factor, swap_size) |
| | nw = random_divisor(w, latent_tile_size * factor, swap_size) |
| |
|
| | if nh * nw > 1: |
| | q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) |
| | self.temp = (nh, nw, h, w) |
| | return q, k, v |
| |
|
| | return q, k, v |
| | def hypertile_out(out, extra_options): |
| | if self.temp is not None: |
| | nh, nw, h, w = self.temp |
| | self.temp = None |
| | out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) |
| | out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) |
| | return out |
| |
|
| |
|
| | m = model.clone() |
| | m.set_model_attn1_patch(hypertile_in) |
| | m.set_model_attn1_output_patch(hypertile_out) |
| | return (m, ) |
| |
|
| | NODE_CLASS_MAPPINGS = { |
| | "HyperTile": HyperTile, |
| | } |
| |
|