| import gradio as gr |
| import torch |
| import spaces |
| from diffusers import ( |
| DiffusionPipeline, |
| AutoPipelineForText2Image, |
| AutoPipelineForImage2Image, |
| DPMSolverMultistepScheduler |
| ) |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from huggingface_hub import hf_hub_download, list_models |
| import numpy as np |
| from PIL import Image |
| import io |
| import base64 |
| import json |
| import os |
| import gc |
| import logging |
| import requests |
| import tempfile |
| from typing import Optional, List, Dict, Any |
| import random |
|
|
| |
| class ModelConfig: |
| def __init__(self): |
| self.t2v_model = "Wan-AI/Wan2.1-T2V-14B" |
| self.i2v_model = "Wan-AI/Wan2.1-I2V-14B-480P" |
| self.ti2v_model = "Wan-AI/Wan2.1-TI2V-14B-720P" |
| self.animate_model = "Wan-AI/Wan2.1-Animate-14B" |
| self.default_width = 1024 |
| self.default_height = 1024 |
| self.max_seed = 2**32 - 1 |
| self.default_steps = 25 |
| self.default_guidance = 7.5 |
|
|
| |
| class LoRAManager: |
| def __init__(self): |
| self.loaded_loras = {} |
| self.logger = self._setup_logger() |
| |
| def _setup_logger(self) -> logging.Logger: |
| """Setup logging configuration""" |
| logger = logging.getLogger("LoRAManager") |
| logger.setLevel(logging.INFO) |
| if not logger.handlers: |
| handler = logging.StreamHandler() |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| handler.setFormatter(formatter) |
| logger.addHandler(handler) |
| return logger |
|
|
| def get_popular_loras(self) -> List[Dict[str, str]]: |
| """Get a list of popular LoRA models including Lightning LoRA""" |
| return [ |
| {"name": "None", "repo_id": "", "filename": "", "type": "none"}, |
| {"name": "Lightning LoRA (4-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_4step_lora.safetensors", "type": "lightning"}, |
| {"name": "Lightning LoRA (8-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_8step_lora.safetensors", "type": "lightning"}, |
| {"name": "Anime Style", "repo_id": "Linaqruf/anime-detailer-xl-lora", "filename": "anime-detailer-xl.safetensors", "type": "style"}, |
| {"name": "Pixel Art", "repo_id": "nerijs/pixel-art-xl", "filename": "pixel-art-xl.safetensors", "type": "style"}, |
| {"name": "Watercolor", "repo_id": "SG161222/RealVisXL_V4.0_Lightning", "filename": "RealVisXL_V4.0_Lightning.safetensors", "type": "style"}, |
| {"name": "Photorealistic", "repo_id": "SG161222/RealVisXL_V4.0", "filename": "RealVisXL_V4.0.safetensors", "type": "style"}, |
| {"name": "Relighting LoRA", "repo_id": "Wan-AI/Wan2.2-Animate-14B", "filename": "relighting_lora.safetensors", "type": "relighting"}, |
| ] |
|
|
| def load_lora_from_hub(self, repo_id: str, filename: str, weight: float = 1.0) -> Optional[str]: |
| """Load LoRA from Hugging Face Hub""" |
| if not repo_id or not filename: |
| return None |
| |
| key = f"{repo_id}/{filename}" |
| if key in self.loaded_loras: |
| return self.loaded_loras[key] |
| |
| try: |
| lora_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| self.loaded_loras[key] = lora_path |
| self.logger.info(f"Downloaded LoRA: {repo_id}/{filename}") |
| return lora_path |
| except Exception as e: |
| self.logger.error(f"Failed to load LoRA {repo_id}/{filename}: {e}") |
| return None |
|
|
| def load_lora_from_civitai(self, model_id: str, weight: float = 1.0) -> Optional[str]: |
| """Load LoRA from CivitAI using direct download""" |
| if not model_id: |
| return None |
| |
| key = f"civitai/{model_id}" |
| if key in self.loaded_loras: |
| return self.loaded_loras[key] |
| |
| try: |
| civitai_url = f"https://civitai.com/api/download/models/{model_id}" |
| with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as tmp_file: |
| self.logger.info(f"Downloading LoRA from CivitAI: {model_id}") |
| headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'} |
| response = requests.get(civitai_url, headers=headers) |
| response.raise_for_status() |
| tmp_file.write(response.content) |
| self.loaded_loras[key] = tmp_file.name |
| self.logger.info(f"Downloaded CivitAI LoRA: {model_id}") |
| return tmp_file.name |
| except Exception as e: |
| self.logger.error(f"Failed to load LoRA from CivitAI {model_id}: {e}") |
| return None |
|
|
| def apply_lora_to_pipeline(self, pipeline, lora_path: str, weight: float = 1.0): |
| """Apply LoRA to pipeline""" |
| try: |
| if lora_path and os.path.exists(lora_path): |
| pipeline.load_lora_weights(lora_path) |
| self.logger.info(f"Applied LoRA: {lora_path} with weight {weight}") |
| return True |
| except Exception as e: |
| self.logger.error(f"Failed to apply LoRA: {e}") |
| return False |
|
|
| |
| model_config = ModelConfig() |
| lora_manager = LoRAManager() |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| txt2img_pipe = None |
| img2img_pipe = None |
|
|
| def calculate_duration(prompt: str, width: int, height: int, num_inference_steps: int, lora_selection: str) -> int: |
| """ |
| Calculate dynamic duration based on generation parameters for ZeroGPU optimization. |
| Base duration: ~3.75 seconds per step for SDXL on H200 |
| """ |
| base_duration = num_inference_steps * 3.75 |
| resolution_factor = (width * height) / (1024 * 1024) |
| resolution_overhead = base_duration * (resolution_factor - 1) * 0.2 |
| lora_overhead = 15 if lora_selection != "None" else 0 |
| total_duration = base_duration + resolution_overhead + lora_overhead + 20 |
| return max(60, min(300, int(total_duration))) |
|
|
| @spaces.GPU(duration=calculate_duration) |
| def load_models(): |
| """Load models on GPU with ZeroGPU optimization""" |
| global txt2img_pipe, img2img_pipe |
| |
| if txt2img_pipe is None: |
| logger.info("Loading SDXL text-to-image model on ZeroGPU...") |
| txt2img_pipe = DiffusionPipeline.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| variant="fp16" |
| ) |
| txt2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(txt2img_pipe.scheduler.config) |
| txt2img_pipe.to("cuda") |
| txt2img_pipe.enable_attention_slicing() |
| txt2img_pipe.enable_vae_slicing() |
| |
| if img2img_pipe is None: |
| logger.info("Loading SDXL image-to-image model on ZeroGPU...") |
| img2img_pipe = AutoPipelineForImage2Image.from_pretrained( |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| variant="fp16" |
| ) |
| img2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(img2img_pipe.scheduler.config) |
| img2img_pipe.to("cuda") |
| img2img_pipe.enable_attention_slicing() |
| img2img_pipe.enable_vae_slicing() |
|
|
| def clear_memory(): |
| """Clear memory cache for ZeroGPU""" |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| @spaces.GPU(duration=calculate_duration) |
| def generate_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""): |
| """Generate image from text prompt with LoRA support using ZeroGPU""" |
| if not prompt or not prompt.strip(): |
| return None, "Error: Prompt cannot be empty" |
| |
| |
| if width < 256 or width > 2048 or height < 256 or height > 2048: |
| return None, f"Error: Width and height must be between 256 and 2048, got {width}x{height}" |
| if guidance_scale < 0.1 or guidance_scale > 20: |
| return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}" |
| if num_inference_steps < 1 or num_inference_steps > 100: |
| return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}" |
| |
| try: |
| |
| load_models() |
| |
| |
| if seed == -1: |
| seed = torch.randint(0, model_config.max_seed, (1,)).item() |
| generator = torch.Generator(device="cuda").manual_seed(seed) |
| |
| |
| current_pipe = txt2img_pipe |
| lora_info = "" |
| lightning_used = False |
| |
| |
| if custom_lora_repo and custom_lora_filename: |
| if "civitai.com" in custom_lora_repo: |
| model_id = custom_lora_repo.split("/")[-1] |
| lora_path = lora_manager.load_lora_from_civitai(model_id, lora_weight) |
| else: |
| lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight) |
| if lora_path: |
| lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) |
| lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}" |
| |
| |
| elif lora_selection != "None": |
| popular_loras = lora_manager.get_popular_loras() |
| selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None) |
| if selected_lora and selected_lora["repo_id"]: |
| lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight) |
| if lora_path: |
| lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) |
| lora_info = f", LoRA: {lora_selection}" |
| if selected_lora.get("type") == "lightning": |
| if "4step" in selected_lora["filename"]: |
| num_inference_steps = 4 |
| elif "8step" in selected_lora["filename"]: |
| num_inference_steps = 8 |
| lightning_used = True |
| |
| |
| result = current_pipe( |
| prompt=prompt.strip(), |
| negative_prompt=negative_prompt.strip(), |
| width=width, |
| height=height, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator |
| ) |
| |
| lightning_info = " (Lightning optimized)" if lightning_used else "" |
| logger.info(f"Text-to-image generation completed on ZeroGPU") |
| return result.images[0], f"Generated with seed: {seed}, steps: {num_inference_steps}{lightning_info}{lora_info}" |
| |
| except Exception as e: |
| clear_memory() |
| logger.error(f"Text-to-image generation failed: {e}") |
| return None, f"Error: {str(e)}" |
|
|
| @spaces.GPU(duration=calculate_duration) |
| def generate_image_to_image(prompt, input_image, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""): |
| """Generate image from input image and prompt with LoRA support using ZeroGPU""" |
| if input_image is None: |
| return None, "Error: Please provide an input image" |
| if not prompt or not prompt.strip(): |
| return None, "Error: Prompt cannot be empty" |
| |
| |
| if not 0.1 <= strength <= 1.0: |
| return None, f"Error: Strength must be between 0.1 and 1.0, got {strength}" |
| if guidance_scale < 0.1 or guidance_scale > 20: |
| return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}" |
| if num_inference_steps < 1 or num_inference_steps > 100: |
| return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}" |
| |
| try: |
| |
| load_models() |
| |
| |
| if seed == -1: |
| seed = torch.randint(0, model_config.max_seed, (1,)).item() |
| generator = torch.Generator(device="cuda").manual_seed(seed) |
| |
| |
| target_size = (model_config.default_width, model_config.default_height) |
| if input_image.size != target_size: |
| input_image = input_image.resize(target_size) |
| |
| |
| current_pipe = img2img_pipe |
| lora_info = "" |
| |
| |
| if custom_lora_repo and custom_lora_filename: |
| lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight) |
| if lora_path: |
| lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) |
| lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}" |
| |
| |
| elif lora_selection != "None": |
| popular_loras = lora_manager.get_popular_loras() |
| selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None) |
| if selected_lora and selected_lora["repo_id"]: |
| lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight) |
| if lora_path: |
| lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight) |
| lora_info = f", LoRA: {lora_selection}" |
| |
| |
| image = current_pipe( |
| prompt=prompt.strip(), |
| image=input_image, |
| negative_prompt=negative_prompt.strip(), |
| strength=strength, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator |
| ).images[0] |
| |
| logger.info(f"Image-to-image generation completed on ZeroGPU") |
| return image, f"Generated with seed: {seed}, strength: {strength}{lora_info}" |
| |
| except Exception as e: |
| clear_memory() |
| logger.error(f"Image-to-image generation failed: {e}") |
| return None, f"Error: {str(e)}" |
|
|
| def image_to_base64(image): |
| """Convert PIL image to base64 string""" |
| if image is None: |
| return None |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| img_str = base64.b64encode(buffer.getvalue()).decode() |
| return f"data:image/png;base64,{img_str}" |
|
|
| |
| def api_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0): |
| """API endpoint for text-to-image generation""" |
| image, info = generate_text_to_image( |
| prompt, negative_prompt, width, height, guidance_scale, |
| num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename |
| ) |
| if image: |
| return { |
| "success": True, |
| "image": image_to_base64(image), |
| "info": info |
| } |
| else: |
| return { |
| "success": False, |
| "error": info |
| } |
|
|
| def api_image_to_image(prompt, image_data, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0): |
| """API endpoint for image-to-image generation""" |
| try: |
| |
| if image_data.startswith('data:image'): |
| image_data = image_data.split(',')[1] |
| image_bytes = base64.b64decode(image_data) |
| input_image = Image.open(io.BytesIO(image_bytes)) |
| |
| image, info = generate_image_to_image( |
| prompt, input_image, negative_prompt, strength, guidance_scale, |
| num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename |
| ) |
| if image: |
| return { |
| "success": True, |
| "image": image_to_base64(image), |
| "info": info |
| } |
| else: |
| return { |
| "success": False, |
| "error": info |
| } |
| except Exception as e: |
| return { |
| "success": False, |