AbstractPhil commited on
Commit
93038cf
·
verified ·
1 Parent(s): 0c67338

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -46
app.py CHANGED
@@ -259,9 +259,17 @@ class SDXLFlowMatchingPipeline:
259
  prompt: str,
260
  negative_prompt: str = "",
261
  clip_skip: int = 1,
262
- t5_summary: str = ""
 
263
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
264
- """Encode prompts using Lyra VAE v2 fusion (CLIP + T5)."""
 
 
 
 
 
 
 
265
  if self.lyra_model is None or self.t5_encoder is None:
266
  raise ValueError("Lyra VAE components not initialized")
267
 
@@ -271,18 +279,16 @@ class SDXLFlowMatchingPipeline:
271
  )
272
 
273
  # Format T5 input with pilcrow separator (¶)
274
- # Training format was: "tags ¶ summary"
275
  SUMMARY_SEPARATOR = "¶"
276
  if t5_summary.strip():
277
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
278
  else:
279
- # Fallback: duplicate prompt if no summary provided
280
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}"
281
 
282
  # Get T5 embeddings
283
  t5_inputs = self.t5_tokenizer(
284
  t5_prompt,
285
- max_length=512, # T5-XL uses 512
286
  padding='max_length',
287
  truncation=True,
288
  return_tensors='pt'
@@ -291,40 +297,88 @@ class SDXLFlowMatchingPipeline:
291
  with torch.no_grad():
292
  t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
293
 
294
- # For SDXL, split the concatenated CLIP-L + CLIP-G embeddings
295
  clip_l_dim = 768
296
  clip_g_dim = 1280
297
 
298
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
299
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
300
 
301
- # Lyra v2 expects these exact keys from training config:
302
- # clip_l, clip_g, t5_xl_l, t5_xl_g
303
- # Upcast inputs to float32 for Lyra (model is fp32 for stability)
304
- modality_inputs = {
305
- 'clip_l': clip_l_embeds.float(),
306
- 'clip_g': clip_g_embeds.float(),
307
- 't5_xl_l': t5_embeds.float(),
308
- 't5_xl_g': t5_embeds.float() # Same T5 embedding for both bindings
309
- }
310
 
311
  with torch.no_grad():
312
- reconstructions, mu, logvar, _ = self.lyra_model(
313
- modality_inputs,
314
- target_modalities=['clip_l', 'clip_g']
315
- )
316
- # Cast outputs back to original dtype (float16)
317
- fused_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
318
- fused_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- # Recombine fused CLIP-L and CLIP-G
321
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
322
 
323
- # Process negative prompt similarly if present
324
  if negative_prompt:
325
- # For negative, just use the negative prompt without summary
326
- t5_neg_prompt = f"{negative_prompt} {SUMMARY_SEPARATOR} {negative_prompt}"
327
 
 
328
  t5_inputs_neg = self.t5_tokenizer(
329
  t5_neg_prompt,
330
  max_length=512,
@@ -339,22 +393,34 @@ class SDXLFlowMatchingPipeline:
339
  neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
340
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
341
 
342
- modality_inputs_neg = {
343
- 'clip_l': neg_clip_l.float(),
344
- 'clip_g': neg_clip_g.float(),
345
- 't5_xl_l': t5_embeds_neg.float(),
346
- 't5_xl_g': t5_embeds_neg.float()
347
- }
 
 
 
 
 
 
348
 
349
- with torch.no_grad():
350
- reconstructions_neg, _, _, _ = self.lyra_model(
351
- modality_inputs_neg,
352
- target_modalities=['clip_l', 'clip_g']
353
- )
354
- fused_neg_clip_l = reconstructions_neg['clip_l'].to(negative_prompt_embeds.dtype)
355
- fused_neg_clip_g = reconstructions_neg['clip_g'].to(negative_prompt_embeds.dtype)
 
 
 
356
 
357
- negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, fused_neg_clip_g], dim=-1)
 
 
 
358
  else:
359
  negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
360
 
@@ -388,6 +454,7 @@ class SDXLFlowMatchingPipeline:
388
  use_lyra: bool = False,
389
  clip_skip: int = 1,
390
  t5_summary: str = "",
 
391
  progress_callback=None
392
  ):
393
  """Generate image using SDXL architecture."""
@@ -401,7 +468,7 @@ class SDXLFlowMatchingPipeline:
401
  # Encode prompts
402
  if use_lyra and self.lyra_model is not None:
403
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
404
- prompt, negative_prompt, clip_skip, t5_summary
405
  )
406
  else:
407
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
@@ -1234,6 +1301,7 @@ def generate_image(
1234
  shift: float,
1235
  use_flow_matching: bool,
1236
  use_lyra: bool,
 
1237
  seed: int,
1238
  randomize_seed: bool,
1239
  progress=gr.Progress()
@@ -1313,6 +1381,7 @@ def generate_image(
1313
  use_lyra=True,
1314
  clip_skip=clip_skip,
1315
  t5_summary=t5_summary,
 
1316
  progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
1317
  )
1318
 
@@ -1400,6 +1469,15 @@ def create_demo():
1400
  info="Compare standard vs geometric fusion"
1401
  )
1402
 
 
 
 
 
 
 
 
 
 
1403
  with gr.Accordion("Generation Settings", open=True):
1404
  num_steps = gr.Slider(
1405
  label="Steps",
@@ -1500,27 +1578,27 @@ def create_demo():
1500
  "A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
1501
  "lowres, bad anatomy, worst quality, low quality",
1502
  "Illustrious XL",
1503
- 2, 25, 7.0, 1024, 1024, 0.0, False, True, 42, False
1504
  ],
1505
  [
1506
  "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k",
1507
  "A breathtaking mountain vista bathed in warm golden light at sunset, with a perfectly still crystal clear lake reflecting the peaks",
1508
  "blurry, low quality",
1509
  "SDXL Base",
1510
- 1, 30, 7.5, 1024, 1024, 0.0, False, True, 123, False
1511
  ],
1512
  [
1513
  "cyberpunk city at night, neon lights, rain, highly detailed",
1514
  "A futuristic cyberpunk metropolis at night with vibrant neon lights reflecting off rain-slicked streets",
1515
  "low quality, blurry",
1516
  "Flow-Lune (SD1.5)",
1517
- 1, 20, 7.5, 512, 512, 2.5, True, True, 456, False
1518
  ],
1519
  ],
1520
  inputs=[
1521
  prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1522
  num_steps, cfg_scale, width, height, shift,
1523
- use_flow_matching, use_lyra, seed, randomize_seed
1524
  ],
1525
  outputs=[output_image_standard, output_image_lyra, output_seed],
1526
  fn=generate_image,
@@ -1597,7 +1675,7 @@ def create_demo():
1597
  inputs=[
1598
  prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1599
  num_steps, cfg_scale, width, height, shift,
1600
- use_flow_matching, use_lyra, seed, randomize_seed
1601
  ],
1602
  outputs=[output_image_standard, output_image_lyra, output_seed]
1603
  )
 
259
  prompt: str,
260
  negative_prompt: str = "",
261
  clip_skip: int = 1,
262
+ t5_summary: str = "",
263
+ lyra_strength: float = 0.3
264
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
265
+ """Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
266
+
267
+ Uses cross-modal translation: encode T5 → decode to CLIP space,
268
+ then blend with original CLIP embeddings.
269
+
270
+ Args:
271
+ lyra_strength: Blend factor (0.0 = pure CLIP, 1.0 = pure Lyra reconstruction)
272
+ """
273
  if self.lyra_model is None or self.t5_encoder is None:
274
  raise ValueError("Lyra VAE components not initialized")
275
 
 
279
  )
280
 
281
  # Format T5 input with pilcrow separator (¶)
 
282
  SUMMARY_SEPARATOR = "¶"
283
  if t5_summary.strip():
284
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
285
  else:
 
286
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}"
287
 
288
  # Get T5 embeddings
289
  t5_inputs = self.t5_tokenizer(
290
  t5_prompt,
291
+ max_length=512,
292
  padding='max_length',
293
  truncation=True,
294
  return_tensors='pt'
 
297
  with torch.no_grad():
298
  t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
299
 
 
300
  clip_l_dim = 768
301
  clip_g_dim = 1280
302
 
303
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
304
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
305
 
306
+ # Debug: print input stats
307
+ print(f"[Lyra Debug] CLIP-L input: shape={clip_l_embeds.shape}, mean={clip_l_embeds.mean():.4f}, std={clip_l_embeds.std():.4f}")
308
+ print(f"[Lyra Debug] CLIP-G input: shape={clip_g_embeds.shape}, mean={clip_g_embeds.mean():.4f}, std={clip_g_embeds.std():.4f}")
309
+ print(f"[Lyra Debug] T5 input: shape={t5_embeds.shape}, mean={t5_embeds.mean():.4f}, std={t5_embeds.std():.4f}")
 
 
 
 
 
310
 
311
  with torch.no_grad():
312
+ # Try approach 1: Cross-modal - encode T5 only, decode to CLIP
313
+ # This uses T5's semantic understanding to generate CLIP-compatible embeddings
314
+ t5_only_inputs = {
315
+ 't5_xl_l': t5_embeds.float(),
316
+ 't5_xl_g': t5_embeds.float()
317
+ }
318
+
319
+ # Check if model has separate encode/decode methods
320
+ if hasattr(self.lyra_model, 'encode') and hasattr(self.lyra_model, 'decode'):
321
+ print("[Lyra Debug] Using separate encode/decode path")
322
+ # Encode T5 to latent space
323
+ mu, logvar = self.lyra_model.encode(t5_only_inputs)
324
+ z = mu # Use mean for deterministic output
325
+ print(f"[Lyra Debug] Latent z: shape={z.shape}, mean={z.mean():.4f}, std={z.std():.4f}")
326
+
327
+ # Decode to CLIP space
328
+ reconstructions = self.lyra_model.decode(z, target_modalities=['clip_l', 'clip_g'])
329
+ else:
330
+ print("[Lyra Debug] Using forward pass with all modalities")
331
+ # Fall back to full forward pass with all modalities
332
+ modality_inputs = {
333
+ 'clip_l': clip_l_embeds.float(),
334
+ 'clip_g': clip_g_embeds.float(),
335
+ 't5_xl_l': t5_embeds.float(),
336
+ 't5_xl_g': t5_embeds.float()
337
+ }
338
+ reconstructions, mu, logvar, _ = self.lyra_model(
339
+ modality_inputs,
340
+ target_modalities=['clip_l', 'clip_g']
341
+ )
342
+ print(f"[Lyra Debug] Latent mu: shape={mu.shape}, mean={mu.mean():.4f}, std={mu.std():.4f}")
343
+
344
+ lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
345
+ lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
346
+
347
+ print(f"[Lyra Debug] Lyra CLIP-L output: mean={lyra_clip_l.mean():.4f}, std={lyra_clip_l.std():.4f}")
348
+ print(f"[Lyra Debug] Lyra CLIP-G output: mean={lyra_clip_g.mean():.4f}, std={lyra_clip_g.std():.4f}")
349
+
350
+ # Check if reconstruction stats are wildly different from input
351
+ # If so, we may need to normalize
352
+ clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
353
+ clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
354
+ print(f"[Lyra Debug] Std ratio CLIP-L: {clip_l_std_ratio:.4f}, CLIP-G: {clip_g_std_ratio:.4f}")
355
+
356
+ # Normalize reconstructions to match input statistics if needed
357
+ if clip_l_std_ratio > 2.0 or clip_l_std_ratio < 0.5:
358
+ print("[Lyra Debug] Normalizing CLIP-L reconstruction to match input stats")
359
+ lyra_clip_l = (lyra_clip_l - lyra_clip_l.mean()) / (lyra_clip_l.std() + 1e-8)
360
+ lyra_clip_l = lyra_clip_l * clip_l_embeds.std() + clip_l_embeds.mean()
361
+
362
+ if clip_g_std_ratio > 2.0 or clip_g_std_ratio < 0.5:
363
+ print("[Lyra Debug] Normalizing CLIP-G reconstruction to match input stats")
364
+ lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
365
+ lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
366
+
367
+ # Blend original CLIP with Lyra reconstruction
368
+ fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
369
+ fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
370
+
371
+ print(f"[Lyra Debug] Final fused CLIP-L: mean={fused_clip_l.mean():.4f}, std={fused_clip_l.std():.4f}")
372
+ print(f"[Lyra Debug] lyra_strength={lyra_strength}")
373
 
 
374
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
375
 
376
+ # Process negative prompt (simpler - just use original CLIP for negative)
377
  if negative_prompt:
378
+ # For negative, blend less aggressively
379
+ neg_strength = lyra_strength * 0.5
380
 
381
+ t5_neg_prompt = f"{negative_prompt} {SUMMARY_SEPARATOR} {negative_prompt}"
382
  t5_inputs_neg = self.t5_tokenizer(
383
  t5_neg_prompt,
384
  max_length=512,
 
393
  neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
394
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
395
 
396
+ if hasattr(self.lyra_model, 'encode') and hasattr(self.lyra_model, 'decode'):
397
+ t5_neg_inputs = {'t5_xl_l': t5_embeds_neg.float(), 't5_xl_g': t5_embeds_neg.float()}
398
+ mu_neg, _ = self.lyra_model.encode(t5_neg_inputs)
399
+ recon_neg = self.lyra_model.decode(mu_neg, target_modalities=['clip_l', 'clip_g'])
400
+ else:
401
+ modality_inputs_neg = {
402
+ 'clip_l': neg_clip_l.float(),
403
+ 'clip_g': neg_clip_g.float(),
404
+ 't5_xl_l': t5_embeds_neg.float(),
405
+ 't5_xl_g': t5_embeds_neg.float()
406
+ }
407
+ recon_neg, _, _, _ = self.lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
408
 
409
+ lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
410
+ lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
411
+
412
+ # Normalize if needed
413
+ if lyra_neg_l.std() / (neg_clip_l.std() + 1e-8) > 2.0:
414
+ lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
415
+ lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
416
+ if lyra_neg_g.std() / (neg_clip_g.std() + 1e-8) > 2.0:
417
+ lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
418
+ lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
419
 
420
+ fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l
421
+ fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g
422
+
423
+ negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1)
424
  else:
425
  negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
426
 
 
454
  use_lyra: bool = False,
455
  clip_skip: int = 1,
456
  t5_summary: str = "",
457
+ lyra_strength: float = 0.3,
458
  progress_callback=None
459
  ):
460
  """Generate image using SDXL architecture."""
 
468
  # Encode prompts
469
  if use_lyra and self.lyra_model is not None:
470
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
471
+ prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
472
  )
473
  else:
474
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
 
1301
  shift: float,
1302
  use_flow_matching: bool,
1303
  use_lyra: bool,
1304
+ lyra_strength: float,
1305
  seed: int,
1306
  randomize_seed: bool,
1307
  progress=gr.Progress()
 
1381
  use_lyra=True,
1382
  clip_skip=clip_skip,
1383
  t5_summary=t5_summary,
1384
+ lyra_strength=lyra_strength,
1385
  progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
1386
  )
1387
 
 
1469
  info="Compare standard vs geometric fusion"
1470
  )
1471
 
1472
+ lyra_strength = gr.Slider(
1473
+ label="Lyra Blend Strength",
1474
+ minimum=0.0,
1475
+ maximum=1.0,
1476
+ value=0.3,
1477
+ step=0.05,
1478
+ info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction"
1479
+ )
1480
+
1481
  with gr.Accordion("Generation Settings", open=True):
1482
  num_steps = gr.Slider(
1483
  label="Steps",
 
1578
  "A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
1579
  "lowres, bad anatomy, worst quality, low quality",
1580
  "Illustrious XL",
1581
+ 2, 25, 7.0, 1024, 1024, 0.0, False, True, 0.3, 42, False
1582
  ],
1583
  [
1584
  "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k",
1585
  "A breathtaking mountain vista bathed in warm golden light at sunset, with a perfectly still crystal clear lake reflecting the peaks",
1586
  "blurry, low quality",
1587
  "SDXL Base",
1588
+ 1, 30, 7.5, 1024, 1024, 0.0, False, True, 0.3, 123, False
1589
  ],
1590
  [
1591
  "cyberpunk city at night, neon lights, rain, highly detailed",
1592
  "A futuristic cyberpunk metropolis at night with vibrant neon lights reflecting off rain-slicked streets",
1593
  "low quality, blurry",
1594
  "Flow-Lune (SD1.5)",
1595
+ 1, 20, 7.5, 512, 512, 2.5, True, True, 0.3, 456, False
1596
  ],
1597
  ],
1598
  inputs=[
1599
  prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1600
  num_steps, cfg_scale, width, height, shift,
1601
+ use_flow_matching, use_lyra, lyra_strength, seed, randomize_seed
1602
  ],
1603
  outputs=[output_image_standard, output_image_lyra, output_seed],
1604
  fn=generate_image,
 
1675
  inputs=[
1676
  prompt, t5_summary, negative_prompt, model_choice, clip_skip,
1677
  num_steps, cfg_scale, width, height, shift,
1678
+ use_flow_matching, use_lyra, lyra_strength, seed, randomize_seed
1679
  ],
1680
  outputs=[output_image_standard, output_image_lyra, output_seed]
1681
  )