|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
import launch |
|
|
import modules.scripts as scripts |
|
|
import modules.shared as shared |
|
|
from modules.processing import Processed, StableDiffusionProcessing |
|
|
|
|
|
|
|
|
class ToMe: |
|
|
"""ToMe implementation""" |
|
|
|
|
|
def __init__(self): |
|
|
self.infotext_fields = [] |
|
|
self.paste_field_names = [] |
|
|
|
|
|
def patch_model(self, model: torch.nn.Module, is_hires: bool): |
|
|
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5) |
|
|
max_downsample: int = int( |
|
|
shared.opts.data.get('tome_maximum_down_sampling', 1)) |
|
|
sx: int = int(shared.opts.data.get('tome_stride_x', 2)) |
|
|
sy: int = int(shared.opts.data.get('tome_stride_y', 2)) |
|
|
use_rand: bool = bool(shared.opts.data.get('tome_random', True)) |
|
|
merge_attn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_attention', True)) |
|
|
merge_crossattn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_cross_attention', False)) |
|
|
merge_mlp: bool = bool(shared.opts.data.get('tome_merge_mlp', False)) |
|
|
|
|
|
import tomesd |
|
|
tomesd.apply_patch(model, |
|
|
ratio=ratio, |
|
|
max_downsample=max_downsample, |
|
|
sx=sx, |
|
|
sy=sy, |
|
|
use_rand=use_rand, |
|
|
merge_attn=merge_attn, |
|
|
merge_crossattn=merge_crossattn, |
|
|
merge_mlp=merge_mlp) |
|
|
|
|
|
target = 'to hires pass' if is_hires else 'globally' |
|
|
|
|
|
print( |
|
|
f"Applying ToMe patch {target} with ratio[{ratio}], " |
|
|
f"max_downsample[{max_downsample}], sx[{sx}], sy[{sy}], " |
|
|
f"use_rand[{use_rand}], merge_attn[{merge_attn}], " |
|
|
f"merge_crossattn[{merge_crossattn}], merge_mlp[{merge_mlp}]", |
|
|
file=sys.stderr) |
|
|
|
|
|
def on_ui_settings_callback(self): |
|
|
section = ('tome', 'ToMe Settings') |
|
|
shared.opts.add_option( |
|
|
"tome_merging_ratio", |
|
|
shared.OptionInfo( |
|
|
0.5, |
|
|
"ToMe merging ratio, higher the faster, it should not go over 1-(1/(sx*sy)), which is 0.75 by default", |
|
|
gr.Slider, { |
|
|
"minimum": 0, |
|
|
"maximum": 0.99, |
|
|
"step": 0.01 |
|
|
}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_min_x", |
|
|
shared.OptionInfo( |
|
|
768, |
|
|
"Only activate tome if image width reach this value", |
|
|
gr.Slider, { |
|
|
"minimum": 512, |
|
|
"maximum": 2048, |
|
|
"step": 256 |
|
|
}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_min_y", |
|
|
shared.OptionInfo( |
|
|
768, |
|
|
"Only activate tome if image height reach this value", |
|
|
gr.Slider, { |
|
|
"minimum": 512, |
|
|
"maximum": 2048, |
|
|
"step": 256 |
|
|
}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_random", |
|
|
shared.OptionInfo( |
|
|
False, |
|
|
"Use random perturbations - Disable if you see werid artifacts", |
|
|
gr.Checkbox, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_merge_attention", |
|
|
shared.OptionInfo(True, |
|
|
"Merge attention (recommended)", |
|
|
gr.Checkbox, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_merge_cross_attention", |
|
|
shared.OptionInfo(False, |
|
|
"Merge cross attention (not recommended)", |
|
|
gr.Checkbox, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_merge_mlp", |
|
|
shared.OptionInfo(False, |
|
|
"Merge mlp (very not recommended)", |
|
|
gr.Checkbox, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_maximum_down_sampling", |
|
|
shared.OptionInfo("1", |
|
|
"Maximum down sampling", |
|
|
gr.Radio, {"choices": ["1", "2", "4", "8"]}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_stride_x", |
|
|
shared.OptionInfo(2, |
|
|
"Stride - X", |
|
|
gr.Slider, { |
|
|
"minimum": 2, |
|
|
"maximum": 8, |
|
|
"step": 2 |
|
|
}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_stride_y", |
|
|
shared.OptionInfo(2, |
|
|
"Stride - Y", |
|
|
gr.Slider, { |
|
|
"minimum": 2, |
|
|
"maximum": 8, |
|
|
"step": 2 |
|
|
}, |
|
|
section=section)) |
|
|
shared.opts.add_option( |
|
|
"tome_force_hires", |
|
|
shared.OptionInfo(True, |
|
|
"Enable during hires pass (Experimental)", |
|
|
gr.Checkbox, |
|
|
section=section)) |
|
|
|
|
|
|
|
|
_tome_instance = ToMe() |
|
|
scripts.script_callbacks.on_ui_settings(_tome_instance.on_ui_settings_callback) |
|
|
|
|
|
|
|
|
class Script(scripts.Script): |
|
|
"""Use script interface to inject into image generation loop""" |
|
|
|
|
|
def title(self): |
|
|
return "ToMe" |
|
|
|
|
|
def show(self, is_img2img: bool): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
def ui(self, is_img2img: bool): |
|
|
enable_checkbox = gr.Checkbox( |
|
|
value=True, |
|
|
label="Enable ToMe optimization", |
|
|
info= |
|
|
"Use tomesd to boost generation speed. Other settings in Settings Tab.", |
|
|
interactive=True) |
|
|
|
|
|
if not launch.is_installed("tomesd"): |
|
|
enable_checkbox.info = "Import tomesd failed, please check your environment before using ToMe extension." |
|
|
enable_checkbox.interactive = False |
|
|
|
|
|
return [enable_checkbox] |
|
|
|
|
|
def process(self, p: StableDiffusionProcessing, *args): |
|
|
|
|
|
if not launch.is_installed("tomesd"): |
|
|
print( |
|
|
"Cannot import tomesd, please install it manually following the instructions on https://github.com/dbolya/tomesd", |
|
|
file=sys.stderr) |
|
|
return |
|
|
|
|
|
if args and args[0]: |
|
|
tome_min_x = shared.opts.data.get('tome_min_x', 768) |
|
|
tome_min_y = shared.opts.data.get('tome_min_y', 768) |
|
|
|
|
|
if p.width >= tome_min_x and p.height >= tome_min_y: |
|
|
_tome_instance.patch_model(p.sd_model, is_hires=False) |
|
|
|
|
|
|
|
|
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5) |
|
|
max_downsample: int = int( |
|
|
shared.opts.data.get('tome_maximum_down_sampling', 1)) |
|
|
sx: int = int(shared.opts.data.get('tome_stride_x', 2)) |
|
|
sy: int = int(shared.opts.data.get('tome_stride_y', 2)) |
|
|
use_rand: bool = bool(shared.opts.data.get( |
|
|
'tome_random', True)) |
|
|
merge_attn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_attention', True)) |
|
|
merge_crossattn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_cross_attention', False)) |
|
|
merge_mlp: bool = bool( |
|
|
shared.opts.data.get('tome_merge_mlp', False)) |
|
|
|
|
|
p.extra_generation_params['tome_merging_ratio'] = ratio |
|
|
p.extra_generation_params[ |
|
|
'tome_maximum_down_sampling'] = max_downsample |
|
|
p.extra_generation_params['tome_stride_x'] = sx |
|
|
p.extra_generation_params['tome_stride_y'] = sy |
|
|
p.extra_generation_params['tome_random'] = use_rand |
|
|
p.extra_generation_params['tome_merge_attention'] = merge_attn |
|
|
p.extra_generation_params[ |
|
|
'tome_merge_cross_attention'] = merge_crossattn |
|
|
p.extra_generation_params['tome_merge_mlp'] = merge_mlp |
|
|
|
|
|
return |
|
|
|
|
|
print( |
|
|
f'Image size [{p.width}*{p.height}] < ' |
|
|
f'Threashold [{tome_min_x}*{tome_min_y}]' |
|
|
', no need to apply ToMe patch', |
|
|
file=sys.stderr) |
|
|
|
|
|
def postprocess(self, p: StableDiffusionProcessing, processed: Processed, |
|
|
*args): |
|
|
if not launch.is_installed("tomesd"): |
|
|
return |
|
|
|
|
|
|
|
|
if args and args[0]: |
|
|
import tomesd |
|
|
tomesd.remove_patch(p.sd_model) |
|
|
|
|
|
def before_hires_pass(self, p, *args, **kwargs): |
|
|
if not launch.is_installed("tomesd"): |
|
|
return |
|
|
|
|
|
tome_force_hires: bool = bool( |
|
|
shared.opts.data.get('tome_force_hires', True)) |
|
|
if tome_force_hires and args and args[0]: |
|
|
_tome_instance.patch_model(p.sd_model, is_hires=True) |
|
|
|
|
|
|
|
|
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5) |
|
|
max_downsample: int = int( |
|
|
shared.opts.data.get('tome_maximum_down_sampling', 1)) |
|
|
sx: int = int(shared.opts.data.get('tome_stride_x', 2)) |
|
|
sy: int = int(shared.opts.data.get('tome_stride_y', 2)) |
|
|
use_rand: bool = bool(shared.opts.data.get('tome_random', True)) |
|
|
merge_attn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_attention', True)) |
|
|
merge_crossattn: bool = bool( |
|
|
shared.opts.data.get('tome_merge_cross_attention', False)) |
|
|
merge_mlp: bool = bool( |
|
|
shared.opts.data.get('tome_merge_mlp', False)) |
|
|
tome_force_hires: bool = bool( |
|
|
shared.opts.data.get('tome_force_hires', True)) |
|
|
|
|
|
p.extra_generation_params['tome_merging_ratio'] = ratio |
|
|
p.extra_generation_params[ |
|
|
'tome_maximum_down_sampling'] = max_downsample |
|
|
p.extra_generation_params['tome_stride_x'] = sx |
|
|
p.extra_generation_params['tome_stride_y'] = sy |
|
|
p.extra_generation_params['tome_random'] = use_rand |
|
|
p.extra_generation_params['tome_merge_attention'] = merge_attn |
|
|
p.extra_generation_params[ |
|
|
'tome_merge_cross_attention'] = merge_crossattn |
|
|
p.extra_generation_params['tome_merge_mlp'] = merge_mlp |
|
|
p.extra_generation_params['tome_force_hires'] = tome_force_hires |
|
|
|
|
|
def post_hires_pass(self, p, *args, **kwargs): |
|
|
if not launch.is_installed("tomesd"): |
|
|
return |
|
|
|
|
|
tome_force_hires: bool = bool( |
|
|
shared.opts.data.get('tome_force_hires', True)) |
|
|
|
|
|
if tome_force_hires and args and args[0]: |
|
|
import tomesd |
|
|
tomesd.remove_patch(p.sd_model) |
|
|
|