File size: 4,935 Bytes
d73500e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional, Literal
EXPERT_DROP_METHODS = ('global_pruning', 'layerwise_pruning', 'progressive_pruning', 'dynamic_skipping', 'post_dropping')
LAYER_DROP_METHODS = ('consecutive', 'discrete', 'post_dropping')
BLOCK_DROP_METHODS = ('consecutive', 'discrete', 'post_dropping')
@dataclass
class PruningArguments:
r"""
Arguments pertaining to specify the decoding parameters.
"""
prune_seed: Optional[int] = field(
default=42,
metadata={"help": "Seed for sampling the calibration data."},
)
prune_method: Optional[str] = field(
default="wanda",
metadata={"choices": ["wanda", "sparsegpt", "gradient-first", "gradient-zeroth", "magnitude", "remap_gate", "decompose_moe", "expert_drop", "block_drop", "layer_drop", "nbl_linearize"]},
)
prune_model_save_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save the pruned model."},
)
n_calibration_samples: Optional[int] = field(
default=128,
metadata={"help": "Number of calibration samples."},
)
prune_data_type: Literal["pt", "sft", "rm", "ppo"] = field(
default="sft",
metadata={"choices": ["pt", "sft", "rm", "ppo"],
"help": "Path to save the pruned model."},
)
# π For pruning
sparsity_ratio: Optional[float] = field( # this term denotes the "parameter_ratio" for decomposition
default=0.5,
metadata={"help": "Sparsity Level."},
)
sparsity_type: Optional[Literal["structured", "unstructured", "4:8", "2:4"]] = field(
default="unstructured",
metadata={"choices": ["structured", "unstructured", "4:8", "2:4"]},
)
use_variant: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use the variant for Wanda."},
)
# π For decomposition
level: Optional[str] = field(
default="expert",
metadata={"choices": ["expert", "layer", "model"]},
)
has_sparse: Optional[bool] = field(
default=True,
)
do_permute: Optional[bool] = field(
default=True,
)
use_svd: Optional[bool] = field(
default=True,
)
top_scores: Optional[bool] = field(
default=True,
)
# π For expert drop
expert_drop_method: Optional[str] = field(
default="layerwise_pruning",
metadata={"help": ' '.join(['Supported dropping methods:'] + list(EXPERT_DROP_METHODS)),
"choices": EXPERT_DROP_METHODS},
)
r: Optional[int] = field(
default=4,
metadata={"help": 'Number of experts to preserve'}
)
# π For layer drop & block drop
layer_drop_method: Optional[str] = field(
default="discrete",
metadata={"help": ' '.join(['Supported dropping methods:'] + list(LAYER_DROP_METHODS)),
"choices": LAYER_DROP_METHODS},
)
block_drop_method: Optional[str] = field(
default="discrete",
metadata={"help": ' '.join(['Supported dropping methods:'] + list(BLOCK_DROP_METHODS)),
"choices": BLOCK_DROP_METHODS},
)
drop_n: Optional[int] = field(
default=4,
metadata={"help": 'Number of blocks to drop'}
)
layer_drop_norm: Optional[bool] = field(
default=True,
metadata={"help": 'determine whether to consider norm when calculating similarity. If True, use the hidden states before norm to calculate similarity.'}
)
target_layer: Optional[str] = field(
default=None,
metadata={"help": 'determine which type of layer is dropped when layer_drop. ',
"choices": ["mlp", "attn", "all"]},
)
only_update_config: Optional[bool] = field(
default=False,
metadata={"help": 'Only output the config file without saving model weights. '}
)
similarity_cache_file: Optional[str] = field(
default=None,
metadata={"help": 'Cached file storing the similarity scores across layers to reduce the computation consumption. '
'If the file does not exist, it will be created.'},
)
# For NBL
num_layers_to_linearize: Optional[int] = field(
default=4,
metadata={"help": "Number of attention layers to linearize."},
)
nbl_metric_cache_file: Optional[str] = field(
default=None,
metadata={"help": "Cached file for NBL metrics (NMSE scores)."},
)
# π For gate-remapping
pruned_model_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the pruned model. (Only for Gate-Remapping)"},
)
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
return args
|