dikdimon commited on
Commit
7f6efce
·
verified ·
1 Parent(s): baaf54d

Upload z-2-prompt-fusion-extension using SD-Hub

Browse files
Files changed (28) hide show
  1. z-2-prompt-fusion-extension/.gitignore +3 -0
  2. z-2-prompt-fusion-extension/LICENSE +21 -0
  3. z-2-prompt-fusion-extension/geometries.py +42 -0
  4. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc +0 -0
  5. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc +0 -0
  6. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc +0 -0
  7. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc +0 -0
  8. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc +0 -0
  9. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc +0 -0
  10. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc +0 -0
  11. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc +0 -0
  12. z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc +0 -0
  13. z-2-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py +296 -0
  14. z-2-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py +19 -0
  15. z-2-prompt-fusion-extension/lib_prompt_fusion/geometries.py +42 -0
  16. z-2-prompt-fusion-extension/lib_prompt_fusion/global_state.py +28 -0
  17. z-2-prompt-fusion-extension/lib_prompt_fusion/hijacker.py +34 -0
  18. z-2-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py +87 -0
  19. z-2-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py +253 -0
  20. z-2-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py +358 -0
  21. z-2-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py +38 -0
  22. z-2-prompt-fusion-extension/metadata.ini +2 -0
  23. z-2-prompt-fusion-extension/readme.md +95 -0
  24. z-2-prompt-fusion-extension/requirements.txt +1 -0
  25. z-2-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc +0 -0
  26. z-2-prompt-fusion-extension/scripts/promptlang.py +147 -0
  27. z-2-prompt-fusion-extension/test/parser_tests.py +104 -0
  28. z-2-prompt-fusion-extension/test/run_all.py +7 -0
z-2-prompt-fusion-extension/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv/
2
+ /.idea/
3
+ __pycache__/
z-2-prompt-fusion-extension/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2003
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
z-2-prompt-fusion-extension/geometries.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from lib_prompt_fusion import interpolation_tensor
4
+
5
+
6
+ def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
7
+ p0, p1 = control_points
8
+
9
+ # Нормы и защита от нулевой нормы (fallback на линейную геометрию)
10
+ p0_norm = torch.linalg.norm(p0)
11
+ p1_norm = torch.linalg.norm(p1)
12
+ if float(p0_norm) == 0.0 or float(p1_norm) == 0.0:
13
+ return linear_geometry(control_points, params)
14
+
15
+ # Косинусная близость в [-1, 1] с аккуратным clamping
16
+ similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
17
+ similarity = min(1.0, max(-1.0, float(similarity)))
18
+
19
+ # Если почти параллельны/антипараллельны — надёжный fallback на линейную
20
+ if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
21
+ return linear_geometry(control_points, params)
22
+
23
+ # Полуугол используем для симметричного т-ремапа
24
+ angle = math.acos(float(similarity)) / 2.0
25
+
26
+ slerp_t = angle * (2 * params.t - 1)
27
+ slerp_t = math.tan(slerp_t) / math.tan(angle)
28
+ slerp_t = (slerp_t + 1.0) / 2.0
29
+
30
+ # Выравниваем норму p1 под норму p0 для корректной дуги
31
+ normalized_p1 = p1 / p1_norm * p0_norm
32
+ slerp_p = p0 + (normalized_p1 - p0) * slerp_t
33
+ slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
34
+
35
+ # Смешиваем с линейной геометрией согласно slerp_scale
36
+ lerp_p = linear_geometry(control_points, params)
37
+ return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
38
+
39
+
40
+ def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
41
+ p0, p1 = control_points
42
+ return p0 + (p1 - p0) * params.t
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc ADDED
Binary file (795 Bytes). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc ADDED
Binary file (9.88 kB). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc ADDED
Binary file (877 Bytes). View file
 
z-2-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from lib_prompt_fusion import interpolation_functions
3
+ from lib_prompt_fusion.t_scaler import scale_t
4
+ from lib_prompt_fusion import interpolation_tensor
5
+
6
+
7
+ class ListExpression:
8
+ def __init__(self, expressions):
9
+ self.__expressions = expressions
10
+
11
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
12
+ if not self.__expressions:
13
+ return
14
+
15
+ def expr_extend_tensor(expr):
16
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
17
+
18
+ expr_extend_tensor(self.__expressions[0])
19
+ for expression in self.__expressions[1:]:
20
+ tensor_builder.append(' ')
21
+ expr_extend_tensor(expression)
22
+
23
+
24
+ class InterpolationExpression:
25
+ @staticmethod
26
+ def create(exprs, steps, function_name):
27
+ if function_name == "mean":
28
+ return AverageExpression(exprs, steps)
29
+
30
+ max_len = min(len(exprs), len(steps))
31
+ exprs = exprs[:max_len]
32
+ steps = steps[:max_len]
33
+
34
+ return InterpolationExpression(exprs, steps, function_name)
35
+
36
+ def __init__(self, expressions, steps, function_name=None):
37
+ assert len(expressions) >= 2
38
+ assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
39
+ self.__expressions = expressions
40
+ self.__steps = steps
41
+ self.__function_name = function_name if function_name is not None else 'linear'
42
+
43
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
44
+ def tensor_updater(expr):
45
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
46
+
47
+ tensor_builder.extrude(
48
+ [tensor_updater(expr) for expr in self.__expressions],
49
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling)
50
+ )
51
+
52
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
53
+ steps = list(self.__steps)
54
+ if steps[0] is None:
55
+ steps[0] = LiftExpression(str(steps_range[0] - 1))
56
+ if steps[-1] is None:
57
+ steps[-1] = LiftExpression(str(steps_range[1] - 1))
58
+
59
+ for i, step in enumerate(steps):
60
+ if step is None:
61
+ continue
62
+
63
+ step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
64
+ if use_old_scheduling and 0 < step < 1:
65
+ step *= total_steps
66
+ elif not use_old_scheduling and isinstance(step, float):
67
+ step = (step - int(is_hires)) * total_steps
68
+ else:
69
+ step += 1
70
+ steps[i] = int(step)
71
+
72
+ i = 1
73
+ while i < len(steps):
74
+ none_len = 0
75
+ while steps[i + none_len] is None:
76
+ none_len += 1
77
+ min_step, max_step = steps[i - 1], steps[i + none_len]
78
+ for j in range(none_len):
79
+ steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
80
+ i += 1 + none_len
81
+
82
+ interpolation_function = {
83
+ 'linear': interpolation_functions.compute_linear,
84
+ 'bezier': interpolation_functions.compute_bezier,
85
+ 'catmull': interpolation_functions.compute_catmull,
86
+ }[self.__function_name]
87
+
88
+ def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
89
+ scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
90
+ scaled_t = scale_t(scaled_t, steps)
91
+ new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
92
+ return interpolation_function(conds, new_params)
93
+
94
+ return steps_scale_t
95
+
96
+
97
+ class AverageExpression:
98
+ def __init__(self, expressions, weights):
99
+ # Делаем «mean» устойчивым к лишним весам
100
+ if len(expressions) < len(weights):
101
+ weights = weights[:len(expressions)]
102
+ self.__expressions = expressions
103
+ self.__weights = weights
104
+
105
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
106
+ def tensor_updater(expr):
107
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
108
+
109
+ tensor_builder.extrude(
110
+ [tensor_updater(expr) for expr in self.__expressions],
111
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling)
112
+ )
113
+
114
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
115
+ weights = [
116
+ _eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
117
+ for weight in self.__weights
118
+ ]
119
+ explicit_weights = [w for w in weights if w is not None]
120
+ explicit_sum = sum(explicit_weights) if explicit_weights else 0.0
121
+
122
+ # Нормализация: если нет явных весов или сумма == 0 — равномерно
123
+ norm_weights = []
124
+ for w in weights:
125
+ if not explicit_weights or explicit_sum == 0.0 or w is None:
126
+ norm_weights.append(1.0 / len(self.__expressions))
127
+ else:
128
+ norm_weights.append((w / explicit_sum) * len(explicit_weights) / len(self.__expressions))
129
+
130
+ # Заполняем недостающие веса равномерно
131
+ while len(norm_weights) < len(self.__expressions):
132
+ norm_weights.append(1.0 / len(self.__expressions))
133
+
134
+ def interpolation_function(conds, _params):
135
+ total = None
136
+ for cond, weight in zip(conds, norm_weights):
137
+ cond *= weight
138
+ total = cond if total is None else total + cond
139
+ return total
140
+
141
+ return interpolation_function
142
+
143
+
144
+ class AlternationExpression:
145
+ def __init__(self, expressions, speed):
146
+ self.__expressions = expressions
147
+ self.__speed = speed
148
+
149
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
150
+ if self.__speed is None:
151
+ tensor_builder.append('[')
152
+ for expr_i, expr in enumerate(self.__expressions):
153
+ if expr_i >= 1:
154
+ tensor_builder.append('|')
155
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
156
+ tensor_builder.append(']')
157
+ return
158
+
159
+ def tensor_updater(expr):
160
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
161
+
162
+ exprs = self.__expressions + [self.__expressions[0]]
163
+
164
+ tensor_builder.extrude(
165
+ [tensor_updater(expr) for expr in exprs],
166
+ self.get_interpolation_function(_eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling),
167
+ exprs, steps_range, total_steps)
168
+ )
169
+
170
+ def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
171
+ def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
172
+ wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
173
+ if wrapped_t < 0:
174
+ wrapped_t += 1.0
175
+ new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
176
+ return interpolation_functions.compute_linear(control_points, new_params)
177
+
178
+ return compute_wrap
179
+
180
+
181
+ class EditingExpression:
182
+ def __init__(self, expressions, step):
183
+ assert 1 <= len(expressions) <= 2
184
+ self.__expressions = expressions
185
+ self.__step = step
186
+
187
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
188
+ if self.__step is None:
189
+ tensor_builder.append('[')
190
+ for expr_i, expr in enumerate(self.__expressions):
191
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
192
+ tensor_builder.append(':')
193
+ tensor_builder.append(']')
194
+ return
195
+
196
+ step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
197
+ step_int = step
198
+ if use_old_scheduling and 0 < step < 1:
199
+ step_int *= total_steps
200
+ elif not use_old_scheduling and isinstance(step, float):
201
+ step_int = (step_int - int(is_hires)) * total_steps
202
+ else:
203
+ step_int += 1
204
+ step_int = int(step_int)
205
+
206
+ tensor_builder.append('[')
207
+ for expr_i, expr in enumerate(self.__expressions):
208
+ expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
209
+ expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
210
+ tensor_builder.append(':')
211
+
212
+ tensor_builder.append(f'{step}]')
213
+
214
+
215
+ class WeightedExpression:
216
+ def __init__(self, nested, weight=None, positive=True):
217
+ self.__nested = nested
218
+ if not positive:
219
+ assert weight is None
220
+ self.__weight = weight
221
+ self.__positive = positive
222
+
223
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
224
+ open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
225
+ tensor_builder.append(open_bracket)
226
+ self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
227
+ if self.__weight is not None:
228
+ tensor_builder.append(':')
229
+ self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
230
+ tensor_builder.append(close_bracket)
231
+
232
+
233
+ class WeightInterpolationExpression:
234
+ def __init__(self, nested, weight_begin, weight_end):
235
+ self.__nested = nested
236
+ self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.0))
237
+ self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.0))
238
+
239
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
240
+ steps_range_size = steps_range[1] - steps_range[0]
241
+ weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
242
+ weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
243
+
244
+ for i in range(steps_range_size):
245
+ step = i + steps_range[0]
246
+ weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
247
+ weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
248
+ if step > steps_range[0]:
249
+ weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
250
+ if step + 1 < steps_range[1]:
251
+ weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
252
+ weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
253
+
254
+
255
+ class DeclarationExpression:
256
+ def __init__(self, symbol, parameters, value, target):
257
+ self.__symbol = symbol
258
+ self.__value = value
259
+ self.__target = target
260
+ self.__parameters = parameters
261
+
262
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
263
+ updated_context = dict(context)
264
+ updated_context[self.__symbol] = (self.__value, self.__parameters)
265
+ self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
266
+
267
+
268
+ class SubstitutionExpression:
269
+ def __init__(self, symbol, arguments):
270
+ self.__symbol = symbol
271
+ self.__arguments = arguments
272
+
273
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
274
+ updated_context = dict(context)
275
+ nested, parameters = context[self.__symbol]
276
+ for argument, parameter in zip(self.__arguments, parameters):
277
+ updated_context[parameter] = argument, []
278
+ nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
279
+
280
+
281
+ class LiftExpression:
282
+ def __init__(self, value):
283
+ self.__value = value
284
+
285
+ def extend_tensor(self, tensor_builder, *_args, **_kwargs):
286
+ tensor_builder.append(self.__value)
287
+
288
+
289
+ def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
290
+ mock_database = ['']
291
+ expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database),
292
+ steps_range, total_steps, context, is_hires, use_old_scheduling)
293
+ try:
294
+ return int(mock_database[0])
295
+ except ValueError:
296
+ return float(mock_database[0])
z-2-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion import interpolation_tensor
2
+
3
+
4
+ _empty_cond = None
5
+
6
+
7
+ def get():
8
+ return _empty_cond
9
+
10
+
11
+ def init(model):
12
+ global _empty_cond
13
+ cond = model.get_learned_conditioning([''])
14
+ if isinstance(cond, dict):
15
+ cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
16
+ else:
17
+ cond = interpolation_tensor.TensorCondWrapper(cond[0])
18
+
19
+ _empty_cond = cond
z-2-prompt-fusion-extension/lib_prompt_fusion/geometries.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from lib_prompt_fusion import interpolation_tensor
4
+
5
+
6
+ def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
7
+ p0, p1 = control_points
8
+
9
+ # Нормы и защита от нулевой нормы (fallback на линейную геометрию)
10
+ p0_norm = torch.linalg.norm(p0)
11
+ p1_norm = torch.linalg.norm(p1)
12
+ if float(p0_norm) == 0.0 or float(p1_norm) == 0.0:
13
+ return linear_geometry(control_points, params)
14
+
15
+ # Косинусная близость в [-1, 1] с аккуратным clamping
16
+ similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
17
+ similarity = min(1.0, max(-1.0, float(similarity)))
18
+
19
+ # Если почти параллельны/антипараллельны — надёжный fallback на линейную
20
+ if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
21
+ return linear_geometry(control_points, params)
22
+
23
+ # Полуугол используем для симметричного т-ремапа
24
+ angle = math.acos(float(similarity)) / 2.0
25
+
26
+ slerp_t = angle * (2 * params.t - 1)
27
+ slerp_t = math.tan(slerp_t) / math.tan(angle)
28
+ slerp_t = (slerp_t + 1.0) / 2.0
29
+
30
+ # Выравниваем норму p1 под норму p0 для корректной дуги
31
+ normalized_p1 = p1 / p1_norm * p0_norm
32
+ slerp_p = p0 + (normalized_p1 - p0) * slerp_t
33
+ slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
34
+
35
+ # Смешиваем с линейной геометрией согласно slerp_scale
36
+ lerp_p = linear_geometry(control_points, params)
37
+ return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
38
+
39
+
40
+ def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
41
+ p0, p1 = control_points
42
+ return p0 + (p1 - p0) * params.t
z-2-prompt-fusion-extension/lib_prompt_fusion/global_state.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from modules import shared, prompt_parser
3
+ from lib_prompt_fusion import empty_cond
4
+
5
+
6
+ old_webui_is_negative: bool = False
7
+ negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
8
+ negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
9
+
10
+
11
+ def get_origin_cond_at(step: int, is_hires: bool = False):
12
+ fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
13
+ if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
14
+ return empty_cond.get()
15
+
16
+ for schedule in fallback_schedules:
17
+ if schedule.end_at_step >= step:
18
+ return schedule.cond
19
+
20
+ return empty_cond.get()
21
+
22
+
23
+ def get_slerp_scale():
24
+ return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
25
+
26
+
27
+ def get_slerp_epsilon():
28
+ return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
z-2-prompt-fusion-extension/lib_prompt_fusion/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModuleHijacker:
2
+ def __init__(self, module):
3
+ self.__module = module
4
+ self.__original_functions = dict()
5
+
6
+ def hijack(self, attribute):
7
+ if attribute not in self.__original_functions:
8
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
9
+
10
+ def decorator(function):
11
+ def wrapper(*args, **kwargs):
12
+ return function(*args, **kwargs, original_function=self.__original_functions[attribute])
13
+
14
+ setattr(self.__module, attribute, wrapper)
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ register_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ register_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
z-2-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from lib_prompt_fusion import interpolation_tensor, geometries
4
+
5
+
6
+ def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
7
+ if len(control_points) <= 2:
8
+ return geometries.slerp_geometry(control_points, params)
9
+ else:
10
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
11
+ cp0 = control_points[target_curve]
12
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
13
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
14
+ return geometries.slerp_geometry([cp0, cp1], new_params)
15
+
16
+
17
+ def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
18
+ def compute_casteljau(ps, size):
19
+ for i in reversed(range(1, size)):
20
+ for j in range(i):
21
+ ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
22
+
23
+ return ps[0]
24
+
25
+ if len(control_points) == 1:
26
+ return control_points[0]
27
+ elif len(control_points) == 2:
28
+ return geometries.slerp_geometry(control_points, params)
29
+ copied_control_points = copy.deepcopy(control_points)
30
+ return compute_casteljau(copied_control_points, len(copied_control_points))
31
+
32
+
33
+ def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
34
+ if len(control_points) <= 2:
35
+ return compute_linear(control_points, params)
36
+ else:
37
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
38
+ g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
39
+ cp0 = control_points[target_curve]
40
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
41
+ g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
42
+ ip0 = cp0 + (cp1 - g0)/6
43
+ ip1 = cp1 + (cp0 - g1)/6
44
+
45
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
46
+ return compute_bezier([cp0, ip0, ip1, cp1], new_params)
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import turtle as tr
51
+ import torch
52
+ size = 60
53
+ turtle_tool = tr.Turtle()
54
+ turtle_tool.speed(10)
55
+ turtle_tool.up()
56
+
57
+ points = torch.Tensor([[-2., -2.], [2., 2.]])
58
+ origin = torch.Tensor([1.5, 1.6])
59
+
60
+ def sample(slerp_scale, color):
61
+ for i in range(size):
62
+ t = i / size
63
+ params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
64
+ point = origin + compute_linear(points - origin, params)
65
+ try:
66
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
67
+ turtle_tool.dot(5, color)
68
+ print(point)
69
+ except ValueError:
70
+ pass
71
+
72
+ sample(0, "black")
73
+ sample(1, "green")
74
+ sample(2, "blue")
75
+ sample(-1, "purple")
76
+ sample(-2, "orange")
77
+
78
+ for point in points:
79
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
80
+ turtle_tool.dot(5, "red")
81
+
82
+ turtle_tool.goto(tuple(float(p) * 100. for p in origin))
83
+ turtle_tool.dot(10, "red")
84
+
85
+ turtle_tool.goto(100000, 100000)
86
+ turtle_tool.dot()
87
+ tr.done()
z-2-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from modules import prompt_parser
4
+ from typing import NamedTuple, Union, ClassVar
5
+
6
+
7
+ class InterpolationParams(NamedTuple):
8
+ t: float
9
+ step: int
10
+ total_steps: int
11
+ slerp_scale: float
12
+ slerp_epsilon: float
13
+
14
+
15
+ class InterpolationTensor:
16
+ def __init__(self, sub_tensors, interpolation_function):
17
+ self.__sub_tensors = sub_tensors
18
+ self.__interpolation_function = interpolation_function
19
+
20
+ def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
21
+ cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
22
+ if params.slerp_scale != 0:
23
+ cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
24
+ return cond
25
+
26
+ def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
27
+ if self.__interpolation_function is None:
28
+ return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
29
+
30
+ control_points = [
31
+ sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
32
+ for sub_tensor in self.__sub_tensors
33
+ ]
34
+
35
+ CondWrapper, control_points_values = conds_to_cp_values(control_points)
36
+ return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
37
+
38
+ def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
39
+ schedule = None
40
+ for schedule in self.__sub_tensors:
41
+ if schedule.end_at_step >= step:
42
+ break
43
+
44
+ res = schedule.cond.extend_like(origin_cond, empty_cond)
45
+ if slerp_scale != 0:
46
+ res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
47
+ return res
48
+
49
+
50
+ def conds_to_cp_values(conds):
51
+ CondWrapper = type(conds[0])
52
+ cp_values = [cond.to_cp_values() for cond in conds]
53
+ # Если словарные кондишены — зафиксируем порядок ключей первого элемента
54
+ if CondWrapper is DictCondWrapper:
55
+ DictCondWrapper.set_keys(list(conds[0].original_cond.keys()))
56
+ transposed = [[v[i] for v in cp_values] for i in range(len(cp_values[0]))]
57
+ return CondWrapper, transposed
58
+
59
+
60
+ class InterpolationTensorBuilder:
61
+ def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
62
+ self.__indices_tensor = tensor if tensor is not None else 0
63
+ self.__prompt_database = prompt_database if prompt_database is not None else ['']
64
+ self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
65
+
66
+ def append(self, suffix):
67
+ for i in range(len(self.__prompt_database)):
68
+ self.__prompt_database[i] += suffix
69
+
70
+ def extrude(self, tensor_updaters, interpolation_function):
71
+ extruded_indices_tensor = []
72
+ extruded_prompt_database = []
73
+ extruded_interpolation_functions = []
74
+
75
+ for update_tensor in tensor_updaters:
76
+ nested_tensor_builder = InterpolationTensorBuilder(
77
+ self.__indices_tensor,
78
+ self.__prompt_database[:],
79
+ interpolation_functions=[])
80
+
81
+ update_tensor(nested_tensor_builder)
82
+
83
+ extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
84
+ tensor=nested_tensor_builder.__indices_tensor,
85
+ offset=len(extruded_prompt_database)))
86
+ extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
87
+ extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
88
+
89
+ self.__indices_tensor = extruded_indices_tensor
90
+ self.__prompt_database[:] = extruded_prompt_database
91
+ self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
92
+
93
+ def get_prompt_database(self):
94
+ return self.__prompt_database
95
+
96
+ @staticmethod
97
+ def __offset_tensor(tensor, offset):
98
+ try:
99
+ return tensor + offset
100
+
101
+ except TypeError:
102
+ return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
103
+
104
+ def build(self, conds, empty_cond):
105
+ max_cond_size = self.__max_cond_size(conds)
106
+ conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
107
+ return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
108
+
109
+ @staticmethod
110
+ def __build_conditionings_tensor(tensor, int_funcs, conds):
111
+ if type(tensor) is int:
112
+ return InterpolationTensor(conds[tensor], None)
113
+ else:
114
+ int_func, nested_int_funcs = int_funcs[0]
115
+ return InterpolationTensor(
116
+ [
117
+ InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
118
+ for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
119
+ ],
120
+ int_func,
121
+ )
122
+
123
+ def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
124
+ return [
125
+ [
126
+ prompt_parser.ScheduledPromptConditioning(
127
+ cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
128
+ end_at_step=schedule.end_at_step
129
+ )
130
+ for schedule in schedules
131
+ ]
132
+ for schedules in conds
133
+ ]
134
+
135
+ @staticmethod
136
+ def __max_cond_size(conds):
137
+ return max(schedule.cond.size(0)
138
+ for schedules in conds
139
+ for schedule in schedules)
140
+
141
+
142
+ @dataclasses.dataclass
143
+ class DictCondWrapper:
144
+ original_cond: dict
145
+
146
+ # Порядок ключей для восстановления из контрольных значений
147
+ _keys: ClassVar[list] = None
148
+
149
+ @classmethod
150
+ def set_keys(cls, keys: list):
151
+ cls._keys = list(keys)
152
+
153
+ @classmethod
154
+ def from_cp_values(cls, cp_values):
155
+ keys = cls._keys or ['crossattn', 'vector']
156
+ return cls({k: v for k, v in zip(keys, cp_values)})
157
+
158
+ def size(self, *args, **kwargs):
159
+ return self.original_cond['crossattn'].size(*args, **kwargs)
160
+
161
+ def extend_like(self, that, empty):
162
+ missing_size = max(0, that.size(0) - self.size(0)) // 77
163
+ extended = DictCondWrapper(self.original_cond.copy())
164
+ extended.original_cond['crossattn'] = torch.cat([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
165
+ return extended
166
+
167
+ def resize_schedule(self, target_size, empty_cond):
168
+ cond_missing_size = (target_size - self.size(0)) // 77
169
+ if cond_missing_size <= 0:
170
+ return self
171
+
172
+ resized_cond = self.original_cond.copy()
173
+ resized_cond['crossattn'] = torch.cat([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
174
+ return DictCondWrapper(resized_cond)
175
+
176
+ def to_cp_values(self):
177
+ keys = self._keys or list(self.original_cond.keys())
178
+ return [self.original_cond[k] for k in keys]
179
+
180
+ def to(self, dtype: Union[dict, torch.dtype]):
181
+ if not isinstance(dtype, dict):
182
+ dtype = {
183
+ k: dtype
184
+ for k in self.original_cond.keys()
185
+ }
186
+ return DictCondWrapper({
187
+ k: v.to(dtype=dtype[k])
188
+ for k, v in self.original_cond.items()
189
+ })
190
+
191
+ @property
192
+ def dtype(self):
193
+ return {
194
+ k: v.dtype
195
+ for k, v in self.original_cond.items()
196
+ }
197
+
198
+ def __sub__(self, that):
199
+ return DictCondWrapper({
200
+ k: v - that.original_cond[k]
201
+ for k, v in self.original_cond.items()
202
+ })
203
+
204
+ def __add__(self, that):
205
+ return DictCondWrapper({
206
+ k: v + that.original_cond[k]
207
+ for k, v in self.original_cond.items()
208
+ })
209
+
210
+ def __eq__(self, that):
211
+ return all(torch.equal(self.original_cond[k], that.original_cond[k]) for k in self.original_cond.keys())
212
+
213
+
214
+ @dataclasses.dataclass
215
+ class TensorCondWrapper:
216
+ original_cond: torch.Tensor
217
+
218
+ @staticmethod
219
+ def from_cp_values(cp_values):
220
+ return TensorCondWrapper(next(iter(cp_values)))
221
+
222
+ def size(self, *args, **kwargs):
223
+ return self.original_cond.size(*args, **kwargs)
224
+
225
+ def extend_like(self, that, empty):
226
+ missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
227
+ return TensorCondWrapper(torch.cat([self.original_cond] + [empty.original_cond] * missing_size))
228
+
229
+ def resize_schedule(self, target_size, empty_cond):
230
+ cond_missing_size = (target_size - self.original_cond.size(0)) // 77
231
+ if cond_missing_size <= 0:
232
+ return self
233
+
234
+ return TensorCondWrapper(torch.cat([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
235
+
236
+ def to_cp_values(self):
237
+ return [self.original_cond]
238
+
239
+ def to(self, dtype: torch.dtype):
240
+ return TensorCondWrapper(self.original_cond.to(dtype=dtype))
241
+
242
+ @property
243
+ def dtype(self):
244
+ return self.original_cond.dtype
245
+
246
+ def __sub__(self, that):
247
+ return TensorCondWrapper(self.original_cond - that.original_cond)
248
+
249
+ def __add__(self, that):
250
+ return TensorCondWrapper(self.original_cond + that.original_cond)
251
+
252
+ def __eq__(self, that):
253
+ return torch.equal(self.original_cond, that.original_cond)
z-2-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib_prompt_fusion.ast_nodes as ast
2
+ from collections import namedtuple
3
+ import re
4
+
5
+
6
+ ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
7
+
8
+
9
+ def parse_prompt(prompt):
10
+ prompt = prompt.lstrip()
11
+ prompt, list_expr = parse_list_expression(prompt, set())
12
+ return list_expr
13
+
14
+
15
+ def parse_list_expression(prompt, stoppers):
16
+ exprs = []
17
+ try:
18
+ while True:
19
+ prompt, expr = parse_expression(prompt, stoppers)
20
+ exprs.append(expr)
21
+ except ValueError:
22
+ return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
23
+
24
+
25
+ def parse_expression(prompt, stoppers):
26
+ for parse in _parsers():
27
+ try:
28
+ return parse(prompt, stoppers)
29
+ except ValueError:
30
+ pass
31
+ raise ValueError
32
+
33
+
34
+ def _parsers():
35
+ return (
36
+ parse_text,
37
+ parse_declaration,
38
+ parse_substitution,
39
+ parse_positive_attention,
40
+ parse_negative_attention,
41
+ parse_editing,
42
+ parse_alternation,
43
+ parse_interpolation,
44
+ parse_unrestricted_text,
45
+ )
46
+
47
+
48
+ def parse_text(prompt, stoppers):
49
+ return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
50
+
51
+
52
+ def parse_unrestricted_text(prompt, stoppers):
53
+ escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
54
+ regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
55
+ prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
56
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
57
+
58
+
59
+ def parse_substitution(prompt, stoppers):
60
+ prompt, symbol = parse_symbol(prompt, stoppers)
61
+ prompt, arguments = parse_arguments(prompt, stoppers)
62
+ return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
63
+
64
+
65
+ def parse_arguments(prompt, stoppers):
66
+ try:
67
+ prompt, _ = parse_open_paren(prompt, stoppers)
68
+ prompt, arguments = parse_inner_arguments(prompt, stoppers)
69
+ prompt, _ = parse_close_paren(prompt, stoppers)
70
+ except ValueError:
71
+ arguments = []
72
+ return ParseResult(prompt=prompt, expr=arguments)
73
+
74
+
75
+ def parse_inner_arguments(prompt, stoppers):
76
+ arguments = []
77
+ try:
78
+ while True:
79
+ prompt, arg = parse_list_expression(prompt, {',', ')'})
80
+ arguments.append(arg)
81
+ prompt, _ = parse_comma(prompt, stoppers)
82
+ except ValueError:
83
+ pass
84
+ return ParseResult(prompt=prompt, expr=arguments)
85
+
86
+
87
+ def parse_declaration(prompt, stoppers):
88
+ prompt, symbol = parse_symbol(prompt, stoppers)
89
+ prompt, parameters = parse_parameters(prompt, stoppers)
90
+ prompt, _ = parse_equals(prompt, stoppers)
91
+ prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
92
+ prompt, _ = parse_newline(prompt, stoppers)
93
+ prompt, expr = parse_list_expression(prompt, stoppers)
94
+ return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
95
+
96
+
97
+ def parse_parameters(prompt, stoppers):
98
+ try:
99
+ prompt, _ = parse_open_paren(prompt, stoppers)
100
+ prompt, parameters = parse_inner_parameters(prompt, stoppers)
101
+ prompt, _ = parse_close_paren(prompt, stoppers)
102
+ except ValueError:
103
+ parameters = []
104
+ return ParseResult(prompt=prompt, expr=parameters)
105
+
106
+
107
+ def parse_inner_parameters(prompt, stoppers):
108
+ parameters = []
109
+ try:
110
+ while True:
111
+ prompt, param = parse_symbol(prompt, stoppers)
112
+ parameters.append(param)
113
+ prompt, _ = parse_comma(prompt, stoppers)
114
+ except ValueError:
115
+ pass
116
+ return ParseResult(prompt=prompt, expr=parameters)
117
+
118
+
119
+ def parse_interpolation(prompt, stoppers):
120
+ prompt, _ = parse_open_square(prompt, stoppers)
121
+ prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
122
+ prompt, steps = parse_interpolation_steps(prompt, stoppers)
123
+ prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
124
+ prompt, _ = parse_close_square(prompt, stoppers)
125
+ return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
126
+
127
+
128
+ def parse_interpolation_exprs(prompt, stoppers):
129
+ exprs = []
130
+ try:
131
+ while True:
132
+ prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
133
+ if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
134
+ raise ValueError
135
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
136
+ exprs.append(expr)
137
+ except ValueError:
138
+ pass
139
+ return ParseResult(prompt=prompt, expr=exprs)
140
+
141
+
142
+ def parse_interpolation_function_name(prompt, stoppers):
143
+ try:
144
+ prompt, _ = parse_colon(prompt, stoppers)
145
+ function_names = ('linear', 'catmull', 'bezier', 'mean')
146
+ return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
147
+ except ValueError:
148
+ return ParseResult(prompt=prompt, expr=None)
149
+
150
+
151
+ def parse_interpolation_steps(prompt, stoppers):
152
+ steps = []
153
+ try:
154
+ while True:
155
+ prompt, step = parse_interpolation_step(prompt, stoppers)
156
+ steps.append(step)
157
+ prompt, _ = parse_comma(prompt, stoppers)
158
+ except ValueError:
159
+ pass
160
+ return ParseResult(prompt=prompt, expr=steps)
161
+
162
+
163
+ def parse_interpolation_step(prompt, stoppers):
164
+ try:
165
+ return parse_step(prompt, stoppers)
166
+ except ValueError:
167
+ pass
168
+ # Безопасная проверка границы
169
+ if not prompt or prompt[0] in {',', ':', ']'}:
170
+ return ParseResult(prompt=prompt, expr=None)
171
+ raise ValueError
172
+
173
+
174
+ def parse_alternation(prompt, stoppers):
175
+ prompt, _ = parse_open_square(prompt, stoppers)
176
+ prompt, exprs = parse_alternation_exprs(prompt, stoppers)
177
+ prompt, speed = parse_alternation_speed(prompt, stoppers)
178
+ prompt, _ = parse_close_square(prompt, stoppers)
179
+ return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
180
+
181
+
182
+ def parse_alternation_exprs(prompt, stoppers):
183
+ exprs = []
184
+ try:
185
+ while True:
186
+ prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
187
+ exprs.append(expr)
188
+ prompt, _ = parse_vertical_bar(prompt, stoppers)
189
+ except ValueError:
190
+ if len(exprs) < 2:
191
+ raise
192
+ return ParseResult(prompt=prompt, expr=exprs)
193
+
194
+
195
+ def parse_alternation_speed(prompt, stoppers):
196
+ try:
197
+ prompt, _ = parse_colon(prompt, stoppers)
198
+ prompt, speed = parse_step(prompt, stoppers)
199
+ return ParseResult(prompt=prompt, expr=speed)
200
+ except ValueError:
201
+ pass
202
+ return ParseResult(prompt=prompt, expr=None)
203
+
204
+
205
+ def parse_editing(prompt, stoppers):
206
+ prompt, _ = parse_open_square(prompt, stoppers)
207
+ prompt, exprs = parse_editing_exprs(prompt, stoppers)
208
+ try:
209
+ prompt, step = parse_step(prompt, stoppers)
210
+ except ValueError:
211
+ step = None
212
+ prompt, _ = parse_close_square(prompt, stoppers)
213
+ return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
214
+
215
+
216
+ def parse_editing_exprs(prompt, stoppers):
217
+ exprs = []
218
+ try:
219
+ for _ in range(2):
220
+ prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
221
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
222
+ exprs.append(expr)
223
+ except ValueError:
224
+ pass
225
+ return ParseResult(prompt=prompt, expr=exprs)
226
+
227
+
228
+ def parse_negative_attention(prompt, stoppers):
229
+ prompt, _ = parse_open_square(prompt, stoppers)
230
+ prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
231
+ prompt, _ = parse_close_square(prompt, stoppers)
232
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
233
+
234
+
235
+ def parse_positive_attention(prompt, stoppers):
236
+ prompt, _ = parse_open_paren(prompt, stoppers)
237
+ prompt, expr = parse_list_expression(prompt, {':', ')'})
238
+ prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
239
+ prompt, _ = parse_close_paren(prompt, stoppers)
240
+ if len(weight_exprs) >= 2:
241
+ return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
242
+ else:
243
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
244
+
245
+
246
+ def parse_attention_weights(prompt, stoppers):
247
+ weights = []
248
+ try:
249
+ prompt, _ = parse_colon(prompt, stoppers)
250
+ except ValueError:
251
+ return ParseResult(prompt=prompt, expr=weights)
252
+ while True:
253
+ try:
254
+ prompt, weight_expr = parse_weight(prompt, stoppers)
255
+ weights.append(weight_expr)
256
+ prompt, _ = parse_comma(prompt, stoppers)
257
+ except ValueError:
258
+ return ParseResult(prompt=prompt, expr=weights)
259
+
260
+
261
+ def parse_step(prompt, stoppers):
262
+ try:
263
+ prompt, step = parse_int_not_float(prompt, stoppers)
264
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
265
+ except ValueError:
266
+ pass
267
+ try:
268
+ prompt, step = parse_float(prompt, stoppers)
269
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
270
+ except ValueError:
271
+ pass
272
+ return parse_substitution(prompt, stoppers)
273
+
274
+
275
+ def parse_weight(prompt, stoppers):
276
+ try:
277
+ prompt, step = parse_float(prompt, stoppers)
278
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
279
+ except ValueError:
280
+ pass
281
+ return parse_substitution(prompt, stoppers)
282
+
283
+
284
+ def parse_symbol(prompt, stoppers):
285
+ prompt, _ = parse_dollar(prompt)
286
+ return parse_symbol_text(prompt, stoppers)
287
+
288
+
289
+ def parse_symbol_text(prompt, stoppers):
290
+ return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
291
+
292
+
293
+ def parse_float(prompt, stoppers):
294
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
295
+
296
+
297
+ def parse_int_not_float(prompt, stoppers):
298
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
299
+
300
+
301
+ def parse_dollar(prompt):
302
+ dollar_sign = re.escape('$')
303
+ return parse_token(prompt, f'({dollar_sign})')
304
+
305
+
306
+ def parse_equals(prompt, stoppers):
307
+ return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
308
+
309
+
310
+ def parse_comma(prompt, stoppers):
311
+ return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
312
+
313
+
314
+ def parse_colon(prompt, stoppers):
315
+ return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
316
+
317
+
318
+ def parse_vertical_bar(prompt, stoppers):
319
+ return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
320
+
321
+
322
+ def parse_open_square(prompt, stoppers):
323
+ return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
324
+
325
+
326
+ def parse_close_square(prompt, stoppers):
327
+ return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
328
+
329
+
330
+ def parse_open_paren(prompt, stoppers):
331
+ return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
332
+
333
+
334
+ def parse_close_paren(prompt, stoppers):
335
+ return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
336
+
337
+
338
+ def parse_newline(prompt, stoppers):
339
+ return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
340
+
341
+
342
+ def parse_token(prompt, regex):
343
+ match = re.match(regex, prompt)
344
+ if match is None:
345
+ raise ValueError
346
+ return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
347
+
348
+
349
+ def whitespace_tail_regex(regex, stoppers):
350
+ if '\n' in stoppers:
351
+ return rf'({regex})[ \t\f\r]*'
352
+ return rf'({regex})\s*'
353
+
354
+
355
+ def set_concat(left, right):
356
+ result = set(left)
357
+ result.update(right)
358
+ return result
z-2-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scale_t(t, positions):
2
+ if t >= 1.:
3
+ return 1.
4
+
5
+ if t <= 0.:
6
+ return 0.
7
+
8
+ distances = []
9
+ for i in range(len(positions)-1):
10
+ distances.append(positions[i+1] - positions[i])
11
+
12
+ total_distance = sum(distances)
13
+ for i in range(len(distances)):
14
+ distances[i] = distances[i]/total_distance
15
+
16
+ for i in range(len(distances)-1):
17
+ distances[i+1] = distances[i] + distances[i+1]
18
+
19
+ distances.insert(0, 0.0)
20
+
21
+ spline_index = 0
22
+ for i, distance in enumerate(distances):
23
+ if t > distance:
24
+ spline_index = i
25
+ else:
26
+ break
27
+
28
+ if spline_index >= len(distances) - 1:
29
+ return 1
30
+
31
+ local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
32
+ return (spline_index + local_ratio)/(len(distances)-1)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ total_steps = 20
37
+ for i in range(total_steps):
38
+ print(i, scale_t(i/total_steps, [9, 10]))
z-2-prompt-fusion-extension/metadata.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [Extension]
2
+ Name = prompt-fusion
z-2-prompt-fusion-extension/readme.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Fusion
2
+
3
+ Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
4
+
5
+ ```
6
+ # linear prompt interpolation
7
+ [night light:magical forest: 5, 15]
8
+
9
+ # catmull-rom curve prompt interpolation
10
+ [night light:magical forest:norvegian territory: 5, 15, 25:catmull]
11
+
12
+ # alternation interpolation
13
+ [ufo|a strange sight:0.5]
14
+
15
+ # linear attention interpolation
16
+ (fire extinguisher: 1.0, 2.0)
17
+
18
+ # prompt-editing-aware attention interpolation
19
+ [(fire extinguisher: 1.0, 2.0)::5]
20
+
21
+ # weighted sum of conditions
22
+ [space station : kitchen mixer :: mean]
23
+
24
+ # define functions and variables to simplify repeating patterns and use a consistent structure
25
+ $prompt($style, $quality, $character, $background) = (
26
+ A detailed picture in the style of $style,
27
+ $quality,
28
+ $character lying back,
29
+ $background in the background
30
+ :1)
31
+ ```
32
+
33
+ ## Features
34
+
35
+ - [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
36
+ - [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
37
+ - [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
38
+ - [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
39
+ - [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
40
+ - Complete backwards compatibility with the builtin prompt syntax of the webui
41
+
42
+ The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
43
+
44
+ Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
45
+
46
+ The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
47
+
48
+ ## Usage
49
+ - Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
50
+
51
+ ## Examples
52
+
53
+ ### 1. Influencing the beginning of the sampling process
54
+
55
+ Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
56
+
57
+ ```
58
+ [lion:bird:girl: , 7, 10]
59
+ ```
60
+
61
+ ![curve1](https://user-images.githubusercontent.com/32277961/214725976-b72bafc6-0c5d-4491-9c95-b73da41da082.gif)
62
+
63
+ ### 2. Influencing the middle of the sampling process
64
+
65
+ Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
66
+
67
+ ```
68
+ [fireball:seawater:dragon: , .1, .4:bezier] monster
69
+ ```
70
+
71
+ ![curve2](https://user-images.githubusercontent.com/32277961/214941229-2dccad78-f856-42bb-ae6b-16b65b273cda.gif)
72
+
73
+ ## Webui supported releases
74
+
75
+ The following webui releases are officially supported:
76
+ - `v1.0.0-pre`
77
+ - `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
78
+
79
+ ## Installation
80
+ 1. Visit the **Extensions** tab of Automatic's WebUI.
81
+ 2. Visit the **Available** subtab.
82
+ 3. Look for **Prompt Fusion**.
83
+ 4. Press the **Install** button.
84
+ 5. Wait for the webui to finish downloading the extension.
85
+ 6. Visit the **Installed** subtab.
86
+ 7. click on **Apply and restart UI**.
87
+
88
+ Alternatively, instead of steps 6 and 7, you can restart the webui completely.
89
+
90
+ ## Related Projects
91
+
92
+ - Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
93
+ - Shift Attention: https://github.com/yownas/shift-attention
94
+ - Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
95
+
z-2-prompt-fusion-extension/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
z-2-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc ADDED
Binary file (5.58 kB). View file
 
z-2-prompt-fusion-extension/scripts/promptlang.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from lib_prompt_fusion import hijacker, empty_cond, global_state, interpolation_tensor, prompt_parser as prompt_fusion_parser
3
+ from modules import scripts, script_callbacks, prompt_parser, shared
4
+
5
+
6
+ fusion_hijacker_attribute = '__fusion_hijacker'
7
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
8
+ module=prompt_parser,
9
+ hijacker_attribute=fusion_hijacker_attribute,
10
+ register_uninstall=script_callbacks.on_script_unloaded)
11
+
12
+
13
+ def on_ui_settings():
14
+ section = ('prompt-fusion', 'Prompt Fusion')
15
+ shared.opts.add_option('prompt_fusion_enabled', shared.OptionInfo(True, 'Enable prompt-fusion extension', section=section))
16
+ shared.opts.add_option('prompt_fusion_slerp_scale', shared.OptionInfo(0, 'Slerp scale (0 = linear geometry, 1 = slerp geometry)', component=gr.Number, section=section))
17
+ shared.opts.add_option('prompt_fusion_slerp_negative_origin', shared.OptionInfo(True, 'use negative prompt as slerp origin', section=section))
18
+ shared.opts.add_option('prompt_fusion_slerp_epsilon', shared.OptionInfo(0.0001, 'Slerp epsilon (fallback on linear geometry when conds are too similar. 0 = parallel, 1 = perpendicular)', component=gr.Number, section=section))
19
+
20
+
21
+ script_callbacks.on_ui_settings(on_ui_settings)
22
+
23
+
24
+ @prompt_parser_hijacker.hijack('get_learned_conditioning')
25
+ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
26
+ if not shared.opts.prompt_fusion_enabled:
27
+ return original_function(model, prompts, total_steps, *args, **kwargs)
28
+
29
+ hires_steps = args[0] if len(args) >= 1 else None
30
+ use_old_scheduling = args[1] if len(args) >= 2 else True
31
+ is_hires = hires_steps is not None
32
+ if is_hires:
33
+ real_total_steps = hires_steps
34
+ else:
35
+ real_total_steps = total_steps
36
+
37
+ if hasattr(prompts, 'is_negative_prompt'):
38
+ is_negative_prompt = prompts.is_negative_prompt
39
+ else:
40
+ is_negative_prompt = global_state.old_webui_is_negative
41
+
42
+ empty_cond.init(model)
43
+
44
+ tensor_builders = _parse_tensor_builders(prompts, real_total_steps, is_hires, use_old_scheduling)
45
+ if hasattr(prompt_parser, 'SdConditioning'):
46
+ empty_conditioning = prompt_parser.SdConditioning(prompts)
47
+ empty_conditioning.clear()
48
+ else:
49
+ empty_conditioning = []
50
+
51
+ flattened_prompts, consecutive_ranges = _get_flattened_prompts(tensor_builders, empty_conditioning)
52
+ extra_args = args[2:] if len(args) > 2 else ()
53
+ flattened_schedules = original_function(model, flattened_prompts, total_steps, hires_steps, use_old_scheduling, *extra_args, **kwargs)
54
+
55
+ if isinstance(flattened_schedules[0][0].cond, dict): # sdxl
56
+ CondWrapper = interpolation_tensor.DictCondWrapper
57
+ else:
58
+ CondWrapper = interpolation_tensor.TensorCondWrapper
59
+
60
+ flattened_schedules = [
61
+ [
62
+ prompt_parser.ScheduledPromptConditioning(cond=CondWrapper(schedule.cond), end_at_step=schedule.end_at_step)
63
+ for schedule in subschedules
64
+ ]
65
+ for subschedules in flattened_schedules
66
+ ]
67
+
68
+ cond_tensors = (tensor_builder.build(flattened_schedules[begin:end], empty_cond.get())
69
+ for begin, end, tensor_builder
70
+ in zip(consecutive_ranges[:-1], consecutive_ranges[1:], tensor_builders))
71
+
72
+ schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires)
73
+ for cond_tensor in cond_tensors]
74
+
75
+ if is_negative_prompt:
76
+ if hires_steps is not None:
77
+ global_state.negative_schedules_hires = schedules[0]
78
+ else:
79
+ global_state.negative_schedules = schedules[0]
80
+
81
+ schedules = [
82
+ [
83
+ prompt_parser.ScheduledPromptConditioning(cond=schedule.cond.original_cond, end_at_step=schedule.end_at_step)
84
+ for schedule in subschedules
85
+ ]
86
+ for subschedules in schedules
87
+ ]
88
+
89
+ return schedules
90
+
91
+
92
+ @prompt_parser_hijacker.hijack('get_multicond_learned_conditioning')
93
+ def _hijacked_get_multicond_learned_conditioning(*args, original_function, **kwargs):
94
+ res = original_function(*args, **kwargs)
95
+ global_state.old_webui_is_negative = False
96
+ return res
97
+
98
+
99
+ def _parse_tensor_builders(prompts, total_steps, is_hires, use_old_scheduling):
100
+ tensor_builders = []
101
+
102
+ for prompt in prompts:
103
+ expr = prompt_fusion_parser.parse_prompt(prompt)
104
+ tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
105
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires, use_old_scheduling)
106
+ tensor_builders.append(tensor_builder)
107
+
108
+ return tensor_builders
109
+
110
+
111
+ def _get_flattened_prompts(tensor_builders, flattened_prompts=None):
112
+ if flattened_prompts is None:
113
+ flattened_prompts = []
114
+ consecutive_ranges = [0]
115
+
116
+ for tensor_builder in tensor_builders:
117
+ flattened_prompts.extend(tensor_builder.get_prompt_database())
118
+ consecutive_ranges.append(len(flattened_prompts))
119
+
120
+ return flattened_prompts, consecutive_ranges
121
+
122
+
123
+ def _sample_tensor_schedules(tensor, steps, is_hires):
124
+ schedules = []
125
+
126
+ for step in range(steps):
127
+ origin_cond = global_state.get_origin_cond_at(step, is_hires)
128
+ params = interpolation_tensor.InterpolationParams(step / steps, step, steps, global_state.get_slerp_scale(), global_state.get_slerp_epsilon())
129
+ schedule_cond = tensor.interpolate(params, origin_cond, empty_cond.get())
130
+ if schedules and schedules[-1].cond == schedule_cond:
131
+ schedules[-1] = prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedules[-1].cond)
132
+ else:
133
+ schedules.append(prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedule_cond))
134
+
135
+ return schedules
136
+
137
+
138
+ class PromptFusionScript(scripts.Script):
139
+ def title(self):
140
+ return 'Prompt Fusion'
141
+
142
+ def show(self, is_img2img):
143
+ return scripts.AlwaysVisible
144
+
145
+ def process(self, p, *args):
146
+ global_state.negative_schedules = None
147
+ global_state.old_webui_is_negative = True
z-2-prompt-fusion-extension/test/parser_tests.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion.prompt_parser import parse_prompt
2
+ from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
3
+
4
+
5
+ def run_functional_tests(total_steps=100):
6
+ for i, (given, expected) in enumerate(functional_parse_test_cases):
7
+ expr = parse_prompt(given)
8
+ tensor_builder = InterpolationTensorBuilder()
9
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
10
+
11
+ actual = tensor_builder.get_prompt_database()
12
+
13
+ if type(expected) is set:
14
+ assert set(actual) == expected, f"{actual} != {expected}"
15
+ else:
16
+ assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
17
+
18
+
19
+ functional_parse_test_cases = [
20
+ ('single',)*2,
21
+ ('some space separated text',)*2,
22
+ ('(legacy weighted prompt:-2.1)',)*2,
23
+ ('mixed (legacy weight:3.6) and text',)*2,
24
+ ('legacy [range begin:0] thingy',)*2,
25
+ ('legacy [range end::3] thingy',)*2,
26
+ ('legacy [[nested range::3]:2] thingy',)*2,
27
+ ('legacy [[nested range:2]::3] thingy',)*2,
28
+ ('sugar [range:,abc:3] thingy',)*2,
29
+ ('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
30
+ ('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
31
+ ('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
32
+ ('legacy [from:to:2] thingy',)*2,
33
+ ('legacy [negative weight]',)*2,
34
+ ('legacy (positive weight)',)*2,
35
+ ('[abc:1girl:2]',)*2,
36
+ ('[::]',)*2,
37
+ ('[a:b:]',)*2,
38
+ ('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
39
+ ('1girl',)*2,
40
+ ('dashes-in-text',)*2,
41
+ ('text, separated with, comas',)*2,
42
+ ('{prompt}',)*2,
43
+ ('[abc|def ghi|jkl]',)*2,
44
+ ('merging this AND with this',)*2,
45
+ (':',)*2,
46
+ (r'portrait \(object\)',)*2,
47
+ (r'\[escaped square\]',)*2,
48
+ (r'\$var = abc',)*2,
49
+ (r'\\$ arst',)*2,
50
+ (r'$$ arst',)*2,
51
+ ('$var = abc', ''),
52
+ ('$a = prompt value\n$a', 'prompt value'),
53
+ ('$a = prompt value\n$b = $a\n$b', 'prompt value'),
54
+ ('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
55
+ ('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
56
+ ('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
57
+ ('a [b:c:5, 6] d', {'a b d', 'a c d'}),
58
+ ('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
59
+ ('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
60
+ ('a [b:c:,] d', {'a b d', 'a c d'}),
61
+ ('0[1.0:1.1:,]2[3.0:3.1:,]4', {
62
+ '0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
63
+ '0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
64
+ }),
65
+ ('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
66
+ '0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
67
+ '0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
68
+ '0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
69
+ }),
70
+ ('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
71
+ '0.0 1.0 2.0', '0.0 1.0 2.1',
72
+ '0.1 1.0 2.0', '0.1 1.0 2.1',
73
+ '0.0 1.1 2.0', '0.0 1.1 2.1',
74
+ '0.1 1.1 2.0', '0.1 1.1 2.1',
75
+ }),
76
+ ('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
77
+ ('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
78
+ ('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
79
+ ('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
80
+ ('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
81
+ ('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
82
+ ('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
83
+ ('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
84
+ ('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
85
+ ('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
86
+ ('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
87
+ ('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
88
+ ('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
89
+ ('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
90
+ ('[a|b|c]', '[a|b|c]'),
91
+ ('[a|b|c:]', '[a|b|c]'),
92
+ ('[a|b|c:1]', {'a', 'b', 'c'}),
93
+ ('[a|b|c:2]', {'a', 'b', 'c'}),
94
+ ('[a|b|c:0.5]', {'a', 'b', 'c'}),
95
+ ('[a|b|c:1.1]', {'a', 'b', 'c'}),
96
+ ('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
97
+ ('[a:b:c::mean]', {'a', 'b', 'c'}),
98
+ ('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
99
+ ('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
100
+ ]
101
+
102
+
103
+ def run_tests():
104
+ run_functional_tests()
z-2-prompt-fusion-extension/test/run_all.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ import parser_tests
4
+
5
+
6
+ if __name__ == '__main__':
7
+ parser_tests.run_tests()