File size: 5,550 Bytes
d73500e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional, Literal

EXPERT_DROP_METHODS = ('global_pruning', 'layerwise_pruning', 'progressive_pruning', 'dynamic_skipping', 'post_dropping')
LAYER_DROP_METHODS = ('consecutive', 'discrete', 'post_dropping', 'super_weight_guided')
BLOCK_DROP_METHODS = ('consecutive', 'discrete', 'post_dropping')
SUPER_WEIGHT_METHODS = ('analysis', 'pruning')


@dataclass
class PruningArguments:
    r"""
    Arguments pertaining to specify the decoding parameters.
    """
    prune_seed: Optional[int] = field(
        default=42,
        metadata={"help": "Seed for sampling the calibration data."},
    )
    prune_method: Optional[str] = field(
        default="wanda",
        metadata={"choices": ["wanda", "sparsegpt", "gradient-first", "gradient-zeroth", "magnitude", "remap_gate", "decompose_moe", "expert_drop", "block_drop", "layer_drop", "super_weight"]},
    )
    prune_model_save_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to save the pruned model."},
    )
    n_calibration_samples: Optional[int] = field(
        default=128,
        metadata={"help": "Number of calibration samples."},
    )
    prune_data_type: Literal["pt", "sft", "rm", "ppo"] = field(
        default="sft",
        metadata={"choices": ["pt", "sft", "rm", "ppo"],
                  "help": "Path to save the pruned model."},
    )

    # πŸ” For pruning
    sparsity_ratio: Optional[float] = field(  # this term denotes the "parameter_ratio" for decomposition
        default=0.5,
        metadata={"help": "Sparsity Level."},
    )
    sparsity_type: Optional[Literal["structured", "unstructured", "4:8", "2:4"]] = field(
        default="unstructured",
        metadata={"choices": ["structured", "unstructured", "4:8", "2:4"]},
    )
    use_variant: Optional[bool] = field(
        default=False,
        metadata={"help": "Whether to use the variant for Wanda."},
    )

    # πŸ” For decomposition
    level: Optional[str] = field(
        default="expert",
        metadata={"choices": ["expert", "layer", "model"]},
    )
    has_sparse: Optional[bool] = field(
        default=True,
    )
    do_permute: Optional[bool] = field(
        default=True,
    )
    use_svd: Optional[bool] = field(
        default=True,
    )
    top_scores: Optional[bool] = field(
        default=True,
    )

    # πŸ” For expert drop
    expert_drop_method: Optional[str] = field(
        default="layerwise_pruning",
        metadata={"help": ' '.join(['Supported dropping methods:'] + list(EXPERT_DROP_METHODS)),
                  "choices": EXPERT_DROP_METHODS},
    )
    r: Optional[int] = field(
        default=4,
        metadata={"help": 'Number of experts to preserve'}
    )

    # πŸ” For layer drop & block drop
    layer_drop_method: Optional[str] = field(
        default="discrete",
        metadata={"help": ' '.join(['Supported dropping methods:'] + list(LAYER_DROP_METHODS)),
                  "choices": LAYER_DROP_METHODS},
    )
    block_drop_method: Optional[str] = field(
        default="discrete",
        metadata={"help": ' '.join(['Supported dropping methods:'] + list(BLOCK_DROP_METHODS)),
                  "choices": BLOCK_DROP_METHODS},
    )
    drop_n: Optional[int] = field(
        default=4,
        metadata={"help": 'Number of blocks to drop'}
    )
    layer_drop_norm: Optional[bool] = field(
        default=True,
        metadata={"help": 'determine whether to consider norm when calculating similarity. If True, use the hidden states before norm to calculate similarity.'}
    )
    target_layer: Optional[str] = field(
        default=None,
        metadata={"help": 'determine which type of layer is dropped when layer_drop. ',
            "choices": ["mlp", "attn", "all"]},
    )
    only_update_config: Optional[bool] = field(
        default=False,
        metadata={"help": 'Only output the config file without saving model weights. '}
    )
    similarity_cache_file: Optional[str] = field(
        default=None,
        metadata={"help": 'Cached file storing the similarity scores across layers to reduce the computation consumption. '
                          'If the file does not exist, it will be created.'},
    )

    # πŸ” For Super Weight
    super_weight_method: Optional[str] = field(
        default="analysis",
        metadata={"help": ' '.join(['Supported super weight methods:'] + list(SUPER_WEIGHT_METHODS)),
                  "choices": SUPER_WEIGHT_METHODS},
    )
    super_weight_threshold: Optional[float] = field(
        default=3.0,
        metadata={"help": 'Threshold for detecting super weights based on activation magnitude.'},
    )
    super_weight_cache_file: Optional[str] = field(
        default=None,
        metadata={"help": 'Cached file storing the super weight detection results. '
                          'If the file does not exist, it will be created.'},
    )
    prune_super_weight_n: Optional[int] = field(
        default=0,
        metadata={"help": 'Number of super weights to prune. If 0, only detection is performed.'},
    )

    # πŸ” For gate-remapping
    pruned_model_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the pruned model. (Only for Gate-Remapping)"},
    )

    def to_dict(self) -> Dict[str, Any]:
        args = asdict(self)
        if args.get("max_new_tokens", -1) > 0:
            args.pop("max_length", None)
        else:
            args.pop("max_new_tokens", None)
        return args