sdas / adept-sampler-v3 /scripts /adept_sampler_v3_FULL.py
dikdimon's picture
Upload adept-sampler-v3 using SD-Hub
8022862 verified
"""
Adept Sampler FULL PORT for Automatic1111 WebUI
Ported from ComfyUI/reForge extension
COMPLETE VERSION with:
- ALL Schedulers (16 types)
- ALL Samplers (Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S, LMS)
- VAE Reflection
- Dynamic Weight Scaling
Version: 3.0 FULL
"""
import torch
import numpy as np
import math
from tqdm import trange
from modules import scripts, shared, script_callbacks
import gradio as gr
import k_diffusion.sampling
# ============================================================================
# GLOBAL STATE
# ============================================================================
ADEPT_STATE = {
"enabled": False,
"scale": 1.0,
"shift": 0.0,
"start_pct": 0.0,
"end_pct": 1.0,
"eta": 1.0,
"s_noise": 1.0,
"adaptive_eta": False,
"scheduler": "Standard",
"vae_reflection": False,
}
# Store original samplers
ORIGINAL_SAMPLERS = {}
# VAE Reflection state
_vae_reflection_active = False
_vae_original_padding_modes = {}
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def to_d(x, sigma, denoised):
"""Convert denoised prediction to derivative."""
diff = x - denoised
safe_sigma = torch.clamp(sigma, min=1e-4)
derivative = diff / safe_sigma
sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
derivative_max = torch.abs(derivative).max()
if derivative_max > sigma_adaptive_threshold:
derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold)
return derivative
def get_ancestral_step(sigma, sigma_next, eta=1.0):
"""Calculate ancestral step sizes."""
if sigma_next == 0:
return 0.0, 0.0
sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5)
sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct):
"""Compute dynamic scale based on progress."""
progress = step_idx / max(total_steps - 1, 1)
if progress < start_pct or progress > end_pct:
return 1.0
# Smooth fade in/out
if progress < start_pct + 0.1:
fade = (progress - start_pct) / 0.1
return 1.0 + (base_scale - 1.0) * fade
elif progress > end_pct - 0.1:
fade = (end_pct - progress) / 0.1
return 1.0 + (base_scale - 1.0) * fade
else:
return base_scale
def default_noise_sampler(x):
"""Simple noise sampler fallback."""
def sampler(sigma, sigma_next):
return torch.randn_like(x)
return sampler
# ============================================================================
# WEIGHT PATCHER
# ============================================================================
class AdeptWeightPatcher:
"""Context manager for safe model weight modification."""
def __init__(self, model, scale, shift):
self.model = model
self.scale = scale
self.shift = shift
self.backups = {}
self.target_layers = []
# Cache target layers
for name, module in model.named_modules():
if any(block in name for block in ['input_blocks', 'middle_block', 'output_blocks']):
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if hasattr(module, 'weight') and module.weight is not None:
self.target_layers.append((name, module))
def __enter__(self):
if abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6:
return self
try:
for name, module in self.target_layers:
self.backups[name] = module.weight.data.clone()
module.weight.data = module.weight.data * self.scale + self.shift
except Exception as e:
print(f"⚠️ Weight patching failed: {e}")
self.__exit__(None, None, None)
raise
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
for name, module in self.target_layers:
if name in self.backups:
module.weight.data.copy_(self.backups[name])
self.backups.clear()
except Exception as e:
print(f"❌ CRITICAL: Failed to restore weights: {e}")
for name, backup_data in self.backups.items():
try:
for n, m in self.target_layers:
if n == name:
m.weight.data.copy_(backup_data)
except:
pass
return False
# ============================================================================
# VAE REFLECTION PATCHER
# ============================================================================
class VAEReflectionPatcher:
"""Context manager for VAE reflection padding."""
def __init__(self, vae_model):
self.vae_model = vae_model
self.backups = {}
def __enter__(self):
global _vae_reflection_active, _vae_original_padding_modes
if _vae_reflection_active or self.vae_model is None:
return self
_vae_original_padding_modes.clear()
patched_count = 0
try:
for name, module in self.vae_model.named_modules():
if isinstance(module, torch.nn.Conv2d):
_vae_original_padding_modes[name] = module.padding_mode
module.padding_mode = 'reflect'
patched_count += 1
_vae_reflection_active = True
print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers")
except Exception as e:
print(f"❌ VAE Reflection failed: {e}")
self.__exit__(None, None, None)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _vae_reflection_active, _vae_original_padding_modes
if self.vae_model is None:
_vae_reflection_active = False
_vae_original_padding_modes.clear()
return False
restored_count = 0
try:
for name, module in self.vae_model.named_modules():
if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes:
module.padding_mode = _vae_original_padding_modes[name]
restored_count += 1
_vae_reflection_active = False
_vae_original_padding_modes.clear()
print(f"🔄 VAE Reflection: Restored {restored_count} layers")
except Exception as e:
print(f"⚠️ VAE Reflection restore warning: {e}")
return False
# ============================================================================
# ALL SCHEDULERS (16 types)
# ============================================================================
def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AOS-V (Anime-Optimized Schedule for v-prediction models)."""
rho = 7.0
p1_steps = int(num_steps * 0.2)
p2_steps = int(num_steps * 0.6)
ramp = torch.empty(num_steps, device=device, dtype=torch.float32)
if p1_steps > 0:
torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps])
ramp[:p1_steps].pow_(0.5).mul_(0.6)
if p2_steps > p1_steps:
torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps])
if num_steps > p2_steps:
torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:])
ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho)
return torch.cat([ramp, torch.zeros(1, device=device)])
def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AOS-ε (Anime-Optimized Schedule for epsilon-prediction models)."""
rho = 7.0
p1_frac, p2_frac = 0.35, 0.7
ramp_p1_val, ramp_p2_val = 0.4, 0.75
p1_steps = int(num_steps * p1_frac)
p2_steps = int(num_steps * p2_frac)
phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val
phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device)
phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7
phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val
if p1_steps == 0: phase1_ramp = torch.empty(0, device=device)
if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device)
if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device)
ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp])
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
detail_power = 0.85
u_progressive = u ** detail_power
mid_boost_strength = 0.08
mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5)
u_modulated = u_progressive + mid_boost
u_min, u_max = u_modulated.min(), u_modulated.max()
if u_max - u_min > 1e-8:
u_modulated = (u_modulated - u_min) / (u_max - u_min)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho
for i in range(1, len(sigmas)):
if sigmas[i] >= sigmas[i-1]:
sigmas[i] = sigmas[i-1] * 0.995
max_ratio = 1.5
if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio:
sigmas[i] = sigmas[i-1] / max_ratio
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'):
"""Entropic power schedule."""
rho = 7.0
linear_ramp = torch.linspace(0, 1, num_steps, device=device)
power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power
ramp = (linear_ramp + power_ramp) / 2.0
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Schedule optimized around log SNR = 0 region."""
rho = 7.0
log_snr_max = 2 * torch.log(sigma_max)
log_snr_min = 2 * torch.log(sigma_min)
t = torch.linspace(0, 1, num_steps, device=device)
concentration_power = 3.0
sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5))
linear_t = t
blend_factor = 0.7
combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t
log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max)
sigmas = torch.exp(log_snr / 2)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Constant rate of distributional change."""
rho = 7.0
t = torch.linspace(0, 1, num_steps, device=device)
corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Adaptive schedule combining multiple strategies."""
rho = 7.0
base_t = torch.linspace(0, 1, num_steps, device=device)
strategies = [
lambda t: t,
lambda t: t ** 0.8,
lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t),
lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))),
]
weights = [0.2, 0.3, 0.2, 0.3]
combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies))
if (combined_t.max() - combined_t.min()) > 1e-6:
combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min())
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Cosine-annealed schedule."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
t = (1 - torch.cos(math.pi * u)) / 2
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Uniform in log-SNR space."""
u = torch.linspace(0, 1, num_steps, device=device)
log_snr_max = 2 * torch.log(sigma_max)
log_snr_min = 2 * torch.log(sigma_min)
log_snr = log_snr_max + u * (log_snr_min - log_snr_max)
sigmas = torch.exp(log_snr / 2)
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0):
"""Concentrate steps near mid-range sigmas."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
k_tensor = torch.tensor(k, device=device, dtype=u.dtype)
t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0):
"""Faster early lock-in with extra resolution in final steps."""
rho = 7.0
u = torch.linspace(0, 1, num_steps, device=device)
early_mask = u < pivot
late_mask = ~early_mask
t = torch.empty_like(u)
t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot
late_u = u[late_mask]
t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot)))
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Karras schedule with controlled jitter."""
if num_steps <= 0:
return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
rho = 7.0
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
denom = max(1, num_steps - 1)
base = (indices + 0.5) / denom
jitter_seed = torch.sin((indices + 1) * 2.3999632)
jitter_strength = 0.35
jitter = jitter_seed * jitter_strength / denom
u = torch.clamp(base + jitter, 0.0, 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'):
"""Stochastic scheduler with controlled randomness."""
rho = 7.0
# Base schedule
if base_schedule == 'karras':
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
u = (indices / max(1, num_steps - 1)) ** (1 / rho)
elif base_schedule == 'cosine':
u = torch.linspace(0, 1, num_steps, device=device)
u = (1 - torch.cos(math.pi * u)) / 2
else: # uniform
u = torch.linspace(0, 1, num_steps, device=device)
# Add noise
if noise_type == 'brownian':
noise = torch.randn(num_steps, device=device).cumsum(0)
noise = noise / noise.std()
elif noise_type == 'uniform':
noise = torch.rand(num_steps, device=device) * 2 - 1
else: # normal
noise = torch.randn(num_steps, device=device)
u_noisy = u + noise * noise_scale / num_steps
u_noisy = torch.clamp(u_noisy, 0, 1)
# Sort to maintain monotonicity
u_noisy, _ = torch.sort(u_noisy, descending=True)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho
return torch.cat([sigmas, torch.zeros(1, device=device)])
def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""JYS (Jump Your Steps) dynamic scheduler."""
if num_steps <= 0:
return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
if num_steps == 1:
return torch.tensor([sigma_max.item(), 0.0], device=device)
elif num_steps == 2:
mid = (sigma_max + sigma_min) / 2
return torch.tensor([sigma_max.item(), mid.item(), 0.0], device=device)
# Dynamic phase-based distribution
early_steps = max(1, int(num_steps * 0.2))
final_steps = max(1, int(num_steps * 0.2))
middle_steps = max(1, num_steps - early_steps - final_steps)
sigma_max_val = sigma_max.item() if torch.is_tensor(sigma_max) else float(sigma_max)
# Early phase (foundation)
early_jump_size = max(50, (sigma_max_val - 600) // early_steps)
early_sigmas = []
current_sigma = sigma_max_val
for _ in range(early_steps):
early_sigmas.append(current_sigma)
current_sigma = max(600, current_sigma - early_jump_size)
# Middle phase (structure + detail)
middle_sigmas = []
structure_steps = max(1, middle_steps // 2)
structure_jump = max(10, (600 - 300) // structure_steps)
current_sigma = 600
for _ in range(structure_steps):
middle_sigmas.append(current_sigma)
current_sigma = max(300, current_sigma - structure_jump)
detail_steps = middle_steps - structure_steps
if detail_steps > 0:
detail_jump = max(5, (300 - 200) // detail_steps)
current_sigma = 300
for _ in range(detail_steps):
middle_sigmas.append(current_sigma)
current_sigma = max(200, current_sigma - detail_jump)
# Final phase (refinement)
final_start = min(middle_sigmas) if middle_sigmas else 200
final_jump = max(5, final_start // final_steps)
final_sigmas = []
current_sigma = final_start
for _ in range(final_steps):
final_sigmas.append(current_sigma)
current_sigma = max(0, current_sigma - final_jump)
all_sigmas = early_sigmas + middle_sigmas + final_sigmas
unique_sigmas = list(dict.fromkeys(all_sigmas))
unique_sigmas.sort(reverse=True)
# Pad if needed
while len(unique_sigmas) < num_steps:
for i in range(len(unique_sigmas) - 1):
mid = (unique_sigmas[i] + unique_sigmas[i + 1]) / 2
if mid not in unique_sigmas:
unique_sigmas.insert(i + 1, mid)
if len(unique_sigmas) >= num_steps:
break
if len(unique_sigmas) > num_steps:
unique_sigmas = unique_sigmas[:num_steps]
if unique_sigmas[-1] != 0:
unique_sigmas.append(0)
return torch.tensor(unique_sigmas, device=device, dtype=torch.float32)
def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""Hybrid: JYS mid-phase with Karras locks."""
if num_steps <= 0:
return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
rho = 7.0
jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1]
indices = torch.arange(num_steps, device=device, dtype=torch.float32)
denom = max(1, num_steps - 1)
base = (indices + 0.5) / denom
jitter_seed = torch.sin((indices + 1) * 2.3999632)
jitter_strength = 0.35
jitter = jitter_seed * jitter_strength / denom
u = torch.clamp(base + jitter, 0.0, 1.0)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
positions = torch.linspace(0, 1, num_steps, device=device)
jys_weight = torch.empty_like(positions)
early_mask = positions < 0.3
mid_mask = (positions >= 0.3) & (positions < 0.8)
late_mask = positions >= 0.8
jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3)
jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5)
jys_weight[late_mask] = 0.9
jys_weight = jys_weight.clamp(0.2, 0.9)
log_jys = torch.log(jys_sigmas.clamp_min(1e-6))
log_karras = torch.log(karras_sigmas.clamp_min(1e-6))
log_hybrid = torch.lerp(log_karras, log_jys, jys_weight)
hybrid = torch.exp(log_hybrid)
smoothing = 1.0 - 0.05 * (1 - positions) ** 2
hybrid = hybrid * smoothing
for i in range(1, hybrid.shape[0]):
if hybrid[i] > hybrid[i - 1]:
hybrid[i] = hybrid[i - 1] * 0.999
return torch.cat([hybrid, torch.zeros(1, device=device)])
def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
"""AYS (Align Your Steps) optimized for SDXL."""
AYS_SCHEDULES = {
10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000],
15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336,
0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000],
20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000,
0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156,
0.0039, 0.0000],
25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000,
0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500,
0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000],
30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667,
0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917,
0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081,
0.0041, 0.0010, 0.0000],
}
if num_steps in AYS_SCHEDULES:
normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32)
else:
available_steps = sorted(AYS_SCHEDULES.keys())
if num_steps < available_steps[0]:
ref_steps = available_steps[0]
elif num_steps > available_steps[-1]:
ref_steps = available_steps[-1]
else:
ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1])
ref_schedule = np.array(AYS_SCHEDULES[ref_steps])
t_ref = np.linspace(0, 1, len(ref_schedule))
t_new = np.linspace(0, 1, num_steps + 1)
log_ref = np.log(ref_schedule + 1e-8)
log_ref[-1] = log_ref[-2] - 3.0
log_interp = np.interp(t_new, t_ref, log_ref)
normalized_np = np.exp(log_interp)
normalized_np[-1] = 0.0
normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32)
sigma_range = sigma_max - sigma_min
sigmas = normalized * sigma_range + sigma_min
sigmas[0] = sigma_max
sigmas[-1] = 0.0
for i in range(1, len(sigmas) - 1):
if sigmas[i] >= sigmas[i-1]:
sigmas[i] = sigmas[i-1] * 0.999
return sigmas
def apply_custom_scheduler(sigmas, scheduler_type="Standard"):
"""Apply custom scheduler to sigma schedule."""
if scheduler_type == "Standard" or len(sigmas) < 2:
return sigmas
sigma_min = sigmas[-1] if sigmas[-1] > 0 else sigmas[-2] * 0.001
sigma_max = sigmas[0]
steps = len(sigmas) - 1
device = sigmas.device
scheduler_map = {
"AOS-V": create_aos_v_sigmas,
"AOS-Epsilon": create_aos_e_sigmas,
"AkashicAOS": create_aos_akashic_sigmas,
"Entropic": create_entropic_sigmas,
"SNR-Optimized": create_snr_optimized_sigmas,
"Constant-Rate": create_constant_rate_sigmas,
"Adaptive-Optimized": create_adaptive_optimized_sigmas,
"Cosine-Annealed": create_cosine_sigmas,
"LogSNR-Uniform": create_logsnr_uniform_sigmas,
"Tanh Mid-Boost": create_tanh_midboost_sigmas,
"Exponential Tail": create_exponential_tail_sigmas,
"Jittered-Karras": create_jittered_karras_sigmas,
"Stochastic": create_stochastic_sigmas,
"JYS (Dynamic)": create_jys_sigmas,
"Hybrid JYS-Karras": create_hybrid_jys_karras_sigmas,
"AYS-SDXL": create_ays_sdxl_sigmas,
}
if scheduler_type in scheduler_map:
try:
return scheduler_map[scheduler_type](sigma_max, sigma_min, steps, device)
except Exception as e:
print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard")
return sigmas
return sigmas
# ============================================================================
# SAMPLER IMPLEMENTATIONS
# ============================================================================
@torch.no_grad()
def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Euler sampler with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'euler' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
return _basic_euler(model, x, sigmas, extra_args, callback, disable)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept Euler active: scale={base_scale:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler"):
sigma = sigmas[i]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# Euler step
d = to_d(x, sigma, denoised)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
dt = sigmas[i + 1] - sigma
x = x + d * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic Euler."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None,
disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
"""Euler Ancestral with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'euler_ancestral' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False)
current_eta = ADEPT_STATE.get('eta', eta)
current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
if noise_sampler is None:
noise_sampler = default_noise_sampler(x)
total_steps = len(sigmas) - 1
print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
progress = i / max(total_steps, 1)
# Adaptive eta
if use_adaptive_eta:
if progress < 0.3:
current_eta = eta * 1.08
elif progress < 0.7:
current_eta = eta * 0.95
else:
current_eta = eta * 1.02
else:
current_eta = eta
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# Euler Ancestral step
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
d = to_d(x, sigma, denoised)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
dt = sigma_down - sigma
x = x + d * dt
if sigma_up > 0:
noise = noise_sampler(sigma, sigma_next) * current_s_noise
x = x + noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
"""Fallback basic Euler Ancestral."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigma_up > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Heun sampler with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'heun' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
return _basic_heun(model, x, sigmas, extra_args, callback, disable)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept Heun active: scale={base_scale:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# First evaluation
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
if torch.isnan(d).any() or torch.isinf(d).any():
d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
dt = sigma_next - sigma
if sigma_next == 0:
# Last step
x = x + d * dt
else:
# Heun's method: two-stage
x_2 = x + d * dt
# Second evaluation
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
else:
denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
d_2 = to_d(x_2, sigma_next, denoised_2)
if torch.isnan(d_2).any() or torch.isinf(d_2).any():
d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0)
# Average
d_prime = (d + d_2) / 2
x = x + d_prime * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic Heun."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
if sigmas[i + 1] == 0:
x = x + d * dt
else:
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM++ 2M sampler with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'dpmpp_2m' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable)
return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}")
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# DPM++ 2M step
t, t_next = sigma, sigma_next
h = t_next - t
if old_denoised is None or sigma_next == 0:
# First step (Euler)
x = (sigma_next / sigma) * x - (-h).expm1() * denoised
else:
# Second order
h_last = t - sigmas[i - 1]
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d
old_denoised = denoised
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Fallback basic DPM++ 2M."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
t, t_next = sigmas[i], sigmas[i + 1]
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (t_next / t) * x - (-h).expm1() * denoised
else:
h_last = t - sigmas[i - 1]
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (t_next / t) * x - (-h).expm1() * denoised_d
old_denoised = denoised
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
"""DPM++ 2S Ancestral with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
current_eta = ADEPT_STATE.get('eta', eta)
current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
if noise_sampler is None:
noise_sampler = default_noise_sampler(x)
total_steps = len(sigmas) - 1
print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}")
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# First evaluation
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
# DPM++ 2S step with ancestral noise
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
if sigma_down == 0:
d = to_d(x, sigma, denoised)
x = x + d * (sigma_down - sigma)
else:
# Midpoint method
t, t_next = sigma, sigma_down
h = t_next - t
s = t + h * 0.5
# Step to midpoint
x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
# Evaluate at midpoint
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised_mid = model(x_mid, s * s_in, **extra_args)
else:
denoised_mid = model(x_mid, s * s_in, **extra_args)
# Full step using midpoint
x = (t_next / t) * x - (-h).expm1() * denoised_mid
# Add ancestral noise
if sigma_up > 0:
noise = noise_sampler(sigma, sigma_next) * current_s_noise
x = x + noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
"""Fallback basic DPM++ 2S Ancestral."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = default_noise_sampler(x)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
if sigma_down == 0:
d = to_d(x, sigmas[i], denoised)
x = x + d * (sigma_down - sigmas[i])
else:
t, t_next = sigmas[i], sigma_down
h = t_next - t
s = t + h * 0.5
x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
denoised_mid = model(x_mid, s * s_in, **extra_args)
x = (t_next / t) * x - (-h).expm1() * denoised_mid
if sigma_up > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
@torch.no_grad()
def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
"""LMS sampler with Adept weight scaling."""
if not ADEPT_STATE.get('enabled', False):
global ORIGINAL_SAMPLERS
if 'lms' in ORIGINAL_SAMPLERS:
return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order)
return _basic_lms(model, x, sigmas, extra_args, callback, disable, order)
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
# Get settings
base_scale = ADEPT_STATE.get('scale', 1.0)
shift = ADEPT_STATE.get('shift', 0.0)
start_pct = ADEPT_STATE.get('start_pct', 0.0)
end_pct = ADEPT_STATE.get('end_pct', 1.0)
# Get UNet
try:
unet_model = shared.sd_model.model.diffusion_model
except AttributeError:
unet_model = None
total_steps = len(sigmas) - 1
print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}")
ds = []
for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"):
sigma = sigmas[i]
# Dynamic scale
current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
# Evaluate model with weight patching
if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
with AdeptWeightPatcher(unet_model, current_scale, shift):
denoised = model(x, sigma * s_in, **extra_args)
else:
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
# Linear multistep coefficients
cur_order = min(i + 1, order)
coeffs = [1.0]
for j in range(1, cur_order):
prod = 1.0
for k in range(cur_order):
if k != j:
prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
coeffs.append(prod)
# Apply multistep
d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
dt = sigmas[i + 1] - sigma
x = x + d_multistep * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
return x
def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
"""Fallback basic LMS."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
cur_order = min(i + 1, order)
coeffs = [1.0]
for j in range(1, cur_order):
prod = 1.0
for k in range(cur_order):
if k != j:
prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
coeffs.append(prod)
d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
dt = sigmas[i + 1] - sigmas[i]
x = x + d_multistep * dt
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
return x
# ============================================================================
# MONKEY PATCHING
# ============================================================================
def patch_k_diffusion():
"""Apply monkey patches to ALL k-diffusion samplers."""
global ORIGINAL_SAMPLERS
samplers_to_patch = {
'sample_euler': sample_adept_euler,
'sample_euler_ancestral': sample_adept_euler_ancestral,
'sample_heun': sample_adept_heun,
'sample_dpmpp_2m': sample_adept_dpmpp_2m,
'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral,
'sample_lms': sample_adept_lms,
}
patched_count = 0
for original_name, adept_func in samplers_to_patch.items():
if hasattr(k_diffusion.sampling, original_name):
# Save original
if original_name not in ORIGINAL_SAMPLERS:
original_func = getattr(k_diffusion.sampling, original_name)
ORIGINAL_SAMPLERS[original_name.replace('sample_', '')] = original_func
# Apply patch
setattr(k_diffusion.sampling, original_name, adept_func)
patched_count += 1
print(f"✅ Adept Sampler v3 FULL: Patched {patched_count} samplers")
print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS")
print(f" Schedulers: 16 types available")
def unpatch_k_diffusion():
"""Restore original k-diffusion samplers."""
global ORIGINAL_SAMPLERS
samplers_to_restore = {
'euler': 'sample_euler',
'euler_ancestral': 'sample_euler_ancestral',
'heun': 'sample_heun',
'dpmpp_2m': 'sample_dpmpp_2m',
'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral',
'lms': 'sample_lms',
}
restored_count = 0
for key, attr_name in samplers_to_restore.items():
if key in ORIGINAL_SAMPLERS:
setattr(k_diffusion.sampling, attr_name, ORIGINAL_SAMPLERS[key])
restored_count += 1
print(f"🔄 Adept Sampler: Restored {restored_count} original samplers")
# ============================================================================
# A1111 EXTENSION SCRIPT
# ============================================================================
class AdeptSamplerScript(scripts.Script):
"""Adept Sampler FULL extension for A1111."""
def title(self):
return "Adept Sampler v3 FULL"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
"""Create UI elements."""
with gr.Accordion("Adept Sampler v3 FULL", open=False):
enabled = gr.Checkbox(
label="Enable Adept Sampler",
value=False,
elem_id="adept_enabled"
)
gr.HTML("<p style='color: #888;'>Works with: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS</p>")
with gr.Row():
scale = gr.Slider(
minimum=0.5,
maximum=2.0,
step=0.05,
value=1.0,
label="Weight Scale",
elem_id="adept_scale"
)
shift = gr.Slider(
minimum=-0.5,
maximum=0.5,
step=0.01,
value=0.0,
label="Weight Shift",
elem_id="adept_shift"
)
with gr.Row():
start_pct = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.0,
label="Start Percent",
elem_id="adept_start"
)
end_pct = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.05,
value=1.0,
label="End Percent",
elem_id="adept_end"
)
with gr.Row():
eta = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.01,
value=1.0,
label="Eta (Ancestral samplers)",
elem_id="adept_eta"
)
s_noise = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.01,
value=1.0,
label="S-Noise",
elem_id="adept_s_noise"
)
adaptive_eta = gr.Checkbox(
label="Adaptive Eta (dynamic eta during sampling)",
value=False,
elem_id="adept_adaptive_eta"
)
scheduler = gr.Dropdown(
choices=[
"Standard",
"AOS-V",
"AOS-Epsilon",
"AkashicAOS",
"Entropic",
"SNR-Optimized",
"Constant-Rate",
"Adaptive-Optimized",
"Cosine-Annealed",
"LogSNR-Uniform",
"Tanh Mid-Boost",
"Exponential Tail",
"Jittered-Karras",
"Stochastic",
"JYS (Dynamic)",
"Hybrid JYS-Karras",
"AYS-SDXL",
],
value="Standard",
label="Scheduler Type",
elem_id="adept_scheduler"
)
vae_reflection = gr.Checkbox(
label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)",
value=False,
elem_id="adept_vae_reflection"
)
return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection]
def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection):
"""Process parameters and update global state."""
global ADEPT_STATE
# Apply scheduler to sigmas
if enabled and scheduler != "Standard":
# Get original sigmas
original_sigmas = p.sampler.model_wrap.sigmas
# Apply custom scheduler
new_sigmas = apply_custom_scheduler(original_sigmas, scheduler)
# Update sigmas
p.sampler.model_wrap.sigmas = new_sigmas
print(f"📊 Applied scheduler: {scheduler}")
# Update global state
ADEPT_STATE.update({
"enabled": enabled,
"scale": scale,
"shift": shift,
"start_pct": start_pct,
"end_pct": end_pct,
"eta": eta,
"s_noise": s_noise,
"adaptive_eta": adaptive_eta,
"scheduler": scheduler,
"vae_reflection": vae_reflection,
})
# Add to generation info
if enabled:
p.extra_generation_params.update({
"Adept Sampler": "v3 FULL",
"Adept Scale": scale,
"Adept Shift": shift,
"Adept Range": f"{start_pct:.0%}-{end_pct:.0%}",
"Adept Eta": eta,
"Adept S-Noise": s_noise,
"Adept Adaptive Eta": adaptive_eta,
"Adept Scheduler": scheduler,
"Adept VAE Reflection": vae_reflection,
})
def process_batch(self, p, *args, **kwargs):
"""Wrap entire batch in VAE Reflection if enabled."""
if ADEPT_STATE.get('vae_reflection', False):
try:
vae_model = shared.sd_model.first_stage_model
with VAEReflectionPatcher(vae_model):
# VAE reflection active during this batch
pass
except Exception as e:
print(f"⚠️ VAE Reflection error: {e}")
# ============================================================================
# INITIALIZATION
# ============================================================================
# Apply patches on load
patch_k_diffusion()
# Register cleanup
def on_script_unloaded():
unpatch_k_diffusion()
try:
script_callbacks.on_script_unloaded(on_script_unloaded)
except AttributeError:
print("⚠️ Script unload callback not available")
print("🚀 Adept Sampler v3 FULL loaded!")
print(" - 6 Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS")
print(" - 16 Schedulers: AOS-V, AOS-Epsilon, AkashicAOS, Entropic, SNR-Optimized,")
print(" Constant-Rate, Adaptive-Optimized, Cosine-Annealed, LogSNR-Uniform,")
print(" Tanh Mid-Boost, Exponential Tail, Jittered-Karras, Stochastic,")
print(" JYS (Dynamic), Hybrid JYS-Karras, AYS-SDXL")
print(" - VAE Reflection support")
print(" - Dynamic Weight Scaling")