Leeps commited on
Commit
b1ced3d
·
1 Parent(s): 1cb7498

Add triangle prompt embedding mixer

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +183 -1
README.md CHANGED
@@ -22,6 +22,7 @@ A ZeroGPU-ready Gradio app for learning how Stable Diffusion works inside Diffus
22
  Instead of only calling `pipe(prompt)`, the app exposes a small custom denoising loop:
23
 
24
  - prompt embeddings can be used directly, averaged, or combined with vector arithmetic
 
25
  - initial latent noise can come from one seed or a blend of two seeds
26
  - classifier-free guidance can use standard CFG or a student-edited equation
27
  - intermediate latent snapshots show how the image emerges across denoising steps
 
22
  Instead of only calling `pipe(prompt)`, the app exposes a small custom denoising loop:
23
 
24
  - prompt embeddings can be used directly, averaged, or combined with vector arithmetic
25
+ - three prompt embeddings can be explored as a clickable triangle between concepts
26
  - initial latent noise can come from one seed or a blend of two seeds
27
  - classifier-free guidance can use standard CFG or a student-edited equation
28
  - intermediate latent snapshots show how the image emerges across denoising steps
app.py CHANGED
@@ -40,6 +40,13 @@ prompt_c, _ = encode_prompt(prompt_c, negative_prompt)
40
 
41
  if mode == "average":
42
  prompt_embeds = (1 - mix) * prompt_a + mix * prompt_b
 
 
 
 
 
 
 
43
  elif mode == "analogy":
44
  prompt_embeds = prompt_a + strength * (prompt_b - prompt_c)
45
  else:
@@ -120,6 +127,128 @@ def blank_image(message="Run generation to make an image."):
120
  return image
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  @lru_cache(maxsize=2)
124
  def load_pipe(model_id, scheduler_name, device_type):
125
  device = torch.device(device_type)
@@ -185,6 +314,9 @@ def mix_prompt_embeddings(
185
  negative_prompt,
186
  embedding_mode,
187
  prompt_mix,
 
 
 
188
  analogy_strength,
189
  renormalize_prompt,
190
  ):
@@ -195,6 +327,17 @@ def mix_prompt_embeddings(
195
  if embedding_mode == "Average prompt A and B":
196
  mixed = (1.0 - prompt_mix) * emb_a + prompt_mix * emb_b
197
  formula = f"prompt = {(1.0 - prompt_mix):.2f} * A + {prompt_mix:.2f} * B"
 
 
 
 
 
 
 
 
 
 
 
198
  elif embedding_mode == "Vector arithmetic: A + s * (B - C)":
199
  mixed = emb_a + analogy_strength * (emb_b - emb_c)
200
  formula = f"prompt = A + {analogy_strength:.2f} * (B - C)"
@@ -212,6 +355,7 @@ def mix_prompt_embeddings(
212
  ["cosine(A, B)", round(cosine_similarity(emb_a, emb_b), 4)],
213
  ["cosine(A, mixed)", round(cosine_similarity(emb_a, mixed), 4)],
214
  ["cosine(B, mixed)", round(cosine_similarity(emb_b, mixed), 4)],
 
215
  ["norm(A)", round(float(original_norm.cpu()), 3)],
216
  ["norm(mixed)", round(float(mixed.detach().float().norm().cpu()), 3)],
217
  ]
@@ -316,6 +460,9 @@ def generate(
316
  negative_prompt,
317
  embedding_mode,
318
  prompt_mix,
 
 
 
319
  analogy_strength,
320
  renormalize_prompt,
321
  seed_a,
@@ -356,6 +503,9 @@ def generate(
356
  negative_prompt or "",
357
  embedding_mode,
358
  float(prompt_mix),
 
 
 
359
  float(analogy_strength),
360
  bool(renormalize_prompt),
361
  )
@@ -486,12 +636,28 @@ def build_app():
486
  [
487
  "Prompt A only",
488
  "Average prompt A and B",
 
489
  "Vector arithmetic: A + s * (B - C)",
490
  ],
491
- value="Average prompt A and B",
492
  label="Embedding equation",
493
  )
494
  prompt_mix = gr.Slider(0, 1, value=0.5, step=0.05, label="Prompt B weight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  analogy_strength = gr.Slider(-2, 2, value=0.8, step=0.1, label="Vector arithmetic strength")
496
  renormalize_prompt = gr.Checkbox(value=True, label="Keep mixed prompt embedding norm near prompt A")
497
 
@@ -569,6 +735,19 @@ def build_app():
569
  outputs=[seed_a, seed_b],
570
  show_progress="hidden",
571
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  generate_button.click(
573
  generate,
574
  inputs=[
@@ -580,6 +759,9 @@ def build_app():
580
  negative_prompt,
581
  embedding_mode,
582
  prompt_mix,
 
 
 
583
  analogy_strength,
584
  renormalize_prompt,
585
  seed_a,
 
40
 
41
  if mode == "average":
42
  prompt_embeds = (1 - mix) * prompt_a + mix * prompt_b
43
+ elif mode == "triangle":
44
+ total = weight_a + weight_b + weight_c
45
+ prompt_embeds = (
46
+ (weight_a / total) * prompt_a
47
+ + (weight_b / total) * prompt_b
48
+ + (weight_c / total) * prompt_c
49
+ )
50
  elif mode == "analogy":
51
  prompt_embeds = prompt_a + strength * (prompt_b - prompt_c)
52
  else:
 
127
  return image
128
 
129
 
130
+ TRIANGLE_SIZE = 360
131
+ TRIANGLE_A = (180, 32)
132
+ TRIANGLE_B = (42, 300)
133
+ TRIANGLE_C = (318, 300)
134
+
135
+
136
+ def normalized_triangle_weights(weight_a, weight_b, weight_c):
137
+ weights = [max(0.0, float(weight_a)), max(0.0, float(weight_b)), max(0.0, float(weight_c))]
138
+ total = sum(weights)
139
+ if total <= 0:
140
+ return 1.0, 0.0, 0.0
141
+ return tuple(weight / total for weight in weights)
142
+
143
+
144
+ def weighted_triangle_point(weight_a, weight_b, weight_c):
145
+ weight_a, weight_b, weight_c = normalized_triangle_weights(weight_a, weight_b, weight_c)
146
+ x = weight_a * TRIANGLE_A[0] + weight_b * TRIANGLE_B[0] + weight_c * TRIANGLE_C[0]
147
+ y = weight_a * TRIANGLE_A[1] + weight_b * TRIANGLE_B[1] + weight_c * TRIANGLE_C[1]
148
+ return int(round(x)), int(round(y))
149
+
150
+
151
+ def short_corner_label(text, fallback):
152
+ text = " ".join(str(text or fallback).split())
153
+ return text[:24] + ("..." if len(text) > 24 else "")
154
+
155
+
156
+ def make_triangle_picker(prompt_a, prompt_b, prompt_c, weight_a=1 / 3, weight_b=1 / 3, weight_c=1 / 3):
157
+ image = Image.new("RGB", (TRIANGLE_SIZE, TRIANGLE_SIZE), (248, 250, 252))
158
+ draw = ImageDraw.Draw(image)
159
+ vertices = [TRIANGLE_A, TRIANGLE_B, TRIANGLE_C]
160
+
161
+ draw.polygon(vertices, fill=(235, 244, 255), outline=(43, 74, 111))
162
+ for i in range(1, 7):
163
+ t = i / 7
164
+ left = (
165
+ int((1 - t) * TRIANGLE_A[0] + t * TRIANGLE_B[0]),
166
+ int((1 - t) * TRIANGLE_A[1] + t * TRIANGLE_B[1]),
167
+ )
168
+ right = (
169
+ int((1 - t) * TRIANGLE_A[0] + t * TRIANGLE_C[0]),
170
+ int((1 - t) * TRIANGLE_A[1] + t * TRIANGLE_C[1]),
171
+ )
172
+ draw.line((left, right), fill=(190, 207, 225), width=1)
173
+
174
+ bottom = (
175
+ int((1 - t) * TRIANGLE_B[0] + t * TRIANGLE_C[0]),
176
+ int((1 - t) * TRIANGLE_B[1] + t * TRIANGLE_C[1]),
177
+ )
178
+ left_side = (
179
+ int((1 - t) * TRIANGLE_A[0] + t * TRIANGLE_B[0]),
180
+ int((1 - t) * TRIANGLE_A[1] + t * TRIANGLE_B[1]),
181
+ )
182
+ right_side = (
183
+ int((1 - t) * TRIANGLE_A[0] + t * TRIANGLE_C[0]),
184
+ int((1 - t) * TRIANGLE_A[1] + t * TRIANGLE_C[1]),
185
+ )
186
+ draw.line((TRIANGLE_B, right_side), fill=(214, 224, 236), width=1)
187
+ draw.line((TRIANGLE_C, left_side), fill=(214, 224, 236), width=1)
188
+ draw.line((bottom, TRIANGLE_A), fill=(214, 224, 236), width=1)
189
+
190
+ labels = [
191
+ (TRIANGLE_A, "A", short_corner_label(prompt_a, "Prompt A"), (55, 94, 151)),
192
+ (TRIANGLE_B, "B", short_corner_label(prompt_b, "Prompt B"), (5, 122, 85)),
193
+ (TRIANGLE_C, "C", short_corner_label(prompt_c, "Prompt C"), (154, 72, 174)),
194
+ ]
195
+ for (x, y), letter, label, color in labels:
196
+ draw.ellipse((x - 13, y - 13, x + 13, y + 13), fill=color, outline=(255, 255, 255), width=3)
197
+ draw.text((x - 4, y - 7), letter, fill=(255, 255, 255))
198
+ label_x = max(8, min(TRIANGLE_SIZE - 150, x - 70))
199
+ label_y = y - 34 if y < TRIANGLE_SIZE / 2 else y + 18
200
+ draw.text((label_x, label_y), label, fill=(30, 41, 59))
201
+
202
+ x, y = weighted_triangle_point(weight_a, weight_b, weight_c)
203
+ draw.ellipse((x - 9, y - 9, x + 9, y + 9), fill=(239, 68, 68), outline=(15, 23, 42), width=2)
204
+ draw.text((12, 12), "Click inside the triangle to choose the embedding blend.", fill=(51, 65, 85))
205
+ return image
206
+
207
+
208
+ def barycentric_triangle_weights(x, y):
209
+ ax, ay = TRIANGLE_A
210
+ bx, by = TRIANGLE_B
211
+ cx, cy = TRIANGLE_C
212
+ denominator = (by - cy) * (ax - cx) + (cx - bx) * (ay - cy)
213
+ if denominator == 0:
214
+ return 1.0, 0.0, 0.0
215
+
216
+ weight_a = ((by - cy) * (x - cx) + (cx - bx) * (y - cy)) / denominator
217
+ weight_b = ((cy - ay) * (x - cx) + (ax - cx) * (y - cy)) / denominator
218
+ weight_c = 1.0 - weight_a - weight_b
219
+
220
+ if min(weight_a, weight_b, weight_c) < 0:
221
+ weight_a, weight_b, weight_c = normalized_triangle_weights(weight_a, weight_b, weight_c)
222
+
223
+ return normalized_triangle_weights(weight_a, weight_b, weight_c)
224
+
225
+
226
+ def triangle_status(weight_a, weight_b, weight_c):
227
+ weight_a, weight_b, weight_c = normalized_triangle_weights(weight_a, weight_b, weight_c)
228
+ return f"A: {weight_a:.2f} B: {weight_b:.2f} C: {weight_c:.2f}"
229
+
230
+
231
+ def update_triangle_from_weights(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c):
232
+ weight_a, weight_b, weight_c = normalized_triangle_weights(weight_a, weight_b, weight_c)
233
+ return (
234
+ make_triangle_picker(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c),
235
+ triangle_status(weight_a, weight_b, weight_c),
236
+ weight_a,
237
+ weight_b,
238
+ weight_c,
239
+ )
240
+
241
+
242
+ def select_triangle_point(prompt_a, prompt_b, prompt_c, evt: gr.SelectData):
243
+ index = evt.index
244
+ if isinstance(index, dict):
245
+ x, y = index.get("x", TRIANGLE_A[0]), index.get("y", TRIANGLE_A[1])
246
+ else:
247
+ x, y = index[:2]
248
+ weight_a, weight_b, weight_c = barycentric_triangle_weights(float(x), float(y))
249
+ return update_triangle_from_weights(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c)
250
+
251
+
252
  @lru_cache(maxsize=2)
253
  def load_pipe(model_id, scheduler_name, device_type):
254
  device = torch.device(device_type)
 
314
  negative_prompt,
315
  embedding_mode,
316
  prompt_mix,
317
+ triangle_weight_a,
318
+ triangle_weight_b,
319
+ triangle_weight_c,
320
  analogy_strength,
321
  renormalize_prompt,
322
  ):
 
327
  if embedding_mode == "Average prompt A and B":
328
  mixed = (1.0 - prompt_mix) * emb_a + prompt_mix * emb_b
329
  formula = f"prompt = {(1.0 - prompt_mix):.2f} * A + {prompt_mix:.2f} * B"
330
+ elif embedding_mode == "Triangle blend: A/B/C":
331
+ triangle_weight_a, triangle_weight_b, triangle_weight_c = normalized_triangle_weights(
332
+ triangle_weight_a,
333
+ triangle_weight_b,
334
+ triangle_weight_c,
335
+ )
336
+ mixed = triangle_weight_a * emb_a + triangle_weight_b * emb_b + triangle_weight_c * emb_c
337
+ formula = (
338
+ f"prompt = {triangle_weight_a:.2f} * A + {triangle_weight_b:.2f} * B "
339
+ f"+ {triangle_weight_c:.2f} * C"
340
+ )
341
  elif embedding_mode == "Vector arithmetic: A + s * (B - C)":
342
  mixed = emb_a + analogy_strength * (emb_b - emb_c)
343
  formula = f"prompt = A + {analogy_strength:.2f} * (B - C)"
 
355
  ["cosine(A, B)", round(cosine_similarity(emb_a, emb_b), 4)],
356
  ["cosine(A, mixed)", round(cosine_similarity(emb_a, mixed), 4)],
357
  ["cosine(B, mixed)", round(cosine_similarity(emb_b, mixed), 4)],
358
+ ["cosine(C, mixed)", round(cosine_similarity(emb_c, mixed), 4)],
359
  ["norm(A)", round(float(original_norm.cpu()), 3)],
360
  ["norm(mixed)", round(float(mixed.detach().float().norm().cpu()), 3)],
361
  ]
 
460
  negative_prompt,
461
  embedding_mode,
462
  prompt_mix,
463
+ triangle_weight_a,
464
+ triangle_weight_b,
465
+ triangle_weight_c,
466
  analogy_strength,
467
  renormalize_prompt,
468
  seed_a,
 
503
  negative_prompt or "",
504
  embedding_mode,
505
  float(prompt_mix),
506
+ float(triangle_weight_a),
507
+ float(triangle_weight_b),
508
+ float(triangle_weight_c),
509
  float(analogy_strength),
510
  bool(renormalize_prompt),
511
  )
 
636
  [
637
  "Prompt A only",
638
  "Average prompt A and B",
639
+ "Triangle blend: A/B/C",
640
  "Vector arithmetic: A + s * (B - C)",
641
  ],
642
+ value="Triangle blend: A/B/C",
643
  label="Embedding equation",
644
  )
645
  prompt_mix = gr.Slider(0, 1, value=0.5, step=0.05, label="Prompt B weight")
646
+ triangle_picker = gr.Image(
647
+ value=make_triangle_picker(DEFAULT_PROMPT_A, DEFAULT_PROMPT_B, DEFAULT_PROMPT_C),
648
+ label="Triangle embedding mixer",
649
+ type="pil",
650
+ interactive=False,
651
+ )
652
+ triangle_status_box = gr.Textbox(
653
+ value=triangle_status(1 / 3, 1 / 3, 1 / 3),
654
+ label="Triangle weights",
655
+ interactive=False,
656
+ )
657
+ with gr.Row():
658
+ triangle_weight_a = gr.Slider(0, 1, value=1 / 3, step=0.01, label="A weight")
659
+ triangle_weight_b = gr.Slider(0, 1, value=1 / 3, step=0.01, label="B weight")
660
+ triangle_weight_c = gr.Slider(0, 1, value=1 / 3, step=0.01, label="C weight")
661
  analogy_strength = gr.Slider(-2, 2, value=0.8, step=0.1, label="Vector arithmetic strength")
662
  renormalize_prompt = gr.Checkbox(value=True, label="Keep mixed prompt embedding norm near prompt A")
663
 
 
735
  outputs=[seed_a, seed_b],
736
  show_progress="hidden",
737
  )
738
+ triangle_picker.select(
739
+ select_triangle_point,
740
+ inputs=[prompt_a, prompt_b, prompt_c],
741
+ outputs=[triangle_picker, triangle_status_box, triangle_weight_a, triangle_weight_b, triangle_weight_c],
742
+ show_progress="hidden",
743
+ )
744
+ for triangle_input in [prompt_a, prompt_b, prompt_c, triangle_weight_a, triangle_weight_b, triangle_weight_c]:
745
+ triangle_input.change(
746
+ update_triangle_from_weights,
747
+ inputs=[prompt_a, prompt_b, prompt_c, triangle_weight_a, triangle_weight_b, triangle_weight_c],
748
+ outputs=[triangle_picker, triangle_status_box, triangle_weight_a, triangle_weight_b, triangle_weight_c],
749
+ show_progress="hidden",
750
+ )
751
  generate_button.click(
752
  generate,
753
  inputs=[
 
759
  negative_prompt,
760
  embedding_mode,
761
  prompt_mix,
762
+ triangle_weight_a,
763
+ triangle_weight_b,
764
+ triangle_weight_c,
765
  analogy_strength,
766
  renormalize_prompt,
767
  seed_a,