Upload 4 files
Browse files
cfg-prompt-forge/README.md
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π§ CFG & Prompt Forge v4
|
| 2 |
+
|
| 3 |
+
> Unified CFG control extension for **stock AUTOMATIC1111 Stable Diffusion WebUI**.
|
| 4 |
+
> Combines the best ideas from 8+ CFG-related extensions into one accordion.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Sources & Credits
|
| 9 |
+
|
| 10 |
+
| Feature | Ported from |
|
| 11 |
+
|---|---|
|
| 12 |
+
| Step Ranges (Linear / Fixed / Default) | sd-webui-dycfg (guzuligo) |
|
| 13 |
+
| Skimmed CFG Classic & Smooth | sd-webui-skimmed_cfg (Extraltodeus) |
|
| 14 |
+
| Auto CFG, Lerp Uncond, LIR, Subtract Mean | ComfyUI-AutomaticCFG (Extraltodeus) |
|
| 15 |
+
| Smooth Skimmed (interpolated_scales) | pre_cfg_comfy_nodes (Extraltodeus) |
|
| 16 |
+
| Guidance Geometry Skew/Stretch/Squash | NRS β Negative Rejection Steering (Scrag3H1ll) |
|
| 17 |
+
| Reinhard Tonemap, Rescale CFG, Heuristic CFG | CFgfade (Panchovix) |
|
| 18 |
+
| HiRes-pass CFG override | HiresCFG (Panchovix) |
|
| 19 |
+
| Warmup/EndStep/LockAfterEnd, Apply to Pos/Neg, Pass Mode | SEGA v5.1 (continue-revolution) |
|
| 20 |
+
| Core opts direct override | cfg-prompt-forge v1 (original) |
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
## Features
|
| 25 |
+
|
| 26 |
+
### π Core Sampling Parameters
|
| 27 |
+
Per-generation `override_settings` override β auto-restored after generation.
|
| 28 |
+
|
| 29 |
+
| Control | `opts` key | Effect |
|
| 30 |
+
|---|---|---|
|
| 31 |
+
| Skip Early CFG | `skip_early_cond` | Fraction of early steps ignoring negative prompt |
|
| 32 |
+
| Prompt Word-Wrap | `comma_padding_backtrack` | CLIP 75-token chunk backtrack |
|
| 33 |
+
| NGMS | `s_min_uncond` | Skip negative when Ο < threshold |
|
| 34 |
+
| NGMS All Steps | `s_min_uncond_all` | Skip every matching step |
|
| 35 |
+
|
| 36 |
+
### π CFG Step Ranges (DyCFG)
|
| 37 |
+
Three step windows with start / end / CFG value / interpolation.
|
| 38 |
+
Interp: **Default** (base), **Linear** (interpolate), **Fixed** (hold last value).
|
| 39 |
+
|
| 40 |
+
### π CFG Schedule
|
| 41 |
+
8 curve shapes multiplied on top of step-range CFG. Live SVG preview.
|
| 42 |
+
Off / Ramp Up / Ramp Down / Cosine In-Out / Bell / Inv. Bell / Step Up / Step Down.
|
| 43 |
+
|
| 44 |
+
### π End-step Boost
|
| 45 |
+
Ramp CFG multiplier from Γ1.0 up to BoostΓ during last N% of steps. Stacks with schedule.
|
| 46 |
+
|
| 47 |
+
### π² Per-step Jitter
|
| 48 |
+
Random Β±% perturbation each step for organic textures.
|
| 49 |
+
|
| 50 |
+
### π€ Auto CFG (ComfyUI-AutomaticCFG)
|
| 51 |
+
Per-latent-channel adaptive CFG. Methods: hard / soft / hard_squared / range.
|
| 52 |
+
When enabled, replaces the combine loop β NRS is bypassed.
|
| 53 |
+
|
| 54 |
+
### π Uncond Lerp
|
| 55 |
+
Blend `text_uncond` toward `text_cond` preserving norm.
|
| 56 |
+
Strength 1.0 = no change. <1 weakens negative. >1 exaggerates it.
|
| 57 |
+
|
| 58 |
+
### π¬ Skimmed CFG
|
| 59 |
+
Text-embedding-level uncond modification:
|
| 60 |
+
- **Classic** β binary mask on over-aligned dimensions
|
| 61 |
+
- **Smooth** β importance-weighted blend (interpolated_scales)
|
| 62 |
+
|
| 63 |
+
### β‘ NRS Guidance Geometry
|
| 64 |
+
Skew / Stretch / Squash on cond noise predictions relative to uncond:
|
| 65 |
+
- **Stretch** β amplify perpendicular component (sharper positive)
|
| 66 |
+
- **Skew** β steer cond away from uncond rejection
|
| 67 |
+
- **Squash** β renormalise back to original cond magnitude
|
| 68 |
+
|
| 69 |
+
### π Post-CFG Shaping (inside combine_denoised hook)
|
| 70 |
+
|
| 71 |
+
| Feature | Effect |
|
| 72 |
+
|---|---|
|
| 73 |
+
| Reinhard Tonemap | Smooth magnitude clamp on noise prediction β prevents "blowout" |
|
| 74 |
+
| Rescale CFG | Match result std to raw cond std β prevents latent inflation |
|
| 75 |
+
| Heuristic CFG | Quantile contrast normalisation β fights cartoon-ification at high CFG |
|
| 76 |
+
| Latent Intensity Rescale | Per-channel topk intensity target |
|
| 77 |
+
| Subtract Latent Mean | Removes spatial mean per batch item β reduces colour drift |
|
| 78 |
+
|
| 79 |
+
### π§ Feature Activation Window (SEGA v5.1)
|
| 80 |
+
|
| 81 |
+
Controls **when** embedding features (Lerp, Skim) are active and **which branch** they modify.
|
| 82 |
+
|
| 83 |
+
| Control | Effect |
|
| 84 |
+
|---|---|
|
| 85 |
+
| **Warmup Steps** | Skip embedding features for first N steps (too noisy) |
|
| 86 |
+
| **End Step** | Deactivate after this step (0 = no limit) |
|
| 87 |
+
| **Lock after End** | Once End Step is crossed, stays off permanently β including HR pass |
|
| 88 |
+
| **Apply to Negative** | Lerp/Skim modifies `text_uncond` |
|
| 89 |
+
| **Apply to Positive** | Lerp/Skim also modifies `text_cond` |
|
| 90 |
+
|
| 91 |
+
**Lock after End** is the key SEGA idea: once the guidance window closes at End Step, the lock is latched (`lock_triggered = True`) and persists through the HiRes pass β so effects stop cleanly even mid-generation.
|
| 92 |
+
|
| 93 |
+
### πΌ Pass Mode & HiRes Fix CFG
|
| 94 |
+
|
| 95 |
+
**Pass Mode** β which denoising pass activates CFG features:
|
| 96 |
+
- `Both passes` β normal (default)
|
| 97 |
+
- `First pass only` β CFG shaping only during base pass; HR pass uses plain defaults
|
| 98 |
+
- `HR pass only` β CFG shaping activates only for HiRes fix
|
| 99 |
+
|
| 100 |
+
**HR CFG Override** β independently swaps `p.cfg_scale` for the HR pass only.
|
| 101 |
+
Works regardless of Pass Mode. Lower values (1β4) often clean up HiRes upscaling.
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Installation
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
cd stable-diffusion-webui/extensions
|
| 109 |
+
git clone <repo-url> cfg-prompt-forge
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Or **Extensions β Install from URL** in WebUI.
|
| 113 |
+
|
| 114 |
+
---
|
| 115 |
+
|
| 116 |
+
## Recipes
|
| 117 |
+
|
| 118 |
+
### High-quality portrait
|
| 119 |
+
```
|
| 120 |
+
Skip Early CFG: 0.25 | NGMS: 0.50
|
| 121 |
+
Schedule: Ramp Down Γ1.6 β Γ0.9
|
| 122 |
+
End Boost from 80% Γ1.4
|
| 123 |
+
Reinhard: 4.0
|
| 124 |
+
Warmup: 5 steps | End Step: 0 (full)
|
| 125 |
+
Apply to Negative: β
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Organic / painterly
|
| 129 |
+
```
|
| 130 |
+
Jitter: Β±15% | Schedule: Bell Γ0.6 β Γ1.5
|
| 131 |
+
Skim: Smooth scale=6.0
|
| 132 |
+
Subtract Mean: β | Apply to Negative: β | Apply to Positive: β
|
| 133 |
+
Warmup: 8 | End Step: 25
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Speed run (30 steps)
|
| 137 |
+
```
|
| 138 |
+
Skip Early CFG: 0.40 | NGMS: 1.20 | NGMS All Steps: β
|
| 139 |
+
Auto CFG: hard ref=8 (NRS auto-disabled)
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Sharp upscale only during HR
|
| 143 |
+
```
|
| 144 |
+
Pass Mode: HR pass only
|
| 145 |
+
End Boost from 85% Γ1.3
|
| 146 |
+
Rescale CFG: 0.7
|
| 147 |
+
HR CFG Override: 2.0 (also active)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Lock effects to base pass only, then clean HR
|
| 151 |
+
```
|
| 152 |
+
Pass Mode: First pass only (schedule/boost off during HR)
|
| 153 |
+
Lock after End: β | End Step: [last step of base pass]
|
| 154 |
+
HR CFG Override: 3.0
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Technical Notes
|
| 160 |
+
|
| 161 |
+
- **No model patching** β only standard A1111 hooks
|
| 162 |
+
- **`override_settings`** scoped to current generation, auto-restored
|
| 163 |
+
- **`combine_denoised` hook** installed/removed per generation; safe to stack
|
| 164 |
+
- **`lock_triggered`** state survives into HR pass (SEGA v5.1 pattern) β set in `_emb_features_active()`, cleared in `postprocess()`
|
| 165 |
+
- **`in_hr_pass`** flag set in `before_hr()` β allows Pass Mode to discriminate
|
| 166 |
+
- **SDXL compatible** β handles both plain tensor and `{"crossattn": ...}` dict conditioning
|
cfg-prompt-forge/javascript/cfg_prompt_forge.js
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* CFG & Prompt Forge v4 β UI enhancements
|
| 3 |
+
* β’ Blue left-border on accordion when enabled
|
| 4 |
+
* β’ Dims sub-sections when their parent Enable checkbox is off
|
| 5 |
+
* β’ Status badge in accordion title
|
| 6 |
+
*/
|
| 7 |
+
(function () {
|
| 8 |
+
"use strict";
|
| 9 |
+
|
| 10 |
+
function wireAccordion(accId) {
|
| 11 |
+
const acc = document.getElementById(accId);
|
| 12 |
+
if (!acc) return;
|
| 13 |
+
const mainCb = acc.querySelector('input[type="checkbox"]');
|
| 14 |
+
const btn = acc.querySelector("button[aria-expanded], button.label-wrap");
|
| 15 |
+
if (!mainCb || !btn) return;
|
| 16 |
+
|
| 17 |
+
function update() {
|
| 18 |
+
btn.style.borderLeft = mainCb.checked ? "3px solid #5bc8f5" : "";
|
| 19 |
+
btn.style.paddingLeft = mainCb.checked ? "9px" : "";
|
| 20 |
+
btn.style.background = mainCb.checked ? "rgba(91,200,245,.05)" : "";
|
| 21 |
+
}
|
| 22 |
+
mainCb.addEventListener("change", update);
|
| 23 |
+
update();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
function wireDimming(accId) {
|
| 27 |
+
const acc = document.getElementById(accId);
|
| 28 |
+
if (!acc) return;
|
| 29 |
+
|
| 30 |
+
// For every checkbox whose label contains "Enable" (but not the master β¦)
|
| 31 |
+
acc.querySelectorAll('input[type="checkbox"]').forEach(cb => {
|
| 32 |
+
const lbl = (cb.closest("label") || cb.parentElement || {}).textContent || "";
|
| 33 |
+
if (!lbl.includes("Enable") || lbl.includes("β¦")) return;
|
| 34 |
+
|
| 35 |
+
function apply() {
|
| 36 |
+
// Walk forward siblings until the next Enable checkbox group
|
| 37 |
+
let el = cb.closest('[class*="row"],[class*="form"],[class*="block"]')
|
| 38 |
+
|| cb.parentElement;
|
| 39 |
+
if (!el) return;
|
| 40 |
+
let sib = el.nextElementSibling;
|
| 41 |
+
while (sib) {
|
| 42 |
+
const nextEn = sib.querySelector('input[type="checkbox"]');
|
| 43 |
+
if (nextEn) {
|
| 44 |
+
const nt = (nextEn.closest("label") || nextEn.parentElement || {}).textContent || "";
|
| 45 |
+
if (nt.includes("Enable")) break;
|
| 46 |
+
}
|
| 47 |
+
sib.style.opacity = cb.checked ? "1" : "0.4";
|
| 48 |
+
sib.style.pointerEvents = cb.checked ? "" : "none";
|
| 49 |
+
sib.style.transition = "opacity .15s";
|
| 50 |
+
sib = sib.nextElementSibling;
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
cb.addEventListener("change", apply);
|
| 54 |
+
apply();
|
| 55 |
+
});
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
function init() {
|
| 59 |
+
["cpf_acc_t2i", "cpf_acc_i2i"].forEach(id => {
|
| 60 |
+
wireAccordion(id);
|
| 61 |
+
wireDimming(id);
|
| 62 |
+
});
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
let tries = 0;
|
| 66 |
+
const t = setInterval(() => {
|
| 67 |
+
if (document.getElementById("cpf_acc_t2i") ||
|
| 68 |
+
document.getElementById("cpf_acc_i2i") ||
|
| 69 |
+
++tries > 100) {
|
| 70 |
+
clearInterval(t);
|
| 71 |
+
init();
|
| 72 |
+
}
|
| 73 |
+
}, 350);
|
| 74 |
+
})();
|
cfg-prompt-forge/scripts/__pycache__/cfg_prompt_forge.cpython-312.pyc
ADDED
|
Binary file (90.9 kB). View file
|
|
|
cfg-prompt-forge/scripts/cfg_prompt_forge.py
ADDED
|
@@ -0,0 +1,1926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
+
β CFG & Prompt Forge v4 β Stable Diffusion WebUI Extension β
|
| 4 |
+
β Stock A1111 only (NOT Forge-only APIs) β
|
| 5 |
+
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£
|
| 6 |
+
β Sources: β
|
| 7 |
+
β β’ sd-webui-dycfg β step-range combine hook β
|
| 8 |
+
β β’ sd-webui-skimmed_cfg β embedding-level uncond mask β
|
| 9 |
+
β β’ ComfyUI-AutomaticCFG β adaptive per-channel CFG, LIR, mean β
|
| 10 |
+
β β’ pre_cfg_comfy_nodes β lerp-uncond, smooth-skim math β
|
| 11 |
+
β β’ NRS β Skew/Stretch/Squash guidance geometry β
|
| 12 |
+
β β’ CFgfade β Reinhard, Rescale, Heuristic CFG β
|
| 13 |
+
β β’ HiresCFG β before_hr CFG override β
|
| 14 |
+
β β’ SEGA v5.1 β Warmup/EndStep/LockAfterEnd, Apply to β
|
| 15 |
+
β Positive/Negative, Pass Mode β
|
| 16 |
+
β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ£
|
| 17 |
+
β HOOKS β
|
| 18 |
+
β on_cfg_denoiser β schedule / boost / jitter / step-range / β
|
| 19 |
+
β lerp-uncond / skimmed CFG (with SEGA window+lock) β
|
| 20 |
+
β combine_denoised β Auto CFG / NRS / heuristic / Reinhard / Rescale / β
|
| 21 |
+
β LIR / subtract-mean β
|
| 22 |
+
β before_hr β HR-pass CFG swap + pass-mode arming β
|
| 23 |
+
β postprocess_imageβ HR-pass CFG restore β
|
| 24 |
+
β before_process β write override_settings, arm all state β
|
| 25 |
+
β postprocess β disarm + unhook β
|
| 26 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import math
|
| 30 |
+
import random
|
| 31 |
+
from typing import Optional
|
| 32 |
+
|
| 33 |
+
import gradio as gr
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
from modules import scripts, script_callbacks
|
| 37 |
+
from modules.processing import StableDiffusionProcessingTxt2Img
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from modules.sd_samplers_cfg_denoiser import CFGDenoiser as _DCLS
|
| 41 |
+
except ImportError:
|
| 42 |
+
_DCLS = None
|
| 43 |
+
|
| 44 |
+
_EXT = "CPF"
|
| 45 |
+
_ORIG = "combine_denoised"
|
| 46 |
+
_SAVED = f"_{_EXT}_saved_{_ORIG}"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 50 |
+
# Schedule math
|
| 51 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
+
_SCHED: dict = {
|
| 54 |
+
"off" : "Off (flat Γ 1.0)",
|
| 55 |
+
"ramp_up" : "Ramp Up",
|
| 56 |
+
"ramp_down": "Ramp Down",
|
| 57 |
+
"cos_up" : "Cosine Ease-In",
|
| 58 |
+
"cos_down" : "Cosine Ease-Out",
|
| 59 |
+
"bell" : "Bell (low β high β low)",
|
| 60 |
+
"inv_bell" : "Inv. Bell (high β low β high)",
|
| 61 |
+
"step_up" : "Step Up (at midpoint)",
|
| 62 |
+
"step_down": "Step Down (at midpoint)",
|
| 63 |
+
}
|
| 64 |
+
_L2K = {v: k for k, v in _SCHED.items()}
|
| 65 |
+
|
| 66 |
+
# ββ Step-schedule curves (used by multiple parameters) ββββββββββββ
|
| 67 |
+
_CURVES = [
|
| 68 |
+
"Off", "Linear Down", "Linear Up", "Cosine Down", "Cosine Up",
|
| 69 |
+
"Bell", "Inv. Bell", "Step Up", "Step Down",
|
| 70 |
+
"Power Down", "Power Up", "Repeating", "Sawtooth",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _sm(t, k, mn, mx):
|
| 75 |
+
t = max(0.0, min(1.0, t))
|
| 76 |
+
if k == "ramp_up" : return mn + (mx - mn) * t
|
| 77 |
+
if k == "ramp_down": return mx - (mx - mn) * t
|
| 78 |
+
if k == "cos_up" : return mn + (mx - mn) * (1 - math.cos(math.pi * t)) / 2
|
| 79 |
+
if k == "cos_down" : return mx - (mx - mn) * (1 - math.cos(math.pi * t)) / 2
|
| 80 |
+
if k == "bell" : return mn + (mx - mn) * math.sin(math.pi * t)
|
| 81 |
+
if k == "inv_bell" : return mx - (mx - mn) * math.sin(math.pi * t)
|
| 82 |
+
if k == "step_up" : return mn if t < 0.5 else mx
|
| 83 |
+
if k == "step_down": return mx if t < 0.5 else mn
|
| 84 |
+
return 1.0
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _schedule_value(base, step, total, curve, min_val, sched_val=2.0):
|
| 88 |
+
"""Apply curve schedule to a float parameter over denoising steps.
|
| 89 |
+
|
| 90 |
+
Returns the parameter value at the given step.
|
| 91 |
+
curve='Off' β returns base unchanged.
|
| 92 |
+
"""
|
| 93 |
+
if curve == "Off":
|
| 94 |
+
return base
|
| 95 |
+
frac = step / max(total - 1, 1)
|
| 96 |
+
frac = max(0.0, min(1.0, frac))
|
| 97 |
+
|
| 98 |
+
if curve == "Linear Down": mult = 1.0 - frac
|
| 99 |
+
elif curve == "Linear Up": mult = frac
|
| 100 |
+
elif curve == "Cosine Down": mult = math.cos(frac * math.pi / 2)
|
| 101 |
+
elif curve == "Cosine Up": mult = 1.0 - math.cos(frac * math.pi / 2)
|
| 102 |
+
elif curve == "Bell": mult = math.sin(math.pi * frac)
|
| 103 |
+
elif curve == "Inv. Bell": mult = 1.0 - math.sin(math.pi * frac)
|
| 104 |
+
elif curve == "Step Up": mult = 0.0 if frac < 0.5 else 1.0
|
| 105 |
+
elif curve == "Step Down": mult = 1.0 if frac < 0.5 else 0.0
|
| 106 |
+
elif curve == "Power Down": mult = 1.0 - math.pow(frac, max(sched_val, 0.1))
|
| 107 |
+
elif curve == "Power Up": mult = math.pow(frac, max(sched_val, 0.1))
|
| 108 |
+
elif curve == "Repeating": mult = 1.0 - abs(2.0 * ((frac * sched_val) % 1.0) - 1.0)
|
| 109 |
+
elif curve == "Sawtooth": mult = (frac * sched_val) % 1.0
|
| 110 |
+
else:
|
| 111 |
+
return base
|
| 112 |
+
return min_val + (base - min_val) * mult
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _step_control_val(base_val, step, sc_enabled, r1_start, r1_end, r1_val,
|
| 116 |
+
r2_start, r2_end, r2_val, lock, locked, off_val):
|
| 117 |
+
"""NRS-style Individual Step Control with Range1/Range2 + Lock after End.
|
| 118 |
+
|
| 119 |
+
Returns (value, new_locked_flag).
|
| 120 |
+
When locked β (off_val, True).
|
| 121 |
+
When in range β (range_val, locked unchanged).
|
| 122 |
+
Otherwise β (base_val, locked unchanged).
|
| 123 |
+
"""
|
| 124 |
+
if locked:
|
| 125 |
+
return off_val, True
|
| 126 |
+
if not sc_enabled:
|
| 127 |
+
return base_val, False
|
| 128 |
+
|
| 129 |
+
def _in(s, e):
|
| 130 |
+
return s > 0 and step >= s and (e == 0 or step <= e)
|
| 131 |
+
|
| 132 |
+
if _in(r1_start, r1_end):
|
| 133 |
+
return r1_val, False
|
| 134 |
+
if _in(r2_start, r2_end):
|
| 135 |
+
return r2_val, False
|
| 136 |
+
|
| 137 |
+
if lock:
|
| 138 |
+
# Only consider finite range ends (>0) for lock triggering.
|
| 139 |
+
ends = [e for e in (r1_end, r2_end) if e > 0]
|
| 140 |
+
if ends and step > max(ends):
|
| 141 |
+
return off_val, True
|
| 142 |
+
return base_val, False
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
# Global state
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
|
| 149 |
+
class _S:
|
| 150 |
+
enabled : bool = False
|
| 151 |
+
# core opts
|
| 152 |
+
skip_early : float = 0.0
|
| 153 |
+
word_wrap : int = 20
|
| 154 |
+
ngms : float = 0.0
|
| 155 |
+
ngms_all : bool = False
|
| 156 |
+
# step ranges
|
| 157 |
+
ranges_en : bool = False
|
| 158 |
+
range_scales : list = None
|
| 159 |
+
base_cfg : float = 7.0
|
| 160 |
+
# schedule
|
| 161 |
+
sched_en : bool = False
|
| 162 |
+
sched_k : str = "off"
|
| 163 |
+
sched_min : float = 0.5
|
| 164 |
+
sched_max : float = 1.5
|
| 165 |
+
# boost
|
| 166 |
+
boost_en : bool = False
|
| 167 |
+
boost_from : float = 0.80
|
| 168 |
+
boost_mul : float = 1.5
|
| 169 |
+
# jitter
|
| 170 |
+
jitter_en : bool = False
|
| 171 |
+
jitter_amt : float = 0.10
|
| 172 |
+
# auto cfg
|
| 173 |
+
autocfg_en : bool = False
|
| 174 |
+
autocfg_method : str = "hard"
|
| 175 |
+
autocfg_ref : float = 8.0
|
| 176 |
+
autocfg_topk : float = 0.25
|
| 177 |
+
# lerp uncond
|
| 178 |
+
lerp_en : bool = False
|
| 179 |
+
lerp_str : float = 1.0
|
| 180 |
+
# skimmed cfg
|
| 181 |
+
skim_en : bool = False
|
| 182 |
+
skim_mode : str = "Classic"
|
| 183 |
+
skim_scale : float = 7.0
|
| 184 |
+
skim_flip : bool = False
|
| 185 |
+
skim_target : str = "Uncond"
|
| 186 |
+
# NRS
|
| 187 |
+
nrs_en : bool = False
|
| 188 |
+
nrs_skew : float = 2.0
|
| 189 |
+
nrs_stretch : float = 5.0
|
| 190 |
+
nrs_squash : float = 0.75
|
| 191 |
+
# NRS v2 β Vector Softcap (feature 2)
|
| 192 |
+
nrs_softcap : float = 0.0
|
| 193 |
+
nrs_softcap_mode : str = "Per Sample"
|
| 194 |
+
# NRS v2 β Midpoint Refinement / Kohaku RK2 (feature 4)
|
| 195 |
+
nrs_midpoint : float = 0.0
|
| 196 |
+
nrs_midpoint_mode : str = "Classic"
|
| 197 |
+
nrs_midpoint_fh : bool = True
|
| 198 |
+
# TCFG β Tangential Damping via SVD (feature 3)
|
| 199 |
+
tcfg_en : bool = False
|
| 200 |
+
# Disagreement Gate β adaptive skew/stretch (feature 6)
|
| 201 |
+
dis_en : bool = False
|
| 202 |
+
dis_strength : float = 0.5
|
| 203 |
+
dis_threshold : float = 0.3
|
| 204 |
+
dis_metric : str = "Cosine"
|
| 205 |
+
# Inter-step smoothing β Momentum + EMA Delta (feature 7)
|
| 206 |
+
interp_mode : str = "None"
|
| 207 |
+
interp_mom : float = 0.0
|
| 208 |
+
interp_ema : float = 0.5
|
| 209 |
+
# Runtime context (NOT UI β populated in _on_cfg_denoiser, read in _custom_combine)
|
| 210 |
+
current_x : object = None # noisy latent params.x
|
| 211 |
+
current_sigma : object = None # params.sigma
|
| 212 |
+
current_step : int = 0
|
| 213 |
+
total_steps : int = 20
|
| 214 |
+
# Inter-step state (reset each generation and each HR pass)
|
| 215 |
+
_prev_nrs : object = None
|
| 216 |
+
_prev_vel : object = None
|
| 217 |
+
_prev_delta : object = None
|
| 218 |
+
# post-cfg
|
| 219 |
+
reinhard_en : bool = False
|
| 220 |
+
reinhard_ref : float = 4.0
|
| 221 |
+
rescale_en : bool = False
|
| 222 |
+
rescale_mult : float = 0.7
|
| 223 |
+
heuristic_en : bool = False
|
| 224 |
+
heuristic_cfg : float = 5.0
|
| 225 |
+
heuristic_hstart : float = 0.0
|
| 226 |
+
lir_en : bool = False
|
| 227 |
+
lir_cfg : float = 8.0
|
| 228 |
+
lir_method : str = "hard"
|
| 229 |
+
lir_topk : float = 0.25
|
| 230 |
+
mean_en : bool = False
|
| 231 |
+
# ββ Step schedules (per-parameter curve over steps) βββββββββββββββββββββ
|
| 232 |
+
nrs_sched_curve : str = "Off"
|
| 233 |
+
nrs_sched_min_skew : float = 0.0
|
| 234 |
+
nrs_sched_min_stretch : float = 0.0
|
| 235 |
+
skim_sched_curve : str = "Off"
|
| 236 |
+
skim_sched_min : float = 1.0
|
| 237 |
+
lerp_sched_curve : str = "Off"
|
| 238 |
+
lerp_sched_min : float = 0.5
|
| 239 |
+
heuristic_sched_curve : str = "Off"
|
| 240 |
+
heuristic_sched_min : float = 1.0
|
| 241 |
+
# Computed at each step by _on_cfg_denoiser
|
| 242 |
+
_sched_skim_scale : float = 0.0
|
| 243 |
+
_sched_lerp_str : float = 0.0
|
| 244 |
+
_sched_nrs_skew : float = 0.0
|
| 245 |
+
_sched_nrs_stretch : float = 0.0
|
| 246 |
+
_sched_heuristic_cfg : float = 0.0
|
| 247 |
+
# ββ NRS Individual Step Control (Range1/Range2 + Lock after End) ββββββββ
|
| 248 |
+
nrs_sc_enabled : bool = False
|
| 249 |
+
nrs_sc_r1_start : int = 0
|
| 250 |
+
nrs_sc_r1_end : int = 0
|
| 251 |
+
nrs_sc_r1_skew : float = 0.0
|
| 252 |
+
nrs_sc_r1_stretch: float = 0.0
|
| 253 |
+
nrs_sc_r2_start : int = 0
|
| 254 |
+
nrs_sc_r2_end : int = 0
|
| 255 |
+
nrs_sc_r2_skew : float = 0.0
|
| 256 |
+
nrs_sc_r2_stretch: float = 0.0
|
| 257 |
+
nrs_sc_lock : bool = False
|
| 258 |
+
nrs_sc_locked : bool = False # runtime flag
|
| 259 |
+
# ββ SEGA-inspired activation window βββββββββββββββββββββββββββββββββββββ
|
| 260 |
+
warmup_steps : int = 0 # first N steps: embedding features inactive
|
| 261 |
+
end_step : int = 0 # 0 = no limit; stop features after this step
|
| 262 |
+
lock_after_end : bool = False # once end_step crossed, stay off for HR pass too
|
| 263 |
+
lock_triggered : bool = False # runtime flag, NOT a UI control
|
| 264 |
+
apply_to_pos : bool = False # apply Lerp/Skim to text_cond as well
|
| 265 |
+
apply_to_neg : bool = True # apply Lerp/Skim to text_uncond
|
| 266 |
+
# ββ pass mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 267 |
+
pass_mode : str = "Both passes" # "Both passes"/"First pass only"/"HR pass only"
|
| 268 |
+
in_hr_pass : bool = False # runtime flag
|
| 269 |
+
# HR CFG
|
| 270 |
+
hr_en : bool = False
|
| 271 |
+
hr_cfg : float = 1.0
|
| 272 |
+
hr_cfg_saved : Optional[float] = None
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
_st = _S()
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 279 |
+
# SEGA-inspired activation-window helpers
|
| 280 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
+
|
| 282 |
+
def _emb_features_active(step: int) -> bool:
|
| 283 |
+
"""
|
| 284 |
+
Returns True if embedding-level features (Lerp, Skim) should run this step.
|
| 285 |
+
|
| 286 |
+
Rules (ported from SEGA v5.1 is_guidance_active + is_locked_after_end):
|
| 287 |
+
1. Lock triggered (end_step crossed with lock_after_end) β False permanently
|
| 288 |
+
2. step < warmup_steps β still warming up β False
|
| 289 |
+
3. end_step > 0 and step > end_step β past window β possibly lock, then False
|
| 290 |
+
"""
|
| 291 |
+
if _st.lock_triggered:
|
| 292 |
+
return False
|
| 293 |
+
if step < _st.warmup_steps:
|
| 294 |
+
return False
|
| 295 |
+
if _st.end_step > 0 and step > _st.end_step:
|
| 296 |
+
if _st.lock_after_end:
|
| 297 |
+
_st.lock_triggered = True # permanent for HR pass too
|
| 298 |
+
return False
|
| 299 |
+
return True
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _pass_features_active() -> bool:
|
| 303 |
+
"""
|
| 304 |
+
Returns True if CFG-multiplier features (schedule, boost, jitter, ranges)
|
| 305 |
+
should run given the current pass_mode.
|
| 306 |
+
"""
|
| 307 |
+
if _st.pass_mode == "First pass only" and _st.in_hr_pass:
|
| 308 |
+
return False
|
| 309 |
+
if _st.pass_mode == "HR pass only" and not _st.in_hr_pass:
|
| 310 |
+
return False
|
| 311 |
+
return True
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 315 |
+
# Step-range builder (DyCFG)
|
| 316 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
|
| 318 |
+
def _build_ranges(total, base, rows):
|
| 319 |
+
scales = ["Default"] * total
|
| 320 |
+
for (s1, e1, v, intp) in rows:
|
| 321 |
+
s = max(1, int(s1)); e = int(e1) if int(e1) > 0 else total
|
| 322 |
+
if e < s: continue
|
| 323 |
+
v = max(1.0, float(v))
|
| 324 |
+
for i in range(s - 1, min(e, total)): scales[i] = v
|
| 325 |
+
for i in range(s - 2, -1, -1):
|
| 326 |
+
if isinstance(scales[i], str): scales[i] = intp
|
| 327 |
+
else: break
|
| 328 |
+
scales = ["Default"] + scales + ["Default"]
|
| 329 |
+
prev = base
|
| 330 |
+
for i in range(len(scales)):
|
| 331 |
+
if scales[i] == "Default": scales[i] = base
|
| 332 |
+
elif scales[i] == "Fixed": scales[i] = prev if isinstance(prev, float) else base
|
| 333 |
+
if isinstance(scales[i], float): prev = scales[i]
|
| 334 |
+
i = 0
|
| 335 |
+
while i < len(scales):
|
| 336 |
+
if scales[i] == "Linear":
|
| 337 |
+
pv = scales[i-1] if i > 0 and isinstance(scales[i-1], float) else base
|
| 338 |
+
j = i + 1
|
| 339 |
+
while j < len(scales) and scales[j] == "Linear": j += 1
|
| 340 |
+
nv = scales[j] if j < len(scales) and isinstance(scales[j], float) else base
|
| 341 |
+
d = (nv - pv) / (j - i + 1)
|
| 342 |
+
for k in range(i, j): scales[k] = pv + (k - i + 1) * d
|
| 343 |
+
i = j
|
| 344 |
+
else: i += 1
|
| 345 |
+
scales = scales[1:-1]
|
| 346 |
+
return [float(v) if isinstance(v, float) else base for v in scales]
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 350 |
+
# AutoCFG helpers (AutomaticCFG)
|
| 351 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def _topk_range(flat_ch, method, topk):
|
| 355 |
+
k = max(1, int(flat_ch.numel() * topk))
|
| 356 |
+
|
| 357 |
+
data = flat_ch.float()
|
| 358 |
+
if method == "range":
|
| 359 |
+
data = data - data.mean()
|
| 360 |
+
mx = torch.topk(data, k, largest=True ).values.mean().item()
|
| 361 |
+
mna = torch.topk(data, k, largest=False).values.abs().mean().item()
|
| 362 |
+
else:
|
| 363 |
+
mx = torch.topk(data, k, largest=True ).values.mean().item()
|
| 364 |
+
mn_v = torch.topk(data, k, largest=False).values
|
| 365 |
+
mna = abs(mn_v.mean().item()) if method == "soft" else mn_v.abs().mean().item()
|
| 366 |
+
|
| 367 |
+
r = (mx + mna) / 2.0
|
| 368 |
+
return r * r if method == "hard_squared" else r
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@torch.no_grad()
|
| 372 |
+
def _auto_cfg_combine(unc, conds_list, x_out, eff_scale, ref_cfg, method, topk):
|
| 373 |
+
target = eff_scale / 10.0
|
| 374 |
+
out = torch.zeros_like(unc)
|
| 375 |
+
for i, conds in enumerate(conds_list):
|
| 376 |
+
u = unc[i]
|
| 377 |
+
ref = u.clone()
|
| 378 |
+
for ci, w in conds: ref = ref + (x_out[ci] - u) * (w * ref_cfg)
|
| 379 |
+
C = ref.shape[0]
|
| 380 |
+
flat = ref.reshape(C, -1)
|
| 381 |
+
ranges = [_topk_range(flat[c], method, topk) for c in range(C)]
|
| 382 |
+
out[i] = u.clone()
|
| 383 |
+
for ci, w in conds:
|
| 384 |
+
delta = x_out[ci] - u
|
| 385 |
+
for c in range(C):
|
| 386 |
+
r = max(ranges[c], 1e-8)
|
| 387 |
+
out[i][c] = out[i][c] + delta[c] * (w * ref_cfg * target / r)
|
| 388 |
+
return out
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 392 |
+
# NRS v2 (NRS Kohaku Enhanced v3.5 β sigma-aware latent-space math)
|
| 393 |
+
# Features: proper sigma conversion, Vector Softcap, Midpoint Refinement,
|
| 394 |
+
# TCFG (SVD tangential damping), Disagreement Gate, Momentum + EMA Delta
|
| 395 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 396 |
+
|
| 397 |
+
def _nrs_log_skip(name, exc=None):
|
| 398 |
+
"""Silent skip β no crash on NRS sub-feature error."""
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _validate_4d(t, tag=""):
|
| 403 |
+
return (torch.is_tensor(t) and t.dim() == 4
|
| 404 |
+
and t.numel() > 0 and torch.isfinite(t).all())
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@torch.no_grad()
|
| 408 |
+
def _nrs_softcap_vector(v: torch.Tensor, strength: float,
|
| 409 |
+
mode: str = "Per Sample") -> torch.Tensor:
|
| 410 |
+
"""
|
| 411 |
+
Norm-based tanh softcap β preserves direction, compresses magnitude.
|
| 412 |
+
Source: NRS Kohaku Enhanced v3.5 _nrs_softcap_vector()
|
| 413 |
+
Near-zero: linear. Large values: saturate smoothly.
|
| 414 |
+
"""
|
| 415 |
+
if strength <= 0.0:
|
| 416 |
+
return v
|
| 417 |
+
try:
|
| 418 |
+
eps = 1e-6
|
| 419 |
+
s = max(strength, eps)
|
| 420 |
+
if mode == "Per Channel":
|
| 421 |
+
v_abs = v.abs()
|
| 422 |
+
cap = (v_abs.mean(dim=(2, 3), keepdim=True) * s).clamp(min=eps)
|
| 423 |
+
ratio = v_abs / cap
|
| 424 |
+
scale = torch.tanh(ratio) / ratio.clamp(min=eps)
|
| 425 |
+
result = v * scale
|
| 426 |
+
elif mode == "Global Batch":
|
| 427 |
+
v_norm = v.norm(p=2, dim=(1, 2, 3), keepdim=True).clamp(min=eps)
|
| 428 |
+
cap = (v_norm.mean() * s).clamp(min=eps)
|
| 429 |
+
ratio = v_norm / cap
|
| 430 |
+
scale = torch.tanh(ratio) / ratio.clamp(min=eps)
|
| 431 |
+
result = v * scale
|
| 432 |
+
else: # "Per Sample" (default)
|
| 433 |
+
v_norm = v.norm(p=2, dim=1, keepdim=True).clamp(min=eps)
|
| 434 |
+
cap = (v_norm.mean(dim=(2, 3), keepdim=True) * s).clamp(min=eps)
|
| 435 |
+
ratio = v_norm / cap
|
| 436 |
+
scale = torch.tanh(ratio) / ratio.clamp(min=eps)
|
| 437 |
+
result = v * scale
|
| 438 |
+
return result if torch.isfinite(result).all() else v
|
| 439 |
+
except Exception:
|
| 440 |
+
return v
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
@torch.no_grad()
|
| 444 |
+
def _nrs_core(x_orig, cond, uncond, sigma,
|
| 445 |
+
skew: float, stretch: float, squash: float,
|
| 446 |
+
vector_softcap: float = 0.0,
|
| 447 |
+
vector_softcap_mode: str = "Per Sample") -> torch.Tensor:
|
| 448 |
+
"""
|
| 449 |
+
Proper sigma-aware NRS kernel.
|
| 450 |
+
Source: NRS Kohaku Enhanced v3.5 _nrs_core()
|
| 451 |
+
|
| 452 |
+
Converts denoised predictions to NRS-space via sigma, applies
|
| 453 |
+
Skew/Stretch/Squash geometry, then converts back.
|
| 454 |
+
Supports v-prediction models (auto-detected via sd_model.parameterization).
|
| 455 |
+
"""
|
| 456 |
+
if not _validate_4d(x_orig, "nrs_x") or not _validate_4d(cond, "nrs_c"):
|
| 457 |
+
return cond
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
# Detect v-prediction model
|
| 461 |
+
from modules import shared as _shared
|
| 462 |
+
is_v_pred = (hasattr(_shared.sd_model, "parameterization") and
|
| 463 |
+
_shared.sd_model.parameterization == "v")
|
| 464 |
+
except Exception:
|
| 465 |
+
is_v_pred = False
|
| 466 |
+
|
| 467 |
+
# Normalise sigma to a [1,1,1,1] tensor
|
| 468 |
+
if isinstance(sigma, torch.Tensor):
|
| 469 |
+
sig = sigma.flatten()[0].to(cond.dtype)
|
| 470 |
+
else:
|
| 471 |
+
sig = torch.tensor(float(sigma), device=cond.device, dtype=cond.dtype)
|
| 472 |
+
sig = sig.view(1, 1, 1, 1).clamp(min=1e-6)
|
| 473 |
+
sig_root = (sig ** 2 + 1).sqrt()
|
| 474 |
+
|
| 475 |
+
# Convert to NRS geometry space
|
| 476 |
+
if is_v_pred:
|
| 477 |
+
nrs_c, nrs_u = cond, uncond
|
| 478 |
+
else:
|
| 479 |
+
x_div = x_orig / (sig ** 2 + 1)
|
| 480 |
+
factor = sig / sig_root
|
| 481 |
+
nrs_c = x_orig - (x_div - cond * factor)
|
| 482 |
+
nrs_u = x_orig - (x_div - uncond * factor)
|
| 483 |
+
|
| 484 |
+
eps = 1e-6
|
| 485 |
+
|
| 486 |
+
def _dot(a, b): return (a * b).sum(dim=1, keepdim=True)
|
| 487 |
+
def _nrm2(v): return _dot(v, v)
|
| 488 |
+
|
| 489 |
+
c2c = _nrm2(nrs_c) + eps
|
| 490 |
+
u2c = _dot(nrs_u, nrs_c)
|
| 491 |
+
u_on_c = (u2c / c2c) * nrs_c
|
| 492 |
+
|
| 493 |
+
proj_diff = nrs_c - u_on_c
|
| 494 |
+
if vector_softcap > 0.0:
|
| 495 |
+
proj_diff = _nrs_softcap_vector(proj_diff, vector_softcap, vector_softcap_mode)
|
| 496 |
+
stretched = nrs_c + stretch * proj_diff
|
| 497 |
+
|
| 498 |
+
u_rej_c = nrs_u - u_on_c
|
| 499 |
+
if vector_softcap > 0.0:
|
| 500 |
+
u_rej_c = _nrs_softcap_vector(u_rej_c, vector_softcap, vector_softcap_mode)
|
| 501 |
+
skewed = stretched - skew * u_rej_c
|
| 502 |
+
|
| 503 |
+
# Squash: normalise back to original cond length (L2)
|
| 504 |
+
c_len = nrs_c.norm(dim=1, keepdim=True)
|
| 505 |
+
s_len = skewed.norm(dim=1, keepdim=True) + eps
|
| 506 |
+
sq_scale = ((1.0 - squash) + squash * (c_len / s_len)).clamp(-10.0, 10.0)
|
| 507 |
+
x_final = skewed * sq_scale
|
| 508 |
+
|
| 509 |
+
if not torch.isfinite(x_final).all():
|
| 510 |
+
return cond
|
| 511 |
+
|
| 512 |
+
# Convert back to denoised space
|
| 513 |
+
if is_v_pred:
|
| 514 |
+
return x_final
|
| 515 |
+
else:
|
| 516 |
+
return (x_div - (x_orig - x_final)) * (sig_root / sig)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
@torch.no_grad()
|
| 520 |
+
def _nrs_midpoint_refined(x_orig, cond, uncond, sigma,
|
| 521 |
+
skew, stretch, squash,
|
| 522 |
+
blend: float = 0.5,
|
| 523 |
+
mode: str = "Classic",
|
| 524 |
+
first_half_only: bool = True,
|
| 525 |
+
step: int = 0, total: int = 20,
|
| 526 |
+
softcap: float = 0.0,
|
| 527 |
+
softcap_mode: str = "Per Sample") -> torch.Tensor:
|
| 528 |
+
"""
|
| 529 |
+
Kohaku RK2 midpoint refinement for NRS.
|
| 530 |
+
Source: NRS Kohaku Enhanced v3.5 calc_nrs_midpoint_refined()
|
| 531 |
+
|
| 532 |
+
nrs_direct = NRS(x, c, u)
|
| 533 |
+
x_mid = x + (nrs_direct - x) * blend * 0.5 β midpoint shift
|
| 534 |
+
nrs_refined = NRS(x_mid, c, u)
|
| 535 |
+
result = (nrs_direct + nrs_refined) / 2 β RK2 average
|
| 536 |
+
"""
|
| 537 |
+
def _nrs(x, c, u):
|
| 538 |
+
return _nrs_core(x, c, u, sigma, skew, stretch, squash, softcap, softcap_mode)
|
| 539 |
+
|
| 540 |
+
nrs_direct = _nrs(x_orig, cond, uncond)
|
| 541 |
+
|
| 542 |
+
if blend <= 0.0:
|
| 543 |
+
return nrs_direct
|
| 544 |
+
if first_half_only and step > total / 2:
|
| 545 |
+
return nrs_direct
|
| 546 |
+
|
| 547 |
+
x_mid = x_orig + (nrs_direct - x_orig) * (blend * 0.5)
|
| 548 |
+
|
| 549 |
+
if mode == "Conservative":
|
| 550 |
+
nrs_ref = _nrs_core(x_mid, cond, uncond, sigma,
|
| 551 |
+
skew * 0.75, stretch * 0.75, squash,
|
| 552 |
+
softcap, softcap_mode)
|
| 553 |
+
else:
|
| 554 |
+
nrs_ref = _nrs(x_mid, cond, uncond)
|
| 555 |
+
|
| 556 |
+
if mode == "Directional Midpoint":
|
| 557 |
+
try:
|
| 558 |
+
correction = nrs_ref - nrs_direct
|
| 559 |
+
base_dir = nrs_direct - x_orig
|
| 560 |
+
dims = tuple(range(1, correction.ndim))
|
| 561 |
+
guide_norm = base_dir.norm(p=2, dim=dims, keepdim=True)
|
| 562 |
+
if torch.any(guide_norm > 1e-6):
|
| 563 |
+
# Project correction onto base_dir direction
|
| 564 |
+
c_dot_g = (correction * base_dir).sum(dim=dims, keepdim=True)
|
| 565 |
+
g_dot_g = (base_dir * base_dir).sum(dim=dims, keepdim=True) + 1e-9
|
| 566 |
+
c_para = (c_dot_g / g_dot_g) * base_dir
|
| 567 |
+
c_orth = correction - c_para
|
| 568 |
+
nrs_ref = nrs_direct + c_para + c_orth * 0.75
|
| 569 |
+
except Exception:
|
| 570 |
+
pass
|
| 571 |
+
|
| 572 |
+
return (nrs_direct + nrs_ref) * 0.5
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
# ββ TCFG: Tangential Damping via SVD (NRS Kohaku #33) βββββββββββββββββββ
|
| 576 |
+
|
| 577 |
+
@torch.no_grad()
|
| 578 |
+
def _tcfg_uncond(cond: torch.Tensor, uncond: torch.Tensor,
|
| 579 |
+
step: int = -1) -> torch.Tensor:
|
| 580 |
+
"""
|
| 581 |
+
TCFG: filter tangential components from uncond that are misaligned with cond.
|
| 582 |
+
Source: arxiv 2503.18137 (Kwon et al., CVPR 2025) / NRS Kohaku apply_tcfg_uncond()
|
| 583 |
+
|
| 584 |
+
Algorithm (10 lines):
|
| 585 |
+
1. Stack [cond, uncond] β [B, 2, N]
|
| 586 |
+
2. SVD β Vh: [B, 2, N]
|
| 587 |
+
3. Zero out second right singular vector (tangential)
|
| 588 |
+
4. Project uncond through zeroed Vh
|
| 589 |
+
"""
|
| 590 |
+
if not _validate_4d(cond, "tcfg_c") or not _validate_4d(uncond, "tcfg_u"):
|
| 591 |
+
return uncond
|
| 592 |
+
if cond.shape != uncond.shape:
|
| 593 |
+
return uncond
|
| 594 |
+
orig_dtype = uncond.dtype
|
| 595 |
+
orig_device = uncond.device
|
| 596 |
+
try:
|
| 597 |
+
if not torch.isfinite(cond).all() or not torch.isfinite(uncond).all():
|
| 598 |
+
return uncond
|
| 599 |
+
B = cond.shape[0]
|
| 600 |
+
cf = cond.float(); uf = uncond.float()
|
| 601 |
+
all_n = torch.stack((cf, uf), dim=1).reshape(B, 2, -1) # [B,2,N]
|
| 602 |
+
_U, _S, Vh = torch.linalg.svd(all_n, full_matrices=False) # Vh:[B,2,N]
|
| 603 |
+
Vh_mod = Vh.clone(); Vh_mod[:, 1, :] = 0.0
|
| 604 |
+
uf_flat = uf.reshape(B, 1, -1) # [B,1,N]
|
| 605 |
+
x_Vh = torch.matmul(uf_flat, Vh.transpose(-2, -1)) # [B,1,2]
|
| 606 |
+
result = torch.matmul(x_Vh, Vh_mod).reshape(*uncond.shape)
|
| 607 |
+
if not torch.isfinite(result).all():
|
| 608 |
+
return uncond
|
| 609 |
+
return result.to(device=orig_device, dtype=orig_dtype)
|
| 610 |
+
except Exception:
|
| 611 |
+
return uncond
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
# ββ Disagreement Gate (NRS Kohaku #38) ββββββββββββββββββββββββββββββββββ
|
| 615 |
+
|
| 616 |
+
@torch.no_grad()
|
| 617 |
+
def _disagreement_gate(cond: torch.Tensor, uncond: torch.Tensor,
|
| 618 |
+
skew: float, stretch: float,
|
| 619 |
+
strength: float = 0.5, threshold: float = 0.3,
|
| 620 |
+
metric: str = "Cosine") -> tuple:
|
| 621 |
+
"""
|
| 622 |
+
Scale skew/stretch by how much cond and uncond disagree.
|
| 623 |
+
Source: NRS Kohaku Enhanced apply_disagreement_gate()
|
| 624 |
+
|
| 625 |
+
gate = strength + (1 - strength) * clamp(disagreement / threshold, 0, 1)
|
| 626 |
+
At zero disagreement β gate = strength (minimum scaling)
|
| 627 |
+
At full disagreement β gate = 1.0 (full skew/stretch)
|
| 628 |
+
"""
|
| 629 |
+
if not torch.is_tensor(cond) or not torch.is_tensor(uncond):
|
| 630 |
+
return skew, stretch
|
| 631 |
+
try:
|
| 632 |
+
eps = 1e-6
|
| 633 |
+
c = cond.float().reshape(cond.shape[0], -1)
|
| 634 |
+
u = uncond.float().reshape(uncond.shape[0], -1)
|
| 635 |
+
if metric == "L2":
|
| 636 |
+
diff_norm = (c - u).norm(p=2, dim=1)
|
| 637 |
+
ref_norm = (c.norm(p=2, dim=1) + u.norm(p=2, dim=1)) / 2.0 + eps
|
| 638 |
+
raw_dist = (diff_norm / ref_norm).mean().clamp(0.0, 2.0) / 2.0
|
| 639 |
+
else: # Cosine
|
| 640 |
+
cn = torch.nn.functional.normalize(c, dim=1)
|
| 641 |
+
un = torch.nn.functional.normalize(u, dim=1)
|
| 642 |
+
raw_dist = ((1.0 - (cn * un).sum(dim=1).mean()) / 2.0).clamp(0.0, 1.0)
|
| 643 |
+
if not torch.isfinite(raw_dist):
|
| 644 |
+
return skew, stretch
|
| 645 |
+
thr = max(threshold, eps)
|
| 646 |
+
dis = float((raw_dist / thr).clamp(0.0, 1.0).item())
|
| 647 |
+
gate = strength + (1.0 - strength) * dis
|
| 648 |
+
return skew * gate, stretch * gate
|
| 649 |
+
except Exception:
|
| 650 |
+
return skew, stretch
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
# ββ Inter-step smoothing: Momentum + EMA Delta (NRS Kohaku) βββββββββββββ
|
| 654 |
+
|
| 655 |
+
@torch.no_grad()
|
| 656 |
+
def _apply_momentum(result, prev_result, prev_vel, momentum: float):
|
| 657 |
+
"""
|
| 658 |
+
RES/Clybius momentum smoothing between denoising steps.
|
| 659 |
+
Source: NRS Kohaku apply_nrs_momentum()
|
| 660 |
+
Formula: vel = m*(1-m/2)*prev_vel + (1-m*(1-m/2))*curr_diff
|
| 661 |
+
result = prev_result + vel
|
| 662 |
+
Returns (smoothed_result, new_vel)
|
| 663 |
+
"""
|
| 664 |
+
if momentum <= 0.0 or prev_result is None:
|
| 665 |
+
curr = result - prev_result if prev_result is not None else None
|
| 666 |
+
return result, curr
|
| 667 |
+
try:
|
| 668 |
+
curr_diff = result - prev_result
|
| 669 |
+
eff_m = momentum * (1.0 - momentum * 0.5)
|
| 670 |
+
vel = (eff_m * prev_vel + (1.0 - eff_m) * curr_diff
|
| 671 |
+
if prev_vel is not None else curr_diff)
|
| 672 |
+
return prev_result + vel, vel
|
| 673 |
+
except Exception:
|
| 674 |
+
return result, None
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
@torch.no_grad()
|
| 678 |
+
def _apply_ema_delta(result, prev_result, prev_delta, ema_alpha: float):
|
| 679 |
+
"""
|
| 680 |
+
EMA smoothing in update-vector space.
|
| 681 |
+
Source: NRS Kohaku apply_nrs_ema_delta()
|
| 682 |
+
Smooths the step-to-step update delta, not the full result tensor.
|
| 683 |
+
Returns (smoothed_result, new_delta_state)
|
| 684 |
+
"""
|
| 685 |
+
if prev_result is None:
|
| 686 |
+
return result, None
|
| 687 |
+
try:
|
| 688 |
+
curr_delta = result - prev_result
|
| 689 |
+
delta_state = (torch.lerp(prev_delta, curr_delta, ema_alpha)
|
| 690 |
+
if prev_delta is not None else curr_delta)
|
| 691 |
+
return prev_result + delta_state, delta_state
|
| 692 |
+
except Exception:
|
| 693 |
+
return result, None
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 697 |
+
# Skimmed CFG (sd-webui-skimmed_cfg)
|
| 698 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 699 |
+
|
| 700 |
+
@torch.no_grad()
|
| 701 |
+
def _skim_compute(c, u, cs, ss, no_flip):
|
| 702 |
+
"""Compute skimmed CFG mask.
|
| 703 |
+
Returns (mask, cond_correction, uncond_correction) where:
|
| 704 |
+
- cond_correction is subtracted from cond to move it toward uncond
|
| 705 |
+
- uncond_correction is ADDED to uncond to move it toward cond
|
| 706 |
+
Formula matches original skimmed_CFG reference with x_orig = uncond."""
|
| 707 |
+
den = u - (cs * ((u - c) - (u - u)))
|
| 708 |
+
ms = (c - u).sign() == c.sign()
|
| 709 |
+
md = c.sign() == (c * cs - u * (cs - 1)).sign()
|
| 710 |
+
if no_flip:
|
| 711 |
+
mask = ms & md
|
| 712 |
+
else:
|
| 713 |
+
mask = ms & md & (den.sign() == (den - u).sign())
|
| 714 |
+
low = u - (ss * ((u - c) - (u - u)))
|
| 715 |
+
# correction for cond (subtract): (den - low) / cs
|
| 716 |
+
# correction for uncond (add): (den - low) / cs (same magnitude, opposite direction)
|
| 717 |
+
corr = (den - low) / cs
|
| 718 |
+
return mask, corr, corr
|
| 719 |
+
|
| 720 |
+
@torch.no_grad()
|
| 721 |
+
def _skim_smooth(u, c, cs, ss):
|
| 722 |
+
df = u - cs * c - (cs - 1) * u
|
| 723 |
+
ds = u - ss * c - (ss - 1) * u
|
| 724 |
+
ad = (df - ds).abs(); span = ad.max() - ad.min()
|
| 725 |
+
ad = (ad - ad.min()) / (span + 1e-8)
|
| 726 |
+
nr = (ss - 1) / max(cs - 1, 1e-6)
|
| 727 |
+
return (c * (1 - nr) + u * nr) * (1 - ad) + u * ad
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def _get_emb(raw):
|
| 731 |
+
"""Extract crossattn tensor from raw conditioning (handles SD1.x and SDXL dict)."""
|
| 732 |
+
if isinstance(raw, dict):
|
| 733 |
+
return raw.get("crossattn", None)
|
| 734 |
+
return raw if torch.is_tensor(raw) else None
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def _set_emb(raw_ref, new_tensor):
|
| 738 |
+
"""Write back modified tensor, preserving dict wrapper for SDXL.
|
| 739 |
+
NEVER changes the shape of the original β only replaces values in-place."""
|
| 740 |
+
if isinstance(raw_ref, dict):
|
| 741 |
+
out = raw_ref.__class__(raw_ref); out["crossattn"] = new_tensor; return out
|
| 742 |
+
return new_tensor
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
def _emb_window(c: torch.Tensor, u: torch.Tensor):
|
| 746 |
+
"""Return slices of c and u at the MINIMUM shared sequence length.
|
| 747 |
+
This guarantees we NEVER extend either tensor β A1111 crashes if shape changes."""
|
| 748 |
+
sl = min(c.shape[1], u.shape[1])
|
| 749 |
+
return c[:, :sl], u[:, :sl], sl
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def _apply_skimmed(params) -> None:
|
| 753 |
+
"""Apply Skimmed CFG to text_uncond and/or text_cond.
|
| 754 |
+
SAFETY: only plain tensors or SDXL dicts are processed; never changes shape.
|
| 755 |
+
skim_target controls which tensor gets modified:
|
| 756 |
+
- "Uncond": modify uncond directly (port default, moves uncond toward cond)
|
| 757 |
+
- "Cond β Uncond": modify cond and assign to uncond (original reference behavior,
|
| 758 |
+
makes condβuncond at mask positions β zero guidance from those positions)
|
| 759 |
+
- "Both": modify both cond and uncond"""
|
| 760 |
+
if not _st.apply_to_neg and not _st.apply_to_pos:
|
| 761 |
+
return
|
| 762 |
+
rc = params.text_cond
|
| 763 |
+
ru = params.text_uncond
|
| 764 |
+
if ru is None:
|
| 765 |
+
return
|
| 766 |
+
|
| 767 |
+
c = _get_emb(rc)
|
| 768 |
+
u = _get_emb(ru)
|
| 769 |
+
if c is None or u is None:
|
| 770 |
+
return
|
| 771 |
+
if not torch.is_tensor(c) or not torch.is_tensor(u):
|
| 772 |
+
return
|
| 773 |
+
if c.dim() != 3 or u.dim() != 3:
|
| 774 |
+
return
|
| 775 |
+
|
| 776 |
+
c_w, u_w, sl = _emb_window(c, u)
|
| 777 |
+
|
| 778 |
+
d = params.denoiser
|
| 779 |
+
cfg_base = (_st.hr_cfg if (_st.in_hr_pass and _st.hr_en) else _st.base_cfg)
|
| 780 |
+
cs = max(d.cond_scale_miltiplier * cfg_base, 1e-6)
|
| 781 |
+
|
| 782 |
+
eff_scale = _st._sched_skim_scale
|
| 783 |
+
if _st.skim_mode == "Smooth":
|
| 784 |
+
# _skim_smooth always produces a modified uncond
|
| 785 |
+
new_u_w = _skim_smooth(u_w, c_w, cs, eff_scale)
|
| 786 |
+
if _st.apply_to_neg:
|
| 787 |
+
new_u = u.clone()
|
| 788 |
+
new_u[:, :sl] = new_u_w
|
| 789 |
+
params.text_uncond = _set_emb(ru, new_u)
|
| 790 |
+
if _st.apply_to_pos:
|
| 791 |
+
new_c = c.clone()
|
| 792 |
+
new_c[:, :sl] = _skim_smooth(c_w, u_w, cs, eff_scale)
|
| 793 |
+
params.text_cond = _set_emb(rc, new_c)
|
| 794 |
+
return
|
| 795 |
+
|
| 796 |
+
# Classic mode: compute mask + corrections
|
| 797 |
+
mask, corr_c, corr_u = _skim_compute(c_w, u_w, cs, eff_scale, _st.skim_flip)
|
| 798 |
+
target = _st.skim_target
|
| 799 |
+
|
| 800 |
+
if target == "Cond β Uncond":
|
| 801 |
+
if _st.apply_to_neg:
|
| 802 |
+
new_c_for_u = c.clone()
|
| 803 |
+
new_c_for_u[:, :sl][mask] = c_w[mask] - corr_c[mask]
|
| 804 |
+
params.text_uncond = _set_emb(ru, new_c_for_u)
|
| 805 |
+
if _st.apply_to_pos:
|
| 806 |
+
new_c = c.clone()
|
| 807 |
+
new_c[:, :sl][mask] = c_w[mask] - corr_c[mask]
|
| 808 |
+
params.text_cond = _set_emb(rc, new_c)
|
| 809 |
+
|
| 810 |
+
elif target == "Uncond":
|
| 811 |
+
if _st.apply_to_neg:
|
| 812 |
+
new_u = u.clone()
|
| 813 |
+
new_u[:, :sl][mask] = u_w[mask] + corr_u[mask]
|
| 814 |
+
params.text_uncond = _set_emb(ru, new_u)
|
| 815 |
+
if _st.apply_to_pos:
|
| 816 |
+
new_c = c.clone()
|
| 817 |
+
new_c[:, :sl][mask] = c_w[mask] - corr_c[mask]
|
| 818 |
+
params.text_cond = _set_emb(rc, new_c)
|
| 819 |
+
|
| 820 |
+
elif target == "Both":
|
| 821 |
+
if _st.apply_to_neg:
|
| 822 |
+
new_u = u.clone()
|
| 823 |
+
new_u[:, :sl][mask] = u_w[mask] + corr_u[mask] * 0.5
|
| 824 |
+
params.text_uncond = _set_emb(ru, new_u)
|
| 825 |
+
if _st.apply_to_pos:
|
| 826 |
+
new_c = c.clone()
|
| 827 |
+
new_c[:, :sl][mask] = c_w[mask] - corr_c[mask] * 0.5
|
| 828 |
+
params.text_cond = _set_emb(rc, new_c)
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 832 |
+
# Lerp Uncond (AutomaticCFG / pre_cfg_comfy)
|
| 833 |
+
# Blend text_cond/uncond toward each other, preserving original norm.
|
| 834 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 835 |
+
|
| 836 |
+
@torch.no_grad()
|
| 837 |
+
def _apply_lerp_uncond(params) -> None:
|
| 838 |
+
rc = params.text_cond
|
| 839 |
+
ru = params.text_uncond
|
| 840 |
+
if ru is None:
|
| 841 |
+
return
|
| 842 |
+
|
| 843 |
+
c = _get_emb(rc)
|
| 844 |
+
u = _get_emb(ru)
|
| 845 |
+
if c is None or u is None:
|
| 846 |
+
return
|
| 847 |
+
if not torch.is_tensor(c) or not torch.is_tensor(u):
|
| 848 |
+
return
|
| 849 |
+
if c.dim() != 3 or u.dim() != 3:
|
| 850 |
+
return
|
| 851 |
+
|
| 852 |
+
# Work on MINIMUM shared length only β never extend
|
| 853 |
+
c_w, u_w, sl = _emb_window(c, u)
|
| 854 |
+
|
| 855 |
+
eff_lerp = _st._sched_lerp_str
|
| 856 |
+
if _st.apply_to_neg:
|
| 857 |
+
u_norm = u.norm()
|
| 858 |
+
nu_w = torch.lerp(c_w, u_w, eff_lerp)
|
| 859 |
+
nu_n = nu_w.norm()
|
| 860 |
+
if nu_n > 1e-8:
|
| 861 |
+
nu_w = nu_w * (u_norm / nu_n)
|
| 862 |
+
new_u = u.clone()
|
| 863 |
+
new_u[:, :sl] = nu_w
|
| 864 |
+
params.text_uncond = _set_emb(ru, new_u)
|
| 865 |
+
|
| 866 |
+
if _st.apply_to_pos:
|
| 867 |
+
c_norm = c.norm()
|
| 868 |
+
nc_w = torch.lerp(u_w, c_w, eff_lerp)
|
| 869 |
+
nc_n = nc_w.norm()
|
| 870 |
+
if nc_n > 1e-8:
|
| 871 |
+
nc_w = nc_w * (c_norm / nc_n)
|
| 872 |
+
new_c = c.clone()
|
| 873 |
+
new_c[:, :sl] = nc_w
|
| 874 |
+
params.text_cond = _set_emb(rc, new_c)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 878 |
+
# on_cfg_denoiser callback (registered once at import)
|
| 879 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 880 |
+
|
| 881 |
+
def _cond_has_shape(v) -> bool:
|
| 882 |
+
"""Return True if v is something A1111 can call .shape[1] on without crashing.
|
| 883 |
+
|
| 884 |
+
A1111 sd_samplers_cfg_denoiser.py line 243 does:
|
| 885 |
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
| 886 |
+
When pad_cond_uncond settings are OFF (A1111 default), this is the FIRST
|
| 887 |
+
place .shape is accessed after our callback. If v is a dict (SDXL-style
|
| 888 |
+
conditioning set by another extension) A1111 crashes with:
|
| 889 |
+
AttributeError: 'dict' object has no attribute 'shape'
|
| 890 |
+
We guard against that here.
|
| 891 |
+
"""
|
| 892 |
+
if torch.is_tensor(v):
|
| 893 |
+
return True
|
| 894 |
+
if isinstance(v, dict):
|
| 895 |
+
ca = v.get("crossattn", None)
|
| 896 |
+
return torch.is_tensor(ca)
|
| 897 |
+
return hasattr(v, "shape")
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def _on_cfg_denoiser(params) -> None:
|
| 901 |
+
if not _st.enabled:
|
| 902 |
+
return
|
| 903 |
+
|
| 904 |
+
d = params.denoiser
|
| 905 |
+
total = max(1, params.total_sampling_steps)
|
| 906 |
+
step = params.sampling_step
|
| 907 |
+
t = step / total
|
| 908 |
+
|
| 909 |
+
# Store context needed by _custom_combine (sigma-aware NRS, midpoint, inter-step)
|
| 910 |
+
_st.current_x = params.x
|
| 911 |
+
_st.current_sigma = params.sigma
|
| 912 |
+
_st.current_step = step
|
| 913 |
+
_st.total_steps = total
|
| 914 |
+
|
| 915 |
+
# ββ Compute scheduled parameter values for this step ββββββββββββββββββ
|
| 916 |
+
_st._sched_skim_scale = _schedule_value(
|
| 917 |
+
_st.skim_scale, step, total, _st.skim_sched_curve, _st.skim_sched_min)
|
| 918 |
+
_st._sched_lerp_str = _schedule_value(
|
| 919 |
+
_st.lerp_str, step, total, _st.lerp_sched_curve, _st.lerp_sched_min)
|
| 920 |
+
_st._sched_nrs_skew = _schedule_value(
|
| 921 |
+
_st.nrs_skew, step, total, _st.nrs_sched_curve, _st.nrs_sched_min_skew)
|
| 922 |
+
_st._sched_nrs_stretch = _schedule_value(
|
| 923 |
+
_st.nrs_stretch, step, total, _st.nrs_sched_curve, _st.nrs_sched_min_stretch)
|
| 924 |
+
_st._sched_heuristic_cfg = _schedule_value(
|
| 925 |
+
_st.heuristic_cfg, step, total, _st.heuristic_sched_curve, _st.heuristic_sched_min)
|
| 926 |
+
|
| 927 |
+
# ββ NRS Individual Step Control overrides (Range1/Range2 + Lock) ββββββ
|
| 928 |
+
_st._sched_nrs_skew, _st.nrs_sc_locked = _step_control_val(
|
| 929 |
+
_st._sched_nrs_skew, step,
|
| 930 |
+
_st.nrs_sc_enabled,
|
| 931 |
+
_st.nrs_sc_r1_start, _st.nrs_sc_r1_end, _st.nrs_sc_r1_skew,
|
| 932 |
+
_st.nrs_sc_r2_start, _st.nrs_sc_r2_end, _st.nrs_sc_r2_skew,
|
| 933 |
+
_st.nrs_sc_lock, _st.nrs_sc_locked, off_val=0.0)
|
| 934 |
+
_st._sched_nrs_stretch, _ = _step_control_val(
|
| 935 |
+
_st._sched_nrs_stretch, step,
|
| 936 |
+
_st.nrs_sc_enabled,
|
| 937 |
+
_st.nrs_sc_r1_start, _st.nrs_sc_r1_end, _st.nrs_sc_r1_stretch,
|
| 938 |
+
_st.nrs_sc_r2_start, _st.nrs_sc_r2_end, _st.nrs_sc_r2_stretch,
|
| 939 |
+
_st.nrs_sc_lock, _st.nrs_sc_locked, off_val=0.0)
|
| 940 |
+
|
| 941 |
+
# ββ Pass-mode gate (schedule / boost / jitter / ranges) ββββββββββββββ
|
| 942 |
+
if _pass_features_active():
|
| 943 |
+
d.cond_scale_miltiplier = 1.0
|
| 944 |
+
|
| 945 |
+
if _st.ranges_en and _st.range_scales and step < len(_st.range_scales):
|
| 946 |
+
d.cond_scale_miltiplier = _st.range_scales[step] / max(_st.base_cfg, 1e-4)
|
| 947 |
+
|
| 948 |
+
if _st.sched_en and _st.sched_k != "off":
|
| 949 |
+
d.cond_scale_miltiplier *= _sm(t, _st.sched_k, _st.sched_min, _st.sched_max)
|
| 950 |
+
|
| 951 |
+
if _st.boost_en and t >= _st.boost_from:
|
| 952 |
+
ramp = (t - _st.boost_from) / max(1e-9, 1.0 - _st.boost_from)
|
| 953 |
+
d.cond_scale_miltiplier *= 1.0 + (_st.boost_mul - 1.0) * ramp
|
| 954 |
+
|
| 955 |
+
if _st.jitter_en and _st.jitter_amt > 0:
|
| 956 |
+
d.cond_scale_miltiplier *= 1.0 + random.uniform(-_st.jitter_amt, _st.jitter_amt)
|
| 957 |
+
|
| 958 |
+
d.cond_scale_miltiplier = max(0.01, d.cond_scale_miltiplier)
|
| 959 |
+
else:
|
| 960 |
+
d.cond_scale_miltiplier = 1.0
|
| 961 |
+
|
| 962 |
+
# ββ Embedding-level features (SEGA-inspired window + lock guard) ββββββ
|
| 963 |
+
if _emb_features_active(step) and _pass_features_active():
|
| 964 |
+
# SAVE originals β restored on any failure or bad output
|
| 965 |
+
_orig_cond = params.text_cond
|
| 966 |
+
_orig_uncond = params.text_uncond
|
| 967 |
+
|
| 968 |
+
# Skip embedding features if text_cond is not in a state we can work with.
|
| 969 |
+
# Another extension (e.g. NRS Kohaku) may have already turned text_cond
|
| 970 |
+
# into a plain dict without a "crossattn" key, which would crash A1111
|
| 971 |
+
# at line 243 of sd_samplers_cfg_denoiser.py regardless of our changes.
|
| 972 |
+
if not _cond_has_shape(_orig_cond) or not _cond_has_shape(_orig_uncond):
|
| 973 |
+
return # don't touch anything β let A1111 handle it
|
| 974 |
+
|
| 975 |
+
try:
|
| 976 |
+
if _st.lerp_en:
|
| 977 |
+
_apply_lerp_uncond(params)
|
| 978 |
+
if _st.skim_en:
|
| 979 |
+
_apply_skimmed(params)
|
| 980 |
+
except Exception:
|
| 981 |
+
# On any failure, restore to the guaranteed-valid originals
|
| 982 |
+
params.text_cond = _orig_cond
|
| 983 |
+
params.text_uncond = _orig_uncond
|
| 984 |
+
return
|
| 985 |
+
|
| 986 |
+
# Post-modification validation: ensure we haven't corrupted the shape.
|
| 987 |
+
# If something produced a non-.shape-able object, roll back.
|
| 988 |
+
if not _cond_has_shape(params.text_cond):
|
| 989 |
+
params.text_cond = _orig_cond
|
| 990 |
+
if not _cond_has_shape(params.text_uncond):
|
| 991 |
+
params.text_uncond = _orig_uncond
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
|
| 995 |
+
|
| 996 |
+
# ββ Late-fixup callback (registered lazily in before_process) ββββββββββββββββ
|
| 997 |
+
# WHY THIS EXISTS:
|
| 998 |
+
# A1111 loads extensions alphabetically. "cfg-prompt-forge" loads before
|
| 999 |
+
# extensions starting with letters later than 'c' (e.g. "nrs_kohakuβ¦").
|
| 1000 |
+
# Callbacks registered at import time fire in registration order, so NRS Kohaku's
|
| 1001 |
+
# callback runs AFTER ours. NRS Kohaku can wrap text_cond in an SDXL-style dict
|
| 1002 |
+
# even when the model is SD1.x, causing A1111 line-243 to crash:
|
| 1003 |
+
# if tensor.shape[1] == uncond.shape[1] β AttributeError: dict has no .shape
|
| 1004 |
+
#
|
| 1005 |
+
# Solution: register _late_cond_fixup lazily inside before_process.
|
| 1006 |
+
# Because before_process runs AFTER all import-time registrations are done,
|
| 1007 |
+
# our fixup is added to the END of the callback list and fires LAST β after
|
| 1008 |
+
# every other extension that might corrupt the conditioning.
|
| 1009 |
+
|
| 1010 |
+
_fixup_registered: bool = False # module-level flag: only register once
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def _late_cond_fixup(params) -> None:
|
| 1014 |
+
"""Last-resort fixup with explicit model architecture detection.
|
| 1015 |
+
Unwraps dict-style conditioning back to plain tensors for SD1.x/SD2.x
|
| 1016 |
+
models, while preserving the dict wrapper for SDXL/SD3 models that
|
| 1017 |
+
genuinely require it.
|
| 1018 |
+
|
| 1019 |
+
Previous heuristic checked for "vector" key, but some extensions add
|
| 1020 |
+
a "vector" key to SD1.x conditioning as well, causing false negatives.
|
| 1021 |
+
Using shared.sd_model.is_sdxl / is_sd3 is authoritative.
|
| 1022 |
+
"""
|
| 1023 |
+
from modules import shared
|
| 1024 |
+
|
| 1025 |
+
m = getattr(shared, 'sd_model', None)
|
| 1026 |
+
expects_dict = getattr(m, 'is_sdxl', False) or getattr(m, 'is_sd3', False)
|
| 1027 |
+
|
| 1028 |
+
def _try_unwrap(v):
|
| 1029 |
+
if not isinstance(v, dict):
|
| 1030 |
+
return v
|
| 1031 |
+
ca = v.get("crossattn", None)
|
| 1032 |
+
if torch.is_tensor(ca) and not expects_dict:
|
| 1033 |
+
return ca
|
| 1034 |
+
return v
|
| 1035 |
+
|
| 1036 |
+
params.text_cond = _try_unwrap(params.text_cond)
|
| 1037 |
+
params.text_uncond = _try_unwrap(params.text_uncond)
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
# βββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββ
|
| 1041 |
+
# combine_denoised class-patch
|
| 1042 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1043 |
+
|
| 1044 |
+
def _hook_combine() -> None:
|
| 1045 |
+
if _DCLS is None: return
|
| 1046 |
+
if hasattr(_DCLS, _SAVED):
|
| 1047 |
+
setattr(_DCLS, _ORIG, getattr(_DCLS, _SAVED))
|
| 1048 |
+
delattr(_DCLS, _SAVED)
|
| 1049 |
+
orig = getattr(_DCLS, _ORIG)
|
| 1050 |
+
setattr(_DCLS, _SAVED, orig)
|
| 1051 |
+
|
| 1052 |
+
def _hooked(self, x_out, conds_list, uncond, cond_scale):
|
| 1053 |
+
if not _st.enabled:
|
| 1054 |
+
return orig(self, x_out, conds_list, uncond, cond_scale)
|
| 1055 |
+
try:
|
| 1056 |
+
return _custom_combine(orig, self, x_out, conds_list, uncond, cond_scale)
|
| 1057 |
+
except Exception:
|
| 1058 |
+
return orig(self, x_out, conds_list, uncond, cond_scale)
|
| 1059 |
+
|
| 1060 |
+
setattr(_DCLS, _ORIG, _hooked)
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
def _unhook_combine() -> None:
|
| 1064 |
+
if _DCLS is None: return
|
| 1065 |
+
if hasattr(_DCLS, _SAVED):
|
| 1066 |
+
setattr(_DCLS, _ORIG, getattr(_DCLS, _SAVED))
|
| 1067 |
+
delattr(_DCLS, _SAVED)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
@torch.no_grad()
|
| 1071 |
+
def _custom_combine(orig, self, x_out, conds_list, uncond, cond_scale):
|
| 1072 |
+
# Respect pass_mode
|
| 1073 |
+
if not _pass_features_active():
|
| 1074 |
+
return orig(self, x_out, conds_list, uncond, cond_scale)
|
| 1075 |
+
|
| 1076 |
+
unc = x_out[-uncond.shape[0]:]
|
| 1077 |
+
# cond_scale already contains cond_scale_miltiplier β A1111 forward() calls
|
| 1078 |
+
# combine_denoised(x_out, conds_list, uncond, cond_scale * self.cond_scale_miltiplier)
|
| 1079 |
+
# so we must NOT multiply by cond_scale_miltiplier again here.
|
| 1080 |
+
eff_scale = cond_scale
|
| 1081 |
+
|
| 1082 |
+
if _st.autocfg_en and _st.autocfg_method != "None":
|
| 1083 |
+
denoised = _auto_cfg_combine(unc, conds_list, x_out, eff_scale,
|
| 1084 |
+
_st.autocfg_ref, _st.autocfg_method, _st.autocfg_topk)
|
| 1085 |
+
else:
|
| 1086 |
+
# ββ TCFG: filter tangential uncond components before NRS βββββββββ
|
| 1087 |
+
unc_nrs = unc
|
| 1088 |
+
if _st.tcfg_en and len(conds_list) > 0:
|
| 1089 |
+
try:
|
| 1090 |
+
# Build representative cond tensor (one per batch item)
|
| 1091 |
+
rep_cond = torch.stack(
|
| 1092 |
+
[x_out[conds_list[i][0][0]] for i in range(unc.shape[0])
|
| 1093 |
+
if conds_list[i]], dim=0)
|
| 1094 |
+
unc_nrs = _tcfg_uncond(rep_cond, unc, step=_st.current_step)
|
| 1095 |
+
except Exception:
|
| 1096 |
+
unc_nrs = unc
|
| 1097 |
+
|
| 1098 |
+
denoised = torch.clone(unc)
|
| 1099 |
+
sigma = _st.current_sigma
|
| 1100 |
+
x_orig = _st.current_x # noisy latent = params.x = x_in
|
| 1101 |
+
|
| 1102 |
+
# x_in layout: [repeats[0]Γx[0], repeats[1]Γx[1], ..., x[0], x[1], ...]
|
| 1103 |
+
# The last batch_size entries are the original (unrepeat-ed) noisy latents,
|
| 1104 |
+
# one per batch item. Using x_orig[i] would give the wrong entry when
|
| 1105 |
+
# repeats[i] > 1 (AND composition). Use the tail slice instead.
|
| 1106 |
+
x_batch = (x_orig[-uncond.shape[0]:] if x_orig is not None else None)
|
| 1107 |
+
|
| 1108 |
+
for i, conds in enumerate(conds_list):
|
| 1109 |
+
for ci, w in conds:
|
| 1110 |
+
cp = x_out[ci]
|
| 1111 |
+
up = unc_nrs[i] if _st.tcfg_en else unc[i]
|
| 1112 |
+
|
| 1113 |
+
if _st.nrs_en:
|
| 1114 |
+
# ββ Disagreement Gate: scale skew/stretch adaptively ββ
|
| 1115 |
+
eff_skew, eff_stretch = _st._sched_nrs_skew, _st._sched_nrs_stretch
|
| 1116 |
+
if _st.dis_en:
|
| 1117 |
+
eff_skew, eff_stretch = _disagreement_gate(
|
| 1118 |
+
cp, up,
|
| 1119 |
+
eff_skew, eff_stretch,
|
| 1120 |
+
_st.dis_strength, _st.dis_threshold, _st.dis_metric)
|
| 1121 |
+
|
| 1122 |
+
# ββ Proper sigma-aware NRS with optional Midpoint βββββ
|
| 1123 |
+
if sigma is not None and x_batch is not None:
|
| 1124 |
+
xi = x_batch[i] if x_batch.shape[0] > i else x_batch[0]
|
| 1125 |
+
xi = xi.unsqueeze(0) if xi.dim() == 3 else xi
|
| 1126 |
+
cp_b = cp.unsqueeze(0) if cp.dim() == 3 else cp
|
| 1127 |
+
up_b = up.unsqueeze(0) if up.dim() == 3 else up
|
| 1128 |
+
|
| 1129 |
+
if _st.nrs_midpoint > 0.0:
|
| 1130 |
+
cp_b = _nrs_midpoint_refined(
|
| 1131 |
+
xi, cp_b, up_b, sigma,
|
| 1132 |
+
eff_skew, eff_stretch, _st.nrs_squash,
|
| 1133 |
+
blend=_st.nrs_midpoint,
|
| 1134 |
+
mode=_st.nrs_midpoint_mode,
|
| 1135 |
+
first_half_only=_st.nrs_midpoint_fh,
|
| 1136 |
+
step=_st.current_step, total=_st.total_steps,
|
| 1137 |
+
softcap=_st.nrs_softcap,
|
| 1138 |
+
softcap_mode=_st.nrs_softcap_mode)
|
| 1139 |
+
else:
|
| 1140 |
+
cp_b = _nrs_core(xi, cp_b, up_b, sigma,
|
| 1141 |
+
eff_skew, eff_stretch, _st.nrs_squash,
|
| 1142 |
+
vector_softcap=_st.nrs_softcap,
|
| 1143 |
+
vector_softcap_mode=_st.nrs_softcap_mode)
|
| 1144 |
+
|
| 1145 |
+
cp = cp_b.squeeze(0) if cp_b.shape[0] == 1 else cp_b
|
| 1146 |
+
else:
|
| 1147 |
+
# Fallback: simple geometry without sigma (no x_batch stored)
|
| 1148 |
+
eps = 1e-6
|
| 1149 |
+
B = 1; c = cp.reshape(1, -1); u = up.reshape(1, -1)
|
| 1150 |
+
c2c = (c * c).sum(dim=1, keepdim=True) + eps
|
| 1151 |
+
u2c = (u * c).sum(dim=1, keepdim=True)
|
| 1152 |
+
u_on_c = (u2c / c2c) * c
|
| 1153 |
+
stretched = c + eff_stretch * (c - u_on_c)
|
| 1154 |
+
skewed = stretched - eff_skew * (u - u_on_c)
|
| 1155 |
+
c_len = c.norm(dim=1, keepdim=True)
|
| 1156 |
+
s_len = skewed.norm(dim=1, keepdim=True) + eps
|
| 1157 |
+
cp = (skewed * ((1.0 - _st.nrs_squash)
|
| 1158 |
+
+ _st.nrs_squash * (c_len / s_len))).reshape(cp.shape)
|
| 1159 |
+
|
| 1160 |
+
denoised[i] = denoised[i] + (cp - unc[i]) * (w * eff_scale)
|
| 1161 |
+
|
| 1162 |
+
if _st.heuristic_en:
|
| 1163 |
+
denoised = _heuristic(denoised, unc, eff_scale, _st._sched_heuristic_cfg)
|
| 1164 |
+
if _st.reinhard_en:
|
| 1165 |
+
denoised = _reinhard(denoised, unc, eff_scale, _st.reinhard_ref)
|
| 1166 |
+
if _st.rescale_en:
|
| 1167 |
+
denoised = _rescale(denoised, x_out, conds_list, _st.rescale_mult)
|
| 1168 |
+
if _st.lir_en:
|
| 1169 |
+
denoised = _lir(denoised, _st.lir_cfg, _st.lir_method, _st.lir_topk)
|
| 1170 |
+
if _st.mean_en:
|
| 1171 |
+
dims = list(range(1, denoised.dim()))
|
| 1172 |
+
denoised = denoised - denoised.mean(dim=dims, keepdim=True)
|
| 1173 |
+
|
| 1174 |
+
# ββ Inter-step smoothing (Momentum / EMA Delta) βββββββββββββββββββββββ
|
| 1175 |
+
if _st.interp_mode == "Momentum" and _st.interp_mom > 0.0:
|
| 1176 |
+
denoised, _st._prev_vel = _apply_momentum(
|
| 1177 |
+
denoised, _st._prev_nrs, _st._prev_vel, _st.interp_mom)
|
| 1178 |
+
elif _st.interp_mode == "EMA Delta":
|
| 1179 |
+
denoised, _st._prev_delta = _apply_ema_delta(
|
| 1180 |
+
denoised, _st._prev_nrs, _st._prev_delta, _st.interp_ema)
|
| 1181 |
+
|
| 1182 |
+
_st._prev_nrs = denoised.detach().clone()
|
| 1183 |
+
|
| 1184 |
+
return denoised
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
# ββ Post-CFG shapers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1188 |
+
|
| 1189 |
+
@torch.no_grad()
|
| 1190 |
+
def _reinhard(result, unc, cs, ref):
|
| 1191 |
+
if cs <= 0: return result
|
| 1192 |
+
noise = result - unc
|
| 1193 |
+
dims = list(range(1, noise.dim()))
|
| 1194 |
+
mag = torch.linalg.vector_norm(noise, dim=dims, keepdim=True) + 1e-8
|
| 1195 |
+
unit = noise / mag
|
| 1196 |
+
mean = mag.mean(); std = mag.std() if mag.numel() > 1 else torch.zeros_like(mean)
|
| 1197 |
+
top = (mean + 3.0 * std) * (ref / cs)
|
| 1198 |
+
return unc + unit * (top * mag / (mag + top)) * cs
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
@torch.no_grad()
|
| 1202 |
+
def _rescale(result, x_out, conds_list, mult):
|
| 1203 |
+
try:
|
| 1204 |
+
ci = conds_list[0][0][0]
|
| 1205 |
+
ref = x_out[ci : ci + 1]
|
| 1206 |
+
dims = list(range(1, result.dim()))
|
| 1207 |
+
sc = ref.std(dim=dims, keepdim=True) + 1e-8
|
| 1208 |
+
sr = result.std(dim=dims, keepdim=True) + 1e-8
|
| 1209 |
+
return torch.lerp(result, result * (sc / sr), float(mult))
|
| 1210 |
+
except (IndexError, TypeError): return result
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
@torch.no_grad()
|
| 1214 |
+
def _heuristic(result, unc, cs, h_cfg):
|
| 1215 |
+
if cs <= 0 or abs(cs - h_cfg) < 0.05:
|
| 1216 |
+
return result
|
| 1217 |
+
if _st.heuristic_hstart > 0.0 and _st.current_step < _st.heuristic_hstart * _st.total_steps:
|
| 1218 |
+
return result
|
| 1219 |
+
noise = result - unc
|
| 1220 |
+
h_res = unc + h_cfg * noise / cs
|
| 1221 |
+
rQ = torch.quantile((result - result.mean()).abs().float(), 0.99) + 1e-8
|
| 1222 |
+
hQ = torch.quantile((h_res - h_res .mean()).abs().float(), 0.99) + 1e-8
|
| 1223 |
+
return unc + noise * (hQ / rQ)
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
@torch.no_grad()
|
| 1227 |
+
def _lir(result, target_cfg, method, topk):
|
| 1228 |
+
target = target_cfg / 10.0
|
| 1229 |
+
out = result.clone()
|
| 1230 |
+
for b in range(result.shape[0]):
|
| 1231 |
+
flat = result[b].reshape(result[b].shape[0], -1)
|
| 1232 |
+
for c in range(flat.shape[0]):
|
| 1233 |
+
r = max(_topk_range(flat[c], method, topk), 1e-8)
|
| 1234 |
+
out[b][c] = result[b][c] * (target / r)
|
| 1235 |
+
return out
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1239 |
+
# SVG schedule preview
|
| 1240 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1241 |
+
|
| 1242 |
+
def _make_svg(kind, mn, mx, W=340, H=88):
|
| 1243 |
+
PAD = 16; w, h = W - 2*PAD, H - 2*PAD; N = 80
|
| 1244 |
+
y_lo = 0.0; y_hi = max(3.0, float(mx) + 0.4)
|
| 1245 |
+
def px(t, v): return f"{PAD+t*w:.1f},{PAD+(1-(v-y_lo)/(y_hi-y_lo))*h:.1f}"
|
| 1246 |
+
pts = " ".join(px(i/N, _sm(i/N, kind, float(mn), float(mx))) for i in range(N+1))
|
| 1247 |
+
by = PAD + (1 - (1.0 - y_lo)/(y_hi - y_lo)) * h
|
| 1248 |
+
_, ymn = map(float, px(0.0, float(mn)).split(","))
|
| 1249 |
+
_, ymx = map(float, px(0.5, float(mx)).split(","))
|
| 1250 |
+
return (f'<div style="margin:4px 0"><svg xmlns="http://www.w3.org/2000/svg" '
|
| 1251 |
+
f'width="{W}" height="{H}" style="background:#0b0b18;border-radius:8px;'
|
| 1252 |
+
f'border:1px solid #1e1e36;display:block">'
|
| 1253 |
+
f'<line x1="{PAD}" y1="{by:.1f}" x2="{PAD+w}" y2="{by:.1f}" '
|
| 1254 |
+
f'stroke="#3a3a5a" stroke-dasharray="4 3" stroke-width="1"/>'
|
| 1255 |
+
f'<text x="{PAD+3}" y="{by-3:.0f}" fill="#555" font-size="9" font-family="monospace">Γ1.0</text>'
|
| 1256 |
+
f'<text x="{PAD+3}" y="{ymn+4:.0f}" fill="#6ab4ff" font-size="9" font-family="monospace" opacity=".8">Γ{float(mn):.2f}</text>'
|
| 1257 |
+
f'<text x="{PAD+3}" y="{ymx-2:.0f}" fill="#6ab4ff" font-size="9" font-family="monospace" opacity=".8">Γ{float(mx):.2f}</text>'
|
| 1258 |
+
f'<polyline points="{pts}" fill="none" stroke="#5bc8f5" stroke-width="2.2" '
|
| 1259 |
+
f'stroke-linecap="round" stroke-linejoin="round"/>'
|
| 1260 |
+
f'<text x="{PAD}" y="{H-2}" fill="#444" font-size="9" font-family="monospace">0</text>'
|
| 1261 |
+
f'<text x="{PAD+w-28}" y="{H-2}" fill="#444" font-size="9" font-family="monospace">final</text>'
|
| 1262 |
+
f'</svg></div>')
|
| 1263 |
+
|
| 1264 |
+
|
| 1265 |
+
def _h(txt):
|
| 1266 |
+
return (f"<div style='font-weight:700;color:#a8d4ff;margin:12px 0 2px;font-size:.9em;"
|
| 1267 |
+
f"letter-spacing:.03em'>{txt}</div>"
|
| 1268 |
+
f"<hr style='border:none;border-top:1px solid #1e1e36;margin:0 0 5px'/>")
|
| 1269 |
+
|
| 1270 |
+
|
| 1271 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1272 |
+
# Script
|
| 1273 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1274 |
+
|
| 1275 |
+
class CFGPromptForge(scripts.Script):
|
| 1276 |
+
|
| 1277 |
+
def title(self): return "CFG & Prompt Forge"
|
| 1278 |
+
def show(self, _): return scripts.AlwaysVisible
|
| 1279 |
+
|
| 1280 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1281 |
+
def ui(self, is_img2img):
|
| 1282 |
+
tab = "i2i" if is_img2img else "t2i"
|
| 1283 |
+
|
| 1284 |
+
with gr.Accordion("π§ CFG & Prompt Forge", open=False,
|
| 1285 |
+
elem_id=f"cpf_acc_{tab}"):
|
| 1286 |
+
|
| 1287 |
+
enabled = gr.Checkbox(label="β¦ Enable CFG & Prompt Forge",
|
| 1288 |
+
value=False, elem_id=f"cpf_en_{tab}")
|
| 1289 |
+
|
| 1290 |
+
# ββ 1. CORE OPTS ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1291 |
+
gr.HTML(_h("π Core Sampling Parameters"))
|
| 1292 |
+
with gr.Row():
|
| 1293 |
+
skip_early = gr.Slider(
|
| 1294 |
+
label="Skip Early CFG", minimum=0.0, maximum=1.0,
|
| 1295 |
+
step=0.01, value=0.0,
|
| 1296 |
+
info="Fraction of early steps where negative prompt is ignored.")
|
| 1297 |
+
word_wrap = gr.Slider(
|
| 1298 |
+
label="Prompt Word-Wrap Limit (tokens)",
|
| 1299 |
+
minimum=0, maximum=74, step=1, value=20,
|
| 1300 |
+
info="CLIP 75-token chunk backtrack threshold.")
|
| 1301 |
+
with gr.Row():
|
| 1302 |
+
ngms = gr.Slider(
|
| 1303 |
+
label="Neg. Guidance Min Ο (NGMS)",
|
| 1304 |
+
minimum=0.0, maximum=15.0, step=0.01, value=0.0,
|
| 1305 |
+
info="Skip negative prompt when Ο < this. Try 0.5β1.0.")
|
| 1306 |
+
ngms_all = gr.Checkbox(
|
| 1307 |
+
label="NGMS: skip every matching step",
|
| 1308 |
+
value=False)
|
| 1309 |
+
|
| 1310 |
+
# ββ 2. STEP RANGES ββββββββββββββββββββββββββββββββββββββββββββ
|
| 1311 |
+
gr.HTML(_h("π CFG Step Ranges (DyCFG)"))
|
| 1312 |
+
ranges_en = gr.Checkbox(label="Enable Step Ranges", value=False)
|
| 1313 |
+
range_rows = []
|
| 1314 |
+
for idx in range(3):
|
| 1315 |
+
with gr.Row():
|
| 1316 |
+
rs = gr.Number(label=f"#{idx+1} Start", value=0, precision=0, minimum=0, maximum=500)
|
| 1317 |
+
re = gr.Number(label="End (0=last)", value=0, precision=0, minimum=0, maximum=500)
|
| 1318 |
+
rv = gr.Number(label="CFG scale", value=7.0, precision=2, minimum=1.0, maximum=100.0)
|
| 1319 |
+
ri = gr.Radio(choices=["Default","Linear","Fixed"], value="Default", label="Interp")
|
| 1320 |
+
range_rows.extend([rs, re, rv, ri])
|
| 1321 |
+
|
| 1322 |
+
# ββ 3. SCHEDULE ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1323 |
+
gr.HTML(_h("π CFG Schedule"))
|
| 1324 |
+
sched_en = gr.Checkbox(label="Enable Schedule", value=False)
|
| 1325 |
+
with gr.Row():
|
| 1326 |
+
sched_type = gr.Dropdown(label="Shape",
|
| 1327 |
+
choices=list(_SCHED.values()), value=_SCHED["off"])
|
| 1328 |
+
sched_min = gr.Slider(label="Min Γ", minimum=0.0, maximum=2.0, step=0.05, value=0.5)
|
| 1329 |
+
sched_max = gr.Slider(label="Max Γ", minimum=0.5, maximum=4.0, step=0.05, value=1.5)
|
| 1330 |
+
sched_svg = gr.HTML(_make_svg("off", 0.5, 1.5))
|
| 1331 |
+
def _upd(lbl, mn, mx): return _make_svg(_L2K.get(lbl, "off"), mn, mx)
|
| 1332 |
+
sched_type.change(_upd, [sched_type, sched_min, sched_max], sched_svg)
|
| 1333 |
+
sched_min .change(_upd, [sched_type, sched_min, sched_max], sched_svg)
|
| 1334 |
+
sched_max .change(_upd, [sched_type, sched_min, sched_max], sched_svg)
|
| 1335 |
+
|
| 1336 |
+
# ββ 4. BOOST + JITTER βββββββββββββββββββββββββββββββββββββββββ
|
| 1337 |
+
gr.HTML(_h("π End-step Boost & π² Jitter"))
|
| 1338 |
+
with gr.Row():
|
| 1339 |
+
boost_en = gr.Checkbox(label="Enable End Boost", value=False)
|
| 1340 |
+
boost_from = gr.Slider(label="Boost From (step fraction)",
|
| 1341 |
+
minimum=0.5, maximum=0.95, step=0.01, value=0.80,
|
| 1342 |
+
info="e.g. 0.80 = activates in the last 20% of steps. "
|
| 1343 |
+
"Max capped at 0.95 β values near 1.0 may never fire "
|
| 1344 |
+
"on low step counts (with 20 steps, 0.95 fires only on "
|
| 1345 |
+
"the final step).")
|
| 1346 |
+
boost_mul = gr.Slider(label="Boost Γ",
|
| 1347 |
+
minimum=1.0, maximum=5.0, step=0.05, value=1.5)
|
| 1348 |
+
with gr.Row():
|
| 1349 |
+
jitter_en = gr.Checkbox(label="Enable Jitter", value=False)
|
| 1350 |
+
jitter_amt = gr.Slider(label="Jitter Amount (relative Β±)",
|
| 1351 |
+
minimum=0.0, maximum=0.5, step=0.01, value=0.10)
|
| 1352 |
+
|
| 1353 |
+
# ββ 5. AUTO CFG ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1354 |
+
gr.HTML(_h("π€ Auto CFG (ComfyUI-AutomaticCFG)"))
|
| 1355 |
+
with gr.Row():
|
| 1356 |
+
autocfg_en = gr.Checkbox(label="Enable Auto CFG", value=False)
|
| 1357 |
+
autocfg_method = gr.Dropdown(label="Method",
|
| 1358 |
+
choices=["hard","soft","hard_squared","range"], value="hard")
|
| 1359 |
+
with gr.Row():
|
| 1360 |
+
autocfg_ref = gr.Slider(label="Reference CFG",
|
| 1361 |
+
minimum=1.0, maximum=30.0, step=0.5, value=8.0)
|
| 1362 |
+
autocfg_topk = gr.Slider(label="Top-k fraction",
|
| 1363 |
+
minimum=0.05, maximum=0.5, step=0.01, value=0.25)
|
| 1364 |
+
|
| 1365 |
+
# ββ 6. LERP UNCOND ββββββββββββββββββββββββββββββββββββββββββββ
|
| 1366 |
+
gr.HTML(_h("π Uncond Lerp (AutoCFG / pre_cfg)"))
|
| 1367 |
+
with gr.Row():
|
| 1368 |
+
lerp_en = gr.Checkbox(label="Enable Uncond Lerp", value=False)
|
| 1369 |
+
lerp_str = gr.Slider(label="Lerp Strength",
|
| 1370 |
+
minimum=0.0, maximum=3.0, step=0.05, value=1.0,
|
| 1371 |
+
info="1.0 = no change. < 1 pulls uncond toward cond.")
|
| 1372 |
+
with gr.Row():
|
| 1373 |
+
lerp_sched_curve = gr.Dropdown(label="Schedule",
|
| 1374 |
+
choices=_CURVES, value="Off",
|
| 1375 |
+
info="Curve shape over steps.")
|
| 1376 |
+
lerp_sched_min = gr.Slider(label="Min Lerp",
|
| 1377 |
+
minimum=0.0, maximum=3.0, step=0.05, value=0.5,
|
| 1378 |
+
info="Lerp strength reached at the curve minimum.")
|
| 1379 |
+
|
| 1380 |
+
# ββ 7. SKIMMED CFG ββββββββββββββββββββββββββββββββββββββββββββ
|
| 1381 |
+
gr.HTML(_h("π¬ Skimmed CFG (sd-webui-skimmed_cfg)"))
|
| 1382 |
+
with gr.Row():
|
| 1383 |
+
skim_en = gr.Checkbox(label="Enable Skimmed CFG", value=False)
|
| 1384 |
+
skim_mode = gr.Radio(choices=["Classic","Smooth"], value="Classic", label="Mode")
|
| 1385 |
+
with gr.Row():
|
| 1386 |
+
skim_scale = gr.Slider(label="Skimming Scale",
|
| 1387 |
+
minimum=1.0, maximum=20.0, step=0.5, value=7.0)
|
| 1388 |
+
skim_flip = gr.Checkbox(label="Disable Flipping Filter (Classic only)", value=False)
|
| 1389 |
+
with gr.Row():
|
| 1390 |
+
skim_sched_curve = gr.Dropdown(label="Schedule",
|
| 1391 |
+
choices=_CURVES, value="Off",
|
| 1392 |
+
info="Curve shape over steps.")
|
| 1393 |
+
skim_sched_min = gr.Slider(label="Min Skim Scale",
|
| 1394 |
+
minimum=1.0, maximum=20.0, step=0.5, value=1.0,
|
| 1395 |
+
info="Skim scale reached at the curve minimum.")
|
| 1396 |
+
with gr.Row():
|
| 1397 |
+
skim_target = gr.Radio(
|
| 1398 |
+
choices=["Uncond", "Cond β Uncond", "Both"],
|
| 1399 |
+
value="Uncond", label="Classic Target",
|
| 1400 |
+
info="'Uncond' β modify uncond directly (moves uncond toward cond). "
|
| 1401 |
+
"'Cond β Uncond' β modify cond, put into uncond (original skimmed-CFG: zero guidance at mask). "
|
| 1402 |
+
"'Both' β modify both embeddings. Smooth mode always targets the branch selected via Apply-to flags.")
|
| 1403 |
+
|
| 1404 |
+
# ββ 8. NRS ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1405 |
+
gr.HTML(_h("β‘ Guidance Geometry β NRS v2 (NRS Kohaku Enhanced)"))
|
| 1406 |
+
gr.HTML(
|
| 1407 |
+
"<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
|
| 1408 |
+
"Proper sigma-aware NRS on latent noise predictions. "
|
| 1409 |
+
"Includes <b>Vector Softcap</b> (prevent oversteering), "
|
| 1410 |
+
"<b>Midpoint Refinement</b> (Kohaku RK2), "
|
| 1411 |
+
"<b>TCFG</b> (SVD tangential damping), "
|
| 1412 |
+
"<b>Disagreement Gate</b> (adaptive intensity) and "
|
| 1413 |
+
"<b>Inter-step smoothing</b> (Momentum / EMA).</p>"
|
| 1414 |
+
)
|
| 1415 |
+
with gr.Row():
|
| 1416 |
+
nrs_en = gr.Checkbox(label="Enable NRS Geometry", value=False)
|
| 1417 |
+
nrs_skew = gr.Slider(label="Skew", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
|
| 1418 |
+
nrs_stretch = gr.Slider(label="Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=5.0)
|
| 1419 |
+
nrs_squash = gr.Slider(label="Squash", minimum=0.0, maximum=1.0, step=0.05, value=0.75)
|
| 1420 |
+
with gr.Row():
|
| 1421 |
+
nrs_sched_curve = gr.Dropdown(label="Schedule",
|
| 1422 |
+
choices=_CURVES, value="Off",
|
| 1423 |
+
info="Curve shape for Skew/Stretch over steps.")
|
| 1424 |
+
nrs_sched_min_skew = gr.Slider(label="Min Skew",
|
| 1425 |
+
minimum=-10.0, maximum=10.0, step=0.1, value=0.0,
|
| 1426 |
+
info="Skew value at the curve minimum.")
|
| 1427 |
+
nrs_sched_min_stretch = gr.Slider(label="Min Stretch",
|
| 1428 |
+
minimum=-10.0, maximum=10.0, step=0.1, value=0.0,
|
| 1429 |
+
info="Stretch value at the curve minimum.")
|
| 1430 |
+
|
| 1431 |
+
gr.HTML("<b style='color:#8ab8e8'>Individual Step Control (NRS-style Range1/Range2 + Lock)</b>"
|
| 1432 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1433 |
+
"Override Skew/Stretch with specific values in one or two step ranges. "
|
| 1434 |
+
"Lock after End: once the last range end is crossed, NRS stays disabled "
|
| 1435 |
+
"for the rest of the generation (including HR pass).</p>")
|
| 1436 |
+
with gr.Row():
|
| 1437 |
+
nrs_sc_enabled = gr.Checkbox(label="Enable Step Control", value=False)
|
| 1438 |
+
nrs_sc_lock = gr.Checkbox(label="Lock after End", value=False,
|
| 1439 |
+
info="Once past last range end, keep NRS disabled permanently.")
|
| 1440 |
+
with gr.Row():
|
| 1441 |
+
nrs_sc_r1_start = gr.Number(label="R1 Start", value=0, precision=0, minimum=0, maximum=500)
|
| 1442 |
+
nrs_sc_r1_end = gr.Number(label="R1 End (0=last)", value=0, precision=0, minimum=0, maximum=500)
|
| 1443 |
+
nrs_sc_r1_skew = gr.Slider(label="R1 Skew", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
|
| 1444 |
+
nrs_sc_r1_stretch = gr.Slider(label="R1 Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=5.0)
|
| 1445 |
+
with gr.Row():
|
| 1446 |
+
nrs_sc_r2_start = gr.Number(label="R2 Start", value=0, precision=0, minimum=0, maximum=500)
|
| 1447 |
+
nrs_sc_r2_end = gr.Number(label="R2 End (0=last)", value=0, precision=0, minimum=0, maximum=500)
|
| 1448 |
+
nrs_sc_r2_skew = gr.Slider(label="R2 Skew", minimum=-10.0, maximum=10.0, step=0.1, value=1.0)
|
| 1449 |
+
nrs_sc_r2_stretch = gr.Slider(label="R2 Stretch", minimum=-10.0, maximum=10.0, step=0.1, value=2.0)
|
| 1450 |
+
|
| 1451 |
+
gr.HTML("<b style='color:#8ab8e8'>Vector Softcap</b>"
|
| 1452 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1453 |
+
"tanh-based direction-preserving magnitude compression. "
|
| 1454 |
+
"Prevents oversteering at high Skew/Stretch values. 0 = disabled.</p>")
|
| 1455 |
+
with gr.Row():
|
| 1456 |
+
nrs_softcap = gr.Slider(label="Softcap Strength",
|
| 1457 |
+
minimum=0.0, maximum=5.0, step=0.1, value=0.0,
|
| 1458 |
+
info="0 = off. Try 1.0β2.0 as a starting point.")
|
| 1459 |
+
nrs_softcap_mode = gr.Dropdown(label="Softcap Mode",
|
| 1460 |
+
choices=["Per Sample", "Per Channel", "Global Batch"],
|
| 1461 |
+
value="Per Sample",
|
| 1462 |
+
info="Per Sample: recommended. Per Channel: best with high-detail. "
|
| 1463 |
+
"Global Batch: legacy behaviour.")
|
| 1464 |
+
|
| 1465 |
+
gr.HTML("<b style='color:#8ab8e8'>Midpoint Refinement (Kohaku RK2)</b>"
|
| 1466 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1467 |
+
"Runge-Kutta 2nd-order midpoint correction β computes NRS twice "
|
| 1468 |
+
"and averages, similar to Kohaku sampler logic. 0 = disabled.</p>")
|
| 1469 |
+
with gr.Row():
|
| 1470 |
+
nrs_midpoint = gr.Slider(label="Midpoint Blend",
|
| 1471 |
+
minimum=0.0, maximum=1.0, step=0.05, value=0.0,
|
| 1472 |
+
info="How far to shift toward midpoint. Try 0.3β0.6.")
|
| 1473 |
+
nrs_midpoint_mode = gr.Dropdown(label="Midpoint Mode",
|
| 1474 |
+
choices=["Classic", "Conservative", "Directional Midpoint"],
|
| 1475 |
+
value="Classic",
|
| 1476 |
+
info="Classic: direct average. Conservative: soften 2nd-pass geometry. "
|
| 1477 |
+
"Directional: project correction onto base direction.")
|
| 1478 |
+
nrs_midpoint_fh = gr.Checkbox(
|
| 1479 |
+
label="First Half Only",
|
| 1480 |
+
value=True,
|
| 1481 |
+
info="Only refine during first half of steps (structure phase). "
|
| 1482 |
+
"Disable to refine all steps.")
|
| 1483 |
+
|
| 1484 |
+
gr.HTML("<b style='color:#8ab8e8'>TCFG β Tangential Damping (arxiv 2503.18137)</b>"
|
| 1485 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1486 |
+
"SVD-based filter: removes uncond components that are tangential "
|
| 1487 |
+
"(misaligned) to cond. Reduces CFG manifold drift. No extra UNet calls.</p>")
|
| 1488 |
+
tcfg_en = gr.Checkbox(label="Enable TCFG", value=False)
|
| 1489 |
+
|
| 1490 |
+
gr.HTML("<b style='color:#8ab8e8'>Disagreement Gate</b>"
|
| 1491 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1492 |
+
"Reduces effective Skew/Stretch when cond and uncond are already "
|
| 1493 |
+
"similar β avoids pointless NRS on well-aligned guidance.</p>")
|
| 1494 |
+
with gr.Row():
|
| 1495 |
+
dis_en = gr.Checkbox(label="Enable Disagreement Gate", value=False)
|
| 1496 |
+
dis_strength = gr.Slider(label="Gate Floor",
|
| 1497 |
+
minimum=0.0, maximum=1.0, step=0.05, value=0.5,
|
| 1498 |
+
info="Minimum gate at zero disagreement. 0 = fully blocks NRS when "
|
| 1499 |
+
"condβuncond. 1 = gate has no effect.")
|
| 1500 |
+
dis_threshold = gr.Slider(label="Gate Threshold",
|
| 1501 |
+
minimum=0.05, maximum=1.0, step=0.05, value=0.3,
|
| 1502 |
+
info="Disagreement level at which full skew/stretch is reached.")
|
| 1503 |
+
dis_metric = gr.Radio(choices=["Cosine", "L2"],
|
| 1504 |
+
value="Cosine", label="Metric",
|
| 1505 |
+
info="Cosine: direction-based. L2: magnitude-based.")
|
| 1506 |
+
|
| 1507 |
+
gr.HTML("<b style='color:#8ab8e8'>Inter-step Smoothing</b>"
|
| 1508 |
+
"<p style='color:#666;font-size:.81em;margin:2px 0 4px'>"
|
| 1509 |
+
"<b>Momentum</b>: RES/Clybius velocity blend between steps β "
|
| 1510 |
+
"smooths abrupt NRS jumps.<br>"
|
| 1511 |
+
"<b>EMA Delta</b>: exponential moving average of the step update vector β "
|
| 1512 |
+
"less temporal blurring than momentum.</p>")
|
| 1513 |
+
with gr.Row():
|
| 1514 |
+
interp_mode = gr.Radio(choices=["None", "Momentum", "EMA Delta"],
|
| 1515 |
+
value="None", label="Inter-step Mode")
|
| 1516 |
+
interp_mom = gr.Slider(label="Momentum",
|
| 1517 |
+
minimum=0.0, maximum=0.99, step=0.01, value=0.0,
|
| 1518 |
+
info="Higher = more temporal smoothing. 0.5β0.8 typical.")
|
| 1519 |
+
interp_ema = gr.Slider(label="EMA Alpha",
|
| 1520 |
+
minimum=0.01, maximum=0.99, step=0.01, value=0.5,
|
| 1521 |
+
info="Blend factor toward current step. 1.0 = no smoothing.")
|
| 1522 |
+
|
| 1523 |
+
# ββ 9. POST-CFG βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1524 |
+
gr.HTML(_h("π Post-CFG Shaping"))
|
| 1525 |
+
with gr.Row():
|
| 1526 |
+
reinhard_en = gr.Checkbox(label="Enable Reinhard Tonemap", value=False)
|
| 1527 |
+
reinhard_ref = gr.Slider(label="Reinhard Ref.",
|
| 1528 |
+
minimum=0.5, maximum=16.0, step=0.5, value=4.0)
|
| 1529 |
+
with gr.Row():
|
| 1530 |
+
rescale_en = gr.Checkbox(label="Enable Rescale CFG", value=False)
|
| 1531 |
+
rescale_mult = gr.Slider(label="Rescale Strength",
|
| 1532 |
+
minimum=0.0, maximum=1.0, step=0.05, value=0.7)
|
| 1533 |
+
with gr.Row():
|
| 1534 |
+
heuristic_en = gr.Checkbox(label="Enable Heuristic CFG", value=False)
|
| 1535 |
+
heuristic_cfg = gr.Slider(label="Heuristic Ref. CFG",
|
| 1536 |
+
minimum=1.0, maximum=20.0, step=0.5, value=5.0)
|
| 1537 |
+
heuristic_hstart = gr.Slider(label="Heuristic Start Step",
|
| 1538 |
+
minimum=0.0, maximum=1.0, step=0.01, value=0.0,
|
| 1539 |
+
info="Fraction of steps before heuristic activates. 0 = from step 0.")
|
| 1540 |
+
with gr.Row():
|
| 1541 |
+
heuristic_sched_curve = gr.Dropdown(label="Schedule",
|
| 1542 |
+
choices=_CURVES, value="Off",
|
| 1543 |
+
info="Curve shape for Heuristic CFG over steps.")
|
| 1544 |
+
heuristic_sched_min = gr.Slider(label="Min Heuristic CFG",
|
| 1545 |
+
minimum=0.5, maximum=20.0, step=0.5, value=1.0,
|
| 1546 |
+
info="Heuristic CFG at the curve minimum.")
|
| 1547 |
+
with gr.Row():
|
| 1548 |
+
lir_en = gr.Checkbox(label="Enable Latent Intensity Rescale", value=False)
|
| 1549 |
+
lir_cfg = gr.Slider(label="LIR Target CFG",
|
| 1550 |
+
minimum=1.0, maximum=30.0, step=0.5, value=8.0)
|
| 1551 |
+
lir_method = gr.Dropdown(label="LIR Method",
|
| 1552 |
+
choices=["hard","soft","hard_squared","range"], value="hard")
|
| 1553 |
+
lir_topk = gr.Slider(label="LIR Top-k",
|
| 1554 |
+
minimum=0.05, maximum=0.5, step=0.01, value=0.25)
|
| 1555 |
+
mean_en = gr.Checkbox(label="Enable Subtract Latent Mean", value=False)
|
| 1556 |
+
|
| 1557 |
+
# ββ 10. SEGA-INSPIRED ACTIVATION WINDOW ββββββββββββββββββββββ
|
| 1558 |
+
gr.HTML(_h("π§ Feature Activation Window (SEGA v5.1)"))
|
| 1559 |
+
gr.HTML(
|
| 1560 |
+
"<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
|
| 1561 |
+
"Controls <em>when</em> embedding-level features (Lerp Uncond, "
|
| 1562 |
+
"Skimmed CFG) are active within a generation, and <em>which</em> "
|
| 1563 |
+
"prompt branch they modify. Inspired by SEGA's warmup, guidance "
|
| 1564 |
+
"end step, and lock-after-end mechanism.<br><br>"
|
| 1565 |
+
"<b>Warmup Steps</b> β embedding features inactive for the first N steps "
|
| 1566 |
+
"(early steps are too noisy for fine guidance).<br>"
|
| 1567 |
+
"<b>End Step</b> β deactivate embedding features after this step "
|
| 1568 |
+
"(0 = keep active until the end).<br>"
|
| 1569 |
+
"<b>Lock after End</b> β once End Step is crossed the features stay "
|
| 1570 |
+
"off for the rest of the generation, including the HiRes pass.<br>"
|
| 1571 |
+
"<b>Apply to Positive / Negative</b> β choose which prompt branch "
|
| 1572 |
+
"Lerp Uncond and Skimmed CFG modify.</p>"
|
| 1573 |
+
)
|
| 1574 |
+
with gr.Row():
|
| 1575 |
+
warmup_steps = gr.Slider(label="Warmup Steps",
|
| 1576 |
+
minimum=0, maximum=50, step=1, value=0,
|
| 1577 |
+
info="Skip embedding features for first N steps.")
|
| 1578 |
+
end_step = gr.Slider(label="End Step (0 = no limit)",
|
| 1579 |
+
minimum=0, maximum=150, step=1, value=0,
|
| 1580 |
+
info="Deactivate embedding features after this step.")
|
| 1581 |
+
lock_after_end = gr.Checkbox(
|
| 1582 |
+
label="Lock after End (stays off for HR pass too)",
|
| 1583 |
+
value=False)
|
| 1584 |
+
with gr.Row():
|
| 1585 |
+
apply_to_neg = gr.Checkbox(
|
| 1586 |
+
label="Apply to Negative (text_uncond)",
|
| 1587 |
+
value=True,
|
| 1588 |
+
info="Lerp Uncond and Skimmed CFG modify the negative/unconditional "
|
| 1589 |
+
"prompt embeddings.")
|
| 1590 |
+
apply_to_pos = gr.Checkbox(
|
| 1591 |
+
label="Apply to Positive (text_cond)",
|
| 1592 |
+
value=False,
|
| 1593 |
+
info="Lerp Uncond and Skimmed CFG also modify the positive/conditional "
|
| 1594 |
+
"prompt embeddings.")
|
| 1595 |
+
|
| 1596 |
+
# ββ 11. PASS MODE + HIRES ββββββββββββββββββββββββββββββββββββ
|
| 1597 |
+
gr.HTML(_h("πΌ Pass Mode & HiRes Fix CFG"))
|
| 1598 |
+
gr.HTML(
|
| 1599 |
+
"<p style='color:#666;font-size:.82em;margin:0 0 5px'>"
|
| 1600 |
+
"<b>Pass Mode</b> β controls during which denoising pass CFG "
|
| 1601 |
+
"features (schedule, boost, jitter, step-ranges, embedding features) "
|
| 1602 |
+
"are active. Useful when you want aggressive shaping only on the "
|
| 1603 |
+
"base pass and clean defaults on the HiRes upscale pass, or vice versa.<br><br>"
|
| 1604 |
+
"<b>HR CFG Override</b> β separately swap <code>p.cfg_scale</code> "
|
| 1605 |
+
"for the HiRes fix second pass (txt2img only). Works independently "
|
| 1606 |
+
"of Pass Mode.</p>"
|
| 1607 |
+
)
|
| 1608 |
+
with gr.Row():
|
| 1609 |
+
pass_mode = gr.Radio(
|
| 1610 |
+
choices=["Both passes", "First pass only", "HR pass only"],
|
| 1611 |
+
value="Both passes",
|
| 1612 |
+
label="Pass Mode",
|
| 1613 |
+
elem_id=f"cpf_passmode_{tab}")
|
| 1614 |
+
with gr.Row():
|
| 1615 |
+
hr_en = gr.Checkbox(label="Override HR CFG", value=False)
|
| 1616 |
+
hr_cfg = gr.Slider(label="CFG for HR pass",
|
| 1617 |
+
minimum=1.0, maximum=30.0, step=0.1, value=1.0)
|
| 1618 |
+
|
| 1619 |
+
self.infotext_fields = [
|
| 1620 |
+
(skip_early, "CPF SkipEarlyCFG"),
|
| 1621 |
+
(ngms, "CPF NGMS"),
|
| 1622 |
+
(hr_cfg, "CPF HR_CFG"),
|
| 1623 |
+
]
|
| 1624 |
+
|
| 1625 |
+
# !! ORDER must EXACTLY match _parse_args() !!
|
| 1626 |
+
return [
|
| 1627 |
+
enabled,
|
| 1628 |
+
skip_early, word_wrap, ngms, ngms_all,
|
| 1629 |
+
ranges_en, *range_rows, # 1 + 12
|
| 1630 |
+
sched_en, sched_type, sched_min, sched_max,
|
| 1631 |
+
boost_en, boost_from, boost_mul,
|
| 1632 |
+
jitter_en, jitter_amt,
|
| 1633 |
+
autocfg_en, autocfg_method, autocfg_ref, autocfg_topk,
|
| 1634 |
+
lerp_en, lerp_str,
|
| 1635 |
+
lerp_sched_curve, lerp_sched_min,
|
| 1636 |
+
skim_en, skim_mode, skim_scale, skim_flip, skim_target,
|
| 1637 |
+
skim_sched_curve, skim_sched_min,
|
| 1638 |
+
nrs_en, nrs_skew, nrs_stretch, nrs_squash,
|
| 1639 |
+
nrs_sched_curve, nrs_sched_min_skew, nrs_sched_min_stretch,
|
| 1640 |
+
nrs_sc_enabled, nrs_sc_lock,
|
| 1641 |
+
nrs_sc_r1_start, nrs_sc_r1_end, nrs_sc_r1_skew, nrs_sc_r1_stretch,
|
| 1642 |
+
nrs_sc_r2_start, nrs_sc_r2_end, nrs_sc_r2_skew, nrs_sc_r2_stretch,
|
| 1643 |
+
nrs_softcap, nrs_softcap_mode,
|
| 1644 |
+
nrs_midpoint, nrs_midpoint_mode, nrs_midpoint_fh,
|
| 1645 |
+
tcfg_en,
|
| 1646 |
+
dis_en, dis_strength, dis_threshold, dis_metric,
|
| 1647 |
+
interp_mode, interp_mom, interp_ema,
|
| 1648 |
+
reinhard_en, reinhard_ref,
|
| 1649 |
+
rescale_en, rescale_mult,
|
| 1650 |
+
heuristic_en, heuristic_cfg, heuristic_hstart,
|
| 1651 |
+
heuristic_sched_curve, heuristic_sched_min,
|
| 1652 |
+
lir_en, lir_cfg, lir_method, lir_topk,
|
| 1653 |
+
mean_en,
|
| 1654 |
+
warmup_steps, end_step, lock_after_end, # β SEGA window
|
| 1655 |
+
apply_to_neg, apply_to_pos, # β SEGA branch flags
|
| 1656 |
+
pass_mode, # β pass mode
|
| 1657 |
+
hr_en, hr_cfg,
|
| 1658 |
+
]
|
| 1659 |
+
|
| 1660 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1661 |
+
@staticmethod
|
| 1662 |
+
def _parse_args(raw):
|
| 1663 |
+
it = iter(raw); n = lambda: next(it)
|
| 1664 |
+
return dict(
|
| 1665 |
+
enabled = n(),
|
| 1666 |
+
skip_early = n(), word_wrap = n(), ngms = n(), ngms_all = n(),
|
| 1667 |
+
ranges_en = n(),
|
| 1668 |
+
ranges = [(int(n()), int(n()), float(n()), n()) for _ in range(3)],
|
| 1669 |
+
sched_en = n(), sched_type = n(), sched_min = n(), sched_max = n(),
|
| 1670 |
+
boost_en = n(), boost_from = n(), boost_mul = n(),
|
| 1671 |
+
jitter_en = n(), jitter_amt = n(),
|
| 1672 |
+
autocfg_en = n(), autocfg_method = n(), autocfg_ref = n(), autocfg_topk = n(),
|
| 1673 |
+
lerp_en = n(), lerp_str = n(),
|
| 1674 |
+
lerp_sched_curve = n(), lerp_sched_min = n(),
|
| 1675 |
+
skim_en = n(), skim_mode = n(), skim_scale = n(), skim_flip = n(), skim_target = n(),
|
| 1676 |
+
skim_sched_curve = n(), skim_sched_min = n(),
|
| 1677 |
+
nrs_en = n(), nrs_skew = n(), nrs_stretch = n(), nrs_squash = n(),
|
| 1678 |
+
nrs_sched_curve = n(), nrs_sched_min_skew = n(), nrs_sched_min_stretch = n(),
|
| 1679 |
+
nrs_sc_enabled = n(), nrs_sc_lock = n(),
|
| 1680 |
+
nrs_sc_r1_start = n(), nrs_sc_r1_end = n(), nrs_sc_r1_skew = n(), nrs_sc_r1_stretch = n(),
|
| 1681 |
+
nrs_sc_r2_start = n(), nrs_sc_r2_end = n(), nrs_sc_r2_skew = n(), nrs_sc_r2_stretch = n(),
|
| 1682 |
+
nrs_softcap = n(), nrs_softcap_mode = n(),
|
| 1683 |
+
nrs_midpoint = n(), nrs_midpoint_mode = n(), nrs_midpoint_fh = n(),
|
| 1684 |
+
tcfg_en = n(),
|
| 1685 |
+
dis_en = n(), dis_strength = n(), dis_threshold = n(), dis_metric = n(),
|
| 1686 |
+
interp_mode = n(), interp_mom = n(), interp_ema = n(),
|
| 1687 |
+
reinhard_en = n(), reinhard_ref = n(),
|
| 1688 |
+
rescale_en = n(), rescale_mult = n(),
|
| 1689 |
+
heuristic_en = n(), heuristic_cfg = n(), heuristic_hstart = n(),
|
| 1690 |
+
heuristic_sched_curve = n(), heuristic_sched_min = n(),
|
| 1691 |
+
lir_en = n(), lir_cfg = n(), lir_method = n(), lir_topk = n(),
|
| 1692 |
+
mean_en = n(),
|
| 1693 |
+
warmup_steps = n(), end_step = n(), lock_after_end = n(),
|
| 1694 |
+
apply_to_neg = n(), apply_to_pos = n(),
|
| 1695 |
+
pass_mode = n(),
|
| 1696 |
+
hr_en = n(), hr_cfg = n(),
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1700 |
+
def before_process(self, p, *args):
|
| 1701 |
+
global _fixup_registered
|
| 1702 |
+
_st.enabled = False
|
| 1703 |
+
try:
|
| 1704 |
+
v = self._parse_args(args)
|
| 1705 |
+
except StopIteration:
|
| 1706 |
+
return
|
| 1707 |
+
if not v["enabled"]:
|
| 1708 |
+
return
|
| 1709 |
+
|
| 1710 |
+
# ββ Register late fixup once per session ββββββββββββββββββββββββββββββ
|
| 1711 |
+
# Done here (not at import) so it's appended AFTER all other extensions'
|
| 1712 |
+
# import-time callbacks β guaranteed to fire last every step.
|
| 1713 |
+
if not _fixup_registered:
|
| 1714 |
+
script_callbacks.on_cfg_denoiser(_late_cond_fixup)
|
| 1715 |
+
_fixup_registered = True
|
| 1716 |
+
|
| 1717 |
+
p.override_settings["skip_early_cond"] = float(v["skip_early"])
|
| 1718 |
+
p.override_settings["comma_padding_backtrack"] = int (v["word_wrap"])
|
| 1719 |
+
p.override_settings["s_min_uncond"] = float(v["ngms"])
|
| 1720 |
+
p.override_settings["s_min_uncond_all"] = bool (v["ngms_all"])
|
| 1721 |
+
|
| 1722 |
+
_st.base_cfg = float(p.cfg_scale)
|
| 1723 |
+
_st.ranges_en = bool(v["ranges_en"])
|
| 1724 |
+
_st.range_scales = None
|
| 1725 |
+
if _st.ranges_en:
|
| 1726 |
+
active = [(s, e, val, i) for (s, e, val, i) in v["ranges"] if int(s) > 0 or int(e) > 0]
|
| 1727 |
+
if active:
|
| 1728 |
+
_st.range_scales = _build_ranges(p.steps, _st.base_cfg, active)
|
| 1729 |
+
|
| 1730 |
+
_st.sched_en = bool(v["sched_en"])
|
| 1731 |
+
_st.sched_k = _L2K.get(v["sched_type"], "off")
|
| 1732 |
+
_st.sched_min = float(v["sched_min"])
|
| 1733 |
+
_st.sched_max = float(v["sched_max"])
|
| 1734 |
+
|
| 1735 |
+
_st.boost_en = bool (v["boost_en"])
|
| 1736 |
+
_st.boost_from = float(v["boost_from"])
|
| 1737 |
+
_st.boost_mul = float(v["boost_mul"])
|
| 1738 |
+
_st.jitter_en = bool (v["jitter_en"])
|
| 1739 |
+
_st.jitter_amt = float(v["jitter_amt"])
|
| 1740 |
+
|
| 1741 |
+
_st.autocfg_en = bool (v["autocfg_en"])
|
| 1742 |
+
_st.autocfg_method = v["autocfg_method"]
|
| 1743 |
+
_st.autocfg_ref = float(v["autocfg_ref"])
|
| 1744 |
+
_st.autocfg_topk = float(v["autocfg_topk"])
|
| 1745 |
+
|
| 1746 |
+
_st.lerp_en = bool (v["lerp_en"])
|
| 1747 |
+
_st.lerp_str = float(v["lerp_str"])
|
| 1748 |
+
_st.lerp_sched_curve = v.get("lerp_sched_curve", "Off")
|
| 1749 |
+
_st.lerp_sched_min = float(v.get("lerp_sched_min", 0.5))
|
| 1750 |
+
|
| 1751 |
+
_st.skim_en = bool (v["skim_en"])
|
| 1752 |
+
_st.skim_mode = v["skim_mode"]
|
| 1753 |
+
_st.skim_scale = float(v["skim_scale"])
|
| 1754 |
+
_st.skim_flip = bool (v["skim_flip"])
|
| 1755 |
+
_st.skim_target = v["skim_target"]
|
| 1756 |
+
_st.skim_sched_curve = v.get("skim_sched_curve", "Off")
|
| 1757 |
+
_st.skim_sched_min = float(v.get("skim_sched_min", 1.0))
|
| 1758 |
+
|
| 1759 |
+
_st.nrs_en = bool (v["nrs_en"])
|
| 1760 |
+
_st.nrs_skew = float(v["nrs_skew"])
|
| 1761 |
+
_st.nrs_stretch = float(v["nrs_stretch"])
|
| 1762 |
+
_st.nrs_squash = float(v["nrs_squash"])
|
| 1763 |
+
_st.nrs_sched_curve = v.get("nrs_sched_curve", "Off")
|
| 1764 |
+
_st.nrs_sched_min_skew = float(v.get("nrs_sched_min_skew", 0.0))
|
| 1765 |
+
_st.nrs_sched_min_stretch = float(v.get("nrs_sched_min_stretch", 0.0))
|
| 1766 |
+
_st.nrs_sc_enabled = bool(v.get("nrs_sc_enabled", False))
|
| 1767 |
+
_st.nrs_sc_lock = bool(v.get("nrs_sc_lock", False))
|
| 1768 |
+
_st.nrs_sc_locked = False # always reset at generation start
|
| 1769 |
+
_st.nrs_sc_r1_start = int(v.get("nrs_sc_r1_start", 0))
|
| 1770 |
+
_st.nrs_sc_r1_end = int(v.get("nrs_sc_r1_end", 0))
|
| 1771 |
+
_st.nrs_sc_r1_skew = float(v.get("nrs_sc_r1_skew", 2.0))
|
| 1772 |
+
_st.nrs_sc_r1_stretch= float(v.get("nrs_sc_r1_stretch", 5.0))
|
| 1773 |
+
_st.nrs_sc_r2_start = int(v.get("nrs_sc_r2_start", 0))
|
| 1774 |
+
_st.nrs_sc_r2_end = int(v.get("nrs_sc_r2_end", 0))
|
| 1775 |
+
_st.nrs_sc_r2_skew = float(v.get("nrs_sc_r2_skew", 1.0))
|
| 1776 |
+
_st.nrs_sc_r2_stretch= float(v.get("nrs_sc_r2_stretch", 2.0))
|
| 1777 |
+
# NRS v2
|
| 1778 |
+
_st.nrs_softcap = float(v["nrs_softcap"])
|
| 1779 |
+
_st.nrs_softcap_mode = v["nrs_softcap_mode"]
|
| 1780 |
+
_st.nrs_midpoint = float(v["nrs_midpoint"])
|
| 1781 |
+
_st.nrs_midpoint_mode = v["nrs_midpoint_mode"]
|
| 1782 |
+
_st.nrs_midpoint_fh = bool (v["nrs_midpoint_fh"])
|
| 1783 |
+
# TCFG
|
| 1784 |
+
_st.tcfg_en = bool(v["tcfg_en"])
|
| 1785 |
+
# Disagreement Gate
|
| 1786 |
+
_st.dis_en = bool (v["dis_en"])
|
| 1787 |
+
_st.dis_strength = float(v["dis_strength"])
|
| 1788 |
+
_st.dis_threshold = float(v["dis_threshold"])
|
| 1789 |
+
_st.dis_metric = v["dis_metric"]
|
| 1790 |
+
# Inter-step smoothing
|
| 1791 |
+
_st.interp_mode = v["interp_mode"]
|
| 1792 |
+
_st.interp_mom = float(v["interp_mom"])
|
| 1793 |
+
_st.interp_ema = float(v["interp_ema"])
|
| 1794 |
+
# Reset inter-step runtime state at start of generation
|
| 1795 |
+
_st._prev_nrs = None
|
| 1796 |
+
_st._prev_vel = None
|
| 1797 |
+
_st._prev_delta = None
|
| 1798 |
+
# Reset context (will be populated by on_cfg_denoiser)
|
| 1799 |
+
_st.current_x = None
|
| 1800 |
+
_st.current_sigma = None
|
| 1801 |
+
|
| 1802 |
+
_st.reinhard_en = bool (v["reinhard_en"])
|
| 1803 |
+
_st.reinhard_ref = float(v["reinhard_ref"])
|
| 1804 |
+
_st.rescale_en = bool (v["rescale_en"])
|
| 1805 |
+
_st.rescale_mult = float(v["rescale_mult"])
|
| 1806 |
+
_st.heuristic_en = bool (v["heuristic_en"])
|
| 1807 |
+
_st.heuristic_cfg = float(v["heuristic_cfg"])
|
| 1808 |
+
_st.heuristic_hstart = float(v["heuristic_hstart"])
|
| 1809 |
+
_st.heuristic_sched_curve = v.get("heuristic_sched_curve", "Off")
|
| 1810 |
+
_st.heuristic_sched_min = float(v.get("heuristic_sched_min", 1.0))
|
| 1811 |
+
_st.lir_en = bool (v["lir_en"])
|
| 1812 |
+
_st.lir_cfg = float(v["lir_cfg"])
|
| 1813 |
+
_st.lir_method = v["lir_method"]
|
| 1814 |
+
_st.lir_topk = float(v["lir_topk"])
|
| 1815 |
+
_st.mean_en = bool (v["mean_en"])
|
| 1816 |
+
|
| 1817 |
+
# SEGA-inspired
|
| 1818 |
+
_st.warmup_steps = int (v["warmup_steps"])
|
| 1819 |
+
_st.end_step = int (v["end_step"])
|
| 1820 |
+
_st.lock_after_end = bool (v["lock_after_end"])
|
| 1821 |
+
_st.lock_triggered = False # always reset at generation start
|
| 1822 |
+
_st.apply_to_neg = bool (v["apply_to_neg"])
|
| 1823 |
+
_st.apply_to_pos = bool (v["apply_to_pos"])
|
| 1824 |
+
|
| 1825 |
+
# Pass mode
|
| 1826 |
+
_st.pass_mode = v["pass_mode"]
|
| 1827 |
+
_st.in_hr_pass = False # first pass starts here
|
| 1828 |
+
|
| 1829 |
+
_st.hr_en = bool (v["hr_en"])
|
| 1830 |
+
_st.hr_cfg = float(v["hr_cfg"])
|
| 1831 |
+
|
| 1832 |
+
needs_hook = any([_st.nrs_en, _st.autocfg_en, _st.reinhard_en,
|
| 1833 |
+
_st.rescale_en, _st.heuristic_en, _st.lir_en,
|
| 1834 |
+
_st.mean_en, _st.tcfg_en,
|
| 1835 |
+
_st.interp_mode != "None"])
|
| 1836 |
+
if needs_hook:
|
| 1837 |
+
_hook_combine()
|
| 1838 |
+
|
| 1839 |
+
_st.enabled = True
|
| 1840 |
+
|
| 1841 |
+
gp = p.extra_generation_params
|
| 1842 |
+
gp["CPF SkipEarlyCFG"] = v["skip_early"]
|
| 1843 |
+
gp["CPF NGMS"] = v["ngms"]
|
| 1844 |
+
if _st.ranges_en and _st.range_scales:
|
| 1845 |
+
gp["CPF StepRanges"] = str([(s,e,va,i) for s,e,va,i in v["ranges"] if s or e])
|
| 1846 |
+
if _st.sched_en and _st.sched_k != "off":
|
| 1847 |
+
gp["CPF Schedule"] = f"{_SCHED.get(_st.sched_k)} Γ{_st.sched_min:.2f}βΓ{_st.sched_max:.2f}"
|
| 1848 |
+
if _st.boost_en: gp["CPF Boost"] = f"Γ{_st.boost_mul:.2f} from {_st.boost_from*100:.0f}%"
|
| 1849 |
+
if _st.jitter_en: gp["CPF Jitter"] = f"Β±{_st.jitter_amt*100:.0f}%"
|
| 1850 |
+
if _st.autocfg_en: gp["CPF AutoCFG"] = f"{_st.autocfg_method} ref={_st.autocfg_ref}"
|
| 1851 |
+
if _st.lerp_en:
|
| 1852 |
+
gp["CPF LerpUncond"] = str(_st.lerp_str)
|
| 1853 |
+
if _st.lerp_sched_curve != "Off":
|
| 1854 |
+
gp["CPF LerpSched"] = f"{_st.lerp_sched_curve} min={_st.lerp_sched_min}"
|
| 1855 |
+
if _st.skim_en:
|
| 1856 |
+
gp["CPF Skim"] = f"{_st.skim_mode} s={_st.skim_scale}"
|
| 1857 |
+
if _st.skim_sched_curve != "Off":
|
| 1858 |
+
gp["CPF SkimSched"] = f"{_st.skim_sched_curve} min={_st.skim_sched_min}"
|
| 1859 |
+
if _st.nrs_en:
|
| 1860 |
+
gp["CPF NRS"] = f"skew={_st.nrs_skew} str={_st.nrs_stretch} sq={_st.nrs_squash}"
|
| 1861 |
+
if _st.nrs_sched_curve != "Off":
|
| 1862 |
+
gp["CPF NRSSched"] = f"{_st.nrs_sched_curve} minSkew={_st.nrs_sched_min_skew} minStr={_st.nrs_sched_min_stretch}"
|
| 1863 |
+
if _st.nrs_sc_enabled:
|
| 1864 |
+
gp["CPF NRSSC"] = ("R1:({}-{})sk={}st={} "
|
| 1865 |
+
"R2:({}-{})sk={}st={} lock={}".format(
|
| 1866 |
+
_st.nrs_sc_r1_start, _st.nrs_sc_r1_end,
|
| 1867 |
+
_st.nrs_sc_r1_skew, _st.nrs_sc_r1_stretch,
|
| 1868 |
+
_st.nrs_sc_r2_start, _st.nrs_sc_r2_end,
|
| 1869 |
+
_st.nrs_sc_r2_skew, _st.nrs_sc_r2_stretch,
|
| 1870 |
+
_st.nrs_sc_lock))
|
| 1871 |
+
if _st.reinhard_en: gp["CPF Reinhard"] = str(_st.reinhard_ref)
|
| 1872 |
+
if _st.rescale_en: gp["CPF Rescale"] = str(_st.rescale_mult)
|
| 1873 |
+
if _st.heuristic_en:
|
| 1874 |
+
gp["CPF Heuristic"] = str(_st.heuristic_cfg)
|
| 1875 |
+
if _st.heuristic_hstart > 0: gp["CPF Heuristic Start"] = str(_st.heuristic_hstart)
|
| 1876 |
+
if _st.heuristic_sched_curve != "Off":
|
| 1877 |
+
gp["CPF HeuristicSched"] = f"{_st.heuristic_sched_curve} min={_st.heuristic_sched_min}"
|
| 1878 |
+
if _st.lir_en: gp["CPF LIR"] = f"{_st.lir_method} cfg={_st.lir_cfg}"
|
| 1879 |
+
if _st.mean_en: gp["CPF SubtractMean"] = "1"
|
| 1880 |
+
if _st.warmup_steps or _st.end_step:
|
| 1881 |
+
gp["CPF Window"] = f"warmup={_st.warmup_steps} end={_st.end_step} lock={_st.lock_after_end}"
|
| 1882 |
+
if _st.apply_to_pos or not _st.apply_to_neg:
|
| 1883 |
+
gp["CPF Apply"] = f"pos={_st.apply_to_pos} neg={_st.apply_to_neg}"
|
| 1884 |
+
if _st.pass_mode != "Both passes":
|
| 1885 |
+
gp["CPF PassMode"] = _st.pass_mode
|
| 1886 |
+
|
| 1887 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 1888 |
+
def before_hr(self, p, *args):
|
| 1889 |
+
# Arm HR-pass flag β persists lock_triggered from first pass (SEGA lock)
|
| 1890 |
+
_st.in_hr_pass = True
|
| 1891 |
+
# Reset inter-step state: HR pass is a new sequence
|
| 1892 |
+
_st._prev_nrs = None
|
| 1893 |
+
_st._prev_vel = None
|
| 1894 |
+
_st._prev_delta = None
|
| 1895 |
+
# NOTE: nrs_sc_locked is NOT reset here β Lock after End persists across HR
|
| 1896 |
+
# pass, matching NRS behaviour. It is reset only in postprocess.
|
| 1897 |
+
|
| 1898 |
+
# HR CFG override (independent of pass_mode)
|
| 1899 |
+
if _st.enabled and _st.hr_en and isinstance(p, StableDiffusionProcessingTxt2Img):
|
| 1900 |
+
_st.hr_cfg_saved = p.cfg_scale
|
| 1901 |
+
p.cfg_scale = _st.hr_cfg
|
| 1902 |
+
p.extra_generation_params["CPF HR_CFG"] = _st.hr_cfg
|
| 1903 |
+
|
| 1904 |
+
def postprocess_image(self, p, pp, *args):
|
| 1905 |
+
# NOTE: intentionally NOT restoring hr_cfg_saved here.
|
| 1906 |
+
# postprocess_image is called for every image in a batch, so restoring
|
| 1907 |
+
# after the first image would leave the rest using base cfg instead of
|
| 1908 |
+
# hr_cfg. The restore is done once in postprocess() after all images.
|
| 1909 |
+
pass
|
| 1910 |
+
|
| 1911 |
+
def postprocess(self, p, processed, *args):
|
| 1912 |
+
_st.enabled = False
|
| 1913 |
+
_st.lock_triggered = False
|
| 1914 |
+
_st.nrs_sc_locked = False
|
| 1915 |
+
_st.in_hr_pass = False
|
| 1916 |
+
# Reset inter-step state
|
| 1917 |
+
_st._prev_nrs = None
|
| 1918 |
+
_st._prev_vel = None
|
| 1919 |
+
_st._prev_delta = None
|
| 1920 |
+
# Reset context
|
| 1921 |
+
_st.current_x = None
|
| 1922 |
+
_st.current_sigma = None
|
| 1923 |
+
_unhook_combine()
|
| 1924 |
+
if _st.hr_cfg_saved is not None:
|
| 1925 |
+
p.cfg_scale = _st.hr_cfg_saved
|
| 1926 |
+
_st.hr_cfg_saved = None
|