Upload 37 files
Browse files- neutral_prompt_patched/.gitignore +1 -0
- neutral_prompt_patched/LICENSE +21 -0
- neutral_prompt_patched/README.md +99 -0
- neutral_prompt_patched/lib_neutral_prompt/__init__.py +1 -0
- neutral_prompt_patched/lib_neutral_prompt/affine_transform.py +83 -0
- neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py +684 -0
- neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py +23 -0
- neutral_prompt_patched/lib_neutral_prompt/external_code/api.py +27 -0
- neutral_prompt_patched/lib_neutral_prompt/global_state.py +75 -0
- neutral_prompt_patched/lib_neutral_prompt/hijacker.py +34 -0
- neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py +382 -0
- neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py +134 -0
- neutral_prompt_patched/lib_neutral_prompt/ui.py +196 -0
- neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py +42 -0
- neutral_prompt_patched/scripts/neutral_prompt.py +94 -0
- neutral_prompt_patched/test/perp_parser/__init__.py +0 -0
- neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py +133 -0
- neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py +217 -0
- neutral_prompt_patched/test/perp_parser/test_basic_parser.py +122 -0
- neutral_prompt_patched/test/perp_parser/test_malicious_parser.py +182 -0
- prompt-fusion-extension-main/.gitignore +3 -0
- prompt-fusion-extension-main/LICENSE +21 -0
- prompt-fusion-extension-main/lib_prompt_fusion/ast_nodes.py +307 -0
- prompt-fusion-extension-main/lib_prompt_fusion/empty_cond.py +19 -0
- prompt-fusion-extension-main/lib_prompt_fusion/geometries.py +33 -0
- prompt-fusion-extension-main/lib_prompt_fusion/global_state.py +28 -0
- prompt-fusion-extension-main/lib_prompt_fusion/hijacker.py +34 -0
- prompt-fusion-extension-main/lib_prompt_fusion/interpolation_functions.py +87 -0
- prompt-fusion-extension-main/lib_prompt_fusion/interpolation_tensor.py +249 -0
- prompt-fusion-extension-main/lib_prompt_fusion/prompt_parser.py +378 -0
- prompt-fusion-extension-main/lib_prompt_fusion/t_scaler.py +38 -0
- prompt-fusion-extension-main/metadata.ini +2 -0
- prompt-fusion-extension-main/readme.md +95 -0
- prompt-fusion-extension-main/requirements.txt +1 -0
- prompt-fusion-extension-main/scripts/promptlang.py +355 -0
- prompt-fusion-extension-main/test/parser_tests.py +104 -0
- prompt-fusion-extension-main/test/run_all.py +7 -0
neutral_prompt_patched/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
neutral_prompt_patched/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 ljleb
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
neutral_prompt_patched/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# sd-webui-neutral-prompt — Ultimate Edition
|
| 2 |
+
|
| 3 |
+
> **Unified merge of the main experimental branches.**
|
| 4 |
+
> Most features from all branches are merged and work simultaneously.
|
| 5 |
+
> Some `life`/dev-only tooling (debug visualisations, mask images) was intentionally not carried over — only the production-safe subset is included.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## What's inside
|
| 10 |
+
|
| 11 |
+
| Feature | Origin branch | Keyword / setting |
|
| 12 |
+
|---|---|---|
|
| 13 |
+
| Perpendicular projection | `main` | `AND_PERP` |
|
| 14 |
+
| Saliency-guided mask | `main` | `AND_SALT` |
|
| 15 |
+
| Semantic guidance top-k | `main` | `AND_TOPK` |
|
| 16 |
+
| CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
|
| 17 |
+
| XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
|
| 18 |
+
| External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
|
| 19 |
+
| CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
|
| 20 |
+
| Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` prefix |
|
| 21 |
+
| Soft alignment blend | `alignment_blend` | `AND_ALIGN_D_S` |
|
| 22 |
+
| Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN_D_S` |
|
| 23 |
+
| Sharpened saliency maps (k=20) | `life` (production-safe subset) | used automatically by `AND_SALT` |
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## Prompt syntax
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
positive prompt text
|
| 31 |
+
AND_PERP negative-direction text :0.8
|
| 32 |
+
AND_SALT competing concept :1.0
|
| 33 |
+
AND_TOPK fine detail tweak :0.5
|
| 34 |
+
AND_ALIGN_4_8 style blend :0.6
|
| 35 |
+
AND_MASK_ALIGN_4_8 structure-preserving style :0.6
|
| 36 |
+
AND_PERP ROTATE[0.125] rotated perpendicular concept :1.0
|
| 37 |
+
ROTATE[0.125] AND_PERP rotated perpendicular concept :1.0
|
| 38 |
+
AND_SALT SLIDE[0.05,0] spatially shifted concept :1.0
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Affine transform keywords
|
| 42 |
+
Affine keywords are supported in **either order** relative to the conciliation keyword:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
AND_PERP ROTATE[0.125] vivid colors :0.8 ← affine after keyword
|
| 46 |
+
ROTATE[0.125] AND_PERP vivid colors :0.8 ← affine before keyword (both work)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
| Keyword | Parameter | Effect |
|
| 50 |
+
|---|---|---|
|
| 51 |
+
| `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
|
| 52 |
+
| `SLIDE[x,y]` | normalised offset | Translate latent contribution |
|
| 53 |
+
| `SCALE[x,y]` | scale factors | Scale latent contribution |
|
| 54 |
+
| `SHEAR[x,y]` | shear in turns | Shear latent contribution |
|
| 55 |
+
|
| 56 |
+
Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
|
| 57 |
+
|
| 58 |
+
### AND_ALIGN_D_S / AND_MASK_ALIGN_D_S
|
| 59 |
+
- **D** = detail kernel size (small → fine detail)
|
| 60 |
+
- **S** = structure kernel size (large → global composition)
|
| 61 |
+
- The child prompt is blended in proportionally to how much it alters detail *without* changing structure.
|
| 62 |
+
- `AND_ALIGN` uses a soft weight; `AND_MASK_ALIGN` uses a binary 0/1 mask.
|
| 63 |
+
- Supported range: D, S ∈ [2, 32], D ≠ S.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## External API
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
# In another extension or script:
|
| 71 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 72 |
+
|
| 73 |
+
override_cfg_rescale(0.7) # override for next step only
|
| 74 |
+
factor = get_last_cfg_rescale_factor() # read after generation
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Installation
|
| 80 |
+
|
| 81 |
+
1. Clone / copy this folder into `stable-diffusion-webui/extensions/`.
|
| 82 |
+
2. Restart the webui.
|
| 83 |
+
|
| 84 |
+
---
|
| 85 |
+
|
| 86 |
+
## CFG Rescale φ
|
| 87 |
+
|
| 88 |
+
When set to a value > 0 the extension rescales the CFG output to have the
|
| 89 |
+
same mean and a partially rescaled standard deviation as the raw predicted x0.
|
| 90 |
+
This reduces colour over-saturation at high CFG values without affecting
|
| 91 |
+
prompt adherence as much as simply lowering CFG.
|
| 92 |
+
|
| 93 |
+
The formula used is the **mean-preserving** variant from the `main` branch:
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
|
neutral_prompt_patched/lib_neutral_prompt/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# lib_neutral_prompt package
|
neutral_prompt_patched/lib_neutral_prompt/affine_transform.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Affine spatial transform utilities for latent tensors.
|
| 3 |
+
Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
|
| 4 |
+
|
| 5 |
+
Provides:
|
| 6 |
+
apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
|
| 7 |
+
apply_masked_transform() – apply affine + cosine-feathered mask blending
|
| 8 |
+
create_cosine_feathered_mask() – smooth circular weight mask
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from typing import Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def apply_affine_transform(
|
| 18 |
+
tensor: torch.Tensor,
|
| 19 |
+
affine: torch.Tensor,
|
| 20 |
+
mode: str = 'bilinear',
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
|
| 24 |
+
|
| 25 |
+
:param tensor: Input tensor of shape [C, H, W].
|
| 26 |
+
:param affine: 2×3 float32 affine matrix.
|
| 27 |
+
:param mode: Interpolation mode for grid_sample ('bilinear' default).
|
| 28 |
+
The original affine branch used bilinear for the pre-noise
|
| 29 |
+
inverse path (smoother) and nearest for post-combine.
|
| 30 |
+
Bilinear is the better default for both.
|
| 31 |
+
:return: Transformed tensor of shape [C, H, W].
|
| 32 |
+
"""
|
| 33 |
+
affine = affine.clone().to(tensor.device)
|
| 34 |
+
aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
|
| 35 |
+
affine[0, 1] *= aspect_ratio
|
| 36 |
+
affine[1, 0] /= aspect_ratio
|
| 37 |
+
|
| 38 |
+
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
| 39 |
+
return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def apply_masked_transform(
|
| 43 |
+
tensor: torch.Tensor,
|
| 44 |
+
affine: torch.Tensor,
|
| 45 |
+
weight: float,
|
| 46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
"""
|
| 48 |
+
Apply an affine transform with a cosine-feathered spatial weight mask.
|
| 49 |
+
|
| 50 |
+
The mask channel is appended to the tensor before transformation so that
|
| 51 |
+
the same spatial warp is applied to both content and mask consistently.
|
| 52 |
+
|
| 53 |
+
:param tensor: C×H×W latent tensor.
|
| 54 |
+
:param affine: 2×3 affine matrix.
|
| 55 |
+
:param weight: Global scalar weight for the mask.
|
| 56 |
+
:return: (transformed_content [C, H, W], transformed_mask [H, W])
|
| 57 |
+
"""
|
| 58 |
+
mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
|
| 59 |
+
tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
|
| 60 |
+
transformed = apply_affine_transform(tensor_with_mask, affine)
|
| 61 |
+
return transformed[:-1], transformed[-1] # content, mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
Create a circularly-clipped cosine-feathered mask of shape [H, W].
|
| 67 |
+
|
| 68 |
+
Values at the centre approach `weight`; values outside the unit circle
|
| 69 |
+
are exactly 0, with a smooth cosine fall-off in between.
|
| 70 |
+
|
| 71 |
+
:param size: (H, W) tuple.
|
| 72 |
+
:param weight: Peak mask value at the centre.
|
| 73 |
+
:return: Float32 mask tensor of shape [H, W].
|
| 74 |
+
"""
|
| 75 |
+
y, x = torch.meshgrid(
|
| 76 |
+
torch.linspace(-1, 1, size[0]),
|
| 77 |
+
torch.linspace(-1, 1, size[1]),
|
| 78 |
+
indexing='ij',
|
| 79 |
+
)
|
| 80 |
+
dist = torch.sqrt(x ** 2 + y ** 2)
|
| 81 |
+
mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
|
| 82 |
+
mask[dist > 1] = 0.0
|
| 83 |
+
return mask.float() * weight
|
neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified CFG Denoiser Hijack
|
| 3 |
+
===========================
|
| 4 |
+
Combines features from all sd-webui-neutral-prompt branches:
|
| 5 |
+
|
| 6 |
+
• Core PERP / SALT / TOPK strategies (main branch)
|
| 7 |
+
• cfg_rescale with mean-preserving formula (main branch – better than multiplicative)
|
| 8 |
+
• cfg_rescale_override (XYZ-grid / API support) (main branch)
|
| 9 |
+
• CFGRescaleFactorSingleton export (export_rescale_factor branch)
|
| 10 |
+
• Affine spatial transforms per-prompt (affine branch)
|
| 11 |
+
• AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
|
| 12 |
+
• AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
|
| 13 |
+
• Improved salience with sharpness parameter k (life branch – production-safe subset)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
import functools
|
| 20 |
+
import re
|
| 21 |
+
import sys
|
| 22 |
+
import textwrap
|
| 23 |
+
from typing import Dict, List, Tuple
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
from lib_neutral_prompt import affine_transform as affine_mod
|
| 29 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
|
| 30 |
+
from modules import script_callbacks, sd_samplers, shared
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Pre-noise affine hook (affine branch)
|
| 35 |
+
#
|
| 36 |
+
# Applied BEFORE each CFG denoising step: the noisy latent slice for every
|
| 37 |
+
# cond prompt is warped by the INVERSE of that prompt's affine transform.
|
| 38 |
+
# This means the UNet sees the latent in the "locally-rotated" frame of each
|
| 39 |
+
# prompt. combine_denoised then applies the FORWARD transform to the
|
| 40 |
+
# resulting cond delta, restoring it to the global frame.
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
@dataclasses.dataclass
|
| 44 |
+
class _PreNoiseArgs:
|
| 45 |
+
x: torch.Tensor
|
| 46 |
+
cond_indices: List[Tuple[int, float]]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _GlobalToLocalAffineVisitor:
|
| 50 |
+
"""Build a {cond_index: inverse_affine} dict for every leaf prompt."""
|
| 51 |
+
|
| 52 |
+
def visit_leaf_prompt(
|
| 53 |
+
self,
|
| 54 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 55 |
+
args: _PreNoiseArgs,
|
| 56 |
+
index: int,
|
| 57 |
+
) -> Dict[int, torch.Tensor]:
|
| 58 |
+
cond_index = args.cond_indices[index][0]
|
| 59 |
+
if that.local_transform is not None:
|
| 60 |
+
# 2×3 → 3×3 → invert → back to 2×3
|
| 61 |
+
m3 = torch.vstack([that.local_transform,
|
| 62 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 63 |
+
inv = torch.linalg.inv(m3)[:-1]
|
| 64 |
+
else:
|
| 65 |
+
inv = torch.eye(3)[:-1]
|
| 66 |
+
return {cond_index: inv}
|
| 67 |
+
|
| 68 |
+
def visit_composite_prompt(
|
| 69 |
+
self,
|
| 70 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 71 |
+
args: _PreNoiseArgs,
|
| 72 |
+
index: int,
|
| 73 |
+
) -> Dict[int, torch.Tensor]:
|
| 74 |
+
inv_transforms: Dict[int, torch.Tensor] = {}
|
| 75 |
+
|
| 76 |
+
for child in that.children:
|
| 77 |
+
inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
|
| 78 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 79 |
+
|
| 80 |
+
# Compose parent transform on top of all children
|
| 81 |
+
if that.local_transform is not None:
|
| 82 |
+
m3 = torch.vstack([that.local_transform,
|
| 83 |
+
torch.tensor([0.0, 0.0, 1.0])])
|
| 84 |
+
parent_inv = torch.linalg.inv(m3)
|
| 85 |
+
for inv in inv_transforms.values():
|
| 86 |
+
# inv is 2×3; extend to 3×3, multiply, trim back
|
| 87 |
+
inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
|
| 88 |
+
inv[:] = (parent_inv @ inv3)[:-1]
|
| 89 |
+
|
| 90 |
+
return inv_transforms
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
|
| 94 |
+
"""Pre-noise hook: warp each cond latent slice by the inverse affine."""
|
| 95 |
+
if not global_state.is_enabled:
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
if not _batch_is_compatible(global_state.prompt_exprs, global_state.batch_cond_indices):
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
for prompt, cond_indices in zip(global_state.prompt_exprs,
|
| 102 |
+
global_state.batch_cond_indices):
|
| 103 |
+
args = _PreNoiseArgs(params.x, cond_indices)
|
| 104 |
+
inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
|
| 105 |
+
for cond_index, _ in cond_indices:
|
| 106 |
+
params.x[cond_index] = affine_mod.apply_affine_transform(
|
| 107 |
+
params.x[cond_index], inv_transforms[cond_index]
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _flat_size(prompt: neutral_prompt_parser.PromptExpr) -> int:
|
| 116 |
+
return prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _batch_is_compatible(
|
| 120 |
+
prompts: List[neutral_prompt_parser.PromptExpr],
|
| 121 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 122 |
+
) -> bool:
|
| 123 |
+
if len(prompts) != len(batch_cond_indices):
|
| 124 |
+
_console_warn(f'''
|
| 125 |
+
Neutral Prompt batch mismatch:
|
| 126 |
+
prompt_exprs={len(prompts)} vs batch_cond_indices={len(batch_cond_indices)}
|
| 127 |
+
Falling back to original A1111 behavior for this step.
|
| 128 |
+
''')
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
for i, (prompt, cond_indices) in enumerate(zip(prompts, batch_cond_indices)):
|
| 132 |
+
need = _flat_size(prompt)
|
| 133 |
+
got = len(cond_indices)
|
| 134 |
+
if need != got:
|
| 135 |
+
_console_warn(f'''
|
| 136 |
+
Neutral Prompt branch mismatch at prompt #{i}:
|
| 137 |
+
expected {need} branches, got {got}
|
| 138 |
+
This usually means another extension replaced prompt parsing hooks.
|
| 139 |
+
Falling back to original A1111 behavior for this step.
|
| 140 |
+
''')
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# Public entry point
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
def combine_denoised_hijack(
|
| 151 |
+
x_out: torch.Tensor,
|
| 152 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 153 |
+
text_uncond: torch.Tensor,
|
| 154 |
+
cond_scale: float,
|
| 155 |
+
original_function,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
if not global_state.is_enabled:
|
| 158 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 159 |
+
|
| 160 |
+
if not _batch_is_compatible(global_state.prompt_exprs, batch_cond_indices):
|
| 161 |
+
return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
|
| 162 |
+
|
| 163 |
+
denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
|
| 164 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 165 |
+
|
| 166 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 167 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 168 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
|
| 169 |
+
aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
|
| 170 |
+
|
| 171 |
+
# Apply per-prompt affine transform to both deltas (affine branch)
|
| 172 |
+
if prompt.local_transform is not None:
|
| 173 |
+
cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
|
| 174 |
+
aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
|
| 175 |
+
|
| 176 |
+
cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
|
| 177 |
+
denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
|
| 178 |
+
|
| 179 |
+
return denoised
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
# Internal helpers
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
|
| 186 |
+
def _get_webui_denoised(
|
| 187 |
+
x_out: torch.Tensor,
|
| 188 |
+
batch_cond_indices: List[List[Tuple[int, float]]],
|
| 189 |
+
text_uncond: torch.Tensor,
|
| 190 |
+
cond_scale: float,
|
| 191 |
+
original_function,
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
uncond = x_out[-text_uncond.shape[0]:]
|
| 194 |
+
sliced_batch_x_out: List[torch.Tensor] = []
|
| 195 |
+
sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 196 |
+
|
| 197 |
+
for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
|
| 198 |
+
args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
|
| 199 |
+
sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
|
| 200 |
+
if sliced_indices:
|
| 201 |
+
sliced_batch_cond_indices.append(sliced_indices)
|
| 202 |
+
sliced_batch_x_out.extend(sliced_x_out)
|
| 203 |
+
|
| 204 |
+
sliced_batch_x_out += list(uncond)
|
| 205 |
+
return original_function(
|
| 206 |
+
torch.stack(sliced_batch_x_out, dim=0),
|
| 207 |
+
sliced_batch_cond_indices,
|
| 208 |
+
text_uncond,
|
| 209 |
+
cond_scale,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
"""
|
| 215 |
+
Mean-preserving CFG rescale (main branch formula).
|
| 216 |
+
Override is applied first so XYZ-grid / external API overrides are never silently skipped.
|
| 217 |
+
Also stores the computed rescale factor in CFGRescaleFactorSingleton.
|
| 218 |
+
"""
|
| 219 |
+
# Clear last step's factor so get() returns None if rescaling is skipped
|
| 220 |
+
# this step (stale value bug fix).
|
| 221 |
+
global_state.CFGRescaleFactorSingleton.clear()
|
| 222 |
+
|
| 223 |
+
# Apply one-shot override BEFORE the early-exit check, otherwise a 0→nonzero
|
| 224 |
+
# override (from the external API) would be silently discarded.
|
| 225 |
+
global_state.apply_and_clear_cfg_rescale_override()
|
| 226 |
+
|
| 227 |
+
if global_state.cfg_rescale == 0:
|
| 228 |
+
return cfg_cond
|
| 229 |
+
|
| 230 |
+
cfg_std = cfg_cond.std()
|
| 231 |
+
if cfg_std == 0:
|
| 232 |
+
# Degenerate case: constant tensor – rescaling is a no-op.
|
| 233 |
+
return cfg_cond
|
| 234 |
+
|
| 235 |
+
cfg_cond_mean = cfg_cond.mean()
|
| 236 |
+
rescale_mean = (
|
| 237 |
+
(1 - global_state.cfg_rescale) * cfg_cond_mean
|
| 238 |
+
+ global_state.cfg_rescale * cond.mean()
|
| 239 |
+
)
|
| 240 |
+
rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
|
| 241 |
+
|
| 242 |
+
# Export for external consumers (export_rescale_factor branch)
|
| 243 |
+
global_state.CFGRescaleFactorSingleton.set(
|
| 244 |
+
rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@dataclasses.dataclass
|
| 251 |
+
class _DenoiseArgs:
|
| 252 |
+
x_out: torch.Tensor
|
| 253 |
+
uncond: torch.Tensor
|
| 254 |
+
cond_indices: List[Tuple[int, float]]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ---------------------------------------------------------------------------
|
| 258 |
+
# Gather webui-style conditions (needed for the webui's own CFG path)
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
|
| 261 |
+
def _gather_webui_conds(
|
| 262 |
+
prompt: neutral_prompt_parser.CompositePrompt,
|
| 263 |
+
args: _DenoiseArgs,
|
| 264 |
+
index_in: int,
|
| 265 |
+
index_out: int,
|
| 266 |
+
) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
|
| 267 |
+
sliced_x_out: List[torch.Tensor] = []
|
| 268 |
+
sliced_cond_indices: List[Tuple[int, float]] = []
|
| 269 |
+
|
| 270 |
+
for child in prompt.children:
|
| 271 |
+
if child.conciliation is None:
|
| 272 |
+
if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
|
| 273 |
+
child_x_out = args.x_out[args.cond_indices[index_in][0]]
|
| 274 |
+
child_weight = child.weight
|
| 275 |
+
else:
|
| 276 |
+
child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
|
| 277 |
+
child_x_out = child_x_out + args.uncond
|
| 278 |
+
|
| 279 |
+
index_offset = index_out + len(sliced_x_out)
|
| 280 |
+
sliced_x_out.append(child_x_out)
|
| 281 |
+
sliced_cond_indices.append((index_offset, child_weight))
|
| 282 |
+
|
| 283 |
+
index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 284 |
+
|
| 285 |
+
return sliced_x_out, sliced_cond_indices
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _get_cond_delta_and_weight(
|
| 289 |
+
prompt: neutral_prompt_parser.PromptExpr,
|
| 290 |
+
args: _DenoiseArgs,
|
| 291 |
+
index: int,
|
| 292 |
+
) -> Tuple[torch.Tensor, float]:
|
| 293 |
+
"""Compute cond delta and effective weight, applying affine transform if present."""
|
| 294 |
+
cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
|
| 295 |
+
cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
|
| 296 |
+
weight = prompt.weight
|
| 297 |
+
|
| 298 |
+
if prompt.local_transform is not None:
|
| 299 |
+
transformed, weight_tensor = affine_mod.apply_masked_transform(
|
| 300 |
+
cond_delta + args.uncond,
|
| 301 |
+
prompt.local_transform,
|
| 302 |
+
prompt.weight,
|
| 303 |
+
)
|
| 304 |
+
cond_delta = transformed - args.uncond
|
| 305 |
+
weight = weight_tensor # type: ignore[assignment]
|
| 306 |
+
|
| 307 |
+
return cond_delta, weight
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
# Visitor: CondDelta – weighted sum of leaf cond − uncond
|
| 312 |
+
# ---------------------------------------------------------------------------
|
| 313 |
+
|
| 314 |
+
class _CondDeltaVisitor:
|
| 315 |
+
def visit_leaf_prompt(
|
| 316 |
+
self,
|
| 317 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 318 |
+
args: _DenoiseArgs,
|
| 319 |
+
index: int,
|
| 320 |
+
) -> torch.Tensor:
|
| 321 |
+
cond_info = args.cond_indices[index]
|
| 322 |
+
if that.weight != cond_info[1]:
|
| 323 |
+
_console_warn(f'''
|
| 324 |
+
An unexpected noise weight was encountered at prompt #{index}
|
| 325 |
+
Expected :{that.weight}, but got :{cond_info[1]}
|
| 326 |
+
This is likely due to another extension also monkey patching `combine_denoised`.
|
| 327 |
+
Please open a bug report: https://github.com/ljleb/sd-webui-neutral-prompt/issues
|
| 328 |
+
''')
|
| 329 |
+
return args.x_out[cond_info[0]] - args.uncond
|
| 330 |
+
|
| 331 |
+
def visit_composite_prompt(
|
| 332 |
+
self,
|
| 333 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 334 |
+
args: _DenoiseArgs,
|
| 335 |
+
index: int,
|
| 336 |
+
) -> torch.Tensor:
|
| 337 |
+
cond_delta = torch.zeros_like(args.x_out[0])
|
| 338 |
+
for child in that.children:
|
| 339 |
+
if child.conciliation is None:
|
| 340 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 341 |
+
cond_delta = cond_delta + child_weight * child_delta
|
| 342 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 343 |
+
return cond_delta
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
# Visitor: AuxCondDelta – all conciliation strategies
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
|
| 350 |
+
class _AuxCondDeltaVisitor:
|
| 351 |
+
def visit_leaf_prompt(
|
| 352 |
+
self,
|
| 353 |
+
that: neutral_prompt_parser.LeafPrompt,
|
| 354 |
+
args: _DenoiseArgs,
|
| 355 |
+
cond_delta: torch.Tensor,
|
| 356 |
+
index: int,
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
return torch.zeros_like(args.x_out[0])
|
| 359 |
+
|
| 360 |
+
def visit_composite_prompt(
|
| 361 |
+
self,
|
| 362 |
+
that: neutral_prompt_parser.CompositePrompt,
|
| 363 |
+
args: _DenoiseArgs,
|
| 364 |
+
cond_delta: torch.Tensor,
|
| 365 |
+
index: int,
|
| 366 |
+
) -> torch.Tensor:
|
| 367 |
+
aux_cond_delta = torch.zeros_like(args.x_out[0])
|
| 368 |
+
salient_cond_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT k=20
|
| 369 |
+
salient_wide_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_WIDE k=1
|
| 370 |
+
salient_blob_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_BLOB k=20 + morphology
|
| 371 |
+
align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 372 |
+
mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
|
| 373 |
+
|
| 374 |
+
for child in that.children:
|
| 375 |
+
if child.conciliation is not None:
|
| 376 |
+
child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
|
| 377 |
+
strat = child.conciliation
|
| 378 |
+
|
| 379 |
+
if strat == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
|
| 380 |
+
aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
|
| 381 |
+
|
| 382 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
|
| 383 |
+
salient_cond_deltas.append((child_delta, child_weight))
|
| 384 |
+
|
| 385 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_WIDE:
|
| 386 |
+
salient_wide_deltas.append((child_delta, child_weight))
|
| 387 |
+
|
| 388 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_BLOB:
|
| 389 |
+
salient_blob_deltas.append((child_delta, child_weight))
|
| 390 |
+
|
| 391 |
+
elif strat == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
|
| 392 |
+
aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, 0.05)
|
| 393 |
+
|
| 394 |
+
else:
|
| 395 |
+
# AND_ALIGN_D_S (soft alignment blend)
|
| 396 |
+
m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
|
| 397 |
+
if m:
|
| 398 |
+
align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 399 |
+
else:
|
| 400 |
+
# AND_MASK_ALIGN_D_S (binary alignment mask)
|
| 401 |
+
m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
|
| 402 |
+
if m:
|
| 403 |
+
mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
|
| 404 |
+
|
| 405 |
+
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
| 406 |
+
|
| 407 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas, child_k=20)
|
| 408 |
+
aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas, child_k=1)
|
| 409 |
+
aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
|
| 410 |
+
aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
|
| 411 |
+
aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
|
| 412 |
+
return aux_cond_delta
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# ---------------------------------------------------------------------------
|
| 416 |
+
# Strategy implementations
|
| 417 |
+
# ---------------------------------------------------------------------------
|
| 418 |
+
|
| 419 |
+
def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
|
| 420 |
+
if (normal == 0).all():
|
| 421 |
+
if shared.state.sampling_step <= 0:
|
| 422 |
+
_warn_projection_not_found()
|
| 423 |
+
return vector
|
| 424 |
+
return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _salient_blend(
|
| 428 |
+
normal: torch.Tensor,
|
| 429 |
+
vectors: List[Tuple[torch.Tensor, float]],
|
| 430 |
+
child_k: float = 20.0,
|
| 431 |
+
) -> torch.Tensor:
|
| 432 |
+
"""
|
| 433 |
+
Saliency-guided blend: each child prompt wins in the latent regions
|
| 434 |
+
where its absolute activation magnitude is strongest.
|
| 435 |
+
|
| 436 |
+
child_k controls how sharp/selective the child salience mask is:
|
| 437 |
+
child_k=20 → very sharp, 1-2 peak pixels (AND_SALT — life-branch style)
|
| 438 |
+
child_k=1 → broad, ~55% of pixels (AND_SALT_WIDE — original main-branch)
|
| 439 |
+
Parent always uses k=1 (diffuse reference).
|
| 440 |
+
"""
|
| 441 |
+
if not vectors:
|
| 442 |
+
return torch.zeros_like(normal)
|
| 443 |
+
|
| 444 |
+
salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=child_k) for v, _ in vectors]
|
| 445 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 446 |
+
|
| 447 |
+
result = torch.zeros_like(normal)
|
| 448 |
+
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
| 449 |
+
vector_mask = (mask == mask_i).float()
|
| 450 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 451 |
+
return result
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _salient_blend_blob(
|
| 455 |
+
normal: torch.Tensor,
|
| 456 |
+
vectors: List[Tuple[torch.Tensor, float]],
|
| 457 |
+
) -> torch.Tensor:
|
| 458 |
+
"""
|
| 459 |
+
AND_SALT_BLOB: life-branch dev.py algorithm (cleaned of debug scaffolding).
|
| 460 |
+
|
| 461 |
+
Pipeline per child:
|
| 462 |
+
1. k=20 softmax → very sharp salience seed (1-2 pixels)
|
| 463 |
+
2. _life_erode ×6 → erode to spatially dense core
|
| 464 |
+
3. _life_thickify ×2 → grow outward into a smooth blob
|
| 465 |
+
4. result += weight × blob_mask × (child − parent)
|
| 466 |
+
|
| 467 |
+
This is what the life-branch author was iterating toward in dev.py.
|
| 468 |
+
"""
|
| 469 |
+
if not vectors:
|
| 470 |
+
return torch.zeros_like(normal)
|
| 471 |
+
|
| 472 |
+
salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=20) for v, _ in vectors]
|
| 473 |
+
mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
|
| 474 |
+
|
| 475 |
+
result = torch.zeros_like(normal)
|
| 476 |
+
for mask_i, (vector, weight) in enumerate(vectors, start=1):
|
| 477 |
+
vector_mask = (mask == mask_i).float()
|
| 478 |
+
for _ in range(6):
|
| 479 |
+
vector_mask = _life_step(vector_mask, _erode_rule)
|
| 480 |
+
for _ in range(2):
|
| 481 |
+
vector_mask = _life_step(vector_mask, _thickify_rule)
|
| 482 |
+
result = result + weight * vector_mask * (vector - normal)
|
| 483 |
+
return result
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
|
| 487 |
+
"""
|
| 488 |
+
One step of a cellular-automaton morphology on a C×H×W binary mask.
|
| 489 |
+
|
| 490 |
+
The conv3d trick from dev.py: concatenate [board, board[:-1]] along the
|
| 491 |
+
channel axis so a single 3-D convolution simultaneously sums the 3×3
|
| 492 |
+
spatial neighbourhood across *all* C channels. This keeps the operation
|
| 493 |
+
GPU-friendly and avoids an explicit loop over channels.
|
| 494 |
+
|
| 495 |
+
neighbors[c,h,w] = Σ_{dc,dh,dw} board_padded[c+dc, h+dh, w+dw]
|
| 496 |
+
minus the center pixel itself (board is subtracted).
|
| 497 |
+
"""
|
| 498 |
+
C = board.shape[0]
|
| 499 |
+
kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
|
| 500 |
+
kernel = kernel.unsqueeze(0).unsqueeze(0) # [1, 1, C, 3, 3]
|
| 501 |
+
|
| 502 |
+
# Pad spatially (left/right/top/bottom) but not along channel axis
|
| 503 |
+
padded = torch.cat([board.clone(), board[:-1].clone()], dim=0) # [2C-1, H, W]
|
| 504 |
+
padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0) # [2C-1, H+2, W+2]
|
| 505 |
+
|
| 506 |
+
neighbors = torch.nn.functional.conv3d(
|
| 507 |
+
padded.unsqueeze(0).unsqueeze(0), # [1, 1, 2C-1, H+2, W+2]
|
| 508 |
+
kernel, # [1, 1, C, 3, 3]
|
| 509 |
+
padding=0,
|
| 510 |
+
).squeeze(0).squeeze(0) # [C, H, W]
|
| 511 |
+
|
| 512 |
+
neighbors = neighbors - board # subtract center pixel
|
| 513 |
+
return rule(board, neighbors).float()
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 517 |
+
"""Keep a pixel only if it is set AND its C-channel neighbourhood is dense.
|
| 518 |
+
Threshold C*5 matches dev.py: for C=4 that means ≥20 out of 36 possible."""
|
| 519 |
+
C = board.shape[0]
|
| 520 |
+
return (board == 1) & (neighbors >= C * 5)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
|
| 524 |
+
"""Grow: keep existing pixels OR add any pixel adjacent to the core.
|
| 525 |
+
population ≥ 4 ensures only pixels touching at least one set neighbour grow."""
|
| 526 |
+
population = board + neighbors
|
| 527 |
+
return (board == 1) | (population >= 4)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
|
| 531 |
+
"""Softmax-based salience map. k > 1 → sharper, more selective mask."""
|
| 532 |
+
return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
|
| 536 |
+
"""Keep only the top k_ratio fraction of activations by absolute value."""
|
| 537 |
+
k = int(torch.numel(vector) * (1 - k_ratio))
|
| 538 |
+
k = max(1, k) # kthvalue requires k >= 1
|
| 539 |
+
threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
|
| 540 |
+
return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# ---------------------------------------------------------------------------
|
| 544 |
+
# Alignment blend (alignment_blend branch)
|
| 545 |
+
# ---------------------------------------------------------------------------
|
| 546 |
+
|
| 547 |
+
def _compute_subregion_similarity_map(
|
| 548 |
+
child_vector: torch.Tensor,
|
| 549 |
+
parent_vector: torch.Tensor,
|
| 550 |
+
region_size: int = 2,
|
| 551 |
+
) -> torch.Tensor:
|
| 552 |
+
"""
|
| 553 |
+
Compute local average cosine similarity of (region_size × region_size) regions.
|
| 554 |
+
Returns a map of shape [C, H, W] with values in [-1, 1].
|
| 555 |
+
"""
|
| 556 |
+
C, H, W = child_vector.shape
|
| 557 |
+
parent = parent_vector.unsqueeze(0) # [1, C, H, W]
|
| 558 |
+
child = child_vector.unsqueeze(0)
|
| 559 |
+
|
| 560 |
+
region_radius = region_size // 2
|
| 561 |
+
if region_size % 2 == 1:
|
| 562 |
+
pad_size = (region_radius,) * 4
|
| 563 |
+
else:
|
| 564 |
+
pad_size = (region_radius - 1, region_radius) * 2
|
| 565 |
+
|
| 566 |
+
parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
|
| 567 |
+
child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
|
| 568 |
+
|
| 569 |
+
# [H*W, C, region_size, region_size]
|
| 570 |
+
# .contiguous() is required before .view() because .permute() produces a
|
| 571 |
+
# non-contiguous tensor and .view() raises RuntimeError on non-contiguous input.
|
| 572 |
+
parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 573 |
+
child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
|
| 574 |
+
|
| 575 |
+
unfold2 = torch.nn.Unfold(kernel_size=2)
|
| 576 |
+
parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
|
| 577 |
+
child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
|
| 578 |
+
|
| 579 |
+
parent_sub = F.normalize(parent_sub, p=2, dim=2)
|
| 580 |
+
child_sub = F.normalize(child_sub, p=2, dim=2)
|
| 581 |
+
sim = (parent_sub * child_sub).sum(dim=2) # [H*W, C, (r-1)^2]
|
| 582 |
+
return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W) # [C, H, W]
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def _alignment_blend(
|
| 586 |
+
parent: torch.Tensor,
|
| 587 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 588 |
+
) -> torch.Tensor:
|
| 589 |
+
"""
|
| 590 |
+
Soft alignment blend (AND_ALIGN_D_S).
|
| 591 |
+
Child contribution is weighted by max(0, structure_alignment − detail_alignment).
|
| 592 |
+
High weight where child changes detail without breaking structure.
|
| 593 |
+
"""
|
| 594 |
+
result = torch.zeros_like(parent)
|
| 595 |
+
for child, weight, detail_size, structure_size in children:
|
| 596 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 597 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 598 |
+
|
| 599 |
+
# Normalise by absolute max so that all-negative maps (anti-correlated prompts
|
| 600 |
+
# such as 'black' vs 'white') don't blow up to ±1e7 and clamp incorrectly to 1.
|
| 601 |
+
# Dividing by positive max would invert the sign ordering when all values are
|
| 602 |
+
# negative; dividing by abs-max preserves relative ordering in both cases.
|
| 603 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 604 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 605 |
+
detail_sim = detail_sim / d_abs_max
|
| 606 |
+
structure_sim = structure_sim / s_abs_max
|
| 607 |
+
|
| 608 |
+
alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
|
| 609 |
+
result = result + (child - parent) * weight * alignment_weight
|
| 610 |
+
return result
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def _alignment_mask_blend(
|
| 614 |
+
parent: torch.Tensor,
|
| 615 |
+
children: List[Tuple[torch.Tensor, float, int, int]],
|
| 616 |
+
) -> torch.Tensor:
|
| 617 |
+
"""
|
| 618 |
+
Binary alignment mask (AND_MASK_ALIGN_D_S).
|
| 619 |
+
Child receives full weight where structure_alignment > detail_alignment, zero elsewhere.
|
| 620 |
+
"""
|
| 621 |
+
result = torch.zeros_like(parent)
|
| 622 |
+
for child, weight, detail_size, structure_size in children:
|
| 623 |
+
detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
|
| 624 |
+
structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
|
| 625 |
+
|
| 626 |
+
d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
|
| 627 |
+
s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
|
| 628 |
+
detail_sim = detail_sim / d_abs_max
|
| 629 |
+
structure_sim = structure_sim / s_abs_max
|
| 630 |
+
|
| 631 |
+
alignment_mask = (structure_sim > detail_sim).to(child)
|
| 632 |
+
result = result + (child - parent) * weight * alignment_mask
|
| 633 |
+
return result
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# ---------------------------------------------------------------------------
|
| 637 |
+
# Sampler hijack
|
| 638 |
+
# ---------------------------------------------------------------------------
|
| 639 |
+
|
| 640 |
+
sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 641 |
+
module=sd_samplers,
|
| 642 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 643 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
@sd_samplers_hijacker.hijack('create_sampler')
|
| 648 |
+
def create_sampler_hijack(name: str, model, original_function):
|
| 649 |
+
sampler = original_function(name, model)
|
| 650 |
+
if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
|
| 651 |
+
if global_state.is_enabled:
|
| 652 |
+
_warn_unsupported_sampler()
|
| 653 |
+
return sampler
|
| 654 |
+
|
| 655 |
+
sampler.model_wrap_cfg.combine_denoised = functools.partial(
|
| 656 |
+
combine_denoised_hijack,
|
| 657 |
+
original_function=sampler.model_wrap_cfg.combine_denoised,
|
| 658 |
+
)
|
| 659 |
+
return sampler
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
# ---------------------------------------------------------------------------
|
| 663 |
+
# Warnings / logging
|
| 664 |
+
# ---------------------------------------------------------------------------
|
| 665 |
+
|
| 666 |
+
def _warn_unsupported_sampler() -> None:
|
| 667 |
+
_console_warn('''
|
| 668 |
+
Neutral prompt relies on composition via AND, which the webui does not support
|
| 669 |
+
when using any of the DDIM, PLMS and UniPC samplers.
|
| 670 |
+
The sampler will NOT be patched – falling back on the original implementation.
|
| 671 |
+
''')
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def _warn_projection_not_found() -> None:
|
| 675 |
+
_console_warn('''
|
| 676 |
+
Could not find a projection for one or more AND_PERP prompts.
|
| 677 |
+
These prompts will NOT be made perpendicular.
|
| 678 |
+
''')
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def _console_warn(message: str) -> None:
|
| 682 |
+
if not global_state.verbose:
|
| 683 |
+
return
|
| 684 |
+
print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
|
neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@contextlib.contextmanager
|
| 5 |
+
def fix_path():
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
extension_path = str(Path(__file__).parent.parent.parent)
|
| 10 |
+
added = False
|
| 11 |
+
if extension_path not in sys.path:
|
| 12 |
+
sys.path.insert(0, extension_path)
|
| 13 |
+
added = True
|
| 14 |
+
|
| 15 |
+
yield
|
| 16 |
+
|
| 17 |
+
if added:
|
| 18 |
+
sys.path.remove(extension_path)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
with fix_path():
|
| 22 |
+
del fix_path, contextlib
|
| 23 |
+
from .api import *
|
neutral_prompt_patched/lib_neutral_prompt/external_code/api.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
External API for sd-webui-neutral-prompt.
|
| 3 |
+
|
| 4 |
+
Provides thin helpers that external extensions / scripts can import via:
|
| 5 |
+
|
| 6 |
+
from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from lib_neutral_prompt import global_state
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def override_cfg_rescale(cfg_rescale: float) -> None:
|
| 15 |
+
"""
|
| 16 |
+
Override the CFG rescale value for the *next* generation step only.
|
| 17 |
+
After the step runs the value is automatically cleared.
|
| 18 |
+
"""
|
| 19 |
+
global_state.cfg_rescale_override = cfg_rescale
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_last_cfg_rescale_factor() -> Optional[float]:
|
| 23 |
+
"""
|
| 24 |
+
Return the CFG rescale factor computed during the most recent denoising
|
| 25 |
+
step, or None if rescaling was not active.
|
| 26 |
+
"""
|
| 27 |
+
return global_state.CFGRescaleFactorSingleton.get()
|
neutral_prompt_patched/lib_neutral_prompt/global_state.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Global mutable state shared across the extension.
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- is_enabled / prompt_exprs / verbose (all branches)
|
| 6 |
+
- cfg_rescale + cfg_rescale_override mechanism (main branch)
|
| 7 |
+
- batch_cond_indices (affine branch – needed by on_cfg_denoiser pre-noise hook)
|
| 8 |
+
- CFGRescaleFactorSingleton – exposes the last computed rescale factor
|
| 9 |
+
to external callers / API users (export_rescale_factor branch)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import threading
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Runtime state
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
is_enabled: bool = False
|
| 22 |
+
prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
|
| 23 |
+
# Populated by reconstruct_multicond_batch hook; used by the pre-noise
|
| 24 |
+
# affine transform hook (on_cfg_denoiser) to know which latent slices
|
| 25 |
+
# correspond to which prompt child.
|
| 26 |
+
batch_cond_indices: List[List[Tuple[int, float]]] = []
|
| 27 |
+
cfg_rescale: float = 0.0
|
| 28 |
+
verbose: bool = False # matches UI default; UI syncs this on settings load
|
| 29 |
+
|
| 30 |
+
# Set to a float value by XYZ-grid or external API to override cfg_rescale
|
| 31 |
+
# for a single generation step, then auto-cleared.
|
| 32 |
+
cfg_rescale_override: Optional[float] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def apply_and_clear_cfg_rescale_override() -> None:
|
| 36 |
+
"""Apply a one-shot cfg_rescale override and immediately clear it."""
|
| 37 |
+
global cfg_rescale, cfg_rescale_override
|
| 38 |
+
if cfg_rescale_override is not None:
|
| 39 |
+
cfg_rescale = cfg_rescale_override
|
| 40 |
+
cfg_rescale_override = None
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# CFG Rescale Factor Singleton (export_rescale_factor branch)
|
| 45 |
+
#
|
| 46 |
+
# Stores the last computed rescale factor so that external tools / scripts
|
| 47 |
+
# can read it without re-deriving it.
|
| 48 |
+
#
|
| 49 |
+
# Uses threading.local() so concurrent API workers each get their own slot.
|
| 50 |
+
# clear() must be called at the start of every _cfg_rescale() call so that
|
| 51 |
+
# get() correctly returns None when rescaling was skipped this step.
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
class CFGRescaleFactorSingleton:
|
| 55 |
+
"""Thread-local store for the most recently computed CFG rescale factor.
|
| 56 |
+
|
| 57 |
+
Lifecycle per denoising step:
|
| 58 |
+
1. _cfg_rescale() calls clear() at entry — value becomes None.
|
| 59 |
+
2. If rescaling is active, set() stores the computed factor.
|
| 60 |
+
3. External code calls get() after the step; receives the factor or None.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
_state = threading.local()
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def set(cls, value: float) -> None:
|
| 67 |
+
cls._state.cfg_rescale_factor = value
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def get(cls) -> Optional[float]:
|
| 71 |
+
return getattr(cls._state, 'cfg_rescale_factor', None)
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def clear(cls) -> None:
|
| 75 |
+
cls._state.cfg_rescale_factor = None
|
neutral_prompt_patched/lib_neutral_prompt/hijacker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ModuleHijacker:
|
| 5 |
+
def __init__(self, module):
|
| 6 |
+
self.__module = module
|
| 7 |
+
self.__original_functions = dict()
|
| 8 |
+
|
| 9 |
+
def hijack(self, attribute):
|
| 10 |
+
if attribute not in self.__original_functions:
|
| 11 |
+
self.__original_functions[attribute] = getattr(self.__module, attribute)
|
| 12 |
+
|
| 13 |
+
def decorator(function):
|
| 14 |
+
setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
|
| 15 |
+
return function
|
| 16 |
+
|
| 17 |
+
return decorator
|
| 18 |
+
|
| 19 |
+
def reset_module(self):
|
| 20 |
+
for attribute, original_function in self.__original_functions.items():
|
| 21 |
+
setattr(self.__module, attribute, original_function)
|
| 22 |
+
|
| 23 |
+
self.__original_functions.clear()
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
|
| 27 |
+
if not hasattr(module, hijacker_attribute):
|
| 28 |
+
module_hijacker = ModuleHijacker(module)
|
| 29 |
+
setattr(module, hijacker_attribute, module_hijacker)
|
| 30 |
+
on_uninstall(lambda: delattr(module, hijacker_attribute))
|
| 31 |
+
on_uninstall(module_hijacker.reset_module)
|
| 32 |
+
return module_hijacker
|
| 33 |
+
else:
|
| 34 |
+
return getattr(module, hijacker_attribute)
|
neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified Neutral Prompt Parser
|
| 3 |
+
Combines:
|
| 4 |
+
- Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
|
| 5 |
+
- Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
|
| 6 |
+
- Local alignment blend: AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import abc
|
| 10 |
+
import dataclasses
|
| 11 |
+
import math
|
| 12 |
+
import re
|
| 13 |
+
from enum import Enum
|
| 14 |
+
from itertools import product
|
| 15 |
+
from typing import Any, List, Optional, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Keyword registry
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
# Base conciliation keywords
|
| 25 |
+
_BASE_KEYWORD_MAP = {
|
| 26 |
+
'AND_PERP': 'PERPENDICULAR',
|
| 27 |
+
'AND_SALT': 'SALIENCE_MASK', # k=20 sharp mask (life-branch style)
|
| 28 |
+
'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE', # k=1 broad mask (original main-branch behaviour)
|
| 29 |
+
'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB', # k=20 + erode + thickify → smooth blob (dev.py intent)
|
| 30 |
+
'AND_TOPK': 'SEMANTIC_GUIDANCE',
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Alignment blend: AND_ALIGN_D_S (soft weight, structure > detail)
|
| 34 |
+
_ALIGN_KEYWORD_MAP = {
|
| 35 |
+
f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
|
| 36 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Alignment mask: AND_MASK_ALIGN_D_S (binary mask, structure > detail)
|
| 40 |
+
_MASK_ALIGN_KEYWORD_MAP = {
|
| 41 |
+
f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
|
| 42 |
+
for i, j in product(range(2, 33), repeat=2) if i != j
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
|
| 46 |
+
|
| 47 |
+
PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
|
| 48 |
+
ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
|
| 49 |
+
|
| 50 |
+
# Lists kept for ordered iteration (tokenize uses list order for regex priority)
|
| 51 |
+
prompt_keywords = [e.value for e in PromptKeyword]
|
| 52 |
+
conciliation_strategies = [e.value for e in ConciliationStrategy]
|
| 53 |
+
|
| 54 |
+
# Sets for O(1) membership checks in parser hot-loops (1864-item list = 230x slower)
|
| 55 |
+
_prompt_keywords_set = frozenset(prompt_keywords)
|
| 56 |
+
_conciliation_strategies_set = frozenset(conciliation_strategies)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Affine transform definitions (from affine branch)
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
affine_transforms = {
|
| 64 |
+
'ROTATE': lambda t, angle=0, *_: t @ torch.tensor([
|
| 65 |
+
[math.cos(angle * 2 * math.pi), -math.sin(angle * 2 * math.pi), 0],
|
| 66 |
+
[math.sin(angle * 2 * math.pi), math.cos(angle * 2 * math.pi), 0],
|
| 67 |
+
[0, 0, 1],
|
| 68 |
+
], dtype=torch.float32),
|
| 69 |
+
'SLIDE': lambda t, x=0, y=0, *_: t @ torch.tensor([
|
| 70 |
+
[1, 0, float(x)],
|
| 71 |
+
[0, 1, float(y)],
|
| 72 |
+
[0, 0, 1],
|
| 73 |
+
], dtype=torch.float32),
|
| 74 |
+
'SCALE': lambda t, x=1, y=None, *_: t @ torch.tensor([
|
| 75 |
+
[float(x), 0, 0],
|
| 76 |
+
[0, float(y) if y is not None else float(x), 0],
|
| 77 |
+
[0, 0, 1],
|
| 78 |
+
], dtype=torch.float32),
|
| 79 |
+
'SHEAR': lambda t, x=0, y=None, *_: t @ torch.tensor([
|
| 80 |
+
[1, math.tan(float(x) * 2 * math.pi), 0],
|
| 81 |
+
[math.tan(float(y if y is not None else x) * 2 * math.pi), 1, 0],
|
| 82 |
+
[0, 0, 1],
|
| 83 |
+
], dtype=torch.float32),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# AST node types
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
@dataclasses.dataclass
|
| 92 |
+
class PromptExpr(abc.ABC):
|
| 93 |
+
weight: float
|
| 94 |
+
conciliation: Optional[ConciliationStrategy]
|
| 95 |
+
local_transform: Optional[torch.Tensor] # 2×3 affine tensor or None
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def accept(self, visitor, *args, **kwargs) -> Any:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@dataclasses.dataclass
|
| 103 |
+
class LeafPrompt(PromptExpr):
|
| 104 |
+
prompt: str
|
| 105 |
+
|
| 106 |
+
def accept(self, visitor, *args, **kwargs):
|
| 107 |
+
return visitor.visit_leaf_prompt(self, *args, **kwargs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class CompositePrompt(PromptExpr):
|
| 112 |
+
children: List[PromptExpr]
|
| 113 |
+
|
| 114 |
+
def accept(self, visitor, *args, **kwargs):
|
| 115 |
+
return visitor.visit_composite_prompt(self, *args, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class FlatSizeVisitor:
|
| 119 |
+
def visit_leaf_prompt(self, that: LeafPrompt) -> int:
|
| 120 |
+
return 1
|
| 121 |
+
|
| 122 |
+
def visit_composite_prompt(self, that: CompositePrompt) -> int:
|
| 123 |
+
return sum(child.accept(self) for child in that.children) if that.children else 0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ---------------------------------------------------------------------------
|
| 127 |
+
# Parser
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
|
| 130 |
+
def parse_root(string: str) -> CompositePrompt:
|
| 131 |
+
tokens = tokenize(string)
|
| 132 |
+
prompts = parse_prompts(tokens)
|
| 133 |
+
return CompositePrompt(1.0, None, None, prompts)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
|
| 137 |
+
prompts = [parse_prompt(tokens, first=True, nested=nested)]
|
| 138 |
+
while tokens:
|
| 139 |
+
if nested and tokens[0] == ']':
|
| 140 |
+
break
|
| 141 |
+
prompts.append(parse_prompt(tokens, first=False, nested=nested))
|
| 142 |
+
return prompts
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _compose_affine(
|
| 146 |
+
a: Optional[torch.Tensor],
|
| 147 |
+
b: Optional[torch.Tensor],
|
| 148 |
+
) -> Optional[torch.Tensor]:
|
| 149 |
+
"""
|
| 150 |
+
Compose two 2×3 affine matrices (a applied first, then b).
|
| 151 |
+
Returns None if both are None; returns the non-None one if only one exists.
|
| 152 |
+
"""
|
| 153 |
+
if a is None:
|
| 154 |
+
return b
|
| 155 |
+
if b is None:
|
| 156 |
+
return a
|
| 157 |
+
# Lift to 3×3, multiply, trim back to 2×3
|
| 158 |
+
a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 159 |
+
b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
|
| 160 |
+
return (b3 @ a3)[:-1]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
|
| 164 |
+
"""
|
| 165 |
+
Return True if `tokens` starts with one or more affine keywords followed
|
| 166 |
+
immediately by a prompt keyword, e.g.:
|
| 167 |
+
ROTATE[0.125] AND_PERP ...
|
| 168 |
+
SLIDE[0.05,0] AND_SALT ...
|
| 169 |
+
ROTATE[0.125] SLIDE[0.05,0] AND_PERP ...
|
| 170 |
+
|
| 171 |
+
This is a pure structural scan — it does NOT compute any affine matrices,
|
| 172 |
+
so it is safe to call as a look-ahead probe without mock-torch issues.
|
| 173 |
+
|
| 174 |
+
Used in two places:
|
| 175 |
+
- parse_prompt() → consume leading affine before keyword
|
| 176 |
+
- _parse_prompt_text() → stop current segment before this boundary
|
| 177 |
+
"""
|
| 178 |
+
pos = 0
|
| 179 |
+
n = len(tokens)
|
| 180 |
+
# skip a leading whitespace-only token
|
| 181 |
+
if pos < n and not tokens[pos].strip():
|
| 182 |
+
pos += 1
|
| 183 |
+
# must start with an affine keyword
|
| 184 |
+
if pos >= n or tokens[pos] not in affine_transforms:
|
| 185 |
+
return False
|
| 186 |
+
# consume one or more KEYWORD [ args ] blocks
|
| 187 |
+
while pos < n and tokens[pos] in affine_transforms:
|
| 188 |
+
pos += 1 # skip keyword
|
| 189 |
+
if pos >= n or tokens[pos] != '[':
|
| 190 |
+
return False
|
| 191 |
+
pos += 1 # skip '['
|
| 192 |
+
if pos < n and tokens[pos] != ']':
|
| 193 |
+
pos += 1 # skip args (may contain commas)
|
| 194 |
+
if pos >= n or tokens[pos] != ']':
|
| 195 |
+
return False
|
| 196 |
+
pos += 1 # skip ']'
|
| 197 |
+
if pos < n and not tokens[pos].strip():
|
| 198 |
+
pos += 1 # skip optional whitespace
|
| 199 |
+
# the very next token must be a prompt keyword
|
| 200 |
+
return pos < n and tokens[pos] in _prompt_keywords_set
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
|
| 204 |
+
# Support two affine-keyword orderings:
|
| 205 |
+
# (A) AND_PERP ROTATE[0.25] text ← trailing affine (original syntax)
|
| 206 |
+
# (B) ROTATE[0.25] AND_PERP text ← leading affine (documented syntax)
|
| 207 |
+
#
|
| 208 |
+
# Note: we check _looks_like_leading_affine_prompt unconditionally (not
|
| 209 |
+
# gated on `not first`) so that a prompt starting with e.g.
|
| 210 |
+
# "ROTATE[0.125] AND_PERP text" is handled correctly even as the first
|
| 211 |
+
# segment.
|
| 212 |
+
leading_affine: Optional[torch.Tensor] = None
|
| 213 |
+
if _looks_like_leading_affine_prompt(tokens):
|
| 214 |
+
probe = tokens.copy()
|
| 215 |
+
leading_affine = _parse_affine_transform(probe)
|
| 216 |
+
tokens[:] = probe
|
| 217 |
+
|
| 218 |
+
# After consuming a leading affine the next token is a keyword; accept it.
|
| 219 |
+
# We also accept a keyword as the very first token (first=True) without a
|
| 220 |
+
# leading affine — e.g. "AND_PERP vivid" at the start of a string should
|
| 221 |
+
# produce a single PERP child, not an empty default child followed by a PERP
|
| 222 |
+
# child. The original `not first` guard was an artefact of earlier grammar;
|
| 223 |
+
# all existing tests pass without it.
|
| 224 |
+
if tokens and tokens[0] in _prompt_keywords_set:
|
| 225 |
+
prompt_type = tokens.pop(0)
|
| 226 |
+
else:
|
| 227 |
+
prompt_type = PromptKeyword.AND.value
|
| 228 |
+
|
| 229 |
+
conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
|
| 230 |
+
|
| 231 |
+
# Parse optional trailing affine transform chain (original syntax)
|
| 232 |
+
trailing_affine = _parse_affine_transform(tokens)
|
| 233 |
+
affine_transform = _compose_affine(leading_affine, trailing_affine)
|
| 234 |
+
|
| 235 |
+
# Try composite (bracketed) form
|
| 236 |
+
tokens_copy = tokens.copy()
|
| 237 |
+
if tokens_copy and tokens_copy[0] == '[':
|
| 238 |
+
tokens_copy.pop(0)
|
| 239 |
+
prompts = parse_prompts(tokens_copy, nested=True)
|
| 240 |
+
if tokens_copy:
|
| 241 |
+
assert tokens_copy.pop(0) == ']'
|
| 242 |
+
if len(prompts) > 1:
|
| 243 |
+
tokens[:] = tokens_copy
|
| 244 |
+
weight = _parse_weight(tokens)
|
| 245 |
+
return CompositePrompt(weight, conciliation, affine_transform, prompts)
|
| 246 |
+
|
| 247 |
+
prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
|
| 248 |
+
return LeafPrompt(weight, conciliation, affine_transform, prompt_text)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
|
| 252 |
+
text = ''
|
| 253 |
+
depth = 0
|
| 254 |
+
weight = 1.0
|
| 255 |
+
while tokens:
|
| 256 |
+
tok = tokens[0]
|
| 257 |
+
if tok == ']':
|
| 258 |
+
if depth == 0:
|
| 259 |
+
if nested:
|
| 260 |
+
break
|
| 261 |
+
else:
|
| 262 |
+
depth -= 1
|
| 263 |
+
elif tok == '[':
|
| 264 |
+
depth += 1
|
| 265 |
+
elif tok == ':':
|
| 266 |
+
if len(tokens) >= 2 and _is_float(tokens[1].strip()):
|
| 267 |
+
if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
|
| 268 |
+
tokens.pop(0)
|
| 269 |
+
weight = float(tokens.pop(0).strip())
|
| 270 |
+
break
|
| 271 |
+
elif tok in _prompt_keywords_set:
|
| 272 |
+
break
|
| 273 |
+
# Stop before a leading-affine-then-keyword boundary, e.g.
|
| 274 |
+
# ROTATE[0.125] AND_PERP text
|
| 275 |
+
# so the affine is consumed by the *next* parse_prompt call, not
|
| 276 |
+
# swallowed as raw text by the current segment.
|
| 277 |
+
elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
|
| 278 |
+
break
|
| 279 |
+
text += tokens.pop(0)
|
| 280 |
+
return text, weight
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
|
| 284 |
+
"""
|
| 285 |
+
Consume optional affine keywords like ROTATE[0.25] SLIDE[0.1,0] from the token stream.
|
| 286 |
+
Returns a 2×3 float32 tensor, or None if no affine tokens were found.
|
| 287 |
+
"""
|
| 288 |
+
tokens_copy = tokens.copy()
|
| 289 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 290 |
+
tokens_copy.pop(0)
|
| 291 |
+
|
| 292 |
+
affine_funcs = []
|
| 293 |
+
|
| 294 |
+
while tokens_copy and tokens_copy[0] in affine_transforms:
|
| 295 |
+
func = affine_transforms[tokens_copy.pop(0)]
|
| 296 |
+
args: List[float] = []
|
| 297 |
+
|
| 298 |
+
if not (tokens_copy and tokens_copy[0] == '['):
|
| 299 |
+
break
|
| 300 |
+
tokens_copy.pop(0) # consume '['
|
| 301 |
+
|
| 302 |
+
if tokens_copy and tokens_copy[0] != ']':
|
| 303 |
+
if tokens_copy[0].strip():
|
| 304 |
+
try:
|
| 305 |
+
args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
|
| 306 |
+
except ValueError:
|
| 307 |
+
break
|
| 308 |
+
else:
|
| 309 |
+
tokens_copy.pop(0)
|
| 310 |
+
|
| 311 |
+
if not (tokens_copy and tokens_copy[0] == ']'):
|
| 312 |
+
break
|
| 313 |
+
tokens_copy.pop(0) # consume ']'
|
| 314 |
+
|
| 315 |
+
affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
|
| 316 |
+
if tokens_copy and not tokens_copy[0].strip():
|
| 317 |
+
tokens_copy.pop(0)
|
| 318 |
+
tokens[:] = tokens_copy
|
| 319 |
+
|
| 320 |
+
if not affine_funcs:
|
| 321 |
+
return None
|
| 322 |
+
|
| 323 |
+
transform = torch.eye(3, dtype=torch.float32)[:-1] # 2×3
|
| 324 |
+
for fn in reversed(affine_funcs):
|
| 325 |
+
transform = fn(transform)
|
| 326 |
+
return transform
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _parse_weight(tokens: List[str]) -> float:
|
| 330 |
+
if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
|
| 331 |
+
tokens.pop(0)
|
| 332 |
+
return float(tokens.pop(0))
|
| 333 |
+
return 1.0
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def tokenize(s: str) -> List[str]:
|
| 337 |
+
# Use general patterns for AND_ALIGN / AND_MASK_ALIGN families instead of
|
| 338 |
+
# enumerating all 1800+ combinations - that regex is 41k chars and 500x slower.
|
| 339 |
+
# Order matters: longer/more-specific patterns must come first.
|
| 340 |
+
# Use (?<!\w) / (?!\w) instead of \b because _ counts as \w in Python's \b,
|
| 341 |
+
# which would cause 'AND' to match inside 'XAND' (no boundary between \w chars).
|
| 342 |
+
affine_kw_pattern = '|'.join(
|
| 343 |
+
rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
|
| 344 |
+
)
|
| 345 |
+
keyword_pattern = (
|
| 346 |
+
r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # must come before AND_ALIGN and AND
|
| 347 |
+
r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
|
| 348 |
+
r'|(?<!\w)AND_PERP(?!\w)'
|
| 349 |
+
r'|(?<!\w)AND_SALT_WIDE(?!\w)' # must come before AND_SALT
|
| 350 |
+
r'|(?<!\w)AND_SALT_BLOB(?!\w)' # must come before AND_SALT
|
| 351 |
+
r'|(?<!\w)AND_SALT(?!\w)'
|
| 352 |
+
r'|(?<!\w)AND_TOPK(?!\w)'
|
| 353 |
+
r'|(?<!\w)AND(?!\w)'
|
| 354 |
+
)
|
| 355 |
+
return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _is_float(s: str) -> bool:
|
| 359 |
+
try:
|
| 360 |
+
float(s)
|
| 361 |
+
return True
|
| 362 |
+
except ValueError:
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# ---------------------------------------------------------------------------
|
| 367 |
+
# Quick smoke-test
|
| 368 |
+
# ---------------------------------------------------------------------------
|
| 369 |
+
if __name__ == '__main__':
|
| 370 |
+
res = parse_root('''
|
| 371 |
+
hello
|
| 372 |
+
AND_PERP [
|
| 373 |
+
arst
|
| 374 |
+
AND defg : 2
|
| 375 |
+
AND_SALT [
|
| 376 |
+
very nested huh? what do you say :.0
|
| 377 |
+
]
|
| 378 |
+
]
|
| 379 |
+
AND_ALIGN_4_8 watercolor style :0.5
|
| 380 |
+
ROTATE[0.125] AND_PERP vibrant colors
|
| 381 |
+
''')
|
| 382 |
+
print(res)
|
neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
|
| 6 |
+
from modules import script_callbacks, prompt_parser
|
| 7 |
+
|
| 8 |
+
# ---------------------------------------------------------------------------
|
| 9 |
+
# Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
|
| 10 |
+
# neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
|
| 11 |
+
# So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
|
| 12 |
+
# and the patched prompt_parser splits that into 3 multicond branches instead of 1,
|
| 13 |
+
# which shifts all batch_cond_indices and breaks PERP/SALT/affine.
|
| 14 |
+
#
|
| 15 |
+
# Solution: escape any standalone '&' inside leaf text with '\&' before handing
|
| 16 |
+
# the string to the webui parser. The patched prompt_parser correctly unescapes
|
| 17 |
+
# '\&' back to '&' during conditioning, so the model sees the original text.
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
_STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
|
| 21 |
+
|
| 22 |
+
# ── debug logging ──────────────────────────────────────────────────────────
|
| 23 |
+
# Set env variable NP_DEBUG=1 to enable verbose output in the A1111 console.
|
| 24 |
+
# Example (Linux/Mac): NP_DEBUG=1 python launch.py
|
| 25 |
+
# Example (Windows cmd): set NP_DEBUG=1 && python launch.py
|
| 26 |
+
import os as _os
|
| 27 |
+
_DEBUG = _os.getenv("NP_DEBUG", "0").strip() not in ("0", "", "false", "no", "off")
|
| 28 |
+
_log = logging.getLogger("neutral_prompt.hijack")
|
| 29 |
+
# ──────────────────────────────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _escape_leaf_ampersands(text: str) -> str:
|
| 33 |
+
"""Escape standalone '&' so patched prompt_parser doesn't split a single
|
| 34 |
+
Neutral Prompt leaf into extra multicond branches.
|
| 35 |
+
"cat & dog" -> "cat \\& dog"
|
| 36 |
+
"R&D" -> "R&D" (unchanged — not standalone)
|
| 37 |
+
"\\&" -> "\\&" (unchanged — already escaped)
|
| 38 |
+
"""
|
| 39 |
+
if not text or '&' not in text:
|
| 40 |
+
return text
|
| 41 |
+
return _STANDALONE_AMP.sub(r'\\&', text)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 45 |
+
module=prompt_parser,
|
| 46 |
+
hijacker_attribute='__neutral_prompt_hijacker',
|
| 47 |
+
on_uninstall=script_callbacks.on_script_unloaded,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@prompt_parser_hijacker.hijack('get_multicond_prompt_list')
|
| 52 |
+
def get_multicond_prompt_list_hijack(prompts, original_function):
|
| 53 |
+
if not global_state.is_enabled:
|
| 54 |
+
return original_function(prompts)
|
| 55 |
+
|
| 56 |
+
global_state.prompt_exprs = parse_prompts(prompts)
|
| 57 |
+
webui_prompts = transpile_exprs(global_state.prompt_exprs)
|
| 58 |
+
|
| 59 |
+
# ── debug: transpiled strings ──────────────────────────────────────────
|
| 60 |
+
if _DEBUG:
|
| 61 |
+
for i, (orig, transp) in enumerate(zip(prompts, webui_prompts)):
|
| 62 |
+
_log.warning(
|
| 63 |
+
"[NP_DEBUG] prompt[%d]\n"
|
| 64 |
+
" original : %r\n"
|
| 65 |
+
" transpiled : %r\n"
|
| 66 |
+
" branches : %d",
|
| 67 |
+
i, orig, transp,
|
| 68 |
+
# count multicond splits the same way prompt_parser_fixed does
|
| 69 |
+
len(re.split(r'(?:\bAND\b|(?<!\S)&(?!\S))(?!_PERP|_SALT|_TOPK)', transp))
|
| 70 |
+
)
|
| 71 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 72 |
+
|
| 73 |
+
if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
|
| 74 |
+
webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
|
| 75 |
+
|
| 76 |
+
result = original_function(webui_prompts)
|
| 77 |
+
|
| 78 |
+
# ── debug: multicond parts ─────────────────────────────────────────────
|
| 79 |
+
if _DEBUG:
|
| 80 |
+
conds_list, prompt_flat_list, prompt_indexes = result
|
| 81 |
+
_log.warning(
|
| 82 |
+
"[NP_DEBUG] get_multicond_prompt_list result\n"
|
| 83 |
+
" prompt_flat_list : %r\n"
|
| 84 |
+
" prompt_indexes : %r\n"
|
| 85 |
+
" conds_list : %r",
|
| 86 |
+
list(prompt_flat_list),
|
| 87 |
+
dict(prompt_indexes),
|
| 88 |
+
conds_list,
|
| 89 |
+
)
|
| 90 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 91 |
+
|
| 92 |
+
return result
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
|
| 96 |
+
exprs = []
|
| 97 |
+
for prompt in prompts:
|
| 98 |
+
expr = neutral_prompt_parser.parse_root(prompt)
|
| 99 |
+
exprs.append(expr)
|
| 100 |
+
return exprs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
|
| 104 |
+
webui_prompts = []
|
| 105 |
+
for expr in exprs:
|
| 106 |
+
webui_prompts.append(expr.accept(WebuiPromptVisitor()))
|
| 107 |
+
return webui_prompts
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class WebuiPromptVisitor:
|
| 111 |
+
def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
|
| 112 |
+
prompt = _escape_leaf_ampersands(that.prompt)
|
| 113 |
+
return f'{prompt} :{that.weight}'
|
| 114 |
+
|
| 115 |
+
def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
|
| 116 |
+
return ' AND '.join(child.accept(self) for child in that.children)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
|
| 120 |
+
def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
|
| 121 |
+
"""Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
|
| 122 |
+
res = original_function(*args, **kwargs)
|
| 123 |
+
global_state.batch_cond_indices = res[0]
|
| 124 |
+
|
| 125 |
+
# ── debug: batch_cond_indices ──────────────────────────────────────────
|
| 126 |
+
if _DEBUG:
|
| 127 |
+
_log.warning(
|
| 128 |
+
"[NP_DEBUG] reconstruct_multicond_batch\n"
|
| 129 |
+
" batch_cond_indices : %r",
|
| 130 |
+
res[0],
|
| 131 |
+
)
|
| 132 |
+
# ──────────────────────────────────────────────────────────────────────
|
| 133 |
+
|
| 134 |
+
return res
|
neutral_prompt_patched/lib_neutral_prompt/ui.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI for sd-webui-neutral-prompt (unified).
|
| 3 |
+
|
| 4 |
+
Combines:
|
| 5 |
+
- Core prompt types + tooltip (main branch)
|
| 6 |
+
- AND_ALIGN / AND_MASK_ALIGN entries (alignment_blend branch)
|
| 7 |
+
- elem_id on dropdown (main branch)
|
| 8 |
+
- CFG Rescale infotext / paste support (all branches)
|
| 9 |
+
- Affine-transform hint in tooltip (affine branch)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from itertools import product
|
| 15 |
+
from typing import Callable, Dict, List, Tuple
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
import gradio as gr
|
| 19 |
+
|
| 20 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 21 |
+
from modules import script_callbacks, shared
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
txt2img_prompt_textbox = None
|
| 25 |
+
img2img_prompt_textbox = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Prompt-type registry shown in the UI dropdown
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
prompt_types: Dict[str, str] = {
|
| 33 |
+
'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
|
| 34 |
+
'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
|
| 35 |
+
'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
|
| 36 |
+
'Saliency wide / classic (AND_SALT_WIDE)':neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
|
| 37 |
+
'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# Add AND_ALIGN_D_S entries for commonly-used kernel size pairs
|
| 41 |
+
for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
|
| 42 |
+
_key = f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'
|
| 43 |
+
prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
|
| 44 |
+
|
| 45 |
+
# Add AND_MASK_ALIGN_D_S entries
|
| 46 |
+
for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
|
| 47 |
+
_key = f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'
|
| 48 |
+
prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
|
| 49 |
+
|
| 50 |
+
prompt_types_tooltip = '\n'.join([
|
| 51 |
+
'AND – add all prompt features equally (webui built-in)',
|
| 52 |
+
'AND_PERP – reduce contradicting prompt features via perpendicular projection',
|
| 53 |
+
'AND_SALT – sharp saliency mask: child wins only at its 1-2 peak pixels (surgical)',
|
| 54 |
+
'AND_SALT_BLOB – blob saliency mask: peak pixels eroded to core, then grown to smooth blob',
|
| 55 |
+
'AND_SALT_WIDE – broad saliency mask: child wins ~55% of pixels (classic main-branch)',
|
| 56 |
+
'AND_TOPK – small targeted changes (semantic guidance top-k)',
|
| 57 |
+
'AND_ALIGN_D_S – soft blend: child adds details without breaking structure',
|
| 58 |
+
'AND_MASK_ALIGN_D_S– binary mask blend: stricter version of AND_ALIGN',
|
| 59 |
+
'',
|
| 60 |
+
'Affine transforms (supported in either order for auxiliary segments):',
|
| 61 |
+
' ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]',
|
| 62 |
+
' e.g.: ROTATE[0.125] AND_PERP vibrant colors :0.8',
|
| 63 |
+
' e.g.: AND_PERP ROTATE[0.125] vibrant colors :0.8',
|
| 64 |
+
])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# UI component class
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
@dataclasses.dataclass
|
| 72 |
+
class AccordionInterface:
|
| 73 |
+
get_elem_id: Callable[[str], str]
|
| 74 |
+
|
| 75 |
+
def __post_init__(self):
|
| 76 |
+
self.is_rendered = False
|
| 77 |
+
|
| 78 |
+
self.cfg_rescale = gr.Slider(
|
| 79 |
+
label='CFG rescale φ',
|
| 80 |
+
minimum=0, maximum=1, value=0,
|
| 81 |
+
info='0 = disabled. Rescales the CFG output towards the predicted x0 to reduce over-saturation.',
|
| 82 |
+
)
|
| 83 |
+
self.neutral_prompt = gr.Textbox(
|
| 84 |
+
label='Neutral prompt',
|
| 85 |
+
show_label=False,
|
| 86 |
+
lines=3,
|
| 87 |
+
placeholder=(
|
| 88 |
+
'Neutral / auxiliary prompt '
|
| 89 |
+
'(press "Apply to prompt" to append with the chosen strategy keyword)'
|
| 90 |
+
),
|
| 91 |
+
)
|
| 92 |
+
self.neutral_cond_scale = gr.Slider(
|
| 93 |
+
label='Prompt weight',
|
| 94 |
+
minimum=-3, maximum=3, value=1,
|
| 95 |
+
)
|
| 96 |
+
self.aux_prompt_type = gr.Dropdown(
|
| 97 |
+
label='Prompt type',
|
| 98 |
+
choices=list(prompt_types.keys()),
|
| 99 |
+
value=next(iter(prompt_types.keys())),
|
| 100 |
+
tooltip=prompt_types_tooltip,
|
| 101 |
+
elem_id=self.get_elem_id('formatter_prompt_type'),
|
| 102 |
+
)
|
| 103 |
+
self.append_to_prompt_button = gr.Button(value='Apply to prompt')
|
| 104 |
+
|
| 105 |
+
# ------------------------------------------------------------------
|
| 106 |
+
|
| 107 |
+
def arrange_components(self, is_img2img: bool) -> None:
|
| 108 |
+
if self.is_rendered:
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
with gr.Accordion(label='Neutral Prompt', open=False):
|
| 112 |
+
self.cfg_rescale.render()
|
| 113 |
+
with gr.Accordion(label='Prompt formatter', open=False):
|
| 114 |
+
self.neutral_prompt.render()
|
| 115 |
+
self.neutral_cond_scale.render()
|
| 116 |
+
self.aux_prompt_type.render()
|
| 117 |
+
self.append_to_prompt_button.render()
|
| 118 |
+
|
| 119 |
+
def connect_events(self, is_img2img: bool) -> None:
|
| 120 |
+
if self.is_rendered:
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
|
| 124 |
+
self.append_to_prompt_button.click(
|
| 125 |
+
fn=lambda init_prompt, prompt, scale, prompt_type: (
|
| 126 |
+
f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}',
|
| 127 |
+
'',
|
| 128 |
+
),
|
| 129 |
+
inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type],
|
| 130 |
+
outputs=[prompt_textbox, self.neutral_prompt],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def set_rendered(self, value: bool = True) -> None:
|
| 134 |
+
self.is_rendered = value
|
| 135 |
+
|
| 136 |
+
# ------------------------------------------------------------------
|
| 137 |
+
# Script interface (args / infotext / paste)
|
| 138 |
+
|
| 139 |
+
def get_components(self) -> Tuple[gr.components.Component, ...]:
|
| 140 |
+
return (self.cfg_rescale,)
|
| 141 |
+
|
| 142 |
+
def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str], ...]:
|
| 143 |
+
return tuple(zip(self.get_components(), ('CFG Rescale phi',)))
|
| 144 |
+
|
| 145 |
+
def get_paste_field_names(self) -> List[str]:
|
| 146 |
+
return ['CFG Rescale phi']
|
| 147 |
+
|
| 148 |
+
def get_extra_generation_params(self, args: Dict) -> Dict:
|
| 149 |
+
return {'CFG Rescale phi': args['cfg_rescale']}
|
| 150 |
+
|
| 151 |
+
def unpack_processing_args(self, cfg_rescale: float) -> Dict:
|
| 152 |
+
return {'cfg_rescale': cfg_rescale}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Settings & callbacks
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def on_ui_settings() -> None:
|
| 160 |
+
section = ('neutral_prompt', 'Neutral Prompt')
|
| 161 |
+
shared.opts.add_option(
|
| 162 |
+
'neutral_prompt_enabled',
|
| 163 |
+
shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
|
| 164 |
+
)
|
| 165 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 166 |
+
shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
|
| 167 |
+
|
| 168 |
+
shared.opts.add_option(
|
| 169 |
+
'neutral_prompt_verbose',
|
| 170 |
+
shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
|
| 171 |
+
)
|
| 172 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 173 |
+
shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _update_enabled() -> None:
|
| 180 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _update_verbose() -> None:
|
| 184 |
+
global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def on_after_component(component, **_kwargs) -> None:
|
| 188 |
+
global txt2img_prompt_textbox, img2img_prompt_textbox
|
| 189 |
+
eid = getattr(component, 'elem_id', None)
|
| 190 |
+
if eid == 'txt2img_prompt':
|
| 191 |
+
txt2img_prompt_textbox = component
|
| 192 |
+
elif eid == 'img2img_prompt':
|
| 193 |
+
img2img_prompt_textbox = component
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
script_callbacks.on_after_component(on_after_component)
|
neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from types import ModuleType
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from modules import scripts
|
| 5 |
+
from lib_neutral_prompt import global_state
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def patch():
|
| 9 |
+
xyz_module = find_xyz_module()
|
| 10 |
+
if xyz_module is None:
|
| 11 |
+
print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
xyz_module.axis_options.extend([
|
| 15 |
+
xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class XyzFloat(float):
|
| 20 |
+
is_xyz: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def apply_cfg_rescale():
|
| 24 |
+
def callback(_p, v, _vs):
|
| 25 |
+
global_state.cfg_rescale = XyzFloat(v)
|
| 26 |
+
|
| 27 |
+
return callback
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def int_or_float(string):
|
| 31 |
+
try:
|
| 32 |
+
return int(string)
|
| 33 |
+
except ValueError:
|
| 34 |
+
return float(string)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def find_xyz_module() -> Optional[ModuleType]:
|
| 38 |
+
for data in scripts.scripts_data:
|
| 39 |
+
if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
|
| 40 |
+
return data.module
|
| 41 |
+
|
| 42 |
+
return None
|
neutral_prompt_patched/scripts/neutral_prompt.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
|
| 2 |
+
from modules import scripts, processing, shared
|
| 3 |
+
from typing import Dict
|
| 4 |
+
import functools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class NeutralPromptScript(scripts.Script):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.accordion_interface = None
|
| 10 |
+
self._is_img2img = False
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def is_img2img(self):
|
| 14 |
+
return self._is_img2img
|
| 15 |
+
|
| 16 |
+
@is_img2img.setter
|
| 17 |
+
def is_img2img(self, is_img2img):
|
| 18 |
+
self._is_img2img = is_img2img
|
| 19 |
+
if self.accordion_interface is None:
|
| 20 |
+
self.accordion_interface = ui.AccordionInterface(self.elem_id)
|
| 21 |
+
|
| 22 |
+
def title(self) -> str:
|
| 23 |
+
return "Neutral Prompt"
|
| 24 |
+
|
| 25 |
+
def show(self, is_img2img: bool):
|
| 26 |
+
return scripts.AlwaysVisible
|
| 27 |
+
|
| 28 |
+
def ui(self, is_img2img: bool):
|
| 29 |
+
self.hijack_composable_lora(is_img2img)
|
| 30 |
+
|
| 31 |
+
self.accordion_interface.arrange_components(is_img2img)
|
| 32 |
+
self.accordion_interface.connect_events(is_img2img)
|
| 33 |
+
self.infotext_fields = self.accordion_interface.get_infotext_fields()
|
| 34 |
+
self.paste_field_names = self.accordion_interface.get_paste_field_names()
|
| 35 |
+
self.accordion_interface.set_rendered()
|
| 36 |
+
return self.accordion_interface.get_components()
|
| 37 |
+
|
| 38 |
+
def process(self, p: processing.StableDiffusionProcessing, *args):
|
| 39 |
+
args = self.accordion_interface.unpack_processing_args(*args)
|
| 40 |
+
|
| 41 |
+
self.update_global_state(args)
|
| 42 |
+
if global_state.is_enabled:
|
| 43 |
+
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
|
| 44 |
+
|
| 45 |
+
def update_global_state(self, args: Dict):
|
| 46 |
+
if shared.state.job_no == 0:
|
| 47 |
+
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
| 48 |
+
|
| 49 |
+
for k, v in args.items():
|
| 50 |
+
try:
|
| 51 |
+
getattr(global_state, k)
|
| 52 |
+
except AttributeError:
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
if getattr(getattr(global_state, k), 'is_xyz', False):
|
| 56 |
+
xyz_attr = getattr(global_state, k)
|
| 57 |
+
xyz_attr.is_xyz = False
|
| 58 |
+
args[k] = xyz_attr
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
if shared.state.job_no > 0:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
setattr(global_state, k, v)
|
| 65 |
+
|
| 66 |
+
def hijack_composable_lora(self, is_img2img):
|
| 67 |
+
if self.accordion_interface.is_rendered:
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
lora_script = None
|
| 71 |
+
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
| 72 |
+
|
| 73 |
+
for script in script_runner.alwayson_scripts:
|
| 74 |
+
if script.title().lower() == "composable lora":
|
| 75 |
+
lora_script = script
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
if lora_script is not None:
|
| 79 |
+
lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
|
| 83 |
+
if not global_state.is_enabled:
|
| 84 |
+
return original_function(p, *args, **kwargs)
|
| 85 |
+
|
| 86 |
+
exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
|
| 87 |
+
all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
|
| 88 |
+
res = original_function(p, *args, **kwargs)
|
| 89 |
+
# restore original prompts
|
| 90 |
+
p.all_prompts = all_prompts
|
| 91 |
+
return res
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
xyz_grid.patch()
|
neutral_prompt_patched/test/perp_parser/__init__.py
ADDED
|
File without changes
|
neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for affine-keyword ordering in parse_prompt.
|
| 3 |
+
|
| 4 |
+
Both orderings must work correctly:
|
| 5 |
+
(A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
|
| 6 |
+
(B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
|
| 7 |
+
"""
|
| 8 |
+
import unittest
|
| 9 |
+
|
| 10 |
+
from lib_neutral_prompt import neutral_prompt_parser as p
|
| 11 |
+
from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestAffineKeywordOrder(unittest.TestCase):
|
| 15 |
+
|
| 16 |
+
# ------------------------------------------------------------------ #
|
| 17 |
+
# trailing affine (AND_PERP ROTATE[...] text) #
|
| 18 |
+
# ------------------------------------------------------------------ #
|
| 19 |
+
|
| 20 |
+
def test_trailing_affine_basic(self):
|
| 21 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
|
| 22 |
+
self.assertEqual(len(root.children), 2)
|
| 23 |
+
self.assertEqual(root.children[0].prompt.strip(), "base")
|
| 24 |
+
self.assertIsNone(root.children[0].local_transform)
|
| 25 |
+
self.assertIsNotNone(root.children[1].local_transform)
|
| 26 |
+
self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 27 |
+
self.assertAlmostEqual(root.children[1].weight, 0.8)
|
| 28 |
+
|
| 29 |
+
# ------------------------------------------------------------------ #
|
| 30 |
+
# leading affine (ROTATE[...] AND_PERP text) #
|
| 31 |
+
# ------------------------------------------------------------------ #
|
| 32 |
+
|
| 33 |
+
def test_leading_affine_mid_prompt(self):
|
| 34 |
+
"""ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
|
| 35 |
+
root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 36 |
+
self.assertEqual(len(root.children), 2,
|
| 37 |
+
f"Expected 2 children, got {len(root.children)}: "
|
| 38 |
+
f"{[getattr(c,'prompt','composite') for c in root.children]}")
|
| 39 |
+
first, second = root.children
|
| 40 |
+
# first segment = "base " with no transform
|
| 41 |
+
self.assertIn("base", first.prompt)
|
| 42 |
+
self.assertNotIn("ROTATE", first.prompt)
|
| 43 |
+
self.assertIsNone(first.local_transform)
|
| 44 |
+
# second segment gets the transform
|
| 45 |
+
self.assertIsNotNone(second.local_transform)
|
| 46 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 47 |
+
self.assertAlmostEqual(second.weight, 0.8)
|
| 48 |
+
|
| 49 |
+
def test_leading_affine_at_start(self):
|
| 50 |
+
"""ROTATE[...] AND_PERP as first (and only non-AND) segment."""
|
| 51 |
+
root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
|
| 52 |
+
self.assertEqual(len(root.children), 1)
|
| 53 |
+
child = root.children[0]
|
| 54 |
+
self.assertIsNotNone(child.local_transform)
|
| 55 |
+
self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
|
| 56 |
+
self.assertIn("vivid", child.prompt)
|
| 57 |
+
self.assertNotIn("ROTATE", child.prompt)
|
| 58 |
+
|
| 59 |
+
def test_leading_affine_with_and_salt(self):
|
| 60 |
+
root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
|
| 61 |
+
self.assertEqual(len(root.children), 2)
|
| 62 |
+
second = root.children[1]
|
| 63 |
+
self.assertIsNotNone(second.local_transform)
|
| 64 |
+
self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
# Composed / chained affine #
|
| 68 |
+
# ------------------------------------------------------------------ #
|
| 69 |
+
|
| 70 |
+
def test_leading_and_trailing_compose(self):
|
| 71 |
+
"""ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
|
| 72 |
+
root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
|
| 73 |
+
# Just check no crash and all children parse
|
| 74 |
+
self.assertGreaterEqual(len(root.children), 1)
|
| 75 |
+
|
| 76 |
+
# ------------------------------------------------------------------ #
|
| 77 |
+
# CFGRescaleFactorSingleton lifecycle #
|
| 78 |
+
# ------------------------------------------------------------------ #
|
| 79 |
+
|
| 80 |
+
def test_cfg_rescale_singleton_clear(self):
|
| 81 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 82 |
+
S.clear()
|
| 83 |
+
self.assertIsNone(S.get())
|
| 84 |
+
S.set(1.23)
|
| 85 |
+
self.assertAlmostEqual(S.get(), 1.23)
|
| 86 |
+
S.clear()
|
| 87 |
+
self.assertIsNone(S.get())
|
| 88 |
+
|
| 89 |
+
def test_cfg_rescale_singleton_thread_local(self):
|
| 90 |
+
import threading
|
| 91 |
+
from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
|
| 92 |
+
results = {}
|
| 93 |
+
def worker(val, key):
|
| 94 |
+
S.clear()
|
| 95 |
+
S.set(val)
|
| 96 |
+
import time; time.sleep(0.01)
|
| 97 |
+
results[key] = S.get()
|
| 98 |
+
threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
|
| 99 |
+
for t in threads: t.start()
|
| 100 |
+
for t in threads: t.join()
|
| 101 |
+
for i in range(3):
|
| 102 |
+
self.assertAlmostEqual(results[i], i * 10.0,
|
| 103 |
+
msg=f"Thread {i} got {results[i]} expected {i*10.0}")
|
| 104 |
+
|
| 105 |
+
# ------------------------------------------------------------------ #
|
| 106 |
+
# New AND_SALT variants parse correctly #
|
| 107 |
+
# ------------------------------------------------------------------ #
|
| 108 |
+
|
| 109 |
+
def test_and_salt_wide_parsed(self):
|
| 110 |
+
root = p.parse_root("base AND_SALT_WIDE vivid")
|
| 111 |
+
self.assertEqual(len(root.children), 2)
|
| 112 |
+
self.assertEqual(root.children[1].conciliation,
|
| 113 |
+
p.ConciliationStrategy.SALIENCE_MASK_WIDE)
|
| 114 |
+
|
| 115 |
+
def test_and_salt_blob_parsed(self):
|
| 116 |
+
root = p.parse_root("base AND_SALT_BLOB vivid")
|
| 117 |
+
self.assertEqual(len(root.children), 2)
|
| 118 |
+
self.assertEqual(root.children[1].conciliation,
|
| 119 |
+
p.ConciliationStrategy.SALIENCE_MASK_BLOB)
|
| 120 |
+
|
| 121 |
+
def test_and_align_parsed(self):
|
| 122 |
+
root = p.parse_root("base AND_ALIGN_4_8 vivid")
|
| 123 |
+
self.assertEqual(len(root.children), 2)
|
| 124 |
+
self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
|
| 125 |
+
|
| 126 |
+
def test_and_mask_align_parsed(self):
|
| 127 |
+
root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
|
| 128 |
+
self.assertEqual(len(root.children), 2)
|
| 129 |
+
self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration smoke tests for the pre-noise affine pipeline.
|
| 3 |
+
|
| 4 |
+
Requires PyTorch. If torch is not installed the entire module is skipped via
|
| 5 |
+
``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
|
| 6 |
+
|
| 7 |
+
Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
|
| 8 |
+
modules.shared). torch itself is real — no sys.modules replacement.
|
| 9 |
+
|
| 10 |
+
Tests:
|
| 11 |
+
1. Hook is a no-op when global_state.is_enabled = False.
|
| 12 |
+
2. Identity transform (angle=0) leaves x numerically unchanged.
|
| 13 |
+
3. Non-identity transform calls apply_affine_transform (verified by spy).
|
| 14 |
+
4. Prompts with no affine leave x unchanged.
|
| 15 |
+
5. Empty state does not crash.
|
| 16 |
+
6. Batch with multiple prompts does not crash.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import dataclasses
|
| 20 |
+
import importlib
|
| 21 |
+
import math
|
| 22 |
+
import sys
|
| 23 |
+
import types
|
| 24 |
+
import unittest
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Detect torch availability — skip the whole module if absent
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
try:
|
| 30 |
+
import importlib.util as _ilu
|
| 31 |
+
_spec = _ilu.find_spec('torch')
|
| 32 |
+
_TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
|
| 33 |
+
except (ValueError, ModuleNotFoundError):
|
| 34 |
+
_TORCH_AVAILABLE = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
|
| 39 |
+
# be imported even when running the skip path)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
def _stub(name):
|
| 43 |
+
m = types.ModuleType(name)
|
| 44 |
+
sys.modules.setdefault(name, m)
|
| 45 |
+
return sys.modules[name]
|
| 46 |
+
|
| 47 |
+
for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
|
| 48 |
+
'modules.shared', 'modules.prompt_parser', 'gradio'):
|
| 49 |
+
_stub(_mod_name)
|
| 50 |
+
|
| 51 |
+
import modules.script_callbacks as _sc
|
| 52 |
+
if not hasattr(_sc, 'on_cfg_denoiser'):
|
| 53 |
+
_sc.on_cfg_denoiser = lambda fn: None
|
| 54 |
+
if not hasattr(_sc, 'on_script_unloaded'):
|
| 55 |
+
_sc.on_script_unloaded = lambda fn: None
|
| 56 |
+
|
| 57 |
+
import modules.shared as _sh
|
| 58 |
+
if not hasattr(_sh, 'opts'):
|
| 59 |
+
_sh.opts = types.SimpleNamespace(
|
| 60 |
+
data={},
|
| 61 |
+
add_option=lambda *a, **k: None,
|
| 62 |
+
onchange=lambda *a, **k: None,
|
| 63 |
+
OptionInfo=lambda *a, **k: None,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
import modules.sd_samplers as _ss
|
| 67 |
+
if not hasattr(_ss, 'cfg_denoiser'):
|
| 68 |
+
_ss.cfg_denoiser = None
|
| 69 |
+
# cfg_denoiser_hijack installs a hijack on create_sampler at import time
|
| 70 |
+
if not hasattr(_ss, 'create_sampler'):
|
| 71 |
+
_ss.create_sampler = lambda *a, **k: None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
# Repo root on sys.path
|
| 76 |
+
# ---------------------------------------------------------------------------
|
| 77 |
+
|
| 78 |
+
from pathlib import Path
|
| 79 |
+
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 80 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 81 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Conditional imports (only when torch is available)
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
if _TORCH_AVAILABLE:
|
| 89 |
+
import torch
|
| 90 |
+
from lib_neutral_prompt import global_state, neutral_prompt_parser
|
| 91 |
+
from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
|
| 92 |
+
import lib_neutral_prompt.affine_transform as _at_mod
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------
|
| 95 |
+
# Helpers
|
| 96 |
+
# -----------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _leaf(text, angle=None):
|
| 99 |
+
"""LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
|
| 100 |
+
transform = None
|
| 101 |
+
if angle is not None:
|
| 102 |
+
c = math.cos(angle * 2 * math.pi)
|
| 103 |
+
s = math.sin(angle * 2 * math.pi)
|
| 104 |
+
transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
|
| 105 |
+
return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
|
| 106 |
+
|
| 107 |
+
def _comp(*children):
|
| 108 |
+
return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
|
| 109 |
+
|
| 110 |
+
@dataclasses.dataclass
|
| 111 |
+
class _FakeParams:
|
| 112 |
+
x: torch.Tensor
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Test class
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
@unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
|
| 120 |
+
class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
|
| 121 |
+
|
| 122 |
+
def setUp(self):
|
| 123 |
+
global_state.is_enabled = True
|
| 124 |
+
global_state.prompt_exprs = []
|
| 125 |
+
global_state.batch_cond_indices = []
|
| 126 |
+
|
| 127 |
+
def tearDown(self):
|
| 128 |
+
global_state.is_enabled = False
|
| 129 |
+
global_state.prompt_exprs = []
|
| 130 |
+
global_state.batch_cond_indices = []
|
| 131 |
+
|
| 132 |
+
# 1. disabled → x untouched
|
| 133 |
+
def test_disabled_leaves_x_unchanged(self):
|
| 134 |
+
global_state.is_enabled = False
|
| 135 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 136 |
+
x[0] = 1.0
|
| 137 |
+
orig = x.clone()
|
| 138 |
+
params = _FakeParams(x=x.clone())
|
| 139 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
|
| 140 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 141 |
+
_on_cfg_denoiser(params)
|
| 142 |
+
self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
|
| 143 |
+
|
| 144 |
+
# 2. identity transform → x numerically unchanged
|
| 145 |
+
def test_identity_transform_no_change(self):
|
| 146 |
+
x = torch.randn(2, 4, 8, 8)
|
| 147 |
+
orig = x.clone()
|
| 148 |
+
params = _FakeParams(x=x.clone())
|
| 149 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
|
| 150 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 151 |
+
_on_cfg_denoiser(params)
|
| 152 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 153 |
+
"Identity affine must not change x")
|
| 154 |
+
|
| 155 |
+
# 3. non-identity → apply_affine_transform is called (spy)
|
| 156 |
+
def test_nonidentity_calls_apply_affine_transform(self):
|
| 157 |
+
x = torch.zeros(2, 4, 8, 8)
|
| 158 |
+
x[1] = torch.randn(4, 8, 8)
|
| 159 |
+
params = _FakeParams(x=x.clone())
|
| 160 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
|
| 161 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 162 |
+
|
| 163 |
+
_real = _at_mod.apply_affine_transform
|
| 164 |
+
calls = []
|
| 165 |
+
def _spy(tensor, affine, mode='bilinear'):
|
| 166 |
+
calls.append(True)
|
| 167 |
+
return _real(tensor, affine, mode)
|
| 168 |
+
|
| 169 |
+
_at_mod.apply_affine_transform = _spy
|
| 170 |
+
try:
|
| 171 |
+
_on_cfg_denoiser(params)
|
| 172 |
+
finally:
|
| 173 |
+
_at_mod.apply_affine_transform = _real
|
| 174 |
+
|
| 175 |
+
self.assertGreater(len(calls), 0,
|
| 176 |
+
"apply_affine_transform must be called for non-identity transform")
|
| 177 |
+
|
| 178 |
+
# 4. no affine on any child → x unchanged
|
| 179 |
+
def test_no_affine_no_change(self):
|
| 180 |
+
x = torch.randn(2, 4, 8, 8)
|
| 181 |
+
orig = x.clone()
|
| 182 |
+
params = _FakeParams(x=x.clone())
|
| 183 |
+
global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
|
| 184 |
+
global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
|
| 185 |
+
_on_cfg_denoiser(params)
|
| 186 |
+
self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
|
| 187 |
+
"No affine → x must not change")
|
| 188 |
+
|
| 189 |
+
# 5. empty state → no crash
|
| 190 |
+
def test_empty_state_no_crash(self):
|
| 191 |
+
params = _FakeParams(x=torch.randn(1, 4, 8, 8))
|
| 192 |
+
global_state.prompt_exprs = []
|
| 193 |
+
global_state.batch_cond_indices = []
|
| 194 |
+
try:
|
| 195 |
+
_on_cfg_denoiser(params)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
|
| 198 |
+
|
| 199 |
+
# 6. two batch entries → no crash
|
| 200 |
+
def test_batch_multiple_prompts_no_crash(self):
|
| 201 |
+
params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
|
| 202 |
+
global_state.prompt_exprs = [
|
| 203 |
+
_comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
|
| 204 |
+
_comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
|
| 205 |
+
]
|
| 206 |
+
global_state.batch_cond_indices = [
|
| 207 |
+
[(0, 1.0), (1, 1.0)],
|
| 208 |
+
[(2, 1.0), (3, 1.0)],
|
| 209 |
+
]
|
| 210 |
+
try:
|
| 211 |
+
_on_cfg_denoiser(params)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
self.fail(f"Multi-prompt batch raised: {e}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == '__main__':
|
| 217 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_basic_parser.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
|
| 11 |
+
self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
|
| 12 |
+
self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
|
| 13 |
+
self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
|
| 14 |
+
self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
|
| 15 |
+
self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
|
| 16 |
+
self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
|
| 17 |
+
|
| 18 |
+
def test_simple_prompt_child_count(self):
|
| 19 |
+
self.assertEqual(len(self.simple_prompt.children), 1)
|
| 20 |
+
|
| 21 |
+
def test_simple_prompt_child_weight(self):
|
| 22 |
+
self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
|
| 23 |
+
|
| 24 |
+
def test_simple_prompt_child_prompt(self):
|
| 25 |
+
self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
|
| 26 |
+
|
| 27 |
+
def test_square_weight_prompt(self):
|
| 28 |
+
prompt = "a [b c d e : f g h :1.5]"
|
| 29 |
+
parsed = neutral_prompt_parser.parse_root(prompt)
|
| 30 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 31 |
+
|
| 32 |
+
composed_prompt = f"{prompt} AND_PERP other prompt"
|
| 33 |
+
parsed = neutral_prompt_parser.parse_root(composed_prompt)
|
| 34 |
+
self.assertEqual(parsed.children[0].prompt, prompt)
|
| 35 |
+
|
| 36 |
+
def test_and_prompt_child_count(self):
|
| 37 |
+
self.assertEqual(len(self.and_prompt.children), 2)
|
| 38 |
+
|
| 39 |
+
def test_and_prompt_child_weights_and_prompts(self):
|
| 40 |
+
self.assertEqual(self.and_prompt.children[0].weight, 1.0)
|
| 41 |
+
self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
|
| 42 |
+
self.assertEqual(self.and_prompt.children[1].weight, 2.0)
|
| 43 |
+
self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
|
| 44 |
+
|
| 45 |
+
def test_and_perp_prompt_child_count(self):
|
| 46 |
+
self.assertEqual(len(self.and_perp_prompt.children), 2)
|
| 47 |
+
|
| 48 |
+
def test_and_perp_prompt_child_types(self):
|
| 49 |
+
self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 50 |
+
self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 51 |
+
|
| 52 |
+
def test_and_perp_prompt_nested_child(self):
|
| 53 |
+
nested_child = self.and_perp_prompt.children[1]
|
| 54 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 55 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 56 |
+
|
| 57 |
+
def test_nested_and_perp_prompt_child_count(self):
|
| 58 |
+
self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
|
| 59 |
+
|
| 60 |
+
def test_nested_and_perp_prompt_child_types(self):
|
| 61 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 62 |
+
self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 63 |
+
|
| 64 |
+
def test_nested_and_perp_prompt_nested_child_types(self):
|
| 65 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[0]
|
| 66 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 67 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 68 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 69 |
+
|
| 70 |
+
def test_nested_and_perp_prompt_nested_child(self):
|
| 71 |
+
nested_child = self.nested_and_perp_prompt.children[1].children[1]
|
| 72 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 73 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 74 |
+
|
| 75 |
+
def test_invalid_weight_child_count(self):
|
| 76 |
+
self.assertEqual(len(self.invalid_weight.children), 1)
|
| 77 |
+
|
| 78 |
+
def test_invalid_weight_child_weight(self):
|
| 79 |
+
self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
|
| 80 |
+
|
| 81 |
+
def test_invalid_weight_child_prompt(self):
|
| 82 |
+
self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
|
| 83 |
+
|
| 84 |
+
def test_and_salt_prompt_child_count(self):
|
| 85 |
+
self.assertEqual(len(self.and_salt_prompt.children), 2)
|
| 86 |
+
|
| 87 |
+
def test_and_salt_prompt_child_types(self):
|
| 88 |
+
self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 89 |
+
self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
|
| 90 |
+
|
| 91 |
+
def test_and_salt_prompt_nested_child(self):
|
| 92 |
+
nested_child = self.and_salt_prompt.children[1]
|
| 93 |
+
self.assertEqual(nested_child.weight, 2.0)
|
| 94 |
+
self.assertEqual(nested_child.prompt.strip(), "goodbye")
|
| 95 |
+
|
| 96 |
+
def test_nested_and_salt_prompt_child_count(self):
|
| 97 |
+
self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
|
| 98 |
+
|
| 99 |
+
def test_nested_and_salt_prompt_child_types(self):
|
| 100 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
|
| 101 |
+
self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
|
| 102 |
+
|
| 103 |
+
def test_nested_and_salt_prompt_nested_child_types(self):
|
| 104 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[0]
|
| 105 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 106 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 107 |
+
self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
|
| 108 |
+
|
| 109 |
+
def test_nested_and_salt_prompt_nested_child(self):
|
| 110 |
+
nested_child = self.nested_and_salt_prompt.children[1].children[1]
|
| 111 |
+
self.assertEqual(nested_child.weight, 3.0)
|
| 112 |
+
self.assertEqual(nested_child.prompt.strip(), "welcome")
|
| 113 |
+
|
| 114 |
+
def test_start_with_prompt_editing(self):
|
| 115 |
+
prompt = "[(long shot:1.2):0.1] detail.."
|
| 116 |
+
res = neutral_prompt_parser.parse_root(prompt)
|
| 117 |
+
self.assertEqual(res.children[0].weight, 1.0)
|
| 118 |
+
self.assertEqual(res.children[0].prompt, prompt)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
unittest.main()
|
neutral_prompt_patched/test/perp_parser/test_malicious_parser.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import pathlib
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
| 5 |
+
from lib_neutral_prompt import neutral_prompt_parser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestMaliciousPromptParser(unittest.TestCase):
|
| 9 |
+
def setUp(self):
|
| 10 |
+
self.parser = neutral_prompt_parser
|
| 11 |
+
|
| 12 |
+
def test_empty(self):
|
| 13 |
+
result = self.parser.parse_root("")
|
| 14 |
+
self.assertEqual(result.children[0].prompt, "")
|
| 15 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 16 |
+
|
| 17 |
+
def test_zero_weight(self):
|
| 18 |
+
result = self.parser.parse_root("hello :0.0")
|
| 19 |
+
self.assertEqual(result.children[0].weight, 0.0)
|
| 20 |
+
|
| 21 |
+
def test_mixed_positive_and_negative_weights(self):
|
| 22 |
+
result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
|
| 23 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 24 |
+
self.assertEqual(result.children[1].weight, -2.0)
|
| 25 |
+
|
| 26 |
+
def test_debalanced_square_brackets(self):
|
| 27 |
+
prompt = "a [ b " * 100
|
| 28 |
+
result = self.parser.parse_root(prompt)
|
| 29 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 30 |
+
|
| 31 |
+
prompt = "a ] b " * 100
|
| 32 |
+
result = self.parser.parse_root(prompt)
|
| 33 |
+
self.assertEqual(result.children[0].prompt, prompt)
|
| 34 |
+
|
| 35 |
+
repeats = 10
|
| 36 |
+
prompt = "a [ [ b AND c ] " * repeats
|
| 37 |
+
result = self.parser.parse_root(prompt)
|
| 38 |
+
self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
|
| 39 |
+
|
| 40 |
+
repeats = 10
|
| 41 |
+
prompt = "a [ b AND c ] ] " * repeats
|
| 42 |
+
result = self.parser.parse_root(prompt)
|
| 43 |
+
self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
|
| 44 |
+
|
| 45 |
+
def test_erroneous_syntax(self):
|
| 46 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
|
| 47 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 48 |
+
self.assertEqual(result.children[1].prompt, "[goodbye ")
|
| 49 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 50 |
+
|
| 51 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
|
| 52 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 53 |
+
self.assertEqual(result.children[1].prompt, " goodbye ")
|
| 54 |
+
|
| 55 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
|
| 56 |
+
self.assertEqual(result.children[1].prompt, " goodbye]")
|
| 57 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 58 |
+
|
| 59 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
|
| 60 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 61 |
+
self.assertEqual(result.children[1].prompt, " a [ goodbye ")
|
| 62 |
+
|
| 63 |
+
result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
|
| 64 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 65 |
+
self.assertEqual(result.children[2].prompt, " goodbye ")
|
| 66 |
+
|
| 67 |
+
def test_huge_number_of_prompt_parts(self):
|
| 68 |
+
result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
|
| 69 |
+
self.assertEqual(len(result.children), 10**4)
|
| 70 |
+
|
| 71 |
+
def test_prompt_ending_with_weight(self):
|
| 72 |
+
result = self.parser.parse_root("hello :1.0 AND :2.0")
|
| 73 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 74 |
+
self.assertEqual(result.children[1].prompt, "")
|
| 75 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 76 |
+
|
| 77 |
+
def test_huge_input_string(self):
|
| 78 |
+
big_string = "hello :1.0 AND " * 10**4
|
| 79 |
+
result = self.parser.parse_root(big_string)
|
| 80 |
+
self.assertEqual(len(result.children), 10**4 + 1)
|
| 81 |
+
|
| 82 |
+
def test_deeply_nested_prompt(self):
|
| 83 |
+
deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
|
| 84 |
+
result = self.parser.parse_root(deeply_nested_prompt)
|
| 85 |
+
self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
|
| 86 |
+
|
| 87 |
+
def test_complex_nested_prompts(self):
|
| 88 |
+
complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
|
| 89 |
+
result = self.parser.parse_root(complex_prompt)
|
| 90 |
+
self.assertEqual(result.children[0].weight, 1.0)
|
| 91 |
+
self.assertEqual(result.children[1].weight, 2.0)
|
| 92 |
+
self.assertEqual(result.children[2].children[0].weight, 3.0)
|
| 93 |
+
self.assertEqual(result.children[2].children[1].weight, 4.0)
|
| 94 |
+
self.assertEqual(result.children[2].children[2].weight, 5.0)
|
| 95 |
+
|
| 96 |
+
def test_string_with_random_characters(self):
|
| 97 |
+
random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
|
| 98 |
+
try:
|
| 99 |
+
self.parser.parse_root(random_chars)
|
| 100 |
+
except Exception:
|
| 101 |
+
self.fail("parse_root couldn't handle a string with random characters.")
|
| 102 |
+
|
| 103 |
+
def test_string_with_unexpected_symbols(self):
|
| 104 |
+
unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
|
| 105 |
+
try:
|
| 106 |
+
self.parser.parse_root(unexpected_symbols)
|
| 107 |
+
except Exception:
|
| 108 |
+
self.fail("parse_root couldn't handle a string with unexpected symbols.")
|
| 109 |
+
|
| 110 |
+
def test_string_with_unconventional_structure(self):
|
| 111 |
+
unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
|
| 112 |
+
try:
|
| 113 |
+
self.parser.parse_root(unconventional_structure)
|
| 114 |
+
except Exception:
|
| 115 |
+
self.fail("parse_root couldn't handle a string with unconventional structure.")
|
| 116 |
+
|
| 117 |
+
def test_string_with_mixed_alphabets_and_numbers(self):
|
| 118 |
+
mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
|
| 119 |
+
try:
|
| 120 |
+
self.parser.parse_root(mixed_alphabets_and_numbers)
|
| 121 |
+
except Exception:
|
| 122 |
+
self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
|
| 123 |
+
|
| 124 |
+
def test_string_with_nested_brackets(self):
|
| 125 |
+
nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
|
| 126 |
+
try:
|
| 127 |
+
self.parser.parse_root(nested_brackets)
|
| 128 |
+
except Exception:
|
| 129 |
+
self.fail("parse_root couldn't handle a string with nested brackets.")
|
| 130 |
+
|
| 131 |
+
def test_unmatched_opening_braces(self):
|
| 132 |
+
unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
|
| 133 |
+
try:
|
| 134 |
+
self.parser.parse_root(unmatched_opening_braces)
|
| 135 |
+
except Exception:
|
| 136 |
+
self.fail("parse_root couldn't handle a string with unmatched opening braces.")
|
| 137 |
+
|
| 138 |
+
def test_unmatched_closing_braces(self):
|
| 139 |
+
unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
|
| 140 |
+
try:
|
| 141 |
+
self.parser.parse_root(unmatched_closing_braces)
|
| 142 |
+
except Exception:
|
| 143 |
+
self.fail("parse_root couldn't handle a string with unmatched closing braces.")
|
| 144 |
+
|
| 145 |
+
def test_repeating_colons(self):
|
| 146 |
+
repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
|
| 147 |
+
try:
|
| 148 |
+
self.parser.parse_root(repeating_colons)
|
| 149 |
+
except Exception:
|
| 150 |
+
self.fail("parse_root couldn't handle a string with repeating colons.")
|
| 151 |
+
|
| 152 |
+
def test_excessive_whitespace(self):
|
| 153 |
+
excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
|
| 154 |
+
try:
|
| 155 |
+
self.parser.parse_root(excessive_whitespace)
|
| 156 |
+
except Exception:
|
| 157 |
+
self.fail("parse_root couldn't handle a string with excessive whitespace.")
|
| 158 |
+
|
| 159 |
+
def test_repeating_AND_keyword(self):
|
| 160 |
+
repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
|
| 161 |
+
try:
|
| 162 |
+
self.parser.parse_root(repeating_AND_keyword)
|
| 163 |
+
except Exception:
|
| 164 |
+
self.fail("parse_root couldn't handle a string with repeating AND keyword.")
|
| 165 |
+
|
| 166 |
+
def test_repeating_AND_PERP_keyword(self):
|
| 167 |
+
repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
|
| 168 |
+
try:
|
| 169 |
+
self.parser.parse_root(repeating_AND_PERP_keyword)
|
| 170 |
+
except Exception:
|
| 171 |
+
self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
|
| 172 |
+
|
| 173 |
+
def test_square_weight_prompt(self):
|
| 174 |
+
prompt = "AND_PERP [weighted] you thought it was the end"
|
| 175 |
+
try:
|
| 176 |
+
self.parser.parse_root(prompt)
|
| 177 |
+
except Exception:
|
| 178 |
+
self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if __name__ == '__main__':
|
| 182 |
+
unittest.main()
|
prompt-fusion-extension-main/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/venv/
|
| 2 |
+
/.idea/
|
| 3 |
+
__pycache__/
|
prompt-fusion-extension-main/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2003
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
prompt-fusion-extension-main/lib_prompt_fusion/ast_nodes.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from lib_prompt_fusion import interpolation_functions
|
| 3 |
+
from lib_prompt_fusion.t_scaler import scale_t
|
| 4 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ListExpression:
|
| 8 |
+
def __init__(self, expressions):
|
| 9 |
+
self.__expressions = expressions
|
| 10 |
+
|
| 11 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 12 |
+
if not self.__expressions:
|
| 13 |
+
return
|
| 14 |
+
|
| 15 |
+
def expr_extend_tensor(expr):
|
| 16 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 17 |
+
|
| 18 |
+
expr_extend_tensor(self.__expressions[0])
|
| 19 |
+
for expression in self.__expressions[1:]:
|
| 20 |
+
tensor_builder.append(' ')
|
| 21 |
+
expr_extend_tensor(expression)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InterpolationExpression:
|
| 25 |
+
@staticmethod
|
| 26 |
+
def create(exprs, steps, function_name):
|
| 27 |
+
if function_name == "mean":
|
| 28 |
+
return AverageExpression(exprs, steps)
|
| 29 |
+
|
| 30 |
+
max_len = min(len(exprs), len(steps))
|
| 31 |
+
exprs = exprs[:max_len]
|
| 32 |
+
steps = steps[:max_len]
|
| 33 |
+
|
| 34 |
+
return InterpolationExpression(exprs, steps, function_name)
|
| 35 |
+
|
| 36 |
+
def __init__(self, expressions, steps, function_name=None):
|
| 37 |
+
assert len(expressions) >= 2
|
| 38 |
+
assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
|
| 39 |
+
self.__expressions = expressions
|
| 40 |
+
self.__steps = steps
|
| 41 |
+
self.__function_name = function_name if function_name is not None else 'linear'
|
| 42 |
+
|
| 43 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 44 |
+
def tensor_updater(expr):
|
| 45 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 46 |
+
|
| 47 |
+
tensor_builder.extrude(
|
| 48 |
+
[tensor_updater(expr) for expr in self.__expressions],
|
| 49 |
+
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
|
| 50 |
+
|
| 51 |
+
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 52 |
+
steps = list(self.__steps)
|
| 53 |
+
if steps[0] is None:
|
| 54 |
+
steps[0] = LiftExpression(str(steps_range[0] - 1))
|
| 55 |
+
if steps[-1] is None:
|
| 56 |
+
steps[-1] = LiftExpression(str(steps_range[1] - 1))
|
| 57 |
+
|
| 58 |
+
for i, step in enumerate(steps):
|
| 59 |
+
if step is None:
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 63 |
+
|
| 64 |
+
if use_old_scheduling and 0 < step < 1:
|
| 65 |
+
step *= total_steps
|
| 66 |
+
elif not use_old_scheduling and isinstance(step, float):
|
| 67 |
+
step = (step - int(is_hires)) * total_steps
|
| 68 |
+
else:
|
| 69 |
+
step += 1
|
| 70 |
+
|
| 71 |
+
steps[i] = int(step)
|
| 72 |
+
|
| 73 |
+
i = 1
|
| 74 |
+
while i < len(steps):
|
| 75 |
+
none_len = 0
|
| 76 |
+
while steps[i + none_len] is None:
|
| 77 |
+
none_len += 1
|
| 78 |
+
|
| 79 |
+
min_step, max_step = steps[i - 1], steps[i + none_len]
|
| 80 |
+
|
| 81 |
+
for j in range(none_len):
|
| 82 |
+
steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
|
| 83 |
+
|
| 84 |
+
i += 1 + none_len
|
| 85 |
+
|
| 86 |
+
interpolation_function = {
|
| 87 |
+
'linear': interpolation_functions.compute_linear,
|
| 88 |
+
'bezier': interpolation_functions.compute_bezier,
|
| 89 |
+
'catmull': interpolation_functions.compute_catmull,
|
| 90 |
+
}[self.__function_name]
|
| 91 |
+
|
| 92 |
+
def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
|
| 93 |
+
scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
|
| 94 |
+
scaled_t = scale_t(scaled_t, steps)
|
| 95 |
+
|
| 96 |
+
new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
|
| 97 |
+
return interpolation_function(conds, new_params)
|
| 98 |
+
|
| 99 |
+
return steps_scale_t
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AverageExpression:
|
| 103 |
+
def __init__(self, expressions, weights):
|
| 104 |
+
if len(expressions) < len(weights):
|
| 105 |
+
raise ValueError
|
| 106 |
+
|
| 107 |
+
self.__expressions = expressions
|
| 108 |
+
self.__weights = weights
|
| 109 |
+
|
| 110 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 111 |
+
def tensor_updater(expr):
|
| 112 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 113 |
+
|
| 114 |
+
tensor_builder.extrude(
|
| 115 |
+
[tensor_updater(expr) for expr in self.__expressions],
|
| 116 |
+
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
|
| 117 |
+
|
| 118 |
+
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 119 |
+
weights = [
|
| 120 |
+
_eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
|
| 121 |
+
for weight in self.__weights
|
| 122 |
+
]
|
| 123 |
+
explicit_weights = [weight for weight in weights if weight is not None]
|
| 124 |
+
weights = [
|
| 125 |
+
weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
|
| 126 |
+
if weight is not None
|
| 127 |
+
else 1 / len(self.__expressions)
|
| 128 |
+
for weight in weights
|
| 129 |
+
]
|
| 130 |
+
weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
|
| 131 |
+
|
| 132 |
+
def interpolation_function(conds, _params):
|
| 133 |
+
total = None
|
| 134 |
+
for cond, weight in zip(conds, weights):
|
| 135 |
+
cond *= weight
|
| 136 |
+
if total is None:
|
| 137 |
+
total = cond
|
| 138 |
+
else:
|
| 139 |
+
total += cond
|
| 140 |
+
|
| 141 |
+
return total
|
| 142 |
+
|
| 143 |
+
return interpolation_function
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class AlternationExpression:
|
| 147 |
+
def __init__(self, expressions, speed):
|
| 148 |
+
self.__expressions = expressions
|
| 149 |
+
self.__speed = speed
|
| 150 |
+
|
| 151 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 152 |
+
if self.__speed is None:
|
| 153 |
+
speed = None
|
| 154 |
+
else:
|
| 155 |
+
speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 156 |
+
|
| 157 |
+
if speed is None:
|
| 158 |
+
tensor_builder.append('[')
|
| 159 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 160 |
+
if expr_i >= 1:
|
| 161 |
+
tensor_builder.append('|')
|
| 162 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 163 |
+
tensor_builder.append(']')
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
def tensor_updater(expr):
|
| 167 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 168 |
+
|
| 169 |
+
exprs = self.__expressions + [self.__expressions[0]]
|
| 170 |
+
|
| 171 |
+
tensor_builder.extrude(
|
| 172 |
+
[tensor_updater(expr) for expr in exprs],
|
| 173 |
+
self.get_interpolation_function(speed, exprs, steps_range, total_steps))
|
| 174 |
+
|
| 175 |
+
def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
|
| 176 |
+
def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
|
| 177 |
+
wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
|
| 178 |
+
if wrapped_t < 0:
|
| 179 |
+
wrapped_t = wrapped_t + 1
|
| 180 |
+
new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
|
| 181 |
+
return interpolation_functions.compute_linear(control_points, new_params)
|
| 182 |
+
|
| 183 |
+
return compute_wrap
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class EditingExpression:
|
| 187 |
+
def __init__(self, expressions, step):
|
| 188 |
+
assert 1 <= len(expressions) <= 2
|
| 189 |
+
self.__expressions = expressions
|
| 190 |
+
self.__step = step
|
| 191 |
+
|
| 192 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 193 |
+
if self.__step is None:
|
| 194 |
+
tensor_builder.append('[')
|
| 195 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 196 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 197 |
+
tensor_builder.append(':')
|
| 198 |
+
tensor_builder.append(']')
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 202 |
+
step_int = step
|
| 203 |
+
if use_old_scheduling and 0 < step < 1:
|
| 204 |
+
step_int *= total_steps
|
| 205 |
+
elif not use_old_scheduling and isinstance(step, float):
|
| 206 |
+
step_int = (step_int - int(is_hires)) * total_steps
|
| 207 |
+
else:
|
| 208 |
+
step_int += 1
|
| 209 |
+
|
| 210 |
+
step_int = int(step_int)
|
| 211 |
+
|
| 212 |
+
tensor_builder.append('[')
|
| 213 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 214 |
+
expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
|
| 215 |
+
expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 216 |
+
tensor_builder.append(':')
|
| 217 |
+
|
| 218 |
+
tensor_builder.append(f'{step}]')
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class WeightedExpression:
|
| 222 |
+
def __init__(self, nested, weight=None, positive=True):
|
| 223 |
+
self.__nested = nested
|
| 224 |
+
if not positive:
|
| 225 |
+
assert weight is None
|
| 226 |
+
|
| 227 |
+
self.__weight = weight
|
| 228 |
+
self.__positive = positive
|
| 229 |
+
|
| 230 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 231 |
+
open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
|
| 232 |
+
tensor_builder.append(open_bracket)
|
| 233 |
+
self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 234 |
+
|
| 235 |
+
if self.__weight is not None:
|
| 236 |
+
tensor_builder.append(':')
|
| 237 |
+
self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 238 |
+
|
| 239 |
+
tensor_builder.append(close_bracket)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class WeightInterpolationExpression:
|
| 243 |
+
def __init__(self, nested, weight_begin, weight_end):
|
| 244 |
+
self.__nested = nested
|
| 245 |
+
self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
|
| 246 |
+
self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
|
| 247 |
+
|
| 248 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 249 |
+
steps_range_size = steps_range[1] - steps_range[0]
|
| 250 |
+
|
| 251 |
+
weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 252 |
+
weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 253 |
+
|
| 254 |
+
for i in range(steps_range_size):
|
| 255 |
+
step = i + steps_range[0]
|
| 256 |
+
|
| 257 |
+
weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
|
| 258 |
+
weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
|
| 259 |
+
if step > steps_range[0]:
|
| 260 |
+
weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
|
| 261 |
+
if step + 1 < steps_range[1]:
|
| 262 |
+
weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
|
| 263 |
+
|
| 264 |
+
weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class DeclarationExpression:
|
| 268 |
+
def __init__(self, symbol, parameters, value, target):
|
| 269 |
+
self.__symbol = symbol
|
| 270 |
+
self.__value = value
|
| 271 |
+
self.__target = target
|
| 272 |
+
self.__parameters = parameters
|
| 273 |
+
|
| 274 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 275 |
+
updated_context = dict(context)
|
| 276 |
+
updated_context[self.__symbol] = (self.__value, self.__parameters)
|
| 277 |
+
self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class SubstitutionExpression:
|
| 281 |
+
def __init__(self, symbol, arguments):
|
| 282 |
+
self.__symbol = symbol
|
| 283 |
+
self.__arguments = arguments
|
| 284 |
+
|
| 285 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 286 |
+
updated_context = dict(context)
|
| 287 |
+
nested, parameters = context[self.__symbol]
|
| 288 |
+
for argument, parameter in zip(self.__arguments, parameters):
|
| 289 |
+
updated_context[parameter] = argument, []
|
| 290 |
+
nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class LiftExpression:
|
| 294 |
+
def __init__(self, value):
|
| 295 |
+
self.__value = value
|
| 296 |
+
|
| 297 |
+
def extend_tensor(self, tensor_builder, *_args, **_kwargs):
|
| 298 |
+
tensor_builder.append(self.__value)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 302 |
+
mock_database = ['']
|
| 303 |
+
expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 304 |
+
try:
|
| 305 |
+
return int(mock_database[0])
|
| 306 |
+
except ValueError:
|
| 307 |
+
return float(mock_database[0])
|
prompt-fusion-extension-main/lib_prompt_fusion/empty_cond.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
_empty_cond = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get():
|
| 8 |
+
return _empty_cond
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def init(model):
|
| 12 |
+
global _empty_cond
|
| 13 |
+
cond = model.get_learned_conditioning([''])
|
| 14 |
+
if isinstance(cond, dict):
|
| 15 |
+
cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
|
| 16 |
+
else:
|
| 17 |
+
cond = interpolation_tensor.TensorCondWrapper(cond[0])
|
| 18 |
+
|
| 19 |
+
_empty_cond = cond
|
prompt-fusion-extension-main/lib_prompt_fusion/geometries.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
|
| 7 |
+
p0, p1 = control_points
|
| 8 |
+
p0_norm = torch.linalg.norm(p0)
|
| 9 |
+
p1_norm = torch.linalg.norm(p1)
|
| 10 |
+
|
| 11 |
+
similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
|
| 12 |
+
similarity = min(1., max(-1., float(similarity)))
|
| 13 |
+
if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
|
| 14 |
+
return linear_geometry(control_points, params)
|
| 15 |
+
|
| 16 |
+
angle = math.acos(float(similarity)) / 2
|
| 17 |
+
|
| 18 |
+
slerp_t = angle * (2 * params.t - 1)
|
| 19 |
+
slerp_t = math.tan(slerp_t) / math.tan(angle)
|
| 20 |
+
slerp_t = (slerp_t + 1) / 2
|
| 21 |
+
|
| 22 |
+
normalized_p1 = p1 / p1_norm * p0_norm
|
| 23 |
+
slerp_p = p0 + (normalized_p1 - p0) * slerp_t
|
| 24 |
+
slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
|
| 25 |
+
|
| 26 |
+
lerp_p = linear_geometry(control_points, params)
|
| 27 |
+
return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
|
| 31 |
+
p0, p1 = control_points
|
| 32 |
+
res = p0 + (p1 - p0) * params.t
|
| 33 |
+
return res
|
prompt-fusion-extension-main/lib_prompt_fusion/global_state.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from modules import shared, prompt_parser
|
| 3 |
+
from lib_prompt_fusion import empty_cond
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
old_webui_is_negative: bool = False
|
| 7 |
+
negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
|
| 8 |
+
negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_origin_cond_at(step: int, is_hires: bool = False):
|
| 12 |
+
fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
|
| 13 |
+
if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
|
| 14 |
+
return empty_cond.get()
|
| 15 |
+
|
| 16 |
+
for schedule in fallback_schedules:
|
| 17 |
+
if schedule.end_at_step >= step:
|
| 18 |
+
return schedule.cond
|
| 19 |
+
|
| 20 |
+
return empty_cond.get()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_slerp_scale():
|
| 24 |
+
return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_slerp_epsilon():
|
| 28 |
+
return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
|
prompt-fusion-extension-main/lib_prompt_fusion/hijacker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ModuleHijacker:
|
| 2 |
+
def __init__(self, module):
|
| 3 |
+
self.__module = module
|
| 4 |
+
self.__original_functions = dict()
|
| 5 |
+
|
| 6 |
+
def hijack(self, attribute):
|
| 7 |
+
if attribute not in self.__original_functions:
|
| 8 |
+
self.__original_functions[attribute] = getattr(self.__module, attribute)
|
| 9 |
+
|
| 10 |
+
def decorator(function):
|
| 11 |
+
def wrapper(*args, **kwargs):
|
| 12 |
+
return function(*args, **kwargs, original_function=self.__original_functions[attribute])
|
| 13 |
+
|
| 14 |
+
setattr(self.__module, attribute, wrapper)
|
| 15 |
+
return function
|
| 16 |
+
|
| 17 |
+
return decorator
|
| 18 |
+
|
| 19 |
+
def reset_module(self):
|
| 20 |
+
for attribute, original_function in self.__original_functions.items():
|
| 21 |
+
setattr(self.__module, attribute, original_function)
|
| 22 |
+
|
| 23 |
+
self.__original_functions.clear()
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
|
| 27 |
+
if not hasattr(module, hijacker_attribute):
|
| 28 |
+
module_hijacker = ModuleHijacker(module)
|
| 29 |
+
setattr(module, hijacker_attribute, module_hijacker)
|
| 30 |
+
register_uninstall(lambda: delattr(module, hijacker_attribute))
|
| 31 |
+
register_uninstall(module_hijacker.reset_module)
|
| 32 |
+
return module_hijacker
|
| 33 |
+
else:
|
| 34 |
+
return getattr(module, hijacker_attribute)
|
prompt-fusion-extension-main/lib_prompt_fusion/interpolation_functions.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from lib_prompt_fusion import interpolation_tensor, geometries
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
|
| 7 |
+
if len(control_points) <= 2:
|
| 8 |
+
return geometries.slerp_geometry(control_points, params)
|
| 9 |
+
else:
|
| 10 |
+
target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
|
| 11 |
+
cp0 = control_points[target_curve]
|
| 12 |
+
cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
|
| 13 |
+
new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
|
| 14 |
+
return geometries.slerp_geometry([cp0, cp1], new_params)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
|
| 18 |
+
def compute_casteljau(ps, size):
|
| 19 |
+
for i in reversed(range(1, size)):
|
| 20 |
+
for j in range(i):
|
| 21 |
+
ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
|
| 22 |
+
|
| 23 |
+
return ps[0]
|
| 24 |
+
|
| 25 |
+
if len(control_points) == 1:
|
| 26 |
+
return control_points[0]
|
| 27 |
+
elif len(control_points) == 2:
|
| 28 |
+
return geometries.slerp_geometry(control_points, params)
|
| 29 |
+
copied_control_points = copy.deepcopy(control_points)
|
| 30 |
+
return compute_casteljau(copied_control_points, len(copied_control_points))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
|
| 34 |
+
if len(control_points) <= 2:
|
| 35 |
+
return compute_linear(control_points, params)
|
| 36 |
+
else:
|
| 37 |
+
target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
|
| 38 |
+
g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
|
| 39 |
+
cp0 = control_points[target_curve]
|
| 40 |
+
cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
|
| 41 |
+
g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
|
| 42 |
+
ip0 = cp0 + (cp1 - g0)/6
|
| 43 |
+
ip1 = cp1 + (cp0 - g1)/6
|
| 44 |
+
|
| 45 |
+
new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
|
| 46 |
+
return compute_bezier([cp0, ip0, ip1, cp1], new_params)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == '__main__':
|
| 50 |
+
import turtle as tr
|
| 51 |
+
import torch
|
| 52 |
+
size = 60
|
| 53 |
+
turtle_tool = tr.Turtle()
|
| 54 |
+
turtle_tool.speed(10)
|
| 55 |
+
turtle_tool.up()
|
| 56 |
+
|
| 57 |
+
points = torch.Tensor([[-2., -2.], [2., 2.]])
|
| 58 |
+
origin = torch.Tensor([1.5, 1.6])
|
| 59 |
+
|
| 60 |
+
def sample(slerp_scale, color):
|
| 61 |
+
for i in range(size):
|
| 62 |
+
t = i / size
|
| 63 |
+
params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
|
| 64 |
+
point = origin + compute_linear(points - origin, params)
|
| 65 |
+
try:
|
| 66 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in point))
|
| 67 |
+
turtle_tool.dot(5, color)
|
| 68 |
+
print(point)
|
| 69 |
+
except ValueError:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
sample(0, "black")
|
| 73 |
+
sample(1, "green")
|
| 74 |
+
sample(2, "blue")
|
| 75 |
+
sample(-1, "purple")
|
| 76 |
+
sample(-2, "orange")
|
| 77 |
+
|
| 78 |
+
for point in points:
|
| 79 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in point))
|
| 80 |
+
turtle_tool.dot(5, "red")
|
| 81 |
+
|
| 82 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in origin))
|
| 83 |
+
turtle_tool.dot(10, "red")
|
| 84 |
+
|
| 85 |
+
turtle_tool.goto(100000, 100000)
|
| 86 |
+
turtle_tool.dot()
|
| 87 |
+
tr.done()
|
prompt-fusion-extension-main/lib_prompt_fusion/interpolation_tensor.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import torch
|
| 3 |
+
from modules import prompt_parser
|
| 4 |
+
from typing import NamedTuple, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InterpolationParams(NamedTuple):
|
| 8 |
+
t: float
|
| 9 |
+
step: int
|
| 10 |
+
total_steps: int
|
| 11 |
+
slerp_scale: float
|
| 12 |
+
slerp_epsilon: float
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InterpolationTensor:
|
| 16 |
+
def __init__(self, sub_tensors, interpolation_function):
|
| 17 |
+
self.__sub_tensors = sub_tensors
|
| 18 |
+
self.__interpolation_function = interpolation_function
|
| 19 |
+
|
| 20 |
+
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
|
| 21 |
+
cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
|
| 22 |
+
if params.slerp_scale != 0:
|
| 23 |
+
cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
|
| 24 |
+
return cond
|
| 25 |
+
|
| 26 |
+
def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
|
| 27 |
+
if self.__interpolation_function is None:
|
| 28 |
+
return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
|
| 29 |
+
|
| 30 |
+
control_points = [
|
| 31 |
+
sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
|
| 32 |
+
for sub_tensor in self.__sub_tensors
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
CondWrapper, control_points_values = conds_to_cp_values(control_points)
|
| 36 |
+
return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
|
| 37 |
+
|
| 38 |
+
def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
|
| 39 |
+
schedule = None
|
| 40 |
+
for schedule in self.__sub_tensors:
|
| 41 |
+
if schedule.end_at_step >= step:
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
res = schedule.cond.extend_like(origin_cond, empty_cond)
|
| 45 |
+
if slerp_scale != 0:
|
| 46 |
+
res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
|
| 47 |
+
return res
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def conds_to_cp_values(conds):
|
| 51 |
+
CondWrapper = type(conds[0])
|
| 52 |
+
cp_values = [
|
| 53 |
+
cond.to_cp_values()
|
| 54 |
+
for cond in conds
|
| 55 |
+
]
|
| 56 |
+
return CondWrapper, [
|
| 57 |
+
[v[i] for v in cp_values]
|
| 58 |
+
for i in range(len(cp_values[0]))
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class InterpolationTensorBuilder:
|
| 63 |
+
def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
|
| 64 |
+
self.__indices_tensor = tensor if tensor is not None else 0
|
| 65 |
+
self.__prompt_database = prompt_database if prompt_database is not None else ['']
|
| 66 |
+
self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
|
| 67 |
+
|
| 68 |
+
def append(self, suffix):
|
| 69 |
+
for i in range(len(self.__prompt_database)):
|
| 70 |
+
self.__prompt_database[i] += suffix
|
| 71 |
+
|
| 72 |
+
def extrude(self, tensor_updaters, interpolation_function):
|
| 73 |
+
extruded_indices_tensor = []
|
| 74 |
+
extruded_prompt_database = []
|
| 75 |
+
extruded_interpolation_functions = []
|
| 76 |
+
|
| 77 |
+
for update_tensor in tensor_updaters:
|
| 78 |
+
nested_tensor_builder = InterpolationTensorBuilder(
|
| 79 |
+
self.__indices_tensor,
|
| 80 |
+
self.__prompt_database[:],
|
| 81 |
+
interpolation_functions=[])
|
| 82 |
+
|
| 83 |
+
update_tensor(nested_tensor_builder)
|
| 84 |
+
|
| 85 |
+
extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
|
| 86 |
+
tensor=nested_tensor_builder.__indices_tensor,
|
| 87 |
+
offset=len(extruded_prompt_database)))
|
| 88 |
+
extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
|
| 89 |
+
extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
|
| 90 |
+
|
| 91 |
+
self.__indices_tensor = extruded_indices_tensor
|
| 92 |
+
self.__prompt_database[:] = extruded_prompt_database
|
| 93 |
+
self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
|
| 94 |
+
|
| 95 |
+
def get_prompt_database(self):
|
| 96 |
+
return self.__prompt_database
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def __offset_tensor(tensor, offset):
|
| 100 |
+
try:
|
| 101 |
+
return tensor + offset
|
| 102 |
+
|
| 103 |
+
except TypeError:
|
| 104 |
+
return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
|
| 105 |
+
|
| 106 |
+
def build(self, conds, empty_cond):
|
| 107 |
+
max_cond_size = self.__max_cond_size(conds)
|
| 108 |
+
conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
|
| 109 |
+
return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def __build_conditionings_tensor(tensor, int_funcs, conds):
|
| 113 |
+
if type(tensor) is int:
|
| 114 |
+
return InterpolationTensor(conds[tensor], None)
|
| 115 |
+
else:
|
| 116 |
+
int_func, nested_int_funcs = int_funcs[0]
|
| 117 |
+
return InterpolationTensor(
|
| 118 |
+
[
|
| 119 |
+
InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
|
| 120 |
+
for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
|
| 121 |
+
],
|
| 122 |
+
int_func,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
|
| 126 |
+
return [
|
| 127 |
+
[
|
| 128 |
+
prompt_parser.ScheduledPromptConditioning(
|
| 129 |
+
cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
|
| 130 |
+
end_at_step=schedule.end_at_step
|
| 131 |
+
)
|
| 132 |
+
for schedule in schedules
|
| 133 |
+
]
|
| 134 |
+
for schedules in conds
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def __max_cond_size(conds):
|
| 139 |
+
return max(schedule.cond.size(0)
|
| 140 |
+
for schedules in conds
|
| 141 |
+
for schedule in schedules)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclasses.dataclass
|
| 145 |
+
class DictCondWrapper:
|
| 146 |
+
original_cond: dict
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def from_cp_values(cp_values):
|
| 150 |
+
return DictCondWrapper({
|
| 151 |
+
k: v
|
| 152 |
+
for k, v in zip(('crossattn', 'vector'), cp_values)
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
def size(self, *args, **kwargs):
|
| 156 |
+
return self.original_cond['crossattn'].size(*args, **kwargs)
|
| 157 |
+
|
| 158 |
+
def extend_like(self, that, empty):
|
| 159 |
+
missing_size = max(0, that.size(0) - self.size(0)) // 77
|
| 160 |
+
extended = DictCondWrapper(self.original_cond.copy())
|
| 161 |
+
extended.original_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
|
| 162 |
+
return extended
|
| 163 |
+
|
| 164 |
+
def resize_schedule(self, target_size, empty_cond):
|
| 165 |
+
cond_missing_size = (target_size - self.size(0)) // 77
|
| 166 |
+
if cond_missing_size <= 0:
|
| 167 |
+
return self
|
| 168 |
+
|
| 169 |
+
resized_cond = self.original_cond.copy()
|
| 170 |
+
resized_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
|
| 171 |
+
return DictCondWrapper(resized_cond)
|
| 172 |
+
|
| 173 |
+
def to_cp_values(self):
|
| 174 |
+
return list(self.original_cond.values())
|
| 175 |
+
|
| 176 |
+
def to(self, dtype: Union[dict, torch.dtype]):
|
| 177 |
+
if not isinstance(dtype, dict):
|
| 178 |
+
dtype = {
|
| 179 |
+
k: dtype
|
| 180 |
+
for k in self.original_cond.keys()
|
| 181 |
+
}
|
| 182 |
+
return DictCondWrapper({
|
| 183 |
+
k: v.to(dtype=dtype[k])
|
| 184 |
+
for k, v in self.original_cond.items()
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def dtype(self):
|
| 189 |
+
return {
|
| 190 |
+
k: v.dtype
|
| 191 |
+
for k, v in self.original_cond.items()
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def __sub__(self, that):
|
| 195 |
+
return DictCondWrapper({
|
| 196 |
+
k: v - that.original_cond[k]
|
| 197 |
+
for k, v in self.original_cond.items()
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
def __add__(self, that):
|
| 201 |
+
return DictCondWrapper({
|
| 202 |
+
k: v + that.original_cond[k]
|
| 203 |
+
for k, v in self.original_cond.items()
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
def __eq__(self, that):
|
| 207 |
+
return all((self.original_cond[k] == that.original_cond[k]).all() for k in self.original_cond.keys())
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclasses.dataclass
|
| 211 |
+
class TensorCondWrapper:
|
| 212 |
+
original_cond: torch.Tensor
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def from_cp_values(cp_values):
|
| 216 |
+
return TensorCondWrapper(next(iter(cp_values)))
|
| 217 |
+
|
| 218 |
+
def size(self, *args, **kwargs):
|
| 219 |
+
return self.original_cond.size(*args, **kwargs)
|
| 220 |
+
|
| 221 |
+
def extend_like(self, that, empty):
|
| 222 |
+
missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
|
| 223 |
+
return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty.original_cond] * missing_size))
|
| 224 |
+
|
| 225 |
+
def resize_schedule(self, target_size, empty_cond):
|
| 226 |
+
cond_missing_size = (target_size - self.original_cond.size(0)) // 77
|
| 227 |
+
if cond_missing_size <= 0:
|
| 228 |
+
return self
|
| 229 |
+
|
| 230 |
+
return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
|
| 231 |
+
|
| 232 |
+
def to_cp_values(self):
|
| 233 |
+
return [self.original_cond]
|
| 234 |
+
|
| 235 |
+
def to(self, dtype: torch.dtype):
|
| 236 |
+
return TensorCondWrapper(self.original_cond.to(dtype=dtype))
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def dtype(self):
|
| 240 |
+
return self.original_cond.dtype
|
| 241 |
+
|
| 242 |
+
def __sub__(self, that):
|
| 243 |
+
return TensorCondWrapper(self.original_cond - that.original_cond)
|
| 244 |
+
|
| 245 |
+
def __add__(self, that):
|
| 246 |
+
return TensorCondWrapper(self.original_cond + that.original_cond)
|
| 247 |
+
|
| 248 |
+
def __eq__(self, that):
|
| 249 |
+
return (self.original_cond == that.original_cond).all()
|
prompt-fusion-extension-main/lib_prompt_fusion/prompt_parser.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lib_prompt_fusion.ast_nodes as ast
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_prompt(prompt):
|
| 10 |
+
prompt = prompt.lstrip()
|
| 11 |
+
prompt, list_expr = parse_list_expression(prompt, set())
|
| 12 |
+
return list_expr
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_list_expression(prompt, stoppers):
|
| 16 |
+
exprs = []
|
| 17 |
+
try:
|
| 18 |
+
while True:
|
| 19 |
+
prompt, expr = parse_expression(prompt, stoppers)
|
| 20 |
+
exprs.append(expr)
|
| 21 |
+
except ValueError:
|
| 22 |
+
return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_expression(prompt, stoppers):
|
| 26 |
+
for parse in _parsers():
|
| 27 |
+
try:
|
| 28 |
+
return parse(prompt, stoppers)
|
| 29 |
+
except ValueError:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
raise ValueError
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _parsers():
|
| 36 |
+
return (
|
| 37 |
+
parse_text,
|
| 38 |
+
parse_declaration,
|
| 39 |
+
parse_substitution,
|
| 40 |
+
parse_positive_attention,
|
| 41 |
+
parse_negative_attention,
|
| 42 |
+
parse_editing,
|
| 43 |
+
parse_alternation,
|
| 44 |
+
parse_interpolation,
|
| 45 |
+
parse_unrestricted_text,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_text(prompt, stoppers):
|
| 50 |
+
return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def parse_unrestricted_text(prompt, stoppers):
|
| 54 |
+
escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
|
| 55 |
+
regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
|
| 56 |
+
prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
|
| 57 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def parse_substitution(prompt, stoppers):
|
| 61 |
+
prompt, symbol = parse_symbol(prompt, stoppers)
|
| 62 |
+
prompt, arguments = parse_arguments(prompt, stoppers)
|
| 63 |
+
return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_arguments(prompt, stoppers):
|
| 67 |
+
try:
|
| 68 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 69 |
+
prompt, arguments = parse_inner_arguments(prompt, stoppers)
|
| 70 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 71 |
+
except ValueError:
|
| 72 |
+
arguments = []
|
| 73 |
+
return ParseResult(prompt=prompt, expr=arguments)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_inner_arguments(prompt, stoppers):
|
| 77 |
+
arguments = []
|
| 78 |
+
try:
|
| 79 |
+
while True:
|
| 80 |
+
prompt, arg = parse_list_expression(prompt, {',', ')'})
|
| 81 |
+
arguments.append(arg)
|
| 82 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 83 |
+
except ValueError:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
return ParseResult(prompt=prompt, expr=arguments)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def parse_declaration(prompt, stoppers):
|
| 90 |
+
prompt, symbol = parse_symbol(prompt, stoppers)
|
| 91 |
+
prompt, parameters = parse_parameters(prompt, stoppers)
|
| 92 |
+
prompt, _ = parse_equals(prompt, stoppers)
|
| 93 |
+
prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
|
| 94 |
+
prompt, _ = parse_newline(prompt, stoppers)
|
| 95 |
+
prompt, expr = parse_list_expression(prompt, stoppers)
|
| 96 |
+
return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def parse_parameters(prompt, stoppers):
|
| 100 |
+
try:
|
| 101 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 102 |
+
prompt, parameters = parse_inner_parameters(prompt, stoppers)
|
| 103 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 104 |
+
except ValueError:
|
| 105 |
+
parameters = []
|
| 106 |
+
return ParseResult(prompt=prompt, expr=parameters)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def parse_inner_parameters(prompt, stoppers):
|
| 110 |
+
parameters = []
|
| 111 |
+
try:
|
| 112 |
+
while True:
|
| 113 |
+
prompt, param = parse_symbol(prompt, stoppers)
|
| 114 |
+
parameters.append(param)
|
| 115 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 116 |
+
except ValueError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
return ParseResult(prompt=prompt, expr=parameters)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def parse_interpolation(prompt, stoppers):
|
| 123 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 124 |
+
prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
|
| 125 |
+
prompt, steps = parse_interpolation_steps(prompt, stoppers)
|
| 126 |
+
prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
|
| 127 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 128 |
+
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def parse_interpolation_exprs(prompt, stoppers):
|
| 132 |
+
exprs = []
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
while True:
|
| 136 |
+
prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
|
| 137 |
+
if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
|
| 138 |
+
raise ValueError
|
| 139 |
+
|
| 140 |
+
prompt, _ = parse_colon(prompt_tmp, stoppers)
|
| 141 |
+
exprs.append(expr)
|
| 142 |
+
except ValueError:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_interpolation_function_name(prompt, stoppers):
|
| 149 |
+
try:
|
| 150 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 151 |
+
function_names = ('linear', 'catmull', 'bezier', 'mean')
|
| 152 |
+
return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
|
| 153 |
+
except ValueError:
|
| 154 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def parse_interpolation_steps(prompt, stoppers):
|
| 158 |
+
steps = []
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
while True:
|
| 162 |
+
prompt, step = parse_interpolation_step(prompt, stoppers)
|
| 163 |
+
steps.append(step)
|
| 164 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 165 |
+
except ValueError:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
return ParseResult(prompt=prompt, expr=steps)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def parse_interpolation_step(prompt, stoppers):
|
| 172 |
+
try:
|
| 173 |
+
return parse_step(prompt, stoppers)
|
| 174 |
+
except ValueError:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
if prompt[0] in {',', ':', ']'}:
|
| 178 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 179 |
+
|
| 180 |
+
raise ValueError
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def parse_alternation(prompt, stoppers):
|
| 184 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 185 |
+
prompt, exprs = parse_alternation_exprs(prompt, stoppers)
|
| 186 |
+
prompt, speed = parse_alternation_speed(prompt, stoppers)
|
| 187 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 188 |
+
return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def parse_alternation_exprs(prompt, stoppers):
|
| 192 |
+
exprs = []
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
while True:
|
| 196 |
+
prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
|
| 197 |
+
exprs.append(expr)
|
| 198 |
+
prompt, _ = parse_vertical_bar(prompt, stoppers)
|
| 199 |
+
except ValueError:
|
| 200 |
+
if len(exprs) < 2:
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def parse_alternation_speed(prompt, stoppers):
|
| 207 |
+
try:
|
| 208 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 209 |
+
prompt, speed = parse_step(prompt, stoppers)
|
| 210 |
+
return ParseResult(prompt=prompt, expr=speed)
|
| 211 |
+
except ValueError:
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def parse_editing(prompt, stoppers):
|
| 218 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 219 |
+
prompt, exprs = parse_editing_exprs(prompt, stoppers)
|
| 220 |
+
try:
|
| 221 |
+
prompt, step = parse_step(prompt, stoppers)
|
| 222 |
+
except ValueError:
|
| 223 |
+
step = None
|
| 224 |
+
|
| 225 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 226 |
+
return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def parse_editing_exprs(prompt, stoppers):
|
| 230 |
+
exprs = []
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
for _ in range(2):
|
| 234 |
+
prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
|
| 235 |
+
prompt, _ = parse_colon(prompt_tmp, stoppers)
|
| 236 |
+
exprs.append(expr)
|
| 237 |
+
except ValueError:
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def parse_negative_attention(prompt, stoppers):
|
| 244 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 245 |
+
prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
|
| 246 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 247 |
+
return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def parse_positive_attention(prompt, stoppers):
|
| 251 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 252 |
+
prompt, expr = parse_list_expression(prompt, {':', ')'})
|
| 253 |
+
prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
|
| 254 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 255 |
+
if len(weight_exprs) >= 2:
|
| 256 |
+
return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
|
| 257 |
+
else:
|
| 258 |
+
return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def parse_attention_weights(prompt, stoppers):
|
| 262 |
+
weights = []
|
| 263 |
+
try:
|
| 264 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 265 |
+
except ValueError:
|
| 266 |
+
return ParseResult(prompt=prompt, expr=weights)
|
| 267 |
+
|
| 268 |
+
while True:
|
| 269 |
+
try:
|
| 270 |
+
prompt, weight_expr = parse_weight(prompt, stoppers)
|
| 271 |
+
weights.append(weight_expr)
|
| 272 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 273 |
+
except ValueError:
|
| 274 |
+
return ParseResult(prompt=prompt, expr=weights)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def parse_step(prompt, stoppers):
|
| 278 |
+
try:
|
| 279 |
+
prompt, step = parse_int_not_float(prompt, stoppers)
|
| 280 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 281 |
+
except ValueError:
|
| 282 |
+
pass
|
| 283 |
+
try:
|
| 284 |
+
prompt, step = parse_float(prompt, stoppers)
|
| 285 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 286 |
+
except ValueError:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
return parse_substitution(prompt, stoppers)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def parse_weight(prompt, stoppers):
|
| 293 |
+
try:
|
| 294 |
+
prompt, step = parse_float(prompt, stoppers)
|
| 295 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 296 |
+
except ValueError:
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
return parse_substitution(prompt, stoppers)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def parse_symbol(prompt, stoppers):
|
| 303 |
+
prompt, _ = parse_dollar(prompt)
|
| 304 |
+
return parse_symbol_text(prompt, stoppers)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def parse_symbol_text(prompt, stoppers):
|
| 308 |
+
return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def parse_float(prompt, stoppers):
|
| 312 |
+
return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def parse_int_not_float(prompt, stoppers):
|
| 316 |
+
return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def parse_dollar(prompt):
|
| 320 |
+
dollar_sign = re.escape('$')
|
| 321 |
+
return parse_token(prompt, f'({dollar_sign})')
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def parse_equals(prompt, stoppers):
|
| 325 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def parse_comma(prompt, stoppers):
|
| 329 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def parse_colon(prompt, stoppers):
|
| 333 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def parse_vertical_bar(prompt, stoppers):
|
| 337 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def parse_open_square(prompt, stoppers):
|
| 341 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def parse_close_square(prompt, stoppers):
|
| 345 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def parse_open_paren(prompt, stoppers):
|
| 349 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def parse_close_paren(prompt, stoppers):
|
| 353 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def parse_newline(prompt, stoppers):
|
| 357 |
+
return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def parse_token(prompt, regex):
|
| 361 |
+
match = re.match(regex, prompt)
|
| 362 |
+
if match is None:
|
| 363 |
+
raise ValueError
|
| 364 |
+
|
| 365 |
+
return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def whitespace_tail_regex(regex, stoppers):
|
| 369 |
+
if '\n' in stoppers:
|
| 370 |
+
return rf'({regex})[ \t\f\r]*'
|
| 371 |
+
|
| 372 |
+
return rf'({regex})\s*'
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def set_concat(left, right):
|
| 376 |
+
result = set(left)
|
| 377 |
+
result.update(right)
|
| 378 |
+
return result
|
prompt-fusion-extension-main/lib_prompt_fusion/t_scaler.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def scale_t(t, positions):
|
| 2 |
+
if t >= 1.:
|
| 3 |
+
return 1.
|
| 4 |
+
|
| 5 |
+
if t <= 0.:
|
| 6 |
+
return 0.
|
| 7 |
+
|
| 8 |
+
distances = []
|
| 9 |
+
for i in range(len(positions)-1):
|
| 10 |
+
distances.append(positions[i+1] - positions[i])
|
| 11 |
+
|
| 12 |
+
total_distance = sum(distances)
|
| 13 |
+
for i in range(len(distances)):
|
| 14 |
+
distances[i] = distances[i]/total_distance
|
| 15 |
+
|
| 16 |
+
for i in range(len(distances)-1):
|
| 17 |
+
distances[i+1] = distances[i] + distances[i+1]
|
| 18 |
+
|
| 19 |
+
distances.insert(0, 0.0)
|
| 20 |
+
|
| 21 |
+
spline_index = 0
|
| 22 |
+
for i, distance in enumerate(distances):
|
| 23 |
+
if t > distance:
|
| 24 |
+
spline_index = i
|
| 25 |
+
else:
|
| 26 |
+
break
|
| 27 |
+
|
| 28 |
+
if spline_index >= len(distances) - 1:
|
| 29 |
+
return 1
|
| 30 |
+
|
| 31 |
+
local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
|
| 32 |
+
return (spline_index + local_ratio)/(len(distances)-1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
total_steps = 20
|
| 37 |
+
for i in range(total_steps):
|
| 38 |
+
print(i, scale_t(i/total_steps, [9, 10]))
|
prompt-fusion-extension-main/metadata.ini
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Extension]
|
| 2 |
+
Name = prompt-fusion
|
prompt-fusion-extension-main/readme.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Prompt Fusion
|
| 2 |
+
|
| 3 |
+
Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
# linear prompt interpolation
|
| 7 |
+
[night light:magical forest: 5, 15]
|
| 8 |
+
|
| 9 |
+
# catmull-rom curve prompt interpolation
|
| 10 |
+
[night light:magical forest:norvegian territory: 5, 15, 25:catmull]
|
| 11 |
+
|
| 12 |
+
# alternation interpolation
|
| 13 |
+
[ufo|a strange sight:0.5]
|
| 14 |
+
|
| 15 |
+
# linear attention interpolation
|
| 16 |
+
(fire extinguisher: 1.0, 2.0)
|
| 17 |
+
|
| 18 |
+
# prompt-editing-aware attention interpolation
|
| 19 |
+
[(fire extinguisher: 1.0, 2.0)::5]
|
| 20 |
+
|
| 21 |
+
# weighted sum of conditions
|
| 22 |
+
[space station : kitchen mixer :: mean]
|
| 23 |
+
|
| 24 |
+
# define functions and variables to simplify repeating patterns and use a consistent structure
|
| 25 |
+
$prompt($style, $quality, $character, $background) = (
|
| 26 |
+
A detailed picture in the style of $style,
|
| 27 |
+
$quality,
|
| 28 |
+
$character lying back,
|
| 29 |
+
$background in the background
|
| 30 |
+
:1)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Features
|
| 34 |
+
|
| 35 |
+
- [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
|
| 36 |
+
- [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
|
| 37 |
+
- [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
|
| 38 |
+
- [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
|
| 39 |
+
- [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
|
| 40 |
+
- Complete backwards compatibility with the builtin prompt syntax of the webui
|
| 41 |
+
|
| 42 |
+
The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
|
| 43 |
+
|
| 44 |
+
Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
|
| 45 |
+
|
| 46 |
+
The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
|
| 47 |
+
|
| 48 |
+
## Usage
|
| 49 |
+
- Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
|
| 50 |
+
|
| 51 |
+
## Examples
|
| 52 |
+
|
| 53 |
+
### 1. Influencing the beginning of the sampling process
|
| 54 |
+
|
| 55 |
+
Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
[lion:bird:girl: , 7, 10]
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+

|
| 62 |
+
|
| 63 |
+
### 2. Influencing the middle of the sampling process
|
| 64 |
+
|
| 65 |
+
Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
[fireball:seawater:dragon: , .1, .4:bezier] monster
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
## Webui supported releases
|
| 74 |
+
|
| 75 |
+
The following webui releases are officially supported:
|
| 76 |
+
- `v1.0.0-pre`
|
| 77 |
+
- `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
|
| 78 |
+
|
| 79 |
+
## Installation
|
| 80 |
+
1. Visit the **Extensions** tab of Automatic's WebUI.
|
| 81 |
+
2. Visit the **Available** subtab.
|
| 82 |
+
3. Look for **Prompt Fusion**.
|
| 83 |
+
4. Press the **Install** button.
|
| 84 |
+
5. Wait for the webui to finish downloading the extension.
|
| 85 |
+
6. Visit the **Installed** subtab.
|
| 86 |
+
7. click on **Apply and restart UI**.
|
| 87 |
+
|
| 88 |
+
Alternatively, instead of steps 6 and 7, you can restart the webui completely.
|
| 89 |
+
|
| 90 |
+
## Related Projects
|
| 91 |
+
|
| 92 |
+
- Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
|
| 93 |
+
- Shift Attention: https://github.com/yownas/shift-attention
|
| 94 |
+
- Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
|
| 95 |
+
|
prompt-fusion-extension-main/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
prompt-fusion-extension-main/scripts/promptlang.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import re
|
| 3 |
+
from lib_prompt_fusion import hijacker, empty_cond, global_state, interpolation_tensor, prompt_parser as prompt_fusion_parser
|
| 4 |
+
from modules import scripts, script_callbacks, prompt_parser, shared
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
fusion_hijacker_attribute = '__fusion_hijacker'
|
| 8 |
+
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 9 |
+
module=prompt_parser,
|
| 10 |
+
hijacker_attribute=fusion_hijacker_attribute,
|
| 11 |
+
register_uninstall=script_callbacks.on_script_unloaded)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _neutral_prompt_enabled() -> bool:
|
| 15 |
+
try:
|
| 16 |
+
from lib_neutral_prompt import global_state as np_state
|
| 17 |
+
return bool(getattr(np_state, "is_enabled", False))
|
| 18 |
+
except Exception:
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def on_ui_settings():
|
| 24 |
+
section = ('prompt-fusion', 'Prompt Fusion')
|
| 25 |
+
shared.opts.add_option('prompt_fusion_enabled', shared.OptionInfo(True, 'Enable prompt-fusion extension', section=section))
|
| 26 |
+
shared.opts.add_option('prompt_fusion_slerp_scale', shared.OptionInfo(0, 'Slerp scale (0 = linear geometry, 1 = slerp geometry)', component=gr.Number, section=section))
|
| 27 |
+
shared.opts.add_option('prompt_fusion_slerp_negative_origin', shared.OptionInfo(True, 'use negative prompt as slerp origin', section=section))
|
| 28 |
+
shared.opts.add_option('prompt_fusion_slerp_epsilon', shared.OptionInfo(0.0001, 'Slerp epsilon (fallback on linear geometry when conds are too similar. 0 = parallel, 1 = perpendicular)', component=gr.Number, section=section))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@prompt_parser_hijacker.hijack('get_learned_conditioning')
|
| 35 |
+
def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
|
| 36 |
+
if _neutral_prompt_enabled():
|
| 37 |
+
return original_function(model, prompts, total_steps, *args, **kwargs)
|
| 38 |
+
|
| 39 |
+
if not shared.opts.prompt_fusion_enabled:
|
| 40 |
+
return original_function(model, prompts, total_steps, *args, **kwargs)
|
| 41 |
+
|
| 42 |
+
hires_steps, use_old_scheduling, *_ = args if args else (None, True)
|
| 43 |
+
is_hires = hires_steps is not None
|
| 44 |
+
if is_hires:
|
| 45 |
+
real_total_steps = hires_steps
|
| 46 |
+
else:
|
| 47 |
+
real_total_steps = total_steps
|
| 48 |
+
|
| 49 |
+
if hasattr(prompts, 'is_negative_prompt'):
|
| 50 |
+
is_negative_prompt = prompts.is_negative_prompt
|
| 51 |
+
else:
|
| 52 |
+
is_negative_prompt = global_state.old_webui_is_negative
|
| 53 |
+
|
| 54 |
+
empty_cond.init(model)
|
| 55 |
+
|
| 56 |
+
tensor_builders = _parse_tensor_builders(prompts, real_total_steps, is_hires, use_old_scheduling)
|
| 57 |
+
if hasattr(prompt_parser, 'SdConditioning'):
|
| 58 |
+
empty_conditioning = prompt_parser.SdConditioning(prompts)
|
| 59 |
+
empty_conditioning.clear()
|
| 60 |
+
else:
|
| 61 |
+
empty_conditioning = []
|
| 62 |
+
|
| 63 |
+
flattened_prompts, consecutive_ranges = _get_flattened_prompts(tensor_builders, empty_conditioning)
|
| 64 |
+
flattened_schedules = original_function(model, flattened_prompts, total_steps, *args, **kwargs)
|
| 65 |
+
|
| 66 |
+
if isinstance(flattened_schedules[0][0].cond, dict): # sdxl
|
| 67 |
+
CondWrapper = interpolation_tensor.DictCondWrapper
|
| 68 |
+
else:
|
| 69 |
+
CondWrapper = interpolation_tensor.TensorCondWrapper
|
| 70 |
+
|
| 71 |
+
flattened_schedules = [
|
| 72 |
+
[
|
| 73 |
+
prompt_parser.ScheduledPromptConditioning(cond=CondWrapper(schedule.cond), end_at_step=schedule.end_at_step)
|
| 74 |
+
for schedule in subschedules
|
| 75 |
+
]
|
| 76 |
+
for subschedules in flattened_schedules
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
cond_tensors = (tensor_builder.build(flattened_schedules[begin:end], empty_cond.get())
|
| 80 |
+
for begin, end, tensor_builder
|
| 81 |
+
in zip(consecutive_ranges[:-1], consecutive_ranges[1:], tensor_builders))
|
| 82 |
+
|
| 83 |
+
schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires)
|
| 84 |
+
for cond_tensor in cond_tensors]
|
| 85 |
+
|
| 86 |
+
if is_negative_prompt:
|
| 87 |
+
if hires_steps is not None:
|
| 88 |
+
global_state.negative_schedules_hires = schedules[0]
|
| 89 |
+
else:
|
| 90 |
+
global_state.negative_schedules = schedules[0]
|
| 91 |
+
|
| 92 |
+
schedules = [
|
| 93 |
+
[
|
| 94 |
+
prompt_parser.ScheduledPromptConditioning(cond=schedule.cond.original_cond, end_at_step=schedule.end_at_step)
|
| 95 |
+
for schedule in subschedules
|
| 96 |
+
]
|
| 97 |
+
for subschedules in schedules
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
return schedules
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@prompt_parser_hijacker.hijack('get_multicond_learned_conditioning')
|
| 104 |
+
def _hijacked_get_multicond_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
|
| 105 |
+
if _neutral_prompt_enabled():
|
| 106 |
+
try:
|
| 107 |
+
return original_function(model, prompts, total_steps, *args, **kwargs)
|
| 108 |
+
finally:
|
| 109 |
+
global_state.old_webui_is_negative = False
|
| 110 |
+
|
| 111 |
+
if not shared.opts.prompt_fusion_enabled:
|
| 112 |
+
try:
|
| 113 |
+
return original_function(model, prompts, total_steps, *args, **kwargs)
|
| 114 |
+
finally:
|
| 115 |
+
global_state.old_webui_is_negative = False
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
hires_steps, use_old_scheduling, *_ = args if args else (None, True)
|
| 119 |
+
|
| 120 |
+
res_indexes, prompt_flat_list = _get_multicond_prompt_list(prompts)
|
| 121 |
+
learned_conditioning = prompt_parser.get_learned_conditioning(
|
| 122 |
+
model,
|
| 123 |
+
prompt_flat_list,
|
| 124 |
+
total_steps,
|
| 125 |
+
hires_steps,
|
| 126 |
+
use_old_scheduling,
|
| 127 |
+
**kwargs,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
batch = [
|
| 131 |
+
[prompt_parser.ComposableScheduledPromptConditioning(learned_conditioning[i], weight)
|
| 132 |
+
for i, weight in indexes]
|
| 133 |
+
for indexes in res_indexes
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
return prompt_parser.MulticondLearnedConditioning(
|
| 137 |
+
shape=_infer_multicond_shape(learned_conditioning),
|
| 138 |
+
batch=batch,
|
| 139 |
+
)
|
| 140 |
+
finally:
|
| 141 |
+
global_state.old_webui_is_negative = False
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _parse_tensor_builders(prompts, total_steps, is_hires, use_old_scheduling):
|
| 145 |
+
tensor_builders = []
|
| 146 |
+
|
| 147 |
+
for prompt in prompts:
|
| 148 |
+
expr = prompt_fusion_parser.parse_prompt(prompt)
|
| 149 |
+
tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
|
| 150 |
+
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires, use_old_scheduling)
|
| 151 |
+
tensor_builders.append(tensor_builder)
|
| 152 |
+
|
| 153 |
+
return tensor_builders
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _get_flattened_prompts(tensor_builders, flattened_prompts=None):
|
| 157 |
+
if flattened_prompts is None:
|
| 158 |
+
flattened_prompts = []
|
| 159 |
+
consecutive_ranges = [0]
|
| 160 |
+
|
| 161 |
+
for tensor_builder in tensor_builders:
|
| 162 |
+
flattened_prompts.extend(tensor_builder.get_prompt_database())
|
| 163 |
+
consecutive_ranges.append(len(flattened_prompts))
|
| 164 |
+
|
| 165 |
+
return flattened_prompts, consecutive_ranges
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _sample_tensor_schedules(tensor, steps, is_hires):
|
| 169 |
+
schedules = []
|
| 170 |
+
|
| 171 |
+
for step in range(steps):
|
| 172 |
+
origin_cond = global_state.get_origin_cond_at(step, is_hires)
|
| 173 |
+
params = interpolation_tensor.InterpolationParams(step / steps, step, steps, global_state.get_slerp_scale(), global_state.get_slerp_epsilon())
|
| 174 |
+
schedule_cond = tensor.interpolate(params, origin_cond, empty_cond.get())
|
| 175 |
+
if schedules and schedules[-1].cond == schedule_cond:
|
| 176 |
+
schedules[-1] = prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedules[-1].cond)
|
| 177 |
+
else:
|
| 178 |
+
schedules.append(prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedule_cond))
|
| 179 |
+
|
| 180 |
+
return schedules
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class PromptFusionScript(scripts.Script):
|
| 184 |
+
def title(self):
|
| 185 |
+
return 'Prompt Fusion'
|
| 186 |
+
|
| 187 |
+
def show(self, is_img2img):
|
| 188 |
+
return scripts.AlwaysVisible
|
| 189 |
+
|
| 190 |
+
def process(self, p, *args):
|
| 191 |
+
global_state.negative_schedules = None
|
| 192 |
+
global_state.old_webui_is_negative = True
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _get_multicond_prompt_list(prompts):
|
| 196 |
+
prompt_indexes = {}
|
| 197 |
+
prompt_flat_list = prompt_parser.SdConditioning([], copy_from=prompts)
|
| 198 |
+
res_indexes = []
|
| 199 |
+
|
| 200 |
+
for prompt in prompts:
|
| 201 |
+
indexes = []
|
| 202 |
+
for text, weight in _split_multicond_prompt(prompt):
|
| 203 |
+
index = prompt_indexes.get(text)
|
| 204 |
+
if index is None:
|
| 205 |
+
index = len(prompt_flat_list)
|
| 206 |
+
prompt_flat_list.append(text)
|
| 207 |
+
prompt_indexes[text] = index
|
| 208 |
+
indexes.append((index, weight))
|
| 209 |
+
res_indexes.append(indexes)
|
| 210 |
+
|
| 211 |
+
return res_indexes, prompt_flat_list
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _split_multicond_prompt(prompt):
|
| 215 |
+
parts = []
|
| 216 |
+
buf = []
|
| 217 |
+
depth_paren = depth_brack = depth_brace = 0
|
| 218 |
+
escaped = False
|
| 219 |
+
i = 0
|
| 220 |
+
|
| 221 |
+
while i < len(prompt):
|
| 222 |
+
ch = prompt[i]
|
| 223 |
+
|
| 224 |
+
if escaped:
|
| 225 |
+
buf.append(ch)
|
| 226 |
+
escaped = False
|
| 227 |
+
i += 1
|
| 228 |
+
continue
|
| 229 |
+
|
| 230 |
+
if ch == '\\':
|
| 231 |
+
buf.append(ch)
|
| 232 |
+
escaped = True
|
| 233 |
+
i += 1
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
if ch == '(':
|
| 237 |
+
depth_paren += 1
|
| 238 |
+
elif ch == ')' and depth_paren > 0:
|
| 239 |
+
depth_paren -= 1
|
| 240 |
+
elif ch == '[':
|
| 241 |
+
depth_brack += 1
|
| 242 |
+
elif ch == ']' and depth_brack > 0:
|
| 243 |
+
depth_brack -= 1
|
| 244 |
+
elif ch == '{':
|
| 245 |
+
depth_brace += 1
|
| 246 |
+
elif ch == '}' and depth_brace > 0:
|
| 247 |
+
depth_brace -= 1
|
| 248 |
+
|
| 249 |
+
if depth_paren == 0 and depth_brack == 0 and depth_brace == 0:
|
| 250 |
+
if _matches_top_level_and(prompt, i):
|
| 251 |
+
_append_multicond_piece(parts, ''.join(buf))
|
| 252 |
+
buf = []
|
| 253 |
+
i += 3
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
if _matches_top_level_amp(prompt, i):
|
| 257 |
+
_append_multicond_piece(parts, ''.join(buf))
|
| 258 |
+
buf = []
|
| 259 |
+
i += 1
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
buf.append(ch)
|
| 263 |
+
i += 1
|
| 264 |
+
|
| 265 |
+
_append_multicond_piece(parts, ''.join(buf))
|
| 266 |
+
|
| 267 |
+
return parts or [('', 1.0)]
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _append_multicond_piece(parts, piece):
|
| 271 |
+
text, weight = _split_tail_weight(piece)
|
| 272 |
+
parts.append((text, weight))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def _split_tail_weight(piece):
|
| 276 |
+
piece = piece.strip()
|
| 277 |
+
if not piece:
|
| 278 |
+
return '', 1.0
|
| 279 |
+
|
| 280 |
+
depth_paren = depth_brack = depth_brace = 0
|
| 281 |
+
escaped = False
|
| 282 |
+
split_at = None
|
| 283 |
+
|
| 284 |
+
for i, ch in enumerate(piece):
|
| 285 |
+
if escaped:
|
| 286 |
+
escaped = False
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
if ch == '\\':
|
| 290 |
+
escaped = True
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
if ch == '(':
|
| 294 |
+
depth_paren += 1
|
| 295 |
+
elif ch == ')' and depth_paren > 0:
|
| 296 |
+
depth_paren -= 1
|
| 297 |
+
elif ch == '[':
|
| 298 |
+
depth_brack += 1
|
| 299 |
+
elif ch == ']' and depth_brack > 0:
|
| 300 |
+
depth_brack -= 1
|
| 301 |
+
elif ch == '{':
|
| 302 |
+
depth_brace += 1
|
| 303 |
+
elif ch == '}' and depth_brace > 0:
|
| 304 |
+
depth_brace -= 1
|
| 305 |
+
elif ch in '::' and depth_paren == 0 and depth_brack == 0 and depth_brace == 0:
|
| 306 |
+
split_at = i
|
| 307 |
+
|
| 308 |
+
if split_at is None:
|
| 309 |
+
return piece, 1.0
|
| 310 |
+
|
| 311 |
+
suffix = piece[split_at + 1:].strip()
|
| 312 |
+
if not re.fullmatch(r'[-+]?(?:\d+\.?|\d*\.\d+)', suffix):
|
| 313 |
+
return piece, 1.0
|
| 314 |
+
|
| 315 |
+
text = piece[:split_at].rstrip()
|
| 316 |
+
if not text:
|
| 317 |
+
return piece, 1.0
|
| 318 |
+
|
| 319 |
+
return text, float(suffix)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _matches_top_level_and(prompt, index):
|
| 323 |
+
if prompt[index:index + 3] != 'AND':
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
prev_ch = prompt[index - 1] if index > 0 else ''
|
| 327 |
+
next_ch = prompt[index + 3] if index + 3 < len(prompt) else ''
|
| 328 |
+
|
| 329 |
+
if (prev_ch.isalnum() or prev_ch == '_') or (next_ch.isalnum() or next_ch == '_'):
|
| 330 |
+
return False
|
| 331 |
+
|
| 332 |
+
return not any(prompt.startswith(f'AND{suffix}', index) for suffix in ('_PERP', '_SALT', '_TOPK'))
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _matches_top_level_amp(prompt, index):
|
| 336 |
+
if prompt[index] != '&':
|
| 337 |
+
return False
|
| 338 |
+
|
| 339 |
+
prev_ch = prompt[index - 1] if index > 0 else ' '
|
| 340 |
+
next_ch = prompt[index + 1] if index + 1 < len(prompt) else ' '
|
| 341 |
+
return prev_ch.isspace() and next_ch.isspace()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _infer_multicond_shape(learned_conditioning):
|
| 345 |
+
if not learned_conditioning or not learned_conditioning[0]:
|
| 346 |
+
return (0,)
|
| 347 |
+
|
| 348 |
+
cond = learned_conditioning[0][0].cond
|
| 349 |
+
if isinstance(cond, dict):
|
| 350 |
+
crossattn = cond.get('crossattn')
|
| 351 |
+
if isinstance(crossattn, list) and crossattn:
|
| 352 |
+
return getattr(crossattn[0], 'shape', None) or (0,)
|
| 353 |
+
return getattr(crossattn, 'shape', None) or (0,)
|
| 354 |
+
|
| 355 |
+
return getattr(cond, 'shape', None) or (0,)
|
prompt-fusion-extension-main/test/parser_tests.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_prompt_fusion.prompt_parser import parse_prompt
|
| 2 |
+
from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def run_functional_tests(total_steps=100):
|
| 6 |
+
for i, (given, expected) in enumerate(functional_parse_test_cases):
|
| 7 |
+
expr = parse_prompt(given)
|
| 8 |
+
tensor_builder = InterpolationTensorBuilder()
|
| 9 |
+
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
|
| 10 |
+
|
| 11 |
+
actual = tensor_builder.get_prompt_database()
|
| 12 |
+
|
| 13 |
+
if type(expected) is set:
|
| 14 |
+
assert set(actual) == expected, f"{actual} != {expected}"
|
| 15 |
+
else:
|
| 16 |
+
assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
functional_parse_test_cases = [
|
| 20 |
+
('single',)*2,
|
| 21 |
+
('some space separated text',)*2,
|
| 22 |
+
('(legacy weighted prompt:-2.1)',)*2,
|
| 23 |
+
('mixed (legacy weight:3.6) and text',)*2,
|
| 24 |
+
('legacy [range begin:0] thingy',)*2,
|
| 25 |
+
('legacy [range end::3] thingy',)*2,
|
| 26 |
+
('legacy [[nested range::3]:2] thingy',)*2,
|
| 27 |
+
('legacy [[nested range:2]::3] thingy',)*2,
|
| 28 |
+
('sugar [range:,abc:3] thingy',)*2,
|
| 29 |
+
('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
|
| 30 |
+
('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
|
| 31 |
+
('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
|
| 32 |
+
('legacy [from:to:2] thingy',)*2,
|
| 33 |
+
('legacy [negative weight]',)*2,
|
| 34 |
+
('legacy (positive weight)',)*2,
|
| 35 |
+
('[abc:1girl:2]',)*2,
|
| 36 |
+
('[::]',)*2,
|
| 37 |
+
('[a:b:]',)*2,
|
| 38 |
+
('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
|
| 39 |
+
('1girl',)*2,
|
| 40 |
+
('dashes-in-text',)*2,
|
| 41 |
+
('text, separated with, comas',)*2,
|
| 42 |
+
('{prompt}',)*2,
|
| 43 |
+
('[abc|def ghi|jkl]',)*2,
|
| 44 |
+
('merging this AND with this',)*2,
|
| 45 |
+
(':',)*2,
|
| 46 |
+
(r'portrait \(object\)',)*2,
|
| 47 |
+
(r'\[escaped square\]',)*2,
|
| 48 |
+
(r'\$var = abc',)*2,
|
| 49 |
+
(r'\\$ arst',)*2,
|
| 50 |
+
(r'$$ arst',)*2,
|
| 51 |
+
('$var = abc', ''),
|
| 52 |
+
('$a = prompt value\n$a', 'prompt value'),
|
| 53 |
+
('$a = prompt value\n$b = $a\n$b', 'prompt value'),
|
| 54 |
+
('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
|
| 55 |
+
('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
|
| 56 |
+
('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
|
| 57 |
+
('a [b:c:5, 6] d', {'a b d', 'a c d'}),
|
| 58 |
+
('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
|
| 59 |
+
('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
|
| 60 |
+
('a [b:c:,] d', {'a b d', 'a c d'}),
|
| 61 |
+
('0[1.0:1.1:,]2[3.0:3.1:,]4', {
|
| 62 |
+
'0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
|
| 63 |
+
'0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
|
| 64 |
+
}),
|
| 65 |
+
('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
|
| 66 |
+
'0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
|
| 67 |
+
'0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
|
| 68 |
+
'0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
|
| 69 |
+
}),
|
| 70 |
+
('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
|
| 71 |
+
'0.0 1.0 2.0', '0.0 1.0 2.1',
|
| 72 |
+
'0.1 1.0 2.0', '0.1 1.0 2.1',
|
| 73 |
+
'0.0 1.1 2.0', '0.0 1.1 2.1',
|
| 74 |
+
'0.1 1.1 2.0', '0.1 1.1 2.1',
|
| 75 |
+
}),
|
| 76 |
+
('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
|
| 77 |
+
('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
|
| 78 |
+
('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
|
| 79 |
+
('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
|
| 80 |
+
('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
|
| 81 |
+
('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
|
| 82 |
+
('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
|
| 83 |
+
('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
|
| 84 |
+
('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
|
| 85 |
+
('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
|
| 86 |
+
('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
|
| 87 |
+
('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
|
| 88 |
+
('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
|
| 89 |
+
('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
|
| 90 |
+
('[a|b|c]', '[a|b|c]'),
|
| 91 |
+
('[a|b|c:]', '[a|b|c]'),
|
| 92 |
+
('[a|b|c:1]', {'a', 'b', 'c'}),
|
| 93 |
+
('[a|b|c:2]', {'a', 'b', 'c'}),
|
| 94 |
+
('[a|b|c:0.5]', {'a', 'b', 'c'}),
|
| 95 |
+
('[a|b|c:1.1]', {'a', 'b', 'c'}),
|
| 96 |
+
('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
|
| 97 |
+
('[a:b:c::mean]', {'a', 'b', 'c'}),
|
| 98 |
+
('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
|
| 99 |
+
('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def run_tests():
|
| 104 |
+
run_functional_tests()
|
prompt-fusion-extension-main/test/run_all.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('..')
|
| 3 |
+
import parser_tests
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
parser_tests.run_tests()
|