dikdimon commited on
Commit
02bf2ae
·
verified ·
1 Parent(s): c55e495

Upload z-sd-webui-neutral-prompt-workYEAH4 using SD-Hub

Browse files
Files changed (24) hide show
  1. z-sd-webui-neutral-prompt-workYEAH4/.gitignore +1 -0
  2. z-sd-webui-neutral-prompt-workYEAH4/LICENSE +21 -0
  3. z-sd-webui-neutral-prompt-workYEAH4/README.md +113 -0
  4. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc +0 -0
  5. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc +0 -0
  6. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc +0 -0
  7. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc +0 -0
  8. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc +0 -0
  9. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc +0 -0
  10. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc +0 -0
  11. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/cfg_denoiser_hijack.py +440 -0
  12. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/external_code/__init__.py +23 -0
  13. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/external_code/api.py +5 -0
  14. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/global_state.py +16 -0
  15. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/hijacker.py +34 -0
  16. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/neutral_prompt_parser.py +225 -0
  17. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/prompt_parser_hijack.py +207 -0
  18. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/ui.py +109 -0
  19. z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/xyz_grid.py +42 -0
  20. z-sd-webui-neutral-prompt-workYEAH4/scripts/__pycache__/neutral_prompt.cpython-310.pyc +0 -0
  21. z-sd-webui-neutral-prompt-workYEAH4/scripts/neutral_prompt.py +94 -0
  22. z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/__init__.py +0 -0
  23. z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/basic_test.py +122 -0
  24. z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/malicious_test.py +182 -0
z-sd-webui-neutral-prompt-workYEAH4/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
z-sd-webui-neutral-prompt-workYEAH4/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 ljleb
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
z-sd-webui-neutral-prompt-workYEAH4/README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neutral Prompt
2
+
3
+ Neutral prompt is an a1111 webui extension that adds alternative composable diffusion keywords to the prompt language. It enhances the original implementation using more recent research.
4
+
5
+ ## Features
6
+
7
+ - [Perp-Neg](https://perp-neg.github.io/) orthogonal prompts, invoked using the `AND_PERP` keyword
8
+ - saliency-aware noise blending, invoked using the `AND_SALT` keyword (credits to [Magic Fusion](https://magicfusion.github.io/) for the algorithm used to determine SNB maps from epsilons)
9
+ - semantic guidance top-k filtering, invoked using the `AND_TOPK` keyword (reference: https://arxiv.org/abs/2301.12247)
10
+ - standard deviation based CFG rescaling (Reference: https://arxiv.org/abs/2305.08891, section 3.4)
11
+
12
+ ## Usage
13
+
14
+ *Disclaimer: some sections of the readme have been generated by GPT-4. If anything is unclear, feel free to ask for clarifications in the [discussions](https://github.com/ljleb/sd-webui-neutral-prompt/discussions).*
15
+
16
+ ### Keyword `AND_PERP`
17
+
18
+ The `AND_PERP` keyword, standing for "PERPendicular `AND`", integrates the orthogonalization process described in the Perp-Neg paper. Essentially, `AND_PERP` allows for prompting concepts that highly overlap with regular prompts, by negating contradicting concepts.
19
+
20
+ You could visualize it as such: if `AND` prompts are "greedy" (taking as much space as possible in the output), `AND_PERP` prompts are opposite, relinquishing control as soon as there is a disagreement in the generated output.
21
+
22
+ ### Keyword `AND_SALT`
23
+
24
+ Saliency-aware blending is made possible using the `AND_SALT` keyword, shorthand for "SALienT `AND`". In essence, `AND_SALT` keeps the highest activation pixels at each denoising step.
25
+
26
+ Think of it as a territorial dispute: the image generated by the `AND` prompts is one country, and the images generated by `AND_SALT` prompts represent neighbouring nations. They're all vying for the same land - whoever strikes the strongest at a given time (denoising step) and location (latent pixel) claims it.
27
+
28
+ ### Keyword `AND_TOPK`
29
+
30
+ The `AND_TOPK` keyword refers to "TOP-K filtering". It keeps only the "k" highest activation latent pixels in the noise map and discards the rest. It works similarly to `AND_SALT`, except that the high-activation regions are simply added instead of replacing previous content.
31
+
32
+ Currently, k is constantly 5% of all latent pixels, meaning 95% of the weakest latent pixel values at each step are discarded.
33
+
34
+ Top-k filtering is useful when you want to have a more targeted effect on the generated image. It should work best with smaller objects and details.
35
+
36
+ ## Examples
37
+
38
+ ### Using the `AND_PERP` Keyword
39
+
40
+ Here is an example to illustrate one use case of the `AND_PREP` keyword. Prompt:
41
+
42
+ `beautiful castle landscape AND monster house castle :-1`
43
+
44
+ This is an XY grid with prompt S/R `AND, AND_PERP`:
45
+
46
+ ![image](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/29f3cf34-2ed4-45d2-b73a-b6fadec21d61)
47
+
48
+ Key observations:
49
+
50
+ - The `AND_PERP` images exhibit a higher dynamic range compared to the `AND` images.
51
+ - Since the prompts have a lot of overlap, the `AND` images sometimes struggle to depict a castle. This isn't a problem for the `AND_PERP` images.
52
+ - The `AND` images tend to lean towards a purple color, because this was the path of least resistance between the two opposing prompts during generation. In contrast, the `AND_PERP` images, free from this tug-of-war, present a clearer representation.
53
+
54
+ ### Using the `AND_SALT` Keyword
55
+
56
+ The `AND_SALT` keyword can be used to invoke saliency-aware blending. It spotlights and accentuates areas of high-activation in the output.
57
+
58
+ Consider this example prompt utilizing `AND_SALT`:
59
+
60
+ ```
61
+ a vibrant rainforest with lush green foliage
62
+ AND_SALT the glimmering rays of a golden sunset piercing through the trees
63
+ ```
64
+
65
+ In this case, the extension identifies and isolates the most salient regions in the sunset prompt. Then, the extension applies this marsked image to the rainforest prompt. Only the portions of the rainforest prompt that coincide with the salient areas of the sunset prompt are affected. These areas are replaced by pixels from the sunset prompt.
66
+
67
+ This is an XY grid with prompt S/R `AND_SALT, AND, AND_PERP`:
68
+
69
+ ![xyz_grid-0008-1564977627-a vibrant rainforest with lush green foliage_AND_SALT the glimmering rays of a golden sunset piercing through the trees](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/2404f20b-47f6-457f-b4c5-76b9fd919345)
70
+
71
+ Key observations:
72
+
73
+ - `AND_SALT` behaves more diplomatically, enhancing areas where its impact makes the most sense and aligning with high activity regions in the output
74
+ - `AND` gives equal weight to both prompts, creating a blended result
75
+ - `AND_PERP` will find its way through anything not blocked by the regular prompt
76
+
77
+ ## Advanced Features
78
+
79
+ ### Nesting prompts
80
+
81
+ The extension supports nesting of all prompt keywords including `AND`, allowing greater flexibility and control over the final output. Here's an example of how these keywords can be combined:
82
+
83
+ ```
84
+ magical tree forests, eternal city
85
+ AND_PERP [
86
+ electrical pole voyage
87
+ AND_SALT small nocturne companion
88
+ ]
89
+ AND_SALT [
90
+ electrical tornado
91
+ AND_SALT electric arcs, bzzz, sparks
92
+ ]
93
+ ```
94
+
95
+ To generate the final image from the diffusion model:
96
+
97
+ 1. The extension first processes the root `AND` prompts. In this case, it's just `magical tree forests, eternal city`
98
+ 2. It then processes the `AND_SALT` prompt `small nocturne companion` in the context of `electrical pole voyage`. This enhances salient features in the `electrical pole voyage` image
99
+ 3. This new image is orthogonalized with the image from `magical tree forests, eternal city`, blending the details of the 'electrical pole voyage' into the main scene without creating conflicts
100
+ 4. The extension then turns to the second `AND_SALT` group. It processes `electric arcs, bzzz, sparks` in the context of `electrical tornado`, amplifying salient features in the electrical tornado image
101
+ 5. The image from this `AND_SALT` group is then combined with the `magical tree forests, eternal city` image. The final output retains the strongest features from both the `electrical tornado` (enhanced by 'electric arcs, bzzz, sparks') and the earlier 'magical tree forests, eternal city' scene influenced by the 'electrical pole voyage'
102
+
103
+ Each keyword can define a distinct denoising space within its square brackets `[...]`. Prompts inside it merge into a single image before further processing down the prompt tree.
104
+
105
+ While there's no strict limit on the depth of nesting, experimental evidence suggests that going beyond a depth of 2 is generally unnecessary. We're still exploring the added precision from deeper nesting. If you discover innovative ways of controlling the generations using nested prompts, please share in the discussions!
106
+
107
+ ![image](https://github.com/ljleb/sd-webui-neutral-prompt/assets/32277961/f16587fe-2244-4832-a253-98f819a9e2e0)
108
+
109
+ ## Special Mentions
110
+
111
+ Special thanks to these people for helping make this extension possible:
112
+
113
+ - [Ai-Casanova](https://github.com/AI-Casanova) : for sharing mathematical knowledge, time, and conducting proof-testing to enhance the robustness of this extension
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/cfg_denoiser_hijack.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/global_state.cpython-310.pyc ADDED
Binary file (701 Bytes). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/hijacker.cpython-310.pyc ADDED
Binary file (1.81 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/neutral_prompt_parser.cpython-310.pyc ADDED
Binary file (7.06 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/prompt_parser_hijack.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/ui.cpython-310.pyc ADDED
Binary file (5.34 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/__pycache__/xyz_grid.cpython-310.pyc ADDED
Binary file (1.65 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/cfg_denoiser_hijack.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lib_neutral_prompt/cfg_denoiser_hijack.py
2
+
3
+ from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
4
+ from lib_neutral_prompt.prompt_parser_hijack import WebuiPromptVisitor
5
+ from modules import script_callbacks, sd_samplers, shared, prompt_parser
6
+ from typing import Tuple, List
7
+ import dataclasses
8
+ import functools
9
+ import torch
10
+ import sys
11
+ import textwrap
12
+
13
+
14
+ # -----------------------------
15
+ # Utils / warnings
16
+ # -----------------------------
17
+
18
+ def console_warn(message: str):
19
+ if not getattr(global_state, "verbose", False):
20
+ return
21
+ print(
22
+ f"\n[sd-webui-neutral-prompt] WARNING: {textwrap.dedent(message)}\n"
23
+ f"Please check for conflicts with other prompt-parsing extensions or report at "
24
+ f"https://github.com/ljleb/sd-webui-neutral-prompt/issues",
25
+ file=sys.stderr,
26
+ )
27
+
28
+
29
+ def warn_unsupported_sampler():
30
+ console_warn(
31
+ """
32
+ Neutral Prompt composition via AND is unsupported by DDIM / PLMS / UniPC in webui.
33
+ The sampler will NOT be patched — falling back to original combine_denoised...
34
+ """
35
+ )
36
+
37
+
38
+ def warn_projection_not_found():
39
+ console_warn(
40
+ """
41
+ Could not find a stable projection for one or more AND_PERP prompts.
42
+ Those prompts will NOT be made perpendicular at the very first steps.
43
+ """
44
+ )
45
+
46
+
47
+ # -----------------------------
48
+ # Math helpers used by visitors
49
+ # -----------------------------
50
+
51
+ def get_salience(vector: torch.Tensor) -> torch.Tensor:
52
+ # softmax по абсолютным значениям; возвращаем «карту заметности» формы vector
53
+ return torch.softmax(torch.abs(vector).flatten(), dim=0).reshape_as(vector)
54
+
55
+
56
+ def filter_abs_top_k(vector: torch.Tensor, k_ratio: float) -> torch.Tensor:
57
+ # пропускаем только самые большие по модулю компоненты
58
+ k = int(torch.numel(vector) * (1 - k_ratio))
59
+ top_k, _ = torch.kthvalue(torch.abs(torch.flatten(vector)), max(k, 1))
60
+ return vector * (torch.abs(vector) >= top_k).to(vector.dtype)
61
+
62
+
63
+ def get_perpendicular_component(normal: torch.Tensor, vector: torch.Tensor) -> torch.Tensor:
64
+ # проецируем vector на подпространство, перпендикулярное normal
65
+ if (normal == 0).all():
66
+ # на первых шагам проекция может быть нулевой — не шумим логом чаще одного раза в начало семплинга
67
+ if getattr(shared.state, "sampling_step", 0) <= 0:
68
+ warn_projection_not_found()
69
+ return vector
70
+ return vector - normal * torch.sum(normal * vector) / (torch.norm(normal) ** 2)
71
+
72
+
73
+ def salient_blend(normal: torch.Tensor, vectors: List[Tuple[torch.Tensor, float]]) -> torch.Tensor:
74
+ # пиксельно выбираем «самый заметный» вектор; складываем дельты относительно normal
75
+ salience_maps = [get_salience(normal)] + [get_salience(v) for v, _ in vectors]
76
+ mask = torch.argmax(torch.stack(salience_maps, dim=0), dim=0)
77
+ result = torch.zeros_like(normal)
78
+ for mask_i, (vector, weight) in enumerate(vectors, start=1):
79
+ vector_mask = (mask == mask_i).float()
80
+ result += weight * vector_mask * (vector - normal)
81
+ return result
82
+
83
+
84
+ # -----------------------------
85
+ # Glue with webui
86
+ # -----------------------------
87
+
88
+ sd_samplers_hijacker = hijacker.ModuleHijacker.install_or_get(
89
+ module=sd_samplers,
90
+ hijacker_attribute="__neutral_prompt_hijacker",
91
+ on_uninstall=script_callbacks.on_script_unloaded,
92
+ )
93
+
94
+
95
+ @sd_samplers_hijacker.hijack("create_sampler")
96
+ def create_sampler_hijack(name: str, model, original_function):
97
+ sampler = original_function(name, model)
98
+ # только те, у кого есть model_wrap_cfg.combine_denoised
99
+ if not hasattr(sampler, "model_wrap_cfg") or not hasattr(sampler.model_wrap_cfg, "combine_denoised"):
100
+ if getattr(global_state, "is_enabled", False):
101
+ warn_unsupported_sampler()
102
+ return sampler
103
+
104
+ sampler.model_wrap_cfg.combine_denoised = functools.partial(
105
+ combine_denoised_hijack, original_function=sampler.model_wrap_cfg.combine_denoised
106
+ )
107
+ return sampler
108
+
109
+
110
+ # -----------------------------
111
+ # Scheduling helpers
112
+ # -----------------------------
113
+
114
+ def convert_to_prompt_expr_from_multicond(c, prompts):
115
+ # fallback: когда global_state.prompt_exprs ещё пуст, а webui уже дал индексы/веса
116
+ exprs = []
117
+ for batch_conds in c:
118
+ for idx, weight in batch_conds:
119
+ prompt_text = prompts[idx] if prompts and idx < len(prompts) else f"prompt_{idx}"
120
+ exprs.append(
121
+ neutral_prompt_parser.LeafPrompt(
122
+ weight=float(weight) if isinstance(weight, (int, float)) else 1.0,
123
+ conciliation=None,
124
+ prompt=prompt_text,
125
+ )
126
+ )
127
+ return exprs
128
+
129
+
130
+ def get_active_prompt(schedule, current_step):
131
+ # выбираем активный текст по текущему шагу
132
+ for end_step, text in schedule:
133
+ if current_step <= end_step:
134
+ return text
135
+ return schedule[-1][1] if schedule else "empty_prompt"
136
+
137
+
138
+ # -----------------------------
139
+ # Main entry
140
+ # -----------------------------
141
+
142
+ def combine_denoised_hijack(
143
+ x_out: torch.Tensor,
144
+ batch_cond_indices: List[List[Tuple[int, float]]],
145
+ text_uncond: torch.Tensor,
146
+ cond_scale: float,
147
+ original_function,
148
+ ) -> torch.Tensor:
149
+ """
150
+ Патчим combine_denoised так, чтобы он поддерживал Neutral Prompt-композиции.
151
+ """
152
+ if not getattr(global_state, "is_enabled", False):
153
+ return original_function(x_out, batch_cond_indices, text_uncond, cond_scale)
154
+
155
+ # если кто-то вызвал раньше, а exprs ещё нет — построим из webui conds
156
+ if not getattr(global_state, "prompt_exprs", []) and batch_cond_indices:
157
+ global_state.prompt_exprs = convert_to_prompt_expr_from_multicond(batch_cond_indices, [])
158
+
159
+ # поддержка расписаний webui (prompt travel / динамика по шагам)
160
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(
161
+ getattr(global_state, "raw_prompts", [expr.accept(WebuiPromptVisitor()) for expr in global_state.prompt_exprs]),
162
+ shared.state.sampling_steps,
163
+ )
164
+ active_prompts = [get_active_prompt(s, shared.state.sampling_step) for s in prompt_schedules]
165
+ global_state.prompt_exprs = [neutral_prompt_parser.parse_root(p) for p in active_prompts]
166
+
167
+ # считаем «чистый» denoised от webui на переупакованном батче
168
+ denoised = get_webui_denoised(x_out, batch_cond_indices, text_uncond, cond_scale, original_function)
169
+
170
+ # затем добавим наши корректировки (aux_cond_delta) и CFG-rescale
171
+ uncond = x_out[-text_uncond.shape[0] :]
172
+
173
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
174
+ args = CombineDenoiseArgs(x_out=x_out, uncond=uncond[batch_i], cond_indices=cond_indices)
175
+ cond_delta = prompt.accept(CondDeltaVisitor(), args, 0)
176
+ aux_cond_delta = prompt.accept(AuxCondDeltaVisitor(), args, cond_delta, 0)
177
+
178
+ # добавляем только aux часть в cfg_cond (основанный на webui результате)
179
+ cfg_cond = denoised[batch_i] + aux_cond_delta * cond_scale
180
+ # а rescale делаем относительно полной «цели»: uncond + cond_delta + aux_cond_delta
181
+ denoised[batch_i] = cfg_rescale(cfg_cond, uncond[batch_i] + cond_delta + aux_cond_delta)
182
+
183
+ return denoised
184
+
185
+
186
+ def get_webui_denoised(
187
+ x_out: torch.Tensor,
188
+ batch_cond_indices: List[List[Tuple[int, float]]],
189
+ text_uncond: torch.Tensor,
190
+ cond_scale: float,
191
+ original_function,
192
+ ) -> torch.Tensor:
193
+ """
194
+ Переупаковываем батч под оригинальный combine_denoised:
195
+ - строим список «условных срезов» (с учётом сложных композиций)
196
+ - отдаём webui оригинальные индексы/веса, но уже на наш переупакованный батч
197
+ - webui вернёт по одному denoised на картинку (а не на срез!)
198
+ """
199
+ uncond = x_out[-text_uncond.shape[0] :]
200
+ sliced_batch_x_out: List[torch.Tensor] = []
201
+ sliced_batch_cond_indices: List[List[Tuple[int, float]]] = []
202
+
203
+ for batch_i, (prompt, cond_indices) in enumerate(zip(global_state.prompt_exprs, batch_cond_indices)):
204
+ args = CombineDenoiseArgs(x_out=x_out, uncond=uncond[batch_i], cond_indices=cond_indices)
205
+
206
+ # sanity-check: число базовых листьев должно совпадать с количеством webui cond_indices
207
+ base_size = prompt.accept(BaseSizeVisitor())
208
+ if base_size != len(cond_indices) and getattr(global_state, "verbose", False):
209
+ console_warn(
210
+ f"BaseSize({base_size}) != cond_indices({len(cond_indices)}) at batch {batch_i}. "
211
+ f"Another extension may also modify parsing."
212
+ )
213
+
214
+ sliced_x_out, sliced_cond_indices = gather_webui_conds(prompt, args, 0, len(sliced_batch_x_out))
215
+ if sliced_cond_indices:
216
+ sliced_batch_cond_indices.append(sliced_cond_indices)
217
+ sliced_batch_x_out.extend(sliced_x_out)
218
+
219
+ # добавляем uncond-часть всего батча (webui ожидает её в хвосте)
220
+ sliced_batch_x_out += list(uncond)
221
+ sliced_batch_x_out = torch.stack(sliced_batch_x_out, dim=0)
222
+
223
+ # вызвать оригинальный combine_denoised на нашем переупакованном батче
224
+ return original_function(sliced_batch_x_out, sliced_batch_cond_indices, text_uncond, cond_scale)
225
+
226
+
227
+ def cfg_rescale(cfg_cond: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ CFG Rescale (как в SDXL/SD WebUI) — нормируем дисперсию cond к cond, но сохраняем среднее в смеси.
230
+ """
231
+ if getattr(global_state, "cfg_rescale", 0) == 0:
232
+ return cfg_cond
233
+
234
+ global_state.apply_and_clear_cfg_rescale_override()
235
+ cfg_cond_mean = cfg_cond.mean()
236
+ cfg_rescale_mean = (1 - global_state.cfg_rescale) * cfg_cond_mean + global_state.cfg_rescale * cond.mean()
237
+ cfg_rescale_factor = global_state.cfg_rescale * (cond.std() / (cfg_cond.std() + 1e-8) - 1) + 1
238
+ return cfg_rescale_mean + (cfg_cond - cfg_cond_mean) * cfg_rescale_factor
239
+
240
+
241
+ # -----------------------------
242
+ # Data passed around visitors
243
+ # -----------------------------
244
+
245
+ @dataclasses.dataclass
246
+ class CombineDenoiseArgs:
247
+ x_out: torch.Tensor
248
+ uncond: torch.Tensor
249
+ cond_indices: List[Tuple[int, float]] # (index_in_x_out, weight_from_webui)
250
+
251
+
252
+ # -----------------------------
253
+ # Size visitor for REAL webui indices
254
+ # -----------------------------
255
+
256
+ class BaseSizeVisitor:
257
+ """
258
+ Считает ТОЛЬКО базовые листья (conciliation is None), т.е. столько,
259
+ сколько webui реально выдаёт cond_indices.
260
+ """
261
+ def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> int:
262
+ return 1 if getattr(that, "conciliation", None) is None else 0
263
+
264
+ def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> int:
265
+ if not getattr(that, "children", None):
266
+ return 0
267
+ return sum(child.accept(self) for child in that.children)
268
+
269
+
270
+ # -----------------------------
271
+ # Visitors producing deltas
272
+ # -----------------------------
273
+
274
+ class CondDeltaVisitor:
275
+ def visit_leaf_prompt(
276
+ self,
277
+ that: neutral_prompt_parser.LeafPrompt,
278
+ args: CombineDenoiseArgs,
279
+ index: int,
280
+ ) -> torch.Tensor:
281
+ # индекс тратится ТОЛЬКО на базовых листьях; подстрахуемся от выхода за границу
282
+ if index >= len(args.cond_indices):
283
+ return torch.zeros_like(args.x_out[0])
284
+
285
+ cond_info = args.cond_indices[index]
286
+ if that.weight != cond_info[1] and getattr(global_state, "verbose", False):
287
+ console_warn(
288
+ f"""
289
+ An unexpected noise weight was encountered at prompt #{index}
290
+ Expected: {that.weight}, but got: {cond_info[1]}
291
+ This may happen if another extension monkey-patches combine_denoised.
292
+ """
293
+ )
294
+ # разница между cond и uncond (базовая CFG-дельта)
295
+ return args.x_out[cond_info[0]] - args.uncond
296
+
297
+ def visit_composite_prompt(
298
+ self,
299
+ that: neutral_prompt_parser.CompositePrompt,
300
+ args: CombineDenoiseArgs,
301
+ index: int,
302
+ ) -> torch.Tensor:
303
+ cond_delta = torch.zeros_like(args.x_out[0])
304
+
305
+ for child in that.children:
306
+ child_cond_delta = child.accept(self, args, index)
307
+
308
+ # композиционные стратегии, влияющие на сложение delta
309
+ cs = getattr(that, "conciliation", None)
310
+ if cs in (
311
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE,
312
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE1,
313
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE2,
314
+ neutral_prompt_parser.ConciliationStrategy.SEQUENCE,
315
+ neutral_prompt_parser.ConciliationStrategy.NESTED_SEQUENCE,
316
+ neutral_prompt_parser.ConciliationStrategy.TOP_LEVEL_SEQUENCE,
317
+ ):
318
+ cond_delta += child_cond_delta
319
+ elif cs == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
320
+ cond_delta += salient_blend(cond_delta, [(child_cond_delta, child.weight)])
321
+ else:
322
+ cond_delta += child.weight * child_cond_delta
323
+
324
+ # шаг по количеству БАЗОВЫХ листьев внутри ребёнка
325
+ index += child.accept(BaseSizeVisitor())
326
+
327
+ return cond_delta
328
+
329
+
330
+ class AuxCondDeltaVisitor:
331
+ def visit_leaf_prompt(
332
+ self,
333
+ that: neutral_prompt_parser.LeafPrompt,
334
+ args: CombineDenoiseArgs,
335
+ cond_delta: torch.Tensor,
336
+ index: int,
337
+ ) -> torch.Tensor:
338
+ # для одиночного листа нет доп. коррекции
339
+ return torch.zeros_like(args.x_out[0])
340
+
341
+ def visit_composite_prompt(
342
+ self,
343
+ that: neutral_prompt_parser.CompositePrompt,
344
+ args: CombineDenoiseArgs,
345
+ cond_delta: torch.Tensor,
346
+ index: int,
347
+ ) -> torch.Tensor:
348
+ aux_cond_delta = torch.zeros_like(args.x_out[0])
349
+ salient_cond_deltas: List[Tuple[torch.Tensor, float]] = []
350
+
351
+ for child in that.children:
352
+ if getattr(child, "conciliation", None) is not None:
353
+ # рекурсивно собираем child deltas
354
+ child_cond_delta = child.accept(CondDeltaVisitor(), args, index)
355
+ child_cond_delta += child.accept(self, args, child_cond_delta, index)
356
+
357
+ if child.conciliation == neutral_prompt_parser.ConciliationStrategy.PERPENDICULAR:
358
+ aux_cond_delta += child.weight * get_perpendicular_component(cond_delta, child_cond_delta)
359
+ elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SALIENCE_MASK:
360
+ salient_cond_deltas.append((child_cond_delta, child.weight))
361
+ elif child.conciliation == neutral_prompt_parser.ConciliationStrategy.SEMANTIC_GUIDANCE:
362
+ aux_cond_delta += child.weight * filter_abs_top_k(child_cond_delta, 0.05)
363
+ elif child.conciliation in (
364
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE,
365
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE1,
366
+ neutral_prompt_parser.ConciliationStrategy.ALTERNATE2,
367
+ neutral_prompt_parser.ConciliationStrategy.SEQUENCE,
368
+ neutral_prompt_parser.ConciliationStrategy.NESTED_SEQUENCE,
369
+ neutral_prompt_parser.ConciliationStrategy.TOP_LEVEL_SEQUENCE,
370
+ ):
371
+ aux_cond_delta += child_cond_delta
372
+
373
+ # шаг только по БАЗОВЫМ листьям
374
+ index += child.accept(BaseSizeVisitor())
375
+
376
+ aux_cond_delta += salient_blend(cond_delta, salient_cond_deltas)
377
+ return aux_cond_delta
378
+
379
+
380
+ # -----------------------------
381
+ # Slice builder for webui call
382
+ # -----------------------------
383
+
384
+ def gather_webui_conds(
385
+ prompt: neutral_prompt_parser.PromptExpr,
386
+ args: CombineDenoiseArgs,
387
+ index_in: int,
388
+ index_out: int,
389
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, float]]]:
390
+ """
391
+ Формируем список «условных срезов» (x_out) и соответствующие пары (index, weight),
392
+ которые поймёт оригинальный webui combine_denoised.
393
+ - Для базовых листьев берём реальные webui cond_indices (и тратим индекс)
394
+ - Для «вторичных» листьев и композитов синтезируем канал: uncond + (cond_delta + aux_delta)
395
+ """
396
+ sliced_x_out: List[torch.Tensor] = []
397
+ sliced_cond_indices: List[Tuple[int, float]] = []
398
+
399
+ # Лист
400
+ if isinstance(prompt, neutral_prompt_parser.LeafPrompt):
401
+ if getattr(prompt, "conciliation", None) is None:
402
+ # базовый лист — берём реальный cond канал из x_out по webui индексу
403
+ if index_in >= len(args.cond_indices):
404
+ return sliced_x_out, sliced_cond_indices
405
+ real_idx, _w = args.cond_indices[index_in]
406
+ child_x_out = args.x_out[real_idx]
407
+ else:
408
+ # вторичный лист — синтезируем «условный канал»
409
+ child_cond = prompt.accept(CondDeltaVisitor(), args, index_in)
410
+ child_cond += prompt.accept(AuxCondDeltaVisitor(), args, child_cond, index_in)
411
+ child_x_out = args.uncond + child_cond
412
+
413
+ index_offset = index_out + len(sliced_x_out)
414
+ sliced_x_out.append(child_x_out)
415
+ sliced_cond_indices.append((index_offset, prompt.weight))
416
+ return sliced_x_out, sliced_cond_indices
417
+
418
+ # Композит
419
+ if isinstance(prompt, neutral_prompt_parser.CompositePrompt):
420
+ for child in prompt.children:
421
+ if isinstance(child, neutral_prompt_parser.LeafPrompt) and getattr(child, "conciliation", None) is None:
422
+ # базовый лист — реальный индекс webui
423
+ if index_in >= len(args.cond_indices):
424
+ break
425
+ real_idx, _w = args.cond_indices[index_in]
426
+ child_x_out = args.x_out[real_idx]
427
+ else:
428
+ # всё остальное — синтезируем канал
429
+ child_cond = child.accept(CondDeltaVisitor(), args, index_in)
430
+ child_cond += child.accept(AuxCondDeltaVisitor(), args, child_cond, index_in)
431
+ child_x_out = args.uncond + child_cond
432
+
433
+ index_offset = index_out + len(sliced_x_out)
434
+ sliced_x_out.append(child_x_out)
435
+ sliced_cond_indices.append((index_offset, child.weight))
436
+
437
+ # критично: шагать по ЧИСЛУ БАЗОВЫХ листьев внутри ребёнка
438
+ index_in += child.accept(BaseSizeVisitor())
439
+
440
+ return sliced_x_out, sliced_cond_indices
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/external_code/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+
3
+
4
+ @contextlib.contextmanager
5
+ def fix_path():
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ extension_path = str(Path(__file__).parent.parent.parent)
10
+ added = False
11
+ if extension_path not in sys.path:
12
+ sys.path.insert(0, extension_path)
13
+ added = True
14
+
15
+ yield
16
+
17
+ if added:
18
+ sys.path.remove(extension_path)
19
+
20
+
21
+ with fix_path():
22
+ del fix_path, contextlib
23
+ from .api import *
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/external_code/api.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state
2
+
3
+
4
+ def override_cfg_rescale(cfg_rescale: float):
5
+ global_state.cfg_rescale_override = cfg_rescale
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/global_state.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from lib_neutral_prompt import neutral_prompt_parser
3
+
4
+
5
+ is_enabled: bool = False
6
+ prompt_exprs: List[neutral_prompt_parser.PromptExpr] = []
7
+ cfg_rescale: float = 0.0
8
+ verbose: bool = True
9
+ cfg_rescale_override: Optional[float] = None
10
+
11
+
12
+ def apply_and_clear_cfg_rescale_override():
13
+ global cfg_rescale, cfg_rescale_override
14
+ if cfg_rescale_override is not None:
15
+ cfg_rescale = cfg_rescale_override
16
+ cfg_rescale_override = None
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+
4
+ class ModuleHijacker:
5
+ def __init__(self, module):
6
+ self.__module = module
7
+ self.__original_functions = dict()
8
+
9
+ def hijack(self, attribute):
10
+ if attribute not in self.__original_functions:
11
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
12
+
13
+ def decorator(function):
14
+ setattr(self.__module, attribute, functools.partial(function, original_function=self.__original_functions[attribute]))
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, on_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ on_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ on_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/neutral_prompt_parser.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import dataclasses
3
+ import re
4
+ from enum import Enum
5
+ from typing import List, Tuple, Any, Optional
6
+
7
+ class PromptKeyword(Enum):
8
+ AND = 'AND'
9
+ AND_PERP = 'AND_PERP'
10
+ AND_SALT = 'AND_SALT'
11
+ AND_TOPK = 'AND_TOPK'
12
+ ALTERNATE = '|'
13
+ AND_RULE = '&'
14
+ GROUPED = '{'
15
+ SEQUENCE = '::'
16
+ ALTERNATE1 = 'alternate1'
17
+ ALTERNATE2 = 'alternate2'
18
+ TOP_LEVEL_SEQUENCE = 'top_level_sequence'
19
+ NESTED_SEQUENCE = 'nested_sequence'
20
+
21
+ prompt_keywords = [e.value for e in PromptKeyword]
22
+
23
+ class ConciliationStrategy(Enum):
24
+ PERPENDICULAR = PromptKeyword.AND_PERP.value
25
+ SALIENCE_MASK = PromptKeyword.AND_SALT.value
26
+ SEMANTIC_GUIDANCE = PromptKeyword.AND_TOPK.value
27
+ ALTERNATE = PromptKeyword.ALTERNATE.value
28
+ AND_RULE = PromptKeyword.AND_RULE.value
29
+ GROUPED = PromptKeyword.GROUPED.value
30
+ SEQUENCE = PromptKeyword.SEQUENCE.value
31
+ ALTERNATE1 = PromptKeyword.ALTERNATE1.value
32
+ ALTERNATE2 = PromptKeyword.ALTERNATE2.value
33
+ TOP_LEVEL_SEQUENCE = PromptKeyword.TOP_LEVEL_SEQUENCE.value
34
+ NESTED_SEQUENCE = PromptKeyword.NESTED_SEQUENCE.value
35
+
36
+ conciliation_strategies = [e.value for e in ConciliationStrategy]
37
+
38
+ @dataclasses.dataclass
39
+ class PromptExpr(abc.ABC):
40
+ weight: float
41
+ conciliation: Optional[ConciliationStrategy]
42
+
43
+ @abc.abstractmethod
44
+ def accept(self, visitor, *args, **kwargs) -> Any:
45
+ pass
46
+
47
+ @dataclasses.dataclass
48
+ class LeafPrompt(PromptExpr):
49
+ prompt: str
50
+
51
+ def accept(self, visitor, *args, **kwargs):
52
+ return visitor.visit_leaf_prompt(self, *args, **kwargs)
53
+
54
+ @dataclasses.dataclass
55
+ class CompositePrompt(PromptExpr):
56
+ children: List[PromptExpr]
57
+
58
+ def accept(self, visitor, *args, **kwargs):
59
+ return visitor.visit_composite_prompt(self, *args, **kwargs)
60
+
61
+ class FlatSizeVisitor:
62
+ def visit_leaf_prompt(self, that: LeafPrompt) -> int:
63
+ return 1
64
+
65
+ def visit_composite_prompt(self, that: CompositePrompt) -> int:
66
+ return sum(child.accept(self) for child in that.children) if that.children else 0
67
+
68
+ def parse_root(string: str) -> CompositePrompt:
69
+ tokens = tokenize(string)
70
+ prompts = parse_prompts(tokens)
71
+ return CompositePrompt(1.0, None, prompts)
72
+
73
+ def parse_prompts(tokens: List[str], *, nested: bool = False) -> List[PromptExpr]:
74
+ prompts = []
75
+ while tokens:
76
+ if nested and tokens[0] in [']', '}']:
77
+ break
78
+ prompts.append(parse_prompt(tokens, first=len(prompts) == 0, nested=nested))
79
+ return prompts
80
+
81
+ def parse_prompt(tokens: List[str], *, first: bool, nested: bool = False) -> PromptExpr:
82
+ if not first and tokens[0] in prompt_keywords:
83
+ prompt_type = tokens.pop(0)
84
+ else:
85
+ prompt_type = PromptKeyword.AND.value
86
+
87
+ tokens_copy = tokens.copy()
88
+ # Поддержка [prompt1|prompt2] и [prompt1:prompt2:step]
89
+ if tokens_copy and tokens_copy[0] == '[':
90
+ tokens_copy.pop(0)
91
+
92
+ # 🔧 ВАЖНО: дефолтное значение, иначе UnboundLocalError
93
+ conciliation = None
94
+
95
+ prompts = [parse_prompt(tokens_copy, first=True, nested=True)]
96
+ if tokens_copy and tokens_copy[0] == '|': # [a|b]
97
+ tokens_copy.pop(0)
98
+ prompts.append(parse_prompt(tokens_copy, first=False, nested=True))
99
+ conciliation = ConciliationStrategy.ALTERNATE
100
+ if tokens_copy and tokens_copy[0] == '!': # [a|b!]
101
+ tokens_copy.pop(0)
102
+ conciliation = ConciliationStrategy.ALTERNATE1
103
+
104
+ elif tokens_copy and tokens_copy[0] == ':': # [a:b(:step)]
105
+ tokens_copy.pop(0)
106
+ prompts.append(parse_prompt(tokens_copy, first=False, nested=True))
107
+ if tokens_copy and tokens_copy[0] == ':':
108
+ tokens_copy.pop(0)
109
+ if tokens_copy and is_float(tokens_copy[0]):
110
+ tokens_copy.pop(0) # step игнорим на этом уровне
111
+
112
+ if tokens_copy and tokens_copy[0] == ']':
113
+ tokens_copy.pop(0)
114
+
115
+ tokens[:] = tokens_copy
116
+ weight = parse_weight(tokens)
117
+ return CompositePrompt(weight, conciliation, prompts)
118
+
119
+
120
+ # Поддержка (prompt:weight)
121
+ if tokens_copy and tokens_copy[0] == '(':
122
+ tokens_copy.pop(0)
123
+ prompt_text, weight = parse_prompt_text(tokens_copy, nested=True)
124
+ if tokens_copy and tokens_copy[0] == ':':
125
+ tokens_copy.pop(0)
126
+ if tokens_copy and is_float(tokens_copy[0]):
127
+ weight = float(tokens_copy.pop(0))
128
+ if tokens_copy and tokens_copy[0] == ')':
129
+ tokens_copy.pop(0)
130
+ tokens[:] = tokens_copy
131
+ return LeafPrompt(weight, None, prompt_text)
132
+
133
+ # Поддержка {grouped}
134
+ if tokens_copy and tokens_copy[0] == '{':
135
+ tokens_copy.pop(0)
136
+ prompts = parse_prompts(tokens_copy, nested=True)
137
+ if tokens_copy and tokens_copy[0] == '}':
138
+ tokens_copy.pop(0)
139
+ tokens[:] = tokens_copy
140
+ weight = parse_weight(tokens)
141
+ return CompositePrompt(weight, ConciliationStrategy.GROUPED, prompts)
142
+
143
+ # Поддержка & и ::
144
+ if tokens_copy and tokens_copy[0] in ['&', '::']:
145
+ op = tokens_copy.pop(0)
146
+ prompts = parse_prompts(tokens_copy, nested=True)
147
+ tokens[:] = tokens_copy
148
+ weight = parse_weight(tokens)
149
+ conciliation = ConciliationStrategy.AND_RULE if op == '&' else ConciliationStrategy.SEQUENCE
150
+ return CompositePrompt(weight, conciliation, prompts)
151
+
152
+ # Поддержка top_level_sequence и nested_sequence
153
+ if tokens_copy and tokens_copy[0] == '::' and tokens_copy[-1] in ['!', ';']:
154
+ tokens_copy.pop(0)
155
+ prompts = parse_prompts(tokens_copy, nested=True)
156
+ if tokens_copy and tokens_copy[0] in ['!', ';']:
157
+ tokens_copy.pop(0)
158
+ tokens[:] = tokens_copy
159
+ weight = parse_weight(tokens)
160
+ return CompositePrompt(weight, ConciliationStrategy.NESTED_SEQUENCE, prompts)
161
+
162
+ if tokens_copy and '!!' in tokens_copy:
163
+ idx = tokens_copy.index('!!')
164
+ owner = parse_prompt(tokens_copy[:idx], first=True, nested=nested)
165
+ tokens_copy = tokens_copy[idx+1:]
166
+ prompts = [owner] + parse_prompts(tokens_copy, nested=True)
167
+ tokens[:] = tokens_copy
168
+ weight = parse_weight(tokens)
169
+ return CompositePrompt(weight, ConciliationStrategy.TOP_LEVEL_SEQUENCE, prompts)
170
+
171
+ prompt_text, weight = parse_prompt_text(tokens, nested=nested)
172
+ return LeafPrompt(weight, ConciliationStrategy(prompt_type) if prompt_type in conciliation_strategies else None, prompt_text)
173
+
174
+ def parse_prompt_text(tokens: List[str], *, nested: bool = False) -> Tuple[str, float]:
175
+ text = ''
176
+ depth = 0
177
+ weight = 1.0
178
+ while tokens:
179
+ if tokens[0] in [']', '}'] and depth == 0 and nested:
180
+ break
181
+ elif tokens[0] in ['[', '{', '(']:
182
+ depth += 1
183
+ elif tokens[0] in [')', ']']:
184
+ depth -= 1
185
+ elif tokens[0] == ':' and len(tokens) >= 2 and is_float(tokens[1].strip()):
186
+ if len(tokens) < 3 or tokens[2] in prompt_keywords or tokens[2] in [']', ')', '}'] and depth == 0:
187
+ tokens.pop(0)
188
+ weight = float(tokens.pop(0).strip())
189
+ break
190
+ elif tokens[0] in prompt_keywords and depth == 0:
191
+ break
192
+ text += tokens.pop(0)
193
+ return text, weight
194
+
195
+ def parse_weight(tokens: List[str]) -> float:
196
+ weight = 1.0
197
+ if len(tokens) >= 2 and tokens[0] == ':' and is_float(tokens[1]):
198
+ tokens.pop(0)
199
+ weight = float(tokens.pop(0))
200
+ return weight
201
+
202
+ def tokenize(s: str):
203
+ prompt_keywords_regex = '|'.join(r'\b' + re.escape(keyword) + r'\b' for keyword in prompt_keywords)
204
+ pattern = r'(\[|\]|\(|\)|:|\{|\}|&|\||' + prompt_keywords_regex + r')'
205
+ return [s for s in re.split(pattern, s) if s.strip()]
206
+
207
+ def is_float(string: str) -> bool:
208
+ try:
209
+ float(string)
210
+ return True
211
+ except ValueError:
212
+ return False
213
+
214
+ if __name__ == '__main__':
215
+ res = parse_root('''
216
+ hello
217
+ AND_PERP [
218
+ arst
219
+ AND defg : 2
220
+ AND_SALT [
221
+ very nested huh? what do you say :.0
222
+ ]
223
+ ]
224
+ ''')
225
+ pass
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/prompt_parser_hijack.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List
3
+ import logging
4
+ import lark
5
+
6
+ from lib_neutral_prompt import hijacker, global_state, neutral_prompt_parser
7
+ from modules import script_callbacks, prompt_parser
8
+ from modules.prompt_parser import schedule_parser, resolve_tree
9
+
10
+ logging.basicConfig(level=logging.WARNING)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
14
+ module=prompt_parser,
15
+ hijacker_attribute='__neutral_prompt_hijacker',
16
+ on_uninstall=script_callbacks.on_script_unloaded,
17
+ )
18
+
19
+ def convert_to_prompt_expr(prompts: List[str], conds_list=None):
20
+ exprs: List[neutral_prompt_parser.PromptExpr] = []
21
+ if conds_list: # берём индексы/веса из оригинального парсера, если он их дал
22
+ for batch_conds in conds_list:
23
+ for idx, weight in batch_conds:
24
+ prompt_text = prompts[idx] if idx < len(prompts) else f"prompt_{idx}"
25
+ exprs.append(
26
+ neutral_prompt_parser.LeafPrompt(
27
+ weight=float(weight) if isinstance(weight, (int, float)) else 1.0,
28
+ conciliation=None,
29
+ prompt=prompt_text
30
+ )
31
+ )
32
+ return exprs
33
+
34
+ # fallback — разбираем сами деревом
35
+ for prompt in prompts:
36
+ try:
37
+ tree = schedule_parser.parse(prompt)
38
+ expr = tree_to_prompt_expr(tree)
39
+ exprs.append(expr)
40
+ except lark.exceptions.LarkError as e:
41
+ logger.warning("Failed to parse prompt '%s': %s", prompt, e)
42
+ exprs.append(
43
+ neutral_prompt_parser.LeafPrompt(
44
+ weight=1.0, conciliation=None, prompt=prompt
45
+ )
46
+ )
47
+ return exprs
48
+
49
+
50
+ def tree_to_prompt_expr(tree: lark.Tree | lark.Token) -> neutral_prompt_parser.PromptExpr:
51
+ """Преобразует lark.Tree в PromptExpr Neutral Prompt."""
52
+ # Листовой токен
53
+ if isinstance(tree, lark.Token):
54
+ return neutral_prompt_parser.LeafPrompt(weight=1.0, conciliation=None, prompt=str(tree))
55
+
56
+ data = getattr(tree, "data", None)
57
+
58
+ if data == "plain":
59
+ # plain всегда Token внутри грамматики
60
+ return neutral_prompt_parser.LeafPrompt(weight=1.0, conciliation=None, prompt=tree.children[0].value)
61
+
62
+ elif data in ("alternate", "alternate1", "alternate2"):
63
+ # фильтруем пробелы безопасно
64
+ children = [
65
+ tree_to_prompt_expr(child)
66
+ for child in tree.children
67
+ if not (isinstance(child, lark.Token) and child.type == "WHITESPACE")
68
+ ]
69
+ # Храним как один тип ALTERNATE — выбор сделает визитор при транспиляции
70
+ return neutral_prompt_parser.CompositePrompt(
71
+ weight=1.0,
72
+ conciliation=neutral_prompt_parser.ConciliationStrategy.ALTERNATE,
73
+ children=children
74
+ )
75
+
76
+ elif data == "grouped":
77
+ children = [
78
+ tree_to_prompt_expr(child)
79
+ for child in tree.children
80
+ if not (isinstance(child, lark.Token) and child.type == "WHITESPACE")
81
+ ]
82
+ return neutral_prompt_parser.CompositePrompt(
83
+ weight=1.0,
84
+ conciliation=neutral_prompt_parser.ConciliationStrategy.GROUPED,
85
+ children=children
86
+ )
87
+
88
+ elif data == "and_rule":
89
+ children = [
90
+ tree_to_prompt_expr(child)
91
+ for child in tree.children
92
+ if not (isinstance(child, lark.Token) and child.type == "WHITESPACE")
93
+ ]
94
+ return neutral_prompt_parser.CompositePrompt(
95
+ weight=1.0,
96
+ conciliation=neutral_prompt_parser.ConciliationStrategy.AND_RULE,
97
+ children=children
98
+ )
99
+
100
+ elif data == "sequence":
101
+ children = [
102
+ tree_to_prompt_expr(child)
103
+ for child in tree.children
104
+ if not (isinstance(child, lark.Token) and child.type == "WHITESPACE")
105
+ ]
106
+ return neutral_prompt_parser.CompositePrompt(
107
+ weight=1.0,
108
+ conciliation=neutral_prompt_parser.ConciliationStrategy.SEQUENCE,
109
+ children=children
110
+ )
111
+
112
+ elif data == "emphasized":
113
+ # безопасно читаем вес: второй элемент может быть Token/Tree/отсутствовать
114
+ prompt_expr = tree_to_prompt_expr(tree.children[0])
115
+ w = 1.1
116
+ if len(tree.children) > 1 and isinstance(tree.children[1], lark.Token) and tree.children[1].type == "NUMBER":
117
+ try:
118
+ w = float(tree.children[1])
119
+ except (ValueError, TypeError):
120
+ w = 1.0
121
+ prompt_expr.weight = w
122
+ return prompt_expr
123
+
124
+ elif data == "numbered":
125
+ quantity = int(tree.children[0])
126
+ distinct = (str(tree.children[1]) == "!") if len(tree.children) > 1 else False # корректная проверка
127
+ target = tree_to_prompt_expr(tree.children[-1])
128
+
129
+ # В текущей модели PromptExpr нет «мешка выбора», просто дублируем target.
130
+ # (Если нужно «distinct», это можно отразить внутри transpile в будущем.)
131
+ return neutral_prompt_parser.CompositePrompt(
132
+ weight=1.0,
133
+ conciliation=None,
134
+ children=[target] * quantity
135
+ )
136
+
137
+ # по умолчанию — разворачиваем дерево в строку
138
+ return neutral_prompt_parser.LeafPrompt(weight=1.0, conciliation=None, prompt=resolve_tree(tree))
139
+
140
+
141
+ @prompt_parser_hijacker.hijack('get_multicond_prompt_list')
142
+ def get_multicond_prompt_list_hijack(prompts, original_function):
143
+ if not global_state.is_enabled:
144
+ return original_function(prompts)
145
+
146
+ # Если это уже SdConditioning — разворачиваем в строки совместимые с webui
147
+ if isinstance(prompts, prompt_parser.SdConditioning):
148
+ webui_prompts = [prompt_parser.resolve_tree(p, keep_spacing=True) for p in prompts]
149
+ else:
150
+ webui_prompts = prompts
151
+
152
+ # Совместимость: оригинал может вернуть 2 или 3 элемента; нам нужен только первый (indexes/weights)
153
+ original_result = original_function(webui_prompts)
154
+ if isinstance(original_result, tuple) and len(original_result) >= 1:
155
+ conds_list = original_result[0]
156
+ global_state.prompt_exprs = convert_to_prompt_expr(webui_prompts, conds_list)
157
+ else:
158
+ global_state.prompt_exprs = convert_to_prompt_expr(webui_prompts)
159
+
160
+ transformed_prompts = transpile_exprs(global_state.prompt_exprs)
161
+ if isinstance(prompts, prompt_parser.SdConditioning):
162
+ transformed_prompts = prompt_parser.SdConditioning(transformed_prompts, copy_from=prompts)
163
+
164
+ return original_function(transformed_prompts)
165
+
166
+
167
+ def transpile_exprs(exprs: List[neutral_prompt_parser.PromptExpr]) -> List[str]:
168
+ return [expr.accept(WebuiPromptVisitor()) for expr in exprs]
169
+
170
+
171
+ class WebuiPromptVisitor:
172
+ # LeafPrompt -> "text:1.2" если weight != 1.0
173
+ def visit_leaf_prompt(self, that: neutral_prompt_parser.LeafPrompt) -> str:
174
+ return f"{that.prompt}:{that.weight}" if that.weight != 1.0 else that.prompt
175
+
176
+ def visit_composite_prompt(self, that: neutral_prompt_parser.CompositePrompt) -> str:
177
+ children_rendered = [child.accept(self) for child in that.children]
178
+
179
+ cs = that.conciliation
180
+ if cs == neutral_prompt_parser.ConciliationStrategy.ALTERNATE:
181
+ # используем стандартный синтаксис webui: [a|b|c]
182
+ return "[" + "|".join(children_rendered) + "]"
183
+
184
+ if cs == neutral_prompt_parser.ConciliationStrategy.SEQUENCE:
185
+ # owner :: d1, d2, d3
186
+ if not children_rendered:
187
+ return ""
188
+ owner = children_rendered[0]
189
+ tail = ", ".join(children_rendered[1:]) if len(children_rendered) > 1 else ""
190
+ return f"{owner}::{tail}" if tail else owner
191
+
192
+ if cs == neutral_prompt_parser.ConciliationStrategy.AND_RULE:
193
+ # a & b & c
194
+ return " & ".join(children_rendered)
195
+
196
+ if cs == neutral_prompt_parser.ConciliationStrategy.GROUPED:
197
+ # {a, b, c}
198
+ return "{" + ", ".join(children_rendered) + "}"
199
+
200
+ # неизвестная стратегия — вернём просто перечисление
201
+ return ", ".join(children_rendered)
202
+
203
+
204
+ # простой экспорт, если где-то вызывают prompt_parser_hijack.parse_prompts(...)
205
+ def parse_prompts(prompts: List[str]) -> List[str]:
206
+ exprs = convert_to_prompt_expr(prompts)
207
+ return transpile_exprs(exprs)
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/ui.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state, neutral_prompt_parser
2
+ from modules import script_callbacks, shared
3
+ from typing import Dict, Tuple, List, Callable
4
+ import gradio as gr
5
+ import dataclasses
6
+
7
+ txt2img_prompt_textbox = None
8
+ img2img_prompt_textbox = None
9
+
10
+ prompt_types = {
11
+ 'Perpendicular': neutral_prompt_parser.PromptKeyword.AND_PERP.value,
12
+ 'Saliency-aware': neutral_prompt_parser.PromptKeyword.AND_SALT.value,
13
+ 'Semantic guidance top-k': neutral_prompt_parser.PromptKeyword.AND_TOPK.value,
14
+ 'Alternate': neutral_prompt_parser.PromptKeyword.ALTERNATE.value,
15
+ 'And Rule': neutral_prompt_parser.PromptKeyword.AND_RULE.value,
16
+ 'Grouped': neutral_prompt_parser.PromptKeyword.GROUPED.value,
17
+ 'Sequence': neutral_prompt_parser.PromptKeyword.SEQUENCE.value,
18
+ }
19
+ prompt_types_tooltip = '\n'.join([
20
+ 'AND - add all prompt features equally (webui builtin)',
21
+ 'Perpendicular - reduce the impact of contradicting prompt features',
22
+ 'Saliency-aware - strongest prompt features win',
23
+ 'Semantic guidance top-k - small targeted changes',
24
+ 'Alternate - select one option from a list (e.g., [cat|dog])',
25
+ 'And Rule - combine prompts with & (e.g., cat & dog)',
26
+ 'Grouped - group prompts together (e.g., {cat,dog})',
27
+ 'Sequence - sequential prompt application (e.g., cat::dog)',
28
+ ])
29
+
30
+ @dataclasses.dataclass
31
+ class AccordionInterface:
32
+ get_elem_id: Callable
33
+
34
+ def __post_init__(self):
35
+ self.is_rendered = False
36
+ self.cfg_rescale = gr.Slider(label='CFG rescale', minimum=0, maximum=1, value=0)
37
+ self.neutral_prompt = gr.Textbox(label='Neutral prompt', show_label=False, lines=3, placeholder='Neutral prompt (click on apply below to append this to the positive prompt textbox)')
38
+ self.neutral_cond_scale = gr.Slider(label='Prompt weight', minimum=-3, maximum=3, value=1)
39
+ self.aux_prompt_type = gr.Dropdown(
40
+ label='Prompt type',
41
+ choices=list(prompt_types.keys()),
42
+ value=next(iter(prompt_types.keys())),
43
+ tooltip=prompt_types_tooltip,
44
+ elem_id=self.get_elem_id('formatter_prompt_type')
45
+ )
46
+ self.append_to_prompt_button = gr.Button(value='Apply to prompt')
47
+
48
+ def arrange_components(self, is_img2img: bool):
49
+ if self.is_rendered:
50
+ return
51
+ with gr.Accordion(label='Neutral Prompt', open=False):
52
+ self.cfg_rescale.render()
53
+ with gr.Accordion(label='Prompt formatter', open=False):
54
+ self.neutral_prompt.render()
55
+ self.neutral_cond_scale.render()
56
+ self.aux_prompt_type.render()
57
+ self.append_to_prompt_button.render()
58
+
59
+ def connect_events(self, is_img2img: bool):
60
+ if self.is_rendered:
61
+ return
62
+ prompt_textbox = img2img_prompt_textbox if is_img2img else txt2img_prompt_textbox
63
+ self.append_to_prompt_button.click(
64
+ fn=lambda init_prompt, prompt, scale, prompt_type: (
65
+ f'{init_prompt}\n{prompt_types[prompt_type]} {prompt} :{scale}', ''
66
+ ),
67
+ inputs=[prompt_textbox, self.neutral_prompt, self.neutral_cond_scale, self.aux_prompt_type],
68
+ outputs=[prompt_textbox, self.neutral_prompt]
69
+ )
70
+
71
+ def set_rendered(self, value: bool = True):
72
+ self.is_rendered = value
73
+
74
+ def get_components(self) -> Tuple[gr.components.Component]:
75
+ return (self.cfg_rescale,)
76
+
77
+ def get_infotext_fields(self) -> Tuple[Tuple[gr.components.Component, str]]:
78
+ return tuple(zip(self.get_components(), ('CFG Rescale phi',)))
79
+
80
+ def get_paste_field_names(self) -> List[str]:
81
+ return ['CFG Rescale phi']
82
+
83
+ def get_extra_generation_params(self, args: Dict) -> Dict:
84
+ return {'CFG Rescale phi': args['cfg_rescale']}
85
+
86
+ def unpack_processing_args(self, cfg_rescale: float) -> Dict:
87
+ return {'cfg_rescale': cfg_rescale}
88
+
89
+ def on_ui_settings():
90
+ section = ('neutral_prompt', 'Neutral Prompt')
91
+ shared.opts.add_option('neutral_prompt_enabled', shared.OptionInfo(True, 'Enable neutral-prompt extension', section=section))
92
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
93
+ shared.opts.add_option('neutral_prompt_verbose', shared.OptionInfo(False, 'Enable verbose debugging for neutral-prompt', section=section))
94
+ shared.opts.onchange('neutral_prompt_verbose', update_verbose)
95
+
96
+ script_callbacks.on_ui_settings(on_ui_settings)
97
+
98
+ def update_verbose():
99
+ global_state.verbose = shared.opts.data.get('neutral_prompt_verbose', False)
100
+
101
+ def on_after_component(component, **_kwargs):
102
+ if getattr(component, 'elem_id', None) == 'txt2img_prompt':
103
+ global txt2img_prompt_textbox
104
+ txt2img_prompt_textbox = component
105
+ if getattr(component, 'elem_id', None) == 'img2img_prompt':
106
+ global img2img_prompt_textbox
107
+ img2img_prompt_textbox = component
108
+
109
+ script_callbacks.on_after_component(on_after_component)
z-sd-webui-neutral-prompt-workYEAH4/lib_neutral_prompt/xyz_grid.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from types import ModuleType
3
+ from typing import Optional
4
+ from modules import scripts
5
+ from lib_neutral_prompt import global_state
6
+
7
+
8
+ def patch():
9
+ xyz_module = find_xyz_module()
10
+ if xyz_module is None:
11
+ print("[sd-webui-neutral-prompt]", "xyz_grid.py not found.", file=sys.stderr)
12
+ return
13
+
14
+ xyz_module.axis_options.extend([
15
+ xyz_module.AxisOption("[Neutral Prompt] CFG Rescale", int_or_float, apply_cfg_rescale()),
16
+ ])
17
+
18
+
19
+ class XyzFloat(float):
20
+ is_xyz: bool = True
21
+
22
+
23
+ def apply_cfg_rescale():
24
+ def callback(_p, v, _vs):
25
+ global_state.cfg_rescale = XyzFloat(v)
26
+
27
+ return callback
28
+
29
+
30
+ def int_or_float(string):
31
+ try:
32
+ return int(string)
33
+ except ValueError:
34
+ return float(string)
35
+
36
+
37
+ def find_xyz_module() -> Optional[ModuleType]:
38
+ for data in scripts.scripts_data:
39
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
40
+ return data.module
41
+
42
+ return None
z-sd-webui-neutral-prompt-workYEAH4/scripts/__pycache__/neutral_prompt.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
z-sd-webui-neutral-prompt-workYEAH4/scripts/neutral_prompt.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
2
+ from modules import scripts, processing, shared
3
+ from typing import Dict
4
+ import functools
5
+
6
+
7
+ class NeutralPromptScript(scripts.Script):
8
+ def __init__(self):
9
+ self.accordion_interface = None
10
+ self._is_img2img = False
11
+
12
+ @property
13
+ def is_img2img(self):
14
+ return self._is_img2img
15
+
16
+ @is_img2img.setter
17
+ def is_img2img(self, is_img2img):
18
+ self._is_img2img = is_img2img
19
+ if self.accordion_interface is None:
20
+ self.accordion_interface = ui.AccordionInterface(self.elem_id)
21
+
22
+ def title(self) -> str:
23
+ return "Neutral Prompt"
24
+
25
+ def show(self, is_img2img: bool):
26
+ return scripts.AlwaysVisible
27
+
28
+ def ui(self, is_img2img: bool):
29
+ self.hijack_composable_lora(is_img2img)
30
+
31
+ self.accordion_interface.arrange_components(is_img2img)
32
+ self.accordion_interface.connect_events(is_img2img)
33
+ self.infotext_fields = self.accordion_interface.get_infotext_fields()
34
+ self.paste_field_names = self.accordion_interface.get_paste_field_names()
35
+ self.accordion_interface.set_rendered()
36
+ return self.accordion_interface.get_components()
37
+
38
+ def process(self, p: processing.StableDiffusionProcessing, *args):
39
+ args = self.accordion_interface.unpack_processing_args(*args)
40
+
41
+ self.update_global_state(args)
42
+ if global_state.is_enabled:
43
+ p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
44
+
45
+ def update_global_state(self, args: Dict):
46
+ if shared.state.job_no == 0:
47
+ global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
48
+
49
+ for k, v in args.items():
50
+ try:
51
+ getattr(global_state, k)
52
+ except AttributeError:
53
+ continue
54
+
55
+ if getattr(getattr(global_state, k), 'is_xyz', False):
56
+ xyz_attr = getattr(global_state, k)
57
+ xyz_attr.is_xyz = False
58
+ args[k] = xyz_attr
59
+ continue
60
+
61
+ if shared.state.job_no > 0:
62
+ continue
63
+
64
+ setattr(global_state, k, v)
65
+
66
+ def hijack_composable_lora(self, is_img2img):
67
+ if self.accordion_interface.is_rendered:
68
+ return
69
+
70
+ lora_script = None
71
+ script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
72
+
73
+ for script in script_runner.alwayson_scripts:
74
+ if script.title().lower() == "composable lora":
75
+ lora_script = script
76
+ break
77
+
78
+ if lora_script is not None:
79
+ lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
80
+
81
+
82
+ def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
83
+ if not global_state.is_enabled:
84
+ return original_function(p, *args, **kwargs)
85
+
86
+ exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
87
+ all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
88
+ res = original_function(p, *args, **kwargs)
89
+ # restore original prompts
90
+ p.all_prompts = all_prompts
91
+ return res
92
+
93
+
94
+ xyz_grid.patch()
z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/__init__.py ADDED
File without changes
z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/basic_test.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.simple_prompt = neutral_prompt_parser.parse_root("hello :1.0")
11
+ self.and_prompt = neutral_prompt_parser.parse_root("hello AND goodbye :2.0")
12
+ self.and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP goodbye :2.0")
13
+ self.and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT goodbye :2.0")
14
+ self.nested_and_perp_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0 AND_PERP welcome :3.0]")
15
+ self.nested_and_salt_prompt = neutral_prompt_parser.parse_root("hello :1.0 AND_SALT [goodbye :2.0 AND_SALT welcome :3.0]")
16
+ self.invalid_weight = neutral_prompt_parser.parse_root("hello :not_a_float")
17
+
18
+ def test_simple_prompt_child_count(self):
19
+ self.assertEqual(len(self.simple_prompt.children), 1)
20
+
21
+ def test_simple_prompt_child_weight(self):
22
+ self.assertEqual(self.simple_prompt.children[0].weight, 1.0)
23
+
24
+ def test_simple_prompt_child_prompt(self):
25
+ self.assertEqual(self.simple_prompt.children[0].prompt, "hello ")
26
+
27
+ def test_square_weight_prompt(self):
28
+ prompt = "a [b c d e : f g h :1.5]"
29
+ parsed = neutral_prompt_parser.parse_root(prompt)
30
+ self.assertEqual(parsed.children[0].prompt, prompt)
31
+
32
+ composed_prompt = f"{prompt} AND_PERP other prompt"
33
+ parsed = neutral_prompt_parser.parse_root(composed_prompt)
34
+ self.assertEqual(parsed.children[0].prompt, prompt)
35
+
36
+ def test_and_prompt_child_count(self):
37
+ self.assertEqual(len(self.and_prompt.children), 2)
38
+
39
+ def test_and_prompt_child_weights_and_prompts(self):
40
+ self.assertEqual(self.and_prompt.children[0].weight, 1.0)
41
+ self.assertEqual(self.and_prompt.children[0].prompt, "hello ")
42
+ self.assertEqual(self.and_prompt.children[1].weight, 2.0)
43
+ self.assertEqual(self.and_prompt.children[1].prompt, " goodbye ")
44
+
45
+ def test_and_perp_prompt_child_count(self):
46
+ self.assertEqual(len(self.and_perp_prompt.children), 2)
47
+
48
+ def test_and_perp_prompt_child_types(self):
49
+ self.assertIsInstance(self.and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
50
+ self.assertIsInstance(self.and_perp_prompt.children[1], neutral_prompt_parser.LeafPrompt)
51
+
52
+ def test_and_perp_prompt_nested_child(self):
53
+ nested_child = self.and_perp_prompt.children[1]
54
+ self.assertEqual(nested_child.weight, 2.0)
55
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
56
+
57
+ def test_nested_and_perp_prompt_child_count(self):
58
+ self.assertEqual(len(self.nested_and_perp_prompt.children), 2)
59
+
60
+ def test_nested_and_perp_prompt_child_types(self):
61
+ self.assertIsInstance(self.nested_and_perp_prompt.children[0], neutral_prompt_parser.LeafPrompt)
62
+ self.assertIsInstance(self.nested_and_perp_prompt.children[1], neutral_prompt_parser.CompositePrompt)
63
+
64
+ def test_nested_and_perp_prompt_nested_child_types(self):
65
+ nested_child = self.nested_and_perp_prompt.children[1].children[0]
66
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
67
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
68
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
69
+
70
+ def test_nested_and_perp_prompt_nested_child(self):
71
+ nested_child = self.nested_and_perp_prompt.children[1].children[1]
72
+ self.assertEqual(nested_child.weight, 3.0)
73
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
74
+
75
+ def test_invalid_weight_child_count(self):
76
+ self.assertEqual(len(self.invalid_weight.children), 1)
77
+
78
+ def test_invalid_weight_child_weight(self):
79
+ self.assertEqual(self.invalid_weight.children[0].weight, 1.0)
80
+
81
+ def test_invalid_weight_child_prompt(self):
82
+ self.assertEqual(self.invalid_weight.children[0].prompt, "hello :not_a_float")
83
+
84
+ def test_and_salt_prompt_child_count(self):
85
+ self.assertEqual(len(self.and_salt_prompt.children), 2)
86
+
87
+ def test_and_salt_prompt_child_types(self):
88
+ self.assertIsInstance(self.and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
89
+ self.assertIsInstance(self.and_salt_prompt.children[1], neutral_prompt_parser.LeafPrompt)
90
+
91
+ def test_and_salt_prompt_nested_child(self):
92
+ nested_child = self.and_salt_prompt.children[1]
93
+ self.assertEqual(nested_child.weight, 2.0)
94
+ self.assertEqual(nested_child.prompt.strip(), "goodbye")
95
+
96
+ def test_nested_and_salt_prompt_child_count(self):
97
+ self.assertEqual(len(self.nested_and_salt_prompt.children), 2)
98
+
99
+ def test_nested_and_salt_prompt_child_types(self):
100
+ self.assertIsInstance(self.nested_and_salt_prompt.children[0], neutral_prompt_parser.LeafPrompt)
101
+ self.assertIsInstance(self.nested_and_salt_prompt.children[1], neutral_prompt_parser.CompositePrompt)
102
+
103
+ def test_nested_and_salt_prompt_nested_child_types(self):
104
+ nested_child = self.nested_and_salt_prompt.children[1].children[0]
105
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
106
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
107
+ self.assertIsInstance(nested_child, neutral_prompt_parser.LeafPrompt)
108
+
109
+ def test_nested_and_salt_prompt_nested_child(self):
110
+ nested_child = self.nested_and_salt_prompt.children[1].children[1]
111
+ self.assertEqual(nested_child.weight, 3.0)
112
+ self.assertEqual(nested_child.prompt.strip(), "welcome")
113
+
114
+ def test_start_with_prompt_editing(self):
115
+ prompt = "[(long shot:1.2):0.1] detail.."
116
+ res = neutral_prompt_parser.parse_root(prompt)
117
+ self.assertEqual(res.children[0].weight, 1.0)
118
+ self.assertEqual(res.children[0].prompt, prompt)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ unittest.main()
z-sd-webui-neutral-prompt-workYEAH4/test/perp_parser/malicious_test.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import pathlib
3
+ import sys
4
+ sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
5
+ from lib_neutral_prompt import neutral_prompt_parser
6
+
7
+
8
+ class TestMaliciousPromptParser(unittest.TestCase):
9
+ def setUp(self):
10
+ self.parser = neutral_prompt_parser
11
+
12
+ def test_empty(self):
13
+ result = self.parser.parse_root("")
14
+ self.assertEqual(result.children[0].prompt, "")
15
+ self.assertEqual(result.children[0].weight, 1.0)
16
+
17
+ def test_zero_weight(self):
18
+ result = self.parser.parse_root("hello :0.0")
19
+ self.assertEqual(result.children[0].weight, 0.0)
20
+
21
+ def test_mixed_positive_and_negative_weights(self):
22
+ result = self.parser.parse_root("hello :1.0 AND goodbye :-2.0")
23
+ self.assertEqual(result.children[0].weight, 1.0)
24
+ self.assertEqual(result.children[1].weight, -2.0)
25
+
26
+ def test_debalanced_square_brackets(self):
27
+ prompt = "a [ b " * 100
28
+ result = self.parser.parse_root(prompt)
29
+ self.assertEqual(result.children[0].prompt, prompt)
30
+
31
+ prompt = "a ] b " * 100
32
+ result = self.parser.parse_root(prompt)
33
+ self.assertEqual(result.children[0].prompt, prompt)
34
+
35
+ repeats = 10
36
+ prompt = "a [ [ b AND c ] " * repeats
37
+ result = self.parser.parse_root(prompt)
38
+ self.assertEqual([x.prompt for x in result.children], ["a [[ b ", *[" c ] a [[ b "] * (repeats - 1), " c ]"])
39
+
40
+ repeats = 10
41
+ prompt = "a [ b AND c ] ] " * repeats
42
+ result = self.parser.parse_root(prompt)
43
+ self.assertEqual([x.prompt for x in result.children], ["a [ b ", *[" c ]] a [ b "] * (repeats - 1), " c ]]"])
44
+
45
+ def test_erroneous_syntax(self):
46
+ result = self.parser.parse_root("hello :1.0 AND_PERP [goodbye :2.0")
47
+ self.assertEqual(result.children[0].weight, 1.0)
48
+ self.assertEqual(result.children[1].prompt, "[goodbye ")
49
+ self.assertEqual(result.children[1].weight, 2.0)
50
+
51
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye :2.0]")
52
+ self.assertEqual(result.children[0].weight, 1.0)
53
+ self.assertEqual(result.children[1].prompt, " goodbye ")
54
+
55
+ result = self.parser.parse_root("hello :1.0 AND_PERP goodbye] :2.0")
56
+ self.assertEqual(result.children[1].prompt, " goodbye]")
57
+ self.assertEqual(result.children[1].weight, 2.0)
58
+
59
+ result = self.parser.parse_root("hello :1.0 AND_PERP a [ goodbye :2.0")
60
+ self.assertEqual(result.children[1].weight, 2.0)
61
+ self.assertEqual(result.children[1].prompt, " a [ goodbye ")
62
+
63
+ result = self.parser.parse_root("hello :1.0 AND_PERP AND goodbye :2.0")
64
+ self.assertEqual(result.children[0].weight, 1.0)
65
+ self.assertEqual(result.children[2].prompt, " goodbye ")
66
+
67
+ def test_huge_number_of_prompt_parts(self):
68
+ result = self.parser.parse_root(" AND ".join(f"hello{i} :{i}" for i in range(10**4)))
69
+ self.assertEqual(len(result.children), 10**4)
70
+
71
+ def test_prompt_ending_with_weight(self):
72
+ result = self.parser.parse_root("hello :1.0 AND :2.0")
73
+ self.assertEqual(result.children[0].weight, 1.0)
74
+ self.assertEqual(result.children[1].prompt, "")
75
+ self.assertEqual(result.children[1].weight, 2.0)
76
+
77
+ def test_huge_input_string(self):
78
+ big_string = "hello :1.0 AND " * 10**4
79
+ result = self.parser.parse_root(big_string)
80
+ self.assertEqual(len(result.children), 10**4 + 1)
81
+
82
+ def test_deeply_nested_prompt(self):
83
+ deeply_nested_prompt = "hello :1.0" + " AND_PERP [goodbye :2.0" * 100 + "]" * 100
84
+ result = self.parser.parse_root(deeply_nested_prompt)
85
+ self.assertIsInstance(result.children[1], neutral_prompt_parser.CompositePrompt)
86
+
87
+ def test_complex_nested_prompts(self):
88
+ complex_prompt = "hello :1.0 AND goodbye :2.0 AND_PERP [welcome :3.0 AND farewell :4.0 AND_PERP greetings:5.0]"
89
+ result = self.parser.parse_root(complex_prompt)
90
+ self.assertEqual(result.children[0].weight, 1.0)
91
+ self.assertEqual(result.children[1].weight, 2.0)
92
+ self.assertEqual(result.children[2].children[0].weight, 3.0)
93
+ self.assertEqual(result.children[2].children[1].weight, 4.0)
94
+ self.assertEqual(result.children[2].children[2].weight, 5.0)
95
+
96
+ def test_string_with_random_characters(self):
97
+ random_chars = "ASDFGHJKL:@#$/.,|}{><~`12[3]456AND_PERP7890"
98
+ try:
99
+ self.parser.parse_root(random_chars)
100
+ except Exception:
101
+ self.fail("parse_root couldn't handle a string with random characters.")
102
+
103
+ def test_string_with_unexpected_symbols(self):
104
+ unexpected_symbols = "hello :1.0 AND $%^&*()goodbye :2.0"
105
+ try:
106
+ self.parser.parse_root(unexpected_symbols)
107
+ except Exception:
108
+ self.fail("parse_root couldn't handle a string with unexpected symbols.")
109
+
110
+ def test_string_with_unconventional_structure(self):
111
+ unconventional_structure = "hello :1.0 AND_PERP :2.0 AND [goodbye]"
112
+ try:
113
+ self.parser.parse_root(unconventional_structure)
114
+ except Exception:
115
+ self.fail("parse_root couldn't handle a string with unconventional structure.")
116
+
117
+ def test_string_with_mixed_alphabets_and_numbers(self):
118
+ mixed_alphabets_and_numbers = "123hello :1.0 AND goodbye456 :2.0"
119
+ try:
120
+ self.parser.parse_root(mixed_alphabets_and_numbers)
121
+ except Exception:
122
+ self.fail("parse_root couldn't handle a string with mixed alphabets and numbers.")
123
+
124
+ def test_string_with_nested_brackets(self):
125
+ nested_brackets = "hello :1.0 AND [goodbye :2.0 AND [[welcome :3.0]]]"
126
+ try:
127
+ self.parser.parse_root(nested_brackets)
128
+ except Exception:
129
+ self.fail("parse_root couldn't handle a string with nested brackets.")
130
+
131
+ def test_unmatched_opening_braces(self):
132
+ unmatched_opening_braces = "hello [[[[[[[[[ :1.0 AND_PERP goodbye :2.0"
133
+ try:
134
+ self.parser.parse_root(unmatched_opening_braces)
135
+ except Exception:
136
+ self.fail("parse_root couldn't handle a string with unmatched opening braces.")
137
+
138
+ def test_unmatched_closing_braces(self):
139
+ unmatched_closing_braces = "hello :1.0 AND_PERP goodbye ]]]]]]]]] :2.0"
140
+ try:
141
+ self.parser.parse_root(unmatched_closing_braces)
142
+ except Exception:
143
+ self.fail("parse_root couldn't handle a string with unmatched closing braces.")
144
+
145
+ def test_repeating_colons(self):
146
+ repeating_colons = "hello ::::::: :1.0 AND_PERP goodbye :::: :2.0"
147
+ try:
148
+ self.parser.parse_root(repeating_colons)
149
+ except Exception:
150
+ self.fail("parse_root couldn't handle a string with repeating colons.")
151
+
152
+ def test_excessive_whitespace(self):
153
+ excessive_whitespace = "hello :1.0 AND_PERP goodbye :2.0"
154
+ try:
155
+ self.parser.parse_root(excessive_whitespace)
156
+ except Exception:
157
+ self.fail("parse_root couldn't handle a string with excessive whitespace.")
158
+
159
+ def test_repeating_AND_keyword(self):
160
+ repeating_AND_keyword = "hello :1.0 AND AND AND AND AND goodbye :2.0"
161
+ try:
162
+ self.parser.parse_root(repeating_AND_keyword)
163
+ except Exception:
164
+ self.fail("parse_root couldn't handle a string with repeating AND keyword.")
165
+
166
+ def test_repeating_AND_PERP_keyword(self):
167
+ repeating_AND_PERP_keyword = "hello :1.0 AND_PERP AND_PERP AND_PERP AND_PERP goodbye :2.0"
168
+ try:
169
+ self.parser.parse_root(repeating_AND_PERP_keyword)
170
+ except Exception:
171
+ self.fail("parse_root couldn't handle a string with repeating AND_PERP keyword.")
172
+
173
+ def test_square_weight_prompt(self):
174
+ prompt = "AND_PERP [weighted] you thought it was the end"
175
+ try:
176
+ self.parser.parse_root(prompt)
177
+ except Exception:
178
+ self.fail("parse_root couldn't handle a string starting with a square-weighted sub-prompt.")
179
+
180
+
181
+ if __name__ == '__main__':
182
+ unittest.main()