Spaces:
Sleeping
Sleeping
| # all utiles functions | |
| import math | |
| from typing import List, Optional | |
| import torch | |
| from diffusers.models.activations import GEGLU, GELU | |
| def get_total_params(model, trainable: bool = True): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad == trainable) | |
| def get_precision(precision: str): | |
| assert precision in ["fp16", "fp32", "bf16"], "precision must be either fp16, fp32, bf16" | |
| if precision == "fp16": | |
| torch_dtype = torch.float16 | |
| elif precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif precision == "fp32": | |
| torch_dtype = torch.float32 | |
| elif precision == "fp64": | |
| torch_dtype = torch.float64 | |
| return torch_dtype | |
| def calculate_mask_sparsity(hooker, threshold: Optional[float] = None): | |
| total_num_lambs = 0 | |
| num_activate_lambs = 0 | |
| binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks | |
| for lamb in hooker.lambs: | |
| total_num_lambs += lamb.size(0) | |
| if binary: | |
| assert threshold is None, "threshold should be None for binary mask" | |
| num_activate_lambs += lamb.sum().item() | |
| else: | |
| assert threshold is not None, "threshold must be provided for non-binary mask" | |
| num_activate_lambs += (lamb >= threshold).sum().item() | |
| return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs | |
| def linear_layer_masking(module, lamb): | |
| """ | |
| Apply soft masking to attention layer weights (K, Q, V projections). | |
| This function multiplies attention layer weights by mask values without | |
| removing parameters, allowing for gradual pruning during training. | |
| Args: | |
| module: Attention module containing to_k, to_q, to_v, and to_out | |
| lamb: Per-head mask values to apply | |
| Returns: | |
| module: Modified module with masked weights | |
| """ | |
| # perform masking on K Q V to see if it still works | |
| inner_dim = module.to_k.in_features // module.heads | |
| modules_to_remove = [module.to_k, module.to_q, module.to_v] | |
| for module_to_remove in modules_to_remove: | |
| for idx, head_mask in enumerate(lamb): | |
| module_to_remove.weight.data[idx * inner_dim : (idx + 1) * inner_dim, :] *= head_mask | |
| if module_to_remove.bias is not None: | |
| module_to_remove.bias.data[idx * inner_dim : (idx + 1) * inner_dim] *= head_mask | |
| # perform masking on the output | |
| for idx, head_mask in enumerate(lamb): | |
| module.to_out[0].weight.data[:, idx * inner_dim : (idx + 1) * inner_dim] *= head_mask | |
| return module | |
| # create dummy module for skip connection | |
| class SkipConnection(torch.nn.Module): | |
| """ | |
| Skip connection module for completely pruned layers. | |
| When a layer is fully pruned, this module replaces it and simply | |
| returns the input unchanged, maintaining the model's forward pass. | |
| """ | |
| def __init__(self): | |
| super(SkipConnection, self).__init__() | |
| def forward(*args, **kwargs): | |
| return args[1] | |
| class AttentionSkipConnection(torch.nn.Module): | |
| """ | |
| Model-specific skip connection for attention layers. | |
| Handles different return patterns based on model architecture: | |
| - SD3/FLUX models may return multiple values | |
| - Other models return single hidden states | |
| Args: | |
| model_type: Type of diffusion model ("sd3", "flux", "flux_dev", etc.) | |
| """ | |
| def __init__(self, model_type): | |
| super(AttentionSkipConnection, self).__init__() | |
| self.model_type = model_type | |
| def forward(self, hidden_states=None, encoder_hidden_states=None, *args, **kwargs): | |
| # Return the first non-None input, or hidden_states as default | |
| if self.model_type not in ["sd3", "flux", "flux_dev"]: | |
| return hidden_states | |
| if encoder_hidden_states is not None: | |
| return hidden_states, encoder_hidden_states | |
| return hidden_states | |
| def linear_layer_pruning(module, lamb, model_type): | |
| """ | |
| Physically prune attention layers by removing parameters for pruned heads. | |
| This function performs structural pruning through the following detailed steps: | |
| 1. **Input Processing**: Latent features are fed into linear modules (to_k, to_q, to_v) | |
| with shape (cross_attn_dim, inner_kv_dim / inner_dim) | |
| 2. **Head Division**: Inner features are divided into attention heads, where: | |
| - Query shape: [B, N, H, D] (batch, sequence, heads, head_dim) | |
| - New hidden dimension = inner_dim * (unmasked_heads / total_heads) | |
| - K, Q, V projections have shape [cross_attn_dim, inner_kv_dim / inner_dim] | |
| - Each head occupies (heads * inner_dim) rows in the weight matrix | |
| - **Important**: Input channels remain unchanged, only output rows are pruned | |
| 3. **Attention Computation**: Updated latent features after scaled dot-product attention | |
| 4. **Output Projection**: Final projection layer (to_out) from pruned inner_dim to original latent_dim | |
| - Pruned dimension changes from input (dim=0) to output (dim=1) | |
| - **Critical**: Output channels remain unchanged to maintain model compatibility | |
| Args: | |
| module: Attention module to prune (contains to_k, to_q, to_v, to_out) | |
| lamb: Learned mask values per attention head (1=keep, 0=prune) | |
| model_type: Model architecture type for skip connection handling | |
| Returns: | |
| module: Pruned attention module or AttentionSkipConnection if fully pruned | |
| Note: | |
| - Supports additional projections (add_k_proj, add_q_proj, add_v_proj) for certain architectures | |
| - Handles both to_out and to_add_out projection layers | |
| - Updates all relevant module parameters (inner_dim, query_dim, heads, etc.) | |
| """ | |
| heads_to_keep = torch.nonzero(lamb).squeeze() | |
| if len(heads_to_keep.shape) == 0: | |
| # if only one head is kept, or none | |
| heads_to_keep = heads_to_keep.unsqueeze(0) | |
| modules_to_remove = [module.to_k, module.to_q, module.to_v] | |
| if getattr(module, "add_k_proj", None) is not None: | |
| modules_to_remove.extend([module.add_k_proj, module.add_q_proj, module.add_v_proj]) | |
| new_heads = int(lamb.sum().item()) | |
| if new_heads == 0: | |
| return AttentionSkipConnection(model_type=model_type) | |
| for module_to_remove in modules_to_remove: | |
| # get head dimension | |
| inner_dim = module_to_remove.out_features // module.heads | |
| # place holder for the rows to keep | |
| rows_to_keep = torch.zeros( | |
| module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device | |
| ) | |
| for idx in heads_to_keep: | |
| rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True | |
| # overwrite the inner projection with masked projection | |
| module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :] | |
| if module_to_remove.bias is not None: | |
| module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep] | |
| module_to_remove.out_features = int(sum(rows_to_keep).item()) | |
| # Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0) | |
| # with column masking, dim 1 | |
| if getattr(module, "to_out", None) is not None: | |
| module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep] | |
| module.to_out[0].in_features = int(sum(rows_to_keep).item()) | |
| if getattr(module, "to_add_out", None) is not None: | |
| module.to_add_out.weight.data = module.to_add_out.weight.data[:, rows_to_keep] | |
| module.to_add_out.in_features = int(sum(rows_to_keep).item()) | |
| # update parameters in the attention module | |
| module.inner_dim = module.inner_dim // module.heads * new_heads | |
| module.query_dim = module.query_dim // module.heads * new_heads | |
| module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads | |
| module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads | |
| module.heads = new_heads | |
| return module | |
| def update_flux_single_transformer_projection(parent_module, module, lamb, old_inner_dim): | |
| """ | |
| Updates the proj_out module in a FluxSingleTransformerBlock after attention head pruning. | |
| FLUX models use a proj_out layer that takes concatenated input from both attention output | |
| and MLP hidden states: torch.cat([attn_output, mlp_hidden_states], dim=2). When attention | |
| heads are pruned, the attention dimension changes but the MLP dimension remains constant, | |
| requiring careful weight matrix reconstruction. | |
| Args: | |
| parent_module: FluxSingleTransformerBlock containing the proj_out layer | |
| module: Pruned attention module (or AttentionSkipConnection) | |
| lamb: Original mask values used for pruning decisions | |
| old_inner_dim: Original attention inner dimension before pruning | |
| Returns: | |
| parent_module: Updated parent module with corrected proj_out dimensions | |
| Note: | |
| - Handles skip connections when module is completely pruned | |
| - Preserves MLP weights while updating attention weights | |
| - Only modifies proj_out if dimensions actually changed | |
| """ | |
| # Handle Skip Connection case (when module is completely pruned) | |
| if isinstance(module, AttentionSkipConnection): | |
| return parent_module | |
| if hasattr(parent_module, "proj_out"): | |
| # Calculate how much the attention dimension changed | |
| attention_dim_change = old_inner_dim - module.inner_dim | |
| if attention_dim_change > 0: # Only update if dimensions actually changed | |
| # Get current weight matrix and dimensions | |
| old_weight = parent_module.proj_out.weight.data | |
| old_in_features = parent_module.proj_out.in_features | |
| # Calculate new input dimension | |
| new_in_features = old_in_features - attention_dim_change | |
| # Create new weight matrix | |
| new_weight = torch.zeros( | |
| old_weight.shape[0], new_in_features, | |
| device=old_weight.device, dtype=old_weight.dtype | |
| ) | |
| # Calculate head dimensions | |
| old_head_dim = old_inner_dim // lamb.shape[0] | |
| # Create mask for attention columns to keep | |
| heads_to_keep = torch.nonzero(lamb).squeeze() | |
| if len(heads_to_keep.shape) == 0: | |
| heads_to_keep = heads_to_keep.unsqueeze(0) | |
| attn_cols_to_keep = torch.zeros(old_inner_dim, dtype=torch.bool, device=old_weight.device) | |
| for idx in heads_to_keep: | |
| attn_cols_to_keep[idx * old_head_dim : (idx + 1) * old_head_dim] = True | |
| # Copy weights for kept attention heads | |
| kept_indices = torch.nonzero(attn_cols_to_keep).squeeze() | |
| for i, idx in enumerate(kept_indices): | |
| if i < module.inner_dim: | |
| new_weight[:, i] = old_weight[:, idx] | |
| # Copy MLP weights (unchanged part) | |
| mlp_start = old_inner_dim | |
| if mlp_start < old_in_features: # Ensure there's actually an MLP part | |
| new_weight[:, module.inner_dim:] = old_weight[:, mlp_start:] | |
| # Update the projection layer | |
| parent_module.proj_out.weight.data = new_weight | |
| parent_module.proj_out.in_features = new_in_features | |
| return parent_module | |
| def ffn_linear_layer_pruning(module, lamb): | |
| """ | |
| Prunes feed-forward network layers based on learned masks. | |
| Note: This function could potentially be merged with linear_layer_pruning | |
| for better code organization in future refactoring. | |
| Args: | |
| module: FFN module to prune | |
| lamb: Learned mask values for pruning decisions | |
| Returns: | |
| Pruned module or SkipConnection if fully pruned | |
| """ | |
| lambda_to_keep = torch.nonzero(lamb).squeeze() | |
| if len(lambda_to_keep) == 0: | |
| return SkipConnection() | |
| num_lambda = len(lambda_to_keep) | |
| if hasattr(module, "net") and len(module.net) >= 3: | |
| # Standard FFN blocks | |
| if isinstance(module.net[0], GELU): | |
| # linear layer weight remove before activation | |
| module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :] | |
| module.net[0].proj.out_features = num_lambda | |
| if module.net[0].proj.bias is not None: | |
| module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep] | |
| update_act = GELU(module.net[0].proj.in_features, num_lambda) | |
| update_act.proj = module.net[0].proj | |
| module.net[0] = update_act | |
| elif isinstance(module.net[0], GEGLU): | |
| output_feature = module.net[0].proj.out_features | |
| module.net[0].proj.weight.data = torch.cat( | |
| [ | |
| module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :], | |
| module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :], | |
| ], | |
| dim=0, | |
| ) | |
| module.net[0].proj.out_features = num_lambda * 2 | |
| if module.net[0].proj.bias is not None: | |
| module.net[0].proj.bias.data = torch.cat( | |
| [ | |
| module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep], | |
| module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep], | |
| ] | |
| ) | |
| update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2) | |
| update_act.proj = module.net[0].proj | |
| module.net[0] = update_act | |
| # proj weight after activation | |
| module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep] | |
| module.net[2].in_features = num_lambda | |
| elif hasattr(module, "proj_mlp") and hasattr(module, "proj_out"): | |
| # FFN For FluxSingleTransformerBlock | |
| module.proj_mlp.weight.data = module.proj_mlp.weight.data[lambda_to_keep, :] | |
| module.proj_mlp.out_features = num_lambda | |
| if module.proj_mlp.bias is not None: | |
| module.proj_mlp.bias.data = module.proj_mlp.bias.data[lambda_to_keep] | |
| # Update mlp_hidden_dim to reflect the new size | |
| old_mlp_hidden_dim = module.mlp_hidden_dim | |
| module.mlp_hidden_dim = num_lambda | |
| # The proj_out layer takes concatenated input from both attention output and MLP output | |
| # We need to keep the attention part unchanged but update the MLP part | |
| old_dim = module.proj_out.in_features | |
| attn_dim = old_dim - old_mlp_hidden_dim # Attention dimension | |
| new_in_features = attn_dim + num_lambda | |
| new_weight = torch.zeros( | |
| module.proj_out.weight.shape[0], new_in_features, | |
| device=module.proj_out.weight.device, dtype=module.proj_out.weight.dtype | |
| ) | |
| # Copy attention part (unchanged) | |
| new_weight[:, :attn_dim] = module.proj_out.weight.data[:, :attn_dim] | |
| # Copy selected MLP parts | |
| for i, idx in enumerate(lambda_to_keep): | |
| new_weight[:, attn_dim + i] = module.proj_out.weight.data[:, attn_dim + idx] | |
| # Update the projection layer | |
| module.proj_out.weight.data = new_weight | |
| module.proj_out.in_features = new_in_features | |
| return module | |
| # create SparsityLinear module | |
| class SparsityLinear(torch.nn.Module): | |
| """ | |
| Sparse linear layer that maintains original output dimensions. | |
| This layer projects to a smaller intermediate dimension then expands | |
| back to the original size, placing values only at specified indices. | |
| Used for normalization layer pruning where output dimensions must match. | |
| Args: | |
| in_features: Input feature dimension | |
| out_features: Output feature dimension (original size) | |
| lambda_to_keep: Indices of features to keep active | |
| num_lambda: Number of active features (len(lambda_to_keep)) | |
| """ | |
| def __init__(self, in_features, out_features, lambda_to_keep, num_lambda): | |
| super(SparsityLinear, self).__init__() | |
| self.sparse_proj = torch.nn.Linear(in_features, num_lambda) | |
| self.out_features = out_features | |
| self.lambda_to_keep = lambda_to_keep | |
| def forward(self, x): | |
| x = self.sparse_proj(x) | |
| output = torch.zeros(x.size(0), self.out_features, device=x.device, dtype=x.dtype) | |
| output[:, self.lambda_to_keep] = x | |
| return output | |
| def norm_layer_pruning(module, lamb): | |
| """ | |
| Pruning the layer normalization layer for FLUX model | |
| """ | |
| lambda_to_keep = torch.nonzero(lamb).squeeze() | |
| if len(lambda_to_keep) == 0: | |
| return SkipConnection() | |
| num_lambda = len(lambda_to_keep) | |
| # get num_features | |
| in_features = module.linear.in_features | |
| out_features = module.linear.out_features | |
| sparselinear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda) | |
| sparselinear.sparse_proj.weight.data = module.linear.weight.data[lambda_to_keep] | |
| sparselinear.sparse_proj.bias.data = module.linear.bias.data[lambda_to_keep] | |
| module.linear = sparselinear | |
| return module | |
| def hard_concrete_distribution( | |
| p, beta: float = 0.83, eps: float = 1e-8, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False | |
| ): | |
| u = torch.rand(p.shape).to(p.device) | |
| if use_log: | |
| p = torch.clamp(p, min=eps) | |
| p = torch.log(p) | |
| s = torch.sigmoid((torch.log(u + eps) - torch.log(1 - u + eps) + p) / beta) | |
| s = s * (eta - gamma) + gamma | |
| s = s.clamp(0, 1) | |
| return s | |
| def l0_complexity_loss(alpha, beta: float = 0.83, eta: float = 1.1, gamma: float = -0.1, use_log: bool = False): | |
| offset = beta * math.log(-gamma / eta) | |
| loss = torch.sigmoid(alpha - offset).sum() | |
| return loss | |
| def calculate_reg_loss( | |
| loss_reg, | |
| lambs: List[torch.Tensor], | |
| p: int, | |
| use_log: bool = False, | |
| mean=True, | |
| reg=True, # regularize the lambda with bounded value range | |
| reg_alpha=0.4, # alpha for the regularizer, avoid gradient vanishing | |
| reg_beta=1, # beta for shifting the lambda toward positive value (avoid gradient vanishing) | |
| ): | |
| if p == 0: | |
| for lamb in lambs: | |
| loss_reg += l0_complexity_loss(lamb, use_log=use_log) | |
| loss_reg /= len(lambs) | |
| elif p == 1 or p == 2: | |
| for lamb in lambs: | |
| if reg: | |
| lamb = torch.sigmoid(lamb * reg_alpha + reg_beta) | |
| if mean: | |
| loss_reg += lamb.norm(p) / len(lamb) | |
| else: | |
| loss_reg += lamb.norm(p) | |
| loss_reg /= len(lambs) | |
| else: | |
| raise NotImplementedError | |
| return loss_reg | |