Commit ·
d0bd9ea
1
Parent(s): 3a509a0
Upload 3 files
Browse files- VCM07_style.pt +3 -0
- VCM07_style2.pt +3 -0
- prompt_blending.py +183 -0
VCM07_style.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8de463223c45d273ec77449808d47bce0b6987678ccc71cf3d413beba6ad3a17
|
| 3 |
+
size 25515
|
VCM07_style2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f1aec4732e93aa0943a30f1d8c8ec666abc9e94b0d0136d43c58740bd3d510f
|
| 3 |
+
size 25515
|
prompt_blending.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import modules.scripts as scripts
|
| 2 |
+
import modules.prompt_parser as prompt_parser
|
| 3 |
+
import itertools
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def hijacked_get_learned_conditioning(model, prompts, steps):
|
| 8 |
+
global real_get_learned_conditioning
|
| 9 |
+
|
| 10 |
+
if not hasattr(model, '__hacked'):
|
| 11 |
+
real_model_func = model.get_learned_conditioning
|
| 12 |
+
|
| 13 |
+
def hijacked_model_func(texts):
|
| 14 |
+
weighted_prompts = list(map(lambda t: get_weighted_prompt((t, 1)), texts))
|
| 15 |
+
all_texts = []
|
| 16 |
+
for weighted_prompt in weighted_prompts:
|
| 17 |
+
for (prompt, weight) in weighted_prompt:
|
| 18 |
+
all_texts.append(prompt)
|
| 19 |
+
|
| 20 |
+
if len(all_texts) > len(texts):
|
| 21 |
+
all_conds = real_model_func(all_texts)
|
| 22 |
+
offset = 0
|
| 23 |
+
|
| 24 |
+
conds = []
|
| 25 |
+
|
| 26 |
+
for weighted_prompt in weighted_prompts:
|
| 27 |
+
c = torch.zeros_like(all_conds[offset])
|
| 28 |
+
for (i, (prompt, weight)) in enumerate(weighted_prompt):
|
| 29 |
+
c = torch.add(c, all_conds[i+offset], alpha=weight)
|
| 30 |
+
conds.append(c)
|
| 31 |
+
offset += len(weighted_prompt)
|
| 32 |
+
return conds
|
| 33 |
+
else:
|
| 34 |
+
return real_model_func(texts)
|
| 35 |
+
|
| 36 |
+
model.get_learned_conditioning = hijacked_model_func
|
| 37 |
+
model.__hacked = True
|
| 38 |
+
|
| 39 |
+
switched_prompts = list(map(lambda p: switch_syntax(p), prompts))
|
| 40 |
+
return real_get_learned_conditioning(model, switched_prompts, steps)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
real_get_learned_conditioning = hijacked_get_learned_conditioning # no really, overriden below
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Script(scripts.Script):
|
| 47 |
+
def title(self):
|
| 48 |
+
return "Prompt Blending"
|
| 49 |
+
|
| 50 |
+
def show(self, is_img2img):
|
| 51 |
+
global real_get_learned_conditioning
|
| 52 |
+
if real_get_learned_conditioning == hijacked_get_learned_conditioning:
|
| 53 |
+
real_get_learned_conditioning = prompt_parser.get_learned_conditioning
|
| 54 |
+
prompt_parser.get_learned_conditioning = hijacked_get_learned_conditioning
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
def ui(self, is_img2img):
|
| 58 |
+
return []
|
| 59 |
+
|
| 60 |
+
def run(self, p, seeds):
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
OPEN = '{'
|
| 65 |
+
CLOSE = '}'
|
| 66 |
+
SEPARATE = '|'
|
| 67 |
+
MARK = '@'
|
| 68 |
+
REAL_MARK = ':'
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def combine(left, right):
|
| 72 |
+
return map(lambda p: (p[0][0] + p[1][0], p[0][1] * p[1][1]), itertools.product(left, right))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_weighted_prompt(prompt_weight):
|
| 76 |
+
(prompt, full_weight) = prompt_weight
|
| 77 |
+
results = [('', full_weight)]
|
| 78 |
+
alts = []
|
| 79 |
+
start = 0
|
| 80 |
+
mark = -1
|
| 81 |
+
open_count = 0
|
| 82 |
+
first_open = 0
|
| 83 |
+
nested = False
|
| 84 |
+
|
| 85 |
+
for i, c in enumerate(prompt):
|
| 86 |
+
add_alt = False
|
| 87 |
+
do_combine = False
|
| 88 |
+
if c == OPEN:
|
| 89 |
+
open_count += 1
|
| 90 |
+
if open_count == 1:
|
| 91 |
+
first_open = i
|
| 92 |
+
results = list(combine(results, [(prompt[start:i], 1)]))
|
| 93 |
+
start = i + 1
|
| 94 |
+
else:
|
| 95 |
+
nested = True
|
| 96 |
+
|
| 97 |
+
if c == MARK and open_count == 1:
|
| 98 |
+
mark = i
|
| 99 |
+
|
| 100 |
+
if c == SEPARATE and open_count == 1:
|
| 101 |
+
add_alt = True
|
| 102 |
+
|
| 103 |
+
if c == CLOSE:
|
| 104 |
+
open_count -= 1
|
| 105 |
+
if open_count == 0:
|
| 106 |
+
add_alt = True
|
| 107 |
+
do_combine = True
|
| 108 |
+
if i == len(prompt) - 1 and open_count > 0:
|
| 109 |
+
add_alt = True
|
| 110 |
+
do_combine = True
|
| 111 |
+
|
| 112 |
+
if add_alt:
|
| 113 |
+
end = i
|
| 114 |
+
weight = 1
|
| 115 |
+
if mark != -1:
|
| 116 |
+
weight_str = prompt[mark + 1:i]
|
| 117 |
+
try:
|
| 118 |
+
weight = float(weight_str)
|
| 119 |
+
end = mark
|
| 120 |
+
except ValueError:
|
| 121 |
+
print("warning, not a number:", weight_str)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
alt = (prompt[start:end], weight)
|
| 126 |
+
alts += get_weighted_prompt(alt) if nested else [alt]
|
| 127 |
+
nested = False
|
| 128 |
+
mark = -1
|
| 129 |
+
start = i + 1
|
| 130 |
+
|
| 131 |
+
if do_combine:
|
| 132 |
+
if len(alts) <= 1:
|
| 133 |
+
alts = [(prompt[first_open:i + 1], 1)]
|
| 134 |
+
|
| 135 |
+
results = list(combine(results, alts))
|
| 136 |
+
alts = []
|
| 137 |
+
|
| 138 |
+
# rest of the prompt
|
| 139 |
+
results = list(combine(results, [(prompt[start:], 1)]))
|
| 140 |
+
weight_sum = sum(map(lambda r: r[1], results))
|
| 141 |
+
results = list(map(lambda p: (p[0], p[1] / weight_sum * full_weight), results))
|
| 142 |
+
|
| 143 |
+
return results
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def switch_syntax(prompt):
|
| 147 |
+
p = list(prompt)
|
| 148 |
+
stack = []
|
| 149 |
+
for i, c in enumerate(p):
|
| 150 |
+
if c == '{' or c == '[' or c == '(':
|
| 151 |
+
stack.append(c)
|
| 152 |
+
|
| 153 |
+
if len(stack) > 0:
|
| 154 |
+
if c == '}' or c == ']' or c == ')':
|
| 155 |
+
stack.pop()
|
| 156 |
+
|
| 157 |
+
if c == REAL_MARK and stack[-1] == '{':
|
| 158 |
+
p[i] = MARK
|
| 159 |
+
|
| 160 |
+
return "".join(p)
|
| 161 |
+
|
| 162 |
+
# def test(p, w=1):
|
| 163 |
+
# print('')
|
| 164 |
+
# print(p)
|
| 165 |
+
# result = get_weighted_prompt((p, w))
|
| 166 |
+
# print(result)
|
| 167 |
+
# print(sum(map(lambda x: x[1], result)))
|
| 168 |
+
#
|
| 169 |
+
#
|
| 170 |
+
# test("fantasy landscape")
|
| 171 |
+
# test("fantasy {landscape|city}, dark")
|
| 172 |
+
# test("fantasy {landscape|city}, {fire|ice} ")
|
| 173 |
+
# test("fantasy {landscape|city}, {fire|ice}, {dark|light} ")
|
| 174 |
+
# test("fantasy landscape, {{fire|lava}|ice}")
|
| 175 |
+
# test("fantasy landscape, {{fire@4|lava@1}|ice@2}")
|
| 176 |
+
# test("fantasy landscape, {{fire@error|lava@1}|ice@2}")
|
| 177 |
+
# test("fantasy landscape, {{fire|lava}|ice@2")
|
| 178 |
+
# test("fantasy landscape, {fire|lava} {cool} {ice,water}")
|
| 179 |
+
# test("fantasy landscape, {fire|lava} {cool} {ice,water")
|
| 180 |
+
# test("{lava|ice|water@5}")
|
| 181 |
+
# test("{fire@4|lava@1}", 5)
|
| 182 |
+
# test("{{fire@4|lava@1}|ice@2|water@5}")
|
| 183 |
+
# test("{fire|lava@3.5}")
|