File size: 10,485 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from transformers.utils import strtobool
from swift.llm import get_model_arch
from swift.utils import get_logger
logger = get_logger()
@dataclass
class TunerArguments:
"""
TunerArguments is a dataclass that holds configuration for various tuners.
Args:
target_modules (List[str]): List of target modules for tuning. Default is ['all-linear'].
target_regex (Optional[str]): Regular expression to match target modules. Default is None.
modules_to_save (List[str]): List of modules to save. Default is an empty list.
lora_rank (int): Rank for LoRA. Default is 8.
lora_alpha (int): Alpha value for LoRA. Default is 32.
lora_dropout (float): Dropout rate for LoRA. Default is 0.05.
Allowed values are 'none', 'all'.
lora_dtype (Literal): Data type for LoRA. Default is 'AUTO'. Allowed values are 'fp16', 'bf16', 'fp32', 'AUTO'.
lorap_lr_ratio (float): Learning rate ratio for LoRA. Default is None.
use_rslora (bool): Flag to indicate if RSLora is used. Default is False.
use_dora (bool): Flag to indicate if Dora is used. Default is False.
init_weights (str): Initialization method for weights of supported tuners. Default is 'true'.
lora_ga_batch_size (int): Batch size used for estimating gradients during initialization in LoRA-GA.
Default value is 2.
lora_ga_iters (int): Number of iterations for estimating gradients during initialization in LoRA-GA.
Default value is 2.
lora_ga_max_length (int): Maximum input length for estimating gradients during initialization in LoRA-GA.
Default value is 1024.
lora_ga_direction (str): Initial direction used for gradient estimation during initialization in LoRA-GA.
Default value is `ArB2r`. Allowed: `ArBr`, `A2rBr`, `ArB2r`, and `random`.
lora_ga_scale (str): The scaling method for initialization in LoRA-GA.
Default value is `stable`. Allowed values are: `gd`, `unit`, `stable`, and `weightS`.
lora_ga_stable_gamma (int): The gamma value when choosing `stable` scaling for initialization.
Default value is 16.
fourier_n_frequency (int): Number of frequencies for FourierFT. Default is 2000.
fourier_scaling (float): Scaling factor for FourierFT. Default is 300.0.
boft_block_size (int): Block size for BOFT. Default is 4.
boft_block_num (int): Number of blocks for BOFT. Default is 0.
boft_n_butterfly_factor (int): Butterfly factor for BOFT. Default is 1.
boft_dropout (float): Dropout rate for BOFT. Default is 0.0.
vera_rank (int): Rank for Vera. Default is 256.
vera_projection_prng_key (int): PRNG key for Vera projection. Default is 0.
vera_dropout (float): Dropout rate for Vera. Default is 0.0.
vera_d_initial (float): Initial value for Vera D. Default is 0.1.
adapter_act (str): Activation function for adapter. Default is 'gelu'.
adapter_length (int): Length of the adapter. Default is 128.
use_galore (bool): Flag to indicate if Galore is used. Default is False.
galore_target_modules (Optional[List[str]]): List of target modules for Galore. Default is None.
galore_rank (int): Rank for Galore. Default is 128.
galore_update_proj_gap (int): Update projection gap for Galore. Default is 50.
galore_scale (float): Scaling factor for Galore. Default is 1.0.
galore_proj_type (str): Projection type for Galore. Default is 'std'.
galore_optim_per_parameter (bool): Flag to indicate if optimization is per parameter for Galore.
Default is False.
galore_with_embedding (bool): Flag to indicate if embedding is used with Galore. Default is False.
galore_quantization (bool): Flag to indicate if use Q-Galore. Default is False.
galore_proj_quant (bool): Flag to indicate if projection quantization is used for Galore. Default is False.
galore_proj_bits (int): Number of bits for projection quantization. Default is 4.
galore_proj_group_size (int): Group size for projection quantization. Default is 256.
galore_cos_threshold (float): Cosine threshold for projection quantization. Default is 0.4.
galore_gamma_proj (int): Gamma for projection quantization. Default is 2.
galore_queue_size (int): Queue size for projection quantization. Default is 5.
adalora_target_r (int): Target rank for AdaLoRA. Default is 8.
adalora_init_r (int): Initial rank for AdaLoRA. Default is 12.
adalora_tinit (int): Initial T value for AdaLoRA. Default is 100.
adalora_tfinal (int): Final T value for AdaLoRA. Default is 1000.
adalora_deltaT (int): Delta T value for AdaLoRA. Default is 10.
adalora_beta1 (float): Beta1 value for AdaLoRA. Default is 0.85.
adalora_beta2 (float): Beta2 value for AdaLoRA. Default is 0.85.
adalora_orth_reg_weight (float): Orthogonal regularization weight for AdaLoRA. Default is 0.5.
llamapro_num_new_blocks (int): Number of new blocks for LLaMAPro. Default is 4.
llamapro_num_groups (Optional[int]): Number of groups for LLaMAPro. Default is None.
lisa_activated_layers (int): Number of activated layers for LISA. Default is 0.
lisa_step_interval (int): Step interval for LISA activation. Default is 20.
reft_layer_key (Optional[str]): Key identifier for ReFT layer. Default is None.
reft_layers (Optional[List[int]]): List of layers involved in ReFT. Default is None.
reft_rank (int): Rank parameter for ReFT. Default is 4.
reft_intervention_type (Literal): Type of intervention for ReFT. Default is 'LoreftIntervention'.
reft_args (Optional[str]): Additional arguments for ReFT. Default is None.
"""
# full
freeze_parameters: List[str] = field(default_factory=list)
freeze_parameters_regex: Optional[str] = None
freeze_parameters_ratio: float = 0. # 0 ~ 1
trainable_parameters: List[str] = field(default_factory=list)
trainable_parameters_regex: Optional[str] = None
# lora or full
freeze_llm: bool = False
freeze_vit: bool = True
freeze_aligner: bool = True
# tuners
target_modules: List[str] = field(default_factory=lambda: ['all-linear'])
target_regex: Optional[str] = None
# e.g. ['wte', 'ln_1', 'ln_2', 'ln_f', 'lm_head']
modules_to_save: List[str] = field(default_factory=list)
# lora
lora_rank: int = 8
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_bias: Literal['none', 'all'] = 'none'
lora_dtype: Literal['float16', 'bfloat16', 'float32', None] = None
lorap_lr_ratio: Optional[float] = None
use_rslora: bool = False
use_dora: bool = False
# Lora: Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false', 'lora-ga']
lora_ga_batch_size: int = 2
lora_ga_iters: int = 2
lora_ga_max_length: int = 1024
lora_ga_direction: str = 'ArB2r'
lora_ga_scale: str = 'stable'
lora_ga_stable_gamma: int = 16
# Bone: Literal['bat', 'true', 'false']
init_weights: str = 'true'
# fourierft
fourier_n_frequency: int = 2000
fourier_scaling: float = 300.0
# BOFT
boft_block_size: int = 4
boft_block_num: int = 0
boft_n_butterfly_factor: int = 1
boft_dropout: float = 0.0
# Vera
vera_rank: int = 256
vera_projection_prng_key: int = 0
vera_dropout: float = 0.0
vera_d_initial: float = 0.1
# adapter
adapter_act: str = 'gelu'
adapter_length: int = 128
# galore
use_galore: bool = False
galore_target_modules: Optional[List[str]] = None
galore_rank: int = 128
galore_update_proj_gap: int = 50
galore_scale: float = 1.0
galore_proj_type: str = 'std'
galore_optim_per_parameter: bool = False
galore_with_embedding: bool = False
galore_quantization: bool = False
galore_proj_quant: bool = False
galore_proj_bits: int = 4
galore_proj_group_size: int = 256
galore_cos_threshold: float = 0.4
galore_gamma_proj: int = 2
galore_queue_size: int = 5
# adalora
adalora_target_r: int = 8
adalora_init_r: int = 12
adalora_tinit: int = 0
adalora_tfinal: int = 0
adalora_deltaT: int = 1
adalora_beta1: float = 0.85
adalora_beta2: float = 0.85
adalora_orth_reg_weight: float = 0.5
# llamapro
llamapro_num_new_blocks: int = 4
llamapro_num_groups: Optional[int] = None
# lisa
lisa_activated_layers: int = 0
lisa_step_interval: int = 20
# reft
reft_layer_key: Optional[str] = None
reft_layers: Optional[List[int]] = None
reft_rank: int = 4
reft_intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
'LobireftIntervention', 'DireftIntervention',
'NodireftIntervention'] = 'LoreftIntervention'
reft_args: Optional[str] = None
def __post_init__(self):
if isinstance(self.init_weights, str) and self.init_weights.lower() in {'true', 'false'}:
self.init_weights = bool(strtobool(self.init_weights))
self._init_multimodal_full()
if self.target_regex:
self.target_modules = self.target_regex
def _init_multimodal_full(self):
model_arch = get_model_arch(self.model_meta.model_arch)
if not self.model_meta.is_multimodal or not model_arch:
return
if self.freeze_llm:
self.freeze_parameters += model_arch.language_model
if self.freeze_vit:
self.freeze_parameters += model_arch.vision_tower
if self.freeze_aligner:
self.freeze_parameters += model_arch.aligner
else:
self.trainable_parameters += model_arch.aligner
self.freeze_parameters += model_arch.generator
if self.freeze_parameters:
logger.info(f'freeze_parameters: {self.freeze_parameters}')
if self.trainable_parameters:
logger.info(f'additional trainable_parameters: {self.trainable_parameters}')
|