dikdimon's picture
Upload extensions using SD-Hub extension
c336648 verified
# Copyright 2023 SLAPaper
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import gradio as gr
import torch
import launch
import modules.scripts as scripts
import modules.shared as shared
from modules.processing import Processed, StableDiffusionProcessing
class ToMe:
"""ToMe implementation"""
def __init__(self):
self.infotext_fields = []
self.paste_field_names = []
def patch_model(self, model: torch.nn.Module, is_hires: bool):
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5)
max_downsample: int = int(
shared.opts.data.get('tome_maximum_down_sampling', 1))
sx: int = int(shared.opts.data.get('tome_stride_x', 2))
sy: int = int(shared.opts.data.get('tome_stride_y', 2))
use_rand: bool = bool(shared.opts.data.get('tome_random', True))
merge_attn: bool = bool(
shared.opts.data.get('tome_merge_attention', True))
merge_crossattn: bool = bool(
shared.opts.data.get('tome_merge_cross_attention', False))
merge_mlp: bool = bool(shared.opts.data.get('tome_merge_mlp', False))
import tomesd
tomesd.apply_patch(model,
ratio=ratio,
max_downsample=max_downsample,
sx=sx,
sy=sy,
use_rand=use_rand,
merge_attn=merge_attn,
merge_crossattn=merge_crossattn,
merge_mlp=merge_mlp)
target = 'to hires pass' if is_hires else 'globally'
print(
f"Applying ToMe patch {target} with ratio[{ratio}], "
f"max_downsample[{max_downsample}], sx[{sx}], sy[{sy}], "
f"use_rand[{use_rand}], merge_attn[{merge_attn}], "
f"merge_crossattn[{merge_crossattn}], merge_mlp[{merge_mlp}]",
file=sys.stderr)
def on_ui_settings_callback(self):
section = ('tome', 'ToMe Settings')
shared.opts.add_option(
"tome_merging_ratio",
shared.OptionInfo(
0.5,
"ToMe merging ratio, higher the faster, it should not go over 1-(1/(sx*sy)), which is 0.75 by default",
gr.Slider, {
"minimum": 0,
"maximum": 0.99,
"step": 0.01
},
section=section))
shared.opts.add_option(
"tome_min_x",
shared.OptionInfo(
768,
"Only activate tome if image width reach this value",
gr.Slider, {
"minimum": 512,
"maximum": 2048,
"step": 256
},
section=section))
shared.opts.add_option(
"tome_min_y",
shared.OptionInfo(
768,
"Only activate tome if image height reach this value",
gr.Slider, {
"minimum": 512,
"maximum": 2048,
"step": 256
},
section=section))
shared.opts.add_option(
"tome_random",
shared.OptionInfo(
False,
"Use random perturbations - Disable if you see werid artifacts",
gr.Checkbox,
section=section))
shared.opts.add_option(
"tome_merge_attention",
shared.OptionInfo(True,
"Merge attention (recommended)",
gr.Checkbox,
section=section))
shared.opts.add_option(
"tome_merge_cross_attention",
shared.OptionInfo(False,
"Merge cross attention (not recommended)",
gr.Checkbox,
section=section))
shared.opts.add_option(
"tome_merge_mlp",
shared.OptionInfo(False,
"Merge mlp (very not recommended)",
gr.Checkbox,
section=section))
shared.opts.add_option(
"tome_maximum_down_sampling",
shared.OptionInfo("1",
"Maximum down sampling",
gr.Radio, {"choices": ["1", "2", "4", "8"]},
section=section))
shared.opts.add_option(
"tome_stride_x",
shared.OptionInfo(2,
"Stride - X",
gr.Slider, {
"minimum": 2,
"maximum": 8,
"step": 2
},
section=section))
shared.opts.add_option(
"tome_stride_y",
shared.OptionInfo(2,
"Stride - Y",
gr.Slider, {
"minimum": 2,
"maximum": 8,
"step": 2
},
section=section))
shared.opts.add_option(
"tome_force_hires",
shared.OptionInfo(True,
"Enable during hires pass (Experimental)",
gr.Checkbox,
section=section))
_tome_instance = ToMe()
scripts.script_callbacks.on_ui_settings(_tome_instance.on_ui_settings_callback)
class Script(scripts.Script):
"""Use script interface to inject into image generation loop"""
def title(self):
return "ToMe"
def show(self, is_img2img: bool):
return scripts.AlwaysVisible
def ui(self, is_img2img: bool):
enable_checkbox = gr.Checkbox(
value=True,
label="Enable ToMe optimization",
info=
"Use tomesd to boost generation speed. Other settings in Settings Tab.",
interactive=True)
if not launch.is_installed("tomesd"):
enable_checkbox.info = "Import tomesd failed, please check your environment before using ToMe extension."
enable_checkbox.interactive = False
return [enable_checkbox]
def process(self, p: StableDiffusionProcessing, *args):
# patch all, unload when postprocess
if not launch.is_installed("tomesd"):
print(
"Cannot import tomesd, please install it manually following the instructions on https://github.com/dbolya/tomesd",
file=sys.stderr)
return
if args and args[0]:
tome_min_x = shared.opts.data.get('tome_min_x', 768)
tome_min_y = shared.opts.data.get('tome_min_y', 768)
if p.width >= tome_min_x and p.height >= tome_min_y:
_tome_instance.patch_model(p.sd_model, is_hires=False)
# add generation info
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5)
max_downsample: int = int(
shared.opts.data.get('tome_maximum_down_sampling', 1))
sx: int = int(shared.opts.data.get('tome_stride_x', 2))
sy: int = int(shared.opts.data.get('tome_stride_y', 2))
use_rand: bool = bool(shared.opts.data.get(
'tome_random', True))
merge_attn: bool = bool(
shared.opts.data.get('tome_merge_attention', True))
merge_crossattn: bool = bool(
shared.opts.data.get('tome_merge_cross_attention', False))
merge_mlp: bool = bool(
shared.opts.data.get('tome_merge_mlp', False))
p.extra_generation_params['tome_merging_ratio'] = ratio
p.extra_generation_params[
'tome_maximum_down_sampling'] = max_downsample
p.extra_generation_params['tome_stride_x'] = sx
p.extra_generation_params['tome_stride_y'] = sy
p.extra_generation_params['tome_random'] = use_rand
p.extra_generation_params['tome_merge_attention'] = merge_attn
p.extra_generation_params[
'tome_merge_cross_attention'] = merge_crossattn
p.extra_generation_params['tome_merge_mlp'] = merge_mlp
return
print(
f'Image size [{p.width}*{p.height}] < '
f'Threashold [{tome_min_x}*{tome_min_y}]'
', no need to apply ToMe patch',
file=sys.stderr)
def postprocess(self, p: StableDiffusionProcessing, processed: Processed,
*args):
if not launch.is_installed("tomesd"):
return
# remove patch all
if args and args[0]:
import tomesd
tomesd.remove_patch(p.sd_model)
def before_hires_pass(self, p, *args, **kwargs):
if not launch.is_installed("tomesd"):
return
tome_force_hires: bool = bool(
shared.opts.data.get('tome_force_hires', True))
if tome_force_hires and args and args[0]:
_tome_instance.patch_model(p.sd_model, is_hires=True)
# add generation info
ratio: float = shared.opts.data.get('tome_merging_ratio', 0.5)
max_downsample: int = int(
shared.opts.data.get('tome_maximum_down_sampling', 1))
sx: int = int(shared.opts.data.get('tome_stride_x', 2))
sy: int = int(shared.opts.data.get('tome_stride_y', 2))
use_rand: bool = bool(shared.opts.data.get('tome_random', True))
merge_attn: bool = bool(
shared.opts.data.get('tome_merge_attention', True))
merge_crossattn: bool = bool(
shared.opts.data.get('tome_merge_cross_attention', False))
merge_mlp: bool = bool(
shared.opts.data.get('tome_merge_mlp', False))
tome_force_hires: bool = bool(
shared.opts.data.get('tome_force_hires', True))
p.extra_generation_params['tome_merging_ratio'] = ratio
p.extra_generation_params[
'tome_maximum_down_sampling'] = max_downsample
p.extra_generation_params['tome_stride_x'] = sx
p.extra_generation_params['tome_stride_y'] = sy
p.extra_generation_params['tome_random'] = use_rand
p.extra_generation_params['tome_merge_attention'] = merge_attn
p.extra_generation_params[
'tome_merge_cross_attention'] = merge_crossattn
p.extra_generation_params['tome_merge_mlp'] = merge_mlp
p.extra_generation_params['tome_force_hires'] = tome_force_hires
def post_hires_pass(self, p, *args, **kwargs):
if not launch.is_installed("tomesd"):
return
tome_force_hires: bool = bool(
shared.opts.data.get('tome_force_hires', True))
# remove patch all
if tome_force_hires and args and args[0]:
import tomesd
tomesd.remove_patch(p.sd_model)