LWZ19 commited on
Commit
f5267ae
·
1 Parent(s): c516bef

Update code

Browse files
Files changed (4) hide show
  1. README.md +8 -6
  2. app.py +285 -0
  3. requirements.txt +12 -0
  4. utils.py +459 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: Ecodiff Flux Dev High
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.0.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: EcoDiff Flux.1 [dev] + LoRA
3
+ emoji: 🖼️
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Generate images from text prompts using a 25-30% pruned model with LoRA retraining
11
+ startup_duration_timeout: 3h
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gradio as gr
3
+ import numpy as np
4
+ import random
5
+ import pickle
6
+ import torch
7
+ import os
8
+ import sys
9
+ import spaces
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ from diffusers import FluxPipeline
12
+ from diffusers.models import FluxTransformer2DModel
13
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
14
+ from diffusers.loaders.lora_base import LORA_WEIGHT_NAME_SAFE
15
+ from safetensors.torch import load_file
16
+
17
+ # Import essential classes for unpickling pruned models
18
+ from utils import SparsityLinear, SkipConnection, AttentionSkipConnection
19
+
20
+ # Create a simple mock module for pickle imports
21
+ class MockModule:
22
+ def __init__(self):
23
+ # Add all the classes that pickle might need
24
+ self.SparsityLinear = SparsityLinear
25
+ self.SkipConnection = SkipConnection
26
+ self.AttentionSkipConnection = AttentionSkipConnection
27
+ # Self-reference for nested imports
28
+ self.utils = self
29
+
30
+ # Register the mock module for all sdib import paths
31
+ mock = MockModule()
32
+ sys.modules['sdib'] = mock
33
+ sys.modules['sdib.utils'] = mock
34
+ sys.modules['sdib.utils.utils'] = mock
35
+
36
+
37
+ ################################################################################
38
+ ################################################################################
39
+
40
+
41
+ # Configuration
42
+ PRUNING_RATIOS = [25, 30]
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ MAX_SEED = np.iinfo(np.int32).max
46
+ dtype = torch.bfloat16
47
+
48
+ print("🚀 Loading base Flux dev pipeline...")
49
+ base_pipe = FluxPipeline.from_pretrained(
50
+ "black-forest-labs/FLUX.1-dev",
51
+ torch_dtype=dtype
52
+ )
53
+ print("✅ Base Flux dev pipeline loaded!")
54
+
55
+ # Global storage for all models
56
+ pruned_models = {}
57
+
58
+ print("📥 Preloading all pruned models...")
59
+ for ratio in PRUNING_RATIOS:
60
+ try:
61
+ print(f"Loading {ratio}% pruned model...")
62
+ model_file = hf_hub_download(
63
+ repo_id="LWZ19/flux_prune",
64
+ filename=f"dev/pruned_model_{ratio}.pkl"
65
+ )
66
+
67
+ with open(model_file, "rb") as f:
68
+ pruned_model = pickle.load(f)
69
+ pruned_model.to("cpu")
70
+ pruned_model.to(dtype)
71
+
72
+ pruned_models[ratio] = pruned_model
73
+ print(f"✅ {ratio}% pruned model loaded!")
74
+ except Exception as e:
75
+ print(f"❌ Failed to load {ratio}% pruned model: {e}")
76
+ pruned_models[ratio] = None
77
+
78
+ print("📥 Preloading all LoRA weights...")
79
+ for ratio in PRUNING_RATIOS:
80
+ try:
81
+ lora_repo_path = snapshot_download(
82
+ repo_id="LWZ19/flux_retrain_weights",
83
+ allow_patterns=[f"dev/lora/prune_{ratio}/*"]
84
+ )
85
+ lora_weights = load_file(os.path.join(lora_repo_path, "dev", "lora", f"prune_{ratio}", LORA_WEIGHT_NAME_SAFE))
86
+ print("✅ LoRA checkpoint loaded!")
87
+
88
+ # Temporarily set the pruned model as transformer
89
+ base_pipe.transformer = pruned_models[ratio]
90
+
91
+ # Load and merge LoRA weights
92
+ base_pipe.load_lora_weights(lora_weights)
93
+ base_pipe.fuse_lora()
94
+ base_pipe.unload_lora_weights()
95
+
96
+ # Store the merged model back
97
+ pruned_models[ratio] = base_pipe.transformer
98
+
99
+ print(f"✅ LoRA merged with {ratio}% pruned model!")
100
+
101
+ except Exception as e:
102
+ print(f"❌ Failed to load LoRA checkpoint: {e}")
103
+
104
+ # Model state
105
+ base_pipe.transformer = pruned_models[25].to(device)
106
+ current_ratio = 25
107
+
108
+
109
+ def load_model(ratio):
110
+ """Apply specified model to the pipeline with optional LoRA"""
111
+ global current_ratio
112
+
113
+ try:
114
+ # Switch to new pruned model if different ratio
115
+ if current_ratio != ratio:
116
+ base_pipe.transformer = pruned_models[ratio].to(device)
117
+ current_ratio = ratio
118
+
119
+ return f"✅ Ready with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
120
+
121
+ except Exception as e:
122
+ return f"❌ Failed to apply weights: {str(e)}"
123
+
124
+
125
+ @spaces.GPU(duration=80)
126
+ def generate_image(
127
+ ratio,
128
+ prompt,
129
+ seed,
130
+ randomize_seed,
131
+ width,
132
+ height,
133
+ guidance_scale,
134
+ num_inference_steps,
135
+ progress=gr.Progress(track_tqdm=True),
136
+ ):
137
+ if randomize_seed:
138
+ seed = random.randint(0, MAX_SEED)
139
+
140
+ try:
141
+ # Apply model configuration
142
+ status = load_model(ratio)
143
+ if "❌" in status:
144
+ return None, seed, status
145
+
146
+ # Move pipeline to GPU for generation
147
+ base_pipe.to(device)
148
+
149
+ generator = torch.Generator(device).manual_seed(seed)
150
+
151
+ # Generate image using base pipeline (already configured with pruned model)
152
+ image = base_pipe(
153
+ prompt=prompt,
154
+ guidance_scale=guidance_scale,
155
+ num_inference_steps=num_inference_steps,
156
+ width=width,
157
+ height=height,
158
+ generator=generator,
159
+ ).images[0]
160
+
161
+ # Clean up GPU memory
162
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
163
+
164
+ result_status = f"✅ Generated with {ratio}% pruned Flux.1 [dev] + LoRA retrained"
165
+ return image, seed, result_status
166
+
167
+ except Exception as e:
168
+ error_status = f"❌ Generation failed: {str(e)}\nPlease retry after a few minutes."
169
+ return None, seed, error_status
170
+
171
+ examples = [
172
+ "A clock tower floating in a sea of clouds",
173
+ "A cozy library with a roaring fireplace",
174
+ "A cat playing football",
175
+ "A magical forest with glowing mushrooms",
176
+ "An astronaut riding a rainbow unicorn",
177
+ ]
178
+
179
+ css = """
180
+ #col-container {
181
+ margin: 0 auto;
182
+ max-width: 720px;
183
+ }
184
+ """
185
+
186
+ with gr.Blocks(css=css) as demo:
187
+ with gr.Column(elem_id="col-container"):
188
+ gr.Markdown("# EcoDiff Flux.1 [dev]: Memory-Efficient Diffusion")
189
+ gr.Markdown("Generate images using pruned Flux.1 [dev] models with 25% and 30% pruning ratios, both LoRA retrained.")
190
+
191
+ with gr.Row():
192
+ prompt = gr.Text(
193
+ label="Prompt",
194
+ show_label=False,
195
+ max_lines=1,
196
+ placeholder="Enter your prompt",
197
+ container=False,
198
+ )
199
+
200
+ with gr.Row():
201
+ ratio = gr.Dropdown(
202
+ choices=PRUNING_RATIOS,
203
+ value=25,
204
+ label="Pruning Ratio (%)",
205
+ info="Select pruning ratio",
206
+ scale=1
207
+ )
208
+
209
+ generate_button = gr.Button("Generate", variant="primary")
210
+ result = gr.Image(label="Result", show_label=False)
211
+ status_display = gr.Textbox(label="Status", interactive=False)
212
+
213
+ with gr.Accordion("Advanced Settings", open=False):
214
+ seed = gr.Slider(
215
+ label="Seed",
216
+ minimum=0,
217
+ maximum=MAX_SEED,
218
+ step=1,
219
+ value=0,
220
+ )
221
+
222
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
223
+
224
+ with gr.Row():
225
+ width = gr.Slider(
226
+ label="Width",
227
+ minimum=512,
228
+ maximum=2048,
229
+ step=32,
230
+ value=1024,
231
+ )
232
+
233
+ height = gr.Slider(
234
+ label="Height",
235
+ minimum=512,
236
+ maximum=2048,
237
+ step=32,
238
+ value=1024,
239
+ )
240
+
241
+ with gr.Row():
242
+ guidance_scale = gr.Slider(
243
+ label="Guidance scale",
244
+ minimum=1.0,
245
+ maximum=10.0,
246
+ step=0.1,
247
+ value=3.5,
248
+ )
249
+
250
+ num_inference_steps = gr.Slider(
251
+ label="Number of inference steps",
252
+ minimum=1,
253
+ maximum=50,
254
+ step=1,
255
+ value=50,
256
+ )
257
+
258
+ gr.Examples(examples=examples, inputs=[prompt])
259
+
260
+ gr.Markdown("""
261
+ ### About EcoDiff Flux.1 [dev] Unified
262
+ This space showcases pruned Flux.1 [dev] models using learnable pruning techniques with LoRA fine-tuning.
263
+
264
+ - **Base Model**: Flux.1 [dev]
265
+ - **Pruning Ratios**: 25% and 30% of parameters removed
266
+ - **LoRA Enhancement**: Both models are retrained with LoRA weights for improved quality
267
+ """)
268
+
269
+ generate_button.click(
270
+ fn=generate_image,
271
+ inputs=[
272
+ ratio,
273
+ prompt,
274
+ seed,
275
+ randomize_seed,
276
+ width,
277
+ height,
278
+ guidance_scale,
279
+ num_inference_steps,
280
+ ],
281
+ outputs=[result, seed, status_display],
282
+ )
283
+
284
+ if __name__ == "__main__":
285
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ diffusers==0.34.0
4
+ transformers
5
+ accelerate
6
+ safetensors
7
+ sentencepiece
8
+ peft
9
+ huggingface_hub
10
+ pillow
11
+ numpy
12
+ tqdm
utils.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # all utiles functions
2
+ import math
3
+ from typing import List, Optional
4
+
5
+ import torch
6
+ from diffusers.models.activations import GEGLU, GELU
7
+
8
+
9
+ def get_total_params(model, trainable: bool = True):
10
+ return sum(p.numel() for p in model.parameters() if p.requires_grad == trainable)
11
+
12
+
13
+ def get_precision(precision: str):
14
+ assert precision in ["fp16", "fp32", "bf16"], "precision must be either fp16, fp32, bf16"
15
+ if precision == "fp16":
16
+ torch_dtype = torch.float16
17
+ elif precision == "bf16":
18
+ torch_dtype = torch.bfloat16
19
+ elif precision == "fp32":
20
+ torch_dtype = torch.float32
21
+ elif precision == "fp64":
22
+ torch_dtype = torch.float64
23
+ return torch_dtype
24
+
25
+
26
+ def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
27
+ total_num_lambs = 0
28
+ num_activate_lambs = 0
29
+ binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks
30
+ for lamb in hooker.lambs:
31
+ total_num_lambs += lamb.size(0)
32
+ if binary:
33
+ assert threshold is None, "threshold should be None for binary mask"
34
+ num_activate_lambs += lamb.sum().item()
35
+ else:
36
+ assert threshold is not None, "threshold must be provided for non-binary mask"
37
+ num_activate_lambs += (lamb >= threshold).sum().item()
38
+ return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
39
+
40
+
41
+ def linear_layer_masking(module, lamb):
42
+ """
43
+ Apply soft masking to attention layer weights (K, Q, V projections).
44
+
45
+ This function multiplies attention layer weights by mask values without
46
+ removing parameters, allowing for gradual pruning during training.
47
+
48
+ Args:
49
+ module: Attention module containing to_k, to_q, to_v, and to_out
50
+ lamb: Per-head mask values to apply
51
+
52
+ Returns:
53
+ module: Modified module with masked weights
54
+ """
55
+ # perform masking on K Q V to see if it still works
56
+ inner_dim = module.to_k.in_features // module.heads
57
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
58
+ for module_to_remove in modules_to_remove:
59
+ for idx, head_mask in enumerate(lamb):
60
+ module_to_remove.weight.data[idx * inner_dim : (idx + 1) * inner_dim, :] *= head_mask
61
+ if module_to_remove.bias is not None:
62
+ module_to_remove.bias.data[idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
63
+
64
+ # perform masking on the output
65
+ for idx, head_mask in enumerate(lamb):
66
+ module.to_out[0].weight.data[:, idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
67
+ return module
68
+
69
+
70
+ # create dummy module for skip connection
71
+ class SkipConnection(torch.nn.Module):
72
+ """
73
+ Skip connection module for completely pruned layers.
74
+
75
+ When a layer is fully pruned, this module replaces it and simply
76
+ returns the input unchanged, maintaining the model's forward pass.
77
+ """
78
+ def __init__(self):
79
+ super(SkipConnection, self).__init__()
80
+
81
+ def forward(*args, **kwargs):
82
+ return args[1]
83
+
84
+
85
+ class AttentionSkipConnection(torch.nn.Module):
86
+ """
87
+ Model-specific skip connection for attention layers.
88
+
89
+ Handles different return patterns based on model architecture:
90
+ - SD3/FLUX models may return multiple values
91
+ - Other models return single hidden states
92
+
93
+ Args:
94
+ model_type: Type of diffusion model ("sd3", "flux", "flux_dev", etc.)
95
+ """
96
+ def __init__(self, model_type):
97
+ super(AttentionSkipConnection, self).__init__()
98
+ self.model_type = model_type
99
+
100
+ def forward(self, hidden_states=None, encoder_hidden_states=None, *args, **kwargs):
101
+ # Return the first non-None input, or hidden_states as default
102
+ if self.model_type not in ["sd3", "flux", "flux_dev"]:
103
+ return hidden_states
104
+
105
+ if encoder_hidden_states is not None:
106
+ return hidden_states, encoder_hidden_states
107
+
108
+ return hidden_states
109
+
110
+
111
+ def linear_layer_pruning(module, lamb, model_type):
112
+ """
113
+ Physically prune attention layers by removing parameters for pruned heads.
114
+
115
+ This function performs structural pruning through the following detailed steps:
116
+
117
+ 1. **Input Processing**: Latent features are fed into linear modules (to_k, to_q, to_v)
118
+ with shape (cross_attn_dim, inner_kv_dim / inner_dim)
119
+
120
+ 2. **Head Division**: Inner features are divided into attention heads, where:
121
+ - Query shape: [B, N, H, D] (batch, sequence, heads, head_dim)
122
+ - New hidden dimension = inner_dim * (unmasked_heads / total_heads)
123
+ - K, Q, V projections have shape [cross_attn_dim, inner_kv_dim / inner_dim]
124
+ - Each head occupies (heads * inner_dim) rows in the weight matrix
125
+ - **Important**: Input channels remain unchanged, only output rows are pruned
126
+
127
+ 3. **Attention Computation**: Updated latent features after scaled dot-product attention
128
+
129
+ 4. **Output Projection**: Final projection layer (to_out) from pruned inner_dim to original latent_dim
130
+ - Pruned dimension changes from input (dim=0) to output (dim=1)
131
+ - **Critical**: Output channels remain unchanged to maintain model compatibility
132
+
133
+ Args:
134
+ module: Attention module to prune (contains to_k, to_q, to_v, to_out)
135
+ lamb: Learned mask values per attention head (1=keep, 0=prune)
136
+ model_type: Model architecture type for skip connection handling
137
+
138
+ Returns:
139
+ module: Pruned attention module or AttentionSkipConnection if fully pruned
140
+
141
+ Note:
142
+ - Supports additional projections (add_k_proj, add_q_proj, add_v_proj) for certain architectures
143
+ - Handles both to_out and to_add_out projection layers
144
+ - Updates all relevant module parameters (inner_dim, query_dim, heads, etc.)
145
+ """
146
+
147
+ heads_to_keep = torch.nonzero(lamb).squeeze()
148
+ if len(heads_to_keep.shape) == 0:
149
+ # if only one head is kept, or none
150
+ heads_to_keep = heads_to_keep.unsqueeze(0)
151
+
152
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
153
+
154
+ if getattr(module, "add_k_proj", None) is not None:
155
+ modules_to_remove.extend([module.add_k_proj, module.add_q_proj, module.add_v_proj])
156
+
157
+ new_heads = int(lamb.sum().item())
158
+
159
+ if new_heads == 0:
160
+ return AttentionSkipConnection(model_type=model_type)
161
+
162
+ for module_to_remove in modules_to_remove:
163
+ # get head dimension
164
+ inner_dim = module_to_remove.out_features // module.heads
165
+ # place holder for the rows to keep
166
+ rows_to_keep = torch.zeros(
167
+ module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device
168
+ )
169
+
170
+ for idx in heads_to_keep:
171
+ rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
172
+
173
+ # overwrite the inner projection with masked projection
174
+ module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
175
+ if module_to_remove.bias is not None:
176
+ module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
177
+ module_to_remove.out_features = int(sum(rows_to_keep).item())
178
+
179
+ # Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
180
+ # with column masking, dim 1
181
+ if getattr(module, "to_out", None) is not None:
182
+ module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
183
+ module.to_out[0].in_features = int(sum(rows_to_keep).item())
184
+
185
+ if getattr(module, "to_add_out", None) is not None:
186
+ module.to_add_out.weight.data = module.to_add_out.weight.data[:, rows_to_keep]
187
+ module.to_add_out.in_features = int(sum(rows_to_keep).item())
188
+
189
+ # update parameters in the attention module
190
+ module.inner_dim = module.inner_dim // module.heads * new_heads
191
+ module.query_dim = module.query_dim // module.heads * new_heads
192
+ module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
193
+ module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
194
+ module.heads = new_heads
195
+ return module
196
+
197
+
198
+ def update_flux_single_transformer_projection(parent_module, module, lamb, old_inner_dim):
199
+ """
200
+ Updates the proj_out module in a FluxSingleTransformerBlock after attention head pruning.
201
+
202
+ FLUX models use a proj_out layer that takes concatenated input from both attention output
203
+ and MLP hidden states: torch.cat([attn_output, mlp_hidden_states], dim=2). When attention
204
+ heads are pruned, the attention dimension changes but the MLP dimension remains constant,
205
+ requiring careful weight matrix reconstruction.
206
+
207
+ Args:
208
+ parent_module: FluxSingleTransformerBlock containing the proj_out layer
209
+ module: Pruned attention module (or AttentionSkipConnection)
210
+ lamb: Original mask values used for pruning decisions
211
+ old_inner_dim: Original attention inner dimension before pruning
212
+
213
+ Returns:
214
+ parent_module: Updated parent module with corrected proj_out dimensions
215
+
216
+ Note:
217
+ - Handles skip connections when module is completely pruned
218
+ - Preserves MLP weights while updating attention weights
219
+ - Only modifies proj_out if dimensions actually changed
220
+ """
221
+ # Handle Skip Connection case (when module is completely pruned)
222
+ if isinstance(module, AttentionSkipConnection):
223
+ return parent_module
224
+
225
+ if hasattr(parent_module, "proj_out"):
226
+ # Calculate how much the attention dimension changed
227
+ attention_dim_change = old_inner_dim - module.inner_dim
228
+
229
+ if attention_dim_change > 0: # Only update if dimensions actually changed
230
+ # Get current weight matrix and dimensions
231
+ old_weight = parent_module.proj_out.weight.data
232
+ old_in_features = parent_module.proj_out.in_features
233
+
234
+ # Calculate new input dimension
235
+ new_in_features = old_in_features - attention_dim_change
236
+
237
+ # Create new weight matrix
238
+ new_weight = torch.zeros(
239
+ old_weight.shape[0], new_in_features,
240
+ device=old_weight.device, dtype=old_weight.dtype
241
+ )
242
+
243
+ # Calculate head dimensions
244
+ old_head_dim = old_inner_dim // lamb.shape[0]
245
+
246
+ # Create mask for attention columns to keep
247
+ heads_to_keep = torch.nonzero(lamb).squeeze()
248
+ if len(heads_to_keep.shape) == 0:
249
+ heads_to_keep = heads_to_keep.unsqueeze(0)
250
+
251
+ attn_cols_to_keep = torch.zeros(old_inner_dim, dtype=torch.bool, device=old_weight.device)
252
+ for idx in heads_to_keep:
253
+ attn_cols_to_keep[idx * old_head_dim : (idx + 1) * old_head_dim] = True
254
+
255
+ # Copy weights for kept attention heads
256
+ kept_indices = torch.nonzero(attn_cols_to_keep).squeeze()
257
+ for i, idx in enumerate(kept_indices):
258
+ if i < module.inner_dim:
259
+ new_weight[:, i] = old_weight[:, idx]
260
+
261
+ # Copy MLP weights (unchanged part)
262
+ mlp_start = old_inner_dim
263
+ if mlp_start < old_in_features: # Ensure there's actually an MLP part
264
+ new_weight[:, module.inner_dim:] = old_weight[:, mlp_start:]
265
+
266
+ # Update the projection layer
267
+ parent_module.proj_out.weight.data = new_weight
268
+ parent_module.proj_out.in_features = new_in_features
269
+ return parent_module
270
+
271
+
272
+ def ffn_linear_layer_pruning(module, lamb):
273
+ """
274
+ Prunes feed-forward network layers based on learned masks.
275
+
276
+ Note: This function could potentially be merged with linear_layer_pruning
277
+ for better code organization in future refactoring.
278
+
279
+ Args:
280
+ module: FFN module to prune
281
+ lamb: Learned mask values for pruning decisions
282
+
283
+ Returns:
284
+ Pruned module or SkipConnection if fully pruned
285
+ """
286
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
287
+ if len(lambda_to_keep) == 0:
288
+ return SkipConnection()
289
+
290
+ num_lambda = len(lambda_to_keep)
291
+
292
+ if hasattr(module, "net") and len(module.net) >= 3:
293
+ # Standard FFN blocks
294
+ if isinstance(module.net[0], GELU):
295
+ # linear layer weight remove before activation
296
+ module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :]
297
+ module.net[0].proj.out_features = num_lambda
298
+ if module.net[0].proj.bias is not None:
299
+ module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
300
+
301
+ update_act = GELU(module.net[0].proj.in_features, num_lambda)
302
+ update_act.proj = module.net[0].proj
303
+ module.net[0] = update_act
304
+ elif isinstance(module.net[0], GEGLU):
305
+ output_feature = module.net[0].proj.out_features
306
+ module.net[0].proj.weight.data = torch.cat(
307
+ [
308
+ module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :],
309
+ module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :],
310
+ ],
311
+ dim=0,
312
+ )
313
+ module.net[0].proj.out_features = num_lambda * 2
314
+ if module.net[0].proj.bias is not None:
315
+ module.net[0].proj.bias.data = torch.cat(
316
+ [
317
+ module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
318
+ module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
319
+ ]
320
+ )
321
+
322
+ update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
323
+ update_act.proj = module.net[0].proj
324
+ module.net[0] = update_act
325
+
326
+ # proj weight after activation
327
+ module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
328
+ module.net[2].in_features = num_lambda
329
+
330
+ elif hasattr(module, "proj_mlp") and hasattr(module, "proj_out"):
331
+ # FFN For FluxSingleTransformerBlock
332
+ module.proj_mlp.weight.data = module.proj_mlp.weight.data[lambda_to_keep, :]
333
+ module.proj_mlp.out_features = num_lambda
334
+ if module.proj_mlp.bias is not None:
335
+ module.proj_mlp.bias.data = module.proj_mlp.bias.data[lambda_to_keep]
336
+
337
+ # Update mlp_hidden_dim to reflect the new size
338
+ old_mlp_hidden_dim = module.mlp_hidden_dim
339
+ module.mlp_hidden_dim = num_lambda
340
+
341
+ # The proj_out layer takes concatenated input from both attention output and MLP output
342
+ # We need to keep the attention part unchanged but update the MLP part
343
+ old_dim = module.proj_out.in_features
344
+ attn_dim = old_dim - old_mlp_hidden_dim # Attention dimension
345
+ new_in_features = attn_dim + num_lambda
346
+
347
+ new_weight = torch.zeros(
348
+ module.proj_out.weight.shape[0], new_in_features,
349
+ device=module.proj_out.weight.device, dtype=module.proj_out.weight.dtype
350
+ )
351
+
352
+ # Copy attention part (unchanged)
353
+ new_weight[:, :attn_dim] = module.proj_out.weight.data[:, :attn_dim]
354
+
355
+ # Copy selected MLP parts
356
+ for i, idx in enumerate(lambda_to_keep):
357
+ new_weight[:, attn_dim + i] = module.proj_out.weight.data[:, attn_dim + idx]
358
+
359
+ # Update the projection layer
360
+ module.proj_out.weight.data = new_weight
361
+ module.proj_out.in_features = new_in_features
362
+
363
+ return module
364
+
365
+
366
+ # create SparsityLinear module
367
+ class SparsityLinear(torch.nn.Module):
368
+ """
369
+ Sparse linear layer that maintains original output dimensions.
370
+
371
+ This layer projects to a smaller intermediate dimension then expands
372
+ back to the original size, placing values only at specified indices.
373
+ Used for normalization layer pruning where output dimensions must match.
374
+
375
+ Args:
376
+ in_features: Input feature dimension
377
+ out_features: Output feature dimension (original size)
378
+ lambda_to_keep: Indices of features to keep active
379
+ num_lambda: Number of active features (len(lambda_to_keep))
380
+ """
381
+ def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
382
+ super(SparsityLinear, self).__init__()
383
+ self.sparse_proj = torch.nn.Linear(in_features, num_lambda)
384
+ self.out_features = out_features
385
+ self.lambda_to_keep = lambda_to_keep
386
+
387
+ def forward(self, x):
388
+ x = self.sparse_proj(x)
389
+ output = torch.zeros(x.size(0), self.out_features, device=x.device, dtype=x.dtype)
390
+ output[:, self.lambda_to_keep] = x
391
+ return output
392
+
393
+
394
+ def norm_layer_pruning(module, lamb):
395
+ """
396
+ Pruning the layer normalization layer for FLUX model
397
+ """
398
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
399
+ if len(lambda_to_keep) == 0:
400
+ return SkipConnection()
401
+
402
+ num_lambda = len(lambda_to_keep)
403
+
404
+ # get num_features
405
+ in_features = module.linear.in_features
406
+ out_features = module.linear.out_features
407
+
408
+ sparselinear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
409
+ sparselinear.sparse_proj.weight.data = module.linear.weight.data[lambda_to_keep]
410
+ sparselinear.sparse_proj.bias.data = module.linear.bias.data[lambda_to_keep]
411
+ module.linear = sparselinear
412
+ return module
413
+
414
+
415
+ def hard_concrete_distribution(
416
+ p, beta: float = 0.83, eps: float = 1e-8, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False
417
+ ):
418
+ u = torch.rand(p.shape).to(p.device)
419
+ if use_log:
420
+ p = torch.clamp(p, min=eps)
421
+ p = torch.log(p)
422
+ s = torch.sigmoid((torch.log(u + eps) - torch.log(1 - u + eps) + p) / beta)
423
+ s = s * (eta - gamma) + gamma
424
+ s = s.clamp(0, 1)
425
+ return s
426
+
427
+
428
+ def l0_complexity_loss(alpha, beta: float = 0.83, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False):
429
+ offset = beta * math.log(-gamma / eta)
430
+ loss = torch.sigmoid(alpha - offset).sum()
431
+ return loss
432
+
433
+
434
+ def calculate_reg_loss(
435
+ loss_reg,
436
+ lambs: List[torch.Tensor],
437
+ p: int,
438
+ use_log: bool = False,
439
+ mean=True,
440
+ reg=True, # regularize the lambda with bounded value range
441
+ reg_alpha=0.4, # alpha for the regularizer, avoid gradient vanishing
442
+ reg_beta=1, # beta for shifting the lambda toward positive value (avoid gradient vanishing)
443
+ ):
444
+ if p == 0:
445
+ for lamb in lambs:
446
+ loss_reg += l0_complexity_loss(lamb, use_log=use_log)
447
+ loss_reg /= len(lambs)
448
+ elif p == 1 or p == 2:
449
+ for lamb in lambs:
450
+ if reg:
451
+ lamb = torch.sigmoid(lamb * reg_alpha + reg_beta)
452
+ if mean:
453
+ loss_reg += lamb.norm(p) / len(lamb)
454
+ else:
455
+ loss_reg += lamb.norm(p)
456
+ loss_reg /= len(lambs)
457
+ else:
458
+ raise NotImplementedError
459
+ return loss_reg