dikdimon commited on
Commit
8022862
·
verified ·
1 Parent(s): 989e83c

Upload adept-sampler-v3 using SD-Hub

Browse files
adept-sampler-v3/scripts/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.14 kB). View file
 
adept-sampler-v3/scripts/__pycache__/adept_sampler_v3_FULL.cpython-310.pyc ADDED
Binary file (36.3 kB). View file
 
adept-sampler-v3/scripts/adept_sampler_v3_FULL.py ADDED
@@ -0,0 +1,1554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adept Sampler FULL PORT for Automatic1111 WebUI
3
+ Ported from ComfyUI/reForge extension
4
+
5
+ COMPLETE VERSION with:
6
+ - ALL Schedulers (16 types)
7
+ - ALL Samplers (Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S, LMS)
8
+ - VAE Reflection
9
+ - Dynamic Weight Scaling
10
+
11
+ Version: 3.0 FULL
12
+ """
13
+
14
+ import torch
15
+ import numpy as np
16
+ import math
17
+ from tqdm import trange
18
+ from modules import scripts, shared, script_callbacks
19
+ import gradio as gr
20
+ import k_diffusion.sampling
21
+
22
+ # ============================================================================
23
+ # GLOBAL STATE
24
+ # ============================================================================
25
+ ADEPT_STATE = {
26
+ "enabled": False,
27
+ "scale": 1.0,
28
+ "shift": 0.0,
29
+ "start_pct": 0.0,
30
+ "end_pct": 1.0,
31
+ "eta": 1.0,
32
+ "s_noise": 1.0,
33
+ "adaptive_eta": False,
34
+ "scheduler": "Standard",
35
+ "vae_reflection": False,
36
+ }
37
+
38
+ # Store original samplers
39
+ ORIGINAL_SAMPLERS = {}
40
+
41
+ # VAE Reflection state
42
+ _vae_reflection_active = False
43
+ _vae_original_padding_modes = {}
44
+
45
+ # ============================================================================
46
+ # UTILITY FUNCTIONS
47
+ # ============================================================================
48
+
49
+ def to_d(x, sigma, denoised):
50
+ """Convert denoised prediction to derivative."""
51
+ diff = x - denoised
52
+ safe_sigma = torch.clamp(sigma, min=1e-4)
53
+ derivative = diff / safe_sigma
54
+
55
+ sigma_adaptive_threshold = 1000.0 * (1.0 + sigma / 10.0)
56
+ derivative_max = torch.abs(derivative).max()
57
+ if derivative_max > sigma_adaptive_threshold:
58
+ derivative = torch.clamp(derivative, -sigma_adaptive_threshold, sigma_adaptive_threshold)
59
+
60
+ return derivative
61
+
62
+
63
+ def get_ancestral_step(sigma, sigma_next, eta=1.0):
64
+ """Calculate ancestral step sizes."""
65
+ if sigma_next == 0:
66
+ return 0.0, 0.0
67
+ sigma_up = min(sigma_next, eta * (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) ** 0.5)
68
+ sigma_down = (sigma_next ** 2 - sigma_up ** 2) ** 0.5
69
+ return sigma_down, sigma_up
70
+
71
+
72
+ def compute_dynamic_scale(step_idx, total_steps, base_scale, start_pct, end_pct):
73
+ """Compute dynamic scale based on progress."""
74
+ progress = step_idx / max(total_steps - 1, 1)
75
+
76
+ if progress < start_pct or progress > end_pct:
77
+ return 1.0
78
+
79
+ # Smooth fade in/out
80
+ if progress < start_pct + 0.1:
81
+ fade = (progress - start_pct) / 0.1
82
+ return 1.0 + (base_scale - 1.0) * fade
83
+ elif progress > end_pct - 0.1:
84
+ fade = (end_pct - progress) / 0.1
85
+ return 1.0 + (base_scale - 1.0) * fade
86
+ else:
87
+ return base_scale
88
+
89
+
90
+ def default_noise_sampler(x):
91
+ """Simple noise sampler fallback."""
92
+ def sampler(sigma, sigma_next):
93
+ return torch.randn_like(x)
94
+ return sampler
95
+
96
+
97
+ # ============================================================================
98
+ # WEIGHT PATCHER
99
+ # ============================================================================
100
+
101
+ class AdeptWeightPatcher:
102
+ """Context manager for safe model weight modification."""
103
+
104
+ def __init__(self, model, scale, shift):
105
+ self.model = model
106
+ self.scale = scale
107
+ self.shift = shift
108
+ self.backups = {}
109
+ self.target_layers = []
110
+
111
+ # Cache target layers
112
+ for name, module in model.named_modules():
113
+ if any(block in name for block in ['input_blocks', 'middle_block', 'output_blocks']):
114
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
115
+ if hasattr(module, 'weight') and module.weight is not None:
116
+ self.target_layers.append((name, module))
117
+
118
+ def __enter__(self):
119
+ if abs(self.scale - 1.0) < 1e-6 and abs(self.shift) < 1e-6:
120
+ return self
121
+
122
+ try:
123
+ for name, module in self.target_layers:
124
+ self.backups[name] = module.weight.data.clone()
125
+ module.weight.data = module.weight.data * self.scale + self.shift
126
+ except Exception as e:
127
+ print(f"⚠️ Weight patching failed: {e}")
128
+ self.__exit__(None, None, None)
129
+ raise
130
+
131
+ return self
132
+
133
+ def __exit__(self, exc_type, exc_val, exc_tb):
134
+ try:
135
+ for name, module in self.target_layers:
136
+ if name in self.backups:
137
+ module.weight.data.copy_(self.backups[name])
138
+ self.backups.clear()
139
+ except Exception as e:
140
+ print(f"❌ CRITICAL: Failed to restore weights: {e}")
141
+ for name, backup_data in self.backups.items():
142
+ try:
143
+ for n, m in self.target_layers:
144
+ if n == name:
145
+ m.weight.data.copy_(backup_data)
146
+ except:
147
+ pass
148
+
149
+ return False
150
+
151
+
152
+ # ============================================================================
153
+ # VAE REFLECTION PATCHER
154
+ # ============================================================================
155
+
156
+ class VAEReflectionPatcher:
157
+ """Context manager for VAE reflection padding."""
158
+
159
+ def __init__(self, vae_model):
160
+ self.vae_model = vae_model
161
+ self.backups = {}
162
+
163
+ def __enter__(self):
164
+ global _vae_reflection_active, _vae_original_padding_modes
165
+
166
+ if _vae_reflection_active or self.vae_model is None:
167
+ return self
168
+
169
+ _vae_original_padding_modes.clear()
170
+ patched_count = 0
171
+
172
+ try:
173
+ for name, module in self.vae_model.named_modules():
174
+ if isinstance(module, torch.nn.Conv2d):
175
+ _vae_original_padding_modes[name] = module.padding_mode
176
+ module.padding_mode = 'reflect'
177
+ patched_count += 1
178
+
179
+ _vae_reflection_active = True
180
+ print(f"🪞 VAE Reflection: Patched {patched_count} Conv2d layers")
181
+ except Exception as e:
182
+ print(f"❌ VAE Reflection failed: {e}")
183
+ self.__exit__(None, None, None)
184
+
185
+ return self
186
+
187
+ def __exit__(self, exc_type, exc_val, exc_tb):
188
+ global _vae_reflection_active, _vae_original_padding_modes
189
+
190
+ if self.vae_model is None:
191
+ _vae_reflection_active = False
192
+ _vae_original_padding_modes.clear()
193
+ return False
194
+
195
+ restored_count = 0
196
+ try:
197
+ for name, module in self.vae_model.named_modules():
198
+ if isinstance(module, torch.nn.Conv2d) and name in _vae_original_padding_modes:
199
+ module.padding_mode = _vae_original_padding_modes[name]
200
+ restored_count += 1
201
+
202
+ _vae_reflection_active = False
203
+ _vae_original_padding_modes.clear()
204
+ print(f"🔄 VAE Reflection: Restored {restored_count} layers")
205
+ except Exception as e:
206
+ print(f"⚠️ VAE Reflection restore warning: {e}")
207
+
208
+ return False
209
+
210
+
211
+ # ============================================================================
212
+ # ALL SCHEDULERS (16 types)
213
+ # ============================================================================
214
+
215
+ def create_aos_v_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
216
+ """AOS-V (Anime-Optimized Schedule for v-prediction models)."""
217
+ rho = 7.0
218
+
219
+ p1_steps = int(num_steps * 0.2)
220
+ p2_steps = int(num_steps * 0.6)
221
+
222
+ ramp = torch.empty(num_steps, device=device, dtype=torch.float32)
223
+
224
+ if p1_steps > 0:
225
+ torch.linspace(0, 1, p1_steps, out=ramp[:p1_steps])
226
+ ramp[:p1_steps].pow_(0.5).mul_(0.6)
227
+
228
+ if p2_steps > p1_steps:
229
+ torch.linspace(0.6, 0.9, p2_steps - p1_steps, out=ramp[p1_steps:p2_steps])
230
+
231
+ if num_steps > p2_steps:
232
+ torch.linspace(0, 1, num_steps - p2_steps, out=ramp[p2_steps:])
233
+ ramp[p2_steps:].pow_(3).mul_(0.1).add_(0.9)
234
+
235
+ min_inv_rho = sigma_min ** (1 / rho)
236
+ max_inv_rho = sigma_max ** (1 / rho)
237
+ ramp.mul_(min_inv_rho - max_inv_rho).add_(max_inv_rho).pow_(rho)
238
+
239
+ return torch.cat([ramp, torch.zeros(1, device=device)])
240
+
241
+
242
+ def create_aos_e_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
243
+ """AOS-ε (Anime-Optimized Schedule for epsilon-prediction models)."""
244
+ rho = 7.0
245
+
246
+ p1_frac, p2_frac = 0.35, 0.7
247
+ ramp_p1_val, ramp_p2_val = 0.4, 0.75
248
+
249
+ p1_steps = int(num_steps * p1_frac)
250
+ p2_steps = int(num_steps * p2_frac)
251
+
252
+ phase1_ramp = torch.linspace(0, 1, p1_steps, device=device) ** 1.5 * ramp_p1_val
253
+ phase2_ramp = torch.linspace(ramp_p1_val, ramp_p2_val, p2_steps - p1_steps, device=device)
254
+ phase3_base = torch.linspace(0, 1, num_steps - p2_steps, device=device) ** 0.7
255
+ phase3_ramp = phase3_base * (1 - ramp_p2_val) + ramp_p2_val
256
+
257
+ if p1_steps == 0: phase1_ramp = torch.empty(0, device=device)
258
+ if p2_steps - p1_steps == 0: phase2_ramp = torch.empty(0, device=device)
259
+ if num_steps - p2_steps == 0: phase3_ramp = torch.empty(0, device=device)
260
+
261
+ ramp = torch.cat([phase1_ramp, phase2_ramp, phase3_ramp])
262
+
263
+ min_inv_rho = sigma_min ** (1 / rho)
264
+ max_inv_rho = sigma_max ** (1 / rho)
265
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
266
+
267
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
268
+
269
+
270
+ def create_aos_akashic_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
271
+ """AkashicAOS v2: Detail-Progressive Schedule for EQ-VAE SDXL models."""
272
+ rho = 7.0
273
+
274
+ u = torch.linspace(0, 1, num_steps, device=device)
275
+
276
+ detail_power = 0.85
277
+ u_progressive = u ** detail_power
278
+
279
+ mid_boost_strength = 0.08
280
+ mid_boost = mid_boost_strength * torch.sin(math.pi * u) * (1 - u * 0.5)
281
+
282
+ u_modulated = u_progressive + mid_boost
283
+
284
+ u_min, u_max = u_modulated.min(), u_modulated.max()
285
+ if u_max - u_min > 1e-8:
286
+ u_modulated = (u_modulated - u_min) / (u_max - u_min)
287
+
288
+ min_inv_rho = sigma_min ** (1 / rho)
289
+ max_inv_rho = sigma_max ** (1 / rho)
290
+ sigmas = (max_inv_rho + u_modulated * (min_inv_rho - max_inv_rho)) ** rho
291
+
292
+ for i in range(1, len(sigmas)):
293
+ if sigmas[i] >= sigmas[i-1]:
294
+ sigmas[i] = sigmas[i-1] * 0.995
295
+ max_ratio = 1.5
296
+ if i > 0 and sigmas[i-1] / sigmas[i] > max_ratio:
297
+ sigmas[i] = sigmas[i-1] / max_ratio
298
+
299
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
300
+
301
+
302
+ def create_entropic_sigmas(sigma_max, sigma_min, num_steps, power=6.0, device='cpu'):
303
+ """Entropic power schedule."""
304
+ rho = 7.0
305
+
306
+ linear_ramp = torch.linspace(0, 1, num_steps, device=device)
307
+ power_ramp = 1 - torch.linspace(1, 0, num_steps, device=device) ** power
308
+
309
+ ramp = (linear_ramp + power_ramp) / 2.0
310
+
311
+ min_inv_rho = sigma_min ** (1 / rho)
312
+ max_inv_rho = sigma_max ** (1 / rho)
313
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
314
+
315
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
316
+
317
+
318
+ def create_snr_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
319
+ """Schedule optimized around log SNR = 0 region."""
320
+ rho = 7.0
321
+
322
+ log_snr_max = 2 * torch.log(sigma_max)
323
+ log_snr_min = 2 * torch.log(sigma_min)
324
+
325
+ t = torch.linspace(0, 1, num_steps, device=device)
326
+
327
+ concentration_power = 3.0
328
+ sigmoid_t = torch.sigmoid(concentration_power * (t - 0.5))
329
+
330
+ linear_t = t
331
+ blend_factor = 0.7
332
+ combined_t = blend_factor * sigmoid_t + (1 - blend_factor) * linear_t
333
+
334
+ log_snr = log_snr_max + combined_t * (log_snr_min - log_snr_max)
335
+ sigmas = torch.exp(log_snr / 2)
336
+
337
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
338
+
339
+
340
+ def create_constant_rate_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
341
+ """Constant rate of distributional change."""
342
+ rho = 7.0
343
+
344
+ t = torch.linspace(0, 1, num_steps, device=device)
345
+ corrected_t = t + 0.3 * torch.sin(math.pi * t) * (1 - t)
346
+
347
+ min_inv_rho = sigma_min ** (1 / rho)
348
+ max_inv_rho = sigma_max ** (1 / rho)
349
+ sigmas = (max_inv_rho + corrected_t * (min_inv_rho - max_inv_rho)) ** rho
350
+
351
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
352
+
353
+
354
+ def create_adaptive_optimized_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
355
+ """Adaptive schedule combining multiple strategies."""
356
+ rho = 7.0
357
+
358
+ base_t = torch.linspace(0, 1, num_steps, device=device)
359
+
360
+ strategies = [
361
+ lambda t: t,
362
+ lambda t: t ** 0.8,
363
+ lambda t: t + 0.2 * torch.sin(2 * math.pi * t) * (1 - t),
364
+ lambda t: 1 / (1 + torch.exp(-3 * (t - 0.5))),
365
+ ]
366
+
367
+ weights = [0.2, 0.3, 0.2, 0.3]
368
+ combined_t = sum(w * s(base_t) for w, s in zip(weights, strategies))
369
+
370
+ if (combined_t.max() - combined_t.min()) > 1e-6:
371
+ combined_t = (combined_t - combined_t.min()) / (combined_t.max() - combined_t.min())
372
+
373
+ min_inv_rho = sigma_min ** (1 / rho)
374
+ max_inv_rho = sigma_max ** (1 / rho)
375
+ sigmas = (max_inv_rho + combined_t * (min_inv_rho - max_inv_rho)) ** rho
376
+
377
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
378
+
379
+
380
+ def create_cosine_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
381
+ """Cosine-annealed schedule."""
382
+ rho = 7.0
383
+ u = torch.linspace(0, 1, num_steps, device=device)
384
+ t = (1 - torch.cos(math.pi * u)) / 2
385
+ min_inv_rho = sigma_min ** (1 / rho)
386
+ max_inv_rho = sigma_max ** (1 / rho)
387
+ sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
388
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
389
+
390
+
391
+ def create_logsnr_uniform_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
392
+ """Uniform in log-SNR space."""
393
+ u = torch.linspace(0, 1, num_steps, device=device)
394
+ log_snr_max = 2 * torch.log(sigma_max)
395
+ log_snr_min = 2 * torch.log(sigma_min)
396
+ log_snr = log_snr_max + u * (log_snr_min - log_snr_max)
397
+ sigmas = torch.exp(log_snr / 2)
398
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
399
+
400
+
401
+ def create_tanh_midboost_sigmas(sigma_max, sigma_min, num_steps, device='cpu', k=4.0):
402
+ """Concentrate steps near mid-range sigmas."""
403
+ rho = 7.0
404
+ u = torch.linspace(0, 1, num_steps, device=device)
405
+ k_tensor = torch.tensor(k, device=device, dtype=u.dtype)
406
+ t = 0.5 * (torch.tanh(k_tensor * (u - 0.5)) / torch.tanh(k_tensor / 2) + 1.0)
407
+ min_inv_rho = sigma_min ** (1 / rho)
408
+ max_inv_rho = sigma_max ** (1 / rho)
409
+ sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
410
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
411
+
412
+
413
+ def create_exponential_tail_sigmas(sigma_max, sigma_min, num_steps, device='cpu', pivot=0.7, gamma=0.8, beta=5.0):
414
+ """Faster early lock-in with extra resolution in final steps."""
415
+ rho = 7.0
416
+ u = torch.linspace(0, 1, num_steps, device=device)
417
+
418
+ early_mask = u < pivot
419
+ late_mask = ~early_mask
420
+
421
+ t = torch.empty_like(u)
422
+ t[early_mask] = (u[early_mask] / pivot) ** gamma * pivot
423
+ late_u = u[late_mask]
424
+ t[late_mask] = pivot + (1 - pivot) * (1 - torch.exp(-beta * (late_u - pivot) / (1 - pivot)))
425
+
426
+ min_inv_rho = sigma_min ** (1 / rho)
427
+ max_inv_rho = sigma_max ** (1 / rho)
428
+ sigmas = (max_inv_rho + t * (min_inv_rho - max_inv_rho)) ** rho
429
+
430
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
431
+
432
+
433
+ def create_jittered_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
434
+ """Karras schedule with controlled jitter."""
435
+ if num_steps <= 0:
436
+ return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
437
+
438
+ rho = 7.0
439
+ indices = torch.arange(num_steps, device=device, dtype=torch.float32)
440
+ denom = max(1, num_steps - 1)
441
+
442
+ base = (indices + 0.5) / denom
443
+ jitter_seed = torch.sin((indices + 1) * 2.3999632)
444
+ jitter_strength = 0.35
445
+ jitter = jitter_seed * jitter_strength / denom
446
+
447
+ u = torch.clamp(base + jitter, 0.0, 1.0)
448
+
449
+ min_inv_rho = sigma_min ** (1 / rho)
450
+ max_inv_rho = sigma_max ** (1 / rho)
451
+ sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
452
+
453
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
454
+
455
+
456
+ def create_stochastic_sigmas(sigma_max, sigma_min, num_steps, device='cpu', noise_type='brownian', noise_scale=0.3, base_schedule='karras'):
457
+ """Stochastic scheduler with controlled randomness."""
458
+ rho = 7.0
459
+
460
+ # Base schedule
461
+ if base_schedule == 'karras':
462
+ indices = torch.arange(num_steps, device=device, dtype=torch.float32)
463
+ u = (indices / max(1, num_steps - 1)) ** (1 / rho)
464
+ elif base_schedule == 'cosine':
465
+ u = torch.linspace(0, 1, num_steps, device=device)
466
+ u = (1 - torch.cos(math.pi * u)) / 2
467
+ else: # uniform
468
+ u = torch.linspace(0, 1, num_steps, device=device)
469
+
470
+ # Add noise
471
+ if noise_type == 'brownian':
472
+ noise = torch.randn(num_steps, device=device).cumsum(0)
473
+ noise = noise / noise.std()
474
+ elif noise_type == 'uniform':
475
+ noise = torch.rand(num_steps, device=device) * 2 - 1
476
+ else: # normal
477
+ noise = torch.randn(num_steps, device=device)
478
+
479
+ u_noisy = u + noise * noise_scale / num_steps
480
+ u_noisy = torch.clamp(u_noisy, 0, 1)
481
+
482
+ # Sort to maintain monotonicity
483
+ u_noisy, _ = torch.sort(u_noisy, descending=True)
484
+
485
+ min_inv_rho = sigma_min ** (1 / rho)
486
+ max_inv_rho = sigma_max ** (1 / rho)
487
+ sigmas = (max_inv_rho + u_noisy * (min_inv_rho - max_inv_rho)) ** rho
488
+
489
+ return torch.cat([sigmas, torch.zeros(1, device=device)])
490
+
491
+
492
+ def create_jys_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
493
+ """JYS (Jump Your Steps) dynamic scheduler."""
494
+ if num_steps <= 0:
495
+ return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
496
+ if num_steps == 1:
497
+ return torch.tensor([sigma_max.item(), 0.0], device=device)
498
+ elif num_steps == 2:
499
+ mid = (sigma_max + sigma_min) / 2
500
+ return torch.tensor([sigma_max.item(), mid.item(), 0.0], device=device)
501
+
502
+ # Dynamic phase-based distribution
503
+ early_steps = max(1, int(num_steps * 0.2))
504
+ final_steps = max(1, int(num_steps * 0.2))
505
+ middle_steps = max(1, num_steps - early_steps - final_steps)
506
+
507
+ sigma_max_val = sigma_max.item() if torch.is_tensor(sigma_max) else float(sigma_max)
508
+
509
+ # Early phase (foundation)
510
+ early_jump_size = max(50, (sigma_max_val - 600) // early_steps)
511
+ early_sigmas = []
512
+ current_sigma = sigma_max_val
513
+ for _ in range(early_steps):
514
+ early_sigmas.append(current_sigma)
515
+ current_sigma = max(600, current_sigma - early_jump_size)
516
+
517
+ # Middle phase (structure + detail)
518
+ middle_sigmas = []
519
+ structure_steps = max(1, middle_steps // 2)
520
+ structure_jump = max(10, (600 - 300) // structure_steps)
521
+ current_sigma = 600
522
+ for _ in range(structure_steps):
523
+ middle_sigmas.append(current_sigma)
524
+ current_sigma = max(300, current_sigma - structure_jump)
525
+
526
+ detail_steps = middle_steps - structure_steps
527
+ if detail_steps > 0:
528
+ detail_jump = max(5, (300 - 200) // detail_steps)
529
+ current_sigma = 300
530
+ for _ in range(detail_steps):
531
+ middle_sigmas.append(current_sigma)
532
+ current_sigma = max(200, current_sigma - detail_jump)
533
+
534
+ # Final phase (refinement)
535
+ final_start = min(middle_sigmas) if middle_sigmas else 200
536
+ final_jump = max(5, final_start // final_steps)
537
+ final_sigmas = []
538
+ current_sigma = final_start
539
+ for _ in range(final_steps):
540
+ final_sigmas.append(current_sigma)
541
+ current_sigma = max(0, current_sigma - final_jump)
542
+
543
+ all_sigmas = early_sigmas + middle_sigmas + final_sigmas
544
+ unique_sigmas = list(dict.fromkeys(all_sigmas))
545
+ unique_sigmas.sort(reverse=True)
546
+
547
+ # Pad if needed
548
+ while len(unique_sigmas) < num_steps:
549
+ for i in range(len(unique_sigmas) - 1):
550
+ mid = (unique_sigmas[i] + unique_sigmas[i + 1]) / 2
551
+ if mid not in unique_sigmas:
552
+ unique_sigmas.insert(i + 1, mid)
553
+ if len(unique_sigmas) >= num_steps:
554
+ break
555
+
556
+ if len(unique_sigmas) > num_steps:
557
+ unique_sigmas = unique_sigmas[:num_steps]
558
+
559
+ if unique_sigmas[-1] != 0:
560
+ unique_sigmas.append(0)
561
+
562
+ return torch.tensor(unique_sigmas, device=device, dtype=torch.float32)
563
+
564
+
565
+ def create_hybrid_jys_karras_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
566
+ """Hybrid: JYS mid-phase with Karras locks."""
567
+ if num_steps <= 0:
568
+ return torch.cat([sigma_max.unsqueeze(0), torch.zeros(1, device=device)])
569
+
570
+ rho = 7.0
571
+
572
+ jys_sigmas = create_jys_sigmas(sigma_max, sigma_min, num_steps, device=device)[:-1]
573
+
574
+ indices = torch.arange(num_steps, device=device, dtype=torch.float32)
575
+ denom = max(1, num_steps - 1)
576
+ base = (indices + 0.5) / denom
577
+ jitter_seed = torch.sin((indices + 1) * 2.3999632)
578
+ jitter_strength = 0.35
579
+ jitter = jitter_seed * jitter_strength / denom
580
+ u = torch.clamp(base + jitter, 0.0, 1.0)
581
+
582
+ min_inv_rho = sigma_min ** (1 / rho)
583
+ max_inv_rho = sigma_max ** (1 / rho)
584
+ karras_sigmas = (max_inv_rho + u * (min_inv_rho - max_inv_rho)) ** rho
585
+
586
+ positions = torch.linspace(0, 1, num_steps, device=device)
587
+ jys_weight = torch.empty_like(positions)
588
+ early_mask = positions < 0.3
589
+ mid_mask = (positions >= 0.3) & (positions < 0.8)
590
+ late_mask = positions >= 0.8
591
+ jys_weight[early_mask] = 0.2 + 0.4 * (positions[early_mask] / 0.3)
592
+ jys_weight[mid_mask] = 0.6 + 0.3 * ((positions[mid_mask] - 0.3) / 0.5)
593
+ jys_weight[late_mask] = 0.9
594
+ jys_weight = jys_weight.clamp(0.2, 0.9)
595
+
596
+ log_jys = torch.log(jys_sigmas.clamp_min(1e-6))
597
+ log_karras = torch.log(karras_sigmas.clamp_min(1e-6))
598
+ log_hybrid = torch.lerp(log_karras, log_jys, jys_weight)
599
+
600
+ hybrid = torch.exp(log_hybrid)
601
+
602
+ smoothing = 1.0 - 0.05 * (1 - positions) ** 2
603
+ hybrid = hybrid * smoothing
604
+
605
+ for i in range(1, hybrid.shape[0]):
606
+ if hybrid[i] > hybrid[i - 1]:
607
+ hybrid[i] = hybrid[i - 1] * 0.999
608
+
609
+ return torch.cat([hybrid, torch.zeros(1, device=device)])
610
+
611
+
612
+ def create_ays_sdxl_sigmas(sigma_max, sigma_min, num_steps, device='cpu'):
613
+ """AYS (Align Your Steps) optimized for SDXL."""
614
+
615
+ AYS_SCHEDULES = {
616
+ 10: [1.0000, 0.8751, 0.7502, 0.6254, 0.5004, 0.3755, 0.2506, 0.1253, 0.0502, 0.0000],
617
+ 15: [1.0000, 0.9167, 0.8334, 0.7501, 0.6668, 0.5835, 0.5002, 0.4169, 0.3336,
618
+ 0.2503, 0.1670, 0.0837, 0.0335, 0.0084, 0.0000],
619
+ 20: [1.0000, 0.9375, 0.8750, 0.8125, 0.7500, 0.6875, 0.6250, 0.5625, 0.5000,
620
+ 0.4375, 0.3750, 0.3125, 0.2500, 0.1875, 0.1250, 0.0625, 0.0313, 0.0156,
621
+ 0.0039, 0.0000],
622
+ 25: [1.0000, 0.9500, 0.9000, 0.8500, 0.8000, 0.7500, 0.7000, 0.6500, 0.6000,
623
+ 0.5500, 0.5000, 0.4500, 0.4000, 0.3500, 0.3000, 0.2500, 0.2000, 0.1500,
624
+ 0.1000, 0.0625, 0.0391, 0.0195, 0.0098, 0.0024, 0.0000],
625
+ 30: [1.0000, 0.9583, 0.9167, 0.8750, 0.8333, 0.7917, 0.7500, 0.7083, 0.6667,
626
+ 0.6250, 0.5833, 0.5417, 0.5000, 0.4583, 0.4167, 0.3750, 0.3333, 0.2917,
627
+ 0.2500, 0.2083, 0.1667, 0.1250, 0.0833, 0.0521, 0.0326, 0.0163, 0.0081,
628
+ 0.0041, 0.0010, 0.0000],
629
+ }
630
+
631
+ if num_steps in AYS_SCHEDULES:
632
+ normalized = torch.tensor(AYS_SCHEDULES[num_steps], device=device, dtype=torch.float32)
633
+ else:
634
+ available_steps = sorted(AYS_SCHEDULES.keys())
635
+
636
+ if num_steps < available_steps[0]:
637
+ ref_steps = available_steps[0]
638
+ elif num_steps > available_steps[-1]:
639
+ ref_steps = available_steps[-1]
640
+ else:
641
+ ref_steps = min([s for s in available_steps if s >= num_steps], default=available_steps[-1])
642
+
643
+ ref_schedule = np.array(AYS_SCHEDULES[ref_steps])
644
+
645
+ t_ref = np.linspace(0, 1, len(ref_schedule))
646
+ t_new = np.linspace(0, 1, num_steps + 1)
647
+
648
+ log_ref = np.log(ref_schedule + 1e-8)
649
+ log_ref[-1] = log_ref[-2] - 3.0
650
+
651
+ log_interp = np.interp(t_new, t_ref, log_ref)
652
+ normalized_np = np.exp(log_interp)
653
+ normalized_np[-1] = 0.0
654
+
655
+ normalized = torch.tensor(normalized_np, device=device, dtype=torch.float32)
656
+
657
+ sigma_range = sigma_max - sigma_min
658
+ sigmas = normalized * sigma_range + sigma_min
659
+
660
+ sigmas[0] = sigma_max
661
+ sigmas[-1] = 0.0
662
+
663
+ for i in range(1, len(sigmas) - 1):
664
+ if sigmas[i] >= sigmas[i-1]:
665
+ sigmas[i] = sigmas[i-1] * 0.999
666
+
667
+ return sigmas
668
+
669
+
670
+ def apply_custom_scheduler(sigmas, scheduler_type="Standard"):
671
+ """Apply custom scheduler to sigma schedule."""
672
+ if scheduler_type == "Standard" or len(sigmas) < 2:
673
+ return sigmas
674
+
675
+ sigma_min = sigmas[-1] if sigmas[-1] > 0 else sigmas[-2] * 0.001
676
+ sigma_max = sigmas[0]
677
+ steps = len(sigmas) - 1
678
+ device = sigmas.device
679
+
680
+ scheduler_map = {
681
+ "AOS-V": create_aos_v_sigmas,
682
+ "AOS-Epsilon": create_aos_e_sigmas,
683
+ "AkashicAOS": create_aos_akashic_sigmas,
684
+ "Entropic": create_entropic_sigmas,
685
+ "SNR-Optimized": create_snr_optimized_sigmas,
686
+ "Constant-Rate": create_constant_rate_sigmas,
687
+ "Adaptive-Optimized": create_adaptive_optimized_sigmas,
688
+ "Cosine-Annealed": create_cosine_sigmas,
689
+ "LogSNR-Uniform": create_logsnr_uniform_sigmas,
690
+ "Tanh Mid-Boost": create_tanh_midboost_sigmas,
691
+ "Exponential Tail": create_exponential_tail_sigmas,
692
+ "Jittered-Karras": create_jittered_karras_sigmas,
693
+ "Stochastic": create_stochastic_sigmas,
694
+ "JYS (Dynamic)": create_jys_sigmas,
695
+ "Hybrid JYS-Karras": create_hybrid_jys_karras_sigmas,
696
+ "AYS-SDXL": create_ays_sdxl_sigmas,
697
+ }
698
+
699
+ if scheduler_type in scheduler_map:
700
+ try:
701
+ return scheduler_map[scheduler_type](sigma_max, sigma_min, steps, device)
702
+ except Exception as e:
703
+ print(f"⚠️ Scheduler {scheduler_type} failed: {e}, using standard")
704
+ return sigmas
705
+
706
+ return sigmas
707
+
708
+
709
+ # ============================================================================
710
+ # SAMPLER IMPLEMENTATIONS
711
+ # ============================================================================
712
+
713
+ @torch.no_grad()
714
+ def sample_adept_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
715
+ """Euler sampler with Adept weight scaling."""
716
+
717
+ if not ADEPT_STATE.get('enabled', False):
718
+ global ORIGINAL_SAMPLERS
719
+ if 'euler' in ORIGINAL_SAMPLERS:
720
+ return ORIGINAL_SAMPLERS['euler'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
721
+ return _basic_euler(model, x, sigmas, extra_args, callback, disable)
722
+
723
+ extra_args = {} if extra_args is None else extra_args
724
+ s_in = x.new_ones([x.shape[0]])
725
+
726
+ # Get settings
727
+ base_scale = ADEPT_STATE.get('scale', 1.0)
728
+ shift = ADEPT_STATE.get('shift', 0.0)
729
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
730
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
731
+
732
+ # Get UNet
733
+ try:
734
+ unet_model = shared.sd_model.model.diffusion_model
735
+ except AttributeError:
736
+ unet_model = None
737
+
738
+ total_steps = len(sigmas) - 1
739
+ print(f"✅ Adept Euler active: scale={base_scale:.2f}")
740
+
741
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler"):
742
+ sigma = sigmas[i]
743
+
744
+ # Dynamic scale
745
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
746
+
747
+ # Evaluate model with weight patching
748
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
749
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
750
+ denoised = model(x, sigma * s_in, **extra_args)
751
+ else:
752
+ denoised = model(x, sigma * s_in, **extra_args)
753
+
754
+ # Euler step
755
+ d = to_d(x, sigma, denoised)
756
+
757
+ if torch.isnan(d).any() or torch.isinf(d).any():
758
+ d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
759
+
760
+ dt = sigmas[i + 1] - sigma
761
+ x = x + d * dt
762
+
763
+ if callback is not None:
764
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
765
+
766
+ return x
767
+
768
+
769
+ def _basic_euler(model, x, sigmas, extra_args=None, callback=None, disable=None):
770
+ """Fallback basic Euler."""
771
+ extra_args = {} if extra_args is None else extra_args
772
+ s_in = x.new_ones([x.shape[0]])
773
+
774
+ for i in trange(len(sigmas) - 1, disable=disable):
775
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
776
+ d = to_d(x, sigmas[i], denoised)
777
+ dt = sigmas[i + 1] - sigmas[i]
778
+ x = x + d * dt
779
+ if callback is not None:
780
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
781
+
782
+ return x
783
+
784
+
785
+ @torch.no_grad()
786
+ def sample_adept_euler_ancestral(model, x, sigmas, extra_args=None, callback=None,
787
+ disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
788
+ """Euler Ancestral with Adept weight scaling."""
789
+
790
+ if not ADEPT_STATE.get('enabled', False):
791
+ global ORIGINAL_SAMPLERS
792
+ if 'euler_ancestral' in ORIGINAL_SAMPLERS:
793
+ return ORIGINAL_SAMPLERS['euler_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
794
+ return _basic_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
795
+
796
+ extra_args = {} if extra_args is None else extra_args
797
+ s_in = x.new_ones([x.shape[0]])
798
+
799
+ # Get settings
800
+ base_scale = ADEPT_STATE.get('scale', 1.0)
801
+ shift = ADEPT_STATE.get('shift', 0.0)
802
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
803
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
804
+ use_adaptive_eta = ADEPT_STATE.get('adaptive_eta', False)
805
+ current_eta = ADEPT_STATE.get('eta', eta)
806
+ current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
807
+
808
+ # Get UNet
809
+ try:
810
+ unet_model = shared.sd_model.model.diffusion_model
811
+ except AttributeError:
812
+ unet_model = None
813
+
814
+ if noise_sampler is None:
815
+ noise_sampler = default_noise_sampler(x)
816
+
817
+ total_steps = len(sigmas) - 1
818
+ print(f"✅ Adept Euler A active: scale={base_scale:.2f}, eta={current_eta:.2f}")
819
+
820
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Euler A"):
821
+ sigma = sigmas[i]
822
+ sigma_next = sigmas[i + 1]
823
+
824
+ progress = i / max(total_steps, 1)
825
+
826
+ # Adaptive eta
827
+ if use_adaptive_eta:
828
+ if progress < 0.3:
829
+ current_eta = eta * 1.08
830
+ elif progress < 0.7:
831
+ current_eta = eta * 0.95
832
+ else:
833
+ current_eta = eta * 1.02
834
+ else:
835
+ current_eta = eta
836
+
837
+ # Dynamic scale
838
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
839
+
840
+ # Evaluate model with weight patching
841
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
842
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
843
+ denoised = model(x, sigma * s_in, **extra_args)
844
+ else:
845
+ denoised = model(x, sigma * s_in, **extra_args)
846
+
847
+ # Euler Ancestral step
848
+ sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
849
+ d = to_d(x, sigma, denoised)
850
+
851
+ if torch.isnan(d).any() or torch.isinf(d).any():
852
+ d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
853
+
854
+ dt = sigma_down - sigma
855
+ x = x + d * dt
856
+
857
+ if sigma_up > 0:
858
+ noise = noise_sampler(sigma, sigma_next) * current_s_noise
859
+ x = x + noise * sigma_up
860
+
861
+ if callback is not None:
862
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
863
+
864
+ return x
865
+
866
+
867
+ def _basic_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
868
+ """Fallback basic Euler Ancestral."""
869
+ extra_args = {} if extra_args is None else extra_args
870
+ s_in = x.new_ones([x.shape[0]])
871
+ noise_sampler = default_noise_sampler(x)
872
+
873
+ for i in trange(len(sigmas) - 1, disable=disable):
874
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
875
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
876
+ d = to_d(x, sigmas[i], denoised)
877
+ dt = sigma_down - sigmas[i]
878
+ x = x + d * dt
879
+ if sigma_up > 0:
880
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
881
+ if callback is not None:
882
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
883
+
884
+ return x
885
+
886
+
887
+ @torch.no_grad()
888
+ def sample_adept_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
889
+ """Heun sampler with Adept weight scaling."""
890
+
891
+ if not ADEPT_STATE.get('enabled', False):
892
+ global ORIGINAL_SAMPLERS
893
+ if 'heun' in ORIGINAL_SAMPLERS:
894
+ return ORIGINAL_SAMPLERS['heun'](model, x, sigmas, extra_args, callback, disable, s_churn, s_tmin, s_tmax, s_noise)
895
+ return _basic_heun(model, x, sigmas, extra_args, callback, disable)
896
+
897
+ extra_args = {} if extra_args is None else extra_args
898
+ s_in = x.new_ones([x.shape[0]])
899
+
900
+ # Get settings
901
+ base_scale = ADEPT_STATE.get('scale', 1.0)
902
+ shift = ADEPT_STATE.get('shift', 0.0)
903
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
904
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
905
+
906
+ # Get UNet
907
+ try:
908
+ unet_model = shared.sd_model.model.diffusion_model
909
+ except AttributeError:
910
+ unet_model = None
911
+
912
+ total_steps = len(sigmas) - 1
913
+ print(f"✅ Adept Heun active: scale={base_scale:.2f}")
914
+
915
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept Heun"):
916
+ sigma = sigmas[i]
917
+ sigma_next = sigmas[i + 1]
918
+
919
+ # Dynamic scale
920
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
921
+
922
+ # First evaluation
923
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
924
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
925
+ denoised = model(x, sigma * s_in, **extra_args)
926
+ else:
927
+ denoised = model(x, sigma * s_in, **extra_args)
928
+
929
+ d = to_d(x, sigma, denoised)
930
+
931
+ if torch.isnan(d).any() or torch.isinf(d).any():
932
+ d = torch.nan_to_num(d, nan=0.0, posinf=1.0, neginf=-1.0)
933
+
934
+ dt = sigma_next - sigma
935
+
936
+ if sigma_next == 0:
937
+ # Last step
938
+ x = x + d * dt
939
+ else:
940
+ # Heun's method: two-stage
941
+ x_2 = x + d * dt
942
+
943
+ # Second evaluation
944
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
945
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
946
+ denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
947
+ else:
948
+ denoised_2 = model(x_2, sigma_next * s_in, **extra_args)
949
+
950
+ d_2 = to_d(x_2, sigma_next, denoised_2)
951
+
952
+ if torch.isnan(d_2).any() or torch.isinf(d_2).any():
953
+ d_2 = torch.nan_to_num(d_2, nan=0.0, posinf=1.0, neginf=-1.0)
954
+
955
+ # Average
956
+ d_prime = (d + d_2) / 2
957
+ x = x + d_prime * dt
958
+
959
+ if callback is not None:
960
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
961
+
962
+ return x
963
+
964
+
965
+ def _basic_heun(model, x, sigmas, extra_args=None, callback=None, disable=None):
966
+ """Fallback basic Heun."""
967
+ extra_args = {} if extra_args is None else extra_args
968
+ s_in = x.new_ones([x.shape[0]])
969
+
970
+ for i in trange(len(sigmas) - 1, disable=disable):
971
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
972
+ d = to_d(x, sigmas[i], denoised)
973
+ dt = sigmas[i + 1] - sigmas[i]
974
+
975
+ if sigmas[i + 1] == 0:
976
+ x = x + d * dt
977
+ else:
978
+ x_2 = x + d * dt
979
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
980
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
981
+ d_prime = (d + d_2) / 2
982
+ x = x + d_prime * dt
983
+
984
+ if callback is not None:
985
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
986
+
987
+ return x
988
+
989
+
990
+ @torch.no_grad()
991
+ def sample_adept_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
992
+ """DPM++ 2M sampler with Adept weight scaling."""
993
+
994
+ if not ADEPT_STATE.get('enabled', False):
995
+ global ORIGINAL_SAMPLERS
996
+ if 'dpmpp_2m' in ORIGINAL_SAMPLERS:
997
+ return ORIGINAL_SAMPLERS['dpmpp_2m'](model, x, sigmas, extra_args, callback, disable)
998
+ return _basic_dpmpp_2m(model, x, sigmas, extra_args, callback, disable)
999
+
1000
+ extra_args = {} if extra_args is None else extra_args
1001
+ s_in = x.new_ones([x.shape[0]])
1002
+
1003
+ # Get settings
1004
+ base_scale = ADEPT_STATE.get('scale', 1.0)
1005
+ shift = ADEPT_STATE.get('shift', 0.0)
1006
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
1007
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
1008
+
1009
+ # Get UNet
1010
+ try:
1011
+ unet_model = shared.sd_model.model.diffusion_model
1012
+ except AttributeError:
1013
+ unet_model = None
1014
+
1015
+ total_steps = len(sigmas) - 1
1016
+ print(f"✅ Adept DPM++ 2M active: scale={base_scale:.2f}")
1017
+
1018
+ old_denoised = None
1019
+
1020
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2M"):
1021
+ sigma = sigmas[i]
1022
+ sigma_next = sigmas[i + 1]
1023
+
1024
+ # Dynamic scale
1025
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
1026
+
1027
+ # Evaluate model with weight patching
1028
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
1029
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
1030
+ denoised = model(x, sigma * s_in, **extra_args)
1031
+ else:
1032
+ denoised = model(x, sigma * s_in, **extra_args)
1033
+
1034
+ # DPM++ 2M step
1035
+ t, t_next = sigma, sigma_next
1036
+ h = t_next - t
1037
+
1038
+ if old_denoised is None or sigma_next == 0:
1039
+ # First step (Euler)
1040
+ x = (sigma_next / sigma) * x - (-h).expm1() * denoised
1041
+ else:
1042
+ # Second order
1043
+ h_last = t - sigmas[i - 1]
1044
+ r = h_last / h
1045
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
1046
+ x = (sigma_next / sigma) * x - (-h).expm1() * denoised_d
1047
+
1048
+ old_denoised = denoised
1049
+
1050
+ if callback is not None:
1051
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
1052
+
1053
+ return x
1054
+
1055
+
1056
+ def _basic_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
1057
+ """Fallback basic DPM++ 2M."""
1058
+ extra_args = {} if extra_args is None else extra_args
1059
+ s_in = x.new_ones([x.shape[0]])
1060
+ old_denoised = None
1061
+
1062
+ for i in trange(len(sigmas) - 1, disable=disable):
1063
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1064
+ t, t_next = sigmas[i], sigmas[i + 1]
1065
+ h = t_next - t
1066
+
1067
+ if old_denoised is None or sigmas[i + 1] == 0:
1068
+ x = (t_next / t) * x - (-h).expm1() * denoised
1069
+ else:
1070
+ h_last = t - sigmas[i - 1]
1071
+ r = h_last / h
1072
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
1073
+ x = (t_next / t) * x - (-h).expm1() * denoised_d
1074
+
1075
+ old_denoised = denoised
1076
+
1077
+ if callback is not None:
1078
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
1079
+
1080
+ return x
1081
+
1082
+
1083
+ @torch.no_grad()
1084
+ def sample_adept_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None):
1085
+ """DPM++ 2S Ancestral with Adept weight scaling."""
1086
+
1087
+ if not ADEPT_STATE.get('enabled', False):
1088
+ global ORIGINAL_SAMPLERS
1089
+ if 'dpmpp_2s_ancestral' in ORIGINAL_SAMPLERS:
1090
+ return ORIGINAL_SAMPLERS['dpmpp_2s_ancestral'](model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
1091
+ return _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise)
1092
+
1093
+ extra_args = {} if extra_args is None else extra_args
1094
+ s_in = x.new_ones([x.shape[0]])
1095
+
1096
+ # Get settings
1097
+ base_scale = ADEPT_STATE.get('scale', 1.0)
1098
+ shift = ADEPT_STATE.get('shift', 0.0)
1099
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
1100
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
1101
+ current_eta = ADEPT_STATE.get('eta', eta)
1102
+ current_s_noise = ADEPT_STATE.get('s_noise', s_noise)
1103
+
1104
+ # Get UNet
1105
+ try:
1106
+ unet_model = shared.sd_model.model.diffusion_model
1107
+ except AttributeError:
1108
+ unet_model = None
1109
+
1110
+ if noise_sampler is None:
1111
+ noise_sampler = default_noise_sampler(x)
1112
+
1113
+ total_steps = len(sigmas) - 1
1114
+ print(f"✅ Adept DPM++ 2S A active: scale={base_scale:.2f}")
1115
+
1116
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept DPM++ 2S A"):
1117
+ sigma = sigmas[i]
1118
+ sigma_next = sigmas[i + 1]
1119
+
1120
+ # Dynamic scale
1121
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
1122
+
1123
+ # First evaluation
1124
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
1125
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
1126
+ denoised = model(x, sigma * s_in, **extra_args)
1127
+ else:
1128
+ denoised = model(x, sigma * s_in, **extra_args)
1129
+
1130
+ # DPM++ 2S step with ancestral noise
1131
+ sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, current_eta)
1132
+
1133
+ if sigma_down == 0:
1134
+ d = to_d(x, sigma, denoised)
1135
+ x = x + d * (sigma_down - sigma)
1136
+ else:
1137
+ # Midpoint method
1138
+ t, t_next = sigma, sigma_down
1139
+ h = t_next - t
1140
+ s = t + h * 0.5
1141
+
1142
+ # Step to midpoint
1143
+ x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
1144
+
1145
+ # Evaluate at midpoint
1146
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
1147
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
1148
+ denoised_mid = model(x_mid, s * s_in, **extra_args)
1149
+ else:
1150
+ denoised_mid = model(x_mid, s * s_in, **extra_args)
1151
+
1152
+ # Full step using midpoint
1153
+ x = (t_next / t) * x - (-h).expm1() * denoised_mid
1154
+
1155
+ # Add ancestral noise
1156
+ if sigma_up > 0:
1157
+ noise = noise_sampler(sigma, sigma_next) * current_s_noise
1158
+ x = x + noise * sigma_up
1159
+
1160
+ if callback is not None:
1161
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
1162
+
1163
+ return x
1164
+
1165
+
1166
+ def _basic_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0):
1167
+ """Fallback basic DPM++ 2S Ancestral."""
1168
+ extra_args = {} if extra_args is None else extra_args
1169
+ s_in = x.new_ones([x.shape[0]])
1170
+ noise_sampler = default_noise_sampler(x)
1171
+
1172
+ for i in trange(len(sigmas) - 1, disable=disable):
1173
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1174
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta)
1175
+
1176
+ if sigma_down == 0:
1177
+ d = to_d(x, sigmas[i], denoised)
1178
+ x = x + d * (sigma_down - sigmas[i])
1179
+ else:
1180
+ t, t_next = sigmas[i], sigma_down
1181
+ h = t_next - t
1182
+ s = t + h * 0.5
1183
+ x_mid = (s / t) * x - (-(h * 0.5)).expm1() * denoised
1184
+ denoised_mid = model(x_mid, s * s_in, **extra_args)
1185
+ x = (t_next / t) * x - (-h).expm1() * denoised_mid
1186
+
1187
+ if sigma_up > 0:
1188
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1189
+
1190
+ if callback is not None:
1191
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
1192
+
1193
+ return x
1194
+
1195
+
1196
+ @torch.no_grad()
1197
+ def sample_adept_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
1198
+ """LMS sampler with Adept weight scaling."""
1199
+
1200
+ if not ADEPT_STATE.get('enabled', False):
1201
+ global ORIGINAL_SAMPLERS
1202
+ if 'lms' in ORIGINAL_SAMPLERS:
1203
+ return ORIGINAL_SAMPLERS['lms'](model, x, sigmas, extra_args, callback, disable, order)
1204
+ return _basic_lms(model, x, sigmas, extra_args, callback, disable, order)
1205
+
1206
+ extra_args = {} if extra_args is None else extra_args
1207
+ s_in = x.new_ones([x.shape[0]])
1208
+
1209
+ # Get settings
1210
+ base_scale = ADEPT_STATE.get('scale', 1.0)
1211
+ shift = ADEPT_STATE.get('shift', 0.0)
1212
+ start_pct = ADEPT_STATE.get('start_pct', 0.0)
1213
+ end_pct = ADEPT_STATE.get('end_pct', 1.0)
1214
+
1215
+ # Get UNet
1216
+ try:
1217
+ unet_model = shared.sd_model.model.diffusion_model
1218
+ except AttributeError:
1219
+ unet_model = None
1220
+
1221
+ total_steps = len(sigmas) - 1
1222
+ print(f"✅ Adept LMS active: scale={base_scale:.2f}, order={order}")
1223
+
1224
+ ds = []
1225
+
1226
+ for i in trange(len(sigmas) - 1, disable=disable, desc="Adept LMS"):
1227
+ sigma = sigmas[i]
1228
+
1229
+ # Dynamic scale
1230
+ current_scale = compute_dynamic_scale(i, total_steps, base_scale, start_pct, end_pct)
1231
+
1232
+ # Evaluate model with weight patching
1233
+ if unet_model is not None and abs(current_scale - 1.0) > 1e-6:
1234
+ with AdeptWeightPatcher(unet_model, current_scale, shift):
1235
+ denoised = model(x, sigma * s_in, **extra_args)
1236
+ else:
1237
+ denoised = model(x, sigma * s_in, **extra_args)
1238
+
1239
+ d = to_d(x, sigma, denoised)
1240
+ ds.append(d)
1241
+
1242
+ if len(ds) > order:
1243
+ ds.pop(0)
1244
+
1245
+ # Linear multistep coefficients
1246
+ cur_order = min(i + 1, order)
1247
+ coeffs = [1.0]
1248
+
1249
+ for j in range(1, cur_order):
1250
+ prod = 1.0
1251
+ for k in range(cur_order):
1252
+ if k != j:
1253
+ prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
1254
+ coeffs.append(prod)
1255
+
1256
+ # Apply multistep
1257
+ d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
1258
+
1259
+ dt = sigmas[i + 1] - sigma
1260
+ x = x + d_multistep * dt
1261
+
1262
+ if callback is not None:
1263
+ callback({'x': x, 'i': i, 'sigma': sigma, 'denoised': denoised})
1264
+
1265
+ return x
1266
+
1267
+
1268
+ def _basic_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
1269
+ """Fallback basic LMS."""
1270
+ extra_args = {} if extra_args is None else extra_args
1271
+ s_in = x.new_ones([x.shape[0]])
1272
+ ds = []
1273
+
1274
+ for i in trange(len(sigmas) - 1, disable=disable):
1275
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1276
+ d = to_d(x, sigmas[i], denoised)
1277
+ ds.append(d)
1278
+
1279
+ if len(ds) > order:
1280
+ ds.pop(0)
1281
+
1282
+ cur_order = min(i + 1, order)
1283
+ coeffs = [1.0]
1284
+ for j in range(1, cur_order):
1285
+ prod = 1.0
1286
+ for k in range(cur_order):
1287
+ if k != j:
1288
+ prod *= (sigmas[i] - sigmas[i - k]) / (sigmas[i - j] - sigmas[i - k])
1289
+ coeffs.append(prod)
1290
+
1291
+ d_multistep = sum(c * d_val for c, d_val in zip(coeffs, reversed(ds[-cur_order:])))
1292
+ dt = sigmas[i + 1] - sigmas[i]
1293
+ x = x + d_multistep * dt
1294
+
1295
+ if callback is not None:
1296
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'denoised': denoised})
1297
+
1298
+ return x
1299
+
1300
+
1301
+ # ============================================================================
1302
+ # MONKEY PATCHING
1303
+ # ============================================================================
1304
+
1305
+ def patch_k_diffusion():
1306
+ """Apply monkey patches to ALL k-diffusion samplers."""
1307
+ global ORIGINAL_SAMPLERS
1308
+
1309
+ samplers_to_patch = {
1310
+ 'sample_euler': sample_adept_euler,
1311
+ 'sample_euler_ancestral': sample_adept_euler_ancestral,
1312
+ 'sample_heun': sample_adept_heun,
1313
+ 'sample_dpmpp_2m': sample_adept_dpmpp_2m,
1314
+ 'sample_dpmpp_2s_ancestral': sample_adept_dpmpp_2s_ancestral,
1315
+ 'sample_lms': sample_adept_lms,
1316
+ }
1317
+
1318
+ patched_count = 0
1319
+ for original_name, adept_func in samplers_to_patch.items():
1320
+ if hasattr(k_diffusion.sampling, original_name):
1321
+ # Save original
1322
+ if original_name not in ORIGINAL_SAMPLERS:
1323
+ original_func = getattr(k_diffusion.sampling, original_name)
1324
+ ORIGINAL_SAMPLERS[original_name.replace('sample_', '')] = original_func
1325
+
1326
+ # Apply patch
1327
+ setattr(k_diffusion.sampling, original_name, adept_func)
1328
+ patched_count += 1
1329
+
1330
+ print(f"✅ Adept Sampler v3 FULL: Patched {patched_count} samplers")
1331
+ print(f" Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS")
1332
+ print(f" Schedulers: 16 types available")
1333
+
1334
+
1335
+ def unpatch_k_diffusion():
1336
+ """Restore original k-diffusion samplers."""
1337
+ global ORIGINAL_SAMPLERS
1338
+
1339
+ samplers_to_restore = {
1340
+ 'euler': 'sample_euler',
1341
+ 'euler_ancestral': 'sample_euler_ancestral',
1342
+ 'heun': 'sample_heun',
1343
+ 'dpmpp_2m': 'sample_dpmpp_2m',
1344
+ 'dpmpp_2s_ancestral': 'sample_dpmpp_2s_ancestral',
1345
+ 'lms': 'sample_lms',
1346
+ }
1347
+
1348
+ restored_count = 0
1349
+ for key, attr_name in samplers_to_restore.items():
1350
+ if key in ORIGINAL_SAMPLERS:
1351
+ setattr(k_diffusion.sampling, attr_name, ORIGINAL_SAMPLERS[key])
1352
+ restored_count += 1
1353
+
1354
+ print(f"🔄 Adept Sampler: Restored {restored_count} original samplers")
1355
+
1356
+
1357
+ # ============================================================================
1358
+ # A1111 EXTENSION SCRIPT
1359
+ # ============================================================================
1360
+
1361
+ class AdeptSamplerScript(scripts.Script):
1362
+ """Adept Sampler FULL extension for A1111."""
1363
+
1364
+ def title(self):
1365
+ return "Adept Sampler v3 FULL"
1366
+
1367
+ def show(self, is_img2img):
1368
+ return scripts.AlwaysVisible
1369
+
1370
+ def ui(self, is_img2img):
1371
+ """Create UI elements."""
1372
+ with gr.Accordion("Adept Sampler v3 FULL", open=False):
1373
+ enabled = gr.Checkbox(
1374
+ label="Enable Adept Sampler",
1375
+ value=False,
1376
+ elem_id="adept_enabled"
1377
+ )
1378
+
1379
+ gr.HTML("<p style='color: #888;'>Works with: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS</p>")
1380
+
1381
+ with gr.Row():
1382
+ scale = gr.Slider(
1383
+ minimum=0.5,
1384
+ maximum=2.0,
1385
+ step=0.05,
1386
+ value=1.0,
1387
+ label="Weight Scale",
1388
+ elem_id="adept_scale"
1389
+ )
1390
+ shift = gr.Slider(
1391
+ minimum=-0.5,
1392
+ maximum=0.5,
1393
+ step=0.01,
1394
+ value=0.0,
1395
+ label="Weight Shift",
1396
+ elem_id="adept_shift"
1397
+ )
1398
+
1399
+ with gr.Row():
1400
+ start_pct = gr.Slider(
1401
+ minimum=0.0,
1402
+ maximum=1.0,
1403
+ step=0.05,
1404
+ value=0.0,
1405
+ label="Start Percent",
1406
+ elem_id="adept_start"
1407
+ )
1408
+ end_pct = gr.Slider(
1409
+ minimum=0.0,
1410
+ maximum=1.0,
1411
+ step=0.05,
1412
+ value=1.0,
1413
+ label="End Percent",
1414
+ elem_id="adept_end"
1415
+ )
1416
+
1417
+ with gr.Row():
1418
+ eta = gr.Slider(
1419
+ minimum=0.0,
1420
+ maximum=2.0,
1421
+ step=0.01,
1422
+ value=1.0,
1423
+ label="Eta (Ancestral samplers)",
1424
+ elem_id="adept_eta"
1425
+ )
1426
+ s_noise = gr.Slider(
1427
+ minimum=0.0,
1428
+ maximum=2.0,
1429
+ step=0.01,
1430
+ value=1.0,
1431
+ label="S-Noise",
1432
+ elem_id="adept_s_noise"
1433
+ )
1434
+
1435
+ adaptive_eta = gr.Checkbox(
1436
+ label="Adaptive Eta (dynamic eta during sampling)",
1437
+ value=False,
1438
+ elem_id="adept_adaptive_eta"
1439
+ )
1440
+
1441
+ scheduler = gr.Dropdown(
1442
+ choices=[
1443
+ "Standard",
1444
+ "AOS-V",
1445
+ "AOS-Epsilon",
1446
+ "AkashicAOS",
1447
+ "Entropic",
1448
+ "SNR-Optimized",
1449
+ "Constant-Rate",
1450
+ "Adaptive-Optimized",
1451
+ "Cosine-Annealed",
1452
+ "LogSNR-Uniform",
1453
+ "Tanh Mid-Boost",
1454
+ "Exponential Tail",
1455
+ "Jittered-Karras",
1456
+ "Stochastic",
1457
+ "JYS (Dynamic)",
1458
+ "Hybrid JYS-Karras",
1459
+ "AYS-SDXL",
1460
+ ],
1461
+ value="Standard",
1462
+ label="Scheduler Type",
1463
+ elem_id="adept_scheduler"
1464
+ )
1465
+
1466
+ vae_reflection = gr.Checkbox(
1467
+ label="Enable VAE Reflection (fixes edge artifacts for EQ-VAE)",
1468
+ value=False,
1469
+ elem_id="adept_vae_reflection"
1470
+ )
1471
+
1472
+ return [enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection]
1473
+
1474
+ def process(self, p, enabled, scale, shift, start_pct, end_pct, eta, s_noise, adaptive_eta, scheduler, vae_reflection):
1475
+ """Process parameters and update global state."""
1476
+ global ADEPT_STATE
1477
+
1478
+ # Apply scheduler to sigmas
1479
+ if enabled and scheduler != "Standard":
1480
+ # Get original sigmas
1481
+ original_sigmas = p.sampler.model_wrap.sigmas
1482
+
1483
+ # Apply custom scheduler
1484
+ new_sigmas = apply_custom_scheduler(original_sigmas, scheduler)
1485
+
1486
+ # Update sigmas
1487
+ p.sampler.model_wrap.sigmas = new_sigmas
1488
+
1489
+ print(f"📊 Applied scheduler: {scheduler}")
1490
+
1491
+ # Update global state
1492
+ ADEPT_STATE.update({
1493
+ "enabled": enabled,
1494
+ "scale": scale,
1495
+ "shift": shift,
1496
+ "start_pct": start_pct,
1497
+ "end_pct": end_pct,
1498
+ "eta": eta,
1499
+ "s_noise": s_noise,
1500
+ "adaptive_eta": adaptive_eta,
1501
+ "scheduler": scheduler,
1502
+ "vae_reflection": vae_reflection,
1503
+ })
1504
+
1505
+ # Add to generation info
1506
+ if enabled:
1507
+ p.extra_generation_params.update({
1508
+ "Adept Sampler": "v3 FULL",
1509
+ "Adept Scale": scale,
1510
+ "Adept Shift": shift,
1511
+ "Adept Range": f"{start_pct:.0%}-{end_pct:.0%}",
1512
+ "Adept Eta": eta,
1513
+ "Adept S-Noise": s_noise,
1514
+ "Adept Adaptive Eta": adaptive_eta,
1515
+ "Adept Scheduler": scheduler,
1516
+ "Adept VAE Reflection": vae_reflection,
1517
+ })
1518
+
1519
+ def process_batch(self, p, *args, **kwargs):
1520
+ """Wrap entire batch in VAE Reflection if enabled."""
1521
+ if ADEPT_STATE.get('vae_reflection', False):
1522
+ try:
1523
+ vae_model = shared.sd_model.first_stage_model
1524
+ with VAEReflectionPatcher(vae_model):
1525
+ # VAE reflection active during this batch
1526
+ pass
1527
+ except Exception as e:
1528
+ print(f"⚠️ VAE Reflection error: {e}")
1529
+
1530
+
1531
+ # ============================================================================
1532
+ # INITIALIZATION
1533
+ # ============================================================================
1534
+
1535
+ # Apply patches on load
1536
+ patch_k_diffusion()
1537
+
1538
+ # Register cleanup
1539
+ def on_script_unloaded():
1540
+ unpatch_k_diffusion()
1541
+
1542
+ try:
1543
+ script_callbacks.on_script_unloaded(on_script_unloaded)
1544
+ except AttributeError:
1545
+ print("⚠️ Script unload callback not available")
1546
+
1547
+ print("🚀 Adept Sampler v3 FULL loaded!")
1548
+ print(" - 6 Samplers: Euler, Euler A, Heun, DPM++ 2M, DPM++ 2S A, LMS")
1549
+ print(" - 16 Schedulers: AOS-V, AOS-Epsilon, AkashicAOS, Entropic, SNR-Optimized,")
1550
+ print(" Constant-Rate, Adaptive-Optimized, Cosine-Annealed, LogSNR-Uniform,")
1551
+ print(" Tanh Mid-Boost, Exponential Tail, Jittered-Karras, Stochastic,")
1552
+ print(" JYS (Dynamic), Hybrid JYS-Karras, AYS-SDXL")
1553
+ print(" - VAE Reflection support")
1554
+ print(" - Dynamic Weight Scaling")