| """ |
| Adept Sampler v5 for Automatic1111 WebUI |
| Complete port with ALL custom samplers from ComfyUI |
| |
| Version: 5.0 |
| """ |
|
|
| import torch |
| import numpy as np |
| import math |
| from tqdm import trange |
| from modules import scripts, shared, script_callbacks |
| import gradio as gr |
| import k_diffusion.sampling |
|
|
| |
| try: |
| from torchvision.transforms.functional import gaussian_blur |
| TORCHVISION_AVAILABLE = True |
| except ImportError: |
| TORCHVISION_AVAILABLE = False |
| print("⚠️ torchvision not available - detail enhancement disabled") |
|
|
| |
| |
| |
| ADEPT_STATE = { |
| "enabled": False, |
| "scale": 1.0, |
| "shift": 0.0, |
| "start_pct": 0.0, |
| "end_pct": 1.0, |
| "eta": 1.0, |
| "s_noise": 1.0, |
| "adaptive_eta": False, |
| "scheduler": "Standard", |
| "vae_reflection": False, |
| |
| |
| "use_custom_sampler": False, |
| "custom_sampler": "Akashic Solver v2", |
| "tau": 0.5, |
| "phase_strength": 0.5, |
| "solver_order": 2, |
| "use_corrector": True, |
| "phase_noise": False, |
| "enhanced_derivative": False, |
| "smea_strength": 0.0, |
| "ndb_strength": 0.0, |
| "eqvae_mode": "Off", |
|
|
| |
| "mirror_correction_phase": 0.5, |
| "mirror_smooth_phase": False, |
|
|
| |
| "cfg_drift_enabled": False, |
| "cfg_drift_method": "mean", |
| "cfg_drift_intensity": 0.5, |
| "spectral_cfg_enabled": False, |
| "spectral_multiplier": 1.0, |
| "spectral_percentile": 5.0, |
| "phase_cfg_enabled": False, |
| "phase_cfg_alpha": 2.0, |
| "phase_cfg_beta": 2.0, |
| "cfg_runtime_mode": "off", |
|
|
| |
| "_cfg_step_idx": 0, |
| "_cfg_total_steps": 1, |
| } |
|
|
| |
| ORIGINAL_SAMPLERS = {} |
|
|
| |
| _vae_reflection_active = False |
| _vae_original_padding_modes = {} |
|
|
| |
| _ADEPT_CFG_AFTER_CB = None |
| _ADEPT_CFG_DENOISER_CB = None |
| _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False |
|
|
| |
| _CFGD_ORIG_COMBINE = None |
| _CFGD_ORIG_COMBINE_EDIT = None |
| _CFGD_ORIG_FORWARD = None |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| _ADEPT_CFGDENOISER_CTX_ATTR = "_adept_cfg_ctx" |
|
|
| |
| |
| |
|
|
| def to_d(x, sigma, denoised): |
| """Convert denoised prediction to derivative.""" |
| diff = x - denoised |
| safe_sigma = torch.clamp(sigma, min=1e-4) |
| derivative = diff / safe_sigma |
| |
| sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) |
| derivative_max = torch.abs(derivative).max() |
| if derivative_max > sigma_adaptive_threshold: |
| derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) |
| |
| return derivative |
|
|
|
|
| def get_ancestral_step(sigma, sigma_next, eta=1.0): |
| """Calculate ancestral step sizes.""" |
| if sigma_next == 0: |
| return 0.0, 0.0 |
| sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) |
| sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 |
| return sigma_down, sigma_up |
|
|
|
|
| def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct): |
| """ |
| Compute weight scale for the current step with smooth fade-in/fade-out. |
| |
| The fade ramp is proportional to the active window width rather than a |
| fixed 0.1 absolute value. For a narrow window (e.g. start=0.4, end=0.5) |
| a hard-coded 0.1 ramp would consume the entire window and produce jarring |
| or contradictory behaviour; normalising to 20 % of window width keeps the |
| envelope sensible at any window size. |
| |
| Returns 1.0 (no-op) outside [start_pct, end_pct]. |
| """ |
| |
| start_pct = max(0.0, min(float(start_pct), 1.0)) |
| end_pct = max(start_pct, min(float(end_pct), 1.0)) |
| total_steps = max(total_steps, 1) |
| progress = step_idx / max(total_steps - 1, 1) |
|
|
| if progress < start_pct or progress > end_pct: |
| return 1.0 |
|
|
| window = end_pct - start_pct |
| if window < 1e-6: |
| |
| return float(base_scale) |
|
|
| |
| |
| ramp = min(0.20 * window, 0.05) |
|
|
| if ramp > 0 and progress < start_pct + ramp: |
| fade = (progress - start_pct) / ramp |
| return 1.0 + (base_scale - 1.0) * fade |
| elif ramp > 0 and progress > end_pct - ramp: |
| fade = (end_pct - progress) / ramp |
| return 1.0 + (base_scale - 1.0) * fade |
| else: |
| return float(base_scale) |
|
|
|
|
| def default_noise_sampler(x): |
| """Simple noise sampler fallback.""" |
| def sampler(sigma, sigma_next): |
| return torch.randn_like(x) |
| return sampler |
|
|
|
|
| def get_noise_sampler(x): |
| """Get noise sampler for the given tensor.""" |
| return default_noise_sampler(x) |
|
|
|
|
| |
| |
| |
|
|
| def to_d_enhanced_ancestral(x, sigma, denoised, eta, progress, generator=None): |
| """Enhanced derivative for ancestral sampling.""" |
| diff = x - denoised |
| safe_sigma = torch.clamp(sigma, min=1e-4) |
| base_derivative = diff / safe_sigma |
|
|
| def safe_randn_like(tensor, generator=None): |
| if generator is None: |
| return torch.randn_like(tensor) |
| try: |
| return torch.randn(tensor.shape, device=tensor.device, dtype=tensor.dtype, generator=generator) |
| except (TypeError, AttributeError): |
| return torch.randn_like(tensor) |
|
|
| if eta > 1.0: |
| eta_correction = 0.02 * (eta - 1.0) * safe_randn_like(diff, generator) * progress |
| base_derivative = base_derivative + eta_correction |
| elif eta < 1.0: |
| eta_correction = 0.015 * (1.0 - eta) * safe_randn_like(diff, generator) * (1.0 - progress) |
| base_derivative = base_derivative - eta_correction |
|
|
| if progress < 0.3: |
| phase_correction = 0.01 * safe_randn_like(diff, generator) |
| base_derivative = base_derivative + phase_correction |
| elif progress > 0.7: |
| phase_correction = 0.008 * safe_randn_like(diff, generator) |
| base_derivative = base_derivative - phase_correction |
|
|
| sigma_adaptive_threshold = 500.0 * (1.0 + sigma / 10.0) |
| derivative_max = torch.abs(base_derivative).max() |
| if derivative_max > sigma_adaptive_threshold: |
| base_derivative = torch.clamp(base_derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) |
|
|
| return base_derivative |
|
|
|
|
| def apply_dynamic_thresholding(x, percentile=0.995, clamp_range=1.0): |
| """Dynamic thresholding for high CFG.""" |
| if percentile >= 1.0: |
| return x |
| |
| try: |
| batch_size = x.shape[0] |
| x_flat = x.view(batch_size, -1) |
| |
| abs_max = torch.abs(x_flat).max(dim=1, keepdim=True)[0] |
| if abs_max.max() < 5.0: |
| return x |
| |
| k = max(1, int(x_flat.shape[1] * (1.0 - percentile))) |
| topk_vals = torch.topk(torch.abs(x_flat), k=k, dim=1, largest=True)[0] |
| s = topk_vals[:, -1:].clamp(min=1.0) |
| |
| threshold = s * 2.5 |
| mask = torch.abs(x_flat) > threshold |
| x_flat = torch.where(mask, torch.sign(x_flat) * threshold, x_flat) |
| x_flat = x_flat * 0.98 |
| |
| return x_flat.view(x.shape) |
| except Exception as e: |
| return x |
|
|
|
|
| def compute_compensation_ratio(r, step_idx, total_steps, base_ratio=1.0): |
| """DC-Solver compensation.""" |
| progress = step_idx / max(total_steps - 1, 1) |
| if progress < 0.3: |
| phase_weight = 1.5 |
| elif progress < 0.7: |
| phase_weight = 1.0 |
| else: |
| phase_weight = 1.3 |
| return base_ratio * phase_weight * (1.0 + 0.1 * math.tanh(r - 1.0)) |
|
|
|
|
| def compute_tau_eqvae(progress, base_tau=0.5, phase_strength=0.5): |
| """Phase-aware tau for standard VAE.""" |
| if progress < 0.30: |
| phase_factor = 1.0 + 0.2 * phase_strength |
| elif progress < 0.60: |
| phase_factor = 1.0 - 0.15 * phase_strength |
| else: |
| phase_factor = 1.0 - 0.3 * phase_strength |
| return min(1.0, max(0.0, base_tau * phase_factor)) |
|
|
|
|
| def compute_eqvae_tau(progress, base_tau, phase_strength): |
| """EQ-VAE tau with shifted phases.""" |
| if progress < 0.25: |
| phase_factor = 1.0 + 0.10 * phase_strength |
| elif progress < 0.55: |
| phase_factor = 1.0 - 0.10 * phase_strength |
| else: |
| phase_factor = 1.0 - 0.20 * phase_strength |
| return min(1.0, max(0.0, base_tau * phase_factor)) |
|
|
|
|
| def compute_eqvae_noise_scale(base_s_noise, progress): |
| """EQ-VAE noise scale.""" |
| eqvae_base_factor = 0.88 |
| if progress < 0.25: |
| phase_factor = 1.0 + 0.05 * (1.0 - progress / 0.25) |
| elif progress < 0.60: |
| phase_factor = 1.0 - 0.05 * ((progress - 0.25) / 0.35) |
| else: |
| phase_factor = 0.95 |
| return base_s_noise * eqvae_base_factor * phase_factor |
|
|
|
|
| def compute_eqvae_ndb(progress, ndb_strength): |
| """Native Detail Boost for EQ-VAE.""" |
| if ndb_strength <= 0: |
| return 0.5, 0.0 |
| |
| blur_sigma = 0.6 |
| if progress < 0.30: |
| phase_progress = progress / 0.30 |
| high_freq_boost = 0.03 * ndb_strength * phase_progress |
| elif progress < 0.60: |
| phase_progress = (progress - 0.30) / 0.30 |
| high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength |
| else: |
| phase_progress = (progress - 0.60) / 0.40 |
| high_freq_boost = (0.10 + 0.10 * phase_progress) * ndb_strength |
| return blur_sigma, high_freq_boost |
|
|
|
|
| def compute_native_detail_boost(progress, ndb_strength=0.0): |
| """Native Detail Boost for standard VAE.""" |
| if ndb_strength <= 0: |
| return 1.0, 0.0 |
| |
| if progress < 0.30: |
| phase_progress = progress / 0.30 |
| high_freq_boost = 0.03 * ndb_strength * phase_progress |
| elif progress < 0.60: |
| phase_progress = (progress - 0.30) / 0.30 |
| high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength |
| else: |
| phase_progress = (progress - 0.60) / 0.40 |
| high_freq_boost = (0.10 + 0.08 * phase_progress) * ndb_strength |
| return 1.0, high_freq_boost |
|
|
|
|
| def compute_smea_factor(progress, smea_strength=0.5): |
| """SMEA coherency.""" |
| if smea_strength <= 0: |
| return 1.0 |
| smea_interp = 0.5 * (1 + math.sin(math.pi * (progress - 0.5))) |
| return 1.0 - smea_strength * (1.0 - smea_interp) |
|
|
|
|
|
|
|
|
| |
| |
| |
|
|
| def apply_spectral_modulation_clybius(noise_pred, multiplier=1.0, percentile=5.0): |
| """ |
| Clybius Spectral Modulation: Apply frequency-domain corrections to noise prediction. |
| |
| This is the correct implementation based on ComfyUI-Latent-Modifiers. |
| It should be applied to noise_pred (cond - uncond), NOT to denoised latent. |
| |
| Args: |
| noise_pred: The noise prediction tensor (cond - uncond) |
| multiplier: Modulation strength (0=none, 1=full Clybius effect). Default: 1.0 |
| percentile: Upper/lower percentile threshold. Default: 5.0 |
| |
| Returns: |
| Spectrally modulated noise prediction |
| """ |
| if multiplier == 0 or percentile <= 0: |
| return noise_pred |
| |
| try: |
| |
| fourier = torch.fft.fft2(noise_pred, dim=(-2, -1)) |
| |
| |
| log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8) |
| |
| |
| log_amp_flat = log_amp.abs().flatten(2) |
| quantile_low = torch.quantile(log_amp_flat, percentile * 0.01, dim=2) |
| quantile_high = torch.quantile(log_amp_flat, 1 - percentile * 0.01, dim=2) |
| |
| |
| quantile_low = quantile_low.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) |
| quantile_high = quantile_high.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) |
| |
| |
| |
| |
| mask_low = ((log_amp < quantile_low).float() + 1).clamp_(max=1.5) |
| mask_high = ((log_amp < quantile_high).float()).clamp_(min=0.5) |
| |
| |
| filtered_fourier = fourier * ((mask_low * mask_high) ** multiplier) |
| |
| |
| result = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)).real |
| |
| return result |
| |
| except Exception as e: |
| print(f"⚠️ Spectral modulation failed: {e}") |
| return noise_pred |
|
|
|
|
|
|
| def create_spectral_modulation_cfg_hook(multiplier=1.0, percentile=5.0): |
| """ |
| Create a CFG hook that applies Clybius spectral modulation to noise prediction. |
| |
| This hooks into reForge's set_model_sampler_cfg_function to intercept |
| the CFG calculation and apply spectral modulation at the correct point. |
| |
| Args: |
| multiplier: Modulation strength (0=none, 1=full). Default: 1.0 |
| percentile: Frequency percentile threshold. Default: 5.0 |
| |
| Returns: |
| A hook function to pass to set_model_sampler_cfg_function |
| """ |
| def spectral_cfg_hook(args): |
| cond = args["cond"] |
| uncond = args["uncond"] |
| cond_scale = args["cond_scale"] |
| sigma = args["sigma"] |
| x_orig = args["input"] |
| |
| |
| sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) |
| |
| |
| x = x_orig / (sigma * sigma + 1.0) |
| cond_v = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) |
| uncond_v = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) |
| |
| |
| noise_pred = cond_v - uncond_v |
| |
| |
| noise_pred_modulated = apply_spectral_modulation_clybius(noise_pred, multiplier, percentile) |
| |
| |
| x_cfg = uncond_v + cond_scale * noise_pred_modulated |
| |
| |
| return x_orig - (x - x_cfg * sigma / (sigma * sigma + 1.0) ** 0.5) |
| |
| return spectral_cfg_hook |
|
|
|
|
|
|
| def apply_combat_cfg_drift(latent, method='mean', intensity=1.0): |
| """ |
| Combat CFG Drift: Reduce mean drift from high CFG values. |
| |
| Based on ComfyUI-Latent-Modifiers. |
| |
| As CFG increases, the latent mean can drift away from 0, which causes |
| color shifts and other artifacts. This technique reduces the drift |
| proportionally based on intensity. |
| |
| Args: |
| latent: The latent tensor to correct |
| method: 'mean' or 'median'. Default: 'mean' |
| intensity: How much drift to remove (0=none, 1=full). Default: 1.0 |
| |
| Returns: |
| Drift-corrected latent |
| """ |
| if intensity <= 0: |
| return latent |
|
|
| try: |
| if method == 'median': |
| |
| center = latent.view(latent.shape[0], -1).median(dim=-1, keepdim=True)[0] |
| center = center.view(latent.shape[0], 1, 1, 1) |
| else: |
| |
| |
| center = latent.mean(dim=(1, 2, 3), keepdim=True) |
|
|
| |
| |
| return latent - center * intensity |
|
|
| except Exception as e: |
| print(f"⚠️ Combat CFG drift failed: {e}") |
| return latent |
|
|
|
|
|
|
| def compute_phase_aware_cfg_scale(base_scale, progress, alpha=2.0, beta=2.0): |
| """ |
| Phase-Aware CFG Scaling: Adjust CFG scale based on sampling progress. |
| |
| Inspired by β-CFG (arXiv:2502.10574). |
| |
| CFG effectiveness varies by sampling phase: |
| - Early: Lower CFG allows manifold exploration |
| - Middle: Higher CFG for prompt adherence |
| - Late: Lower CFG to stay on data manifold |
| |
| Args: |
| base_scale: The user-specified CFG scale |
| progress: Sampling progress (0.0 to 1.0) |
| alpha: Beta distribution alpha parameter. Default: 2.0 |
| beta: Beta distribution beta parameter. Default: 2.0 |
| |
| Returns: |
| Adjusted CFG scale for the current step |
| """ |
| try: |
| |
| |
| |
| if alpha == 2.0 and beta == 2.0: |
| |
| scale_factor = 4.0 * progress * (1.0 - progress) |
| scale_factor = 0.7 + 0.6 * scale_factor |
| else: |
| |
| |
| mode = (alpha - 1.0) / (alpha + beta - 2.0) if (alpha + beta) > 2 else 0.5 |
| |
| dist_from_mode = abs(progress - mode) |
| scale_factor = 1.0 - 0.3 * dist_from_mode * 2 |
| scale_factor = max(0.7, min(1.3, scale_factor)) |
|
|
| return base_scale * scale_factor |
|
|
| except Exception as e: |
| print(f"⚠️ Phase-aware CFG scaling failed: {e}") |
| return base_scale |
|
|
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| def create_phase_aware_native_cfg_hook(base_hook=None, alpha=2.0, beta=2.0): |
| """ |
| Native CFG hook for Forge/reForge-like backends that support |
| set_model_sampler_cfg_function(). Applies phase-aware CFG scaling, |
| then optionally delegates to a downstream hook (e.g. spectral modulation). |
| """ |
| def hook(args): |
| cond = args["cond"] |
| uncond = args["uncond"] |
| cond_scale = float(args["cond_scale"]) |
| sigma = args["sigma"] |
| x_orig = args["input"] |
|
|
| total_steps = max(int(ADEPT_STATE.get("_cfg_total_steps", 1)), 1) |
| step_idx = int(ADEPT_STATE.get("_cfg_step_idx", 0)) |
| progress = min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) |
|
|
| phased_scale = compute_phase_aware_cfg_scale(cond_scale, progress, |
| alpha=alpha, beta=beta) |
| patched_args = dict(args) |
| patched_args["cond_scale"] = phased_scale |
|
|
| if base_hook is not None: |
| return base_hook(patched_args) |
|
|
| |
| sigma_b = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) |
| x = x_orig / (sigma_b * sigma_b + 1.0) |
| cond_v = ((x - (x_orig - cond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b |
| uncond_v = ((x - (x_orig - uncond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b |
| x_cfg = uncond_v + phased_scale * (cond_v - uncond_v) |
| return x_orig - (x - x_cfg * sigma_b / (sigma_b * sigma_b + 1.0) ** 0.5) |
| return hook |
|
|
|
|
| def create_combined_native_cfg_hook(): |
| """ |
| Build one composite native hook from whatever CFG features are enabled. |
| Layer order: phase-aware scale → spectral modulation. |
| Returns None if nothing is enabled (caller should clear the hook). |
| """ |
| base_hook = None |
| if ADEPT_STATE.get("spectral_cfg_enabled", False): |
| base_hook = create_spectral_modulation_cfg_hook( |
| multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), |
| percentile=ADEPT_STATE.get("spectral_percentile", 5.0), |
| ) |
| if ADEPT_STATE.get("phase_cfg_enabled", False): |
| return create_phase_aware_native_cfg_hook( |
| base_hook=base_hook, |
| alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), |
| beta=ADEPT_STATE.get("phase_cfg_beta", 2.0), |
| ) |
| return base_hook |
|
|
|
|
| def adept_cfg_denoiser_callback(params): |
| """ |
| Official A1111 on_cfg_denoiser callback. Used only to track step |
| counters for phase-aware progress bookkeeping; the public API here |
| doesn't expose cond/uncond predictions so we can't do CFG math. |
| """ |
| ADEPT_STATE["_cfg_step_idx"] = int(getattr(params, "sampling_step", 0)) |
| ADEPT_STATE["_cfg_total_steps"] = int(getattr(params, "total_sampling_steps", 1)) |
|
|
|
|
| def adept_after_cfg_callback(params): |
| """ |
| Official A1111 on_cfg_after_cfg callback. |
| Combat CFG Drift is the only technique that maps cleanly here, |
| because AfterCFGCallbackParams only provides (x, sampling_step, |
| total_sampling_steps) — no raw cond/uncond tensors. |
| """ |
| if not ADEPT_STATE.get("enabled", False): |
| return |
| if not ADEPT_STATE.get("cfg_drift_enabled", False): |
| return |
| try: |
| params.x = apply_combat_cfg_drift( |
| params.x, |
| method=ADEPT_STATE.get("cfg_drift_method", "mean"), |
| intensity=ADEPT_STATE.get("cfg_drift_intensity", 0.5), |
| ) |
| except Exception as e: |
| print(f"⚠️ Adept post-CFG drift callback failed: {e}") |
|
|
|
|
| def uninstall_a1111_cfg_callbacks(): |
| global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB |
| for cb in (_ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB): |
| if cb is not None: |
| try: |
| script_callbacks.remove_callbacks_for_function(cb) |
| except Exception: |
| pass |
| _ADEPT_CFG_AFTER_CB = None |
| _ADEPT_CFG_DENOISER_CB = None |
|
|
|
|
| def install_a1111_cfg_callbacks(): |
| global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB |
| uninstall_a1111_cfg_callbacks() |
| _ADEPT_CFG_DENOISER_CB = adept_cfg_denoiser_callback |
| _ADEPT_CFG_AFTER_CB = adept_after_cfg_callback |
| script_callbacks.on_cfg_denoiser( _ADEPT_CFG_DENOISER_CB, name="adept_cfg_denoiser") |
| script_callbacks.on_cfg_after_cfg( _ADEPT_CFG_AFTER_CB, name="adept_after_cfg") |
|
|
|
|
| def _get_native_cfg_hook_target(): |
| """ |
| Locate a Forge/reForge-like model object that supports |
| set_model_sampler_cfg_function(), if one exists. |
| """ |
| sd = getattr(shared, "sd_model", None) |
| candidates = [ |
| sd, |
| getattr(sd, "forge_objects", None), |
| getattr(sd, "model", None), |
| getattr(getattr(sd, "model", None), "model", None) if sd else None, |
| ] |
| for obj in candidates: |
| if obj is not None and hasattr(obj, "set_model_sampler_cfg_function"): |
| return obj |
| return None |
|
|
|
|
| def uninstall_native_cfg_hook(): |
| global _ADEPT_NATIVE_CFG_HOOK_ACTIVE |
| target = _get_native_cfg_hook_target() |
| if target is not None: |
| try: |
| target.set_model_sampler_cfg_function(None) |
| except Exception: |
| pass |
| _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False |
|
|
|
|
| def install_native_cfg_hook(): |
| global _ADEPT_NATIVE_CFG_HOOK_ACTIVE |
| target = _get_native_cfg_hook_target() |
| if target is None: |
| _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False |
| return False |
| hook = create_combined_native_cfg_hook() |
| try: |
| target.set_model_sampler_cfg_function(hook) |
| _ADEPT_NATIVE_CFG_HOOK_ACTIVE = (hook is not None) |
| return True |
| except Exception as e: |
| print(f"⚠️ Adept native CFG hook install failed: {e}") |
| _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def _adept_cfg_progress_from_denoiser(denoiser): |
| """Compute sampling progress [0,1] from CFGDenoiser step counters.""" |
| total_steps = max(int(getattr(denoiser, "total_steps", 1) or 1), 1) |
| step_idx = int(getattr(denoiser, "step", 0)) |
| return min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) |
|
|
|
|
| def _adept_nativeish_cfg_term(x_i, sigma_i, cond_i, uncond_i, scale): |
| """ |
| Approximate native hook behavior for one cond/uncond pair. |
| Feeds x/sigma/cond/uncond into the same composite hook builder used in |
| native-hook mode so stock A1111 gets as close as possible to Forge parity. |
| Falls back to plain weighted delta if hook is unavailable or errors. |
| """ |
| if abs(float(scale)) < 1e-12: |
| return torch.zeros_like(uncond_i) |
| hook = create_combined_native_cfg_hook() |
| if hook is None: |
| return (cond_i - uncond_i) * float(scale) |
| try: |
| combined = hook({ |
| "cond": cond_i, |
| "uncond": uncond_i, |
| "cond_scale": float(scale), |
| "sigma": sigma_i, |
| "input": x_i, |
| }) |
| return combined - uncond_i |
| except Exception as e: |
| print(f"⚠️ Adept native-ish CFG term fallback: {e}") |
| return (cond_i - uncond_i) * float(scale) |
|
|
|
|
| def patch_cfg_denoiser(): |
| """ |
| Stock A1111 fallback for Spectral Modulation + Phase-Aware CFG. |
| |
| Strategy: |
| 1. Thin forward() wrapper that stashes x / sigma on the instance so the |
| combine methods can reach them — original forward logic is untouched. |
| 2. Patched combine_denoised() uses those values to call the same composite |
| hook builder as native-hook mode, giving near-parity behaviour. |
| 3. Patched combine_denoised_for_edit_model() does the same for pix2pix. |
| |
| This is intentionally safer than a full forward() rewrite: upstream A1111 |
| changes to refiner/masking/skip-uncond logic remain unaffected. |
| """ |
| global _CFGD_ORIG_COMBINE, _CFGD_ORIG_COMBINE_EDIT, _CFGD_ORIG_FORWARD, _CFGD_MONKEYPATCH_ACTIVE |
| try: |
| from modules import sd_samplers_cfg_denoiser as sd_cfg |
| except Exception as e: |
| print(f"⚠️ Adept CFGDenoiser patch import failed: {e}") |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| return False |
|
|
| cls = sd_cfg.CFGDenoiser |
|
|
| if _CFGD_ORIG_COMBINE is None: |
| _CFGD_ORIG_COMBINE = cls.combine_denoised |
| if _CFGD_ORIG_COMBINE_EDIT is None: |
| _CFGD_ORIG_COMBINE_EDIT = cls.combine_denoised_for_edit_model |
| if _CFGD_ORIG_FORWARD is None: |
| _CFGD_ORIG_FORWARD = cls.forward |
|
|
| |
| def adept_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): |
| setattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, { |
| "x": x, |
| "sigma": sigma, |
| "uncond": uncond, |
| "cond": cond, |
| "cond_scale": float(cond_scale), |
| }) |
| try: |
| return _CFGD_ORIG_FORWARD(self, x, sigma, uncond, cond, |
| cond_scale, s_min_uncond, image_cond) |
| finally: |
| try: |
| delattr(self, _ADEPT_CFGDENOISER_CTX_ATTR) |
| except AttributeError: |
| pass |
|
|
| |
| def adept_combine_denoised(self, x_out, conds_list, uncond, cond_scale): |
| denoised_uncond = x_out[-uncond.shape[0]:] |
| denoised = torch.clone(denoised_uncond) |
| ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) |
| x_ctx = None if ctx is None else ctx.get("x", None) |
| sigma_ctx = None if ctx is None else ctx.get("sigma", None) |
| progress = _adept_cfg_progress_from_denoiser(self) |
|
|
| eff_scale = float(cond_scale) |
| if ADEPT_STATE.get("phase_cfg_enabled", False): |
| eff_scale = compute_phase_aware_cfg_scale( |
| eff_scale, progress, |
| alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), |
| beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), |
| ) |
|
|
| use_nativeish = ( |
| (ADEPT_STATE.get("spectral_cfg_enabled", False) or |
| ADEPT_STATE.get("phase_cfg_enabled", False)) |
| and x_ctx is not None and sigma_ctx is not None |
| ) |
|
|
| for i, conds in enumerate(conds_list): |
| for cond_index, weight in conds: |
| cond_i = x_out[cond_index:cond_index + 1] |
| uncond_i = denoised_uncond[i:i + 1] |
|
|
| if use_nativeish: |
| term = _adept_nativeish_cfg_term( |
| x_i = x_ctx[i:i + 1], |
| sigma_i = sigma_ctx[i:i + 1], |
| cond_i = cond_i, |
| uncond_i = uncond_i, |
| scale = float(weight) * eff_scale, |
| ) |
| denoised[i:i + 1] += term |
| else: |
| delta = cond_i - uncond_i |
| if ADEPT_STATE.get("spectral_cfg_enabled", False): |
| delta = apply_spectral_modulation_clybius( |
| delta, |
| multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), |
| percentile=ADEPT_STATE.get("spectral_percentile", 5.0), |
| ) |
| denoised[i:i + 1] += delta * (float(weight) * eff_scale) |
|
|
| return denoised |
|
|
| |
| def adept_combine_denoised_for_edit_model(self, x_out, cond_scale): |
| out_cond, out_img_cond, out_uncond = x_out.chunk(3) |
| ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) |
| x_ctx = None if ctx is None else ctx.get("x", None) |
| sigma_ctx = None if ctx is None else ctx.get("sigma", None) |
| progress = _adept_cfg_progress_from_denoiser(self) |
|
|
| eff_scale = float(cond_scale) |
| if ADEPT_STATE.get("phase_cfg_enabled", False): |
| eff_scale = compute_phase_aware_cfg_scale( |
| eff_scale, progress, |
| alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), |
| beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), |
| ) |
|
|
| |
| if (ADEPT_STATE.get("spectral_cfg_enabled", False) or |
| ADEPT_STATE.get("phase_cfg_enabled", False)): |
| if x_ctx is not None and sigma_ctx is not None: |
| try: |
| hook = create_combined_native_cfg_hook() |
| if hook is not None: |
| base = hook({ |
| "cond": out_cond, |
| "uncond": out_img_cond, |
| "cond_scale": eff_scale, |
| "sigma": sigma_ctx, |
| "input": x_ctx, |
| }) |
| return base + self.image_cfg_scale * (out_img_cond - out_uncond) |
| except Exception as e: |
| print(f"⚠️ Adept edit-model native-ish fallback: {e}") |
|
|
| |
| delta = out_cond - out_img_cond |
| if ADEPT_STATE.get("spectral_cfg_enabled", False): |
| delta = apply_spectral_modulation_clybius( |
| delta, |
| multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), |
| percentile=ADEPT_STATE.get("spectral_percentile", 5.0), |
| ) |
| return out_uncond + eff_scale * delta + self.image_cfg_scale * (out_img_cond - out_uncond) |
|
|
| try: |
| cls.forward = adept_forward |
| cls.combine_denoised = adept_combine_denoised |
| cls.combine_denoised_for_edit_model = adept_combine_denoised_for_edit_model |
| _CFGD_MONKEYPATCH_ACTIVE = True |
| return True |
| except Exception as e: |
| print(f"⚠️ Adept CFGDenoiser patch failed: {e}") |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| return False |
|
|
|
|
| def unpatch_cfg_denoiser(): |
| global _CFGD_MONKEYPATCH_ACTIVE |
| try: |
| from modules import sd_samplers_cfg_denoiser as sd_cfg |
| except Exception: |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| return False |
| cls = sd_cfg.CFGDenoiser |
| try: |
| if _CFGD_ORIG_FORWARD is not None: |
| cls.forward = _CFGD_ORIG_FORWARD |
| if _CFGD_ORIG_COMBINE is not None: |
| cls.combine_denoised = _CFGD_ORIG_COMBINE |
| if _CFGD_ORIG_COMBINE_EDIT is not None: |
| cls.combine_denoised_for_edit_model = _CFGD_ORIG_COMBINE_EDIT |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| return True |
| except Exception as e: |
| print(f"⚠️ Adept CFGDenoiser unpatch failed: {e}") |
| _CFGD_MONKEYPATCH_ACTIVE = False |
| return False |
|
|
|
|
| def configure_cfg_runtime(): |
| """ |
| Select and activate the right CFG runtime mode: |
| off – nothing enabled; all hooks/callbacks/patches cleared |
| a1111-postcfg – stock A1111; Combat CFG Drift only via official callback |
| a1111-monkeypatch – stock A1111; Spectral + Phase-Aware via CFGDenoiser patch |
| native-hook – Forge/reForge-like backend; all three via sampler CFG hook |
| |
| Returns the mode string so process() can log it. |
| """ |
| |
| if not ADEPT_STATE.get("enabled", False): |
| uninstall_a1111_cfg_callbacks() |
| uninstall_native_cfg_hook() |
| unpatch_cfg_denoiser() |
| ADEPT_STATE["cfg_runtime_mode"] = "off" |
| return "off" |
|
|
| drift = ADEPT_STATE.get("cfg_drift_enabled", False) |
| spectral = ADEPT_STATE.get("spectral_cfg_enabled", False) |
| phase = ADEPT_STATE.get("phase_cfg_enabled", False) |
|
|
| |
| uninstall_a1111_cfg_callbacks() |
| uninstall_native_cfg_hook() |
| unpatch_cfg_denoiser() |
|
|
| if not (drift or spectral or phase): |
| ADEPT_STATE["cfg_runtime_mode"] = "off" |
| return "off" |
|
|
| |
| native_target = _get_native_cfg_hook_target() |
| if native_target is not None: |
| install_a1111_cfg_callbacks() |
| install_native_cfg_hook() |
| ADEPT_STATE["cfg_runtime_mode"] = "native-hook" |
| return "native-hook" |
|
|
| |
| install_a1111_cfg_callbacks() |
|
|
| if spectral or phase: |
| ok = patch_cfg_denoiser() |
| if ok: |
| ADEPT_STATE["cfg_runtime_mode"] = "a1111-monkeypatch" |
| print("✅ Adept: stock A1111 CFGDenoiser monkey-patch active (spectral/phase + drift enabled)") |
| return "a1111-monkeypatch" |
| print("⚠️ Adept: CFGDenoiser monkey-patch failed; falling back to post-CFG drift only") |
|
|
| ADEPT_STATE["cfg_runtime_mode"] = "a1111-postcfg" |
| return "a1111-postcfg" |
|
|
|
|
| def sa_solver_step(x, d_history, sigma, sigma_next, tau, s_noise=1.0, noise_sampler=None, |
| order=2, ndb_strength=0.0, progress=0.0, eqvae_mode=False, eqvae_blur_sigma=None): |
| """SA-Solver step - CRITICAL for Akashic Solver.""" |
| dt = sigma_next - sigma |
| |
| if len(d_history) >= 2 and order >= 2: |
| sigma_cur, d_cur = d_history[-1] |
| sigma_prev, d_prev = d_history[-2] |
| |
| h_prev = sigma_cur - sigma_prev |
| r = abs(dt / (h_prev + 1e-8)) if abs(h_prev) > 1e-8 else 1.0 |
| r = min(r, 2.0) |
| |
| if len(d_history) >= 3 and order >= 3: |
| sigma_0, d_0 = d_history[-3] |
| h_0 = sigma_prev - sigma_0 |
| h_1 = h_prev |
| |
| if abs(h_0) > 1e-6 and abs(h_1) > 1e-6: |
| r0 = min(abs(h_1 / h_0), 2.0) |
| r1 = min(abs(dt / (h_1 + 1e-8)), 2.0) |
| |
| tau_blend = 1.0 - tau |
| c0_ab3 = 1.0 + (1.0 + r0) * r1 / 2.0 |
| c1_ab3 = -(1.0 + r0) * r1 / 2.0 |
| c2_ab3 = r0 * r1 / 2.0 |
| c0 = tau_blend * c0_ab3 + (1.0 - tau_blend) * 1.0 |
| c1 = tau_blend * c1_ab3 |
| c2 = tau_blend * c2_ab3 |
| |
| c_sum = c0 + c1 + c2 |
| if abs(c_sum) > 1e-8: |
| c0 /= c_sum |
| c1 /= c_sum |
| c2 /= c_sum |
| else: |
| c0, c1, c2 = 1.0, 0.0, 0.0 |
| |
| d_interp = c0 * d_cur + c1 * d_prev + c2 * d_0 |
| else: |
| tau_blend = 1.0 - tau |
| c1_ab2 = 1.0 + 0.5 * r |
| c2_ab2 = -0.5 * r |
| c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 |
| c2 = tau_blend * c2_ab2 |
| c_sum = c1 + c2 |
| if abs(c_sum) > 1e-8: |
| c1 /= c_sum |
| c2 /= c_sum |
| d_interp = c1 * d_cur + c2 * d_prev |
| else: |
| tau_blend = 1.0 - tau |
| c1_ab2 = 1.0 + 0.5 * r |
| c2_ab2 = -0.5 * r |
| c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 |
| c2 = tau_blend * c2_ab2 |
| c_sum = c1 + c2 |
| if abs(c_sum) > 1e-8: |
| c1 /= c_sum |
| c2 /= c_sum |
| d_interp = c1 * d_cur + c2 * d_prev |
| elif len(d_history) >= 1: |
| d_interp = d_history[-1][1] |
| else: |
| d_interp = torch.zeros_like(x) |
| |
| |
| sigma_up = 0.0 |
| if tau > 0 and sigma_next > 0 and noise_sampler is not None: |
| sigma_ancestral_sq = sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / (sigma ** 2 + 1e-8) |
| sigma_ancestral = sigma_ancestral_sq ** 0.5 if sigma_ancestral_sq > 0 else 0.0 |
| sigma_up = tau * sigma_ancestral |
| |
| sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 |
| dt_adjusted = sigma_down - sigma |
| |
| x_det = x + d_interp * dt_adjusted |
| noise = noise_sampler(sigma, sigma_next) * s_noise * sigma_up |
| |
| |
| if ndb_strength > 0 and TORCHVISION_AVAILABLE: |
| |
| if eqvae_mode: |
| blur_sigma, high_freq_boost = compute_eqvae_ndb(progress, ndb_strength) |
| else: |
| _, high_freq_boost = compute_native_detail_boost(progress, ndb_strength) |
| blur_sigma = 0.5 |
| |
| |
| if eqvae_blur_sigma is not None: |
| blur_sigma = eqvae_blur_sigma |
| |
| |
| try: |
| low_freq_noise = gaussian_blur(noise, kernel_size=3, sigma=blur_sigma) |
| high_freq_noise = noise - low_freq_noise |
| noise = noise + high_freq_noise * high_freq_boost |
| except Exception: |
| pass |
| |
| x_next = x_det + noise |
| else: |
| x_next = x + d_interp * dt |
| |
| return x_next, sigma_up |
|
|
|
|
| def create_detail_enhanced_model(model, x, sigmas, settings): |
| |
| |
| |
| |
| """Detail enhancement wrapper.""" |
| if not TORCHVISION_AVAILABLE: |
| return model |
| |
| base_strength = settings.get('detail_enhancement_strength', 0.05) |
| radius = settings.get('detail_separation_radius', 0.5) |
| total_steps = len(sigmas) - 1 |
| |
| class DetailEnhancer: |
| def __init__(self): |
| self.current_step = 0 |
| |
| def __call__(self, x_current, sigma, **kwargs): |
| denoised = model(x_current, sigma, **kwargs) |
| |
| try: |
| low_freq = gaussian_blur(denoised, kernel_size=3, sigma=radius) |
| high_freq = denoised - low_freq |
| |
| progress = min(self.current_step / max(total_steps, 1), 1.0) |
| strength = base_strength * (0.5 + progress) |
| |
| enhanced = denoised + high_freq * strength |
| self.current_step += 1 |
| return enhanced |
| except Exception: |
| return denoised |
| |
| return DetailEnhancer() |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_adept_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, |
| order=2, use_corrector=True, use_detail_enhancement=False, settings=None): |
| """ |
| Adept Solver: A unified training-free diffusion solver synthesizing improvements from: |
| - DPM-Solver++ (data prediction, dynamic thresholding) |
| - UniPC (unified predictor-corrector framework) |
| - DEIS (exponential integrator) |
| - DC-Solver (dynamic compensation) |
| """ |
| extra_args = {} if extra_args is None else extra_args |
| settings = settings or {} |
| s_in = x.new_ones([x.shape[0]]) |
| |
| order = max(1, min(order, 3)) |
| |
| print(f"🚀 Adept Solver active (Order: {order}, Corrector: {'On' if use_corrector else 'Off'})") |
| |
| active_model = model |
| |
| |
| |
| if use_detail_enhancement and TORCHVISION_AVAILABLE: |
| active_model = create_detail_enhanced_model(model, x, sigmas, settings) |
|
|
| model_outputs = [] |
| |
| for i in range(len(sigmas) - 1): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| denoised = active_model(x, sigma * s_in, **extra_args) |
| |
| if extra_args.get('cond_scale', 1.0) > 7.0: |
| denoised = apply_dynamic_thresholding(denoised, percentile=0.995) |
| |
| d = to_d(x, sigma, denoised) |
|
|
| derivative_max = torch.abs(d).max() |
| sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) |
| if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: |
| print(f"⚠️ Extreme derivative detected at step {i}/{len(sigmas)-1}. Clamping for stability.") |
| d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.zeros_like(d) |
|
|
| model_outputs.append((sigma, d)) |
| if len(model_outputs) > order: |
| model_outputs.pop(0) |
| |
| dt = sigma_next - sigma |
| |
| if len(model_outputs) == 1 or order == 1: |
| x_pred = x + d * dt |
| elif len(model_outputs) == 2 and order >= 2: |
| sigma_prev, d_prev = model_outputs[-2] |
| d_cur = model_outputs[-1][1] |
| |
| h = sigma - sigma_prev |
| compensation_ratio = compute_compensation_ratio(h.item() if torch.is_tensor(h) else float(h), i, len(sigmas)) |
| |
| d_interp = d_cur + compensation_ratio * (d_cur - d_prev) |
| x_pred = x + d_interp * dt |
| else: |
| sigma_0, d_0 = model_outputs[-3] |
| sigma_1, d_1 = model_outputs[-2] |
| sigma_2, d_2 = model_outputs[-1] |
| |
| h_0 = sigma_2 - sigma_1 |
| h_1 = sigma_1 - sigma_0 |
| |
| h_0_val = h_0.item() if torch.is_tensor(h_0) else float(h_0) |
| h_1_val = h_1.item() if torch.is_tensor(h_1) else float(h_1) |
| |
| if abs(h_1_val) < 1e-6: |
| compensation_ratio = compute_compensation_ratio(h_0_val, i, len(sigmas)) |
| d_interp = d_2 + compensation_ratio * (d_2 - d_1) |
| else: |
| r0 = h_0_val / h_1_val |
| c0 = 1.0 + r0 / 2.0 |
| c1 = -r0 / 2.0 |
| c2 = 0.0 |
| |
| c_sum = c0 + c1 + c2 |
| c0 /= c_sum |
| c1 /= c_sum |
| c2 = 1.0 - c0 - c1 |
| |
| d_interp = c0 * d_2 + c1 * d_1 + c2 * d_0 |
| |
| x_pred = x + d_interp * dt |
| |
| if use_corrector and i < len(sigmas) - 2: |
| denoised_pred = active_model(x_pred, sigma_next * s_in, **extra_args) |
| |
| if extra_args.get('cond_scale', 1.0) > 7.0: |
| denoised_pred = apply_dynamic_thresholding(denoised_pred, percentile=0.995) |
|
|
| d_pred = to_d(x_pred, sigma_next, denoised_pred) |
|
|
| if torch.isnan(d_pred).any() or torch.isinf(d_pred).any() or torch.abs(d_pred).max() > 1000.0: |
| d_pred = torch.clamp(d_pred, -100.0, 100.0) |
| if torch.isnan(d_pred).any() or torch.isinf(d_pred).any(): |
| d_pred = torch.zeros_like(d_pred) |
| |
| dt = sigma_next - sigma |
| x = x + (d + d_pred) * dt * 0.5 |
| else: |
| x = x_pred |
| |
| if torch.isnan(x).any() or torch.isinf(x).any(): |
| print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") |
| if i == 0: |
| raise RuntimeError("NaN/Inf on first step - check model/inputs") |
| |
| print(" Attempting recovery with conservative Euler step...") |
| denoised_safe = active_model(x, sigma * s_in, **extra_args) |
| if torch.isnan(denoised_safe).any(): |
| raise RuntimeError("Model producing NaN - check CFG scale and model") |
| |
| d_safe = to_d(x, sigma, denoised_safe) |
| dt_safe = (sigma_next - sigma) * 0.5 |
| x = x + d_safe * dt_safe |
| use_corrector = False |
| print(" Recovery successful. Corrector disabled for stability.") |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
| def sample_adept_ancestral_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, |
| eta=1.0, s_noise=1.0, adaptive_eta=False, phase_noise=False, |
| phase_strength=0.5, enhanced_derivative=False, |
| use_detail_enhancement=False, settings=None): |
| """ |
| Enhanced Adept Ancestral Solver: Advanced ancestral sampling with phase-aware adaptations. |
| |
| Key innovations: |
| 1. Adaptive ancestral step sizing that changes throughout sampling phases |
| 2. Phase-aware noise injection (more noise early, less noise late) |
| 3. Enhanced derivative computation with ancestral-specific corrections |
| 4. Dynamic eta scheduling for better control |
| """ |
| extra_args = {} if extra_args is None else extra_args |
| settings = settings or {} |
| s_in = x.new_ones([x.shape[0]]) |
| |
| print(f"🚀 Enhanced Adept Ancestral Solver active (η: {eta:.2f}, s_noise: {s_noise:.2f})") |
| print(f" Adaptive Eta: {adaptive_eta}, Phase Noise: {phase_noise}, Enhanced Derivative: {enhanced_derivative}") |
| |
| active_model = model |
| |
| if use_detail_enhancement and TORCHVISION_AVAILABLE: |
| active_model = create_detail_enhanced_model(model, x, sigmas, settings) |
|
|
| noise_sampler = get_noise_sampler(x) |
| for i in range(len(sigmas) - 1): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| progress = i / max(len(sigmas) - 1, 1) |
| |
| if adaptive_eta: |
| if progress < 0.3: |
| current_eta = eta * 1.08 |
| elif progress < 0.7: |
| current_eta = eta * 0.95 |
| else: |
| current_eta = eta * 1.02 |
| else: |
| current_eta = eta |
| |
| denoised = active_model(x, sigma * s_in, **extra_args) |
| |
| if extra_args.get('cond_scale', 1.0) > 7.0: |
| denoised = apply_dynamic_thresholding(denoised, percentile=0.995) |
| |
| if enhanced_derivative: |
| d = to_d_enhanced_ancestral(x, sigma, denoised, current_eta, progress, None) |
| else: |
| d = to_d(x, sigma, denoised) |
|
|
| derivative_max = torch.abs(d).max() |
| sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) |
| if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: |
| d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.zeros_like(d) |
|
|
| if sigma_next > 0: |
| sigma_up = min(sigma_next, current_eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) |
| sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 |
| else: |
| sigma_up = 0.0 |
| sigma_down = 0.0 |
| |
| dt = sigma_down - sigma |
| x_pred = x + d * dt |
| |
| if sigma_next > 0: |
| if phase_noise: |
| if progress < 0.25: |
| target_multiplier = 1.0 + (0.05 * min(progress / 0.25, 1.0)) |
| elif progress < 0.6: |
| target_multiplier = 1.0 - (0.02 * min((progress - 0.25) / 0.35, 1.0)) |
| else: |
| target_multiplier = 1.0 - (0.05 * min((progress - 0.6) / 0.4, 1.0)) |
|
|
| noise_multiplier = 1.0 + (target_multiplier - 1.0) * phase_strength |
| adaptive_s_noise = s_noise * noise_multiplier |
| else: |
| adaptive_s_noise = s_noise |
| |
| noise = noise_sampler(sigma, sigma_next) * adaptive_s_noise * sigma_up |
| x = x_pred + noise |
| else: |
| x = x_pred |
| |
| if torch.isnan(x).any() or torch.isinf(x).any(): |
| print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") |
| if i == 0: |
| raise RuntimeError("NaN/Inf on first step - check model/inputs") |
| |
| print(" Attempting recovery...") |
| denoised_safe = active_model(x, sigma * s_in, **extra_args) |
| if torch.isnan(denoised_safe).any(): |
| raise RuntimeError("Model producing NaN - check CFG scale and model") |
| |
| d_safe = to_d(x, sigma, denoised_safe) |
| dt_safe = (sigma_next - sigma) * 0.5 |
| x = x + d_safe * dt_safe |
| print(" Recovery successful.") |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
| def sample_mirror_correction_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, |
| eta=1.0, s_noise=1.0, correction_phase=0.5, smooth_phase=False): |
| """ |
| Mirror Correction Euler: Euler Ancestral with a semantic reflection probe. |
| |
| In the first `correction_phase` fraction of steps, uses a 3-call Heun correction: |
| x_probe = 2*D(x) - x (reflection of x through its own denoised prediction) |
| The probe lies on the denoising trajectory, giving a curvature estimate for the |
| Heun correction. Remaining steps: standard 1-call Euler Ancestral. |
| |
| Args: |
| eta: Ancestral noise coefficient. 0=deterministic, 1=full ancestral. Default: 1.0 |
| s_noise: Noise scale multiplier. Default: 1.0 |
| correction_phase: Fraction of steps that receive the 3-call correction. Default: 0.5 |
| smooth_phase: Use continuous log-sigma weighting instead of a binary cutoff. Default: False |
| """ |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
|
|
| print(f"🔮 Mirror Correction Euler active (η: {eta:.2f}, s_noise: {s_noise:.2f})") |
| print(f" Correction Phase: {correction_phase:.2f}, Smooth Phase: {smooth_phase}") |
|
|
| noise_sampler = get_noise_sampler(x) |
| n_steps = len(sigmas) - 1 |
|
|
| log_sigma_phase = None |
| log_sigma_max = None |
| smooth_denom = 1e-6 |
| if smooth_phase and n_steps > 0: |
| sigma_max_val = sigmas[0].clamp(min=1e-6) |
| phase_idx = min(int(correction_phase * n_steps), n_steps - 1) |
| sigma_phase_val = sigmas[phase_idx].clamp(min=1e-6) |
| log_sigma_max = torch.log(sigma_max_val).item() |
| log_sigma_phase = torch.log(sigma_phase_val).item() |
| smooth_denom = max(log_sigma_max - log_sigma_phase, 1e-6) |
|
|
| for i in range(n_steps): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| progress = i / max(n_steps - 1, 1) |
|
|
| denoised = model(x, sigma * s_in, **extra_args) |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
|
|
| d = to_d(x, sigma, denoised) |
|
|
| if sigma_next > 0: |
| sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) |
| sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 |
| else: |
| sigma_up = 0.0 |
| sigma_down = 0.0 |
| dt = sigma_down - sigma |
|
|
| if smooth_phase and log_sigma_phase is not None: |
| log_sig = torch.log(sigma.clamp(min=1e-6)).item() |
| t = max(0.0, min(1.0, (log_sig - log_sigma_phase) / smooth_denom)) |
| correction_weight = t ** 0.5 |
|
|
| if correction_weight > 1e-3 and sigma_next > 0: |
| x_probe = 2 * denoised - x |
| d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) |
|
|
| d_diff_norm = (d - d_probe).norm() |
| d_scale = (d.norm() + d_probe.norm()) / 2 + 1e-6 |
| gradient_agreement = max(0.0, 1.0 - (d_diff_norm / d_scale).item()) |
| effective_weight = correction_weight * gradient_agreement |
|
|
| if effective_weight > 1e-3: |
| x3 = x + ((d + d_probe) / 2) * dt |
| d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) |
| d_heun = (d + d3) / 2 |
| if not (torch.isnan(d_heun).any() or torch.isinf(d_heun).any()): |
| d = d + effective_weight * (d_heun - d) |
| else: |
| if progress < correction_phase and sigma_next > 0: |
| x_probe = 2 * denoised - x |
| d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) |
| x3 = x + ((d + d_probe) / 2) * dt |
| d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) |
| d = (d + d3) / 2 |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.zeros_like(d) |
|
|
| x = x + d * dt |
| if sigma_next > 0: |
| x = x + noise_sampler(sigma, sigma_next) * s_noise * sigma_up |
|
|
| return x |
|
|
|
|
| @torch.no_grad() |
| def sample_akashic_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, |
| tau=0.5, eta=1.0, s_noise=1.0, adaptive_eta=True, phase_strength=0.5, |
| order=2, smea_strength=0.0, ndb_strength=0.0, |
| use_detail_enhancement=False, settings=None, eqvae_mode='Off'): |
| """ |
| AkashicSolver v2 [EXPERIMENTAL]: Advanced sampler optimized for EQ-VAE models. |
| |
| Combines: |
| 1. SA-SOLVER BASE: Multi-step Adams-Bashforth integration with tau function |
| 2. PHASE-AWARE SAMPLING: Three-phase approach with adaptive parameters |
| 3. SMEA COHERENCY: Sine-based interpolation for high-resolution coherency |
| |
| Args: |
| eqvae_mode: EQ-VAE optimization mode ('Off' or 'Balanced') |
| """ |
| extra_args = {} if extra_args is None else extra_args |
| settings = settings or {} |
| s_in = x.new_ones([x.shape[0]]) |
|
|
| if isinstance(eqvae_mode, bool): |
| eqvae_enabled = eqvae_mode |
| else: |
| eqvae_enabled = eqvae_mode == 'Balanced' |
|
|
| if eqvae_enabled: |
| print(f"🌀 AkashicSolver v2 [EQ-VAE BALANCED] active") |
| print(f" Optimized for EQ-VAE's cleaner latent space") |
| else: |
| print(f"🌀 AkashicSolver v2 [EXPERIMENTAL] active") |
| print(f" τ (tau): {tau:.2f}, η (eta): {eta:.2f}, s_noise: {s_noise:.2f}") |
| print(f" Order: {order}, Adaptive Eta: {adaptive_eta}, Phase Strength: {phase_strength:.2f}") |
| if smea_strength > 0: |
| print(f" SMEA: {smea_strength:.2f} (high-res coherency)") |
| if ndb_strength > 0: |
| print(f" Native Detail Boost: {ndb_strength:.2f} (detail enhancement)") |
| if not eqvae_enabled: |
| print(f" ⚠️ Use external rescaleCFG (e.g., 0.7) for EQ-VAE models") |
| |
| active_model = model |
| |
| if use_detail_enhancement and TORCHVISION_AVAILABLE: |
| active_model = create_detail_enhanced_model(model, x, sigmas, settings) |
|
|
| noise_sampler = get_noise_sampler(x) |
| total_steps = len(sigmas) - 1 |
| d_history = [] |
| |
| for i in range(total_steps): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| progress = i / max(total_steps - 1, 1) |
| |
| if adaptive_eta: |
| if eqvae_enabled: |
| current_tau = compute_eqvae_tau(progress, tau, phase_strength) |
| else: |
| current_tau = compute_tau_eqvae(progress, tau, phase_strength) |
| else: |
| current_tau = tau |
|
|
| if adaptive_eta: |
| if eqvae_enabled: |
| if progress < 0.25: |
| current_eta = eta * (1.0 + 0.03 * phase_strength) |
| elif progress < 0.55: |
| current_eta = eta * (1.0 - 0.03 * phase_strength) |
| else: |
| current_eta = eta * (1.0 + 0.02 * phase_strength) |
| else: |
| if progress < 0.30: |
| current_eta = eta * (1.0 + 0.08 * phase_strength) |
| elif progress < 0.60: |
| current_eta = eta * (1.0 - 0.05 * phase_strength) |
| else: |
| current_eta = eta * (1.0 + 0.02 * phase_strength) |
| else: |
| current_eta = eta |
| |
| smea_factor = compute_smea_factor(progress, smea_strength) |
| |
| denoised = active_model(x, sigma * s_in, **extra_args) |
| |
| cfg_scale = extra_args.get('cond_scale', 1.0) |
| if cfg_scale > 7.0: |
| denoised = apply_dynamic_thresholding(denoised, percentile=0.995) |
| |
| d = to_d(x, sigma, denoised) |
| |
| derivative_max = torch.abs(d).max() |
| sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) |
| if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: |
| d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.zeros_like(d) |
| |
| d_history.append((sigma, d)) |
| if len(d_history) > order: |
| d_history.pop(0) |
| |
| effective_tau = current_tau |
|
|
| if eqvae_enabled: |
| effective_s_noise = compute_eqvae_noise_scale(s_noise * current_eta, progress) * smea_factor |
| else: |
| effective_s_noise = s_noise * current_eta * smea_factor |
|
|
| if progress < 0.30: |
| noise_multiplier = 1.0 + 0.03 * phase_strength |
| elif progress < 0.60: |
| noise_multiplier = 1.0 - 0.01 * phase_strength |
| else: |
| noise_multiplier = 1.0 - 0.02 * phase_strength |
|
|
| effective_s_noise *= noise_multiplier |
|
|
| if eqvae_enabled and ndb_strength > 0: |
| eqvae_blur_sigma, _ = compute_eqvae_ndb(progress, ndb_strength) |
| else: |
| eqvae_blur_sigma = None |
|
|
| x, sigma_up = sa_solver_step( |
| x=x, |
| d_history=d_history, |
| sigma=sigma, |
| sigma_next=sigma_next, |
| tau=effective_tau, |
| s_noise=effective_s_noise, |
| noise_sampler=noise_sampler, |
| order=order, |
| ndb_strength=ndb_strength, |
| progress=progress, |
| eqvae_mode=eqvae_enabled, |
| eqvae_blur_sigma=eqvae_blur_sigma |
| ) |
| |
| if torch.isnan(x).any() or torch.isinf(x).any(): |
| print(f"❌ AkashicSolver v2: NaN/Inf detected at step {i}/{total_steps}!") |
| |
| if i == 0: |
| raise RuntimeError("NaN/Inf on first step - check model/inputs") |
| |
| print(" Attempting recovery...") |
| denoised_safe = active_model(x, sigma * s_in, **extra_args) |
| if torch.isnan(denoised_safe).any(): |
| raise RuntimeError("Model producing NaN - reduce CFG scale or check model") |
| |
| d_safe = to_d(x, sigma, denoised_safe) |
| dt_safe = (sigma_next - sigma) * 0.5 |
| x = x + d_safe * dt_safe |
| |
| d_history.clear() |
| print(" Recovery successful. Multi-step history cleared.") |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| def should_patch_weights(unet_model, scale, shift): |
| """Return True if weight patching is actually needed for these parameters.""" |
| return ( |
| unet_model is not None and |
| (abs(scale - 1.0) > 1e-6 or abs(shift) > 1e-6) |
| ) |
|
|
|
|
| class AdeptWeightPatcher: |
| """Temporary weight scaling for UNet.""" |
| |
| def __init__(self, unet_model, scale=1.0, shift=0.0): |
| self.unet_model = unet_model |
| self.scale = scale |
| self.shift = shift |
| self.backups = {} |
| self.target_layers = [] |
| |
| def __enter__(self): |
| if self.unet_model is None or (abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6): |
| return self |
| |
| self.target_layers.clear() |
| self.backups.clear() |
| |
| try: |
| for name, module in self.unet_model.named_modules(): |
| if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): |
| if hasattr(module, 'weight') and module.weight is not None: |
| self.target_layers.append((name, module)) |
| self.backups[name] = module.weight.data.clone() |
| module.weight.data = module.weight.data * self.scale + self.shift |
| except Exception as e: |
| print(f"❌ Weight patcher failed: {e}") |
| self.__exit__(None, None, None) |
| |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| try: |
| for name, module in self.target_layers: |
| if name in self.backups: |
| module.weight.data.copy_(self.backups[name]) |
| self.backups.clear() |
| self.target_layers.clear() |
| except Exception as e: |
| print(f"❌ CRITICAL: Failed to restore weights: {e}") |
| for name, backup_data in self.backups.items(): |
| try: |
| for n, m in self.target_layers: |
| if n == name: |
| m.weight.data.copy_(backup_data) |
| except: |
| pass |
| |
| return False |
|
|
|
|
| |
| |
| |
|
|
| class VAEReflectionPatcher: |
| """Context manager for VAE reflection padding.""" |
| |
| def __init__(self, vae_model): |
| self.vae_model = vae_model |
| self.backups = {} |
| |
| def __enter__(self): |
| global _vae_reflection_active, _vae_original_padding_modes |
| |
| if _vae_reflection_active or self.vae_model is None: |
| return self |
| |
| _vae_original_padding_modes.clear() |
| patched_count = 0 |
| |
| try: |
| for name, module in self.vae_model.named_modules(): |
| if isinstance(module, torch.nn.Conv2d): |
| _vae_original_padding_modes[name] = module.padding_mode |
| module.padding_mode = 'reflect' |
| patched_count += 1 |
| |
| _vae_reflection_active = True |
| print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers") |
| except Exception as e: |
| print(f"❌ VAE Reflection failed: {e}") |
| self.__exit__(None, None, None) |
| |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| global _vae_reflection_active, _vae_original_padding_modes |
| |
| if self.vae_model is None: |
| _vae_reflection_active = False |
| _vae_original_padding_modes.clear() |
| return False |
| |
| restored_count = 0 |
| try: |
| for name, module in self.vae_model.named_modules(): |
| if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: |
| module.padding_mode = _vae_original_padding_modes[name] |
| restored_count += 1 |
| |
| _vae_reflection_active = False |
| _vae_original_padding_modes.clear() |
| print(f"🔄 VAE Reflection: Restored {restored_count} layers") |
| except Exception as e: |
| print(f"⚠️ VAE Reflection restore warning: {e}") |
| |
| return False |
|
|
|
|
| def force_restore_vae_reflection(): |
| """ |
| Emergency / unload-path restore for VAE padding modes. |
| Safe to call at any time — does nothing if VAE reflection was not active. |
| """ |
| global _vae_reflection_active, _vae_original_padding_modes |
| if not _vae_reflection_active and not _vae_original_padding_modes: |
| return |
| try: |
| sd = getattr(shared, "sd_model", None) |
| vae = getattr(sd, "first_stage_model", None) if sd else None |
| if vae is not None: |
| restored = 0 |
| for name, module in vae.named_modules(): |
| if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: |
| module.padding_mode = _vae_original_padding_modes[name] |
| restored += 1 |
| if restored: |
| print(f"🔄 VAE Reflection: force-restored {restored} layers") |
| except Exception as e: |
| print(f"⚠️ VAE Reflection force-restore warning: {e}") |
| finally: |
| _vae_reflection_active = False |
| _vae_original_padding_modes.clear() |
|
|
|
|
| |
| |
|
|
| def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """AOS-V (Anime-Optimized Schedule for v-prediction models).""" |
| rho = 7.0 |
| |
| p1_steps = int(num_steps * 0.2) |
| p2_steps = int(num_steps * 0.6) |
| |
| ramp = torch.empty(num_steps, device=device, dtype=torch.float32) |
| |
| if p1_steps > 0: |
| torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps]) |
| ramp[:p1_steps].pow_(0.5).mul_(0.6) |
| |
| if p2_steps > p1_steps: |
| torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps]) |
| |
| if num_steps > p2_steps: |
| torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:]) |
| ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho) |
| |
| return torch.cat([ramp, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """AOS-ε (Anime-Optimized Schedule for epsilon-prediction models).""" |
| rho = 7.0 |
| |
| p1_frac, p2_frac = 0.35, 0.7 |
| ramp_p1_val, ramp_p2_val = 0.4, 0.75 |
|
|
| p1_steps = int(num_steps * p1_frac) |
| p2_steps = int(num_steps * p2_frac) |
|
|
| phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val |
| phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device) |
| phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7 |
| phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val |
| |
| if p1_steps == 0: phase1_ramp = torch.empty(0, device=device) |
| if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device) |
| if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device) |
|
|
| ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp]) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models.""" |
| rho = 7.0 |
| |
| u = torch.linspace(0, 1, num_steps, device=device) |
| |
| detail_power = 0.85 |
| u_progressive = u ** detail_power |
| |
| mid_boost_strength = 0.08 |
| mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5) |
| |
| u_modulated = u_progressive + mid_boost |
| |
| u_min, u_max = u_modulated.min(), u_modulated.max() |
| if u_max - u_min > 1e-8: |
| u_modulated = (u_modulated - u_min) / (u_max - u_min) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho |
| |
| for i in range(1, len(sigmas)): |
| if sigmas[i] >= sigmas[i-1]: |
| sigmas[i] = sigmas[i-1] * 0.995 |
| max_ratio = 1.5 |
| if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio: |
| sigmas[i] = sigmas[i-1] / max_ratio |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'): |
| """Entropic power schedule.""" |
| rho = 7.0 |
| |
| linear_ramp = torch.linspace(0, 1, num_steps, device=device) |
| power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power |
| |
| ramp = (linear_ramp + power_ramp) / 2.0 |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Schedule optimized around log SNR = 0 region.""" |
| rho = 7.0 |
| |
| log_snr_max = 2 * torch.log(sigma_max) |
| log_snr_min = 2 * torch.log(sigma_min) |
| |
| t = torch.linspace(0, 1, num_steps, device=device) |
| |
| concentration_power = 3.0 |
| sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5)) |
| |
| linear_t = t |
| blend_factor = 0.7 |
| combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t |
| |
| log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max) |
| sigmas = torch.exp(log_snr / 2) |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Constant rate of distributional change.""" |
| rho = 7.0 |
| |
| t = torch.linspace(0, 1, num_steps, device=device) |
| corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Adaptive schedule combining multiple strategies.""" |
| rho = 7.0 |
| |
| base_t = torch.linspace(0, 1, num_steps, device=device) |
| |
| strategies = [ |
| lambda t: t, |
| lambda t: t ** 0.8, |
| lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t), |
| lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))), |
| ] |
| |
| weights = [0.2, 0.3, 0.2, 0.3] |
| combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies)) |
| |
| if (combined_t.max() - combined_t.min()) > 1e-6: |
| combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min()) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Cosine-annealed schedule.""" |
| rho = 7.0 |
| u = torch.linspace(0, 1, num_steps, device=device) |
| t = (1 - torch.cos(math.pi * u)) / 2 |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Uniform in log-SNR space.""" |
| u = torch.linspace(0, 1, num_steps, device=device) |
| log_snr_max = 2 * torch.log(sigma_max) |
| log_snr_min = 2 * torch.log(sigma_min) |
| log_snr = log_snr_max + u * (log_snr_min - log_snr_max) |
| sigmas = torch.exp(log_snr / 2) |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0): |
| """Concentrate steps near mid-range sigmas.""" |
| rho = 7.0 |
| u = torch.linspace(0, 1, num_steps, device=device) |
| k_tensor = torch.tensor(k, device=device, dtype=u.dtype) |
| t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0) |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0): |
| """Faster early lock-in with extra resolution in final steps.""" |
| rho = 7.0 |
| u = torch.linspace(0, 1, num_steps, device=device) |
| |
| early_mask = u < pivot |
| late_mask = ~early_mask |
| |
| t = torch.empty_like(u) |
| t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot |
| late_u = u[late_mask] |
| t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot))) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Karras schedule with controlled jitter.""" |
| if num_steps <= 0: |
| return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) |
| |
| rho = 7.0 |
| indices = torch.arange(num_steps, device=device, dtype=torch.float32) |
| denom = max(1, num_steps - 1) |
| |
| base = (indices + 0.5) / denom |
| jitter_seed = torch.sin((indices + 1) * 2.3999632) |
| jitter_strength = 0.35 |
| jitter = jitter_seed * jitter_strength / denom |
| |
| u = torch.clamp(base + jitter, 0.0, 1.0) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho |
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'): |
| """Stochastic scheduler with controlled randomness.""" |
| rho = 7.0 |
| |
| |
| if base_schedule == 'karras': |
| indices = torch.arange(num_steps, device=device, dtype=torch.float32) |
| u = (indices / max(1, num_steps - 1)) ** (1 / rho) |
| elif base_schedule == 'cosine': |
| u = torch.linspace(0, 1, num_steps, device=device) |
| u = (1 - torch.cos(math.pi * u)) / 2 |
| else: |
| u = torch.linspace(0, 1, num_steps, device=device) |
| |
| |
| if noise_type == 'brownian': |
| noise = torch.randn(num_steps, device=device).cumsum(0) |
| noise = noise / noise.std() |
| elif noise_type == 'uniform': |
| noise = torch.rand(num_steps, device=device) * 2 - 1 |
| else: |
| noise = torch.randn(num_steps, device=device) |
| |
| u_noisy = u + noise * noise_scale / num_steps |
| u_noisy = torch.clamp(u_noisy, 0, 1) |
| |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho |
|
|
| |
| |
| |
| sigmas, _ = torch.sort(sigmas, descending=True) |
|
|
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """ |
| JYS (Jump Your Steps) schedule using dynamically computed timestep sequences. |
| Strategy: Large jumps early, dense clustering in detail region, fine steps at end. |
| Ported from ComfyUI reference implementation. |
| """ |
| |
| |
| |
| jys_timesteps = _compute_jys_timesteps(num_steps) |
| if jys_timesteps and jys_timesteps[-1] == 0: |
| jys_timesteps = jys_timesteps[:-1] |
|
|
| rho = 7.0 |
|
|
| normalized_timesteps = [(1000 - t) / 1000.0 for t in jys_timesteps] |
| t_tensor = torch.tensor(normalized_timesteps, device=device, dtype=torch.float32) |
|
|
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + t_tensor * (min_inv_rho - max_inv_rho)) ** rho |
|
|
| sigmas, _ = torch.sort(sigmas, descending=True) |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def _compute_jys_timesteps(num_steps): |
| """Dynamically compute optimised JYS timestep sequence (0..1000 scale).""" |
| if num_steps <= 0: |
| return [0] |
| if num_steps == 1: |
| return [1000, 0] |
| elif num_steps == 2: |
| return [1000, 500, 0] |
| elif num_steps == 3: |
| return [1000, 600, 200, 0] |
|
|
| early_steps = max(1, int(num_steps * 0.2)) |
| final_steps = max(1, int(num_steps * 0.2)) |
| middle_steps = max(1, num_steps - early_steps - final_steps) |
|
|
| early_jump_size = max(50, (1000 - 600) // early_steps) |
| early_timesteps = [] |
| current_t = 1000 |
| for _ in range(early_steps): |
| early_timesteps.append(int(current_t)) |
| current_t = max(600, current_t - early_jump_size) |
|
|
| middle_timesteps = [] |
| structure_steps = max(1, middle_steps // 2) |
| structure_jump_size = max(10, (600 - 300) // structure_steps) |
| current_t = 600 |
| for _ in range(structure_steps): |
| middle_timesteps.append(int(current_t)) |
| current_t = max(300, current_t - structure_jump_size) |
|
|
| detail_steps = middle_steps - structure_steps |
| if detail_steps > 0: |
| detail_jump_size = max(5, (300 - 200) // detail_steps) |
| current_t = 300 |
| for _ in range(detail_steps): |
| middle_timesteps.append(int(current_t)) |
| current_t = max(200, current_t - detail_jump_size) |
|
|
| final_start = min(middle_timesteps) if middle_timesteps else 200 |
| final_jump_size = max(5, final_start // final_steps) |
| final_timesteps = [] |
| current_t = final_start |
| for _ in range(final_steps): |
| final_timesteps.append(int(current_t)) |
| current_t = max(0, current_t - final_jump_size) |
|
|
| all_timesteps = early_timesteps + middle_timesteps + final_timesteps |
| unique_timesteps = list(dict.fromkeys(all_timesteps)) |
| unique_timesteps.sort(reverse=True) |
|
|
| while len(unique_timesteps) < num_steps: |
| for i in range(len(unique_timesteps) - 1): |
| mid_point = (unique_timesteps[i] + unique_timesteps[i + 1]) // 2 |
| if mid_point not in unique_timesteps: |
| unique_timesteps.insert(i + 1, mid_point) |
| if len(unique_timesteps) >= num_steps: |
| break |
|
|
| if len(unique_timesteps) > num_steps: |
| unique_timesteps = unique_timesteps[:num_steps] |
|
|
| if unique_timesteps[-1] != 0: |
| unique_timesteps.append(0) |
|
|
| return unique_timesteps |
|
|
|
|
| def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """Hybrid: JYS mid-phase with Karras locks.""" |
| if num_steps <= 0: |
| return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) |
|
|
| rho = 7.0 |
|
|
| jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1] |
|
|
| indices = torch.arange(num_steps, device=device, dtype=torch.float32) |
| denom = max(1, num_steps - 1) |
| base = (indices + 0.5) / denom |
| jitter_seed = torch.sin((indices + 1) * 2.3999632) |
| jitter_strength = 0.35 |
| jitter = jitter_seed * jitter_strength / denom |
| u = torch.clamp(base + jitter, 0.0, 1.0) |
|
|
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho |
|
|
| positions = torch.linspace(0, 1, num_steps, device=device) |
| jys_weight = torch.empty_like(positions) |
| early_mask = positions < 0.3 |
| mid_mask = (positions >= 0.3) & (positions < 0.8) |
| late_mask = positions >= 0.8 |
| jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3) |
| jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5) |
| jys_weight[late_mask] = 0.9 |
| jys_weight = jys_weight.clamp(0.2, 0.9) |
|
|
| log_jys = torch.log(jys_sigmas.clamp_min(1e-6)) |
| log_karras = torch.log(karras_sigmas.clamp_min(1e-6)) |
| log_hybrid = torch.lerp(log_karras, log_jys, jys_weight) |
|
|
| hybrid = torch.exp(log_hybrid) |
|
|
| smoothing = 1.0 - 0.05 * (1 - positions) ** 2 |
| hybrid = hybrid * smoothing |
|
|
| for i in range(1, hybrid.shape[0]): |
| if hybrid[i] > hybrid[i - 1]: |
| hybrid[i] = hybrid[i - 1] * 0.999 |
|
|
| return torch.cat([hybrid, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """AYS (Align Your Steps) optimized for SDXL.""" |
| |
| AYS_SCHEDULES = { |
| 10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000], |
| 15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336, |
| 0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000], |
| 20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000, |
| 0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156, |
| 0.0039, 0.0000], |
| 25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000, |
| 0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500, |
| 0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000], |
| 30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667, |
| 0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917, |
| 0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081, |
| 0.0041, 0.0010, 0.0000], |
| } |
| |
| if num_steps in AYS_SCHEDULES: |
| normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32) |
| else: |
| available_steps = sorted(AYS_SCHEDULES.keys()) |
| |
| if num_steps < available_steps[0]: |
| ref_steps = available_steps[0] |
| elif num_steps > available_steps[-1]: |
| ref_steps = available_steps[-1] |
| else: |
| ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1]) |
| |
| ref_schedule = np.array(AYS_SCHEDULES[ref_steps]) |
| |
| t_ref = np.linspace(0, 1, len(ref_schedule)) |
| t_new = np.linspace(0, 1, num_steps + 1) |
| |
| log_ref = np.log(ref_schedule + 1e-8) |
| log_ref[-1] = log_ref[-2] - 3.0 |
| |
| log_interp = np.interp(t_new, t_ref, log_ref) |
| normalized_np = np.exp(log_interp) |
| normalized_np[-1] = 0.0 |
| |
| normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32) |
| |
| sigma_range = sigma_max - sigma_min |
| sigmas = normalized * sigma_range + sigma_min |
| |
| sigmas[0] = sigma_max |
| sigmas[-1] = 0.0 |
| |
| for i in range(1, len(sigmas) - 1): |
| if sigmas[i] >= sigmas[i-1]: |
| sigmas[i] = sigmas[i-1] * 0.999 |
|
|
| |
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """ |
| AkashicAOS Alt: Karras-based schedule with EQ-VAE-tuned warping. |
| Stronger detail-progressive bias (power=0.78) and shifted tanh crossover at t=0.55. |
| Adaptive rho scales with step count for multi-step solver stability. |
| """ |
| if num_steps <= 0: |
| return torch.zeros(1, device=device) |
|
|
| rho = min(11.0, max(7.0, 7.0 + 2.0 * (20.0 / max(num_steps, 10)))) |
| u = torch.linspace(0, 1, num_steps, device=device) |
|
|
| detail_power = 0.78 |
| u_detail = u ** detail_power |
|
|
| t_center = 0.55 |
| beta = 0.07 |
| gamma = 4.0 |
| crossover = beta * torch.tanh(gamma * (u - t_center)) |
| u_modulated = u_detail + crossover |
|
|
| u_min, u_max = u_modulated.min(), u_modulated.max() |
| if u_max - u_min > 1e-8: |
| u_modulated = (u_modulated - u_min) / (u_max - u_min) |
|
|
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho |
|
|
| max_ratio = 1.5 |
| for i in range(1, len(sigmas)): |
| if sigmas[i] >= sigmas[i - 1]: |
| sigmas[i] = sigmas[i - 1] * 0.995 |
| if sigmas[i - 1] / sigmas[i].clamp(min=1e-10) > max_ratio: |
| sigmas[i] = sigmas[i - 1] / max_ratio |
|
|
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): |
| """ |
| AkashicEQFlow: Robust crossover-focused log-SNR schedule for EQ-VAE models. |
| Concentrates steps around the structure-to-detail transition in logSNR space, |
| blended with a Karras prior. Adaptive density width + ratio slew-rate limiting. |
| """ |
| if num_steps <= 0: |
| return torch.zeros(1, device=device) |
|
|
| lambda_min = -2.0 * math.log(max(float(sigma_max), 1e-10)) |
| lambda_max = -2.0 * math.log(max(float(sigma_min), 1e-10)) |
| lambda_range = max(lambda_max - lambda_min, 1e-8) |
|
|
| step_factor = min(1.0, max(0.0, (num_steps - 16) / 30.0)) |
| lambda_center = 0.20 + 0.15 * step_factor |
| u_center = (lambda_center - lambda_min) / lambda_range |
| u_center = float(min(0.88, max(0.12, u_center))) |
|
|
| concentration = min(3.2, max(1.35, 1.1 + num_steps / 16.0)) |
| base_width = min(0.30, max(0.18, 0.31 - 0.0028 * num_steps)) |
| width_left = base_width * 1.06 |
| width_right = base_width * 0.94 |
| detail_side_gain = 1.08 + 0.04 * step_factor |
|
|
| N = 1200 |
| t = torch.linspace(0, 1, N, device=device) |
| delta = t - u_center |
| left_core = torch.exp(-((delta / width_left) ** 2) / 2.0) |
| right_core = detail_side_gain * torch.exp(-((delta / width_right) ** 2) / 2.0) |
| crossover_core = torch.where(delta <= 0, left_core, right_core) |
|
|
| detail_floor = 0.08 * (t ** 1.4) |
| composition_floor = 0.05 * ((1 - t) ** 1.7) |
| density = 1.0 + concentration * crossover_core + detail_floor + composition_floor |
|
|
| dt_val = 1.0 / (N - 1) |
| cdf = torch.zeros(N, device=device) |
| cdf[1:] = torch.cumsum((density[:-1] + density[1:]) * 0.5 * dt_val, dim=0) |
| cdf = cdf / cdf[-1].clamp(min=1e-12) |
|
|
| targets = torch.linspace(0, 1, num_steps, device=device) |
| indices = torch.searchsorted(cdf, targets).clamp(1, N - 1) |
| lo = indices - 1 |
| hi = indices |
| frac = (targets - cdf[lo]) / (cdf[hi] - cdf[lo]).clamp(min=1e-12) |
| u_steps = t[lo] + frac * (t[hi] - t[lo]) |
|
|
| lambdas_eqflow = lambda_min + u_steps * lambda_range |
|
|
| rho = min(10.0, max(7.0, 7.0 + 1.5 * (22.0 / max(num_steps, 12)))) |
| u_karras = torch.linspace(0, 1, num_steps, device=device) |
| min_inv_rho = sigma_min ** (1 / rho) |
| max_inv_rho = sigma_max ** (1 / rho) |
| sigmas_karras = (max_inv_rho + u_karras * (min_inv_rho - max_inv_rho)) ** rho |
| lambdas_karras = -2.0 * torch.log(sigmas_karras.clamp(min=1e-10)) |
|
|
| blend_eqflow = min(0.60, max(0.35, 0.38 + num_steps / 200.0)) |
| lambdas = (1.0 - blend_eqflow) * lambdas_karras + blend_eqflow * lambdas_eqflow |
| sigmas = torch.exp(-lambdas / 2.0) |
|
|
| if num_steps >= 40: |
| max_ratio = 1.50 |
| elif num_steps >= 28: |
| max_ratio = 1.55 |
| elif num_steps >= 18: |
| max_ratio = 1.65 |
| else: |
| max_ratio = 1.85 |
| ratio_slew = 1.18 |
| prev_ratio = None |
|
|
| sigmas[0] = sigma_max |
| for i in range(1, len(sigmas)): |
| if sigmas[i] >= sigmas[i - 1]: |
| sigmas[i] = sigmas[i - 1] * 0.995 |
| ratio = float((sigmas[i - 1] / sigmas[i].clamp(min=1e-10)).item()) |
| ratio = min(ratio, max_ratio) |
| if prev_ratio is not None: |
| ratio = min(ratio, prev_ratio * ratio_slew) |
| ratio = max(ratio, prev_ratio / ratio_slew) |
| ratio = max(1.001, ratio) |
| sigmas[i] = sigmas[i - 1] / ratio |
| prev_ratio = ratio |
|
|
| return torch.cat([sigmas, torch.zeros(1, device=device)]) |
|
|
|
|
| def apply_custom_scheduler(sigmas, scheduler_type="Standard"): |
| """ |
| Apply a custom sigma schedule. |
| |
| sigma_min uses sigmas[-2] (last non-zero step), never the zero-terminator. |
| Each scheduler is invoked via a lambda so keyword args with non-standard |
| defaults (e.g. Entropic's `power`) are always passed correctly. |
| """ |
| if scheduler_type == "Standard" or len(sigmas) < 2: |
| return sigmas |
|
|
| sigma_max = sigmas[0] |
| |
| sigma_min = sigmas[-2] if len(sigmas) >= 2 else sigmas[0] |
| if sigma_min <= 0: |
| sigma_min = sigma_max * 1e-3 |
| num_steps = len(sigmas) - 1 |
| device = sigmas.device |
|
|
| scheduler_map = { |
| "AOS-V": lambda: create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device), |
| "AOS-Epsilon": lambda: create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device), |
| "AkashicAOS": lambda: create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Entropic": lambda: create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device=device), |
| "SNR-Optimized": lambda: create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Constant-Rate": lambda: create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Adaptive-Optimized": lambda: create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Cosine-Annealed": lambda: create_cosine_sigmas(sigma_max, sigma_min, num_steps, device), |
| "LogSNR-Uniform": lambda: create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Tanh Mid-Boost": lambda: create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Exponential Tail": lambda: create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Jittered-Karras": lambda: create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Stochastic": lambda: create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device=device), |
| "JYS (Dynamic)": lambda: create_jys_sigmas(sigma_max, sigma_min, num_steps, device), |
| "Hybrid JYS-Karras": lambda: create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device), |
| "AYS-SDXL": lambda: create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device), |
| "AkashicAOS Alt": lambda: create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device), |
| "AkashicEQFlow": lambda: create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device), |
| } |
|
|
| fn = scheduler_map.get(scheduler_type) |
| if fn is not None: |
| try: |
| result = fn() |
| if result is not None and len(result) > 1: |
| return result |
| print(f"⚠️ Scheduler {scheduler_type} returned empty/None, using standard") |
| except Exception as e: |
| print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard") |
|
|
| return sigmas |
|
|
|
|
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): |
| """Euler sampler with Adept weight scaling OR custom sampler.""" |
| |
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), |
| order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), |
| ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, |
| settings={}, |
| eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), |
| use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, |
| settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), |
| phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), |
| enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, |
| settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'euler' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) |
| return _basic_euler(model, x, sigmas, extra_args, callback, disable) |
|
|
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| total_steps = len(sigmas) - 1 |
| |
| for i in range(total_steps): |
| sigma = sigmas[i] |
| gamma = min(s_churn / total_steps, 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| eps = torch.randn_like(x) * s_noise if gamma > 0 else 0 |
| sigma_hat = sigma * (gamma + 1) |
| if gamma > 0: |
| x = x + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 |
| |
| denoised = model(x, sigma_hat * s_in, **extra_args) |
| d = to_d(x, sigma_hat, denoised) |
| dt = sigmas[i + 1] - sigma_hat |
| x = x + d * dt |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma_hat, 'denoised': denoised}) |
| |
| return x |
|
|
| def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, |
| disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): |
| """Euler Ancestral with Adept weight scaling.""" |
|
|
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
|
|
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'euler_ancestral' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) |
| return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) |
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False) |
| current_eta = ADEPT_STATE.get('eta', eta) |
| current_s_noise = ADEPT_STATE.get('s_noise', s_noise) |
| |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| if noise_sampler is None: |
| noise_sampler = default_noise_sampler(x) |
| |
| total_steps = len(sigmas) - 1 |
| print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}") |
| |
| for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| progress = i / max(total_steps, 1) |
| |
| |
| if use_adaptive_eta: |
| if progress < 0.3: |
| current_eta = eta * 1.08 |
| elif progress < 0.7: |
| current_eta = eta * 0.95 |
| else: |
| current_eta = eta * 1.02 |
| else: |
| current_eta = eta |
| |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised = model(x, sigma * s_in, **extra_args) |
| else: |
| denoised = model(x, sigma * s_in, **extra_args) |
| |
| |
| sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) |
| d = to_d(x, sigma, denoised) |
| |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) |
| |
| dt = sigma_down - sigma |
| x = x + d * dt |
| |
| if sigma_up > 0: |
| noise = noise_sampler(sigma, sigma_next) * current_s_noise |
| x = x + noise * sigma_up |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None): |
| """Fallback basic Euler (used when ORIGINAL_SAMPLERS has no 'euler' key).""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| for i in trange(len(sigmas) - 1, disable=disable): |
| sigma = sigmas[i] |
| denoised = model(x, sigma * s_in, **extra_args) |
| d = to_d(x, sigma, denoised) |
| dt = sigmas[i + 1] - sigma |
| x = x + d * dt |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| return x |
|
|
|
|
| def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): |
| """Fallback basic Euler Ancestral.""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| noise_sampler = default_noise_sampler(x) |
| |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) |
| d = to_d(x, sigmas[i], denoised) |
| dt = sigma_down - sigmas[i] |
| x = x + d * dt |
| if sigma_up > 0: |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
|
|
| def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): |
| """Heun sampler with Adept weight scaling.""" |
|
|
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
|
|
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'heun' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) |
| return _basic_heun(model, x, sigmas, extra_args, callback, disable) |
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| total_steps = len(sigmas) - 1 |
| print(f"✅ Adept Heun active: scale={base_scale:.2f}") |
| |
| for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised = model(x, sigma * s_in, **extra_args) |
| else: |
| denoised = model(x, sigma * s_in, **extra_args) |
| |
| d = to_d(x, sigma, denoised) |
| |
| if torch.isnan(d).any() or torch.isinf(d).any(): |
| d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) |
| |
| dt = sigma_next - sigma |
| |
| if sigma_next == 0: |
| |
| x = x + d * dt |
| else: |
| |
| x_2 = x + d * dt |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised_2 = model(x_2, sigma_next * s_in, **extra_args) |
| else: |
| denoised_2 = model(x_2, sigma_next * s_in, **extra_args) |
| |
| d_2 = to_d(x_2, sigma_next, denoised_2) |
| |
| if torch.isnan(d_2).any() or torch.isinf(d_2).any(): |
| d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0) |
| |
| |
| d_prime = (d + d_2) / 2 |
| x = x + d_prime * dt |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None): |
| """Fallback basic Heun.""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| d = to_d(x, sigmas[i], denoised) |
| dt = sigmas[i + 1] - sigmas[i] |
| |
| if sigmas[i + 1] == 0: |
| x = x + d * dt |
| else: |
| x_2 = x + d * dt |
| denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) |
| d_2 = to_d(x_2, sigmas[i + 1], denoised_2) |
| d_prime = (d + d_2) / 2 |
| x = x + d_prime * dt |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
|
|
| def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): |
| """DPM++ 2M sampler with Adept weight scaling.""" |
|
|
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
|
|
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'dpmpp_2m' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable) |
| return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable) |
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| total_steps = len(sigmas) - 1 |
| print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}") |
| |
| old_denoised = None |
| |
| for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised = model(x, sigma * s_in, **extra_args) |
| else: |
| denoised = model(x, sigma * s_in, **extra_args) |
| |
| |
| t, t_next = sigma, sigma_next |
| h = t_next - t |
| |
| if old_denoised is None or sigma_next == 0: |
| |
| x = (sigma_next / sigma) * x - (-h).expm1() * denoised |
| else: |
| |
| h_last = t - sigmas[i - 1] |
| r = h_last / h |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
| x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d |
| |
| old_denoised = denoised |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): |
| """Fallback basic DPM++ 2M.""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| old_denoised = None |
| |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| t, t_next = sigmas[i], sigmas[i + 1] |
| h = t_next - t |
| |
| if old_denoised is None or sigmas[i + 1] == 0: |
| x = (t_next / t) * x - (-h).expm1() * denoised |
| else: |
| h_last = t - sigmas[i - 1] |
| r = h_last / h |
| denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised |
| x = (t_next / t) * x - (-h).expm1() * denoised_d |
| |
| old_denoised = denoised |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
|
|
| def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): |
| """DPM++ 2S Ancestral with Adept weight scaling.""" |
|
|
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
|
|
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) |
| return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) |
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| current_eta = ADEPT_STATE.get('eta', eta) |
| current_s_noise = ADEPT_STATE.get('s_noise', s_noise) |
| |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| if noise_sampler is None: |
| noise_sampler = default_noise_sampler(x) |
| |
| total_steps = len(sigmas) - 1 |
| print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}") |
| |
| for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"): |
| sigma = sigmas[i] |
| sigma_next = sigmas[i + 1] |
| |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised = model(x, sigma * s_in, **extra_args) |
| else: |
| denoised = model(x, sigma * s_in, **extra_args) |
| |
| |
| sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) |
| |
| if sigma_down == 0: |
| d = to_d(x, sigma, denoised) |
| x = x + d * (sigma_down - sigma) |
| else: |
| |
| t, t_next = sigma, sigma_down |
| h = t_next - t |
| s = t + h * 0.5 |
| |
| |
| x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised_mid = model(x_mid, s * s_in, **extra_args) |
| else: |
| denoised_mid = model(x_mid, s * s_in, **extra_args) |
| |
| |
| x = (t_next / t) * x - (-h).expm1() * denoised_mid |
| |
| |
| if sigma_up > 0: |
| noise = noise_sampler(sigma, sigma_next) * current_s_noise |
| x = x + noise * sigma_up |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): |
| """Fallback basic DPM++ 2S Ancestral.""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| noise_sampler = default_noise_sampler(x) |
| |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) |
| |
| if sigma_down == 0: |
| d = to_d(x, sigmas[i], denoised) |
| x = x + d * (sigma_down - sigmas[i]) |
| else: |
| t, t_next = sigmas[i], sigma_down |
| h = t_next - t |
| s = t + h * 0.5 |
| x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised |
| denoised_mid = model(x_mid, s * s_in, **extra_args) |
| x = (t_next / t) * x - (-h).expm1() * denoised_mid |
| |
| if sigma_up > 0: |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| @torch.no_grad() |
|
|
| def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): |
| """LMS sampler with Adept weight scaling.""" |
|
|
| |
| if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): |
| custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') |
| print(f"🌀 Redirecting to {custom_type}") |
| |
| |
| scheduler = ADEPT_STATE.get('scheduler', 'Standard') |
| if scheduler != "Standard": |
| sigmas = apply_custom_scheduler(sigmas, scheduler) |
| print(f" 📊 Applied {scheduler} scheduler") |
| |
| if custom_type == "Akashic Solver v2": |
| return sample_akashic_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), |
| smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), |
| use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') |
| ) |
| elif custom_type == "Adept Solver": |
| return sample_adept_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Adept Ancestral Solver": |
| return sample_adept_ancestral_solver( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), |
| phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), |
| use_detail_enhancement=False, settings={} |
| ) |
| elif custom_type == "Mirror Correction Euler": |
| return sample_mirror_correction_euler( |
| model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, |
| eta=ADEPT_STATE.get('eta', 1.0), |
| s_noise=ADEPT_STATE.get('s_noise', 1.0), |
| correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), |
| smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) |
| ) |
| |
|
|
| |
| if not ADEPT_STATE.get('enabled', False): |
| global ORIGINAL_SAMPLERS |
| if 'lms' in ORIGINAL_SAMPLERS: |
| return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order) |
| return _basic_lms(model, x, sigmas, extra_args, callback, disable, order) |
| |
| _sched = ADEPT_STATE.get('scheduler', 'Standard') |
| if _sched != 'Standard': |
| sigmas = apply_custom_scheduler(sigmas, _sched) |
|
|
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| |
| |
| base_scale = ADEPT_STATE.get('scale', 1.0) |
| shift = ADEPT_STATE.get('shift', 0.0) |
| start_pct = ADEPT_STATE.get('start_pct', 0.0) |
| end_pct = ADEPT_STATE.get('end_pct', 1.0) |
| |
| |
| try: |
| unet_model = shared.sd_model.model.diffusion_model |
| except AttributeError: |
| unet_model = None |
| |
| total_steps = len(sigmas) - 1 |
| print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}") |
| |
| ds = [] |
| |
| for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"): |
| sigma = sigmas[i] |
| |
| |
| current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) |
| |
| |
| if should_patch_weights(unet_model, current_scale, shift): |
| with AdeptWeightPatcher(unet_model, current_scale, shift): |
| denoised = model(x, sigma * s_in, **extra_args) |
| else: |
| denoised = model(x, sigma * s_in, **extra_args) |
| |
| d = to_d(x, sigma, denoised) |
| ds.append(d) |
| |
| if len(ds) > order: |
| ds.pop(0) |
| |
| |
| cur_order = min(i + 1, order) |
| coeffs = [1.0] |
| |
| for j in range(1, cur_order): |
| prod = 1.0 |
| for k in range(cur_order): |
| if k != j: |
| prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) |
| coeffs.append(prod) |
| |
| |
| d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) |
| |
| dt = sigmas[i + 1] - sigma |
| x = x + d_multistep * dt |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): |
| """Fallback basic LMS.""" |
| extra_args = {} if extra_args is None else extra_args |
| s_in = x.new_ones([x.shape[0]]) |
| ds = [] |
| |
| for i in trange(len(sigmas) - 1, disable=disable): |
| denoised = model(x, sigmas[i] * s_in, **extra_args) |
| d = to_d(x, sigmas[i], denoised) |
| ds.append(d) |
| |
| if len(ds) > order: |
| ds.pop(0) |
| |
| cur_order = min(i + 1, order) |
| coeffs = [1.0] |
| for j in range(1, cur_order): |
| prod = 1.0 |
| for k in range(cur_order): |
| if k != j: |
| prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) |
| coeffs.append(prod) |
| |
| d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) |
| dt = sigmas[i + 1] - sigmas[i] |
| x = x + d_multistep * dt |
| |
| if callback is not None: |
| callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| def patch_k_diffusion(): |
| """Apply monkey patches to ALL k-diffusion samplers.""" |
| global ORIGINAL_SAMPLERS |
|
|
| samplers_to_patch = { |
| 'sample_euler': sample_adept_euler, |
| 'sample_euler_ancestral': sample_adept_euler_ancestral, |
| 'sample_heun': sample_adept_heun, |
| 'sample_dpmpp_2m': sample_adept_dpmpp_2m, |
| 'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, |
| 'sample_lms': sample_adept_lms, |
| } |
|
|
| patched_count = 0 |
| for original_name, adept_func in samplers_to_patch.items(): |
| if not hasattr(k_diffusion.sampling, original_name): |
| continue |
| key = original_name.replace('sample_', '') |
| current_func = getattr(k_diffusion.sampling, original_name) |
| |
| if key not in ORIGINAL_SAMPLERS: |
| ORIGINAL_SAMPLERS[key] = current_func |
| |
| if current_func is not adept_func: |
| setattr(k_diffusion.sampling, original_name, adept_func) |
| patched_count += 1 |
|
|
| print(f"✅ Adept Sampler v5: Patched {patched_count} samplers") |
| print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS") |
| print(f" Schedulers: 18 types available") |
|
|
|
|
| def unpatch_k_diffusion(): |
| """ |
| Restore original k-diffusion samplers. |
| |
| Safe-unpatch strategy: before restoring we check whether the live slot |
| still holds *our* wrapper. If another extension has wrapped us on top |
| (i.e. live_func is not our adept_func but also not the original we |
| saved), blindly restoring would silently remove *their* wrapper too. |
| In that case we skip the restore for that slot and log a warning so the |
| operator knows the coexistence situation. |
| """ |
| global ORIGINAL_SAMPLERS |
|
|
| adept_wrappers = { |
| 'euler': 'sample_euler', |
| 'euler_ancestral': 'sample_euler_ancestral', |
| 'heun': 'sample_heun', |
| 'dpmpp_2m': 'sample_dpmpp_2m', |
| 'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral', |
| 'lms': 'sample_lms', |
| } |
| adept_funcs = { |
| 'euler': sample_adept_euler, |
| 'euler_ancestral': sample_adept_euler_ancestral, |
| 'heun': sample_adept_heun, |
| 'dpmpp_2m': sample_adept_dpmpp_2m, |
| 'dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, |
| 'lms': sample_adept_lms, |
| } |
|
|
| restored_count = 0 |
| skipped_count = 0 |
| for key, attr_name in adept_wrappers.items(): |
| if key not in ORIGINAL_SAMPLERS: |
| continue |
| live_func = getattr(k_diffusion.sampling, attr_name, None) |
| our_func = adept_funcs[key] |
| saved_original = ORIGINAL_SAMPLERS[key] |
|
|
| if live_func is our_func: |
| |
| setattr(k_diffusion.sampling, attr_name, saved_original) |
| restored_count += 1 |
| elif live_func is saved_original: |
| |
| restored_count += 1 |
| else: |
| |
| |
| print(f"⚠️ Adept unpatch: {attr_name} is currently owned by another " |
| f"extension ({live_func!r}). Skipping restore to avoid breaking " |
| f"their wrapper — you may need to reload the UI to fully unload.") |
| skipped_count += 1 |
|
|
| ORIGINAL_SAMPLERS.clear() |
| print(f"🔄 Adept Sampler: Restored {restored_count} samplers" |
| + (f", skipped {skipped_count} (foreign wrappers)" if skipped_count else "")) |
|
|
|
|
| |
| |
| |
|
|
| class AdeptSamplerScript(scripts.Script): |
| def title(self): |
| return "Adept Sampler v5" |
| |
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
| |
| def ui(self, is_img2img): |
| with gr.Accordion("Adept Sampler v5", open=False): |
| enabled = gr.Checkbox(label="Enable Adept Sampler", value=False, elem_id="adept_enabled") |
| |
| with gr.Row(): |
| scale = gr.Slider(minimum=0.5, maximum=2.0, step=0.05, value=1.0, label="Weight Scale") |
| shift = gr.Slider(minimum=-0.5, maximum=0.5, step=0.01, value=0.0, label="Weight Shift") |
|
|
| with gr.Row(): |
| start_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Start Percent") |
| end_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Percent") |
|
|
| gr.HTML("<p style='color: #888; font-size: 0.85em; margin: 2px 0 10px;'>" |
| "⚠️ Weight Scale / Shift / Start–End apply to the 6 patched k-diffusion samplers only. " |
| "Custom samplers (Akashic, Adept, Mirror) use their own internal parameters.</p>") |
| |
| with gr.Row(): |
| eta = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Eta (Ancestral samplers)") |
| s_noise = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="S-Noise") |
| |
| adaptive_eta = gr.Checkbox(label="Adaptive Eta (dynamic eta during sampling)", value=False) |
| |
| scheduler = gr.Dropdown( |
| choices=["Standard", "AOS-V", "AOS-Epsilon", "AkashicAOS", "Entropic", "SNR-Optimized", |
| "Constant-Rate", "Adaptive-Optimized", "Cosine-Annealed", "LogSNR-Uniform", |
| "Tanh Mid-Boost", "Exponential Tail", "Jittered-Karras", "Stochastic", |
| "JYS (Dynamic)", "Hybrid JYS-Karras", "AYS-SDXL", |
| "AkashicAOS Alt", "AkashicEQFlow"], |
| value="Standard", label="Scheduler Type" |
| ) |
| |
| vae_reflection = gr.Checkbox(label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)", value=False) |
| |
| gr.HTML("<hr style='margin: 15px 0;'>") |
| gr.HTML("<h3 style='margin: 10px 0;'>🌀 Custom Advanced Samplers</h3>") |
| gr.HTML("<p style='color: #888; font-size: 0.9em;'>Enable to use Akashic/Adept/Ancestral samplers instead of k-diffusion</p>") |
| |
| use_custom = gr.Checkbox(label="Use Custom Sampler (overrides k-diffusion)", value=False) |
| custom_type = gr.Dropdown( |
| choices=["Akashic Solver v2", "Adept Solver", "Adept Ancestral Solver", "Mirror Correction Euler"], |
| value="Akashic Solver v2", label="Custom Sampler Type" |
| ) |
| |
| with gr.Accordion("⚙️ Akashic Solver Settings", open=False): |
| tau = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Tau (0=ODE, 1=SDE)") |
| phase_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Phase Strength") |
| smea = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="SMEA (high-res coherency)") |
| ndb = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="NDB (detail boost)") |
| eqvae = gr.Dropdown(choices=["Off", "Balanced"], value="Off", label="EQ-VAE Mode") |
| |
| with gr.Accordion("⚙️ Adept Solver Settings", open=False): |
| solver_order = gr.Slider(minimum=1, maximum=3, step=1, value=2, label="Order (1-3)") |
| use_corrector = gr.Checkbox(value=True, label="Use Corrector") |
| |
| with gr.Accordion("⚙️ Ancestral Solver Settings", open=False): |
| phase_noise = gr.Checkbox(value=False, label="Phase-Aware Noise") |
| enhanced_deriv = gr.Checkbox(value=False, label="Enhanced Derivative") |
|
|
| with gr.Accordion("⚙️ Mirror Correction Euler Settings", open=False): |
| gr.HTML("<p style='color: #888; font-size: 0.9em;'>Active only when Custom Sampler = Mirror Correction Euler</p>") |
| mirror_correction_phase = gr.Slider( |
| minimum=0.0, maximum=1.0, step=0.05, value=0.5, |
| label="Correction Phase (fraction of steps with 3-call Heun correction)" |
| ) |
| mirror_smooth_phase = gr.Checkbox( |
| value=False, |
| label="Smooth Phase (log-sigma blend instead of binary cutoff)" |
| ) |
|
|
| gr.HTML("<hr style='margin: 15px 0;'>") |
| gr.HTML("<h3 style='margin: 10px 0;'>🎛️ CFG Enhancements</h3>") |
| gr.HTML("<p style='color: #888; font-size: 0.9em;'>" |
| "Combat CFG Drift works in stock A1111 via official callback. " |
| "Spectral Modulation & Phase-Aware CFG use a native sampler hook on " |
| "Forge/reForge-like backends, or a CFGDenoiser monkey-patch on stock A1111 " |
| "(near-parity; active mode logged to console).</p>") |
|
|
| with gr.Accordion("⚙️ CFG Enhancement Settings", open=False): |
| cfg_drift_enabled = gr.Checkbox(value=False, label="Enable Combat CFG Drift") |
| with gr.Row(): |
| cfg_drift_method = gr.Dropdown( |
| choices=["mean", "median"], value="mean", |
| label="Drift Method" |
| ) |
| cfg_drift_intensity = gr.Slider( |
| minimum=0.0, maximum=1.0, step=0.05, value=0.5, |
| label="Drift Intensity" |
| ) |
| spectral_cfg_enabled = gr.Checkbox( |
| value=False, label="Enable Spectral Modulation (native hook or A1111 monkey-patch)" |
| ) |
| with gr.Row(): |
| spectral_multiplier = gr.Slider( |
| minimum=0.0, maximum=2.0, step=0.05, value=1.0, |
| label="Spectral Multiplier" |
| ) |
| spectral_percentile = gr.Slider( |
| minimum=1.0, maximum=25.0, step=0.5, value=5.0, |
| label="Spectral Percentile" |
| ) |
| phase_cfg_enabled = gr.Checkbox( |
| value=False, label="Enable Phase-Aware CFG (native hook or A1111 monkey-patch)" |
| ) |
| with gr.Row(): |
| phase_cfg_alpha = gr.Slider( |
| minimum=1.1, maximum=4.0, step=0.1, value=2.0, |
| label="Phase CFG Alpha" |
| ) |
| phase_cfg_beta = gr.Slider( |
| minimum=1.1, maximum=4.0, step=0.1, value=2.0, |
| label="Phase CFG Beta" |
| ) |
|
|
| return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, |
| use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, |
| phase_noise, enhanced_deriv, |
| mirror_correction_phase, mirror_smooth_phase, |
| cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, |
| spectral_cfg_enabled, spectral_multiplier, spectral_percentile, |
| phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta] |
| |
| def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, |
| use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, |
| phase_noise, enhanced_deriv, |
| mirror_correction_phase, mirror_smooth_phase, |
| cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, |
| spectral_cfg_enabled, spectral_multiplier, spectral_percentile, |
| phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta): |
| global ADEPT_STATE |
|
|
| |
| |
| |
| ADEPT_STATE.update({ |
| "enabled": enabled, |
| "scale": scale, |
| "shift": shift, |
| "start_pct": start_pct, |
| "end_pct": end_pct, |
| "eta": eta, |
| "s_noise": s_noise, |
| "adaptive_eta": adaptive_eta, |
| "scheduler": scheduler, |
| "vae_reflection": enabled and vae_reflection, |
| "use_custom_sampler": use_custom, |
| "custom_sampler": custom_type, |
| "tau": tau, |
| "phase_strength": phase_strength, |
| "smea_strength": smea, |
| "ndb_strength": ndb, |
| "eqvae_mode": eqvae, |
| "solver_order": int(solver_order), |
| "use_corrector": use_corrector, |
| "phase_noise": phase_noise, |
| "enhanced_derivative": enhanced_deriv, |
| |
| "mirror_correction_phase": mirror_correction_phase, |
| "mirror_smooth_phase": mirror_smooth_phase, |
| |
| "cfg_drift_enabled": enabled and cfg_drift_enabled, |
| "cfg_drift_method": cfg_drift_method, |
| "cfg_drift_intensity": cfg_drift_intensity, |
| "spectral_cfg_enabled": enabled and spectral_cfg_enabled, |
| "spectral_multiplier": spectral_multiplier, |
| "spectral_percentile": spectral_percentile, |
| "phase_cfg_enabled": enabled and phase_cfg_enabled, |
| "phase_cfg_alpha": phase_cfg_alpha, |
| "phase_cfg_beta": phase_cfg_beta, |
| }) |
|
|
| |
| |
|
|
| |
| |
| runtime_mode = configure_cfg_runtime() |
|
|
| if enabled: |
| info = { |
| "Adept Sampler": "v5", |
| "Adept Scheduler": scheduler, |
| "CFG Runtime": runtime_mode, |
| } |
| if use_custom: |
| info["Adept Custom"] = custom_type |
| if custom_type == "Akashic Solver v2": |
| info["Adept Tau"] = tau |
| info["Adept EQ-VAE"] = eqvae |
| p.extra_generation_params.update(info) |
|
|
| def process_batch(self, p, *args, **kwargs): |
| """Apply VAE Reflection before batch processing.""" |
| if ADEPT_STATE.get("enabled", False) and ADEPT_STATE.get("vae_reflection", False): |
| try: |
| vae_model = shared.sd_model.first_stage_model |
| patcher = VAEReflectionPatcher(vae_model) |
| patcher.__enter__() |
| p.adept_vae_patcher = patcher |
| except Exception as e: |
| print(f"⚠️ VAE Reflection error: {e}") |
| |
| def postprocess_batch(self, p, *args, **kwargs): |
| """Restore VAE padding modes after batch processing.""" |
| if hasattr(p, 'adept_vae_patcher'): |
| try: |
| p.adept_vae_patcher.__exit__(None, None, None) |
| delattr(p, 'adept_vae_patcher') |
| except Exception as e: |
| print(f"⚠️ VAE Reflection restore error: {e}") |
| |
| force_restore_vae_reflection() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _adept_deferred_init(): |
| patch_k_diffusion() |
|
|
| try: |
| script_callbacks.on_before_ui(_adept_deferred_init) |
| except Exception: |
| |
| patch_k_diffusion() |
|
|
| def on_script_unloaded(): |
| try: |
| force_restore_vae_reflection() |
| except Exception: |
| pass |
| try: |
| uninstall_a1111_cfg_callbacks() |
| except Exception: |
| pass |
| try: |
| uninstall_native_cfg_hook() |
| except Exception: |
| pass |
| try: |
| unpatch_cfg_denoiser() |
| except Exception: |
| pass |
| try: |
| unpatch_k_diffusion() |
| except Exception: |
| pass |
|
|
| try: |
| script_callbacks.on_script_unloaded(on_script_unloaded) |
| except AttributeError: |
| print("⚠️ Script unload callback not available") |
|
|
| print("🚀 Adept Sampler v5 loaded!") |
| print(" ✨ 4 Custom Samplers: Akashic v2, Adept Solver, Adept Ancestral, Mirror Correction Euler") |
| print(" ⚡ 6 k-diffusion Samplers with weight scaling") |
| print(" 📅 18 Schedulers (including AkashicAOS Alt, AkashicEQFlow)") |
| print(" 🎨 VAE Reflection") |
| print(" ✅ A1111 port of ComfyUI-Adept-Sampler") |
|
|
|
|