| | import torch |
| | import numpy as np |
| | import os |
| |
|
| | from hf_models.opt.modeling_opt import OPTForCausalLM |
| | from hf_models.llama.modeling_llama import LlamaForCausalLM |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import build_data_loader |
| | from HybridTensor.utils.utils import extract_model_name |
| |
|
| | from datasets import load_dataset |
| | import json |
| |
|
| | from tqdm import tqdm |
| | import argparse |
| | from HybridTensor.utils.activations import MODELS |
| |
|
| |
|
| | def load_layer_data(data_dir, layer_idx, data_type): |
| | """ |
| | Load data for a specific layer and data type. |
| | |
| | Args: |
| | data_dir (str): Directory where data is stored. |
| | layer_idx (int): Layer index. |
| | data_type (str): One of 'hidden_states', 'mlp_activations', 'attn_norms'. |
| | |
| | Returns: |
| | np.ndarray: The data array of shape (num_samples, feature_size). |
| | """ |
| | |
| | metadata_filename = os.path.join(data_dir, 'metadata.json') |
| | with open(metadata_filename, 'r') as f: |
| | metadata = json.load(f) |
| |
|
| | num_layers = metadata['num_layers'] |
| | hidden_size = metadata['hidden_size'] |
| | num_heads = metadata['num_heads'] |
| | max_samples = metadata['max_samples'] |
| |
|
| | |
| | if layer_idx < 0 or layer_idx >= num_layers: |
| | raise ValueError(f"Invalid layer_idx: {layer_idx}. Must be between 0 and {num_layers - 1}.") |
| |
|
| | |
| | if data_type == 'hidden_states': |
| | sample_counts = metadata['hidden_states_counters'] |
| | sample_count = sample_counts[layer_idx] |
| | feature_size = hidden_size |
| | elif data_type == 'mlp_activations': |
| | sample_counts = metadata['mlp_activations_counters'] |
| | sample_count = sample_counts[layer_idx] |
| | feature_size = hidden_size * 4 |
| | elif data_type == 'attn_norms': |
| | sample_counts = metadata['attn_norms_counters'] |
| | sample_count = sample_counts[layer_idx] |
| | feature_size = num_heads |
| | else: |
| | raise ValueError(f"Invalid data_type: {data_type}. Must be 'hidden_states', 'mlp_activations', or 'attn_norms'.") |
| |
|
| | |
| | filename = os.path.join(data_dir, f'{data_type}_layer_{layer_idx}.dat') |
| | data_mmap = np.memmap(filename, dtype='float16', mode='r', shape=(max_samples, feature_size)) |
| |
|
| | |
| | data = np.array(data_mmap[:sample_count]) |
| | del data_mmap |
| | return data |
| |
|
| | def initialize_data_structures(data_dir, num_layers, hidden_size, num_heads, num_neurons, max_samples, |
| | mlp_activation=True, attn_norm=True): |
| | """ |
| | Initialize mmap files and counters for hidden_states, mlp_activations, and attn_norms. |
| | |
| | Args: |
| | data_dir (str): Directory where data is stored. |
| | num_layers (int): Number of transformer layers. |
| | hidden_size (int): Hidden size of the model. |
| | num_heads (int): Number of attention heads. |
| | max_samples (int): Maximum number of samples to collect. |
| | |
| | Returns: |
| | tuple: Contains lists of mmap files and counters for each data type. |
| | """ |
| | |
| | hidden_states_files = [] |
| | mlp_activations_files = [] |
| | attn_norms_files = [] |
| |
|
| | hidden_states_counters = [] |
| | mlp_activations_counters = [] |
| | attn_norms_counters = [] |
| |
|
| | for layer_idx in range(num_layers): |
| | |
| | hs_filename = os.path.join(data_dir, f'hidden_states_layer_{layer_idx}.dat') |
| | hs_file = np.memmap(hs_filename, dtype='float16', mode='w+', shape=(max_samples, hidden_size)) |
| | hidden_states_files.append(hs_file) |
| | hidden_states_counters.append(0) |
| |
|
| | |
| | if mlp_activation: |
| | mlp_filename = os.path.join(data_dir, f'mlp_activations_layer_{layer_idx}.dat') |
| | mlp_file = np.memmap(mlp_filename, dtype='float16', mode='w+', shape=(max_samples, num_neurons)) |
| | mlp_activations_files.append(mlp_file) |
| | mlp_activations_counters.append(0) |
| |
|
| | |
| | if attn_norm: |
| | attn_filename = os.path.join(data_dir, f'attn_norms_layer_{layer_idx}.dat') |
| | attn_file = np.memmap(attn_filename, dtype='float16', mode='w+', shape=(max_samples, num_heads)) |
| | attn_norms_files.append(attn_file) |
| | attn_norms_counters.append(0) |
| |
|
| | return ( |
| | hidden_states_files, |
| | hidden_states_counters, |
| | mlp_activations_files, |
| | mlp_activations_counters, |
| | attn_norms_files, |
| | attn_norms_counters |
| | ) |
| |
|
| | def process_hidden_states(layer_idx, hidden_states_layer, valid_token_indices, hidden_size, hidden_states_files, hidden_states_counters): |
| | """ |
| | Process and store hidden states for a specific layer. |
| | """ |
| | hs = hidden_states_layer.view(-1, hidden_size) |
| | hs_valid = hs[valid_token_indices.cpu()] |
| | hs_counter = hidden_states_counters[layer_idx] |
| | hs_file = hidden_states_files[layer_idx] |
| | hs_file[hs_counter:hs_counter+hs_valid.shape[0], :] = hs_valid.cpu().numpy().astype('float16') |
| | hidden_states_counters[layer_idx] += hs_valid.shape[0] |
| |
|
| | def process_mlp_activations(layer_idx, mlp_activations_layer, valid_token_indices, hidden_size, mlp_activations_files, mlp_activations_counters): |
| | """ |
| | Process and store MLP activations for a specific layer. |
| | """ |
| | neuron_shape = mlp_activations_layer.shape[-1] |
| | mlp_activations_layer = mlp_activations_layer.view(-1, neuron_shape) |
| | |
| | mlp_valid = mlp_activations_layer[valid_token_indices.cpu()] |
| | mlp_counter = mlp_activations_counters[layer_idx] |
| | mlp_file = mlp_activations_files[layer_idx] |
| | mlp_file[mlp_counter:mlp_counter+mlp_valid.shape[0], :] = mlp_valid.cpu().numpy().astype('float16') |
| | mlp_activations_counters[layer_idx] += mlp_valid.shape[0] |
| |
|
| | def process_attn_norms(layer_idx, attn_outputs_layer, valid_token_indices, num_heads, attn_norms_files, attn_norms_counters): |
| | """ |
| | Process and store attention norms for a specific layer. |
| | """ |
| | |
| | attn = attn_outputs_layer |
| | attn_norms = torch.norm(attn, dim=-1) |
| | attn_norms = attn_norms.view(-1, num_heads) |
| | attn_valid = attn_norms[valid_token_indices.cpu()] |
| | attn_counter = attn_norms_counters[layer_idx] |
| | attn_file = attn_norms_files[layer_idx] |
| | attn_file[attn_counter:attn_counter+attn_valid.shape[0], :] = attn_valid.cpu().numpy().astype('float16') |
| | attn_norms_counters[layer_idx] += attn_valid.shape[0] |
| |
|
| | def process_batch( |
| | outputs, |
| | input_ids, |
| | attention_mask, |
| | total_samples, |
| | max_samples, |
| | num_layers, |
| | hidden_size, |
| | num_heads, |
| | hidden_states_files, |
| | hidden_states_counters, |
| | mlp_activations_files, |
| | mlp_activations_counters, |
| | attn_norms_files, |
| | attn_norms_counters, |
| | args |
| | ): |
| | """ |
| | Process a batch of model outputs and update the data files. |
| | |
| | Returns: |
| | total_samples (int): Updated total number of samples processed. |
| | reached_max_samples (bool): Indicates if the maximum number of samples has been reached. |
| | """ |
| | |
| | |
| | |
| | |
| | hidden_states = outputs['router_inputs'] |
| | |
| | if args.mlp_activation: |
| | mlp_activations = outputs['mlp_activations'] |
| | else: |
| | mlp_activations = None |
| | |
| | if args.attn_norm: |
| | attn_outputs = outputs['attn_outputs'] |
| | else: |
| | attn_outputs = None |
| |
|
| | batch_size, seq_len = input_ids.shape |
| |
|
| | |
| | attention_mask_flat = attention_mask.view(-1).bool() |
| | num_valid_tokens = attention_mask_flat.sum().item() |
| |
|
| | |
| | if total_samples + num_valid_tokens >= max_samples: |
| | tokens_to_process = max_samples - total_samples |
| | total_samples = max_samples |
| | reached_max_samples = True |
| | else: |
| | tokens_to_process = num_valid_tokens |
| | total_samples += num_valid_tokens |
| | reached_max_samples = False |
| |
|
| | |
| | valid_token_indices = attention_mask_flat.nonzero(as_tuple=False).view(-1) |
| | |
| | valid_token_indices = valid_token_indices[:tokens_to_process] |
| |
|
| | for layer_idx in range(num_layers): |
| | |
| | process_hidden_states( |
| | layer_idx, |
| | hidden_states[layer_idx], |
| | valid_token_indices, |
| | hidden_size, |
| | hidden_states_files, |
| | hidden_states_counters |
| | ) |
| | if args.mlp_activation: |
| | |
| | process_mlp_activations( |
| | layer_idx, |
| | mlp_activations[layer_idx], |
| | valid_token_indices, |
| | hidden_size, |
| | mlp_activations_files, |
| | mlp_activations_counters |
| | ) |
| |
|
| | if args.attn_norm: |
| | |
| | process_attn_norms( |
| | layer_idx, |
| | attn_outputs[layer_idx], |
| | valid_token_indices, |
| | num_heads, |
| | attn_norms_files, |
| | attn_norms_counters |
| | ) |
| |
|
| | return total_samples, reached_max_samples, num_valid_tokens |
| |
|
| | def finalize_data_collection( |
| | data_dir, |
| | num_layers, |
| | hidden_size, |
| | num_heads, |
| | max_samples, |
| | hidden_states_files, |
| | mlp_activations_files, |
| | attn_norms_files, |
| | hidden_states_counters, |
| | mlp_activations_counters, |
| | attn_norms_counters, |
| | args |
| | ): |
| | """ |
| | Finalize the data collection by flushing and closing mmap files and saving metadata. |
| | |
| | Args: |
| | data_dir (str): Directory where data is stored. |
| | num_layers (int): Number of transformer layers. |
| | hidden_size (int): Hidden size of the model. |
| | num_heads (int): Number of attention heads. |
| | max_samples (int): Maximum number of samples to collect. |
| | hidden_states_files (list): List of mmap files for hidden states. |
| | mlp_activations_files (list): List of mmap files for MLP activations. |
| | attn_norms_files (list): List of mmap files for attention norms. |
| | hidden_states_counters (list): List of counters for hidden states. |
| | mlp_activations_counters (list): List of counters for MLP activations. |
| | attn_norms_counters (list): List of counters for attention norms. |
| | """ |
| |
|
| | for layer_idx in range(num_layers): |
| | |
| | hs_file = hidden_states_files[layer_idx] |
| | hs_file.flush() |
| | del hs_file |
| |
|
| | if args.mlp_activation: |
| | |
| | mlp_file = mlp_activations_files[layer_idx] |
| | mlp_file.flush() |
| | del mlp_file |
| |
|
| | if args.attn_norm: |
| | |
| | attn_file = attn_norms_files[layer_idx] |
| | attn_file.flush() |
| | del attn_file |
| |
|
| | |
| | metadata = { |
| | 'num_layers': num_layers, |
| | 'hidden_size': hidden_size, |
| | 'num_heads': num_heads, |
| | 'max_samples': max_samples, |
| | 'hidden_states_counters': hidden_states_counters, |
| | 'mlp_activations_counters': mlp_activations_counters, |
| | 'attn_norms_counters': attn_norms_counters |
| | } |
| |
|
| | |
| | metadata_filename = os.path.join(data_dir, 'metadata.json') |
| | with open(metadata_filename, 'w') as f: |
| | json.dump(metadata, f) |
| |
|
| | print("Finalization complete. Metadata saved.") |
| |
|
| | def arg_parser(): |
| | parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation') |
| | parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate') |
| | parser.add_argument('--batch_size', type=int, default=4, help='Batch size for evaluation') |
| | parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length') |
| | parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds') |
| | parser.add_argument('--device_map', type=str, default='cuda:0', help='Device to use for evaluation') |
| | parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection') |
| | parser.add_argument('--data_dir', type=str, default='<PATH_TO_DATA_DIR>', help='Directory to store generated data') |
| | parser.add_argument('--max_samples', type=int, default=5000, help='Maximum number of samples to collect') |
| | parser.add_argument('--model_family', type=str, default='opt', choices= ["opt", "llama"], help='Model family to evaluate') |
| | parser.add_argument('--mlp_activation', type=bool, default=False, help='Collect MLP activations') |
| | parser.add_argument('--attn_norm', type=bool, default=True, help='Collect attention norms') |
| |
|
| | return parser.parse_args() |
| |
|
| | if __name__ =="__main__": |
| | args = arg_parser() |
| | model_name = MODELS[args.model_index-1] |
| | batch_size = args.batch_size |
| | max_length = args.max_length |
| | data_collection = args.data_collection |
| | device_map = args.device_map |
| |
|
| | |
| | if args.model_family == 'opt': |
| | model = OPTForCausalLM.from_pretrained( |
| | model_name, device_map=device_map, torch_dtype=torch.float16, |
| | attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, |
| | return_dict=True |
| | ) |
| | num_neurons = model.config.ffn_dim |
| | |
| | elif args.model_family == 'llama': |
| | model = LlamaForCausalLM.from_pretrained( |
| | model_name, device_map=device_map, torch_dtype=torch.float16, |
| | attn_implementation="flash_attention_2", output_hidden_states=True, output_attentions=True, |
| | return_dict=True |
| | ) |
| | num_neurons = model.config.intermediate_size |
| |
|
| | data_loader = build_data_loader( |
| | model_name, "wikitext", "wikitext-2-raw-v1", batch_size, max_length, split='train' |
| | ) |
| | if args.device_map == "auto": |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | else: |
| | device = torch.device(device_map if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | model_name_clean = extract_model_name(model_name) |
| | folder_name = f"{model_name_clean}_act_data" |
| | data_dir = os.path.join(args.data_dir, folder_name) |
| | |
| | if not os.path.exists(data_dir): |
| | os.makedirs(data_dir) |
| |
|
| | num_layers = model.config.num_hidden_layers |
| | hidden_size = model.config.hidden_size |
| | num_heads = model.config.num_attention_heads |
| | max_samples = args.max_samples |
| | |
| | |
| | print(f"Collecting data for model: {model_name}") |
| | print(f"Data directory: {data_dir}") |
| | print(f"Number of layers: {num_layers}") |
| | print(f"Hidden size: {hidden_size}") |
| | print(f"Number of heads: {num_heads}") |
| | print(f"Number of neurons: {num_neurons}") |
| | print(f"Max samples: {max_samples}") |
| | print(f"Collecting MLP activations: {args.mlp_activation}") |
| | print(f"Collecting attention norms: {args.attn_norm}") |
| | |
| | |
| | (hidden_states_files, hidden_states_counters, mlp_activations_files, |
| | mlp_activations_counters, attn_norms_files, attn_norms_counters) = initialize_data_structures(data_dir, num_layers, |
| | hidden_size, num_heads, num_neurons, max_samples, |
| | mlp_activation=args.mlp_activation, attn_norm=args.attn_norm) |
| |
|
| | total_samples = 0 |
| | |
| | with torch.no_grad(): |
| | with tqdm(total=max_samples, desc="Router training data collection") as pbar: |
| | for batch in data_loader: |
| | input_ids = batch['input_ids'].to(device) |
| | attention_mask = batch['attention_mask'].to(device) |
| |
|
| | outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, |
| | output_attentions=False, return_dict=True, output_mlp_activation=args.mlp_activation, |
| | output_attn_output=args.attn_norm, output_router_inputs=True) |
| |
|
| | |
| | total_samples, reached_max_samples, num_valid_tokens = process_batch( |
| | outputs, input_ids, attention_mask, |
| | total_samples, max_samples, num_layers, |
| | hidden_size, num_heads, hidden_states_files, |
| | hidden_states_counters, mlp_activations_files, mlp_activations_counters, |
| | attn_norms_files, attn_norms_counters, |
| | args=args) |
| | |
| | pbar.update(num_valid_tokens) |
| |
|
| | if reached_max_samples: |
| | break |
| |
|
| | |
| | finalize_data_collection( |
| | data_dir, num_layers, hidden_size, |
| | num_heads, max_samples, hidden_states_files, |
| | mlp_activations_files, attn_norms_files, hidden_states_counters, |
| | mlp_activations_counters, attn_norms_counters, |
| | args |
| | ) |
| | |
| | if reached_max_samples: |
| | print(f"Reached maximum number of samples. total_samples = {total_samples}") |
| | print(f"Data collection complete. Data saved to {data_dir}") |