AbstractPhil commited on
Commit
edab745
·
verified ·
1 Parent(s): 9cda443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -288,16 +288,14 @@ class SDXLFlowMatchingPipeline:
288
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
289
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
290
 
291
- # Get Lyra model dtype and cast inputs to match
292
- lyra_dtype = next(self.lyra_model.parameters()).dtype
293
-
294
  # Lyra v2 expects these exact keys from training config:
295
  # clip_l, clip_g, t5_xl_l, t5_xl_g
 
296
  modality_inputs = {
297
- 'clip_l': clip_l_embeds.to(lyra_dtype),
298
- 'clip_g': clip_g_embeds.to(lyra_dtype),
299
- 't5_xl_l': t5_embeds.to(lyra_dtype),
300
- 't5_xl_g': t5_embeds.to(lyra_dtype) # Same T5 embedding for both bindings
301
  }
302
 
303
  with torch.no_grad():
@@ -305,6 +303,7 @@ class SDXLFlowMatchingPipeline:
305
  modality_inputs,
306
  target_modalities=['clip_l', 'clip_g']
307
  )
 
308
  fused_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
309
  fused_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
310
 
@@ -328,10 +327,10 @@ class SDXLFlowMatchingPipeline:
328
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
329
 
330
  modality_inputs_neg = {
331
- 'clip_l': neg_clip_l.to(lyra_dtype),
332
- 'clip_g': neg_clip_g.to(lyra_dtype),
333
- 't5_xl_l': t5_embeds_neg.to(lyra_dtype),
334
- 't5_xl_g': t5_embeds_neg.to(lyra_dtype)
335
  }
336
 
337
  with torch.no_grad():
@@ -339,8 +338,8 @@ class SDXLFlowMatchingPipeline:
339
  modality_inputs_neg,
340
  target_modalities=['clip_l', 'clip_g']
341
  )
342
- fused_neg_clip_l = reconstructions_neg['clip_l'].to(prompt_embeds.dtype)
343
- fused_neg_clip_g = reconstructions_neg['clip_g'].to(prompt_embeds.dtype)
344
 
345
  negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, fused_neg_clip_g], dim=-1)
346
  else:
@@ -1029,8 +1028,8 @@ def load_lyra_vae_xl(
1029
  else:
1030
  lyra_model.load_state_dict(checkpoint)
1031
 
1032
- # Use float16 to match SDXL pipeline
1033
- lyra_model.to(device, dtype=torch.float16)
1034
  lyra_model.eval()
1035
 
1036
  print(f"✅ Lyra VAE v2 loaded")
 
288
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
289
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
290
 
 
 
 
291
  # Lyra v2 expects these exact keys from training config:
292
  # clip_l, clip_g, t5_xl_l, t5_xl_g
293
+ # Upcast inputs to float32 for Lyra (model is fp32 for stability)
294
  modality_inputs = {
295
+ 'clip_l': clip_l_embeds.float(),
296
+ 'clip_g': clip_g_embeds.float(),
297
+ 't5_xl_l': t5_embeds.float(),
298
+ 't5_xl_g': t5_embeds.float() # Same T5 embedding for both bindings
299
  }
300
 
301
  with torch.no_grad():
 
303
  modality_inputs,
304
  target_modalities=['clip_l', 'clip_g']
305
  )
306
+ # Cast outputs back to original dtype (float16)
307
  fused_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
308
  fused_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
309
 
 
327
  neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
328
 
329
  modality_inputs_neg = {
330
+ 'clip_l': neg_clip_l.float(),
331
+ 'clip_g': neg_clip_g.float(),
332
+ 't5_xl_l': t5_embeds_neg.float(),
333
+ 't5_xl_g': t5_embeds_neg.float()
334
  }
335
 
336
  with torch.no_grad():
 
338
  modality_inputs_neg,
339
  target_modalities=['clip_l', 'clip_g']
340
  )
341
+ fused_neg_clip_l = reconstructions_neg['clip_l'].to(negative_prompt_embeds.dtype)
342
+ fused_neg_clip_g = reconstructions_neg['clip_g'].to(negative_prompt_embeds.dtype)
343
 
344
  negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, fused_neg_clip_g], dim=-1)
345
  else:
 
1028
  else:
1029
  lyra_model.load_state_dict(checkpoint)
1030
 
1031
+ # Keep Lyra in float32 for stability - inputs will be upcast
1032
+ lyra_model.to(device)
1033
  lyra_model.eval()
1034
 
1035
  print(f"✅ Lyra VAE v2 loaded")