| |
| import inspect |
| import os |
| from typing import List, Union |
|
|
| import torch |
| import torch.nn as nn |
| import transformers |
| from packaging import version |
| from transformers import TrainingArguments |
|
|
| from swift.llm import TrainArguments, deep_getattr, get_model_arch |
| from swift.plugin import Tuner, extra_tuners |
| from swift.tuners import Swift |
| from swift.utils import (activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters, |
| get_logger, use_torchacc) |
|
|
| logger = get_logger() |
|
|
|
|
| def apply_liger(model_type: str): |
| from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, |
| apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma, |
| apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen3, |
| apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl, |
| apply_liger_kernel_to_phi3, apply_liger_kernel_to_mllama) |
| from swift.llm import ModelType |
| if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2): |
| apply_liger_kernel_to_llama() |
| elif model_type in (ModelType.mistral): |
| apply_liger_kernel_to_mistral() |
| elif model_type in (ModelType.mixtral): |
| apply_liger_kernel_to_mixtral() |
| elif model_type in (ModelType.gemma, ModelType.gemma2): |
| apply_liger_kernel_to_gemma() |
| elif model_type in (ModelType.qwen2, ModelType.qwen2_5): |
| apply_liger_kernel_to_qwen2() |
| elif model_type in (ModelType.qwen3): |
| apply_liger_kernel_to_qwen3() |
| elif model_type in (ModelType.phi3): |
| apply_liger_kernel_to_phi3() |
| elif model_type in (ModelType.llama3_2_vision): |
| apply_liger_kernel_to_mllama() |
| elif model_type in (ModelType.qwen2_vl): |
| apply_liger_kernel_to_qwen2_vl() |
| elif model_type in (ModelType.qwen2_5_vl): |
| apply_liger_kernel_to_qwen2_5_vl() |
| else: |
| raise ValueError(f'Unsupported liger model_type: {model_type}') |
|
|
|
|
| def get_multimodal_target_regex( |
| model, |
| *, |
| freeze_llm: bool = False, |
| freeze_vit: bool = True, |
| freeze_aligner: bool = True, |
| include_embedding: bool = False, |
| ) -> str: |
| model_arch = get_model_arch(model.model_meta.model_arch) |
| modules = [] |
| if not freeze_llm: |
| modules += model_arch.language_model |
| if not freeze_vit: |
| modules += model_arch.vision_tower |
| if not freeze_aligner: |
| modules += model_arch.aligner |
| assert len(modules) > 0, f'modules: {modules}' |
|
|
| extra_layers = [] |
| if include_embedding: |
| extra_layers.append(nn.Embedding) |
| res = [] |
| for module in modules: |
| rejected_modules = [] |
| if not freeze_vit: |
| for aligner in model_arch.aligner: |
| if aligner.startswith(f'{module}.'): |
| rejected_modules.append(aligner) |
|
|
| sub_module = deep_getattr(model, module) |
| target_modules = find_all_linears(sub_module, model_arch, extra_layers) |
| target_modules = [tm for tm in target_modules if tm] |
| target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' |
| rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' |
| res.append(rf'{rejected_pattern}{module}{target_pattern}') |
|
|
| return rf'^({"|".join(res)})$' |
|
|
|
|
| def get_target_modules(args, model) -> Union[str, List[str]]: |
| """Replace all-linear to actual modules""" |
| model_meta = model.model_meta |
| if isinstance(args.target_modules, str): |
| return args.target_modules |
| target_modules = args.target_modules.copy() |
| if 'all-linear' in target_modules: |
| if model_meta.is_multimodal: |
| return get_multimodal_target_regex( |
| model, |
| freeze_llm=args.freeze_llm, |
| freeze_vit=args.freeze_vit, |
| freeze_aligner=args.freeze_aligner, |
| include_embedding='all-embedding' in target_modules) |
| else: |
| target_modules.remove('all-linear') |
| target_modules += find_all_linears(model) |
| if 'all-embedding' in target_modules: |
| target_modules.remove('all-embedding') |
| target_modules += find_embedding(model) |
| return target_modules |
|
|
|
|
| def get_modules_to_save(args, model, task_type=None): |
| modules_to_save = args.modules_to_save.copy() |
| if 'all-embedding' in args.modules_to_save: |
| modules_to_save.remove('all-embedding') |
| modules_to_save += find_embedding(model) |
| if 'all-norm' in args.modules_to_save: |
| modules_to_save.remove('all-norm') |
| modules_to_save += find_norm(model) |
| if task_type and task_type.lower() == 'seq_cls': |
| modules_to_save.append('v_head') |
| return modules_to_save |
|
|
|
|
| def get_vera_target_modules(model, config): |
| """This function is only useful on the vera tuner""" |
| target_modules = config.target_modules |
| modules_dict = { |
| name: module.weight.shape |
| for name, module in model.named_modules() |
| if isinstance(module, torch.nn.Linear) and any([t in name for t in target_modules]) |
| } |
| if len(set(modules_dict.values())) > 1: |
| v = [t for t in target_modules if 'v' in t] |
| if not v: |
| raise ValueError('Please manually pass in `vera_target_modules`, do not use `all-linear`,' |
| 'because Vera need all target linears to be the same size.') |
| v = v[0] |
| shape = [shape for name, shape in modules_dict.items() if v in name][0] |
| names = [_name for _name, _shape in modules_dict.items() if _shape == shape] |
| config.target_modules = [t for t in target_modules if any([t in name for name in names])] |
| return config |
|
|
|
|
| def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None, task_type=None): |
| from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, LLaMAProConfig, LongLoRAModelType, LoraConfig, |
| LoRAConfig, ReftConfig, Swift, VeraConfig) |
| task_type = (task_type or args.task_type).upper() |
| target_modules = get_target_modules(args, model) |
| modules_to_save = get_modules_to_save(args, model, task_type) |
| lora_kwargs = { |
| 'r': args.lora_rank, |
| 'target_modules': target_modules, |
| 'lora_alpha': args.lora_alpha, |
| 'lora_dropout': args.lora_dropout, |
| 'bias': args.lora_bias, |
| 'modules_to_save': modules_to_save, |
| 'use_rslora': args.use_rslora, |
| 'use_dora': args.use_dora, |
| 'lorap_lr_ratio': args.lorap_lr_ratio, |
| 'init_lora_weights': args.init_weights, |
| } |
| if args.train_type in ('lora', 'longlora'): |
| if args.use_swift_lora: |
| lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs) |
| model = Swift.prepare_model(model, lora_config) |
| logger.info(f'lora_config: {lora_config}') |
| elif args.tuner_backend == 'peft': |
| if task_type == 'EMBEDDING': |
| task_type = None |
| lora_config = LoraConfig(task_type=task_type, lora_dtype=args.lora_dtype, **lora_kwargs) |
| if args.init_weights == 'lora-ga': |
| try: |
| import lora_ga |
| except ImportError as e: |
| error_message = """ |
| Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub. |
| Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'. |
| """ |
| logger.info(error_message) |
| raise RuntimeError(error_message) from e |
| model = lora_ga.entrypoint.get_lora_ga_model( |
| model=model, |
| data_collator=template.data_collator, |
| dataset=train_dataset, |
| batch_size=args.lora_ga_batch_size, |
| num_iters=args.lora_ga_iters, |
| max_length=args.lora_ga_max_length, |
| direction=args.lora_ga_direction, |
| dtype=args.lora_dtype, |
| scale=args.lora_ga_scale, |
| stable_gamma=args.lora_ga_stable_gamma, |
| ) |
| else: |
| model = Swift.prepare_model(model, lora_config) |
| logger.info(f'lora_config: {lora_config}') |
| elif args.tuner_backend == 'unsloth': |
| if args.resume_from_checkpoint is None: |
| if args.model_meta.is_multimodal: |
| from unsloth import FastVisionModel as UnslothModel |
| else: |
| from unsloth import FastLanguageModel as UnslothModel |
| assert args.train_type == 'lora', 'Unsloth does not support LongLoRA' |
| lora_kwargs.pop('lorap_lr_ratio') |
| model = UnslothModel.get_peft_model( |
| model, |
| use_gradient_checkpointing='unsloth', |
| max_seq_length=args.max_length or 2048, |
| **lora_kwargs, |
| ) |
| logger.info(f'unsloth_config: {lora_kwargs}') |
| if args.train_type == 'longlora': |
| assert LongLoRAModelType.LLAMA in args.model_type |
| assert version.parse(transformers.__version__) >= version.parse('4.39.3') |
| from swift.tuners.longlora.llama import replace_llama_attn |
| replace_llama_attn(model) |
| model.config.group_size_ratio = 0.25 |
| elif args.train_type == 'adalora': |
| lora_kwargs.pop('lorap_lr_ratio', None) |
| lora_kwargs['rank_pattern'] = None |
| from swift.plugin.optimizer import calculate_max_steps |
| adalora_config = AdaLoraConfig( |
| task_type=task_type, |
| **lora_kwargs, |
| target_r=args.adalora_target_r, |
| init_r=args.adalora_init_r, |
| tinit=args.adalora_tinit, |
| tfinal=args.adalora_tfinal, |
| deltaT=args.adalora_deltaT, |
| beta1=args.adalora_beta1, |
| beta2=args.adalora_beta2, |
| orth_reg_weight=args.adalora_orth_reg_weight, |
| total_step=calculate_max_steps(args.training_args, train_dataset), |
| ) |
| model = Swift.prepare_model(model, adalora_config) |
| logger.info(f'adalora_config: {adalora_config}') |
| elif args.train_type == 'llamapro': |
| llamapro_config = LLaMAProConfig( |
| model_type=model.model_meta.model_arch, |
| num_new_blocks=args.llamapro_num_new_blocks, |
| num_groups=args.llamapro_num_groups) |
| model = Swift.prepare_model(model, llamapro_config) |
| logger.info(f'llamapro_config: {llamapro_config}') |
| elif args.train_type == 'adapter': |
| model_arch = get_model_arch(model.model_meta.model_arch) |
| mlp_key = model_arch.mlp |
| mlp_key = mlp_key.split('.{}.')[1] |
| adapter_config = AdapterConfig( |
| dim=model.config.hidden_size, |
| target_modules=[mlp_key], |
| hidden_pos=0, |
| adapter_length=args.adapter_length, |
| act_layer=args.adapter_act) |
| model = Swift.prepare_model(model, adapter_config) |
| logger.info(f'adapter_config: {adapter_config}') |
| elif args.train_type == 'vera': |
| vera_config = VeraConfig( |
| r=args.vera_rank, |
| target_modules=target_modules, |
| projection_prng_key=args.vera_projection_prng_key, |
| vera_dropout=args.vera_dropout, |
| d_initial=args.vera_d_initial, |
| modules_to_save=args.modules_to_save, |
| ) |
| vera_config = get_vera_target_modules(model, vera_config) |
| model = Swift.prepare_model(model, vera_config) |
| logger.info(f'vera_config: {vera_config}') |
| elif args.train_type == 'boft': |
| boft_config = BOFTConfig( |
| boft_block_size=args.boft_block_size, |
| boft_block_num=args.boft_block_num, |
| boft_n_butterfly_factor=args.boft_n_butterfly_factor, |
| target_modules=target_modules, |
| boft_dropout=args.boft_dropout, |
| modules_to_save=args.modules_to_save, |
| ) |
| model = Swift.prepare_model(model, boft_config) |
| logger.info(f'boft_config: {boft_config}') |
| elif args.train_type == 'fourierft': |
| from peft import FourierFTConfig |
| fourier_config = FourierFTConfig( |
| target_modules=target_modules, |
| modules_to_save=args.modules_to_save, |
| n_frequency=args.fourier_n_frequency, |
| scaling=args.fourier_scaling, |
| ) |
| model = Swift.prepare_model(model, fourier_config) |
| logger.info(f'fourier_config: {fourier_config}') |
| elif args.train_type == 'reft': |
| reft_config = ReftConfig( |
| model_type=model.model_meta.model_arch, |
| layer_key=args.reft_layer_key, |
| r=args.reft_rank, |
| layers=args.reft_layers, |
| intervention_type=args.reft_intervention_type, |
| args=args.reft_args, |
| ) |
| logger.info(f'reft config: {reft_config}') |
| model = Swift.prepare_model(model, {'reft': reft_config}) |
| elif args.train_type == 'bone': |
| |
| from peft import BoneConfig |
| bone_config = BoneConfig( |
| target_modules=target_modules, |
| r=args.reft_rank, |
| init_weights=args.init_weights, |
| ) |
| logger.info(f'bone config: {bone_config}') |
| model = Swift.prepare_model(model, bone_config) |
| return model |
|
|
|
|
| def torchacc_resume_from_checkpoint(args, model): |
| import safetensors |
| weights_file = os.path.join(args.resume_from_checkpoint, 'pytorch_model.bin') |
| safe_weights_file = os.path.join(args.resume_from_checkpoint, 'model.safetensors') |
| if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): |
| if args.save_safetensors and os.path.isfile(safe_weights_file): |
| state_dict = safetensors.torch.load_file(safe_weights_file, device='cpu') |
| else: |
| state_dict = torch.load(weights_file, map_location='cpu') |
| model.load_state_dict(state_dict, False) |
| del state_dict |
| else: |
| from transformers.modeling_utils import load_sharded_checkpoint |
| |
| load_result = load_sharded_checkpoint( |
| model, args.resume_from_checkpoint, strict=False, prefer_safe=args.save_safetensors) |
| if len(load_result.missing_keys) != 0: |
| if model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( |
| model._keys_to_ignore_on_save): |
| model.tie_weights() |
| else: |
| logger.warning(f'There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.') |
| if len(load_result.unexpected_keys) != 0: |
| logger.warning(f'There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.') |
|
|
|
|
| class TunerMixin: |
|
|
| @classmethod |
| def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): |
| if args.use_liger_kernel and 'use_liger_kernel' not in inspect.signature(TrainingArguments).parameters: |
| |
| apply_liger(args.model_type) |
|
|
| if args.is_adapter: |
| if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners: |
| |
| |
| |
| model.requires_grad_(False) |
| if args.resume_from_checkpoint: |
| if args.train_type in extra_tuners: |
| tuner: Tuner = extra_tuners[args.train_type] |
| else: |
| tuner = Swift |
| kwargs = {} |
| if use_torchacc(): |
| kwargs = {'adapter_name': 'default'} |
| model = tuner.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True, **kwargs) |
| else: |
| if args.train_type in extra_tuners: |
| tuner: Tuner = extra_tuners[args.train_type] |
| model = tuner.prepare_model(args, model) |
| else: |
| model = prepare_adapter( |
| args, model, template=template, train_dataset=train_dataset, task_type=task_type) |
| |
| |
| for p in model.parameters(): |
| if p.requires_grad and p.dtype == torch.float16: |
| logger.info_once('Convert trainable parameters from fp16 to fp32.') |
| p.data = p.data.to(dtype=torch.float32) |
| elif args.train_type == 'full': |
| model.train() |
| model.requires_grad_(True) |
|
|
| freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex) |
| if len(args.trainable_parameters) > 0 or args.trainable_parameters_regex is not None: |
| activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex) |
| if use_torchacc() and args.resume_from_checkpoint: |
| torchacc_resume_from_checkpoint(args, model) |
| else: |
| raise ValueError(f'args.train_type: {args.train_type}') |
|
|
| if args.resume_only_model: |
| args.training_args.resume_from_checkpoint = None |
| if args.use_galore: |
| from swift.trainers.optimizers.galore import GaLoreConfig |
| if args.galore_target_modules is None: |
| args.galore_target_modules = find_all_linears(model) |
| if args.galore_with_embedding: |
| args.galore_target_modules += find_embedding(model) |
| args.galore_config = GaLoreConfig( |
| target_modules=args.galore_target_modules, |
| rank=args.galore_rank, |
| update_proj_gap=args.galore_update_proj_gap, |
| galore_scale=args.galore_scale, |
| proj_type=args.galore_proj_type, |
| optim_per_parameter=args.galore_optim_per_parameter, |
| quantize=args.galore_quantization, |
| proj_quant=args.galore_proj_quant, |
| proj_bits=args.galore_proj_bits, |
| proj_group_size=args.galore_proj_group_size, |
| cos_threshold=args.galore_cos_threshold, |
| gamma_proj=args.galore_gamma_proj, |
| queue_size=args.galore_queue_size, |
| ) |
| args.training_args.galore_config = args.galore_config |
|
|
| if args.sequence_parallel_size > 1: |
| from swift.trainers.sequence_parallel import sequence_parallel |
| if hasattr(model, 'model_meta'): |
| is_multimodal = model.model_meta.is_multimodal |
| else: |
| is_multimodal = model.model.model_meta.is_multimodal |
| |
| |
| sequence_parallel.prepare_model(model, template.tokenizer, split_in_forward=is_multimodal) |
|
|
| return model |
|
|