fac / fc /sd-webui-todo /scripts /TODO.py
dikdimon's picture
Upload fc using SD-Hub extension
50261d7 verified
import math
import torch
import torch.nn.functional as F
import gradio as gr
from modules import scripts
from modules import shared
from typing import Type, Dict, Any, Tuple, Callable
def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
batch_size = item.shape[0]
item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2)
item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
item = item.reshape(batch_size, new_h * new_w, -1)
return item
def compute_merge(x: torch.Tensor, todo_info: Dict[str, Any]) -> Tuple[Callable, ...]:
original_h, original_w = todo_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
args = todo_info["args"]
downsample_factor_1 = args['downsample_factor']
downsample_factor_2 = args['downsample_factor_level_2']
downsample_factor_3 = args['downsample_factor_level_3']
cur_h = math.ceil(original_h / downsample)
cur_w = math.ceil(original_w / downsample)
m = lambda v: v
if downsample == 1 and downsample_factor_1 != 1:
new_h = int(cur_h / downsample_factor_1)
new_w = int(cur_w / downsample_factor_1)
m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"])
elif downsample == 2 and downsample_factor_2 != 1:
new_h = int(cur_h / downsample_factor_2)
new_w = int(cur_w / downsample_factor_2)
m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"])
elif downsample == 4 and downsample_factor_3 != 1:
new_h = int(cur_h / downsample_factor_3)
new_w = int(cur_w / downsample_factor_3)
m = lambda v: up_or_downsample(v, cur_w, cur_h, new_w, new_h, args["downsample_method"])
return m
class ToDo(scripts.Script):
#sorting_priority = 50
#is_in_high_res_fix = False
def title(self):
return "Token Downsampling"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui_bk(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
todo_enabled = gr.Checkbox(label='Enabled', value=False)
#todo_enabled_hr = gr.Checkbox(label='Enable only during hires fix', value=False)
todo_downsample_method = gr.Dropdown(label="Downsample method", choices=["nearest", "bilinear", "bicubic", "nearest-exact"], value="nearest-exact")
todo_downsample_factor_depth_1 = gr.Slider(label='Downsample Factor Depth 1', minimum=0.01, maximum=8.0, step=0.01, value=2.0)
todo_downsample_factor_depth_2 = gr.Slider(label='Downsample Factor Depth 2', minimum=0.01, maximum=8.0, step=0.01, value=1.0)
self.infotext_fields = (
(todo_enabled, lambda d: gr.Checkbox.update(value="todo_enabled" in d)),
(todo_downsample_method, "todo_downsample_method"),
(todo_downsample_factor_depth_1, "todo_downsample_factor_depth_1"),
(todo_downsample_factor_depth_2, "todo_downsample_factor_depth_2"),)
return todo_enabled, todo_downsample_method, todo_downsample_factor_depth_1, todo_downsample_factor_depth_2
#def before_hr(self, p, *script_args, **kwargs):
#self.is_in_high_res_fix = True
def process(self, p, *script_args, **kwargs):
#todo_enabled, todo_downsample_method, todo_downsample_factor_depth_1, todo_downsample_factor_depth_2 = script_args
#if not p.enable_hr:
#self.is_in_high_res_fix = False
_todo_enabled = shared.opts.todo_enable
if not _todo_enabled:
return
#if todo_enabled_hr and not self.is_in_high_res_fix:
#return
_todo_downsample_method = shared.opts.todo_downsample_method
_todo_downsample_factor_depth_1 = shared.opts.todo_downsample_factor_depth_1
_todo_downsample_factor_depth_2 = shared.opts.todo_downsample_factor_depth_2
apply_patch(
shared.sd_model,
downsample_factor = _todo_downsample_factor_depth_1,
downsample_factor_level_2 = _todo_downsample_factor_depth_2,
downsample_method = _todo_downsample_method)
p.extra_generation_params["todo_enabled"] = _todo_enabled
p.extra_generation_params["todo_downsample_method"] = _todo_downsample_method
p.extra_generation_params["todo_downsample_factor_depth_1"] = _todo_downsample_factor_depth_1
p.extra_generation_params["todo_downsample_factor_depth_2"] = _todo_downsample_factor_depth_2
#self.is_in_high_res_fix = False
return
def postprocess(self, p, processed, *args):
#todo_enabled, todo_downsample_method, todo_downsample_factor_depth_1, todo_downsample_factor_depth_2 = args
remove_patch(shared.sd_model)
return
def ext_on_ui_settings():
options = {
"todo_enable": shared.OptionInfo(False, "Enable Token downsampling", infotext="todo enable"),
"todo_downsample_method": shared.OptionInfo("nearest-exact", "Downsampling method", gr.Dropdown, lambda: {"choices": ["nearest-exact", "bilinear", "bicubic", "nearest"]}, infotext="todo downsample method"),
"todo_downsample_factor_depth_1": shared.OptionInfo(1.0, "Downsample factor depth 1", gr.Slider, {"minimum": 0.01, "maximum": 8.0, "step": 0.01}, infotext="todo downsample factor depth 1"),
"todo_downsample_factor_depth_2": shared.OptionInfo(1.0, "Downsample factor depth 2", gr.Slider, {"minimum": 0.01, "maximum": 8.0, "step": 0.01}, infotext="todo downsample factor depth 2"),
}
for name, opt in options.items():
opt.section = ('todo', "ToDo")
shared.opts.add_option(name, opt)
on_ui_settings(ext_on_ui_settings)
def make_todo_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
This patch applies ToMe to the forward function of the block.
"""
class ToDoBlock(block_class):
# Save for unpatching later
_parent = block_class
def _forward(self, x: torch.Tensor, context: torch.Tensor = None, mask = None) -> torch.Tensor:
c = context if self.disable_self_attn else None
c = self.norm1(x) if c is None else c
m = compute_merge(x, self._todo_info)
x = self.attn1(self.norm1(x), context = m(c)) + x
x = self.attn2(self.norm2(x), context = context) + x
x = self.ff(self.norm3(x)) + x
return x
return ToDoBlock
def hook_todo_model(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._todo_info["size"] = (args[0].shape[2], args[0].shape[3])
return None
model._todo_info["hooks"].append(model.register_forward_pre_hook(hook))
def apply_patch(
model: torch.nn.Module,
downsample_method: str = "nearest-exact",
downsample_factor: float = 2,
downsample_factor_level_2: float = 1,
downsample_factor_level_3: float = 1,
):
# Make sure the module is not currently patched
remove_patch(model)
diffusion_model = model.model.diffusion_model
diffusion_model._todo_info = {
"size": None,
"hooks": [],
"args": {
"downsample_method": downsample_method, # native torch interpolation methods ["nearest", "linear", "bilinear", "bicubic", "nearest-exact"]
"downsample_factor": downsample_factor, # amount to downsample by
"downsample_factor_level_2": downsample_factor_level_2, # amount to downsample by at the 2nd down block of unet
"downsample_factor_level_3": downsample_factor_level_3,
}
}
hook_todo_model(diffusion_model)
for _, module in diffusion_model.named_modules():
# If for some reason this has a different name, create an issue and I'll fix it
if isinstance_str(module, "BasicTransformerBlock"):
make_todo_block_fn = make_todo_block
module.__class__ = make_todo_block_fn(module.__class__)
module._todo_info = diffusion_model._todo_info
# Something introduced in SD 2.0 (LDM only)
if not hasattr(module, "disable_self_attn"):
module.disable_self_attn = False
return model
def remove_patch(model: torch.nn.Module):
""" Removes a patch from a ToMe Diffusion module if it was already patched. """
for _, module in model.named_modules():
if hasattr(module, "_todo_info"):
for hook in module._todo_info["hooks"]:
hook.remove()
module._todo_info["hooks"].clear()
if module.__class__.__name__ == "ToDoBlock":
module.__class__ = module._parent
return model
def isinstance_str(x: object, cls_name: str):
"""
Checks whether x has any class *named* cls_name in its ancestry.
Doesn't require access to the class's implementation.
Useful for patching!
"""
for _cls in x.__class__.__mro__:
if _cls.__name__ == cls_name:
return True
return False