import copy import gradio as gr import numpy as np import random import pickle import torch import os import sys import spaces from huggingface_hub import hf_hub_download, snapshot_download from diffusers import FluxPipeline from diffusers.models import FluxTransformer2DModel from diffusers.utils import SAFETENSORS_WEIGHTS_NAME from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE from safetensors.torch import load_file # Import essential classes for unpickling pruned models from utils import SparsityLinear, SkipConnection, AttentionSkipConnection # Create a simple mock module for pickle imports class MockModule: def __init__(self): # Add all the classes that pickle might need self.SparsityLinear = SparsityLinear self.SkipConnection = SkipConnection self.AttentionSkipConnection = AttentionSkipConnection # Self-reference for nested imports self.utils = self # Register the mock module for all sdib import paths mock = MockModule() sys.modules['sdib'] = mock sys.modules['sdib.utils'] = mock sys.modules['sdib.utils.utils'] = mock ################################################################################ ################################################################################ # Configuration PRUNING_RATIOS = [10, 15, 20] device = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max dtype = torch.bfloat16 print("🚀 Loading base Flux dev pipeline...") base_pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=dtype ) print("✅ Base Flux dev pipeline loaded!") # Global storage for all models pruned_models = {} lora_weights = None print("📥 Preloading all pruned models...") for ratio in PRUNING_RATIOS: try: print(f"Loading {ratio}% pruned model...") model_file = hf_hub_download( repo_id="LWZ19/ecodiff_flux_prune", filename=f"dev/pruned_model_{ratio}.pkl" ) with open(model_file, "rb") as f: pruned_model = pickle.load(f) pruned_model.to("cpu") pruned_model.to(dtype) pruned_models[ratio] = pruned_model print(f"✅ {ratio}% pruned model loaded!") except Exception as e: print(f"❌ Failed to load {ratio}% pruned model: {e}") pruned_models[ratio] = None print("📥 Preloading LoRA checkpoint for 20% pruning ratio...") try: lora_repo_path = snapshot_download( repo_id="LWZ19/ecodiff_flux_retrain_weights", allow_patterns=[f"dev/lora/prune_20/*"] ) lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", "prune_20", LORA_WEIGHT_NAME_SAFE)) print("✅ LoRA checkpoint loaded!") except Exception as e: print(f"❌ Failed to load LoRA checkpoint: {e}") lora_weights = None # Model state base_pipe.transformer = pruned_models[10].to(device) current_ratio = 10 def load_model(ratio, use_lora=False): """Apply specified model to the pipeline with optional LoRA""" global current_ratio try: # Switch to new pruned model if different ratio if current_ratio != ratio: base_pipe.transformer = pruned_models[ratio].to(device) current_ratio = ratio # Handle LoRA loading for 20% ratio if ratio == 20 and use_lora and lora_weights is not None: base_pipe.load_lora_weights(lora_weights) return f"✅ Ready with {ratio}% pruned Flux.1 [dev] + LoRA retrained" elif ratio == 20 and use_lora and lora_weights is None: return f"❌ LoRA weights not available for {ratio}% model" else: return f"✅ Ready with {ratio}% pruned Flux.1 [dev] (no retraining)" except Exception as e: return f"❌ Failed to apply weights: {str(e)}" @spaces.GPU(duration=99) def generate_image( ratio, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, use_lora=False, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) try: # Apply model configuration status = load_model(ratio, use_lora) if "❌" in status: return None, seed, status # Move pipeline to GPU for generation base_pipe.to(device) generator = torch.Generator(device).manual_seed(seed) # Generate image using base pipeline (already configured with pruned model) image = base_pipe( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] # Clean up GPU memory torch.cuda.empty_cache() if torch.cuda.is_available() else None if ratio == 20 and use_lora and lora_weights is not None: base_pipe.unload_lora_weights() result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev] + LoRA retrained" else: result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev]" return image, seed, result_status except Exception as e: error_status = f"❌ Generation failed: {str(e)}\nPlease retry after a few minutes." return None, seed, error_status examples = [ "A clock tower floating in a sea of clouds", "A cozy library with a roaring fireplace", "A cat playing football", "A magical forest with glowing mushrooms", "An astronaut riding a rainbow unicorn", ] css = """ #col-container { margin: 0 auto; max-width: 720px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion") gr.Markdown("Generate images using pruned Flux.1 [dev] models with multiple pruning ratios. For 20% pruning, optional LoRA retrained weights are available.") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) with gr.Row(): ratio = gr.Dropdown( choices=PRUNING_RATIOS, value=10, label="Pruning Ratio (%)", info="Select pruning ratio", scale=1 ) with gr.Row(visible=False) as lora_row: use_lora = gr.Checkbox( label="Use LoRA Retrained Model", value=False, info="Enable LoRA fine-tuned weights (only available for 20% pruning)" ) generate_button = gr.Button("Generate", variant="primary") result = gr.Image(label="Result", show_label=False) status_display = gr.Textbox(label="Status", interactive=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=2048, step=32, value=1024, ) height = gr.Slider( label="Height", minimum=512, maximum=2048, step=32, value=1024, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=50, ) gr.Examples(examples=examples, inputs=[prompt]) gr.Markdown(""" ### About EcoDiff Flux.1 [dev] Unified This space showcases multiple pruned Flux.1 [dev] models using learnable pruning techniques with optional LoRA fine-tuning. - **Base Model**: Flux.1 [dev] - **Pruning Ratios**: 10%, 15%, 20% of parameters removed - **LoRA Enhancement**: Available for 20% pruning ratio with retrained weights for improved quality """) def update_lora_visibility(ratio_value): return gr.update(visible=(ratio_value == 20)) ratio.change( fn=update_lora_visibility, inputs=[ratio], outputs=[lora_row] ) generate_button.click( fn=generate_image, inputs=[ ratio, prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, use_lora, ], outputs=[result, seed, status_display], ) if __name__ == "__main__": demo.launch()