Student0809 commited on
Commit
ac35f70
·
verified ·
1 Parent(s): 2742ed8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ms-swift/silence_overlaps/700/test/silence_transcriptions_test.json +27 -0
  2. ms-swift/swift/plugin/loss_scale/config/ignore_empty_think.json +3 -0
  3. ms-swift/swift/trainers/__pycache__/arguments.cpython-310.pyc +0 -0
  4. ms-swift/swift/trainers/__pycache__/mixin.cpython-310.pyc +0 -0
  5. ms-swift/swift/trainers/__pycache__/utils.cpython-310.pyc +0 -0
  6. ms-swift/swift/trainers/arguments.py +214 -0
  7. ms-swift/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc +0 -0
  8. ms-swift/swift/trainers/optimizers/galore/adamw.py +141 -0
  9. ms-swift/swift/trainers/optimizers/galore/adamw8bit.py +112 -0
  10. ms-swift/swift/trainers/rlhf_trainer/__init__.py +37 -0
  11. ms-swift/swift/trainers/rlhf_trainer/cpo_trainer.py +32 -0
  12. ms-swift/swift/trainers/rlhf_trainer/dpo_trainer.py +129 -0
  13. ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py +1424 -0
  14. ms-swift/swift/trainers/sequence_parallel/__init__.py +8 -0
  15. ms-swift/swift/tuners/__pycache__/base.cpython-310.pyc +0 -0
  16. ms-swift/swift/tuners/base.py +926 -0
  17. ms-swift/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc +0 -0
  18. ms-swift/swift/tuners/longlora/llama.py +409 -0
  19. ms-swift/swift/tuners/reft.py +215 -0
  20. ms-swift/swift/tuners/scetuning/scetuning.py +235 -0
  21. ms-swift/swift/ui/__init__.py +2 -0
  22. ms-swift/swift/ui/llm_eval/llm_eval.py +189 -0
  23. ms-swift/swift/ui/llm_eval/runtime.py +108 -0
  24. ms-swift/swift/ui/llm_export/__init__.py +1 -0
  25. ms-swift/swift/ui/llm_export/export.py +89 -0
  26. ms-swift/swift/ui/llm_train/galore.py +58 -0
  27. ms-swift/swift/ui/llm_train/lisa.py +44 -0
  28. ms-swift/swift/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  29. ms-swift/swift/utils/__pycache__/constants.cpython-310.pyc +0 -0
  30. ms-swift/swift/utils/__pycache__/env.cpython-310.pyc +0 -0
  31. ms-swift/swift/utils/__pycache__/import_utils.cpython-310.pyc +0 -0
  32. ms-swift/swift/utils/__pycache__/io_utils.cpython-310.pyc +0 -0
  33. ms-swift/swift/utils/__pycache__/tb_utils.cpython-310.pyc +0 -0
  34. ms-swift/swift/utils/__pycache__/torchacc_utils.cpython-310.pyc +0 -0
  35. ms-swift/swift/utils/env.py +104 -0
  36. ms-swift/swift/utils/import_utils.py +106 -0
  37. ms-swift/swift/utils/io_utils.py +118 -0
  38. ms-swift/swift/utils/np_utils.py +38 -0
  39. ms-swift/swift/utils/torchacc_utils.py +917 -0
  40. ms-swift/tests/eval/test_eval.py +66 -0
  41. ms-swift/tests/export/test_quant.py +69 -0
  42. ms-swift/tests/general/test_arch.py +44 -0
  43. ms-swift/tests/general/test_dataset.py +90 -0
  44. ms-swift/tests/general/test_model.py +30 -0
  45. ms-swift/tests/general/test_stream.py +20 -0
  46. ms-swift/tests/general/test_template.py +74 -0
  47. ms-swift/tests/hub/__init__.py +0 -0
  48. ms-swift/tests/hub/test_check_model.py +24 -0
  49. ms-swift/tests/infer/test_infer.py +73 -0
  50. ms-swift/tests/infer/test_logprobs.py +71 -0
ms-swift/silence_overlaps/700/test/silence_transcriptions_test.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "key": "SODA_PROCESSED--train--873625",
4
+ "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--873625.wav",
5
+ "model_output": "[00:00 - 00:03] Speaker A: Hi, I'm here to test drive the car that I saw online.\n[00:04 - 00:06] Speaker B: Great! What model are you interested in?\n[00:07 - 00:11] Speaker A: The one that I was looking at was the new BMW M3.\n[00:11 - 00:17] Speaker B: That's a great choice! Let me grab the keys and we'll go over some of the features before heading out for the test drive.\n[00:17 - 00:24] Speaker A: Before we go, could you tell me a bit more about the car's features? I want to make sure it has everything I'm looking for.\n[00:25 - 00:33] Speaker B: Absolutely! The BMW M3 comes with a twin-turbo inline-six engine, adaptive suspension, and a premium interior with all the latest tech.\n[00:34 - 00:35] Speaker A: Sounds good to me.\n[00:36 - 00:39] Speaker B: This car is amazing! The acceleration is incredible!\n[00:39 - 00:46] Speaker A: Yeah, it's definitely a powerful car. But what do you think about the overall driving experience\n[00:45 - 00:52] Speaker B: Sorry to cut in, but I just noticed the steering wheel feels a bit stiff. Is that normal, or is it adjustable?\n[00:56 - 01:07] Speaker A: It's normal for the M3, but it does have adjustable settings to customize the steering feel to your preference. Now, as I was saying, what do you think about the handling overall?\n[01:07 - 01:12] Speaker B: It feels really good. It's responsive and precise. I love it!\n[01:12 - 01:15] Speaker A: Great! So you're interested in purchasing this car?\n[01:16 - 01:18] Speaker B: Yeah, I think I am. How much is it?\n[01:19 - 01:20] Speaker A: It's $60,000.\n[01:21 - 01:24] Speaker B: That's a lot of money. I'm not sure if I can afford that."
6
+ },
7
+ {
8
+ "key": "SODA_PROCESSED--train--891432",
9
+ "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--891432.wav",
10
+ "model_output": "[00:00 - 00:02] Speaker A: I don't know, I just feel like something is\n[00:02 - 00:06] Speaker B: What do you mean? Are you sensing something specific or is it just a general feeling?\n[00:11 - 00:15] Speaker A: I don't know, I just have a bad feeling about this whole situation.\n[00:16 - 00:20] Speaker B: Do you want to talk about it? Maybe we can figure out what's bothering you.\n[00:20 - 00:28] Speaker A: I'm not sure. I just feel like there's something we're overlooking, something that could change everything if we realized it.\n[00:28 - 00:33] Speaker B: Something we're missing? Like a detail we overlooked or something more significant?\n[00:33 - 00:37] Speaker A: Yeah, exactly. But whatever it is, it's not good.\n[00:37 - 00:42] Speaker B: Why do you say that? Is there something that happened recently that triggered this feeling?\n[00:42 - 00:48] Speaker A: Because everything is just too perfect. It's like we're being set up for something, but I don't know what.\n[00:48 - 00:54] Speaker B: That does seem suspicious. But maybe we're just overthinking things, or is there something specific that makes you feel this way?\n[00:55 - 00:58] Speaker A: Maybe...I don't know, I just can't shake this feeling."
11
+ },
12
+ {
13
+ "key": "SODA_PROCESSED--train--527126",
14
+ "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--527126.wav",
15
+ "model_output": "[00:00 - 00:07] Speaker A: Hey, Cherish. I'm starting to work on a new project and I could really use some help. Would you be interested in lending a hand?\n[00:08 - 00:10] Speaker B: Sure, Seneca. What's the project?\n[00:11 - 00:22] Speaker A: It's a bit complicated, but essentially I'm trying to create a detailed map of the world. Every country, every city, every landmark...I want to include it\n[00:21 - 00:26] Speaker B: Sorry to jump in, but how are you planning to gather all that information? It sounds like a massive undertaking.\n[00:31 - 00:54] Speaker A: You're right, it is. I was thinking of using a combination of online resources and some specialized databases. But I haven't mentioned the tools I need yet. It's going to be a lot of work, but I think it will be really cool once it's finished. We might need to use some advanced GIS software and perhaps even some AI tools to process the data efficiently.\n[00:54 - 01:02] Speaker B: That does sound like a lot of work. But if you're up for it, I'm happy to help out however I can. What kind of tools are you thinking of using exactly?"
16
+ },
17
+ {
18
+ "key": "SODA_PROCESSED--train--243315",
19
+ "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--243315.wav",
20
+ "model_output": "[00:00 - 00:04] Speaker A: Hey, Mom. I finished the clothes you were helping me\n[00:03 - 00:05] Speaker B: Oh, let me see! How did they turn out?\n[00:11 - 00:15] Speaker A: Yeah, they turned out pretty good, if I do say so myself.\n[00:15 - 00:29] Speaker B: They most certainly did! You're getting so big and tall; it won't be long before those old clothes of yours won't fit you anymore and we'll need to completely revamp your wardrobe with stylish new outfits that match your growing personality. We'll have to go shopping for new ones soon.\n[00:29 - 00:33] Speaker A: I know, right? It's hard to believe how fast time is flying by.\n[00:34 - 00:39] Speaker B: It really is. Sometimes it feels like only yesterday you were born, and now you're already a young man.\n[00:40 - 00:45] Speaker A: So true. Well, I should probably get going now. Thanks for all your help, Mom.\n[00:46 - 00:48] Speaker B: Of course, sweetheart. Anytime."
21
+ },
22
+ {
23
+ "key": "SODA_PROCESSED--train--201813",
24
+ "audio_url": "/root/autodl-tmp/output_silence45/json/silence/SODA_PROCESSED--train--201813.wav",
25
+ "model_output": "[00:00 - 00:04] Speaker A: Hi, I'm looking for a new couch, coffee table, and rug.\n[00:04 - 00:10] Speaker B: Great! We have a wide selection of furniture to choose from. Let me show you some of our couches first.\n[00:10 - 00:11] Speaker A: OK.\n[00:11 - 00:20] Speaker B: We have a variety of styles and colors to choose from. What is your preferred style? Are you looking for something modern, traditional, or maybe something in between?\n[00:20 - 00:26] Speaker A: Sorry, before we get into that, can you tell me if you have any ongoing discounts or promotions?\n[00:26 - 00:33] Speaker B: Yes, we do have some promotions running right now. I was just about to ask about your budget, though. Do you have a specific number in mind?\n[00:34 - 00:37] Speaker A: I'm not really sure. Maybe around $500?\n[00:38 - 00:49] Speaker B: We have some great options within your budget. This couch here is only $499. It's a popular choice because it's very versatile and can be used in many different ways\n[00:48 - 00:54] Speaker A: Actually, I was also wondering about the durability of this couch. How long does it typically last?\n[01:00 - 01:05] Speaker B: It's made with high-quality materials, so it should last you several years with proper care. Would you like to see it?\n[01:06 - 01:08] Speaker A: Yes, that sounds perfect. I'll take it!"
26
+ }
27
+ ]
ms-swift/swift/plugin/loss_scale/config/ignore_empty_think.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<think>\n\n</think>\n\n": [0.0]
3
+ }
ms-swift/swift/trainers/__pycache__/arguments.cpython-310.pyc ADDED
Binary file (7.91 kB). View file
 
ms-swift/swift/trainers/__pycache__/mixin.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
ms-swift/swift/trainers/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.97 kB). View file
 
ms-swift/swift/trainers/arguments.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import math
3
+ import os
4
+ import platform
5
+ from dataclasses import dataclass, field
6
+ from functools import wraps
7
+ from typing import List, Literal, Optional, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from transformers.training_args import TrainingArguments as HfTrainingArguments
12
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
13
+
14
+ from swift.utils import get_dist_setting, get_logger, is_liger_available, use_torchacc
15
+ from .optimizers.galore import GaLoreConfig
16
+
17
+ logger = get_logger()
18
+
19
+
20
+ @dataclass
21
+ class TrainArgumentsMixin:
22
+ """
23
+ check_model (bool): Flag to check the model is latest. Default is True.
24
+ acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'.
25
+ """
26
+ per_device_train_batch_size: int = 1
27
+ per_device_eval_batch_size: int = 1
28
+ gradient_accumulation_steps: Optional[int] = None
29
+
30
+ gradient_checkpointing: bool = True
31
+ gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
32
+ logging_first_step: bool = True
33
+ logging_steps: int = 5
34
+
35
+ weight_decay: float = 0.1
36
+ adam_beta2: float = 0.95
37
+ lr_scheduler_type: str = 'cosine'
38
+ lr_scheduler_kwargs: Optional[Union[dict, str]] = None
39
+ report_to: List[str] = field(default_factory=lambda: ['tensorboard'])
40
+ dataloader_num_workers: Optional[int] = None
41
+ dataloader_prefetch_factor: Optional[int] = None
42
+ use_liger_kernel: bool = False
43
+
44
+ # extra
45
+ check_model: bool = True
46
+ acc_strategy: Literal['token', 'seq'] = 'token'
47
+ train_dataloader_shuffle: bool = True
48
+ max_epochs: Optional[int] = None
49
+
50
+ # torchacc
51
+ metric_warmup_step: Optional[float] = 0
52
+ fsdp_num: int = 1
53
+ acc_steps: int = 1
54
+
55
+ # train-eval loop args
56
+ eval_use_evalscope: bool = False
57
+ eval_datasets: List[str] = field(default_factory=list)
58
+ eval_limit: Optional[int] = None
59
+ eval_datasets_args: Optional[Union[str, dict]] = None
60
+ eval_generation_config: Optional[Union[str, dict]] = None
61
+
62
+ def _fix_gradient_checkpointing(self):
63
+ # fix use_reentrant
64
+ if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching
65
+ return
66
+ # Consistent with the default behavior of transformers.
67
+ use_reentrant_ = (
68
+ self.gradient_checkpointing_kwargs.get('use_reentrant', True)
69
+ if self.gradient_checkpointing_kwargs else True)
70
+ _old_checkpoint = torch.utils.checkpoint.checkpoint
71
+
72
+ @wraps(_old_checkpoint)
73
+ def _new_checkpoint(*args, use_reentrant=None, **kwargs):
74
+ return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs)
75
+
76
+ torch.utils.checkpoint._old_checkpoint = _old_checkpoint
77
+ torch.utils.checkpoint.checkpoint = _new_checkpoint
78
+ try:
79
+ # Fix the old version of transformers.
80
+ import transformers.modeling_utils
81
+ transformers.modeling_utils.checkpoint = _new_checkpoint
82
+ except (ImportError, AttributeError):
83
+ pass
84
+
85
+ def _init_liger(self):
86
+ if self.use_liger_kernel:
87
+ assert is_liger_available(), 'use_liger_kernel requires liger_kernels, try `pip install liger-kernel`'
88
+
89
+ def __post_init__(self):
90
+ from swift.llm.argument.base_args.model_args import ModelArguments
91
+ if use_torchacc():
92
+ self.dataloader_drop_last = True
93
+ if self.gradient_accumulation_steps is None:
94
+ world_size = get_dist_setting()[2]
95
+ self.gradient_accumulation_steps = max(1, math.ceil(16 / self.per_device_train_batch_size / world_size))
96
+ logger.info(f'Setting args.gradient_accumulation_steps: {self.gradient_accumulation_steps}')
97
+ if self.lr_scheduler_kwargs:
98
+ self.lr_scheduler_kwargs = ModelArguments.parse_to_dict(self.lr_scheduler_kwargs)
99
+ if self.gradient_checkpointing_kwargs:
100
+ self.gradient_checkpointing_kwargs = ModelArguments.parse_to_dict(self.gradient_checkpointing_kwargs)
101
+ self._fix_gradient_checkpointing()
102
+ self._init_liger()
103
+ if self.dataloader_num_workers is None:
104
+ if platform.system() == 'Windows':
105
+ self.dataloader_num_workers = 0
106
+ else:
107
+ self.dataloader_num_workers = 1
108
+ logger.info(f'Setting args.dataloader_num_workers: {self.dataloader_num_workers}')
109
+ if self.dataloader_prefetch_factor is None and self.dataloader_num_workers > 0:
110
+ self.dataloader_prefetch_factor = 10
111
+ if self.eval_use_evalscope:
112
+ try:
113
+ import evalscope
114
+ except ImportError:
115
+ raise ImportError('evalscope is not installed, please install it by `pip install evalscope`')
116
+ self.eval_datasets_args = ModelArguments.parse_to_dict(self.eval_datasets_args)
117
+ self.eval_generation_config = ModelArguments.parse_to_dict(self.eval_generation_config)
118
+
119
+ super().__post_init__()
120
+
121
+
122
+ @dataclass
123
+ class SwiftArgumentsMixin(TrainArgumentsMixin):
124
+ # Value copied from TrainArguments
125
+ train_type: Optional[str] = None
126
+ optimizer: Optional[str] = None
127
+ local_repo_path: Optional[str] = None
128
+ galore_config: Optional[GaLoreConfig] = None
129
+
130
+ def __post_init__(self):
131
+ if hasattr(self, 'output_dir'):
132
+ self.output_dir = os.path.abspath(os.path.expanduser(self.output_dir))
133
+ super().__post_init__()
134
+
135
+ @property
136
+ def place_model_on_device(self):
137
+ return False if use_torchacc() else super().place_model_on_device
138
+
139
+
140
+ @dataclass
141
+ class GRPOArgumentsMixin:
142
+ epsilon: float = 0.2
143
+ epsilon_high: Optional[float] = None
144
+ top_k: int = 50
145
+ top_p: float = 0.9
146
+ repetition_penalty: float = 1.
147
+ num_infer_workers: int = 1
148
+ # vllm
149
+ vllm_device: List[str] = field(default_factory=lambda: ['auto'])
150
+ vllm_gpu_memory_utilization: float = 0.9
151
+ vllm_max_model_len: Optional[int] = None
152
+ vllm_max_num_seqs: int = 256
153
+ vllm_enforce_eager: bool = False
154
+ vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
155
+ vllm_enable_prefix_caching: bool = True
156
+ # reward function args, see details in swift/plugin/orm.py
157
+ # cosine reward, https://arxiv.org/abs/2502.03373
158
+ cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
159
+ cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
160
+ cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
161
+ cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
162
+ cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
163
+ # repetition penalty, https://arxiv.org/abs/2502.03373
164
+ repetition_n_grams: int = 3
165
+ repetition_max_penalty: float = -1.0
166
+
167
+ reward_model: Optional[List[str]] = None
168
+ reward_model_plugin: Optional[List[str]] = None
169
+ # LMDeploy in GRPO
170
+ use_lmdeploy: bool = False
171
+ lmdeploy_device: Optional[str] = 'auto'
172
+ lmdeploy_session_len: Optional[int] = None
173
+ lmdeploy_cache_max_entry_count: float = 0.8
174
+
175
+ async_generate: bool = False
176
+ tensor_parallel_size: int = 1
177
+ sleep_level: int = 0
178
+ move_model_batches: Optional[int] = None
179
+ offload_optimizer: bool = False
180
+ offload_model: bool = False
181
+ gc_collect_after_offload: bool = False
182
+ multi_turn_func: Optional[str] = None
183
+
184
+ # DAPO, https://arxiv.org/abs/2503.14476
185
+ dynamic_sample: bool = False
186
+ max_resample_times: int = 3
187
+ overlong_filter: bool = False
188
+ soft_max_length: Optional[int] = None
189
+ soft_cache_length: Optional[int] = None
190
+
191
+ # Dr. GRPO, https://arxiv.org/abs/2503.20783
192
+ scale_rewards: bool = True
193
+
194
+ # compatible with trl main branch(0.17.0.dev0)
195
+ wandb_log_unique_prompts: Optional[bool] = None
196
+
197
+ # external vllm
198
+ vllm_server_host: Optional[str] = None
199
+ vllm_server_port: int = 8000
200
+ vllm_server_timeout: float = 240.0
201
+ vllm_client = None
202
+
203
+ # dataset
204
+ dataset_shuffle: Optional[bool] = True
205
+
206
+
207
+ @dataclass
208
+ class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
209
+ pass
210
+
211
+
212
+ @dataclass
213
+ class Seq2SeqTrainingArguments(SwiftArgumentsMixin, HfSeq2SeqTrainingArguments):
214
+ pass
ms-swift/swift/trainers/optimizers/galore/__pycache__/utils.cpython-310.pyc ADDED
Binary file (7.17 kB). View file
 
ms-swift/swift/trainers/optimizers/galore/adamw.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy dependencies from transformers/optimization.py
2
+ # code borrowed from https://github.com/jiaweizzhao/GaLore
3
+ import math
4
+ from typing import Callable, Iterable, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Optimizer
9
+ from transformers.utils.versions import require_version
10
+
11
+ from .galore_projector import GaLoreProjector
12
+
13
+
14
+ class AdamW(Optimizer):
15
+ """
16
+ Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
17
+ Regularization](https://arxiv.org/abs/1711.05101).
18
+
19
+ Parameters:
20
+ params (`Iterable[nn.parameter.Parameter]`):
21
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
22
+ lr (`float`, *optional*, defaults to 0.001):
23
+ The learning rate to use.
24
+ betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
25
+ Adam's betas parameters (b1, b2).
26
+ eps (`float`, *optional*, defaults to 1e-06):
27
+ Adam's epsilon for numerical stability.
28
+ weight_decay (`float`, *optional*, defaults to 0.0):
29
+ Decoupled weight decay to apply.
30
+ correct_bias (`bool`, *optional*, defaults to `True`):
31
+ Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
32
+ no_deprecation_warning (`bool`, *optional*, defaults to `False`):
33
+ A flag used to disable the deprecation warning (set to `True` to disable the warning).
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ params: Iterable[nn.parameter.Parameter],
39
+ lr: float = 1e-3,
40
+ betas: Tuple[float, float] = (0.9, 0.999),
41
+ eps: float = 1e-6,
42
+ weight_decay: float = 0.0,
43
+ correct_bias: bool = True,
44
+ no_deprecation_warning: bool = False,
45
+ ):
46
+ require_version('torch>=1.5.0') # add_ with alpha
47
+ if lr < 0.0:
48
+ raise ValueError(f'Invalid learning rate: {lr} - should be >= 0.0')
49
+ if not 0.0 <= betas[0] < 1.0:
50
+ raise ValueError(f'Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)')
51
+ if not 0.0 <= betas[1] < 1.0:
52
+ raise ValueError(f'Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)')
53
+ if not 0.0 <= eps:
54
+ raise ValueError(f'Invalid epsilon value: {eps} - should be >= 0.0')
55
+ defaults = {'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay, 'correct_bias': correct_bias}
56
+ super().__init__(params, defaults)
57
+
58
+ @torch.no_grad()
59
+ def step(self, closure: Callable = None):
60
+ """
61
+ Performs a single optimization step.
62
+
63
+ Arguments:
64
+ closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
65
+ """
66
+ loss = None
67
+ if closure is not None:
68
+ loss = closure()
69
+
70
+ for group in self.param_groups:
71
+ for p in group['params']:
72
+ if p.grad is None:
73
+ continue
74
+ grad = p.grad
75
+ if grad.is_sparse:
76
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
77
+
78
+ state = self.state[p]
79
+
80
+ if 'step' not in state:
81
+ state['step'] = 0
82
+
83
+ # GaLore Projection
84
+ if 'rank' in group:
85
+ if 'projector' not in state:
86
+ state['projector'] = GaLoreProjector(
87
+ group['rank'],
88
+ update_proj_gap=group['update_proj_gap'],
89
+ scale=group['scale'],
90
+ proj_type=group['proj_type'])
91
+
92
+ grad = state['projector'].project(grad, state['step'])
93
+
94
+ # State initialization
95
+ if 'exp_avg' not in state:
96
+ # Exponential moving average of gradient values
97
+ state['exp_avg'] = torch.zeros_like(grad)
98
+ # Exponential moving average of squared gradient values
99
+ state['exp_avg_sq'] = torch.zeros_like(grad)
100
+
101
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
102
+ beta1, beta2 = group['betas']
103
+
104
+ state['step'] += 1
105
+
106
+ # Decay the first and second moment running average coefficient
107
+ # In-place operations to update the averages at the same time
108
+ exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
109
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
110
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
111
+
112
+ step_size = group['lr']
113
+ if group['correct_bias']: # No bias correction for Bert
114
+ bias_correction1 = 1.0 - beta1**state['step']
115
+ bias_correction2 = 1.0 - beta2**state['step']
116
+ step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
117
+
118
+ # compute norm gradient
119
+ norm_grad = exp_avg / denom
120
+
121
+ # GaLore Projection Back
122
+ if 'rank' in group:
123
+ norm_grad = state['projector'].project_back(norm_grad)
124
+
125
+ p.add_(norm_grad, alpha=-step_size)
126
+
127
+ # Just adding the square of the weights to the loss function is *not*
128
+ # the correct way of using L2 regularization/weight decay with Adam,
129
+ # since that will interact with the m and v parameters in strange ways.
130
+ #
131
+ # Instead we want to decay the weights in a manner that doesn't interact
132
+ # with the m/v parameters. This is equivalent to adding the square
133
+ # of the weights to the loss with plain (non-momentum) SGD.
134
+ # Add weight decay at the end (fixed version)
135
+ if group['weight_decay'] > 0.0:
136
+ p.add_(p, alpha=(-group['lr'] * group['weight_decay']))
137
+
138
+ return loss
139
+
140
+
141
+ GaLoreAdamW = AdamW
ms-swift/swift/trainers/optimizers/galore/adamw8bit.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code borrowed from https://github.com/jiaweizzhao/GaLore
2
+ import torch
3
+ from bitsandbytes.optim.optimizer import Optimizer2State
4
+
5
+ from .galore_projector import GaLoreProjector
6
+
7
+
8
+ class AdamW8bit(Optimizer2State):
9
+
10
+ def __init__(self,
11
+ params,
12
+ lr=1e-3,
13
+ betas=(0.9, 0.999),
14
+ eps=1e-8,
15
+ weight_decay=1e-2,
16
+ amsgrad=False,
17
+ optim_bits=32,
18
+ args=None,
19
+ min_8bit_size=4096,
20
+ percentile_clipping=100,
21
+ block_wise=True,
22
+ is_paged=False):
23
+ super().__init__(
24
+ 'adam',
25
+ params,
26
+ lr,
27
+ betas,
28
+ eps,
29
+ weight_decay,
30
+ 8,
31
+ args,
32
+ min_8bit_size,
33
+ percentile_clipping,
34
+ block_wise,
35
+ is_paged=is_paged)
36
+
37
+ @torch.no_grad()
38
+ def step(self, closure=None):
39
+ """Performs a single optimization step.
40
+
41
+ Arguments:
42
+ closure (callable, optional): A closure that reevaluates the model
43
+ and returns the loss.
44
+ """
45
+ loss = None
46
+ if closure is not None:
47
+ with torch.enable_grad():
48
+ loss = closure()
49
+
50
+ if not self.initialized:
51
+ self.check_overrides()
52
+ self.to_gpu() # needed for fairseq pure fp16 training
53
+ self.initialized = True
54
+
55
+ # if self.is_paged: self.page_mng.prefetch_all()
56
+ for gindex, group in enumerate(self.param_groups):
57
+ for pindex, p in enumerate(group['params']):
58
+ if p.grad is None:
59
+ continue
60
+ state = self.state[p]
61
+
62
+ if 'step' not in state:
63
+ state['step'] = 0
64
+
65
+ # GaLore Projection
66
+ if 'rank' in group:
67
+ if 'projector' not in state:
68
+ state['projector'] = GaLoreProjector(
69
+ group['rank'],
70
+ update_proj_gap=group['update_proj_gap'],
71
+ scale=group['scale'],
72
+ proj_type=group['proj_type'])
73
+
74
+ if 'weight_decay' in group and group['weight_decay'] > 0:
75
+ # ensure that the weight decay is not applied to the norm grad
76
+ group['weight_decay_saved'] = group['weight_decay']
77
+ group['weight_decay'] = 0
78
+
79
+ grad = state['projector'].project(p.grad, state['step'])
80
+
81
+ # suboptimal implementation
82
+ p.saved_data = p.data.clone()
83
+ p.data = grad.clone().to(p.data.dtype).to(p.data.device)
84
+ p.data.zero_()
85
+ p.grad = grad
86
+
87
+ if 'state1' not in state:
88
+ self.init_state(group, p, gindex, pindex)
89
+
90
+ self.prefetch_state(p)
91
+ self.update_step(group, p, gindex, pindex)
92
+ torch.cuda.synchronize()
93
+
94
+ # GaLore Projection Back
95
+ if 'rank' in group:
96
+ p.data = p.saved_data.add_(state['projector'].project_back(p.data))
97
+
98
+ # apply weight decay
99
+ if 'weight_decay_saved' in group:
100
+ p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved'])
101
+ group['weight_decay'] = group['weight_decay_saved']
102
+ del group['weight_decay_saved']
103
+
104
+ if self.is_paged:
105
+ # all paged operation are asynchronous, we need
106
+ # to sync to make sure all tensors are in the right state
107
+ torch.cuda.synchronize()
108
+
109
+ return loss
110
+
111
+
112
+ GaLoreAdamW8bit = AdamW8bit
ms-swift/swift/trainers/rlhf_trainer/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import TYPE_CHECKING
3
+
4
+ from swift.utils.import_utils import _LazyModule
5
+
6
+ if TYPE_CHECKING:
7
+ from .cpo_trainer import CPOTrainer
8
+ from .dpo_trainer import DPOTrainer
9
+ from .grpo_trainer import GRPOTrainer
10
+ from .kto_trainer import KTOTrainer
11
+ from .orpo_trainer import ORPOTrainer
12
+ from .ppo_trainer import PPOTrainer
13
+ from .reward_trainer import RewardTrainer
14
+ from .rlhf_mixin import RLHFTrainerMixin
15
+ from .utils import _split_into_mini_batches, patch_lora_merge, patch_lora_unmerge, round_robin
16
+ else:
17
+ _import_structure = {
18
+ 'cpo_trainer': ['CPOTrainer'],
19
+ 'dpo_trainer': ['DPOTrainer'],
20
+ 'grpo_trainer': ['GRPOTrainer'],
21
+ 'kto_trainer': ['KTOTrainer'],
22
+ 'orpo_trainer': ['ORPOTrainer'],
23
+ 'ppo_trainer': ['PPOTrainer'],
24
+ 'reward_trainer': ['RewardTrainer'],
25
+ 'rlhf_mixin': ['RLHFTrainerMixin'],
26
+ 'utils': ['_split_into_mini_batches', 'patch_lora_merge', 'patch_lora_unmerge', 'round_robin'],
27
+ }
28
+
29
+ import sys
30
+
31
+ sys.modules[__name__] = _LazyModule(
32
+ __name__,
33
+ globals()['__file__'],
34
+ _import_structure,
35
+ module_spec=__spec__,
36
+ extra_objects={},
37
+ )
ms-swift/swift/trainers/rlhf_trainer/cpo_trainer.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import warnings
3
+ from typing import Optional, Union
4
+
5
+ import torch.nn as nn
6
+ from transformers import PreTrainedModel
7
+ from trl import CPOTrainer as HFCPOTrainer
8
+
9
+ from ..mixin import SwiftMixin
10
+ from .rlhf_mixin import RLHFTrainerMixin
11
+
12
+ del HFCPOTrainer.__init__
13
+
14
+
15
+ class CPOTrainer(RLHFTrainerMixin, SwiftMixin, HFCPOTrainer):
16
+
17
+ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
18
+ ref_model = kwargs.get('ref_model')
19
+ assert ref_model is None, 'CPO/SimPO does not require a ref_model.'
20
+
21
+ args = kwargs['args']
22
+ self.label_smoothing = args.label_smoothing
23
+ self.loss_type = args.loss_type
24
+ self.cpo_alpha = args.cpo_alpha
25
+ if args.loss_type == 'simpo':
26
+ self.simpo_gamma = args.simpo_gamma
27
+ if self.cpo_alpha > 0:
28
+ warnings.warn('You are using CPO-SimPO method because you set a non-zero cpo_alpha. '
29
+ 'This will result in the CPO-SimPO method '
30
+ '(https://github.com/fe1ixxu/CPO_SIMPO/tree/main). '
31
+ 'If you want to use a pure SimPO method, please set cpo_alpha to 0.')
32
+ super().__init__(model, *_args, **kwargs)
ms-swift/swift/trainers/rlhf_trainer/dpo_trainer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from peft import PeftModel
7
+ from transformers import PreTrainedModel
8
+ from trl import DPOTrainer as HFDPOTrainer
9
+
10
+ from ..mixin import DataLoaderMixin, SwiftMixin
11
+ from .rlhf_mixin import RLHFTrainerMixin
12
+
13
+ del HFDPOTrainer.__init__
14
+
15
+
16
+ class DPOTrainer(RLHFTrainerMixin, SwiftMixin, DataLoaderMixin, HFDPOTrainer):
17
+
18
+ def __init__(self,
19
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
20
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
21
+ *_args,
22
+ **kwargs):
23
+ from trl.trainer import FDivergenceConstants
24
+ args = kwargs['args']
25
+ self.label_smoothing = args.label_smoothing
26
+ self.loss_type = args.loss_type
27
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
28
+ self.f_divergence_type = args.f_divergence_type
29
+ self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
30
+ self.is_peft_model = isinstance(model, PeftModel)
31
+
32
+ self.ref_adapter_name = args.ref_adapter_name
33
+ self.reference_free = args.reference_free
34
+ self.use_weighting = False
35
+
36
+ super().__init__(model, ref_model, *_args, **kwargs)
37
+
38
+ def get_nll_loss(self, logits, labels):
39
+ if not self.is_encoder_decoder:
40
+ # Shift so that tokens < n predict n
41
+ logits = logits[..., :-1, :].contiguous()
42
+ labels = labels[..., 1:].contiguous()
43
+ # Flatten the tokens
44
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
45
+ logits = logits.view(-1, logits.shape[-1])
46
+ labels = labels.view(-1)
47
+ # Enable model parallelism
48
+ labels = labels.to(logits.device)
49
+ return loss_fct(logits, labels)
50
+
51
+ def concatenated_forward(
52
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
53
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
54
+ batch = batch.copy()
55
+ num_examples = batch['labels'].shape[0] // 2
56
+ labels = batch.pop('labels', None)
57
+ if self.is_encoder_decoder:
58
+ batch['labels'] = labels
59
+
60
+ if self.aux_loss_enabled:
61
+ batch['output_router_logits'] = True
62
+ outputs = model(**batch, use_cache=False)
63
+ batch['labels'] = labels
64
+ if outputs.logits.shape[1] != labels.shape[1]:
65
+ # for llava, the model returns logits for the entire sequence, including the image tokens
66
+ # (placed before the text tokens)
67
+ outputs.logits = outputs.logits[:, -labels.shape[1]:]
68
+ for key in ['input_ids', 'attention_mask', 'labels']:
69
+ batch[f'concatenated_{key}'] = batch.pop(key, None)
70
+ if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
71
+ batch['concatenated_input_ids'] = batch['concatenated_labels']
72
+
73
+ all_logits = outputs.logits
74
+
75
+ if all_logits.shape[:2] != batch['concatenated_labels'].shape[:2]:
76
+ # for llava, the model returns logits for the entire sequence,
77
+ # including the image tokens (placed before the text tokens)
78
+ seq_len = batch['concatenated_labels'].shape[1]
79
+ all_logits = all_logits[:, -seq_len:]
80
+
81
+ all_logps, size_completion = self.get_batch_logps(
82
+ all_logits,
83
+ batch['concatenated_labels'],
84
+ is_encoder_decoder=self.is_encoder_decoder,
85
+ label_pad_token_id=self.label_pad_token_id,
86
+ )
87
+
88
+ output = {}
89
+
90
+ if self.args.rpo_alpha is not None:
91
+ labels = batch['concatenated_labels'].clone()
92
+ output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples])
93
+
94
+ if self.loss_type == 'ipo':
95
+ all_logps = all_logps / size_completion
96
+
97
+ output['chosen_logps'] = all_logps[:num_examples]
98
+ output['rejected_logps'] = all_logps[num_examples:]
99
+ output['mean_chosen_logits'] = all_logits[:num_examples].mean()
100
+ output['mean_rejected_logits'] = all_logits[num_examples:].mean()
101
+
102
+ if self.aux_loss_enabled:
103
+ output['aux_loss'] = outputs.aux_loss
104
+
105
+ return output
106
+
107
+ @staticmethod
108
+ def get_batch_logps(
109
+ logits: torch.FloatTensor,
110
+ labels: torch.LongTensor,
111
+ label_pad_token_id: int = -100,
112
+ is_encoder_decoder: bool = False,
113
+ ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
114
+ if logits.shape[:-1] != labels.shape:
115
+ raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
116
+ 'and labels must have the same shape {labels.shape}')
117
+ if not is_encoder_decoder:
118
+ labels = labels[:, 1:].clone()
119
+ logits = logits[:, :-1, :]
120
+ else:
121
+ labels = labels.clone()
122
+
123
+ loss_mask = labels != label_pad_token_id
124
+
125
+ labels[labels == label_pad_token_id] = 0
126
+
127
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
128
+
129
+ return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
ms-swift/swift/trainers/rlhf_trainer/grpo_trainer.py ADDED
@@ -0,0 +1,1424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Part of the implementation is borrowed from huggingface/trl.
3
+ import concurrent.futures
4
+ import inspect
5
+ import os
6
+ import re
7
+ import time
8
+ from collections import defaultdict, deque
9
+ from concurrent.futures import Future
10
+ from contextlib import contextmanager
11
+ from copy import copy, deepcopy
12
+ from dataclasses import asdict, dataclass, field
13
+ from math import ceil
14
+ from queue import Queue
15
+ from types import MethodType
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import datasets
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import transformers
23
+ from accelerate.utils import gather, gather_object, is_peft_model, set_seed
24
+ from packaging import version
25
+ from torch.nn import ModuleList
26
+ from torch.utils.data import DataLoader
27
+ from transformers import PreTrainedModel, TrainerCallback
28
+ from transformers.integrations import is_deepspeed_zero3_enabled
29
+ from transformers.trainer import Trainer
30
+ from transformers.trainer_utils import seed_worker
31
+ from trl import GRPOTrainer as HFGRPOTrainer
32
+ from trl.extras.profiling import profiling_decorator
33
+ from trl.models import prepare_deepspeed
34
+ from trl.trainer.grpo_trainer import nanmax, nanmin
35
+
36
+ from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device
37
+ from swift.llm.infer.infer_engine import set_device_context
38
+ from swift.llm.template.template_inputs import StdTemplateInputs
39
+ from swift.plugin import multi_turns, orms, rm_plugins
40
+ from swift.utils import (JsonlWriter, gc_collect, get_device, get_device_count, get_dist_setting, get_logger,
41
+ get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available)
42
+ from ..mixin import SwiftMixin
43
+ from .rlhf_mixin import RLHFTrainerMixin
44
+ from .utils import patch_lora_merge, patch_lora_unmerge, round_robin
45
+
46
+ del HFGRPOTrainer.__init__
47
+ del HFGRPOTrainer.log
48
+
49
+ logger = get_logger()
50
+ if is_wandb_available():
51
+ import wandb
52
+
53
+ InputsType = List[Dict[str, Union[torch.Tensor, Any]]]
54
+ OutputsType = List[List[Tuple[List[Dict], str]]]
55
+
56
+
57
+ @contextmanager
58
+ def unwrap_model_for_generation(
59
+ model,
60
+ accelerator,
61
+ gather_deepspeed3_params=True,
62
+ gather_parameters: List = None,
63
+ ):
64
+ unwrapped_model = accelerator.unwrap_model(model)
65
+ if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
66
+ if not gather_deepspeed3_params:
67
+ yield accelerator.unwrap_model(model)
68
+ else:
69
+ import deepspeed
70
+ parameters = [
71
+ parameter for name, parameter in model.named_parameters()
72
+ if not gather_parameters or name in gather_parameters
73
+ ]
74
+ with deepspeed.zero.GatheredParameters(parameters):
75
+ from trl.models.utils import remove_hooks
76
+ remove_hooks(model)
77
+ yield accelerator.unwrap_model(model)
78
+ from trl.models.utils import add_hooks
79
+ add_hooks(model)
80
+ else:
81
+ yield unwrapped_model
82
+
83
+
84
+ class GRPOCallback(TrainerCallback):
85
+
86
+ def __init__(self, trainer):
87
+ self.trainer = trainer
88
+
89
+ # offload original_modules to cpu, to save memory
90
+ def on_train_begin(self, args, state, control, **kwargs):
91
+ self.trainer.queue = self.trainer.train_queue
92
+ train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
93
+ self.trainer._prefetch(train_dataloader)
94
+
95
+
96
+ @dataclass
97
+ class DataCache:
98
+ inputs: List[Dict] = field(default_factory=list)
99
+ outputs: List[Dict] = field(default_factory=list)
100
+ distributed_idx: List[List] = field(default_factory=list)
101
+
102
+
103
+ class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
104
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
105
+
106
+ def __init__(self,
107
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
108
+ ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
109
+ reward_model: Optional[List[Union[PreTrainedModel, nn.Module]]] = None,
110
+ reward_funcs: Optional[List[Union[str, Callable]]] = None,
111
+ *_args,
112
+ **kwargs):
113
+ from swift.trainers.rlhf_arguments import GRPOConfig
114
+ args: GRPOConfig = kwargs['args']
115
+ self.args = args
116
+ self.train_queue = Queue()
117
+ self.eval_queue = Queue()
118
+ self.processing_class = kwargs.get('template').tokenizer
119
+ self.offload_modules = {}
120
+ self.offload_states = {}
121
+ _, _, _, local_world_size = get_dist_setting()
122
+
123
+ if not isinstance(reward_funcs, list):
124
+ reward_funcs = [reward_funcs]
125
+
126
+ if reward_funcs:
127
+ for i, reward_func in enumerate(reward_funcs):
128
+ if reward_func in orms:
129
+ reward_func_class = orms[reward_func]
130
+ reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
131
+ reward_func_kwargs = {
132
+ key: getattr(args, key)
133
+ for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
134
+ }
135
+ if 'tokenizer' in reward_func_args:
136
+ reward_func_kwargs['tokenizer'] = self.processing_class
137
+ reward_funcs[i] = reward_func_class(**reward_func_kwargs)
138
+ elif not callable(reward_func):
139
+ raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')
140
+
141
+ self.reward_funcs = reward_funcs
142
+ self.reward_func_names = []
143
+ for reward_func in reward_funcs:
144
+ if inspect.isfunction(reward_func):
145
+ reward_func_name = reward_func.__name__
146
+ else:
147
+ reward_func_name = reward_func.__class__.__name__
148
+ self.reward_func_names.append(reward_func_name)
149
+
150
+ self.reward_model_plugins = [None] * len(self.reward_funcs)
151
+
152
+ if reward_model is not None:
153
+ reward_template = kwargs.pop('reward_template')
154
+ reward_plugins = args.reward_model_plugin
155
+ if reward_plugins is None:
156
+ reward_plugins = ['default'] * len(reward_model)
157
+ assert len(reward_plugins) == len(reward_model), (
158
+ f"The number of 'reward_model_plugin' ({len(reward_plugins)}) does not match "
159
+ f"the number of 'reward_model' ({len(reward_model)}). "
160
+ "Please provide a corresponding 'reward_model_plugin' for each 'reward_model'.")
161
+ for rm, rm_plugin, rm_template in zip(reward_model, reward_plugins, reward_template):
162
+ # Set encoding mode train(see details in Template.encode).
163
+ # Set max_length to None to disable truncation, as the input length has already been truncated earlier.
164
+ rm_template.set_mode('train')
165
+ rm_template.max_length = None
166
+ if rm_plugin not in rm_plugins:
167
+ raise ValueError(f'rm_plugin {rm_plugin} is not implemented in swift.llm.plugin')
168
+ self.reward_model_plugins.append(rm_plugins[rm_plugin](model=rm, template=rm_template))
169
+ self.reward_funcs.append(rm)
170
+ self.reward_func_names.append(rm.config._name_or_path.split('/')[-1])
171
+
172
+ if not self.reward_funcs:
173
+ raise ValueError('You must specify reward_funcs or reward_model')
174
+
175
+ # Reward weights
176
+ if args.reward_weights is not None:
177
+ if len(args.reward_weights) != len(reward_funcs):
178
+ raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
179
+ f'functions ({len(reward_funcs)})')
180
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
181
+ else:
182
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
183
+
184
+ self.multi_turn_func = None
185
+ if self.args.multi_turn_func:
186
+ if isinstance(self.args.multi_turn_func, str):
187
+ assert self.args.multi_turn_func in multi_turns
188
+ multi_turn_func = multi_turns[self.args.multi_turn_func]
189
+ self.multi_turn_func = multi_turn_func
190
+ else:
191
+ self.multi_turn_func = self.args.multi_turn_func
192
+
193
+ self.num_generations = args.num_generations
194
+ self.temperature = args.temperature
195
+ self.loss_type = args.loss_type
196
+ model.warnings_issued['estimate_tokens'] = True
197
+ kwargs['data_collator'] = lambda features: features
198
+ self.shuffle_dataset = args.dataset_shuffle
199
+
200
+ use_vllm = args.use_vllm
201
+ use_lmdeploy = args.use_lmdeploy
202
+ vllm_client = kwargs.pop('vllm_client') # for external vllm
203
+ if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
204
+ import torch.distributed as dist
205
+ rank, _, _, _ = get_dist_setting()
206
+ for tp_group in self.tp_group_ranks():
207
+ group = dist.new_group(tp_group)
208
+ if rank in tp_group:
209
+ self.group = group
210
+
211
+ super().__init__(model, ref_model, *_args, **kwargs)
212
+
213
+ self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
214
+ self.log_completions = args.log_completions
215
+ self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
216
+ self.num_completions_to_print = args.num_completions_to_print
217
+ self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
218
+ # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
219
+ # final optimization step.
220
+ maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
221
+ self._textual_logs = {
222
+ 'prompt': deque(maxlen=maxlen),
223
+ 'completion': deque(maxlen=maxlen),
224
+ 'rewards': defaultdict(lambda: deque(maxlen=maxlen)),
225
+ }
226
+
227
+ num_processes = self.accelerator.num_processes
228
+ self.effective_train_batch_size = effective_batch_size = \
229
+ args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps
230
+ possible_values = [n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0]
231
+
232
+ if self.num_generations not in possible_values:
233
+ raise ValueError(
234
+ f'The effective train batch size ({num_processes} x {args.per_device_train_batch_size} x '
235
+ f'{args.gradient_accumulation_steps}) must be evenly divisible by the number of generations per '
236
+ f'prompt ({self.num_generations}). Given the current effective train batch size, the valid values for '
237
+ f'the number of generations are: {possible_values}.')
238
+ if self.args.eval_strategy != 'no':
239
+ effective_batch_size = args.per_device_eval_batch_size * num_processes
240
+ possible_values = [
241
+ n_gen for n_gen in range(2, effective_batch_size + 1) if (effective_batch_size) % n_gen == 0
242
+ ]
243
+ if self.num_generations not in possible_values:
244
+ raise ValueError(
245
+ f'The effective eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be '
246
+ f'evenly divisible by the number of generations per prompt ({self.num_generations}). Given the '
247
+ 'current effective eval batch size, the valid values for the number of generations are: '
248
+ f'{possible_values}.')
249
+
250
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
251
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
252
+ # it's safer to set it in all cases.
253
+ set_seed(args.seed, device_specific=True)
254
+ self.parameter_groups, self.parameter_groups_no_lora = self.split_batches()
255
+ self.infer_device = None
256
+ self.use_fast_infer = use_vllm or use_lmdeploy # whether to use the PT backend
257
+ self.is_external_vllm = use_vllm and args.vllm_server_host is not None
258
+ if self.use_fast_infer:
259
+ if self.infer_rank >= 0:
260
+ fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device
261
+ if fast_infer_device[0] == 'auto':
262
+ if get_device_count() == 1:
263
+ fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it
264
+ else:
265
+ fast_infer_device = []
266
+ for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()):
267
+ fast_infer_device.append(get_device(idx))
268
+
269
+ for _device in fast_infer_device:
270
+ # Check that the requested device is available
271
+ if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count():
272
+ raise ValueError(f'The requested device for vllm ({_device}) is not available. '
273
+ f'You are likely using vLLM '
274
+ 'without restricting the number of GPUs for training. '
275
+ 'Set the `--num_processes` argument to a '
276
+ 'value lower than the number of GPUs available on your machine—typically, '
277
+ 'reducing it by one is sufficient. '
278
+ f'In your case: `--num_processes {get_device_count() - 1}`.')
279
+
280
+ if use_vllm:
281
+ if not is_vllm_available():
282
+ raise ImportError('vLLM is not available and `use_vllm` is set to True. '
283
+ 'Please install vLLM with `pip install vllm -U` to use it.')
284
+ if self.is_external_vllm:
285
+ self.vllm_client = vllm_client
286
+ else:
287
+ self.engine = self.prepare_vllm(model, fast_infer_device)
288
+ self.infer_device = fast_infer_device[self.local_infer_rank]
289
+ elif use_lmdeploy:
290
+ if not is_lmdeploy_available():
291
+ raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.'
292
+ 'Please install LMDeploy with `pip install lmdeploy -U` to use it.')
293
+ from swift.llm import LmdeployEngine
294
+ from swift.tuners import Swift
295
+ with Swift.grpo_context(model, self.template.processor):
296
+ fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1])
297
+ self.engine = LmdeployEngine(
298
+ model.model_dir,
299
+ model.model_info.torch_dtype,
300
+ model_type=model.model_meta.model_type,
301
+ devices=[fast_infer_device],
302
+ session_len=args.lmdeploy_session_len,
303
+ cache_max_entry_count=args.lmdeploy_cache_max_entry_count,
304
+ reload_weights=True)
305
+ self.infer_device = fast_infer_device
306
+ from lmdeploy.turbomind.turbomind import TurboMind
307
+ lmdeploy_engine = self.engine.engine.engine
308
+ assert isinstance(lmdeploy_engine, TurboMind), (
309
+ "Currently only LMDeploy's TurboMind backend is supported. "
310
+ 'The current model is incompatible - please use vLLM or PyTorch backend instead.')
311
+ if not self.is_external_vllm:
312
+ self.engine.default_template = copy(self.template) # Avoid thread-unsafe modifications of the mode.
313
+ self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
314
+
315
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
316
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
317
+ # synchronize all processes after vLLM has been fully initialized.
318
+ self.accelerator.wait_for_everyone()
319
+ else:
320
+ from swift.llm import PtEngine
321
+ self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit
322
+ # Avoid thread-unsafe modifications of the mode.
323
+ self.request_config = RequestConfig(
324
+ max_tokens=args.max_completion_length,
325
+ temperature=args.temperature,
326
+ top_p=args.top_p,
327
+ top_k=args.top_k,
328
+ repetition_penalty=args.repetition_penalty,
329
+ stop=args.stop_words,
330
+ )
331
+
332
+ if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
333
+ self.request_config.n = self.args.tensor_parallel_size
334
+ if self.infer_rank >= 0:
335
+ self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size
336
+
337
+ self.model_accepts_loss_kwargs = False
338
+
339
+ for i, reward_func in enumerate(self.reward_funcs):
340
+ if isinstance(reward_func, PreTrainedModel):
341
+ if self.is_deepspeed_enabled:
342
+ self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
343
+ else:
344
+ self.reward_funcs[i] = self.accelerator.prepare_model(
345
+ reward_func, evaluation_mode=True, device_placement=True)
346
+
347
+ # Multi-step
348
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
349
+ self.epsilon_low = args.epsilon
350
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
351
+
352
+ # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa
353
+ self._step = 0
354
+ # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
355
+ # `_get_train_sampler` and `_prepare_inputs`.
356
+ self._buffered_inputs = None
357
+ if self.args.async_generate:
358
+ self.add_callback(GRPOCallback(self))
359
+
360
+ if self.args.dynamic_sample:
361
+ self.resample_dataset = deepcopy(self.train_dataset)
362
+
363
+ def cyclic_iter(iterable):
364
+ while True:
365
+ for x in iterable:
366
+ yield x
367
+
368
+ self.resample_iterator = cyclic_iter(self.get_resample_dataloader())
369
+ # flag indicating whether the evaluation has started
370
+ self.eval_flag = False
371
+
372
+ @profiling_decorator
373
+ def _prepare_inputs(
374
+ self, accumulated_local_batch: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
375
+ mode = 'train' if self.model.training else 'eval'
376
+ if mode == 'train':
377
+ generate_every = self.args.gradient_accumulation_steps * self.num_iterations
378
+ if self._step % generate_every == 0 or self._buffered_inputs is None:
379
+ accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
380
+ self._buffered_inputs = accumulated_local_batch # < this is the change
381
+ inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
382
+ self._step += 1
383
+ else:
384
+ inputs = self._generate_and_score_completions(accumulated_local_batch)
385
+ return inputs
386
+
387
+ def split_batches(self):
388
+ """Sync weights in batches
389
+ Only split LLM layers for now:
390
+ 1. N batches for layers
391
+ 2. other, embeds, lm_heads in one batch
392
+ 3. multi-modal components in one batch
393
+ """
394
+ model = self.accelerator.unwrap_model(self.model)
395
+ if self.args.move_model_batches is None:
396
+ # All in one
397
+ return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None]
398
+
399
+ model_arch = get_model_arch(model.model_meta.model_arch)
400
+ non_llm_parameters = []
401
+ llm_embeds = []
402
+ parameters = []
403
+ pattern = r'\.(\d+)\.'
404
+
405
+ layer_count = None
406
+ # Get the number of layers in LLM modules
407
+ for name, module in model.named_modules():
408
+ if isinstance(module, ModuleList):
409
+ if model_arch is not None and isinstance(model_arch, MultiModelKeys):
410
+ llm = model_arch.language_model
411
+ vision_tower = model_arch.vision_tower
412
+ if any(vt in name for vt in vision_tower):
413
+ continue
414
+ if isinstance(llm, list):
415
+ llm = llm[0]
416
+ if name.startswith('base_model'):
417
+ name = name.replace('base_model.', '')
418
+ if llm in name:
419
+ layer_count = len(module)
420
+ else:
421
+ layer_count = len(module)
422
+ assert layer_count is not None, 'Cannot find ModuleList to split modules.'
423
+
424
+ n_layers = ceil(layer_count / self.args.move_model_batches)
425
+ for _ in range(self.args.move_model_batches):
426
+ parameters.append([])
427
+
428
+ def replace_lora(name):
429
+ if 'lora_' in name:
430
+ return ''
431
+ else:
432
+ return name.replace('base_layer.', '')
433
+
434
+ def remove_lora_and_prefix(names):
435
+ names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names])
436
+ return [n for n in names if n]
437
+
438
+ def split_llm(name):
439
+ match = re.search(pattern, name)
440
+ if match:
441
+ number = match.group(1)
442
+ group = int(number) // n_layers
443
+ parameters[group].append(name)
444
+ else:
445
+ llm_embeds.append(name)
446
+
447
+ for name, parameter in model.named_parameters():
448
+ if 'ref_model' in name:
449
+ continue
450
+ if model_arch is not None and isinstance(model_arch, MultiModelKeys):
451
+ llm = model_arch.language_model
452
+ vision_tower = model_arch.vision_tower
453
+ if any(vt in name for vt in vision_tower):
454
+ non_llm_parameters.append(name)
455
+ elif isinstance(llm, list):
456
+ llm = llm[0]
457
+ if llm in name:
458
+ split_llm(name)
459
+ else:
460
+ non_llm_parameters.append(name)
461
+ else:
462
+ split_llm(name)
463
+
464
+ if llm_embeds:
465
+ parameters.append(llm_embeds)
466
+ if non_llm_parameters:
467
+ parameters.append(non_llm_parameters)
468
+ parameters = [p for p in parameters if p]
469
+ parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters]
470
+ return parameters, parameters_no_lora
471
+
472
+ def prepare_vllm(self, model, fast_infer_device):
473
+ from swift.tuners import Swift
474
+ from swift.llm import VllmEngine
475
+ from swift.llm.infer.infer_engine import GRPOVllmEngine
476
+ _, _, _, local_world_size = get_dist_setting()
477
+ if self.args.tensor_parallel_size > 1:
478
+ vllm_kwargs = {'distributed_executor_backend': 'external_launcher'}
479
+ else:
480
+ vllm_kwargs = {}
481
+ if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
482
+ # Compatibility with TP
483
+ cls = GRPOVllmEngine
484
+ engine_kwargs = {'seed': 0}
485
+ else:
486
+ cls = VllmEngine
487
+ engine_kwargs = {}
488
+ with Swift.grpo_context(model, self.template.processor):
489
+ engine = cls(
490
+ model.model_dir,
491
+ model.model_info.torch_dtype,
492
+ model_type=model.model_meta.model_type,
493
+ device=fast_infer_device[self.local_infer_rank],
494
+ tensor_parallel_size=self.args.tensor_parallel_size,
495
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
496
+ enable_prefix_caching=self.args.vllm_enable_prefix_caching,
497
+ max_num_seqs=self.args.vllm_max_num_seqs,
498
+ enforce_eager=self.args.vllm_enforce_eager,
499
+ limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
500
+ num_infer_workers=self.args.num_infer_workers,
501
+ enable_sleep_mode=self.args.sleep_level > 0,
502
+ use_async_engine=False,
503
+ max_model_len=self.args.vllm_max_model_len,
504
+ engine_kwargs=engine_kwargs,
505
+ **vllm_kwargs)
506
+ engine.default_template = self.template
507
+ return engine
508
+
509
+ @property
510
+ def infer_rank(self):
511
+ if self.is_external_vllm:
512
+ # When using external vLLM, only the main process (rank=0) acts as the client.
513
+ return 0 if self.accelerator.is_main_process else -1
514
+ rank, local_rank, world_size, local_world_size = get_dist_setting()
515
+ node_rank = get_node_setting()[0]
516
+ for _vllm_rank in range(self.args.num_infer_workers):
517
+ if local_rank == _vllm_rank:
518
+ return node_rank * self.args.num_infer_workers + _vllm_rank
519
+ if local_rank == -1:
520
+ return 0
521
+ return -1
522
+
523
+ @property
524
+ def infer_rank_tp_0(self):
525
+ # whether is tp rank0, get data from this rank
526
+ # vllm needs all tp ranks inputs and sampling params are the same
527
+ rank, local_rank, world_size, local_world_size = get_dist_setting()
528
+ node_rank = get_node_setting()[0]
529
+ for _vllm_rank in range(self.args.num_infer_workers):
530
+ if local_rank == _vllm_rank and _vllm_rank % self.args.tensor_parallel_size == 0:
531
+ return (node_rank * self.args.num_infer_workers + _vllm_rank // self.args.tensor_parallel_size)
532
+ if local_rank == -1:
533
+ return 0
534
+ return -1
535
+
536
+ @property
537
+ def local_infer_rank(self):
538
+ rank, local_rank, world_size, local_world_size = get_dist_setting()
539
+ for _vllm_rank in range(self.args.num_infer_workers):
540
+ if local_rank == _vllm_rank:
541
+ return _vllm_rank
542
+
543
+ return -1
544
+
545
+ def tp_group_ranks(self):
546
+ rank, local_rank, world_size, local_world_size = get_dist_setting()
547
+ return [
548
+ list(range(0, world_size))[i:i + self.args.tensor_parallel_size]
549
+ for i in range(0, world_size, self.args.tensor_parallel_size)
550
+ ]
551
+
552
+ @contextmanager
553
+ def _template_context(self, template):
554
+ # The max_length for prompt and completion has already been restricted, so there is no need for max_length here.
555
+ max_length = template.max_length
556
+ mode = template.mode
557
+ if mode in {'vllm', 'pt', 'lmdeploy'}:
558
+ template.set_mode('train')
559
+ template.max_length = None
560
+ loss_scale = template.loss_scale
561
+ if self.multi_turn_func:
562
+ template.loss_scale = 'default'
563
+ try:
564
+ yield
565
+ finally:
566
+ template.loss_scale = loss_scale
567
+ template.set_mode(mode)
568
+ template.max_length = max_length
569
+
570
+ @profiling_decorator
571
+ def _move_model_to_vllm_lmdeploy(self):
572
+ if self.is_external_vllm:
573
+ return super()._move_model_to_vllm()
574
+
575
+ from accelerate.utils.other import is_compiled_module
576
+
577
+ for i, parameter_group in enumerate(self.parameter_groups):
578
+ parameter_group_no_lora = self.parameter_groups_no_lora[i]
579
+ with unwrap_model_for_generation(
580
+ self.model,
581
+ self.accelerator,
582
+ gather_deepspeed3_params=self.args.ds3_gather_for_generation,
583
+ gather_parameters=parameter_group) as unwrapped_model:
584
+
585
+ if is_compiled_module(unwrapped_model):
586
+ unwrapped_model = unwrapped_model._orig_mod
587
+ if is_peft_model(unwrapped_model):
588
+ with patch_lora_merge(unwrapped_model, parameter_group):
589
+ unwrapped_model.merge_adapter()
590
+ state_dict = unwrapped_model.state_dict()
591
+ # Remove base_model and base_layer prefixes
592
+ state_dict = {
593
+ k.removeprefix('base_model.model.').replace('.base_layer', ''): v
594
+ for k, v in state_dict.items()
595
+ }
596
+ # Remove values with adapter prefix (example: "_lora")
597
+ state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
598
+ # When module to save, remove its prefix and discard the original module
599
+ state_dict = {
600
+ k.replace('modules_to_save.default.', ''): v
601
+ for k, v in state_dict.items() if 'original_module' not in k
602
+ }
603
+ else:
604
+ state_dict = unwrapped_model.state_dict()
605
+ if parameter_group_no_lora:
606
+ parameter_group_no_lora = [n.replace('base_model.model.', '') for n in parameter_group_no_lora]
607
+ state_dict = {k: v for k, v in state_dict.items() if k in parameter_group_no_lora}
608
+ assert len(state_dict) > 0 and all([state.shape != torch.Size([0]) for state in state_dict.values()])
609
+ if self.infer_rank >= 0:
610
+ if self.args.async_generate:
611
+ self._wait_queue()
612
+ if self.args.use_vllm:
613
+ llm_model = self.engine.inner_model
614
+ else:
615
+ llm_model = self.engine.engine.engine
616
+ llm_model.load_weights(state_dict.items())
617
+ del state_dict
618
+ gc_collect()
619
+ # Unmerge the adapter to restore the model to its original state.
620
+ # This must be done after loading weights to ensure they correspond to the merged state.
621
+ if is_peft_model(unwrapped_model):
622
+ with patch_lora_unmerge(unwrapped_model):
623
+ unwrapped_model.unmerge_adapter()
624
+
625
+ if self.infer_rank >= 0 and self.args.use_vllm and self.args.vllm_enable_prefix_caching:
626
+ self.engine.engine.reset_prefix_cache()
627
+
628
+ def _wait_queue(self):
629
+ while self._queue.empty():
630
+ time.sleep(0.01)
631
+
632
+ @staticmethod
633
+ def reorder_outputs(outputs, distributed_idx):
634
+ index_to_output = {}
635
+ current_position = 0
636
+ for output_idx in distributed_idx:
637
+ for idx in output_idx:
638
+ index_to_output[idx] = outputs[current_position]
639
+ current_position += 1
640
+
641
+ return [index_to_output[idx] for idx in sorted(index_to_output.keys())]
642
+
643
+ def _infer_multi_turn(self, inputs_slice: np.ndarray, request_config: RequestConfig) -> Union[OutputsType, List]:
644
+ """Perform multi-turn or single-turn inference with support for tensor parallelism.
645
+
646
+ Args:
647
+ inputs_slice: Array of input requests
648
+ request_config: Inference configuration parameters
649
+
650
+ Returns:
651
+ List of outputs where each entry contains:
652
+ - List of responses per prompt (length = tensor_parallel_size)
653
+ - Each response is a tuple of (message_history, finish_reason)
654
+ """
655
+ from swift.llm.infer.protocol import ChatCompletionResponse
656
+ rank, _, _, _ = get_dist_setting()
657
+ request_config = copy(request_config)
658
+ results: List[ChatCompletionResponse] = self._engine_infer(
659
+ infer_requests=inputs_slice, request_config=request_config, use_tqdm=False)
660
+ prompt_lens = len(inputs_slice)
661
+ messages_list = [None] * (len(inputs_slice) * self.args.tensor_parallel_size)
662
+ if self.multi_turn_func:
663
+ remove_response = True
664
+ while len(inputs_slice) > 0:
665
+ request_config.n = 1
666
+ if self.infer_rank_tp_0 >= 0 or not self.use_fast_infer:
667
+ inputs = []
668
+ cnt = 0
669
+ for i, output in enumerate(results):
670
+ for choice in output.choices:
671
+ _input: Dict = deepcopy(inputs_slice[i])
672
+ if remove_response or _input['messages'][-1]['role'] != 'assistant' or not \
673
+ _input['messages'][-1]['content']:
674
+ InferRequest.remove_response(_input['messages'])
675
+ _input['messages'].append({'role': 'assistant', 'content': choice.message.content})
676
+ else:
677
+ _input['messages'][-1]['content'] += choice.message.content
678
+ if 'index' not in _input:
679
+ _input['index'] = cnt
680
+ _input['finish_reason'] = choice.finish_reason
681
+ cnt += 1
682
+ inputs.append(_input)
683
+ results: List[Dict] = self.multi_turn_func(inputs) # noqa
684
+ else:
685
+ length = sum([len(results[i].choices) for i in range(len(results))])
686
+ results = [None] * length
687
+
688
+ if self.args.tensor_parallel_size > 1:
689
+ # avoid duplicate calling in the same tensor parallel group
690
+ import torch.distributed as dist
691
+ if 'group_src' in inspect.signature(dist.broadcast_object_list).parameters:
692
+ dist.broadcast_object_list(results, group_src=0, group=self.group)
693
+ else:
694
+ global_src = dist.get_global_rank(self.group, 0)
695
+ dist.broadcast_object_list(results, src=global_src, group=self.group)
696
+ inputs_slice = [r for r in results if not r['finished']]
697
+ for idx, r in enumerate(results):
698
+ if r['finished'] or r['finish_reason'] == 'length':
699
+ messages_list[r['index']] = (r['messages'], r['finish_reason'])
700
+ if len(inputs_slice) > 0:
701
+ _input_std = []
702
+ for _input in inputs_slice:
703
+ _input_std.append(StdTemplateInputs.from_dict(_input))
704
+ # StdTemplateInputs will not remove responses in infer
705
+ results = self._engine_infer(
706
+ infer_requests=_input_std, request_config=request_config, use_tqdm=False)
707
+ # concat responses from the second loop
708
+ remove_response = False
709
+
710
+ outputs = []
711
+ assert not any([m is None for m in messages_list])
712
+ for i in range(0, len(messages_list), self.args.tensor_parallel_size):
713
+ # reformat to [[x, x, x, x] [x, x, x, x]]
714
+ # this is the same format of sampling_params.n > 1
715
+ outputs.append(messages_list[i:i + self.args.tensor_parallel_size])
716
+ assert len(outputs) == prompt_lens
717
+ assert all([len(o) == self.args.tensor_parallel_size for o in outputs])
718
+ else:
719
+ # single turn
720
+ outputs = []
721
+ for i, output in enumerate(results):
722
+ _choices = []
723
+ for choice in output.choices:
724
+ _input: Dict = deepcopy(inputs_slice[i])
725
+ InferRequest.remove_response(_input['messages'])
726
+ _input['messages'].append({'role': 'assistant', 'content': choice.message.content})
727
+ _choices.append((_input['messages'], choice.finish_reason))
728
+ outputs.append(_choices)
729
+ assert len(outputs) == prompt_lens
730
+ assert all([len(o) == self.args.tensor_parallel_size for o in outputs])
731
+
732
+ if self.args.tensor_parallel_size > 1:
733
+ if self.infer_rank_tp_0 < 0:
734
+ outputs = []
735
+ else:
736
+ _outputs = []
737
+ for tp_idx in range(self.args.tensor_parallel_size):
738
+ for prompt_idx in range(len(outputs)):
739
+ _outputs.append(outputs[prompt_idx][tp_idx])
740
+ outputs = [_outputs]
741
+
742
+ return outputs
743
+
744
+ def async_infer(self, inputs, inputs_slice, distributed_idx):
745
+
746
+ def infer_task():
747
+ with set_device_context(self.infer_device), self.multi_turn_completion_length_context():
748
+ return self._infer_multi_turn(inputs_slice, self.request_config)
749
+
750
+ future: Future = self.executor.submit(infer_task)
751
+ # pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling
752
+ current_queue = self._queue
753
+
754
+ def done(_self):
755
+ current_queue.put(DataCache(inputs, _self.result(), distributed_idx))
756
+
757
+ future.add_done_callback(done)
758
+
759
+ def _prefetch(self, dataloader: DataLoader):
760
+ inputs = next(iter(dataloader))
761
+ all_inputs = gather_object(inputs)
762
+ nnodes = get_node_setting()[1]
763
+ distributed_idx = round_robin(len(all_inputs), nnodes * self.args.num_infer_workers)
764
+ if self.infer_rank >= 0:
765
+ _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]]
766
+ with self.multi_turn_completion_length_context():
767
+ outputs = self._infer_multi_turn(_input_slice, self.request_config)
768
+ self._queue.put(DataCache(inputs, outputs, distributed_idx))
769
+ else:
770
+ self._queue.put(DataCache(inputs, [], distributed_idx))
771
+ if self.accelerator.num_processes > 1:
772
+ self.accelerator.wait_for_everyone()
773
+
774
+ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]:
775
+ """
776
+ This function performs fast inference by managing model and optimizer offloading,
777
+ loading weights if necessary, distributing inputs among workers, and generating
778
+ completions using the vLLM/LMDeploy framework. It supports both synchronous and asynchronous
779
+ inference modes.
780
+ inputs: local inputs
781
+ """
782
+
783
+ if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0:
784
+ if self.args.offload_model:
785
+ self.offload_model()
786
+ if self.args.offload_optimizer:
787
+ self.offload_optimizer()
788
+ if self.args.gc_collect_after_offload:
789
+ gc_collect()
790
+ # Skip the first wake_up to avoid the warning "Executor is not sleeping"
791
+ if self.engine.inner_model_executor.is_sleeping:
792
+ self.engine.engine.wake_up()
793
+ # First, have main process load weights if needed
794
+ if self.state.global_step != self._last_loaded_step:
795
+ self._move_model_to_vllm_lmdeploy()
796
+ self._last_loaded_step = self.state.global_step
797
+ all_inputs = gather_object(inputs)
798
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
799
+ # Distribute inputs to different workers
800
+ # for example, 2 workers, 6 inputs, 0/2/4 dispatch to the first worker
801
+ # 1/3/5 dispatch to the second worker
802
+ # trying to shuffle and average the length
803
+ nnodes = get_node_setting()[1]
804
+ num_workers = 1 if self.is_external_vllm else nnodes
805
+ distributed_idx = round_robin(len(all_inputs), num_workers * self.args.num_infer_workers)
806
+ if self.infer_rank >= 0:
807
+ _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]]
808
+ if self.args.async_generate:
809
+ self.async_infer(inputs, _input_slice, distributed_idx)
810
+ data_cache = self._queue.get()
811
+ inputs = data_cache.inputs
812
+ outputs = data_cache.outputs
813
+ distributed_idx = data_cache.distributed_idx
814
+ else:
815
+ with set_device_context(self.infer_device):
816
+ request_config = copy(self.request_config)
817
+ if self.args.tensor_parallel_size > 1:
818
+ request_config.seed += self.state.global_step
819
+ with self.multi_turn_completion_length_context():
820
+ outputs = self._infer_multi_turn(_input_slice, self.request_config)
821
+ else:
822
+ if self.args.async_generate:
823
+ # using old model to generate, which will ignore the `clip` of advantages.
824
+ self._queue.put(DataCache(inputs, [], distributed_idx))
825
+ data_cache = self._queue.get()
826
+ inputs = data_cache.inputs
827
+ distributed_idx = data_cache.distributed_idx
828
+ outputs = []
829
+ outputs = gather_object(outputs)
830
+ if self.args.tensor_parallel_size > 1:
831
+ outputs = [[item] for output in outputs for item in output]
832
+ if not self.is_external_vllm:
833
+ outputs = self.reorder_outputs(outputs, distributed_idx)
834
+ if not self.is_external_vllm and self.args.sleep_level > 0 and self.infer_rank >= 0:
835
+ self.engine.engine.sleep(level=self.args.sleep_level)
836
+ if self.args.gc_collect_after_offload:
837
+ gc_collect()
838
+ if self.args.offload_model:
839
+ self.load_model()
840
+ if self.args.offload_optimizer:
841
+ self.load_optimizer()
842
+ return inputs, outputs
843
+
844
+ def _generate_completions(self, inputs: InputsType) -> InputsType:
845
+ """Generate completions for given inputs using either fast inference or standard PyTorch inference.
846
+
847
+ Args:
848
+ inputs: List of input examples containing conversation messages.
849
+
850
+ Returns:
851
+ Modified inputs with generated completions added to the last message
852
+ and truncation flag set in 'is_truncated' field.
853
+ """
854
+ mode = 'train' if self.model.training else 'eval'
855
+ if self.use_fast_infer:
856
+ inputs, outputs = self._fast_infer(inputs)
857
+ # Slice to keep only the local part of the data
858
+ process_slice = slice(
859
+ self.accelerator.process_index * len(inputs),
860
+ (self.accelerator.process_index + 1) * len(inputs),
861
+ )
862
+ outputs = outputs[process_slice]
863
+ else:
864
+ # pt infer
865
+ is_multimodal = self.model.model_meta.is_multimodal
866
+ if is_multimodal:
867
+ models = self.template.remove_post_encode_hook()
868
+ with unwrap_model_for_generation(
869
+ self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
870
+ ), self.multi_turn_completion_length_context():
871
+ outputs = self._infer_multi_turn(inputs, self.request_config)
872
+ if mode == 'train':
873
+ # In training mode, ensure the model is returned to train() mode after inference
874
+ # This is necessary as pt engines set the model to eval mode during generation
875
+ self.model.train()
876
+ if is_multimodal:
877
+ self.template.register_post_encode_hook(models)
878
+ if isinstance(outputs[0][0], list):
879
+ outputs = [output[0] for output in outputs]
880
+
881
+ for i, output in enumerate(outputs):
882
+ inputs[i]['messages'] = output[0][0]
883
+ inputs[i]['is_truncated'] = output[0][1] == 'length'
884
+
885
+ return inputs
886
+
887
+ def _generate_and_score_completions(self, inputs: InputsType) -> InputsType:
888
+
889
+ inputs = self._generate_completions(inputs)
890
+ total_rewards_per_func, total_rewards, completions = self._score_completions(inputs)
891
+ mode = 'train' if self.model.training else 'eval'
892
+
893
+ if self.args.dynamic_sample and mode == 'train':
894
+ # dynamic sampling for std=0 groups
895
+ inputs, total_rewards, total_rewards_per_func, completions = \
896
+ self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions)
897
+
898
+ # Prepare final outputs with advantages and other required fields
899
+ batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards)
900
+ # Log metrics
901
+ messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))]
902
+
903
+ self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func)
904
+
905
+ return batch_encoded_inputs
906
+
907
+ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
908
+ """Score completions using all reward functions
909
+
910
+ Args:
911
+ inputs: List of input examples, each containing a 'messages' list with conversation history
912
+
913
+ Returns:
914
+ Tuple containing:
915
+ - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards
916
+ - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards
917
+ - completions: List of generated completion strings
918
+ """
919
+ device = self.accelerator.device
920
+ completions = [example['messages'][-1]['content'] for example in inputs]
921
+ rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device)
922
+
923
+ for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)):
924
+ # reward model
925
+ if isinstance(reward_func, nn.Module):
926
+ rewards_per_func[:, i] = reward_model_plugin(inputs=inputs)
927
+ # reward function
928
+ else:
929
+ # Repeat all input columns (but "messages" and "completion") to match the number of generations
930
+ reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
931
+ output_reward_func = reward_func(completions, **reward_kwargs)
932
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
933
+
934
+ total_rewards_per_func = gather(rewards_per_func)
935
+ total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
936
+
937
+ return total_rewards_per_func, total_rewards, completions
938
+
939
+ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions):
940
+ # DAPO https://arxiv.org/abs/2503.14476
941
+ # Replaces samples with zero-reward-variance groups (std=0)
942
+ resample_count = 0
943
+ valid_samples = []
944
+ valid_rewards = []
945
+ valid_rewards_per_func = []
946
+ valid_completions = []
947
+
948
+ origin_data = (inputs, rewards, rewards_per_func, completions)
949
+
950
+ while resample_count < self.args.max_resample_times:
951
+ grouped_rewards = rewards.view(-1, self.num_generations)
952
+ group_std = grouped_rewards.std(dim=1)
953
+
954
+ valid_mask = (group_std > 0).repeat_interleave(self.num_generations)
955
+ all_inputs = gather_object(inputs)
956
+ valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask])
957
+ valid_rewards.append(rewards[valid_mask])
958
+ valid_rewards_per_func.append(rewards_per_func[valid_mask])
959
+ valid_completions.extend(
960
+ [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask])
961
+
962
+ if len(valid_samples) >= self.effective_train_batch_size:
963
+ break
964
+
965
+ inputs = next(self.resample_iterator)
966
+ inputs = Trainer._prepare_inputs(self, inputs)
967
+ inputs = self._generate_completions(inputs)
968
+ rewards_per_func, rewards, completions = self._score_completions(inputs)
969
+ resample_count += 1
970
+
971
+ if len(valid_samples) >= self.effective_train_batch_size:
972
+ process_slice = slice(
973
+ self.accelerator.process_index * len(inputs),
974
+ (self.accelerator.process_index + 1) * len(inputs),
975
+ )
976
+ inputs = valid_samples[:self.effective_train_batch_size][process_slice]
977
+ rewards = torch.cat(valid_rewards)[:self.effective_train_batch_size]
978
+ rewards_per_func = torch.cat(valid_rewards_per_func)[:self.effective_train_batch_size]
979
+ completions = valid_completions[:self.effective_train_batch_size][process_slice]
980
+ else:
981
+ logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.')
982
+ inputs, rewards, rewards_per_func, completions = origin_data
983
+
984
+ return inputs, rewards, rewards_per_func, completions
985
+
986
+ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]:
987
+ """
988
+ Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training.
989
+
990
+ Args:
991
+ inputs (InputsType): List of input samples. Original shape is [gas*bs] where:
992
+ - gas: gradient accumulation steps
993
+ - bs: per-device batch size
994
+ rewards (torch.Tensor): Tensor of rewards corresponding to the inputs.
995
+ Shape should match the total number of samples (gas*bs*num_generations)
996
+
997
+ Returns:
998
+ List[InputsType]: A list of prepared batch inputs, organized as [gas][bs]
999
+ """
1000
+ # Compute advantages
1001
+ grouped_rewards = rewards.view(-1, self.num_generations)
1002
+ mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0)
1003
+ std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0)
1004
+ advantages = (rewards - mean_grouped_rewards)
1005
+ if self.args.scale_rewards:
1006
+ advantages /= (std_grouped_rewards + 1e-4)
1007
+
1008
+ # Slice to keep only the local part of the data
1009
+ process_slice = slice(
1010
+ self.accelerator.process_index * len(inputs),
1011
+ (self.accelerator.process_index + 1) * len(inputs),
1012
+ )
1013
+ advantages = advantages[process_slice]
1014
+
1015
+ mode = 'train' if self.model.training else 'eval'
1016
+ bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
1017
+ gas = self.args.gradient_accumulation_steps if mode == 'train' else 1
1018
+
1019
+ assert len(inputs) == bs * gas, f'Expected {bs * gas} inputs, got {len(inputs)}'
1020
+ gas_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(gas)]
1021
+
1022
+ ga_batch_encoded_inputs = []
1023
+ template = self.template
1024
+
1025
+ # Split advantages by GAS chunks
1026
+ advantage_chunks = torch.chunk(advantages, gas)
1027
+
1028
+ for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)):
1029
+ # Encode and process each batch (size=bs)
1030
+ with self._template_context(template):
1031
+ batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch]
1032
+ batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device)
1033
+
1034
+ # Process labels and masks
1035
+ labels = batch_encoded_inputs.pop('labels')
1036
+ logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item()
1037
+ batch_encoded_inputs.update({
1038
+ 'completion_mask':
1039
+ labels[:, -logits_to_keep:] != -100,
1040
+ 'truncated_mask':
1041
+ torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool),
1042
+ 'logits_to_keep':
1043
+ logits_to_keep,
1044
+ 'advantages':
1045
+ batch_advantages
1046
+ })
1047
+
1048
+ with torch.no_grad():
1049
+ batch_encoded_inputs['old_per_token_logps'] = (
1050
+ self._get_per_token_logps(self.model, batch_encoded_inputs) if self.old_policy else None)
1051
+
1052
+ if self.beta == 0.0:
1053
+ ref_per_token_logps = None
1054
+ elif self.ref_model is not None:
1055
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, batch_encoded_inputs)
1056
+ else:
1057
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
1058
+ ref_per_token_logps = self._get_per_token_logps(self.model, batch_encoded_inputs)
1059
+ batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps
1060
+
1061
+ ga_batch_encoded_inputs.append(batch_encoded_inputs)
1062
+
1063
+ return ga_batch_encoded_inputs
1064
+
1065
+ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func):
1066
+ """Log training/evaluation metrics"""
1067
+ mode = 'train' if self.model.training else 'eval'
1068
+ device = self.accelerator.device
1069
+
1070
+ # Calculate completion length metrics
1071
+ agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs]))
1072
+
1073
+ self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item())
1074
+ self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item())
1075
+ self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item())
1076
+ # Calculate clip ratio
1077
+ agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device))
1078
+
1079
+ term_completion_mask = agg_completion_mask[agg_truncated_mask]
1080
+ clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask)
1081
+
1082
+ self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio)
1083
+
1084
+ for i, reward_func_name in enumerate(self.reward_func_names):
1085
+ mean_rewards = rewards_per_func[:, i].mean().item()
1086
+ self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards)
1087
+ std_rewards = rewards_per_func[:, i].std().item()
1088
+ self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards)
1089
+
1090
+ # Log overall reward stats
1091
+ grouped_rewards = rewards.view(-1, self.num_generations)
1092
+ self._metrics[mode]['reward'].append(grouped_rewards.mean().item())
1093
+ self._metrics[mode]['reward_std'].append(grouped_rewards.std(dim=1).mean().item())
1094
+
1095
+ # Log prompt and completion texts
1096
+ self._textual_logs['prompt'].extend(gather_object(messages))
1097
+ self._textual_logs['completion'].extend(gather_object(completions))
1098
+ for i, name in enumerate(self.reward_func_names):
1099
+ self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist())
1100
+
1101
+ @profiling_decorator
1102
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
1103
+ # Compute the per-token log probabilities for the model, return_outputs=True in mini-batch training
1104
+ if isinstance(inputs, list):
1105
+ assert len(inputs) == 1
1106
+ inputs = inputs[0]
1107
+ completion_mask = inputs['completion_mask']
1108
+ truncated_mask = inputs['truncated_mask']
1109
+ # apply the completion_mask to exclude loss and metrics for overlong completions
1110
+ if self.args.overlong_filter and any(truncated_mask):
1111
+ if all(truncated_mask):
1112
+ logger.info('All completions are overlong, loss and KL will be zero')
1113
+ truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device)
1114
+ completion_mask = completion_mask * (~truncated_mask)
1115
+
1116
+ per_token_logps = self._get_per_token_logps(model, inputs)
1117
+
1118
+ # Compute the KL divergence between the model and the reference model
1119
+ if self.beta != 0.0:
1120
+ ref_per_token_logps = inputs['ref_per_token_logps']
1121
+ per_token_kl = (
1122
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1)
1123
+
1124
+ advantages = inputs['advantages']
1125
+ old_per_token_logps = inputs['old_per_token_logps'] if self.old_policy else per_token_logps.detach()
1126
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
1127
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
1128
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
1129
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
1130
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
1131
+ if self.beta != 0.0:
1132
+ per_token_loss = per_token_loss + self.beta * per_token_kl
1133
+
1134
+ if self.loss_type == 'grpo':
1135
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
1136
+ elif self.loss_type == 'bnpo':
1137
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1138
+ elif self.loss_type == 'dr_grpo':
1139
+ loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
1140
+ else:
1141
+ raise ValueError(f'Unknown loss type: {self.loss_type}')
1142
+
1143
+ # Log the metrics
1144
+ mode = 'train' if self.model.training else 'eval'
1145
+
1146
+ if self.beta != 0.0:
1147
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
1148
+ self._metrics[mode]['kl'].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
1149
+
1150
+ # Compute the clipped probability ratios
1151
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
1152
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1153
+ is_region_clipped = is_low_clipped | is_high_clipped
1154
+
1155
+ low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
1156
+ high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
1157
+ clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
1158
+
1159
+ gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
1160
+ self._metrics[mode]['clip_ratio/low_mean'].append(gathered_low_clip.nanmean().item())
1161
+ self._metrics[mode]['clip_ratio/low_min'].append(nanmin(gathered_low_clip).item())
1162
+ gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
1163
+ self._metrics[mode]['clip_ratio/high_mean'].append(gathered_high_clip.nanmean().item())
1164
+ self._metrics[mode]['clip_ratio/high_max'].append(nanmax(gathered_high_clip).item())
1165
+ gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
1166
+ self._metrics[mode]['clip_ratio/region_mean'].append(gathered_clip_ratio.nanmean().item())
1167
+
1168
+ return loss
1169
+
1170
+ # Get the per-token log probabilities for the completions for the model and the reference model
1171
+ @profiling_decorator
1172
+ def _get_per_token_logps(self, model, inputs):
1173
+ from trl.trainer.utils import selective_log_softmax
1174
+ logits_to_keep = inputs['logits_to_keep']
1175
+ input_ids = inputs['input_ids']
1176
+ unwrapped_model = self.accelerator.unwrap_model(model)
1177
+ if is_peft_model(unwrapped_model):
1178
+ parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters
1179
+ else:
1180
+ parameters = inspect.signature(unwrapped_model.forward).parameters
1181
+ if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters:
1182
+ # save memory
1183
+ return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep)
1184
+ inputs = {
1185
+ k: v
1186
+ for k, v in inputs.items() if k not in [
1187
+ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps',
1188
+ 'truncated_mask'
1189
+ ]
1190
+ }
1191
+ with self._template_context(self.template):
1192
+ logits = model(**inputs).logits
1193
+ # exclude the last logit: it corresponds to the next token pred
1194
+ logits = logits[:, -(logits_to_keep + 1):-1, :]
1195
+ logits = logits / self.temperature
1196
+ input_ids = input_ids[:, -logits_to_keep:]
1197
+ return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
1198
+
1199
+ def evaluation_loop(self, dataloader, *args, **kwargs):
1200
+ # Wait for the training rollout to complete
1201
+ if self.args.async_generate:
1202
+ while not self.is_async_generate_eval_rollout_done():
1203
+ time.sleep(0.1)
1204
+ if self._queue.empty() and self.args.async_generate:
1205
+ self._prefetch(dataloader)
1206
+ metric_key_prefix = kwargs['metric_key_prefix']
1207
+ output = super().evaluation_loop(dataloader, *args, **kwargs)
1208
+ metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()}
1209
+ output.metrics.update(metrics)
1210
+ self.eval_flag = True
1211
+ return output
1212
+
1213
+ def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor:
1214
+ if self.args.async_generate:
1215
+ # Wait for the eval rollout to complete
1216
+ while not self.is_async_generate_eval_rollout_done():
1217
+ time.sleep(0.1)
1218
+ return super().training_step(model, inputs, num_items_in_batch)
1219
+
1220
+ def _engine_infer(
1221
+ self,
1222
+ infer_requests: List[InferRequest],
1223
+ request_config: Optional[RequestConfig] = None,
1224
+ *,
1225
+ use_tqdm: Optional[bool] = None,
1226
+ ):
1227
+ if self.is_external_vllm:
1228
+ self._process_infer_requests_images(infer_requests)
1229
+ return self.vllm_client.infer(infer_requests.tolist(), asdict(request_config), use_tqdm=use_tqdm)
1230
+ else:
1231
+ return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
1232
+
1233
+ def _process_infer_requests_images(self, infer_requests: List[InferRequest]):
1234
+ import base64
1235
+ if not any('images' in request for request in infer_requests):
1236
+ return
1237
+ for request in infer_requests:
1238
+ if 'images' not in request:
1239
+ continue
1240
+ for i, img in enumerate(request['images']):
1241
+ if 'bytes' in img and img['bytes']:
1242
+ request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8')
1243
+ return
1244
+
1245
+ @property
1246
+ def old_policy(self):
1247
+ return self.num_iterations > 1
1248
+
1249
+ @property
1250
+ def _queue(self):
1251
+ if self.control.should_evaluate:
1252
+ return self.eval_queue
1253
+ else:
1254
+ return self.train_queue
1255
+
1256
+ @torch.no_grad()
1257
+ def offload_model(self):
1258
+ if len(self.offload_modules) > 0:
1259
+ return
1260
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
1261
+ for name, module in unwrapped_model.named_modules():
1262
+ if isinstance(module, torch.nn.Embedding):
1263
+ self.offload_modules[name] = module.weight.device
1264
+ module.to('cpu')
1265
+ elif not hasattr(module, 'device'):
1266
+ pass
1267
+ elif module.device.type != 'cpu':
1268
+ self.offload_modules[name] = module.device
1269
+ module.to('cpu')
1270
+
1271
+ @torch.no_grad()
1272
+ def load_model(self):
1273
+ if len(self.offload_modules) == 0:
1274
+ return
1275
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
1276
+ for name, device in self.offload_modules.items():
1277
+ module = unwrapped_model.get_submodule(name)
1278
+ if isinstance(module, torch.nn.Embedding):
1279
+ module.weight.to(device)
1280
+ else:
1281
+ module.to(device)
1282
+ self.offload_modules.clear()
1283
+
1284
+ @torch.no_grad()
1285
+ def offload_optimizer(self):
1286
+ if len(self.offload_states) > 0:
1287
+ return
1288
+ if not self.optimizer.state:
1289
+ return
1290
+ for param_group in self.optimizer.param_groups:
1291
+ for param in param_group['params']:
1292
+ state = self.optimizer.state[param]
1293
+ for key, value in state.items():
1294
+ if isinstance(value, torch.Tensor):
1295
+ self.offload_states[key] = value.device
1296
+ state[key] = value.to('cpu', non_blocking=True)
1297
+
1298
+ @torch.no_grad()
1299
+ def load_optimizer(self):
1300
+ if len(self.offload_states) == 0:
1301
+ return
1302
+ if not self.optimizer.state:
1303
+ return
1304
+ for param_group in self.optimizer.param_groups:
1305
+ for param in param_group['params']:
1306
+ state = self.optimizer.state[param]
1307
+ for key, value in state.items():
1308
+ if isinstance(value, torch.Tensor):
1309
+ state[key] = value.to(self.offload_states[key], non_blocking=True)
1310
+ self.offload_states.clear()
1311
+
1312
+ @contextmanager
1313
+ def multi_turn_completion_length_context(self):
1314
+ """
1315
+ Context manager that temporarily adjusts the engine's max length handling
1316
+ for multi-turn generation scenarios.
1317
+
1318
+ Ensures the total sequence length (prompt + completion) never exceeds:
1319
+ min(original_max_len, prompt_tokens + max_completion_length)
1320
+ """
1321
+ if not (self.multi_turn_func and self.infer_rank >= 0) or self.is_external_vllm:
1322
+ yield
1323
+ return
1324
+
1325
+ original_fn = self.engine.set_default_max_tokens
1326
+ original_max_len = self.engine.max_model_len
1327
+
1328
+ def set_default_max_tokens(_self, request_config: RequestConfig, inputs: InputsType) -> None:
1329
+ # Calculate required context window
1330
+ original_max_len = _self.max_model_len or 8192
1331
+ if isinstance(inputs, dict):
1332
+ inputs = [inputs]
1333
+ prompt_tokens = max(_self._get_num_tokens(inp) for inp in inputs)
1334
+
1335
+ if not hasattr(_self, 'set_grpo_max_model_len'):
1336
+ # set max model len in first round
1337
+ max_len = min(original_max_len, prompt_tokens + request_config.max_tokens)
1338
+ _self.max_model_len = max_len
1339
+ _self.set_grpo_max_model_len = True
1340
+ else:
1341
+ if _self.max_model_len <= prompt_tokens:
1342
+ # modify max_model_len > prompt_tokens to avoid crash
1343
+ num_tokens_avoid_crash = 10
1344
+ _self.max_model_len = (prompt_tokens + num_tokens_avoid_crash)
1345
+ request_config.max_tokens = num_tokens_avoid_crash
1346
+
1347
+ original_fn(request_config, inputs)
1348
+
1349
+ try:
1350
+ self.engine.set_default_max_tokens = MethodType(set_default_max_tokens, self.engine)
1351
+ yield
1352
+ finally:
1353
+ self.engine.set_default_max_tokens = original_fn
1354
+ self.engine.max_model_len = original_max_len
1355
+ del self.engine.set_grpo_max_model_len
1356
+
1357
+ def get_resample_dataloader(self) -> DataLoader:
1358
+ resample_dataset = self.resample_dataset
1359
+ data_collator = self.data_collator
1360
+ if isinstance(resample_dataset, datasets.Dataset):
1361
+ resample_dataset = self._remove_unused_columns(resample_dataset, description='training')
1362
+ else:
1363
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
1364
+
1365
+ dataloader_params = {
1366
+ 'batch_size': self._train_batch_size * self.args.gradient_accumulation_steps,
1367
+ 'collate_fn': data_collator,
1368
+ 'num_workers': self.args.dataloader_num_workers,
1369
+ 'pin_memory': self.args.dataloader_pin_memory,
1370
+ 'persistent_workers': self.args.dataloader_persistent_workers,
1371
+ }
1372
+
1373
+ @contextmanager
1374
+ def seed_context(self):
1375
+ seed = self.args.seed
1376
+ self.args.seed = seed + 1
1377
+ yield
1378
+ self.args.seed = seed
1379
+
1380
+ if not isinstance(resample_dataset, torch.utils.data.IterableDataset):
1381
+ with seed_context(self): # Set a different seed for resampling than the train_dataset.
1382
+ dataloader_params['sampler'] = self._get_train_sampler()
1383
+ dataloader_params['drop_last'] = self.args.dataloader_drop_last
1384
+ dataloader_params['worker_init_fn'] = seed_worker
1385
+ dataloader_params['prefetch_factor'] = self.args.dataloader_prefetch_factor
1386
+
1387
+ return self.accelerator.prepare(DataLoader(resample_dataset, **dataloader_params))
1388
+
1389
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1390
+ mode = 'train' if self.model.training else 'eval'
1391
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
1392
+
1393
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1394
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1395
+ if mode == 'eval':
1396
+ metrics = {f'eval_{key}': val for key, val in metrics.items()}
1397
+
1398
+ logs = {**logs, **metrics}
1399
+ if version.parse(transformers.__version__) >= version.parse('4.47.0.dev0'):
1400
+ super().log(logs, start_time)
1401
+ else: # transformers<=4.46
1402
+ super().log(logs)
1403
+ self._metrics[mode].clear()
1404
+
1405
+ if self.accelerator.is_main_process and self.log_completions:
1406
+ table = {
1407
+ 'step': [str(self.state.global_step)] * len(self._textual_logs['prompt']),
1408
+ 'prompt': self._textual_logs['prompt'],
1409
+ 'completion': self._textual_logs['completion'],
1410
+ **self._textual_logs['rewards'],
1411
+ }
1412
+ self.jsonl_writer.append(table)
1413
+ if self.args.report_to and 'wandb' in self.args.report_to and wandb.run is not None:
1414
+ import pandas as pd
1415
+ df = pd.DataFrame(table)
1416
+ if self.wandb_log_unique_prompts:
1417
+ df = df.drop_duplicates(subset=['prompt'])
1418
+ wandb.log({'completions': wandb.Table(dataframe=df)})
1419
+
1420
+ def is_async_generate_eval_rollout_done(self):
1421
+ return not self.eval_flag or not self.eval_queue.empty()
1422
+
1423
+ def is_async_generate_train_rollout_done(self):
1424
+ return not self.train_queue.empty()
ms-swift/swift/trainers/sequence_parallel/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ if os.environ.get('SEQUENCE_PARALLEL_IMPL', 'ulysses') == 'xtuner':
4
+ from .xtuner import XTuner
5
+ sequence_parallel = XTuner()
6
+ else:
7
+ from .ulysses import Ulysses
8
+ sequence_parallel = Ulysses()
ms-swift/swift/tuners/__pycache__/base.cpython-310.pyc ADDED
Binary file (32.4 kB). View file
 
ms-swift/swift/tuners/base.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Copyright 2023-present the HuggingFace Inc. team.
3
+ import os
4
+ import re
5
+ import shutil
6
+ import tempfile
7
+ from contextlib import contextmanager
8
+ from copy import copy
9
+ from functools import partial
10
+ from inspect import Parameter, Signature, signature
11
+ from types import MethodType
12
+ from typing import Dict, List, Literal, Optional, Union
13
+
14
+ import json
15
+ import torch
16
+ from modelscope import snapshot_download
17
+ from peft.utils import CONFIG_NAME
18
+ from peft.utils.other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
19
+ from torch import nn
20
+ from transformers import Trainer
21
+
22
+ from swift.utils.constants import DEFAULT_ADAPTER, SWIFT_TYPE_KEY
23
+ from swift.utils.logger import get_logger
24
+ from ..utils.torch_utils import get_device_count
25
+ from .mapping import SwiftTuners
26
+ from .peft import PeftConfig, PeftModel, get_peft_model
27
+ from .utils import SwiftConfig, SwiftOutput
28
+
29
+ logger = get_logger()
30
+
31
+
32
+ class SwiftModel(nn.Module):
33
+ """The Swift wrapper model.
34
+
35
+ Args:
36
+ model (`Union[nn.Module, 'SwiftModel']`) A module to be tuned by Swift.
37
+ config (`Union[SwiftConfig, Dict[str, SwiftConfig]]`) A config or a dict of {adapter_name: SwiftConfig}.
38
+ If it's a config class, the adapter_name will be `default`
39
+ extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved.
40
+ inference_mode (bool, `optional`): Load model at inference mode, default False.
41
+ """
42
+
43
+ EXTRA_STATE_DIR = 'extra_states'
44
+
45
+ def __init__(self,
46
+ model: Union[nn.Module, 'SwiftModel'],
47
+ config: Union[SwiftConfig, Dict[str, SwiftConfig]],
48
+ extra_state_keys: List[str] = None,
49
+ inference_mode: bool = False,
50
+ **kwargs):
51
+ super().__init__()
52
+ self.adapters = {}
53
+ self.active_adapters = set()
54
+ if isinstance(model, SwiftModel):
55
+ self.adapters = model.adapters
56
+ extra_state_keys = extra_state_keys or []
57
+ extra_state_keys.extend(model.extra_state_keys)
58
+ self.active_adapters = model.active_adapters
59
+ model = model.base_model
60
+
61
+ self.base_model = model
62
+ new_adapters = []
63
+ if isinstance(config, SwiftConfig):
64
+ if DEFAULT_ADAPTER not in self.adapters:
65
+ all_parts = self._deactivate_all_parts()
66
+ self.adapters[DEFAULT_ADAPTER] = self._prepare_model(model, config, DEFAULT_ADAPTER)
67
+ for part in all_parts:
68
+ self.activate_adapter(part)
69
+ new_adapters.append(DEFAULT_ADAPTER)
70
+ if self.adapters[DEFAULT_ADAPTER].model is not None:
71
+ self.base_model = self.adapters[DEFAULT_ADAPTER].model
72
+ else:
73
+ logger.warn(f'Adapter {DEFAULT_ADAPTER} has been patched, skip.')
74
+ elif isinstance(config, dict):
75
+ assert (all(isinstance(c, SwiftConfig) for c in config.values()))
76
+ for adapter_name, _config in config.items():
77
+ if adapter_name not in self.adapters:
78
+ all_parts = self._deactivate_all_parts()
79
+ self.adapters[adapter_name] = self._prepare_model(model, _config, adapter_name)
80
+ for part in all_parts:
81
+ self.activate_adapter(part)
82
+ new_adapters.append(adapter_name)
83
+ if self.adapters[adapter_name].model is not None:
84
+ self.base_model = self.adapters[adapter_name].model
85
+ else:
86
+ logger.warn(f'Adapter {adapter_name} has been patched, skip.')
87
+
88
+ self.extra_state_keys = extra_state_keys or []
89
+ self.has_additional_modules = any([c.config.has_additional_modules for c in self.adapters.values()])
90
+
91
+ def forward(self, *args, **kwargs):
92
+ return self.base_model(*args, **kwargs)
93
+
94
+ _parameters = [Parameter('self', Parameter.POSITIONAL_ONLY)]
95
+ _parameters += list(signature(self.base_model.forward).parameters.values())
96
+ forward.__signature__ = Signature(_parameters)
97
+ self.forward = MethodType(forward, self)
98
+ for adapter_name in new_adapters:
99
+ self.activate_adapter(adapter_name)
100
+
101
+ if inference_mode:
102
+ self.eval()
103
+ else:
104
+ for key, output in self.adapters.items():
105
+ if key in new_adapters:
106
+ output.mark_trainable_callback(model)
107
+ if self.extra_state_keys:
108
+ for n, p in model.named_parameters():
109
+ if any(re.fullmatch(extra_key, n) for extra_key in self.extra_state_keys):
110
+ p.requires_grad = True
111
+
112
+ @property
113
+ def model(self):
114
+ return self.base_model
115
+
116
+ def _deactivate_all_parts(self):
117
+ deactivated = []
118
+ for adapter in self.active_adapters:
119
+ output = self.adapters[adapter]
120
+ if output.config.swift_type == SwiftTuners.PART:
121
+ deactivated.append(adapter)
122
+ self.deactivate_adapter(adapter)
123
+ return deactivated
124
+
125
+ def load_state_dict(self, state_dict, strict=True, adapter_name: str = None):
126
+ if adapter_name is not None:
127
+ output: SwiftOutput = self.adapters[adapter_name]
128
+ if getattr(output.config, 'modules_to_save', None):
129
+ for key, value in copy(state_dict).items():
130
+ for module_name in output.config.modules_to_save:
131
+ if module_name in key:
132
+ state_dict.pop(key)
133
+ key = key.replace(module_name, f'{module_name}.modules_to_save.{adapter_name}')
134
+ break
135
+ state_dict[key] = value
136
+
137
+ for key, value in copy(state_dict).items():
138
+ if key.startswith('base_model.model.'):
139
+ state_dict.pop(key, None)
140
+ key = key[len('base_model.model.'):]
141
+ if f'lora_A.{adapter_name}.' not in key and 'lora_A' in key:
142
+ state_dict.pop(key, None)
143
+ key = key.replace('lora_A.', f'lora_A.{adapter_name}.')
144
+ if f'lora_B.{adapter_name}.' not in key and 'lora_B' in key:
145
+ state_dict.pop(key, None)
146
+ key = key.replace('lora_B.', f'lora_B.{adapter_name}.')
147
+ if f'lora_embedding_A.{adapter_name}.' not in key and 'lora_embedding_A' in key:
148
+ state_dict.pop(key, None)
149
+ key = key.replace('lora_embedding_A.', f'lora_embedding_A.{adapter_name}.')
150
+ if f'lora_embedding_B.{adapter_name}.' not in key and 'lora_embedding_B' in key:
151
+ state_dict.pop(key, None)
152
+ key = key.replace('lora_embedding_B.', f'lora_embedding_B.{adapter_name}.')
153
+ state_dict[key] = value
154
+
155
+ if output.load_state_dict_callback:
156
+ state_dict = output.load_state_dict_callback(self.base_model, adapter_name, state_dict)
157
+
158
+ incompatible_keys = self.base_model.load_state_dict(state_dict, False)
159
+ if incompatible_keys and len(incompatible_keys[1]) > 0:
160
+ logger.error(f'Load state dict with unexpected keys: {incompatible_keys[1]}')
161
+
162
+ def state_dict(self,
163
+ *args,
164
+ destination=None,
165
+ prefix='',
166
+ keep_vars=False,
167
+ adapter_name: str = None,
168
+ peft_format: bool = False,
169
+ **kwargs):
170
+ """
171
+ Args:
172
+ destination (`dict`, `optional`): If provided, the state of module will
173
+ be updated into the dict and the same object is returned.
174
+ Otherwise, an ``OrderedDict`` will be created and returned.
175
+ Default: ``None``.
176
+ prefix (`str`, `optional`): a prefix added to parameter and buffer
177
+ names to compose the keys in state_dict. Default: ``''``.
178
+ keep_vars (`bool`, `optional`): by default the :class:`~torch.Tensor` s
179
+ returned in the state dict are detached from autograd. If it's
180
+ set to ``True``, detaching will not be performed.
181
+ Default: ``False``.
182
+ adapter_name (`str`, `optional`): The name of the adapter's parameters to be saved,
183
+ `None` input will save all adapters.
184
+ peft_format (`bool`, `optional`): Save with peft format (extra `base_model.model.` prefix)
185
+ **kwargs:
186
+ save_adapter(`bool`): Save adapters or not, default True
187
+ save_extra_states(`bool`): Save extra states or not, default True
188
+ Returns:
189
+ The state dict to be saved.
190
+ """
191
+ state_dict = kwargs.get('state_dict')
192
+ if state_dict is None:
193
+ state_dict = self.base_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
194
+ state_dict = {
195
+ key[len('base_model.'):] if key.startswith('base_model.') else key: value
196
+ for key, value in state_dict.items()
197
+ }
198
+ if not self.has_additional_modules:
199
+ return state_dict
200
+
201
+ state_dicts = {}
202
+ if kwargs.get('save_adapter', True):
203
+ for name, output in self.adapters.items():
204
+ if (adapter_name == name or adapter_name is None) and output.config.has_additional_modules: # noqa
205
+ state_dicts.update(output.state_dict_callback(state_dict, name))
206
+ modules_to_save_names = [
207
+ sub_name for sub_name, _ in self.base_model.named_parameters()
208
+ if f'modules_to_save.{name}' in sub_name
209
+ ]
210
+ for module_name in modules_to_save_names:
211
+ if f'modules_to_save.{name}' in module_name:
212
+ state_dicts[module_name.replace(f'modules_to_save.{name}.', '')] = state_dict[module_name]
213
+ if kwargs.get('save_extra_states', True):
214
+ state_dicts.update({
215
+ k: v
216
+ for k, v in state_dict.items() if any(
217
+ re.fullmatch(extra_key, k) for extra_key in self.extra_state_keys)
218
+ })
219
+ if peft_format:
220
+ new_state_dict = {}
221
+ for key, value in state_dicts.items():
222
+ if not key.startswith('base_model.model.'):
223
+ key = 'base_model.model.' + key
224
+ key = key.replace(f'lora_A.{adapter_name}.', 'lora_A.')
225
+ key = key.replace(f'lora_B.{adapter_name}.', 'lora_B.')
226
+ key = key.replace(f'lora_embedding_A.{adapter_name}.', 'lora_embedding_A.')
227
+ key = key.replace(f'lora_embedding_B.{adapter_name}.', 'lora_embedding_B.')
228
+ new_state_dict[key] = value
229
+ state_dicts = new_state_dict
230
+ return state_dicts
231
+
232
+ def __getattr__(self, key: str):
233
+ """Forward missing attributes to the wrapped module."""
234
+ try:
235
+ return super().__getattr__(key)
236
+ except AttributeError:
237
+ if 'base_model' in dir(self):
238
+ return getattr(self.base_model, key)
239
+ raise
240
+
241
+ @staticmethod
242
+ def load_state_file(path, device: Optional[str] = None):
243
+ """Load a state dict file by the input path.
244
+
245
+ Args:
246
+ path: The local dir to load the state file.
247
+
248
+ Returns:
249
+ The state dict.
250
+ """
251
+ if device is None:
252
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
253
+ if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
254
+ filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
255
+ from safetensors.torch import load_file as safe_load_file
256
+ return safe_load_file(filename, device=device)
257
+ elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
258
+ filename = os.path.join(path, WEIGHTS_NAME)
259
+ return torch.load(filename, map_location=device)
260
+ return None
261
+
262
+ def create_optimizer_param_groups(self, **defaults):
263
+ all_param_names = set()
264
+ param_groups = []
265
+ for output in self.adapters.values():
266
+ if output.optimizer_group_callback:
267
+ param_names, param_group = output.optimizer_group_callback(self.model, **defaults)
268
+ if param_names and all_param_names & param_names:
269
+ raise ValueError('Cannot set one parameter to different param groups')
270
+ if param_names and param_group:
271
+ all_param_names.update(param_names)
272
+ param_groups.extend(param_group)
273
+
274
+ decay_parameters = Trainer.get_decay_parameter_names(None, self.model)
275
+ param_groups.extend([
276
+ {
277
+ 'params': [
278
+ p for n, p in self.model.named_parameters()
279
+ if (n in decay_parameters and n not in all_param_names and p.requires_grad)
280
+ ],
281
+ 'weight_decay':
282
+ defaults['weight_decay'],
283
+ },
284
+ {
285
+ 'params': [
286
+ p for n, p in self.model.named_parameters()
287
+ if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
288
+ ],
289
+ 'weight_decay':
290
+ 0.0,
291
+ },
292
+ ])
293
+
294
+ return param_groups
295
+
296
+ @classmethod
297
+ def from_pretrained(cls,
298
+ model: Union[nn.Module, 'SwiftModel'],
299
+ model_id: str = None,
300
+ adapter_name: Union[str, List[str], Dict[str, str]] = None,
301
+ inference_mode: bool = True,
302
+ revision: str = None,
303
+ **kwargs):
304
+ """Load a set of tuners and corresponding weights by a model_id.
305
+
306
+ Args:
307
+ model (`Union[torch.nn.Module, 'SwiftModel']`): The model to be tuned,
308
+ if the model is already a `SwiftModel` it will be un-wrapped and re-wrapped..
309
+ model_id (`str`): The model_id or a local model dir of tuners to use to tune the model.
310
+ adapter_name (`Union[str, List[str], Dict[str, str]]`): The adapter_names saved in the model repo to load.
311
+ Default `None`, means load all tuners saved in the model_id
312
+ inference_mode (`bool`): Use in the inference mode or not.
313
+ revision (`str`): The model revision to use.
314
+ **kwargs:
315
+ extra_state_keys (`List[str]`, `optional`) A list of regex to match the extra state keys to be saved.
316
+ Other parameters will be passed to the device_map.
317
+ Returns:
318
+ The `SwiftModel` instance.
319
+ """
320
+ adapters = {}
321
+ model_dir = model_id
322
+ if not os.path.exists(model_dir):
323
+ model_dir = snapshot_download(model_dir, revision=revision)
324
+ if os.path.isfile(model_dir):
325
+ raise ValueError(f'Please pass in a local dir or a model id, not a local file: {model_dir}')
326
+ extra_state_keys = kwargs.pop('extra_state_keys', None)
327
+ if extra_state_keys is None and os.path.isfile(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME)):
328
+ with open(os.path.join(model_dir, cls.EXTRA_STATE_DIR, CONFIG_NAME), 'r', encoding='utf-8') as file:
329
+ _json = json.load(file)
330
+ extra_state_keys = _json.get('extra_state_keys')
331
+ if adapter_name is None:
332
+ adapter_name = [
333
+ sub_dir for sub_dir in os.listdir(model_dir)
334
+ if os.path.isfile(os.path.join(model_dir, sub_dir, CONFIG_NAME)) and sub_dir != cls.EXTRA_STATE_DIR
335
+ ]
336
+ for _name in adapter_name if isinstance(adapter_name,
337
+ list) else [adapter_name] \
338
+ if isinstance(adapter_name, str) else adapter_name.keys():
339
+ sub_folder = os.path.join(model_dir, _name)
340
+ config_file = os.path.join(sub_folder, CONFIG_NAME)
341
+
342
+ if not os.path.isfile(config_file):
343
+ logger.warning(f'{_name} is not a valid tuner')
344
+ continue
345
+
346
+ with open(config_file, 'r', encoding='utf-8') as file:
347
+ json_object = json.load(file)
348
+
349
+ if SWIFT_TYPE_KEY not in json_object:
350
+ raise ValueError('Mixed using with peft is not allowed now.')
351
+ else:
352
+ key = _name if not isinstance(adapter_name, dict) else adapter_name[_name]
353
+ adapters[key] = SwiftConfig.from_pretrained(sub_folder)
354
+
355
+ self = SwiftModel(model, adapters, extra_state_keys, inference_mode, **kwargs)
356
+ for _name in adapter_name if isinstance(adapter_name,
357
+ list) else [adapter_name] \
358
+ if isinstance(adapter_name, str) else adapter_name.keys():
359
+ _adapter = _name if not isinstance(adapter_name, dict) else adapter_name[_name]
360
+ output: SwiftOutput = self.adapters[_adapter]
361
+ sub_folder = os.path.join(model_dir, _name)
362
+ if output.load_callback:
363
+ output.load_callback(self, sub_folder, _adapter)
364
+ continue
365
+ state_dict = cls.load_state_file(sub_folder)
366
+ if state_dict is not None:
367
+ if isinstance(adapter_name, dict):
368
+ # TODO this logic is fragile! replace `_name` may cause other parts replaced
369
+ state_dict = {key.replace(_name, adapter_name[_name]): value for key, value in state_dict.items()}
370
+ self.load_state_dict(state_dict, adapter_name=_adapter)
371
+ state_dict = cls.load_state_file(os.path.join(model_dir, self.EXTRA_STATE_DIR))
372
+ if state_dict is not None:
373
+ self.load_state_dict(state_dict)
374
+ return self
375
+
376
+ @classmethod
377
+ def _prepare_model(
378
+ cls,
379
+ model: nn.Module,
380
+ config: SwiftConfig,
381
+ adapter_name: str,
382
+ ):
383
+ assert (hasattr(config, SWIFT_TYPE_KEY))
384
+ from .mapping import SWIFT_MAPPING
385
+
386
+ adapter_cls = SWIFT_MAPPING[config.swift_type][1]
387
+ if adapter_cls.has_additional_modules() and not getattr(model, 'model_frozen', False):
388
+ for _, p in model.named_parameters():
389
+ p.requires_grad = False
390
+ model.model_frozen = True
391
+ config.has_additional_modules = adapter_cls.has_additional_modules()
392
+ return adapter_cls.prepare_model(model, config, adapter_name)
393
+
394
+ def create_or_update_model_card(self, output_dir: str):
395
+ """
396
+ Updates or create the model card.
397
+ """
398
+ if not os.path.exists(os.path.join(output_dir, 'README.md')):
399
+ lines = []
400
+ else:
401
+ with open(os.path.join(output_dir, 'README.md'), 'r', encoding='utf-8') as f:
402
+ lines = f.readlines()
403
+
404
+ quantization_config = None
405
+ if hasattr(self.base_model, 'config') and hasattr(self.base_model.config, 'quantization_config'):
406
+ if hasattr(self.base_model.config.quantization_config, 'to_dict'):
407
+ quantization_config = self.base_model.config.quantization_config.to_dict()
408
+ training_config_text = ''
409
+ # Adds quantization information if it was used
410
+ if quantization_config is not None:
411
+ training_config_text += '\nThe following `bitsandbytes` quantization config was used during training:\n'
412
+ training_config_text += '\n'.join([f'- {name}: {value}' for name, value in quantization_config.items()])
413
+ training_config_text += '\n'
414
+
415
+ training_procedure_heading = '## Training procedure\n'
416
+ if training_procedure_heading in lines:
417
+ lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)
418
+ else:
419
+ lines.append(f'{training_procedure_heading}\n{training_config_text}')
420
+
421
+ framework_block_heading = '### Framework versions\n'
422
+ from swift.version import __version__
423
+ if framework_block_heading in lines:
424
+ lines.insert(lines.index(framework_block_heading) + 2, f'- SWIFT {__version__}\n')
425
+ else:
426
+ lines.append(f'{framework_block_heading}\n\n- SWIFT {__version__}\n')
427
+
428
+ base_model_heading = '### Base model information\n'
429
+ lines.append(f'{base_model_heading}\n\n- BaseModel Class {self.base_model.__class__.__name__}\n')
430
+
431
+ # write the lines back to README.md
432
+ with open(os.path.join(output_dir, 'README.md'), 'w', encoding='utf-8') as f:
433
+ f.writelines(lines)
434
+
435
+ def add_weighted_adapter(
436
+ self,
437
+ adapters,
438
+ weights,
439
+ adapter_name,
440
+ combination_type='svd',
441
+ svd_rank=None,
442
+ svd_clamp=None,
443
+ svd_full_matrices=True,
444
+ svd_driver=None,
445
+ density=None,
446
+ majority_sign_method: Literal['total', 'frequency'] = 'total',
447
+ ):
448
+ """
449
+ This method adds a new adapter by merging the given adapters with the given weights.
450
+
451
+ When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to
452
+ the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM
453
+ errors.
454
+
455
+ Args:
456
+ adapters (`list`):
457
+ List of adapter names to be merged.
458
+ weights (`list`):
459
+ List of weights for each adapter.
460
+ adapter_name (`str`):
461
+ Name of the new adapter.
462
+ combination_type (`str`):
463
+ The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
464
+ `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
465
+ combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
466
+ mixed adapter may be too big and result in OOM errors).
467
+ svd_rank (`int`, *optional*):
468
+ Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
469
+ svd_clamp (`float`, *optional*):
470
+ A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform
471
+ clamping. Defaults to None.
472
+ svd_full_matrices (`bool`, *optional*):
473
+ Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned
474
+ tensors U and Vh. Defaults to True.
475
+ svd_driver (`str`, *optional*):
476
+ Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be
477
+ one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd`
478
+ documentation. Defaults to None.
479
+ density (`float`, *optional*):
480
+ Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
481
+ with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
482
+ `magnintude_prune`, `magnitude_prune_svd`]
483
+ majority_sign_method (`str`):
484
+ The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
485
+ Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
486
+ """
487
+ from swift.tuners.lora import LoraModel
488
+ lora_model = LoraModel(self.model, None, '')
489
+ lora_model.peft_config = {key: value.config for key, value in self.adapters.items()}
490
+ from peft.tuners.lora import LoraLayer
491
+ lora_model.targeted_module_names = [
492
+ key for key, value in self.model.named_modules() if isinstance(value, LoraLayer)
493
+ ]
494
+ lora_model.active_adapter = self.active_adapters
495
+ lora_model.add_weighted_adapter(
496
+ adapters=adapters,
497
+ weights=weights,
498
+ adapter_name=adapter_name,
499
+ combination_type=combination_type,
500
+ svd_rank=svd_rank,
501
+ svd_clamp=svd_clamp,
502
+ svd_full_matrices=svd_full_matrices,
503
+ svd_driver=svd_driver,
504
+ density=density,
505
+ majority_sign_method=majority_sign_method,
506
+ )
507
+
508
+ def state_dict_callback(state_dict, adapter_name, cfg):
509
+ from swift.tuners.lora_layers import lora_state_dict
510
+ return lora_state_dict(state_dict, adapter_name, cfg.bias)
511
+
512
+ def mark_trainable_callback(model, cfg):
513
+ from swift.tuners.lora_layers import mark_lora_as_trainable
514
+ mark_lora_as_trainable(model, adapter_name, cfg.bias)
515
+
516
+ cfg = lora_model.peft_config[adapter_name]
517
+ cfg.has_additional_modules = True
518
+ self.adapters[adapter_name] = SwiftOutput(
519
+ config=cfg,
520
+ state_dict_callback=partial(state_dict_callback, cfg=cfg),
521
+ mark_trainable_callback=partial(mark_trainable_callback, cfg=cfg),
522
+ optimizer_group_callback=None,
523
+ )
524
+
525
+ self.set_active_adapters(adapter_name)
526
+
527
+ def save_pretrained(self,
528
+ save_directory: str,
529
+ safe_serialization: bool = False,
530
+ adapter_name: Union[str, List[str]] = None,
531
+ **kwargs):
532
+ """Save the adapters to a local directory.
533
+
534
+ Args:
535
+ save_directory (`str`): The directory to use.
536
+ safe_serialization (`bool`): Use safe tensors to save the weights, default False.
537
+ adapter_name(`Union[str, List[str]]`): The adapters to be saved, default is `None` to save all.
538
+ """
539
+ peft_format = kwargs.pop('peft_format', False)
540
+ if os.path.isfile(save_directory):
541
+ raise ValueError(f'Provided path ({save_directory}) should be a directory, not a file')
542
+ os.makedirs(save_directory, exist_ok=True)
543
+ if not self.has_additional_modules:
544
+ if hasattr(self.base_model, 'save_pretrained'):
545
+ self.base_model.save_pretrained(save_directory, safe_serialization=safe_serialization)
546
+ else:
547
+ self._save_state_dict(self.base_model.state_dict(), save_directory, safe_serialization)
548
+ self.create_or_update_model_card(save_directory)
549
+ else:
550
+ self.create_or_update_model_card(save_directory)
551
+
552
+ adapter_names = adapter_name if isinstance(adapter_name, list) or adapter_name is None else [adapter_name]
553
+
554
+ state_dict_kwargs = {}
555
+ state_dict = kwargs.get('state_dict')
556
+ if state_dict is not None:
557
+ state_dict_kwargs['state_dict'] = kwargs['state_dict']
558
+ for adapter_name, output in self.adapters.items():
559
+ if adapter_names is not None and adapter_name not in adapter_names:
560
+ continue
561
+
562
+ save_to_peft = peft_format and output.config.swift_type == SwiftTuners.LORA
563
+ save_to_peft = save_to_peft and output.config.can_be_saved_to_peft()
564
+ if peft_format and not save_to_peft:
565
+ logger.error('You are using additional lora parameters, which is not compatible with peft,'
566
+ 'which is unable to save to peft format.')
567
+ output_dir = os.path.join(save_directory,
568
+ adapter_name) if adapter_name != 'default' or not save_to_peft else save_directory
569
+
570
+ if save_to_peft:
571
+ config = output.config.to_peft_config()
572
+ config.save_pretrained(output_dir)
573
+ else:
574
+ output.config.save_pretrained(output_dir)
575
+
576
+ if output.save_callback:
577
+ output.save_callback(self, output_dir, adapter_name)
578
+ continue
579
+
580
+ # save only the trainable weights
581
+ output_state_dict = self.state_dict(
582
+ adapter_name=adapter_name, save_extra_states=False, peft_format=save_to_peft, **state_dict_kwargs)
583
+ os.makedirs(output_dir, exist_ok=True)
584
+ if output_state_dict and output.config.has_additional_modules:
585
+ self._save_state_dict(output_state_dict, output_dir, safe_serialization)
586
+
587
+ output_state_dict = self.state_dict(save_extra_states=True, save_adapter=False, **state_dict_kwargs)
588
+ if len(output_state_dict) > 0:
589
+ if self.has_additional_modules:
590
+ os.makedirs(os.path.join(save_directory, self.EXTRA_STATE_DIR), exist_ok=True)
591
+ self._save_state_dict(output_state_dict, os.path.join(save_directory, self.EXTRA_STATE_DIR),
592
+ safe_serialization)
593
+ with open(
594
+ os.path.join(save_directory, self.EXTRA_STATE_DIR, CONFIG_NAME), 'w', encoding='utf-8') as file:
595
+ json.dump({'extra_state_keys': self.extra_state_keys}, file)
596
+ else:
597
+ logger.error('Full parameter training, save_extra_states will be ignored')
598
+
599
+ if not os.path.exists(os.path.join(save_directory, 'configuration.json')):
600
+ with open(os.path.join(save_directory, 'configuration.json'), 'w', encoding='utf-8') as f:
601
+ f.write('{}')
602
+
603
+ @staticmethod
604
+ def _save_state_dict(output_state_dict, save_directory, safe_serialization):
605
+ if safe_serialization:
606
+ from safetensors.torch import save_file as safe_save_file
607
+ safe_save_file(
608
+ output_state_dict, os.path.join(save_directory, SAFETENSORS_WEIGHTS_NAME), metadata={'format': 'pt'})
609
+ else:
610
+ torch.save(output_state_dict, os.path.join(save_directory, WEIGHTS_NAME))
611
+
612
+ @contextmanager
613
+ def disable_adapter(self):
614
+ try:
615
+ self.set_active_adapters(adapter_names=[])
616
+ yield
617
+ finally:
618
+ self.set_active_adapters(adapter_names=self.adapters.keys())
619
+
620
+ def set_active_adapters(self, adapter_names: Union[List[str], str], offload: str = None):
621
+ """Set activated adapters
622
+
623
+ Args:
624
+ adapter_names(`Union[List[str], str]`): The adapters needed to be activated
625
+ offload(`str`): Whether to offload the deactivated ones to `cpu` or `meta` device
626
+ """
627
+ if not adapter_names:
628
+ adapter_names = []
629
+
630
+ if isinstance(adapter_names, str):
631
+ adapter_names = [adapter_names]
632
+
633
+ adapter_names = set(adapter_names)
634
+ for adapter_name in (adapter_names & set(self.adapters.keys())):
635
+ self.activate_adapter(adapter_name)
636
+
637
+ for adapter_name in (set(self.adapters.keys()) - adapter_names):
638
+ self.deactivate_adapter(adapter_name, offload)
639
+
640
+ self.active_adapters = (adapter_names & set(self.adapters.keys()))
641
+
642
+ def activate_adapter(self, adapter_name: str):
643
+ """Activate one adapter
644
+
645
+ Args:
646
+ adapter_name(`str`): The adapter needed to be activated
647
+ """
648
+ if adapter_name not in self.adapters:
649
+ logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}')
650
+ return
651
+
652
+ from .mapping import SWIFT_MAPPING
653
+ SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
654
+ .activate_adapter(self.base_model, adapter_name, True)
655
+ self.active_adapters = self.active_adapters | {adapter_name}
656
+
657
+ def deactivate_adapter(self, adapter_name: str, offload: str = None):
658
+ """Deactivate one adapter
659
+
660
+ Args:
661
+ adapter_name(`str`): The adapter needed to be activated
662
+ offload(`str`): Whether to offload to `cpu` or `meta` device
663
+ """
664
+ if adapter_name not in self.adapters:
665
+ logger.warning(f'{adapter_name} not in adapters: {self.adapters.keys()}')
666
+ return
667
+
668
+ from .mapping import SWIFT_MAPPING
669
+ SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
670
+ .activate_adapter(self.base_model, adapter_name, False, offload=offload)
671
+ self.active_adapters = self.active_adapters - {adapter_name}
672
+
673
+ def get_trainable_parameters(self):
674
+ """
675
+ Get the content of trainable parameters in the model.
676
+ """
677
+ trainable_params = 0
678
+ all_param = 0
679
+ for _, param in self.base_model.named_parameters():
680
+ num_params = param.numel()
681
+ # if using DS Zero 3 and the weights are initialized empty
682
+ if num_params == 0 and hasattr(param, 'ds_numel'):
683
+ num_params = param.ds_numel
684
+
685
+ all_param += num_params
686
+ if param.requires_grad:
687
+ trainable_params += num_params
688
+ return f'trainable params: {trainable_params:,d} || all params: {all_param:,d} ' \
689
+ f'|| trainable%: {100 * trainable_params / all_param:.4f}' \
690
+ '|| cuda memory: ' \
691
+ f'{sum([torch.cuda.memory_allocated(i) for i in range(get_device_count())])/1024/1024/1024:.2f}' \
692
+ 'GiB.'
693
+
694
+
695
+ class Swift:
696
+ """The Wrapper to use both Peft and Swift tuners."""
697
+
698
+ @staticmethod
699
+ def prepare_model(model: Union[nn.Module, SwiftModel], config: Union[SwiftConfig, PeftConfig,
700
+ Dict[str, SwiftConfig]], **kwargs):
701
+ """Prepare a model by the input config.
702
+
703
+ Args:
704
+ model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned.
705
+ config(`Union[SwiftConfig, PeftConfig, Dict[str, SwiftConfig]]`): The config or config dict, can be either
706
+ SwiftConfigs or PeftConfigs
707
+ **kwargs:
708
+ Extra kwargs needed by SwiftModel or PeftModel.
709
+ Returns:
710
+ The model wrapped by SwiftModel or PeftModel.
711
+ """
712
+
713
+ if isinstance(config, (SwiftConfig, dict)):
714
+ return SwiftModel(model, config, **kwargs)
715
+ else:
716
+ return get_peft_model(model, config, **kwargs)
717
+
718
+ @staticmethod
719
+ def merge_and_unload(model: Union[PeftModel, SwiftModel], **kwargs):
720
+ """Merge tuners into the base model and unload them.
721
+
722
+ Args:
723
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
724
+ kwargs:
725
+ adapter_name(`Union[str, List[str]]`): The adapter_name to unload, only supported in swift tuners.
726
+
727
+ """
728
+ from peft import PeftModel as _PeftModel
729
+ if isinstance(model, _PeftModel):
730
+ model.merge_and_unload()
731
+ elif isinstance(model, SwiftModel):
732
+ from swift import LoRAConfig
733
+ from swift.tuners import LoRA
734
+ adapter_name = kwargs.get('adapter_name', None)
735
+ if isinstance(adapter_name, str):
736
+ adapter_name = [adapter_name]
737
+ for adapter, output in model.adapters.items():
738
+ if isinstance(output.config, LoRAConfig) and (adapter_name is None or adapter in adapter_name):
739
+ LoRA.unpatch_lora(model, output.config, adapter)
740
+
741
+ @staticmethod
742
+ @contextmanager
743
+ def grpo_context(model: Union[SwiftModel, torch.nn.Module], processor):
744
+ # Save the model and temporarily modify model.model_dir.
745
+ if not isinstance(model, SwiftModel):
746
+ yield
747
+ return
748
+ else:
749
+ assert len(model.adapters) == 1
750
+ adapter = list(model.adapters.values())[0]
751
+ if adapter.config.swift_type == SwiftTuners.LLAMAPRO:
752
+ from modelscope.hub.utils.utils import get_cache_dir
753
+ temp_dir = tempfile.mkdtemp(dir=get_cache_dir())
754
+ model_dir = model.model_dir
755
+ from transformers.integrations import is_deepspeed_zero3_enabled
756
+ if is_deepspeed_zero3_enabled():
757
+ raise ValueError('DeepSpeed ZeRO3 not supported for LLaMAPro&GRPO currently.')
758
+ model.base_model.save_pretrained(temp_dir)
759
+ processor.save_pretrained(temp_dir)
760
+ model.model_dir = temp_dir
761
+ yield
762
+ if adapter.config.swift_type == SwiftTuners.LLAMAPRO:
763
+ model.model_dir = model_dir
764
+ shutil.rmtree(temp_dir)
765
+
766
+ @staticmethod
767
+ def merge(model: Union[PeftModel, SwiftModel], **kwargs):
768
+ """Merge tuners into the base model, will not unload them.
769
+
770
+ Args:
771
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
772
+ """
773
+ from .lora_layers import LoraLayer, LoRALayer
774
+ for sub_module in model.modules():
775
+ if isinstance(sub_module, (LoraLayer, LoRALayer)):
776
+ sub_module.merge(**kwargs)
777
+
778
+ @staticmethod
779
+ def unmerge(model: Union[PeftModel, SwiftModel], **kwargs):
780
+ """Unmerge tuners from the base model
781
+
782
+ Args:
783
+ model(`Union[PeftModel, SwiftModel]`): The model instance with tuners
784
+ """
785
+ from .lora_layers import LoraLayer, LoRALayer
786
+ for sub_module in model.modules():
787
+ if isinstance(sub_module, (LoraLayer, LoRALayer)):
788
+ sub_module.unmerge(**kwargs)
789
+
790
+ @staticmethod
791
+ def save_to_peft_format(ckpt_dir: str, output_dir: str) -> None:
792
+ """Save swift format to peft format
793
+
794
+ Args:
795
+ ckpt_dir(`str`): Original swift output dir
796
+ output_dir(`str`): Converted peft format dir
797
+ """
798
+ assert ckpt_dir and output_dir, 'Please pass in valid ckpt_dir and output_dir.'
799
+ assert os.path.exists(ckpt_dir), f'ckpt_dir: {ckpt_dir} must exists in local disk.'
800
+ if os.path.exists(os.path.join(ckpt_dir, SwiftModel.EXTRA_STATE_DIR)):
801
+ raise AssertionError('Cannot transfer to peft format, because you are additional state dicts.')
802
+
803
+ adapter_names = [
804
+ sub_dir for sub_dir in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, sub_dir, CONFIG_NAME))
805
+ ]
806
+
807
+ def has_custom_content(_json):
808
+ if _json.get('swift_type', _json.get('peft_type')) != SwiftTuners.LORA:
809
+ logger.warn('Only LoRA can be converted to peft format')
810
+ return True
811
+
812
+ from swift import LoRAConfig
813
+ return not LoRAConfig(**_json).can_be_saved_to_peft()
814
+
815
+ for adapter in adapter_names:
816
+ with open(os.path.join(ckpt_dir, adapter, CONFIG_NAME), encoding='utf-8') as f:
817
+ _json = json.load(f)
818
+ if has_custom_content(_json):
819
+ raise AssertionError('Cannot transfer to peft format, '
820
+ 'because you have special parameters or adapter types.')
821
+
822
+ os.makedirs(output_dir, exist_ok=True)
823
+ if ckpt_dir != output_dir:
824
+ shutil.copytree(ckpt_dir, output_dir, dirs_exist_ok=True)
825
+
826
+ for adapter in adapter_names:
827
+ safe_serialization = os.path.isfile(os.path.join(output_dir, adapter, SAFETENSORS_WEIGHTS_NAME))
828
+ state_dict = SwiftModel.load_state_file(os.path.join(output_dir, adapter))
829
+ new_state_dict = {}
830
+ for key, value in state_dict.items():
831
+ if not key.startswith('base_model.model.'):
832
+ key = 'base_model.model.' + key
833
+ key = key.replace(f'lora_A.{adapter}.', 'lora_A.')
834
+ key = key.replace(f'lora_B.{adapter}.', 'lora_B.')
835
+ key = key.replace(f'lora_embedding_A.{adapter}.', 'lora_embedding_A.')
836
+ key = key.replace(f'lora_embedding_B.{adapter}.', 'lora_embedding_B.')
837
+ key = key.replace(f'lora_magnitude_vector.{adapter}', 'lora_magnitude_vector')
838
+ new_state_dict[key] = value
839
+ state_dict = new_state_dict
840
+ SwiftModel._save_state_dict(state_dict, os.path.join(output_dir, adapter), safe_serialization)
841
+ from swift import LoRAConfig
842
+ with open(os.path.join(output_dir, adapter, CONFIG_NAME), encoding='utf-8') as f:
843
+ _json = json.load(f)
844
+ peft_config = LoRAConfig(**_json).to_peft_config()
845
+ peft_config.save_pretrained(os.path.join(output_dir, adapter))
846
+
847
+ if 'default' in adapter_names:
848
+ shutil.move(os.path.join(output_dir, 'default', CONFIG_NAME), os.path.join(output_dir, CONFIG_NAME))
849
+ state_dict = SwiftModel.load_state_file(os.path.join(output_dir, 'default'))
850
+ safe_serialization = os.path.isfile(os.path.join(output_dir, 'default', SAFETENSORS_WEIGHTS_NAME))
851
+ SwiftModel._save_state_dict(state_dict, output_dir, safe_serialization)
852
+ shutil.rmtree(os.path.join(output_dir, 'default'))
853
+
854
+ @staticmethod
855
+ def from_pretrained(model: Union[nn.Module, SwiftModel, PeftModel],
856
+ model_id: str = None,
857
+ adapter_name: Union[str, List[str], Dict[str, str]] = None,
858
+ revision: str = None,
859
+ **kwargs):
860
+ """Prepare a model by a model_id in the ModelScope hub or a local dir.
861
+
862
+ Args:
863
+ model(`Union[nn.Module, 'SwiftModel']`): The model to be tuned.
864
+ model_id(`str`): The model id of the modelhub or a local dir containing the configs/weights.
865
+ adapter_name(`str`, `optional`): The adapter_name to use.
866
+ revision(`str`, `optional`): The model revision if the model_id is a model id of the modelhub.
867
+ **kwargs:
868
+ Extra kwargs needed by ``SwiftModel.from_pretrained`` or ``PeftModel.from_pretrained``.
869
+ Returns:
870
+ The model wrapped by SwiftModel or PeftModel.
871
+ """
872
+ if not os.path.exists(model_id):
873
+ model_id = snapshot_download(model_id, revision=revision)
874
+ is_peft_model = False
875
+ if os.path.exists(os.path.join(model_id, CONFIG_NAME)):
876
+ with open(os.path.join(model_id, CONFIG_NAME), 'r', encoding='utf-8') as f:
877
+ _json = json.load(f)
878
+ is_peft_model = SWIFT_TYPE_KEY not in _json
879
+
880
+ _name = adapter_name if isinstance(
881
+ adapter_name, str) or adapter_name is None else adapter_name[0] \
882
+ if isinstance(adapter_name, list) else list(adapter_name.keys())[0]
883
+ _name = _name or ''
884
+ if os.path.exists(os.path.join(model_id, _name, CONFIG_NAME)):
885
+ with open(os.path.join(model_id, _name, CONFIG_NAME), 'r', encoding='utf-8') as f:
886
+ _json = json.load(f)
887
+ is_peft_model = SWIFT_TYPE_KEY not in _json and 'extra_state_keys' not in _json
888
+ if is_peft_model:
889
+
890
+ def load_peft_model(_model, _adapter_name, _new_name=None):
891
+ if not _new_name:
892
+ _new_name = _adapter_name
893
+ import peft
894
+ if not isinstance(_model, peft.PeftModel):
895
+ return PeftModel.from_pretrained(
896
+ _model,
897
+ os.path.join(model_id, _adapter_name) if _adapter_name != 'default'
898
+ and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id,
899
+ revision=revision,
900
+ adapter_name=_new_name,
901
+ **kwargs)
902
+ else:
903
+ _model.load_adapter(
904
+ os.path.join(model_id, _adapter_name) if _adapter_name != 'default'
905
+ and os.path.exists(os.path.join(model_id, _adapter_name)) else model_id, _new_name)
906
+ return _model
907
+
908
+ if not adapter_name:
909
+ peft_model = load_peft_model(model, 'default')
910
+ for _dir in os.listdir(model_id):
911
+ if os.path.isdir(os.path.join(model_id, _dir)) and \
912
+ os.path.exists(os.path.join(model_id, _dir, CONFIG_NAME)):
913
+ peft_model = load_peft_model(peft_model, _dir)
914
+ elif isinstance(adapter_name, str):
915
+ return load_peft_model(model, adapter_name)
916
+ elif isinstance(adapter_name, list):
917
+ peft_model = model
918
+ for name in adapter_name:
919
+ peft_model = load_peft_model(peft_model, name)
920
+ else:
921
+ peft_model = model
922
+ for key, value in adapter_name.items():
923
+ peft_model = load_peft_model(peft_model, key, value)
924
+ return peft_model
925
+ else:
926
+ return SwiftModel.from_pretrained(model, model_id, revision=revision, adapter_name=adapter_name, **kwargs)
ms-swift/swift/tuners/longlora/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
ms-swift/swift/tuners/longlora/llama.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Part of the implementation is borrowed from dvlab-research/LongLoRA.
3
+
4
+ import math
5
+ from types import MethodType
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from transformers import Cache, StaticCache
12
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
13
+
14
+ from swift.utils import get_logger
15
+
16
+ logger = get_logger()
17
+
18
+
19
+ def _preprocess_qkv_fa2(attn_module, query_states, key_states, value_states, attention_mask):
20
+ if attn_module.training:
21
+ bsz, q_len = query_states.shape[:2]
22
+ group_size = int(q_len * attn_module.config.group_size_ratio)
23
+ if q_len % group_size != 0:
24
+ raise ValueError(f'The sequence length {q_len} should'
25
+ f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}')
26
+
27
+ num_group = q_len // group_size
28
+
29
+ def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
30
+ qkv[:, :, num_heads // 2:] = qkv[:, :, num_heads // 2:].roll(-group_size // 2, dims=1)
31
+ qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim)
32
+ return qkv
33
+
34
+ query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
35
+ key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
36
+ value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
37
+ if attention_mask is not None:
38
+ attention_mask = attention_mask[:, :group_size].repeat(num_group, 1)
39
+
40
+ return query_states, key_states, value_states, attention_mask
41
+
42
+
43
+ def _preprocess_qkv(attn_module, query_states, key_states, value_states, attention_mask):
44
+ if attn_module.training:
45
+ bsz, _, q_len = query_states.shape[:3]
46
+ group_size = int(q_len * attn_module.config.group_size_ratio)
47
+ if q_len % group_size != 0:
48
+ raise ValueError(f'The sequence length {q_len} should'
49
+ f'be able to be split by the group_ratio {attn_module.config.group_size_ratio}')
50
+
51
+ num_group = q_len // group_size
52
+
53
+ def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
54
+ qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
55
+ qkv = qkv.transpose(1, 2)
56
+ qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim)
57
+ return qkv.transpose(1, 2)
58
+
59
+ query_states = shift(query_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
60
+ key_states = shift(key_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
61
+ value_states = shift(value_states, bsz, q_len, group_size, attn_module.num_heads, attn_module.head_dim)
62
+ if attention_mask is not None:
63
+ attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
64
+
65
+ return query_states, key_states, value_states, attention_mask
66
+
67
+
68
+ def _postprocess_qkv(attn_module, attn_output, q_len):
69
+ if attn_module.training:
70
+ group_size = int(q_len * attn_module.config.group_size_ratio)
71
+ attn_output = attn_output.transpose(1, 2)
72
+ attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim)
73
+ # shift back
74
+ attn_output_clone = attn_output.clone()
75
+ attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll(
76
+ group_size // 2, dims=1)
77
+ attn_output = attn_output_clone
78
+ return attn_output.transpose(1, 2)
79
+
80
+
81
+ def _postprocess_qkv_fa2(attn_module, attn_output, q_len):
82
+ if attn_module.training:
83
+ group_size = int(q_len * attn_module.config.group_size_ratio)
84
+ attn_output = attn_output.reshape(-1, q_len, attn_module.num_heads, attn_module.head_dim)
85
+ attn_output_clone = attn_output.clone()
86
+ # shift back
87
+ attn_output_clone[:, :, attn_module.num_heads // 2:] = attn_output[:, :, attn_module.num_heads // 2:].roll(
88
+ group_size // 2, dims=1)
89
+ attn_output = attn_output_clone
90
+ return attn_output
91
+
92
+
93
+ # code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa
94
+ def eager_forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ position_ids: Optional[torch.LongTensor] = None,
99
+ past_key_value: Optional[Cache] = None,
100
+ output_attentions: bool = False,
101
+ use_cache: bool = False,
102
+ cache_position: Optional[torch.LongTensor] = None,
103
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
104
+ **kwargs,
105
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
106
+ bsz, q_len, _ = hidden_states.size()
107
+
108
+ if self.config.pretraining_tp > 1:
109
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
110
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
111
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
112
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
113
+
114
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
115
+ query_states = torch.cat(query_states, dim=-1)
116
+
117
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
118
+ key_states = torch.cat(key_states, dim=-1)
119
+
120
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
121
+ value_states = torch.cat(value_states, dim=-1)
122
+
123
+ else:
124
+ query_states = self.q_proj(hidden_states)
125
+ key_states = self.k_proj(hidden_states)
126
+ value_states = self.v_proj(hidden_states)
127
+
128
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
129
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
130
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
131
+
132
+ if position_embeddings is None:
133
+ logger.warning_once(
134
+ 'The attention layers in this model are transitioning from computing the RoPE embeddings internally '
135
+ 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed '
136
+ '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be '
137
+ 'removed and `position_embeddings` will be mandatory.')
138
+ cos, sin = self.rotary_emb(value_states, position_ids)
139
+ else:
140
+ cos, sin = position_embeddings
141
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
142
+
143
+ if past_key_value is not None:
144
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
145
+ cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
146
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
147
+
148
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
149
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
150
+
151
+ # patch position rolling
152
+ query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states,
153
+ attention_mask)
154
+
155
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
156
+
157
+ if attention_mask is not None: # no matter the length, we just slice it
158
+ causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
159
+ attn_weights = attn_weights + causal_mask
160
+
161
+ # upcast attention to fp32
162
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
163
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
164
+ attn_output = torch.matmul(attn_weights, value_states)
165
+
166
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
167
+ raise ValueError(f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
168
+ f' {attn_output.size()}')
169
+
170
+ # patch position unrolling
171
+ attn_output = _postprocess_qkv(self, attn_output, q_len)
172
+
173
+ attn_output = attn_output.transpose(1, 2).contiguous()
174
+
175
+ attn_output = attn_output.reshape(bsz, q_len, -1)
176
+
177
+ if self.config.pretraining_tp > 1:
178
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
179
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
180
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
181
+ else:
182
+ attn_output = self.o_proj(attn_output)
183
+
184
+ if not output_attentions:
185
+ attn_weights = None
186
+
187
+ return attn_output, attn_weights, past_key_value
188
+
189
+
190
+ # code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa
191
+ def fa2_forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.LongTensor] = None,
195
+ position_ids: Optional[torch.LongTensor] = None,
196
+ past_key_value: Optional[Cache] = None,
197
+ output_attentions: bool = False,
198
+ use_cache: bool = False,
199
+ cache_position: Optional[torch.LongTensor] = None,
200
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
201
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
202
+ if isinstance(past_key_value, StaticCache):
203
+ raise ValueError(
204
+ '`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` '
205
+ 'make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers'
206
+ )
207
+
208
+ output_attentions = False
209
+
210
+ bsz, q_len, _ = hidden_states.size()
211
+
212
+ query_states = self.q_proj(hidden_states)
213
+ key_states = self.k_proj(hidden_states)
214
+ value_states = self.v_proj(hidden_states)
215
+
216
+ # Flash attention requires the input to have the shape
217
+ # batch_size x seq_length x head_dim x hidden_dim
218
+ # therefore we just need to keep the original shape
219
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
220
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
221
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
222
+
223
+ if position_embeddings is None:
224
+ logger.warning_once(
225
+ 'The attention layers in this model are transitioning from computing the RoPE embeddings internally '
226
+ 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed '
227
+ '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be '
228
+ 'removed and `position_embeddings` will be mandatory.')
229
+ cos, sin = self.rotary_emb(value_states, position_ids)
230
+ else:
231
+ cos, sin = position_embeddings
232
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
233
+
234
+ if past_key_value is not None:
235
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
236
+ cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
237
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
238
+
239
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
240
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
241
+ # to be able to avoid many of these transpose/reshape/view.
242
+ query_states = query_states.transpose(1, 2)
243
+ key_states = key_states.transpose(1, 2)
244
+ value_states = value_states.transpose(1, 2)
245
+
246
+ dropout_rate = self.attention_dropout if self.training else 0.0
247
+
248
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
249
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
250
+ # cast them back in the correct dtype just to be sure everything works as expected.
251
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
252
+ # in fp32. (LlamaRMSNorm handles it correctly)
253
+
254
+ input_dtype = query_states.dtype
255
+ if input_dtype == torch.float32:
256
+ if torch.is_autocast_enabled():
257
+ target_dtype = torch.get_autocast_gpu_dtype()
258
+ # Handle the case where the model is quantized
259
+ elif hasattr(self.config, '_pre_quantization_dtype'):
260
+ target_dtype = self.config._pre_quantization_dtype
261
+ else:
262
+ target_dtype = self.q_proj.weight.dtype
263
+
264
+ logger.warning_once(
265
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
266
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
267
+ f' {target_dtype}.')
268
+
269
+ query_states = query_states.to(target_dtype)
270
+ key_states = key_states.to(target_dtype)
271
+ value_states = value_states.to(target_dtype)
272
+
273
+ # patch position rolling
274
+ query_states, key_states, value_states, attention_mask = _preprocess_qkv_fa2(
275
+ self, query_states, key_states, value_states, attention_mask)
276
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
277
+ attn_output = _flash_attention_forward(
278
+ query_states,
279
+ key_states,
280
+ value_states,
281
+ attention_mask,
282
+ q_len,
283
+ position_ids=position_ids,
284
+ dropout=dropout_rate,
285
+ sliding_window=getattr(self, 'sliding_window', None),
286
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
287
+ is_causal=self.is_causal,
288
+ )
289
+
290
+ # patch position unrolling
291
+ attn_output = _postprocess_qkv_fa2(self, attn_output, q_len)
292
+
293
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
294
+ attn_output = self.o_proj(attn_output)
295
+
296
+ if not output_attentions:
297
+ attn_weights = None
298
+
299
+ return attn_output, attn_weights, past_key_value
300
+
301
+
302
+ # code borrowed from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py # noqa
303
+ def sdpa_forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ past_key_value: Optional[Cache] = None,
309
+ output_attentions: bool = False,
310
+ use_cache: bool = False,
311
+ cache_position: Optional[torch.LongTensor] = None,
312
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
313
+ **kwargs,
314
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
315
+ if output_attentions:
316
+ return super().forward(
317
+ hidden_states=hidden_states,
318
+ attention_mask=attention_mask,
319
+ position_ids=position_ids,
320
+ past_key_value=past_key_value,
321
+ output_attentions=output_attentions,
322
+ use_cache=use_cache,
323
+ cache_position=cache_position,
324
+ position_embeddings=position_embeddings,
325
+ )
326
+
327
+ bsz, q_len, _ = hidden_states.size()
328
+
329
+ query_states = self.q_proj(hidden_states)
330
+ key_states = self.k_proj(hidden_states)
331
+ value_states = self.v_proj(hidden_states)
332
+
333
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
334
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
335
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
336
+
337
+ if position_embeddings is None:
338
+ logger.warning_once(
339
+ 'The attention layers in this model are transitioning from computing the RoPE embeddings internally '
340
+ 'through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed '
341
+ '`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be '
342
+ 'removed and `position_embeddings` will be mandatory.')
343
+ cos, sin = self.rotary_emb(value_states, position_ids)
344
+ else:
345
+ cos, sin = position_embeddings
346
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
347
+
348
+ if past_key_value is not None:
349
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
350
+ cache_kwargs = {'sin': sin, 'cos': cos, 'cache_position': cache_position}
351
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
352
+
353
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
354
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
355
+
356
+ causal_mask = attention_mask
357
+ if attention_mask is not None:
358
+ causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
359
+
360
+ if query_states.device.type == 'cuda' and causal_mask is not None:
361
+ query_states = query_states.contiguous()
362
+ key_states = key_states.contiguous()
363
+ value_states = value_states.contiguous()
364
+
365
+ is_causal = True if causal_mask is None and q_len > 1 else False
366
+
367
+ # patch position rolling
368
+ query_states, key_states, value_states, causal_mask = _preprocess_qkv(self, query_states, key_states, value_states,
369
+ causal_mask)
370
+
371
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
372
+ query_states,
373
+ key_states,
374
+ value_states,
375
+ attn_mask=causal_mask,
376
+ dropout_p=self.attention_dropout if self.training else 0.0,
377
+ is_causal=is_causal,
378
+ )
379
+
380
+ # patch position unrolling
381
+ attn_output = _postprocess_qkv(self, attn_output, q_len)
382
+
383
+ attn_output = attn_output.transpose(1, 2).contiguous()
384
+ attn_output = attn_output.view(bsz, q_len, -1)
385
+
386
+ attn_output = self.o_proj(attn_output)
387
+
388
+ return attn_output, None, past_key_value
389
+
390
+
391
+ def replace_llama_attn(model: nn.Module):
392
+ layers = None
393
+ for module in model.modules():
394
+ if isinstance(module, torch.nn.ModuleList):
395
+ layers = module
396
+ break
397
+ assert layers is not None
398
+ for idx, m in enumerate(layers):
399
+ if model.config._attn_implementation == 'flash_attention_2':
400
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
401
+ if cuda_major < 8:
402
+ logger.warn(
403
+ 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' # noqa
404
+ 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593')
405
+ m.self_attn.forward = MethodType(fa2_forward, m.self_attn)
406
+ elif model.config._attn_implementation == 'eager':
407
+ m.self_attn.forward = MethodType(eager_forward, m.self_attn)
408
+ elif model.config._attn_implementation == 'sdpa':
409
+ m.self_attn.forward = MethodType(sdpa_forward, m.self_attn)
ms-swift/swift/tuners/reft.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass
3
+ from types import MethodType
4
+ from typing import List, Literal, Optional
5
+
6
+ import json
7
+ import torch
8
+ from torch import nn
9
+
10
+ from swift.utils import get_logger, patch_getattr
11
+ from .utils import SwiftAdapter, SwiftConfig, SwiftOutput
12
+
13
+ logger = get_logger()
14
+
15
+
16
+ @dataclass
17
+ class ReftConfig(SwiftConfig):
18
+ """
19
+ Train a model with Reft.
20
+ Paper: https://arxiv.org/pdf/2404.03592
21
+
22
+ Args:
23
+ model_type(`Optional[str]`): The model_type to find down_proj/layers.
24
+ layer_key(`Optional[str]`): Manually specify the layer key, for example `language_model.layers`.
25
+ layers (`Optional[List[int]]`): The layer number to inject.
26
+ r(`int`): The rank of Reft.
27
+ intervention_type (`Literal['NoreftIntervention', 'LoreftIntervention',
28
+ 'ConsreftIntervention', 'LobireftIntervention',
29
+ 'DireftIntervention', 'NodireftIntervention']`): The intervention type,
30
+ default LoreftIntervention
31
+ args (`Optional[str]`): Other reft_args in json-string format
32
+ """
33
+
34
+ model_type: Optional[str] = None
35
+ layer_key: Optional[str] = None
36
+ layers: Optional[List[int]] = None
37
+ r: int = 4
38
+ intervention_type: Literal['NoreftIntervention', 'LoreftIntervention', 'ConsreftIntervention',
39
+ 'LobireftIntervention', 'DireftIntervention',
40
+ 'NodireftIntervention'] = 'LoreftIntervention'
41
+ args: Optional[str] = None
42
+
43
+ def __post_init__(self):
44
+ from .mapping import SwiftTuners
45
+ self.swift_type = SwiftTuners.REFT
46
+ if self.args:
47
+ self.args = json.loads(self.args)
48
+ else:
49
+ self.args = {}
50
+
51
+
52
+ class Reft(SwiftAdapter):
53
+
54
+ @staticmethod
55
+ def prepare_model(model: nn.Module, config: ReftConfig, adapter_name: str):
56
+ from swift.utils.import_utils import is_pyreft_available
57
+ if not is_pyreft_available():
58
+ raise ImportError('Please install pyreft before using ReFT: ' '`pip install pyreft`')
59
+
60
+ import pyreft
61
+ from pyreft import ReftModel
62
+ from pyreft.interventions import LowRankRotateLayer
63
+ from pyreft import (
64
+ NoreftIntervention,
65
+ LoreftIntervention,
66
+ ConsreftIntervention,
67
+ LobireftIntervention,
68
+ DireftIntervention,
69
+ NodireftIntervention,
70
+ )
71
+
72
+ intervention_mapping = {
73
+ 'NoreftIntervention': NoreftIntervention,
74
+ 'LoreftIntervention': LoreftIntervention,
75
+ 'ConsreftIntervention': ConsreftIntervention,
76
+ 'LobireftIntervention': LobireftIntervention,
77
+ 'DireftIntervention': DireftIntervention,
78
+ 'NodireftIntervention': NodireftIntervention,
79
+ }
80
+
81
+ patch_getattr(ReftModel, 'model')
82
+
83
+ def forward(self, x):
84
+ self.to(x.device)
85
+ return self.forward_origin(x)
86
+
87
+ def forward2(self, base, source=None, subspaces=None):
88
+ self.to(base.device)
89
+ return self.forward_origin(base, source, subspaces)
90
+
91
+ if not hasattr(LowRankRotateLayer, 'forward_origin'):
92
+ LowRankRotateLayer.forward_origin = LowRankRotateLayer.forward
93
+ LowRankRotateLayer.forward = forward
94
+ NoreftIntervention.forward_origin = NoreftIntervention.forward
95
+ NoreftIntervention.forward = forward2
96
+ LoreftIntervention.forward_origin = LoreftIntervention.forward
97
+ LoreftIntervention.forward = forward2
98
+ ConsreftIntervention.forward_origin = ConsreftIntervention.forward
99
+ ConsreftIntervention.forward = forward2
100
+ LobireftIntervention.forward_origin = LobireftIntervention.forward
101
+ LobireftIntervention.forward = forward2
102
+ DireftIntervention.forward_origin = DireftIntervention.forward
103
+ DireftIntervention.forward = forward2
104
+ NodireftIntervention.forward_origin = NodireftIntervention.forward
105
+ NodireftIntervention.forward = forward2
106
+
107
+ module_list_key = config.layer_key
108
+ if module_list_key is None:
109
+ model_key_mapping = Reft.get_model_key_mapping(config.model_type, config)
110
+ module_list_key = model_key_mapping.module_list
111
+ logger.info(f'Applying Reft to module: {module_list_key}')
112
+ module_list: nn.ModuleList = model.get_submodule(module_list_key)
113
+ representations = []
114
+ for idx, layer in enumerate(module_list):
115
+ if config.layers and idx not in config.layers:
116
+ continue
117
+ intervention_config = {
118
+ 'layer':
119
+ idx,
120
+ 'component':
121
+ module_list_key + f'[{idx}].output',
122
+ 'low_rank_dimension':
123
+ config.r,
124
+ 'intervention':
125
+ intervention_mapping[config.intervention_type](
126
+ embed_dim=model.config.hidden_size, low_rank_dimension=config.r, **config.args)
127
+ }
128
+ representations.append(intervention_config)
129
+
130
+ reft_config = pyreft.ReftConfig(representations=representations)
131
+ reft_model = pyreft.get_reft_model(model, reft_config, set_device=False)
132
+ reft_model.reft_config = reft_model.config
133
+ reft_model.config = reft_model.model.config
134
+
135
+ def _pre_forward_hook(module, args, kwargs):
136
+ if 'base' in kwargs:
137
+ return args, kwargs
138
+
139
+ if 'input_ids' not in kwargs:
140
+ raise ValueError('Input does not contain `input_ids`, maybe the model does not support ReFT.')
141
+ # run intervened forward pass
142
+ unit_locations = None
143
+ if 'intervention_locations' in kwargs:
144
+ if kwargs['intervention_locations'].dim() == 3:
145
+ unit_locations = {
146
+ 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
147
+ }
148
+ else:
149
+ # this is dummy for lora only baseline
150
+ unit_locations = {'sources->base': (None, 0)}
151
+ kwargs = {
152
+ 'base': {
153
+ 'input_ids': kwargs['input_ids'],
154
+ 'attention_mask': kwargs['attention_mask']
155
+ },
156
+ 'unit_locations': unit_locations,
157
+ 'labels': kwargs['labels'],
158
+ 'subspaces': kwargs['subspaces'].permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
159
+ }
160
+ return args, kwargs
161
+
162
+ def _post_forward_hook(module, args, kwargs, outputs):
163
+ return outputs[1]
164
+
165
+ def _generate(self, **kwargs):
166
+ # run intervened forward pass
167
+ unit_locations = None
168
+ if 'intervention_locations' in kwargs:
169
+ if kwargs['intervention_locations'].dim() == 3:
170
+ unit_locations = {
171
+ 'sources->base': (None, kwargs['intervention_locations'].permute(1, 0, 2).tolist())
172
+ }
173
+ else:
174
+ # this is dummy for lora only baseline
175
+ unit_locations = {'sources->base': (None, 0)}
176
+
177
+ _kwargs = {
178
+ 'base': {
179
+ 'input_ids': kwargs.pop('input_ids'),
180
+ 'attention_mask': kwargs.pop('attention_mask')
181
+ },
182
+ 'unit_locations': unit_locations,
183
+ 'subspaces': kwargs.pop('subspaces').permute(1, 0, 2).tolist() if 'subspaces' in kwargs else None
184
+ }
185
+ _kwargs = {**_kwargs, **kwargs}
186
+ return self.generate_origin(**_kwargs)[1]
187
+
188
+ reft_model.generate_origin = reft_model.generate
189
+ reft_model.generate = MethodType(_generate, reft_model)
190
+ reft_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
191
+ reft_model.register_forward_hook(_post_forward_hook, with_kwargs=True)
192
+
193
+ def save_callback(swift_model, model_dir, adapter_name):
194
+ reft_model.save_intervention(save_directory=model_dir, include_model=False)
195
+
196
+ def mark_trainable_callback(model):
197
+ return
198
+
199
+ def load_callback(swift_model, model_dir, adapter_name):
200
+ reft_model.load_intervention(model_dir, include_model=False)
201
+
202
+ return SwiftOutput(
203
+ model=reft_model,
204
+ config=config,
205
+ mark_trainable_callback=mark_trainable_callback,
206
+ save_callback=save_callback,
207
+ load_callback=load_callback)
208
+
209
+ @staticmethod
210
+ def has_additional_modules():
211
+ return True
212
+
213
+ @staticmethod
214
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
215
+ assert activate, 'ReFT does not support deactivate'
ms-swift/swift/tuners/scetuning/scetuning.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import re
3
+ import types
4
+ from dataclasses import dataclass, field
5
+ from typing import List, Optional, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from swift.tuners.utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
11
+ from swift.utils import get_logger
12
+ from swift.utils.torch_utils import find_sub_module
13
+ from .scetuning_components import probe_output_hook
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ @dataclass
19
+ class SCETuningConfig(SwiftConfig):
20
+ """
21
+ The configuration class for the SCEdit module.
22
+
23
+ 'SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing' by Jiang et al.(2023)
24
+ See https://arxiv.org/abs/2312.11392
25
+
26
+ Args:
27
+ dims(`Union[List[int], int]`): The dimensions of the hidden states
28
+ target_modules(`Union[List[str], str]`): The target module to be replaced, can a regex string
29
+ hint_modules(`Union[List[str], str]`): The hint module to be replaced, can a regex string
30
+ tuner_mode(`str`): Location of tuner operation.
31
+ tuner_op(`str`): Tuner operation.
32
+ down_ratio(`float`): The dim down ratio of tuner hidden state.
33
+ """
34
+
35
+ dims: Optional[Union[List[int], int]] = field(
36
+ default=None, metadata={'help': 'The dimensions of the hidden states'})
37
+
38
+ target_modules: Optional[Union[List[str], str]] = field(
39
+ default=None,
40
+ metadata={'help': 'The target module to be replaced, can be a regex string or name list of full match format'})
41
+
42
+ hint_modules: Optional[Union[List[str], str]] = field(
43
+ default=None,
44
+ metadata={'help': 'The hint modules to be replaced, can be a regex string or name list of full match format'})
45
+
46
+ tuner_mode: str = field(
47
+ default='decoder',
48
+ metadata={'help': 'Location of tuner operation. The tuner mode choices: encoder, decoder, and identity'})
49
+
50
+ tuner_op: str = field(default='SCEAdapter', metadata={'help': 'The tuner ops choices: SCEAdapter'})
51
+
52
+ down_ratio: float = field(default=1.0, metadata={'help': 'The dim down ratio of tuner hidden state'})
53
+
54
+ def __post_init__(self):
55
+ from swift.tuners.mapping import SwiftTuners
56
+ self.swift_type = SwiftTuners.SCETUNING
57
+
58
+
59
+ class SCETuning(SwiftAdapter):
60
+
61
+ @staticmethod
62
+ def prepare_model(model: nn.Module, config: SCETuningConfig, adapter_name: str) -> SwiftOutput:
63
+ """Prepare a model with `SCETuningConfig`"""
64
+ module_keys = [key for key, _ in model.named_modules()]
65
+ # 1. Matching the hint module
66
+ hint_module_ins_list = []
67
+ if config.hint_modules:
68
+ if isinstance(config.hint_modules, list):
69
+ for module_key in config.hint_modules:
70
+ assert module_key in module_keys
71
+ h_module = model.get_submodule(module_key)
72
+ logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}')
73
+ if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)):
74
+ logger.warning(
75
+ f'Type of {type(h_module)} may not be supported because of its customized forward')
76
+ h_module.register_forward_hook(probe_output_hook, with_kwargs=True)
77
+ hint_module_ins_list.append(h_module)
78
+ else:
79
+ for module_key in module_keys:
80
+ if re.fullmatch(config.hint_modules, module_key):
81
+ h_module = model.get_submodule(module_key)
82
+ logger.info(f'Matching hint module [{module_key}] of type {type(h_module)}')
83
+ if isinstance(h_module, (nn.ModuleList, nn.ModuleDict)):
84
+ logger.warning(
85
+ f'Type of {type(h_module)} may not be supported because of its customized forward')
86
+ h_module.register_forward_hook(probe_output_hook, with_kwargs=True)
87
+ hint_module_ins_list.append(h_module)
88
+ if len(hint_module_ins_list) == 0:
89
+ logger.error('Cannot match hint modules')
90
+
91
+ def _get_module(module):
92
+ if isinstance(module, nn.ModuleList):
93
+ module = module[-1]
94
+ return _get_module(module)
95
+ return module
96
+
97
+ # 2. Matching the target module
98
+ target_module_ins_list = []
99
+ assert config.target_modules is not None
100
+ if isinstance(config.target_modules, list):
101
+ for module_key in config.target_modules:
102
+ assert module_key in module_keys
103
+ t_module = model.get_submodule(module_key)
104
+ logger.info(f'Matching target module [{module_key}] of type {type(t_module)}')
105
+ target_module_ins_list.append(_get_module(t_module))
106
+ else:
107
+ for module_key in module_keys:
108
+ if re.fullmatch(config.target_modules, module_key):
109
+ t_module = model.get_submodule(module_key)
110
+ logger.info(f'Matching target module [{module_key}] of type {type(t_module)}')
111
+ target_module_ins_list.append(_get_module(t_module))
112
+ if len(target_module_ins_list) == 0:
113
+ logger.error('Cannot match target modules')
114
+ if len(hint_module_ins_list) > 0 and not len(hint_module_ins_list) == len(target_module_ins_list):
115
+ logger.info("Target modules' length should be equal with hint modules.")
116
+ assert len(hint_module_ins_list) == len(target_module_ins_list)
117
+ if isinstance(config.dims, int):
118
+ dims = [config.dims for _ in target_module_ins_list]
119
+ else:
120
+ assert len(config.dims) == len(target_module_ins_list)
121
+ dims = config.dims
122
+
123
+ # refactor forward function
124
+ def _forward_encoder_mode(self, *args, **kwargs):
125
+ args = getattr(self, f'forward_origin_{adapter_name}')(*args, **kwargs)
126
+ args_type = type(args)
127
+ if args_type is tuple:
128
+ args = args[0]
129
+ if hasattr(self, 'hint'):
130
+ hint_out = self.hint.probe_output_data
131
+ args_main = getattr(self, f'scetuner_{adapter_name}')(args, hint_out)
132
+ else:
133
+ args_main = getattr(self, f'scetuner_{adapter_name}')(args)
134
+ if args_type is tuple:
135
+ args_main = (args_main, )
136
+ return args_main
137
+
138
+ def _forward_decoder_mode(self, *args, **kwargs):
139
+ args_type = type(args)
140
+ if args_type is tuple:
141
+ args_sub_tuner = args[0]
142
+ args_sub_extra = args[1:]
143
+ tuner_module = getattr(self, f'scetuner_{adapter_name}')
144
+ args_hidden, args_res = torch.split(args_sub_tuner, args_sub_tuner.shape[1] - tuner_module.dim, 1)
145
+ if hasattr(self, 'hint'):
146
+ hint_out = self.hint.probe_output_data
147
+ args_res_new = tuner_module(args_res, hint_out)
148
+ else:
149
+ args_res_new = tuner_module(args_res)
150
+ args_sub_tuner_new = torch.cat([args_hidden, args_res_new], dim=1)
151
+ if args_type is tuple:
152
+ args_main = (args_sub_tuner_new, *args_sub_extra)
153
+
154
+ args_main = getattr(self, f'forward_origin_{adapter_name}')(*args_main, **kwargs)
155
+ return args_main
156
+
157
+ # 3. inject the tuners
158
+ for tuner_id, t_module in enumerate(target_module_ins_list):
159
+ setattr(t_module, f'forward_origin_{adapter_name}', getattr(t_module, 'forward'))
160
+ if config.tuner_mode in ('encoder', 'identity'):
161
+ _forward = _forward_encoder_mode
162
+ elif config.tuner_mode == 'decoder':
163
+ _forward = _forward_decoder_mode
164
+ else:
165
+ raise Exception(f'Error tuner_mode: {config.tuner_mode}')
166
+ setattr(t_module, 'forward', types.MethodType(_forward, t_module))
167
+ tuner_op = SCETunerModule(
168
+ name=config.tuner_op,
169
+ adapter_name=adapter_name,
170
+ module_key=str(tuner_id),
171
+ dim=dims[tuner_id],
172
+ tuner_length=int(dims[tuner_id] * config.down_ratio))
173
+ setattr(t_module, f'scetuner_{adapter_name}', tuner_op)
174
+ if len(hint_module_ins_list) > 0:
175
+ setattr(t_module, 'hint', hint_module_ins_list[tuner_id])
176
+
177
+ def state_dict_callback(state_dict, adapter_name, **kwargs):
178
+ state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key}
179
+ return state_dict_new
180
+
181
+ def mark_trainable_callback(model):
182
+ return
183
+
184
+ return SwiftOutput(
185
+ config=config, state_dict_callback=state_dict_callback, mark_trainable_callback=mark_trainable_callback)
186
+
187
+ @staticmethod
188
+ def activate_adapter(module: torch.nn.Module, adapter_name: str, activate: bool, offload: str = None):
189
+ modules = find_sub_module(module, f'scetuner_{adapter_name}')
190
+ for _module in modules:
191
+ _module: ActivationMixin
192
+ _module: nn.Module
193
+ _module.set_activation(adapter_name, activate)
194
+ SwiftAdapter.save_memory(_module, adapter_name, _module.module_key, activate, offload)
195
+
196
+
197
+ class SCETunerModule(nn.Module, ActivationMixin):
198
+
199
+ def __init__(self,
200
+ name,
201
+ adapter_name,
202
+ module_key,
203
+ dim,
204
+ tuner_length,
205
+ tuner_type=None,
206
+ tuner_weight=None,
207
+ act_layer=nn.GELU,
208
+ zero_init_last=True,
209
+ use_bias=True):
210
+ super(SCETunerModule, self).__init__()
211
+ super(nn.Module, self).__init__(module_key)
212
+ self.name = name
213
+ self.adapter_name = adapter_name
214
+ self.dim = dim
215
+ if name == 'SCEAdapter':
216
+ from .scetuning_components import SCEAdapter
217
+ self.tuner_op = SCEAdapter(
218
+ dim=dim,
219
+ adapter_length=tuner_length,
220
+ adapter_type=tuner_type,
221
+ adapter_weight=tuner_weight,
222
+ act_layer=act_layer)
223
+ else:
224
+ raise Exception(f'Error tuner op {name}')
225
+ self.mark_all_sub_modules_as_plugin()
226
+
227
+ def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
228
+ if not self.is_activated(self.adapter_name):
229
+ return x
230
+ if self.name == 'SCEAdapter':
231
+ self.tuner_op.to(x.device)
232
+ out = self.tuner_op(x)
233
+ else:
234
+ raise Exception(f'Error tuner op {self.name}')
235
+ return out
ms-swift/swift/ui/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from .app import webui_main
ms-swift/swift/ui/llm_eval/llm_eval.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import re
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from functools import partial
8
+ from typing import Type
9
+
10
+ import gradio as gr
11
+ import json
12
+ import torch
13
+ from json import JSONDecodeError
14
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
15
+
16
+ from swift.llm import EvalArguments
17
+ from swift.ui.base import BaseUI
18
+ from swift.ui.llm_eval.eval import Eval
19
+ from swift.ui.llm_eval.model import Model
20
+ from swift.ui.llm_eval.runtime import EvalRuntime
21
+ from swift.utils import get_device_count
22
+
23
+
24
+ class LLMEval(BaseUI):
25
+ group = 'llm_eval'
26
+
27
+ sub_ui = [Model, Eval, EvalRuntime]
28
+
29
+ cmd = 'eval'
30
+
31
+ locale_dict = {
32
+ 'llm_eval': {
33
+ 'label': {
34
+ 'zh': 'LLM评测',
35
+ 'en': 'LLM evaluation',
36
+ }
37
+ },
38
+ 'more_params': {
39
+ 'label': {
40
+ 'zh': '更多参数',
41
+ 'en': 'More params'
42
+ },
43
+ 'info': {
44
+ 'zh': '以json格式或--xxx xxx命令行格式填入',
45
+ 'en': 'Fill in with json format or --xxx xxx cmd format'
46
+ }
47
+ },
48
+ 'evaluate': {
49
+ 'value': {
50
+ 'zh': '开始评测',
51
+ 'en': 'Begin Evaluation'
52
+ },
53
+ },
54
+ 'gpu_id': {
55
+ 'label': {
56
+ 'zh': '选择可用GPU',
57
+ 'en': 'Choose GPU'
58
+ },
59
+ 'info': {
60
+ 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU',
61
+ 'en': 'Select GPU to train'
62
+ }
63
+ },
64
+ }
65
+
66
+ choice_dict = BaseUI.get_choices_from_dataclass(EvalArguments)
67
+ default_dict = BaseUI.get_default_value_from_dataclass(EvalArguments)
68
+ arguments = BaseUI.get_argument_names(EvalArguments)
69
+
70
+ @classmethod
71
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
72
+ with gr.TabItem(elem_id='llm_eval', label=''):
73
+ default_device = 'cpu'
74
+ device_count = get_device_count()
75
+ if device_count > 0:
76
+ default_device = '0'
77
+ with gr.Blocks():
78
+ Model.build_ui(base_tab)
79
+ Eval.build_ui(base_tab)
80
+ EvalRuntime.build_ui(base_tab)
81
+ with gr.Row():
82
+ gr.Textbox(elem_id='more_params', lines=4, scale=20)
83
+ gr.Button(elem_id='evaluate', scale=2, variant='primary')
84
+ gr.Dropdown(
85
+ elem_id='gpu_id',
86
+ multiselect=True,
87
+ choices=[str(i) for i in range(device_count)] + ['cpu'],
88
+ value=default_device,
89
+ scale=8)
90
+
91
+ cls.element('evaluate').click(
92
+ cls.eval_model, list(base_tab.valid_elements().values()),
93
+ [cls.element('runtime_tab'), cls.element('running_tasks')])
94
+
95
+ base_tab.element('running_tasks').change(
96
+ partial(EvalRuntime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')],
97
+ list(base_tab.valid_elements().values()) + [cls.element('log')])
98
+ EvalRuntime.element('kill_task').click(
99
+ EvalRuntime.kill_task,
100
+ [EvalRuntime.element('running_tasks')],
101
+ [EvalRuntime.element('running_tasks')] + [EvalRuntime.element('log')],
102
+ )
103
+
104
+ @classmethod
105
+ def eval(cls, *args):
106
+ eval_args = cls.get_default_value_from_dataclass(EvalArguments)
107
+ kwargs = {}
108
+ kwargs_is_list = {}
109
+ other_kwargs = {}
110
+ more_params = {}
111
+ more_params_cmd = ''
112
+ keys = cls.valid_element_keys()
113
+ for key, value in zip(keys, args):
114
+ compare_value = eval_args.get(key)
115
+ compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
116
+ compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
117
+ if key in eval_args and compare_value_ui != compare_value_arg and value:
118
+ if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
119
+ value = int(value)
120
+ elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
121
+ value = float(value)
122
+ elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
123
+ value = True if value.lower() == 'true' else False
124
+ kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
125
+ kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
126
+ else:
127
+ other_kwargs[key] = value
128
+ if key == 'more_params' and value:
129
+ try:
130
+ more_params = json.loads(value)
131
+ except (JSONDecodeError or TypeError):
132
+ more_params_cmd = value
133
+
134
+ kwargs.update(more_params)
135
+ model = kwargs.get('model')
136
+ if model and os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')):
137
+ kwargs['ckpt_dir'] = kwargs.pop('model')
138
+
139
+ eval_args = EvalArguments(
140
+ **{
141
+ key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value
142
+ for key, value in kwargs.items()
143
+ })
144
+ params = ''
145
+ sep = f'{cls.quote} {cls.quote}'
146
+ for e in kwargs:
147
+ if isinstance(kwargs[e], list):
148
+ params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} '
149
+ elif e in kwargs_is_list and kwargs_is_list[e]:
150
+ all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()]
151
+ params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} '
152
+ else:
153
+ params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} '
154
+ params += more_params_cmd + ' '
155
+ devices = other_kwargs['gpu_id']
156
+ devices = [d for d in devices if d]
157
+ assert (len(devices) == 1 or 'cpu' not in devices)
158
+ gpus = ','.join(devices)
159
+ cuda_param = ''
160
+ if gpus != 'cpu':
161
+ if is_torch_npu_available():
162
+ cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}'
163
+ elif is_torch_cuda_available():
164
+ cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
165
+ else:
166
+ cuda_param = ''
167
+ now = datetime.now()
168
+ time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}'
169
+ file_path = f'output/{eval_args.model_type}-{time_str}'
170
+ if not os.path.exists(file_path):
171
+ os.makedirs(file_path, exist_ok=True)
172
+ log_file = os.path.join(os.getcwd(), f'{file_path}/run_eval.log')
173
+ eval_args.log_file = log_file
174
+ params += f'--log_file "{log_file}" '
175
+ params += '--ignore_args_error true '
176
+ if sys.platform == 'win32':
177
+ if cuda_param:
178
+ cuda_param = f'set {cuda_param} && '
179
+ run_command = f'{cuda_param}start /b swift eval {params} > {log_file} 2>&1'
180
+ else:
181
+ run_command = f'{cuda_param} nohup swift eval {params} > {log_file} 2>&1 &'
182
+ return run_command, eval_args, log_file
183
+
184
+ @classmethod
185
+ def eval_model(cls, *args):
186
+ run_command, eval_args, log_file = cls.eval(*args)
187
+ os.system(run_command)
188
+ time.sleep(2)
189
+ return gr.update(open=True), EvalRuntime.refresh_tasks(log_file)
ms-swift/swift/ui/llm_eval/runtime.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+ from packaging import version
6
+
7
+ from swift.ui.base import BaseUI
8
+ from swift.ui.llm_infer.runtime import Runtime
9
+ from swift.utils import get_logger
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ class EvalRuntime(Runtime):
15
+
16
+ group = 'llm_eval'
17
+
18
+ cmd = 'eval'
19
+
20
+ locale_dict = {
21
+ 'runtime_tab': {
22
+ 'label': {
23
+ 'zh': '运行时',
24
+ 'en': 'Runtime'
25
+ },
26
+ },
27
+ 'running_cmd': {
28
+ 'label': {
29
+ 'zh': '运行命令',
30
+ 'en': 'Command line'
31
+ },
32
+ 'info': {
33
+ 'zh': '执行的实际命令',
34
+ 'en': 'The actual command'
35
+ }
36
+ },
37
+ 'show_log': {
38
+ 'value': {
39
+ 'zh': '展示评测状态',
40
+ 'en': 'Show eval status'
41
+ },
42
+ },
43
+ 'stop_show_log': {
44
+ 'value': {
45
+ 'zh': '停止展示',
46
+ 'en': 'Stop showing running status'
47
+ },
48
+ },
49
+ 'log': {
50
+ 'label': {
51
+ 'zh': '日志输出',
52
+ 'en': 'Logging content'
53
+ },
54
+ 'info': {
55
+ 'zh': '如果日志无更新请再次点击"展示日志内容"',
56
+ 'en': 'Please press "Show log" if the log content is not updating'
57
+ }
58
+ },
59
+ 'running_tasks': {
60
+ 'label': {
61
+ 'zh': '运行中评测',
62
+ 'en': 'Running evaluation'
63
+ },
64
+ 'info': {
65
+ 'zh': '所有的swift eval命令启动的任务',
66
+ 'en': 'All tasks started by swift eval'
67
+ }
68
+ },
69
+ 'refresh_tasks': {
70
+ 'value': {
71
+ 'zh': '找回评测',
72
+ 'en': 'Find evaluation'
73
+ },
74
+ },
75
+ 'kill_task': {
76
+ 'value': {
77
+ 'zh': '杀死评测',
78
+ 'en': 'Kill evaluation'
79
+ },
80
+ },
81
+ }
82
+
83
+ @classmethod
84
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
85
+ with gr.Accordion(elem_id='runtime_tab', open=False, visible=True):
86
+ with gr.Blocks():
87
+ with gr.Row():
88
+ gr.Dropdown(elem_id='running_tasks', scale=10)
89
+ gr.Button(elem_id='refresh_tasks', scale=1, variant='primary')
90
+ gr.Button(elem_id='show_log', scale=1, variant='primary')
91
+ gr.Button(elem_id='stop_show_log', scale=1)
92
+ gr.Button(elem_id='kill_task', scale=1, size='lg')
93
+ with gr.Row():
94
+ gr.Textbox(elem_id='log', lines=6, visible=False)
95
+
96
+ concurrency_limit = {}
97
+ if version.parse(gr.__version__) >= version.parse('4.0.0'):
98
+ concurrency_limit = {'concurrency_limit': 5}
99
+ cls.log_event = base_tab.element('show_log').click(cls.update_log, [], [cls.element('log')]).then(
100
+ cls.wait, [base_tab.element('running_tasks')], [cls.element('log')], **concurrency_limit)
101
+
102
+ base_tab.element('stop_show_log').click(cls.break_log_event, [cls.element('running_tasks')], [])
103
+
104
+ base_tab.element('refresh_tasks').click(
105
+ cls.refresh_tasks,
106
+ [base_tab.element('running_tasks')],
107
+ [base_tab.element('running_tasks')],
108
+ )
ms-swift/swift/ui/llm_export/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
ms-swift/swift/ui/llm_export/export.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.llm.dataset.register import get_dataset_list
7
+ from swift.ui.base import BaseUI
8
+
9
+
10
+ class Export(BaseUI):
11
+
12
+ group = 'llm_export'
13
+
14
+ locale_dict = {
15
+ 'merge_lora': {
16
+ 'label': {
17
+ 'zh': '合并lora',
18
+ 'en': 'Merge lora'
19
+ },
20
+ 'info': {
21
+ 'zh':
22
+ 'lora合并的路径在填入的checkpoint同级目录,请查看运行时log获取更具体的信息',
23
+ 'en':
24
+ 'The output path is in the sibling directory as the input checkpoint. '
25
+ 'Please refer to the runtime log for more specific information.'
26
+ },
27
+ },
28
+ 'device_map': {
29
+ 'label': {
30
+ 'zh': '合并lora使用的device_map',
31
+ 'en': 'The device_map when merge-lora'
32
+ },
33
+ 'info': {
34
+ 'zh': '如果显存不够请填入cpu',
35
+ 'en': 'If GPU memory is not enough, fill in cpu'
36
+ },
37
+ },
38
+ 'quant_bits': {
39
+ 'label': {
40
+ 'zh': '量化比特数',
41
+ 'en': 'Quantize bits'
42
+ },
43
+ },
44
+ 'quant_method': {
45
+ 'label': {
46
+ 'zh': '量化方法',
47
+ 'en': 'Quantize method'
48
+ },
49
+ },
50
+ 'quant_n_samples': {
51
+ 'label': {
52
+ 'zh': '量化集采样数',
53
+ 'en': 'Sampled rows from calibration dataset'
54
+ },
55
+ },
56
+ 'max_length': {
57
+ 'label': {
58
+ 'zh': '量化集的max-length',
59
+ 'en': 'The quantize sequence length'
60
+ },
61
+ },
62
+ 'output_dir': {
63
+ 'label': {
64
+ 'zh': '输出路径',
65
+ 'en': 'Output dir'
66
+ },
67
+ },
68
+ 'dataset': {
69
+ 'label': {
70
+ 'zh': '校准数据集',
71
+ 'en': 'Calibration datasets'
72
+ },
73
+ },
74
+ }
75
+
76
+ @classmethod
77
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
78
+ with gr.Row():
79
+ gr.Checkbox(elem_id='merge_lora', scale=10)
80
+ gr.Textbox(elem_id='device_map', scale=20)
81
+ with gr.Row():
82
+ gr.Dropdown(elem_id='quant_bits', scale=20)
83
+ gr.Dropdown(elem_id='quant_method', scale=20)
84
+ gr.Textbox(elem_id='quant_n_samples', scale=20)
85
+ gr.Textbox(elem_id='max_length', scale=20)
86
+ with gr.Row():
87
+ gr.Textbox(elem_id='output_dir', scale=20)
88
+ gr.Dropdown(
89
+ elem_id='dataset', multiselect=True, allow_custom_value=True, choices=get_dataset_list(), scale=20)
ms-swift/swift/ui/llm_train/galore.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Galore(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'galore_tab': {
15
+ 'label': {
16
+ 'zh': 'Galore参数设置',
17
+ 'en': 'Galore Settings'
18
+ },
19
+ },
20
+ 'use_galore': {
21
+ 'label': {
22
+ 'zh': '使用GaLore',
23
+ 'en': 'Use GaLore'
24
+ },
25
+ 'info': {
26
+ 'zh': '使用Galore来减少全参数训练的显存消耗',
27
+ 'en': 'Use Galore to reduce GPU memory usage in full parameter training'
28
+ }
29
+ },
30
+ 'galore_rank': {
31
+ 'label': {
32
+ 'zh': 'Galore的秩',
33
+ 'en': 'The rank of Galore'
34
+ },
35
+ },
36
+ 'galore_update_proj_gap': {
37
+ 'label': {
38
+ 'zh': 'Galore project matrix更新频率',
39
+ 'en': 'The updating gap of the project matrix'
40
+ },
41
+ },
42
+ 'galore_optim_per_parameter': {
43
+ 'label': {
44
+ 'zh': '为每个Galore Parameter创建单独的optimizer',
45
+ 'en': 'Create unique optimizer for per Galore parameter'
46
+ },
47
+ },
48
+ }
49
+
50
+ @classmethod
51
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
52
+ with gr.Accordion(elem_id='galore_tab', open=False):
53
+ with gr.Blocks():
54
+ with gr.Row():
55
+ gr.Checkbox(elem_id='use_galore', scale=4)
56
+ gr.Slider(elem_id='galore_rank', minimum=8, maximum=256, step=8, scale=4)
57
+ gr.Slider(elem_id='galore_update_proj_gap', minimum=10, maximum=1000, step=50, scale=4)
58
+ gr.Checkbox(elem_id='galore_optim_per_parameter', scale=4)
ms-swift/swift/ui/llm_train/lisa.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Type
3
+
4
+ import gradio as gr
5
+
6
+ from swift.ui.base import BaseUI
7
+
8
+
9
+ class Lisa(BaseUI):
10
+
11
+ group = 'llm_train'
12
+
13
+ locale_dict = {
14
+ 'lisa_tab': {
15
+ 'label': {
16
+ 'zh': 'LISA参数设置',
17
+ 'en': 'LISA settings'
18
+ },
19
+ },
20
+ 'lisa_activated_layers': {
21
+ 'label': {
22
+ 'zh': 'LISA激活层数',
23
+ 'en': 'LoRA activated layers'
24
+ },
25
+ 'info': {
26
+ 'zh': 'LISA每次训练的模型层数,调整为正整数代表使用LISA',
27
+ 'en': 'Num of layers activated each time, a positive value means using lisa'
28
+ }
29
+ },
30
+ 'lisa_step_interval': {
31
+ 'label': {
32
+ 'zh': 'LISA切换layers间隔',
33
+ 'en': 'The interval of lisa layers switching'
34
+ }
35
+ },
36
+ }
37
+
38
+ @classmethod
39
+ def do_build_ui(cls, base_tab: Type['BaseUI']):
40
+ with gr.Accordion(elem_id='lisa_tab', open=False):
41
+ with gr.Blocks():
42
+ with gr.Row():
43
+ gr.Textbox(elem_id='lisa_activated_layers')
44
+ gr.Textbox(elem_id='lisa_step_interval')
ms-swift/swift/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
ms-swift/swift/utils/__pycache__/constants.cpython-310.pyc ADDED
Binary file (762 Bytes). View file
 
ms-swift/swift/utils/__pycache__/env.cpython-310.pyc ADDED
Binary file (3.36 kB). View file
 
ms-swift/swift/utils/__pycache__/import_utils.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
ms-swift/swift/utils/__pycache__/io_utils.cpython-310.pyc ADDED
Binary file (4.76 kB). View file
 
ms-swift/swift/utils/__pycache__/tb_utils.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
ms-swift/swift/utils/__pycache__/torchacc_utils.cpython-310.pyc ADDED
Binary file (25.1 kB). View file
 
ms-swift/swift/utils/env.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ from transformers.utils import strtobool
8
+
9
+ from .logger import get_logger
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ def use_hf_hub():
15
+ return strtobool(os.environ.get('USE_HF', '0'))
16
+
17
+
18
+ def is_deepspeed_enabled():
19
+ return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', '0'))
20
+
21
+
22
+ def use_torchacc() -> bool:
23
+ return strtobool(os.getenv('USE_TORCHACC', '0'))
24
+
25
+
26
+ def get_dist_setting() -> Tuple[int, int, int, int]:
27
+ """return rank, local_rank, world_size, local_world_size"""
28
+ rank = int(os.getenv('RANK', -1))
29
+ local_rank = int(os.getenv('LOCAL_RANK', -1))
30
+ world_size = int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1)
31
+ # compat deepspeed launch
32
+ local_world_size = int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1))
33
+ return rank, local_rank, world_size, local_world_size
34
+
35
+
36
+ def get_node_setting():
37
+ node_rank = int(os.getenv('NODE_RANK', 0))
38
+ nnodes = int(os.getenv('NNODES', 1))
39
+ return node_rank, nnodes
40
+
41
+
42
+ def is_local_master():
43
+ local_rank = get_dist_setting()[1]
44
+ return local_rank in {-1, 0}
45
+
46
+
47
+ def is_master():
48
+ rank = get_dist_setting()[0]
49
+ return rank in {-1, 0}
50
+
51
+
52
+ def torchacc_trim_graph():
53
+ return strtobool(os.getenv('TORCHACC_TRIM_GRAPH', '0'))
54
+
55
+
56
+ def is_dist():
57
+ """Determine if the training is distributed"""
58
+ if use_torchacc():
59
+ return False
60
+ rank, local_rank, _, _ = get_dist_setting()
61
+ return rank >= 0 and local_rank >= 0
62
+
63
+
64
+ def is_mp() -> bool:
65
+ if use_torchacc():
66
+ return False
67
+ if strtobool(os.environ.get('USE_FAST_INFERENCE', 'false')):
68
+ return False
69
+ from swift.utils import get_device_count
70
+ n_gpu = get_device_count()
71
+ local_world_size = get_dist_setting()[3]
72
+ assert n_gpu % local_world_size == 0, f'n_gpu: {n_gpu}, local_world_size: {local_world_size}'
73
+ if n_gpu // local_world_size >= 2:
74
+ return True
75
+ return False
76
+
77
+
78
+ def is_mp_ddp() -> bool:
79
+ # patch_mp_ddp will occur when `import swift`.
80
+ if is_dist() and is_mp():
81
+ logger.info('Using MP(device_map) + DDP')
82
+ return True
83
+ return False
84
+
85
+
86
+ def is_dist_ta() -> bool:
87
+ """Determine if the TorchAcc training is distributed"""
88
+ _, _, world_size, _ = get_dist_setting()
89
+ if use_torchacc() and world_size > 1:
90
+ if not dist.is_initialized():
91
+ import torchacc as ta
92
+ # Initialize in advance
93
+ dist.init_process_group(backend=ta.dist.BACKEND_NAME)
94
+ return True
95
+ else:
96
+ return False
97
+
98
+
99
+ def is_pai_training_job() -> bool:
100
+ return 'PAI_TRAINING_JOB_ID' in os.environ
101
+
102
+
103
+ def get_pai_tensorboard_dir() -> Optional[str]:
104
+ return os.environ.get('PAI_OUTPUT_TENSORBOARD')
ms-swift/swift/utils/import_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Copyright 2023-present the HuggingFace Inc. team.
3
+
4
+ import importlib.util
5
+ import os
6
+ from itertools import chain
7
+ from types import ModuleType
8
+ from typing import Any
9
+
10
+ from .logger import get_logger
11
+
12
+ logger = get_logger() # pylint: disable=invalid-name
13
+
14
+
15
+ def is_vllm_available():
16
+ return importlib.util.find_spec('vllm') is not None
17
+
18
+
19
+ def is_vllm_ascend_available():
20
+ return importlib.util.find_spec('vllm_ascend') is not None
21
+
22
+
23
+ def is_lmdeploy_available():
24
+ return importlib.util.find_spec('lmdeploy') is not None
25
+
26
+
27
+ def is_liger_available():
28
+ return importlib.util.find_spec('liger_kernel') is not None
29
+
30
+
31
+ def is_swanlab_available():
32
+ return importlib.util.find_spec('swanlab') is not None
33
+
34
+
35
+ def is_xtuner_available():
36
+ return importlib.util.find_spec('xtuner') is not None
37
+
38
+
39
+ def is_megatron_available():
40
+ return importlib.util.find_spec('megatron') is not None
41
+
42
+
43
+ def is_unsloth_available() -> bool:
44
+ return importlib.util.find_spec('unsloth') is not None
45
+
46
+
47
+ def is_pyreft_available() -> bool:
48
+ return importlib.util.find_spec('pyreft') is not None
49
+
50
+
51
+ def is_wandb_available() -> bool:
52
+ return importlib.util.find_spec('wandb') is not None
53
+
54
+
55
+ class _LazyModule(ModuleType):
56
+ """
57
+ Module class that surfaces all objects but only performs associated imports when the objects are requested.
58
+ """
59
+
60
+ # Very heavily inspired by optuna.integration._IntegrationModule
61
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
62
+ def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
63
+ super().__init__(name)
64
+ self._modules = set(import_structure.keys())
65
+ self._class_to_module = {}
66
+ for key, values in import_structure.items():
67
+ for value in values:
68
+ self._class_to_module[value] = key
69
+ # Needed for autocompletion in an IDE
70
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
71
+ self.__file__ = module_file
72
+ self.__spec__ = module_spec
73
+ self.__path__ = [os.path.dirname(module_file)]
74
+ self._objects = {} if extra_objects is None else extra_objects
75
+ self._name = name
76
+ self._import_structure = import_structure
77
+
78
+ # Needed for autocompletion in an IDE
79
+ def __dir__(self):
80
+ result = super().__dir__()
81
+ # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
82
+ # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
83
+ for attr in self.__all__:
84
+ if attr not in result:
85
+ result.append(attr)
86
+ return result
87
+
88
+ def __getattr__(self, name: str) -> Any:
89
+ if name in self._objects:
90
+ return self._objects[name]
91
+ if name in self._modules:
92
+ value = self._get_module(name)
93
+ elif name in self._class_to_module.keys():
94
+ module = self._get_module(self._class_to_module[name])
95
+ value = getattr(module, name)
96
+ else:
97
+ raise AttributeError(f'module {self.__name__} has no attribute {name}')
98
+
99
+ setattr(self, name, value)
100
+ return value
101
+
102
+ def _get_module(self, module_name: str):
103
+ return importlib.import_module('.' + module_name, self.__name__)
104
+
105
+ def __reduce__(self):
106
+ return self.__class__, (self._name, self.__file__, self._import_structure)
ms-swift/swift/utils/io_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from queue import Queue
4
+ from threading import Thread
5
+ from typing import Any, Dict, List, Literal, Union
6
+
7
+ import json
8
+ import requests
9
+ import torch.distributed as dist
10
+ from accelerate.utils import gather_object
11
+ from modelscope.hub.api import ModelScopeConfig
12
+ from tqdm import tqdm
13
+
14
+ from .env import is_master
15
+ from .logger import get_logger
16
+ from .utils import check_json_format
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ def download_ms_file(url: str, local_path: str, cookies=None) -> None:
22
+ if cookies is None:
23
+ cookies = ModelScopeConfig.get_cookies()
24
+ resp = requests.get(url, cookies=cookies, stream=True)
25
+ with open(local_path, 'wb') as f:
26
+ for data in tqdm(resp.iter_lines()):
27
+ f.write(data)
28
+
29
+
30
+ def read_from_jsonl(fpath: str, encoding: str = 'utf-8') -> List[Any]:
31
+ res: List[Any] = []
32
+ with open(fpath, 'r', encoding=encoding) as f:
33
+ for line in f:
34
+ res.append(json.loads(line))
35
+ return res
36
+
37
+
38
+ def write_to_jsonl(fpath: str, obj_list: List[Any], encoding: str = 'utf-8') -> None:
39
+ res: List[str] = []
40
+ for obj in obj_list:
41
+ res.append(json.dumps(obj, ensure_ascii=False))
42
+ with open(fpath, 'w', encoding=encoding) as f:
43
+ text = '\n'.join(res)
44
+ f.write(f'{text}\n')
45
+
46
+
47
+ class JsonlWriter:
48
+
49
+ def __init__(self, fpath: str, *, encoding: str = 'utf-8', strict: bool = True, enable_async: bool = False):
50
+ self.fpath = os.path.abspath(os.path.expanduser(fpath)) if is_master() else None
51
+ self.encoding = encoding
52
+ self.strict = strict
53
+ self.enable_async = enable_async
54
+ self._queue = Queue()
55
+ self._thread = None
56
+
57
+ def _append_worker(self):
58
+ while True:
59
+ item = self._queue.get()
60
+ self._append(**item)
61
+
62
+ def _append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
63
+ if isinstance(obj, (list, tuple)) and all(isinstance(item, dict) for item in obj):
64
+ obj_list = obj
65
+ else:
66
+ obj_list = [obj]
67
+ if gather_obj and dist.is_initialized():
68
+ obj_list = gather_object(obj_list)
69
+ if not is_master():
70
+ return
71
+ obj_list = check_json_format(obj_list)
72
+ for i, _obj in enumerate(obj_list):
73
+ obj_list[i] = json.dumps(_obj, ensure_ascii=False) + '\n'
74
+ self._write_buffer(''.join(obj_list))
75
+
76
+ def append(self, obj: Union[Dict, List[Dict]], gather_obj: bool = False):
77
+ if self.enable_async:
78
+ if self._thread is None:
79
+ self._thread = Thread(target=self._append_worker, daemon=True)
80
+ self._thread.start()
81
+ self._queue.put({'obj': obj, 'gather_obj': gather_obj})
82
+ else:
83
+ self._append(obj, gather_obj=gather_obj)
84
+
85
+ def _write_buffer(self, text: str):
86
+ if not text:
87
+ return
88
+ assert is_master(), f'is_master(): {is_master()}'
89
+ try:
90
+ os.makedirs(os.path.dirname(self.fpath), exist_ok=True)
91
+ with open(self.fpath, 'a', encoding=self.encoding) as f:
92
+ f.write(text)
93
+ except Exception:
94
+ if self.strict:
95
+ raise
96
+ logger.error(f'Cannot write content to jsonl file. text: {text}')
97
+
98
+
99
+ def append_to_jsonl(fpath: str, obj: Union[Dict, List[Dict]], *, encoding: str = 'utf-8', strict: bool = True) -> None:
100
+ jsonl_writer = JsonlWriter(fpath, encoding=encoding, strict=strict)
101
+ jsonl_writer.append(obj)
102
+
103
+
104
+ def get_file_mm_type(file_name: str) -> Literal['image', 'video', 'audio']:
105
+ video_extensions = {'.mp4', '.mkv', '.mov', '.avi', '.wmv', '.flv', '.webm'}
106
+ audio_extensions = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a'}
107
+ image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
108
+
109
+ _, ext = os.path.splitext(file_name)
110
+
111
+ if ext.lower() in video_extensions:
112
+ return 'video'
113
+ elif ext.lower() in audio_extensions:
114
+ return 'audio'
115
+ elif ext.lower() in image_extensions:
116
+ return 'image'
117
+ else:
118
+ raise ValueError(f'file_name: {file_name}, ext: {ext}')
ms-swift/swift/utils/np_utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ def transform_jsonl_to_df(dict_list: List[Dict[str, Any]]) -> pd.DataFrame:
9
+ """Relevant function: `io_utils.read_from_jsonl()`"""
10
+ data_dict: Dict[str, List[Any]] = {}
11
+ for i, obj in enumerate(dict_list):
12
+ for k, v in obj.items():
13
+ if k not in data_dict:
14
+ data_dict[k] = [None] * i
15
+ data_dict[k].append(v)
16
+ for k in set(data_dict.keys()) - set(obj.keys()):
17
+ data_dict[k].append(None)
18
+ return pd.DataFrame.from_dict(data_dict)
19
+
20
+
21
+ def get_seed(random_state: Optional[np.random.RandomState] = None) -> int:
22
+ if random_state is None:
23
+ random_state = np.random.RandomState()
24
+ seed_max = np.iinfo(np.int32).max
25
+ seed = random_state.randint(0, seed_max)
26
+ return seed
27
+
28
+
29
+ def stat_array(array: Union[np.ndarray, List[int], 'torch.Tensor']) -> Tuple[Dict[str, float], str]:
30
+ if isinstance(array, list):
31
+ array = np.array(array)
32
+ mean = array.mean().item()
33
+ std = array.std().item()
34
+ min_ = array.min().item()
35
+ max_ = array.max().item()
36
+ size = array.shape[0]
37
+ string = f'{mean:.6f}±{std:.6f}, min={min_:.6f}, max={max_:.6f}, size={size}'
38
+ return {'mean': mean, 'std': std, 'min': min_, 'max': max_, 'size': size}, string
ms-swift/swift/utils/torchacc_utils.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ import types
5
+ from typing import List, Optional, Tuple
6
+
7
+ import safetensors
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import transformers
11
+ from packaging import version
12
+ from peft import PeftModel
13
+ from torch.utils.data import DataLoader
14
+ from transformers import PreTrainedModel, trainer
15
+ from transformers.modeling_utils import unwrap_model
16
+
17
+ from swift.utils import get_logger, torchacc_trim_graph, use_torchacc
18
+
19
+ logger = get_logger()
20
+
21
+
22
+ # DataLoader
23
+ def get_bucket_sizes(max_length: int) -> List[int]:
24
+ """Get the bucket sizes for TorchAcc.
25
+ You can set the environment variable TORCHACC_DATA_BUCKETS to specify
26
+ the bucket sizes. If not set, we use a normal distribution bucketing with
27
+ 8 buckets.
28
+ """
29
+ padding_p_base = 2
30
+ if os.getenv('TORCHACC_DATA_BUCKETS') is not None:
31
+ bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')]
32
+ bucket_sizes.append(max_length)
33
+ else:
34
+ if os.getenv('TORCHACC_CACHE_PATH') is not None: # padding strategy when persistent cache is enabled
35
+ padding_p_base = 1.4
36
+ padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base)
37
+ try:
38
+ padding_p_base = float(padding_p_base)
39
+ except ValueError as e:
40
+ logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}')
41
+ raise e
42
+ bucket_sizes = [16, 32, 48, 64, 96, 128]
43
+ base_size = 256
44
+ while base_size < max_length:
45
+ bucket_sizes.append((int(base_size) + 127) // 128 * 128)
46
+ base_size *= padding_p_base
47
+ bucket_sizes.append(max_length)
48
+
49
+ return bucket_sizes
50
+
51
+
52
+ def _get_closet_bucket(bucket_sizes, data_length):
53
+ """Select the one from bucket_sizes that is closest in distance to
54
+ data_length. This is required for TorchAcc.
55
+ """
56
+ closest_length = sys.maxsize
57
+ for b in bucket_sizes:
58
+ if b == data_length or ((b < closest_length) and (b > data_length)):
59
+ closest_length = b
60
+
61
+ if closest_length == sys.maxsize:
62
+ bucket_sizes.append(data_length)
63
+ closest_length = data_length
64
+
65
+ return closest_length
66
+
67
+
68
+ def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank,
69
+ world_size, padding_right):
70
+ if padding_to is None:
71
+ longest_len = input_ids.shape[-1]
72
+ bucket_sizes = get_bucket_sizes(max_length)
73
+ bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len)
74
+ padding_length = bucket_data_length - input_ids.shape[1]
75
+ pad_tuple = (0, padding_length) if padding_right else (padding_length, 0)
76
+ input_ids = F.pad(input_ids, pad_tuple, 'constant', tokenizer.pad_token_id)
77
+ attention_mask = F.pad(attention_mask, pad_tuple, 'constant', 0)
78
+ if loss_scale:
79
+ loss_scale = F.pad(loss_scale, pad_tuple, 'constant', 0.)
80
+ labels = F.pad(labels, pad_tuple, 'constant', -100)
81
+
82
+ # manually split the batch to different DP rank.
83
+ batch_size = input_ids.shape[0] // world_size
84
+ if batch_size > 0:
85
+ start = rank * batch_size
86
+ end = (rank + 1) * batch_size
87
+ input_ids = input_ids[start:end, :]
88
+ attention_mask = attention_mask[start:end, :]
89
+ labels = labels[start:end, :]
90
+ if loss_scale:
91
+ loss_scale = loss_scale[start:end, :]
92
+ return input_ids, attention_mask, labels, loss_scale
93
+
94
+
95
+ def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size):
96
+ # patch skip_first_batches for customized dataloader.
97
+ def acc_skip_first_batches(dataloader, num_batches=0):
98
+ from accelerate.data_loader import SkipBatchSampler
99
+ batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches)
100
+ try:
101
+ dataset = dataloader.dataset
102
+ except AttributeError:
103
+ dataset = dataloader._loader.dataset
104
+ dataloader_params = {
105
+ 'collate_fn': data_collator,
106
+ 'num_workers': args.dataloader_num_workers,
107
+ 'pin_memory': args.dataloader_pin_memory,
108
+ 'persistent_workers': args.dataloader_persistent_workers,
109
+ }
110
+
111
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
112
+ dataloader_params['batch_sampler'] = batch_sampler
113
+ dataloader_params['worker_init_fn'] = trainer.seed_worker
114
+
115
+ return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device)
116
+
117
+ trainer.skip_first_batches = acc_skip_first_batches
118
+
119
+ # dataloader for TorchAcc.
120
+ import torchacc as ta
121
+
122
+ dataloader_params = {
123
+ 'batch_size': batch_size,
124
+ 'collate_fn': data_collator,
125
+ 'num_workers': args.dataloader_num_workers,
126
+ 'pin_memory': args.dataloader_pin_memory,
127
+ 'persistent_workers': args.dataloader_persistent_workers,
128
+ }
129
+
130
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
131
+ dataloader_params['sampler'] = sampler
132
+ dataloader_params['drop_last'] = args.dataloader_drop_last
133
+ dataloader_params['worker_init_fn'] = trainer.seed_worker
134
+
135
+ return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device)
136
+
137
+
138
+ def ta_eval_dataloader(eval_dataset, data_collator, sampler, args):
139
+ import torchacc as ta
140
+
141
+ dataloader_params = {
142
+ 'batch_size': args.eval_batch_size,
143
+ 'collate_fn': data_collator,
144
+ 'num_workers': args.dataloader_num_workers,
145
+ 'pin_memory': args.dataloader_pin_memory,
146
+ 'persistent_workers': args.dataloader_persistent_workers,
147
+ }
148
+
149
+ if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
150
+ dataloader_params['sampler'] = sampler
151
+ dataloader_params['drop_last'] = args.dataloader_drop_last
152
+
153
+ return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device)
154
+
155
+
156
+ def ta_test_dataloader(test_dataset, data_collator, sampler, args):
157
+ import torchacc as ta
158
+
159
+ dataloader_params = {
160
+ 'batch_size': args.eval_batch_size,
161
+ 'collate_fn': data_collator,
162
+ 'num_workers': args.dataloader_num_workers,
163
+ 'pin_memory': args.dataloader_pin_memory,
164
+ 'persistent_workers': args.dataloader_persistent_workers,
165
+ }
166
+
167
+ if not isinstance(test_dataset, torch.utils.data.IterableDataset):
168
+ dataloader_params['sampler'] = sampler
169
+ dataloader_params['drop_last'] = args.dataloader_drop_last
170
+
171
+ # We use the same batch_size as for eval.
172
+ return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device)
173
+
174
+
175
+ # Save/load checkpoint
176
+ def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir):
177
+ import torch_xla.core.xla_model as xm
178
+ xm.rendezvous('saving_optimizer_states')
179
+ xm.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt'), master_only=False)
180
+ xm.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt'), master_only=False)
181
+ xm.rendezvous('saving_optimizer_states_done')
182
+
183
+
184
+ def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device):
185
+ import torch_xla.core.xla_model as xm
186
+ optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu')
187
+ lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu')
188
+ xm.send_cpu_data_to_device(optimizer_state, device)
189
+ xm.send_cpu_data_to_device(lr_scheduler_state, device)
190
+
191
+ optimizer.load_state_dict(optimizer_state)
192
+ lr_scheduler.load_state_dict(lr_scheduler_state)
193
+ return optimizer, lr_scheduler
194
+
195
+
196
+ def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None):
197
+ output_dir = output_dir if output_dir is not None else args.output_dir
198
+ import torch_xla.core.xla_model as xm
199
+
200
+ model = self_model
201
+
202
+ if xm.is_master_ordinal(local=False):
203
+ os.makedirs(output_dir, exist_ok=True)
204
+ torch.save(args, os.path.join(output_dir, 'training_args.bin'))
205
+
206
+ xm.mark_step()
207
+ # Save a trained model and configuration using `save_pretrained()`.
208
+ # They can then be reloaded using `from_pretrained()`
209
+ supported_classes = (PreTrainedModel, PeftModel)
210
+ if not isinstance(model, supported_classes):
211
+ if isinstance(unwrap_model(model), supported_classes):
212
+ unwrap_model(model).save_pretrained(
213
+ output_dir,
214
+ is_main_process=args.should_save,
215
+ state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
216
+ save_function=xm.save,
217
+ safe_serialization=args.save_safetensors,
218
+ )
219
+ else:
220
+ logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
221
+ state_dict = xm._maybe_convert_to_cpu(model.state_dict())
222
+ if args.save_safetensors:
223
+ safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
224
+ else:
225
+ torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
226
+ else:
227
+ model.save_pretrained(
228
+ output_dir,
229
+ is_main_process=args.should_save,
230
+ save_function=xm.save,
231
+ safe_serialization=args.save_safetensors,
232
+ state_dict=xm._maybe_convert_to_cpu(model.state_dict()))
233
+ if tokenizer is not None and args.should_save:
234
+ tokenizer.save_pretrained(output_dir)
235
+
236
+
237
+ def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir):
238
+ import torch_xla.core.xla_model as xm
239
+ from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
240
+
241
+ xm.mark_step()
242
+
243
+ if xm.is_master_ordinal(local=False):
244
+ os.makedirs(output_dir, exist_ok=True)
245
+ torch.save(args, os.path.join(output_dir, 'training_args.bin'))
246
+
247
+ supported_classes = (PreTrainedModel, PeftModel)
248
+ model = self_model._get_underlay_model().module.module
249
+ unwrapped_model = unwrap_model(model)
250
+
251
+ xm.rendezvous('saving_checkpoint')
252
+ ckpt = {
253
+ 'model': self_model._get_underlay_model().state_dict(),
254
+ 'shard_metadata': self_model._get_underlay_model().get_shard_metadata(),
255
+ }
256
+ if isinstance(model, PeftModel):
257
+ ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin')
258
+ else:
259
+ ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin')
260
+ xm.save(ckpt, ckpt_path, master_only=False)
261
+ # Make sure all ranks have saved checkpoints
262
+ xm.rendezvous('save_full_checkpoints')
263
+
264
+ if tokenizer is not None and args.should_save:
265
+ tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save)
266
+
267
+ # rank 0 consolidates and saves the whole checkpoint.
268
+ if xm.is_master_ordinal(local=False):
269
+ if isinstance(model, PeftModel):
270
+ ckpt_suffix = 'rank*-of-*-adapter_model.bin'
271
+ else:
272
+ ckpt_suffix = 'rank*-of-*-pytorch_model.bin'
273
+ full_state_dict, _ = consolidate_sharded_model_checkpoints(
274
+ ckpt_prefix=os.path.join(output_dir, ''), ckpt_suffix=ckpt_suffix, save_model=False)
275
+
276
+ if isinstance(unwrapped_model, supported_classes):
277
+ unwrapped_model.save_pretrained(
278
+ output_dir,
279
+ state_dict=full_state_dict,
280
+ save_function=xm.save,
281
+ safe_serialization=args.save_safetensors,
282
+ )
283
+ else:
284
+ logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
285
+ if args.save_safetensors:
286
+ safetensors.torch.save_file(full_state_dict, os.path.join(output_dir, 'model.safetensors'))
287
+ else:
288
+ torch.save(full_state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
289
+
290
+ xm.rendezvous('ckpt_consolidation')
291
+ # delete the sharded checkpoint.
292
+ os.remove(ckpt_path)
293
+
294
+
295
+ def ta_trim_graph():
296
+ if use_torchacc() and torchacc_trim_graph():
297
+ import torchacc as ta
298
+ ta.mark_step()
299
+
300
+
301
+ # Model patch
302
+ def rotate_half(x):
303
+ """Rotates half the hidden dims of the input."""
304
+ x1 = x[..., :x.shape[-1] // 2]
305
+ x2 = x[..., x.shape[-1] // 2:]
306
+ return torch.cat((-x2, x1), dim=-1)
307
+
308
+
309
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
310
+ """Applies Rotary Position Embedding to the query and key tensors.
311
+
312
+ Args:
313
+ q (`torch.Tensor`): The query tensor.
314
+ k (`torch.Tensor`): The key tensor.
315
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
316
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
317
+ position_ids (`torch.Tensor`):
318
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
319
+ used to pass offsetted position ids when working with a KV-cache.
320
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
321
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
322
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
323
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
324
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
325
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
326
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
327
+ Returns:
328
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
329
+ """
330
+ if position_ids is not None:
331
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
332
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
333
+ else:
334
+ cos = cos.unsqueeze(unsqueeze_dim)
335
+ sin = sin.unsqueeze(unsqueeze_dim)
336
+ q_embed = (q * cos) + (rotate_half(q) * sin)
337
+ k_embed = (k * cos) + (rotate_half(k) * sin)
338
+ return q_embed, k_embed
339
+
340
+
341
+ def patch_acc_model(args, model):
342
+ if not args.use_flash_attn:
343
+ logger.warn('Currently use flash attn for torchacc.')
344
+ if args.model_type.startswith('qwen1half') or args.model_type.startswith('qwen2'):
345
+ model = patch_qwen2_model(model)
346
+ elif args.model_type.startswith('qwen'):
347
+ import torchacc as ta
348
+ model = ta.patch_qwen_model(model)
349
+ elif args.model_type.startswith('baichuan'):
350
+ model = patch_baichuan_model(model)
351
+ elif args.model_type.startswith('llama') or args.model_type.startswith('yi'):
352
+ model = patch_llama_model(model)
353
+ elif args.model_type.startswith('chatglm'):
354
+ model = patah_chatglm_model(model)
355
+ return model
356
+
357
+
358
+ def patch_llama_model(model):
359
+
360
+ def update_causal_mask(self, *args, **kwargs):
361
+ # attention_mask is not supported in TorchAcc.
362
+ return None
363
+
364
+ def llama_attn_forward(self,
365
+ hidden_states: torch.Tensor,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ position_ids: Optional[torch.Tensor] = None,
368
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
369
+ output_attentions: bool = False,
370
+ use_cache: bool = False,
371
+ cache_position: Optional[torch.LongTensor] = None,
372
+ **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
373
+ from torchacc.ops import flash_attn_varlen_xla
374
+ import einops
375
+
376
+ bsz, q_len, _ = hidden_states.size()
377
+
378
+ query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
379
+ key_states = (
380
+ self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2))
381
+ value_states = (
382
+ self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2))
383
+
384
+ kv_seq_len = key_states.shape[-2]
385
+ assert past_key_value is None, 'past_key_value is not supported'
386
+
387
+ if version.parse(transformers.__version__) >= version.parse('4.36'):
388
+ cos, sin = self.rotary_emb(value_states, position_ids)
389
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
390
+ else:
391
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
392
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
393
+
394
+ assert not output_attentions, 'output_attentions is not supported'
395
+
396
+ if past_key_value is not None:
397
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
398
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
399
+ past_key_value = (key_states, value_states) if use_cache else None
400
+
401
+ # See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
402
+ # if attention_mask is not None:
403
+ # value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1)
404
+ q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...')
405
+ k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...')
406
+ v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...')
407
+ max_s = q_len
408
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
409
+ output = flash_attn_varlen_xla(
410
+ q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True)
411
+ output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz)
412
+
413
+ return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value
414
+
415
+ for layer in model.model.layers:
416
+ layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn)
417
+
418
+ if version.parse(transformers.__version__) >= version.parse('4.38'):
419
+ model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model)
420
+
421
+ return model
422
+
423
+
424
+ def patah_chatglm_model(model):
425
+
426
+ def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
427
+ # x: [sq, b, np, hn]
428
+ sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
429
+ rot_dim = rope_cache.shape[-2] * 2
430
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
431
+ # truncate to support variable sizes
432
+ rope_cache = rope_cache[:sq]
433
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
434
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
435
+ x_out2 = torch.stack(
436
+ [
437
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
438
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
439
+ ],
440
+ -1,
441
+ )
442
+ x_out2 = x_out2.flatten(3)
443
+ return torch.cat((x_out2, x_pass), dim=-1)
444
+
445
+ def chatglm_attn_forward(self,
446
+ hidden_states,
447
+ attention_mask,
448
+ rotary_pos_emb,
449
+ kv_cache=None,
450
+ use_cache=True,
451
+ **kwargs):
452
+ # hidden_states: [sq, b, h]
453
+
454
+ # =================================================
455
+ # Pre-allocate memory for key-values for inference.
456
+ # =================================================
457
+ # =====================
458
+ # Query, Key, and Value
459
+ # =====================
460
+
461
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
462
+ mixed_x_layer = self.query_key_value(hidden_states)
463
+
464
+ if self.multi_query_attention:
465
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
466
+ [
467
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
468
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
469
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
470
+ ],
471
+ dim=-1,
472
+ )
473
+ query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition,
474
+ self.hidden_size_per_attention_head))
475
+ key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
476
+ self.hidden_size_per_attention_head))
477
+ value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
478
+ self.hidden_size_per_attention_head))
479
+ else:
480
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition,
481
+ 3 * self.hidden_size_per_attention_head)
482
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
483
+
484
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
485
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
486
+
487
+ # apply relative positional encoding (rotary embedding)
488
+ if rotary_pos_emb is not None:
489
+ query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb)
490
+ key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb)
491
+
492
+ # adjust key and value for inference
493
+ if kv_cache is not None:
494
+ cache_k, cache_v = kv_cache
495
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
496
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
497
+ if use_cache:
498
+ kv_cache = (key_layer, value_layer)
499
+ else:
500
+ kv_cache = None
501
+
502
+ if self.multi_query_attention:
503
+ key_layer = key_layer.unsqueeze(-2)
504
+ key_layer = key_layer.expand(
505
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
506
+ key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition,
507
+ self.hidden_size_per_attention_head))
508
+ value_layer = value_layer.unsqueeze(-2)
509
+ value_layer = value_layer.expand(
510
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
511
+ value_layer = value_layer.contiguous().view(value_layer.size()[:2]
512
+ + (self.num_attention_heads_per_partition,
513
+ self.hidden_size_per_attention_head))
514
+
515
+ # ==================================
516
+ # core attention computation
517
+ # ==================================
518
+
519
+ from torchacc.ops import flash_attn_varlen_qkvpacked_xla
520
+ import einops
521
+
522
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
523
+ bsz, _, q_len, _ = query_layer.size()
524
+ qkv = torch.stack([query_layer, key_layer, value_layer], dim=2)
525
+ qkv = qkv.transpose(1, 3)
526
+ qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')
527
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
528
+ context_layer = flash_attn_varlen_qkvpacked_xla(
529
+ qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True)
530
+ context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz)
531
+ context_layer = context_layer.permute(1, 0, 2, 3)
532
+ new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, )
533
+ context_layer = context_layer.reshape(*new_context_layer_shape)
534
+
535
+ # =================
536
+ # Output. [sq, b, h]
537
+ # =================
538
+
539
+ output = self.dense(context_layer)
540
+
541
+ return output, kv_cache
542
+
543
+ def torchacc_swiglu(x):
544
+ x = torch.chunk(x, 2, dim=-1)
545
+ return F.silu(x[0]).to(x[0].dtype) * x[1]
546
+
547
+ # patch attention
548
+ for layer in model.transformer.encoder.layers:
549
+ layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention)
550
+ layer.mlp.activation_func = torchacc_swiglu
551
+
552
+ return model
553
+
554
+
555
+ def patch_baichuan_model(model):
556
+
557
+ def baichuan_attn_forward(self,
558
+ hidden_states: torch.Tensor,
559
+ attention_mask: Optional[torch.Tensor] = None,
560
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
561
+ output_attentions: bool = False,
562
+ use_cache: bool = False,
563
+ **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
564
+
565
+ import einops
566
+
567
+ bsz, q_len, _ = hidden_states.size()
568
+
569
+ proj = self.W_pack(hidden_states)
570
+ proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2))
571
+ query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
572
+ key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
573
+ value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
574
+
575
+ kv_seq_len = key_states.shape[-2]
576
+ if past_key_value is not None:
577
+ kv_seq_len += past_key_value[0].shape[-2]
578
+
579
+ if past_key_value is not None:
580
+ # reuse k, v, self_attention
581
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
582
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
583
+
584
+ past_key_value = (key_states, value_states) if use_cache else None
585
+
586
+ from torchacc.ops import flash_attn_varlen_xla
587
+ query_states = query_states.transpose(1, 2)
588
+ key_states = key_states.transpose(1, 2)
589
+ value_states = value_states.transpose(1, 2)
590
+ q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
591
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
592
+ output = flash_attn_varlen_xla(
593
+ q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True)
594
+ output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz)
595
+ output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)'))
596
+ return output, None, past_key_value
597
+
598
+ for layer in model.base_model.layers:
599
+ layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn)
600
+
601
+ return model
602
+
603
+
604
+ def patch_qwen2_model(model):
605
+
606
+ def update_causal_mask(self, *args, **kwargs):
607
+ # attention_mask is not supported in TorchAcc.
608
+ return None
609
+
610
+ def qwen2_attn_forward(
611
+ self,
612
+ hidden_states,
613
+ attention_mask=None,
614
+ position_ids=None,
615
+ past_key_value=None,
616
+ output_attentions=False,
617
+ use_cache=False,
618
+ cache_position=None,
619
+ position_embeddings=None,
620
+ **kwargs,
621
+ ):
622
+
623
+ bsz, q_len, _ = hidden_states.size()
624
+
625
+ query_states = self.q_proj(hidden_states)
626
+ key_states = self.k_proj(hidden_states)
627
+ value_states = self.v_proj(hidden_states)
628
+
629
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
630
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
631
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
632
+
633
+ kv_seq_len = key_states.shape[-2]
634
+ if past_key_value is not None:
635
+ if self.layer_idx is None:
636
+ raise ValueError(
637
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
638
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
639
+ 'with a layer index.')
640
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
641
+
642
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
643
+ # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
644
+ rotary_seq_len = kv_seq_len + 1
645
+
646
+ if version.parse(transformers.__version__) >= version.parse('4.45'):
647
+ if position_embeddings is None:
648
+ cos, sin = self.rotary_emb(value_states, position_ids)
649
+ else:
650
+ cos, sin = position_embeddings
651
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
652
+ else:
653
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
654
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
655
+
656
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
657
+
658
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
659
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
660
+ # cast them back in float16 just to be sure everything works as expected.
661
+ input_dtype = query_states.dtype
662
+ if input_dtype == torch.float32:
663
+ if torch.is_autocast_enabled():
664
+ target_dtype = torch.get_autocast_gpu_dtype()
665
+ # Handle the case where the model is quantized
666
+ elif hasattr(self.config, '_pre_quantization_dtype'):
667
+ target_dtype = self.config._pre_quantization_dtype
668
+ else:
669
+ target_dtype = self.q_proj.weight.dtype
670
+
671
+ query_states = query_states.to(target_dtype)
672
+ key_states = key_states.to(target_dtype)
673
+ value_states = value_states.to(target_dtype)
674
+
675
+ # Reshape to the expected shape for Flash Attention
676
+ query_states = query_states.transpose(1, 2)
677
+ key_states = key_states.transpose(1, 2)
678
+ value_states = value_states.transpose(1, 2)
679
+
680
+ from torchacc.ops import flash_attn_varlen_xla
681
+ import einops
682
+
683
+ q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
684
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
685
+
686
+ attn_output = flash_attn_varlen_xla(
687
+ q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True)
688
+
689
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
690
+ attn_output = self.o_proj(attn_output)
691
+
692
+ if not output_attentions:
693
+ attn_weights = None
694
+
695
+ return attn_output, attn_weights, past_key_value
696
+
697
+ def qwen2_forward(self,
698
+ input_ids: torch.LongTensor = None,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ position_ids: Optional[torch.LongTensor] = None,
701
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
702
+ inputs_embeds: Optional[torch.FloatTensor] = None,
703
+ use_cache: Optional[bool] = None,
704
+ output_attentions: Optional[bool] = None,
705
+ output_hidden_states: Optional[bool] = None,
706
+ return_dict: Optional[bool] = None,
707
+ **kwargs):
708
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
709
+ output_hidden_states = (
710
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
711
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
712
+
713
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
714
+
715
+ # retrieve input_ids and inputs_embeds
716
+ if input_ids is not None and inputs_embeds is not None:
717
+ raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
718
+ elif input_ids is not None:
719
+ batch_size, seq_length = input_ids.shape
720
+ elif inputs_embeds is not None:
721
+ batch_size, seq_length, _ = inputs_embeds.shape
722
+ else:
723
+ raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
724
+
725
+ if self.gradient_checkpointing and self.training:
726
+ if use_cache:
727
+ use_cache = False
728
+
729
+ past_key_values_length = 0
730
+
731
+ if use_cache:
732
+ use_legacy_cache = not isinstance(past_key_values, Cache)
733
+ if use_legacy_cache:
734
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
735
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
736
+
737
+ if position_ids is None:
738
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
739
+ position_ids = torch.arange(
740
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
741
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
742
+ else:
743
+ position_ids = position_ids.view(-1, seq_length).long()
744
+
745
+ if inputs_embeds is None:
746
+ inputs_embeds = self.embed_tokens(input_ids)
747
+
748
+ hidden_states = inputs_embeds
749
+
750
+ # decoder layers
751
+ all_hidden_states = () if output_hidden_states else None
752
+ all_self_attns = () if output_attentions else None
753
+ next_decoder_cache = None
754
+
755
+ for decoder_layer in self.layers:
756
+ if output_hidden_states:
757
+ all_hidden_states += (hidden_states, )
758
+
759
+ if self.gradient_checkpointing and self.training:
760
+ layer_outputs = self._gradient_checkpointing_func(
761
+ decoder_layer.__call__,
762
+ hidden_states,
763
+ attention_mask,
764
+ position_ids,
765
+ past_key_values,
766
+ output_attentions,
767
+ use_cache,
768
+ )
769
+ else:
770
+ layer_outputs = decoder_layer(
771
+ hidden_states,
772
+ attention_mask=attention_mask,
773
+ position_ids=position_ids,
774
+ past_key_value=past_key_values,
775
+ output_attentions=output_attentions,
776
+ use_cache=use_cache,
777
+ )
778
+
779
+ hidden_states = layer_outputs[0]
780
+
781
+ if use_cache:
782
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
783
+
784
+ if output_attentions:
785
+ all_self_attns += (layer_outputs[1], )
786
+
787
+ hidden_states = self.norm(hidden_states)
788
+
789
+ # add hidden states from the last decoder layer
790
+ if output_hidden_states:
791
+ all_hidden_states += (hidden_states, )
792
+
793
+ next_cache = None
794
+ if use_cache:
795
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
796
+
797
+ if not return_dict:
798
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
799
+ from transformers.modeling_outputs import BaseModelOutputWithPast
800
+ return BaseModelOutputWithPast(
801
+ last_hidden_state=hidden_states,
802
+ past_key_values=next_cache,
803
+ hidden_states=all_hidden_states,
804
+ attentions=all_self_attns,
805
+ )
806
+
807
+ for layer in model.model.layers:
808
+ layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn)
809
+
810
+ if version.parse(transformers.__version__) >= version.parse('4.43'):
811
+ model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model)
812
+ else:
813
+ model.model.forward = types.MethodType(qwen2_forward, model.model)
814
+ return model
815
+
816
+
817
+ def patch_clip_grad_norm(accelerator):
818
+ from accelerate.utils import DistributedType
819
+ from accelerate.optimizer import AcceleratedOptimizer
820
+ import torch_xla.core.xla_model as xm
821
+
822
+ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
823
+ """
824
+ Should be used in place of `torch.nn.utils.clip_grad_norm_`.
825
+
826
+ Returns:
827
+ `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).
828
+
829
+ Example:
830
+
831
+ ```python
832
+ >>> from accelerate import Accelerator
833
+
834
+ >>> accelerator = Accelerator(gradient_accumulation_steps=2)
835
+ >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
836
+
837
+ >>> for input, target in dataloader:
838
+ ... optimizer.zero_grad()
839
+ ... output = model(input)
840
+ ... loss = loss_func(output, target)
841
+ ... accelerator.backward(loss)
842
+ ... if accelerator.sync_gradients:
843
+ ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
844
+ ... optimizer.step()
845
+ ```
846
+ """
847
+ if self.distributed_type == DistributedType.FSDP:
848
+ self.unscale_gradients()
849
+ parameters = [p for p in parameters]
850
+ for model in self._models:
851
+ if parameters == [p for p in model.parameters()]:
852
+ return model.clip_grad_norm_(max_norm, norm_type)
853
+ elif self.distributed_type == DistributedType.DEEPSPEED:
854
+ # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
855
+ # We cannot return the gradient norm because DeepSpeed does it.
856
+ return None
857
+ elif self.distributed_type == DistributedType.XLA:
858
+ # Reduce gradients first for XLA
859
+ for acc_opt in self._optimizers:
860
+ if not acc_opt.gradient_state.is_xla_gradients_synced:
861
+ opt = acc_opt
862
+ while isinstance(opt, AcceleratedOptimizer):
863
+ opt = opt.optimizer
864
+ gradients = xm._fetch_gradients(opt)
865
+ # Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
866
+ # one by one in self.reduce is non-inplace.
867
+ xm.all_reduce('sum', gradients, scale=1.0 / self.num_processes)
868
+ # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
869
+ acc_opt.gradient_state.is_xla_gradients_synced = True
870
+ if os.environ.get('ACCELERATE_USE_FSDP', 'false') == 'true':
871
+ self.unscale_gradients()
872
+ parameters = [p for p in parameters]
873
+ for model in self._models:
874
+ if parameters == [p for p in model.parameters()]:
875
+ return model._get_underlay_model().clip_grad_norm_(max_norm, norm_type)
876
+ self.unscale_gradients()
877
+ return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
878
+
879
+ # TODO(baole): This should be removed once accelerate is updated.
880
+ accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator)
881
+ return accelerator
882
+
883
+
884
+ def ta_accelerate(model,
885
+ fsdp_num,
886
+ layer_cls_name,
887
+ bf16=True,
888
+ fp16=False,
889
+ gradient_checkpointing=True,
890
+ fsdp_flatten_parameters=False):
891
+ """ accelerate LLM training using TorchAcc(only available internally).
892
+ """
893
+ import torchacc as ta
894
+ assert layer_cls_name is not None
895
+
896
+ def get_ta_config():
897
+ config = ta.Config()
898
+ config.compute.fp16 = fp16
899
+ config.compute.bf16 = bf16
900
+
901
+ config.memory.gc = gradient_checkpointing
902
+ if config.memory.gc:
903
+ config.memory.gc_cls = {layer_cls_name}
904
+
905
+ config.dist.fsdp.size = fsdp_num
906
+ config.dist.fsdp.wrap_layer_cls = {layer_cls_name}
907
+ config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters
908
+ config.dist.dp.size = 1
909
+
910
+ if fsdp_num > 1:
911
+ os.environ['ACCELERATE_USE_FSDP'] = 'true'
912
+
913
+ return config
914
+
915
+ ta_config = get_ta_config()
916
+ model = ta.accelerate(model, config=ta_config)
917
+ return model
ms-swift/tests/eval/test_eval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4
+
5
+ infer_backend = 'vllm'
6
+
7
+
8
+ def test_eval_native():
9
+ from swift.llm import EvalArguments, eval_main
10
+ eval_main(
11
+ EvalArguments(
12
+ model='Qwen/Qwen2.5-0.5B-Instruct',
13
+ eval_dataset='arc',
14
+ infer_backend=infer_backend,
15
+ eval_backend='Native',
16
+ eval_limit=10,
17
+ eval_generation_config={
18
+ 'max_new_tokens': 128,
19
+ 'temperature': 0.1
20
+ },
21
+ extra_eval_args={
22
+ 'stream': True,
23
+ 'ignore_errors': True
24
+ },
25
+ ))
26
+
27
+
28
+ def test_eval_llm():
29
+ from swift.llm import EvalArguments, eval_main
30
+ eval_main(
31
+ EvalArguments(
32
+ model='Qwen/Qwen2-7B-Instruct',
33
+ eval_dataset='arc_c',
34
+ infer_backend=infer_backend,
35
+ eval_backend='OpenCompass',
36
+ eval_limit=10))
37
+
38
+
39
+ def test_eval_mllm():
40
+ from swift.llm import EvalArguments, eval_main
41
+ eval_main(
42
+ EvalArguments(
43
+ model='Qwen/Qwen2.5-VL-3B-Instruct',
44
+ eval_dataset=['realWorldQA'],
45
+ infer_backend='pt',
46
+ eval_backend='VLMEvalKit',
47
+ eval_limit=10,
48
+ eval_generation_config={
49
+ 'max_new_tokens': 128,
50
+ 'temperature': 0.1
51
+ }))
52
+
53
+
54
+ def test_eval_url():
55
+ from swift.llm import EvalArguments, eval_main, DeployArguments, run_deploy
56
+ deploy_args = DeployArguments(model='Qwen/Qwen2-VL-7B-Instruct', infer_backend=infer_backend, verbose=False)
57
+
58
+ with run_deploy(deploy_args, return_url=True) as url:
59
+ eval_main(EvalArguments(model='Qwen2-VL-7B-Instruct', eval_url=url, eval_dataset=['arc_c']))
60
+
61
+
62
+ if __name__ == '__main__':
63
+ # test_eval_llm()
64
+ test_eval_mllm()
65
+ # test_eval_url()
66
+ # test_eval_native()
ms-swift/tests/export/test_quant.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal
3
+
4
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
+
6
+
7
+ def test_llm_quant(quant_method: Literal['gptq', 'awq'] = 'awq'):
8
+ from swift.llm import export_main, ExportArguments
9
+ export_main(
10
+ ExportArguments(
11
+ model='Qwen/Qwen2-7B-Instruct',
12
+ quant_bits=4,
13
+ dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000'],
14
+ quant_method=quant_method))
15
+
16
+
17
+ def test_vlm_quant(quant_method: Literal['gptq', 'awq'] = 'awq'):
18
+ from swift.llm import export_main, ExportArguments
19
+ export_main(
20
+ ExportArguments(
21
+ model='Qwen/Qwen2-VL-7B-Instruct',
22
+ quant_bits=4,
23
+ dataset=['modelscope/coco_2014_caption:validation#1000'],
24
+ quant_method=quant_method))
25
+
26
+
27
+ def test_audio_quant(quant_method: Literal['gptq', 'awq'] = 'awq'):
28
+ from swift.llm import export_main, ExportArguments
29
+ export_main(
30
+ ExportArguments(
31
+ model='Qwen/Qwen2-Audio-7B-Instruct',
32
+ quant_bits=4,
33
+ dataset=['speech_asr/speech_asr_aishell1_trainsets:validation#1000'],
34
+ quant_method=quant_method))
35
+
36
+
37
+ def test_vlm_bnb_quant():
38
+ from swift.llm import export_main, ExportArguments, infer_main, InferArguments
39
+ export_main(ExportArguments(model='Qwen/Qwen2-VL-7B-Instruct', quant_bits=4, quant_method='bnb'))
40
+
41
+ # infer_main(InferArguments(ckpt_dir='Qwen/Qwen2-VL-7B-Instruct-bnb-int4'))
42
+
43
+
44
+ def test_bert():
45
+ from swift.llm import export_main, ExportArguments
46
+ output_dir = 'output/swift_test_bert_merged'
47
+ export_main(ExportArguments(adapters='swift/test_bert', merge_lora=True, output_dir=output_dir))
48
+ export_main(
49
+ ExportArguments(model=output_dir, load_data_args=True, quant_bits=4, quant_method='gptq', max_length=512))
50
+
51
+
52
+ def test_reward_model():
53
+ from swift.llm import export_main, ExportArguments
54
+
55
+ export_main(
56
+ ExportArguments(
57
+ model='Shanghai_AI_Laboratory/internlm2-1_8b-reward',
58
+ dataset=['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000'],
59
+ quant_bits=4,
60
+ quant_method='gptq'))
61
+
62
+
63
+ if __name__ == '__main__':
64
+ # test_llm_quant('gptq')
65
+ # test_vlm_quant('gptq')
66
+ # test_audio_quant('gptq')
67
+ # test_vlm_bnb_quant()
68
+ # test_bert()
69
+ test_reward_model()
ms-swift/tests/general/test_arch.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def test_model_arch():
2
+ from swift.llm import MODEL_MAPPING, safe_snapshot_download
3
+ from transformers import PretrainedConfig
4
+ from swift.utils import JsonlWriter
5
+ import random
6
+ jsonl_writer = JsonlWriter('model_arch.jsonl')
7
+ for i, (model_type, model_meta) in enumerate(MODEL_MAPPING.items()):
8
+ if i < 0:
9
+ continue
10
+ arch_list = model_meta.architectures
11
+ for model_group in model_meta.model_groups:
12
+ model = random.choice(model_group.models).ms_model_id
13
+ config_dict = None
14
+ try:
15
+ model_dir = safe_snapshot_download(model, download_model=False)
16
+ config_dict = PretrainedConfig.get_config_dict(model_dir)[0]
17
+ except Exception:
18
+ pass
19
+ finally:
20
+ msg = None
21
+ if config_dict:
22
+ arch = config_dict.get('architectures')
23
+ if arch and arch[0] not in arch_list:
24
+ msg = {
25
+ 'model_type': model_type,
26
+ 'model': model,
27
+ 'config_arch': arch,
28
+ 'architectures': arch_list
29
+ }
30
+ elif not arch and arch_list:
31
+ msg = {
32
+ 'model_type': model_type,
33
+ 'model': model,
34
+ 'config_arch': arch,
35
+ 'architectures': arch_list
36
+ }
37
+ else:
38
+ msg = {'msg': 'error', 'model_type': model_type, 'model': model, 'arch_list': arch_list}
39
+ if msg:
40
+ jsonl_writer.append(msg)
41
+
42
+
43
+ if __name__ == '__main__':
44
+ test_model_arch()
ms-swift/tests/general/test_dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from swift.llm import load_dataset
4
+
5
+
6
+ def _test_dataset(datasets: List[str], num_proc: int = 1, strict: bool = False, **kwargs):
7
+ dataset = load_dataset(datasets, num_proc=num_proc, strict=strict, **kwargs)
8
+ print(f'dataset[0]: {dataset[0]}')
9
+ print(f'dataset[1]: {dataset[1]}')
10
+
11
+
12
+ def test_sft():
13
+ # swift/SlimOrca swift/cosmopedia-100k
14
+ # _test_dataset(['lvjianjin/AdvertiseGen'])
15
+ # _test_dataset(['AI-ModelScope/Duet-v0.5'])
16
+ # _test_dataset(['swift/SlimOrca', 'swift/cosmopedia-100k'])
17
+ # _test_dataset(['OmniData/Zhihu-KOL-More-Than-100-Upvotes'])
18
+ # _test_dataset(['OmniData/Zhihu-KOL'])
19
+ _test_dataset([
20
+ 'AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000',
21
+ 'AI-ModelScope/LongAlpaca-12k#1000'
22
+ ])
23
+ # _test_dataset(['swift/Infinity-Instruct:all'])
24
+ # _test_dataset(['swift/sharegpt:all'])
25
+ # _test_dataset(['AI-ModelScope/sharegpt_gpt4:all'])
26
+ # _test_dataset(['iic/ms_bench'])
27
+ # _test_dataset(['swift/tagengo-gpt4'])
28
+
29
+
30
+ def test_mllm():
31
+ # _test_dataset(['AI-ModelScope/ShareGPT4V:all'])
32
+ # _test_dataset(['AI-ModelScope/LLaVA-Pretrain'])
33
+ # _test_dataset(['swift/TextCaps'])
34
+ # _test_dataset(['swift/RLAIF-V-Dataset:all'])
35
+ # _test_dataset(['swift/OK-VQA_train'])
36
+ # _test_dataset(['swift/OCR-VQA'])
37
+ # _test_dataset(['swift/A-OKVQA'])
38
+ # _test_dataset(['AI-ModelScope/MovieChat-1K-test'])
39
+ _test_dataset([
40
+ 'AI-ModelScope/LaTeX_OCR:all', 'modelscope/coco_2014_caption:validation',
41
+ 'speech_asr/speech_asr_aishell1_trainsets:validation'
42
+ ],
43
+ strict=False)
44
+ # _test_dataset(['swift/VideoChatGPT:all'])
45
+ # _test_dataset(['speech_asr/speech_asr_aishell1_trainsets:validation'])
46
+ # _test_dataset(['AI-ModelScope/captcha-images'])
47
+ # _test_dataset(['swift/gpt4v-dataset:all'])
48
+ # _test_dataset(['modelscope/coco_2014_caption:validation'])
49
+ # _test_dataset(['AI-ModelScope/LLaVA-Instruct-150K'], num_proc=16)
50
+
51
+
52
+ def test_agent():
53
+ _test_dataset(['swift/ToolBench'])
54
+ # _test_dataset(['AI-ModelScope/ms_agent_for_agentfabric:all'])
55
+
56
+
57
+ def test_dpo():
58
+ _test_dataset(['AI-ModelScope/orpo-dpo-mix-40k'])
59
+ _test_dataset(['AI-ModelScope/hh-rlhf:all'])
60
+ _test_dataset(['AI-ModelScope/hh_rlhf_cn:all'])
61
+ _test_dataset(['hjh0119/shareAI-Llama3-DPO-zh-en-emoji:all'])
62
+
63
+
64
+ def test_kto():
65
+ _test_dataset(['AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto'])
66
+
67
+
68
+ def test_pretrain():
69
+ _test_dataset(['AI-ModelScope/ruozhiba:all'])
70
+
71
+
72
+ def test_dataset_info():
73
+ _test_dataset(['swift/self-cognition#500'], model_name='xiao huang', model_author='swift')
74
+ # _test_dataset(['codefuse-ai/CodeExercise-Python-27k'])
75
+
76
+
77
+ def test_cls():
78
+ _test_dataset(['simpleai/HC3-Chinese:baike'])
79
+ _test_dataset(['simpleai/HC3-Chinese:baike_cls'])
80
+
81
+
82
+ if __name__ == '__main__':
83
+ # test_sft()
84
+ # test_agent()
85
+ # test_dpo()
86
+ # test_kto()
87
+ test_mllm()
88
+ # test_pretrain()
89
+ # test_dataset_info()
90
+ # test_cls()
ms-swift/tests/general/test_model.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ from swift.utils import get_device
6
+
7
+ os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
8
+
9
+
10
+ def test_qwen2():
11
+ import os
12
+ from swift.llm import get_model_tokenizer
13
+ model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False)
14
+ print(f'model: {model}, tokenizer: {tokenizer}')
15
+ # test hf
16
+ model, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False, use_hf=True)
17
+
18
+ model, tokenizer = get_model_tokenizer(
19
+ 'Qwen/Qwen2-7B-Instruct', torch.float32, device_map=get_device(), attn_impl='flash_attn')
20
+ print(f'model: {model}, tokenizer: {tokenizer}')
21
+
22
+
23
+ def test_modelscope_hub():
24
+ from swift.llm import get_model_tokenizer
25
+ model, tokenizer = get_model_tokenizer('Qwen/Qwen2___5-Math-1___5B-Instruct/', load_model=False)
26
+
27
+
28
+ if __name__ == '__main__':
29
+ test_qwen2()
30
+ # test_modelscope_hub()
ms-swift/tests/general/test_stream.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from swift.llm import load_dataset
2
+
3
+
4
+ def test_local_dataset():
5
+ # please use git clone
6
+ from swift.llm import git_clone_github
7
+ model_dir = git_clone_github('https://www.modelscope.cn/datasets/swift/swift-sft-mixture.git')
8
+ dataset = load_dataset(datasets=[f'{model_dir}:firefly'], streaming=True)[0]
9
+ print(next(iter(dataset)))
10
+
11
+
12
+ def test_hub_dataset():
13
+ local_dataset = 'swift/swift-sft-mixture:firefly'
14
+ dataset = load_dataset(datasets=[local_dataset], streaming=True)[0]
15
+ print(next(iter(dataset)))
16
+
17
+
18
+ if __name__ == '__main__':
19
+ test_local_dataset()
20
+ # test_hub_dataset()
ms-swift/tests/general/test_template.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset
2
+
3
+ from swift.llm import EncodePreprocessor, TemplateInputs, get_model_tokenizer, get_template, load_dataset
4
+
5
+
6
+ def test_template():
7
+ _, tokenizer = get_model_tokenizer('Qwen/Qwen2-7B-Instruct', load_model=False)
8
+ template = get_template(tokenizer.model_meta.template, tokenizer)
9
+ template_inputs = TemplateInputs([{
10
+ 'role': 'system',
11
+ 'content': 'AAA'
12
+ }, {
13
+ 'role': 'user',
14
+ 'content': 'BBB'
15
+ }, {
16
+ 'role': 'assistant',
17
+ 'content': 'CCC'
18
+ }, {
19
+ 'role': 'user',
20
+ 'content': 'DDD'
21
+ }])
22
+ inputs = template.encode(template_inputs)
23
+ print(f'inputs.keys(): {inputs.keys()}')
24
+ print(tokenizer.decode(inputs['input_ids']))
25
+
26
+
27
+ def test_mllm():
28
+ _, tokenizer = get_model_tokenizer('Qwen/Qwen2-VL-7B-Instruct', load_model=False)
29
+ template = get_template(tokenizer.model_meta.template, tokenizer)
30
+ template_inputs = TemplateInputs([{
31
+ 'role': 'system',
32
+ 'content': 'AAA'
33
+ }, {
34
+ 'role': 'user',
35
+ 'content': '<image>BBB'
36
+ }, {
37
+ 'role': 'assistant',
38
+ 'content': 'CCC'
39
+ }, {
40
+ 'role': 'user',
41
+ 'content': 'DDD'
42
+ }],
43
+ images=['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'])
44
+ inputs = template.encode(template_inputs)
45
+ print(f'inputs.keys(): {inputs.keys()}')
46
+ print(template.safe_decode(inputs['input_ids']))
47
+
48
+
49
+ def _test_dataset_map(model_id: str, dataset_id: str):
50
+ _, tokenizer = get_model_tokenizer(model_id, load_model=False)
51
+ template = get_template(tokenizer.model_meta.template, tokenizer)
52
+ dataset = load_dataset([dataset_id], num_proc=2)[0]
53
+
54
+ # 1: 1500
55
+ # 16: 10766.36 examples/s
56
+ new_dataset = EncodePreprocessor(template)(dataset, num_proc=4)
57
+ print(f'new_dataset: {new_dataset}')
58
+ print(template.safe_decode(new_dataset[0]['input_ids']))
59
+ print(template.safe_decode(new_dataset[1]['input_ids']))
60
+
61
+
62
+ def test_llm_dataset_map():
63
+ _test_dataset_map('Qwen/Qwen2-7B-Instruct', 'AI-ModelScope/alpaca-gpt4-data-zh')
64
+
65
+
66
+ def test_mllm_dataset_map():
67
+ _test_dataset_map('Qwen/Qwen2-VL-7B-Instruct', 'modelscope/coco_2014_caption:validation#100')
68
+
69
+
70
+ if __name__ == '__main__':
71
+ # test_template()
72
+ # test_mllm()
73
+ # test_llm_dataset_map()
74
+ test_mllm_dataset_map()
ms-swift/tests/hub/__init__.py ADDED
File without changes
ms-swift/tests/hub/test_check_model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ import unittest
5
+
6
+ from modelscope import Model, check_local_model_is_latest
7
+
8
+
9
+ class TestCheckModel(unittest.TestCase):
10
+
11
+ def setUp(self):
12
+ print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
13
+ self.tmp_dir = tempfile.TemporaryDirectory().name
14
+ if not os.path.exists(self.tmp_dir):
15
+ os.makedirs(self.tmp_dir)
16
+
17
+ def tearDown(self):
18
+ import peft
19
+ shutil.rmtree(self.tmp_dir)
20
+ super().tearDown()
21
+
22
+ def test_check_model(self):
23
+ model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base', revision='v1.0.0')
24
+ self.assertFalse(check_local_model_is_latest(model.model_dir))
ms-swift/tests/infer/test_infer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7
+
8
+
9
+ def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']):
10
+ from swift.llm import InferRequest, get_template
11
+ if infer_backend == 'lmdeploy':
12
+ from swift.llm import LmdeployEngine
13
+ engine = LmdeployEngine('OpenGVLab/InternVL2_5-2B', torch.float32)
14
+ elif infer_backend == 'pt':
15
+ from swift.llm import PtEngine
16
+ engine = PtEngine('Qwen/Qwen2-7B-Instruct', max_batch_size=16)
17
+ elif infer_backend == 'vllm':
18
+ from swift.llm import VllmEngine
19
+ engine = VllmEngine('Qwen/Qwen2-7B-Instruct')
20
+ template = get_template(engine.model_meta.template, engine.tokenizer)
21
+ infer_requests = [
22
+ # InferRequest([{'role': 'user', 'content': '晚上睡不着觉怎么办'}]) for i in range(100)
23
+ InferRequest([{
24
+ 'role': 'user',
25
+ 'content': 'hello! who are you'
26
+ }]) for i in range(100)
27
+ ]
28
+ return engine, template, infer_requests
29
+
30
+
31
+ def test_infer(infer_backend):
32
+ from swift.llm import RequestConfig
33
+ from swift.plugin import InferStats
34
+ engine, template, infer_requests = _prepare(infer_backend=infer_backend)
35
+ request_config = RequestConfig(temperature=0)
36
+ infer_stats = InferStats()
37
+
38
+ response_list = engine.infer(
39
+ infer_requests, template=template, request_config=request_config, metrics=[infer_stats])
40
+
41
+ for response in response_list[:2]:
42
+ print(response.choices[0].message.content)
43
+ print(infer_stats.compute())
44
+
45
+
46
+ def test_stream(infer_backend):
47
+ from swift.llm import RequestConfig
48
+ from swift.plugin import InferStats
49
+ engine, template, infer_requests = _prepare(infer_backend=infer_backend)
50
+ infer_stats = InferStats()
51
+ request_config = RequestConfig(temperature=0, stream=True, logprobs=True)
52
+
53
+ gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats])
54
+
55
+ for response in gen_list[0]:
56
+ if response is None:
57
+ continue
58
+ print(response.choices[0].delta.content, end='', flush=True)
59
+ print()
60
+ print(infer_stats.compute())
61
+
62
+ gen_list = engine.infer(
63
+ infer_requests, template=template, request_config=request_config, use_tqdm=True, metrics=[infer_stats])
64
+
65
+ for response in gen_list[0]:
66
+ pass
67
+
68
+ print(infer_stats.compute())
69
+
70
+
71
+ if __name__ == '__main__':
72
+ test_infer('pt')
73
+ # test_stream('pt')
ms-swift/tests/infer/test_logprobs.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ if __name__ == '__main__':
7
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8
+
9
+
10
+ def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']):
11
+ from swift.llm import InferRequest, get_template
12
+
13
+ if infer_backend == 'lmdeploy':
14
+ from swift.llm import LmdeployEngine
15
+ engine = LmdeployEngine('Qwen/Qwen2-7B-Instruct', torch.float32)
16
+ elif infer_backend == 'pt':
17
+ from swift.llm import PtEngine
18
+ engine = PtEngine('Qwen/Qwen2-7B-Instruct')
19
+ elif infer_backend == 'vllm':
20
+ from swift.llm import VllmEngine
21
+ engine = VllmEngine('Qwen/Qwen2-7B-Instruct')
22
+ template = get_template(engine.model_meta.template, engine.tokenizer)
23
+ infer_requests = [
24
+ InferRequest([{
25
+ 'role': 'user',
26
+ 'content': '晚上睡不着觉怎么办'
27
+ }]),
28
+ InferRequest([{
29
+ 'role': 'user',
30
+ 'content': 'hello! who are you'
31
+ }])
32
+ ]
33
+ return engine, template, infer_requests
34
+
35
+
36
+ def test_infer(engine, template, infer_requests):
37
+ from swift.llm import RequestConfig
38
+ from swift.plugin import InferStats
39
+
40
+ request_config = RequestConfig(temperature=0, logprobs=True, top_logprobs=2)
41
+ infer_stats = InferStats()
42
+
43
+ response_list = engine.infer(
44
+ infer_requests, template=template, request_config=request_config, metrics=[infer_stats])
45
+
46
+ for response in response_list[:2]:
47
+ print(response.choices[0].message.content)
48
+ print(infer_stats.compute())
49
+
50
+
51
+ def test_stream(engine, template, infer_requests):
52
+ from swift.llm import RequestConfig
53
+ from swift.plugin import InferStats
54
+
55
+ infer_stats = InferStats()
56
+ request_config = RequestConfig(temperature=0, stream=True, logprobs=True, top_logprobs=2)
57
+
58
+ gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats])
59
+
60
+ for response in gen_list[0]:
61
+ if response is None:
62
+ continue
63
+ print(response.choices[0].delta.content, end='', flush=True)
64
+
65
+ print(infer_stats.compute())
66
+
67
+
68
+ if __name__ == '__main__':
69
+ engine, template, infer_requests = _prepare(infer_backend='pt')
70
+ test_infer(engine, template, infer_requests)
71
+ test_stream(engine, template, infer_requests)