import gradio as gr import torch import spaces from diffusers import ( DiffusionPipeline, AutoPipelineForText2Image, AutoPipelineForImage2Image, DPMSolverMultistepScheduler ) from transformers import CLIPTextModel, CLIPTokenizer from huggingface_hub import hf_hub_download, list_models import numpy as np from PIL import Image import io import base64 import json import os import gc import logging import requests import tempfile from typing import Optional, List, Dict, Any import random # Configuration class ModelConfig: def __init__(self): self.t2v_model = "Wan-AI/Wan2.1-T2V-14B" self.i2v_model = "Wan-AI/Wan2.1-I2V-14B-480P" self.ti2v_model = "Wan-AI/Wan2.1-TI2V-14B-720P" self.animate_model = "Wan-AI/Wan2.1-Animate-14B" self.default_width = 1024 self.default_height = 1024 self.max_seed = 2**32 - 1 self.default_steps = 25 self.default_guidance = 7.5 # LoRA Manager with ZeroGPU optimizations class LoRAManager: def __init__(self): self.loaded_loras = {} self.logger = self._setup_logger() def _setup_logger(self) -> logging.Logger: """Setup logging configuration""" logger = logging.getLogger("LoRAManager") logger.setLevel(logging.INFO) if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) return logger def get_popular_loras(self) -> List[Dict[str, str]]: """Get a list of popular LoRA models including Lightning LoRA""" return [ {"name": "None", "repo_id": "", "filename": "", "type": "none"}, {"name": "Lightning LoRA (4-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_4step_lora.safetensors", "type": "lightning"}, {"name": "Lightning LoRA (8-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_8step_lora.safetensors", "type": "lightning"}, {"name": "Anime Style", "repo_id": "Linaqruf/anime-detailer-xl-lora", "filename": "anime-detailer-xl.safetensors", "type": "style"}, {"name": "Pixel Art", "repo_id": "nerijs/pixel-art-xl", "filename": "pixel-art-xl.safetensors", "type": "style"}, {"name": "Watercolor", "repo_id": "SG161222/RealVisXL_V4.0_Lightning", "filename": "RealVisXL_V4.0_Lightning.safetensors", "type": "style"}, {"name": "Photorealistic", "repo_id": "SG161222/RealVisXL_V4.0", "filename": "RealVisXL_V4.0.safetensors", "type": "style"}, {"name": "Relighting LoRA", "repo_id": "Wan-AI/Wan2.2-Animate-14B", "filename": "relighting_lora.safetensors", "type": "relighting"}, ] def load_lora_from_hub(self, repo_id: str, filename: str, weight: float = 1.0) -> Optional[str]: """Load LoRA from Hugging Face Hub""" if not repo_id or not filename: return None key = f"{repo_id}/{filename}" if key in self.loaded_loras: return self.loaded_loras[key] try: lora_path = hf_hub_download(repo_id=repo_id, filename=filename) self.loaded_loras[key] = lora_path self.logger.info(f"Downloaded LoRA: {repo_id}/{filename}") return lora_path except Exception as e: self.logger.error(f"Failed to load LoRA {repo_id}/{filename}: {e}") return None def load_lora_from_civitai(self, model_id: str, weight: float = 1.0) -> Optional[str]: """Load LoRA from CivitAI using direct download""" if not model_id: return None key = f"civitai/{model_id}" if key in self.loaded_loras: return self.loaded_loras[key] try: civitai_url = f"https://civitai.com/api/download/models/{model_id}" with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as tmp_file: self.logger.info(f"Downloading LoRA from CivitAI: {model_id}") headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} response = requests.get(civitai_url, headers=headers) response.raise_for_status() tmp_file.write(response.content) self.loaded_loras[key] = tmp_file.name self.logger.info(f"Downloaded CivitAI LoRA: {model_id}") return tmp_file.name except Exception as e: self.logger.error(f"Failed to load LoRA from CivitAI {model_id}: {e}") return None def apply_lora_to_pipeline(self, pipeline, lora_path: str, weight: float = 1.0): """Apply LoRA to pipeline""" try: if lora_path and os.path.exists(lora_path): pipeline.load_lora_weights(lora_path) self.logger.info(f"Applied LoRA: {lora_path} with weight {weight}") return True except Exception as e: self.logger.error(f"Failed to apply LoRA: {e}") return False # Global configuration and managers model_config = ModelConfig() lora_manager = LoRAManager() logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Global variables for models (loaded on GPU when needed) txt2img_pipe = None img2img_pipe = None def calculate_duration(prompt: str, width: int, height: int, num_inference_steps: int, lora_selection: str) -> int: """ Calculate dynamic duration based on generation parameters for ZeroGPU optimization. Base duration: ~3.75 seconds per step for SDXL on H200 """ base_duration = num_inference_steps * 3.75 resolution_factor = (width * height) / (1024 * 1024) resolution_overhead = base_duration * (resolution_factor - 1) * 0.2 lora_overhead = 15 if lora_selection != "None" else 0 total_duration = base_duration + resolution_overhead + lora_overhead + 20 return max(60, min(300, int(total_duration))) # Clamp between 60-300 seconds @spaces.GPU(duration=calculate_duration) def load_models(): """Load models on GPU with ZeroGPU optimization""" global txt2img_pipe, img2img_pipe if txt2img_pipe is None: logger.info("Loading SDXL text-to-image model on ZeroGPU...") txt2img_pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) txt2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(txt2img_pipe.scheduler.config) txt2img_pipe.to("cuda") txt2img_pipe.enable_attention_slicing() txt2img_pipe.enable_vae_slicing() if img2img_pipe is None: logger.info("Loading SDXL image-to-image model on ZeroGPU...") img2img_pipe = AutoPipelineForImage2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16" ) img2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(img2img_pipe.scheduler.config) img2img_pipe.to("cuda") img2img_pipe.enable_attention_slicing() img2img_pipe.enable_vae_slicing() def clear_memory(): """Clear memory cache for ZeroGPU""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @spaces.GPU(duration=calculate_duration) def generate_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""): """Generate image from text prompt with LoRA support using ZeroGPU""" if not prompt or not prompt.strip(): return None, "Error: Prompt cannot be empty" # Validate parameters if width < 256 or width > 2048 or height < 256 or height > 2048: return None, f"Error: Width and height must be between 256 and 2048, got {width}x{height}" if guidance_scale < 0.1 or guidance_scale > 20: return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}" if num_inference_steps < 1 or num_inference_steps > 100: return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}" try: # Load models on GPU load_models() # Generate seed if needed if seed == -1: seed = torch.randint(0, model_config.max_seed, (1,)).item() generator = torch.Generator(device="cuda").manual_seed(seed) # Handle LoRA loading current_pipe = txt2img_pipe lora_info = "" lightning_used = False # Load custom LoRA if specified if custom_lora_repo and custom_lora_filename: if "civitai.com" in custom_lora_repo: model_id = custom_lora_repo.split("/")[-1] lora_path = lora_manager.load_lora_from_civitai(model_id, lora_weight) else: lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight) if lora_path: lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}" # Load predefined LoRA if selected elif lora_selection != "None": popular_loras = lora_manager.get_popular_loras() selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None) if selected_lora and selected_lora["repo_id"]: lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight) if lora_path: lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) lora_info = f", LoRA: {lora_selection}" if selected_lora.get("type") == "lightning": if "4step" in selected_lora["filename"]: num_inference_steps = 4 elif "8step" in selected_lora["filename"]: num_inference_steps = 8 lightning_used = True # Generate image result = current_pipe( prompt=prompt.strip(), negative_prompt=negative_prompt.strip(), width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ) lightning_info = " (Lightning optimized)" if lightning_used else "" logger.info(f"Text-to-image generation completed on ZeroGPU") return result.images[0], f"Generated with seed: {seed}, steps: {num_inference_steps}{lightning_info}{lora_info}" except Exception as e: clear_memory() logger.error(f"Text-to-image generation failed: {e}") return None, f"Error: {str(e)}" @spaces.GPU(duration=calculate_duration) def generate_image_to_image(prompt, input_image, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""): """Generate image from input image and prompt with LoRA support using ZeroGPU""" if input_image is None: return None, "Error: Please provide an input image" if not prompt or not prompt.strip(): return None, "Error: Prompt cannot be empty" # Validate parameters if not 0.1 <= strength <= 1.0: return None, f"Error: Strength must be between 0.1 and 1.0, got {strength}" if guidance_scale < 0.1 or guidance_scale > 20: return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}" if num_inference_steps < 1 or num_inference_steps > 100: return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}" try: # Load models on GPU load_models() # Generate seed if needed if seed == -1: seed = torch.randint(0, model_config.max_seed, (1,)).item() generator = torch.Generator(device="cuda").manual_seed(seed) # Resize input image if needed target_size = (model_config.default_width, model_config.default_height) if input_image.size != target_size: input_image = input_image.resize(target_size) # Handle LoRA loading current_pipe = img2img_pipe lora_info = "" # Load custom LoRA if specified if custom_lora_repo and custom_lora_filename: lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight) if lora_path: lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}" # Load predefined LoRA if selected elif lora_selection != "None": popular_loras = lora_manager.get_popular_loras() selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None) if selected_lora and selected_lora["repo_id"]: lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight) if lora_path: lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) lora_info = f", LoRA: {lora_selection}" # Generate image image = current_pipe( prompt=prompt.strip(), image=input_image, negative_prompt=negative_prompt.strip(), strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ).images[0] logger.info(f"Image-to-image generation completed on ZeroGPU") return image, f"Generated with seed: {seed}, strength: {strength}{lora_info}" except Exception as e: clear_memory() logger.error(f"Image-to-image generation failed: {e}") return None, f"Error: {str(e)}" def image_to_base64(image): """Convert PIL image to base64 string""" if image is None: return None buffer = io.BytesIO() image.save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()).decode() return f"data:image/png;base64,{img_str}" # API endpoint functions def api_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0): """API endpoint for text-to-image generation""" image, info = generate_text_to_image( prompt, negative_prompt, width, height, guidance_scale, num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename ) if image: return { "success": True, "image": image_to_base64(image), "info": info } else: return { "success": False, "error": info } def api_image_to_image(prompt, image_data, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0): """API endpoint for image-to-image generation""" try: # Decode base64 image if image_data.startswith('data:image'): image_data = image_data.split(',')[1] image_bytes = base64.b64decode(image_data) input_image = Image.open(io.BytesIO(image_bytes)) image, info = generate_image_to_image( prompt, input_image, negative_prompt, strength, guidance_scale, num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename ) if image: return { "success": True, "image": image_to_base64(image), "info": info } else: return { "success": False, "error": info } except Exception as e: return { "success": False,