|
|
import logging |
|
|
from os import environ |
|
|
import modules.scripts as scripts |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from collections import OrderedDict |
|
|
from typing import Union |
|
|
|
|
|
from modules import script_callbacks |
|
|
from modules.script_callbacks import CFGDenoiserParams |
|
|
try: |
|
|
from modules.rng import randn_like |
|
|
except ImportError: |
|
|
from torch import randn_like |
|
|
|
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) |
|
|
|
|
|
""" |
|
|
|
|
|
An implementation of CADS: Unleashing the Diversity of Diffusion Models through Condition-Annealed Sampling for Automatic1111 Webui |
|
|
|
|
|
@inproceedings{ |
|
|
sadat2024cads, |
|
|
title={{CADS}: Unleashing the Diversity of Diffusion Models through Condition-Annealed Sampling}, |
|
|
author={Seyedmorteza Sadat and Jakob Buhmann and Derek Bradley and Otmar Hilliges and Romann M. Weber}, |
|
|
booktitle={The Twelfth International Conference on Learning Representations}, |
|
|
year={2024}, |
|
|
url={https://openreview.net/forum?id=zMoNrajk2X} |
|
|
} |
|
|
|
|
|
Author: v0xie |
|
|
GitHub URL: https://github.com/v0xie/sd-webui-cads |
|
|
|
|
|
""" |
|
|
class CADSExtensionScript(scripts.Script): |
|
|
|
|
|
def title(self): |
|
|
return "CADS" |
|
|
|
|
|
|
|
|
def show(self, is_img2img): |
|
|
return scripts.AlwaysVisible |
|
|
|
|
|
|
|
|
def ui(self, is_img2img): |
|
|
with gr.Accordion('CADS', open=False): |
|
|
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='cads_active') |
|
|
rescale = gr.Checkbox(value=True, default=True, label="Rescale CFG", elem_id = 'cads_rescale') |
|
|
with gr.Row(): |
|
|
step_start = gr.Slider(value = 0, minimum = 0, maximum = 50, step = 1, label="Tau 1 Step", elem_id = 'cads_tau1_step', info="Step to start. (0=disable)") |
|
|
step_stop = gr.Slider(value = 0, minimum = 0, maximum = 50, step = 1, label="Tau 2 Step", elem_id = 'cads_tau1_step', info="Step to start. (0=disable)") |
|
|
with gr.Row(): |
|
|
t1 = gr.Slider(value = 0.6, minimum = 0.0, maximum = 1.0, step = 0.05, label="Tau 1", elem_id = 'cads_tau1', info="Step to start interpolating from full strength. Default 0.6") |
|
|
t2 = gr.Slider(value = 0.9, minimum = 0.0, maximum = 1.0, step = 0.05, label="Tau 2", elem_id = 'cads_tau2', info="Step to stop affecting image. Default 0.9") |
|
|
with gr.Row(): |
|
|
noise_scale = gr.Slider(value = 0.25, minimum = 0.0, maximum = 1.0, step = 0.01, label="Noise Scale", elem_id = 'cads_noise_scale', info='Scale of noise injected at every time step, default 0.25, recommended <= 0.3') |
|
|
mixing_factor= gr.Slider(value = 1.0, minimum = 0.0, maximum = 1.0, step = 0.01, label="Mixing Factor", elem_id = 'cads_mixing_factor', info='Regularization factor, lowering this will increase the diversity of the images with more chance of divergence, default 1.0') |
|
|
with gr.Accordion('Experimental', open=False): |
|
|
apply_to_hr_pass = gr.Checkbox(value=False, default=False, label="Apply to Hires. Fix", elem_id='cads_hr_fix_active', info='Requires a very high denoising value to work. Default False') |
|
|
active.do_not_save_to_config = True |
|
|
rescale.do_not_save_to_config = True |
|
|
t1.do_not_save_to_config = True |
|
|
t2.do_not_save_to_config = True |
|
|
noise_scale.do_not_save_to_config = True |
|
|
mixing_factor.do_not_save_to_config = True |
|
|
apply_to_hr_pass.do_not_save_to_config = True |
|
|
self.infotext_fields = [ |
|
|
(active, lambda d: gr.Checkbox.update(value='CADS Active' in d)), |
|
|
(rescale, 'CADS Rescale'), |
|
|
(step_start, 'CADS Tau 1 Step'), |
|
|
(step_stop, 'CADS Tau 2 Step'), |
|
|
(t1, 'CADS Tau 1'), |
|
|
(t2, 'CADS Tau 2'), |
|
|
(noise_scale, 'CADS Noise Scale'), |
|
|
(mixing_factor, 'CADS Mixing Factor'), |
|
|
(apply_to_hr_pass, 'CADS Apply To Hires. Fix'), |
|
|
] |
|
|
self.paste_field_names = [ |
|
|
'cads_active', |
|
|
'cads_rescale', |
|
|
'cads_tau1', |
|
|
'cads_tau2', |
|
|
'cads_noise_scale', |
|
|
'cads_mixing_factor', |
|
|
'cads_hr_fix_active', |
|
|
] |
|
|
return [active, step_start, step_stop, t1, t2, noise_scale, mixing_factor, rescale, apply_to_hr_pass] |
|
|
|
|
|
def before_process_batch(self, p, active, step_start, step_stop, t1, t2, noise_scale, mixing_factor, rescale, apply_to_hr_pass, *args, **kwargs): |
|
|
self.unhook_callbacks() |
|
|
active = getattr(p, "cads_active", active) |
|
|
if active is False: |
|
|
return |
|
|
steps = getattr(p, "steps", -1) |
|
|
if step_start != 0: |
|
|
step_start = getattr(p, "cads_tau1_step", step_start) |
|
|
t1 = max(min(step_start / steps, 1.0), 0.0) |
|
|
else: |
|
|
t1 = getattr(p, "cads_tau1", t1) |
|
|
|
|
|
if step_stop != 0: |
|
|
step_stop = getattr(p, "cads_tau2_step", step_stop) |
|
|
t2 = max(min(step_stop / steps, 1.0), 0.0) |
|
|
else: |
|
|
t2 = getattr(p, "cads_tau2", t2) |
|
|
|
|
|
noise_scale = getattr(p, "cads_noise_scale", noise_scale) |
|
|
mixing_factor = getattr(p, "cads_mixing_factor", mixing_factor) |
|
|
rescale = getattr(p, "cads_rescale", rescale) |
|
|
apply_to_hr_pass = getattr(p, "cads_hr_fix_active", apply_to_hr_pass) |
|
|
|
|
|
first_pass_steps = getattr(p, "steps", -1) |
|
|
if first_pass_steps <= 0: |
|
|
logger.error("Steps not set, disabling CADS") |
|
|
return |
|
|
|
|
|
p.extra_generation_params.update({ |
|
|
"CADS Active": active, |
|
|
"CADS Tau 1 Step": step_start, |
|
|
"CADS Tau 2 Step": step_stop, |
|
|
"CADS Tau 1": t1, |
|
|
"CADS Tau 2": t2, |
|
|
"CADS Noise Scale": noise_scale, |
|
|
"CADS Mixing Factor": mixing_factor, |
|
|
"CADS Rescale": rescale, |
|
|
"CADS Apply To Hires. Fix": apply_to_hr_pass, |
|
|
}) |
|
|
self.create_hook(p, active, t1, t2, noise_scale, mixing_factor, rescale, first_pass_steps) |
|
|
|
|
|
def create_hook(self, p, active, t1, t2, noise_scale, mixing_factor, rescale, total_sampling_steps, *args, **kwargs): |
|
|
|
|
|
y = lambda params: self.on_cfg_denoiser_callback(params, t1=t1, t2=t2, noise_scale=noise_scale, mixing_factor=mixing_factor, rescale=rescale, total_sampling_steps=total_sampling_steps) |
|
|
|
|
|
logger.debug('Hooked callbacks') |
|
|
script_callbacks.on_cfg_denoiser(y) |
|
|
script_callbacks.on_script_unloaded(self.unhook_callbacks) |
|
|
|
|
|
def postprocess_batch(self, p, active, t1, t2, noise_scale, mixing_factor, rescale, apply_to_hr_pass, *args, **kwargs): |
|
|
self.unhook_callbacks() |
|
|
|
|
|
def unhook_callbacks(self): |
|
|
logger.debug('Unhooked callbacks') |
|
|
script_callbacks.remove_current_script_callbacks() |
|
|
|
|
|
def cads_linear_schedule(self, t, tau1, tau2): |
|
|
""" CADS annealing schedule function """ |
|
|
if t <= tau1: |
|
|
return 1.0 |
|
|
if t>= tau2: |
|
|
return 0.0 |
|
|
gamma = (tau2-t)/(tau2-tau1) |
|
|
return gamma |
|
|
|
|
|
def add_noise(self, y, gamma, noise_scale, psi, rescale=False): |
|
|
""" CADS adding noise to the condition |
|
|
|
|
|
Arguments: |
|
|
y: Input conditioning |
|
|
gamma: Noise level w.r.t t |
|
|
noise_scale (float): Noise scale |
|
|
psi (float): Rescaling factor |
|
|
rescale (bool): Rescale the condition |
|
|
""" |
|
|
y_mean, y_std = torch.mean(y), torch.std(y) |
|
|
y = np.sqrt(gamma) * y + noise_scale * np.sqrt(1-gamma) * randn_like(y) |
|
|
if rescale: |
|
|
y_scaled = (y - torch.mean(y)) / torch.std(y) * y_std + y_mean |
|
|
if not torch.isnan(y_scaled).any(): |
|
|
y = psi * y_scaled + (1 - psi) * y |
|
|
else: |
|
|
logger.debug("Warning: NaN encountered in rescaling") |
|
|
return y |
|
|
|
|
|
def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, t1, t2, noise_scale, mixing_factor, rescale, total_sampling_steps): |
|
|
sampling_step = params.sampling_step |
|
|
total_sampling_step = total_sampling_steps |
|
|
text_cond = params.text_cond |
|
|
text_uncond = params.text_uncond |
|
|
|
|
|
t = 1.0 - max(min(sampling_step / total_sampling_step, 1.0), 0.0) |
|
|
gamma = self.cads_linear_schedule(t, t1, t2) |
|
|
|
|
|
if isinstance(text_cond, torch.Tensor) and isinstance(text_uncond, torch.Tensor): |
|
|
params.text_cond = self.add_noise(text_cond, gamma, noise_scale, mixing_factor, rescale) |
|
|
params.text_uncond = self.add_noise(text_uncond, gamma, noise_scale, mixing_factor, rescale) |
|
|
|
|
|
elif isinstance(text_cond, Union[dict, OrderedDict]) and isinstance(text_uncond, Union[dict, OrderedDict]): |
|
|
params.text_cond['crossattn'] = self.add_noise(text_cond['crossattn'], gamma, noise_scale, mixing_factor, rescale) |
|
|
params.text_uncond['crossattn'] = self.add_noise(text_uncond['crossattn'], gamma, noise_scale, mixing_factor, rescale) |
|
|
params.text_cond['vector'] = self.add_noise(text_cond['vector'], gamma, noise_scale, mixing_factor, rescale) |
|
|
params.text_uncond['vector'] = self.add_noise(text_uncond['vector'], gamma, noise_scale, mixing_factor, rescale) |
|
|
else: |
|
|
logger.error('Unknown text_cond type') |
|
|
pass |
|
|
|
|
|
def before_hr(self, p, *args): |
|
|
self.unhook_callbacks() |
|
|
|
|
|
params = getattr(p, "extra_generation_params", None) |
|
|
if not params: |
|
|
logger.error("Missing attribute extra_generation_params") |
|
|
return |
|
|
|
|
|
active = params.get("CADS Active", False) |
|
|
if active is False: |
|
|
return |
|
|
|
|
|
apply_to_hr_pass = params.get("CADS Apply To Hires. Fix", False) |
|
|
if apply_to_hr_pass is False: |
|
|
logger.debug("Disabled for hires. fix") |
|
|
return |
|
|
|
|
|
t1 = params.get("CADS Tau 1", None) |
|
|
t2 = params.get("CADS Tau 2", None) |
|
|
noise_scale = params.get("CADS Noise Scale", None) |
|
|
mixing_factor = params.get("CADS Mixing Factor", None) |
|
|
rescale = params.get("CADS Rescale", None) |
|
|
|
|
|
if t1 is None or t2 is None or noise_scale is None or mixing_factor is None or rescale is None: |
|
|
logger.error("Missing needed parameters for Hires. fix") |
|
|
return |
|
|
|
|
|
hr_pass_steps = getattr(p, "hr_second_pass_steps", -1) |
|
|
if hr_pass_steps < 0: |
|
|
logger.error("Attribute hr_second_pass_steps not found") |
|
|
return |
|
|
if hr_pass_steps == 0: |
|
|
logger.debug("Using first pass step count for hires. fix") |
|
|
hr_pass_steps = getattr(p, "steps", -1) |
|
|
|
|
|
logger.debug("Enabled for hi-res fix with %i steps, re-hooking CADS", hr_pass_steps) |
|
|
self.create_hook(p, active, t1, t2, noise_scale, mixing_factor, rescale, hr_pass_steps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cads_apply_override(field, boolean: bool = False): |
|
|
def fun(p, x, xs): |
|
|
if boolean: |
|
|
x = True if x.lower() == "true" else False |
|
|
setattr(p, field, x) |
|
|
return fun |
|
|
|
|
|
def cads_apply_field(field): |
|
|
def fun(p, x, xs): |
|
|
if not hasattr(p, "cads_active"): |
|
|
setattr(p, "cads_active", True) |
|
|
setattr(p, field, x) |
|
|
|
|
|
return fun |
|
|
|
|
|
def make_axis_options(): |
|
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module |
|
|
|
|
|
if not hasattr(xyz_grid, "boolean_choice"): |
|
|
xyz_grid.boolean_choice = lambda reverse=False: ["True", "False"] if not reverse else ["False", "True"] |
|
|
extra_axis_options = { |
|
|
xyz_grid.AxisOption("[CADS] Active", str, cads_apply_override('cads_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[CADS] Rescale CFG", str, cads_apply_override('cads_rescale', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
xyz_grid.AxisOption("[CADS] Tau 1", float, cads_apply_field("cads_tau1")), |
|
|
xyz_grid.AxisOption("[CADS] Tau 2", float, cads_apply_field("cads_tau2")), |
|
|
xyz_grid.AxisOption("[CADS] Noise Scale", float, cads_apply_field("cads_noise_scale")), |
|
|
xyz_grid.AxisOption("[CADS] Mixing Factor", float, cads_apply_field("cads_mixing_factor")), |
|
|
xyz_grid.AxisOption("[CADS] Apply to Hires. Fix", str, cads_apply_override('cads_hr_fix_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), |
|
|
} |
|
|
if not any("[CADS]" in x.label for x in xyz_grid.axis_options): |
|
|
xyz_grid.axis_options.extend(extra_axis_options) |
|
|
|
|
|
def callback_before_ui(): |
|
|
try: |
|
|
make_axis_options() |
|
|
except: |
|
|
logger.exception("CADS: Error while making axis options") |
|
|
|
|
|
script_callbacks.on_before_ui(callback_before_ui) |
|
|
|