Spaces:
Running on Zero
Running on Zero
File size: 8,089 Bytes
d066167 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import dataclasses
import torch
import k_diffusion
import inspect
from types import SimpleNamespace
from refnet.util import default
from .scheduler import schedulers, schedulers_map
from .denoiser import CFGDenoiser
defaults = SimpleNamespace(**{
"eta_ddim": 0.0,
"eta_ancestral": 1.0,
"ddim_discretize": "uniform",
"s_churn": 0.0,
"s_tmin": 0.0,
"s_noise": 1.0,
"k_sched_type": "Automatic",
"sigma_min": 0.0,
"sigma_max": 0.0,
"rho": 0.0,
"eta_noise_seed_delta": 0,
"always_discard_next_to_last_sigma": False,
})
@dataclasses.dataclass
class Sampler:
label: str
funcname: str
aliases: any
options: dict
samplers_k_diffusion = [
Sampler('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
Sampler('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
Sampler('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
Sampler('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
Sampler('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
Sampler('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
Sampler('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
Sampler('Euler', 'sample_euler', ['k_euler'], {}),
Sampler('LMS', 'sample_lms', ['k_lms'], {}),
Sampler('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
Sampler('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
Sampler('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
Sampler('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
Sampler('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True})
]
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_fast': ['s_noise'],
'sample_dpm_2_ancestral': ['s_noise'],
'sample_dpmpp_2s_ancestral': ['s_noise'],
'sample_dpmpp_sde': ['s_noise'],
'sample_dpmpp_2m_sde': ['s_noise'],
'sample_dpmpp_3m_sde': ['s_noise'],
}
def kdiffusion_sampler_list():
return [k.label for k in samplers_k_diffusion]
k_diffusion_samplers_map = {x.label: x for x in samplers_k_diffusion}
k_diffusion_scheduler = {x.name: x.function for x in schedulers}
def exists(v):
return v is not None
class KDiffusionSampler:
def __init__(self, sampler, scheduler, sd, device):
# k_diffusion_samplers_map[]
self.config = k_diffusion_samplers_map[sampler]
funcname = self.config.funcname
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, funcname)
self.scheduler_name = scheduler
self.sd = CFGDenoiser(sd, device)
self.model_wrap = self.sd.model_wrap
self.device = device
self.s_min_uncond = None
self.s_churn = 0.0
self.s_tmin = 0.0
self.s_tmax = float('inf')
self.s_noise = 1.0
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
self.eta_default = 1.0
self.eta = None
self.extra_params = []
if exists(sd.sigma_max) and exists(sd.sigma_min):
self.model_wrap.sigmas[-1] = sd.sigma_max
self.model_wrap.sigmas[0] = sd.sigma_min
def initialize(self):
self.eta = getattr(defaults, self.eta_option_field, 0.0)
extra_params_kwargs = {}
for param_name in self.extra_params:
if param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(self, param_name)
if 'eta' in inspect.signature(self.func).parameters:
extra_params_kwargs['eta'] = self.eta
if len(self.extra_params) > 0:
s_churn = getattr(defaults, 's_churn', self.s_churn)
s_tmin = getattr(defaults, 's_tmin', self.s_tmin)
s_tmax = getattr(defaults, 's_tmax', self.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(defaults, 's_noise', self.s_noise)
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
self.s_churn = s_churn
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
self.s_tmin = s_tmin
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
self.s_tmax = s_tmax
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
self.s_noise = s_noise
return extra_params_kwargs
def create_noise_sampler(self, x, sigmas, seed):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed)
def get_sigmas(self, steps, sigmas_min=None, sigmas_max=None):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
steps += 1 if discard_next_to_last_sigma else 0
if self.scheduler_name == 'Automatic':
self.scheduler_name = self.config.options.get('scheduler', None)
scheduler = schedulers_map.get(self.scheduler_name)
sigma_min = default(sigmas_min, self.model_wrap.sigma_min)
sigma_max = default(sigmas_max, self.model_wrap.sigma_max)
if scheduler is None or scheduler.function is None:
sigmas = self.model_wrap.get_sigmas(steps)
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=self.device)
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def __call__(self, x, sigmas, sampler_extra_args, seed, deterministic, steps=None):
x = x * sigmas[0]
extra_params_kwargs = self.initialize()
parameters = inspect.signature(self.func).parameters
if 'n' in parameters:
extra_params_kwargs['n'] = steps
if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = sigmas[sigmas > 0].min()
extra_params_kwargs['sigma_max'] = sigmas.max()
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigmas
if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, seed) if deterministic else None
extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
return self.func(self.sd, x, extra_args=sampler_extra_args, disable=False, **extra_params_kwargs)
|