AbstractPhil's picture
Update app.py
bdf00d2 verified
raw
history blame
55.6 kB
"""
Lyra/Lune Flow-Matching Inference Space
Author: AbstractPhil
License: MIT
SD1.5 and SDXL-based flow matching with geometric crystalline architectures.
Supports Illustrious XL, standard SDXL, and SD1.5 variants.
"""
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from typing import Optional, Dict, Tuple
import spaces
from safetensors.torch import load_file as load_safetensors
from diffusers import (
UNet2DConditionModel,
AutoencoderKL,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler
)
from diffusers.models import UNet2DConditionModel as DiffusersUNet
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPTextModelWithProjection,
T5EncoderModel,
T5Tokenizer
)
from huggingface_hub import hf_hub_download
from geofractal.models.vae.vae_lyra_v2 import MultiModalVAE, MultiModalVAEConfig
LYRA_AVAILABLE = True
# ============================================================================
# CONSTANTS
# ============================================================================
# Model architectures
ARCH_SD15 = "sd15"
ARCH_SDXL = "sdxl"
# ComfyUI key prefixes for SDXL single-file checkpoints
COMFYUI_UNET_PREFIX = "model.diffusion_model."
COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer."
COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model."
COMFYUI_VAE_PREFIX = "first_stage_model."
# ============================================================================
# MODEL LOADING UTILITIES
# ============================================================================
def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
"""Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint."""
components = {
"unet": {},
"clip_l": {},
"clip_g": {},
"vae": {}
}
for key, value in state_dict.items():
if key.startswith(COMFYUI_UNET_PREFIX):
new_key = key[len(COMFYUI_UNET_PREFIX):]
components["unet"][new_key] = value
elif key.startswith(COMFYUI_CLIP_L_PREFIX):
new_key = key[len(COMFYUI_CLIP_L_PREFIX):]
components["clip_l"][new_key] = value
elif key.startswith(COMFYUI_CLIP_G_PREFIX):
new_key = key[len(COMFYUI_CLIP_G_PREFIX):]
components["clip_g"][new_key] = value
elif key.startswith(COMFYUI_VAE_PREFIX):
new_key = key[len(COMFYUI_VAE_PREFIX):]
components["vae"][new_key] = value
print(f" Extracted components:")
print(f" UNet: {len(components['unet'])} keys")
print(f" CLIP-L: {len(components['clip_l'])} keys")
print(f" CLIP-G: {len(components['clip_g'])} keys")
print(f" VAE: {len(components['vae'])} keys")
return components
def get_clip_hidden_state(
model_output,
clip_skip: int = 1,
output_hidden_states: bool = True
) -> torch.Tensor:
"""Extract hidden state with clip_skip support."""
if clip_skip == 1 or not output_hidden_states:
return model_output.last_hidden_state
if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None:
# hidden_states is tuple: (embedding, layer1, ..., layerN)
# clip_skip=2 means penultimate layer = hidden_states[-2]
return model_output.hidden_states[-clip_skip]
return model_output.last_hidden_state
# ============================================================================
# SDXL PIPELINE
# ============================================================================
class SDXLFlowMatchingPipeline:
"""Pipeline for SDXL-based flow-matching inference with dual CLIP encoders."""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel, # CLIP-L
text_encoder_2: CLIPTextModelWithProjection, # CLIP-G
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler,
device: str = "cuda",
t5_encoder: Optional[T5EncoderModel] = None,
t5_tokenizer: Optional[T5Tokenizer] = None,
lyra_model: Optional[any] = None,
clip_skip: int = 1
):
self.vae = vae
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.unet = unet
self.scheduler = scheduler
self.device = device
# Lyra components
self.t5_encoder = t5_encoder
self.t5_tokenizer = t5_tokenizer
self.lyra_model = lyra_model
# Settings
self.clip_skip = clip_skip
self.vae_scale_factor = 0.13025 # SDXL VAE scaling
self.arch = ARCH_SDXL
def encode_prompt(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using dual CLIP encoders for SDXL."""
# CLIP-L encoding
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
output_hidden_states = clip_skip > 1
clip_l_output = self.text_encoder(
text_input_ids,
output_hidden_states=output_hidden_states
)
prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states)
# CLIP-G encoding
text_inputs_2 = self.tokenizer_2(
prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids_2 = text_inputs_2.input_ids.to(self.device)
with torch.no_grad():
clip_g_output = self.text_encoder_2(
text_input_ids_2,
output_hidden_states=output_hidden_states
)
prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states)
# Get pooled output from CLIP-G
pooled_prompt_embeds = clip_g_output.text_embeds
# Concatenate CLIP-L and CLIP-G embeddings
prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1)
# Negative prompt
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
uncond_inputs_2 = self.tokenizer_2(
negative_prompt,
padding="max_length",
max_length=self.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device)
with torch.no_grad():
uncond_output_l = self.text_encoder(
uncond_input_ids,
output_hidden_states=output_hidden_states
)
negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states)
uncond_output_g = self.text_encoder_2(
uncond_input_ids_2,
output_hidden_states=output_hidden_states
)
negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states)
negative_pooled = uncond_output_g.text_embeds
negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled = torch.zeros_like(pooled_prompt_embeds)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled
def encode_prompt_lyra(
self,
prompt: str,
negative_prompt: str = "",
clip_skip: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode prompts using Lyra VAE fusion (CLIP + T5)."""
if self.lyra_model is None or self.t5_encoder is None:
raise ValueError("Lyra VAE components not initialized")
# Get standard CLIP embeddings first
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
prompt, negative_prompt, clip_skip
)
# Get T5 embeddings
t5_inputs = self.t5_tokenizer(
prompt,
max_length=77,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
# For SDXL, we need to handle the concatenated CLIP-L + CLIP-G embeddings
# Split them, fuse CLIP-L through Lyra, then recombine
clip_l_dim = 768
clip_g_dim = 1280
clip_l_embeds = prompt_embeds[..., :clip_l_dim]
clip_g_embeds = prompt_embeds[..., clip_l_dim:]
# Fuse CLIP-L through Lyra
modality_inputs = {
'clip': clip_l_embeds,
't5': t5_embeds
}
with torch.no_grad():
reconstructions, mu, logvar = self.lyra_model(
modality_inputs,
target_modalities=['clip']
)
fused_clip_l = reconstructions['clip']
# Recombine with CLIP-G
prompt_embeds_fused = torch.cat([fused_clip_l, clip_g_embeds], dim=-1)
# Process negative prompt similarly if present
if negative_prompt:
t5_inputs_neg = self.t5_tokenizer(
negative_prompt,
max_length=77,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds_neg = self.t5_encoder(**t5_inputs_neg).last_hidden_state
neg_clip_l = negative_prompt_embeds[..., :clip_l_dim]
neg_clip_g = negative_prompt_embeds[..., clip_l_dim:]
modality_inputs_neg = {
'clip': neg_clip_l,
't5': t5_embeds_neg
}
with torch.no_grad():
reconstructions_neg, _, _ = self.lyra_model(
modality_inputs_neg,
target_modalities=['clip']
)
fused_neg_clip_l = reconstructions_neg['clip']
negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, neg_clip_g], dim=-1)
else:
negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused)
return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled
def _get_add_time_ids(
self,
original_size: Tuple[int, int],
crops_coords_top_left: Tuple[int, int],
target_size: Tuple[int, int],
dtype: torch.dtype
) -> torch.Tensor:
"""Create time embedding IDs for SDXL."""
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
return add_time_ids
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
shift: float = 0.0,
use_flow_matching: bool = False,
prediction_type: str = "epsilon",
seed: Optional[int] = None,
use_lyra: bool = False,
clip_skip: int = 1,
progress_callback=None
):
"""Generate image using SDXL architecture."""
# Set seed
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
# Encode prompts
if use_lyra and self.lyra_model is not None:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra(
prompt, negative_prompt, clip_skip
)
else:
prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt(
prompt, negative_prompt, clip_skip
)
# Prepare latents
latent_channels = 4
latent_height = height // 8
latent_width = width // 8
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
generator=generator,
device=self.device,
dtype=torch.float16
)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
# Scale initial latents
if not use_flow_matching:
latents = latents * self.scheduler.init_noise_sigma
# Prepare added time embeddings for SDXL
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=torch.float16
)
negative_add_time_ids = add_time_ids # Same for negative
# Denoising loop
for i, t in enumerate(timesteps):
if progress_callback:
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
# Expand for CFG
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
# Flow matching scaling
if use_flow_matching and shift > 0:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
scaling = torch.sqrt(1 + sigma_shifted ** 2)
latent_model_input = latent_model_input / scaling
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Prepare timestep
timestep = t.expand(latent_model_input.shape[0])
# Prepare added conditions
if guidance_scale > 1.0:
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_pooled, pooled])
add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids])
else:
text_embeds = prompt_embeds
add_text_embeds = pooled
add_time_ids_input = add_time_ids
# Prepare added cond kwargs for SDXL UNet
added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids_input
}
# Predict noise
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False
)[0]
# CFG
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Step
if use_flow_matching:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
if prediction_type == "v_prediction":
v_pred = noise_pred
alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
sigma_t = sigma_shifted
noise_pred = alpha_t * v_pred + sigma_t * latents
dt = -1.0 / num_inference_steps
latents = latents + dt * noise_pred
else:
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False
)[0]
# Decode
latents = latents / self.vae_scale_factor
with torch.no_grad():
image = self.vae.decode(latents.to(self.vae.dtype)).sample
# Convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
return image
# ============================================================================
# SD1.5 PIPELINE (Original)
# ============================================================================
class SD15FlowMatchingPipeline:
"""Pipeline for SD1.5-based flow-matching inference."""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler,
device: str = "cuda",
t5_encoder: Optional[T5EncoderModel] = None,
t5_tokenizer: Optional[T5Tokenizer] = None,
lyra_model: Optional[any] = None
):
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler
self.device = device
self.t5_encoder = t5_encoder
self.t5_tokenizer = t5_tokenizer
self.lyra_model = lyra_model
self.vae_scale_factor = 0.18215
self.arch = ARCH_SD15
self.is_lune_model = False
def encode_prompt(self, prompt: str, negative_prompt: str = ""):
"""Encode text prompts to embeddings."""
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
prompt_embeds = self.text_encoder(text_input_ids)[0]
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
with torch.no_grad():
negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0]
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
return prompt_embeds, negative_prompt_embeds
def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""):
"""Encode using Lyra VAE (CLIP + T5 fusion)."""
if self.lyra_model is None or self.t5_encoder is None:
raise ValueError("Lyra VAE components not initialized")
# CLIP
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.device)
with torch.no_grad():
clip_embeds = self.text_encoder(text_input_ids)[0]
# T5
t5_inputs = self.t5_tokenizer(
prompt,
max_length=77,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state
# Fuse
modality_inputs = {'clip': clip_embeds, 't5': t5_embeds}
with torch.no_grad():
reconstructions, mu, logvar = self.lyra_model(
modality_inputs,
target_modalities=['clip']
)
prompt_embeds = reconstructions['clip']
# Negative
if negative_prompt:
uncond_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
uncond_input_ids = uncond_inputs.input_ids.to(self.device)
with torch.no_grad():
clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0]
t5_inputs_uncond = self.t5_tokenizer(
negative_prompt,
max_length=77,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(self.device)
with torch.no_grad():
t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state
modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond}
with torch.no_grad():
reconstructions_uncond, _, _ = self.lyra_model(
modality_inputs_uncond,
target_modalities=['clip']
)
negative_prompt_embeds = reconstructions_uncond['clip']
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
height: int = 512,
width: int = 512,
num_inference_steps: int = 20,
guidance_scale: float = 7.5,
shift: float = 2.5,
use_flow_matching: bool = True,
prediction_type: str = "epsilon",
seed: Optional[int] = None,
use_lyra: bool = False,
clip_skip: int = 1, # Unused for SD1.5 but kept for API consistency
progress_callback=None
):
"""Generate image."""
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
else:
generator = None
if use_lyra and self.lyra_model is not None:
prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt)
latent_channels = 4
latent_height = height // 8
latent_width = width // 8
latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
generator=generator,
device=self.device,
dtype=torch.float32
)
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps = self.scheduler.timesteps
if not use_flow_matching:
latents = latents * self.scheduler.init_noise_sigma
for i, t in enumerate(timesteps):
if progress_callback:
progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}")
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
if use_flow_matching and shift > 0:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
scaling = torch.sqrt(1 + sigma_shifted ** 2)
latent_model_input = latent_model_input / scaling
else:
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
noise_pred = self.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
return_dict=False
)[0]
if guidance_scale > 1.0:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if use_flow_matching:
sigma = t.float() / 1000.0
sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma)
if prediction_type == "v_prediction":
v_pred = noise_pred
alpha_t = torch.sqrt(1 - sigma_shifted ** 2)
sigma_t = sigma_shifted
noise_pred = alpha_t * v_pred + sigma_t * latents
dt = -1.0 / num_inference_steps
latents = latents + dt * noise_pred
else:
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = latents / self.vae_scale_factor
if self.is_lune_model:
latents = latents * 5.52
with torch.no_grad():
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
image = Image.fromarray(image[0])
return image
# ============================================================================
# MODEL LOADERS
# ============================================================================
def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"):
"""Load Lune checkpoint from .pt file."""
print(f"📥 Downloading: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"🏗️ Initializing SD1.5 UNet...")
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=torch.float32
)
student_state_dict = checkpoint["student"]
cleaned_dict = {}
for key, value in student_state_dict.items():
if key.startswith("unet."):
cleaned_dict[key[5:]] = value
else:
cleaned_dict[key] = value
unet.load_state_dict(cleaned_dict, strict=False)
step = checkpoint.get("gstep", "unknown")
print(f"✅ Loaded Lune from step {step}")
return unet.to(device)
def load_illustrious_xl(
repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
filename: str = "illustriousXL_v01.safetensors",
device: str = "cuda"
) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
"""Load Illustrious XL from single safetensors file."""
print(f"📥 Downloading Illustrious XL: {repo_id}/{filename}")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
print(f"✓ Downloaded: {checkpoint_path}")
print("📦 Loading safetensors...")
state_dict = load_safetensors(checkpoint_path)
# Extract components
components = extract_comfyui_components(state_dict)
# Load UNet from SDXL base config, then load weights
print("🏗️ Initializing SDXL UNet...")
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet",
torch_dtype=torch.float16
)
if components["unet"]:
missing, unexpected = unet.load_state_dict(components["unet"], strict=False)
print(f" UNet: {len(missing)} missing, {len(unexpected)} unexpected keys")
# Load VAE
print("🏗️ Initializing SDXL VAE...")
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
torch_dtype=torch.float16
)
if components["vae"]:
missing, unexpected = vae.load_state_dict(components["vae"], strict=False)
print(f" VAE: {len(missing)} missing, {len(unexpected)} unexpected keys")
# Load CLIP-L
print("🏗️ Loading CLIP-L...")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.float16
)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# Load CLIP-G
print("🏗️ Loading CLIP-G...")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
torch_dtype=torch.float16
)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# Move to device
unet = unet.to(device)
vae = vae.to(device)
text_encoder = text_encoder.to(device)
text_encoder_2 = text_encoder_2.to(device)
print("✅ Illustrious XL loaded!")
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
def load_sdxl_base(device: str = "cuda"):
"""Load standard SDXL base model."""
print("📥 Loading SDXL Base 1.0...")
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet",
torch_dtype=torch.float16
).to(device)
vae = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae",
torch_dtype=torch.float16
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="text_encoder",
torch_dtype=torch.float16
).to(device)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="text_encoder_2",
torch_dtype=torch.float16
).to(device)
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="tokenizer"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="tokenizer_2"
)
print("✅ SDXL Base loaded!")
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"):
"""Load Lyra VAE (SD1.5 version) from HuggingFace."""
if not LYRA_AVAILABLE:
print("⚠️ Lyra VAE not available")
return None
print(f"🎵 Loading Lyra VAE from {repo_id}...")
try:
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename="best_model.pt",
repo_type="model"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if 'config' in checkpoint:
config_dict = checkpoint['config']
else:
config_dict = {
'modality_dims': {"clip": 768, "t5": 768},
'latent_dim': 768,
'seq_len': 77,
'encoder_layers': 3,
'decoder_layers': 3,
'hidden_dim': 1024,
'dropout': 0.1,
'fusion_strategy': 'cantor',
'fusion_heads': 8,
'fusion_dropout': 0.1
}
vae_config = MultiModalVAEConfig(
modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}),
latent_dim=config_dict.get('latent_dim', 768),
seq_len=config_dict.get('seq_len', 77),
encoder_layers=config_dict.get('encoder_layers', 3),
decoder_layers=config_dict.get('decoder_layers', 3),
hidden_dim=config_dict.get('hidden_dim', 1024),
dropout=config_dict.get('dropout', 0.1),
fusion_strategy=config_dict.get('fusion_strategy', 'cantor'),
fusion_heads=config_dict.get('fusion_heads', 8),
fusion_dropout=config_dict.get('fusion_dropout', 0.1)
)
lyra_model = MultiModalVAE(vae_config)
if 'model_state_dict' in checkpoint:
lyra_model.load_state_dict(checkpoint['model_state_dict'])
else:
lyra_model.load_state_dict(checkpoint)
lyra_model.to(device)
lyra_model.eval()
print(f"✅ Lyra VAE (SD1.5) loaded")
return lyra_model
except Exception as e:
print(f"❌ Failed to load Lyra VAE: {e}")
return None
def load_lyra_vae_xl(
repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious",
device: str = "cuda"
):
"""Load Lyra VAE XL version for SDXL/Illustrious."""
if not LYRA_AVAILABLE:
print("⚠️ Lyra VAE not available")
return None
print(f"🎵 Loading Lyra VAE XL from {repo_id}...")
try:
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename="best_model.pt",
repo_type="model"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if 'config' in checkpoint:
config_dict = checkpoint['config']
else:
# XL defaults - note larger dimensions
config_dict = {
'modality_dims': {"clip": 768, "t5": 2048}, # T5-XL
'latent_dim': 2048,
'seq_len': 77,
'encoder_layers': 4,
'decoder_layers': 4,
'hidden_dim': 2048,
'dropout': 0.1,
'fusion_strategy': 'adaptive_cantor',
'fusion_heads': 16,
'fusion_dropout': 0.1
}
vae_config = MultiModalVAEConfig(
modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 2048}),
latent_dim=config_dict.get('latent_dim', 2048),
seq_len=config_dict.get('seq_len', 77),
encoder_layers=config_dict.get('encoder_layers', 4),
decoder_layers=config_dict.get('decoder_layers', 4),
hidden_dim=config_dict.get('hidden_dim', 2048),
dropout=config_dict.get('dropout', 0.1),
fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'),
fusion_heads=config_dict.get('fusion_heads', 16),
fusion_dropout=config_dict.get('fusion_dropout', 0.1)
)
lyra_model = MultiModalVAE(vae_config)
if 'model_state_dict' in checkpoint:
lyra_model.load_state_dict(checkpoint['model_state_dict'])
else:
lyra_model.load_state_dict(checkpoint)
lyra_model.to(device)
lyra_model.eval()
print(f"✅ Lyra VAE XL loaded")
if 'global_step' in checkpoint:
print(f" Step: {checkpoint['global_step']:,}")
return lyra_model
except Exception as e:
print(f"❌ Failed to load Lyra VAE XL: {e}")
return None
# ============================================================================
# PIPELINE INITIALIZATION
# ============================================================================
def initialize_pipeline(model_choice: str, device: str = "cuda"):
"""Initialize the complete pipeline based on model choice."""
print(f"🚀 Initializing {model_choice} pipeline...")
# Determine architecture
is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice
is_lune = "Lune" in model_choice
if is_sdxl:
# SDXL-based models
if "Illustrious" in model_choice:
unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device)
else:
unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device)
# T5-XL for Lyra
print("Loading T5-XL encoder...")
t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xl")
t5_encoder = T5EncoderModel.from_pretrained(
"google/t5-v1_1-xl",
torch_dtype=torch.float16
).to(device)
t5_encoder.eval()
print("✓ T5-XL loaded")
# Lyra XL
lyra_model = load_lyra_vae_xl(device=device)
# Scheduler (epsilon for SDXL)
scheduler = EulerDiscreteScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="scheduler"
)
pipeline = SDXLFlowMatchingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
device=device,
t5_encoder=t5_encoder,
t5_tokenizer=t5_tokenizer,
lyra_model=lyra_model,
clip_skip=1
)
else:
# SD1.5-based models
vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
torch_dtype=torch.float32
).to(device)
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14",
torch_dtype=torch.float32
).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# T5-base for SD1.5 Lyra
print("Loading T5-base encoder...")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
t5_encoder = T5EncoderModel.from_pretrained(
"t5-base",
torch_dtype=torch.float32
).to(device)
t5_encoder.eval()
print("✓ T5-base loaded")
# Lyra (SD1.5 version)
lyra_model = load_lyra_vae(device=device)
# Load UNet
if is_lune:
repo_id = "AbstractPhil/sd15-flow-lune"
filename = "sd15_flow_lune_e34_s34000.pt"
unet = load_lune_checkpoint(repo_id, filename, device)
else:
unet = UNet2DConditionModel.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet",
torch_dtype=torch.float32
).to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="scheduler"
)
pipeline = SD15FlowMatchingPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
device=device,
t5_encoder=t5_encoder,
t5_tokenizer=t5_tokenizer,
lyra_model=lyra_model
)
pipeline.is_lune_model = is_lune
print("✅ Pipeline initialized!")
return pipeline
# ============================================================================
# GLOBAL STATE
# ============================================================================
CURRENT_PIPELINE = None
CURRENT_MODEL = None
def get_pipeline(model_choice: str):
"""Get or create pipeline for selected model."""
global CURRENT_PIPELINE, CURRENT_MODEL
if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice:
CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda")
CURRENT_MODEL = model_choice
return CURRENT_PIPELINE
# ============================================================================
# INFERENCE
# ============================================================================
def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int:
"""Estimate GPU duration."""
base_time_per_step = 0.5 if is_sdxl else 0.3
resolution_factor = (width * height) / (512 * 512)
estimated = num_steps * base_time_per_step * resolution_factor
if use_lyra:
estimated *= 2
estimated += 3
return int(estimated + 20)
@spaces.GPU(duration=lambda *args: estimate_duration(
args[4], args[6], args[7], args[10],
"SDXL" in args[2] or "Illustrious" in args[2]
))
def generate_image(
prompt: str,
negative_prompt: str,
model_choice: str,
clip_skip: int,
num_steps: int,
cfg_scale: float,
width: int,
height: int,
shift: float,
use_flow_matching: bool,
use_lyra: bool,
seed: int,
randomize_seed: bool,
progress=gr.Progress()
):
"""Generate image with ZeroGPU support."""
if randomize_seed:
seed = np.random.randint(0, 2**32 - 1)
def progress_callback(step, total, desc):
progress((step + 1) / total, desc=desc)
try:
pipeline = get_pipeline(model_choice)
# Determine prediction type based on model
is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice
prediction_type = "epsilon" # SDXL always uses epsilon
if not is_sdxl and "Lune" in model_choice:
prediction_type = "v_prediction"
if not use_lyra or pipeline.lyra_model is None:
progress(0.05, desc="Generating...")
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed,
use_lyra=False,
clip_skip=clip_skip,
progress_callback=progress_callback
)
progress(1.0, desc="Complete!")
return image, None, seed
else:
progress(0.05, desc="Generating standard...")
image_standard = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed,
use_lyra=False,
clip_skip=clip_skip,
progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d)
)
progress(0.5, desc="Generating Lyra fusion...")
image_lyra = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_steps,
guidance_scale=cfg_scale,
shift=shift,
use_flow_matching=use_flow_matching,
prediction_type=prediction_type,
seed=seed,
use_lyra=True,
clip_skip=clip_skip,
progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d)
)
progress(1.0, desc="Complete!")
return image_standard, image_lyra, seed
except Exception as e:
print(f"❌ Generation failed: {e}")
raise e
# ============================================================================
# GRADIO UI
# ============================================================================
def create_demo():
"""Create Gradio interface."""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌙 Lyra/Lune Flow-Matching Image Generation
**Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil)
Generate images using SD1.5 and SDXL-based models with geometric deep learning:
| Model | Architecture | Best For |
|-------|-------------|----------|
| **Illustrious XL** | SDXL | Anime/illustration, high detail |
| **SDXL Base** | SDXL | Photorealistic, general purpose |
| **Flow-Lune** | SD1.5 | Fast flow matching (15-25 steps) |
| **SD1.5 Base** | SD1.5 | Baseline comparison |
Enable **Lyra VAE** for CLIP+T5 fusion comparison!
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.TextArea(
label="Prompt",
value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
lines=3
)
negative_prompt = gr.TextArea(
label="Negative Prompt",
value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality",
lines=2
)
model_choice = gr.Dropdown(
label="Model",
choices=[
"Illustrious XL",
"SDXL Base",
"Flow-Lune (SD1.5)",
"SD1.5 Base"
],
value="Illustrious XL"
)
clip_skip = gr.Slider(
label="CLIP Skip",
minimum=1,
maximum=4,
value=2,
step=1,
info="2 recommended for Illustrious, 1 for others"
)
use_lyra = gr.Checkbox(
label="Enable Lyra VAE (CLIP+T5 Fusion)",
value=False,
info="Compare standard vs geometric fusion"
)
with gr.Accordion("Generation Settings", open=True):
num_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
value=25,
step=1
)
cfg_scale = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=20.0,
value=7.0,
step=0.5
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=1536,
value=1024,
step=64
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=1536,
value=1024,
step=64
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**32 - 1,
value=42,
step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
with gr.Accordion("Advanced (Flow Matching)", open=False):
use_flow_matching = gr.Checkbox(
label="Enable Flow Matching",
value=False,
info="Use flow matching ODE (for Lune only)"
)
shift = gr.Slider(
label="Shift",
minimum=0.0,
maximum=5.0,
value=0.0,
step=0.1,
info="Flow matching shift (0=disabled)"
)
generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
with gr.Row():
output_image_standard = gr.Image(
label="Generated Image",
type="pil"
)
output_image_lyra = gr.Image(
label="Lyra Fusion 🎵",
type="pil",
visible=False
)
output_seed = gr.Number(label="Seed", precision=0)
gr.Markdown("""
### Tips
- **Illustrious XL**: Use CLIP skip 2, booru-style tags
- **SDXL Base**: Natural language prompts work well
- **Flow-Lune**: Enable flow matching, shift ~2.5, fewer steps
- **Lyra**: Generates both standard and fused for comparison
### Model Info
- SDXL models use **epsilon** prediction
- Lune uses **v_prediction** with flow matching
- Lyra fuses CLIP + T5 for richer semantics
""")
# Examples
gr.Examples(
examples=[
[
"masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background",
"lowres, bad anatomy, worst quality, low quality",
"Illustrious XL",
2, 25, 7.0, 1024, 1024, 0.0, False, False, 42, False
],
[
"A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k",
"blurry, low quality",
"SDXL Base",
1, 30, 7.5, 1024, 1024, 0.0, False, False, 123, False
],
[
"cyberpunk city at night, neon lights, rain, highly detailed",
"low quality, blurry",
"Flow-Lune (SD1.5)",
1, 20, 7.5, 512, 512, 2.5, True, False, 456, False
],
],
inputs=[
prompt, negative_prompt, model_choice, clip_skip,
num_steps, cfg_scale, width, height, shift,
use_flow_matching, use_lyra, seed, randomize_seed
],
outputs=[output_image_standard, output_image_lyra, output_seed],
fn=generate_image,
cache_examples=False
)
# Event handlers
def on_model_change(model_name):
"""Update defaults based on model."""
if "Illustrious" in model_name:
return {
clip_skip: gr.update(value=2),
width: gr.update(value=1024),
height: gr.update(value=1024),
num_steps: gr.update(value=25),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0)
}
elif "SDXL" in model_name:
return {
clip_skip: gr.update(value=1),
width: gr.update(value=1024),
height: gr.update(value=1024),
num_steps: gr.update(value=30),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0)
}
elif "Lune" in model_name:
return {
clip_skip: gr.update(value=1),
width: gr.update(value=512),
height: gr.update(value=512),
num_steps: gr.update(value=20),
use_flow_matching: gr.update(value=True),
shift: gr.update(value=2.5)
}
else: # SD1.5 Base
return {
clip_skip: gr.update(value=1),
width: gr.update(value=512),
height: gr.update(value=512),
num_steps: gr.update(value=30),
use_flow_matching: gr.update(value=False),
shift: gr.update(value=0.0)
}
def on_lyra_toggle(enabled):
"""Show/hide Lyra comparison."""
if enabled:
return {
output_image_standard: gr.update(visible=True, label="Standard"),
output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵")
}
else:
return {
output_image_standard: gr.update(visible=True, label="Generated Image"),
output_image_lyra: gr.update(visible=False)
}
model_choice.change(
fn=on_model_change,
inputs=[model_choice],
outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift]
)
use_lyra.change(
fn=on_lyra_toggle,
inputs=[use_lyra],
outputs=[output_image_standard, output_image_lyra]
)
generate_btn.click(
fn=generate_image,
inputs=[
prompt, negative_prompt, model_choice, clip_skip,
num_steps, cfg_scale, width, height, shift,
use_flow_matching, use_lyra, seed, randomize_seed
],
outputs=[output_image_standard, output_image_lyra, output_seed]
)
return demo
# ============================================================================
# LAUNCH
# ============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=20)
demo.launch(show_api=False)