Student0809 commited on
Commit
2742ed8
·
verified ·
1 Parent(s): a050167

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ms-swift/.ipynb_checkpoints/clean_transcripts-checkpoint.py +95 -0
  2. ms-swift/.ipynb_checkpoints/dataset_new-checkpoint.json +0 -0
  3. ms-swift/silence_overlaps/delete_transcript.json +0 -0
  4. ms-swift/swift/llm/train/__pycache__/__init__.cpython-310.pyc +0 -0
  5. ms-swift/swift/llm/train/__pycache__/kto.cpython-310.pyc +0 -0
  6. ms-swift/swift/megatron/__init__.py +35 -0
  7. ms-swift/swift/megatron/argument/megatron_args.py +253 -0
  8. ms-swift/swift/megatron/model/gpt/mcore2hf.py +70 -0
  9. ms-swift/swift/megatron/train/__init__.py +2 -0
  10. ms-swift/swift/megatron/train/pt.py +19 -0
  11. ms-swift/swift/megatron/train/sft.py +65 -0
  12. ms-swift/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc +0 -0
  13. ms-swift/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc +0 -0
  14. ms-swift/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc +0 -0
  15. ms-swift/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc +0 -0
  16. ms-swift/swift/plugin/agent_template/hermes.py +78 -0
  17. ms-swift/swift/plugin/agent_template/react.py +66 -0
  18. ms-swift/swift/plugin/loss_scale/__init__.py +1 -0
  19. ms-swift/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc +0 -0
  20. ms-swift/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc +0 -0
  21. ms-swift/swift/plugin/loss_scale/config/agentflan.json +22 -0
  22. ms-swift/swift/plugin/loss_scale/config/hermes.json +3 -0
  23. ms-swift/swift/plugin/prm.py +154 -0
  24. ms-swift/swift/plugin/rm_plugin.py +229 -0
  25. ms-swift/swift/trainers/__init__.py +49 -0
  26. ms-swift/swift/trainers/__pycache__/__init__.cpython-310.pyc +0 -0
  27. ms-swift/swift/trainers/__pycache__/callback.cpython-310.pyc +0 -0
  28. ms-swift/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc +0 -0
  29. ms-swift/swift/trainers/__pycache__/trainers.cpython-310.pyc +0 -0
  30. ms-swift/swift/trainers/callback.py +124 -0
  31. ms-swift/swift/trainers/mixin.py +516 -0
  32. ms-swift/swift/trainers/optimizers/__init__.py +1 -0
  33. ms-swift/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc +0 -0
  34. ms-swift/swift/trainers/optimizers/galore/__init__.py +28 -0
  35. ms-swift/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc +0 -0
  36. ms-swift/swift/trainers/optimizers/galore/adafactor.py +272 -0
  37. ms-swift/swift/trainers/optimizers/galore/galore_projector.py +109 -0
  38. ms-swift/swift/trainers/optimizers/galore/utils.py +214 -0
  39. ms-swift/swift/trainers/rlhf_arguments.py +63 -0
  40. ms-swift/swift/trainers/rlhf_trainer/kto_trainer.py +69 -0
  41. ms-swift/swift/trainers/rlhf_trainer/orpo_trainer.py +19 -0
  42. ms-swift/swift/trainers/rlhf_trainer/ppo_trainer.py +65 -0
  43. ms-swift/swift/trainers/rlhf_trainer/reward_trainer.py +78 -0
  44. ms-swift/swift/trainers/rlhf_trainer/rlhf_mixin.py +104 -0
  45. ms-swift/swift/trainers/rlhf_trainer/utils.py +132 -0
  46. ms-swift/swift/trainers/rlhf_trainer/vllm_client.py +212 -0
  47. ms-swift/swift/trainers/sequence_parallel/base.py +45 -0
  48. ms-swift/swift/trainers/sequence_parallel/ulysses.py +594 -0
  49. ms-swift/swift/trainers/sequence_parallel/xtuner.py +127 -0
  50. ms-swift/swift/trainers/torchacc_mixin.py +156 -0
ms-swift/.ipynb_checkpoints/clean_transcripts-checkpoint.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from typing import List, Dict, Tuple
4
+
5
+ def parse_timestamp(timestamp: str) -> Tuple[int, int]:
6
+ """Convert timestamp string like '00:15' to seconds."""
7
+ minutes, seconds = map(int, timestamp.split(':'))
8
+ return minutes * 60 + seconds
9
+
10
+ def extract_time_and_speaker(line: str) -> Tuple[Tuple[int, int], str]:
11
+ """Extract time range and speaker from a line."""
12
+ # Extract time range
13
+ time_match = re.match(r'\[(\d{2}:\d{2}) - (\d{2}:\d{2})\] (Speaker [A-Z]):', line)
14
+ if not time_match:
15
+ return None, None
16
+
17
+ start_time = parse_timestamp(time_match.group(1))
18
+ end_time = parse_timestamp(time_match.group(2))
19
+ speaker = time_match.group(3)
20
+
21
+ return (start_time, end_time), speaker
22
+
23
+ def has_overlap(range1: Tuple[int, int], range2: Tuple[int, int]) -> bool:
24
+ """Check if two time ranges overlap."""
25
+ start1, end1 = range1
26
+ start2, end2 = range2
27
+ return not (end1 <= start2 or end2 <= start1)
28
+
29
+ def has_same_speaker_overlap(transcript: str) -> bool:
30
+ """Check if a transcript contains overlapping timestamps for the same speaker."""
31
+ lines = transcript.split('\n')
32
+ # Dictionary to store time ranges for each speaker
33
+ speaker_ranges = {}
34
+
35
+ for line in lines:
36
+ if not line.strip():
37
+ continue
38
+
39
+ time_range, speaker = extract_time_and_speaker(line)
40
+ if time_range is None or speaker is None:
41
+ continue
42
+
43
+ # Check for overlaps with existing ranges of the same speaker
44
+ if speaker in speaker_ranges:
45
+ for existing_range in speaker_ranges[speaker]:
46
+ if has_overlap(time_range, existing_range):
47
+ return True
48
+
49
+ speaker_ranges[speaker].append(time_range)
50
+ else:
51
+ speaker_ranges[speaker] = [time_range]
52
+
53
+ return False
54
+
55
+ def process_file(input_file: str, output_file: str, delete_file: str):
56
+ """Process the JSON file and separate entries with same-speaker overlapping timestamps."""
57
+ with open(input_file, 'r', encoding='utf-8') as f:
58
+ data = json.load(f)
59
+
60
+ if isinstance(data, dict):
61
+ data = [data]
62
+
63
+ cleaned_data = []
64
+ deleted_data = []
65
+ removed_count = 0
66
+
67
+ for entry in data:
68
+ if 'model_output' in entry:
69
+ if not has_same_speaker_overlap(entry['model_output']):
70
+ cleaned_data.append(entry)
71
+ else:
72
+ deleted_data.append(entry)
73
+ removed_count += 1
74
+ print(f"Removing entry with key: {entry.get('key', 'unknown')}")
75
+
76
+ # Save cleaned data
77
+ with open(output_file, 'w', encoding='utf-8') as f:
78
+ json.dump(cleaned_data, f, ensure_ascii=False, indent=2)
79
+
80
+ # Save deleted data
81
+ with open(delete_file, 'w', encoding='utf-8') as f:
82
+ json.dump(deleted_data, f, ensure_ascii=False, indent=2)
83
+
84
+ print(f"\nProcessing Summary:")
85
+ print(f"Processed {len(data)} entries")
86
+ print(f"Removed {removed_count} entries with same-speaker overlapping timestamps")
87
+ print(f"Remaining entries: {len(cleaned_data)}")
88
+
89
+ if __name__ == '__main__':
90
+ input_file = 'silence_overlaps/transcriptions.json'
91
+ output_file = 'silence_overlaps/cleaned_transcriptions2.json'
92
+ delete_file = 'silence_overlaps/delete_transcript2.json'
93
+ process_file(input_file, output_file, delete_file)
94
+ print(f"\nCleaned transcriptions have been saved to {output_file}")
95
+ print(f"Deleted entries have been saved to {delete_file}")
ms-swift/.ipynb_checkpoints/dataset_new-checkpoint.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/delete_transcript.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/swift/llm/train/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (392 Bytes). View file
 
ms-swift/swift/llm/train/__pycache__/kto.cpython-310.pyc ADDED
Binary file (2.94 kB). View file
 
ms-swift/swift/megatron/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ try:
4
+ from .init import init_megatron_env
5
+ init_megatron_env()
6
+ except Exception:
7
+ # allows lint pass.
8
+ raise
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ from swift.utils.import_utils import _LazyModule
13
+
14
+ if TYPE_CHECKING:
15
+ from .train import megatron_sft_main, megatron_pt_main
16
+ from .utils import convert_hf2mcore, convert_mcore2hf
17
+ from .argument import MegatronTrainArguments
18
+ from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
19
+ else:
20
+ _import_structure = {
21
+ 'train': ['megatron_sft_main', 'megatron_pt_main'],
22
+ 'utils': ['convert_hf2mcore', 'convert_mcore2hf'],
23
+ 'argument': ['MegatronTrainArguments'],
24
+ 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model']
25
+ }
26
+
27
+ import sys
28
+
29
+ sys.modules[__name__] = _LazyModule(
30
+ __name__,
31
+ globals()['__file__'],
32
+ _import_structure,
33
+ module_spec=__spec__,
34
+ extra_objects={},
35
+ )
ms-swift/swift/megatron/argument/megatron_args.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ from dataclasses import asdict, dataclass, field
5
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from transformers.utils.versions import require_version
9
+
10
+ from swift.llm.argument.base_args import to_abspath
11
+
12
+
13
+ @dataclass
14
+ class ExtraMegatronArguments:
15
+ padded_vocab_size: Optional[int] = None
16
+ rope_scaling: Optional[Union[dict, str]] = None
17
+ torch_dtype: Optional[torch.dtype] = None
18
+
19
+ dataloader_persistent_workers: bool = True
20
+ dataloader_prefetch_factor: int = 10
21
+
22
+ model_type: Optional[str] = None
23
+ max_epochs: Optional[int] = None
24
+
25
+
26
+ @dataclass
27
+ class MegatronArguments(ExtraMegatronArguments):
28
+ # training
29
+ micro_batch_size: int = 1
30
+ global_batch_size: int = 16
31
+ recompute_granularity: Literal['selective', 'full'] = 'selective'
32
+ recompute_method: Literal['uniform', 'block'] = None
33
+ recompute_num_layers: Optional[int] = None
34
+ recompute_modules: List[str] = field(default_factory=lambda: ['core_attn'])
35
+ use_cpu_initialization: bool = False
36
+ deterministic_mode: bool = False
37
+ train_iters: Optional[int] = None
38
+ log_interval: int = 5
39
+ tensorboard_dir: Optional[str] = None
40
+ no_masked_softmax_fusion: bool = False
41
+ no_bias_dropout_fusion: bool = False
42
+ no_bias_swiglu_fusion: bool = False
43
+ no_rope_fusion: bool = False
44
+ no_gradient_accumulation_fusion: bool = False
45
+ cross_entropy_loss_fusion: bool = False
46
+ calculate_per_token_loss: bool = True
47
+ use_flash_attn: bool = False
48
+ attention_backend: str = 'auto' # flash, fused, unfused, local, auto
49
+ optimizer: Literal['adam', 'sgd'] = 'adam'
50
+ dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic'
51
+ manual_gc: bool = False
52
+ manual_gc_interval: int = 0
53
+
54
+ # learning rate
55
+ lr: float = 1e-5
56
+ lr_decay_style: Literal['cosine', 'linear', 'constant'] = 'cosine'
57
+ # The default is None, which will be set to `train_iters`.
58
+ lr_decay_iters: Optional[int] = None
59
+ lr_warmup_iters: int = 0
60
+ min_lr: float = 0
61
+
62
+ # regularization
63
+ weight_decay: float = 0.1
64
+ clip_grad: float = 1.
65
+ adam_beta1: float = 0.9
66
+ adam_beta2: float = 0.95
67
+ adam_eps: float = 1e-8
68
+ sgd_momentum: float = 0.9
69
+
70
+ # checkpoint
71
+ save: Optional[str] = None
72
+ save_interval: int = 500
73
+ no_save_optim: bool = False
74
+ no_save_rng: bool = False
75
+ load: Optional[str] = None
76
+ no_load_optim: bool = False
77
+ no_load_rng: bool = False
78
+ finetune: bool = False
79
+ ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist'
80
+ no_initialization: bool = True
81
+ auto_detect_ckpt_format: bool = True
82
+ exit_on_missing_checkpoint: bool = True
83
+
84
+ # dist
85
+ distributed_backend: Literal['nccl', 'gloo'] = 'nccl'
86
+ use_distributed_optimizer: bool = True
87
+ tensor_model_parallel_size: int = 1
88
+ pipeline_model_parallel_size: int = 1
89
+ decoder_first_pipeline_num_layers: Optional[int] = None
90
+ decoder_last_pipeline_num_layers: Optional[int] = None
91
+ sequence_parallel: bool = False
92
+ context_parallel_size: int = 1
93
+ tp_comm_overlap: bool = False
94
+ overlap_grad_reduce: bool = False
95
+ overlap_param_gather: bool = False
96
+ distributed_timeout_minutes: int = 60
97
+
98
+ # model
99
+ num_layers: Optional[int] = None
100
+ hidden_size: Optional[int] = None
101
+ ffn_hidden_size: Optional[int] = None
102
+ num_attention_heads: Optional[int] = None
103
+ group_query_attention: Optional[bool] = None
104
+ num_query_groups: Optional[int] = None
105
+ max_position_embeddings: Optional[int] = None
106
+ position_embedding_type: Literal['learned_absolute', 'rope', 'relative', 'none'] = 'rope'
107
+ rotary_base: Optional[int] = None
108
+ rotary_percent: float = 1.
109
+ normalization: Literal['LayerNorm', 'RMSNorm'] = 'RMSNorm'
110
+ norm_epsilon: Optional[float] = None
111
+ swiglu: Optional[bool] = None
112
+ untie_embeddings_and_output_weights: Optional[bool] = None
113
+ disable_bias_linear: Optional[bool] = None
114
+ add_qkv_bias: Optional[bool] = None
115
+ attention_dropout: Optional[float] = None
116
+ hidden_dropout: float = 0.
117
+ kv_channels: Optional[int] = None
118
+ qk_layernorm: Optional[bool] = None
119
+ transformer_impl: Literal['local', 'transformer_engine'] = 'transformer_engine'
120
+
121
+ # moe
122
+ num_experts: Optional[int] = None
123
+ moe_ffn_hidden_size: Optional[int] = None
124
+ moe_shared_expert_intermediate_size: Optional[int] = None
125
+ moe_router_topk: Optional[int] = None
126
+ moe_router_pre_softmax: Optional[bool] = None
127
+ moe_aux_loss_coeff: Optional[float] = None
128
+
129
+ expert_model_parallel_size: int = 1
130
+ moe_token_dispatcher_type: Literal['allgather', 'alltoall', 'alltoall_seq'] = 'alltoall'
131
+ moe_grouped_gemm: bool = False
132
+ moe_router_load_balancing_type: Literal['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'] = 'aux_loss'
133
+ moe_z_loss_coeff: Optional[float] = None
134
+ moe_expert_capacity_factor: Optional[float] = None
135
+ moe_shared_expert_overlap: bool = False
136
+
137
+ # mixed precision
138
+ fp16: Optional[bool] = None
139
+ bf16: Optional[bool] = None
140
+ apply_query_key_layer_scaling: Optional[bool] = None
141
+ attention_softmax_in_fp32: bool = True
142
+
143
+ # logging
144
+ log_params_norm: bool = False
145
+ log_throughput: bool = True
146
+ tensorboard_log_interval: int = 1
147
+ tensorboard_queue_size: int = 50
148
+ log_timers_to_tensorboard: bool = True
149
+ no_log_learning_rate_to_tensorboard: bool = False
150
+ log_validation_ppl_to_tensorboard: bool = True
151
+ log_memory_to_tensorboard: bool = True
152
+ logging_level: Optional[str] = None
153
+ wandb_project: Optional[str] = None
154
+ wandb_exp_name: Optional[str] = None
155
+ wandb_save_dir: Optional[str] = None
156
+
157
+ # evaluate
158
+ eval_iters: int = 100
159
+ eval_interval: Optional[int] = None
160
+
161
+ # other
162
+ seed: int = 42
163
+ seq_length: Optional[int] = None
164
+ num_workers: int = 4
165
+ no_create_attention_mask_in_dataloader: bool = True
166
+
167
+ def _set_default(self):
168
+ if self.num_query_groups is None:
169
+ self.num_query_groups = 1
170
+ if self.norm_epsilon is None:
171
+ self.norm_epsilon = 1e-5
172
+ if self.rotary_base is None:
173
+ self.rotary_base = 10000
174
+ if self.attention_dropout is None:
175
+ self.attention_dropout = 0.
176
+ if self.untie_embeddings_and_output_weights is None:
177
+ self.untie_embeddings_and_output_weights = True
178
+ if self.swiglu is None:
179
+ self.swiglu = True
180
+ if self.add_qkv_bias is None:
181
+ self.add_qkv_bias = True
182
+ if self.disable_bias_linear is None:
183
+ self.disable_bias_linear = True
184
+ if self.moe_router_topk is None:
185
+ self.moe_router_topk = 2
186
+ if self.moe_router_pre_softmax is None:
187
+ self.moe_router_pre_softmax = False
188
+ if self.moe_aux_loss_coeff is None:
189
+ self.moe_aux_loss_coeff = 0.
190
+ if self.qk_layernorm is None:
191
+ self.qk_layernorm = False
192
+
193
+ def _init_mixed_precision(self):
194
+ from swift.llm.argument.base_args.model_args import ModelArguments
195
+ ModelArguments._init_mixed_precision(self)
196
+ if self.apply_query_key_layer_scaling is None:
197
+ self.apply_query_key_layer_scaling = self.fp16
198
+ if self.apply_query_key_layer_scaling:
199
+ os.environ['NVTE_APPLY_QK_LAYER_SCALING'] = '1'
200
+
201
+ def _init_moe(self):
202
+ if self.moe_shared_expert_intermediate_size == 0:
203
+ self.moe_shared_expert_intermediate_size = None
204
+ if self.moe_ffn_hidden_size is None:
205
+ self.moe_ffn_hidden_size = self.ffn_hidden_size
206
+ else:
207
+ self.ffn_hidden_size = self.moe_ffn_hidden_size
208
+
209
+ def __post_init__(self):
210
+ from swift.llm.argument.base_args.model_args import ModelArguments
211
+ if self.use_flash_attn or self.attention_backend == 'flash':
212
+ require_version('flash-attn')
213
+ os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
214
+ self._set_default()
215
+ self.group_query_attention = self.num_query_groups > 1
216
+ if self.rope_scaling is not None:
217
+ self.rope_scaling = ModelArguments.parse_to_dict(self.rope_scaling)
218
+ if self.eval_interval is None:
219
+ self.eval_interval = self.save_interval
220
+ if self.seq_length is None:
221
+ self.seq_length = self.max_position_embeddings
222
+ if self.tensorboard_dir is None and self.save is not None:
223
+ self.tensorboard_dir = f'{self.save}/runs'
224
+ self._init_moe()
225
+ self._init_mixed_precision()
226
+
227
+ self.tensorboard_dir = to_abspath(self.tensorboard_dir)
228
+
229
+ def _args_to_argv(self) -> Tuple[List[Any], Dict[str, Any]]:
230
+ new_args = []
231
+ args_dict = asdict(self)
232
+ extra_args = {}
233
+ for k, value in args_dict.items():
234
+ if k not in MegatronArguments.__annotations__:
235
+ extra_args[k] = value
236
+ continue
237
+ if value is None or value is False:
238
+ continue
239
+ new_args.append(f"--{k.replace('_', '-')}")
240
+ if isinstance(value, list):
241
+ new_args += [str(v) for v in value]
242
+ elif value is not True:
243
+ new_args.append(str(value))
244
+
245
+ return new_args, extra_args
246
+
247
+ def parse_to_megatron(self):
248
+ new_args, extra_args = self._args_to_argv()
249
+ sys._old_argv = sys.argv
250
+ sys.argv = sys.argv[:1] + new_args
251
+ # parameter conflict
252
+ extra_args.pop('loss_scale', None)
253
+ return extra_args
ms-swift/swift/megatron/model/gpt/mcore2hf.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from megatron.training import get_args
3
+
4
+
5
+ def set_attn_state(args, mg_attn, hf_attn):
6
+ num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
7
+ # Copy weights
8
+ mg_attn_weight = mg_attn.linear_qkv.weight.reshape((num_query_groups, -1, args.hidden_size))
9
+ q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[
10
+ 0] // num_query_groups
11
+ hf_attn.q_proj.weight.data.copy_(mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size))
12
+ hf_attn.k_proj.weight.data.copy_(mg_attn_weight[:, q_dim:-kv_dim, :].reshape(-1, args.hidden_size))
13
+ hf_attn.v_proj.weight.data.copy_(mg_attn_weight[:, -kv_dim:, :].reshape(-1, args.hidden_size))
14
+ hf_attn.o_proj.weight.data.copy_(mg_attn.linear_proj.weight)
15
+
16
+ # Copy bias
17
+ if args.add_qkv_bias:
18
+ mg_attn_bias = mg_attn.linear_qkv.bias.reshape((num_query_groups, -1))
19
+ hf_attn.q_proj.bias.data.copy_(mg_attn_bias[:, :q_dim].reshape(-1))
20
+ hf_attn.k_proj.bias.data.copy_(mg_attn_bias[:, q_dim:-kv_dim].reshape(-1))
21
+ hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
22
+
23
+ if args.qk_layernorm:
24
+ hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
25
+ hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)
26
+
27
+
28
+ def _set_mlp_state(mg_mlp, hf_mlp):
29
+ ffn_hidden_size = hf_mlp.gate_proj.weight.shape[0]
30
+ hf_mlp.gate_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[:ffn_hidden_size])
31
+ hf_mlp.up_proj.weight.data.copy_(mg_mlp.linear_fc1.weight[ffn_hidden_size:])
32
+ hf_mlp.down_proj.weight.data.copy_(mg_mlp.linear_fc2.weight)
33
+
34
+
35
+ def set_mlp_state(args, mg_mlp, hf_mlp):
36
+ if args.num_experts:
37
+ hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight)
38
+ if mg_mlp.shared_experts is not None:
39
+ hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight)
40
+ for expert_idx in range(args.num_experts):
41
+ _set_mlp_state(mg_mlp.experts.local_experts[expert_idx], hf_mlp.experts[expert_idx])
42
+
43
+ if mg_mlp.shared_experts is not None:
44
+ _set_mlp_state(mg_mlp.shared_experts, hf_mlp.shared_expert)
45
+ else:
46
+ _set_mlp_state(mg_mlp, hf_mlp)
47
+
48
+
49
+ def set_layer_state(args, mg_model, hf_model, layer_idx):
50
+ mg_layer = mg_model.decoder.layers[layer_idx]
51
+ hf_layer = hf_model.model.layers[layer_idx]
52
+ set_attn_state(args, mg_layer.self_attention, hf_layer.self_attn)
53
+ set_mlp_state(args, mg_layer.mlp, hf_layer.mlp)
54
+
55
+ post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight
56
+ if args.num_experts:
57
+ post_attention_layernorm_weight.data.copy_(mg_layer.pre_mlp_layernorm.weight)
58
+ else:
59
+ post_attention_layernorm_weight.data.copy_(mg_layer.mlp.linear_fc1.layer_norm_weight)
60
+ hf_layer.input_layernorm.weight.data.copy_(mg_layer.self_attention.linear_qkv.layer_norm_weight)
61
+
62
+
63
+ def convert_mcore2hf(hf_model, mg_model):
64
+ args = get_args()
65
+ hf_model.model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight)
66
+ if args.untie_embeddings_and_output_weights:
67
+ hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight)
68
+ hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight)
69
+ for layer_idx in range(args.num_layers):
70
+ set_layer_state(args, mg_model, hf_model, layer_idx)
ms-swift/swift/megatron/train/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pt import megatron_pt_main
2
+ from .sft import megatron_sft_main
ms-swift/swift/megatron/train/pt.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import List, Union
3
+
4
+ from ..argument import MegatronTrainArguments
5
+ from .sft import MegatronSft
6
+
7
+
8
+ class MegatronPt(MegatronSft):
9
+ args_class = MegatronTrainArguments
10
+ args: args_class
11
+
12
+ def _prepare_template(self) -> None:
13
+ self.args.use_chat_template = False
14
+ super()._prepare_template()
15
+ self.template.loss_scale = 'all'
16
+
17
+
18
+ def megatron_pt_main(args: Union[List[str], MegatronTrainArguments, None] = None):
19
+ return MegatronPt(args).main()
ms-swift/swift/megatron/train/sft.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from typing import List, Union
4
+
5
+ from megatron.core.enums import ModelType
6
+ from megatron.training import pretrain
7
+
8
+ from swift.llm.train import SwiftSft
9
+ from swift.utils import get_logger, is_master, plot_images
10
+ from ..argument import MegatronTrainArguments
11
+ from ..utils import patch_megatron_tokenizer
12
+ from .patcher import patch_megatron_data_collator, patch_training_log
13
+ from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ class MegatronSft(SwiftSft):
19
+ args_class = MegatronTrainArguments
20
+ args: args_class
21
+
22
+ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None:
23
+ self.train_msg = {}
24
+ super(SwiftSft, self).__init__(args)
25
+ args = self.args
26
+ _, self.processor = args.get_model_processor(load_model=False)
27
+ patch_megatron_tokenizer(self.processor)
28
+ args.init_model_args(self.processor.model_info.config)
29
+ self._prepare_template()
30
+ self.template.use_megatron = True
31
+ args.save_args(args.save)
32
+
33
+ def run(self):
34
+ args = self.args
35
+
36
+ train_dataset, val_dataset = self._get_dataset()
37
+ train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
38
+ data_collator = self.template.data_collator
39
+ if args.streaming:
40
+ train_dataset = build_streaming_dataloader(args, train_dataset, data_collator)
41
+ if val_dataset is not None:
42
+ val_dataset = build_streaming_dataloader(args, val_dataset, data_collator)
43
+ datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset)
44
+ datasets_provider.is_distributed = True
45
+
46
+ logging_path = os.path.join(args.save, 'logging.jsonl')
47
+ logger.info(f'The logging file will be saved in: {logging_path}')
48
+ try:
49
+ with patch_training_log(), patch_megatron_data_collator(data_collator):
50
+ pretrain(
51
+ datasets_provider,
52
+ args.megatron_model_meta.model_provider,
53
+ ModelType.encoder_or_decoder,
54
+ forward_step,
55
+ args_defaults=args.extra_args)
56
+ finally:
57
+ # Visualization
58
+ if is_master():
59
+ images_dir = os.path.join(args.save, 'images')
60
+ logger.info(f'images_dir: {images_dir}')
61
+ plot_images(images_dir, args.tensorboard_dir)
62
+
63
+
64
+ def megatron_sft_main(args: Union[List[str], MegatronTrainArguments, None] = None):
65
+ return MegatronSft(args).main()
ms-swift/swift/plugin/agent_template/__pycache__/glm4.cpython-310.pyc ADDED
Binary file (3.32 kB). View file
 
ms-swift/swift/plugin/agent_template/__pycache__/llama.cpython-310.pyc ADDED
Binary file (3.39 kB). View file
 
ms-swift/swift/plugin/agent_template/__pycache__/qwen.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
ms-swift/swift/plugin/agent_template/__pycache__/react.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
ms-swift/swift/plugin/agent_template/hermes.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import re
3
+ from typing import TYPE_CHECKING, List, Tuple, Union
4
+
5
+ import json
6
+
7
+ from .base import BaseAgentTemplate
8
+
9
+ if TYPE_CHECKING:
10
+ from swift.llm.infer import Function
11
+ from swift.llm.template import Prompt
12
+
13
+
14
+ class HermesAgentTemplate(BaseAgentTemplate):
15
+
16
+ def get_toolcall(self, response: str) -> List['Function']:
17
+ from swift.llm.infer import Function
18
+ res_list = re.findall(r'<tool_call>(.+?)</tool_call>', response, re.DOTALL)
19
+ functions = []
20
+ for res in res_list:
21
+ res = self._parse_json(res)
22
+ if isinstance(res, dict) and 'name' in res and 'arguments' in res:
23
+ functions.append(Function(name=res['name'], arguments=res['arguments']))
24
+ if len(functions) == 0:
25
+ # compat react_en
26
+ return super().get_toolcall(response)
27
+ return functions
28
+
29
+ def _format_tool_responses(
30
+ self,
31
+ assistant_content: str,
32
+ tool_messages,
33
+ ) -> Tuple[str, 'Prompt']:
34
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
35
+ if with_action:
36
+ return super()._format_tool_responses(assistant_content, tool_messages)
37
+ if hasattr(self, 'template_meta'):
38
+ prompt = self.template_meta.prompt
39
+ chat_sep = self.template_meta.chat_sep
40
+ else:
41
+ prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']
42
+ chat_sep = ['<|im_end|>\n']
43
+ res = chat_sep.copy()
44
+ res_tool = []
45
+ for tool_message in tool_messages:
46
+ tool_content = tool_message['content']
47
+ res_tool.append(f'<tool_response>\n{tool_content}\n</tool_response>')
48
+ total_tool = '\n'.join(res_tool)
49
+ for context in prompt:
50
+ if isinstance(context, str):
51
+ context = context.replace('{{QUERY}}', total_tool)
52
+ res.append(context)
53
+ return assistant_content, res
54
+
55
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
56
+ tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools]
57
+ return f"""{system}
58
+
59
+ # Tools
60
+
61
+ You may call one or more functions to assist with the user query.
62
+
63
+ You are provided with function signatures within <tools></tools> XML tags:
64
+ <tools>
65
+ """ + '\n'.join(tool_descs) + """
66
+ </tools>
67
+
68
+ For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
69
+ <tool_call>
70
+ {"name": <function-name>, "arguments": <args-json-object>}
71
+ </tool_call>"""
72
+
73
+ def _format_tool_calls(self, tool_call_messages):
74
+ tool_calls = []
75
+ for message in tool_call_messages:
76
+ tool_call = self._parse_tool_call(message['content'])
77
+ tool_calls.append(f'<tool_call>\n{json.dumps(tool_call, ensure_ascii=False)}\n</tool_call>')
78
+ return '\n'.join(tool_calls)
ms-swift/swift/plugin/agent_template/react.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import List, Union
3
+
4
+ from .base import BaseAgentTemplate
5
+
6
+
7
+ class ReactEnAgentTemplate(BaseAgentTemplate):
8
+
9
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
10
+ tool_names = []
11
+ tool_descs = []
12
+ for tool in tools:
13
+ tool_desc = self._parse_tool(tool, 'en')
14
+ tool_names.append(tool_desc.name_for_model)
15
+ tool_descs.append(
16
+ f'{tool_desc.name_for_model}: Call this tool to interact with the {tool_desc.name_for_human} API. '
17
+ f'What is the {tool_desc.name_for_human} API useful for? {tool_desc.description_for_model} '
18
+ f'Parameters: {tool_desc.parameters} {tool_desc.args_format}')
19
+
20
+ return """Answer the following questions as best you can. You have access to the following tools:
21
+
22
+ """ + '\n\n'.join(tool_descs) + f"""
23
+
24
+ Use the following format:
25
+
26
+ Question: the input question you must answer
27
+ Thought: you should always think about what to do
28
+ Action: the action to take, should be one of [{','.join(tool_names)}]
29
+ Action Input: the input to the action
30
+ Observation: the result of the action
31
+ ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
32
+ Thought: I now know the final answer
33
+ Final Answer: the final answer to the original input question
34
+
35
+ Begin!
36
+ """
37
+
38
+
39
+ class ReactZnAgentTemplate(BaseAgentTemplate):
40
+
41
+ def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
42
+ tool_names = []
43
+ tool_descs = []
44
+ for tool in tools:
45
+ tool_desc = self._parse_tool(tool, 'zh')
46
+ tool_names.append(tool_desc.name_for_model)
47
+ tool_descs.append(f'{tool_desc.name_for_model}: 调用此工具与 {tool_desc.name_for_human} API 进行交互。'
48
+ f'{tool_desc.name_for_human} 有什么用?{tool_desc.description_for_model} '
49
+ f'输入参数:{tool_desc.parameters} {tool_desc.args_format}')
50
+ return """尽可能地回答以下问题。你可以使用以下工具:
51
+
52
+ """ + '\n\n'.join(tool_descs) + f"""
53
+
54
+ 请按照以下格式进行:
55
+
56
+ Question: 需要你回答的输入问题
57
+ Thought: 你应该总是思考该做什么
58
+ Action: 需要使用的工具,应该是[{','.join(tool_names)}]中的一个
59
+ Action Input: 传入工具的内容
60
+ Observation: 行动的结果
61
+ ... (这个Thought/Action/Action Input/Observation可以重复N次)
62
+ Thought: 我现在知道最后的答案
63
+ Final Answer: 对原始输入问题的最终答案
64
+
65
+ 现在开始!
66
+ """
ms-swift/swift/plugin/loss_scale/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .loss_scale import loss_scale_map
ms-swift/swift/plugin/loss_scale/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (227 Bytes). View file
 
ms-swift/swift/plugin/loss_scale/__pycache__/loss_scale.cpython-310.pyc ADDED
Binary file (4.7 kB). View file
 
ms-swift/swift/plugin/loss_scale/config/agentflan.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "response":{
3
+ "Name:": [1.0, 3.0],
4
+ "Action:": [1.0, 3.0],
5
+ "ACTION:": [1.0,3.0],
6
+ "Tool:": [1.0, 3.0],
7
+ "Command": [1.0, 3.0],
8
+ "Arguments:": [1.0, 3.0],
9
+ "action input": [1.0, 3.0],
10
+ "ACTION_INPUT:":[1.0, 3.0],
11
+ "Action Input:": [1.0, 3.0],
12
+ "Thought:": [1.0, 1.0],
13
+ "Final Answer:": [1.0, 1.0],
14
+ "Observation:": [2.0, 0.0]
15
+ },
16
+ "query":{
17
+ "What is the tool you want to use": [3.0],
18
+ "What are the required parameter names": [3.0],
19
+ "What is the value of": [3.0],
20
+ "What are the required parameter names for this tool": [3.0]
21
+ }
22
+ }
ms-swift/swift/plugin/loss_scale/config/hermes.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<tool_call>.+?</tool_call>": [2.0]
3
+ }
ms-swift/swift/plugin/prm.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Union
3
+
4
+ import json
5
+
6
+ from swift.llm import InferRequest
7
+
8
+
9
+ class PRM:
10
+
11
+ def __call__(self, **kwargs) -> List[Any]:
12
+ raise NotImplementedError
13
+
14
+
15
+ SYSTEM = """
16
+ You are a process reward model, give the reward value of the answer, you must follow the instructions below:
17
+
18
+ 1. Output a float reward value between -1.0 and 1.0, -1.0 means the worst answer, 1.0 means the best answer, please think step by step to give your reasons and thoughts, but the reward must appare at the end with this format: **Reward: your-reward-value**.
19
+
20
+ 2. The answer may be incomplete, you must give the reward by the existing part of the answer, taking into account semantic coherence, logical correctness, and clarity.
21
+
22
+ 3. A ground truth answer will be given to you, it may be not the best one, consider it as a reference example.
23
+
24
+ Begin!
25
+ """ # noqa
26
+
27
+ QUERY = """
28
+ The original question or the previous conversation:
29
+
30
+ #query#
31
+
32
+ Here is the ground truth as the reference:
33
+
34
+ #ground_truth#
35
+
36
+ Given the upper information, give your reward(-1.0~1.0) of the following answer:
37
+
38
+ #response#
39
+ """
40
+
41
+
42
+ class QwenMaxPRM(PRM):
43
+
44
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
45
+ **kwargs) -> List[float]:
46
+ # TODO: check request_config
47
+ rewards = []
48
+
49
+ from openai import OpenAI
50
+
51
+ client = OpenAI(
52
+ api_key=os.getenv('DASHSCOPE_API_KEY'),
53
+ base_url='https://dashscope.aliyuncs.com/compatible-mode/v1',
54
+ )
55
+
56
+ for request, ground_truth in zip(infer_requests, ground_truths):
57
+ previous = request['messages'][:-1]
58
+ if previous[0]['role'] == 'system':
59
+ previous = previous[1:]
60
+
61
+ assert request['messages'][-1]['role'] == 'assistant'
62
+ query = QUERY.replace('#query#', json.dumps(previous))
63
+ query = query.replace('#ground_truth#', ground_truth)
64
+ query = query.replace('#response#', request['messages'][-1]['content'])
65
+ messages = [
66
+ {
67
+ 'role': 'system',
68
+ 'content': SYSTEM
69
+ },
70
+ {
71
+ 'role': 'user',
72
+ 'content': query
73
+ },
74
+ ]
75
+ completion = client.chat.completions.create(
76
+ model='qwen-max',
77
+ messages=messages,
78
+ )
79
+
80
+ content = completion.choices[0].message.content
81
+ if 'Reward:' not in content:
82
+ rewards.append(0.)
83
+ else:
84
+ try:
85
+ reward = float(content.split('Reward:')[1].strip().replace('*', ''))
86
+ rewards.append(reward)
87
+ except Exception:
88
+ rewards.append(0.)
89
+
90
+ return rewards
91
+
92
+
93
+ class ClientPRM(PRM):
94
+
95
+ def __init__(self, api_key=None, base_url=None, model=None):
96
+ from swift.llm import InferClient
97
+ import os
98
+ if api_key is None:
99
+ api_key = os.getenv('DASHSCOPE_API_KEY')
100
+ if base_url is None:
101
+ base_url = 'https://dashscope.aliyuncs.com/compatible-mode/v1'
102
+ if model is None:
103
+ model = 'qwen-plus'
104
+ self.infer_engine = InferClient(base_url=base_url, api_key=api_key)
105
+ self.infer_engine.strict = False
106
+ self.infer_kwargs = {
107
+ 'model': model,
108
+ }
109
+
110
+ def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
111
+ **kwargs) -> List[float]:
112
+ prm_infer_requests = []
113
+ request_config = kwargs.get('request_config')
114
+ for request, ground_truth in zip(infer_requests, ground_truths):
115
+ previous = request['messages'][:-1]
116
+ if previous[0]['role'] == 'system':
117
+ previous = previous[1:]
118
+
119
+ assert request['messages'][-1]['role'] == 'assistant'
120
+ query = QUERY.replace('#query#', json.dumps(previous))
121
+ query = query.replace('#ground_truth#', ground_truth)
122
+ query = query.replace('#response#', request['messages'][-1]['content'])
123
+ messages = [
124
+ {
125
+ 'role': 'system',
126
+ 'content': SYSTEM
127
+ },
128
+ {
129
+ 'role': 'user',
130
+ 'content': query
131
+ },
132
+ ]
133
+
134
+ prm_infer_requests.append(InferRequest(messages=messages))
135
+
136
+ responses = self.infer_engine.infer(prm_infer_requests, request_config=request_config, **self.infer_kwargs)
137
+ rewards = []
138
+ for response in responses:
139
+ content = response.choices[0].message.content
140
+ if 'Reward:' not in content:
141
+ rewards.append(0.)
142
+ else:
143
+ try:
144
+ reward = float(content.split('Reward:')[1].strip().replace('*', ''))
145
+ rewards.append(reward)
146
+ except Exception:
147
+ rewards.append(0.)
148
+ return rewards
149
+
150
+
151
+ prms = {
152
+ 'qwen_max': QwenMaxPRM,
153
+ 'client': ClientPRM,
154
+ }
ms-swift/swift/plugin/rm_plugin.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import textwrap
3
+ from copy import deepcopy
4
+ from typing import Dict, List
5
+
6
+ import torch
7
+
8
+ from swift.llm import PtEngine, RequestConfig, Template, to_device
9
+ from swift.llm.infer.protocol import ChatCompletionResponse
10
+ from swift.utils import get_logger
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class DefaultRMPlugin:
16
+ """
17
+ Default Reward Model Plugin
18
+
19
+ This class implements the default processing logic for reward models.
20
+ It assumes that `self.model` is a classification model with a value head(output dimmension 1).
21
+ The first logits value from the model's output is used as the reward score.
22
+ """
23
+
24
+ def __init__(self, model, template):
25
+ self.model = model
26
+ self.template: Template = template
27
+
28
+ def __call__(self, inputs):
29
+ batched_inputs = [self.template.encode(deepcopy(infer_request)) for infer_request in inputs]
30
+ reward_inputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
31
+ reward_inputs.pop('labels')
32
+
33
+ with torch.inference_mode():
34
+ return self.model(**reward_inputs).logits[:, 0]
35
+
36
+
37
+ class GenRMPlugin(DefaultRMPlugin):
38
+
39
+ def __init__(self, model, template):
40
+ """
41
+ Generative Reward Model Plugin Example.
42
+
43
+ This method sets up the reward model plugin by initializing the PtEngine for efficient inference,
44
+ configuring the request parameters, and defining the system prompt that guides the reward model in
45
+ evaluating responses.
46
+
47
+ Args:
48
+ model (torch.nn.Module): The generative reward model.
49
+ template (Template): The template used for encoding input data.
50
+ """
51
+
52
+ super().__init__(model, template)
53
+ # initilize PTEngine to infer
54
+ self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
55
+ self.request_config = RequestConfig() # customise your request config here
56
+ self.system = textwrap.dedent("""
57
+ Based on the dialogue history, analyze in detail whether the model's response is accurate, complete, and relevant.
58
+ Assign a reward score between 0 and 1, where 0 indicates completely incorrect and 1 indicates fully correct.
59
+ Before finishing your response, please assign a reward using the following format:
60
+
61
+ Reward: {reward}
62
+
63
+ For example:
64
+ Reward: 0.85
65
+ """) # noqa
66
+
67
+ def __call__(self, inputs):
68
+ """
69
+ Compute reward scores for the provided inputs.
70
+
71
+ This method processes each input by converting dialogue messages into a query, sending the query to the
72
+ reward model for inference, and extracting the reward scores from the model's responses. The final reward
73
+ for each input is the average of all extracted scores.
74
+ Args:
75
+ inputs (List[Dict]): A list of input requests. Each input request is a dictionary containing:
76
+ - 'messages' (List[Dict]): messages from the training model. Each message dictionary includes:
77
+ - 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
78
+ - 'content' (str): The content of the message.
79
+ - Additional dataset columns as key-value pairs (e.g., 'solutions', 'images').
80
+ Returns:
81
+ torch.Tensor: A tensor containing the average reward scores for each input. The tensor has a shape of (N,),
82
+ where N is the number of input requests.
83
+ """
84
+
85
+ rm_inputs = self.prepare_rm_inputs(inputs)
86
+ results = self.engine.infer(rm_inputs, self.request_config, use_tqdm=False)
87
+ rewards = self.compute_rewards(results)
88
+ return torch.tensor(rewards, dtype=torch.float32)
89
+
90
+ def prepare_rm_inputs(self, inputs: List[Dict]) -> List[Dict]:
91
+ """
92
+ Prepare inputs for the reward model by converting messages into queries.
93
+
94
+ Args:
95
+ inputs (List[Dict]): A list of input requests.
96
+
97
+ Returns:
98
+ List[Dict]: Processed inputs for the reward model.
99
+ """
100
+ rm_inputs = []
101
+ for idx, infer_request in enumerate(inputs):
102
+ # Deep copy to prevent modification of original input
103
+ rm_infer_request = deepcopy(infer_request)
104
+
105
+ # Extract and convert messages to a single query string
106
+ messages = rm_infer_request.get('messages')
107
+ query = self.messages_to_query(messages)
108
+
109
+ # Construct new messages tailored for the reward model
110
+ rm_messages = [{'role': 'system', 'content': self.system}, {'role': 'user', 'content': query}]
111
+
112
+ # Update the messages in the reward infer request
113
+ rm_infer_request['messages'] = rm_messages
114
+ rm_inputs.append(rm_infer_request)
115
+ return rm_inputs
116
+
117
+ @staticmethod
118
+ def extract_reward(model_output: str) -> float:
119
+ """
120
+ Extract the reward score from the model's output.
121
+
122
+ Args:
123
+ model_output (str): The model's output string, expected to follow the format "Reward: {reward}".
124
+
125
+ Returns:
126
+ float: The extracted reward score.
127
+
128
+ Raises:
129
+ ValueError: If the reward score cannot be extracted or the format is incorrect.
130
+ """
131
+ match = re.search(r'Reward:\s*([0-1](?:\.\d+)?)', model_output)
132
+ if match:
133
+ return float(match.group(1))
134
+ else:
135
+ logger.warning("Unable to extract reward score from the model's output, set reward to 0")
136
+ return None
137
+
138
+ @staticmethod
139
+ def messages_to_query(messages):
140
+ """
141
+ Compress a list of message dictionaries into a single query string.
142
+
143
+ Args:
144
+ messages (list[dict]): A list of message dictionaries, each containing:
145
+ - 'role' (str): The role of the speaker (e.g., 'user', 'assistant').
146
+ - 'content' (str): The content of the message.
147
+
148
+ Returns:
149
+ str: A single string that concatenates all messages in a formatted manner.
150
+
151
+ Example:
152
+ >>> messages = [
153
+ ... {'role': 'user', 'content': 'Hello, how are you?'},
154
+ ... {'role': 'assistant', 'content': 'I am fine, thank you! How can I assist you today?'},
155
+ ... {'role': 'user', 'content': 'Can you help me with my homework?'}
156
+ ... ]
157
+ >>> print(messages_to_query(messages))
158
+ User: Hello, how are you?
159
+ Assistant: I am fine, thank you! How can I assist you today?
160
+ User: Can you help me with my homework?
161
+ """
162
+ # Initialize an empty list to hold formatted messages
163
+ formatted_messages = []
164
+
165
+ # Define a mapping for role capitalization if needed
166
+ role_mapping = {
167
+ 'user': 'User',
168
+ 'assistant': 'Assistant',
169
+ 'system': 'System'
170
+ # Add more roles here as needed
171
+ }
172
+
173
+ for idx, message in enumerate(messages):
174
+ if not isinstance(message, dict):
175
+ raise TypeError(f'Each message must be a dictionary. Found {type(message)} at index {idx}.')
176
+
177
+ # Extract 'role' and 'content' from each message
178
+ role = message.get('role')
179
+ content = message.get('content')
180
+ if not content:
181
+ continue
182
+
183
+ # Capitalize the role using the mapping, default to capitalized original role
184
+ role_formatted = role_mapping.get(role.lower(), role.capitalize())
185
+
186
+ # Append the formatted message to the list
187
+ formatted_messages.append(f'{role_formatted}: {content}')
188
+
189
+ # Join all formatted messages with newline characters
190
+ query = '\n'.join(formatted_messages)
191
+
192
+ return query
193
+
194
+ def compute_rewards(self, results: List[ChatCompletionResponse]) -> List[float]:
195
+ """
196
+ Compute average reward scores from the reward model's outputs.
197
+
198
+ Args:
199
+ results (List[ChatCompletionResponse]): A list of results from the reward model.
200
+
201
+ Returns:
202
+ List[float]: A list of average reward scores.
203
+ """
204
+ rewards = []
205
+ for idx, output in enumerate(results):
206
+ try:
207
+ cur_rewards = []
208
+ for choice in output.choices:
209
+ response = choice.message.content
210
+ reward = self.extract_reward(response)
211
+ cur_rewards.append(reward)
212
+ cur_rewards = [r for r in cur_rewards if r is not None]
213
+ if cur_rewards:
214
+ average_reward = sum(cur_rewards) / len(cur_rewards)
215
+ else:
216
+ average_reward = 0.0
217
+ logger.warning('No valid rewards extracted. Assigning reward score of 0.0.')
218
+
219
+ rewards.append(average_reward)
220
+ except Exception as e:
221
+ logger.error(f'Error computing reward: {e}')
222
+ rewards.append(0.0) # Assign default reward score on failure
223
+ return rewards
224
+
225
+
226
+ rm_plugins = {
227
+ 'default': DefaultRMPlugin,
228
+ 'genrm': GenRMPlugin,
229
+ }
ms-swift/swift/trainers/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import TYPE_CHECKING
3
+
4
+ from transformers.trainer_callback import TrainerCallback
5
+ from transformers.trainer_utils import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
6
+ SchedulerType)
7
+
8
+ from swift.utils.import_utils import _LazyModule
9
+ from . import callback
10
+
11
+ try:
12
+ # https://github.com/huggingface/transformers/pull/25702
13
+ from transformers.trainer_utils import ShardedDDPOption
14
+ except ImportError:
15
+ ShardedDDPOption = None
16
+
17
+ if TYPE_CHECKING:
18
+ from .arguments import Seq2SeqTrainingArguments, TrainingArguments
19
+ from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer,
20
+ RewardTrainer, GRPOTrainer)
21
+ from .rlhf_arguments import DPOConfig, CPOConfig, KTOConfig, ORPOConfig, PPOConfig, RewardConfig
22
+ from .trainer_factory import TrainerFactory
23
+ from .trainers import Seq2SeqTrainer, Trainer, EmbeddingTrainer
24
+ from .mixin import SwiftMixin
25
+
26
+ else:
27
+ _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
28
+ _import_structure = {
29
+ 'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
30
+ 'rlhf_arguments':
31
+ ['DPOConfig', 'CPOConfig', 'KTOConfig', 'ORPOConfig', 'PPOConfig', 'RewardConfig', 'GRPOConfig'],
32
+ 'rlhf_trainer': [
33
+ 'CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer',
34
+ 'GRPOTrainer'
35
+ ],
36
+ 'trainer_factory': ['TrainerFactory'],
37
+ 'trainers': ['Seq2SeqTrainer', 'Trainer', 'EmbeddingTrainer'],
38
+ 'mixin': ['SwiftMixin'],
39
+ }
40
+
41
+ import sys
42
+
43
+ sys.modules[__name__] = _LazyModule(
44
+ __name__,
45
+ globals()['__file__'],
46
+ _import_structure,
47
+ module_spec=__spec__,
48
+ extra_objects=_extra_objects,
49
+ )
ms-swift/swift/trainers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
ms-swift/swift/trainers/__pycache__/callback.cpython-310.pyc ADDED
Binary file (4.96 kB). View file
 
ms-swift/swift/trainers/__pycache__/trainer_factory.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
ms-swift/swift/trainers/__pycache__/trainers.cpython-310.pyc ADDED
Binary file (7.93 kB). View file
 
ms-swift/swift/trainers/callback.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import math
3
+ import os
4
+ import time
5
+
6
+ from tqdm import tqdm
7
+ from transformers import trainer
8
+ from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl,
9
+ TrainerState)
10
+ from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics
11
+
12
+ from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
13
+ from ..utils.utils import format_time
14
+ from .arguments import TrainingArguments
15
+
16
+
17
+ def add_train_message(logs, state, start_time) -> None:
18
+ logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}'
19
+ train_percentage = state.global_step / state.max_steps if state.max_steps else 0.
20
+ logs['percentage'] = f'{train_percentage * 100:.2f}%'
21
+ elapsed = time.time() - start_time
22
+ logs['elapsed_time'] = format_time(elapsed)
23
+ if train_percentage != 0:
24
+ logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed)
25
+ for k, v in logs.items():
26
+ if isinstance(v, float):
27
+ logs[k] = round(logs[k], 8)
28
+
29
+
30
+ class ProgressCallbackNew(ProgressCallback):
31
+
32
+ def on_train_begin(self, args, state, control, **kwargs):
33
+ if state.is_world_process_zero:
34
+ self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
35
+ self.current_step = 0
36
+ self.start_time = time.time()
37
+ if use_torchacc():
38
+ self.warmup_start_time = 0
39
+ self.warmup_metric = None
40
+ self.metric_warmup_step = int(args.metric_warmup_step
41
+ * state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step
42
+
43
+ def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):
44
+ if state.is_world_process_zero and has_length(eval_dataloader):
45
+ if self.prediction_bar is None:
46
+ if self.training_bar is not None:
47
+ self.training_bar.fp.write('\n')
48
+ self.prediction_bar = tqdm(
49
+ desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0)
50
+ self.prediction_bar.update()
51
+
52
+ def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
53
+
54
+ if use_torchacc():
55
+ if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
56
+ self.warmup_start_time = time.time()
57
+ self.metric_warmup_step = state.global_step
58
+ if state.max_steps == state.global_step and self.warmup_metric is None:
59
+ num_steps = state.max_steps - self.metric_warmup_step
60
+ num_total_samples = args.train_dataset_sample
61
+ num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps)
62
+ self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples,
63
+ num_steps)
64
+ self.warmup_metric['num_total_samples'] = num_total_samples
65
+ self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples
66
+ if 'train_samples_per_second' in logs:
67
+ logs.update(self.warmup_metric)
68
+ state.log_history[-1] = logs
69
+
70
+ add_train_message(logs, state, self.start_time)
71
+ if not is_pai_training_job() and state.is_world_process_zero:
72
+ jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
73
+ append_to_jsonl(jsonl_path, logs)
74
+ super().on_log(args, state, control, logs, **kwargs)
75
+ if state.is_world_process_zero and self.training_bar is not None:
76
+ self.training_bar.refresh()
77
+
78
+
79
+ class DefaultFlowCallbackNew(DefaultFlowCallback):
80
+
81
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
82
+ control = super().on_step_end(args, state, control, **kwargs)
83
+ # save the last ckpt
84
+ evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
85
+ if state.global_step == state.max_steps:
86
+ if evaluation_strategy != IntervalStrategy.NO:
87
+ control.should_evaluate = True
88
+ if args.save_strategy != IntervalStrategy.NO:
89
+ control.should_save = True
90
+ return control
91
+
92
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
93
+ control = super().on_epoch_end(args, state, control, **kwargs)
94
+ evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
95
+ if args.max_epochs is not None and args.max_epochs <= math.ceil(state.epoch):
96
+ if evaluation_strategy != IntervalStrategy.NO:
97
+ control.should_evaluate = True
98
+ if args.save_strategy != IntervalStrategy.NO:
99
+ control.should_save = True
100
+ control.should_training_stop = True
101
+ return control
102
+
103
+
104
+ class PrinterCallbackNew(PrinterCallback):
105
+
106
+ def on_train_begin(self, args, state, control, **kwargs):
107
+ self.start_time = time.time()
108
+ return super().on_train_begin(args, state, control, **kwargs)
109
+
110
+ def on_log(self, args, state, control, logs=None, **kwargs):
111
+ add_train_message(logs, state, self.start_time)
112
+ if not is_pai_training_job() and state.is_world_process_zero:
113
+ jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
114
+ append_to_jsonl(jsonl_path, logs)
115
+
116
+ _ = logs.pop('total_flos', None)
117
+ if state.is_world_process_zero:
118
+ print(logs, flush=True)
119
+
120
+
121
+ # monkey patching
122
+ trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
123
+ trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
124
+ trainer.PrinterCallback = PrinterCallbackNew
ms-swift/swift/trainers/mixin.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ # Part of the implementation is borrowed from huggingface/transformers.
3
+ import inspect
4
+ import os
5
+ import shutil
6
+ import time
7
+ from contextlib import contextmanager
8
+ from copy import copy
9
+ from functools import partial
10
+ from types import MethodType
11
+ from typing import Callable, Dict, List, Optional, Tuple, Union
12
+
13
+ import safetensors
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ import transformers
18
+ from datasets import Dataset as HfDataset
19
+ from modelscope import check_local_model_is_latest
20
+ from packaging import version
21
+ from peft import PeftModel
22
+ from torch.nn import Module
23
+ from torch.utils.data import DataLoader
24
+ from transformers import PreTrainedModel
25
+ from transformers.data.data_collator import DataCollator
26
+ from transformers.integrations import is_deepspeed_zero3_enabled
27
+ from transformers.modeling_utils import unwrap_model
28
+ from transformers.trainer import TrainerCallback
29
+ from transformers.trainer_utils import EvalPrediction, IntervalStrategy
30
+ from transformers.utils import is_torch_npu_available
31
+
32
+ from swift.hub import get_hub
33
+ from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
34
+ from swift.plugin import MeanMetric, compute_acc, extra_tuners
35
+ from swift.tuners import SwiftModel
36
+ from swift.utils import get_logger, is_mp_ddp, use_torchacc
37
+ from swift.utils.torchacc_utils import ta_trim_graph
38
+ from ..utils.torch_utils import get_device_count
39
+ from .arguments import TrainingArguments
40
+ from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model
41
+
42
+ try:
43
+ from trl import AutoModelForCausalLMWithValueHead
44
+ except (ImportError, RuntimeError):
45
+ AutoModelForCausalLMWithValueHead = None
46
+
47
+ logger = get_logger()
48
+
49
+
50
+ class SwiftMixin:
51
+
52
+ def __init__(self,
53
+ model: Union[PreTrainedModel, Module] = None,
54
+ args: TrainingArguments = None,
55
+ data_collator: Optional[DataCollator] = None,
56
+ train_dataset: Optional[HfDataset] = None,
57
+ eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
58
+ template: Optional[Template] = None,
59
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
60
+ compute_loss_func: Optional[Callable] = None,
61
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
62
+ callbacks: Optional[List[TrainerCallback]] = None,
63
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
64
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
65
+ **kwargs) -> None:
66
+ if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1:
67
+ args.dataloader_num_workers = 1
68
+ logger.warning('Using IterableDataset, setting args.dataloader_num_workers to 1.')
69
+
70
+ if args.check_model and hasattr(model, 'model_dir'):
71
+ from swift.utils.logger import ms_logger_ignore_error
72
+ with ms_logger_ignore_error():
73
+ check_local_model_is_latest(
74
+ model.model_dir, user_agent={
75
+ 'invoked_by': 'local_trainer',
76
+ 'third_party': 'swift',
77
+ })
78
+ if eval_dataset is None and args:
79
+ args.evaluation_strategy = IntervalStrategy.NO
80
+ args.eval_strategy = IntervalStrategy.NO
81
+
82
+ self._custom_metrics = {}
83
+ self.template = template
84
+ self.max_memory = 0
85
+ self.hub = get_hub()
86
+
87
+ self.model_meta = model.model_meta
88
+ with self.hub.patch_hub():
89
+ super().__init__(
90
+ model=model,
91
+ args=args,
92
+ data_collator=data_collator,
93
+ train_dataset=train_dataset,
94
+ eval_dataset=eval_dataset,
95
+ tokenizer=template.tokenizer,
96
+ model_init=model_init,
97
+ compute_metrics=compute_metrics,
98
+ callbacks=callbacks,
99
+ optimizers=optimizers,
100
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
101
+ **kwargs)
102
+
103
+ self.compute_loss_func = compute_loss_func
104
+ if get_function(model.__class__.forward) is not get_function(model.forward):
105
+ self.label_names = find_labels(model)
106
+ self.can_return_loss = can_return_loss(model)
107
+ self.label_names = self.label_names or ['labels']
108
+ self.start_time = time.time()
109
+ if self.template.sequence_parallel_size > 1:
110
+ from swift.trainers.sequence_parallel import sequence_parallel
111
+ sequence_parallel.prepare_trainer(self)
112
+
113
+ def _save_initial_model(self, output_dir):
114
+ # pissa/olora/lora-ga
115
+ model = unwrap_model(self.model)
116
+ if isinstance(model, PeftModel):
117
+ config = model.peft_config.get('default')
118
+ init_lora_weights = getattr(config, 'init_lora_weights', None)
119
+ if (isinstance(init_lora_weights, str)
120
+ and any(s in init_lora_weights for s in ('pissa', 'olora', 'lora-ga'))):
121
+ config.init_lora_weights = True
122
+ model.save_pretrained(os.path.join(output_dir, 'initial_model'))
123
+ config.init_lora_weights = init_lora_weights
124
+
125
+ def _save_converted_model(self, output_dir):
126
+ # pissa/olora/lora-ga
127
+ model = unwrap_model(self.model)
128
+ if isinstance(model, PeftModel):
129
+ config = model.peft_config.get('default')
130
+ init_lora_weights = getattr(config, 'init_lora_weights', None)
131
+ if isinstance(init_lora_weights, str):
132
+ config = copy(config)
133
+ os.makedirs(os.path.join(output_dir, 'converted'), exist_ok=True)
134
+ if 'lora-ga' in init_lora_weights:
135
+ try:
136
+ from lora_ga.entrypoint import LoraGAContext
137
+ with LoraGAContext(model):
138
+ model.save_pretrained(
139
+ os.path.join(output_dir, 'converted', 'default'),
140
+ path_initial_model_for_weight_conversion=os.path.join(
141
+ os.path.dirname(output_dir), 'initial_model'),
142
+ )
143
+ model.peft_config['default'] = config
144
+ except ImportError as e:
145
+ error_message = """
146
+ Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
147
+ Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
148
+ """
149
+ logger.info(error_message)
150
+ raise RuntimeError(error_message) from e
151
+ elif 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
152
+ model.save_pretrained(
153
+ os.path.join(output_dir, 'converted', 'default'),
154
+ path_initial_model_for_weight_conversion=os.path.join(
155
+ os.path.dirname(output_dir), 'initial_model'),
156
+ )
157
+ model.peft_config['default'] = config
158
+
159
+ def _load_optimizer_and_scheduler(self, *args, **kwargs):
160
+ super()._load_optimizer_and_scheduler(*args, **kwargs)
161
+ if is_mp_ddp():
162
+ # fix mp+ddp adamw
163
+ for v in self.optimizer.state.values():
164
+ if 'step' in v:
165
+ # not on the same device
166
+ device_set = set([t.device for t in v.values()]) - {v['step'].device, torch.device('cpu')}
167
+ if len(device_set) >= 1:
168
+ v['step'] = v['step'].to('cpu')
169
+
170
+ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
171
+ # model
172
+ supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
173
+ supported_names = ('SentenceTransformer')
174
+ if AutoModelForCausalLMWithValueHead is not None:
175
+ supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
176
+ save_safetensors = self.args.save_safetensors
177
+ if not isinstance(self.model, supported_classes) and self.model.__class__.__name__ not in supported_names:
178
+ if state_dict is None:
179
+ state_dict = self.model.state_dict()
180
+
181
+ _unwrap_model = unwrap_model(self.model)
182
+ if isinstance(_unwrap_model, supported_classes):
183
+ _unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
184
+ else:
185
+ logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
186
+ if save_safetensors:
187
+ safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
188
+ else:
189
+ torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
190
+ elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead):
191
+ # save reward model
192
+ state_dict = self.model.state_dict()
193
+ decoder_state_dict, v_head_state_dict = {}, {}
194
+ for name, param in state_dict.items():
195
+ if name.startswith('v_head.'):
196
+ v_head_state_dict[name] = param
197
+ else:
198
+ decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param
199
+ self.model.pretrained_model.save_pretrained(
200
+ output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors)
201
+ if save_safetensors:
202
+ from safetensors.torch import save_file
203
+ save_file(
204
+ v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'})
205
+ else:
206
+ torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin'))
207
+ elif is_instance_of_ms_model(self.model):
208
+ PreTrainedModel.save_pretrained(
209
+ self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
210
+ elif self.args.train_type in extra_tuners:
211
+ extra_tuners[self.args.train_type].save_pretrained(
212
+ self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
213
+ else:
214
+ if self.model.__class__.__name__ != 'SentenceTransformer':
215
+ self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
216
+ else:
217
+
218
+ @contextmanager
219
+ def save_context():
220
+ save_pretrained = self.model[0].auto_model.save_pretrained
221
+ _state_dict = {
222
+ key[len('0.auto_model.'):] if 'auto_model' in key else key: value
223
+ for key, value in state_dict.items()
224
+ }
225
+ self.model[0].auto_model.save_pretrained = partial(
226
+ self.model[0].auto_model.save_pretrained, state_dict=_state_dict)
227
+ yield
228
+ self.model[0].auto_model.save_pretrained = save_pretrained
229
+
230
+ with save_context():
231
+ self.model.save_pretrained(output_dir, safe_serialization=save_safetensors)
232
+ # copy sentencetransformers files
233
+ from swift.utils import copy_files_by_pattern
234
+ copy_files_by_pattern(self.model.model_dir, output_dir, '*.py')
235
+ copy_files_by_pattern(self.model.model_dir, output_dir, '*.json')
236
+
237
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
238
+ """Compatible with swift and peft"""
239
+ # If we are executing this function, we are the process zero, so we don't check for that.
240
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
241
+ os.makedirs(output_dir, exist_ok=True)
242
+ self._save_model(output_dir, state_dict)
243
+ # training_args.bin
244
+ torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
245
+ self._save_converted_model(output_dir)
246
+ # args.json
247
+ args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
248
+ if os.path.exists(args_path):
249
+ shutil.copy(args_path, os.path.join(output_dir, 'args.json'))
250
+ # predict.jsonl
251
+ predict_jsonl = os.path.join(os.path.dirname(output_dir), 'predict.jsonl')
252
+ if os.path.exists(predict_jsonl):
253
+ shutil.move(predict_jsonl, os.path.join(output_dir, 'predict.jsonl'))
254
+
255
+ is_adapter = isinstance(self.model, (SwiftModel, PeftModel))
256
+ # tokenizer
257
+ if not is_adapter:
258
+ from swift.llm import save_checkpoint
259
+ additional_saved_files = self.model_meta.additional_saved_files
260
+ save_checkpoint(
261
+ None,
262
+ self.template.processor,
263
+ output_dir,
264
+ model_dirs=[self.model.model_dir],
265
+ additional_saved_files=additional_saved_files)
266
+ if getattr(self.model, 'origin_generation_config', None):
267
+ self.model.origin_generation_config.save_pretrained(output_dir)
268
+
269
+ def _fix_zero3_gather_all_parameters(self) -> None:
270
+ if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'):
271
+ parameters = inspect.signature(self.deepspeed._zero3_consolidated_16bit_state_dict).parameters
272
+ if 'exclude_frozen_parameters' in parameters:
273
+
274
+ def _zero3_consolidated_16bit_state_dict(model, exclude_frozen_parameters=False):
275
+ unwrapped = unwrap_model(model)
276
+ exclude_frozen_parameters = False
277
+ if isinstance(unwrapped, SwiftModel) and unwrapped.has_additional_modules:
278
+ exclude_frozen_parameters = True
279
+ if isinstance(unwrapped, PeftModel):
280
+ exclude_frozen_parameters = True
281
+ return model._zero3_consolidated_16bit_state_dict_origin(exclude_frozen_parameters)
282
+
283
+ self.deepspeed._zero3_consolidated_16bit_state_dict_origin = (
284
+ self.deepspeed._zero3_consolidated_16bit_state_dict)
285
+ self.deepspeed._zero3_consolidated_16bit_state_dict = MethodType(_zero3_consolidated_16bit_state_dict,
286
+ self.deepspeed)
287
+
288
+ def _save_checkpoint(self, *args, **kwargs):
289
+ self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
290
+ self._fix_zero3_gather_all_parameters()
291
+ result = super()._save_checkpoint(*args, **kwargs)
292
+ logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
293
+ return result
294
+
295
+ @staticmethod
296
+ @contextmanager
297
+ def _fix_grad_norm_nan():
298
+ from accelerate import Accelerator
299
+ origin_clip_grad_norm_ = Accelerator.clip_grad_norm_
300
+
301
+ def clip_grad_norm_(self, parameters, *args, **kwargs):
302
+ # If NaN occurs, ignore weight updates.
303
+ parameters = list(parameters)
304
+ grad_norm = origin_clip_grad_norm_(self, parameters, *args, **kwargs)
305
+ if isinstance(grad_norm, torch.Tensor) and grad_norm.isnan().item():
306
+ for p in parameters:
307
+ p.grad = None
308
+ return grad_norm
309
+
310
+ Accelerator.clip_grad_norm_ = clip_grad_norm_
311
+ try:
312
+ yield
313
+ finally:
314
+ Accelerator.clip_grad_norm_ = origin_clip_grad_norm_
315
+
316
+ def train(self, *args, **kwargs):
317
+ if self.model_meta.is_multimodal:
318
+ models = []
319
+ for model_name in ['model', 'ref_model', 'value_model']:
320
+ model = getattr(self, model_name, None)
321
+ if isinstance(model, nn.Module):
322
+ models.append(model)
323
+
324
+ reward_model = getattr(self, 'reward_model', None)
325
+ if reward_model is not None:
326
+ if isinstance(reward_model, list):
327
+ models.extend([m for m in reward_model if isinstance(m, nn.Module)])
328
+ elif isinstance(reward_model, nn.Module):
329
+ models.append(reward_model)
330
+
331
+ models = list(set(models)) # Deduplicate
332
+ self.template.register_post_encode_hook(models)
333
+ logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}.')
334
+ self._save_initial_model(self.args.output_dir)
335
+ with self.hub.patch_hub(), self._fix_grad_norm_nan():
336
+ res = super().train(*args, **kwargs)
337
+ self.template.remove_post_encode_hook()
338
+ return res
339
+
340
+ def push_to_hub(self, *args, **kwargs):
341
+ with self.hub.patch_hub():
342
+ return super().push_to_hub(*args, **kwargs)
343
+
344
+ def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float:
345
+ if device is None:
346
+ mems = [torch.cuda.max_memory_reserved(device=device) for device in range(get_device_count())]
347
+ else:
348
+ mems = [torch.cuda.max_memory_reserved(device=device)]
349
+ mem = sum(mems) / 1024**3
350
+ self.max_memory = max(self.max_memory, mem)
351
+ return mem
352
+
353
+ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
354
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
355
+ self.control.should_log = False
356
+
357
+ # all_gather + mean() to get average loss over all processes
358
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
359
+ loss = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
360
+ logs: Dict[str, float] = {'loss': loss} # loss first
361
+
362
+ for k, metric in self._custom_metrics.items():
363
+ value = metric.compute()
364
+ if len(value) == 1:
365
+ val = list(value.values())[0]
366
+ logs[k] = val
367
+ else:
368
+ for k_suffix, val in value.items():
369
+ new_k = f'{k}_{k_suffix}'
370
+ logs[new_k] = val
371
+ metric.reset()
372
+
373
+ if version.parse(transformers.__version__) >= version.parse('4.38'):
374
+ grad_norm = args[0]
375
+ if grad_norm is not None:
376
+ logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
377
+ logs['learning_rate'] = self._get_learning_rate()
378
+ if not is_torch_npu_available():
379
+ logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2)
380
+
381
+ elapse_time = time.time() - self.start_time
382
+ logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6)
383
+ for k in list(logs.keys()):
384
+ if logs[k] is None:
385
+ logs.pop(k)
386
+ tr_loss -= tr_loss
387
+ self._total_loss_scalar += tr_loss_scalar
388
+ self._globalstep_last_logged = self.state.global_step
389
+ self.store_flos()
390
+ self.log(logs)
391
+
392
+ if self.args.eval_use_evalscope and self.control.should_evaluate:
393
+ self._evalscope_eval()
394
+ super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
395
+
396
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
397
+ if self.args.optimizer is not None:
398
+ from swift.plugin import optimizers_map
399
+ optimizer_callback = optimizers_map[self.args.optimizer]
400
+ self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
401
+ if self.optimizer is None:
402
+ self.create_optimizer()
403
+ if self.lr_scheduler is None:
404
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
405
+ else:
406
+ super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)
407
+
408
+ def _compute_acc(self, outputs, labels) -> None:
409
+ args = self.args
410
+ acc_steps = args.acc_steps
411
+ preds = outputs.logits.argmax(dim=-1)
412
+ if self.state.global_step % acc_steps == 0:
413
+ if use_torchacc():
414
+ ta_trim_graph()
415
+ preds = preds.to('cpu')
416
+ labels = labels.to('cpu')
417
+ metrics = compute_acc(
418
+ preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
419
+ for k, v in metrics.items():
420
+ if k not in self._custom_metrics:
421
+ self._custom_metrics[k] = MeanMetric(nan_value=None)
422
+ self._custom_metrics[k].update(v)
423
+
424
+ @torch.no_grad()
425
+ def _evalscope_eval(self):
426
+ from ..llm.eval.utils import EvalModel
427
+ from evalscope import TaskConfig, run_task
428
+ from evalscope.constants import EvalType
429
+
430
+ self.model.eval()
431
+ max_batch_size = self.args.per_device_eval_batch_size
432
+ custom_model = EvalModel(
433
+ self.model, self.template, max_batch_size=max_batch_size, model_name=f'model-step{self.state.global_step}')
434
+ task_config = TaskConfig(
435
+ model=custom_model,
436
+ eval_type=EvalType.CUSTOM,
437
+ datasets=self.args.eval_datasets,
438
+ dataset_args=self.args.eval_datasets_args,
439
+ limit=self.args.eval_limit,
440
+ work_dir=os.path.join(self.args.output_dir, 'eval'),
441
+ eval_batch_size=max_batch_size,
442
+ generation_config=self.args.eval_generation_config or {'max_tokens': 512},
443
+ )
444
+ # start evaluation
445
+ eval_report = run_task(task_config)
446
+ # convert to dict
447
+ eval_dict = {f'test_{k}': v.score for k, v in eval_report.items()}
448
+ self.log(eval_dict)
449
+
450
+ self.model.train()
451
+ return eval_dict
452
+
453
+ def get_batch_samples(self, *args, **kwargs):
454
+ res = super().get_batch_samples(*args, **kwargs)
455
+ if self.template.sequence_parallel_size == 1:
456
+ return res
457
+
458
+ batch_samples, num_items_in_batch = res
459
+ if num_items_in_batch is None:
460
+ num_items_in_batch = torch.tensor(0).to(args[2])
461
+ from swift.trainers.sequence_parallel import sequence_parallel
462
+ dist.all_reduce(num_items_in_batch, dist.ReduceOp.SUM, sequence_parallel.sp_group)
463
+ return batch_samples, num_items_in_batch
464
+
465
+
466
+ class DataLoaderMixin:
467
+
468
+ def get_train_dataloader(self):
469
+ dataloader = None
470
+ if self.template.sequence_parallel_size > 1:
471
+ from swift.trainers.sequence_parallel import sequence_parallel
472
+ dataloader = sequence_parallel.get_dataloader(self, self.train_dataset, self._train_batch_size)
473
+ if dataloader is None:
474
+ # Higher efficiency
475
+ if self.train_dataset is None:
476
+ raise ValueError('Trainer: training requires a train_dataset.')
477
+ args = self.args
478
+ train_dataset = self.train_dataset
479
+
480
+ dataloader_params = {
481
+ 'collate_fn': self.data_collator,
482
+ 'num_workers': args.dataloader_num_workers,
483
+ 'pin_memory': args.dataloader_pin_memory,
484
+ 'persistent_workers': args.dataloader_persistent_workers,
485
+ 'prefetch_factor': args.dataloader_prefetch_factor
486
+ }
487
+ batch_sampler_params = {
488
+ 'drop_last': args.dataloader_drop_last,
489
+ 'shuffle': args.train_dataloader_shuffle,
490
+ 'data_seed': args.data_seed,
491
+ }
492
+
493
+ if hasattr(train_dataset, '__len__'):
494
+ batch_sampler = BatchSamplerShard(
495
+ len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params)
496
+ dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params)
497
+ else:
498
+ # IterableDataset
499
+ if dist.is_initialized() and dataloader_params['prefetch_factor']:
500
+ dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
501
+ dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params)
502
+ dataloader = DataLoaderDispatcher(dataloader)
503
+
504
+ return dataloader
505
+
506
+ def get_eval_dataloader(self, eval_dataset=None):
507
+ dataloader = None
508
+ if self.template.sequence_parallel_size > 1:
509
+ from swift.trainers.sequence_parallel import sequence_parallel
510
+ if eval_dataset is None and self.eval_dataset is None:
511
+ raise ValueError('Trainer: evaluation requires an eval_dataset.')
512
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
513
+ dataloader = sequence_parallel.get_dataloader(self, eval_dataset, self.args.eval_batch_size)
514
+ if dataloader is None:
515
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
516
+ return dataloader
ms-swift/swift/trainers/optimizers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
ms-swift/swift/trainers/optimizers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (180 Bytes). View file
 
ms-swift/swift/trainers/optimizers/galore/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from swift.utils.import_utils import _LazyModule
6
+
7
+ if TYPE_CHECKING:
8
+ from .utils import create_optimizer_and_scheduler, GaLoreConfig
9
+ from .adafactor import GaLoreAdafactor
10
+ from .adamw8bit import GaLoreAdamW8bit
11
+ from .adamw import GaLoreAdamW
12
+ else:
13
+ _import_structure = {
14
+ 'utils': ['GaLoreConfig', 'create_optimizer_and_scheduler'],
15
+ 'adafactor': ['GaLoreAdafactor'],
16
+ 'adamw8bit': ['GaLoreAdamW8bit'],
17
+ 'adamw': ['GaLoreAdamW'],
18
+ }
19
+
20
+ import sys
21
+
22
+ sys.modules[__name__] = _LazyModule(
23
+ __name__,
24
+ globals()['__file__'],
25
+ _import_structure,
26
+ module_spec=__spec__,
27
+ extra_objects={},
28
+ )
ms-swift/swift/trainers/optimizers/galore/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (765 Bytes). View file
 
ms-swift/swift/trainers/optimizers/galore/adafactor.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy dependencies from transformers/optimization.py
2
+ # code borrowed from https://github.com/jiaweizzhao/GaLore
3
+ import math
4
+
5
+ import torch
6
+ from torch.optim import Optimizer
7
+ from transformers.utils.versions import require_version
8
+
9
+ from .galore_projector import GaLoreProjector
10
+
11
+
12
+ class Adafactor(Optimizer):
13
+ """
14
+ AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
15
+ https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
16
+
17
+ Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
18
+ this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
19
+ `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
20
+ `relative_step=False`.
21
+
22
+ Arguments:
23
+ params (`Iterable[nn.parameter.Parameter]`):
24
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
25
+ lr (`float`, *optional*):
26
+ The external learning rate.
27
+ eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
28
+ Regularization constants for square gradient and parameter scale respectively
29
+ clip_threshold (`float`, *optional*, defaults to 1.0):
30
+ Threshold of root mean square of final gradient update
31
+ decay_rate (`float`, *optional*, defaults to -0.8):
32
+ Coefficient used to compute running averages of square
33
+ beta1 (`float`, *optional*):
34
+ Coefficient used for computing running averages of gradient
35
+ weight_decay (`float`, *optional*, defaults to 0.0):
36
+ Weight decay (L2 penalty)
37
+ scale_parameter (`bool`, *optional*, defaults to `True`):
38
+ If True, learning rate is scaled by root mean square
39
+ relative_step (`bool`, *optional*, defaults to `True`):
40
+ If True, time-dependent learning rate is computed instead of external learning rate
41
+ warmup_init (`bool`, *optional*, defaults to `False`):
42
+ Time-dependent learning rate computation depends on whether warm-up initialization is being used
43
+
44
+ This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
45
+
46
+ Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
47
+
48
+ - Training without LR warmup or clip_threshold is not recommended.
49
+
50
+ - use scheduled LR warm-up to fixed LR
51
+ - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
52
+ - Disable relative updates
53
+ - Use scale_parameter=False
54
+ - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
55
+
56
+ Example:
57
+
58
+ ```python
59
+ Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
60
+ ```
61
+
62
+ Others reported the following combination to work well:
63
+
64
+ ```python
65
+ Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
66
+ ```
67
+
68
+ When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
69
+ scheduler as following:
70
+
71
+ ```python
72
+ from transformers.optimization import Adafactor, AdafactorSchedule
73
+
74
+ optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
75
+ lr_scheduler = AdafactorSchedule(optimizer)
76
+ trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
77
+ ```
78
+
79
+ Usage:
80
+
81
+ ```python
82
+ # replace AdamW with Adafactor
83
+ optimizer = Adafactor(
84
+ model.parameters(),
85
+ lr=1e-3,
86
+ eps=(1e-30, 1e-3),
87
+ clip_threshold=1.0,
88
+ decay_rate=-0.8,
89
+ beta1=None,
90
+ weight_decay=0.0,
91
+ relative_step=False,
92
+ scale_parameter=False,
93
+ warmup_init=False,
94
+ )
95
+ ```"""
96
+
97
+ def __init__(
98
+ self,
99
+ params,
100
+ lr=None,
101
+ eps=(1e-30, 1e-3),
102
+ clip_threshold=1.0,
103
+ decay_rate=-0.8,
104
+ beta1=None,
105
+ weight_decay=0.0,
106
+ scale_parameter=True,
107
+ relative_step=True,
108
+ warmup_init=False,
109
+ ):
110
+ require_version('torch>=1.5.0') # add_ with alpha
111
+ if lr is not None and relative_step:
112
+ raise ValueError('Cannot combine manual `lr` and `relative_step=True` options')
113
+ if warmup_init and not relative_step:
114
+ raise ValueError('`warmup_init=True` requires `relative_step=True`')
115
+
116
+ defaults = {
117
+ 'lr': lr,
118
+ 'eps': eps,
119
+ 'clip_threshold': clip_threshold,
120
+ 'decay_rate': decay_rate,
121
+ 'beta1': beta1,
122
+ 'weight_decay': weight_decay,
123
+ 'scale_parameter': scale_parameter,
124
+ 'relative_step': relative_step,
125
+ 'warmup_init': warmup_init,
126
+ }
127
+ super().__init__(params, defaults)
128
+
129
+ @staticmethod
130
+ def _get_lr(param_group, param_state):
131
+ rel_step_sz = param_group['lr']
132
+ if param_group['relative_step']:
133
+ min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
134
+ rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state['step']))
135
+ param_scale = 1.0
136
+ if param_group['scale_parameter']:
137
+ param_scale = max(param_group['eps'][1], param_state['RMS'])
138
+ return param_scale * rel_step_sz
139
+
140
+ @staticmethod
141
+ def _get_options(param_group, param_shape):
142
+ factored = len(param_shape) >= 2
143
+ use_first_moment = param_group['beta1'] is not None
144
+ return factored, use_first_moment
145
+
146
+ @staticmethod
147
+ def _rms(tensor):
148
+ return tensor.norm(2) / (tensor.numel()**0.5)
149
+
150
+ @staticmethod
151
+ def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
152
+ # copy from fairseq's adafactor implementation:
153
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
154
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
155
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
156
+ return torch.mul(r_factor, c_factor)
157
+
158
+ @torch.no_grad()
159
+ def step(self, closure=None):
160
+ """
161
+ Performs a single optimization step
162
+
163
+ Arguments:
164
+ closure (callable, optional): A closure that reevaluates the model
165
+ and returns the loss.
166
+ """
167
+ loss = None
168
+ if closure is not None:
169
+ loss = closure()
170
+
171
+ for group in self.param_groups:
172
+ for p in group['params']:
173
+ if p.grad is None:
174
+ continue
175
+ grad = p.grad
176
+ if grad.dtype in {torch.float16, torch.bfloat16}:
177
+ grad = grad.float()
178
+ if grad.is_sparse:
179
+ raise RuntimeError('Adafactor does not support sparse gradients.')
180
+
181
+ state = self.state[p]
182
+
183
+ if 'step' not in state:
184
+ state['step'] = 0
185
+
186
+ # GaLore Projection
187
+ if 'rank' in group:
188
+ if 'projector' not in state:
189
+ state['projector'] = GaLoreProjector(
190
+ group['rank'],
191
+ update_proj_gap=group['update_proj_gap'],
192
+ scale=group['scale'],
193
+ proj_type=group['proj_type'])
194
+
195
+ grad = state['projector'].project(grad, state['step'])
196
+
197
+ grad_shape = grad.shape
198
+
199
+ factored, use_first_moment = self._get_options(group, grad_shape)
200
+ # State Initialization
201
+ if 'RMS' not in state:
202
+ state['step'] = 0
203
+
204
+ if use_first_moment:
205
+ # Exponential moving average of gradient values
206
+ state['exp_avg'] = torch.zeros_like(grad)
207
+ if factored:
208
+ state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
209
+ state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
210
+ else:
211
+ state['exp_avg_sq'] = torch.zeros_like(grad)
212
+
213
+ state['RMS'] = 0
214
+ else:
215
+ if use_first_moment:
216
+ state['exp_avg'] = state['exp_avg'].to(grad)
217
+ if factored:
218
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
219
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
220
+ else:
221
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
222
+
223
+ p_data_fp32 = p
224
+ if p.dtype in {torch.float16, torch.bfloat16}:
225
+ p_data_fp32 = p_data_fp32.float()
226
+
227
+ state['step'] += 1
228
+ state['RMS'] = self._rms(p_data_fp32)
229
+ lr = self._get_lr(group, state)
230
+
231
+ beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
232
+ update = (grad**2) + group['eps'][0]
233
+ if factored:
234
+ exp_avg_sq_row = state['exp_avg_sq_row']
235
+ exp_avg_sq_col = state['exp_avg_sq_col']
236
+
237
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
238
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
239
+
240
+ # Approximation of exponential moving average of square of gradient
241
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
242
+ update.mul_(grad)
243
+ else:
244
+ exp_avg_sq = state['exp_avg_sq']
245
+
246
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
247
+ update = exp_avg_sq.rsqrt().mul_(grad)
248
+
249
+ update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
250
+ update.mul_(lr)
251
+
252
+ if use_first_moment:
253
+ exp_avg = state['exp_avg']
254
+ exp_avg.mul_(group['beta1']).add_(update, alpha=(1 - group['beta1']))
255
+ update = exp_avg
256
+
257
+ # GaLore Projection Back
258
+ if 'rank' in group:
259
+ update = state['projector'].project_back(update)
260
+
261
+ if group['weight_decay'] != 0:
262
+ p_data_fp32.add_(p_data_fp32, alpha=(-group['weight_decay'] * lr))
263
+
264
+ p_data_fp32.add_(-update)
265
+
266
+ if p.dtype in {torch.float16, torch.bfloat16}:
267
+ p.copy_(p_data_fp32)
268
+
269
+ return loss
270
+
271
+
272
+ GaLoreAdafactor = Adafactor
ms-swift/swift/trainers/optimizers/galore/galore_projector.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code borrowed from https://github.com/jiaweizzhao/GaLore
2
+
3
+ import torch
4
+
5
+
6
+ class GaLoreProjector:
7
+
8
+ def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type='std'):
9
+ self.rank = rank
10
+ self.verbose = verbose
11
+ self.update_proj_gap = update_proj_gap
12
+ self.scale = scale
13
+ self.ortho_matrix = None
14
+ self.proj_type = proj_type
15
+
16
+ def project(self, full_rank_grad, iter):
17
+
18
+ if self.proj_type == 'std':
19
+ if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
20
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
21
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
22
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
23
+ else:
24
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
25
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
26
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
27
+ elif self.proj_type == 'reverse_std':
28
+ if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
29
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
30
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
31
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
32
+ else:
33
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
34
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
35
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
36
+ elif self.proj_type == 'right':
37
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
38
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='right')
39
+ low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
40
+ elif self.proj_type == 'left':
41
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
42
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='left')
43
+ low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
44
+ elif self.proj_type == 'full':
45
+ if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
46
+ self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type='full')
47
+ low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
48
+
49
+ return low_rank_grad
50
+
51
+ def project_back(self, low_rank_grad):
52
+
53
+ if self.proj_type == 'std':
54
+ if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
55
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
56
+ else:
57
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
58
+ elif self.proj_type == 'reverse_std':
59
+ if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
60
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
61
+ else:
62
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
63
+ elif self.proj_type == 'right':
64
+ full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
65
+ elif self.proj_type == 'left':
66
+ full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
67
+ elif self.proj_type == 'full':
68
+ full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
69
+
70
+ return full_rank_grad * self.scale
71
+
72
+ # svd decomposition
73
+ def get_orthogonal_matrix(self, weights, rank, type):
74
+ module_params = weights
75
+
76
+ if module_params.data.dtype != torch.float:
77
+ float_data = False
78
+ original_type = module_params.data.dtype
79
+ original_device = module_params.data.device
80
+ matrix = module_params.data.float()
81
+ else:
82
+ float_data = True
83
+ matrix = module_params.data
84
+
85
+ U, s, Vh = torch.linalg.svd(matrix, full_matrices=False)
86
+
87
+ # make the smaller matrix always to be orthogonal matrix
88
+ if type == 'right':
89
+ A = U[:, :rank] @ torch.diag(s[:rank])
90
+ B = Vh[:rank, :]
91
+
92
+ if not float_data:
93
+ B = B.to(original_device).type(original_type)
94
+ return B
95
+ elif type == 'left':
96
+ A = U[:, :rank]
97
+ B = torch.diag(s[:rank]) @ Vh[:rank, :]
98
+ if not float_data:
99
+ A = A.to(original_device).type(original_type)
100
+ return A
101
+ elif type == 'full':
102
+ A = U[:, :rank]
103
+ B = Vh[:rank, :]
104
+ if not float_data:
105
+ A = A.to(original_device).type(original_type)
106
+ B = B.to(original_device).type(original_type)
107
+ return [A, B]
108
+ else:
109
+ raise ValueError('type should be left, right or full')
ms-swift/swift/trainers/optimizers/galore/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import importlib
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Tuple, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Optimizer
9
+ from transformers import Trainer, TrainingArguments, get_scheduler
10
+
11
+ from swift.utils import get_logger
12
+
13
+ try:
14
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
15
+ except ImportError:
16
+ from torch.optim.lr_scheduler import LRScheduler
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ @dataclass
22
+ class GaLoreConfig:
23
+ """
24
+ The configuration class for the Galore module.
25
+
26
+
27
+ See https://arxiv.org/abs/2403.03507
28
+
29
+ Args:
30
+ rank (`int`): The galore rank
31
+ target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
32
+ will use all attn and mlp linears
33
+ update_proj_gap(`int`): The projection update interval for galore
34
+ proj_type(`str`) The project type of Galore, valid values are `std`,
35
+ `reverse_std`, `right`, `left`, `full`
36
+ galore_scale(float): the scale of gradient
37
+ optim_per_parameter(bool): Gives one optimizer per parameter
38
+ """
39
+ rank: int = 128
40
+ target_modules: Union[str, List[str]] = None
41
+ update_proj_gap: int = 50
42
+ galore_scale: float = 1.0
43
+ proj_type: str = 'std'
44
+ optim_per_parameter: bool = False
45
+ quantize: bool = False
46
+ proj_quant: bool = False
47
+ proj_bits: int = 4
48
+ proj_group_size: int = 256
49
+ cos_threshold: float = 0.4
50
+ gamma_proj: int = 2
51
+ queue_size: int = 5
52
+
53
+
54
+ class GaloreOptimizerWrapper(Optimizer):
55
+
56
+ def __init__(self, optimizers: Dict[Any, Optimizer]):
57
+ self.optimizers = optimizers
58
+ super().__init__([torch.tensor([1., 2., 3.])], {'lr': 1.})
59
+
60
+ def zero_grad(self, *args, **kwargs) -> None:
61
+ for optim in self.optimizers.values():
62
+ optim.zero_grad(*args, **kwargs)
63
+
64
+ def step(self, *args, **kwargs) -> None:
65
+ for optim in self.optimizers.values():
66
+ optim.step(*args, **kwargs)
67
+
68
+
69
+ class GaloreSchedulerWrapper(LRScheduler):
70
+
71
+ def __init__(self, lr_schedulers: Dict[Any, LRScheduler]):
72
+ self.lr_schedulers = lr_schedulers
73
+
74
+ def step(self, *args, **kwargs) -> None:
75
+ for lr_scheduler in self.lr_schedulers.values():
76
+ lr_scheduler.step(*args, **kwargs)
77
+ self._last_lr = lr_scheduler.get_last_lr()
78
+
79
+
80
+ def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, config: GaLoreConfig, max_steps,
81
+ **defaults):
82
+ galore_params = []
83
+ for module_name, module in model.named_modules():
84
+ if not isinstance(module, (nn.Linear, nn.Embedding)) or \
85
+ not any(target_key in module_name for target_key in config.target_modules):
86
+ continue
87
+
88
+ if not module.weight.requires_grad:
89
+ continue
90
+
91
+ logger.info(f'Enable GaLore for weights in module: {module_name}')
92
+ galore_params.append(module.weight)
93
+
94
+ id_galore_params = [id(p) for p in galore_params]
95
+ galore_defaults = {
96
+ 'rank': config.rank,
97
+ 'update_proj_gap': config.update_proj_gap,
98
+ 'scale': config.galore_scale,
99
+ 'proj_type': config.proj_type,
100
+ **defaults
101
+ }
102
+ if config.quantize:
103
+ galore_defaults['quant'] = config.proj_quant
104
+ galore_defaults['quant_n_bit'] = config.proj_bits
105
+ galore_defaults['quant_group_size'] = config.proj_group_size
106
+ galore_defaults['cos_threshold'] = config.cos_threshold
107
+ galore_defaults['gamma_proj'] = config.gamma_proj
108
+ galore_defaults['queue_size'] = config.queue_size
109
+ optim_cls, optim_kwargs = get_optimizer(args, config)
110
+
111
+ if config.optim_per_parameter and not config.quantize:
112
+ # q-galore does not support optim_per_parameter
113
+ optimizer_dict = {}
114
+ galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2
115
+ for p in model.parameters():
116
+ if p.requires_grad:
117
+ if id(p) in id_galore_params:
118
+ optimizer_dict[p] = optim_cls([{'params': [p], **galore_defaults}], **optim_kwargs)
119
+ else:
120
+ optimizer_dict[p] = optim_cls([{'params': [p], **defaults}], **optim_kwargs)
121
+
122
+ # get scheduler dict
123
+ scheduler_dict = {}
124
+ for p in model.parameters():
125
+ if p.requires_grad:
126
+ scheduler_dict[p] = get_scheduler(
127
+ optimizer=optimizer_dict[p],
128
+ name=args.lr_scheduler_type,
129
+ num_training_steps=max_steps * 2,
130
+ num_warmup_steps=args.warmup_steps * 2,
131
+ scheduler_specific_kwargs=args.lr_scheduler_kwargs,
132
+ )
133
+
134
+ return GaloreOptimizerWrapper(optimizer_dict), GaloreSchedulerWrapper(scheduler_dict)
135
+ else:
136
+ decay_parameters = Trainer.get_decay_parameter_names(Trainer, model)
137
+ param_groups = [{
138
+ 'params': galore_params,
139
+ **galore_defaults,
140
+ }]
141
+ param_groups.extend([
142
+ {
143
+ 'params': [
144
+ p for n, p in model.named_parameters()
145
+ if (n in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
146
+ ],
147
+ 'weight_decay':
148
+ defaults['weight_decay'],
149
+ },
150
+ {
151
+ 'params': [
152
+ p for n, p in model.named_parameters()
153
+ if (n not in decay_parameters and id(p) not in id_galore_params and p.requires_grad)
154
+ ],
155
+ 'weight_decay':
156
+ 0.0,
157
+ },
158
+ ])
159
+ optim = optim_cls(param_groups, **optim_kwargs)
160
+ scheduler = get_scheduler(
161
+ optimizer=optim,
162
+ name=args.lr_scheduler_type,
163
+ num_training_steps=max_steps,
164
+ num_warmup_steps=args.warmup_steps,
165
+ scheduler_specific_kwargs=args.lr_scheduler_kwargs,
166
+ )
167
+ return optim, scheduler
168
+
169
+
170
+ def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]:
171
+ # parse args.optim_args
172
+ optim_args = {}
173
+ if args.optim_args:
174
+ for mapping in args.optim_args.replace(' ', '').split(','):
175
+ key, value = mapping.split('=')
176
+ optim_args[key] = value
177
+
178
+ optimizer_kwargs = {'lr': args.learning_rate}
179
+
180
+ adam_kwargs = {
181
+ 'betas': (args.adam_beta1, args.adam_beta2),
182
+ 'eps': args.adam_epsilon,
183
+ }
184
+ if args.optim == 'adafactor':
185
+ from .adafactor import GaLoreAdafactor
186
+ optimizer_cls = GaLoreAdafactor
187
+ optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
188
+ elif args.optim in ('adamw_hf', 'adamw_torch'):
189
+ if config.quantize:
190
+ assert importlib.util.find_spec('q_galore_torch') is not None, \
191
+ 'Please install q-galore by `pip install q_galore_torch`'
192
+ logger.info('If you encounter `absmax2` error, please downgrade your bitsandbytes to 0.40.0')
193
+ from swift.utils import get_dist_setting
194
+ _, _, world_size, _ = get_dist_setting()
195
+ if world_size > 1:
196
+ # from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW
197
+ from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
198
+ else:
199
+ from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
200
+ else:
201
+ from .adamw import GaLoreAdamW
202
+ optimizer_cls = GaLoreAdamW
203
+ optimizer_kwargs.update(adam_kwargs)
204
+ elif 'adamw' in args.optim and '8bit' in args.optim:
205
+ try:
206
+ from .adamw8bit import GaLoreAdamW8bit
207
+ optimizer_cls = GaLoreAdamW8bit
208
+ optimizer_kwargs.update(adam_kwargs)
209
+ optimizer_kwargs.update({'optim_bits': 8, 'is_paged': 'paged' in args.optim})
210
+ except ImportError:
211
+ raise ValueError('Trainer tried to instantiate bnb optimizer but bnb is not installed!')
212
+ else:
213
+ raise ValueError(f'Galore not supported for optimizer type: {args.optim}')
214
+ return optimizer_cls, optimizer_kwargs
ms-swift/swift/trainers/rlhf_arguments.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from trl import CPOConfig as HfCPOConfig
5
+ from trl import DPOConfig as HfDPOConfig
6
+ from trl import GRPOConfig as HfGRPOConfig
7
+ from trl import KTOConfig as HfKTOConfig
8
+ from trl import ORPOConfig as HfORPOConfig
9
+ from trl import PPOConfig as HfPPOConfig
10
+ from trl import RewardConfig as HfRewardConfig
11
+
12
+ from .arguments import GRPOArgumentsMixin, SwiftArgumentsMixin
13
+
14
+
15
+ @dataclass
16
+ class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
17
+ pass
18
+
19
+
20
+ @dataclass
21
+ class CPOConfig(SwiftArgumentsMixin, HfCPOConfig):
22
+ pass
23
+
24
+
25
+ @dataclass
26
+ class ORPOConfig(SwiftArgumentsMixin, HfORPOConfig):
27
+ pass
28
+
29
+
30
+ @dataclass
31
+ class KTOConfig(SwiftArgumentsMixin, HfKTOConfig):
32
+ pass
33
+
34
+
35
+ @dataclass
36
+ class RewardConfig(SwiftArgumentsMixin, HfRewardConfig):
37
+ pass
38
+
39
+
40
+ @dataclass
41
+ class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
42
+ pass
43
+
44
+
45
+ @dataclass
46
+ class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
47
+ stop_words: List[str] = field(default_factory=list)
48
+
49
+ def __post_init__(self):
50
+ from swift.llm.argument.base_args.model_args import ModelArguments
51
+ super().__post_init__()
52
+ if self.cosine_max_len is None:
53
+ self.cosine_max_len = self.max_completion_length
54
+ self.vllm_limit_mm_per_prompt = ModelArguments.parse_to_dict(self.vllm_limit_mm_per_prompt)
55
+
56
+ if self.deepspeed and 'zero_optimization' in self.deepspeed and self.deepspeed['zero_optimization'][
57
+ 'stage'] == 3:
58
+ # https://github.com/modelscope/ms-swift/issues/3237
59
+ self.deepspeed['zero_optimization']['stage3_prefetch_bucket_size'] = 0
60
+ self.deepspeed_plugin.hf_ds_config.config['zero_optimization']['stage3_prefetch_bucket_size'] = 0
61
+
62
+ # https://github.com/modelscope/ms-swift/issues/3863
63
+ self.dataloader_drop_last = True
ms-swift/swift/trainers/rlhf_trainer/kto_trainer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from contextlib import contextmanager
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from peft import PeftModel
8
+ from transformers import PreTrainedModel
9
+ from trl import KTOTrainer as HFKTOTrainer
10
+
11
+ from swift.utils import get_logger
12
+ from ..mixin import SwiftMixin
13
+ from .rlhf_mixin import RLHFTrainerMixin
14
+
15
+ logger = get_logger()
16
+
17
+ del HFKTOTrainer.__init__
18
+
19
+
20
+ class KTOTrainer(RLHFTrainerMixin, SwiftMixin, HFKTOTrainer):
21
+
22
+ def __init__(self,
23
+ model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
24
+ ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
25
+ *_args,
26
+ **kwargs):
27
+ args = kwargs['args']
28
+ args.disable_dropout = True
29
+ self.desirable_weight = args.desirable_weight
30
+ self.undesirable_weight = args.undesirable_weight
31
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
32
+ self.is_peft_model = isinstance(model, PeftModel)
33
+ if hasattr(args, 'loss_type'):
34
+ self.loss_type = args.loss_type
35
+ else:
36
+ self.loss_type = 'kto'
37
+
38
+ self.ref_adapter_name = None
39
+ # Not all losses require a KL calculation
40
+ self.calculate_KL = True
41
+ if self.loss_type in ['apo_zero_unpaired']:
42
+ self.calculate_KL = False
43
+ super().__init__(model, ref_model, *_args, **kwargs)
44
+
45
+ def forward(
46
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
47
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
48
+ is_kl = True
49
+
50
+ def _add_data_hook(model, args, kwargs):
51
+ nonlocal is_kl
52
+ if is_kl:
53
+ kwargs = {k[len('KL_completion_'):]: v for k, v in batch.items() if k.startswith('KL_completion_')}
54
+ else:
55
+ kwargs = {k[len('completion_'):]: v for k, v in batch.items() if k.startswith('completion_')}
56
+ is_kl = not is_kl
57
+ return (), kwargs
58
+
59
+ @contextmanager
60
+ def _patch_model_call():
61
+ handle = model.register_forward_pre_hook(_add_data_hook, with_kwargs=True, prepend=True)
62
+
63
+ try:
64
+ yield
65
+ finally:
66
+ handle.remove()
67
+
68
+ with _patch_model_call():
69
+ return super().forward(model, batch)
ms-swift/swift/trainers/rlhf_trainer/orpo_trainer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Optional, Union
3
+
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel
6
+ from trl import ORPOTrainer as HFORPOTrainer
7
+
8
+ from ..mixin import SwiftMixin
9
+ from .rlhf_mixin import RLHFTrainerMixin
10
+
11
+ del HFORPOTrainer.__init__
12
+
13
+
14
+ class ORPOTrainer(RLHFTrainerMixin, SwiftMixin, HFORPOTrainer):
15
+
16
+ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, *_args, **kwargs):
17
+ ref_model = kwargs.get('ref_model')
18
+ assert ref_model is None, 'ORPO does not require a ref_model.'
19
+ super().__init__(model, *_args, **kwargs)
ms-swift/swift/trainers/rlhf_trainer/ppo_trainer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import inspect
3
+ from contextlib import contextmanager
4
+
5
+ import transformers
6
+ from packaging import version
7
+ from torch.utils.data import DataLoader
8
+ from transformers import PreTrainedModel
9
+ from trl import PPOTrainer as HFPPOTrainer
10
+
11
+ from swift.utils import patch_getattr
12
+ from ..mixin import SwiftMixin
13
+
14
+ ppo_trainer_init = HFPPOTrainer.__init__
15
+ del HFPPOTrainer.__init__
16
+
17
+
18
+ class PPOTrainer(SwiftMixin, HFPPOTrainer):
19
+
20
+ @staticmethod
21
+ @contextmanager
22
+ def _patch_dataloader(collate_fn):
23
+ __init__ = DataLoader.__init__
24
+
25
+ def __new_init__(self, *args, **kwargs):
26
+ kwargs['collate_fn'] = collate_fn
27
+ __init__(self, *args, **kwargs)
28
+
29
+ DataLoader.__init__ = __new_init__
30
+ try:
31
+ yield
32
+ finally:
33
+ DataLoader.__init__ = __init__
34
+
35
+ def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, **kwargs):
36
+ super().__init__(model, *_args, **{k: v for k, v in kwargs.items() if k not in {'reward_model', 'value_model'}})
37
+ with self._patch_dataloader(kwargs['data_collator']):
38
+ new_kwargs = {
39
+ k: v
40
+ for k, v in kwargs.items()
41
+ if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset']
42
+ }
43
+ parameters = inspect.signature(ppo_trainer_init).parameters
44
+ if 'config' in parameters:
45
+ new_kwargs['config'] = kwargs['args']
46
+ else:
47
+ new_kwargs['args'] = kwargs['args']
48
+ if 'processing_class' in parameters:
49
+ new_kwargs['processing_class'] = self.tokenizer
50
+ else:
51
+ new_kwargs['tokenizer'] = self.tokenizer
52
+ ppo_trainer_init(self, model=model, ref_model=ref_model, **new_kwargs)
53
+ unwrap_model = self.accelerator.unwrap_model(self.model)
54
+ patch_getattr(unwrap_model.__class__, 'policy')
55
+
56
+ def train(self, *args, **kwargs):
57
+ # remove args that are not needed for the HFPPOTrainer
58
+ super().train()
59
+
60
+ def _save_checkpoint(self, *args, **kwargs):
61
+ if version.parse(transformers.__version__) >= version.parse('4.47'):
62
+ metrics = kwargs.pop('metrics', None)
63
+ trial = kwargs.get('trial')
64
+ self._determine_best_metric(metrics=metrics, trial=trial)
65
+ return super()._save_checkpoint(*args, **kwargs)
ms-swift/swift/trainers/rlhf_trainer/reward_trainer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from collections import defaultdict
3
+ from typing import Any, Dict, Tuple, Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ from accelerate.utils import gather_object
9
+ from transformers import PreTrainedModel
10
+ from trl import RewardTrainer as HFRewardTrainer
11
+ from trl.trainer.utils import print_rich_table
12
+
13
+ from ..mixin import SwiftMixin
14
+ from .rlhf_mixin import RLHFTrainerMixin
15
+
16
+ del HFRewardTrainer.__init__
17
+
18
+
19
+ class RewardTrainer(RLHFTrainerMixin, SwiftMixin, HFRewardTrainer):
20
+
21
+ def compute_loss(self,
22
+ model: Union[PreTrainedModel, nn.Module],
23
+ inputs: Dict[str, Union[torch.Tensor, Any]],
24
+ return_outputs=False,
25
+ num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
26
+ inputs.pop('labels', None) # not use
27
+ attention_mask = inputs['attention_mask']
28
+ batch_size = attention_mask.shape[0] // 2
29
+ rewards = model(**inputs).logits
30
+ rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0)
31
+ if 'margin' in inputs:
32
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs['margin']).mean()
33
+ else:
34
+ loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
35
+ if self.args.center_rewards_coefficient is not None:
36
+ loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected)**2)
37
+ # compat transformers>=4.46.*
38
+ if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
39
+ loss /= self.args.gradient_accumulation_steps
40
+ if return_outputs:
41
+ return loss, {
42
+ 'rewards_chosen': rewards_chosen,
43
+ 'rewards_rejected': rewards_rejected,
44
+ }
45
+ return loss
46
+
47
+ def visualize_samples(self, num_print_samples: int):
48
+ """
49
+ Visualize the reward model logits prediction
50
+
51
+ Args:
52
+ num_print_samples (`int`, defaults to `4`):
53
+ The number of samples to print. Set to `-1` to print all samples.
54
+ """
55
+ eval_dataloader = self.get_eval_dataloader()
56
+ table = defaultdict(list)
57
+ for _, inputs in enumerate(eval_dataloader):
58
+ _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
59
+ input_ids = inputs['input_ids']
60
+ attention_mask = inputs['attention_mask']
61
+ sequence_lengths = ((torch.eq(attention_mask, 0).int().argmax(-1) - 1) % attention_mask.shape[1]).tolist()
62
+ text = [self.template.safe_decode(tokens[:sequence_lengths[i]]) for i, tokens in enumerate(input_ids)]
63
+ batch_size = input_ids.shape[0] // 2
64
+ chosen_text, rejected_text = text[:batch_size], text[batch_size:]
65
+ table['chosen_text'].extend(gather_object(chosen_text))
66
+ table['rejected_text'].extend(gather_object(rejected_text))
67
+ table['logits'].extend(
68
+ gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]))
69
+ if 0 <= num_print_samples <= len(table['chosen_text']):
70
+ break
71
+ df = pd.DataFrame(table)
72
+ if self.accelerator.process_index == 0:
73
+ print_rich_table(df[:num_print_samples])
74
+ if 'wandb' in self.args.report_to:
75
+ import wandb
76
+
77
+ if wandb.run is not None:
78
+ wandb.log({'completions': wandb.Table(dataframe=df)})
ms-swift/swift/trainers/rlhf_trainer/rlhf_mixin.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from collections import defaultdict
3
+ from contextlib import contextmanager, nullcontext
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import PreTrainedModel
9
+ from transformers.integrations import is_deepspeed_zero3_enabled
10
+
11
+ try:
12
+ from trl import AutoModelForCausalLMWithValueHead
13
+ except (ImportError, RuntimeError):
14
+ AutoModelForCausalLMWithValueHead = None
15
+
16
+
17
+ class RLHFTrainerMixin:
18
+
19
+ def __init__(self,
20
+ model: Optional[Union[PreTrainedModel, nn.Module]] = None,
21
+ ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
22
+ *_args,
23
+ **kwargs):
24
+ from trl.trainer import disable_dropout_in_model
25
+ from swift.llm import HfConfigFactory
26
+ self.ref_model = ref_model
27
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
28
+ args = kwargs['args']
29
+ self.beta = getattr(args, 'beta', 0.0)
30
+ if getattr(args, 'disable_dropout', False):
31
+ disable_dropout_in_model(model)
32
+ if self.ref_model is not None:
33
+ disable_dropout_in_model(self.ref_model)
34
+
35
+ self.is_encoder_decoder = kwargs['template'].is_encoder_decoder
36
+ self.aux_loss_enabled = getattr(model.config, 'output_router_logits', False)
37
+ self._peft_has_been_casted_to_bf16 = False
38
+ self.generate_during_eval = getattr(args, 'generate_during_eval', False)
39
+ if self.is_encoder_decoder:
40
+ self.decoder_start_token_id = HfConfigFactory.get_config_attr(model.config, 'decoder_start_token_id')
41
+ self.pad_token_id = HfConfigFactory.get_config_attr(model.config, 'pad_token_id')
42
+ # not use
43
+ self.is_vision_model = False
44
+ self.label_pad_token_id = -100
45
+ self.use_dpo_data_collator = True
46
+ super().__init__(model, *_args, **kwargs)
47
+ if is_deepspeed_zero3_enabled() and ref_model is not None:
48
+ try:
49
+ from trl.models.utils import prepare_deepspeed
50
+ except ImportError as e:
51
+ raise ImportError('Please install trl>=0.14 via `pip install "trl>=0.14"`') from e
52
+ prepare_deepspeed(self.ref_model, self.accelerator) # Does not wrap DeepSpeedEngine
53
+ self.padding_value = self.tokenizer.pad_token_id
54
+
55
+ def concatenated_forward(
56
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
57
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
58
+ model_kwargs = batch.copy()
59
+ labels = model_kwargs.pop('labels', None)
60
+ if self.is_encoder_decoder:
61
+ model_kwargs['labels'] = labels
62
+
63
+ if self.aux_loss_enabled:
64
+ model_kwargs['output_router_logits'] = True
65
+ outputs = model(**model_kwargs, use_cache=False)
66
+ model_kwargs['labels'] = labels
67
+ model_kwargs['chosen_labels'] = torch.zeros(model_kwargs['labels'].shape[0] // 2) # just get shape
68
+ if outputs.logits.shape[1] != labels.shape[1]:
69
+ # for llava, the model returns logits for the entire sequence, including the image tokens
70
+ # (placed before the text tokens)
71
+ outputs.logits = outputs.logits[:, -labels.shape[1]:]
72
+ for key in ['input_ids', 'attention_mask', 'labels']:
73
+ model_kwargs[f'concatenated_{key}'] = model_kwargs.pop(key, None)
74
+ if self.__class__.__name__ == 'ORPOTrainer': # Pass-through labels
75
+ model_kwargs['concatenated_input_ids'] = model_kwargs['concatenated_labels']
76
+
77
+ @contextmanager
78
+ def _patch_concatenated_forward():
79
+ _old_concatenated_inputs = self.concatenated_inputs
80
+ _old_model_call = model.__class__.__call__
81
+ self.concatenated_inputs = lambda *args, **kwargs: model_kwargs
82
+ model.__class__.__call__ = lambda *args, **kwargs: outputs
83
+ try:
84
+ yield
85
+ finally:
86
+ self.concatenated_inputs = _old_concatenated_inputs
87
+ model.__class__.__call__ = _old_model_call
88
+
89
+ with _patch_concatenated_forward():
90
+ return super().concatenated_forward(model, model_kwargs)
91
+
92
+ def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, *args, **kwargs):
93
+ if self.is_encoder_decoder:
94
+ labels = labels.clone() # fix trl bug
95
+ return super().get_batch_logps(logits, labels, *args, **kwargs)
96
+
97
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
98
+ res = super().compute_loss(model, inputs, return_outputs=return_outputs)
99
+ # compat transformers>=4.46.*
100
+ if num_items_in_batch is not None and self.model_accepts_loss_kwargs:
101
+ loss = res[0] if return_outputs else res
102
+ loss /= self.args.gradient_accumulation_steps
103
+ return (loss, res[1:]) if return_outputs else loss
104
+ return res
ms-swift/swift/trainers/rlhf_trainer/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from types import MethodType
3
+ from typing import Any, List, Optional
4
+
5
+ import torch
6
+ from peft.tuners import lora
7
+ from peft.tuners.lora import LoraLayer
8
+
9
+
10
+ def round_robin(num_reqs, num_workers):
11
+ """Distribute requests evenly across workers using round-robin algorithm.
12
+
13
+ Args:
14
+ num_reqs (int): Total number of requests to distribute
15
+ num_workers (int): Number of available workers
16
+
17
+ Returns:
18
+ list: A list of lists where each sublist contains the request indices
19
+ assigned to that particular node
20
+ """
21
+ distribution = [[] for _ in range(num_workers)]
22
+ for idx in range(num_reqs):
23
+ worker_id = idx % num_workers
24
+ distribution[worker_id].append(idx)
25
+ return distribution
26
+
27
+
28
+ @contextmanager
29
+ def patch_lora_merge(model, parameter_group=None):
30
+ """Patch LoraLayer's merge and get_delta_weight methods for controlled merging.
31
+
32
+ Args:
33
+ model: The PEFT model to patch
34
+ parameter_group: Optional list of parameter names to restrict merging
35
+
36
+ Yields:
37
+ The patched model (context manager ensures cleanup)
38
+ """
39
+ from peft.tuners.tuners_utils import check_adapters_to_merge
40
+
41
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
42
+ if parameter_group and all(self.name not in pg for pg in parameter_group):
43
+ return # Skip if not in target parameter group
44
+ adapter_names = check_adapters_to_merge(self, adapter_names)
45
+ if not adapter_names:
46
+ return
47
+
48
+ for active_adapter in adapter_names:
49
+ if active_adapter in self.lora_A.keys():
50
+ base_layer = self.get_base_layer()
51
+ if self.use_dora.get(active_adapter, False):
52
+ self.lora_magnitude_vector[active_adapter].weight.data = \
53
+ self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device)
54
+
55
+ return self.merge_origin(safe_merge, adapter_names)
56
+
57
+ def get_delta_weight(self, adapter) -> torch.Tensor:
58
+ # Ensure tensors are on correct device
59
+ if isinstance(self, lora.Embedding):
60
+ self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device)
61
+ self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device)
62
+ else:
63
+ self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device)
64
+ self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device)
65
+ return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device)
66
+
67
+ def _cache_pop(self, key: str) -> Any:
68
+ value = self._caches.pop(key).to(self.base_layer.weight.device)
69
+ return value
70
+
71
+ # Patch all LoraLayer instances
72
+ for name, module in model.named_modules():
73
+ if isinstance(module, LoraLayer):
74
+ module.name = name
75
+ if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'):
76
+ module.merge_origin = module.merge
77
+ module.merge = MethodType(merge, module)
78
+ module.get_delta_weight_origin = module.get_delta_weight
79
+ module.get_delta_weight = MethodType(get_delta_weight, module)
80
+ module._cache_pop_origin = module._cache_pop
81
+ module._cache_pop = MethodType(_cache_pop, module)
82
+
83
+ try:
84
+ yield model
85
+ finally:
86
+ # Cleanup: restore original methods
87
+ for module in model.modules():
88
+ if isinstance(module, LoraLayer):
89
+ if hasattr(module, 'merge_origin'):
90
+ module.merge = module.merge_origin
91
+ del module.merge_origin
92
+ module.get_delta_weight = module.get_delta_weight_origin
93
+ del module.get_delta_weight_origin
94
+ module._cache_pop = module._cache_pop_origin
95
+ del module._cache_pop_origin
96
+
97
+
98
+ @contextmanager
99
+ def patch_lora_unmerge(model):
100
+ """Patch the unmerge method to ensure proper device handling."""
101
+
102
+ def _cache_pop_patched(self, key: str) -> Any:
103
+ value = self._caches.pop(key).to(self.base_layer.weight.device)
104
+ return value
105
+
106
+ def unmerge_patched(self):
107
+ if not self.merged:
108
+ return
109
+ # Move magnitude vectors to correct device first
110
+ for adapter in list(self.merged_adapters):
111
+ if self.use_dora.get(adapter, False):
112
+ self.lora_magnitude_vector[adapter].weight.data = \
113
+ self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device)
114
+
115
+ return self.unmerge_origin()
116
+
117
+ for module in model.modules():
118
+ if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'):
119
+ module.unmerge_origin = module.unmerge
120
+ module.unmerge = MethodType(unmerge_patched, module)
121
+ module._cache_pop_origin = module._cache_pop
122
+ module._cache_pop = MethodType(_cache_pop_patched, module)
123
+
124
+ try:
125
+ yield model
126
+ finally:
127
+ for module in model.modules():
128
+ if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'):
129
+ module.unmerge = module.unmerge_origin
130
+ del module.unmerge_origin
131
+ module._cache_pop = module._cache_pop_origin
132
+ del module._cache_pop_origin
ms-swift/swift/trainers/rlhf_trainer/vllm_client.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ # Code partially sourced from Hugging Face TRL
4
+
5
+ import atexit
6
+ import logging
7
+ import time
8
+ from typing import List, Optional
9
+
10
+ import requests
11
+ import torch
12
+ from dacite import from_dict
13
+ from requests import ConnectionError
14
+ from torch import nn
15
+
16
+ from swift.llm import AdapterRequest, InferRequest, Template
17
+ from swift.llm.infer.protocol import ChatCompletionResponse, RequestConfig
18
+ from swift.plugin import Metric
19
+ from swift.utils import is_vllm_ascend_available, is_vllm_available
20
+
21
+ if is_vllm_available():
22
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
23
+ from vllm.distributed.utils import StatelessProcessGroup
24
+
25
+ if is_vllm_ascend_available():
26
+ from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator # noqa
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class VLLMClient:
32
+ """
33
+ A client class to interact with a vLLM server.
34
+
35
+ This class provides methods to infer completions, initialize and manage weight update groups, and update model
36
+ weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
37
+
38
+ Args:
39
+ host (`str`, *optional*, defaults to `"0.0.0.0"`):
40
+ IP address of the vLLM server.
41
+ server_port (`int`, *optional*, defaults to `8000`):
42
+ Port number of the vLLM server.
43
+ group_port (`int`, *optional*, defaults to `51216`):
44
+ Port number for the weight update group.
45
+ connection_timeout (`float`, *optional*, defaults to `0.0`):
46
+ Total timeout duration in seconds to wait for the server to be up. If the server is not up after the
47
+ timeout, a `ConnectionError` is raised.
48
+ """
49
+
50
+ def __init__(self,
51
+ host: str = '0.0.0.0',
52
+ server_port: int = 8000,
53
+ group_port: int = 51216,
54
+ connection_timeout: float = 0.0):
55
+ if not is_vllm_available():
56
+ raise ImportError('vLLM is not installed. Please install it with `pip install vllm`.')
57
+
58
+ self.session = requests.Session()
59
+ self.host = host
60
+ self.server_port = server_port
61
+ self.group_port = group_port
62
+ self.check_server(connection_timeout) # check server and fail after timeout
63
+
64
+ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
65
+ """
66
+ Check server availability with retries on failure, within a total timeout duration. If the server is not up
67
+ after the total timeout duration, raise a `ConnectionError`.
68
+
69
+ Args:
70
+ retry_interval (`float`, *optional*, defaults to `2.0`):
71
+ Interval in seconds between retries.
72
+ total_timeout (`float`, *optional*, defaults to `0.0`):
73
+ Total timeout duration in seconds.
74
+ """
75
+ url = f'http://{self.host}:{self.server_port}/health/'
76
+ start_time = time.time() # Record the start time
77
+
78
+ while True:
79
+ try:
80
+ response = requests.get(url)
81
+ except requests.exceptions.RequestException as exc:
82
+ # Check if the total timeout duration has passed
83
+ elapsed_time = time.time() - start_time
84
+ if elapsed_time >= total_timeout:
85
+ raise ConnectionError(
86
+ f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
87
+ 'seconds. Make sure the server is running by running `swift deploy`.') from exc
88
+ else:
89
+ if response.status_code == 200:
90
+ logger.info('Server is up!')
91
+ return None
92
+
93
+ # Retry logic: wait before trying again
94
+ logger.info(f'Server is not up yet. Retrying in {retry_interval} seconds...')
95
+ time.sleep(retry_interval)
96
+
97
+ def infer(
98
+ self,
99
+ infer_requests: List[InferRequest],
100
+ request_config: Optional[RequestConfig] = None,
101
+ metrics: Optional[List[Metric]] = None,
102
+ *,
103
+ template: Optional[Template] = None,
104
+ use_tqdm: Optional[bool] = None,
105
+ adapter_request: Optional[AdapterRequest] = None,
106
+ ):
107
+ url = f'http://{self.host}:{self.server_port}/infer/'
108
+ response = self.session.post(
109
+ url,
110
+ json={
111
+ 'infer_requests': infer_requests,
112
+ 'request_config': request_config,
113
+ 'metrics': metrics,
114
+ 'template': template,
115
+ 'use_tqdm': use_tqdm,
116
+ 'adapter_request': adapter_request,
117
+ },
118
+ )
119
+ if response.status_code == 200:
120
+ return [from_dict(data_class=ChatCompletionResponse, data=resp) for resp in response.json()]
121
+ else:
122
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
123
+
124
+ def init_communicator(self):
125
+ """
126
+ Initializes the weight update group in a distributed setup for model synchronization.
127
+ """
128
+ # Get the tensor parallel size from the server
129
+ url = f'http://{self.host}:{self.server_port}/get_world_size/'
130
+ response = requests.get(url)
131
+ if response.status_code == 200:
132
+ vllm_world_size = response.json()['world_size']
133
+ else:
134
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
135
+
136
+ world_size = vllm_world_size + 1 # add the client to the world
137
+ self.rank = vllm_world_size # the client's rank is the last process
138
+
139
+ # Initialize weight update group
140
+ url = f'http://{self.host}:{self.server_port}/init_communicator/'
141
+ # In the server side, the host is set to 0.0.0.0
142
+ response = self.session.post(url, json={'host': '0.0.0.0', 'port': self.group_port, 'world_size': world_size})
143
+ if response.status_code != 200:
144
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
145
+
146
+ # Brief delay to allow server initialization. While not strictly required (client socket will retry on
147
+ # connection failure), this prevents log warnings like:
148
+ # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
149
+ time.sleep(0.1)
150
+
151
+ # Set up the communication group for weight broadcasting
152
+ pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
153
+ self.pynccl_comm = PyNcclCommunicator(pg, device=0)
154
+
155
+ # When the client object is deleted, close the weight update group
156
+ atexit.register(self.close_communicator)
157
+
158
+ def update_named_param(self, name: str, weights: torch.Tensor):
159
+ """
160
+ Updates a specific named parameter in the model and broadcasts it to other processes.
161
+
162
+ Args:
163
+ name (`str`):
164
+ Name of the layer whose weights are being updated.
165
+ weights (`torch.Tensor`):
166
+ Tensor containing the updated weights.
167
+ """
168
+ dtype, shape = str(weights.dtype), tuple(weights.shape)
169
+ url = f'http://{self.host}:{self.server_port}/update_named_param/'
170
+ response = self.session.post(url, json={'name': name, 'dtype': dtype, 'shape': shape})
171
+ if response.status_code != 200:
172
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
173
+
174
+ # Broadcast the weights to the other processes
175
+ self.pynccl_comm.broadcast(weights, src=self.rank)
176
+ self.pynccl_comm.group.barrier()
177
+
178
+ def update_model_params(self, model: nn.Module):
179
+ """
180
+ Updates all parameters of the given model by calling `update_named_param` for each parameter in the model.
181
+
182
+ Args:
183
+ model (`nn.Module`):
184
+ Model whose parameters (weights/biases) are to be updated.
185
+ """
186
+ for name, param in model.named_parameters():
187
+ # Update each parameter individually
188
+ self.update_named_param(name, param.data)
189
+
190
+ def reset_prefix_cache(self):
191
+ """
192
+ Resets the prefix cache for the model.
193
+ """
194
+ url = f'http://{self.host}:{self.server_port}/reset_prefix_cache/'
195
+ response = self.session.post(url)
196
+ if response.status_code != 200:
197
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
198
+
199
+ def close_communicator(self):
200
+ """
201
+ Closes the weight update group and cleans up the communication group.
202
+ """
203
+ url = f'http://{self.host}:{self.server_port}/close_communicator/'
204
+
205
+ try:
206
+ response = self.session.post(url)
207
+ except ConnectionError:
208
+ # The server might be already down, so we don't need to close the communicator
209
+ pass
210
+ else:
211
+ if response.status_code != 200:
212
+ raise Exception(f'Request failed: {response.status_code}, {response.text}')
ms-swift/swift/trainers/sequence_parallel/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from abc import abstractmethod
3
+
4
+
5
+ class SequenceParallel(abc.ABC):
6
+
7
+ @abstractmethod
8
+ def init_sequence_parallel(self, size):
9
+ pass
10
+
11
+ @abstractmethod
12
+ def prepare_model(self, model, tokenizer, split_in_forward):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def pad_and_split_inputs(self,
17
+ tokenizer,
18
+ input_ids,
19
+ input_embeds,
20
+ labels,
21
+ position_ids,
22
+ attention_mask,
23
+ loss_scale,
24
+ embed_tokens=None):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def reduce_outputs(self, loss, labels):
29
+ pass
30
+
31
+ @property
32
+ def sp_group(self):
33
+ return None
34
+
35
+ @abstractmethod
36
+ def world_size(self):
37
+ pass
38
+
39
+ @abstractmethod
40
+ def prepare_trainer(self, trainer):
41
+ pass
42
+
43
+ @abstractmethod
44
+ def get_dataloader(self, trainer, dataset, batch_size):
45
+ pass
ms-swift/swift/trainers/sequence_parallel/ulysses.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from types import MethodType
4
+ from typing import Any, Dict, Iterator, List, Optional, Tuple
5
+
6
+ import datasets
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from peft import PeftModel
11
+ from torch.distributed.device_mesh import init_device_mesh
12
+ from torch.nn import CrossEntropyLoss
13
+ from torch.utils.data import DataLoader, Sampler
14
+ from transformers.trainer_utils import seed_worker
15
+
16
+ from swift.llm import DataLoaderDispatcher, get_model_arch
17
+ from swift.tuners import SwiftModel
18
+ from swift.utils import get_current_device, get_device, get_dist_setting
19
+ from .base import SequenceParallel
20
+
21
+
22
+ class GatherLoss(torch.autograd.Function):
23
+ """Gather loss from sequence group"""
24
+
25
+ @staticmethod
26
+ def forward(ctx, loss, labels, process_group, gather_idx=None):
27
+ """
28
+ Args:
29
+ loss: loss tensor after splitting
30
+ labels: labels tensor after splitting
31
+ process_group: the sequence parallel group
32
+ gather_idx: gather the tensors on this dim
33
+ """
34
+ ctx.process_group = process_group
35
+ shape0 = labels.shape[0]
36
+ ctx.scatter_shape = labels.shape[gather_idx or 0]
37
+ ctx.gather_idx = gather_idx or 0
38
+ world_size = dist.get_world_size(group=process_group) # the sp world size
39
+ output = torch.empty((shape0 * world_size, *loss.shape[1:]), dtype=loss.dtype, device=loss.device)
40
+ # gather all from sp group
41
+ dist.all_gather_into_tensor(output, loss, group=process_group)
42
+ if gather_idx is not None:
43
+ output = torch.cat(output.split(shape0, dim=0), dim=gather_idx)
44
+ labels_output = torch.empty((shape0 * world_size, *labels.shape[1:]), dtype=labels.dtype, device=labels.device)
45
+ dist.all_gather_into_tensor(labels_output, labels, group=process_group)
46
+ if gather_idx is not None:
47
+ labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=gather_idx)
48
+ return output, labels_output
49
+
50
+ @staticmethod
51
+ def backward(ctx, *grad_output):
52
+ _grad = grad_output[0] * dist.get_world_size(group=ctx.process_group)
53
+ return _grad.split(
54
+ ctx.scatter_shape, dim=ctx.gather_idx)[dist.get_rank(ctx.process_group)].contiguous(), None, None, None
55
+
56
+
57
+ # For nll loss
58
+ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None, process_group=None) -> torch.Tensor:
59
+ if hasattr(outputs, 'logits'):
60
+ logits = outputs.logits
61
+ else:
62
+ logits = outputs
63
+ device = logits.device
64
+ logits = logits.view(-1, logits.shape[-1])
65
+ labels = labels.flatten().to(device)
66
+ # Flatten the tokens
67
+ loss_fct = CrossEntropyLoss(reduction='none')
68
+ # flatten loss
69
+ loss = loss_fct(logits, labels)
70
+
71
+ if loss_scale is not None:
72
+ loss_scale = loss_scale.flatten().to(loss.device)
73
+ loss = (loss_scale * loss)
74
+ loss, labels = GatherLoss.apply(loss, labels, process_group)
75
+ loss = loss[labels != -100].sum()
76
+ if num_items_in_batch is None:
77
+ loss = loss / (labels != -100).sum()
78
+ else:
79
+ loss = loss / num_items_in_batch
80
+ return loss
81
+
82
+
83
+ # For DPO
84
+ def get_batch_logps(logits: torch.FloatTensor,
85
+ labels: torch.LongTensor,
86
+ label_pad_token_id: int = -100,
87
+ is_encoder_decoder: bool = False,
88
+ process_group=None) -> Tuple[torch.FloatTensor, torch.LongTensor]:
89
+ labels = labels.clone() # No need to shift, pad and split has shifted the inputs.
90
+ loss_mask = labels != label_pad_token_id
91
+ labels[labels == label_pad_token_id] = 0
92
+ labels = labels.to(logits.device)
93
+ loss_mask = loss_mask.to(logits.device)
94
+ per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
95
+ total_per_token_logps, total_loss_mask = GatherLoss.apply(per_token_logps, loss_mask, process_group, 1)
96
+ return (total_per_token_logps * total_loss_mask).sum(-1), total_loss_mask.sum(-1)
97
+
98
+
99
+ class UlyssesSampler(Sampler):
100
+
101
+ # Code borrowed from mmengine
102
+ def __init__(self, ulysses, dataset, shuffle: bool = True, seed=None, round_up: bool = True) -> None:
103
+ self.ulysses = ulysses
104
+ rank = dist.get_rank(ulysses.device_mesh['data'].get_group())
105
+ world_size = ulysses.device_mesh['data'].size()
106
+ self.rank = rank
107
+ self.world_size = world_size
108
+
109
+ self.dataset = dataset
110
+ self.shuffle = shuffle
111
+ assert seed is not None
112
+ self.seed = seed
113
+ self.epoch = 0
114
+ self.round_up = round_up
115
+
116
+ if self.round_up:
117
+ self.num_samples = math.ceil(len(self.dataset) / world_size)
118
+ self.total_size = self.num_samples * self.world_size
119
+ else:
120
+ self.num_samples = math.ceil((len(self.dataset) - rank) / world_size)
121
+ self.total_size = len(self.dataset)
122
+
123
+ def __iter__(self) -> Iterator[int]:
124
+ if self.shuffle:
125
+ g = torch.Generator()
126
+ g.manual_seed(self.seed + self.epoch)
127
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
128
+ else:
129
+ indices = torch.arange(len(self.dataset)).tolist()
130
+
131
+ if self.round_up:
132
+ indices = (indices * int(self.total_size / len(indices) + 1))[:self.total_size]
133
+
134
+ indices = indices[self.rank:self.total_size:self.world_size]
135
+
136
+ return iter(indices)
137
+
138
+ def __len__(self) -> int:
139
+ return self.num_samples
140
+
141
+ def set_epoch(self, epoch: int) -> None:
142
+ self.epoch = epoch
143
+
144
+
145
+ class UlyssesDispatcher(DataLoaderDispatcher):
146
+
147
+ def __init__(self, base_dataloader, ulysses):
148
+ super().__init__(base_dataloader)
149
+ self.ulysses = ulysses
150
+
151
+ def __iter__(self):
152
+ base_iter = iter(self.base_dataloader)
153
+ while True:
154
+ data = None
155
+ try:
156
+ for i in range(self.ulysses.dp_world_size):
157
+ data = next(base_iter)
158
+ if i == self.ulysses.dp_rank:
159
+ break
160
+ except StopIteration:
161
+ pass
162
+ if data is None:
163
+ break
164
+ yield data
165
+
166
+
167
+ # Code borrowed from deepspeed, here is why:
168
+ # 1. Reduce the dependency
169
+ # 2. The original code is complex
170
+ def _generate_layout_params(scatter_idx, seq_world_size, input):
171
+ if scatter_idx < 2:
172
+ bs, global_seq_len, num_local_head, head_dim = input.shape
173
+ pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
174
+ pre_all2all_permute_idx = (1, 0, 2, 3, 4)
175
+
176
+ post_all2all_permute_idx = (1, 2, 0, 3, 4)
177
+ post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
178
+ else:
179
+ bs, local_seq_len, num_total_head, head_dim = input.shape
180
+ assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible '
181
+ f'by the sequence parallel size ({seq_world_size})!')
182
+ pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
183
+ pre_all2all_permute_idx = (2, 0, 1, 3, 4)
184
+
185
+ post_all2all_permute_idx = (1, 0, 2, 3, 4)
186
+ post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
187
+
188
+ return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
189
+
190
+
191
+ def post_all2all(permute_idx, res_shape):
192
+ """
193
+ Post-processing function for `all2all` communication.
194
+ """
195
+
196
+ def post_func(input):
197
+ if permute_idx is not None:
198
+ input = input.permute(permute_idx).contiguous()
199
+ output = input.reshape(res_shape).contiguous()
200
+
201
+ return output
202
+
203
+ return post_func
204
+
205
+
206
+ def pre_all2all_fun(permute_idx, inp_shape, input):
207
+ """
208
+ Pre-processing function for `all2all` communication.
209
+ """
210
+ input_t = input.reshape(inp_shape).contiguous()
211
+ if permute_idx is not None:
212
+ input_t = input_t.permute(permute_idx).contiguous()
213
+ return input_t
214
+
215
+
216
+ def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs):
217
+ seq_world_size = dist.get_world_size(group)
218
+ num_heads = input.shape[2]
219
+ if num_heads % seq_world_size != 0 and not scatter_idx < 2:
220
+ raise NotImplementedError
221
+ pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = (
222
+ _generate_layout_params(scatter_idx, seq_world_size, input))
223
+
224
+ input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
225
+
226
+ post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
227
+ output = torch.empty_like(input_t)
228
+ dist.all_to_all_single(output, input_t, group=group)
229
+
230
+ res = post_all2all_fun(output)
231
+ return res
232
+
233
+
234
+ class _SeqAllToAll(torch.autograd.Function):
235
+
236
+ @staticmethod
237
+ def forward(
238
+ ctx: Any,
239
+ group: dist.ProcessGroup,
240
+ input: torch.Tensor,
241
+ scatter_idx: int,
242
+ gather_idx: int,
243
+ ) -> torch.Tensor:
244
+ ctx.group = group
245
+ ctx.scatter_idx = scatter_idx
246
+ ctx.gather_idx = gather_idx
247
+ res = single_all_to_all(input, scatter_idx, gather_idx, group)
248
+ return res
249
+
250
+ @staticmethod
251
+ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
252
+ return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None
253
+
254
+
255
+ class DistributedAttention(torch.nn.Module):
256
+
257
+ def __init__(
258
+ self,
259
+ local_attention,
260
+ sequence_process_group: dist.ProcessGroup,
261
+ scatter_idx: int = 2,
262
+ gather_idx: int = 1,
263
+ ) -> None:
264
+ super(DistributedAttention, self).__init__()
265
+ self.local_attn = local_attention
266
+ self.spg = sequence_process_group
267
+ self.scatter_idx = scatter_idx
268
+ self.gather_idx = gather_idx
269
+
270
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor,
271
+ *args: Any, **kwargs) -> torch.Tensor:
272
+ query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
273
+ key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
274
+ value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
275
+ position_ids = kwargs.pop('position_ids', None)
276
+ if position_ids is not None:
277
+ shape0 = position_ids.shape[0]
278
+ position_ids_output = torch.empty((shape0 * dist.get_world_size(self.spg), position_ids.shape[1]),
279
+ dtype=position_ids.dtype,
280
+ device=position_ids.device)
281
+ dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.spg)
282
+ position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1)
283
+ context_layer = self.local_attn(
284
+ query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs)
285
+ output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
286
+ return output
287
+
288
+
289
+ class Ulysses(SequenceParallel):
290
+
291
+ def __init__(self):
292
+ self.split_in_forward = None
293
+ self.dp_world_size = None
294
+ self.sp_world_size = None
295
+ self.model_dtype = None
296
+ self.causal_mask_func = None
297
+ self.device_mesh = None
298
+ self._inited = False
299
+
300
+ def init_sequence_parallel(self, size):
301
+ if self._inited:
302
+ return
303
+ self._inited = True
304
+ self.sp_world_size = size
305
+ rank, local_rank, world_size, local_world_size = get_dist_setting()
306
+ self.dp_world_size = world_size // size
307
+ self.device_mesh = init_device_mesh(
308
+ get_device().split(':')[0], mesh_shape=(world_size // size, size), mesh_dim_names=['data', 'sequence'])
309
+
310
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
311
+ ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'] = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
312
+ ALL_ATTENTION_FUNCTIONS['sdpa_origin'] = ALL_ATTENTION_FUNCTIONS['sdpa']
313
+
314
+ def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
315
+ dist_attn, **kwargs):
316
+ if dist_attn.local_attn is None:
317
+
318
+ def _attention(query, key, value, *args, **kwargs):
319
+ query = query.transpose(1, 2)
320
+ key = key.transpose(1, 2)
321
+ value = value.transpose(1, 2)
322
+ return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args,
323
+ **kwargs)[0]
324
+
325
+ dist_attn.local_attn = _attention
326
+
327
+ return dist_attn(
328
+ query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
329
+ *args, **kwargs), None
330
+
331
+ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
332
+ dist_attn, **kwargs):
333
+ if dist_attn.local_attn is None:
334
+
335
+ def _attention(query, key, value, *args, **kwargs):
336
+ query = query.transpose(1, 2)
337
+ key = key.transpose(1, 2)
338
+ value = value.transpose(1, 2)
339
+ return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0]
340
+
341
+ dist_attn.local_attn = _attention
342
+ return dist_attn(
343
+ query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
344
+ *args, **kwargs), None
345
+
346
+ ALL_ATTENTION_FUNCTIONS['flash_attention_2'] = partial(
347
+ local_flash_attn, dist_attn=DistributedAttention(None, self.sp_group))
348
+ ALL_ATTENTION_FUNCTIONS['sdpa'] = partial(local_sdpa_attn, dist_attn=DistributedAttention(None, self.sp_group))
349
+
350
+ from transformers.modeling_flash_attention_utils import is_flash_attn_available
351
+ if is_flash_attn_available():
352
+ # TODO this works for multi-modal models like qwen2.5-vl
353
+ # SDPA is not supported, because we need to copy the code to our project, which will bring
354
+ # more works for maintaining.
355
+ from transformers import modeling_flash_attention_utils
356
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
357
+ _distributed_flash_attention = DistributedAttention(_flash_attention_forward, self.sp_group)
358
+
359
+ def flash_attention_forward(query_states: torch.Tensor, key_states: torch.Tensor,
360
+ value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], q_len,
361
+ *args, **kwargs):
362
+ return _distributed_flash_attention(query_states, key_states, value_states, attention_mask,
363
+ q_len * self.sp_world_size, *args, **kwargs)
364
+
365
+ modeling_flash_attention_utils._flash_attention_forward = flash_attention_forward
366
+
367
+ def prepare_model(self, model, tokenizer, split_in_forward):
368
+ self.split_in_forward = split_in_forward
369
+
370
+ def forward(_self, **kwargs):
371
+ # Split embedding here for multi-modal
372
+ inputs_embeds = kwargs['inputs_embeds']
373
+ position_ids = kwargs['position_ids']
374
+ attention_mask = kwargs['attention_mask']
375
+ _, inputs_embeds, _, position_ids, attention_mask, _ = self.pad_and_split_inputs(
376
+ tokenizer,
377
+ None,
378
+ inputs_embeds,
379
+ None,
380
+ position_ids,
381
+ attention_mask,
382
+ None,
383
+ embed_tokens=_self.embed_tokens)
384
+ kwargs['inputs_embeds'] = inputs_embeds
385
+ kwargs['position_ids'] = position_ids
386
+ kwargs['attention_mask'] = attention_mask
387
+ return _self.forward_origin(**kwargs)
388
+
389
+ if isinstance(model, (SwiftModel, PeftModel)):
390
+ model = model.model
391
+ model_meta = model.model_meta
392
+ llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None)
393
+ if llm_prefix:
394
+ llm_model = getattr(model, llm_prefix[0])
395
+ else:
396
+ llm_model = model
397
+
398
+ if 'CausalLM' not in llm_model.__class__.__name__:
399
+ llm_model = model
400
+
401
+ base_model = llm_model.model
402
+ self.causal_mask_func = base_model._update_causal_mask
403
+ if self.split_in_forward:
404
+ # for multi modal models
405
+ base_model.forward_origin = base_model.forward
406
+ base_model.forward = MethodType(forward, base_model)
407
+
408
+ self.model_dtype = next(model.parameters()).dtype
409
+
410
+ def _pad_sp(self, tensor, padding_value, dim=-1):
411
+ # code borrowed from xtuner
412
+ length = tensor.shape[dim]
413
+ if length % self.sp_world_size == 0:
414
+ return tensor
415
+
416
+ pad_num = self.sp_world_size - (length % self.sp_world_size)
417
+ if not isinstance(padding_value, torch.Tensor):
418
+ # ids
419
+ pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else
420
+ (*tensor.shape[:dim], pad_num))
421
+ pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
422
+ tensor = torch.cat([tensor, pad], dim=dim)
423
+ else:
424
+ # For embeddings
425
+ tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim)
426
+ return tensor
427
+
428
+ def world_size(self):
429
+ return self.sp_world_size
430
+
431
+ def _split_sp(self, input, dim: int, sp_group: dist.ProcessGroup):
432
+ # code borrowed from xtuner
433
+ if self.sp_world_size == 1:
434
+ return input
435
+
436
+ rank = dist.get_rank(sp_group)
437
+ dim_size = input.size(dim)
438
+ assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of '
439
+ f'world size ({self.sp_world_size}), cannot split tensor evenly')
440
+
441
+ tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim)
442
+ output = tensor_list[rank].contiguous()
443
+
444
+ return output
445
+
446
+ def pad_and_split_inputs(self,
447
+ tokenizer,
448
+ input_ids,
449
+ input_embeds,
450
+ labels,
451
+ position_ids,
452
+ attention_mask,
453
+ loss_scale,
454
+ embed_tokens=None):
455
+ sp_group = self.sp_group
456
+ split_inputs = False
457
+ if (input_ids is not None and not self.split_in_forward) or input_embeds is not None:
458
+ # Whether split the model inputs
459
+ # cannot split input_ids for multi-modal models
460
+ split_inputs = True
461
+ if input_ids is not None and split_inputs:
462
+ input_ids = self._pad_sp(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
463
+ if input_embeds is not None:
464
+ pad_emb = embed_tokens(torch.tensor(tokenizer.pad_token_id).to(embed_tokens.weight.device)).unsqueeze(0)
465
+ input_embeds = self._pad_sp(input_embeds, padding_value=pad_emb, dim=1)
466
+ if position_ids is not None and split_inputs:
467
+ position_ids = self._pad_sp(position_ids, padding_value=0, dim=-1)
468
+ if split_inputs:
469
+ inputs = input_ids if input_ids is not None else input_embeds
470
+ attn_shape = inputs.shape[1] # The sequence length
471
+ if attention_mask is None:
472
+ attention_mask = torch.ones_like(position_ids)
473
+ attention_mask = self._pad_sp(attention_mask, padding_value=0, dim=-1)
474
+ cache_position = torch.arange(0, attn_shape, device=inputs.device)
475
+ # pad attention mask to 4d to avoid calculation errors
476
+ attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, None,
477
+ None)
478
+ if input_ids is not None and split_inputs:
479
+ input_ids = self._split_sp(input_ids, dim=1, sp_group=sp_group)
480
+ if input_embeds is not None:
481
+ input_embeds = self._split_sp(input_embeds, dim=1, sp_group=sp_group)
482
+ if position_ids is not None and split_inputs:
483
+ position_ids = self._split_sp(position_ids, dim=-1, sp_group=sp_group)
484
+ if labels is not None:
485
+ labels = self._pad_sp(labels, padding_value=-100, dim=-1)
486
+ labels[:, 0] = -100 # make the last invalid, so we do not need to cut the loss of last token
487
+ labels = torch.roll(labels, shifts=-1, dims=1)
488
+ labels = self._split_sp(labels, dim=1, sp_group=sp_group)
489
+
490
+ if loss_scale is not None:
491
+ loss_scale = self._pad_sp(loss_scale, padding_value=0., dim=-1)
492
+ loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1)
493
+ loss_scale = self._split_sp(loss_scale, dim=-1, sp_group=sp_group)
494
+
495
+ return input_ids, input_embeds, labels, position_ids, attention_mask, loss_scale
496
+
497
+ def reduce_outputs(self, loss, labels):
498
+ return loss
499
+
500
+ @property
501
+ def sp_rank(self):
502
+ return dist.get_rank(self.device_mesh['sequence'].get_group())
503
+
504
+ @property
505
+ def dp_rank(self):
506
+ return dist.get_rank(self.device_mesh['data'].get_group())
507
+
508
+ @property
509
+ def sp_group(self):
510
+ return self.device_mesh['sequence'].get_group()
511
+
512
+ @property
513
+ def dp_group(self):
514
+ return self.device_mesh['data'].get_group()
515
+
516
+ def get_dataloader(self, trainer, dataset, batch_size):
517
+ data_collator = trainer.data_collator
518
+ if isinstance(dataset, datasets.Dataset):
519
+ dataset = trainer._remove_unused_columns(dataset, description='training')
520
+ else:
521
+ data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
522
+ if hasattr(dataset, '__len__'):
523
+ sampler = UlyssesSampler(self, dataset, seed=42)
524
+ dataloader_params = {
525
+ 'batch_size': batch_size,
526
+ 'collate_fn': data_collator,
527
+ 'num_workers': trainer.args.dataloader_num_workers,
528
+ 'pin_memory': trainer.args.dataloader_pin_memory,
529
+ 'persistent_workers': trainer.args.dataloader_persistent_workers,
530
+ }
531
+
532
+ if not isinstance(dataset, torch.utils.data.IterableDataset):
533
+ dataloader_params['sampler'] = sampler
534
+ dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
535
+ dataloader_params['worker_init_fn'] = seed_worker
536
+
537
+ return DataLoader(dataset, **dataloader_params)
538
+ else:
539
+ dataloader_params = {
540
+ 'collate_fn': data_collator,
541
+ 'num_workers': trainer.args.dataloader_num_workers,
542
+ 'pin_memory': trainer.args.dataloader_pin_memory,
543
+ 'persistent_workers': trainer.args.dataloader_persistent_workers,
544
+ 'prefetch_factor': trainer.args.dataloader_prefetch_factor
545
+ }
546
+ if dist.is_initialized() and dataloader_params['prefetch_factor']:
547
+ dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size()
548
+ dataloader = DataLoader(dataset, batch_size=batch_size, **dataloader_params)
549
+ dataloader = UlyssesDispatcher(dataloader, self)
550
+ return dataloader
551
+
552
+ def prepare_trainer(self, trainer):
553
+ if trainer.train_dataset is None:
554
+ raise ValueError('Trainer: training requires a train_dataset.')
555
+
556
+ trainer.compute_loss_func = partial(loss_scale_sp_func, process_group=self.sp_group)
557
+ if hasattr(trainer, 'get_batch_logps'):
558
+ trainer.get_batch_logps = partial(get_batch_logps, process_group=self.sp_group)
559
+ if hasattr(trainer, 'get_nll_loss'):
560
+
561
+ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
562
+ return loss_scale_sp_func(*args, process_group=self.sp_group, **kwargs)
563
+
564
+ trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer)
565
+
566
+ from swift.plugin import metric
567
+ from swift.trainers import mixin
568
+ compute_acc_origin = metric.compute_acc
569
+
570
+ def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
571
+
572
+ # Gather preds and labels across the sp group
573
+ if isinstance(preds, np.ndarray):
574
+ preds = torch.from_numpy(preds).to(get_current_device())
575
+ if isinstance(labels, np.ndarray):
576
+ labels = torch.from_numpy(labels).to(get_current_device())
577
+ shape0 = preds.shape[0]
578
+ preds_output = torch.empty((shape0 * self.sp_world_size, preds.shape[1]),
579
+ dtype=preds.dtype,
580
+ device=preds.device)
581
+ dist.all_gather_into_tensor(preds_output, preds, group=self.sp_group)
582
+ preds_output = torch.cat(preds_output.split(shape0, dim=0), dim=1)
583
+ shape0 = labels.shape[0]
584
+ labels_output = torch.empty((shape0 * self.sp_world_size, labels.shape[1]),
585
+ dtype=labels.dtype,
586
+ device=labels.device)
587
+ dist.all_gather_into_tensor(labels_output, labels, group=self.sp_group)
588
+ labels_output = torch.cat(labels_output.split(shape0, dim=0), dim=1)
589
+ # roll back to fit compute_acc
590
+ labels_output = torch.roll(labels_output, shifts=1, dims=1)
591
+ return compute_acc_origin(preds_output, labels_output, *args, **kwargs)
592
+
593
+ metric.compute_acc = compute_acc
594
+ mixin.compute_acc = compute_acc
ms-swift/swift/trainers/sequence_parallel/xtuner.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any
3
+
4
+ import datasets
5
+ import torch
6
+ import torch.distributed as dist
7
+ from datasets import Dataset
8
+ from torch.utils.data import DataLoader
9
+ from transformers.trainer_utils import seed_worker
10
+
11
+ from .base import SequenceParallel
12
+
13
+
14
+ class XTuner(SequenceParallel):
15
+
16
+ @staticmethod
17
+ def assert_xtuner_runtime_condition():
18
+ from swift.utils import is_xtuner_available
19
+ assert is_xtuner_available(), \
20
+ ('Please install XTuner first to pack dataset to `max_length`.'
21
+ '`pip install -U \'xtuner[deepspeed]\'`')
22
+ assert dist.is_initialized(), 'pack_to_max_length is only available with distributed training.'
23
+
24
+ def pack_dataset_xtuner(self, dataset: Dataset, args: Any) -> Any:
25
+ self.assert_xtuner_runtime_condition()
26
+ if dist.get_rank() == 0:
27
+ ds = [i[0] for i in dataset.data]
28
+ train_dataset = Dataset.from_list(ds)
29
+ from xtuner.dataset.huggingface import pack_dataset
30
+ train_dataset = pack_dataset(
31
+ train_dataset,
32
+ max_length=args.max_length,
33
+ use_varlen_attn=False,
34
+ shuffle_before_pack=True,
35
+ map_num_proc=16)
36
+ objects = [train_dataset]
37
+ train_dataset.save_to_disk('alpaca_pack')
38
+ else:
39
+ objects = [None]
40
+ dist.broadcast_object_list(objects, src=0)
41
+ train_dataset = objects[0]
42
+ return train_dataset
43
+
44
+ @property
45
+ def sp_group(self):
46
+ from xtuner.parallel.sequence import get_sequence_parallel_group
47
+ return get_sequence_parallel_group()
48
+
49
+ def init_sequence_parallel(self, size):
50
+ self.assert_xtuner_runtime_condition()
51
+ from xtuner.parallel.sequence import init_sequence_parallel
52
+ init_sequence_parallel(size)
53
+
54
+ def prepare_model(self, model, tokenizer, split_in_forward):
55
+ self.assert_xtuner_runtime_condition()
56
+ from xtuner.model.modules.dispatch import dispatch_modules
57
+ dispatch_modules(model)
58
+
59
+ def pad_and_split_inputs(self,
60
+ tokenizer,
61
+ input_ids,
62
+ input_embeds,
63
+ labels,
64
+ position_ids,
65
+ attention_mask,
66
+ loss_scale,
67
+ embed_tokens=None):
68
+ self.assert_xtuner_runtime_condition()
69
+ from xtuner.parallel.sequence import (pad_for_sequence_parallel, split_for_sequence_parallel,
70
+ get_sequence_parallel_group)
71
+ input_ids = pad_for_sequence_parallel(input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
72
+ labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
73
+ position_ids = pad_for_sequence_parallel(position_ids, padding_value=0, dim=-1)
74
+ if attention_mask is not None:
75
+ attention_mask = pad_for_sequence_parallel(attention_mask, padding_value=0, dim=-1)
76
+
77
+ sp_group = get_sequence_parallel_group()
78
+ input_ids = split_for_sequence_parallel(input_ids, dim=1, sp_group=sp_group)
79
+ labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
80
+ position_ids = split_for_sequence_parallel(position_ids, dim=1, sp_group=sp_group)
81
+ if attention_mask is not None:
82
+ attention_mask = split_for_sequence_parallel(attention_mask, dim=-1, sp_group=sp_group)
83
+ if loss_scale is not None:
84
+ loss_scale = pad_for_sequence_parallel(loss_scale, padding_value=0., dim=-1)
85
+ loss_scale = split_for_sequence_parallel(loss_scale, dim=1, sp_group=sp_group)
86
+
87
+ return input_ids, None, labels, position_ids, attention_mask, loss_scale
88
+
89
+ def reduce_outputs(self, loss, labels):
90
+ from xtuner.parallel.sequence import (reduce_sequence_parallel_loss, get_sequence_parallel_group)
91
+ # reduce loss for logging correctly
92
+ num_tokens = (labels != -100).sum()
93
+ return reduce_sequence_parallel_loss(loss, num_tokens, get_sequence_parallel_group())
94
+
95
+ def world_size(self):
96
+ self.assert_xtuner_runtime_condition()
97
+ from xtuner.parallel.sequence import get_sequence_parallel_world_size
98
+ return get_sequence_parallel_world_size()
99
+
100
+ def prepare_trainer(self, trainer):
101
+ pass
102
+
103
+ def get_dataloader(self, trainer, dataset, batch_size):
104
+ # modified from HFTrainer.get_train_dataloader
105
+ # RandomSampler -> SequenceParallelSampler
106
+ self.assert_xtuner_runtime_condition()
107
+ data_collator = trainer.data_collator
108
+ if isinstance(dataset, datasets.Dataset):
109
+ dataset = trainer._remove_unused_columns(dataset, description='training')
110
+ else:
111
+ data_collator = trainer._get_collator_with_removed_columns(data_collator, description='training')
112
+
113
+ dataloader_params = {
114
+ 'batch_size': batch_size,
115
+ 'collate_fn': data_collator,
116
+ 'num_workers': trainer.args.dataloader_num_workers,
117
+ 'pin_memory': trainer.args.dataloader_pin_memory,
118
+ 'persistent_workers': trainer.args.dataloader_persistent_workers,
119
+ }
120
+
121
+ if not isinstance(dataset, torch.utils.data.IterableDataset):
122
+ from xtuner.parallel import SequenceParallelSampler
123
+ dataloader_params['sampler'] = SequenceParallelSampler(dataset, seed=1024)
124
+ dataloader_params['drop_last'] = trainer.args.dataloader_drop_last
125
+ dataloader_params['worker_init_fn'] = seed_worker
126
+
127
+ return DataLoader(dataset, **dataloader_params)
ms-swift/swift/trainers/torchacc_mixin.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import shutil
4
+ from typing import Optional
5
+
6
+ from transformers import PreTrainedModel, is_datasets_available
7
+
8
+ from swift.utils import use_torchacc
9
+ from swift.utils.torchacc_utils import (patch_clip_grad_norm, save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint,
10
+ ta_eval_dataloader, ta_load_optimizer_and_scheduler,
11
+ ta_save_optimizer_and_scheduler, ta_test_dataloader, ta_train_dataloader,
12
+ ta_trim_graph)
13
+
14
+
15
+ class TorchAccMixin:
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ if use_torchacc():
19
+ patch_clip_grad_norm(self.accelerator)
20
+ super().__init__(*args, **kwargs)
21
+
22
+ def get_train_dataloader(self):
23
+ if not use_torchacc():
24
+ return super().get_train_dataloader()
25
+
26
+ if is_datasets_available():
27
+ import datasets
28
+
29
+ if self.train_dataset is None:
30
+ raise ValueError('Trainer: training requires a train_dataset.')
31
+
32
+ train_dataset = self.train_dataset
33
+ data_collator = self.data_collator
34
+
35
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
36
+ train_dataset = self._remove_unused_columns(train_dataset, description='training')
37
+ else:
38
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
39
+
40
+ return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args,
41
+ self._train_batch_size)
42
+
43
+ def get_eval_dataloader(self, eval_dataset=None):
44
+
45
+ if not use_torchacc():
46
+ return super().get_eval_dataloader(eval_dataset)
47
+
48
+ if is_datasets_available():
49
+ import datasets
50
+
51
+ if eval_dataset is None and self.eval_dataset is None:
52
+ raise ValueError('Trainer: evaluation requires an eval_dataset.')
53
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
54
+ data_collator = self.data_collator
55
+
56
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
57
+ eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation')
58
+ else:
59
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation')
60
+
61
+ return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args)
62
+
63
+ def get_test_dataloader(self, test_dataset):
64
+
65
+ if not use_torchacc():
66
+ return super().get_test_dataloader(test_dataset)
67
+
68
+ if is_datasets_available():
69
+ import datasets
70
+
71
+ data_collator = self.data_collator
72
+
73
+ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
74
+ test_dataset = self._remove_unused_columns(test_dataset, description='test')
75
+ else:
76
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='test')
77
+
78
+ return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args)
79
+
80
+ def _save_tpu(self, output_dir: Optional[str] = None):
81
+
82
+ if not use_torchacc():
83
+ return super()._save_tpu(output_dir)
84
+
85
+ import torch_xla.core.xla_model as xm
86
+
87
+ # Compatible with swift and peft
88
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
89
+
90
+ if xm.is_master_ordinal(local=False):
91
+ os.makedirs(output_dir, exist_ok=True)
92
+ # configuration.json
93
+ model_dir = getattr(self.model, 'model_dir', None)
94
+ if model_dir is not None:
95
+ src_path = os.path.join(model_dir, 'configuration.json')
96
+ dst_path = os.path.join(output_dir, 'configuration.json')
97
+ if os.path.exists(src_path):
98
+ shutil.copy(src_path, dst_path)
99
+ else:
100
+ self._create_configuration_file(self.model, output_dir)
101
+ self._save_sft_args(output_dir)
102
+ # generation_config
103
+ generation_config = getattr(self.args, 'generation_config', None)
104
+ if generation_config is not None:
105
+ generation_config.save_pretrained(output_dir)
106
+
107
+ # model
108
+ if self.args.fsdp_num > 1:
109
+ save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
110
+ else:
111
+ save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
112
+
113
+ # additional files
114
+ if xm.is_master_ordinal(local=False):
115
+ if self.args is not None and self.args.sft_type == 'full':
116
+ additional_files = getattr(self.args, 'additional_saved_files',
117
+ None) or [] + ['preprocessor_config.json']
118
+ if model_dir is not None:
119
+ for file in additional_files:
120
+ src_path = os.path.join(model_dir, file)
121
+ dst_path = os.path.join(output_dir, file)
122
+ if os.path.isfile(src_path):
123
+ shutil.copy(src_path, dst_path)
124
+ elif os.path.isdir(src_path):
125
+ shutil.copytree(src_path, dst_path)
126
+
127
+ def _load_optimizer_and_scheduler(self, checkpoint):
128
+
129
+ if not use_torchacc() or self.args.fsdp_num == 1:
130
+ return super()._load_optimizer_and_scheduler(checkpoint)
131
+
132
+ self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler,
133
+ checkpoint, self.args.device)
134
+
135
+ def _save_optimizer_and_scheduler(self, output_dir):
136
+ if not use_torchacc() or not self.args.fsdp_num == 1:
137
+ return super()._save_optimizer_and_scheduler(output_dir)
138
+
139
+ return ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir)
140
+
141
+ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
142
+ if use_torchacc() and self.control.should_log:
143
+ ta_trim_graph()
144
+ super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
145
+
146
+ def _load_from_checkpoint(self, resume_from_checkpoint: str, model=None) -> None:
147
+ if use_torchacc():
148
+ if model is None:
149
+ model = self.model
150
+ # Loading checkpoint of TorchAcc has been done in tuner.py when
151
+ # sft_type is 'full'.
152
+ if self.args.fsdp_num > 1:
153
+ model = model._get_underlay_model().module.module
154
+ if isinstance(model, PreTrainedModel):
155
+ return
156
+ return super()._load_from_checkpoint(resume_from_checkpoint, model)