| import logging |
| import math |
| import os |
| import sys |
| from copy import deepcopy |
| import shutil |
| import torch |
| import torch.nn.functional as F |
| from accelerate import Accelerator |
| from torch import no_grad |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from .io import create_dir |
| from .utils import prepare_calibration_input, print_gpu_memory, auto_map, CUSTOM_FILE |
| from .wrapper import HiddenStatesRecordWrapper |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_block_similarities(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, cache_file=None): |
| device = accelerator.device |
|
|
| if cache_file is not None and os.path.exists(cache_file): |
| |
| accelerator.print(f"Loading cached model from {cache_file}") |
| similarities = torch.load(cache_file, map_location=device) |
|
|
| else: |
| |
| accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.") |
| unwrapped_model = accelerator.unwrap_model(model) |
| unwrapped_model.config.use_cache = False |
| layers = unwrapped_model.model.layers |
|
|
| accelerator.print("Getting features...") |
| inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) |
| num_layers = unwrapped_model.config.num_hidden_layers |
| |
| |
| |
| |
| similarities = torch.full((len(layers), 1), -math.inf, device=device) |
|
|
| accelerator.print('Starting ...') |
| dtype = torch.float32 |
|
|
| for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process): |
| sys.stderr.flush() |
| torch.cuda.empty_cache() |
| print_gpu_memory(accelerator) |
| layer = layers[i] |
|
|
| |
| wrapped_layer = HiddenStatesRecordWrapper(layer, record_input=True, record_output=True) |
|
|
| |
| def record_states_hook(_, input, output): |
| wrapped_layer.record(input[0].data, output[0].data) |
|
|
| |
| handle = layer.register_forward_hook(record_states_hook) |
| for j in range(num_samples): |
| outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0] |
| handle.remove() |
|
|
| |
| inputs, outputs = outputs, inputs |
| print_gpu_memory(accelerator) |
|
|
| input_hidden_states = torch.cat(wrapped_layer.input_hidden_states, dim=0).to(dtype).to(device) |
| output_hidden_states = torch.cat(wrapped_layer.output_hidden_states, dim=0).to(dtype).to(device) |
| cos_sim = F.cosine_similarity(input_hidden_states, output_hidden_states, dim=-1) |
| cos_sim = cos_sim.mean() |
| cos_sim = accelerator.reduce(cos_sim, reduction="mean") |
| similarities[i, 0] = cos_sim |
| layer.to("cpu") |
|
|
| |
| if cache_file is not None: |
| if accelerator.is_main_process: |
| create_dir(os.path.dirname(cache_file)) |
| torch.save(similarities.clone().cpu(), cache_file) |
| print(f"Saving cached similarities to {cache_file}") |
| accelerator.wait_for_everyone() |
|
|
| accelerator.print("similarities\n", similarities) |
|
|
| return similarities |
|
|
| @no_grad() |
| def get_block_similarities_consecutive(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, cache_file=None): |
| device = accelerator.device |
|
|
| if cache_file is not None and os.path.exists(cache_file): |
| |
| accelerator.print(f"Loading cached model from {cache_file}") |
| similarities = torch.load(cache_file, map_location=device) |
|
|
| else: |
| |
| accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.") |
| unwrapped_model = accelerator.unwrap_model(model) |
| unwrapped_model.config.use_cache = False |
| layers = unwrapped_model.model.layers |
|
|
| accelerator.print("Getting features...") |
| inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) |
|
|
| |
| num_layers = unwrapped_model.config.num_hidden_layers |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| similarities = torch.full((len(layers), len(layers)), -math.inf, device=device) |
| accelerator.print('Starting ...') |
| wrapped_layers = [] |
|
|
| for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process): |
| sys.stderr.flush() |
| torch.cuda.empty_cache() |
| print_gpu_memory(accelerator) |
| layer = layers[i] |
|
|
| |
| wrapped_layer = HiddenStatesRecordWrapper(layer, record_input=True, record_output=(i == len(layers) - 1)) |
| wrapped_layers.append(wrapped_layer) |
|
|
| |
| def record_states_hook(_, input, output): |
| wrapped_layer.record(input[0].data, output[0].data) |
|
|
| |
| handle = layer.register_forward_hook(record_states_hook) |
| for j in range(num_samples): |
| if getattr(unwrapped_model.config, "model_type", None) == "llama": |
| outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0] |
| else: |
| outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0] |
| handle.remove() |
|
|
| |
| inputs, outputs = outputs, inputs |
| print_gpu_memory(accelerator) |
|
|
| dtype = torch.float32 |
| all_hidden_states = [] |
| for i in tqdm(range(len(layers)), desc="Concatenating hidden states...", disable=not accelerator.is_main_process): |
| all_hidden_states.append(torch.cat(wrapped_layers[i].input_hidden_states, dim=0).to(dtype)) |
| all_hidden_states.append(torch.cat(wrapped_layers[-1].output_hidden_states, dim=0).to(dtype)) |
| accelerator.print(f'Total {len(all_hidden_states)} hidden states concatenated.') |
|
|
| for i in tqdm(range(len(all_hidden_states)), desc="Calculating similarities...", disable=not accelerator.is_main_process): |
| for j in range(i + 1, len(all_hidden_states)): |
| packed_hidden_states_layer_i = all_hidden_states[i].to(device) |
| packed_hidden_states_layer_j = all_hidden_states[j].to(device) |
| index_gap = j - i |
|
|
| cos_sim = F.cosine_similarity(packed_hidden_states_layer_i, packed_hidden_states_layer_j, dim=-1) |
| cos_sim = cos_sim.mean() |
| cos_sim = accelerator.reduce(cos_sim, reduction="mean") |
|
|
| similarities[i, index_gap - 1] = cos_sim |
|
|
| |
| if cache_file is not None: |
| if accelerator.is_main_process: |
| create_dir(os.path.dirname(cache_file)) |
| torch.save(similarities.clone().cpu(), cache_file) |
| print(f"Saving cached similarities to {cache_file}") |
| accelerator.wait_for_everyone() |
|
|
| accelerator.print("similarities\n", similarities) |
|
|
| return similarities |
|
|
|
|
| def max_with_tolerance(similarities: torch.tensor, tolerance: float): |
| max_value, _ = torch.max(similarities, dim=0) |
| close_indices = torch.where(torch.abs(similarities - max_value) < tolerance)[0] |
| begin_layer_id = close_indices[0] |
|
|
| return max_value, begin_layer_id |
|
|
|
|
| def get_top_k(similarities, k, tolerance): |
| dropped_layer_list = [] |
| dropped_sim_list = [] |
| for _ in range(k): |
| max_value, max_index = max_with_tolerance(similarities, tolerance=tolerance) |
| dropped_layer_list.append(max_index.item()) |
| dropped_sim_list.append(max_value.item()) |
| similarities[max_index] = 0 |
| return dropped_sim_list, dropped_layer_list |
|
|
| def consecutive_block_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int): |
| """ |
| π Prune blocks in a consecutive order. |
| E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 1, 7, 8, 9] |
| """ |
| drop_n = args.drop_n |
|
|
| similarities = get_block_similarities_consecutive(model, dataloader, accelerator, num_samples, cache_file=args.similarity_cache_file) |
| similarities_drop_n = similarities[:, drop_n].view(-1) |
| max_similarity, begin_layer_id = torch.max(similarities_drop_n, dim=0) |
| accelerator.print(f"similarities_drop_n: {similarities_drop_n}") |
| accelerator.print(f"max_similarity: {max_similarity}, begin_layer_id: {begin_layer_id}") |
|
|
| end_layer_id = begin_layer_id + drop_n |
| dropped_layer_list = [i for i in range(begin_layer_id, end_layer_id)] |
|
|
| accelerator.print(f"Dropped layer: {dropped_layer_list}, max_similarity: {max_similarity}") |
| return dropped_layer_list |
|
|
|
|
| def discrete_block_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int): |
| """ |
| π Prune blocks in a discrete order. |
| E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 2, 6, 8, 9] |
| """ |
| drop_n = args.drop_n |
|
|
| similarities = get_block_similarities(model, dataloader, accelerator, num_samples, cache_file=args.similarity_cache_file) |
|
|
| similarities_drop_1 = similarities[:, 0].view(-1) |
| sorted_similarities, sorted_layer_id = torch.sort(similarities_drop_1, dim=0, descending=True) |
| accelerator.print(f"similarities_drop_1: {similarities_drop_1}") |
|
|
| dropped_layer_list = sorted_layer_id[:drop_n].tolist() |
| accelerator.print(f"Dropped layer: {dropped_layer_list}, similarities: {sorted_similarities[:drop_n].tolist()}") |
| return dropped_layer_list |
|
|
|
|
|
|
| def post_block_drop(prune_model_save_path, model, tokenizer, reserved_layer_list, accelerator: Accelerator, only_update_config=False): |
| unwrapped_model = accelerator.unwrap_model(model) |
|
|
| if accelerator.is_main_process: |
| out_cfg = deepcopy(unwrapped_model.config) |
| model_type = getattr(unwrapped_model.config, "model_type", None) |
|
|
| if model_type in auto_map: |
| out_cfg.auto_map = auto_map[model_type] |
| else: |
| raise ValueError("Unsupported model type!") |
|
|
| dropped_attn_list = dropped_mlp_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list)) |
| out_cfg.drop_mlp_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_mlp_list', [])) if v] + dropped_mlp_list |
| out_cfg.drop_attn_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_attn_list', [])) if v] + dropped_attn_list |
|
|
| accelerator.print(f"Dropped attention list: {dropped_attn_list}") |
| accelerator.print(f"Dropped MLP list: {dropped_mlp_list}") |
|
|
| accelerator.print("Saving...") |
| shutil.copy(CUSTOM_FILE[out_cfg.model_type]["config"], prune_model_save_path) |
| shutil.copy(CUSTOM_FILE[out_cfg.model_type]["model"], prune_model_save_path) |
| if not only_update_config: |
| model.save_pretrained(prune_model_save_path) |
| tokenizer.save_pretrained(prune_model_save_path) |
| out_cfg.save_pretrained(prune_model_save_path) |