Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Lyra/Lune Flow-Matching Inference Space | |
| Author: AbstractPhil | |
| License: MIT | |
| SD1.5 and SDXL-based flow matching with geometric crystalline architectures. | |
| Supports Illustrious XL, standard SDXL, and SD1.5 variants. | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Dict, Tuple | |
| import spaces | |
| from safetensors.torch import load_file as load_safetensors | |
| from diffusers import ( | |
| UNet2DConditionModel, | |
| AutoencoderKL, | |
| EulerDiscreteScheduler, | |
| EulerAncestralDiscreteScheduler | |
| ) | |
| from diffusers.models import UNet2DConditionModel as DiffusersUNet | |
| from transformers import ( | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| CLIPTextModelWithProjection, | |
| T5EncoderModel, | |
| T5Tokenizer | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # Import Lyra VAE from geofractal | |
| try: | |
| from geofractal.models.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig | |
| LYRA_AVAILABLE = True | |
| except ImportError: | |
| try: | |
| from geofractal.train.model.vae.vae_lyra import MultiModalVAE, MultiModalVAEConfig | |
| LYRA_AVAILABLE = True | |
| except ImportError: | |
| print("⚠️ Lyra VAE not available - install geofractal") | |
| LYRA_AVAILABLE = False | |
| # ============================================================================ | |
| # CONSTANTS | |
| # ============================================================================ | |
| # Model architectures | |
| ARCH_SD15 = "sd15" | |
| ARCH_SDXL = "sdxl" | |
| # ComfyUI key prefixes for SDXL single-file checkpoints | |
| COMFYUI_UNET_PREFIX = "model.diffusion_model." | |
| COMFYUI_CLIP_L_PREFIX = "conditioner.embedders.0.transformer." | |
| COMFYUI_CLIP_G_PREFIX = "conditioner.embedders.1.model." | |
| COMFYUI_VAE_PREFIX = "first_stage_model." | |
| # ============================================================================ | |
| # MODEL LOADING UTILITIES | |
| # ============================================================================ | |
| def extract_comfyui_components(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: | |
| """Extract UNet, CLIP-L, CLIP-G, and VAE from ComfyUI single-file checkpoint.""" | |
| components = { | |
| "unet": {}, | |
| "clip_l": {}, | |
| "clip_g": {}, | |
| "vae": {} | |
| } | |
| for key, value in state_dict.items(): | |
| if key.startswith(COMFYUI_UNET_PREFIX): | |
| new_key = key[len(COMFYUI_UNET_PREFIX):] | |
| components["unet"][new_key] = value | |
| elif key.startswith(COMFYUI_CLIP_L_PREFIX): | |
| new_key = key[len(COMFYUI_CLIP_L_PREFIX):] | |
| components["clip_l"][new_key] = value | |
| elif key.startswith(COMFYUI_CLIP_G_PREFIX): | |
| new_key = key[len(COMFYUI_CLIP_G_PREFIX):] | |
| components["clip_g"][new_key] = value | |
| elif key.startswith(COMFYUI_VAE_PREFIX): | |
| new_key = key[len(COMFYUI_VAE_PREFIX):] | |
| components["vae"][new_key] = value | |
| print(f" Extracted components:") | |
| print(f" UNet: {len(components['unet'])} keys") | |
| print(f" CLIP-L: {len(components['clip_l'])} keys") | |
| print(f" CLIP-G: {len(components['clip_g'])} keys") | |
| print(f" VAE: {len(components['vae'])} keys") | |
| return components | |
| def get_clip_hidden_state( | |
| model_output, | |
| clip_skip: int = 1, | |
| output_hidden_states: bool = True | |
| ) -> torch.Tensor: | |
| """Extract hidden state with clip_skip support.""" | |
| if clip_skip == 1 or not output_hidden_states: | |
| return model_output.last_hidden_state | |
| if hasattr(model_output, 'hidden_states') and model_output.hidden_states is not None: | |
| # hidden_states is tuple: (embedding, layer1, ..., layerN) | |
| # clip_skip=2 means penultimate layer = hidden_states[-2] | |
| return model_output.hidden_states[-clip_skip] | |
| return model_output.last_hidden_state | |
| # ============================================================================ | |
| # SDXL PIPELINE | |
| # ============================================================================ | |
| class SDXLFlowMatchingPipeline: | |
| """Pipeline for SDXL-based flow-matching inference with dual CLIP encoders.""" | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, # CLIP-L | |
| text_encoder_2: CLIPTextModelWithProjection, # CLIP-G | |
| tokenizer: CLIPTokenizer, | |
| tokenizer_2: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler, | |
| device: str = "cuda", | |
| t5_encoder: Optional[T5EncoderModel] = None, | |
| t5_tokenizer: Optional[T5Tokenizer] = None, | |
| lyra_model: Optional[any] = None, | |
| clip_skip: int = 1 | |
| ): | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.tokenizer = tokenizer | |
| self.tokenizer_2 = tokenizer_2 | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| # Lyra components | |
| self.t5_encoder = t5_encoder | |
| self.t5_tokenizer = t5_tokenizer | |
| self.lyra_model = lyra_model | |
| # Settings | |
| self.clip_skip = clip_skip | |
| self.vae_scale_factor = 0.13025 # SDXL VAE scaling | |
| self.arch = ARCH_SDXL | |
| def encode_prompt( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using dual CLIP encoders for SDXL.""" | |
| # CLIP-L encoding | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| output_hidden_states = clip_skip > 1 | |
| clip_l_output = self.text_encoder( | |
| text_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_l = get_clip_hidden_state(clip_l_output, clip_skip, output_hidden_states) | |
| # CLIP-G encoding | |
| text_inputs_2 = self.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids_2 = text_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_g_output = self.text_encoder_2( | |
| text_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| prompt_embeds_g = get_clip_hidden_state(clip_g_output, clip_skip, output_hidden_states) | |
| # Get pooled output from CLIP-G | |
| pooled_prompt_embeds = clip_g_output.text_embeds | |
| # Concatenate CLIP-L and CLIP-G embeddings | |
| prompt_embeds = torch.cat([prompt_embeds_l, prompt_embeds_g], dim=-1) | |
| # Negative prompt | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| uncond_inputs_2 = self.tokenizer_2( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer_2.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids_2 = uncond_inputs_2.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| uncond_output_l = self.text_encoder( | |
| uncond_input_ids, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_l = get_clip_hidden_state(uncond_output_l, clip_skip, output_hidden_states) | |
| uncond_output_g = self.text_encoder_2( | |
| uncond_input_ids_2, | |
| output_hidden_states=output_hidden_states | |
| ) | |
| negative_embeds_g = get_clip_hidden_state(uncond_output_g, clip_skip, output_hidden_states) | |
| negative_pooled = uncond_output_g.text_embeds | |
| negative_prompt_embeds = torch.cat([negative_embeds_l, negative_embeds_g], dim=-1) | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| negative_pooled = torch.zeros_like(pooled_prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled | |
| def encode_prompt_lyra( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| clip_skip: int = 1 | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Encode prompts using Lyra VAE fusion (CLIP + T5).""" | |
| if self.lyra_model is None or self.t5_encoder is None: | |
| raise ValueError("Lyra VAE components not initialized") | |
| # Get standard CLIP embeddings first | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| # Get T5 embeddings | |
| t5_inputs = self.t5_tokenizer( | |
| prompt, | |
| max_length=77, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state | |
| # For SDXL, we need to handle the concatenated CLIP-L + CLIP-G embeddings | |
| # Split them, fuse CLIP-L through Lyra, then recombine | |
| clip_l_dim = 768 | |
| clip_g_dim = 1280 | |
| clip_l_embeds = prompt_embeds[..., :clip_l_dim] | |
| clip_g_embeds = prompt_embeds[..., clip_l_dim:] | |
| # Fuse CLIP-L through Lyra | |
| modality_inputs = { | |
| 'clip': clip_l_embeds, | |
| 't5': t5_embeds | |
| } | |
| with torch.no_grad(): | |
| reconstructions, mu, logvar = self.lyra_model( | |
| modality_inputs, | |
| target_modalities=['clip'] | |
| ) | |
| fused_clip_l = reconstructions['clip'] | |
| # Recombine with CLIP-G | |
| prompt_embeds_fused = torch.cat([fused_clip_l, clip_g_embeds], dim=-1) | |
| # Process negative prompt similarly if present | |
| if negative_prompt: | |
| t5_inputs_neg = self.t5_tokenizer( | |
| negative_prompt, | |
| max_length=77, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds_neg = self.t5_encoder(**t5_inputs_neg).last_hidden_state | |
| neg_clip_l = negative_prompt_embeds[..., :clip_l_dim] | |
| neg_clip_g = negative_prompt_embeds[..., clip_l_dim:] | |
| modality_inputs_neg = { | |
| 'clip': neg_clip_l, | |
| 't5': t5_embeds_neg | |
| } | |
| with torch.no_grad(): | |
| reconstructions_neg, _, _ = self.lyra_model( | |
| modality_inputs_neg, | |
| target_modalities=['clip'] | |
| ) | |
| fused_neg_clip_l = reconstructions_neg['clip'] | |
| negative_prompt_embeds_fused = torch.cat([fused_neg_clip_l, neg_clip_g], dim=-1) | |
| else: | |
| negative_prompt_embeds_fused = torch.zeros_like(prompt_embeds_fused) | |
| return prompt_embeds_fused, negative_prompt_embeds_fused, pooled, negative_pooled | |
| def _get_add_time_ids( | |
| self, | |
| original_size: Tuple[int, int], | |
| crops_coords_top_left: Tuple[int, int], | |
| target_size: Tuple[int, int], | |
| dtype: torch.dtype | |
| ) -> torch.Tensor: | |
| """Create time embedding IDs for SDXL.""" | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device) | |
| return add_time_ids | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7.5, | |
| shift: float = 0.0, | |
| use_flow_matching: bool = False, | |
| prediction_type: str = "epsilon", | |
| seed: Optional[int] = None, | |
| use_lyra: bool = False, | |
| clip_skip: int = 1, | |
| progress_callback=None | |
| ): | |
| """Generate image using SDXL architecture.""" | |
| # Set seed | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| else: | |
| generator = None | |
| # Encode prompts | |
| if use_lyra and self.lyra_model is not None: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt_lyra( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| else: | |
| prompt_embeds, negative_prompt_embeds, pooled, negative_pooled = self.encode_prompt( | |
| prompt, negative_prompt, clip_skip | |
| ) | |
| # Prepare latents | |
| latent_channels = 4 | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| latents = torch.randn( | |
| (1, latent_channels, latent_height, latent_width), | |
| generator=generator, | |
| device=self.device, | |
| dtype=torch.float16 | |
| ) | |
| # Set timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| # Scale initial latents | |
| if not use_flow_matching: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # Prepare added time embeddings for SDXL | |
| original_size = (height, width) | |
| target_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| add_time_ids = self._get_add_time_ids( | |
| original_size, crops_coords_top_left, target_size, dtype=torch.float16 | |
| ) | |
| negative_add_time_ids = add_time_ids # Same for negative | |
| # Denoising loop | |
| for i, t in enumerate(timesteps): | |
| if progress_callback: | |
| progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") | |
| # Expand for CFG | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| # Flow matching scaling | |
| if use_flow_matching and shift > 0: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| scaling = torch.sqrt(1 + sigma_shifted ** 2) | |
| latent_model_input = latent_model_input / scaling | |
| else: | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # Prepare timestep | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| # Prepare added conditions | |
| if guidance_scale > 1.0: | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| add_text_embeds = torch.cat([negative_pooled, pooled]) | |
| add_time_ids_input = torch.cat([negative_add_time_ids, add_time_ids]) | |
| else: | |
| text_embeds = prompt_embeds | |
| add_text_embeds = pooled | |
| add_time_ids_input = add_time_ids | |
| # Prepare added cond kwargs for SDXL UNet | |
| added_cond_kwargs = { | |
| "text_embeds": add_text_embeds, | |
| "time_ids": add_time_ids_input | |
| } | |
| # Predict noise | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False | |
| )[0] | |
| # CFG | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # Step | |
| if use_flow_matching: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| if prediction_type == "v_prediction": | |
| v_pred = noise_pred | |
| alpha_t = torch.sqrt(1 - sigma_shifted ** 2) | |
| sigma_t = sigma_shifted | |
| noise_pred = alpha_t * v_pred + sigma_t * latents | |
| dt = -1.0 / num_inference_steps | |
| latents = latents + dt * noise_pred | |
| else: | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, return_dict=False | |
| )[0] | |
| # Decode | |
| latents = latents / self.vae_scale_factor | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents.to(self.vae.dtype)).sample | |
| # Convert to PIL | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image[0]) | |
| return image | |
| # ============================================================================ | |
| # SD1.5 PIPELINE (Original) | |
| # ============================================================================ | |
| class SD15FlowMatchingPipeline: | |
| """Pipeline for SD1.5-based flow-matching inference.""" | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler, | |
| device: str = "cuda", | |
| t5_encoder: Optional[T5EncoderModel] = None, | |
| t5_tokenizer: Optional[T5Tokenizer] = None, | |
| lyra_model: Optional[any] = None | |
| ): | |
| self.vae = vae | |
| self.text_encoder = text_encoder | |
| self.tokenizer = tokenizer | |
| self.unet = unet | |
| self.scheduler = scheduler | |
| self.device = device | |
| self.t5_encoder = t5_encoder | |
| self.t5_tokenizer = t5_tokenizer | |
| self.lyra_model = lyra_model | |
| self.vae_scale_factor = 0.18215 | |
| self.arch = ARCH_SD15 | |
| self.is_lune_model = False | |
| def encode_prompt(self, prompt: str, negative_prompt: str = ""): | |
| """Encode text prompts to embeddings.""" | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| prompt_embeds = self.text_encoder(text_input_ids)[0] | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds | |
| def encode_prompt_lyra(self, prompt: str, negative_prompt: str = ""): | |
| """Encode using Lyra VAE (CLIP + T5 fusion).""" | |
| if self.lyra_model is None or self.t5_encoder is None: | |
| raise ValueError("Lyra VAE components not initialized") | |
| # CLIP | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_embeds = self.text_encoder(text_input_ids)[0] | |
| # T5 | |
| t5_inputs = self.t5_tokenizer( | |
| prompt, | |
| max_length=77, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds = self.t5_encoder(**t5_inputs).last_hidden_state | |
| # Fuse | |
| modality_inputs = {'clip': clip_embeds, 't5': t5_embeds} | |
| with torch.no_grad(): | |
| reconstructions, mu, logvar = self.lyra_model( | |
| modality_inputs, | |
| target_modalities=['clip'] | |
| ) | |
| prompt_embeds = reconstructions['clip'] | |
| # Negative | |
| if negative_prompt: | |
| uncond_inputs = self.tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| uncond_input_ids = uncond_inputs.input_ids.to(self.device) | |
| with torch.no_grad(): | |
| clip_embeds_uncond = self.text_encoder(uncond_input_ids)[0] | |
| t5_inputs_uncond = self.t5_tokenizer( | |
| negative_prompt, | |
| max_length=77, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| t5_embeds_uncond = self.t5_encoder(**t5_inputs_uncond).last_hidden_state | |
| modality_inputs_uncond = {'clip': clip_embeds_uncond, 't5': t5_embeds_uncond} | |
| with torch.no_grad(): | |
| reconstructions_uncond, _, _ = self.lyra_model( | |
| modality_inputs_uncond, | |
| target_modalities=['clip'] | |
| ) | |
| negative_prompt_embeds = reconstructions_uncond['clip'] | |
| else: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| return prompt_embeds, negative_prompt_embeds | |
| def __call__( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| height: int = 512, | |
| width: int = 512, | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7.5, | |
| shift: float = 2.5, | |
| use_flow_matching: bool = True, | |
| prediction_type: str = "epsilon", | |
| seed: Optional[int] = None, | |
| use_lyra: bool = False, | |
| clip_skip: int = 1, # Unused for SD1.5 but kept for API consistency | |
| progress_callback=None | |
| ): | |
| """Generate image.""" | |
| if seed is not None: | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| else: | |
| generator = None | |
| if use_lyra and self.lyra_model is not None: | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt_lyra(prompt, negative_prompt) | |
| else: | |
| prompt_embeds, negative_prompt_embeds = self.encode_prompt(prompt, negative_prompt) | |
| latent_channels = 4 | |
| latent_height = height // 8 | |
| latent_width = width // 8 | |
| latents = torch.randn( | |
| (1, latent_channels, latent_height, latent_width), | |
| generator=generator, | |
| device=self.device, | |
| dtype=torch.float32 | |
| ) | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| timesteps = self.scheduler.timesteps | |
| if not use_flow_matching: | |
| latents = latents * self.scheduler.init_noise_sigma | |
| for i, t in enumerate(timesteps): | |
| if progress_callback: | |
| progress_callback(i, num_inference_steps, f"Step {i+1}/{num_inference_steps}") | |
| latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents | |
| if use_flow_matching and shift > 0: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| scaling = torch.sqrt(1 + sigma_shifted ** 2) | |
| latent_model_input = latent_model_input / scaling | |
| else: | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| return_dict=False | |
| )[0] | |
| if guidance_scale > 1.0: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| if use_flow_matching: | |
| sigma = t.float() / 1000.0 | |
| sigma_shifted = (shift * sigma) / (1 + (shift - 1) * sigma) | |
| if prediction_type == "v_prediction": | |
| v_pred = noise_pred | |
| alpha_t = torch.sqrt(1 - sigma_shifted ** 2) | |
| sigma_t = sigma_shifted | |
| noise_pred = alpha_t * v_pred + sigma_t * latents | |
| dt = -1.0 / num_inference_steps | |
| latents = latents + dt * noise_pred | |
| else: | |
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| latents = latents / self.vae_scale_factor | |
| if self.is_lune_model: | |
| latents = latents * 5.52 | |
| with torch.no_grad(): | |
| image = self.vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = Image.fromarray(image[0]) | |
| return image | |
| # ============================================================================ | |
| # MODEL LOADERS | |
| # ============================================================================ | |
| def load_lune_checkpoint(repo_id: str, filename: str, device: str = "cuda"): | |
| """Load Lune checkpoint from .pt file.""" | |
| print(f"📥 Downloading: {repo_id}/{filename}") | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| print(f"🏗️ Initializing SD1.5 UNet...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="unet", | |
| torch_dtype=torch.float32 | |
| ) | |
| student_state_dict = checkpoint["student"] | |
| cleaned_dict = {} | |
| for key, value in student_state_dict.items(): | |
| if key.startswith("unet."): | |
| cleaned_dict[key[5:]] = value | |
| else: | |
| cleaned_dict[key] = value | |
| unet.load_state_dict(cleaned_dict, strict=False) | |
| step = checkpoint.get("gstep", "unknown") | |
| print(f"✅ Loaded Lune from step {step}") | |
| return unet.to(device) | |
| def load_illustrious_xl( | |
| repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", | |
| filename: str = "illustriousXL_v01.safetensors", | |
| device: str = "cuda" | |
| ) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]: | |
| """Load Illustrious XL from single safetensors file.""" | |
| print(f"📥 Downloading Illustrious XL: {repo_id}/{filename}") | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model") | |
| print(f"✓ Downloaded: {checkpoint_path}") | |
| print("📦 Loading safetensors...") | |
| state_dict = load_safetensors(checkpoint_path) | |
| # Extract components | |
| components = extract_comfyui_components(state_dict) | |
| # Load UNet from SDXL base config, then load weights | |
| print("🏗️ Initializing SDXL UNet...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="unet", | |
| torch_dtype=torch.float16 | |
| ) | |
| if components["unet"]: | |
| missing, unexpected = unet.load_state_dict(components["unet"], strict=False) | |
| print(f" UNet: {len(missing)} missing, {len(unexpected)} unexpected keys") | |
| # Load VAE | |
| print("🏗️ Initializing SDXL VAE...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="vae", | |
| torch_dtype=torch.float16 | |
| ) | |
| if components["vae"]: | |
| missing, unexpected = vae.load_state_dict(components["vae"], strict=False) | |
| print(f" VAE: {len(missing)} missing, {len(unexpected)} unexpected keys") | |
| # Load CLIP-L | |
| print("🏗️ Loading CLIP-L...") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| torch_dtype=torch.float16 | |
| ) | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| # Load CLIP-G | |
| print("🏗️ Loading CLIP-G...") | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", | |
| torch_dtype=torch.float16 | |
| ) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | |
| # Move to device | |
| unet = unet.to(device) | |
| vae = vae.to(device) | |
| text_encoder = text_encoder.to(device) | |
| text_encoder_2 = text_encoder_2.to(device) | |
| print("✅ Illustrious XL loaded!") | |
| return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 | |
| def load_sdxl_base(device: str = "cuda"): | |
| """Load standard SDXL base model.""" | |
| print("📥 Loading SDXL Base 1.0...") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="unet", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| vae = AutoencoderKL.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="vae", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="text_encoder_2", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="tokenizer" | |
| ) | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="tokenizer_2" | |
| ) | |
| print("✅ SDXL Base loaded!") | |
| return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 | |
| def load_lyra_vae(repo_id: str = "AbstractPhil/vae-lyra", device: str = "cuda"): | |
| """Load Lyra VAE (SD1.5 version) from HuggingFace.""" | |
| if not LYRA_AVAILABLE: | |
| print("⚠️ Lyra VAE not available") | |
| return None | |
| print(f"🎵 Loading Lyra VAE from {repo_id}...") | |
| try: | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="best_model.pt", | |
| repo_type="model" | |
| ) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| if 'config' in checkpoint: | |
| config_dict = checkpoint['config'] | |
| else: | |
| config_dict = { | |
| 'modality_dims': {"clip": 768, "t5": 768}, | |
| 'latent_dim': 768, | |
| 'seq_len': 77, | |
| 'encoder_layers': 3, | |
| 'decoder_layers': 3, | |
| 'hidden_dim': 1024, | |
| 'dropout': 0.1, | |
| 'fusion_strategy': 'cantor', | |
| 'fusion_heads': 8, | |
| 'fusion_dropout': 0.1 | |
| } | |
| vae_config = MultiModalVAEConfig( | |
| modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 768}), | |
| latent_dim=config_dict.get('latent_dim', 768), | |
| seq_len=config_dict.get('seq_len', 77), | |
| encoder_layers=config_dict.get('encoder_layers', 3), | |
| decoder_layers=config_dict.get('decoder_layers', 3), | |
| hidden_dim=config_dict.get('hidden_dim', 1024), | |
| dropout=config_dict.get('dropout', 0.1), | |
| fusion_strategy=config_dict.get('fusion_strategy', 'cantor'), | |
| fusion_heads=config_dict.get('fusion_heads', 8), | |
| fusion_dropout=config_dict.get('fusion_dropout', 0.1) | |
| ) | |
| lyra_model = MultiModalVAE(vae_config) | |
| if 'model_state_dict' in checkpoint: | |
| lyra_model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| lyra_model.load_state_dict(checkpoint) | |
| lyra_model.to(device) | |
| lyra_model.eval() | |
| print(f"✅ Lyra VAE (SD1.5) loaded") | |
| return lyra_model | |
| except Exception as e: | |
| print(f"❌ Failed to load Lyra VAE: {e}") | |
| return None | |
| def load_lyra_vae_xl( | |
| repo_id: str = "AbstractPhil/vae-lyra-xl-adaptive-cantor-illustrious", | |
| device: str = "cuda" | |
| ): | |
| """Load Lyra VAE XL version for SDXL/Illustrious.""" | |
| if not LYRA_AVAILABLE: | |
| print("⚠️ Lyra VAE not available") | |
| return None | |
| print(f"🎵 Loading Lyra VAE XL from {repo_id}...") | |
| try: | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="best_model.pt", | |
| repo_type="model" | |
| ) | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| if 'config' in checkpoint: | |
| config_dict = checkpoint['config'] | |
| else: | |
| # XL defaults - note larger dimensions | |
| config_dict = { | |
| 'modality_dims': {"clip": 768, "t5": 2048}, # T5-XL | |
| 'latent_dim': 2048, | |
| 'seq_len': 77, | |
| 'encoder_layers': 4, | |
| 'decoder_layers': 4, | |
| 'hidden_dim': 2048, | |
| 'dropout': 0.1, | |
| 'fusion_strategy': 'adaptive_cantor', | |
| 'fusion_heads': 16, | |
| 'fusion_dropout': 0.1 | |
| } | |
| vae_config = MultiModalVAEConfig( | |
| modality_dims=config_dict.get('modality_dims', {"clip": 768, "t5": 2048}), | |
| latent_dim=config_dict.get('latent_dim', 2048), | |
| seq_len=config_dict.get('seq_len', 77), | |
| encoder_layers=config_dict.get('encoder_layers', 4), | |
| decoder_layers=config_dict.get('decoder_layers', 4), | |
| hidden_dim=config_dict.get('hidden_dim', 2048), | |
| dropout=config_dict.get('dropout', 0.1), | |
| fusion_strategy=config_dict.get('fusion_strategy', 'adaptive_cantor'), | |
| fusion_heads=config_dict.get('fusion_heads', 16), | |
| fusion_dropout=config_dict.get('fusion_dropout', 0.1) | |
| ) | |
| lyra_model = MultiModalVAE(vae_config) | |
| if 'model_state_dict' in checkpoint: | |
| lyra_model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| lyra_model.load_state_dict(checkpoint) | |
| lyra_model.to(device) | |
| lyra_model.eval() | |
| print(f"✅ Lyra VAE XL loaded") | |
| if 'global_step' in checkpoint: | |
| print(f" Step: {checkpoint['global_step']:,}") | |
| return lyra_model | |
| except Exception as e: | |
| print(f"❌ Failed to load Lyra VAE XL: {e}") | |
| return None | |
| # ============================================================================ | |
| # PIPELINE INITIALIZATION | |
| # ============================================================================ | |
| def initialize_pipeline(model_choice: str, device: str = "cuda"): | |
| """Initialize the complete pipeline based on model choice.""" | |
| print(f"🚀 Initializing {model_choice} pipeline...") | |
| # Determine architecture | |
| is_sdxl = "Illustrious" in model_choice or "SDXL" in model_choice | |
| is_lune = "Lune" in model_choice | |
| if is_sdxl: | |
| # SDXL-based models | |
| if "Illustrious" in model_choice: | |
| unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_illustrious_xl(device=device) | |
| else: | |
| unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2 = load_sdxl_base(device=device) | |
| # T5-XL for Lyra | |
| print("Loading T5-XL encoder...") | |
| t5_tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xl") | |
| t5_encoder = T5EncoderModel.from_pretrained( | |
| "google/t5-v1_1-xl", | |
| torch_dtype=torch.float16 | |
| ).to(device) | |
| t5_encoder.eval() | |
| print("✓ T5-XL loaded") | |
| # Lyra XL | |
| lyra_model = load_lyra_vae_xl(device=device) | |
| # Scheduler (epsilon for SDXL) | |
| scheduler = EulerDiscreteScheduler.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| subfolder="scheduler" | |
| ) | |
| pipeline = SDXLFlowMatchingPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| unet=unet, | |
| scheduler=scheduler, | |
| device=device, | |
| t5_encoder=t5_encoder, | |
| t5_tokenizer=t5_tokenizer, | |
| lyra_model=lyra_model, | |
| clip_skip=1 | |
| ) | |
| else: | |
| # SD1.5-based models | |
| vae = AutoencoderKL.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="vae", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| "openai/clip-vit-large-patch14", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| # T5-base for SD1.5 Lyra | |
| print("Loading T5-base encoder...") | |
| t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
| t5_encoder = T5EncoderModel.from_pretrained( | |
| "t5-base", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| t5_encoder.eval() | |
| print("✓ T5-base loaded") | |
| # Lyra (SD1.5 version) | |
| lyra_model = load_lyra_vae(device=device) | |
| # Load UNet | |
| if is_lune: | |
| repo_id = "AbstractPhil/sd15-flow-lune" | |
| filename = "sd15_flow_lune_e34_s34000.pt" | |
| unet = load_lune_checkpoint(repo_id, filename, device) | |
| else: | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="unet", | |
| torch_dtype=torch.float32 | |
| ).to(device) | |
| scheduler = EulerDiscreteScheduler.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| subfolder="scheduler" | |
| ) | |
| pipeline = SD15FlowMatchingPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| device=device, | |
| t5_encoder=t5_encoder, | |
| t5_tokenizer=t5_tokenizer, | |
| lyra_model=lyra_model | |
| ) | |
| pipeline.is_lune_model = is_lune | |
| print("✅ Pipeline initialized!") | |
| return pipeline | |
| # ============================================================================ | |
| # GLOBAL STATE | |
| # ============================================================================ | |
| CURRENT_PIPELINE = None | |
| CURRENT_MODEL = None | |
| def get_pipeline(model_choice: str): | |
| """Get or create pipeline for selected model.""" | |
| global CURRENT_PIPELINE, CURRENT_MODEL | |
| if CURRENT_PIPELINE is None or CURRENT_MODEL != model_choice: | |
| CURRENT_PIPELINE = initialize_pipeline(model_choice, device="cuda") | |
| CURRENT_MODEL = model_choice | |
| return CURRENT_PIPELINE | |
| # ============================================================================ | |
| # INFERENCE | |
| # ============================================================================ | |
| def estimate_duration(num_steps: int, width: int, height: int, use_lyra: bool = False, is_sdxl: bool = False) -> int: | |
| """Estimate GPU duration.""" | |
| base_time_per_step = 0.5 if is_sdxl else 0.3 | |
| resolution_factor = (width * height) / (512 * 512) | |
| estimated = num_steps * base_time_per_step * resolution_factor | |
| if use_lyra: | |
| estimated *= 2 | |
| estimated += 3 | |
| return int(estimated + 20) | |
| def generate_image( | |
| prompt: str, | |
| negative_prompt: str, | |
| model_choice: str, | |
| clip_skip: int, | |
| num_steps: int, | |
| cfg_scale: float, | |
| width: int, | |
| height: int, | |
| shift: float, | |
| use_flow_matching: bool, | |
| use_lyra: bool, | |
| seed: int, | |
| randomize_seed: bool, | |
| progress=gr.Progress() | |
| ): | |
| """Generate image with ZeroGPU support.""" | |
| if randomize_seed: | |
| seed = np.random.randint(0, 2**32 - 1) | |
| def progress_callback(step, total, desc): | |
| progress((step + 1) / total, desc=desc) | |
| try: | |
| pipeline = get_pipeline(model_choice) | |
| # Determine prediction type based on model | |
| is_sdxl = "SDXL" in model_choice or "Illustrious" in model_choice | |
| prediction_type = "epsilon" # SDXL always uses epsilon | |
| if not is_sdxl and "Lune" in model_choice: | |
| prediction_type = "v_prediction" | |
| if not use_lyra or pipeline.lyra_model is None: | |
| progress(0.05, desc="Generating...") | |
| image = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=progress_callback | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image, None, seed | |
| else: | |
| progress(0.05, desc="Generating standard...") | |
| image_standard = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, | |
| use_lyra=False, | |
| clip_skip=clip_skip, | |
| progress_callback=lambda s, t, d: progress(0.05 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(0.5, desc="Generating Lyra fusion...") | |
| image_lyra = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_steps, | |
| guidance_scale=cfg_scale, | |
| shift=shift, | |
| use_flow_matching=use_flow_matching, | |
| prediction_type=prediction_type, | |
| seed=seed, | |
| use_lyra=True, | |
| clip_skip=clip_skip, | |
| progress_callback=lambda s, t, d: progress(0.5 + (s/t) * 0.45, desc=d) | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return image_standard, image_lyra, seed | |
| except Exception as e: | |
| print(f"❌ Generation failed: {e}") | |
| raise e | |
| # ============================================================================ | |
| # GRADIO UI | |
| # ============================================================================ | |
| def create_demo(): | |
| """Create Gradio interface.""" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🌙 Lyra/Lune Flow-Matching Image Generation | |
| **Geometric crystalline diffusion** by [AbstractPhil](https://huggingface.co/AbstractPhil) | |
| Generate images using SD1.5 and SDXL-based models with geometric deep learning: | |
| | Model | Architecture | Best For | | |
| |-------|-------------|----------| | |
| | **Illustrious XL** | SDXL | Anime/illustration, high detail | | |
| | **SDXL Base** | SDXL | Photorealistic, general purpose | | |
| | **Flow-Lune** | SD1.5 | Fast flow matching (15-25 steps) | | |
| | **SD1.5 Base** | SD1.5 | Baseline comparison | | |
| Enable **Lyra VAE** for CLIP+T5 fusion comparison! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.TextArea( | |
| label="Prompt", | |
| value="masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", | |
| lines=3 | |
| ) | |
| negative_prompt = gr.TextArea( | |
| label="Negative Prompt", | |
| value="lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality", | |
| lines=2 | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "Illustrious XL", | |
| "SDXL Base", | |
| "Flow-Lune (SD1.5)", | |
| "SD1.5 Base" | |
| ], | |
| value="Illustrious XL" | |
| ) | |
| clip_skip = gr.Slider( | |
| label="CLIP Skip", | |
| minimum=1, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| info="2 recommended for Illustrious, 1 for others" | |
| ) | |
| use_lyra = gr.Checkbox( | |
| label="Enable Lyra VAE (CLIP+T5 Fusion)", | |
| value=False, | |
| info="Compare standard vs geometric fusion" | |
| ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| num_steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=50, | |
| value=25, | |
| step=1 | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=20.0, | |
| value=7.0, | |
| step=0.5 | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=1536, | |
| value=1024, | |
| step=64 | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2**32 - 1, | |
| value=42, | |
| step=1 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| with gr.Accordion("Advanced (Flow Matching)", open=False): | |
| use_flow_matching = gr.Checkbox( | |
| label="Enable Flow Matching", | |
| value=False, | |
| info="Use flow matching ODE (for Lune only)" | |
| ) | |
| shift = gr.Slider( | |
| label="Shift", | |
| minimum=0.0, | |
| maximum=5.0, | |
| value=0.0, | |
| step=0.1, | |
| info="Flow matching shift (0=disabled)" | |
| ) | |
| generate_btn = gr.Button("🎨 Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| output_image_standard = gr.Image( | |
| label="Generated Image", | |
| type="pil" | |
| ) | |
| output_image_lyra = gr.Image( | |
| label="Lyra Fusion 🎵", | |
| type="pil", | |
| visible=False | |
| ) | |
| output_seed = gr.Number(label="Seed", precision=0) | |
| gr.Markdown(""" | |
| ### Tips | |
| - **Illustrious XL**: Use CLIP skip 2, booru-style tags | |
| - **SDXL Base**: Natural language prompts work well | |
| - **Flow-Lune**: Enable flow matching, shift ~2.5, fewer steps | |
| - **Lyra**: Generates both standard and fused for comparison | |
| ### Model Info | |
| - SDXL models use **epsilon** prediction | |
| - Lune uses **v_prediction** with flow matching | |
| - Lyra fuses CLIP + T5 for richer semantics | |
| """) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "masterpiece, best quality, 1girl, blue hair, school uniform, cherry blossoms, detailed background", | |
| "lowres, bad anatomy, worst quality, low quality", | |
| "Illustrious XL", | |
| 2, 25, 7.0, 1024, 1024, 0.0, False, False, 42, False | |
| ], | |
| [ | |
| "A majestic mountain landscape at golden hour, crystal clear lake, photorealistic, 8k", | |
| "blurry, low quality", | |
| "SDXL Base", | |
| 1, 30, 7.5, 1024, 1024, 0.0, False, False, 123, False | |
| ], | |
| [ | |
| "cyberpunk city at night, neon lights, rain, highly detailed", | |
| "low quality, blurry", | |
| "Flow-Lune (SD1.5)", | |
| 1, 20, 7.5, 512, 512, 2.5, True, False, 456, False | |
| ], | |
| ], | |
| inputs=[ | |
| prompt, negative_prompt, model_choice, clip_skip, | |
| num_steps, cfg_scale, width, height, shift, | |
| use_flow_matching, use_lyra, seed, randomize_seed | |
| ], | |
| outputs=[output_image_standard, output_image_lyra, output_seed], | |
| fn=generate_image, | |
| cache_examples=False | |
| ) | |
| # Event handlers | |
| def on_model_change(model_name): | |
| """Update defaults based on model.""" | |
| if "Illustrious" in model_name: | |
| return { | |
| clip_skip: gr.update(value=2), | |
| width: gr.update(value=1024), | |
| height: gr.update(value=1024), | |
| num_steps: gr.update(value=25), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0) | |
| } | |
| elif "SDXL" in model_name: | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=1024), | |
| height: gr.update(value=1024), | |
| num_steps: gr.update(value=30), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0) | |
| } | |
| elif "Lune" in model_name: | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=512), | |
| height: gr.update(value=512), | |
| num_steps: gr.update(value=20), | |
| use_flow_matching: gr.update(value=True), | |
| shift: gr.update(value=2.5) | |
| } | |
| else: # SD1.5 Base | |
| return { | |
| clip_skip: gr.update(value=1), | |
| width: gr.update(value=512), | |
| height: gr.update(value=512), | |
| num_steps: gr.update(value=30), | |
| use_flow_matching: gr.update(value=False), | |
| shift: gr.update(value=0.0) | |
| } | |
| def on_lyra_toggle(enabled): | |
| """Show/hide Lyra comparison.""" | |
| if enabled: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Standard"), | |
| output_image_lyra: gr.update(visible=True, label="Lyra Fusion 🎵") | |
| } | |
| else: | |
| return { | |
| output_image_standard: gr.update(visible=True, label="Generated Image"), | |
| output_image_lyra: gr.update(visible=False) | |
| } | |
| model_choice.change( | |
| fn=on_model_change, | |
| inputs=[model_choice], | |
| outputs=[clip_skip, width, height, num_steps, use_flow_matching, shift] | |
| ) | |
| use_lyra.change( | |
| fn=on_lyra_toggle, | |
| inputs=[use_lyra], | |
| outputs=[output_image_standard, output_image_lyra] | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[ | |
| prompt, negative_prompt, model_choice, clip_skip, | |
| num_steps, cfg_scale, width, height, shift, | |
| use_flow_matching, use_lyra, seed, randomize_seed | |
| ], | |
| outputs=[output_image_standard, output_image_lyra, output_seed] | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # LAUNCH | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) | |
| demo.launch(show_api=False) |