s1ghhh's picture
Upload folder using huggingface_hub
d73500e verified
import logging
import math
import os
import sys
import shutil
import pickle
from copy import deepcopy
import gc
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from torch import no_grad
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from .io import create_dir
from .utils import print_gpu_memory, prepare_calibration_input, auto_map, CUSTOM_FILE
from .wrapper import HiddenStatesRecordWrapper
import scipy
import subprocess
logger = logging.getLogger(__name__)
# πŸ” compute similarity
@no_grad()
def get_layer_similarities(model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int, drop_norm: bool, target_layer: str, cache_file=None):
device = accelerator.device
if cache_file is not None and os.path.exists(cache_file):
# use cached file
accelerator.print(f"Loading cached model from {cache_file}")
similarities = torch.load(cache_file, map_location=device)
else:
# calculate similarities
accelerator.print(f"No cached model found. Running model on {num_samples} samples for each device.")
unwrapped_model = accelerator.unwrap_model(model) # πŸ” unwrap model first
unwrapped_model.config.use_cache = False
unwrapped_model.config.output_attentions = True
layers = unwrapped_model.model.layers
accelerator.print("Getting features...")
inputs, outputs, attention_mask, position_ids, cache_position = prepare_calibration_input(unwrapped_model, dataloader, num_samples) # πŸ”
# πŸ” Get layer ids
num_layers = unwrapped_model.config.num_hidden_layers
layer_indices = list(range(num_layers))
# πŸ” Initialize the similarities.
# Row: each layer
# Column: similarity to the next n layer
# Example: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] # shape(6)
similarities = torch.full((num_layers,), -math.inf, device=device)
if hasattr(unwrapped_model.config, f'drop_{target_layer}_list'):
skipped_layers = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_{target_layer}_list', [])) if v]
else:
skipped_layers = []
accelerator.print('Starting ...')
for i in tqdm(range(num_layers), desc="Recording hidden states...", disable=not accelerator.is_main_process):
if i in skipped_layers:
similarities[i] = -math.inf
accelerator.print('Skip the dropped layer: ', i)
continue
sys.stderr.flush()
torch.cuda.empty_cache()
print_gpu_memory(accelerator)
layer = layers[i]
if i in layer_indices:
if target_layer == 'mlp':
module_pre_norm = layer.post_attention_layernorm
module = layer.mlp
elif target_layer == 'attn':
module_pre_norm = layer.input_layernorm
module = layer.self_attn
elif target_layer == 'all':
raise ValueError("Unsupported target_layer!")
if drop_norm:
wrapped_module_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=True, record_output=False) # πŸ” Wrap layer
else:
wrapped_module_pre_norm = HiddenStatesRecordWrapper(module_pre_norm, record_input=False, record_output=True) # πŸ” Wrap layer
wrapped_module = HiddenStatesRecordWrapper(module, record_input=False, record_output=True) # πŸ” Wrap layer
# Forward hook for recording hidden states
def record_module_pre_norm_states_hook(_, input, output):
wrapped_module_pre_norm.record(input[0].data, output[0].data)
if target_layer == 'mlp':
def record_module_states_hook(_, input, output):
wrapped_module.record(input[0].data, output[0].data)
elif target_layer == 'attn':
def record_module_states_hook(_, input, output):
wrapped_module.record(None, output[0].data)
else:
raise ValueError("Unsupported target_layer!")
# Get hidden states
handles = []
handles.append(module_pre_norm.register_forward_hook(record_module_pre_norm_states_hook))
handles.append(module.register_forward_hook(record_module_states_hook))
for j in range(num_samples):
print(j)
if getattr(unwrapped_model.config, "model_type", None) == "llama":
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0]
else:
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], output_attentions=False)[0]
for handle in handles:
handle.remove()
dtype = torch.float16
if drop_norm:
input_hidden_states = torch.cat(wrapped_module_pre_norm.input_hidden_states, dim=0).to(dtype).to(device)
output_hidden_states = torch.cat(wrapped_module.output_hidden_states, dim=0).to(dtype).to(device)
#attn_eigs = torch.cat(attn_eigs, dim=0)
#att_hidden_states = torch.cat(wrapped_module.output_hidden_states, dim=0).to(dtype).to("cpu")
else:
input_hidden_states = torch.cat(wrapped_module_pre_norm.output_hidden_states, dim=0).to(dtype).to(device)
output_hidden_states = torch.cat(wrapped_module.output_hidden_states, dim=0).to(dtype).to(device)
#attention_scores = torch.stack(wrapped_module.attention_scores, dim=0).to(dtype).to(device)
# πŸ” Calculate similarity (output+input due to residual connection)
#cos_sim = F.cosine_similarity(input_hidden_states, input_hidden_states + output_hidden_states, dim=-1) # (total_token_num)
#cos_sim = cos_sim.mean()
#cos_sim = accelerator.reduce(cos_sim, reduction="mean") # πŸ” All reduce across devices
#accelerator.print(f'layer {i} similarity: {cos_sim.item()}')
#similarities[i] = cos_sim
X = input_hidden_states.to("cpu").type(dtype).T
Y = output_hidden_states.to("cpu").type(dtype).T
with open(f"./llm_variables/xlayer_objs.pkl", 'wb') as f:
pickle.dump([X, Y], f)
del X, Y
del input_hidden_states, output_hidden_states
gc.collect()
torch.cuda.empty_cache()
accelerator.free_memory()
# Path to the script
script_path = "./src/calculate_cca.py"
# Variables to pass
var1 = str(i)
var2 = "./llm_variables"
# Run the script and pass variables as arguments
subprocess.run(["python", script_path, var1, var2])
print("CCA calculation ended")
with open(f"./llm_variables/similarity_scores.pkl", "rb") as f:
b = pickle.load(f)
similarities[i] = 1/(b[0][1].item())
print(similarities)
else:
for j in range(num_samples):
if getattr(unwrapped_model.config, "model_type", None) == "llama":
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j], cache_position=cache_position[j])[0]
else:
outputs[j] = layer(inputs[j], attention_mask=attention_mask[j], position_ids=position_ids[j])[0]
# Update inputs & outputs
inputs, outputs = outputs, inputs
# Save to the cache file
if cache_file is not None:
if accelerator.is_main_process:
create_dir(os.path.dirname(cache_file))
torch.save(similarities.clone().cpu(), cache_file)
print(f"Saving cached similarities to {cache_file}")
accelerator.wait_for_everyone()
accelerator.print("similarities\n", similarities)
return similarities
# πŸ” find indices of dropped layers
def discrete_layer_dropping(args, model, dataloader: DataLoader, accelerator: Accelerator, num_samples: int):
"""
πŸ” Prune mlp layers in a discrete order.
E.g., [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -> [0, 2, 6, 8, 9]
"""
drop_n = args.drop_n
if args.target_layer == 'all':
similarities_attn = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer='attn', cache_file=args.similarity_cache_file.replace("all", "all_attn"))
similarities_mlp = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer='mlp', cache_file=args.similarity_cache_file.replace("all", "all_mlp"))
similarities = torch.cat((similarities_attn, similarities_mlp), dim=0)
else:
similarities = get_layer_similarities(model, dataloader, accelerator, num_samples, args.layer_drop_norm, target_layer=args.target_layer, cache_file=args.similarity_cache_file)
sorted_similarities, sorted_layer_id = torch.sort(similarities, dim=0, descending=True)
dropped_layer_list = sorted_layer_id[:drop_n].tolist()
accelerator.print(f"Dropped layer: {dropped_layer_list}, similarities: {sorted_similarities[:drop_n].tolist()}")
return dropped_layer_list
def post_layers_drop(prune_model_save_path, target_layer, model, tokenizer, reserved_layer_list, accelerator: Accelerator, only_update_config=False):
unwrapped_model = accelerator.unwrap_model(model) # πŸ” unwrap model first
if accelerator.is_main_process:
out_cfg = deepcopy(unwrapped_model.config)
model_type = getattr(unwrapped_model.config, "model_type", None)
if model_type in auto_map:
out_cfg.auto_map = auto_map[model_type]
else:
raise ValueError("Unsupported model type!")
dropped_attn_list = []
dropped_mlp_list = []
if target_layer == 'all':
dropped_layer_list = list(set(list(range(out_cfg.num_hidden_layers * 2))) - set(reserved_layer_list))
for idx in dropped_layer_list:
if idx >= out_cfg.num_hidden_layers:
dropped_mlp_list.append(idx - out_cfg.num_hidden_layers)
else:
dropped_attn_list.append(idx)
elif target_layer == 'attn':
dropped_attn_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list))
elif target_layer == 'mlp':
dropped_mlp_list = list(set(list(range(out_cfg.num_hidden_layers))) - set(reserved_layer_list))
else:
raise ValueError("Unsupported target_layer!")
out_cfg.drop_mlp_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_mlp_list', [])) if v] + dropped_mlp_list
out_cfg.drop_attn_list = [idx for idx, v in enumerate(getattr(unwrapped_model.config, f'drop_attn_list', [])) if v] + dropped_attn_list
accelerator.print(f"Dropped attention list: {dropped_attn_list}")
accelerator.print(f"Dropped MLP list: {dropped_mlp_list}")
accelerator.print("Saving...")
shutil.copy(CUSTOM_FILE[out_cfg.model_type]["config"], prune_model_save_path)
shutil.copy(CUSTOM_FILE[out_cfg.model_type]["model"], prune_model_save_path)
if not only_update_config:
model.save_pretrained(prune_model_save_path)
tokenizer.save_pretrained(prune_model_save_path)
out_cfg.save_pretrained(prune_model_save_path)