Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -288,16 +288,14 @@ class SDXLFlowMatchingPipeline:
|
|
| 288 |
clip_l_embeds = prompt_embeds[..., :clip_l_dim]
|
| 289 |
clip_g_embeds = prompt_embeds[..., clip_l_dim:]
|
| 290 |
|
| 291 |
-
# Get Lyra model dtype and cast inputs to match
|
| 292 |
-
lyra_dtype = next(self.lyra_model.parameters()).dtype
|
| 293 |
-
|
| 294 |
# Lyra v2 expects these exact keys from training config:
|
| 295 |
# clip_l, clip_g, t5_xl_l, t5_xl_g
|
|
|
|
| 296 |
modality_inputs = {
|
| 297 |
-
'clip_l': clip_l_embeds.
|
| 298 |
-
'clip_g': clip_g_embeds.
|
| 299 |
-
't5_xl_l': t5_embeds.
|
| 300 |
-
't5_xl_g': t5_embeds.
|
| 301 |
}
|
| 302 |
|
| 303 |
with torch.no_grad():
|
|
@@ -305,6 +303,7 @@ class SDXLFlowMatchingPipeline:
|
|
| 305 |
modality_inputs,
|
| 306 |
target_modalities=['clip_l', 'clip_g']
|
| 307 |
)
|
|
|
|
| 308 |
fused_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
|
| 309 |
fused_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
|
| 310 |
|
|
@@ -328,10 +327,10 @@ class SDXLFlowMatchingPipeline:
|
|
| 328 |
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
|
| 329 |
|
| 330 |
modality_inputs_neg = {
|
| 331 |
-
'clip_l': neg_clip_l.
|
| 332 |
-
'clip_g': neg_clip_g.
|
| 333 |
-
't5_xl_l': t5_embeds_neg.
|
| 334 |
-
't5_xl_g': t5_embeds_neg.
|
| 335 |
}
|
| 336 |
|
| 337 |
with torch.no_grad():
|
|
@@ -339,8 +338,8 @@ class SDXLFlowMatchingPipeline:
|
|
| 339 |
modality_inputs_neg,
|
| 340 |
target_modalities=['clip_l', 'clip_g']
|
| 341 |
)
|
| 342 |
-
fused_neg_clip_l = reconstructions_neg['clip_l'].to(
|
| 343 |
-
fused_neg_clip_g = reconstructions_neg['clip_g'].to(
|
| 344 |
|
| 345 |
negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, fused_neg_clip_g], dim=-1)
|
| 346 |
else:
|
|
@@ -1029,8 +1028,8 @@ def load_lyra_vae_xl(
|
|
| 1029 |
else:
|
| 1030 |
lyra_model.load_state_dict(checkpoint)
|
| 1031 |
|
| 1032 |
-
#
|
| 1033 |
-
lyra_model.to(device
|
| 1034 |
lyra_model.eval()
|
| 1035 |
|
| 1036 |
print(f"✅ Lyra VAE v2 loaded")
|
|
|
|
| 288 |
clip_l_embeds = prompt_embeds[..., :clip_l_dim]
|
| 289 |
clip_g_embeds = prompt_embeds[..., clip_l_dim:]
|
| 290 |
|
|
|
|
|
|
|
|
|
|
| 291 |
# Lyra v2 expects these exact keys from training config:
|
| 292 |
# clip_l, clip_g, t5_xl_l, t5_xl_g
|
| 293 |
+
# Upcast inputs to float32 for Lyra (model is fp32 for stability)
|
| 294 |
modality_inputs = {
|
| 295 |
+
'clip_l': clip_l_embeds.float(),
|
| 296 |
+
'clip_g': clip_g_embeds.float(),
|
| 297 |
+
't5_xl_l': t5_embeds.float(),
|
| 298 |
+
't5_xl_g': t5_embeds.float() # Same T5 embedding for both bindings
|
| 299 |
}
|
| 300 |
|
| 301 |
with torch.no_grad():
|
|
|
|
| 303 |
modality_inputs,
|
| 304 |
target_modalities=['clip_l', 'clip_g']
|
| 305 |
)
|
| 306 |
+
# Cast outputs back to original dtype (float16)
|
| 307 |
fused_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
|
| 308 |
fused_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
|
| 309 |
|
|
|
|
| 327 |
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
|
| 328 |
|
| 329 |
modality_inputs_neg = {
|
| 330 |
+
'clip_l': neg_clip_l.float(),
|
| 331 |
+
'clip_g': neg_clip_g.float(),
|
| 332 |
+
't5_xl_l': t5_embeds_neg.float(),
|
| 333 |
+
't5_xl_g': t5_embeds_neg.float()
|
| 334 |
}
|
| 335 |
|
| 336 |
with torch.no_grad():
|
|
|
|
| 338 |
modality_inputs_neg,
|
| 339 |
target_modalities=['clip_l', 'clip_g']
|
| 340 |
)
|
| 341 |
+
fused_neg_clip_l = reconstructions_neg['clip_l'].to(negative_prompt_embeds.dtype)
|
| 342 |
+
fused_neg_clip_g = reconstructions_neg['clip_g'].to(negative_prompt_embeds.dtype)
|
| 343 |
|
| 344 |
negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, fused_neg_clip_g], dim=-1)
|
| 345 |
else:
|
|
|
|
| 1028 |
else:
|
| 1029 |
lyra_model.load_state_dict(checkpoint)
|
| 1030 |
|
| 1031 |
+
# Keep Lyra in float32 for stability - inputs will be upcast
|
| 1032 |
+
lyra_model.to(device)
|
| 1033 |
lyra_model.eval()
|
| 1034 |
|
| 1035 |
print(f"✅ Lyra VAE v2 loaded")
|