LWZ19's picture
Update code
f5267ae
raw
history blame
8.69 kB
import copy
import gradio as gr
import numpy as np
import random
import pickle
import torch
import os
import sys
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers import FluxPipeline
from diffusers.models import FluxTransformer2DModel
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
from safetensors.torch import load_file
# Import essential classes for unpickling pruned models
from utils import SparsityLinear, SkipConnection, AttentionSkipConnection
# Create a simple mock module for pickle imports
class MockModule:
def __init__(self):
# Add all the classes that pickle might need
self.SparsityLinear = SparsityLinear
self.SkipConnection = SkipConnection
self.AttentionSkipConnection = AttentionSkipConnection
# Self-reference for nested imports
self.utils = self
# Register the mock module for all sdib import paths
mock = MockModule()
sys.modules['sdib'] = mock
sys.modules['sdib.utils'] = mock
sys.modules['sdib.utils.utils'] = mock
################################################################################
################################################################################
# Configuration
PRUNING_RATIOS = [25, 30]
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
dtype = torch.bfloat16
print("🚀 Loading base Flux dev pipeline...")
base_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=dtype
)
print("✅ Base Flux dev pipeline loaded!")
# Global storage for all models
pruned_models = {}
print("📥 Preloading all pruned models...")
for ratio in PRUNING_RATIOS:
try:
print(f"Loading {ratio}% pruned model...")
model_file = hf_hub_download(
repo_id="LWZ19/flux_prune",
filename=f"dev/pruned_model_{ratio}.pkl"
)
with open(model_file, "rb") as f:
pruned_model = pickle.load(f)
pruned_model.to("cpu")
pruned_model.to(dtype)
pruned_models[ratio] = pruned_model
print(f"✅ {ratio}% pruned model loaded!")
except Exception as e:
print(f"❌ Failed to load {ratio}% pruned model: {e}")
pruned_models[ratio] = None
print("📥 Preloading all LoRA weights...")
for ratio in PRUNING_RATIOS:
try:
lora_repo_path = snapshot_download(
repo_id="LWZ19/flux_retrain_weights",
allow_patterns=[f"dev/lora/prune_{ratio}/*"]
)
lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", f"prune_{ratio}", LORA_WEIGHT_NAME_SAFE))
print("✅ LoRA checkpoint loaded!")
# Temporarily set the pruned model as transformer
base_pipe.transformer = pruned_models[ratio]
# Load and merge LoRA weights
base_pipe.load_lora_weights(lora_weights)
base_pipe.fuse_lora()
base_pipe.unload_lora_weights()
# Store the merged model back
pruned_models[ratio] = base_pipe.transformer
print(f"✅ LoRA merged with {ratio}% pruned model!")
except Exception as e:
print(f"❌ Failed to load LoRA checkpoint: {e}")
# Model state
base_pipe.transformer = pruned_models[25].to(device)
current_ratio = 25
def load_model(ratio):
"""Apply specified model to the pipeline with optional LoRA"""
global current_ratio
try:
# Switch to new pruned model if different ratio
if current_ratio != ratio:
base_pipe.transformer = pruned_models[ratio].to(device)
current_ratio = ratio
return f"✅ Ready with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
except Exception as e:
return f"❌ Failed to apply weights: {str(e)}"
@spaces.GPU(duration=80)
def generate_image(
ratio,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Apply model configuration
status = load_model(ratio)
if "❌" in status:
return None, seed, status
# Move pipeline to GPU for generation
base_pipe.to(device)
generator = torch.Generator(device).manual_seed(seed)
# Generate image using base pipeline (already configured with pruned model)
image = base_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
# Clean up GPU memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None
result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
return image, seed, result_status
except Exception as e:
error_status = f"❌ Generation failed: {str(e)}\nPlease retry after a few minutes."
return None, seed, error_status
examples = [
"A clock tower floating in a sea of clouds",
"A cozy library with a roaring fireplace",
"A cat playing football",
"A magical forest with glowing mushrooms",
"An astronaut riding a rainbow unicorn",
]
css = """
#col-container {
margin: 0 auto;
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion")
gr.Markdown("Generate images using pruned Flux.1 [dev] models with 25% and 30% pruning ratios, both LoRA retrained.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Row():
ratio = gr.Dropdown(
choices=PRUNING_RATIOS,
value=25,
label="Pruning Ratio (%)",
info="Select pruning ratio",
scale=1
)
generate_button = gr.Button("Generate", variant="primary")
result = gr.Image(label="Result", show_label=False)
status_display = gr.Textbox(label="Status", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=2048,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.Markdown("""
### About EcoDiff Flux.1 [dev] Unified
This space showcases pruned Flux.1 [dev] models using learnable pruning techniques with LoRA fine-tuning.
- **Base Model**: Flux.1 [dev]
- **Pruning Ratios**: 25% and 30% of parameters removed
- **LoRA Enhancement**: Both models are retrained with LoRA weights for improved quality
""")
generate_button.click(
fn=generate_image,
inputs=[
ratio,
prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed, status_display],
)
if __name__ == "__main__":
demo.launch()