| | import logging |
| | import math |
| | import os |
| | import sys |
| | import shutil |
| | from copy import deepcopy |
| | from typing import Dict, List, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from accelerate import Accelerator |
| | from torch import no_grad |
| | from torch.utils.data import DataLoader |
| | from tqdm import tqdm |
| |
|
| | from .io import create_dir, save_json |
| | from .utils import print_gpu_memory, prepare_calibration_input, auto_map, CUSTOM_FILE |
| | from .wrapper import HiddenStatesRecordWrapper |
| | from .super_weight import find_super_weights as detect_super_weights |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | @no_grad() |
| | def get_layer_similarities(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, drop_norm: bool, target_layer: str, cache_file=None): |
| | device = accelerator.device |
| |
|
| | if cache_file is not None and os.path.exists(cache_file): |
| | |
| | accelerator.print(f"Loading cached model from {cache_file}") |
| | similarities = torch.load(cache_file, map_location=device) |
| |
|
| | else: |
| | |
| | accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.") |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.config.use_cache = False |
| | layers = unwrapped_model.model.layers |
| |
|
| | accelerator.print("Getting features...") |
| | inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) |
| |
|
| | |
| | num_layers = unwrapped_model.config.num_hidden_layers |
| | layer_indices = list(range(num_layers)) |
| |
|
| | |
| | |
| | |
| | |
| | similarities = torch.full((num_layers,), -math.inf, device=device) |
| | if hasattr(unwrapped_model.config, f'drop_{target_layer}_list'): |
| | skipped_layers = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_{target_layer}_list', [])) if v] |
| | else: |
| | skipped_layers = [] |
| |
|
| | accelerator.print('Starting ...') |
| | for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process): |
| | if i in skipped_layers: |
| | similarities[i] = -math.inf |
| | accelerator.print('Skip the dropped layer: ', i) |
| | continue |
| | sys.stderr.flush() |
| | torch.cuda.empty_cache() |
| | print_gpu_memory(accelerator) |
| | layer = layers[i] |
| |
|
| | if i in layer_indices: |
| | if target_layer == 'mlp': |
| | module_pre_norm = layer.post_attention_layernorm |
| | module = layer.mlp |
| | elif target_layer == 'attn': |
| | module_pre_norm = layer.input_layernorm |
| | module = layer.self_attn |
| | elif target_layer == 'all': |
| | raise ValueError("Unsupported target_layer!") |
| | if drop_norm: |
| | wrapped_module_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=False) |
| | else: |
| | wrapped_module_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=False, record_output=True) |
| | wrapped_module = HiddenStatesRecordWrapper(module, record_input=False, record_output=True) |
| |
|
| | |
| | def record_module_pre_norm_states_hook(_, input, output): |
| | wrapped_module_pre_norm.record(input[0].data, output[0].data) |
| |
|
| | if target_layer == 'mlp': |
| | def record_module_states_hook(_, input, output): |
| | wrapped_module.record(input[0].data, output[0].data) |
| | elif target_layer == 'attn': |
| | def record_module_states_hook(_, input, output): |
| | wrapped_module.record(None, output[0].data) |
| | else: |
| | raise ValueError("Unsupported target_layer!") |
| | |
| | handles = [] |
| | handles.append(module_pre_norm.register_forward_hook(record_module_pre_norm_states_hook)) |
| | handles.append(module.register_forward_hook(record_module_states_hook)) |
| | for j in range(num_samples): |
| | if getattr(unwrapped_model.config, "model_type", None) == "llama": |
| | outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0] |
| | else: |
| | outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0] |
| | for handle in handles: |
| | handle.remove() |
| | |
| | dtype = torch.float32 |
| |
|
| | if drop_norm: |
| | input_hidden_states = torch.cat(wrapped_module_pre_norm.input_hidden_states, dim=0).to(dtype).to(device) |
| | output_hidden_states = input_hidden_states + torch.cat(wrapped_module.output_hidden_states, dim=0).to(dtype).to(device) |
| | else: |
| | input_hidden_states = torch.cat(wrapped_module_pre_norm.output_hidden_states, dim=0).to(dtype).to(device) |
| | output_hidden_states = torch.cat(wrapped_module.output_hidden_states, dim=0).to(dtype).to(device) |
| |
|
| | |
| | cos_sim = F.cosine_similarity(input_hidden_states, output_hidden_states, dim=-1) |
| | cos_sim = cos_sim.mean() |
| | cos_sim = accelerator.reduce(cos_sim, reduction="mean") |
| | accelerator.print(f'layer {i} similarity: {cos_sim.item()}') |
| | similarities[i] = cos_sim |
| | |
| | else: |
| | for j in range(num_samples): |
| | if getattr(unwrapped_model.config, "model_type", None) == "llama": |
| | outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0] |
| | else: |
| | outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0] |
| |
|
| | |
| | inputs, outputs = outputs, inputs |
| |
|
| | |
| | if cache_file is not None: |
| | if accelerator.is_main_process: |
| | create_dir(os.path.dirname(cache_file)) |
| | torch.save(similarities.clone().cpu(), cache_file) |
| | print(f"Saving cached similarities to {cache_file}") |
| | accelerator.wait_for_everyone() |
| |
|
| | accelerator.print("similarities\n", similarities) |
| |
|
| | return similarities |
| |
|
| | |
| | def discrete_layer_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int): |
| | """ |
| | π Prune mlp layers in a discrete order. |
| | E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 2, 6, 8, 9] |
| | """ |
| | drop_n = args.drop_n |
| |
|
| | if args.target_layer == 'all': |
| | similarities_attn = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer='attn', cache_file=args.similarity_cache_file.replace("all", "all_attn")) |
| | similarities_mlp = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer='mlp', cache_file=args.similarity_cache_file.replace("all", "all_mlp")) |
| | similarities = torch.cat((similarities_attn, similarities_mlp), dim=0) |
| | else: |
| | similarities = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer=args.target_layer, cache_file=args.similarity_cache_file) |
| |
|
| | sorted_similarities, sorted_layer_id = torch.sort(similarities, dim=0, descending=True) |
| |
|
| | dropped_layer_list = sorted_layer_id[:drop_n].tolist() |
| | accelerator.print(f"Dropped layer: {dropped_layer_list}, similarities: {sorted_similarities[:drop_n].tolist()}") |
| | return dropped_layer_list |
| |
|
| |
|
| | def _serialize_super_weights(super_weights: Dict[int, Tuple[int, int, float, float]]) -> Dict[str, List[float]]: |
| | """Convert the detected super weights into JSON-serializable dict.""" |
| | return { |
| | str(layer_idx): [int(row), int(col), float(weight_val), float(activation_val)] |
| | for layer_idx, (row, col, weight_val, activation_val) in super_weights.items() |
| | } |
| |
|
| |
|
| | def _super_weight_activation_delta( |
| | reference: Dict[int, Tuple[int, int, float, float]], |
| | candidate: Dict[int, Tuple[int, int, float, float]], |
| | ) -> float: |
| | """Compute total activation delta between reference and candidate super weights.""" |
| | if not reference: |
| | return 0.0 |
| | delta = 0.0 |
| | for layer_idx, (_, _, _, ref_activation) in reference.items(): |
| | cand = candidate.get(layer_idx) |
| | if cand is None: |
| | delta += abs(ref_activation) |
| | else: |
| | delta += abs(ref_activation - cand[3]) |
| | return float(delta) |
| |
|
| |
|
| | def super_weight_guided_attn_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int): |
| | """ |
| | Drop attention layers sequentially in a greedy fashion based on the impact on super weight activations. |
| | At every step, temporarily drop each remaining attention layer, recompute the super weight activations, |
| | and pick the layer whose removal yields the smallest deviation from the current baseline activations. |
| | """ |
| | if args.target_layer != 'attn': |
| | raise ValueError("super_weight_guided layer dropping only supports target_layer='attn'.") |
| |
|
| | accelerator.print("Running super-weight-guided attention dropping...") |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | layers = getattr(unwrapped_model.model, "layers", None) |
| | if layers is None: |
| | raise ValueError("Unable to access model layers for super-weight-guided dropping.") |
| |
|
| | num_layers = len(layers) |
| | drop_flags = getattr(unwrapped_model.config, "drop_attn_list", None) |
| | if not drop_flags: |
| | drop_flags = [False] * num_layers |
| | unwrapped_model.config.drop_attn_list = drop_flags |
| | else: |
| | |
| | drop_flags = list(drop_flags) |
| | if len(drop_flags) < num_layers: |
| | drop_flags.extend([False] * (num_layers - len(drop_flags))) |
| | unwrapped_model.config.drop_attn_list = drop_flags |
| |
|
| | |
| | for idx, flag in enumerate(drop_flags): |
| | if hasattr(layers[idx], "drop_attn"): |
| | layers[idx].drop_attn = flag |
| | else: |
| | raise ValueError("Layer does not expose drop_attn attribute, cannot perform guided dropping.") |
| |
|
| | initially_dropped = {idx for idx, flag in enumerate(drop_flags) if flag} |
| | remaining_layers = [idx for idx in range(num_layers) if idx not in initially_dropped] |
| |
|
| | accelerator.print(f"Initial dropped attention layers: {sorted(initially_dropped)}") |
| | accelerator.print(f"Remaining attention layers to evaluate: {remaining_layers}") |
| |
|
| | def _set_drop_flag(layer_idx: int, value: bool): |
| | layers[layer_idx].drop_attn = value |
| | drop_flags[layer_idx] = value |
| |
|
| | def _detect_current_super_weights(): |
| | return detect_super_weights( |
| | model, |
| | dataloader, |
| | accelerator, |
| | num_samples=num_samples, |
| | threshold=getattr(args, 'super_weight_threshold', 3.0), |
| | cache_file=None, |
| | ) |
| |
|
| | current_super_weights = _detect_current_super_weights() |
| | baseline_snapshot = _serialize_super_weights(current_super_weights) |
| | drop_history: List[Dict[str, object]] = [] |
| | drop_order: List[int] = [] |
| |
|
| | if not remaining_layers: |
| | accelerator.print("All attention layers are already dropped. Nothing to do.") |
| | trace = { |
| | "initially_dropped_layers": sorted(initially_dropped), |
| | "initial_super_weights": baseline_snapshot, |
| | "drop_order": drop_order, |
| | "drop_history": drop_history, |
| | } |
| | if args.prune_model_save_path and accelerator.is_main_process: |
| | trace_path = os.path.join(args.prune_model_save_path, "super_weight_attn_drop_trace.json") |
| | save_json(trace, trace_path, indent=2) |
| | accelerator.print(f"Super-weight-guided drop trace saved to {trace_path}") |
| | return sorted(initially_dropped) |
| |
|
| | step = 0 |
| | while remaining_layers: |
| | step += 1 |
| | best_layer = None |
| | best_delta = math.inf |
| | best_candidate_weights = None |
| |
|
| | accelerator.print(f"[SuperWeightDrop][Step {step}] evaluating {len(remaining_layers)} candidate layers...") |
| | for candidate_layer in remaining_layers: |
| | _set_drop_flag(candidate_layer, True) |
| | candidate_super_weights = _detect_current_super_weights() |
| | delta = _super_weight_activation_delta(current_super_weights, candidate_super_weights) |
| | accelerator.print( |
| | f"[SuperWeightDrop] Layer {candidate_layer} delta={delta:.6f} " |
| | f"(baseline count={len(current_super_weights)}, candidate count={len(candidate_super_weights)})" |
| | ) |
| |
|
| | if delta < best_delta: |
| | best_delta = delta |
| | best_layer = candidate_layer |
| | best_candidate_weights = candidate_super_weights |
| |
|
| | _set_drop_flag(candidate_layer, False) |
| |
|
| | if best_layer is None or best_candidate_weights is None: |
| | raise RuntimeError("Failed to identify the next attention layer to drop.") |
| |
|
| | |
| | _set_drop_flag(best_layer, True) |
| | remaining_layers.remove(best_layer) |
| | current_super_weights = best_candidate_weights |
| | drop_order.append(best_layer) |
| |
|
| | drop_history.append( |
| | { |
| | "step": step, |
| | "layer_index": best_layer, |
| | "activation_delta": best_delta, |
| | "super_weights": _serialize_super_weights(best_candidate_weights), |
| | } |
| | ) |
| |
|
| | accelerator.print( |
| | f"[SuperWeightDrop] Dropped layer {best_layer} at step {step} (delta={best_delta:.6f}). " |
| | f"{len(remaining_layers)} layers remaining." |
| | ) |
| |
|
| | trace = { |
| | "initially_dropped_layers": sorted(initially_dropped), |
| | "initial_super_weights": baseline_snapshot, |
| | "drop_order": drop_order, |
| | "drop_history": drop_history, |
| | } |
| |
|
| | if args.prune_model_save_path and accelerator.is_main_process: |
| | trace_path = os.path.join(args.prune_model_save_path, "super_weight_attn_drop_trace.json") |
| | save_json(trace, trace_path, indent=2) |
| | accelerator.print(f"Super-weight-guided drop trace saved to {trace_path}") |
| |
|
| | return sorted(list(initially_dropped | set(drop_order))) |
| |
|
| |
|
| | def post_layers_drop(prune_model_save_path, target_layer, model, tokenizer, reserved_layer_list, accelerator: Accelerator, only_update_config=False): |
| | unwrapped_model = accelerator.unwrap_model(model) |
| |
|
| | if accelerator.is_main_process: |
| | out_cfg = deepcopy(unwrapped_model.config) |
| | model_type = getattr(unwrapped_model.config, "model_type", None) |
| |
|
| | if model_type in auto_map: |
| | out_cfg.auto_map = auto_map[model_type] |
| | else: |
| | raise ValueError("Unsupported model type!") |
| | dropped_attn_list = [] |
| | dropped_mlp_list = [] |
| | if target_layer == 'all': |
| | dropped_layer_list = list(set(list(range(out_cfg.num_hidden_layers * 2))) - set(reserved_layer_list)) |
| | for idx in dropped_layer_list: |
| | if idx >= out_cfg.num_hidden_layers: |
| | dropped_mlp_list.append(idx - out_cfg.num_hidden_layers) |
| | else: |
| | dropped_attn_list.append(idx) |
| | elif target_layer == 'attn': |
| | dropped_attn_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list)) |
| | elif target_layer == 'mlp': |
| | dropped_mlp_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list)) |
| | else: |
| | raise ValueError("Unsupported target_layer!") |
| |
|
| | out_cfg.drop_mlp_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_mlp_list', [])) if v] + dropped_mlp_list |
| | out_cfg.drop_attn_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_attn_list', [])) if v] + dropped_attn_list |
| |
|
| | accelerator.print(f"Dropped attention list: {dropped_attn_list}") |
| | accelerator.print(f"Dropped MLP list: {dropped_mlp_list}") |
| |
|
| | accelerator.print("Saving...") |
| | shutil.copy(CUSTOM_FILE[out_cfg.model_type]["config"], prune_model_save_path) |
| | shutil.copy(CUSTOM_FILE[out_cfg.model_type]["model"], prune_model_save_path) |
| | if not only_update_config: |
| | model.save_pretrained(prune_model_save_path) |
| | tokenizer.save_pretrained(prune_model_save_path) |
| | out_cfg.save_pretrained(prune_model_save_path) |
| |
|