jblast94's picture
Update app.py
982b0f1 verified
import gradio as gr
import torch
import spaces
from diffusers import (
DiffusionPipeline,
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
DPMSolverMultistepScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import hf_hub_download, list_models
import numpy as np
from PIL import Image
import io
import base64
import json
import os
import gc
import logging
import requests
import tempfile
from typing import Optional, List, Dict, Any
import random
# Configuration
class ModelConfig:
def __init__(self):
self.t2v_model = "Wan-AI/Wan2.1-T2V-14B"
self.i2v_model = "Wan-AI/Wan2.1-I2V-14B-480P"
self.ti2v_model = "Wan-AI/Wan2.1-TI2V-14B-720P"
self.animate_model = "Wan-AI/Wan2.1-Animate-14B"
self.default_width = 1024
self.default_height = 1024
self.max_seed = 2**32 - 1
self.default_steps = 25
self.default_guidance = 7.5
# LoRA Manager with ZeroGPU optimizations
class LoRAManager:
def __init__(self):
self.loaded_loras = {}
self.logger = self._setup_logger()
def _setup_logger(self) -> logging.Logger:
"""Setup logging configuration"""
logger = logging.getLogger("LoRAManager")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def get_popular_loras(self) -> List[Dict[str, str]]:
"""Get a list of popular LoRA models including Lightning LoRA"""
return [
{"name": "None", "repo_id": "", "filename": "", "type": "none"},
{"name": "Lightning LoRA (4-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_4step_lora.safetensors", "type": "lightning"},
{"name": "Lightning LoRA (8-step)", "repo_id": "ByteDance/SDXL-Lightning", "filename": "sdxl_lightning_8step_lora.safetensors", "type": "lightning"},
{"name": "Anime Style", "repo_id": "Linaqruf/anime-detailer-xl-lora", "filename": "anime-detailer-xl.safetensors", "type": "style"},
{"name": "Pixel Art", "repo_id": "nerijs/pixel-art-xl", "filename": "pixel-art-xl.safetensors", "type": "style"},
{"name": "Watercolor", "repo_id": "SG161222/RealVisXL_V4.0_Lightning", "filename": "RealVisXL_V4.0_Lightning.safetensors", "type": "style"},
{"name": "Photorealistic", "repo_id": "SG161222/RealVisXL_V4.0", "filename": "RealVisXL_V4.0.safetensors", "type": "style"},
{"name": "Relighting LoRA", "repo_id": "Wan-AI/Wan2.2-Animate-14B", "filename": "relighting_lora.safetensors", "type": "relighting"},
]
def load_lora_from_hub(self, repo_id: str, filename: str, weight: float = 1.0) -> Optional[str]:
"""Load LoRA from Hugging Face Hub"""
if not repo_id or not filename:
return None
key = f"{repo_id}/{filename}"
if key in self.loaded_loras:
return self.loaded_loras[key]
try:
lora_path = hf_hub_download(repo_id=repo_id, filename=filename)
self.loaded_loras[key] = lora_path
self.logger.info(f"Downloaded LoRA: {repo_id}/{filename}")
return lora_path
except Exception as e:
self.logger.error(f"Failed to load LoRA {repo_id}/{filename}: {e}")
return None
def load_lora_from_civitai(self, model_id: str, weight: float = 1.0) -> Optional[str]:
"""Load LoRA from CivitAI using direct download"""
if not model_id:
return None
key = f"civitai/{model_id}"
if key in self.loaded_loras:
return self.loaded_loras[key]
try:
civitai_url = f"https://civitai.com/api/download/models/{model_id}"
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as tmp_file:
self.logger.info(f"Downloading LoRA from CivitAI: {model_id}")
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
response = requests.get(civitai_url, headers=headers)
response.raise_for_status()
tmp_file.write(response.content)
self.loaded_loras[key] = tmp_file.name
self.logger.info(f"Downloaded CivitAI LoRA: {model_id}")
return tmp_file.name
except Exception as e:
self.logger.error(f"Failed to load LoRA from CivitAI {model_id}: {e}")
return None
def apply_lora_to_pipeline(self, pipeline, lora_path: str, weight: float = 1.0):
"""Apply LoRA to pipeline"""
try:
if lora_path and os.path.exists(lora_path):
pipeline.load_lora_weights(lora_path)
self.logger.info(f"Applied LoRA: {lora_path} with weight {weight}")
return True
except Exception as e:
self.logger.error(f"Failed to apply LoRA: {e}")
return False
# Global configuration and managers
model_config = ModelConfig()
lora_manager = LoRAManager()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Global variables for models (loaded on GPU when needed)
txt2img_pipe = None
img2img_pipe = None
def calculate_duration(prompt: str, width: int, height: int, num_inference_steps: int, lora_selection: str) -> int:
"""
Calculate dynamic duration based on generation parameters for ZeroGPU optimization.
Base duration: ~3.75 seconds per step for SDXL on H200
"""
base_duration = num_inference_steps * 3.75
resolution_factor = (width * height) / (1024 * 1024)
resolution_overhead = base_duration * (resolution_factor - 1) * 0.2
lora_overhead = 15 if lora_selection != "None" else 0
total_duration = base_duration + resolution_overhead + lora_overhead + 20
return max(60, min(300, int(total_duration))) # Clamp between 60-300 seconds
@spaces.GPU(duration=calculate_duration)
def load_models():
"""Load models on GPU with ZeroGPU optimization"""
global txt2img_pipe, img2img_pipe
if txt2img_pipe is None:
logger.info("Loading SDXL text-to-image model on ZeroGPU...")
txt2img_pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
)
txt2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(txt2img_pipe.scheduler.config)
txt2img_pipe.to("cuda")
txt2img_pipe.enable_attention_slicing()
txt2img_pipe.enable_vae_slicing()
if img2img_pipe is None:
logger.info("Loading SDXL image-to-image model on ZeroGPU...")
img2img_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
)
img2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(img2img_pipe.scheduler.config)
img2img_pipe.to("cuda")
img2img_pipe.enable_attention_slicing()
img2img_pipe.enable_vae_slicing()
def clear_memory():
"""Clear memory cache for ZeroGPU"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@spaces.GPU(duration=calculate_duration)
def generate_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""):
"""Generate image from text prompt with LoRA support using ZeroGPU"""
if not prompt or not prompt.strip():
return None, "Error: Prompt cannot be empty"
# Validate parameters
if width < 256 or width > 2048 or height < 256 or height > 2048:
return None, f"Error: Width and height must be between 256 and 2048, got {width}x{height}"
if guidance_scale < 0.1 or guidance_scale > 20:
return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}"
if num_inference_steps < 1 or num_inference_steps > 100:
return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}"
try:
# Load models on GPU
load_models()
# Generate seed if needed
if seed == -1:
seed = torch.randint(0, model_config.max_seed, (1,)).item()
generator = torch.Generator(device="cuda").manual_seed(seed)
# Handle LoRA loading
current_pipe = txt2img_pipe
lora_info = ""
lightning_used = False
# Load custom LoRA if specified
if custom_lora_repo and custom_lora_filename:
if "civitai.com" in custom_lora_repo:
model_id = custom_lora_repo.split("/")[-1]
lora_path = lora_manager.load_lora_from_civitai(model_id, lora_weight)
else:
lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight)
if lora_path:
lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight)
lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}"
# Load predefined LoRA if selected
elif lora_selection != "None":
popular_loras = lora_manager.get_popular_loras()
selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None)
if selected_lora and selected_lora["repo_id"]:
lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight)
if lora_path:
lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight)
lora_info = f", LoRA: {lora_selection}"
if selected_lora.get("type") == "lightning":
if "4step" in selected_lora["filename"]:
num_inference_steps = 4
elif "8step" in selected_lora["filename"]:
num_inference_steps = 8
lightning_used = True
# Generate image
result = current_pipe(
prompt=prompt.strip(),
negative_prompt=negative_prompt.strip(),
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
)
lightning_info = " (Lightning optimized)" if lightning_used else ""
logger.info(f"Text-to-image generation completed on ZeroGPU")
return result.images[0], f"Generated with seed: {seed}, steps: {num_inference_steps}{lightning_info}{lora_info}"
except Exception as e:
clear_memory()
logger.error(f"Text-to-image generation failed: {e}")
return None, f"Error: {str(e)}"
@spaces.GPU(duration=calculate_duration)
def generate_image_to_image(prompt, input_image, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_selection="None", lora_weight=1.0, custom_lora_repo="", custom_lora_filename=""):
"""Generate image from input image and prompt with LoRA support using ZeroGPU"""
if input_image is None:
return None, "Error: Please provide an input image"
if not prompt or not prompt.strip():
return None, "Error: Prompt cannot be empty"
# Validate parameters
if not 0.1 <= strength <= 1.0:
return None, f"Error: Strength must be between 0.1 and 1.0, got {strength}"
if guidance_scale < 0.1 or guidance_scale > 20:
return None, f"Error: Guidance scale must be between 0.1 and 20, got {guidance_scale}"
if num_inference_steps < 1 or num_inference_steps > 100:
return None, f"Error: Inference steps must be between 1 and 100, got {num_inference_steps}"
try:
# Load models on GPU
load_models()
# Generate seed if needed
if seed == -1:
seed = torch.randint(0, model_config.max_seed, (1,)).item()
generator = torch.Generator(device="cuda").manual_seed(seed)
# Resize input image if needed
target_size = (model_config.default_width, model_config.default_height)
if input_image.size != target_size:
input_image = input_image.resize(target_size)
# Handle LoRA loading
current_pipe = img2img_pipe
lora_info = ""
# Load custom LoRA if specified
if custom_lora_repo and custom_lora_filename:
lora_path = lora_manager.load_lora_from_hub(custom_lora_repo, custom_lora_filename, lora_weight)
if lora_path:
lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight)
lora_info = f", Custom LoRA: {custom_lora_repo}/{custom_lora_filename}"
# Load predefined LoRA if selected
elif lora_selection != "None":
popular_loras = lora_manager.get_popular_loras()
selected_lora = next((lora for lora in popular_loras if lora["name"] == lora_selection), None)
if selected_lora and selected_lora["repo_id"]:
lora_path = lora_manager.load_lora_from_hub(selected_lora["repo_id"], selected_lora["filename"], lora_weight)
if lora_path:
lora_manager.apply_lora_to_pipeline(current_pipe, lora_path, lora_weight)
lora_info = f", LoRA: {lora_selection}"
# Generate image
image = current_pipe(
prompt=prompt.strip(),
image=input_image,
negative_prompt=negative_prompt.strip(),
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
logger.info(f"Image-to-image generation completed on ZeroGPU")
return image, f"Generated with seed: {seed}, strength: {strength}{lora_info}"
except Exception as e:
clear_memory()
logger.error(f"Image-to-image generation failed: {e}")
return None, f"Error: {str(e)}"
def image_to_base64(image):
"""Convert PIL image to base64 string"""
if image is None:
return None
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
# API endpoint functions
def api_text_to_image(prompt, negative_prompt="", width=1024, height=1024, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0):
"""API endpoint for text-to-image generation"""
image, info = generate_text_to_image(
prompt, negative_prompt, width, height, guidance_scale,
num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename
)
if image:
return {
"success": True,
"image": image_to_base64(image),
"info": info
}
else:
return {
"success": False,
"error": info
}
def api_image_to_image(prompt, image_data, negative_prompt="", strength=0.7, guidance_scale=7.5, num_inference_steps=25, seed=-1, lora_repo="", lora_filename="", lora_weight=1.0):
"""API endpoint for image-to-image generation"""
try:
# Decode base64 image
if image_data.startswith('data:image'):
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
input_image = Image.open(io.BytesIO(image_bytes))
image, info = generate_image_to_image(
prompt, input_image, negative_prompt, strength, guidance_scale,
num_inference_steps, seed, "None", lora_weight, lora_repo, lora_filename
)
if image:
return {
"success": True,
"image": image_to_base64(image),
"info": info
}
else:
return {
"success": False,
"error": info
}
except Exception as e:
return {
"success": False,