AbstractPhil commited on
Commit
dda3cf4
·
verified ·
1 Parent(s): a7aafe6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1062 -325
app.py CHANGED
@@ -9,6 +9,11 @@ Supports Illustrious XL, standard SDXL, and SD1.5 variants.
9
  Lyra VAE Versions:
10
  - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
11
  - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
 
 
 
 
 
12
  """
13
 
14
  import os
@@ -17,7 +22,7 @@ import torch
17
  import gradio as gr
18
  import numpy as np
19
  from PIL import Image
20
- from typing import Optional, Dict, Tuple
21
  import spaces
22
  from safetensors.torch import load_file as load_safetensors
23
 
@@ -38,35 +43,29 @@ from transformers import (
38
  )
39
  from huggingface_hub import hf_hub_download
40
 
41
- # Lazy imports for Lyra
42
- LYRA_V1_AVAILABLE = False
43
- LYRA_V2_AVAILABLE = False
44
- LyraV1 = None
45
- LyraV1Config = None
46
- LyraV2 = None
47
- LyraV2Config = None
48
 
 
 
 
 
 
 
 
49
 
50
- def _load_lyra_imports():
51
- """Lazy load Lyra VAE modules."""
52
- global LYRA_V1_AVAILABLE, LYRA_V2_AVAILABLE
53
- global LyraV1, LyraV1Config, LyraV2, LyraV2Config
54
-
55
- try:
56
- from geofractal.model.vae.vae_lyra import MultiModalVAE as _LyraV1, MultiModalVAEConfig as _LyraV1Config
57
- LyraV1 = _LyraV1
58
- LyraV1Config = _LyraV1Config
59
- LYRA_V1_AVAILABLE = True
60
- except ImportError:
61
- print("⚠️ Lyra VAE v1 not available")
62
-
63
- try:
64
- from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as _LyraV2, MultiModalVAEConfig as _LyraV2Config
65
- LyraV2 = _LyraV2
66
- LyraV2Config = _LyraV2Config
67
- LYRA_V2_AVAILABLE = True
68
- except ImportError:
69
- print("⚠️ Lyra VAE v2 not available")
70
 
71
 
72
  # ============================================================================
@@ -76,66 +75,32 @@ def _load_lyra_imports():
76
  ARCH_SD15 = "sd15"
77
  ARCH_SDXL = "sdxl"
78
 
79
- # Scheduler options
80
  SCHEDULER_EULER_A = "Euler Ancestral"
81
  SCHEDULER_EULER = "Euler"
82
  SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE"
83
  SCHEDULER_DPM_2M = "DPM++ 2M"
84
 
85
- SDXL_SCHEDULERS = [SCHEDULER_EULER_A, SCHEDULER_EULER, SCHEDULER_DPM_2M_SDE, SCHEDULER_DPM_2M]
86
-
87
-
88
- # ============================================================================
89
- # SCHEDULER FACTORY
90
- # ============================================================================
91
-
92
- def get_scheduler(scheduler_name: str, config_path: str = "stabilityai/stable-diffusion-xl-base-1.0"):
93
- """Create scheduler by name."""
94
-
95
- if scheduler_name == SCHEDULER_EULER_A:
96
- return EulerAncestralDiscreteScheduler.from_pretrained(
97
- config_path, subfolder="scheduler"
98
- )
99
- elif scheduler_name == SCHEDULER_EULER:
100
- return EulerDiscreteScheduler.from_pretrained(
101
- config_path, subfolder="scheduler"
102
- )
103
- elif scheduler_name == SCHEDULER_DPM_2M_SDE:
104
- return DPMSolverSDEScheduler.from_pretrained(
105
- config_path, subfolder="scheduler",
106
- algorithm_type="sde-dpmsolver++",
107
- solver_order=2,
108
- )
109
- elif scheduler_name == SCHEDULER_DPM_2M:
110
- return DPMSolverMultistepScheduler.from_pretrained(
111
- config_path, subfolder="scheduler",
112
- algorithm_type="dpmsolver++",
113
- solver_order=2,
114
- )
115
- else:
116
- # Default to Euler Ancestral
117
- return EulerAncestralDiscreteScheduler.from_pretrained(
118
- config_path, subfolder="scheduler"
119
- )
120
 
 
 
 
 
 
121
 
122
- # ============================================================================
123
- # MODEL LOADING UTILITIES
124
- # ============================================================================
125
 
126
- def get_clip_hidden_state(
127
- model_output,
128
- clip_skip: int = 1,
129
- output_hidden_states: bool = True
130
- ) -> torch.Tensor:
131
- """Extract hidden state with clip_skip support."""
132
- if clip_skip == 1 or not output_hidden_states:
133
- return model_output.last_hidden_state
134
-
135
- if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
136
- return model_output.hidden_states[-clip_skip]
137
-
138
- return model_output.last_hidden_state
139
 
140
 
141
  # ============================================================================
@@ -143,168 +108,281 @@ def get_clip_hidden_state(
143
  # ============================================================================
144
 
145
  class LazyT5Encoder:
146
- """Lazy loader for T5 encoder - only loads when first accessed."""
147
 
148
- def __init__(self, model_name: str = "google/flan-t5-xl", device: str = "cuda"):
149
  self.model_name = model_name
150
  self.device = device
 
151
  self._encoder = None
152
  self._tokenizer = None
 
153
 
154
  @property
155
- def encoder(self):
156
  if self._encoder is None:
157
- print(f"📥 Loading T5 encoder: {self.model_name}...")
158
  self._encoder = T5EncoderModel.from_pretrained(
159
  self.model_name,
160
- torch_dtype=torch.float16
161
  ).to(self.device)
162
  self._encoder.eval()
163
- print("✓ T5 encoder loaded")
 
164
  return self._encoder
165
 
166
  @property
167
- def tokenizer(self):
168
  if self._tokenizer is None:
169
  print(f"📥 Loading T5 tokenizer: {self.model_name}...")
170
  self._tokenizer = T5Tokenizer.from_pretrained(self.model_name)
171
  print("✓ T5 tokenizer loaded")
172
  return self._tokenizer
173
 
174
- def is_loaded(self):
175
- return self._encoder is not None
 
 
 
 
 
 
 
 
 
 
176
 
177
 
178
  class LazyLyraModel:
179
- """Lazy loader for Lyra VAE - only loads when first accessed."""
180
 
181
- def __init__(self, repo_id: str, device: str = "cuda", version: int = 2):
 
 
 
 
 
182
  self.repo_id = repo_id
183
  self.device = device
184
- self.version = version
185
  self._model = None
 
 
186
 
187
  @property
188
  def model(self):
189
  if self._model is None:
190
- _load_lyra_imports()
191
 
192
- if self.version == 2:
193
- self._model = self._load_v2()
 
 
 
 
 
 
194
  else:
195
- self._model = self._load_v1()
 
 
 
 
 
 
196
  return self._model
197
 
198
- def _load_v2(self):
 
 
 
 
 
 
 
 
 
 
 
199
  if not LYRA_V2_AVAILABLE:
200
- print("⚠️ Lyra VAE v2 not available")
201
- return None
202
 
203
- print(f"🎵 Loading Lyra VAE v2 from {self.repo_id}...")
 
 
 
 
204
 
205
- try:
206
- from huggingface_hub import list_repo_files
207
-
208
- config_path = hf_hub_download(
209
- repo_id=self.repo_id,
210
- filename="config.json",
211
- repo_type="model"
212
- )
213
-
214
- with open(config_path, 'r') as f:
215
- config_dict = json.load(f)
216
-
217
- print(f" ✓ Config: {config_dict.get('fusion_strategy', 'unknown')} fusion")
218
-
219
- # Auto-detect checkpoint
220
- repo_files = list_repo_files(self.repo_id, repo_type="model")
221
- checkpoint_files = [f for f in repo_files if f.endswith('.pt')]
222
- checkpoint_files = [f for f in checkpoint_files if 'checkpoint' in f.lower()]
223
-
224
- if not checkpoint_files:
225
- raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
226
-
227
- import re
228
- def extract_step(name):
229
- match = re.search(r'(\d+)\.pt', name)
230
- return int(match.group(1)) if match else 0
231
-
232
- checkpoint_files.sort(key=extract_step, reverse=True)
233
- checkpoint_filename = checkpoint_files[0]
234
- print(f" ✓ Using: {checkpoint_filename}")
235
-
236
- checkpoint_path = hf_hub_download(
237
- repo_id=self.repo_id,
238
- filename=checkpoint_filename,
239
- repo_type="model"
240
- )
241
-
242
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
243
-
244
- vae_config = LyraV2Config(
245
- modality_dims=config_dict.get('modality_dims', {
246
- "clip_l": 768, "clip_g": 1280,
247
- "t5_xl_l": 2048, "t5_xl_g": 2048
248
- }),
249
- modality_seq_lens=config_dict.get('modality_seq_lens', {
250
- "clip_l": 77, "clip_g": 77,
251
- "t5_xl_l": 512, "t5_xl_g": 512
252
- }),
253
- binding_config=config_dict.get('binding_config', {
254
- "clip_l": {"t5_xl_l": 0.3},
255
- "clip_g": {"t5_xl_g": 0.3},
256
- "t5_xl_l": {},
257
- "t5_xl_g": {}
258
- }),
259
- latent_dim=config_dict.get('latent_dim', 2048),
260
- seq_len=config_dict.get('seq_len', 77),
261
- encoder_layers=config_dict.get('encoder_layers', 3),
262
- decoder_layers=config_dict.get('decoder_layers', 3),
263
- hidden_dim=config_dict.get('hidden_dim', 2048),
264
- dropout=config_dict.get('dropout', 0.1),
265
- fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
266
- fusion_heads=config_dict.get('fusion_heads', 8),
267
- fusion_dropout=config_dict.get('fusion_dropout', 0.1),
268
- cantor_depth=config_dict.get('cantor_depth', 8),
269
- cantor_local_window=config_dict.get('cantor_local_window', 3),
270
- alpha_init=config_dict.get('alpha_init', 1.0),
271
- beta_init=config_dict.get('beta_init', 0.3),
272
- )
273
-
274
- lyra_model = LyraV2(vae_config)
275
-
276
  state_dict = checkpoint.get('model_state_dict', checkpoint)
277
- missing, unexpected = lyra_model.load_state_dict(state_dict, strict=False)
278
-
279
- if missing:
280
- print(f" ⚠️ Missing keys: {len(missing)}")
281
- if unexpected:
282
- print(f" ⚠️ Unexpected keys: {len(unexpected)}")
283
-
284
- lyra_model.to(self.device)
285
- lyra_model.eval()
286
-
287
- total_params = sum(p.numel() for p in lyra_model.parameters())
288
- print(f"✅ Lyra VAE v2 loaded ({total_params/1e6:.1f}M params)")
289
-
290
- return lyra_model
291
-
292
- except Exception as e:
293
- print(f"❌ Failed to load Lyra VAE v2: {e}")
294
- import traceback
295
- traceback.print_exc()
296
- return None
297
-
298
- def _load_v1(self):
299
- if not LYRA_V1_AVAILABLE:
300
- print("⚠️ Lyra VAE v1 not available")
301
- return None
302
 
303
- # Similar implementation for v1...
304
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- def is_loaded(self):
307
- return self._model is not None
308
 
309
 
310
  # ============================================================================
@@ -312,7 +390,10 @@ class LazyLyraModel:
312
  # ============================================================================
313
 
314
  class SDXLFlowMatchingPipeline:
315
- """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders."""
 
 
 
316
 
317
  def __init__(
318
  self,
@@ -337,7 +418,7 @@ class SDXLFlowMatchingPipeline:
337
  self.scheduler = scheduler
338
  self.device = device
339
 
340
- # Lazy loaders
341
  self.t5_loader = t5_loader
342
  self.lyra_loader = lyra_loader
343
 
@@ -345,23 +426,41 @@ class SDXLFlowMatchingPipeline:
345
  self.clip_skip = clip_skip
346
  self.vae_scale_factor = 0.13025
347
  self.arch = ARCH_SDXL
 
 
 
348
 
349
  def set_scheduler(self, scheduler_name: str):
350
- """Switch scheduler."""
351
- self.scheduler = get_scheduler(scheduler_name)
 
 
 
 
 
 
 
352
 
353
  @property
354
- def t5_encoder(self):
 
355
  return self.t5_loader.encoder if self.t5_loader else None
356
 
357
  @property
358
- def t5_tokenizer(self):
 
359
  return self.t5_loader.tokenizer if self.t5_loader else None
360
 
361
  @property
362
  def lyra_model(self):
 
363
  return self.lyra_loader.model if self.lyra_loader else None
364
 
 
 
 
 
 
365
  def encode_prompt(
366
  self,
367
  prompt: str,
@@ -406,6 +505,7 @@ class SDXLFlowMatchingPipeline:
406
  prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
407
  pooled_prompt_embeds = clip_g_output.text_embeds
408
 
 
409
  prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
410
 
411
  # Negative prompt
@@ -457,17 +557,24 @@ class SDXLFlowMatchingPipeline:
457
  t5_summary: str = "",
458
  lyra_strength: float = 0.3
459
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
460
- """Encode prompts using Lyra VAE v2 fusion (CLIP + T5)."""
461
 
462
- if self.lyra_model is None or self.t5_encoder is None:
463
- raise ValueError("Lyra VAE components not initialized")
 
 
 
 
 
 
 
464
 
465
  # Get standard CLIP embeddings first
466
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
467
  prompt, negative_prompt, clip_skip
468
  )
469
 
470
- # Format T5 input
471
  SUMMARY_SEPARATOR = "¶"
472
  if t5_summary.strip():
473
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
@@ -475,7 +582,7 @@ class SDXLFlowMatchingPipeline:
475
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}"
476
 
477
  # Get T5 embeddings
478
- t5_inputs = self.t5_tokenizer(
479
  t5_prompt,
480
  max_length=512,
481
  padding='max_length',
@@ -484,9 +591,11 @@ class SDXLFlowMatchingPipeline:
484
  ).to(self.device)
485
 
486
  with torch.no_grad():
487
- t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
488
 
489
  clip_l_dim = 768
 
 
490
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
491
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
492
 
@@ -497,7 +606,7 @@ class SDXLFlowMatchingPipeline:
497
  't5_xl_l': t5_embeds.float(),
498
  't5_xl_g': t5_embeds.float()
499
  }
500
- reconstructions, mu, logvar, _ = self.lyra_model(
501
  modality_inputs,
502
  target_modalities=['clip_l', 'clip_g']
503
  )
@@ -505,7 +614,7 @@ class SDXLFlowMatchingPipeline:
505
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
506
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
507
 
508
- # Normalize if stats are off
509
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
510
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
511
 
@@ -517,14 +626,60 @@ class SDXLFlowMatchingPipeline:
517
  lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
518
  lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
519
 
520
- # Blend
521
  fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
522
  fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
523
 
524
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
525
 
526
- # Negative prompt - just use original CLIP
527
- return prompt_embeds_fused, negative_prompt_embeds, pooled, negative_pooled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  def _get_add_time_ids(
530
  self,
@@ -545,11 +700,14 @@ class SDXLFlowMatchingPipeline:
545
  negative_prompt: str = "",
546
  height: int = 1024,
547
  width: int = 1024,
548
- num_inference_steps: int = 25,
549
- guidance_scale: float = 7.0,
 
 
 
550
  seed: Optional[int] = None,
551
  use_lyra: bool = False,
552
- clip_skip: int = 2,
553
  t5_summary: str = "",
554
  lyra_strength: float = 1.0,
555
  progress_callback=None
@@ -561,8 +719,8 @@ class SDXLFlowMatchingPipeline:
561
  else:
562
  generator = None
563
 
564
- # Encode prompts
565
- if use_lyra and self.lyra_loader is not None:
566
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
567
  prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
568
  )
@@ -587,9 +745,10 @@ class SDXLFlowMatchingPipeline:
587
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
588
  timesteps = self.scheduler.timesteps
589
 
590
- latents = latents * self.scheduler.init_noise_sigma
 
591
 
592
- # Time embeddings for SDXL
593
  original_size = (height, width)
594
  target_size = (height, width)
595
  crops_coords_top_left = (0, 0)
@@ -605,7 +764,14 @@ class SDXLFlowMatchingPipeline:
605
  progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
606
 
607
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
608
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
 
 
 
609
 
610
  timestep = t.expand(latent_model_input.shape[0])
611
 
@@ -635,7 +801,22 @@ class SDXLFlowMatchingPipeline:
635
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
636
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
637
 
638
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
  # Decode
641
  latents = latents / self.vae_scale_factor
@@ -651,12 +832,310 @@ class SDXLFlowMatchingPipeline:
651
  return image
652
 
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  # ============================================================================
655
  # MODEL LOADERS
656
  # ============================================================================
657
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  def load_illustrious_xl(
659
- repo_id: str = "AbstractPhil/illustrious-xl-v1",
660
  filename: str = "illustriousXL_v01.safetensors",
661
  device: str = "cuda"
662
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
@@ -668,7 +1147,7 @@ def load_illustrious_xl(
668
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
669
  print(f"✓ Downloaded: {checkpoint_path}")
670
 
671
- print("📦 Loading pipeline...")
672
  pipe = StableDiffusionXLPipeline.from_single_file(
673
  checkpoint_path,
674
  torch_dtype=torch.float16,
@@ -686,6 +1165,51 @@ def load_illustrious_xl(
686
  torch.cuda.empty_cache()
687
 
688
  print("✅ Illustrious XL loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
 
690
  return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
691
 
@@ -694,60 +1218,111 @@ def load_illustrious_xl(
694
  # PIPELINE INITIALIZATION
695
  # ============================================================================
696
 
697
- def initialize_sdxl_pipeline(
698
- model_choice: str,
699
- scheduler_name: str = SCHEDULER_EULER_A,
700
- device: str = "cuda"
701
- ):
702
- """Initialize SDXL pipeline with lazy T5/Lyra loading."""
703
 
704
  print(f"🚀 Initializing {model_choice} pipeline...")
705
 
706
- # Load base model
707
- if "Illustrious" in model_choice:
708
- unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  else:
710
- # SDXL Base
711
- from diffusers import StableDiffusionXLPipeline
712
- pipe = StableDiffusionXLPipeline.from_pretrained(
713
- "stabilityai/stable-diffusion-xl-base-1.0",
714
- torch_dtype=torch.float16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  )
716
- unet = pipe.unet.to(device)
717
- vae = pipe.vae.to(device)
718
- text_encoder = pipe.text_encoder.to(device)
719
- text_encoder_2 = pipe.text_encoder_2.to(device)
720
- tokenizer = pipe.tokenizer
721
- tokenizer_2 = pipe.tokenizer_2
722
- del pipe
723
- torch.cuda.empty_cache()
724
-
725
- # Create lazy loaders (don't download yet)
726
- t5_loader = LazyT5Encoder(model_name="google/flan-t5-xl", device=device)
727
- lyra_loader = LazyLyraModel(
728
- repo_id="AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
729
- device=device,
730
- version=2
731
- )
732
-
733
- # Get scheduler
734
- scheduler = get_scheduler(scheduler_name)
735
-
736
- pipeline = SDXLFlowMatchingPipeline(
737
- vae=vae,
738
- text_encoder=text_encoder,
739
- text_encoder_2=text_encoder_2,
740
- tokenizer=tokenizer,
741
- tokenizer_2=tokenizer_2,
742
- unet=unet,
743
- scheduler=scheduler,
744
- device=device,
745
- t5_loader=t5_loader,
746
- lyra_loader=lyra_loader,
747
- clip_skip=2
748
- )
 
749
 
750
- print("✅ Pipeline initialized (T5/Lyra will load on first use)")
751
  return pipeline
752
 
753
 
@@ -757,20 +1332,15 @@ def initialize_sdxl_pipeline(
757
 
758
  CURRENT_PIPELINE = None
759
  CURRENT_MODEL = None
760
- CURRENT_SCHEDULER = None
761
 
762
 
763
- def get_pipeline(model_choice: str, scheduler_name: str = SCHEDULER_EULER_A):
764
  """Get or create pipeline for selected model."""
765
- global CURRENT_PIPELINE, CURRENT_MODEL, CURRENT_SCHEDULER
766
 
767
  if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
768
- CURRENT_PIPELINE = initialize_sdxl_pipeline(model_choice, scheduler_name, device="cuda")
769
  CURRENT_MODEL = model_choice
770
- CURRENT_SCHEDULER = scheduler_name
771
- elif CURRENT_SCHEDULER != scheduler_name:
772
- CURRENT_PIPELINE.set_scheduler(scheduler_name)
773
- CURRENT_SCHEDULER = scheduler_name
774
 
775
  return CURRENT_PIPELINE
776
 
@@ -779,18 +1349,36 @@ def get_pipeline(model_choice: str, scheduler_name: str = SCHEDULER_EULER_A):
779
  # INFERENCE
780
  # ============================================================================
781
 
782
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783
  def generate_image(
784
  prompt: str,
785
  t5_summary: str,
786
  negative_prompt: str,
787
  model_choice: str,
788
- scheduler_name: str,
789
  clip_skip: int,
790
  num_steps: int,
791
  cfg_scale: float,
792
  width: int,
793
  height: int,
 
 
794
  use_lyra: bool,
795
  lyra_strength: float,
796
  seed: int,
@@ -806,9 +1394,18 @@ def generate_image(
806
  progress((step + 1) / total, desc=desc)
807
 
808
  try:
809
- pipeline = get_pipeline(model_choice, scheduler_name)
 
 
 
 
 
810
 
811
- if not use_lyra or pipeline.lyra_loader is None:
 
 
 
 
812
  progress(0.05, desc="Generating...")
813
 
814
  image = pipeline(
@@ -818,6 +1415,9 @@ def generate_image(
818
  width=width,
819
  num_inference_steps=num_steps,
820
  guidance_scale=cfg_scale,
 
 
 
821
  seed=seed,
822
  use_lyra=False,
823
  clip_skip=clip_skip,
@@ -828,6 +1428,7 @@ def generate_image(
828
  return image, None, seed
829
 
830
  else:
 
831
  progress(0.05, desc="Generating standard...")
832
 
833
  image_standard = pipeline(
@@ -837,13 +1438,16 @@ def generate_image(
837
  width=width,
838
  num_inference_steps=num_steps,
839
  guidance_scale=cfg_scale,
 
 
 
840
  seed=seed,
841
  use_lyra=False,
842
  clip_skip=clip_skip,
843
  progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
844
  )
845
 
846
- progress(0.5, desc="Loading Lyra + T5 (first run only)...")
847
 
848
  image_lyra = pipeline(
849
  prompt=prompt,
@@ -852,6 +1456,9 @@ def generate_image(
852
  width=width,
853
  num_inference_steps=num_steps,
854
  guidance_scale=cfg_scale,
 
 
 
855
  seed=seed,
856
  use_lyra=True,
857
  clip_skip=clip_skip,
@@ -879,93 +1486,217 @@ def create_demo():
879
 
880
  with gr.Blocks() as demo:
881
  gr.Markdown("""
882
- # 🌙 Lyra/Illustrious XL Image Generation
883
 
884
  **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
885
 
 
 
886
  | Model | Architecture | Lyra Version | Best For |
887
  |-------|-------------|--------------|----------|
888
  | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
889
  | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
 
 
890
 
891
- **Lyra VAE** fuses CLIP + T5-XL embeddings using adaptive Cantor attention.
892
- T5 and Lyra only load when you enable the Lyra checkbox!
893
  """)
894
 
895
  with gr.Row():
896
  with gr.Column(scale=1):
897
  prompt = gr.TextArea(
898
- label="Prompt",
899
  value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
900
  lines=3
901
  )
902
 
903
  t5_summary = gr.TextArea(
904
- label="T5 Summary (for Lyra)",
905
- value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms",
906
  lines=2,
907
- info="Natural language description for T5. Leave empty to use prompt."
908
  )
909
 
910
  negative_prompt = gr.TextArea(
911
  label="Negative Prompt",
912
- value="lowres, bad anatomy, bad hands, text, error, worst quality, low quality",
913
  lines=2
914
  )
915
 
916
- with gr.Row():
917
- model_choice = gr.Dropdown(
918
- label="Model",
919
- choices=["Illustrious XL", "SDXL Base"],
920
- value="Illustrious XL"
921
- )
922
-
923
- scheduler_name = gr.Dropdown(
924
- label="Scheduler",
925
- choices=SDXL_SCHEDULERS,
926
- value=SCHEDULER_EULER_A
927
- )
 
 
 
 
 
928
 
929
  clip_skip = gr.Slider(
930
  label="CLIP Skip",
931
- minimum=1, maximum=4, value=2, step=1,
932
- info="2 recommended for Illustrious"
 
 
 
933
  )
934
 
935
  use_lyra = gr.Checkbox(
936
- label="Enable Lyra VAE (loads T5-XL on first use)",
937
  value=False,
938
- info="Compare standard vs geometric fusion"
939
  )
940
 
941
  lyra_strength = gr.Slider(
942
  label="Lyra Blend Strength",
943
- minimum=0.0, maximum=2.0, value=1.0, step=0.05,
944
- info="0.0 = pure CLIP, 1.0 = pure Lyra"
 
 
 
945
  )
946
 
947
  with gr.Accordion("Generation Settings", open=True):
948
- num_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1)
949
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1.0, maximum=15.0, value=7.0, step=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
950
 
951
  with gr.Row():
952
- width = gr.Slider(label="Width", minimum=512, maximum=1536, value=1024, step=64)
953
- height = gr.Slider(label="Height", minimum=512, maximum=1536, value=1024, step=64)
 
 
 
 
 
 
 
 
 
 
 
 
954
 
955
- seed = gr.Slider(label="Seed", minimum=0, maximum=2**32 - 1, value=42, step=1)
956
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
 
958
  generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
959
 
960
  with gr.Column(scale=1):
961
  with gr.Row():
962
- output_image_standard = gr.Image(label="Standard", type="pil")
963
- output_image_lyra = gr.Image(label="Lyra Fusion 🎵", type="pil", visible=True)
 
 
 
 
 
 
 
964
 
965
  output_seed = gr.Number(label="Seed", precision=0)
 
 
 
 
 
 
 
 
966
 
967
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  def on_lyra_toggle(enabled):
 
969
  if enabled:
970
  return {
971
  output_image_standard: gr.update(visible=True, label="Standard"),
@@ -977,6 +1708,12 @@ def create_demo():
977
  output_image_lyra: gr.update(visible=False)
978
  }
979
 
 
 
 
 
 
 
980
  use_lyra.change(
981
  fn=on_lyra_toggle,
982
  inputs=[use_lyra],
@@ -986,9 +1723,9 @@ def create_demo():
986
  generate_btn.click(
987
  fn=generate_image,
988
  inputs=[
989
- prompt, t5_summary, negative_prompt, model_choice, scheduler_name,
990
- clip_skip, num_steps, cfg_scale, width, height,
991
- use_lyra, lyra_strength, seed, randomize_seed
992
  ],
993
  outputs=[output_image_standard, output_image_lyra, output_seed]
994
  )
 
9
  Lyra VAE Versions:
10
  - v1: SD1.5 (768 dim CLIP + T5-base) - geofractal.model.vae.vae_lyra
11
  - v2: SDXL/Illustrious (768 CLIP-L + 1280 CLIP-G + 2048 T5-XL) - geofractal.model.vae.vae_lyra_v2
12
+
13
+ Features:
14
+ - Lazy loading: T5 and Lyra only download when first used
15
+ - Multiple schedulers: Euler Ancestral, Euler, DPM++ 2M SDE, DPM++ 2M
16
+ - Integrated loader module for automatic version detection
17
  """
18
 
19
  import os
 
22
  import gradio as gr
23
  import numpy as np
24
  from PIL import Image
25
+ from typing import Optional, Dict, Tuple, Union
26
  import spaces
27
  from safetensors.torch import load_file as load_safetensors
28
 
 
43
  )
44
  from huggingface_hub import hf_hub_download
45
 
46
+ # Import Lyra VAE v1 (SD1.5) from geofractal
47
+ try:
48
+ from geofractal.model.vae.vae_lyra import MultiModalVAE as LyraV1, MultiModalVAEConfig as LyraV1Config
49
+ LYRA_V1_AVAILABLE = True
50
+ except ImportError:
51
+ print("⚠️ Lyra VAE v1 not available")
52
+ LYRA_V1_AVAILABLE = False
53
 
54
+ # Import Lyra VAE v2 (SDXL/Illustrious) from geofractal
55
+ try:
56
+ from geofractal.model.vae.vae_lyra_v2 import MultiModalVAE as LyraV2, MultiModalVAEConfig as LyraV2Config
57
+ LYRA_V2_AVAILABLE = True
58
+ except ImportError:
59
+ print("⚠️ Lyra VAE v2 not available")
60
+ LYRA_V2_AVAILABLE = False
61
 
62
+ # Import Lyra loader module
63
+ try:
64
+ from geofractal.model.vae.load_lyra import load_vae_lyra, load_lyra_illustrious
65
+ LYRA_LOADER_AVAILABLE = True
66
+ except ImportError:
67
+ print("⚠️ Lyra loader module not available, using fallback")
68
+ LYRA_LOADER_AVAILABLE = False
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  # ============================================================================
 
75
  ARCH_SD15 = "sd15"
76
  ARCH_SDXL = "sdxl"
77
 
78
+ # Scheduler names
79
  SCHEDULER_EULER_A = "Euler Ancestral"
80
  SCHEDULER_EULER = "Euler"
81
  SCHEDULER_DPM_2M_SDE = "DPM++ 2M SDE"
82
  SCHEDULER_DPM_2M = "DPM++ 2M"
83
 
84
+ SCHEDULER_CHOICES = [
85
+ SCHEDULER_EULER_A,
86
+ SCHEDULER_EULER,
87
+ SCHEDULER_DPM_2M_SDE,
88
+ SCHEDULER_DPM_2M,
89
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # ComfyUI key prefixes for SDXL single-file checkpoints
92
+ COMFYUI_UNET_PREFIX = "model.diffusion_model."
93
+ COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer."
94
+ COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model."
95
+ COMFYUI_VAE_PREFIX = "first_stage_model."
96
 
97
+ # Lyra repos
98
+ LYRA_ILLUSTRIOUS_REPO = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious"
99
+ LYRA_SD15_REPO = "AbstractPhil/vae-lyra"
100
 
101
+ # T5 model - use flan-t5-xl (what Lyra was trained on)
102
+ T5_XL_MODEL = "google/flan-t5-xl"
103
+ T5_BASE_MODEL = "google/flan-t5-base"
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  # ============================================================================
 
108
  # ============================================================================
109
 
110
  class LazyT5Encoder:
111
+ """Lazy loader for T5 encoder - only downloads/loads when first accessed."""
112
 
113
+ def __init__(self, model_name: str = T5_XL_MODEL, device: str = "cuda", dtype=torch.float16):
114
  self.model_name = model_name
115
  self.device = device
116
+ self.dtype = dtype
117
  self._encoder = None
118
  self._tokenizer = None
119
+ self._loaded = False
120
 
121
  @property
122
+ def encoder(self) -> T5EncoderModel:
123
  if self._encoder is None:
124
+ print(f"📥 Lazy loading T5 encoder: {self.model_name}...")
125
  self._encoder = T5EncoderModel.from_pretrained(
126
  self.model_name,
127
+ torch_dtype=self.dtype
128
  ).to(self.device)
129
  self._encoder.eval()
130
+ print(f"✓ T5 encoder loaded ({sum(p.numel() for p in self._encoder.parameters())/1e6:.1f}M params)")
131
+ self._loaded = True
132
  return self._encoder
133
 
134
  @property
135
+ def tokenizer(self) -> T5Tokenizer:
136
  if self._tokenizer is None:
137
  print(f"📥 Loading T5 tokenizer: {self.model_name}...")
138
  self._tokenizer = T5Tokenizer.from_pretrained(self.model_name)
139
  print("✓ T5 tokenizer loaded")
140
  return self._tokenizer
141
 
142
+ @property
143
+ def is_loaded(self) -> bool:
144
+ return self._loaded
145
+
146
+ def unload(self):
147
+ """Free VRAM by unloading the encoder."""
148
+ if self._encoder is not None:
149
+ del self._encoder
150
+ self._encoder = None
151
+ self._loaded = False
152
+ torch.cuda.empty_cache()
153
+ print("🗑️ T5 encoder unloaded")
154
 
155
 
156
  class LazyLyraModel:
157
+ """Lazy loader for Lyra VAE - only downloads/loads when first accessed."""
158
 
159
+ def __init__(
160
+ self,
161
+ repo_id: str = LYRA_ILLUSTRIOUS_REPO,
162
+ device: str = "cuda",
163
+ checkpoint: Optional[str] = None
164
+ ):
165
  self.repo_id = repo_id
166
  self.device = device
167
+ self.checkpoint = checkpoint
168
  self._model = None
169
+ self._info = None
170
+ self._loaded = False
171
 
172
  @property
173
  def model(self):
174
  if self._model is None:
175
+ print(f"📥 Lazy loading Lyra VAE: {self.repo_id}...")
176
 
177
+ if LYRA_LOADER_AVAILABLE:
178
+ # Use the loader module
179
+ self._model, self._info = load_vae_lyra(
180
+ self.repo_id,
181
+ checkpoint=self.checkpoint,
182
+ device=self.device,
183
+ return_info=True
184
+ )
185
  else:
186
+ # Fallback to manual loading
187
+ self._model = self._load_fallback()
188
+ self._info = {"repo_id": self.repo_id, "version": "v2"}
189
+
190
+ self._model.eval()
191
+ self._loaded = True
192
+ print(f"✓ Lyra VAE loaded")
193
  return self._model
194
 
195
+ @property
196
+ def info(self) -> Optional[Dict]:
197
+ if self._info is None and self._model is not None:
198
+ return {"repo_id": self.repo_id}
199
+ return self._info
200
+
201
+ @property
202
+ def is_loaded(self) -> bool:
203
+ return self._loaded
204
+
205
+ def _load_fallback(self):
206
+ """Fallback loading if loader module not available."""
207
  if not LYRA_V2_AVAILABLE:
208
+ raise ImportError("Lyra VAE v2 not available")
 
209
 
210
+ config_path = hf_hub_download(
211
+ repo_id=self.repo_id,
212
+ filename="config.json",
213
+ repo_type="model"
214
+ )
215
 
216
+ with open(config_path, 'r') as f:
217
+ config_dict = json.load(f)
218
+
219
+ # Find checkpoint
220
+ from huggingface_hub import list_repo_files
221
+ import re
222
+
223
+ repo_files = list_repo_files(self.repo_id, repo_type="model")
224
+ checkpoint_files = [f for f in repo_files if f.endswith('.safetensors') or f.endswith('.pt')]
225
+
226
+ # Prefer weights/ folder
227
+ weights_files = [f for f in checkpoint_files if f.startswith('weights/')]
228
+ if weights_files:
229
+ checkpoint_file = sorted(weights_files)[-1] # Latest
230
+ elif checkpoint_files:
231
+ checkpoint_file = checkpoint_files[0]
232
+ else:
233
+ raise FileNotFoundError(f"No checkpoint found in {self.repo_id}")
234
+
235
+ checkpoint_path = hf_hub_download(
236
+ repo_id=self.repo_id,
237
+ filename=checkpoint_file,
238
+ repo_type="model"
239
+ )
240
+
241
+ # Load weights
242
+ if checkpoint_file.endswith('.safetensors'):
243
+ state_dict = load_safetensors(checkpoint_path, device="cpu")
244
+ else:
 
 
 
 
 
 
 
 
245
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  state_dict = checkpoint.get('model_state_dict', checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ # Build config
249
+ vae_config = LyraV2Config(
250
+ modality_dims=config_dict.get('modality_dims'),
251
+ modality_seq_lens=config_dict.get('modality_seq_lens'),
252
+ binding_config=config_dict.get('binding_config'),
253
+ latent_dim=config_dict.get('latent_dim', 2048),
254
+ hidden_dim=config_dict.get('hidden_dim', 2048),
255
+ fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
256
+ )
257
+
258
+ model = LyraV2(vae_config)
259
+ model.load_state_dict(state_dict, strict=False)
260
+ model.to(self.device)
261
+
262
+ return model
263
+
264
+ def unload(self):
265
+ """Free VRAM by unloading the model."""
266
+ if self._model is not None:
267
+ del self._model
268
+ self._model = None
269
+ self._info = None
270
+ self._loaded = False
271
+ torch.cuda.empty_cache()
272
+ print("🗑️ Lyra VAE unloaded")
273
+
274
+
275
+ # ============================================================================
276
+ # SCHEDULER FACTORY
277
+ # ============================================================================
278
+
279
+ def get_scheduler(
280
+ scheduler_name: str,
281
+ config_source: str = "stabilityai/stable-diffusion-xl-base-1.0",
282
+ is_sdxl: bool = True
283
+ ):
284
+ """Create scheduler by name.
285
+
286
+ Args:
287
+ scheduler_name: One of SCHEDULER_CHOICES
288
+ config_source: HF repo to load scheduler config from
289
+ is_sdxl: Whether this is for SDXL (affects some defaults)
290
+
291
+ Returns:
292
+ Configured scheduler instance
293
+ """
294
+ subfolder = "scheduler"
295
+
296
+ if scheduler_name == SCHEDULER_EULER_A:
297
+ return EulerAncestralDiscreteScheduler.from_pretrained(
298
+ config_source,
299
+ subfolder=subfolder
300
+ )
301
+
302
+ elif scheduler_name == SCHEDULER_EULER:
303
+ return EulerDiscreteScheduler.from_pretrained(
304
+ config_source,
305
+ subfolder=subfolder
306
+ )
307
+
308
+ elif scheduler_name == SCHEDULER_DPM_2M_SDE:
309
+ # DPM++ 2M SDE - good for detailed images
310
+ return DPMSolverSDEScheduler.from_pretrained(
311
+ config_source,
312
+ subfolder=subfolder,
313
+ algorithm_type="sde-dpmsolver++",
314
+ solver_order=2,
315
+ use_karras_sigmas=True,
316
+ )
317
+
318
+ elif scheduler_name == SCHEDULER_DPM_2M:
319
+ # DPM++ 2M - fast and quality
320
+ return DPMSolverMultistepScheduler.from_pretrained(
321
+ config_source,
322
+ subfolder=subfolder,
323
+ algorithm_type="dpmsolver++",
324
+ solver_order=2,
325
+ use_karras_sigmas=True,
326
+ )
327
+
328
+ else:
329
+ print(f"⚠️ Unknown scheduler '{scheduler_name}', defaulting to Euler Ancestral")
330
+ return EulerAncestralDiscreteScheduler.from_pretrained(
331
+ config_source,
332
+ subfolder=subfolder
333
+ )
334
+
335
+
336
+ # ============================================================================
337
+ # UTILITIES
338
+ # ============================================================================
339
+
340
+ def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
341
+ """Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint."""
342
+
343
+ components = {
344
+ "unet": {},
345
+ "clip_l": {},
346
+ "clip_g": {},
347
+ "vae": {}
348
+ }
349
+
350
+ for key, value in state_dict.items():
351
+ if key.startswith(COMFYUI_UNET_PREFIX):
352
+ new_key = key[len(COMFYUI_UNET_PREFIX):]
353
+ components["unet"][new_key] = value
354
+ elif key.startswith(COMFYUI_CLIP_L_PREFIX):
355
+ new_key = key[len(COMFYUI_CLIP_L_PREFIX):]
356
+ components["clip_l"][new_key] = value
357
+ elif key.startswith(COMFYUI_CLIP_G_PREFIX):
358
+ new_key = key[len(COMFYUI_CLIP_G_PREFIX):]
359
+ components["clip_g"][new_key] = value
360
+ elif key.startswith(COMFYUI_VAE_PREFIX):
361
+ new_key = key[len(COMFYUI_VAE_PREFIX):]
362
+ components["vae"][new_key] = value
363
+
364
+ print(f" Extracted components:")
365
+ print(f" UNet: {len(components['unet'])} keys")
366
+ print(f" CLIP-L: {len(components['clip_l'])} keys")
367
+ print(f" CLIP-G: {len(components['clip_g'])} keys")
368
+ print(f" VAE: {len(components['vae'])} keys")
369
+
370
+ return components
371
+
372
+
373
+ def get_clip_hidden_state(
374
+ model_output,
375
+ clip_skip: int = 1,
376
+ output_hidden_states: bool = True
377
+ ) -> torch.Tensor:
378
+ """Extract hidden state with clip_skip support."""
379
+ if clip_skip == 1 or not output_hidden_states:
380
+ return model_output.last_hidden_state
381
+
382
+ if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
383
+ return model_output.hidden_states[-clip_skip]
384
 
385
+ return model_output.last_hidden_state
 
386
 
387
 
388
  # ============================================================================
 
390
  # ============================================================================
391
 
392
  class SDXLFlowMatchingPipeline:
393
+ """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders.
394
+
395
+ Uses lazy loading for T5 and Lyra - they're only downloaded when actually used.
396
+ """
397
 
398
  def __init__(
399
  self,
 
418
  self.scheduler = scheduler
419
  self.device = device
420
 
421
+ # Lazy loaders for Lyra components
422
  self.t5_loader = t5_loader
423
  self.lyra_loader = lyra_loader
424
 
 
426
  self.clip_skip = clip_skip
427
  self.vae_scale_factor = 0.13025
428
  self.arch = ARCH_SDXL
429
+
430
+ # Track current scheduler name for UI
431
+ self._scheduler_name = SCHEDULER_EULER_A
432
 
433
  def set_scheduler(self, scheduler_name: str):
434
+ """Switch scheduler without reloading model."""
435
+ if scheduler_name != self._scheduler_name:
436
+ self.scheduler = get_scheduler(
437
+ scheduler_name,
438
+ config_source="stabilityai/stable-diffusion-xl-base-1.0",
439
+ is_sdxl=True
440
+ )
441
+ self._scheduler_name = scheduler_name
442
+ print(f"✓ Scheduler changed to: {scheduler_name}")
443
 
444
  @property
445
+ def t5_encoder(self) -> Optional[T5EncoderModel]:
446
+ """Access T5 encoder (triggers lazy load if needed)."""
447
  return self.t5_loader.encoder if self.t5_loader else None
448
 
449
  @property
450
+ def t5_tokenizer(self) -> Optional[T5Tokenizer]:
451
+ """Access T5 tokenizer (triggers lazy load if needed)."""
452
  return self.t5_loader.tokenizer if self.t5_loader else None
453
 
454
  @property
455
  def lyra_model(self):
456
+ """Access Lyra model (triggers lazy load if needed)."""
457
  return self.lyra_loader.model if self.lyra_loader else None
458
 
459
+ @property
460
+ def lyra_available(self) -> bool:
461
+ """Check if Lyra components are configured (not necessarily loaded)."""
462
+ return self.t5_loader is not None and self.lyra_loader is not None
463
+
464
  def encode_prompt(
465
  self,
466
  prompt: str,
 
505
  prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
506
  pooled_prompt_embeds = clip_g_output.text_embeds
507
 
508
+ # Concatenate CLIP-L and CLIP-G embeddings
509
  prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
510
 
511
  # Negative prompt
 
557
  t5_summary: str = "",
558
  lyra_strength: float = 0.3
559
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
560
+ """Encode prompts using Lyra VAE v2 fusion (CLIP + T5).
561
 
562
+ This triggers lazy loading of T5 and Lyra if not already loaded.
563
+ """
564
+ if not self.lyra_available:
565
+ raise ValueError("Lyra VAE components not configured")
566
+
567
+ # Access properties triggers lazy load
568
+ t5_encoder = self.t5_encoder
569
+ t5_tokenizer = self.t5_tokenizer
570
+ lyra_model = self.lyra_model
571
 
572
  # Get standard CLIP embeddings first
573
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
574
  prompt, negative_prompt, clip_skip
575
  )
576
 
577
+ # Format T5 input with pilcrow separator (¶)
578
  SUMMARY_SEPARATOR = "¶"
579
  if t5_summary.strip():
580
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {t5_summary}"
 
582
  t5_prompt = f"{prompt} {SUMMARY_SEPARATOR} {prompt}"
583
 
584
  # Get T5 embeddings
585
+ t5_inputs = t5_tokenizer(
586
  t5_prompt,
587
  max_length=512,
588
  padding='max_length',
 
591
  ).to(self.device)
592
 
593
  with torch.no_grad():
594
+ t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
595
 
596
  clip_l_dim = 768
597
+ clip_g_dim = 1280
598
+
599
  clip_l_embeds = prompt_embeds[..., :clip_l_dim]
600
  clip_g_embeds = prompt_embeds[..., clip_l_dim:]
601
 
 
606
  't5_xl_l': t5_embeds.float(),
607
  't5_xl_g': t5_embeds.float()
608
  }
609
+ reconstructions, mu, logvar, _ = lyra_model(
610
  modality_inputs,
611
  target_modalities=['clip_l', 'clip_g']
612
  )
 
614
  lyra_clip_l = reconstructions['clip_l'].to(prompt_embeds.dtype)
615
  lyra_clip_g = reconstructions['clip_g'].to(prompt_embeds.dtype)
616
 
617
+ # Normalize reconstructions to match input statistics
618
  clip_l_std_ratio = lyra_clip_l.std() / (clip_l_embeds.std() + 1e-8)
619
  clip_g_std_ratio = lyra_clip_g.std() / (clip_g_embeds.std() + 1e-8)
620
 
 
626
  lyra_clip_g = (lyra_clip_g - lyra_clip_g.mean()) / (lyra_clip_g.std() + 1e-8)
627
  lyra_clip_g = lyra_clip_g * clip_g_embeds.std() + clip_g_embeds.mean()
628
 
629
+ # Blend original CLIP with Lyra reconstruction
630
  fused_clip_l = (1 - lyra_strength) * clip_l_embeds + lyra_strength * lyra_clip_l
631
  fused_clip_g = (1 - lyra_strength) * clip_g_embeds + lyra_strength * lyra_clip_g
632
 
633
  prompt_embeds_fused = torch.cat([fused_clip_l, fused_clip_g], dim=-1)
634
 
635
+ # Process negative prompt
636
+ if negative_prompt:
637
+ neg_strength = lyra_strength * 0.5 # Less aggressive for negative
638
+
639
+ t5_neg_prompt = f"{negative_prompt} {SUMMARY_SEPARATOR} {negative_prompt}"
640
+ t5_inputs_neg = t5_tokenizer(
641
+ t5_neg_prompt,
642
+ max_length=512,
643
+ padding='max_length',
644
+ truncation=True,
645
+ return_tensors='pt'
646
+ ).to(self.device)
647
+
648
+ with torch.no_grad():
649
+ t5_embeds_neg = t5_encoder(**t5_inputs_neg).last_hidden_state
650
+
651
+ neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
652
+ neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
653
+
654
+ modality_inputs_neg = {
655
+ 'clip_l': neg_clip_l.float(),
656
+ 'clip_g': neg_clip_g.float(),
657
+ 't5_xl_l': t5_embeds_neg.float(),
658
+ 't5_xl_g': t5_embeds_neg.float()
659
+ }
660
+ recon_neg, _, _, _ = lyra_model(modality_inputs_neg, target_modalities=['clip_l', 'clip_g'])
661
+
662
+ lyra_neg_l = recon_neg['clip_l'].to(negative_prompt_embeds.dtype)
663
+ lyra_neg_g = recon_neg['clip_g'].to(negative_prompt_embeds.dtype)
664
+
665
+ # Normalize
666
+ neg_l_ratio = lyra_neg_l.std() / (neg_clip_l.std() + 1e-8)
667
+ neg_g_ratio = lyra_neg_g.std() / (neg_clip_g.std() + 1e-8)
668
+ if neg_l_ratio > 2.0 or neg_l_ratio < 0.5:
669
+ lyra_neg_l = (lyra_neg_l - lyra_neg_l.mean()) / (lyra_neg_l.std() + 1e-8)
670
+ lyra_neg_l = lyra_neg_l * neg_clip_l.std() + neg_clip_l.mean()
671
+ if neg_g_ratio > 2.0 or neg_g_ratio < 0.5:
672
+ lyra_neg_g = (lyra_neg_g - lyra_neg_g.mean()) / (lyra_neg_g.std() + 1e-8)
673
+ lyra_neg_g = lyra_neg_g * neg_clip_g.std() + neg_clip_g.mean()
674
+
675
+ fused_neg_l = (1 - neg_strength) * neg_clip_l + neg_strength * lyra_neg_l
676
+ fused_neg_g = (1 - neg_strength) * neg_clip_g + neg_strength * lyra_neg_g
677
+
678
+ negative_prompt_embeds_fused = torch.cat([fused_neg_l, fused_neg_g], dim=-1)
679
+ else:
680
+ negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
681
+
682
+ return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled
683
 
684
  def _get_add_time_ids(
685
  self,
 
700
  negative_prompt: str = "",
701
  height: int = 1024,
702
  width: int = 1024,
703
+ num_inference_steps: int = 20,
704
+ guidance_scale: float = 7.5,
705
+ shift: float = 0.0,
706
+ use_flow_matching: bool = False,
707
+ prediction_type: str = "epsilon",
708
  seed: Optional[int] = None,
709
  use_lyra: bool = False,
710
+ clip_skip: int = 1,
711
  t5_summary: str = "",
712
  lyra_strength: float = 1.0,
713
  progress_callback=None
 
719
  else:
720
  generator = None
721
 
722
+ # Encode prompts (Lyra triggers lazy load only if use_lyra=True)
723
+ if use_lyra and self.lyra_available:
724
  prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
725
  prompt, negative_prompt, clip_skip, t5_summary, lyra_strength
726
  )
 
745
  self.scheduler.set_timesteps(num_inference_steps, device=self.device)
746
  timesteps = self.scheduler.timesteps
747
 
748
+ if not use_flow_matching:
749
+ latents = latents * self.scheduler.init_noise_sigma
750
 
751
+ # Prepare added time embeddings for SDXL
752
  original_size = (height, width)
753
  target_size = (height, width)
754
  crops_coords_top_left = (0, 0)
 
764
  progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
765
 
766
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
767
+
768
+ if use_flow_matching and shift > 0:
769
+ sigma = t.float() / 1000.0
770
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
771
+ scaling = torch.sqrt(1 + sigma_shifted ** 2)
772
+ latent_model_input = latent_model_input / scaling
773
+ else:
774
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
775
 
776
  timestep = t.expand(latent_model_input.shape[0])
777
 
 
801
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
802
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
803
 
804
+ if use_flow_matching:
805
+ sigma = t.float() / 1000.0
806
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
807
+
808
+ if prediction_type == "v_prediction":
809
+ v_pred = noise_pred
810
+ alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
811
+ sigma_t = sigma_shifted
812
+ noise_pred = alpha_t * v_pred + sigma_t * latents
813
+
814
+ dt = -1.0 / num_inference_steps
815
+ latents = latents + dt * noise_pred
816
+ else:
817
+ latents = self.scheduler.step(
818
+ noise_pred, t, latents, return_dict=False
819
+ )[0]
820
 
821
  # Decode
822
  latents = latents / self.vae_scale_factor
 
832
  return image
833
 
834
 
835
+ # ============================================================================
836
+ # SD1.5 PIPELINE
837
+ # ============================================================================
838
+
839
+ class SD15FlowMatchingPipeline:
840
+ """Pipeline for SD1.5-based flow-matching inference."""
841
+
842
+ def __init__(
843
+ self,
844
+ vae: AutoencoderKL,
845
+ text_encoder: CLIPTextModel,
846
+ tokenizer: CLIPTokenizer,
847
+ unet: UNet2DConditionModel,
848
+ scheduler,
849
+ device: str = "cuda",
850
+ t5_loader: Optional[LazyT5Encoder] = None,
851
+ lyra_loader: Optional[LazyLyraModel] = None,
852
+ ):
853
+ self.vae = vae
854
+ self.text_encoder = text_encoder
855
+ self.tokenizer = tokenizer
856
+ self.unet = unet
857
+ self.scheduler = scheduler
858
+ self.device = device
859
+
860
+ self.t5_loader = t5_loader
861
+ self.lyra_loader = lyra_loader
862
+
863
+ self.vae_scale_factor = 0.18215
864
+ self.arch = ARCH_SD15
865
+ self.is_lune_model = False
866
+
867
+ @property
868
+ def t5_encoder(self):
869
+ return self.t5_loader.encoder if self.t5_loader else None
870
+
871
+ @property
872
+ def t5_tokenizer(self):
873
+ return self.t5_loader.tokenizer if self.t5_loader else None
874
+
875
+ @property
876
+ def lyra_model(self):
877
+ return self.lyra_loader.model if self.lyra_loader else None
878
+
879
+ @property
880
+ def lyra_available(self) -> bool:
881
+ return self.t5_loader is not None and self.lyra_loader is not None
882
+
883
+ def encode_prompt(self, prompt: str, negative_prompt: str = ""):
884
+ """Encode text prompts to embeddings."""
885
+ text_inputs = self.tokenizer(
886
+ prompt,
887
+ padding="max_length",
888
+ max_length=self.tokenizer.model_max_length,
889
+ truncation=True,
890
+ return_tensors="pt",
891
+ )
892
+ text_input_ids = text_inputs.input_ids.to(self.device)
893
+
894
+ with torch.no_grad():
895
+ prompt_embeds = self.text_encoder(text_input_ids)[0]
896
+
897
+ if negative_prompt:
898
+ uncond_inputs = self.tokenizer(
899
+ negative_prompt,
900
+ padding="max_length",
901
+ max_length=self.tokenizer.model_max_length,
902
+ truncation=True,
903
+ return_tensors="pt",
904
+ )
905
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
906
+
907
+ with torch.no_grad():
908
+ negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
909
+ else:
910
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
911
+
912
+ return prompt_embeds, negative_prompt_embeds
913
+
914
+ def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
915
+ """Encode using Lyra VAE (CLIP + T5 fusion)."""
916
+ if not self.lyra_available:
917
+ raise ValueError("Lyra VAE components not configured")
918
+
919
+ t5_encoder = self.t5_encoder
920
+ t5_tokenizer = self.t5_tokenizer
921
+ lyra_model = self.lyra_model
922
+
923
+ # CLIP
924
+ text_inputs = self.tokenizer(
925
+ prompt,
926
+ padding="max_length",
927
+ max_length=self.tokenizer.model_max_length,
928
+ truncation=True,
929
+ return_tensors="pt",
930
+ )
931
+ text_input_ids = text_inputs.input_ids.to(self.device)
932
+
933
+ with torch.no_grad():
934
+ clip_embeds = self.text_encoder(text_input_ids)[0]
935
+
936
+ # T5
937
+ t5_inputs = t5_tokenizer(
938
+ prompt,
939
+ max_length=77,
940
+ padding='max_length',
941
+ truncation=True,
942
+ return_tensors='pt'
943
+ ).to(self.device)
944
+
945
+ with torch.no_grad():
946
+ t5_embeds = t5_encoder(**t5_inputs).last_hidden_state
947
+
948
+ # Fuse
949
+ modality_inputs = {'clip': clip_embeds, 't5': t5_embeds}
950
+
951
+ with torch.no_grad():
952
+ reconstructions, mu, logvar = lyra_model(
953
+ modality_inputs,
954
+ target_modalities=['clip']
955
+ )
956
+ prompt_embeds = reconstructions['clip']
957
+
958
+ # Negative
959
+ if negative_prompt:
960
+ uncond_inputs = self.tokenizer(
961
+ negative_prompt,
962
+ padding="max_length",
963
+ max_length=self.tokenizer.model_max_length,
964
+ truncation=True,
965
+ return_tensors="pt",
966
+ )
967
+ uncond_input_ids = uncond_inputs.input_ids.to(self.device)
968
+
969
+ with torch.no_grad():
970
+ clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
971
+
972
+ t5_inputs_uncond = t5_tokenizer(
973
+ negative_prompt,
974
+ max_length=77,
975
+ padding='max_length',
976
+ truncation=True,
977
+ return_tensors='pt'
978
+ ).to(self.device)
979
+
980
+ with torch.no_grad():
981
+ t5_embeds_uncond = t5_encoder(**t5_inputs_uncond).last_hidden_state
982
+
983
+ modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond}
984
+
985
+ with torch.no_grad():
986
+ reconstructions_uncond, _, _ = lyra_model(
987
+ modality_inputs_uncond,
988
+ target_modalities=['clip']
989
+ )
990
+ negative_prompt_embeds = reconstructions_uncond['clip']
991
+ else:
992
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
993
+
994
+ return prompt_embeds, negative_prompt_embeds
995
+
996
+ @torch.no_grad()
997
+ def __call__(
998
+ self,
999
+ prompt: str,
1000
+ negative_prompt: str = "",
1001
+ height: int = 512,
1002
+ width: int = 512,
1003
+ num_inference_steps: int = 20,
1004
+ guidance_scale: float = 7.5,
1005
+ shift: float = 2.5,
1006
+ use_flow_matching: bool = True,
1007
+ prediction_type: str = "epsilon",
1008
+ seed: Optional[int] = None,
1009
+ use_lyra: bool = False,
1010
+ clip_skip: int = 1,
1011
+ t5_summary: str = "",
1012
+ lyra_strength: float = 1.0,
1013
+ progress_callback=None
1014
+ ):
1015
+ """Generate image."""
1016
+
1017
+ if seed is not None:
1018
+ generator = torch.Generator(device=self.device).manual_seed(seed)
1019
+ else:
1020
+ generator = None
1021
+
1022
+ if use_lyra and self.lyra_available:
1023
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt)
1024
+ else:
1025
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt)
1026
+
1027
+ latent_channels = 4
1028
+ latent_height = height // 8
1029
+ latent_width = width // 8
1030
+
1031
+ latents = torch.randn(
1032
+ (1, latent_channels, latent_height, latent_width),
1033
+ generator=generator,
1034
+ device=self.device,
1035
+ dtype=torch.float32
1036
+ )
1037
+
1038
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
1039
+ timesteps = self.scheduler.timesteps
1040
+
1041
+ if not use_flow_matching:
1042
+ latents = latents * self.scheduler.init_noise_sigma
1043
+
1044
+ for i, t in enumerate(timesteps):
1045
+ if progress_callback:
1046
+ progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
1047
+
1048
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
1049
+
1050
+ if use_flow_matching and shift > 0:
1051
+ sigma = t.float() / 1000.0
1052
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
1053
+ scaling = torch.sqrt(1 + sigma_shifted ** 2)
1054
+ latent_model_input = latent_model_input / scaling
1055
+ else:
1056
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1057
+
1058
+ timestep = t.expand(latent_model_input.shape[0])
1059
+ text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
1060
+
1061
+ noise_pred = self.unet(
1062
+ latent_model_input,
1063
+ timestep,
1064
+ encoder_hidden_states=text_embeds,
1065
+ return_dict=False
1066
+ )[0]
1067
+
1068
+ if guidance_scale > 1.0:
1069
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1070
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1071
+
1072
+ if use_flow_matching:
1073
+ sigma = t.float() / 1000.0
1074
+ sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
1075
+
1076
+ if prediction_type == "v_prediction":
1077
+ v_pred = noise_pred
1078
+ alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
1079
+ sigma_t = sigma_shifted
1080
+ noise_pred = alpha_t * v_pred + sigma_t * latents
1081
+
1082
+ dt = -1.0 / num_inference_steps
1083
+ latents = latents + dt * noise_pred
1084
+ else:
1085
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1086
+
1087
+ latents = latents / self.vae_scale_factor
1088
+
1089
+ if self.is_lune_model:
1090
+ latents = latents * 5.52
1091
+
1092
+ with torch.no_grad():
1093
+ image = self.vae.decode(latents).sample
1094
+
1095
+ image = (image / 2 + 0.5).clamp(0, 1)
1096
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1097
+ image = (image * 255).round().astype("uint8")
1098
+ image = Image.fromarray(image[0])
1099
+
1100
+ return image
1101
+
1102
+
1103
  # ============================================================================
1104
  # MODEL LOADERS
1105
  # ============================================================================
1106
 
1107
+ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
1108
+ """Load Lune checkpoint from .pt file."""
1109
+ print(f"📥 Downloading: {repo_id}/{filename}")
1110
+
1111
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
1112
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
1113
+
1114
+ print(f"🏗️ Initializing SD1.5 UNet...")
1115
+ unet = UNet2DConditionModel.from_pretrained(
1116
+ "runwayml/stable-diffusion-v1-5",
1117
+ subfolder="unet",
1118
+ torch_dtype=torch.float32
1119
+ )
1120
+
1121
+ student_state_dict = checkpoint["student"]
1122
+ cleaned_dict = {}
1123
+ for key, value in student_state_dict.items():
1124
+ if key.startswith("unet."):
1125
+ cleaned_dict[key[5:]] = value
1126
+ else:
1127
+ cleaned_dict[key] = value
1128
+
1129
+ unet.load_state_dict(cleaned_dict, strict=False)
1130
+
1131
+ step = checkpoint.get("gstep", "unknown")
1132
+ print(f"✅ Loaded Lune from step {step}")
1133
+
1134
+ return unet.to(device)
1135
+
1136
+
1137
  def load_illustrious_xl(
1138
+ repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
1139
  filename: str = "illustriousXL_v01.safetensors",
1140
  device: str = "cuda"
1141
  ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
 
1147
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
1148
  print(f"✓ Downloaded: {checkpoint_path}")
1149
 
1150
+ print("📦 Loading with StableDiffusionXLPipeline.from_single_file()...")
1151
  pipe = StableDiffusionXLPipeline.from_single_file(
1152
  checkpoint_path,
1153
  torch_dtype=torch.float16,
 
1165
  torch.cuda.empty_cache()
1166
 
1167
  print("✅ Illustrious XL loaded!")
1168
+ print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}")
1169
+ print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}")
1170
+
1171
+ return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
1172
+
1173
+
1174
+ def load_sdxl_base(device: str = "cuda"):
1175
+ """Load standard SDXL base model."""
1176
+ print("📥 Loading SDXL Base 1.0...")
1177
+
1178
+ unet = UNet2DConditionModel.from_pretrained(
1179
+ "stabilityai/stable-diffusion-xl-base-1.0",
1180
+ subfolder="unet",
1181
+ torch_dtype=torch.float16
1182
+ ).to(device)
1183
+
1184
+ vae = AutoencoderKL.from_pretrained(
1185
+ "stabilityai/stable-diffusion-xl-base-1.0",
1186
+ subfolder="vae",
1187
+ torch_dtype=torch.float16
1188
+ ).to(device)
1189
+
1190
+ text_encoder = CLIPTextModel.from_pretrained(
1191
+ "stabilityai/stable-diffusion-xl-base-1.0",
1192
+ subfolder="text_encoder",
1193
+ torch_dtype=torch.float16
1194
+ ).to(device)
1195
+
1196
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
1197
+ "stabilityai/stable-diffusion-xl-base-1.0",
1198
+ subfolder="text_encoder_2",
1199
+ torch_dtype=torch.float16
1200
+ ).to(device)
1201
+
1202
+ tokenizer = CLIPTokenizer.from_pretrained(
1203
+ "stabilityai/stable-diffusion-xl-base-1.0",
1204
+ subfolder="tokenizer"
1205
+ )
1206
+
1207
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
1208
+ "stabilityai/stable-diffusion-xl-base-1.0",
1209
+ subfolder="tokenizer_2"
1210
+ )
1211
+
1212
+ print("✅ SDXL Base loaded!")
1213
 
1214
  return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
1215
 
 
1218
  # PIPELINE INITIALIZATION
1219
  # ============================================================================
1220
 
1221
+ def initialize_pipeline(model_choice: str, device: str = "cuda"):
1222
+ """Initialize the complete pipeline based on model choice.
1223
+
1224
+ Uses lazy loading for T5 and Lyra - they won't be downloaded until first use.
1225
+ """
 
1226
 
1227
  print(f"🚀 Initializing {model_choice} pipeline...")
1228
 
1229
+ is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
1230
+ is_lune = "Lune" in model_choice
1231
+
1232
+ if is_sdxl:
1233
+ # SDXL-based models
1234
+ if "Illustrious" in model_choice:
1235
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
1236
+ else:
1237
+ unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
1238
+
1239
+ # Create LAZY loaders for T5 and Lyra (no download yet!)
1240
+ print("📋 Configuring lazy loaders for T5-XL and Lyra VAE (will download on first use)")
1241
+ t5_loader = LazyT5Encoder(
1242
+ model_name=T5_XL_MODEL, # google/flan-t5-xl
1243
+ device=device,
1244
+ dtype=torch.float16
1245
+ )
1246
+ lyra_loader = LazyLyraModel(
1247
+ repo_id=LYRA_ILLUSTRIOUS_REPO,
1248
+ device=device
1249
+ )
1250
+
1251
+ # Default scheduler: Euler Ancestral
1252
+ scheduler = get_scheduler(SCHEDULER_EULER_A, is_sdxl=True)
1253
+
1254
+ pipeline = SDXLFlowMatchingPipeline(
1255
+ vae=vae,
1256
+ text_encoder=text_encoder,
1257
+ text_encoder_2=text_encoder_2,
1258
+ tokenizer=tokenizer,
1259
+ tokenizer_2=tokenizer_2,
1260
+ unet=unet,
1261
+ scheduler=scheduler,
1262
+ device=device,
1263
+ t5_loader=t5_loader,
1264
+ lyra_loader=lyra_loader,
1265
+ clip_skip=1
1266
+ )
1267
+
1268
  else:
1269
+ # SD1.5-based models
1270
+ vae = AutoencoderKL.from_pretrained(
1271
+ "runwayml/stable-diffusion-v1-5",
1272
+ subfolder="vae",
1273
+ torch_dtype=torch.float32
1274
+ ).to(device)
1275
+
1276
+ text_encoder = CLIPTextModel.from_pretrained(
1277
+ "openai/clip-vit-large-patch14",
1278
+ torch_dtype=torch.float32
1279
+ ).to(device)
1280
+
1281
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1282
+
1283
+ # Lazy loaders for SD1.5 Lyra (T5-base)
1284
+ print("📋 Configuring lazy loaders for T5-base and Lyra VAE v1 (will download on first use)")
1285
+ t5_loader = LazyT5Encoder(
1286
+ model_name=T5_BASE_MODEL, # google/flan-t5-base
1287
+ device=device,
1288
+ dtype=torch.float32
1289
  )
1290
+ lyra_loader = LazyLyraModel(
1291
+ repo_id=LYRA_SD15_REPO,
1292
+ device=device
1293
+ )
1294
+
1295
+ # Load UNet
1296
+ if is_lune:
1297
+ repo_id = "AbstractPhil/sd15-flow-lune"
1298
+ filename = "sd15_flow_lune_e34_s34000.pt"
1299
+ unet = load_lune_checkpoint(repo_id, filename, device)
1300
+ else:
1301
+ unet = UNet2DConditionModel.from_pretrained(
1302
+ "runwayml/stable-diffusion-v1-5",
1303
+ subfolder="unet",
1304
+ torch_dtype=torch.float32
1305
+ ).to(device)
1306
+
1307
+ scheduler = EulerDiscreteScheduler.from_pretrained(
1308
+ "runwayml/stable-diffusion-v1-5",
1309
+ subfolder="scheduler"
1310
+ )
1311
+
1312
+ pipeline = SD15FlowMatchingPipeline(
1313
+ vae=vae,
1314
+ text_encoder=text_encoder,
1315
+ tokenizer=tokenizer,
1316
+ unet=unet,
1317
+ scheduler=scheduler,
1318
+ device=device,
1319
+ t5_loader=t5_loader,
1320
+ lyra_loader=lyra_loader,
1321
+ )
1322
+
1323
+ pipeline.is_lune_model = is_lune
1324
 
1325
+ print("✅ Pipeline initialized! (T5 and Lyra will load on first use)")
1326
  return pipeline
1327
 
1328
 
 
1332
 
1333
  CURRENT_PIPELINE = None
1334
  CURRENT_MODEL = None
 
1335
 
1336
 
1337
+ def get_pipeline(model_choice: str):
1338
  """Get or create pipeline for selected model."""
1339
+ global CURRENT_PIPELINE, CURRENT_MODEL
1340
 
1341
  if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
1342
+ CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
1343
  CURRENT_MODEL = model_choice
 
 
 
 
1344
 
1345
  return CURRENT_PIPELINE
1346
 
 
1349
  # INFERENCE
1350
  # ============================================================================
1351
 
1352
+ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int:
1353
+ """Estimate GPU duration."""
1354
+ base_time_per_step = 0.5 if is_sdxl else 0.3
1355
+ resolution_factor = (width * height) / (512 * 512)
1356
+ estimated = num_steps * base_time_per_step * resolution_factor
1357
+
1358
+ if use_lyra:
1359
+ estimated *= 2
1360
+ estimated += 10 # Extra time for lazy loading on first use
1361
+
1362
+ return int(estimated + 20)
1363
+
1364
+
1365
+ @spaces.GPU(duration=lambda *args: estimate_duration(
1366
+ args[6], args[8], args[9], args[12],
1367
+ "SDXL" in args[3] or "Illustrious" in args[3]
1368
+ ))
1369
  def generate_image(
1370
  prompt: str,
1371
  t5_summary: str,
1372
  negative_prompt: str,
1373
  model_choice: str,
1374
+ scheduler_choice: str,
1375
  clip_skip: int,
1376
  num_steps: int,
1377
  cfg_scale: float,
1378
  width: int,
1379
  height: int,
1380
+ shift: float,
1381
+ use_flow_matching: bool,
1382
  use_lyra: bool,
1383
  lyra_strength: float,
1384
  seed: int,
 
1394
  progress((step + 1) / total, desc=desc)
1395
 
1396
  try:
1397
+ pipeline = get_pipeline(model_choice)
1398
+
1399
+ # Update scheduler if needed (SDXL only)
1400
+ is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
1401
+ if is_sdxl and hasattr(pipeline, 'set_scheduler'):
1402
+ pipeline.set_scheduler(scheduler_choice)
1403
 
1404
+ prediction_type = "epsilon"
1405
+ if not is_sdxl and "Lune" in model_choice:
1406
+ prediction_type = "v_prediction"
1407
+
1408
+ if not use_lyra or not pipeline.lyra_available:
1409
  progress(0.05, desc="Generating...")
1410
 
1411
  image = pipeline(
 
1415
  width=width,
1416
  num_inference_steps=num_steps,
1417
  guidance_scale=cfg_scale,
1418
+ shift=shift,
1419
+ use_flow_matching=use_flow_matching,
1420
+ prediction_type=prediction_type,
1421
  seed=seed,
1422
  use_lyra=False,
1423
  clip_skip=clip_skip,
 
1428
  return image, None, seed
1429
 
1430
  else:
1431
+ # Side-by-side comparison
1432
  progress(0.05, desc="Generating standard...")
1433
 
1434
  image_standard = pipeline(
 
1438
  width=width,
1439
  num_inference_steps=num_steps,
1440
  guidance_scale=cfg_scale,
1441
+ shift=shift,
1442
+ use_flow_matching=use_flow_matching,
1443
+ prediction_type=prediction_type,
1444
  seed=seed,
1445
  use_lyra=False,
1446
  clip_skip=clip_skip,
1447
  progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
1448
  )
1449
 
1450
+ progress(0.5, desc="Generating Lyra fusion (loading T5 + Lyra if needed)...")
1451
 
1452
  image_lyra = pipeline(
1453
  prompt=prompt,
 
1456
  width=width,
1457
  num_inference_steps=num_steps,
1458
  guidance_scale=cfg_scale,
1459
+ shift=shift,
1460
+ use_flow_matching=use_flow_matching,
1461
+ prediction_type=prediction_type,
1462
  seed=seed,
1463
  use_lyra=True,
1464
  clip_skip=clip_skip,
 
1486
 
1487
  with gr.Blocks() as demo:
1488
  gr.Markdown("""
1489
+ # 🌙 Lyra/Lune Flow-Matching Image Generation
1490
 
1491
  **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
1492
 
1493
+ Generate images using SD1.5 and SDXL-based models with geometric deep learning:
1494
+
1495
  | Model | Architecture | Lyra Version | Best For |
1496
  |-------|-------------|--------------|----------|
1497
  | **Illustrious XL** | SDXL | v2 (T5-XL) | Anime/illustration, high detail |
1498
  | **SDXL Base** | SDXL | v2 (T5-XL) | Photorealistic, general purpose |
1499
+ | **Flow-Lune** | SD1.5 | v1 (T5-base) | Fast flow matching (15-25 steps) |
1500
+ | **SD1.5 Base** | SD1.5 | v1 (T5-base) | Baseline comparison |
1501
 
1502
+ **Lazy Loading**: T5 and Lyra VAE are only downloaded when you enable Lyra fusion!
 
1503
  """)
1504
 
1505
  with gr.Row():
1506
  with gr.Column(scale=1):
1507
  prompt = gr.TextArea(
1508
+ label="Prompt (Tags for CLIP)",
1509
  value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
1510
  lines=3
1511
  )
1512
 
1513
  t5_summary = gr.TextArea(
1514
+ label="T5 Summary (Natural Language for Lyra)",
1515
+ value="A beautiful anime girl with flowing blue hair wearing a school uniform, surrounded by delicate pink cherry blossoms against a bright sky",
1516
  lines=2,
1517
+ info="Used after separator for T5. Leave empty to use tags only."
1518
  )
1519
 
1520
  negative_prompt = gr.TextArea(
1521
  label="Negative Prompt",
1522
+ value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality",
1523
  lines=2
1524
  )
1525
 
1526
+ model_choice = gr.Dropdown(
1527
+ label="Model",
1528
+ choices=[
1529
+ "Illustrious XL",
1530
+ "SDXL Base",
1531
+ "Flow-Lune (SD1.5)",
1532
+ "SD1.5 Base"
1533
+ ],
1534
+ value="Illustrious XL"
1535
+ )
1536
+
1537
+ scheduler_choice = gr.Dropdown(
1538
+ label="Scheduler (SDXL only)",
1539
+ choices=SCHEDULER_CHOICES,
1540
+ value=SCHEDULER_EULER_A,
1541
+ info="Euler Ancestral recommended for Illustrious"
1542
+ )
1543
 
1544
  clip_skip = gr.Slider(
1545
  label="CLIP Skip",
1546
+ minimum=1,
1547
+ maximum=4,
1548
+ value=2,
1549
+ step=1,
1550
+ info="2 recommended for Illustrious, 1 for others"
1551
  )
1552
 
1553
  use_lyra = gr.Checkbox(
1554
+ label="Enable Lyra VAE (CLIP+T5 Fusion)",
1555
  value=False,
1556
+ info="Enables lazy loading of T5 and Lyra on first use"
1557
  )
1558
 
1559
  lyra_strength = gr.Slider(
1560
  label="Lyra Blend Strength",
1561
+ minimum=0.0,
1562
+ maximum=3.0,
1563
+ value=1.0,
1564
+ step=0.05,
1565
+ info="0.0 = pure CLIP, 1.0 = pure Lyra reconstruction"
1566
  )
1567
 
1568
  with gr.Accordion("Generation Settings", open=True):
1569
+ num_steps = gr.Slider(
1570
+ label="Steps",
1571
+ minimum=1,
1572
+ maximum=50,
1573
+ value=25,
1574
+ step=1
1575
+ )
1576
+
1577
+ cfg_scale = gr.Slider(
1578
+ label="CFG Scale",
1579
+ minimum=1.0,
1580
+ maximum=20.0,
1581
+ value=7.0,
1582
+ step=0.5
1583
+ )
1584
 
1585
  with gr.Row():
1586
+ width = gr.Slider(
1587
+ label="Width",
1588
+ minimum=512,
1589
+ maximum=1536,
1590
+ value=1024,
1591
+ step=64
1592
+ )
1593
+ height = gr.Slider(
1594
+ label="Height",
1595
+ minimum=512,
1596
+ maximum=1536,
1597
+ value=1024,
1598
+ step=64
1599
+ )
1600
 
1601
+ seed = gr.Slider(
1602
+ label="Seed",
1603
+ minimum=0,
1604
+ maximum=2**32 - 1,
1605
+ value=42,
1606
+ step=1
1607
+ )
1608
+
1609
+ randomize_seed = gr.Checkbox(
1610
+ label="Randomize Seed",
1611
+ value=True
1612
+ )
1613
+
1614
+ with gr.Accordion("Advanced (Flow Matching)", open=False):
1615
+ use_flow_matching = gr.Checkbox(
1616
+ label="Enable Flow Matching",
1617
+ value=False,
1618
+ info="Use flow matching ODE (for Lune only)"
1619
+ )
1620
+
1621
+ shift = gr.Slider(
1622
+ label="Shift",
1623
+ minimum=0.0,
1624
+ maximum=5.0,
1625
+ value=0.0,
1626
+ step=0.1,
1627
+ info="Flow matching shift (0=disabled)"
1628
+ )
1629
 
1630
  generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
1631
 
1632
  with gr.Column(scale=1):
1633
  with gr.Row():
1634
+ output_image_standard = gr.Image(
1635
+ label="Generated Image",
1636
+ type="pil"
1637
+ )
1638
+ output_image_lyra = gr.Image(
1639
+ label="Lyra Fusion 🎵",
1640
+ type="pil",
1641
+ visible=False
1642
+ )
1643
 
1644
  output_seed = gr.Number(label="Seed", precision=0)
1645
+
1646
+ gr.Markdown("""
1647
+ ### Tips
1648
+ - **Lazy Loading**: T5-XL (~3GB) and Lyra VAE only download when you enable Lyra
1649
+ - **Illustrious XL**: Use CLIP skip 2, Euler Ancestral scheduler
1650
+ - **Schedulers**: DPM++ 2M SDE for detail, Euler A for speed
1651
+ - **Lyra v2**: Uses `google/flan-t5-xl` for richer semantics
1652
+ """)
1653
 
1654
  # Event handlers
1655
+ def on_model_change(model_name):
1656
+ """Update defaults based on model."""
1657
+ if "Illustrious" in model_name:
1658
+ return {
1659
+ clip_skip: gr.update(value=2),
1660
+ width: gr.update(value=1024),
1661
+ height: gr.update(value=1024),
1662
+ num_steps: gr.update(value=25),
1663
+ use_flow_matching: gr.update(value=False),
1664
+ shift: gr.update(value=0.0),
1665
+ scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A)
1666
+ }
1667
+ elif "SDXL" in model_name:
1668
+ return {
1669
+ clip_skip: gr.update(value=1),
1670
+ width: gr.update(value=1024),
1671
+ height: gr.update(value=1024),
1672
+ num_steps: gr.update(value=30),
1673
+ use_flow_matching: gr.update(value=False),
1674
+ shift: gr.update(value=0.0),
1675
+ scheduler_choice: gr.update(visible=True, value=SCHEDULER_EULER_A)
1676
+ }
1677
+ elif "Lune" in model_name:
1678
+ return {
1679
+ clip_skip: gr.update(value=1),
1680
+ width: gr.update(value=512),
1681
+ height: gr.update(value=512),
1682
+ num_steps: gr.update(value=20),
1683
+ use_flow_matching: gr.update(value=True),
1684
+ shift: gr.update(value=2.5),
1685
+ scheduler_choice: gr.update(visible=False)
1686
+ }
1687
+ else: # SD1.5 Base
1688
+ return {
1689
+ clip_skip: gr.update(value=1),
1690
+ width: gr.update(value=512),
1691
+ height: gr.update(value=512),
1692
+ num_steps: gr.update(value=30),
1693
+ use_flow_matching: gr.update(value=False),
1694
+ shift: gr.update(value=0.0),
1695
+ scheduler_choice: gr.update(visible=False)
1696
+ }
1697
+
1698
  def on_lyra_toggle(enabled):
1699
+ """Show/hide Lyra comparison."""
1700
  if enabled:
1701
  return {
1702
  output_image_standard: gr.update(visible=True, label="Standard"),
 
1708
  output_image_lyra: gr.update(visible=False)
1709
  }
1710
 
1711
+ model_choice.change(
1712
+ fn=on_model_change,
1713
+ inputs=[model_choice],
1714
+ outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift, scheduler_choice]
1715
+ )
1716
+
1717
  use_lyra.change(
1718
  fn=on_lyra_toggle,
1719
  inputs=[use_lyra],
 
1723
  generate_btn.click(
1724
  fn=generate_image,
1725
  inputs=[
1726
+ prompt, t5_summary, negative_prompt, model_choice, scheduler_choice, clip_skip,
1727
+ num_steps, cfg_scale, width, height, shift,
1728
+ use_flow_matching, use_lyra, lyra_strength, seed, randomize_seed
1729
  ],
1730
  outputs=[output_image_standard, output_image_lyra, output_seed]
1731
  )