|
|
""" |
|
|
MiniMind Pruning Toolkit |
|
|
Structured and unstructured pruning for model compression. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Dict, List, Tuple |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.utils.prune as prune |
|
|
|
|
|
|
|
|
class PruningMethod(Enum): |
|
|
"""Supported pruning methods.""" |
|
|
MAGNITUDE = "magnitude" |
|
|
STRUCTURED = "structured" |
|
|
MOVEMENT = "movement" |
|
|
WANDA = "wanda" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PruningConfig: |
|
|
"""Configuration for pruning.""" |
|
|
method: PruningMethod = PruningMethod.MAGNITUDE |
|
|
sparsity: float = 0.5 |
|
|
structured: bool = False |
|
|
prune_heads: bool = True |
|
|
prune_experts: bool = True |
|
|
prune_ffn: bool = True |
|
|
min_heads: int = 2 |
|
|
min_experts: int = 2 |
|
|
|
|
|
|
|
|
class Mind2Pruner: |
|
|
"""Pruner for MiniMind models.""" |
|
|
|
|
|
def __init__(self, config: Optional[PruningConfig] = None): |
|
|
self.config = config or PruningConfig() |
|
|
|
|
|
def prune( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Prune the model. |
|
|
|
|
|
Args: |
|
|
model: Model to prune |
|
|
calibration_data: Data for importance estimation |
|
|
|
|
|
Returns: |
|
|
Pruned model |
|
|
""" |
|
|
if self.config.method == PruningMethod.MAGNITUDE: |
|
|
return self._magnitude_pruning(model) |
|
|
elif self.config.method == PruningMethod.STRUCTURED: |
|
|
return self._structured_pruning(model, calibration_data) |
|
|
elif self.config.method == PruningMethod.WANDA: |
|
|
return self._wanda_pruning(model, calibration_data) |
|
|
else: |
|
|
raise ValueError(f"Unsupported pruning method: {self.config.method}") |
|
|
|
|
|
def _magnitude_pruning(self, model: nn.Module) -> nn.Module: |
|
|
"""Apply unstructured magnitude pruning.""" |
|
|
modules_to_prune = [] |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
modules_to_prune.append((module, "weight")) |
|
|
|
|
|
|
|
|
prune.global_unstructured( |
|
|
modules_to_prune, |
|
|
pruning_method=prune.L1Unstructured, |
|
|
amount=self.config.sparsity, |
|
|
) |
|
|
|
|
|
|
|
|
for module, _ in modules_to_prune: |
|
|
prune.remove(module, "weight") |
|
|
|
|
|
return model |
|
|
|
|
|
def _structured_pruning( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
"""Apply structured pruning (channels/heads).""" |
|
|
|
|
|
importance_scores = self._compute_importance(model, calibration_data) |
|
|
|
|
|
|
|
|
if self.config.prune_heads: |
|
|
model = self._prune_attention_heads(model, importance_scores) |
|
|
|
|
|
|
|
|
if self.config.prune_ffn: |
|
|
model = self._prune_ffn_neurons(model, importance_scores) |
|
|
|
|
|
|
|
|
if self.config.prune_experts: |
|
|
model = self._prune_experts(model, importance_scores) |
|
|
|
|
|
return model |
|
|
|
|
|
def _compute_importance( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Compute importance scores for different components.""" |
|
|
importance = {} |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "num_heads"): |
|
|
|
|
|
q_weight = getattr(module, "q_proj", None) |
|
|
if q_weight is not None: |
|
|
weight = q_weight.weight.data |
|
|
num_heads = module.num_heads |
|
|
head_dim = weight.shape[0] // num_heads |
|
|
|
|
|
head_importance = torch.zeros(num_heads) |
|
|
for h in range(num_heads): |
|
|
start = h * head_dim |
|
|
end = (h + 1) * head_dim |
|
|
head_importance[h] = weight[start:end].norm() |
|
|
|
|
|
importance[f"{name}.heads"] = head_importance |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear) and "gate_proj" in name: |
|
|
weight = module.weight.data |
|
|
neuron_importance = weight.norm(dim=1) |
|
|
importance[f"{name}.neurons"] = neuron_importance |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "experts"): |
|
|
expert_importance = torch.zeros(len(module.experts)) |
|
|
for i, expert in enumerate(module.experts): |
|
|
expert_params = sum(p.numel() for p in expert.parameters()) |
|
|
expert_norm = sum(p.data.norm() for p in expert.parameters()) |
|
|
expert_importance[i] = expert_norm / max(1, expert_params) |
|
|
|
|
|
importance[f"{name}.experts"] = expert_importance |
|
|
|
|
|
return importance |
|
|
|
|
|
def _prune_attention_heads( |
|
|
self, |
|
|
model: nn.Module, |
|
|
importance: Dict[str, torch.Tensor], |
|
|
) -> nn.Module: |
|
|
"""Prune least important attention heads.""" |
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "num_heads"): |
|
|
head_key = f"{name}.heads" |
|
|
if head_key in importance: |
|
|
scores = importance[head_key] |
|
|
num_heads = len(scores) |
|
|
num_prune = int(num_heads * self.config.sparsity) |
|
|
num_keep = max(self.config.min_heads, num_heads - num_prune) |
|
|
|
|
|
|
|
|
_, keep_indices = torch.topk(scores, num_keep) |
|
|
keep_indices = keep_indices.sort()[0] |
|
|
|
|
|
|
|
|
head_dim = module.head_dim |
|
|
mask = torch.zeros(num_heads * head_dim) |
|
|
for idx in keep_indices: |
|
|
start = idx * head_dim |
|
|
end = (idx + 1) * head_dim |
|
|
mask[start:end] = 1 |
|
|
|
|
|
|
|
|
for proj_name in ["q_proj", "o_proj"]: |
|
|
proj = getattr(module, proj_name, None) |
|
|
if proj is not None: |
|
|
if proj_name == "q_proj": |
|
|
proj.weight.data *= mask.unsqueeze(1).to(proj.weight.device) |
|
|
else: |
|
|
proj.weight.data *= mask.unsqueeze(0).to(proj.weight.device) |
|
|
|
|
|
return model |
|
|
|
|
|
def _prune_ffn_neurons( |
|
|
self, |
|
|
model: nn.Module, |
|
|
importance: Dict[str, torch.Tensor], |
|
|
) -> nn.Module: |
|
|
"""Prune least important FFN neurons.""" |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear) and "gate_proj" in name: |
|
|
neuron_key = f"{name}.neurons" |
|
|
if neuron_key in importance: |
|
|
scores = importance[neuron_key] |
|
|
num_neurons = len(scores) |
|
|
num_prune = int(num_neurons * self.config.sparsity) |
|
|
num_keep = num_neurons - num_prune |
|
|
|
|
|
_, keep_indices = torch.topk(scores, num_keep) |
|
|
|
|
|
|
|
|
mask = torch.zeros(num_neurons) |
|
|
mask[keep_indices] = 1 |
|
|
|
|
|
|
|
|
module.weight.data *= mask.unsqueeze(1).to(module.weight.device) |
|
|
|
|
|
return model |
|
|
|
|
|
def _prune_experts( |
|
|
self, |
|
|
model: nn.Module, |
|
|
importance: Dict[str, torch.Tensor], |
|
|
) -> nn.Module: |
|
|
"""Prune least important MoE experts.""" |
|
|
for name, module in model.named_modules(): |
|
|
if hasattr(module, "experts"): |
|
|
expert_key = f"{name}.experts" |
|
|
if expert_key in importance: |
|
|
scores = importance[expert_key] |
|
|
num_experts = len(scores) |
|
|
num_prune = int(num_experts * self.config.sparsity) |
|
|
num_keep = max(self.config.min_experts, num_experts - num_prune) |
|
|
|
|
|
_, keep_indices = torch.topk(scores, num_keep) |
|
|
keep_indices = keep_indices.sort()[0].tolist() |
|
|
|
|
|
|
|
|
for i, expert in enumerate(module.experts): |
|
|
if i not in keep_indices: |
|
|
for param in expert.parameters(): |
|
|
param.data.zero_() |
|
|
|
|
|
print(f"Pruned experts in {name}: keeping {keep_indices}") |
|
|
|
|
|
return model |
|
|
|
|
|
def _wanda_pruning( |
|
|
self, |
|
|
model: nn.Module, |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Apply WANDA (Weights AND Activations) pruning. |
|
|
Combines weight magnitude with activation magnitude. |
|
|
""" |
|
|
if calibration_data is None: |
|
|
print("Warning: WANDA requires calibration data, falling back to magnitude pruning") |
|
|
return self._magnitude_pruning(model) |
|
|
|
|
|
model.eval() |
|
|
activation_norms = {} |
|
|
|
|
|
|
|
|
def hook_fn(name): |
|
|
def hook(module, input, output): |
|
|
if isinstance(input, tuple): |
|
|
input = input[0] |
|
|
activation_norms[name] = input.abs().mean(dim=(0, 1)) |
|
|
return hook |
|
|
|
|
|
|
|
|
handles = [] |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
handles.append(module.register_forward_hook(hook_fn(name))) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
model(calibration_data) |
|
|
|
|
|
|
|
|
for handle in handles: |
|
|
handle.remove() |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear) and name in activation_norms: |
|
|
weight = module.weight.data |
|
|
act_norm = activation_norms[name].to(weight.device) |
|
|
|
|
|
|
|
|
wanda_score = weight.abs() * act_norm.unsqueeze(0) |
|
|
|
|
|
|
|
|
threshold = torch.quantile(wanda_score.flatten(), self.config.sparsity) |
|
|
mask = (wanda_score >= threshold).float() |
|
|
module.weight.data *= mask |
|
|
|
|
|
return model |
|
|
|
|
|
def compute_sparsity(self, model: nn.Module) -> Dict[str, float]: |
|
|
"""Compute actual sparsity of the model.""" |
|
|
total_params = 0 |
|
|
zero_params = 0 |
|
|
layer_sparsity = {} |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
params = module.weight.numel() |
|
|
zeros = (module.weight == 0).sum().item() |
|
|
total_params += params |
|
|
zero_params += zeros |
|
|
layer_sparsity[name] = zeros / params |
|
|
|
|
|
return { |
|
|
"total_sparsity": zero_params / max(1, total_params), |
|
|
"layer_sparsity": layer_sparsity, |
|
|
} |
|
|
|
|
|
|
|
|
def prune_model( |
|
|
model: nn.Module, |
|
|
sparsity: float = 0.5, |
|
|
method: str = "magnitude", |
|
|
calibration_data: Optional[torch.Tensor] = None, |
|
|
) -> nn.Module: |
|
|
""" |
|
|
Convenience function to prune a model. |
|
|
|
|
|
Args: |
|
|
model: Model to prune |
|
|
sparsity: Target sparsity ratio |
|
|
method: Pruning method (magnitude, structured, wanda) |
|
|
calibration_data: Calibration data for importance estimation |
|
|
|
|
|
Returns: |
|
|
Pruned model |
|
|
""" |
|
|
config = PruningConfig( |
|
|
method=PruningMethod(method), |
|
|
sparsity=sparsity, |
|
|
) |
|
|
pruner = Mind2Pruner(config) |
|
|
return pruner.prune(model, calibration_data) |
|
|
|