Upload z-prompt-fusion-extension using SD-Hub
Browse files- z-prompt-fusion-extension/.gitignore +3 -0
- z-prompt-fusion-extension/LICENSE +21 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py +307 -0
- z-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py +19 -0
- z-prompt-fusion-extension/lib_prompt_fusion/geometries.py +33 -0
- z-prompt-fusion-extension/lib_prompt_fusion/global_state.py +28 -0
- z-prompt-fusion-extension/lib_prompt_fusion/hijacker.py +34 -0
- z-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py +87 -0
- z-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py +249 -0
- z-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py +378 -0
- z-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py +38 -0
- z-prompt-fusion-extension/metadata.ini +2 -0
- z-prompt-fusion-extension/readme.md +95 -0
- z-prompt-fusion-extension/requirements.txt +1 -0
- z-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc +0 -0
- z-prompt-fusion-extension/scripts/promptlang.py +354 -0
- z-prompt-fusion-extension/test/parser_tests.py +104 -0
- z-prompt-fusion-extension/test/run_all.py +7 -0
z-prompt-fusion-extension/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/venv/
|
| 2 |
+
/.idea/
|
| 3 |
+
__pycache__/
|
z-prompt-fusion-extension/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2003
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/ast_nodes.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/empty_cond.cpython-310.pyc
ADDED
|
Binary file (791 Bytes). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/geometries.cpython-310.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/global_state.cpython-310.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/hijacker.cpython-310.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_functions.cpython-310.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/interpolation_tensor.cpython-310.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/prompt_parser.cpython-310.pyc
ADDED
|
Binary file (9.9 kB). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/__pycache__/t_scaler.cpython-310.pyc
ADDED
|
Binary file (873 Bytes). View file
|
|
|
z-prompt-fusion-extension/lib_prompt_fusion/ast_nodes.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from lib_prompt_fusion import interpolation_functions
|
| 3 |
+
from lib_prompt_fusion.t_scaler import scale_t
|
| 4 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ListExpression:
|
| 8 |
+
def __init__(self, expressions):
|
| 9 |
+
self.__expressions = expressions
|
| 10 |
+
|
| 11 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 12 |
+
if not self.__expressions:
|
| 13 |
+
return
|
| 14 |
+
|
| 15 |
+
def expr_extend_tensor(expr):
|
| 16 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 17 |
+
|
| 18 |
+
expr_extend_tensor(self.__expressions[0])
|
| 19 |
+
for expression in self.__expressions[1:]:
|
| 20 |
+
tensor_builder.append(' ')
|
| 21 |
+
expr_extend_tensor(expression)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InterpolationExpression:
|
| 25 |
+
@staticmethod
|
| 26 |
+
def create(exprs, steps, function_name):
|
| 27 |
+
if function_name == "mean":
|
| 28 |
+
return AverageExpression(exprs, steps)
|
| 29 |
+
|
| 30 |
+
max_len = min(len(exprs), len(steps))
|
| 31 |
+
exprs = exprs[:max_len]
|
| 32 |
+
steps = steps[:max_len]
|
| 33 |
+
|
| 34 |
+
return InterpolationExpression(exprs, steps, function_name)
|
| 35 |
+
|
| 36 |
+
def __init__(self, expressions, steps, function_name=None):
|
| 37 |
+
assert len(expressions) >= 2
|
| 38 |
+
assert len(steps) == len(expressions), 'the number of steps must be the same as the number of expressions'
|
| 39 |
+
self.__expressions = expressions
|
| 40 |
+
self.__steps = steps
|
| 41 |
+
self.__function_name = function_name if function_name is not None else 'linear'
|
| 42 |
+
|
| 43 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 44 |
+
def tensor_updater(expr):
|
| 45 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 46 |
+
|
| 47 |
+
tensor_builder.extrude(
|
| 48 |
+
[tensor_updater(expr) for expr in self.__expressions],
|
| 49 |
+
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
|
| 50 |
+
|
| 51 |
+
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 52 |
+
steps = list(self.__steps)
|
| 53 |
+
if steps[0] is None:
|
| 54 |
+
steps[0] = LiftExpression(str(steps_range[0] - 1))
|
| 55 |
+
if steps[-1] is None:
|
| 56 |
+
steps[-1] = LiftExpression(str(steps_range[1] - 1))
|
| 57 |
+
|
| 58 |
+
for i, step in enumerate(steps):
|
| 59 |
+
if step is None:
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
step = _eval_int_or_float(step, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 63 |
+
|
| 64 |
+
if use_old_scheduling and 0 < step < 1:
|
| 65 |
+
step *= total_steps
|
| 66 |
+
elif not use_old_scheduling and isinstance(step, float):
|
| 67 |
+
step = (step - int(is_hires)) * total_steps
|
| 68 |
+
else:
|
| 69 |
+
step += 1
|
| 70 |
+
|
| 71 |
+
steps[i] = int(step)
|
| 72 |
+
|
| 73 |
+
i = 1
|
| 74 |
+
while i < len(steps):
|
| 75 |
+
none_len = 0
|
| 76 |
+
while steps[i + none_len] is None:
|
| 77 |
+
none_len += 1
|
| 78 |
+
|
| 79 |
+
min_step, max_step = steps[i - 1], steps[i + none_len]
|
| 80 |
+
|
| 81 |
+
for j in range(none_len):
|
| 82 |
+
steps[i + j] = min_step + (max_step - min_step) * (j + 1) / (none_len + 1)
|
| 83 |
+
|
| 84 |
+
i += 1 + none_len
|
| 85 |
+
|
| 86 |
+
interpolation_function = {
|
| 87 |
+
'linear': interpolation_functions.compute_linear,
|
| 88 |
+
'bezier': interpolation_functions.compute_bezier,
|
| 89 |
+
'catmull': interpolation_functions.compute_catmull,
|
| 90 |
+
}[self.__function_name]
|
| 91 |
+
|
| 92 |
+
def steps_scale_t(conds, params: interpolation_tensor.InterpolationParams):
|
| 93 |
+
scaled_t = (params.t * total_steps - steps[0]) / max(1, steps[-1] - steps[0])
|
| 94 |
+
scaled_t = scale_t(scaled_t, steps)
|
| 95 |
+
|
| 96 |
+
new_params = interpolation_tensor.InterpolationParams(scaled_t, *params[1:])
|
| 97 |
+
return interpolation_function(conds, new_params)
|
| 98 |
+
|
| 99 |
+
return steps_scale_t
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AverageExpression:
|
| 103 |
+
def __init__(self, expressions, weights):
|
| 104 |
+
if len(expressions) < len(weights):
|
| 105 |
+
raise ValueError
|
| 106 |
+
|
| 107 |
+
self.__expressions = expressions
|
| 108 |
+
self.__weights = weights
|
| 109 |
+
|
| 110 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 111 |
+
def tensor_updater(expr):
|
| 112 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 113 |
+
|
| 114 |
+
tensor_builder.extrude(
|
| 115 |
+
[tensor_updater(expr) for expr in self.__expressions],
|
| 116 |
+
self.get_interpolation_function(steps_range, total_steps, context, is_hires, use_old_scheduling))
|
| 117 |
+
|
| 118 |
+
def get_interpolation_function(self, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 119 |
+
weights = [
|
| 120 |
+
_eval_int_or_float(weight, steps_range, total_steps, context, is_hires, use_old_scheduling) if weight is not None else None
|
| 121 |
+
for weight in self.__weights
|
| 122 |
+
]
|
| 123 |
+
explicit_weights = [weight for weight in weights if weight is not None]
|
| 124 |
+
weights = [
|
| 125 |
+
weight / sum(explicit_weights) * len(explicit_weights) / len(self.__expressions)
|
| 126 |
+
if weight is not None
|
| 127 |
+
else 1 / len(self.__expressions)
|
| 128 |
+
for weight in weights
|
| 129 |
+
]
|
| 130 |
+
weights.extend(1 / len(self.__expressions) for _ in range(len(self.__expressions) - len(weights)))
|
| 131 |
+
|
| 132 |
+
def interpolation_function(conds, _params):
|
| 133 |
+
total = None
|
| 134 |
+
for cond, weight in zip(conds, weights):
|
| 135 |
+
cond *= weight
|
| 136 |
+
if total is None:
|
| 137 |
+
total = cond
|
| 138 |
+
else:
|
| 139 |
+
total += cond
|
| 140 |
+
|
| 141 |
+
return total
|
| 142 |
+
|
| 143 |
+
return interpolation_function
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class AlternationExpression:
|
| 147 |
+
def __init__(self, expressions, speed):
|
| 148 |
+
self.__expressions = expressions
|
| 149 |
+
self.__speed = speed
|
| 150 |
+
|
| 151 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 152 |
+
if self.__speed is None:
|
| 153 |
+
speed = None
|
| 154 |
+
else:
|
| 155 |
+
speed = _eval_int_or_float(self.__speed, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 156 |
+
|
| 157 |
+
if speed is None:
|
| 158 |
+
tensor_builder.append('[')
|
| 159 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 160 |
+
if expr_i >= 1:
|
| 161 |
+
tensor_builder.append('|')
|
| 162 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 163 |
+
tensor_builder.append(']')
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
def tensor_updater(expr):
|
| 167 |
+
return lambda t: expr.extend_tensor(t, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 168 |
+
|
| 169 |
+
exprs = self.__expressions + [self.__expressions[0]]
|
| 170 |
+
|
| 171 |
+
tensor_builder.extrude(
|
| 172 |
+
[tensor_updater(expr) for expr in exprs],
|
| 173 |
+
self.get_interpolation_function(speed, exprs, steps_range, total_steps))
|
| 174 |
+
|
| 175 |
+
def get_interpolation_function(self, speed, exprs, steps_range, total_steps):
|
| 176 |
+
def compute_wrap(control_points, params: interpolation_tensor.InterpolationParams):
|
| 177 |
+
wrapped_t = math.fmod((params.t * total_steps - steps_range[0]) / (len(exprs) - 1) * speed, 1.0)
|
| 178 |
+
if wrapped_t < 0:
|
| 179 |
+
wrapped_t = wrapped_t + 1
|
| 180 |
+
new_params = interpolation_tensor.InterpolationParams(wrapped_t, *params[1:])
|
| 181 |
+
return interpolation_functions.compute_linear(control_points, new_params)
|
| 182 |
+
|
| 183 |
+
return compute_wrap
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class EditingExpression:
|
| 187 |
+
def __init__(self, expressions, step):
|
| 188 |
+
assert 1 <= len(expressions) <= 2
|
| 189 |
+
self.__expressions = expressions
|
| 190 |
+
self.__step = step
|
| 191 |
+
|
| 192 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 193 |
+
if self.__step is None:
|
| 194 |
+
tensor_builder.append('[')
|
| 195 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 196 |
+
expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 197 |
+
tensor_builder.append(':')
|
| 198 |
+
tensor_builder.append(']')
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
step = _eval_int_or_float(self.__step, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 202 |
+
step_int = step
|
| 203 |
+
if use_old_scheduling and 0 < step < 1:
|
| 204 |
+
step_int *= total_steps
|
| 205 |
+
elif not use_old_scheduling and isinstance(step, float):
|
| 206 |
+
step_int = (step_int - int(is_hires)) * total_steps
|
| 207 |
+
else:
|
| 208 |
+
step_int += 1
|
| 209 |
+
|
| 210 |
+
step_int = int(step_int)
|
| 211 |
+
|
| 212 |
+
tensor_builder.append('[')
|
| 213 |
+
for expr_i, expr in enumerate(self.__expressions):
|
| 214 |
+
expr_steps_range = (steps_range[0], step_int) if expr_i == 0 and len(self.__expressions) >= 2 else (step_int, steps_range[1])
|
| 215 |
+
expr.extend_tensor(tensor_builder, expr_steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 216 |
+
tensor_builder.append(':')
|
| 217 |
+
|
| 218 |
+
tensor_builder.append(f'{step}]')
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class WeightedExpression:
|
| 222 |
+
def __init__(self, nested, weight=None, positive=True):
|
| 223 |
+
self.__nested = nested
|
| 224 |
+
if not positive:
|
| 225 |
+
assert weight is None
|
| 226 |
+
|
| 227 |
+
self.__weight = weight
|
| 228 |
+
self.__positive = positive
|
| 229 |
+
|
| 230 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 231 |
+
open_bracket, close_bracket = ('(', ')') if self.__positive else ('[', ']')
|
| 232 |
+
tensor_builder.append(open_bracket)
|
| 233 |
+
self.__nested.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 234 |
+
|
| 235 |
+
if self.__weight is not None:
|
| 236 |
+
tensor_builder.append(':')
|
| 237 |
+
self.__weight.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 238 |
+
|
| 239 |
+
tensor_builder.append(close_bracket)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class WeightInterpolationExpression:
|
| 243 |
+
def __init__(self, nested, weight_begin, weight_end):
|
| 244 |
+
self.__nested = nested
|
| 245 |
+
self.__weight_begin = weight_begin if weight_begin is not None else LiftExpression(str(1.))
|
| 246 |
+
self.__weight_end = weight_end if weight_end is not None else LiftExpression(str(1.))
|
| 247 |
+
|
| 248 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 249 |
+
steps_range_size = steps_range[1] - steps_range[0]
|
| 250 |
+
|
| 251 |
+
weight_begin = _eval_int_or_float(self.__weight_begin, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 252 |
+
weight_end = _eval_int_or_float(self.__weight_end, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 253 |
+
|
| 254 |
+
for i in range(steps_range_size):
|
| 255 |
+
step = i + steps_range[0]
|
| 256 |
+
|
| 257 |
+
weight = weight_begin + (weight_end - weight_begin) * (i / max(steps_range_size - 1, 1))
|
| 258 |
+
weight_step_expr = WeightedExpression(self.__nested, LiftExpression(str(weight)))
|
| 259 |
+
if step > steps_range[0]:
|
| 260 |
+
weight_step_expr = EditingExpression([weight_step_expr], LiftExpression(str(step - 1)))
|
| 261 |
+
if step + 1 < steps_range[1]:
|
| 262 |
+
weight_step_expr = EditingExpression([weight_step_expr, ListExpression([])], LiftExpression(str(step)))
|
| 263 |
+
|
| 264 |
+
weight_step_expr.extend_tensor(tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class DeclarationExpression:
|
| 268 |
+
def __init__(self, symbol, parameters, value, target):
|
| 269 |
+
self.__symbol = symbol
|
| 270 |
+
self.__value = value
|
| 271 |
+
self.__target = target
|
| 272 |
+
self.__parameters = parameters
|
| 273 |
+
|
| 274 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 275 |
+
updated_context = dict(context)
|
| 276 |
+
updated_context[self.__symbol] = (self.__value, self.__parameters)
|
| 277 |
+
self.__target.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class SubstitutionExpression:
|
| 281 |
+
def __init__(self, symbol, arguments):
|
| 282 |
+
self.__symbol = symbol
|
| 283 |
+
self.__arguments = arguments
|
| 284 |
+
|
| 285 |
+
def extend_tensor(self, tensor_builder, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 286 |
+
updated_context = dict(context)
|
| 287 |
+
nested, parameters = context[self.__symbol]
|
| 288 |
+
for argument, parameter in zip(self.__arguments, parameters):
|
| 289 |
+
updated_context[parameter] = argument, []
|
| 290 |
+
nested.extend_tensor(tensor_builder, steps_range, total_steps, updated_context, is_hires, use_old_scheduling)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class LiftExpression:
|
| 294 |
+
def __init__(self, value):
|
| 295 |
+
self.__value = value
|
| 296 |
+
|
| 297 |
+
def extend_tensor(self, tensor_builder, *_args, **_kwargs):
|
| 298 |
+
tensor_builder.append(self.__value)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _eval_int_or_float(expression, steps_range, total_steps, context, is_hires, use_old_scheduling):
|
| 302 |
+
mock_database = ['']
|
| 303 |
+
expression.extend_tensor(interpolation_tensor.InterpolationTensorBuilder(prompt_database=mock_database), steps_range, total_steps, context, is_hires, use_old_scheduling)
|
| 304 |
+
try:
|
| 305 |
+
return int(mock_database[0])
|
| 306 |
+
except ValueError:
|
| 307 |
+
return float(mock_database[0])
|
z-prompt-fusion-extension/lib_prompt_fusion/empty_cond.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
_empty_cond = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get():
|
| 8 |
+
return _empty_cond
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def init(model):
|
| 12 |
+
global _empty_cond
|
| 13 |
+
cond = model.get_learned_conditioning([''])
|
| 14 |
+
if isinstance(cond, dict):
|
| 15 |
+
cond = interpolation_tensor.DictCondWrapper({k: v[0] for k, v in cond.items()})
|
| 16 |
+
else:
|
| 17 |
+
cond = interpolation_tensor.TensorCondWrapper(cond[0])
|
| 18 |
+
|
| 19 |
+
_empty_cond = cond
|
z-prompt-fusion-extension/lib_prompt_fusion/geometries.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from lib_prompt_fusion import interpolation_tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def slerp_geometry(control_points, params: interpolation_tensor.InterpolationParams):
|
| 7 |
+
p0, p1 = control_points
|
| 8 |
+
p0_norm = torch.linalg.norm(p0)
|
| 9 |
+
p1_norm = torch.linalg.norm(p1)
|
| 10 |
+
|
| 11 |
+
similarity = torch.sum((p0 / p0_norm) * (p1 / p1_norm))
|
| 12 |
+
similarity = min(1., max(-1., float(similarity)))
|
| 13 |
+
if similarity <= params.slerp_epsilon - 1 or similarity >= 1 - params.slerp_epsilon:
|
| 14 |
+
return linear_geometry(control_points, params)
|
| 15 |
+
|
| 16 |
+
angle = math.acos(float(similarity)) / 2
|
| 17 |
+
|
| 18 |
+
slerp_t = angle * (2 * params.t - 1)
|
| 19 |
+
slerp_t = math.tan(slerp_t) / math.tan(angle)
|
| 20 |
+
slerp_t = (slerp_t + 1) / 2
|
| 21 |
+
|
| 22 |
+
normalized_p1 = p1 / p1_norm * p0_norm
|
| 23 |
+
slerp_p = p0 + (normalized_p1 - p0) * slerp_t
|
| 24 |
+
slerp_p = slerp_p / torch.linalg.norm(slerp_p) * (p0_norm + (p1_norm - p0_norm) * params.t)
|
| 25 |
+
|
| 26 |
+
lerp_p = linear_geometry(control_points, params)
|
| 27 |
+
return lerp_p + (slerp_p - lerp_p) * params.slerp_scale
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def linear_geometry(control_points, params: interpolation_tensor.InterpolationParams):
|
| 31 |
+
p0, p1 = control_points
|
| 32 |
+
res = p0 + (p1 - p0) * params.t
|
| 33 |
+
return res
|
z-prompt-fusion-extension/lib_prompt_fusion/global_state.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from modules import shared, prompt_parser
|
| 3 |
+
from lib_prompt_fusion import empty_cond
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
old_webui_is_negative: bool = False
|
| 7 |
+
negative_schedules: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
|
| 8 |
+
negative_schedules_hires: Optional[List[prompt_parser.ScheduledPromptConditioning]] = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_origin_cond_at(step: int, is_hires: bool = False):
|
| 12 |
+
fallback_schedules = negative_schedules_hires if is_hires else negative_schedules
|
| 13 |
+
if not fallback_schedules or not shared.opts.data.get('prompt_fusion_slerp_negative_origin', False):
|
| 14 |
+
return empty_cond.get()
|
| 15 |
+
|
| 16 |
+
for schedule in fallback_schedules:
|
| 17 |
+
if schedule.end_at_step >= step:
|
| 18 |
+
return schedule.cond
|
| 19 |
+
|
| 20 |
+
return empty_cond.get()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_slerp_scale():
|
| 24 |
+
return shared.opts.data.get('prompt_fusion_slerp_scale', 0.0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_slerp_epsilon():
|
| 28 |
+
return shared.opts.data.get('prompt_fusion_slerp_epsilon', 0.0001)
|
z-prompt-fusion-extension/lib_prompt_fusion/hijacker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ModuleHijacker:
|
| 2 |
+
def __init__(self, module):
|
| 3 |
+
self.__module = module
|
| 4 |
+
self.__original_functions = dict()
|
| 5 |
+
|
| 6 |
+
def hijack(self, attribute):
|
| 7 |
+
if attribute not in self.__original_functions:
|
| 8 |
+
self.__original_functions[attribute] = getattr(self.__module, attribute)
|
| 9 |
+
|
| 10 |
+
def decorator(function):
|
| 11 |
+
def wrapper(*args, **kwargs):
|
| 12 |
+
return function(*args, **kwargs, original_function=self.__original_functions[attribute])
|
| 13 |
+
|
| 14 |
+
setattr(self.__module, attribute, wrapper)
|
| 15 |
+
return function
|
| 16 |
+
|
| 17 |
+
return decorator
|
| 18 |
+
|
| 19 |
+
def reset_module(self):
|
| 20 |
+
for attribute, original_function in self.__original_functions.items():
|
| 21 |
+
setattr(self.__module, attribute, original_function)
|
| 22 |
+
|
| 23 |
+
self.__original_functions.clear()
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def install_or_get(module, hijacker_attribute, register_uninstall=lambda _callback: None):
|
| 27 |
+
if not hasattr(module, hijacker_attribute):
|
| 28 |
+
module_hijacker = ModuleHijacker(module)
|
| 29 |
+
setattr(module, hijacker_attribute, module_hijacker)
|
| 30 |
+
register_uninstall(lambda: delattr(module, hijacker_attribute))
|
| 31 |
+
register_uninstall(module_hijacker.reset_module)
|
| 32 |
+
return module_hijacker
|
| 33 |
+
else:
|
| 34 |
+
return getattr(module, hijacker_attribute)
|
z-prompt-fusion-extension/lib_prompt_fusion/interpolation_functions.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from lib_prompt_fusion import interpolation_tensor, geometries
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def compute_linear(control_points, params: interpolation_tensor.InterpolationParams):
|
| 7 |
+
if len(control_points) <= 2:
|
| 8 |
+
return geometries.slerp_geometry(control_points, params)
|
| 9 |
+
else:
|
| 10 |
+
target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
|
| 11 |
+
cp0 = control_points[target_curve]
|
| 12 |
+
cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
|
| 13 |
+
new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
|
| 14 |
+
return geometries.slerp_geometry([cp0, cp1], new_params)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_bezier(control_points, params: interpolation_tensor.InterpolationParams):
|
| 18 |
+
def compute_casteljau(ps, size):
|
| 19 |
+
for i in reversed(range(1, size)):
|
| 20 |
+
for j in range(i):
|
| 21 |
+
ps[j] = geometries.slerp_geometry([ps[j], ps[j+1]], params)
|
| 22 |
+
|
| 23 |
+
return ps[0]
|
| 24 |
+
|
| 25 |
+
if len(control_points) == 1:
|
| 26 |
+
return control_points[0]
|
| 27 |
+
elif len(control_points) == 2:
|
| 28 |
+
return geometries.slerp_geometry(control_points, params)
|
| 29 |
+
copied_control_points = copy.deepcopy(control_points)
|
| 30 |
+
return compute_casteljau(copied_control_points, len(copied_control_points))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def compute_catmull(control_points, params: interpolation_tensor.InterpolationParams):
|
| 34 |
+
if len(control_points) <= 2:
|
| 35 |
+
return compute_linear(control_points, params)
|
| 36 |
+
else:
|
| 37 |
+
target_curve = min(int(params.t * (len(control_points) - 1)), len(control_points) - 1)
|
| 38 |
+
g0 = control_points[target_curve - 1] if target_curve > 0 else control_points[0] * 2 - control_points[1]
|
| 39 |
+
cp0 = control_points[target_curve]
|
| 40 |
+
cp1 = control_points[target_curve + 1] if target_curve + 1 < len(control_points) else control_points[-1]
|
| 41 |
+
g1 = control_points[target_curve + 2] if target_curve + 2 < len(control_points) else cp1 * 2 - cp0
|
| 42 |
+
ip0 = cp0 + (cp1 - g0)/6
|
| 43 |
+
ip1 = cp1 + (cp0 - g1)/6
|
| 44 |
+
|
| 45 |
+
new_params = interpolation_tensor.InterpolationParams(math.fmod(params.t * (len(control_points) - 1), 1.), *params[1:])
|
| 46 |
+
return compute_bezier([cp0, ip0, ip1, cp1], new_params)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == '__main__':
|
| 50 |
+
import turtle as tr
|
| 51 |
+
import torch
|
| 52 |
+
size = 60
|
| 53 |
+
turtle_tool = tr.Turtle()
|
| 54 |
+
turtle_tool.speed(10)
|
| 55 |
+
turtle_tool.up()
|
| 56 |
+
|
| 57 |
+
points = torch.Tensor([[-2., -2.], [2., 2.]])
|
| 58 |
+
origin = torch.Tensor([1.5, 1.6])
|
| 59 |
+
|
| 60 |
+
def sample(slerp_scale, color):
|
| 61 |
+
for i in range(size):
|
| 62 |
+
t = i / size
|
| 63 |
+
params = interpolation_tensor.InterpolationParams(t, i, size, slerp_scale, 0.0001)
|
| 64 |
+
point = origin + compute_linear(points - origin, params)
|
| 65 |
+
try:
|
| 66 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in point))
|
| 67 |
+
turtle_tool.dot(5, color)
|
| 68 |
+
print(point)
|
| 69 |
+
except ValueError:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
sample(0, "black")
|
| 73 |
+
sample(1, "green")
|
| 74 |
+
sample(2, "blue")
|
| 75 |
+
sample(-1, "purple")
|
| 76 |
+
sample(-2, "orange")
|
| 77 |
+
|
| 78 |
+
for point in points:
|
| 79 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in point))
|
| 80 |
+
turtle_tool.dot(5, "red")
|
| 81 |
+
|
| 82 |
+
turtle_tool.goto(tuple(float(p) * 100. for p in origin))
|
| 83 |
+
turtle_tool.dot(10, "red")
|
| 84 |
+
|
| 85 |
+
turtle_tool.goto(100000, 100000)
|
| 86 |
+
turtle_tool.dot()
|
| 87 |
+
tr.done()
|
z-prompt-fusion-extension/lib_prompt_fusion/interpolation_tensor.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import torch
|
| 3 |
+
from modules import prompt_parser
|
| 4 |
+
from typing import NamedTuple, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InterpolationParams(NamedTuple):
|
| 8 |
+
t: float
|
| 9 |
+
step: int
|
| 10 |
+
total_steps: int
|
| 11 |
+
slerp_scale: float
|
| 12 |
+
slerp_epsilon: float
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InterpolationTensor:
|
| 16 |
+
def __init__(self, sub_tensors, interpolation_function):
|
| 17 |
+
self.__sub_tensors = sub_tensors
|
| 18 |
+
self.__interpolation_function = interpolation_function
|
| 19 |
+
|
| 20 |
+
def interpolate(self, params: InterpolationParams, origin_cond, empty_cond):
|
| 21 |
+
cond = self.interpolate_cond_rec(params, origin_cond, empty_cond)
|
| 22 |
+
if params.slerp_scale != 0:
|
| 23 |
+
cond = (cond + origin_cond.extend_like(cond, empty_cond)).to(dtype=origin_cond.dtype)
|
| 24 |
+
return cond
|
| 25 |
+
|
| 26 |
+
def interpolate_cond_rec(self, params: InterpolationParams, origin_cond, empty_cond):
|
| 27 |
+
if self.__interpolation_function is None:
|
| 28 |
+
return self.get_cond_point(params.step, origin_cond, empty_cond, params.slerp_scale)
|
| 29 |
+
|
| 30 |
+
control_points = [
|
| 31 |
+
sub_tensor.interpolate_cond_rec(params, origin_cond, empty_cond)
|
| 32 |
+
for sub_tensor in self.__sub_tensors
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
CondWrapper, control_points_values = conds_to_cp_values(control_points)
|
| 36 |
+
return CondWrapper.from_cp_values(self.__interpolation_function(control_points, params) for control_points in control_points_values)
|
| 37 |
+
|
| 38 |
+
def get_cond_point(self, step, origin_cond, empty_cond, slerp_scale):
|
| 39 |
+
schedule = None
|
| 40 |
+
for schedule in self.__sub_tensors:
|
| 41 |
+
if schedule.end_at_step >= step:
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
res = schedule.cond.extend_like(origin_cond, empty_cond)
|
| 45 |
+
if slerp_scale != 0:
|
| 46 |
+
res = res.to(dtype=torch.float) - origin_cond.extend_like(schedule.cond, empty_cond).to(dtype=torch.float)
|
| 47 |
+
return res
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def conds_to_cp_values(conds):
|
| 51 |
+
CondWrapper = type(conds[0])
|
| 52 |
+
cp_values = [
|
| 53 |
+
cond.to_cp_values()
|
| 54 |
+
for cond in conds
|
| 55 |
+
]
|
| 56 |
+
return CondWrapper, [
|
| 57 |
+
[v[i] for v in cp_values]
|
| 58 |
+
for i in range(len(cp_values[0]))
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class InterpolationTensorBuilder:
|
| 63 |
+
def __init__(self, tensor=None, prompt_database=None, interpolation_functions=None):
|
| 64 |
+
self.__indices_tensor = tensor if tensor is not None else 0
|
| 65 |
+
self.__prompt_database = prompt_database if prompt_database is not None else ['']
|
| 66 |
+
self.__interpolation_functions = interpolation_functions if interpolation_functions is not None else []
|
| 67 |
+
|
| 68 |
+
def append(self, suffix):
|
| 69 |
+
for i in range(len(self.__prompt_database)):
|
| 70 |
+
self.__prompt_database[i] += suffix
|
| 71 |
+
|
| 72 |
+
def extrude(self, tensor_updaters, interpolation_function):
|
| 73 |
+
extruded_indices_tensor = []
|
| 74 |
+
extruded_prompt_database = []
|
| 75 |
+
extruded_interpolation_functions = []
|
| 76 |
+
|
| 77 |
+
for update_tensor in tensor_updaters:
|
| 78 |
+
nested_tensor_builder = InterpolationTensorBuilder(
|
| 79 |
+
self.__indices_tensor,
|
| 80 |
+
self.__prompt_database[:],
|
| 81 |
+
interpolation_functions=[])
|
| 82 |
+
|
| 83 |
+
update_tensor(nested_tensor_builder)
|
| 84 |
+
|
| 85 |
+
extruded_indices_tensor.append(InterpolationTensorBuilder.__offset_tensor(
|
| 86 |
+
tensor=nested_tensor_builder.__indices_tensor,
|
| 87 |
+
offset=len(extruded_prompt_database)))
|
| 88 |
+
extruded_prompt_database.extend(nested_tensor_builder.__prompt_database)
|
| 89 |
+
extruded_interpolation_functions.append(nested_tensor_builder.__interpolation_functions)
|
| 90 |
+
|
| 91 |
+
self.__indices_tensor = extruded_indices_tensor
|
| 92 |
+
self.__prompt_database[:] = extruded_prompt_database
|
| 93 |
+
self.__interpolation_functions.insert(0, (interpolation_function, extruded_interpolation_functions))
|
| 94 |
+
|
| 95 |
+
def get_prompt_database(self):
|
| 96 |
+
return self.__prompt_database
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def __offset_tensor(tensor, offset):
|
| 100 |
+
try:
|
| 101 |
+
return tensor + offset
|
| 102 |
+
|
| 103 |
+
except TypeError:
|
| 104 |
+
return [InterpolationTensorBuilder.__offset_tensor(e, offset) for e in tensor]
|
| 105 |
+
|
| 106 |
+
def build(self, conds, empty_cond):
|
| 107 |
+
max_cond_size = self.__max_cond_size(conds)
|
| 108 |
+
conds = self.__resize_uniformly(conds, max_cond_size, empty_cond)
|
| 109 |
+
return InterpolationTensorBuilder.__build_conditionings_tensor(self.__indices_tensor, self.__interpolation_functions, conds)
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def __build_conditionings_tensor(tensor, int_funcs, conds):
|
| 113 |
+
if type(tensor) is int:
|
| 114 |
+
return InterpolationTensor(conds[tensor], None)
|
| 115 |
+
else:
|
| 116 |
+
int_func, nested_int_funcs = int_funcs[0]
|
| 117 |
+
return InterpolationTensor(
|
| 118 |
+
[
|
| 119 |
+
InterpolationTensorBuilder.__build_conditionings_tensor(sub_tensor, nested_int_funcs + int_funcs[1:], conds)
|
| 120 |
+
for sub_tensor, nested_int_funcs in zip(tensor, nested_int_funcs)
|
| 121 |
+
],
|
| 122 |
+
int_func,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def __resize_uniformly(self, conds, max_cond_size: int, empty_cond):
|
| 126 |
+
return [
|
| 127 |
+
[
|
| 128 |
+
prompt_parser.ScheduledPromptConditioning(
|
| 129 |
+
cond=schedule.cond.resize_schedule(max_cond_size, empty_cond),
|
| 130 |
+
end_at_step=schedule.end_at_step
|
| 131 |
+
)
|
| 132 |
+
for schedule in schedules
|
| 133 |
+
]
|
| 134 |
+
for schedules in conds
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def __max_cond_size(conds):
|
| 139 |
+
return max(schedule.cond.size(0)
|
| 140 |
+
for schedules in conds
|
| 141 |
+
for schedule in schedules)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclasses.dataclass
|
| 145 |
+
class DictCondWrapper:
|
| 146 |
+
original_cond: dict
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def from_cp_values(cp_values):
|
| 150 |
+
return DictCondWrapper({
|
| 151 |
+
k: v
|
| 152 |
+
for k, v in zip(('crossattn', 'vector'), cp_values)
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
def size(self, *args, **kwargs):
|
| 156 |
+
return self.original_cond['crossattn'].size(*args, **kwargs)
|
| 157 |
+
|
| 158 |
+
def extend_like(self, that, empty):
|
| 159 |
+
missing_size = max(0, that.size(0) - self.size(0)) // 77
|
| 160 |
+
extended = DictCondWrapper(self.original_cond.copy())
|
| 161 |
+
extended.original_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty.original_cond['crossattn']] * missing_size)
|
| 162 |
+
return extended
|
| 163 |
+
|
| 164 |
+
def resize_schedule(self, target_size, empty_cond):
|
| 165 |
+
cond_missing_size = (target_size - self.size(0)) // 77
|
| 166 |
+
if cond_missing_size <= 0:
|
| 167 |
+
return self
|
| 168 |
+
|
| 169 |
+
resized_cond = self.original_cond.copy()
|
| 170 |
+
resized_cond['crossattn'] = torch.concatenate([self.original_cond['crossattn']] + [empty_cond.original_cond['crossattn']] * cond_missing_size)
|
| 171 |
+
return DictCondWrapper(resized_cond)
|
| 172 |
+
|
| 173 |
+
def to_cp_values(self):
|
| 174 |
+
return list(self.original_cond.values())
|
| 175 |
+
|
| 176 |
+
def to(self, dtype: Union[dict, torch.dtype]):
|
| 177 |
+
if not isinstance(dtype, dict):
|
| 178 |
+
dtype = {
|
| 179 |
+
k: dtype
|
| 180 |
+
for k in self.original_cond.keys()
|
| 181 |
+
}
|
| 182 |
+
return DictCondWrapper({
|
| 183 |
+
k: v.to(dtype=dtype[k])
|
| 184 |
+
for k, v in self.original_cond.items()
|
| 185 |
+
})
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def dtype(self):
|
| 189 |
+
return {
|
| 190 |
+
k: v.dtype
|
| 191 |
+
for k, v in self.original_cond.items()
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
def __sub__(self, that):
|
| 195 |
+
return DictCondWrapper({
|
| 196 |
+
k: v - that.original_cond[k]
|
| 197 |
+
for k, v in self.original_cond.items()
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
def __add__(self, that):
|
| 201 |
+
return DictCondWrapper({
|
| 202 |
+
k: v + that.original_cond[k]
|
| 203 |
+
for k, v in self.original_cond.items()
|
| 204 |
+
})
|
| 205 |
+
|
| 206 |
+
def __eq__(self, that):
|
| 207 |
+
return all((self.original_cond[k] == that.original_cond[k]).all() for k in self.original_cond.keys())
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@dataclasses.dataclass
|
| 211 |
+
class TensorCondWrapper:
|
| 212 |
+
original_cond: torch.Tensor
|
| 213 |
+
|
| 214 |
+
@staticmethod
|
| 215 |
+
def from_cp_values(cp_values):
|
| 216 |
+
return TensorCondWrapper(next(iter(cp_values)))
|
| 217 |
+
|
| 218 |
+
def size(self, *args, **kwargs):
|
| 219 |
+
return self.original_cond.size(*args, **kwargs)
|
| 220 |
+
|
| 221 |
+
def extend_like(self, that, empty):
|
| 222 |
+
missing_size = max(0, that.size(0) - self.original_cond.size(0)) // 77
|
| 223 |
+
return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty.original_cond] * missing_size))
|
| 224 |
+
|
| 225 |
+
def resize_schedule(self, target_size, empty_cond):
|
| 226 |
+
cond_missing_size = (target_size - self.original_cond.size(0)) // 77
|
| 227 |
+
if cond_missing_size <= 0:
|
| 228 |
+
return self
|
| 229 |
+
|
| 230 |
+
return TensorCondWrapper(torch.concatenate([self.original_cond] + [empty_cond.original_cond] * cond_missing_size))
|
| 231 |
+
|
| 232 |
+
def to_cp_values(self):
|
| 233 |
+
return [self.original_cond]
|
| 234 |
+
|
| 235 |
+
def to(self, dtype: torch.dtype):
|
| 236 |
+
return TensorCondWrapper(self.original_cond.to(dtype=dtype))
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def dtype(self):
|
| 240 |
+
return self.original_cond.dtype
|
| 241 |
+
|
| 242 |
+
def __sub__(self, that):
|
| 243 |
+
return TensorCondWrapper(self.original_cond - that.original_cond)
|
| 244 |
+
|
| 245 |
+
def __add__(self, that):
|
| 246 |
+
return TensorCondWrapper(self.original_cond + that.original_cond)
|
| 247 |
+
|
| 248 |
+
def __eq__(self, that):
|
| 249 |
+
return (self.original_cond == that.original_cond).all()
|
z-prompt-fusion-extension/lib_prompt_fusion/prompt_parser.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import lib_prompt_fusion.ast_nodes as ast
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
ParseResult = namedtuple('ParseResult', ['prompt', 'expr'])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_prompt(prompt):
|
| 10 |
+
prompt = prompt.lstrip()
|
| 11 |
+
prompt, list_expr = parse_list_expression(prompt, set())
|
| 12 |
+
return list_expr
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_list_expression(prompt, stoppers):
|
| 16 |
+
exprs = []
|
| 17 |
+
try:
|
| 18 |
+
while True:
|
| 19 |
+
prompt, expr = parse_expression(prompt, stoppers)
|
| 20 |
+
exprs.append(expr)
|
| 21 |
+
except ValueError:
|
| 22 |
+
return ParseResult(prompt=prompt, expr=ast.ListExpression(exprs))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_expression(prompt, stoppers):
|
| 26 |
+
for parse in _parsers():
|
| 27 |
+
try:
|
| 28 |
+
return parse(prompt, stoppers)
|
| 29 |
+
except ValueError:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
raise ValueError
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _parsers():
|
| 36 |
+
return (
|
| 37 |
+
parse_text,
|
| 38 |
+
parse_declaration,
|
| 39 |
+
parse_substitution,
|
| 40 |
+
parse_positive_attention,
|
| 41 |
+
parse_negative_attention,
|
| 42 |
+
parse_editing,
|
| 43 |
+
parse_alternation,
|
| 44 |
+
parse_interpolation,
|
| 45 |
+
parse_unrestricted_text,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def parse_text(prompt, stoppers):
|
| 50 |
+
return parse_unrestricted_text(prompt, set_concat(stoppers, {'[', '(', '$'}))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def parse_unrestricted_text(prompt, stoppers):
|
| 54 |
+
escaped_stoppers = ''.join(re.escape(stopper) for stopper in stoppers)
|
| 55 |
+
regex = rf'(?:[^{escaped_stoppers}\\\s]|\$(?![a-zA-Z_])|\\.)+'
|
| 56 |
+
prompt, expr = parse_token(prompt, whitespace_tail_regex(regex, stoppers))
|
| 57 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(expr))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def parse_substitution(prompt, stoppers):
|
| 61 |
+
prompt, symbol = parse_symbol(prompt, stoppers)
|
| 62 |
+
prompt, arguments = parse_arguments(prompt, stoppers)
|
| 63 |
+
return ParseResult(prompt=prompt, expr=ast.SubstitutionExpression(symbol, arguments))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_arguments(prompt, stoppers):
|
| 67 |
+
try:
|
| 68 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 69 |
+
prompt, arguments = parse_inner_arguments(prompt, stoppers)
|
| 70 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 71 |
+
except ValueError:
|
| 72 |
+
arguments = []
|
| 73 |
+
return ParseResult(prompt=prompt, expr=arguments)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def parse_inner_arguments(prompt, stoppers):
|
| 77 |
+
arguments = []
|
| 78 |
+
try:
|
| 79 |
+
while True:
|
| 80 |
+
prompt, arg = parse_list_expression(prompt, {',', ')'})
|
| 81 |
+
arguments.append(arg)
|
| 82 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 83 |
+
except ValueError:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
return ParseResult(prompt=prompt, expr=arguments)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def parse_declaration(prompt, stoppers):
|
| 90 |
+
prompt, symbol = parse_symbol(prompt, stoppers)
|
| 91 |
+
prompt, parameters = parse_parameters(prompt, stoppers)
|
| 92 |
+
prompt, _ = parse_equals(prompt, stoppers)
|
| 93 |
+
prompt, value = parse_list_expression(prompt, set_concat(stoppers, '\n'))
|
| 94 |
+
prompt, _ = parse_newline(prompt, stoppers)
|
| 95 |
+
prompt, expr = parse_list_expression(prompt, stoppers)
|
| 96 |
+
return ParseResult(prompt=prompt, expr=ast.DeclarationExpression(symbol, parameters, value, expr))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def parse_parameters(prompt, stoppers):
|
| 100 |
+
try:
|
| 101 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 102 |
+
prompt, parameters = parse_inner_parameters(prompt, stoppers)
|
| 103 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 104 |
+
except ValueError:
|
| 105 |
+
parameters = []
|
| 106 |
+
return ParseResult(prompt=prompt, expr=parameters)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def parse_inner_parameters(prompt, stoppers):
|
| 110 |
+
parameters = []
|
| 111 |
+
try:
|
| 112 |
+
while True:
|
| 113 |
+
prompt, param = parse_symbol(prompt, stoppers)
|
| 114 |
+
parameters.append(param)
|
| 115 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 116 |
+
except ValueError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
return ParseResult(prompt=prompt, expr=parameters)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def parse_interpolation(prompt, stoppers):
|
| 123 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 124 |
+
prompt, exprs = parse_interpolation_exprs(prompt, stoppers)
|
| 125 |
+
prompt, steps = parse_interpolation_steps(prompt, stoppers)
|
| 126 |
+
prompt, function_name = parse_interpolation_function_name(prompt, stoppers)
|
| 127 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 128 |
+
return ParseResult(prompt=prompt, expr=ast.InterpolationExpression.create(exprs, steps, function_name))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def parse_interpolation_exprs(prompt, stoppers):
|
| 132 |
+
exprs = []
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
while True:
|
| 136 |
+
prompt_tmp, expr = parse_list_expression(prompt, {':', ']'})
|
| 137 |
+
if parse_interpolation_function_name(prompt_tmp, stoppers).expr is not None:
|
| 138 |
+
raise ValueError
|
| 139 |
+
|
| 140 |
+
prompt, _ = parse_colon(prompt_tmp, stoppers)
|
| 141 |
+
exprs.append(expr)
|
| 142 |
+
except ValueError:
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_interpolation_function_name(prompt, stoppers):
|
| 149 |
+
try:
|
| 150 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 151 |
+
function_names = ('linear', 'catmull', 'bezier', 'mean')
|
| 152 |
+
return parse_token(prompt, whitespace_tail_regex('|'.join(function_names), stoppers))
|
| 153 |
+
except ValueError:
|
| 154 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def parse_interpolation_steps(prompt, stoppers):
|
| 158 |
+
steps = []
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
while True:
|
| 162 |
+
prompt, step = parse_interpolation_step(prompt, stoppers)
|
| 163 |
+
steps.append(step)
|
| 164 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 165 |
+
except ValueError:
|
| 166 |
+
pass
|
| 167 |
+
|
| 168 |
+
return ParseResult(prompt=prompt, expr=steps)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def parse_interpolation_step(prompt, stoppers):
|
| 172 |
+
try:
|
| 173 |
+
return parse_step(prompt, stoppers)
|
| 174 |
+
except ValueError:
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
if prompt[0] in {',', ':', ']'}:
|
| 178 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 179 |
+
|
| 180 |
+
raise ValueError
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def parse_alternation(prompt, stoppers):
|
| 184 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 185 |
+
prompt, exprs = parse_alternation_exprs(prompt, stoppers)
|
| 186 |
+
prompt, speed = parse_alternation_speed(prompt, stoppers)
|
| 187 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 188 |
+
return ParseResult(prompt=prompt, expr=ast.AlternationExpression(exprs, speed))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def parse_alternation_exprs(prompt, stoppers):
|
| 192 |
+
exprs = []
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
while True:
|
| 196 |
+
prompt, expr = parse_list_expression(prompt, {'|', ':', ']'})
|
| 197 |
+
exprs.append(expr)
|
| 198 |
+
prompt, _ = parse_vertical_bar(prompt, stoppers)
|
| 199 |
+
except ValueError:
|
| 200 |
+
if len(exprs) < 2:
|
| 201 |
+
raise
|
| 202 |
+
|
| 203 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def parse_alternation_speed(prompt, stoppers):
|
| 207 |
+
try:
|
| 208 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 209 |
+
prompt, speed = parse_step(prompt, stoppers)
|
| 210 |
+
return ParseResult(prompt=prompt, expr=speed)
|
| 211 |
+
except ValueError:
|
| 212 |
+
pass
|
| 213 |
+
|
| 214 |
+
return ParseResult(prompt=prompt, expr=None)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def parse_editing(prompt, stoppers):
|
| 218 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 219 |
+
prompt, exprs = parse_editing_exprs(prompt, stoppers)
|
| 220 |
+
try:
|
| 221 |
+
prompt, step = parse_step(prompt, stoppers)
|
| 222 |
+
except ValueError:
|
| 223 |
+
step = None
|
| 224 |
+
|
| 225 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 226 |
+
return ParseResult(prompt=prompt, expr=ast.EditingExpression(exprs, step))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def parse_editing_exprs(prompt, stoppers):
|
| 230 |
+
exprs = []
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
for _ in range(2):
|
| 234 |
+
prompt_tmp, expr = parse_list_expression(prompt, {'|', ':', ']'})
|
| 235 |
+
prompt, _ = parse_colon(prompt_tmp, stoppers)
|
| 236 |
+
exprs.append(expr)
|
| 237 |
+
except ValueError:
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
return ParseResult(prompt=prompt, expr=exprs)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def parse_negative_attention(prompt, stoppers):
|
| 244 |
+
prompt, _ = parse_open_square(prompt, stoppers)
|
| 245 |
+
prompt, expr = parse_list_expression(prompt, set_concat(stoppers, {':', ']'}))
|
| 246 |
+
prompt, _ = parse_close_square(prompt, stoppers)
|
| 247 |
+
return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, positive=False))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def parse_positive_attention(prompt, stoppers):
|
| 251 |
+
prompt, _ = parse_open_paren(prompt, stoppers)
|
| 252 |
+
prompt, expr = parse_list_expression(prompt, {':', ')'})
|
| 253 |
+
prompt, weight_exprs = parse_attention_weights(prompt, stoppers)
|
| 254 |
+
prompt, _ = parse_close_paren(prompt, stoppers)
|
| 255 |
+
if len(weight_exprs) >= 2:
|
| 256 |
+
return ParseResult(prompt=prompt, expr=ast.WeightInterpolationExpression(expr, *weight_exprs[:2]))
|
| 257 |
+
else:
|
| 258 |
+
return ParseResult(prompt=prompt, expr=ast.WeightedExpression(expr, *weight_exprs[:1]))
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def parse_attention_weights(prompt, stoppers):
|
| 262 |
+
weights = []
|
| 263 |
+
try:
|
| 264 |
+
prompt, _ = parse_colon(prompt, stoppers)
|
| 265 |
+
except ValueError:
|
| 266 |
+
return ParseResult(prompt=prompt, expr=weights)
|
| 267 |
+
|
| 268 |
+
while True:
|
| 269 |
+
try:
|
| 270 |
+
prompt, weight_expr = parse_weight(prompt, stoppers)
|
| 271 |
+
weights.append(weight_expr)
|
| 272 |
+
prompt, _ = parse_comma(prompt, stoppers)
|
| 273 |
+
except ValueError:
|
| 274 |
+
return ParseResult(prompt=prompt, expr=weights)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def parse_step(prompt, stoppers):
|
| 278 |
+
try:
|
| 279 |
+
prompt, step = parse_int_not_float(prompt, stoppers)
|
| 280 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 281 |
+
except ValueError:
|
| 282 |
+
pass
|
| 283 |
+
try:
|
| 284 |
+
prompt, step = parse_float(prompt, stoppers)
|
| 285 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 286 |
+
except ValueError:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
return parse_substitution(prompt, stoppers)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def parse_weight(prompt, stoppers):
|
| 293 |
+
try:
|
| 294 |
+
prompt, step = parse_float(prompt, stoppers)
|
| 295 |
+
return ParseResult(prompt=prompt, expr=ast.LiftExpression(step))
|
| 296 |
+
except ValueError:
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
return parse_substitution(prompt, stoppers)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def parse_symbol(prompt, stoppers):
|
| 303 |
+
prompt, _ = parse_dollar(prompt)
|
| 304 |
+
return parse_symbol_text(prompt, stoppers)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def parse_symbol_text(prompt, stoppers):
|
| 308 |
+
return parse_token(prompt, whitespace_tail_regex('[a-zA-Z_][a-zA-Z0-9_]*', stoppers))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def parse_float(prompt, stoppers):
|
| 312 |
+
return parse_token(prompt, whitespace_tail_regex(r'[+-]?(?:\d+(?:\.\d*)?|\.\d+)', stoppers))
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def parse_int_not_float(prompt, stoppers):
|
| 316 |
+
return parse_token(prompt, whitespace_tail_regex(r'[+-]?\d+(?!\.)', stoppers))
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def parse_dollar(prompt):
|
| 320 |
+
dollar_sign = re.escape('$')
|
| 321 |
+
return parse_token(prompt, f'({dollar_sign})')
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def parse_equals(prompt, stoppers):
|
| 325 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('='), stoppers))
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def parse_comma(prompt, stoppers):
|
| 329 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(','), stoppers))
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def parse_colon(prompt, stoppers):
|
| 333 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(':'), stoppers))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def parse_vertical_bar(prompt, stoppers):
|
| 337 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('|'), stoppers))
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def parse_open_square(prompt, stoppers):
|
| 341 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('['), stoppers))
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def parse_close_square(prompt, stoppers):
|
| 345 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(']'), stoppers))
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def parse_open_paren(prompt, stoppers):
|
| 349 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape('('), stoppers))
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def parse_close_paren(prompt, stoppers):
|
| 353 |
+
return parse_token(prompt, whitespace_tail_regex(re.escape(')'), stoppers))
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def parse_newline(prompt, stoppers):
|
| 357 |
+
return parse_token(prompt, whitespace_tail_regex('\n|$', stoppers))
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def parse_token(prompt, regex):
|
| 361 |
+
match = re.match(regex, prompt)
|
| 362 |
+
if match is None:
|
| 363 |
+
raise ValueError
|
| 364 |
+
|
| 365 |
+
return ParseResult(prompt=prompt[len(match.group()):], expr=match.groups()[-1])
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def whitespace_tail_regex(regex, stoppers):
|
| 369 |
+
if '\n' in stoppers:
|
| 370 |
+
return rf'({regex})[ \t\f\r]*'
|
| 371 |
+
|
| 372 |
+
return rf'({regex})\s*'
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def set_concat(left, right):
|
| 376 |
+
result = set(left)
|
| 377 |
+
result.update(right)
|
| 378 |
+
return result
|
z-prompt-fusion-extension/lib_prompt_fusion/t_scaler.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def scale_t(t, positions):
|
| 2 |
+
if t >= 1.:
|
| 3 |
+
return 1.
|
| 4 |
+
|
| 5 |
+
if t <= 0.:
|
| 6 |
+
return 0.
|
| 7 |
+
|
| 8 |
+
distances = []
|
| 9 |
+
for i in range(len(positions)-1):
|
| 10 |
+
distances.append(positions[i+1] - positions[i])
|
| 11 |
+
|
| 12 |
+
total_distance = sum(distances)
|
| 13 |
+
for i in range(len(distances)):
|
| 14 |
+
distances[i] = distances[i]/total_distance
|
| 15 |
+
|
| 16 |
+
for i in range(len(distances)-1):
|
| 17 |
+
distances[i+1] = distances[i] + distances[i+1]
|
| 18 |
+
|
| 19 |
+
distances.insert(0, 0.0)
|
| 20 |
+
|
| 21 |
+
spline_index = 0
|
| 22 |
+
for i, distance in enumerate(distances):
|
| 23 |
+
if t > distance:
|
| 24 |
+
spline_index = i
|
| 25 |
+
else:
|
| 26 |
+
break
|
| 27 |
+
|
| 28 |
+
if spline_index >= len(distances) - 1:
|
| 29 |
+
return 1
|
| 30 |
+
|
| 31 |
+
local_ratio = (t - distances[spline_index]) / (distances[spline_index+1] - distances[spline_index])
|
| 32 |
+
return (spline_index + local_ratio)/(len(distances)-1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
total_steps = 20
|
| 37 |
+
for i in range(total_steps):
|
| 38 |
+
print(i, scale_t(i/total_steps, [9, 10]))
|
z-prompt-fusion-extension/metadata.ini
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[Extension]
|
| 2 |
+
Name = prompt-fusion
|
z-prompt-fusion-extension/readme.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Prompt Fusion
|
| 2 |
+
|
| 3 |
+
Prompt Fusion is an [auto1111 webui extension](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions) that adds more possibilities to the native prompt syntax. Among other additions, it allows to interpolate between the embeddings of different prompts, continuously:
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
# linear prompt interpolation
|
| 7 |
+
[night light:magical forest: 5, 15]
|
| 8 |
+
|
| 9 |
+
# catmull-rom curve prompt interpolation
|
| 10 |
+
[night light:magical forest:norvegian territory: 5, 15, 25:catmull]
|
| 11 |
+
|
| 12 |
+
# alternation interpolation
|
| 13 |
+
[ufo|a strange sight:0.5]
|
| 14 |
+
|
| 15 |
+
# linear attention interpolation
|
| 16 |
+
(fire extinguisher: 1.0, 2.0)
|
| 17 |
+
|
| 18 |
+
# prompt-editing-aware attention interpolation
|
| 19 |
+
[(fire extinguisher: 1.0, 2.0)::5]
|
| 20 |
+
|
| 21 |
+
# weighted sum of conditions
|
| 22 |
+
[space station : kitchen mixer :: mean]
|
| 23 |
+
|
| 24 |
+
# define functions and variables to simplify repeating patterns and use a consistent structure
|
| 25 |
+
$prompt($style, $quality, $character, $background) = (
|
| 26 |
+
A detailed picture in the style of $style,
|
| 27 |
+
$quality,
|
| 28 |
+
$character lying back,
|
| 29 |
+
$background in the background
|
| 30 |
+
:1)
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Features
|
| 34 |
+
|
| 35 |
+
- [Prompt interpolation using a curve function](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Interpolation)
|
| 36 |
+
- [Attention interpolation aware of contextual prompt editing](https://github.com/ljleb/prompt-fusion-extension/wiki/Attention-Interpolation)
|
| 37 |
+
- [Alternation interpolation](https://github.com/ljleb/prompt-fusion-extension/wiki/Alternation-interpolation)
|
| 38 |
+
- [Prompt weighted sum](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Average)
|
| 39 |
+
- [Prompt variables and functions](https://github.com/ljleb/prompt-fusion-extension/wiki/Prompt-Variables)
|
| 40 |
+
- Complete backwards compatibility with the builtin prompt syntax of the webui
|
| 41 |
+
|
| 42 |
+
The prompt interpolation feature is similar to [Prompt Travel](https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel), which allows to create videos of images generated by navigating the latent space iteratively. Unlike Prompt Travel however, instead of generating multiple images, Prompt Fusion allows you to travel during the sampling process of *a single image*. Also, instead of interpolating the latent space, it uses the embedding space to determine intermediate embeddings.
|
| 43 |
+
|
| 44 |
+
Prompt interpolation is also similar to [Prompt Blending](https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts). The main difference is that this extension calculates a new embedding for every step, as opposed to calculating it once and using that same one embedding for all the steps.
|
| 45 |
+
|
| 46 |
+
The attention interpolation feature is similar to [Shift Attention](https://github.com/yownas/shift-attention), which allows to generate multiple images with slight variations in the attention given to certain parts of the prompt. Unlike Shift Attention, instead of generating multiple images, Prompt Fusion allows to shift the attention of certain parts of a prompt during the sampling process of *a single image*.
|
| 47 |
+
|
| 48 |
+
## Usage
|
| 49 |
+
- Check the [wiki pages](https://github.com/ljleb/fusion/wiki) for the extension documentation.
|
| 50 |
+
|
| 51 |
+
## Examples
|
| 52 |
+
|
| 53 |
+
### 1. Influencing the beginning of the sampling process
|
| 54 |
+
|
| 55 |
+
Interpolate linearly (by default) from `lion` (step 0) to `bird` (step 8) to `girl` (step 11), and stay at `girl` for the rest of the sampling steps:
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
[lion:bird:girl: , 7, 10]
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+

|
| 62 |
+
|
| 63 |
+
### 2. Influencing the middle of the sampling process
|
| 64 |
+
|
| 65 |
+
Interpolate using a bezier curve from `fireball monster` (step 0) to `dragon monster` (step 12, because 30 steps * 0.4 = step 12), while using `seawater monster` as an intermediate control point to steer the curve away during interpolation and to get creative results:
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
[fireball:seawater:dragon: , .1, .4:bezier] monster
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
## Webui supported releases
|
| 74 |
+
|
| 75 |
+
The following webui releases are officially supported:
|
| 76 |
+
- `v1.0.0-pre`
|
| 77 |
+
- `master` (there may be a slight lag for issues arising during quick a1111 webui updates)
|
| 78 |
+
|
| 79 |
+
## Installation
|
| 80 |
+
1. Visit the **Extensions** tab of Automatic's WebUI.
|
| 81 |
+
2. Visit the **Available** subtab.
|
| 82 |
+
3. Look for **Prompt Fusion**.
|
| 83 |
+
4. Press the **Install** button.
|
| 84 |
+
5. Wait for the webui to finish downloading the extension.
|
| 85 |
+
6. Visit the **Installed** subtab.
|
| 86 |
+
7. click on **Apply and restart UI**.
|
| 87 |
+
|
| 88 |
+
Alternatively, instead of steps 6 and 7, you can restart the webui completely.
|
| 89 |
+
|
| 90 |
+
## Related Projects
|
| 91 |
+
|
| 92 |
+
- Prompt Travel: https://github.com/Kahsolt/stable-diffusion-webui-prompt-travel
|
| 93 |
+
- Shift Attention: https://github.com/yownas/shift-attention
|
| 94 |
+
- Prompt Blending: https://github.com/amotile/stable-diffusion-backend/tree/master/src/process/implementations/automatic1111_scripts
|
| 95 |
+
|
z-prompt-fusion-extension/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
z-prompt-fusion-extension/scripts/__pycache__/promptlang.cpython-310.pyc
ADDED
|
Binary file (8.6 kB). View file
|
|
|
z-prompt-fusion-extension/scripts/promptlang.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from typing import List, Any
|
| 3 |
+
|
| 4 |
+
from lib_prompt_fusion import (
|
| 5 |
+
hijacker,
|
| 6 |
+
empty_cond,
|
| 7 |
+
global_state,
|
| 8 |
+
interpolation_tensor,
|
| 9 |
+
prompt_parser as prompt_fusion_parser,
|
| 10 |
+
)
|
| 11 |
+
from modules import scripts, script_callbacks, prompt_parser, shared
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# -----------------------------------------------------------------------------
|
| 15 |
+
# UI options
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def on_ui_settings():
|
| 19 |
+
section = ("prompt-fusion", "Prompt Fusion")
|
| 20 |
+
shared.opts.add_option(
|
| 21 |
+
"prompt_fusion_enabled",
|
| 22 |
+
shared.OptionInfo(True, "Enable prompt-fusion extension", section=section),
|
| 23 |
+
)
|
| 24 |
+
shared.opts.add_option(
|
| 25 |
+
"prompt_fusion_slerp_scale",
|
| 26 |
+
shared.OptionInfo(
|
| 27 |
+
0.0,
|
| 28 |
+
"Slerp scale (0 = linear geometry, 1 = slerp geometry)",
|
| 29 |
+
component=gr.Number,
|
| 30 |
+
section=section,
|
| 31 |
+
),
|
| 32 |
+
)
|
| 33 |
+
shared.opts.add_option(
|
| 34 |
+
"prompt_fusion_slerp_negative_origin",
|
| 35 |
+
shared.OptionInfo(
|
| 36 |
+
True,
|
| 37 |
+
"Use negative prompt schedule as slerp origin",
|
| 38 |
+
component=gr.Checkbox,
|
| 39 |
+
section=section,
|
| 40 |
+
),
|
| 41 |
+
)
|
| 42 |
+
shared.opts.add_option(
|
| 43 |
+
"prompt_fusion_slerp_epsilon",
|
| 44 |
+
shared.OptionInfo(
|
| 45 |
+
1e-4,
|
| 46 |
+
"Slerp epsilon (similarity clamp; too similar/too orthogonal -> fallback to linear)",
|
| 47 |
+
component=gr.Number,
|
| 48 |
+
section=section,
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
script_callbacks.on_ui_settings(on_ui_settings)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# -----------------------------------------------------------------------------
|
| 57 |
+
# Hijacker installation
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
fusion_hijacker_attribute = "__fusion_hijacker"
|
| 61 |
+
prompt_parser_hijacker = hijacker.ModuleHijacker.install_or_get(
|
| 62 |
+
module=prompt_parser,
|
| 63 |
+
hijacker_attribute=fusion_hijacker_attribute,
|
| 64 |
+
register_uninstall=script_callbacks.on_script_unloaded,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# -----------------------------------------------------------------------------
|
| 69 |
+
# Helpers
|
| 70 |
+
# -----------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
def _wrap_cond_any(raw_cond):
|
| 73 |
+
"""
|
| 74 |
+
Заворачивает сырое cond (dict[str, Tensor] или Tensor) в обёртки,
|
| 75 |
+
которые понимает логика интерполяции (DictCondWrapper/TensorCondWrapper).
|
| 76 |
+
Ничего не индексируем — считаем, что пришёл «одиночный» cond.
|
| 77 |
+
"""
|
| 78 |
+
if isinstance(raw_cond, dict):
|
| 79 |
+
return interpolation_tensor.DictCondWrapper(raw_cond)
|
| 80 |
+
else:
|
| 81 |
+
return interpolation_tensor.TensorCondWrapper(raw_cond)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _adapt_flattened_schedules(result: Any, total_steps: int):
|
| 85 |
+
"""
|
| 86 |
+
Нормализует вывод prompt_parser.get_learned_conditioning() к
|
| 87 |
+
List[List[ScheduledPromptConditioning]].
|
| 88 |
+
|
| 89 |
+
Поддерживает:
|
| 90 |
+
- уже готовый список списков расписаний;
|
| 91 |
+
- MulticondLearnedConditioning c .batch (комбинируем по весам);
|
| 92 |
+
- одиночный cond (оборачиваем в расписание «весь диапазон»).
|
| 93 |
+
"""
|
| 94 |
+
# 1) Уже тот формат, который нужен
|
| 95 |
+
if isinstance(result, list) and (len(result) == 0 or isinstance(result[0], list)):
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
# 2) Попытка «утиная типизация» MulticondLearnedConditioning
|
| 99 |
+
batch = getattr(result, "batch", None)
|
| 100 |
+
if batch is not None:
|
| 101 |
+
adapted = []
|
| 102 |
+
# batch: List[List[ComposableScheduledPromptConditioning]]
|
| 103 |
+
for composables in batch:
|
| 104 |
+
# Собираем все границы из всех сабкомпонент
|
| 105 |
+
boundaries = set()
|
| 106 |
+
for comp in composables:
|
| 107 |
+
for entry in getattr(comp, "schedules", []) or []:
|
| 108 |
+
try:
|
| 109 |
+
boundaries.add(int(entry.end_at_step))
|
| 110 |
+
except Exception:
|
| 111 |
+
boundaries.add(total_steps - 1)
|
| 112 |
+
if not boundaries:
|
| 113 |
+
boundaries = {total_steps - 1}
|
| 114 |
+
sorted_bounds = sorted(boundaries)
|
| 115 |
+
|
| 116 |
+
# Вычислитель cond на шаге для одного composable
|
| 117 |
+
def cond_for_step(comp, step):
|
| 118 |
+
# comp.schedules: List[ScheduledPromptConditioning]
|
| 119 |
+
idx = 0
|
| 120 |
+
for i, entry in enumerate(comp.schedules):
|
| 121 |
+
try:
|
| 122 |
+
end_at = int(entry.end_at_step)
|
| 123 |
+
except Exception:
|
| 124 |
+
end_at = total_steps - 1
|
| 125 |
+
if step <= end_at:
|
| 126 |
+
idx = i
|
| 127 |
+
break
|
| 128 |
+
return comp.schedules[idx].cond
|
| 129 |
+
|
| 130 |
+
# Сумма cond с весами; поддерживаем dict или Tensor
|
| 131 |
+
def add_scaled(dst, src, w):
|
| 132 |
+
if dst is None:
|
| 133 |
+
return {k: v * w for k, v in src.items()} if isinstance(src, dict) else src * w
|
| 134 |
+
if isinstance(src, dict):
|
| 135 |
+
out = dict(dst)
|
| 136 |
+
for k, v in src.items():
|
| 137 |
+
out[k] = out.get(k, 0) + v * w
|
| 138 |
+
return out
|
| 139 |
+
return dst + src * w
|
| 140 |
+
|
| 141 |
+
sched = []
|
| 142 |
+
for end_at in sorted_bounds:
|
| 143 |
+
combined = None
|
| 144 |
+
for comp in composables:
|
| 145 |
+
w = float(getattr(comp, "weight", 1.0))
|
| 146 |
+
combined = add_scaled(combined, cond_for_step(comp, end_at), w)
|
| 147 |
+
sched.append(
|
| 148 |
+
prompt_parser.ScheduledPromptConditioning(
|
| 149 |
+
end_at_step=end_at, cond=combined
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
adapted.append(sched)
|
| 153 |
+
return adapted
|
| 154 |
+
|
| 155 |
+
# 3) Фолбэк: одиночный cond -> одно расписание на весь диапазон
|
| 156 |
+
return [[prompt_parser.ScheduledPromptConditioning(end_at_step=total_steps - 1, cond=result)]]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _build_tensor_for_prompt(
|
| 160 |
+
model,
|
| 161 |
+
prompt_text: str,
|
| 162 |
+
total_steps: int,
|
| 163 |
+
is_hires: bool,
|
| 164 |
+
use_old_scheduling: bool,
|
| 165 |
+
original_function,
|
| 166 |
+
*,
|
| 167 |
+
original_prompts_proto=None,
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Разбирает один текст промпта своим AST, строит базу строк, кодирует их
|
| 171 |
+
через «модульный»/стандартный парсер (original_function),
|
| 172 |
+
нормализует расписания и собирает интерполяционный тензор.
|
| 173 |
+
Возвращает готовый список ScheduledPromptConditioning для этого промпта.
|
| 174 |
+
"""
|
| 175 |
+
# 1) Построить TensorBuilder из AST
|
| 176 |
+
tensor_builder = interpolation_tensor.InterpolationTensorBuilder()
|
| 177 |
+
expr = prompt_fusion_parser.parse_prompt(prompt_text)
|
| 178 |
+
expr.extend_tensor(
|
| 179 |
+
tensor_builder,
|
| 180 |
+
[0, max(0, total_steps - 1)],
|
| 181 |
+
total_steps,
|
| 182 |
+
{},
|
| 183 |
+
is_hires,
|
| 184 |
+
use_old_scheduling,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# 2) Получить «плоскую» базу строк и закодировать их в cond
|
| 188 |
+
flat_prompts = tensor_builder.get_prompt_database()
|
| 189 |
+
|
| 190 |
+
# Сохраняем поведение A1111: в некоторых версиях промпт-объект содержит флаг негативности.
|
| 191 |
+
# Сделаем list-обёртку с тем же атрибутом, если он есть у исходного prompts.
|
| 192 |
+
class _SdLike(list):
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
if original_prompts_proto is not None and hasattr(original_prompts_proto, "is_negative_prompt"):
|
| 196 |
+
sd_like = _SdLike(flat_prompts)
|
| 197 |
+
sd_like.is_negative_prompt = original_prompts_proto.is_negative_prompt
|
| 198 |
+
flat_input = sd_like
|
| 199 |
+
else:
|
| 200 |
+
flat_input = flat_prompts
|
| 201 |
+
|
| 202 |
+
flattened = original_function(model, flat_input, total_steps)
|
| 203 |
+
flattened = _adapt_flattened_schedules(flattened, total_steps)
|
| 204 |
+
|
| 205 |
+
# 3) Обернуть cond-значения
|
| 206 |
+
wrapped_conds: List[List[prompt_parser.ScheduledPromptConditioning]] = []
|
| 207 |
+
for sched in flattened:
|
| 208 |
+
wrapped_sched = [
|
| 209 |
+
prompt_parser.ScheduledPromptConditioning(
|
| 210 |
+
end_at_step=entry.end_at_step,
|
| 211 |
+
cond=_wrap_cond_any(entry.cond),
|
| 212 |
+
)
|
| 213 |
+
for entry in sched
|
| 214 |
+
]
|
| 215 |
+
wrapped_conds.append(wrapped_sched)
|
| 216 |
+
|
| 217 |
+
# 4) Собрать интерполяционный тензор и вычислить расписания по шагам
|
| 218 |
+
tensor = tensor_builder.build(wrapped_conds, empty_cond.get())
|
| 219 |
+
|
| 220 |
+
slerp_scale = global_state.get_slerp_scale()
|
| 221 |
+
slerp_epsilon = global_state.get_slerp_epsilon()
|
| 222 |
+
|
| 223 |
+
# Сформировать расписание, склеивая одинаковые соседние сегменты
|
| 224 |
+
schedules: List[prompt_parser.ScheduledPromptConditioning] = []
|
| 225 |
+
prev_wrapper = None
|
| 226 |
+
for step in range(total_steps):
|
| 227 |
+
params = interpolation_tensor.InterpolationParams(
|
| 228 |
+
t=step / max(1, total_steps - 1),
|
| 229 |
+
step=step,
|
| 230 |
+
total_steps=total_steps,
|
| 231 |
+
slerp_scale=slerp_scale,
|
| 232 |
+
slerp_epsilon=slerp_epsilon,
|
| 233 |
+
)
|
| 234 |
+
origin = global_state.get_origin_cond_at(step, is_hires=is_hires)
|
| 235 |
+
cond_wrapper = tensor.interpolate(params, origin, empty_cond.get())
|
| 236 |
+
|
| 237 |
+
# Склейка соседних, если равны (по wrapper.__eq__)
|
| 238 |
+
if prev_wrapper is not None and cond_wrapper == prev_wrapper:
|
| 239 |
+
schedules[-1].end_at_step = step
|
| 240 |
+
else:
|
| 241 |
+
raw = cond_wrapper.original_cond if hasattr(cond_wrapper, "original_cond") else cond_wrapper
|
| 242 |
+
schedules.append(
|
| 243 |
+
prompt_parser.ScheduledPromptConditioning(end_at_step=step, cond=raw)
|
| 244 |
+
)
|
| 245 |
+
prev_wrapper = cond_wrapper
|
| 246 |
+
|
| 247 |
+
# Финализировать последний сегмент на последний шаг
|
| 248 |
+
if schedules:
|
| 249 |
+
schedules[-1].end_at_step = total_steps - 1
|
| 250 |
+
|
| 251 |
+
return schedules
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# -----------------------------------------------------------------------------
|
| 255 |
+
# Hooks
|
| 256 |
+
# -----------------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
@prompt_parser_hijacker.hijack("get_learned_conditioning")
|
| 259 |
+
def _fusion_get_learned_conditioning(
|
| 260 |
+
model,
|
| 261 |
+
prompts,
|
| 262 |
+
total_steps: int,
|
| 263 |
+
*args,
|
| 264 |
+
original_function=None,
|
| 265 |
+
**kwargs,
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Главный перехват. Если расширение выключено — просто проксируем оригинал.
|
| 269 |
+
Иначе: строим собственный интерполяционный тензор поверх энкодинга «модульного» парсера.
|
| 270 |
+
"""
|
| 271 |
+
# В офф — строго оригинал
|
| 272 |
+
if not shared.opts.data.get("prompt_fusion_enabled", True):
|
| 273 |
+
return original_function(model, prompts, total_steps, *args, **kwargs)
|
| 274 |
+
|
| 275 |
+
# Извлекаем hires-флаг/режим интерпретации шагов (совместимость со старыми A1111)
|
| 276 |
+
hires_steps = None
|
| 277 |
+
use_old_scheduling = True
|
| 278 |
+
if args:
|
| 279 |
+
# Популярный порядок: (hires_steps, use_old_scheduling, ...)
|
| 280 |
+
try:
|
| 281 |
+
hires_steps, use_old_scheduling = args[0], args[1]
|
| 282 |
+
except Exception:
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
is_hires = hires_steps is not None
|
| 286 |
+
real_total_steps = hires_steps if is_hires else total_steps
|
| 287 |
+
|
| 288 |
+
# Определяем, это негативный вызов или нет
|
| 289 |
+
if hasattr(prompts, "is_negative_prompt"):
|
| 290 |
+
is_negative_prompt = bool(prompts.is_negative_prompt)
|
| 291 |
+
else:
|
| 292 |
+
# Старые билды: первый вызов «считаем негативным», второй — положительным
|
| 293 |
+
is_negative_prompt = global_state.old_webui_is_negative
|
| 294 |
+
|
| 295 |
+
# Инициируем «пустой» cond
|
| 296 |
+
empty_cond.init(model)
|
| 297 |
+
|
| 298 |
+
# Готовим выход для всего батча
|
| 299 |
+
out: List[List[prompt_parser.ScheduledPromptConditioning]] = []
|
| 300 |
+
|
| 301 |
+
# Пробегаем по элементам батча
|
| 302 |
+
for prompt_text in prompts:
|
| 303 |
+
sched = _build_tensor_for_prompt(
|
| 304 |
+
model=model,
|
| 305 |
+
prompt_text=prompt_text,
|
| 306 |
+
total_steps=real_total_steps,
|
| 307 |
+
is_hires=is_hires,
|
| 308 |
+
use_old_scheduling=use_old_scheduling,
|
| 309 |
+
original_function=original_function,
|
| 310 |
+
original_prompts_proto=prompts,
|
| 311 |
+
)
|
| 312 |
+
out.append(sched)
|
| 313 |
+
|
| 314 |
+
# Если это негативные расписания — сохраним для slerp-origin
|
| 315 |
+
if is_negative_prompt:
|
| 316 |
+
if is_hires:
|
| 317 |
+
global_state.negative_schedules_hires = out[0] if out else None
|
| 318 |
+
else:
|
| 319 |
+
global_state.negative_schedules = out[0] if out else None
|
| 320 |
+
|
| 321 |
+
return out
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@prompt_parser_hijacker.hijack("get_multicond_learned_conditioning")
|
| 325 |
+
def _fusion_get_multicond_learned_conditioning(
|
| 326 |
+
*args,
|
| 327 |
+
original_function=None,
|
| 328 |
+
**kwargs,
|
| 329 |
+
):
|
| 330 |
+
"""
|
| 331 |
+
Совместимость со старым порядком вызовов A1111:
|
| 332 |
+
после вызова этого метода следующий get_learned_conditioning — уже не «негативный».
|
| 333 |
+
Ничего не меняем в значениях — возвращаем как есть.
|
| 334 |
+
"""
|
| 335 |
+
global_state.old_webui_is_negative = False
|
| 336 |
+
return original_function(*args, **kwargs)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# -----------------------------------------------------------------------------
|
| 340 |
+
# WebUI script shim (чтобы сбрасывать «негативность» и кэш между генерациями)
|
| 341 |
+
# -----------------------------------------------------------------------------
|
| 342 |
+
|
| 343 |
+
class PromptFusionScript(scripts.Script):
|
| 344 |
+
def title(self):
|
| 345 |
+
return "Prompt Fusion"
|
| 346 |
+
|
| 347 |
+
def show(self, is_img2img):
|
| 348 |
+
return scripts.AlwaysVisible
|
| 349 |
+
|
| 350 |
+
def process(self, p, *args):
|
| 351 |
+
# Перед началом пайпа считаем, что следующий get_learned_conditioning — «негативный»
|
| 352 |
+
global_state.negative_schedules = None
|
| 353 |
+
global_state.negative_schedules_hires = None
|
| 354 |
+
global_state.old_webui_is_negative = True
|
z-prompt-fusion-extension/test/parser_tests.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lib_prompt_fusion.prompt_parser import parse_prompt
|
| 2 |
+
from lib_prompt_fusion.interpolation_tensor import InterpolationTensorBuilder
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def run_functional_tests(total_steps=100):
|
| 6 |
+
for i, (given, expected) in enumerate(functional_parse_test_cases):
|
| 7 |
+
expr = parse_prompt(given)
|
| 8 |
+
tensor_builder = InterpolationTensorBuilder()
|
| 9 |
+
expr.extend_tensor(tensor_builder, (0, total_steps), total_steps, dict(), is_hires=False, use_old_scheduling=False)
|
| 10 |
+
|
| 11 |
+
actual = tensor_builder.get_prompt_database()
|
| 12 |
+
|
| 13 |
+
if type(expected) is set:
|
| 14 |
+
assert set(actual) == expected, f"{actual} != {expected}"
|
| 15 |
+
else:
|
| 16 |
+
assert len(actual) == 1 and actual[0] == expected, f"'{actual[0]}' != '{expected}'"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
functional_parse_test_cases = [
|
| 20 |
+
('single',)*2,
|
| 21 |
+
('some space separated text',)*2,
|
| 22 |
+
('(legacy weighted prompt:-2.1)',)*2,
|
| 23 |
+
('mixed (legacy weight:3.6) and text',)*2,
|
| 24 |
+
('legacy [range begin:0] thingy',)*2,
|
| 25 |
+
('legacy [range end::3] thingy',)*2,
|
| 26 |
+
('legacy [[nested range::3]:2] thingy',)*2,
|
| 27 |
+
('legacy [[nested range:2]::3] thingy',)*2,
|
| 28 |
+
('sugar [range:,abc:3] thingy',)*2,
|
| 29 |
+
('sugar [[(weight interpolation:0,12):0]::1] thingy', 'sugar [[(weight interpolation:0.0):0]::1] thingy'),
|
| 30 |
+
('sugar [[(weight interpolation:0,12):0]::2] thingy', 'sugar [[[(weight interpolation:0.0)::1][(weight interpolation:12.0):1]:0]::2] thingy'),
|
| 31 |
+
('sugar [[(weight interpolation:0,12):0]::3] thingy', 'sugar [[[(weight interpolation:0.0)::1][[(weight interpolation:6.0):1]::2][(weight interpolation:12.0):2]:0]::3] thingy'),
|
| 32 |
+
('legacy [from:to:2] thingy',)*2,
|
| 33 |
+
('legacy [negative weight]',)*2,
|
| 34 |
+
('legacy (positive weight)',)*2,
|
| 35 |
+
('[abc:1girl:2]',)*2,
|
| 36 |
+
('[::]',)*2,
|
| 37 |
+
('[a:b:]',)*2,
|
| 38 |
+
('[[a:b:1,2]:b:]', {'[a:b:]', '[b:b:]'}),
|
| 39 |
+
('1girl',)*2,
|
| 40 |
+
('dashes-in-text',)*2,
|
| 41 |
+
('text, separated with, comas',)*2,
|
| 42 |
+
('{prompt}',)*2,
|
| 43 |
+
('[abc|def ghi|jkl]',)*2,
|
| 44 |
+
('merging this AND with this',)*2,
|
| 45 |
+
(':',)*2,
|
| 46 |
+
(r'portrait \(object\)',)*2,
|
| 47 |
+
(r'\[escaped square\]',)*2,
|
| 48 |
+
(r'\$var = abc',)*2,
|
| 49 |
+
(r'\\$ arst',)*2,
|
| 50 |
+
(r'$$ arst',)*2,
|
| 51 |
+
('$var = abc', ''),
|
| 52 |
+
('$a = prompt value\n$a', 'prompt value'),
|
| 53 |
+
('$a = prompt value\n$b = $a\n$b', 'prompt value'),
|
| 54 |
+
('$a = (multiline\nprompt\nvalue:1.0)\n$a', '(multiline prompt value:1.0)'),
|
| 55 |
+
('$a = ($aa = nested variable\nmultiline\n$aa:1.0)\n$a', '(multiline nested variable:1.0)'),
|
| 56 |
+
('a [b:c:-1, 10] d', {'a b d', 'a c d'}),
|
| 57 |
+
('a [b:c:5, 6] d', {'a b d', 'a c d'}),
|
| 58 |
+
('a [b:c:0.25, 0.5] d', {'a b d', 'a c d'}),
|
| 59 |
+
('a [b:c:.25, .5] d', {'a b d', 'a c d'}),
|
| 60 |
+
('a [b:c:,] d', {'a b d', 'a c d'}),
|
| 61 |
+
('0[1.0:1.1:,]2[3.0:3.1:,]4', {
|
| 62 |
+
'0 1.0 2 3.0 4', '0 1.1 2 3.0 4',
|
| 63 |
+
'0 1.0 2 3.1 4', '0 1.1 2 3.1 4',
|
| 64 |
+
}),
|
| 65 |
+
('0[1.0:1.1:1.2:,.5,]2[3.0:3.1:,]4', {
|
| 66 |
+
'0 1.0 2 3.0 4', '0 1.0 2 3.1 4',
|
| 67 |
+
'0 1.1 2 3.0 4', '0 1.1 2 3.1 4',
|
| 68 |
+
'0 1.2 2 3.0 4', '0 1.2 2 3.1 4',
|
| 69 |
+
}),
|
| 70 |
+
('[0.0:0.1:,][1.0:1.1:,][2.0:2.1:,]', {
|
| 71 |
+
'0.0 1.0 2.0', '0.0 1.0 2.1',
|
| 72 |
+
'0.1 1.0 2.0', '0.1 1.0 2.1',
|
| 73 |
+
'0.0 1.1 2.0', '0.0 1.1 2.1',
|
| 74 |
+
'0.1 1.1 2.0', '0.1 1.1 2.1',
|
| 75 |
+
}),
|
| 76 |
+
('[top level:interpolatin:lik a pro:1,3,5:linear]', {'top level', 'interpolatin', 'lik a pro'}),
|
| 77 |
+
('[[nested:expr:,]:abc:,]', {'nested', 'expr', 'abc'}),
|
| 78 |
+
('[(nested attention:2.0):abc:,]', {'(nested attention:2.0)', 'abc'}),
|
| 79 |
+
('[[nested editing:15]:abc:,]', {'[nested editing:15]', 'abc'}),
|
| 80 |
+
('[[nested interpolation:abc:,]:12]', {'[nested interpolation:12]', '[abc:12]'}),
|
| 81 |
+
('[[nested interpolation:abc:,]::7]', {'[nested interpolation::7]', '[abc::7]'}),
|
| 82 |
+
('$attention = 1.5\n(prompt:$attention)', '(prompt:1.5)'),
|
| 83 |
+
('$a = 0\n$b = 12\n[[(prompt:$a,$b):0]::2]', '[[[(prompt:0.0)::1][(prompt:12.0):1]:0]::2]'),
|
| 84 |
+
('$step = 5\n[legacy:editing:$step]', '[legacy:editing:5]'),
|
| 85 |
+
('$begin = 2\n$end = 7\n[prompt:interpolation:$begin, $end]', {'prompt', 'interpolation'}),
|
| 86 |
+
('$a($b, $c) = prompt with $b, prompt with $c\n$a(cat, dog)', 'prompt with cat , prompt with dog'),
|
| 87 |
+
('$a($b) = prompt with $b\n$c($d) = yeay $a($d)\n$c(dog)', 'yeay prompt with dog'),
|
| 88 |
+
('$a = a lot of animals\n$b($c) = I love $c\n$b($a)', 'I love a lot of animals'),
|
| 89 |
+
('$a($b) = prompt with $b\n$c($d) = yeay $d\n$a($c(dog))', 'prompt with yeay dog'),
|
| 90 |
+
('[a|b|c]', '[a|b|c]'),
|
| 91 |
+
('[a|b|c:]', '[a|b|c]'),
|
| 92 |
+
('[a|b|c:1]', {'a', 'b', 'c'}),
|
| 93 |
+
('[a|b|c:2]', {'a', 'b', 'c'}),
|
| 94 |
+
('[a|b|c:0.5]', {'a', 'b', 'c'}),
|
| 95 |
+
('[a|b|c:1.1]', {'a', 'b', 'c'}),
|
| 96 |
+
('[[[Imperial Yellow|Amber]:[Ruby|Plum|Bronze]:9]::39]',)*2,
|
| 97 |
+
('[a:b:c::mean]', {'a', 'b', 'c'}),
|
| 98 |
+
('[a:b:c:,,:mean]', {'a', 'b', 'c'}),
|
| 99 |
+
('[a:b:c: 1, 2, 3:mean]', {'a', 'b', 'c'}),
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def run_tests():
|
| 104 |
+
run_functional_tests()
|
z-prompt-fusion-extension/test/run_all.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('..')
|
| 3 |
+
import parser_tests
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
parser_tests.run_tests()
|