Upload 20 files
Browse files- neutral_prompt_patched/.gitignore +1 -0
- neutral_prompt_patched/LICENSE +21 -0
- neutral_prompt_patched/README.md +99 -0
- neutral_prompt_patched/lib_neutral_prompt/__init__.py +1 -0
- neutral_prompt_patched/lib_neutral_prompt/affine_transform.py +83 -0
- neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py +646 -0
- neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py +23 -0
- neutral_prompt_patched/lib_neutral_prompt/external_code/api.py +27 -0
- neutral_prompt_patched/lib_neutral_prompt/global_state.py +75 -0
- neutral_prompt_patched/lib_neutral_prompt/hijacker.py +34 -0
- neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py +382 -0
- neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py +86 -0
- neutral_prompt_patched/lib_neutral_prompt/ui.py +196 -0
- neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py +42 -0
- neutral_prompt_patched/scripts/neutral_prompt.py +94 -0
- neutral_prompt_patched/test/perp_parser/__init__.py +0 -0
- neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py +133 -0
- neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py +217 -0
- neutral_prompt_patched/test/perp_parser/test_basic_parser.py +122 -0
- neutral_prompt_patched/test/perp_parser/test_malicious_parser.py +182 -0
neutral_prompt_patched/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
neutral_prompt_patched/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 ljleb
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
neutral_prompt_patched/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sd-webui-neutral-prompt — Ultimate Edition
|
| 2 |
+
|
| 3 |
+
> **Unified merge of the main experimental branches.**
|
| 4 |
+
> Most features from all branches are merged and work simultaneously.
|
| 5 |
+
> Some `life`/dev-only tooling (debug visualisations, mask images) was intentionally not carried over — only the production-safe subset is included.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## What's inside
|
| 10 |
+
|
| 11 |
+
| Feature | Origin branch | Keyword / setting |
|
| 12 |
+
|---|---|---|
|
| 13 |
+
| Perpendicular projection | `main` | `AND_PERP` |
|
| 14 |
+
| Saliency-guided mask | `main` | `AND_SALT` |
|
| 15 |
+
| Semantic guidance top-k | `main` | `AND_TOPK` |
|
| 16 |
+
| CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
|
| 17 |
+
| XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
|
| 18 |
+
| External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
|
| 19 |
+
| CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
|
| 20 |
+
| Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` prefix |
|
| 21 |
+
| Soft alignment blend | `alignment_blend` | `AND_ALIGN_D_S` |
|
| 22 |
+
| Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN_D_S` |
|
| 23 |
+
| Sharpened saliency maps (k=20) | `life` (production-safe subset) | used automatically by `AND_SALT` |
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## Prompt syntax
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
positive prompt text
|
| 31 |
+
AND_PERP negative-direction text :0.8
|
| 32 |
+
AND_SALT competing concept :1.0
|
| 33 |
+
AND_TOPK fine detail tweak :0.5
|
| 34 |
+
AND_ALIGN_4_8 style blend :0.6
|
| 35 |
+
AND_MASK_ALIGN_4_8 structure-preserving style :0.6
|
| 36 |
+
AND_PERP ROTATE[0.125] rotated perpendicular concept :1.0
|
| 37 |
+
ROTATE[0.125] AND_PERP rotated perpendicular concept :1.0
|
| 38 |
+
AND_SALT SLIDE[0.05,0] spatially shifted concept :1.0
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Affine transform keywords
|
| 42 |
+
Affine keywords are supported in **either order** relative to the conciliation keyword:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
AND_PERP ROTATE[0.125] vivid colors :0.8 ← affine after keyword
|
| 46 |
+
ROTATE[0.125] AND_PERP vivid colors :0.8 ← affine before keyword (both work)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
| Keyword | Parameter | Effect |
|
| 50 |
+
|---|---|---|
|
| 51 |
+
| `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
|
| 52 |
+
| `SLIDE[x,y]` | normalised offset | Translate latent contribution |
|
| 53 |
+
| `SCALE[x,y]` | scale factors | Scale latent contribution |
|
| 54 |
+
| `SHEAR[x,y]` | shear in turns | Shear latent contribution |
|
| 55 |
+
|
| 56 |
+
Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
|
| 57 |
+
|
| 58 |
+
### AND_ALIGN_D_S / AND_MASK_ALIGN_D_S
|
| 59 |
+
- **D** = detail kernel size (small → fine detail)
|
| 60 |
+
- **S** = structure kernel size (large → global composition)
|
| 61 |
+
- The child prompt is blended in proportionally to how much it alters detail *without* changing structure.
|
| 62 |
+
- `AND_ALIGN` uses a soft weight; `AND_MASK_ALIGN` uses a binary 0/1 mask.
|
| 63 |
+
- Supported range: D, S ∈ [2, 32], D ≠ S.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## External API
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
# In another extension or script:
|
| 71 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 72 |
+
|
| 73 |
+
override_cfg_rescale(0.7) # override for next step only
|
| 74 |
+
factor = get_last_cfg_rescale_factor() # read after generation
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Installation
|
| 80 |
+
|
| 81 |
+
1. Clone / copy this folder into `stable-diffusion-webui/extensions/`.
|
| 82 |
+
2. Restart the webui.
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## CFG Rescale φ
|
| 87 |
+
|
| 88 |
+
When set to a value > 0 the extension rescales the CFG output to have the
|
| 89 |
+
same mean and a partially rescaled standard deviation as the raw predicted x0.
|
| 90 |
+
This reduces colour over-saturation at high CFG values without affecting
|
| 91 |
+
prompt adherence as much as simply lowering CFG.
|
| 92 |
+
|
| 93 |
+
The formula used is the **mean-preserving** variant from the `main` branch:
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
|
neutral_prompt_patched/lib_neutral_prompt/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# lib_neutral_prompt package
|
neutral_prompt_patched/lib_neutral_prompt/affine_transform.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Affine spatial transform utilities for latent tensors.
|
| 3 |
+
Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
|
| 4 |
+
|
| 5 |
+
Provides:
|
| 6 |
+
apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
|
| 7 |
+
apply_masked_transform() – apply affine + cosine-feathered mask blending
|
| 8 |
+
create_cosine_feathered_mask() – smooth circular weight mask
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def apply_affine_transform(
|
| 18 |
+
tensor: torch.Tensor,
|
| 19 |
+
affine: torch.Tensor,
|
| 20 |
+
mode: str = 'bilinear',
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
|
| 24 |
+
|
| 25 |
+
:param tensor: Input tensor of shape [C, H, W].
|
| 26 |
+
:param affine: 2×3 float32 affine matrix.
|
| 27 |
+
:param mode: Interpolation mode for grid_sample ('bilinear' default).
|
| 28 |
+
The original affine branch used bilinear for the pre-noise
|
| 29 |
+
inverse path (smoother) and nearest for post-combine.
|
| 30 |
+
Bilinear is the better default for both.
|
| 31 |
+
:return: Transformed tensor of shape [C, H, W].
|
| 32 |
+
"""
|
| 33 |
+
affine = affine.clone().to(tensor.device)
|
| 34 |
+
aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
|
| 35 |
+
affine[0, 1] *= aspect_ratio
|
| 36 |
+
affine[1, 0] /= aspect_ratio
|
| 37 |
+
|
| 38 |
+
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
| 39 |
+
return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_masked_transform(
|
| 43 |
+
tensor: torch.Tensor,
|
| 44 |
+
affine: torch.Tensor,
|
| 45 |
+
weight: float,
|
| 46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Apply an affine transform with a cosine-feathered spatial weight mask.
|
| 49 |
+
|
| 50 |
+
The mask channel is appended to the tensor before transformation so that
|
| 51 |
+
the same spatial warp is applied to both content and mask consistently.
|
| 52 |
+
|
| 53 |
+
:param tensor: C×H×W latent tensor.
|
| 54 |
+
:param affine: 2×3 affine matrix.
|
| 55 |
+
:param weight: Global scalar weight for the mask.
|
| 56 |
+
:return: (transformed_content [C, H, W], transformed_mask [H, W])
|
| 57 |
+
"""
|
| 58 |
+
mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
|
| 59 |
+
tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
|
| 60 |
+
transformed = apply_affine_transform(tensor_with_mask, affine)
|
| 61 |
+
return transformed[:-1], transformed[-1] # content, mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Create a circularly-clipped cosine-feathered mask of shape [H, W].
|
| 67 |
+
|
| 68 |
+
Values at the centre approach `weight`; values outside the unit circle
|
| 69 |
+
are exactly 0, with a smooth cosine fall-off in between.
|
| 70 |
+
|
| 71 |
+
:param size: (H, W) tuple.
|
| 72 |
+
:param weight: Peak mask value at the centre.
|
| 73 |
+
:return: Float32 mask tensor of shape [H, W].
|
| 74 |
+
"""
|
| 75 |
+
y, x = torch.meshgrid(
|
| 76 |
+
torch.linspace(-1, 1, size[0]),
|
| 77 |
+
torch.linspace(-1, 1, size[1]),
|
| 78 |
+
indexing='ij',
|
| 79 |
+
)
|
| 80 |
+
dist = torch.sqrt(x ** 2 + y ** 2)
|
| 81 |
+
mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
|
| 82 |
+
mask[dist > 1] = 0.0
|
| 83 |
+
return mask.float() * weight
|
neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified CFG Denoiser Hijack
|
| 3 |
+
===========================
|
| 4 |
+
Combines features from all sd-webui-neutral-prompt branches:
|
| 5 |
+
|
| 6 |
+
• Core PERP / SALT / TOPK strategies (main branch)
|
| 7 |
+
• cfg_rescale with mean-preserving formula (main branch – better than multiplicative)
|
| 8 |
+
• cfg_rescale_override (XYZ-grid / API support) (main branch)
|
| 9 |
+
• CFGRescaleFactorSingleton export (export_rescale_factor branch)
|
| 10 |
+
• Affine spatial transforms per-prompt (affine branch)
|
| 11 |
+
• AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
|
| 12 |
+
• AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
|
| 13 |
+
• Improved salience with sharpness parameter k (life branch – production-safe subset)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
import functools
|
| 20 |
+
import re
|
| 21 |
+
import sys
|
| 22 |
+
import textwrap
|
| 23 |
+
from typing import Dict, List, Tuple
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from lib_neutral_prompt import affine_transform as affine_mod
|
| 29 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
|
| 30 |
+
from modules import script_callbacks, sd_samplers, shared
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Pre-noise affine hook (affine branch)
|
| 35 |
+
#
|
| 36 |
+
# Applied BEFORE each CFG denoising step: the noisy latent slice for every
|
| 37 |
+
# cond prompt is warped by the INVERSE of that prompt's affine transform.
|
| 38 |
+
# This means the UNet sees the latent in the "locally-rotated" frame of each
|
| 39 |
+
# prompt. combine_denoised then applies the FORWARD transform to the
|
| 40 |
+
# resulting cond delta, restoring it to the global frame.
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
@dataclasses.dataclass
|
| 44 |
+
class _PreNoiseArgs:
|
| 45 |
+
x: torch.Tensor
|
| 46 |
+
cond_indices: List[Tuple[int, float]]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _GlobalToLocalAffineVisitor:
|
| 50 |
+
"""Build a {cond_index: inverse_affine} dict for every leaf prompt."""
|
| 51 |
+
|
| 52 |
+
def visit_leaf_prompt(
|
| 53 |
+
self,
|
| 54 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 55 |
+
args: _PreNoiseArgs,
|
| 56 |
+
index: int,
|
| 57 |
+
) -> Dict[int, torch.Tensor]:
|
| 58 |
+
cond_index = args.cond_indices[index][0]
|
| 59 |
+
if that.local_transform is not None:
|
| 60 |
+
# 2×3 → 3×3 → invert → back to 2×3
|
| 61 |
+
m3 = torch.vstack([that.local_transform,
|
| 62 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 63 |
+
inv = torch.linalg.inv(m3)[:-1]
|
| 64 |
+
else:
|
| 65 |
+
inv = torch.eye(3)[:-1]
|
| 66 |
+
return {cond_index: inv}
|
| 67 |
+
|
| 68 |
+
def visit_composite_prompt(
|
| 69 |
+
self,
|
| 70 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 71 |
+
args: _PreNoiseArgs,
|
| 72 |
+
index: int,
|
| 73 |
+
) -> Dict[int, torch.Tensor]:
|
| 74 |
+
inv_transforms: Dict[int, torch.Tensor] = {}
|
| 75 |
+
|
| 76 |
+
for child in that.children:
|
| 77 |
+
inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
|
| 78 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 79 |
+
|
| 80 |
+
# Compose parent transform on top of all children
|
| 81 |
+
if that.local_transform is not None:
|
| 82 |
+
m3 = torch.vstack([that.local_transform,
|
| 83 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 84 |
+
parent_inv = torch.linalg.inv(m3)
|
| 85 |
+
for inv in inv_transforms.values():
|
| 86 |
+
# inv is 2×3; extend to 3×3, multiply, trim back
|
| 87 |
+
inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
|
| 88 |
+
inv[:] = (parent_inv @ inv3)[:-1]
|
| 89 |
+
|
| 90 |
+
return inv_transforms
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
|
| 94 |
+
"""Pre-noise hook: warp each cond latent slice by the inverse affine."""
|
| 95 |
+
if not global_state.is_enabled:
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
for prompt, cond_indices in zip(global_state.prompt_exprs,
|
| 99 |
+
global_state.batch_cond_indices):
|
| 100 |
+
args = _PreNoiseArgs(params.x, cond_indices)
|
| 101 |
+
inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
|
| 102 |
+
for cond_index, _ in cond_indices:
|
| 103 |
+
params.x[cond_index] = affine_mod.apply_affine_transform(
|
| 104 |
+
params.x[cond_index], inv_transforms[cond_index]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Public entry point
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
def combine_denoised_hijack(
|
| 116 |
+
x_out: torch.Tensor,
|
| 117 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 118 |
+
text_uncond: torch.Tensor,
|
| 119 |
+
cond_scale: float,
|
| 120 |
+
original_function,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
if not global_state.is_enabled:
|
| 123 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 124 |
+
|
| 125 |
+
denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
|
| 126 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 127 |
+
|
| 128 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 129 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 130 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
|
| 131 |
+
aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
|
| 132 |
+
|
| 133 |
+
# Apply per-prompt affine transform to both deltas (affine branch)
|
| 134 |
+
if prompt.local_transform is not None:
|
| 135 |
+
cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
|
| 136 |
+
aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
|
| 137 |
+
|
| 138 |
+
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
| 139 |
+
denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
| 140 |
+
|
| 141 |
+
return denoised
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ---------------------------------------------------------------------------
|
| 145 |
+
# Internal helpers
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
|
| 148 |
+
def _get_webui_denoised(
|
| 149 |
+
x_out: torch.Tensor,
|
| 150 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 151 |
+
text_uncond: torch.Tensor,
|
| 152 |
+
cond_scale: float,
|
| 153 |
+
original_function,
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 156 |
+
sliced_batch_x_out: List[torch.Tensor] = []
|
| 157 |
+
sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 158 |
+
|
| 159 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 160 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 161 |
+
sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
|
| 162 |
+
if sliced_indices:
|
| 163 |
+
sliced_batch_cond_indices.append(sliced_indices)
|
| 164 |
+
sliced_batch_x_out.extend(sliced_x_out)
|
| 165 |
+
|
| 166 |
+
sliced_batch_x_out += list(uncond)
|
| 167 |
+
return original_function(
|
| 168 |
+
torch.stack(sliced_batch_x_out, dim=0),
|
| 169 |
+
sliced_batch_cond_indices,
|
| 170 |
+
text_uncond,
|
| 171 |
+
cond_scale,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
"""
|
| 177 |
+
Mean-preserving CFG rescale (main branch formula).
|
| 178 |
+
Override is applied first so XYZ-grid / external API overrides are never silently skipped.
|
| 179 |
+
Also stores the computed rescale factor in CFGRescaleFactorSingleton.
|
| 180 |
+
"""
|
| 181 |
+
# Clear last step's factor so get() returns None if rescaling is skipped
|
| 182 |
+
# this step (stale value bug fix).
|
| 183 |
+
global_state.CFGRescaleFactorSingleton.clear()
|
| 184 |
+
|
| 185 |
+
# Apply one-shot override BEFORE the early-exit check, otherwise a 0→nonzero
|
| 186 |
+
# override (from the external API) would be silently discarded.
|
| 187 |
+
global_state.apply_and_clear_cfg_rescale_override()
|
| 188 |
+
|
| 189 |
+
if global_state.cfg_rescale == 0:
|
| 190 |
+
return cfg_cond
|
| 191 |
+
|
| 192 |
+
cfg_std = cfg_cond.std()
|
| 193 |
+
if cfg_std == 0:
|
| 194 |
+
# Degenerate case: constant tensor – rescaling is a no-op.
|
| 195 |
+
return cfg_cond
|
| 196 |
+
|
| 197 |
+
cfg_cond_mean = cfg_cond.mean()
|
| 198 |
+
rescale_mean = (
|
| 199 |
+
(1 - global_state.cfg_rescale) * cfg_cond_mean
|
| 200 |
+
+ global_state.cfg_rescale * cond.mean()
|
| 201 |
+
)
|
| 202 |
+
rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
|
| 203 |
+
|
| 204 |
+
# Export for external consumers (export_rescale_factor branch)
|
| 205 |
+
global_state.CFGRescaleFactorSingleton.set(
|
| 206 |
+
rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclasses.dataclass
|
| 213 |
+
class _DenoiseArgs:
|
| 214 |
+
x_out: torch.Tensor
|
| 215 |
+
uncond: torch.Tensor
|
| 216 |
+
cond_indices: List[Tuple[int, float]]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ---------------------------------------------------------------------------
|
| 220 |
+
# Gather webui-style conditions (needed for the webui's own CFG path)
|
| 221 |
+
# ---------------------------------------------------------------------------
|
| 222 |
+
|
| 223 |
+
def _gather_webui_conds(
|
| 224 |
+
prompt: neutral_prompt_parser.CompositePrompt,
|
| 225 |
+
args: _DenoiseArgs,
|
| 226 |
+
index_in: int,
|
| 227 |
+
index_out: int,
|
| 228 |
+
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
| 229 |
+
sliced_x_out: List[torch.Tensor] = []
|
| 230 |
+
sliced_cond_indices: List[Tuple[int, float]] = []
|
| 231 |
+
|
| 232 |
+
for child in prompt.children:
|
| 233 |
+
if child.conciliation is None:
|
| 234 |
+
if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
|
| 235 |
+
child_x_out = args.x_out[args.cond_indices[index_in][0]]
|
| 236 |
+
child_weight = child.weight
|
| 237 |
+
else:
|
| 238 |
+
child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
|
| 239 |
+
child_x_out = child_x_out + args.uncond
|
| 240 |
+
|
| 241 |
+
index_offset = index_out + len(sliced_x_out)
|
| 242 |
+
sliced_x_out.append(child_x_out)
|
| 243 |
+
sliced_cond_indices.append((index_offset, child_weight))
|
| 244 |
+
|
| 245 |
+
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 246 |
+
|
| 247 |
+
return sliced_x_out, sliced_cond_indices
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _get_cond_delta_and_weight(
|
| 251 |
+
prompt: neutral_prompt_parser.PromptExpr,
|
| 252 |
+
args: _DenoiseArgs,
|
| 253 |
+
index: int,
|
| 254 |
+
) -> Tuple[torch.Tensor, float]:
|
| 255 |
+
"""Compute cond delta and effective weight, applying affine transform if present."""
|
| 256 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
|
| 257 |
+
cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
|
| 258 |
+
weight = prompt.weight
|
| 259 |
+
|
| 260 |
+
if prompt.local_transform is not None:
|
| 261 |
+
transformed, weight_tensor = affine_mod.apply_masked_transform(
|
| 262 |
+
cond_delta + args.uncond,
|
| 263 |
+
prompt.local_transform,
|
| 264 |
+
prompt.weight,
|
| 265 |
+
)
|
| 266 |
+
cond_delta = transformed - args.uncond
|
| 267 |
+
weight = weight_tensor # type: ignore[assignment]
|
| 268 |
+
|
| 269 |
+
return cond_delta, weight
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
# Visitor: CondDelta – weighted sum of leaf cond − uncond
|
| 274 |
+
# ---------------------------------------------------------------------------
|
| 275 |
+
|
| 276 |
+
class _CondDeltaVisitor:
|
| 277 |
+
def visit_leaf_prompt(
|
| 278 |
+
self,
|
| 279 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 280 |
+
args: _DenoiseArgs,
|
| 281 |
+
index: int,
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
cond_info = args.cond_indices[index]
|
| 284 |
+
if that.weight != cond_info[1]:
|
| 285 |
+
_console_warn(f'''
|
| 286 |
+
An unexpected noise weight was encountered at prompt #{index}
|
| 287 |
+
Expected :{that.weight}, but got :{cond_info[1]}
|
| 288 |
+
This is likely due to another extension also monkey patching `combine_denoised`.
|
| 289 |
+
Please open a bug report: https://github.com/ljleb/sd-webui-neutral-prompt/issues
|
| 290 |
+
''')
|
| 291 |
+
return args.x_out[cond_info[0]] - args.uncond
|
| 292 |
+
|
| 293 |
+
def visit_composite_prompt(
|
| 294 |
+
self,
|
| 295 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 296 |
+
args: _DenoiseArgs,
|
| 297 |
+
index: int,
|
| 298 |
+
) -> torch.Tensor:
|
| 299 |
+
cond_delta = torch.zeros_like(args.x_out[0])
|
| 300 |
+
for child in that.children:
|
| 301 |
+
if child.conciliation is None:
|
| 302 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 303 |
+
cond_delta = cond_delta + child_weight * child_delta
|
| 304 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 305 |
+
return cond_delta
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# Visitor: AuxCondDelta – all conciliation strategies
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
class _AuxCondDeltaVisitor:
|
| 313 |
+
def visit_leaf_prompt(
|
| 314 |
+
self,
|
| 315 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 316 |
+
args: _DenoiseArgs,
|
| 317 |
+
cond_delta: torch.Tensor,
|
| 318 |
+
index: int,
|
| 319 |
+
) -> torch.Tensor:
|
| 320 |
+
return torch.zeros_like(args.x_out[0])
|
| 321 |
+
|
| 322 |
+
def visit_composite_prompt(
|
| 323 |
+
self,
|
| 324 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 325 |
+
args: _DenoiseArgs,
|
| 326 |
+
cond_delta: torch.Tensor,
|
| 327 |
+
index: int,
|
| 328 |
+
) -> torch.Tensor:
|
| 329 |
+
aux_cond_delta = torch.zeros_like(args.x_out[0])
|
| 330 |
+
salient_cond_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT k=20
|
| 331 |
+
salient_wide_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_WIDE k=1
|
| 332 |
+
salient_blob_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_BLOB k=20 + morphology
|
| 333 |
+
align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 334 |
+
mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 335 |
+
|
| 336 |
+
for child in that.children:
|
| 337 |
+
if child.conciliation is not None:
|
| 338 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 339 |
+
strat = child.conciliation
|
| 340 |
+
|
| 341 |
+
if strat == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
|
| 342 |
+
aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
|
| 343 |
+
|
| 344 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
|
| 345 |
+
salient_cond_deltas.append((child_delta, child_weight))
|
| 346 |
+
|
| 347 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_WIDE:
|
| 348 |
+
salient_wide_deltas.append((child_delta, child_weight))
|
| 349 |
+
|
| 350 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_BLOB:
|
| 351 |
+
salient_blob_deltas.append((child_delta, child_weight))
|
| 352 |
+
|
| 353 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
|
| 354 |
+
aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, 0.05)
|
| 355 |
+
|
| 356 |
+
else:
|
| 357 |
+
# AND_ALIGN_D_S (soft alignment blend)
|
| 358 |
+
m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
|
| 359 |
+
if m:
|
| 360 |
+
align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 361 |
+
else:
|
| 362 |
+
# AND_MASK_ALIGN_D_S (binary alignment mask)
|
| 363 |
+
m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
|
| 364 |
+
if m:
|
| 365 |
+
mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 366 |
+
|
| 367 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 368 |
+
|
| 369 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas, child_k=20)
|
| 370 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas, child_k=1)
|
| 371 |
+
aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
|
| 372 |
+
aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
|
| 373 |
+
aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
|
| 374 |
+
return aux_cond_delta
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ---------------------------------------------------------------------------
|
| 378 |
+
# Strategy implementations
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
|
| 381 |
+
def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
| 382 |
+
if (normal == 0).all():
|
| 383 |
+
if shared.state.sampling_step <= 0:
|
| 384 |
+
_warn_projection_not_found()
|
| 385 |
+
return vector
|
| 386 |
+
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def _salient_blend(
|
| 390 |
+
normal: torch.Tensor,
|
| 391 |
+
vectors: List[Tuple[torch.Tensor, float]],
|
| 392 |
+
child_k: float = 20.0,
|
| 393 |
+
) -> torch.Tensor:
|
| 394 |
+
"""
|
| 395 |
+
Saliency-guided blend: each child prompt wins in the latent regions
|
| 396 |
+
where its absolute activation magnitude is strongest.
|
| 397 |
+
|
| 398 |
+
child_k controls how sharp/selective the child salience mask is:
|
| 399 |
+
child_k=20 → very sharp, 1-2 peak pixels (AND_SALT — life-branch style)
|
| 400 |
+
child_k=1 → broad, ~55% of pixels (AND_SALT_WIDE — original main-branch)
|
| 401 |
+
Parent always uses k=1 (diffuse reference).
|
| 402 |
+
"""
|
| 403 |
+
if not vectors:
|
| 404 |
+
return torch.zeros_like(normal)
|
| 405 |
+
|
| 406 |
+
salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=child_k) for v, _ in vectors]
|
| 407 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 408 |
+
|
| 409 |
+
result = torch.zeros_like(normal)
|
| 410 |
+
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
| 411 |
+
vector_mask = (mask == mask_i).float()
|
| 412 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 413 |
+
return result
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _salient_blend_blob(
|
| 417 |
+
normal: torch.Tensor,
|
| 418 |
+
vectors: List[Tuple[torch.Tensor, float]],
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
"""
|
| 421 |
+
AND_SALT_BLOB: life-branch dev.py algorithm (cleaned of debug scaffolding).
|
| 422 |
+
|
| 423 |
+
Pipeline per child:
|
| 424 |
+
1. k=20 softmax → very sharp salience seed (1-2 pixels)
|
| 425 |
+
2. _life_erode ×6 → erode to spatially dense core
|
| 426 |
+
3. _life_thickify ×2 → grow outward into a smooth blob
|
| 427 |
+
4. result += weight × blob_mask × (child − parent)
|
| 428 |
+
|
| 429 |
+
This is what the life-branch author was iterating toward in dev.py.
|
| 430 |
+
"""
|
| 431 |
+
if not vectors:
|
| 432 |
+
return torch.zeros_like(normal)
|
| 433 |
+
|
| 434 |
+
salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=20) for v, _ in vectors]
|
| 435 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 436 |
+
|
| 437 |
+
result = torch.zeros_like(normal)
|
| 438 |
+
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
| 439 |
+
vector_mask = (mask == mask_i).float()
|
| 440 |
+
for _ in range(6):
|
| 441 |
+
vector_mask = _life_step(vector_mask, _erode_rule)
|
| 442 |
+
for _ in range(2):
|
| 443 |
+
vector_mask = _life_step(vector_mask, _thickify_rule)
|
| 444 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 445 |
+
return result
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
|
| 449 |
+
"""
|
| 450 |
+
One step of a cellular-automaton morphology on a C×H×W binary mask.
|
| 451 |
+
|
| 452 |
+
The conv3d trick from dev.py: concatenate [board, board[:-1]] along the
|
| 453 |
+
channel axis so a single 3-D convolution simultaneously sums the 3×3
|
| 454 |
+
spatial neighbourhood across *all* C channels. This keeps the operation
|
| 455 |
+
GPU-friendly and avoids an explicit loop over channels.
|
| 456 |
+
|
| 457 |
+
neighbors[c,h,w] = Σ_{dc,dh,dw} board_padded[c+dc, h+dh, w+dw]
|
| 458 |
+
minus the center pixel itself (board is subtracted).
|
| 459 |
+
"""
|
| 460 |
+
C = board.shape[0]
|
| 461 |
+
kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
|
| 462 |
+
kernel = kernel.unsqueeze(0).unsqueeze(0) # [1, 1, C, 3, 3]
|
| 463 |
+
|
| 464 |
+
# Pad spatially (left/right/top/bottom) but not along channel axis
|
| 465 |
+
padded = torch.cat([board.clone(), board[:-1].clone()], dim=0) # [2C-1, H, W]
|
| 466 |
+
padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0) # [2C-1, H+2, W+2]
|
| 467 |
+
|
| 468 |
+
neighbors = torch.nn.functional.conv3d(
|
| 469 |
+
padded.unsqueeze(0).unsqueeze(0), # [1, 1, 2C-1, H+2, W+2]
|
| 470 |
+
kernel, # [1, 1, C, 3, 3]
|
| 471 |
+
padding=0,
|
| 472 |
+
).squeeze(0).squeeze(0) # [C, H, W]
|
| 473 |
+
|
| 474 |
+
neighbors = neighbors - board # subtract center pixel
|
| 475 |
+
return rule(board, neighbors).float()
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 479 |
+
"""Keep a pixel only if it is set AND its C-channel neighbourhood is dense.
|
| 480 |
+
Threshold C*5 matches dev.py: for C=4 that means ≥20 out of 36 possible."""
|
| 481 |
+
C = board.shape[0]
|
| 482 |
+
return (board == 1) & (neighbors >= C * 5)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
"""Grow: keep existing pixels OR add any pixel adjacent to the core.
|
| 487 |
+
population ≥ 4 ensures only pixels touching at least one set neighbour grow."""
|
| 488 |
+
population = board + neighbors
|
| 489 |
+
return (board == 1) | (population >= 4)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
|
| 493 |
+
"""Softmax-based salience map. k > 1 → sharper, more selective mask."""
|
| 494 |
+
return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
|
| 498 |
+
"""Keep only the top k_ratio fraction of activations by absolute value."""
|
| 499 |
+
k = int(torch.numel(vector) * (1 - k_ratio))
|
| 500 |
+
k = max(1, k) # kthvalue requires k >= 1
|
| 501 |
+
threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
|
| 502 |
+
return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
# ---------------------------------------------------------------------------
|
| 506 |
+
# Alignment blend (alignment_blend branch)
|
| 507 |
+
# ---------------------------------------------------------------------------
|
| 508 |
+
|
| 509 |
+
def _compute_subregion_similarity_map(
|
| 510 |
+
child_vector: torch.Tensor,
|
| 511 |
+
parent_vector: torch.Tensor,
|
| 512 |
+
region_size: int = 2,
|
| 513 |
+
) -> torch.Tensor:
|
| 514 |
+
"""
|
| 515 |
+
Compute local average cosine similarity of (region_size × region_size) regions.
|
| 516 |
+
Returns a map of shape [C, H, W] with values in [-1, 1].
|
| 517 |
+
"""
|
| 518 |
+
C, H, W = child_vector.shape
|
| 519 |
+
parent = parent_vector.unsqueeze(0) # [1, C, H, W]
|
| 520 |
+
child = child_vector.unsqueeze(0)
|
| 521 |
+
|
| 522 |
+
region_radius = region_size // 2
|
| 523 |
+
if region_size % 2 == 1:
|
| 524 |
+
pad_size = (region_radius,) * 4
|
| 525 |
+
else:
|
| 526 |
+
pad_size = (region_radius - 1, region_radius) * 2
|
| 527 |
+
|
| 528 |
+
parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
|
| 529 |
+
child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
|
| 530 |
+
|
| 531 |
+
# [H*W, C, region_size, region_size]
|
| 532 |
+
# .contiguous() is required before .view() because .permute() produces a
|
| 533 |
+
# non-contiguous tensor and .view() raises RuntimeError on non-contiguous input.
|
| 534 |
+
parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 535 |
+
child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 536 |
+
|
| 537 |
+
unfold2 = torch.nn.Unfold(kernel_size=2)
|
| 538 |
+
parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
|
| 539 |
+
child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
|
| 540 |
+
|
| 541 |
+
parent_sub = F.normalize(parent_sub, p=2, dim=2)
|
| 542 |
+
child_sub = F.normalize(child_sub, p=2, dim=2)
|
| 543 |
+
sim = (parent_sub * child_sub).sum(dim=2) # [H*W, C, (r-1)^2]
|
| 544 |
+
return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W) # [C, H, W]
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def _alignment_blend(
|
| 548 |
+
parent: torch.Tensor,
|
| 549 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 550 |
+
) -> torch.Tensor:
|
| 551 |
+
"""
|
| 552 |
+
Soft alignment blend (AND_ALIGN_D_S).
|
| 553 |
+
Child contribution is weighted by max(0, structure_alignment − detail_alignment).
|
| 554 |
+
High weight where child changes detail without breaking structure.
|
| 555 |
+
"""
|
| 556 |
+
result = torch.zeros_like(parent)
|
| 557 |
+
for child, weight, detail_size, structure_size in children:
|
| 558 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 559 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 560 |
+
|
| 561 |
+
# Normalise by absolute max so that all-negative maps (anti-correlated prompts
|
| 562 |
+
# such as 'black' vs 'white') don't blow up to ±1e7 and clamp incorrectly to 1.
|
| 563 |
+
# Dividing by positive max would invert the sign ordering when all values are
|
| 564 |
+
# negative; dividing by abs-max preserves relative ordering in both cases.
|
| 565 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 566 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 567 |
+
detail_sim = detail_sim / d_abs_max
|
| 568 |
+
structure_sim = structure_sim / s_abs_max
|
| 569 |
+
|
| 570 |
+
alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
|
| 571 |
+
result = result + (child - parent) * weight * alignment_weight
|
| 572 |
+
return result
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def _alignment_mask_blend(
|
| 576 |
+
parent: torch.Tensor,
|
| 577 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 578 |
+
) -> torch.Tensor:
|
| 579 |
+
"""
|
| 580 |
+
Binary alignment mask (AND_MASK_ALIGN_D_S).
|
| 581 |
+
Child receives full weight where structure_alignment > detail_alignment, zero elsewhere.
|
| 582 |
+
"""
|
| 583 |
+
result = torch.zeros_like(parent)
|
| 584 |
+
for child, weight, detail_size, structure_size in children:
|
| 585 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 586 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 587 |
+
|
| 588 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 589 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 590 |
+
detail_sim = detail_sim / d_abs_max
|
| 591 |
+
structure_sim = structure_sim / s_abs_max
|
| 592 |
+
|
| 593 |
+
alignment_mask = (structure_sim > detail_sim).to(child)
|
| 594 |
+
result = result + (child - parent) * weight * alignment_mask
|
| 595 |
+
return result
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
# ---------------------------------------------------------------------------
|
| 599 |
+
# Sampler hijack
|
| 600 |
+
# ---------------------------------------------------------------------------
|
| 601 |
+
|
| 602 |
+
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 603 |
+
module=sd_samplers,
|
| 604 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 605 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
@sd_samplers_hijacker.hijack('create_sampler')
|
| 610 |
+
def create_sampler_hijack(name: str, model, original_function):
|
| 611 |
+
sampler = original_function(name, model)
|
| 612 |
+
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
|
| 613 |
+
if global_state.is_enabled:
|
| 614 |
+
_warn_unsupported_sampler()
|
| 615 |
+
return sampler
|
| 616 |
+
|
| 617 |
+
sampler.model_wrap_cfg.combine_denoised = functools.partial(
|
| 618 |
+
combine_denoised_hijack,
|
| 619 |
+
original_function=sampler.model_wrap_cfg.combine_denoised,
|
| 620 |
+
)
|
| 621 |
+
return sampler
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
# ---------------------------------------------------------------------------
|
| 625 |
+
# Warnings / logging
|
| 626 |
+
# ---------------------------------------------------------------------------
|
| 627 |
+
|
| 628 |
+
def _warn_unsupported_sampler() -> None:
|
| 629 |
+
_console_warn('''
|
| 630 |
+
Neutral prompt relies on composition via AND, which the webui does not support
|
| 631 |
+
when using any of the DDIM, PLMS and UniPC samplers.
|
| 632 |
+
The sampler will NOT be patched – falling back on the original implementation.
|
| 633 |
+
''')
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def _warn_projection_not_found() -> None:
|
| 637 |
+
_console_warn('''
|
| 638 |
+
Could not find a projection for one or more AND_PERP prompts.
|
| 639 |
+
These prompts will NOT be made perpendicular.
|
| 640 |
+
''')
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def _console_warn(message: str) -> None:
|
| 644 |
+
if not global_state.verbose:
|
| 645 |
+
return
|
| 646 |
+
print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
|
neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@contextlib.contextmanager
|
| 5 |
+
def fix_path():
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
extension_path = str(Path(__file__).parent.parent.parent)
|
| 10 |
+
added = False
|
| 11 |
+
if extension_path not in sys.path:
|
| 12 |
+
sys.path.insert(0, extension_path)
|
| 13 |
+
added = True
|
| 14 |
+
|
| 15 |
+
yield
|
| 16 |
+
|
| 17 |
+
if added:
|
| 18 |
+
sys.path.remove(extension_path)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
with fix_path():
|
| 22 |
+
del fix_path, contextlib
|
| 23 |
+
from .api import *
|
neutral_prompt_patched/lib_neutral_prompt/external_code/api.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
External API for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Provides thin helpers that external extensions / scripts can import via:
|
| 5 |
+
|
| 6 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from lib_neutral_prompt import global_state
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def override_cfg_rescale(cfg_rescale: float) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Override the CFG rescale value for the *next* generation step only.
|
| 17 |
+
After the step runs the value is automatically cleared.
|
| 18 |
+
"""
|
| 19 |
+
global_state.cfg_rescale_override = cfg_rescale
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_last_cfg_rescale_factor() -> Optional[float]:
|
| 23 |
+
"""
|
| 24 |
+
Return the CFG rescale factor computed during the most recent denoising
|
| 25 |
+
step, or None if rescaling was not active.
|
| 26 |
+
"""
|
| 27 |
+
return global_state.CFGRescaleFactorSingleton.get()
|
neutral_prompt_patched/lib_neutral_prompt/global_state.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global mutable state shared across the extension.
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- is_enabled / prompt_exprs / verbose (all branches)
|
| 6 |
+
- cfg_rescale + cfg_rescale_override mechanism (main branch)
|
| 7 |
+
- batch_cond_indices (affine branch – needed by on_cfg_denoiser pre-noise hook)
|
| 8 |
+
- CFGRescaleFactorSingleton – exposes the last computed rescale factor
|
| 9 |
+
to external callers / API users (export_rescale_factor branch)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import threading
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Runtime state
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
is_enabled: bool = False
|
| 22 |
+
prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
|
| 23 |
+
# Populated by reconstruct_multicond_batch hook; used by the pre-noise
|
| 24 |
+
# affine transform hook (on_cfg_denoiser) to know which latent slices
|
| 25 |
+
# correspond to which prompt child.
|
| 26 |
+
batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 27 |
+
cfg_rescale: float = 0.0
|
| 28 |
+
verbose: bool = False # matches UI default; UI syncs this on settings load
|
| 29 |
+
|
| 30 |
+
# Set to a float value by XYZ-grid or external API to override cfg_rescale
|
| 31 |
+
# for a single generation step, then auto-cleared.
|
| 32 |
+
cfg_rescale_override: Optional[float] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def apply_and_clear_cfg_rescale_override() -> None:
|
| 36 |
+
"""Apply a one-shot cfg_rescale override and immediately clear it."""
|
| 37 |
+
global cfg_rescale, cfg_rescale_override
|
| 38 |
+
if cfg_rescale_override is not None:
|
| 39 |
+
cfg_rescale = cfg_rescale_override
|
| 40 |
+
cfg_rescale_override = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# CFG Rescale Factor Singleton (export_rescale_factor branch)
|
| 45 |
+
#
|
| 46 |
+
# Stores the last computed rescale factor so that external tools / scripts
|
| 47 |
+
# can read it without re-deriving it.
|
| 48 |
+
#
|
| 49 |
+
# Uses threading.local() so concurrent API workers each get their own slot.
|
| 50 |
+
# clear() must be called at the start of every _cfg_rescale() call so that
|
| 51 |
+
# get() correctly returns None when rescaling was skipped this step.
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
class CFGRescaleFactorSingleton:
|
| 55 |
+
"""Thread-local store for the most recently computed CFG rescale factor.
|
| 56 |
+
|
| 57 |
+
Lifecycle per denoising step:
|
| 58 |
+
1. _cfg_rescale() calls clear() at entry — value becomes None.
|
| 59 |
+
2. If rescaling is active, set() stores the computed factor.
|
| 60 |
+
3. External code calls get() after the step; receives the factor or None.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
_state = threading.local()
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def set(cls, value: float) -> None:
|
| 67 |
+
cls._state.cfg_rescale_factor = value
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def get(cls) -> Optional[float]:
|
| 71 |
+
return getattr(cls._state, 'cfg_rescale_factor', None)
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def clear(cls) -> None:
|
| 75 |
+
cls._state.cfg_rescale_factor = None
|
neutral_prompt_patched/lib_neutral_prompt/hijacker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ModuleHijacker:
|
| 5 |
+
def __init__(self, module):
|
| 6 |
+
self.__module = module
|
| 7 |
+
self.__original_functions = dict()
|
| 8 |
+
|
| 9 |
+
def hijack(self, attribute):
|
| 10 |
+
if attribute not in self.__original_functions:
|
| 11 |
+
self.__original_functions[attribute] = getattr(self.__module, attribute)
|
| 12 |
+
|
| 13 |
+
def decorator(function):
|
| 14 |
+
setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
|
| 15 |
+
return function
|
| 16 |
+
|
| 17 |
+
return decorator
|
| 18 |
+
|
| 19 |
+
def reset_module(self):
|
| 20 |
+
for attribute, original_function in self.__original_functions.items():
|
| 21 |
+
setattr(self.__module, attribute, original_function)
|
| 22 |
+
|
| 23 |
+
self.__original_functions.clear()
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
|
| 27 |
+
if not hasattr(module, hijacker_attribute):
|
| 28 |
+
module_hijacker = ModuleHijacker(module)
|
| 29 |
+
setattr(module, hijacker_attribute, module_hijacker)
|
| 30 |
+
on_uninstall(lambda: delattr(module, hijacker_attribute))
|
| 31 |
+
on_uninstall(module_hijacker.reset_module)
|
| 32 |
+
return module_hijacker
|
| 33 |
+
else:
|
| 34 |
+
return getattr(module, hijacker_attribute)
|
neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Neutral Prompt Parser
|
| 3 |
+
Combines:
|
| 4 |
+
- Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
|
| 5 |
+
- Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
|
| 6 |
+
- Local alignment blend: AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import abc
|
| 10 |
+
import dataclasses
|
| 11 |
+
import math
|
| 12 |
+
import re
|
| 13 |
+
from enum import Enum
|
| 14 |
+
from itertools import product
|
| 15 |
+
from typing import Any, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Keyword registry
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
# Base conciliation keywords
|
| 25 |
+
_BASE_KEYWORD_MAP = {
|
| 26 |
+
'AND_PERP': 'PERPENDICULAR',
|
| 27 |
+
'AND_SALT': 'SALIENCE_MASK', # k=20 sharp mask (life-branch style)
|
| 28 |
+
'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE', # k=1 broad mask (original main-branch behaviour)
|
| 29 |
+
'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB', # k=20 + erode + thickify → smooth blob (dev.py intent)
|
| 30 |
+
'AND_TOPK': 'SEMANTIC_GUIDANCE',
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Alignment blend: AND_ALIGN_D_S (soft weight, structure > detail)
|
| 34 |
+
_ALIGN_KEYWORD_MAP = {
|
| 35 |
+
f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
|
| 36 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Alignment mask: AND_MASK_ALIGN_D_S (binary mask, structure > detail)
|
| 40 |
+
_MASK_ALIGN_KEYWORD_MAP = {
|
| 41 |
+
f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
|
| 42 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
|
| 46 |
+
|
| 47 |
+
PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
|
| 48 |
+
ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
|
| 49 |
+
|
| 50 |
+
# Lists kept for ordered iteration (tokenize uses list order for regex priority)
|
| 51 |
+
prompt_keywords = [e.value for e in PromptKeyword]
|
| 52 |
+
conciliation_strategies = [e.value for e in ConciliationStrategy]
|
| 53 |
+
|
| 54 |
+
# Sets for O(1) membership checks in parser hot-loops (1864-item list = 230x slower)
|
| 55 |
+
_prompt_keywords_set = frozenset(prompt_keywords)
|
| 56 |
+
_conciliation_strategies_set = frozenset(conciliation_strategies)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Affine transform definitions (from affine branch)
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
affine_transforms = {
|
| 64 |
+
'ROTATE': lambda t, angle=0, *_: t @ torch.tensor([
|
| 65 |
+
[math.cos(angle * 2 * math.pi), -math.sin(angle * 2 * math.pi), 0],
|
| 66 |
+
[math.sin(angle * 2 * math.pi), math.cos(angle * 2 * math.pi), 0],
|
| 67 |
+
[0, 0, 1],
|
| 68 |
+
], dtype=torch.float32),
|
| 69 |
+
'SLIDE': lambda t, x=0, y=0, *_: t @ torch.tensor([
|
| 70 |
+
[1, 0, float(x)],
|
| 71 |
+
[0, 1, float(y)],
|
| 72 |
+
[0, 0, 1],
|
| 73 |
+
], dtype=torch.float32),
|
| 74 |
+
'SCALE': lambda t, x=1, y=None, *_: t @ torch.tensor([
|
| 75 |
+
[float(x), 0, 0],
|
| 76 |
+
[0, float(y) if y is not None else float(x), 0],
|
| 77 |
+
[0, 0, 1],
|
| 78 |
+
], dtype=torch.float32),
|
| 79 |
+
'SHEAR': lambda t, x=0, y=None, *_: t @ torch.tensor([
|
| 80 |
+
[1, math.tan(float(x) * 2 * math.pi), 0],
|
| 81 |
+
[math.tan(float(y if y is not None else x) * 2 * math.pi), 1, 0],
|
| 82 |
+
[0, 0, 1],
|
| 83 |
+
], dtype=torch.float32),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# AST node types
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
@dataclasses.dataclass
|
| 92 |
+
class PromptExpr(abc.ABC):
|
| 93 |
+
weight: float
|
| 94 |
+
conciliation: Optional[ConciliationStrategy]
|
| 95 |
+
local_transform: Optional[torch.Tensor] # 2×3 affine tensor or None
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def accept(self, visitor, *args, **kwargs) -> Any:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclasses.dataclass
|
| 103 |
+
class LeafPrompt(PromptExpr):
|
| 104 |
+
prompt: str
|
| 105 |
+
|
| 106 |
+
def accept(self, visitor, *args, **kwargs):
|
| 107 |
+
return visitor.visit_leaf_prompt(self, *args, **kwargs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class CompositePrompt(PromptExpr):
|
| 112 |
+
children: List[PromptExpr]
|
| 113 |
+
|
| 114 |
+
def accept(self, visitor, *args, **kwargs):
|
| 115 |
+
return visitor.visit_composite_prompt(self, *args, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class FlatSizeVisitor:
|
| 119 |
+
def visit_leaf_prompt(self, that: LeafPrompt) -> int:
|
| 120 |
+
return 1
|
| 121 |
+
|
| 122 |
+
def visit_composite_prompt(self, that: CompositePrompt) -> int:
|
| 123 |
+
return sum(child.accept(self) for child in that.children) if that.children else 0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
# Parser
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def parse_root(string: str) -> CompositePrompt:
|
| 131 |
+
tokens = tokenize(string)
|
| 132 |
+
prompts = parse_prompts(tokens)
|
| 133 |
+
return CompositePrompt(1.0, None, None, prompts)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
|
| 137 |
+
prompts = [parse_prompt(tokens, first=True, nested=nested)]
|
| 138 |
+
while tokens:
|
| 139 |
+
if nested and tokens[0] == ']':
|
| 140 |
+
break
|
| 141 |
+
prompts.append(parse_prompt(tokens, first=False, nested=nested))
|
| 142 |
+
return prompts
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _compose_affine(
|
| 146 |
+
a: Optional[torch.Tensor],
|
| 147 |
+
b: Optional[torch.Tensor],
|
| 148 |
+
) -> Optional[torch.Tensor]:
|
| 149 |
+
"""
|
| 150 |
+
Compose two 2×3 affine matrices (a applied first, then b).
|
| 151 |
+
Returns None if both are None; returns the non-None one if only one exists.
|
| 152 |
+
"""
|
| 153 |
+
if a is None:
|
| 154 |
+
return b
|
| 155 |
+
if b is None:
|
| 156 |
+
return a
|
| 157 |
+
# Lift to 3×3, multiply, trim back to 2×3
|
| 158 |
+
a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 159 |
+
b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 160 |
+
return (b3 @ a3)[:-1]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
|
| 164 |
+
"""
|
| 165 |
+
Return True if `tokens` starts with one or more affine keywords followed
|
| 166 |
+
immediately by a prompt keyword, e.g.:
|
| 167 |
+
ROTATE[0.125] AND_PERP ...
|
| 168 |
+
SLIDE[0.05,0] AND_SALT ...
|
| 169 |
+
ROTATE[0.125] SLIDE[0.05,0] AND_PERP ...
|
| 170 |
+
|
| 171 |
+
This is a pure structural scan — it does NOT compute any affine matrices,
|
| 172 |
+
so it is safe to call as a look-ahead probe without mock-torch issues.
|
| 173 |
+
|
| 174 |
+
Used in two places:
|
| 175 |
+
- parse_prompt() → consume leading affine before keyword
|
| 176 |
+
- _parse_prompt_text() → stop current segment before this boundary
|
| 177 |
+
"""
|
| 178 |
+
pos = 0
|
| 179 |
+
n = len(tokens)
|
| 180 |
+
# skip a leading whitespace-only token
|
| 181 |
+
if pos < n and not tokens[pos].strip():
|
| 182 |
+
pos += 1
|
| 183 |
+
# must start with an affine keyword
|
| 184 |
+
if pos >= n or tokens[pos] not in affine_transforms:
|
| 185 |
+
return False
|
| 186 |
+
# consume one or more KEYWORD [ args ] blocks
|
| 187 |
+
while pos < n and tokens[pos] in affine_transforms:
|
| 188 |
+
pos += 1 # skip keyword
|
| 189 |
+
if pos >= n or tokens[pos] != '[':
|
| 190 |
+
return False
|
| 191 |
+
pos += 1 # skip '['
|
| 192 |
+
if pos < n and tokens[pos] != ']':
|
| 193 |
+
pos += 1 # skip args (may contain commas)
|
| 194 |
+
if pos >= n or tokens[pos] != ']':
|
| 195 |
+
return False
|
| 196 |
+
pos += 1 # skip ']'
|
| 197 |
+
if pos < n and not tokens[pos].strip():
|
| 198 |
+
pos += 1 # skip optional whitespace
|
| 199 |
+
# the very next token must be a prompt keyword
|
| 200 |
+
return pos < n and tokens[pos] in _prompt_keywords_set
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
|
| 204 |
+
# Support two affine-keyword orderings:
|
| 205 |
+
# (A) AND_PERP ROTATE[0.25] text ← trailing affine (original syntax)
|
| 206 |
+
# (B) ROTATE[0.25] AND_PERP text ← leading affine (documented syntax)
|
| 207 |
+
#
|
| 208 |
+
# Note: we check _looks_like_leading_affine_prompt unconditionally (not
|
| 209 |
+
# gated on `not first`) so that a prompt starting with e.g.
|
| 210 |
+
# "ROTATE[0.125] AND_PERP text" is handled correctly even as the first
|
| 211 |
+
# segment.
|
| 212 |
+
leading_affine: Optional[torch.Tensor] = None
|
| 213 |
+
if _looks_like_leading_affine_prompt(tokens):
|
| 214 |
+
probe = tokens.copy()
|
| 215 |
+
leading_affine = _parse_affine_transform(probe)
|
| 216 |
+
tokens[:] = probe
|
| 217 |
+
|
| 218 |
+
# After consuming a leading affine the next token is a keyword; accept it.
|
| 219 |
+
# We also accept a keyword as the very first token (first=True) without a
|
| 220 |
+
# leading affine — e.g. "AND_PERP vivid" at the start of a string should
|
| 221 |
+
# produce a single PERP child, not an empty default child followed by a PERP
|
| 222 |
+
# child. The original `not first` guard was an artefact of earlier grammar;
|
| 223 |
+
# all existing tests pass without it.
|
| 224 |
+
if tokens and tokens[0] in _prompt_keywords_set:
|
| 225 |
+
prompt_type = tokens.pop(0)
|
| 226 |
+
else:
|
| 227 |
+
prompt_type = PromptKeyword.AND.value
|
| 228 |
+
|
| 229 |
+
conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
|
| 230 |
+
|
| 231 |
+
# Parse optional trailing affine transform chain (original syntax)
|
| 232 |
+
trailing_affine = _parse_affine_transform(tokens)
|
| 233 |
+
affine_transform = _compose_affine(leading_affine, trailing_affine)
|
| 234 |
+
|
| 235 |
+
# Try composite (bracketed) form
|
| 236 |
+
tokens_copy = tokens.copy()
|
| 237 |
+
if tokens_copy and tokens_copy[0] == '[':
|
| 238 |
+
tokens_copy.pop(0)
|
| 239 |
+
prompts = parse_prompts(tokens_copy, nested=True)
|
| 240 |
+
if tokens_copy:
|
| 241 |
+
assert tokens_copy.pop(0) == ']'
|
| 242 |
+
if len(prompts) > 1:
|
| 243 |
+
tokens[:] = tokens_copy
|
| 244 |
+
weight = _parse_weight(tokens)
|
| 245 |
+
return CompositePrompt(weight, conciliation, affine_transform, prompts)
|
| 246 |
+
|
| 247 |
+
prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
|
| 248 |
+
return LeafPrompt(weight, conciliation, affine_transform, prompt_text)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
|
| 252 |
+
text = ''
|
| 253 |
+
depth = 0
|
| 254 |
+
weight = 1.0
|
| 255 |
+
while tokens:
|
| 256 |
+
tok = tokens[0]
|
| 257 |
+
if tok == ']':
|
| 258 |
+
if depth == 0:
|
| 259 |
+
if nested:
|
| 260 |
+
break
|
| 261 |
+
else:
|
| 262 |
+
depth -= 1
|
| 263 |
+
elif tok == '[':
|
| 264 |
+
depth += 1
|
| 265 |
+
elif tok == ':':
|
| 266 |
+
if len(tokens) >= 2 and _is_float(tokens[1].strip()):
|
| 267 |
+
if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
|
| 268 |
+
tokens.pop(0)
|
| 269 |
+
weight = float(tokens.pop(0).strip())
|
| 270 |
+
break
|
| 271 |
+
elif tok in _prompt_keywords_set:
|
| 272 |
+
break
|
| 273 |
+
# Stop before a leading-affine-then-keyword boundary, e.g.
|
| 274 |
+
# ROTATE[0.125] AND_PERP text
|
| 275 |
+
# so the affine is consumed by the *next* parse_prompt call, not
|
| 276 |
+
# swallowed as raw text by the current segment.
|
| 277 |
+
elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
|
| 278 |
+
break
|
| 279 |
+
text += tokens.pop(0)
|
| 280 |
+
return text, weight
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
|
| 284 |
+
"""
|
| 285 |
+
Consume optional affine keywords like ROTATE[0.25] SLIDE[0.1,0] from the token stream.
|
| 286 |
+
Returns a 2×3 float32 tensor, or None if no affine tokens were found.
|
| 287 |
+
"""
|
| 288 |
+
tokens_copy = tokens.copy()
|
| 289 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 290 |
+
tokens_copy.pop(0)
|
| 291 |
+
|
| 292 |
+
affine_funcs = []
|
| 293 |
+
|
| 294 |
+
while tokens_copy and tokens_copy[0] in affine_transforms:
|
| 295 |
+
func = affine_transforms[tokens_copy.pop(0)]
|
| 296 |
+
args: List[float] = []
|
| 297 |
+
|
| 298 |
+
if not (tokens_copy and tokens_copy[0] == '['):
|
| 299 |
+
break
|
| 300 |
+
tokens_copy.pop(0) # consume '['
|
| 301 |
+
|
| 302 |
+
if tokens_copy and tokens_copy[0] != ']':
|
| 303 |
+
if tokens_copy[0].strip():
|
| 304 |
+
try:
|
| 305 |
+
args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
|
| 306 |
+
except ValueError:
|
| 307 |
+
break
|
| 308 |
+
else:
|
| 309 |
+
tokens_copy.pop(0)
|
| 310 |
+
|
| 311 |
+
if not (tokens_copy and tokens_copy[0] == ']'):
|
| 312 |
+
break
|
| 313 |
+
tokens_copy.pop(0) # consume ']'
|
| 314 |
+
|
| 315 |
+
affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
|
| 316 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 317 |
+
tokens_copy.pop(0)
|
| 318 |
+
tokens[:] = tokens_copy
|
| 319 |
+
|
| 320 |
+
if not affine_funcs:
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
transform = torch.eye(3, dtype=torch.float32)[:-1] # 2×3
|
| 324 |
+
for fn in reversed(affine_funcs):
|
| 325 |
+
transform = fn(transform)
|
| 326 |
+
return transform
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _parse_weight(tokens: List[str]) -> float:
|
| 330 |
+
if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
|
| 331 |
+
tokens.pop(0)
|
| 332 |
+
return float(tokens.pop(0))
|
| 333 |
+
return 1.0
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def tokenize(s: str) -> List[str]:
|
| 337 |
+
# Use general patterns for AND_ALIGN / AND_MASK_ALIGN families instead of
|
| 338 |
+
# enumerating all 1800+ combinations - that regex is 41k chars and 500x slower.
|
| 339 |
+
# Order matters: longer/more-specific patterns must come first.
|
| 340 |
+
# Use (?<!\w) / (?!\w) instead of \b because _ counts as \w in Python's \b,
|
| 341 |
+
# which would cause 'AND' to match inside 'XAND' (no boundary between \w chars).
|
| 342 |
+
affine_kw_pattern = '|'.join(
|
| 343 |
+
rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
|
| 344 |
+
)
|
| 345 |
+
keyword_pattern = (
|
| 346 |
+
r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # must come before AND_ALIGN and AND
|
| 347 |
+
r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
|
| 348 |
+
r'|(?<!\w)AND_PERP(?!\w)'
|
| 349 |
+
r'|(?<!\w)AND_SALT_WIDE(?!\w)' # must come before AND_SALT
|
| 350 |
+
r'|(?<!\w)AND_SALT_BLOB(?!\w)' # must come before AND_SALT
|
| 351 |
+
r'|(?<!\w)AND_SALT(?!\w)'
|
| 352 |
+
r'|(?<!\w)AND_TOPK(?!\w)'
|
| 353 |
+
r'|(?<!\w)AND(?!\w)'
|
| 354 |
+
)
|
| 355 |
+
return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _is_float(s: str) -> bool:
|
| 359 |
+
try:
|
| 360 |
+
float(s)
|
| 361 |
+
return True
|
| 362 |
+
except ValueError:
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ---------------------------------------------------------------------------
|
| 367 |
+
# Quick smoke-test
|
| 368 |
+
# ---------------------------------------------------------------------------
|
| 369 |
+
if __name__ == '__main__':
|
| 370 |
+
res = parse_root('''
|
| 371 |
+
hello
|
| 372 |
+
AND_PERP [
|
| 373 |
+
arst
|
| 374 |
+
AND defg : 2
|
| 375 |
+
AND_SALT [
|
| 376 |
+
very nested huh? what do you say :.0
|
| 377 |
+
]
|
| 378 |
+
]
|
| 379 |
+
AND_ALIGN_4_8 watercolor style :0.5
|
| 380 |
+
ROTATE[0.125] AND_PERP vibrant colors
|
| 381 |
+
''')
|
| 382 |
+
print(res)
|
neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
|
| 5 |
+
from modules import script_callbacks, prompt_parser
|
| 6 |
+
|
| 7 |
+
# ---------------------------------------------------------------------------
|
| 8 |
+
# Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
|
| 9 |
+
# neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
|
| 10 |
+
# So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
|
| 11 |
+
# and the patched prompt_parser splits that into 3 multicond branches instead of 1,
|
| 12 |
+
# which shifts all batch_cond_indices and breaks PERP/SALT/affine.
|
| 13 |
+
#
|
| 14 |
+
# Solution: escape any standalone '&' inside leaf text with '\&' before handing
|
| 15 |
+
# the string to the webui parser. The patched prompt_parser correctly unescapes
|
| 16 |
+
# '\&' back to '&' during conditioning, so the model sees the original text.
|
| 17 |
+
#
|
| 18 |
+
# "Standalone &" = surrounded by whitespace (or at start/end of string).
|
| 19 |
+
# Examples:
|
| 20 |
+
# "cat & dog" -> "cat \& dog" <- was splitting, now safe
|
| 21 |
+
# "R&D" -> "R&D" <- was NOT splitting, unchanged
|
| 22 |
+
# "\&" -> "\&" <- already escaped, not double-escaped
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
_STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _escape_leaf_ampersands(text: str) -> str:
|
| 28 |
+
"""Escape standalone '&' in a leaf prompt so patched prompt_parser
|
| 29 |
+
does not treat it as a multicond AND separator."""
|
| 30 |
+
if not text or '&' not in text:
|
| 31 |
+
return text
|
| 32 |
+
return _STANDALONE_AMP.sub(r'\\&', text)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 36 |
+
module=prompt_parser,
|
| 37 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 38 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@prompt_parser_hijacker.hijack('get_multicond_prompt_list')
|
| 43 |
+
def get_multicond_prompt_list_hijack(prompts, original_function):
|
| 44 |
+
if not global_state.is_enabled:
|
| 45 |
+
return original_function(prompts)
|
| 46 |
+
|
| 47 |
+
global_state.prompt_exprs = parse_prompts(prompts)
|
| 48 |
+
webui_prompts = transpile_exprs(global_state.prompt_exprs)
|
| 49 |
+
if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
|
| 50 |
+
webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
|
| 51 |
+
|
| 52 |
+
return original_function(webui_prompts)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
|
| 56 |
+
exprs = []
|
| 57 |
+
for prompt in prompts:
|
| 58 |
+
expr = neutral_prompt_parser.parse_root(prompt)
|
| 59 |
+
exprs.append(expr)
|
| 60 |
+
return exprs
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
|
| 64 |
+
webui_prompts = []
|
| 65 |
+
for expr in exprs:
|
| 66 |
+
webui_prompts.append(expr.accept(WebuiPromptVisitor()))
|
| 67 |
+
return webui_prompts
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class WebuiPromptVisitor:
|
| 71 |
+
def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
|
| 72 |
+
# Escape standalone '&' so patched prompt_parser doesn't split the leaf
|
| 73 |
+
# into extra multicond branches. '\&' is correctly unescaped downstream.
|
| 74 |
+
prompt = _escape_leaf_ampersands(that.prompt)
|
| 75 |
+
return f'{prompt} :{that.weight}'
|
| 76 |
+
|
| 77 |
+
def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
|
| 78 |
+
return ' AND '.join(child.accept(self) for child in that.children)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
|
| 82 |
+
def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
|
| 83 |
+
"""Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
|
| 84 |
+
res = original_function(*args, **kwargs)
|
| 85 |
+
global_state.batch_cond_indices = res[0]
|
| 86 |
+
return res
|
neutral_prompt_patched/lib_neutral_prompt/ui.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI for sd-webui-neutral-prompt (unified).
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- Core prompt types + tooltip (main branch)
|
| 6 |
+
- AND_ALIGN / AND_MASK_ALIGN entries (alignment_blend branch)
|
| 7 |
+
- elem_id on dropdown (main branch)
|
| 8 |
+
- CFG Rescale infotext / paste support (all branches)
|
| 9 |
+
- Affine-transform hint in tooltip (affine branch)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from itertools import product
|
| 15 |
+
from typing import Callable, Dict, List, Tuple
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
import gradio as gr
|
| 19 |
+
|
| 20 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 21 |
+
from modules import script_callbacks, shared
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
txt2img_prompt_textbox = None
|
| 25 |
+
img2img_prompt_textbox = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Prompt-type registry shown in the UI dropdown
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
prompt_types: Dict[str, str] = {
|
| 33 |
+
'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
|
| 34 |
+
'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
|
| 35 |
+
'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
|
| 36 |
+
'Saliency wide / classic (AND_SALT_WIDE)':neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
|
| 37 |
+
'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Add AND_ALIGN_D_S entries for commonly-used kernel size pairs
|
| 41 |
+
for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
|
| 42 |
+
_key = f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'
|
| 43 |
+
prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
|
| 44 |
+
|
| 45 |
+
# Add AND_MASK_ALIGN_D_S entries
|
| 46 |
+
for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
|
| 47 |
+
_key = f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'
|
| 48 |
+
prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
|
| 49 |
+
|
| 50 |
+
prompt_types_tooltip = '\n'.join([
|
| 51 |
+
'AND – add all prompt features equally (webui built-in)',
|
| 52 |
+
'AND_PERP – reduce contradicting prompt features via perpendicular projection',
|
| 53 |
+
'AND_SALT – sharp saliency mask: child wins only at its 1-2 peak pixels (surgical)',
|
| 54 |
+
'AND_SALT_BLOB – blob saliency mask: peak pixels eroded to core, then grown to smooth blob',
|
| 55 |
+
'AND_SALT_WIDE – broad saliency mask: child wins ~55% of pixels (classic main-branch)',
|
| 56 |
+
'AND_TOPK – small targeted changes (semantic guidance top-k)',
|
| 57 |
+
'AND_ALIGN_D_S – soft blend: child adds details without breaking structure',
|
| 58 |
+
'AND_MASK_ALIGN_D_S– binary mask blend: stricter version of AND_ALIGN',
|
| 59 |
+
'',
|
| 60 |
+
'Affine transforms (supported in either order for auxiliary segments):',
|
| 61 |
+
' ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]',
|
| 62 |
+
' e.g.: ROTATE[0.125] AND_PERP vibrant colors :0.8',
|
| 63 |
+
' e.g.: AND_PERP ROTATE[0.125] vibrant colors :0.8',
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# UI component class
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
@dataclasses.dataclass
|
| 72 |
+
class AccordionInterface:
|
| 73 |
+
get_elem_id: Callable[[str], str]
|
| 74 |
+
|
| 75 |
+
def __post_init__(self):
|
| 76 |
+
self.is_rendered = False
|
| 77 |
+
|
| 78 |
+
self.cfg_rescale = gr.Slider(
|
| 79 |
+
label='CFG rescale φ',
|
| 80 |
+
minimum=0, maximum=1, value=0,
|
| 81 |
+
info='0 = disabled. Rescales the CFG output towards the predicted x0 to reduce over-saturation.',
|
| 82 |
+
)
|
| 83 |
+
self.neutral_prompt = gr.Textbox(
|
| 84 |
+
label='Neutral prompt',
|
| 85 |
+
show_label=False,
|
| 86 |
+
lines=3,
|
| 87 |
+
placeholder=(
|
| 88 |
+
'Neutral / auxiliary prompt '
|
| 89 |
+
'(press "Apply to prompt" to append with the chosen strategy keyword)'
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
self.neutral_cond_scale = gr.Slider(
|
| 93 |
+
label='Prompt weight',
|
| 94 |
+
minimum=-3, maximum=3, value=1,
|
| 95 |
+
)
|
| 96 |
+
self.aux_prompt_type = gr.Dropdown(
|
| 97 |
+
label='Prompt type',
|
| 98 |
+
choices=list(prompt_types.keys()),
|
| 99 |
+
value=next(iter(prompt_types.keys())),
|
| 100 |
+
tooltip=prompt_types_tooltip,
|
| 101 |
+
elem_id=self.get_elem_id('formatter_prompt_type'),
|
| 102 |
+
)
|
| 103 |
+
self.append_to_prompt_button = gr.Button(value='Apply to prompt')
|
| 104 |
+
|
| 105 |
+
# ------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
def arrange_components(self, is_img2img: bool) -> None:
|
| 108 |
+
if self.is_rendered:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
with gr.Accordion(label='Neutral Prompt', open=False):
|
| 112 |
+
self.cfg_rescale.render()
|
| 113 |
+
with gr.Accordion(label='Prompt formatter', open=False):
|
| 114 |
+
self.neutral_prompt.render()
|
| 115 |
+
self.neutral_cond_scale.render()
|
| 116 |
+
self.aux_prompt_type.render()
|
| 117 |
+
self.append_to_prompt_button.render()
|
| 118 |
+
|
| 119 |
+
def connect_events(self, is_img2img: bool) -> None:
|
| 120 |
+
if self.is_rendered:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
|
| 124 |
+
self.append_to_prompt_button.click(
|
| 125 |
+
fn=lambda init_prompt, prompt, scale, prompt_type: (
|
| 126 |
+
f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}',
|
| 127 |
+
'',
|
| 128 |
+
),
|
| 129 |
+
inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type],
|
| 130 |
+
outputs=[prompt_textbox, self.neutral_prompt],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def set_rendered(self, value: bool = True) -> None:
|
| 134 |
+
self.is_rendered = value
|
| 135 |
+
|
| 136 |
+
# ------------------------------------------------------------------
|
| 137 |
+
# Script interface (args / infotext / paste)
|
| 138 |
+
|
| 139 |
+
def get_components(self) -> Tuple[gr.components.Component, ...]:
|
| 140 |
+
return (self.cfg_rescale,)
|
| 141 |
+
|
| 142 |
+
def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str], ...]:
|
| 143 |
+
return tuple(zip(self.get_components(), ('CFG Rescale phi',)))
|
| 144 |
+
|
| 145 |
+
def get_paste_field_names(self) -> List[str]:
|
| 146 |
+
return ['CFG Rescale phi']
|
| 147 |
+
|
| 148 |
+
def get_extra_generation_params(self, args: Dict) -> Dict:
|
| 149 |
+
return {'CFG Rescale phi': args['cfg_rescale']}
|
| 150 |
+
|
| 151 |
+
def unpack_processing_args(self, cfg_rescale: float) -> Dict:
|
| 152 |
+
return {'cfg_rescale': cfg_rescale}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Settings & callbacks
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def on_ui_settings() -> None:
|
| 160 |
+
section = ('neutral_prompt', 'Neutral Prompt')
|
| 161 |
+
shared.opts.add_option(
|
| 162 |
+
'neutral_prompt_enabled',
|
| 163 |
+
shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
|
| 164 |
+
)
|
| 165 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 166 |
+
shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
|
| 167 |
+
|
| 168 |
+
shared.opts.add_option(
|
| 169 |
+
'neutral_prompt_verbose',
|
| 170 |
+
shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
|
| 171 |
+
)
|
| 172 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 173 |
+
shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _update_enabled() -> None:
|
| 180 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _update_verbose() -> None:
|
| 184 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def on_after_component(component, **_kwargs) -> None:
|
| 188 |
+
global txt2img_prompt_textbox, img2img_prompt_textbox
|
| 189 |
+
eid = getattr(component, 'elem_id', None)
|
| 190 |
+
if eid == 'txt2img_prompt':
|
| 191 |
+
txt2img_prompt_textbox = component
|
| 192 |
+
elif eid == 'img2img_prompt':
|
| 193 |
+
img2img_prompt_textbox = component
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
script_callbacks.on_after_component(on_after_component)
|
neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from types import ModuleType
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from modules import scripts
|
| 5 |
+
from lib_neutral_prompt import global_state
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def patch():
|
| 9 |
+
xyz_module = find_xyz_module()
|
| 10 |
+
if xyz_module is None:
|
| 11 |
+
print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
xyz_module.axis_options.extend([
|
| 15 |
+
xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class XyzFloat(float):
|
| 20 |
+
is_xyz: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def apply_cfg_rescale():
|
| 24 |
+
def callback(_p, v, _vs):
|
| 25 |
+
global_state.cfg_rescale = XyzFloat(v)
|
| 26 |
+
|
| 27 |
+
return callback
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def int_or_float(string):
|
| 31 |
+
try:
|
| 32 |
+
return int(string)
|
| 33 |
+
except ValueError:
|
| 34 |
+
return float(string)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def find_xyz_module() -> Optional[ModuleType]:
|
| 38 |
+
for data in scripts.scripts_data:
|
| 39 |
+
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
|
| 40 |
+
return data.module
|
| 41 |
+
|
| 42 |
+
return None
|
neutral_prompt_patched/scripts/neutral_prompt.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
|
| 2 |
+
from modules import scripts, processing, shared
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NeutralPromptScript(scripts.Script):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.accordion_interface = None
|
| 10 |
+
self._is_img2img = False
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def is_img2img(self):
|
| 14 |
+
return self._is_img2img
|
| 15 |
+
|
| 16 |
+
@is_img2img.setter
|
| 17 |
+
def is_img2img(self, is_img2img):
|
| 18 |
+
self._is_img2img = is_img2img
|
| 19 |
+
if self.accordion_interface is None:
|
| 20 |
+
self.accordion_interface = ui.AccordionInterface(self.elem_id)
|
| 21 |
+
|
| 22 |
+
def title(self) -> str:
|
| 23 |
+
return "Neutral Prompt"
|
| 24 |
+
|
| 25 |
+
def show(self, is_img2img: bool):
|
| 26 |
+
return scripts.AlwaysVisible
|
| 27 |
+
|
| 28 |
+
def ui(self, is_img2img: bool):
|
| 29 |
+
self.hijack_composable_lora(is_img2img)
|
| 30 |
+
|
| 31 |
+
self.accordion_interface.arrange_components(is_img2img)
|
| 32 |
+
self.accordion_interface.connect_events(is_img2img)
|
| 33 |
+
self.infotext_fields = self.accordion_interface.get_infotext_fields()
|
| 34 |
+
self.paste_field_names = self.accordion_interface.get_paste_field_names()
|
| 35 |
+
self.accordion_interface.set_rendered()
|
| 36 |
+
return self.accordion_interface.get_components()
|
| 37 |
+
|
| 38 |
+
def process(self, p: processing.StableDiffusionProcessing, *args):
|
| 39 |
+
args = self.accordion_interface.unpack_processing_args(*args)
|
| 40 |
+
|
| 41 |
+
self.update_global_state(args)
|
| 42 |
+
if global_state.is_enabled:
|
| 43 |
+
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
|
| 44 |
+
|
| 45 |
+
def update_global_state(self, args: Dict):
|
| 46 |
+
if shared.state.job_no == 0:
|
| 47 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 48 |
+
|
| 49 |
+
for k, v in args.items():
|
| 50 |
+
try:
|
| 51 |
+
getattr(global_state, k)
|
| 52 |
+
except AttributeError:
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
if getattr(getattr(global_state, k), 'is_xyz', False):
|
| 56 |
+
xyz_attr = getattr(global_state, k)
|
| 57 |
+
xyz_attr.is_xyz = False
|
| 58 |
+
args[k] = xyz_attr
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
if shared.state.job_no > 0:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
setattr(global_state, k, v)
|
| 65 |
+
|
| 66 |
+
def hijack_composable_lora(self, is_img2img):
|
| 67 |
+
if self.accordion_interface.is_rendered:
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
lora_script = None
|
| 71 |
+
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
| 72 |
+
|
| 73 |
+
for script in script_runner.alwayson_scripts:
|
| 74 |
+
if script.title().lower() == "composable lora":
|
| 75 |
+
lora_script = script
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
if lora_script is not None:
|
| 79 |
+
lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
|
| 83 |
+
if not global_state.is_enabled:
|
| 84 |
+
return original_function(p, *args, **kwargs)
|
| 85 |
+
|
| 86 |
+
exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
|
| 87 |
+
all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
|
| 88 |
+
res = original_function(p, *args, **kwargs)
|
| 89 |
+
# restore original prompts
|
| 90 |
+
p.all_prompts = all_prompts
|
| 91 |
+
return res
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
xyz_grid.patch()
|
neutral_prompt_patched/test/perp_parser/__init__.py
ADDED
|
File without changes
|
neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for affine-keyword ordering in parse_prompt.
|
| 3 |
+
|
| 4 |
+
Both orderings must work correctly:
|
| 5 |
+
(A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
|
| 6 |
+
(B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
|
| 7 |
+
"""
|
| 8 |
+
import unittest
|
| 9 |
+
|
| 10 |
+
from lib_neutral_prompt import neutral_prompt_parser as p
|
| 11 |
+
from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestAffineKeywordOrder(unittest.TestCase):
|
| 15 |
+
|
| 16 |
+
# ------------------------------------------------------------------ #
|
| 17 |
+
# trailing affine (AND_PERP ROTATE[...] text) #
|
| 18 |
+
# ------------------------------------------------------------------ #
|
| 19 |
+
|
| 20 |
+
def test_trailing_affine_basic(self):
|
| 21 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
|
| 22 |
+
self.assertEqual(len(root.children), 2)
|
| 23 |
+
self.assertEqual(root.children[0].prompt.strip(), "base")
|
| 24 |
+
self.assertIsNone(root.children[0].local_transform)
|
| 25 |
+
self.assertIsNotNone(root.children[1].local_transform)
|
| 26 |
+
self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 27 |
+
self.assertAlmostEqual(root.children[1].weight, 0.8)
|
| 28 |
+
|
| 29 |
+
# ------------------------------------------------------------------ #
|
| 30 |
+
# leading affine (ROTATE[...] AND_PERP text) #
|
| 31 |
+
# ------------------------------------------------------------------ #
|
| 32 |
+
|
| 33 |
+
def test_leading_affine_mid_prompt(self):
|
| 34 |
+
"""ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
|
| 35 |
+
root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 36 |
+
self.assertEqual(len(root.children), 2,
|
| 37 |
+
f"Expected 2 children, got {len(root.children)}: "
|
| 38 |
+
f"{[getattr(c,'prompt','composite') for c in root.children]}")
|
| 39 |
+
first, second = root.children
|
| 40 |
+
# first segment = "base " with no transform
|
| 41 |
+
self.assertIn("base", first.prompt)
|
| 42 |
+
self.assertNotIn("ROTATE", first.prompt)
|
| 43 |
+
self.assertIsNone(first.local_transform)
|
| 44 |
+
# second segment gets the transform
|
| 45 |
+
self.assertIsNotNone(second.local_transform)
|
| 46 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 47 |
+
self.assertAlmostEqual(second.weight, 0.8)
|
| 48 |
+
|
| 49 |
+
def test_leading_affine_at_start(self):
|
| 50 |
+
"""ROTATE[...] AND_PERP as first (and only non-AND) segment."""
|
| 51 |
+
root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 52 |
+
self.assertEqual(len(root.children), 1)
|
| 53 |
+
child = root.children[0]
|
| 54 |
+
self.assertIsNotNone(child.local_transform)
|
| 55 |
+
self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 56 |
+
self.assertIn("vivid", child.prompt)
|
| 57 |
+
self.assertNotIn("ROTATE", child.prompt)
|
| 58 |
+
|
| 59 |
+
def test_leading_affine_with_and_salt(self):
|
| 60 |
+
root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
|
| 61 |
+
self.assertEqual(len(root.children), 2)
|
| 62 |
+
second = root.children[1]
|
| 63 |
+
self.assertIsNotNone(second.local_transform)
|
| 64 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
# Composed / chained affine #
|
| 68 |
+
# ------------------------------------------------------------------ #
|
| 69 |
+
|
| 70 |
+
def test_leading_and_trailing_compose(self):
|
| 71 |
+
"""ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
|
| 72 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
|
| 73 |
+
# Just check no crash and all children parse
|
| 74 |
+
self.assertGreaterEqual(len(root.children), 1)
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------ #
|
| 77 |
+
# CFGRescaleFactorSingleton lifecycle #
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
|
| 80 |
+
def test_cfg_rescale_singleton_clear(self):
|
| 81 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 82 |
+
S.clear()
|
| 83 |
+
self.assertIsNone(S.get())
|
| 84 |
+
S.set(1.23)
|
| 85 |
+
self.assertAlmostEqual(S.get(), 1.23)
|
| 86 |
+
S.clear()
|
| 87 |
+
self.assertIsNone(S.get())
|
| 88 |
+
|
| 89 |
+
def test_cfg_rescale_singleton_thread_local(self):
|
| 90 |
+
import threading
|
| 91 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 92 |
+
results = {}
|
| 93 |
+
def worker(val, key):
|
| 94 |
+
S.clear()
|
| 95 |
+
S.set(val)
|
| 96 |
+
import time; time.sleep(0.01)
|
| 97 |
+
results[key] = S.get()
|
| 98 |
+
threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
|
| 99 |
+
for t in threads: t.start()
|
| 100 |
+
for t in threads: t.join()
|
| 101 |
+
for i in range(3):
|
| 102 |
+
self.assertAlmostEqual(results[i], i * 10.0,
|
| 103 |
+
msg=f"Thread {i} got {results[i]} expected {i*10.0}")
|
| 104 |
+
|
| 105 |
+
# ------------------------------------------------------------------ #
|
| 106 |
+
# New AND_SALT variants parse correctly #
|
| 107 |
+
# ------------------------------------------------------------------ #
|
| 108 |
+
|
| 109 |
+
def test_and_salt_wide_parsed(self):
|
| 110 |
+
root = p.parse_root("base AND_SALT_WIDE vivid")
|
| 111 |
+
self.assertEqual(len(root.children), 2)
|
| 112 |
+
self.assertEqual(root.children[1].conciliation,
|
| 113 |
+
p.ConciliationStrategy.SALIENCE_MASK_WIDE)
|
| 114 |
+
|
| 115 |
+
def test_and_salt_blob_parsed(self):
|
| 116 |
+
root = p.parse_root("base AND_SALT_BLOB vivid")
|
| 117 |
+
self.assertEqual(len(root.children), 2)
|
| 118 |
+
self.assertEqual(root.children[1].conciliation,
|
| 119 |
+
p.ConciliationStrategy.SALIENCE_MASK_BLOB)
|
| 120 |
+
|
| 121 |
+
def test_and_align_parsed(self):
|
| 122 |
+
root = p.parse_root("base AND_ALIGN_4_8 vivid")
|
| 123 |
+
self.assertEqual(len(root.children), 2)
|
| 124 |
+
self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
|
| 125 |
+
|
| 126 |
+
def test_and_mask_align_parsed(self):
|
| 127 |
+
root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
|
| 128 |
+
self.assertEqual(len(root.children), 2)
|
| 129 |
+
self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration smoke tests for the pre-noise affine pipeline.
|
| 3 |
+
|
| 4 |
+
Requires PyTorch. If torch is not installed the entire module is skipped via
|
| 5 |
+
``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
|
| 6 |
+
|
| 7 |
+
Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
|
| 8 |
+
modules.shared). torch itself is real — no sys.modules replacement.
|
| 9 |
+
|
| 10 |
+
Tests:
|
| 11 |
+
1. Hook is a no-op when global_state.is_enabled = False.
|
| 12 |
+
2. Identity transform (angle=0) leaves x numerically unchanged.
|
| 13 |
+
3. Non-identity transform calls apply_affine_transform (verified by spy).
|
| 14 |
+
4. Prompts with no affine leave x unchanged.
|
| 15 |
+
5. Empty state does not crash.
|
| 16 |
+
6. Batch with multiple prompts does not crash.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import dataclasses
|
| 20 |
+
import importlib
|
| 21 |
+
import math
|
| 22 |
+
import sys
|
| 23 |
+
import types
|
| 24 |
+
import unittest
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Detect torch availability — skip the whole module if absent
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
try:
|
| 30 |
+
import importlib.util as _ilu
|
| 31 |
+
_spec = _ilu.find_spec('torch')
|
| 32 |
+
_TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
|
| 33 |
+
except (ValueError, ModuleNotFoundError):
|
| 34 |
+
_TORCH_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
|
| 39 |
+
# be imported even when running the skip path)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def _stub(name):
|
| 43 |
+
m = types.ModuleType(name)
|
| 44 |
+
sys.modules.setdefault(name, m)
|
| 45 |
+
return sys.modules[name]
|
| 46 |
+
|
| 47 |
+
for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
|
| 48 |
+
'modules.shared', 'modules.prompt_parser', 'gradio'):
|
| 49 |
+
_stub(_mod_name)
|
| 50 |
+
|
| 51 |
+
import modules.script_callbacks as _sc
|
| 52 |
+
if not hasattr(_sc, 'on_cfg_denoiser'):
|
| 53 |
+
_sc.on_cfg_denoiser = lambda fn: None
|
| 54 |
+
if not hasattr(_sc, 'on_script_unloaded'):
|
| 55 |
+
_sc.on_script_unloaded = lambda fn: None
|
| 56 |
+
|
| 57 |
+
import modules.shared as _sh
|
| 58 |
+
if not hasattr(_sh, 'opts'):
|
| 59 |
+
_sh.opts = types.SimpleNamespace(
|
| 60 |
+
data={},
|
| 61 |
+
add_option=lambda *a, **k: None,
|
| 62 |
+
onchange=lambda *a, **k: None,
|
| 63 |
+
OptionInfo=lambda *a, **k: None,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
import modules.sd_samplers as _ss
|
| 67 |
+
if not hasattr(_ss, 'cfg_denoiser'):
|
| 68 |
+
_ss.cfg_denoiser = None
|
| 69 |
+
# cfg_denoiser_hijack installs a hijack on create_sampler at import time
|
| 70 |
+
if not hasattr(_ss, 'create_sampler'):
|
| 71 |
+
_ss.create_sampler = lambda *a, **k: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Repo root on sys.path
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
from pathlib import Path
|
| 79 |
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 80 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 81 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Conditional imports (only when torch is available)
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
if _TORCH_AVAILABLE:
|
| 89 |
+
import torch
|
| 90 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 91 |
+
from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
|
| 92 |
+
import lib_neutral_prompt.affine_transform as _at_mod
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------
|
| 95 |
+
# Helpers
|
| 96 |
+
# -----------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _leaf(text, angle=None):
|
| 99 |
+
"""LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
|
| 100 |
+
transform = None
|
| 101 |
+
if angle is not None:
|
| 102 |
+
c = math.cos(angle * 2 * math.pi)
|
| 103 |
+
s = math.sin(angle * 2 * math.pi)
|
| 104 |
+
transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
|
| 105 |
+
return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
|
| 106 |
+
|
| 107 |
+
def _comp(*children):
|
| 108 |
+
return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class _FakeParams:
|
| 112 |
+
x: torch.Tensor
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Test class
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
|
| 120 |
+
class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
|
| 121 |
+
|
| 122 |
+
def setUp(self):
|
| 123 |
+
global_state.is_enabled = True
|
| 124 |
+
global_state.prompt_exprs = []
|
| 125 |
+
global_state.batch_cond_indices = []
|
| 126 |
+
|
| 127 |
+
def tearDown(self):
|
| 128 |
+
global_state.is_enabled = False
|
| 129 |
+
global_state.prompt_exprs = []
|
| 130 |
+
global_state.batch_cond_indices = []
|
| 131 |
+
|
| 132 |
+
# 1. disabled → x untouched
|
| 133 |
+
def test_disabled_leaves_x_unchanged(self):
|
| 134 |
+
global_state.is_enabled = False
|
| 135 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 136 |
+
x[0] = 1.0
|
| 137 |
+
orig = x.clone()
|
| 138 |
+
params = _FakeParams(x=x.clone())
|
| 139 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
|
| 140 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 141 |
+
_on_cfg_denoiser(params)
|
| 142 |
+
self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
|
| 143 |
+
|
| 144 |
+
# 2. identity transform → x numerically unchanged
|
| 145 |
+
def test_identity_transform_no_change(self):
|
| 146 |
+
x = torch.randn(2, 4, 8, 8)
|
| 147 |
+
orig = x.clone()
|
| 148 |
+
params = _FakeParams(x=x.clone())
|
| 149 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
|
| 150 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 151 |
+
_on_cfg_denoiser(params)
|
| 152 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 153 |
+
"Identity affine must not change x")
|
| 154 |
+
|
| 155 |
+
# 3. non-identity → apply_affine_transform is called (spy)
|
| 156 |
+
def test_nonidentity_calls_apply_affine_transform(self):
|
| 157 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 158 |
+
x[1] = torch.randn(4, 8, 8)
|
| 159 |
+
params = _FakeParams(x=x.clone())
|
| 160 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
|
| 161 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 162 |
+
|
| 163 |
+
_real = _at_mod.apply_affine_transform
|
| 164 |
+
calls = []
|
| 165 |
+
def _spy(tensor, affine, mode='bilinear'):
|
| 166 |
+
calls.append(True)
|
| 167 |
+
return _real(tensor, affine, mode)
|
| 168 |
+
|
| 169 |
+
_at_mod.apply_affine_transform = _spy
|
| 170 |
+
try:
|
| 171 |
+
_on_cfg_denoiser(params)
|
| 172 |
+
finally:
|
| 173 |
+
_at_mod.apply_affine_transform = _real
|
| 174 |
+
|
| 175 |
+
self.assertGreater(len(calls), 0,
|
| 176 |
+
"apply_affine_transform must be called for non-identity transform")
|
| 177 |
+
|
| 178 |
+
# 4. no affine on any child → x unchanged
|
| 179 |
+
def test_no_affine_no_change(self):
|
| 180 |
+
x = torch.randn(2, 4, 8, 8)
|
| 181 |
+
orig = x.clone()
|
| 182 |
+
params = _FakeParams(x=x.clone())
|
| 183 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
|
| 184 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 185 |
+
_on_cfg_denoiser(params)
|
| 186 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 187 |
+
"No affine → x must not change")
|
| 188 |
+
|
| 189 |
+
# 5. empty state → no crash
|
| 190 |
+
def test_empty_state_no_crash(self):
|
| 191 |
+
params = _FakeParams(x=torch.randn(1, 4, 8, 8))
|
| 192 |
+
global_state.prompt_exprs = []
|
| 193 |
+
global_state.batch_cond_indices = []
|
| 194 |
+
try:
|
| 195 |
+
_on_cfg_denoiser(params)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
|
| 198 |
+
|
| 199 |
+
# 6. two batch entries → no crash
|
| 200 |
+
def test_batch_multiple_prompts_no_crash(self):
|
| 201 |
+
params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
|
| 202 |
+
global_state.prompt_exprs = [
|
| 203 |
+
_comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
|
| 204 |
+
_comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
|
| 205 |
+
]
|
| 206 |
+
global_state.batch_cond_indices = [
|
| 207 |
+
[(0, 1.0), (1, 1.0)],
|
| 208 |
+
[(2, 1.0), (3, 1.0)],
|
| 209 |
+
]
|
| 210 |
+
try:
|
| 211 |
+
_on_cfg_denoiser(params)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
self.fail(f"Multi-prompt batch raised: {e}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == '__main__':
|
| 217 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_basic_parser.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
|
| 11 |
+
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
|
| 12 |
+
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
|
| 13 |
+
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
|
| 14 |
+
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
|
| 15 |
+
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
|
| 16 |
+
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
|
| 17 |
+
|
| 18 |
+
def test_simple_prompt_child_count(self):
|
| 19 |
+
self.assertEqual(len(self.simple_prompt.children), 1)
|
| 20 |
+
|
| 21 |
+
def test_simple_prompt_child_weight(self):
|
| 22 |
+
self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
|
| 23 |
+
|
| 24 |
+
def test_simple_prompt_child_prompt(self):
|
| 25 |
+
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
|
| 26 |
+
|
| 27 |
+
def test_square_weight_prompt(self):
|
| 28 |
+
prompt = "a [b c d e : f g h :1.5]"
|
| 29 |
+
parsed = neutral_prompt_parser.parse_root(prompt)
|
| 30 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 31 |
+
|
| 32 |
+
composed_prompt = f"{prompt} AND_PERP other prompt"
|
| 33 |
+
parsed = neutral_prompt_parser.parse_root(composed_prompt)
|
| 34 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 35 |
+
|
| 36 |
+
def test_and_prompt_child_count(self):
|
| 37 |
+
self.assertEqual(len(self.and_prompt.children), 2)
|
| 38 |
+
|
| 39 |
+
def test_and_prompt_child_weights_and_prompts(self):
|
| 40 |
+
self.assertEqual(self.and_prompt.children[0].weight, 1.0)
|
| 41 |
+
self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
|
| 42 |
+
self.assertEqual(self.and_prompt.children[1].weight, 2.0)
|
| 43 |
+
self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
|
| 44 |
+
|
| 45 |
+
def test_and_perp_prompt_child_count(self):
|
| 46 |
+
self.assertEqual(len(self.and_perp_prompt.children), 2)
|
| 47 |
+
|
| 48 |
+
def test_and_perp_prompt_child_types(self):
|
| 49 |
+
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 50 |
+
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 51 |
+
|
| 52 |
+
def test_and_perp_prompt_nested_child(self):
|
| 53 |
+
nested_child = self.and_perp_prompt.children[1]
|
| 54 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 55 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 56 |
+
|
| 57 |
+
def test_nested_and_perp_prompt_child_count(self):
|
| 58 |
+
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
|
| 59 |
+
|
| 60 |
+
def test_nested_and_perp_prompt_child_types(self):
|
| 61 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 62 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 63 |
+
|
| 64 |
+
def test_nested_and_perp_prompt_nested_child_types(self):
|
| 65 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[0]
|
| 66 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 67 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 68 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 69 |
+
|
| 70 |
+
def test_nested_and_perp_prompt_nested_child(self):
|
| 71 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 72 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 73 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 74 |
+
|
| 75 |
+
def test_invalid_weight_child_count(self):
|
| 76 |
+
self.assertEqual(len(self.invalid_weight.children), 1)
|
| 77 |
+
|
| 78 |
+
def test_invalid_weight_child_weight(self):
|
| 79 |
+
self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
|
| 80 |
+
|
| 81 |
+
def test_invalid_weight_child_prompt(self):
|
| 82 |
+
self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
|
| 83 |
+
|
| 84 |
+
def test_and_salt_prompt_child_count(self):
|
| 85 |
+
self.assertEqual(len(self.and_salt_prompt.children), 2)
|
| 86 |
+
|
| 87 |
+
def test_and_salt_prompt_child_types(self):
|
| 88 |
+
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 89 |
+
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 90 |
+
|
| 91 |
+
def test_and_salt_prompt_nested_child(self):
|
| 92 |
+
nested_child = self.and_salt_prompt.children[1]
|
| 93 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 94 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 95 |
+
|
| 96 |
+
def test_nested_and_salt_prompt_child_count(self):
|
| 97 |
+
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
|
| 98 |
+
|
| 99 |
+
def test_nested_and_salt_prompt_child_types(self):
|
| 100 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 101 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 102 |
+
|
| 103 |
+
def test_nested_and_salt_prompt_nested_child_types(self):
|
| 104 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[0]
|
| 105 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 106 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 107 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 108 |
+
|
| 109 |
+
def test_nested_and_salt_prompt_nested_child(self):
|
| 110 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 111 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 112 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 113 |
+
|
| 114 |
+
def test_start_with_prompt_editing(self):
|
| 115 |
+
prompt = "[(long shot:1.2):0.1] detail.."
|
| 116 |
+
res = neutral_prompt_parser.parse_root(prompt)
|
| 117 |
+
self.assertEqual(res.children[0].weight, 1.0)
|
| 118 |
+
self.assertEqual(res.children[0].prompt, prompt)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_malicious_parser.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestMaliciousPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.parser = neutral_prompt_parser
|
| 11 |
+
|
| 12 |
+
def test_empty(self):
|
| 13 |
+
result = self.parser.parse_root("")
|
| 14 |
+
self.assertEqual(result.children[0].prompt, "")
|
| 15 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 16 |
+
|
| 17 |
+
def test_zero_weight(self):
|
| 18 |
+
result = self.parser.parse_root("hello :0.0")
|
| 19 |
+
self.assertEqual(result.children[0].weight, 0.0)
|
| 20 |
+
|
| 21 |
+
def test_mixed_positive_and_negative_weights(self):
|
| 22 |
+
result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
|
| 23 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 24 |
+
self.assertEqual(result.children[1].weight, -2.0)
|
| 25 |
+
|
| 26 |
+
def test_debalanced_square_brackets(self):
|
| 27 |
+
prompt = "a [ b " * 100
|
| 28 |
+
result = self.parser.parse_root(prompt)
|
| 29 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 30 |
+
|
| 31 |
+
prompt = "a ] b " * 100
|
| 32 |
+
result = self.parser.parse_root(prompt)
|
| 33 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 34 |
+
|
| 35 |
+
repeats = 10
|
| 36 |
+
prompt = "a [ [ b AND c ] " * repeats
|
| 37 |
+
result = self.parser.parse_root(prompt)
|
| 38 |
+
self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
|
| 39 |
+
|
| 40 |
+
repeats = 10
|
| 41 |
+
prompt = "a [ b AND c ] ] " * repeats
|
| 42 |
+
result = self.parser.parse_root(prompt)
|
| 43 |
+
self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
|
| 44 |
+
|
| 45 |
+
def test_erroneous_syntax(self):
|
| 46 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
|
| 47 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 48 |
+
self.assertEqual(result.children[1].prompt, "[goodbye ")
|
| 49 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 50 |
+
|
| 51 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
|
| 52 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 53 |
+
self.assertEqual(result.children[1].prompt, " goodbye ")
|
| 54 |
+
|
| 55 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
|
| 56 |
+
self.assertEqual(result.children[1].prompt, " goodbye]")
|
| 57 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 58 |
+
|
| 59 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
|
| 60 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 61 |
+
self.assertEqual(result.children[1].prompt, " a [ goodbye ")
|
| 62 |
+
|
| 63 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
|
| 64 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 65 |
+
self.assertEqual(result.children[2].prompt, " goodbye ")
|
| 66 |
+
|
| 67 |
+
def test_huge_number_of_prompt_parts(self):
|
| 68 |
+
result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
|
| 69 |
+
self.assertEqual(len(result.children), 10**4)
|
| 70 |
+
|
| 71 |
+
def test_prompt_ending_with_weight(self):
|
| 72 |
+
result = self.parser.parse_root("hello :1.0 AND :2.0")
|
| 73 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 74 |
+
self.assertEqual(result.children[1].prompt, "")
|
| 75 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 76 |
+
|
| 77 |
+
def test_huge_input_string(self):
|
| 78 |
+
big_string = "hello :1.0 AND " * 10**4
|
| 79 |
+
result = self.parser.parse_root(big_string)
|
| 80 |
+
self.assertEqual(len(result.children), 10**4 + 1)
|
| 81 |
+
|
| 82 |
+
def test_deeply_nested_prompt(self):
|
| 83 |
+
deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
|
| 84 |
+
result = self.parser.parse_root(deeply_nested_prompt)
|
| 85 |
+
self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
|
| 86 |
+
|
| 87 |
+
def test_complex_nested_prompts(self):
|
| 88 |
+
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
|
| 89 |
+
result = self.parser.parse_root(complex_prompt)
|
| 90 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 91 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 92 |
+
self.assertEqual(result.children[2].children[0].weight, 3.0)
|
| 93 |
+
self.assertEqual(result.children[2].children[1].weight, 4.0)
|
| 94 |
+
self.assertEqual(result.children[2].children[2].weight, 5.0)
|
| 95 |
+
|
| 96 |
+
def test_string_with_random_characters(self):
|
| 97 |
+
random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
|
| 98 |
+
try:
|
| 99 |
+
self.parser.parse_root(random_chars)
|
| 100 |
+
except Exception:
|
| 101 |
+
self.fail("parse_root couldn't handle a string with random characters.")
|
| 102 |
+
|
| 103 |
+
def test_string_with_unexpected_symbols(self):
|
| 104 |
+
unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
|
| 105 |
+
try:
|
| 106 |
+
self.parser.parse_root(unexpected_symbols)
|
| 107 |
+
except Exception:
|
| 108 |
+
self.fail("parse_root couldn't handle a string with unexpected symbols.")
|
| 109 |
+
|
| 110 |
+
def test_string_with_unconventional_structure(self):
|
| 111 |
+
unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
|
| 112 |
+
try:
|
| 113 |
+
self.parser.parse_root(unconventional_structure)
|
| 114 |
+
except Exception:
|
| 115 |
+
self.fail("parse_root couldn't handle a string with unconventional structure.")
|
| 116 |
+
|
| 117 |
+
def test_string_with_mixed_alphabets_and_numbers(self):
|
| 118 |
+
mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
|
| 119 |
+
try:
|
| 120 |
+
self.parser.parse_root(mixed_alphabets_and_numbers)
|
| 121 |
+
except Exception:
|
| 122 |
+
self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
|
| 123 |
+
|
| 124 |
+
def test_string_with_nested_brackets(self):
|
| 125 |
+
nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
|
| 126 |
+
try:
|
| 127 |
+
self.parser.parse_root(nested_brackets)
|
| 128 |
+
except Exception:
|
| 129 |
+
self.fail("parse_root couldn't handle a string with nested brackets.")
|
| 130 |
+
|
| 131 |
+
def test_unmatched_opening_braces(self):
|
| 132 |
+
unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
|
| 133 |
+
try:
|
| 134 |
+
self.parser.parse_root(unmatched_opening_braces)
|
| 135 |
+
except Exception:
|
| 136 |
+
self.fail("parse_root couldn't handle a string with unmatched opening braces.")
|
| 137 |
+
|
| 138 |
+
def test_unmatched_closing_braces(self):
|
| 139 |
+
unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
|
| 140 |
+
try:
|
| 141 |
+
self.parser.parse_root(unmatched_closing_braces)
|
| 142 |
+
except Exception:
|
| 143 |
+
self.fail("parse_root couldn't handle a string with unmatched closing braces.")
|
| 144 |
+
|
| 145 |
+
def test_repeating_colons(self):
|
| 146 |
+
repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
|
| 147 |
+
try:
|
| 148 |
+
self.parser.parse_root(repeating_colons)
|
| 149 |
+
except Exception:
|
| 150 |
+
self.fail("parse_root couldn't handle a string with repeating colons.")
|
| 151 |
+
|
| 152 |
+
def test_excessive_whitespace(self):
|
| 153 |
+
excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
|
| 154 |
+
try:
|
| 155 |
+
self.parser.parse_root(excessive_whitespace)
|
| 156 |
+
except Exception:
|
| 157 |
+
self.fail("parse_root couldn't handle a string with excessive whitespace.")
|
| 158 |
+
|
| 159 |
+
def test_repeating_AND_keyword(self):
|
| 160 |
+
repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
|
| 161 |
+
try:
|
| 162 |
+
self.parser.parse_root(repeating_AND_keyword)
|
| 163 |
+
except Exception:
|
| 164 |
+
self.fail("parse_root couldn't handle a string with repeating AND keyword.")
|
| 165 |
+
|
| 166 |
+
def test_repeating_AND_PERP_keyword(self):
|
| 167 |
+
repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
|
| 168 |
+
try:
|
| 169 |
+
self.parser.parse_root(repeating_AND_PERP_keyword)
|
| 170 |
+
except Exception:
|
| 171 |
+
self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
|
| 172 |
+
|
| 173 |
+
def test_square_weight_prompt(self):
|
| 174 |
+
prompt = "AND_PERP [weighted] you thought it was the end"
|
| 175 |
+
try:
|
| 176 |
+
self.parser.parse_root(prompt)
|
| 177 |
+
except Exception:
|
| 178 |
+
self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == '__main__':
|
| 182 |
+
unittest.main()
|