""" Lyra/Lune Flow-Matching Inference Space Author: AbstractPhil License: MIT SD1.5 and SDXL-based flow matching with geometric crystalline architectures. Supports Illustrious XL, standard SDXL, and SD1.5 variants. """ import os import torch import gradio as gr import numpy as np from PIL import Image from typing import Optional, Dict, Tuple import spaces from safetensors.torch import load_file as load_safetensors from diffusers import ( UNet2DConditionModel, AutoencoderKL, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler ) from diffusers.models import UNet2DConditionModel as DiffusersUNet from transformers import ( CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5EncoderModel, T5Tokenizer ) from huggingface_hub import hf_hub_download from geofractal.models.vae.vae_lyra_v2 import MultiModalVAE, MultiModalVAEConfig LYRA_AVAILABLE = True # ============================================================================ # CONSTANTS # ============================================================================ # Model architectures ARCH_SD15 = "sd15" ARCH_SDXL = "sdxl" # ComfyUI key prefixes for SDXL single-file checkpoints COMFYUI_UNET_PREFIX = "model.diffusion_model." COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer." COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model." COMFYUI_VAE_PREFIX = "first_stage_model." # ============================================================================ # MODEL LOADING UTILITIES # ============================================================================ def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: """Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint.""" components = { "unet": {}, "clip_l": {}, "clip_g": {}, "vae": {} } for key, value in state_dict.items(): if key.startswith(COMFYUI_UNET_PREFIX): new_key = key[len(COMFYUI_UNET_PREFIX):] components["unet"][new_key] = value elif key.startswith(COMFYUI_CLIP_L_PREFIX): new_key = key[len(COMFYUI_CLIP_L_PREFIX):] components["clip_l"][new_key] = value elif key.startswith(COMFYUI_CLIP_G_PREFIX): new_key = key[len(COMFYUI_CLIP_G_PREFIX):] components["clip_g"][new_key] = value elif key.startswith(COMFYUI_VAE_PREFIX): new_key = key[len(COMFYUI_VAE_PREFIX):] components["vae"][new_key] = value print(f" Extracted components:") print(f" UNet: {len(components['unet'])} keys") print(f" CLIP-L: {len(components['clip_l'])} keys") print(f" CLIP-G: {len(components['clip_g'])} keys") print(f" VAE: {len(components['vae'])} keys") return components def get_clip_hidden_state( model_output, clip_skip: int = 1, output_hidden_states: bool = True ) -> torch.Tensor: """Extract hidden state with clip_skip support.""" if clip_skip == 1 or not output_hidden_states: return model_output.last_hidden_state if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None: # hidden_states is tuple: (embedding, layer1, ..., layerN) # clip_skip=2 means penultimate layer = hidden_states[-2] return model_output.hidden_states[-clip_skip] return model_output.last_hidden_state # ============================================================================ # SDXL PIPELINE # ============================================================================ class SDXLFlowMatchingPipeline: """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders.""" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, # CLIP-L text_encoder_2: CLIPTextModelWithProjection, # CLIP-G tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler, device: str = "cuda", t5_encoder: Optional[T5EncoderModel] = None, t5_tokenizer: Optional[T5Tokenizer] = None, lyra_model: Optional[any] = None, clip_skip: int = 1 ): self.vae = vae self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.unet = unet self.scheduler = scheduler self.device = device # Lyra components self.t5_encoder = t5_encoder self.t5_tokenizer = t5_tokenizer self.lyra_model = lyra_model # Settings self.clip_skip = clip_skip self.vae_scale_factor = 0.13025 # SDXL VAE scaling self.arch = ARCH_SDXL def encode_prompt( self, prompt: str, negative_prompt: str = "", clip_skip: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Encode prompts using dual CLIP encoders for SDXL.""" # CLIP-L encoding text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): output_hidden_states = clip_skip > 1 clip_l_output = self.text_encoder( text_input_ids, output_hidden_states=output_hidden_states ) prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states) # CLIP-G encoding text_inputs_2 = self.tokenizer_2( prompt, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids_2 = text_inputs_2.input_ids.to(self.device) with torch.no_grad(): clip_g_output = self.text_encoder_2( text_input_ids_2, output_hidden_states=output_hidden_states ) prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states) # Get pooled output from CLIP-G pooled_prompt_embeds = clip_g_output.text_embeds # Concatenate CLIP-L and CLIP-G embeddings prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1) # Negative prompt if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) uncond_inputs_2 = self.tokenizer_2( negative_prompt, padding="max_length", max_length=self.tokenizer_2.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device) with torch.no_grad(): uncond_output_l = self.text_encoder( uncond_input_ids, output_hidden_states=output_hidden_states ) negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states) uncond_output_g = self.text_encoder_2( uncond_input_ids_2, output_hidden_states=output_hidden_states ) negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states) negative_pooled = uncond_output_g.text_embeds negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1) else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled = torch.zeros_like(pooled_prompt_embeds) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled def encode_prompt_lyra( self, prompt: str, negative_prompt: str = "", clip_skip: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Encode prompts using Lyra VAE fusion (CLIP + T5).""" if self.lyra_model is None or self.t5_encoder is None: raise ValueError("Lyra VAE components not initialized") # Get standard CLIP embeddings first prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( prompt, negative_prompt, clip_skip ) # Get T5 embeddings t5_inputs = self.t5_tokenizer( prompt, max_length=77, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state # For SDXL, we need to handle the concatenated CLIP-L + CLIP-G embeddings # Split them, fuse CLIP-L through Lyra, then recombine clip_l_dim = 768 clip_g_dim = 1280 clip_l_embeds = prompt_embeds[..., :clip_l_dim] clip_g_embeds = prompt_embeds[..., clip_l_dim:] # Fuse CLIP-L through Lyra modality_inputs = { 'clip': clip_l_embeds, 't5': t5_embeds } with torch.no_grad(): reconstructions, mu, logvar = self.lyra_model( modality_inputs, target_modalities=['clip'] ) fused_clip_l = reconstructions['clip'] # Recombine with CLIP-G prompt_embeds_fused = torch.cat([fused_clip_l, clip_g_embeds], dim=-1) # Process negative prompt similarly if present if negative_prompt: t5_inputs_neg = self.t5_tokenizer( negative_prompt, max_length=77, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds_neg = self.t5_encoder(**t5_inputs_neg).last_hidden_state neg_clip_l = negative_prompt_embeds[..., :clip_l_dim] neg_clip_g = negative_prompt_embeds[..., clip_l_dim:] modality_inputs_neg = { 'clip': neg_clip_l, 't5': t5_embeds_neg } with torch.no_grad(): reconstructions_neg, _, _ = self.lyra_model( modality_inputs_neg, target_modalities=['clip'] ) fused_neg_clip_l = reconstructions_neg['clip'] negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, neg_clip_g], dim=-1) else: negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused) return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled def _get_add_time_ids( self, original_size: Tuple[int, int], crops_coords_top_left: Tuple[int, int], target_size: Tuple[int, int], dtype: torch.dtype ) -> torch.Tensor: """Create time embedding IDs for SDXL.""" add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) return add_time_ids @torch.no_grad() def __call__( self, prompt: str, negative_prompt: str = "", height: int = 1024, width: int = 1024, num_inference_steps: int = 20, guidance_scale: float = 7.5, shift: float = 0.0, use_flow_matching: bool = False, prediction_type: str = "epsilon", seed: Optional[int] = None, use_lyra: bool = False, clip_skip: int = 1, progress_callback=None ): """Generate image using SDXL architecture.""" # Set seed if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) else: generator = None # Encode prompts if use_lyra and self.lyra_model is not None: prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra( prompt, negative_prompt, clip_skip ) else: prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( prompt, negative_prompt, clip_skip ) # Prepare latents latent_channels = 4 latent_height = height // 8 latent_width = width // 8 latents = torch.randn( (1, latent_channels, latent_height, latent_width), generator=generator, device=self.device, dtype=torch.float16 ) # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps # Scale initial latents if not use_flow_matching: latents = latents * self.scheduler.init_noise_sigma # Prepare added time embeddings for SDXL original_size = (height, width) target_size = (height, width) crops_coords_top_left = (0, 0) add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=torch.float16 ) negative_add_time_ids = add_time_ids # Same for negative # Denoising loop for i, t in enumerate(timesteps): if progress_callback: progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") # Expand for CFG latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents # Flow matching scaling if use_flow_matching and shift > 0: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) scaling = torch.sqrt(1 + sigma_shifted ** 2) latent_model_input = latent_model_input / scaling else: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # Prepare timestep timestep = t.expand(latent_model_input.shape[0]) # Prepare added conditions if guidance_scale > 1.0: text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) add_text_embeds = torch.cat([negative_pooled, pooled]) add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids]) else: text_embeds = prompt_embeds add_text_embeds = pooled add_time_ids_input = add_time_ids # Prepare added cond kwargs for SDXL UNet added_cond_kwargs = { "text_embeds": add_text_embeds, "time_ids": add_time_ids_input } # Predict noise noise_pred = self.unet( latent_model_input, timestep, encoder_hidden_states=text_embeds, added_cond_kwargs=added_cond_kwargs, return_dict=False )[0] # CFG if guidance_scale > 1.0: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Step if use_flow_matching: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) if prediction_type == "v_prediction": v_pred = noise_pred alpha_t = torch.sqrt(1 - sigma_shifted ** 2) sigma_t = sigma_shifted noise_pred = alpha_t * v_pred + sigma_t * latents dt = -1.0 / num_inference_steps latents = latents + dt * noise_pred else: latents = self.scheduler.step( noise_pred, t, latents, return_dict=False )[0] # Decode latents = latents / self.vae_scale_factor with torch.no_grad(): image = self.vae.decode(latents.to(self.vae.dtype)).sample # Convert to PIL image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image[0]) return image # ============================================================================ # SD1.5 PIPELINE (Original) # ============================================================================ class SD15FlowMatchingPipeline: """Pipeline for SD1.5-based flow-matching inference.""" def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler, device: str = "cuda", t5_encoder: Optional[T5EncoderModel] = None, t5_tokenizer: Optional[T5Tokenizer] = None, lyra_model: Optional[any] = None ): self.vae = vae self.text_encoder = text_encoder self.tokenizer = tokenizer self.unet = unet self.scheduler = scheduler self.device = device self.t5_encoder = t5_encoder self.t5_tokenizer = t5_tokenizer self.lyra_model = lyra_model self.vae_scale_factor = 0.18215 self.arch = ARCH_SD15 self.is_lune_model = False def encode_prompt(self, prompt: str, negative_prompt: str = ""): """Encode text prompts to embeddings.""" text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): prompt_embeds = self.text_encoder(text_input_ids)[0] if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) with torch.no_grad(): negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) return prompt_embeds, negative_prompt_embeds def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""): """Encode using Lyra VAE (CLIP + T5 fusion).""" if self.lyra_model is None or self.t5_encoder is None: raise ValueError("Lyra VAE components not initialized") # CLIP text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.device) with torch.no_grad(): clip_embeds = self.text_encoder(text_input_ids)[0] # T5 t5_inputs = self.t5_tokenizer( prompt, max_length=77, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state # Fuse modality_inputs = {'clip': clip_embeds, 't5': t5_embeds} with torch.no_grad(): reconstructions, mu, logvar = self.lyra_model( modality_inputs, target_modalities=['clip'] ) prompt_embeds = reconstructions['clip'] # Negative if negative_prompt: uncond_inputs = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_inputs.input_ids.to(self.device) with torch.no_grad(): clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0] t5_inputs_uncond = self.t5_tokenizer( negative_prompt, max_length=77, padding='max_length', truncation=True, return_tensors='pt' ).to(self.device) with torch.no_grad(): t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond} with torch.no_grad(): reconstructions_uncond, _, _ = self.lyra_model( modality_inputs_uncond, target_modalities=['clip'] ) negative_prompt_embeds = reconstructions_uncond['clip'] else: negative_prompt_embeds = torch.zeros_like(prompt_embeds) return prompt_embeds, negative_prompt_embeds @torch.no_grad() def __call__( self, prompt: str, negative_prompt: str = "", height: int = 512, width: int = 512, num_inference_steps: int = 20, guidance_scale: float = 7.5, shift: float = 2.5, use_flow_matching: bool = True, prediction_type: str = "epsilon", seed: Optional[int] = None, use_lyra: bool = False, clip_skip: int = 1, # Unused for SD1.5 but kept for API consistency progress_callback=None ): """Generate image.""" if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) else: generator = None if use_lyra and self.lyra_model is not None: prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt) else: prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) latent_channels = 4 latent_height = height // 8 latent_width = width // 8 latents = torch.randn( (1, latent_channels, latent_height, latent_width), generator=generator, device=self.device, dtype=torch.float32 ) self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps if not use_flow_matching: latents = latents * self.scheduler.init_noise_sigma for i, t in enumerate(timesteps): if progress_callback: progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents if use_flow_matching and shift > 0: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) scaling = torch.sqrt(1 + sigma_shifted ** 2) latent_model_input = latent_model_input / scaling else: latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timestep = t.expand(latent_model_input.shape[0]) text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds noise_pred = self.unet( latent_model_input, timestep, encoder_hidden_states=text_embeds, return_dict=False )[0] if guidance_scale > 1.0: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if use_flow_matching: sigma = t.float() / 1000.0 sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) if prediction_type == "v_prediction": v_pred = noise_pred alpha_t = torch.sqrt(1 - sigma_shifted ** 2) sigma_t = sigma_shifted noise_pred = alpha_t * v_pred + sigma_t * latents dt = -1.0 / num_inference_steps latents = latents + dt * noise_pred else: latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = latents / self.vae_scale_factor if self.is_lune_model: latents = latents * 5.52 with torch.no_grad(): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") image = Image.fromarray(image[0]) return image # ============================================================================ # MODEL LOADERS # ============================================================================ def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"): """Load Lune checkpoint from .pt file.""" print(f"📥 Downloading: {repo_id}/{filename}") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") checkpoint = torch.load(checkpoint_path, map_location="cpu") print(f"🏗️ Initializing SD1.5 UNet...") unet = UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float32 ) student_state_dict = checkpoint["student"] cleaned_dict = {} for key, value in student_state_dict.items(): if key.startswith("unet."): cleaned_dict[key[5:]] = value else: cleaned_dict[key] = value unet.load_state_dict(cleaned_dict, strict=False) step = checkpoint.get("gstep", "unknown") print(f"✅ Loaded Lune from step {step}") return unet.to(device) def load_illustrious_xl( repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", filename: str = "illustriousXL_v01.safetensors", device: str = "cuda" ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]: """Load Illustrious XL from single safetensors file.""" print(f"📥 Downloading Illustrious XL: {repo_id}/{filename}") checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") print(f"✓ Downloaded: {checkpoint_path}") print("📦 Loading safetensors...") state_dict = load_safetensors(checkpoint_path) # Extract components components = extract_comfyui_components(state_dict) # Load UNet from SDXL base config, then load weights print("🏗️ Initializing SDXL UNet...") unet = UNet2DConditionModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16 ) if components["unet"]: missing, unexpected = unet.load_state_dict(components["unet"], strict=False) print(f" UNet: {len(missing)} missing, {len(unexpected)} unexpected keys") # Load VAE print("🏗️ Initializing SDXL VAE...") vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae", torch_dtype=torch.float16 ) if components["vae"]: missing, unexpected = vae.load_state_dict(components["vae"], strict=False) print(f" VAE: {len(missing)} missing, {len(unexpected)} unexpected keys") # Load CLIP-L print("🏗️ Loading CLIP-L...") text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=torch.float16 ) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # Load CLIP-G print("🏗️ Loading CLIP-G...") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=torch.float16 ) tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # Move to device unet = unet.to(device) vae = vae.to(device) text_encoder = text_encoder.to(device) text_encoder_2 = text_encoder_2.to(device) print("✅ Illustrious XL loaded!") return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 def load_sdxl_base(device: str = "cuda"): """Load standard SDXL base model.""" print("📥 Loading SDXL Base 1.0...") unet = UNet2DConditionModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16 ).to(device) vae = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae", torch_dtype=torch.float16 ).to(device) text_encoder = CLIPTextModel.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", torch_dtype=torch.float16 ).to(device) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2", torch_dtype=torch.float16 ).to(device) tokenizer = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer" ) tokenizer_2 = CLIPTokenizer.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="tokenizer_2" ) print("✅ SDXL Base loaded!") return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"): """Load Lyra VAE (SD1.5 version) from HuggingFace.""" if not LYRA_AVAILABLE: print("⚠️ Lyra VAE not available") return None print(f"🎵 Loading Lyra VAE from {repo_id}...") try: checkpoint_path = hf_hub_download( repo_id=repo_id, filename="best_model.pt", repo_type="model" ) checkpoint = torch.load(checkpoint_path, map_location="cpu") if 'config' in checkpoint: config_dict = checkpoint['config'] else: config_dict = { 'modality_dims': {"clip": 768, "t5": 768}, 'latent_dim': 768, 'seq_len': 77, 'encoder_layers': 3, 'decoder_layers': 3, 'hidden_dim': 1024, 'dropout': 0.1, 'fusion_strategy': 'cantor', 'fusion_heads': 8, 'fusion_dropout': 0.1 } vae_config = MultiModalVAEConfig( modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}), latent_dim=config_dict.get('latent_dim', 768), seq_len=config_dict.get('seq_len', 77), encoder_layers=config_dict.get('encoder_layers', 3), decoder_layers=config_dict.get('decoder_layers', 3), hidden_dim=config_dict.get('hidden_dim', 1024), dropout=config_dict.get('dropout', 0.1), fusion_strategy=config_dict.get('fusion_strategy', 'cantor'), fusion_heads=config_dict.get('fusion_heads', 8), fusion_dropout=config_dict.get('fusion_dropout', 0.1) ) lyra_model = MultiModalVAE(vae_config) if 'model_state_dict' in checkpoint: lyra_model.load_state_dict(checkpoint['model_state_dict']) else: lyra_model.load_state_dict(checkpoint) lyra_model.to(device) lyra_model.eval() print(f"✅ Lyra VAE (SD1.5) loaded") return lyra_model except Exception as e: print(f"❌ Failed to load Lyra VAE: {e}") return None def load_lyra_vae_xl( repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", device: str = "cuda" ): """Load Lyra VAE XL version for SDXL/Illustrious.""" if not LYRA_AVAILABLE: print("⚠️ Lyra VAE not available") return None print(f"🎵 Loading Lyra VAE XL from {repo_id}...") try: checkpoint_path = hf_hub_download( repo_id=repo_id, filename="best_model.pt", repo_type="model" ) checkpoint = torch.load(checkpoint_path, map_location="cpu") if 'config' in checkpoint: config_dict = checkpoint['config'] else: # XL defaults - note larger dimensions config_dict = { 'modality_dims': {"clip": 768, "t5": 2048}, # T5-XL 'latent_dim': 2048, 'seq_len': 77, 'encoder_layers': 4, 'decoder_layers': 4, 'hidden_dim': 2048, 'dropout': 0.1, 'fusion_strategy': 'adaptive_cantor', 'fusion_heads': 16, 'fusion_dropout': 0.1 } vae_config = MultiModalVAEConfig( modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 2048}), latent_dim=config_dict.get('latent_dim', 2048), seq_len=config_dict.get('seq_len', 77), encoder_layers=config_dict.get('encoder_layers', 4), decoder_layers=config_dict.get('decoder_layers', 4), hidden_dim=config_dict.get('hidden_dim', 2048), dropout=config_dict.get('dropout', 0.1), fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'), fusion_heads=config_dict.get('fusion_heads', 16), fusion_dropout=config_dict.get('fusion_dropout', 0.1) ) lyra_model = MultiModalVAE(vae_config) if 'model_state_dict' in checkpoint: lyra_model.load_state_dict(checkpoint['model_state_dict']) else: lyra_model.load_state_dict(checkpoint) lyra_model.to(device) lyra_model.eval() print(f"✅ Lyra VAE XL loaded") if 'global_step' in checkpoint: print(f" Step: {checkpoint['global_step']:,}") return lyra_model except Exception as e: print(f"❌ Failed to load Lyra VAE XL: {e}") return None # ============================================================================ # PIPELINE INITIALIZATION # ============================================================================ def initialize_pipeline(model_choice: str, device: str = "cuda"): """Initialize the complete pipeline based on model choice.""" print(f"🚀 Initializing {model_choice} pipeline...") # Determine architecture is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice is_lune = "Lune" in model_choice if is_sdxl: # SDXL-based models if "Illustrious" in model_choice: unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device) else: unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device) # T5-XL for Lyra print("Loading T5-XL encoder...") t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xl") t5_encoder = T5EncoderModel.from_pretrained( "google/t5-v1_1-xl", torch_dtype=torch.float16 ).to(device) t5_encoder.eval() print("✓ T5-XL loaded") # Lyra XL lyra_model = load_lyra_vae_xl(device=device) # Scheduler (epsilon for SDXL) scheduler = EulerDiscreteScheduler.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler" ) pipeline = SDXLFlowMatchingPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, device=device, t5_encoder=t5_encoder, t5_tokenizer=t5_tokenizer, lyra_model=lyra_model, clip_skip=1 ) else: # SD1.5-based models vae = AutoencoderKL.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=torch.float32 ).to(device) text_encoder = CLIPTextModel.from_pretrained( "openai/clip-vit-large-patch14", torch_dtype=torch.float32 ).to(device) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # T5-base for SD1.5 Lyra print("Loading T5-base encoder...") t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") t5_encoder = T5EncoderModel.from_pretrained( "t5-base", torch_dtype=torch.float32 ).to(device) t5_encoder.eval() print("✓ T5-base loaded") # Lyra (SD1.5 version) lyra_model = load_lyra_vae(device=device) # Load UNet if is_lune: repo_id = "AbstractPhil/sd15-flow-lune" filename = "sd15_flow_lune_e34_s34000.pt" unet = load_lune_checkpoint(repo_id, filename, device) else: unet = UNet2DConditionModel.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="unet", torch_dtype=torch.float32 ).to(device) scheduler = EulerDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler" ) pipeline = SD15FlowMatchingPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, device=device, t5_encoder=t5_encoder, t5_tokenizer=t5_tokenizer, lyra_model=lyra_model ) pipeline.is_lune_model = is_lune print("✅ Pipeline initialized!") return pipeline # ============================================================================ # GLOBAL STATE # ============================================================================ CURRENT_PIPELINE = None CURRENT_MODEL = None def get_pipeline(model_choice: str): """Get or create pipeline for selected model.""" global CURRENT_PIPELINE, CURRENT_MODEL if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice: CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda") CURRENT_MODEL = model_choice return CURRENT_PIPELINE # ============================================================================ # INFERENCE # ============================================================================ def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int: """Estimate GPU duration.""" base_time_per_step = 0.5 if is_sdxl else 0.3 resolution_factor = (width * height) / (512 * 512) estimated = num_steps * base_time_per_step * resolution_factor if use_lyra: estimated *= 2 estimated += 3 return int(estimated + 20) @spaces.GPU(duration=lambda *args: estimate_duration( args[4], args[6], args[7], args[10], "SDXL" in args[2] or "Illustrious" in args[2] )) def generate_image( prompt: str, negative_prompt: str, model_choice: str, clip_skip: int, num_steps: int, cfg_scale: float, width: int, height: int, shift: float, use_flow_matching: bool, use_lyra: bool, seed: int, randomize_seed: bool, progress=gr.Progress() ): """Generate image with ZeroGPU support.""" if randomize_seed: seed = np.random.randint(0, 2**32 - 1) def progress_callback(step, total, desc): progress((step + 1) / total, desc=desc) try: pipeline = get_pipeline(model_choice) # Determine prediction type based on model is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice prediction_type = "epsilon" # SDXL always uses epsilon if not is_sdxl and "Lune" in model_choice: prediction_type = "v_prediction" if not use_lyra or pipeline.lyra_model is None: progress(0.05, desc="Generating...") image = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, use_lyra=False, clip_skip=clip_skip, progress_callback=progress_callback ) progress(1.0, desc="Complete!") return image, None, seed else: progress(0.05, desc="Generating standard...") image_standard = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, use_lyra=False, clip_skip=clip_skip, progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) ) progress(0.5, desc="Generating Lyra fusion...") image_lyra = pipeline( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, shift=shift, use_flow_matching=use_flow_matching, prediction_type=prediction_type, seed=seed, use_lyra=True, clip_skip=clip_skip, progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) ) progress(1.0, desc="Complete!") return image_standard, image_lyra, seed except Exception as e: print(f"❌ Generation failed: {e}") raise e # ============================================================================ # GRADIO UI # ============================================================================ def create_demo(): """Create Gradio interface.""" with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🌙 Lyra/Lune Flow-Matching Image Generation **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil) Generate images using SD1.5 and SDXL-based models with geometric deep learning: | Model | Architecture | Best For | |-------|-------------|----------| | **Illustrious XL** | SDXL | Anime/illustration, high detail | | **SDXL Base** | SDXL | Photorealistic, general purpose | | **Flow-Lune** | SD1.5 | Fast flow matching (15-25 steps) | | **SD1.5 Base** | SD1.5 | Baseline comparison | Enable **Lyra VAE** for CLIP+T5 fusion comparison! """) with gr.Row(): with gr.Column(scale=1): prompt = gr.TextArea( label="Prompt", value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", lines=3 ) negative_prompt = gr.TextArea( label="Negative Prompt", value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality", lines=2 ) model_choice = gr.Dropdown( label="Model", choices=[ "Illustrious XL", "SDXL Base", "Flow-Lune (SD1.5)", "SD1.5 Base" ], value="Illustrious XL" ) clip_skip = gr.Slider( label="CLIP Skip", minimum=1, maximum=4, value=2, step=1, info="2 recommended for Illustrious, 1 for others" ) use_lyra = gr.Checkbox( label="Enable Lyra VAE (CLIP+T5 Fusion)", value=False, info="Compare standard vs geometric fusion" ) with gr.Accordion("Generation Settings", open=True): num_steps = gr.Slider( label="Steps", minimum=1, maximum=50, value=25, step=1 ) cfg_scale = gr.Slider( label="CFG Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.5 ) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=1536, value=1024, step=64 ) height = gr.Slider( label="Height", minimum=512, maximum=1536, value=1024, step=64 ) seed = gr.Slider( label="Seed", minimum=0, maximum=2**32 - 1, value=42, step=1 ) randomize_seed = gr.Checkbox( label="Randomize Seed", value=True ) with gr.Accordion("Advanced (Flow Matching)", open=False): use_flow_matching = gr.Checkbox( label="Enable Flow Matching", value=False, info="Use flow matching ODE (for Lune only)" ) shift = gr.Slider( label="Shift", minimum=0.0, maximum=5.0, value=0.0, step=0.1, info="Flow matching shift (0=disabled)" ) generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") with gr.Column(scale=1): with gr.Row(): output_image_standard = gr.Image( label="Generated Image", type="pil" ) output_image_lyra = gr.Image( label="Lyra Fusion 🎵", type="pil", visible=False ) output_seed = gr.Number(label="Seed", precision=0) gr.Markdown(""" ### Tips - **Illustrious XL**: Use CLIP skip 2, booru-style tags - **SDXL Base**: Natural language prompts work well - **Flow-Lune**: Enable flow matching, shift ~2.5, fewer steps - **Lyra**: Generates both standard and fused for comparison ### Model Info - SDXL models use **epsilon** prediction - Lune uses **v_prediction** with flow matching - Lyra fuses CLIP + T5 for richer semantics """) # Examples gr.Examples( examples=[ [ "masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", "lowres, bad anatomy, worst quality, low quality", "Illustrious XL", 2, 25, 7.0, 1024, 1024, 0.0, False, False, 42, False ], [ "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k", "blurry, low quality", "SDXL Base", 1, 30, 7.5, 1024, 1024, 0.0, False, False, 123, False ], [ "cyberpunk city at night, neon lights, rain, highly detailed", "low quality, blurry", "Flow-Lune (SD1.5)", 1, 20, 7.5, 512, 512, 2.5, True, False, 456, False ], ], inputs=[ prompt, negative_prompt, model_choice, clip_skip, num_steps, cfg_scale, width, height, shift, use_flow_matching, use_lyra, seed, randomize_seed ], outputs=[output_image_standard, output_image_lyra, output_seed], fn=generate_image, cache_examples=False ) # Event handlers def on_model_change(model_name): """Update defaults based on model.""" if "Illustrious" in model_name: return { clip_skip: gr.update(value=2), width: gr.update(value=1024), height: gr.update(value=1024), num_steps: gr.update(value=25), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0) } elif "SDXL" in model_name: return { clip_skip: gr.update(value=1), width: gr.update(value=1024), height: gr.update(value=1024), num_steps: gr.update(value=30), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0) } elif "Lune" in model_name: return { clip_skip: gr.update(value=1), width: gr.update(value=512), height: gr.update(value=512), num_steps: gr.update(value=20), use_flow_matching: gr.update(value=True), shift: gr.update(value=2.5) } else: # SD1.5 Base return { clip_skip: gr.update(value=1), width: gr.update(value=512), height: gr.update(value=512), num_steps: gr.update(value=30), use_flow_matching: gr.update(value=False), shift: gr.update(value=0.0) } def on_lyra_toggle(enabled): """Show/hide Lyra comparison.""" if enabled: return { output_image_standard: gr.update(visible=True, label="Standard"), output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵") } else: return { output_image_standard: gr.update(visible=True, label="Generated Image"), output_image_lyra: gr.update(visible=False) } model_choice.change( fn=on_model_change, inputs=[model_choice], outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift] ) use_lyra.change( fn=on_lyra_toggle, inputs=[use_lyra], outputs=[output_image_standard, output_image_lyra] ) generate_btn.click( fn=generate_image, inputs=[ prompt, negative_prompt, model_choice, clip_skip, num_steps, cfg_scale, width, height, shift, use_flow_matching, use_lyra, seed, randomize_seed ], outputs=[output_image_standard, output_image_lyra, output_seed] ) return demo # ============================================================================ # LAUNCH # ============================================================================ if __name__ == "__main__": demo = create_demo() demo.queue(max_size=20) demo.launch(show_api=False)