LWZ19's picture
Update app.py
92be97a verified
import copy
import gradio as gr
import numpy as np
import random
import pickle
import torch
import os
import sys
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers import FluxPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
from safetensors.torch import load_file
# Import essential classes for unpickling pruned models
from utils import SparsityLinear, SkipConnection, AttentionSkipConnection
# Create a simple mock module for pickle imports
class MockModule:
def __init__(self):
# Add all the classes that pickle might need
self.SparsityLinear = SparsityLinear
self.SkipConnection = SkipConnection
self.AttentionSkipConnection = AttentionSkipConnection
# Self-reference for nested imports
self.utils = self
# Register the mock module for all sdib import paths
mock = MockModule()
sys.modules['sdib'] = mock
sys.modules['sdib.utils'] = mock
sys.modules['sdib.utils.utils'] = mock
################################################################################
################################################################################
# Configuration
PRUNING_RATIOS = [10, 15, 20]
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
dtype = torch.bfloat16
print("🚀 Loading base Flux dev pipeline...")
base_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype
)
print("✅ Base Flux dev pipeline loaded!")
# Global storage for all models
pruned_models = {}
lora_weights = None
print("📥 Preloading all pruned models...")
for ratio in PRUNING_RATIOS:
try:
print(f"Loading {ratio}% pruned model...")
model_file = hf_hub_download(
repo_id="LWZ19/ecodiff_flux_prune",
filename=f"dev/pruned_model_{ratio}.pkl"
)
with open(model_file, "rb") as f:
pruned_model = pickle.load(f)
pruned_model.to("cpu")
pruned_model.to(dtype)
pruned_models[ratio] = pruned_model
print(f"✅ {ratio}% pruned model loaded!")
except Exception as e:
print(f"❌ Failed to load {ratio}% pruned model: {e}")
pruned_models[ratio] = None
print("📥 Preloading LoRA checkpoint for 20% pruning ratio...")
try:
lora_repo_path = snapshot_download(
repo_id="LWZ19/ecodiff_flux_retrain_weights",
allow_patterns=[f"dev/lora/prune_20/*"]
)
lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", "prune_20", LORA_WEIGHT_NAME_SAFE))
print("✅ LoRA checkpoint loaded!")
except Exception as e:
print(f"❌ Failed to load LoRA checkpoint: {e}")
lora_weights = None
# Model state
base_pipe.transformer = pruned_models[10].to(device)
current_ratio = 10
def load_model(ratio, use_lora=False):
"""Apply specified model to the pipeline with optional LoRA"""
global current_ratio
try:
# Switch to new pruned model if different ratio
if current_ratio != ratio:
base_pipe.transformer = pruned_models[ratio].to(device)
current_ratio = ratio
# Handle LoRA loading for 20% ratio
if ratio == 20 and use_lora and lora_weights is not None:
base_pipe.load_lora_weights(lora_weights)
return f"✅ Ready with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
elif ratio == 20 and use_lora and lora_weights is None:
return f"❌ LoRA weights not available for {ratio}% model"
else:
return f"✅ Ready with {ratio}% pruned Flux.1 [dev] (no retraining)"
except Exception as e:
return f"❌ Failed to apply weights: {str(e)}"
@spaces.GPU(duration=99)
def generate_image(
ratio,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
use_lora=False,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Apply model configuration
status = load_model(ratio, use_lora)
if "❌" in status:
return None, seed, status
# Move pipeline to GPU for generation
base_pipe.to(device)
generator = torch.Generator(device).manual_seed(seed)
# Generate image using base pipeline (already configured with pruned model)
image = base_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
# Clean up GPU memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if ratio == 20 and use_lora and lora_weights is not None:
base_pipe.unload_lora_weights()
result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
else:
result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev]"
return image, seed, result_status
except Exception as e:
error_status = f"❌ Generation failed: {str(e)}\nPlease retry after a few minutes."
return None, seed, error_status
examples = [
"A clock tower floating in a sea of clouds",
"A cozy library with a roaring fireplace",
"A cat playing football",
"A magical forest with glowing mushrooms",
"An astronaut riding a rainbow unicorn",
]
css = """
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion")
gr.Markdown("Generate images using pruned Flux.1 [dev] models with multiple pruning ratios. For 20% pruning, optional LoRA retrained weights are available.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
ratio = gr.Dropdown(
choices=PRUNING_RATIOS,
value=10,
label="Pruning Ratio (%)",
info="Select pruning ratio",
scale=1
)
with gr.Row(visible=False) as lora_row:
use_lora = gr.Checkbox(
label="Use LoRA Retrained Model",
value=False,
info="Enable LoRA fine-tuned weights (only available for 20% pruning)"
)
generate_button = gr.Button("Generate", variant="primary")
result = gr.Image(label="Result", show_label=False)
status_display = gr.Textbox(label="Status", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.Markdown("""
### About EcoDiff Flux.1 [dev] Unified
This space showcases multiple pruned Flux.1 [dev] models using learnable pruning techniques with optional LoRA fine-tuning.
- **Base Model**: Flux.1 [dev]
- **Pruning Ratios**: 10%, 15%, 20% of parameters removed
- **LoRA Enhancement**: Available for 20% pruning ratio with retrained weights for improved quality
""")
def update_lora_visibility(ratio_value):
return gr.update(visible=(ratio_value == 20))
ratio.change(
fn=update_lora_visibility,
inputs=[ratio],
outputs=[lora_row]
)
generate_button.click(
fn=generate_image,
inputs=[
ratio,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
use_lora,
],
outputs=[result, seed, status_display],
)
if __name__ == "__main__":
demo.launch()