Spaces:
Sleeping
Sleeping
| import copy | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import pickle | |
| import torch | |
| import os | |
| import sys | |
| import spaces | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from diffusers import FluxPipeline | |
| from diffusers.models import FluxTransformer2DModel | |
| from diffusers.utils import SAFETENSORS_WEIGHTS_NAME | |
| from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE | |
| from safetensors.torch import load_file | |
| # Import essential classes for unpickling pruned models | |
| from utils import SparsityLinear, SkipConnection, AttentionSkipConnection | |
| # Create a simple mock module for pickle imports | |
| class MockModule: | |
| def __init__(self): | |
| # Add all the classes that pickle might need | |
| self.SparsityLinear = SparsityLinear | |
| self.SkipConnection = SkipConnection | |
| self.AttentionSkipConnection = AttentionSkipConnection | |
| # Self-reference for nested imports | |
| self.utils = self | |
| # Register the mock module for all sdib import paths | |
| mock = MockModule() | |
| sys.modules['sdib'] = mock | |
| sys.modules['sdib.utils'] = mock | |
| sys.modules['sdib.utils.utils'] = mock | |
| ################################################################################ | |
| ################################################################################ | |
| # Configuration | |
| PRUNING_RATIOS = [10, 15, 20] | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| dtype = torch.bfloat16 | |
| print("🚀 Loading base Flux dev pipeline...") | |
| base_pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| torch_dtype=dtype | |
| ) | |
| print("✅ Base Flux dev pipeline loaded!") | |
| # Global storage for all models | |
| pruned_models = {} | |
| lora_weights = None | |
| print("📥 Preloading all pruned models...") | |
| for ratio in PRUNING_RATIOS: | |
| try: | |
| print(f"Loading {ratio}% pruned model...") | |
| model_file = hf_hub_download( | |
| repo_id="LWZ19/ecodiff_flux_prune", | |
| filename=f"dev/pruned_model_{ratio}.pkl" | |
| ) | |
| with open(model_file, "rb") as f: | |
| pruned_model = pickle.load(f) | |
| pruned_model.to("cpu") | |
| pruned_model.to(dtype) | |
| pruned_models[ratio] = pruned_model | |
| print(f"✅ {ratio}% pruned model loaded!") | |
| except Exception as e: | |
| print(f"❌ Failed to load {ratio}% pruned model: {e}") | |
| pruned_models[ratio] = None | |
| print("📥 Preloading LoRA checkpoint for 20% pruning ratio...") | |
| try: | |
| lora_repo_path = snapshot_download( | |
| repo_id="LWZ19/ecodiff_flux_retrain_weights", | |
| allow_patterns=[f"dev/lora/prune_20/*"] | |
| ) | |
| lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", "prune_20", LORA_WEIGHT_NAME_SAFE)) | |
| print("✅ LoRA checkpoint loaded!") | |
| except Exception as e: | |
| print(f"❌ Failed to load LoRA checkpoint: {e}") | |
| lora_weights = None | |
| # Model state | |
| base_pipe.transformer = pruned_models[10].to(device) | |
| current_ratio = 10 | |
| def load_model(ratio, use_lora=False): | |
| """Apply specified model to the pipeline with optional LoRA""" | |
| global current_ratio | |
| try: | |
| # Switch to new pruned model if different ratio | |
| if current_ratio != ratio: | |
| base_pipe.transformer = pruned_models[ratio].to(device) | |
| current_ratio = ratio | |
| # Handle LoRA loading for 20% ratio | |
| if ratio == 20 and use_lora and lora_weights is not None: | |
| base_pipe.load_lora_weights(lora_weights) | |
| return f"✅ Ready with {ratio}% pruned Flux.1 [dev] + LoRA retrained" | |
| elif ratio == 20 and use_lora and lora_weights is None: | |
| return f"❌ LoRA weights not available for {ratio}% model" | |
| else: | |
| return f"✅ Ready with {ratio}% pruned Flux.1 [dev] (no retraining)" | |
| except Exception as e: | |
| return f"❌ Failed to apply weights: {str(e)}" | |
| def generate_image( | |
| ratio, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| use_lora=False, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| try: | |
| # Apply model configuration | |
| status = load_model(ratio, use_lora) | |
| if "❌" in status: | |
| return None, seed, status | |
| # Move pipeline to GPU for generation | |
| base_pipe.to(device) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| # Generate image using base pipeline (already configured with pruned model) | |
| image = base_pipe( | |
| prompt=prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ).images[0] | |
| # Clean up GPU memory | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| if ratio == 20 and use_lora and lora_weights is not None: | |
| base_pipe.unload_lora_weights() | |
| result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev] + LoRA retrained" | |
| else: | |
| result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev]" | |
| return image, seed, result_status | |
| except Exception as e: | |
| error_status = f"❌ Generation failed: {str(e)}\nPlease retry after a few minutes." | |
| return None, seed, error_status | |
| examples = [ | |
| "A clock tower floating in a sea of clouds", | |
| "A cozy library with a roaring fireplace", | |
| "A cat playing football", | |
| "A magical forest with glowing mushrooms", | |
| "An astronaut riding a rainbow unicorn", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 720px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion") | |
| gr.Markdown("Generate images using pruned Flux.1 [dev] models with multiple pruning ratios. For 20% pruning, optional LoRA retrained weights are available.") | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| ) | |
| with gr.Row(): | |
| ratio = gr.Dropdown( | |
| choices=PRUNING_RATIOS, | |
| value=10, | |
| label="Pruning Ratio (%)", | |
| info="Select pruning ratio", | |
| scale=1 | |
| ) | |
| with gr.Row(visible=False) as lora_row: | |
| use_lora = gr.Checkbox( | |
| label="Use LoRA Retrained Model", | |
| value=False, | |
| info="Enable LoRA fine-tuned weights (only available for 20% pruning)" | |
| ) | |
| generate_button = gr.Button("Generate", variant="primary") | |
| result = gr.Image(label="Result", show_label=False) | |
| status_display = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=512, | |
| maximum=2048, | |
| step=32, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=512, | |
| maximum=2048, | |
| step=32, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=3.5, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=50, | |
| ) | |
| gr.Examples(examples=examples, inputs=[prompt]) | |
| gr.Markdown(""" | |
| ### About EcoDiff Flux.1 [dev] Unified | |
| This space showcases multiple pruned Flux.1 [dev] models using learnable pruning techniques with optional LoRA fine-tuning. | |
| - **Base Model**: Flux.1 [dev] | |
| - **Pruning Ratios**: 10%, 15%, 20% of parameters removed | |
| - **LoRA Enhancement**: Available for 20% pruning ratio with retrained weights for improved quality | |
| """) | |
| def update_lora_visibility(ratio_value): | |
| return gr.update(visible=(ratio_value == 20)) | |
| ratio.change( | |
| fn=update_lora_visibility, | |
| inputs=[ratio], | |
| outputs=[lora_row] | |
| ) | |
| generate_button.click( | |
| fn=generate_image, | |
| inputs=[ | |
| ratio, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| use_lora, | |
| ], | |
| outputs=[result, seed, status_display], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |