Spaces:
Running on Zero
Running on Zero
| """Classifier-Free Guidance implementation.""" | |
| import math | |
| import logging | |
| import torch | |
| from src.cond import cond, cond_util | |
| def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): | |
| """Apply classifier-free guidance to predictions.""" | |
| # Dynamic CFG rescaling | |
| if "cfg_guider" in model_options: | |
| guider = model_options["cfg_guider"] | |
| if hasattr(guider, "dynamic_cfg_rescaling") and guider.dynamic_cfg_rescaling: | |
| cond_scale = guider._apply_dynamic_cfg_rescaling(cond_pred, uncond_pred, cond_scale) | |
| # Custom sampler CFG | |
| if "sampler_cfg_function" in model_options: | |
| cfg_result = x - model_options["sampler_cfg_function"]({ | |
| "cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, | |
| "timestep": timestep, "input": x, "sigma": timestep, | |
| "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, | |
| "model": model, "model_options": model_options, | |
| }) | |
| elif math.isclose(cond_scale, 1.0): | |
| cfg_result = cond_pred | |
| else: | |
| cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale) | |
| # Post-CFG functions | |
| for fn in model_options.get("sampler_post_cfg_function", []): | |
| cfg_result = fn({ | |
| "denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, | |
| "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, | |
| "sigma": timestep, "model_options": model_options, "input": x, | |
| }) | |
| return cfg_result | |
| def sampling_function(model, x, timestep, uncond, condo, cond_scale, model_options={}, seed=None): | |
| """Perform sampling with CFG.""" | |
| uncond_ = None if (math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False)) else uncond | |
| cond_outputs = cond.calc_cond_batch(model, [condo, uncond_], x, timestep, model_options) | |
| # Pre-CFG functions | |
| for fn in model_options.get("sampler_pre_cfg_function", []): | |
| cond_outputs = fn({ | |
| "conds": [condo, uncond_], "conds_out": cond_outputs, "cond_scale": cond_scale, | |
| "timestep": timestep, "input": x, "sigma": timestep, | |
| "model": model, "model_options": model_options, | |
| }) | |
| return cfg_function(model, cond_outputs[0], cond_outputs[1], cond_scale, x, timestep, model_options, condo, uncond_) | |
| class CFGGuider: | |
| """Guidance with Classifier-Free Guidance.""" | |
| def __init__(self, model_patcher, flux=False, dynamic_cfg_rescaling=False, dynamic_cfg_method="variance", | |
| dynamic_cfg_percentile=95.0, dynamic_cfg_target_scale=7.0, | |
| adaptive_noise_enabled=False, adaptive_noise_method="complexity"): | |
| self.model_patcher = model_patcher | |
| inner_model = getattr(model_patcher, "model", model_patcher) | |
| self.model_options = getattr( | |
| model_patcher, | |
| "model_options", | |
| getattr(inner_model, "model_options", {}), | |
| ) | |
| self.original_conds = {} | |
| self.cfg = 1.0 | |
| self.cfg_free_enabled = False | |
| self.cfg_free_start_percent = 70.0 | |
| self.original_cfg = 1.0 | |
| self.sigmas = None | |
| self.flux = flux # Flag for FLUX model behavior | |
| self.dynamic_cfg_rescaling = dynamic_cfg_rescaling | |
| self.dynamic_cfg_method = dynamic_cfg_method | |
| self.dynamic_cfg_percentile = dynamic_cfg_percentile | |
| self.dynamic_cfg_target_scale = dynamic_cfg_target_scale | |
| self.adaptive_noise_enabled = adaptive_noise_enabled | |
| self.adaptive_noise_method = adaptive_noise_method | |
| self.complexity_history = [] | |
| self.base_sigmas = None | |
| def set_conds(self, positive, negative): | |
| self.inner_set_conds({"positive": positive, "negative": negative}) | |
| def set_cfg(self, cfg): | |
| self.cfg = cfg | |
| self.original_cfg = cfg | |
| def set_cfg_free_params(self, enabled=False, start_percent=70.0): | |
| self.cfg_free_enabled = enabled | |
| self.cfg_free_start_percent = max(0.0, min(100.0, start_percent)) | |
| if enabled: | |
| print(f"CFG-Free sampling ACTIVE: CFG will reduce to 0 starting at {start_percent:.0f}% of steps") | |
| def _update_cfg_for_sigma(self, sigma): | |
| """Update CFG based on current sigma for CFG-free sampling.""" | |
| if not self.cfg_free_enabled or self.sigmas is None or len(self.sigmas) <= 1: | |
| return | |
| total_steps = len(self.sigmas) - 1 | |
| current_step = min(range(len(self.sigmas)), key=lambda i: abs(float(self.sigmas[i]) - float(sigma))) | |
| progress = (current_step / total_steps) * 100.0 if total_steps > 0 else 0 | |
| if progress >= self.cfg_free_start_percent: | |
| remaining = 100.0 - self.cfg_free_start_percent | |
| if remaining > 0: | |
| self.cfg = max(0.0, self.original_cfg * (1.0 - (progress - self.cfg_free_start_percent) / remaining)) | |
| else: | |
| self.cfg = self.original_cfg | |
| def _apply_dynamic_cfg_rescaling(self, cond_pred, uncond_pred, cond_scale): | |
| """Apply dynamic CFG rescaling.""" | |
| if not self.dynamic_cfg_rescaling: | |
| return cond_scale | |
| diff = cond_pred - uncond_pred | |
| if self.dynamic_cfg_method == "variance": | |
| variance = min(torch.var(diff).item() / 0.1, 10.0) | |
| adjusted = cond_scale / (1.0 + variance * 0.1) | |
| elif self.dynamic_cfg_method == "range": | |
| low = torch.quantile(diff.flatten(), (100 - self.dynamic_cfg_percentile) / 100).item() | |
| high = torch.quantile(diff.flatten(), self.dynamic_cfg_percentile / 100).item() | |
| adjusted = min(cond_scale / max(high - low, 0.01), self.dynamic_cfg_target_scale) | |
| else: | |
| adjusted = cond_scale | |
| return max(1.0, min(adjusted, 20.0)) | |
| def inner_set_conds(self, conds): | |
| for k in conds: | |
| self.original_conds[k] = cond.convert_cond(conds[k]) | |
| def __call__(self, *args, **kwargs): | |
| return self.predict_noise(*args, **kwargs) | |
| def predict_noise(self, x, timestep, model_options={}, seed=None): | |
| if self.cfg_free_enabled: | |
| self._update_cfg_for_sigma(timestep) | |
| opts = {**model_options, "cfg_guider": self} | |
| result = sampling_function(self.inner_model, x, timestep, | |
| self.conds.get("negative"), self.conds.get("positive"), | |
| self.cfg, model_options=opts, seed=seed) | |
| if self.adaptive_noise_enabled: | |
| self.complexity_history.append(self._calc_complexity(result)) | |
| return result | |
| def _calc_complexity(self, prediction): | |
| """Calculate complexity for adaptive noise.""" | |
| if self.adaptive_noise_method == "complexity": | |
| dx = prediction[:, :, :, 1:] - prediction[:, :, :, :-1] | |
| dy = prediction[:, :, 1:, :] - prediction[:, :, :-1, :] | |
| h, w = min(dx.shape[2], dy.shape[2]), min(dx.shape[3], dy.shape[3]) | |
| return (dx[:, :, :h, :w].abs() + dy[:, :, :h, :w].abs()).mean().item() | |
| return prediction.var(dim=[2, 3]).mean().item() | |
| def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, pipeline=False): | |
| if latent_image is not None and torch.count_nonzero(latent_image) > 0: | |
| latent_image = self.inner_model.process_latent_in(latent_image) | |
| self.conds = cond.process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) | |
| self.sigmas = sigmas | |
| if self.adaptive_noise_enabled and len(self.complexity_history) > 0: | |
| if self.base_sigmas is None: | |
| self.base_sigmas = sigmas.clone() | |
| avg = sum(self.complexity_history) / len(self.complexity_history) | |
| sigmas = self.base_sigmas * (1.0 + (avg / max(0.01, avg + 0.1)) * 0.5) | |
| samples = sampler.sample(self, sigmas, {"model_options": self.model_options, "seed": seed}, | |
| callback, noise, latent_image, denoise_mask, disable_pbar, pipeline=pipeline) | |
| return self.inner_model.process_latent_out(samples.to(torch.float32)) | |
| def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, | |
| disable_pbar=False, seed=None, pipeline=False): | |
| self.conds = {k: [a.copy() for a in v] for k, v in self.original_conds.items()} | |
| self.inner_model, self.conds, self.loaded_models = cond_util.prepare_sampling( | |
| self.model_patcher, noise.shape, self.conds) | |
| inner_patcher = getattr(self.model_patcher, "model", self.model_patcher) | |
| device = getattr(self.model_patcher, "load_device", getattr(inner_patcher, "load_device", None)) | |
| # Handle mock objects in tests | |
| if not isinstance(device, (torch.device, str)): | |
| from src.Device import Device | |
| device = Device.get_torch_device() | |
| output = self.inner_sample(noise.to(device), latent_image.to(device), device, sampler, | |
| sigmas.to(device), denoise_mask, callback, disable_pbar, seed, pipeline) | |
| from src.Device.ModelCache import get_model_cache | |
| get_model_cache().prevent_model_cleanup(self.conds, self.loaded_models) | |
| del self.inner_model, self.conds, self.loaded_models | |
| return output | |