"""Classifier-Free Guidance implementation.""" import math import logging import torch from src.cond import cond, cond_util def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): """Apply classifier-free guidance to predictions.""" # Dynamic CFG rescaling if "cfg_guider" in model_options: guider = model_options["cfg_guider"] if hasattr(guider, "dynamic_cfg_rescaling") and guider.dynamic_cfg_rescaling: cond_scale = guider._apply_dynamic_cfg_rescaling(cond_pred, uncond_pred, cond_scale) # Custom sampler CFG if "sampler_cfg_function" in model_options: cfg_result = x - model_options["sampler_cfg_function"]({ "cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options, }) elif math.isclose(cond_scale, 1.0): cfg_result = cond_pred else: cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale) # Post-CFG functions for fn in model_options.get("sampler_post_cfg_function", []): cfg_result = fn({ "denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, "sigma": timestep, "model_options": model_options, "input": x, }) return cfg_result def sampling_function(model, x, timestep, uncond, condo, cond_scale, model_options={}, seed=None): """Perform sampling with CFG.""" uncond_ = None if (math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False)) else uncond cond_outputs = cond.calc_cond_batch(model, [condo, uncond_], x, timestep, model_options) # Pre-CFG functions for fn in model_options.get("sampler_pre_cfg_function", []): cond_outputs = fn({ "conds": [condo, uncond_], "conds_out": cond_outputs, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "model": model, "model_options": model_options, }) return cfg_function(model, cond_outputs[0], cond_outputs[1], cond_scale, x, timestep, model_options, condo, uncond_) class CFGGuider: """Guidance with Classifier-Free Guidance.""" def __init__(self, model_patcher, flux=False, dynamic_cfg_rescaling=False, dynamic_cfg_method="variance", dynamic_cfg_percentile=95.0, dynamic_cfg_target_scale=7.0, adaptive_noise_enabled=False, adaptive_noise_method="complexity"): self.model_patcher = model_patcher inner_model = getattr(model_patcher, "model", model_patcher) self.model_options = getattr( model_patcher, "model_options", getattr(inner_model, "model_options", {}), ) self.original_conds = {} self.cfg = 1.0 self.cfg_free_enabled = False self.cfg_free_start_percent = 70.0 self.original_cfg = 1.0 self.sigmas = None self.flux = flux # Flag for FLUX model behavior self.dynamic_cfg_rescaling = dynamic_cfg_rescaling self.dynamic_cfg_method = dynamic_cfg_method self.dynamic_cfg_percentile = dynamic_cfg_percentile self.dynamic_cfg_target_scale = dynamic_cfg_target_scale self.adaptive_noise_enabled = adaptive_noise_enabled self.adaptive_noise_method = adaptive_noise_method self.complexity_history = [] self.base_sigmas = None def set_conds(self, positive, negative): self.inner_set_conds({"positive": positive, "negative": negative}) def set_cfg(self, cfg): self.cfg = cfg self.original_cfg = cfg def set_cfg_free_params(self, enabled=False, start_percent=70.0): self.cfg_free_enabled = enabled self.cfg_free_start_percent = max(0.0, min(100.0, start_percent)) if enabled: print(f"CFG-Free sampling ACTIVE: CFG will reduce to 0 starting at {start_percent:.0f}% of steps") def _update_cfg_for_sigma(self, sigma): """Update CFG based on current sigma for CFG-free sampling.""" if not self.cfg_free_enabled or self.sigmas is None or len(self.sigmas) <= 1: return total_steps = len(self.sigmas) - 1 current_step = min(range(len(self.sigmas)), key=lambda i: abs(float(self.sigmas[i]) - float(sigma))) progress = (current_step / total_steps) * 100.0 if total_steps > 0 else 0 if progress >= self.cfg_free_start_percent: remaining = 100.0 - self.cfg_free_start_percent if remaining > 0: self.cfg = max(0.0, self.original_cfg * (1.0 - (progress - self.cfg_free_start_percent) / remaining)) else: self.cfg = self.original_cfg def _apply_dynamic_cfg_rescaling(self, cond_pred, uncond_pred, cond_scale): """Apply dynamic CFG rescaling.""" if not self.dynamic_cfg_rescaling: return cond_scale diff = cond_pred - uncond_pred if self.dynamic_cfg_method == "variance": variance = min(torch.var(diff).item() / 0.1, 10.0) adjusted = cond_scale / (1.0 + variance * 0.1) elif self.dynamic_cfg_method == "range": low = torch.quantile(diff.flatten(), (100 - self.dynamic_cfg_percentile) / 100).item() high = torch.quantile(diff.flatten(), self.dynamic_cfg_percentile / 100).item() adjusted = min(cond_scale / max(high - low, 0.01), self.dynamic_cfg_target_scale) else: adjusted = cond_scale return max(1.0, min(adjusted, 20.0)) def inner_set_conds(self, conds): for k in conds: self.original_conds[k] = cond.convert_cond(conds[k]) def __call__(self, *args, **kwargs): return self.predict_noise(*args, **kwargs) def predict_noise(self, x, timestep, model_options={}, seed=None): if self.cfg_free_enabled: self._update_cfg_for_sigma(timestep) opts = {**model_options, "cfg_guider": self} result = sampling_function(self.inner_model, x, timestep, self.conds.get("negative"), self.conds.get("positive"), self.cfg, model_options=opts, seed=seed) if self.adaptive_noise_enabled: self.complexity_history.append(self._calc_complexity(result)) return result def _calc_complexity(self, prediction): """Calculate complexity for adaptive noise.""" if self.adaptive_noise_method == "complexity": dx = prediction[:, :, :, 1:] - prediction[:, :, :, :-1] dy = prediction[:, :, 1:, :] - prediction[:, :, :-1, :] h, w = min(dx.shape[2], dy.shape[2]), min(dx.shape[3], dy.shape[3]) return (dx[:, :, :h, :w].abs() + dy[:, :, :h, :w].abs()).mean().item() return prediction.var(dim=[2, 3]).mean().item() def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, pipeline=False): if latent_image is not None and torch.count_nonzero(latent_image) > 0: latent_image = self.inner_model.process_latent_in(latent_image) self.conds = cond.process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) self.sigmas = sigmas if self.adaptive_noise_enabled and len(self.complexity_history) > 0: if self.base_sigmas is None: self.base_sigmas = sigmas.clone() avg = sum(self.complexity_history) / len(self.complexity_history) sigmas = self.base_sigmas * (1.0 + (avg / max(0.01, avg + 0.1)) * 0.5) samples = sampler.sample(self, sigmas, {"model_options": self.model_options, "seed": seed}, callback, noise, latent_image, denoise_mask, disable_pbar, pipeline=pipeline) return self.inner_model.process_latent_out(samples.to(torch.float32)) def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, pipeline=False): self.conds = {k: [a.copy() for a in v] for k, v in self.original_conds.items()} self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling( self.model_patcher, noise.shape, self.conds) inner_patcher = getattr(self.model_patcher, "model", self.model_patcher) device = getattr(self.model_patcher, "load_device", getattr(inner_patcher, "load_device", None)) # Handle mock objects in tests if not isinstance(device, (torch.device, str)): from src.Device import Device device = Device.get_torch_device() output = self.inner_sample(noise.to(device), latent_image.to(device), device, sampler, sigmas.to(device), denoise_mask, callback, disable_pbar, seed, pipeline) from src.Device.ModelCache import get_model_cache get_model_cache().prevent_model_cleanup(self.conds, self.loaded_models) del self.inner_model, self.conds, self.loaded_models return output