""" Adept Sampler v5 for Automatic1111 WebUI Complete port with ALL custom samplers from ComfyUI Version: 5.0 """ import torch import numpy as np import math from tqdm import trange from modules import scripts, shared, script_callbacks import gradio as gr import k_diffusion.sampling # Try import torchvision for detail enhancement try: from torchvision.transforms.functional import gaussian_blur TORCHVISION_AVAILABLE = True except ImportError: TORCHVISION_AVAILABLE = False print("⚠️ torchvision not available - detail enhancement disabled") # ============================================================================ # GLOBAL STATE # ============================================================================ ADEPT_STATE = { "enabled": False, "scale": 1.0, "shift": 0.0, "start_pct": 0.0, "end_pct": 1.0, "eta": 1.0, "s_noise": 1.0, "adaptive_eta": False, "scheduler": "Standard", "vae_reflection": False, # Custom sampler settings "use_custom_sampler": False, "custom_sampler": "Akashic Solver v2", "tau": 0.5, "phase_strength": 0.5, "solver_order": 2, "use_corrector": True, "phase_noise": False, "enhanced_derivative": False, "smea_strength": 0.0, "ndb_strength": 0.0, "eqvae_mode": "Off", # Mirror Correction Euler controls "mirror_correction_phase": 0.5, "mirror_smooth_phase": False, # CFG enhancement settings "cfg_drift_enabled": False, "cfg_drift_method": "mean", "cfg_drift_intensity": 0.5, "spectral_cfg_enabled": False, "spectral_multiplier": 1.0, "spectral_percentile": 5.0, "phase_cfg_enabled": False, "phase_cfg_alpha": 2.0, "phase_cfg_beta": 2.0, "cfg_runtime_mode": "off", # off | a1111-postcfg | a1111-monkeypatch | native-hook # Internal bookkeeping for phase-aware CFG progress tracking "_cfg_step_idx": 0, "_cfg_total_steps": 1, } # Store original samplers ORIGINAL_SAMPLERS = {} # VAE Reflection state _vae_reflection_active = False _vae_original_padding_modes = {} # CFG hook / callback runtime state _ADEPT_CFG_AFTER_CB = None _ADEPT_CFG_DENOISER_CB = None _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False # CFGDenoiser monkey-patch runtime state _CFGD_ORIG_COMBINE = None _CFGD_ORIG_COMBINE_EDIT = None _CFGD_ORIG_FORWARD = None _CFGD_MONKEYPATCH_ACTIVE = False _ADEPT_CFGDENOISER_CTX_ATTR = "_adept_cfg_ctx" # ============================================================================ # BASIC UTILITY FUNCTIONS (from v3) # ============================================================================ def to_d(x, sigma, denoised): """Convert denoised prediction to derivative.""" diff = x - denoised safe_sigma = torch.clamp(sigma, min=1e-4) derivative = diff / safe_sigma sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) derivative_max = torch.abs(derivative).max() if derivative_max > sigma_adaptive_threshold: derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) return derivative def get_ancestral_step(sigma, sigma_next, eta=1.0): """Calculate ancestral step sizes.""" if sigma_next == 0: return 0.0, 0.0 sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct): """ Compute weight scale for the current step with smooth fade-in/fade-out. The fade ramp is proportional to the active window width rather than a fixed 0.1 absolute value. For a narrow window (e.g. start=0.4, end=0.5) a hard-coded 0.1 ramp would consume the entire window and produce jarring or contradictory behaviour; normalising to 20 % of window width keeps the envelope sensible at any window size. Returns 1.0 (no-op) outside [start_pct, end_pct]. """ # Clamp / validate inputs so callers can pass raw UI values safely. start_pct = max(0.0, min(float(start_pct), 1.0)) end_pct = max(start_pct, min(float(end_pct), 1.0)) total_steps = max(total_steps, 1) progress = step_idx / max(total_steps - 1, 1) if progress < start_pct or progress > end_pct: return 1.0 window = end_pct - start_pct if window < 1e-6: # Degenerate window — treat as fully active for that single step. return float(base_scale) # Fade ramp = 20 % of window, capped at 0.05 so it never feels sluggish # on very wide windows and never feels jarring on narrow ones. ramp = min(0.20 * window, 0.05) if ramp > 0 and progress < start_pct + ramp: fade = (progress - start_pct) / ramp return 1.0 + (base_scale - 1.0) * fade elif ramp > 0 and progress > end_pct - ramp: fade = (end_pct - progress) / ramp return 1.0 + (base_scale - 1.0) * fade else: return float(base_scale) def default_noise_sampler(x): """Simple noise sampler fallback.""" def sampler(sigma, sigma_next): return torch.randn_like(x) return sampler def get_noise_sampler(x): """Get noise sampler for the given tensor.""" return default_noise_sampler(x) # ============================================================================ # ADVANCED UTILITY FUNCTIONS # ============================================================================ def to_d_enhanced_ancestral(x, sigma, denoised, eta, progress, generator=None): """Enhanced derivative for ancestral sampling.""" diff = x - denoised safe_sigma = torch.clamp(sigma, min=1e-4) base_derivative = diff / safe_sigma def safe_randn_like(tensor, generator=None): if generator is None: return torch.randn_like(tensor) try: return torch.randn(tensor.shape, device=tensor.device, dtype=tensor.dtype, generator=generator) except (TypeError, AttributeError): return torch.randn_like(tensor) if eta > 1.0: eta_correction = 0.02 * (eta - 1.0) * safe_randn_like(diff, generator) * progress base_derivative = base_derivative + eta_correction elif eta < 1.0: eta_correction = 0.015 * (1.0 - eta) * safe_randn_like(diff, generator) * (1.0 - progress) base_derivative = base_derivative - eta_correction if progress < 0.3: phase_correction = 0.01 * safe_randn_like(diff, generator) base_derivative = base_derivative + phase_correction elif progress > 0.7: phase_correction = 0.008 * safe_randn_like(diff, generator) base_derivative = base_derivative - phase_correction sigma_adaptive_threshold = 500.0 * (1.0 + sigma / 10.0) derivative_max = torch.abs(base_derivative).max() if derivative_max > sigma_adaptive_threshold: base_derivative = torch.clamp(base_derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold) return base_derivative def apply_dynamic_thresholding(x, percentile=0.995, clamp_range=1.0): """Dynamic thresholding for high CFG.""" if percentile >= 1.0: return x try: batch_size = x.shape[0] x_flat = x.view(batch_size, -1) abs_max = torch.abs(x_flat).max(dim=1, keepdim=True)[0] if abs_max.max() < 5.0: return x k = max(1, int(x_flat.shape[1] * (1.0 - percentile))) topk_vals = torch.topk(torch.abs(x_flat), k=k, dim=1, largest=True)[0] s = topk_vals[:, -1:].clamp(min=1.0) threshold = s * 2.5 mask = torch.abs(x_flat) > threshold x_flat = torch.where(mask, torch.sign(x_flat) * threshold, x_flat) x_flat = x_flat * 0.98 return x_flat.view(x.shape) except Exception as e: return x def compute_compensation_ratio(r, step_idx, total_steps, base_ratio=1.0): """DC-Solver compensation.""" progress = step_idx / max(total_steps - 1, 1) if progress < 0.3: phase_weight = 1.5 elif progress < 0.7: phase_weight = 1.0 else: phase_weight = 1.3 return base_ratio * phase_weight * (1.0 + 0.1 * math.tanh(r - 1.0)) def compute_tau_eqvae(progress, base_tau=0.5, phase_strength=0.5): """Phase-aware tau for standard VAE.""" if progress < 0.30: phase_factor = 1.0 + 0.2 * phase_strength elif progress < 0.60: phase_factor = 1.0 - 0.15 * phase_strength else: phase_factor = 1.0 - 0.3 * phase_strength return min(1.0, max(0.0, base_tau * phase_factor)) def compute_eqvae_tau(progress, base_tau, phase_strength): """EQ-VAE tau with shifted phases.""" if progress < 0.25: phase_factor = 1.0 + 0.10 * phase_strength elif progress < 0.55: phase_factor = 1.0 - 0.10 * phase_strength else: phase_factor = 1.0 - 0.20 * phase_strength return min(1.0, max(0.0, base_tau * phase_factor)) def compute_eqvae_noise_scale(base_s_noise, progress): """EQ-VAE noise scale.""" eqvae_base_factor = 0.88 if progress < 0.25: phase_factor = 1.0 + 0.05 * (1.0 - progress / 0.25) elif progress < 0.60: phase_factor = 1.0 - 0.05 * ((progress - 0.25) / 0.35) else: phase_factor = 0.95 return base_s_noise * eqvae_base_factor * phase_factor def compute_eqvae_ndb(progress, ndb_strength): """Native Detail Boost for EQ-VAE.""" if ndb_strength <= 0: return 0.5, 0.0 blur_sigma = 0.6 if progress < 0.30: phase_progress = progress / 0.30 high_freq_boost = 0.03 * ndb_strength * phase_progress elif progress < 0.60: phase_progress = (progress - 0.30) / 0.30 high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength else: phase_progress = (progress - 0.60) / 0.40 high_freq_boost = (0.10 + 0.10 * phase_progress) * ndb_strength return blur_sigma, high_freq_boost def compute_native_detail_boost(progress, ndb_strength=0.0): """Native Detail Boost for standard VAE.""" if ndb_strength <= 0: return 1.0, 0.0 if progress < 0.30: phase_progress = progress / 0.30 high_freq_boost = 0.03 * ndb_strength * phase_progress elif progress < 0.60: phase_progress = (progress - 0.30) / 0.30 high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength else: phase_progress = (progress - 0.60) / 0.40 high_freq_boost = (0.10 + 0.08 * phase_progress) * ndb_strength return 1.0, high_freq_boost def compute_smea_factor(progress, smea_strength=0.5): """SMEA coherency.""" if smea_strength <= 0: return 1.0 smea_interp = 0.5 * (1 + math.sin(math.pi * (progress - 0.5))) return 1.0 - smea_strength * (1.0 - smea_interp) # ============================================================================ # ADVANCED CFG TECHNIQUES from reForge # ============================================================================ def apply_spectral_modulation_clybius(noise_pred, multiplier=1.0, percentile=5.0): """ Clybius Spectral Modulation: Apply frequency-domain corrections to noise prediction. This is the correct implementation based on ComfyUI-Latent-Modifiers. It should be applied to noise_pred (cond - uncond), NOT to denoised latent. Args: noise_pred: The noise prediction tensor (cond - uncond) multiplier: Modulation strength (0=none, 1=full Clybius effect). Default: 1.0 percentile: Upper/lower percentile threshold. Default: 5.0 Returns: Spectrally modulated noise prediction """ if multiplier == 0 or percentile <= 0: return noise_pred try: # FFT fourier = torch.fft.fft2(noise_pred, dim=(-2, -1)) # Log amplitude (with small epsilon for numerical stability) log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8) # Compute quantiles on absolute log amplitude log_amp_flat = log_amp.abs().flatten(2) quantile_low = torch.quantile(log_amp_flat, percentile * 0.01, dim=2) quantile_high = torch.quantile(log_amp_flat, 1 - percentile * 0.01, dim=2) # Expand quantiles back to log_amp shape quantile_low = quantile_low.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) quantile_high = quantile_high.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape) # Create masks (Clybius approach) # mask_low: boost values below low threshold (range 1.0 to 1.5) # mask_high: reduce values above high threshold (range 0.5 to 1.0) mask_low = ((log_amp < quantile_low).float() + 1).clamp_(max=1.5) mask_high = ((log_amp < quantile_high).float()).clamp_(min=0.5) # Apply modulation via exponentiation filtered_fourier = fourier * ((mask_low * mask_high) ** multiplier) # Inverse FFT result = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)).real return result except Exception as e: print(f"⚠️ Spectral modulation failed: {e}") return noise_pred def create_spectral_modulation_cfg_hook(multiplier=1.0, percentile=5.0): """ Create a CFG hook that applies Clybius spectral modulation to noise prediction. This hooks into reForge's set_model_sampler_cfg_function to intercept the CFG calculation and apply spectral modulation at the correct point. Args: multiplier: Modulation strength (0=none, 1=full). Default: 1.0 percentile: Frequency percentile threshold. Default: 5.0 Returns: A hook function to pass to set_model_sampler_cfg_function """ def spectral_cfg_hook(args): cond = args["cond"] uncond = args["uncond"] cond_scale = args["cond_scale"] sigma = args["sigma"] x_orig = args["input"] # Reshape sigma for broadcasting sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) # Convert to v-pred space (from RescaleCFG reference) x = x_orig / (sigma * sigma + 1.0) cond_v = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) uncond_v = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) # Compute noise prediction noise_pred = cond_v - uncond_v # Apply Clybius spectral modulation to noise prediction noise_pred_modulated = apply_spectral_modulation_clybius(noise_pred, multiplier, percentile) # Compute CFG with modified noise prediction x_cfg = uncond_v + cond_scale * noise_pred_modulated # Convert back from v-pred space return x_orig - (x - x_cfg * sigma / (sigma * sigma + 1.0) ** 0.5) return spectral_cfg_hook def apply_combat_cfg_drift(latent, method='mean', intensity=1.0): """ Combat CFG Drift: Reduce mean drift from high CFG values. Based on ComfyUI-Latent-Modifiers. As CFG increases, the latent mean can drift away from 0, which causes color shifts and other artifacts. This technique reduces the drift proportionally based on intensity. Args: latent: The latent tensor to correct method: 'mean' or 'median'. Default: 'mean' intensity: How much drift to remove (0=none, 1=full). Default: 1.0 Returns: Drift-corrected latent """ if intensity <= 0: return latent try: if method == 'median': # Compute global median per batch (across all channels and spatial dims) center = latent.view(latent.shape[0], -1).median(dim=-1, keepdim=True)[0] center = center.view(latent.shape[0], 1, 1, 1) else: # Compute global mean per batch (across all channels and spatial dims) # This matches ComfyUI's PostCFGsubtractMeanNode implementation center = latent.mean(dim=(1, 2, 3), keepdim=True) # Remove drift proportionally based on intensity # intensity=1.0 removes all drift, intensity=0.5 removes half return latent - center * intensity except Exception as e: print(f"⚠️ Combat CFG drift failed: {e}") return latent def compute_phase_aware_cfg_scale(base_scale, progress, alpha=2.0, beta=2.0): """ Phase-Aware CFG Scaling: Adjust CFG scale based on sampling progress. Inspired by β-CFG (arXiv:2502.10574). CFG effectiveness varies by sampling phase: - Early: Lower CFG allows manifold exploration - Middle: Higher CFG for prompt adherence - Late: Lower CFG to stay on data manifold Args: base_scale: The user-specified CFG scale progress: Sampling progress (0.0 to 1.0) alpha: Beta distribution alpha parameter. Default: 2.0 beta: Beta distribution beta parameter. Default: 2.0 Returns: Adjusted CFG scale for the current step """ try: # Use a simple polynomial approximation of beta distribution # Beta(2,2) peaks at 0.5 with a smooth curve # f(x) = 6 * x * (1-x) for Beta(2,2), normalized to peak at 1 if alpha == 2.0 and beta == 2.0: # Simple case: symmetric peak at 0.5 scale_factor = 4.0 * progress * (1.0 - progress) # Peaks at 1.0 when progress=0.5 scale_factor = 0.7 + 0.6 * scale_factor # Range: 0.7 to 1.3 else: # General case: use polynomial approximation # Mode of Beta(a,b) is at (a-1)/(a+b-2) mode = (alpha - 1.0) / (alpha + beta - 2.0) if (alpha + beta) > 2 else 0.5 # Create a smooth curve that peaks at the mode dist_from_mode = abs(progress - mode) scale_factor = 1.0 - 0.3 * dist_from_mode * 2 # Simple linear falloff scale_factor = max(0.7, min(1.3, scale_factor)) return base_scale * scale_factor except Exception as e: print(f"⚠️ Phase-aware CFG scaling failed: {e}") return base_scale # apply_cfg_techniques() removed — was using legacy keys (akashic_combat_cfg_drift / # akashic_combat_drift_intensity) that no longer match the live CFG runtime, which # operates through configure_cfg_runtime() / adept_after_cfg_callback instead. # ============================================================================ # DUAL-MODE CFG RUNTIME (A1111 callbacks + optional native hook) # ============================================================================ def create_phase_aware_native_cfg_hook(base_hook=None, alpha=2.0, beta=2.0): """ Native CFG hook for Forge/reForge-like backends that support set_model_sampler_cfg_function(). Applies phase-aware CFG scaling, then optionally delegates to a downstream hook (e.g. spectral modulation). """ def hook(args): cond = args["cond"] uncond = args["uncond"] cond_scale = float(args["cond_scale"]) sigma = args["sigma"] x_orig = args["input"] total_steps = max(int(ADEPT_STATE.get("_cfg_total_steps", 1)), 1) step_idx = int(ADEPT_STATE.get("_cfg_step_idx", 0)) progress = min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) phased_scale = compute_phase_aware_cfg_scale(cond_scale, progress, alpha=alpha, beta=beta) patched_args = dict(args) patched_args["cond_scale"] = phased_scale if base_hook is not None: return base_hook(patched_args) # Vanilla CFG combine with phased scale sigma_b = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) x = x_orig / (sigma_b * sigma_b + 1.0) cond_v = ((x - (x_orig - cond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b uncond_v = ((x - (x_orig - uncond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b x_cfg = uncond_v + phased_scale * (cond_v - uncond_v) return x_orig - (x - x_cfg * sigma_b / (sigma_b * sigma_b + 1.0) ** 0.5) return hook def create_combined_native_cfg_hook(): """ Build one composite native hook from whatever CFG features are enabled. Layer order: phase-aware scale → spectral modulation. Returns None if nothing is enabled (caller should clear the hook). """ base_hook = None if ADEPT_STATE.get("spectral_cfg_enabled", False): base_hook = create_spectral_modulation_cfg_hook( multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), percentile=ADEPT_STATE.get("spectral_percentile", 5.0), ) if ADEPT_STATE.get("phase_cfg_enabled", False): return create_phase_aware_native_cfg_hook( base_hook=base_hook, alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), beta=ADEPT_STATE.get("phase_cfg_beta", 2.0), ) return base_hook def adept_cfg_denoiser_callback(params): """ Official A1111 on_cfg_denoiser callback. Used only to track step counters for phase-aware progress bookkeeping; the public API here doesn't expose cond/uncond predictions so we can't do CFG math. """ ADEPT_STATE["_cfg_step_idx"] = int(getattr(params, "sampling_step", 0)) ADEPT_STATE["_cfg_total_steps"] = int(getattr(params, "total_sampling_steps", 1)) def adept_after_cfg_callback(params): """ Official A1111 on_cfg_after_cfg callback. Combat CFG Drift is the only technique that maps cleanly here, because AfterCFGCallbackParams only provides (x, sampling_step, total_sampling_steps) — no raw cond/uncond tensors. """ if not ADEPT_STATE.get("enabled", False): return if not ADEPT_STATE.get("cfg_drift_enabled", False): return try: params.x = apply_combat_cfg_drift( params.x, method=ADEPT_STATE.get("cfg_drift_method", "mean"), intensity=ADEPT_STATE.get("cfg_drift_intensity", 0.5), ) except Exception as e: print(f"⚠️ Adept post-CFG drift callback failed: {e}") def uninstall_a1111_cfg_callbacks(): global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB for cb in (_ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB): if cb is not None: try: script_callbacks.remove_callbacks_for_function(cb) except Exception: pass _ADEPT_CFG_AFTER_CB = None _ADEPT_CFG_DENOISER_CB = None def install_a1111_cfg_callbacks(): global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB uninstall_a1111_cfg_callbacks() _ADEPT_CFG_DENOISER_CB = adept_cfg_denoiser_callback _ADEPT_CFG_AFTER_CB = adept_after_cfg_callback script_callbacks.on_cfg_denoiser( _ADEPT_CFG_DENOISER_CB, name="adept_cfg_denoiser") script_callbacks.on_cfg_after_cfg( _ADEPT_CFG_AFTER_CB, name="adept_after_cfg") def _get_native_cfg_hook_target(): """ Locate a Forge/reForge-like model object that supports set_model_sampler_cfg_function(), if one exists. """ sd = getattr(shared, "sd_model", None) candidates = [ sd, getattr(sd, "forge_objects", None), getattr(sd, "model", None), getattr(getattr(sd, "model", None), "model", None) if sd else None, ] for obj in candidates: if obj is not None and hasattr(obj, "set_model_sampler_cfg_function"): return obj return None def uninstall_native_cfg_hook(): global _ADEPT_NATIVE_CFG_HOOK_ACTIVE target = _get_native_cfg_hook_target() if target is not None: try: target.set_model_sampler_cfg_function(None) except Exception: pass _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False def install_native_cfg_hook(): global _ADEPT_NATIVE_CFG_HOOK_ACTIVE target = _get_native_cfg_hook_target() if target is None: _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False return False hook = create_combined_native_cfg_hook() try: target.set_model_sampler_cfg_function(hook) # None clears it if nothing enabled _ADEPT_NATIVE_CFG_HOOK_ACTIVE = (hook is not None) return True except Exception as e: print(f"⚠️ Adept native CFG hook install failed: {e}") _ADEPT_NATIVE_CFG_HOOK_ACTIVE = False return False # ============================================================================ # CFGDenoiser MONKEY-PATCH (stock A1111 fallback for spectral/phase CFG) # ============================================================================ def _adept_cfg_progress_from_denoiser(denoiser): """Compute sampling progress [0,1] from CFGDenoiser step counters.""" total_steps = max(int(getattr(denoiser, "total_steps", 1) or 1), 1) step_idx = int(getattr(denoiser, "step", 0)) return min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0) def _adept_nativeish_cfg_term(x_i, sigma_i, cond_i, uncond_i, scale): """ Approximate native hook behavior for one cond/uncond pair. Feeds x/sigma/cond/uncond into the same composite hook builder used in native-hook mode so stock A1111 gets as close as possible to Forge parity. Falls back to plain weighted delta if hook is unavailable or errors. """ if abs(float(scale)) < 1e-12: return torch.zeros_like(uncond_i) hook = create_combined_native_cfg_hook() if hook is None: return (cond_i - uncond_i) * float(scale) try: combined = hook({ "cond": cond_i, "uncond": uncond_i, "cond_scale": float(scale), "sigma": sigma_i, "input": x_i, }) return combined - uncond_i except Exception as e: print(f"⚠️ Adept native-ish CFG term fallback: {e}") return (cond_i - uncond_i) * float(scale) def patch_cfg_denoiser(): """ Stock A1111 fallback for Spectral Modulation + Phase-Aware CFG. Strategy: 1. Thin forward() wrapper that stashes x / sigma on the instance so the combine methods can reach them — original forward logic is untouched. 2. Patched combine_denoised() uses those values to call the same composite hook builder as native-hook mode, giving near-parity behaviour. 3. Patched combine_denoised_for_edit_model() does the same for pix2pix. This is intentionally safer than a full forward() rewrite: upstream A1111 changes to refiner/masking/skip-uncond logic remain unaffected. """ global _CFGD_ORIG_COMBINE, _CFGD_ORIG_COMBINE_EDIT, _CFGD_ORIG_FORWARD, _CFGD_MONKEYPATCH_ACTIVE try: from modules import sd_samplers_cfg_denoiser as sd_cfg except Exception as e: print(f"⚠️ Adept CFGDenoiser patch import failed: {e}") _CFGD_MONKEYPATCH_ACTIVE = False return False cls = sd_cfg.CFGDenoiser if _CFGD_ORIG_COMBINE is None: _CFGD_ORIG_COMBINE = cls.combine_denoised if _CFGD_ORIG_COMBINE_EDIT is None: _CFGD_ORIG_COMBINE_EDIT = cls.combine_denoised_for_edit_model if _CFGD_ORIG_FORWARD is None: _CFGD_ORIG_FORWARD = cls.forward # --- forward wrapper: stash x/sigma, then run original --- def adept_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): setattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, { "x": x, "sigma": sigma, "uncond": uncond, "cond": cond, "cond_scale": float(cond_scale), }) try: return _CFGD_ORIG_FORWARD(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond) finally: try: delattr(self, _ADEPT_CFGDENOISER_CTX_ATTR) except AttributeError: pass # --- combine_denoised: per-cond native-ish or plain path --- def adept_combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) x_ctx = None if ctx is None else ctx.get("x", None) sigma_ctx = None if ctx is None else ctx.get("sigma", None) progress = _adept_cfg_progress_from_denoiser(self) eff_scale = float(cond_scale) if ADEPT_STATE.get("phase_cfg_enabled", False): eff_scale = compute_phase_aware_cfg_scale( eff_scale, progress, alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), ) use_nativeish = ( (ADEPT_STATE.get("spectral_cfg_enabled", False) or ADEPT_STATE.get("phase_cfg_enabled", False)) and x_ctx is not None and sigma_ctx is not None ) for i, conds in enumerate(conds_list): for cond_index, weight in conds: cond_i = x_out[cond_index:cond_index + 1] uncond_i = denoised_uncond[i:i + 1] if use_nativeish: term = _adept_nativeish_cfg_term( x_i = x_ctx[i:i + 1], sigma_i = sigma_ctx[i:i + 1], cond_i = cond_i, uncond_i = uncond_i, scale = float(weight) * eff_scale, ) denoised[i:i + 1] += term else: delta = cond_i - uncond_i if ADEPT_STATE.get("spectral_cfg_enabled", False): delta = apply_spectral_modulation_clybius( delta, multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), percentile=ADEPT_STATE.get("spectral_percentile", 5.0), ) denoised[i:i + 1] += delta * (float(weight) * eff_scale) return denoised # --- combine_denoised_for_edit_model: pix2pix / instruct path --- def adept_combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3) ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None) x_ctx = None if ctx is None else ctx.get("x", None) sigma_ctx = None if ctx is None else ctx.get("sigma", None) progress = _adept_cfg_progress_from_denoiser(self) eff_scale = float(cond_scale) if ADEPT_STATE.get("phase_cfg_enabled", False): eff_scale = compute_phase_aware_cfg_scale( eff_scale, progress, alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0), beta =ADEPT_STATE.get("phase_cfg_beta", 2.0), ) # Native-ish path when context is available if (ADEPT_STATE.get("spectral_cfg_enabled", False) or ADEPT_STATE.get("phase_cfg_enabled", False)): if x_ctx is not None and sigma_ctx is not None: try: hook = create_combined_native_cfg_hook() if hook is not None: base = hook({ "cond": out_cond, "uncond": out_img_cond, "cond_scale": eff_scale, "sigma": sigma_ctx, "input": x_ctx, }) return base + self.image_cfg_scale * (out_img_cond - out_uncond) except Exception as e: print(f"⚠️ Adept edit-model native-ish fallback: {e}") # Plain path (no context or hook failed) delta = out_cond - out_img_cond if ADEPT_STATE.get("spectral_cfg_enabled", False): delta = apply_spectral_modulation_clybius( delta, multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0), percentile=ADEPT_STATE.get("spectral_percentile", 5.0), ) return out_uncond + eff_scale * delta + self.image_cfg_scale * (out_img_cond - out_uncond) try: cls.forward = adept_forward cls.combine_denoised = adept_combine_denoised cls.combine_denoised_for_edit_model = adept_combine_denoised_for_edit_model _CFGD_MONKEYPATCH_ACTIVE = True return True except Exception as e: print(f"⚠️ Adept CFGDenoiser patch failed: {e}") _CFGD_MONKEYPATCH_ACTIVE = False return False def unpatch_cfg_denoiser(): global _CFGD_MONKEYPATCH_ACTIVE try: from modules import sd_samplers_cfg_denoiser as sd_cfg except Exception: _CFGD_MONKEYPATCH_ACTIVE = False return False cls = sd_cfg.CFGDenoiser try: if _CFGD_ORIG_FORWARD is not None: cls.forward = _CFGD_ORIG_FORWARD if _CFGD_ORIG_COMBINE is not None: cls.combine_denoised = _CFGD_ORIG_COMBINE if _CFGD_ORIG_COMBINE_EDIT is not None: cls.combine_denoised_for_edit_model = _CFGD_ORIG_COMBINE_EDIT _CFGD_MONKEYPATCH_ACTIVE = False return True except Exception as e: print(f"⚠️ Adept CFGDenoiser unpatch failed: {e}") _CFGD_MONKEYPATCH_ACTIVE = False return False def configure_cfg_runtime(): """ Select and activate the right CFG runtime mode: off – nothing enabled; all hooks/callbacks/patches cleared a1111-postcfg – stock A1111; Combat CFG Drift only via official callback a1111-monkeypatch – stock A1111; Spectral + Phase-Aware via CFGDenoiser patch native-hook – Forge/reForge-like backend; all three via sampler CFG hook Returns the mode string so process() can log it. """ # If the extension is globally disabled, always tear down and return off. if not ADEPT_STATE.get("enabled", False): uninstall_a1111_cfg_callbacks() uninstall_native_cfg_hook() unpatch_cfg_denoiser() ADEPT_STATE["cfg_runtime_mode"] = "off" return "off" drift = ADEPT_STATE.get("cfg_drift_enabled", False) spectral = ADEPT_STATE.get("spectral_cfg_enabled", False) phase = ADEPT_STATE.get("phase_cfg_enabled", False) # Always tear down everything first for a clean slate uninstall_a1111_cfg_callbacks() uninstall_native_cfg_hook() unpatch_cfg_denoiser() if not (drift or spectral or phase): ADEPT_STATE["cfg_runtime_mode"] = "off" return "off" # Prefer native hook if backend supports it native_target = _get_native_cfg_hook_target() if native_target is not None: install_a1111_cfg_callbacks() # keeps drift working in native mode too install_native_cfg_hook() ADEPT_STATE["cfg_runtime_mode"] = "native-hook" return "native-hook" # Stock A1111: always install callbacks (drift) install_a1111_cfg_callbacks() if spectral or phase: ok = patch_cfg_denoiser() if ok: ADEPT_STATE["cfg_runtime_mode"] = "a1111-monkeypatch" print("✅ Adept: stock A1111 CFGDenoiser monkey-patch active (spectral/phase + drift enabled)") return "a1111-monkeypatch" print("⚠️ Adept: CFGDenoiser monkey-patch failed; falling back to post-CFG drift only") ADEPT_STATE["cfg_runtime_mode"] = "a1111-postcfg" return "a1111-postcfg" def sa_solver_step(x, d_history, sigma, sigma_next, tau, s_noise=1.0, noise_sampler=None, order=2, ndb_strength=0.0, progress=0.0, eqvae_mode=False, eqvae_blur_sigma=None): """SA-Solver step - CRITICAL for Akashic Solver.""" dt = sigma_next - sigma if len(d_history) >= 2 and order >= 2: sigma_cur, d_cur = d_history[-1] sigma_prev, d_prev = d_history[-2] h_prev = sigma_cur - sigma_prev r = abs(dt / (h_prev + 1e-8)) if abs(h_prev) > 1e-8 else 1.0 r = min(r, 2.0) if len(d_history) >= 3 and order >= 3: sigma_0, d_0 = d_history[-3] h_0 = sigma_prev - sigma_0 h_1 = h_prev if abs(h_0) > 1e-6 and abs(h_1) > 1e-6: r0 = min(abs(h_1 / h_0), 2.0) r1 = min(abs(dt / (h_1 + 1e-8)), 2.0) tau_blend = 1.0 - tau c0_ab3 = 1.0 + (1.0 + r0) * r1 / 2.0 c1_ab3 = -(1.0 + r0) * r1 / 2.0 c2_ab3 = r0 * r1 / 2.0 c0 = tau_blend * c0_ab3 + (1.0 - tau_blend) * 1.0 c1 = tau_blend * c1_ab3 c2 = tau_blend * c2_ab3 c_sum = c0 + c1 + c2 if abs(c_sum) > 1e-8: c0 /= c_sum c1 /= c_sum c2 /= c_sum else: c0, c1, c2 = 1.0, 0.0, 0.0 d_interp = c0 * d_cur + c1 * d_prev + c2 * d_0 else: tau_blend = 1.0 - tau c1_ab2 = 1.0 + 0.5 * r c2_ab2 = -0.5 * r c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 c2 = tau_blend * c2_ab2 c_sum = c1 + c2 if abs(c_sum) > 1e-8: c1 /= c_sum c2 /= c_sum d_interp = c1 * d_cur + c2 * d_prev else: tau_blend = 1.0 - tau c1_ab2 = 1.0 + 0.5 * r c2_ab2 = -0.5 * r c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0 c2 = tau_blend * c2_ab2 c_sum = c1 + c2 if abs(c_sum) > 1e-8: c1 /= c_sum c2 /= c_sum d_interp = c1 * d_cur + c2 * d_prev elif len(d_history) >= 1: d_interp = d_history[-1][1] else: d_interp = torch.zeros_like(x) # Compute sigma_up based on tau (controls stochasticity) sigma_up = 0.0 if tau > 0 and sigma_next > 0 and noise_sampler is not None: sigma_ancestral_sq = sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / (sigma ** 2 + 1e-8) sigma_ancestral = sigma_ancestral_sq ** 0.5 if sigma_ancestral_sq > 0 else 0.0 sigma_up = tau * sigma_ancestral sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 dt_adjusted = sigma_down - sigma x_det = x + d_interp * dt_adjusted noise = noise_sampler(sigma, sigma_next) * s_noise * sigma_up # Apply Native Detail Boost if enabled if ndb_strength > 0 and TORCHVISION_AVAILABLE: # Use EQ-VAE optimized NDB parameters if in EQ-VAE mode if eqvae_mode: blur_sigma, high_freq_boost = compute_eqvae_ndb(progress, ndb_strength) else: _, high_freq_boost = compute_native_detail_boost(progress, ndb_strength) blur_sigma = 0.5 # Default blur sigma # Override blur_sigma if explicitly provided if eqvae_blur_sigma is not None: blur_sigma = eqvae_blur_sigma # Extract high-frequency component from noise using Gaussian blur try: low_freq_noise = gaussian_blur(noise, kernel_size=3, sigma=blur_sigma) high_freq_noise = noise - low_freq_noise noise = noise + high_freq_noise * high_freq_boost except Exception: pass # Fallback: use original noise if blur fails x_next = x_det + noise else: x_next = x + d_interp * dt return x_next, sigma_up def create_detail_enhanced_model(model, x, sigmas, settings): # NOTE: Detail Enhancement is currently an internal/experimental path. # It is not wired into the UI and callers always pass # use_detail_enhancement=False, so this function is never invoked at # runtime. Kept for future re-integration; do not rely on it. """Detail enhancement wrapper.""" if not TORCHVISION_AVAILABLE: return model base_strength = settings.get('detail_enhancement_strength', 0.05) radius = settings.get('detail_separation_radius', 0.5) total_steps = len(sigmas) - 1 class DetailEnhancer: def __init__(self): self.current_step = 0 def __call__(self, x_current, sigma, **kwargs): denoised = model(x_current, sigma, **kwargs) try: low_freq = gaussian_blur(denoised, kernel_size=3, sigma=radius) high_freq = denoised - low_freq progress = min(self.current_step / max(total_steps, 1), 1.0) strength = base_strength * (0.5 + progress) enhanced = denoised + high_freq * strength self.current_step += 1 return enhanced except Exception: return denoised return DetailEnhancer() # ============================================================================ # ============================================================================ # ============================================================================ # CUSTOM ADVANCED SAMPLERS (Complete port from ComfyUI) # ============================================================================ @torch.no_grad() def sample_adept_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, order=2, use_corrector=True, use_detail_enhancement=False, settings=None): """ Adept Solver: A unified training-free diffusion solver synthesizing improvements from: - DPM-Solver++ (data prediction, dynamic thresholding) - UniPC (unified predictor-corrector framework) - DEIS (exponential integrator) - DC-Solver (dynamic compensation) """ extra_args = {} if extra_args is None else extra_args settings = settings or {} s_in = x.new_ones([x.shape[0]]) order = max(1, min(order, 3)) print(f"🚀 Adept Solver active (Order: {order}, Corrector: {'On' if use_corrector else 'Off'})") active_model = model # use_detail_enhancement is always False from current call-sites; # the block below is preserved for future re-integration but is not # currently reachable via the UI. if use_detail_enhancement and TORCHVISION_AVAILABLE: active_model = create_detail_enhanced_model(model, x, sigmas, settings) model_outputs = [] for i in range(len(sigmas) - 1): sigma = sigmas[i] sigma_next = sigmas[i + 1] denoised = active_model(x, sigma * s_in, **extra_args) if extra_args.get('cond_scale', 1.0) > 7.0: denoised = apply_dynamic_thresholding(denoised, percentile=0.995) d = to_d(x, sigma, denoised) derivative_max = torch.abs(d).max() sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: print(f"⚠️ Extreme derivative detected at step {i}/{len(sigmas)-1}. Clamping for stability.") d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.zeros_like(d) model_outputs.append((sigma, d)) if len(model_outputs) > order: model_outputs.pop(0) dt = sigma_next - sigma if len(model_outputs) == 1 or order == 1: x_pred = x + d * dt elif len(model_outputs) == 2 and order >= 2: sigma_prev, d_prev = model_outputs[-2] d_cur = model_outputs[-1][1] h = sigma - sigma_prev compensation_ratio = compute_compensation_ratio(h.item() if torch.is_tensor(h) else float(h), i, len(sigmas)) d_interp = d_cur + compensation_ratio * (d_cur - d_prev) x_pred = x + d_interp * dt else: sigma_0, d_0 = model_outputs[-3] sigma_1, d_1 = model_outputs[-2] sigma_2, d_2 = model_outputs[-1] h_0 = sigma_2 - sigma_1 h_1 = sigma_1 - sigma_0 h_0_val = h_0.item() if torch.is_tensor(h_0) else float(h_0) h_1_val = h_1.item() if torch.is_tensor(h_1) else float(h_1) if abs(h_1_val) < 1e-6: compensation_ratio = compute_compensation_ratio(h_0_val, i, len(sigmas)) d_interp = d_2 + compensation_ratio * (d_2 - d_1) else: r0 = h_0_val / h_1_val c0 = 1.0 + r0 / 2.0 c1 = -r0 / 2.0 c2 = 0.0 c_sum = c0 + c1 + c2 c0 /= c_sum c1 /= c_sum c2 = 1.0 - c0 - c1 d_interp = c0 * d_2 + c1 * d_1 + c2 * d_0 x_pred = x + d_interp * dt if use_corrector and i < len(sigmas) - 2: denoised_pred = active_model(x_pred, sigma_next * s_in, **extra_args) if extra_args.get('cond_scale', 1.0) > 7.0: denoised_pred = apply_dynamic_thresholding(denoised_pred, percentile=0.995) d_pred = to_d(x_pred, sigma_next, denoised_pred) if torch.isnan(d_pred).any() or torch.isinf(d_pred).any() or torch.abs(d_pred).max() > 1000.0: d_pred = torch.clamp(d_pred, -100.0, 100.0) if torch.isnan(d_pred).any() or torch.isinf(d_pred).any(): d_pred = torch.zeros_like(d_pred) dt = sigma_next - sigma x = x + (d + d_pred) * dt * 0.5 else: x = x_pred if torch.isnan(x).any() or torch.isinf(x).any(): print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") if i == 0: raise RuntimeError("NaN/Inf on first step - check model/inputs") print(" Attempting recovery with conservative Euler step...") denoised_safe = active_model(x, sigma * s_in, **extra_args) if torch.isnan(denoised_safe).any(): raise RuntimeError("Model producing NaN - check CFG scale and model") d_safe = to_d(x, sigma, denoised_safe) dt_safe = (sigma_next - sigma) * 0.5 x = x + d_safe * dt_safe use_corrector = False print(" Recovery successful. Corrector disabled for stability.") if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_adept_ancestral_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, adaptive_eta=False, phase_noise=False, phase_strength=0.5, enhanced_derivative=False, use_detail_enhancement=False, settings=None): """ Enhanced Adept Ancestral Solver: Advanced ancestral sampling with phase-aware adaptations. Key innovations: 1. Adaptive ancestral step sizing that changes throughout sampling phases 2. Phase-aware noise injection (more noise early, less noise late) 3. Enhanced derivative computation with ancestral-specific corrections 4. Dynamic eta scheduling for better control """ extra_args = {} if extra_args is None else extra_args settings = settings or {} s_in = x.new_ones([x.shape[0]]) print(f"🚀 Enhanced Adept Ancestral Solver active (η: {eta:.2f}, s_noise: {s_noise:.2f})") print(f" Adaptive Eta: {adaptive_eta}, Phase Noise: {phase_noise}, Enhanced Derivative: {enhanced_derivative}") active_model = model # use_detail_enhancement is always False from current call-sites. if use_detail_enhancement and TORCHVISION_AVAILABLE: active_model = create_detail_enhanced_model(model, x, sigmas, settings) noise_sampler = get_noise_sampler(x) for i in range(len(sigmas) - 1): sigma = sigmas[i] sigma_next = sigmas[i + 1] progress = i / max(len(sigmas) - 1, 1) if adaptive_eta: if progress < 0.3: current_eta = eta * 1.08 elif progress < 0.7: current_eta = eta * 0.95 else: current_eta = eta * 1.02 else: current_eta = eta denoised = active_model(x, sigma * s_in, **extra_args) if extra_args.get('cond_scale', 1.0) > 7.0: denoised = apply_dynamic_thresholding(denoised, percentile=0.995) if enhanced_derivative: d = to_d_enhanced_ancestral(x, sigma, denoised, current_eta, progress, None) else: d = to_d(x, sigma, denoised) derivative_max = torch.abs(d).max() sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.zeros_like(d) if sigma_next > 0: sigma_up = min(sigma_next, current_eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 else: sigma_up = 0.0 sigma_down = 0.0 dt = sigma_down - sigma x_pred = x + d * dt if sigma_next > 0: if phase_noise: if progress < 0.25: target_multiplier = 1.0 + (0.05 * min(progress / 0.25, 1.0)) elif progress < 0.6: target_multiplier = 1.0 - (0.02 * min((progress - 0.25) / 0.35, 1.0)) else: target_multiplier = 1.0 - (0.05 * min((progress - 0.6) / 0.4, 1.0)) noise_multiplier = 1.0 + (target_multiplier - 1.0) * phase_strength adaptive_s_noise = s_noise * noise_multiplier else: adaptive_s_noise = s_noise noise = noise_sampler(sigma, sigma_next) * adaptive_s_noise * sigma_up x = x_pred + noise else: x = x_pred if torch.isnan(x).any() or torch.isinf(x).any(): print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!") if i == 0: raise RuntimeError("NaN/Inf on first step - check model/inputs") print(" Attempting recovery...") denoised_safe = active_model(x, sigma * s_in, **extra_args) if torch.isnan(denoised_safe).any(): raise RuntimeError("Model producing NaN - check CFG scale and model") d_safe = to_d(x, sigma, denoised_safe) dt_safe = (sigma_next - sigma) * 0.5 x = x + d_safe * dt_safe print(" Recovery successful.") if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_mirror_correction_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, correction_phase=0.5, smooth_phase=False): """ Mirror Correction Euler: Euler Ancestral with a semantic reflection probe. In the first `correction_phase` fraction of steps, uses a 3-call Heun correction: x_probe = 2*D(x) - x (reflection of x through its own denoised prediction) The probe lies on the denoising trajectory, giving a curvature estimate for the Heun correction. Remaining steps: standard 1-call Euler Ancestral. Args: eta: Ancestral noise coefficient. 0=deterministic, 1=full ancestral. Default: 1.0 s_noise: Noise scale multiplier. Default: 1.0 correction_phase: Fraction of steps that receive the 3-call correction. Default: 0.5 smooth_phase: Use continuous log-sigma weighting instead of a binary cutoff. Default: False """ extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) print(f"🔮 Mirror Correction Euler active (η: {eta:.2f}, s_noise: {s_noise:.2f})") print(f" Correction Phase: {correction_phase:.2f}, Smooth Phase: {smooth_phase}") noise_sampler = get_noise_sampler(x) n_steps = len(sigmas) - 1 log_sigma_phase = None log_sigma_max = None smooth_denom = 1e-6 if smooth_phase and n_steps > 0: sigma_max_val = sigmas[0].clamp(min=1e-6) phase_idx = min(int(correction_phase * n_steps), n_steps - 1) sigma_phase_val = sigmas[phase_idx].clamp(min=1e-6) log_sigma_max = torch.log(sigma_max_val).item() log_sigma_phase = torch.log(sigma_phase_val).item() smooth_denom = max(log_sigma_max - log_sigma_phase, 1e-6) for i in range(n_steps): sigma = sigmas[i] sigma_next = sigmas[i + 1] progress = i / max(n_steps - 1, 1) denoised = model(x, sigma * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) d = to_d(x, sigma, denoised) if sigma_next > 0: sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5) sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5 else: sigma_up = 0.0 sigma_down = 0.0 dt = sigma_down - sigma if smooth_phase and log_sigma_phase is not None: log_sig = torch.log(sigma.clamp(min=1e-6)).item() t = max(0.0, min(1.0, (log_sig - log_sigma_phase) / smooth_denom)) correction_weight = t ** 0.5 if correction_weight > 1e-3 and sigma_next > 0: x_probe = 2 * denoised - x d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) d_diff_norm = (d - d_probe).norm() d_scale = (d.norm() + d_probe.norm()) / 2 + 1e-6 gradient_agreement = max(0.0, 1.0 - (d_diff_norm / d_scale).item()) effective_weight = correction_weight * gradient_agreement if effective_weight > 1e-3: x3 = x + ((d + d_probe) / 2) * dt d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) d_heun = (d + d3) / 2 if not (torch.isnan(d_heun).any() or torch.isinf(d_heun).any()): d = d + effective_weight * (d_heun - d) else: if progress < correction_phase and sigma_next > 0: x_probe = 2 * denoised - x d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args)) x3 = x + ((d + d_probe) / 2) * dt d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args)) d = (d + d3) / 2 if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.zeros_like(d) x = x + d * dt if sigma_next > 0: x = x + noise_sampler(sigma, sigma_next) * s_noise * sigma_up return x @torch.no_grad() def sample_akashic_solver(model, x, sigmas, extra_args=None, callback=None, disable=None, tau=0.5, eta=1.0, s_noise=1.0, adaptive_eta=True, phase_strength=0.5, order=2, smea_strength=0.0, ndb_strength=0.0, use_detail_enhancement=False, settings=None, eqvae_mode='Off'): """ AkashicSolver v2 [EXPERIMENTAL]: Advanced sampler optimized for EQ-VAE models. Combines: 1. SA-SOLVER BASE: Multi-step Adams-Bashforth integration with tau function 2. PHASE-AWARE SAMPLING: Three-phase approach with adaptive parameters 3. SMEA COHERENCY: Sine-based interpolation for high-resolution coherency Args: eqvae_mode: EQ-VAE optimization mode ('Off' or 'Balanced') """ extra_args = {} if extra_args is None else extra_args settings = settings or {} s_in = x.new_ones([x.shape[0]]) if isinstance(eqvae_mode, bool): eqvae_enabled = eqvae_mode else: eqvae_enabled = eqvae_mode == 'Balanced' if eqvae_enabled: print(f"🌀 AkashicSolver v2 [EQ-VAE BALANCED] active") print(f" Optimized for EQ-VAE's cleaner latent space") else: print(f"🌀 AkashicSolver v2 [EXPERIMENTAL] active") print(f" τ (tau): {tau:.2f}, η (eta): {eta:.2f}, s_noise: {s_noise:.2f}") print(f" Order: {order}, Adaptive Eta: {adaptive_eta}, Phase Strength: {phase_strength:.2f}") if smea_strength > 0: print(f" SMEA: {smea_strength:.2f} (high-res coherency)") if ndb_strength > 0: print(f" Native Detail Boost: {ndb_strength:.2f} (detail enhancement)") if not eqvae_enabled: print(f" ⚠️ Use external rescaleCFG (e.g., 0.7) for EQ-VAE models") active_model = model # use_detail_enhancement is always False from current call-sites. if use_detail_enhancement and TORCHVISION_AVAILABLE: active_model = create_detail_enhanced_model(model, x, sigmas, settings) noise_sampler = get_noise_sampler(x) total_steps = len(sigmas) - 1 d_history = [] for i in range(total_steps): sigma = sigmas[i] sigma_next = sigmas[i + 1] progress = i / max(total_steps - 1, 1) if adaptive_eta: if eqvae_enabled: current_tau = compute_eqvae_tau(progress, tau, phase_strength) else: current_tau = compute_tau_eqvae(progress, tau, phase_strength) else: current_tau = tau if adaptive_eta: if eqvae_enabled: if progress < 0.25: current_eta = eta * (1.0 + 0.03 * phase_strength) elif progress < 0.55: current_eta = eta * (1.0 - 0.03 * phase_strength) else: current_eta = eta * (1.0 + 0.02 * phase_strength) else: if progress < 0.30: current_eta = eta * (1.0 + 0.08 * phase_strength) elif progress < 0.60: current_eta = eta * (1.0 - 0.05 * phase_strength) else: current_eta = eta * (1.0 + 0.02 * phase_strength) else: current_eta = eta smea_factor = compute_smea_factor(progress, smea_strength) denoised = active_model(x, sigma * s_in, **extra_args) cfg_scale = extra_args.get('cond_scale', 1.0) if cfg_scale > 7.0: denoised = apply_dynamic_thresholding(denoised, percentile=0.995) d = to_d(x, sigma, denoised) derivative_max = torch.abs(d).max() sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0) if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold: d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold) if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.zeros_like(d) d_history.append((sigma, d)) if len(d_history) > order: d_history.pop(0) effective_tau = current_tau if eqvae_enabled: effective_s_noise = compute_eqvae_noise_scale(s_noise * current_eta, progress) * smea_factor else: effective_s_noise = s_noise * current_eta * smea_factor if progress < 0.30: noise_multiplier = 1.0 + 0.03 * phase_strength elif progress < 0.60: noise_multiplier = 1.0 - 0.01 * phase_strength else: noise_multiplier = 1.0 - 0.02 * phase_strength effective_s_noise *= noise_multiplier if eqvae_enabled and ndb_strength > 0: eqvae_blur_sigma, _ = compute_eqvae_ndb(progress, ndb_strength) else: eqvae_blur_sigma = None x, sigma_up = sa_solver_step( x=x, d_history=d_history, sigma=sigma, sigma_next=sigma_next, tau=effective_tau, s_noise=effective_s_noise, noise_sampler=noise_sampler, order=order, ndb_strength=ndb_strength, progress=progress, eqvae_mode=eqvae_enabled, eqvae_blur_sigma=eqvae_blur_sigma ) if torch.isnan(x).any() or torch.isinf(x).any(): print(f"❌ AkashicSolver v2: NaN/Inf detected at step {i}/{total_steps}!") if i == 0: raise RuntimeError("NaN/Inf on first step - check model/inputs") print(" Attempting recovery...") denoised_safe = active_model(x, sigma * s_in, **extra_args) if torch.isnan(denoised_safe).any(): raise RuntimeError("Model producing NaN - reduce CFG scale or check model") d_safe = to_d(x, sigma, denoised_safe) dt_safe = (sigma_next - sigma) * 0.5 x = x + d_safe * dt_safe d_history.clear() print(" Recovery successful. Multi-step history cleared.") if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x # ============================================================================ # ============================================================================ # ============================================================================ # This file continues from PART1 + PART1B # ============================================================================ # ============================================================================ # WEIGHT PATCHER (from v3 - unchanged) # ============================================================================ def should_patch_weights(unet_model, scale, shift): """Return True if weight patching is actually needed for these parameters.""" return ( unet_model is not None and (abs(scale - 1.0) > 1e-6 or abs(shift) > 1e-6) ) class AdeptWeightPatcher: """Temporary weight scaling for UNet.""" def __init__(self, unet_model, scale=1.0, shift=0.0): self.unet_model = unet_model self.scale = scale self.shift = shift self.backups = {} self.target_layers = [] def __enter__(self): if self.unet_model is None or (abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6): return self self.target_layers.clear() self.backups.clear() try: for name, module in self.unet_model.named_modules(): if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): if hasattr(module, 'weight') and module.weight is not None: self.target_layers.append((name, module)) self.backups[name] = module.weight.data.clone() module.weight.data = module.weight.data * self.scale + self.shift except Exception as e: print(f"❌ Weight patcher failed: {e}") self.__exit__(None, None, None) return self def __exit__(self, exc_type, exc_val, exc_tb): try: for name, module in self.target_layers: if name in self.backups: module.weight.data.copy_(self.backups[name]) self.backups.clear() self.target_layers.clear() except Exception as e: print(f"❌ CRITICAL: Failed to restore weights: {e}") for name, backup_data in self.backups.items(): try: for n, m in self.target_layers: if n == name: m.weight.data.copy_(backup_data) except: pass return False # ============================================================================ # VAE REFLECTION PATCHER (from v3 - unchanged) # ============================================================================ class VAEReflectionPatcher: """Context manager for VAE reflection padding.""" def __init__(self, vae_model): self.vae_model = vae_model self.backups = {} def __enter__(self): global _vae_reflection_active, _vae_original_padding_modes if _vae_reflection_active or self.vae_model is None: return self _vae_original_padding_modes.clear() patched_count = 0 try: for name, module in self.vae_model.named_modules(): if isinstance(module, torch.nn.Conv2d): _vae_original_padding_modes[name] = module.padding_mode module.padding_mode = 'reflect' patched_count += 1 _vae_reflection_active = True print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers") except Exception as e: print(f"❌ VAE Reflection failed: {e}") self.__exit__(None, None, None) return self def __exit__(self, exc_type, exc_val, exc_tb): global _vae_reflection_active, _vae_original_padding_modes if self.vae_model is None: _vae_reflection_active = False _vae_original_padding_modes.clear() return False restored_count = 0 try: for name, module in self.vae_model.named_modules(): if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: module.padding_mode = _vae_original_padding_modes[name] restored_count += 1 _vae_reflection_active = False _vae_original_padding_modes.clear() print(f"🔄 VAE Reflection: Restored {restored_count} layers") except Exception as e: print(f"⚠️ VAE Reflection restore warning: {e}") return False def force_restore_vae_reflection(): """ Emergency / unload-path restore for VAE padding modes. Safe to call at any time — does nothing if VAE reflection was not active. """ global _vae_reflection_active, _vae_original_padding_modes if not _vae_reflection_active and not _vae_original_padding_modes: return try: sd = getattr(shared, "sd_model", None) vae = getattr(sd, "first_stage_model", None) if sd else None if vae is not None: restored = 0 for name, module in vae.named_modules(): if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes: module.padding_mode = _vae_original_padding_modes[name] restored += 1 if restored: print(f"🔄 VAE Reflection: force-restored {restored} layers") except Exception as e: print(f"⚠️ VAE Reflection force-restore warning: {e}") finally: _vae_reflection_active = False _vae_original_padding_modes.clear() # ALL SCHEDULERS (18 types) # ============================================================================ def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """AOS-V (Anime-Optimized Schedule for v-prediction models).""" rho = 7.0 p1_steps = int(num_steps * 0.2) p2_steps = int(num_steps * 0.6) ramp = torch.empty(num_steps, device=device, dtype=torch.float32) if p1_steps > 0: torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps]) ramp[:p1_steps].pow_(0.5).mul_(0.6) if p2_steps > p1_steps: torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps]) if num_steps > p2_steps: torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:]) ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho) return torch.cat([ramp, torch.zeros(1, device=device)]) def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """AOS-ε (Anime-Optimized Schedule for epsilon-prediction models).""" rho = 7.0 p1_frac, p2_frac = 0.35, 0.7 ramp_p1_val, ramp_p2_val = 0.4, 0.75 p1_steps = int(num_steps * p1_frac) p2_steps = int(num_steps * p2_frac) phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device) phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7 phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val if p1_steps == 0: phase1_ramp = torch.empty(0, device=device) if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device) if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device) ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp]) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models.""" rho = 7.0 u = torch.linspace(0, 1, num_steps, device=device) detail_power = 0.85 u_progressive = u ** detail_power mid_boost_strength = 0.08 mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5) u_modulated = u_progressive + mid_boost u_min, u_max = u_modulated.min(), u_modulated.max() if u_max - u_min > 1e-8: u_modulated = (u_modulated - u_min) / (u_max - u_min) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho for i in range(1, len(sigmas)): if sigmas[i] >= sigmas[i-1]: sigmas[i] = sigmas[i-1] * 0.995 max_ratio = 1.5 if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio: sigmas[i] = sigmas[i-1] / max_ratio return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'): """Entropic power schedule.""" rho = 7.0 linear_ramp = torch.linspace(0, 1, num_steps, device=device) power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power ramp = (linear_ramp + power_ramp) / 2.0 min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Schedule optimized around log SNR = 0 region.""" rho = 7.0 log_snr_max = 2 * torch.log(sigma_max) log_snr_min = 2 * torch.log(sigma_min) t = torch.linspace(0, 1, num_steps, device=device) concentration_power = 3.0 sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5)) linear_t = t blend_factor = 0.7 combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max) sigmas = torch.exp(log_snr / 2) return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Constant rate of distributional change.""" rho = 7.0 t = torch.linspace(0, 1, num_steps, device=device) corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Adaptive schedule combining multiple strategies.""" rho = 7.0 base_t = torch.linspace(0, 1, num_steps, device=device) strategies = [ lambda t: t, lambda t: t ** 0.8, lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t), lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))), ] weights = [0.2, 0.3, 0.2, 0.3] combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies)) if (combined_t.max() - combined_t.min()) > 1e-6: combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min()) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Cosine-annealed schedule.""" rho = 7.0 u = torch.linspace(0, 1, num_steps, device=device) t = (1 - torch.cos(math.pi * u)) / 2 min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Uniform in log-SNR space.""" u = torch.linspace(0, 1, num_steps, device=device) log_snr_max = 2 * torch.log(sigma_max) log_snr_min = 2 * torch.log(sigma_min) log_snr = log_snr_max + u * (log_snr_min - log_snr_max) sigmas = torch.exp(log_snr / 2) return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0): """Concentrate steps near mid-range sigmas.""" rho = 7.0 u = torch.linspace(0, 1, num_steps, device=device) k_tensor = torch.tensor(k, device=device, dtype=u.dtype) t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0): """Faster early lock-in with extra resolution in final steps.""" rho = 7.0 u = torch.linspace(0, 1, num_steps, device=device) early_mask = u < pivot late_mask = ~early_mask t = torch.empty_like(u) t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot late_u = u[late_mask] t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot))) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Karras schedule with controlled jitter.""" if num_steps <= 0: return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) rho = 7.0 indices = torch.arange(num_steps, device=device, dtype=torch.float32) denom = max(1, num_steps - 1) base = (indices + 0.5) / denom jitter_seed = torch.sin((indices + 1) * 2.3999632) jitter_strength = 0.35 jitter = jitter_seed * jitter_strength / denom u = torch.clamp(base + jitter, 0.0, 1.0) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'): """Stochastic scheduler with controlled randomness.""" rho = 7.0 # Base schedule if base_schedule == 'karras': indices = torch.arange(num_steps, device=device, dtype=torch.float32) u = (indices / max(1, num_steps - 1)) ** (1 / rho) elif base_schedule == 'cosine': u = torch.linspace(0, 1, num_steps, device=device) u = (1 - torch.cos(math.pi * u)) / 2 else: # uniform u = torch.linspace(0, 1, num_steps, device=device) # Add noise if noise_type == 'brownian': noise = torch.randn(num_steps, device=device).cumsum(0) noise = noise / noise.std() elif noise_type == 'uniform': noise = torch.rand(num_steps, device=device) * 2 - 1 else: # normal noise = torch.randn(num_steps, device=device) u_noisy = u + noise * noise_scale / num_steps u_noisy = torch.clamp(u_noisy, 0, 1) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho # Sort the final sigmas descending so schedule is always noise→clean. # Sorting u_noisy descending before the transform gives wrong order # because the Karras mapping is monotone-decreasing in u. sigmas, _ = torch.sort(sigmas, descending=True) return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """ JYS (Jump Your Steps) schedule using dynamically computed timestep sequences. Strategy: Large jumps early, dense clustering in detail region, fine steps at end. Ported from ComfyUI reference implementation. """ # _compute_jys_timesteps returns num_steps entries + a trailing 0. # Strip the trailing 0 so we get exactly num_steps timesteps; the # explicit zeros(1) terminator is appended below. jys_timesteps = _compute_jys_timesteps(num_steps) if jys_timesteps and jys_timesteps[-1] == 0: jys_timesteps = jys_timesteps[:-1] rho = 7.0 normalized_timesteps = [(1000 - t) / 1000.0 for t in jys_timesteps] t_tensor = torch.tensor(normalized_timesteps, device=device, dtype=torch.float32) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + t_tensor * (min_inv_rho - max_inv_rho)) ** rho sigmas, _ = torch.sort(sigmas, descending=True) return torch.cat([sigmas, torch.zeros(1, device=device)]) def _compute_jys_timesteps(num_steps): """Dynamically compute optimised JYS timestep sequence (0..1000 scale).""" if num_steps <= 0: return [0] if num_steps == 1: return [1000, 0] elif num_steps == 2: return [1000, 500, 0] elif num_steps == 3: return [1000, 600, 200, 0] early_steps = max(1, int(num_steps * 0.2)) final_steps = max(1, int(num_steps * 0.2)) middle_steps = max(1, num_steps - early_steps - final_steps) early_jump_size = max(50, (1000 - 600) // early_steps) early_timesteps = [] current_t = 1000 for _ in range(early_steps): early_timesteps.append(int(current_t)) current_t = max(600, current_t - early_jump_size) middle_timesteps = [] structure_steps = max(1, middle_steps // 2) structure_jump_size = max(10, (600 - 300) // structure_steps) current_t = 600 for _ in range(structure_steps): middle_timesteps.append(int(current_t)) current_t = max(300, current_t - structure_jump_size) detail_steps = middle_steps - structure_steps if detail_steps > 0: detail_jump_size = max(5, (300 - 200) // detail_steps) current_t = 300 for _ in range(detail_steps): middle_timesteps.append(int(current_t)) current_t = max(200, current_t - detail_jump_size) final_start = min(middle_timesteps) if middle_timesteps else 200 final_jump_size = max(5, final_start // final_steps) final_timesteps = [] current_t = final_start for _ in range(final_steps): final_timesteps.append(int(current_t)) current_t = max(0, current_t - final_jump_size) all_timesteps = early_timesteps + middle_timesteps + final_timesteps unique_timesteps = list(dict.fromkeys(all_timesteps)) unique_timesteps.sort(reverse=True) while len(unique_timesteps) < num_steps: for i in range(len(unique_timesteps) - 1): mid_point = (unique_timesteps[i] + unique_timesteps[i + 1]) // 2 if mid_point not in unique_timesteps: unique_timesteps.insert(i + 1, mid_point) if len(unique_timesteps) >= num_steps: break if len(unique_timesteps) > num_steps: unique_timesteps = unique_timesteps[:num_steps] if unique_timesteps[-1] != 0: unique_timesteps.append(0) return unique_timesteps def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """Hybrid: JYS mid-phase with Karras locks.""" if num_steps <= 0: return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)]) rho = 7.0 jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1] indices = torch.arange(num_steps, device=device, dtype=torch.float32) denom = max(1, num_steps - 1) base = (indices + 0.5) / denom jitter_seed = torch.sin((indices + 1) * 2.3999632) jitter_strength = 0.35 jitter = jitter_seed * jitter_strength / denom u = torch.clamp(base + jitter, 0.0, 1.0) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho positions = torch.linspace(0, 1, num_steps, device=device) jys_weight = torch.empty_like(positions) early_mask = positions < 0.3 mid_mask = (positions >= 0.3) & (positions < 0.8) late_mask = positions >= 0.8 jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3) jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5) jys_weight[late_mask] = 0.9 jys_weight = jys_weight.clamp(0.2, 0.9) log_jys = torch.log(jys_sigmas.clamp_min(1e-6)) log_karras = torch.log(karras_sigmas.clamp_min(1e-6)) log_hybrid = torch.lerp(log_karras, log_jys, jys_weight) hybrid = torch.exp(log_hybrid) smoothing = 1.0 - 0.05 * (1 - positions) ** 2 hybrid = hybrid * smoothing for i in range(1, hybrid.shape[0]): if hybrid[i] > hybrid[i - 1]: hybrid[i] = hybrid[i - 1] * 0.999 return torch.cat([hybrid, torch.zeros(1, device=device)]) def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """AYS (Align Your Steps) optimized for SDXL.""" AYS_SCHEDULES = { 10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000], 15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336, 0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000], 20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000, 0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156, 0.0039, 0.0000], 25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000, 0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500, 0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000], 30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667, 0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917, 0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081, 0.0041, 0.0010, 0.0000], } if num_steps in AYS_SCHEDULES: normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32) else: available_steps = sorted(AYS_SCHEDULES.keys()) if num_steps < available_steps[0]: ref_steps = available_steps[0] elif num_steps > available_steps[-1]: ref_steps = available_steps[-1] else: ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1]) ref_schedule = np.array(AYS_SCHEDULES[ref_steps]) t_ref = np.linspace(0, 1, len(ref_schedule)) t_new = np.linspace(0, 1, num_steps + 1) log_ref = np.log(ref_schedule + 1e-8) log_ref[-1] = log_ref[-2] - 3.0 log_interp = np.interp(t_new, t_ref, log_ref) normalized_np = np.exp(log_interp) normalized_np[-1] = 0.0 normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32) sigma_range = sigma_max - sigma_min sigmas = normalized * sigma_range + sigma_min sigmas[0] = sigma_max sigmas[-1] = 0.0 for i in range(1, len(sigmas) - 1): if sigmas[i] >= sigmas[i-1]: sigmas[i] = sigmas[i-1] * 0.999 # Append zero-terminator so output has num_steps+1 entries like all other schedulers. return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """ AkashicAOS Alt: Karras-based schedule with EQ-VAE-tuned warping. Stronger detail-progressive bias (power=0.78) and shifted tanh crossover at t=0.55. Adaptive rho scales with step count for multi-step solver stability. """ if num_steps <= 0: return torch.zeros(1, device=device) rho = min(11.0, max(7.0, 7.0 + 2.0 * (20.0 / max(num_steps, 10)))) u = torch.linspace(0, 1, num_steps, device=device) detail_power = 0.78 u_detail = u ** detail_power t_center = 0.55 beta = 0.07 gamma = 4.0 crossover = beta * torch.tanh(gamma * (u - t_center)) u_modulated = u_detail + crossover u_min, u_max = u_modulated.min(), u_modulated.max() if u_max - u_min > 1e-8: u_modulated = (u_modulated - u_min) / (u_max - u_min) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho max_ratio = 1.5 for i in range(1, len(sigmas)): if sigmas[i] >= sigmas[i - 1]: sigmas[i] = sigmas[i - 1] * 0.995 if sigmas[i - 1] / sigmas[i].clamp(min=1e-10) > max_ratio: sigmas[i] = sigmas[i - 1] / max_ratio return torch.cat([sigmas, torch.zeros(1, device=device)]) def create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device='cpu'): """ AkashicEQFlow: Robust crossover-focused log-SNR schedule for EQ-VAE models. Concentrates steps around the structure-to-detail transition in logSNR space, blended with a Karras prior. Adaptive density width + ratio slew-rate limiting. """ if num_steps <= 0: return torch.zeros(1, device=device) lambda_min = -2.0 * math.log(max(float(sigma_max), 1e-10)) lambda_max = -2.0 * math.log(max(float(sigma_min), 1e-10)) lambda_range = max(lambda_max - lambda_min, 1e-8) step_factor = min(1.0, max(0.0, (num_steps - 16) / 30.0)) lambda_center = 0.20 + 0.15 * step_factor u_center = (lambda_center - lambda_min) / lambda_range u_center = float(min(0.88, max(0.12, u_center))) concentration = min(3.2, max(1.35, 1.1 + num_steps / 16.0)) base_width = min(0.30, max(0.18, 0.31 - 0.0028 * num_steps)) width_left = base_width * 1.06 width_right = base_width * 0.94 detail_side_gain = 1.08 + 0.04 * step_factor N = 1200 t = torch.linspace(0, 1, N, device=device) delta = t - u_center left_core = torch.exp(-((delta / width_left) ** 2) / 2.0) right_core = detail_side_gain * torch.exp(-((delta / width_right) ** 2) / 2.0) crossover_core = torch.where(delta <= 0, left_core, right_core) detail_floor = 0.08 * (t ** 1.4) composition_floor = 0.05 * ((1 - t) ** 1.7) density = 1.0 + concentration * crossover_core + detail_floor + composition_floor dt_val = 1.0 / (N - 1) cdf = torch.zeros(N, device=device) cdf[1:] = torch.cumsum((density[:-1] + density[1:]) * 0.5 * dt_val, dim=0) cdf = cdf / cdf[-1].clamp(min=1e-12) targets = torch.linspace(0, 1, num_steps, device=device) indices = torch.searchsorted(cdf, targets).clamp(1, N - 1) lo = indices - 1 hi = indices frac = (targets - cdf[lo]) / (cdf[hi] - cdf[lo]).clamp(min=1e-12) u_steps = t[lo] + frac * (t[hi] - t[lo]) lambdas_eqflow = lambda_min + u_steps * lambda_range rho = min(10.0, max(7.0, 7.0 + 1.5 * (22.0 / max(num_steps, 12)))) u_karras = torch.linspace(0, 1, num_steps, device=device) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas_karras = (max_inv_rho + u_karras * (min_inv_rho - max_inv_rho)) ** rho lambdas_karras = -2.0 * torch.log(sigmas_karras.clamp(min=1e-10)) blend_eqflow = min(0.60, max(0.35, 0.38 + num_steps / 200.0)) lambdas = (1.0 - blend_eqflow) * lambdas_karras + blend_eqflow * lambdas_eqflow sigmas = torch.exp(-lambdas / 2.0) if num_steps >= 40: max_ratio = 1.50 elif num_steps >= 28: max_ratio = 1.55 elif num_steps >= 18: max_ratio = 1.65 else: max_ratio = 1.85 ratio_slew = 1.18 prev_ratio = None sigmas[0] = sigma_max for i in range(1, len(sigmas)): if sigmas[i] >= sigmas[i - 1]: sigmas[i] = sigmas[i - 1] * 0.995 ratio = float((sigmas[i - 1] / sigmas[i].clamp(min=1e-10)).item()) ratio = min(ratio, max_ratio) if prev_ratio is not None: ratio = min(ratio, prev_ratio * ratio_slew) ratio = max(ratio, prev_ratio / ratio_slew) ratio = max(1.001, ratio) sigmas[i] = sigmas[i - 1] / ratio prev_ratio = ratio return torch.cat([sigmas, torch.zeros(1, device=device)]) def apply_custom_scheduler(sigmas, scheduler_type="Standard"): """ Apply a custom sigma schedule. sigma_min uses sigmas[-2] (last non-zero step), never the zero-terminator. Each scheduler is invoked via a lambda so keyword args with non-standard defaults (e.g. Entropic's `power`) are always passed correctly. """ if scheduler_type == "Standard" or len(sigmas) < 2: return sigmas sigma_max = sigmas[0] # Use the last non-zero sigma as sigma_min; sigmas[-1] is always 0. sigma_min = sigmas[-2] if len(sigmas) >= 2 else sigmas[0] if sigma_min <= 0: sigma_min = sigma_max * 1e-3 num_steps = len(sigmas) - 1 device = sigmas.device scheduler_map = { "AOS-V": lambda: create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device), "AOS-Epsilon": lambda: create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device), "AkashicAOS": lambda: create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device), "Entropic": lambda: create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device=device), "SNR-Optimized": lambda: create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device), "Constant-Rate": lambda: create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device), "Adaptive-Optimized": lambda: create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device), "Cosine-Annealed": lambda: create_cosine_sigmas(sigma_max, sigma_min, num_steps, device), "LogSNR-Uniform": lambda: create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device), "Tanh Mid-Boost": lambda: create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device), "Exponential Tail": lambda: create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device), "Jittered-Karras": lambda: create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device), "Stochastic": lambda: create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device=device), "JYS (Dynamic)": lambda: create_jys_sigmas(sigma_max, sigma_min, num_steps, device), "Hybrid JYS-Karras": lambda: create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device), "AYS-SDXL": lambda: create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device), "AkashicAOS Alt": lambda: create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device), "AkashicEQFlow": lambda: create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device), } fn = scheduler_map.get(scheduler_type) if fn is not None: try: result = fn() if result is not None and len(result) > 1: return result print(f"⚠️ Scheduler {scheduler_type} returned empty/None, using standard") except Exception as e: print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard") return sigmas # ============================================================================ # ============================================================================ # K-DIFFUSION SAMPLERS with Custom Sampler Integration # ============================================================================ @torch.no_grad() def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Euler sampler with Adept weight scaling OR custom sampler.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) # STANDARD K-DIFFUSION MODE (from v3 - unchanged) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'euler' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) return _basic_euler(model, x, sigmas, extra_args, callback, disable) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None total_steps = len(sigmas) - 1 for i in range(total_steps): sigma = sigmas[i] gamma = min(s_churn / total_steps, 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) with AdeptWeightPatcher(unet_model, current_scale, shift): eps = torch.randn_like(x) * s_noise if gamma > 0 else 0 sigma_hat = sigma * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) dt = sigmas[i + 1] - sigma_hat x = x + d * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma_hat, 'denoised': denoised}) return x def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): """Euler Ancestral with Adept weight scaling.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'euler_ancestral' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # Get settings base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False) current_eta = ADEPT_STATE.get('eta', eta) current_s_noise = ADEPT_STATE.get('s_noise', s_noise) # Get UNet try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None if noise_sampler is None: noise_sampler = default_noise_sampler(x) total_steps = len(sigmas) - 1 print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}") for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"): sigma = sigmas[i] sigma_next = sigmas[i + 1] progress = i / max(total_steps, 1) # Adaptive eta if use_adaptive_eta: if progress < 0.3: current_eta = eta * 1.08 elif progress < 0.7: current_eta = eta * 0.95 else: current_eta = eta * 1.02 else: current_eta = eta # Dynamic scale current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) # Evaluate model with weight patching if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised = model(x, sigma * s_in, **extra_args) else: denoised = model(x, sigma * s_in, **extra_args) # Euler Ancestral step sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) d = to_d(x, sigma, denoised) if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) dt = sigma_down - sigma x = x + d * dt if sigma_up > 0: noise = noise_sampler(sigma, sigma_next) * current_s_noise x = x + noise * sigma_up if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None): """Fallback basic Euler (used when ORIGINAL_SAMPLERS has no 'euler' key).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): sigma = sigmas[i] denoised = model(x, sigma * s_in, **extra_args) d = to_d(x, sigma, denoised) dt = sigmas[i + 1] - sigma x = x + d * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): """Fallback basic Euler Ancestral.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) noise_sampler = default_noise_sampler(x) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) d = to_d(x, sigmas[i], denoised) dt = sigma_down - sigmas[i] x = x + d * dt if sigma_up > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Heun sampler with Adept weight scaling.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'heun' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise) return _basic_heun(model, x, sigmas, extra_args, callback, disable) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # Get settings base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) # Get UNet try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None total_steps = len(sigmas) - 1 print(f"✅ Adept Heun active: scale={base_scale:.2f}") for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"): sigma = sigmas[i] sigma_next = sigmas[i + 1] # Dynamic scale current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) # First evaluation if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised = model(x, sigma * s_in, **extra_args) else: denoised = model(x, sigma * s_in, **extra_args) d = to_d(x, sigma, denoised) if torch.isnan(d).any() or torch.isinf(d).any(): d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0) dt = sigma_next - sigma if sigma_next == 0: # Last step x = x + d * dt else: # Heun's method: two-stage x_2 = x + d * dt # Second evaluation if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised_2 = model(x_2, sigma_next * s_in, **extra_args) else: denoised_2 = model(x_2, sigma_next * s_in, **extra_args) d_2 = to_d(x_2, sigma_next, denoised_2) if torch.isnan(d_2).any() or torch.isinf(d_2).any(): d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0) # Average d_prime = (d + d_2) / 2 x = x + d_prime * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None): """Fallback basic Heun.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) d = to_d(x, sigmas[i], denoised) dt = sigmas[i + 1] - sigmas[i] if sigmas[i + 1] == 0: x = x + d * dt else: x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): """DPM++ 2M sampler with Adept weight scaling.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'dpmpp_2m' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable) return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # Get settings base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) # Get UNet try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None total_steps = len(sigmas) - 1 print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}") old_denoised = None for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"): sigma = sigmas[i] sigma_next = sigmas[i + 1] # Dynamic scale current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) # Evaluate model with weight patching if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised = model(x, sigma * s_in, **extra_args) else: denoised = model(x, sigma * s_in, **extra_args) # DPM++ 2M step t, t_next = sigma, sigma_next h = t_next - t if old_denoised is None or sigma_next == 0: # First step (Euler) x = (sigma_next / sigma) * x - (-h).expm1() * denoised else: # Second order h_last = t - sigmas[i - 1] r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d old_denoised = denoised if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): """Fallback basic DPM++ 2M.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_denoised = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) t, t_next = sigmas[i], sigmas[i + 1] h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: x = (t_next / t) * x - (-h).expm1() * denoised else: h_last = t - sigmas[i - 1] r = h_last / h denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised x = (t_next / t) * x - (-h).expm1() * denoised_d old_denoised = denoised if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None): """DPM++ 2S Ancestral with Adept weight scaling.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # Get settings base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) current_eta = ADEPT_STATE.get('eta', eta) current_s_noise = ADEPT_STATE.get('s_noise', s_noise) # Get UNet try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None if noise_sampler is None: noise_sampler = default_noise_sampler(x) total_steps = len(sigmas) - 1 print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}") for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"): sigma = sigmas[i] sigma_next = sigmas[i + 1] # Dynamic scale current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) # First evaluation if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised = model(x, sigma * s_in, **extra_args) else: denoised = model(x, sigma * s_in, **extra_args) # DPM++ 2S step with ancestral noise sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta) if sigma_down == 0: d = to_d(x, sigma, denoised) x = x + d * (sigma_down - sigma) else: # Midpoint method t, t_next = sigma, sigma_down h = t_next - t s = t + h * 0.5 # Step to midpoint x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised # Evaluate at midpoint if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised_mid = model(x_mid, s * s_in, **extra_args) else: denoised_mid = model(x_mid, s * s_in, **extra_args) # Full step using midpoint x = (t_next / t) * x - (-h).expm1() * denoised_mid # Add ancestral noise if sigma_up > 0: noise = noise_sampler(sigma, sigma_next) * current_s_noise x = x + noise * sigma_up if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0): """Fallback basic DPM++ 2S Ancestral.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) noise_sampler = default_noise_sampler(x) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta) if sigma_down == 0: d = to_d(x, sigmas[i], denoised) x = x + d * (sigma_down - sigmas[i]) else: t, t_next = sigmas[i], sigma_down h = t_next - t s = t + h * 0.5 x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised denoised_mid = model(x_mid, s * s_in, **extra_args) x = (t_next / t) * x - (-h).expm1() * denoised_mid if sigma_up > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x @torch.no_grad() def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): """LMS sampler with Adept weight scaling.""" # CUSTOM SAMPLER INTEGRATION if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False): custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2') print(f"🌀 Redirecting to {custom_type}") # Apply scheduler to sigmas for custom samplers scheduler = ADEPT_STATE.get('scheduler', 'Standard') if scheduler != "Standard": sigmas = apply_custom_scheduler(sigmas, scheduler) print(f" 📊 Applied {scheduler} scheduler") if custom_type == "Akashic Solver v2": return sample_akashic_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2), smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0), use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off') ) elif custom_type == "Adept Solver": return sample_adept_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True), use_detail_enhancement=False, settings={} ) elif custom_type == "Adept Ancestral Solver": return sample_adept_ancestral_solver( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False), phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False), use_detail_enhancement=False, settings={} ) elif custom_type == "Mirror Correction Euler": return sample_mirror_correction_euler( model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0), correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5), smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False) ) if not ADEPT_STATE.get('enabled', False): global ORIGINAL_SAMPLERS if 'lms' in ORIGINAL_SAMPLERS: return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order) return _basic_lms(model, x, sigmas, extra_args, callback, disable, order) # Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap) _sched = ADEPT_STATE.get('scheduler', 'Standard') if _sched != 'Standard': sigmas = apply_custom_scheduler(sigmas, _sched) extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) # Get settings base_scale = ADEPT_STATE.get('scale', 1.0) shift = ADEPT_STATE.get('shift', 0.0) start_pct = ADEPT_STATE.get('start_pct', 0.0) end_pct = ADEPT_STATE.get('end_pct', 1.0) # Get UNet try: unet_model = shared.sd_model.model.diffusion_model except AttributeError: unet_model = None total_steps = len(sigmas) - 1 print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}") ds = [] for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"): sigma = sigmas[i] # Dynamic scale current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct) # Evaluate model with weight patching if should_patch_weights(unet_model, current_scale, shift): with AdeptWeightPatcher(unet_model, current_scale, shift): denoised = model(x, sigma * s_in, **extra_args) else: denoised = model(x, sigma * s_in, **extra_args) d = to_d(x, sigma, denoised) ds.append(d) if len(ds) > order: ds.pop(0) # Linear multistep coefficients cur_order = min(i + 1, order) coeffs = [1.0] for j in range(1, cur_order): prod = 1.0 for k in range(cur_order): if k != j: prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) coeffs.append(prod) # Apply multistep d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) dt = sigmas[i + 1] - sigma x = x + d_multistep * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised}) return x def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): """Fallback basic LMS.""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) ds = [] for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) d = to_d(x, sigmas[i], denoised) ds.append(d) if len(ds) > order: ds.pop(0) cur_order = min(i + 1, order) coeffs = [1.0] for j in range(1, cur_order): prod = 1.0 for k in range(cur_order): if k != j: prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k]) coeffs.append(prod) d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:]))) dt = sigmas[i + 1] - sigmas[i] x = x + d_multistep * dt if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised}) return x # ============================================================================ # MONKEY PATCHING # ============================================================================ def patch_k_diffusion(): """Apply monkey patches to ALL k-diffusion samplers.""" global ORIGINAL_SAMPLERS samplers_to_patch = { 'sample_euler': sample_adept_euler, 'sample_euler_ancestral': sample_adept_euler_ancestral, 'sample_heun': sample_adept_heun, 'sample_dpmpp_2m': sample_adept_dpmpp_2m, 'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, 'sample_lms': sample_adept_lms, } patched_count = 0 for original_name, adept_func in samplers_to_patch.items(): if not hasattr(k_diffusion.sampling, original_name): continue key = original_name.replace('sample_', '') current_func = getattr(k_diffusion.sampling, original_name) # Only save original if we haven't stored it yet (avoid saving already-patched func) if key not in ORIGINAL_SAMPLERS: ORIGINAL_SAMPLERS[key] = current_func # Always (re-)apply our patch if it isn't already there if current_func is not adept_func: setattr(k_diffusion.sampling, original_name, adept_func) patched_count += 1 print(f"✅ Adept Sampler v5: Patched {patched_count} samplers") print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS") print(f" Schedulers: 18 types available") def unpatch_k_diffusion(): """ Restore original k-diffusion samplers. Safe-unpatch strategy: before restoring we check whether the live slot still holds *our* wrapper. If another extension has wrapped us on top (i.e. live_func is not our adept_func but also not the original we saved), blindly restoring would silently remove *their* wrapper too. In that case we skip the restore for that slot and log a warning so the operator knows the coexistence situation. """ global ORIGINAL_SAMPLERS adept_wrappers = { 'euler': 'sample_euler', 'euler_ancestral': 'sample_euler_ancestral', 'heun': 'sample_heun', 'dpmpp_2m': 'sample_dpmpp_2m', 'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral', 'lms': 'sample_lms', } adept_funcs = { 'euler': sample_adept_euler, 'euler_ancestral': sample_adept_euler_ancestral, 'heun': sample_adept_heun, 'dpmpp_2m': sample_adept_dpmpp_2m, 'dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral, 'lms': sample_adept_lms, } restored_count = 0 skipped_count = 0 for key, attr_name in adept_wrappers.items(): if key not in ORIGINAL_SAMPLERS: continue live_func = getattr(k_diffusion.sampling, attr_name, None) our_func = adept_funcs[key] saved_original = ORIGINAL_SAMPLERS[key] if live_func is our_func: # Normal case: we still own the slot — safe to restore. setattr(k_diffusion.sampling, attr_name, saved_original) restored_count += 1 elif live_func is saved_original: # Already restored somehow — nothing to do. restored_count += 1 else: # Another extension wrapped us. Restoring would silently # remove their wrapper; skip and warn instead. print(f"⚠️ Adept unpatch: {attr_name} is currently owned by another " f"extension ({live_func!r}). Skipping restore to avoid breaking " f"their wrapper — you may need to reload the UI to fully unload.") skipped_count += 1 ORIGINAL_SAMPLERS.clear() print(f"🔄 Adept Sampler: Restored {restored_count} samplers" + (f", skipped {skipped_count} (foreign wrappers)" if skipped_count else "")) # ============================================================================ # A1111 EXTENSION SCRIPT # ============================================================================ class AdeptSamplerScript(scripts.Script): def title(self): return "Adept Sampler v5" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, is_img2img): with gr.Accordion("Adept Sampler v5", open=False): enabled = gr.Checkbox(label="Enable Adept Sampler", value=False, elem_id="adept_enabled") with gr.Row(): scale = gr.Slider(minimum=0.5, maximum=2.0, step=0.05, value=1.0, label="Weight Scale") shift = gr.Slider(minimum=-0.5, maximum=0.5, step=0.01, value=0.0, label="Weight Shift") with gr.Row(): start_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Start Percent") end_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Percent") gr.HTML("

" "⚠️ Weight Scale / Shift / Start–End apply to the 6 patched k-diffusion samplers only. " "Custom samplers (Akashic, Adept, Mirror) use their own internal parameters.

") with gr.Row(): eta = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Eta (Ancestral samplers)") s_noise = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="S-Noise") adaptive_eta = gr.Checkbox(label="Adaptive Eta (dynamic eta during sampling)", value=False) scheduler = gr.Dropdown( choices=["Standard", "AOS-V", "AOS-Epsilon", "AkashicAOS", "Entropic", "SNR-Optimized", "Constant-Rate", "Adaptive-Optimized", "Cosine-Annealed", "LogSNR-Uniform", "Tanh Mid-Boost", "Exponential Tail", "Jittered-Karras", "Stochastic", "JYS (Dynamic)", "Hybrid JYS-Karras", "AYS-SDXL", "AkashicAOS Alt", "AkashicEQFlow"], value="Standard", label="Scheduler Type" ) vae_reflection = gr.Checkbox(label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)", value=False) gr.HTML("
") gr.HTML("

🌀 Custom Advanced Samplers

") gr.HTML("

Enable to use Akashic/Adept/Ancestral samplers instead of k-diffusion

") use_custom = gr.Checkbox(label="Use Custom Sampler (overrides k-diffusion)", value=False) custom_type = gr.Dropdown( choices=["Akashic Solver v2", "Adept Solver", "Adept Ancestral Solver", "Mirror Correction Euler"], value="Akashic Solver v2", label="Custom Sampler Type" ) with gr.Accordion("⚙️ Akashic Solver Settings", open=False): tau = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Tau (0=ODE, 1=SDE)") phase_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Phase Strength") smea = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="SMEA (high-res coherency)") ndb = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="NDB (detail boost)") eqvae = gr.Dropdown(choices=["Off", "Balanced"], value="Off", label="EQ-VAE Mode") with gr.Accordion("⚙️ Adept Solver Settings", open=False): solver_order = gr.Slider(minimum=1, maximum=3, step=1, value=2, label="Order (1-3)") use_corrector = gr.Checkbox(value=True, label="Use Corrector") with gr.Accordion("⚙️ Ancestral Solver Settings", open=False): phase_noise = gr.Checkbox(value=False, label="Phase-Aware Noise") enhanced_deriv = gr.Checkbox(value=False, label="Enhanced Derivative") with gr.Accordion("⚙️ Mirror Correction Euler Settings", open=False): gr.HTML("

Active only when Custom Sampler = Mirror Correction Euler

") mirror_correction_phase = gr.Slider( minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Correction Phase (fraction of steps with 3-call Heun correction)" ) mirror_smooth_phase = gr.Checkbox( value=False, label="Smooth Phase (log-sigma blend instead of binary cutoff)" ) gr.HTML("
") gr.HTML("

🎛️ CFG Enhancements

") gr.HTML("

" "Combat CFG Drift works in stock A1111 via official callback. " "Spectral Modulation & Phase-Aware CFG use a native sampler hook on " "Forge/reForge-like backends, or a CFGDenoiser monkey-patch on stock A1111 " "(near-parity; active mode logged to console).

") with gr.Accordion("⚙️ CFG Enhancement Settings", open=False): cfg_drift_enabled = gr.Checkbox(value=False, label="Enable Combat CFG Drift") with gr.Row(): cfg_drift_method = gr.Dropdown( choices=["mean", "median"], value="mean", label="Drift Method" ) cfg_drift_intensity = gr.Slider( minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Drift Intensity" ) spectral_cfg_enabled = gr.Checkbox( value=False, label="Enable Spectral Modulation (native hook or A1111 monkey-patch)" ) with gr.Row(): spectral_multiplier = gr.Slider( minimum=0.0, maximum=2.0, step=0.05, value=1.0, label="Spectral Multiplier" ) spectral_percentile = gr.Slider( minimum=1.0, maximum=25.0, step=0.5, value=5.0, label="Spectral Percentile" ) phase_cfg_enabled = gr.Checkbox( value=False, label="Enable Phase-Aware CFG (native hook or A1111 monkey-patch)" ) with gr.Row(): phase_cfg_alpha = gr.Slider( minimum=1.1, maximum=4.0, step=0.1, value=2.0, label="Phase CFG Alpha" ) phase_cfg_beta = gr.Slider( minimum=1.1, maximum=4.0, step=0.1, value=2.0, label="Phase CFG Beta" ) return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, phase_noise, enhanced_deriv, mirror_correction_phase, mirror_smooth_phase, cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, spectral_cfg_enabled, spectral_multiplier, spectral_percentile, phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta] def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection, use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector, phase_noise, enhanced_deriv, mirror_correction_phase, mirror_smooth_phase, cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity, spectral_cfg_enabled, spectral_multiplier, spectral_percentile, phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta): global ADEPT_STATE # Gate all sub-features through the master enabled switch. # This prevents CFG hooks, native patches, and VAE Reflection from # activating when the extension is globally toggled off. ADEPT_STATE.update({ "enabled": enabled, "scale": scale, "shift": shift, "start_pct": start_pct, "end_pct": end_pct, "eta": eta, "s_noise": s_noise, "adaptive_eta": adaptive_eta, "scheduler": scheduler, "vae_reflection": enabled and vae_reflection, # gated "use_custom_sampler": use_custom, "custom_sampler": custom_type, "tau": tau, "phase_strength": phase_strength, "smea_strength": smea, "ndb_strength": ndb, "eqvae_mode": eqvae, "solver_order": int(solver_order), "use_corrector": use_corrector, "phase_noise": phase_noise, "enhanced_derivative": enhanced_deriv, # Mirror Correction Euler "mirror_correction_phase": mirror_correction_phase, "mirror_smooth_phase": mirror_smooth_phase, # CFG enhancements — all gated through enabled "cfg_drift_enabled": enabled and cfg_drift_enabled, "cfg_drift_method": cfg_drift_method, "cfg_drift_intensity": cfg_drift_intensity, "spectral_cfg_enabled": enabled and spectral_cfg_enabled, "spectral_multiplier": spectral_multiplier, "spectral_percentile": spectral_percentile, "phase_cfg_enabled": enabled and phase_cfg_enabled, "phase_cfg_alpha": phase_cfg_alpha, "phase_cfg_beta": phase_cfg_beta, }) # Scheduler is now applied inside each patched sampler function, # so p.sampler.model_wrap patching is no longer needed here. # Always reconfigure CFG runtime — even when disabled — so any previously # installed native hook or A1111 callbacks get cleanly removed. runtime_mode = configure_cfg_runtime() if enabled: info = { "Adept Sampler": "v5", "Adept Scheduler": scheduler, "CFG Runtime": runtime_mode, } if use_custom: info["Adept Custom"] = custom_type if custom_type == "Akashic Solver v2": info["Adept Tau"] = tau info["Adept EQ-VAE"] = eqvae p.extra_generation_params.update(info) def process_batch(self, p, *args, **kwargs): """Apply VAE Reflection before batch processing.""" if ADEPT_STATE.get("enabled", False) and ADEPT_STATE.get("vae_reflection", False): try: vae_model = shared.sd_model.first_stage_model patcher = VAEReflectionPatcher(vae_model) patcher.__enter__() p.adept_vae_patcher = patcher except Exception as e: print(f"⚠️ VAE Reflection error: {e}") def postprocess_batch(self, p, *args, **kwargs): """Restore VAE padding modes after batch processing.""" if hasattr(p, 'adept_vae_patcher'): try: p.adept_vae_patcher.__exit__(None, None, None) delattr(p, 'adept_vae_patcher') except Exception as e: print(f"⚠️ VAE Reflection restore error: {e}") # Safety net: force-restore even if the patcher context failed force_restore_vae_reflection() # ============================================================================ # INITIALIZATION # ============================================================================ # # k-diffusion wrappers are installed via on_before_ui (fires after all # extensions are imported) rather than at bare module import time. # This reduces the risk of interacting badly with other extensions that # also wrap k_diffusion.sampling functions, because our wrappers are put # on last and therefore sit outermost in the call chain. # Uninstall happens in on_script_unloaded() as before. def _adept_deferred_init(): patch_k_diffusion() try: script_callbacks.on_before_ui(_adept_deferred_init) except Exception: # Fallback: if on_before_ui isn't available (older A1111), patch immediately. patch_k_diffusion() def on_script_unloaded(): try: force_restore_vae_reflection() except Exception: pass try: uninstall_a1111_cfg_callbacks() except Exception: pass try: uninstall_native_cfg_hook() except Exception: pass try: unpatch_cfg_denoiser() except Exception: pass try: unpatch_k_diffusion() except Exception: pass try: script_callbacks.on_script_unloaded(on_script_unloaded) except AttributeError: print("⚠️ Script unload callback not available") print("🚀 Adept Sampler v5 loaded!") print(" ✨ 4 Custom Samplers: Akashic v2, Adept Solver, Adept Ancestral, Mirror Correction Euler") print(" ⚡ 6 k-diffusion Samplers with weight scaling") print(" 📅 18 Schedulers (including AkashicAOS Alt, AkashicEQFlow)") print(" 🎨 VAE Reflection") print(" ✅ A1111 port of ComfyUI-Adept-Sampler")