s1ghhh's picture
Upload folder using huggingface_hub
d73500e verified
import logging
import math
import os
import sys
from copy import deepcopy
import shutil
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from torch import no_grad
from torch.utils.data import DataLoader
from tqdm import tqdm
from .io import create_dir
from .utils import prepare_calibration_input, print_gpu_memory, auto_map, CUSTOM_FILE
from .wrapper import HiddenStatesRecordWrapper
logger = logging.getLogger(__name__)
def get_block_similarities(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, cache_file=None):
device = accelerator.device
if cache_file is not None and os.path.exists(cache_file):
# use cached file
accelerator.print(f"Loading cached model from {cache_file}")
similarities = torch.load(cache_file, map_location=device)
else:
# calculate similarities
accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.")
unwrapped_model = accelerator.unwrap_model(model) # πŸ” unwrap model first
unwrapped_model.config.use_cache = False
layers = unwrapped_model.model.layers
accelerator.print("Getting features...")
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) # πŸ”
num_layers = unwrapped_model.config.num_hidden_layers
# πŸ” Initialize the similarities.
# Row: each layer
# Column: similarity to the next n layer
# Example: [ [0.5], [0.5], [0.5], [0.5], [0.5], [0.5]] # shape(6, 1)
similarities = torch.full((len(layers), 1), -math.inf, device=device)
accelerator.print('Starting ...')
dtype = torch.float32
for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process):
sys.stderr.flush()
torch.cuda.empty_cache()
print_gpu_memory(accelerator)
layer = layers[i]
# Wrap layer
wrapped_layer = HiddenStatesRecordWrapper(layer, record_input=True, record_output=True) # πŸ” Wrap layer
# Forward hook for recording hidden states
def record_states_hook(_, input, output):
wrapped_layer.record(input[0].data, output[0].data)
# Get states
handle = layer.register_forward_hook(record_states_hook)
for j in range(num_samples):
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0]
handle.remove()
# Update inputs & outputs
inputs, outputs = outputs, inputs
print_gpu_memory(accelerator)
input_hidden_states = torch.cat(wrapped_layer.input_hidden_states, dim=0).to(dtype).to(device)
output_hidden_states = torch.cat(wrapped_layer.output_hidden_states, dim=0).to(dtype).to(device)
cos_sim = F.cosine_similarity(input_hidden_states, output_hidden_states, dim=-1) # (total_token_num)
cos_sim = cos_sim.mean()
cos_sim = accelerator.reduce(cos_sim, reduction="mean") # πŸ” All reduce across devices
similarities[i, 0] = cos_sim
layer.to("cpu") # πŸ”
# Save to the cache file
if cache_file is not None:
if accelerator.is_main_process:
create_dir(os.path.dirname(cache_file))
torch.save(similarities.clone().cpu(), cache_file)
print(f"Saving cached similarities to {cache_file}")
accelerator.wait_for_everyone()
accelerator.print("similarities\n", similarities)
return similarities
@no_grad()
def get_block_similarities_consecutive(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, cache_file=None):
device = accelerator.device
if cache_file is not None and os.path.exists(cache_file):
# use cached file
accelerator.print(f"Loading cached model from {cache_file}")
similarities = torch.load(cache_file, map_location=device)
else:
# calculate similarities
accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.")
unwrapped_model = accelerator.unwrap_model(model) # πŸ” unwrap model first
unwrapped_model.config.use_cache = False
layers = unwrapped_model.model.layers
accelerator.print("Getting features...")
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) # πŸ”
# πŸ” Get layer ids
num_layers = unwrapped_model.config.num_hidden_layers
# πŸ” Initialize the similarities.
# Row: each layer
# Column: similarity to the next n layer
# Example: [[ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
# [ 0.5, 0.5, 0.5, 0.5, 0.5, -inf],
# [ 0.5, 0.5, 0.5, 0.5, -inf, -inf],
# [ 0.5, 0.5, 0.5, -inf, -inf, -inf],
# [ 0.5, 0.5, -inf, -inf, -inf, -inf],
# [ 0.5, -inf, -inf, -inf, -inf, -inf]] # shape(6, 6)
similarities = torch.full((len(layers), len(layers)), -math.inf, device=device)
accelerator.print('Starting ...')
wrapped_layers = []
for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process):
sys.stderr.flush()
torch.cuda.empty_cache()
print_gpu_memory(accelerator)
layer = layers[i]
# Wrap layer
wrapped_layer = HiddenStatesRecordWrapper(layer, record_input=True, record_output=(i == len(layers) - 1)) # πŸ” Wrap layer
wrapped_layers.append(wrapped_layer)
# Forward hook for recording hidden states
def record_states_hook(_, input, output):
wrapped_layer.record(input[0].data, output[0].data)
# Get states
handle = layer.register_forward_hook(record_states_hook)
for j in range(num_samples):
if getattr(unwrapped_model.config, "model_type", None) == "llama":
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0]
else:
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0]
handle.remove()
# Update inputs & outputs
inputs, outputs = outputs, inputs
print_gpu_memory(accelerator)
dtype = torch.float32
all_hidden_states = []
for i in tqdm(range(len(layers)), desc="Concatenating hidden states...", disable=not accelerator.is_main_process):
all_hidden_states.append(torch.cat(wrapped_layers[i].input_hidden_states, dim=0).to(dtype)) # (total_token_num, hidden_size)
all_hidden_states.append(torch.cat(wrapped_layers[-1].output_hidden_states, dim=0).to(dtype))
accelerator.print(f'Total {len(all_hidden_states)} hidden states concatenated.')
for i in tqdm(range(len(all_hidden_states)), desc="Calculating similarities...", disable=not accelerator.is_main_process):
for j in range(i + 1, len(all_hidden_states)):
packed_hidden_states_layer_i = all_hidden_states[i].to(device)
packed_hidden_states_layer_j = all_hidden_states[j].to(device)
index_gap = j - i
cos_sim = F.cosine_similarity(packed_hidden_states_layer_i, packed_hidden_states_layer_j, dim=-1) # (total_token_num)
cos_sim = cos_sim.mean()
cos_sim = accelerator.reduce(cos_sim, reduction="mean") # πŸ” All reduce across devices
similarities[i, index_gap - 1] = cos_sim
# Save to the cache file
if cache_file is not None:
if accelerator.is_main_process:
create_dir(os.path.dirname(cache_file))
torch.save(similarities.clone().cpu(), cache_file)
print(f"Saving cached similarities to {cache_file}")
accelerator.wait_for_everyone()
accelerator.print("similarities\n", similarities)
return similarities
def max_with_tolerance(similarities: torch.tensor, tolerance: float):
max_value, _ = torch.max(similarities, dim=0)
close_indices = torch.where(torch.abs(similarities - max_value) < tolerance)[0]
begin_layer_id = close_indices[0]
return max_value, begin_layer_id
def get_top_k(similarities, k, tolerance):
dropped_layer_list = []
dropped_sim_list = []
for _ in range(k):
max_value, max_index = max_with_tolerance(similarities, tolerance=tolerance)
dropped_layer_list.append(max_index.item())
dropped_sim_list.append(max_value.item())
similarities[max_index] = 0
return dropped_sim_list, dropped_layer_list
def consecutive_block_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int):
"""
πŸ” Prune blocks in a consecutive order.
E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 1, 7, 8, 9]
"""
drop_n = args.drop_n
similarities = get_block_similarities_consecutive(model, dataloader, accelerator, num_samples, cache_file=args.similarity_cache_file)
similarities_drop_n = similarities[:, drop_n].view(-1)
max_similarity, begin_layer_id = torch.max(similarities_drop_n, dim=0)
accelerator.print(f"similarities_drop_n: {similarities_drop_n}")
accelerator.print(f"max_similarity: {max_similarity}, begin_layer_id: {begin_layer_id}")
end_layer_id = begin_layer_id + drop_n
dropped_layer_list = [i for i in range(begin_layer_id, end_layer_id)]
accelerator.print(f"Dropped layer: {dropped_layer_list}, max_similarity: {max_similarity}")
return dropped_layer_list
def discrete_block_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int):
"""
πŸ” Prune blocks in a discrete order.
E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 2, 6, 8, 9]
"""
drop_n = args.drop_n
similarities = get_block_similarities(model, dataloader, accelerator, num_samples, cache_file=args.similarity_cache_file)
similarities_drop_1 = similarities[:, 0].view(-1)
sorted_similarities, sorted_layer_id = torch.sort(similarities_drop_1, dim=0, descending=True)
accelerator.print(f"similarities_drop_1: {similarities_drop_1}")
dropped_layer_list = sorted_layer_id[:drop_n].tolist()
accelerator.print(f"Dropped layer: {dropped_layer_list}, similarities: {sorted_similarities[:drop_n].tolist()}")
return dropped_layer_list
def post_block_drop(prune_model_save_path, model, tokenizer, reserved_layer_list, accelerator: Accelerator, only_update_config=False):
unwrapped_model = accelerator.unwrap_model(model) # πŸ” unwrap model first
if accelerator.is_main_process:
out_cfg = deepcopy(unwrapped_model.config)
model_type = getattr(unwrapped_model.config, "model_type", None)
if model_type in auto_map:
out_cfg.auto_map = auto_map[model_type]
else:
raise ValueError("Unsupported model type!")
dropped_attn_list = dropped_mlp_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list))
out_cfg.drop_mlp_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_mlp_list', [])) if v] + dropped_mlp_list
out_cfg.drop_attn_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_attn_list', [])) if v] + dropped_attn_list
accelerator.print(f"Dropped attention list: {dropped_attn_list}")
accelerator.print(f"Dropped MLP list: {dropped_mlp_list}")
accelerator.print("Saving...")
shutil.copy(CUSTOM_FILE[out_cfg.model_type]["config"], prune_model_save_path)
shutil.copy(CUSTOM_FILE[out_cfg.model_type]["model"], prune_model_save_path)
if not only_update_config:
model.save_pretrained(prune_model_save_path)
tokenizer.save_pretrained(prune_model_save_path)
out_cfg.save_pretrained(prune_model_save_path)