PolarSparsity / HybridTensor /benchmarks /model_perplexity.py
Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
# python -m HybridTensor.benchmarks.model_perplexity --model_index 14 --batch_size 4 --max_length 512 --attn_th 1 --static_thresholds True
import sys
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopkAttn
from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
from HybridTensor.utils.activations import ActivationThresholds, identify_model_type, MODELS, CONFIGS
from HybridTensor.utils.utils import extract_model_name, compute_perplexity
import argparse
from datasets import load_dataset
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import (_update_model_attn_thresholds,
build_data_loader,
compute_sparse_perplexity,
compute_perplexity_data_collection,
display_model_menu,
_interactive_mode,
arg_parser,
)
results_dir = "results/activations"
def compute_attn_layer_sparsity(model_name, min_th, critical_th, attn_sparsity):
# Get model configuration
# model_name = MODELS[model_index - 1]
model_config = CONFIGS[model_name]
num_layers = model_config['num_layer']
# Load the importance scores from the file specified in the configuration
file_path = model_config['layer_imp']
with open(file_path, 'r') as f:
attn_layer_imp = json.load(f)
layer_importance = attn_layer_imp['importance_scores']
# Classify layers as critical or sparse
critical_layers = [i for i, imp in enumerate(layer_importance) if imp >= critical_th]
sparse_layers = [i for i, imp in enumerate(layer_importance) if imp < critical_th]
# Calculate total sparse importance and the attention value
sum_sparse_importance = sum(layer_importance[i] for i in sparse_layers)
attn_val = attn_sparsity * len(sparse_layers)
# Compute the sparsity map per layer
layer_sparsity_map = {}
for layer_idx in range(num_layers):
if layer_idx in critical_layers:
layer_sparsity_map[layer_idx] = 1.0 # Fully dense for critical layers
else:
if sum_sparse_importance > 0:
raw_fraction = (layer_importance[layer_idx] / sum_sparse_importance) * attn_val
else:
raw_fraction = attn_sparsity
# Clamp the fraction between min_th and 1.0
fraction = max(raw_fraction, min_th)
fraction = min(fraction, 1.0)
layer_sparsity_map[layer_idx] = fraction
return layer_sparsity_map
def compute_average_activation(layer_sparsity_map):
"""
Computes the average activation for each layer based on the sparsity map.
"""
total_activation = 0.0
for layer_idx, fraction in layer_sparsity_map.items():
total_activation += fraction
average_activation = total_activation / len(layer_sparsity_map)
return average_activation
def compute_sparse_perplexity(model_name='facebook/opt-125m',
dataset_name='wikitext',
dataset_config='wikitext-2-raw-v1',
batch_size=8,
max_length=512,
attn_th=0.0,
static_thresholds=True,
device_map="cuda:0"):
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# load the activation thresholds
num_layers = CONFIGS[model_name]['num_layer']
sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
print(f"Static attention activations: {sp_thresholds.activation_threshold}")
if not static_thresholds:
# act_threshold_filepath = CONFIGS[model_name]['sp_config']
attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
sp_thresholds.load_thresholds(attn_sparsity_map)
average_act = compute_average_activation(attn_sparsity_map)
print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
print(f"Average activation: {average_act:.2f}")
# Load tokenizer and model
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model_type = identify_model_type(model_name)
if model_type == 'OPT':
print(f"Loading OPT model: {model_name}")
model = SparseOPTTopkAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
elif model_type == 'Llama':
print(f"Loading Llama model: {model_name}")
model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
model.eval()
# # Load dataset
dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
perplexity = compute_perplexity(model, dataloader, device)
return perplexity
def arg_parser():
parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
parser.add_argument('--device_map', type=str, default='auto', help='Device to use for evaluation')
parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
parser.add_argument('--static_thresholds', type=bool, default=False, help='Use static thresholds for attention layers')
return parser.parse_args()
def main():
"""
Main function to execute the perplexity computation with user-selected OPT model.
"""
print("=== OPT Models Perplexity Evaluation ===\n")
args = arg_parser()
if args.interactive:
selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
else:
selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
if data_collection:
print("\nStarting data collection...\n")
compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
print("\nData collection complete.\n")
else:
print("\nStarting perplexity computation...\n")
perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length,
attn_th=attn_th,
device_map=device_map,
static_thresholds=args.static_thresholds)
print(f"\n=== Perplexity Results ===")
print(f"Model: {selected_model}")
print(f"Perplexity: {perplexity:.2f}\n")
if __name__ == "__main__":
main()