dikdimon commited on
Commit
409a349
Β·
verified Β·
1 Parent(s): 7b7182e

Upload 4 files

Browse files
cfg-prompt-forge/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸ”§ CFG & Prompt Forge v4
2
+
3
+ > Unified CFG control extension for **stock AUTOMATIC1111 Stable Diffusion WebUI**.
4
+ > Combines the best ideas from 8+ CFG-related extensions into one accordion.
5
+
6
+ ---
7
+
8
+ ## Sources & Credits
9
+
10
+ | Feature | Ported from |
11
+ |---|---|
12
+ | Step Ranges (Linear / Fixed / Default) | sd-webui-dycfg (guzuligo) |
13
+ | Skimmed CFG Classic & Smooth | sd-webui-skimmed_cfg (Extraltodeus) |
14
+ | Auto CFG, Lerp Uncond, LIR, Subtract Mean | ComfyUI-AutomaticCFG (Extraltodeus) |
15
+ | Smooth Skimmed (interpolated_scales) | pre_cfg_comfy_nodes (Extraltodeus) |
16
+ | Guidance Geometry Skew/Stretch/Squash | NRS β€” Negative Rejection Steering (Scrag3H1ll) |
17
+ | Reinhard Tonemap, Rescale CFG, Heuristic CFG | CFgfade (Panchovix) |
18
+ | HiRes-pass CFG override | HiresCFG (Panchovix) |
19
+ | Warmup/EndStep/LockAfterEnd, Apply to Pos/Neg, Pass Mode | SEGA v5.1 (continue-revolution) |
20
+ | Core opts direct override | cfg-prompt-forge v1 (original) |
21
+
22
+ ---
23
+
24
+ ## Features
25
+
26
+ ### πŸ“ Core Sampling Parameters
27
+ Per-generation `override_settings` override β€” auto-restored after generation.
28
+
29
+ | Control | `opts` key | Effect |
30
+ |---|---|---|
31
+ | Skip Early CFG | `skip_early_cond` | Fraction of early steps ignoring negative prompt |
32
+ | Prompt Word-Wrap | `comma_padding_backtrack` | CLIP 75-token chunk backtrack |
33
+ | NGMS | `s_min_uncond` | Skip negative when Οƒ < threshold |
34
+ | NGMS All Steps | `s_min_uncond_all` | Skip every matching step |
35
+
36
+ ### πŸ“Š CFG Step Ranges (DyCFG)
37
+ Three step windows with start / end / CFG value / interpolation.
38
+ Interp: **Default** (base), **Linear** (interpolate), **Fixed** (hold last value).
39
+
40
+ ### πŸ“ˆ CFG Schedule
41
+ 8 curve shapes multiplied on top of step-range CFG. Live SVG preview.
42
+ Off / Ramp Up / Ramp Down / Cosine In-Out / Bell / Inv. Bell / Step Up / Step Down.
43
+
44
+ ### πŸš€ End-step Boost
45
+ Ramp CFG multiplier from Γ—1.0 up to BoostΓ— during last N% of steps. Stacks with schedule.
46
+
47
+ ### 🎲 Per-step Jitter
48
+ Random Β±% perturbation each step for organic textures.
49
+
50
+ ### πŸ€– Auto CFG (ComfyUI-AutomaticCFG)
51
+ Per-latent-channel adaptive CFG. Methods: hard / soft / hard_squared / range.
52
+ When enabled, replaces the combine loop β€” NRS is bypassed.
53
+
54
+ ### πŸ”— Uncond Lerp
55
+ Blend `text_uncond` toward `text_cond` preserving norm.
56
+ Strength 1.0 = no change. <1 weakens negative. >1 exaggerates it.
57
+
58
+ ### πŸ”¬ Skimmed CFG
59
+ Text-embedding-level uncond modification:
60
+ - **Classic** β€” binary mask on over-aligned dimensions
61
+ - **Smooth** β€” importance-weighted blend (interpolated_scales)
62
+
63
+ ### ⚑ NRS Guidance Geometry
64
+ Skew / Stretch / Squash on cond noise predictions relative to uncond:
65
+ - **Stretch** β€” amplify perpendicular component (sharper positive)
66
+ - **Skew** β€” steer cond away from uncond rejection
67
+ - **Squash** β€” renormalise back to original cond magnitude
68
+
69
+ ### 🎚 Post-CFG Shaping (inside combine_denoised hook)
70
+
71
+ | Feature | Effect |
72
+ |---|---|
73
+ | Reinhard Tonemap | Smooth magnitude clamp on noise prediction β€” prevents "blowout" |
74
+ | Rescale CFG | Match result std to raw cond std β€” prevents latent inflation |
75
+ | Heuristic CFG | Quantile contrast normalisation β€” fights cartoon-ification at high CFG |
76
+ | Latent Intensity Rescale | Per-channel topk intensity target |
77
+ | Subtract Latent Mean | Removes spatial mean per batch item β€” reduces colour drift |
78
+
79
+ ### 🧭 Feature Activation Window (SEGA v5.1)
80
+
81
+ Controls **when** embedding features (Lerp, Skim) are active and **which branch** they modify.
82
+
83
+ | Control | Effect |
84
+ |---|---|
85
+ | **Warmup Steps** | Skip embedding features for first N steps (too noisy) |
86
+ | **End Step** | Deactivate after this step (0 = no limit) |
87
+ | **Lock after End** | Once End Step is crossed, stays off permanently β€” including HR pass |
88
+ | **Apply to Negative** | Lerp/Skim modifies `text_uncond` |
89
+ | **Apply to Positive** | Lerp/Skim also modifies `text_cond` |
90
+
91
+ **Lock after End** is the key SEGA idea: once the guidance window closes at End Step, the lock is latched (`lock_triggered = True`) and persists through the HiRes pass β€” so effects stop cleanly even mid-generation.
92
+
93
+ ### πŸ–Ό Pass Mode & HiRes Fix CFG
94
+
95
+ **Pass Mode** β€” which denoising pass activates CFG features:
96
+ - `Both passes` β€” normal (default)
97
+ - `First pass only` β€” CFG shaping only during base pass; HR pass uses plain defaults
98
+ - `HR pass only` β€” CFG shaping activates only for HiRes fix
99
+
100
+ **HR CFG Override** β€” independently swaps `p.cfg_scale` for the HR pass only.
101
+ Works regardless of Pass Mode. Lower values (1–4) often clean up HiRes upscaling.
102
+
103
+ ---
104
+
105
+ ## Installation
106
+
107
+ ```bash
108
+ cd stable-diffusion-webui/extensions
109
+ git clone <repo-url> cfg-prompt-forge
110
+ ```
111
+
112
+ Or **Extensions β†’ Install from URL** in WebUI.
113
+
114
+ ---
115
+
116
+ ## Recipes
117
+
118
+ ### High-quality portrait
119
+ ```
120
+ Skip Early CFG: 0.25 | NGMS: 0.50
121
+ Schedule: Ramp Down Γ—1.6 β†’ Γ—0.9
122
+ End Boost from 80% Γ—1.4
123
+ Reinhard: 4.0
124
+ Warmup: 5 steps | End Step: 0 (full)
125
+ Apply to Negative: βœ“
126
+ ```
127
+
128
+ ### Organic / painterly
129
+ ```
130
+ Jitter: Β±15% | Schedule: Bell Γ—0.6 β†’ Γ—1.5
131
+ Skim: Smooth scale=6.0
132
+ Subtract Mean: βœ“ | Apply to Negative: βœ“ | Apply to Positive: βœ“
133
+ Warmup: 8 | End Step: 25
134
+ ```
135
+
136
+ ### Speed run (30 steps)
137
+ ```
138
+ Skip Early CFG: 0.40 | NGMS: 1.20 | NGMS All Steps: βœ“
139
+ Auto CFG: hard ref=8 (NRS auto-disabled)
140
+ ```
141
+
142
+ ### Sharp upscale only during HR
143
+ ```
144
+ Pass Mode: HR pass only
145
+ End Boost from 85% Γ—1.3
146
+ Rescale CFG: 0.7
147
+ HR CFG Override: 2.0 (also active)
148
+ ```
149
+
150
+ ### Lock effects to base pass only, then clean HR
151
+ ```
152
+ Pass Mode: First pass only (schedule/boost off during HR)
153
+ Lock after End: βœ“ | End Step: [last step of base pass]
154
+ HR CFG Override: 3.0
155
+ ```
156
+
157
+ ---
158
+
159
+ ## Technical Notes
160
+
161
+ - **No model patching** β€” only standard A1111 hooks
162
+ - **`override_settings`** scoped to current generation, auto-restored
163
+ - **`combine_denoised` hook** installed/removed per generation; safe to stack
164
+ - **`lock_triggered`** state survives into HR pass (SEGA v5.1 pattern) β€” set in `_emb_features_active()`, cleared in `postprocess()`
165
+ - **`in_hr_pass`** flag set in `before_hr()` β€” allows Pass Mode to discriminate
166
+ - **SDXL compatible** β€” handles both plain tensor and `{"crossattn": ...}` dict conditioning
cfg-prompt-forge/javascript/cfg_prompt_forge.js ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * CFG & Prompt Forge v4 β€” UI enhancements
3
+ * β€’ Blue left-border on accordion when enabled
4
+ * β€’ Dims sub-sections when their parent Enable checkbox is off
5
+ * β€’ Status badge in accordion title
6
+ */
7
+ (function () {
8
+ "use strict";
9
+
10
+ function wireAccordion(accId) {
11
+ const acc = document.getElementById(accId);
12
+ if (!acc) return;
13
+ const mainCb = acc.querySelector('input[type="checkbox"]');
14
+ const btn = acc.querySelector("button[aria-expanded], button.label-wrap");
15
+ if (!mainCb || !btn) return;
16
+
17
+ function update() {
18
+ btn.style.borderLeft = mainCb.checked ? "3px solid #5bc8f5" : "";
19
+ btn.style.paddingLeft = mainCb.checked ? "9px" : "";
20
+ btn.style.background = mainCb.checked ? "rgba(91,200,245,.05)" : "";
21
+ }
22
+ mainCb.addEventListener("change", update);
23
+ update();
24
+ }
25
+
26
+ function wireDimming(accId) {
27
+ const acc = document.getElementById(accId);
28
+ if (!acc) return;
29
+
30
+ // For every checkbox whose label contains "Enable" (but not the master ✦)
31
+ acc.querySelectorAll('input[type="checkbox"]').forEach(cb => {
32
+ const lbl = (cb.closest("label") || cb.parentElement || {}).textContent || "";
33
+ if (!lbl.includes("Enable") || lbl.includes("✦")) return;
34
+
35
+ function apply() {
36
+ // Walk forward siblings until the next Enable checkbox group
37
+ let el = cb.closest('[class*="row"],[class*="form"],[class*="block"]')
38
+ || cb.parentElement;
39
+ if (!el) return;
40
+ let sib = el.nextElementSibling;
41
+ while (sib) {
42
+ const nextEn = sib.querySelector('input[type="checkbox"]');
43
+ if (nextEn) {
44
+ const nt = (nextEn.closest("label") || nextEn.parentElement || {}).textContent || "";
45
+ if (nt.includes("Enable")) break;
46
+ }
47
+ sib.style.opacity = cb.checked ? "1" : "0.4";
48
+ sib.style.pointerEvents = cb.checked ? "" : "none";
49
+ sib.style.transition = "opacity .15s";
50
+ sib = sib.nextElementSibling;
51
+ }
52
+ }
53
+ cb.addEventListener("change", apply);
54
+ apply();
55
+ });
56
+ }
57
+
58
+ function init() {
59
+ ["cpf_acc_t2i", "cpf_acc_i2i"].forEach(id => {
60
+ wireAccordion(id);
61
+ wireDimming(id);
62
+ });
63
+ }
64
+
65
+ let tries = 0;
66
+ const t = setInterval(() => {
67
+ if (document.getElementById("cpf_acc_t2i") ||
68
+ document.getElementById("cpf_acc_i2i") ||
69
+ ++tries > 100) {
70
+ clearInterval(t);
71
+ init();
72
+ }
73
+ }, 350);
74
+ })();
cfg-prompt-forge/scripts/__pycache__/cfg_prompt_forge.cpython-312.pyc ADDED
Binary file (90.9 kB). View file
 
cfg-prompt-forge/scripts/cfg_prompt_forge.py ADDED
@@ -0,0 +1,1926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════════════╗
3
+ β•‘ CFG & Prompt Forge v4 β€” Stable Diffusion WebUI Extension β•‘
4
+ β•‘ Stock A1111 only (NOT Forge-only APIs) β•‘
5
+ ╠══════════════════════════════════════════════════════════════════════════════╣
6
+ β•‘ Sources: β•‘
7
+ β•‘ β€’ sd-webui-dycfg β€” step-range combine hook β•‘
8
+ β•‘ β€’ sd-webui-skimmed_cfg β€” embedding-level uncond mask β•‘
9
+ β•‘ β€’ ComfyUI-AutomaticCFG β€” adaptive per-channel CFG, LIR, mean β•‘
10
+ β•‘ β€’ pre_cfg_comfy_nodes β€” lerp-uncond, smooth-skim math β•‘
11
+ β•‘ β€’ NRS β€” Skew/Stretch/Squash guidance geometry β•‘
12
+ β•‘ β€’ CFgfade β€” Reinhard, Rescale, Heuristic CFG β•‘
13
+ β•‘ β€’ HiresCFG β€” before_hr CFG override β•‘
14
+ β•‘ β€’ SEGA v5.1 β€” Warmup/EndStep/LockAfterEnd, Apply to β•‘
15
+ β•‘ Positive/Negative, Pass Mode β•‘
16
+ ╠══════════════════════════════════════════════════════════════════════════════╣
17
+ β•‘ HOOKS β•‘
18
+ β•‘ on_cfg_denoiser β€” schedule / boost / jitter / step-range / β•‘
19
+ β•‘ lerp-uncond / skimmed CFG (with SEGA window+lock) β•‘
20
+ β•‘ combine_denoised β€” Auto CFG / NRS / heuristic / Reinhard / Rescale / β•‘
21
+ β•‘ LIR / subtract-mean β•‘
22
+ β•‘ before_hr β€” HR-pass CFG swap + pass-mode arming β•‘
23
+ β•‘ postprocess_imageβ€” HR-pass CFG restore β•‘
24
+ β•‘ before_process β€” write override_settings, arm all state β•‘
25
+ β•‘ postprocess β€” disarm + unhook β•‘
26
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
27
+ """
28
+
29
+ import math
30
+ import random
31
+ from typing import Optional
32
+
33
+ import gradio as gr
34
+ import torch
35
+
36
+ from modules import scripts, script_callbacks
37
+ from modules.processing import StableDiffusionProcessingTxt2Img
38
+
39
+ try:
40
+ from modules.sd_samplers_cfg_denoiser import CFGDenoiser as _DCLS
41
+ except ImportError:
42
+ _DCLS = None
43
+
44
+ _EXT = "CPF"
45
+ _ORIG = "combine_denoised"
46
+ _SAVED = f"_{_EXT}_saved_{_ORIG}"
47
+
48
+
49
+ # ═══════════════════════════════════════════════════════════════════════════
50
+ # Schedule math
51
+ # ═══════════════════════════════════════════════════════════════════════════
52
+
53
+ _SCHED: dict = {
54
+ "off" : "Off (flat Γ— 1.0)",
55
+ "ramp_up" : "Ramp Up",
56
+ "ramp_down": "Ramp Down",
57
+ "cos_up" : "Cosine Ease-In",
58
+ "cos_down" : "Cosine Ease-Out",
59
+ "bell" : "Bell (low β†’ high β†’ low)",
60
+ "inv_bell" : "Inv. Bell (high β†’ low β†’ high)",
61
+ "step_up" : "Step Up (at midpoint)",
62
+ "step_down": "Step Down (at midpoint)",
63
+ }
64
+ _L2K = {v: k for k, v in _SCHED.items()}
65
+
66
+ # ── Step-schedule curves (used by multiple parameters) ────────────
67
+ _CURVES = [
68
+ "Off", "Linear Down", "Linear Up", "Cosine Down", "Cosine Up",
69
+ "Bell", "Inv. Bell", "Step Up", "Step Down",
70
+ "Power Down", "Power Up", "Repeating", "Sawtooth",
71
+ ]
72
+
73
+
74
+ def _sm(t, k, mn, mx):
75
+ t = max(0.0, min(1.0, t))
76
+ if k == "ramp_up" : return mn + (mx - mn) * t
77
+ if k == "ramp_down": return mx - (mx - mn) * t
78
+ if k == "cos_up" : return mn + (mx - mn) * (1 - math.cos(math.pi * t)) / 2
79
+ if k == "cos_down" : return mx - (mx - mn) * (1 - math.cos(math.pi * t)) / 2
80
+ if k == "bell" : return mn + (mx - mn) * math.sin(math.pi * t)
81
+ if k == "inv_bell" : return mx - (mx - mn) * math.sin(math.pi * t)
82
+ if k == "step_up" : return mn if t < 0.5 else mx
83
+ if k == "step_down": return mx if t < 0.5 else mn
84
+ return 1.0
85
+
86
+
87
+ def _schedule_value(base, step, total, curve, min_val, sched_val=2.0):
88
+ """Apply curve schedule to a float parameter over denoising steps.
89
+
90
+ Returns the parameter value at the given step.
91
+ curve='Off' β†’ returns base unchanged.
92
+ """
93
+ if curve == "Off":
94
+ return base
95
+ frac = step / max(total - 1, 1)
96
+ frac = max(0.0, min(1.0, frac))
97
+
98
+ if curve == "Linear Down": mult = 1.0 - frac
99
+ elif curve == "Linear Up": mult = frac
100
+ elif curve == "Cosine Down": mult = math.cos(frac * math.pi / 2)
101
+ elif curve == "Cosine Up": mult = 1.0 - math.cos(frac * math.pi / 2)
102
+ elif curve == "Bell": mult = math.sin(math.pi * frac)
103
+ elif curve == "Inv. Bell": mult = 1.0 - math.sin(math.pi * frac)
104
+ elif curve == "Step Up": mult = 0.0 if frac < 0.5 else 1.0
105
+ elif curve == "Step Down": mult = 1.0 if frac < 0.5 else 0.0
106
+ elif curve == "Power Down": mult = 1.0 - math.pow(frac, max(sched_val, 0.1))
107
+ elif curve == "Power Up": mult = math.pow(frac, max(sched_val, 0.1))
108
+ elif curve == "Repeating": mult = 1.0 - abs(2.0 * ((frac * sched_val) % 1.0) - 1.0)
109
+ elif curve == "Sawtooth": mult = (frac * sched_val) % 1.0
110
+ else:
111
+ return base
112
+ return min_val + (base - min_val) * mult
113
+
114
+
115
+ def _step_control_val(base_val, step, sc_enabled, r1_start, r1_end, r1_val,
116
+ r2_start, r2_end, r2_val, lock, locked, off_val):
117
+ """NRS-style Individual Step Control with Range1/Range2 + Lock after End.
118
+
119
+ Returns (value, new_locked_flag).
120
+ When locked β†’ (off_val, True).
121
+ When in range β†’ (range_val, locked unchanged).
122
+ Otherwise β†’ (base_val, locked unchanged).
123
+ """
124
+ if locked:
125
+ return off_val, True
126
+ if not sc_enabled:
127
+ return base_val, False
128
+
129
+ def _in(s, e):
130
+ return s > 0 and step >= s and (e == 0 or step <= e)
131
+
132
+ if _in(r1_start, r1_end):
133
+ return r1_val, False
134
+ if _in(r2_start, r2_end):
135
+ return r2_val, False
136
+
137
+ if lock:
138
+ # Only consider finite range ends (>0) for lock triggering.
139
+ ends = [e for e in (r1_end, r2_end) if e > 0]
140
+ if ends and step > max(ends):
141
+ return off_val, True
142
+ return base_val, False
143
+
144
+
145
+ # ═══════════════════════════════════════════════════════════════════════════
146
+ # Global state
147
+ # ═══════════════════════════════════════════════════════════════════════════
148
+
149
+ class _S:
150
+ enabled : bool = False
151
+ # core opts
152
+ skip_early : float = 0.0
153
+ word_wrap : int = 20
154
+ ngms : float = 0.0
155
+ ngms_all : bool = False
156
+ # step ranges
157
+ ranges_en : bool = False
158
+ range_scales : list = None
159
+ base_cfg : float = 7.0
160
+ # schedule
161
+ sched_en : bool = False
162
+ sched_k : str = "off"
163
+ sched_min : float = 0.5
164
+ sched_max : float = 1.5
165
+ # boost
166
+ boost_en : bool = False
167
+ boost_from : float = 0.80
168
+ boost_mul : float = 1.5
169
+ # jitter
170
+ jitter_en : bool = False
171
+ jitter_amt : float = 0.10
172
+ # auto cfg
173
+ autocfg_en : bool = False
174
+ autocfg_method : str = "hard"
175
+ autocfg_ref : float = 8.0
176
+ autocfg_topk : float = 0.25
177
+ # lerp uncond
178
+ lerp_en : bool = False
179
+ lerp_str : float = 1.0
180
+ # skimmed cfg
181
+ skim_en : bool = False
182
+ skim_mode : str = "Classic"
183
+ skim_scale : float = 7.0
184
+ skim_flip : bool = False
185
+ skim_target : str = "Uncond"
186
+ # NRS
187
+ nrs_en : bool = False
188
+ nrs_skew : float = 2.0
189
+ nrs_stretch : float = 5.0
190
+ nrs_squash : float = 0.75
191
+ # NRS v2 β€” Vector Softcap (feature 2)
192
+ nrs_softcap : float = 0.0
193
+ nrs_softcap_mode : str = "Per Sample"
194
+ # NRS v2 β€” Midpoint Refinement / Kohaku RK2 (feature 4)
195
+ nrs_midpoint : float = 0.0
196
+ nrs_midpoint_mode : str = "Classic"
197
+ nrs_midpoint_fh : bool = True
198
+ # TCFG β€” Tangential Damping via SVD (feature 3)
199
+ tcfg_en : bool = False
200
+ # Disagreement Gate β€” adaptive skew/stretch (feature 6)
201
+ dis_en : bool = False
202
+ dis_strength : float = 0.5
203
+ dis_threshold : float = 0.3
204
+ dis_metric : str = "Cosine"
205
+ # Inter-step smoothing β€” Momentum + EMA Delta (feature 7)
206
+ interp_mode : str = "None"
207
+ interp_mom : float = 0.0
208
+ interp_ema : float = 0.5
209
+ # Runtime context (NOT UI β€” populated in _on_cfg_denoiser, read in _custom_combine)
210
+ current_x : object = None # noisy latent params.x
211
+ current_sigma : object = None # params.sigma
212
+ current_step : int = 0
213
+ total_steps : int = 20
214
+ # Inter-step state (reset each generation and each HR pass)
215
+ _prev_nrs : object = None
216
+ _prev_vel : object = None
217
+ _prev_delta : object = None
218
+ # post-cfg
219
+ reinhard_en : bool = False
220
+ reinhard_ref : float = 4.0
221
+ rescale_en : bool = False
222
+ rescale_mult : float = 0.7
223
+ heuristic_en : bool = False
224
+ heuristic_cfg : float = 5.0
225
+ heuristic_hstart : float = 0.0
226
+ lir_en : bool = False
227
+ lir_cfg : float = 8.0
228
+ lir_method : str = "hard"
229
+ lir_topk : float = 0.25
230
+ mean_en : bool = False
231
+ # ── Step schedules (per-parameter curve over steps) ─────────────────────
232
+ nrs_sched_curve : str = "Off"
233
+ nrs_sched_min_skew : float = 0.0
234
+ nrs_sched_min_stretch : float = 0.0
235
+ skim_sched_curve : str = "Off"
236
+ skim_sched_min : float = 1.0
237
+ lerp_sched_curve : str = "Off"
238
+ lerp_sched_min : float = 0.5
239
+ heuristic_sched_curve : str = "Off"
240
+ heuristic_sched_min : float = 1.0
241
+ # Computed at each step by _on_cfg_denoiser
242
+ _sched_skim_scale : float = 0.0
243
+ _sched_lerp_str : float = 0.0
244
+ _sched_nrs_skew : float = 0.0
245
+ _sched_nrs_stretch : float = 0.0
246
+ _sched_heuristic_cfg : float = 0.0
247
+ # ── NRS Individual Step Control (Range1/Range2 + Lock after End) ────────
248
+ nrs_sc_enabled : bool = False
249
+ nrs_sc_r1_start : int = 0
250
+ nrs_sc_r1_end : int = 0
251
+ nrs_sc_r1_skew : float = 0.0
252
+ nrs_sc_r1_stretch: float = 0.0
253
+ nrs_sc_r2_start : int = 0
254
+ nrs_sc_r2_end : int = 0
255
+ nrs_sc_r2_skew : float = 0.0
256
+ nrs_sc_r2_stretch: float = 0.0
257
+ nrs_sc_lock : bool = False
258
+ nrs_sc_locked : bool = False # runtime flag
259
+ # ── SEGA-inspired activation window ─────────────────────────────────────
260
+ warmup_steps : int = 0 # first N steps: embedding features inactive
261
+ end_step : int = 0 # 0 = no limit; stop features after this step
262
+ lock_after_end : bool = False # once end_step crossed, stay off for HR pass too
263
+ lock_triggered : bool = False # runtime flag, NOT a UI control
264
+ apply_to_pos : bool = False # apply Lerp/Skim to text_cond as well
265
+ apply_to_neg : bool = True # apply Lerp/Skim to text_uncond
266
+ # ── pass mode ────────────────────────────────────────────────────────────
267
+ pass_mode : str = "Both passes" # "Both passes"/"First pass only"/"HR pass only"
268
+ in_hr_pass : bool = False # runtime flag
269
+ # HR CFG
270
+ hr_en : bool = False
271
+ hr_cfg : float = 1.0
272
+ hr_cfg_saved : Optional[float] = None
273
+
274
+
275
+ _st = _S()
276
+
277
+
278
+ # ═══════════════════════════════════════════════════════════════════════════
279
+ # SEGA-inspired activation-window helpers
280
+ # ═══════════════════════════════════════════════════════════════════════════
281
+
282
+ def _emb_features_active(step: int) -> bool:
283
+ """
284
+ Returns True if embedding-level features (Lerp, Skim) should run this step.
285
+
286
+ Rules (ported from SEGA v5.1 is_guidance_active + is_locked_after_end):
287
+ 1. Lock triggered (end_step crossed with lock_after_end) β†’ False permanently
288
+ 2. step < warmup_steps β†’ still warming up β†’ False
289
+ 3. end_step > 0 and step > end_step β†’ past window β†’ possibly lock, then False
290
+ """
291
+ if _st.lock_triggered:
292
+ return False
293
+ if step < _st.warmup_steps:
294
+ return False
295
+ if _st.end_step > 0 and step > _st.end_step:
296
+ if _st.lock_after_end:
297
+ _st.lock_triggered = True # permanent for HR pass too
298
+ return False
299
+ return True
300
+
301
+
302
+ def _pass_features_active() -> bool:
303
+ """
304
+ Returns True if CFG-multiplier features (schedule, boost, jitter, ranges)
305
+ should run given the current pass_mode.
306
+ """
307
+ if _st.pass_mode == "First pass only" and _st.in_hr_pass:
308
+ return False
309
+ if _st.pass_mode == "HR pass only" and not _st.in_hr_pass:
310
+ return False
311
+ return True
312
+
313
+
314
+ # ═══════════════════════════════════════════════════════════════════════════
315
+ # Step-range builder (DyCFG)
316
+ # ═══════════════════════════════════════════════════════════════════════════
317
+
318
+ def _build_ranges(total, base, rows):
319
+ scales = ["Default"] * total
320
+ for (s1, e1, v, intp) in rows:
321
+ s = max(1, int(s1)); e = int(e1) if int(e1) > 0 else total
322
+ if e < s: continue
323
+ v = max(1.0, float(v))
324
+ for i in range(s - 1, min(e, total)): scales[i] = v
325
+ for i in range(s - 2, -1, -1):
326
+ if isinstance(scales[i], str): scales[i] = intp
327
+ else: break
328
+ scales = ["Default"] + scales + ["Default"]
329
+ prev = base
330
+ for i in range(len(scales)):
331
+ if scales[i] == "Default": scales[i] = base
332
+ elif scales[i] == "Fixed": scales[i] = prev if isinstance(prev, float) else base
333
+ if isinstance(scales[i], float): prev = scales[i]
334
+ i = 0
335
+ while i < len(scales):
336
+ if scales[i] == "Linear":
337
+ pv = scales[i-1] if i > 0 and isinstance(scales[i-1], float) else base
338
+ j = i + 1
339
+ while j < len(scales) and scales[j] == "Linear": j += 1
340
+ nv = scales[j] if j < len(scales) and isinstance(scales[j], float) else base
341
+ d = (nv - pv) / (j - i + 1)
342
+ for k in range(i, j): scales[k] = pv + (k - i + 1) * d
343
+ i = j
344
+ else: i += 1
345
+ scales = scales[1:-1]
346
+ return [float(v) if isinstance(v, float) else base for v in scales]
347
+
348
+
349
+ # ═══════════════════════════════════════════════════════════════════════════
350
+ # AutoCFG helpers (AutomaticCFG)
351
+ # ═══════════════════════════════════════════════════════════════════════════
352
+
353
+ @torch.no_grad()
354
+ def _topk_range(flat_ch, method, topk):
355
+ k = max(1, int(flat_ch.numel() * topk))
356
+
357
+ data = flat_ch.float()
358
+ if method == "range":
359
+ data = data - data.mean()
360
+ mx = torch.topk(data, k, largest=True ).values.mean().item()
361
+ mna = torch.topk(data, k, largest=False).values.abs().mean().item()
362
+ else:
363
+ mx = torch.topk(data, k, largest=True ).values.mean().item()
364
+ mn_v = torch.topk(data, k, largest=False).values
365
+ mna = abs(mn_v.mean().item()) if method == "soft" else mn_v.abs().mean().item()
366
+
367
+ r = (mx + mna) / 2.0
368
+ return r * r if method == "hard_squared" else r
369
+
370
+
371
+ @torch.no_grad()
372
+ def _auto_cfg_combine(unc, conds_list, x_out, eff_scale, ref_cfg, method, topk):
373
+ target = eff_scale / 10.0
374
+ out = torch.zeros_like(unc)
375
+ for i, conds in enumerate(conds_list):
376
+ u = unc[i]
377
+ ref = u.clone()
378
+ for ci, w in conds: ref = ref + (x_out[ci] - u) * (w * ref_cfg)
379
+ C = ref.shape[0]
380
+ flat = ref.reshape(C, -1)
381
+ ranges = [_topk_range(flat[c], method, topk) for c in range(C)]
382
+ out[i] = u.clone()
383
+ for ci, w in conds:
384
+ delta = x_out[ci] - u
385
+ for c in range(C):
386
+ r = max(ranges[c], 1e-8)
387
+ out[i][c] = out[i][c] + delta[c] * (w * ref_cfg * target / r)
388
+ return out
389
+
390
+
391
+ # ═══════════════════════════════════════════════════════════════════════════
392
+ # NRS v2 (NRS Kohaku Enhanced v3.5 β€” sigma-aware latent-space math)
393
+ # Features: proper sigma conversion, Vector Softcap, Midpoint Refinement,
394
+ # TCFG (SVD tangential damping), Disagreement Gate, Momentum + EMA Delta
395
+ # ═══════════════════════════════════════════════════════════════════════════
396
+
397
+ def _nrs_log_skip(name, exc=None):
398
+ """Silent skip β€” no crash on NRS sub-feature error."""
399
+ pass
400
+
401
+
402
+ def _validate_4d(t, tag=""):
403
+ return (torch.is_tensor(t) and t.dim() == 4
404
+ and t.numel() > 0 and torch.isfinite(t).all())
405
+
406
+
407
+ @torch.no_grad()
408
+ def _nrs_softcap_vector(v: torch.Tensor, strength: float,
409
+ mode: str = "Per Sample") -> torch.Tensor:
410
+ """
411
+ Norm-based tanh softcap β€” preserves direction, compresses magnitude.
412
+ Source: NRS Kohaku Enhanced v3.5 _nrs_softcap_vector()
413
+ Near-zero: linear. Large values: saturate smoothly.
414
+ """
415
+ if strength <= 0.0:
416
+ return v
417
+ try:
418
+ eps = 1e-6
419
+ s = max(strength, eps)
420
+ if mode == "Per Channel":
421
+ v_abs = v.abs()
422
+ cap = (v_abs.mean(dim=(2, 3), keepdim=True) * s).clamp(min=eps)
423
+ ratio = v_abs / cap
424
+ scale = torch.tanh(ratio) / ratio.clamp(min=eps)
425
+ result = v * scale
426
+ elif mode == "Global Batch":
427
+ v_norm = v.norm(p=2, dim=(1, 2, 3), keepdim=True).clamp(min=eps)
428
+ cap = (v_norm.mean() * s).clamp(min=eps)
429
+ ratio = v_norm / cap
430
+ scale = torch.tanh(ratio) / ratio.clamp(min=eps)
431
+ result = v * scale
432
+ else: # "Per Sample" (default)
433
+ v_norm = v.norm(p=2, dim=1, keepdim=True).clamp(min=eps)
434
+ cap = (v_norm.mean(dim=(2, 3), keepdim=True) * s).clamp(min=eps)
435
+ ratio = v_norm / cap
436
+ scale = torch.tanh(ratio) / ratio.clamp(min=eps)
437
+ result = v * scale
438
+ return result if torch.isfinite(result).all() else v
439
+ except Exception:
440
+ return v
441
+
442
+
443
+ @torch.no_grad()
444
+ def _nrs_core(x_orig, cond, uncond, sigma,
445
+ skew: float, stretch: float, squash: float,
446
+ vector_softcap: float = 0.0,
447
+ vector_softcap_mode: str = "Per Sample") -> torch.Tensor:
448
+ """
449
+ Proper sigma-aware NRS kernel.
450
+ Source: NRS Kohaku Enhanced v3.5 _nrs_core()
451
+
452
+ Converts denoised predictions to NRS-space via sigma, applies
453
+ Skew/Stretch/Squash geometry, then converts back.
454
+ Supports v-prediction models (auto-detected via sd_model.parameterization).
455
+ """
456
+ if not _validate_4d(x_orig, "nrs_x") or not _validate_4d(cond, "nrs_c"):
457
+ return cond
458
+
459
+ try:
460
+ # Detect v-prediction model
461
+ from modules import shared as _shared
462
+ is_v_pred = (hasattr(_shared.sd_model, "parameterization") and
463
+ _shared.sd_model.parameterization == "v")
464
+ except Exception:
465
+ is_v_pred = False
466
+
467
+ # Normalise sigma to a [1,1,1,1] tensor
468
+ if isinstance(sigma, torch.Tensor):
469
+ sig = sigma.flatten()[0].to(cond.dtype)
470
+ else:
471
+ sig = torch.tensor(float(sigma), device=cond.device, dtype=cond.dtype)
472
+ sig = sig.view(1, 1, 1, 1).clamp(min=1e-6)
473
+ sig_root = (sig ** 2 + 1).sqrt()
474
+
475
+ # Convert to NRS geometry space
476
+ if is_v_pred:
477
+ nrs_c, nrs_u = cond, uncond
478
+ else:
479
+ x_div = x_orig / (sig ** 2 + 1)
480
+ factor = sig / sig_root
481
+ nrs_c = x_orig - (x_div - cond * factor)
482
+ nrs_u = x_orig - (x_div - uncond * factor)
483
+
484
+ eps = 1e-6
485
+
486
+ def _dot(a, b): return (a * b).sum(dim=1, keepdim=True)
487
+ def _nrm2(v): return _dot(v, v)
488
+
489
+ c2c = _nrm2(nrs_c) + eps
490
+ u2c = _dot(nrs_u, nrs_c)
491
+ u_on_c = (u2c / c2c) * nrs_c
492
+
493
+ proj_diff = nrs_c - u_on_c
494
+ if vector_softcap > 0.0:
495
+ proj_diff = _nrs_softcap_vector(proj_diff, vector_softcap, vector_softcap_mode)
496
+ stretched = nrs_c + stretch * proj_diff
497
+
498
+ u_rej_c = nrs_u - u_on_c
499
+ if vector_softcap > 0.0:
500
+ u_rej_c = _nrs_softcap_vector(u_rej_c, vector_softcap, vector_softcap_mode)
501
+ skewed = stretched - skew * u_rej_c
502
+
503
+ # Squash: normalise back to original cond length (L2)
504
+ c_len = nrs_c.norm(dim=1, keepdim=True)
505
+ s_len = skewed.norm(dim=1, keepdim=True) + eps
506
+ sq_scale = ((1.0 - squash) + squash * (c_len / s_len)).clamp(-10.0, 10.0)
507
+ x_final = skewed * sq_scale
508
+
509
+ if not torch.isfinite(x_final).all():
510
+ return cond
511
+
512
+ # Convert back to denoised space
513
+ if is_v_pred:
514
+ return x_final
515
+ else:
516
+ return (x_div - (x_orig - x_final)) * (sig_root / sig)
517
+
518
+
519
+ @torch.no_grad()
520
+ def _nrs_midpoint_refined(x_orig, cond, uncond, sigma,
521
+ skew, stretch, squash,
522
+ blend: float = 0.5,
523
+ mode: str = "Classic",
524
+ first_half_only: bool = True,
525
+ step: int = 0, total: int = 20,
526
+ softcap: float = 0.0,
527
+ softcap_mode: str = "Per Sample") -> torch.Tensor:
528
+ """
529
+ Kohaku RK2 midpoint refinement for NRS.
530
+ Source: NRS Kohaku Enhanced v3.5 calc_nrs_midpoint_refined()
531
+
532
+ nrs_direct = NRS(x, c, u)
533
+ x_mid = x + (nrs_direct - x) * blend * 0.5 ← midpoint shift
534
+ nrs_refined = NRS(x_mid, c, u)
535
+ result = (nrs_direct + nrs_refined) / 2 ← RK2 average
536
+ """
537
+ def _nrs(x, c, u):
538
+ return _nrs_core(x, c, u, sigma, skew, stretch, squash, softcap, softcap_mode)
539
+
540
+ nrs_direct = _nrs(x_orig, cond, uncond)
541
+
542
+ if blend <= 0.0:
543
+ return nrs_direct
544
+ if first_half_only and step > total / 2:
545
+ return nrs_direct
546
+
547
+ x_mid = x_orig + (nrs_direct - x_orig) * (blend * 0.5)
548
+
549
+ if mode == "Conservative":
550
+ nrs_ref = _nrs_core(x_mid, cond, uncond, sigma,
551
+ skew * 0.75, stretch * 0.75, squash,
552
+ softcap, softcap_mode)
553
+ else:
554
+ nrs_ref = _nrs(x_mid, cond, uncond)
555
+
556
+ if mode == "Directional Midpoint":
557
+ try:
558
+ correction = nrs_ref - nrs_direct
559
+ base_dir = nrs_direct - x_orig
560
+ dims = tuple(range(1, correction.ndim))
561
+ guide_norm = base_dir.norm(p=2, dim=dims, keepdim=True)
562
+ if torch.any(guide_norm > 1e-6):
563
+ # Project correction onto base_dir direction
564
+ c_dot_g = (correction * base_dir).sum(dim=dims, keepdim=True)
565
+ g_dot_g = (base_dir * base_dir).sum(dim=dims, keepdim=True) + 1e-9
566
+ c_para = (c_dot_g / g_dot_g) * base_dir
567
+ c_orth = correction - c_para
568
+ nrs_ref = nrs_direct + c_para + c_orth * 0.75
569
+ except Exception:
570
+ pass
571
+
572
+ return (nrs_direct + nrs_ref) * 0.5
573
+
574
+
575
+ # ── TCFG: Tangential Damping via SVD (NRS Kohaku #33) ───────────────────
576
+
577
+ @torch.no_grad()
578
+ def _tcfg_uncond(cond: torch.Tensor, uncond: torch.Tensor,
579
+ step: int = -1) -> torch.Tensor:
580
+ """
581
+ TCFG: filter tangential components from uncond that are misaligned with cond.
582
+ Source: arxiv 2503.18137 (Kwon et al., CVPR 2025) / NRS Kohaku apply_tcfg_uncond()
583
+
584
+ Algorithm (10 lines):
585
+ 1. Stack [cond, uncond] β†’ [B, 2, N]
586
+ 2. SVD β†’ Vh: [B, 2, N]
587
+ 3. Zero out second right singular vector (tangential)
588
+ 4. Project uncond through zeroed Vh
589
+ """
590
+ if not _validate_4d(cond, "tcfg_c") or not _validate_4d(uncond, "tcfg_u"):
591
+ return uncond
592
+ if cond.shape != uncond.shape:
593
+ return uncond
594
+ orig_dtype = uncond.dtype
595
+ orig_device = uncond.device
596
+ try:
597
+ if not torch.isfinite(cond).all() or not torch.isfinite(uncond).all():
598
+ return uncond
599
+ B = cond.shape[0]
600
+ cf = cond.float(); uf = uncond.float()
601
+ all_n = torch.stack((cf, uf), dim=1).reshape(B, 2, -1) # [B,2,N]
602
+ _U, _S, Vh = torch.linalg.svd(all_n, full_matrices=False) # Vh:[B,2,N]
603
+ Vh_mod = Vh.clone(); Vh_mod[:, 1, :] = 0.0
604
+ uf_flat = uf.reshape(B, 1, -1) # [B,1,N]
605
+ x_Vh = torch.matmul(uf_flat, Vh.transpose(-2, -1)) # [B,1,2]
606
+ result = torch.matmul(x_Vh, Vh_mod).reshape(*uncond.shape)
607
+ if not torch.isfinite(result).all():
608
+ return uncond
609
+ return result.to(device=orig_device, dtype=orig_dtype)
610
+ except Exception:
611
+ return uncond
612
+
613
+
614
+ # ── Disagreement Gate (NRS Kohaku #38) ──────────────────────────────────
615
+
616
+ @torch.no_grad()
617
+ def _disagreement_gate(cond: torch.Tensor, uncond: torch.Tensor,
618
+ skew: float, stretch: float,
619
+ strength: float = 0.5, threshold: float = 0.3,
620
+ metric: str = "Cosine") -> tuple:
621
+ """
622
+ Scale skew/stretch by how much cond and uncond disagree.
623
+ Source: NRS Kohaku Enhanced apply_disagreement_gate()
624
+
625
+ gate = strength + (1 - strength) * clamp(disagreement / threshold, 0, 1)
626
+ At zero disagreement β†’ gate = strength (minimum scaling)
627
+ At full disagreement β†’ gate = 1.0 (full skew/stretch)
628
+ """
629
+ if not torch.is_tensor(cond) or not torch.is_tensor(uncond):
630
+ return skew, stretch
631
+ try:
632
+ eps = 1e-6
633
+ c = cond.float().reshape(cond.shape[0], -1)
634
+ u = uncond.float().reshape(uncond.shape[0], -1)
635
+ if metric == "L2":
636
+ diff_norm = (c - u).norm(p=2, dim=1)
637
+ ref_norm = (c.norm(p=2, dim=1) + u.norm(p=2, dim=1)) / 2.0 + eps
638
+ raw_dist = (diff_norm / ref_norm).mean().clamp(0.0, 2.0) / 2.0
639
+ else: # Cosine
640
+ cn = torch.nn.functional.normalize(c, dim=1)
641
+ un = torch.nn.functional.normalize(u, dim=1)
642
+ raw_dist = ((1.0 - (cn * un).sum(dim=1).mean()) / 2.0).clamp(0.0, 1.0)
643
+ if not torch.isfinite(raw_dist):
644
+ return skew, stretch
645
+ thr = max(threshold, eps)
646
+ dis = float((raw_dist / thr).clamp(0.0, 1.0).item())
647
+ gate = strength + (1.0 - strength) * dis
648
+ return skew * gate, stretch * gate
649
+ except Exception:
650
+ return skew, stretch
651
+
652
+
653
+ # ── Inter-step smoothing: Momentum + EMA Delta (NRS Kohaku) ─────────────
654
+
655
+ @torch.no_grad()
656
+ def _apply_momentum(result, prev_result, prev_vel, momentum: float):
657
+ """
658
+ RES/Clybius momentum smoothing between denoising steps.
659
+ Source: NRS Kohaku apply_nrs_momentum()
660
+ Formula: vel = m*(1-m/2)*prev_vel + (1-m*(1-m/2))*curr_diff
661
+ result = prev_result + vel
662
+ Returns (smoothed_result, new_vel)
663
+ """
664
+ if momentum <= 0.0 or prev_result is None:
665
+ curr = result - prev_result if prev_result is not None else None
666
+ return result, curr
667
+ try:
668
+ curr_diff = result - prev_result
669
+ eff_m = momentum * (1.0 - momentum * 0.5)
670
+ vel = (eff_m * prev_vel + (1.0 - eff_m) * curr_diff
671
+ if prev_vel is not None else curr_diff)
672
+ return prev_result + vel, vel
673
+ except Exception:
674
+ return result, None
675
+
676
+
677
+ @torch.no_grad()
678
+ def _apply_ema_delta(result, prev_result, prev_delta, ema_alpha: float):
679
+ """
680
+ EMA smoothing in update-vector space.
681
+ Source: NRS Kohaku apply_nrs_ema_delta()
682
+ Smooths the step-to-step update delta, not the full result tensor.
683
+ Returns (smoothed_result, new_delta_state)
684
+ """
685
+ if prev_result is None:
686
+ return result, None
687
+ try:
688
+ curr_delta = result - prev_result
689
+ delta_state = (torch.lerp(prev_delta, curr_delta, ema_alpha)
690
+ if prev_delta is not None else curr_delta)
691
+ return prev_result + delta_state, delta_state
692
+ except Exception:
693
+ return result, None
694
+
695
+
696
+ # ═══════════════════════════════════════════════════════════════════════════
697
+ # Skimmed CFG (sd-webui-skimmed_cfg)
698
+ # ═══════════════════════════════════════════════════════════════════════════
699
+
700
+ @torch.no_grad()
701
+ def _skim_compute(c, u, cs, ss, no_flip):
702
+ """Compute skimmed CFG mask.
703
+ Returns (mask, cond_correction, uncond_correction) where:
704
+ - cond_correction is subtracted from cond to move it toward uncond
705
+ - uncond_correction is ADDED to uncond to move it toward cond
706
+ Formula matches original skimmed_CFG reference with x_orig = uncond."""
707
+ den = u - (cs * ((u - c) - (u - u)))
708
+ ms = (c - u).sign() == c.sign()
709
+ md = c.sign() == (c * cs - u * (cs - 1)).sign()
710
+ if no_flip:
711
+ mask = ms & md
712
+ else:
713
+ mask = ms & md & (den.sign() == (den - u).sign())
714
+ low = u - (ss * ((u - c) - (u - u)))
715
+ # correction for cond (subtract): (den - low) / cs
716
+ # correction for uncond (add): (den - low) / cs (same magnitude, opposite direction)
717
+ corr = (den - low) / cs
718
+ return mask, corr, corr
719
+
720
+ @torch.no_grad()
721
+ def _skim_smooth(u, c, cs, ss):
722
+ df = u - cs * c - (cs - 1) * u
723
+ ds = u - ss * c - (ss - 1) * u
724
+ ad = (df - ds).abs(); span = ad.max() - ad.min()
725
+ ad = (ad - ad.min()) / (span + 1e-8)
726
+ nr = (ss - 1) / max(cs - 1, 1e-6)
727
+ return (c * (1 - nr) + u * nr) * (1 - ad) + u * ad
728
+
729
+
730
+ def _get_emb(raw):
731
+ """Extract crossattn tensor from raw conditioning (handles SD1.x and SDXL dict)."""
732
+ if isinstance(raw, dict):
733
+ return raw.get("crossattn", None)
734
+ return raw if torch.is_tensor(raw) else None
735
+
736
+
737
+ def _set_emb(raw_ref, new_tensor):
738
+ """Write back modified tensor, preserving dict wrapper for SDXL.
739
+ NEVER changes the shape of the original β€” only replaces values in-place."""
740
+ if isinstance(raw_ref, dict):
741
+ out = raw_ref.__class__(raw_ref); out["crossattn"] = new_tensor; return out
742
+ return new_tensor
743
+
744
+
745
+ def _emb_window(c: torch.Tensor, u: torch.Tensor):
746
+ """Return slices of c and u at the MINIMUM shared sequence length.
747
+ This guarantees we NEVER extend either tensor β€” A1111 crashes if shape changes."""
748
+ sl = min(c.shape[1], u.shape[1])
749
+ return c[:, :sl], u[:, :sl], sl
750
+
751
+
752
+ def _apply_skimmed(params) -> None:
753
+ """Apply Skimmed CFG to text_uncond and/or text_cond.
754
+ SAFETY: only plain tensors or SDXL dicts are processed; never changes shape.
755
+ skim_target controls which tensor gets modified:
756
+ - "Uncond": modify uncond directly (port default, moves uncond toward cond)
757
+ - "Cond β†’ Uncond": modify cond and assign to uncond (original reference behavior,
758
+ makes condβ‰ˆuncond at mask positions β†’ zero guidance from those positions)
759
+ - "Both": modify both cond and uncond"""
760
+ if not _st.apply_to_neg and not _st.apply_to_pos:
761
+ return
762
+ rc = params.text_cond
763
+ ru = params.text_uncond
764
+ if ru is None:
765
+ return
766
+
767
+ c = _get_emb(rc)
768
+ u = _get_emb(ru)
769
+ if c is None or u is None:
770
+ return
771
+ if not torch.is_tensor(c) or not torch.is_tensor(u):
772
+ return
773
+ if c.dim() != 3 or u.dim() != 3:
774
+ return
775
+
776
+ c_w, u_w, sl = _emb_window(c, u)
777
+
778
+ d = params.denoiser
779
+ cfg_base = (_st.hr_cfg if (_st.in_hr_pass and _st.hr_en) else _st.base_cfg)
780
+ cs = max(d.cond_scale_miltiplier * cfg_base, 1e-6)
781
+
782
+ eff_scale = _st._sched_skim_scale
783
+ if _st.skim_mode == "Smooth":
784
+ # _skim_smooth always produces a modified uncond
785
+ new_u_w = _skim_smooth(u_w, c_w, cs, eff_scale)
786
+ if _st.apply_to_neg:
787
+ new_u = u.clone()
788
+ new_u[:, :sl] = new_u_w
789
+ params.text_uncond = _set_emb(ru, new_u)
790
+ if _st.apply_to_pos:
791
+ new_c = c.clone()
792
+ new_c[:, :sl] = _skim_smooth(c_w, u_w, cs, eff_scale)
793
+ params.text_cond = _set_emb(rc, new_c)
794
+ return
795
+
796
+ # Classic mode: compute mask + corrections
797
+ mask, corr_c, corr_u = _skim_compute(c_w, u_w, cs, eff_scale, _st.skim_flip)
798
+ target = _st.skim_target
799
+
800
+ if target == "Cond β†’ Uncond":
801
+ if _st.apply_to_neg:
802
+ new_c_for_u = c.clone()
803
+ new_c_for_u[:, :sl][mask] = c_w[mask] - corr_c[mask]
804
+ params.text_uncond = _set_emb(ru, new_c_for_u)
805
+ if _st.apply_to_pos:
806
+ new_c = c.clone()
807
+ new_c[:, :sl][mask] = c_w[mask] - corr_c[mask]
808
+ params.text_cond = _set_emb(rc, new_c)
809
+
810
+ elif target == "Uncond":
811
+ if _st.apply_to_neg:
812
+ new_u = u.clone()
813
+ new_u[:, :sl][mask] = u_w[mask] + corr_u[mask]
814
+ params.text_uncond = _set_emb(ru, new_u)
815
+ if _st.apply_to_pos:
816
+ new_c = c.clone()
817
+ new_c[:, :sl][mask] = c_w[mask] - corr_c[mask]
818
+ params.text_cond = _set_emb(rc, new_c)
819
+
820
+ elif target == "Both":
821
+ if _st.apply_to_neg:
822
+ new_u = u.clone()
823
+ new_u[:, :sl][mask] = u_w[mask] + corr_u[mask] * 0.5
824
+ params.text_uncond = _set_emb(ru, new_u)
825
+ if _st.apply_to_pos:
826
+ new_c = c.clone()
827
+ new_c[:, :sl][mask] = c_w[mask] - corr_c[mask] * 0.5
828
+ params.text_cond = _set_emb(rc, new_c)
829
+
830
+
831
+ # ═══════════════════════════════════════════════════════════════════════════
832
+ # Lerp Uncond (AutomaticCFG / pre_cfg_comfy)
833
+ # Blend text_cond/uncond toward each other, preserving original norm.
834
+ # ═══════════════════════════════════════════════════════════════════════════
835
+
836
+ @torch.no_grad()
837
+ def _apply_lerp_uncond(params) -> None:
838
+ rc = params.text_cond
839
+ ru = params.text_uncond
840
+ if ru is None:
841
+ return
842
+
843
+ c = _get_emb(rc)
844
+ u = _get_emb(ru)
845
+ if c is None or u is None:
846
+ return
847
+ if not torch.is_tensor(c) or not torch.is_tensor(u):
848
+ return
849
+ if c.dim() != 3 or u.dim() != 3:
850
+ return
851
+
852
+ # Work on MINIMUM shared length only β€” never extend
853
+ c_w, u_w, sl = _emb_window(c, u)
854
+
855
+ eff_lerp = _st._sched_lerp_str
856
+ if _st.apply_to_neg:
857
+ u_norm = u.norm()
858
+ nu_w = torch.lerp(c_w, u_w, eff_lerp)
859
+ nu_n = nu_w.norm()
860
+ if nu_n > 1e-8:
861
+ nu_w = nu_w * (u_norm / nu_n)
862
+ new_u = u.clone()
863
+ new_u[:, :sl] = nu_w
864
+ params.text_uncond = _set_emb(ru, new_u)
865
+
866
+ if _st.apply_to_pos:
867
+ c_norm = c.norm()
868
+ nc_w = torch.lerp(u_w, c_w, eff_lerp)
869
+ nc_n = nc_w.norm()
870
+ if nc_n > 1e-8:
871
+ nc_w = nc_w * (c_norm / nc_n)
872
+ new_c = c.clone()
873
+ new_c[:, :sl] = nc_w
874
+ params.text_cond = _set_emb(rc, new_c)
875
+
876
+
877
+ # ═══════════════════════════════════════════════════════════════════════════
878
+ # on_cfg_denoiser callback (registered once at import)
879
+ # ═══════════════════════════════════════════════════════════════════════════
880
+
881
+ def _cond_has_shape(v) -> bool:
882
+ """Return True if v is something A1111 can call .shape[1] on without crashing.
883
+
884
+ A1111 sd_samplers_cfg_denoiser.py line 243 does:
885
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
886
+ When pad_cond_uncond settings are OFF (A1111 default), this is the FIRST
887
+ place .shape is accessed after our callback. If v is a dict (SDXL-style
888
+ conditioning set by another extension) A1111 crashes with:
889
+ AttributeError: 'dict' object has no attribute 'shape'
890
+ We guard against that here.
891
+ """
892
+ if torch.is_tensor(v):
893
+ return True
894
+ if isinstance(v, dict):
895
+ ca = v.get("crossattn", None)
896
+ return torch.is_tensor(ca)
897
+ return hasattr(v, "shape")
898
+
899
+
900
+ def _on_cfg_denoiser(params) -> None:
901
+ if not _st.enabled:
902
+ return
903
+
904
+ d = params.denoiser
905
+ total = max(1, params.total_sampling_steps)
906
+ step = params.sampling_step
907
+ t = step / total
908
+
909
+ # Store context needed by _custom_combine (sigma-aware NRS, midpoint, inter-step)
910
+ _st.current_x = params.x
911
+ _st.current_sigma = params.sigma
912
+ _st.current_step = step
913
+ _st.total_steps = total
914
+
915
+ # ── Compute scheduled parameter values for this step ──────────────────
916
+ _st._sched_skim_scale = _schedule_value(
917
+ _st.skim_scale, step, total, _st.skim_sched_curve, _st.skim_sched_min)
918
+ _st._sched_lerp_str = _schedule_value(
919
+ _st.lerp_str, step, total, _st.lerp_sched_curve, _st.lerp_sched_min)
920
+ _st._sched_nrs_skew = _schedule_value(
921
+ _st.nrs_skew, step, total, _st.nrs_sched_curve, _st.nrs_sched_min_skew)
922
+ _st._sched_nrs_stretch = _schedule_value(
923
+ _st.nrs_stretch, step, total, _st.nrs_sched_curve, _st.nrs_sched_min_stretch)
924
+ _st._sched_heuristic_cfg = _schedule_value(
925
+ _st.heuristic_cfg, step, total, _st.heuristic_sched_curve, _st.heuristic_sched_min)
926
+
927
+ # ── NRS Individual Step Control overrides (Range1/Range2 + Lock) ──────
928
+ _st._sched_nrs_skew, _st.nrs_sc_locked = _step_control_val(
929
+ _st._sched_nrs_skew, step,
930
+ _st.nrs_sc_enabled,
931
+ _st.nrs_sc_r1_start, _st.nrs_sc_r1_end, _st.nrs_sc_r1_skew,
932
+ _st.nrs_sc_r2_start, _st.nrs_sc_r2_end, _st.nrs_sc_r2_skew,
933
+ _st.nrs_sc_lock, _st.nrs_sc_locked, off_val=0.0)
934
+ _st._sched_nrs_stretch, _ = _step_control_val(
935
+ _st._sched_nrs_stretch, step,
936
+ _st.nrs_sc_enabled,
937
+ _st.nrs_sc_r1_start, _st.nrs_sc_r1_end, _st.nrs_sc_r1_stretch,
938
+ _st.nrs_sc_r2_start, _st.nrs_sc_r2_end, _st.nrs_sc_r2_stretch,
939
+ _st.nrs_sc_lock, _st.nrs_sc_locked, off_val=0.0)
940
+
941
+ # ── Pass-mode gate (schedule / boost / jitter / ranges) ──────────────
942
+ if _pass_features_active():
943
+ d.cond_scale_miltiplier = 1.0
944
+
945
+ if _st.ranges_en and _st.range_scales and step < len(_st.range_scales):
946
+ d.cond_scale_miltiplier = _st.range_scales[step] / max(_st.base_cfg, 1e-4)
947
+
948
+ if _st.sched_en and _st.sched_k != "off":
949
+ d.cond_scale_miltiplier *= _sm(t, _st.sched_k, _st.sched_min, _st.sched_max)
950
+
951
+ if _st.boost_en and t >= _st.boost_from:
952
+ ramp = (t - _st.boost_from) / max(1e-9, 1.0 - _st.boost_from)
953
+ d.cond_scale_miltiplier *= 1.0 + (_st.boost_mul - 1.0) * ramp
954
+
955
+ if _st.jitter_en and _st.jitter_amt > 0:
956
+ d.cond_scale_miltiplier *= 1.0 + random.uniform(-_st.jitter_amt, _st.jitter_amt)
957
+
958
+ d.cond_scale_miltiplier = max(0.01, d.cond_scale_miltiplier)
959
+ else:
960
+ d.cond_scale_miltiplier = 1.0
961
+
962
+ # ── Embedding-level features (SEGA-inspired window + lock guard) ──────
963
+ if _emb_features_active(step) and _pass_features_active():
964
+ # SAVE originals β€” restored on any failure or bad output
965
+ _orig_cond = params.text_cond
966
+ _orig_uncond = params.text_uncond
967
+
968
+ # Skip embedding features if text_cond is not in a state we can work with.
969
+ # Another extension (e.g. NRS Kohaku) may have already turned text_cond
970
+ # into a plain dict without a "crossattn" key, which would crash A1111
971
+ # at line 243 of sd_samplers_cfg_denoiser.py regardless of our changes.
972
+ if not _cond_has_shape(_orig_cond) or not _cond_has_shape(_orig_uncond):
973
+ return # don't touch anything β€” let A1111 handle it
974
+
975
+ try:
976
+ if _st.lerp_en:
977
+ _apply_lerp_uncond(params)
978
+ if _st.skim_en:
979
+ _apply_skimmed(params)
980
+ except Exception:
981
+ # On any failure, restore to the guaranteed-valid originals
982
+ params.text_cond = _orig_cond
983
+ params.text_uncond = _orig_uncond
984
+ return
985
+
986
+ # Post-modification validation: ensure we haven't corrupted the shape.
987
+ # If something produced a non-.shape-able object, roll back.
988
+ if not _cond_has_shape(params.text_cond):
989
+ params.text_cond = _orig_cond
990
+ if not _cond_has_shape(params.text_uncond):
991
+ params.text_uncond = _orig_uncond
992
+
993
+
994
+ script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
995
+
996
+ # ── Late-fixup callback (registered lazily in before_process) ────────────────
997
+ # WHY THIS EXISTS:
998
+ # A1111 loads extensions alphabetically. "cfg-prompt-forge" loads before
999
+ # extensions starting with letters later than 'c' (e.g. "nrs_kohaku…").
1000
+ # Callbacks registered at import time fire in registration order, so NRS Kohaku's
1001
+ # callback runs AFTER ours. NRS Kohaku can wrap text_cond in an SDXL-style dict
1002
+ # even when the model is SD1.x, causing A1111 line-243 to crash:
1003
+ # if tensor.shape[1] == uncond.shape[1] β†’ AttributeError: dict has no .shape
1004
+ #
1005
+ # Solution: register _late_cond_fixup lazily inside before_process.
1006
+ # Because before_process runs AFTER all import-time registrations are done,
1007
+ # our fixup is added to the END of the callback list and fires LAST β€” after
1008
+ # every other extension that might corrupt the conditioning.
1009
+
1010
+ _fixup_registered: bool = False # module-level flag: only register once
1011
+
1012
+
1013
+ def _late_cond_fixup(params) -> None:
1014
+ """Last-resort fixup with explicit model architecture detection.
1015
+ Unwraps dict-style conditioning back to plain tensors for SD1.x/SD2.x
1016
+ models, while preserving the dict wrapper for SDXL/SD3 models that
1017
+ genuinely require it.
1018
+
1019
+ Previous heuristic checked for "vector" key, but some extensions add
1020
+ a "vector" key to SD1.x conditioning as well, causing false negatives.
1021
+ Using shared.sd_model.is_sdxl / is_sd3 is authoritative.
1022
+ """
1023
+ from modules import shared
1024
+
1025
+ m = getattr(shared, 'sd_model', None)
1026
+ expects_dict = getattr(m, 'is_sdxl', False) or getattr(m, 'is_sd3', False)
1027
+
1028
+ def _try_unwrap(v):
1029
+ if not isinstance(v, dict):
1030
+ return v
1031
+ ca = v.get("crossattn", None)
1032
+ if torch.is_tensor(ca) and not expects_dict:
1033
+ return ca
1034
+ return v
1035
+
1036
+ params.text_cond = _try_unwrap(params.text_cond)
1037
+ params.text_uncond = _try_unwrap(params.text_uncond)
1038
+
1039
+
1040
+ # ═════════════════════════════════════════���═════════════════════════════════
1041
+ # combine_denoised class-patch
1042
+ # ═══════════════════════════════════════════════════════════════════════════
1043
+
1044
+ def _hook_combine() -> None:
1045
+ if _DCLS is None: return
1046
+ if hasattr(_DCLS, _SAVED):
1047
+ setattr(_DCLS, _ORIG, getattr(_DCLS, _SAVED))
1048
+ delattr(_DCLS, _SAVED)
1049
+ orig = getattr(_DCLS, _ORIG)
1050
+ setattr(_DCLS, _SAVED, orig)
1051
+
1052
+ def _hooked(self, x_out, conds_list, uncond, cond_scale):
1053
+ if not _st.enabled:
1054
+ return orig(self, x_out, conds_list, uncond, cond_scale)
1055
+ try:
1056
+ return _custom_combine(orig, self, x_out, conds_list, uncond, cond_scale)
1057
+ except Exception:
1058
+ return orig(self, x_out, conds_list, uncond, cond_scale)
1059
+
1060
+ setattr(_DCLS, _ORIG, _hooked)
1061
+
1062
+
1063
+ def _unhook_combine() -> None:
1064
+ if _DCLS is None: return
1065
+ if hasattr(_DCLS, _SAVED):
1066
+ setattr(_DCLS, _ORIG, getattr(_DCLS, _SAVED))
1067
+ delattr(_DCLS, _SAVED)
1068
+
1069
+
1070
+ @torch.no_grad()
1071
+ def _custom_combine(orig, self, x_out, conds_list, uncond, cond_scale):
1072
+ # Respect pass_mode
1073
+ if not _pass_features_active():
1074
+ return orig(self, x_out, conds_list, uncond, cond_scale)
1075
+
1076
+ unc = x_out[-uncond.shape[0]:]
1077
+ # cond_scale already contains cond_scale_miltiplier β€” A1111 forward() calls
1078
+ # combine_denoised(x_out, conds_list, uncond, cond_scale * self.cond_scale_miltiplier)
1079
+ # so we must NOT multiply by cond_scale_miltiplier again here.
1080
+ eff_scale = cond_scale
1081
+
1082
+ if _st.autocfg_en and _st.autocfg_method != "None":
1083
+ denoised = _auto_cfg_combine(unc, conds_list, x_out, eff_scale,
1084
+ _st.autocfg_ref, _st.autocfg_method, _st.autocfg_topk)
1085
+ else:
1086
+ # ── TCFG: filter tangential uncond components before NRS ─────────
1087
+ unc_nrs = unc
1088
+ if _st.tcfg_en and len(conds_list) > 0:
1089
+ try:
1090
+ # Build representative cond tensor (one per batch item)
1091
+ rep_cond = torch.stack(
1092
+ [x_out[conds_list[i][0][0]] for i in range(unc.shape[0])
1093
+ if conds_list[i]], dim=0)
1094
+ unc_nrs = _tcfg_uncond(rep_cond, unc, step=_st.current_step)
1095
+ except Exception:
1096
+ unc_nrs = unc
1097
+
1098
+ denoised = torch.clone(unc)
1099
+ sigma = _st.current_sigma
1100
+ x_orig = _st.current_x # noisy latent = params.x = x_in
1101
+
1102
+ # x_in layout: [repeats[0]Γ—x[0], repeats[1]Γ—x[1], ..., x[0], x[1], ...]
1103
+ # The last batch_size entries are the original (unrepeat-ed) noisy latents,
1104
+ # one per batch item. Using x_orig[i] would give the wrong entry when
1105
+ # repeats[i] > 1 (AND composition). Use the tail slice instead.
1106
+ x_batch = (x_orig[-uncond.shape[0]:] if x_orig is not None else None)
1107
+
1108
+ for i, conds in enumerate(conds_list):
1109
+ for ci, w in conds:
1110
+ cp = x_out[ci]
1111
+ up = unc_nrs[i] if _st.tcfg_en else unc[i]
1112
+
1113
+ if _st.nrs_en:
1114
+ # ── Disagreement Gate: scale skew/stretch adaptively ──
1115
+ eff_skew, eff_stretch = _st._sched_nrs_skew, _st._sched_nrs_stretch
1116
+ if _st.dis_en:
1117
+ eff_skew, eff_stretch = _disagreement_gate(
1118
+ cp, up,
1119
+ eff_skew, eff_stretch,
1120
+ _st.dis_strength, _st.dis_threshold, _st.dis_metric)
1121
+
1122
+ # ── Proper sigma-aware NRS with optional Midpoint ─────
1123
+ if sigma is not None and x_batch is not None:
1124
+ xi = x_batch[i] if x_batch.shape[0] > i else x_batch[0]
1125
+ xi = xi.unsqueeze(0) if xi.dim() == 3 else xi
1126
+ cp_b = cp.unsqueeze(0) if cp.dim() == 3 else cp
1127
+ up_b = up.unsqueeze(0) if up.dim() == 3 else up
1128
+
1129
+ if _st.nrs_midpoint > 0.0:
1130
+ cp_b = _nrs_midpoint_refined(
1131
+ xi, cp_b, up_b, sigma,
1132
+ eff_skew, eff_stretch, _st.nrs_squash,
1133
+ blend=_st.nrs_midpoint,
1134
+ mode=_st.nrs_midpoint_mode,
1135
+ first_half_only=_st.nrs_midpoint_fh,
1136
+ step=_st.current_step, total=_st.total_steps,
1137
+ softcap=_st.nrs_softcap,
1138
+ softcap_mode=_st.nrs_softcap_mode)
1139
+ else:
1140
+ cp_b = _nrs_core(xi, cp_b, up_b, sigma,
1141
+ eff_skew, eff_stretch, _st.nrs_squash,
1142
+ vector_softcap=_st.nrs_softcap,
1143
+ vector_softcap_mode=_st.nrs_softcap_mode)
1144
+
1145
+ cp = cp_b.squeeze(0) if cp_b.shape[0] == 1 else cp_b
1146
+ else:
1147
+ # Fallback: simple geometry without sigma (no x_batch stored)
1148
+ eps = 1e-6
1149
+ B = 1; c = cp.reshape(1, -1); u = up.reshape(1, -1)
1150
+ c2c = (c * c).sum(dim=1, keepdim=True) + eps
1151
+ u2c = (u * c).sum(dim=1, keepdim=True)
1152
+ u_on_c = (u2c / c2c) * c
1153
+ stretched = c + eff_stretch * (c - u_on_c)
1154
+ skewed = stretched - eff_skew * (u - u_on_c)
1155
+ c_len = c.norm(dim=1, keepdim=True)
1156
+ s_len = skewed.norm(dim=1, keepdim=True) + eps
1157
+ cp = (skewed * ((1.0 - _st.nrs_squash)
1158
+ + _st.nrs_squash * (c_len / s_len))).reshape(cp.shape)
1159
+
1160
+ denoised[i] = denoised[i] + (cp - unc[i]) * (w * eff_scale)
1161
+
1162
+ if _st.heuristic_en:
1163
+ denoised = _heuristic(denoised, unc, eff_scale, _st._sched_heuristic_cfg)
1164
+ if _st.reinhard_en:
1165
+ denoised = _reinhard(denoised, unc, eff_scale, _st.reinhard_ref)
1166
+ if _st.rescale_en:
1167
+ denoised = _rescale(denoised, x_out, conds_list, _st.rescale_mult)
1168
+ if _st.lir_en:
1169
+ denoised = _lir(denoised, _st.lir_cfg, _st.lir_method, _st.lir_topk)
1170
+ if _st.mean_en:
1171
+ dims = list(range(1, denoised.dim()))
1172
+ denoised = denoised - denoised.mean(dim=dims, keepdim=True)
1173
+
1174
+ # ── Inter-step smoothing (Momentum / EMA Delta) ───────────────────────
1175
+ if _st.interp_mode == "Momentum" and _st.interp_mom > 0.0:
1176
+ denoised, _st._prev_vel = _apply_momentum(
1177
+ denoised, _st._prev_nrs, _st._prev_vel, _st.interp_mom)
1178
+ elif _st.interp_mode == "EMA Delta":
1179
+ denoised, _st._prev_delta = _apply_ema_delta(
1180
+ denoised, _st._prev_nrs, _st._prev_delta, _st.interp_ema)
1181
+
1182
+ _st._prev_nrs = denoised.detach().clone()
1183
+
1184
+ return denoised
1185
+
1186
+
1187
+ # ── Post-CFG shapers ──────────────────────────────────────────────────────
1188
+
1189
+ @torch.no_grad()
1190
+ def _reinhard(result, unc, cs, ref):
1191
+ if cs <= 0: return result
1192
+ noise = result - unc
1193
+ dims = list(range(1, noise.dim()))
1194
+ mag = torch.linalg.vector_norm(noise, dim=dims, keepdim=True) + 1e-8
1195
+ unit = noise / mag
1196
+ mean = mag.mean(); std = mag.std() if mag.numel() > 1 else torch.zeros_like(mean)
1197
+ top = (mean + 3.0 * std) * (ref / cs)
1198
+ return unc + unit * (top * mag / (mag + top)) * cs
1199
+
1200
+
1201
+ @torch.no_grad()
1202
+ def _rescale(result, x_out, conds_list, mult):
1203
+ try:
1204
+ ci = conds_list[0][0][0]
1205
+ ref = x_out[ci : ci + 1]
1206
+ dims = list(range(1, result.dim()))
1207
+ sc = ref.std(dim=dims, keepdim=True) + 1e-8
1208
+ sr = result.std(dim=dims, keepdim=True) + 1e-8
1209
+ return torch.lerp(result, result * (sc / sr), float(mult))
1210
+ except (IndexError, TypeError): return result
1211
+
1212
+
1213
+ @torch.no_grad()
1214
+ def _heuristic(result, unc, cs, h_cfg):
1215
+ if cs <= 0 or abs(cs - h_cfg) < 0.05:
1216
+ return result
1217
+ if _st.heuristic_hstart > 0.0 and _st.current_step < _st.heuristic_hstart * _st.total_steps:
1218
+ return result
1219
+ noise = result - unc
1220
+ h_res = unc + h_cfg * noise / cs
1221
+ rQ = torch.quantile((result - result.mean()).abs().float(), 0.99) + 1e-8
1222
+ hQ = torch.quantile((h_res - h_res .mean()).abs().float(), 0.99) + 1e-8
1223
+ return unc + noise * (hQ / rQ)
1224
+
1225
+
1226
+ @torch.no_grad()
1227
+ def _lir(result, target_cfg, method, topk):
1228
+ target = target_cfg / 10.0
1229
+ out = result.clone()
1230
+ for b in range(result.shape[0]):
1231
+ flat = result[b].reshape(result[b].shape[0], -1)
1232
+ for c in range(flat.shape[0]):
1233
+ r = max(_topk_range(flat[c], method, topk), 1e-8)
1234
+ out[b][c] = result[b][c] * (target / r)
1235
+ return out
1236
+
1237
+
1238
+ # ═══════════════════════════════════════════════════════════════════════════
1239
+ # SVG schedule preview
1240
+ # ═══════════════════════════════════════════════════════════════════════════
1241
+
1242
+ def _make_svg(kind, mn, mx, W=340, H=88):
1243
+ PAD = 16; w, h = W - 2*PAD, H - 2*PAD; N = 80
1244
+ y_lo = 0.0; y_hi = max(3.0, float(mx) + 0.4)
1245
+ def px(t, v): return f"{PAD+t*w:.1f},{PAD+(1-(v-y_lo)/(y_hi-y_lo))*h:.1f}"
1246
+ pts = " ".join(px(i/N, _sm(i/N, kind, float(mn), float(mx))) for i in range(N+1))
1247
+ by = PAD + (1 - (1.0 - y_lo)/(y_hi - y_lo)) * h
1248
+ _, ymn = map(float, px(0.0, float(mn)).split(","))
1249
+ _, ymx = map(float, px(0.5, float(mx)).split(","))
1250
+ return (f'<div style="margin:4px 0"><svg xmlns="http://www.w3.org/2000/svg" '
1251
+ f'width="{W}" height="{H}" style="background:#0b0b18;border-radius:8px;'
1252
+ f'border:1px solid #1e1e36;display:block">'
1253
+ f'<line x1="{PAD}" y1="{by:.1f}" x2="{PAD+w}" y2="{by:.1f}" '
1254
+ f'stroke="#3a3a5a" stroke-dasharray="4 3" stroke-width="1"/>'
1255
+ f'<text x="{PAD+3}" y="{by-3:.0f}" fill="#555" font-size="9" font-family="monospace">Γ—1.0</text>'
1256
+ f'<text x="{PAD+3}" y="{ymn+4:.0f}" fill="#6ab4ff" font-size="9" font-family="monospace" opacity=".8">Γ—{float(mn):.2f}</text>'
1257
+ f'<text x="{PAD+3}" y="{ymx-2:.0f}" fill="#6ab4ff" font-size="9" font-family="monospace" opacity=".8">Γ—{float(mx):.2f}</text>'
1258
+ f'<polyline points="{pts}" fill="none" stroke="#5bc8f5" stroke-width="2.2" '
1259
+ f'stroke-linecap="round" stroke-linejoin="round"/>'
1260
+ f'<text x="{PAD}" y="{H-2}" fill="#444" font-size="9" font-family="monospace">0</text>'
1261
+ f'<text x="{PAD+w-28}" y="{H-2}" fill="#444" font-size="9" font-family="monospace">final</text>'
1262
+ f'</svg></div>')
1263
+
1264
+
1265
+ def _h(txt):
1266
+ return (f"<div style='font-weight:700;color:#a8d4ff;margin:12px 0 2px;font-size:.9em;"
1267
+ f"letter-spacing:.03em'>{txt}</div>"
1268
+ f"<hr style='border:none;border-top:1px solid #1e1e36;margin:0 0 5px'/>")
1269
+
1270
+
1271
+ # ═══════════════════════════════════════════════════════════════════════════
1272
+ # Script
1273
+ # ═══════════════════════════════════════════════════════════════════════════
1274
+
1275
+ class CFGPromptForge(scripts.Script):
1276
+
1277
+ def title(self): return "CFG & Prompt Forge"
1278
+ def show(self, _): return scripts.AlwaysVisible
1279
+
1280
+ # ──────────────────────────────────────────────────────────────────────
1281
+ def ui(self, is_img2img):
1282
+ tab = "i2i" if is_img2img else "t2i"
1283
+
1284
+ with gr.Accordion("πŸ”§ CFG & Prompt Forge", open=False,
1285
+ elem_id=f"cpf_acc_{tab}"):
1286
+
1287
+ enabled = gr.Checkbox(label="✦ Enable CFG & Prompt Forge",
1288
+ value=False, elem_id=f"cpf_en_{tab}")
1289
+
1290
+ # ── 1. CORE OPTS ──────────────────────────────────────────────
1291
+ gr.HTML(_h("πŸ“ Core Sampling Parameters"))
1292
+ with gr.Row():
1293
+ skip_early = gr.Slider(
1294
+ label="Skip Early CFG", minimum=0.0, maximum=1.0,
1295
+ step=0.01, value=0.0,
1296
+ info="Fraction of early steps where negative prompt is ignored.")
1297
+ word_wrap = gr.Slider(
1298
+ label="Prompt Word-Wrap Limit (tokens)",
1299
+ minimum=0, maximum=74, step=1, value=20,
1300
+ info="CLIP 75-token chunk backtrack threshold.")
1301
+ with gr.Row():
1302
+ ngms = gr.Slider(
1303
+ label="Neg. Guidance Min Οƒ (NGMS)",
1304
+ minimum=0.0, maximum=15.0, step=0.01, value=0.0,
1305
+ info="Skip negative prompt when Οƒ < this. Try 0.5–1.0.")
1306
+ ngms_all = gr.Checkbox(
1307
+ label="NGMS: skip every matching step",
1308
+ value=False)
1309
+
1310
+ # ── 2. STEP RANGES ────────────────────────────────────────────
1311
+ gr.HTML(_h("πŸ“Š CFG Step Ranges (DyCFG)"))
1312
+ ranges_en = gr.Checkbox(label="Enable Step Ranges", value=False)
1313
+ range_rows = []
1314
+ for idx in range(3):
1315
+ with gr.Row():
1316
+ rs = gr.Number(label=f"#{idx+1} Start", value=0, precision=0, minimum=0, maximum=500)
1317
+ re = gr.Number(label="End (0=last)", value=0, precision=0, minimum=0, maximum=500)
1318
+ rv = gr.Number(label="CFG scale", value=7.0, precision=2, minimum=1.0, maximum=100.0)
1319
+ ri = gr.Radio(choices=["Default","Linear","Fixed"], value="Default", label="Interp")
1320
+ range_rows.extend([rs, re, rv, ri])
1321
+
1322
+ # ── 3. SCHEDULE ────────────────────────────────────────────────
1323
+ gr.HTML(_h("πŸ“ˆ CFG Schedule"))
1324
+ sched_en = gr.Checkbox(label="Enable Schedule", value=False)
1325
+ with gr.Row():
1326
+ sched_type = gr.Dropdown(label="Shape",
1327
+ choices=list(_SCHED.values()), value=_SCHED["off"])
1328
+ sched_min = gr.Slider(label="Min Γ—", minimum=0.0, maximum=2.0, step=0.05, value=0.5)
1329
+ sched_max = gr.Slider(label="Max Γ—", minimum=0.5, maximum=4.0, step=0.05, value=1.5)
1330
+ sched_svg = gr.HTML(_make_svg("off", 0.5, 1.5))
1331
+ def _upd(lbl, mn, mx): return _make_svg(_L2K.get(lbl, "off"), mn, mx)
1332
+ sched_type.change(_upd, [sched_type, sched_min, sched_max], sched_svg)
1333
+ sched_min .change(_upd, [sched_type, sched_min, sched_max], sched_svg)
1334
+ sched_max .change(_upd, [sched_type, sched_min, sched_max], sched_svg)
1335
+
1336
+ # ── 4. BOOST + JITTER ─────────────────────────────────────────
1337
+ gr.HTML(_h("πŸš€ End-step Boost & 🎲 Jitter"))
1338
+ with gr.Row():
1339
+ boost_en = gr.Checkbox(label="Enable End Boost", value=False)
1340
+ boost_from = gr.Slider(label="Boost From (step fraction)",
1341
+ minimum=0.5, maximum=0.95, step=0.01, value=0.80,
1342
+ info="e.g. 0.80 = activates in the last 20% of steps. "
1343
+ "Max capped at 0.95 β€” values near 1.0 may never fire "
1344
+ "on low step counts (with 20 steps, 0.95 fires only on "
1345
+ "the final step).")
1346
+ boost_mul = gr.Slider(label="Boost Γ—",
1347
+ minimum=1.0, maximum=5.0, step=0.05, value=1.5)
1348
+ with gr.Row():
1349
+ jitter_en = gr.Checkbox(label="Enable Jitter", value=False)
1350
+ jitter_amt = gr.Slider(label="Jitter Amount (relative Β±)",
1351
+ minimum=0.0, maximum=0.5, step=0.01, value=0.10)
1352
+
1353
+ # ── 5. AUTO CFG ────────────────────────────────────────────────
1354
+ gr.HTML(_h("πŸ€– Auto CFG (ComfyUI-AutomaticCFG)"))
1355
+ with gr.Row():
1356
+ autocfg_en = gr.Checkbox(label="Enable Auto CFG", value=False)
1357
+ autocfg_method = gr.Dropdown(label="Method",
1358
+ choices=["hard","soft","hard_squared","range"], value="hard")
1359
+ with gr.Row():
1360
+ autocfg_ref = gr.Slider(label="Reference CFG",
1361
+ minimum=1.0, maximum=30.0, step=0.5, value=8.0)
1362
+ autocfg_topk = gr.Slider(label="Top-k fraction",
1363
+ minimum=0.05, maximum=0.5, step=0.01, value=0.25)
1364
+
1365
+ # ── 6. LERP UNCOND ────────────────────────────────────────────
1366
+ gr.HTML(_h("πŸ”— Uncond Lerp (AutoCFG / pre_cfg)"))
1367
+ with gr.Row():
1368
+ lerp_en = gr.Checkbox(label="Enable Uncond Lerp", value=False)
1369
+ lerp_str = gr.Slider(label="Lerp Strength",
1370
+ minimum=0.0, maximum=3.0, step=0.05, value=1.0,
1371
+ info="1.0 = no change. < 1 pulls uncond toward cond.")
1372
+ with gr.Row():
1373
+ lerp_sched_curve = gr.Dropdown(label="Schedule",
1374
+ choices=_CURVES, value="Off",
1375
+ info="Curve shape over steps.")
1376
+ lerp_sched_min = gr.Slider(label="Min Lerp",
1377
+ minimum=0.0, maximum=3.0, step=0.05, value=0.5,
1378
+ info="Lerp strength reached at the curve minimum.")
1379
+
1380
+ # ── 7. SKIMMED CFG ────────────────────────────────────────────
1381
+ gr.HTML(_h("πŸ”¬ Skimmed CFG (sd-webui-skimmed_cfg)"))
1382
+ with gr.Row():
1383
+ skim_en = gr.Checkbox(label="Enable Skimmed CFG", value=False)
1384
+ skim_mode = gr.Radio(choices=["Classic","Smooth"], value="Classic", label="Mode")
1385
+ with gr.Row():
1386
+ skim_scale = gr.Slider(label="Skimming Scale",
1387
+ minimum=1.0, maximum=20.0, step=0.5, value=7.0)
1388
+ skim_flip = gr.Checkbox(label="Disable Flipping Filter (Classic only)", value=False)
1389
+ with gr.Row():
1390
+ skim_sched_curve = gr.Dropdown(label="Schedule",
1391
+ choices=_CURVES, value="Off",
1392
+ info="Curve shape over steps.")
1393
+ skim_sched_min = gr.Slider(label="Min Skim Scale",
1394
+ minimum=1.0, maximum=20.0, step=0.5, value=1.0,
1395
+ info="Skim scale reached at the curve minimum.")
1396
+ with gr.Row():
1397
+ skim_target = gr.Radio(
1398
+ choices=["Uncond", "Cond β†’ Uncond", "Both"],
1399
+ value="Uncond", label="Classic Target",
1400
+ info="'Uncond' β€” modify uncond directly (moves uncond toward cond). "
1401
+ "'Cond β†’ Uncond' β€” modify cond, put into uncond (original skimmed-CFG: zero guidance at mask). "
1402
+ "'Both' β€” modify both embeddings. Smooth mode always targets the branch selected via Apply-to flags.")
1403
+
1404
+ # ── 8. NRS ────────────────────────────────────────────────────
1405
+ gr.HTML(_h("⚑ Guidance Geometry β€” NRS v2 (NRS Kohaku Enhanced)"))
1406
+ gr.HTML(
1407
+ "<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
1408
+ "Proper sigma-aware NRS on latent noise predictions. "
1409
+ "Includes <b>Vector Softcap</b> (prevent oversteering), "
1410
+ "<b>Midpoint Refinement</b> (Kohaku RK2), "
1411
+ "<b>TCFG</b> (SVD tangential damping), "
1412
+ "<b>Disagreement Gate</b> (adaptive intensity) and "
1413
+ "<b>Inter-step smoothing</b> (Momentum / EMA).</p>"
1414
+ )
1415
+ with gr.Row():
1416
+ nrs_en = gr.Checkbox(label="Enable NRS Geometry", value=False)
1417
+ nrs_skew = gr.Slider(label="Skew", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
1418
+ nrs_stretch = gr.Slider(label="Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=5.0)
1419
+ nrs_squash = gr.Slider(label="Squash", minimum=0.0, maximum=1.0, step=0.05, value=0.75)
1420
+ with gr.Row():
1421
+ nrs_sched_curve = gr.Dropdown(label="Schedule",
1422
+ choices=_CURVES, value="Off",
1423
+ info="Curve shape for Skew/Stretch over steps.")
1424
+ nrs_sched_min_skew = gr.Slider(label="Min Skew",
1425
+ minimum=-10.0, maximum=10.0, step=0.1, value=0.0,
1426
+ info="Skew value at the curve minimum.")
1427
+ nrs_sched_min_stretch = gr.Slider(label="Min Stretch",
1428
+ minimum=-10.0, maximum=10.0, step=0.1, value=0.0,
1429
+ info="Stretch value at the curve minimum.")
1430
+
1431
+ gr.HTML("<b style='color:#8ab8e8'>Individual Step Control (NRS-style Range1/Range2 + Lock)</b>"
1432
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1433
+ "Override Skew/Stretch with specific values in one or two step ranges. "
1434
+ "Lock after End: once the last range end is crossed, NRS stays disabled "
1435
+ "for the rest of the generation (including HR pass).</p>")
1436
+ with gr.Row():
1437
+ nrs_sc_enabled = gr.Checkbox(label="Enable Step Control", value=False)
1438
+ nrs_sc_lock = gr.Checkbox(label="Lock after End", value=False,
1439
+ info="Once past last range end, keep NRS disabled permanently.")
1440
+ with gr.Row():
1441
+ nrs_sc_r1_start = gr.Number(label="R1 Start", value=0, precision=0, minimum=0, maximum=500)
1442
+ nrs_sc_r1_end = gr.Number(label="R1 End (0=last)", value=0, precision=0, minimum=0, maximum=500)
1443
+ nrs_sc_r1_skew = gr.Slider(label="R1 Skew", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
1444
+ nrs_sc_r1_stretch = gr.Slider(label="R1 Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=5.0)
1445
+ with gr.Row():
1446
+ nrs_sc_r2_start = gr.Number(label="R2 Start", value=0, precision=0, minimum=0, maximum=500)
1447
+ nrs_sc_r2_end = gr.Number(label="R2 End (0=last)", value=0, precision=0, minimum=0, maximum=500)
1448
+ nrs_sc_r2_skew = gr.Slider(label="R2 Skew", minimum=-10.0, maximum=10.0, step=0.1, value=1.0)
1449
+ nrs_sc_r2_stretch = gr.Slider(label="R2 Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
1450
+
1451
+ gr.HTML("<b style='color:#8ab8e8'>Vector Softcap</b>"
1452
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1453
+ "tanh-based direction-preserving magnitude compression. "
1454
+ "Prevents oversteering at high Skew/Stretch values. 0 = disabled.</p>")
1455
+ with gr.Row():
1456
+ nrs_softcap = gr.Slider(label="Softcap Strength",
1457
+ minimum=0.0, maximum=5.0, step=0.1, value=0.0,
1458
+ info="0 = off. Try 1.0–2.0 as a starting point.")
1459
+ nrs_softcap_mode = gr.Dropdown(label="Softcap Mode",
1460
+ choices=["Per Sample", "Per Channel", "Global Batch"],
1461
+ value="Per Sample",
1462
+ info="Per Sample: recommended. Per Channel: best with high-detail. "
1463
+ "Global Batch: legacy behaviour.")
1464
+
1465
+ gr.HTML("<b style='color:#8ab8e8'>Midpoint Refinement (Kohaku RK2)</b>"
1466
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1467
+ "Runge-Kutta 2nd-order midpoint correction β€” computes NRS twice "
1468
+ "and averages, similar to Kohaku sampler logic. 0 = disabled.</p>")
1469
+ with gr.Row():
1470
+ nrs_midpoint = gr.Slider(label="Midpoint Blend",
1471
+ minimum=0.0, maximum=1.0, step=0.05, value=0.0,
1472
+ info="How far to shift toward midpoint. Try 0.3–0.6.")
1473
+ nrs_midpoint_mode = gr.Dropdown(label="Midpoint Mode",
1474
+ choices=["Classic", "Conservative", "Directional Midpoint"],
1475
+ value="Classic",
1476
+ info="Classic: direct average. Conservative: soften 2nd-pass geometry. "
1477
+ "Directional: project correction onto base direction.")
1478
+ nrs_midpoint_fh = gr.Checkbox(
1479
+ label="First Half Only",
1480
+ value=True,
1481
+ info="Only refine during first half of steps (structure phase). "
1482
+ "Disable to refine all steps.")
1483
+
1484
+ gr.HTML("<b style='color:#8ab8e8'>TCFG β€” Tangential Damping (arxiv 2503.18137)</b>"
1485
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1486
+ "SVD-based filter: removes uncond components that are tangential "
1487
+ "(misaligned) to cond. Reduces CFG manifold drift. No extra UNet calls.</p>")
1488
+ tcfg_en = gr.Checkbox(label="Enable TCFG", value=False)
1489
+
1490
+ gr.HTML("<b style='color:#8ab8e8'>Disagreement Gate</b>"
1491
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1492
+ "Reduces effective Skew/Stretch when cond and uncond are already "
1493
+ "similar β€” avoids pointless NRS on well-aligned guidance.</p>")
1494
+ with gr.Row():
1495
+ dis_en = gr.Checkbox(label="Enable Disagreement Gate", value=False)
1496
+ dis_strength = gr.Slider(label="Gate Floor",
1497
+ minimum=0.0, maximum=1.0, step=0.05, value=0.5,
1498
+ info="Minimum gate at zero disagreement. 0 = fully blocks NRS when "
1499
+ "condβ‰ˆuncond. 1 = gate has no effect.")
1500
+ dis_threshold = gr.Slider(label="Gate Threshold",
1501
+ minimum=0.05, maximum=1.0, step=0.05, value=0.3,
1502
+ info="Disagreement level at which full skew/stretch is reached.")
1503
+ dis_metric = gr.Radio(choices=["Cosine", "L2"],
1504
+ value="Cosine", label="Metric",
1505
+ info="Cosine: direction-based. L2: magnitude-based.")
1506
+
1507
+ gr.HTML("<b style='color:#8ab8e8'>Inter-step Smoothing</b>"
1508
+ "<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
1509
+ "<b>Momentum</b>: RES/Clybius velocity blend between steps β€” "
1510
+ "smooths abrupt NRS jumps.<br>"
1511
+ "<b>EMA Delta</b>: exponential moving average of the step update vector β€” "
1512
+ "less temporal blurring than momentum.</p>")
1513
+ with gr.Row():
1514
+ interp_mode = gr.Radio(choices=["None", "Momentum", "EMA Delta"],
1515
+ value="None", label="Inter-step Mode")
1516
+ interp_mom = gr.Slider(label="Momentum",
1517
+ minimum=0.0, maximum=0.99, step=0.01, value=0.0,
1518
+ info="Higher = more temporal smoothing. 0.5–0.8 typical.")
1519
+ interp_ema = gr.Slider(label="EMA Alpha",
1520
+ minimum=0.01, maximum=0.99, step=0.01, value=0.5,
1521
+ info="Blend factor toward current step. 1.0 = no smoothing.")
1522
+
1523
+ # ── 9. POST-CFG ───────────────────────────────────────────────
1524
+ gr.HTML(_h("🎚 Post-CFG Shaping"))
1525
+ with gr.Row():
1526
+ reinhard_en = gr.Checkbox(label="Enable Reinhard Tonemap", value=False)
1527
+ reinhard_ref = gr.Slider(label="Reinhard Ref.",
1528
+ minimum=0.5, maximum=16.0, step=0.5, value=4.0)
1529
+ with gr.Row():
1530
+ rescale_en = gr.Checkbox(label="Enable Rescale CFG", value=False)
1531
+ rescale_mult = gr.Slider(label="Rescale Strength",
1532
+ minimum=0.0, maximum=1.0, step=0.05, value=0.7)
1533
+ with gr.Row():
1534
+ heuristic_en = gr.Checkbox(label="Enable Heuristic CFG", value=False)
1535
+ heuristic_cfg = gr.Slider(label="Heuristic Ref. CFG",
1536
+ minimum=1.0, maximum=20.0, step=0.5, value=5.0)
1537
+ heuristic_hstart = gr.Slider(label="Heuristic Start Step",
1538
+ minimum=0.0, maximum=1.0, step=0.01, value=0.0,
1539
+ info="Fraction of steps before heuristic activates. 0 = from step 0.")
1540
+ with gr.Row():
1541
+ heuristic_sched_curve = gr.Dropdown(label="Schedule",
1542
+ choices=_CURVES, value="Off",
1543
+ info="Curve shape for Heuristic CFG over steps.")
1544
+ heuristic_sched_min = gr.Slider(label="Min Heuristic CFG",
1545
+ minimum=0.5, maximum=20.0, step=0.5, value=1.0,
1546
+ info="Heuristic CFG at the curve minimum.")
1547
+ with gr.Row():
1548
+ lir_en = gr.Checkbox(label="Enable Latent Intensity Rescale", value=False)
1549
+ lir_cfg = gr.Slider(label="LIR Target CFG",
1550
+ minimum=1.0, maximum=30.0, step=0.5, value=8.0)
1551
+ lir_method = gr.Dropdown(label="LIR Method",
1552
+ choices=["hard","soft","hard_squared","range"], value="hard")
1553
+ lir_topk = gr.Slider(label="LIR Top-k",
1554
+ minimum=0.05, maximum=0.5, step=0.01, value=0.25)
1555
+ mean_en = gr.Checkbox(label="Enable Subtract Latent Mean", value=False)
1556
+
1557
+ # ── 10. SEGA-INSPIRED ACTIVATION WINDOW ──────────────────────
1558
+ gr.HTML(_h("🧭 Feature Activation Window (SEGA v5.1)"))
1559
+ gr.HTML(
1560
+ "<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
1561
+ "Controls <em>when</em> embedding-level features (Lerp Uncond, "
1562
+ "Skimmed CFG) are active within a generation, and <em>which</em> "
1563
+ "prompt branch they modify. Inspired by SEGA's warmup, guidance "
1564
+ "end step, and lock-after-end mechanism.<br><br>"
1565
+ "<b>Warmup Steps</b> β€” embedding features inactive for the first N steps "
1566
+ "(early steps are too noisy for fine guidance).<br>"
1567
+ "<b>End Step</b> β€” deactivate embedding features after this step "
1568
+ "(0 = keep active until the end).<br>"
1569
+ "<b>Lock after End</b> β€” once End Step is crossed the features stay "
1570
+ "off for the rest of the generation, including the HiRes pass.<br>"
1571
+ "<b>Apply to Positive / Negative</b> β€” choose which prompt branch "
1572
+ "Lerp Uncond and Skimmed CFG modify.</p>"
1573
+ )
1574
+ with gr.Row():
1575
+ warmup_steps = gr.Slider(label="Warmup Steps",
1576
+ minimum=0, maximum=50, step=1, value=0,
1577
+ info="Skip embedding features for first N steps.")
1578
+ end_step = gr.Slider(label="End Step (0 = no limit)",
1579
+ minimum=0, maximum=150, step=1, value=0,
1580
+ info="Deactivate embedding features after this step.")
1581
+ lock_after_end = gr.Checkbox(
1582
+ label="Lock after End (stays off for HR pass too)",
1583
+ value=False)
1584
+ with gr.Row():
1585
+ apply_to_neg = gr.Checkbox(
1586
+ label="Apply to Negative (text_uncond)",
1587
+ value=True,
1588
+ info="Lerp Uncond and Skimmed CFG modify the negative/unconditional "
1589
+ "prompt embeddings.")
1590
+ apply_to_pos = gr.Checkbox(
1591
+ label="Apply to Positive (text_cond)",
1592
+ value=False,
1593
+ info="Lerp Uncond and Skimmed CFG also modify the positive/conditional "
1594
+ "prompt embeddings.")
1595
+
1596
+ # ── 11. PASS MODE + HIRES ────────────────────────────────────
1597
+ gr.HTML(_h("πŸ–Ό Pass Mode & HiRes Fix CFG"))
1598
+ gr.HTML(
1599
+ "<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
1600
+ "<b>Pass Mode</b> β€” controls during which denoising pass CFG "
1601
+ "features (schedule, boost, jitter, step-ranges, embedding features) "
1602
+ "are active. Useful when you want aggressive shaping only on the "
1603
+ "base pass and clean defaults on the HiRes upscale pass, or vice versa.<br><br>"
1604
+ "<b>HR CFG Override</b> β€” separately swap <code>p.cfg_scale</code> "
1605
+ "for the HiRes fix second pass (txt2img only). Works independently "
1606
+ "of Pass Mode.</p>"
1607
+ )
1608
+ with gr.Row():
1609
+ pass_mode = gr.Radio(
1610
+ choices=["Both passes", "First pass only", "HR pass only"],
1611
+ value="Both passes",
1612
+ label="Pass Mode",
1613
+ elem_id=f"cpf_passmode_{tab}")
1614
+ with gr.Row():
1615
+ hr_en = gr.Checkbox(label="Override HR CFG", value=False)
1616
+ hr_cfg = gr.Slider(label="CFG for HR pass",
1617
+ minimum=1.0, maximum=30.0, step=0.1, value=1.0)
1618
+
1619
+ self.infotext_fields = [
1620
+ (skip_early, "CPF SkipEarlyCFG"),
1621
+ (ngms, "CPF NGMS"),
1622
+ (hr_cfg, "CPF HR_CFG"),
1623
+ ]
1624
+
1625
+ # !! ORDER must EXACTLY match _parse_args() !!
1626
+ return [
1627
+ enabled,
1628
+ skip_early, word_wrap, ngms, ngms_all,
1629
+ ranges_en, *range_rows, # 1 + 12
1630
+ sched_en, sched_type, sched_min, sched_max,
1631
+ boost_en, boost_from, boost_mul,
1632
+ jitter_en, jitter_amt,
1633
+ autocfg_en, autocfg_method, autocfg_ref, autocfg_topk,
1634
+ lerp_en, lerp_str,
1635
+ lerp_sched_curve, lerp_sched_min,
1636
+ skim_en, skim_mode, skim_scale, skim_flip, skim_target,
1637
+ skim_sched_curve, skim_sched_min,
1638
+ nrs_en, nrs_skew, nrs_stretch, nrs_squash,
1639
+ nrs_sched_curve, nrs_sched_min_skew, nrs_sched_min_stretch,
1640
+ nrs_sc_enabled, nrs_sc_lock,
1641
+ nrs_sc_r1_start, nrs_sc_r1_end, nrs_sc_r1_skew, nrs_sc_r1_stretch,
1642
+ nrs_sc_r2_start, nrs_sc_r2_end, nrs_sc_r2_skew, nrs_sc_r2_stretch,
1643
+ nrs_softcap, nrs_softcap_mode,
1644
+ nrs_midpoint, nrs_midpoint_mode, nrs_midpoint_fh,
1645
+ tcfg_en,
1646
+ dis_en, dis_strength, dis_threshold, dis_metric,
1647
+ interp_mode, interp_mom, interp_ema,
1648
+ reinhard_en, reinhard_ref,
1649
+ rescale_en, rescale_mult,
1650
+ heuristic_en, heuristic_cfg, heuristic_hstart,
1651
+ heuristic_sched_curve, heuristic_sched_min,
1652
+ lir_en, lir_cfg, lir_method, lir_topk,
1653
+ mean_en,
1654
+ warmup_steps, end_step, lock_after_end, # ← SEGA window
1655
+ apply_to_neg, apply_to_pos, # ← SEGA branch flags
1656
+ pass_mode, # ← pass mode
1657
+ hr_en, hr_cfg,
1658
+ ]
1659
+
1660
+ # ──────────────────────────────────────────────────────────────────────
1661
+ @staticmethod
1662
+ def _parse_args(raw):
1663
+ it = iter(raw); n = lambda: next(it)
1664
+ return dict(
1665
+ enabled = n(),
1666
+ skip_early = n(), word_wrap = n(), ngms = n(), ngms_all = n(),
1667
+ ranges_en = n(),
1668
+ ranges = [(int(n()), int(n()), float(n()), n()) for _ in range(3)],
1669
+ sched_en = n(), sched_type = n(), sched_min = n(), sched_max = n(),
1670
+ boost_en = n(), boost_from = n(), boost_mul = n(),
1671
+ jitter_en = n(), jitter_amt = n(),
1672
+ autocfg_en = n(), autocfg_method = n(), autocfg_ref = n(), autocfg_topk = n(),
1673
+ lerp_en = n(), lerp_str = n(),
1674
+ lerp_sched_curve = n(), lerp_sched_min = n(),
1675
+ skim_en = n(), skim_mode = n(), skim_scale = n(), skim_flip = n(), skim_target = n(),
1676
+ skim_sched_curve = n(), skim_sched_min = n(),
1677
+ nrs_en = n(), nrs_skew = n(), nrs_stretch = n(), nrs_squash = n(),
1678
+ nrs_sched_curve = n(), nrs_sched_min_skew = n(), nrs_sched_min_stretch = n(),
1679
+ nrs_sc_enabled = n(), nrs_sc_lock = n(),
1680
+ nrs_sc_r1_start = n(), nrs_sc_r1_end = n(), nrs_sc_r1_skew = n(), nrs_sc_r1_stretch = n(),
1681
+ nrs_sc_r2_start = n(), nrs_sc_r2_end = n(), nrs_sc_r2_skew = n(), nrs_sc_r2_stretch = n(),
1682
+ nrs_softcap = n(), nrs_softcap_mode = n(),
1683
+ nrs_midpoint = n(), nrs_midpoint_mode = n(), nrs_midpoint_fh = n(),
1684
+ tcfg_en = n(),
1685
+ dis_en = n(), dis_strength = n(), dis_threshold = n(), dis_metric = n(),
1686
+ interp_mode = n(), interp_mom = n(), interp_ema = n(),
1687
+ reinhard_en = n(), reinhard_ref = n(),
1688
+ rescale_en = n(), rescale_mult = n(),
1689
+ heuristic_en = n(), heuristic_cfg = n(), heuristic_hstart = n(),
1690
+ heuristic_sched_curve = n(), heuristic_sched_min = n(),
1691
+ lir_en = n(), lir_cfg = n(), lir_method = n(), lir_topk = n(),
1692
+ mean_en = n(),
1693
+ warmup_steps = n(), end_step = n(), lock_after_end = n(),
1694
+ apply_to_neg = n(), apply_to_pos = n(),
1695
+ pass_mode = n(),
1696
+ hr_en = n(), hr_cfg = n(),
1697
+ )
1698
+
1699
+ # ──────────────────────────────────────────────────────────────────────
1700
+ def before_process(self, p, *args):
1701
+ global _fixup_registered
1702
+ _st.enabled = False
1703
+ try:
1704
+ v = self._parse_args(args)
1705
+ except StopIteration:
1706
+ return
1707
+ if not v["enabled"]:
1708
+ return
1709
+
1710
+ # ── Register late fixup once per session ──────────────────────────────
1711
+ # Done here (not at import) so it's appended AFTER all other extensions'
1712
+ # import-time callbacks β†’ guaranteed to fire last every step.
1713
+ if not _fixup_registered:
1714
+ script_callbacks.on_cfg_denoiser(_late_cond_fixup)
1715
+ _fixup_registered = True
1716
+
1717
+ p.override_settings["skip_early_cond"] = float(v["skip_early"])
1718
+ p.override_settings["comma_padding_backtrack"] = int (v["word_wrap"])
1719
+ p.override_settings["s_min_uncond"] = float(v["ngms"])
1720
+ p.override_settings["s_min_uncond_all"] = bool (v["ngms_all"])
1721
+
1722
+ _st.base_cfg = float(p.cfg_scale)
1723
+ _st.ranges_en = bool(v["ranges_en"])
1724
+ _st.range_scales = None
1725
+ if _st.ranges_en:
1726
+ active = [(s, e, val, i) for (s, e, val, i) in v["ranges"] if int(s) > 0 or int(e) > 0]
1727
+ if active:
1728
+ _st.range_scales = _build_ranges(p.steps, _st.base_cfg, active)
1729
+
1730
+ _st.sched_en = bool(v["sched_en"])
1731
+ _st.sched_k = _L2K.get(v["sched_type"], "off")
1732
+ _st.sched_min = float(v["sched_min"])
1733
+ _st.sched_max = float(v["sched_max"])
1734
+
1735
+ _st.boost_en = bool (v["boost_en"])
1736
+ _st.boost_from = float(v["boost_from"])
1737
+ _st.boost_mul = float(v["boost_mul"])
1738
+ _st.jitter_en = bool (v["jitter_en"])
1739
+ _st.jitter_amt = float(v["jitter_amt"])
1740
+
1741
+ _st.autocfg_en = bool (v["autocfg_en"])
1742
+ _st.autocfg_method = v["autocfg_method"]
1743
+ _st.autocfg_ref = float(v["autocfg_ref"])
1744
+ _st.autocfg_topk = float(v["autocfg_topk"])
1745
+
1746
+ _st.lerp_en = bool (v["lerp_en"])
1747
+ _st.lerp_str = float(v["lerp_str"])
1748
+ _st.lerp_sched_curve = v.get("lerp_sched_curve", "Off")
1749
+ _st.lerp_sched_min = float(v.get("lerp_sched_min", 0.5))
1750
+
1751
+ _st.skim_en = bool (v["skim_en"])
1752
+ _st.skim_mode = v["skim_mode"]
1753
+ _st.skim_scale = float(v["skim_scale"])
1754
+ _st.skim_flip = bool (v["skim_flip"])
1755
+ _st.skim_target = v["skim_target"]
1756
+ _st.skim_sched_curve = v.get("skim_sched_curve", "Off")
1757
+ _st.skim_sched_min = float(v.get("skim_sched_min", 1.0))
1758
+
1759
+ _st.nrs_en = bool (v["nrs_en"])
1760
+ _st.nrs_skew = float(v["nrs_skew"])
1761
+ _st.nrs_stretch = float(v["nrs_stretch"])
1762
+ _st.nrs_squash = float(v["nrs_squash"])
1763
+ _st.nrs_sched_curve = v.get("nrs_sched_curve", "Off")
1764
+ _st.nrs_sched_min_skew = float(v.get("nrs_sched_min_skew", 0.0))
1765
+ _st.nrs_sched_min_stretch = float(v.get("nrs_sched_min_stretch", 0.0))
1766
+ _st.nrs_sc_enabled = bool(v.get("nrs_sc_enabled", False))
1767
+ _st.nrs_sc_lock = bool(v.get("nrs_sc_lock", False))
1768
+ _st.nrs_sc_locked = False # always reset at generation start
1769
+ _st.nrs_sc_r1_start = int(v.get("nrs_sc_r1_start", 0))
1770
+ _st.nrs_sc_r1_end = int(v.get("nrs_sc_r1_end", 0))
1771
+ _st.nrs_sc_r1_skew = float(v.get("nrs_sc_r1_skew", 2.0))
1772
+ _st.nrs_sc_r1_stretch= float(v.get("nrs_sc_r1_stretch", 5.0))
1773
+ _st.nrs_sc_r2_start = int(v.get("nrs_sc_r2_start", 0))
1774
+ _st.nrs_sc_r2_end = int(v.get("nrs_sc_r2_end", 0))
1775
+ _st.nrs_sc_r2_skew = float(v.get("nrs_sc_r2_skew", 1.0))
1776
+ _st.nrs_sc_r2_stretch= float(v.get("nrs_sc_r2_stretch", 2.0))
1777
+ # NRS v2
1778
+ _st.nrs_softcap = float(v["nrs_softcap"])
1779
+ _st.nrs_softcap_mode = v["nrs_softcap_mode"]
1780
+ _st.nrs_midpoint = float(v["nrs_midpoint"])
1781
+ _st.nrs_midpoint_mode = v["nrs_midpoint_mode"]
1782
+ _st.nrs_midpoint_fh = bool (v["nrs_midpoint_fh"])
1783
+ # TCFG
1784
+ _st.tcfg_en = bool(v["tcfg_en"])
1785
+ # Disagreement Gate
1786
+ _st.dis_en = bool (v["dis_en"])
1787
+ _st.dis_strength = float(v["dis_strength"])
1788
+ _st.dis_threshold = float(v["dis_threshold"])
1789
+ _st.dis_metric = v["dis_metric"]
1790
+ # Inter-step smoothing
1791
+ _st.interp_mode = v["interp_mode"]
1792
+ _st.interp_mom = float(v["interp_mom"])
1793
+ _st.interp_ema = float(v["interp_ema"])
1794
+ # Reset inter-step runtime state at start of generation
1795
+ _st._prev_nrs = None
1796
+ _st._prev_vel = None
1797
+ _st._prev_delta = None
1798
+ # Reset context (will be populated by on_cfg_denoiser)
1799
+ _st.current_x = None
1800
+ _st.current_sigma = None
1801
+
1802
+ _st.reinhard_en = bool (v["reinhard_en"])
1803
+ _st.reinhard_ref = float(v["reinhard_ref"])
1804
+ _st.rescale_en = bool (v["rescale_en"])
1805
+ _st.rescale_mult = float(v["rescale_mult"])
1806
+ _st.heuristic_en = bool (v["heuristic_en"])
1807
+ _st.heuristic_cfg = float(v["heuristic_cfg"])
1808
+ _st.heuristic_hstart = float(v["heuristic_hstart"])
1809
+ _st.heuristic_sched_curve = v.get("heuristic_sched_curve", "Off")
1810
+ _st.heuristic_sched_min = float(v.get("heuristic_sched_min", 1.0))
1811
+ _st.lir_en = bool (v["lir_en"])
1812
+ _st.lir_cfg = float(v["lir_cfg"])
1813
+ _st.lir_method = v["lir_method"]
1814
+ _st.lir_topk = float(v["lir_topk"])
1815
+ _st.mean_en = bool (v["mean_en"])
1816
+
1817
+ # SEGA-inspired
1818
+ _st.warmup_steps = int (v["warmup_steps"])
1819
+ _st.end_step = int (v["end_step"])
1820
+ _st.lock_after_end = bool (v["lock_after_end"])
1821
+ _st.lock_triggered = False # always reset at generation start
1822
+ _st.apply_to_neg = bool (v["apply_to_neg"])
1823
+ _st.apply_to_pos = bool (v["apply_to_pos"])
1824
+
1825
+ # Pass mode
1826
+ _st.pass_mode = v["pass_mode"]
1827
+ _st.in_hr_pass = False # first pass starts here
1828
+
1829
+ _st.hr_en = bool (v["hr_en"])
1830
+ _st.hr_cfg = float(v["hr_cfg"])
1831
+
1832
+ needs_hook = any([_st.nrs_en, _st.autocfg_en, _st.reinhard_en,
1833
+ _st.rescale_en, _st.heuristic_en, _st.lir_en,
1834
+ _st.mean_en, _st.tcfg_en,
1835
+ _st.interp_mode != "None"])
1836
+ if needs_hook:
1837
+ _hook_combine()
1838
+
1839
+ _st.enabled = True
1840
+
1841
+ gp = p.extra_generation_params
1842
+ gp["CPF SkipEarlyCFG"] = v["skip_early"]
1843
+ gp["CPF NGMS"] = v["ngms"]
1844
+ if _st.ranges_en and _st.range_scales:
1845
+ gp["CPF StepRanges"] = str([(s,e,va,i) for s,e,va,i in v["ranges"] if s or e])
1846
+ if _st.sched_en and _st.sched_k != "off":
1847
+ gp["CPF Schedule"] = f"{_SCHED.get(_st.sched_k)} Γ—{_st.sched_min:.2f}–×{_st.sched_max:.2f}"
1848
+ if _st.boost_en: gp["CPF Boost"] = f"Γ—{_st.boost_mul:.2f} from {_st.boost_from*100:.0f}%"
1849
+ if _st.jitter_en: gp["CPF Jitter"] = f"Β±{_st.jitter_amt*100:.0f}%"
1850
+ if _st.autocfg_en: gp["CPF AutoCFG"] = f"{_st.autocfg_method} ref={_st.autocfg_ref}"
1851
+ if _st.lerp_en:
1852
+ gp["CPF LerpUncond"] = str(_st.lerp_str)
1853
+ if _st.lerp_sched_curve != "Off":
1854
+ gp["CPF LerpSched"] = f"{_st.lerp_sched_curve} min={_st.lerp_sched_min}"
1855
+ if _st.skim_en:
1856
+ gp["CPF Skim"] = f"{_st.skim_mode} s={_st.skim_scale}"
1857
+ if _st.skim_sched_curve != "Off":
1858
+ gp["CPF SkimSched"] = f"{_st.skim_sched_curve} min={_st.skim_sched_min}"
1859
+ if _st.nrs_en:
1860
+ gp["CPF NRS"] = f"skew={_st.nrs_skew} str={_st.nrs_stretch} sq={_st.nrs_squash}"
1861
+ if _st.nrs_sched_curve != "Off":
1862
+ gp["CPF NRSSched"] = f"{_st.nrs_sched_curve} minSkew={_st.nrs_sched_min_skew} minStr={_st.nrs_sched_min_stretch}"
1863
+ if _st.nrs_sc_enabled:
1864
+ gp["CPF NRSSC"] = ("R1:({}-{})sk={}st={} "
1865
+ "R2:({}-{})sk={}st={} lock={}".format(
1866
+ _st.nrs_sc_r1_start, _st.nrs_sc_r1_end,
1867
+ _st.nrs_sc_r1_skew, _st.nrs_sc_r1_stretch,
1868
+ _st.nrs_sc_r2_start, _st.nrs_sc_r2_end,
1869
+ _st.nrs_sc_r2_skew, _st.nrs_sc_r2_stretch,
1870
+ _st.nrs_sc_lock))
1871
+ if _st.reinhard_en: gp["CPF Reinhard"] = str(_st.reinhard_ref)
1872
+ if _st.rescale_en: gp["CPF Rescale"] = str(_st.rescale_mult)
1873
+ if _st.heuristic_en:
1874
+ gp["CPF Heuristic"] = str(_st.heuristic_cfg)
1875
+ if _st.heuristic_hstart > 0: gp["CPF Heuristic Start"] = str(_st.heuristic_hstart)
1876
+ if _st.heuristic_sched_curve != "Off":
1877
+ gp["CPF HeuristicSched"] = f"{_st.heuristic_sched_curve} min={_st.heuristic_sched_min}"
1878
+ if _st.lir_en: gp["CPF LIR"] = f"{_st.lir_method} cfg={_st.lir_cfg}"
1879
+ if _st.mean_en: gp["CPF SubtractMean"] = "1"
1880
+ if _st.warmup_steps or _st.end_step:
1881
+ gp["CPF Window"] = f"warmup={_st.warmup_steps} end={_st.end_step} lock={_st.lock_after_end}"
1882
+ if _st.apply_to_pos or not _st.apply_to_neg:
1883
+ gp["CPF Apply"] = f"pos={_st.apply_to_pos} neg={_st.apply_to_neg}"
1884
+ if _st.pass_mode != "Both passes":
1885
+ gp["CPF PassMode"] = _st.pass_mode
1886
+
1887
+ # ──────────────────────────────────────────────────────────────────────
1888
+ def before_hr(self, p, *args):
1889
+ # Arm HR-pass flag β€” persists lock_triggered from first pass (SEGA lock)
1890
+ _st.in_hr_pass = True
1891
+ # Reset inter-step state: HR pass is a new sequence
1892
+ _st._prev_nrs = None
1893
+ _st._prev_vel = None
1894
+ _st._prev_delta = None
1895
+ # NOTE: nrs_sc_locked is NOT reset here β€” Lock after End persists across HR
1896
+ # pass, matching NRS behaviour. It is reset only in postprocess.
1897
+
1898
+ # HR CFG override (independent of pass_mode)
1899
+ if _st.enabled and _st.hr_en and isinstance(p, StableDiffusionProcessingTxt2Img):
1900
+ _st.hr_cfg_saved = p.cfg_scale
1901
+ p.cfg_scale = _st.hr_cfg
1902
+ p.extra_generation_params["CPF HR_CFG"] = _st.hr_cfg
1903
+
1904
+ def postprocess_image(self, p, pp, *args):
1905
+ # NOTE: intentionally NOT restoring hr_cfg_saved here.
1906
+ # postprocess_image is called for every image in a batch, so restoring
1907
+ # after the first image would leave the rest using base cfg instead of
1908
+ # hr_cfg. The restore is done once in postprocess() after all images.
1909
+ pass
1910
+
1911
+ def postprocess(self, p, processed, *args):
1912
+ _st.enabled = False
1913
+ _st.lock_triggered = False
1914
+ _st.nrs_sc_locked = False
1915
+ _st.in_hr_pass = False
1916
+ # Reset inter-step state
1917
+ _st._prev_nrs = None
1918
+ _st._prev_vel = None
1919
+ _st._prev_delta = None
1920
+ # Reset context
1921
+ _st.current_x = None
1922
+ _st.current_sigma = None
1923
+ _unhook_combine()
1924
+ if _st.hr_cfg_saved is not None:
1925
+ p.cfg_scale = _st.hr_cfg_saved
1926
+ _st.hr_cfg_saved = None