dikdimon commited on
Commit
33b9772
·
verified ·
1 Parent(s): bee3485

Upload z-prompt-fusion-extension using SD-Hub

Browse files
Files changed (27) hide show
  1. z-prompt-fusion-extension/.gitignore +3 -0
  2. z-prompt-fusion-extension/LICENSE +21 -0
  3. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc +0 -0
  4. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc +0 -0
  5. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc +0 -0
  6. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc +0 -0
  7. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc +0 -0
  8. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc +0 -0
  9. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc +0 -0
  10. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc +0 -0
  11. z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc +0 -0
  12. z-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py +307 -0
  13. z-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py +19 -0
  14. z-prompt-fusion-extension/lib_prompt_fusion/geometries.py +33 -0
  15. z-prompt-fusion-extension/lib_prompt_fusion/global_state.py +28 -0
  16. z-prompt-fusion-extension/lib_prompt_fusion/hijacker.py +34 -0
  17. z-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py +87 -0
  18. z-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py +249 -0
  19. z-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py +378 -0
  20. z-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py +38 -0
  21. z-prompt-fusion-extension/metadata.ini +2 -0
  22. z-prompt-fusion-extension/readme.md +95 -0
  23. z-prompt-fusion-extension/requirements.txt +1 -0
  24. z-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc +0 -0
  25. z-prompt-fusion-extension/scripts/promptlang.py +354 -0
  26. z-prompt-fusion-extension/test/parser_tests.py +104 -0
  27. z-prompt-fusion-extension/test/run_all.py +7 -0
z-prompt-fusion-extension/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv/
2
+ /.idea/
3
+ __pycache__/
z-prompt-fusion-extension/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2003
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc ADDED
Binary file (791 Bytes). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc ADDED
Binary file (9.9 kB). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc ADDED
Binary file (873 Bytes). View file
 
z-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from lib_prompt_fusion import interpolation_functions
3
+ from lib_prompt_fusion.t_scaler import scale_t
4
+ from lib_prompt_fusion import interpolation_tensor
5
+
6
+
7
+ class ListExpression:
8
+ def __init__(self, expressions):
9
+ self.__expressions = expressions
10
+
11
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
12
+ if not self.__expressions:
13
+ return
14
+
15
+ def expr_extend_tensor(expr):
16
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
17
+
18
+ expr_extend_tensor(self.__expressions[0])
19
+ for expression in self.__expressions[1:]:
20
+ tensor_builder.append(' ')
21
+ expr_extend_tensor(expression)
22
+
23
+
24
+ class InterpolationExpression:
25
+ @staticmethod
26
+ def create(exprs, steps, function_name):
27
+ if function_name == "mean":
28
+ return AverageExpression(exprs, steps)
29
+
30
+ max_len = min(len(exprs), len(steps))
31
+ exprs = exprs[:max_len]
32
+ steps = steps[:max_len]
33
+
34
+ return InterpolationExpression(exprs, steps, function_name)
35
+
36
+ def __init__(self, expressions, steps, function_name=None):
37
+ assert len(expressions) >= 2
38
+ assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
39
+ self.__expressions = expressions
40
+ self.__steps = steps
41
+ self.__function_name = function_name if function_name is not None else 'linear'
42
+
43
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
44
+ def tensor_updater(expr):
45
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
46
+
47
+ tensor_builder.extrude(
48
+ [tensor_updater(expr) for expr in self.__expressions],
49
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
50
+
51
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
52
+ steps = list(self.__steps)
53
+ if steps[0] is None:
54
+ steps[0] = LiftExpression(str(steps_range[0] - 1))
55
+ if steps[-1] is None:
56
+ steps[-1] = LiftExpression(str(steps_range[1] - 1))
57
+
58
+ for i, step in enumerate(steps):
59
+ if step is None:
60
+ continue
61
+
62
+ step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
63
+
64
+ if use_old_scheduling and 0 < step < 1:
65
+ step *= total_steps
66
+ elif not use_old_scheduling and isinstance(step, float):
67
+ step = (step - int(is_hires)) * total_steps
68
+ else:
69
+ step += 1
70
+
71
+ steps[i] = int(step)
72
+
73
+ i = 1
74
+ while i < len(steps):
75
+ none_len = 0
76
+ while steps[i + none_len] is None:
77
+ none_len += 1
78
+
79
+ min_step, max_step = steps[i - 1], steps[i + none_len]
80
+
81
+ for j in range(none_len):
82
+ steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
83
+
84
+ i += 1 + none_len
85
+
86
+ interpolation_function = {
87
+ 'linear': interpolation_functions.compute_linear,
88
+ 'bezier': interpolation_functions.compute_bezier,
89
+ 'catmull': interpolation_functions.compute_catmull,
90
+ }[self.__function_name]
91
+
92
+ def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
93
+ scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
94
+ scaled_t = scale_t(scaled_t, steps)
95
+
96
+ new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
97
+ return interpolation_function(conds, new_params)
98
+
99
+ return steps_scale_t
100
+
101
+
102
+ class AverageExpression:
103
+ def __init__(self, expressions, weights):
104
+ if len(expressions) < len(weights):
105
+ raise ValueError
106
+
107
+ self.__expressions = expressions
108
+ self.__weights = weights
109
+
110
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
111
+ def tensor_updater(expr):
112
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
113
+
114
+ tensor_builder.extrude(
115
+ [tensor_updater(expr) for expr in self.__expressions],
116
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
117
+
118
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
119
+ weights = [
120
+ _eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
121
+ for weight in self.__weights
122
+ ]
123
+ explicit_weights = [weight for weight in weights if weight is not None]
124
+ weights = [
125
+ weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
126
+ if weight is not None
127
+ else 1 / len(self.__expressions)
128
+ for weight in weights
129
+ ]
130
+ weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
131
+
132
+ def interpolation_function(conds, _params):
133
+ total = None
134
+ for cond, weight in zip(conds, weights):
135
+ cond *= weight
136
+ if total is None:
137
+ total = cond
138
+ else:
139
+ total += cond
140
+
141
+ return total
142
+
143
+ return interpolation_function
144
+
145
+
146
+ class AlternationExpression:
147
+ def __init__(self, expressions, speed):
148
+ self.__expressions = expressions
149
+ self.__speed = speed
150
+
151
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
152
+ if self.__speed is None:
153
+ speed = None
154
+ else:
155
+ speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
156
+
157
+ if speed is None:
158
+ tensor_builder.append('[')
159
+ for expr_i, expr in enumerate(self.__expressions):
160
+ if expr_i >= 1:
161
+ tensor_builder.append('|')
162
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
163
+ tensor_builder.append(']')
164
+ return
165
+
166
+ def tensor_updater(expr):
167
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
168
+
169
+ exprs = self.__expressions + [self.__expressions[0]]
170
+
171
+ tensor_builder.extrude(
172
+ [tensor_updater(expr) for expr in exprs],
173
+ self.get_interpolation_function(speed, exprs, steps_range, total_steps))
174
+
175
+ def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
176
+ def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
177
+ wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
178
+ if wrapped_t < 0:
179
+ wrapped_t = wrapped_t + 1
180
+ new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
181
+ return interpolation_functions.compute_linear(control_points, new_params)
182
+
183
+ return compute_wrap
184
+
185
+
186
+ class EditingExpression:
187
+ def __init__(self, expressions, step):
188
+ assert 1 <= len(expressions) <= 2
189
+ self.__expressions = expressions
190
+ self.__step = step
191
+
192
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
193
+ if self.__step is None:
194
+ tensor_builder.append('[')
195
+ for expr_i, expr in enumerate(self.__expressions):
196
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
197
+ tensor_builder.append(':')
198
+ tensor_builder.append(']')
199
+ return
200
+
201
+ step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
202
+ step_int = step
203
+ if use_old_scheduling and 0 < step < 1:
204
+ step_int *= total_steps
205
+ elif not use_old_scheduling and isinstance(step, float):
206
+ step_int = (step_int - int(is_hires)) * total_steps
207
+ else:
208
+ step_int += 1
209
+
210
+ step_int = int(step_int)
211
+
212
+ tensor_builder.append('[')
213
+ for expr_i, expr in enumerate(self.__expressions):
214
+ expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
215
+ expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
216
+ tensor_builder.append(':')
217
+
218
+ tensor_builder.append(f'{step}]')
219
+
220
+
221
+ class WeightedExpression:
222
+ def __init__(self, nested, weight=None, positive=True):
223
+ self.__nested = nested
224
+ if not positive:
225
+ assert weight is None
226
+
227
+ self.__weight = weight
228
+ self.__positive = positive
229
+
230
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
231
+ open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
232
+ tensor_builder.append(open_bracket)
233
+ self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
234
+
235
+ if self.__weight is not None:
236
+ tensor_builder.append(':')
237
+ self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
238
+
239
+ tensor_builder.append(close_bracket)
240
+
241
+
242
+ class WeightInterpolationExpression:
243
+ def __init__(self, nested, weight_begin, weight_end):
244
+ self.__nested = nested
245
+ self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
246
+ self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
247
+
248
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
249
+ steps_range_size = steps_range[1] - steps_range[0]
250
+
251
+ weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
252
+ weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
253
+
254
+ for i in range(steps_range_size):
255
+ step = i + steps_range[0]
256
+
257
+ weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
258
+ weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
259
+ if step > steps_range[0]:
260
+ weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
261
+ if step + 1 < steps_range[1]:
262
+ weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
263
+
264
+ weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
265
+
266
+
267
+ class DeclarationExpression:
268
+ def __init__(self, symbol, parameters, value, target):
269
+ self.__symbol = symbol
270
+ self.__value = value
271
+ self.__target = target
272
+ self.__parameters = parameters
273
+
274
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
275
+ updated_context = dict(context)
276
+ updated_context[self.__symbol] = (self.__value, self.__parameters)
277
+ self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
278
+
279
+
280
+ class SubstitutionExpression:
281
+ def __init__(self, symbol, arguments):
282
+ self.__symbol = symbol
283
+ self.__arguments = arguments
284
+
285
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
286
+ updated_context = dict(context)
287
+ nested, parameters = context[self.__symbol]
288
+ for argument, parameter in zip(self.__arguments, parameters):
289
+ updated_context[parameter] = argument, []
290
+ nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
291
+
292
+
293
+ class LiftExpression:
294
+ def __init__(self, value):
295
+ self.__value = value
296
+
297
+ def extend_tensor(self, tensor_builder, *_args, **_kwargs):
298
+ tensor_builder.append(self.__value)
299
+
300
+
301
+ def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
302
+ mock_database = ['']
303
+ expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
304
+ try:
305
+ return int(mock_database[0])
306
+ except ValueError:
307
+ return float(mock_database[0])
z-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion import interpolation_tensor
2
+
3
+
4
+ _empty_cond = None
5
+
6
+
7
+ def get():
8
+ return _empty_cond
9
+
10
+
11
+ def init(model):
12
+ global _empty_cond
13
+ cond = model.get_learned_conditioning([''])
14
+ if isinstance(cond, dict):
15
+ cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
16
+ else:
17
+ cond = interpolation_tensor.TensorCondWrapper(cond[0])
18
+
19
+ _empty_cond = cond
z-prompt-fusion-extension/lib_prompt_fusion/geometries.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from lib_prompt_fusion import interpolation_tensor
4
+
5
+
6
+ def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
7
+ p0, p1 = control_points
8
+ p0_norm = torch.linalg.norm(p0)
9
+ p1_norm = torch.linalg.norm(p1)
10
+
11
+ similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
12
+ similarity = min(1., max(-1., float(similarity)))
13
+ if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
14
+ return linear_geometry(control_points, params)
15
+
16
+ angle = math.acos(float(similarity)) / 2
17
+
18
+ slerp_t = angle * (2 * params.t - 1)
19
+ slerp_t = math.tan(slerp_t) / math.tan(angle)
20
+ slerp_t = (slerp_t + 1) / 2
21
+
22
+ normalized_p1 = p1 / p1_norm * p0_norm
23
+ slerp_p = p0 + (normalized_p1 - p0) * slerp_t
24
+ slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
25
+
26
+ lerp_p = linear_geometry(control_points, params)
27
+ return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
28
+
29
+
30
+ def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
31
+ p0, p1 = control_points
32
+ res = p0 + (p1 - p0) * params.t
33
+ return res
z-prompt-fusion-extension/lib_prompt_fusion/global_state.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from modules import shared, prompt_parser
3
+ from lib_prompt_fusion import empty_cond
4
+
5
+
6
+ old_webui_is_negative: bool = False
7
+ negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
8
+ negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
9
+
10
+
11
+ def get_origin_cond_at(step: int, is_hires: bool = False):
12
+ fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
13
+ if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
14
+ return empty_cond.get()
15
+
16
+ for schedule in fallback_schedules:
17
+ if schedule.end_at_step >= step:
18
+ return schedule.cond
19
+
20
+ return empty_cond.get()
21
+
22
+
23
+ def get_slerp_scale():
24
+ return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
25
+
26
+
27
+ def get_slerp_epsilon():
28
+ return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
z-prompt-fusion-extension/lib_prompt_fusion/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModuleHijacker:
2
+ def __init__(self, module):
3
+ self.__module = module
4
+ self.__original_functions = dict()
5
+
6
+ def hijack(self, attribute):
7
+ if attribute not in self.__original_functions:
8
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
9
+
10
+ def decorator(function):
11
+ def wrapper(*args, **kwargs):
12
+ return function(*args, **kwargs, original_function=self.__original_functions[attribute])
13
+
14
+ setattr(self.__module, attribute, wrapper)
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ register_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ register_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
z-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from lib_prompt_fusion import interpolation_tensor, geometries
4
+
5
+
6
+ def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
7
+ if len(control_points) <= 2:
8
+ return geometries.slerp_geometry(control_points, params)
9
+ else:
10
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
11
+ cp0 = control_points[target_curve]
12
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
13
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
14
+ return geometries.slerp_geometry([cp0, cp1], new_params)
15
+
16
+
17
+ def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
18
+ def compute_casteljau(ps, size):
19
+ for i in reversed(range(1, size)):
20
+ for j in range(i):
21
+ ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
22
+
23
+ return ps[0]
24
+
25
+ if len(control_points) == 1:
26
+ return control_points[0]
27
+ elif len(control_points) == 2:
28
+ return geometries.slerp_geometry(control_points, params)
29
+ copied_control_points = copy.deepcopy(control_points)
30
+ return compute_casteljau(copied_control_points, len(copied_control_points))
31
+
32
+
33
+ def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
34
+ if len(control_points) <= 2:
35
+ return compute_linear(control_points, params)
36
+ else:
37
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
38
+ g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
39
+ cp0 = control_points[target_curve]
40
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
41
+ g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
42
+ ip0 = cp0 + (cp1 - g0)/6
43
+ ip1 = cp1 + (cp0 - g1)/6
44
+
45
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
46
+ return compute_bezier([cp0, ip0, ip1, cp1], new_params)
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import turtle as tr
51
+ import torch
52
+ size = 60
53
+ turtle_tool = tr.Turtle()
54
+ turtle_tool.speed(10)
55
+ turtle_tool.up()
56
+
57
+ points = torch.Tensor([[-2., -2.], [2., 2.]])
58
+ origin = torch.Tensor([1.5, 1.6])
59
+
60
+ def sample(slerp_scale, color):
61
+ for i in range(size):
62
+ t = i / size
63
+ params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
64
+ point = origin + compute_linear(points - origin, params)
65
+ try:
66
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
67
+ turtle_tool.dot(5, color)
68
+ print(point)
69
+ except ValueError:
70
+ pass
71
+
72
+ sample(0, "black")
73
+ sample(1, "green")
74
+ sample(2, "blue")
75
+ sample(-1, "purple")
76
+ sample(-2, "orange")
77
+
78
+ for point in points:
79
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
80
+ turtle_tool.dot(5, "red")
81
+
82
+ turtle_tool.goto(tuple(float(p) * 100. for p in origin))
83
+ turtle_tool.dot(10, "red")
84
+
85
+ turtle_tool.goto(100000, 100000)
86
+ turtle_tool.dot()
87
+ tr.done()
z-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from modules import prompt_parser
4
+ from typing import NamedTuple, Union
5
+
6
+
7
+ class InterpolationParams(NamedTuple):
8
+ t: float
9
+ step: int
10
+ total_steps: int
11
+ slerp_scale: float
12
+ slerp_epsilon: float
13
+
14
+
15
+ class InterpolationTensor:
16
+ def __init__(self, sub_tensors, interpolation_function):
17
+ self.__sub_tensors = sub_tensors
18
+ self.__interpolation_function = interpolation_function
19
+
20
+ def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
21
+ cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
22
+ if params.slerp_scale != 0:
23
+ cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
24
+ return cond
25
+
26
+ def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
27
+ if self.__interpolation_function is None:
28
+ return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
29
+
30
+ control_points = [
31
+ sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
32
+ for sub_tensor in self.__sub_tensors
33
+ ]
34
+
35
+ CondWrapper, control_points_values = conds_to_cp_values(control_points)
36
+ return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
37
+
38
+ def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
39
+ schedule = None
40
+ for schedule in self.__sub_tensors:
41
+ if schedule.end_at_step >= step:
42
+ break
43
+
44
+ res = schedule.cond.extend_like(origin_cond, empty_cond)
45
+ if slerp_scale != 0:
46
+ res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
47
+ return res
48
+
49
+
50
+ def conds_to_cp_values(conds):
51
+ CondWrapper = type(conds[0])
52
+ cp_values = [
53
+ cond.to_cp_values()
54
+ for cond in conds
55
+ ]
56
+ return CondWrapper, [
57
+ [v[i] for v in cp_values]
58
+ for i in range(len(cp_values[0]))
59
+ ]
60
+
61
+
62
+ class InterpolationTensorBuilder:
63
+ def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
64
+ self.__indices_tensor = tensor if tensor is not None else 0
65
+ self.__prompt_database = prompt_database if prompt_database is not None else ['']
66
+ self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
67
+
68
+ def append(self, suffix):
69
+ for i in range(len(self.__prompt_database)):
70
+ self.__prompt_database[i] += suffix
71
+
72
+ def extrude(self, tensor_updaters, interpolation_function):
73
+ extruded_indices_tensor = []
74
+ extruded_prompt_database = []
75
+ extruded_interpolation_functions = []
76
+
77
+ for update_tensor in tensor_updaters:
78
+ nested_tensor_builder = InterpolationTensorBuilder(
79
+ self.__indices_tensor,
80
+ self.__prompt_database[:],
81
+ interpolation_functions=[])
82
+
83
+ update_tensor(nested_tensor_builder)
84
+
85
+ extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
86
+ tensor=nested_tensor_builder.__indices_tensor,
87
+ offset=len(extruded_prompt_database)))
88
+ extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
89
+ extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
90
+
91
+ self.__indices_tensor = extruded_indices_tensor
92
+ self.__prompt_database[:] = extruded_prompt_database
93
+ self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
94
+
95
+ def get_prompt_database(self):
96
+ return self.__prompt_database
97
+
98
+ @staticmethod
99
+ def __offset_tensor(tensor, offset):
100
+ try:
101
+ return tensor + offset
102
+
103
+ except TypeError:
104
+ return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
105
+
106
+ def build(self, conds, empty_cond):
107
+ max_cond_size = self.__max_cond_size(conds)
108
+ conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
109
+ return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
110
+
111
+ @staticmethod
112
+ def __build_conditionings_tensor(tensor, int_funcs, conds):
113
+ if type(tensor) is int:
114
+ return InterpolationTensor(conds[tensor], None)
115
+ else:
116
+ int_func, nested_int_funcs = int_funcs[0]
117
+ return InterpolationTensor(
118
+ [
119
+ InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
120
+ for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
121
+ ],
122
+ int_func,
123
+ )
124
+
125
+ def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
126
+ return [
127
+ [
128
+ prompt_parser.ScheduledPromptConditioning(
129
+ cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
130
+ end_at_step=schedule.end_at_step
131
+ )
132
+ for schedule in schedules
133
+ ]
134
+ for schedules in conds
135
+ ]
136
+
137
+ @staticmethod
138
+ def __max_cond_size(conds):
139
+ return max(schedule.cond.size(0)
140
+ for schedules in conds
141
+ for schedule in schedules)
142
+
143
+
144
+ @dataclasses.dataclass
145
+ class DictCondWrapper:
146
+ original_cond: dict
147
+
148
+ @staticmethod
149
+ def from_cp_values(cp_values):
150
+ return DictCondWrapper({
151
+ k: v
152
+ for k, v in zip(('crossattn', 'vector'), cp_values)
153
+ })
154
+
155
+ def size(self, *args, **kwargs):
156
+ return self.original_cond['crossattn'].size(*args, **kwargs)
157
+
158
+ def extend_like(self, that, empty):
159
+ missing_size = max(0, that.size(0) - self.size(0)) // 77
160
+ extended = DictCondWrapper(self.original_cond.copy())
161
+ extended.original_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
162
+ return extended
163
+
164
+ def resize_schedule(self, target_size, empty_cond):
165
+ cond_missing_size = (target_size - self.size(0)) // 77
166
+ if cond_missing_size <= 0:
167
+ return self
168
+
169
+ resized_cond = self.original_cond.copy()
170
+ resized_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
171
+ return DictCondWrapper(resized_cond)
172
+
173
+ def to_cp_values(self):
174
+ return list(self.original_cond.values())
175
+
176
+ def to(self, dtype: Union[dict, torch.dtype]):
177
+ if not isinstance(dtype, dict):
178
+ dtype = {
179
+ k: dtype
180
+ for k in self.original_cond.keys()
181
+ }
182
+ return DictCondWrapper({
183
+ k: v.to(dtype=dtype[k])
184
+ for k, v in self.original_cond.items()
185
+ })
186
+
187
+ @property
188
+ def dtype(self):
189
+ return {
190
+ k: v.dtype
191
+ for k, v in self.original_cond.items()
192
+ }
193
+
194
+ def __sub__(self, that):
195
+ return DictCondWrapper({
196
+ k: v - that.original_cond[k]
197
+ for k, v in self.original_cond.items()
198
+ })
199
+
200
+ def __add__(self, that):
201
+ return DictCondWrapper({
202
+ k: v + that.original_cond[k]
203
+ for k, v in self.original_cond.items()
204
+ })
205
+
206
+ def __eq__(self, that):
207
+ return all((self.original_cond[k] == that.original_cond[k]).all() for k in self.original_cond.keys())
208
+
209
+
210
+ @dataclasses.dataclass
211
+ class TensorCondWrapper:
212
+ original_cond: torch.Tensor
213
+
214
+ @staticmethod
215
+ def from_cp_values(cp_values):
216
+ return TensorCondWrapper(next(iter(cp_values)))
217
+
218
+ def size(self, *args, **kwargs):
219
+ return self.original_cond.size(*args, **kwargs)
220
+
221
+ def extend_like(self, that, empty):
222
+ missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
223
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty.original_cond] * missing_size))
224
+
225
+ def resize_schedule(self, target_size, empty_cond):
226
+ cond_missing_size = (target_size - self.original_cond.size(0)) // 77
227
+ if cond_missing_size <= 0:
228
+ return self
229
+
230
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
231
+
232
+ def to_cp_values(self):
233
+ return [self.original_cond]
234
+
235
+ def to(self, dtype: torch.dtype):
236
+ return TensorCondWrapper(self.original_cond.to(dtype=dtype))
237
+
238
+ @property
239
+ def dtype(self):
240
+ return self.original_cond.dtype
241
+
242
+ def __sub__(self, that):
243
+ return TensorCondWrapper(self.original_cond - that.original_cond)
244
+
245
+ def __add__(self, that):
246
+ return TensorCondWrapper(self.original_cond + that.original_cond)
247
+
248
+ def __eq__(self, that):
249
+ return (self.original_cond == that.original_cond).all()
z-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib_prompt_fusion.ast_nodes as ast
2
+ from collections import namedtuple
3
+ import re
4
+
5
+
6
+ ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
7
+
8
+
9
+ def parse_prompt(prompt):
10
+ prompt = prompt.lstrip()
11
+ prompt, list_expr = parse_list_expression(prompt, set())
12
+ return list_expr
13
+
14
+
15
+ def parse_list_expression(prompt, stoppers):
16
+ exprs = []
17
+ try:
18
+ while True:
19
+ prompt, expr = parse_expression(prompt, stoppers)
20
+ exprs.append(expr)
21
+ except ValueError:
22
+ return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
23
+
24
+
25
+ def parse_expression(prompt, stoppers):
26
+ for parse in _parsers():
27
+ try:
28
+ return parse(prompt, stoppers)
29
+ except ValueError:
30
+ pass
31
+
32
+ raise ValueError
33
+
34
+
35
+ def _parsers():
36
+ return (
37
+ parse_text,
38
+ parse_declaration,
39
+ parse_substitution,
40
+ parse_positive_attention,
41
+ parse_negative_attention,
42
+ parse_editing,
43
+ parse_alternation,
44
+ parse_interpolation,
45
+ parse_unrestricted_text,
46
+ )
47
+
48
+
49
+ def parse_text(prompt, stoppers):
50
+ return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
51
+
52
+
53
+ def parse_unrestricted_text(prompt, stoppers):
54
+ escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
55
+ regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
56
+ prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
57
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
58
+
59
+
60
+ def parse_substitution(prompt, stoppers):
61
+ prompt, symbol = parse_symbol(prompt, stoppers)
62
+ prompt, arguments = parse_arguments(prompt, stoppers)
63
+ return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
64
+
65
+
66
+ def parse_arguments(prompt, stoppers):
67
+ try:
68
+ prompt, _ = parse_open_paren(prompt, stoppers)
69
+ prompt, arguments = parse_inner_arguments(prompt, stoppers)
70
+ prompt, _ = parse_close_paren(prompt, stoppers)
71
+ except ValueError:
72
+ arguments = []
73
+ return ParseResult(prompt=prompt, expr=arguments)
74
+
75
+
76
+ def parse_inner_arguments(prompt, stoppers):
77
+ arguments = []
78
+ try:
79
+ while True:
80
+ prompt, arg = parse_list_expression(prompt, {',', ')'})
81
+ arguments.append(arg)
82
+ prompt, _ = parse_comma(prompt, stoppers)
83
+ except ValueError:
84
+ pass
85
+
86
+ return ParseResult(prompt=prompt, expr=arguments)
87
+
88
+
89
+ def parse_declaration(prompt, stoppers):
90
+ prompt, symbol = parse_symbol(prompt, stoppers)
91
+ prompt, parameters = parse_parameters(prompt, stoppers)
92
+ prompt, _ = parse_equals(prompt, stoppers)
93
+ prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
94
+ prompt, _ = parse_newline(prompt, stoppers)
95
+ prompt, expr = parse_list_expression(prompt, stoppers)
96
+ return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
97
+
98
+
99
+ def parse_parameters(prompt, stoppers):
100
+ try:
101
+ prompt, _ = parse_open_paren(prompt, stoppers)
102
+ prompt, parameters = parse_inner_parameters(prompt, stoppers)
103
+ prompt, _ = parse_close_paren(prompt, stoppers)
104
+ except ValueError:
105
+ parameters = []
106
+ return ParseResult(prompt=prompt, expr=parameters)
107
+
108
+
109
+ def parse_inner_parameters(prompt, stoppers):
110
+ parameters = []
111
+ try:
112
+ while True:
113
+ prompt, param = parse_symbol(prompt, stoppers)
114
+ parameters.append(param)
115
+ prompt, _ = parse_comma(prompt, stoppers)
116
+ except ValueError:
117
+ pass
118
+
119
+ return ParseResult(prompt=prompt, expr=parameters)
120
+
121
+
122
+ def parse_interpolation(prompt, stoppers):
123
+ prompt, _ = parse_open_square(prompt, stoppers)
124
+ prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
125
+ prompt, steps = parse_interpolation_steps(prompt, stoppers)
126
+ prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
127
+ prompt, _ = parse_close_square(prompt, stoppers)
128
+ return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
129
+
130
+
131
+ def parse_interpolation_exprs(prompt, stoppers):
132
+ exprs = []
133
+
134
+ try:
135
+ while True:
136
+ prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
137
+ if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
138
+ raise ValueError
139
+
140
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
141
+ exprs.append(expr)
142
+ except ValueError:
143
+ pass
144
+
145
+ return ParseResult(prompt=prompt, expr=exprs)
146
+
147
+
148
+ def parse_interpolation_function_name(prompt, stoppers):
149
+ try:
150
+ prompt, _ = parse_colon(prompt, stoppers)
151
+ function_names = ('linear', 'catmull', 'bezier', 'mean')
152
+ return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
153
+ except ValueError:
154
+ return ParseResult(prompt=prompt, expr=None)
155
+
156
+
157
+ def parse_interpolation_steps(prompt, stoppers):
158
+ steps = []
159
+
160
+ try:
161
+ while True:
162
+ prompt, step = parse_interpolation_step(prompt, stoppers)
163
+ steps.append(step)
164
+ prompt, _ = parse_comma(prompt, stoppers)
165
+ except ValueError:
166
+ pass
167
+
168
+ return ParseResult(prompt=prompt, expr=steps)
169
+
170
+
171
+ def parse_interpolation_step(prompt, stoppers):
172
+ try:
173
+ return parse_step(prompt, stoppers)
174
+ except ValueError:
175
+ pass
176
+
177
+ if prompt[0] in {',', ':', ']'}:
178
+ return ParseResult(prompt=prompt, expr=None)
179
+
180
+ raise ValueError
181
+
182
+
183
+ def parse_alternation(prompt, stoppers):
184
+ prompt, _ = parse_open_square(prompt, stoppers)
185
+ prompt, exprs = parse_alternation_exprs(prompt, stoppers)
186
+ prompt, speed = parse_alternation_speed(prompt, stoppers)
187
+ prompt, _ = parse_close_square(prompt, stoppers)
188
+ return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
189
+
190
+
191
+ def parse_alternation_exprs(prompt, stoppers):
192
+ exprs = []
193
+
194
+ try:
195
+ while True:
196
+ prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
197
+ exprs.append(expr)
198
+ prompt, _ = parse_vertical_bar(prompt, stoppers)
199
+ except ValueError:
200
+ if len(exprs) < 2:
201
+ raise
202
+
203
+ return ParseResult(prompt=prompt, expr=exprs)
204
+
205
+
206
+ def parse_alternation_speed(prompt, stoppers):
207
+ try:
208
+ prompt, _ = parse_colon(prompt, stoppers)
209
+ prompt, speed = parse_step(prompt, stoppers)
210
+ return ParseResult(prompt=prompt, expr=speed)
211
+ except ValueError:
212
+ pass
213
+
214
+ return ParseResult(prompt=prompt, expr=None)
215
+
216
+
217
+ def parse_editing(prompt, stoppers):
218
+ prompt, _ = parse_open_square(prompt, stoppers)
219
+ prompt, exprs = parse_editing_exprs(prompt, stoppers)
220
+ try:
221
+ prompt, step = parse_step(prompt, stoppers)
222
+ except ValueError:
223
+ step = None
224
+
225
+ prompt, _ = parse_close_square(prompt, stoppers)
226
+ return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
227
+
228
+
229
+ def parse_editing_exprs(prompt, stoppers):
230
+ exprs = []
231
+
232
+ try:
233
+ for _ in range(2):
234
+ prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
235
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
236
+ exprs.append(expr)
237
+ except ValueError:
238
+ pass
239
+
240
+ return ParseResult(prompt=prompt, expr=exprs)
241
+
242
+
243
+ def parse_negative_attention(prompt, stoppers):
244
+ prompt, _ = parse_open_square(prompt, stoppers)
245
+ prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
246
+ prompt, _ = parse_close_square(prompt, stoppers)
247
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
248
+
249
+
250
+ def parse_positive_attention(prompt, stoppers):
251
+ prompt, _ = parse_open_paren(prompt, stoppers)
252
+ prompt, expr = parse_list_expression(prompt, {':', ')'})
253
+ prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
254
+ prompt, _ = parse_close_paren(prompt, stoppers)
255
+ if len(weight_exprs) >= 2:
256
+ return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
257
+ else:
258
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
259
+
260
+
261
+ def parse_attention_weights(prompt, stoppers):
262
+ weights = []
263
+ try:
264
+ prompt, _ = parse_colon(prompt, stoppers)
265
+ except ValueError:
266
+ return ParseResult(prompt=prompt, expr=weights)
267
+
268
+ while True:
269
+ try:
270
+ prompt, weight_expr = parse_weight(prompt, stoppers)
271
+ weights.append(weight_expr)
272
+ prompt, _ = parse_comma(prompt, stoppers)
273
+ except ValueError:
274
+ return ParseResult(prompt=prompt, expr=weights)
275
+
276
+
277
+ def parse_step(prompt, stoppers):
278
+ try:
279
+ prompt, step = parse_int_not_float(prompt, stoppers)
280
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
281
+ except ValueError:
282
+ pass
283
+ try:
284
+ prompt, step = parse_float(prompt, stoppers)
285
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
286
+ except ValueError:
287
+ pass
288
+
289
+ return parse_substitution(prompt, stoppers)
290
+
291
+
292
+ def parse_weight(prompt, stoppers):
293
+ try:
294
+ prompt, step = parse_float(prompt, stoppers)
295
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
296
+ except ValueError:
297
+ pass
298
+
299
+ return parse_substitution(prompt, stoppers)
300
+
301
+
302
+ def parse_symbol(prompt, stoppers):
303
+ prompt, _ = parse_dollar(prompt)
304
+ return parse_symbol_text(prompt, stoppers)
305
+
306
+
307
+ def parse_symbol_text(prompt, stoppers):
308
+ return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
309
+
310
+
311
+ def parse_float(prompt, stoppers):
312
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
313
+
314
+
315
+ def parse_int_not_float(prompt, stoppers):
316
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
317
+
318
+
319
+ def parse_dollar(prompt):
320
+ dollar_sign = re.escape('$')
321
+ return parse_token(prompt, f'({dollar_sign})')
322
+
323
+
324
+ def parse_equals(prompt, stoppers):
325
+ return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
326
+
327
+
328
+ def parse_comma(prompt, stoppers):
329
+ return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
330
+
331
+
332
+ def parse_colon(prompt, stoppers):
333
+ return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
334
+
335
+
336
+ def parse_vertical_bar(prompt, stoppers):
337
+ return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
338
+
339
+
340
+ def parse_open_square(prompt, stoppers):
341
+ return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
342
+
343
+
344
+ def parse_close_square(prompt, stoppers):
345
+ return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
346
+
347
+
348
+ def parse_open_paren(prompt, stoppers):
349
+ return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
350
+
351
+
352
+ def parse_close_paren(prompt, stoppers):
353
+ return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
354
+
355
+
356
+ def parse_newline(prompt, stoppers):
357
+ return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
358
+
359
+
360
+ def parse_token(prompt, regex):
361
+ match = re.match(regex, prompt)
362
+ if match is None:
363
+ raise ValueError
364
+
365
+ return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
366
+
367
+
368
+ def whitespace_tail_regex(regex, stoppers):
369
+ if '\n' in stoppers:
370
+ return rf'({regex})[ \t\f\r]*'
371
+
372
+ return rf'({regex})\s*'
373
+
374
+
375
+ def set_concat(left, right):
376
+ result = set(left)
377
+ result.update(right)
378
+ return result
z-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scale_t(t, positions):
2
+ if t >= 1.:
3
+ return 1.
4
+
5
+ if t <= 0.:
6
+ return 0.
7
+
8
+ distances = []
9
+ for i in range(len(positions)-1):
10
+ distances.append(positions[i+1] - positions[i])
11
+
12
+ total_distance = sum(distances)
13
+ for i in range(len(distances)):
14
+ distances[i] = distances[i]/total_distance
15
+
16
+ for i in range(len(distances)-1):
17
+ distances[i+1] = distances[i] + distances[i+1]
18
+
19
+ distances.insert(0, 0.0)
20
+
21
+ spline_index = 0
22
+ for i, distance in enumerate(distances):
23
+ if t > distance:
24
+ spline_index = i
25
+ else:
26
+ break
27
+
28
+ if spline_index >= len(distances) - 1:
29
+ return 1
30
+
31
+ local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
32
+ return (spline_index + local_ratio)/(len(distances)-1)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ total_steps = 20
37
+ for i in range(total_steps):
38
+ print(i, scale_t(i/total_steps, [9, 10]))
z-prompt-fusion-extension/metadata.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [Extension]
2
+ Name = prompt-fusion
z-prompt-fusion-extension/readme.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Fusion
2
+
3
+ Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
4
+
5
+ ```
6
+ # linear prompt interpolation
7
+ [night light:magical forest: 5, 15]
8
+
9
+ # catmull-rom curve prompt interpolation
10
+ [night light:magical forest:norvegian territory: 5, 15, 25:catmull]
11
+
12
+ # alternation interpolation
13
+ [ufo|a strange sight:0.5]
14
+
15
+ # linear attention interpolation
16
+ (fire extinguisher: 1.0, 2.0)
17
+
18
+ # prompt-editing-aware attention interpolation
19
+ [(fire extinguisher: 1.0, 2.0)::5]
20
+
21
+ # weighted sum of conditions
22
+ [space station : kitchen mixer :: mean]
23
+
24
+ # define functions and variables to simplify repeating patterns and use a consistent structure
25
+ $prompt($style, $quality, $character, $background) = (
26
+ A detailed picture in the style of $style,
27
+ $quality,
28
+ $character lying back,
29
+ $background in the background
30
+ :1)
31
+ ```
32
+
33
+ ## Features
34
+
35
+ - [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
36
+ - [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
37
+ - [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
38
+ - [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
39
+ - [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
40
+ - Complete backwards compatibility with the builtin prompt syntax of the webui
41
+
42
+ The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
43
+
44
+ Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
45
+
46
+ The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
47
+
48
+ ## Usage
49
+ - Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
50
+
51
+ ## Examples
52
+
53
+ ### 1. Influencing the beginning of the sampling process
54
+
55
+ Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
56
+
57
+ ```
58
+ [lion:bird:girl: , 7, 10]
59
+ ```
60
+
61
+ ![curve1](https://user-images.githubusercontent.com/32277961/214725976-b72bafc6-0c5d-4491-9c95-b73da41da082.gif)
62
+
63
+ ### 2. Influencing the middle of the sampling process
64
+
65
+ Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
66
+
67
+ ```
68
+ [fireball:seawater:dragon: , .1, .4:bezier] monster
69
+ ```
70
+
71
+ ![curve2](https://user-images.githubusercontent.com/32277961/214941229-2dccad78-f856-42bb-ae6b-16b65b273cda.gif)
72
+
73
+ ## Webui supported releases
74
+
75
+ The following webui releases are officially supported:
76
+ - `v1.0.0-pre`
77
+ - `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
78
+
79
+ ## Installation
80
+ 1. Visit the **Extensions** tab of Automatic's WebUI.
81
+ 2. Visit the **Available** subtab.
82
+ 3. Look for **Prompt Fusion**.
83
+ 4. Press the **Install** button.
84
+ 5. Wait for the webui to finish downloading the extension.
85
+ 6. Visit the **Installed** subtab.
86
+ 7. click on **Apply and restart UI**.
87
+
88
+ Alternatively, instead of steps 6 and 7, you can restart the webui completely.
89
+
90
+ ## Related Projects
91
+
92
+ - Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
93
+ - Shift Attention: https://github.com/yownas/shift-attention
94
+ - Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
95
+
z-prompt-fusion-extension/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
z-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc ADDED
Binary file (8.6 kB). View file
 
z-prompt-fusion-extension/scripts/promptlang.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import List, Any
3
+
4
+ from lib_prompt_fusion import (
5
+ hijacker,
6
+ empty_cond,
7
+ global_state,
8
+ interpolation_tensor,
9
+ prompt_parser as prompt_fusion_parser,
10
+ )
11
+ from modules import scripts, script_callbacks, prompt_parser, shared
12
+
13
+
14
+ # -----------------------------------------------------------------------------
15
+ # UI options
16
+ # -----------------------------------------------------------------------------
17
+
18
+ def on_ui_settings():
19
+ section = ("prompt-fusion", "Prompt Fusion")
20
+ shared.opts.add_option(
21
+ "prompt_fusion_enabled",
22
+ shared.OptionInfo(True, "Enable prompt-fusion extension", section=section),
23
+ )
24
+ shared.opts.add_option(
25
+ "prompt_fusion_slerp_scale",
26
+ shared.OptionInfo(
27
+ 0.0,
28
+ "Slerp scale (0 = linear geometry, 1 = slerp geometry)",
29
+ component=gr.Number,
30
+ section=section,
31
+ ),
32
+ )
33
+ shared.opts.add_option(
34
+ "prompt_fusion_slerp_negative_origin",
35
+ shared.OptionInfo(
36
+ True,
37
+ "Use negative prompt schedule as slerp origin",
38
+ component=gr.Checkbox,
39
+ section=section,
40
+ ),
41
+ )
42
+ shared.opts.add_option(
43
+ "prompt_fusion_slerp_epsilon",
44
+ shared.OptionInfo(
45
+ 1e-4,
46
+ "Slerp epsilon (similarity clamp; too similar/too orthogonal -> fallback to linear)",
47
+ component=gr.Number,
48
+ section=section,
49
+ ),
50
+ )
51
+
52
+
53
+ script_callbacks.on_ui_settings(on_ui_settings)
54
+
55
+
56
+ # -----------------------------------------------------------------------------
57
+ # Hijacker installation
58
+ # -----------------------------------------------------------------------------
59
+
60
+ fusion_hijacker_attribute = "__fusion_hijacker"
61
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
62
+ module=prompt_parser,
63
+ hijacker_attribute=fusion_hijacker_attribute,
64
+ register_uninstall=script_callbacks.on_script_unloaded,
65
+ )
66
+
67
+
68
+ # -----------------------------------------------------------------------------
69
+ # Helpers
70
+ # -----------------------------------------------------------------------------
71
+
72
+ def _wrap_cond_any(raw_cond):
73
+ """
74
+ Заворачивает сырое cond (dict[str, Tensor] или Tensor) в обёртки,
75
+ которые понимает логика интерполяции (DictCondWrapper/TensorCondWrapper).
76
+ Ничего не индексируем — считаем, что пришёл «одиночный» cond.
77
+ """
78
+ if isinstance(raw_cond, dict):
79
+ return interpolation_tensor.DictCondWrapper(raw_cond)
80
+ else:
81
+ return interpolation_tensor.TensorCondWrapper(raw_cond)
82
+
83
+
84
+ def _adapt_flattened_schedules(result: Any, total_steps: int):
85
+ """
86
+ Нормализует вывод prompt_parser.get_learned_conditioning() к
87
+ List[List[ScheduledPromptConditioning]].
88
+
89
+ Поддерживает:
90
+ - уже готовый список списков расписаний;
91
+ - MulticondLearnedConditioning c .batch (комбинируем по весам);
92
+ - одиночный cond (оборачиваем в расписание «весь диапазон»).
93
+ """
94
+ # 1) Уже тот формат, который нужен
95
+ if isinstance(result, list) and (len(result) == 0 or isinstance(result[0], list)):
96
+ return result
97
+
98
+ # 2) Попытка «утиная типизация» MulticondLearnedConditioning
99
+ batch = getattr(result, "batch", None)
100
+ if batch is not None:
101
+ adapted = []
102
+ # batch: List[List[ComposableScheduledPromptConditioning]]
103
+ for composables in batch:
104
+ # Собираем все границы из всех сабкомпонент
105
+ boundaries = set()
106
+ for comp in composables:
107
+ for entry in getattr(comp, "schedules", []) or []:
108
+ try:
109
+ boundaries.add(int(entry.end_at_step))
110
+ except Exception:
111
+ boundaries.add(total_steps - 1)
112
+ if not boundaries:
113
+ boundaries = {total_steps - 1}
114
+ sorted_bounds = sorted(boundaries)
115
+
116
+ # Вычислитель cond на шаге для одного composable
117
+ def cond_for_step(comp, step):
118
+ # comp.schedules: List[ScheduledPromptConditioning]
119
+ idx = 0
120
+ for i, entry in enumerate(comp.schedules):
121
+ try:
122
+ end_at = int(entry.end_at_step)
123
+ except Exception:
124
+ end_at = total_steps - 1
125
+ if step <= end_at:
126
+ idx = i
127
+ break
128
+ return comp.schedules[idx].cond
129
+
130
+ # Сумма cond с весами; поддерживаем dict или Tensor
131
+ def add_scaled(dst, src, w):
132
+ if dst is None:
133
+ return {k: v * w for k, v in src.items()} if isinstance(src, dict) else src * w
134
+ if isinstance(src, dict):
135
+ out = dict(dst)
136
+ for k, v in src.items():
137
+ out[k] = out.get(k, 0) + v * w
138
+ return out
139
+ return dst + src * w
140
+
141
+ sched = []
142
+ for end_at in sorted_bounds:
143
+ combined = None
144
+ for comp in composables:
145
+ w = float(getattr(comp, "weight", 1.0))
146
+ combined = add_scaled(combined, cond_for_step(comp, end_at), w)
147
+ sched.append(
148
+ prompt_parser.ScheduledPromptConditioning(
149
+ end_at_step=end_at, cond=combined
150
+ )
151
+ )
152
+ adapted.append(sched)
153
+ return adapted
154
+
155
+ # 3) Фолбэк: одиночный cond -> одно расписание на весь диапазон
156
+ return [[prompt_parser.ScheduledPromptConditioning(end_at_step=total_steps - 1, cond=result)]]
157
+
158
+
159
+ def _build_tensor_for_prompt(
160
+ model,
161
+ prompt_text: str,
162
+ total_steps: int,
163
+ is_hires: bool,
164
+ use_old_scheduling: bool,
165
+ original_function,
166
+ *,
167
+ original_prompts_proto=None,
168
+ ):
169
+ """
170
+ Разбирает один текст промпта своим AST, строит базу строк, кодирует их
171
+ через «модульный»/стандартный парсер (original_function),
172
+ нормализует расписания и собирает интерполяционный тензор.
173
+ Возвращает готовый список ScheduledPromptConditioning для этого промпта.
174
+ """
175
+ # 1) Построить TensorBuilder из AST
176
+ tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
177
+ expr = prompt_fusion_parser.parse_prompt(prompt_text)
178
+ expr.extend_tensor(
179
+ tensor_builder,
180
+ [0, max(0, total_steps - 1)],
181
+ total_steps,
182
+ {},
183
+ is_hires,
184
+ use_old_scheduling,
185
+ )
186
+
187
+ # 2) Получить «плоскую» базу строк и закодировать их в cond
188
+ flat_prompts = tensor_builder.get_prompt_database()
189
+
190
+ # Сохраняем поведение A1111: в некоторых версиях промпт-объект содержит флаг негативности.
191
+ # Сделаем list-обёртку с тем же атрибутом, если он есть у исходного prompts.
192
+ class _SdLike(list):
193
+ pass
194
+
195
+ if original_prompts_proto is not None and hasattr(original_prompts_proto, "is_negative_prompt"):
196
+ sd_like = _SdLike(flat_prompts)
197
+ sd_like.is_negative_prompt = original_prompts_proto.is_negative_prompt
198
+ flat_input = sd_like
199
+ else:
200
+ flat_input = flat_prompts
201
+
202
+ flattened = original_function(model, flat_input, total_steps)
203
+ flattened = _adapt_flattened_schedules(flattened, total_steps)
204
+
205
+ # 3) Обернуть cond-значения
206
+ wrapped_conds: List[List[prompt_parser.ScheduledPromptConditioning]] = []
207
+ for sched in flattened:
208
+ wrapped_sched = [
209
+ prompt_parser.ScheduledPromptConditioning(
210
+ end_at_step=entry.end_at_step,
211
+ cond=_wrap_cond_any(entry.cond),
212
+ )
213
+ for entry in sched
214
+ ]
215
+ wrapped_conds.append(wrapped_sched)
216
+
217
+ # 4) Собрать интерполяционный тензор и вычислить расписания по шагам
218
+ tensor = tensor_builder.build(wrapped_conds, empty_cond.get())
219
+
220
+ slerp_scale = global_state.get_slerp_scale()
221
+ slerp_epsilon = global_state.get_slerp_epsilon()
222
+
223
+ # Сформировать расписание, склеивая одинаковые соседние сегменты
224
+ schedules: List[prompt_parser.ScheduledPromptConditioning] = []
225
+ prev_wrapper = None
226
+ for step in range(total_steps):
227
+ params = interpolation_tensor.InterpolationParams(
228
+ t=step / max(1, total_steps - 1),
229
+ step=step,
230
+ total_steps=total_steps,
231
+ slerp_scale=slerp_scale,
232
+ slerp_epsilon=slerp_epsilon,
233
+ )
234
+ origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
235
+ cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
236
+
237
+ # Склейка соседних, если равны (по wrapper.__eq__)
238
+ if prev_wrapper is not None and cond_wrapper == prev_wrapper:
239
+ schedules[-1].end_at_step = step
240
+ else:
241
+ raw = cond_wrapper.original_cond if hasattr(cond_wrapper, "original_cond") else cond_wrapper
242
+ schedules.append(
243
+ prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=raw)
244
+ )
245
+ prev_wrapper = cond_wrapper
246
+
247
+ # Финализировать последний сегмент на последний шаг
248
+ if schedules:
249
+ schedules[-1].end_at_step = total_steps - 1
250
+
251
+ return schedules
252
+
253
+
254
+ # -----------------------------------------------------------------------------
255
+ # Hooks
256
+ # -----------------------------------------------------------------------------
257
+
258
+ @prompt_parser_hijacker.hijack("get_learned_conditioning")
259
+ def _fusion_get_learned_conditioning(
260
+ model,
261
+ prompts,
262
+ total_steps: int,
263
+ *args,
264
+ original_function=None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+ Главный перехват. Если расширение выключено — просто проксируем оригинал.
269
+ Иначе: строим собственный интерполяционный тензор поверх энкодинга «модульного» парсера.
270
+ """
271
+ # В офф — строго оригинал
272
+ if not shared.opts.data.get("prompt_fusion_enabled", True):
273
+ return original_function(model, prompts, total_steps, *args, **kwargs)
274
+
275
+ # Извлекаем hires-флаг/режим интерпретации шагов (совместимость со старыми A1111)
276
+ hires_steps = None
277
+ use_old_scheduling = True
278
+ if args:
279
+ # Популярный порядок: (hires_steps, use_old_scheduling, ...)
280
+ try:
281
+ hires_steps, use_old_scheduling = args[0], args[1]
282
+ except Exception:
283
+ pass
284
+
285
+ is_hires = hires_steps is not None
286
+ real_total_steps = hires_steps if is_hires else total_steps
287
+
288
+ # Определяем, это негативный вызов или нет
289
+ if hasattr(prompts, "is_negative_prompt"):
290
+ is_negative_prompt = bool(prompts.is_negative_prompt)
291
+ else:
292
+ # Старые билды: первый вызов «считаем негативным», второй — положительным
293
+ is_negative_prompt = global_state.old_webui_is_negative
294
+
295
+ # Инициируем «пустой» cond
296
+ empty_cond.init(model)
297
+
298
+ # Готовим выход для всего батча
299
+ out: List[List[prompt_parser.ScheduledPromptConditioning]] = []
300
+
301
+ # Пробегаем по элементам батча
302
+ for prompt_text in prompts:
303
+ sched = _build_tensor_for_prompt(
304
+ model=model,
305
+ prompt_text=prompt_text,
306
+ total_steps=real_total_steps,
307
+ is_hires=is_hires,
308
+ use_old_scheduling=use_old_scheduling,
309
+ original_function=original_function,
310
+ original_prompts_proto=prompts,
311
+ )
312
+ out.append(sched)
313
+
314
+ # Если это негативные расписания — сохраним для slerp-origin
315
+ if is_negative_prompt:
316
+ if is_hires:
317
+ global_state.negative_schedules_hires = out[0] if out else None
318
+ else:
319
+ global_state.negative_schedules = out[0] if out else None
320
+
321
+ return out
322
+
323
+
324
+ @prompt_parser_hijacker.hijack("get_multicond_learned_conditioning")
325
+ def _fusion_get_multicond_learned_conditioning(
326
+ *args,
327
+ original_function=None,
328
+ **kwargs,
329
+ ):
330
+ """
331
+ Совместимость со старым порядком вызовов A1111:
332
+ после вызова этого метода следующий get_learned_conditioning — уже не «негативный».
333
+ Ничего не меняем в значениях — возвращаем как есть.
334
+ """
335
+ global_state.old_webui_is_negative = False
336
+ return original_function(*args, **kwargs)
337
+
338
+
339
+ # -----------------------------------------------------------------------------
340
+ # WebUI script shim (чтобы сбрасывать «негативность» и кэш между генерациями)
341
+ # -----------------------------------------------------------------------------
342
+
343
+ class PromptFusionScript(scripts.Script):
344
+ def title(self):
345
+ return "Prompt Fusion"
346
+
347
+ def show(self, is_img2img):
348
+ return scripts.AlwaysVisible
349
+
350
+ def process(self, p, *args):
351
+ # Перед началом пайпа считаем, что следующий get_learned_conditioning — «негативный»
352
+ global_state.negative_schedules = None
353
+ global_state.negative_schedules_hires = None
354
+ global_state.old_webui_is_negative = True
z-prompt-fusion-extension/test/parser_tests.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion.prompt_parser import parse_prompt
2
+ from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
3
+
4
+
5
+ def run_functional_tests(total_steps=100):
6
+ for i, (given, expected) in enumerate(functional_parse_test_cases):
7
+ expr = parse_prompt(given)
8
+ tensor_builder = InterpolationTensorBuilder()
9
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
10
+
11
+ actual = tensor_builder.get_prompt_database()
12
+
13
+ if type(expected) is set:
14
+ assert set(actual) == expected, f"{actual} != {expected}"
15
+ else:
16
+ assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
17
+
18
+
19
+ functional_parse_test_cases = [
20
+ ('single',)*2,
21
+ ('some space separated text',)*2,
22
+ ('(legacy weighted prompt:-2.1)',)*2,
23
+ ('mixed (legacy weight:3.6) and text',)*2,
24
+ ('legacy [range begin:0] thingy',)*2,
25
+ ('legacy [range end::3] thingy',)*2,
26
+ ('legacy [[nested range::3]:2] thingy',)*2,
27
+ ('legacy [[nested range:2]::3] thingy',)*2,
28
+ ('sugar [range:,abc:3] thingy',)*2,
29
+ ('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
30
+ ('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
31
+ ('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
32
+ ('legacy [from:to:2] thingy',)*2,
33
+ ('legacy [negative weight]',)*2,
34
+ ('legacy (positive weight)',)*2,
35
+ ('[abc:1girl:2]',)*2,
36
+ ('[::]',)*2,
37
+ ('[a:b:]',)*2,
38
+ ('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
39
+ ('1girl',)*2,
40
+ ('dashes-in-text',)*2,
41
+ ('text, separated with, comas',)*2,
42
+ ('{prompt}',)*2,
43
+ ('[abc|def ghi|jkl]',)*2,
44
+ ('merging this AND with this',)*2,
45
+ (':',)*2,
46
+ (r'portrait \(object\)',)*2,
47
+ (r'\[escaped square\]',)*2,
48
+ (r'\$var = abc',)*2,
49
+ (r'\\$ arst',)*2,
50
+ (r'$$ arst',)*2,
51
+ ('$var = abc', ''),
52
+ ('$a = prompt value\n$a', 'prompt value'),
53
+ ('$a = prompt value\n$b = $a\n$b', 'prompt value'),
54
+ ('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
55
+ ('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
56
+ ('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
57
+ ('a [b:c:5, 6] d', {'a b d', 'a c d'}),
58
+ ('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
59
+ ('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
60
+ ('a [b:c:,] d', {'a b d', 'a c d'}),
61
+ ('0[1.0:1.1:,]2[3.0:3.1:,]4', {
62
+ '0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
63
+ '0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
64
+ }),
65
+ ('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
66
+ '0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
67
+ '0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
68
+ '0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
69
+ }),
70
+ ('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
71
+ '0.0 1.0 2.0', '0.0 1.0 2.1',
72
+ '0.1 1.0 2.0', '0.1 1.0 2.1',
73
+ '0.0 1.1 2.0', '0.0 1.1 2.1',
74
+ '0.1 1.1 2.0', '0.1 1.1 2.1',
75
+ }),
76
+ ('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
77
+ ('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
78
+ ('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
79
+ ('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
80
+ ('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
81
+ ('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
82
+ ('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
83
+ ('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
84
+ ('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
85
+ ('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
86
+ ('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
87
+ ('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
88
+ ('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
89
+ ('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
90
+ ('[a|b|c]', '[a|b|c]'),
91
+ ('[a|b|c:]', '[a|b|c]'),
92
+ ('[a|b|c:1]', {'a', 'b', 'c'}),
93
+ ('[a|b|c:2]', {'a', 'b', 'c'}),
94
+ ('[a|b|c:0.5]', {'a', 'b', 'c'}),
95
+ ('[a|b|c:1.1]', {'a', 'b', 'c'}),
96
+ ('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
97
+ ('[a:b:c::mean]', {'a', 'b', 'c'}),
98
+ ('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
99
+ ('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
100
+ ]
101
+
102
+
103
+ def run_tests():
104
+ run_functional_tests()
z-prompt-fusion-extension/test/run_all.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ import parser_tests
4
+
5
+
6
+ if __name__ == '__main__':
7
+ parser_tests.run_tests()