LWZ19 commited on
Commit
f090370
·
1 Parent(s): e4f79b8
Files changed (4) hide show
  1. README.md +7 -5
  2. app.py +258 -0
  3. requirements.txt +12 -0
  4. utils.py +459 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Ecodiff Flux Dev
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: EcoDiff Flux.1 [dev]
3
+ emoji: 🖼️
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Generate images from text prompts using a pruned model
11
+ startup_duration_timeout: 3h
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gradio as gr
3
+ import numpy as np
4
+ import random
5
+ import pickle
6
+ import torch
7
+ import os
8
+ import sys
9
+ import spaces
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ from diffusers import FluxPipeline
12
+ from diffusers.models import FluxTransformer2DModel
13
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
14
+ from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
15
+ from safetensors.torch import load_file
16
+
17
+ # Import essential classes for unpickling pruned models
18
+ from utils import SparsityLinear, SkipConnection, AttentionSkipConnection
19
+
20
+ # Create a simple mock module for pickle imports
21
+ class MockModule:
22
+ def __init__(self):
23
+ # Add all the classes that pickle might need
24
+ self.SparsityLinear = SparsityLinear
25
+ self.SkipConnection = SkipConnection
26
+ self.AttentionSkipConnection = AttentionSkipConnection
27
+ # Self-reference for nested imports
28
+ self.utils = self
29
+
30
+ # Register the mock module for all sdib import paths
31
+ mock = MockModule()
32
+ sys.modules['sdib'] = mock
33
+ sys.modules['sdib.utils'] = mock
34
+ sys.modules['sdib.utils.utils'] = mock
35
+
36
+
37
+ ################################################################################
38
+ ################################################################################
39
+
40
+
41
+ # Configuration
42
+ PRUNING_RATIOS = [10, 15, 20]
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ MAX_SEED = np.iinfo(np.int32).max
46
+ dtype = torch.bfloat16
47
+
48
+ print("🚀 Loading base Flux dev pipeline...")
49
+ base_pipe = FluxPipeline.from_pretrained(
50
+ "black-forest-labs/FLUX.1-dev",
51
+ torch_dtype=dtype
52
+ )
53
+ print("✅ Base Flux dev pipeline loaded!")
54
+
55
+ # Global storage for all models
56
+ pruned_models = {}
57
+
58
+ print("📥 Preloading all pruned models...")
59
+ for ratio in PRUNING_RATIOS:
60
+ try:
61
+ print(f"Loading {ratio}% pruned model...")
62
+ model_file = hf_hub_download(
63
+ repo_id="LWZ19/flux_dev_prune",
64
+ filename=f"pruned_model_{ratio}.pkl"
65
+ )
66
+
67
+ with open(model_file, "rb") as f:
68
+ pruned_model = pickle.load(f)
69
+ pruned_model.to(dtype)
70
+
71
+ pruned_models[ratio] = pruned_model
72
+ print(f"✅ {ratio}% pruned model loaded!")
73
+ except Exception as e:
74
+ print(f"❌ Failed to load {ratio}% pruned model: {e}")
75
+ pruned_models[ratio] = None
76
+
77
+
78
+ # Model state
79
+ base_pipe.transformer = pruned_models[10]
80
+ current_ratio = 10
81
+
82
+
83
+ def load_model(ratio):
84
+ """Apply specified model to the pipeline"""
85
+ global current_ratio
86
+
87
+ try:
88
+ # Switch to new pruned model if different ratio
89
+ if current_ratio != ratio:
90
+ base_pipe.transformer = pruned_models[ratio]
91
+ current_ratio = ratio
92
+
93
+ return f"✅ Ready with {ratio}% pruned Flux.1 [dev] (no retraining)"
94
+
95
+ except Exception as e:
96
+ return f"❌ Failed to apply weights: {str(e)}"
97
+
98
+
99
+ @spaces.GPU(duration=99)
100
+ def generate_image(
101
+ ratio,
102
+ prompt,
103
+ seed,
104
+ randomize_seed,
105
+ width,
106
+ height,
107
+ guidance_scale,
108
+ num_inference_steps,
109
+ progress=gr.Progress(track_tqdm=True),
110
+ ):
111
+ if randomize_seed:
112
+ seed = random.randint(0, MAX_SEED)
113
+
114
+ try:
115
+ # Apply model configuration
116
+ status = load_model(ratio)
117
+ if "❌" in status:
118
+ return None, seed, status
119
+
120
+ # Move pipeline to GPU for generation
121
+ base_pipe.to(device)
122
+
123
+ generator = torch.Generator(device).manual_seed(seed)
124
+
125
+ # Generate image using base pipeline (already configured with pruned model)
126
+ image = base_pipe(
127
+ prompt=prompt,
128
+ guidance_scale=guidance_scale,
129
+ num_inference_steps=num_inference_steps,
130
+ width=width,
131
+ height=height,
132
+ generator=generator,
133
+ ).images[0]
134
+
135
+ # Clean up GPU memory
136
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
137
+
138
+ result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev]"
139
+ return image, seed, result_status
140
+
141
+ except Exception as e:
142
+ error_status = f"❌ Generation failed: {str(e)}"
143
+ return None, seed, error_status
144
+
145
+ examples = [
146
+ "A clock tower floating in a sea of clouds",
147
+ "A cozy library with a roaring fireplace",
148
+ "A cat playing football",
149
+ "A magical forest with glowing mushrooms",
150
+ "An astronaut riding a rainbow unicorn",
151
+ ]
152
+
153
+ css = """
154
+ #col-container {
155
+ margin: 0 auto;
156
+ max-width: 720px;
157
+ }
158
+ """
159
+
160
+ with gr.Blocks(css=css) as demo:
161
+ with gr.Column(elem_id="col-container"):
162
+ gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion")
163
+ gr.Markdown("Generate images using pruned Flux.1 [dev] models with multiple pruning ratios")
164
+
165
+ with gr.Row():
166
+ prompt = gr.Text(
167
+ label="Prompt",
168
+ show_label=False,
169
+ max_lines=1,
170
+ placeholder="Enter your prompt",
171
+ container=False,
172
+ )
173
+
174
+ with gr.Row():
175
+ ratio = gr.Dropdown(
176
+ choices=PRUNING_RATIOS,
177
+ value=10,
178
+ label="Pruning Ratio (%)",
179
+ info="Select pruning ratio",
180
+ scale=1
181
+ )
182
+
183
+ generate_button = gr.Button("Generate", variant="primary")
184
+ result = gr.Image(label="Result", show_label=False)
185
+ status_display = gr.Textbox(label="Status", interactive=False)
186
+
187
+ with gr.Accordion("Advanced Settings", open=False):
188
+ seed = gr.Slider(
189
+ label="Seed",
190
+ minimum=0,
191
+ maximum=MAX_SEED,
192
+ step=1,
193
+ value=0,
194
+ )
195
+
196
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
197
+
198
+ with gr.Row():
199
+ width = gr.Slider(
200
+ label="Width",
201
+ minimum=512,
202
+ maximum=2048,
203
+ step=32,
204
+ value=1024,
205
+ )
206
+
207
+ height = gr.Slider(
208
+ label="Height",
209
+ minimum=512,
210
+ maximum=2048,
211
+ step=32,
212
+ value=1024,
213
+ )
214
+
215
+ with gr.Row():
216
+ guidance_scale = gr.Slider(
217
+ label="Guidance scale",
218
+ minimum=1.0,
219
+ maximum=10.0,
220
+ step=0.1,
221
+ value=3.5,
222
+ )
223
+
224
+ num_inference_steps = gr.Slider(
225
+ label="Number of inference steps",
226
+ minimum=1,
227
+ maximum=50,
228
+ step=1,
229
+ value=50,
230
+ )
231
+
232
+ gr.Examples(examples=examples, inputs=[prompt])
233
+
234
+ gr.Markdown("""
235
+ ### About EcoDiff Flux.1 [dev] Unified
236
+ This space showcases multiple pruned Flux.1 [dev] models using learnable pruning techniques.
237
+
238
+ - **Base Model**: Flux.1 [dev]
239
+ - **Pruning Ratios**: 10%, 15%, 20% of parameters removed
240
+ """)
241
+
242
+ generate_button.click(
243
+ fn=generate_image,
244
+ inputs=[
245
+ ratio,
246
+ prompt,
247
+ seed,
248
+ randomize_seed,
249
+ width,
250
+ height,
251
+ guidance_scale,
252
+ num_inference_steps,
253
+ ],
254
+ outputs=[result, seed, status_display],
255
+ )
256
+
257
+ if __name__ == "__main__":
258
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ diffusers==0.34.0
4
+ transformers
5
+ accelerate
6
+ safetensors
7
+ sentencepiece
8
+ peft
9
+ huggingface_hub
10
+ pillow
11
+ numpy
12
+ tqdm
utils.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # all utiles functions
2
+ import math
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ from diffusers.models.activations import GEGLU, GELU
7
+
8
+
9
+ def get_total_params(model, trainable: bool = True):
10
+ return sum(p.numel() for p in model.parameters() if p.requires_grad == trainable)
11
+
12
+
13
+ def get_precision(precision: str):
14
+ assert precision in ["fp16", "fp32", "bf16"], "precision must be either fp16, fp32, bf16"
15
+ if precision == "fp16":
16
+ torch_dtype = torch.float16
17
+ elif precision == "bf16":
18
+ torch_dtype = torch.bfloat16
19
+ elif precision == "fp32":
20
+ torch_dtype = torch.float32
21
+ elif precision == "fp64":
22
+ torch_dtype = torch.float64
23
+ return torch_dtype
24
+
25
+
26
+ def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
27
+ total_num_lambs = 0
28
+ num_activate_lambs = 0
29
+ binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks
30
+ for lamb in hooker.lambs:
31
+ total_num_lambs += lamb.size(0)
32
+ if binary:
33
+ assert threshold is None, "threshold should be None for binary mask"
34
+ num_activate_lambs += lamb.sum().item()
35
+ else:
36
+ assert threshold is not None, "threshold must be provided for non-binary mask"
37
+ num_activate_lambs += (lamb >= threshold).sum().item()
38
+ return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
39
+
40
+
41
+ def linear_layer_masking(module, lamb):
42
+ """
43
+ Apply soft masking to attention layer weights (K, Q, V projections).
44
+
45
+ This function multiplies attention layer weights by mask values without
46
+ removing parameters, allowing for gradual pruning during training.
47
+
48
+ Args:
49
+ module: Attention module containing to_k, to_q, to_v, and to_out
50
+ lamb: Per-head mask values to apply
51
+
52
+ Returns:
53
+ module: Modified module with masked weights
54
+ """
55
+ # perform masking on K Q V to see if it still works
56
+ inner_dim = module.to_k.in_features // module.heads
57
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
58
+ for module_to_remove in modules_to_remove:
59
+ for idx, head_mask in enumerate(lamb):
60
+ module_to_remove.weight.data[idx * inner_dim : (idx + 1) * inner_dim, :] *= head_mask
61
+ if module_to_remove.bias is not None:
62
+ module_to_remove.bias.data[idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
63
+
64
+ # perform masking on the output
65
+ for idx, head_mask in enumerate(lamb):
66
+ module.to_out[0].weight.data[:, idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
67
+ return module
68
+
69
+
70
+ # create dummy module for skip connection
71
+ class SkipConnection(torch.nn.Module):
72
+ """
73
+ Skip connection module for completely pruned layers.
74
+
75
+ When a layer is fully pruned, this module replaces it and simply
76
+ returns the input unchanged, maintaining the model's forward pass.
77
+ """
78
+ def __init__(self):
79
+ super(SkipConnection, self).__init__()
80
+
81
+ def forward(*args, **kwargs):
82
+ return args[1]
83
+
84
+
85
+ class AttentionSkipConnection(torch.nn.Module):
86
+ """
87
+ Model-specific skip connection for attention layers.
88
+
89
+ Handles different return patterns based on model architecture:
90
+ - SD3/FLUX models may return multiple values
91
+ - Other models return single hidden states
92
+
93
+ Args:
94
+ model_type: Type of diffusion model ("sd3", "flux", "flux_dev", etc.)
95
+ """
96
+ def __init__(self, model_type):
97
+ super(AttentionSkipConnection, self).__init__()
98
+ self.model_type = model_type
99
+
100
+ def forward(self, hidden_states=None, encoder_hidden_states=None, *args, **kwargs):
101
+ # Return the first non-None input, or hidden_states as default
102
+ if self.model_type not in ["sd3", "flux", "flux_dev"]:
103
+ return hidden_states
104
+
105
+ if encoder_hidden_states is not None:
106
+ return hidden_states, encoder_hidden_states
107
+
108
+ return hidden_states
109
+
110
+
111
+ def linear_layer_pruning(module, lamb, model_type):
112
+ """
113
+ Physically prune attention layers by removing parameters for pruned heads.
114
+
115
+ This function performs structural pruning through the following detailed steps:
116
+
117
+ 1. **Input Processing**: Latent features are fed into linear modules (to_k, to_q, to_v)
118
+ with shape (cross_attn_dim, inner_kv_dim / inner_dim)
119
+
120
+ 2. **Head Division**: Inner features are divided into attention heads, where:
121
+ - Query shape: [B, N, H, D] (batch, sequence, heads, head_dim)
122
+ - New hidden dimension = inner_dim * (unmasked_heads / total_heads)
123
+ - K, Q, V projections have shape [cross_attn_dim, inner_kv_dim / inner_dim]
124
+ - Each head occupies (heads * inner_dim) rows in the weight matrix
125
+ - **Important**: Input channels remain unchanged, only output rows are pruned
126
+
127
+ 3. **Attention Computation**: Updated latent features after scaled dot-product attention
128
+
129
+ 4. **Output Projection**: Final projection layer (to_out) from pruned inner_dim to original latent_dim
130
+ - Pruned dimension changes from input (dim=0) to output (dim=1)
131
+ - **Critical**: Output channels remain unchanged to maintain model compatibility
132
+
133
+ Args:
134
+ module: Attention module to prune (contains to_k, to_q, to_v, to_out)
135
+ lamb: Learned mask values per attention head (1=keep, 0=prune)
136
+ model_type: Model architecture type for skip connection handling
137
+
138
+ Returns:
139
+ module: Pruned attention module or AttentionSkipConnection if fully pruned
140
+
141
+ Note:
142
+ - Supports additional projections (add_k_proj, add_q_proj, add_v_proj) for certain architectures
143
+ - Handles both to_out and to_add_out projection layers
144
+ - Updates all relevant module parameters (inner_dim, query_dim, heads, etc.)
145
+ """
146
+
147
+ heads_to_keep = torch.nonzero(lamb).squeeze()
148
+ if len(heads_to_keep.shape) == 0:
149
+ # if only one head is kept, or none
150
+ heads_to_keep = heads_to_keep.unsqueeze(0)
151
+
152
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
153
+
154
+ if getattr(module, "add_k_proj", None) is not None:
155
+ modules_to_remove.extend([module.add_k_proj, module.add_q_proj, module.add_v_proj])
156
+
157
+ new_heads = int(lamb.sum().item())
158
+
159
+ if new_heads == 0:
160
+ return AttentionSkipConnection(model_type=model_type)
161
+
162
+ for module_to_remove in modules_to_remove:
163
+ # get head dimension
164
+ inner_dim = module_to_remove.out_features // module.heads
165
+ # place holder for the rows to keep
166
+ rows_to_keep = torch.zeros(
167
+ module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device
168
+ )
169
+
170
+ for idx in heads_to_keep:
171
+ rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
172
+
173
+ # overwrite the inner projection with masked projection
174
+ module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
175
+ if module_to_remove.bias is not None:
176
+ module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
177
+ module_to_remove.out_features = int(sum(rows_to_keep).item())
178
+
179
+ # Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
180
+ # with column masking, dim 1
181
+ if getattr(module, "to_out", None) is not None:
182
+ module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
183
+ module.to_out[0].in_features = int(sum(rows_to_keep).item())
184
+
185
+ if getattr(module, "to_add_out", None) is not None:
186
+ module.to_add_out.weight.data = module.to_add_out.weight.data[:, rows_to_keep]
187
+ module.to_add_out.in_features = int(sum(rows_to_keep).item())
188
+
189
+ # update parameters in the attention module
190
+ module.inner_dim = module.inner_dim // module.heads * new_heads
191
+ module.query_dim = module.query_dim // module.heads * new_heads
192
+ module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
193
+ module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
194
+ module.heads = new_heads
195
+ return module
196
+
197
+
198
+ def update_flux_single_transformer_projection(parent_module, module, lamb, old_inner_dim):
199
+ """
200
+ Updates the proj_out module in a FluxSingleTransformerBlock after attention head pruning.
201
+
202
+ FLUX models use a proj_out layer that takes concatenated input from both attention output
203
+ and MLP hidden states: torch.cat([attn_output, mlp_hidden_states], dim=2). When attention
204
+ heads are pruned, the attention dimension changes but the MLP dimension remains constant,
205
+ requiring careful weight matrix reconstruction.
206
+
207
+ Args:
208
+ parent_module: FluxSingleTransformerBlock containing the proj_out layer
209
+ module: Pruned attention module (or AttentionSkipConnection)
210
+ lamb: Original mask values used for pruning decisions
211
+ old_inner_dim: Original attention inner dimension before pruning
212
+
213
+ Returns:
214
+ parent_module: Updated parent module with corrected proj_out dimensions
215
+
216
+ Note:
217
+ - Handles skip connections when module is completely pruned
218
+ - Preserves MLP weights while updating attention weights
219
+ - Only modifies proj_out if dimensions actually changed
220
+ """
221
+ # Handle Skip Connection case (when module is completely pruned)
222
+ if isinstance(module, AttentionSkipConnection):
223
+ return parent_module
224
+
225
+ if hasattr(parent_module, "proj_out"):
226
+ # Calculate how much the attention dimension changed
227
+ attention_dim_change = old_inner_dim - module.inner_dim
228
+
229
+ if attention_dim_change > 0: # Only update if dimensions actually changed
230
+ # Get current weight matrix and dimensions
231
+ old_weight = parent_module.proj_out.weight.data
232
+ old_in_features = parent_module.proj_out.in_features
233
+
234
+ # Calculate new input dimension
235
+ new_in_features = old_in_features - attention_dim_change
236
+
237
+ # Create new weight matrix
238
+ new_weight = torch.zeros(
239
+ old_weight.shape[0], new_in_features,
240
+ device=old_weight.device, dtype=old_weight.dtype
241
+ )
242
+
243
+ # Calculate head dimensions
244
+ old_head_dim = old_inner_dim // lamb.shape[0]
245
+
246
+ # Create mask for attention columns to keep
247
+ heads_to_keep = torch.nonzero(lamb).squeeze()
248
+ if len(heads_to_keep.shape) == 0:
249
+ heads_to_keep = heads_to_keep.unsqueeze(0)
250
+
251
+ attn_cols_to_keep = torch.zeros(old_inner_dim, dtype=torch.bool, device=old_weight.device)
252
+ for idx in heads_to_keep:
253
+ attn_cols_to_keep[idx * old_head_dim : (idx + 1) * old_head_dim] = True
254
+
255
+ # Copy weights for kept attention heads
256
+ kept_indices = torch.nonzero(attn_cols_to_keep).squeeze()
257
+ for i, idx in enumerate(kept_indices):
258
+ if i < module.inner_dim:
259
+ new_weight[:, i] = old_weight[:, idx]
260
+
261
+ # Copy MLP weights (unchanged part)
262
+ mlp_start = old_inner_dim
263
+ if mlp_start < old_in_features: # Ensure there's actually an MLP part
264
+ new_weight[:, module.inner_dim:] = old_weight[:, mlp_start:]
265
+
266
+ # Update the projection layer
267
+ parent_module.proj_out.weight.data = new_weight
268
+ parent_module.proj_out.in_features = new_in_features
269
+ return parent_module
270
+
271
+
272
+ def ffn_linear_layer_pruning(module, lamb):
273
+ """
274
+ Prunes feed-forward network layers based on learned masks.
275
+
276
+ Note: This function could potentially be merged with linear_layer_pruning
277
+ for better code organization in future refactoring.
278
+
279
+ Args:
280
+ module: FFN module to prune
281
+ lamb: Learned mask values for pruning decisions
282
+
283
+ Returns:
284
+ Pruned module or SkipConnection if fully pruned
285
+ """
286
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
287
+ if len(lambda_to_keep) == 0:
288
+ return SkipConnection()
289
+
290
+ num_lambda = len(lambda_to_keep)
291
+
292
+ if hasattr(module, "net") and len(module.net) >= 3:
293
+ # Standard FFN blocks
294
+ if isinstance(module.net[0], GELU):
295
+ # linear layer weight remove before activation
296
+ module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :]
297
+ module.net[0].proj.out_features = num_lambda
298
+ if module.net[0].proj.bias is not None:
299
+ module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
300
+
301
+ update_act = GELU(module.net[0].proj.in_features, num_lambda)
302
+ update_act.proj = module.net[0].proj
303
+ module.net[0] = update_act
304
+ elif isinstance(module.net[0], GEGLU):
305
+ output_feature = module.net[0].proj.out_features
306
+ module.net[0].proj.weight.data = torch.cat(
307
+ [
308
+ module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :],
309
+ module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :],
310
+ ],
311
+ dim=0,
312
+ )
313
+ module.net[0].proj.out_features = num_lambda * 2
314
+ if module.net[0].proj.bias is not None:
315
+ module.net[0].proj.bias.data = torch.cat(
316
+ [
317
+ module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
318
+ module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
319
+ ]
320
+ )
321
+
322
+ update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
323
+ update_act.proj = module.net[0].proj
324
+ module.net[0] = update_act
325
+
326
+ # proj weight after activation
327
+ module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
328
+ module.net[2].in_features = num_lambda
329
+
330
+ elif hasattr(module, "proj_mlp") and hasattr(module, "proj_out"):
331
+ # FFN For FluxSingleTransformerBlock
332
+ module.proj_mlp.weight.data = module.proj_mlp.weight.data[lambda_to_keep, :]
333
+ module.proj_mlp.out_features = num_lambda
334
+ if module.proj_mlp.bias is not None:
335
+ module.proj_mlp.bias.data = module.proj_mlp.bias.data[lambda_to_keep]
336
+
337
+ # Update mlp_hidden_dim to reflect the new size
338
+ old_mlp_hidden_dim = module.mlp_hidden_dim
339
+ module.mlp_hidden_dim = num_lambda
340
+
341
+ # The proj_out layer takes concatenated input from both attention output and MLP output
342
+ # We need to keep the attention part unchanged but update the MLP part
343
+ old_dim = module.proj_out.in_features
344
+ attn_dim = old_dim - old_mlp_hidden_dim # Attention dimension
345
+ new_in_features = attn_dim + num_lambda
346
+
347
+ new_weight = torch.zeros(
348
+ module.proj_out.weight.shape[0], new_in_features,
349
+ device=module.proj_out.weight.device, dtype=module.proj_out.weight.dtype
350
+ )
351
+
352
+ # Copy attention part (unchanged)
353
+ new_weight[:, :attn_dim] = module.proj_out.weight.data[:, :attn_dim]
354
+
355
+ # Copy selected MLP parts
356
+ for i, idx in enumerate(lambda_to_keep):
357
+ new_weight[:, attn_dim + i] = module.proj_out.weight.data[:, attn_dim + idx]
358
+
359
+ # Update the projection layer
360
+ module.proj_out.weight.data = new_weight
361
+ module.proj_out.in_features = new_in_features
362
+
363
+ return module
364
+
365
+
366
+ # create SparsityLinear module
367
+ class SparsityLinear(torch.nn.Module):
368
+ """
369
+ Sparse linear layer that maintains original output dimensions.
370
+
371
+ This layer projects to a smaller intermediate dimension then expands
372
+ back to the original size, placing values only at specified indices.
373
+ Used for normalization layer pruning where output dimensions must match.
374
+
375
+ Args:
376
+ in_features: Input feature dimension
377
+ out_features: Output feature dimension (original size)
378
+ lambda_to_keep: Indices of features to keep active
379
+ num_lambda: Number of active features (len(lambda_to_keep))
380
+ """
381
+ def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
382
+ super(SparsityLinear, self).__init__()
383
+ self.sparse_proj = torch.nn.Linear(in_features, num_lambda)
384
+ self.out_features = out_features
385
+ self.lambda_to_keep = lambda_to_keep
386
+
387
+ def forward(self, x):
388
+ x = self.sparse_proj(x)
389
+ output = torch.zeros(x.size(0), self.out_features, device=x.device, dtype=x.dtype)
390
+ output[:, self.lambda_to_keep] = x
391
+ return output
392
+
393
+
394
+ def norm_layer_pruning(module, lamb):
395
+ """
396
+ Pruning the layer normalization layer for FLUX model
397
+ """
398
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
399
+ if len(lambda_to_keep) == 0:
400
+ return SkipConnection()
401
+
402
+ num_lambda = len(lambda_to_keep)
403
+
404
+ # get num_features
405
+ in_features = module.linear.in_features
406
+ out_features = module.linear.out_features
407
+
408
+ sparselinear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
409
+ sparselinear.sparse_proj.weight.data = module.linear.weight.data[lambda_to_keep]
410
+ sparselinear.sparse_proj.bias.data = module.linear.bias.data[lambda_to_keep]
411
+ module.linear = sparselinear
412
+ return module
413
+
414
+
415
+ def hard_concrete_distribution(
416
+ p, beta: float = 0.83, eps: float = 1e-8, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False
417
+ ):
418
+ u = torch.rand(p.shape).to(p.device)
419
+ if use_log:
420
+ p = torch.clamp(p, min=eps)
421
+ p = torch.log(p)
422
+ s = torch.sigmoid((torch.log(u + eps) - torch.log(1 - u + eps) + p) / beta)
423
+ s = s * (eta - gamma) + gamma
424
+ s = s.clamp(0, 1)
425
+ return s
426
+
427
+
428
+ def l0_complexity_loss(alpha, beta: float = 0.83, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False):
429
+ offset = beta * math.log(-gamma / eta)
430
+ loss = torch.sigmoid(alpha - offset).sum()
431
+ return loss
432
+
433
+
434
+ def calculate_reg_loss(
435
+ loss_reg,
436
+ lambs: List[torch.Tensor],
437
+ p: int,
438
+ use_log: bool = False,
439
+ mean=True,
440
+ reg=True, # regularize the lambda with bounded value range
441
+ reg_alpha=0.4, # alpha for the regularizer, avoid gradient vanishing
442
+ reg_beta=1, # beta for shifting the lambda toward positive value (avoid gradient vanishing)
443
+ ):
444
+ if p == 0:
445
+ for lamb in lambs:
446
+ loss_reg += l0_complexity_loss(lamb, use_log=use_log)
447
+ loss_reg /= len(lambs)
448
+ elif p == 1 or p == 2:
449
+ for lamb in lambs:
450
+ if reg:
451
+ lamb = torch.sigmoid(lamb * reg_alpha + reg_beta)
452
+ if mean:
453
+ loss_reg += lamb.norm(p) / len(lamb)
454
+ else:
455
+ loss_reg += lamb.norm(p)
456
+ loss_reg /= len(lambs)
457
+ else:
458
+ raise NotImplementedError
459
+ return loss_reg