Student0809 commited on
Commit
35dfdd4
·
verified ·
1 Parent(s): 636c5b4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ms-swift/processed_data/processed_silence_speaker_segments.json +0 -0
  2. ms-swift/processed_data/processed_silence_transcriptions.json +0 -0
  3. ms-swift/silence_overlaps/700/original/overlap5s_segments.json +0 -0
  4. ms-swift/silence_overlaps/700/original/overlap5s_silence_segments.json +0 -0
  5. ms-swift/silence_overlaps/700/test/.ipynb_checkpoints/overlap5s_segments_test-checkpoint.json +27 -0
  6. ms-swift/silence_overlaps/700/train/silence_issilence_train.json +0 -0
  7. ms-swift/silence_overlaps/700/train/silence_speaker_segments_train.json +0 -0
  8. ms-swift/silence_overlaps/cleaned_transcriptions2.json +0 -0
  9. ms-swift/swift/llm/model/model/__pycache__/__init__.cpython-310.pyc +0 -0
  10. ms-swift/swift/llm/model/model/__pycache__/bert.cpython-310.pyc +0 -0
  11. ms-swift/swift/llm/model/model/__pycache__/codefuse.cpython-310.pyc +0 -0
  12. ms-swift/swift/llm/model/model/__pycache__/deepseek.cpython-310.pyc +0 -0
  13. ms-swift/swift/llm/model/model/__pycache__/llm.cpython-310.pyc +0 -0
  14. ms-swift/swift/llm/model/model/__pycache__/mamba.cpython-310.pyc +0 -0
  15. ms-swift/swift/llm/model/model/__pycache__/microsoft.cpython-310.pyc +0 -0
  16. ms-swift/swift/llm/model/model/__pycache__/minicpm.cpython-310.pyc +0 -0
  17. ms-swift/swift/llm/model/model/__pycache__/minimax.cpython-310.pyc +0 -0
  18. ms-swift/swift/llm/model/model/__pycache__/mllm.cpython-310.pyc +0 -0
  19. ms-swift/swift/llm/model/model/__pycache__/moonshot.cpython-310.pyc +0 -0
  20. ms-swift/swift/llm/model/model/__pycache__/mplug.cpython-310.pyc +0 -0
  21. ms-swift/swift/llm/model/model/__pycache__/openbuddy.cpython-310.pyc +0 -0
  22. ms-swift/swift/llm/model/model/__pycache__/qwen.cpython-310.pyc +0 -0
  23. ms-swift/swift/llm/model/model/__pycache__/skywork.cpython-310.pyc +0 -0
  24. ms-swift/swift/llm/model/model/__pycache__/stepfun.cpython-310.pyc +0 -0
  25. ms-swift/swift/llm/model/model/__pycache__/telechat.cpython-310.pyc +0 -0
  26. ms-swift/swift/llm/model/model/__pycache__/valley.cpython-310.pyc +0 -0
  27. ms-swift/swift/llm/model/model/__pycache__/yi.cpython-310.pyc +0 -0
  28. ms-swift/swift/llm/model/model/glm.py +371 -0
  29. ms-swift/swift/llm/model/model/minimax.py +156 -0
  30. ms-swift/swift/llm/model/model/openbuddy.py +85 -0
  31. ms-swift/swift/llm/model/model/qwen.py +899 -0
  32. ms-swift/swift/llm/model/model/skywork.py +89 -0
  33. ms-swift/swift/llm/model/model/yi.py +146 -0
  34. ms-swift/swift/llm/sampling/__init__.py +1 -0
  35. ms-swift/swift/llm/sampling/base.py +55 -0
  36. ms-swift/swift/llm/sampling/sampling.py +106 -0
  37. ms-swift/swift/llm/sampling/utils.py +172 -0
  38. ms-swift/swift/llm/sampling/vanilla_sampler.py +216 -0
  39. ms-swift/swift/llm/template/__init__.py +10 -0
  40. ms-swift/swift/llm/template/__pycache__/__init__.cpython-310.pyc +0 -0
  41. ms-swift/swift/llm/template/__pycache__/base.cpython-310.pyc +0 -0
  42. ms-swift/swift/llm/template/__pycache__/grounding.cpython-310.pyc +0 -0
  43. ms-swift/swift/llm/template/__pycache__/register.cpython-310.pyc +0 -0
  44. ms-swift/swift/llm/template/__pycache__/template_inputs.cpython-310.pyc +0 -0
  45. ms-swift/swift/llm/template/__pycache__/utils.cpython-310.pyc +0 -0
  46. ms-swift/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc +0 -0
  47. ms-swift/swift/llm/template/base.py +1606 -0
  48. ms-swift/swift/llm/template/grounding.py +91 -0
  49. ms-swift/swift/llm/template/register.py +57 -0
  50. ms-swift/swift/llm/template/template/__init__.py +2 -0
ms-swift/processed_data/processed_silence_speaker_segments.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/processed_data/processed_silence_transcriptions.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/original/overlap5s_segments.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/original/overlap5s_silence_segments.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/test/.ipynb_checkpoints/overlap5s_segments_test-checkpoint.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "key": "SODA_PROCESSED--train--123906",
4
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--123906.wav",
5
+ "model_output": "Multiple speakers talk simultaneously from 00:03-00:09"
6
+ },
7
+ {
8
+ "key": "SODA_PROCESSED--train--1112763",
9
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--1112763.wav",
10
+ "model_output": "Multiple speakers talk simultaneously from 00:09-00:15"
11
+ },
12
+ {
13
+ "key": "SODA_PROCESSED--train--790538",
14
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--790538.wav",
15
+ "model_output": "Multiple speakers talk simultaneously from 00:15-00:19"
16
+ },
17
+ {
18
+ "key": "SODA_PROCESSED--train--822773",
19
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--822773.wav",
20
+ "model_output": "Multiple speakers talk simultaneously from 00:14-00:19"
21
+ },
22
+ {
23
+ "key": "SODA_PROCESSED--train--424960",
24
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--424960.wav",
25
+ "model_output": "Multiple speakers talk simultaneously from 00:29-00:33"
26
+ }
27
+ ]
ms-swift/silence_overlaps/700/train/silence_issilence_train.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/train/silence_speaker_segments_train.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/cleaned_transcriptions2.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/swift/llm/model/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (632 Bytes). View file
 
ms-swift/swift/llm/model/model/__pycache__/bert.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/codefuse.cpython-310.pyc ADDED
Binary file (1.72 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/deepseek.cpython-310.pyc ADDED
Binary file (6.45 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/llm.cpython-310.pyc ADDED
Binary file (6.81 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/mamba.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/microsoft.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/minicpm.cpython-310.pyc ADDED
Binary file (4.59 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/minimax.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/mllm.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/moonshot.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/mplug.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/openbuddy.cpython-310.pyc ADDED
Binary file (1.83 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/qwen.cpython-310.pyc ADDED
Binary file (21.9 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/skywork.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/stepfun.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/telechat.cpython-310.pyc ADDED
Binary file (1.74 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/valley.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
ms-swift/swift/llm/model/model/__pycache__/yi.cpython-310.pyc ADDED
Binary file (4.49 kB). View file
 
ms-swift/swift/llm/model/model/glm.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, Type
3
+
4
+ import torch
5
+ import transformers
6
+ from packaging import version
7
+ from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
8
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
9
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
10
+
11
+ from swift.llm import TemplateType
12
+ from swift.utils import get_device_count, get_dist_setting, get_logger
13
+ from ..constant import LLMModelType, MLLMModelType
14
+ from ..model_arch import ModelArch
15
+ from ..patcher import patch_output_to_input_device
16
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
17
+ from ..utils import AttnImpl, ModelInfo, safe_snapshot_download
18
+
19
+ logger = get_logger()
20
+
21
+
22
+ def remove_property(tokenizer_cls: Type[PreTrainedTokenizerBase], tokenizer_config: Dict[str, Any]) -> None:
23
+ for k, v in tokenizer_cls.__dict__.items():
24
+ if k.endswith('_token') and isinstance(v, property) and k in tokenizer_config:
25
+ setattr(tokenizer_cls, k, tokenizer_config[k])
26
+
27
+
28
+ def get_model_tokenizer_chatglm(model_dir: str,
29
+ model_info: ModelInfo,
30
+ model_kwargs: Dict[str, Any],
31
+ load_model: bool = True,
32
+ **kwargs):
33
+ if model_kwargs.get('quantization_config') is not None:
34
+ model_kwargs['quantization_config'].llm_int8_skip_modules = ['output_layer']
35
+ # fix transformers>=4.34 bug
36
+ if version.parse(transformers.__version__) >= version.parse('4.34'):
37
+ tokenizer_config = get_tokenizer_config(model_dir)
38
+ class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
39
+ tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
40
+ tokenizer_cls._auto_class = 'AutoTokenizer'
41
+ remove_property(tokenizer_cls, tokenizer_config)
42
+ kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
43
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
44
+ if model is not None:
45
+ from torch.nn import CrossEntropyLoss
46
+ __old_forward = CrossEntropyLoss.forward
47
+
48
+ def cross_entropy_forward(self, inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
49
+ target = target.to(device=inputs.device)
50
+ return __old_forward(self, inputs, target)
51
+
52
+ CrossEntropyLoss.forward = cross_entropy_forward
53
+
54
+ return model, tokenizer
55
+
56
+
57
+ register_model(
58
+ ModelMeta(
59
+ LLMModelType.chatglm2, [
60
+ ModelGroup([
61
+ Model('ZhipuAI/chatglm2-6b', 'THUDM/chatglm2-6b'),
62
+ Model('ZhipuAI/chatglm2-6b-32k', 'THUDM/chatglm2-6b-32k')
63
+ ],
64
+ requires=['transformers<4.42']),
65
+ ModelGroup(
66
+ [Model('ZhipuAI/codegeex2-6b', 'THUDM/codegeex2-6b')],
67
+ requires=['transformers<4.34'],
68
+ tags=['coding'],
69
+ ),
70
+ ],
71
+ TemplateType.chatglm2,
72
+ get_model_tokenizer_chatglm,
73
+ architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
74
+ model_arch=ModelArch.chatglm))
75
+
76
+ register_model(
77
+ ModelMeta(
78
+ LLMModelType.chatglm3, [
79
+ ModelGroup([
80
+ Model('ZhipuAI/chatglm3-6b', 'THUDM/chatglm3-6b'),
81
+ Model('ZhipuAI/chatglm3-6b-base', 'THUDM/chatglm3-6b-base'),
82
+ Model('ZhipuAI/chatglm3-6b-32k', 'THUDM/chatglm3-6b-32k'),
83
+ Model('ZhipuAI/chatglm3-6b-128k', 'THUDM/chatglm3-6b-128k'),
84
+ ])
85
+ ],
86
+ TemplateType.glm4,
87
+ get_model_tokenizer_chatglm,
88
+ architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
89
+ requires=['transformers<4.42'],
90
+ model_arch=ModelArch.chatglm))
91
+
92
+
93
+ def get_model_tokenizer_glm4(model_dir: str,
94
+ model_info: ModelInfo,
95
+ model_kwargs: Dict[str, Any],
96
+ load_model: bool = True,
97
+ **kwargs):
98
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
99
+ AttnImpl.update_attn_impl(model_config, kwargs.get('attn_impl'))
100
+ kwargs['model_config'] = model_config
101
+ model, tokenizer = get_model_tokenizer_chatglm(model_dir, model_info, model_kwargs, load_model, **kwargs)
102
+ if len(tokenizer.encode('<|user|>', add_special_tokens=False)) > 1:
103
+ for k in tokenizer.special_tokens.keys():
104
+ tokenizer.add_tokens(k)
105
+ return model, tokenizer
106
+
107
+
108
+ register_model(
109
+ ModelMeta(
110
+ LLMModelType.glm4,
111
+ [
112
+ ModelGroup([
113
+ Model('ZhipuAI/glm-4-9b-chat', 'THUDM/glm-4-9b-chat'),
114
+ Model('ZhipuAI/glm-4-9b', 'THUDM/glm-4-9b'),
115
+ Model('ZhipuAI/glm-4-9b-chat-1m', 'THUDM/glm-4-9b-chat-1m'),
116
+ ]),
117
+ ModelGroup([
118
+ Model('ZhipuAI/LongWriter-glm4-9b', 'THUDM/LongWriter-glm4-9b'),
119
+ ])
120
+ ],
121
+ TemplateType.glm4,
122
+ get_model_tokenizer_glm4,
123
+ architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
124
+ model_arch=ModelArch.chatglm,
125
+ requires=['transformers>=4.42'],
126
+ ))
127
+
128
+ register_model(
129
+ ModelMeta(
130
+ LLMModelType.glm4_0414,
131
+ [
132
+ ModelGroup([
133
+ Model('ZhipuAI/GLM-4-9B-0414', 'THUDM/GLM-4-9B-0414'),
134
+ Model('ZhipuAI/GLM-4-32B-0414', 'THUDM/GLM-4-32B-0414'),
135
+ Model('ZhipuAI/GLM-4-32B-Base-0414', 'THUDM/GLM-4-32B-Base-0414'),
136
+ Model('ZhipuAI/GLM-Z1-9B-0414', 'THUDM/GLM-Z1-9B-0414'),
137
+ Model('ZhipuAI/GLM-Z1-32B-0414', 'THUDM/GLM-Z1-32B-0414'),
138
+ ])
139
+ ],
140
+ TemplateType.glm4_0414,
141
+ get_model_tokenizer_with_flash_attn,
142
+ architectures=['Glm4ForCausalLM'],
143
+ model_arch=ModelArch.chatglm,
144
+ requires=['transformers>=4.51'],
145
+ ))
146
+
147
+ register_model(
148
+ ModelMeta(
149
+ LLMModelType.glm4_z1_rumination,
150
+ [ModelGroup([
151
+ Model('ZhipuAI/GLM-Z1-Rumination-32B-0414', 'THUDM/GLM-Z1-Rumination-32B-0414'),
152
+ ])],
153
+ TemplateType.glm4_z1_rumination,
154
+ get_model_tokenizer_with_flash_attn,
155
+ architectures=['Glm4ForCausalLM'],
156
+ model_arch=ModelArch.chatglm,
157
+ requires=['transformers>4.51'],
158
+ ))
159
+
160
+ register_model(
161
+ ModelMeta(
162
+ LLMModelType.longwriter_llama3_1,
163
+ [ModelGroup([
164
+ Model('ZhipuAI/LongWriter-llama3.1-8b', 'THUDM/LongWriter-llama3.1-8b'),
165
+ ])],
166
+ TemplateType.longwriter_llama,
167
+ get_model_tokenizer_with_flash_attn,
168
+ architectures=['LlamaForCausalLM'],
169
+ requires=['transformers>=4.43'],
170
+ model_arch=ModelArch.llama,
171
+ ))
172
+
173
+ register_model(
174
+ ModelMeta(
175
+ LLMModelType.codegeex4,
176
+ [ModelGroup([
177
+ Model('ZhipuAI/codegeex4-all-9b', 'THUDM/codegeex4-all-9b'),
178
+ ])],
179
+ TemplateType.codegeex4,
180
+ get_model_tokenizer_glm4,
181
+ requires=['transformers<4.42'],
182
+ architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
183
+ model_arch=ModelArch.chatglm,
184
+ tags=['coding'],
185
+ ))
186
+
187
+
188
+ def get_model_tokenizer_glm4v(model_dir: str,
189
+ model_info: ModelInfo,
190
+ model_kwargs: Dict[str, Any],
191
+ load_model: bool = True,
192
+ **kwargs):
193
+ model, tokenizer = get_model_tokenizer_glm4(model_dir, model_info, model_kwargs, load_model, **kwargs)
194
+ # fix merge-lora
195
+ tokenizer.init_kwargs['image_size'] = 1120
196
+ if load_model:
197
+ # fix device_map 4
198
+ n_gpu = get_device_count()
199
+ local_world_size = get_dist_setting()[3]
200
+ if n_gpu // local_world_size >= 4:
201
+ for layer in model.transformer.vision.transformer.layers:
202
+ patch_output_to_input_device(layer.mlp)
203
+ patch_output_to_input_device(layer.post_attention_layernorm)
204
+ device = next(model.transformer.vision.linear_proj.parameters()).device
205
+ model.transformer.vision.boi.data = model.transformer.vision.boi.to(device)
206
+ model.transformer.vision.eoi.data = model.transformer.vision.eoi.to(device)
207
+ return model, tokenizer
208
+
209
+
210
+ register_model(
211
+ ModelMeta(
212
+ MLLMModelType.glm4v,
213
+ [
214
+ ModelGroup(
215
+ [
216
+ Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'),
217
+ ],
218
+ requires=['transformers>=4.42,<4.45'],
219
+ ),
220
+ ModelGroup(
221
+ [
222
+ Model('ZhipuAI/cogagent-9b-20241220', 'THUDM/cogagent-9b-20241220'),
223
+ ],
224
+ requires=['transformers>=4.42'],
225
+ )
226
+ ],
227
+ TemplateType.glm4v,
228
+ get_model_tokenizer_glm4v,
229
+ architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
230
+ model_arch=ModelArch.glm4v,
231
+ ))
232
+
233
+
234
+ def get_model_tokenizer_cogvlm(model_dir: str,
235
+ model_info: ModelInfo,
236
+ model_kwargs: Dict[str, Any],
237
+ load_model: bool = True,
238
+ **kwargs):
239
+ tokenizer_dir = safe_snapshot_download('AI-ModelScope/vicuna-7b-v1.5', download_model=False, check_local=True)
240
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=True)
241
+ if load_model:
242
+ logger.warning('CogAgent with FusedLayerNorm will cause an training loss of NAN, '
243
+ 'to avoid this, please uninstall apex.')
244
+ logger.info('Please ignore the unimported warning.')
245
+ model, tokenizer = get_model_tokenizer_with_flash_attn(
246
+ model_dir, model_info, model_kwargs, load_model, tokenizer=tokenizer, **kwargs)
247
+ return model, tokenizer
248
+
249
+
250
+ register_model(
251
+ ModelMeta(
252
+ MLLMModelType.cogvlm, [
253
+ ModelGroup([
254
+ Model('ZhipuAI/cogvlm-chat', 'THUDM/cogvlm-chat-hf'),
255
+ ]),
256
+ ],
257
+ TemplateType.cogvlm,
258
+ get_model_tokenizer_cogvlm,
259
+ architectures=['CogVLMForCausalLM'],
260
+ requires=['transformers<4.42'],
261
+ model_arch=ModelArch.cogvlm))
262
+
263
+ register_model(
264
+ ModelMeta(
265
+ MLLMModelType.cogagent_chat, [
266
+ ModelGroup([
267
+ Model('ZhipuAI/cogagent-chat', 'THUDM/cogagent-chat-hf'),
268
+ ]),
269
+ ],
270
+ TemplateType.cogagent_chat,
271
+ get_model_tokenizer_cogvlm,
272
+ architectures=['CogAgentForCausalLM'],
273
+ requires=['transformers<4.42', 'timm'],
274
+ model_arch=ModelArch.cogvlm))
275
+
276
+ register_model(
277
+ ModelMeta(
278
+ MLLMModelType.cogagent_vqa, [ModelGroup([
279
+ Model('ZhipuAI/cogagent-vqa', 'THUDM/cogagent-vqa-hf'),
280
+ ])],
281
+ TemplateType.cogagent_vqa,
282
+ get_model_tokenizer_cogvlm,
283
+ architectures=['CogAgentForCausalLM'],
284
+ requires=['transformers<4.42'],
285
+ model_arch=ModelArch.cogvlm))
286
+
287
+
288
+ def get_model_tokenizer_cogvlm2(*args, **kwargs):
289
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
290
+ if model is not None:
291
+ # fix device map 4
292
+ for layer in model.model.vision.transformer.layers:
293
+ patch_output_to_input_device(layer.mlp)
294
+ patch_output_to_input_device(layer.post_attention_layernorm)
295
+
296
+ device = next(model.model.vision.linear_proj.parameters()).device
297
+ model.model.vision.boi.data = model.model.vision.boi.to(device)
298
+ model.model.vision.eoi.data = model.model.vision.eoi.to(device)
299
+ return model, tokenizer
300
+
301
+
302
+ register_model(
303
+ ModelMeta(
304
+ MLLMModelType.cogvlm2, [
305
+ ModelGroup([
306
+ Model('ZhipuAI/cogvlm2-llama3-chat-19B', 'THUDM/cogvlm2-llama3-chat-19B'),
307
+ Model('ZhipuAI/cogvlm2-llama3-chinese-chat-19B', 'THUDM/cogvlm2-llama3-chinese-chat-19B'),
308
+ ]),
309
+ ],
310
+ TemplateType.cogvlm2,
311
+ get_model_tokenizer_cogvlm2,
312
+ architectures=['CogVLMForCausalLM'],
313
+ requires=['transformers<4.42'],
314
+ model_arch=ModelArch.cogvlm))
315
+
316
+ register_model(
317
+ ModelMeta(
318
+ MLLMModelType.cogvlm2_video,
319
+ [
320
+ ModelGroup([
321
+ Model('ZhipuAI/cogvlm2-video-llama3-chat', 'THUDM/cogvlm2-video-llama3-chat'),
322
+ ]),
323
+ ],
324
+ TemplateType.cogvlm2_video,
325
+ get_model_tokenizer_cogvlm2,
326
+ architectures=['CogVLMVideoForCausalLM'],
327
+ requires=['decord', 'pytorchvideo', 'transformers>=4.42'],
328
+ model_arch=ModelArch.cogvlm,
329
+ tags=['video'],
330
+ ))
331
+
332
+ register_model(
333
+ ModelMeta(
334
+ LLMModelType.glm_edge,
335
+ [
336
+ ModelGroup([
337
+ Model('ZhipuAI/glm-edge-1.5b-chat', 'THUDM/glm-edge-1.5b-chat'),
338
+ Model('ZhipuAI/glm-edge-4b-chat', 'THUDM/glm-edge-4b-chat'),
339
+ ]),
340
+ ],
341
+ TemplateType.glm4,
342
+ get_model_tokenizer_with_flash_attn,
343
+ architectures=['GlmForCausalLM'],
344
+ requires=['transformers>=4.46'],
345
+ ))
346
+
347
+
348
+ def get_model_tokenizer_glm_edge_v(model_dir: str, *args, **kwargs):
349
+ from transformers import AutoImageProcessor
350
+ processor = AutoImageProcessor.from_pretrained(model_dir)
351
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
352
+ processor.tokenizer = tokenizer
353
+ return model, processor
354
+
355
+
356
+ register_model(
357
+ ModelMeta(
358
+ MLLMModelType.glm_edge_v,
359
+ [
360
+ ModelGroup([
361
+ Model('ZhipuAI/glm-edge-v-2b', 'THUDM/glm-edge-v-2b'),
362
+ Model('ZhipuAI/glm-edge-4b-chat', 'THUDM/glm-edge-4b-chat'),
363
+ ]),
364
+ ],
365
+ TemplateType.glm_edge_v,
366
+ get_model_tokenizer_glm_edge_v,
367
+ architectures=['GlmForCausalLM'],
368
+ requires=['transformers>=4.46'],
369
+ model_arch=ModelArch.glm_edge_v,
370
+ tags=['vision'],
371
+ ))
ms-swift/swift/llm/model/model/minimax.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import json
6
+ from transformers import AutoConfig, AutoProcessor
7
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
8
+
9
+ from swift.llm import TemplateType
10
+ from swift.utils import get_device, get_device_count, get_dist_setting, get_logger
11
+ from ..constant import LLMModelType, MLLMModelType
12
+ from ..patcher import patch_ignore_check_imports
13
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
14
+ from ..utils import ModelInfo
15
+
16
+ logger = get_logger()
17
+
18
+
19
+ def get_model_tokenizer_minimax_vl(model_dir: str,
20
+ model_info: ModelInfo,
21
+ model_kwargs: Dict[str, Any],
22
+ load_model: bool = True,
23
+ **kwargs):
24
+ logger.warn('NOTE: minimax-vl-01 model does not support training.')
25
+ n_gpu = get_device_count()
26
+ _, local_rank, _, local_world_size = get_dist_setting()
27
+ device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
28
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
29
+ kwargs['model_config'] = config
30
+ if kwargs.get('attn_impl') == 'flash_attn':
31
+ config.attn_type_list = [1] * len(config.attn_type_list)
32
+ else:
33
+ config.attn_type_list = [0] * len(config.attn_type_list)
34
+ if 'quantization_config' in model_kwargs:
35
+ quantization_config = model_kwargs['quantization_config']
36
+ from transformers import QuantoConfig
37
+ if isinstance(quantization_config, QuantoConfig):
38
+ quantization_config.modules_to_not_convert = (
39
+ [
40
+ 'vision_tower',
41
+ 'image_newline',
42
+ 'multi_modal_projector',
43
+ 'lm_head',
44
+ 'embed_tokens',
45
+ ] + [f'model.layers.{i}.coefficient' for i in range(config.text_config.num_hidden_layers)]
46
+ + [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.text_config.num_hidden_layers)])
47
+
48
+ if len(device_ids) > 1:
49
+ model_safetensors_index_path = os.path.join(model_dir, 'model.safetensors.index.json')
50
+ with open(model_safetensors_index_path, 'r') as f:
51
+ model_safetensors_index = json.load(f)
52
+ weight_map = model_safetensors_index['weight_map']
53
+ vision_map = {}
54
+ for key, value in weight_map.items():
55
+ if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
56
+ new_key = key.replace('.weight', '').replace('.bias', '')
57
+ if new_key not in vision_map:
58
+ vision_map[new_key] = value
59
+
60
+ device_map = {
61
+ 'language_model.model.embed_tokens': get_device(device_ids[0]),
62
+ 'language_model.model.norm': get_device(device_ids[len(device_ids) - 1]),
63
+ 'language_model.lm_head': get_device(device_ids[len(device_ids) - 1])
64
+ }
65
+ for key, value in vision_map.items():
66
+ device_map[key] = get_device(device_ids[0])
67
+ device_map['vision_tower.vision_model.post_layernorm'] = get_device(device_ids[0])
68
+ layers_per_device = config.text_config.num_hidden_layers // len(device_ids)
69
+ for i in range(len(device_ids)):
70
+ for j in range(layers_per_device):
71
+ device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = get_device(device_ids[i])
72
+ model_kwargs['device_map'] = device_map
73
+
74
+ MiniMaxVL01ProcessorKwargs = get_class_from_dynamic_module('processing_minimax_vl_01.MiniMaxVL01ProcessorKwargs',
75
+ model_dir)
76
+ get_hw_multiple_of = get_class_from_dynamic_module('processing_minimax_vl_01.get_hw_multiple_of', model_dir)
77
+ get_num_token = get_class_from_dynamic_module('processing_minimax_vl_01.get_num_token', model_dir)
78
+
79
+ processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
80
+ processor.MiniMaxVL01ProcessorKwargs = MiniMaxVL01ProcessorKwargs
81
+ processor.get_hw_multiple_of = get_hw_multiple_of
82
+ processor.get_num_token = get_num_token
83
+ with patch_ignore_check_imports():
84
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
85
+ **kwargs)
86
+ processor.tokenizer = tokenizer
87
+ return model, processor
88
+
89
+
90
+ register_model(
91
+ ModelMeta(
92
+ MLLMModelType.minimax_vl, [
93
+ ModelGroup([
94
+ Model('MiniMax/MiniMax-VL-01', 'MiniMaxAI/MiniMax-VL-01'),
95
+ ]),
96
+ ],
97
+ TemplateType.minimax_vl,
98
+ get_model_tokenizer_minimax_vl,
99
+ architectures=['MiniMaxVL01ForConditionalGeneration'],
100
+ tags=['vision']))
101
+
102
+
103
+ def get_model_tokenizer_minimax_text(model_dir: str,
104
+ model_info: ModelInfo,
105
+ model_kwargs: Dict[str, Any],
106
+ load_model: bool = True,
107
+ **kwargs):
108
+ logger.warn('NOTE: minimax-text-01 model does not support training.')
109
+ n_gpu = get_device_count()
110
+ _, local_rank, _, local_world_size = get_dist_setting()
111
+ device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
112
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
113
+ kwargs['model_config'] = config
114
+ if kwargs.get('attn_impl') == 'flash_attn':
115
+ config.attn_type_list = [1] * len(config.attn_type_list)
116
+ else:
117
+ config.attn_type_list = [0] * len(config.attn_type_list)
118
+ if 'quantization_config' in model_kwargs:
119
+ quantization_config = model_kwargs['quantization_config']
120
+ from transformers import QuantoConfig
121
+ if isinstance(quantization_config, QuantoConfig):
122
+ quantization_config.modules_to_not_convert = (
123
+ [
124
+ 'lm_head',
125
+ 'embed_tokens',
126
+ ] + [f'model.layers.{i}.coefficient' for i in range(config.num_hidden_layers)]
127
+ + [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.num_hidden_layers)])
128
+
129
+ if len(device_ids) > 1:
130
+ layers_per_device = config.num_hidden_layers // len(device_ids)
131
+ # set device map
132
+ device_map = {
133
+ 'model.embed_tokens': get_device(0),
134
+ 'model.norm': get_device(len(device_ids) - 1),
135
+ 'lm_head': get_device(len(device_ids) - 1)
136
+ }
137
+ for i in range(len(device_ids)):
138
+ for j in range(layers_per_device):
139
+ device_map[f'model.layers.{i * layers_per_device + j}'] = get_device(i)
140
+ model_kwargs['device_map'] = device_map
141
+ with patch_ignore_check_imports():
142
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
143
+ **kwargs)
144
+ return model, tokenizer
145
+
146
+
147
+ register_model(
148
+ ModelMeta(
149
+ LLMModelType.minimax, [
150
+ ModelGroup([
151
+ Model('MiniMax/MiniMax-Text-01', 'MiniMaxAI/MiniMax-Text-01'),
152
+ ]),
153
+ ],
154
+ TemplateType.minimax,
155
+ get_model_tokenizer_minimax_text,
156
+ architectures=['MiniMaxText01ForCausalLM']))
ms-swift/swift/llm/model/model/openbuddy.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from swift.llm import TemplateType
3
+ from swift.utils import get_logger
4
+ from ..constant import LLMModelType
5
+ from ..model_arch import ModelArch
6
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
7
+
8
+ logger = get_logger()
9
+
10
+ register_model(
11
+ ModelMeta(
12
+ LLMModelType.openbuddy_llama,
13
+ [
14
+ ModelGroup([
15
+ Model('OpenBuddy/openbuddy-llama-65b-v8-bf16', 'OpenBuddy/openbuddy-llama-65b-v8-bf16'),
16
+ ]),
17
+ ModelGroup([
18
+ Model('OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16'),
19
+ Model('OpenBuddy/openbuddy-llama2-70b-v10.1-bf16', 'OpenBuddy/openbuddy-llama2-70b-v10.1-bf16'),
20
+ ]),
21
+ ModelGroup([
22
+ Model('OpenBuddy/openbuddy-deepseek-67b-v15.2', 'OpenBuddy/openbuddy-deepseek-67b-v15.2'),
23
+ ]),
24
+ ],
25
+ TemplateType.openbuddy,
26
+ get_model_tokenizer_with_flash_attn,
27
+ model_arch=ModelArch.llama,
28
+ architectures=['LlamaForCausalLM'],
29
+ ))
30
+
31
+ register_model(
32
+ ModelMeta(
33
+ LLMModelType.openbuddy_llama3,
34
+ [
35
+ ModelGroup([
36
+ Model('OpenBuddy/openbuddy-llama3-8b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-8b-v21.1-8k'),
37
+ Model('OpenBuddy/openbuddy-llama3-70b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-70b-v21.1-8k'),
38
+ Model('OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k', 'OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k'),
39
+ ]),
40
+ ModelGroup([
41
+ Model('OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k', 'OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k'),
42
+ Model('OpenBuddy/openbuddy-nemotron-70b-v23.2-131k', 'OpenBuddy/openbuddy-nemotron-70b-v23.2-131k'),
43
+ ],
44
+ requires=['transformers>=4.43']),
45
+ ModelGroup(
46
+ [Model('OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k', 'OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k')],
47
+ requires=['transformers>=4.45'])
48
+ ],
49
+ TemplateType.openbuddy2,
50
+ get_model_tokenizer_with_flash_attn,
51
+ model_arch=ModelArch.llama,
52
+ architectures=['LlamaForCausalLM'],
53
+ ))
54
+
55
+ register_model(
56
+ ModelMeta(
57
+ LLMModelType.openbuddy_mistral,
58
+ [
59
+ ModelGroup([
60
+ Model('OpenBuddy/openbuddy-mistral-7b-v17.1-32k', 'OpenBuddy/openbuddy-mistral-7b-v17.1-32k'),
61
+ ]),
62
+ ModelGroup([
63
+ Model('OpenBuddy/openbuddy-zephyr-7b-v14.1', 'OpenBuddy/openbuddy-zephyr-7b-v14.1'),
64
+ ]),
65
+ ],
66
+ TemplateType.openbuddy,
67
+ get_model_tokenizer_with_flash_attn,
68
+ model_arch=ModelArch.llama,
69
+ requires=['transformers>=4.34'],
70
+ architectures=['MistralForCausalLM'],
71
+ ))
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ LLMModelType.openbuddy_mixtral,
76
+ [
77
+ ModelGroup([
78
+ Model('OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k', 'OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k'),
79
+ ], ),
80
+ ],
81
+ TemplateType.openbuddy,
82
+ get_model_tokenizer_with_flash_attn,
83
+ architectures=['MixtralForCausalLM'],
84
+ requires=['transformers>=4.36'],
85
+ ))
ms-swift/swift/llm/model/model/qwen.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
6
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
7
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
8
+
9
+ from swift.llm import TemplateType
10
+ from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger
11
+ from ..constant import LLMModelType, MLLMModelType, RMModelType
12
+ from ..model_arch import ModelArch
13
+ from ..patcher import patch_fixed_device, patch_output_clone, patch_output_to_input_device
14
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_reward_model,
15
+ get_model_tokenizer_with_flash_attn, register_model)
16
+ from ..utils import AttnImpl, ModelInfo, use_submodel_func
17
+
18
+ logger = get_logger()
19
+ dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'}
20
+
21
+
22
+ def get_model_tokenizer_qwen(model_dir: str,
23
+ model_info: ModelInfo,
24
+ model_kwargs: Dict[str, Any],
25
+ load_model: bool = True,
26
+ model_config=None,
27
+ **kwargs):
28
+ if model_config is None:
29
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
30
+ if model_info.torch_dtype is not None:
31
+ k_true = dtype_mapping[model_info.torch_dtype]
32
+ for k in dtype_mapping.values():
33
+ setattr(model_config, k, k == k_true)
34
+
35
+ quantization_config = model_kwargs.get('quantization_config')
36
+ if not isinstance(quantization_config, BitsAndBytesConfig):
37
+ # not bnb quant
38
+ model_config.torch_dtype = None
39
+ use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
40
+ model_config.use_flash_attn = use_flash_attn
41
+ kwargs['model_config'] = model_config
42
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
43
+ try:
44
+ # fix mp+ddp bug
45
+ model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
46
+ logger.info('registered_causal_mask to cuda')
47
+ except AttributeError:
48
+ pass
49
+ if tokenizer.eos_token_id is None:
50
+ tokenizer.eos_token_id = tokenizer.eod_id
51
+ return model, tokenizer
52
+
53
+
54
+ register_model(
55
+ ModelMeta(
56
+ LLMModelType.qwen,
57
+ [
58
+ # qwen
59
+ ModelGroup([
60
+ # chat
61
+ Model('Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-1_8B-Chat'),
62
+ Model('Qwen/Qwen-7B-Chat', 'Qwen/Qwen-7B-Chat'),
63
+ Model('Qwen/Qwen-14B-Chat', 'Qwen/Qwen-14B-Chat'),
64
+ Model('Qwen/Qwen-72B-Chat', 'Qwen/Qwen-72B-Chat'),
65
+ # base
66
+ Model('Qwen/Qwen-1_8B', 'Qwen/Qwen-1_8B'),
67
+ Model('Qwen/Qwen-7B', 'Qwen/Qwen-7B'),
68
+ Model('Qwen/Qwen-14B', 'Qwen/Qwen-14B'),
69
+ Model('Qwen/Qwen-72B', 'Qwen/Qwen-72B'),
70
+ # gptq-int4
71
+ Model('Qwen/Qwen-1_8B-Chat-Int4', 'Qwen/Qwen-1_8B-Chat-Int4'),
72
+ Model('Qwen/Qwen-7B-Chat-Int4', 'Qwen/Qwen-7B-Chat-Int4'),
73
+ Model('Qwen/Qwen-14B-Chat-Int4', 'Qwen/Qwen-14B-Chat-Int4'),
74
+ Model('Qwen/Qwen-72B-Chat-Int4', 'Qwen/Qwen-72B-Chat-Int4'),
75
+ # gptq-int8
76
+ Model('Qwen/Qwen-1_8B-Chat-Int8', 'Qwen/Qwen-1_8B-Chat-Int8'),
77
+ Model('Qwen/Qwen-7B-Chat-Int8', 'Qwen/Qwen-7B-Chat-Int8'),
78
+ Model('Qwen/Qwen-14B-Chat-Int8', 'Qwen/Qwen-14B-Chat-Int8'),
79
+ Model('Qwen/Qwen-72B-Chat-Int8', 'Qwen/Qwen-72B-Chat-Int8'),
80
+ ]),
81
+ # tongyi-finance
82
+ ModelGroup([
83
+ Model('TongyiFinance/Tongyi-Finance-14B-Chat', 'jxy/Tongyi-Finance-14B-Chat'),
84
+ Model('TongyiFinance/Tongyi-Finance-14B'),
85
+ Model('TongyiFinance/Tongyi-Finance-14B-Chat-Int4', 'jxy/Tongyi-Finance-14B-Chat-Int4'),
86
+ ],
87
+ tags=['financial']),
88
+ ],
89
+ TemplateType.qwen,
90
+ get_model_tokenizer_qwen,
91
+ architectures=['QWenLMHeadModel'],
92
+ model_arch=ModelArch.qwen))
93
+
94
+ register_model(
95
+ ModelMeta(
96
+ LLMModelType.modelscope_agent,
97
+ [ModelGroup([
98
+ Model('iic/ModelScope-Agent-7B'),
99
+ Model('iic/ModelScope-Agent-14B'),
100
+ ])],
101
+ TemplateType.modelscope_agent,
102
+ get_model_tokenizer_qwen,
103
+ architectures=['QWenLMHeadModel'],
104
+ model_arch=ModelArch.qwen))
105
+
106
+
107
+ def _qwen_vl_audio_decode(self, *args, skip_special_tokens=False, **kwargs) -> str:
108
+ if skip_special_tokens:
109
+ token_ids = kwargs['token_ids']
110
+ while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
111
+ token_ids.pop()
112
+ return self._old_decode(*args, skip_special_tokens=False, **kwargs)
113
+ else:
114
+ return self._old_decode(*args, skip_special_tokens=False, **kwargs)
115
+
116
+
117
+ def fix_qwen_inplace_bug(model) -> None:
118
+ # qwen-vl, qwen-audio
119
+ first_drop = model.transformer.drop
120
+ if first_drop.p == 0.:
121
+ # fix in-place operation bug
122
+ patch_output_clone(first_drop)
123
+
124
+
125
+ def get_model_tokenizer_qwen_audio(model_dir: str,
126
+ model_info: ModelInfo,
127
+ model_kwargs: Dict[str, Any],
128
+ load_model: bool = True,
129
+ **kwargs):
130
+ tokenizer_config = get_tokenizer_config(model_dir)
131
+ class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
132
+ tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
133
+ tokenizer_cls._auto_class = 'AutoTokenizer'
134
+ tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
135
+ if not hasattr(tokenizer_cls, '_old_decode'):
136
+ tokenizer_cls._old_decode = tokenizer_cls._decode
137
+ tokenizer_cls._decode = _qwen_vl_audio_decode
138
+ kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
139
+ model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
140
+ if model is not None:
141
+ fix_qwen_inplace_bug(model)
142
+
143
+ return model, tokenizer
144
+
145
+
146
+ register_model(
147
+ ModelMeta(
148
+ MLLMModelType.qwen_audio, [
149
+ ModelGroup([
150
+ Model('Qwen/Qwen-Audio-Chat', 'Qwen/Qwen-Audio-Chat'),
151
+ Model('Qwen/Qwen-Audio', 'Qwen/Qwen-Audio'),
152
+ ])
153
+ ],
154
+ TemplateType.qwen_audio,
155
+ get_model_tokenizer_qwen_audio,
156
+ model_arch=ModelArch.qwen_audio,
157
+ architectures=['QWenLMHeadModel'],
158
+ additional_saved_files=['mel_filters.npz'],
159
+ tags=['audio']))
160
+
161
+
162
+ def _qwen_vl_visual_block_forward(
163
+ self,
164
+ q_x: torch.Tensor,
165
+ k_x: Optional[torch.Tensor] = None,
166
+ v_x: Optional[torch.Tensor] = None,
167
+ attn_mask: Optional[torch.Tensor] = None,
168
+ ):
169
+ k_x = self.ln_1_kv(k_x) if hasattr(self, 'ln_1_kv') and k_x is not None else None
170
+ v_x = self.ln_1_kv(v_x) if hasattr(self, 'ln_1_kv') and v_x is not None else None
171
+
172
+ x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
173
+ z = self.mlp(self.ln_2(x))
174
+ x = x.to(z.device) + z # FIX
175
+ return x
176
+
177
+
178
+ def get_model_tokenizer_qwen_vl(model_dir: str,
179
+ model_info: ModelInfo,
180
+ model_kwargs: Dict[str, Any],
181
+ load_model: bool = True,
182
+ **kwargs):
183
+ if (model_kwargs.get('quantization_config') is not None
184
+ and isinstance(model_kwargs['quantization_config'], BitsAndBytesConfig)):
185
+ # https://github.com/pytorch/pytorch/issues/58969
186
+ model_kwargs['quantization_config'].llm_int8_skip_modules = ['lm_head', 'attn_pool.attn']
187
+ _TransformerBlock = get_class_from_dynamic_module('visual.TransformerBlock', model_dir)
188
+
189
+ def _get_cast_dtype(self) -> torch.dtype:
190
+ return self.resblocks[0].ln_1.weight.dtype
191
+
192
+ _TransformerBlock.__old_get_cast_dtype = _TransformerBlock.get_cast_dtype
193
+ _TransformerBlock.get_cast_dtype = _get_cast_dtype
194
+
195
+ tokenizer_config = get_tokenizer_config(model_dir)
196
+ class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
197
+ tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
198
+ tokenizer_cls._auto_class = 'AutoTokenizer'
199
+ tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
200
+ if not hasattr(tokenizer_cls, '_old_decode'):
201
+ tokenizer_cls._old_decode = tokenizer_cls._decode
202
+ tokenizer_cls._decode = _qwen_vl_audio_decode
203
+ # fix device_map is 4
204
+ n_gpu = get_device_count()
205
+ local_world_size = get_dist_setting()[3]
206
+ if n_gpu // local_world_size >= 4:
207
+ visual_block_cls = get_class_from_dynamic_module('visual.VisualAttentionBlock', model_dir)
208
+ visual_block_cls.__old_forward = visual_block_cls.forward
209
+ visual_block_cls.forward = _qwen_vl_visual_block_forward
210
+
211
+ kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
212
+ model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
213
+ if model is not None:
214
+ device_type = next(model.parameters()).device.type
215
+ fix_qwen_inplace_bug(model)
216
+ # fix device_map is 4
217
+ if n_gpu // local_world_size >= 4:
218
+ model.transformer.visual.proj.data = model.transformer.visual.proj.to(
219
+ model.transformer.visual.ln_post.bias.device)
220
+ # fix images cuda:1 bug
221
+ patch_fixed_device(model.transformer.visual, f'{device_type}:0')
222
+ return model, tokenizer
223
+
224
+
225
+ register_model(
226
+ ModelMeta(
227
+ MLLMModelType.qwen_vl, [
228
+ ModelGroup([
229
+ Model('Qwen/Qwen-VL-Chat', 'Qwen/Qwen-VL-Chat'),
230
+ Model('Qwen/Qwen-VL', 'Qwen/Qwen-VL'),
231
+ Model('Qwen/Qwen-VL-Chat-Int4', 'Qwen/Qwen-VL-Chat-Int4'),
232
+ ])
233
+ ],
234
+ TemplateType.qwen_vl,
235
+ get_model_tokenizer_qwen_vl,
236
+ model_arch=ModelArch.qwen_vl,
237
+ architectures=['QWenLMHeadModel'],
238
+ additional_saved_files=['SimSun.ttf'],
239
+ tags=['vision']))
240
+
241
+ register_model(
242
+ ModelMeta(
243
+ LLMModelType.qwen2,
244
+ [
245
+ # qwen1.5
246
+ ModelGroup([
247
+ # chat
248
+ Model('Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-0.5B-Chat'),
249
+ Model('Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-1.8B-Chat'),
250
+ Model('Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-4B-Chat'),
251
+ Model('Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-7B-Chat'),
252
+ Model('Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen1.5-14B-Chat'),
253
+ Model('Qwen/Qwen1.5-32B-Chat', 'Qwen/Qwen1.5-32B-Chat'),
254
+ Model('Qwen/Qwen1.5-72B-Chat', 'Qwen/Qwen1.5-72B-Chat'),
255
+ Model('Qwen/Qwen1.5-110B-Chat', 'Qwen/Qwen1.5-110B-Chat'),
256
+ # base
257
+ Model('Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B'),
258
+ Model('Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B'),
259
+ Model('Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B'),
260
+ Model('Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B'),
261
+ Model('Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B'),
262
+ Model('Qwen/Qwen1.5-32B', 'Qwen/Qwen1.5-32B'),
263
+ Model('Qwen/Qwen1.5-72B', 'Qwen/Qwen1.5-72B'),
264
+ Model('Qwen/Qwen1.5-110B', 'Qwen/Qwen1.5-110B'),
265
+ # gptq-int4
266
+ Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4'),
267
+ Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4'),
268
+ Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int4'),
269
+ Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int4'),
270
+ Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int4'),
271
+ Model('Qwen/Qwen1.5-32B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-32B-Chat-GPTQ-Int4'),
272
+ Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int4'),
273
+ Model('Qwen/Qwen1.5-110B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-110B-Chat-GPTQ-Int4'),
274
+ # gptq-int8
275
+ Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8'),
276
+ Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8'),
277
+ Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int8'),
278
+ Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int8'),
279
+ Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int8'),
280
+ Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int8'),
281
+ # awq-int4
282
+ Model('Qwen/Qwen1.5-0.5B-Chat-AWQ', 'Qwen/Qwen1.5-0.5B-Chat-AWQ'),
283
+ Model('Qwen/Qwen1.5-1.8B-Chat-AWQ', 'Qwen/Qwen1.5-1.8B-Chat-AWQ'),
284
+ Model('Qwen/Qwen1.5-4B-Chat-AWQ', 'Qwen/Qwen1.5-4B-Chat-AWQ'),
285
+ Model('Qwen/Qwen1.5-7B-Chat-AWQ', 'Qwen/Qwen1.5-7B-Chat-AWQ'),
286
+ Model('Qwen/Qwen1.5-14B-Chat-AWQ', 'Qwen/Qwen1.5-14B-Chat-AWQ'),
287
+ Model('Qwen/Qwen1.5-32B-Chat-AWQ', 'Qwen/Qwen1.5-32B-Chat-AWQ'),
288
+ Model('Qwen/Qwen1.5-72B-Chat-AWQ', 'Qwen/Qwen1.5-72B-Chat-AWQ'),
289
+ Model('Qwen/Qwen1.5-110B-Chat-AWQ', 'Qwen/Qwen1.5-110B-Chat-AWQ'),
290
+ ]),
291
+ # code-qwen1.5
292
+ ModelGroup([
293
+ Model('Qwen/CodeQwen1.5-7B', 'Qwen/CodeQwen1.5-7B'),
294
+ Model('Qwen/CodeQwen1.5-7B-Chat', 'Qwen/CodeQwen1.5-7B-Chat'),
295
+ Model('Qwen/CodeQwen1.5-7B-Chat-AWQ', 'Qwen/CodeQwen1.5-7B-Chat-AWQ'),
296
+ ],
297
+ tags=['coding']),
298
+ # qwen2
299
+ ModelGroup([
300
+ # instruct
301
+ Model('Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-0.5B-Instruct'),
302
+ Model('Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-1.5B-Instruct'),
303
+ Model('Qwen/Qwen2-7B-Instruct', 'Qwen/Qwen2-7B-Instruct'),
304
+ Model('Qwen/Qwen2-72B-Instruct', 'Qwen/Qwen2-72B-Instruct'),
305
+ # base
306
+ Model('Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B'),
307
+ Model('Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B'),
308
+ Model('Qwen/Qwen2-7B', 'Qwen/Qwen2-7B'),
309
+ Model('Qwen/Qwen2-72B', 'Qwen/Qwen2-72B'),
310
+ # gptq-int4
311
+ Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4'),
312
+ Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4'),
313
+ Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int4'),
314
+ Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int4'),
315
+ # gptq-int8
316
+ Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8'),
317
+ Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8'),
318
+ Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int8'),
319
+ Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int8'),
320
+ # awq-int4
321
+ Model('Qwen/Qwen2-0.5B-Instruct-AWQ', 'Qwen/Qwen2-0.5B-Instruct-AWQ'),
322
+ Model('Qwen/Qwen2-1.5B-Instruct-AWQ', 'Qwen/Qwen2-1.5B-Instruct-AWQ'),
323
+ Model('Qwen/Qwen2-7B-Instruct-AWQ', 'Qwen/Qwen2-7B-Instruct-AWQ'),
324
+ Model('Qwen/Qwen2-72B-Instruct-AWQ', 'Qwen/Qwen2-72B-Instruct-AWQ'),
325
+ ]),
326
+ # qwen2-math
327
+ ModelGroup(
328
+ [
329
+ # instruct
330
+ Model('Qwen/Qwen2-Math-1.5B-Instruct', 'Qwen/Qwen2-Math-1.5B-Instruct'),
331
+ Model('Qwen/Qwen2-Math-7B-Instruct', 'Qwen/Qwen2-Math-7B-Instruct'),
332
+ Model('Qwen/Qwen2-Math-72B-Instruct', 'Qwen/Qwen2-Math-72B-Instruct'),
333
+ # base
334
+ Model('Qwen/Qwen2-Math-1.5B', 'Qwen/Qwen2-Math-1.5B'),
335
+ Model('Qwen/Qwen2-Math-7B', 'Qwen/Qwen2-Math-7B'),
336
+ Model('Qwen/Qwen2-Math-72B', 'Qwen/Qwen2-Math-72B'),
337
+ ],
338
+ tags=['math']),
339
+ # qwen2.5-1m
340
+ ModelGroup([
341
+ Model('Qwen/Qwen2.5-7B-Instruct-1M', 'Qwen/Qwen2.5-7B-Instruct-1M'),
342
+ Model('Qwen/Qwen2.5-14B-Instruct-1M', 'Qwen/Qwen2.5-14B-Instruct-1M'),
343
+ ]),
344
+ # other
345
+ ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]),
346
+ ],
347
+ TemplateType.qwen,
348
+ get_model_tokenizer_with_flash_attn,
349
+ architectures=['Qwen2ForCausalLM'],
350
+ requires=['transformers>=4.37'],
351
+ model_arch=ModelArch.llama))
352
+
353
+ register_model(
354
+ ModelMeta(
355
+ LLMModelType.qwen2_5,
356
+ [
357
+ # qwen2.5
358
+ ModelGroup([
359
+ # instruct
360
+ Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'),
361
+ Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'),
362
+ Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'),
363
+ Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'),
364
+ Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'),
365
+ Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
366
+ Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
367
+ # base
368
+ Model('Qwen/Qwen2.5-0.5B', 'Qwen/Qwen2.5-0.5B'),
369
+ Model('Qwen/Qwen2.5-1.5B', 'Qwen/Qwen2.5-1.5B'),
370
+ Model('Qwen/Qwen2.5-3B', 'Qwen/Qwen2.5-3B'),
371
+ Model('Qwen/Qwen2.5-7B', 'Qwen/Qwen2.5-7B'),
372
+ Model('Qwen/Qwen2.5-14B', 'Qwen/Qwen2.5-14B'),
373
+ Model('Qwen/Qwen2.5-32B', 'Qwen/Qwen2.5-32B'),
374
+ Model('Qwen/Qwen2.5-72B', 'Qwen/Qwen2.5-72B'),
375
+ # gptq-int4
376
+ Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4'),
377
+ Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4'),
378
+ Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4'),
379
+ Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4'),
380
+ Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4'),
381
+ Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4'),
382
+ Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4'),
383
+ # gptq-int8
384
+ Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8'),
385
+ Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8'),
386
+ Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8'),
387
+ Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8'),
388
+ Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8'),
389
+ Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8'),
390
+ Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8'),
391
+ # awq-int4
392
+ Model('Qwen/Qwen2.5-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-0.5B-Instruct-AWQ'),
393
+ Model('Qwen/Qwen2.5-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-1.5B-Instruct-AWQ'),
394
+ Model('Qwen/Qwen2.5-3B-Instruct-AWQ', 'Qwen/Qwen2.5-3B-Instruct-AWQ'),
395
+ Model('Qwen/Qwen2.5-7B-Instruct-AWQ', 'Qwen/Qwen2.5-7B-Instruct-AWQ'),
396
+ Model('Qwen/Qwen2.5-14B-Instruct-AWQ', 'Qwen/Qwen2.5-14B-Instruct-AWQ'),
397
+ Model('Qwen/Qwen2.5-32B-Instruct-AWQ', 'Qwen/Qwen2.5-32B-Instruct-AWQ'),
398
+ Model('Qwen/Qwen2.5-72B-Instruct-AWQ', 'Qwen/Qwen2.5-72B-Instruct-AWQ'),
399
+ ]),
400
+ # qwen2.5-coder
401
+ ModelGroup(
402
+ [
403
+ # instruct
404
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct', 'Qwen/Qwen2.5-Coder-0.5B-Instruct'),
405
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct', 'Qwen/Qwen2.5-Coder-1.5B-Instruct'),
406
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct', 'Qwen/Qwen2.5-Coder-3B-Instruct'),
407
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
408
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct', 'Qwen/Qwen2.5-Coder-14B-Instruct'),
409
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct', 'Qwen/Qwen2.5-Coder-32B-Instruct'),
410
+ # base
411
+ Model('Qwen/Qwen2.5-Coder-0.5B', 'Qwen/Qwen2.5-Coder-0.5B'),
412
+ Model('Qwen/Qwen2.5-Coder-1.5B', 'Qwen/Qwen2.5-Coder-1.5B'),
413
+ Model('Qwen/Qwen2.5-Coder-3B', 'Qwen/Qwen2.5-Coder-3B'),
414
+ Model('Qwen/Qwen2.5-Coder-7B', 'Qwen/Qwen2.5-Coder-7B'),
415
+ Model('Qwen/Qwen2.5-Coder-14B', 'Qwen/Qwen2.5-Coder-14B'),
416
+ Model('Qwen/Qwen2.5-Coder-32B', 'Qwen/Qwen2.5-Coder-32B'),
417
+ # AWQ
418
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ'),
419
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ'),
420
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-3B-Instruct-AWQ'),
421
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-7B-Instruct-AWQ'),
422
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-14B-Instruct-AWQ'),
423
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-32B-Instruct-AWQ'),
424
+ # GPTQ
425
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4'),
426
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8'),
427
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4'),
428
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8'),
429
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4'),
430
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8'),
431
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4'),
432
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8'),
433
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4'),
434
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8'),
435
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4'),
436
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8'),
437
+ ],
438
+ tags=['coding'])
439
+ ],
440
+ TemplateType.qwen2_5,
441
+ get_model_tokenizer_with_flash_attn,
442
+ architectures=['Qwen2ForCausalLM'],
443
+ requires=['transformers>=4.37'],
444
+ model_arch=ModelArch.llama))
445
+
446
+ register_model(
447
+ ModelMeta(
448
+ LLMModelType.qwen2_5_math,
449
+ [
450
+ # qwen2.5-math
451
+ ModelGroup(
452
+ [
453
+ # instruct
454
+ Model('Qwen/Qwen2.5-Math-1.5B-Instruct', 'Qwen/Qwen2.5-Math-1.5B-Instruct'),
455
+ Model('Qwen/Qwen2.5-Math-7B-Instruct', 'Qwen/Qwen2.5-Math-7B-Instruct'),
456
+ Model('Qwen/Qwen2.5-Math-72B-Instruct', 'Qwen/Qwen2.5-Math-72B-Instruct'),
457
+ # base
458
+ Model('Qwen/Qwen2.5-Math-1.5B', 'Qwen/Qwen2.5-Math-1.5B'),
459
+ Model('Qwen/Qwen2.5-Math-7B', 'Qwen/Qwen2.5-Math-7B'),
460
+ Model('Qwen/Qwen2.5-Math-72B', 'Qwen/Qwen2.5-Math-72B'),
461
+ ],
462
+ tags=['math']),
463
+ ],
464
+ TemplateType.qwen2_5_math,
465
+ get_model_tokenizer_with_flash_attn,
466
+ architectures=['Qwen2ForCausalLM'],
467
+ requires=['transformers>=4.37'],
468
+ model_arch=ModelArch.llama))
469
+
470
+ register_model(
471
+ ModelMeta(
472
+ LLMModelType.qwen2_moe,
473
+ [
474
+ # qwen1.5-moe
475
+ ModelGroup([
476
+ Model('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'Qwen/Qwen1.5-MoE-A2.7B-Chat'),
477
+ Model('Qwen/Qwen1.5-MoE-A2.7B', 'Qwen/Qwen1.5-MoE-A2.7B'),
478
+ Model('Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4'),
479
+ ]),
480
+ ModelGroup([
481
+ Model('Qwen/Qwen2-57B-A14B-Instruct', 'Qwen/Qwen2-57B-A14B-Instruct'),
482
+ Model('Qwen/Qwen2-57B-A14B', 'Qwen/Qwen2-57B-A14B'),
483
+ Model('Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4'),
484
+ ])
485
+ ],
486
+ TemplateType.qwen,
487
+ get_model_tokenizer_with_flash_attn,
488
+ architectures=['Qwen2MoeForCausalLM'],
489
+ requires=['transformers>=4.40'],
490
+ ))
491
+
492
+ register_model(
493
+ ModelMeta(
494
+ LLMModelType.qwen3,
495
+ [
496
+ ModelGroup([
497
+ Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'),
498
+ Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'),
499
+ Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'),
500
+ Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'),
501
+ Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'),
502
+ Model('Qwen/Qwen3-32B-Base', 'Qwen/Qwen3-32B-Base'),
503
+ # instruct
504
+ Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'),
505
+ Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'),
506
+ Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'),
507
+ Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'),
508
+ Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'),
509
+ Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'),
510
+ # fp8
511
+ Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'),
512
+ Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'),
513
+ Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'),
514
+ Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'),
515
+ Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'),
516
+ Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'),
517
+ # awq
518
+ Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'),
519
+ Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'),
520
+ Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'),
521
+ Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'),
522
+ # swift
523
+ Model('swift/Qwen3-32B-AWQ'),
524
+ ]),
525
+ ],
526
+ TemplateType.qwen3,
527
+ get_model_tokenizer_with_flash_attn,
528
+ architectures=['Qwen3ForCausalLM'],
529
+ requires=['transformers>=4.51'],
530
+ model_arch=ModelArch.llama))
531
+
532
+ register_model(
533
+ ModelMeta(
534
+ LLMModelType.qwen3_moe,
535
+ [
536
+ ModelGroup([
537
+ Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'),
538
+ Model('Qwen/Qwen3-235B-A22B-Base', 'Qwen/Qwen3-235B-A22B-Base'),
539
+ # instruct
540
+ Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'),
541
+ Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'),
542
+ # fp8
543
+ Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'),
544
+ Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'),
545
+ # awq
546
+ Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'),
547
+ Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'),
548
+ ]),
549
+ ],
550
+ TemplateType.qwen3,
551
+ get_model_tokenizer_with_flash_attn,
552
+ architectures=['Qwen3MoeForCausalLM'],
553
+ requires=['transformers>=4.51'],
554
+ ))
555
+
556
+
557
+ def patch_qwen_vl_utils(vision_process):
558
+ if hasattr(vision_process, '_patch'):
559
+ return
560
+ for key in [
561
+ 'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels',
562
+ 'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames'
563
+ ]:
564
+ type_func = float if key == 'fps' else int
565
+ setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper())))
566
+ _read_video_decord = vision_process._read_video_decord
567
+
568
+ def _new_read_video_decord(ele: dict):
569
+ from swift.llm import load_file
570
+ ele['video'] = load_file(ele['video'])
571
+ return _read_video_decord(ele)
572
+
573
+ vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord
574
+ vision_process._patch = True
575
+
576
+
577
+ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
578
+ from transformers import Qwen2VLForConditionalGeneration
579
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
580
+ model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
581
+ if model is not None and hasattr(model.model, 'embed_tokens'):
582
+ patch_output_clone(model.model.embed_tokens)
583
+ patch_output_to_input_device(model.model.embed_tokens)
584
+
585
+ from qwen_vl_utils import vision_process
586
+ patch_qwen_vl_utils(vision_process)
587
+ return model, tokenizer
588
+
589
+
590
+ register_model(
591
+ ModelMeta(
592
+ MLLMModelType.qwen2_vl,
593
+ [
594
+ ModelGroup(
595
+ [
596
+ # chat
597
+ Model('Qwen/Qwen2-VL-2B-Instruct', 'Qwen/Qwen2-VL-2B-Instruct'),
598
+ Model('Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-7B-Instruct'),
599
+ Model('Qwen/Qwen2-VL-72B-Instruct', 'Qwen/Qwen2-VL-72B-Instruct'),
600
+ # base
601
+ Model('Qwen/Qwen2-VL-2B', 'Qwen/Qwen2-VL-2B'),
602
+ Model('Qwen/Qwen2-VL-7B', 'Qwen/Qwen2-VL-7B'),
603
+ Model('Qwen/Qwen2-VL-72B', 'Qwen/Qwen2-VL-72B'),
604
+ # gptq-int4
605
+ Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4'),
606
+ Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4'),
607
+ Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4'),
608
+ # gptq-int8
609
+ Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8'),
610
+ Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8'),
611
+ Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8'),
612
+ # awq-int4
613
+ Model('Qwen/Qwen2-VL-2B-Instruct-AWQ', 'Qwen/Qwen2-VL-2B-Instruct-AWQ'),
614
+ Model('Qwen/Qwen2-VL-7B-Instruct-AWQ', 'Qwen/Qwen2-VL-7B-Instruct-AWQ'),
615
+ Model('Qwen/Qwen2-VL-72B-Instruct-AWQ', 'Qwen/Qwen2-VL-72B-Instruct-AWQ'),
616
+ ], ),
617
+ ModelGroup([
618
+ Model('bytedance-research/UI-TARS-2B-SFT', 'bytedance-research/UI-TARS-2B-SFT'),
619
+ Model('bytedance-research/UI-TARS-7B-SFT', 'bytedance-research/UI-TARS-7B-SFT'),
620
+ Model('bytedance-research/UI-TARS-7B-DPO', 'bytedance-research/UI-TARS-7B-DPO'),
621
+ Model('bytedance-research/UI-TARS-72B-SFT', 'bytedance-research/UI-TARS-72B-SFT'),
622
+ Model('bytedance-research/UI-TARS-72B-DPO', 'bytedance-research/UI-TARS-72B-DPO'),
623
+ ]),
624
+ ModelGroup([
625
+ Model('allenai/olmOCR-7B-0225-preview', 'allenai/olmOCR-7B-0225-preview'),
626
+ ]),
627
+ ],
628
+ TemplateType.qwen2_vl,
629
+ get_model_tokenizer_qwen2_vl,
630
+ model_arch=ModelArch.qwen2_vl,
631
+ architectures=['Qwen2VLForConditionalGeneration'],
632
+ requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
633
+ tags=['vision', 'video']))
634
+
635
+ register_model(
636
+ ModelMeta(
637
+ MLLMModelType.qvq, [
638
+ ModelGroup([
639
+ Model('Qwen/QVQ-72B-Preview', 'Qwen/QVQ-72B-Preview'),
640
+ ]),
641
+ ],
642
+ TemplateType.qvq,
643
+ get_model_tokenizer_qwen2_vl,
644
+ model_arch=ModelArch.qwen2_vl,
645
+ architectures=['Qwen2VLForConditionalGeneration'],
646
+ requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
647
+ tags=['vision', 'video']))
648
+
649
+
650
+ def get_model_tokenizer_qwen2_5_vl(*args, **kwargs):
651
+ from transformers import Qwen2_5_VLForConditionalGeneration
652
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5_VLForConditionalGeneration
653
+ return get_model_tokenizer_qwen2_vl(*args, **kwargs)
654
+
655
+
656
+ register_model(
657
+ ModelMeta(
658
+ MLLMModelType.qwen2_5_vl, [
659
+ ModelGroup([
660
+ Model('Qwen/Qwen2.5-VL-3B-Instruct', 'Qwen/Qwen2.5-VL-3B-Instruct'),
661
+ Model('Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-7B-Instruct'),
662
+ Model('Qwen/Qwen2.5-VL-32B-Instruct', 'Qwen/Qwen2.5-VL-32B-Instruct'),
663
+ Model('Qwen/Qwen2.5-VL-72B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct'),
664
+ ]),
665
+ ModelGroup([
666
+ Model('Qwen/Qwen2.5-VL-3B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-3B-Instruct-AWQ'),
667
+ Model('Qwen/Qwen2.5-VL-7B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-7B-Instruct-AWQ'),
668
+ Model('Qwen/Qwen2.5-VL-32B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-32B-Instruct-AWQ'),
669
+ Model('Qwen/Qwen2.5-VL-72B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-72B-Instruct-AWQ'),
670
+ ]),
671
+ ],
672
+ TemplateType.qwen2_5_vl,
673
+ get_model_tokenizer_qwen2_5_vl,
674
+ model_arch=ModelArch.qwen2_vl,
675
+ architectures=['Qwen2_5_VLForConditionalGeneration'],
676
+ requires=['transformers>=4.49', 'qwen_vl_utils>=0.0.6', 'decord'],
677
+ tags=['vision', 'video']))
678
+
679
+
680
+ def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
681
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniConfig
682
+ from qwen_omni_utils import vision_process
683
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5OmniForConditionalGeneration
684
+ processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code=True)
685
+ kwargs['tokenizer'] = processor.tokenizer
686
+ kwargs['model_config'] = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code=True)
687
+ patch_qwen_vl_utils(vision_process)
688
+ kwargs['model_config'].enable_audio_output = get_env_args('ENABLE_AUDIO_OUTPUT', bool, True)
689
+ model, _ = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
690
+ if model:
691
+ use_submodel_func(model, 'thinker')
692
+ model.config.keys_to_ignore_at_inference += ['hidden_states', 'attention_mask']
693
+ model.config.talker_config.pad_token_id = None
694
+ return model, processor
695
+
696
+
697
+ register_model(
698
+ ModelMeta(
699
+ MLLMModelType.qwen2_5_omni,
700
+ [
701
+ ModelGroup([
702
+ Model('Qwen/Qwen2.5-Omni-3B', 'Qwen/Qwen2.5-Omni-3B'),
703
+ Model('Qwen/Qwen2.5-Omni-7B', 'Qwen/Qwen2.5-Omni-7B'),
704
+ ]),
705
+ ],
706
+ TemplateType.qwen2_5_omni,
707
+ get_model_tokenizer_qwen2_5_omni,
708
+ model_arch=ModelArch.qwen2_5_omni,
709
+ architectures=['Qwen2_5OmniModel'],
710
+ requires=['transformers>=4.50', 'soundfile', 'qwen_omni_utils', 'decord'],
711
+ tags=['vision', 'video', 'audio'],
712
+ additional_saved_files=['spk_dict.pt'],
713
+ ignore_patterns=[],
714
+ ))
715
+
716
+
717
+ def get_model_tokenizer_qwen2_audio(*args, **kwargs):
718
+ from transformers import Qwen2AudioForConditionalGeneration
719
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2AudioForConditionalGeneration
720
+ return get_model_tokenizer_multimodal(*args, **kwargs)
721
+
722
+
723
+ register_model(
724
+ ModelMeta(
725
+ MLLMModelType.qwen2_audio,
726
+ [
727
+ ModelGroup([
728
+ Model('Qwen/Qwen2-Audio-7B-Instruct', 'Qwen/Qwen2-Audio-7B-Instruct'),
729
+ Model('Qwen/Qwen2-Audio-7B', 'Qwen/Qwen2-Audio-7B'),
730
+ ]),
731
+ ],
732
+ TemplateType.qwen2_audio,
733
+ get_model_tokenizer_qwen2_audio,
734
+ model_arch=ModelArch.qwen2_audio,
735
+ architectures=['Qwen2AudioForConditionalGeneration'],
736
+ requires=['transformers>=4.45,<4.49', 'librosa'],
737
+ tags=['audio'],
738
+ ))
739
+
740
+ register_model(
741
+ ModelMeta(
742
+ LLMModelType.marco_o1, [ModelGroup([Model('AIDC-AI/Marco-o1', 'AIDC-AI/Marco-o1')])],
743
+ TemplateType.marco_o1,
744
+ get_model_tokenizer_with_flash_attn,
745
+ model_arch=ModelArch.llama,
746
+ architectures=['Qwen2ForCausalLM'],
747
+ requires=['transformers>=4.37']))
748
+
749
+ register_model(
750
+ ModelMeta(
751
+ LLMModelType.qwq_preview, [ModelGroup([Model('Qwen/QwQ-32B-Preview', 'Qwen/QwQ-32B-Preview')])],
752
+ TemplateType.qwq_preview,
753
+ get_model_tokenizer_with_flash_attn,
754
+ model_arch=ModelArch.llama,
755
+ architectures=['Qwen2ForCausalLM'],
756
+ requires=['transformers>=4.37']))
757
+
758
+ register_model(
759
+ ModelMeta(
760
+ LLMModelType.qwq,
761
+ [ModelGroup([
762
+ Model('Qwen/QwQ-32B', 'Qwen/QwQ-32B'),
763
+ Model('Qwen/QwQ-32B-AWQ', 'Qwen/QwQ-32B-AWQ'),
764
+ ])],
765
+ TemplateType.qwq,
766
+ get_model_tokenizer_with_flash_attn,
767
+ model_arch=ModelArch.llama,
768
+ architectures=['Qwen2ForCausalLM'],
769
+ requires=['transformers>=4.37']))
770
+
771
+
772
+ def get_model_tokenizer_ovis(*args, **kwargs):
773
+ kwargs['attn_impl_keys'] = ['llm_attn_implementation']
774
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
775
+ if model is not None:
776
+ model.visual_tokenizer.to(model.dtype)
777
+ model.vte.to(model.dtype)
778
+
779
+ model.generation_config.cache_implementation = None
780
+ func_list = ['generate', 'forward', 'get_input_embeddings']
781
+ use_submodel_func(model, 'llm', func_list)
782
+ embedding = model.get_input_embeddings()
783
+ patch_output_clone(embedding)
784
+ try:
785
+ # fix device_map
786
+ from transformers.cache_utils import HybridCache
787
+
788
+ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
789
+ **kwargs) -> Tuple[torch.Tensor]:
790
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
791
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
792
+ return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)
793
+
794
+ if not hasattr(HybridCache, '_update_origin'):
795
+ HybridCache._update_origin = HybridCache.update
796
+ HybridCache.update = update
797
+ except ImportError:
798
+ pass
799
+ return model, tokenizer
800
+
801
+
802
+ register_model(
803
+ ModelMeta(
804
+ MLLMModelType.ovis1_6,
805
+ [
806
+ ModelGroup([
807
+ Model('AIDC-AI/Ovis1.6-Gemma2-9B', 'AIDC-AI/Ovis1.6-Gemma2-9B'),
808
+ Model('AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4', 'AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4'),
809
+ Model('AIDC-AI/Ovis1.6-Gemma2-27B', 'AIDC-AI/Ovis1.6-Gemma2-27B'),
810
+ ]),
811
+ ],
812
+ TemplateType.ovis1_6,
813
+ get_model_tokenizer_ovis,
814
+ model_arch=ModelArch.ovis1_6,
815
+ architectures=['Ovis'],
816
+ tags=['vision'],
817
+ requires=['transformers>=4.42'],
818
+ ))
819
+
820
+ register_model(
821
+ ModelMeta(
822
+ MLLMModelType.ovis1_6_llama3,
823
+ [
824
+ ModelGroup([
825
+ Model('AIDC-AI/Ovis1.6-Llama3.2-3B', 'AIDC-AI/Ovis1.6-Llama3.2-3B'),
826
+ ]),
827
+ ],
828
+ TemplateType.ovis1_6_llama3,
829
+ get_model_tokenizer_ovis,
830
+ model_arch=ModelArch.ovis1_6,
831
+ architectures=['Ovis'],
832
+ tags=['vision'],
833
+ ))
834
+
835
+ register_model(
836
+ ModelMeta(
837
+ MLLMModelType.ovis2,
838
+ [
839
+ ModelGroup([
840
+ Model('AIDC-AI/Ovis2-1B', 'AIDC-AI/Ovis2-1B'),
841
+ Model('AIDC-AI/Ovis2-2B', 'AIDC-AI/Ovis2-2B'),
842
+ Model('AIDC-AI/Ovis2-4B', 'AIDC-AI/Ovis2-4B'),
843
+ Model('AIDC-AI/Ovis2-8B', 'AIDC-AI/Ovis2-8B'),
844
+ Model('AIDC-AI/Ovis2-16B', 'AIDC-AI/Ovis2-16B'),
845
+ Model('AIDC-AI/Ovis2-34B', 'AIDC-AI/Ovis2-34B'),
846
+ ]),
847
+ ],
848
+ TemplateType.ovis2,
849
+ get_model_tokenizer_ovis,
850
+ model_arch=ModelArch.ovis1_6,
851
+ architectures=['Ovis'],
852
+ tags=['vision'],
853
+ requires=['transformers>=4.46.2', 'moviepy<2'],
854
+ ))
855
+
856
+ register_model(
857
+ ModelMeta(
858
+ RMModelType.qwen2_reward,
859
+ [
860
+ ModelGroup([
861
+ Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
862
+ ]),
863
+ ],
864
+ TemplateType.qwen,
865
+ get_model_tokenizer_reward_model,
866
+ architectures=['Qwen2ForRewardModel'],
867
+ requires=['transformers>=4.37'],
868
+ ))
869
+
870
+ register_model(
871
+ ModelMeta(
872
+ RMModelType.qwen2_5_prm,
873
+ [
874
+ ModelGroup([
875
+ Model('Qwen/Qwen2.5-Math-PRM-7B', 'Qwen/Qwen2.5-Math-PRM-7B'),
876
+ Model('Qwen/Qwen2.5-Math-7B-PRM800K', 'Qwen/Qwen2.5-Math-7B-PRM800K'),
877
+ Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'),
878
+ ]),
879
+ ],
880
+ TemplateType.qwen2_5_math_prm,
881
+ get_model_tokenizer_reward_model,
882
+ task_type='prm',
883
+ architectures=['Qwen2ForProcessRewardModel'],
884
+ requires=['transformers>=4.37'],
885
+ ))
886
+
887
+ register_model(
888
+ ModelMeta(
889
+ RMModelType.qwen2_5_math_reward,
890
+ [
891
+ ModelGroup([
892
+ Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
893
+ ]),
894
+ ],
895
+ TemplateType.qwen2_5_math,
896
+ get_model_tokenizer_reward_model,
897
+ architectures=['Qwen2ForRewardModel'],
898
+ requires=['transformers>=4.37'],
899
+ ))
ms-swift/swift/llm/model/model/skywork.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import Any, Dict
4
+
5
+ from swift.llm import TemplateType
6
+ from ..constant import LLMModelType, RMModelType
7
+ from ..model_arch import ModelArch
8
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
9
+ from ..utils import ModelInfo
10
+
11
+
12
+ def get_skywork_model_tokenizer(model_dir: str,
13
+ model_info: ModelInfo,
14
+ model_kwargs: Dict[str, Any],
15
+ load_model: bool = True,
16
+ **kwargs):
17
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
18
+ if 'chat' in model_dir:
19
+ tokenizer.add_tokens('[USER]')
20
+ tokenizer.add_tokens('[BOT]')
21
+ tokenizer.add_tokens('[SEP]')
22
+ return model, tokenizer
23
+
24
+
25
+ register_model(
26
+ ModelMeta(
27
+ LLMModelType.skywork,
28
+ [
29
+ ModelGroup([
30
+ Model('skywork/Skywork-13B-base', 'skywork/Skywork-13B-base'),
31
+ Model('skywork/Skywork-13B-chat'),
32
+ ]),
33
+ ],
34
+ TemplateType.skywork,
35
+ get_skywork_model_tokenizer,
36
+ architectures=['SkyworkForCausalLM'],
37
+ model_arch=ModelArch.llama,
38
+ ))
39
+
40
+ register_model(
41
+ ModelMeta(
42
+ LLMModelType.skywork_o1,
43
+ [
44
+ ModelGroup([
45
+ Model('AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B', 'Skywork/Skywork-o1-Open-Llama-3.1-8B'),
46
+ ]),
47
+ ],
48
+ TemplateType.skywork_o1,
49
+ get_model_tokenizer_with_flash_attn,
50
+ architectures=['LlamaForCausalLM'],
51
+ requires=['transformers>=4.43'],
52
+ model_arch=ModelArch.llama,
53
+ ))
54
+
55
+ register_model(
56
+ ModelMeta(
57
+ RMModelType.llama3_2_reward,
58
+ [
59
+ ModelGroup([
60
+ Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B', 'Skywork/Skywork-Reward-Llama-3.1-8B'),
61
+ Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2', 'Skywork/Skywork-Reward-Llama-3.1-8B-v0.2'),
62
+ ]),
63
+ ModelGroup([
64
+ Model('AI-ModelScope/GRM_Llama3.1_8B_rewardmodel-ft', 'Ray2333/GRM_Llama3.1_8B_rewardmodel-ft'),
65
+ Model('AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft', 'Ray2333/GRM-llama3.2-3B-rewardmodel-ft'),
66
+ ])
67
+ ],
68
+ TemplateType.llama3_2,
69
+ get_model_tokenizer_with_flash_attn,
70
+ requires=['transformers>=4.43'],
71
+ architectures=['LlamaForSequenceClassification'],
72
+ model_arch=ModelArch.llama,
73
+ ))
74
+
75
+ register_model(
76
+ ModelMeta(
77
+ RMModelType.gemma_reward,
78
+ [
79
+ ModelGroup([
80
+ Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B', 'Skywork/Skywork-Reward-Gemma-2-27B'),
81
+ Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2', 'Skywork/Skywork-Reward-Gemma-2-27B-v0.2'),
82
+ ]),
83
+ ],
84
+ TemplateType.gemma,
85
+ get_model_tokenizer_with_flash_attn,
86
+ requires=['transformers>=4.42'],
87
+ architectures=['Gemma2ForSequenceClassification'],
88
+ model_arch=ModelArch.llama,
89
+ ))
ms-swift/swift/llm/model/model/yi.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ from typing import Any, Dict
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ from swift.llm import TemplateType
9
+ from swift.utils import get_logger
10
+ from ..constant import LLMModelType, MLLMModelType
11
+ from ..model_arch import ModelArch
12
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
13
+ from ..utils import ModelInfo, git_clone_github
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ def get_model_tokenizer_yi(model_dir, *args, **kwargs):
19
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
20
+ return get_model_tokenizer_with_flash_attn(model_dir, *args, tokenizer=tokenizer, **kwargs)
21
+
22
+
23
+ def get_model_tokenizer_yi_vl(model_dir: str,
24
+ model_info: ModelInfo,
25
+ model_kwargs: Dict[str, Any],
26
+ load_model: bool = True,
27
+ **kwargs):
28
+ local_repo_path = kwargs.get('local_repo_path')
29
+ if not local_repo_path:
30
+ local_repo_path = git_clone_github('https://github.com/01-ai/Yi')
31
+ sys.path.append(os.path.join(local_repo_path, 'VL'))
32
+ from llava.model import LlavaLlamaForCausalLM, LlavaConfig
33
+ from llava.model.constants import key_info
34
+
35
+ model_config = LlavaConfig.from_pretrained(model_dir)
36
+ mm_vision_tower = model_config.mm_vision_tower
37
+ model_config.mm_vision_tower = os.path.join(model_dir, *mm_vision_tower.rsplit('/', maxsplit=2)[-2:])
38
+ model_config.attention_dropout = 0.
39
+ key_info['model_path'] = model_dir
40
+ kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaLlamaForCausalLM
41
+ model, tokenizer = get_model_tokenizer_yi(
42
+ model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
43
+ if model is not None:
44
+ logger.info('Please ignore the above warning.')
45
+ logger.info('Loading the parameters of vision_tower...')
46
+ model.resize_token_embeddings(len(tokenizer))
47
+ vision_tower = model.get_vision_tower()
48
+ vision_tower.load_model()
49
+ vision_tower.to(device=model.device, dtype=model_config.torch_dtype)
50
+ if not hasattr(model.config, 'max_sequence_length'):
51
+ model.config.max_sequence_length = 2048
52
+ tokenizer.image_processor = vision_tower.image_processor
53
+ return model, tokenizer
54
+
55
+
56
+ register_model(
57
+ ModelMeta(
58
+ MLLMModelType.yi_vl,
59
+ [
60
+ ModelGroup([
61
+ Model('01ai/Yi-VL-6B', '01-ai/Yi-VL-6B'),
62
+ Model('01ai/Yi-VL-34B', '01-ai/Yi-VL-34B'),
63
+ ], ),
64
+ ],
65
+ TemplateType.yi_vl,
66
+ get_model_tokenizer_yi_vl,
67
+ model_arch=ModelArch.llava_llama,
68
+ architectures=['LlavaLlamaForCausalLM'],
69
+ requires=['transformers>=4.34'],
70
+ tags=['vision'],
71
+ ))
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ LLMModelType.yi,
76
+ [ # yi
77
+ ModelGroup([
78
+ Model('01ai/Yi-6B', '01-ai/Yi-6B'),
79
+ Model('01ai/Yi-6B-200K', '01-ai/Yi-6B-200K'),
80
+ Model('01ai/Yi-6B-Chat', '01-ai/Yi-6B-Chat'),
81
+ Model('01ai/Yi-6B-Chat-4bits', '01-ai/Yi-6B-Chat-4bits'),
82
+ Model('01ai/Yi-6B-Chat-8bits', '01-ai/Yi-6B-Chat-8bits'),
83
+ Model('01ai/Yi-9B', '01-ai/Yi-9B'),
84
+ Model('01ai/Yi-9B-200K', '01-ai/Yi-9B-200K'),
85
+ Model('01ai/Yi-34B', '01-ai/Yi-34B'),
86
+ Model('01ai/Yi-34B-200K', '01-ai/Yi-34B-200K'),
87
+ Model('01ai/Yi-34B-Chat', '01-ai/Yi-34B-Chat'),
88
+ Model('01ai/Yi-34B-Chat-4bits', '01-ai/Yi-34B-Chat-4bits'),
89
+ Model('01ai/Yi-34B-Chat-8bits', '01-ai/Yi-34B-Chat-8bits'),
90
+ ]),
91
+ # yi1.5
92
+ ModelGroup([
93
+ Model('01ai/Yi-1.5-6B', '01-ai/Yi-1.5-6B'),
94
+ Model('01ai/Yi-1.5-6B-Chat', '01-ai/Yi-1.5-6B-Chat'),
95
+ Model('01ai/Yi-1.5-9B', '01-ai/Yi-1.5-9B'),
96
+ Model('01ai/Yi-1.5-9B-Chat', '01-ai/Yi-1.5-9B-Chat'),
97
+ Model('01ai/Yi-1.5-9B-Chat-16K', '01-ai/Yi-1.5-9B-Chat-16K'),
98
+ Model('01ai/Yi-1.5-34B', '01-ai/Yi-1.5-34B'),
99
+ Model('01ai/Yi-1.5-34B-Chat', '01-ai/Yi-1.5-34B-Chat'),
100
+ Model('01ai/Yi-1.5-34B-Chat-16K', '01-ai/Yi-1.5-34B-Chat-16K'),
101
+ ]),
102
+ # yi1.5-quant
103
+ ModelGroup([
104
+ Model('AI-ModelScope/Yi-1.5-6B-Chat-GPTQ', 'modelscope/Yi-1.5-6B-Chat-GPTQ'),
105
+ Model('AI-ModelScope/Yi-1.5-6B-Chat-AWQ', 'modelscope/Yi-1.5-6B-Chat-AWQ'),
106
+ Model('AI-ModelScope/Yi-1.5-9B-Chat-GPTQ', 'modelscope/Yi-1.5-9B-Chat-GPTQ'),
107
+ Model('AI-ModelScope/Yi-1.5-9B-Chat-AWQ', 'modelscope/Yi-1.5-9B-Chat-AWQ'),
108
+ Model('AI-ModelScope/Yi-1.5-34B-Chat-GPTQ', 'modelscope/Yi-1.5-34B-Chat-GPTQ'),
109
+ Model('AI-ModelScope/Yi-1.5-34B-Chat-AWQ', 'modelscope/Yi-1.5-34B-Chat-AWQ'),
110
+ ]),
111
+ ],
112
+ TemplateType.chatml,
113
+ get_model_tokenizer_yi,
114
+ architectures=['LlamaForCausalLM'],
115
+ model_arch=ModelArch.llama,
116
+ ))
117
+
118
+ register_model(
119
+ ModelMeta(
120
+ LLMModelType.yi_coder, [
121
+ ModelGroup([
122
+ Model('01ai/Yi-Coder-1.5B', '01-ai/Yi-Coder-1.5B'),
123
+ Model('01ai/Yi-Coder-9B', '01-ai/Yi-Coder-9B'),
124
+ Model('01ai/Yi-Coder-1.5B-Chat', '01-ai/Yi-Coder-1.5B-Chat'),
125
+ Model('01ai/Yi-Coder-9B-Chat', '01-ai/Yi-Coder-9B-Chat'),
126
+ ], )
127
+ ],
128
+ TemplateType.yi_coder,
129
+ get_model_tokenizer_yi,
130
+ architectures=['LlamaForCausalLM'],
131
+ model_arch=ModelArch.llama,
132
+ tags=['coding']))
133
+
134
+ register_model(
135
+ ModelMeta(
136
+ LLMModelType.sus,
137
+ [
138
+ ModelGroup([
139
+ Model('SUSTC/SUS-Chat-34B', 'SUSTech/SUS-Chat-34B'),
140
+ ]),
141
+ ],
142
+ TemplateType.sus,
143
+ get_model_tokenizer_yi,
144
+ model_arch=ModelArch.llama,
145
+ architectures=['LlamaForCausalLM'],
146
+ ))
ms-swift/swift/llm/sampling/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampling import sampling_main
ms-swift/swift/llm/sampling/base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from swift.llm import SamplingArguments
4
+ from swift.plugin import orms, prms
5
+ from swift.utils import get_logger
6
+
7
+ logger = get_logger()
8
+
9
+
10
+ class Sampler:
11
+
12
+ def __init__(self, input_args: SamplingArguments):
13
+ self.args = input_args
14
+ self.template = None
15
+ self.processor = None
16
+ self.prm_model = None
17
+ self.orm_model = None
18
+ self._prepare_model_tokenizer()
19
+ self._prepare_template()
20
+ self._prepare_rm()
21
+
22
+ def _prepare_model_tokenizer(self):
23
+ args = self.args
24
+ _, self.processor = args.get_model_processor(load_model=False)
25
+
26
+ def _prepare_rm(self):
27
+ if self.args.prm_model is None:
28
+ self.prm_model = None
29
+ logger.warning('prm_model is None.')
30
+ elif self.args.prm_model in prms:
31
+ self.prm_model = prms[self.args.prm_model]()
32
+ else:
33
+ from swift.llm import PtEngine
34
+ self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)
35
+
36
+ if self.args.orm_model is None:
37
+ self.orm_model = None
38
+ logger.warning('orm_model is None.')
39
+ elif self.args.orm_model in orms:
40
+ self.orm_model = orms[self.args.orm_model]()
41
+ else:
42
+ from swift.llm import PtEngine
43
+ self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)
44
+
45
+ def _prepare_template(self) -> None:
46
+ template = self.args.get_template(self.processor)
47
+ self.template = template
48
+ self.template.set_mode('train')
49
+
50
+ def truncate_input(self, slices: List[Dict[str, Any]]):
51
+ """Truncate the input rows to avoid hitting the max length of the policy model"""
52
+ return slices
53
+
54
+ def do_sample(self, data):
55
+ raise NotImplementedError
ms-swift/swift/llm/sampling/sampling.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import shutil
4
+ import time
5
+ from typing import List, Union
6
+
7
+ import json
8
+
9
+ from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
10
+ from swift.utils import get_logger
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class SwiftSampling(SwiftPipeline):
16
+ args_class = SamplingArguments
17
+ args: args_class
18
+
19
+ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
20
+ super().__init__(args)
21
+ self.args.save_args()
22
+ os.makedirs(self.args.output_dir, exist_ok=True)
23
+ self.cur_piece = 0
24
+ self.total_piece = 1
25
+
26
+ if self.args.data_range:
27
+ self.cur_piece, self.total_piece = self.args.data_range
28
+
29
+ if self.args.sampler_type == 'sample':
30
+ from swift.llm.sampling.vanilla_sampler import VanillaSampler
31
+ self.sampler = VanillaSampler(self.args)
32
+ elif self.args.sampler_type == 'mcts':
33
+ from swift.llm.sampling.mcts import MctsSampler
34
+ self.sampler = MctsSampler(self.args)
35
+ elif self.args.sampler_type == 'distill':
36
+ from swift.llm.sampling.distill_sampler import DistillSampler
37
+ self.sampler = DistillSampler(self.args)
38
+ else:
39
+ raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')
40
+
41
+ def _get_dataset(self):
42
+ args = self.args
43
+ dataset_kwargs = args.get_dataset_kwargs()
44
+ sampling_dataset, _ = load_dataset(
45
+ args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs)
46
+ logger.info(f'Sampling_dataset: {sampling_dataset}')
47
+ dataset_len = len(sampling_dataset)
48
+ piece_len = dataset_len // self.total_piece
49
+ sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1)))
50
+ return sampling_dataset
51
+
52
+ def run(self):
53
+ os.makedirs(self.args.output_dir, exist_ok=True)
54
+ iter_file = os.path.join(self.args.output_dir, self.args.output_file)
55
+ resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
56
+ tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
57
+ ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
58
+ if os.path.exists(iter_file) and not self.args.override_exist_file:
59
+ return
60
+
61
+ index_resume = -1
62
+ write_mode = 'w'
63
+ if self.args.resume:
64
+ write_mode = 'a'
65
+ if os.path.exists(resume_file):
66
+ shutil.copyfile(resume_file, tmp_file)
67
+
68
+ if os.path.exists(ckpt_state_file):
69
+ with open(ckpt_state_file, 'r') as ckpt_state:
70
+ data = json.load(ckpt_state)
71
+ index_resume = data.get('index', -1)
72
+ logger.info(f'Loaded index_resume: {index_resume}')
73
+ else:
74
+ if os.path.exists(tmp_file):
75
+ os.remove(tmp_file)
76
+
77
+ dataset = self._get_dataset()
78
+ dataset_len = len(dataset)
79
+ total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
80
+
81
+ if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
82
+ self.args.num_sampling_per_gpu_batches = total_iters
83
+
84
+ with open(tmp_file, write_mode) as f:
85
+ for _index in range(self.args.num_sampling_per_gpu_batches):
86
+ if _index <= index_resume:
87
+ continue
88
+ logger.info(f' Sampling index:{_index}')
89
+ slices = dataset[self.args.num_sampling_per_gpu_batch_size
90
+ * _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
91
+ slices = self.sampler.truncate_input(slices)
92
+ generated = self.sampler.do_sample(slices)
93
+ f.writelines(generated)
94
+ f.flush()
95
+ shutil.copy(tmp_file, resume_file)
96
+ with open(ckpt_state_file, 'w') as ckpt_state:
97
+ json.dump({'index': _index}, ckpt_state)
98
+
99
+ if os.path.exists(iter_file):
100
+ shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
101
+ shutil.move(resume_file, iter_file)
102
+ logger.info(f'Sample file {iter_file} generated.')
103
+
104
+
105
+ def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
106
+ return SwiftSampling(args).main()
ms-swift/swift/llm/sampling/utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import inspect
3
+ from copy import copy
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import InferRequest, RequestConfig
10
+ from swift.utils import get_logger
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ def get_messages_md5(row: Dict[str, Any]):
16
+ row = copy(row)
17
+ row.pop('choices', None)
18
+ serialized = json.dumps(row, sort_keys=True)
19
+ return hashlib.md5(serialized.encode('utf-8')).hexdigest()
20
+
21
+
22
+ def get_reward(model: Any,
23
+ infer_requests: List[InferRequest],
24
+ request_config: RequestConfig = None,
25
+ ground_truths: List[str] = None,
26
+ threshold: Optional[float] = None):
27
+ """Get reward from an RM model.
28
+
29
+ Args:
30
+ model: The model instance or an RM evaluator
31
+ infer_requests: Infer requests sent to the model
32
+ request_config: Infer config
33
+ ground_truths: The ground truth list
34
+ threshold: An optional threshold to generate the mask
35
+
36
+ Returns:
37
+ Tuple
38
+ Index 0: The min-max normalized scores matched the infer_requests
39
+ Index 1: The mask filtered by the threshold
40
+ """
41
+ from swift.llm import InferEngine
42
+ infer_func = model.infer if isinstance(model, InferEngine) else model.__call__
43
+ parameters = inspect.signature(infer_func).parameters
44
+ gt_param = {}
45
+ if 'ground_truths' in parameters:
46
+ gt_param = {'ground_truths': ground_truths}
47
+ if isinstance(infer_requests[0], dict):
48
+ infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests]
49
+ rewards = infer_func(infer_requests, request_config=request_config, **gt_param)
50
+ from swift.llm.infer.protocol import ChatCompletionResponse
51
+ if isinstance(rewards[0], ChatCompletionResponse):
52
+ print('reward:', rewards[0].choices[0].message.content)
53
+ if isinstance(rewards[0].choices[0].message.content, str):
54
+ rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards]
55
+ elif isinstance(rewards[0].choices[0].message.content, list):
56
+ rewards = [float(min(r.choices[0].message.content)) for r in rewards]
57
+ else:
58
+ rewards = [float(r.choices[0].message.content) for r in rewards]
59
+ arr = []
60
+ for reward in rewards:
61
+ if isinstance(reward, (list, tuple)):
62
+ arr.append(min(reward))
63
+ else:
64
+ arr.append(float(reward))
65
+
66
+ _mask = np.array([True] * len(arr))
67
+ if threshold is not None:
68
+ # > not >=, orm caller passes 0, which will cause error
69
+ _mask = np.array([a > threshold for a in arr])
70
+
71
+ def normalize(arr):
72
+ min_val = np.min(arr)
73
+ max_val = np.max(arr)
74
+ if min_val == max_val:
75
+ if min_val == 0:
76
+ constant_value = 0.0
77
+ else:
78
+ constant_value = min(1.0, min_val)
79
+ return np.full_like(arr, fill_value=constant_value, dtype=np.float64)
80
+ normalized = (arr - min_val) / (max_val - min_val + 1e-5)
81
+ return normalized
82
+
83
+ return normalize(arr), _mask
84
+
85
+
86
+ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs):
87
+ if isinstance(infer_engines, list):
88
+ assert len(infer_engines) >= len(request_configs) >= len(infer_requests)
89
+ from concurrent.futures import ThreadPoolExecutor, as_completed
90
+ n = len(infer_requests)
91
+ with ThreadPoolExecutor(max_workers=n) as executor:
92
+ futures = {
93
+ executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs):
94
+ i
95
+ for i in range(n)
96
+ }
97
+ responses = []
98
+ for future in as_completed(futures):
99
+ task_id = futures[future]
100
+ try:
101
+ responses += future.result()
102
+ except Exception as e:
103
+ logger.info(f'Perform infer task: {task_id} get an error: {e}')
104
+ return responses
105
+ elif isinstance(infer_requests, list):
106
+ responses = []
107
+ if isinstance(request_configs, list):
108
+ assert len(infer_requests) <= len(request_configs)
109
+ for i in range(len(infer_requests)):
110
+ responses += infer_engines.infer(
111
+ [infer_requests[i]],
112
+ request_configs[i],
113
+ **infer_kwargs,
114
+ )
115
+ elif isinstance(request_configs, RequestConfig):
116
+ for infer_request in infer_requests:
117
+ responses += infer_engines.infer(
118
+ [infer_request],
119
+ request_configs,
120
+ **infer_kwargs,
121
+ )
122
+ return responses
123
+ return infer_engines.infer(
124
+ [infer_requests],
125
+ request_configs,
126
+ **infer_kwargs,
127
+ )
128
+
129
+
130
+ def collect_from_mct(monte_carlo_tree, collect_filter_threshold):
131
+ from transformers.utils import strtobool
132
+ if isinstance(monte_carlo_tree, str):
133
+ monte_carlo_tree = json.loads(monte_carlo_tree)
134
+
135
+ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]):
136
+ _prefer_pairs, _correct_answers, _incorrect_answers = [], [], []
137
+ _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']]
138
+ _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']]
139
+ if len(collect_curr_node['children']) > 0:
140
+ for child in collect_curr_node['children']:
141
+ p, c, i = _collect(child, _outcome_rewards, _process_rewards)
142
+ _prefer_pairs += p
143
+ _correct_answers += c
144
+ _incorrect_answers += i
145
+ sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward'])
146
+ if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold:
147
+ # TODO: filter with visit count
148
+ prefer_pair = {
149
+ 'path': 'ки\n'.join(collect_curr_node['path']),
150
+ 'good': sorted_children[-1]['path'][-1],
151
+ 'good_score': sorted_children[-1]['outcome_reward'],
152
+ 'bad': sorted_children[0]['path'][-1],
153
+ 'bad_score': sorted_children[0]['outcome_reward'],
154
+ }
155
+ _prefer_pairs.append(prefer_pair)
156
+ if strtobool(collect_curr_node['terminated']):
157
+ _answer = {
158
+ 'answer': 'ки\n'.join(collect_curr_node['path']),
159
+ 'mean_outcome_reward': np.mean(_outcome_rewards),
160
+ 'min_outcome_reward': np.min(_outcome_rewards),
161
+ 'mean_process_reward': np.mean(_process_rewards),
162
+ 'min_process_reward': np.min(_process_rewards),
163
+ }
164
+ if strtobool(collect_curr_node['correct']):
165
+ _correct_answers.append(_answer)
166
+ else:
167
+ _incorrect_answers.append(_answer)
168
+ return _prefer_pairs, _correct_answers, _incorrect_answers
169
+
170
+ _root = monte_carlo_tree
171
+ prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], [])
172
+ return prefer_pairs, correct_answers, incorrect_answers
ms-swift/swift/llm/sampling/vanilla_sampler.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Any, Dict, List
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import RequestConfig
10
+ from swift.llm.sampling.base import Sampler
11
+ from swift.llm.template.template_inputs import InferRequest
12
+ from swift.utils import get_logger
13
+ from .utils import get_messages_md5, get_reward
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ class VanillaSampler(Sampler):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ if self.args.sampler_engine == 'pt':
24
+ from swift.llm import PtEngine
25
+ _Engine = PtEngine
26
+ elif self.args.sampler_engine == 'vllm':
27
+ from swift.llm import VllmEngine
28
+ _Engine = VllmEngine
29
+ elif self.args.sampler_engine == 'lmdeploy':
30
+ from swift.llm import LmdeployEngine
31
+ _Engine = LmdeployEngine
32
+ elif self.args.sampler_engine == 'no':
33
+ _Engine = None
34
+ else:
35
+ raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
36
+ self.infer_engine = None
37
+ if _Engine:
38
+ self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
39
+ self.infer_engine.default_template = self.template
40
+ self.infer_engine.strict = False
41
+ self.caches = self.read_cache()
42
+
43
+ def read_cache(self):
44
+ cache_files = self.args.cache_files
45
+ caches = {}
46
+ for file in cache_files:
47
+ if not os.path.exists(file):
48
+ logger.warning(f'Cache file does not exist: {file}')
49
+ continue
50
+ with open(file, 'r') as f:
51
+ for line in f.readlines():
52
+ line = line.strip()
53
+ if not line:
54
+ continue
55
+
56
+ content = json.loads(line)
57
+ uuid = content['id']
58
+ messages = content['messages']
59
+ if uuid not in caches:
60
+ caches[uuid] = {'choices': []}
61
+ assert messages[-1]['role'] == 'assistant'
62
+ caches[uuid]['choices'].append(messages[-1]['content'])
63
+ return caches
64
+
65
+ @staticmethod
66
+ def convert_data_to_rows(data):
67
+ rows = []
68
+ key = list(data.keys())[0]
69
+ data_len = len(data[key])
70
+ for idx in range(data_len):
71
+ row = {key: data[key][idx] for key in data}
72
+ if row.get('images') and 'bytes' in row['images'][0]:
73
+ row['images'] = [img['path'] for img in row['images']]
74
+ rows.append(row)
75
+ VanillaSampler.check_row_valid(rows)
76
+ return rows
77
+
78
+ @staticmethod
79
+ def check_row_valid(rows):
80
+ for row in rows:
81
+ assert not row.get('images') or all([isinstance(img, str) and img for img in row['images']])
82
+ assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']])
83
+ assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']])
84
+
85
+ def generate(self, data):
86
+ resp_all = []
87
+ infer_requests = []
88
+ sent = 0
89
+ rows = self.convert_data_to_rows(data)
90
+ for idx, row in enumerate(rows):
91
+ row = deepcopy(row)
92
+ messages = row['messages']
93
+ uuid = get_messages_md5(row)
94
+ if uuid in self.caches:
95
+ choices = self.caches[uuid]['choices']
96
+ if len(choices) == self.args.num_return_sequences:
97
+ continue
98
+ if self.args.system:
99
+ if messages[0]['role'] == 'system':
100
+ messages[0]['content'] = self.args.system
101
+ else:
102
+ messages.insert(0, {'role': 'system', 'content': self.args.system})
103
+ if messages[-1]['role'] == 'assistant':
104
+ messages = messages[:-1]
105
+
106
+ row['messages'] = messages
107
+ infer_request = row
108
+ for i in range(self.args.num_return_sequences):
109
+ infer_requests.append(deepcopy(infer_request))
110
+ sent += 1
111
+
112
+ request_config = RequestConfig(
113
+ max_tokens=self.args.max_new_tokens,
114
+ temperature=self.args.temperature,
115
+ top_k=self.args.top_k,
116
+ top_p=self.args.top_p,
117
+ )
118
+
119
+ resp_list = []
120
+ if len(infer_requests) > 0:
121
+ resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
122
+
123
+ _cur = 0
124
+ for idx, row in enumerate(rows):
125
+ row = deepcopy(row)
126
+ uuid = get_messages_md5(row)
127
+ if uuid in self.caches:
128
+ choices = self.caches[uuid]['choices']
129
+ if len(choices) == self.args.num_return_sequences:
130
+ row['choices'] = choices
131
+ resp_all.append(row)
132
+ continue
133
+
134
+ resps = row
135
+ resps['choices'] = []
136
+ for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
137
+ if not isinstance(resp_list[j], Exception):
138
+ resps['choices'].append(resp_list[j].choices[0].message.content)
139
+ if resps['choices']:
140
+ resp_all.append(resps)
141
+ _cur += 1
142
+ return resp_all
143
+
144
+ def do_sample(self, data):
145
+ generated = []
146
+ resp_all = self.generate(data)
147
+ for i, resps in enumerate(resp_all):
148
+ choices = resps['choices']
149
+ messages = resps['messages']
150
+ uuid = get_messages_md5(resps)
151
+ assert messages[-1]['role'] == 'assistant'
152
+ ground_truth = messages[-1]['content']
153
+
154
+ infer_requests = []
155
+ for decoded in choices:
156
+ _resps = deepcopy(resps)
157
+ _resps['messages'][-1]['content'] = decoded
158
+ infer_requests.append(_resps)
159
+
160
+ _resps = deepcopy(resps)
161
+ _resps['messages'][-1]['content'] = ground_truth
162
+ infer_requests.append(_resps)
163
+ if self.orm_model is not None:
164
+ orm_score, _orm_mask = get_reward(
165
+ self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0)
166
+ else:
167
+ orm_score = np.array([1.0] * len(infer_requests))
168
+ _orm_mask = np.array([True] * len(infer_requests))
169
+ if self.prm_model is not None:
170
+ prm_score, _prm_mask = get_reward(
171
+ self.prm_model,
172
+ infer_requests,
173
+ ground_truths=[ground_truth] * len(infer_requests),
174
+ threshold=self.args.prm_threshold)
175
+ else:
176
+ prm_score = np.array([1.0] * len(infer_requests))
177
+ _prm_mask = np.array([True] * len(infer_requests))
178
+
179
+ _mask = _orm_mask & _prm_mask
180
+ if not any(_mask):
181
+ continue
182
+
183
+ choices.append(ground_truth)
184
+ choices = np.array(choices)
185
+
186
+ if self.orm_model is None and self.prm_model is None:
187
+ positives = choices[:-1]
188
+ for positive in positives:
189
+ _resps = deepcopy(resps)
190
+ _resps.pop('choices', None)
191
+ _resps['id'] = uuid
192
+ _resps['messages'][-1]['content'] = str(positive)
193
+ generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
194
+ else:
195
+ score = np.array(prm_score) + np.array(orm_score * 10)
196
+ sorted_indices = np.argsort(score)[::-1]
197
+ pos_indexes = sorted_indices[0:self.args.n_best_to_keep]
198
+ pos_indexes = [i for i in pos_indexes if _mask[i]]
199
+ neg_index = sorted_indices[-1]
200
+ logger.info(
201
+ f'orm:{orm_score}, prm:{prm_score}, positive index: {pos_indexes}, negative index: {neg_index}')
202
+ if self.args.easy_query_threshold is not None and sum([score > 0 for score in orm_score]) - 1 >= int(
203
+ self.args.num_return_sequences * self.args.easy_query_threshold):
204
+ continue
205
+ if len(pos_indexes) > 0:
206
+ positives = choices[pos_indexes]
207
+ negative = choices[neg_index]
208
+ for positive in positives:
209
+ _resps = deepcopy(resps)
210
+ messages = deepcopy(messages)
211
+ _resps.pop('choices', None)
212
+ _resps['messages'][-1]['content'] = str(positive)
213
+ _resps['rejected_response'] = str(negative)
214
+ _resps['id'] = uuid
215
+ generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
216
+ return generated
ms-swift/swift/llm/template/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from . import template
3
+ from .base import MaxLengthError, Template
4
+ from .constant import TemplateType
5
+ from .grounding import draw_bbox
6
+ from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template
7
+ from .template_inputs import InferRequest, TemplateInputs
8
+ from .template_meta import TemplateMeta
9
+ from .utils import Prompt, Word, split_str_parts_by
10
+ from .vision_utils import load_file, load_image
ms-swift/swift/llm/template/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (758 Bytes). View file
 
ms-swift/swift/llm/template/__pycache__/base.cpython-310.pyc ADDED
Binary file (50.8 kB). View file
 
ms-swift/swift/llm/template/__pycache__/grounding.cpython-310.pyc ADDED
Binary file (3.47 kB). View file
 
ms-swift/swift/llm/template/__pycache__/register.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
ms-swift/swift/llm/template/__pycache__/template_inputs.cpython-310.pyc ADDED
Binary file (6.41 kB). View file
 
ms-swift/swift/llm/template/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.49 kB). View file
 
ms-swift/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc ADDED
Binary file (10.4 kB). View file
 
ms-swift/swift/llm/template/base.py ADDED
@@ -0,0 +1,1606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import hashlib
3
+ import inspect
4
+ import math
5
+ import os
6
+ import re
7
+ from contextlib import contextmanager, nullcontext
8
+ from copy import deepcopy
9
+ from dataclasses import asdict
10
+ from functools import partial, wraps
11
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from modelscope.hub.utils.utils import get_cache_dir
17
+ from peft import PeftModel
18
+ from PIL import Image
19
+ from torch.nn.utils.rnn import pad_sequence
20
+ from transformers import StoppingCriteriaList
21
+ from transformers.integrations import is_deepspeed_zero3_enabled
22
+ from transformers.utils import strtobool
23
+
24
+ from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
25
+ from ..utils import Processor, ProcessorMixin
26
+ from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
27
+ from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by
28
+ from .vision_utils import load_audio, load_batch, load_image, rescale_image
29
+
30
+ logger = get_logger()
31
+ if TYPE_CHECKING:
32
+ from .template_meta import TemplateMeta
33
+
34
+
35
+ class MaxLengthError(ValueError):
36
+ pass
37
+
38
+
39
+ class Template(ProcessorMixin):
40
+ special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>']
41
+ special_keys = ['images', 'videos', 'audios', 'objects']
42
+
43
+ image_placeholder = ['<image>']
44
+ video_placeholder = ['<video>']
45
+ audio_placeholder = ['<audio>']
46
+ cot_process_placeholder = ['ки']
47
+ placeholder_tokens = [] # For clearer printing
48
+ load_images = True
49
+ skip_prompt = True
50
+ use_model = False
51
+ norm_bbox = 'norm1000'
52
+
53
+ is_encoder_decoder = False
54
+
55
+ def __init__(
56
+ self,
57
+ processor: Processor,
58
+ template_meta: 'TemplateMeta',
59
+ default_system: Optional[str] = None,
60
+ max_length: Optional[int] = None,
61
+ *,
62
+ use_chat_template: bool = True,
63
+ truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
64
+ max_pixels: Optional[int] = None,
65
+ agent_template: Optional[str] = None,
66
+ norm_bbox: Literal['norm1000', 'none', None] = None,
67
+ response_prefix: Optional[str] = None,
68
+ # only for train
69
+ padding_side: Literal['left', 'right'] = 'right',
70
+ loss_scale: str = 'default',
71
+ sequence_parallel_size: int = 1,
72
+ # infer/deploy
73
+ template_backend: Literal['swift', 'jinja'] = 'swift',
74
+ ) -> None:
75
+ """
76
+ default_system: Override the default_system in the template.
77
+ max_length: Max length of the sequence
78
+ truncation_strategy: The truncation strategy
79
+ max_pixels: Rescale image to reduce memory usage, default `None` means no limitation.
80
+ e.g. 512 * 512 (H*W)
81
+ padding_side: The padding_side when the training batch_size >= 2
82
+ loss_scale: The loss scale function to use
83
+ """
84
+ from .template_meta import TemplateMeta
85
+ from swift.plugin import agent_templates
86
+
87
+ self.processor = processor
88
+ self.model_info = processor.model_info
89
+ self.config = self.model_info.config
90
+ self.model_meta = processor.model_meta
91
+ if max_length is None:
92
+ max_length = self.model_info.max_model_len
93
+ tokenizer = self.tokenizer
94
+
95
+ if not use_chat_template:
96
+ template_meta = template_meta.to_generate_template_meta()
97
+ else:
98
+ template_meta = deepcopy(template_meta)
99
+ # if default_system is None. not change self.default_system
100
+ template_meta.check_system(default_system)
101
+ if default_system is not None:
102
+ template_meta.default_system = default_system
103
+ if response_prefix is not None:
104
+ template_meta.response_prefix = response_prefix
105
+ logger.info(f'default_system: {repr(template_meta.default_system)}')
106
+ logger.info(f'response_prefix: {repr(template_meta.response_prefix)}')
107
+
108
+ for i, token in enumerate(self.placeholder_tokens):
109
+ if isinstance(token, str):
110
+ self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token)
111
+ template_meta.init(tokenizer)
112
+
113
+ self.template_meta: TemplateMeta = template_meta
114
+ self.use_chat_template = use_chat_template
115
+ self.template_backend = template_backend
116
+ self.max_length = max_length
117
+ self.truncation_strategy = truncation_strategy
118
+ self.loss_scale = loss_scale
119
+ self.max_pixels = max_pixels
120
+ self.padding_side = padding_side
121
+ self.sequence_parallel_size = sequence_parallel_size
122
+ agent_template = agent_template or template_meta.agent_template
123
+ logger.info(f'agent_template: {agent_template}')
124
+ self.agent_template = agent_templates[agent_template]()
125
+ self.norm_bbox = norm_bbox or self.norm_bbox
126
+ logger.info(f'max_length: {self.max_length}')
127
+ logger.info(f'norm_bbox: {self.norm_bbox}')
128
+ if self.is_encoder_decoder:
129
+ self.skip_prompt = False
130
+ self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
131
+ 'train', 'rlhf', 'kto', # train
132
+ 'seq_cls', 'embedding', 'prm'] = 'pt'
133
+ self._packing = False
134
+ self.use_megatron = False
135
+ if self.model_info.task_type != 'causal_lm':
136
+ self.mode = self.model_info.task_type
137
+ self._handles = []
138
+ self._deepspeed_initialize = None
139
+
140
+ @staticmethod
141
+ def _load_image(image, load_images: bool):
142
+ if load_images:
143
+ if isinstance(image, dict) and 'bytes' in image:
144
+ image = image['bytes'] or image['path']
145
+ image = load_image(image)
146
+ else:
147
+ if isinstance(image, dict):
148
+ path = image['path']
149
+ if path and (path.startswith('http') or os.path.exists(path)):
150
+ image = path
151
+ else:
152
+ image = load_image(image['bytes'])
153
+ elif not isinstance(image, str):
154
+ image = load_image(image)
155
+ return image
156
+
157
+ @staticmethod
158
+ def _get_height_width(inputs: StdTemplateInputs) -> None:
159
+ width = []
160
+ height = []
161
+ for image in inputs.images:
162
+ width.append(image.width)
163
+ height.append(image.height)
164
+ inputs.objects['width'] = width
165
+ inputs.objects['height'] = height
166
+
167
+ def normalize_bbox(self, inputs: StdTemplateInputs) -> None:
168
+ objects = inputs.objects
169
+ bbox_list = objects['bbox']
170
+ width_list = objects['width']
171
+ height_list = objects['height']
172
+ bbox_type = objects.pop('bbox_type', None) or 'real'
173
+ image_id_list = objects.pop('image_id', None) or []
174
+ image_id_list += [0] * (len(bbox_list) - len(image_id_list))
175
+ for bbox, image_id in zip(bbox_list, image_id_list):
176
+ if bbox_type == 'norm1':
177
+ width, height = 1, 1
178
+ else:
179
+ width, height = width_list[image_id], height_list[image_id]
180
+ for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])):
181
+ if self.norm_bbox == 'norm1000':
182
+ norm_width, norm_height = 1000, 1000
183
+ elif self.norm_bbox == 'none':
184
+ image = inputs.images[image_id]
185
+ norm_width, norm_height = image.width, image.height
186
+ bbox[2 * i] = int(round(x / width * norm_width))
187
+ bbox[2 * i + 1] = int(round(y / height * norm_height))
188
+
189
+ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
190
+ agent_template = self.agent_template
191
+ agent_template.template_meta = self.template_meta # for hermes
192
+ if inputs.tools:
193
+ if isinstance(inputs.tools, str):
194
+ inputs.tools = agent_template._parse_json(inputs.tools)
195
+ if not isinstance(inputs.tools, (list, tuple)):
196
+ inputs.tools = [inputs.tools]
197
+ elif isinstance(inputs.tools, (list, tuple)):
198
+ inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools]
199
+ else:
200
+ raise ValueError(f'inputs.tools: {inputs.tools}')
201
+ for i, tool in enumerate(inputs.tools):
202
+ inputs.tools[i] = agent_template.wrap_tool(tool)
203
+ i = 0
204
+ messages = inputs.messages
205
+ while i < len(messages):
206
+ if messages[i]['role'] == 'tool_call':
207
+ i_start = i
208
+ while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call':
209
+ i += 1
210
+ tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1])
211
+ messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}]
212
+ i = i_start + 1
213
+ else:
214
+ i += 1
215
+
216
+ def _preprocess_inputs(
217
+ self,
218
+ inputs: StdTemplateInputs,
219
+ ) -> None:
220
+ self._preprocess_function_call(inputs)
221
+ if self.model_meta.is_multimodal:
222
+ self._replace_image_tags(inputs)
223
+ self._replace_start_image_tags(inputs)
224
+ images = inputs.images
225
+ load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
226
+ load_images_origin = load_images
227
+ if self.max_pixels is not None or inputs.objects:
228
+ load_images = True
229
+ if images:
230
+ for i, image in enumerate(images):
231
+ images[i] = self._load_image(images[i], load_images)
232
+ if inputs.objects:
233
+ self._get_height_width(inputs)
234
+ if self.max_pixels is not None:
235
+ # Scale the image proportionally without affecting the scaled objects.
236
+ images = [rescale_image(img, self.max_pixels) for img in images]
237
+ if images and not load_images_origin: # fix pt & qwen-vl
238
+ for i, image in enumerate(images):
239
+ if isinstance(image, Image.Image):
240
+ images[i] = self._save_pil_image(image)
241
+ inputs.images = images
242
+
243
+ if self.mode == 'vllm' and inputs.audios:
244
+ sampling_rate = get_env_args('sampling_rate', int, None)
245
+ inputs.audios = load_batch(
246
+ inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True))
247
+
248
+ if inputs.is_multimodal:
249
+ self._add_default_tags(inputs)
250
+
251
+ @staticmethod
252
+ def _replace_image_tags(inputs: StdTemplateInputs):
253
+ # compat
254
+ if inputs.images:
255
+ return
256
+ images = []
257
+ pattern = r'<img>(.+?)</img>'
258
+ for message in inputs.messages:
259
+ content = message['content']
260
+ if not isinstance(content, str):
261
+ continue
262
+ for image in re.findall(pattern, content):
263
+ # only support local_path
264
+ if os.path.isfile(image):
265
+ images.append(image)
266
+ else:
267
+ logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>')
268
+ message['content'] = re.sub(pattern, '<image>', content)
269
+ inputs.images = images
270
+
271
+ @staticmethod
272
+ def _replace_start_image_tags(inputs: StdTemplateInputs):
273
+ # compat
274
+ generate_mode = False
275
+ message = inputs.messages[-1]
276
+ content = message['content']
277
+ if message['role'] == 'user' and content.endswith('<start-image>'):
278
+ generate_mode = True
279
+ message['content'] = message['content'][:-len('<start-image>')] # remove the <start-image>
280
+ inputs.generate_mode = generate_mode
281
+
282
+ @staticmethod
283
+ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
284
+ get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
285
+ added_tokens_len = 0
286
+ for i, idx in enumerate(replace_idx_list):
287
+ new_tokens = get_new_tokens(i)
288
+ token_len = len(new_tokens)
289
+ input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
290
+ if labels:
291
+ labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
292
+ added_tokens_len += token_len - 1
293
+ return input_ids, labels
294
+
295
+ def compute_loss_context(self, model, inputs):
296
+ return nullcontext()
297
+
298
+ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
299
+ chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
300
+ assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
301
+ rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
302
+ chosen_encoded = self._encode_truncated(chosen_inputs)
303
+ rejected_encoded = self._encode_truncated(rejected_inputs)
304
+
305
+ encoded = {}
306
+ for prefix in ['chosen', 'rejected']:
307
+ data = locals()[f'{prefix}_encoded']
308
+ for k, v in data.items():
309
+ encoded[f'{prefix}_{k}'] = v
310
+ return encoded
311
+
312
+ def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
313
+ label, inputs.label = inputs.label, None
314
+ encoded = self._rlhf_encode(inputs)
315
+ encoded['label'] = bool(label)
316
+ return encoded
317
+
318
+ def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
319
+ _encoded = {}
320
+ labels = []
321
+
322
+ def split_multi_medias(_inputs):
323
+ _content = _inputs.messages[-2]['content']
324
+ image_size = len(re.findall('<image>', _content))
325
+ video_size = len(re.findall('<video>', _content))
326
+ audio_size = len(re.findall('<audio>', _content))
327
+ _inputs.images = inputs.images[:image_size]
328
+ assert len(_inputs.images) == image_size
329
+ inputs.images = inputs.images[image_size:]
330
+ _inputs.videos = inputs.videos[:video_size]
331
+ assert len(_inputs.videos) == video_size
332
+ inputs.videos = inputs.videos[video_size:]
333
+ _inputs.audios = inputs.audios[:audio_size]
334
+ assert len(_inputs.audios) == audio_size
335
+ inputs.audios = inputs.audios[audio_size:]
336
+
337
+ anchor = deepcopy(inputs)
338
+ anchor.messages[-1]['content'] = ''
339
+ anchor.rejected_response = []
340
+ split_multi_medias(anchor)
341
+ anchor_encoded = self._encode_truncated(anchor)
342
+ for key in anchor_encoded:
343
+ _encoded[f'anchor_{key}'] = anchor_encoded[key]
344
+
345
+ positive = deepcopy(inputs)
346
+ positive.messages[-2]['content'] = positive.messages[-1]['content']
347
+ positive.messages[-1]['content'] = ''
348
+ positive.rejected_response = []
349
+ split_multi_medias(positive)
350
+ positive_encoded = self._encode_truncated(positive)
351
+ for key in positive_encoded:
352
+ _encoded[f'positive_{key}'] = positive_encoded[key]
353
+ labels.append(float(inputs.label) if inputs.label is not None else 1.0)
354
+
355
+ rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
356
+ for i in range(rejected_len):
357
+ negative = deepcopy(inputs)
358
+ negative.messages[-2]['content'] = negative.rejected_response[i]
359
+ negative.messages[-1]['content'] = ''
360
+ negative.rejected_response = []
361
+ split_multi_medias(negative)
362
+ negative_encoded = self._encode_truncated(negative)
363
+ for key in negative_encoded:
364
+ _encoded[f'negative{i}_{key}'] = negative_encoded[key]
365
+ labels.append(0.0)
366
+
367
+ _encoded['labels'] = labels
368
+ return _encoded
369
+
370
+ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
371
+ encoded = self._encode_truncated(inputs)
372
+ encoded.pop('labels', None)
373
+ if inputs.label is not None:
374
+ labels = inputs.label
375
+ problem_type = self._get_problem_type(self.config, labels=labels)
376
+ if problem_type == 'single_label_classification':
377
+ labels = int(labels)
378
+ encoded['labels'] = labels
379
+ return encoded
380
+
381
+ @torch.inference_mode()
382
+ def encode(self,
383
+ inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
384
+ return_template_inputs: bool = False) -> Dict[str, Any]:
385
+ """The entrance method of Template!
386
+
387
+ Returns:
388
+ return {'input_ids': List[int], 'labels': Optional[List[int]], ...}
389
+ """
390
+ if isinstance(inputs, (InferRequest, TemplateInputs)):
391
+ inputs = asdict(inputs)
392
+
393
+ if isinstance(inputs, dict):
394
+ inputs = deepcopy(inputs)
395
+ if not self.is_training:
396
+ InferRequest.remove_response(inputs['messages'])
397
+ inputs = StdTemplateInputs.from_dict(inputs)
398
+ elif isinstance(inputs, StdTemplateInputs):
399
+ inputs = deepcopy(inputs)
400
+ assert isinstance(inputs, StdTemplateInputs)
401
+ self._preprocess_inputs(inputs)
402
+
403
+ if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}:
404
+ encoded = self._encode_truncated(inputs)
405
+ elif self.mode == 'seq_cls':
406
+ encoded = self._seq_cls_encode(inputs)
407
+ elif self.mode == 'rlhf':
408
+ encoded = self._rlhf_encode(inputs)
409
+ elif self.mode == 'kto':
410
+ encoded = self._kto_encode(inputs)
411
+ elif self.mode == 'embedding':
412
+ encoded = self._embedding_encode(inputs)
413
+ for key in list(encoded.keys()):
414
+ if encoded[key] is None:
415
+ encoded.pop(key)
416
+ if return_template_inputs:
417
+ encoded['template_inputs'] = inputs
418
+ return encoded
419
+
420
+ def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
421
+ packed = {}
422
+ keys = set()
423
+ for r in row:
424
+ keys.update(r[0].keys())
425
+ for key in keys:
426
+ if key in {'input_ids', 'labels', 'loss_scale'}:
427
+ packed[key] = sum((x[0][key] for x in row), start=[])
428
+ if 'position_ids' not in packed:
429
+ packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
430
+
431
+ packed.update(self._data_collator_mm_data([r[0] for r in row]))
432
+ return packed
433
+
434
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
435
+ return inputs
436
+
437
+ @staticmethod
438
+ def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
439
+ len_tokens = len(stop_tokens)
440
+ if is_finished and generate_ids[-len_tokens:] == stop_tokens:
441
+ return generate_ids[:-len_tokens]
442
+ if not is_finished:
443
+ for i in range(len_tokens, 0, -1):
444
+ if generate_ids[-i:] == stop_tokens[:i]:
445
+ return generate_ids[:-i]
446
+ return generate_ids
447
+
448
+ @staticmethod
449
+ def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int):
450
+ idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist()
451
+ logprobs = logprobs.tolist()
452
+ return {
453
+ 'content': [{
454
+ 'index': pred,
455
+ 'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred],
456
+ 'top_logprobs': [{
457
+ 'index': idx,
458
+ 'logprob': logprobs[idx]
459
+ } for idx in idxs]
460
+ }]
461
+ }
462
+
463
+ @staticmethod
464
+ def _get_problem_type(config, labels=None, logits=None) -> str:
465
+ problem_type = config.problem_type
466
+ if problem_type is not None:
467
+ return problem_type
468
+ if labels is not None:
469
+ if isinstance(labels, (list, tuple)):
470
+ if labels and isinstance(labels[0], float):
471
+ problem_type = 'regression'
472
+ else:
473
+ problem_type = 'multi_label_classification'
474
+ else:
475
+ problem_type = 'single_label_classification'
476
+ assert config.num_labels >= labels + 1
477
+ if logits is not None:
478
+ if logits.shape[-1] == 1:
479
+ problem_type = 'regression'
480
+ else:
481
+ problem_type = 'single_label_classification' # compatible with older versions
482
+ assert problem_type is not None
483
+ config.problem_type = problem_type
484
+ return problem_type
485
+
486
+ def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int):
487
+ assert isinstance(logits, torch.Tensor)
488
+ problem_type = self._get_problem_type(self.config, logits=logits)
489
+ if problem_type == 'regression':
490
+ preds = logits.squeeze(dim=-1).tolist()
491
+ logprobs = [None] * len(preds)
492
+ else:
493
+ if problem_type == 'single_label_classification':
494
+ preds = torch.argmax(logits, dim=-1).tolist()
495
+ logprobs = torch.log_softmax(logits, -1)
496
+ else:
497
+ preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)]
498
+ logprobs = F.logsigmoid(logits)
499
+ logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)]
500
+ return preds, logprobs
501
+
502
+ def decode(self,
503
+ generate_ids: List[int],
504
+ *,
505
+ is_finished: bool = True,
506
+ tokenizer_kwargs=None,
507
+ first_token=True,
508
+ **kwargs) -> Any:
509
+ tokenizer_kwargs = tokenizer_kwargs or {}
510
+ response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs)
511
+ if first_token and self.template_meta.response_prefix:
512
+ response = self.template_meta.response_prefix + response
513
+ return response
514
+
515
+ def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
516
+ raise NotImplementedError
517
+
518
+ def generate(self, model, *args, **kwargs):
519
+ if isinstance(model, PeftModel):
520
+ signature = inspect.signature(model.model.generate)
521
+ else:
522
+ signature = inspect.signature(model.generate)
523
+ if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs:
524
+ kwargs['use_model_defaults'] = False
525
+ return model.generate(*args, **kwargs)
526
+
527
+ def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
528
+ # Do not print template_meta.suffix[-1] and eos_token.
529
+ # However, other stop_words will be printed.
530
+ tokenizer = self.tokenizer
531
+
532
+ if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
533
+ generate_ids = generate_ids[:-1]
534
+ # skip suffix and eos_token
535
+ template_suffix = self.template_meta.suffix[-1]
536
+ if isinstance(template_suffix, str):
537
+ # [-1:]: fix OpenGVLab/Mini-InternVL-Chat-4B-V1-5
538
+ template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:]
539
+ generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
540
+ if 'spaces_between_special_tokens' not in decode_kwargs:
541
+ decode_kwargs['spaces_between_special_tokens'] = False
542
+ return tokenizer.decode(generate_ids, **decode_kwargs)
543
+
544
+ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
545
+ generation_config = generate_kwargs['generation_config']
546
+ stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words
547
+ generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)])
548
+ return generate_kwargs
549
+
550
+ @staticmethod
551
+ def _save_pil_image(image: Image.Image) -> str:
552
+ img_bytes = image.tobytes()
553
+ img_hash = hashlib.sha256(img_bytes).hexdigest()
554
+ tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
555
+ logger.info_once(f'create tmp_dir: {tmp_dir}')
556
+ os.makedirs(tmp_dir, exist_ok=True)
557
+ img_path = os.path.join(tmp_dir, f'{img_hash}.png')
558
+ if not os.path.exists(img_path):
559
+ image.save(img_path)
560
+ return img_path
561
+
562
+ @staticmethod
563
+ def _concat_context_list(
564
+ context_list: List[Context],
565
+ res_context_list: List[Context], # inplace
566
+ res_context_type: List[ContextType], # inplace
567
+ system: Optional[str] = None,
568
+ query: Optional[str] = None,
569
+ response: Optional[str] = None,
570
+ round0: Optional[int] = None) -> None:
571
+ """Concat context list and replace placeholder"""
572
+ round1 = None
573
+ if round0 is not None:
574
+ round1 = str(round0 + 1)
575
+ round0 = str(round0)
576
+ for context in context_list:
577
+ if isinstance(context, str):
578
+ if '{{RESPONSE}}' == context:
579
+ assert response is not None
580
+ res_context_list.append(response)
581
+ res_context_type.append(ContextType.RESPONSE)
582
+ continue
583
+ old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
584
+ new_str_list = [system, query, round0, round1]
585
+ for (old_str, new_str) in zip(old_str_list, new_str_list):
586
+ if new_str is not None and old_str in context:
587
+ assert isinstance(new_str, str), f'new_str: {new_str}'
588
+ context = context.replace(old_str, new_str)
589
+ if len(context) == 0:
590
+ continue
591
+ res_context_list.append(context)
592
+ res_context_type.append(ContextType.OTHER)
593
+
594
+ def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
595
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
596
+ """Merge anything in the context to simplify the inputs"""
597
+ context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
598
+ context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs)
599
+
600
+ res: List[Context] = [] # result of context_list
601
+ res_loss_scale: List[float] = [] # result of loss_scale_list
602
+ temp: List[str] = []
603
+ temp_loss_scale = 0.
604
+ for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
605
+ if isinstance(context, str) and (loss_scale == temp_loss_scale):
606
+ temp.append(context)
607
+ else:
608
+ if len(temp) > 0:
609
+ res.append(''.join(temp))
610
+ res_loss_scale.append(temp_loss_scale)
611
+ temp.clear()
612
+ if isinstance(context, str): # loss_scale diff
613
+ temp.append(context)
614
+ else:
615
+ res.append(context)
616
+ res_loss_scale.append(loss_scale)
617
+ temp_loss_scale = loss_scale
618
+ if len(temp) > 0:
619
+ res.append(''.join(temp))
620
+ res_loss_scale.append(temp_loss_scale)
621
+
622
+ return res, res_loss_scale
623
+
624
+ @staticmethod
625
+ def _split_special_tokens(context_list: List[Context],
626
+ loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
627
+ """Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation"""
628
+ res: List[Context] = []
629
+ loss_scale_res: List[float] = []
630
+ for context, loss_scale in zip(context_list, loss_scale_list):
631
+ contexts = []
632
+ if isinstance(fetch_one(context), str):
633
+ for d in split_str_parts_by(context, Template.special_tokens):
634
+ contexts.extend([d['key'], d['content']])
635
+ contexts = [c for c in contexts if c]
636
+ res.extend(contexts)
637
+ loss_scale_res.extend([loss_scale] * len(contexts))
638
+ else:
639
+ res.append(context)
640
+ loss_scale_res.append(loss_scale)
641
+ return res, loss_scale_res
642
+
643
+ def _tokenize(self, context, **tokenizer_kwargs):
644
+ return self.tokenizer(
645
+ context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
646
+
647
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
648
+ inputs: StdTemplateInputs) -> List[Context]:
649
+ """Override this function to do your own replace operation.
650
+
651
+ This method is used to replace standard tags like `<image>` to some tokens that the model needs.
652
+
653
+ Args:
654
+ media_type: The modal.
655
+ index: The index of the medias, for index 0 represents the first elements in `images`
656
+ inputs: The inputs
657
+
658
+ Returns:
659
+ The content or input_ids after replacement.
660
+ """
661
+ if media_type == 'image':
662
+ if self.mode == 'lmdeploy':
663
+ return [[-100]]
664
+ return self.image_placeholder
665
+ elif media_type == 'video':
666
+ return self.video_placeholder
667
+ elif media_type == 'audio':
668
+ return self.audio_placeholder
669
+
670
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
671
+ """Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task.
672
+ Override this function to do your own replace operation.
673
+
674
+ Args:
675
+ ref: Description of the bbox
676
+ index: The index in the `objects` key
677
+ inputs: The inputs
678
+
679
+ Returns:
680
+ The contents or input_ids replaced
681
+ """
682
+ return [ref]
683
+
684
+ def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]:
685
+ """Replace the cot process label for PRM training or inference.
686
+ Override this function to do your own replace operation.
687
+
688
+ Args:
689
+ inputs: The inputs
690
+
691
+ Returns:
692
+ The contents or input_ids replaced
693
+ """
694
+ return [self.cot_process_placeholder]
695
+
696
+ @staticmethod
697
+ def _get_bbox_str(bbox: List[int]) -> str:
698
+ point = []
699
+ for x, y in zip(bbox[::2], bbox[1::2]):
700
+ point.append(f'({x},{y})')
701
+ return ','.join(point)
702
+
703
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
704
+ """Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task.
705
+ Override this function to do your own replace operation.
706
+
707
+ Args:
708
+ bbox: [x, y] or [x1, y1, x2, y2]
709
+ index: The index in the `objects` key
710
+ inputs: The inputs
711
+
712
+ Returns:
713
+ The contents or input_ids replaced
714
+ """
715
+ return [f'[{self._get_bbox_str(bbox)}]']
716
+
717
+ def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float],
718
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
719
+ # https://github.com/modelscope/ms-swift/issues/3407
720
+ # Fix the bounding box position offset issue in the Qwen2.5-VL grounding task.
721
+ res: List[Context] = []
722
+ res_loss_scale: List[float] = []
723
+ inputs.image_idx = 0
724
+
725
+ for context, loss_scale in zip(context_list, loss_scale_list):
726
+ if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
727
+ c_list = self.replace_tag('image', inputs.image_idx, inputs)
728
+ inputs.image_idx += 1
729
+ loss_scale = 0. if self.template_backend == 'swift' else 1.
730
+ else:
731
+ c_list = [context]
732
+ res += c_list
733
+ res_loss_scale += [loss_scale] * len(c_list)
734
+ return res, res_loss_scale
735
+
736
+ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
737
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
738
+ """This method happens before tokenization, replace standard tags to the contents or input_ids needed by
739
+ the model.
740
+
741
+ Args:
742
+ context_list: The content list
743
+ loss_scale_list: The loss scale list
744
+ Returns:
745
+ The context_list and loss_scale_list after replacement.
746
+ """
747
+ context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs)
748
+ if inputs.images and inputs.objects:
749
+ self.normalize_bbox(inputs)
750
+ # replace tag/object/box
751
+ res: List[Context] = [] # result of context_list
752
+ res_loss_scale: List[float] = [] # result of loss_scale_list
753
+
754
+ # reset
755
+ for k in ['video', 'audio', 'object', 'box']:
756
+ setattr(inputs, f'{k}_idx', 0)
757
+
758
+ for context, loss_scale in zip(context_list, loss_scale_list):
759
+ for k in ['video', 'audio']:
760
+ if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len(
761
+ getattr(inputs, f'{k}s')):
762
+ c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs)
763
+ setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1)
764
+ loss_scale = 0.
765
+ break
766
+ else:
767
+ ref = inputs.objects.get('ref') or []
768
+ bbox = inputs.objects.get('bbox') or []
769
+ if context == '<ref-object>' and inputs.ref_idx < len(ref):
770
+ idx = inputs.ref_idx
771
+ c_list = self.replace_ref(ref[idx], idx, inputs)
772
+ inputs.ref_idx += 1
773
+ elif context == '<bbox>' and inputs.bbox_idx < len(bbox):
774
+ idx = inputs.bbox_idx
775
+ c_list = self.replace_bbox(bbox[idx], idx, inputs)
776
+ inputs.bbox_idx += 1
777
+ elif context == '<cot-process>' and self.mode == 'prm':
778
+ c_list = self.replace_cot_process(inputs)
779
+ else:
780
+ c_list = [context]
781
+ res += c_list
782
+ res_loss_scale += [loss_scale] * len(c_list)
783
+ return res, res_loss_scale
784
+
785
+ @staticmethod
786
+ def _add_default_tags(inputs: StdTemplateInputs):
787
+ total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
788
+ if inputs.rejected_response:
789
+ if isinstance(inputs.rejected_response, str):
790
+ total_content += inputs.rejected_response
791
+ else:
792
+ total_content += '\n'.join(inputs.rejected_response)
793
+ if inputs.system:
794
+ total_content = f'{inputs.system}\n{total_content}'
795
+ for media_type in ['image', 'audio', 'video']:
796
+ media_key, media_tag = f'{media_type}s', f'<{media_type}>'
797
+ medias = getattr(inputs, media_key)
798
+ if not isinstance(medias, list):
799
+ medias = [medias]
800
+ if medias:
801
+ num_media_tags = len(re.findall(media_tag, total_content))
802
+ num_media = len(medias)
803
+ num_new_tags = num_media - num_media_tags
804
+ if num_new_tags > 0:
805
+ inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content']
806
+ elif num_new_tags < 0:
807
+ logger.warning(
808
+ f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. '
809
+ 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.')
810
+
811
+ def _encode_context_list(
812
+ self,
813
+ context_list: List[Context],
814
+ loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
815
+ """return: input_ids, labels, tokenizer_kwargs"""
816
+ input_ids: List[int] = []
817
+ labels: List[int] = []
818
+ loss_scale: List[float] = []
819
+ tokenizer_kwargs = {}
820
+ if loss_scale_list is None:
821
+ loss_scale_list = [0.] * len(context_list)
822
+ ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
823
+ for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
824
+ if isinstance(context, str):
825
+ # tokenizer_kwargs is the returned tokenizer_kwargs,
826
+ # while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
827
+ token_list = self._tokenize(context)
828
+ else:
829
+ token_list = context
830
+ input_ids += token_list
831
+ if loss_scale_list[i] > 0.0:
832
+ labels += token_list
833
+ else:
834
+ labels += [-100] * len(token_list)
835
+ if not ignore_loss_scale:
836
+ loss_scale.extend([loss_weight] * len(token_list))
837
+ if ignore_loss_scale:
838
+ loss_scale = None
839
+ return input_ids, labels, loss_scale, tokenizer_kwargs
840
+
841
+ @staticmethod
842
+ def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]],
843
+ suffix_tokens_id: List[int]) -> None:
844
+ suffix_len = len(suffix_tokens_id)
845
+ start = 0
846
+ for i in range(1, len(labels)):
847
+ if labels[i - 1] >= 0 and labels[i] == -100:
848
+ start = i
849
+ if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
850
+ # [0, 1, 2, -100(start), -100, 3(i), 4]
851
+ length = i - start
852
+ if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id:
853
+ labels[start:start + suffix_len] = suffix_tokens_id
854
+ if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len:
855
+ loss_scale[start:start + suffix_len] = [1] * suffix_len
856
+
857
+ @staticmethod
858
+ def _get_std_messages(messages):
859
+ if messages and messages[0]['role'] == 'assistant':
860
+ messages.insert(0, {'role': 'user', 'content': ''}) # pretrain
861
+ if len(messages) % 2 == 1:
862
+ messages.append({'role': 'assistant', 'content': None}) # inference
863
+
864
+ def _jinja_encode(self, inputs: StdTemplateInputs):
865
+ messages = inputs.messages.copy()
866
+ if inputs.system is not None:
867
+ messages.insert(0, {'role': 'system', 'content': inputs.system})
868
+ if messages[-1]['content'] is None:
869
+ messages.pop()
870
+ add_generation_prompt = messages[-1]['role'] != 'assistant'
871
+ kwargs = {}
872
+ if inputs.tools:
873
+ kwargs['tools'] = inputs.tools
874
+ text = self.tokenizer.apply_chat_template(
875
+ messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
876
+ answer_len = 1 if self.is_training else 0
877
+ return [text], [1.], answer_len
878
+
879
+ def _get_system(self, inputs) -> Optional[str]:
880
+ template_meta = self.template_meta
881
+ system = inputs.system
882
+ tools = inputs.tools
883
+ template_meta.check_system(system)
884
+ if system is None:
885
+ system = template_meta.default_system
886
+
887
+ if tools is not None:
888
+ system = self.agent_template._format_tools(tools, system or '', inputs.messages[0])
889
+ return system
890
+
891
+ @staticmethod
892
+ def _swift_prepare_function_call(agent_template, messages):
893
+ if len(messages) < 2:
894
+ return
895
+ i = 1
896
+ while i < len(messages):
897
+ pre_message, message = messages[i - 1], messages[i]
898
+ pre_role, pre_content = pre_message['role'], pre_message['content']
899
+ role, content = message['role'], message['content']
900
+ if pre_role == 'assistant' and role == 'tool':
901
+ i_start = i
902
+ while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
903
+ i += 1
904
+ pre_message['content'], tool_content = agent_template._format_tool_responses(
905
+ pre_content, messages[i_start:i + 1])
906
+ messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}]
907
+ i = i_start + 1
908
+ elif pre_role == 'assistant' and role == 'assistant':
909
+ # Consecutive messages from the assistant role need to be merged to prevent errors.
910
+ pre_message['content'] = pre_content + content
911
+ messages.pop(i)
912
+ else:
913
+ i += 1
914
+
915
+ def _swift_encode(self, inputs: StdTemplateInputs):
916
+ template_meta = self.template_meta
917
+ system = self._get_system(inputs)
918
+ self._swift_prepare_function_call(self.agent_template, inputs.messages)
919
+
920
+ self._get_std_messages(inputs.messages)
921
+ n_round = len(inputs.messages) // 2
922
+ if n_round > 1 and not self.template_meta.support_multi_round:
923
+ logger.warning_once(
924
+ 'The template does not support multi-round chat. Only use the last round of the conversation.')
925
+ inputs.messages = inputs.messages[-2:]
926
+
927
+ res_context_list: List[Context] = []
928
+ res_context_types: List[ContextType] = []
929
+ sep_token = None
930
+ if template_meta.auto_add_bos:
931
+ all_tokens = self.tokenizer.encode('a')
932
+ single_token = self.tokenizer.encode('a', add_special_tokens=False)
933
+ assert len(single_token) == 1
934
+ idx = all_tokens.index(single_token[0])
935
+ bos_token = all_tokens[:idx]
936
+ sep_token = all_tokens[idx + 1:]
937
+ if bos_token:
938
+ res_context_list.append(bos_token)
939
+ res_context_types.append(ContextType.OTHER)
940
+
941
+ prefix = template_meta.system_prefix if system else template_meta.prefix
942
+ self._concat_context_list(prefix, res_context_list, res_context_types, system=system)
943
+
944
+ n_round = len(inputs.messages) // 2
945
+ for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
946
+ query_role, query = query_message['role'], query_message['content']
947
+ response_role, response = response_message['role'], response_message['content']
948
+ # TODO: Optimize the Template mechanism.
949
+ assert query_role in {'user', 'tool'}, f'query_role: {query_role}'
950
+ assert response_role in {'assistant'}, f'response_role: {response_role}'
951
+ if query_role == 'tool':
952
+ prompt = query
953
+ query = ''
954
+ elif template_meta.is_post_system and i == n_round - 1:
955
+ prompt = template_meta.system_prompt
956
+ else:
957
+ prompt = template_meta.prompt
958
+
959
+ context_list = prompt.copy()
960
+ extra_context_list = []
961
+ extra_context_type = None
962
+ if i < n_round - 1:
963
+ # Not the last round.
964
+ context_list.append('{{RESPONSE}}')
965
+ if inputs.messages[2 * (i + 1)]['role'] != 'tool':
966
+ extra_context_list = template_meta.chat_sep
967
+ extra_context_type = ContextType.OTHER
968
+ elif response is not None:
969
+ # It is the final round, and the response exists (during training).
970
+ context_list.append('{{RESPONSE}}')
971
+ if self.is_training and not sep_token:
972
+ extra_context_list = template_meta.suffix
973
+ extra_context_type = ContextType.SUFFIX
974
+ elif template_meta.response_prefix:
975
+ # final round and during inference.
976
+ context_list.append(template_meta.response_prefix)
977
+
978
+ self._concat_context_list(
979
+ context_list,
980
+ res_context_list,
981
+ res_context_types,
982
+ query=query,
983
+ response=response,
984
+ system=system,
985
+ round0=i)
986
+ res_context_list += extra_context_list
987
+ res_context_types += [extra_context_type] * len(extra_context_list)
988
+ if template_meta.auto_add_bos and sep_token:
989
+ res_context_list.append(sep_token)
990
+ res_context_types.append(ContextType.SUFFIX)
991
+ from swift.plugin import loss_scale_map
992
+ res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types,
993
+ inputs.messages)
994
+ if self.is_training:
995
+ answer_len = len(extra_context_list) + bool(response is not None)
996
+ else:
997
+ answer_len = 0
998
+ return res_context_list, loss_scale_list, answer_len
999
+
1000
+ def _encode_truncated(self, inputs):
1001
+ if self.mode in {'vllm', 'lmdeploy'}:
1002
+ encoded = Template._encode(self, inputs)
1003
+ for key in ['images', 'audios', 'videos']:
1004
+ encoded[key] = getattr(inputs, key)
1005
+ else:
1006
+ encoded = self._encode(inputs)
1007
+
1008
+ input_ids = encoded.get('input_ids')
1009
+ labels = encoded.get('labels')
1010
+ loss_scale = encoded.get('loss_scale')
1011
+ if self.max_length is not None:
1012
+ if self.truncation_strategy == 'right':
1013
+ input_ids = input_ids[:self.max_length]
1014
+ if labels is not None:
1015
+ labels = labels[:self.max_length]
1016
+ if loss_scale is not None:
1017
+ loss_scale = loss_scale[:self.max_length]
1018
+ elif self.truncation_strategy == 'left':
1019
+ if len(input_ids) > self.max_length:
1020
+ logger.warning_once(
1021
+ 'Input data was left-truncated because its length exceeds `max_length` (input length: '
1022
+ f'{len(input_ids)}, max_length: {self.max_length}). '
1023
+ 'This may cause loss of important tokens (e.g., image tokens) and lead to errors. '
1024
+ 'To avoid this, consider increasing `max_length` or pre-filtering long sequences.',
1025
+ hash_id='max_length_check')
1026
+ input_ids = input_ids[-self.max_length:]
1027
+ if labels is not None:
1028
+ labels = labels[-self.max_length:]
1029
+ if loss_scale is not None:
1030
+ loss_scale = loss_scale[-self.max_length:]
1031
+ elif self.truncation_strategy == 'raise':
1032
+ length = len(input_ids or labels or [])
1033
+ if length > self.max_length:
1034
+ raise MaxLengthError(f'Current length of row({length}) is larger'
1035
+ f' than the max_length({self.max_length}).')
1036
+ encoded['input_ids'] = input_ids
1037
+ encoded['labels'] = labels
1038
+ encoded['loss_scale'] = loss_scale
1039
+ return encoded
1040
+
1041
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
1042
+ template_backend = self.template_backend
1043
+ if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
1044
+ and self.mode != 'seq_cls'):
1045
+ template_backend = 'jinja'
1046
+ logger.info_once(f'Setting template_backend: {template_backend}')
1047
+ res_context_list, loss_scale_list, answer_len = (
1048
+ self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs))
1049
+ encoded = {}
1050
+ if self.is_encoder_decoder:
1051
+ # tokenizer_kwargs: use prompt (qwen-audio)
1052
+ total_len = len(res_context_list)
1053
+ for key, _slice in zip(['prompt', 'answer'],
1054
+ [slice(0, total_len - answer_len),
1055
+ slice(total_len - answer_len, total_len)]):
1056
+ context_list, loss_scale = self._simplify_context_list(res_context_list[_slice],
1057
+ loss_scale_list[_slice], inputs)
1058
+ input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale)
1059
+ encoded[f'{key}_input_ids'] = input_ids
1060
+ if key == 'answer':
1061
+ encoded['labels'] = labels
1062
+ encoded['loss_scale'] = loss_scale
1063
+ input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids']
1064
+ else:
1065
+ res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
1066
+ input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
1067
+ res_context_list, loss_scale_list)
1068
+ self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0])
1069
+
1070
+ if tokenizer_kwargs:
1071
+ encoded['tokenizer_kwargs'] = tokenizer_kwargs
1072
+
1073
+ encoded['input_ids'] = input_ids
1074
+ encoded['labels'] = labels
1075
+ encoded['loss_scale'] = loss_scale
1076
+ if self.use_megatron:
1077
+ self._handle_megatron_cp(encoded)
1078
+ encoded['labels'] = encoded['labels'][1:] + [-100]
1079
+ encoded['position_ids'] = list(range(len(encoded['labels'])))
1080
+ elif encoded.get('labels') is not None:
1081
+ encoded['labels'][0] = -100
1082
+ if not self.is_training:
1083
+ for k in list(encoded.keys()):
1084
+ if k.endswith('labels') or k.endswith('loss_scale'):
1085
+ encoded[k] = None
1086
+ return encoded
1087
+
1088
+ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None:
1089
+ cp_size = self.sequence_parallel_size
1090
+ if cp_size == 1:
1091
+ return
1092
+ input_ids = encoded['input_ids']
1093
+ padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids)
1094
+ input_ids += [self.tokenizer.pad_token_id] * padding_len
1095
+ encoded['labels'] += [-100] * padding_len
1096
+
1097
+ def debug_logger(self, inputs):
1098
+ if not strtobool(os.getenv('SWIFT_DEBUG', 'false')):
1099
+ return
1100
+ if 'input_ids' in inputs:
1101
+ k = 'input_ids'
1102
+ val = inputs['input_ids']
1103
+ elif 'generate_ids' in inputs:
1104
+ k = 'generate_ids'
1105
+ val = inputs['generate_ids']
1106
+ for v in val:
1107
+ self.print_inputs({k: v.tolist()})
1108
+
1109
+ @staticmethod
1110
+ def _split_list(inputs: List[int], x: int) -> List[List[int]]:
1111
+ idxs = findall(inputs, x)
1112
+ idxs.append(len(inputs))
1113
+ res = []
1114
+ lo = 0
1115
+ for idx in idxs:
1116
+ res.append(inputs[lo:idx])
1117
+ lo = idx + 1
1118
+ return res
1119
+
1120
+ def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
1121
+ context_list = []
1122
+ if self.mode in {'vllm', 'lmdeploy'}:
1123
+ video = inputs.videos.pop(inputs.video_idx)
1124
+ inputs.video_idx -= 1
1125
+ else:
1126
+ video = inputs.videos[inputs.video_idx]
1127
+ images = inputs.images
1128
+ new_images = load_video_func(video)
1129
+ inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
1130
+ for i in range(len(new_images)):
1131
+ context_list += replace_tag(i)
1132
+ inputs.image_idx += len(new_images)
1133
+ return context_list
1134
+
1135
+ def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]],
1136
+ num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]:
1137
+ if self.skip_prompt:
1138
+ generate_ids = generate_ids[..., num_prompt_tokens:]
1139
+ return generate_ids
1140
+
1141
+ def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str:
1142
+ return response
1143
+
1144
+ def pre_forward_hook(self, model: nn.Module, args, kwargs):
1145
+ from swift.llm import to_device
1146
+ old_kwargs = to_device(kwargs, model.device)
1147
+ kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
1148
+ for k, v in old_kwargs.items():
1149
+ if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
1150
+ kwargs[k] = v
1151
+ if 'inputs_embeds' in kwargs:
1152
+ kwargs.pop('input_ids', None)
1153
+
1154
+ if isinstance(model, PeftModel):
1155
+ parameters = inspect.signature(model.model.forward).parameters
1156
+ else:
1157
+ parameters = inspect.signature(model.forward).parameters
1158
+ if 'position_ids' not in parameters:
1159
+ kwargs.pop('position_ids', None)
1160
+ return args, kwargs
1161
+
1162
+ @property
1163
+ def is_training(self):
1164
+ return self.mode not in {'vllm', 'lmdeploy', 'pt'}
1165
+
1166
+ def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
1167
+ self.mode = mode
1168
+
1169
+ def register_post_encode_hook(self, models: List[nn.Module]) -> None:
1170
+ """This function is important for multi-modal training, as it registers the post_encode method
1171
+ as a forward hook, converting input_ids into inputs_embeds.
1172
+ """
1173
+ if self._handles:
1174
+ return
1175
+
1176
+ for model in models:
1177
+ # please use torch>=2.0
1178
+ handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True)
1179
+ self._handles.append((model, handle))
1180
+
1181
+ if is_deepspeed_zero3_enabled():
1182
+ import deepspeed
1183
+ self._deepspeed_initialize = deepspeed.initialize
1184
+
1185
+ @wraps(self._deepspeed_initialize)
1186
+ def _initialize(*args, **kwargs):
1187
+ res = self._deepspeed_initialize(*args, **kwargs)
1188
+ for model, handle in self._handles:
1189
+ model._forward_pre_hooks.move_to_end(handle.id)
1190
+ return res
1191
+
1192
+ deepspeed.initialize = _initialize
1193
+
1194
+ def remove_post_encode_hook(self):
1195
+ models = []
1196
+ for model, handle in self._handles:
1197
+ models.append(model)
1198
+ handle.remove()
1199
+ self._handles = []
1200
+
1201
+ if self._deepspeed_initialize is not None:
1202
+ import deepspeed
1203
+ deepspeed.initialize = self._deepspeed_initialize
1204
+ self._deepspeed_initialize = None
1205
+ return models
1206
+
1207
+ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1208
+ if self.mode == 'rlhf':
1209
+ return self._rlhf_data_collator(batch, padding_to=padding_to)
1210
+ elif self.mode == 'kto':
1211
+ return self._kto_data_collator(batch, padding_to=padding_to)
1212
+ elif self.mode in {'pt', 'train', 'prm'}:
1213
+ return self._data_collator(batch, padding_to=padding_to)
1214
+ elif self.mode == 'seq_cls':
1215
+ return self._seq_cls_data_collator(batch, padding_to=padding_to)
1216
+ elif self.mode == 'embedding':
1217
+ return self._embedding_data_collator(batch, padding_to=padding_to)
1218
+
1219
+ @staticmethod
1220
+ def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]:
1221
+ new_batch = []
1222
+ for inputs in batch:
1223
+ new_inputs = {}
1224
+ for k, v in inputs.items():
1225
+ if k.startswith(prefix):
1226
+ new_inputs[k[len(prefix):]] = v
1227
+ new_batch.append(new_inputs)
1228
+ return new_batch
1229
+
1230
+ @staticmethod
1231
+ def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]:
1232
+ from swift.llm import RowPreprocessor
1233
+ keys = keys or []
1234
+ rows = RowPreprocessor.rows_to_batched(batch)
1235
+ return {k: rows[k] for k in keys if rows.get(k) is not None}
1236
+
1237
+ @staticmethod
1238
+ def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
1239
+ # List[Tensor] -> List[Tensor]
1240
+ res = []
1241
+ for b in batch:
1242
+ if b.get(attr_name) is not None:
1243
+ res += b.pop(attr_name)
1244
+ return res
1245
+
1246
+ @staticmethod
1247
+ def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]:
1248
+ res = []
1249
+ for b in batch:
1250
+ if b.get(attr_name) is not None:
1251
+ res.append(b.pop(attr_name))
1252
+ return torch.concat(res, dim=dim) if res else None
1253
+
1254
+ def _rlhf_data_collator(self,
1255
+ batch: List[Dict[str, Any]],
1256
+ *,
1257
+ chosen_prefix: str = 'chosen_',
1258
+ rejected_prefix: str = 'rejected_',
1259
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1260
+ new_batch = []
1261
+ for prefix in [chosen_prefix, rejected_prefix]:
1262
+ new_batch += self._fetch_inputs_startswith(batch, prefix)
1263
+ return self._data_collator(new_batch, padding_to=padding_to)
1264
+
1265
+ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1266
+ new_batch = self._fetch_inputs_startswith(batch, 'chosen_')
1267
+ kl_batch = self._fetch_inputs_startswith(batch, 'rejected_')
1268
+
1269
+ res = self._data_collator(new_batch, padding_to=padding_to)
1270
+ kl_res = self._data_collator(kl_batch, padding_to=padding_to)
1271
+ res = {
1272
+ **{f'completion_{k}': v
1273
+ for k, v in res.items()},
1274
+ **{f'KL_completion_{k}': v
1275
+ for k, v in kl_res.items()},
1276
+ }
1277
+ label = [b['label'] for b in batch if b.get('label') is not None]
1278
+ if label:
1279
+ res['label'] = label
1280
+ return res
1281
+
1282
+ def _embedding_data_collator(self,
1283
+ batch: List[Dict[str, Any]],
1284
+ *,
1285
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1286
+ labels = []
1287
+ new_batch = []
1288
+ for b in batch:
1289
+ keys = [key for key in b.keys() if 'negative' in key]
1290
+ max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
1291
+ indexes = ['anchor_', 'positive_']
1292
+ if max_neg is not None:
1293
+ for i in range(0, max_neg + 1):
1294
+ indexes.append(f'negative{i}_')
1295
+ for prefix in indexes:
1296
+ new_batch += self._fetch_inputs_startswith([b], prefix)
1297
+ labels.extend(b.get('labels', None))
1298
+ res = self._data_collator(new_batch, padding_to=padding_to)
1299
+ if labels:
1300
+ res['labels'] = torch.tensor(labels, dtype=torch.float32)
1301
+ return res
1302
+
1303
+ def _seq_cls_data_collator(self,
1304
+ batch: List[Dict[str, Any]],
1305
+ *,
1306
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1307
+ labels = [b.pop('labels') for b in batch if b.get('labels') is not None]
1308
+ res = self._data_collator(batch, padding_to=padding_to)
1309
+ if labels:
1310
+ problem_type = self._get_problem_type(self.config)
1311
+ if problem_type == 'regression':
1312
+ labels = torch.tensor(labels, dtype=torch.float32)
1313
+ elif problem_type == 'multi_label_classification':
1314
+ one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32)
1315
+ for i, label in enumerate(labels):
1316
+ one_hot_labels[i, label] = 1
1317
+ labels = one_hot_labels
1318
+ else:
1319
+ labels = torch.tensor(labels, dtype=torch.long)
1320
+ res['labels'] = labels
1321
+ return res
1322
+
1323
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1324
+ """
1325
+ Args:
1326
+ batch(`List[Dict[str, Any]]`): The input data in batch
1327
+ padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
1328
+ will be padded to the `longest`
1329
+ """
1330
+ assert self.tokenizer.pad_token_id is not None
1331
+ padding_side = self.padding_side if self.is_training else 'left'
1332
+ padding_right = padding_side == 'right'
1333
+ packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0]
1334
+ res = {}
1335
+ if packing_mode:
1336
+ # only support llm
1337
+ for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']:
1338
+ v = self.gather_list(batch, k)
1339
+ if v:
1340
+ res[k] = [v]
1341
+ else:
1342
+ inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
1343
+ input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
1344
+ if inputs_embeds:
1345
+ res['inputs_embeds'] = inputs_embeds
1346
+ if input_ids:
1347
+ res['input_ids'] = input_ids
1348
+ for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
1349
+ val = [b[key] for b in batch if b.get(key) is not None]
1350
+ if val:
1351
+ res[key] = val
1352
+
1353
+ keys = [
1354
+ 'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids'
1355
+ ]
1356
+ pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
1357
+ # Convert to tensor and remove unnecessary dimensions.
1358
+ seq_lens = None
1359
+ for key in keys:
1360
+ if key not in res:
1361
+ continue
1362
+ for i, val in enumerate(res[key]):
1363
+ if isinstance(val, (list, tuple)):
1364
+ val = torch.tensor(val)
1365
+ elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2:
1366
+ val = val[0]
1367
+ res[key][i] = val
1368
+ if not seq_lens:
1369
+ seq_lens = [seq.shape[0] for seq in res[key]]
1370
+ if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res):
1371
+ res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens]
1372
+ if self.is_training and self.padding_side == 'left':
1373
+ res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens]
1374
+
1375
+ if self.use_megatron:
1376
+ padding_to = math.ceil(max(seq_lens) / 128) * 128
1377
+ cp_size = self.sequence_parallel_size
1378
+ if cp_size > 1:
1379
+ padding_len = padding_to - seq_lens[0]
1380
+ position_ids = res['position_ids'][0].tolist()
1381
+ position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2))
1382
+ res['position_ids'][0] = torch.tensor(position_ids)
1383
+
1384
+ for key, pad_value in zip(keys, pad_values):
1385
+ if key not in res:
1386
+ continue
1387
+ if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1:
1388
+ pass
1389
+ elif padding_to is not None:
1390
+ padding_len = padding_to - seq_lens[0]
1391
+ if padding_len > 0:
1392
+ res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
1393
+ 'constant', pad_value)
1394
+ res[key] = self._pad_sequence(res[key], pad_value)
1395
+
1396
+ # multimodal
1397
+ res.update(self._data_collator_mm_data(batch))
1398
+ if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1):
1399
+ res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)
1400
+
1401
+ return res
1402
+
1403
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
1404
+ # multimodal
1405
+ res = {}
1406
+ pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
1407
+ if len(pixel_values) > 0:
1408
+ res['pixel_values'] = torch.concat(pixel_values)
1409
+
1410
+ image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
1411
+ if len(image_sizes) > 0:
1412
+ res['image_sizes'] = torch.concat(image_sizes)
1413
+
1414
+ pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
1415
+ if len(pixel_values_videos) > 0:
1416
+ res['pixel_values_videos'] = torch.concat(pixel_values_videos)
1417
+ return res
1418
+
1419
+ def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side):
1420
+ # torchacc & xtuner
1421
+ input_ids = res.get('input_ids')
1422
+ attention_mask = res.get('attention_mask')
1423
+ labels = res.get('labels')
1424
+ loss_scale = res.get('loss_scale')
1425
+ if use_torchacc():
1426
+ from swift.utils.torchacc_utils import pad_and_split_batch
1427
+ rank, _, world_size, _ = get_dist_setting()
1428
+ input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
1429
+ padding_to,
1430
+ input_ids,
1431
+ attention_mask,
1432
+ labels,
1433
+ loss_scale,
1434
+ self.max_length,
1435
+ tokenizer,
1436
+ rank,
1437
+ world_size,
1438
+ padding_right=padding_side == 'right')
1439
+ if self.sequence_parallel_size > 1 and input_ids is not None:
1440
+ bs, seq_len = input_ids.shape
1441
+ if 'position_ids' not in res:
1442
+ position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
1443
+ else:
1444
+ position_ids = res['position_ids']
1445
+ assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right'
1446
+ from swift.trainers.sequence_parallel import sequence_parallel
1447
+ if sequence_parallel.world_size() > 1:
1448
+ from swift.trainers.sequence_parallel import sequence_parallel
1449
+ input_ids, _, labels, position_ids, attention_mask, loss_scale = \
1450
+ sequence_parallel.pad_and_split_inputs(
1451
+ tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale)
1452
+ res['position_ids'] = position_ids
1453
+ _local_var = locals()
1454
+ for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']:
1455
+ value = _local_var[key]
1456
+ if value is not None:
1457
+ res[key] = value
1458
+ return res
1459
+
1460
+ def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
1461
+ if tokenizer_kwargs is None:
1462
+ tokenizer_kwargs = {}
1463
+ for key in [
1464
+ 'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
1465
+ ]:
1466
+ val = inputs.get(key) # fix val is a tensor
1467
+ if val is None:
1468
+ val = inputs.get(f'{key}_ids')
1469
+ if val is not None:
1470
+ key_upper = key.upper()
1471
+ logger.info(f'[{key_upper}_IDS] {val}')
1472
+ if key == 'labels' and self.mode in {'seq_cls', 'embedding'}:
1473
+ continue
1474
+ if isinstance(val, (list, tuple, torch.Tensor)):
1475
+ val_str = self.safe_decode(val, **tokenizer_kwargs)
1476
+ logger.info(f'[{key_upper}] {val_str}')
1477
+ if inputs.get('loss_scale') is not None:
1478
+ val = inputs['loss_scale']
1479
+ logger.info(f'[LOSS_SCALE] {val}')
1480
+
1481
+ async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None:
1482
+ images = inputs.pop('images', None) or []
1483
+ if len(images) == 0:
1484
+ return
1485
+ input_ids = inputs['input_ids']
1486
+ idx_list = findall(input_ids, -100)
1487
+ assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
1488
+ idx_list.insert(0, -1)
1489
+ new_input_ids = []
1490
+ for i in range(len(idx_list) - 1):
1491
+ new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
1492
+ images[i]['offset'] = len(new_input_ids)
1493
+ new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens']
1494
+ new_input_ids += input_ids[idx_list[-1] + 1:]
1495
+ inputs['input_ids'] = new_input_ids
1496
+ inputs['multimodal'] = images
1497
+
1498
+ async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
1499
+ images = inputs.pop('images', None) or []
1500
+ if len(images) == 0:
1501
+ return
1502
+ from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
1503
+ input_ids = inputs['input_ids']
1504
+ idx_list = findall(input_ids, -100)
1505
+ assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
1506
+ idx_list.insert(0, -1)
1507
+ new_input_ids = []
1508
+ ranges = []
1509
+ for i in range(len(idx_list) - 1):
1510
+ _range = []
1511
+ new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
1512
+ _range.append(len(new_input_ids))
1513
+ new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0]
1514
+ _range.append(len(new_input_ids))
1515
+ ranges.append(_range)
1516
+ new_input_ids += input_ids[idx_list[-1] + 1:]
1517
+ inputs['input_embeddings'] = [image.to('cpu') for image in images]
1518
+ inputs['input_embedding_ranges'] = ranges
1519
+ inputs['input_ids'] = new_input_ids
1520
+
1521
+ def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor:
1522
+ """Pad sequence by some side
1523
+
1524
+ Args:
1525
+ sequences: The input sequences in tensor.
1526
+ padding_value: The padding value
1527
+
1528
+ Returns:
1529
+ A tensor after padding
1530
+ """
1531
+ padding_side = self.padding_side if self.is_training else 'left'
1532
+ padding_right = padding_side == 'right'
1533
+ if padding_right:
1534
+ return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
1535
+
1536
+ max_len = max([s.shape[0] for s in sequences])
1537
+
1538
+ padded_sequences = []
1539
+ for seq in sequences:
1540
+ pad_length = max_len - seq.shape[0]
1541
+ pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
1542
+ padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
1543
+ padded_sequences.append(padded_seq)
1544
+
1545
+ return torch.stack(padded_sequences)
1546
+
1547
+ def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str:
1548
+ if isinstance(self, Template):
1549
+ tokenizer = self.tokenizer
1550
+ placeholder_tokens = self.placeholder_tokens
1551
+ else:
1552
+ tokenizer = self
1553
+ placeholder_tokens = []
1554
+
1555
+ def _is_special(token: int) -> bool:
1556
+ if isinstance(token, float) or token < 0:
1557
+ return True
1558
+ return token in placeholder_tokens
1559
+
1560
+ if isinstance(input_ids, torch.Tensor):
1561
+ input_ids = input_ids.tolist()
1562
+ if len(input_ids) == 0:
1563
+ return ''
1564
+ result_str = ''
1565
+ for i in range(len(input_ids)):
1566
+ if i == 0:
1567
+ if _is_special(input_ids[i]):
1568
+ s = 0
1569
+ else:
1570
+ e = 0
1571
+ continue
1572
+ if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]):
1573
+ s = i
1574
+ result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
1575
+ if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
1576
+ e = i
1577
+ result_str += f'[{input_ids[i - 1]} * {e - s}]'
1578
+ if _is_special(input_ids[i]):
1579
+ result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
1580
+ else:
1581
+ result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
1582
+ return result_str
1583
+
1584
+ @staticmethod
1585
+ @contextmanager
1586
+ def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False):
1587
+ _origin_flash_attention_forward = modeling_module._flash_attention_forward
1588
+
1589
+ def _flash_attention_forward(*args, **kwargs):
1590
+ if use_new_func:
1591
+ from transformers.modeling_flash_attention_utils import (_flash_attention_forward as
1592
+ flash_attention_forward)
1593
+ if args and isinstance(args[0], nn.Module):
1594
+ args = args[1:]
1595
+ if 'is_causal' not in kwargs:
1596
+ kwargs['is_causal'] = True
1597
+ else:
1598
+ flash_attention_forward = _origin_flash_attention_forward
1599
+ kwargs['position_ids'] = position_ids
1600
+ return flash_attention_forward(*args, **kwargs)
1601
+
1602
+ modeling_module._flash_attention_forward = _flash_attention_forward
1603
+ try:
1604
+ yield
1605
+ finally:
1606
+ modeling_module._flash_attention_forward = _origin_flash_attention_forward
ms-swift/swift/llm/template/grounding.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import itertools
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Any, List, Literal
6
+
7
+ import requests
8
+ from modelscope.hub.utils.utils import get_cache_dir
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+
12
+ def _shuffle_colors(nums: List[Any]) -> List[Any]:
13
+ if len(nums) == 1:
14
+ return nums
15
+
16
+ mid = len(nums) // 2
17
+
18
+ left = nums[:mid]
19
+ right = nums[mid:]
20
+ left = _shuffle_colors(left)
21
+ right = _shuffle_colors(right)
22
+ new_nums = []
23
+ for x, y in zip(left, right):
24
+ new_nums += [x, y]
25
+ new_nums += left[len(right):] or right[len(left):]
26
+ return new_nums
27
+
28
+
29
+ def generate_colors():
30
+ vs_combinations = [(v, s) for v, s in itertools.product([0.7, 0.3, 1], [0.7, 0.3, 1])]
31
+ colors = [colorsys.hsv_to_rgb(i / 16, s, v) for v, s in vs_combinations for i in _shuffle_colors(list(range(16)))]
32
+ colors = [(int(r * 255), int(g * 255), int(b * 255)) for r, g, b in colors]
33
+ return _shuffle_colors(colors)
34
+
35
+
36
+ def download_file(url: str) -> str:
37
+ url = url.rstrip('/')
38
+ file_name = url.rsplit('/', 1)[-1]
39
+ cache_dir = os.path.join(get_cache_dir(), 'files')
40
+ os.makedirs(cache_dir, exist_ok=True)
41
+ req = requests.get(url)
42
+ file_path = os.path.join(cache_dir, file_name)
43
+ with open(file_path, 'wb') as f:
44
+ f.write(req.content)
45
+ return file_path
46
+
47
+
48
+ colors = generate_colors()
49
+ color_mapping = {}
50
+
51
+
52
+ def _calculate_brightness(image, region: List[int]):
53
+ cropped_image = image.crop(region)
54
+ grayscale_image = cropped_image.convert('L')
55
+ pixels = list(grayscale_image.getdata())
56
+ average_brightness = sum(pixels) / len(pixels)
57
+ return average_brightness
58
+
59
+
60
+ def draw_bbox(image: Image.Image,
61
+ ref: List[str],
62
+ bbox: List[List[int]],
63
+ norm_bbox: Literal['norm1000', 'none'] = 'norm1000'):
64
+ bbox = deepcopy(bbox)
65
+ font_path = 'https://modelscope.cn/models/Qwen/Qwen-VL-Chat/resolve/master/SimSun.ttf'
66
+ # norm bbox
67
+ for i, box in enumerate(bbox):
68
+ for i in range(len(box)):
69
+ box[i] = int(box[i])
70
+ if norm_bbox == 'norm1000':
71
+ box[0] = box[0] / 1000 * image.width
72
+ box[2] = box[2] / 1000 * image.width
73
+ box[1] = box[1] / 1000 * image.height
74
+ box[3] = box[3] / 1000 * image.height
75
+
76
+ draw = ImageDraw.Draw(image)
77
+ # draw bbox
78
+ assert len(ref) == len(bbox), f'len(refs): {len(ref)}, len(bboxes): {len(bbox)}'
79
+ for (left, top, right, bottom), box_ref in zip(bbox, ref):
80
+ if box_ref not in color_mapping:
81
+ color_mapping[box_ref] = colors[len(color_mapping) % len(colors)]
82
+ color = color_mapping[box_ref]
83
+ draw.rectangle([(left, top), (right, bottom)], outline=color, width=3)
84
+ # draw text
85
+ file_path = download_file(font_path)
86
+ font = ImageFont.truetype(file_path, 20)
87
+ for (left, top, _, _), box_ref in zip(bbox, ref):
88
+ brightness = _calculate_brightness(
89
+ image, [left, top, min(left + 100, image.width),
90
+ min(top + 20, image.height)])
91
+ draw.text((left, top), box_ref, fill='white' if brightness < 128 else 'black', font=font)
ms-swift/swift/llm/template/register.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import Dict, Literal, Optional
4
+
5
+ from ..utils import Processor
6
+ from .base import Template
7
+ from .template_meta import TemplateMeta
8
+
9
+ TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {}
10
+
11
+
12
+ def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> None:
13
+ template_type = template_meta.template_type
14
+ if not exist_ok and template_type in TEMPLATE_MAPPING:
15
+ raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
16
+ TEMPLATE_MAPPING[template_type] = template_meta
17
+
18
+
19
+ def get_template(
20
+ template_type: str,
21
+ processor: Processor,
22
+ default_system: Optional[str] = None,
23
+ max_length: Optional[int] = None,
24
+ *,
25
+ use_chat_template: bool = True,
26
+ template_backend: Literal['swift', 'jinja'] = 'swift',
27
+ truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
28
+ max_pixels: Optional[int] = None, # h * w
29
+ agent_template: Optional[str] = None,
30
+ norm_bbox: Literal['norm1000', 'none', None] = None,
31
+ response_prefix: Optional[str] = None,
32
+ # train
33
+ padding_side: Literal['left', 'right'] = 'right',
34
+ loss_scale: str = 'default',
35
+ sequence_parallel_size: int = 1) -> 'Template':
36
+ template_meta = TEMPLATE_MAPPING[template_type]
37
+ template_cls = template_meta.template_cls
38
+ return template_cls(
39
+ processor,
40
+ template_meta,
41
+ default_system,
42
+ max_length,
43
+ use_chat_template=use_chat_template,
44
+ template_backend=template_backend,
45
+ truncation_strategy=truncation_strategy,
46
+ max_pixels=max_pixels,
47
+ agent_template=agent_template,
48
+ norm_bbox=norm_bbox,
49
+ response_prefix=response_prefix,
50
+ padding_side=padding_side,
51
+ loss_scale=loss_scale,
52
+ sequence_parallel_size=sequence_parallel_size,
53
+ )
54
+
55
+
56
+ def get_template_meta(template_type: str) -> TemplateMeta:
57
+ return TEMPLATE_MAPPING[template_type]
ms-swift/swift/llm/template/template/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
2
+ minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)