dikdimon commited on
Commit
ce3ecbe
·
verified ·
1 Parent(s): 9786d3d

Upload nrs_kohaku_enhanced_v2 (1).py

Browse files
negative_rejection_steering/scripts/nrs_kohaku_enhanced_v2 (1).py ADDED
@@ -0,0 +1,1220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import gradio as gr
4
+ from modules import scripts, script_callbacks, sd_samplers_cfg_denoiser, shared
5
+
6
+ # ==============================================================================
7
+ # NRS + KOHAKU ENHANCED — Version 2.0
8
+ #
9
+ # Improvements over v1:
10
+ # 1. Midpoint Refinement (replaces flawed Antipodal — correct Kohaku principle)
11
+ # 2. Curve Scheduling (12 curves: Constant/Linear/Cosine/Power/Repeating/Sawtooth)
12
+ # 3. CADS Trapezoidal Schedule (tau1/tau2 annealing)
13
+ # 4. Adaptive Phases (Euler → DPM++ → Detail, from adaptive_progressive.py)
14
+ # 5. Per-Channel NRS (independent processing per latent channel)
15
+ # 6. AD Normalization (Absolute Deviation, more robust than L2)
16
+ # 7. Variance Preserving Rescale (phi blend)
17
+ # 8. Interpolate Phi (NRS ↔ plain CFG blend)
18
+ # 9. CFG Drift Correction (mean/median centering, from adept_sampler_v4)
19
+ # 10. Momentum smoothing (from res_solver / clybius)
20
+ # 11. GE-Gamma Extrapolation (from gradient_estimation.py)
21
+ # 12. Native Detail Boost (Gaussian HF enhancement, from adept_sampler_v4)
22
+ # 13. Spectral Modulation (FFT frequency correction, from adept_sampler_v4)
23
+ # 14. Uncond Noise & Scale (from forge_condBlast)
24
+ # 15. Output Clamp (adaptive sigma-based, from adept_sampler_v4)
25
+ # ==============================================================================
26
+
27
+ CURVE_CHOICES = [
28
+ "Constant",
29
+ "Linear Down", "Linear Up",
30
+ "Cosine Down", "Cosine Up",
31
+ "Half Cosine Down", "Half Cosine Up",
32
+ "Power Down", "Power Up",
33
+ "Linear Repeating", "Cosine Repeating",
34
+ "Sawtooth",
35
+ ]
36
+
37
+ SCHED_MODES = ["Off", "Individual Curves", "CADS Anneal", "Adaptive Phases"]
38
+ INTER_STEP_MODES = ["Off", "Momentum", "GE-Gamma"]
39
+ DRIFT_METHODS = ["mean", "median"]
40
+
41
+
42
+ # ==============================================================================
43
+ # ЧАСТЬ 1: МАТЕМАТИЧЕСКОЕ ЯДРО
44
+ # ==============================================================================
45
+
46
+ def _nrs_core(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm=False):
47
+ """
48
+ Base NRS math kernel.
49
+ use_ad_norm: use Absolute Deviation norm for squash (more robust to outliers).
50
+ Source: dynthres_core (3).py variability_measure='AD'
51
+ """
52
+ is_v_pred = False
53
+ if hasattr(shared.sd_model, 'parameterization'):
54
+ is_v_pred = shared.sd_model.parameterization == "v"
55
+
56
+ if isinstance(sigma, torch.Tensor):
57
+ sig_tens = sigma[0]
58
+ else:
59
+ sig_tens = torch.tensor(sigma, device=cond.device, dtype=cond.dtype)
60
+ if sig_tens.dtype != cond.dtype:
61
+ sig_tens = sig_tens.to(dtype=cond.dtype)
62
+
63
+ sig_tens = sig_tens.view(1, 1, 1, 1)
64
+ sig_root = (sig_tens ** 2 + 1).sqrt()
65
+
66
+ if is_v_pred:
67
+ nrs_cond, nrs_uncond = cond, uncond
68
+ x_div = None
69
+ else:
70
+ x_div = x_orig / (sig_tens ** 2 + 1)
71
+ factor = sig_tens / sig_root
72
+ nrs_cond = x_orig - (x_div - cond * factor)
73
+ nrs_uncond = x_orig - (x_div - uncond * factor)
74
+
75
+ def _dot(a, b):
76
+ return (a * b).sum(dim=1, keepdim=True)
77
+
78
+ def _nrm2(v):
79
+ return _dot(v, v)
80
+
81
+ eps_safe = 1e-6
82
+
83
+ c_dot_c = _nrm2(nrs_cond) + eps_safe
84
+ u_dot_c = _dot(nrs_uncond, nrs_cond)
85
+ u_on_c = (u_dot_c / c_dot_c) * nrs_cond
86
+
87
+ proj_diff = nrs_cond - u_on_c
88
+ stretched = nrs_cond + (stretch * proj_diff)
89
+
90
+ u_rej_c = nrs_uncond - u_on_c
91
+ skewed = stretched - (skew * u_rej_c)
92
+
93
+ if use_ad_norm:
94
+ # AD: Mean Absolute Deviation per channel, then average across channels
95
+ # Source: dynthres_core sep_feat_channels=True, variability_measure='AD'
96
+ cond_len = nrs_cond.abs().mean(dim=(2, 3), keepdim=True).mean(dim=1, keepdim=True)
97
+ nrs_len = skewed.abs().mean(dim=(2, 3), keepdim=True).mean(dim=1, keepdim=True) + eps_safe
98
+ else:
99
+ cond_len = nrs_cond.norm(dim=1, keepdim=True)
100
+ nrs_len = skewed.norm(dim=1, keepdim=True) + eps_safe
101
+
102
+ squash_scale = (1 - squash) + (squash * (cond_len / nrs_len))
103
+ x_final = skewed * squash_scale
104
+
105
+ if is_v_pred:
106
+ return x_final
107
+ else:
108
+ return (x_div - (x_orig - x_final)) * (sig_root / sig_tens)
109
+
110
+
111
+ def calc_nrs(x_orig, cond, uncond, sigma, skew, stretch, squash):
112
+ """Backward-compatible wrapper."""
113
+ return _nrs_core(x_orig, cond, uncond, sigma, skew, stretch, squash)
114
+
115
+
116
+ def calc_nrs_per_channel(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm=False):
117
+ """
118
+ Per-channel NRS: process each latent channel independently.
119
+ Source: dynthres_core (3).py sep_feat_channels=True
120
+ Per-channel norms use dim=(2,3) spatial only, preventing cross-channel influence.
121
+ """
122
+ results = []
123
+ for ch in range(cond.shape[1]):
124
+ r = _nrs_core(
125
+ x_orig[:, ch:ch+1],
126
+ cond[:, ch:ch+1],
127
+ uncond[:, ch:ch+1],
128
+ sigma, skew, stretch, squash, use_ad_norm
129
+ )
130
+ results.append(r)
131
+ return torch.cat(results, dim=1)
132
+
133
+
134
+ def calc_nrs_midpoint_refined(x_orig, cond, uncond, sigma, skew, stretch, squash,
135
+ refine_blend=0.0, first_half_only=True,
136
+ current_step=0, total_steps=20,
137
+ use_per_channel=False, use_ad_norm=False):
138
+ """
139
+ Correct Kohaku midpoint refinement for NRS.
140
+
141
+ Kohaku_LoNyu_Yog (sampler, smea_sampling_46.py):
142
+ d = to_d(x, sigma, model(x))
143
+ x3 = x + (d + d2) / 2 * dt # midpoint from averaged direction
144
+ d3 = to_d(x3, sigma, model(x3))
145
+ real_d = (d + d3) / 2 # Runge-Kutta 2nd order average
146
+
147
+ NRS adaptation (no extra model calls needed):
148
+ nrs_direct = NRS(x_orig, cond, uncond)
149
+ x_mid = x_orig + (nrs_direct - x_orig) * blend * 0.5 # shifted latent
150
+ nrs_refined = NRS(x_mid, cond, uncond)
151
+ result = (nrs_direct + nrs_refined) / 2
152
+ """
153
+ _fn = calc_nrs_per_channel if use_per_channel else _nrs_core
154
+
155
+ nrs_direct = _fn(x_orig, cond, uncond, sigma, skew, stretch, squash, use_ad_norm)
156
+
157
+ if refine_blend <= 0.0:
158
+ return nrs_direct
159
+
160
+ if first_half_only and current_step > total_steps / 2:
161
+ return nrs_direct
162
+
163
+ # Midpoint in x-space (between noisy x_orig and denoised nrs_direct)
164
+ x_mid = x_orig + (nrs_direct - x_orig) * (refine_blend * 0.5)
165
+
166
+ nrs_refined = _fn(x_mid, cond, uncond, sigma, skew, stretch, squash, use_ad_norm)
167
+
168
+ # Runge-Kutta style average (Kohaku: real_d = (d + d3) / 2)
169
+ return (nrs_direct + nrs_refined) * 0.5
170
+
171
+
172
+ # ==============================================================================
173
+ # ЧАСТЬ 2: РАСПИСАНИЕ ПАРАМЕТРОВ (SCHEDULING)
174
+ # ==============================================================================
175
+
176
+ def nrs_schedule_value(base_value, step, total_steps, curve="Constant",
177
+ min_value=0.0, sched_val=2.0):
178
+ """
179
+ Apply curve to parameter over sampling steps.
180
+ Source: dynthres_core (3).py interpret_scale() + khrfix (26).py curve_progress()
181
+ """
182
+ if curve == "Constant":
183
+ return base_value
184
+
185
+ frac = step / max(total_steps - 1, 1)
186
+ frac = max(0.0, min(1.0, frac))
187
+ scale = base_value - min_value
188
+
189
+ if curve == "Linear Down":
190
+ val = 1.0 - frac
191
+ elif curve == "Linear Up":
192
+ val = frac
193
+ elif curve == "Cosine Down":
194
+ # Source: dynthres_core cos(frac * pi/2) — от 1.0 до ~0
195
+ val = math.cos(frac * 1.5707963)
196
+ elif curve == "Cosine Up":
197
+ # Source: dynthres_core 1 - cos(frac * pi/2)
198
+ val = 1.0 - math.cos(frac * 1.5707963)
199
+ elif curve == "Half Cosine Down":
200
+ # Source: dynthres_core + khrfix → cos(frac), НЕ cos(frac*pi/2)
201
+ val = math.cos(frac)
202
+ elif curve == "Half Cosine Up":
203
+ # Source: dynthres_core + khrfix → 1 - cos(frac)
204
+ val = 1.0 - math.cos(frac)
205
+ elif curve == "Power Down":
206
+ val = 1.0 - math.pow(frac, max(sched_val, 0.1))
207
+ elif curve == "Power Up":
208
+ val = math.pow(frac, max(sched_val, 0.1))
209
+ elif curve == "Linear Repeating":
210
+ sv = max(sched_val, 0.1)
211
+ portion = (frac * sv) % 1.0
212
+ val = 1.0 - abs(2.0 * portion - 1.0)
213
+ elif curve == "Cosine Repeating":
214
+ sv = max(sched_val, 0.1)
215
+ val = math.cos(2.0 * math.pi * frac * sv) * 0.5 + 0.5
216
+ elif curve == "Sawtooth":
217
+ sv = max(sched_val, 0.1)
218
+ val = (frac * sv) % 1.0
219
+ else:
220
+ val = 1.0
221
+
222
+ return min_value + max(0.0, scale * val)
223
+
224
+
225
+ def nrs_cads_schedule(step, total_steps, tau1=0.6, tau2=0.9):
226
+ """
227
+ CADS trapezoidal NRS strength schedule.
228
+ Source: cads__6__fixed.py cads_linear_schedule(), "Hold after full" mode.
229
+
230
+ t = 1 - step/total (descends from 1 to 0 during sampling)
231
+ - t > tau2 (early steps): gamma = 0 (NRS off)
232
+ - tau1 < t < tau2 (ramp): gamma linearly rises 0→1
233
+ - t <= tau1 (late steps): gamma = 1 (NRS full strength)
234
+
235
+ Defaults tau1=0.6, tau2=0.9 → NRS activates at ~10% of steps,
236
+ reaches full strength at ~40%, stays full for remainder.
237
+ """
238
+ t = 1.0 - step / max(total_steps - 1, 1)
239
+ t = max(0.0, min(1.0, t))
240
+ tau1 = max(0.0, min(1.0, tau1))
241
+ tau2 = max(0.0, min(1.0, tau2))
242
+
243
+ if tau1 >= tau2:
244
+ return 1.0 if t <= tau1 else 0.0
245
+ if t >= tau2:
246
+ return 0.0
247
+ if t <= tau1:
248
+ return 1.0
249
+ return (tau2 - t) / (tau2 - tau1)
250
+
251
+
252
+ def calc_adaptive_nrs_params(base_skew, base_stretch, base_squash, progress,
253
+ euler_end=0.35, dpm_end=0.70):
254
+ """
255
+ Phase-based parameter adjustment.
256
+ Source: adaptive_progressive.py calc_phase_bounds() + phase weight logic.
257
+
258
+ Phase 1 (0 → euler_end): Structural — max skew, moderate stretch
259
+ Phase 2 (euler_end → dpm_end): Transition — decreasing skew, rising squash
260
+ Phase 3 (dpm_end → 1.0): Detail — minimal skew, max squash
261
+ """
262
+ euler_end = max(0.0, min(1.0, euler_end))
263
+ dpm_end = max(euler_end + 0.05, min(1.0, dpm_end))
264
+
265
+ if progress < euler_end:
266
+ skew_f, stretch_f, squash_add = 1.0, 0.8, 0.0
267
+ elif progress < dpm_end:
268
+ ph = (progress - euler_end) / max(dpm_end - euler_end, 1e-8)
269
+ w_euler = max(0.0, 1.0 - ph * 2.5)
270
+ skew_f = w_euler + (1.0 - w_euler) * 0.3
271
+ stretch_f = w_euler * 0.8 + (1.0 - w_euler) * 1.0
272
+ squash_add = (1.0 - w_euler) * 0.3
273
+ else:
274
+ ph = (progress - dpm_end) / max(1.0 - dpm_end, 1e-8)
275
+ skew_f = max(0.0, 0.3 - ph * 0.3)
276
+ stretch_f = 0.8
277
+ squash_add = 0.3 + ph * 0.4
278
+
279
+ return (
280
+ base_skew * skew_f,
281
+ base_stretch * stretch_f,
282
+ min(1.0, base_squash + (1.0 - base_squash) * squash_add),
283
+ )
284
+
285
+
286
+ # ==============================================================================
287
+ # ЧАСТЬ 3: POST-PROCESSING ФУНКЦИИ
288
+ # ==============================================================================
289
+
290
+ def apply_nrs_drift_correction(tensor, intensity=0.0, method='mean'):
291
+ """
292
+ Remove CFG mean/median drift.
293
+ Source: adept_sampler_v4_COMPLETE (2).py apply_combat_cfg_drift()
294
+ Based on ComfyUI-Latent-Modifiers.
295
+ """
296
+ if intensity <= 0.0:
297
+ return tensor
298
+ try:
299
+ if method == 'median':
300
+ center = tensor.view(tensor.shape[0], -1).median(dim=-1, keepdim=True)[0]
301
+ center = center.view(tensor.shape[0], 1, 1, 1)
302
+ else:
303
+ center = tensor.mean(dim=(1, 2, 3), keepdim=True)
304
+ return tensor - center * intensity
305
+ except Exception:
306
+ return tensor
307
+
308
+
309
+ def apply_variance_preserving_rescale(nrs_result, cond_reference, phi=0.0):
310
+ """
311
+ Scale NRS result to match std of reference cond.
312
+ Source: dynthres_core (3).py interpolation logic + variance concept.
313
+ """
314
+ if phi <= 0.0:
315
+ return nrs_result
316
+ try:
317
+ std_ref = cond_reference.std()
318
+ std_nrs = nrs_result.std()
319
+ if std_nrs < 1e-8:
320
+ return nrs_result
321
+ rescaled = nrs_result * (std_ref / std_nrs)
322
+ return phi * rescaled + (1.0 - phi) * nrs_result
323
+ except Exception:
324
+ return nrs_result
325
+
326
+
327
+ def apply_blend_phi(nrs_result, plain_cfg_result, phi=1.0):
328
+ """
329
+ Blend NRS output with standard CFG output.
330
+ Source: dynthres_core (3).py interpolate_phi.
331
+ phi=1.0 → pure NRS, phi=0.0 → pure CFG.
332
+ """
333
+ if phi >= 1.0:
334
+ return nrs_result
335
+ if phi <= 0.0:
336
+ return plain_cfg_result
337
+ return phi * nrs_result + (1.0 - phi) * plain_cfg_result
338
+
339
+
340
+ def apply_nrs_momentum(nrs_result, prev_result, prev_vel, momentum=0.0):
341
+ """
342
+ Full momentum smoothing between steps.
343
+ Source: res_solver (11).py + clybius_dpmpp_4m_sde (7).py momentum_func()
344
+ Formula: vel = m*(1-m/2)*prev_vel + (1-m*(1-m/2))*curr_diff
345
+ result = prev_result + vel
346
+ Returns: (smoothed_result, new_vel)
347
+ """
348
+ if momentum <= 0.0 or prev_result is None:
349
+ curr_diff = nrs_result - prev_result if prev_result is not None else None
350
+ return nrs_result, curr_diff
351
+ try:
352
+ curr_diff = nrs_result - prev_result
353
+ eff_m = momentum * (1.0 - momentum * 0.5)
354
+ if prev_vel is None:
355
+ # First step: velocity = current diff (no history)
356
+ vel = curr_diff
357
+ else:
358
+ # Full RES/Clybius formula: blend prev velocity with current diff
359
+ vel = eff_m * prev_vel + (1.0 - eff_m) * curr_diff
360
+ smoothed = prev_result + vel
361
+ return smoothed, vel
362
+ except Exception:
363
+ return nrs_result, None
364
+
365
+
366
+ def apply_nrs_ge_extrapolation(nrs_result, prev_result, prev_diff, ge_gamma=1.0):
367
+ """
368
+ Gradient Estimation extrapolation between steps.
369
+ Source: gradient_estimation (5).py
370
+ Formula: d_bar = ge_gamma * (d - old_d) + old_d
371
+ ge_gamma=1.0 → standard, >1.0 → extrapolation, <1.0 → smoothing.
372
+ """
373
+ if ge_gamma == 1.0 or prev_result is None or prev_diff is None:
374
+ return nrs_result
375
+ try:
376
+ d = nrs_result - prev_result
377
+ d_bar = ge_gamma * (d - prev_diff) + prev_diff
378
+ return prev_result + d_bar
379
+ except Exception:
380
+ return nrs_result
381
+
382
+
383
+ def apply_nrs_detail_boost(nrs_result, progress, boost_strength=0.0):
384
+ """
385
+ Progressive high-frequency detail enhancement.
386
+ Source: adept_sampler_v4_COMPLETE (2).py compute_native_detail_boost()
387
+ Three phases: early (gentle) → mid → late (strong).
388
+ """
389
+ if boost_strength <= 0.0:
390
+ return nrs_result
391
+ try:
392
+ import torch.nn.functional as F
393
+
394
+ if progress < 0.30:
395
+ hf_boost = 0.03 * boost_strength * (progress / 0.30)
396
+ elif progress < 0.60:
397
+ hf_boost = (0.03 + 0.07 * (progress - 0.30) / 0.30) * boost_strength
398
+ else:
399
+ hf_boost = (0.10 + 0.08 * (progress - 0.60) / 0.40) * boost_strength
400
+
401
+ if hf_boost <= 1e-6:
402
+ return nrs_result
403
+
404
+ # Gaussian kernel for low-freq extraction
405
+ sigma_g = 0.5
406
+ ks = 3
407
+ x_k = torch.linspace(-(ks - 1) / 2, (ks - 1) / 2, ks,
408
+ device=nrs_result.device, dtype=nrs_result.dtype)
409
+ gauss = torch.exp(-0.5 * (x_k / sigma_g) ** 2)
410
+ gauss = gauss / gauss.sum()
411
+ kernel = torch.mm(gauss[:, None], gauss[None, :])
412
+ # .contiguous() required: .expand() creates non-contiguous view, F.conv2d needs contiguous weight
413
+ kernel = kernel.expand(nrs_result.shape[1], 1, ks, ks).contiguous()
414
+
415
+ padded = F.pad(nrs_result, (1, 1, 1, 1), mode='reflect')
416
+ low_freq = F.conv2d(padded, kernel, groups=nrs_result.shape[1])
417
+ high_freq = nrs_result - low_freq
418
+ return nrs_result + high_freq * hf_boost
419
+ except Exception:
420
+ return nrs_result
421
+
422
+
423
+ def apply_spectral_modulation(noise_pred, multiplier=0.0, percentile=5.0):
424
+ """
425
+ Clybius spectral modulation on noise_pred = (cond_x0 - uncond_x0).
426
+ Source: adept_sampler_v4_COMPLETE (2).py apply_spectral_modulation_clybius()
427
+ Boosts low-freq, suppresses extreme high-freq outliers.
428
+ Applied BEFORE NRS computation.
429
+ """
430
+ if multiplier == 0.0 or percentile <= 0:
431
+ return noise_pred
432
+ try:
433
+ fourier = torch.fft.fft2(noise_pred, dim=(-2, -1))
434
+ log_amp = torch.log(torch.sqrt(fourier.real ** 2 + fourier.imag ** 2) + 1e-8)
435
+ flat = log_amp.abs().flatten(2) # [B, C, H*W]
436
+
437
+ q_lo = torch.quantile(flat, percentile * 0.01, dim=2)
438
+ q_hi = torch.quantile(flat, 1.0 - percentile * 0.01, dim=2)
439
+
440
+ # Expand to [B, C, H, W]
441
+ q_lo = q_lo.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
442
+ q_hi = q_hi.unsqueeze(-1).unsqueeze(-1).expand(log_amp.shape)
443
+
444
+ # mask_low: boost frequencies below lower threshold (1.0–1.5 range)
445
+ # mask_high: reduce frequencies above upper threshold (0.5–1.0 range)
446
+ mask_low = ((log_amp < q_lo).float() + 1.0).clamp_(max=1.5)
447
+ mask_high = (log_amp < q_hi).float().clamp_(min=0.5)
448
+
449
+ filtered = fourier * ((mask_low * mask_high) ** multiplier)
450
+ return torch.fft.ifft2(filtered, dim=(-2, -1)).real
451
+ except Exception:
452
+ return noise_pred
453
+
454
+
455
+ def apply_uncond_modifications(uncond, noise_strength=0.0, uncond_scale=1.0):
456
+ """
457
+ Add noise to uncond and/or scale it.
458
+ Source: forge_condBlast (6).py
459
+ noise: lerp(uncond, randn*uncond.std(), strength)
460
+ scale: lerp(zeros, uncond, scale)
461
+ """
462
+ if noise_strength <= 0.0 and uncond_scale == 1.0:
463
+ return uncond
464
+ try:
465
+ result = uncond.clone()
466
+ if noise_strength > 0.0:
467
+ noise = torch.randn_like(result) * result.std()
468
+ result = torch.lerp(result, noise, noise_strength)
469
+ if uncond_scale != 1.0:
470
+ result = torch.lerp(torch.zeros_like(result), result, uncond_scale)
471
+ return result
472
+ except Exception:
473
+ return uncond
474
+
475
+
476
+ def apply_nrs_output_clamp(nrs_result, sigma, clamp_multiplier=0.0):
477
+ """
478
+ Adaptive output clamping based on sigma.
479
+ Source: adept_sampler_v4_COMPLETE (2).py apply_dynamic_thresholding().
480
+ threshold = clamp * (1 + sigma/10)
481
+ """
482
+ if clamp_multiplier <= 0.0:
483
+ return nrs_result
484
+ try:
485
+ sigma_val = sigma[0].item() if isinstance(sigma, torch.Tensor) else float(sigma)
486
+ threshold = clamp_multiplier * (1.0 + sigma_val / 10.0)
487
+ return torch.clamp(nrs_result, -threshold, threshold)
488
+ except Exception:
489
+ return nrs_result
490
+
491
+
492
+ # ==============================================================================
493
+ # ЧАСТЬ 4: STEP CONTROL (сохранено из оригинала)
494
+ # ==============================================================================
495
+
496
+ def should_apply_at_step(current_step, total_steps, start_step, end_step,
497
+ start_frac, end_frac, step_mode):
498
+ if step_mode == "Absolute Steps":
499
+ eff_start = max(0, start_step)
500
+ eff_end = min(total_steps, end_step) if end_step > 0 else total_steps
501
+ return eff_start <= current_step < eff_end
502
+ else:
503
+ eff_start = int(total_steps * max(0.0, min(1.0, start_frac)))
504
+ eff_end = int(total_steps * max(0.0, min(1.0, end_frac)))
505
+ if eff_end == 0:
506
+ eff_end = total_steps
507
+ return eff_start <= current_step < eff_end
508
+
509
+
510
+ def get_param_value_at_step(base_value, current_step, total_steps, start_step, end_step,
511
+ start_frac, end_frac, step_mode, enabled):
512
+ if not enabled:
513
+ return base_value
514
+ if should_apply_at_step(current_step, total_steps, start_step, end_step,
515
+ start_frac, end_frac, step_mode):
516
+ return base_value
517
+ return 0.0
518
+
519
+
520
+ # ==============================================================================
521
+ # ЧАСТЬ 5: HOOKS
522
+ # ==============================================================================
523
+
524
+ def hook_cfg_denoiser_params(params):
525
+ if hasattr(params.denoiser, 'p') and getattr(params.denoiser.p, '_nrs_enabled', False):
526
+ params.denoiser.p._nrs_current_sigma = params.sigma
527
+ params.denoiser.p._nrs_current_x_in = params.x
528
+ if hasattr(params, 'sampling_step'):
529
+ params.denoiser.p._nrs_current_step = params.sampling_step
530
+ elif hasattr(params.denoiser, 'step'):
531
+ params.denoiser.p._nrs_current_step = params.denoiser.step
532
+ else:
533
+ params.denoiser.p._nrs_current_step = getattr(
534
+ params.denoiser.p, '_nrs_current_step', 0)
535
+
536
+
537
+ script_callbacks.on_cfg_denoiser(hook_cfg_denoiser_params)
538
+
539
+ if not hasattr(sd_samplers_cfg_denoiser.CFGDenoiser, 'original_combine_denoised_nrs_backup'):
540
+ sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup = \
541
+ sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised
542
+
543
+
544
+ def hijacked_combine_denoised(self, x_out, conds_list, uncond, cond_scale):
545
+ _orig = sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup
546
+
547
+ if not getattr(self, 'p', None) or not getattr(self.p, '_nrs_enabled', False):
548
+ return _orig(self, x_out, conds_list, uncond, cond_scale)
549
+
550
+ if not hasattr(self.p, '_nrs_current_sigma') or not hasattr(self.p, '_nrs_current_x_in'):
551
+ return _orig(self, x_out, conds_list, uncond, cond_scale)
552
+
553
+ try:
554
+ p = self.p
555
+ sigma = p._nrs_current_sigma
556
+ base_skew, base_stretch, base_squash = p._nrs_params
557
+
558
+ # ── Step Control ──────────────────────────────────────────────────────
559
+ current_step = getattr(p, '_nrs_current_step', 0)
560
+ total_steps = getattr(p, 'steps', 20)
561
+ step_ctrl = getattr(p, '_nrs_step_control_enabled', False)
562
+ step_mode_global = getattr(p, '_nrs_step_control_mode', 'Global')
563
+
564
+ if step_ctrl:
565
+ if step_mode_global == 'Global':
566
+ gs = getattr(p, '_nrs_global_step_settings', {})
567
+ if not should_apply_at_step(
568
+ current_step, total_steps,
569
+ gs.get('start_step', 0), gs.get('end_step', total_steps),
570
+ gs.get('start_frac', 0.0), gs.get('end_frac', 1.0),
571
+ gs.get('step_mode', 'Absolute Steps')):
572
+ return _orig(self, x_out, conds_list, uncond, cond_scale)
573
+ skew, stretch, squash = base_skew, base_stretch, base_squash
574
+ else:
575
+ ind = getattr(p, '_nrs_individual_step_settings', {})
576
+ sk = ind.get('skew', {})
577
+ st = ind.get('stretch', {})
578
+ sq = ind.get('squash', {})
579
+ skew = get_param_value_at_step(
580
+ base_skew, current_step, total_steps,
581
+ sk.get('start_step', 0), sk.get('end_step', total_steps),
582
+ sk.get('start_frac', 0.0), sk.get('end_frac', 1.0),
583
+ sk.get('step_mode', 'Absolute Steps'), sk.get('enabled', True))
584
+ stretch = get_param_value_at_step(
585
+ base_stretch, current_step, total_steps,
586
+ st.get('start_step', 0), st.get('end_step', total_steps),
587
+ st.get('start_frac', 0.0), st.get('end_frac', 1.0),
588
+ st.get('step_mode', 'Absolute Steps'), st.get('enabled', True))
589
+ squash = get_param_value_at_step(
590
+ base_squash, current_step, total_steps,
591
+ sq.get('start_step', 0), sq.get('end_step', total_steps),
592
+ sq.get('start_frac', 0.0), sq.get('end_frac', 1.0),
593
+ sq.get('step_mode', 'Absolute Steps'), sq.get('enabled', True))
594
+ else:
595
+ skew, stretch, squash = base_skew, base_stretch, base_squash
596
+
597
+ # ── Scheduling ────────────────────────────────────────────────────────
598
+ sched_mode = getattr(p, '_nrs_sched_mode', 'Off')
599
+ progress = current_step / max(total_steps - 1, 1)
600
+
601
+ if sched_mode == 'Individual Curves':
602
+ sched_val = getattr(p, '_nrs_sched_val', 2.0)
603
+ skew = nrs_schedule_value(
604
+ skew, current_step, total_steps,
605
+ getattr(p, '_nrs_skew_curve', 'Constant'),
606
+ getattr(p, '_nrs_skew_curve_min', 0.0), sched_val)
607
+ stretch = nrs_schedule_value(
608
+ stretch, current_step, total_steps,
609
+ getattr(p, '_nrs_stretch_curve', 'Constant'),
610
+ getattr(p, '_nrs_stretch_curve_min', 0.0), sched_val)
611
+ squash = nrs_schedule_value(
612
+ squash, current_step, total_steps,
613
+ getattr(p, '_nrs_squash_curve', 'Constant'),
614
+ getattr(p, '_nrs_squash_curve_min', 0.0), sched_val)
615
+
616
+ elif sched_mode == 'CADS Anneal':
617
+ tau1 = getattr(p, '_nrs_cads_tau1', 0.6)
618
+ tau2 = getattr(p, '_nrs_cads_tau2', 0.9)
619
+ cads_scale = nrs_cads_schedule(current_step, total_steps, tau1, tau2)
620
+ skew *= cads_scale
621
+ stretch *= cads_scale
622
+ # squash stays at base — CADS doesn't affect the clamp
623
+
624
+ elif sched_mode == 'Adaptive Phases':
625
+ euler_end = getattr(p, '_nrs_adaptive_euler_end', 0.35)
626
+ dpm_end = getattr(p, '_nrs_adaptive_dpm_end', 0.70)
627
+ skew, stretch, squash = calc_adaptive_nrs_params(
628
+ skew, stretch, squash, progress, euler_end, dpm_end)
629
+
630
+ # ── Feature flags ─────────────────────────────────────────────────────
631
+ per_channel = getattr(p, '_nrs_per_channel', False)
632
+ use_ad_norm = getattr(p, '_nrs_ad_norm', False)
633
+ refine_blend = getattr(p, '_nrs_refine_blend', 0.0)
634
+ refine_first_half = getattr(p, '_nrs_refine_first_half', True)
635
+ blend_phi = getattr(p, '_nrs_blend_phi', 1.0)
636
+ variance_phi = getattr(p, '_nrs_variance_phi', 0.0)
637
+ drift_intensity = getattr(p, '_nrs_drift_intensity', 0.0)
638
+ drift_method = getattr(p, '_nrs_drift_method', 'mean')
639
+ output_clamp = getattr(p, '_nrs_output_clamp', 0.0)
640
+ inter_step_mode = getattr(p, '_nrs_inter_step_mode', 'Off')
641
+ momentum = getattr(p, '_nrs_momentum', 0.0)
642
+ ge_gamma = getattr(p, '_nrs_ge_gamma', 1.0)
643
+ detail_boost = getattr(p, '_nrs_detail_boost', 0.0)
644
+ spectral_mod = getattr(p, '_nrs_spectral_mod', 0.0)
645
+ spectral_pct = getattr(p, '_nrs_spectral_pct', 5.0)
646
+ uncond_noise = getattr(p, '_nrs_uncond_noise', 0.0)
647
+ uncond_scale = getattr(p, '_nrs_uncond_scale', 1.0)
648
+
649
+ # Inter-step state
650
+ prev_results = getattr(p, '_nrs_prev_results', {})
651
+ prev_diffs = getattr(p, '_nrs_prev_diffs', {})
652
+
653
+ # ── Prepare tensors ───────────────────────────────────────────────────
654
+ denoised_uncond = x_out[-uncond.shape[0]:]
655
+ denoised = torch.clone(denoised_uncond)
656
+ x_orig_uncond = p._nrs_current_x_in[-uncond.shape[0]:]
657
+
658
+ # ── Main per-item loop ────────────────────────────────────────────────
659
+ for i, conds in enumerate(conds_list):
660
+ for idx, (cond_index, weight) in enumerate(conds):
661
+ current_cond = x_out[cond_index]
662
+ if idx != 0:
663
+ denoised[i] += (current_cond - denoised_uncond[i]) * (weight * cond_scale)
664
+ continue
665
+
666
+ x_orig_i = x_orig_uncond[i].unsqueeze(0)
667
+ c_in = current_cond.unsqueeze(0) # original, before any modifications
668
+ u_in = denoised_uncond[i].unsqueeze(0)
669
+
670
+ # 1. Uncond modifications
671
+ if uncond_noise > 0.0 or uncond_scale != 1.0:
672
+ u_in = apply_uncond_modifications(u_in, uncond_noise, uncond_scale)
673
+
674
+ # 2. Spectral modulation on noise_pred BEFORE NRS
675
+ # Applied to c_in_mod only — original c_in kept for blend_phi and variance_phi
676
+ c_in_for_nrs = c_in
677
+ if spectral_mod > 0.0:
678
+ noise_pred = c_in - u_in
679
+ noise_pred_mod = apply_spectral_modulation(noise_pred, spectral_mod, spectral_pct)
680
+ c_in_for_nrs = u_in + noise_pred_mod
681
+
682
+ # 3. Core NRS computation
683
+ nrs_result = calc_nrs_midpoint_refined(
684
+ x_orig_i, c_in_for_nrs, u_in, sigma,
685
+ skew, stretch, squash,
686
+ refine_blend=refine_blend,
687
+ first_half_only=refine_first_half,
688
+ current_step=current_step,
689
+ total_steps=total_steps,
690
+ use_per_channel=per_channel,
691
+ use_ad_norm=use_ad_norm,
692
+ )
693
+
694
+ # 4. Variance preserving rescale — use ORIGINAL c_in as reference
695
+ if variance_phi > 0.0:
696
+ nrs_result = apply_variance_preserving_rescale(nrs_result, c_in, variance_phi)
697
+
698
+ # 5. Drift correction
699
+ if drift_intensity > 0.0:
700
+ nrs_result = apply_nrs_drift_correction(nrs_result, drift_intensity, drift_method)
701
+
702
+ # 6. Blend phi (NRS ↔ plain CFG) — plain_cfg uses ORIGINAL c_in
703
+ if blend_phi < 1.0:
704
+ plain_cfg = u_in + (c_in - u_in) * cond_scale
705
+ nrs_result = apply_blend_phi(nrs_result, plain_cfg, blend_phi)
706
+
707
+ # 7. Inter-step: Momentum or GE-Gamma
708
+ prev_r = prev_results.get(i, None)
709
+ if inter_step_mode == 'Momentum' and momentum > 0.0:
710
+ # Full RES/Clybius momentum with velocity tracking
711
+ prev_vel = prev_diffs.get(i, None)
712
+ nrs_result, new_vel = apply_nrs_momentum(nrs_result, prev_r, prev_vel, momentum)
713
+ if new_vel is not None:
714
+ prev_diffs[i] = new_vel.detach().clone()
715
+ elif inter_step_mode == 'GE-Gamma' and ge_gamma != 1.0:
716
+ prev_d = prev_diffs.get(i, None)
717
+ # Save RAW diff BEFORE extrapolation (this is old_d for next step)
718
+ if prev_r is not None:
719
+ raw_diff = (nrs_result - prev_r).detach().clone()
720
+ nrs_result = apply_nrs_ge_extrapolation(nrs_result, prev_r, prev_d, ge_gamma)
721
+ if prev_r is not None:
722
+ prev_diffs[i] = raw_diff # store pre-extrapolation diff
723
+
724
+ # Update inter-step state
725
+ prev_results[i] = nrs_result.detach().clone()
726
+
727
+ # 8. Detail boost
728
+ if detail_boost > 0.0:
729
+ nrs_result = apply_nrs_detail_boost(nrs_result, progress, detail_boost)
730
+
731
+ # 9. Output clamp
732
+ if output_clamp > 0.0:
733
+ nrs_result = apply_nrs_output_clamp(nrs_result, sigma, output_clamp)
734
+
735
+ # Write result
736
+ if len(conds) == 1:
737
+ denoised[i] = nrs_result.squeeze(0)
738
+ else:
739
+ delta = nrs_result.squeeze(0) - denoised_uncond[i]
740
+ denoised[i] += delta * weight
741
+
742
+ # Save inter-step state
743
+ p._nrs_prev_results = prev_results
744
+ p._nrs_prev_diffs = prev_diffs
745
+
746
+ return denoised
747
+
748
+ except Exception as e:
749
+ print(f"!!! NRS Enhanced Error (Fallback): {e}")
750
+ import traceback
751
+ traceback.print_exc()
752
+ return sd_samplers_cfg_denoiser.CFGDenoiser.original_combine_denoised_nrs_backup(
753
+ self, x_out, conds_list, uncond, cond_scale)
754
+
755
+
756
+ # ==============================================================================
757
+ # ЧАСТЬ 6: UI
758
+ # ==============================================================================
759
+
760
+ class NRSScript(scripts.Script):
761
+ def title(self):
762
+ return "NRS + Kohaku Enhanced"
763
+
764
+ def show(self, is_img2img):
765
+ return scripts.AlwaysVisible
766
+
767
+ def ui(self, is_img2img):
768
+ with gr.Accordion("NRS + Kohaku Enhanced", open=False):
769
+ enabled = gr.Checkbox(label="Включить NRS (Enable)", value=False)
770
+
771
+ # ── Инструкция ────────────────────────────────────────────────────
772
+ with gr.Accordion("❓ Инструкция / Help", open=False):
773
+ gr.Markdown("""
774
+ ### NRS + Kohaku Enhanced v2.0
775
+
776
+ **NRS (Negative Rejection Steering)** — замена стандартному CFG с 3 параметрами:
777
+ - **Skew** — отталкивание от Negative prompt (аналог силы CFG для структуры). Старт: 3–5
778
+ - **Stretch** — притяжение к Positive prompt (усиление цветов/стиля). Старт: 2–7
779
+ - **Squash** — ограничитель (0=максимум, 1=мягко+детали). Старт: 0.0
780
+
781
+ ### 🔮 Midpoint Refinement (исправленный Kohaku)
782
+ Правильная адаптация Kohaku_LoNyu_Yog: вычисляет NRS в промежуточной точке и усредняет результаты (Runge-Kutta 2-го порядка). Даёт более точное направление к целевой области.
783
+
784
+ ### 📐 Scheduling
785
+ - **Individual Curves**: каждый параметр меняется по своей кривой (Linear/Cosine/Power/...)
786
+ - **CADS Anneal**: NRS нарастает через несколько шагов (tau1/tau2 трапеция)
787
+ - **Adaptive Phases**: автоматические фазы Euler→DPM→Detail
788
+
789
+ ### 🔬 Advanced Math
790
+ - **Per-Channel**: независимая обработка каждого латентного канала
791
+ - **AD Norm**: Absolute Deviation вместо L2 (устойчивее к выбросам)
792
+ - **Blend Phi**: смешение NRS↔CFG (1.0=чистый NRS, 0.0=чистый CFG)
793
+ - **Variance Phi**: сохранение дисперсии после NRS
794
+
795
+ ### 🔁 Inter-Step
796
+ - **Momentum**: сглаживание NRS-векторов между шагами
797
+ - **GE-Gamma**: экстраполяция направления (>1 усиливает тренд)
798
+ """)
799
+
800
+ # ── Основные параметры ────────────────────────────────────────────
801
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
802
+ " font-size:0.9em; opacity:0.8;'>Основные параметры</div>")
803
+ with gr.Row():
804
+ skew = gr.Slider(label="Skew (Композиция)", minimum=-30.0, maximum=30.0,
805
+ step=0.05, value=4.0,
806
+ info="Отклонение от Neg prompt. Рекомендуется: 3–5")
807
+ stretch = gr.Slider(label="Stretch (Цвета/Стиль)", minimum=-30.0, maximum=30.0,
808
+ step=0.05, value=2.0,
809
+ info="Притяжение к Pos prompt. Рекомендуется: 2–7")
810
+ squash = gr.Slider(label="Squash (Защита от пережарки)", minimum=0.0, maximum=1.0,
811
+ step=0.01, value=0.0,
812
+ info="0=максимальный эффект, 1=больше деталей/мягче")
813
+
814
+ # ── Midpoint Refinement ───────────────────────────────────────────
815
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
816
+ " font-size:0.9em; opacity:0.8;'>🔮 Midpoint Refinement (Kohaku)</div>")
817
+ refine_blend = gr.Slider(
818
+ label="Refinement Blend", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
819
+ info="0=выкл, 0.5=рекомендуется. Runge-Kutta уточнение NRS-вектора")
820
+ refine_first_half = gr.Checkbox(
821
+ label="Only first half of steps (как в оригинале Kohaku)",
822
+ value=True,
823
+ info="Применять refinement только на первой половине шагов")
824
+
825
+ # ── Scheduling ────────────────────────────────────────────────────
826
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
827
+ " font-size:0.9em; opacity:0.8;'>📐 Parameter Scheduling</div>")
828
+ sched_mode = gr.Radio(
829
+ label="Режим расписания", choices=SCHED_MODES, value="Off")
830
+
831
+ with gr.Group(visible=False) as curves_group:
832
+ gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
833
+ "Кривые применяются к базовым значениям независимо</div>")
834
+ with gr.Row():
835
+ skew_curve = gr.Dropdown(label="Skew Curve", choices=CURVE_CHOICES,
836
+ value="Constant")
837
+ skew_curve_min = gr.Slider(label="Skew Min", minimum=-30.0, maximum=30.0,
838
+ step=0.05, value=0.0)
839
+ with gr.Row():
840
+ stretch_curve = gr.Dropdown(label="Stretch Curve", choices=CURVE_CHOICES,
841
+ value="Constant")
842
+ stretch_curve_min = gr.Slider(label="Stretch Min", minimum=-30.0, maximum=30.0,
843
+ step=0.05, value=0.0)
844
+ with gr.Row():
845
+ squash_curve = gr.Dropdown(label="Squash Curve", choices=CURVE_CHOICES,
846
+ value="Constant")
847
+ squash_curve_min = gr.Slider(label="Squash Min", minimum=0.0, maximum=1.0,
848
+ step=0.01, value=0.0)
849
+ sched_val = gr.Slider(
850
+ label="Sched Value (для Power/Repeating кривых)",
851
+ minimum=0.1, maximum=8.0, step=0.1, value=2.0)
852
+
853
+ with gr.Group(visible=False) as cads_group:
854
+ gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
855
+ "Трапецеидальное нарастание силы NRS. "
856
+ "tau1=0.6, tau2=0.9: NRS включается на ~10% шагов, "
857
+ "полная сила с ~40%</div>")
858
+ with gr.Row():
859
+ cads_tau1 = gr.Slider(label="Tau 1 (полная сила)", minimum=0.0, maximum=1.0,
860
+ step=0.05, value=0.6)
861
+ cads_tau2 = gr.Slider(label="Tau 2 (начало нарастания)", minimum=0.0, maximum=1.0,
862
+ step=0.05, value=0.9)
863
+
864
+ with gr.Group(visible=False) as adaptive_group:
865
+ gr.HTML("<div style='font-size:0.85em; opacity:0.7; margin:0.3em 0;'>"
866
+ "Euler Phase: макс. Skew. "
867
+ "DPM Phase: переход. "
868
+ "Detail Phase: минимум Skew, максимум Squash</div>")
869
+ with gr.Row():
870
+ adaptive_euler_end = gr.Slider(label="Euler Phase End", minimum=0.0, maximum=1.0,
871
+ step=0.05, value=0.35)
872
+ adaptive_dpm_end = gr.Slider(label="DPM Phase End", minimum=0.0, maximum=1.0,
873
+ step=0.05, value=0.70)
874
+
875
+ def update_sched_groups(mode):
876
+ return {
877
+ curves_group: gr.update(visible=(mode == "Individual Curves")),
878
+ cads_group: gr.update(visible=(mode == "CADS Anneal")),
879
+ adaptive_group: gr.update(visible=(mode == "Adaptive Phases")),
880
+ }
881
+
882
+ sched_mode.change(fn=update_sched_groups, inputs=[sched_mode],
883
+ outputs=[curves_group, cads_group, adaptive_group])
884
+
885
+ # ── Advanced Math ─────────────────────────────────────────────────
886
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
887
+ " font-size:0.9em; opacity:0.8;'>🔬 Advanced Math</div>")
888
+ with gr.Row():
889
+ per_channel = gr.Checkbox(
890
+ label="Per-Channel Processing",
891
+ value=False,
892
+ info="Обрабатывать каждый латентный канал независимо")
893
+ ad_norm = gr.Checkbox(
894
+ label="AD Normalization",
895
+ value=False,
896
+ info="Absolute Deviation вместо L2 (устойчивее к выбросам)")
897
+ with gr.Row():
898
+ blend_phi = gr.Slider(
899
+ label="Blend Phi (NRS↔CFG)", minimum=0.0, maximum=1.0, step=0.01, value=1.0,
900
+ info="1.0=чистый NRS, 0.0=чистый CFG, между — смесь")
901
+ variance_phi = gr.Slider(
902
+ label="Variance Rescale Phi", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
903
+ info="0=выкл. Нормирует дисперсию NRS-результата к дисперсии cond")
904
+
905
+ # ── Post-Processing ───────────────────────────────────────────────
906
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
907
+ " font-size:0.9em; opacity:0.8;'>📡 Post-Processing</div>")
908
+ with gr.Row():
909
+ drift_intensity = gr.Slider(
910
+ label="Drift Correction", minimum=0.0, maximum=1.0, step=0.01, value=0.0,
911
+ info="Убирает смещение mean/median от высокого CFG")
912
+ drift_method = gr.Radio(
913
+ label="Метод", choices=DRIFT_METHODS, value="mean")
914
+ output_clamp = gr.Slider(
915
+ label="Output Clamp (0=выкл)", minimum=0.0, maximum=200.0, step=0.5, value=0.0,
916
+ info="Адаптивное ограничение экстремальных значений. threshold = clamp*(1+sigma/10)")
917
+
918
+ # ── Inter-Step ────────────────────────────────────────────────────
919
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
920
+ " font-size:0.9em; opacity:0.8;'>🔁 Inter-Step</div>")
921
+ inter_step_mode = gr.Radio(
922
+ label="Режим", choices=INTER_STEP_MODES, value="Off")
923
+ with gr.Row():
924
+ momentum_slider = gr.Slider(
925
+ label="Momentum", minimum=0.0, maximum=0.95, step=0.01, value=0.5,
926
+ visible=False,
927
+ info="Сглаживание NRS между шагами (0=выкл, 0.5=рекомендуется)")
928
+ ge_gamma_slider = gr.Slider(
929
+ label="GE Gamma", minimum=0.1, maximum=4.0, step=0.05, value=1.5,
930
+ visible=False,
931
+ info=">1=экстраполяция тренда, 1=стандарт, <1=сглаживание")
932
+
933
+ def update_inter_step(mode):
934
+ return {
935
+ momentum_slider: gr.update(visible=(mode == "Momentum")),
936
+ ge_gamma_slider: gr.update(visible=(mode == "GE-Gamma")),
937
+ }
938
+
939
+ inter_step_mode.change(fn=update_inter_step, inputs=[inter_step_mode],
940
+ outputs=[momentum_slider, ge_gamma_slider])
941
+
942
+ # ── Enhancements ──────────────────────────────────────────────────
943
+ gr.HTML("<div style='margin:0.6em 0 0.4em; border-bottom:1px solid #555;"
944
+ " font-size:0.9em; opacity:0.8;'>✨ Enhancements</div>")
945
+ with gr.Row():
946
+ detail_boost = gr.Slider(
947
+ label="Detail Boost (0=выкл)", minimum=0.0, maximum=3.0, step=0.05, value=0.0,
948
+ info="Усиление высокочастотных деталей на поздних шагах")
949
+ spectral_mod = gr.Slider(
950
+ label="Spectral Modulation (0=выкл)", minimum=0.0, maximum=2.0, step=0.05, value=0.0,
951
+ info="FFT-коррекция частот noise_pred перед NRS")
952
+ spectral_pct = gr.Slider(
953
+ label="Spectral Percentile", minimum=1.0, maximum=20.0, step=0.5, value=5.0,
954
+ info="Процентиль для частотной маски (меньше = агрессивнее)")
955
+ with gr.Row():
956
+ uncond_noise = gr.Slider(
957
+ label="Uncond Noise (0=выкл)", minimum=0.0, maximum=0.5, step=0.01, value=0.0,
958
+ info="Добавить шум к uncond (увеличивает разнообразие)")
959
+ uncond_scale = gr.Slider(
960
+ label="Uncond Scale", minimum=0.1, maximum=2.0, step=0.01, value=1.0,
961
+ info="Масштаб uncond (1.0=стандарт, <1=ослабить neg)")
962
+
963
+ # ── Step Control ──────────────────────────────────────────────────
964
+ with gr.Accordion("⏱️ Step Control", open=False):
965
+ with gr.Row():
966
+ step_control_enabled = gr.Checkbox(label="Включить Step Control", value=False)
967
+ step_control_mode = gr.Radio(
968
+ label="Режим", choices=["Global", "Individual"], value="Global")
969
+
970
+ with gr.Group(visible=True) as global_group:
971
+ gr.HTML("<div style='font-weight:bold; margin:0.4em 0;'>Глобальные настройки</div>")
972
+ global_step_mode = gr.Radio(
973
+ label="Режим шагов",
974
+ choices=["Absolute Steps", "Fraction of Steps"],
975
+ value="Absolute Steps")
976
+ with gr.Row():
977
+ global_start_step = gr.Slider(label="Start Step", minimum=0,
978
+ maximum=150, step=1, value=0, visible=True)
979
+ global_end_step = gr.Slider(label="End Step (0=конец)", minimum=0,
980
+ maximum=150, step=1, value=0, visible=True)
981
+ with gr.Row():
982
+ global_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
983
+ maximum=1.0, step=0.01, value=0.0, visible=False)
984
+ global_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
985
+ maximum=1.0, step=0.01, value=1.0, visible=False)
986
+
987
+ with gr.Group(visible=False) as individual_group:
988
+ gr.HTML("<div style='font-weight:bold; margin:0.4em 0;'>Индивидуальные настройки</div>")
989
+ with gr.Accordion("Skew — Step Settings", open=False):
990
+ skew_step_enabled = gr.Checkbox(label="Включить для Skew", value=True)
991
+ skew_step_mode = gr.Radio(label="Режим",
992
+ choices=["Absolute Steps", "Fraction of Steps"],
993
+ value="Absolute Steps")
994
+ with gr.Row():
995
+ skew_start_step = gr.Slider(label="Start Step", minimum=0,
996
+ maximum=150, step=1, value=0, visible=True)
997
+ skew_end_step = gr.Slider(label="End Step", minimum=0,
998
+ maximum=150, step=1, value=0, visible=True)
999
+ with gr.Row():
1000
+ skew_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1001
+ maximum=1.0, step=0.01, value=0.0, visible=False)
1002
+ skew_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1003
+ maximum=1.0, step=0.01, value=1.0, visible=False)
1004
+ with gr.Accordion("Stretch — Step Settings", open=False):
1005
+ stretch_step_enabled = gr.Checkbox(label="Включить для Stretch", value=True)
1006
+ stretch_step_mode = gr.Radio(label="Режим",
1007
+ choices=["Absolute Steps", "Fraction of Steps"],
1008
+ value="Absolute Steps")
1009
+ with gr.Row():
1010
+ stretch_start_step = gr.Slider(label="Start Step", minimum=0,
1011
+ maximum=150, step=1, value=0, visible=True)
1012
+ stretch_end_step = gr.Slider(label="End Step", minimum=0,
1013
+ maximum=150, step=1, value=0, visible=True)
1014
+ with gr.Row():
1015
+ stretch_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1016
+ maximum=1.0, step=0.01, value=0.0, visible=False)
1017
+ stretch_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1018
+ maximum=1.0, step=0.01, value=1.0, visible=False)
1019
+ with gr.Accordion("Squash — Step Settings", open=False):
1020
+ squash_step_enabled = gr.Checkbox(label="Включить для Squash", value=True)
1021
+ squash_step_mode = gr.Radio(label="Режим",
1022
+ choices=["Absolute Steps", "Fraction of Steps"],
1023
+ value="Absolute Steps")
1024
+ with gr.Row():
1025
+ squash_start_step = gr.Slider(label="Start Step", minimum=0,
1026
+ maximum=150, step=1, value=0, visible=True)
1027
+ squash_end_step = gr.Slider(label="End Step", minimum=0,
1028
+ maximum=150, step=1, value=0, visible=True)
1029
+ with gr.Row():
1030
+ squash_start_frac = gr.Slider(label="Start (fraction)", minimum=0.0,
1031
+ maximum=1.0, step=0.01, value=0.0, visible=False)
1032
+ squash_end_frac = gr.Slider(label="End (fraction)", minimum=0.0,
1033
+ maximum=1.0, step=0.01, value=1.0, visible=False)
1034
+
1035
+ def update_sc_groups(mode):
1036
+ return {
1037
+ global_group: gr.update(visible=(mode == "Global")),
1038
+ individual_group: gr.update(visible=(mode == "Individual")),
1039
+ }
1040
+
1041
+ step_control_mode.change(fn=update_sc_groups, inputs=[step_control_mode],
1042
+ outputs=[global_group, individual_group])
1043
+
1044
+ def _tog(mode):
1045
+ a = mode == "Absolute Steps"
1046
+ return (gr.update(visible=a), gr.update(visible=a),
1047
+ gr.update(visible=not a), gr.update(visible=not a))
1048
+
1049
+ global_step_mode.change(fn=_tog, inputs=[global_step_mode],
1050
+ outputs=[global_start_step, global_end_step,
1051
+ global_start_frac, global_end_frac])
1052
+ skew_step_mode.change(fn=_tog, inputs=[skew_step_mode],
1053
+ outputs=[skew_start_step, skew_end_step,
1054
+ skew_start_frac, skew_end_frac])
1055
+ stretch_step_mode.change(fn=_tog, inputs=[stretch_step_mode],
1056
+ outputs=[stretch_start_step, stretch_end_step,
1057
+ stretch_start_frac, stretch_end_frac])
1058
+ squash_step_mode.change(fn=_tog, inputs=[squash_step_mode],
1059
+ outputs=[squash_start_step, squash_end_step,
1060
+ squash_start_frac, squash_end_frac])
1061
+
1062
+ return [
1063
+ # Core
1064
+ enabled, skew, stretch, squash,
1065
+ # Midpoint Refinement
1066
+ refine_blend, refine_first_half,
1067
+ # Scheduling
1068
+ sched_mode,
1069
+ skew_curve, skew_curve_min,
1070
+ stretch_curve, stretch_curve_min,
1071
+ squash_curve, squash_curve_min,
1072
+ sched_val,
1073
+ cads_tau1, cads_tau2,
1074
+ adaptive_euler_end, adaptive_dpm_end,
1075
+ # Advanced Math
1076
+ per_channel, ad_norm,
1077
+ blend_phi, variance_phi,
1078
+ # Post-Processing
1079
+ drift_intensity, drift_method,
1080
+ output_clamp,
1081
+ # Inter-Step
1082
+ inter_step_mode, momentum_slider, ge_gamma_slider,
1083
+ # Enhancements
1084
+ detail_boost,
1085
+ spectral_mod, spectral_pct,
1086
+ uncond_noise, uncond_scale,
1087
+ # Step Control
1088
+ step_control_enabled, step_control_mode,
1089
+ global_step_mode, global_start_step, global_end_step,
1090
+ global_start_frac, global_end_frac,
1091
+ skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step,
1092
+ skew_start_frac, skew_end_frac,
1093
+ stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step,
1094
+ stretch_start_frac, stretch_end_frac,
1095
+ squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step,
1096
+ squash_start_frac, squash_end_frac,
1097
+ ]
1098
+
1099
+ def process(self, p,
1100
+ # Core
1101
+ enabled, skew, stretch, squash,
1102
+ # Midpoint Refinement
1103
+ refine_blend, refine_first_half,
1104
+ # Scheduling
1105
+ sched_mode,
1106
+ skew_curve, skew_curve_min,
1107
+ stretch_curve, stretch_curve_min,
1108
+ squash_curve, squash_curve_min,
1109
+ sched_val,
1110
+ cads_tau1, cads_tau2,
1111
+ adaptive_euler_end, adaptive_dpm_end,
1112
+ # Advanced Math
1113
+ per_channel, ad_norm,
1114
+ blend_phi, variance_phi,
1115
+ # Post-Processing
1116
+ drift_intensity, drift_method,
1117
+ output_clamp,
1118
+ # Inter-Step
1119
+ inter_step_mode, momentum, ge_gamma,
1120
+ # Enhancements
1121
+ detail_boost,
1122
+ spectral_mod, spectral_pct,
1123
+ uncond_noise, uncond_scale,
1124
+ # Step Control
1125
+ step_control_enabled, step_control_mode,
1126
+ global_step_mode, global_start_step, global_end_step,
1127
+ global_start_frac, global_end_frac,
1128
+ skew_step_enabled, skew_step_mode, skew_start_step, skew_end_step,
1129
+ skew_start_frac, skew_end_frac,
1130
+ stretch_step_enabled, stretch_step_mode, stretch_start_step, stretch_end_step,
1131
+ stretch_start_frac, stretch_end_frac,
1132
+ squash_step_enabled, squash_step_mode, squash_start_step, squash_end_step,
1133
+ squash_start_frac, squash_end_frac):
1134
+
1135
+ p._nrs_enabled = enabled
1136
+ if not enabled:
1137
+ return
1138
+
1139
+ # Core params
1140
+ p._nrs_params = (skew, stretch, squash)
1141
+
1142
+ # Midpoint Refinement
1143
+ p._nrs_refine_blend = refine_blend
1144
+ p._nrs_refine_first_half = refine_first_half
1145
+
1146
+ # Scheduling
1147
+ p._nrs_sched_mode = sched_mode
1148
+ p._nrs_skew_curve = skew_curve
1149
+ p._nrs_skew_curve_min = skew_curve_min
1150
+ p._nrs_stretch_curve = stretch_curve
1151
+ p._nrs_stretch_curve_min = stretch_curve_min
1152
+ p._nrs_squash_curve = squash_curve
1153
+ p._nrs_squash_curve_min = squash_curve_min
1154
+ p._nrs_sched_val = sched_val
1155
+ p._nrs_cads_tau1 = cads_tau1
1156
+ p._nrs_cads_tau2 = cads_tau2
1157
+ p._nrs_adaptive_euler_end = adaptive_euler_end
1158
+ p._nrs_adaptive_dpm_end = adaptive_dpm_end
1159
+
1160
+ # Advanced Math
1161
+ p._nrs_per_channel = per_channel
1162
+ p._nrs_ad_norm = ad_norm
1163
+ p._nrs_blend_phi = blend_phi
1164
+ p._nrs_variance_phi = variance_phi
1165
+
1166
+ # Post-Processing
1167
+ p._nrs_drift_intensity = drift_intensity
1168
+ p._nrs_drift_method = drift_method
1169
+ p._nrs_output_clamp = output_clamp
1170
+
1171
+ # Inter-Step
1172
+ p._nrs_inter_step_mode = inter_step_mode
1173
+ p._nrs_momentum = momentum
1174
+ p._nrs_ge_gamma = ge_gamma
1175
+ p._nrs_prev_results = {}
1176
+ p._nrs_prev_diffs = {}
1177
+
1178
+ # Enhancements
1179
+ p._nrs_detail_boost = detail_boost
1180
+ p._nrs_spectral_mod = spectral_mod
1181
+ p._nrs_spectral_pct = spectral_pct
1182
+ p._nrs_uncond_noise = uncond_noise
1183
+ p._nrs_uncond_scale = uncond_scale
1184
+
1185
+ # Step Control
1186
+ p._nrs_step_control_enabled = step_control_enabled
1187
+ p._nrs_step_control_mode = step_control_mode
1188
+ p._nrs_global_step_settings = {
1189
+ 'step_mode': global_step_mode,
1190
+ 'start_step': global_start_step,
1191
+ 'end_step': global_end_step,
1192
+ 'start_frac': global_start_frac,
1193
+ 'end_frac': global_end_frac,
1194
+ }
1195
+ p._nrs_individual_step_settings = {
1196
+ 'skew': {
1197
+ 'enabled': skew_step_enabled, 'step_mode': skew_step_mode,
1198
+ 'start_step': skew_start_step, 'end_step': skew_end_step,
1199
+ 'start_frac': skew_start_frac, 'end_frac': skew_end_frac,
1200
+ },
1201
+ 'stretch': {
1202
+ 'enabled': stretch_step_enabled, 'step_mode': stretch_step_mode,
1203
+ 'start_step': stretch_start_step, 'end_step': stretch_end_step,
1204
+ 'start_frac': stretch_start_frac, 'end_frac': stretch_end_frac,
1205
+ },
1206
+ 'squash': {
1207
+ 'enabled': squash_step_enabled, 'step_mode': squash_step_mode,
1208
+ 'start_step': squash_start_step, 'end_step': squash_end_step,
1209
+ 'start_frac': squash_start_frac, 'end_frac': squash_end_frac,
1210
+ },
1211
+ }
1212
+
1213
+ p._nrs_current_step = 0
1214
+ sd_samplers_cfg_denoiser.CFGDenoiser.combine_denoised = hijacked_combine_denoised
1215
+
1216
+ def postprocess(self, p, processed, *args):
1217
+ # Clean up inter-step state to avoid memory leaks between generations
1218
+ for attr in ('_nrs_prev_results', '_nrs_prev_diffs'):
1219
+ if hasattr(p, attr):
1220
+ delattr(p, attr)