Spaces:
Running on Zero
Running on Zero
File size: 9,412 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Classifier-Free Guidance implementation."""
import math
import logging
import torch
from src.cond import cond, cond_util
def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None):
"""Apply classifier-free guidance to predictions."""
# Dynamic CFG rescaling
if "cfg_guider" in model_options:
guider = model_options["cfg_guider"]
if hasattr(guider, "dynamic_cfg_rescaling") and guider.dynamic_cfg_rescaling:
cond_scale = guider._apply_dynamic_cfg_rescaling(cond_pred, uncond_pred, cond_scale)
# Custom sampler CFG
if "sampler_cfg_function" in model_options:
cfg_result = x - model_options["sampler_cfg_function"]({
"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale,
"timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred,
"model": model, "model_options": model_options,
})
elif math.isclose(cond_scale, 1.0):
cfg_result = cond_pred
else:
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)
# Post-CFG functions
for fn in model_options.get("sampler_post_cfg_function", []):
cfg_result = fn({
"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model,
"uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x,
})
return cfg_result
def sampling_function(model, x, timestep, uncond, condo, cond_scale, model_options={}, seed=None):
"""Perform sampling with CFG."""
uncond_ = None if (math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False)) else uncond
cond_outputs = cond.calc_cond_batch(model, [condo, uncond_], x, timestep, model_options)
# Pre-CFG functions
for fn in model_options.get("sampler_pre_cfg_function", []):
cond_outputs = fn({
"conds": [condo, uncond_], "conds_out": cond_outputs, "cond_scale": cond_scale,
"timestep": timestep, "input": x, "sigma": timestep,
"model": model, "model_options": model_options,
})
return cfg_function(model, cond_outputs[0], cond_outputs[1], cond_scale, x, timestep, model_options, condo, uncond_)
class CFGGuider:
"""Guidance with Classifier-Free Guidance."""
def __init__(self, model_patcher, flux=False, dynamic_cfg_rescaling=False, dynamic_cfg_method="variance",
dynamic_cfg_percentile=95.0, dynamic_cfg_target_scale=7.0,
adaptive_noise_enabled=False, adaptive_noise_method="complexity"):
self.model_patcher = model_patcher
inner_model = getattr(model_patcher, "model", model_patcher)
self.model_options = getattr(
model_patcher,
"model_options",
getattr(inner_model, "model_options", {}),
)
self.original_conds = {}
self.cfg = 1.0
self.cfg_free_enabled = False
self.cfg_free_start_percent = 70.0
self.original_cfg = 1.0
self.sigmas = None
self.flux = flux # Flag for FLUX model behavior
self.dynamic_cfg_rescaling = dynamic_cfg_rescaling
self.dynamic_cfg_method = dynamic_cfg_method
self.dynamic_cfg_percentile = dynamic_cfg_percentile
self.dynamic_cfg_target_scale = dynamic_cfg_target_scale
self.adaptive_noise_enabled = adaptive_noise_enabled
self.adaptive_noise_method = adaptive_noise_method
self.complexity_history = []
self.base_sigmas = None
def set_conds(self, positive, negative):
self.inner_set_conds({"positive": positive, "negative": negative})
def set_cfg(self, cfg):
self.cfg = cfg
self.original_cfg = cfg
def set_cfg_free_params(self, enabled=False, start_percent=70.0):
self.cfg_free_enabled = enabled
self.cfg_free_start_percent = max(0.0, min(100.0, start_percent))
if enabled:
print(f"CFG-Free sampling ACTIVE: CFG will reduce to 0 starting at {start_percent:.0f}% of steps")
def _update_cfg_for_sigma(self, sigma):
"""Update CFG based on current sigma for CFG-free sampling."""
if not self.cfg_free_enabled or self.sigmas is None or len(self.sigmas) <= 1:
return
total_steps = len(self.sigmas) - 1
current_step = min(range(len(self.sigmas)), key=lambda i: abs(float(self.sigmas[i]) - float(sigma)))
progress = (current_step / total_steps) * 100.0 if total_steps > 0 else 0
if progress >= self.cfg_free_start_percent:
remaining = 100.0 - self.cfg_free_start_percent
if remaining > 0:
self.cfg = max(0.0, self.original_cfg * (1.0 - (progress - self.cfg_free_start_percent) / remaining))
else:
self.cfg = self.original_cfg
def _apply_dynamic_cfg_rescaling(self, cond_pred, uncond_pred, cond_scale):
"""Apply dynamic CFG rescaling."""
if not self.dynamic_cfg_rescaling:
return cond_scale
diff = cond_pred - uncond_pred
if self.dynamic_cfg_method == "variance":
variance = min(torch.var(diff).item() / 0.1, 10.0)
adjusted = cond_scale / (1.0 + variance * 0.1)
elif self.dynamic_cfg_method == "range":
low = torch.quantile(diff.flatten(), (100 - self.dynamic_cfg_percentile) / 100).item()
high = torch.quantile(diff.flatten(), self.dynamic_cfg_percentile / 100).item()
adjusted = min(cond_scale / max(high - low, 0.01), self.dynamic_cfg_target_scale)
else:
adjusted = cond_scale
return max(1.0, min(adjusted, 20.0))
def inner_set_conds(self, conds):
for k in conds:
self.original_conds[k] = cond.convert_cond(conds[k])
def __call__(self, *args, **kwargs):
return self.predict_noise(*args, **kwargs)
def predict_noise(self, x, timestep, model_options={}, seed=None):
if self.cfg_free_enabled:
self._update_cfg_for_sigma(timestep)
opts = {**model_options, "cfg_guider": self}
result = sampling_function(self.inner_model, x, timestep,
self.conds.get("negative"), self.conds.get("positive"),
self.cfg, model_options=opts, seed=seed)
if self.adaptive_noise_enabled:
self.complexity_history.append(self._calc_complexity(result))
return result
def _calc_complexity(self, prediction):
"""Calculate complexity for adaptive noise."""
if self.adaptive_noise_method == "complexity":
dx = prediction[:, :, :, 1:] - prediction[:, :, :, :-1]
dy = prediction[:, :, 1:, :] - prediction[:, :, :-1, :]
h, w = min(dx.shape[2], dy.shape[2]), min(dx.shape[3], dy.shape[3])
return (dx[:, :, :h, :w].abs() + dy[:, :, :h, :w].abs()).mean().item()
return prediction.var(dim=[2, 3]).mean().item()
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, pipeline=False):
if latent_image is not None and torch.count_nonzero(latent_image) > 0:
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = cond.process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.sigmas = sigmas
if self.adaptive_noise_enabled and len(self.complexity_history) > 0:
if self.base_sigmas is None:
self.base_sigmas = sigmas.clone()
avg = sum(self.complexity_history) / len(self.complexity_history)
sigmas = self.base_sigmas * (1.0 + (avg / max(0.01, avg + 0.1)) * 0.5)
samples = sampler.sample(self, sigmas, {"model_options": self.model_options, "seed": seed},
callback, noise, latent_image, denoise_mask, disable_pbar, pipeline=pipeline)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None,
disable_pbar=False, seed=None, pipeline=False):
self.conds = {k: [a.copy() for a in v] for k, v in self.original_conds.items()}
self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling(
self.model_patcher, noise.shape, self.conds)
inner_patcher = getattr(self.model_patcher, "model", self.model_patcher)
device = getattr(self.model_patcher, "load_device", getattr(inner_patcher, "load_device", None))
# Handle mock objects in tests
if not isinstance(device, (torch.device, str)):
from src.Device import Device
device = Device.get_torch_device()
output = self.inner_sample(noise.to(device), latent_image.to(device), device, sampler,
sigmas.to(device), denoise_mask, callback, disable_pbar, seed, pipeline)
from src.Device.ModelCache import get_model_cache
get_model_cache().prevent_model_cleanup(self.conds, self.loaded_models)
del self.inner_model, self.conds, self.loaded_models
return output
|