ecodiff_flux_dev / utils.py
LWZ19's picture
Add code
f090370
# all utiles functions
import math
from typing import List, Optional
import torch
from diffusers.models.activations import GEGLU, GELU
def get_total_params(model, trainable: bool = True):
return sum(p.numel() for p in model.parameters() if p.requires_grad == trainable)
def get_precision(precision: str):
assert precision in ["fp16", "fp32", "bf16"], "precision must be either fp16, fp32, bf16"
if precision == "fp16":
torch_dtype = torch.float16
elif precision == "bf16":
torch_dtype = torch.bfloat16
elif precision == "fp32":
torch_dtype = torch.float32
elif precision == "fp64":
torch_dtype = torch.float64
return torch_dtype
def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
total_num_lambs = 0
num_activate_lambs = 0
binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks
for lamb in hooker.lambs:
total_num_lambs += lamb.size(0)
if binary:
assert threshold is None, "threshold should be None for binary mask"
num_activate_lambs += lamb.sum().item()
else:
assert threshold is not None, "threshold must be provided for non-binary mask"
num_activate_lambs += (lamb >= threshold).sum().item()
return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
def linear_layer_masking(module, lamb):
"""
Apply soft masking to attention layer weights (K, Q, V projections).
This function multiplies attention layer weights by mask values without
removing parameters, allowing for gradual pruning during training.
Args:
module: Attention module containing to_k, to_q, to_v, and to_out
lamb: Per-head mask values to apply
Returns:
module: Modified module with masked weights
"""
# perform masking on K Q V to see if it still works
inner_dim = module.to_k.in_features // module.heads
modules_to_remove = [module.to_k, module.to_q, module.to_v]
for module_to_remove in modules_to_remove:
for idx, head_mask in enumerate(lamb):
module_to_remove.weight.data[idx * inner_dim : (idx + 1) * inner_dim, :] *= head_mask
if module_to_remove.bias is not None:
module_to_remove.bias.data[idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
# perform masking on the output
for idx, head_mask in enumerate(lamb):
module.to_out[0].weight.data[:, idx * inner_dim : (idx + 1) * inner_dim] *= head_mask
return module
# create dummy module for skip connection
class SkipConnection(torch.nn.Module):
"""
Skip connection module for completely pruned layers.
When a layer is fully pruned, this module replaces it and simply
returns the input unchanged, maintaining the model's forward pass.
"""
def __init__(self):
super(SkipConnection, self).__init__()
def forward(*args, **kwargs):
return args[1]
class AttentionSkipConnection(torch.nn.Module):
"""
Model-specific skip connection for attention layers.
Handles different return patterns based on model architecture:
- SD3/FLUX models may return multiple values
- Other models return single hidden states
Args:
model_type: Type of diffusion model ("sd3", "flux", "flux_dev", etc.)
"""
def __init__(self, model_type):
super(AttentionSkipConnection, self).__init__()
self.model_type = model_type
def forward(self, hidden_states=None, encoder_hidden_states=None, *args, **kwargs):
# Return the first non-None input, or hidden_states as default
if self.model_type not in ["sd3", "flux", "flux_dev"]:
return hidden_states
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
return hidden_states
def linear_layer_pruning(module, lamb, model_type):
"""
Physically prune attention layers by removing parameters for pruned heads.
This function performs structural pruning through the following detailed steps:
1. **Input Processing**: Latent features are fed into linear modules (to_k, to_q, to_v)
with shape (cross_attn_dim, inner_kv_dim / inner_dim)
2. **Head Division**: Inner features are divided into attention heads, where:
- Query shape: [B, N, H, D] (batch, sequence, heads, head_dim)
- New hidden dimension = inner_dim * (unmasked_heads / total_heads)
- K, Q, V projections have shape [cross_attn_dim, inner_kv_dim / inner_dim]
- Each head occupies (heads * inner_dim) rows in the weight matrix
- **Important**: Input channels remain unchanged, only output rows are pruned
3. **Attention Computation**: Updated latent features after scaled dot-product attention
4. **Output Projection**: Final projection layer (to_out) from pruned inner_dim to original latent_dim
- Pruned dimension changes from input (dim=0) to output (dim=1)
- **Critical**: Output channels remain unchanged to maintain model compatibility
Args:
module: Attention module to prune (contains to_k, to_q, to_v, to_out)
lamb: Learned mask values per attention head (1=keep, 0=prune)
model_type: Model architecture type for skip connection handling
Returns:
module: Pruned attention module or AttentionSkipConnection if fully pruned
Note:
- Supports additional projections (add_k_proj, add_q_proj, add_v_proj) for certain architectures
- Handles both to_out and to_add_out projection layers
- Updates all relevant module parameters (inner_dim, query_dim, heads, etc.)
"""
heads_to_keep = torch.nonzero(lamb).squeeze()
if len(heads_to_keep.shape) == 0:
# if only one head is kept, or none
heads_to_keep = heads_to_keep.unsqueeze(0)
modules_to_remove = [module.to_k, module.to_q, module.to_v]
if getattr(module, "add_k_proj", None) is not None:
modules_to_remove.extend([module.add_k_proj, module.add_q_proj, module.add_v_proj])
new_heads = int(lamb.sum().item())
if new_heads == 0:
return AttentionSkipConnection(model_type=model_type)
for module_to_remove in modules_to_remove:
# get head dimension
inner_dim = module_to_remove.out_features // module.heads
# place holder for the rows to keep
rows_to_keep = torch.zeros(
module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device
)
for idx in heads_to_keep:
rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
# overwrite the inner projection with masked projection
module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
if module_to_remove.bias is not None:
module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
module_to_remove.out_features = int(sum(rows_to_keep).item())
# Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
# with column masking, dim 1
if getattr(module, "to_out", None) is not None:
module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
module.to_out[0].in_features = int(sum(rows_to_keep).item())
if getattr(module, "to_add_out", None) is not None:
module.to_add_out.weight.data = module.to_add_out.weight.data[:, rows_to_keep]
module.to_add_out.in_features = int(sum(rows_to_keep).item())
# update parameters in the attention module
module.inner_dim = module.inner_dim // module.heads * new_heads
module.query_dim = module.query_dim // module.heads * new_heads
module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
module.heads = new_heads
return module
def update_flux_single_transformer_projection(parent_module, module, lamb, old_inner_dim):
"""
Updates the proj_out module in a FluxSingleTransformerBlock after attention head pruning.
FLUX models use a proj_out layer that takes concatenated input from both attention output
and MLP hidden states: torch.cat([attn_output, mlp_hidden_states], dim=2). When attention
heads are pruned, the attention dimension changes but the MLP dimension remains constant,
requiring careful weight matrix reconstruction.
Args:
parent_module: FluxSingleTransformerBlock containing the proj_out layer
module: Pruned attention module (or AttentionSkipConnection)
lamb: Original mask values used for pruning decisions
old_inner_dim: Original attention inner dimension before pruning
Returns:
parent_module: Updated parent module with corrected proj_out dimensions
Note:
- Handles skip connections when module is completely pruned
- Preserves MLP weights while updating attention weights
- Only modifies proj_out if dimensions actually changed
"""
# Handle Skip Connection case (when module is completely pruned)
if isinstance(module, AttentionSkipConnection):
return parent_module
if hasattr(parent_module, "proj_out"):
# Calculate how much the attention dimension changed
attention_dim_change = old_inner_dim - module.inner_dim
if attention_dim_change > 0: # Only update if dimensions actually changed
# Get current weight matrix and dimensions
old_weight = parent_module.proj_out.weight.data
old_in_features = parent_module.proj_out.in_features
# Calculate new input dimension
new_in_features = old_in_features - attention_dim_change
# Create new weight matrix
new_weight = torch.zeros(
old_weight.shape[0], new_in_features,
device=old_weight.device, dtype=old_weight.dtype
)
# Calculate head dimensions
old_head_dim = old_inner_dim // lamb.shape[0]
# Create mask for attention columns to keep
heads_to_keep = torch.nonzero(lamb).squeeze()
if len(heads_to_keep.shape) == 0:
heads_to_keep = heads_to_keep.unsqueeze(0)
attn_cols_to_keep = torch.zeros(old_inner_dim, dtype=torch.bool, device=old_weight.device)
for idx in heads_to_keep:
attn_cols_to_keep[idx * old_head_dim : (idx + 1) * old_head_dim] = True
# Copy weights for kept attention heads
kept_indices = torch.nonzero(attn_cols_to_keep).squeeze()
for i, idx in enumerate(kept_indices):
if i < module.inner_dim:
new_weight[:, i] = old_weight[:, idx]
# Copy MLP weights (unchanged part)
mlp_start = old_inner_dim
if mlp_start < old_in_features: # Ensure there's actually an MLP part
new_weight[:, module.inner_dim:] = old_weight[:, mlp_start:]
# Update the projection layer
parent_module.proj_out.weight.data = new_weight
parent_module.proj_out.in_features = new_in_features
return parent_module
def ffn_linear_layer_pruning(module, lamb):
"""
Prunes feed-forward network layers based on learned masks.
Note: This function could potentially be merged with linear_layer_pruning
for better code organization in future refactoring.
Args:
module: FFN module to prune
lamb: Learned mask values for pruning decisions
Returns:
Pruned module or SkipConnection if fully pruned
"""
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
if hasattr(module, "net") and len(module.net) >= 3:
# Standard FFN blocks
if isinstance(module.net[0], GELU):
# linear layer weight remove before activation
module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :]
module.net[0].proj.out_features = num_lambda
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
update_act = GELU(module.net[0].proj.in_features, num_lambda)
update_act.proj = module.net[0].proj
module.net[0] = update_act
elif isinstance(module.net[0], GEGLU):
output_feature = module.net[0].proj.out_features
module.net[0].proj.weight.data = torch.cat(
[
module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :],
module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :],
],
dim=0,
)
module.net[0].proj.out_features = num_lambda * 2
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = torch.cat(
[
module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
]
)
update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
update_act.proj = module.net[0].proj
module.net[0] = update_act
# proj weight after activation
module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
module.net[2].in_features = num_lambda
elif hasattr(module, "proj_mlp") and hasattr(module, "proj_out"):
# FFN For FluxSingleTransformerBlock
module.proj_mlp.weight.data = module.proj_mlp.weight.data[lambda_to_keep, :]
module.proj_mlp.out_features = num_lambda
if module.proj_mlp.bias is not None:
module.proj_mlp.bias.data = module.proj_mlp.bias.data[lambda_to_keep]
# Update mlp_hidden_dim to reflect the new size
old_mlp_hidden_dim = module.mlp_hidden_dim
module.mlp_hidden_dim = num_lambda
# The proj_out layer takes concatenated input from both attention output and MLP output
# We need to keep the attention part unchanged but update the MLP part
old_dim = module.proj_out.in_features
attn_dim = old_dim - old_mlp_hidden_dim # Attention dimension
new_in_features = attn_dim + num_lambda
new_weight = torch.zeros(
module.proj_out.weight.shape[0], new_in_features,
device=module.proj_out.weight.device, dtype=module.proj_out.weight.dtype
)
# Copy attention part (unchanged)
new_weight[:, :attn_dim] = module.proj_out.weight.data[:, :attn_dim]
# Copy selected MLP parts
for i, idx in enumerate(lambda_to_keep):
new_weight[:, attn_dim + i] = module.proj_out.weight.data[:, attn_dim + idx]
# Update the projection layer
module.proj_out.weight.data = new_weight
module.proj_out.in_features = new_in_features
return module
# create SparsityLinear module
class SparsityLinear(torch.nn.Module):
"""
Sparse linear layer that maintains original output dimensions.
This layer projects to a smaller intermediate dimension then expands
back to the original size, placing values only at specified indices.
Used for normalization layer pruning where output dimensions must match.
Args:
in_features: Input feature dimension
out_features: Output feature dimension (original size)
lambda_to_keep: Indices of features to keep active
num_lambda: Number of active features (len(lambda_to_keep))
"""
def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
super(SparsityLinear, self).__init__()
self.sparse_proj = torch.nn.Linear(in_features, num_lambda)
self.out_features = out_features
self.lambda_to_keep = lambda_to_keep
def forward(self, x):
x = self.sparse_proj(x)
output = torch.zeros(x.size(0), self.out_features, device=x.device, dtype=x.dtype)
output[:, self.lambda_to_keep] = x
return output
def norm_layer_pruning(module, lamb):
"""
Pruning the layer normalization layer for FLUX model
"""
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
# get num_features
in_features = module.linear.in_features
out_features = module.linear.out_features
sparselinear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
sparselinear.sparse_proj.weight.data = module.linear.weight.data[lambda_to_keep]
sparselinear.sparse_proj.bias.data = module.linear.bias.data[lambda_to_keep]
module.linear = sparselinear
return module
def hard_concrete_distribution(
p, beta: float = 0.83, eps: float = 1e-8, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False
):
u = torch.rand(p.shape).to(p.device)
if use_log:
p = torch.clamp(p, min=eps)
p = torch.log(p)
s = torch.sigmoid((torch.log(u + eps) - torch.log(1 - u + eps) + p) / beta)
s = s * (eta - gamma) + gamma
s = s.clamp(0, 1)
return s
def l0_complexity_loss(alpha, beta: float = 0.83, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False):
offset = beta * math.log(-gamma / eta)
loss = torch.sigmoid(alpha - offset).sum()
return loss
def calculate_reg_loss(
loss_reg,
lambs: List[torch.Tensor],
p: int,
use_log: bool = False,
mean=True,
reg=True, # regularize the lambda with bounded value range
reg_alpha=0.4, # alpha for the regularizer, avoid gradient vanishing
reg_beta=1, # beta for shifting the lambda toward positive value (avoid gradient vanishing)
):
if p == 0:
for lamb in lambs:
loss_reg += l0_complexity_loss(lamb, use_log=use_log)
loss_reg /= len(lambs)
elif p == 1 or p == 2:
for lamb in lambs:
if reg:
lamb = torch.sigmoid(lamb * reg_alpha + reg_beta)
if mean:
loss_reg += lamb.norm(p) / len(lamb)
else:
loss_reg += lamb.norm(p)
loss_reg /= len(lambs)
else:
raise NotImplementedError
return loss_reg