sdas / adept-sampler-v5 /scripts /adept_sampler_v5.py
dikdimon's picture
Upload adept_sampler_v5.py
5bdd43b verified
"""
Adept Sampler v5 for Automatic1111 WebUI
Complete port with ALL custom samplers from ComfyUI
Version: 5.0
"""
import torch
import numpy as np
import math
from tqdm import trange
from modules import scripts, shared, script_callbacks
import gradio as gr
import k_diffusion.sampling
# Try import torchvision for detail enhancement
try:
from torchvision.transforms.functional import gaussian_blur
TORCHVISION_AVAILABLE = True
except ImportError:
TORCHVISION_AVAILABLE = False
print("⚠️ torchvision not available - detail enhancement disabled")
# ============================================================================
# GLOBAL STATE
# ============================================================================
ADEPT_STATE = {
"enabled": False,
"scale": 1.0,
"shift": 0.0,
"start_pct": 0.0,
"end_pct": 1.0,
"eta": 1.0,
"s_noise": 1.0,
"adaptive_eta": False,
"scheduler": "Standard",
"vae_reflection": False,
# Custom sampler settings
"use_custom_sampler": False,
"custom_sampler": "Akashic Solver v2",
"tau": 0.5,
"phase_strength": 0.5,
"solver_order": 2,
"use_corrector": True,
"phase_noise": False,
"enhanced_derivative": False,
"smea_strength": 0.0,
"ndb_strength": 0.0,
"eqvae_mode": "Off",
# Mirror Correction Euler controls
"mirror_correction_phase": 0.5,
"mirror_smooth_phase": False,
# CFG enhancement settings
"cfg_drift_enabled": False,
"cfg_drift_method": "mean",
"cfg_drift_intensity": 0.5,
"spectral_cfg_enabled": False,
"spectral_multiplier": 1.0,
"spectral_percentile": 5.0,
"phase_cfg_enabled": False,
"phase_cfg_alpha": 2.0,
"phase_cfg_beta": 2.0,
"cfg_runtime_mode": "off", # off | a1111-postcfg | a1111-monkeypatch | native-hook
# Internal bookkeeping for phase-aware CFG progress tracking
"_cfg_step_idx": 0,
"_cfg_total_steps": 1,
}
# Store original samplers
ORIGINAL_SAMPLERS = {}
# VAE Reflection state
_vae_reflection_active = False
_vae_original_padding_modes = {}
# CFG hook / callback runtime state
_ADEPT_CFG_AFTER_CB = None
_ADEPT_CFG_DENOISER_CB = None
_ADEPT_NATIVE_CFG_HOOK_ACTIVE = False
# CFGDenoiser monkey-patch runtime state
_CFGD_ORIG_COMBINE = None
_CFGD_ORIG_COMBINE_EDIT = None
_CFGD_ORIG_FORWARD = None
_CFGD_MONKEYPATCH_ACTIVE = False
_ADEPT_CFGDENOISER_CTX_ATTR = "_adept_cfg_ctx"
# ============================================================================
# BASIC UTILITY FUNCTIONS (from v3)
# ============================================================================
def to_d(x, sigma, denoised):
"""Convert denoised prediction to derivative."""
diff = x - denoised
safe_sigma = torch.clamp(sigma, min=1e-4)
derivative = diff / safe_sigma
sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
derivative_max = torch.abs(derivative).max()
if derivative_max > sigma_adaptive_threshold:
derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold)
return derivative
def get_ancestral_step(sigma, sigma_next, eta=1.0):
"""Calculate ancestral step sizes."""
if sigma_next == 0:
return 0.0, 0.0
sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5)
sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct):
"""
Compute weight scale for the current step with smooth fade-in/fade-out.
The fade ramp is proportional to the active window width rather than a
fixed 0.1 absolute value. For a narrow window (e.g. start=0.4, end=0.5)
a hard-coded 0.1 ramp would consume the entire window and produce jarring
or contradictory behaviour; normalising to 20 % of window width keeps the
envelope sensible at any window size.
Returns 1.0 (no-op) outside [start_pct, end_pct].
"""
# Clamp / validate inputs so callers can pass raw UI values safely.
start_pct = max(0.0, min(float(start_pct), 1.0))
end_pct = max(start_pct, min(float(end_pct), 1.0))
total_steps = max(total_steps, 1)
progress = step_idx / max(total_steps - 1, 1)
if progress < start_pct or progress > end_pct:
return 1.0
window = end_pct - start_pct
if window < 1e-6:
# Degenerate window — treat as fully active for that single step.
return float(base_scale)
# Fade ramp = 20 % of window, capped at 0.05 so it never feels sluggish
# on very wide windows and never feels jarring on narrow ones.
ramp = min(0.20 * window, 0.05)
if ramp > 0 and progress < start_pct + ramp:
fade = (progress - start_pct) / ramp
return 1.0 + (base_scale - 1.0) * fade
elif ramp > 0 and progress > end_pct - ramp:
fade = (end_pct - progress) / ramp
return 1.0 + (base_scale - 1.0) * fade
else:
return float(base_scale)
def default_noise_sampler(x):
"""Simple noise sampler fallback."""
def sampler(sigma, sigma_next):
return torch.randn_like(x)
return sampler
def get_noise_sampler(x):
"""Get noise sampler for the given tensor."""
return default_noise_sampler(x)
# ============================================================================
# ADVANCED UTILITY FUNCTIONS
# ============================================================================
def to_d_enhanced_ancestral(x, sigma, denoised, eta, progress, generator=None):
"""Enhanced derivative for ancestral sampling."""
diff = x - denoised
safe_sigma = torch.clamp(sigma, min=1e-4)
base_derivative = diff / safe_sigma
def safe_randn_like(tensor, generator=None):
if generator is None:
return torch.randn_like(tensor)
try:
return torch.randn(tensor.shape, device=tensor.device, dtype=tensor.dtype, generator=generator)
except (TypeError, AttributeError):
return torch.randn_like(tensor)
if eta > 1.0:
eta_correction = 0.02 * (eta - 1.0) * safe_randn_like(diff, generator) * progress
base_derivative = base_derivative + eta_correction
elif eta < 1.0:
eta_correction = 0.015 * (1.0 - eta) * safe_randn_like(diff, generator) * (1.0 - progress)
base_derivative = base_derivative - eta_correction
if progress < 0.3:
phase_correction = 0.01 * safe_randn_like(diff, generator)
base_derivative = base_derivative + phase_correction
elif progress > 0.7:
phase_correction = 0.008 * safe_randn_like(diff, generator)
base_derivative = base_derivative - phase_correction
sigma_adaptive_threshold = 500.0 * (1.0 + sigma / 10.0)
derivative_max = torch.abs(base_derivative).max()
if derivative_max > sigma_adaptive_threshold:
base_derivative = torch.clamp(base_derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold)
return base_derivative
def apply_dynamic_thresholding(x, percentile=0.995, clamp_range=1.0):
"""Dynamic thresholding for high CFG."""
if percentile >= 1.0:
return x
try:
batch_size = x.shape[0]
x_flat = x.view(batch_size, -1)
abs_max = torch.abs(x_flat).max(dim=1, keepdim=True)[0]
if abs_max.max() < 5.0:
return x
k = max(1, int(x_flat.shape[1] * (1.0 - percentile)))
topk_vals = torch.topk(torch.abs(x_flat), k=k, dim=1, largest=True)[0]
s = topk_vals[:, -1:].clamp(min=1.0)
threshold = s * 2.5
mask = torch.abs(x_flat) > threshold
x_flat = torch.where(mask, torch.sign(x_flat) * threshold, x_flat)
x_flat = x_flat * 0.98
return x_flat.view(x.shape)
except Exception as e:
return x
def compute_compensation_ratio(r, step_idx, total_steps, base_ratio=1.0):
"""DC-Solver compensation."""
progress = step_idx / max(total_steps - 1, 1)
if progress < 0.3:
phase_weight = 1.5
elif progress < 0.7:
phase_weight = 1.0
else:
phase_weight = 1.3
return base_ratio * phase_weight * (1.0 + 0.1 * math.tanh(r - 1.0))
def compute_tau_eqvae(progress, base_tau=0.5, phase_strength=0.5):
"""Phase-aware tau for standard VAE."""
if progress < 0.30:
phase_factor = 1.0 + 0.2 * phase_strength
elif progress < 0.60:
phase_factor = 1.0 - 0.15 * phase_strength
else:
phase_factor = 1.0 - 0.3 * phase_strength
return min(1.0, max(0.0, base_tau * phase_factor))
def compute_eqvae_tau(progress, base_tau, phase_strength):
"""EQ-VAE tau with shifted phases."""
if progress < 0.25:
phase_factor = 1.0 + 0.10 * phase_strength
elif progress < 0.55:
phase_factor = 1.0 - 0.10 * phase_strength
else:
phase_factor = 1.0 - 0.20 * phase_strength
return min(1.0, max(0.0, base_tau * phase_factor))
def compute_eqvae_noise_scale(base_s_noise, progress):
"""EQ-VAE noise scale."""
eqvae_base_factor = 0.88
if progress < 0.25:
phase_factor = 1.0 + 0.05 * (1.0 - progress / 0.25)
elif progress < 0.60:
phase_factor = 1.0 - 0.05 * ((progress - 0.25) / 0.35)
else:
phase_factor = 0.95
return base_s_noise * eqvae_base_factor * phase_factor
def compute_eqvae_ndb(progress, ndb_strength):
"""Native Detail Boost for EQ-VAE."""
if ndb_strength <= 0:
return 0.5, 0.0
blur_sigma = 0.6
if progress < 0.30:
phase_progress = progress / 0.30
high_freq_boost = 0.03 * ndb_strength * phase_progress
elif progress < 0.60:
phase_progress = (progress - 0.30) / 0.30
high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength
else:
phase_progress = (progress - 0.60) / 0.40
high_freq_boost = (0.10 + 0.10 * phase_progress) * ndb_strength
return blur_sigma, high_freq_boost
def compute_native_detail_boost(progress, ndb_strength=0.0):
"""Native Detail Boost for standard VAE."""
if ndb_strength <= 0:
return 1.0, 0.0
if progress < 0.30:
phase_progress = progress / 0.30
high_freq_boost = 0.03 * ndb_strength * phase_progress
elif progress < 0.60:
phase_progress = (progress - 0.30) / 0.30
high_freq_boost = (0.03 + 0.07 * phase_progress) * ndb_strength
else:
phase_progress = (progress - 0.60) / 0.40
high_freq_boost = (0.10 + 0.08 * phase_progress) * ndb_strength
return 1.0, high_freq_boost
def compute_smea_factor(progress, smea_strength=0.5):
"""SMEA coherency."""
if smea_strength <= 0:
return 1.0
smea_interp = 0.5 * (1 + math.sin(math.pi * (progress - 0.5)))
return 1.0 - smea_strength * (1.0 - smea_interp)
# ============================================================================
# ADVANCED CFG TECHNIQUES from reForge
# ============================================================================
def apply_spectral_modulation_clybius(noise_pred, multiplier=1.0, percentile=5.0):
"""
Clybius Spectral Modulation: Apply frequency-domain corrections to noise prediction.
This is the correct implementation based on ComfyUI-Latent-Modifiers.
It should be applied to noise_pred (cond - uncond), NOT to denoised latent.
Args:
noise_pred: The noise prediction tensor (cond - uncond)
multiplier: Modulation strength (0=none, 1=full Clybius effect). Default: 1.0
percentile: Upper/lower percentile threshold. Default: 5.0
Returns:
Spectrally modulated noise prediction
"""
if multiplier == 0 or percentile <= 0:
return noise_pred
try:
# FFT
fourier = torch.fft.fft2(noise_pred, dim=(-2, -1))
# Log amplitude (with small epsilon for numerical stability)
log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8)
# Compute quantiles on absolute log amplitude
log_amp_flat = log_amp.abs().flatten(2)
quantile_low = torch.quantile(log_amp_flat, percentile * 0.01, dim=2)
quantile_high = torch.quantile(log_amp_flat, 1 - percentile * 0.01, dim=2)
# Expand quantiles back to log_amp shape
quantile_low = quantile_low.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
quantile_high = quantile_high.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
# Create masks (Clybius approach)
# mask_low: boost values below low threshold (range 1.0 to 1.5)
# mask_high: reduce values above high threshold (range 0.5 to 1.0)
mask_low = ((log_amp < quantile_low).float() + 1).clamp_(max=1.5)
mask_high = ((log_amp < quantile_high).float()).clamp_(min=0.5)
# Apply modulation via exponentiation
filtered_fourier = fourier * ((mask_low * mask_high) ** multiplier)
# Inverse FFT
result = torch.fft.ifft2(filtered_fourier, dim=(-2, -1)).real
return result
except Exception as e:
print(f"⚠️ Spectral modulation failed: {e}")
return noise_pred
def create_spectral_modulation_cfg_hook(multiplier=1.0, percentile=5.0):
"""
Create a CFG hook that applies Clybius spectral modulation to noise prediction.
This hooks into reForge's set_model_sampler_cfg_function to intercept
the CFG calculation and apply spectral modulation at the correct point.
Args:
multiplier: Modulation strength (0=none, 1=full). Default: 1.0
percentile: Frequency percentile threshold. Default: 5.0
Returns:
A hook function to pass to set_model_sampler_cfg_function
"""
def spectral_cfg_hook(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
x_orig = args["input"]
# Reshape sigma for broadcasting
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
# Convert to v-pred space (from RescaleCFG reference)
x = x_orig / (sigma * sigma + 1.0)
cond_v = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
uncond_v = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
# Compute noise prediction
noise_pred = cond_v - uncond_v
# Apply Clybius spectral modulation to noise prediction
noise_pred_modulated = apply_spectral_modulation_clybius(noise_pred, multiplier, percentile)
# Compute CFG with modified noise prediction
x_cfg = uncond_v + cond_scale * noise_pred_modulated
# Convert back from v-pred space
return x_orig - (x - x_cfg * sigma / (sigma * sigma + 1.0) ** 0.5)
return spectral_cfg_hook
def apply_combat_cfg_drift(latent, method='mean', intensity=1.0):
"""
Combat CFG Drift: Reduce mean drift from high CFG values.
Based on ComfyUI-Latent-Modifiers.
As CFG increases, the latent mean can drift away from 0, which causes
color shifts and other artifacts. This technique reduces the drift
proportionally based on intensity.
Args:
latent: The latent tensor to correct
method: 'mean' or 'median'. Default: 'mean'
intensity: How much drift to remove (0=none, 1=full). Default: 1.0
Returns:
Drift-corrected latent
"""
if intensity <= 0:
return latent
try:
if method == 'median':
# Compute global median per batch (across all channels and spatial dims)
center = latent.view(latent.shape[0], -1).median(dim=-1, keepdim=True)[0]
center = center.view(latent.shape[0], 1, 1, 1)
else:
# Compute global mean per batch (across all channels and spatial dims)
# This matches ComfyUI's PostCFGsubtractMeanNode implementation
center = latent.mean(dim=(1, 2, 3), keepdim=True)
# Remove drift proportionally based on intensity
# intensity=1.0 removes all drift, intensity=0.5 removes half
return latent - center * intensity
except Exception as e:
print(f"⚠️ Combat CFG drift failed: {e}")
return latent
def compute_phase_aware_cfg_scale(base_scale, progress, alpha=2.0, beta=2.0):
"""
Phase-Aware CFG Scaling: Adjust CFG scale based on sampling progress.
Inspired by β-CFG (arXiv:2502.10574).
CFG effectiveness varies by sampling phase:
- Early: Lower CFG allows manifold exploration
- Middle: Higher CFG for prompt adherence
- Late: Lower CFG to stay on data manifold
Args:
base_scale: The user-specified CFG scale
progress: Sampling progress (0.0 to 1.0)
alpha: Beta distribution alpha parameter. Default: 2.0
beta: Beta distribution beta parameter. Default: 2.0
Returns:
Adjusted CFG scale for the current step
"""
try:
# Use a simple polynomial approximation of beta distribution
# Beta(2,2) peaks at 0.5 with a smooth curve
# f(x) = 6 * x * (1-x) for Beta(2,2), normalized to peak at 1
if alpha == 2.0 and beta == 2.0:
# Simple case: symmetric peak at 0.5
scale_factor = 4.0 * progress * (1.0 - progress) # Peaks at 1.0 when progress=0.5
scale_factor = 0.7 + 0.6 * scale_factor # Range: 0.7 to 1.3
else:
# General case: use polynomial approximation
# Mode of Beta(a,b) is at (a-1)/(a+b-2)
mode = (alpha - 1.0) / (alpha + beta - 2.0) if (alpha + beta) > 2 else 0.5
# Create a smooth curve that peaks at the mode
dist_from_mode = abs(progress - mode)
scale_factor = 1.0 - 0.3 * dist_from_mode * 2 # Simple linear falloff
scale_factor = max(0.7, min(1.3, scale_factor))
return base_scale * scale_factor
except Exception as e:
print(f"⚠️ Phase-aware CFG scaling failed: {e}")
return base_scale
# apply_cfg_techniques() removed — was using legacy keys (akashic_combat_cfg_drift /
# akashic_combat_drift_intensity) that no longer match the live CFG runtime, which
# operates through configure_cfg_runtime() / adept_after_cfg_callback instead.
# ============================================================================
# DUAL-MODE CFG RUNTIME (A1111 callbacks + optional native hook)
# ============================================================================
def create_phase_aware_native_cfg_hook(base_hook=None, alpha=2.0, beta=2.0):
"""
Native CFG hook for Forge/reForge-like backends that support
set_model_sampler_cfg_function(). Applies phase-aware CFG scaling,
then optionally delegates to a downstream hook (e.g. spectral modulation).
"""
def hook(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = float(args["cond_scale"])
sigma = args["sigma"]
x_orig = args["input"]
total_steps = max(int(ADEPT_STATE.get("_cfg_total_steps", 1)), 1)
step_idx = int(ADEPT_STATE.get("_cfg_step_idx", 0))
progress = min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0)
phased_scale = compute_phase_aware_cfg_scale(cond_scale, progress,
alpha=alpha, beta=beta)
patched_args = dict(args)
patched_args["cond_scale"] = phased_scale
if base_hook is not None:
return base_hook(patched_args)
# Vanilla CFG combine with phased scale
sigma_b = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x = x_orig / (sigma_b * sigma_b + 1.0)
cond_v = ((x - (x_orig - cond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b
uncond_v = ((x - (x_orig - uncond)) * (sigma_b ** 2 + 1.0) ** 0.5) / sigma_b
x_cfg = uncond_v + phased_scale * (cond_v - uncond_v)
return x_orig - (x - x_cfg * sigma_b / (sigma_b * sigma_b + 1.0) ** 0.5)
return hook
def create_combined_native_cfg_hook():
"""
Build one composite native hook from whatever CFG features are enabled.
Layer order: phase-aware scale → spectral modulation.
Returns None if nothing is enabled (caller should clear the hook).
"""
base_hook = None
if ADEPT_STATE.get("spectral_cfg_enabled", False):
base_hook = create_spectral_modulation_cfg_hook(
multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0),
percentile=ADEPT_STATE.get("spectral_percentile", 5.0),
)
if ADEPT_STATE.get("phase_cfg_enabled", False):
return create_phase_aware_native_cfg_hook(
base_hook=base_hook,
alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0),
beta=ADEPT_STATE.get("phase_cfg_beta", 2.0),
)
return base_hook
def adept_cfg_denoiser_callback(params):
"""
Official A1111 on_cfg_denoiser callback. Used only to track step
counters for phase-aware progress bookkeeping; the public API here
doesn't expose cond/uncond predictions so we can't do CFG math.
"""
ADEPT_STATE["_cfg_step_idx"] = int(getattr(params, "sampling_step", 0))
ADEPT_STATE["_cfg_total_steps"] = int(getattr(params, "total_sampling_steps", 1))
def adept_after_cfg_callback(params):
"""
Official A1111 on_cfg_after_cfg callback.
Combat CFG Drift is the only technique that maps cleanly here,
because AfterCFGCallbackParams only provides (x, sampling_step,
total_sampling_steps) — no raw cond/uncond tensors.
"""
if not ADEPT_STATE.get("enabled", False):
return
if not ADEPT_STATE.get("cfg_drift_enabled", False):
return
try:
params.x = apply_combat_cfg_drift(
params.x,
method=ADEPT_STATE.get("cfg_drift_method", "mean"),
intensity=ADEPT_STATE.get("cfg_drift_intensity", 0.5),
)
except Exception as e:
print(f"⚠️ Adept post-CFG drift callback failed: {e}")
def uninstall_a1111_cfg_callbacks():
global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB
for cb in (_ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB):
if cb is not None:
try:
script_callbacks.remove_callbacks_for_function(cb)
except Exception:
pass
_ADEPT_CFG_AFTER_CB = None
_ADEPT_CFG_DENOISER_CB = None
def install_a1111_cfg_callbacks():
global _ADEPT_CFG_AFTER_CB, _ADEPT_CFG_DENOISER_CB
uninstall_a1111_cfg_callbacks()
_ADEPT_CFG_DENOISER_CB = adept_cfg_denoiser_callback
_ADEPT_CFG_AFTER_CB = adept_after_cfg_callback
script_callbacks.on_cfg_denoiser( _ADEPT_CFG_DENOISER_CB, name="adept_cfg_denoiser")
script_callbacks.on_cfg_after_cfg( _ADEPT_CFG_AFTER_CB, name="adept_after_cfg")
def _get_native_cfg_hook_target():
"""
Locate a Forge/reForge-like model object that supports
set_model_sampler_cfg_function(), if one exists.
"""
sd = getattr(shared, "sd_model", None)
candidates = [
sd,
getattr(sd, "forge_objects", None),
getattr(sd, "model", None),
getattr(getattr(sd, "model", None), "model", None) if sd else None,
]
for obj in candidates:
if obj is not None and hasattr(obj, "set_model_sampler_cfg_function"):
return obj
return None
def uninstall_native_cfg_hook():
global _ADEPT_NATIVE_CFG_HOOK_ACTIVE
target = _get_native_cfg_hook_target()
if target is not None:
try:
target.set_model_sampler_cfg_function(None)
except Exception:
pass
_ADEPT_NATIVE_CFG_HOOK_ACTIVE = False
def install_native_cfg_hook():
global _ADEPT_NATIVE_CFG_HOOK_ACTIVE
target = _get_native_cfg_hook_target()
if target is None:
_ADEPT_NATIVE_CFG_HOOK_ACTIVE = False
return False
hook = create_combined_native_cfg_hook()
try:
target.set_model_sampler_cfg_function(hook) # None clears it if nothing enabled
_ADEPT_NATIVE_CFG_HOOK_ACTIVE = (hook is not None)
return True
except Exception as e:
print(f"⚠️ Adept native CFG hook install failed: {e}")
_ADEPT_NATIVE_CFG_HOOK_ACTIVE = False
return False
# ============================================================================
# CFGDenoiser MONKEY-PATCH (stock A1111 fallback for spectral/phase CFG)
# ============================================================================
def _adept_cfg_progress_from_denoiser(denoiser):
"""Compute sampling progress [0,1] from CFGDenoiser step counters."""
total_steps = max(int(getattr(denoiser, "total_steps", 1) or 1), 1)
step_idx = int(getattr(denoiser, "step", 0))
return min(max(step_idx / max(total_steps - 1, 1), 0.0), 1.0)
def _adept_nativeish_cfg_term(x_i, sigma_i, cond_i, uncond_i, scale):
"""
Approximate native hook behavior for one cond/uncond pair.
Feeds x/sigma/cond/uncond into the same composite hook builder used in
native-hook mode so stock A1111 gets as close as possible to Forge parity.
Falls back to plain weighted delta if hook is unavailable or errors.
"""
if abs(float(scale)) < 1e-12:
return torch.zeros_like(uncond_i)
hook = create_combined_native_cfg_hook()
if hook is None:
return (cond_i - uncond_i) * float(scale)
try:
combined = hook({
"cond": cond_i,
"uncond": uncond_i,
"cond_scale": float(scale),
"sigma": sigma_i,
"input": x_i,
})
return combined - uncond_i
except Exception as e:
print(f"⚠️ Adept native-ish CFG term fallback: {e}")
return (cond_i - uncond_i) * float(scale)
def patch_cfg_denoiser():
"""
Stock A1111 fallback for Spectral Modulation + Phase-Aware CFG.
Strategy:
1. Thin forward() wrapper that stashes x / sigma on the instance so the
combine methods can reach them — original forward logic is untouched.
2. Patched combine_denoised() uses those values to call the same composite
hook builder as native-hook mode, giving near-parity behaviour.
3. Patched combine_denoised_for_edit_model() does the same for pix2pix.
This is intentionally safer than a full forward() rewrite: upstream A1111
changes to refiner/masking/skip-uncond logic remain unaffected.
"""
global _CFGD_ORIG_COMBINE, _CFGD_ORIG_COMBINE_EDIT, _CFGD_ORIG_FORWARD, _CFGD_MONKEYPATCH_ACTIVE
try:
from modules import sd_samplers_cfg_denoiser as sd_cfg
except Exception as e:
print(f"⚠️ Adept CFGDenoiser patch import failed: {e}")
_CFGD_MONKEYPATCH_ACTIVE = False
return False
cls = sd_cfg.CFGDenoiser
if _CFGD_ORIG_COMBINE is None:
_CFGD_ORIG_COMBINE = cls.combine_denoised
if _CFGD_ORIG_COMBINE_EDIT is None:
_CFGD_ORIG_COMBINE_EDIT = cls.combine_denoised_for_edit_model
if _CFGD_ORIG_FORWARD is None:
_CFGD_ORIG_FORWARD = cls.forward
# --- forward wrapper: stash x/sigma, then run original ---
def adept_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
setattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, {
"x": x,
"sigma": sigma,
"uncond": uncond,
"cond": cond,
"cond_scale": float(cond_scale),
})
try:
return _CFGD_ORIG_FORWARD(self, x, sigma, uncond, cond,
cond_scale, s_min_uncond, image_cond)
finally:
try:
delattr(self, _ADEPT_CFGDENOISER_CTX_ATTR)
except AttributeError:
pass
# --- combine_denoised: per-cond native-ish or plain path ---
def adept_combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None)
x_ctx = None if ctx is None else ctx.get("x", None)
sigma_ctx = None if ctx is None else ctx.get("sigma", None)
progress = _adept_cfg_progress_from_denoiser(self)
eff_scale = float(cond_scale)
if ADEPT_STATE.get("phase_cfg_enabled", False):
eff_scale = compute_phase_aware_cfg_scale(
eff_scale, progress,
alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0),
beta =ADEPT_STATE.get("phase_cfg_beta", 2.0),
)
use_nativeish = (
(ADEPT_STATE.get("spectral_cfg_enabled", False) or
ADEPT_STATE.get("phase_cfg_enabled", False))
and x_ctx is not None and sigma_ctx is not None
)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
cond_i = x_out[cond_index:cond_index + 1]
uncond_i = denoised_uncond[i:i + 1]
if use_nativeish:
term = _adept_nativeish_cfg_term(
x_i = x_ctx[i:i + 1],
sigma_i = sigma_ctx[i:i + 1],
cond_i = cond_i,
uncond_i = uncond_i,
scale = float(weight) * eff_scale,
)
denoised[i:i + 1] += term
else:
delta = cond_i - uncond_i
if ADEPT_STATE.get("spectral_cfg_enabled", False):
delta = apply_spectral_modulation_clybius(
delta,
multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0),
percentile=ADEPT_STATE.get("spectral_percentile", 5.0),
)
denoised[i:i + 1] += delta * (float(weight) * eff_scale)
return denoised
# --- combine_denoised_for_edit_model: pix2pix / instruct path ---
def adept_combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
ctx = getattr(self, _ADEPT_CFGDENOISER_CTX_ATTR, None)
x_ctx = None if ctx is None else ctx.get("x", None)
sigma_ctx = None if ctx is None else ctx.get("sigma", None)
progress = _adept_cfg_progress_from_denoiser(self)
eff_scale = float(cond_scale)
if ADEPT_STATE.get("phase_cfg_enabled", False):
eff_scale = compute_phase_aware_cfg_scale(
eff_scale, progress,
alpha=ADEPT_STATE.get("phase_cfg_alpha", 2.0),
beta =ADEPT_STATE.get("phase_cfg_beta", 2.0),
)
# Native-ish path when context is available
if (ADEPT_STATE.get("spectral_cfg_enabled", False) or
ADEPT_STATE.get("phase_cfg_enabled", False)):
if x_ctx is not None and sigma_ctx is not None:
try:
hook = create_combined_native_cfg_hook()
if hook is not None:
base = hook({
"cond": out_cond,
"uncond": out_img_cond,
"cond_scale": eff_scale,
"sigma": sigma_ctx,
"input": x_ctx,
})
return base + self.image_cfg_scale * (out_img_cond - out_uncond)
except Exception as e:
print(f"⚠️ Adept edit-model native-ish fallback: {e}")
# Plain path (no context or hook failed)
delta = out_cond - out_img_cond
if ADEPT_STATE.get("spectral_cfg_enabled", False):
delta = apply_spectral_modulation_clybius(
delta,
multiplier=ADEPT_STATE.get("spectral_multiplier", 1.0),
percentile=ADEPT_STATE.get("spectral_percentile", 5.0),
)
return out_uncond + eff_scale * delta + self.image_cfg_scale * (out_img_cond - out_uncond)
try:
cls.forward = adept_forward
cls.combine_denoised = adept_combine_denoised
cls.combine_denoised_for_edit_model = adept_combine_denoised_for_edit_model
_CFGD_MONKEYPATCH_ACTIVE = True
return True
except Exception as e:
print(f"⚠️ Adept CFGDenoiser patch failed: {e}")
_CFGD_MONKEYPATCH_ACTIVE = False
return False
def unpatch_cfg_denoiser():
global _CFGD_MONKEYPATCH_ACTIVE
try:
from modules import sd_samplers_cfg_denoiser as sd_cfg
except Exception:
_CFGD_MONKEYPATCH_ACTIVE = False
return False
cls = sd_cfg.CFGDenoiser
try:
if _CFGD_ORIG_FORWARD is not None:
cls.forward = _CFGD_ORIG_FORWARD
if _CFGD_ORIG_COMBINE is not None:
cls.combine_denoised = _CFGD_ORIG_COMBINE
if _CFGD_ORIG_COMBINE_EDIT is not None:
cls.combine_denoised_for_edit_model = _CFGD_ORIG_COMBINE_EDIT
_CFGD_MONKEYPATCH_ACTIVE = False
return True
except Exception as e:
print(f"⚠️ Adept CFGDenoiser unpatch failed: {e}")
_CFGD_MONKEYPATCH_ACTIVE = False
return False
def configure_cfg_runtime():
"""
Select and activate the right CFG runtime mode:
off – nothing enabled; all hooks/callbacks/patches cleared
a1111-postcfg – stock A1111; Combat CFG Drift only via official callback
a1111-monkeypatch – stock A1111; Spectral + Phase-Aware via CFGDenoiser patch
native-hook – Forge/reForge-like backend; all three via sampler CFG hook
Returns the mode string so process() can log it.
"""
# If the extension is globally disabled, always tear down and return off.
if not ADEPT_STATE.get("enabled", False):
uninstall_a1111_cfg_callbacks()
uninstall_native_cfg_hook()
unpatch_cfg_denoiser()
ADEPT_STATE["cfg_runtime_mode"] = "off"
return "off"
drift = ADEPT_STATE.get("cfg_drift_enabled", False)
spectral = ADEPT_STATE.get("spectral_cfg_enabled", False)
phase = ADEPT_STATE.get("phase_cfg_enabled", False)
# Always tear down everything first for a clean slate
uninstall_a1111_cfg_callbacks()
uninstall_native_cfg_hook()
unpatch_cfg_denoiser()
if not (drift or spectral or phase):
ADEPT_STATE["cfg_runtime_mode"] = "off"
return "off"
# Prefer native hook if backend supports it
native_target = _get_native_cfg_hook_target()
if native_target is not None:
install_a1111_cfg_callbacks() # keeps drift working in native mode too
install_native_cfg_hook()
ADEPT_STATE["cfg_runtime_mode"] = "native-hook"
return "native-hook"
# Stock A1111: always install callbacks (drift)
install_a1111_cfg_callbacks()
if spectral or phase:
ok = patch_cfg_denoiser()
if ok:
ADEPT_STATE["cfg_runtime_mode"] = "a1111-monkeypatch"
print("✅ Adept: stock A1111 CFGDenoiser monkey-patch active (spectral/phase + drift enabled)")
return "a1111-monkeypatch"
print("⚠️ Adept: CFGDenoiser monkey-patch failed; falling back to post-CFG drift only")
ADEPT_STATE["cfg_runtime_mode"] = "a1111-postcfg"
return "a1111-postcfg"
def sa_solver_step(x, d_history, sigma, sigma_next, tau, s_noise=1.0, noise_sampler=None,
order=2, ndb_strength=0.0, progress=0.0, eqvae_mode=False, eqvae_blur_sigma=None):
"""SA-Solver step - CRITICAL for Akashic Solver."""
dt = sigma_next - sigma
if len(d_history) >= 2 and order >= 2:
sigma_cur, d_cur = d_history[-1]
sigma_prev, d_prev = d_history[-2]
h_prev = sigma_cur - sigma_prev
r = abs(dt / (h_prev + 1e-8)) if abs(h_prev) > 1e-8 else 1.0
r = min(r, 2.0)
if len(d_history) >= 3 and order >= 3:
sigma_0, d_0 = d_history[-3]
h_0 = sigma_prev - sigma_0
h_1 = h_prev
if abs(h_0) > 1e-6 and abs(h_1) > 1e-6:
r0 = min(abs(h_1 / h_0), 2.0)
r1 = min(abs(dt / (h_1 + 1e-8)), 2.0)
tau_blend = 1.0 - tau
c0_ab3 = 1.0 + (1.0 + r0) * r1 / 2.0
c1_ab3 = -(1.0 + r0) * r1 / 2.0
c2_ab3 = r0 * r1 / 2.0
c0 = tau_blend * c0_ab3 + (1.0 - tau_blend) * 1.0
c1 = tau_blend * c1_ab3
c2 = tau_blend * c2_ab3
c_sum = c0 + c1 + c2
if abs(c_sum) > 1e-8:
c0 /= c_sum
c1 /= c_sum
c2 /= c_sum
else:
c0, c1, c2 = 1.0, 0.0, 0.0
d_interp = c0 * d_cur + c1 * d_prev + c2 * d_0
else:
tau_blend = 1.0 - tau
c1_ab2 = 1.0 + 0.5 * r
c2_ab2 = -0.5 * r
c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0
c2 = tau_blend * c2_ab2
c_sum = c1 + c2
if abs(c_sum) > 1e-8:
c1 /= c_sum
c2 /= c_sum
d_interp = c1 * d_cur + c2 * d_prev
else:
tau_blend = 1.0 - tau
c1_ab2 = 1.0 + 0.5 * r
c2_ab2 = -0.5 * r
c1 = tau_blend * c1_ab2 + (1.0 - tau_blend) * 1.0
c2 = tau_blend * c2_ab2
c_sum = c1 + c2
if abs(c_sum) > 1e-8:
c1 /= c_sum
c2 /= c_sum
d_interp = c1 * d_cur + c2 * d_prev
elif len(d_history) >= 1:
d_interp = d_history[-1][1]
else:
d_interp = torch.zeros_like(x)
# Compute sigma_up based on tau (controls stochasticity)
sigma_up = 0.0
if tau > 0 and sigma_next > 0 and noise_sampler is not None:
sigma_ancestral_sq = sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / (sigma ** 2 + 1e-8)
sigma_ancestral = sigma_ancestral_sq ** 0.5 if sigma_ancestral_sq > 0 else 0.0
sigma_up = tau * sigma_ancestral
sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
dt_adjusted = sigma_down - sigma
x_det = x + d_interp * dt_adjusted
noise = noise_sampler(sigma, sigma_next) * s_noise * sigma_up
# Apply Native Detail Boost if enabled
if ndb_strength > 0 and TORCHVISION_AVAILABLE:
# Use EQ-VAE optimized NDB parameters if in EQ-VAE mode
if eqvae_mode:
blur_sigma, high_freq_boost = compute_eqvae_ndb(progress, ndb_strength)
else:
_, high_freq_boost = compute_native_detail_boost(progress, ndb_strength)
blur_sigma = 0.5 # Default blur sigma
# Override blur_sigma if explicitly provided
if eqvae_blur_sigma is not None:
blur_sigma = eqvae_blur_sigma
# Extract high-frequency component from noise using Gaussian blur
try:
low_freq_noise = gaussian_blur(noise, kernel_size=3, sigma=blur_sigma)
high_freq_noise = noise - low_freq_noise
noise = noise + high_freq_noise * high_freq_boost
except Exception:
pass # Fallback: use original noise if blur fails
x_next = x_det + noise
else:
x_next = x + d_interp * dt
return x_next, sigma_up
def create_detail_enhanced_model(model, x, sigmas, settings):
# NOTE: Detail Enhancement is currently an internal/experimental path.
# It is not wired into the UI and callers always pass
# use_detail_enhancement=False, so this function is never invoked at
# runtime. Kept for future re-integration; do not rely on it.
"""Detail enhancement wrapper."""
if not TORCHVISION_AVAILABLE:
return model
base_strength = settings.get('detail_enhancement_strength', 0.05)
radius = settings.get('detail_separation_radius', 0.5)
total_steps = len(sigmas) - 1
class DetailEnhancer:
def __init__(self):
self.current_step = 0
def __call__(self, x_current, sigma, **kwargs):
denoised = model(x_current, sigma, **kwargs)
try:
low_freq = gaussian_blur(denoised, kernel_size=3, sigma=radius)
high_freq = denoised - low_freq
progress = min(self.current_step / max(total_steps, 1), 1.0)
strength = base_strength * (0.5 + progress)
enhanced = denoised + high_freq * strength
self.current_step += 1
return enhanced
except Exception:
return denoised
return DetailEnhancer()
# ============================================================================
# ============================================================================
# ============================================================================
# CUSTOM ADVANCED SAMPLERS (Complete port from ComfyUI)
# ============================================================================
@torch.no_grad()
def sample_adept_solver(model, x, sigmas, extra_args=None, callback=None, disable=None,
order=2, use_corrector=True, use_detail_enhancement=False, settings=None):
"""
Adept Solver: A unified training-free diffusion solver synthesizing improvements from:
- DPM-Solver++ (data prediction, dynamic thresholding)
- UniPC (unified predictor-corrector framework)
- DEIS (exponential integrator)
- DC-Solver (dynamic compensation)
"""
extra_args = {} if extra_args is None else extra_args
settings = settings or {}
s_in = x.new_ones([x.shape[0]])
order = max(1, min(order, 3))
print(f"🚀 Adept Solver active (Order: {order}, Corrector: {'On' if use_corrector else 'Off'})")
active_model = model
# use_detail_enhancement is always False from current call-sites;
# the block below is preserved for future re-integration but is not
# currently reachable via the UI.
if use_detail_enhancement and TORCHVISION_AVAILABLE:
active_model = create_detail_enhanced_model(model, x, sigmas, settings)
model_outputs = []
for i in range(len(sigmas) - 1):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
denoised = active_model(x, sigma * s_in, **extra_args)
if extra_args.get('cond_scale', 1.0) > 7.0:
denoised = apply_dynamic_thresholding(denoised, percentile=0.995)
d = to_d(x, sigma, denoised)
derivative_max = torch.abs(d).max()
sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold:
print(f"⚠️ Extreme derivative detected at step {i}/{len(sigmas)-1}. Clamping for stability.")
d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.zeros_like(d)
model_outputs.append((sigma, d))
if len(model_outputs) > order:
model_outputs.pop(0)
dt = sigma_next - sigma
if len(model_outputs) == 1 or order == 1:
x_pred = x + d * dt
elif len(model_outputs) == 2 and order >= 2:
sigma_prev, d_prev = model_outputs[-2]
d_cur = model_outputs[-1][1]
h = sigma - sigma_prev
compensation_ratio = compute_compensation_ratio(h.item() if torch.is_tensor(h) else float(h), i, len(sigmas))
d_interp = d_cur + compensation_ratio * (d_cur - d_prev)
x_pred = x + d_interp * dt
else:
sigma_0, d_0 = model_outputs[-3]
sigma_1, d_1 = model_outputs[-2]
sigma_2, d_2 = model_outputs[-1]
h_0 = sigma_2 - sigma_1
h_1 = sigma_1 - sigma_0
h_0_val = h_0.item() if torch.is_tensor(h_0) else float(h_0)
h_1_val = h_1.item() if torch.is_tensor(h_1) else float(h_1)
if abs(h_1_val) < 1e-6:
compensation_ratio = compute_compensation_ratio(h_0_val, i, len(sigmas))
d_interp = d_2 + compensation_ratio * (d_2 - d_1)
else:
r0 = h_0_val / h_1_val
c0 = 1.0 + r0 / 2.0
c1 = -r0 / 2.0
c2 = 0.0
c_sum = c0 + c1 + c2
c0 /= c_sum
c1 /= c_sum
c2 = 1.0 - c0 - c1
d_interp = c0 * d_2 + c1 * d_1 + c2 * d_0
x_pred = x + d_interp * dt
if use_corrector and i < len(sigmas) - 2:
denoised_pred = active_model(x_pred, sigma_next * s_in, **extra_args)
if extra_args.get('cond_scale', 1.0) > 7.0:
denoised_pred = apply_dynamic_thresholding(denoised_pred, percentile=0.995)
d_pred = to_d(x_pred, sigma_next, denoised_pred)
if torch.isnan(d_pred).any() or torch.isinf(d_pred).any() or torch.abs(d_pred).max() > 1000.0:
d_pred = torch.clamp(d_pred, -100.0, 100.0)
if torch.isnan(d_pred).any() or torch.isinf(d_pred).any():
d_pred = torch.zeros_like(d_pred)
dt = sigma_next - sigma
x = x + (d + d_pred) * dt * 0.5
else:
x = x_pred
if torch.isnan(x).any() or torch.isinf(x).any():
print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!")
if i == 0:
raise RuntimeError("NaN/Inf on first step - check model/inputs")
print(" Attempting recovery with conservative Euler step...")
denoised_safe = active_model(x, sigma * s_in, **extra_args)
if torch.isnan(denoised_safe).any():
raise RuntimeError("Model producing NaN - check CFG scale and model")
d_safe = to_d(x, sigma, denoised_safe)
dt_safe = (sigma_next - sigma) * 0.5
x = x + d_safe * dt_safe
use_corrector = False
print(" Recovery successful. Corrector disabled for stability.")
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_ancestral_solver(model, x, sigmas, extra_args=None, callback=None, disable=None,
eta=1.0, s_noise=1.0, adaptive_eta=False, phase_noise=False,
phase_strength=0.5, enhanced_derivative=False,
use_detail_enhancement=False, settings=None):
"""
Enhanced Adept Ancestral Solver: Advanced ancestral sampling with phase-aware adaptations.
Key innovations:
1. Adaptive ancestral step sizing that changes throughout sampling phases
2. Phase-aware noise injection (more noise early, less noise late)
3. Enhanced derivative computation with ancestral-specific corrections
4. Dynamic eta scheduling for better control
"""
extra_args = {} if extra_args is None else extra_args
settings = settings or {}
s_in = x.new_ones([x.shape[0]])
print(f"🚀 Enhanced Adept Ancestral Solver active (η: {eta:.2f}, s_noise: {s_noise:.2f})")
print(f" Adaptive Eta: {adaptive_eta}, Phase Noise: {phase_noise}, Enhanced Derivative: {enhanced_derivative}")
active_model = model
# use_detail_enhancement is always False from current call-sites.
if use_detail_enhancement and TORCHVISION_AVAILABLE:
active_model = create_detail_enhanced_model(model, x, sigmas, settings)
noise_sampler = get_noise_sampler(x)
for i in range(len(sigmas) - 1):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
progress = i / max(len(sigmas) - 1, 1)
if adaptive_eta:
if progress < 0.3:
current_eta = eta * 1.08
elif progress < 0.7:
current_eta = eta * 0.95
else:
current_eta = eta * 1.02
else:
current_eta = eta
denoised = active_model(x, sigma * s_in, **extra_args)
if extra_args.get('cond_scale', 1.0) > 7.0:
denoised = apply_dynamic_thresholding(denoised, percentile=0.995)
if enhanced_derivative:
d = to_d_enhanced_ancestral(x, sigma, denoised, current_eta, progress, None)
else:
d = to_d(x, sigma, denoised)
derivative_max = torch.abs(d).max()
sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold:
d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.zeros_like(d)
if sigma_next > 0:
sigma_up = min(sigma_next, current_eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5)
sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
else:
sigma_up = 0.0
sigma_down = 0.0
dt = sigma_down - sigma
x_pred = x + d * dt
if sigma_next > 0:
if phase_noise:
if progress < 0.25:
target_multiplier = 1.0 + (0.05 * min(progress / 0.25, 1.0))
elif progress < 0.6:
target_multiplier = 1.0 - (0.02 * min((progress - 0.25) / 0.35, 1.0))
else:
target_multiplier = 1.0 - (0.05 * min((progress - 0.6) / 0.4, 1.0))
noise_multiplier = 1.0 + (target_multiplier - 1.0) * phase_strength
adaptive_s_noise = s_noise * noise_multiplier
else:
adaptive_s_noise = s_noise
noise = noise_sampler(sigma, sigma_next) * adaptive_s_noise * sigma_up
x = x_pred + noise
else:
x = x_pred
if torch.isnan(x).any() or torch.isinf(x).any():
print(f"❌ CRITICAL: NaN/Inf detected at step {i}/{len(sigmas)-1}!")
if i == 0:
raise RuntimeError("NaN/Inf on first step - check model/inputs")
print(" Attempting recovery...")
denoised_safe = active_model(x, sigma * s_in, **extra_args)
if torch.isnan(denoised_safe).any():
raise RuntimeError("Model producing NaN - check CFG scale and model")
d_safe = to_d(x, sigma, denoised_safe)
dt_safe = (sigma_next - sigma) * 0.5
x = x + d_safe * dt_safe
print(" Recovery successful.")
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_mirror_correction_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
eta=1.0, s_noise=1.0, correction_phase=0.5, smooth_phase=False):
"""
Mirror Correction Euler: Euler Ancestral with a semantic reflection probe.
In the first `correction_phase` fraction of steps, uses a 3-call Heun correction:
x_probe = 2*D(x) - x (reflection of x through its own denoised prediction)
The probe lies on the denoising trajectory, giving a curvature estimate for the
Heun correction. Remaining steps: standard 1-call Euler Ancestral.
Args:
eta: Ancestral noise coefficient. 0=deterministic, 1=full ancestral. Default: 1.0
s_noise: Noise scale multiplier. Default: 1.0
correction_phase: Fraction of steps that receive the 3-call correction. Default: 0.5
smooth_phase: Use continuous log-sigma weighting instead of a binary cutoff. Default: False
"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
print(f"🔮 Mirror Correction Euler active (η: {eta:.2f}, s_noise: {s_noise:.2f})")
print(f" Correction Phase: {correction_phase:.2f}, Smooth Phase: {smooth_phase}")
noise_sampler = get_noise_sampler(x)
n_steps = len(sigmas) - 1
log_sigma_phase = None
log_sigma_max = None
smooth_denom = 1e-6
if smooth_phase and n_steps > 0:
sigma_max_val = sigmas[0].clamp(min=1e-6)
phase_idx = min(int(correction_phase * n_steps), n_steps - 1)
sigma_phase_val = sigmas[phase_idx].clamp(min=1e-6)
log_sigma_max = torch.log(sigma_max_val).item()
log_sigma_phase = torch.log(sigma_phase_val).item()
smooth_denom = max(log_sigma_max - log_sigma_phase, 1e-6)
for i in range(n_steps):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
progress = i / max(n_steps - 1, 1)
denoised = model(x, sigma * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
d = to_d(x, sigma, denoised)
if sigma_next > 0:
sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5)
sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
else:
sigma_up = 0.0
sigma_down = 0.0
dt = sigma_down - sigma
if smooth_phase and log_sigma_phase is not None:
log_sig = torch.log(sigma.clamp(min=1e-6)).item()
t = max(0.0, min(1.0, (log_sig - log_sigma_phase) / smooth_denom))
correction_weight = t ** 0.5
if correction_weight > 1e-3 and sigma_next > 0:
x_probe = 2 * denoised - x
d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args))
d_diff_norm = (d - d_probe).norm()
d_scale = (d.norm() + d_probe.norm()) / 2 + 1e-6
gradient_agreement = max(0.0, 1.0 - (d_diff_norm / d_scale).item())
effective_weight = correction_weight * gradient_agreement
if effective_weight > 1e-3:
x3 = x + ((d + d_probe) / 2) * dt
d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args))
d_heun = (d + d3) / 2
if not (torch.isnan(d_heun).any() or torch.isinf(d_heun).any()):
d = d + effective_weight * (d_heun - d)
else:
if progress < correction_phase and sigma_next > 0:
x_probe = 2 * denoised - x
d_probe = to_d(x_probe, sigma, model(x_probe, sigma * s_in, **extra_args))
x3 = x + ((d + d_probe) / 2) * dt
d3 = to_d(x3, sigma, model(x3, sigma * s_in, **extra_args))
d = (d + d3) / 2
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.zeros_like(d)
x = x + d * dt
if sigma_next > 0:
x = x + noise_sampler(sigma, sigma_next) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_akashic_solver(model, x, sigmas, extra_args=None, callback=None, disable=None,
tau=0.5, eta=1.0, s_noise=1.0, adaptive_eta=True, phase_strength=0.5,
order=2, smea_strength=0.0, ndb_strength=0.0,
use_detail_enhancement=False, settings=None, eqvae_mode='Off'):
"""
AkashicSolver v2 [EXPERIMENTAL]: Advanced sampler optimized for EQ-VAE models.
Combines:
1. SA-SOLVER BASE: Multi-step Adams-Bashforth integration with tau function
2. PHASE-AWARE SAMPLING: Three-phase approach with adaptive parameters
3. SMEA COHERENCY: Sine-based interpolation for high-resolution coherency
Args:
eqvae_mode: EQ-VAE optimization mode ('Off' or 'Balanced')
"""
extra_args = {} if extra_args is None else extra_args
settings = settings or {}
s_in = x.new_ones([x.shape[0]])
if isinstance(eqvae_mode, bool):
eqvae_enabled = eqvae_mode
else:
eqvae_enabled = eqvae_mode == 'Balanced'
if eqvae_enabled:
print(f"🌀 AkashicSolver v2 [EQ-VAE BALANCED] active")
print(f" Optimized for EQ-VAE's cleaner latent space")
else:
print(f"🌀 AkashicSolver v2 [EXPERIMENTAL] active")
print(f" τ (tau): {tau:.2f}, η (eta): {eta:.2f}, s_noise: {s_noise:.2f}")
print(f" Order: {order}, Adaptive Eta: {adaptive_eta}, Phase Strength: {phase_strength:.2f}")
if smea_strength > 0:
print(f" SMEA: {smea_strength:.2f} (high-res coherency)")
if ndb_strength > 0:
print(f" Native Detail Boost: {ndb_strength:.2f} (detail enhancement)")
if not eqvae_enabled:
print(f" ⚠️ Use external rescaleCFG (e.g., 0.7) for EQ-VAE models")
active_model = model
# use_detail_enhancement is always False from current call-sites.
if use_detail_enhancement and TORCHVISION_AVAILABLE:
active_model = create_detail_enhanced_model(model, x, sigmas, settings)
noise_sampler = get_noise_sampler(x)
total_steps = len(sigmas) - 1
d_history = []
for i in range(total_steps):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
progress = i / max(total_steps - 1, 1)
if adaptive_eta:
if eqvae_enabled:
current_tau = compute_eqvae_tau(progress, tau, phase_strength)
else:
current_tau = compute_tau_eqvae(progress, tau, phase_strength)
else:
current_tau = tau
if adaptive_eta:
if eqvae_enabled:
if progress < 0.25:
current_eta = eta * (1.0 + 0.03 * phase_strength)
elif progress < 0.55:
current_eta = eta * (1.0 - 0.03 * phase_strength)
else:
current_eta = eta * (1.0 + 0.02 * phase_strength)
else:
if progress < 0.30:
current_eta = eta * (1.0 + 0.08 * phase_strength)
elif progress < 0.60:
current_eta = eta * (1.0 - 0.05 * phase_strength)
else:
current_eta = eta * (1.0 + 0.02 * phase_strength)
else:
current_eta = eta
smea_factor = compute_smea_factor(progress, smea_strength)
denoised = active_model(x, sigma * s_in, **extra_args)
cfg_scale = extra_args.get('cond_scale', 1.0)
if cfg_scale > 7.0:
denoised = apply_dynamic_thresholding(denoised, percentile=0.995)
d = to_d(x, sigma, denoised)
derivative_max = torch.abs(d).max()
sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
if torch.isnan(d).any() or torch.isinf(d).any() or derivative_max > sigma_adaptive_threshold:
d = torch.clamp(d, -sigma_adaptive_threshold, sigma_adaptive_threshold)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.zeros_like(d)
d_history.append((sigma, d))
if len(d_history) > order:
d_history.pop(0)
effective_tau = current_tau
if eqvae_enabled:
effective_s_noise = compute_eqvae_noise_scale(s_noise * current_eta, progress) * smea_factor
else:
effective_s_noise = s_noise * current_eta * smea_factor
if progress < 0.30:
noise_multiplier = 1.0 + 0.03 * phase_strength
elif progress < 0.60:
noise_multiplier = 1.0 - 0.01 * phase_strength
else:
noise_multiplier = 1.0 - 0.02 * phase_strength
effective_s_noise *= noise_multiplier
if eqvae_enabled and ndb_strength > 0:
eqvae_blur_sigma, _ = compute_eqvae_ndb(progress, ndb_strength)
else:
eqvae_blur_sigma = None
x, sigma_up = sa_solver_step(
x=x,
d_history=d_history,
sigma=sigma,
sigma_next=sigma_next,
tau=effective_tau,
s_noise=effective_s_noise,
noise_sampler=noise_sampler,
order=order,
ndb_strength=ndb_strength,
progress=progress,
eqvae_mode=eqvae_enabled,
eqvae_blur_sigma=eqvae_blur_sigma
)
if torch.isnan(x).any() or torch.isinf(x).any():
print(f"❌ AkashicSolver v2: NaN/Inf detected at step {i}/{total_steps}!")
if i == 0:
raise RuntimeError("NaN/Inf on first step - check model/inputs")
print(" Attempting recovery...")
denoised_safe = active_model(x, sigma * s_in, **extra_args)
if torch.isnan(denoised_safe).any():
raise RuntimeError("Model producing NaN - reduce CFG scale or check model")
d_safe = to_d(x, sigma, denoised_safe)
dt_safe = (sigma_next - sigma) * 0.5
x = x + d_safe * dt_safe
d_history.clear()
print(" Recovery successful. Multi-step history cleared.")
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
# ============================================================================
# ============================================================================
# ============================================================================
# This file continues from PART1 + PART1B
# ============================================================================
# ============================================================================
# WEIGHT PATCHER (from v3 - unchanged)
# ============================================================================
def should_patch_weights(unet_model, scale, shift):
"""Return True if weight patching is actually needed for these parameters."""
return (
unet_model is not None and
(abs(scale - 1.0) > 1e-6 or abs(shift) > 1e-6)
)
class AdeptWeightPatcher:
"""Temporary weight scaling for UNet."""
def __init__(self, unet_model, scale=1.0, shift=0.0):
self.unet_model = unet_model
self.scale = scale
self.shift = shift
self.backups = {}
self.target_layers = []
def __enter__(self):
if self.unet_model is None or (abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6):
return self
self.target_layers.clear()
self.backups.clear()
try:
for name, module in self.unet_model.named_modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
if hasattr(module, 'weight') and module.weight is not None:
self.target_layers.append((name, module))
self.backups[name] = module.weight.data.clone()
module.weight.data = module.weight.data * self.scale + self.shift
except Exception as e:
print(f"❌ Weight patcher failed: {e}")
self.__exit__(None, None, None)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
for name, module in self.target_layers:
if name in self.backups:
module.weight.data.copy_(self.backups[name])
self.backups.clear()
self.target_layers.clear()
except Exception as e:
print(f"❌ CRITICAL: Failed to restore weights: {e}")
for name, backup_data in self.backups.items():
try:
for n, m in self.target_layers:
if n == name:
m.weight.data.copy_(backup_data)
except:
pass
return False
# ============================================================================
# VAE REFLECTION PATCHER (from v3 - unchanged)
# ============================================================================
class VAEReflectionPatcher:
"""Context manager for VAE reflection padding."""
def __init__(self, vae_model):
self.vae_model = vae_model
self.backups = {}
def __enter__(self):
global _vae_reflection_active, _vae_original_padding_modes
if _vae_reflection_active or self.vae_model is None:
return self
_vae_original_padding_modes.clear()
patched_count = 0
try:
for name, module in self.vae_model.named_modules():
if isinstance(module, torch.nn.Conv2d):
_vae_original_padding_modes[name] = module.padding_mode
module.padding_mode = 'reflect'
patched_count += 1
_vae_reflection_active = True
print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers")
except Exception as e:
print(f"❌ VAE Reflection failed: {e}")
self.__exit__(None, None, None)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _vae_reflection_active, _vae_original_padding_modes
if self.vae_model is None:
_vae_reflection_active = False
_vae_original_padding_modes.clear()
return False
restored_count = 0
try:
for name, module in self.vae_model.named_modules():
if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes:
module.padding_mode = _vae_original_padding_modes[name]
restored_count += 1
_vae_reflection_active = False
_vae_original_padding_modes.clear()
print(f"🔄 VAE Reflection: Restored {restored_count} layers")
except Exception as e:
print(f"⚠️ VAE Reflection restore warning: {e}")
return False
def force_restore_vae_reflection():
"""
Emergency / unload-path restore for VAE padding modes.
Safe to call at any time — does nothing if VAE reflection was not active.
"""
global _vae_reflection_active, _vae_original_padding_modes
if not _vae_reflection_active and not _vae_original_padding_modes:
return
try:
sd = getattr(shared, "sd_model", None)
vae = getattr(sd, "first_stage_model", None) if sd else None
if vae is not None:
restored = 0
for name, module in vae.named_modules():
if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes:
module.padding_mode = _vae_original_padding_modes[name]
restored += 1
if restored:
print(f"🔄 VAE Reflection: force-restored {restored} layers")
except Exception as e:
print(f"⚠️ VAE Reflection force-restore warning: {e}")
finally:
_vae_reflection_active = False
_vae_original_padding_modes.clear()
# ALL SCHEDULERS (18 types)
# ============================================================================
def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AOS-V (Anime-Optimized Schedule for v-prediction models)."""
rho = 7.0
p1_steps = int(num_steps * 0.2)
p2_steps = int(num_steps * 0.6)
ramp = torch.empty(num_steps, device=device, dtype=torch.float32)
if p1_steps > 0:
torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps])
ramp[:p1_steps].pow_(0.5).mul_(0.6)
if p2_steps > p1_steps:
torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps])
if num_steps > p2_steps:
torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:])
ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho)
return torch.cat([ramp, torch.zeros(1, device=device)])
def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AOS-ε (Anime-Optimized Schedule for epsilon-prediction models)."""
rho = 7.0
p1_frac, p2_frac = 0.35, 0.7
ramp_p1_val, ramp_p2_val = 0.4, 0.75
p1_steps = int(num_steps * p1_frac)
p2_steps = int(num_steps * p2_frac)
phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val
phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device)
phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7
phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val
if p1_steps == 0: phase1_ramp = torch.empty(0, device=device)
if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device)
if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device)
ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp])
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
detail_power = 0.85
u_progressive = u ** detail_power
mid_boost_strength = 0.08
mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5)
u_modulated = u_progressive + mid_boost
u_min, u_max = u_modulated.min(), u_modulated.max()
if u_max - u_min > 1e-8:
u_modulated = (u_modulated - u_min) / (u_max - u_min)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho
for i in range(1, len(sigmas)):
if sigmas[i] >= sigmas[i-1]:
sigmas[i] = sigmas[i-1] * 0.995
max_ratio = 1.5
if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio:
sigmas[i] = sigmas[i-1] / max_ratio
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'):
"""Entropic power schedule."""
rho = 7.0
linear_ramp = torch.linspace(0, 1, num_steps, device=device)
power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power
ramp = (linear_ramp + power_ramp) / 2.0
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Schedule optimized around log SNR = 0 region."""
rho = 7.0
log_snr_max = 2 * torch.log(sigma_max)
log_snr_min = 2 * torch.log(sigma_min)
t = torch.linspace(0, 1, num_steps, device=device)
concentration_power = 3.0
sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5))
linear_t = t
blend_factor = 0.7
combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t
log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max)
sigmas = torch.exp(log_snr / 2)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Constant rate of distributional change."""
rho = 7.0
t = torch.linspace(0, 1, num_steps, device=device)
corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Adaptive schedule combining multiple strategies."""
rho = 7.0
base_t = torch.linspace(0, 1, num_steps, device=device)
strategies = [
lambda t: t,
lambda t: t ** 0.8,
lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t),
lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))),
]
weights = [0.2, 0.3, 0.2, 0.3]
combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies))
if (combined_t.max() - combined_t.min()) > 1e-6:
combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min())
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Cosine-annealed schedule."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
t = (1 - torch.cos(math.pi * u)) / 2
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Uniform in log-SNR space."""
u = torch.linspace(0, 1, num_steps, device=device)
log_snr_max = 2 * torch.log(sigma_max)
log_snr_min = 2 * torch.log(sigma_min)
log_snr = log_snr_max + u * (log_snr_min - log_snr_max)
sigmas = torch.exp(log_snr / 2)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0):
"""Concentrate steps near mid-range sigmas."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
k_tensor = torch.tensor(k, device=device, dtype=u.dtype)
t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0):
"""Faster early lock-in with extra resolution in final steps."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
early_mask = u < pivot
late_mask = ~early_mask
t = torch.empty_like(u)
t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot
late_u = u[late_mask]
t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot)))
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Karras schedule with controlled jitter."""
if num_steps <= 0:
return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
rho = 7.0
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
denom = max(1, num_steps - 1)
base = (indices + 0.5) / denom
jitter_seed = torch.sin((indices + 1) * 2.3999632)
jitter_strength = 0.35
jitter = jitter_seed * jitter_strength / denom
u = torch.clamp(base + jitter, 0.0, 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'):
"""Stochastic scheduler with controlled randomness."""
rho = 7.0
# Base schedule
if base_schedule == 'karras':
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
u = (indices / max(1, num_steps - 1)) ** (1 / rho)
elif base_schedule == 'cosine':
u = torch.linspace(0, 1, num_steps, device=device)
u = (1 - torch.cos(math.pi * u)) / 2
else: # uniform
u = torch.linspace(0, 1, num_steps, device=device)
# Add noise
if noise_type == 'brownian':
noise = torch.randn(num_steps, device=device).cumsum(0)
noise = noise / noise.std()
elif noise_type == 'uniform':
noise = torch.rand(num_steps, device=device) * 2 - 1
else: # normal
noise = torch.randn(num_steps, device=device)
u_noisy = u + noise * noise_scale / num_steps
u_noisy = torch.clamp(u_noisy, 0, 1)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho
# Sort the final sigmas descending so schedule is always noise→clean.
# Sorting u_noisy descending before the transform gives wrong order
# because the Karras mapping is monotone-decreasing in u.
sigmas, _ = torch.sort(sigmas, descending=True)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""
JYS (Jump Your Steps) schedule using dynamically computed timestep sequences.
Strategy: Large jumps early, dense clustering in detail region, fine steps at end.
Ported from ComfyUI reference implementation.
"""
# _compute_jys_timesteps returns num_steps entries + a trailing 0.
# Strip the trailing 0 so we get exactly num_steps timesteps; the
# explicit zeros(1) terminator is appended below.
jys_timesteps = _compute_jys_timesteps(num_steps)
if jys_timesteps and jys_timesteps[-1] == 0:
jys_timesteps = jys_timesteps[:-1]
rho = 7.0
normalized_timesteps = [(1000 - t) / 1000.0 for t in jys_timesteps]
t_tensor = torch.tensor(normalized_timesteps, device=device, dtype=torch.float32)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t_tensor * (min_inv_rho - max_inv_rho)) ** rho
sigmas, _ = torch.sort(sigmas, descending=True)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def _compute_jys_timesteps(num_steps):
"""Dynamically compute optimised JYS timestep sequence (0..1000 scale)."""
if num_steps <= 0:
return [0]
if num_steps == 1:
return [1000, 0]
elif num_steps == 2:
return [1000, 500, 0]
elif num_steps == 3:
return [1000, 600, 200, 0]
early_steps = max(1, int(num_steps * 0.2))
final_steps = max(1, int(num_steps * 0.2))
middle_steps = max(1, num_steps - early_steps - final_steps)
early_jump_size = max(50, (1000 - 600) // early_steps)
early_timesteps = []
current_t = 1000
for _ in range(early_steps):
early_timesteps.append(int(current_t))
current_t = max(600, current_t - early_jump_size)
middle_timesteps = []
structure_steps = max(1, middle_steps // 2)
structure_jump_size = max(10, (600 - 300) // structure_steps)
current_t = 600
for _ in range(structure_steps):
middle_timesteps.append(int(current_t))
current_t = max(300, current_t - structure_jump_size)
detail_steps = middle_steps - structure_steps
if detail_steps > 0:
detail_jump_size = max(5, (300 - 200) // detail_steps)
current_t = 300
for _ in range(detail_steps):
middle_timesteps.append(int(current_t))
current_t = max(200, current_t - detail_jump_size)
final_start = min(middle_timesteps) if middle_timesteps else 200
final_jump_size = max(5, final_start // final_steps)
final_timesteps = []
current_t = final_start
for _ in range(final_steps):
final_timesteps.append(int(current_t))
current_t = max(0, current_t - final_jump_size)
all_timesteps = early_timesteps + middle_timesteps + final_timesteps
unique_timesteps = list(dict.fromkeys(all_timesteps))
unique_timesteps.sort(reverse=True)
while len(unique_timesteps) < num_steps:
for i in range(len(unique_timesteps) - 1):
mid_point = (unique_timesteps[i] + unique_timesteps[i + 1]) // 2
if mid_point not in unique_timesteps:
unique_timesteps.insert(i + 1, mid_point)
if len(unique_timesteps) >= num_steps:
break
if len(unique_timesteps) > num_steps:
unique_timesteps = unique_timesteps[:num_steps]
if unique_timesteps[-1] != 0:
unique_timesteps.append(0)
return unique_timesteps
def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Hybrid: JYS mid-phase with Karras locks."""
if num_steps <= 0:
return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
rho = 7.0
jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1]
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
denom = max(1, num_steps - 1)
base = (indices + 0.5) / denom
jitter_seed = torch.sin((indices + 1) * 2.3999632)
jitter_strength = 0.35
jitter = jitter_seed * jitter_strength / denom
u = torch.clamp(base + jitter, 0.0, 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
positions = torch.linspace(0, 1, num_steps, device=device)
jys_weight = torch.empty_like(positions)
early_mask = positions < 0.3
mid_mask = (positions >= 0.3) & (positions < 0.8)
late_mask = positions >= 0.8
jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3)
jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5)
jys_weight[late_mask] = 0.9
jys_weight = jys_weight.clamp(0.2, 0.9)
log_jys = torch.log(jys_sigmas.clamp_min(1e-6))
log_karras = torch.log(karras_sigmas.clamp_min(1e-6))
log_hybrid = torch.lerp(log_karras, log_jys, jys_weight)
hybrid = torch.exp(log_hybrid)
smoothing = 1.0 - 0.05 * (1 - positions) ** 2
hybrid = hybrid * smoothing
for i in range(1, hybrid.shape[0]):
if hybrid[i] > hybrid[i - 1]:
hybrid[i] = hybrid[i - 1] * 0.999
return torch.cat([hybrid, torch.zeros(1, device=device)])
def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AYS (Align Your Steps) optimized for SDXL."""
AYS_SCHEDULES = {
10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000],
15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336,
0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000],
20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000,
0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156,
0.0039, 0.0000],
25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000,
0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500,
0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000],
30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667,
0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917,
0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081,
0.0041, 0.0010, 0.0000],
}
if num_steps in AYS_SCHEDULES:
normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32)
else:
available_steps = sorted(AYS_SCHEDULES.keys())
if num_steps < available_steps[0]:
ref_steps = available_steps[0]
elif num_steps > available_steps[-1]:
ref_steps = available_steps[-1]
else:
ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1])
ref_schedule = np.array(AYS_SCHEDULES[ref_steps])
t_ref = np.linspace(0, 1, len(ref_schedule))
t_new = np.linspace(0, 1, num_steps + 1)
log_ref = np.log(ref_schedule + 1e-8)
log_ref[-1] = log_ref[-2] - 3.0
log_interp = np.interp(t_new, t_ref, log_ref)
normalized_np = np.exp(log_interp)
normalized_np[-1] = 0.0
normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32)
sigma_range = sigma_max - sigma_min
sigmas = normalized * sigma_range + sigma_min
sigmas[0] = sigma_max
sigmas[-1] = 0.0
for i in range(1, len(sigmas) - 1):
if sigmas[i] >= sigmas[i-1]:
sigmas[i] = sigmas[i-1] * 0.999
# Append zero-terminator so output has num_steps+1 entries like all other schedulers.
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""
AkashicAOS Alt: Karras-based schedule with EQ-VAE-tuned warping.
Stronger detail-progressive bias (power=0.78) and shifted tanh crossover at t=0.55.
Adaptive rho scales with step count for multi-step solver stability.
"""
if num_steps <= 0:
return torch.zeros(1, device=device)
rho = min(11.0, max(7.0, 7.0 + 2.0 * (20.0 / max(num_steps, 10))))
u = torch.linspace(0, 1, num_steps, device=device)
detail_power = 0.78
u_detail = u ** detail_power
t_center = 0.55
beta = 0.07
gamma = 4.0
crossover = beta * torch.tanh(gamma * (u - t_center))
u_modulated = u_detail + crossover
u_min, u_max = u_modulated.min(), u_modulated.max()
if u_max - u_min > 1e-8:
u_modulated = (u_modulated - u_min) / (u_max - u_min)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho
max_ratio = 1.5
for i in range(1, len(sigmas)):
if sigmas[i] >= sigmas[i - 1]:
sigmas[i] = sigmas[i - 1] * 0.995
if sigmas[i - 1] / sigmas[i].clamp(min=1e-10) > max_ratio:
sigmas[i] = sigmas[i - 1] / max_ratio
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""
AkashicEQFlow: Robust crossover-focused log-SNR schedule for EQ-VAE models.
Concentrates steps around the structure-to-detail transition in logSNR space,
blended with a Karras prior. Adaptive density width + ratio slew-rate limiting.
"""
if num_steps <= 0:
return torch.zeros(1, device=device)
lambda_min = -2.0 * math.log(max(float(sigma_max), 1e-10))
lambda_max = -2.0 * math.log(max(float(sigma_min), 1e-10))
lambda_range = max(lambda_max - lambda_min, 1e-8)
step_factor = min(1.0, max(0.0, (num_steps - 16) / 30.0))
lambda_center = 0.20 + 0.15 * step_factor
u_center = (lambda_center - lambda_min) / lambda_range
u_center = float(min(0.88, max(0.12, u_center)))
concentration = min(3.2, max(1.35, 1.1 + num_steps / 16.0))
base_width = min(0.30, max(0.18, 0.31 - 0.0028 * num_steps))
width_left = base_width * 1.06
width_right = base_width * 0.94
detail_side_gain = 1.08 + 0.04 * step_factor
N = 1200
t = torch.linspace(0, 1, N, device=device)
delta = t - u_center
left_core = torch.exp(-((delta / width_left) ** 2) / 2.0)
right_core = detail_side_gain * torch.exp(-((delta / width_right) ** 2) / 2.0)
crossover_core = torch.where(delta <= 0, left_core, right_core)
detail_floor = 0.08 * (t ** 1.4)
composition_floor = 0.05 * ((1 - t) ** 1.7)
density = 1.0 + concentration * crossover_core + detail_floor + composition_floor
dt_val = 1.0 / (N - 1)
cdf = torch.zeros(N, device=device)
cdf[1:] = torch.cumsum((density[:-1] + density[1:]) * 0.5 * dt_val, dim=0)
cdf = cdf / cdf[-1].clamp(min=1e-12)
targets = torch.linspace(0, 1, num_steps, device=device)
indices = torch.searchsorted(cdf, targets).clamp(1, N - 1)
lo = indices - 1
hi = indices
frac = (targets - cdf[lo]) / (cdf[hi] - cdf[lo]).clamp(min=1e-12)
u_steps = t[lo] + frac * (t[hi] - t[lo])
lambdas_eqflow = lambda_min + u_steps * lambda_range
rho = min(10.0, max(7.0, 7.0 + 1.5 * (22.0 / max(num_steps, 12))))
u_karras = torch.linspace(0, 1, num_steps, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas_karras = (max_inv_rho + u_karras * (min_inv_rho - max_inv_rho)) ** rho
lambdas_karras = -2.0 * torch.log(sigmas_karras.clamp(min=1e-10))
blend_eqflow = min(0.60, max(0.35, 0.38 + num_steps / 200.0))
lambdas = (1.0 - blend_eqflow) * lambdas_karras + blend_eqflow * lambdas_eqflow
sigmas = torch.exp(-lambdas / 2.0)
if num_steps >= 40:
max_ratio = 1.50
elif num_steps >= 28:
max_ratio = 1.55
elif num_steps >= 18:
max_ratio = 1.65
else:
max_ratio = 1.85
ratio_slew = 1.18
prev_ratio = None
sigmas[0] = sigma_max
for i in range(1, len(sigmas)):
if sigmas[i] >= sigmas[i - 1]:
sigmas[i] = sigmas[i - 1] * 0.995
ratio = float((sigmas[i - 1] / sigmas[i].clamp(min=1e-10)).item())
ratio = min(ratio, max_ratio)
if prev_ratio is not None:
ratio = min(ratio, prev_ratio * ratio_slew)
ratio = max(ratio, prev_ratio / ratio_slew)
ratio = max(1.001, ratio)
sigmas[i] = sigmas[i - 1] / ratio
prev_ratio = ratio
return torch.cat([sigmas, torch.zeros(1, device=device)])
def apply_custom_scheduler(sigmas, scheduler_type="Standard"):
"""
Apply a custom sigma schedule.
sigma_min uses sigmas[-2] (last non-zero step), never the zero-terminator.
Each scheduler is invoked via a lambda so keyword args with non-standard
defaults (e.g. Entropic's `power`) are always passed correctly.
"""
if scheduler_type == "Standard" or len(sigmas) < 2:
return sigmas
sigma_max = sigmas[0]
# Use the last non-zero sigma as sigma_min; sigmas[-1] is always 0.
sigma_min = sigmas[-2] if len(sigmas) >= 2 else sigmas[0]
if sigma_min <= 0:
sigma_min = sigma_max * 1e-3
num_steps = len(sigmas) - 1
device = sigmas.device
scheduler_map = {
"AOS-V": lambda: create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device),
"AOS-Epsilon": lambda: create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device),
"AkashicAOS": lambda: create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device),
"Entropic": lambda: create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device=device),
"SNR-Optimized": lambda: create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device),
"Constant-Rate": lambda: create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device),
"Adaptive-Optimized": lambda: create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device),
"Cosine-Annealed": lambda: create_cosine_sigmas(sigma_max, sigma_min, num_steps, device),
"LogSNR-Uniform": lambda: create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device),
"Tanh Mid-Boost": lambda: create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device),
"Exponential Tail": lambda: create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device),
"Jittered-Karras": lambda: create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device),
"Stochastic": lambda: create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device=device),
"JYS (Dynamic)": lambda: create_jys_sigmas(sigma_max, sigma_min, num_steps, device),
"Hybrid JYS-Karras": lambda: create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device),
"AYS-SDXL": lambda: create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device),
"AkashicAOS Alt": lambda: create_aos_akashic_alt_sigmas(sigma_max, sigma_min, num_steps, device),
"AkashicEQFlow": lambda: create_akashic_eqflow_sigmas(sigma_max, sigma_min, num_steps, device),
}
fn = scheduler_map.get(scheduler_type)
if fn is not None:
try:
result = fn()
if result is not None and len(result) > 1:
return result
print(f"⚠️ Scheduler {scheduler_type} returned empty/None, using standard")
except Exception as e:
print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard")
return sigmas
# ============================================================================
# ============================================================================
# K-DIFFUSION SAMPLERS with Custom Sampler Integration
# ============================================================================
@torch.no_grad()
def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Euler sampler with Adept weight scaling OR custom sampler."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5),
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5),
order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0),
ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False,
settings={},
eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2),
use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False,
settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False),
phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5),
enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False,
settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
# STANDARD K-DIFFUSION MODE (from v3 - unchanged)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'euler' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
return _basic_euler(model, x, sigmas, extra_args, callback, disable)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
for i in range(total_steps):
sigma = sigmas[i]
gamma = min(s_churn / total_steps, 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
with AdeptWeightPatcher(unet_model, current_scale, shift):
eps = torch.randn_like(x) * s_noise if gamma > 0 else 0
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma_hat, 'denoised': denoised})
return x
def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None,
disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
"""Euler Ancestral with Adept weight scaling."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'euler_ancestral' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False)
current_eta = ADEPT_STATE.get('eta', eta)
current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
if noise_sampler is None:
noise_sampler = default_noise_sampler(x)
total_steps = len(sigmas) - 1
print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
progress = i / max(total_steps, 1)
# Adaptive eta
if use_adaptive_eta:
if progress < 0.3:
current_eta = eta * 1.08
elif progress < 0.7:
current_eta = eta * 0.95
else:
current_eta = eta * 1.02
else:
current_eta = eta
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# Euler Ancestral step
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
d = to_d(x, sigma, denoised)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
dt = sigma_down - sigma
x = x + d * dt
if sigma_up > 0:
noise = noise_sampler(sigma, sigma_next) * current_s_noise
x = x + noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic Euler (used when ORIGINAL_SAMPLERS has no 'euler' key)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
sigma = sigmas[i]
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
dt = sigmas[i + 1] - sigma
x = x + d * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
"""Fallback basic Euler Ancestral."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigma_up > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Heun sampler with Adept weight scaling."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'heun' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
return _basic_heun(model, x, sigmas, extra_args, callback, disable)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept Heun active: scale={base_scale:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# First evaluation
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
dt = sigma_next - sigma
if sigma_next == 0:
# Last step
x = x + d * dt
else:
# Heun's method: two-stage
x_2 = x + d * dt
# Second evaluation
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
else:
denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
d_2 = to_d(x_2, sigma_next, denoised_2)
if torch.isnan(d_2).any() or torch.isinf(d_2).any():
d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0)
# Average
d_prime = (d + d_2) / 2
x = x + d_prime * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic Heun."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
if sigmas[i + 1] == 0:
x = x + d * dt
else:
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM++ 2M sampler with Adept weight scaling."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'dpmpp_2m' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable)
return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}")
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# DPM++ 2M step
t, t_next = sigma, sigma_next
h = t_next - t
if old_denoised is None or sigma_next == 0:
# First step (Euler)
x = (sigma_next / sigma) * x - (-h).expm1() * denoised
else:
# Second order
h_last = t - sigmas[i - 1]
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d
old_denoised = denoised
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic DPM++ 2M."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
t, t_next = sigmas[i], sigmas[i + 1]
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (t_next / t) * x - (-h).expm1() * denoised
else:
h_last = t - sigmas[i - 1]
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (t_next / t) * x - (-h).expm1() * denoised_d
old_denoised = denoised
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
"""DPM++ 2S Ancestral with Adept weight scaling."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
current_eta = ADEPT_STATE.get('eta', eta)
current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
if noise_sampler is None:
noise_sampler = default_noise_sampler(x)
total_steps = len(sigmas) - 1
print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# First evaluation
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# DPM++ 2S step with ancestral noise
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
if sigma_down == 0:
d = to_d(x, sigma, denoised)
x = x + d * (sigma_down - sigma)
else:
# Midpoint method
t, t_next = sigma, sigma_down
h = t_next - t
s = t + h * 0.5
# Step to midpoint
x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
# Evaluate at midpoint
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised_mid = model(x_mid, s * s_in, **extra_args)
else:
denoised_mid = model(x_mid, s * s_in, **extra_args)
# Full step using midpoint
x = (t_next / t) * x - (-h).expm1() * denoised_mid
# Add ancestral noise
if sigma_up > 0:
noise = noise_sampler(sigma, sigma_next) * current_s_noise
x = x + noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
"""Fallback basic DPM++ 2S Ancestral."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
if sigma_down == 0:
d = to_d(x, sigmas[i], denoised)
x = x + d * (sigma_down - sigmas[i])
else:
t, t_next = sigmas[i], sigma_down
h = t_next - t
s = t + h * 0.5
x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
denoised_mid = model(x_mid, s * s_in, **extra_args)
x = (t_next / t) * x - (-h).expm1() * denoised_mid
if sigma_up > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
"""LMS sampler with Adept weight scaling."""
# CUSTOM SAMPLER INTEGRATION
if ADEPT_STATE.get('enabled', False) and ADEPT_STATE.get('use_custom_sampler', False):
custom_type = ADEPT_STATE.get('custom_sampler', 'Akashic Solver v2')
print(f"🌀 Redirecting to {custom_type}")
# Apply scheduler to sigmas for custom samplers
scheduler = ADEPT_STATE.get('scheduler', 'Standard')
if scheduler != "Standard":
sigmas = apply_custom_scheduler(sigmas, scheduler)
print(f" 📊 Applied {scheduler} scheduler")
if custom_type == "Akashic Solver v2":
return sample_akashic_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
tau=ADEPT_STATE.get('tau', 0.5), eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0), adaptive_eta=ADEPT_STATE.get('adaptive_eta', True),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), order=ADEPT_STATE.get('solver_order', 2),
smea_strength=ADEPT_STATE.get('smea_strength', 0.0), ndb_strength=ADEPT_STATE.get('ndb_strength', 0.0),
use_detail_enhancement=False, settings={}, eqvae_mode=ADEPT_STATE.get('eqvae_mode', 'Off')
)
elif custom_type == "Adept Solver":
return sample_adept_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
order=ADEPT_STATE.get('solver_order', 2), use_corrector=ADEPT_STATE.get('use_corrector', True),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Adept Ancestral Solver":
return sample_adept_ancestral_solver(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0), s_noise=ADEPT_STATE.get('s_noise', 1.0),
adaptive_eta=ADEPT_STATE.get('adaptive_eta', False), phase_noise=ADEPT_STATE.get('phase_noise', False),
phase_strength=ADEPT_STATE.get('phase_strength', 0.5), enhanced_derivative=ADEPT_STATE.get('enhanced_derivative', False),
use_detail_enhancement=False, settings={}
)
elif custom_type == "Mirror Correction Euler":
return sample_mirror_correction_euler(
model=model, x=x, sigmas=sigmas, extra_args=extra_args, callback=callback, disable=disable,
eta=ADEPT_STATE.get('eta', 1.0),
s_noise=ADEPT_STATE.get('s_noise', 1.0),
correction_phase=ADEPT_STATE.get('mirror_correction_phase', 0.5),
smooth_phase=ADEPT_STATE.get('mirror_smooth_phase', False)
)
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'lms' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order)
return _basic_lms(model, x, sigmas, extra_args, callback, disable, order)
# Apply custom scheduler deterministically (before the loop, not via p.sampler.model_wrap)
_sched = ADEPT_STATE.get('scheduler', 'Standard')
if _sched != 'Standard':
sigmas = apply_custom_scheduler(sigmas, _sched)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}")
ds = []
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"):
sigma = sigmas[i]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if should_patch_weights(unet_model, current_scale, shift):
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
# Linear multistep coefficients
cur_order = min(i + 1, order)
coeffs = [1.0]
for j in range(1, cur_order):
prod = 1.0
for k in range(cur_order):
if k != j:
prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
coeffs.append(prod)
# Apply multistep
d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
dt = sigmas[i + 1] - sigma
x = x + d_multistep * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
"""Fallback basic LMS."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
cur_order = min(i + 1, order)
coeffs = [1.0]
for j in range(1, cur_order):
prod = 1.0
for k in range(cur_order):
if k != j:
prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
coeffs.append(prod)
d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
dt = sigmas[i + 1] - sigmas[i]
x = x + d_multistep * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
# ============================================================================
# MONKEY PATCHING
# ============================================================================
def patch_k_diffusion():
"""Apply monkey patches to ALL k-diffusion samplers."""
global ORIGINAL_SAMPLERS
samplers_to_patch = {
'sample_euler': sample_adept_euler,
'sample_euler_ancestral': sample_adept_euler_ancestral,
'sample_heun': sample_adept_heun,
'sample_dpmpp_2m': sample_adept_dpmpp_2m,
'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral,
'sample_lms': sample_adept_lms,
}
patched_count = 0
for original_name, adept_func in samplers_to_patch.items():
if not hasattr(k_diffusion.sampling, original_name):
continue
key = original_name.replace('sample_', '')
current_func = getattr(k_diffusion.sampling, original_name)
# Only save original if we haven't stored it yet (avoid saving already-patched func)
if key not in ORIGINAL_SAMPLERS:
ORIGINAL_SAMPLERS[key] = current_func
# Always (re-)apply our patch if it isn't already there
if current_func is not adept_func:
setattr(k_diffusion.sampling, original_name, adept_func)
patched_count += 1
print(f"✅ Adept Sampler v5: Patched {patched_count} samplers")
print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS")
print(f" Schedulers: 18 types available")
def unpatch_k_diffusion():
"""
Restore original k-diffusion samplers.
Safe-unpatch strategy: before restoring we check whether the live slot
still holds *our* wrapper. If another extension has wrapped us on top
(i.e. live_func is not our adept_func but also not the original we
saved), blindly restoring would silently remove *their* wrapper too.
In that case we skip the restore for that slot and log a warning so the
operator knows the coexistence situation.
"""
global ORIGINAL_SAMPLERS
adept_wrappers = {
'euler': 'sample_euler',
'euler_ancestral': 'sample_euler_ancestral',
'heun': 'sample_heun',
'dpmpp_2m': 'sample_dpmpp_2m',
'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral',
'lms': 'sample_lms',
}
adept_funcs = {
'euler': sample_adept_euler,
'euler_ancestral': sample_adept_euler_ancestral,
'heun': sample_adept_heun,
'dpmpp_2m': sample_adept_dpmpp_2m,
'dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral,
'lms': sample_adept_lms,
}
restored_count = 0
skipped_count = 0
for key, attr_name in adept_wrappers.items():
if key not in ORIGINAL_SAMPLERS:
continue
live_func = getattr(k_diffusion.sampling, attr_name, None)
our_func = adept_funcs[key]
saved_original = ORIGINAL_SAMPLERS[key]
if live_func is our_func:
# Normal case: we still own the slot — safe to restore.
setattr(k_diffusion.sampling, attr_name, saved_original)
restored_count += 1
elif live_func is saved_original:
# Already restored somehow — nothing to do.
restored_count += 1
else:
# Another extension wrapped us. Restoring would silently
# remove their wrapper; skip and warn instead.
print(f"⚠️ Adept unpatch: {attr_name} is currently owned by another "
f"extension ({live_func!r}). Skipping restore to avoid breaking "
f"their wrapper — you may need to reload the UI to fully unload.")
skipped_count += 1
ORIGINAL_SAMPLERS.clear()
print(f"🔄 Adept Sampler: Restored {restored_count} samplers"
+ (f", skipped {skipped_count} (foreign wrappers)" if skipped_count else ""))
# ============================================================================
# A1111 EXTENSION SCRIPT
# ============================================================================
class AdeptSamplerScript(scripts.Script):
def title(self):
return "Adept Sampler v5"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion("Adept Sampler v5", open=False):
enabled = gr.Checkbox(label="Enable Adept Sampler", value=False, elem_id="adept_enabled")
with gr.Row():
scale = gr.Slider(minimum=0.5, maximum=2.0, step=0.05, value=1.0, label="Weight Scale")
shift = gr.Slider(minimum=-0.5, maximum=0.5, step=0.01, value=0.0, label="Weight Shift")
with gr.Row():
start_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Start Percent")
end_pct = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=1.0, label="End Percent")
gr.HTML("<p style='color: #888; font-size: 0.85em; margin: 2px 0 10px;'>"
"⚠️ Weight Scale / Shift / Start–End apply to the 6 patched k-diffusion samplers only. "
"Custom samplers (Akashic, Adept, Mirror) use their own internal parameters.</p>")
with gr.Row():
eta = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Eta (Ancestral samplers)")
s_noise = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="S-Noise")
adaptive_eta = gr.Checkbox(label="Adaptive Eta (dynamic eta during sampling)", value=False)
scheduler = gr.Dropdown(
choices=["Standard", "AOS-V", "AOS-Epsilon", "AkashicAOS", "Entropic", "SNR-Optimized",
"Constant-Rate", "Adaptive-Optimized", "Cosine-Annealed", "LogSNR-Uniform",
"Tanh Mid-Boost", "Exponential Tail", "Jittered-Karras", "Stochastic",
"JYS (Dynamic)", "Hybrid JYS-Karras", "AYS-SDXL",
"AkashicAOS Alt", "AkashicEQFlow"],
value="Standard", label="Scheduler Type"
)
vae_reflection = gr.Checkbox(label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)", value=False)
gr.HTML("<hr style='margin: 15px 0;'>")
gr.HTML("<h3 style='margin: 10px 0;'>🌀 Custom Advanced Samplers</h3>")
gr.HTML("<p style='color: #888; font-size: 0.9em;'>Enable to use Akashic/Adept/Ancestral samplers instead of k-diffusion</p>")
use_custom = gr.Checkbox(label="Use Custom Sampler (overrides k-diffusion)", value=False)
custom_type = gr.Dropdown(
choices=["Akashic Solver v2", "Adept Solver", "Adept Ancestral Solver", "Mirror Correction Euler"],
value="Akashic Solver v2", label="Custom Sampler Type"
)
with gr.Accordion("⚙️ Akashic Solver Settings", open=False):
tau = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Tau (0=ODE, 1=SDE)")
phase_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Phase Strength")
smea = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="SMEA (high-res coherency)")
ndb = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="NDB (detail boost)")
eqvae = gr.Dropdown(choices=["Off", "Balanced"], value="Off", label="EQ-VAE Mode")
with gr.Accordion("⚙️ Adept Solver Settings", open=False):
solver_order = gr.Slider(minimum=1, maximum=3, step=1, value=2, label="Order (1-3)")
use_corrector = gr.Checkbox(value=True, label="Use Corrector")
with gr.Accordion("⚙️ Ancestral Solver Settings", open=False):
phase_noise = gr.Checkbox(value=False, label="Phase-Aware Noise")
enhanced_deriv = gr.Checkbox(value=False, label="Enhanced Derivative")
with gr.Accordion("⚙️ Mirror Correction Euler Settings", open=False):
gr.HTML("<p style='color: #888; font-size: 0.9em;'>Active only when Custom Sampler = Mirror Correction Euler</p>")
mirror_correction_phase = gr.Slider(
minimum=0.0, maximum=1.0, step=0.05, value=0.5,
label="Correction Phase (fraction of steps with 3-call Heun correction)"
)
mirror_smooth_phase = gr.Checkbox(
value=False,
label="Smooth Phase (log-sigma blend instead of binary cutoff)"
)
gr.HTML("<hr style='margin: 15px 0;'>")
gr.HTML("<h3 style='margin: 10px 0;'>🎛️ CFG Enhancements</h3>")
gr.HTML("<p style='color: #888; font-size: 0.9em;'>"
"Combat CFG Drift works in stock A1111 via official callback. "
"Spectral Modulation &amp; Phase-Aware CFG use a native sampler hook on "
"Forge/reForge-like backends, or a CFGDenoiser monkey-patch on stock A1111 "
"(near-parity; active mode logged to console).</p>")
with gr.Accordion("⚙️ CFG Enhancement Settings", open=False):
cfg_drift_enabled = gr.Checkbox(value=False, label="Enable Combat CFG Drift")
with gr.Row():
cfg_drift_method = gr.Dropdown(
choices=["mean", "median"], value="mean",
label="Drift Method"
)
cfg_drift_intensity = gr.Slider(
minimum=0.0, maximum=1.0, step=0.05, value=0.5,
label="Drift Intensity"
)
spectral_cfg_enabled = gr.Checkbox(
value=False, label="Enable Spectral Modulation (native hook or A1111 monkey-patch)"
)
with gr.Row():
spectral_multiplier = gr.Slider(
minimum=0.0, maximum=2.0, step=0.05, value=1.0,
label="Spectral Multiplier"
)
spectral_percentile = gr.Slider(
minimum=1.0, maximum=25.0, step=0.5, value=5.0,
label="Spectral Percentile"
)
phase_cfg_enabled = gr.Checkbox(
value=False, label="Enable Phase-Aware CFG (native hook or A1111 monkey-patch)"
)
with gr.Row():
phase_cfg_alpha = gr.Slider(
minimum=1.1, maximum=4.0, step=0.1, value=2.0,
label="Phase CFG Alpha"
)
phase_cfg_beta = gr.Slider(
minimum=1.1, maximum=4.0, step=0.1, value=2.0,
label="Phase CFG Beta"
)
return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection,
use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector,
phase_noise, enhanced_deriv,
mirror_correction_phase, mirror_smooth_phase,
cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity,
spectral_cfg_enabled, spectral_multiplier, spectral_percentile,
phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta]
def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection,
use_custom, custom_type, tau, phase_strength, smea, ndb, eqvae, solver_order, use_corrector,
phase_noise, enhanced_deriv,
mirror_correction_phase, mirror_smooth_phase,
cfg_drift_enabled, cfg_drift_method, cfg_drift_intensity,
spectral_cfg_enabled, spectral_multiplier, spectral_percentile,
phase_cfg_enabled, phase_cfg_alpha, phase_cfg_beta):
global ADEPT_STATE
# Gate all sub-features through the master enabled switch.
# This prevents CFG hooks, native patches, and VAE Reflection from
# activating when the extension is globally toggled off.
ADEPT_STATE.update({
"enabled": enabled,
"scale": scale,
"shift": shift,
"start_pct": start_pct,
"end_pct": end_pct,
"eta": eta,
"s_noise": s_noise,
"adaptive_eta": adaptive_eta,
"scheduler": scheduler,
"vae_reflection": enabled and vae_reflection, # gated
"use_custom_sampler": use_custom,
"custom_sampler": custom_type,
"tau": tau,
"phase_strength": phase_strength,
"smea_strength": smea,
"ndb_strength": ndb,
"eqvae_mode": eqvae,
"solver_order": int(solver_order),
"use_corrector": use_corrector,
"phase_noise": phase_noise,
"enhanced_derivative": enhanced_deriv,
# Mirror Correction Euler
"mirror_correction_phase": mirror_correction_phase,
"mirror_smooth_phase": mirror_smooth_phase,
# CFG enhancements — all gated through enabled
"cfg_drift_enabled": enabled and cfg_drift_enabled,
"cfg_drift_method": cfg_drift_method,
"cfg_drift_intensity": cfg_drift_intensity,
"spectral_cfg_enabled": enabled and spectral_cfg_enabled,
"spectral_multiplier": spectral_multiplier,
"spectral_percentile": spectral_percentile,
"phase_cfg_enabled": enabled and phase_cfg_enabled,
"phase_cfg_alpha": phase_cfg_alpha,
"phase_cfg_beta": phase_cfg_beta,
})
# Scheduler is now applied inside each patched sampler function,
# so p.sampler.model_wrap patching is no longer needed here.
# Always reconfigure CFG runtime — even when disabled — so any previously
# installed native hook or A1111 callbacks get cleanly removed.
runtime_mode = configure_cfg_runtime()
if enabled:
info = {
"Adept Sampler": "v5",
"Adept Scheduler": scheduler,
"CFG Runtime": runtime_mode,
}
if use_custom:
info["Adept Custom"] = custom_type
if custom_type == "Akashic Solver v2":
info["Adept Tau"] = tau
info["Adept EQ-VAE"] = eqvae
p.extra_generation_params.update(info)
def process_batch(self, p, *args, **kwargs):
"""Apply VAE Reflection before batch processing."""
if ADEPT_STATE.get("enabled", False) and ADEPT_STATE.get("vae_reflection", False):
try:
vae_model = shared.sd_model.first_stage_model
patcher = VAEReflectionPatcher(vae_model)
patcher.__enter__()
p.adept_vae_patcher = patcher
except Exception as e:
print(f"⚠️ VAE Reflection error: {e}")
def postprocess_batch(self, p, *args, **kwargs):
"""Restore VAE padding modes after batch processing."""
if hasattr(p, 'adept_vae_patcher'):
try:
p.adept_vae_patcher.__exit__(None, None, None)
delattr(p, 'adept_vae_patcher')
except Exception as e:
print(f"⚠️ VAE Reflection restore error: {e}")
# Safety net: force-restore even if the patcher context failed
force_restore_vae_reflection()
# ============================================================================
# INITIALIZATION
# ============================================================================
#
# k-diffusion wrappers are installed via on_before_ui (fires after all
# extensions are imported) rather than at bare module import time.
# This reduces the risk of interacting badly with other extensions that
# also wrap k_diffusion.sampling functions, because our wrappers are put
# on last and therefore sit outermost in the call chain.
# Uninstall happens in on_script_unloaded() as before.
def _adept_deferred_init():
patch_k_diffusion()
try:
script_callbacks.on_before_ui(_adept_deferred_init)
except Exception:
# Fallback: if on_before_ui isn't available (older A1111), patch immediately.
patch_k_diffusion()
def on_script_unloaded():
try:
force_restore_vae_reflection()
except Exception:
pass
try:
uninstall_a1111_cfg_callbacks()
except Exception:
pass
try:
uninstall_native_cfg_hook()
except Exception:
pass
try:
unpatch_cfg_denoiser()
except Exception:
pass
try:
unpatch_k_diffusion()
except Exception:
pass
try:
script_callbacks.on_script_unloaded(on_script_unloaded)
except AttributeError:
print("⚠️ Script unload callback not available")
print("🚀 Adept Sampler v5 loaded!")
print(" ✨ 4 Custom Samplers: Akashic v2, Adept Solver, Adept Ancestral, Mirror Correction Euler")
print(" ⚡ 6 k-diffusion Samplers with weight scaling")
print(" 📅 18 Schedulers (including AkashicAOS Alt, AkashicEQFlow)")
print(" 🎨 VAE Reflection")
print(" ✅ A1111 port of ComfyUI-Adept-Sampler")