ecodiff_sdxl / app.py
LWZ19's picture
Update app.py
bf4a5be verified
import copy
import gradio as gr
import numpy as np
import random
import pickle
import torch
import os
import sys
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers import StableDiffusionXLPipeline
from diffusers.models import UNet2DConditionModel
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
from safetensors.torch import load_file
# Import essential classes for unpickling pruned models
from utils import SparsityLinear, SkipConnection, AttentionSkipConnection
# Create a simple mock module for pickle imports
class MockModule:
def __init__(self):
# Add all the classes that pickle might need
self.SparsityLinear = SparsityLinear
self.SkipConnection = SkipConnection
self.AttentionSkipConnection = AttentionSkipConnection
# Self-reference for nested imports
self.utils = self
# Register the mock module for all sdib import paths
mock = MockModule()
sys.modules['sdib'] = mock
sys.modules['sdib.utils'] = mock
sys.modules['sdib.utils.utils'] = mock
# Configuration
PRUNING_RATIOS = [20, 25, 30, 35, 40, 50]
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
dtype = torch.bfloat16
print("🚀 Loading base SDXL pipeline...")
base_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=dtype,
)
print("✅ Base SDXL pipeline loaded!")
# Global storage for all models and weights
pruned_models = {}
pruned_weights = {}
lora_weights = {}
full_weights = {}
print("📥 Preloading all pruned models...")
for ratio in PRUNING_RATIOS:
try:
print(f"Loading {ratio}% pruned model...")
model_file = hf_hub_download(
repo_id="LWZ19/ecodiff_sdxl_prune",
filename=f"pruned_model_{ratio}.pkl"
)
with open(model_file, "rb") as f:
pruned_model = pickle.load(f)
pruned_model.to(dtype)
pruned_models[ratio] = pruned_model
pruned_weights[ratio] = copy.deepcopy(pruned_model.state_dict())
print(f"✅ {ratio}% pruned model loaded!")
except Exception as e:
print(f"❌ Failed to load {ratio}% pruned model: {e}")
pruned_models[ratio] = None
pruned_weights[ratio] = None
print("📥 Preloading all weights...")
for ratio in PRUNING_RATIOS:
# Preload LoRA weights
lora_weights[ratio] = None
try:
print(f"Loading {ratio}% LoRA weights...")
lora_file = hf_hub_download(
repo_id="LWZ19/ecodiff_sdxl_retrain_weights",
filename=f"lora/prune_{ratio}/{LORA_WEIGHT_NAME_SAFE}"
)
lora_weights[ratio] = os.path.dirname(lora_file)
print(f"✅ {ratio}% LoRA weights loaded!")
except Exception as e:
print(f"❌ Failed to preload {ratio}% LoRA weights: {e}")
# Preload full weights
full_weights[ratio] = None
try:
print(f"Loading {ratio}% full weights...")
full_file = hf_hub_download(
repo_id="LWZ19/ecodiff_sdxl_retrain_weights",
filename=f"full/prune_{ratio}/unet/{SAFETENSORS_WEIGHTS_NAME}"
)
full_weights[ratio] = load_file(full_file)
print(f"✅ {ratio}% full weights loaded!")
except Exception as e:
print(f"❌ Failed to preload {ratio}% full weights: {e}")
print("✅ All assets preloaded!")
# Model state
base_pipe.unet = pruned_models[20]
current_ratio = 20
current_retraining_type = "none"
def load_weights(ratio, retraining_type):
"""Apply specified model and weights to the pipeline"""
global current_ratio, current_retraining_type
try:
# Clear previous LoRA weights if switching from LoRA
if current_retraining_type == "lora" and (retraining_type != "lora" or current_ratio != ratio):
base_pipe.unload_lora_weights()
# Switch to new pruned model if different ratio
if current_ratio != ratio:
base_pipe.unet = pruned_models[ratio]
current_ratio = ratio
# Reset UNet to pruned model if switching from full
if current_retraining_type == "full" and retraining_type != "full":
base_pipe.unet.load_state_dict(pruned_weights[ratio], strict=False)
# Apply weights
if retraining_type == "lora" and lora_weights[ratio]:
base_pipe.load_lora_weights(lora_weights[ratio])
print(f"✅ Applied {ratio}% LoRA weights")
elif retraining_type == "full" and full_weights[ratio]:
base_pipe.unet.load_state_dict(full_weights[ratio], strict=False)
print(f"✅ Applied {ratio}% full weights")
elif retraining_type != "none":
return f"❌ {retraining_type} weights not available for {ratio}%"
current_retraining_type = retraining_type
if retraining_type == "none":
return f"✅ Ready with {ratio}% pruned SDXL (no retraining)"
else:
return f"✅ Ready with {ratio}% pruned SDXL + {retraining_type} retraining"
except Exception as e:
return f"❌ Failed to apply weights: {str(e)}"
@spaces.GPU(duration=45)
def generate_image(
ratio,
retraining_type,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Apply model configuration
status = load_weights(ratio, retraining_type)
if "❌" in status:
return None, seed, status
# Move pipeline to GPU for generation
base_pipe.to(device)
generator = torch.Generator(device).manual_seed(seed)
# Generate image using base pipeline (already configured with pruned model + weights)
image = base_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
# Clean up GPU memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
result_status = f"✅ Generated with {ratio}% pruned SDXL + {retraining_type}"
return image, seed, result_status
except Exception as e:
error_status = f"❌ Generation failed: {str(e)}"
return None, seed, error_status
examples = [
"A clock tower floating in a sea of clouds",
"A cozy library with a roaring fireplace",
"A cat playing football",
"A magical forest with glowing mushrooms",
"An astronaut riding a rainbow unicorn",
]
css = """
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# EcoDiff SDXL: Memory-Efficient Diffusion")
gr.Markdown("Generate images using pruned SDXL models with multiple pruning ratios and retraining options")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
ratio = gr.Dropdown(
choices=PRUNING_RATIOS,
value=20,
label="Pruning Ratio (%)",
info="Select pruning ratio",
scale=1
)
retraining_type = gr.Dropdown(
choices=["none", "lora", "full"],
value="none",
label="Retraining Type",
info="Choose retraining method",
scale=1
)
generate_button = gr.Button("Generate", variant="primary")
result = gr.Image(label="Result", show_label=False)
status_display = gr.Textbox(label="Status", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=1024,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=1024,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=2.0,
maximum=15.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.Markdown("""
### About EcoDiff SDXL Unified
This space showcases multiple pruned SDXL models using learnable pruning techniques.
- **Base Model**: Stable Diffusion XL Base 1.0
- **Pruning Ratios**: 20%, 25%, 30%, 35%, 40%, 50% of parameters removed
- **Retraining Options**: None (pruned only), LoRA (parameter efficient fine-tuning), Full (complete fine-tuning)
""")
generate_button.click(
fn=generate_image,
inputs=[
ratio,
retraining_type,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed, status_display],
)
if __name__ == "__main__":
demo.launch()