dikdimon commited on
Commit
733b275
·
verified ·
1 Parent(s): 030cb3d

Upload z-patch_prompt-fusion-extension using SD-Hub

Browse files
z-patch_prompt-fusion-extension/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv/
2
+ /.idea/
3
+ __pycache__/
z-patch_prompt-fusion-extension/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2003
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
z-patch_prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from lib_prompt_fusion import interpolation_functions
3
+ from lib_prompt_fusion.t_scaler import scale_t
4
+ from lib_prompt_fusion import interpolation_tensor
5
+
6
+
7
+ class ListExpression:
8
+ def __init__(self, expressions):
9
+ self.__expressions = expressions
10
+
11
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
12
+ if not self.__expressions:
13
+ return
14
+
15
+ def expr_extend_tensor(expr):
16
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
17
+
18
+ expr_extend_tensor(self.__expressions[0])
19
+ for expression in self.__expressions[1:]:
20
+ tensor_builder.append(' ')
21
+ expr_extend_tensor(expression)
22
+
23
+
24
+ class InterpolationExpression:
25
+ @staticmethod
26
+ def create(exprs, steps, function_name):
27
+ if function_name == "mean":
28
+ return AverageExpression(exprs, steps)
29
+
30
+ max_len = min(len(exprs), len(steps))
31
+ exprs = exprs[:max_len]
32
+ steps = steps[:max_len]
33
+
34
+ return InterpolationExpression(exprs, steps, function_name)
35
+
36
+ def __init__(self, expressions, steps, function_name=None):
37
+ assert len(expressions) >= 2
38
+ assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
39
+ self.__expressions = expressions
40
+ self.__steps = steps
41
+ self.__function_name = function_name if function_name is not None else 'linear'
42
+
43
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
44
+ def tensor_updater(expr):
45
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
46
+
47
+ tensor_builder.extrude(
48
+ [tensor_updater(expr) for expr in self.__expressions],
49
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
50
+
51
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
52
+ steps = list(self.__steps)
53
+ if steps[0] is None:
54
+ steps[0] = LiftExpression(str(steps_range[0] - 1))
55
+ if steps[-1] is None:
56
+ steps[-1] = LiftExpression(str(steps_range[1] - 1))
57
+
58
+ for i, step in enumerate(steps):
59
+ if step is None:
60
+ continue
61
+
62
+ step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
63
+
64
+ if use_old_scheduling and 0 < step < 1:
65
+ step *= total_steps
66
+ elif not use_old_scheduling and isinstance(step, float):
67
+ step = (step - int(is_hires)) * total_steps
68
+ else:
69
+ step += 1
70
+
71
+ steps[i] = int(step)
72
+
73
+ i = 1
74
+ while i < len(steps):
75
+ none_len = 0
76
+ while steps[i + none_len] is None:
77
+ none_len += 1
78
+
79
+ min_step, max_step = steps[i - 1], steps[i + none_len]
80
+
81
+ for j in range(none_len):
82
+ steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
83
+
84
+ i += 1 + none_len
85
+
86
+ interpolation_function = {
87
+ 'linear': interpolation_functions.compute_linear,
88
+ 'bezier': interpolation_functions.compute_bezier,
89
+ 'catmull': interpolation_functions.compute_catmull,
90
+ }[self.__function_name]
91
+
92
+ def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
93
+ scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
94
+ scaled_t = scale_t(scaled_t, steps)
95
+
96
+ new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
97
+ return interpolation_function(conds, new_params)
98
+
99
+ return steps_scale_t
100
+
101
+
102
+ class AverageExpression:
103
+ def __init__(self, expressions, weights):
104
+ if len(expressions) < len(weights):
105
+ raise ValueError
106
+
107
+ self.__expressions = expressions
108
+ self.__weights = weights
109
+
110
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
111
+ def tensor_updater(expr):
112
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
113
+
114
+ tensor_builder.extrude(
115
+ [tensor_updater(expr) for expr in self.__expressions],
116
+ self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
117
+
118
+ def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
119
+ weights = [
120
+ _eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
121
+ for weight in self.__weights
122
+ ]
123
+ explicit_weights = [weight for weight in weights if weight is not None]
124
+ weights = [
125
+ weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
126
+ if weight is not None
127
+ else 1 / len(self.__expressions)
128
+ for weight in weights
129
+ ]
130
+ weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
131
+
132
+ def interpolation_function(conds, _params):
133
+ total = None
134
+ for cond, weight in zip(conds, weights):
135
+ cond *= weight
136
+ if total is None:
137
+ total = cond
138
+ else:
139
+ total += cond
140
+
141
+ return total
142
+
143
+ return interpolation_function
144
+
145
+
146
+ class AlternationExpression:
147
+ def __init__(self, expressions, speed):
148
+ self.__expressions = expressions
149
+ self.__speed = speed
150
+
151
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
152
+ if self.__speed is None:
153
+ speed = None
154
+ else:
155
+ speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
156
+
157
+ if speed is None:
158
+ tensor_builder.append('[')
159
+ for expr_i, expr in enumerate(self.__expressions):
160
+ if expr_i >= 1:
161
+ tensor_builder.append('|')
162
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
163
+ tensor_builder.append(']')
164
+ return
165
+
166
+ def tensor_updater(expr):
167
+ return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
168
+
169
+ exprs = self.__expressions + [self.__expressions[0]]
170
+
171
+ tensor_builder.extrude(
172
+ [tensor_updater(expr) for expr in exprs],
173
+ self.get_interpolation_function(speed, exprs, steps_range, total_steps))
174
+
175
+ def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
176
+ def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
177
+ wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
178
+ if wrapped_t < 0:
179
+ wrapped_t = wrapped_t + 1
180
+ new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
181
+ return interpolation_functions.compute_linear(control_points, new_params)
182
+
183
+ return compute_wrap
184
+
185
+
186
+ class EditingExpression:
187
+ def __init__(self, expressions, step):
188
+ assert 1 <= len(expressions) <= 2
189
+ self.__expressions = expressions
190
+ self.__step = step
191
+
192
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
193
+ if self.__step is None:
194
+ tensor_builder.append('[')
195
+ for expr_i, expr in enumerate(self.__expressions):
196
+ expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
197
+ tensor_builder.append(':')
198
+ tensor_builder.append(']')
199
+ return
200
+
201
+ step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
202
+ step_int = step
203
+ if use_old_scheduling and 0 < step < 1:
204
+ step_int *= total_steps
205
+ elif not use_old_scheduling and isinstance(step, float):
206
+ step_int = (step_int - int(is_hires)) * total_steps
207
+ else:
208
+ step_int += 1
209
+
210
+ step_int = int(step_int)
211
+
212
+ tensor_builder.append('[')
213
+ for expr_i, expr in enumerate(self.__expressions):
214
+ expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
215
+ expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
216
+ tensor_builder.append(':')
217
+
218
+ tensor_builder.append(f'{step}]')
219
+
220
+
221
+ class WeightedExpression:
222
+ def __init__(self, nested, weight=None, positive=True):
223
+ self.__nested = nested
224
+ if not positive:
225
+ assert weight is None
226
+
227
+ self.__weight = weight
228
+ self.__positive = positive
229
+
230
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
231
+ open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
232
+ tensor_builder.append(open_bracket)
233
+ self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
234
+
235
+ if self.__weight is not None:
236
+ tensor_builder.append(':')
237
+ self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
238
+
239
+ tensor_builder.append(close_bracket)
240
+
241
+
242
+ class WeightInterpolationExpression:
243
+ def __init__(self, nested, weight_begin, weight_end):
244
+ self.__nested = nested
245
+ self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
246
+ self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
247
+
248
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
249
+ steps_range_size = steps_range[1] - steps_range[0]
250
+
251
+ weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
252
+ weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
253
+
254
+ for i in range(steps_range_size):
255
+ step = i + steps_range[0]
256
+
257
+ weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
258
+ weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
259
+ if step > steps_range[0]:
260
+ weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
261
+ if step + 1 < steps_range[1]:
262
+ weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
263
+
264
+ weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
265
+
266
+
267
+ class DeclarationExpression:
268
+ def __init__(self, symbol, parameters, value, target):
269
+ self.__symbol = symbol
270
+ self.__value = value
271
+ self.__target = target
272
+ self.__parameters = parameters
273
+
274
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
275
+ updated_context = dict(context)
276
+ updated_context[self.__symbol] = (self.__value, self.__parameters)
277
+ self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
278
+
279
+
280
+ class SubstitutionExpression:
281
+ def __init__(self, symbol, arguments):
282
+ self.__symbol = symbol
283
+ self.__arguments = arguments
284
+
285
+ def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
286
+ updated_context = dict(context)
287
+ nested, parameters = context[self.__symbol]
288
+ for argument, parameter in zip(self.__arguments, parameters):
289
+ updated_context[parameter] = argument, []
290
+ nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
291
+
292
+
293
+ class LiftExpression:
294
+ def __init__(self, value):
295
+ self.__value = value
296
+
297
+ def extend_tensor(self, tensor_builder, *_args, **_kwargs):
298
+ tensor_builder.append(self.__value)
299
+
300
+
301
+ def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
302
+ mock_database = ['']
303
+ expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
304
+ try:
305
+ return int(mock_database[0])
306
+ except ValueError:
307
+ return float(mock_database[0])
z-patch_prompt-fusion-extension/lib_prompt_fusion/empty_cond.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion import interpolation_tensor
2
+
3
+
4
+ _empty_cond = None
5
+
6
+
7
+ def get():
8
+ return _empty_cond
9
+
10
+
11
+ def init(model):
12
+ global _empty_cond
13
+ cond = model.get_learned_conditioning([''])
14
+ if isinstance(cond, dict):
15
+ cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
16
+ else:
17
+ cond = interpolation_tensor.TensorCondWrapper(cond[0])
18
+
19
+ _empty_cond = cond
z-patch_prompt-fusion-extension/lib_prompt_fusion/geometries.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from lib_prompt_fusion import interpolation_tensor
4
+
5
+
6
+ def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
7
+ p0, p1 = control_points
8
+ p0_norm = torch.linalg.norm(p0)
9
+ p1_norm = torch.linalg.norm(p1)
10
+
11
+ similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
12
+ similarity = min(1., max(-1., float(similarity)))
13
+ if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
14
+ return linear_geometry(control_points, params)
15
+
16
+ angle = math.acos(float(similarity)) / 2
17
+
18
+ slerp_t = angle * (2 * params.t - 1)
19
+ slerp_t = math.tan(slerp_t) / math.tan(angle)
20
+ slerp_t = (slerp_t + 1) / 2
21
+
22
+ normalized_p1 = p1 / p1_norm * p0_norm
23
+ slerp_p = p0 + (normalized_p1 - p0) * slerp_t
24
+ slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
25
+
26
+ lerp_p = linear_geometry(control_points, params)
27
+ return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
28
+
29
+
30
+ def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
31
+ p0, p1 = control_points
32
+ res = p0 + (p1 - p0) * params.t
33
+ return res
z-patch_prompt-fusion-extension/lib_prompt_fusion/global_state.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Literal
2
+ import os
3
+ from modules import shared, prompt_parser
4
+ from lib_prompt_fusion import empty_cond
5
+
6
+
7
+ old_webui_is_negative: bool = False
8
+ negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
9
+ negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
10
+
11
+
12
+ def get_origin_cond_at(step: int, is_hires: bool = False):
13
+ fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
14
+ if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
15
+ return empty_cond.get()
16
+
17
+ for schedule in fallback_schedules:
18
+ if schedule.end_at_step >= step:
19
+ return schedule.cond
20
+
21
+ return empty_cond.get()
22
+
23
+
24
+ def get_slerp_scale():
25
+ return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
26
+
27
+
28
+ def get_slerp_epsilon():
29
+ return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
30
+
31
+ def get_neg_origin_mode() -> Literal["first","linear","spherical"]:
32
+ """
33
+ Режим агрегирования origin для негативного prompt'а:
34
+ - first : как сейчас (первая подпрампта)
35
+ - linear : взвешенное линейное усреднение
36
+ - spherical : «сферическое» усреднение (норм., сумма, норм.)
37
+ Источник: сначала UI-настройка (если есть), потом ENV, по умолчанию 'first'.
38
+ """
39
+ try:
40
+ mode = (shared.opts.data.get("prompt_fusion_neg_origin_mode") or "first").lower()
41
+ except Exception:
42
+ mode = os.environ.get("PROMPT_FUSION_NEG_ORIGIN_MODE", "first").lower()
43
+ if mode not in ("first","linear","spherical"):
44
+ mode = "first"
45
+ return mode
z-patch_prompt-fusion-extension/lib_prompt_fusion/hijacker.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModuleHijacker:
2
+ def __init__(self, module):
3
+ self.__module = module
4
+ self.__original_functions = dict()
5
+
6
+ def hijack(self, attribute):
7
+ if attribute not in self.__original_functions:
8
+ self.__original_functions[attribute] = getattr(self.__module, attribute)
9
+
10
+ def decorator(function):
11
+ def wrapper(*args, **kwargs):
12
+ return function(*args, **kwargs, original_function=self.__original_functions[attribute])
13
+
14
+ setattr(self.__module, attribute, wrapper)
15
+ return function
16
+
17
+ return decorator
18
+
19
+ def reset_module(self):
20
+ for attribute, original_function in self.__original_functions.items():
21
+ setattr(self.__module, attribute, original_function)
22
+
23
+ self.__original_functions.clear()
24
+
25
+ @staticmethod
26
+ def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
27
+ if not hasattr(module, hijacker_attribute):
28
+ module_hijacker = ModuleHijacker(module)
29
+ setattr(module, hijacker_attribute, module_hijacker)
30
+ register_uninstall(lambda: delattr(module, hijacker_attribute))
31
+ register_uninstall(module_hijacker.reset_module)
32
+ return module_hijacker
33
+ else:
34
+ return getattr(module, hijacker_attribute)
z-patch_prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from lib_prompt_fusion import interpolation_tensor, geometries
4
+
5
+
6
+ def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
7
+ if len(control_points) <= 2:
8
+ return geometries.slerp_geometry(control_points, params)
9
+ else:
10
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
11
+ cp0 = control_points[target_curve]
12
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
13
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
14
+ return geometries.slerp_geometry([cp0, cp1], new_params)
15
+
16
+
17
+ def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
18
+ def compute_casteljau(ps, size):
19
+ for i in reversed(range(1, size)):
20
+ for j in range(i):
21
+ ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
22
+
23
+ return ps[0]
24
+
25
+ if len(control_points) == 1:
26
+ return control_points[0]
27
+ elif len(control_points) == 2:
28
+ return geometries.slerp_geometry(control_points, params)
29
+ copied_control_points = copy.deepcopy(control_points)
30
+ return compute_casteljau(copied_control_points, len(copied_control_points))
31
+
32
+
33
+ def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
34
+ if len(control_points) <= 2:
35
+ return compute_linear(control_points, params)
36
+ else:
37
+ target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
38
+ g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
39
+ cp0 = control_points[target_curve]
40
+ cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
41
+ g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
42
+ ip0 = cp0 + (cp1 - g0)/6
43
+ ip1 = cp1 + (cp0 - g1)/6
44
+
45
+ new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
46
+ return compute_bezier([cp0, ip0, ip1, cp1], new_params)
47
+
48
+
49
+ if __name__ == '__main__':
50
+ import turtle as tr
51
+ import torch
52
+ size = 60
53
+ turtle_tool = tr.Turtle()
54
+ turtle_tool.speed(10)
55
+ turtle_tool.up()
56
+
57
+ points = torch.Tensor([[-2., -2.], [2., 2.]])
58
+ origin = torch.Tensor([1.5, 1.6])
59
+
60
+ def sample(slerp_scale, color):
61
+ for i in range(size):
62
+ t = i / size
63
+ params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
64
+ point = origin + compute_linear(points - origin, params)
65
+ try:
66
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
67
+ turtle_tool.dot(5, color)
68
+ print(point)
69
+ except ValueError:
70
+ pass
71
+
72
+ sample(0, "black")
73
+ sample(1, "green")
74
+ sample(2, "blue")
75
+ sample(-1, "purple")
76
+ sample(-2, "orange")
77
+
78
+ for point in points:
79
+ turtle_tool.goto(tuple(float(p) * 100. for p in point))
80
+ turtle_tool.dot(5, "red")
81
+
82
+ turtle_tool.goto(tuple(float(p) * 100. for p in origin))
83
+ turtle_tool.dot(10, "red")
84
+
85
+ turtle_tool.goto(100000, 100000)
86
+ turtle_tool.dot()
87
+ tr.done()
z-patch_prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import torch
3
+ from modules import prompt_parser
4
+ from typing import NamedTuple, Union
5
+
6
+
7
+ class InterpolationParams(NamedTuple):
8
+ t: float
9
+ step: int
10
+ total_steps: int
11
+ slerp_scale: float
12
+ slerp_epsilon: float
13
+
14
+
15
+ class InterpolationTensor:
16
+ def __init__(self, sub_tensors, interpolation_function):
17
+ self.__sub_tensors = sub_tensors
18
+ self.__interpolation_function = interpolation_function
19
+
20
+ def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
21
+ cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
22
+ if params.slerp_scale != 0:
23
+ cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
24
+ return cond
25
+
26
+ def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
27
+ if self.__interpolation_function is None:
28
+ return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
29
+
30
+ control_points = [
31
+ sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
32
+ for sub_tensor in self.__sub_tensors
33
+ ]
34
+
35
+ CondWrapper, control_points_values = conds_to_cp_values(control_points)
36
+ return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
37
+
38
+ def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
39
+ schedule = None
40
+ for schedule in self.__sub_tensors:
41
+ if schedule.end_at_step >= step:
42
+ break
43
+
44
+ res = schedule.cond.extend_like(origin_cond, empty_cond)
45
+ if slerp_scale != 0:
46
+ res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
47
+ return res
48
+
49
+
50
+ def conds_to_cp_values(conds):
51
+ CondWrapper = type(conds[0])
52
+ cp_values = [
53
+ cond.to_cp_values()
54
+ for cond in conds
55
+ ]
56
+ return CondWrapper, [
57
+ [v[i] for v in cp_values]
58
+ for i in range(len(cp_values[0]))
59
+ ]
60
+
61
+
62
+ class InterpolationTensorBuilder:
63
+ def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
64
+ self.__indices_tensor = tensor if tensor is not None else 0
65
+ self.__prompt_database = prompt_database if prompt_database is not None else ['']
66
+ self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
67
+
68
+ def append(self, suffix):
69
+ for i in range(len(self.__prompt_database)):
70
+ self.__prompt_database[i] += suffix
71
+
72
+ def extrude(self, tensor_updaters, interpolation_function):
73
+ extruded_indices_tensor = []
74
+ extruded_prompt_database = []
75
+ extruded_interpolation_functions = []
76
+
77
+ for update_tensor in tensor_updaters:
78
+ nested_tensor_builder = InterpolationTensorBuilder(
79
+ self.__indices_tensor,
80
+ self.__prompt_database[:],
81
+ interpolation_functions=[])
82
+
83
+ update_tensor(nested_tensor_builder)
84
+
85
+ extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
86
+ tensor=nested_tensor_builder.__indices_tensor,
87
+ offset=len(extruded_prompt_database)))
88
+ extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
89
+ extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
90
+
91
+ self.__indices_tensor = extruded_indices_tensor
92
+ self.__prompt_database[:] = extruded_prompt_database
93
+ self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
94
+
95
+ def get_prompt_database(self):
96
+ return self.__prompt_database
97
+
98
+ @staticmethod
99
+ def __offset_tensor(tensor, offset):
100
+ try:
101
+ return tensor + offset
102
+
103
+ except TypeError:
104
+ return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
105
+
106
+ def build(self, conds, empty_cond):
107
+ max_cond_size = self.__max_cond_size(conds)
108
+ conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
109
+ return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
110
+
111
+ @staticmethod
112
+ def __build_conditionings_tensor(tensor, int_funcs, conds):
113
+ if type(tensor) is int:
114
+ return InterpolationTensor(conds[tensor], None)
115
+ else:
116
+ int_func, nested_int_funcs = int_funcs[0]
117
+ return InterpolationTensor(
118
+ [
119
+ InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
120
+ for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
121
+ ],
122
+ int_func,
123
+ )
124
+
125
+ def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
126
+ return [
127
+ [
128
+ prompt_parser.ScheduledPromptConditioning(
129
+ cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
130
+ end_at_step=schedule.end_at_step
131
+ )
132
+ for schedule in schedules
133
+ ]
134
+ for schedules in conds
135
+ ]
136
+
137
+ @staticmethod
138
+ def __max_cond_size(conds):
139
+ return max(schedule.cond.size(0)
140
+ for schedules in conds
141
+ for schedule in schedules)
142
+
143
+
144
+ @dataclasses.dataclass
145
+ class DictCondWrapper:
146
+ original_cond: dict
147
+
148
+ @staticmethod
149
+ def from_cp_values(cp_values):
150
+ return DictCondWrapper({
151
+ k: v
152
+ for k, v in zip(('crossattn', 'vector'), cp_values)
153
+ })
154
+
155
+ def size(self, *args, **kwargs):
156
+ return self.original_cond['crossattn'].size(*args, **kwargs)
157
+
158
+ def extend_like(self, that, empty):
159
+ missing_size = max(0, that.size(0) - self.size(0)) // 77
160
+ extended = DictCondWrapper(self.original_cond.copy())
161
+ extended.original_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
162
+ return extended
163
+
164
+ def resize_schedule(self, target_size, empty_cond):
165
+ cond_missing_size = (target_size - self.size(0)) // 77
166
+ if cond_missing_size <= 0:
167
+ return self
168
+
169
+ resized_cond = self.original_cond.copy()
170
+ resized_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
171
+ return DictCondWrapper(resized_cond)
172
+
173
+ def to_cp_values(self):
174
+ return list(self.original_cond.values())
175
+
176
+ def to(self, dtype: Union[dict, torch.dtype]):
177
+ if not isinstance(dtype, dict):
178
+ dtype = {
179
+ k: dtype
180
+ for k in self.original_cond.keys()
181
+ }
182
+ return DictCondWrapper({
183
+ k: v.to(dtype=dtype[k])
184
+ for k, v in self.original_cond.items()
185
+ })
186
+
187
+ @property
188
+ def dtype(self):
189
+ return {
190
+ k: v.dtype
191
+ for k, v in self.original_cond.items()
192
+ }
193
+
194
+ def __sub__(self, that):
195
+ return DictCondWrapper({
196
+ k: v - that.original_cond[k]
197
+ for k, v in self.original_cond.items()
198
+ })
199
+
200
+ def __add__(self, that):
201
+ return DictCondWrapper({
202
+ k: v + that.original_cond[k]
203
+ for k, v in self.original_cond.items()
204
+ })
205
+
206
+ def __eq__(self, that):
207
+ return all((self.original_cond[k] == that.original_cond[k]).all() for k in self.original_cond.keys())
208
+
209
+
210
+ @dataclasses.dataclass
211
+ class TensorCondWrapper:
212
+ original_cond: torch.Tensor
213
+
214
+ @staticmethod
215
+ def from_cp_values(cp_values):
216
+ return TensorCondWrapper(next(iter(cp_values)))
217
+
218
+ def size(self, *args, **kwargs):
219
+ return self.original_cond.size(*args, **kwargs)
220
+
221
+ def extend_like(self, that, empty):
222
+ missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
223
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty.original_cond] * missing_size))
224
+
225
+ def resize_schedule(self, target_size, empty_cond):
226
+ cond_missing_size = (target_size - self.original_cond.size(0)) // 77
227
+ if cond_missing_size <= 0:
228
+ return self
229
+
230
+ return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
231
+
232
+ def to_cp_values(self):
233
+ return [self.original_cond]
234
+
235
+ def to(self, dtype: torch.dtype):
236
+ return TensorCondWrapper(self.original_cond.to(dtype=dtype))
237
+
238
+ @property
239
+ def dtype(self):
240
+ return self.original_cond.dtype
241
+
242
+ def __sub__(self, that):
243
+ return TensorCondWrapper(self.original_cond - that.original_cond)
244
+
245
+ def __add__(self, that):
246
+ return TensorCondWrapper(self.original_cond + that.original_cond)
247
+
248
+ def __eq__(self, that):
249
+ return (self.original_cond == that.original_cond).all()
z-patch_prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib_prompt_fusion.ast_nodes as ast
2
+ from collections import namedtuple
3
+ import re
4
+
5
+
6
+ ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
7
+
8
+
9
+ def parse_prompt(prompt):
10
+ prompt = prompt.lstrip()
11
+ prompt, list_expr = parse_list_expression(prompt, set())
12
+ return list_expr
13
+
14
+
15
+ def parse_list_expression(prompt, stoppers):
16
+ exprs = []
17
+ try:
18
+ while True:
19
+ prompt, expr = parse_expression(prompt, stoppers)
20
+ exprs.append(expr)
21
+ except ValueError:
22
+ return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
23
+
24
+
25
+ def parse_expression(prompt, stoppers):
26
+ for parse in _parsers():
27
+ try:
28
+ return parse(prompt, stoppers)
29
+ except ValueError:
30
+ pass
31
+
32
+ raise ValueError
33
+
34
+
35
+ def _parsers():
36
+ return (
37
+ parse_text,
38
+ parse_declaration,
39
+ parse_substitution,
40
+ parse_positive_attention,
41
+ parse_negative_attention,
42
+ parse_editing,
43
+ parse_alternation,
44
+ parse_interpolation,
45
+ parse_unrestricted_text,
46
+ )
47
+
48
+
49
+ def parse_text(prompt, stoppers):
50
+ return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
51
+
52
+
53
+ def parse_unrestricted_text(prompt, stoppers):
54
+ escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
55
+ regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
56
+ prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
57
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
58
+
59
+
60
+ def parse_substitution(prompt, stoppers):
61
+ prompt, symbol = parse_symbol(prompt, stoppers)
62
+ prompt, arguments = parse_arguments(prompt, stoppers)
63
+ return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
64
+
65
+
66
+ def parse_arguments(prompt, stoppers):
67
+ try:
68
+ prompt, _ = parse_open_paren(prompt, stoppers)
69
+ prompt, arguments = parse_inner_arguments(prompt, stoppers)
70
+ prompt, _ = parse_close_paren(prompt, stoppers)
71
+ except ValueError:
72
+ arguments = []
73
+ return ParseResult(prompt=prompt, expr=arguments)
74
+
75
+
76
+ def parse_inner_arguments(prompt, stoppers):
77
+ arguments = []
78
+ try:
79
+ while True:
80
+ prompt, arg = parse_list_expression(prompt, {',', ')'})
81
+ arguments.append(arg)
82
+ prompt, _ = parse_comma(prompt, stoppers)
83
+ except ValueError:
84
+ pass
85
+
86
+ return ParseResult(prompt=prompt, expr=arguments)
87
+
88
+
89
+ def parse_declaration(prompt, stoppers):
90
+ prompt, symbol = parse_symbol(prompt, stoppers)
91
+ prompt, parameters = parse_parameters(prompt, stoppers)
92
+ prompt, _ = parse_equals(prompt, stoppers)
93
+ prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
94
+ prompt, _ = parse_newline(prompt, stoppers)
95
+ prompt, expr = parse_list_expression(prompt, stoppers)
96
+ return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
97
+
98
+
99
+ def parse_parameters(prompt, stoppers):
100
+ try:
101
+ prompt, _ = parse_open_paren(prompt, stoppers)
102
+ prompt, parameters = parse_inner_parameters(prompt, stoppers)
103
+ prompt, _ = parse_close_paren(prompt, stoppers)
104
+ except ValueError:
105
+ parameters = []
106
+ return ParseResult(prompt=prompt, expr=parameters)
107
+
108
+
109
+ def parse_inner_parameters(prompt, stoppers):
110
+ parameters = []
111
+ try:
112
+ while True:
113
+ prompt, param = parse_symbol(prompt, stoppers)
114
+ parameters.append(param)
115
+ prompt, _ = parse_comma(prompt, stoppers)
116
+ except ValueError:
117
+ pass
118
+
119
+ return ParseResult(prompt=prompt, expr=parameters)
120
+
121
+
122
+ def parse_interpolation(prompt, stoppers):
123
+ prompt, _ = parse_open_square(prompt, stoppers)
124
+ prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
125
+ prompt, steps = parse_interpolation_steps(prompt, stoppers)
126
+ prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
127
+ prompt, _ = parse_close_square(prompt, stoppers)
128
+ return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
129
+
130
+
131
+ def parse_interpolation_exprs(prompt, stoppers):
132
+ exprs = []
133
+
134
+ try:
135
+ while True:
136
+ prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
137
+ if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
138
+ raise ValueError
139
+
140
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
141
+ exprs.append(expr)
142
+ except ValueError:
143
+ pass
144
+
145
+ return ParseResult(prompt=prompt, expr=exprs)
146
+
147
+
148
+ def parse_interpolation_function_name(prompt, stoppers):
149
+ try:
150
+ prompt, _ = parse_colon(prompt, stoppers)
151
+ function_names = ('linear', 'catmull', 'bezier', 'mean')
152
+ return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
153
+ except ValueError:
154
+ return ParseResult(prompt=prompt, expr=None)
155
+
156
+
157
+ def parse_interpolation_steps(prompt, stoppers):
158
+ steps = []
159
+
160
+ try:
161
+ while True:
162
+ prompt, step = parse_interpolation_step(prompt, stoppers)
163
+ steps.append(step)
164
+ prompt, _ = parse_comma(prompt, stoppers)
165
+ except ValueError:
166
+ pass
167
+
168
+ return ParseResult(prompt=prompt, expr=steps)
169
+
170
+
171
+ def parse_interpolation_step(prompt, stoppers):
172
+ try:
173
+ return parse_step(prompt, stoppers)
174
+ except ValueError:
175
+ pass
176
+
177
+ if prompt[0] in {',', ':', ']'}:
178
+ return ParseResult(prompt=prompt, expr=None)
179
+
180
+ raise ValueError
181
+
182
+
183
+ def parse_alternation(prompt, stoppers):
184
+ prompt, _ = parse_open_square(prompt, stoppers)
185
+ prompt, exprs = parse_alternation_exprs(prompt, stoppers)
186
+ prompt, speed = parse_alternation_speed(prompt, stoppers)
187
+ prompt, _ = parse_close_square(prompt, stoppers)
188
+ return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
189
+
190
+
191
+ def parse_alternation_exprs(prompt, stoppers):
192
+ exprs = []
193
+
194
+ try:
195
+ while True:
196
+ prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
197
+ exprs.append(expr)
198
+ prompt, _ = parse_vertical_bar(prompt, stoppers)
199
+ except ValueError:
200
+ if len(exprs) < 2:
201
+ raise
202
+
203
+ return ParseResult(prompt=prompt, expr=exprs)
204
+
205
+
206
+ def parse_alternation_speed(prompt, stoppers):
207
+ try:
208
+ prompt, _ = parse_colon(prompt, stoppers)
209
+ prompt, speed = parse_step(prompt, stoppers)
210
+ return ParseResult(prompt=prompt, expr=speed)
211
+ except ValueError:
212
+ pass
213
+
214
+ return ParseResult(prompt=prompt, expr=None)
215
+
216
+
217
+ def parse_editing(prompt, stoppers):
218
+ prompt, _ = parse_open_square(prompt, stoppers)
219
+ prompt, exprs = parse_editing_exprs(prompt, stoppers)
220
+ try:
221
+ prompt, step = parse_step(prompt, stoppers)
222
+ except ValueError:
223
+ step = None
224
+
225
+ prompt, _ = parse_close_square(prompt, stoppers)
226
+ return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
227
+
228
+
229
+ def parse_editing_exprs(prompt, stoppers):
230
+ exprs = []
231
+
232
+ try:
233
+ for _ in range(2):
234
+ prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
235
+ prompt, _ = parse_colon(prompt_tmp, stoppers)
236
+ exprs.append(expr)
237
+ except ValueError:
238
+ pass
239
+
240
+ return ParseResult(prompt=prompt, expr=exprs)
241
+
242
+
243
+ def parse_negative_attention(prompt, stoppers):
244
+ prompt, _ = parse_open_square(prompt, stoppers)
245
+ prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
246
+ prompt, _ = parse_close_square(prompt, stoppers)
247
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
248
+
249
+
250
+ def parse_positive_attention(prompt, stoppers):
251
+ prompt, _ = parse_open_paren(prompt, stoppers)
252
+ prompt, expr = parse_list_expression(prompt, {':', ')'})
253
+ prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
254
+ prompt, _ = parse_close_paren(prompt, stoppers)
255
+ if len(weight_exprs) >= 2:
256
+ return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
257
+ else:
258
+ return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
259
+
260
+
261
+ def parse_attention_weights(prompt, stoppers):
262
+ weights = []
263
+ try:
264
+ prompt, _ = parse_colon(prompt, stoppers)
265
+ except ValueError:
266
+ return ParseResult(prompt=prompt, expr=weights)
267
+
268
+ while True:
269
+ try:
270
+ prompt, weight_expr = parse_weight(prompt, stoppers)
271
+ weights.append(weight_expr)
272
+ prompt, _ = parse_comma(prompt, stoppers)
273
+ except ValueError:
274
+ return ParseResult(prompt=prompt, expr=weights)
275
+
276
+
277
+ def parse_step(prompt, stoppers):
278
+ try:
279
+ prompt, step = parse_int_not_float(prompt, stoppers)
280
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
281
+ except ValueError:
282
+ pass
283
+ try:
284
+ prompt, step = parse_float(prompt, stoppers)
285
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
286
+ except ValueError:
287
+ pass
288
+
289
+ return parse_substitution(prompt, stoppers)
290
+
291
+
292
+ def parse_weight(prompt, stoppers):
293
+ try:
294
+ prompt, step = parse_float(prompt, stoppers)
295
+ return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
296
+ except ValueError:
297
+ pass
298
+
299
+ return parse_substitution(prompt, stoppers)
300
+
301
+
302
+ def parse_symbol(prompt, stoppers):
303
+ prompt, _ = parse_dollar(prompt)
304
+ return parse_symbol_text(prompt, stoppers)
305
+
306
+
307
+ def parse_symbol_text(prompt, stoppers):
308
+ return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
309
+
310
+
311
+ def parse_float(prompt, stoppers):
312
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
313
+
314
+
315
+ def parse_int_not_float(prompt, stoppers):
316
+ return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
317
+
318
+
319
+ def parse_dollar(prompt):
320
+ dollar_sign = re.escape('$')
321
+ return parse_token(prompt, f'({dollar_sign})')
322
+
323
+
324
+ def parse_equals(prompt, stoppers):
325
+ return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
326
+
327
+
328
+ def parse_comma(prompt, stoppers):
329
+ return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
330
+
331
+
332
+ def parse_colon(prompt, stoppers):
333
+ return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
334
+
335
+
336
+ def parse_vertical_bar(prompt, stoppers):
337
+ return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
338
+
339
+
340
+ def parse_open_square(prompt, stoppers):
341
+ return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
342
+
343
+
344
+ def parse_close_square(prompt, stoppers):
345
+ return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
346
+
347
+
348
+ def parse_open_paren(prompt, stoppers):
349
+ return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
350
+
351
+
352
+ def parse_close_paren(prompt, stoppers):
353
+ return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
354
+
355
+
356
+ def parse_newline(prompt, stoppers):
357
+ return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
358
+
359
+
360
+ def parse_token(prompt, regex):
361
+ match = re.match(regex, prompt)
362
+ if match is None:
363
+ raise ValueError
364
+
365
+ return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
366
+
367
+
368
+ def whitespace_tail_regex(regex, stoppers):
369
+ if '\n' in stoppers:
370
+ return rf'({regex})[ \t\f\r]*'
371
+
372
+ return rf'({regex})\s*'
373
+
374
+
375
+ def set_concat(left, right):
376
+ result = set(left)
377
+ result.update(right)
378
+ return result
z-patch_prompt-fusion-extension/lib_prompt_fusion/t_scaler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def scale_t(t, positions):
2
+ if t >= 1.:
3
+ return 1.
4
+
5
+ if t <= 0.:
6
+ return 0.
7
+
8
+ distances = []
9
+ for i in range(len(positions)-1):
10
+ distances.append(positions[i+1] - positions[i])
11
+
12
+ total_distance = sum(distances)
13
+ for i in range(len(distances)):
14
+ distances[i] = distances[i]/total_distance
15
+
16
+ for i in range(len(distances)-1):
17
+ distances[i+1] = distances[i] + distances[i+1]
18
+
19
+ distances.insert(0, 0.0)
20
+
21
+ spline_index = 0
22
+ for i, distance in enumerate(distances):
23
+ if t > distance:
24
+ spline_index = i
25
+ else:
26
+ break
27
+
28
+ if spline_index >= len(distances) - 1:
29
+ return 1
30
+
31
+ local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
32
+ return (spline_index + local_ratio)/(len(distances)-1)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ total_steps = 20
37
+ for i in range(total_steps):
38
+ print(i, scale_t(i/total_steps, [9, 10]))
z-patch_prompt-fusion-extension/metadata.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [Extension]
2
+ Name = prompt-fusion
z-patch_prompt-fusion-extension/readme.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prompt Fusion
2
+
3
+ Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
4
+
5
+ ```
6
+ # linear prompt interpolation
7
+ [night light:magical forest: 5, 15]
8
+
9
+ # catmull-rom curve prompt interpolation
10
+ [night light:magical forest:norvegian territory: 5, 15, 25:catmull]
11
+
12
+ # alternation interpolation
13
+ [ufo|a strange sight:0.5]
14
+
15
+ # linear attention interpolation
16
+ (fire extinguisher: 1.0, 2.0)
17
+
18
+ # prompt-editing-aware attention interpolation
19
+ [(fire extinguisher: 1.0, 2.0)::5]
20
+
21
+ # weighted sum of conditions
22
+ [space station : kitchen mixer :: mean]
23
+
24
+ # define functions and variables to simplify repeating patterns and use a consistent structure
25
+ $prompt($style, $quality, $character, $background) = (
26
+ A detailed picture in the style of $style,
27
+ $quality,
28
+ $character lying back,
29
+ $background in the background
30
+ :1)
31
+ ```
32
+
33
+ ## Features
34
+
35
+ - [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
36
+ - [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
37
+ - [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
38
+ - [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
39
+ - [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
40
+ - Complete backwards compatibility with the builtin prompt syntax of the webui
41
+
42
+ The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
43
+
44
+ Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
45
+
46
+ The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
47
+
48
+ ## Usage
49
+ - Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
50
+
51
+ ## Examples
52
+
53
+ ### 1. Influencing the beginning of the sampling process
54
+
55
+ Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
56
+
57
+ ```
58
+ [lion:bird:girl: , 7, 10]
59
+ ```
60
+
61
+ ![curve1](https://user-images.githubusercontent.com/32277961/214725976-b72bafc6-0c5d-4491-9c95-b73da41da082.gif)
62
+
63
+ ### 2. Influencing the middle of the sampling process
64
+
65
+ Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
66
+
67
+ ```
68
+ [fireball:seawater:dragon: , .1, .4:bezier] monster
69
+ ```
70
+
71
+ ![curve2](https://user-images.githubusercontent.com/32277961/214941229-2dccad78-f856-42bb-ae6b-16b65b273cda.gif)
72
+
73
+ ## Webui supported releases
74
+
75
+ The following webui releases are officially supported:
76
+ - `v1.0.0-pre`
77
+ - `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
78
+
79
+ ## Installation
80
+ 1. Visit the **Extensions** tab of Automatic's WebUI.
81
+ 2. Visit the **Available** subtab.
82
+ 3. Look for **Prompt Fusion**.
83
+ 4. Press the **Install** button.
84
+ 5. Wait for the webui to finish downloading the extension.
85
+ 6. Visit the **Installed** subtab.
86
+ 7. click on **Apply and restart UI**.
87
+
88
+ Alternatively, instead of steps 6 and 7, you can restart the webui completely.
89
+
90
+ ## Related Projects
91
+
92
+ - Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
93
+ - Shift Attention: https://github.com/yownas/shift-attention
94
+ - Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
95
+
z-patch_prompt-fusion-extension/requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
z-patch_prompt-fusion-extension/scripts/promptlang.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from lib_prompt_fusion import hijacker, empty_cond, global_state, interpolation_tensor, prompt_parser as prompt_fusion_parser
3
+ from modules import scripts, script_callbacks, prompt_parser, shared
4
+
5
+
6
+ fusion_hijacker_attribute = '__fusion_hijacker'
7
+ prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
8
+ module=prompt_parser,
9
+ hijacker_attribute=fusion_hijacker_attribute,
10
+ register_uninstall=script_callbacks.on_script_unloaded)
11
+
12
+
13
+ def on_ui_settings():
14
+ section = ('prompt-fusion', 'Prompt Fusion')
15
+ shared.opts.add_option('prompt_fusion_enabled', shared.OptionInfo(True, 'Enable prompt-fusion extension', section=section))
16
+ shared.opts.add_option('prompt_fusion_slerp_scale', shared.OptionInfo(0, 'Slerp scale (0 = linear geometry, 1 = slerp geometry)', component=gr.Number, section=section))
17
+ shared.opts.add_option('prompt_fusion_slerp_negative_origin', shared.OptionInfo(True, 'use negative prompt as slerp origin', section=section))
18
+ shared.opts.add_option('prompt_fusion_slerp_epsilon', shared.OptionInfo(0.0001, 'Slerp epsilon (fallback on linear geometry when conds are too similar. 0 = parallel, 1 = perpendicular)', component=gr.Number, section=section))
19
+
20
+
21
+ script_callbacks.on_ui_settings(on_ui_settings)
22
+
23
+
24
+ @prompt_parser_hijacker.hijack('get_learned_conditioning')
25
+ def _hijacked_get_learned_conditioning(model, prompts, total_steps, *args, original_function, **kwargs):
26
+ if not shared.opts.prompt_fusion_enabled:
27
+ return original_function(model, prompts, total_steps, *args, **kwargs)
28
+
29
+ hires_steps, use_old_scheduling, *_ = args if args else (None, True)
30
+ is_hires = hires_steps is not None
31
+ if is_hires:
32
+ real_total_steps = hires_steps
33
+ else:
34
+ real_total_steps = total_steps
35
+
36
+ if hasattr(prompts, 'is_negative_prompt'):
37
+ is_negative_prompt = prompts.is_negative_prompt
38
+ else:
39
+ is_negative_prompt = global_state.old_webui_is_negative
40
+
41
+ empty_cond.init(model)
42
+
43
+ tensor_builders = _parse_tensor_builders(prompts, real_total_steps, is_hires, use_old_scheduling)
44
+ if hasattr(prompt_parser, 'SdConditioning'):
45
+ empty_conditioning = prompt_parser.SdConditioning(prompts)
46
+ empty_conditioning.clear()
47
+ else:
48
+ empty_conditioning = []
49
+
50
+ flattened_prompts, consecutive_ranges = _get_flattened_prompts(tensor_builders, empty_conditioning)
51
+ flattened_schedules = original_function(model, flattened_prompts, total_steps, *args, **kwargs)
52
+
53
+ if isinstance(flattened_schedules[0][0].cond, dict): # sdxl
54
+ CondWrapper = interpolation_tensor.DictCondWrapper
55
+ else:
56
+ CondWrapper = interpolation_tensor.TensorCondWrapper
57
+
58
+ flattened_schedules = [
59
+ [
60
+ prompt_parser.ScheduledPromptConditioning(cond=CondWrapper(schedule.cond), end_at_step=schedule.end_at_step)
61
+ for schedule in subschedules
62
+ ]
63
+ for subschedules in flattened_schedules
64
+ ]
65
+
66
+ cond_tensors = (tensor_builder.build(flattened_schedules[begin:end], empty_cond.get())
67
+ for begin, end, tensor_builder
68
+ in zip(consecutive_ranges[:-1], consecutive_ranges[1:], tensor_builders))
69
+
70
+ schedules = [_sample_tensor_schedules(cond_tensor, real_total_steps, is_hires)
71
+ for cond_tensor in cond_tensors]
72
+
73
+ if is_negative_prompt:
74
+ if hires_steps is not None:
75
+ global_state.negative_schedules_hires = schedules[0]
76
+ else:
77
+ global_state.negative_schedules = schedules[0]
78
+
79
+ schedules = [
80
+ [
81
+ prompt_parser.ScheduledPromptConditioning(cond=schedule.cond.original_cond, end_at_step=schedule.end_at_step)
82
+ for schedule in subschedules
83
+ ]
84
+ for subschedules in schedules
85
+ ]
86
+
87
+ return schedules
88
+
89
+
90
+ def _pf_norm_lastdim(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
91
+ return x / (x.norm(dim=-1, keepdim=True) + eps)
92
+
93
+ def _pf_stack_weight(tensors, weights):
94
+ base = tensors[0]
95
+ w = torch.as_tensor(weights, device=base.device, dtype=base.dtype)
96
+ # -> [N, 1, 1, ...] для широковещательного умножения
97
+ w = w.view(-1, *([1] * base.dim()))
98
+ return torch.stack(tensors, dim=0), w
99
+
100
+ def _agg_tensors_linear(tensors, weights):
101
+ T, W = _pf_stack_weight(tensors, weights)
102
+ return (T * W).sum(dim=0) / (float(sum(weights)) + 1e-8)
103
+
104
+ def _agg_tensors_spherical(tensors, weights):
105
+ normed = [_pf_norm_lastdim(t) for t in tensors]
106
+ T, W = _pf_stack_weight(normed, weights)
107
+ s = (T * W).sum(dim=0)
108
+ return _pf_norm_lastdim(s)
109
+
110
+ def _agg_cond_list(conds, weights, mode: str):
111
+ # cond: либо тензор, либо dict SDXL (ключи: 'crossattn', 'vector/pooled' и т.п.)
112
+ if isinstance(conds[0], dict):
113
+ out = {}
114
+ for k in conds[0].keys():
115
+ vals = [c[k] for c in conds]
116
+ if isinstance(vals[0], list): # SDXL может хранить списки тензоров
117
+ out[k] = [
118
+ (_agg_tensors_linear if mode=="linear" else _agg_tensors_spherical)(
119
+ [vv[i] for vv in vals], weights
120
+ )
121
+ for i in range(len(vals[0]))
122
+ ]
123
+ else:
124
+ out[k] = (_agg_tensors_linear if mode=="linear" else _agg_tensors_spherical)(vals, weights)
125
+ return out
126
+ else:
127
+ return (_agg_tensors_linear if mode=="linear" else _agg_tensors_spherical)(conds, weights)
128
+
129
+ def _cond_at_step(schedule, step: int):
130
+ # schedule: List[ScheduledPromptConditioning]; step — граница сегмента
131
+ for sp in schedule:
132
+ if step <= sp.end_at_step:
133
+ return sp.cond
134
+ return schedule[-1].cond
135
+
136
+ def _aggregate_neg_schedule_for_prompt(conds_parts, schedules_flat, mode: str, total_steps: int, pp):
137
+ """
138
+ conds_parts: List[(flat_index, weight)] для ОДНОГО исходного промпта (neg)
139
+ schedules_flat: List[List[ScheduledPromptConditioning]] индексируется flat_index
140
+ mode: 'linear' | 'spherical'
141
+ Возвращает агрегированное расписание (List[ScheduledPromptConditioning]).
142
+ """
143
+ boundaries = {total_steps}
144
+ for idx, _w in conds_parts:
145
+ for sp in schedules_flat[idx]:
146
+ boundaries.add(int(sp.end_at_step))
147
+ bps = sorted(boundaries)
148
+ out = []
149
+ for end in bps:
150
+ conds = []
151
+ ws = []
152
+ for idx, w in conds_parts:
153
+ conds.append(_cond_at_step(schedules_flat[idx], end))
154
+ ws.append(float(w))
155
+ agg = _agg_cond_list(conds, ws, mode)
156
+ out.append(pp.ScheduledPromptConditioning(end_at_step=int(end), cond=agg))
157
+ return out
158
+
159
+ @prompt_parser_hijacker.hijack('get_multicond_learned_conditioning')
160
+ def _hijacked_get_multicond_learned_conditioning(model, prompts, steps, *args, original_function, **kwargs):
161
+ # Если расширение отключено — пробрасываем как раньше
162
+ if not shared.opts.prompt_fusion_enabled:
163
+ res = original_function(model, prompts, steps, *args, **kwargs)
164
+ global_state.old_webui_is_negative = False
165
+ return res
166
+
167
+ # hires/шаги (совместимо с веб-UI)
168
+ hires_steps, use_old_scheduling, *_ = args if args else (None, True)
169
+ is_hires = hires_steps is not None
170
+ real_total_steps = hires_steps if is_hires else steps
171
+
172
+ # Разбиваем мультиконд на части по версии webui
173
+ try:
174
+ conds_list, all_prompts, _ = prompt_parser.get_multicond_prompt_list(prompts)
175
+ except Exception:
176
+ conds_list = [[(i, 1.0)] for i in range(len(prompts))]
177
+ all_prompts = list(prompts)
178
+
179
+ # Строим расписания для каждой части через одиночный перехват (Fusion уже внутри)
180
+ schedules_flat = prompt_parser.get_learned_conditioning(model, all_prompts, steps, *args, **kwargs)
181
+
182
+ # --- Агрегация негативного origin по выбранному режиму ---
183
+ if global_state.old_webui_is_negative and conds_list:
184
+ mode = global_state.get_neg_origin_mode()
185
+ if mode in ("linear","spherical"):
186
+ neg_parts = conds_list[0] # [(flat_index, weight), ...]
187
+ agg_sched = _aggregate_neg_schedule_for_prompt(
188
+ conds_parts=neg_parts,
189
+ schedules_flat=schedules_flat,
190
+ mode=mode,
191
+ total_steps=real_total_steps,
192
+ pp=prompt_parser
193
+ )
194
+ if hires_steps is not None:
195
+ global_state.negative_schedules_hires = agg_sched
196
+ else:
197
+ global_state.negative_schedules = agg_sched
198
+
199
+ # Упаковка в MulticondLearnedConditioning
200
+ batch = []
201
+ for parts in conds_list:
202
+ comp = []
203
+ for flat_index, weight in parts:
204
+ comp.append(prompt_parser.ComposableScheduledPromptConditioning(schedules_flat[flat_index], weight))
205
+ batch.append(comp)
206
+
207
+ # Правильный shape для SDXL/SD1.x
208
+ if schedules_flat and schedules_flat[0]:
209
+ example_cond = schedules_flat[0][0].cond
210
+ if isinstance(example_cond, dict): # SDXL
211
+ ca = example_cond.get('crossattn')
212
+ if isinstance(ca, list) and ca:
213
+ shape = getattr(ca[0], 'shape', None) or (0,)
214
+ else:
215
+ shape = getattr(ca, 'shape', None) or (0,)
216
+ else: # SD 1.x
217
+ shape = getattr(example_cond, 'shape', None) or (0,)
218
+ else:
219
+ shape = (0,)
220
+
221
+ global_state.old_webui_is_negative = False
222
+ return prompt_parser.MulticondLearnedConditioning(shape, batch)
223
+
224
+
225
+ def _parse_tensor_builders(prompts, total_steps, is_hires, use_old_scheduling):
226
+ tensor_builders = []
227
+
228
+ for prompt in prompts:
229
+ expr = prompt_fusion_parser.parse_prompt(prompt)
230
+ tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
231
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires, use_old_scheduling)
232
+ tensor_builders.append(tensor_builder)
233
+
234
+ return tensor_builders
235
+
236
+
237
+ def _get_flattened_prompts(tensor_builders, flattened_prompts=None):
238
+ if flattened_prompts is None:
239
+ flattened_prompts = []
240
+ consecutive_ranges = [0]
241
+
242
+ for tensor_builder in tensor_builders:
243
+ flattened_prompts.extend(tensor_builder.get_prompt_database())
244
+ consecutive_ranges.append(len(flattened_prompts))
245
+
246
+ return flattened_prompts, consecutive_ranges
247
+
248
+
249
+ def _sample_tensor_schedules(tensor, steps, is_hires):
250
+ schedules = []
251
+
252
+ for step in range(steps):
253
+ origin_cond = global_state.get_origin_cond_at(step, is_hires)
254
+ params = interpolation_tensor.InterpolationParams(step / steps, step, steps, global_state.get_slerp_scale(), global_state.get_slerp_epsilon())
255
+ schedule_cond = tensor.interpolate(params, origin_cond, empty_cond.get())
256
+ if schedules and schedules[-1].cond == schedule_cond:
257
+ schedules[-1] = prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedules[-1].cond)
258
+ else:
259
+ schedules.append(prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=schedule_cond))
260
+
261
+ return schedules
262
+
263
+
264
+ class PromptFusionScript(scripts.Script):
265
+ def title(self):
266
+ return 'Prompt Fusion'
267
+
268
+ def show(self, is_img2img):
269
+ return scripts.AlwaysVisible
270
+
271
+ def process(self, p, *args):
272
+ global_state.negative_schedules = None
273
+ global_state.old_webui_is_negative = True
z-patch_prompt-fusion-extension/test/parser_tests.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib_prompt_fusion.prompt_parser import parse_prompt
2
+ from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
3
+
4
+
5
+ def run_functional_tests(total_steps=100):
6
+ for i, (given, expected) in enumerate(functional_parse_test_cases):
7
+ expr = parse_prompt(given)
8
+ tensor_builder = InterpolationTensorBuilder()
9
+ expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
10
+
11
+ actual = tensor_builder.get_prompt_database()
12
+
13
+ if type(expected) is set:
14
+ assert set(actual) == expected, f"{actual} != {expected}"
15
+ else:
16
+ assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
17
+
18
+
19
+ functional_parse_test_cases = [
20
+ ('single',)*2,
21
+ ('some space separated text',)*2,
22
+ ('(legacy weighted prompt:-2.1)',)*2,
23
+ ('mixed (legacy weight:3.6) and text',)*2,
24
+ ('legacy [range begin:0] thingy',)*2,
25
+ ('legacy [range end::3] thingy',)*2,
26
+ ('legacy [[nested range::3]:2] thingy',)*2,
27
+ ('legacy [[nested range:2]::3] thingy',)*2,
28
+ ('sugar [range:,abc:3] thingy',)*2,
29
+ ('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
30
+ ('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
31
+ ('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
32
+ ('legacy [from:to:2] thingy',)*2,
33
+ ('legacy [negative weight]',)*2,
34
+ ('legacy (positive weight)',)*2,
35
+ ('[abc:1girl:2]',)*2,
36
+ ('[::]',)*2,
37
+ ('[a:b:]',)*2,
38
+ ('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
39
+ ('1girl',)*2,
40
+ ('dashes-in-text',)*2,
41
+ ('text, separated with, comas',)*2,
42
+ ('{prompt}',)*2,
43
+ ('[abc|def ghi|jkl]',)*2,
44
+ ('merging this AND with this',)*2,
45
+ (':',)*2,
46
+ (r'portrait \(object\)',)*2,
47
+ (r'\[escaped square\]',)*2,
48
+ (r'\$var = abc',)*2,
49
+ (r'\\$ arst',)*2,
50
+ (r'$$ arst',)*2,
51
+ ('$var = abc', ''),
52
+ ('$a = prompt value\n$a', 'prompt value'),
53
+ ('$a = prompt value\n$b = $a\n$b', 'prompt value'),
54
+ ('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
55
+ ('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
56
+ ('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
57
+ ('a [b:c:5, 6] d', {'a b d', 'a c d'}),
58
+ ('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
59
+ ('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
60
+ ('a [b:c:,] d', {'a b d', 'a c d'}),
61
+ ('0[1.0:1.1:,]2[3.0:3.1:,]4', {
62
+ '0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
63
+ '0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
64
+ }),
65
+ ('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
66
+ '0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
67
+ '0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
68
+ '0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
69
+ }),
70
+ ('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
71
+ '0.0 1.0 2.0', '0.0 1.0 2.1',
72
+ '0.1 1.0 2.0', '0.1 1.0 2.1',
73
+ '0.0 1.1 2.0', '0.0 1.1 2.1',
74
+ '0.1 1.1 2.0', '0.1 1.1 2.1',
75
+ }),
76
+ ('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
77
+ ('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
78
+ ('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
79
+ ('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
80
+ ('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
81
+ ('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
82
+ ('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
83
+ ('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
84
+ ('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
85
+ ('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
86
+ ('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
87
+ ('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
88
+ ('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
89
+ ('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
90
+ ('[a|b|c]', '[a|b|c]'),
91
+ ('[a|b|c:]', '[a|b|c]'),
92
+ ('[a|b|c:1]', {'a', 'b', 'c'}),
93
+ ('[a|b|c:2]', {'a', 'b', 'c'}),
94
+ ('[a|b|c:0.5]', {'a', 'b', 'c'}),
95
+ ('[a|b|c:1.1]', {'a', 'b', 'c'}),
96
+ ('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
97
+ ('[a:b:c::mean]', {'a', 'b', 'c'}),
98
+ ('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
99
+ ('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
100
+ ]
101
+
102
+
103
+ def run_tests():
104
+ run_functional_tests()
z-patch_prompt-fusion-extension/test/run_all.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ import parser_tests
4
+
5
+
6
+ if __name__ == '__main__':
7
+ parser_tests.run_tests()