diff --git a/ms-swift/silence_overlaps/700/test/silence_transcriptions_test.json b/ms-swift/silence_overlaps/700/test/silence_transcriptions_test.json new file mode 100644 index 0000000000000000000000000000000000000000..9ad01be74be93b46ab4bb0c3e8b6b392a0fad6d3 --- /dev/null +++ b/ms-swift/silence_overlaps/700/test/silence_transcriptions_test.json @@ -0,0 +1,27 @@ +[ + { + "key": "SODA_PROCESSED--train--873625", + "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--873625.wav", + "model_output": "[00:00 - 00:03] Speaker A: Hi, I'm here to test drive the car that I saw online.\n[00:04 - 00:06] Speaker B: Great! What model are you interested in?\n[00:07 - 00:11] Speaker A: The one that I was looking at was the new BMW M3.\n[00:11 - 00:17] Speaker B: That's a great choice! Let me grab the keys and we'll go over some of the features before heading out for the test drive.\n[00:17 - 00:24] Speaker A: Before we go, could you tell me a bit more about the car's features? I want to make sure it has everything I'm looking for.\n[00:25 - 00:33] Speaker B: Absolutely! The BMW M3 comes with a twin-turbo inline-six engine, adaptive suspension, and a premium interior with all the latest tech.\n[00:34 - 00:35] Speaker A: Sounds good to me.\n[00:36 - 00:39] Speaker B: This car is amazing! The acceleration is incredible!\n[00:39 - 00:46] Speaker A: Yeah, it's definitely a powerful car. But what do you think about the overall driving experience\n[00:45 - 00:52] Speaker B: Sorry to cut in, but I just noticed the steering wheel feels a bit stiff. Is that normal, or is it adjustable?\n[00:56 - 01:07] Speaker A: It's normal for the M3, but it does have adjustable settings to customize the steering feel to your preference. Now, as I was saying, what do you think about the handling overall?\n[01:07 - 01:12] Speaker B: It feels really good. It's responsive and precise. I love it!\n[01:12 - 01:15] Speaker A: Great! So you're interested in purchasing this car?\n[01:16 - 01:18] Speaker B: Yeah, I think I am. How much is it?\n[01:19 - 01:20] Speaker A: It's $60,000.\n[01:21 - 01:24] Speaker B: That's a lot of money. I'm not sure if I can afford that." + }, + { + "key": "SODA_PROCESSED--train--891432", + "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--891432.wav", + "model_output": "[00:00 - 00:02] Speaker A: I don't know, I just feel like something is\n[00:02 - 00:06] Speaker B: What do you mean? Are you sensing something specific or is it just a general feeling?\n[00:11 - 00:15] Speaker A: I don't know, I just have a bad feeling about this whole situation.\n[00:16 - 00:20] Speaker B: Do you want to talk about it? Maybe we can figure out what's bothering you.\n[00:20 - 00:28] Speaker A: I'm not sure. I just feel like there's something we're overlooking, something that could change everything if we realized it.\n[00:28 - 00:33] Speaker B: Something we're missing? Like a detail we overlooked or something more significant?\n[00:33 - 00:37] Speaker A: Yeah, exactly. But whatever it is, it's not good.\n[00:37 - 00:42] Speaker B: Why do you say that? Is there something that happened recently that triggered this feeling?\n[00:42 - 00:48] Speaker A: Because everything is just too perfect. It's like we're being set up for something, but I don't know what.\n[00:48 - 00:54] Speaker B: That does seem suspicious. But maybe we're just overthinking things, or is there something specific that makes you feel this way?\n[00:55 - 00:58] Speaker A: Maybe...I don't know, I just can't shake this feeling." + }, + { + "key": "SODA_PROCESSED--train--527126", + "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--527126.wav", + "model_output": "[00:00 - 00:07] Speaker A: Hey, Cherish. I'm starting to work on a new project and I could really use some help. Would you be interested in lending a hand?\n[00:08 - 00:10] Speaker B: Sure, Seneca. What's the project?\n[00:11 - 00:22] Speaker A: It's a bit complicated, but essentially I'm trying to create a detailed map of the world. Every country, every city, every landmark...I want to include it\n[00:21 - 00:26] Speaker B: Sorry to jump in, but how are you planning to gather all that information? It sounds like a massive undertaking.\n[00:31 - 00:54] Speaker A: You're right, it is. I was thinking of using a combination of online resources and some specialized databases. But I haven't mentioned the tools I need yet. It's going to be a lot of work, but I think it will be really cool once it's finished. We might need to use some advanced GIS software and perhaps even some AI tools to process the data efficiently.\n[00:54 - 01:02] Speaker B: That does sound like a lot of work. But if you're up for it, I'm happy to help out however I can. What kind of tools are you thinking of using exactly?" + }, + { + "key": "SODA_PROCESSED--train--243315", + "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--243315.wav", + "model_output": "[00:00 - 00:04] Speaker A: Hey, Mom. I finished the clothes you were helping me\n[00:03 - 00:05] Speaker B: Oh, let me see! How did they turn out?\n[00:11 - 00:15] Speaker A: Yeah, they turned out pretty good, if I do say so myself.\n[00:15 - 00:29] Speaker B: They most certainly did! You're getting so big and tall; it won't be long before those old clothes of yours won't fit you anymore and we'll need to completely revamp your wardrobe with stylish new outfits that match your growing personality. We'll have to go shopping for new ones soon.\n[00:29 - 00:33] Speaker A: I know, right? It's hard to believe how fast time is flying by.\n[00:34 - 00:39] Speaker B: It really is. Sometimes it feels like only yesterday you were born, and now you're already a young man.\n[00:40 - 00:45] Speaker A: So true. Well, I should probably get going now. Thanks for all your help, Mom.\n[00:46 - 00:48] Speaker B: Of course, sweetheart. Anytime." + }, + { + "key": "SODA_PROCESSED--train--201813", + "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--201813.wav", + "model_output": "[00:00 - 00:04] Speaker A: Hi, I'm looking for a new couch, coffee table, and rug.\n[00:04 - 00:10] Speaker B: Great! We have a wide selection of furniture to choose from. Let me show you some of our couches first.\n[00:10 - 00:11] Speaker A: OK.\n[00:11 - 00:20] Speaker B: We have a variety of styles and colors to choose from. What is your preferred style? Are you looking for something modern, traditional, or maybe something in between?\n[00:20 - 00:26] Speaker A: Sorry, before we get into that, can you tell me if you have any ongoing discounts or promotions?\n[00:26 - 00:33] Speaker B: Yes, we do have some promotions running right now. I was just about to ask about your budget, though. Do you have a specific number in mind?\n[00:34 - 00:37] Speaker A: I'm not really sure. Maybe around $500?\n[00:38 - 00:49] Speaker B: We have some great options within your budget. This couch here is only $499. It's a popular choice because it's very versatile and can be used in many different ways\n[00:48 - 00:54] Speaker A: Actually, I was also wondering about the durability of this couch. How long does it typically last?\n[01:00 - 01:05] Speaker B: It's made with high-quality materials, so it should last you several years with proper care. Would you like to see it?\n[01:06 - 01:08] Speaker A: Yes, that sounds perfect. I'll take it!" + } +] \ No newline at end of file diff --git a/ms-swift/swift/plugin/loss_scale/config/ignore_empty_think.json b/ms-swift/swift/plugin/loss_scale/config/ignore_empty_think.json new file mode 100644 index 0000000000000000000000000000000000000000..c7c2395fbb78294a543f09072620895e76ef1ea9 --- /dev/null +++ b/ms-swift/swift/plugin/loss_scale/config/ignore_empty_think.json @@ -0,0 +1,3 @@ +{ + "\n\n\n\n": [0.0] +} diff --git a/ms-swift/swift/trainers/__pycache__/arguments.cpython-310.pyc b/ms-swift/swift/trainers/__pycache__/arguments.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5515df1b3774102c5ae4cb671c11044fd814ff Binary files /dev/null and b/ms-swift/swift/trainers/__pycache__/arguments.cpython-310.pyc differ diff --git a/ms-swift/swift/trainers/__pycache__/mixin.cpython-310.pyc b/ms-swift/swift/trainers/__pycache__/mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26626f516705e4d1cf6c2206e779e0a9110036a0 Binary files /dev/null and b/ms-swift/swift/trainers/__pycache__/mixin.cpython-310.pyc differ diff --git a/ms-swift/swift/trainers/__pycache__/utils.cpython-310.pyc b/ms-swift/swift/trainers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4909d56844c1622ca1345b95d9a33f7766943433 Binary files /dev/null and b/ms-swift/swift/trainers/__pycache__/utils.cpython-310.pyc differ diff --git a/ms-swift/swift/trainers/arguments.py b/ms-swift/swift/trainers/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..14c98b5c1a7a14b6cd361565e3382688aeeddcb1 --- /dev/null +++ b/ms-swift/swift/trainers/arguments.py @@ -0,0 +1,214 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +import platform +from dataclasses import dataclass, field +from functools import wraps +from typing import List, Literal, Optional, Union + +import torch +import torch.utils.checkpoint +from transformers.training_args import TrainingArguments as HfTrainingArguments +from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments + +from swift.utils import get_dist_setting, get_logger, is_liger_available, use_torchacc +from .optimizers.galore import GaLoreConfig + +logger = get_logger() + + +@dataclass +class TrainArgumentsMixin: + """ + check_model (bool): Flag to check the model is latest. Default is True. + acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'. + """ + per_device_train_batch_size: int = 1 + per_device_eval_batch_size: int = 1 + gradient_accumulation_steps: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None + logging_first_step: bool = True + logging_steps: int = 5 + + weight_decay: float = 0.1 + adam_beta2: float = 0.95 + lr_scheduler_type: str = 'cosine' + lr_scheduler_kwargs: Optional[Union[dict, str]] = None + report_to: List[str] = field(default_factory=lambda: ['tensorboard']) + dataloader_num_workers: Optional[int] = None + dataloader_prefetch_factor: Optional[int] = None + use_liger_kernel: bool = False + + # extra + check_model: bool = True + acc_strategy: Literal['token', 'seq'] = 'token' + train_dataloader_shuffle: bool = True + max_epochs: Optional[int] = None + + # torchacc + metric_warmup_step: Optional[float] = 0 + fsdp_num: int = 1 + acc_steps: int = 1 + + # train-eval loop args + eval_use_evalscope: bool = False + eval_datasets: List[str] = field(default_factory=list) + eval_limit: Optional[int] = None + eval_datasets_args: Optional[Union[str, dict]] = None + eval_generation_config: Optional[Union[str, dict]] = None + + def _fix_gradient_checkpointing(self): + # fix use_reentrant + if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching + return + # Consistent with the default behavior of transformers. + use_reentrant_ = ( + self.gradient_checkpointing_kwargs.get('use_reentrant', True) + if self.gradient_checkpointing_kwargs else True) + _old_checkpoint = torch.utils.checkpoint.checkpoint + + @wraps(_old_checkpoint) + def _new_checkpoint(*args, use_reentrant=None, **kwargs): + return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs) + + torch.utils.checkpoint._old_checkpoint = _old_checkpoint + torch.utils.checkpoint.checkpoint = _new_checkpoint + try: + # Fix the old version of transformers. + import transformers.modeling_utils + transformers.modeling_utils.checkpoint = _new_checkpoint + except (ImportError, AttributeError): + pass + + def _init_liger(self): + if self.use_liger_kernel: + assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`' + + def __post_init__(self): + from swift.llm.argument.base_args.model_args import ModelArguments + if use_torchacc(): + self.dataloader_drop_last = True + if self.gradient_accumulation_steps is None: + world_size = get_dist_setting()[2] + self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size)) + logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}') + if self.lr_scheduler_kwargs: + self.lr_scheduler_kwargs = ModelArguments.parse_to_dict(self.lr_scheduler_kwargs) + if self.gradient_checkpointing_kwargs: + self.gradient_checkpointing_kwargs = ModelArguments.parse_to_dict(self.gradient_checkpointing_kwargs) + self._fix_gradient_checkpointing() + self._init_liger() + if self.dataloader_num_workers is None: + if platform.system() == 'Windows': + self.dataloader_num_workers = 0 + else: + self.dataloader_num_workers = 1 + logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}') + if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0: + self.dataloader_prefetch_factor = 10 + if self.eval_use_evalscope: + try: + import evalscope + except ImportError: + raise ImportError('evalscope is not installed, please install it by `pip install evalscope`') + self.eval_datasets_args = ModelArguments.parse_to_dict(self.eval_datasets_args) + self.eval_generation_config = ModelArguments.parse_to_dict(self.eval_generation_config) + + super().__post_init__() + + +@dataclass +class SwiftArgumentsMixin(TrainArgumentsMixin): + # Value copied from TrainArguments + train_type: Optional[str] = None + optimizer: Optional[str] = None + local_repo_path: Optional[str] = None + galore_config: Optional[GaLoreConfig] = None + + def __post_init__(self): + if hasattr(self, 'output_dir'): + self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir)) + super().__post_init__() + + @property + def place_model_on_device(self): + return False if use_torchacc() else super().place_model_on_device + + +@dataclass +class GRPOArgumentsMixin: + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + top_k: int = 50 + top_p: float = 0.9 + repetition_penalty: float = 1. + num_infer_workers: int = 1 + # vllm + vllm_device: List[str] = field(default_factory=lambda: ['auto']) + vllm_gpu_memory_utilization: float = 0.9 + vllm_max_model_len: Optional[int] = None + vllm_max_num_seqs: int = 256 + vllm_enforce_eager: bool = False + vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' + vllm_enable_prefix_caching: bool = True + # reward function args, see details in swift/plugin/orm.py + # cosine reward, https://arxiv.org/abs/2502.03373 + cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. + cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length. + cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length. + cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length. + cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length + # repetition penalty, https://arxiv.org/abs/2502.03373 + repetition_n_grams: int = 3 + repetition_max_penalty: float = -1.0 + + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None + # LMDeploy in GRPO + use_lmdeploy: bool = False + lmdeploy_device: Optional[str] = 'auto' + lmdeploy_session_len: Optional[int] = None + lmdeploy_cache_max_entry_count: float = 0.8 + + async_generate: bool = False + tensor_parallel_size: int = 1 + sleep_level: int = 0 + move_model_batches: Optional[int] = None + offload_optimizer: bool = False + offload_model: bool = False + gc_collect_after_offload: bool = False + multi_turn_func: Optional[str] = None + + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + soft_max_length: Optional[int] = None + soft_cache_length: Optional[int] = None + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: bool = True + + # compatible with trl main branch(0.17.0.dev0) + wandb_log_unique_prompts: Optional[bool] = None + + # external vllm + vllm_server_host: Optional[str] = None + vllm_server_port: int = 8000 + vllm_server_timeout: float = 240.0 + vllm_client = None + + # dataset + dataset_shuffle: Optional[bool] = True + + +@dataclass +class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments): + pass + + +@dataclass +class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments): + pass diff --git a/ms-swift/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc b/ms-swift/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b478ba503b1edb57cf46adfa203054a26b376830 Binary files /dev/null and b/ms-swift/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc differ diff --git a/ms-swift/swift/trainers/optimizers/galore/adamw.py b/ms-swift/swift/trainers/optimizers/galore/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..7396334a32d974a3631e30862a384f908a6816f4 --- /dev/null +++ b/ms-swift/swift/trainers/optimizers/galore/adamw.py @@ -0,0 +1,141 @@ +# copy dependencies from transformers/optimization.py +# code borrowed from https://github.com/jiaweizzhao/GaLore +import math +from typing import Callable, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer +from transformers.utils.versions import require_version + +from .galore_projector import GaLoreProjector + + +class AdamW(Optimizer): + """ + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 0.001): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-06): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0.0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + no_deprecation_warning: bool = False, + ): + require_version('torch>=1.5.0') # add_ with alpha + if lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr} - should be >= 0.0') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps} - should be >= 0.0') + defaults = {'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, 'correct_bias': correct_bias} + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + + state = self.state[p] + + if 'step' not in state: + state['step'] = 0 + + # GaLore Projection + if 'rank' in group: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + proj_type=group['proj_type']) + + grad = state['projector'].project(grad, state['step']) + + # State initialization + if 'exp_avg' not in state: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(grad) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1**state['step'] + bias_correction2 = 1.0 - beta2**state['step'] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # compute norm gradient + norm_grad = exp_avg / denom + + # GaLore Projection Back + if 'rank' in group: + norm_grad = state['projector'].project_back(norm_grad) + + p.add_(norm_grad, alpha=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group['weight_decay'] > 0.0: + p.add_(p, alpha=(-group['lr'] * group['weight_decay'])) + + return loss + + +GaLoreAdamW = AdamW diff --git a/ms-swift/swift/trainers/optimizers/galore/adamw8bit.py b/ms-swift/swift/trainers/optimizers/galore/adamw8bit.py new file mode 100644 index 0000000000000000000000000000000000000000..66b0c5b621369ec16577729df5251848a8796e90 --- /dev/null +++ b/ms-swift/swift/trainers/optimizers/galore/adamw8bit.py @@ -0,0 +1,112 @@ +# code borrowed from https://github.com/jiaweizzhao/GaLore +import torch +from bitsandbytes.optim.optimizer import Optimizer2State + +from .galore_projector import GaLoreProjector + + +class AdamW8bit(Optimizer2State): + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False): + super().__init__( + 'adam', + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + # if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group['params']): + if p.grad is None: + continue + state = self.state[p] + + if 'step' not in state: + state['step'] = 0 + + # GaLore Projection + if 'rank' in group: + if 'projector' not in state: + state['projector'] = GaLoreProjector( + group['rank'], + update_proj_gap=group['update_proj_gap'], + scale=group['scale'], + proj_type=group['proj_type']) + + if 'weight_decay' in group and group['weight_decay'] > 0: + # ensure that the weight decay is not applied to the norm grad + group['weight_decay_saved'] = group['weight_decay'] + group['weight_decay'] = 0 + + grad = state['projector'].project(p.grad, state['step']) + + # suboptimal implementation + p.saved_data = p.data.clone() + p.data = grad.clone().to(p.data.dtype).to(p.data.device) + p.data.zero_() + p.grad = grad + + if 'state1' not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + + # GaLore Projection Back + if 'rank' in group: + p.data = p.saved_data.add_(state['projector'].project_back(p.data)) + + # apply weight decay + if 'weight_decay_saved' in group: + p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved']) + group['weight_decay'] = group['weight_decay_saved'] + del group['weight_decay_saved'] + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss + + +GaLoreAdamW8bit = AdamW8bit diff --git a/ms-swift/swift/trainers/rlhf_trainer/__init__.py b/ms-swift/swift/trainers/rlhf_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6d6a7fa3c254acb5ab1ae855de18b0c70ceaaa --- /dev/null +++ b/ms-swift/swift/trainers/rlhf_trainer/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from swift.utils.import_utils import _LazyModule + +if TYPE_CHECKING: + from .cpo_trainer import CPOTrainer + from .dpo_trainer import DPOTrainer + from .grpo_trainer import GRPOTrainer + from .kto_trainer import KTOTrainer + from .orpo_trainer import ORPOTrainer + from .ppo_trainer import PPOTrainer + from .reward_trainer import RewardTrainer + from .rlhf_mixin import RLHFTrainerMixin + from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin +else: + _import_structure = { + 'cpo_trainer': ['CPOTrainer'], + 'dpo_trainer': ['DPOTrainer'], + 'grpo_trainer': ['GRPOTrainer'], + 'kto_trainer': ['KTOTrainer'], + 'orpo_trainer': ['ORPOTrainer'], + 'ppo_trainer': ['PPOTrainer'], + 'reward_trainer': ['RewardTrainer'], + 'rlhf_mixin': ['RLHFTrainerMixin'], + 'utils': ['_split_into_mini_batches', 'patch_lora_merge', 'patch_lora_unmerge', 'round_robin'], + } + + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/ms-swift/swift/trainers/rlhf_trainer/cpo_trainer.py b/ms-swift/swift/trainers/rlhf_trainer/cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..25e4c93578d7d732e581ddfac46420bf5ffe6548 --- /dev/null +++ b/ms-swift/swift/trainers/rlhf_trainer/cpo_trainer.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import warnings +from typing import Optional, Union + +import torch.nn as nn +from transformers import PreTrainedModel +from trl import CPOTrainer as HFCPOTrainer + +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFCPOTrainer.__init__ + + +class CPOTrainer(RLHFTrainerMixin, SwiftMixin, HFCPOTrainer): + + def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs): + ref_model = kwargs.get('ref_model') + assert ref_model is None, 'CPO/SimPO does not require a ref_model.' + + args = kwargs['args'] + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.cpo_alpha = args.cpo_alpha + if args.loss_type == 'simpo': + self.simpo_gamma = args.simpo_gamma + if self.cpo_alpha > 0: + warnings.warn('You are using CPO-SimPO method because you set a non-zero cpo_alpha. ' + 'This will result in the CPO-SimPO method ' + '(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). ' + 'If you want to use a pure SimPO method, please set cpo_alpha to 0.') + super().__init__(model, *_args, **kwargs) diff --git a/ms-swift/swift/trainers/rlhf_trainer/dpo_trainer.py b/ms-swift/swift/trainers/rlhf_trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f03af82120fe16d29424383b3c68765d8e90355 --- /dev/null +++ b/ms-swift/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -0,0 +1,129 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from peft import PeftModel +from transformers import PreTrainedModel +from trl import DPOTrainer as HFDPOTrainer + +from ..mixin import DataLoaderMixin, SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin + +del HFDPOTrainer.__init__ + + +class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer): + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + *_args, + **kwargs): + from trl.trainer import FDivergenceConstants + args = kwargs['args'] + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + self.is_peft_model = isinstance(model, PeftModel) + + self.ref_adapter_name = args.ref_adapter_name + self.reference_free = args.reference_free + self.use_weighting = False + + super().__init__(model, ref_model, *_args, **kwargs) + + def get_nll_loss(self, logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + return loss_fct(logits, labels) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + batch = batch.copy() + num_examples = batch['labels'].shape[0] // 2 + labels = batch.pop('labels', None) + if self.is_encoder_decoder: + batch['labels'] = labels + + if self.aux_loss_enabled: + batch['output_router_logits'] = True + outputs = model(**batch, use_cache=False) + batch['labels'] = labels + if outputs.logits.shape[1] != labels.shape[1]: + # for llava, the model returns logits for the entire sequence, including the image tokens + # (placed before the text tokens) + outputs.logits = outputs.logits[:, -labels.shape[1]:] + for key in ['input_ids', 'attention_mask', 'labels']: + batch[f'concatenated_{key}'] = batch.pop(key, None) + if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels + batch['concatenated_input_ids'] = batch['concatenated_labels'] + + all_logits = outputs.logits + + if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]: + # for llava, the model returns logits for the entire sequence, + # including the image tokens (placed before the text tokens) + seq_len = batch['concatenated_labels'].shape[1] + all_logits = all_logits[:, -seq_len:] + + all_logps, size_completion = self.get_batch_logps( + all_logits, + batch['concatenated_labels'], + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + output = {} + + if self.args.rpo_alpha is not None: + labels = batch['concatenated_labels'].clone() + output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples]) + + if self.loss_type == 'ipo': + all_logps = all_logps / size_completion + + output['chosen_logps'] = all_logps[:num_examples] + output['rejected_logps'] = all_logps[num_examples:] + output['mean_chosen_logits'] = all_logits[:num_examples].mean() + output['mean_rejected_logits'] = all_logits[num_examples:].mean() + + if self.aux_loss_enabled: + output['aux_loss'] = outputs.aux_loss + + return output + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + if logits.shape[:-1] != labels.shape: + raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' + 'and labels must have the same shape {labels.shape}') + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + else: + labels = labels.clone() + + loss_mask = labels != label_pad_token_id + + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) diff --git a/ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py b/ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0d300132380c246a7070bf5b77a1f27bff23cc31 --- /dev/null +++ b/ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -0,0 +1,1424 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from huggingface/trl. +import concurrent.futures +import inspect +import os +import re +import time +from collections import defaultdict, deque +from concurrent.futures import Future +from contextlib import contextmanager +from copy import copy, deepcopy +from dataclasses import asdict, dataclass, field +from math import ceil +from queue import Queue +from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import datasets +import numpy as np +import torch +import torch.nn as nn +import transformers +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from packaging import version +from torch.nn import ModuleList +from torch.utils.data import DataLoader +from transformers import PreTrainedModel, TrainerCallback +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.trainer import Trainer +from transformers.trainer_utils import seed_worker +from trl import GRPOTrainer as HFGRPOTrainer +from trl.extras.profiling import profiling_decorator +from trl.models import prepare_deepspeed +from trl.trainer.grpo_trainer import nanmax, nanmin + +from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device +from swift.llm.infer.infer_engine import set_device_context +from swift.llm.template.template_inputs import StdTemplateInputs +from swift.plugin import multi_turns, orms, rm_plugins +from swift.utils import (JsonlWriter, gc_collect, get_device, get_device_count, get_dist_setting, get_logger, + get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available) +from ..mixin import SwiftMixin +from .rlhf_mixin import RLHFTrainerMixin +from .utils import patch_lora_merge, patch_lora_unmerge, round_robin + +del HFGRPOTrainer.__init__ +del HFGRPOTrainer.log + +logger = get_logger() +if is_wandb_available(): + import wandb + +InputsType = List[Dict[str, Union[torch.Tensor, Any]]] +OutputsType = List[List[Tuple[List[Dict], str]]] + + +@contextmanager +def unwrap_model_for_generation( + model, + accelerator, + gather_deepspeed3_params=True, + gather_parameters: List = None, +): + unwrapped_model = accelerator.unwrap_model(model) + if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) + else: + import deepspeed + parameters = [ + parameter for name, parameter in model.named_parameters() + if not gather_parameters or name in gather_parameters + ] + with deepspeed.zero.GatheredParameters(parameters): + from trl.models.utils import remove_hooks + remove_hooks(model) + yield accelerator.unwrap_model(model) + from trl.models.utils import add_hooks + add_hooks(model) + else: + yield unwrapped_model + + +class GRPOCallback(TrainerCallback): + + def __init__(self, trainer): + self.trainer = trainer + + # offload original_modules to cpu, to save memory + def on_train_begin(self, args, state, control, **kwargs): + self.trainer.queue = self.trainer.train_queue + train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader') + self.trainer._prefetch(train_dataloader) + + +@dataclass +class DataCache: + inputs: List[Dict] = field(default_factory=list) + outputs: List[Dict] = field(default_factory=list) + distributed_idx: List[List] = field(default_factory=list) + + +class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer): + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + def __init__(self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None, + reward_model: Optional[List[Union[PreTrainedModel, nn.Module]]] = None, + reward_funcs: Optional[List[Union[str, Callable]]] = None, + *_args, + **kwargs): + from swift.trainers.rlhf_arguments import GRPOConfig + args: GRPOConfig = kwargs['args'] + self.args = args + self.train_queue = Queue() + self.eval_queue = Queue() + self.processing_class = kwargs.get('template').tokenizer + self.offload_modules = {} + self.offload_states = {} + _, _, _, local_world_size = get_dist_setting() + + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin') + + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + self.reward_model_plugins = [None] * len(self.reward_funcs) + + if reward_model is not None: + reward_template = kwargs.pop('reward_template') + reward_plugins = args.reward_model_plugin + if reward_plugins is None: + reward_plugins = ['default'] * len(reward_model) + assert len(reward_plugins) == len(reward_model), ( + f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match " + f"the number of 'reward_model' ({len(reward_model)}). " + "Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.") + for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template): + # Set encoding mode train(see details in Template.encode). + # Set max_length to None to disable truncation, as the input length has already been truncated earlier. + rm_template.set_mode('train') + rm_template.max_length = None + if rm_plugin not in rm_plugins: + raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin') + self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template)) + self.reward_funcs.append(rm) + self.reward_func_names.append(rm.config._name_or_path.split('/')[-1]) + + if not self.reward_funcs: + raise ValueError('You must specify reward_funcs or reward_model') + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + self.multi_turn_func = None + if self.args.multi_turn_func: + if isinstance(self.args.multi_turn_func, str): + assert self.args.multi_turn_func in multi_turns + multi_turn_func = multi_turns[self.args.multi_turn_func] + self.multi_turn_func = multi_turn_func + else: + self.multi_turn_func = self.args.multi_turn_func + + self.num_generations = args.num_generations + self.temperature = args.temperature + self.loss_type = args.loss_type + model.warnings_issued['estimate_tokens'] = True + kwargs['data_collator'] = lambda features: features + self.shuffle_dataset = args.dataset_shuffle + + use_vllm = args.use_vllm + use_lmdeploy = args.use_lmdeploy + vllm_client = kwargs.pop('vllm_client') # for external vllm + if self.args.tensor_parallel_size > 1 and self.multi_turn_func: + import torch.distributed as dist + rank, _, _, _ = get_dist_setting() + for tp_group in self.tp_group_ranks(): + group = dist.new_group(tp_group) + if rank in tp_group: + self.group = group + + super().__init__(model, ref_model, *_args, **kwargs) + + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl')) + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps + self._textual_logs = { + 'prompt': deque(maxlen=maxlen), + 'completion': deque(maxlen=maxlen), + 'rewards': defaultdict(lambda: deque(maxlen=maxlen)), + } + + num_processes = self.accelerator.num_processes + self.effective_train_batch_size = effective_batch_size = \ + args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps + possible_values = [n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0] + + if self.num_generations not in possible_values: + raise ValueError( + f'The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x ' + f'{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per ' + f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for ' + f'the number of generations are: {possible_values}.') + if self.args.eval_strategy != 'no': + effective_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [ + n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0 + ] + if self.num_generations not in possible_values: + raise ValueError( + f'The effective eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be ' + f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the ' + 'current effective eval batch size, the valid values for the number of generations are: ' + f'{possible_values}.') + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + self.parameter_groups, self.parameter_groups_no_lora = self.split_batches() + self.infer_device = None + self.use_fast_infer = use_vllm or use_lmdeploy # whether to use the PT backend + self.is_external_vllm = use_vllm and args.vllm_server_host is not None + if self.use_fast_infer: + if self.infer_rank >= 0: + fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device + if fast_infer_device[0] == 'auto': + if get_device_count() == 1: + fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it + else: + fast_infer_device = [] + for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()): + fast_infer_device.append(get_device(idx)) + + for _device in fast_infer_device: + # Check that the requested device is available + if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count(): + raise ValueError(f'The requested device for vllm ({_device}) is not available. ' + f'You are likely using vLLM ' + 'without restricting the number of GPUs for training. ' + 'Set the `--num_processes` argument to a ' + 'value lower than the number of GPUs available on your machine—typically, ' + 'reducing it by one is sufficient. ' + f'In your case: `--num_processes {get_device_count() - 1}`.') + + if use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.is_external_vllm: + self.vllm_client = vllm_client + else: + self.engine = self.prepare_vllm(model, fast_infer_device) + self.infer_device = fast_infer_device[self.local_infer_rank] + elif use_lmdeploy: + if not is_lmdeploy_available(): + raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.' + 'Please install LMDeploy with `pip install lmdeploy -U` to use it.') + from swift.llm import LmdeployEngine + from swift.tuners import Swift + with Swift.grpo_context(model, self.template.processor): + fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1]) + self.engine = LmdeployEngine( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + devices=[fast_infer_device], + session_len=args.lmdeploy_session_len, + cache_max_entry_count=args.lmdeploy_cache_max_entry_count, + reload_weights=True) + self.infer_device = fast_infer_device + from lmdeploy.turbomind.turbomind import TurboMind + lmdeploy_engine = self.engine.engine.engine + assert isinstance(lmdeploy_engine, TurboMind), ( + "Currently only LMDeploy's TurboMind backend is supported. " + 'The current model is incompatible - please use vLLM or PyTorch backend instead.') + if not self.is_external_vllm: + self.engine.default_template = copy(self.template) # Avoid thread-unsafe modifications of the mode. + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + from swift.llm import PtEngine + self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit + # Avoid thread-unsafe modifications of the mode. + self.request_config = RequestConfig( + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + ) + + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + self.request_config.n = self.args.tensor_parallel_size + if self.infer_rank >= 0: + self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size + + self.model_accepts_loss_kwargs = False + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + + # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + if self.args.async_generate: + self.add_callback(GRPOCallback(self)) + + if self.args.dynamic_sample: + self.resample_dataset = deepcopy(self.train_dataset) + + def cyclic_iter(iterable): + while True: + for x in iterable: + yield x + + self.resample_iterator = cyclic_iter(self.get_resample_dataloader()) + # flag indicating whether the evaluation has started + self.eval_flag = False + + @profiling_decorator + def _prepare_inputs( + self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + mode = 'train' if self.model.training else 'eval' + if mode == 'train': + generate_every = self.args.gradient_accumulation_steps * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch) + self._buffered_inputs = accumulated_local_batch # < this is the change + inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] + self._step += 1 + else: + inputs = self._generate_and_score_completions(accumulated_local_batch) + return inputs + + def split_batches(self): + """Sync weights in batches + Only split LLM layers for now: + 1. N batches for layers + 2. other, embeds, lm_heads in one batch + 3. multi-modal components in one batch + """ + model = self.accelerator.unwrap_model(self.model) + if self.args.move_model_batches is None: + # All in one + return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None] + + model_arch = get_model_arch(model.model_meta.model_arch) + non_llm_parameters = [] + llm_embeds = [] + parameters = [] + pattern = r'\.(\d+)\.' + + layer_count = None + # Get the number of layers in LLM modules + for name, module in model.named_modules(): + if isinstance(module, ModuleList): + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + continue + if isinstance(llm, list): + llm = llm[0] + if name.startswith('base_model'): + name = name.replace('base_model.', '') + if llm in name: + layer_count = len(module) + else: + layer_count = len(module) + assert layer_count is not None, 'Cannot find ModuleList to split modules.' + + n_layers = ceil(layer_count / self.args.move_model_batches) + for _ in range(self.args.move_model_batches): + parameters.append([]) + + def replace_lora(name): + if 'lora_' in name: + return '' + else: + return name.replace('base_layer.', '') + + def remove_lora_and_prefix(names): + names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names]) + return [n for n in names if n] + + def split_llm(name): + match = re.search(pattern, name) + if match: + number = match.group(1) + group = int(number) // n_layers + parameters[group].append(name) + else: + llm_embeds.append(name) + + for name, parameter in model.named_parameters(): + if 'ref_model' in name: + continue + if model_arch is not None and isinstance(model_arch, MultiModelKeys): + llm = model_arch.language_model + vision_tower = model_arch.vision_tower + if any(vt in name for vt in vision_tower): + non_llm_parameters.append(name) + elif isinstance(llm, list): + llm = llm[0] + if llm in name: + split_llm(name) + else: + non_llm_parameters.append(name) + else: + split_llm(name) + + if llm_embeds: + parameters.append(llm_embeds) + if non_llm_parameters: + parameters.append(non_llm_parameters) + parameters = [p for p in parameters if p] + parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters] + return parameters, parameters_no_lora + + def prepare_vllm(self, model, fast_infer_device): + from swift.tuners import Swift + from swift.llm import VllmEngine + from swift.llm.infer.infer_engine import GRPOVllmEngine + _, _, _, local_world_size = get_dist_setting() + if self.args.tensor_parallel_size > 1: + vllm_kwargs = {'distributed_executor_backend': 'external_launcher'} + else: + vllm_kwargs = {} + if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1: + # Compatibility with TP + cls = GRPOVllmEngine + engine_kwargs = {'seed': 0} + else: + cls = VllmEngine + engine_kwargs = {} + with Swift.grpo_context(model, self.template.processor): + engine = cls( + model.model_dir, + model.model_info.torch_dtype, + model_type=model.model_meta.model_type, + device=fast_infer_device[self.local_infer_rank], + tensor_parallel_size=self.args.tensor_parallel_size, + gpu_memory_utilization=self.args.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=self.args.vllm_max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + num_infer_workers=self.args.num_infer_workers, + enable_sleep_mode=self.args.sleep_level > 0, + use_async_engine=False, + max_model_len=self.args.vllm_max_model_len, + engine_kwargs=engine_kwargs, + **vllm_kwargs) + engine.default_template = self.template + return engine + + @property + def infer_rank(self): + if self.is_external_vllm: + # When using external vLLM, only the main process (rank=0) acts as the client. + return 0 if self.accelerator.is_main_process else -1 + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return node_rank * self.args.num_infer_workers + _vllm_rank + if local_rank == -1: + return 0 + return -1 + + @property + def infer_rank_tp_0(self): + # whether is tp rank0, get data from this rank + # vllm needs all tp ranks inputs and sampling params are the same + rank, local_rank, world_size, local_world_size = get_dist_setting() + node_rank = get_node_setting()[0] + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank and _vllm_rank % self.args.tensor_parallel_size == 0: + return (node_rank * self.args.num_infer_workers + _vllm_rank // self.args.tensor_parallel_size) + if local_rank == -1: + return 0 + return -1 + + @property + def local_infer_rank(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + for _vllm_rank in range(self.args.num_infer_workers): + if local_rank == _vllm_rank: + return _vllm_rank + + return -1 + + def tp_group_ranks(self): + rank, local_rank, world_size, local_world_size = get_dist_setting() + return [ + list(range(0, world_size))[i:i + self.args.tensor_parallel_size] + for i in range(0, world_size, self.args.tensor_parallel_size) + ] + + @contextmanager + def _template_context(self, template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + mode = template.mode + if mode in {'vllm', 'pt', 'lmdeploy'}: + template.set_mode('train') + template.max_length = None + loss_scale = template.loss_scale + if self.multi_turn_func: + template.loss_scale = 'default' + try: + yield + finally: + template.loss_scale = loss_scale + template.set_mode(mode) + template.max_length = max_length + + @profiling_decorator + def _move_model_to_vllm_lmdeploy(self): + if self.is_external_vllm: + return super()._move_model_to_vllm() + + from accelerate.utils.other import is_compiled_module + + for i, parameter_group in enumerate(self.parameter_groups): + parameter_group_no_lora = self.parameter_groups_no_lora[i] + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + gather_parameters=parameter_group) as unwrapped_model: + + if is_compiled_module(unwrapped_model): + unwrapped_model = unwrapped_model._orig_mod + if is_peft_model(unwrapped_model): + with patch_lora_merge(unwrapped_model, parameter_group): + unwrapped_model.merge_adapter() + state_dict = unwrapped_model.state_dict() + # Remove base_model and base_layer prefixes + state_dict = { + k.removeprefix('base_model.model.').replace('.base_layer', ''): v + for k, v in state_dict.items() + } + # Remove values with adapter prefix (example: "_lora") + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} + # When module to save, remove its prefix and discard the original module + state_dict = { + k.replace('modules_to_save.default.', ''): v + for k, v in state_dict.items() if 'original_module' not in k + } + else: + state_dict = unwrapped_model.state_dict() + if parameter_group_no_lora: + parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora] + state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora} + assert len(state_dict) > 0 and all([state.shape != torch.Size([0]) for state in state_dict.values()]) + if self.infer_rank >= 0: + if self.args.async_generate: + self._wait_queue() + if self.args.use_vllm: + llm_model = self.engine.inner_model + else: + llm_model = self.engine.engine.engine + llm_model.load_weights(state_dict.items()) + del state_dict + gc_collect() + # Unmerge the adapter to restore the model to its original state. + # This must be done after loading weights to ensure they correspond to the merged state. + if is_peft_model(unwrapped_model): + with patch_lora_unmerge(unwrapped_model): + unwrapped_model.unmerge_adapter() + + if self.infer_rank >= 0 and self.args.use_vllm and self.args.vllm_enable_prefix_caching: + self.engine.engine.reset_prefix_cache() + + def _wait_queue(self): + while self._queue.empty(): + time.sleep(0.01) + + @staticmethod + def reorder_outputs(outputs, distributed_idx): + index_to_output = {} + current_position = 0 + for output_idx in distributed_idx: + for idx in output_idx: + index_to_output[idx] = outputs[current_position] + current_position += 1 + + return [index_to_output[idx] for idx in sorted(index_to_output.keys())] + + def _infer_multi_turn(self, inputs_slice: np.ndarray, request_config: RequestConfig) -> Union[OutputsType, List]: + """Perform multi-turn or single-turn inference with support for tensor parallelism. + + Args: + inputs_slice: Array of input requests + request_config: Inference configuration parameters + + Returns: + List of outputs where each entry contains: + - List of responses per prompt (length = tensor_parallel_size) + - Each response is a tuple of (message_history, finish_reason) + """ + from swift.llm.infer.protocol import ChatCompletionResponse + rank, _, _, _ = get_dist_setting() + request_config = copy(request_config) + results: List[ChatCompletionResponse] = self._engine_infer( + infer_requests=inputs_slice, request_config=request_config, use_tqdm=False) + prompt_lens = len(inputs_slice) + messages_list = [None] * (len(inputs_slice) * self.args.tensor_parallel_size) + if self.multi_turn_func: + remove_response = True + while len(inputs_slice) > 0: + request_config.n = 1 + if self.infer_rank_tp_0 >= 0 or not self.use_fast_infer: + inputs = [] + cnt = 0 + for i, output in enumerate(results): + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + if remove_response or _input['messages'][-1]['role'] != 'assistant' or not \ + _input['messages'][-1]['content']: + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + else: + _input['messages'][-1]['content'] += choice.message.content + if 'index' not in _input: + _input['index'] = cnt + _input['finish_reason'] = choice.finish_reason + cnt += 1 + inputs.append(_input) + results: List[Dict] = self.multi_turn_func(inputs) # noqa + else: + length = sum([len(results[i].choices) for i in range(len(results))]) + results = [None] * length + + if self.args.tensor_parallel_size > 1: + # avoid duplicate calling in the same tensor parallel group + import torch.distributed as dist + if 'group_src' in inspect.signature(dist.broadcast_object_list).parameters: + dist.broadcast_object_list(results, group_src=0, group=self.group) + else: + global_src = dist.get_global_rank(self.group, 0) + dist.broadcast_object_list(results, src=global_src, group=self.group) + inputs_slice = [r for r in results if not r['finished']] + for idx, r in enumerate(results): + if r['finished'] or r['finish_reason'] == 'length': + messages_list[r['index']] = (r['messages'], r['finish_reason']) + if len(inputs_slice) > 0: + _input_std = [] + for _input in inputs_slice: + _input_std.append(StdTemplateInputs.from_dict(_input)) + # StdTemplateInputs will not remove responses in infer + results = self._engine_infer( + infer_requests=_input_std, request_config=request_config, use_tqdm=False) + # concat responses from the second loop + remove_response = False + + outputs = [] + assert not any([m is None for m in messages_list]) + for i in range(0, len(messages_list), self.args.tensor_parallel_size): + # reformat to [[x, x, x, x] [x, x, x, x]] + # this is the same format of sampling_params.n > 1 + outputs.append(messages_list[i:i + self.args.tensor_parallel_size]) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + else: + # single turn + outputs = [] + for i, output in enumerate(results): + _choices = [] + for choice in output.choices: + _input: Dict = deepcopy(inputs_slice[i]) + InferRequest.remove_response(_input['messages']) + _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) + _choices.append((_input['messages'], choice.finish_reason)) + outputs.append(_choices) + assert len(outputs) == prompt_lens + assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) + + if self.args.tensor_parallel_size > 1: + if self.infer_rank_tp_0 < 0: + outputs = [] + else: + _outputs = [] + for tp_idx in range(self.args.tensor_parallel_size): + for prompt_idx in range(len(outputs)): + _outputs.append(outputs[prompt_idx][tp_idx]) + outputs = [_outputs] + + return outputs + + def async_infer(self, inputs, inputs_slice, distributed_idx): + + def infer_task(): + with set_device_context(self.infer_device), self.multi_turn_completion_length_context(): + return self._infer_multi_turn(inputs_slice, self.request_config) + + future: Future = self.executor.submit(infer_task) + # pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling + current_queue = self._queue + + def done(_self): + current_queue.put(DataCache(inputs, _self.result(), distributed_idx)) + + future.add_done_callback(done) + + def _prefetch(self, dataloader: DataLoader): + inputs = next(iter(dataloader)) + all_inputs = gather_object(inputs) + nnodes = get_node_setting()[1] + distributed_idx = round_robin(len(all_inputs), nnodes * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + self._queue.put(DataCache(inputs, outputs, distributed_idx)) + else: + self._queue.put(DataCache(inputs, [], distributed_idx)) + if self.accelerator.num_processes > 1: + self.accelerator.wait_for_everyone() + + def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: + """ + This function performs fast inference by managing model and optimizer offloading, + loading weights if necessary, distributing inputs among workers, and generating + completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous + inference modes. + inputs: local inputs + """ + + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + if self.args.offload_model: + self.offload_model() + if self.args.offload_optimizer: + self.offload_optimizer() + if self.args.gc_collect_after_offload: + gc_collect() + # Skip the first wake_up to avoid the warning "Executor is not sleeping" + if self.engine.inner_model_executor.is_sleeping: + self.engine.engine.wake_up() + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy() + self._last_loaded_step = self.state.global_step + all_inputs = gather_object(inputs) + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + # Distribute inputs to different workers + # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker + # 1/3/5 dispatch to the second worker + # trying to shuffle and average the length + nnodes = get_node_setting()[1] + num_workers = 1 if self.is_external_vllm else nnodes + distributed_idx = round_robin(len(all_inputs), num_workers * self.args.num_infer_workers) + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + if self.args.async_generate: + self.async_infer(inputs, _input_slice, distributed_idx) + data_cache = self._queue.get() + inputs = data_cache.inputs + outputs = data_cache.outputs + distributed_idx = data_cache.distributed_idx + else: + with set_device_context(self.infer_device): + request_config = copy(self.request_config) + if self.args.tensor_parallel_size > 1: + request_config.seed += self.state.global_step + with self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(_input_slice, self.request_config) + else: + if self.args.async_generate: + # using old model to generate, which will ignore the `clip` of advantages. + self._queue.put(DataCache(inputs, [], distributed_idx)) + data_cache = self._queue.get() + inputs = data_cache.inputs + distributed_idx = data_cache.distributed_idx + outputs = [] + outputs = gather_object(outputs) + if self.args.tensor_parallel_size > 1: + outputs = [[item] for output in outputs for item in output] + if not self.is_external_vllm: + outputs = self.reorder_outputs(outputs, distributed_idx) + if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0: + self.engine.engine.sleep(level=self.args.sleep_level) + if self.args.gc_collect_after_offload: + gc_collect() + if self.args.offload_model: + self.load_model() + if self.args.offload_optimizer: + self.load_optimizer() + return inputs, outputs + + def _generate_completions(self, inputs: InputsType) -> InputsType: + """Generate completions for given inputs using either fast inference or standard PyTorch inference. + + Args: + inputs: List of input examples containing conversation messages. + + Returns: + Modified inputs with generated completions added to the last message + and truncation flag set in 'is_truncated' field. + """ + mode = 'train' if self.model.training else 'eval' + if self.use_fast_infer: + inputs, outputs = self._fast_infer(inputs) + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + outputs = outputs[process_slice] + else: + # pt infer + is_multimodal = self.model.model_meta.is_multimodal + if is_multimodal: + models = self.template.remove_post_encode_hook() + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ), self.multi_turn_completion_length_context(): + outputs = self._infer_multi_turn(inputs, self.request_config) + if mode == 'train': + # In training mode, ensure the model is returned to train() mode after inference + # This is necessary as pt engines set the model to eval mode during generation + self.model.train() + if is_multimodal: + self.template.register_post_encode_hook(models) + if isinstance(outputs[0][0], list): + outputs = [output[0] for output in outputs] + + for i, output in enumerate(outputs): + inputs[i]['messages'] = output[0][0] + inputs[i]['is_truncated'] = output[0][1] == 'length' + + return inputs + + def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + + inputs = self._generate_completions(inputs) + total_rewards_per_func, total_rewards, completions = self._score_completions(inputs) + mode = 'train' if self.model.training else 'eval' + + if self.args.dynamic_sample and mode == 'train': + # dynamic sampling for std=0 groups + inputs, total_rewards, total_rewards_per_func, completions = \ + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) + + # Prepare final outputs with advantages and other required fields + batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards) + # Log metrics + messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))] + + self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) + + return batch_encoded_inputs + + def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + """Score completions using all reward functions + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + Tuple containing: + - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards + - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards + - completions: List of generated completion strings + """ + device = self.accelerator.device + completions = [example['messages'][-1]['content'] for example in inputs] + rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) + + for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)): + # reward model + if isinstance(reward_func, nn.Module): + rewards_per_func[:, i] = reward_model_plugin(inputs=inputs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs = RowPreprocessor.rows_to_batched(inputs) + output_reward_func = reward_func(completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + total_rewards_per_func = gather(rewards_per_func) + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + + return total_rewards_per_func, total_rewards, completions + + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): + # DAPO https://arxiv.org/abs/2503.14476 + # Replaces samples with zero-reward-variance groups (std=0) + resample_count = 0 + valid_samples = [] + valid_rewards = [] + valid_rewards_per_func = [] + valid_completions = [] + + origin_data = (inputs, rewards, rewards_per_func, completions) + + while resample_count < self.args.max_resample_times: + grouped_rewards = rewards.view(-1, self.num_generations) + group_std = grouped_rewards.std(dim=1) + + valid_mask = (group_std > 0).repeat_interleave(self.num_generations) + all_inputs = gather_object(inputs) + valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask]) + valid_rewards.append(rewards[valid_mask]) + valid_rewards_per_func.append(rewards_per_func[valid_mask]) + valid_completions.extend( + [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask]) + + if len(valid_samples) >= self.effective_train_batch_size: + break + + inputs = next(self.resample_iterator) + inputs = Trainer._prepare_inputs(self, inputs) + inputs = self._generate_completions(inputs) + rewards_per_func, rewards, completions = self._score_completions(inputs) + resample_count += 1 + + if len(valid_samples) >= self.effective_train_batch_size: + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + inputs = valid_samples[:self.effective_train_batch_size][process_slice] + rewards = torch.cat(valid_rewards)[:self.effective_train_batch_size] + rewards_per_func = torch.cat(valid_rewards_per_func)[:self.effective_train_batch_size] + completions = valid_completions[:self.effective_train_batch_size][process_slice] + else: + logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.') + inputs, rewards, rewards_per_func, completions = origin_data + + return inputs, rewards, rewards_per_func, completions + + def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]: + """ + Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + + Args: + inputs (InputsType): List of input samples. Original shape is [gas*bs] where: + - gas: gradient accumulation steps + - bs: per-device batch size + rewards (torch.Tensor): Tensor of rewards corresponding to the inputs. + Shape should match the total number of samples (gas*bs*num_generations) + + Returns: + List[InputsType]: A list of prepared batch inputs, organized as [gas][bs] + """ + # Compute advantages + grouped_rewards = rewards.view(-1, self.num_generations) + mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) + advantages = (rewards - mean_grouped_rewards) + if self.args.scale_rewards: + advantages /= (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(inputs), + (self.accelerator.process_index + 1) * len(inputs), + ) + advantages = advantages[process_slice] + + mode = 'train' if self.model.training else 'eval' + bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + gas = self.args.gradient_accumulation_steps if mode == 'train' else 1 + + assert len(inputs) == bs * gas, f'Expected {bs * gas} inputs, got {len(inputs)}' + gas_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(gas)] + + ga_batch_encoded_inputs = [] + template = self.template + + # Split advantages by GAS chunks + advantage_chunks = torch.chunk(advantages, gas) + + for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): + # Encode and process each batch (size=bs) + with self._template_context(template): + batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch] + batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) + + # Process labels and masks + labels = batch_encoded_inputs.pop('labels') + logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs.update({ + 'completion_mask': + labels[:, -logits_to_keep:] != -100, + 'truncated_mask': + torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool), + 'logits_to_keep': + logits_to_keep, + 'advantages': + batch_advantages + }) + + with torch.no_grad(): + batch_encoded_inputs['old_per_token_logps'] = ( + self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None) + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs) + batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps + + ga_batch_encoded_inputs.append(batch_encoded_inputs) + + return ga_batch_encoded_inputs + + def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func): + """Log training/evaluation metrics""" + mode = 'train' if self.model.training else 'eval' + device = self.accelerator.device + + # Calculate completion length metrics + agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs])) + + self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) + self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) + self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) + # Calculate clip ratio + agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device)) + + term_completion_mask = agg_completion_mask[agg_truncated_mask] + clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) + + self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) + + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = rewards_per_func[:, i].mean().item() + self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards) + std_rewards = rewards_per_func[:, i].std().item() + self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) + + # Log overall reward stats + grouped_rewards = rewards.view(-1, self.num_generations) + self._metrics[mode]['reward'].append(grouped_rewards.mean().item()) + self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item()) + + # Log prompt and completion texts + self._textual_logs['prompt'].extend(gather_object(messages)) + self._textual_logs['completion'].extend(gather_object(completions)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training + if isinstance(inputs, list): + assert len(inputs) == 1 + inputs = inputs[0] + completion_mask = inputs['completion_mask'] + truncated_mask = inputs['truncated_mask'] + # apply the completion_mask to exclude loss and metrics for overlong completions + if self.args.overlong_filter and any(truncated_mask): + if all(truncated_mask): + logger.info('All completions are overlong, loss and KL will be zero') + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) + completion_mask = completion_mask * (~truncated_mask) + + per_token_logps = self._get_per_token_logps(model, inputs) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs['ref_per_token_logps'] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + advantages = inputs['advantages'] + old_per_token_logps = inputs['old_per_token_logps'] if self.old_policy else per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + # Log the metrics + mode = 'train' if self.model.training else 'eval' + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum() + high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum() + clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum() + + gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) + self._metrics[mode]['clip_ratio/low_mean'].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/low_min'].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) + self._metrics[mode]['clip_ratio/high_mean'].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]['clip_ratio/high_max'].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) + self._metrics[mode]['clip_ratio/region_mean'].append(gathered_clip_ratio.nanmean().item()) + + return loss + + # Get the per-token log probabilities for the completions for the model and the reference model + @profiling_decorator + def _get_per_token_logps(self, model, inputs): + from trl.trainer.utils import selective_log_softmax + logits_to_keep = inputs['logits_to_keep'] + input_ids = inputs['input_ids'] + unwrapped_model = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_model): + parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters + else: + parameters = inspect.signature(unwrapped_model.forward).parameters + if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters: + # save memory + return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) + inputs = { + k: v + for k, v in inputs.items() if k not in [ + 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', + 'truncated_mask' + ] + } + with self._template_context(self.template): + logits = model(**inputs).logits + # exclude the last logit: it corresponds to the next token pred + logits = logits[:, -(logits_to_keep + 1):-1, :] + logits = logits / self.temperature + input_ids = input_ids[:, -logits_to_keep:] + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + def evaluation_loop(self, dataloader, *args, **kwargs): + # Wait for the training rollout to complete + if self.args.async_generate: + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + if self._queue.empty() and self.args.async_generate: + self._prefetch(dataloader) + metric_key_prefix = kwargs['metric_key_prefix'] + output = super().evaluation_loop(dataloader, *args, **kwargs) + metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()} + output.metrics.update(metrics) + self.eval_flag = True + return output + + def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor: + if self.args.async_generate: + # Wait for the eval rollout to complete + while not self.is_async_generate_eval_rollout_done(): + time.sleep(0.1) + return super().training_step(model, inputs, num_items_in_batch) + + def _engine_infer( + self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + *, + use_tqdm: Optional[bool] = None, + ): + if self.is_external_vllm: + self._process_infer_requests_images(infer_requests) + return self.vllm_client.infer(infer_requests.tolist(), asdict(request_config), use_tqdm=use_tqdm) + else: + return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + + def _process_infer_requests_images(self, infer_requests: List[InferRequest]): + import base64 + if not any('images' in request for request in infer_requests): + return + for request in infer_requests: + if 'images' not in request: + continue + for i, img in enumerate(request['images']): + if 'bytes' in img and img['bytes']: + request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8') + return + + @property + def old_policy(self): + return self.num_iterations > 1 + + @property + def _queue(self): + if self.control.should_evaluate: + return self.eval_queue + else: + return self.train_queue + + @torch.no_grad() + def offload_model(self): + if len(self.offload_modules) > 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, module in unwrapped_model.named_modules(): + if isinstance(module, torch.nn.Embedding): + self.offload_modules[name] = module.weight.device + module.to('cpu') + elif not hasattr(module, 'device'): + pass + elif module.device.type != 'cpu': + self.offload_modules[name] = module.device + module.to('cpu') + + @torch.no_grad() + def load_model(self): + if len(self.offload_modules) == 0: + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + for name, device in self.offload_modules.items(): + module = unwrapped_model.get_submodule(name) + if isinstance(module, torch.nn.Embedding): + module.weight.to(device) + else: + module.to(device) + self.offload_modules.clear() + + @torch.no_grad() + def offload_optimizer(self): + if len(self.offload_states) > 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + self.offload_states[key] = value.device + state[key] = value.to('cpu', non_blocking=True) + + @torch.no_grad() + def load_optimizer(self): + if len(self.offload_states) == 0: + return + if not self.optimizer.state: + return + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + state = self.optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to(self.offload_states[key], non_blocking=True) + self.offload_states.clear() + + @contextmanager + def multi_turn_completion_length_context(self): + """ + Context manager that temporarily adjusts the engine's max length handling + for multi-turn generation scenarios. + + Ensures the total sequence length (prompt + completion) never exceeds: + min(original_max_len, prompt_tokens + max_completion_length) + """ + if not (self.multi_turn_func and self.infer_rank >= 0) or self.is_external_vllm: + yield + return + + original_fn = self.engine.set_default_max_tokens + original_max_len = self.engine.max_model_len + + def set_default_max_tokens(_self, request_config: RequestConfig, inputs: InputsType) -> None: + # Calculate required context window + original_max_len = _self.max_model_len or 8192 + if isinstance(inputs, dict): + inputs = [inputs] + prompt_tokens = max(_self._get_num_tokens(inp) for inp in inputs) + + if not hasattr(_self, 'set_grpo_max_model_len'): + # set max model len in first round + max_len = min(original_max_len, prompt_tokens + request_config.max_tokens) + _self.max_model_len = max_len + _self.set_grpo_max_model_len = True + else: + if _self.max_model_len <= prompt_tokens: + # modify max_model_len > prompt_tokens to avoid crash + num_tokens_avoid_crash = 10 + _self.max_model_len = (prompt_tokens + num_tokens_avoid_crash) + request_config.max_tokens = num_tokens_avoid_crash + + original_fn(request_config, inputs) + + try: + self.engine.set_default_max_tokens = MethodType(set_default_max_tokens, self.engine) + yield + finally: + self.engine.set_default_max_tokens = original_fn + self.engine.max_model_len = original_max_len + del self.engine.set_grpo_max_model_len + + def get_resample_dataloader(self) -> DataLoader: + resample_dataset = self.resample_dataset + data_collator = self.data_collator + if isinstance(resample_dataset, datasets.Dataset): + resample_dataset = self._remove_unused_columns(resample_dataset, description='training') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='training') + + dataloader_params = { + 'batch_size': self._train_batch_size * self.args.gradient_accumulation_steps, + 'collate_fn': data_collator, + 'num_workers': self.args.dataloader_num_workers, + 'pin_memory': self.args.dataloader_pin_memory, + 'persistent_workers': self.args.dataloader_persistent_workers, + } + + @contextmanager + def seed_context(self): + seed = self.args.seed + self.args.seed = seed + 1 + yield + self.args.seed = seed + + if not isinstance(resample_dataset, torch.utils.data.IterableDataset): + with seed_context(self): # Set a different seed for resampling than the train_dataset. + dataloader_params['sampler'] = self._get_train_sampler() + dataloader_params['drop_last'] = self.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + dataloader_params['prefetch_factor'] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(resample_dataset, **dataloader_params)) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = 'train' if self.model.training else 'eval' + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == 'eval': + metrics = {f'eval_{key}': val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + table = { + 'step': [str(self.state.global_step)] * len(self._textual_logs['prompt']), + 'prompt': self._textual_logs['prompt'], + 'completion': self._textual_logs['completion'], + **self._textual_logs['rewards'], + } + self.jsonl_writer.append(table) + if self.args.report_to and 'wandb' in self.args.report_to and wandb.run is not None: + import pandas as pd + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=['prompt']) + wandb.log({'completions': wandb.Table(dataframe=df)}) + + def is_async_generate_eval_rollout_done(self): + return not self.eval_flag or not self.eval_queue.empty() + + def is_async_generate_train_rollout_done(self): + return not self.train_queue.empty() diff --git a/ms-swift/swift/trainers/sequence_parallel/__init__.py b/ms-swift/swift/trainers/sequence_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0602f84075025d705b8910801b030f2591e77804 --- /dev/null +++ b/ms-swift/swift/trainers/sequence_parallel/__init__.py @@ -0,0 +1,8 @@ +import os + +if os.environ.get('SEQUENCE_PARALLEL_IMPL', 'ulysses') == 'xtuner': + from .xtuner import XTuner + sequence_parallel = XTuner() +else: + from .ulysses import Ulysses + sequence_parallel = Ulysses() diff --git a/ms-swift/swift/tuners/__pycache__/base.cpython-310.pyc b/ms-swift/swift/tuners/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b84097a290c343f39a2be3f1c1e5152a083142a Binary files /dev/null and b/ms-swift/swift/tuners/__pycache__/base.cpython-310.pyc differ diff --git a/ms-swift/swift/tuners/base.py b/ms-swift/swift/tuners/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fafc0883abce55d975055352ade4d9f5b3cbdd58 --- /dev/null +++ b/ms-swift/swift/tuners/base.py @@ -0,0 +1,926 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2023-present the HuggingFace Inc. team. +import os +import re +import shutil +import tempfile +from contextlib import contextmanager +from copy import copy +from functools import partial +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Dict, List, Literal, Optional, Union + +import json +import torch +from modelscope import snapshot_download +from peft.utils import CONFIG_NAME +from peft.utils.other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME +from torch import nn +from transformers import Trainer + +from swift.utils.constants import DEFAULT_ADAPTER, SWIFT_TYPE_KEY +from swift.utils.logger import get_logger +from ..utils.torch_utils import get_device_count +from .mapping import SwiftTuners +from .peft import PeftConfig, PeftModel, get_peft_model +from .utils import SwiftConfig, SwiftOutput + +logger = get_logger() + + +class SwiftModel(nn.Module): + """The Swift wrapper model. + + Args: + model (`Union[nn.Module, 'SwiftModel']`) A module to be tuned by Swift. + config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of {adapter_name: SwiftConfig}. + If it's a config class, the adapter_name will be `default` + extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved. + inference_mode (bool, `optional`): Load model at inference mode, default False. + """ + + EXTRA_STATE_DIR = 'extra_states' + + def __init__(self, + model: Union[nn.Module, 'SwiftModel'], + config: Union[SwiftConfig, Dict[str, SwiftConfig]], + extra_state_keys: List[str] = None, + inference_mode: bool = False, + **kwargs): + super().__init__() + self.adapters = {} + self.active_adapters = set() + if isinstance(model, SwiftModel): + self.adapters = model.adapters + extra_state_keys = extra_state_keys or [] + extra_state_keys.extend(model.extra_state_keys) + self.active_adapters = model.active_adapters + model = model.base_model + + self.base_model = model + new_adapters = [] + if isinstance(config, SwiftConfig): + if DEFAULT_ADAPTER not in self.adapters: + all_parts = self._deactivate_all_parts() + self.adapters[DEFAULT_ADAPTER] = self._prepare_model(model, config, DEFAULT_ADAPTER) + for part in all_parts: + self.activate_adapter(part) + new_adapters.append(DEFAULT_ADAPTER) + if self.adapters[DEFAULT_ADAPTER].model is not None: + self.base_model = self.adapters[DEFAULT_ADAPTER].model + else: + logger.warn(f'Adapter {DEFAULT_ADAPTER} has been patched, skip.') + elif isinstance(config, dict): + assert (all(isinstance(c, SwiftConfig) for c in config.values())) + for adapter_name, _config in config.items(): + if adapter_name not in self.adapters: + all_parts = self._deactivate_all_parts() + self.adapters[adapter_name] = self._prepare_model(model, _config, adapter_name) + for part in all_parts: + self.activate_adapter(part) + new_adapters.append(adapter_name) + if self.adapters[adapter_name].model is not None: + self.base_model = self.adapters[adapter_name].model + else: + logger.warn(f'Adapter {adapter_name} has been patched, skip.') + + self.extra_state_keys = extra_state_keys or [] + self.has_additional_modules = any([c.config.has_additional_modules for c in self.adapters.values()]) + + def forward(self, *args, **kwargs): + return self.base_model(*args, **kwargs) + + _parameters = [Parameter('self', Parameter.POSITIONAL_ONLY)] + _parameters += list(signature(self.base_model.forward).parameters.values()) + forward.__signature__ = Signature(_parameters) + self.forward = MethodType(forward, self) + for adapter_name in new_adapters: + self.activate_adapter(adapter_name) + + if inference_mode: + self.eval() + else: + for key, output in self.adapters.items(): + if key in new_adapters: + output.mark_trainable_callback(model) + if self.extra_state_keys: + for n, p in model.named_parameters(): + if any(re.fullmatch(extra_key, n) for extra_key in self.extra_state_keys): + p.requires_grad = True + + @property + def model(self): + return self.base_model + + def _deactivate_all_parts(self): + deactivated = [] + for adapter in self.active_adapters: + output = self.adapters[adapter] + if output.config.swift_type == SwiftTuners.PART: + deactivated.append(adapter) + self.deactivate_adapter(adapter) + return deactivated + + def load_state_dict(self, state_dict, strict=True, adapter_name: str = None): + if adapter_name is not None: + output: SwiftOutput = self.adapters[adapter_name] + if getattr(output.config, 'modules_to_save', None): + for key, value in copy(state_dict).items(): + for module_name in output.config.modules_to_save: + if module_name in key: + state_dict.pop(key) + key = key.replace(module_name, f'{module_name}.modules_to_save.{adapter_name}') + break + state_dict[key] = value + + for key, value in copy(state_dict).items(): + if key.startswith('base_model.model.'): + state_dict.pop(key, None) + key = key[len('base_model.model.'):] + if f'lora_A.{adapter_name}.' not in key and 'lora_A' in key: + state_dict.pop(key, None) + key = key.replace('lora_A.', f'lora_A.{adapter_name}.') + if f'lora_B.{adapter_name}.' not in key and 'lora_B' in key: + state_dict.pop(key, None) + key = key.replace('lora_B.', f'lora_B.{adapter_name}.') + if f'lora_embedding_A.{adapter_name}.' not in key and 'lora_embedding_A' in key: + state_dict.pop(key, None) + key = key.replace('lora_embedding_A.', f'lora_embedding_A.{adapter_name}.') + if f'lora_embedding_B.{adapter_name}.' not in key and 'lora_embedding_B' in key: + state_dict.pop(key, None) + key = key.replace('lora_embedding_B.', f'lora_embedding_B.{adapter_name}.') + state_dict[key] = value + + if output.load_state_dict_callback: + state_dict = output.load_state_dict_callback(self.base_model, adapter_name, state_dict) + + incompatible_keys = self.base_model.load_state_dict(state_dict, False) + if incompatible_keys and len(incompatible_keys[1]) > 0: + logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}') + + def state_dict(self, + *args, + destination=None, + prefix='', + keep_vars=False, + adapter_name: str = None, + peft_format: bool = False, + **kwargs): + """ + Args: + destination (`dict`, `optional`): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (`str`, `optional`): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (`bool`, `optional`): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + adapter_name (`str`, `optional`): The name of the adapter's parameters to be saved, + `None` input will save all adapters. + peft_format (`bool`, `optional`): Save with peft format (extra `base_model.model.` prefix) + **kwargs: + save_adapter(`bool`): Save adapters or not, default True + save_extra_states(`bool`): Save extra states or not, default True + Returns: + The state dict to be saved. + """ + state_dict = kwargs.get('state_dict') + if state_dict is None: + state_dict = self.base_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + state_dict = { + key[len('base_model.'):] if key.startswith('base_model.') else key: value + for key, value in state_dict.items() + } + if not self.has_additional_modules: + return state_dict + + state_dicts = {} + if kwargs.get('save_adapter', True): + for name, output in self.adapters.items(): + if (adapter_name == name or adapter_name is None) and output.config.has_additional_modules: # noqa + state_dicts.update(output.state_dict_callback(state_dict, name)) + modules_to_save_names = [ + sub_name for sub_name, _ in self.base_model.named_parameters() + if f'modules_to_save.{name}' in sub_name + ] + for module_name in modules_to_save_names: + if f'modules_to_save.{name}' in module_name: + state_dicts[module_name.replace(f'modules_to_save.{name}.', '')] = state_dict[module_name] + if kwargs.get('save_extra_states', True): + state_dicts.update({ + k: v + for k, v in state_dict.items() if any( + re.fullmatch(extra_key, k) for extra_key in self.extra_state_keys) + }) + if peft_format: + new_state_dict = {} + for key, value in state_dicts.items(): + if not key.startswith('base_model.model.'): + key = 'base_model.model.' + key + key = key.replace(f'lora_A.{adapter_name}.', 'lora_A.') + key = key.replace(f'lora_B.{adapter_name}.', 'lora_B.') + key = key.replace(f'lora_embedding_A.{adapter_name}.', 'lora_embedding_A.') + key = key.replace(f'lora_embedding_B.{adapter_name}.', 'lora_embedding_B.') + new_state_dict[key] = value + state_dicts = new_state_dict + return state_dicts + + def __getattr__(self, key: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(key) + except AttributeError: + if 'base_model' in dir(self): + return getattr(self.base_model, key) + raise + + @staticmethod + def load_state_file(path, device: Optional[str] = None): + """Load a state dict file by the input path. + + Args: + path: The local dir to load the state file. + + Returns: + The state dict. + """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + from safetensors.torch import load_file as safe_load_file + return safe_load_file(filename, device=device) + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + return torch.load(filename, map_location=device) + return None + + def create_optimizer_param_groups(self, **defaults): + all_param_names = set() + param_groups = [] + for output in self.adapters.values(): + if output.optimizer_group_callback: + param_names, param_group = output.optimizer_group_callback(self.model, **defaults) + if param_names and all_param_names & param_names: + raise ValueError('Cannot set one parameter to different param groups') + if param_names and param_group: + all_param_names.update(param_names) + param_groups.extend(param_group) + + decay_parameters = Trainer.get_decay_parameter_names(None, self.model) + param_groups.extend([ + { + 'params': [ + p for n, p in self.model.named_parameters() + if (n in decay_parameters and n not in all_param_names and p.requires_grad) + ], + 'weight_decay': + defaults['weight_decay'], + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if (n not in decay_parameters and n not in all_param_names and p.requires_grad) + ], + 'weight_decay': + 0.0, + }, + ]) + + return param_groups + + @classmethod + def from_pretrained(cls, + model: Union[nn.Module, 'SwiftModel'], + model_id: str = None, + adapter_name: Union[str, List[str], Dict[str, str]] = None, + inference_mode: bool = True, + revision: str = None, + **kwargs): + """Load a set of tuners and corresponding weights by a model_id. + + Args: + model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned, + if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped.. + model_id (`str`): The model_id or a local model dir of tuners to use to tune the model. + adapter_name (`Union[str, List[str], Dict[str, str]]`): The adapter_names saved in the model repo to load. + Default `None`, means load all tuners saved in the model_id + inference_mode (`bool`): Use in the inference mode or not. + revision (`str`): The model revision to use. + **kwargs: + extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved. + Other parameters will be passed to the device_map. + Returns: + The `SwiftModel` instance. + """ + adapters = {} + model_dir = model_id + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir, revision=revision) + if os.path.isfile(model_dir): + raise ValueError(f'Please pass in a local dir or a model id, not a local file: {model_dir}') + extra_state_keys = kwargs.pop('extra_state_keys', None) + if extra_state_keys is None and os.path.isfile(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME)): + with open(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME), 'r', encoding='utf-8') as file: + _json = json.load(file) + extra_state_keys = _json.get('extra_state_keys') + if adapter_name is None: + adapter_name = [ + sub_dir for sub_dir in os.listdir(model_dir) + if os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME)) and sub_dir != cls.EXTRA_STATE_DIR + ] + for _name in adapter_name if isinstance(adapter_name, + list) else [adapter_name] \ + if isinstance(adapter_name, str) else adapter_name.keys(): + sub_folder = os.path.join(model_dir, _name) + config_file = os.path.join(sub_folder, CONFIG_NAME) + + if not os.path.isfile(config_file): + logger.warning(f'{_name} is not a valid tuner') + continue + + with open(config_file, 'r', encoding='utf-8') as file: + json_object = json.load(file) + + if SWIFT_TYPE_KEY not in json_object: + raise ValueError('Mixed using with peft is not allowed now.') + else: + key = _name if not isinstance(adapter_name, dict) else adapter_name[_name] + adapters[key] = SwiftConfig.from_pretrained(sub_folder) + + self = SwiftModel(model, adapters, extra_state_keys, inference_mode, **kwargs) + for _name in adapter_name if isinstance(adapter_name, + list) else [adapter_name] \ + if isinstance(adapter_name, str) else adapter_name.keys(): + _adapter = _name if not isinstance(adapter_name, dict) else adapter_name[_name] + output: SwiftOutput = self.adapters[_adapter] + sub_folder = os.path.join(model_dir, _name) + if output.load_callback: + output.load_callback(self, sub_folder, _adapter) + continue + state_dict = cls.load_state_file(sub_folder) + if state_dict is not None: + if isinstance(adapter_name, dict): + # TODO this logic is fragile! replace `_name` may cause other parts replaced + state_dict = {key.replace(_name, adapter_name[_name]): value for key, value in state_dict.items()} + self.load_state_dict(state_dict, adapter_name=_adapter) + state_dict = cls.load_state_file(os.path.join(model_dir, self.EXTRA_STATE_DIR)) + if state_dict is not None: + self.load_state_dict(state_dict) + return self + + @classmethod + def _prepare_model( + cls, + model: nn.Module, + config: SwiftConfig, + adapter_name: str, + ): + assert (hasattr(config, SWIFT_TYPE_KEY)) + from .mapping import SWIFT_MAPPING + + adapter_cls = SWIFT_MAPPING[config.swift_type][1] + if adapter_cls.has_additional_modules() and not getattr(model, 'model_frozen', False): + for _, p in model.named_parameters(): + p.requires_grad = False + model.model_frozen = True + config.has_additional_modules = adapter_cls.has_additional_modules() + return adapter_cls.prepare_model(model, config, adapter_name) + + def create_or_update_model_card(self, output_dir: str): + """ + Updates or create the model card. + """ + if not os.path.exists(os.path.join(output_dir, 'README.md')): + lines = [] + else: + with open(os.path.join(output_dir, 'README.md'), 'r', encoding='utf-8') as f: + lines = f.readlines() + + quantization_config = None + if hasattr(self.base_model, 'config') and hasattr(self.base_model.config, 'quantization_config'): + if hasattr(self.base_model.config.quantization_config, 'to_dict'): + quantization_config = self.base_model.config.quantization_config.to_dict() + training_config_text = '' + # Adds quantization information if it was used + if quantization_config is not None: + training_config_text += '\nThe following `bitsandbytes` quantization config was used during training:\n' + training_config_text += '\n'.join([f'- {name}: {value}' for name, value in quantization_config.items()]) + training_config_text += '\n' + + training_procedure_heading = '## Training procedure\n' + if training_procedure_heading in lines: + lines.insert(lines.index(training_procedure_heading) + 2, training_config_text) + else: + lines.append(f'{training_procedure_heading}\n{training_config_text}') + + framework_block_heading = '### Framework versions\n' + from swift.version import __version__ + if framework_block_heading in lines: + lines.insert(lines.index(framework_block_heading) + 2, f'- SWIFT {__version__}\n') + else: + lines.append(f'{framework_block_heading}\n\n- SWIFT {__version__}\n') + + base_model_heading = '### Base model information\n' + lines.append(f'{base_model_heading}\n\n- BaseModel Class {self.base_model.__class__.__name__}\n') + + # write the lines back to README.md + with open(os.path.join(output_dir, 'README.md'), 'w', encoding='utf-8') as f: + f.writelines(lines) + + def add_weighted_adapter( + self, + adapters, + weights, + adapter_name, + combination_type='svd', + svd_rank=None, + svd_clamp=None, + svd_full_matrices=True, + svd_driver=None, + density=None, + majority_sign_method: Literal['total', 'frequency'] = 'total', + ): + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + svd_driver (`str`, *optional*): + Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be + one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd` + documentation. Defaults to None. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + from swift.tuners.lora import LoraModel + lora_model = LoraModel(self.model, None, '') + lora_model.peft_config = {key: value.config for key, value in self.adapters.items()} + from peft.tuners.lora import LoraLayer + lora_model.targeted_module_names = [ + key for key, value in self.model.named_modules() if isinstance(value, LoraLayer) + ] + lora_model.active_adapter = self.active_adapters + lora_model.add_weighted_adapter( + adapters=adapters, + weights=weights, + adapter_name=adapter_name, + combination_type=combination_type, + svd_rank=svd_rank, + svd_clamp=svd_clamp, + svd_full_matrices=svd_full_matrices, + svd_driver=svd_driver, + density=density, + majority_sign_method=majority_sign_method, + ) + + def state_dict_callback(state_dict, adapter_name, cfg): + from swift.tuners.lora_layers import lora_state_dict + return lora_state_dict(state_dict, adapter_name, cfg.bias) + + def mark_trainable_callback(model, cfg): + from swift.tuners.lora_layers import mark_lora_as_trainable + mark_lora_as_trainable(model, adapter_name, cfg.bias) + + cfg = lora_model.peft_config[adapter_name] + cfg.has_additional_modules = True + self.adapters[adapter_name] = SwiftOutput( + config=cfg, + state_dict_callback=partial(state_dict_callback, cfg=cfg), + mark_trainable_callback=partial(mark_trainable_callback, cfg=cfg), + optimizer_group_callback=None, + ) + + self.set_active_adapters(adapter_name) + + def save_pretrained(self, + save_directory: str, + safe_serialization: bool = False, + adapter_name: Union[str, List[str]] = None, + **kwargs): + """Save the adapters to a local directory. + + Args: + save_directory (`str`): The directory to use. + safe_serialization (`bool`): Use safe tensors to save the weights, default False. + adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `None` to save all. + """ + peft_format = kwargs.pop('peft_format', False) + if os.path.isfile(save_directory): + raise ValueError(f'Provided path ({save_directory}) should be a directory, not a file') + os.makedirs(save_directory, exist_ok=True) + if not self.has_additional_modules: + if hasattr(self.base_model, 'save_pretrained'): + self.base_model.save_pretrained(save_directory, safe_serialization=safe_serialization) + else: + self._save_state_dict(self.base_model.state_dict(), save_directory, safe_serialization) + self.create_or_update_model_card(save_directory) + else: + self.create_or_update_model_card(save_directory) + + adapter_names = adapter_name if isinstance(adapter_name, list) or adapter_name is None else [adapter_name] + + state_dict_kwargs = {} + state_dict = kwargs.get('state_dict') + if state_dict is not None: + state_dict_kwargs['state_dict'] = kwargs['state_dict'] + for adapter_name, output in self.adapters.items(): + if adapter_names is not None and adapter_name not in adapter_names: + continue + + save_to_peft = peft_format and output.config.swift_type == SwiftTuners.LORA + save_to_peft = save_to_peft and output.config.can_be_saved_to_peft() + if peft_format and not save_to_peft: + logger.error('You are using additional lora parameters, which is not compatible with peft,' + 'which is unable to save to peft format.') + output_dir = os.path.join(save_directory, + adapter_name) if adapter_name != 'default' or not save_to_peft else save_directory + + if save_to_peft: + config = output.config.to_peft_config() + config.save_pretrained(output_dir) + else: + output.config.save_pretrained(output_dir) + + if output.save_callback: + output.save_callback(self, output_dir, adapter_name) + continue + + # save only the trainable weights + output_state_dict = self.state_dict( + adapter_name=adapter_name, save_extra_states=False, peft_format=save_to_peft, **state_dict_kwargs) + os.makedirs(output_dir, exist_ok=True) + if output_state_dict and output.config.has_additional_modules: + self._save_state_dict(output_state_dict, output_dir, safe_serialization) + + output_state_dict = self.state_dict(save_extra_states=True, save_adapter=False, **state_dict_kwargs) + if len(output_state_dict) > 0: + if self.has_additional_modules: + os.makedirs(os.path.join(save_directory, self.EXTRA_STATE_DIR), exist_ok=True) + self._save_state_dict(output_state_dict, os.path.join(save_directory, self.EXTRA_STATE_DIR), + safe_serialization) + with open( + os.path.join(save_directory, self.EXTRA_STATE_DIR, CONFIG_NAME), 'w', encoding='utf-8') as file: + json.dump({'extra_state_keys': self.extra_state_keys}, file) + else: + logger.error('Full parameter training, save_extra_states will be ignored') + + if not os.path.exists(os.path.join(save_directory, 'configuration.json')): + with open(os.path.join(save_directory, 'configuration.json'), 'w', encoding='utf-8') as f: + f.write('{}') + + @staticmethod + def _save_state_dict(output_state_dict, save_directory, safe_serialization): + if safe_serialization: + from safetensors.torch import save_file as safe_save_file + safe_save_file( + output_state_dict, os.path.join(save_directory, SAFETENSORS_WEIGHTS_NAME), metadata={'format': 'pt'}) + else: + torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + @contextmanager + def disable_adapter(self): + try: + self.set_active_adapters(adapter_names=[]) + yield + finally: + self.set_active_adapters(adapter_names=self.adapters.keys()) + + def set_active_adapters(self, adapter_names: Union[List[str], str], offload: str = None): + """Set activated adapters + + Args: + adapter_names(`Union[List[str], str]`): The adapters needed to be activated + offload(`str`): Whether to offload the deactivated ones to `cpu` or `meta` device + """ + if not adapter_names: + adapter_names = [] + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + adapter_names = set(adapter_names) + for adapter_name in (adapter_names & set(self.adapters.keys())): + self.activate_adapter(adapter_name) + + for adapter_name in (set(self.adapters.keys()) - adapter_names): + self.deactivate_adapter(adapter_name, offload) + + self.active_adapters = (adapter_names & set(self.adapters.keys())) + + def activate_adapter(self, adapter_name: str): + """Activate one adapter + + Args: + adapter_name(`str`): The adapter needed to be activated + """ + if adapter_name not in self.adapters: + logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, True) + self.active_adapters = self.active_adapters | {adapter_name} + + def deactivate_adapter(self, adapter_name: str, offload: str = None): + """Deactivate one adapter + + Args: + adapter_name(`str`): The adapter needed to be activated + offload(`str`): Whether to offload to `cpu` or `meta` device + """ + if adapter_name not in self.adapters: + logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}') + return + + from .mapping import SWIFT_MAPPING + SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\ + .activate_adapter(self.base_model, adapter_name, False, offload=offload) + self.active_adapters = self.active_adapters - {adapter_name} + + def get_trainable_parameters(self): + """ + Get the content of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in self.base_model.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, 'ds_numel'): + num_params = param.ds_numel + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + return f'trainable params: {trainable_params:,d} || all params: {all_param:,d} ' \ + f'|| trainable%: {100 * trainable_params / all_param:.4f}' \ + '|| cuda memory: ' \ + f'{sum([torch.cuda.memory_allocated(i) for i in range(get_device_count())])/1024/1024/1024:.2f}' \ + 'GiB.' + + +class Swift: + """The Wrapper to use both Peft and Swift tuners.""" + + @staticmethod + def prepare_model(model: Union[nn.Module, SwiftModel], config: Union[SwiftConfig, PeftConfig, + Dict[str, SwiftConfig]], **kwargs): + """Prepare a model by the input config. + + Args: + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. + config(`Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]]`): The config or config dict, can be either + SwiftConfigs or PeftConfigs + **kwargs: + Extra kwargs needed by SwiftModel or PeftModel. + Returns: + The model wrapped by SwiftModel or PeftModel. + """ + + if isinstance(config, (SwiftConfig, dict)): + return SwiftModel(model, config, **kwargs) + else: + return get_peft_model(model, config, **kwargs) + + @staticmethod + def merge_and_unload(model: Union[PeftModel, SwiftModel], **kwargs): + """Merge tuners into the base model and unload them. + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + kwargs: + adapter_name(`Union[str, List[str]]`): The adapter_name to unload, only supported in swift tuners. + + """ + from peft import PeftModel as _PeftModel + if isinstance(model, _PeftModel): + model.merge_and_unload() + elif isinstance(model, SwiftModel): + from swift import LoRAConfig + from swift.tuners import LoRA + adapter_name = kwargs.get('adapter_name', None) + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + for adapter, output in model.adapters.items(): + if isinstance(output.config, LoRAConfig) and (adapter_name is None or adapter in adapter_name): + LoRA.unpatch_lora(model, output.config, adapter) + + @staticmethod + @contextmanager + def grpo_context(model: Union[SwiftModel, torch.nn.Module], processor): + # Save the model and temporarily modify model.model_dir. + if not isinstance(model, SwiftModel): + yield + return + else: + assert len(model.adapters) == 1 + adapter = list(model.adapters.values())[0] + if adapter.config.swift_type == SwiftTuners.LLAMAPRO: + from modelscope.hub.utils.utils import get_cache_dir + temp_dir = tempfile.mkdtemp(dir=get_cache_dir()) + model_dir = model.model_dir + from transformers.integrations import is_deepspeed_zero3_enabled + if is_deepspeed_zero3_enabled(): + raise ValueError('DeepSpeed ZeRO3 not supported for LLaMAPro&GRPO currently.') + model.base_model.save_pretrained(temp_dir) + processor.save_pretrained(temp_dir) + model.model_dir = temp_dir + yield + if adapter.config.swift_type == SwiftTuners.LLAMAPRO: + model.model_dir = model_dir + shutil.rmtree(temp_dir) + + @staticmethod + def merge(model: Union[PeftModel, SwiftModel], **kwargs): + """Merge tuners into the base model, will not unload them. + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + """ + from .lora_layers import LoraLayer, LoRALayer + for sub_module in model.modules(): + if isinstance(sub_module, (LoraLayer, LoRALayer)): + sub_module.merge(**kwargs) + + @staticmethod + def unmerge(model: Union[PeftModel, SwiftModel], **kwargs): + """Unmerge tuners from the base model + + Args: + model(`Union[PeftModel, SwiftModel]`): The model instance with tuners + """ + from .lora_layers import LoraLayer, LoRALayer + for sub_module in model.modules(): + if isinstance(sub_module, (LoraLayer, LoRALayer)): + sub_module.unmerge(**kwargs) + + @staticmethod + def save_to_peft_format(ckpt_dir: str, output_dir: str) -> None: + """Save swift format to peft format + + Args: + ckpt_dir(`str`): Original swift output dir + output_dir(`str`): Converted peft format dir + """ + assert ckpt_dir and output_dir, 'Please pass in valid ckpt_dir and output_dir.' + assert os.path.exists(ckpt_dir), f'ckpt_dir: {ckpt_dir} must exists in local disk.' + if os.path.exists(os.path.join(ckpt_dir, SwiftModel.EXTRA_STATE_DIR)): + raise AssertionError('Cannot transfer to peft format, because you are additional state dicts.') + + adapter_names = [ + sub_dir for sub_dir in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, sub_dir, CONFIG_NAME)) + ] + + def has_custom_content(_json): + if _json.get('swift_type', _json.get('peft_type')) != SwiftTuners.LORA: + logger.warn('Only LoRA can be converted to peft format') + return True + + from swift import LoRAConfig + return not LoRAConfig(**_json).can_be_saved_to_peft() + + for adapter in adapter_names: + with open(os.path.join(ckpt_dir, adapter, CONFIG_NAME), encoding='utf-8') as f: + _json = json.load(f) + if has_custom_content(_json): + raise AssertionError('Cannot transfer to peft format, ' + 'because you have special parameters or adapter types.') + + os.makedirs(output_dir, exist_ok=True) + if ckpt_dir != output_dir: + shutil.copytree(ckpt_dir, output_dir, dirs_exist_ok=True) + + for adapter in adapter_names: + safe_serialization = os.path.isfile(os.path.join(output_dir, adapter, SAFETENSORS_WEIGHTS_NAME)) + state_dict = SwiftModel.load_state_file(os.path.join(output_dir, adapter)) + new_state_dict = {} + for key, value in state_dict.items(): + if not key.startswith('base_model.model.'): + key = 'base_model.model.' + key + key = key.replace(f'lora_A.{adapter}.', 'lora_A.') + key = key.replace(f'lora_B.{adapter}.', 'lora_B.') + key = key.replace(f'lora_embedding_A.{adapter}.', 'lora_embedding_A.') + key = key.replace(f'lora_embedding_B.{adapter}.', 'lora_embedding_B.') + key = key.replace(f'lora_magnitude_vector.{adapter}', 'lora_magnitude_vector') + new_state_dict[key] = value + state_dict = new_state_dict + SwiftModel._save_state_dict(state_dict, os.path.join(output_dir, adapter), safe_serialization) + from swift import LoRAConfig + with open(os.path.join(output_dir, adapter, CONFIG_NAME), encoding='utf-8') as f: + _json = json.load(f) + peft_config = LoRAConfig(**_json).to_peft_config() + peft_config.save_pretrained(os.path.join(output_dir, adapter)) + + if 'default' in adapter_names: + shutil.move(os.path.join(output_dir, 'default', CONFIG_NAME), os.path.join(output_dir, CONFIG_NAME)) + state_dict = SwiftModel.load_state_file(os.path.join(output_dir, 'default')) + safe_serialization = os.path.isfile(os.path.join(output_dir, 'default', SAFETENSORS_WEIGHTS_NAME)) + SwiftModel._save_state_dict(state_dict, output_dir, safe_serialization) + shutil.rmtree(os.path.join(output_dir, 'default')) + + @staticmethod + def from_pretrained(model: Union[nn.Module, SwiftModel, PeftModel], + model_id: str = None, + adapter_name: Union[str, List[str], Dict[str, str]] = None, + revision: str = None, + **kwargs): + """Prepare a model by a model_id in the ModelScope hub or a local dir. + + Args: + model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned. + model_id(`str`): The model id of the modelhub or a local dir containing the configs/weights. + adapter_name(`str`, `optional`): The adapter_name to use. + revision(`str`, `optional`): The model revision if the model_id is a model id of the modelhub. + **kwargs: + Extra kwargs needed by ``SwiftModel.from_pretrained`` or ``PeftModel.from_pretrained``. + Returns: + The model wrapped by SwiftModel or PeftModel. + """ + if not os.path.exists(model_id): + model_id = snapshot_download(model_id, revision=revision) + is_peft_model = False + if os.path.exists(os.path.join(model_id, CONFIG_NAME)): + with open(os.path.join(model_id, CONFIG_NAME), 'r', encoding='utf-8') as f: + _json = json.load(f) + is_peft_model = SWIFT_TYPE_KEY not in _json + + _name = adapter_name if isinstance( + adapter_name, str) or adapter_name is None else adapter_name[0] \ + if isinstance(adapter_name, list) else list(adapter_name.keys())[0] + _name = _name or '' + if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)): + with open(os.path.join(model_id, _name, CONFIG_NAME), 'r', encoding='utf-8') as f: + _json = json.load(f) + is_peft_model = SWIFT_TYPE_KEY not in _json and 'extra_state_keys' not in _json + if is_peft_model: + + def load_peft_model(_model, _adapter_name, _new_name=None): + if not _new_name: + _new_name = _adapter_name + import peft + if not isinstance(_model, peft.PeftModel): + return PeftModel.from_pretrained( + _model, + os.path.join(model_id, _adapter_name) if _adapter_name != 'default' + and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, + revision=revision, + adapter_name=_new_name, + **kwargs) + else: + _model.load_adapter( + os.path.join(model_id, _adapter_name) if _adapter_name != 'default' + and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, _new_name) + return _model + + if not adapter_name: + peft_model = load_peft_model(model, 'default') + for _dir in os.listdir(model_id): + if os.path.isdir(os.path.join(model_id, _dir)) and \ + os.path.exists(os.path.join(model_id, _dir, CONFIG_NAME)): + peft_model = load_peft_model(peft_model, _dir) + elif isinstance(adapter_name, str): + return load_peft_model(model, adapter_name) + elif isinstance(adapter_name, list): + peft_model = model + for name in adapter_name: + peft_model = load_peft_model(peft_model, name) + else: + peft_model = model + for key, value in adapter_name.items(): + peft_model = load_peft_model(peft_model, key, value) + return peft_model + else: + return SwiftModel.from_pretrained(model, model_id, revision=revision, adapter_name=adapter_name, **kwargs) diff --git a/ms-swift/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cb20ea79ddff3447945e8f58d2d2ec5b394fcf6 Binary files /dev/null and b/ms-swift/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/tuners/longlora/llama.py b/ms-swift/swift/tuners/longlora/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..6c54abcc05c1b4a1d3c998cd9a1ed365ea08486f --- /dev/null +++ b/ms-swift/swift/tuners/longlora/llama.py @@ -0,0 +1,409 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from dvlab-research/LongLoRA. + +import math +from types import MethodType +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache, StaticCache +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +from swift.utils import get_logger + +logger = get_logger() + + +def _preprocess_qkv_fa2(attn_module, query_states, key_states, value_states, attention_mask): + if attn_module.training: + bsz, q_len = query_states.shape[:2] + group_size = int(q_len * attn_module.config.group_size_ratio) + if q_len % group_size != 0: + raise ValueError(f'The sequence length {q_len} should' + f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}') + + num_group = q_len // group_size + + def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, :, num_heads // 2:] = qkv[:, :, num_heads // 2:].roll(-group_size // 2, dims=1) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim) + return qkv + + query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + if attention_mask is not None: + attention_mask = attention_mask[:, :group_size].repeat(num_group, 1) + + return query_states, key_states, value_states, attention_mask + + +def _preprocess_qkv(attn_module, query_states, key_states, value_states, attention_mask): + if attn_module.training: + bsz, _, q_len = query_states.shape[:3] + group_size = int(q_len * attn_module.config.group_size_ratio) + if q_len % group_size != 0: + raise ValueError(f'The sequence length {q_len} should' + f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}') + + num_group = q_len // group_size + + def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2) + qkv = qkv.transpose(1, 2) + qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim) + return qkv.transpose(1, 2) + + query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) + + return query_states, key_states, value_states, attention_mask + + +def _postprocess_qkv(attn_module, attn_output, q_len): + if attn_module.training: + group_size = int(q_len * attn_module.config.group_size_ratio) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim) + # shift back + attn_output_clone = attn_output.clone() + attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll( + group_size // 2, dims=1) + attn_output = attn_output_clone + return attn_output.transpose(1, 2) + + +def _postprocess_qkv_fa2(attn_module, attn_output, q_len): + if attn_module.training: + group_size = int(q_len * attn_module.config.group_size_ratio) + attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim) + attn_output_clone = attn_output.clone() + # shift back + attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll( + group_size // 2, dims=1) + attn_output = attn_output_clone + return attn_output + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def eager_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # patch position rolling + query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states, + attention_mask) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError(f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + # patch position unrolling + attn_output = _postprocess_qkv(self, attn_output, q_len) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def fa2_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` ' + 'make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers' + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # patch position rolling + query_states, key_states, value_states, attention_mask = _preprocess_qkv_fa2( + self, query_states, key_states, value_states, attention_mask) + from transformers.modeling_flash_attention_utils import _flash_attention_forward + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, 'sliding_window', None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + # patch position unrolling + attn_output = _postprocess_qkv_fa2(self, attn_output, q_len) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa +def sdpa_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + 'The attention layers in this model are transitioning from computing the RoPE embeddings internally ' + 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ' + '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be ' + 'removed and `position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + + if query_states.device.type == 'cuda' and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + # patch position rolling + query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states, + causal_mask) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + # patch position unrolling + attn_output = _postprocess_qkv(self, attn_output, q_len) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def replace_llama_attn(model: nn.Module): + layers = None + for module in model.modules(): + if isinstance(module, torch.nn.ModuleList): + layers = module + break + assert layers is not None + for idx, m in enumerate(layers): + if model.config._attn_implementation == 'flash_attention_2': + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + logger.warn( + 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' # noqa + 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593') + m.self_attn.forward = MethodType(fa2_forward, m.self_attn) + elif model.config._attn_implementation == 'eager': + m.self_attn.forward = MethodType(eager_forward, m.self_attn) + elif model.config._attn_implementation == 'sdpa': + m.self_attn.forward = MethodType(sdpa_forward, m.self_attn) diff --git a/ms-swift/swift/tuners/reft.py b/ms-swift/swift/tuners/reft.py new file mode 100644 index 0000000000000000000000000000000000000000..8179b61ccda8b81241cd583ec039c70665e4077a --- /dev/null +++ b/ms-swift/swift/tuners/reft.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass +from types import MethodType +from typing import List, Literal, Optional + +import json +import torch +from torch import nn + +from swift.utils import get_logger, patch_getattr +from .utils import SwiftAdapter, SwiftConfig, SwiftOutput + +logger = get_logger() + + +@dataclass +class ReftConfig(SwiftConfig): + """ + Train a model with Reft. + Paper: https://arxiv.org/pdf/2404.03592 + + Args: + model_type(`Optional[str]`): The model_type to find down_proj/layers. + layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`. + layers (`Optional[List[int]]`): The layer number to inject. + r(`int`): The rank of Reft. + intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention', + 'ConsreftIntervention', 'LobireftIntervention', + 'DireftIntervention', 'NodireftIntervention']`): The intervention type, + default LoreftIntervention + args (`Optional[str]`): Other reft_args in json-string format + """ + + model_type: Optional[str] = None + layer_key: Optional[str] = None + layers: Optional[List[int]] = None + r: int = 4 + intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention', + 'LobireftIntervention', 'DireftIntervention', + 'NodireftIntervention'] = 'LoreftIntervention' + args: Optional[str] = None + + def __post_init__(self): + from .mapping import SwiftTuners + self.swift_type = SwiftTuners.REFT + if self.args: + self.args = json.loads(self.args) + else: + self.args = {} + + +class Reft(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str): + from swift.utils.import_utils import is_pyreft_available + if not is_pyreft_available(): + raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`') + + import pyreft + from pyreft import ReftModel + from pyreft.interventions import LowRankRotateLayer + from pyreft import ( + NoreftIntervention, + LoreftIntervention, + ConsreftIntervention, + LobireftIntervention, + DireftIntervention, + NodireftIntervention, + ) + + intervention_mapping = { + 'NoreftIntervention': NoreftIntervention, + 'LoreftIntervention': LoreftIntervention, + 'ConsreftIntervention': ConsreftIntervention, + 'LobireftIntervention': LobireftIntervention, + 'DireftIntervention': DireftIntervention, + 'NodireftIntervention': NodireftIntervention, + } + + patch_getattr(ReftModel, 'model') + + def forward(self, x): + self.to(x.device) + return self.forward_origin(x) + + def forward2(self, base, source=None, subspaces=None): + self.to(base.device) + return self.forward_origin(base, source, subspaces) + + if not hasattr(LowRankRotateLayer, 'forward_origin'): + LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward + LowRankRotateLayer.forward = forward + NoreftIntervention.forward_origin = NoreftIntervention.forward + NoreftIntervention.forward = forward2 + LoreftIntervention.forward_origin = LoreftIntervention.forward + LoreftIntervention.forward = forward2 + ConsreftIntervention.forward_origin = ConsreftIntervention.forward + ConsreftIntervention.forward = forward2 + LobireftIntervention.forward_origin = LobireftIntervention.forward + LobireftIntervention.forward = forward2 + DireftIntervention.forward_origin = DireftIntervention.forward + DireftIntervention.forward = forward2 + NodireftIntervention.forward_origin = NodireftIntervention.forward + NodireftIntervention.forward = forward2 + + module_list_key = config.layer_key + if module_list_key is None: + model_key_mapping = Reft.get_model_key_mapping(config.model_type, config) + module_list_key = model_key_mapping.module_list + logger.info(f'Applying Reft to module: {module_list_key}') + module_list: nn.ModuleList = model.get_submodule(module_list_key) + representations = [] + for idx, layer in enumerate(module_list): + if config.layers and idx not in config.layers: + continue + intervention_config = { + 'layer': + idx, + 'component': + module_list_key + f'[{idx}].output', + 'low_rank_dimension': + config.r, + 'intervention': + intervention_mapping[config.intervention_type]( + embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args) + } + representations.append(intervention_config) + + reft_config = pyreft.ReftConfig(representations=representations) + reft_model = pyreft.get_reft_model(model, reft_config, set_device=False) + reft_model.reft_config = reft_model.config + reft_model.config = reft_model.model.config + + def _pre_forward_hook(module, args, kwargs): + if 'base' in kwargs: + return args, kwargs + + if 'input_ids' not in kwargs: + raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.') + # run intervened forward pass + unit_locations = None + if 'intervention_locations' in kwargs: + if kwargs['intervention_locations'].dim() == 3: + unit_locations = { + 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) + } + else: + # this is dummy for lora only baseline + unit_locations = {'sources->base': (None, 0)} + kwargs = { + 'base': { + 'input_ids': kwargs['input_ids'], + 'attention_mask': kwargs['attention_mask'] + }, + 'unit_locations': unit_locations, + 'labels': kwargs['labels'], + 'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None + } + return args, kwargs + + def _post_forward_hook(module, args, kwargs, outputs): + return outputs[1] + + def _generate(self, **kwargs): + # run intervened forward pass + unit_locations = None + if 'intervention_locations' in kwargs: + if kwargs['intervention_locations'].dim() == 3: + unit_locations = { + 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist()) + } + else: + # this is dummy for lora only baseline + unit_locations = {'sources->base': (None, 0)} + + _kwargs = { + 'base': { + 'input_ids': kwargs.pop('input_ids'), + 'attention_mask': kwargs.pop('attention_mask') + }, + 'unit_locations': unit_locations, + 'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None + } + _kwargs = {**_kwargs, **kwargs} + return self.generate_origin(**_kwargs)[1] + + reft_model.generate_origin = reft_model.generate + reft_model.generate = MethodType(_generate, reft_model) + reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) + reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True) + + def save_callback(swift_model, model_dir, adapter_name): + reft_model.save_intervention(save_directory=model_dir, include_model=False) + + def mark_trainable_callback(model): + return + + def load_callback(swift_model, model_dir, adapter_name): + reft_model.load_intervention(model_dir, include_model=False) + + return SwiftOutput( + model=reft_model, + config=config, + mark_trainable_callback=mark_trainable_callback, + save_callback=save_callback, + load_callback=load_callback) + + @staticmethod + def has_additional_modules(): + return True + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + assert activate, 'ReFT does not support deactivate' diff --git a/ms-swift/swift/tuners/scetuning/scetuning.py b/ms-swift/swift/tuners/scetuning/scetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..c105cd1baef206f64d0f9ce82333eab1e94f5dfd --- /dev/null +++ b/ms-swift/swift/tuners/scetuning/scetuning.py @@ -0,0 +1,235 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import re +import types +from dataclasses import dataclass, field +from typing import List, Optional, Union + +import torch +from torch import nn + +from swift.tuners.utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput +from swift.utils import get_logger +from swift.utils.torch_utils import find_sub_module +from .scetuning_components import probe_output_hook + +logger = get_logger() + + +@dataclass +class SCETuningConfig(SwiftConfig): + """ + The configuration class for the SCEdit module. + + 'SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing' by Jiang et al.(2023) + See https://arxiv.org/abs/2312.11392 + + Args: + dims(`Union[List[int], int]`): The dimensions of the hidden states + target_modules(`Union[List[str], str]`): The target module to be replaced, can a regex string + hint_modules(`Union[List[str], str]`): The hint module to be replaced, can a regex string + tuner_mode(`str`): Location of tuner operation. + tuner_op(`str`): Tuner operation. + down_ratio(`float`): The dim down ratio of tuner hidden state. + """ + + dims: Optional[Union[List[int], int]] = field( + default=None, metadata={'help': 'The dimensions of the hidden states'}) + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={'help': 'The target module to be replaced, can be a regex string or name list of full match format'}) + + hint_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={'help': 'The hint modules to be replaced, can be a regex string or name list of full match format'}) + + tuner_mode: str = field( + default='decoder', + metadata={'help': 'Location of tuner operation. The tuner mode choices: encoder, decoder, and identity'}) + + tuner_op: str = field(default='SCEAdapter', metadata={'help': 'The tuner ops choices: SCEAdapter'}) + + down_ratio: float = field(default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'}) + + def __post_init__(self): + from swift.tuners.mapping import SwiftTuners + self.swift_type = SwiftTuners.SCETUNING + + +class SCETuning(SwiftAdapter): + + @staticmethod + def prepare_model(model: nn.Module, config: SCETuningConfig, adapter_name: str) -> SwiftOutput: + """Prepare a model with `SCETuningConfig`""" + module_keys = [key for key, _ in model.named_modules()] + # 1. Matching the hint module + hint_module_ins_list = [] + if config.hint_modules: + if isinstance(config.hint_modules, list): + for module_key in config.hint_modules: + assert module_key in module_keys + h_module = model.get_submodule(module_key) + logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') + if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(h_module)} may not be supported because of its customized forward') + h_module.register_forward_hook(probe_output_hook, with_kwargs=True) + hint_module_ins_list.append(h_module) + else: + for module_key in module_keys: + if re.fullmatch(config.hint_modules, module_key): + h_module = model.get_submodule(module_key) + logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}') + if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)): + logger.warning( + f'Type of {type(h_module)} may not be supported because of its customized forward') + h_module.register_forward_hook(probe_output_hook, with_kwargs=True) + hint_module_ins_list.append(h_module) + if len(hint_module_ins_list) == 0: + logger.error('Cannot match hint modules') + + def _get_module(module): + if isinstance(module, nn.ModuleList): + module = module[-1] + return _get_module(module) + return module + + # 2. Matching the target module + target_module_ins_list = [] + assert config.target_modules is not None + if isinstance(config.target_modules, list): + for module_key in config.target_modules: + assert module_key in module_keys + t_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') + target_module_ins_list.append(_get_module(t_module)) + else: + for module_key in module_keys: + if re.fullmatch(config.target_modules, module_key): + t_module = model.get_submodule(module_key) + logger.info(f'Matching target module [{module_key}] of type {type(t_module)}') + target_module_ins_list.append(_get_module(t_module)) + if len(target_module_ins_list) == 0: + logger.error('Cannot match target modules') + if len(hint_module_ins_list) > 0 and not len(hint_module_ins_list) == len(target_module_ins_list): + logger.info("Target modules' length should be equal with hint modules.") + assert len(hint_module_ins_list) == len(target_module_ins_list) + if isinstance(config.dims, int): + dims = [config.dims for _ in target_module_ins_list] + else: + assert len(config.dims) == len(target_module_ins_list) + dims = config.dims + + # refactor forward function + def _forward_encoder_mode(self, *args, **kwargs): + args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs) + args_type = type(args) + if args_type is tuple: + args = args[0] + if hasattr(self, 'hint'): + hint_out = self.hint.probe_output_data + args_main = getattr(self, f'scetuner_{adapter_name}')(args, hint_out) + else: + args_main = getattr(self, f'scetuner_{adapter_name}')(args) + if args_type is tuple: + args_main = (args_main, ) + return args_main + + def _forward_decoder_mode(self, *args, **kwargs): + args_type = type(args) + if args_type is tuple: + args_sub_tuner = args[0] + args_sub_extra = args[1:] + tuner_module = getattr(self, f'scetuner_{adapter_name}') + args_hidden, args_res = torch.split(args_sub_tuner, args_sub_tuner.shape[1] - tuner_module.dim, 1) + if hasattr(self, 'hint'): + hint_out = self.hint.probe_output_data + args_res_new = tuner_module(args_res, hint_out) + else: + args_res_new = tuner_module(args_res) + args_sub_tuner_new = torch.cat([args_hidden, args_res_new], dim=1) + if args_type is tuple: + args_main = (args_sub_tuner_new, *args_sub_extra) + + args_main = getattr(self, f'forward_origin_{adapter_name}')(*args_main, **kwargs) + return args_main + + # 3. inject the tuners + for tuner_id, t_module in enumerate(target_module_ins_list): + setattr(t_module, f'forward_origin_{adapter_name}', getattr(t_module, 'forward')) + if config.tuner_mode in ('encoder', 'identity'): + _forward = _forward_encoder_mode + elif config.tuner_mode == 'decoder': + _forward = _forward_decoder_mode + else: + raise Exception(f'Error tuner_mode: {config.tuner_mode}') + setattr(t_module, 'forward', types.MethodType(_forward, t_module)) + tuner_op = SCETunerModule( + name=config.tuner_op, + adapter_name=adapter_name, + module_key=str(tuner_id), + dim=dims[tuner_id], + tuner_length=int(dims[tuner_id] * config.down_ratio)) + setattr(t_module, f'scetuner_{adapter_name}', tuner_op) + if len(hint_module_ins_list) > 0: + setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) + + def state_dict_callback(state_dict, adapter_name, **kwargs): + state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key} + return state_dict_new + + def mark_trainable_callback(model): + return + + return SwiftOutput( + config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback) + + @staticmethod + def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None): + modules = find_sub_module(module, f'scetuner_{adapter_name}') + for _module in modules: + _module: ActivationMixin + _module: nn.Module + _module.set_activation(adapter_name, activate) + SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload) + + +class SCETunerModule(nn.Module, ActivationMixin): + + def __init__(self, + name, + adapter_name, + module_key, + dim, + tuner_length, + tuner_type=None, + tuner_weight=None, + act_layer=nn.GELU, + zero_init_last=True, + use_bias=True): + super(SCETunerModule, self).__init__() + super(nn.Module, self).__init__(module_key) + self.name = name + self.adapter_name = adapter_name + self.dim = dim + if name == 'SCEAdapter': + from .scetuning_components import SCEAdapter + self.tuner_op = SCEAdapter( + dim=dim, + adapter_length=tuner_length, + adapter_type=tuner_type, + adapter_weight=tuner_weight, + act_layer=act_layer) + else: + raise Exception(f'Error tuner op {name}') + self.mark_all_sub_modules_as_plugin() + + def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs): + if not self.is_activated(self.adapter_name): + return x + if self.name == 'SCEAdapter': + self.tuner_op.to(x.device) + out = self.tuner_op(x) + else: + raise Exception(f'Error tuner op {self.name}') + return out diff --git a/ms-swift/swift/ui/__init__.py b/ms-swift/swift/ui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3b0163fb48e49cef87c02087e58472af76e74f --- /dev/null +++ b/ms-swift/swift/ui/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .app import webui_main diff --git a/ms-swift/swift/ui/llm_eval/llm_eval.py b/ms-swift/swift/ui/llm_eval/llm_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..05824f1904756ca393678fed74957383665755b4 --- /dev/null +++ b/ms-swift/swift/ui/llm_eval/llm_eval.py @@ -0,0 +1,189 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import re +import sys +import time +from datetime import datetime +from functools import partial +from typing import Type + +import gradio as gr +import json +import torch +from json import JSONDecodeError +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + +from swift.llm import EvalArguments +from swift.ui.base import BaseUI +from swift.ui.llm_eval.eval import Eval +from swift.ui.llm_eval.model import Model +from swift.ui.llm_eval.runtime import EvalRuntime +from swift.utils import get_device_count + + +class LLMEval(BaseUI): + group = 'llm_eval' + + sub_ui = [Model, Eval, EvalRuntime] + + cmd = 'eval' + + locale_dict = { + 'llm_eval': { + 'label': { + 'zh': 'LLM评测', + 'en': 'LLM evaluation', + } + }, + 'more_params': { + 'label': { + 'zh': '更多参数', + 'en': 'More params' + }, + 'info': { + 'zh': '以json格式或--xxx xxx命令行格式填入', + 'en': 'Fill in with json format or --xxx xxx cmd format' + } + }, + 'evaluate': { + 'value': { + 'zh': '开始评测', + 'en': 'Begin Evaluation' + }, + }, + 'gpu_id': { + 'label': { + 'zh': '选择可用GPU', + 'en': 'Choose GPU' + }, + 'info': { + 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU', + 'en': 'Select GPU to train' + } + }, + } + + choice_dict = BaseUI.get_choices_from_dataclass(EvalArguments) + default_dict = BaseUI.get_default_value_from_dataclass(EvalArguments) + arguments = BaseUI.get_argument_names(EvalArguments) + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.TabItem(elem_id='llm_eval', label=''): + default_device = 'cpu' + device_count = get_device_count() + if device_count > 0: + default_device = '0' + with gr.Blocks(): + Model.build_ui(base_tab) + Eval.build_ui(base_tab) + EvalRuntime.build_ui(base_tab) + with gr.Row(): + gr.Textbox(elem_id='more_params', lines=4, scale=20) + gr.Button(elem_id='evaluate', scale=2, variant='primary') + gr.Dropdown( + elem_id='gpu_id', + multiselect=True, + choices=[str(i) for i in range(device_count)] + ['cpu'], + value=default_device, + scale=8) + + cls.element('evaluate').click( + cls.eval_model, list(base_tab.valid_elements().values()), + [cls.element('runtime_tab'), cls.element('running_tasks')]) + + base_tab.element('running_tasks').change( + partial(EvalRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], + list(base_tab.valid_elements().values()) + [cls.element('log')]) + EvalRuntime.element('kill_task').click( + EvalRuntime.kill_task, + [EvalRuntime.element('running_tasks')], + [EvalRuntime.element('running_tasks')] + [EvalRuntime.element('log')], + ) + + @classmethod + def eval(cls, *args): + eval_args = cls.get_default_value_from_dataclass(EvalArguments) + kwargs = {} + kwargs_is_list = {} + other_kwargs = {} + more_params = {} + more_params_cmd = '' + keys = cls.valid_element_keys() + for key, value in zip(keys, args): + compare_value = eval_args.get(key) + compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value + compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value + if key in eval_args and compare_value_ui != compare_value_arg and value: + if isinstance(value, str) and re.fullmatch(cls.int_regex, value): + value = int(value) + elif isinstance(value, str) and re.fullmatch(cls.float_regex, value): + value = float(value) + elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value): + value = True if value.lower() == 'true' else False + kwargs[key] = value if not isinstance(value, list) else ' '.join(value) + kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False) + else: + other_kwargs[key] = value + if key == 'more_params' and value: + try: + more_params = json.loads(value) + except (JSONDecodeError or TypeError): + more_params_cmd = value + + kwargs.update(more_params) + model = kwargs.get('model') + if model and os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')): + kwargs['ckpt_dir'] = kwargs.pop('model') + + eval_args = EvalArguments( + **{ + key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value + for key, value in kwargs.items() + }) + params = '' + sep = f'{cls.quote} {cls.quote}' + for e in kwargs: + if isinstance(kwargs[e], list): + params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} ' + elif e in kwargs_is_list and kwargs_is_list[e]: + all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()] + params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} ' + else: + params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} ' + params += more_params_cmd + ' ' + devices = other_kwargs['gpu_id'] + devices = [d for d in devices if d] + assert (len(devices) == 1 or 'cpu' not in devices) + gpus = ','.join(devices) + cuda_param = '' + if gpus != 'cpu': + if is_torch_npu_available(): + cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}' + elif is_torch_cuda_available(): + cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}' + else: + cuda_param = '' + now = datetime.now() + time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}' + file_path = f'output/{eval_args.model_type}-{time_str}' + if not os.path.exists(file_path): + os.makedirs(file_path, exist_ok=True) + log_file = os.path.join(os.getcwd(), f'{file_path}/run_eval.log') + eval_args.log_file = log_file + params += f'--log_file "{log_file}" ' + params += '--ignore_args_error true ' + if sys.platform == 'win32': + if cuda_param: + cuda_param = f'set {cuda_param} && ' + run_command = f'{cuda_param}start /b swift eval {params} > {log_file} 2>&1' + else: + run_command = f'{cuda_param} nohup swift eval {params} > {log_file} 2>&1 &' + return run_command, eval_args, log_file + + @classmethod + def eval_model(cls, *args): + run_command, eval_args, log_file = cls.eval(*args) + os.system(run_command) + time.sleep(2) + return gr.update(open=True), EvalRuntime.refresh_tasks(log_file) diff --git a/ms-swift/swift/ui/llm_eval/runtime.py b/ms-swift/swift/ui/llm_eval/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..03c90b81b0dfd454562a9ed1786ef224e0f0c3ce --- /dev/null +++ b/ms-swift/swift/ui/llm_eval/runtime.py @@ -0,0 +1,108 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr +from packaging import version + +from swift.ui.base import BaseUI +from swift.ui.llm_infer.runtime import Runtime +from swift.utils import get_logger + +logger = get_logger() + + +class EvalRuntime(Runtime): + + group = 'llm_eval' + + cmd = 'eval' + + locale_dict = { + 'runtime_tab': { + 'label': { + 'zh': '运行时', + 'en': 'Runtime' + }, + }, + 'running_cmd': { + 'label': { + 'zh': '运行命令', + 'en': 'Command line' + }, + 'info': { + 'zh': '执行的实际命令', + 'en': 'The actual command' + } + }, + 'show_log': { + 'value': { + 'zh': '展示评测状态', + 'en': 'Show eval status' + }, + }, + 'stop_show_log': { + 'value': { + 'zh': '停止展示', + 'en': 'Stop showing running status' + }, + }, + 'log': { + 'label': { + 'zh': '日志输出', + 'en': 'Logging content' + }, + 'info': { + 'zh': '如果日志无更新请再次点击"展示日志内容"', + 'en': 'Please press "Show log" if the log content is not updating' + } + }, + 'running_tasks': { + 'label': { + 'zh': '运行中评测', + 'en': 'Running evaluation' + }, + 'info': { + 'zh': '所有的swift eval命令启动的任务', + 'en': 'All tasks started by swift eval' + } + }, + 'refresh_tasks': { + 'value': { + 'zh': '找回评测', + 'en': 'Find evaluation' + }, + }, + 'kill_task': { + 'value': { + 'zh': '杀死评测', + 'en': 'Kill evaluation' + }, + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Accordion(elem_id='runtime_tab', open=False, visible=True): + with gr.Blocks(): + with gr.Row(): + gr.Dropdown(elem_id='running_tasks', scale=10) + gr.Button(elem_id='refresh_tasks', scale=1, variant='primary') + gr.Button(elem_id='show_log', scale=1, variant='primary') + gr.Button(elem_id='stop_show_log', scale=1) + gr.Button(elem_id='kill_task', scale=1, size='lg') + with gr.Row(): + gr.Textbox(elem_id='log', lines=6, visible=False) + + concurrency_limit = {} + if version.parse(gr.__version__) >= version.parse('4.0.0'): + concurrency_limit = {'concurrency_limit': 5} + cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then( + cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit) + + base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], []) + + base_tab.element('refresh_tasks').click( + cls.refresh_tasks, + [base_tab.element('running_tasks')], + [base_tab.element('running_tasks')], + ) diff --git a/ms-swift/swift/ui/llm_export/__init__.py b/ms-swift/swift/ui/llm_export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b937315b6e719ae8289fee2908aa486222eb76c5 --- /dev/null +++ b/ms-swift/swift/ui/llm_export/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. diff --git a/ms-swift/swift/ui/llm_export/export.py b/ms-swift/swift/ui/llm_export/export.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4ee80c3bbefcbcb4b232fa146a25f9857b5169 --- /dev/null +++ b/ms-swift/swift/ui/llm_export/export.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.llm.dataset.register import get_dataset_list +from swift.ui.base import BaseUI + + +class Export(BaseUI): + + group = 'llm_export' + + locale_dict = { + 'merge_lora': { + 'label': { + 'zh': '合并lora', + 'en': 'Merge lora' + }, + 'info': { + 'zh': + 'lora合并的路径在填入的checkpoint同级目录,请查看运行时log获取更具体的信息', + 'en': + 'The output path is in the sibling directory as the input checkpoint. ' + 'Please refer to the runtime log for more specific information.' + }, + }, + 'device_map': { + 'label': { + 'zh': '合并lora使用的device_map', + 'en': 'The device_map when merge-lora' + }, + 'info': { + 'zh': '如果显存不够请填入cpu', + 'en': 'If GPU memory is not enough, fill in cpu' + }, + }, + 'quant_bits': { + 'label': { + 'zh': '量化比特数', + 'en': 'Quantize bits' + }, + }, + 'quant_method': { + 'label': { + 'zh': '量化方法', + 'en': 'Quantize method' + }, + }, + 'quant_n_samples': { + 'label': { + 'zh': '量化集采样数', + 'en': 'Sampled rows from calibration dataset' + }, + }, + 'max_length': { + 'label': { + 'zh': '量化集的max-length', + 'en': 'The quantize sequence length' + }, + }, + 'output_dir': { + 'label': { + 'zh': '输出路径', + 'en': 'Output dir' + }, + }, + 'dataset': { + 'label': { + 'zh': '校准数据集', + 'en': 'Calibration datasets' + }, + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Row(): + gr.Checkbox(elem_id='merge_lora', scale=10) + gr.Textbox(elem_id='device_map', scale=20) + with gr.Row(): + gr.Dropdown(elem_id='quant_bits', scale=20) + gr.Dropdown(elem_id='quant_method', scale=20) + gr.Textbox(elem_id='quant_n_samples', scale=20) + gr.Textbox(elem_id='max_length', scale=20) + with gr.Row(): + gr.Textbox(elem_id='output_dir', scale=20) + gr.Dropdown( + elem_id='dataset', multiselect=True, allow_custom_value=True, choices=get_dataset_list(), scale=20) diff --git a/ms-swift/swift/ui/llm_train/galore.py b/ms-swift/swift/ui/llm_train/galore.py new file mode 100644 index 0000000000000000000000000000000000000000..b16016e6cb4981e2c20c21a2311d94f06b9e38ea --- /dev/null +++ b/ms-swift/swift/ui/llm_train/galore.py @@ -0,0 +1,58 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.ui.base import BaseUI + + +class Galore(BaseUI): + + group = 'llm_train' + + locale_dict = { + 'galore_tab': { + 'label': { + 'zh': 'Galore参数设置', + 'en': 'Galore Settings' + }, + }, + 'use_galore': { + 'label': { + 'zh': '使用GaLore', + 'en': 'Use GaLore' + }, + 'info': { + 'zh': '使用Galore来减少全参数训练的显存消耗', + 'en': 'Use Galore to reduce GPU memory usage in full parameter training' + } + }, + 'galore_rank': { + 'label': { + 'zh': 'Galore的秩', + 'en': 'The rank of Galore' + }, + }, + 'galore_update_proj_gap': { + 'label': { + 'zh': 'Galore project matrix更新频率', + 'en': 'The updating gap of the project matrix' + }, + }, + 'galore_optim_per_parameter': { + 'label': { + 'zh': '为每个Galore Parameter创建单独的optimizer', + 'en': 'Create unique optimizer for per Galore parameter' + }, + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Accordion(elem_id='galore_tab', open=False): + with gr.Blocks(): + with gr.Row(): + gr.Checkbox(elem_id='use_galore', scale=4) + gr.Slider(elem_id='galore_rank', minimum=8, maximum=256, step=8, scale=4) + gr.Slider(elem_id='galore_update_proj_gap', minimum=10, maximum=1000, step=50, scale=4) + gr.Checkbox(elem_id='galore_optim_per_parameter', scale=4) diff --git a/ms-swift/swift/ui/llm_train/lisa.py b/ms-swift/swift/ui/llm_train/lisa.py new file mode 100644 index 0000000000000000000000000000000000000000..d547d7c1bea4233a66fff167e29965fe088fdff4 --- /dev/null +++ b/ms-swift/swift/ui/llm_train/lisa.py @@ -0,0 +1,44 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Type + +import gradio as gr + +from swift.ui.base import BaseUI + + +class Lisa(BaseUI): + + group = 'llm_train' + + locale_dict = { + 'lisa_tab': { + 'label': { + 'zh': 'LISA参数设置', + 'en': 'LISA settings' + }, + }, + 'lisa_activated_layers': { + 'label': { + 'zh': 'LISA激活层数', + 'en': 'LoRA activated layers' + }, + 'info': { + 'zh': 'LISA每次训练的模型层数,调整为正整数代表使用LISA', + 'en': 'Num of layers activated each time, a positive value means using lisa' + } + }, + 'lisa_step_interval': { + 'label': { + 'zh': 'LISA切换layers间隔', + 'en': 'The interval of lisa layers switching' + } + }, + } + + @classmethod + def do_build_ui(cls, base_tab: Type['BaseUI']): + with gr.Accordion(elem_id='lisa_tab', open=False): + with gr.Blocks(): + with gr.Row(): + gr.Textbox(elem_id='lisa_activated_layers') + gr.Textbox(elem_id='lisa_step_interval') diff --git a/ms-swift/swift/utils/__pycache__/__init__.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f8cf0ba089e4903ce3055bf68338b7a5906d1b2 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/constants.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2276b6506bc6051ecb16d09f0c7eddcf4ba62754 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/constants.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/env.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9c60f653f1888d2025a373e18a4fcb329acdd85 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/env.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/import_utils.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/import_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcf4928d4e61f2677808e618b29452121574ccd0 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/import_utils.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/io_utils.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/io_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63434b51f4b1a9b28b2c4a23da65fe65707e3cc4 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/io_utils.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/tb_utils.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/tb_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def9f97467c11d57e81242e4caa4133958a07e97 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/tb_utils.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/__pycache__/torchacc_utils.cpython-310.pyc b/ms-swift/swift/utils/__pycache__/torchacc_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2f3813efc4eb644493eb2796956990d942acaf5 Binary files /dev/null and b/ms-swift/swift/utils/__pycache__/torchacc_utils.cpython-310.pyc differ diff --git a/ms-swift/swift/utils/env.py b/ms-swift/swift/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..32265f07a8d50f696f40623d68dca5d0756fbe4f --- /dev/null +++ b/ms-swift/swift/utils/env.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +from transformers.utils import strtobool + +from .logger import get_logger + +logger = get_logger() + + +def use_hf_hub(): + return strtobool(os.environ.get('USE_HF', '0')) + + +def is_deepspeed_enabled(): + return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', '0')) + + +def use_torchacc() -> bool: + return strtobool(os.getenv('USE_TORCHACC', '0')) + + +def get_dist_setting() -> Tuple[int, int, int, int]: + """return rank, local_rank, world_size, local_world_size""" + rank = int(os.getenv('RANK', -1)) + local_rank = int(os.getenv('LOCAL_RANK', -1)) + world_size = int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1) + # compat deepspeed launch + local_world_size = int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1)) + return rank, local_rank, world_size, local_world_size + + +def get_node_setting(): + node_rank = int(os.getenv('NODE_RANK', 0)) + nnodes = int(os.getenv('NNODES', 1)) + return node_rank, nnodes + + +def is_local_master(): + local_rank = get_dist_setting()[1] + return local_rank in {-1, 0} + + +def is_master(): + rank = get_dist_setting()[0] + return rank in {-1, 0} + + +def torchacc_trim_graph(): + return strtobool(os.getenv('TORCHACC_TRIM_GRAPH', '0')) + + +def is_dist(): + """Determine if the training is distributed""" + if use_torchacc(): + return False + rank, local_rank, _, _ = get_dist_setting() + return rank >= 0 and local_rank >= 0 + + +def is_mp() -> bool: + if use_torchacc(): + return False + if strtobool(os.environ.get('USE_FAST_INFERENCE', 'false')): + return False + from swift.utils import get_device_count + n_gpu = get_device_count() + local_world_size = get_dist_setting()[3] + assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}' + if n_gpu // local_world_size >= 2: + return True + return False + + +def is_mp_ddp() -> bool: + # patch_mp_ddp will occur when `import swift`. + if is_dist() and is_mp(): + logger.info('Using MP(device_map) + DDP') + return True + return False + + +def is_dist_ta() -> bool: + """Determine if the TorchAcc training is distributed""" + _, _, world_size, _ = get_dist_setting() + if use_torchacc() and world_size > 1: + if not dist.is_initialized(): + import torchacc as ta + # Initialize in advance + dist.init_process_group(backend=ta.dist.BACKEND_NAME) + return True + else: + return False + + +def is_pai_training_job() -> bool: + return 'PAI_TRAINING_JOB_ID' in os.environ + + +def get_pai_tensorboard_dir() -> Optional[str]: + return os.environ.get('PAI_OUTPUT_TENSORBOARD') diff --git a/ms-swift/swift/utils/import_utils.py b/ms-swift/swift/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..831e3acd04f65e018b52918be3d89023a426516c --- /dev/null +++ b/ms-swift/swift/utils/import_utils.py @@ -0,0 +1,106 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2023-present the HuggingFace Inc. team. + +import importlib.util +import os +from itertools import chain +from types import ModuleType +from typing import Any + +from .logger import get_logger + +logger = get_logger() # pylint: disable=invalid-name + + +def is_vllm_available(): + return importlib.util.find_spec('vllm') is not None + + +def is_vllm_ascend_available(): + return importlib.util.find_spec('vllm_ascend') is not None + + +def is_lmdeploy_available(): + return importlib.util.find_spec('lmdeploy') is not None + + +def is_liger_available(): + return importlib.util.find_spec('liger_kernel') is not None + + +def is_swanlab_available(): + return importlib.util.find_spec('swanlab') is not None + + +def is_xtuner_available(): + return importlib.util.find_spec('xtuner') is not None + + +def is_megatron_available(): + return importlib.util.find_spec('megatron') is not None + + +def is_unsloth_available() -> bool: + return importlib.util.find_spec('unsloth') is not None + + +def is_pyreft_available() -> bool: + return importlib.util.find_spec('pyreft') is not None + + +def is_wandb_available() -> bool: + return importlib.util.find_spec('wandb') is not None + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f'module {self.__name__} has no attribute {name}') + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + return importlib.import_module('.' + module_name, self.__name__) + + def __reduce__(self): + return self.__class__, (self._name, self.__file__, self._import_structure) diff --git a/ms-swift/swift/utils/io_utils.py b/ms-swift/swift/utils/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d76c78640630ff4f6935e77871990f8dda0c6c24 --- /dev/null +++ b/ms-swift/swift/utils/io_utils.py @@ -0,0 +1,118 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from queue import Queue +from threading import Thread +from typing import Any, Dict, List, Literal, Union + +import json +import requests +import torch.distributed as dist +from accelerate.utils import gather_object +from modelscope.hub.api import ModelScopeConfig +from tqdm import tqdm + +from .env import is_master +from .logger import get_logger +from .utils import check_json_format + +logger = get_logger() + + +def download_ms_file(url: str, local_path: str, cookies=None) -> None: + if cookies is None: + cookies = ModelScopeConfig.get_cookies() + resp = requests.get(url, cookies=cookies, stream=True) + with open(local_path, 'wb') as f: + for data in tqdm(resp.iter_lines()): + f.write(data) + + +def read_from_jsonl(fpath: str, encoding: str = 'utf-8') -> List[Any]: + res: List[Any] = [] + with open(fpath, 'r', encoding=encoding) as f: + for line in f: + res.append(json.loads(line)) + return res + + +def write_to_jsonl(fpath: str, obj_list: List[Any], encoding: str = 'utf-8') -> None: + res: List[str] = [] + for obj in obj_list: + res.append(json.dumps(obj, ensure_ascii=False)) + with open(fpath, 'w', encoding=encoding) as f: + text = '\n'.join(res) + f.write(f'{text}\n') + + +class JsonlWriter: + + def __init__(self, fpath: str, *, encoding: str = 'utf-8', strict: bool = True, enable_async: bool = False): + self.fpath = os.path.abspath(os.path.expanduser(fpath)) if is_master() else None + self.encoding = encoding + self.strict = strict + self.enable_async = enable_async + self._queue = Queue() + self._thread = None + + def _append_worker(self): + while True: + item = self._queue.get() + self._append(**item) + + def _append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False): + if isinstance(obj, (list, tuple)) and all(isinstance(item, dict) for item in obj): + obj_list = obj + else: + obj_list = [obj] + if gather_obj and dist.is_initialized(): + obj_list = gather_object(obj_list) + if not is_master(): + return + obj_list = check_json_format(obj_list) + for i, _obj in enumerate(obj_list): + obj_list[i] = json.dumps(_obj, ensure_ascii=False) + '\n' + self._write_buffer(''.join(obj_list)) + + def append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False): + if self.enable_async: + if self._thread is None: + self._thread = Thread(target=self._append_worker, daemon=True) + self._thread.start() + self._queue.put({'obj': obj, 'gather_obj': gather_obj}) + else: + self._append(obj, gather_obj=gather_obj) + + def _write_buffer(self, text: str): + if not text: + return + assert is_master(), f'is_master(): {is_master()}' + try: + os.makedirs(os.path.dirname(self.fpath), exist_ok=True) + with open(self.fpath, 'a', encoding=self.encoding) as f: + f.write(text) + except Exception: + if self.strict: + raise + logger.error(f'Cannot write content to jsonl file. text: {text}') + + +def append_to_jsonl(fpath: str, obj: Union[Dict, List[Dict]], *, encoding: str = 'utf-8', strict: bool = True) -> None: + jsonl_writer = JsonlWriter(fpath, encoding=encoding, strict=strict) + jsonl_writer.append(obj) + + +def get_file_mm_type(file_name: str) -> Literal['image', 'video', 'audio']: + video_extensions = {'.mp4', '.mkv', '.mov', '.avi', '.wmv', '.flv', '.webm'} + audio_extensions = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a'} + image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} + + _, ext = os.path.splitext(file_name) + + if ext.lower() in video_extensions: + return 'video' + elif ext.lower() in audio_extensions: + return 'audio' + elif ext.lower() in image_extensions: + return 'image' + else: + raise ValueError(f'file_name: {file_name}, ext: {ext}') diff --git a/ms-swift/swift/utils/np_utils.py b/ms-swift/swift/utils/np_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90148b1fd76a3a856a48ae68df6e028f1aa7d412 --- /dev/null +++ b/ms-swift/swift/utils/np_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd + + +def transform_jsonl_to_df(dict_list: List[Dict[str, Any]]) -> pd.DataFrame: + """Relevant function: `io_utils.read_from_jsonl()`""" + data_dict: Dict[str, List[Any]] = {} + for i, obj in enumerate(dict_list): + for k, v in obj.items(): + if k not in data_dict: + data_dict[k] = [None] * i + data_dict[k].append(v) + for k in set(data_dict.keys()) - set(obj.keys()): + data_dict[k].append(None) + return pd.DataFrame.from_dict(data_dict) + + +def get_seed(random_state: Optional[np.random.RandomState] = None) -> int: + if random_state is None: + random_state = np.random.RandomState() + seed_max = np.iinfo(np.int32).max + seed = random_state.randint(0, seed_max) + return seed + + +def stat_array(array: Union[np.ndarray, List[int], 'torch.Tensor']) -> Tuple[Dict[str, float], str]: + if isinstance(array, list): + array = np.array(array) + mean = array.mean().item() + std = array.std().item() + min_ = array.min().item() + max_ = array.max().item() + size = array.shape[0] + string = f'{mean:.6f}±{std:.6f}, min={min_:.6f}, max={max_:.6f}, size={size}' + return {'mean': mean, 'std': std, 'min': min_, 'max': max_, 'size': size}, string diff --git a/ms-swift/swift/utils/torchacc_utils.py b/ms-swift/swift/utils/torchacc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1084303c236365d2cfddaa204e8964769e719b88 --- /dev/null +++ b/ms-swift/swift/utils/torchacc_utils.py @@ -0,0 +1,917 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import sys +import types +from typing import List, Optional, Tuple + +import safetensors +import torch +import torch.nn.functional as F +import transformers +from packaging import version +from peft import PeftModel +from torch.utils.data import DataLoader +from transformers import PreTrainedModel, trainer +from transformers.modeling_utils import unwrap_model + +from swift.utils import get_logger, torchacc_trim_graph, use_torchacc + +logger = get_logger() + + +# DataLoader +def get_bucket_sizes(max_length: int) -> List[int]: + """Get the bucket sizes for TorchAcc. + You can set the environment variable TORCHACC_DATA_BUCKETS to specify + the bucket sizes. If not set, we use a normal distribution bucketing with + 8 buckets. + """ + padding_p_base = 2 + if os.getenv('TORCHACC_DATA_BUCKETS') is not None: + bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')] + bucket_sizes.append(max_length) + else: + if os.getenv('TORCHACC_CACHE_PATH') is not None: # padding strategy when persistent cache is enabled + padding_p_base = 1.4 + padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base) + try: + padding_p_base = float(padding_p_base) + except ValueError as e: + logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}') + raise e + bucket_sizes = [16, 32, 48, 64, 96, 128] + base_size = 256 + while base_size < max_length: + bucket_sizes.append((int(base_size) + 127) // 128 * 128) + base_size *= padding_p_base + bucket_sizes.append(max_length) + + return bucket_sizes + + +def _get_closet_bucket(bucket_sizes, data_length): + """Select the one from bucket_sizes that is closest in distance to + data_length. This is required for TorchAcc. + """ + closest_length = sys.maxsize + for b in bucket_sizes: + if b == data_length or ((b < closest_length) and (b > data_length)): + closest_length = b + + if closest_length == sys.maxsize: + bucket_sizes.append(data_length) + closest_length = data_length + + return closest_length + + +def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank, + world_size, padding_right): + if padding_to is None: + longest_len = input_ids.shape[-1] + bucket_sizes = get_bucket_sizes(max_length) + bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len) + padding_length = bucket_data_length - input_ids.shape[1] + pad_tuple = (0, padding_length) if padding_right else (padding_length, 0) + input_ids = F.pad(input_ids, pad_tuple, 'constant', tokenizer.pad_token_id) + attention_mask = F.pad(attention_mask, pad_tuple, 'constant', 0) + if loss_scale: + loss_scale = F.pad(loss_scale, pad_tuple, 'constant', 0.) + labels = F.pad(labels, pad_tuple, 'constant', -100) + + # manually split the batch to different DP rank. + batch_size = input_ids.shape[0] // world_size + if batch_size > 0: + start = rank * batch_size + end = (rank + 1) * batch_size + input_ids = input_ids[start:end, :] + attention_mask = attention_mask[start:end, :] + labels = labels[start:end, :] + if loss_scale: + loss_scale = loss_scale[start:end, :] + return input_ids, attention_mask, labels, loss_scale + + +def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size): + # patch skip_first_batches for customized dataloader. + def acc_skip_first_batches(dataloader, num_batches=0): + from accelerate.data_loader import SkipBatchSampler + batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches) + try: + dataset = dataloader.dataset + except AttributeError: + dataset = dataloader._loader.dataset + dataloader_params = { + 'collate_fn': data_collator, + 'num_workers': args.dataloader_num_workers, + 'pin_memory': args.dataloader_pin_memory, + 'persistent_workers': args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params['batch_sampler'] = batch_sampler + dataloader_params['worker_init_fn'] = trainer.seed_worker + + return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device) + + trainer.skip_first_batches = acc_skip_first_batches + + # dataloader for TorchAcc. + import torchacc as ta + + dataloader_params = { + 'batch_size': batch_size, + 'collate_fn': data_collator, + 'num_workers': args.dataloader_num_workers, + 'pin_memory': args.dataloader_pin_memory, + 'persistent_workers': args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params['sampler'] = sampler + dataloader_params['drop_last'] = args.dataloader_drop_last + dataloader_params['worker_init_fn'] = trainer.seed_worker + + return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device) + + +def ta_eval_dataloader(eval_dataset, data_collator, sampler, args): + import torchacc as ta + + dataloader_params = { + 'batch_size': args.eval_batch_size, + 'collate_fn': data_collator, + 'num_workers': args.dataloader_num_workers, + 'pin_memory': args.dataloader_pin_memory, + 'persistent_workers': args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params['sampler'] = sampler + dataloader_params['drop_last'] = args.dataloader_drop_last + + return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device) + + +def ta_test_dataloader(test_dataset, data_collator, sampler, args): + import torchacc as ta + + dataloader_params = { + 'batch_size': args.eval_batch_size, + 'collate_fn': data_collator, + 'num_workers': args.dataloader_num_workers, + 'pin_memory': args.dataloader_pin_memory, + 'persistent_workers': args.dataloader_persistent_workers, + } + + if not isinstance(test_dataset, torch.utils.data.IterableDataset): + dataloader_params['sampler'] = sampler + dataloader_params['drop_last'] = args.dataloader_drop_last + + # We use the same batch_size as for eval. + return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device) + + +# Save/load checkpoint +def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir): + import torch_xla.core.xla_model as xm + xm.rendezvous('saving_optimizer_states') + xm.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt'), master_only=False) + xm.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt'), master_only=False) + xm.rendezvous('saving_optimizer_states_done') + + +def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device): + import torch_xla.core.xla_model as xm + optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu') + lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu') + xm.send_cpu_data_to_device(optimizer_state, device) + xm.send_cpu_data_to_device(lr_scheduler_state, device) + + optimizer.load_state_dict(optimizer_state) + lr_scheduler.load_state_dict(lr_scheduler_state) + return optimizer, lr_scheduler + + +def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else args.output_dir + import torch_xla.core.xla_model as xm + + model = self_model + + if xm.is_master_ordinal(local=False): + os.makedirs(output_dir, exist_ok=True) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + + xm.mark_step() + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + supported_classes = (PreTrainedModel, PeftModel) + if not isinstance(model, supported_classes): + if isinstance(unwrap_model(model), supported_classes): + unwrap_model(model).save_pretrained( + output_dir, + is_main_process=args.should_save, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), + save_function=xm.save, + safe_serialization=args.save_safetensors, + ) + else: + logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') + state_dict = xm._maybe_convert_to_cpu(model.state_dict()) + if args.save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) + else: + torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) + else: + model.save_pretrained( + output_dir, + is_main_process=args.should_save, + save_function=xm.save, + safe_serialization=args.save_safetensors, + state_dict=xm._maybe_convert_to_cpu(model.state_dict())) + if tokenizer is not None and args.should_save: + tokenizer.save_pretrained(output_dir) + + +def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir): + import torch_xla.core.xla_model as xm + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + xm.mark_step() + + if xm.is_master_ordinal(local=False): + os.makedirs(output_dir, exist_ok=True) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + + supported_classes = (PreTrainedModel, PeftModel) + model = self_model._get_underlay_model().module.module + unwrapped_model = unwrap_model(model) + + xm.rendezvous('saving_checkpoint') + ckpt = { + 'model': self_model._get_underlay_model().state_dict(), + 'shard_metadata': self_model._get_underlay_model().get_shard_metadata(), + } + if isinstance(model, PeftModel): + ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin') + else: + ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin') + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous('save_full_checkpoints') + + if tokenizer is not None and args.should_save: + tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save) + + # rank 0 consolidates and saves the whole checkpoint. + if xm.is_master_ordinal(local=False): + if isinstance(model, PeftModel): + ckpt_suffix = 'rank*-of-*-adapter_model.bin' + else: + ckpt_suffix = 'rank*-of-*-pytorch_model.bin' + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ''), ckpt_suffix=ckpt_suffix, save_model=False) + + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=args.save_safetensors, + ) + else: + logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') + if args.save_safetensors: + safetensors.torch.save_file(full_state_dict, os.path.join(output_dir, 'model.safetensors')) + else: + torch.save(full_state_dict, os.path.join(output_dir, 'pytorch_model.bin')) + + xm.rendezvous('ckpt_consolidation') + # delete the sharded checkpoint. + os.remove(ckpt_path) + + +def ta_trim_graph(): + if use_torchacc() and torchacc_trim_graph(): + import torchacc as ta + ta.mark_step() + + +# Model patch +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + if position_ids is not None: + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + else: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def patch_acc_model(args, model): + if not args.use_flash_attn: + logger.warn('Currently use flash attn for torchacc.') + if args.model_type.startswith('qwen1half') or args.model_type.startswith('qwen2'): + model = patch_qwen2_model(model) + elif args.model_type.startswith('qwen'): + import torchacc as ta + model = ta.patch_qwen_model(model) + elif args.model_type.startswith('baichuan'): + model = patch_baichuan_model(model) + elif args.model_type.startswith('llama') or args.model_type.startswith('yi'): + model = patch_llama_model(model) + elif args.model_type.startswith('chatglm'): + model = patah_chatglm_model(model) + return model + + +def patch_llama_model(model): + + def update_causal_mask(self, *args, **kwargs): + # attention_mask is not supported in TorchAcc. + return None + + def llama_attn_forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + from torchacc.ops import flash_attn_varlen_xla + import einops + + bsz, q_len, _ = hidden_states.size() + + query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, 'past_key_value is not supported' + + if version.parse(transformers.__version__) >= version.parse('4.36'): + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + assert not output_attentions, 'output_attentions is not supported' + + if past_key_value is not None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = (key_states, value_states) if use_cache else None + + # See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + # if attention_mask is not None: + # value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1) + q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...') + k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...') + v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...') + max_s = q_len + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True) + output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) + + return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value + + for layer in model.model.layers: + layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn) + + if version.parse(transformers.__version__) >= version.parse('4.38'): + model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) + + return model + + +def patah_chatglm_model(model): + + def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + def chatglm_attn_forward(self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + **kwargs): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head)) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) + value_layer = value_layer.contiguous().view(value_layer.size()[:2] + + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head)) + + # ================================== + # core attention computation + # ================================== + + from torchacc.ops import flash_attn_varlen_qkvpacked_xla + import einops + + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + bsz, _, q_len, _ = query_layer.size() + qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) + qkv = qkv.transpose(1, 3) + qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) + context_layer = flash_attn_varlen_qkvpacked_xla( + qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True) + context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz) + context_layer = context_layer.permute(1, 0, 2, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, ) + context_layer = context_layer.reshape(*new_context_layer_shape) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + def torchacc_swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]).to(x[0].dtype) * x[1] + + # patch attention + for layer in model.transformer.encoder.layers: + layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention) + layer.mlp.activation_func = torchacc_swiglu + + return model + + +def patch_baichuan_model(model): + + def baichuan_attn_forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + import einops + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)) + query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + from torchacc.ops import flash_attn_varlen_xla + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True) + output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) + output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')) + return output, None, past_key_value + + for layer in model.base_model.layers: + layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn) + + return model + + +def patch_qwen2_model(model): + + def update_causal_mask(self, *args, **kwargs): + # attention_mask is not supported in TorchAcc. + return None + + def qwen2_attn_forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + output_attentions=False, + use_cache=False, + cache_position=None, + position_embeddings=None, + **kwargs, + ): + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.') + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + rotary_seq_len = kv_seq_len + 1 + + if version.parse(transformers.__version__) >= version.parse('4.45'): + if position_embeddings is None: + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + else: + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reshape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + from torchacc.ops import flash_attn_varlen_xla + import einops + + q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) + + attn_output = flash_attn_varlen_xla( + q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def qwen2_forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds') + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + from transformers.modeling_outputs import BaseModelOutputWithPast + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + for layer in model.model.layers: + layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn) + + if version.parse(transformers.__version__) >= version.parse('4.43'): + model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) + else: + model.model.forward = types.MethodType(qwen2_forward, model.model) + return model + + +def patch_clip_grad_norm(accelerator): + from accelerate.utils import DistributedType + from accelerate.optimizer import AcceleratedOptimizer + import torch_xla.core.xla_model as xm + + def clip_grad_norm_(self, parameters, max_norm, norm_type=2): + """ + Should be used in place of `torch.nn.utils.clip_grad_norm_`. + + Returns: + `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector). + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(gradient_accumulation_steps=2) + >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) + + >>> for input, target in dataloader: + ... optimizer.zero_grad() + ... output = model(input) + ... loss = loss_func(output, target) + ... accelerator.backward(loss) + ... if accelerator.sync_gradients: + ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) + ... optimizer.step() + ``` + """ + if self.distributed_type == DistributedType.FSDP: + self.unscale_gradients() + parameters = [p for p in parameters] + for model in self._models: + if parameters == [p for p in model.parameters()]: + return model.clip_grad_norm_(max_norm, norm_type) + elif self.distributed_type == DistributedType.DEEPSPEED: + # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + # We cannot return the gradient norm because DeepSpeed does it. + return None + elif self.distributed_type == DistributedType.XLA: + # Reduce gradients first for XLA + for acc_opt in self._optimizers: + if not acc_opt.gradient_state.is_xla_gradients_synced: + opt = acc_opt + while isinstance(opt, AcceleratedOptimizer): + opt = opt.optimizer + gradients = xm._fetch_gradients(opt) + # Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor + # one by one in self.reduce is non-inplace. + xm.all_reduce('sum', gradients, scale=1.0 / self.num_processes) + # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step. + acc_opt.gradient_state.is_xla_gradients_synced = True + if os.environ.get('ACCELERATE_USE_FSDP', 'false') == 'true': + self.unscale_gradients() + parameters = [p for p in parameters] + for model in self._models: + if parameters == [p for p in model.parameters()]: + return model._get_underlay_model().clip_grad_norm_(max_norm, norm_type) + self.unscale_gradients() + return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + + # TODO(baole): This should be removed once accelerate is updated. + accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) + return accelerator + + +def ta_accelerate(model, + fsdp_num, + layer_cls_name, + bf16=True, + fp16=False, + gradient_checkpointing=True, + fsdp_flatten_parameters=False): + """ accelerate LLM training using TorchAcc(only available internally). + """ + import torchacc as ta + assert layer_cls_name is not None + + def get_ta_config(): + config = ta.Config() + config.compute.fp16 = fp16 + config.compute.bf16 = bf16 + + config.memory.gc = gradient_checkpointing + if config.memory.gc: + config.memory.gc_cls = {layer_cls_name} + + config.dist.fsdp.size = fsdp_num + config.dist.fsdp.wrap_layer_cls = {layer_cls_name} + config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters + config.dist.dp.size = 1 + + if fsdp_num > 1: + os.environ['ACCELERATE_USE_FSDP'] = 'true' + + return config + + ta_config = get_ta_config() + model = ta.accelerate(model, config=ta_config) + return model diff --git a/ms-swift/tests/eval/test_eval.py b/ms-swift/tests/eval/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..596ebb5b225f06bcc25dc672099343032620d149 --- /dev/null +++ b/ms-swift/tests/eval/test_eval.py @@ -0,0 +1,66 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +infer_backend = 'vllm' + + +def test_eval_native(): + from swift.llm import EvalArguments, eval_main + eval_main( + EvalArguments( + model='Qwen/Qwen2.5-0.5B-Instruct', + eval_dataset='arc', + infer_backend=infer_backend, + eval_backend='Native', + eval_limit=10, + eval_generation_config={ + 'max_new_tokens': 128, + 'temperature': 0.1 + }, + extra_eval_args={ + 'stream': True, + 'ignore_errors': True + }, + )) + + +def test_eval_llm(): + from swift.llm import EvalArguments, eval_main + eval_main( + EvalArguments( + model='Qwen/Qwen2-7B-Instruct', + eval_dataset='arc_c', + infer_backend=infer_backend, + eval_backend='OpenCompass', + eval_limit=10)) + + +def test_eval_mllm(): + from swift.llm import EvalArguments, eval_main + eval_main( + EvalArguments( + model='Qwen/Qwen2.5-VL-3B-Instruct', + eval_dataset=['realWorldQA'], + infer_backend='pt', + eval_backend='VLMEvalKit', + eval_limit=10, + eval_generation_config={ + 'max_new_tokens': 128, + 'temperature': 0.1 + })) + + +def test_eval_url(): + from swift.llm import EvalArguments, eval_main, DeployArguments, run_deploy + deploy_args = DeployArguments(model='Qwen/Qwen2-VL-7B-Instruct', infer_backend=infer_backend, verbose=False) + + with run_deploy(deploy_args, return_url=True) as url: + eval_main(EvalArguments(model='Qwen2-VL-7B-Instruct', eval_url=url, eval_dataset=['arc_c'])) + + +if __name__ == '__main__': + # test_eval_llm() + test_eval_mllm() + # test_eval_url() + # test_eval_native() diff --git a/ms-swift/tests/export/test_quant.py b/ms-swift/tests/export/test_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..d1be16459d0e9217ba2f26bdaff2ec9aa81634fa --- /dev/null +++ b/ms-swift/tests/export/test_quant.py @@ -0,0 +1,69 @@ +import os +from typing import Literal + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def test_llm_quant(quant_method: Literal['gptq', 'awq'] = 'awq'): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2-7B-Instruct', + quant_bits=4, + dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000'], + quant_method=quant_method)) + + +def test_vlm_quant(quant_method: Literal['gptq', 'awq'] = 'awq'): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2-VL-7B-Instruct', + quant_bits=4, + dataset=['modelscope/coco_2014_caption:validation#1000'], + quant_method=quant_method)) + + +def test_audio_quant(quant_method: Literal['gptq', 'awq'] = 'awq'): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model='Qwen/Qwen2-Audio-7B-Instruct', + quant_bits=4, + dataset=['speech_asr/speech_asr_aishell1_trainsets:validation#1000'], + quant_method=quant_method)) + + +def test_vlm_bnb_quant(): + from swift.llm import export_main, ExportArguments, infer_main, InferArguments + export_main(ExportArguments(model='Qwen/Qwen2-VL-7B-Instruct', quant_bits=4, quant_method='bnb')) + + # infer_main(InferArguments(ckpt_dir='Qwen/Qwen2-VL-7B-Instruct-bnb-int4')) + + +def test_bert(): + from swift.llm import export_main, ExportArguments + output_dir = 'output/swift_test_bert_merged' + export_main(ExportArguments(adapters='swift/test_bert', merge_lora=True, output_dir=output_dir)) + export_main( + ExportArguments(model=output_dir, load_data_args=True, quant_bits=4, quant_method='gptq', max_length=512)) + + +def test_reward_model(): + from swift.llm import export_main, ExportArguments + + export_main( + ExportArguments( + model='Shanghai_AI_Laboratory/internlm2-1_8b-reward', + dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000'], + quant_bits=4, + quant_method='gptq')) + + +if __name__ == '__main__': + # test_llm_quant('gptq') + # test_vlm_quant('gptq') + # test_audio_quant('gptq') + # test_vlm_bnb_quant() + # test_bert() + test_reward_model() diff --git a/ms-swift/tests/general/test_arch.py b/ms-swift/tests/general/test_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..bea7f3cd9f1e245be52a4e8e6886c3e3f9b8e24a --- /dev/null +++ b/ms-swift/tests/general/test_arch.py @@ -0,0 +1,44 @@ +def test_model_arch(): + from swift.llm import MODEL_MAPPING, safe_snapshot_download + from transformers import PretrainedConfig + from swift.utils import JsonlWriter + import random + jsonl_writer = JsonlWriter('model_arch.jsonl') + for i, (model_type, model_meta) in enumerate(MODEL_MAPPING.items()): + if i < 0: + continue + arch_list = model_meta.architectures + for model_group in model_meta.model_groups: + model = random.choice(model_group.models).ms_model_id + config_dict = None + try: + model_dir = safe_snapshot_download(model, download_model=False) + config_dict = PretrainedConfig.get_config_dict(model_dir)[0] + except Exception: + pass + finally: + msg = None + if config_dict: + arch = config_dict.get('architectures') + if arch and arch[0] not in arch_list: + msg = { + 'model_type': model_type, + 'model': model, + 'config_arch': arch, + 'architectures': arch_list + } + elif not arch and arch_list: + msg = { + 'model_type': model_type, + 'model': model, + 'config_arch': arch, + 'architectures': arch_list + } + else: + msg = {'msg': 'error', 'model_type': model_type, 'model': model, 'arch_list': arch_list} + if msg: + jsonl_writer.append(msg) + + +if __name__ == '__main__': + test_model_arch() diff --git a/ms-swift/tests/general/test_dataset.py b/ms-swift/tests/general/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..371401fbecc58fddfcab0338d0c88bbb716daee4 --- /dev/null +++ b/ms-swift/tests/general/test_dataset.py @@ -0,0 +1,90 @@ +from typing import List + +from swift.llm import load_dataset + + +def _test_dataset(datasets: List[str], num_proc: int = 1, strict: bool = False, **kwargs): + dataset = load_dataset(datasets, num_proc=num_proc, strict=strict, **kwargs) + print(f'dataset[0]: {dataset[0]}') + print(f'dataset[1]: {dataset[1]}') + + +def test_sft(): + # swift/SlimOrca swift/cosmopedia-100k + # _test_dataset(['lvjianjin/AdvertiseGen']) + # _test_dataset(['AI-ModelScope/Duet-v0.5']) + # _test_dataset(['swift/SlimOrca', 'swift/cosmopedia-100k']) + # _test_dataset(['OmniData/Zhihu-KOL-More-Than-100-Upvotes']) + # _test_dataset(['OmniData/Zhihu-KOL']) + _test_dataset([ + 'AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000', + 'AI-ModelScope/LongAlpaca-12k#1000' + ]) + # _test_dataset(['swift/Infinity-Instruct:all']) + # _test_dataset(['swift/sharegpt:all']) + # _test_dataset(['AI-ModelScope/sharegpt_gpt4:all']) + # _test_dataset(['iic/ms_bench']) + # _test_dataset(['swift/tagengo-gpt4']) + + +def test_mllm(): + # _test_dataset(['AI-ModelScope/ShareGPT4V:all']) + # _test_dataset(['AI-ModelScope/LLaVA-Pretrain']) + # _test_dataset(['swift/TextCaps']) + # _test_dataset(['swift/RLAIF-V-Dataset:all']) + # _test_dataset(['swift/OK-VQA_train']) + # _test_dataset(['swift/OCR-VQA']) + # _test_dataset(['swift/A-OKVQA']) + # _test_dataset(['AI-ModelScope/MovieChat-1K-test']) + _test_dataset([ + 'AI-ModelScope/LaTeX_OCR:all', 'modelscope/coco_2014_caption:validation', + 'speech_asr/speech_asr_aishell1_trainsets:validation' + ], + strict=False) + # _test_dataset(['swift/VideoChatGPT:all']) + # _test_dataset(['speech_asr/speech_asr_aishell1_trainsets:validation']) + # _test_dataset(['AI-ModelScope/captcha-images']) + # _test_dataset(['swift/gpt4v-dataset:all']) + # _test_dataset(['modelscope/coco_2014_caption:validation']) + # _test_dataset(['AI-ModelScope/LLaVA-Instruct-150K'], num_proc=16) + + +def test_agent(): + _test_dataset(['swift/ToolBench']) + # _test_dataset(['AI-ModelScope/ms_agent_for_agentfabric:all']) + + +def test_dpo(): + _test_dataset(['AI-ModelScope/orpo-dpo-mix-40k']) + _test_dataset(['AI-ModelScope/hh-rlhf:all']) + _test_dataset(['AI-ModelScope/hh_rlhf_cn:all']) + _test_dataset(['hjh0119/shareAI-Llama3-DPO-zh-en-emoji:all']) + + +def test_kto(): + _test_dataset(['AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto']) + + +def test_pretrain(): + _test_dataset(['AI-ModelScope/ruozhiba:all']) + + +def test_dataset_info(): + _test_dataset(['swift/self-cognition#500'], model_name='xiao huang', model_author='swift') + # _test_dataset(['codefuse-ai/CodeExercise-Python-27k']) + + +def test_cls(): + _test_dataset(['simpleai/HC3-Chinese:baike']) + _test_dataset(['simpleai/HC3-Chinese:baike_cls']) + + +if __name__ == '__main__': + # test_sft() + # test_agent() + # test_dpo() + # test_kto() + test_mllm() + # test_pretrain() + # test_dataset_info() + # test_cls() diff --git a/ms-swift/tests/general/test_model.py b/ms-swift/tests/general/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2558f9fd08cb57541d3732ab6767367f57a734 --- /dev/null +++ b/ms-swift/tests/general/test_model.py @@ -0,0 +1,30 @@ +import os + +import torch + +from swift.utils import get_device + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + + +def test_qwen2(): + import os + from swift.llm import get_model_tokenizer + model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False) + print(f'model: {model}, tokenizer: {tokenizer}') + # test hf + model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False, use_hf=True) + + model, tokenizer = get_model_tokenizer( + 'Qwen/Qwen2-7B-Instruct', torch.float32, device_map=get_device(), attn_impl='flash_attn') + print(f'model: {model}, tokenizer: {tokenizer}') + + +def test_modelscope_hub(): + from swift.llm import get_model_tokenizer + model, tokenizer = get_model_tokenizer('Qwen/Qwen2___5-Math-1___5B-Instruct/', load_model=False) + + +if __name__ == '__main__': + test_qwen2() + # test_modelscope_hub() diff --git a/ms-swift/tests/general/test_stream.py b/ms-swift/tests/general/test_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..ad206962270ba09064a5568295072d1d0171204d --- /dev/null +++ b/ms-swift/tests/general/test_stream.py @@ -0,0 +1,20 @@ +from swift.llm import load_dataset + + +def test_local_dataset(): + # please use git clone + from swift.llm import git_clone_github + model_dir = git_clone_github('https://www.modelscope.cn/datasets/swift/swift-sft-mixture.git') + dataset = load_dataset(datasets=[f'{model_dir}:firefly'], streaming=True)[0] + print(next(iter(dataset))) + + +def test_hub_dataset(): + local_dataset = 'swift/swift-sft-mixture:firefly' + dataset = load_dataset(datasets=[local_dataset], streaming=True)[0] + print(next(iter(dataset))) + + +if __name__ == '__main__': + test_local_dataset() + # test_hub_dataset() diff --git a/ms-swift/tests/general/test_template.py b/ms-swift/tests/general/test_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e447f9c620d716942e0c0dbfe20ffbdd2eebeb93 --- /dev/null +++ b/ms-swift/tests/general/test_template.py @@ -0,0 +1,74 @@ +from datasets import Dataset + +from swift.llm import EncodePreprocessor, TemplateInputs, get_model_tokenizer, get_template, load_dataset + + +def test_template(): + _, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False) + template = get_template(tokenizer.model_meta.template, tokenizer) + template_inputs = TemplateInputs([{ + 'role': 'system', + 'content': 'AAA' + }, { + 'role': 'user', + 'content': 'BBB' + }, { + 'role': 'assistant', + 'content': 'CCC' + }, { + 'role': 'user', + 'content': 'DDD' + }]) + inputs = template.encode(template_inputs) + print(f'inputs.keys(): {inputs.keys()}') + print(tokenizer.decode(inputs['input_ids'])) + + +def test_mllm(): + _, tokenizer = get_model_tokenizer('Qwen/Qwen2-VL-7B-Instruct', load_model=False) + template = get_template(tokenizer.model_meta.template, tokenizer) + template_inputs = TemplateInputs([{ + 'role': 'system', + 'content': 'AAA' + }, { + 'role': 'user', + 'content': 'BBB' + }, { + 'role': 'assistant', + 'content': 'CCC' + }, { + 'role': 'user', + 'content': 'DDD' + }], + images=['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png']) + inputs = template.encode(template_inputs) + print(f'inputs.keys(): {inputs.keys()}') + print(template.safe_decode(inputs['input_ids'])) + + +def _test_dataset_map(model_id: str, dataset_id: str): + _, tokenizer = get_model_tokenizer(model_id, load_model=False) + template = get_template(tokenizer.model_meta.template, tokenizer) + dataset = load_dataset([dataset_id], num_proc=2)[0] + + # 1: 1500 + # 16: 10766.36 examples/s + new_dataset = EncodePreprocessor(template)(dataset, num_proc=4) + print(f'new_dataset: {new_dataset}') + print(template.safe_decode(new_dataset[0]['input_ids'])) + print(template.safe_decode(new_dataset[1]['input_ids'])) + + +def test_llm_dataset_map(): + _test_dataset_map('Qwen/Qwen2-7B-Instruct', 'AI-ModelScope/alpaca-gpt4-data-zh') + + +def test_mllm_dataset_map(): + _test_dataset_map('Qwen/Qwen2-VL-7B-Instruct', 'modelscope/coco_2014_caption:validation#100') + + +if __name__ == '__main__': + # test_template() + # test_mllm() + # test_llm_dataset_map() + test_mllm_dataset_map() diff --git a/ms-swift/tests/hub/__init__.py b/ms-swift/tests/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ms-swift/tests/hub/test_check_model.py b/ms-swift/tests/hub/test_check_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f929ed7bb70422ef3b841847e4803a3c6f85bb --- /dev/null +++ b/ms-swift/tests/hub/test_check_model.py @@ -0,0 +1,24 @@ +import os +import shutil +import tempfile +import unittest + +from modelscope import Model, check_local_model_is_latest + + +class TestCheckModel(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + import peft + shutil.rmtree(self.tmp_dir) + super().tearDown() + + def test_check_model(self): + model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base', revision='v1.0.0') + self.assertFalse(check_local_model_is_latest(model.model_dir)) diff --git a/ms-swift/tests/infer/test_infer.py b/ms-swift/tests/infer/test_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6546035f58a126af05778a46967c41179f5f873 --- /dev/null +++ b/ms-swift/tests/infer/test_infer.py @@ -0,0 +1,73 @@ +import os +from typing import Literal + +import torch + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']): + from swift.llm import InferRequest, get_template + if infer_backend == 'lmdeploy': + from swift.llm import LmdeployEngine + engine = LmdeployEngine('OpenGVLab/InternVL2_5-2B', torch.float32) + elif infer_backend == 'pt': + from swift.llm import PtEngine + engine = PtEngine('Qwen/Qwen2-7B-Instruct', max_batch_size=16) + elif infer_backend == 'vllm': + from swift.llm import VllmEngine + engine = VllmEngine('Qwen/Qwen2-7B-Instruct') + template = get_template(engine.model_meta.template, engine.tokenizer) + infer_requests = [ + # InferRequest([{'role': 'user', 'content': '晚上睡不着觉怎么办'}]) for i in range(100) + InferRequest([{ + 'role': 'user', + 'content': 'hello! who are you' + }]) for i in range(100) + ] + return engine, template, infer_requests + + +def test_infer(infer_backend): + from swift.llm import RequestConfig + from swift.plugin import InferStats + engine, template, infer_requests = _prepare(infer_backend=infer_backend) + request_config = RequestConfig(temperature=0) + infer_stats = InferStats() + + response_list = engine.infer( + infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in response_list[:2]: + print(response.choices[0].message.content) + print(infer_stats.compute()) + + +def test_stream(infer_backend): + from swift.llm import RequestConfig + from swift.plugin import InferStats + engine, template, infer_requests = _prepare(infer_backend=infer_backend) + infer_stats = InferStats() + request_config = RequestConfig(temperature=0, stream=True, logprobs=True) + + gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in gen_list[0]: + if response is None: + continue + print(response.choices[0].delta.content, end='', flush=True) + print() + print(infer_stats.compute()) + + gen_list = engine.infer( + infer_requests, template=template, request_config=request_config, use_tqdm=True, metrics=[infer_stats]) + + for response in gen_list[0]: + pass + + print(infer_stats.compute()) + + +if __name__ == '__main__': + test_infer('pt') + # test_stream('pt') diff --git a/ms-swift/tests/infer/test_logprobs.py b/ms-swift/tests/infer/test_logprobs.py new file mode 100644 index 0000000000000000000000000000000000000000..c24add93a068fce6cbeac0e58d9608cdb07d2a44 --- /dev/null +++ b/ms-swift/tests/infer/test_logprobs.py @@ -0,0 +1,71 @@ +import os +from typing import Literal + +import torch + +if __name__ == '__main__': + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']): + from swift.llm import InferRequest, get_template + + if infer_backend == 'lmdeploy': + from swift.llm import LmdeployEngine + engine = LmdeployEngine('Qwen/Qwen2-7B-Instruct', torch.float32) + elif infer_backend == 'pt': + from swift.llm import PtEngine + engine = PtEngine('Qwen/Qwen2-7B-Instruct') + elif infer_backend == 'vllm': + from swift.llm import VllmEngine + engine = VllmEngine('Qwen/Qwen2-7B-Instruct') + template = get_template(engine.model_meta.template, engine.tokenizer) + infer_requests = [ + InferRequest([{ + 'role': 'user', + 'content': '晚上睡不着觉怎么办' + }]), + InferRequest([{ + 'role': 'user', + 'content': 'hello! who are you' + }]) + ] + return engine, template, infer_requests + + +def test_infer(engine, template, infer_requests): + from swift.llm import RequestConfig + from swift.plugin import InferStats + + request_config = RequestConfig(temperature=0, logprobs=True, top_logprobs=2) + infer_stats = InferStats() + + response_list = engine.infer( + infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in response_list[:2]: + print(response.choices[0].message.content) + print(infer_stats.compute()) + + +def test_stream(engine, template, infer_requests): + from swift.llm import RequestConfig + from swift.plugin import InferStats + + infer_stats = InferStats() + request_config = RequestConfig(temperature=0, stream=True, logprobs=True, top_logprobs=2) + + gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in gen_list[0]: + if response is None: + continue + print(response.choices[0].delta.content, end='', flush=True) + + print(infer_stats.compute()) + + +if __name__ == '__main__': + engine, template, infer_requests = _prepare(infer_backend='pt') + test_infer(engine, template, infer_requests) + test_stream(engine, template, infer_requests) diff --git a/ms-swift/tests/infer/test_main.py b/ms-swift/tests/infer/test_main.py new file mode 100644 index 0000000000000000000000000000000000000000..a145ae73c008503cfd33edc4c8b0433399879856 --- /dev/null +++ b/ms-swift/tests/infer/test_main.py @@ -0,0 +1,73 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def test_cli(infer_backend): + from swift.llm import infer_main, InferArguments + args = InferArguments(model='Qwen/Qwen2-VL-7B-Instruct', infer_backend=infer_backend) + infer_main(args) + + +def test_cli_jinja(infer_backend): + from swift.llm import infer_main, InferArguments + args = InferArguments(model='Qwen/Qwen2-VL-7B-Instruct', infer_backend=infer_backend, template_backend='jinja') + infer_main(args) + + +def test_dataset(infer_backend): + from swift.llm import infer_main, InferArguments + args = InferArguments( + model='Qwen/Qwen2-7B-Instruct', + infer_backend=infer_backend, + val_dataset=['AI-ModelScope/alpaca-gpt4-data-zh#10'], + stream=True) + infer_main(args) + + +def test_mllm_dataset(infer_backend): + from swift.llm import infer_main, InferArguments + args = InferArguments( + model='Qwen/Qwen2-VL-7B-Instruct', + infer_backend=infer_backend, + val_dataset=['modelscope/coco_2014_caption:validation#1000'], + stream=True) + infer_main(args) + + +def test_dataset_ddp(): + os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' + from swift.llm import infer_main, InferArguments + args = InferArguments( + model='Qwen/Qwen2-7B-Instruct', max_batch_size=64, val_dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000']) + infer_main(args) + + +def test_dataset_mp_ddp(): + os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' + from swift.llm import infer_main, InferArguments + args = InferArguments( + model='Qwen/Qwen2-7B-Instruct', max_batch_size=64, val_dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000']) + infer_main(args) + + +def test_emu3_gen(infer_backend): + from swift.llm import infer_main, InferArguments + args = InferArguments( + model='BAAI/Emu3-Gen', + infer_backend=infer_backend, + stream=False, + use_chat_template=False, + top_k=2048, + max_new_tokens=40960) + infer_main(args) + + +if __name__ == '__main__': + # test_cli('pt') + # test_cli_jinja('pt') + # test_dataset('pt') + # test_mllm_dataset('pt') + # test_dataset_ddp() + # test_dataset_mp_ddp() + test_emu3_gen('pt') diff --git a/ms-swift/tests/infer/test_mllm.py b/ms-swift/tests/infer/test_mllm.py new file mode 100644 index 0000000000000000000000000000000000000000..9958d659da06c63705d457dff8e536363a0f1aa6 --- /dev/null +++ b/ms-swift/tests/infer/test_mllm.py @@ -0,0 +1,79 @@ +import os +from typing import Literal + +import torch + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']): + from swift.llm import InferRequest, get_template + if infer_backend == 'lmdeploy': + from swift.llm import LmdeployEngine + engine = LmdeployEngine('Qwen/Qwen-VL-Chat', torch.float32) + elif infer_backend == 'pt': + from swift.llm import PtEngine + engine = PtEngine('Qwen/Qwen2-VL-7B-Instruct') + elif infer_backend == 'vllm': + from swift.llm import VllmEngine + engine = VllmEngine('Qwen/Qwen2-VL-7B-Instruct') + template = get_template(engine.model_meta.template, engine.processor) + infer_requests = [ + InferRequest([{ + 'role': 'user', + 'content': '晚上睡不着觉怎么办' + }]), + InferRequest([{ + 'role': + 'user', + 'content': [{ + 'type': 'image_url', + 'image_url': 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png' + }] + }]) + ] + return engine, template, infer_requests + + +def test_infer(engine, template, infer_requests): + from swift.llm import RequestConfig + from swift.plugin import InferStats + request_config = RequestConfig(temperature=0) + infer_stats = InferStats() + + response_list = engine.infer( + infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in response_list[:2]: + print(response.choices[0].message.content) + print(infer_stats.compute()) + + +def test_stream(engine, template, infer_requests): + from swift.llm import RequestConfig + from swift.plugin import InferStats + infer_stats = InferStats() + request_config = RequestConfig(temperature=0, stream=True, logprobs=True) + + gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) + + for response in gen_list[0]: + if response is None: + continue + print(response.choices[0].delta.content, end='', flush=True) + print() + print(infer_stats.compute()) + + gen_list = engine.infer( + infer_requests, template=template, request_config=request_config, use_tqdm=True, metrics=[infer_stats]) + + for response in gen_list[0]: + pass + + print(infer_stats.compute()) + + +if __name__ == '__main__': + engine, template, infer_requests = _prepare(infer_backend='pt') + test_infer(engine, template, infer_requests) + test_stream(engine, template, infer_requests) diff --git a/ms-swift/tests/llm/__init__.py b/ms-swift/tests/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ms-swift/tests/llm/config/infer.json b/ms-swift/tests/llm/config/infer.json new file mode 100644 index 0000000000000000000000000000000000000000..193476604050d0ba5018dd4e36369756799d9363 --- /dev/null +++ b/ms-swift/tests/llm/config/infer.json @@ -0,0 +1,5 @@ +{ + "ckpt_dir": "/mnt/workspace/yzhao/modelscope/swift/output/pai_test/checkpoint-6", + "val_dataset_sample": 2, + "load_dataset_config": true +} diff --git a/ms-swift/tests/llm/config/sft.json b/ms-swift/tests/llm/config/sft.json new file mode 100644 index 0000000000000000000000000000000000000000..da728a80fec1b95673379f4cde37af396c3b739e --- /dev/null +++ b/ms-swift/tests/llm/config/sft.json @@ -0,0 +1,7 @@ +{ + "model_type": "qwen-1_8b-chat", + "dataset": "jd-sentiment-zh", + "output_dir": "output/pai_test", + "train_dataset_sample": 100, + "eval_steps": 5 +} diff --git a/ms-swift/tests/llm/data/alpaca.csv b/ms-swift/tests/llm/data/alpaca.csv new file mode 100644 index 0000000000000000000000000000000000000000..bc956f052a36fccd7994e48b92486f690097395d --- /dev/null +++ b/ms-swift/tests/llm/data/alpaca.csv @@ -0,0 +1,4 @@ +system,instruction,input,output +00000,11111,22222,3.3 +,aaaaa,,ccccc +,AAAAA,BBBBB,CCCCC diff --git a/ms-swift/tests/llm/data/alpaca2.csv b/ms-swift/tests/llm/data/alpaca2.csv new file mode 100644 index 0000000000000000000000000000000000000000..cfdb441132b28345aed44adb9535cb07ee0ed13e --- /dev/null +++ b/ms-swift/tests/llm/data/alpaca2.csv @@ -0,0 +1,4 @@ +instruction,output +11111,33333 +aaaaa,ccccc +AAAAA,CCCCC diff --git a/ms-swift/tests/llm/data/chatml.jsonl b/ms-swift/tests/llm/data/chatml.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..1970637a11f049bf011c68d7faf418dd57a76d22 --- /dev/null +++ b/ms-swift/tests/llm/data/chatml.jsonl @@ -0,0 +1,3 @@ +{"messages": [{"role": "system", "content": "00000"}, {"role": "user", "content": "11111"}, {"role": "assistant", "content": "22222"}]} +{"messages": [{"role": "user", "content": "aaaaa"}, {"role": "assistant", "content": "bbbbb"}, {"role": "user", "content": "ccccc"}, {"role": "assistant", "content": "ddddd"}]} +{"messages": [{"role": "user", "content": "AAAAA"}, {"role": "assistant", "content": "BBBBB"}, {"role": "user", "content": "CCCCC"}, {"role": "assistant", "content": "DDDDD"}]} diff --git a/ms-swift/tests/llm/data/sharegpt.jsonl b/ms-swift/tests/llm/data/sharegpt.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..e3ef4954a9fcfa5652fd0b809d01bd94435262ad --- /dev/null +++ b/ms-swift/tests/llm/data/sharegpt.jsonl @@ -0,0 +1,3 @@ +{"system": "00000", "conversation": [{"human": "11111", "assistant": "22222"}]} +{"conversation": [{"human": "aaaaa", "assistant": "bbbbb"}]} +{"conversation": [{"human": "AAAAA", "assistant": "BBBBB"}, {"human": "CCCCC", "assistant": "DDDDD"}, {"human": "EEEEE", "assistant": "FFFFF"}]} diff --git a/ms-swift/tests/llm/data/swift_multi.jsonl b/ms-swift/tests/llm/data/swift_multi.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..5d78c48088ba210e5ad4fb3f48b755a639ec6e0c --- /dev/null +++ b/ms-swift/tests/llm/data/swift_multi.jsonl @@ -0,0 +1,3 @@ +{"system": "00000", "query": "55555", "response": "66666"} +{"query": "eeeee", "response": "fffff", "history": []} +{"query": "EEEEE", "response": "FFFFF", "history": [["AAAAA", "BBBBB"], ["CCCCC", "DDDDD"]]} diff --git a/ms-swift/tests/llm/data/swift_pre.csv b/ms-swift/tests/llm/data/swift_pre.csv new file mode 100644 index 0000000000000000000000000000000000000000..45bae8fcde78142c6301399dc65c0e767d2afad2 --- /dev/null +++ b/ms-swift/tests/llm/data/swift_pre.csv @@ -0,0 +1,4 @@ +response +11111 +aaaaa +AAAAA diff --git a/ms-swift/tests/llm/data/swift_single.csv b/ms-swift/tests/llm/data/swift_single.csv new file mode 100644 index 0000000000000000000000000000000000000000..8fa2dbce7eaab4dcedf28b07a09346e16851052b --- /dev/null +++ b/ms-swift/tests/llm/data/swift_single.csv @@ -0,0 +1,4 @@ +system,query,response +00000,11111,22222 +,aaaaa,bbbbb +,AAAAA,BBBBB diff --git a/ms-swift/tests/llm/data/swift_single.jsonl b/ms-swift/tests/llm/data/swift_single.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..ae08d8a37b12f4fc3d52a4c72c61d7905441618a --- /dev/null +++ b/ms-swift/tests/llm/data/swift_single.jsonl @@ -0,0 +1,3 @@ +{"system": "00000", "query": "11111", "response": "22222"} +{"query": "aaaaa", "response": "bbbbb"} +{"query": "AAAAA", "response": "BBBBB"} diff --git a/ms-swift/tests/llm/load_model.py b/ms-swift/tests/llm/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb814bde563a394ab83f990ea76be1e468f6c8d --- /dev/null +++ b/ms-swift/tests/llm/load_model.py @@ -0,0 +1,45 @@ +import argparse +from dataclasses import fields + +import torch + +from swift.llm import MODEL_ARCH_MAPPING, ModelKeys, get_model_tokenizer + + +def get_model_and_tokenizer(ms_model_id, model_arch=None): + try: + import transformers + print(f'Test model: {ms_model_id} with transformers version: {transformers.__version__}') + model_ins, tokenizer = get_model_tokenizer(ms_model_id) + model_ins: torch.nn.Module + if model_arch: + model_arch: ModelKeys = MODEL_ARCH_MAPPING[model_arch] + for f in fields(model_arch): + value = getattr(model_arch, f.name) + if value is not None and f.name != 'arch_name': + if isinstance(value, str): + value = [value] + for v in value: + v = v.replace('{}', '0') + model_ins.get_submodule(v) + except Exception: + import traceback + print(traceback.format_exc()) + raise + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--ms_model_id', + type=str, + required=True, + ) + parser.add_argument( + '--model_arch', + type=str, + required=True, + ) + args = parser.parse_args() + + get_model_and_tokenizer(args.ms_model_id, args.model_arch) diff --git a/ms-swift/tests/llm/load_template.py b/ms-swift/tests/llm/load_template.py new file mode 100644 index 0000000000000000000000000000000000000000..680ce15e741ffc22840a29cb5dea4f9e491efcd2 --- /dev/null +++ b/ms-swift/tests/llm/load_template.py @@ -0,0 +1,138 @@ +import argparse +from collections.abc import Mapping + +import json +import torch +from transformers import PreTrainedTokenizerBase + + +def to_list(input_ids): + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.cpu().numpy().tolist() + if isinstance(input_ids, list) and isinstance(input_ids[0], list): + input_ids = input_ids[0] + return input_ids + + +def load_ds(ds): + from swift.llm import load_dataset + train_dataset, val_dataset = load_dataset( + ds, + split_dataset_ratio=0.0, + strict=False, + num_proc=1, + model_name=['小黄', 'Xiao Huang'], + model_author=['魔搭', 'ModelScope']) + return train_dataset.select(range(1)) + + +def load_and_tokenize(ms_model_id, template): + from swift.llm import EncodePreprocessor, get_model_tokenizer, get_template + try: + vl_fields = ['vl', 'video', 'minicpmv', 'llava', 'vision', 'emu', 'florence'] + model_ins, tokenizer = get_model_tokenizer(ms_model_id, load_model='mplug' in ms_model_id.lower()) + template_ins = get_template(template, tokenizer) + if template_ins.use_model: + model_ins, _ = get_model_tokenizer(ms_model_id, load_model=True) + template_ins.model = model_ins + template_ins.set_mode('train') + if 'audio' in template_ins.__class__.__name__.lower(): + output = EncodePreprocessor(template_ins)( + load_ds('speech_asr/speech_asr_aishell1_trainsets:validation/test')) + input_ids = output[0].get('input_ids') + elif any([vl in template for vl in vl_fields]): + for row in load_ds('modelscope/coco_2014_caption:validation'): + output = template_ins.encode(row) + input_ids = output.get('input_ids') + # output = EncodePreprocessor(template_ins)(load_ds('swift/OK-VQA_train')) + if model_ins is not None and model_ins.model_meta.is_multimodal: + inputs = template_ins.pre_data_collator([output], model=model_ins) + _, output = template_ins.pre_forward_hook(model_ins, None, inputs) + else: + output = EncodePreprocessor(template_ins)(load_ds('modelscope/DuReader_robust-QG')) + input_ids = output[0].get('input_ids') + if isinstance(output, Mapping): + assert output.get('input_ids') is not None or output.get('inputs_embeds') is not None + else: + assert output[0].get('input_ids') is not None or output[0].get('inputs_embeds') is not None + input_ids = to_list(input_ids) + sent = '' + try: + if not isinstance(tokenizer, PreTrainedTokenizerBase) and hasattr(tokenizer, 'tokenizer'): + tokenizer = tokenizer.tokenizer + sent = tokenizer.decode(input_ids) + except Exception: + pass + return input_ids, sent + except Exception: + import traceback + print(traceback.format_exc()) + raise + + +def load_ds_old(ds): + from swift.llm import load_dataset + train_dataset, val_dataset = load_dataset(ds, split_dataset_ratio=0.0) + return train_dataset.select(range(1)) + + +def load_and_tokenize_old(ms_model_id, template): + model_type = None + model_info = None + from swift.llm import get_model_tokenizer + from swift.llm import get_template, MODEL_MAPPING + found = False + for model_type, model_info in MODEL_MAPPING.items(): + if model_info['model_id_or_path'].lower() == ms_model_id.lower(): + found = True + break + + if not found: + raise ValueError(f'No model_type found: {ms_model_id}') + + vl_fields = ['vl', 'video', 'minicpm-v', 'llava', 'vision', 'emu', 'florence'] + model_ins, tokenizer = get_model_tokenizer(model_type, load_model=True) + + if model_info['template'] == 'default-generation': + model_info['template'] = template.replace('_', '-') + template_ins = get_template(model_info['template'], tokenizer) + template_ins.model = model_ins + if 'audio' in model_info['template']: + output = template_ins.encode(load_ds_old('aishell1-zh-mini')[0]) + elif any([vl in model_info['template'] for vl in vl_fields]): + output = template_ins.encode(load_ds_old('coco-en-mini')[0]) + else: + output = template_ins.encode(load_ds_old('dureader-robust-zh')[0]) + input_ids = to_list(output[0]['input_ids']) + sent = '' + try: + sent = tokenizer.decode(input_ids) + except Exception: + pass + return input_ids, sent + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--ms_model_id', + type=str, + required=True, + ) + parser.add_argument( + '--template', + type=str, + required=True, + ) + parser.add_argument('--new', type=str, required=False, default='1') + args = parser.parse_args() + + is_new = args.new == '1' + if is_new: + input_ids, sent = load_and_tokenize(args.ms_model_id, args.template) + else: + input_ids, sent = load_and_tokenize_old(args.ms_model_id, args.template) + file = 'new_input_ids.txt' if is_new else 'old_input_ids.txt' + if input_ids is not None: + with open(file, 'w') as f: + json.dump({'input_ids': input_ids, 'sent': sent}, f) diff --git a/ms-swift/tests/llm/test_custom.py b/ms-swift/tests/llm/test_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb59c7e71c3e3ba6e63a6902f9db7d42a26baf7 --- /dev/null +++ b/ms-swift/tests/llm/test_custom.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest +from typing import Any, Dict, Optional + +import torch + +from swift.llm import (DatasetMeta, InferRequest, Model, ModelGroup, ModelMeta, PtEngine, RequestConfig, + ResponsePreprocessor, TemplateMeta, get_model_tokenizer_with_flash_attn, load_dataset, + register_dataset, register_model, register_template) + + +class CustomPreprocessor(ResponsePreprocessor): + prompt = """Task: Based on the given two sentences, provide a similarity score between 0.0 and 5.0. +Sentence 1: {text1} +Sentence 2: {text2} +Similarity score: """ + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + return super().preprocess({ + 'query': self.prompt.format(text1=row['text1'], text2=row['text2']), + 'response': f"{row['label']:.1f}" + }) + + +register_dataset( + DatasetMeta( + ms_dataset_id='swift/stsb', + hf_dataset_id='SetFit/stsb', + preprocess_func=CustomPreprocessor(), + )) + +register_template( + TemplateMeta( + template_type='custom', + prefix=['System\n{{SYSTEM}}\n'], + prompt=['User\n{{QUERY}}\nAssistant\n'], + chat_sep=['\n'])) + +register_model( + ModelMeta( + model_type='custom', + model_groups=[ + ModelGroup([Model('AI-ModelScope/Nemotron-Mini-4B-Instruct', 'nvidia/Nemotron-Mini-4B-Instruct')]) + ], + template='custom', + get_function=get_model_tokenizer_with_flash_attn, + ignore_patterns=['nemo'])) + + +class TestCustom(unittest.TestCase): + + def test_custom_model(self): + infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}]) + request_config = RequestConfig(max_tokens=512, temperature=0) + engine = PtEngine('AI-ModelScope/Nemotron-Mini-4B-Instruct', torch.float16) + response = engine.infer([infer_request], request_config) + swift_response = response[0].choices[0].message.content + + engine.default_template.template_backend = 'jinja' + response = engine.infer([infer_request], request_config) + jinja_response = response[0].choices[0].message.content + assert swift_response == jinja_response, (f'swift_response: {swift_response}\njinja_response: {jinja_response}') + print(f'response: {swift_response}') + + def test_custom_dataset(self): + dataset = load_dataset(['swift/stsb'])[0] + assert len(dataset) == 5749 + assert list(dataset[0].keys()) == ['messages'] + print(f'dataset: {dataset}') + print(f'dataset[0]: {dataset[0]}') + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/llm/test_dataset.py b/ms-swift/tests/llm/test_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..00fa93e309402c480df3a4e18800cebaa995e543 --- /dev/null +++ b/ms-swift/tests/llm/test_dataset.py @@ -0,0 +1,19 @@ +import unittest + +from swift.llm import load_dataset + + +class TestDataset(unittest.TestCase): + + def test_load_v_dataset(self): + if not __name__ == '__main__': + # ignore citest error in github + return + + for ds in ['m3it#1000', 'mantis-instruct#1000', 'llava-med-zh-instruct#1000']: + ds = load_dataset(ds) + assert len(ds[0]) > 800 + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/llm/test_ollama_export.py b/ms-swift/tests/llm/test_ollama_export.py new file mode 100644 index 0000000000000000000000000000000000000000..44a8ff775c4f47baec7f1fa27ca6698851fcb324 --- /dev/null +++ b/ms-swift/tests/llm/test_ollama_export.py @@ -0,0 +1,80 @@ +import os +import shutil +import tempfile +import unittest + +import transformers +from packaging import version + +from swift.llm import ExportArguments, export_main + +if __name__ == '__main__': + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestTemplate(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + + def tearDown(self): + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skip('swift2.0') + def test_llama3(self): + args = ExportArguments(model_type='llama3-8b-instruct', to_ollama=True, ollama_output_dir=self.tmp_dir) + export_main(args) + + template = ('TEMPLATE """{{ if .System }}<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n' + '{{ .System }}<|eot_id|>{{ else }}<|begin_of_text|>{{ end }}{{ if .Prompt }}<|start_header_id|>user' + '<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + '{{ end }}{{ .Response }}<|eot_id|>"""') + + stop = 'PARAMETER stop "<|eot_id|>"' + + with open(os.path.join(self.tmp_dir, 'Modelfile'), 'r') as f: + content = f.read() + self.assertTrue(template in content) + self.assertTrue(stop in content) + + @unittest.skip('swift2.0') + def test_glm4(self): + if version.parse(transformers.__version__) >= version.parse('4.45'): + return + + args = ExportArguments(model_type='glm4-9b-chat', to_ollama=True, ollama_output_dir=self.tmp_dir) + export_main(args) + + template = ('TEMPLATE """{{ if .System }}[gMASK] <|system|>\n{{ .System }}{{ else }}' + '[gMASK] {{ end }}{{ if .Prompt }}<|user|>\n{{ .Prompt }}<|assistant|>\n' + '{{ end }}{{ .Response }}<|user|>"""') + + stop = 'PARAMETER stop "<|user|>"' + + with open(os.path.join(self.tmp_dir, 'Modelfile'), 'r') as f: + content = f.read() + self.assertTrue(template in content) + self.assertTrue(stop in content) + + @unittest.skip('swift2.0') + def test_qwen2(self): + args = ExportArguments(model_type='qwen2-7b-instruct', to_ollama=True, ollama_output_dir=self.tmp_dir) + export_main(args) + + template = ('TEMPLATE """{{ if .System }}<|im_start|>system\n{{ .System }}<|im_end|>\n{{ else }}{{ end }}' + '{{ if .Prompt }}<|im_start|>user\n{{ .Prompt }}<|im_end|>\n<|im_start|>assistant\n' + '{{ end }}{{ .Response }}<|im_end|>"""') + + stop = 'PARAMETER stop "<|im_end|>"' + + with open(os.path.join(self.tmp_dir, 'Modelfile'), 'r') as f: + content = f.read() + self.assertTrue(template in content) + self.assertTrue(stop in content) + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/llm/test_run.py b/ms-swift/tests/llm/test_run.py new file mode 100644 index 0000000000000000000000000000000000000000..becb6ef4a22f1244df327561096b82f709404bcd --- /dev/null +++ b/ms-swift/tests/llm/test_run.py @@ -0,0 +1,458 @@ +if __name__ == '__main__': + import os + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' + +import os +import shutil +import tempfile +import unittest +from functools import partial +from typing import Any, Dict, List + +import torch +from datasets import Dataset as HfDataset +from modelscope import Model, MsDataset, snapshot_download +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer + +from swift import Trainer, TrainingArguments, get_logger +from swift.llm import (InferArguments, ModelType, RLHFArguments, TrainArguments, infer_main, merge_lora, rlhf_main, + sft_main) + +NO_EVAL_HUMAN = True + +logger = get_logger() + +kwargs = { + 'per_device_train_batch_size': 2, + 'per_device_eval_batch_size': 2, + 'save_steps': 5, + 'gradient_accumulation_steps': 4, + 'num_train_epochs': 1, +} + + +class TestRun(unittest.TestCase): + + def setUp(self): + print(f'Testing {type(self).__name__}.{self._testMethodName}') + self._tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir = self._tmp_dir.name + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def test_template(self): + if not __name__ == '__main__': + # ignore citest error in github + return + torch.cuda.empty_cache() + output = sft_main( + TrainArguments( + model='Qwen/Qwen1.5-0.5B', + train_type='full', + dataset='DAMO_NLP/jd', + val_dataset='DAMO_NLP/jd#20', + streaming=True, + max_steps=12, + **kwargs)) + last_model_checkpoint = output['last_model_checkpoint'] + torch.cuda.empty_cache() + result = infer_main(InferArguments(model=last_model_checkpoint, load_data_args=True, val_dataset_sample=2)) + assert len(result[0]['response']) < 20 + + def test_hf_hub(self): + if not __name__ == '__main__': + # ignore citest error in github + return + torch.cuda.empty_cache() + train_dataset_fnames = [ + 'alpaca.csv', 'chatml.jsonl', 'swift_pre.jsonl', 'swift_single.csv', 'swift_multi.jsonl', + 'swift_multi.json#2' + ] + folder = os.path.join(os.path.dirname(__file__), 'data') + dataset = [ + 'llm-wizard/alpaca-gpt4-data-zh#20', + 'shibing624/alpaca-zh#20', + ] + [os.path.join(folder, fname) for fname in train_dataset_fnames] + output = sft_main( + TrainArguments( + model='Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4', train_type='lora', dataset=dataset, use_hf=True, **kwargs)) + last_model_checkpoint = output['last_model_checkpoint'] + torch.cuda.empty_cache() + infer_main(InferArguments(adapters=last_model_checkpoint, load_data_args=True, val_dataset_sample=2)) + + @unittest.skip('avoid ci error') + def test_basic(self): + output_dir = 'output' + quant_bits_list = [0, 4] + train_dataset_fnames = [ + 'alpaca.csv', 'chatml.jsonl', 'swift_pre.jsonl', 'swift_single.csv', 'swift_multi.jsonl', + 'swift_multi.json#2' + ] + folder = os.path.join(os.path.dirname(__file__), 'data') + dataset = [ + 'AI-ModelScope/alpaca-gpt4-data-zh#20', + 'hurner/alpaca-gpt4-data-zh#20', + ] + [os.path.join(folder, fname) for fname in train_dataset_fnames] + if not __name__ == '__main__': + output_dir = self.tmp_dir + quant_bits_list = [4] + dataset = dataset[:2] + for quant_bits in quant_bits_list: + if quant_bits == 0: + predict_with_generate = False + quant_method = None + else: + predict_with_generate = True + quant_method = 'bnb' + sft_args = TrainArguments( + model='Qwen/Qwen2-0.5B-Instruct', + quant_bits=quant_bits, + eval_steps=5, + adam_beta2=0.95, + quant_method=quant_method, + predict_with_generate=predict_with_generate, + dataset=dataset, + val_dataset='DAMO_NLP/jd#20', + output_dir=output_dir, + download_mode='force_redownload', + include_num_input_tokens_seen=True, + gradient_checkpointing=True, + **kwargs) + torch.cuda.empty_cache() + output = sft_main(sft_args) + print(output) + best_model_checkpoint = output['best_model_checkpoint'] + print(f'best_model_checkpoint: {best_model_checkpoint}') + if __name__ == '__main__': + infer_args = InferArguments( + adapters=best_model_checkpoint, + merge_lora={ + 0: True, + 4: False + }[quant_bits], + load_data_args=NO_EVAL_HUMAN, + val_dataset_sample=5) + torch.cuda.empty_cache() + result = infer_main(infer_args) + print(result) + # if __name__ == '__main__': + # app_ui_main(infer_args) + + def test_vl_audio(self): + output_dir = 'output' + if not __name__ == '__main__': + # ignore citest error in github + return + model_type_list = ['Qwen/Qwen-VL-Chat', 'Qwen/Qwen-Audio-Chat'] + dataset_list = [ + 'modelscope/coco_2014_caption:validation#100', 'speech_asr/speech_asr_aishell1_trainsets:validation#100' + ] + for model, dataset in zip(model_type_list, dataset_list): + sft_args = TrainArguments( + model=model, + eval_steps=5, + dataset=[dataset], + output_dir=output_dir, + gradient_checkpointing=True, + lazy_tokenize=True, + disable_tqdm=True, + **kwargs) + torch.cuda.empty_cache() + output = sft_main(sft_args) + print(output) + best_model_checkpoint = output['best_model_checkpoint'] + print(f'best_model_checkpoint: {best_model_checkpoint}') + infer_args = InferArguments( + adapters=best_model_checkpoint, + load_data_args=True, + stream={ + 'Qwen/Qwen-VL-Chat': True, + 'Qwen/Qwen-Audio-Chat': False + }[model], + val_dataset_sample=5) + torch.cuda.empty_cache() + result = infer_main(infer_args) + print(result) + + def test_custom_dataset(self): + if not __name__ == '__main__': + # ignore citest error in github + return + train_dataset_fnames = [ + 'alpaca.csv', 'chatml.jsonl', 'swift_pre.jsonl', 'swift_single.csv', 'swift_multi.jsonl', + 'swift_multi.json', 'sharegpt.jsonl' + ] + val_dataset_fnames = [ + 'alpaca.jsonl', + 'alpaca2.csv', + 'conversations.jsonl', + 'swift_pre.csv', + 'swift_single.jsonl', + # 'swift_#:#.jsonl#3' + ] + folder = os.path.join(os.path.dirname(__file__), 'data') + resume_from_checkpoint = None + train_kwargs = kwargs.copy() + train_kwargs.pop('num_train_epochs') + for num_train_epochs in [1, 2]: + sft_args = TrainArguments( + model='Qwen/Qwen-7B-Chat', + dataset=['swift/self-cognition#20'] + [os.path.join(folder, fname) for fname in train_dataset_fnames], + val_dataset=[os.path.join(folder, fname) for fname in val_dataset_fnames], + resume_from_checkpoint=resume_from_checkpoint, + num_train_epochs=num_train_epochs, + model_name='小黄', + model_author='魔搭', + **train_kwargs) + + torch.cuda.empty_cache() + result = sft_main(sft_args) + best_model_checkpoint = result['best_model_checkpoint'] + resume_from_checkpoint = result['last_model_checkpoint'] + + for load_args in [True, False]: + infer_kwargs = {} + if load_args is False: + args_json = os.path.join(best_model_checkpoint, 'args.json') + assert os.path.exists(args_json) + os.remove(args_json) + infer_kwargs = {'model': 'Qwen/Qwen-7B-Chat'} + infer_args = InferArguments( + adapters=best_model_checkpoint, + load_data_args=load_args and NO_EVAL_HUMAN, + merge_lora=load_args, + val_dataset=[os.path.join(folder, fname) for fname in val_dataset_fnames], + **infer_kwargs) + torch.cuda.empty_cache() + infer_main(infer_args) + + def test_rlhf(self): + if not __name__ == '__main__': + # ignore citest error in github + return + torch.cuda.empty_cache() + # llm rlhf + # + rlhf_types = ['dpo', 'orpo', 'simpo', 'kto', 'cpo', 'rm', 'ppo'] + for rlhf_type in rlhf_types: + dataset = ('AI-ModelScope/hh_rlhf_cn:harmless_base_cn#100' + if rlhf_type != 'kto' else 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#100') + train_kwargs = {} + if rlhf_type == 'ppo': + train_kwargs['reward_model'] = 'Qwen/Qwen2-1.5B-Instruct' + output = rlhf_main( + RLHFArguments( + rlhf_type=rlhf_type, + model='Qwen/Qwen2-1.5B-Instruct', + dataset=dataset, + eval_steps=5, + split_dataset_ratio=0.05, + **train_kwargs, + **kwargs)) + if rlhf_type == 'ppo': + model_checkpoint = output['last_model_checkpoint'] + else: + model_checkpoint = output['best_model_checkpoint'] + + torch.cuda.empty_cache() + infer_main(InferArguments(adapters=model_checkpoint, load_data_args=True)) + + # mllm rlhf + visual_rlhf_types = ['dpo', 'orpo', 'simpo', 'cpo', 'rm'] + test_model = [ + 'OpenGVLab/InternVL2-2B', 'Qwen/Qwen2-VL-2B-Instruct', 'llava-hf/llava-v1.6-mistral-7b-hf', + 'AI-ModelScope/Florence-2-base-ft' + ] # decoder only and encoder-decoder + for rlhf_type in visual_rlhf_types: + for model in test_model: + dataset_name = 'swift/RLAIF-V-Dataset#100' + output = rlhf_main( + RLHFArguments( + rlhf_type=rlhf_type, + model=model, + dataset=dataset_name, + eval_steps=5, + dataset_num_proc=16, + **kwargs)) + best_model_checkpoint = output['best_model_checkpoint'] + torch.cuda.empty_cache() + infer_main(InferArguments(adapters=best_model_checkpoint, load_data_args=True, val_dataset_sample=2)) + + def test_loss_matching(self): + output_dir = 'output' + if not __name__ == '__main__': + # ignore citest error in github + return + losses = [] + for use_swift_lora in [False, True]: + bool_var = use_swift_lora + torch.cuda.empty_cache() + output = sft_main([ + '--model', 'Qwen/Qwen-7B-Chat', '--save_steps', '5', '--dataset', + 'AI-ModelScope/leetcode-solutions-python#200', '--output_dir', output_dir, '--gradient_checkpointing', + 'true', '--max_new_tokens', '100', '--attn_impl', 'flash_attn', '--target_modules', 'all-linear', + '--seed', '0', '--lora_bias', 'all', '--modules_to_save', 'lm_head', '--use_swift_lora', + str(use_swift_lora), '--num_train_epochs', '1', '--gradient_accumulation_steps', '16' + ]) + best_model_checkpoint = output['best_model_checkpoint'] + print(f'best_model_checkpoint: {best_model_checkpoint}') + load_data_args = str(bool_var or NO_EVAL_HUMAN) + if load_data_args: + val_dataset_sample = 2 + else: + val_dataset_sample = -1 + torch.cuda.empty_cache() + infer_main([ + '--adapters', best_model_checkpoint, '--val_dataset_sample', + str(val_dataset_sample), '--max_new_tokens', '100', '--attn_impl', 'eager', '--merge_lora', + str(bool_var), '--load_data_args', + str(load_data_args) + ]) + loss = output['log_history'][-1]['train_loss'] + losses.append(loss) + self.assertTrue(abs(losses[0] - losses[1]) < 5e-4) + print(f'swift_loss: {losses[0]}') + print(f'peft_loss: {losses[1]}') + self.assertTrue(0.95 <= losses[0] <= 1) + + def test_pai_compat(self): + if not __name__ == '__main__': + # ignore citest error in github + return + from swift.llm import sft_main, infer_main + os.environ['PAI_TRAINING_JOB_ID'] = '123456' + folder = os.path.join(os.path.dirname(__file__), 'config') + tensorboard_dir = os.path.join('output/pai_test', 'pai_tensorboard') + os.environ['PAI_OUTPUT_TENSORBOARD'] = tensorboard_dir + sft_json = os.path.join(folder, 'sft.json') + infer_json = os.path.join(folder, 'infer.json') + torch.cuda.empty_cache() + output = sft_main([sft_json]) + print() + infer_args = { + 'adapters': output['best_model_checkpoint'], + 'val_dataset_sample': 2, + 'load_data_args': True, + } + import json + with open(infer_json, 'w') as f: + json.dump(infer_args, f, ensure_ascii=False, indent=4) + torch.cuda.empty_cache() + infer_main([infer_json]) + os.environ.pop('PAI_TRAINING_JOB_ID') + + +def data_collate_fn(batch: List[Dict[str, Any]], tokenizer) -> Dict[str, torch.Tensor]: + # text-classification + assert tokenizer.pad_token_id is not None + input_ids = [torch.tensor(b['input_ids']) for b in batch] + labels = torch.tensor([b['labels'] for b in batch]) + attention_mask = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))] + + input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) + attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0) + return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels} + + +class BertTrainer(Trainer): + + def compute_loss(self, model, inputs, return_outputs=False): + outputs = model(**inputs) + loss = outputs.loss + if loss is None: + logits, loss = list(outputs.logits) + return (loss, outputs) if return_outputs else loss + + +class TestTrainer(unittest.TestCase): + + def setUp(self): + self._tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir = self._tmp_dir.name + # self.tmp_dir = 'test' + logger.info(f'self.tmp_dir: {self.tmp_dir}') + + def tearDown(self): + if os.path.isdir(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + # api = HubApi() + # api.delete_model(self.hub_model_id) + # logger.info(f'delete model: {self.hub_model_id}') + + def test_trainer(self): + self.hub_model_id = 'test_trainer2' + logger.info(f'self.hub_model_id: {self.hub_model_id}') + self.tmp_dir = 'output/damo/nlp_structbert_backbone_base_std' + push_to_hub = True + if not __name__ == '__main__': + # ignore citest error in github + return + model_id = 'damo/nlp_structbert_backbone_base_std' + model_dir = snapshot_download(model_id, 'master') + tokenizer = AutoTokenizer.from_pretrained(model_dir) + dataset = MsDataset.load('clue', subset_name='tnews') + num_labels = max(dataset['train']['label']) + 1 + model = Model.from_pretrained(model_dir, task='text-classification', num_labels=num_labels) + train_dataset, val_dataset = dataset['train'].to_hf_dataset(), dataset['validation'].to_hf_dataset() + train_dataset: HfDataset = train_dataset.select(range(100)) + val_dataset: HfDataset = val_dataset.select(range(20)) + + # + def tokenize_func(examples): + data = tokenizer(examples['sentence'], return_attention_mask=False) + examples['input_ids'] = data['input_ids'] + examples['labels'] = examples['label'] + del examples['sentence'], examples['label'] + return examples + + train_dataset = train_dataset.map(tokenize_func) + val_dataset = val_dataset.map(tokenize_func) + + data_collator = partial(data_collate_fn, tokenizer=tokenizer) + for save_only_model in [True, False]: + trainer_args = TrainingArguments( + self.tmp_dir, + do_train=True, + do_eval=True, + num_train_epochs=1, + evaluation_strategy='steps', + save_strategy='steps', + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + push_to_hub=push_to_hub, + hub_token=None, # use env var + hub_private_repo=True, + hub_strategy='every_save', + hub_model_id=self.hub_model_id, + overwrite_output_dir=True, + save_steps=10, + save_total_limit=2, + metric_for_best_model='loss', + greater_is_better=False, + report_to=['tensorboard'], + gradient_accumulation_steps=1, + logging_steps=5, + eval_steps=10, + save_safetensors=False, + save_only_model=save_only_model) + trainer_args._n_gpu = 1 + trainer = BertTrainer(model, trainer_args, data_collator, train_dataset, val_dataset, tokenizer) + self.hub_model_id = trainer_args.hub_model_id + trainer.train() + if trainer_args.push_to_hub: + trainer.push_to_hub() + + +if __name__ == '__main__': + # TestRun().test_template() + # TestRun().test_hf_hub() + # TestRun().test_basic() + # TestRun().test_custom_dataset() + # TestRun().test_vl_audio() + # TestRun().test_loss_matching() + # + # TestRun().test_rlhf() + unittest.main() diff --git a/ms-swift/tests/llm/test_run3.py b/ms-swift/tests/llm/test_run3.py new file mode 100644 index 0000000000000000000000000000000000000000..9590d36e702eb20bb53317676df36ce88893c935 --- /dev/null +++ b/ms-swift/tests/llm/test_run3.py @@ -0,0 +1,172 @@ +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np + +from swift.llm import MODEL_MAPPING, load_dataset + + +class TestRun3(unittest.TestCase): + + def setUp(self): + print(f'Testing {type(self).__name__}.{self._testMethodName}') + self._tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir = self._tmp_dir.name + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def load_ds(self, ds): + train_dataset, val_dataset = load_dataset( + ds, + split_dataset_ratio=0.0, + strict=False, + num_proc=1, + model_name=['小黄', 'Xiao Huang'], + model_author=['魔搭', 'ModelScope']) + return train_dataset.select(range(min(50, len(train_dataset)))) + + # def test_model_load(self): + # if os.path.exists('./models.txt'): + # with open('./models.txt', 'r') as f: + # models = json.load(f) + # else: + # models = [] + # for model_name, model_meta in MODEL_MAPPING.items(): + # meta_requires = model_meta.requires or [] + # for group in model_meta.model_groups: + # model = group.models[0] + # if 'skip_test' in (group.tags or []) or model.ms_model_id in models: + # break + # requires = meta_requires + (group.requires or []) + # for req in requires: + # os.system(f'pip install "{req}"') + # if not any(['transformers' in req for req in requires]): + # os.system('pip install transformers -U') + # if not any(['accelerate' in req for req in requires]): + # os.system('pip install accelerate -U') + # try: + # model_arch_args = '' + # if model_meta.model_arch: + # model_arch_args = f'--model_arch {model_meta.model_arch}' + # cmd = ('PYTHONPATH=. python tests/llm/load_model.py ' + # f'--ms_model_id {model.ms_model_id} {model_arch_args}') + # if os.system(cmd) != 0: + # raise RuntimeError() + # except Exception: + # passed = False + # else: + # passed = True + # models.append(model.ms_model_id) + # finally: + # if passed: + # with open('./models.txt', 'w') as f: + # json.dump(models, f) + + # def test_template_load(self): + # if os.path.exists('./templates.txt'): + # with open('./templates.txt', 'r') as f: + # templates = json.load(f) + # else: + # templates = [] + # for model_name, model_meta in MODEL_MAPPING.items(): + # template = model_meta.template + # meta_requires = model_meta.requires or [] + # for group in model_meta.model_groups: + # model = group.models[0] + # if 'skip_test' in (group.tags or []) or template in templates: + # break + # requires = meta_requires + (group.requires or []) + # for req in requires: + # os.system(f'pip install "{req}"') + # if not any(['transformers' in req for req in requires]): + # os.system('pip install transformers -U') + # if not any(['accelerate' in req for req in requires]): + # os.system('pip install accelerate -U') + # try: + # cmd = ('PYTHONPATH=. python tests/llm/load_template.py ' + # f'--ms_model_id {model.ms_model_id} --template {template}') + # if os.system(cmd) != 0: + # raise RuntimeError() + # except Exception: + # import traceback + # print(traceback.format_exc()) + # passed = False + # else: + # passed = True + # templates.append(template) + # finally: + # if passed: + # with open('./templates.txt', 'w') as f: + # json.dump(templates, f) + + @unittest.skip('skip') + def test_template_compare(self): + if os.path.exists('./templates.txt'): + with open('./templates.txt', 'r') as f: + templates = json.load(f) + else: + templates = [] + skip_model_type = { + 'grok', 'deepseek_moe', 'deepseek_v2', 'deepseek_v2_5', 'llama3_1_omni', 'llava_next_qwen_hf', + 'llava1_6_yi', 'llava_next_qwen', 'mixtral', 'codefuse_codellama', 'wizardlm2', 'wizardlm2_awq', + 'openbuddy_deepseek', 'sus', 'openbuddy_mixtral', 'openbuddy_llama', 'dbrx', 'nenotron', 'reflection', + 'xverse_moe', 'qwen2_moe', 'yuan2', 'wizardlm2_moe', 'emu3_gen', 'llava1_6_mistral', 'mplug_owl3_241101', + 'llava1_6_yi_hf' + } + for model_name, model_meta in MODEL_MAPPING.items(): + if model_name in skip_model_type: + continue + template = model_meta.template + meta_requires = model_meta.requires or [] + for group in model_meta.model_groups: + model = group.models[0] + if 'awq' in model.ms_model_id.lower() or 'gptq' in model.ms_model_id.lower(): + break + if template in templates: + break + requires = meta_requires + (group.requires or []) + for req in requires: + os.system(f'pip install "{req}"') + if not any(['transformers' in req for req in requires]): + os.system('pip install transformers -U') + if not any(['accelerate' in req for req in requires]): + os.system('pip install accelerate -U') + try: + cmd = ('CUDA_VISIBLE_DEVICES=0 PYTHONPATH=. python tests/llm/load_template.py ' + f'--ms_model_id {model.ms_model_id} --template {template}') + if os.system(cmd) != 0: + raise RuntimeError() + cmd = ( + 'CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/mnt/workspace/yzhao/tastelikefeet/swift python tests/llm/load_template.py ' # noqa + f'--ms_model_id {model.ms_model_id} --template {template} --new 0') + if os.system(cmd) != 0: + raise RuntimeError() + with open('new_input_ids.txt', 'r') as f: + input_ids_new = json.load(f) + with open('old_input_ids.txt', 'r') as f: + input_ids_old = json.load(f) + print('model_id', model.ms_model_id, 'new:', input_ids_new, 'old:', input_ids_old) + self.assertTrue(np.allclose(input_ids_new['input_ids'], input_ids_old['input_ids'])) + except Exception: + import traceback + print(traceback.format_exc()) + passed = False + else: + passed = True + templates.append(template) + finally: + if passed: + with open('./templates.txt', 'w') as f: + json.dump(templates, f) + if os.path.exists('new_input_ids.txt'): + os.remove('new_input_ids.txt') + if os.path.exists('old_input_ids.txt'): + os.remove('old_input_ids.txt') + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/llm/test_template.py b/ms-swift/tests/llm/test_template.py new file mode 100644 index 0000000000000000000000000000000000000000..9856ecdb0661f6de86aab42de738bcf0cfe28582 --- /dev/null +++ b/ms-swift/tests/llm/test_template.py @@ -0,0 +1,104 @@ +import os +import unittest + +from swift.llm import PtEngine, RequestConfig, get_model_tokenizer, get_template +from swift.utils import get_logger, seed_everything + +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ['SWIFT_DEBUG'] = '1' + +logger = get_logger() + + +def _infer_model(pt_engine, system=None, messages=None): + seed_everything(42) + request_config = RequestConfig(max_tokens=128, temperature=0) + if messages is None: + messages = [] + if system is not None: + messages += [{'role': 'system', 'content': system}] + messages += [{'role': 'user', 'content': '你好'}] + resp = pt_engine.infer([{'messages': messages}], request_config=request_config) + response = resp[0].choices[0].message.content + messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '这是什么'}] + resp = pt_engine.infer([{ + 'messages': messages, + }], request_config=request_config) + response = resp[0].choices[0].message.content + messages += [{'role': 'assistant', 'content': response}] + logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}') + return response + + +class TestTemplate(unittest.TestCase): + + def test_template(self): + pt_engine = PtEngine('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4') + response = _infer_model(pt_engine) + pt_engine.default_template.template_backend = 'jinja' + response2 = _infer_model(pt_engine) + assert response == response2 + + def test_tool_message_join(self): + from copy import deepcopy + + from swift.plugin import agent_templates + + messages = [ + # first round + { + 'role': 'user', + 'content': 'user1' + }, + { + 'role': 'assistant', + 'content': 'assistant1' + }, + { + 'role': 'assistant', + 'content': 'assistant2' + }, + { + 'role': 'tool', + 'content': 'tool1' + }, + # second round + { + 'role': 'assistant', + 'content': 'assistant3' + }, + { + 'role': 'tool', + 'content': 'tool2' + }, + { + 'role': 'tool', + 'content': 'tool3' + }, + ] + + # testing two template type. + tokenizer = get_model_tokenizer('Qwen/Qwen2.5-7B-Instruct', load_model=False)[1] + template = get_template(tokenizer.model_meta.template, tokenizer) + for agent_template_type in ('react_zh', 'qwen_zh'): + agent_template = agent_templates[agent_template_type]() + template.agent_template = agent_template + observation = agent_template.keyword.observation + test_messages = deepcopy(messages) + test_messages[2]['content'] = 'assistant2' + observation + test_messages[4]['content'] = ( + agent_template.keyword.action + agent_template.keyword.action_input + 'assistant3' + observation) + encoded = template.encode({'messages': test_messages}) + res = template.safe_decode(encoded['input_ids']) + + ground_truth = ( + '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' + '<|im_start|>user\nuser1<|im_end|>\n' + f'<|im_start|>assistant\nassistant1assistant2{observation}tool1' + f'{agent_template.keyword.action}{agent_template.keyword.action_input}assistant3' + f'{observation}tool2\n{observation}tool3\n') + assert res == ground_truth + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/llm/test_utils.py b/ms-swift/tests/llm/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbc0f82945ede8b99a3ed41c92a6262c0578b00 --- /dev/null +++ b/ms-swift/tests/llm/test_utils.py @@ -0,0 +1,28 @@ +import unittest + +from swift.llm import load_dataset +from swift.utils import lower_bound + + +class TestLlmUtils(unittest.TestCase): + + def test_count_startswith(self): + arr = [-100] * 1000 + list(range(1000)) + self.assertTrue(lower_bound(0, len(arr), lambda i: arr[i] != -100) == 1000) + + def test_count_endswith(self): + arr = list(range(1000)) + [-100] * 1000 + self.assertTrue(lower_bound(0, len(arr), lambda i: arr[i] == -100) == 1000) + + @unittest.skip('avoid ci error') + def test_dataset(self): + dataset = load_dataset(['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#200'], + num_proc=4, + strict=False, + download_mode='force_redownload') + print(f'dataset[0]: {dataset[0]}') + print(f'dataset[1]: {dataset[1]}') + + +if __name__ == '__main__': + unittest.main() diff --git a/ms-swift/tests/megatron/test_export.py b/ms-swift/tests/megatron/test_export.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c7e437d1c593834f1b2703cdfe6f328b266589 --- /dev/null +++ b/ms-swift/tests/megatron/test_export.py @@ -0,0 +1,64 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def _infer_model(pt_engine, system=None, messages=None): + from swift.utils import seed_everything, get_logger + from swift.llm import RequestConfig + logger = get_logger() + seed_everything(42) + request_config = RequestConfig(max_tokens=128, temperature=0) + if messages is None: + messages = [] + if system is not None: + messages += [{'role': 'system', 'content': system}] + messages += [{'role': 'user', 'content': 'who are you?'}] + resp = pt_engine.infer([{'messages': messages}], request_config=request_config) + response = resp[0].choices[0].message.content + messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '这是什么'}] + else: + messages = messages.copy() + resp = pt_engine.infer([{ + 'messages': messages, + }], request_config=request_config) + response = resp[0].choices[0].message.content + messages += [{'role': 'assistant', 'content': response}] + logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}') + return response + + +model_id = 'Qwen/Qwen2-7B-Instruct' + + +def hf2mcore(): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + model=model_id, to_mcore=True, torch_dtype='bfloat16', exist_ok=True, test_convert_precision=True)) + + +def mcore2hf(): + from swift.llm import export_main, ExportArguments + export_main( + ExportArguments( + mcore_model='Qwen2-7B-Instruct-mcore', + to_hf=True, + torch_dtype='bfloat16', + exist_ok=True, + test_convert_precision=True)) + + +def infer_hf_align(): + from swift.llm import PtEngine + pt_engine = PtEngine(model_id) + response = _infer_model(pt_engine) + pt_engine = PtEngine('Qwen2-7B-Instruct-mcore-hf') + response2 = _infer_model(pt_engine) + assert response == response2 + + +if __name__ == '__main__': + # hf2mcore() + mcore2hf() + infer_hf_align() diff --git a/ms-swift/tests/model_tag.py b/ms-swift/tests/model_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..fe732b23fb101d9fbebe8d3a9f0b7a1539f87c3b --- /dev/null +++ b/ms-swift/tests/model_tag.py @@ -0,0 +1,172 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import logging +import os + +import json +import requests + +from swift.version import __version__ + + +# 打标 +class ModelTag(object): + _URL = os.environ.get('MODEL_TAG_URL', None) + + # 模型测试结果 + BATCH_COMMIT_RESULT_URL = f'{_URL}/batchCommitResult' + # 测试阶段完成 + BATCH_REFRESH_STAGE_URL = f'{_URL}/batchRefreshStage' + # query_model_stage + QUERY_MODEL_STAGE_URL = f'{_URL}/queryModelStage' + + HEADER = {'Content-Type': 'application/json'} + + # 检测结果 + MODEL_SKIP = 0 + MODEL_FAIL = 1 + MODEL_PASS = 2 + + class ItemResult(object): + + def __init__(self): + self.result = 0 + self.name = '' + self.info = '' + + def to_json(self): + return {'name': self.name, 'result': self.result, 'info': self.info} + + def __init__(self): + self.job_name = '' + self.job_id = '' + self.model = '' + self.sdk_version = '' + self.image_version = '' + self.domain = '' + self.task = '' + self.source = '' + self.stage = '' + # ItemResult list + self.item_result = [] + + # 发送请求 + def _post_request(self, url, param): + try: + logging.info(url + ' query: ' + str(json.dumps(param, ensure_ascii=False))) + res = requests.post(url=url, headers=self.HEADER, data=json.dumps(param, ensure_ascii=False).encode('utf8')) + if res.status_code == 200: + logging.info(f'{url} post结果: ' + res.text) + res_json = json.loads(res.text) + if int(res_json['errorCode']) == 200: + return res_json['content'] + else: + logging.error(res.text) + else: + logging.error(res.text) + except Exception as e: + logging.error(e) + + return None + + # 提交模型测试结果 + def batch_commit_result(self): + try: + param = { + 'sdkVersion': + self.sdk_version, + 'imageVersion': + self.image_version, + 'source': + self.source, + 'jobName': + self.job_name, + 'jobId': + self.job_id, + 'modelList': [{ + 'model': self.model, + 'domain': self.domain, + 'task': self.task, + 'itemResult': self.item_result + }] + } + return self._post_request(self.BATCH_COMMIT_RESULT_URL, param) + + except Exception as e: + logging.error(e) + + return + + # 测试阶段完成 + def batch_refresh_stage(self): + try: + param = { + 'sdkVersion': self.sdk_version, + 'imageVersion': self.image_version, + 'source': self.source, + 'stage': self.stage, + 'modelList': [{ + 'model': self.model, + 'domain': self.domain, + 'task': self.task + }] + } + return self._post_request(self.BATCH_REFRESH_STAGE_URL, param) + + except Exception as e: + logging.error(e) + + return + + # 查询模型某个阶段的最新测试结果(只返回单个结果 + def query_model_stage(self): + try: + param = { + 'sdkVersion': self.sdk_version, + 'model': self.model, + 'stage': self.stage, + 'imageVersion': self.image_version + } + return self._post_request(self.QUERY_MODEL_STAGE_URL, param) + + except Exception as e: + logging.error(e) + + return None + + # 提交模型UT测试结果 + """ + model_tag = ModelTag() + model_tag.model = "XXX" + model_tag.sdk_version = "0.3.7" + model_tag.domain = "nlp" + model_tag.task = "word-segmentation" + item = model_tag.ItemResult() + item.result = model_tag.MODEL_PASS + item.name = "ALL" + item.info = "" + model_tag.item_result.append(item.to_json()) + """ + + def commit_ut_result(self): + if self._URL is not None and self._URL != '': + self.job_name = 'UT' + self.source = 'dev' + self.stage = 'integration' + + self.batch_commit_result() + self.batch_refresh_stage() + + +def commit_model_ut_result(model_name, ut_result): + model_tag = ModelTag() + model_tag.model = model_name.replace('damo/', '') + model_tag.sdk_version = __version__ + # model_tag.domain = "" + # model_tag.task = "" + item = model_tag.ItemResult() + item.result = ut_result + item.name = 'ALL' + item.info = '' + model_tag.item_result.append(item.to_json()) + model_tag.commit_ut_result() diff --git a/ms-swift/tests/test_utils.py b/ms-swift/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d6cc072a52c5008f1ab2426fe79f194c93511ab7 --- /dev/null +++ b/ms-swift/tests/test_utils.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import os +import pickle +import shutil +import socket +import subprocess +import sys +import tarfile +import tempfile +import unittest +from collections import OrderedDict +from collections.abc import Mapping +from os.path import expanduser + +import numpy as np +import requests +from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH + +TEST_LEVEL = 2 +TEST_LEVEL_STR = 'TEST_LEVEL' + +# for user citest and sdkdev +TEST_ACCESS_TOKEN1 = os.environ.get('TEST_ACCESS_TOKEN_CITEST', None) +TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None) + +TEST_MODEL_CHINESE_NAME = '内部测试模型' +TEST_MODEL_ORG = 'citest' + + +def delete_credential(): + path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) + shutil.rmtree(path_credential, ignore_errors=True) + + +def test_level(): + global TEST_LEVEL + if TEST_LEVEL_STR in os.environ: + TEST_LEVEL = int(os.environ[TEST_LEVEL_STR]) + + return TEST_LEVEL + + +def require_tf(test_case): + test_case = unittest.skip('test requires TensorFlow')(test_case) + return test_case + + +def require_torch(test_case): + return test_case + + +def set_test_level(level: int): + global TEST_LEVEL + TEST_LEVEL = level + + +class DummyTorchDataset: + + def __init__(self, feat, label, num) -> None: + self.feat = feat + self.label = label + self.num = num + + def __getitem__(self, index): + import torch + return {'feat': torch.Tensor(self.feat), 'labels': torch.Tensor(self.label)} + + def __len__(self): + return self.num + + +def create_dummy_test_dataset(feat, label, num): + return DummyTorchDataset(feat, label, num) + + +def download_and_untar(fpath, furl, dst) -> str: + if not os.path.exists(fpath): + r = requests.get(furl) + with open(fpath, 'wb') as f: + f.write(r.content) + + file_name = os.path.basename(fpath) + root_dir = os.path.dirname(fpath) + target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0] + target_dir_path = os.path.join(root_dir, target_dir_name) + + # untar the file + t = tarfile.open(fpath) + t.extractall(path=dst) + + return target_dir_path + + +def get_case_model_info(): + status_code, result = subprocess.getstatusoutput( + 'grep -rn "damo/" tests/ | grep -v ".pyc" | grep -v "Binary file" | grep -v run.py ') + lines = result.split('\n') + test_cases = OrderedDict() + model_cases = OrderedDict() + for line in lines: + # "tests/msdatasets/test_ms_dataset.py:92: model_id = 'damo/bert-base-sst2'" + line = line.strip() + elements = line.split(':') + test_file = elements[0] + model_pos = line.find('damo') + left_quote = line[model_pos - 1] + rquote_idx = line.rfind(left_quote) + model_name = line[model_pos:rquote_idx] + if test_file not in test_cases: + test_cases[test_file] = set() + model_info = test_cases[test_file] + model_info.add(model_name) + + if model_name not in model_cases: + model_cases[model_name] = set() + case_info = model_cases[model_name] + case_info.add(test_file.replace('tests/', '').replace('.py', '').replace('/', '.')) + + return model_cases + + +def compare_arguments_nested(print_content, arg1, arg2, rtol=1.e-3, atol=1.e-8, ignore_unknown_type=True): + type1 = type(arg1) + type2 = type(arg2) + if type1.__name__ != type2.__name__: + if print_content is not None: + print(f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}') + return False + + if arg1 is None: + return True + elif isinstance(arg1, (int, str, bool, np.bool_, np.integer, np.str_)): + if arg1 != arg2: + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (float, np.floating)): + if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True): + if print_content is not None: + print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') + return False + return True + elif isinstance(arg1, (tuple, list)): + if len(arg1) != len(arg2): + if print_content is not None: + print(f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}') + return False + if not all([ + compare_arguments_nested(None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) + for sub_arg1, sub_arg2 in zip(arg1, arg2) + ]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, Mapping): + keys1 = arg1.keys() + keys2 = arg2.keys() + if len(keys1) != len(keys2): + if print_content is not None: + print(f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}') + return False + if len(set(keys1) - set(keys2)) > 0: + if print_content is not None: + print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') + return False + if not all([compare_arguments_nested(None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1]): + if print_content is not None: + print(f'{print_content}') + return False + return True + elif isinstance(arg1, np.ndarray): + arg1 = np.where(np.equal(arg1, None), np.NaN, arg1).astype(dtype=float) + arg2 = np.where(np.equal(arg2, None), np.NaN, arg2).astype(dtype=float) + if not all(np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True).flatten()): + if print_content is not None: + print(f'{print_content}') + return False + return True + else: + if ignore_unknown_type: + return True + else: + raise ValueError(f'type not supported: {type1}') + + +_DIST_SCRIPT_TEMPLATE = """ +import ast +import argparse +import pickle +import torch +from torch import distributed as dist +from modelscope.utils.torch_utils import get_dist_info +import {} + +parser = argparse.ArgumentParser() +parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results') +parser.add_argument('--save_file', type=str, help='save file') +parser.add_argument('--local_rank', type=int, default=0) +args = parser.parse_args() + + +def main(): + results = {}.{}({}) # module.func(params) + if args.save_all_ranks: + save_file = args.save_file + str(dist.get_rank()) + with open(save_file, 'wb') as f: + pickle.dump(results, f) + else: + rank, _ = get_dist_info() + if rank == 0: + with open(args.save_file, 'wb') as f: + pickle.dump(results, f) + + +if __name__ == '__main__': + main() +""" + + +class DistributedTestCase(unittest.TestCase): + """Distributed TestCase for test function with distributed mode. + Examples: + >>> import torch + >>> from torch import distributed as dist + >>> from modelscope.utils.torch_utils import init_dist + + >>> def _test_func(*args, **kwargs): + >>> init_dist(launcher='pytorch') + >>> rank = dist.get_rank() + >>> if rank == 0: + >>> value = torch.tensor(1.0).cuda() + >>> else: + >>> value = torch.tensor(2.0).cuda() + >>> dist.all_reduce(value) + >>> return value.cpu().numpy() + + >>> class DistTest(DistributedTestCase): + >>> def test_function_dist(self): + >>> args = () # args should be python builtin type + >>> kwargs = {} # kwargs should be python builtin type + >>> self.start( + >>> _test_func, + >>> num_gpus=2, + >>> assert_callback=lambda x: self.assertEqual(x, 3.0), + >>> *args, + >>> **kwargs, + >>> ) + """ + + def _start(self, dist_start_cmd, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): + script_path = func.__code__.co_filename + script_dir, script_name = os.path.split(script_path) + script_name = os.path.splitext(script_name)[0] + func_name = func.__qualname__ + + func_params = [] + for arg in args: + if isinstance(arg, str): + arg = ('\'{}\''.format(arg)) + func_params.append(str(arg)) + + for k, v in kwargs.items(): + if isinstance(v, str): + v = ('\'{}\''.format(v)) + func_params.append('{}={}'.format(k, v)) + + func_params = ','.join(func_params).strip(',') + + tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name + tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name + + with open(tmp_run_file, 'w') as f: + print('save temporary run file to : {}'.format(tmp_run_file)) + print('save results to : {}'.format(tmp_res_file)) + run_file_content = _DIST_SCRIPT_TEMPLATE.format(script_name, script_name, func_name, func_params) + f.write(run_file_content) + + tmp_res_files = [] + if save_all_ranks: + for i in range(num_gpus): + tmp_res_files.append(tmp_res_file + str(i)) + else: + tmp_res_files = [tmp_res_file] + self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) + + tmp_env = copy.deepcopy(os.environ) + tmp_env['PYTHONPATH'] = ':'.join((tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') + # avoid distributed test hang + tmp_env['NCCL_P2P_DISABLE'] = '1' + script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, tmp_res_file) + script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) + print('script command: %s' % script_cmd) + res = subprocess.call(script_cmd, shell=True, env=tmp_env) + + script_res = [] + for res_file in tmp_res_files: + with open(res_file, 'rb') as f: + script_res.append(pickle.load(f)) + if not save_all_ranks: + script_res = script_res[0] + + if assert_callback: + assert_callback(script_res) + + self.assertEqual(res, 0, msg='The test function ``{}`` in ``{}`` run failed!'.format(func_name, script_name)) + + return script_res + + def start(self, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): + from .torch_utils import _find_free_port + ip = socket.gethostbyname(socket.gethostname()) + if 'dist_start_cmd' in kwargs: + dist_start_cmd = kwargs.pop('dist_start_cmd') + else: + dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d ' \ + '--master_addr=\'%s\' --master_port=%s' % (sys.executable, num_gpus, ip, _find_free_port()) + + return self._start( + dist_start_cmd=dist_start_cmd, + func=func, + num_gpus=num_gpus, + assert_callback=assert_callback, + save_all_ranks=save_all_ranks, + *args, + **kwargs) + + def clean_tmp(self, tmp_file_list): + for file in tmp_file_list: + if os.path.exists(file): + if os.path.isdir(file): + shutil.rmtree(file) + else: + os.remove(file)