dikdimon commited on
Commit
d70eb82
·
verified ·
1 Parent(s): a3b01e8

Upload 37 files

Browse files
Files changed (37) hide show
  1. neutral_prompt_patched/.gitignore +1 -0
  2. neutral_prompt_patched/LICENSE +21 -0
  3. neutral_prompt_patched/README.md +99 -0
  4. neutral_prompt_patched/lib_neutral_prompt/__init__.py +1 -0
  5. neutral_prompt_patched/lib_neutral_prompt/affine_transform.py +83 -0
  6. neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py +684 -0
  7. neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py +23 -0
  8. neutral_prompt_patched/lib_neutral_prompt/external_code/api.py +27 -0
  9. neutral_prompt_patched/lib_neutral_prompt/global_state.py +75 -0
  10. neutral_prompt_patched/lib_neutral_prompt/hijacker.py +34 -0
  11. neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py +382 -0
  12. neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py +134 -0
  13. neutral_prompt_patched/lib_neutral_prompt/ui.py +196 -0
  14. neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py +42 -0
  15. neutral_prompt_patched/scripts/neutral_prompt.py +94 -0
  16. neutral_prompt_patched/test/perp_parser/__init__.py +0 -0
  17. neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py +133 -0
  18. neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py +217 -0
  19. neutral_prompt_patched/test/perp_parser/test_basic_parser.py +122 -0
  20. neutral_prompt_patched/test/perp_parser/test_malicious_parser.py +182 -0
  21. prompt-fusion-extension-main/.gitignore +3 -0
  22. prompt-fusion-extension-main/LICENSE +21 -0
  23. prompt-fusion-extension-main/lib_prompt_fusion/ast_nodes.py +307 -0
  24. prompt-fusion-extension-main/lib_prompt_fusion/empty_cond.py +19 -0
  25. prompt-fusion-extension-main/lib_prompt_fusion/geometries.py +33 -0
  26. prompt-fusion-extension-main/lib_prompt_fusion/global_state.py +28 -0
  27. prompt-fusion-extension-main/lib_prompt_fusion/hijacker.py +34 -0
  28. prompt-fusion-extension-main/lib_prompt_fusion/interpolation_functions.py +87 -0
  29. prompt-fusion-extension-main/lib_prompt_fusion/interpolation_tensor.py +249 -0
  30. prompt-fusion-extension-main/lib_prompt_fusion/prompt_parser.py +378 -0
  31. prompt-fusion-extension-main/lib_prompt_fusion/t_scaler.py +38 -0
  32. prompt-fusion-extension-main/metadata.ini +2 -0
  33. prompt-fusion-extension-main/readme.md +95 -0
  34. prompt-fusion-extension-main/requirements.txt +1 -0
  35. prompt-fusion-extension-main/scripts/promptlang.py +355 -0
  36. prompt-fusion-extension-main/test/parser_tests.py +104 -0
  37. prompt-fusion-extension-main/test/run_all.py +7 -0
neutral_prompt_patched/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
neutral_prompt_patched/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 ljleb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
neutral_prompt_patched/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sd-webui-neutral-prompt — Ultimate Edition
2
+
3
+ > **Unified merge of the main experimental branches.**
4
+ > Most features from all branches are merged and work simultaneously.
5
+ > Some `life`/dev-only tooling (debug visualisations, mask images) was intentionally not carried over — only the production-safe subset is included.
6
+
7
+ ---
8
+
9
+ ## What's inside
10
+
11
+ | Feature | Origin branch | Keyword / setting |
12
+ |---|---|---|
13
+ | Perpendicular projection | `main` | `AND_PERP` |
14
+ | Saliency-guided mask | `main` | `AND_SALT` |
15
+ | Semantic guidance top-k | `main` | `AND_TOPK` |
16
+ | CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
17
+ | XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
18
+ | External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
19
+ | CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
20
+ | Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` prefix |
21
+ | Soft alignment blend | `alignment_blend` | `AND_ALIGN_D_S` |
22
+ | Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN_D_S` |
23
+ | Sharpened saliency maps (k=20) | `life` (production-safe subset) | used automatically by `AND_SALT` |
24
+
25
+ ---
26
+
27
+ ## Prompt syntax
28
+
29
+ ```
30
+ positive prompt text
31
+ AND_PERP negative-direction text :0.8
32
+ AND_SALT competing concept :1.0
33
+ AND_TOPK fine detail tweak :0.5
34
+ AND_ALIGN_4_8 style blend :0.6
35
+ AND_MASK_ALIGN_4_8 structure-preserving style :0.6
36
+ AND_PERP ROTATE[0.125] rotated perpendicular concept :1.0
37
+ ROTATE[0.125] AND_PERP rotated perpendicular concept :1.0
38
+ AND_SALT SLIDE[0.05,0] spatially shifted concept :1.0
39
+ ```
40
+
41
+ ### Affine transform keywords
42
+ Affine keywords are supported in **either order** relative to the conciliation keyword:
43
+
44
+ ```
45
+ AND_PERP ROTATE[0.125] vivid colors :0.8 ← affine after keyword
46
+ ROTATE[0.125] AND_PERP vivid colors :0.8 ← affine before keyword (both work)
47
+ ```
48
+
49
+ | Keyword | Parameter | Effect |
50
+ |---|---|---|
51
+ | `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
52
+ | `SLIDE[x,y]` | normalised offset | Translate latent contribution |
53
+ | `SCALE[x,y]` | scale factors | Scale latent contribution |
54
+ | `SHEAR[x,y]` | shear in turns | Shear latent contribution |
55
+
56
+ Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
57
+
58
+ ### AND_ALIGN_D_S / AND_MASK_ALIGN_D_S
59
+ - **D** = detail kernel size (small → fine detail)
60
+ - **S** = structure kernel size (large → global composition)
61
+ - The child prompt is blended in proportionally to how much it alters detail *without* changing structure.
62
+ - `AND_ALIGN` uses a soft weight; `AND_MASK_ALIGN` uses a binary 0/1 mask.
63
+ - Supported range: D, S ∈ [2, 32], D ≠ S.
64
+
65
+ ---
66
+
67
+ ## External API
68
+
69
+ ```python
70
+ # In another extension or script:
71
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
72
+
73
+ override_cfg_rescale(0.7) # override for next step only
74
+ factor = get_last_cfg_rescale_factor() # read after generation
75
+ ```
76
+
77
+ ---
78
+
79
+ ## Installation
80
+
81
+ 1. Clone / copy this folder into `stable-diffusion-webui/extensions/`.
82
+ 2. Restart the webui.
83
+
84
+ ---
85
+
86
+ ## CFG Rescale φ
87
+
88
+ When set to a value > 0 the extension rescales the CFG output to have the
89
+ same mean and a partially rescaled standard deviation as the raw predicted x0.
90
+ This reduces colour over-saturation at high CFG values without affecting
91
+ prompt adherence as much as simply lowering CFG.
92
+
93
+ The formula used is the **mean-preserving** variant from the `main` branch:
94
+
95
+ ```
96
+ rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
97
+ ```
98
+
99
+ where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
neutral_prompt_patched/lib_neutral_prompt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # lib_neutral_prompt package
neutral_prompt_patched/lib_neutral_prompt/affine_transform.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Affine spatial transform utilities for latent tensors.
3
+ Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
4
+
5
+ Provides:
6
+ apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
7
+ apply_masked_transform() – apply affine + cosine-feathered mask blending
8
+ create_cosine_feathered_mask() – smooth circular weight mask
9
+ """
10
+
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def apply_affine_transform(
18
+ tensor: torch.Tensor,
19
+ affine: torch.Tensor,
20
+ mode: str = 'bilinear',
21
+ ) -> torch.Tensor:
22
+ """
23
+ Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
24
+
25
+ :param tensor: Input tensor of shape [C, H, W].
26
+ :param affine: 2×3 float32 affine matrix.
27
+ :param mode: Interpolation mode for grid_sample ('bilinear' default).
28
+ The original affine branch used bilinear for the pre-noise
29
+ inverse path (smoother) and nearest for post-combine.
30
+ Bilinear is the better default for both.
31
+ :return: Transformed tensor of shape [C, H, W].
32
+ """
33
+ affine = affine.clone().to(tensor.device)
34
+ aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
35
+ affine[0, 1] *= aspect_ratio
36
+ affine[1, 0] /= aspect_ratio
37
+
38
+ grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
39
+ return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
40
+
41
+
42
+ def apply_masked_transform(
43
+ tensor: torch.Tensor,
44
+ affine: torch.Tensor,
45
+ weight: float,
46
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ """
48
+ Apply an affine transform with a cosine-feathered spatial weight mask.
49
+
50
+ The mask channel is appended to the tensor before transformation so that
51
+ the same spatial warp is applied to both content and mask consistently.
52
+
53
+ :param tensor: C×H×W latent tensor.
54
+ :param affine: 2×3 affine matrix.
55
+ :param weight: Global scalar weight for the mask.
56
+ :return: (transformed_content [C, H, W], transformed_mask [H, W])
57
+ """
58
+ mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
59
+ tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
60
+ transformed = apply_affine_transform(tensor_with_mask, affine)
61
+ return transformed[:-1], transformed[-1] # content, mask
62
+
63
+
64
+ def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
65
+ """
66
+ Create a circularly-clipped cosine-feathered mask of shape [H, W].
67
+
68
+ Values at the centre approach `weight`; values outside the unit circle
69
+ are exactly 0, with a smooth cosine fall-off in between.
70
+
71
+ :param size: (H, W) tuple.
72
+ :param weight: Peak mask value at the centre.
73
+ :return: Float32 mask tensor of shape [H, W].
74
+ """
75
+ y, x = torch.meshgrid(
76
+ torch.linspace(-1, 1, size[0]),
77
+ torch.linspace(-1, 1, size[1]),
78
+ indexing='ij',
79
+ )
80
+ dist = torch.sqrt(x ** 2 + y ** 2)
81
+ mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
82
+ mask[dist > 1] = 0.0
83
+ return mask.float() * weight
neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified CFG Denoiser Hijack
3
+ ===========================
4
+ Combines features from all sd-webui-neutral-prompt branches:
5
+
6
+ • Core PERP / SALT / TOPK strategies (main branch)
7
+ • cfg_rescale with mean-preserving formula (main branch – better than multiplicative)
8
+ • cfg_rescale_override (XYZ-grid / API support) (main branch)
9
+ • CFGRescaleFactorSingleton export (export_rescale_factor branch)
10
+ • Affine spatial transforms per-prompt (affine branch)
11
+ • AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
12
+ • AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
13
+ • Improved salience with sharpness parameter k (life branch – production-safe subset)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import dataclasses
19
+ import functools
20
+ import re
21
+ import sys
22
+ import textwrap
23
+ from typing import Dict, List, Tuple
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from lib_neutral_prompt import affine_transform as affine_mod
29
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
30
+ from modules import script_callbacks, sd_samplers, shared
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Pre-noise affine hook (affine branch)
35
+ #
36
+ # Applied BEFORE each CFG denoising step: the noisy latent slice for every
37
+ # cond prompt is warped by the INVERSE of that prompt's affine transform.
38
+ # This means the UNet sees the latent in the "locally-rotated" frame of each
39
+ # prompt. combine_denoised then applies the FORWARD transform to the
40
+ # resulting cond delta, restoring it to the global frame.
41
+ # ---------------------------------------------------------------------------
42
+
43
+ @dataclasses.dataclass
44
+ class _PreNoiseArgs:
45
+ x: torch.Tensor
46
+ cond_indices: List[Tuple[int, float]]
47
+
48
+
49
+ class _GlobalToLocalAffineVisitor:
50
+ """Build a {cond_index: inverse_affine} dict for every leaf prompt."""
51
+
52
+ def visit_leaf_prompt(
53
+ self,
54
+ that: neutral_prompt_parser.LeafPrompt,
55
+ args: _PreNoiseArgs,
56
+ index: int,
57
+ ) -> Dict[int, torch.Tensor]:
58
+ cond_index = args.cond_indices[index][0]
59
+ if that.local_transform is not None:
60
+ # 2×3 → 3×3 → invert → back to 2×3
61
+ m3 = torch.vstack([that.local_transform,
62
+ torch.tensor([0.0, 0.0, 1.0])])
63
+ inv = torch.linalg.inv(m3)[:-1]
64
+ else:
65
+ inv = torch.eye(3)[:-1]
66
+ return {cond_index: inv}
67
+
68
+ def visit_composite_prompt(
69
+ self,
70
+ that: neutral_prompt_parser.CompositePrompt,
71
+ args: _PreNoiseArgs,
72
+ index: int,
73
+ ) -> Dict[int, torch.Tensor]:
74
+ inv_transforms: Dict[int, torch.Tensor] = {}
75
+
76
+ for child in that.children:
77
+ inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
78
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
79
+
80
+ # Compose parent transform on top of all children
81
+ if that.local_transform is not None:
82
+ m3 = torch.vstack([that.local_transform,
83
+ torch.tensor([0.0, 0.0, 1.0])])
84
+ parent_inv = torch.linalg.inv(m3)
85
+ for inv in inv_transforms.values():
86
+ # inv is 2×3; extend to 3×3, multiply, trim back
87
+ inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
88
+ inv[:] = (parent_inv @ inv3)[:-1]
89
+
90
+ return inv_transforms
91
+
92
+
93
+ def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
94
+ """Pre-noise hook: warp each cond latent slice by the inverse affine."""
95
+ if not global_state.is_enabled:
96
+ return
97
+
98
+ if not _batch_is_compatible(global_state.prompt_exprs, global_state.batch_cond_indices):
99
+ return
100
+
101
+ for prompt, cond_indices in zip(global_state.prompt_exprs,
102
+ global_state.batch_cond_indices):
103
+ args = _PreNoiseArgs(params.x, cond_indices)
104
+ inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
105
+ for cond_index, _ in cond_indices:
106
+ params.x[cond_index] = affine_mod.apply_affine_transform(
107
+ params.x[cond_index], inv_transforms[cond_index]
108
+ )
109
+
110
+
111
+ script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
112
+
113
+
114
+
115
+ def _flat_size(prompt: neutral_prompt_parser.PromptExpr) -> int:
116
+ return prompt.accept(neutral_prompt_parser.FlatSizeVisitor())
117
+
118
+
119
+ def _batch_is_compatible(
120
+ prompts: List[neutral_prompt_parser.PromptExpr],
121
+ batch_cond_indices: List[List[Tuple[int, float]]],
122
+ ) -> bool:
123
+ if len(prompts) != len(batch_cond_indices):
124
+ _console_warn(f'''
125
+ Neutral Prompt batch mismatch:
126
+ prompt_exprs={len(prompts)} vs batch_cond_indices={len(batch_cond_indices)}
127
+ Falling back to original A1111 behavior for this step.
128
+ ''')
129
+ return False
130
+
131
+ for i, (prompt, cond_indices) in enumerate(zip(prompts, batch_cond_indices)):
132
+ need = _flat_size(prompt)
133
+ got = len(cond_indices)
134
+ if need != got:
135
+ _console_warn(f'''
136
+ Neutral Prompt branch mismatch at prompt #{i}:
137
+ expected {need} branches, got {got}
138
+ This usually means another extension replaced prompt parsing hooks.
139
+ Falling back to original A1111 behavior for this step.
140
+ ''')
141
+ return False
142
+
143
+ return True
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Public entry point
148
+ # ---------------------------------------------------------------------------
149
+
150
+ def combine_denoised_hijack(
151
+ x_out: torch.Tensor,
152
+ batch_cond_indices: List[List[Tuple[int, float]]],
153
+ text_uncond: torch.Tensor,
154
+ cond_scale: float,
155
+ original_function,
156
+ ) -> torch.Tensor:
157
+ if not global_state.is_enabled:
158
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
159
+
160
+ if not _batch_is_compatible(global_state.prompt_exprs, batch_cond_indices):
161
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
162
+
163
+ denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
164
+ uncond = x_out[-text_uncond.shape[0]:]
165
+
166
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
167
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
168
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
169
+ aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
170
+
171
+ # Apply per-prompt affine transform to both deltas (affine branch)
172
+ if prompt.local_transform is not None:
173
+ cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
174
+ aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
175
+
176
+ cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
177
+ denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
178
+
179
+ return denoised
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Internal helpers
184
+ # ---------------------------------------------------------------------------
185
+
186
+ def _get_webui_denoised(
187
+ x_out: torch.Tensor,
188
+ batch_cond_indices: List[List[Tuple[int, float]]],
189
+ text_uncond: torch.Tensor,
190
+ cond_scale: float,
191
+ original_function,
192
+ ) -> torch.Tensor:
193
+ uncond = x_out[-text_uncond.shape[0]:]
194
+ sliced_batch_x_out: List[torch.Tensor] = []
195
+ sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
196
+
197
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
198
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
199
+ sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
200
+ if sliced_indices:
201
+ sliced_batch_cond_indices.append(sliced_indices)
202
+ sliced_batch_x_out.extend(sliced_x_out)
203
+
204
+ sliced_batch_x_out += list(uncond)
205
+ return original_function(
206
+ torch.stack(sliced_batch_x_out, dim=0),
207
+ sliced_batch_cond_indices,
208
+ text_uncond,
209
+ cond_scale,
210
+ )
211
+
212
+
213
+ def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
214
+ """
215
+ Mean-preserving CFG rescale (main branch formula).
216
+ Override is applied first so XYZ-grid / external API overrides are never silently skipped.
217
+ Also stores the computed rescale factor in CFGRescaleFactorSingleton.
218
+ """
219
+ # Clear last step's factor so get() returns None if rescaling is skipped
220
+ # this step (stale value bug fix).
221
+ global_state.CFGRescaleFactorSingleton.clear()
222
+
223
+ # Apply one-shot override BEFORE the early-exit check, otherwise a 0→nonzero
224
+ # override (from the external API) would be silently discarded.
225
+ global_state.apply_and_clear_cfg_rescale_override()
226
+
227
+ if global_state.cfg_rescale == 0:
228
+ return cfg_cond
229
+
230
+ cfg_std = cfg_cond.std()
231
+ if cfg_std == 0:
232
+ # Degenerate case: constant tensor – rescaling is a no-op.
233
+ return cfg_cond
234
+
235
+ cfg_cond_mean = cfg_cond.mean()
236
+ rescale_mean = (
237
+ (1 - global_state.cfg_rescale) * cfg_cond_mean
238
+ + global_state.cfg_rescale * cond.mean()
239
+ )
240
+ rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
241
+
242
+ # Export for external consumers (export_rescale_factor branch)
243
+ global_state.CFGRescaleFactorSingleton.set(
244
+ rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
245
+ )
246
+
247
+ return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
248
+
249
+
250
+ @dataclasses.dataclass
251
+ class _DenoiseArgs:
252
+ x_out: torch.Tensor
253
+ uncond: torch.Tensor
254
+ cond_indices: List[Tuple[int, float]]
255
+
256
+
257
+ # ---------------------------------------------------------------------------
258
+ # Gather webui-style conditions (needed for the webui's own CFG path)
259
+ # ---------------------------------------------------------------------------
260
+
261
+ def _gather_webui_conds(
262
+ prompt: neutral_prompt_parser.CompositePrompt,
263
+ args: _DenoiseArgs,
264
+ index_in: int,
265
+ index_out: int,
266
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
267
+ sliced_x_out: List[torch.Tensor] = []
268
+ sliced_cond_indices: List[Tuple[int, float]] = []
269
+
270
+ for child in prompt.children:
271
+ if child.conciliation is None:
272
+ if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
273
+ child_x_out = args.x_out[args.cond_indices[index_in][0]]
274
+ child_weight = child.weight
275
+ else:
276
+ child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
277
+ child_x_out = child_x_out + args.uncond
278
+
279
+ index_offset = index_out + len(sliced_x_out)
280
+ sliced_x_out.append(child_x_out)
281
+ sliced_cond_indices.append((index_offset, child_weight))
282
+
283
+ index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
284
+
285
+ return sliced_x_out, sliced_cond_indices
286
+
287
+
288
+ def _get_cond_delta_and_weight(
289
+ prompt: neutral_prompt_parser.PromptExpr,
290
+ args: _DenoiseArgs,
291
+ index: int,
292
+ ) -> Tuple[torch.Tensor, float]:
293
+ """Compute cond delta and effective weight, applying affine transform if present."""
294
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
295
+ cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
296
+ weight = prompt.weight
297
+
298
+ if prompt.local_transform is not None:
299
+ transformed, weight_tensor = affine_mod.apply_masked_transform(
300
+ cond_delta + args.uncond,
301
+ prompt.local_transform,
302
+ prompt.weight,
303
+ )
304
+ cond_delta = transformed - args.uncond
305
+ weight = weight_tensor # type: ignore[assignment]
306
+
307
+ return cond_delta, weight
308
+
309
+
310
+ # ---------------------------------------------------------------------------
311
+ # Visitor: CondDelta – weighted sum of leaf cond − uncond
312
+ # ---------------------------------------------------------------------------
313
+
314
+ class _CondDeltaVisitor:
315
+ def visit_leaf_prompt(
316
+ self,
317
+ that: neutral_prompt_parser.LeafPrompt,
318
+ args: _DenoiseArgs,
319
+ index: int,
320
+ ) -> torch.Tensor:
321
+ cond_info = args.cond_indices[index]
322
+ if that.weight != cond_info[1]:
323
+ _console_warn(f'''
324
+ An unexpected noise weight was encountered at prompt #{index}
325
+ Expected :{that.weight}, but got :{cond_info[1]}
326
+ This is likely due to another extension also monkey patching `combine_denoised`.
327
+ Please open a bug report: https://github.com/ljleb/sd-webui-neutral-prompt/issues
328
+ ''')
329
+ return args.x_out[cond_info[0]] - args.uncond
330
+
331
+ def visit_composite_prompt(
332
+ self,
333
+ that: neutral_prompt_parser.CompositePrompt,
334
+ args: _DenoiseArgs,
335
+ index: int,
336
+ ) -> torch.Tensor:
337
+ cond_delta = torch.zeros_like(args.x_out[0])
338
+ for child in that.children:
339
+ if child.conciliation is None:
340
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
341
+ cond_delta = cond_delta + child_weight * child_delta
342
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
343
+ return cond_delta
344
+
345
+
346
+ # ---------------------------------------------------------------------------
347
+ # Visitor: AuxCondDelta – all conciliation strategies
348
+ # ---------------------------------------------------------------------------
349
+
350
+ class _AuxCondDeltaVisitor:
351
+ def visit_leaf_prompt(
352
+ self,
353
+ that: neutral_prompt_parser.LeafPrompt,
354
+ args: _DenoiseArgs,
355
+ cond_delta: torch.Tensor,
356
+ index: int,
357
+ ) -> torch.Tensor:
358
+ return torch.zeros_like(args.x_out[0])
359
+
360
+ def visit_composite_prompt(
361
+ self,
362
+ that: neutral_prompt_parser.CompositePrompt,
363
+ args: _DenoiseArgs,
364
+ cond_delta: torch.Tensor,
365
+ index: int,
366
+ ) -> torch.Tensor:
367
+ aux_cond_delta = torch.zeros_like(args.x_out[0])
368
+ salient_cond_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT k=20
369
+ salient_wide_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_WIDE k=1
370
+ salient_blob_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_BLOB k=20 + morphology
371
+ align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
372
+ mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
373
+
374
+ for child in that.children:
375
+ if child.conciliation is not None:
376
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
377
+ strat = child.conciliation
378
+
379
+ if strat == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
380
+ aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
381
+
382
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
383
+ salient_cond_deltas.append((child_delta, child_weight))
384
+
385
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_WIDE:
386
+ salient_wide_deltas.append((child_delta, child_weight))
387
+
388
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_BLOB:
389
+ salient_blob_deltas.append((child_delta, child_weight))
390
+
391
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
392
+ aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, 0.05)
393
+
394
+ else:
395
+ # AND_ALIGN_D_S (soft alignment blend)
396
+ m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
397
+ if m:
398
+ align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
399
+ else:
400
+ # AND_MASK_ALIGN_D_S (binary alignment mask)
401
+ m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
402
+ if m:
403
+ mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
404
+
405
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
406
+
407
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas, child_k=20)
408
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas, child_k=1)
409
+ aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
410
+ aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
411
+ aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
412
+ return aux_cond_delta
413
+
414
+
415
+ # ---------------------------------------------------------------------------
416
+ # Strategy implementations
417
+ # ---------------------------------------------------------------------------
418
+
419
+ def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
420
+ if (normal == 0).all():
421
+ if shared.state.sampling_step <= 0:
422
+ _warn_projection_not_found()
423
+ return vector
424
+ return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
425
+
426
+
427
+ def _salient_blend(
428
+ normal: torch.Tensor,
429
+ vectors: List[Tuple[torch.Tensor, float]],
430
+ child_k: float = 20.0,
431
+ ) -> torch.Tensor:
432
+ """
433
+ Saliency-guided blend: each child prompt wins in the latent regions
434
+ where its absolute activation magnitude is strongest.
435
+
436
+ child_k controls how sharp/selective the child salience mask is:
437
+ child_k=20 → very sharp, 1-2 peak pixels (AND_SALT — life-branch style)
438
+ child_k=1 → broad, ~55% of pixels (AND_SALT_WIDE — original main-branch)
439
+ Parent always uses k=1 (diffuse reference).
440
+ """
441
+ if not vectors:
442
+ return torch.zeros_like(normal)
443
+
444
+ salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=child_k) for v, _ in vectors]
445
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
446
+
447
+ result = torch.zeros_like(normal)
448
+ for mask_i, (vector, weight) in enumerate(vectors, start=1):
449
+ vector_mask = (mask == mask_i).float()
450
+ result = result + weight * vector_mask * (vector - normal)
451
+ return result
452
+
453
+
454
+ def _salient_blend_blob(
455
+ normal: torch.Tensor,
456
+ vectors: List[Tuple[torch.Tensor, float]],
457
+ ) -> torch.Tensor:
458
+ """
459
+ AND_SALT_BLOB: life-branch dev.py algorithm (cleaned of debug scaffolding).
460
+
461
+ Pipeline per child:
462
+ 1. k=20 softmax → very sharp salience seed (1-2 pixels)
463
+ 2. _life_erode ×6 → erode to spatially dense core
464
+ 3. _life_thickify ×2 → grow outward into a smooth blob
465
+ 4. result += weight × blob_mask × (child − parent)
466
+
467
+ This is what the life-branch author was iterating toward in dev.py.
468
+ """
469
+ if not vectors:
470
+ return torch.zeros_like(normal)
471
+
472
+ salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=20) for v, _ in vectors]
473
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
474
+
475
+ result = torch.zeros_like(normal)
476
+ for mask_i, (vector, weight) in enumerate(vectors, start=1):
477
+ vector_mask = (mask == mask_i).float()
478
+ for _ in range(6):
479
+ vector_mask = _life_step(vector_mask, _erode_rule)
480
+ for _ in range(2):
481
+ vector_mask = _life_step(vector_mask, _thickify_rule)
482
+ result = result + weight * vector_mask * (vector - normal)
483
+ return result
484
+
485
+
486
+ def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
487
+ """
488
+ One step of a cellular-automaton morphology on a C×H×W binary mask.
489
+
490
+ The conv3d trick from dev.py: concatenate [board, board[:-1]] along the
491
+ channel axis so a single 3-D convolution simultaneously sums the 3×3
492
+ spatial neighbourhood across *all* C channels. This keeps the operation
493
+ GPU-friendly and avoids an explicit loop over channels.
494
+
495
+ neighbors[c,h,w] = Σ_{dc,dh,dw} board_padded[c+dc, h+dh, w+dw]
496
+ minus the center pixel itself (board is subtracted).
497
+ """
498
+ C = board.shape[0]
499
+ kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
500
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # [1, 1, C, 3, 3]
501
+
502
+ # Pad spatially (left/right/top/bottom) but not along channel axis
503
+ padded = torch.cat([board.clone(), board[:-1].clone()], dim=0) # [2C-1, H, W]
504
+ padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0) # [2C-1, H+2, W+2]
505
+
506
+ neighbors = torch.nn.functional.conv3d(
507
+ padded.unsqueeze(0).unsqueeze(0), # [1, 1, 2C-1, H+2, W+2]
508
+ kernel, # [1, 1, C, 3, 3]
509
+ padding=0,
510
+ ).squeeze(0).squeeze(0) # [C, H, W]
511
+
512
+ neighbors = neighbors - board # subtract center pixel
513
+ return rule(board, neighbors).float()
514
+
515
+
516
+ def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
517
+ """Keep a pixel only if it is set AND its C-channel neighbourhood is dense.
518
+ Threshold C*5 matches dev.py: for C=4 that means ≥20 out of 36 possible."""
519
+ C = board.shape[0]
520
+ return (board == 1) & (neighbors >= C * 5)
521
+
522
+
523
+ def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
524
+ """Grow: keep existing pixels OR add any pixel adjacent to the core.
525
+ population ≥ 4 ensures only pixels touching at least one set neighbour grow."""
526
+ population = board + neighbors
527
+ return (board == 1) | (population >= 4)
528
+
529
+
530
+ def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
531
+ """Softmax-based salience map. k > 1 → sharper, more selective mask."""
532
+ return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
533
+
534
+
535
+ def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
536
+ """Keep only the top k_ratio fraction of activations by absolute value."""
537
+ k = int(torch.numel(vector) * (1 - k_ratio))
538
+ k = max(1, k) # kthvalue requires k >= 1
539
+ threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
540
+ return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
541
+
542
+
543
+ # ---------------------------------------------------------------------------
544
+ # Alignment blend (alignment_blend branch)
545
+ # ---------------------------------------------------------------------------
546
+
547
+ def _compute_subregion_similarity_map(
548
+ child_vector: torch.Tensor,
549
+ parent_vector: torch.Tensor,
550
+ region_size: int = 2,
551
+ ) -> torch.Tensor:
552
+ """
553
+ Compute local average cosine similarity of (region_size × region_size) regions.
554
+ Returns a map of shape [C, H, W] with values in [-1, 1].
555
+ """
556
+ C, H, W = child_vector.shape
557
+ parent = parent_vector.unsqueeze(0) # [1, C, H, W]
558
+ child = child_vector.unsqueeze(0)
559
+
560
+ region_radius = region_size // 2
561
+ if region_size % 2 == 1:
562
+ pad_size = (region_radius,) * 4
563
+ else:
564
+ pad_size = (region_radius - 1, region_radius) * 2
565
+
566
+ parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
567
+ child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
568
+
569
+ # [H*W, C, region_size, region_size]
570
+ # .contiguous() is required before .view() because .permute() produces a
571
+ # non-contiguous tensor and .view() raises RuntimeError on non-contiguous input.
572
+ parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
573
+ child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
574
+
575
+ unfold2 = torch.nn.Unfold(kernel_size=2)
576
+ parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
577
+ child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
578
+
579
+ parent_sub = F.normalize(parent_sub, p=2, dim=2)
580
+ child_sub = F.normalize(child_sub, p=2, dim=2)
581
+ sim = (parent_sub * child_sub).sum(dim=2) # [H*W, C, (r-1)^2]
582
+ return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W) # [C, H, W]
583
+
584
+
585
+ def _alignment_blend(
586
+ parent: torch.Tensor,
587
+ children: List[Tuple[torch.Tensor, float, int, int]],
588
+ ) -> torch.Tensor:
589
+ """
590
+ Soft alignment blend (AND_ALIGN_D_S).
591
+ Child contribution is weighted by max(0, structure_alignment − detail_alignment).
592
+ High weight where child changes detail without breaking structure.
593
+ """
594
+ result = torch.zeros_like(parent)
595
+ for child, weight, detail_size, structure_size in children:
596
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
597
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
598
+
599
+ # Normalise by absolute max so that all-negative maps (anti-correlated prompts
600
+ # such as 'black' vs 'white') don't blow up to ±1e7 and clamp incorrectly to 1.
601
+ # Dividing by positive max would invert the sign ordering when all values are
602
+ # negative; dividing by abs-max preserves relative ordering in both cases.
603
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
604
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
605
+ detail_sim = detail_sim / d_abs_max
606
+ structure_sim = structure_sim / s_abs_max
607
+
608
+ alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
609
+ result = result + (child - parent) * weight * alignment_weight
610
+ return result
611
+
612
+
613
+ def _alignment_mask_blend(
614
+ parent: torch.Tensor,
615
+ children: List[Tuple[torch.Tensor, float, int, int]],
616
+ ) -> torch.Tensor:
617
+ """
618
+ Binary alignment mask (AND_MASK_ALIGN_D_S).
619
+ Child receives full weight where structure_alignment > detail_alignment, zero elsewhere.
620
+ """
621
+ result = torch.zeros_like(parent)
622
+ for child, weight, detail_size, structure_size in children:
623
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
624
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
625
+
626
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
627
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
628
+ detail_sim = detail_sim / d_abs_max
629
+ structure_sim = structure_sim / s_abs_max
630
+
631
+ alignment_mask = (structure_sim > detail_sim).to(child)
632
+ result = result + (child - parent) * weight * alignment_mask
633
+ return result
634
+
635
+
636
+ # ---------------------------------------------------------------------------
637
+ # Sampler hijack
638
+ # ---------------------------------------------------------------------------
639
+
640
+ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
641
+ module=sd_samplers,
642
+ hijacker_attribute='__neutral_prompt_hijacker',
643
+ on_uninstall=script_callbacks.on_script_unloaded,
644
+ )
645
+
646
+
647
+ @sd_samplers_hijacker.hijack('create_sampler')
648
+ def create_sampler_hijack(name: str, model, original_function):
649
+ sampler = original_function(name, model)
650
+ if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
651
+ if global_state.is_enabled:
652
+ _warn_unsupported_sampler()
653
+ return sampler
654
+
655
+ sampler.model_wrap_cfg.combine_denoised = functools.partial(
656
+ combine_denoised_hijack,
657
+ original_function=sampler.model_wrap_cfg.combine_denoised,
658
+ )
659
+ return sampler
660
+
661
+
662
+ # ---------------------------------------------------------------------------
663
+ # Warnings / logging
664
+ # ---------------------------------------------------------------------------
665
+
666
+ def _warn_unsupported_sampler() -> None:
667
+ _console_warn('''
668
+ Neutral prompt relies on composition via AND, which the webui does not support
669
+ when using any of the DDIM, PLMS and UniPC samplers.
670
+ The sampler will NOT be patched – falling back on the original implementation.
671
+ ''')
672
+
673
+
674
+ def _warn_projection_not_found() -> None:
675
+ _console_warn('''
676
+ Could not find a projection for one or more AND_PERP prompts.
677
+ These prompts will NOT be made perpendicular.
678
+ ''')
679
+
680
+
681
+ def _console_warn(message: str) -> None:
682
+ if not global_state.verbose:
683
+ return
684
+ print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+
4
+ @contextlib.contextmanager
5
+ def fix_path():
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ extension_path = str(Path(__file__).parent.parent.parent)
10
+ added = False
11
+ if extension_path not in sys.path:
12
+ sys.path.insert(0, extension_path)
13
+ added = True
14
+
15
+ yield
16
+
17
+ if added:
18
+ sys.path.remove(extension_path)
19
+
20
+
21
+ with fix_path():
22
+ del fix_path, contextlib
23
+ from .api import *
neutral_prompt_patched/lib_neutral_prompt/external_code/api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ External API for sd-webui-neutral-prompt.
3
+
4
+ Provides thin helpers that external extensions / scripts can import via:
5
+
6
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from lib_neutral_prompt import global_state
12
+
13
+
14
+ def override_cfg_rescale(cfg_rescale: float) -> None:
15
+ """
16
+ Override the CFG rescale value for the *next* generation step only.
17
+ After the step runs the value is automatically cleared.
18
+ """
19
+ global_state.cfg_rescale_override = cfg_rescale
20
+
21
+
22
+ def get_last_cfg_rescale_factor() -> Optional[float]:
23
+ """
24
+ Return the CFG rescale factor computed during the most recent denoising
25
+ step, or None if rescaling was not active.
26
+ """
27
+ return global_state.CFGRescaleFactorSingleton.get()
neutral_prompt_patched/lib_neutral_prompt/global_state.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global mutable state shared across the extension.
3
+
4
+ Combines:
5
+ - is_enabled / prompt_exprs / verbose (all branches)
6
+ - cfg_rescale + cfg_rescale_override mechanism (main branch)
7
+ - batch_cond_indices (affine branch – needed by on_cfg_denoiser pre-noise hook)
8
+ - CFGRescaleFactorSingleton – exposes the last computed rescale factor
9
+ to external callers / API users (export_rescale_factor branch)
10
+ """
11
+
12
+ import threading
13
+ from typing import List, Optional, Tuple
14
+ from lib_neutral_prompt import neutral_prompt_parser
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Runtime state
19
+ # ---------------------------------------------------------------------------
20
+
21
+ is_enabled: bool = False
22
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
23
+ # Populated by reconstruct_multicond_batch hook; used by the pre-noise
24
+ # affine transform hook (on_cfg_denoiser) to know which latent slices
25
+ # correspond to which prompt child.
26
+ batch_cond_indices: List[List[Tuple[int, float]]] = []
27
+ cfg_rescale: float = 0.0
28
+ verbose: bool = False # matches UI default; UI syncs this on settings load
29
+
30
+ # Set to a float value by XYZ-grid or external API to override cfg_rescale
31
+ # for a single generation step, then auto-cleared.
32
+ cfg_rescale_override: Optional[float] = None
33
+
34
+
35
+ def apply_and_clear_cfg_rescale_override() -> None:
36
+ """Apply a one-shot cfg_rescale override and immediately clear it."""
37
+ global cfg_rescale, cfg_rescale_override
38
+ if cfg_rescale_override is not None:
39
+ cfg_rescale = cfg_rescale_override
40
+ cfg_rescale_override = None
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # CFG Rescale Factor Singleton (export_rescale_factor branch)
45
+ #
46
+ # Stores the last computed rescale factor so that external tools / scripts
47
+ # can read it without re-deriving it.
48
+ #
49
+ # Uses threading.local() so concurrent API workers each get their own slot.
50
+ # clear() must be called at the start of every _cfg_rescale() call so that
51
+ # get() correctly returns None when rescaling was skipped this step.
52
+ # ---------------------------------------------------------------------------
53
+
54
+ class CFGRescaleFactorSingleton:
55
+ """Thread-local store for the most recently computed CFG rescale factor.
56
+
57
+ Lifecycle per denoising step:
58
+ 1. _cfg_rescale() calls clear() at entry — value becomes None.
59
+ 2. If rescaling is active, set() stores the computed factor.
60
+ 3. External code calls get() after the step; receives the factor or None.
61
+ """
62
+
63
+ _state = threading.local()
64
+
65
+ @classmethod
66
+ def set(cls, value: float) -> None:
67
+ cls._state.cfg_rescale_factor = value
68
+
69
+ @classmethod
70
+ def get(cls) -> Optional[float]:
71
+ return getattr(cls._state, 'cfg_rescale_factor', None)
72
+
73
+ @classmethod
74
+ def clear(cls) -> None:
75
+ cls._state.cfg_rescale_factor = None
neutral_prompt_patched/lib_neutral_prompt/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+
4
+ class ModuleHijacker:
5
+ def __init__(self, module):
6
+ self.__module = module
7
+ self.__original_functions = dict()
8
+
9
+ def hijack(self, attribute):
10
+ if attribute not in self.__original_functions:
11
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
12
+
13
+ def decorator(function):
14
+ setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ on_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ on_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Neutral Prompt Parser
3
+ Combines:
4
+ - Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
5
+ - Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
6
+ - Local alignment blend: AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
7
+ """
8
+
9
+ import abc
10
+ import dataclasses
11
+ import math
12
+ import re
13
+ from enum import Enum
14
+ from itertools import product
15
+ from typing import Any, List, Optional, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Keyword registry
22
+ # ---------------------------------------------------------------------------
23
+
24
+ # Base conciliation keywords
25
+ _BASE_KEYWORD_MAP = {
26
+ 'AND_PERP': 'PERPENDICULAR',
27
+ 'AND_SALT': 'SALIENCE_MASK', # k=20 sharp mask (life-branch style)
28
+ 'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE', # k=1 broad mask (original main-branch behaviour)
29
+ 'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB', # k=20 + erode + thickify → smooth blob (dev.py intent)
30
+ 'AND_TOPK': 'SEMANTIC_GUIDANCE',
31
+ }
32
+
33
+ # Alignment blend: AND_ALIGN_D_S (soft weight, structure > detail)
34
+ _ALIGN_KEYWORD_MAP = {
35
+ f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
36
+ for i, j in product(range(2, 33), repeat=2) if i != j
37
+ }
38
+
39
+ # Alignment mask: AND_MASK_ALIGN_D_S (binary mask, structure > detail)
40
+ _MASK_ALIGN_KEYWORD_MAP = {
41
+ f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
42
+ for i, j in product(range(2, 33), repeat=2) if i != j
43
+ }
44
+
45
+ keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
46
+
47
+ PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
48
+ ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
49
+
50
+ # Lists kept for ordered iteration (tokenize uses list order for regex priority)
51
+ prompt_keywords = [e.value for e in PromptKeyword]
52
+ conciliation_strategies = [e.value for e in ConciliationStrategy]
53
+
54
+ # Sets for O(1) membership checks in parser hot-loops (1864-item list = 230x slower)
55
+ _prompt_keywords_set = frozenset(prompt_keywords)
56
+ _conciliation_strategies_set = frozenset(conciliation_strategies)
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Affine transform definitions (from affine branch)
61
+ # ---------------------------------------------------------------------------
62
+
63
+ affine_transforms = {
64
+ 'ROTATE': lambda t, angle=0, *_: t @ torch.tensor([
65
+ [math.cos(angle * 2 * math.pi), -math.sin(angle * 2 * math.pi), 0],
66
+ [math.sin(angle * 2 * math.pi), math.cos(angle * 2 * math.pi), 0],
67
+ [0, 0, 1],
68
+ ], dtype=torch.float32),
69
+ 'SLIDE': lambda t, x=0, y=0, *_: t @ torch.tensor([
70
+ [1, 0, float(x)],
71
+ [0, 1, float(y)],
72
+ [0, 0, 1],
73
+ ], dtype=torch.float32),
74
+ 'SCALE': lambda t, x=1, y=None, *_: t @ torch.tensor([
75
+ [float(x), 0, 0],
76
+ [0, float(y) if y is not None else float(x), 0],
77
+ [0, 0, 1],
78
+ ], dtype=torch.float32),
79
+ 'SHEAR': lambda t, x=0, y=None, *_: t @ torch.tensor([
80
+ [1, math.tan(float(x) * 2 * math.pi), 0],
81
+ [math.tan(float(y if y is not None else x) * 2 * math.pi), 1, 0],
82
+ [0, 0, 1],
83
+ ], dtype=torch.float32),
84
+ }
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # AST node types
89
+ # ---------------------------------------------------------------------------
90
+
91
+ @dataclasses.dataclass
92
+ class PromptExpr(abc.ABC):
93
+ weight: float
94
+ conciliation: Optional[ConciliationStrategy]
95
+ local_transform: Optional[torch.Tensor] # 2×3 affine tensor or None
96
+
97
+ @abc.abstractmethod
98
+ def accept(self, visitor, *args, **kwargs) -> Any:
99
+ pass
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class LeafPrompt(PromptExpr):
104
+ prompt: str
105
+
106
+ def accept(self, visitor, *args, **kwargs):
107
+ return visitor.visit_leaf_prompt(self, *args, **kwargs)
108
+
109
+
110
+ @dataclasses.dataclass
111
+ class CompositePrompt(PromptExpr):
112
+ children: List[PromptExpr]
113
+
114
+ def accept(self, visitor, *args, **kwargs):
115
+ return visitor.visit_composite_prompt(self, *args, **kwargs)
116
+
117
+
118
+ class FlatSizeVisitor:
119
+ def visit_leaf_prompt(self, that: LeafPrompt) -> int:
120
+ return 1
121
+
122
+ def visit_composite_prompt(self, that: CompositePrompt) -> int:
123
+ return sum(child.accept(self) for child in that.children) if that.children else 0
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # Parser
128
+ # ---------------------------------------------------------------------------
129
+
130
+ def parse_root(string: str) -> CompositePrompt:
131
+ tokens = tokenize(string)
132
+ prompts = parse_prompts(tokens)
133
+ return CompositePrompt(1.0, None, None, prompts)
134
+
135
+
136
+ def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
137
+ prompts = [parse_prompt(tokens, first=True, nested=nested)]
138
+ while tokens:
139
+ if nested and tokens[0] == ']':
140
+ break
141
+ prompts.append(parse_prompt(tokens, first=False, nested=nested))
142
+ return prompts
143
+
144
+
145
+ def _compose_affine(
146
+ a: Optional[torch.Tensor],
147
+ b: Optional[torch.Tensor],
148
+ ) -> Optional[torch.Tensor]:
149
+ """
150
+ Compose two 2×3 affine matrices (a applied first, then b).
151
+ Returns None if both are None; returns the non-None one if only one exists.
152
+ """
153
+ if a is None:
154
+ return b
155
+ if b is None:
156
+ return a
157
+ # Lift to 3×3, multiply, trim back to 2×3
158
+ a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
159
+ b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
160
+ return (b3 @ a3)[:-1]
161
+
162
+
163
+ def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
164
+ """
165
+ Return True if `tokens` starts with one or more affine keywords followed
166
+ immediately by a prompt keyword, e.g.:
167
+ ROTATE[0.125] AND_PERP ...
168
+ SLIDE[0.05,0] AND_SALT ...
169
+ ROTATE[0.125] SLIDE[0.05,0] AND_PERP ...
170
+
171
+ This is a pure structural scan — it does NOT compute any affine matrices,
172
+ so it is safe to call as a look-ahead probe without mock-torch issues.
173
+
174
+ Used in two places:
175
+ - parse_prompt() → consume leading affine before keyword
176
+ - _parse_prompt_text() → stop current segment before this boundary
177
+ """
178
+ pos = 0
179
+ n = len(tokens)
180
+ # skip a leading whitespace-only token
181
+ if pos < n and not tokens[pos].strip():
182
+ pos += 1
183
+ # must start with an affine keyword
184
+ if pos >= n or tokens[pos] not in affine_transforms:
185
+ return False
186
+ # consume one or more KEYWORD [ args ] blocks
187
+ while pos < n and tokens[pos] in affine_transforms:
188
+ pos += 1 # skip keyword
189
+ if pos >= n or tokens[pos] != '[':
190
+ return False
191
+ pos += 1 # skip '['
192
+ if pos < n and tokens[pos] != ']':
193
+ pos += 1 # skip args (may contain commas)
194
+ if pos >= n or tokens[pos] != ']':
195
+ return False
196
+ pos += 1 # skip ']'
197
+ if pos < n and not tokens[pos].strip():
198
+ pos += 1 # skip optional whitespace
199
+ # the very next token must be a prompt keyword
200
+ return pos < n and tokens[pos] in _prompt_keywords_set
201
+
202
+
203
+ def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
204
+ # Support two affine-keyword orderings:
205
+ # (A) AND_PERP ROTATE[0.25] text ← trailing affine (original syntax)
206
+ # (B) ROTATE[0.25] AND_PERP text ← leading affine (documented syntax)
207
+ #
208
+ # Note: we check _looks_like_leading_affine_prompt unconditionally (not
209
+ # gated on `not first`) so that a prompt starting with e.g.
210
+ # "ROTATE[0.125] AND_PERP text" is handled correctly even as the first
211
+ # segment.
212
+ leading_affine: Optional[torch.Tensor] = None
213
+ if _looks_like_leading_affine_prompt(tokens):
214
+ probe = tokens.copy()
215
+ leading_affine = _parse_affine_transform(probe)
216
+ tokens[:] = probe
217
+
218
+ # After consuming a leading affine the next token is a keyword; accept it.
219
+ # We also accept a keyword as the very first token (first=True) without a
220
+ # leading affine — e.g. "AND_PERP vivid" at the start of a string should
221
+ # produce a single PERP child, not an empty default child followed by a PERP
222
+ # child. The original `not first` guard was an artefact of earlier grammar;
223
+ # all existing tests pass without it.
224
+ if tokens and tokens[0] in _prompt_keywords_set:
225
+ prompt_type = tokens.pop(0)
226
+ else:
227
+ prompt_type = PromptKeyword.AND.value
228
+
229
+ conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
230
+
231
+ # Parse optional trailing affine transform chain (original syntax)
232
+ trailing_affine = _parse_affine_transform(tokens)
233
+ affine_transform = _compose_affine(leading_affine, trailing_affine)
234
+
235
+ # Try composite (bracketed) form
236
+ tokens_copy = tokens.copy()
237
+ if tokens_copy and tokens_copy[0] == '[':
238
+ tokens_copy.pop(0)
239
+ prompts = parse_prompts(tokens_copy, nested=True)
240
+ if tokens_copy:
241
+ assert tokens_copy.pop(0) == ']'
242
+ if len(prompts) > 1:
243
+ tokens[:] = tokens_copy
244
+ weight = _parse_weight(tokens)
245
+ return CompositePrompt(weight, conciliation, affine_transform, prompts)
246
+
247
+ prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
248
+ return LeafPrompt(weight, conciliation, affine_transform, prompt_text)
249
+
250
+
251
+ def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
252
+ text = ''
253
+ depth = 0
254
+ weight = 1.0
255
+ while tokens:
256
+ tok = tokens[0]
257
+ if tok == ']':
258
+ if depth == 0:
259
+ if nested:
260
+ break
261
+ else:
262
+ depth -= 1
263
+ elif tok == '[':
264
+ depth += 1
265
+ elif tok == ':':
266
+ if len(tokens) >= 2 and _is_float(tokens[1].strip()):
267
+ if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
268
+ tokens.pop(0)
269
+ weight = float(tokens.pop(0).strip())
270
+ break
271
+ elif tok in _prompt_keywords_set:
272
+ break
273
+ # Stop before a leading-affine-then-keyword boundary, e.g.
274
+ # ROTATE[0.125] AND_PERP text
275
+ # so the affine is consumed by the *next* parse_prompt call, not
276
+ # swallowed as raw text by the current segment.
277
+ elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
278
+ break
279
+ text += tokens.pop(0)
280
+ return text, weight
281
+
282
+
283
+ def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
284
+ """
285
+ Consume optional affine keywords like ROTATE[0.25] SLIDE[0.1,0] from the token stream.
286
+ Returns a 2×3 float32 tensor, or None if no affine tokens were found.
287
+ """
288
+ tokens_copy = tokens.copy()
289
+ if tokens_copy and not tokens_copy[0].strip():
290
+ tokens_copy.pop(0)
291
+
292
+ affine_funcs = []
293
+
294
+ while tokens_copy and tokens_copy[0] in affine_transforms:
295
+ func = affine_transforms[tokens_copy.pop(0)]
296
+ args: List[float] = []
297
+
298
+ if not (tokens_copy and tokens_copy[0] == '['):
299
+ break
300
+ tokens_copy.pop(0) # consume '['
301
+
302
+ if tokens_copy and tokens_copy[0] != ']':
303
+ if tokens_copy[0].strip():
304
+ try:
305
+ args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
306
+ except ValueError:
307
+ break
308
+ else:
309
+ tokens_copy.pop(0)
310
+
311
+ if not (tokens_copy and tokens_copy[0] == ']'):
312
+ break
313
+ tokens_copy.pop(0) # consume ']'
314
+
315
+ affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
316
+ if tokens_copy and not tokens_copy[0].strip():
317
+ tokens_copy.pop(0)
318
+ tokens[:] = tokens_copy
319
+
320
+ if not affine_funcs:
321
+ return None
322
+
323
+ transform = torch.eye(3, dtype=torch.float32)[:-1] # 2×3
324
+ for fn in reversed(affine_funcs):
325
+ transform = fn(transform)
326
+ return transform
327
+
328
+
329
+ def _parse_weight(tokens: List[str]) -> float:
330
+ if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
331
+ tokens.pop(0)
332
+ return float(tokens.pop(0))
333
+ return 1.0
334
+
335
+
336
+ def tokenize(s: str) -> List[str]:
337
+ # Use general patterns for AND_ALIGN / AND_MASK_ALIGN families instead of
338
+ # enumerating all 1800+ combinations - that regex is 41k chars and 500x slower.
339
+ # Order matters: longer/more-specific patterns must come first.
340
+ # Use (?<!\w) / (?!\w) instead of \b because _ counts as \w in Python's \b,
341
+ # which would cause 'AND' to match inside 'XAND' (no boundary between \w chars).
342
+ affine_kw_pattern = '|'.join(
343
+ rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
344
+ )
345
+ keyword_pattern = (
346
+ r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # must come before AND_ALIGN and AND
347
+ r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
348
+ r'|(?<!\w)AND_PERP(?!\w)'
349
+ r'|(?<!\w)AND_SALT_WIDE(?!\w)' # must come before AND_SALT
350
+ r'|(?<!\w)AND_SALT_BLOB(?!\w)' # must come before AND_SALT
351
+ r'|(?<!\w)AND_SALT(?!\w)'
352
+ r'|(?<!\w)AND_TOPK(?!\w)'
353
+ r'|(?<!\w)AND(?!\w)'
354
+ )
355
+ return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
356
+
357
+
358
+ def _is_float(s: str) -> bool:
359
+ try:
360
+ float(s)
361
+ return True
362
+ except ValueError:
363
+ return False
364
+
365
+
366
+ # ---------------------------------------------------------------------------
367
+ # Quick smoke-test
368
+ # ---------------------------------------------------------------------------
369
+ if __name__ == '__main__':
370
+ res = parse_root('''
371
+ hello
372
+ AND_PERP [
373
+ arst
374
+ AND defg : 2
375
+ AND_SALT [
376
+ very nested huh? what do you say :.0
377
+ ]
378
+ ]
379
+ AND_ALIGN_4_8 watercolor style :0.5
380
+ ROTATE[0.125] AND_PERP vibrant colors
381
+ ''')
382
+ print(res)
neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+ from typing import List
4
+
5
+ from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
6
+ from modules import script_callbacks, prompt_parser
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
10
+ # neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
11
+ # So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
12
+ # and the patched prompt_parser splits that into 3 multicond branches instead of 1,
13
+ # which shifts all batch_cond_indices and breaks PERP/SALT/affine.
14
+ #
15
+ # Solution: escape any standalone '&' inside leaf text with '\&' before handing
16
+ # the string to the webui parser. The patched prompt_parser correctly unescapes
17
+ # '\&' back to '&' during conditioning, so the model sees the original text.
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
21
+
22
+ # ── debug logging ──────────────────────────────────────────────────────────
23
+ # Set env variable NP_DEBUG=1 to enable verbose output in the A1111 console.
24
+ # Example (Linux/Mac): NP_DEBUG=1 python launch.py
25
+ # Example (Windows cmd): set NP_DEBUG=1 && python launch.py
26
+ import os as _os
27
+ _DEBUG = _os.getenv("NP_DEBUG", "0").strip() not in ("0", "", "false", "no", "off")
28
+ _log = logging.getLogger("neutral_prompt.hijack")
29
+ # ──────────────────────────────────────────────────────────────────────────
30
+
31
+
32
+ def _escape_leaf_ampersands(text: str) -> str:
33
+ """Escape standalone '&' so patched prompt_parser doesn't split a single
34
+ Neutral Prompt leaf into extra multicond branches.
35
+ "cat & dog" -> "cat \\& dog"
36
+ "R&D" -> "R&D" (unchanged — not standalone)
37
+ "\\&" -> "\\&" (unchanged — already escaped)
38
+ """
39
+ if not text or '&' not in text:
40
+ return text
41
+ return _STANDALONE_AMP.sub(r'\\&', text)
42
+
43
+
44
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
45
+ module=prompt_parser,
46
+ hijacker_attribute='__neutral_prompt_hijacker',
47
+ on_uninstall=script_callbacks.on_script_unloaded,
48
+ )
49
+
50
+
51
+ @prompt_parser_hijacker.hijack('get_multicond_prompt_list')
52
+ def get_multicond_prompt_list_hijack(prompts, original_function):
53
+ if not global_state.is_enabled:
54
+ return original_function(prompts)
55
+
56
+ global_state.prompt_exprs = parse_prompts(prompts)
57
+ webui_prompts = transpile_exprs(global_state.prompt_exprs)
58
+
59
+ # ── debug: transpiled strings ──────────────────────────────────────────
60
+ if _DEBUG:
61
+ for i, (orig, transp) in enumerate(zip(prompts, webui_prompts)):
62
+ _log.warning(
63
+ "[NP_DEBUG] prompt[%d]\n"
64
+ " original : %r\n"
65
+ " transpiled : %r\n"
66
+ " branches : %d",
67
+ i, orig, transp,
68
+ # count multicond splits the same way prompt_parser_fixed does
69
+ len(re.split(r'(?:\bAND\b|(?<!\S)&(?!\S))(?!_PERP|_SALT|_TOPK)', transp))
70
+ )
71
+ # ──────────────────────────────────────────────────────────────────────
72
+
73
+ if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
74
+ webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
75
+
76
+ result = original_function(webui_prompts)
77
+
78
+ # ── debug: multicond parts ─────────────────────────────────────────────
79
+ if _DEBUG:
80
+ conds_list, prompt_flat_list, prompt_indexes = result
81
+ _log.warning(
82
+ "[NP_DEBUG] get_multicond_prompt_list result\n"
83
+ " prompt_flat_list : %r\n"
84
+ " prompt_indexes : %r\n"
85
+ " conds_list : %r",
86
+ list(prompt_flat_list),
87
+ dict(prompt_indexes),
88
+ conds_list,
89
+ )
90
+ # ──────────────────────────────────────────────────────────────────────
91
+
92
+ return result
93
+
94
+
95
+ def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
96
+ exprs = []
97
+ for prompt in prompts:
98
+ expr = neutral_prompt_parser.parse_root(prompt)
99
+ exprs.append(expr)
100
+ return exprs
101
+
102
+
103
+ def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
104
+ webui_prompts = []
105
+ for expr in exprs:
106
+ webui_prompts.append(expr.accept(WebuiPromptVisitor()))
107
+ return webui_prompts
108
+
109
+
110
+ class WebuiPromptVisitor:
111
+ def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
112
+ prompt = _escape_leaf_ampersands(that.prompt)
113
+ return f'{prompt} :{that.weight}'
114
+
115
+ def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
116
+ return ' AND '.join(child.accept(self) for child in that.children)
117
+
118
+
119
+ @prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
120
+ def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
121
+ """Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
122
+ res = original_function(*args, **kwargs)
123
+ global_state.batch_cond_indices = res[0]
124
+
125
+ # ── debug: batch_cond_indices ──────────────────────────────────────────
126
+ if _DEBUG:
127
+ _log.warning(
128
+ "[NP_DEBUG] reconstruct_multicond_batch\n"
129
+ " batch_cond_indices : %r",
130
+ res[0],
131
+ )
132
+ # ──────────────────────────────────────────────────────────────────────
133
+
134
+ return res
neutral_prompt_patched/lib_neutral_prompt/ui.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for sd-webui-neutral-prompt (unified).
3
+
4
+ Combines:
5
+ - Core prompt types + tooltip (main branch)
6
+ - AND_ALIGN / AND_MASK_ALIGN entries (alignment_blend branch)
7
+ - elem_id on dropdown (main branch)
8
+ - CFG Rescale infotext / paste support (all branches)
9
+ - Affine-transform hint in tooltip (affine branch)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from itertools import product
15
+ from typing import Callable, Dict, List, Tuple
16
+
17
+ import dataclasses
18
+ import gradio as gr
19
+
20
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
21
+ from modules import script_callbacks, shared
22
+
23
+
24
+ txt2img_prompt_textbox = None
25
+ img2img_prompt_textbox = None
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Prompt-type registry shown in the UI dropdown
30
+ # ---------------------------------------------------------------------------
31
+
32
+ prompt_types: Dict[str, str] = {
33
+ 'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
34
+ 'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
35
+ 'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
36
+ 'Saliency wide / classic (AND_SALT_WIDE)':neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
37
+ 'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
38
+ }
39
+
40
+ # Add AND_ALIGN_D_S entries for commonly-used kernel size pairs
41
+ for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
42
+ _key = f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'
43
+ prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
44
+
45
+ # Add AND_MASK_ALIGN_D_S entries
46
+ for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
47
+ _key = f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'
48
+ prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
49
+
50
+ prompt_types_tooltip = '\n'.join([
51
+ 'AND – add all prompt features equally (webui built-in)',
52
+ 'AND_PERP – reduce contradicting prompt features via perpendicular projection',
53
+ 'AND_SALT – sharp saliency mask: child wins only at its 1-2 peak pixels (surgical)',
54
+ 'AND_SALT_BLOB – blob saliency mask: peak pixels eroded to core, then grown to smooth blob',
55
+ 'AND_SALT_WIDE – broad saliency mask: child wins ~55% of pixels (classic main-branch)',
56
+ 'AND_TOPK – small targeted changes (semantic guidance top-k)',
57
+ 'AND_ALIGN_D_S – soft blend: child adds details without breaking structure',
58
+ 'AND_MASK_ALIGN_D_S– binary mask blend: stricter version of AND_ALIGN',
59
+ '',
60
+ 'Affine transforms (supported in either order for auxiliary segments):',
61
+ ' ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]',
62
+ ' e.g.: ROTATE[0.125] AND_PERP vibrant colors :0.8',
63
+ ' e.g.: AND_PERP ROTATE[0.125] vibrant colors :0.8',
64
+ ])
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # UI component class
69
+ # ---------------------------------------------------------------------------
70
+
71
+ @dataclasses.dataclass
72
+ class AccordionInterface:
73
+ get_elem_id: Callable[[str], str]
74
+
75
+ def __post_init__(self):
76
+ self.is_rendered = False
77
+
78
+ self.cfg_rescale = gr.Slider(
79
+ label='CFG rescale φ',
80
+ minimum=0, maximum=1, value=0,
81
+ info='0 = disabled. Rescales the CFG output towards the predicted x0 to reduce over-saturation.',
82
+ )
83
+ self.neutral_prompt = gr.Textbox(
84
+ label='Neutral prompt',
85
+ show_label=False,
86
+ lines=3,
87
+ placeholder=(
88
+ 'Neutral / auxiliary prompt '
89
+ '(press "Apply to prompt" to append with the chosen strategy keyword)'
90
+ ),
91
+ )
92
+ self.neutral_cond_scale = gr.Slider(
93
+ label='Prompt weight',
94
+ minimum=-3, maximum=3, value=1,
95
+ )
96
+ self.aux_prompt_type = gr.Dropdown(
97
+ label='Prompt type',
98
+ choices=list(prompt_types.keys()),
99
+ value=next(iter(prompt_types.keys())),
100
+ tooltip=prompt_types_tooltip,
101
+ elem_id=self.get_elem_id('formatter_prompt_type'),
102
+ )
103
+ self.append_to_prompt_button = gr.Button(value='Apply to prompt')
104
+
105
+ # ------------------------------------------------------------------
106
+
107
+ def arrange_components(self, is_img2img: bool) -> None:
108
+ if self.is_rendered:
109
+ return
110
+
111
+ with gr.Accordion(label='Neutral Prompt', open=False):
112
+ self.cfg_rescale.render()
113
+ with gr.Accordion(label='Prompt formatter', open=False):
114
+ self.neutral_prompt.render()
115
+ self.neutral_cond_scale.render()
116
+ self.aux_prompt_type.render()
117
+ self.append_to_prompt_button.render()
118
+
119
+ def connect_events(self, is_img2img: bool) -> None:
120
+ if self.is_rendered:
121
+ return
122
+
123
+ prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
124
+ self.append_to_prompt_button.click(
125
+ fn=lambda init_prompt, prompt, scale, prompt_type: (
126
+ f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}',
127
+ '',
128
+ ),
129
+ inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type],
130
+ outputs=[prompt_textbox, self.neutral_prompt],
131
+ )
132
+
133
+ def set_rendered(self, value: bool = True) -> None:
134
+ self.is_rendered = value
135
+
136
+ # ------------------------------------------------------------------
137
+ # Script interface (args / infotext / paste)
138
+
139
+ def get_components(self) -> Tuple[gr.components.Component, ...]:
140
+ return (self.cfg_rescale,)
141
+
142
+ def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str], ...]:
143
+ return tuple(zip(self.get_components(), ('CFG Rescale phi',)))
144
+
145
+ def get_paste_field_names(self) -> List[str]:
146
+ return ['CFG Rescale phi']
147
+
148
+ def get_extra_generation_params(self, args: Dict) -> Dict:
149
+ return {'CFG Rescale phi': args['cfg_rescale']}
150
+
151
+ def unpack_processing_args(self, cfg_rescale: float) -> Dict:
152
+ return {'cfg_rescale': cfg_rescale}
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Settings & callbacks
157
+ # ---------------------------------------------------------------------------
158
+
159
+ def on_ui_settings() -> None:
160
+ section = ('neutral_prompt', 'Neutral Prompt')
161
+ shared.opts.add_option(
162
+ 'neutral_prompt_enabled',
163
+ shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
164
+ )
165
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
166
+ shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
167
+
168
+ shared.opts.add_option(
169
+ 'neutral_prompt_verbose',
170
+ shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
171
+ )
172
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
173
+ shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
174
+
175
+
176
+ script_callbacks.on_ui_settings(on_ui_settings)
177
+
178
+
179
+ def _update_enabled() -> None:
180
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
181
+
182
+
183
+ def _update_verbose() -> None:
184
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
185
+
186
+
187
+ def on_after_component(component, **_kwargs) -> None:
188
+ global txt2img_prompt_textbox, img2img_prompt_textbox
189
+ eid = getattr(component, 'elem_id', None)
190
+ if eid == 'txt2img_prompt':
191
+ txt2img_prompt_textbox = component
192
+ elif eid == 'img2img_prompt':
193
+ img2img_prompt_textbox = component
194
+
195
+
196
+ script_callbacks.on_after_component(on_after_component)
neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from types import ModuleType
3
+ from typing import Optional
4
+ from modules import scripts
5
+ from lib_neutral_prompt import global_state
6
+
7
+
8
+ def patch():
9
+ xyz_module = find_xyz_module()
10
+ if xyz_module is None:
11
+ print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
12
+ return
13
+
14
+ xyz_module.axis_options.extend([
15
+ xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
16
+ ])
17
+
18
+
19
+ class XyzFloat(float):
20
+ is_xyz: bool = True
21
+
22
+
23
+ def apply_cfg_rescale():
24
+ def callback(_p, v, _vs):
25
+ global_state.cfg_rescale = XyzFloat(v)
26
+
27
+ return callback
28
+
29
+
30
+ def int_or_float(string):
31
+ try:
32
+ return int(string)
33
+ except ValueError:
34
+ return float(string)
35
+
36
+
37
+ def find_xyz_module() -> Optional[ModuleType]:
38
+ for data in scripts.scripts_data:
39
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
40
+ return data.module
41
+
42
+ return None
neutral_prompt_patched/scripts/neutral_prompt.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
2
+ from modules import scripts, processing, shared
3
+ from typing import Dict
4
+ import functools
5
+
6
+
7
+ class NeutralPromptScript(scripts.Script):
8
+ def __init__(self):
9
+ self.accordion_interface = None
10
+ self._is_img2img = False
11
+
12
+ @property
13
+ def is_img2img(self):
14
+ return self._is_img2img
15
+
16
+ @is_img2img.setter
17
+ def is_img2img(self, is_img2img):
18
+ self._is_img2img = is_img2img
19
+ if self.accordion_interface is None:
20
+ self.accordion_interface = ui.AccordionInterface(self.elem_id)
21
+
22
+ def title(self) -> str:
23
+ return "Neutral Prompt"
24
+
25
+ def show(self, is_img2img: bool):
26
+ return scripts.AlwaysVisible
27
+
28
+ def ui(self, is_img2img: bool):
29
+ self.hijack_composable_lora(is_img2img)
30
+
31
+ self.accordion_interface.arrange_components(is_img2img)
32
+ self.accordion_interface.connect_events(is_img2img)
33
+ self.infotext_fields = self.accordion_interface.get_infotext_fields()
34
+ self.paste_field_names = self.accordion_interface.get_paste_field_names()
35
+ self.accordion_interface.set_rendered()
36
+ return self.accordion_interface.get_components()
37
+
38
+ def process(self, p: processing.StableDiffusionProcessing, *args):
39
+ args = self.accordion_interface.unpack_processing_args(*args)
40
+
41
+ self.update_global_state(args)
42
+ if global_state.is_enabled:
43
+ p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
44
+
45
+ def update_global_state(self, args: Dict):
46
+ if shared.state.job_no == 0:
47
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
48
+
49
+ for k, v in args.items():
50
+ try:
51
+ getattr(global_state, k)
52
+ except AttributeError:
53
+ continue
54
+
55
+ if getattr(getattr(global_state, k), 'is_xyz', False):
56
+ xyz_attr = getattr(global_state, k)
57
+ xyz_attr.is_xyz = False
58
+ args[k] = xyz_attr
59
+ continue
60
+
61
+ if shared.state.job_no > 0:
62
+ continue
63
+
64
+ setattr(global_state, k, v)
65
+
66
+ def hijack_composable_lora(self, is_img2img):
67
+ if self.accordion_interface.is_rendered:
68
+ return
69
+
70
+ lora_script = None
71
+ script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
72
+
73
+ for script in script_runner.alwayson_scripts:
74
+ if script.title().lower() == "composable lora":
75
+ lora_script = script
76
+ break
77
+
78
+ if lora_script is not None:
79
+ lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
80
+
81
+
82
+ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
83
+ if not global_state.is_enabled:
84
+ return original_function(p, *args, **kwargs)
85
+
86
+ exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
87
+ all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
88
+ res = original_function(p, *args, **kwargs)
89
+ # restore original prompts
90
+ p.all_prompts = all_prompts
91
+ return res
92
+
93
+
94
+ xyz_grid.patch()
neutral_prompt_patched/test/perp_parser/__init__.py ADDED
File without changes
neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for affine-keyword ordering in parse_prompt.
3
+
4
+ Both orderings must work correctly:
5
+ (A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
6
+ (B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
7
+ """
8
+ import unittest
9
+
10
+ from lib_neutral_prompt import neutral_prompt_parser as p
11
+ from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
12
+
13
+
14
+ class TestAffineKeywordOrder(unittest.TestCase):
15
+
16
+ # ------------------------------------------------------------------ #
17
+ # trailing affine (AND_PERP ROTATE[...] text) #
18
+ # ------------------------------------------------------------------ #
19
+
20
+ def test_trailing_affine_basic(self):
21
+ root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
22
+ self.assertEqual(len(root.children), 2)
23
+ self.assertEqual(root.children[0].prompt.strip(), "base")
24
+ self.assertIsNone(root.children[0].local_transform)
25
+ self.assertIsNotNone(root.children[1].local_transform)
26
+ self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
27
+ self.assertAlmostEqual(root.children[1].weight, 0.8)
28
+
29
+ # ------------------------------------------------------------------ #
30
+ # leading affine (ROTATE[...] AND_PERP text) #
31
+ # ------------------------------------------------------------------ #
32
+
33
+ def test_leading_affine_mid_prompt(self):
34
+ """ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
35
+ root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
36
+ self.assertEqual(len(root.children), 2,
37
+ f"Expected 2 children, got {len(root.children)}: "
38
+ f"{[getattr(c,'prompt','composite') for c in root.children]}")
39
+ first, second = root.children
40
+ # first segment = "base " with no transform
41
+ self.assertIn("base", first.prompt)
42
+ self.assertNotIn("ROTATE", first.prompt)
43
+ self.assertIsNone(first.local_transform)
44
+ # second segment gets the transform
45
+ self.assertIsNotNone(second.local_transform)
46
+ self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
47
+ self.assertAlmostEqual(second.weight, 0.8)
48
+
49
+ def test_leading_affine_at_start(self):
50
+ """ROTATE[...] AND_PERP as first (and only non-AND) segment."""
51
+ root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
52
+ self.assertEqual(len(root.children), 1)
53
+ child = root.children[0]
54
+ self.assertIsNotNone(child.local_transform)
55
+ self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
56
+ self.assertIn("vivid", child.prompt)
57
+ self.assertNotIn("ROTATE", child.prompt)
58
+
59
+ def test_leading_affine_with_and_salt(self):
60
+ root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
61
+ self.assertEqual(len(root.children), 2)
62
+ second = root.children[1]
63
+ self.assertIsNotNone(second.local_transform)
64
+ self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
65
+
66
+ # ------------------------------------------------------------------ #
67
+ # Composed / chained affine #
68
+ # ------------------------------------------------------------------ #
69
+
70
+ def test_leading_and_trailing_compose(self):
71
+ """ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
72
+ root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
73
+ # Just check no crash and all children parse
74
+ self.assertGreaterEqual(len(root.children), 1)
75
+
76
+ # ------------------------------------------------------------------ #
77
+ # CFGRescaleFactorSingleton lifecycle #
78
+ # ------------------------------------------------------------------ #
79
+
80
+ def test_cfg_rescale_singleton_clear(self):
81
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
82
+ S.clear()
83
+ self.assertIsNone(S.get())
84
+ S.set(1.23)
85
+ self.assertAlmostEqual(S.get(), 1.23)
86
+ S.clear()
87
+ self.assertIsNone(S.get())
88
+
89
+ def test_cfg_rescale_singleton_thread_local(self):
90
+ import threading
91
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
92
+ results = {}
93
+ def worker(val, key):
94
+ S.clear()
95
+ S.set(val)
96
+ import time; time.sleep(0.01)
97
+ results[key] = S.get()
98
+ threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
99
+ for t in threads: t.start()
100
+ for t in threads: t.join()
101
+ for i in range(3):
102
+ self.assertAlmostEqual(results[i], i * 10.0,
103
+ msg=f"Thread {i} got {results[i]} expected {i*10.0}")
104
+
105
+ # ------------------------------------------------------------------ #
106
+ # New AND_SALT variants parse correctly #
107
+ # ------------------------------------------------------------------ #
108
+
109
+ def test_and_salt_wide_parsed(self):
110
+ root = p.parse_root("base AND_SALT_WIDE vivid")
111
+ self.assertEqual(len(root.children), 2)
112
+ self.assertEqual(root.children[1].conciliation,
113
+ p.ConciliationStrategy.SALIENCE_MASK_WIDE)
114
+
115
+ def test_and_salt_blob_parsed(self):
116
+ root = p.parse_root("base AND_SALT_BLOB vivid")
117
+ self.assertEqual(len(root.children), 2)
118
+ self.assertEqual(root.children[1].conciliation,
119
+ p.ConciliationStrategy.SALIENCE_MASK_BLOB)
120
+
121
+ def test_and_align_parsed(self):
122
+ root = p.parse_root("base AND_ALIGN_4_8 vivid")
123
+ self.assertEqual(len(root.children), 2)
124
+ self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
125
+
126
+ def test_and_mask_align_parsed(self):
127
+ root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
128
+ self.assertEqual(len(root.children), 2)
129
+ self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration smoke tests for the pre-noise affine pipeline.
3
+
4
+ Requires PyTorch. If torch is not installed the entire module is skipped via
5
+ ``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
6
+
7
+ Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
8
+ modules.shared). torch itself is real — no sys.modules replacement.
9
+
10
+ Tests:
11
+ 1. Hook is a no-op when global_state.is_enabled = False.
12
+ 2. Identity transform (angle=0) leaves x numerically unchanged.
13
+ 3. Non-identity transform calls apply_affine_transform (verified by spy).
14
+ 4. Prompts with no affine leave x unchanged.
15
+ 5. Empty state does not crash.
16
+ 6. Batch with multiple prompts does not crash.
17
+ """
18
+
19
+ import dataclasses
20
+ import importlib
21
+ import math
22
+ import sys
23
+ import types
24
+ import unittest
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Detect torch availability — skip the whole module if absent
28
+ # ---------------------------------------------------------------------------
29
+ try:
30
+ import importlib.util as _ilu
31
+ _spec = _ilu.find_spec('torch')
32
+ _TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
33
+ except (ValueError, ModuleNotFoundError):
34
+ _TORCH_AVAILABLE = False
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
39
+ # be imported even when running the skip path)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def _stub(name):
43
+ m = types.ModuleType(name)
44
+ sys.modules.setdefault(name, m)
45
+ return sys.modules[name]
46
+
47
+ for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
48
+ 'modules.shared', 'modules.prompt_parser', 'gradio'):
49
+ _stub(_mod_name)
50
+
51
+ import modules.script_callbacks as _sc
52
+ if not hasattr(_sc, 'on_cfg_denoiser'):
53
+ _sc.on_cfg_denoiser = lambda fn: None
54
+ if not hasattr(_sc, 'on_script_unloaded'):
55
+ _sc.on_script_unloaded = lambda fn: None
56
+
57
+ import modules.shared as _sh
58
+ if not hasattr(_sh, 'opts'):
59
+ _sh.opts = types.SimpleNamespace(
60
+ data={},
61
+ add_option=lambda *a, **k: None,
62
+ onchange=lambda *a, **k: None,
63
+ OptionInfo=lambda *a, **k: None,
64
+ )
65
+
66
+ import modules.sd_samplers as _ss
67
+ if not hasattr(_ss, 'cfg_denoiser'):
68
+ _ss.cfg_denoiser = None
69
+ # cfg_denoiser_hijack installs a hijack on create_sampler at import time
70
+ if not hasattr(_ss, 'create_sampler'):
71
+ _ss.create_sampler = lambda *a, **k: None
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Repo root on sys.path
76
+ # ---------------------------------------------------------------------------
77
+
78
+ from pathlib import Path
79
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
80
+ if str(_REPO_ROOT) not in sys.path:
81
+ sys.path.insert(0, str(_REPO_ROOT))
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Conditional imports (only when torch is available)
86
+ # ---------------------------------------------------------------------------
87
+
88
+ if _TORCH_AVAILABLE:
89
+ import torch
90
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
91
+ from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
92
+ import lib_neutral_prompt.affine_transform as _at_mod
93
+
94
+ # -----------------------------------------------------------------------
95
+ # Helpers
96
+ # -----------------------------------------------------------------------
97
+
98
+ def _leaf(text, angle=None):
99
+ """LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
100
+ transform = None
101
+ if angle is not None:
102
+ c = math.cos(angle * 2 * math.pi)
103
+ s = math.sin(angle * 2 * math.pi)
104
+ transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
105
+ return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
106
+
107
+ def _comp(*children):
108
+ return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
109
+
110
+ @dataclasses.dataclass
111
+ class _FakeParams:
112
+ x: torch.Tensor
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Test class
117
+ # ---------------------------------------------------------------------------
118
+
119
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
120
+ class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
121
+
122
+ def setUp(self):
123
+ global_state.is_enabled = True
124
+ global_state.prompt_exprs = []
125
+ global_state.batch_cond_indices = []
126
+
127
+ def tearDown(self):
128
+ global_state.is_enabled = False
129
+ global_state.prompt_exprs = []
130
+ global_state.batch_cond_indices = []
131
+
132
+ # 1. disabled → x untouched
133
+ def test_disabled_leaves_x_unchanged(self):
134
+ global_state.is_enabled = False
135
+ x = torch.zeros(2, 4, 8, 8)
136
+ x[0] = 1.0
137
+ orig = x.clone()
138
+ params = _FakeParams(x=x.clone())
139
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
140
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
141
+ _on_cfg_denoiser(params)
142
+ self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
143
+
144
+ # 2. identity transform → x numerically unchanged
145
+ def test_identity_transform_no_change(self):
146
+ x = torch.randn(2, 4, 8, 8)
147
+ orig = x.clone()
148
+ params = _FakeParams(x=x.clone())
149
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
150
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
151
+ _on_cfg_denoiser(params)
152
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
153
+ "Identity affine must not change x")
154
+
155
+ # 3. non-identity → apply_affine_transform is called (spy)
156
+ def test_nonidentity_calls_apply_affine_transform(self):
157
+ x = torch.zeros(2, 4, 8, 8)
158
+ x[1] = torch.randn(4, 8, 8)
159
+ params = _FakeParams(x=x.clone())
160
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
161
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
162
+
163
+ _real = _at_mod.apply_affine_transform
164
+ calls = []
165
+ def _spy(tensor, affine, mode='bilinear'):
166
+ calls.append(True)
167
+ return _real(tensor, affine, mode)
168
+
169
+ _at_mod.apply_affine_transform = _spy
170
+ try:
171
+ _on_cfg_denoiser(params)
172
+ finally:
173
+ _at_mod.apply_affine_transform = _real
174
+
175
+ self.assertGreater(len(calls), 0,
176
+ "apply_affine_transform must be called for non-identity transform")
177
+
178
+ # 4. no affine on any child → x unchanged
179
+ def test_no_affine_no_change(self):
180
+ x = torch.randn(2, 4, 8, 8)
181
+ orig = x.clone()
182
+ params = _FakeParams(x=x.clone())
183
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
184
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
185
+ _on_cfg_denoiser(params)
186
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
187
+ "No affine → x must not change")
188
+
189
+ # 5. empty state → no crash
190
+ def test_empty_state_no_crash(self):
191
+ params = _FakeParams(x=torch.randn(1, 4, 8, 8))
192
+ global_state.prompt_exprs = []
193
+ global_state.batch_cond_indices = []
194
+ try:
195
+ _on_cfg_denoiser(params)
196
+ except Exception as e:
197
+ self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
198
+
199
+ # 6. two batch entries → no crash
200
+ def test_batch_multiple_prompts_no_crash(self):
201
+ params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
202
+ global_state.prompt_exprs = [
203
+ _comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
204
+ _comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
205
+ ]
206
+ global_state.batch_cond_indices = [
207
+ [(0, 1.0), (1, 1.0)],
208
+ [(2, 1.0), (3, 1.0)],
209
+ ]
210
+ try:
211
+ _on_cfg_denoiser(params)
212
+ except Exception as e:
213
+ self.fail(f"Multi-prompt batch raised: {e}")
214
+
215
+
216
+ if __name__ == '__main__':
217
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_basic_parser.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
11
+ self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
12
+ self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
13
+ self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
14
+ self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
15
+ self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
16
+ self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
17
+
18
+ def test_simple_prompt_child_count(self):
19
+ self.assertEqual(len(self.simple_prompt.children), 1)
20
+
21
+ def test_simple_prompt_child_weight(self):
22
+ self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
23
+
24
+ def test_simple_prompt_child_prompt(self):
25
+ self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
26
+
27
+ def test_square_weight_prompt(self):
28
+ prompt = "a [b c d e : f g h :1.5]"
29
+ parsed = neutral_prompt_parser.parse_root(prompt)
30
+ self.assertEqual(parsed.children[0].prompt, prompt)
31
+
32
+ composed_prompt = f"{prompt} AND_PERP other prompt"
33
+ parsed = neutral_prompt_parser.parse_root(composed_prompt)
34
+ self.assertEqual(parsed.children[0].prompt, prompt)
35
+
36
+ def test_and_prompt_child_count(self):
37
+ self.assertEqual(len(self.and_prompt.children), 2)
38
+
39
+ def test_and_prompt_child_weights_and_prompts(self):
40
+ self.assertEqual(self.and_prompt.children[0].weight, 1.0)
41
+ self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
42
+ self.assertEqual(self.and_prompt.children[1].weight, 2.0)
43
+ self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
44
+
45
+ def test_and_perp_prompt_child_count(self):
46
+ self.assertEqual(len(self.and_perp_prompt.children), 2)
47
+
48
+ def test_and_perp_prompt_child_types(self):
49
+ self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
50
+ self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
51
+
52
+ def test_and_perp_prompt_nested_child(self):
53
+ nested_child = self.and_perp_prompt.children[1]
54
+ self.assertEqual(nested_child.weight, 2.0)
55
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
56
+
57
+ def test_nested_and_perp_prompt_child_count(self):
58
+ self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
59
+
60
+ def test_nested_and_perp_prompt_child_types(self):
61
+ self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
62
+ self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
63
+
64
+ def test_nested_and_perp_prompt_nested_child_types(self):
65
+ nested_child = self.nested_and_perp_prompt.children[1].children[0]
66
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
67
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
68
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
69
+
70
+ def test_nested_and_perp_prompt_nested_child(self):
71
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
72
+ self.assertEqual(nested_child.weight, 3.0)
73
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
74
+
75
+ def test_invalid_weight_child_count(self):
76
+ self.assertEqual(len(self.invalid_weight.children), 1)
77
+
78
+ def test_invalid_weight_child_weight(self):
79
+ self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
80
+
81
+ def test_invalid_weight_child_prompt(self):
82
+ self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
83
+
84
+ def test_and_salt_prompt_child_count(self):
85
+ self.assertEqual(len(self.and_salt_prompt.children), 2)
86
+
87
+ def test_and_salt_prompt_child_types(self):
88
+ self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
89
+ self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
90
+
91
+ def test_and_salt_prompt_nested_child(self):
92
+ nested_child = self.and_salt_prompt.children[1]
93
+ self.assertEqual(nested_child.weight, 2.0)
94
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
95
+
96
+ def test_nested_and_salt_prompt_child_count(self):
97
+ self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
98
+
99
+ def test_nested_and_salt_prompt_child_types(self):
100
+ self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
101
+ self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
102
+
103
+ def test_nested_and_salt_prompt_nested_child_types(self):
104
+ nested_child = self.nested_and_salt_prompt.children[1].children[0]
105
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
106
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
107
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
108
+
109
+ def test_nested_and_salt_prompt_nested_child(self):
110
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
111
+ self.assertEqual(nested_child.weight, 3.0)
112
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
113
+
114
+ def test_start_with_prompt_editing(self):
115
+ prompt = "[(long shot:1.2):0.1] detail.."
116
+ res = neutral_prompt_parser.parse_root(prompt)
117
+ self.assertEqual(res.children[0].weight, 1.0)
118
+ self.assertEqual(res.children[0].prompt, prompt)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_malicious_parser.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestMaliciousPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.parser = neutral_prompt_parser
11
+
12
+ def test_empty(self):
13
+ result = self.parser.parse_root("")
14
+ self.assertEqual(result.children[0].prompt, "")
15
+ self.assertEqual(result.children[0].weight, 1.0)
16
+
17
+ def test_zero_weight(self):
18
+ result = self.parser.parse_root("hello :0.0")
19
+ self.assertEqual(result.children[0].weight, 0.0)
20
+
21
+ def test_mixed_positive_and_negative_weights(self):
22
+ result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
23
+ self.assertEqual(result.children[0].weight, 1.0)
24
+ self.assertEqual(result.children[1].weight, -2.0)
25
+
26
+ def test_debalanced_square_brackets(self):
27
+ prompt = "a [ b " * 100
28
+ result = self.parser.parse_root(prompt)
29
+ self.assertEqual(result.children[0].prompt, prompt)
30
+
31
+ prompt = "a ] b " * 100
32
+ result = self.parser.parse_root(prompt)
33
+ self.assertEqual(result.children[0].prompt, prompt)
34
+
35
+ repeats = 10
36
+ prompt = "a [ [ b AND c ] " * repeats
37
+ result = self.parser.parse_root(prompt)
38
+ self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
39
+
40
+ repeats = 10
41
+ prompt = "a [ b AND c ] ] " * repeats
42
+ result = self.parser.parse_root(prompt)
43
+ self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
44
+
45
+ def test_erroneous_syntax(self):
46
+ result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
47
+ self.assertEqual(result.children[0].weight, 1.0)
48
+ self.assertEqual(result.children[1].prompt, "[goodbye ")
49
+ self.assertEqual(result.children[1].weight, 2.0)
50
+
51
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
52
+ self.assertEqual(result.children[0].weight, 1.0)
53
+ self.assertEqual(result.children[1].prompt, " goodbye ")
54
+
55
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
56
+ self.assertEqual(result.children[1].prompt, " goodbye]")
57
+ self.assertEqual(result.children[1].weight, 2.0)
58
+
59
+ result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
60
+ self.assertEqual(result.children[1].weight, 2.0)
61
+ self.assertEqual(result.children[1].prompt, " a [ goodbye ")
62
+
63
+ result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
64
+ self.assertEqual(result.children[0].weight, 1.0)
65
+ self.assertEqual(result.children[2].prompt, " goodbye ")
66
+
67
+ def test_huge_number_of_prompt_parts(self):
68
+ result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
69
+ self.assertEqual(len(result.children), 10**4)
70
+
71
+ def test_prompt_ending_with_weight(self):
72
+ result = self.parser.parse_root("hello :1.0 AND :2.0")
73
+ self.assertEqual(result.children[0].weight, 1.0)
74
+ self.assertEqual(result.children[1].prompt, "")
75
+ self.assertEqual(result.children[1].weight, 2.0)
76
+
77
+ def test_huge_input_string(self):
78
+ big_string = "hello :1.0 AND " * 10**4
79
+ result = self.parser.parse_root(big_string)
80
+ self.assertEqual(len(result.children), 10**4 + 1)
81
+
82
+ def test_deeply_nested_prompt(self):
83
+ deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
84
+ result = self.parser.parse_root(deeply_nested_prompt)
85
+ self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
86
+
87
+ def test_complex_nested_prompts(self):
88
+ complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
89
+ result = self.parser.parse_root(complex_prompt)
90
+ self.assertEqual(result.children[0].weight, 1.0)
91
+ self.assertEqual(result.children[1].weight, 2.0)
92
+ self.assertEqual(result.children[2].children[0].weight, 3.0)
93
+ self.assertEqual(result.children[2].children[1].weight, 4.0)
94
+ self.assertEqual(result.children[2].children[2].weight, 5.0)
95
+
96
+ def test_string_with_random_characters(self):
97
+ random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
98
+ try:
99
+ self.parser.parse_root(random_chars)
100
+ except Exception:
101
+ self.fail("parse_root couldn't handle a string with random characters.")
102
+
103
+ def test_string_with_unexpected_symbols(self):
104
+ unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
105
+ try:
106
+ self.parser.parse_root(unexpected_symbols)
107
+ except Exception:
108
+ self.fail("parse_root couldn't handle a string with unexpected symbols.")
109
+
110
+ def test_string_with_unconventional_structure(self):
111
+ unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
112
+ try:
113
+ self.parser.parse_root(unconventional_structure)
114
+ except Exception:
115
+ self.fail("parse_root couldn't handle a string with unconventional structure.")
116
+
117
+ def test_string_with_mixed_alphabets_and_numbers(self):
118
+ mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
119
+ try:
120
+ self.parser.parse_root(mixed_alphabets_and_numbers)
121
+ except Exception:
122
+ self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
123
+
124
+ def test_string_with_nested_brackets(self):
125
+ nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
126
+ try:
127
+ self.parser.parse_root(nested_brackets)
128
+ except Exception:
129
+ self.fail("parse_root couldn't handle a string with nested brackets.")
130
+
131
+ def test_unmatched_opening_braces(self):
132
+ unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
133
+ try:
134
+ self.parser.parse_root(unmatched_opening_braces)
135
+ except Exception:
136
+ self.fail("parse_root couldn't handle a string with unmatched opening braces.")
137
+
138
+ def test_unmatched_closing_braces(self):
139
+ unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
140
+ try:
141
+ self.parser.parse_root(unmatched_closing_braces)
142
+ except Exception:
143
+ self.fail("parse_root couldn't handle a string with unmatched closing braces.")
144
+
145
+ def test_repeating_colons(self):
146
+ repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
147
+ try:
148
+ self.parser.parse_root(repeating_colons)
149
+ except Exception:
150
+ self.fail("parse_root couldn't handle a string with repeating colons.")
151
+
152
+ def test_excessive_whitespace(self):
153
+ excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
154
+ try:
155
+ self.parser.parse_root(excessive_whitespace)
156
+ except Exception:
157
+ self.fail("parse_root couldn't handle a string with excessive whitespace.")
158
+
159
+ def test_repeating_AND_keyword(self):
160
+ repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
161
+ try:
162
+ self.parser.parse_root(repeating_AND_keyword)
163
+ except Exception:
164
+ self.fail("parse_root couldn't handle a string with repeating AND keyword.")
165
+
166
+ def test_repeating_AND_PERP_keyword(self):
167
+ repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
168
+ try:
169
+ self.parser.parse_root(repeating_AND_PERP_keyword)
170
+ except Exception:
171
+ self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
172
+
173
+ def test_square_weight_prompt(self):
174
+ prompt = "AND_PERP [weighted] you thought it was the end"
175
+ try:
176
+ self.parser.parse_root(prompt)
177
+ except Exception:
178
+ self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
179
+
180
+
181
+ if __name__ == '__main__':
182
+ unittest.main()
prompt-fusion-extension-main/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv/
2
+ /.idea/
3
+ __pycache__/
prompt-fusion-extension-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2003
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
prompt-fusion-extension-main/lib_prompt_fusion/ast_nodes.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from lib_prompt_fusion import interpolation_functions
3
+ from lib_prompt_fusion.t_scaler import scale_t
4
+ from lib_prompt_fusion import interpolation_tensor
5
+
6
+
7
+ class ListExpression:
8
+ def __init__(self, expressions):
9
+ self.__expressions = expressions
10
+
11
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
12
+ if not self.__expressions:
13
+ return
14
+
15
+ def expr_extend_tensor(expr):
16
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
17
+
18
+ expr_extend_tensor(self.__expressions[0])
19
+ for expression in self.__expressions[1:]:
20
+ tensor_builder.append(' ')
21
+ expr_extend_tensor(expression)
22
+
23
+
24
+ class InterpolationExpression:
25
+ @staticmethod
26
+ def create(exprs, steps, function_name):
27
+ if function_name == "mean":
28
+ return AverageExpression(exprs, steps)
29
+
30
+ max_len = min(len(exprs), len(steps))
31
+ exprs = exprs[:max_len]
32
+ steps = steps[:max_len]
33
+
34
+ return InterpolationExpression(exprs, steps, function_name)
35
+
36
+ def __init__(self, expressions, steps, function_name=None):
37
+ assert len(expressions) >= 2
38
+ assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
39
+ self.__expressions = expressions
40
+ self.__steps = steps
41
+ self.__function_name = function_name if function_name is not None else 'linear'
42
+
43
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
44
+ def tensor_updater(expr):
45
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
46
+
47
+ tensor_builder.extrude(
48
+ [tensor_updater(expr) for expr in self.__expressions],
49
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
50
+
51
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
52
+ steps = list(self.__steps)
53
+ if steps[0] is None:
54
+ steps[0] = LiftExpression(str(steps_range[0] - 1))
55
+ if steps[-1] is None:
56
+ steps[-1] = LiftExpression(str(steps_range[1] - 1))
57
+
58
+ for i, step in enumerate(steps):
59
+ if step is None:
60
+ continue
61
+
62
+ step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
63
+
64
+ if use_old_scheduling and 0 < step < 1:
65
+ step *= total_steps
66
+ elif not use_old_scheduling and isinstance(step, float):
67
+ step = (step - int(is_hires)) * total_steps
68
+ else:
69
+ step += 1
70
+
71
+ steps[i] = int(step)
72
+
73
+ i = 1
74
+ while i < len(steps):
75
+ none_len = 0
76
+ while steps[i + none_len] is None:
77
+ none_len += 1
78
+
79
+ min_step, max_step = steps[i - 1], steps[i + none_len]
80
+
81
+ for j in range(none_len):
82
+ steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
83
+
84
+ i += 1 + none_len
85
+
86
+ interpolation_function = {
87
+ 'linear': interpolation_functions.compute_linear,
88
+ 'bezier': interpolation_functions.compute_bezier,
89
+ 'catmull': interpolation_functions.compute_catmull,
90
+ }[self.__function_name]
91
+
92
+ def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
93
+ scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
94
+ scaled_t = scale_t(scaled_t, steps)
95
+
96
+ new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
97
+ return interpolation_function(conds, new_params)
98
+
99
+ return steps_scale_t
100
+
101
+
102
+ class AverageExpression:
103
+ def __init__(self, expressions, weights):
104
+ if len(expressions) < len(weights):
105
+ raise ValueError
106
+
107
+ self.__expressions = expressions
108
+ self.__weights = weights
109
+
110
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
111
+ def tensor_updater(expr):
112
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
113
+
114
+ tensor_builder.extrude(
115
+ [tensor_updater(expr) for expr in self.__expressions],
116
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
117
+
118
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
119
+ weights = [
120
+ _eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
121
+ for weight in self.__weights
122
+ ]
123
+ explicit_weights = [weight for weight in weights if weight is not None]
124
+ weights = [
125
+ weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
126
+ if weight is not None
127
+ else 1 / len(self.__expressions)
128
+ for weight in weights
129
+ ]
130
+ weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
131
+
132
+ def interpolation_function(conds, _params):
133
+ total = None
134
+ for cond, weight in zip(conds, weights):
135
+ cond *= weight
136
+ if total is None:
137
+ total = cond
138
+ else:
139
+ total += cond
140
+
141
+ return total
142
+
143
+ return interpolation_function
144
+
145
+
146
+ class AlternationExpression:
147
+ def __init__(self, expressions, speed):
148
+ self.__expressions = expressions
149
+ self.__speed = speed
150
+
151
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
152
+ if self.__speed is None:
153
+ speed = None
154
+ else:
155
+ speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
156
+
157
+ if speed is None:
158
+ tensor_builder.append('[')
159
+ for expr_i, expr in enumerate(self.__expressions):
160
+ if expr_i >= 1:
161
+ tensor_builder.append('|')
162
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
163
+ tensor_builder.append(']')
164
+ return
165
+
166
+ def tensor_updater(expr):
167
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
168
+
169
+ exprs = self.__expressions + [self.__expressions[0]]
170
+
171
+ tensor_builder.extrude(
172
+ [tensor_updater(expr) for expr in exprs],
173
+ self.get_interpolation_function(speed, exprs, steps_range, total_steps))
174
+
175
+ def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
176
+ def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
177
+ wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
178
+ if wrapped_t < 0:
179
+ wrapped_t = wrapped_t + 1
180
+ new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
181
+ return interpolation_functions.compute_linear(control_points, new_params)
182
+
183
+ return compute_wrap
184
+
185
+
186
+ class EditingExpression:
187
+ def __init__(self, expressions, step):
188
+ assert 1 <= len(expressions) <= 2
189
+ self.__expressions = expressions
190
+ self.__step = step
191
+
192
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
193
+ if self.__step is None:
194
+ tensor_builder.append('[')
195
+ for expr_i, expr in enumerate(self.__expressions):
196
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
197
+ tensor_builder.append(':')
198
+ tensor_builder.append(']')
199
+ return
200
+
201
+ step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
202
+ step_int = step
203
+ if use_old_scheduling and 0 < step < 1:
204
+ step_int *= total_steps
205
+ elif not use_old_scheduling and isinstance(step, float):
206
+ step_int = (step_int - int(is_hires)) * total_steps
207
+ else:
208
+ step_int += 1
209
+
210
+ step_int = int(step_int)
211
+
212
+ tensor_builder.append('[')
213
+ for expr_i, expr in enumerate(self.__expressions):
214
+ expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
215
+ expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
216
+ tensor_builder.append(':')
217
+
218
+ tensor_builder.append(f'{step}]')
219
+
220
+
221
+ class WeightedExpression:
222
+ def __init__(self, nested, weight=None, positive=True):
223
+ self.__nested = nested
224
+ if not positive:
225
+ assert weight is None
226
+
227
+ self.__weight = weight
228
+ self.__positive = positive
229
+
230
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
231
+ open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
232
+ tensor_builder.append(open_bracket)
233
+ self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
234
+
235
+ if self.__weight is not None:
236
+ tensor_builder.append(':')
237
+ self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
238
+
239
+ tensor_builder.append(close_bracket)
240
+
241
+
242
+ class WeightInterpolationExpression:
243
+ def __init__(self, nested, weight_begin, weight_end):
244
+ self.__nested = nested
245
+ self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
246
+ self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
247
+
248
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
249
+ steps_range_size = steps_range[1] - steps_range[0]
250
+
251
+ weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
252
+ weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
253
+
254
+ for i in range(steps_range_size):
255
+ step = i + steps_range[0]
256
+
257
+ weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
258
+ weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
259
+ if step > steps_range[0]:
260
+ weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
261
+ if step + 1 < steps_range[1]:
262
+ weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
263
+
264
+ weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
265
+
266
+
267
+ class DeclarationExpression:
268
+ def __init__(self, symbol, parameters, value, target):
269
+ self.__symbol = symbol
270
+ self.__value = value
271
+ self.__target = target
272
+ self.__parameters = parameters
273
+
274
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
275
+ updated_context = dict(context)
276
+ updated_context[self.__symbol] = (self.__value, self.__parameters)
277
+ self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
278
+
279
+
280
+ class SubstitutionExpression:
281
+ def __init__(self, symbol, arguments):
282
+ self.__symbol = symbol
283
+ self.__arguments = arguments
284
+
285
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
286
+ updated_context = dict(context)
287
+ nested, parameters = context[self.__symbol]
288
+ for argument, parameter in zip(self.__arguments, parameters):
289
+ updated_context[parameter] = argument, []
290
+ nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
291
+
292
+
293
+ class LiftExpression:
294
+ def __init__(self, value):
295
+ self.__value = value
296
+
297
+ def extend_tensor(self, tensor_builder, *_args, **_kwargs):
298
+ tensor_builder.append(self.__value)
299
+
300
+
301
+ def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
302
+ mock_database = ['']
303
+ expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
304
+ try:
305
+ return int(mock_database[0])
306
+ except ValueError:
307
+ return float(mock_database[0])
prompt-fusion-extension-main/lib_prompt_fusion/empty_cond.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion import interpolation_tensor
2
+
3
+
4
+ _empty_cond = None
5
+
6
+
7
+ def get():
8
+ return _empty_cond
9
+
10
+
11
+ def init(model):
12
+ global _empty_cond
13
+ cond = model.get_learned_conditioning([''])
14
+ if isinstance(cond, dict):
15
+ cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
16
+ else:
17
+ cond = interpolation_tensor.TensorCondWrapper(cond[0])
18
+
19
+ _empty_cond = cond
prompt-fusion-extension-main/lib_prompt_fusion/geometries.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from lib_prompt_fusion import interpolation_tensor
4
+
5
+
6
+ def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
7
+ p0, p1 = control_points
8
+ p0_norm = torch.linalg.norm(p0)
9
+ p1_norm = torch.linalg.norm(p1)
10
+
11
+ similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
12
+ similarity = min(1., max(-1., float(similarity)))
13
+ if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
14
+ return linear_geometry(control_points, params)
15
+
16
+ angle = math.acos(float(similarity)) / 2
17
+
18
+ slerp_t = angle * (2 * params.t - 1)
19
+ slerp_t = math.tan(slerp_t) / math.tan(angle)
20
+ slerp_t = (slerp_t + 1) / 2
21
+
22
+ normalized_p1 = p1 / p1_norm * p0_norm
23
+ slerp_p = p0 + (normalized_p1 - p0) * slerp_t
24
+ slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
25
+
26
+ lerp_p = linear_geometry(control_points, params)
27
+ return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
28
+
29
+
30
+ def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
31
+ p0, p1 = control_points
32
+ res = p0 + (p1 - p0) * params.t
33
+ return res
prompt-fusion-extension-main/lib_prompt_fusion/global_state.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from modules import shared, prompt_parser
3
+ from lib_prompt_fusion import empty_cond
4
+
5
+
6
+ old_webui_is_negative: bool = False
7
+ negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
8
+ negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
9
+
10
+
11
+ def get_origin_cond_at(step: int, is_hires: bool = False):
12
+ fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
13
+ if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
14
+ return empty_cond.get()
15
+
16
+ for schedule in fallback_schedules:
17
+ if schedule.end_at_step >= step:
18
+ return schedule.cond
19
+
20
+ return empty_cond.get()
21
+
22
+
23
+ def get_slerp_scale():
24
+ return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
25
+
26
+
27
+ def get_slerp_epsilon():
28
+ return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
prompt-fusion-extension-main/lib_prompt_fusion/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModuleHijacker:
2
+ def __init__(self, module):
3
+ self.__module = module
4
+ self.__original_functions = dict()
5
+
6
+ def hijack(self, attribute):
7
+ if attribute not in self.__original_functions:
8
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
9
+
10
+ def decorator(function):
11
+ def wrapper(*args, **kwargs):
12
+ return function(*args, **kwargs, original_function=self.__original_functions[attribute])
13
+
14
+ setattr(self.__module, attribute, wrapper)
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ register_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ register_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
prompt-fusion-extension-main/lib_prompt_fusion/interpolation_functions.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from lib_prompt_fusion import interpolation_tensor, geometries
4
+
5
+
6
+ def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
7
+ if len(control_points) <= 2:
8
+ return geometries.slerp_geometry(control_points, params)
9
+ else:
10
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
11
+ cp0 = control_points[target_curve]
12
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
13
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
14
+ return geometries.slerp_geometry([cp0, cp1], new_params)
15
+
16
+
17
+ def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
18
+ def compute_casteljau(ps, size):
19
+ for i in reversed(range(1, size)):
20
+ for j in range(i):
21
+ ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
22
+
23
+ return ps[0]
24
+
25
+ if len(control_points) == 1:
26
+ return control_points[0]
27
+ elif len(control_points) == 2:
28
+ return geometries.slerp_geometry(control_points, params)
29
+ copied_control_points = copy.deepcopy(control_points)
30
+ return compute_casteljau(copied_control_points, len(copied_control_points))
31
+
32
+
33
+ def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
34
+ if len(control_points) <= 2:
35
+ return compute_linear(control_points, params)
36
+ else:
37
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
38
+ g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
39
+ cp0 = control_points[target_curve]
40
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
41
+ g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
42
+ ip0 = cp0 + (cp1 - g0)/6
43
+ ip1 = cp1 + (cp0 - g1)/6
44
+
45
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
46
+ return compute_bezier([cp0, ip0, ip1, cp1], new_params)
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import turtle as tr
51
+ import torch
52
+ size = 60
53
+ turtle_tool = tr.Turtle()
54
+ turtle_tool.speed(10)
55
+ turtle_tool.up()
56
+
57
+ points = torch.Tensor([[-2., -2.], [2., 2.]])
58
+ origin = torch.Tensor([1.5, 1.6])
59
+
60
+ def sample(slerp_scale, color):
61
+ for i in range(size):
62
+ t = i / size
63
+ params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
64
+ point = origin + compute_linear(points - origin, params)
65
+ try:
66
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
67
+ turtle_tool.dot(5, color)
68
+ print(point)
69
+ except ValueError:
70
+ pass
71
+
72
+ sample(0, "black")
73
+ sample(1, "green")
74
+ sample(2, "blue")
75
+ sample(-1, "purple")
76
+ sample(-2, "orange")
77
+
78
+ for point in points:
79
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
80
+ turtle_tool.dot(5, "red")
81
+
82
+ turtle_tool.goto(tuple(float(p) * 100. for p in origin))
83
+ turtle_tool.dot(10, "red")
84
+
85
+ turtle_tool.goto(100000, 100000)
86
+ turtle_tool.dot()
87
+ tr.done()
prompt-fusion-extension-main/lib_prompt_fusion/interpolation_tensor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from modules import prompt_parser
4
+ from typing import NamedTuple, Union
5
+
6
+
7
+ class InterpolationParams(NamedTuple):
8
+ t: float
9
+ step: int
10
+ total_steps: int
11
+ slerp_scale: float
12
+ slerp_epsilon: float
13
+
14
+
15
+ class InterpolationTensor:
16
+ def __init__(self, sub_tensors, interpolation_function):
17
+ self.__sub_tensors = sub_tensors
18
+ self.__interpolation_function = interpolation_function
19
+
20
+ def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
21
+ cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
22
+ if params.slerp_scale != 0:
23
+ cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
24
+ return cond
25
+
26
+ def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
27
+ if self.__interpolation_function is None:
28
+ return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
29
+
30
+ control_points = [
31
+ sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
32
+ for sub_tensor in self.__sub_tensors
33
+ ]
34
+
35
+ CondWrapper, control_points_values = conds_to_cp_values(control_points)
36
+ return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
37
+
38
+ def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
39
+ schedule = None
40
+ for schedule in self.__sub_tensors:
41
+ if schedule.end_at_step >= step:
42
+ break
43
+
44
+ res = schedule.cond.extend_like(origin_cond, empty_cond)
45
+ if slerp_scale != 0:
46
+ res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
47
+ return res
48
+
49
+
50
+ def conds_to_cp_values(conds):
51
+ CondWrapper = type(conds[0])
52
+ cp_values = [
53
+ cond.to_cp_values()
54
+ for cond in conds
55
+ ]
56
+ return CondWrapper, [
57
+ [v[i] for v in cp_values]
58
+ for i in range(len(cp_values[0]))
59
+ ]
60
+
61
+
62
+ class InterpolationTensorBuilder:
63
+ def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
64
+ self.__indices_tensor = tensor if tensor is not None else 0
65
+ self.__prompt_database = prompt_database if prompt_database is not None else ['']
66
+ self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
67
+
68
+ def append(self, suffix):
69
+ for i in range(len(self.__prompt_database)):
70
+ self.__prompt_database[i] += suffix
71
+
72
+ def extrude(self, tensor_updaters, interpolation_function):
73
+ extruded_indices_tensor = []
74
+ extruded_prompt_database = []
75
+ extruded_interpolation_functions = []
76
+
77
+ for update_tensor in tensor_updaters:
78
+ nested_tensor_builder = InterpolationTensorBuilder(
79
+ self.__indices_tensor,
80
+ self.__prompt_database[:],
81
+ interpolation_functions=[])
82
+
83
+ update_tensor(nested_tensor_builder)
84
+
85
+ extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
86
+ tensor=nested_tensor_builder.__indices_tensor,
87
+ offset=len(extruded_prompt_database)))
88
+ extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
89
+ extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
90
+
91
+ self.__indices_tensor = extruded_indices_tensor
92
+ self.__prompt_database[:] = extruded_prompt_database
93
+ self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
94
+
95
+ def get_prompt_database(self):
96
+ return self.__prompt_database
97
+
98
+ @staticmethod
99
+ def __offset_tensor(tensor, offset):
100
+ try:
101
+ return tensor + offset
102
+
103
+ except TypeError:
104
+ return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
105
+
106
+ def build(self, conds, empty_cond):
107
+ max_cond_size = self.__max_cond_size(conds)
108
+ conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
109
+ return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
110
+
111
+ @staticmethod
112
+ def __build_conditionings_tensor(tensor, int_funcs, conds):
113
+ if type(tensor) is int:
114
+ return InterpolationTensor(conds[tensor], None)
115
+ else:
116
+ int_func, nested_int_funcs = int_funcs[0]
117
+ return InterpolationTensor(
118
+ [
119
+ InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
120
+ for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
121
+ ],
122
+ int_func,
123
+ )
124
+
125
+ def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
126
+ return [
127
+ [
128
+ prompt_parser.ScheduledPromptConditioning(
129
+ cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
130
+ end_at_step=schedule.end_at_step
131
+ )
132
+ for schedule in schedules
133
+ ]
134
+ for schedules in conds
135
+ ]
136
+
137
+ @staticmethod
138
+ def __max_cond_size(conds):
139
+ return max(schedule.cond.size(0)
140
+ for schedules in conds
141
+ for schedule in schedules)
142
+
143
+
144
+ @dataclasses.dataclass
145
+ class DictCondWrapper:
146
+ original_cond: dict
147
+
148
+ @staticmethod
149
+ def from_cp_values(cp_values):
150
+ return DictCondWrapper({
151
+ k: v
152
+ for k, v in zip(('crossattn', 'vector'), cp_values)
153
+ })
154
+
155
+ def size(self, *args, **kwargs):
156
+ return self.original_cond['crossattn'].size(*args, **kwargs)
157
+
158
+ def extend_like(self, that, empty):
159
+ missing_size = max(0, that.size(0) - self.size(0)) // 77
160
+ extended = DictCondWrapper(self.original_cond.copy())
161
+ extended.original_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
162
+ return extended
163
+
164
+ def resize_schedule(self, target_size, empty_cond):
165
+ cond_missing_size = (target_size - self.size(0)) // 77
166
+ if cond_missing_size <= 0:
167
+ return self
168
+
169
+ resized_cond = self.original_cond.copy()
170
+ resized_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
171
+ return DictCondWrapper(resized_cond)
172
+
173
+ def to_cp_values(self):
174
+ return list(self.original_cond.values())
175
+
176
+ def to(self, dtype: Union[dict, torch.dtype]):
177
+ if not isinstance(dtype, dict):
178
+ dtype = {
179
+ k: dtype
180
+ for k in self.original_cond.keys()
181
+ }
182
+ return DictCondWrapper({
183
+ k: v.to(dtype=dtype[k])
184
+ for k, v in self.original_cond.items()
185
+ })
186
+
187
+ @property
188
+ def dtype(self):
189
+ return {
190
+ k: v.dtype
191
+ for k, v in self.original_cond.items()
192
+ }
193
+
194
+ def __sub__(self, that):
195
+ return DictCondWrapper({
196
+ k: v - that.original_cond[k]
197
+ for k, v in self.original_cond.items()
198
+ })
199
+
200
+ def __add__(self, that):
201
+ return DictCondWrapper({
202
+ k: v + that.original_cond[k]
203
+ for k, v in self.original_cond.items()
204
+ })
205
+
206
+ def __eq__(self, that):
207
+ return all((self.original_cond[k] == that.original_cond[k]).all() for k in self.original_cond.keys())
208
+
209
+
210
+ @dataclasses.dataclass
211
+ class TensorCondWrapper:
212
+ original_cond: torch.Tensor
213
+
214
+ @staticmethod
215
+ def from_cp_values(cp_values):
216
+ return TensorCondWrapper(next(iter(cp_values)))
217
+
218
+ def size(self, *args, **kwargs):
219
+ return self.original_cond.size(*args, **kwargs)
220
+
221
+ def extend_like(self, that, empty):
222
+ missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
223
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty.original_cond] * missing_size))
224
+
225
+ def resize_schedule(self, target_size, empty_cond):
226
+ cond_missing_size = (target_size - self.original_cond.size(0)) // 77
227
+ if cond_missing_size <= 0:
228
+ return self
229
+
230
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
231
+
232
+ def to_cp_values(self):
233
+ return [self.original_cond]
234
+
235
+ def to(self, dtype: torch.dtype):
236
+ return TensorCondWrapper(self.original_cond.to(dtype=dtype))
237
+
238
+ @property
239
+ def dtype(self):
240
+ return self.original_cond.dtype
241
+
242
+ def __sub__(self, that):
243
+ return TensorCondWrapper(self.original_cond - that.original_cond)
244
+
245
+ def __add__(self, that):
246
+ return TensorCondWrapper(self.original_cond + that.original_cond)
247
+
248
+ def __eq__(self, that):
249
+ return (self.original_cond == that.original_cond).all()
prompt-fusion-extension-main/lib_prompt_fusion/prompt_parser.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib_prompt_fusion.ast_nodes as ast
2
+ from collections import namedtuple
3
+ import re
4
+
5
+
6
+ ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
7
+
8
+
9
+ def parse_prompt(prompt):
10
+ prompt = prompt.lstrip()
11
+ prompt, list_expr = parse_list_expression(prompt, set())
12
+ return list_expr
13
+
14
+
15
+ def parse_list_expression(prompt, stoppers):
16
+ exprs = []
17
+ try:
18
+ while True:
19
+ prompt, expr = parse_expression(prompt, stoppers)
20
+ exprs.append(expr)
21
+ except ValueError:
22
+ return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
23
+
24
+
25
+ def parse_expression(prompt, stoppers):
26
+ for parse in _parsers():
27
+ try:
28
+ return parse(prompt, stoppers)
29
+ except ValueError:
30
+ pass
31
+
32
+ raise ValueError
33
+
34
+
35
+ def _parsers():
36
+ return (
37
+ parse_text,
38
+ parse_declaration,
39
+ parse_substitution,
40
+ parse_positive_attention,
41
+ parse_negative_attention,
42
+ parse_editing,
43
+ parse_alternation,
44
+ parse_interpolation,
45
+ parse_unrestricted_text,
46
+ )
47
+
48
+
49
+ def parse_text(prompt, stoppers):
50
+ return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
51
+
52
+
53
+ def parse_unrestricted_text(prompt, stoppers):
54
+ escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
55
+ regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
56
+ prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
57
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
58
+
59
+
60
+ def parse_substitution(prompt, stoppers):
61
+ prompt, symbol = parse_symbol(prompt, stoppers)
62
+ prompt, arguments = parse_arguments(prompt, stoppers)
63
+ return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
64
+
65
+
66
+ def parse_arguments(prompt, stoppers):
67
+ try:
68
+ prompt, _ = parse_open_paren(prompt, stoppers)
69
+ prompt, arguments = parse_inner_arguments(prompt, stoppers)
70
+ prompt, _ = parse_close_paren(prompt, stoppers)
71
+ except ValueError:
72
+ arguments = []
73
+ return ParseResult(prompt=prompt, expr=arguments)
74
+
75
+
76
+ def parse_inner_arguments(prompt, stoppers):
77
+ arguments = []
78
+ try:
79
+ while True:
80
+ prompt, arg = parse_list_expression(prompt, {',', ')'})
81
+ arguments.append(arg)
82
+ prompt, _ = parse_comma(prompt, stoppers)
83
+ except ValueError:
84
+ pass
85
+
86
+ return ParseResult(prompt=prompt, expr=arguments)
87
+
88
+
89
+ def parse_declaration(prompt, stoppers):
90
+ prompt, symbol = parse_symbol(prompt, stoppers)
91
+ prompt, parameters = parse_parameters(prompt, stoppers)
92
+ prompt, _ = parse_equals(prompt, stoppers)
93
+ prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
94
+ prompt, _ = parse_newline(prompt, stoppers)
95
+ prompt, expr = parse_list_expression(prompt, stoppers)
96
+ return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
97
+
98
+
99
+ def parse_parameters(prompt, stoppers):
100
+ try:
101
+ prompt, _ = parse_open_paren(prompt, stoppers)
102
+ prompt, parameters = parse_inner_parameters(prompt, stoppers)
103
+ prompt, _ = parse_close_paren(prompt, stoppers)
104
+ except ValueError:
105
+ parameters = []
106
+ return ParseResult(prompt=prompt, expr=parameters)
107
+
108
+
109
+ def parse_inner_parameters(prompt, stoppers):
110
+ parameters = []
111
+ try:
112
+ while True:
113
+ prompt, param = parse_symbol(prompt, stoppers)
114
+ parameters.append(param)
115
+ prompt, _ = parse_comma(prompt, stoppers)
116
+ except ValueError:
117
+ pass
118
+
119
+ return ParseResult(prompt=prompt, expr=parameters)
120
+
121
+
122
+ def parse_interpolation(prompt, stoppers):
123
+ prompt, _ = parse_open_square(prompt, stoppers)
124
+ prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
125
+ prompt, steps = parse_interpolation_steps(prompt, stoppers)
126
+ prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
127
+ prompt, _ = parse_close_square(prompt, stoppers)
128
+ return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
129
+
130
+
131
+ def parse_interpolation_exprs(prompt, stoppers):
132
+ exprs = []
133
+
134
+ try:
135
+ while True:
136
+ prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
137
+ if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
138
+ raise ValueError
139
+
140
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
141
+ exprs.append(expr)
142
+ except ValueError:
143
+ pass
144
+
145
+ return ParseResult(prompt=prompt, expr=exprs)
146
+
147
+
148
+ def parse_interpolation_function_name(prompt, stoppers):
149
+ try:
150
+ prompt, _ = parse_colon(prompt, stoppers)
151
+ function_names = ('linear', 'catmull', 'bezier', 'mean')
152
+ return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
153
+ except ValueError:
154
+ return ParseResult(prompt=prompt, expr=None)
155
+
156
+
157
+ def parse_interpolation_steps(prompt, stoppers):
158
+ steps = []
159
+
160
+ try:
161
+ while True:
162
+ prompt, step = parse_interpolation_step(prompt, stoppers)
163
+ steps.append(step)
164
+ prompt, _ = parse_comma(prompt, stoppers)
165
+ except ValueError:
166
+ pass
167
+
168
+ return ParseResult(prompt=prompt, expr=steps)
169
+
170
+
171
+ def parse_interpolation_step(prompt, stoppers):
172
+ try:
173
+ return parse_step(prompt, stoppers)
174
+ except ValueError:
175
+ pass
176
+
177
+ if prompt[0] in {',', ':', ']'}:
178
+ return ParseResult(prompt=prompt, expr=None)
179
+
180
+ raise ValueError
181
+
182
+
183
+ def parse_alternation(prompt, stoppers):
184
+ prompt, _ = parse_open_square(prompt, stoppers)
185
+ prompt, exprs = parse_alternation_exprs(prompt, stoppers)
186
+ prompt, speed = parse_alternation_speed(prompt, stoppers)
187
+ prompt, _ = parse_close_square(prompt, stoppers)
188
+ return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
189
+
190
+
191
+ def parse_alternation_exprs(prompt, stoppers):
192
+ exprs = []
193
+
194
+ try:
195
+ while True:
196
+ prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
197
+ exprs.append(expr)
198
+ prompt, _ = parse_vertical_bar(prompt, stoppers)
199
+ except ValueError:
200
+ if len(exprs) < 2:
201
+ raise
202
+
203
+ return ParseResult(prompt=prompt, expr=exprs)
204
+
205
+
206
+ def parse_alternation_speed(prompt, stoppers):
207
+ try:
208
+ prompt, _ = parse_colon(prompt, stoppers)
209
+ prompt, speed = parse_step(prompt, stoppers)
210
+ return ParseResult(prompt=prompt, expr=speed)
211
+ except ValueError:
212
+ pass
213
+
214
+ return ParseResult(prompt=prompt, expr=None)
215
+
216
+
217
+ def parse_editing(prompt, stoppers):
218
+ prompt, _ = parse_open_square(prompt, stoppers)
219
+ prompt, exprs = parse_editing_exprs(prompt, stoppers)
220
+ try:
221
+ prompt, step = parse_step(prompt, stoppers)
222
+ except ValueError:
223
+ step = None
224
+
225
+ prompt, _ = parse_close_square(prompt, stoppers)
226
+ return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
227
+
228
+
229
+ def parse_editing_exprs(prompt, stoppers):
230
+ exprs = []
231
+
232
+ try:
233
+ for _ in range(2):
234
+ prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
235
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
236
+ exprs.append(expr)
237
+ except ValueError:
238
+ pass
239
+
240
+ return ParseResult(prompt=prompt, expr=exprs)
241
+
242
+
243
+ def parse_negative_attention(prompt, stoppers):
244
+ prompt, _ = parse_open_square(prompt, stoppers)
245
+ prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
246
+ prompt, _ = parse_close_square(prompt, stoppers)
247
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
248
+
249
+
250
+ def parse_positive_attention(prompt, stoppers):
251
+ prompt, _ = parse_open_paren(prompt, stoppers)
252
+ prompt, expr = parse_list_expression(prompt, {':', ')'})
253
+ prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
254
+ prompt, _ = parse_close_paren(prompt, stoppers)
255
+ if len(weight_exprs) >= 2:
256
+ return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
257
+ else:
258
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
259
+
260
+
261
+ def parse_attention_weights(prompt, stoppers):
262
+ weights = []
263
+ try:
264
+ prompt, _ = parse_colon(prompt, stoppers)
265
+ except ValueError:
266
+ return ParseResult(prompt=prompt, expr=weights)
267
+
268
+ while True:
269
+ try:
270
+ prompt, weight_expr = parse_weight(prompt, stoppers)
271
+ weights.append(weight_expr)
272
+ prompt, _ = parse_comma(prompt, stoppers)
273
+ except ValueError:
274
+ return ParseResult(prompt=prompt, expr=weights)
275
+
276
+
277
+ def parse_step(prompt, stoppers):
278
+ try:
279
+ prompt, step = parse_int_not_float(prompt, stoppers)
280
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
281
+ except ValueError:
282
+ pass
283
+ try:
284
+ prompt, step = parse_float(prompt, stoppers)
285
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
286
+ except ValueError:
287
+ pass
288
+
289
+ return parse_substitution(prompt, stoppers)
290
+
291
+
292
+ def parse_weight(prompt, stoppers):
293
+ try:
294
+ prompt, step = parse_float(prompt, stoppers)
295
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
296
+ except ValueError:
297
+ pass
298
+
299
+ return parse_substitution(prompt, stoppers)
300
+
301
+
302
+ def parse_symbol(prompt, stoppers):
303
+ prompt, _ = parse_dollar(prompt)
304
+ return parse_symbol_text(prompt, stoppers)
305
+
306
+
307
+ def parse_symbol_text(prompt, stoppers):
308
+ return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
309
+
310
+
311
+ def parse_float(prompt, stoppers):
312
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
313
+
314
+
315
+ def parse_int_not_float(prompt, stoppers):
316
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
317
+
318
+
319
+ def parse_dollar(prompt):
320
+ dollar_sign = re.escape('$')
321
+ return parse_token(prompt, f'({dollar_sign})')
322
+
323
+
324
+ def parse_equals(prompt, stoppers):
325
+ return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
326
+
327
+
328
+ def parse_comma(prompt, stoppers):
329
+ return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
330
+
331
+
332
+ def parse_colon(prompt, stoppers):
333
+ return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
334
+
335
+
336
+ def parse_vertical_bar(prompt, stoppers):
337
+ return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
338
+
339
+
340
+ def parse_open_square(prompt, stoppers):
341
+ return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
342
+
343
+
344
+ def parse_close_square(prompt, stoppers):
345
+ return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
346
+
347
+
348
+ def parse_open_paren(prompt, stoppers):
349
+ return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
350
+
351
+
352
+ def parse_close_paren(prompt, stoppers):
353
+ return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
354
+
355
+
356
+ def parse_newline(prompt, stoppers):
357
+ return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
358
+
359
+
360
+ def parse_token(prompt, regex):
361
+ match = re.match(regex, prompt)
362
+ if match is None:
363
+ raise ValueError
364
+
365
+ return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
366
+
367
+
368
+ def whitespace_tail_regex(regex, stoppers):
369
+ if '\n' in stoppers:
370
+ return rf'({regex})[ \t\f\r]*'
371
+
372
+ return rf'({regex})\s*'
373
+
374
+
375
+ def set_concat(left, right):
376
+ result = set(left)
377
+ result.update(right)
378
+ return result
prompt-fusion-extension-main/lib_prompt_fusion/t_scaler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scale_t(t, positions):
2
+ if t >= 1.:
3
+ return 1.
4
+
5
+ if t <= 0.:
6
+ return 0.
7
+
8
+ distances = []
9
+ for i in range(len(positions)-1):
10
+ distances.append(positions[i+1] - positions[i])
11
+
12
+ total_distance = sum(distances)
13
+ for i in range(len(distances)):
14
+ distances[i] = distances[i]/total_distance
15
+
16
+ for i in range(len(distances)-1):
17
+ distances[i+1] = distances[i] + distances[i+1]
18
+
19
+ distances.insert(0, 0.0)
20
+
21
+ spline_index = 0
22
+ for i, distance in enumerate(distances):
23
+ if t > distance:
24
+ spline_index = i
25
+ else:
26
+ break
27
+
28
+ if spline_index >= len(distances) - 1:
29
+ return 1
30
+
31
+ local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
32
+ return (spline_index + local_ratio)/(len(distances)-1)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ total_steps = 20
37
+ for i in range(total_steps):
38
+ print(i, scale_t(i/total_steps, [9, 10]))
prompt-fusion-extension-main/metadata.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [Extension]
2
+ Name = prompt-fusion
prompt-fusion-extension-main/readme.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Fusion
2
+
3
+ Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
4
+
5
+ ```
6
+ # linear prompt interpolation
7
+ [night light:magical forest: 5, 15]
8
+
9
+ # catmull-rom curve prompt interpolation
10
+ [night light:magical forest:norvegian territory: 5, 15, 25:catmull]
11
+
12
+ # alternation interpolation
13
+ [ufo|a strange sight:0.5]
14
+
15
+ # linear attention interpolation
16
+ (fire extinguisher: 1.0, 2.0)
17
+
18
+ # prompt-editing-aware attention interpolation
19
+ [(fire extinguisher: 1.0, 2.0)::5]
20
+
21
+ # weighted sum of conditions
22
+ [space station : kitchen mixer :: mean]
23
+
24
+ # define functions and variables to simplify repeating patterns and use a consistent structure
25
+ $prompt($style, $quality, $character, $background) = (
26
+ A detailed picture in the style of $style,
27
+ $quality,
28
+ $character lying back,
29
+ $background in the background
30
+ :1)
31
+ ```
32
+
33
+ ## Features
34
+
35
+ - [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
36
+ - [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
37
+ - [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
38
+ - [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
39
+ - [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
40
+ - Complete backwards compatibility with the builtin prompt syntax of the webui
41
+
42
+ The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
43
+
44
+ Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
45
+
46
+ The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
47
+
48
+ ## Usage
49
+ - Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
50
+
51
+ ## Examples
52
+
53
+ ### 1. Influencing the beginning of the sampling process
54
+
55
+ Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
56
+
57
+ ```
58
+ [lion:bird:girl: , 7, 10]
59
+ ```
60
+
61
+ ![curve1](https://user-images.githubusercontent.com/32277961/214725976-b72bafc6-0c5d-4491-9c95-b73da41da082.gif)
62
+
63
+ ### 2. Influencing the middle of the sampling process
64
+
65
+ Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
66
+
67
+ ```
68
+ [fireball:seawater:dragon: , .1, .4:bezier] monster
69
+ ```
70
+
71
+ ![curve2](https://user-images.githubusercontent.com/32277961/214941229-2dccad78-f856-42bb-ae6b-16b65b273cda.gif)
72
+
73
+ ## Webui supported releases
74
+
75
+ The following webui releases are officially supported:
76
+ - `v1.0.0-pre`
77
+ - `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
78
+
79
+ ## Installation
80
+ 1. Visit the **Extensions** tab of Automatic's WebUI.
81
+ 2. Visit the **Available** subtab.
82
+ 3. Look for **Prompt Fusion**.
83
+ 4. Press the **Install** button.
84
+ 5. Wait for the webui to finish downloading the extension.
85
+ 6. Visit the **Installed** subtab.
86
+ 7. click on **Apply and restart UI**.
87
+
88
+ Alternatively, instead of steps 6 and 7, you can restart the webui completely.
89
+
90
+ ## Related Projects
91
+
92
+ - Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
93
+ - Shift Attention: https://github.com/yownas/shift-attention
94
+ - Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
95
+
prompt-fusion-extension-main/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
prompt-fusion-extension-main/scripts/promptlang.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import re
3
+ from lib_prompt_fusion import hijacker, empty_cond, global_state, interpolation_tensor, prompt_parser as prompt_fusion_parser
4
+ from modules import scripts, script_callbacks, prompt_parser, shared
5
+
6
+
7
+ fusion_hijacker_attribute = '__fusion_hijacker'
8
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
9
+ module=prompt_parser,
10
+ hijacker_attribute=fusion_hijacker_attribute,
11
+ register_uninstall=script_callbacks.on_script_unloaded)
12
+
13
+
14
+ def _neutral_prompt_enabled() -> bool:
15
+ try:
16
+ from lib_neutral_prompt import global_state as np_state
17
+ return bool(getattr(np_state, "is_enabled", False))
18
+ except Exception:
19
+ return False
20
+
21
+
22
+
23
+ def on_ui_settings():
24
+ section = ('prompt-fusion', 'Prompt Fusion')
25
+ shared.opts.add_option('prompt_fusion_enabled', shared.OptionInfo(True, 'Enable prompt-fusion extension', section=section))
26
+ shared.opts.add_option('prompt_fusion_slerp_scale', shared.OptionInfo(0, 'Slerp scale (0 = linear geometry, 1 = slerp geometry)', component=gr.Number, section=section))
27
+ shared.opts.add_option('prompt_fusion_slerp_negative_origin', shared.OptionInfo(True, 'use negative prompt as slerp origin', section=section))
28
+ shared.opts.add_option('prompt_fusion_slerp_epsilon', shared.OptionInfo(0.0001, 'Slerp epsilon (fallback on linear geometry when conds are too similar. 0 = parallel, 1 = perpendicular)', component=gr.Number, section=section))
29
+
30
+
31
+ script_callbacks.on_ui_settings(on_ui_settings)
32
+
33
+
34
+ @prompt_parser_hijacker.hijack('get_learned_conditioning')
35
+ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
36
+ if _neutral_prompt_enabled():
37
+ return original_function(model, prompts, total_steps, *args, **kwargs)
38
+
39
+ if not shared.opts.prompt_fusion_enabled:
40
+ return original_function(model, prompts, total_steps, *args, **kwargs)
41
+
42
+ hires_steps, use_old_scheduling, *_ = args if args else (None, True)
43
+ is_hires = hires_steps is not None
44
+ if is_hires:
45
+ real_total_steps = hires_steps
46
+ else:
47
+ real_total_steps = total_steps
48
+
49
+ if hasattr(prompts, 'is_negative_prompt'):
50
+ is_negative_prompt = prompts.is_negative_prompt
51
+ else:
52
+ is_negative_prompt = global_state.old_webui_is_negative
53
+
54
+ empty_cond.init(model)
55
+
56
+ tensor_builders = _parse_tensor_builders(prompts, real_total_steps, is_hires, use_old_scheduling)
57
+ if hasattr(prompt_parser, 'SdConditioning'):
58
+ empty_conditioning = prompt_parser.SdConditioning(prompts)
59
+ empty_conditioning.clear()
60
+ else:
61
+ empty_conditioning = []
62
+
63
+ flattened_prompts, consecutive_ranges = _get_flattened_prompts(tensor_builders, empty_conditioning)
64
+ flattened_schedules = original_function(model, flattened_prompts, total_steps, *args, **kwargs)
65
+
66
+ if isinstance(flattened_schedules[0][0].cond, dict): # sdxl
67
+ CondWrapper = interpolation_tensor.DictCondWrapper
68
+ else:
69
+ CondWrapper = interpolation_tensor.TensorCondWrapper
70
+
71
+ flattened_schedules = [
72
+ [
73
+ prompt_parser.ScheduledPromptConditioning(cond=CondWrapper(schedule.cond), end_at_step=schedule.end_at_step)
74
+ for schedule in subschedules
75
+ ]
76
+ for subschedules in flattened_schedules
77
+ ]
78
+
79
+ cond_tensors = (tensor_builder.build(flattened_schedules[begin:end], empty_cond.get())
80
+ for begin, end, tensor_builder
81
+ in zip(consecutive_ranges[:-1], consecutive_ranges[1:], tensor_builders))
82
+
83
+ schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires)
84
+ for cond_tensor in cond_tensors]
85
+
86
+ if is_negative_prompt:
87
+ if hires_steps is not None:
88
+ global_state.negative_schedules_hires = schedules[0]
89
+ else:
90
+ global_state.negative_schedules = schedules[0]
91
+
92
+ schedules = [
93
+ [
94
+ prompt_parser.ScheduledPromptConditioning(cond=schedule.cond.original_cond, end_at_step=schedule.end_at_step)
95
+ for schedule in subschedules
96
+ ]
97
+ for subschedules in schedules
98
+ ]
99
+
100
+ return schedules
101
+
102
+
103
+ @prompt_parser_hijacker.hijack('get_multicond_learned_conditioning')
104
+ def _hijacked_get_multicond_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
105
+ if _neutral_prompt_enabled():
106
+ try:
107
+ return original_function(model, prompts, total_steps, *args, **kwargs)
108
+ finally:
109
+ global_state.old_webui_is_negative = False
110
+
111
+ if not shared.opts.prompt_fusion_enabled:
112
+ try:
113
+ return original_function(model, prompts, total_steps, *args, **kwargs)
114
+ finally:
115
+ global_state.old_webui_is_negative = False
116
+
117
+ try:
118
+ hires_steps, use_old_scheduling, *_ = args if args else (None, True)
119
+
120
+ res_indexes, prompt_flat_list = _get_multicond_prompt_list(prompts)
121
+ learned_conditioning = prompt_parser.get_learned_conditioning(
122
+ model,
123
+ prompt_flat_list,
124
+ total_steps,
125
+ hires_steps,
126
+ use_old_scheduling,
127
+ **kwargs,
128
+ )
129
+
130
+ batch = [
131
+ [prompt_parser.ComposableScheduledPromptConditioning(learned_conditioning[i], weight)
132
+ for i, weight in indexes]
133
+ for indexes in res_indexes
134
+ ]
135
+
136
+ return prompt_parser.MulticondLearnedConditioning(
137
+ shape=_infer_multicond_shape(learned_conditioning),
138
+ batch=batch,
139
+ )
140
+ finally:
141
+ global_state.old_webui_is_negative = False
142
+
143
+
144
+ def _parse_tensor_builders(prompts, total_steps, is_hires, use_old_scheduling):
145
+ tensor_builders = []
146
+
147
+ for prompt in prompts:
148
+ expr = prompt_fusion_parser.parse_prompt(prompt)
149
+ tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
150
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires, use_old_scheduling)
151
+ tensor_builders.append(tensor_builder)
152
+
153
+ return tensor_builders
154
+
155
+
156
+ def _get_flattened_prompts(tensor_builders, flattened_prompts=None):
157
+ if flattened_prompts is None:
158
+ flattened_prompts = []
159
+ consecutive_ranges = [0]
160
+
161
+ for tensor_builder in tensor_builders:
162
+ flattened_prompts.extend(tensor_builder.get_prompt_database())
163
+ consecutive_ranges.append(len(flattened_prompts))
164
+
165
+ return flattened_prompts, consecutive_ranges
166
+
167
+
168
+ def _sample_tensor_schedules(tensor, steps, is_hires):
169
+ schedules = []
170
+
171
+ for step in range(steps):
172
+ origin_cond = global_state.get_origin_cond_at(step, is_hires)
173
+ params = interpolation_tensor.InterpolationParams(step / steps, step, steps, global_state.get_slerp_scale(), global_state.get_slerp_epsilon())
174
+ schedule_cond = tensor.interpolate(params, origin_cond, empty_cond.get())
175
+ if schedules and schedules[-1].cond == schedule_cond:
176
+ schedules[-1] = prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedules[-1].cond)
177
+ else:
178
+ schedules.append(prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedule_cond))
179
+
180
+ return schedules
181
+
182
+
183
+ class PromptFusionScript(scripts.Script):
184
+ def title(self):
185
+ return 'Prompt Fusion'
186
+
187
+ def show(self, is_img2img):
188
+ return scripts.AlwaysVisible
189
+
190
+ def process(self, p, *args):
191
+ global_state.negative_schedules = None
192
+ global_state.old_webui_is_negative = True
193
+
194
+
195
+ def _get_multicond_prompt_list(prompts):
196
+ prompt_indexes = {}
197
+ prompt_flat_list = prompt_parser.SdConditioning([], copy_from=prompts)
198
+ res_indexes = []
199
+
200
+ for prompt in prompts:
201
+ indexes = []
202
+ for text, weight in _split_multicond_prompt(prompt):
203
+ index = prompt_indexes.get(text)
204
+ if index is None:
205
+ index = len(prompt_flat_list)
206
+ prompt_flat_list.append(text)
207
+ prompt_indexes[text] = index
208
+ indexes.append((index, weight))
209
+ res_indexes.append(indexes)
210
+
211
+ return res_indexes, prompt_flat_list
212
+
213
+
214
+ def _split_multicond_prompt(prompt):
215
+ parts = []
216
+ buf = []
217
+ depth_paren = depth_brack = depth_brace = 0
218
+ escaped = False
219
+ i = 0
220
+
221
+ while i < len(prompt):
222
+ ch = prompt[i]
223
+
224
+ if escaped:
225
+ buf.append(ch)
226
+ escaped = False
227
+ i += 1
228
+ continue
229
+
230
+ if ch == '\\':
231
+ buf.append(ch)
232
+ escaped = True
233
+ i += 1
234
+ continue
235
+
236
+ if ch == '(':
237
+ depth_paren += 1
238
+ elif ch == ')' and depth_paren > 0:
239
+ depth_paren -= 1
240
+ elif ch == '[':
241
+ depth_brack += 1
242
+ elif ch == ']' and depth_brack > 0:
243
+ depth_brack -= 1
244
+ elif ch == '{':
245
+ depth_brace += 1
246
+ elif ch == '}' and depth_brace > 0:
247
+ depth_brace -= 1
248
+
249
+ if depth_paren == 0 and depth_brack == 0 and depth_brace == 0:
250
+ if _matches_top_level_and(prompt, i):
251
+ _append_multicond_piece(parts, ''.join(buf))
252
+ buf = []
253
+ i += 3
254
+ continue
255
+
256
+ if _matches_top_level_amp(prompt, i):
257
+ _append_multicond_piece(parts, ''.join(buf))
258
+ buf = []
259
+ i += 1
260
+ continue
261
+
262
+ buf.append(ch)
263
+ i += 1
264
+
265
+ _append_multicond_piece(parts, ''.join(buf))
266
+
267
+ return parts or [('', 1.0)]
268
+
269
+
270
+ def _append_multicond_piece(parts, piece):
271
+ text, weight = _split_tail_weight(piece)
272
+ parts.append((text, weight))
273
+
274
+
275
+ def _split_tail_weight(piece):
276
+ piece = piece.strip()
277
+ if not piece:
278
+ return '', 1.0
279
+
280
+ depth_paren = depth_brack = depth_brace = 0
281
+ escaped = False
282
+ split_at = None
283
+
284
+ for i, ch in enumerate(piece):
285
+ if escaped:
286
+ escaped = False
287
+ continue
288
+
289
+ if ch == '\\':
290
+ escaped = True
291
+ continue
292
+
293
+ if ch == '(':
294
+ depth_paren += 1
295
+ elif ch == ')' and depth_paren > 0:
296
+ depth_paren -= 1
297
+ elif ch == '[':
298
+ depth_brack += 1
299
+ elif ch == ']' and depth_brack > 0:
300
+ depth_brack -= 1
301
+ elif ch == '{':
302
+ depth_brace += 1
303
+ elif ch == '}' and depth_brace > 0:
304
+ depth_brace -= 1
305
+ elif ch in '::' and depth_paren == 0 and depth_brack == 0 and depth_brace == 0:
306
+ split_at = i
307
+
308
+ if split_at is None:
309
+ return piece, 1.0
310
+
311
+ suffix = piece[split_at + 1:].strip()
312
+ if not re.fullmatch(r'[-+]?(?:\d+\.?|\d*\.\d+)', suffix):
313
+ return piece, 1.0
314
+
315
+ text = piece[:split_at].rstrip()
316
+ if not text:
317
+ return piece, 1.0
318
+
319
+ return text, float(suffix)
320
+
321
+
322
+ def _matches_top_level_and(prompt, index):
323
+ if prompt[index:index + 3] != 'AND':
324
+ return False
325
+
326
+ prev_ch = prompt[index - 1] if index > 0 else ''
327
+ next_ch = prompt[index + 3] if index + 3 < len(prompt) else ''
328
+
329
+ if (prev_ch.isalnum() or prev_ch == '_') or (next_ch.isalnum() or next_ch == '_'):
330
+ return False
331
+
332
+ return not any(prompt.startswith(f'AND{suffix}', index) for suffix in ('_PERP', '_SALT', '_TOPK'))
333
+
334
+
335
+ def _matches_top_level_amp(prompt, index):
336
+ if prompt[index] != '&':
337
+ return False
338
+
339
+ prev_ch = prompt[index - 1] if index > 0 else ' '
340
+ next_ch = prompt[index + 1] if index + 1 < len(prompt) else ' '
341
+ return prev_ch.isspace() and next_ch.isspace()
342
+
343
+
344
+ def _infer_multicond_shape(learned_conditioning):
345
+ if not learned_conditioning or not learned_conditioning[0]:
346
+ return (0,)
347
+
348
+ cond = learned_conditioning[0][0].cond
349
+ if isinstance(cond, dict):
350
+ crossattn = cond.get('crossattn')
351
+ if isinstance(crossattn, list) and crossattn:
352
+ return getattr(crossattn[0], 'shape', None) or (0,)
353
+ return getattr(crossattn, 'shape', None) or (0,)
354
+
355
+ return getattr(cond, 'shape', None) or (0,)
prompt-fusion-extension-main/test/parser_tests.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion.prompt_parser import parse_prompt
2
+ from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
3
+
4
+
5
+ def run_functional_tests(total_steps=100):
6
+ for i, (given, expected) in enumerate(functional_parse_test_cases):
7
+ expr = parse_prompt(given)
8
+ tensor_builder = InterpolationTensorBuilder()
9
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
10
+
11
+ actual = tensor_builder.get_prompt_database()
12
+
13
+ if type(expected) is set:
14
+ assert set(actual) == expected, f"{actual} != {expected}"
15
+ else:
16
+ assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
17
+
18
+
19
+ functional_parse_test_cases = [
20
+ ('single',)*2,
21
+ ('some space separated text',)*2,
22
+ ('(legacy weighted prompt:-2.1)',)*2,
23
+ ('mixed (legacy weight:3.6) and text',)*2,
24
+ ('legacy [range begin:0] thingy',)*2,
25
+ ('legacy [range end::3] thingy',)*2,
26
+ ('legacy [[nested range::3]:2] thingy',)*2,
27
+ ('legacy [[nested range:2]::3] thingy',)*2,
28
+ ('sugar [range:,abc:3] thingy',)*2,
29
+ ('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
30
+ ('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
31
+ ('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
32
+ ('legacy [from:to:2] thingy',)*2,
33
+ ('legacy [negative weight]',)*2,
34
+ ('legacy (positive weight)',)*2,
35
+ ('[abc:1girl:2]',)*2,
36
+ ('[::]',)*2,
37
+ ('[a:b:]',)*2,
38
+ ('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
39
+ ('1girl',)*2,
40
+ ('dashes-in-text',)*2,
41
+ ('text, separated with, comas',)*2,
42
+ ('{prompt}',)*2,
43
+ ('[abc|def ghi|jkl]',)*2,
44
+ ('merging this AND with this',)*2,
45
+ (':',)*2,
46
+ (r'portrait \(object\)',)*2,
47
+ (r'\[escaped square\]',)*2,
48
+ (r'\$var = abc',)*2,
49
+ (r'\\$ arst',)*2,
50
+ (r'$$ arst',)*2,
51
+ ('$var = abc', ''),
52
+ ('$a = prompt value\n$a', 'prompt value'),
53
+ ('$a = prompt value\n$b = $a\n$b', 'prompt value'),
54
+ ('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
55
+ ('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
56
+ ('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
57
+ ('a [b:c:5, 6] d', {'a b d', 'a c d'}),
58
+ ('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
59
+ ('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
60
+ ('a [b:c:,] d', {'a b d', 'a c d'}),
61
+ ('0[1.0:1.1:,]2[3.0:3.1:,]4', {
62
+ '0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
63
+ '0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
64
+ }),
65
+ ('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
66
+ '0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
67
+ '0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
68
+ '0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
69
+ }),
70
+ ('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
71
+ '0.0 1.0 2.0', '0.0 1.0 2.1',
72
+ '0.1 1.0 2.0', '0.1 1.0 2.1',
73
+ '0.0 1.1 2.0', '0.0 1.1 2.1',
74
+ '0.1 1.1 2.0', '0.1 1.1 2.1',
75
+ }),
76
+ ('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
77
+ ('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
78
+ ('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
79
+ ('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
80
+ ('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
81
+ ('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
82
+ ('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
83
+ ('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
84
+ ('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
85
+ ('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
86
+ ('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
87
+ ('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
88
+ ('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
89
+ ('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
90
+ ('[a|b|c]', '[a|b|c]'),
91
+ ('[a|b|c:]', '[a|b|c]'),
92
+ ('[a|b|c:1]', {'a', 'b', 'c'}),
93
+ ('[a|b|c:2]', {'a', 'b', 'c'}),
94
+ ('[a|b|c:0.5]', {'a', 'b', 'c'}),
95
+ ('[a|b|c:1.1]', {'a', 'b', 'c'}),
96
+ ('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
97
+ ('[a:b:c::mean]', {'a', 'b', 'c'}),
98
+ ('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
99
+ ('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
100
+ ]
101
+
102
+
103
+ def run_tests():
104
+ run_functional_tests()
prompt-fusion-extension-main/test/run_all.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ import parser_tests
4
+
5
+
6
+ if __name__ == '__main__':
7
+ parser_tests.run_tests()