File size: 8,147 Bytes
d73500e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

import os
from copy import deepcopy
from typing import TYPE_CHECKING, List, Optional

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from torch.utils.data import DataLoader

from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.model import load_model_and_tokenizer

from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling, DataCollatorWithPadding
from .io import load_json, save_sparse_model, save_block_dropped_config, save_layer_dropped_config
from .block_drop import consecutive_block_dropping, discrete_block_dropping, post_block_drop
from .layer_drop import discrete_layer_dropping, post_layers_drop
from .nbl_linearize import apply_nbl_linearization
import torch

if TYPE_CHECKING:
    from transformers import Seq2SeqTrainingArguments, TrainerCallback
    from llmtuner.hparams import DataArguments, FinetuningArguments, ModelArguments, PruningArguments

LAYER_DROP_METHODS_FUNC = {
    'discrete': discrete_layer_dropping,
}

BLOCK_DROP_METHODS_FUNC = {
    'consecutive': consecutive_block_dropping,
    'discrete': discrete_block_dropping,
}


# πŸ” Modified from src.llmtuner.compression.pt.workflow.run_pt
def run_prune(
        model_args: "ModelArguments",
        data_args: "DataArguments",
        training_args: "Seq2SeqTrainingArguments",
        finetuning_args: "FinetuningArguments",
        pruning_args: "PruningArguments",  # πŸ” for pruning
        callbacks: Optional[List["TrainerCallback"]] = None,
):
    """Workflow for pruning and decomposing."""
    # πŸ” accelerator
    accelerator = Accelerator()
    accelerator.print(f"{AcceleratorState()}")
    accelerator.print("Pruning Args:", pruning_args)
    accelerator.print("Model Args:", model_args)

    # πŸ” model & tokenizer
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
    
    if pruning_args.prune_method == "layer_drop" and pruning_args.layer_drop_method == "post_dropping":
        assert (os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")) and (os.environ.get("ACCELERATE_USE_FSDP", "false"))
        reserved_layer_list = load_json(os.path.join(pruning_args.prune_model_save_path, "reserved_layers.json"))
        post_layers_drop(pruning_args.prune_model_save_path, pruning_args.target_layer, model, tokenizer, reserved_layer_list, accelerator, pruning_args.only_update_config)
        exit()
    if pruning_args.prune_method == "block_drop" and pruning_args.block_drop_method == "post_dropping":
        assert (os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")) and (os.environ.get("ACCELERATE_USE_FSDP", "false"))
        reserved_layer_list = load_json(os.path.join(pruning_args.prune_model_save_path, "reserved_layers.json"))
        post_block_drop(pruning_args.prune_model_save_path, model, tokenizer, reserved_layer_list, accelerator, pruning_args.only_update_config)
        exit()

    # πŸ” dataset & data collator & dataloader
    dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage=pruning_args.prune_data_type)

    if pruning_args.prune_data_type == "pt":
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)  # concat all data to seq_length for each batch
    elif pruning_args.prune_data_type == "sft":
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None,  # for shift short attention
            label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
        )
    else:
        raise NotImplementedError
    dataloader = DataLoader(dataset, batch_size=1, collate_fn=data_collator, num_workers=8)  # batch size must be 1
    accelerator.print("Total Sample Num:", len(dataset))
    accelerator.print("Total Used Sample Num:", pruning_args.n_calibration_samples)
    accelerator.print("Max sequence Length:", data_args.cutoff_len)
    accelerator.print(f"Example Data (len = {len(dataset[0]['input_ids'])}):", dataset[0])

    if pruning_args.n_calibration_samples > len(dataset):
        raise ValueError("Number of calibration samples is greater than the number of samples in the dataset!")

    # πŸ” Prepare model & dataloader
    print("Preparing model...")
    model, dataloader = accelerator.prepare(model, dataloader)

    # πŸ” Distribute samples to each device for acceleration
    assert (pruning_args.n_calibration_samples % accelerator.num_processes == 0)  # have to be divided evenly
    num_samples_each_device = pruning_args.n_calibration_samples // accelerator.num_processes
    accelerator.print("Number of samples per device:", len(dataloader))
    accelerator.print("Number of used samples per device:", num_samples_each_device)

    #######################################################################################################
    if pruning_args.prune_method == "nbl_linearize":
        linearization_data = apply_nbl_linearization(
            model,
            dataloader,
            accelerator,
            num_samples_each_device,
            pruning_args.num_layers_to_linearize,
            pruning_args.nbl_metric_cache_file
        )
        if accelerator.is_main_process:
            # The actual layer replacement happens at model loading time.
            # Here we just need to save the config and the weights.
            unwrapped_model = accelerator.unwrap_model(model)
            config = unwrapped_model.config
            
            # Update config
            linearized_layers = list(linearization_data.keys())
            config.nbl_attn_list = linearized_layers
            
            # Save the model with updated config, and also save linearization weights
            if pruning_args.prune_model_save_path:
                os.makedirs(pruning_args.prune_model_save_path, exist_ok=True)
                
                # Save config
                config.save_pretrained(pruning_args.prune_model_save_path)
                
                # Save model weights. The linearized layers are not part of the state_dict yet.
                unwrapped_model.save_pretrained(pruning_args.prune_model_save_path)
                tokenizer.save_pretrained(pruning_args.prune_model_save_path)

                # Save the computed linear weights
                linear_weights_path = os.path.join(pruning_args.prune_model_save_path, "nbl_linear_weights.pt")
                torch.save(linearization_data, linear_weights_path)
                accelerator.print(f"Saved NBL config and weights to {pruning_args.prune_model_save_path}")

    elif pruning_args.prune_method == "layer_drop":
        dropped_layer_list = LAYER_DROP_METHODS_FUNC[pruning_args.layer_drop_method](pruning_args, model, dataloader, accelerator, num_samples_each_device)
    elif pruning_args.prune_method == "block_drop":
        dropped_layer_list = BLOCK_DROP_METHODS_FUNC[pruning_args.block_drop_method](pruning_args, model, dataloader, accelerator, num_samples_each_device)
    else:
        raise NotImplementedError
    #######################################################################################################
    accelerator.print(f"model: {model}")
    if pruning_args.prune_model_save_path is not None:
        if pruning_args.prune_method == "layer_drop":
            save_layer_dropped_config(pruning_args.target_layer, pruning_args.prune_model_save_path, model, tokenizer, accelerator, dropped_layer_list=dropped_layer_list)
        elif pruning_args.prune_method == "block_drop":
            save_block_dropped_config(pruning_args.prune_model_save_path, model, tokenizer, accelerator, dropped_layer_list=dropped_layer_list)
        elif pruning_args.prune_method == "nbl_linearize":
            pass  # Saving is handled inside the nbl_linearize block above
        else:
            # πŸ” Save sparse model to disk
            save_sparse_model(pruning_args.prune_model_save_path, model, tokenizer, accelerator, update_state_dict, check_sparsity=True)

    accelerator.print("All done!")