File size: 8,341 Bytes
b3a3b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# python -m HybridTensor.benchmarks.model_perplexity --model_index 14 --batch_size 4 --max_length 512 --attn_th 1 --static_thresholds True

import sys
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from hf_models.opt.modeling_sparse_opt_topk import SparseOPTForCausalLM as SparseOPTTopkAttn
from hf_models.llama.modeling_sparse_llama_mha_topk import SparseLlamaForCausalLM as SparseLlamaTopKAttn
from HybridTensor.utils.activations import ActivationThresholds, identify_model_type, MODELS, CONFIGS
from HybridTensor.utils.utils import extract_model_name, compute_perplexity
import argparse
from datasets import load_dataset
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd


from HybridTensor.benchmarks.opt_attn_sparse_topk_perplexity import (_update_model_attn_thresholds,
                                                                    build_data_loader,
                                                                    compute_sparse_perplexity,
                                                                    compute_perplexity_data_collection,
                                                                    display_model_menu,
                                                                    _interactive_mode,
                                                                    arg_parser,
                                                                     )


results_dir = "results/activations"

def compute_attn_layer_sparsity(model_name, min_th, critical_th, attn_sparsity):
    # Get model configuration
    # model_name = MODELS[model_index - 1]
    model_config = CONFIGS[model_name]
    num_layers = model_config['num_layer']
    
    # Load the importance scores from the file specified in the configuration
    file_path = model_config['layer_imp']
    with open(file_path, 'r') as f:
        attn_layer_imp = json.load(f)
    layer_importance = attn_layer_imp['importance_scores']

    # Classify layers as critical or sparse
    critical_layers = [i for i, imp in enumerate(layer_importance) if imp >= critical_th]
    sparse_layers   = [i for i, imp in enumerate(layer_importance) if imp < critical_th]

    # Calculate total sparse importance and the attention value
    sum_sparse_importance = sum(layer_importance[i] for i in sparse_layers)
    attn_val = attn_sparsity * len(sparse_layers)

    # Compute the sparsity map per layer
    layer_sparsity_map = {}
    for layer_idx in range(num_layers):
        if layer_idx in critical_layers:
            layer_sparsity_map[layer_idx] = 1.0  # Fully dense for critical layers
        else:
            if sum_sparse_importance > 0:
                raw_fraction = (layer_importance[layer_idx] / sum_sparse_importance) * attn_val
            else:
                raw_fraction = attn_sparsity
            # Clamp the fraction between min_th and 1.0
            fraction = max(raw_fraction, min_th)
            fraction = min(fraction, 1.0)
            layer_sparsity_map[layer_idx] = fraction

    return layer_sparsity_map

def compute_average_activation(layer_sparsity_map):
    """
    Computes the average activation for each layer based on the sparsity map.
    """
    total_activation = 0.0
    for layer_idx, fraction in layer_sparsity_map.items():
        total_activation += fraction

    average_activation = total_activation / len(layer_sparsity_map)
    return average_activation

def compute_sparse_perplexity(model_name='facebook/opt-125m',
                            dataset_name='wikitext',
                            dataset_config='wikitext-2-raw-v1',
                            batch_size=8,
                            max_length=512,
                            attn_th=0.0,
                            static_thresholds=True,
                            device_map="cuda:0"):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # load the activation thresholds
    num_layers = CONFIGS[model_name]['num_layer']
    sp_thresholds = ActivationThresholds(num_layers=num_layers, attn_th=attn_th)
    
    print(f"Static attention activations: {sp_thresholds.activation_threshold}")
    if not static_thresholds:  
        # act_threshold_filepath = CONFIGS[model_name]['sp_config']
        attn_sparsity_map = compute_attn_layer_sparsity(model_name=model_name, min_th=0.2, critical_th=0.3, attn_sparsity=attn_th)
        sp_thresholds.load_thresholds(attn_sparsity_map)
        average_act = compute_average_activation(attn_sparsity_map)
        print(f"Layer imporatance weights attention activations {sp_thresholds.activation_threshold}")
        print(f"Average activation: {average_act:.2f}")
    
    # Load tokenizer and model
    # tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model_type = identify_model_type(model_name)
    if model_type == 'OPT':
        print(f"Loading OPT model: {model_name}")
        model = SparseOPTTopkAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
    elif model_type == 'Llama':
        print(f"Loading Llama model: {model_name}")
        model = SparseLlamaTopKAttn.from_pretrained(model_name, device_map = device_map, torch_dtype=torch.float16, sp_thresholds = sp_thresholds.activation_threshold, attn_implementation="flash_attention_2")
    model.eval()

    # # Load dataset
    dataloader = build_data_loader(model_name, dataset_name, dataset_config, batch_size, max_length)
    perplexity = compute_perplexity(model, dataloader, device)
    return perplexity


def arg_parser():
    parser = argparse.ArgumentParser(description='Sparse Perplexity Evaluation')
    parser.add_argument('--model_index', type=int, default=5, help='Index of the model to evaluate')
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
    parser.add_argument('--max_length', type=int, default=512, help='Maximum sequence length')
    parser.add_argument('--attn_th', type=float, default=0.0, help='Activation threshold for attention layers')
    parser.add_argument('--data_collection', type=bool, default=False, help='Collect data for different activation thresholds')
    parser.add_argument('--device_map', type=str, default='auto', help='Device to use for evaluation')
    parser.add_argument('--interactive', type=bool, default=False, help='Interactive mode for model selection')
    parser.add_argument('--static_thresholds', type=bool, default=False, help='Use static thresholds for attention layers')

    return parser.parse_args()

def main():
    """
    Main function to execute the perplexity computation with user-selected OPT model.
    """
    print("=== OPT Models Perplexity Evaluation ===\n")
    args = arg_parser()
    
    if args.interactive:
        selected_model, batch_size, max_length, data_collection, device_map, attn_th = _interactive_mode()
    
    else:
        selected_model, batch_size, max_length, data_collection, device_map, attn_th = MODELS[args.model_index-1], args.batch_size, args.max_length, args.data_collection, args.device_map, args.attn_th
        print(f"Selected model: {selected_model}, batch size: {batch_size}, max length: {max_length}, attn_th: {attn_th}, data_collection: {data_collection}, device: {device_map}")
        
    if data_collection:
        print("\nStarting data collection...\n")
        compute_perplexity_data_collection(model_name=selected_model, batch_size=batch_size, max_length=max_length, device_map=device_map)
        print("\nData collection complete.\n")
    
    else:
        print("\nStarting perplexity computation...\n")
        perplexity = compute_sparse_perplexity(model_name=selected_model, batch_size=batch_size, max_length=max_length,
                                               attn_th=attn_th,
                                               device_map=device_map,
                                               static_thresholds=args.static_thresholds)
        print(f"\n=== Perplexity Results ===")
        print(f"Model: {selected_model}")
        print(f"Perplexity: {perplexity:.2f}\n")

if __name__ == "__main__":
    main()