| import math |
| import torch |
| import torch.nn.functional as F |
| import gradio as gr |
| from modules import scripts |
| from modules import shared |
| from typing import Type, Dict, Any, Tuple, Callable |
|
|
| def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"): |
| batch_size = item.shape[0] |
|
|
| item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2) |
| item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1) |
| item = item.reshape(batch_size, new_h * new_w, -1) |
|
|
| return item |
|
|
| def compute_merge(x: torch.Tensor, todo_info: Dict[str, Any]) -> Tuple[Callable, ...]: |
| original_h, original_w = todo_info["size"] |
| original_tokens = original_h * original_w |
| downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) |
|
|
| args = todo_info["args"] |
| downsample_factor_1 = args['downsample_factor'] |
| downsample_factor_2 = args['downsample_factor_level_2'] |
| downsample_factor_3 = args['downsample_factor_level_3'] |
| |
| cur_h = math.ceil(original_h / downsample) |
| cur_w = math.ceil(original_w / downsample) |
| m = lambda v: v |
| if downsample == 1 and downsample_factor_1 != 1: |
| new_h = int(cur_h / downsample_factor_1) |
| new_w = int(cur_w / downsample_factor_1) |
| m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"]) |
| elif downsample == 2 and downsample_factor_2 != 1: |
| new_h = int(cur_h / downsample_factor_2) |
| new_w = int(cur_w / downsample_factor_2) |
| m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"]) |
| elif downsample == 4 and downsample_factor_3 != 1: |
| new_h = int(cur_h / downsample_factor_3) |
| new_w = int(cur_w / downsample_factor_3) |
| m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"]) |
| |
| return m |
|
|
| class ToDo(scripts.Script): |
| |
| |
|
|
| def title(self): |
| return "Token Downsampling" |
|
|
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
|
|
| def ui_bk(self, *args, **kwargs): |
| with gr.Accordion(open=False, label=self.title()): |
| todo_enabled = gr.Checkbox(label='Enabled', value=False) |
| |
| todo_downsample_method = gr.Dropdown(label="Downsample method", choices=["nearest", "bilinear", "bicubic", "nearest-exact"], value="nearest-exact") |
| todo_downsample_factor_depth_1 = gr.Slider(label='Downsample Factor Depth 1', minimum=0.01, maximum=8.0, step=0.01, value=2.0) |
| todo_downsample_factor_depth_2 = gr.Slider(label='Downsample Factor Depth 2', minimum=0.01, maximum=8.0, step=0.01, value=1.0) |
| self.infotext_fields = ( |
| (todo_enabled, lambda d: gr.Checkbox.update(value="todo_enabled" in d)), |
| (todo_downsample_method, "todo_downsample_method"), |
| (todo_downsample_factor_depth_1, "todo_downsample_factor_depth_1"), |
| (todo_downsample_factor_depth_2, "todo_downsample_factor_depth_2"),) |
| return todo_enabled, todo_downsample_method, todo_downsample_factor_depth_1, todo_downsample_factor_depth_2 |
|
|
| |
| |
|
|
| def process(self, p, *script_args, **kwargs): |
| |
|
|
| |
| |
|
|
| _todo_enabled = shared.opts.todo_enable |
| |
| if not _todo_enabled: |
| return |
|
|
| |
| |
|
|
| _todo_downsample_method = shared.opts.todo_downsample_method |
| _todo_downsample_factor_depth_1 = shared.opts.todo_downsample_factor_depth_1 |
| _todo_downsample_factor_depth_2 = shared.opts.todo_downsample_factor_depth_2 |
| |
| apply_patch( |
| shared.sd_model, |
| downsample_factor = _todo_downsample_factor_depth_1, |
| downsample_factor_level_2 = _todo_downsample_factor_depth_2, |
| downsample_method = _todo_downsample_method) |
|
|
|
|
| p.extra_generation_params["todo_enabled"] = _todo_enabled |
| p.extra_generation_params["todo_downsample_method"] = _todo_downsample_method |
| p.extra_generation_params["todo_downsample_factor_depth_1"] = _todo_downsample_factor_depth_1 |
| p.extra_generation_params["todo_downsample_factor_depth_2"] = _todo_downsample_factor_depth_2 |
|
|
| |
|
|
| return |
| |
| def postprocess(self, p, processed, *args): |
| |
| remove_patch(shared.sd_model) |
| return |
|
|
| def ext_on_ui_settings(): |
| options = { |
| "todo_enable": shared.OptionInfo(False, "Enable Token downsampling", infotext="todo enable"), |
| "todo_downsample_method": shared.OptionInfo("nearest-exact", "Downsampling method", gr.Dropdown, lambda: {"choices": ["nearest-exact", "bilinear", "bicubic", "nearest"]}, infotext="todo downsample method"), |
| "todo_downsample_factor_depth_1": shared.OptionInfo(1.0, "Downsample factor depth 1", gr.Slider, {"minimum": 0.01, "maximum": 8.0, "step": 0.01}, infotext="todo downsample factor depth 1"), |
| "todo_downsample_factor_depth_2": shared.OptionInfo(1.0, "Downsample factor depth 2", gr.Slider, {"minimum": 0.01, "maximum": 8.0, "step": 0.01}, infotext="todo downsample factor depth 2"), |
| } |
| for name, opt in options.items(): |
| opt.section = ('todo', "ToDo") |
| shared.opts.add_option(name, opt) |
|
|
| on_ui_settings(ext_on_ui_settings) |
|
|
| def make_todo_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: |
| """ |
| Make a patched class on the fly so we don't have to import any specific modules. |
| This patch applies ToMe to the forward function of the block. |
| """ |
|
|
| class ToDoBlock(block_class): |
| |
| _parent = block_class |
|
|
| def _forward(self, x: torch.Tensor, context: torch.Tensor = None, mask = None) -> torch.Tensor: |
| c = context if self.disable_self_attn else None |
| c = self.norm1(x) if c is None else c |
| m = compute_merge(x, self._todo_info) |
|
|
| x = self.attn1(self.norm1(x), context = m(c)) + x |
| x = self.attn2(self.norm2(x), context = context) + x |
| x = self.ff(self.norm3(x)) + x |
| return x |
| |
| return ToDoBlock |
|
|
| def hook_todo_model(model: torch.nn.Module): |
| """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ |
| def hook(module, args): |
| module._todo_info["size"] = (args[0].shape[2], args[0].shape[3]) |
| return None |
| model._todo_info["hooks"].append(model.register_forward_pre_hook(hook)) |
|
|
| def apply_patch( |
| model: torch.nn.Module, |
| downsample_method: str = "nearest-exact", |
| downsample_factor: float = 2, |
| downsample_factor_level_2: float = 1, |
| downsample_factor_level_3: float = 1, |
| ): |
|
|
| |
| remove_patch(model) |
| diffusion_model = model.model.diffusion_model |
| |
| diffusion_model._todo_info = { |
| "size": None, |
| "hooks": [], |
| "args": { |
| "downsample_method": downsample_method, |
| "downsample_factor": downsample_factor, |
| "downsample_factor_level_2": downsample_factor_level_2, |
| "downsample_factor_level_3": downsample_factor_level_3, |
| } |
| } |
| hook_todo_model(diffusion_model) |
|
|
| for _, module in diffusion_model.named_modules(): |
| |
| if isinstance_str(module, "BasicTransformerBlock"): |
| make_todo_block_fn = make_todo_block |
| module.__class__ = make_todo_block_fn(module.__class__) |
| module._todo_info = diffusion_model._todo_info |
|
|
| |
| if not hasattr(module, "disable_self_attn"): |
| module.disable_self_attn = False |
|
|
| return model |
|
|
| def remove_patch(model: torch.nn.Module): |
| """ Removes a patch from a ToMe Diffusion module if it was already patched. """ |
|
|
| for _, module in model.named_modules(): |
| if hasattr(module, "_todo_info"): |
| for hook in module._todo_info["hooks"]: |
| hook.remove() |
| module._todo_info["hooks"].clear() |
|
|
| if module.__class__.__name__ == "ToDoBlock": |
| module.__class__ = module._parent |
| |
| return model |
|
|
| def isinstance_str(x: object, cls_name: str): |
| """ |
| Checks whether x has any class *named* cls_name in its ancestry. |
| Doesn't require access to the class's implementation. |
| |
| Useful for patching! |
| """ |
|
|
| for _cls in x.__class__.__mro__: |
| if _cls.__name__ == cls_name: |
| return True |
| |
| return False |
|
|