dikdimon commited on
Commit
7fbe17a
·
verified ·
1 Parent(s): 545a4ea

Upload 20 files

Browse files
neutral_prompt_patched/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
neutral_prompt_patched/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 ljleb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
neutral_prompt_patched/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sd-webui-neutral-prompt — Ultimate Edition
2
+
3
+ > **Unified merge of the main experimental branches.**
4
+ > Most features from all branches are merged and work simultaneously.
5
+ > Some `life`/dev-only tooling (debug visualisations, mask images) was intentionally not carried over — only the production-safe subset is included.
6
+
7
+ ---
8
+
9
+ ## What's inside
10
+
11
+ | Feature | Origin branch | Keyword / setting |
12
+ |---|---|---|
13
+ | Perpendicular projection | `main` | `AND_PERP` |
14
+ | Saliency-guided mask | `main` | `AND_SALT` |
15
+ | Semantic guidance top-k | `main` | `AND_TOPK` |
16
+ | CFG rescale (mean-preserving) | `main` | CFG rescale φ slider |
17
+ | XYZ-grid CFG rescale axis | `main` | XYZ-grid option |
18
+ | External API override | `main` / `export_rescale_factor` | `override_cfg_rescale()` |
19
+ | CFG rescale factor export | `export_rescale_factor` | `get_last_cfg_rescale_factor()` |
20
+ | Affine spatial transforms | `affine` | `ROTATE / SLIDE / SCALE / SHEAR` prefix |
21
+ | Soft alignment blend | `alignment_blend` | `AND_ALIGN_D_S` |
22
+ | Binary alignment mask | `alignment_mask` | `AND_MASK_ALIGN_D_S` |
23
+ | Sharpened saliency maps (k=20) | `life` (production-safe subset) | used automatically by `AND_SALT` |
24
+
25
+ ---
26
+
27
+ ## Prompt syntax
28
+
29
+ ```
30
+ positive prompt text
31
+ AND_PERP negative-direction text :0.8
32
+ AND_SALT competing concept :1.0
33
+ AND_TOPK fine detail tweak :0.5
34
+ AND_ALIGN_4_8 style blend :0.6
35
+ AND_MASK_ALIGN_4_8 structure-preserving style :0.6
36
+ AND_PERP ROTATE[0.125] rotated perpendicular concept :1.0
37
+ ROTATE[0.125] AND_PERP rotated perpendicular concept :1.0
38
+ AND_SALT SLIDE[0.05,0] spatially shifted concept :1.0
39
+ ```
40
+
41
+ ### Affine transform keywords
42
+ Affine keywords are supported in **either order** relative to the conciliation keyword:
43
+
44
+ ```
45
+ AND_PERP ROTATE[0.125] vivid colors :0.8 ← affine after keyword
46
+ ROTATE[0.125] AND_PERP vivid colors :0.8 ← affine before keyword (both work)
47
+ ```
48
+
49
+ | Keyword | Parameter | Effect |
50
+ |---|---|---|
51
+ | `ROTATE[angle]` | angle in turns (0–1) | Rotate latent contribution |
52
+ | `SLIDE[x,y]` | normalised offset | Translate latent contribution |
53
+ | `SCALE[x,y]` | scale factors | Scale latent contribution |
54
+ | `SHEAR[x,y]` | shear in turns | Shear latent contribution |
55
+
56
+ Multiple transforms can be chained: `ROTATE[0.125] SLIDE[0.1,0]`
57
+
58
+ ### AND_ALIGN_D_S / AND_MASK_ALIGN_D_S
59
+ - **D** = detail kernel size (small → fine detail)
60
+ - **S** = structure kernel size (large → global composition)
61
+ - The child prompt is blended in proportionally to how much it alters detail *without* changing structure.
62
+ - `AND_ALIGN` uses a soft weight; `AND_MASK_ALIGN` uses a binary 0/1 mask.
63
+ - Supported range: D, S ∈ [2, 32], D ≠ S.
64
+
65
+ ---
66
+
67
+ ## External API
68
+
69
+ ```python
70
+ # In another extension or script:
71
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
72
+
73
+ override_cfg_rescale(0.7) # override for next step only
74
+ factor = get_last_cfg_rescale_factor() # read after generation
75
+ ```
76
+
77
+ ---
78
+
79
+ ## Installation
80
+
81
+ 1. Clone / copy this folder into `stable-diffusion-webui/extensions/`.
82
+ 2. Restart the webui.
83
+
84
+ ---
85
+
86
+ ## CFG Rescale φ
87
+
88
+ When set to a value > 0 the extension rescales the CFG output to have the
89
+ same mean and a partially rescaled standard deviation as the raw predicted x0.
90
+ This reduces colour over-saturation at high CFG values without affecting
91
+ prompt adherence as much as simply lowering CFG.
92
+
93
+ The formula used is the **mean-preserving** variant from the `main` branch:
94
+
95
+ ```
96
+ rescaled = rescale_mean + (cfg_cond − cfg_cond_mean) × rescale_factor
97
+ ```
98
+
99
+ where `rescale_factor = φ × (std(cond)/std(cfg_cond) − 1) + 1`.
neutral_prompt_patched/lib_neutral_prompt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # lib_neutral_prompt package
neutral_prompt_patched/lib_neutral_prompt/affine_transform.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Affine spatial transform utilities for latent tensors.
3
+ Extracted from the `affine` branch so that cfg_denoiser_hijack can stay lean.
4
+
5
+ Provides:
6
+ apply_affine_transform() – apply a 2×3 affine grid to a C×H×W tensor
7
+ apply_masked_transform() – apply affine + cosine-feathered mask blending
8
+ create_cosine_feathered_mask() – smooth circular weight mask
9
+ """
10
+
11
+ from typing import Tuple
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+
17
+ def apply_affine_transform(
18
+ tensor: torch.Tensor,
19
+ affine: torch.Tensor,
20
+ mode: str = 'bilinear',
21
+ ) -> torch.Tensor:
22
+ """
23
+ Apply a 2×3 affine transform to a C×H×W tensor, preserving aspect ratio.
24
+
25
+ :param tensor: Input tensor of shape [C, H, W].
26
+ :param affine: 2×3 float32 affine matrix.
27
+ :param mode: Interpolation mode for grid_sample ('bilinear' default).
28
+ The original affine branch used bilinear for the pre-noise
29
+ inverse path (smoother) and nearest for post-combine.
30
+ Bilinear is the better default for both.
31
+ :return: Transformed tensor of shape [C, H, W].
32
+ """
33
+ affine = affine.clone().to(tensor.device)
34
+ aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
35
+ affine[0, 1] *= aspect_ratio
36
+ affine[1, 0] /= aspect_ratio
37
+
38
+ grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
39
+ return F.grid_sample(tensor.unsqueeze(0), grid, mode=mode, align_corners=False).squeeze(0)
40
+
41
+
42
+ def apply_masked_transform(
43
+ tensor: torch.Tensor,
44
+ affine: torch.Tensor,
45
+ weight: float,
46
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ """
48
+ Apply an affine transform with a cosine-feathered spatial weight mask.
49
+
50
+ The mask channel is appended to the tensor before transformation so that
51
+ the same spatial warp is applied to both content and mask consistently.
52
+
53
+ :param tensor: C×H×W latent tensor.
54
+ :param affine: 2×3 affine matrix.
55
+ :param weight: Global scalar weight for the mask.
56
+ :return: (transformed_content [C, H, W], transformed_mask [H, W])
57
+ """
58
+ mask = create_cosine_feathered_mask(tensor.shape[-2:], weight).unsqueeze(0).to(tensor.device)
59
+ tensor_with_mask = torch.cat([tensor, mask], dim=0) # [C+1, H, W]
60
+ transformed = apply_affine_transform(tensor_with_mask, affine)
61
+ return transformed[:-1], transformed[-1] # content, mask
62
+
63
+
64
+ def create_cosine_feathered_mask(size: Tuple[int, int], weight: float) -> torch.Tensor:
65
+ """
66
+ Create a circularly-clipped cosine-feathered mask of shape [H, W].
67
+
68
+ Values at the centre approach `weight`; values outside the unit circle
69
+ are exactly 0, with a smooth cosine fall-off in between.
70
+
71
+ :param size: (H, W) tuple.
72
+ :param weight: Peak mask value at the centre.
73
+ :return: Float32 mask tensor of shape [H, W].
74
+ """
75
+ y, x = torch.meshgrid(
76
+ torch.linspace(-1, 1, size[0]),
77
+ torch.linspace(-1, 1, size[1]),
78
+ indexing='ij',
79
+ )
80
+ dist = torch.sqrt(x ** 2 + y ** 2)
81
+ mask = 0.5 * (1.0 + torch.cos(torch.pi * dist))
82
+ mask[dist > 1] = 0.0
83
+ return mask.float() * weight
neutral_prompt_patched/lib_neutral_prompt/cfg_denoiser_hijack.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified CFG Denoiser Hijack
3
+ ===========================
4
+ Combines features from all sd-webui-neutral-prompt branches:
5
+
6
+ • Core PERP / SALT / TOPK strategies (main branch)
7
+ • cfg_rescale with mean-preserving formula (main branch – better than multiplicative)
8
+ • cfg_rescale_override (XYZ-grid / API support) (main branch)
9
+ • CFGRescaleFactorSingleton export (export_rescale_factor branch)
10
+ • Affine spatial transforms per-prompt (affine branch)
11
+ • AND_ALIGN_D_S – soft alignment blend (alignment_blend branch)
12
+ • AND_MASK_ALIGN_D_S – binary alignment mask (alignment_mask branch)
13
+ • Improved salience with sharpness parameter k (life branch – production-safe subset)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import dataclasses
19
+ import functools
20
+ import re
21
+ import sys
22
+ import textwrap
23
+ from typing import Dict, List, Tuple
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from lib_neutral_prompt import affine_transform as affine_mod
29
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser
30
+ from modules import script_callbacks, sd_samplers, shared
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Pre-noise affine hook (affine branch)
35
+ #
36
+ # Applied BEFORE each CFG denoising step: the noisy latent slice for every
37
+ # cond prompt is warped by the INVERSE of that prompt's affine transform.
38
+ # This means the UNet sees the latent in the "locally-rotated" frame of each
39
+ # prompt. combine_denoised then applies the FORWARD transform to the
40
+ # resulting cond delta, restoring it to the global frame.
41
+ # ---------------------------------------------------------------------------
42
+
43
+ @dataclasses.dataclass
44
+ class _PreNoiseArgs:
45
+ x: torch.Tensor
46
+ cond_indices: List[Tuple[int, float]]
47
+
48
+
49
+ class _GlobalToLocalAffineVisitor:
50
+ """Build a {cond_index: inverse_affine} dict for every leaf prompt."""
51
+
52
+ def visit_leaf_prompt(
53
+ self,
54
+ that: neutral_prompt_parser.LeafPrompt,
55
+ args: _PreNoiseArgs,
56
+ index: int,
57
+ ) -> Dict[int, torch.Tensor]:
58
+ cond_index = args.cond_indices[index][0]
59
+ if that.local_transform is not None:
60
+ # 2×3 → 3×3 → invert → back to 2×3
61
+ m3 = torch.vstack([that.local_transform,
62
+ torch.tensor([0.0, 0.0, 1.0])])
63
+ inv = torch.linalg.inv(m3)[:-1]
64
+ else:
65
+ inv = torch.eye(3)[:-1]
66
+ return {cond_index: inv}
67
+
68
+ def visit_composite_prompt(
69
+ self,
70
+ that: neutral_prompt_parser.CompositePrompt,
71
+ args: _PreNoiseArgs,
72
+ index: int,
73
+ ) -> Dict[int, torch.Tensor]:
74
+ inv_transforms: Dict[int, torch.Tensor] = {}
75
+
76
+ for child in that.children:
77
+ inv_transforms.update(child.accept(_GlobalToLocalAffineVisitor(), args, index))
78
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
79
+
80
+ # Compose parent transform on top of all children
81
+ if that.local_transform is not None:
82
+ m3 = torch.vstack([that.local_transform,
83
+ torch.tensor([0.0, 0.0, 1.0])])
84
+ parent_inv = torch.linalg.inv(m3)
85
+ for inv in inv_transforms.values():
86
+ # inv is 2×3; extend to 3×3, multiply, trim back
87
+ inv3 = torch.vstack([inv, torch.tensor([0.0, 0.0, 1.0])])
88
+ inv[:] = (parent_inv @ inv3)[:-1]
89
+
90
+ return inv_transforms
91
+
92
+
93
+ def _on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams) -> None:
94
+ """Pre-noise hook: warp each cond latent slice by the inverse affine."""
95
+ if not global_state.is_enabled:
96
+ return
97
+
98
+ for prompt, cond_indices in zip(global_state.prompt_exprs,
99
+ global_state.batch_cond_indices):
100
+ args = _PreNoiseArgs(params.x, cond_indices)
101
+ inv_transforms = prompt.accept(_GlobalToLocalAffineVisitor(), args, 0)
102
+ for cond_index, _ in cond_indices:
103
+ params.x[cond_index] = affine_mod.apply_affine_transform(
104
+ params.x[cond_index], inv_transforms[cond_index]
105
+ )
106
+
107
+
108
+ script_callbacks.on_cfg_denoiser(_on_cfg_denoiser)
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Public entry point
113
+ # ---------------------------------------------------------------------------
114
+
115
+ def combine_denoised_hijack(
116
+ x_out: torch.Tensor,
117
+ batch_cond_indices: List[List[Tuple[int, float]]],
118
+ text_uncond: torch.Tensor,
119
+ cond_scale: float,
120
+ original_function,
121
+ ) -> torch.Tensor:
122
+ if not global_state.is_enabled:
123
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
124
+
125
+ denoised = _get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
126
+ uncond = x_out[-text_uncond.shape[0]:]
127
+
128
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
129
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
130
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, 0)
131
+ aux_cond_delta = prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, 0)
132
+
133
+ # Apply per-prompt affine transform to both deltas (affine branch)
134
+ if prompt.local_transform is not None:
135
+ cond_delta = affine_mod.apply_affine_transform(cond_delta, prompt.local_transform)
136
+ aux_cond_delta = affine_mod.apply_affine_transform(aux_cond_delta, prompt.local_transform)
137
+
138
+ cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
139
+ denoised[batch_i] = _cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
140
+
141
+ return denoised
142
+
143
+
144
+ # ---------------------------------------------------------------------------
145
+ # Internal helpers
146
+ # ---------------------------------------------------------------------------
147
+
148
+ def _get_webui_denoised(
149
+ x_out: torch.Tensor,
150
+ batch_cond_indices: List[List[Tuple[int, float]]],
151
+ text_uncond: torch.Tensor,
152
+ cond_scale: float,
153
+ original_function,
154
+ ) -> torch.Tensor:
155
+ uncond = x_out[-text_uncond.shape[0]:]
156
+ sliced_batch_x_out: List[torch.Tensor] = []
157
+ sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
158
+
159
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
160
+ args = _DenoiseArgs(x_out, uncond[batch_i], cond_indices)
161
+ sliced_x_out, sliced_indices = _gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
162
+ if sliced_indices:
163
+ sliced_batch_cond_indices.append(sliced_indices)
164
+ sliced_batch_x_out.extend(sliced_x_out)
165
+
166
+ sliced_batch_x_out += list(uncond)
167
+ return original_function(
168
+ torch.stack(sliced_batch_x_out, dim=0),
169
+ sliced_batch_cond_indices,
170
+ text_uncond,
171
+ cond_scale,
172
+ )
173
+
174
+
175
+ def _cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
176
+ """
177
+ Mean-preserving CFG rescale (main branch formula).
178
+ Override is applied first so XYZ-grid / external API overrides are never silently skipped.
179
+ Also stores the computed rescale factor in CFGRescaleFactorSingleton.
180
+ """
181
+ # Clear last step's factor so get() returns None if rescaling is skipped
182
+ # this step (stale value bug fix).
183
+ global_state.CFGRescaleFactorSingleton.clear()
184
+
185
+ # Apply one-shot override BEFORE the early-exit check, otherwise a 0→nonzero
186
+ # override (from the external API) would be silently discarded.
187
+ global_state.apply_and_clear_cfg_rescale_override()
188
+
189
+ if global_state.cfg_rescale == 0:
190
+ return cfg_cond
191
+
192
+ cfg_std = cfg_cond.std()
193
+ if cfg_std == 0:
194
+ # Degenerate case: constant tensor – rescaling is a no-op.
195
+ return cfg_cond
196
+
197
+ cfg_cond_mean = cfg_cond.mean()
198
+ rescale_mean = (
199
+ (1 - global_state.cfg_rescale) * cfg_cond_mean
200
+ + global_state.cfg_rescale * cond.mean()
201
+ )
202
+ rescale_factor = global_state.cfg_rescale * (cond.std() / cfg_std - 1) + 1
203
+
204
+ # Export for external consumers (export_rescale_factor branch)
205
+ global_state.CFGRescaleFactorSingleton.set(
206
+ rescale_factor.item() if isinstance(rescale_factor, torch.Tensor) else float(rescale_factor)
207
+ )
208
+
209
+ return rescale_mean + (cfg_cond - cfg_cond_mean) * rescale_factor
210
+
211
+
212
+ @dataclasses.dataclass
213
+ class _DenoiseArgs:
214
+ x_out: torch.Tensor
215
+ uncond: torch.Tensor
216
+ cond_indices: List[Tuple[int, float]]
217
+
218
+
219
+ # ---------------------------------------------------------------------------
220
+ # Gather webui-style conditions (needed for the webui's own CFG path)
221
+ # ---------------------------------------------------------------------------
222
+
223
+ def _gather_webui_conds(
224
+ prompt: neutral_prompt_parser.CompositePrompt,
225
+ args: _DenoiseArgs,
226
+ index_in: int,
227
+ index_out: int,
228
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
229
+ sliced_x_out: List[torch.Tensor] = []
230
+ sliced_cond_indices: List[Tuple[int, float]] = []
231
+
232
+ for child in prompt.children:
233
+ if child.conciliation is None:
234
+ if isinstance(child, neutral_prompt_parser.LeafPrompt) and child.local_transform is None:
235
+ child_x_out = args.x_out[args.cond_indices[index_in][0]]
236
+ child_weight = child.weight
237
+ else:
238
+ child_x_out, child_weight = _get_cond_delta_and_weight(child, args, index_in)
239
+ child_x_out = child_x_out + args.uncond
240
+
241
+ index_offset = index_out + len(sliced_x_out)
242
+ sliced_x_out.append(child_x_out)
243
+ sliced_cond_indices.append((index_offset, child_weight))
244
+
245
+ index_in += child.accept(neutral_prompt_parser.FlatSizeVisitor())
246
+
247
+ return sliced_x_out, sliced_cond_indices
248
+
249
+
250
+ def _get_cond_delta_and_weight(
251
+ prompt: neutral_prompt_parser.PromptExpr,
252
+ args: _DenoiseArgs,
253
+ index: int,
254
+ ) -> Tuple[torch.Tensor, float]:
255
+ """Compute cond delta and effective weight, applying affine transform if present."""
256
+ cond_delta = prompt.accept(_CondDeltaVisitor(), args, index)
257
+ cond_delta = cond_delta + prompt.accept(_AuxCondDeltaVisitor(), args, cond_delta, index)
258
+ weight = prompt.weight
259
+
260
+ if prompt.local_transform is not None:
261
+ transformed, weight_tensor = affine_mod.apply_masked_transform(
262
+ cond_delta + args.uncond,
263
+ prompt.local_transform,
264
+ prompt.weight,
265
+ )
266
+ cond_delta = transformed - args.uncond
267
+ weight = weight_tensor # type: ignore[assignment]
268
+
269
+ return cond_delta, weight
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Visitor: CondDelta – weighted sum of leaf cond − uncond
274
+ # ---------------------------------------------------------------------------
275
+
276
+ class _CondDeltaVisitor:
277
+ def visit_leaf_prompt(
278
+ self,
279
+ that: neutral_prompt_parser.LeafPrompt,
280
+ args: _DenoiseArgs,
281
+ index: int,
282
+ ) -> torch.Tensor:
283
+ cond_info = args.cond_indices[index]
284
+ if that.weight != cond_info[1]:
285
+ _console_warn(f'''
286
+ An unexpected noise weight was encountered at prompt #{index}
287
+ Expected :{that.weight}, but got :{cond_info[1]}
288
+ This is likely due to another extension also monkey patching `combine_denoised`.
289
+ Please open a bug report: https://github.com/ljleb/sd-webui-neutral-prompt/issues
290
+ ''')
291
+ return args.x_out[cond_info[0]] - args.uncond
292
+
293
+ def visit_composite_prompt(
294
+ self,
295
+ that: neutral_prompt_parser.CompositePrompt,
296
+ args: _DenoiseArgs,
297
+ index: int,
298
+ ) -> torch.Tensor:
299
+ cond_delta = torch.zeros_like(args.x_out[0])
300
+ for child in that.children:
301
+ if child.conciliation is None:
302
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
303
+ cond_delta = cond_delta + child_weight * child_delta
304
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
305
+ return cond_delta
306
+
307
+
308
+ # ---------------------------------------------------------------------------
309
+ # Visitor: AuxCondDelta – all conciliation strategies
310
+ # ---------------------------------------------------------------------------
311
+
312
+ class _AuxCondDeltaVisitor:
313
+ def visit_leaf_prompt(
314
+ self,
315
+ that: neutral_prompt_parser.LeafPrompt,
316
+ args: _DenoiseArgs,
317
+ cond_delta: torch.Tensor,
318
+ index: int,
319
+ ) -> torch.Tensor:
320
+ return torch.zeros_like(args.x_out[0])
321
+
322
+ def visit_composite_prompt(
323
+ self,
324
+ that: neutral_prompt_parser.CompositePrompt,
325
+ args: _DenoiseArgs,
326
+ cond_delta: torch.Tensor,
327
+ index: int,
328
+ ) -> torch.Tensor:
329
+ aux_cond_delta = torch.zeros_like(args.x_out[0])
330
+ salient_cond_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT k=20
331
+ salient_wide_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_WIDE k=1
332
+ salient_blob_deltas: List[Tuple[torch.Tensor, float]] = [] # AND_SALT_BLOB k=20 + morphology
333
+ align_blend_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
334
+ mask_align_deltas: List[Tuple[torch.Tensor, float, int, int]] = []
335
+
336
+ for child in that.children:
337
+ if child.conciliation is not None:
338
+ child_delta, child_weight = _get_cond_delta_and_weight(child, args, index)
339
+ strat = child.conciliation
340
+
341
+ if strat == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
342
+ aux_cond_delta = aux_cond_delta + child_weight * _get_perpendicular_component(cond_delta, child_delta)
343
+
344
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
345
+ salient_cond_deltas.append((child_delta, child_weight))
346
+
347
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_WIDE:
348
+ salient_wide_deltas.append((child_delta, child_weight))
349
+
350
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK_BLOB:
351
+ salient_blob_deltas.append((child_delta, child_weight))
352
+
353
+ elif strat == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
354
+ aux_cond_delta = aux_cond_delta + child_weight * _filter_abs_top_k(child_delta, 0.05)
355
+
356
+ else:
357
+ # AND_ALIGN_D_S (soft alignment blend)
358
+ m = re.match(r'AND_ALIGN_(\d+)_(\d+)', strat.value)
359
+ if m:
360
+ align_blend_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
361
+ else:
362
+ # AND_MASK_ALIGN_D_S (binary alignment mask)
363
+ m = re.match(r'AND_MASK_ALIGN_(\d+)_(\d+)', strat.value)
364
+ if m:
365
+ mask_align_deltas.append((child_delta, child_weight, int(m.group(1)), int(m.group(2))))
366
+
367
+ index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
368
+
369
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_cond_deltas, child_k=20)
370
+ aux_cond_delta = aux_cond_delta + _salient_blend(cond_delta, salient_wide_deltas, child_k=1)
371
+ aux_cond_delta = aux_cond_delta + _salient_blend_blob(cond_delta, salient_blob_deltas)
372
+ aux_cond_delta = aux_cond_delta + _alignment_blend(cond_delta, align_blend_deltas)
373
+ aux_cond_delta = aux_cond_delta + _alignment_mask_blend(cond_delta, mask_align_deltas)
374
+ return aux_cond_delta
375
+
376
+
377
+ # ---------------------------------------------------------------------------
378
+ # Strategy implementations
379
+ # ---------------------------------------------------------------------------
380
+
381
+ def _get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
382
+ if (normal == 0).all():
383
+ if shared.state.sampling_step <= 0:
384
+ _warn_projection_not_found()
385
+ return vector
386
+ return vector - normal * torch.sum(normal * vector) / torch.norm(normal) ** 2
387
+
388
+
389
+ def _salient_blend(
390
+ normal: torch.Tensor,
391
+ vectors: List[Tuple[torch.Tensor, float]],
392
+ child_k: float = 20.0,
393
+ ) -> torch.Tensor:
394
+ """
395
+ Saliency-guided blend: each child prompt wins in the latent regions
396
+ where its absolute activation magnitude is strongest.
397
+
398
+ child_k controls how sharp/selective the child salience mask is:
399
+ child_k=20 → very sharp, 1-2 peak pixels (AND_SALT — life-branch style)
400
+ child_k=1 → broad, ~55% of pixels (AND_SALT_WIDE — original main-branch)
401
+ Parent always uses k=1 (diffuse reference).
402
+ """
403
+ if not vectors:
404
+ return torch.zeros_like(normal)
405
+
406
+ salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=child_k) for v, _ in vectors]
407
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
408
+
409
+ result = torch.zeros_like(normal)
410
+ for mask_i, (vector, weight) in enumerate(vectors, start=1):
411
+ vector_mask = (mask == mask_i).float()
412
+ result = result + weight * vector_mask * (vector - normal)
413
+ return result
414
+
415
+
416
+ def _salient_blend_blob(
417
+ normal: torch.Tensor,
418
+ vectors: List[Tuple[torch.Tensor, float]],
419
+ ) -> torch.Tensor:
420
+ """
421
+ AND_SALT_BLOB: life-branch dev.py algorithm (cleaned of debug scaffolding).
422
+
423
+ Pipeline per child:
424
+ 1. k=20 softmax → very sharp salience seed (1-2 pixels)
425
+ 2. _life_erode ×6 → erode to spatially dense core
426
+ 3. _life_thickify ×2 → grow outward into a smooth blob
427
+ 4. result += weight × blob_mask × (child − parent)
428
+
429
+ This is what the life-branch author was iterating toward in dev.py.
430
+ """
431
+ if not vectors:
432
+ return torch.zeros_like(normal)
433
+
434
+ salience_maps = [_get_salience(normal, k=1)] + [_get_salience(v, k=20) for v, _ in vectors]
435
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
436
+
437
+ result = torch.zeros_like(normal)
438
+ for mask_i, (vector, weight) in enumerate(vectors, start=1):
439
+ vector_mask = (mask == mask_i).float()
440
+ for _ in range(6):
441
+ vector_mask = _life_step(vector_mask, _erode_rule)
442
+ for _ in range(2):
443
+ vector_mask = _life_step(vector_mask, _thickify_rule)
444
+ result = result + weight * vector_mask * (vector - normal)
445
+ return result
446
+
447
+
448
+ def _life_step(board: torch.Tensor, rule) -> torch.Tensor:
449
+ """
450
+ One step of a cellular-automaton morphology on a C×H×W binary mask.
451
+
452
+ The conv3d trick from dev.py: concatenate [board, board[:-1]] along the
453
+ channel axis so a single 3-D convolution simultaneously sums the 3×3
454
+ spatial neighbourhood across *all* C channels. This keeps the operation
455
+ GPU-friendly and avoids an explicit loop over channels.
456
+
457
+ neighbors[c,h,w] = Σ_{dc,dh,dw} board_padded[c+dc, h+dh, w+dw]
458
+ minus the center pixel itself (board is subtracted).
459
+ """
460
+ C = board.shape[0]
461
+ kernel = torch.ones((C, 3, 3), dtype=board.dtype, device=board.device)
462
+ kernel = kernel.unsqueeze(0).unsqueeze(0) # [1, 1, C, 3, 3]
463
+
464
+ # Pad spatially (left/right/top/bottom) but not along channel axis
465
+ padded = torch.cat([board.clone(), board[:-1].clone()], dim=0) # [2C-1, H, W]
466
+ padded = torch.nn.functional.pad(padded, (1, 1, 1, 1, 0, 0), value=0) # [2C-1, H+2, W+2]
467
+
468
+ neighbors = torch.nn.functional.conv3d(
469
+ padded.unsqueeze(0).unsqueeze(0), # [1, 1, 2C-1, H+2, W+2]
470
+ kernel, # [1, 1, C, 3, 3]
471
+ padding=0,
472
+ ).squeeze(0).squeeze(0) # [C, H, W]
473
+
474
+ neighbors = neighbors - board # subtract center pixel
475
+ return rule(board, neighbors).float()
476
+
477
+
478
+ def _erode_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
479
+ """Keep a pixel only if it is set AND its C-channel neighbourhood is dense.
480
+ Threshold C*5 matches dev.py: for C=4 that means ≥20 out of 36 possible."""
481
+ C = board.shape[0]
482
+ return (board == 1) & (neighbors >= C * 5)
483
+
484
+
485
+ def _thickify_rule(board: torch.Tensor, neighbors: torch.Tensor) -> torch.Tensor:
486
+ """Grow: keep existing pixels OR add any pixel adjacent to the core.
487
+ population ≥ 4 ensures only pixels touching at least one set neighbour grow."""
488
+ population = board + neighbors
489
+ return (board == 1) | (population >= 4)
490
+
491
+
492
+ def _get_salience(vector: torch.Tensor, k: float = 1.0) -> torch.Tensor:
493
+ """Softmax-based salience map. k > 1 → sharper, more selective mask."""
494
+ return torch.softmax(k * torch.abs(vector).flatten(), dim=0).reshape_as(vector)
495
+
496
+
497
+ def _filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
498
+ """Keep only the top k_ratio fraction of activations by absolute value."""
499
+ k = int(torch.numel(vector) * (1 - k_ratio))
500
+ k = max(1, k) # kthvalue requires k >= 1
501
+ threshold, _ = torch.kthvalue(torch.abs(vector.flatten()), k)
502
+ return vector * (torch.abs(vector) >= threshold).to(vector.dtype)
503
+
504
+
505
+ # ---------------------------------------------------------------------------
506
+ # Alignment blend (alignment_blend branch)
507
+ # ---------------------------------------------------------------------------
508
+
509
+ def _compute_subregion_similarity_map(
510
+ child_vector: torch.Tensor,
511
+ parent_vector: torch.Tensor,
512
+ region_size: int = 2,
513
+ ) -> torch.Tensor:
514
+ """
515
+ Compute local average cosine similarity of (region_size × region_size) regions.
516
+ Returns a map of shape [C, H, W] with values in [-1, 1].
517
+ """
518
+ C, H, W = child_vector.shape
519
+ parent = parent_vector.unsqueeze(0) # [1, C, H, W]
520
+ child = child_vector.unsqueeze(0)
521
+
522
+ region_radius = region_size // 2
523
+ if region_size % 2 == 1:
524
+ pad_size = (region_radius,) * 4
525
+ else:
526
+ pad_size = (region_radius - 1, region_radius) * 2
527
+
528
+ parent_reg = F.unfold(F.pad(parent, pad_size, 'constant', 0), kernel_size=region_size)
529
+ child_reg = F.unfold(F.pad(child, pad_size, 'constant', 0), kernel_size=region_size)
530
+
531
+ # [H*W, C, region_size, region_size]
532
+ # .contiguous() is required before .view() because .permute() produces a
533
+ # non-contiguous tensor and .view() raises RuntimeError on non-contiguous input.
534
+ parent_reg = parent_reg.view(1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
535
+ child_reg = child_reg.view( 1, C, region_size**2, H*W).permute(3, 1, 2, 0).contiguous().view(H*W, C, region_size, region_size)
536
+
537
+ unfold2 = torch.nn.Unfold(kernel_size=2)
538
+ parent_sub = unfold2(parent_reg).view(H*W, C, 4, (region_size - 1)**2)
539
+ child_sub = unfold2(child_reg ).view(H*W, C, 4, (region_size - 1)**2)
540
+
541
+ parent_sub = F.normalize(parent_sub, p=2, dim=2)
542
+ child_sub = F.normalize(child_sub, p=2, dim=2)
543
+ sim = (parent_sub * child_sub).sum(dim=2) # [H*W, C, (r-1)^2]
544
+ return sim.mean(dim=2).permute(1, 0).contiguous().view(C, H, W) # [C, H, W]
545
+
546
+
547
+ def _alignment_blend(
548
+ parent: torch.Tensor,
549
+ children: List[Tuple[torch.Tensor, float, int, int]],
550
+ ) -> torch.Tensor:
551
+ """
552
+ Soft alignment blend (AND_ALIGN_D_S).
553
+ Child contribution is weighted by max(0, structure_alignment − detail_alignment).
554
+ High weight where child changes detail without breaking structure.
555
+ """
556
+ result = torch.zeros_like(parent)
557
+ for child, weight, detail_size, structure_size in children:
558
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
559
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
560
+
561
+ # Normalise by absolute max so that all-negative maps (anti-correlated prompts
562
+ # such as 'black' vs 'white') don't blow up to ±1e7 and clamp incorrectly to 1.
563
+ # Dividing by positive max would invert the sign ordering when all values are
564
+ # negative; dividing by abs-max preserves relative ordering in both cases.
565
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
566
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
567
+ detail_sim = detail_sim / d_abs_max
568
+ structure_sim = structure_sim / s_abs_max
569
+
570
+ alignment_weight = torch.clamp(structure_sim - detail_sim, min=0.0, max=1.0)
571
+ result = result + (child - parent) * weight * alignment_weight
572
+ return result
573
+
574
+
575
+ def _alignment_mask_blend(
576
+ parent: torch.Tensor,
577
+ children: List[Tuple[torch.Tensor, float, int, int]],
578
+ ) -> torch.Tensor:
579
+ """
580
+ Binary alignment mask (AND_MASK_ALIGN_D_S).
581
+ Child receives full weight where structure_alignment > detail_alignment, zero elsewhere.
582
+ """
583
+ result = torch.zeros_like(parent)
584
+ for child, weight, detail_size, structure_size in children:
585
+ detail_sim = _compute_subregion_similarity_map(child, parent, detail_size)
586
+ structure_sim = _compute_subregion_similarity_map(child, parent, structure_size)
587
+
588
+ d_abs_max = detail_sim.abs().max().clamp(min=1e-8)
589
+ s_abs_max = structure_sim.abs().max().clamp(min=1e-8)
590
+ detail_sim = detail_sim / d_abs_max
591
+ structure_sim = structure_sim / s_abs_max
592
+
593
+ alignment_mask = (structure_sim > detail_sim).to(child)
594
+ result = result + (child - parent) * weight * alignment_mask
595
+ return result
596
+
597
+
598
+ # ---------------------------------------------------------------------------
599
+ # Sampler hijack
600
+ # ---------------------------------------------------------------------------
601
+
602
+ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
603
+ module=sd_samplers,
604
+ hijacker_attribute='__neutral_prompt_hijacker',
605
+ on_uninstall=script_callbacks.on_script_unloaded,
606
+ )
607
+
608
+
609
+ @sd_samplers_hijacker.hijack('create_sampler')
610
+ def create_sampler_hijack(name: str, model, original_function):
611
+ sampler = original_function(name, model)
612
+ if not hasattr(sampler, 'model_wrap_cfg') or not hasattr(sampler.model_wrap_cfg, 'combine_denoised'):
613
+ if global_state.is_enabled:
614
+ _warn_unsupported_sampler()
615
+ return sampler
616
+
617
+ sampler.model_wrap_cfg.combine_denoised = functools.partial(
618
+ combine_denoised_hijack,
619
+ original_function=sampler.model_wrap_cfg.combine_denoised,
620
+ )
621
+ return sampler
622
+
623
+
624
+ # ---------------------------------------------------------------------------
625
+ # Warnings / logging
626
+ # ---------------------------------------------------------------------------
627
+
628
+ def _warn_unsupported_sampler() -> None:
629
+ _console_warn('''
630
+ Neutral prompt relies on composition via AND, which the webui does not support
631
+ when using any of the DDIM, PLMS and UniPC samplers.
632
+ The sampler will NOT be patched – falling back on the original implementation.
633
+ ''')
634
+
635
+
636
+ def _warn_projection_not_found() -> None:
637
+ _console_warn('''
638
+ Could not find a projection for one or more AND_PERP prompts.
639
+ These prompts will NOT be made perpendicular.
640
+ ''')
641
+
642
+
643
+ def _console_warn(message: str) -> None:
644
+ if not global_state.verbose:
645
+ return
646
+ print(f'\n[sd-webui-neutral-prompt]{textwrap.dedent(message)}', file=sys.stderr)
neutral_prompt_patched/lib_neutral_prompt/external_code/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+
4
+ @contextlib.contextmanager
5
+ def fix_path():
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ extension_path = str(Path(__file__).parent.parent.parent)
10
+ added = False
11
+ if extension_path not in sys.path:
12
+ sys.path.insert(0, extension_path)
13
+ added = True
14
+
15
+ yield
16
+
17
+ if added:
18
+ sys.path.remove(extension_path)
19
+
20
+
21
+ with fix_path():
22
+ del fix_path, contextlib
23
+ from .api import *
neutral_prompt_patched/lib_neutral_prompt/external_code/api.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ External API for sd-webui-neutral-prompt.
3
+
4
+ Provides thin helpers that external extensions / scripts can import via:
5
+
6
+ from lib_neutral_prompt.external_code import override_cfg_rescale, get_last_cfg_rescale_factor
7
+ """
8
+
9
+ from typing import Optional
10
+
11
+ from lib_neutral_prompt import global_state
12
+
13
+
14
+ def override_cfg_rescale(cfg_rescale: float) -> None:
15
+ """
16
+ Override the CFG rescale value for the *next* generation step only.
17
+ After the step runs the value is automatically cleared.
18
+ """
19
+ global_state.cfg_rescale_override = cfg_rescale
20
+
21
+
22
+ def get_last_cfg_rescale_factor() -> Optional[float]:
23
+ """
24
+ Return the CFG rescale factor computed during the most recent denoising
25
+ step, or None if rescaling was not active.
26
+ """
27
+ return global_state.CFGRescaleFactorSingleton.get()
neutral_prompt_patched/lib_neutral_prompt/global_state.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global mutable state shared across the extension.
3
+
4
+ Combines:
5
+ - is_enabled / prompt_exprs / verbose (all branches)
6
+ - cfg_rescale + cfg_rescale_override mechanism (main branch)
7
+ - batch_cond_indices (affine branch – needed by on_cfg_denoiser pre-noise hook)
8
+ - CFGRescaleFactorSingleton – exposes the last computed rescale factor
9
+ to external callers / API users (export_rescale_factor branch)
10
+ """
11
+
12
+ import threading
13
+ from typing import List, Optional, Tuple
14
+ from lib_neutral_prompt import neutral_prompt_parser
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Runtime state
19
+ # ---------------------------------------------------------------------------
20
+
21
+ is_enabled: bool = False
22
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
23
+ # Populated by reconstruct_multicond_batch hook; used by the pre-noise
24
+ # affine transform hook (on_cfg_denoiser) to know which latent slices
25
+ # correspond to which prompt child.
26
+ batch_cond_indices: List[List[Tuple[int, float]]] = []
27
+ cfg_rescale: float = 0.0
28
+ verbose: bool = False # matches UI default; UI syncs this on settings load
29
+
30
+ # Set to a float value by XYZ-grid or external API to override cfg_rescale
31
+ # for a single generation step, then auto-cleared.
32
+ cfg_rescale_override: Optional[float] = None
33
+
34
+
35
+ def apply_and_clear_cfg_rescale_override() -> None:
36
+ """Apply a one-shot cfg_rescale override and immediately clear it."""
37
+ global cfg_rescale, cfg_rescale_override
38
+ if cfg_rescale_override is not None:
39
+ cfg_rescale = cfg_rescale_override
40
+ cfg_rescale_override = None
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # CFG Rescale Factor Singleton (export_rescale_factor branch)
45
+ #
46
+ # Stores the last computed rescale factor so that external tools / scripts
47
+ # can read it without re-deriving it.
48
+ #
49
+ # Uses threading.local() so concurrent API workers each get their own slot.
50
+ # clear() must be called at the start of every _cfg_rescale() call so that
51
+ # get() correctly returns None when rescaling was skipped this step.
52
+ # ---------------------------------------------------------------------------
53
+
54
+ class CFGRescaleFactorSingleton:
55
+ """Thread-local store for the most recently computed CFG rescale factor.
56
+
57
+ Lifecycle per denoising step:
58
+ 1. _cfg_rescale() calls clear() at entry — value becomes None.
59
+ 2. If rescaling is active, set() stores the computed factor.
60
+ 3. External code calls get() after the step; receives the factor or None.
61
+ """
62
+
63
+ _state = threading.local()
64
+
65
+ @classmethod
66
+ def set(cls, value: float) -> None:
67
+ cls._state.cfg_rescale_factor = value
68
+
69
+ @classmethod
70
+ def get(cls) -> Optional[float]:
71
+ return getattr(cls._state, 'cfg_rescale_factor', None)
72
+
73
+ @classmethod
74
+ def clear(cls) -> None:
75
+ cls._state.cfg_rescale_factor = None
neutral_prompt_patched/lib_neutral_prompt/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+
4
+ class ModuleHijacker:
5
+ def __init__(self, module):
6
+ self.__module = module
7
+ self.__original_functions = dict()
8
+
9
+ def hijack(self, attribute):
10
+ if attribute not in self.__original_functions:
11
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
12
+
13
+ def decorator(function):
14
+ setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ on_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ on_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
neutral_prompt_patched/lib_neutral_prompt/neutral_prompt_parser.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Neutral Prompt Parser
3
+ Combines:
4
+ - Core AND / AND_PERP / AND_SALT / AND_TOPK strategies (main branch)
5
+ - Affine spatial transforms: ROTATE / SLIDE / SCALE / SHEAR (affine branch)
6
+ - Local alignment blend: AND_ALIGN_D_S / AND_MASK_ALIGN_D_S (alignment_blend branch)
7
+ """
8
+
9
+ import abc
10
+ import dataclasses
11
+ import math
12
+ import re
13
+ from enum import Enum
14
+ from itertools import product
15
+ from typing import Any, List, Optional, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Keyword registry
22
+ # ---------------------------------------------------------------------------
23
+
24
+ # Base conciliation keywords
25
+ _BASE_KEYWORD_MAP = {
26
+ 'AND_PERP': 'PERPENDICULAR',
27
+ 'AND_SALT': 'SALIENCE_MASK', # k=20 sharp mask (life-branch style)
28
+ 'AND_SALT_WIDE': 'SALIENCE_MASK_WIDE', # k=1 broad mask (original main-branch behaviour)
29
+ 'AND_SALT_BLOB': 'SALIENCE_MASK_BLOB', # k=20 + erode + thickify → smooth blob (dev.py intent)
30
+ 'AND_TOPK': 'SEMANTIC_GUIDANCE',
31
+ }
32
+
33
+ # Alignment blend: AND_ALIGN_D_S (soft weight, structure > detail)
34
+ _ALIGN_KEYWORD_MAP = {
35
+ f'AND_ALIGN_{i}_{j}': f'ALIGNMENT_BLEND_{i}_{j}'
36
+ for i, j in product(range(2, 33), repeat=2) if i != j
37
+ }
38
+
39
+ # Alignment mask: AND_MASK_ALIGN_D_S (binary mask, structure > detail)
40
+ _MASK_ALIGN_KEYWORD_MAP = {
41
+ f'AND_MASK_ALIGN_{i}_{j}': f'ALIGNMENT_MASK_BLEND_{i}_{j}'
42
+ for i, j in product(range(2, 33), repeat=2) if i != j
43
+ }
44
+
45
+ keyword_mapping = _BASE_KEYWORD_MAP | _ALIGN_KEYWORD_MAP | _MASK_ALIGN_KEYWORD_MAP
46
+
47
+ PromptKeyword = Enum('PromptKeyword', {'AND': 'AND', **{k: k for k in keyword_mapping}})
48
+ ConciliationStrategy = Enum('ConciliationStrategy', {v: k for k, v in keyword_mapping.items()})
49
+
50
+ # Lists kept for ordered iteration (tokenize uses list order for regex priority)
51
+ prompt_keywords = [e.value for e in PromptKeyword]
52
+ conciliation_strategies = [e.value for e in ConciliationStrategy]
53
+
54
+ # Sets for O(1) membership checks in parser hot-loops (1864-item list = 230x slower)
55
+ _prompt_keywords_set = frozenset(prompt_keywords)
56
+ _conciliation_strategies_set = frozenset(conciliation_strategies)
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Affine transform definitions (from affine branch)
61
+ # ---------------------------------------------------------------------------
62
+
63
+ affine_transforms = {
64
+ 'ROTATE': lambda t, angle=0, *_: t @ torch.tensor([
65
+ [math.cos(angle * 2 * math.pi), -math.sin(angle * 2 * math.pi), 0],
66
+ [math.sin(angle * 2 * math.pi), math.cos(angle * 2 * math.pi), 0],
67
+ [0, 0, 1],
68
+ ], dtype=torch.float32),
69
+ 'SLIDE': lambda t, x=0, y=0, *_: t @ torch.tensor([
70
+ [1, 0, float(x)],
71
+ [0, 1, float(y)],
72
+ [0, 0, 1],
73
+ ], dtype=torch.float32),
74
+ 'SCALE': lambda t, x=1, y=None, *_: t @ torch.tensor([
75
+ [float(x), 0, 0],
76
+ [0, float(y) if y is not None else float(x), 0],
77
+ [0, 0, 1],
78
+ ], dtype=torch.float32),
79
+ 'SHEAR': lambda t, x=0, y=None, *_: t @ torch.tensor([
80
+ [1, math.tan(float(x) * 2 * math.pi), 0],
81
+ [math.tan(float(y if y is not None else x) * 2 * math.pi), 1, 0],
82
+ [0, 0, 1],
83
+ ], dtype=torch.float32),
84
+ }
85
+
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # AST node types
89
+ # ---------------------------------------------------------------------------
90
+
91
+ @dataclasses.dataclass
92
+ class PromptExpr(abc.ABC):
93
+ weight: float
94
+ conciliation: Optional[ConciliationStrategy]
95
+ local_transform: Optional[torch.Tensor] # 2×3 affine tensor or None
96
+
97
+ @abc.abstractmethod
98
+ def accept(self, visitor, *args, **kwargs) -> Any:
99
+ pass
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class LeafPrompt(PromptExpr):
104
+ prompt: str
105
+
106
+ def accept(self, visitor, *args, **kwargs):
107
+ return visitor.visit_leaf_prompt(self, *args, **kwargs)
108
+
109
+
110
+ @dataclasses.dataclass
111
+ class CompositePrompt(PromptExpr):
112
+ children: List[PromptExpr]
113
+
114
+ def accept(self, visitor, *args, **kwargs):
115
+ return visitor.visit_composite_prompt(self, *args, **kwargs)
116
+
117
+
118
+ class FlatSizeVisitor:
119
+ def visit_leaf_prompt(self, that: LeafPrompt) -> int:
120
+ return 1
121
+
122
+ def visit_composite_prompt(self, that: CompositePrompt) -> int:
123
+ return sum(child.accept(self) for child in that.children) if that.children else 0
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # Parser
128
+ # ---------------------------------------------------------------------------
129
+
130
+ def parse_root(string: str) -> CompositePrompt:
131
+ tokens = tokenize(string)
132
+ prompts = parse_prompts(tokens)
133
+ return CompositePrompt(1.0, None, None, prompts)
134
+
135
+
136
+ def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
137
+ prompts = [parse_prompt(tokens, first=True, nested=nested)]
138
+ while tokens:
139
+ if nested and tokens[0] == ']':
140
+ break
141
+ prompts.append(parse_prompt(tokens, first=False, nested=nested))
142
+ return prompts
143
+
144
+
145
+ def _compose_affine(
146
+ a: Optional[torch.Tensor],
147
+ b: Optional[torch.Tensor],
148
+ ) -> Optional[torch.Tensor]:
149
+ """
150
+ Compose two 2×3 affine matrices (a applied first, then b).
151
+ Returns None if both are None; returns the non-None one if only one exists.
152
+ """
153
+ if a is None:
154
+ return b
155
+ if b is None:
156
+ return a
157
+ # Lift to 3×3, multiply, trim back to 2×3
158
+ a3 = torch.vstack([a, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
159
+ b3 = torch.vstack([b, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32)])
160
+ return (b3 @ a3)[:-1]
161
+
162
+
163
+ def _looks_like_leading_affine_prompt(tokens: List[str]) -> bool:
164
+ """
165
+ Return True if `tokens` starts with one or more affine keywords followed
166
+ immediately by a prompt keyword, e.g.:
167
+ ROTATE[0.125] AND_PERP ...
168
+ SLIDE[0.05,0] AND_SALT ...
169
+ ROTATE[0.125] SLIDE[0.05,0] AND_PERP ...
170
+
171
+ This is a pure structural scan — it does NOT compute any affine matrices,
172
+ so it is safe to call as a look-ahead probe without mock-torch issues.
173
+
174
+ Used in two places:
175
+ - parse_prompt() → consume leading affine before keyword
176
+ - _parse_prompt_text() → stop current segment before this boundary
177
+ """
178
+ pos = 0
179
+ n = len(tokens)
180
+ # skip a leading whitespace-only token
181
+ if pos < n and not tokens[pos].strip():
182
+ pos += 1
183
+ # must start with an affine keyword
184
+ if pos >= n or tokens[pos] not in affine_transforms:
185
+ return False
186
+ # consume one or more KEYWORD [ args ] blocks
187
+ while pos < n and tokens[pos] in affine_transforms:
188
+ pos += 1 # skip keyword
189
+ if pos >= n or tokens[pos] != '[':
190
+ return False
191
+ pos += 1 # skip '['
192
+ if pos < n and tokens[pos] != ']':
193
+ pos += 1 # skip args (may contain commas)
194
+ if pos >= n or tokens[pos] != ']':
195
+ return False
196
+ pos += 1 # skip ']'
197
+ if pos < n and not tokens[pos].strip():
198
+ pos += 1 # skip optional whitespace
199
+ # the very next token must be a prompt keyword
200
+ return pos < n and tokens[pos] in _prompt_keywords_set
201
+
202
+
203
+ def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
204
+ # Support two affine-keyword orderings:
205
+ # (A) AND_PERP ROTATE[0.25] text ← trailing affine (original syntax)
206
+ # (B) ROTATE[0.25] AND_PERP text ← leading affine (documented syntax)
207
+ #
208
+ # Note: we check _looks_like_leading_affine_prompt unconditionally (not
209
+ # gated on `not first`) so that a prompt starting with e.g.
210
+ # "ROTATE[0.125] AND_PERP text" is handled correctly even as the first
211
+ # segment.
212
+ leading_affine: Optional[torch.Tensor] = None
213
+ if _looks_like_leading_affine_prompt(tokens):
214
+ probe = tokens.copy()
215
+ leading_affine = _parse_affine_transform(probe)
216
+ tokens[:] = probe
217
+
218
+ # After consuming a leading affine the next token is a keyword; accept it.
219
+ # We also accept a keyword as the very first token (first=True) without a
220
+ # leading affine — e.g. "AND_PERP vivid" at the start of a string should
221
+ # produce a single PERP child, not an empty default child followed by a PERP
222
+ # child. The original `not first` guard was an artefact of earlier grammar;
223
+ # all existing tests pass without it.
224
+ if tokens and tokens[0] in _prompt_keywords_set:
225
+ prompt_type = tokens.pop(0)
226
+ else:
227
+ prompt_type = PromptKeyword.AND.value
228
+
229
+ conciliation = ConciliationStrategy(prompt_type) if prompt_type in _conciliation_strategies_set else None
230
+
231
+ # Parse optional trailing affine transform chain (original syntax)
232
+ trailing_affine = _parse_affine_transform(tokens)
233
+ affine_transform = _compose_affine(leading_affine, trailing_affine)
234
+
235
+ # Try composite (bracketed) form
236
+ tokens_copy = tokens.copy()
237
+ if tokens_copy and tokens_copy[0] == '[':
238
+ tokens_copy.pop(0)
239
+ prompts = parse_prompts(tokens_copy, nested=True)
240
+ if tokens_copy:
241
+ assert tokens_copy.pop(0) == ']'
242
+ if len(prompts) > 1:
243
+ tokens[:] = tokens_copy
244
+ weight = _parse_weight(tokens)
245
+ return CompositePrompt(weight, conciliation, affine_transform, prompts)
246
+
247
+ prompt_text, weight = _parse_prompt_text(tokens, nested=nested)
248
+ return LeafPrompt(weight, conciliation, affine_transform, prompt_text)
249
+
250
+
251
+ def _parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
252
+ text = ''
253
+ depth = 0
254
+ weight = 1.0
255
+ while tokens:
256
+ tok = tokens[0]
257
+ if tok == ']':
258
+ if depth == 0:
259
+ if nested:
260
+ break
261
+ else:
262
+ depth -= 1
263
+ elif tok == '[':
264
+ depth += 1
265
+ elif tok == ':':
266
+ if len(tokens) >= 2 and _is_float(tokens[1].strip()):
267
+ if len(tokens) < 3 or tokens[2] in _prompt_keywords_set or (tokens[2] == ']' and depth == 0):
268
+ tokens.pop(0)
269
+ weight = float(tokens.pop(0).strip())
270
+ break
271
+ elif tok in _prompt_keywords_set:
272
+ break
273
+ # Stop before a leading-affine-then-keyword boundary, e.g.
274
+ # ROTATE[0.125] AND_PERP text
275
+ # so the affine is consumed by the *next* parse_prompt call, not
276
+ # swallowed as raw text by the current segment.
277
+ elif depth == 0 and _looks_like_leading_affine_prompt(tokens):
278
+ break
279
+ text += tokens.pop(0)
280
+ return text, weight
281
+
282
+
283
+ def _parse_affine_transform(tokens: List[str]) -> Optional[torch.Tensor]:
284
+ """
285
+ Consume optional affine keywords like ROTATE[0.25] SLIDE[0.1,0] from the token stream.
286
+ Returns a 2×3 float32 tensor, or None if no affine tokens were found.
287
+ """
288
+ tokens_copy = tokens.copy()
289
+ if tokens_copy and not tokens_copy[0].strip():
290
+ tokens_copy.pop(0)
291
+
292
+ affine_funcs = []
293
+
294
+ while tokens_copy and tokens_copy[0] in affine_transforms:
295
+ func = affine_transforms[tokens_copy.pop(0)]
296
+ args: List[float] = []
297
+
298
+ if not (tokens_copy and tokens_copy[0] == '['):
299
+ break
300
+ tokens_copy.pop(0) # consume '['
301
+
302
+ if tokens_copy and tokens_copy[0] != ']':
303
+ if tokens_copy[0].strip():
304
+ try:
305
+ args = [float(a.strip()) for a in tokens_copy.pop(0).split(',')]
306
+ except ValueError:
307
+ break
308
+ else:
309
+ tokens_copy.pop(0)
310
+
311
+ if not (tokens_copy and tokens_copy[0] == ']'):
312
+ break
313
+ tokens_copy.pop(0) # consume ']'
314
+
315
+ affine_funcs.append(lambda t, f=func, a=args: f(t, *a))
316
+ if tokens_copy and not tokens_copy[0].strip():
317
+ tokens_copy.pop(0)
318
+ tokens[:] = tokens_copy
319
+
320
+ if not affine_funcs:
321
+ return None
322
+
323
+ transform = torch.eye(3, dtype=torch.float32)[:-1] # 2×3
324
+ for fn in reversed(affine_funcs):
325
+ transform = fn(transform)
326
+ return transform
327
+
328
+
329
+ def _parse_weight(tokens: List[str]) -> float:
330
+ if len(tokens) >= 2 and tokens[0] == ':' and _is_float(tokens[1]):
331
+ tokens.pop(0)
332
+ return float(tokens.pop(0))
333
+ return 1.0
334
+
335
+
336
+ def tokenize(s: str) -> List[str]:
337
+ # Use general patterns for AND_ALIGN / AND_MASK_ALIGN families instead of
338
+ # enumerating all 1800+ combinations - that regex is 41k chars and 500x slower.
339
+ # Order matters: longer/more-specific patterns must come first.
340
+ # Use (?<!\w) / (?!\w) instead of \b because _ counts as \w in Python's \b,
341
+ # which would cause 'AND' to match inside 'XAND' (no boundary between \w chars).
342
+ affine_kw_pattern = '|'.join(
343
+ rf'(?<!\w){kw}(?!\w)' for kw in affine_transforms
344
+ )
345
+ keyword_pattern = (
346
+ r'(?<!\w)AND_MASK_ALIGN_\d+_\d+(?!\w)' # must come before AND_ALIGN and AND
347
+ r'|(?<!\w)AND_ALIGN_\d+_\d+(?!\w)'
348
+ r'|(?<!\w)AND_PERP(?!\w)'
349
+ r'|(?<!\w)AND_SALT_WIDE(?!\w)' # must come before AND_SALT
350
+ r'|(?<!\w)AND_SALT_BLOB(?!\w)' # must come before AND_SALT
351
+ r'|(?<!\w)AND_SALT(?!\w)'
352
+ r'|(?<!\w)AND_TOPK(?!\w)'
353
+ r'|(?<!\w)AND(?!\w)'
354
+ )
355
+ return [t for t in re.split(rf'(\[|\]|:|{keyword_pattern}|{affine_kw_pattern})', s) if t.strip()]
356
+
357
+
358
+ def _is_float(s: str) -> bool:
359
+ try:
360
+ float(s)
361
+ return True
362
+ except ValueError:
363
+ return False
364
+
365
+
366
+ # ---------------------------------------------------------------------------
367
+ # Quick smoke-test
368
+ # ---------------------------------------------------------------------------
369
+ if __name__ == '__main__':
370
+ res = parse_root('''
371
+ hello
372
+ AND_PERP [
373
+ arst
374
+ AND defg : 2
375
+ AND_SALT [
376
+ very nested huh? what do you say :.0
377
+ ]
378
+ ]
379
+ AND_ALIGN_4_8 watercolor style :0.5
380
+ ROTATE[0.125] AND_PERP vibrant colors
381
+ ''')
382
+ print(res)
neutral_prompt_patched/lib_neutral_prompt/prompt_parser_hijack.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+
4
+ from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
5
+ from modules import script_callbacks, prompt_parser
6
+
7
+ # ---------------------------------------------------------------------------
8
+ # Fix: prompt_parser_fixed escapes standalone '&' as a multicond separator.
9
+ # neutral_prompt_parser does NOT treat '&' as AND — it keeps it as literal text.
10
+ # So a leaf like "cat & dog" transpiles to "cat & dog :1.0",
11
+ # and the patched prompt_parser splits that into 3 multicond branches instead of 1,
12
+ # which shifts all batch_cond_indices and breaks PERP/SALT/affine.
13
+ #
14
+ # Solution: escape any standalone '&' inside leaf text with '\&' before handing
15
+ # the string to the webui parser. The patched prompt_parser correctly unescapes
16
+ # '\&' back to '&' during conditioning, so the model sees the original text.
17
+ #
18
+ # "Standalone &" = surrounded by whitespace (or at start/end of string).
19
+ # Examples:
20
+ # "cat & dog" -> "cat \& dog" <- was splitting, now safe
21
+ # "R&D" -> "R&D" <- was NOT splitting, unchanged
22
+ # "\&" -> "\&" <- already escaped, not double-escaped
23
+ # ---------------------------------------------------------------------------
24
+ _STANDALONE_AMP = re.compile(r'(?<!\\)(?<!\S)&(?!\S)')
25
+
26
+
27
+ def _escape_leaf_ampersands(text: str) -> str:
28
+ """Escape standalone '&' in a leaf prompt so patched prompt_parser
29
+ does not treat it as a multicond AND separator."""
30
+ if not text or '&' not in text:
31
+ return text
32
+ return _STANDALONE_AMP.sub(r'\\&', text)
33
+
34
+
35
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
36
+ module=prompt_parser,
37
+ hijacker_attribute='__neutral_prompt_hijacker',
38
+ on_uninstall=script_callbacks.on_script_unloaded,
39
+ )
40
+
41
+
42
+ @prompt_parser_hijacker.hijack('get_multicond_prompt_list')
43
+ def get_multicond_prompt_list_hijack(prompts, original_function):
44
+ if not global_state.is_enabled:
45
+ return original_function(prompts)
46
+
47
+ global_state.prompt_exprs = parse_prompts(prompts)
48
+ webui_prompts = transpile_exprs(global_state.prompt_exprs)
49
+ if isinstance(prompts, getattr(prompt_parser, 'SdConditioning', type(None))):
50
+ webui_prompts = prompt_parser.SdConditioning(webui_prompts, copy_from=prompts)
51
+
52
+ return original_function(webui_prompts)
53
+
54
+
55
+ def parse_prompts(prompts: List[str]) -> List[neutral_prompt_parser.PromptExpr]:
56
+ exprs = []
57
+ for prompt in prompts:
58
+ expr = neutral_prompt_parser.parse_root(prompt)
59
+ exprs.append(expr)
60
+ return exprs
61
+
62
+
63
+ def transpile_exprs(exprs: neutral_prompt_parser.PromptExpr):
64
+ webui_prompts = []
65
+ for expr in exprs:
66
+ webui_prompts.append(expr.accept(WebuiPromptVisitor()))
67
+ return webui_prompts
68
+
69
+
70
+ class WebuiPromptVisitor:
71
+ def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
72
+ # Escape standalone '&' so patched prompt_parser doesn't split the leaf
73
+ # into extra multicond branches. '\&' is correctly unescaped downstream.
74
+ prompt = _escape_leaf_ampersands(that.prompt)
75
+ return f'{prompt} :{that.weight}'
76
+
77
+ def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
78
+ return ' AND '.join(child.accept(self) for child in that.children)
79
+
80
+
81
+ @prompt_parser_hijacker.hijack('reconstruct_multicond_batch')
82
+ def reconstruct_multicond_batch_hijack(*args, original_function, **kwargs):
83
+ """Store batch_cond_indices for the pre-noise affine hook (affine branch)."""
84
+ res = original_function(*args, **kwargs)
85
+ global_state.batch_cond_indices = res[0]
86
+ return res
neutral_prompt_patched/lib_neutral_prompt/ui.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for sd-webui-neutral-prompt (unified).
3
+
4
+ Combines:
5
+ - Core prompt types + tooltip (main branch)
6
+ - AND_ALIGN / AND_MASK_ALIGN entries (alignment_blend branch)
7
+ - elem_id on dropdown (main branch)
8
+ - CFG Rescale infotext / paste support (all branches)
9
+ - Affine-transform hint in tooltip (affine branch)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from itertools import product
15
+ from typing import Callable, Dict, List, Tuple
16
+
17
+ import dataclasses
18
+ import gradio as gr
19
+
20
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
21
+ from modules import script_callbacks, shared
22
+
23
+
24
+ txt2img_prompt_textbox = None
25
+ img2img_prompt_textbox = None
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Prompt-type registry shown in the UI dropdown
30
+ # ---------------------------------------------------------------------------
31
+
32
+ prompt_types: Dict[str, str] = {
33
+ 'Perpendicular (AND_PERP)': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
34
+ 'Saliency sharp (AND_SALT)': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
35
+ 'Saliency blob (AND_SALT_BLOB)': neutral_prompt_parser.PromptKeyword.AND_SALT_BLOB.value,
36
+ 'Saliency wide / classic (AND_SALT_WIDE)':neutral_prompt_parser.PromptKeyword.AND_SALT_WIDE.value,
37
+ 'Semantic guidance top-k (AND_TOPK)': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
38
+ }
39
+
40
+ # Add AND_ALIGN_D_S entries for commonly-used kernel size pairs
41
+ for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
42
+ _key = f'Alignment blend detail={_d} structure={_s} (AND_ALIGN_{_d}_{_s})'
43
+ prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_ALIGN_{_d}_{_s}').value
44
+
45
+ # Add AND_MASK_ALIGN_D_S entries
46
+ for _d, _s in ((2, 4), (2, 8), (4, 8), (4, 16), (8, 16), (8, 32)):
47
+ _key = f'Alignment mask detail={_d} structure={_s} (AND_MASK_ALIGN_{_d}_{_s})'
48
+ prompt_types[_key] = getattr(neutral_prompt_parser.PromptKeyword, f'AND_MASK_ALIGN_{_d}_{_s}').value
49
+
50
+ prompt_types_tooltip = '\n'.join([
51
+ 'AND – add all prompt features equally (webui built-in)',
52
+ 'AND_PERP – reduce contradicting prompt features via perpendicular projection',
53
+ 'AND_SALT – sharp saliency mask: child wins only at its 1-2 peak pixels (surgical)',
54
+ 'AND_SALT_BLOB – blob saliency mask: peak pixels eroded to core, then grown to smooth blob',
55
+ 'AND_SALT_WIDE – broad saliency mask: child wins ~55% of pixels (classic main-branch)',
56
+ 'AND_TOPK – small targeted changes (semantic guidance top-k)',
57
+ 'AND_ALIGN_D_S – soft blend: child adds details without breaking structure',
58
+ 'AND_MASK_ALIGN_D_S– binary mask blend: stricter version of AND_ALIGN',
59
+ '',
60
+ 'Affine transforms (supported in either order for auxiliary segments):',
61
+ ' ROTATE[angle] SLIDE[x,y] SCALE[x,y] SHEAR[x,y]',
62
+ ' e.g.: ROTATE[0.125] AND_PERP vibrant colors :0.8',
63
+ ' e.g.: AND_PERP ROTATE[0.125] vibrant colors :0.8',
64
+ ])
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # UI component class
69
+ # ---------------------------------------------------------------------------
70
+
71
+ @dataclasses.dataclass
72
+ class AccordionInterface:
73
+ get_elem_id: Callable[[str], str]
74
+
75
+ def __post_init__(self):
76
+ self.is_rendered = False
77
+
78
+ self.cfg_rescale = gr.Slider(
79
+ label='CFG rescale φ',
80
+ minimum=0, maximum=1, value=0,
81
+ info='0 = disabled. Rescales the CFG output towards the predicted x0 to reduce over-saturation.',
82
+ )
83
+ self.neutral_prompt = gr.Textbox(
84
+ label='Neutral prompt',
85
+ show_label=False,
86
+ lines=3,
87
+ placeholder=(
88
+ 'Neutral / auxiliary prompt '
89
+ '(press "Apply to prompt" to append with the chosen strategy keyword)'
90
+ ),
91
+ )
92
+ self.neutral_cond_scale = gr.Slider(
93
+ label='Prompt weight',
94
+ minimum=-3, maximum=3, value=1,
95
+ )
96
+ self.aux_prompt_type = gr.Dropdown(
97
+ label='Prompt type',
98
+ choices=list(prompt_types.keys()),
99
+ value=next(iter(prompt_types.keys())),
100
+ tooltip=prompt_types_tooltip,
101
+ elem_id=self.get_elem_id('formatter_prompt_type'),
102
+ )
103
+ self.append_to_prompt_button = gr.Button(value='Apply to prompt')
104
+
105
+ # ------------------------------------------------------------------
106
+
107
+ def arrange_components(self, is_img2img: bool) -> None:
108
+ if self.is_rendered:
109
+ return
110
+
111
+ with gr.Accordion(label='Neutral Prompt', open=False):
112
+ self.cfg_rescale.render()
113
+ with gr.Accordion(label='Prompt formatter', open=False):
114
+ self.neutral_prompt.render()
115
+ self.neutral_cond_scale.render()
116
+ self.aux_prompt_type.render()
117
+ self.append_to_prompt_button.render()
118
+
119
+ def connect_events(self, is_img2img: bool) -> None:
120
+ if self.is_rendered:
121
+ return
122
+
123
+ prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
124
+ self.append_to_prompt_button.click(
125
+ fn=lambda init_prompt, prompt, scale, prompt_type: (
126
+ f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}',
127
+ '',
128
+ ),
129
+ inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type],
130
+ outputs=[prompt_textbox, self.neutral_prompt],
131
+ )
132
+
133
+ def set_rendered(self, value: bool = True) -> None:
134
+ self.is_rendered = value
135
+
136
+ # ------------------------------------------------------------------
137
+ # Script interface (args / infotext / paste)
138
+
139
+ def get_components(self) -> Tuple[gr.components.Component, ...]:
140
+ return (self.cfg_rescale,)
141
+
142
+ def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str], ...]:
143
+ return tuple(zip(self.get_components(), ('CFG Rescale phi',)))
144
+
145
+ def get_paste_field_names(self) -> List[str]:
146
+ return ['CFG Rescale phi']
147
+
148
+ def get_extra_generation_params(self, args: Dict) -> Dict:
149
+ return {'CFG Rescale phi': args['cfg_rescale']}
150
+
151
+ def unpack_processing_args(self, cfg_rescale: float) -> Dict:
152
+ return {'cfg_rescale': cfg_rescale}
153
+
154
+
155
+ # ---------------------------------------------------------------------------
156
+ # Settings & callbacks
157
+ # ---------------------------------------------------------------------------
158
+
159
+ def on_ui_settings() -> None:
160
+ section = ('neutral_prompt', 'Neutral Prompt')
161
+ shared.opts.add_option(
162
+ 'neutral_prompt_enabled',
163
+ shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section),
164
+ )
165
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
166
+ shared.opts.onchange('neutral_prompt_enabled', _update_enabled)
167
+
168
+ shared.opts.add_option(
169
+ 'neutral_prompt_verbose',
170
+ shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section),
171
+ )
172
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
173
+ shared.opts.onchange('neutral_prompt_verbose', _update_verbose)
174
+
175
+
176
+ script_callbacks.on_ui_settings(on_ui_settings)
177
+
178
+
179
+ def _update_enabled() -> None:
180
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
181
+
182
+
183
+ def _update_verbose() -> None:
184
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
185
+
186
+
187
+ def on_after_component(component, **_kwargs) -> None:
188
+ global txt2img_prompt_textbox, img2img_prompt_textbox
189
+ eid = getattr(component, 'elem_id', None)
190
+ if eid == 'txt2img_prompt':
191
+ txt2img_prompt_textbox = component
192
+ elif eid == 'img2img_prompt':
193
+ img2img_prompt_textbox = component
194
+
195
+
196
+ script_callbacks.on_after_component(on_after_component)
neutral_prompt_patched/lib_neutral_prompt/xyz_grid.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from types import ModuleType
3
+ from typing import Optional
4
+ from modules import scripts
5
+ from lib_neutral_prompt import global_state
6
+
7
+
8
+ def patch():
9
+ xyz_module = find_xyz_module()
10
+ if xyz_module is None:
11
+ print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
12
+ return
13
+
14
+ xyz_module.axis_options.extend([
15
+ xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
16
+ ])
17
+
18
+
19
+ class XyzFloat(float):
20
+ is_xyz: bool = True
21
+
22
+
23
+ def apply_cfg_rescale():
24
+ def callback(_p, v, _vs):
25
+ global_state.cfg_rescale = XyzFloat(v)
26
+
27
+ return callback
28
+
29
+
30
+ def int_or_float(string):
31
+ try:
32
+ return int(string)
33
+ except ValueError:
34
+ return float(string)
35
+
36
+
37
+ def find_xyz_module() -> Optional[ModuleType]:
38
+ for data in scripts.scripts_data:
39
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
40
+ return data.module
41
+
42
+ return None
neutral_prompt_patched/scripts/neutral_prompt.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
2
+ from modules import scripts, processing, shared
3
+ from typing import Dict
4
+ import functools
5
+
6
+
7
+ class NeutralPromptScript(scripts.Script):
8
+ def __init__(self):
9
+ self.accordion_interface = None
10
+ self._is_img2img = False
11
+
12
+ @property
13
+ def is_img2img(self):
14
+ return self._is_img2img
15
+
16
+ @is_img2img.setter
17
+ def is_img2img(self, is_img2img):
18
+ self._is_img2img = is_img2img
19
+ if self.accordion_interface is None:
20
+ self.accordion_interface = ui.AccordionInterface(self.elem_id)
21
+
22
+ def title(self) -> str:
23
+ return "Neutral Prompt"
24
+
25
+ def show(self, is_img2img: bool):
26
+ return scripts.AlwaysVisible
27
+
28
+ def ui(self, is_img2img: bool):
29
+ self.hijack_composable_lora(is_img2img)
30
+
31
+ self.accordion_interface.arrange_components(is_img2img)
32
+ self.accordion_interface.connect_events(is_img2img)
33
+ self.infotext_fields = self.accordion_interface.get_infotext_fields()
34
+ self.paste_field_names = self.accordion_interface.get_paste_field_names()
35
+ self.accordion_interface.set_rendered()
36
+ return self.accordion_interface.get_components()
37
+
38
+ def process(self, p: processing.StableDiffusionProcessing, *args):
39
+ args = self.accordion_interface.unpack_processing_args(*args)
40
+
41
+ self.update_global_state(args)
42
+ if global_state.is_enabled:
43
+ p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
44
+
45
+ def update_global_state(self, args: Dict):
46
+ if shared.state.job_no == 0:
47
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
48
+
49
+ for k, v in args.items():
50
+ try:
51
+ getattr(global_state, k)
52
+ except AttributeError:
53
+ continue
54
+
55
+ if getattr(getattr(global_state, k), 'is_xyz', False):
56
+ xyz_attr = getattr(global_state, k)
57
+ xyz_attr.is_xyz = False
58
+ args[k] = xyz_attr
59
+ continue
60
+
61
+ if shared.state.job_no > 0:
62
+ continue
63
+
64
+ setattr(global_state, k, v)
65
+
66
+ def hijack_composable_lora(self, is_img2img):
67
+ if self.accordion_interface.is_rendered:
68
+ return
69
+
70
+ lora_script = None
71
+ script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
72
+
73
+ for script in script_runner.alwayson_scripts:
74
+ if script.title().lower() == "composable lora":
75
+ lora_script = script
76
+ break
77
+
78
+ if lora_script is not None:
79
+ lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
80
+
81
+
82
+ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
83
+ if not global_state.is_enabled:
84
+ return original_function(p, *args, **kwargs)
85
+
86
+ exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
87
+ all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
88
+ res = original_function(p, *args, **kwargs)
89
+ # restore original prompts
90
+ p.all_prompts = all_prompts
91
+ return res
92
+
93
+
94
+ xyz_grid.patch()
neutral_prompt_patched/test/perp_parser/__init__.py ADDED
File without changes
neutral_prompt_patched/test/perp_parser/test_affine_keyword_order.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for affine-keyword ordering in parse_prompt.
3
+
4
+ Both orderings must work correctly:
5
+ (A) AND_PERP ROTATE[0.125] vivid colors ← trailing affine (original)
6
+ (B) ROTATE[0.125] AND_PERP vivid colors ← leading affine (documented)
7
+ """
8
+ import unittest
9
+
10
+ from lib_neutral_prompt import neutral_prompt_parser as p
11
+ from lib_neutral_prompt.neutral_prompt_parser import ConciliationStrategy
12
+
13
+
14
+ class TestAffineKeywordOrder(unittest.TestCase):
15
+
16
+ # ------------------------------------------------------------------ #
17
+ # trailing affine (AND_PERP ROTATE[...] text) #
18
+ # ------------------------------------------------------------------ #
19
+
20
+ def test_trailing_affine_basic(self):
21
+ root = p.parse_root("base AND_PERP ROTATE[0.125] vivid colors :0.8")
22
+ self.assertEqual(len(root.children), 2)
23
+ self.assertEqual(root.children[0].prompt.strip(), "base")
24
+ self.assertIsNone(root.children[0].local_transform)
25
+ self.assertIsNotNone(root.children[1].local_transform)
26
+ self.assertEqual(root.children[1].conciliation, ConciliationStrategy.PERPENDICULAR)
27
+ self.assertAlmostEqual(root.children[1].weight, 0.8)
28
+
29
+ # ------------------------------------------------------------------ #
30
+ # leading affine (ROTATE[...] AND_PERP text) #
31
+ # ------------------------------------------------------------------ #
32
+
33
+ def test_leading_affine_mid_prompt(self):
34
+ """ROTATE[...] AND_PERP must NOT be consumed into the previous prompt's text."""
35
+ root = p.parse_root("base ROTATE[0.125] AND_PERP vivid colors :0.8")
36
+ self.assertEqual(len(root.children), 2,
37
+ f"Expected 2 children, got {len(root.children)}: "
38
+ f"{[getattr(c,'prompt','composite') for c in root.children]}")
39
+ first, second = root.children
40
+ # first segment = "base " with no transform
41
+ self.assertIn("base", first.prompt)
42
+ self.assertNotIn("ROTATE", first.prompt)
43
+ self.assertIsNone(first.local_transform)
44
+ # second segment gets the transform
45
+ self.assertIsNotNone(second.local_transform)
46
+ self.assertEqual(second.conciliation, ConciliationStrategy.PERPENDICULAR)
47
+ self.assertAlmostEqual(second.weight, 0.8)
48
+
49
+ def test_leading_affine_at_start(self):
50
+ """ROTATE[...] AND_PERP as first (and only non-AND) segment."""
51
+ root = p.parse_root("ROTATE[0.125] AND_PERP vivid colors :0.8")
52
+ self.assertEqual(len(root.children), 1)
53
+ child = root.children[0]
54
+ self.assertIsNotNone(child.local_transform)
55
+ self.assertEqual(child.conciliation, ConciliationStrategy.PERPENDICULAR)
56
+ self.assertIn("vivid", child.prompt)
57
+ self.assertNotIn("ROTATE", child.prompt)
58
+
59
+ def test_leading_affine_with_and_salt(self):
60
+ root = p.parse_root("portrait SLIDE[0.05,0] AND_SALT vibrant :1.2")
61
+ self.assertEqual(len(root.children), 2)
62
+ second = root.children[1]
63
+ self.assertIsNotNone(second.local_transform)
64
+ self.assertEqual(second.conciliation, ConciliationStrategy.SALIENCE_MASK)
65
+
66
+ # ------------------------------------------------------------------ #
67
+ # Composed / chained affine #
68
+ # ------------------------------------------------------------------ #
69
+
70
+ def test_leading_and_trailing_compose(self):
71
+ """ROTATE[a] AND_PERP SCALE[b,b] text — both affines composed."""
72
+ root = p.parse_root("base AND_PERP ROTATE[0.125] AND_PERP SCALE[1.5,1.5] vivid")
73
+ # Just check no crash and all children parse
74
+ self.assertGreaterEqual(len(root.children), 1)
75
+
76
+ # ------------------------------------------------------------------ #
77
+ # CFGRescaleFactorSingleton lifecycle #
78
+ # ------------------------------------------------------------------ #
79
+
80
+ def test_cfg_rescale_singleton_clear(self):
81
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
82
+ S.clear()
83
+ self.assertIsNone(S.get())
84
+ S.set(1.23)
85
+ self.assertAlmostEqual(S.get(), 1.23)
86
+ S.clear()
87
+ self.assertIsNone(S.get())
88
+
89
+ def test_cfg_rescale_singleton_thread_local(self):
90
+ import threading
91
+ from lib_neutral_prompt.global_state import CFGRescaleFactorSingleton as S
92
+ results = {}
93
+ def worker(val, key):
94
+ S.clear()
95
+ S.set(val)
96
+ import time; time.sleep(0.01)
97
+ results[key] = S.get()
98
+ threads = [threading.Thread(target=worker, args=(i * 10.0, i)) for i in range(3)]
99
+ for t in threads: t.start()
100
+ for t in threads: t.join()
101
+ for i in range(3):
102
+ self.assertAlmostEqual(results[i], i * 10.0,
103
+ msg=f"Thread {i} got {results[i]} expected {i*10.0}")
104
+
105
+ # ------------------------------------------------------------------ #
106
+ # New AND_SALT variants parse correctly #
107
+ # ------------------------------------------------------------------ #
108
+
109
+ def test_and_salt_wide_parsed(self):
110
+ root = p.parse_root("base AND_SALT_WIDE vivid")
111
+ self.assertEqual(len(root.children), 2)
112
+ self.assertEqual(root.children[1].conciliation,
113
+ p.ConciliationStrategy.SALIENCE_MASK_WIDE)
114
+
115
+ def test_and_salt_blob_parsed(self):
116
+ root = p.parse_root("base AND_SALT_BLOB vivid")
117
+ self.assertEqual(len(root.children), 2)
118
+ self.assertEqual(root.children[1].conciliation,
119
+ p.ConciliationStrategy.SALIENCE_MASK_BLOB)
120
+
121
+ def test_and_align_parsed(self):
122
+ root = p.parse_root("base AND_ALIGN_4_8 vivid")
123
+ self.assertEqual(len(root.children), 2)
124
+ self.assertIn('AND_ALIGN_4_8', root.children[1].conciliation.value)
125
+
126
+ def test_and_mask_align_parsed(self):
127
+ root = p.parse_root("base AND_MASK_ALIGN_4_8 vivid")
128
+ self.assertEqual(len(root.children), 2)
129
+ self.assertIn('AND_MASK_ALIGN_4_8', root.children[1].conciliation.value)
130
+
131
+
132
+ if __name__ == '__main__':
133
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_affine_pipeline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration smoke tests for the pre-noise affine pipeline.
3
+
4
+ Requires PyTorch. If torch is not installed the entire module is skipped via
5
+ ``unittest.skipUnless``, so ``unittest discover`` still exits cleanly.
6
+
7
+ Mocked: A1111 modules only (modules.script_callbacks, modules.sd_samplers,
8
+ modules.shared). torch itself is real — no sys.modules replacement.
9
+
10
+ Tests:
11
+ 1. Hook is a no-op when global_state.is_enabled = False.
12
+ 2. Identity transform (angle=0) leaves x numerically unchanged.
13
+ 3. Non-identity transform calls apply_affine_transform (verified by spy).
14
+ 4. Prompts with no affine leave x unchanged.
15
+ 5. Empty state does not crash.
16
+ 6. Batch with multiple prompts does not crash.
17
+ """
18
+
19
+ import dataclasses
20
+ import importlib
21
+ import math
22
+ import sys
23
+ import types
24
+ import unittest
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Detect torch availability — skip the whole module if absent
28
+ # ---------------------------------------------------------------------------
29
+ try:
30
+ import importlib.util as _ilu
31
+ _spec = _ilu.find_spec('torch')
32
+ _TORCH_AVAILABLE = _spec is not None and getattr(_spec, 'origin', None) is not None
33
+ except (ValueError, ModuleNotFoundError):
34
+ _TORCH_AVAILABLE = False
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # A1111 module stubs (installed unconditionally so cfg_denoiser_hijack can
39
+ # be imported even when running the skip path)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ def _stub(name):
43
+ m = types.ModuleType(name)
44
+ sys.modules.setdefault(name, m)
45
+ return sys.modules[name]
46
+
47
+ for _mod_name in ('modules', 'modules.script_callbacks', 'modules.sd_samplers',
48
+ 'modules.shared', 'modules.prompt_parser', 'gradio'):
49
+ _stub(_mod_name)
50
+
51
+ import modules.script_callbacks as _sc
52
+ if not hasattr(_sc, 'on_cfg_denoiser'):
53
+ _sc.on_cfg_denoiser = lambda fn: None
54
+ if not hasattr(_sc, 'on_script_unloaded'):
55
+ _sc.on_script_unloaded = lambda fn: None
56
+
57
+ import modules.shared as _sh
58
+ if not hasattr(_sh, 'opts'):
59
+ _sh.opts = types.SimpleNamespace(
60
+ data={},
61
+ add_option=lambda *a, **k: None,
62
+ onchange=lambda *a, **k: None,
63
+ OptionInfo=lambda *a, **k: None,
64
+ )
65
+
66
+ import modules.sd_samplers as _ss
67
+ if not hasattr(_ss, 'cfg_denoiser'):
68
+ _ss.cfg_denoiser = None
69
+ # cfg_denoiser_hijack installs a hijack on create_sampler at import time
70
+ if not hasattr(_ss, 'create_sampler'):
71
+ _ss.create_sampler = lambda *a, **k: None
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Repo root on sys.path
76
+ # ---------------------------------------------------------------------------
77
+
78
+ from pathlib import Path
79
+ _REPO_ROOT = Path(__file__).resolve().parents[2]
80
+ if str(_REPO_ROOT) not in sys.path:
81
+ sys.path.insert(0, str(_REPO_ROOT))
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Conditional imports (only when torch is available)
86
+ # ---------------------------------------------------------------------------
87
+
88
+ if _TORCH_AVAILABLE:
89
+ import torch
90
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
91
+ from lib_neutral_prompt.cfg_denoiser_hijack import _on_cfg_denoiser
92
+ import lib_neutral_prompt.affine_transform as _at_mod
93
+
94
+ # -----------------------------------------------------------------------
95
+ # Helpers
96
+ # -----------------------------------------------------------------------
97
+
98
+ def _leaf(text, angle=None):
99
+ """LeafPrompt, optionally with a ROTATE affine (angle in turns)."""
100
+ transform = None
101
+ if angle is not None:
102
+ c = math.cos(angle * 2 * math.pi)
103
+ s = math.sin(angle * 2 * math.pi)
104
+ transform = torch.tensor([[c, -s, 0.0], [s, c, 0.0]])
105
+ return neutral_prompt_parser.LeafPrompt(1.0, None, transform, text)
106
+
107
+ def _comp(*children):
108
+ return neutral_prompt_parser.CompositePrompt(1.0, None, None, list(children))
109
+
110
+ @dataclasses.dataclass
111
+ class _FakeParams:
112
+ x: torch.Tensor
113
+
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Test class
117
+ # ---------------------------------------------------------------------------
118
+
119
+ @unittest.skipUnless(_TORCH_AVAILABLE, 'PyTorch is not installed — skipping pipeline tests')
120
+ class TestOnCfgDenoiserSmokeTest(unittest.TestCase):
121
+
122
+ def setUp(self):
123
+ global_state.is_enabled = True
124
+ global_state.prompt_exprs = []
125
+ global_state.batch_cond_indices = []
126
+
127
+ def tearDown(self):
128
+ global_state.is_enabled = False
129
+ global_state.prompt_exprs = []
130
+ global_state.batch_cond_indices = []
131
+
132
+ # 1. disabled → x untouched
133
+ def test_disabled_leaves_x_unchanged(self):
134
+ global_state.is_enabled = False
135
+ x = torch.zeros(2, 4, 8, 8)
136
+ x[0] = 1.0
137
+ orig = x.clone()
138
+ params = _FakeParams(x=x.clone())
139
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.125))]
140
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
141
+ _on_cfg_denoiser(params)
142
+ self.assertTrue(torch.equal(params.x, orig), "Disabled hook must not modify x")
143
+
144
+ # 2. identity transform → x numerically unchanged
145
+ def test_identity_transform_no_change(self):
146
+ x = torch.randn(2, 4, 8, 8)
147
+ orig = x.clone()
148
+ params = _FakeParams(x=x.clone())
149
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0))]
150
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
151
+ _on_cfg_denoiser(params)
152
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
153
+ "Identity affine must not change x")
154
+
155
+ # 3. non-identity → apply_affine_transform is called (spy)
156
+ def test_nonidentity_calls_apply_affine_transform(self):
157
+ x = torch.zeros(2, 4, 8, 8)
158
+ x[1] = torch.randn(4, 8, 8)
159
+ params = _FakeParams(x=x.clone())
160
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp", 0.25))]
161
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
162
+
163
+ _real = _at_mod.apply_affine_transform
164
+ calls = []
165
+ def _spy(tensor, affine, mode='bilinear'):
166
+ calls.append(True)
167
+ return _real(tensor, affine, mode)
168
+
169
+ _at_mod.apply_affine_transform = _spy
170
+ try:
171
+ _on_cfg_denoiser(params)
172
+ finally:
173
+ _at_mod.apply_affine_transform = _real
174
+
175
+ self.assertGreater(len(calls), 0,
176
+ "apply_affine_transform must be called for non-identity transform")
177
+
178
+ # 4. no affine on any child → x unchanged
179
+ def test_no_affine_no_change(self):
180
+ x = torch.randn(2, 4, 8, 8)
181
+ orig = x.clone()
182
+ params = _FakeParams(x=x.clone())
183
+ global_state.prompt_exprs = [_comp(_leaf("base"), _leaf("perp"))]
184
+ global_state.batch_cond_indices = [[(0, 1.0), (1, 1.0)]]
185
+ _on_cfg_denoiser(params)
186
+ self.assertTrue(torch.allclose(params.x, orig, atol=1e-5),
187
+ "No affine → x must not change")
188
+
189
+ # 5. empty state → no crash
190
+ def test_empty_state_no_crash(self):
191
+ params = _FakeParams(x=torch.randn(1, 4, 8, 8))
192
+ global_state.prompt_exprs = []
193
+ global_state.batch_cond_indices = []
194
+ try:
195
+ _on_cfg_denoiser(params)
196
+ except Exception as e:
197
+ self.fail(f"_on_cfg_denoiser raised with empty state: {e}")
198
+
199
+ # 6. two batch entries → no crash
200
+ def test_batch_multiple_prompts_no_crash(self):
201
+ params = _FakeParams(x=torch.zeros(4, 4, 8, 8))
202
+ global_state.prompt_exprs = [
203
+ _comp(_leaf("base_a"), _leaf("perp_a", 0.25)),
204
+ _comp(_leaf("base_b"), _leaf("perp_b", 0.125)),
205
+ ]
206
+ global_state.batch_cond_indices = [
207
+ [(0, 1.0), (1, 1.0)],
208
+ [(2, 1.0), (3, 1.0)],
209
+ ]
210
+ try:
211
+ _on_cfg_denoiser(params)
212
+ except Exception as e:
213
+ self.fail(f"Multi-prompt batch raised: {e}")
214
+
215
+
216
+ if __name__ == '__main__':
217
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_basic_parser.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
11
+ self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
12
+ self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
13
+ self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
14
+ self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
15
+ self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
16
+ self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
17
+
18
+ def test_simple_prompt_child_count(self):
19
+ self.assertEqual(len(self.simple_prompt.children), 1)
20
+
21
+ def test_simple_prompt_child_weight(self):
22
+ self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
23
+
24
+ def test_simple_prompt_child_prompt(self):
25
+ self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
26
+
27
+ def test_square_weight_prompt(self):
28
+ prompt = "a [b c d e : f g h :1.5]"
29
+ parsed = neutral_prompt_parser.parse_root(prompt)
30
+ self.assertEqual(parsed.children[0].prompt, prompt)
31
+
32
+ composed_prompt = f"{prompt} AND_PERP other prompt"
33
+ parsed = neutral_prompt_parser.parse_root(composed_prompt)
34
+ self.assertEqual(parsed.children[0].prompt, prompt)
35
+
36
+ def test_and_prompt_child_count(self):
37
+ self.assertEqual(len(self.and_prompt.children), 2)
38
+
39
+ def test_and_prompt_child_weights_and_prompts(self):
40
+ self.assertEqual(self.and_prompt.children[0].weight, 1.0)
41
+ self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
42
+ self.assertEqual(self.and_prompt.children[1].weight, 2.0)
43
+ self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
44
+
45
+ def test_and_perp_prompt_child_count(self):
46
+ self.assertEqual(len(self.and_perp_prompt.children), 2)
47
+
48
+ def test_and_perp_prompt_child_types(self):
49
+ self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
50
+ self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
51
+
52
+ def test_and_perp_prompt_nested_child(self):
53
+ nested_child = self.and_perp_prompt.children[1]
54
+ self.assertEqual(nested_child.weight, 2.0)
55
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
56
+
57
+ def test_nested_and_perp_prompt_child_count(self):
58
+ self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
59
+
60
+ def test_nested_and_perp_prompt_child_types(self):
61
+ self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
62
+ self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
63
+
64
+ def test_nested_and_perp_prompt_nested_child_types(self):
65
+ nested_child = self.nested_and_perp_prompt.children[1].children[0]
66
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
67
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
68
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
69
+
70
+ def test_nested_and_perp_prompt_nested_child(self):
71
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
72
+ self.assertEqual(nested_child.weight, 3.0)
73
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
74
+
75
+ def test_invalid_weight_child_count(self):
76
+ self.assertEqual(len(self.invalid_weight.children), 1)
77
+
78
+ def test_invalid_weight_child_weight(self):
79
+ self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
80
+
81
+ def test_invalid_weight_child_prompt(self):
82
+ self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
83
+
84
+ def test_and_salt_prompt_child_count(self):
85
+ self.assertEqual(len(self.and_salt_prompt.children), 2)
86
+
87
+ def test_and_salt_prompt_child_types(self):
88
+ self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
89
+ self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
90
+
91
+ def test_and_salt_prompt_nested_child(self):
92
+ nested_child = self.and_salt_prompt.children[1]
93
+ self.assertEqual(nested_child.weight, 2.0)
94
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
95
+
96
+ def test_nested_and_salt_prompt_child_count(self):
97
+ self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
98
+
99
+ def test_nested_and_salt_prompt_child_types(self):
100
+ self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
101
+ self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
102
+
103
+ def test_nested_and_salt_prompt_nested_child_types(self):
104
+ nested_child = self.nested_and_salt_prompt.children[1].children[0]
105
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
106
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
107
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
108
+
109
+ def test_nested_and_salt_prompt_nested_child(self):
110
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
111
+ self.assertEqual(nested_child.weight, 3.0)
112
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
113
+
114
+ def test_start_with_prompt_editing(self):
115
+ prompt = "[(long shot:1.2):0.1] detail.."
116
+ res = neutral_prompt_parser.parse_root(prompt)
117
+ self.assertEqual(res.children[0].weight, 1.0)
118
+ self.assertEqual(res.children[0].prompt, prompt)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ unittest.main()
neutral_prompt_patched/test/perp_parser/test_malicious_parser.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestMaliciousPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.parser = neutral_prompt_parser
11
+
12
+ def test_empty(self):
13
+ result = self.parser.parse_root("")
14
+ self.assertEqual(result.children[0].prompt, "")
15
+ self.assertEqual(result.children[0].weight, 1.0)
16
+
17
+ def test_zero_weight(self):
18
+ result = self.parser.parse_root("hello :0.0")
19
+ self.assertEqual(result.children[0].weight, 0.0)
20
+
21
+ def test_mixed_positive_and_negative_weights(self):
22
+ result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
23
+ self.assertEqual(result.children[0].weight, 1.0)
24
+ self.assertEqual(result.children[1].weight, -2.0)
25
+
26
+ def test_debalanced_square_brackets(self):
27
+ prompt = "a [ b " * 100
28
+ result = self.parser.parse_root(prompt)
29
+ self.assertEqual(result.children[0].prompt, prompt)
30
+
31
+ prompt = "a ] b " * 100
32
+ result = self.parser.parse_root(prompt)
33
+ self.assertEqual(result.children[0].prompt, prompt)
34
+
35
+ repeats = 10
36
+ prompt = "a [ [ b AND c ] " * repeats
37
+ result = self.parser.parse_root(prompt)
38
+ self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
39
+
40
+ repeats = 10
41
+ prompt = "a [ b AND c ] ] " * repeats
42
+ result = self.parser.parse_root(prompt)
43
+ self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
44
+
45
+ def test_erroneous_syntax(self):
46
+ result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
47
+ self.assertEqual(result.children[0].weight, 1.0)
48
+ self.assertEqual(result.children[1].prompt, "[goodbye ")
49
+ self.assertEqual(result.children[1].weight, 2.0)
50
+
51
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
52
+ self.assertEqual(result.children[0].weight, 1.0)
53
+ self.assertEqual(result.children[1].prompt, " goodbye ")
54
+
55
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
56
+ self.assertEqual(result.children[1].prompt, " goodbye]")
57
+ self.assertEqual(result.children[1].weight, 2.0)
58
+
59
+ result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
60
+ self.assertEqual(result.children[1].weight, 2.0)
61
+ self.assertEqual(result.children[1].prompt, " a [ goodbye ")
62
+
63
+ result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
64
+ self.assertEqual(result.children[0].weight, 1.0)
65
+ self.assertEqual(result.children[2].prompt, " goodbye ")
66
+
67
+ def test_huge_number_of_prompt_parts(self):
68
+ result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
69
+ self.assertEqual(len(result.children), 10**4)
70
+
71
+ def test_prompt_ending_with_weight(self):
72
+ result = self.parser.parse_root("hello :1.0 AND :2.0")
73
+ self.assertEqual(result.children[0].weight, 1.0)
74
+ self.assertEqual(result.children[1].prompt, "")
75
+ self.assertEqual(result.children[1].weight, 2.0)
76
+
77
+ def test_huge_input_string(self):
78
+ big_string = "hello :1.0 AND " * 10**4
79
+ result = self.parser.parse_root(big_string)
80
+ self.assertEqual(len(result.children), 10**4 + 1)
81
+
82
+ def test_deeply_nested_prompt(self):
83
+ deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
84
+ result = self.parser.parse_root(deeply_nested_prompt)
85
+ self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
86
+
87
+ def test_complex_nested_prompts(self):
88
+ complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
89
+ result = self.parser.parse_root(complex_prompt)
90
+ self.assertEqual(result.children[0].weight, 1.0)
91
+ self.assertEqual(result.children[1].weight, 2.0)
92
+ self.assertEqual(result.children[2].children[0].weight, 3.0)
93
+ self.assertEqual(result.children[2].children[1].weight, 4.0)
94
+ self.assertEqual(result.children[2].children[2].weight, 5.0)
95
+
96
+ def test_string_with_random_characters(self):
97
+ random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
98
+ try:
99
+ self.parser.parse_root(random_chars)
100
+ except Exception:
101
+ self.fail("parse_root couldn't handle a string with random characters.")
102
+
103
+ def test_string_with_unexpected_symbols(self):
104
+ unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
105
+ try:
106
+ self.parser.parse_root(unexpected_symbols)
107
+ except Exception:
108
+ self.fail("parse_root couldn't handle a string with unexpected symbols.")
109
+
110
+ def test_string_with_unconventional_structure(self):
111
+ unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
112
+ try:
113
+ self.parser.parse_root(unconventional_structure)
114
+ except Exception:
115
+ self.fail("parse_root couldn't handle a string with unconventional structure.")
116
+
117
+ def test_string_with_mixed_alphabets_and_numbers(self):
118
+ mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
119
+ try:
120
+ self.parser.parse_root(mixed_alphabets_and_numbers)
121
+ except Exception:
122
+ self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
123
+
124
+ def test_string_with_nested_brackets(self):
125
+ nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
126
+ try:
127
+ self.parser.parse_root(nested_brackets)
128
+ except Exception:
129
+ self.fail("parse_root couldn't handle a string with nested brackets.")
130
+
131
+ def test_unmatched_opening_braces(self):
132
+ unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
133
+ try:
134
+ self.parser.parse_root(unmatched_opening_braces)
135
+ except Exception:
136
+ self.fail("parse_root couldn't handle a string with unmatched opening braces.")
137
+
138
+ def test_unmatched_closing_braces(self):
139
+ unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
140
+ try:
141
+ self.parser.parse_root(unmatched_closing_braces)
142
+ except Exception:
143
+ self.fail("parse_root couldn't handle a string with unmatched closing braces.")
144
+
145
+ def test_repeating_colons(self):
146
+ repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
147
+ try:
148
+ self.parser.parse_root(repeating_colons)
149
+ except Exception:
150
+ self.fail("parse_root couldn't handle a string with repeating colons.")
151
+
152
+ def test_excessive_whitespace(self):
153
+ excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
154
+ try:
155
+ self.parser.parse_root(excessive_whitespace)
156
+ except Exception:
157
+ self.fail("parse_root couldn't handle a string with excessive whitespace.")
158
+
159
+ def test_repeating_AND_keyword(self):
160
+ repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
161
+ try:
162
+ self.parser.parse_root(repeating_AND_keyword)
163
+ except Exception:
164
+ self.fail("parse_root couldn't handle a string with repeating AND keyword.")
165
+
166
+ def test_repeating_AND_PERP_keyword(self):
167
+ repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
168
+ try:
169
+ self.parser.parse_root(repeating_AND_PERP_keyword)
170
+ except Exception:
171
+ self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
172
+
173
+ def test_square_weight_prompt(self):
174
+ prompt = "AND_PERP [weighted] you thought it was the end"
175
+ try:
176
+ self.parser.parse_root(prompt)
177
+ except Exception:
178
+ self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
179
+
180
+
181
+ if __name__ == '__main__':
182
+ unittest.main()