novateur commited on
Commit
a7a406c
·
verified ·
1 Parent(s): 0955071

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. swift/llm/dataset/preprocessor/extra.py +112 -0
  2. swift/llm/model/model/microsoft.py +234 -0
  3. swift/llm/model/model/minicpm.py +207 -0
  4. swift/llm/model/model/minimax.py +156 -0
  5. swift/llm/model/model/mistral.py +157 -0
  6. swift/llm/model/model/mllm.py +180 -0
  7. swift/llm/model/model/moonshot.py +47 -0
  8. swift/llm/model/model/mplug.py +142 -0
  9. swift/llm/model/model/openbuddy.py +85 -0
  10. swift/llm/model/model/qwen.py +899 -0
  11. swift/llm/model/model/skywork.py +89 -0
  12. swift/llm/model/model/stepfun.py +86 -0
  13. swift/llm/model/model/telechat.py +59 -0
  14. swift/llm/model/model/valley.py +82 -0
  15. swift/llm/model/model/yi.py +146 -0
  16. swift/llm/sampling/__init__.py +1 -0
  17. swift/llm/sampling/base.py +55 -0
  18. swift/llm/sampling/distill_sampler.py +148 -0
  19. swift/llm/sampling/mcts.py +400 -0
  20. swift/llm/sampling/sampling.py +106 -0
  21. swift/llm/sampling/utils.py +172 -0
  22. swift/llm/sampling/vanilla_sampler.py +216 -0
  23. swift/llm/template/__init__.py +10 -0
  24. swift/llm/template/base.py +1606 -0
  25. swift/llm/template/constant.py +194 -0
  26. swift/llm/template/grounding.py +91 -0
  27. swift/llm/template/register.py +57 -0
  28. swift/llm/template/template/__init__.py +2 -0
  29. swift/llm/template/template/deepseek.py +315 -0
  30. swift/llm/template/template/emu3.py +191 -0
  31. swift/llm/template/template/gemma.py +131 -0
  32. swift/llm/template/template/glm.py +293 -0
  33. swift/llm/template/template/idefics3.py +37 -0
  34. swift/llm/template/template/internlm.py +195 -0
  35. swift/llm/template/template/internvl.py +168 -0
  36. swift/llm/template/template/llama.py +213 -0
  37. swift/llm/template/template/llava.py +309 -0
  38. swift/llm/template/template/llm.py +274 -0
  39. swift/llm/template/template/megrez.py +93 -0
  40. swift/llm/template/template/microsoft.py +205 -0
  41. swift/llm/template/template/minicpm.py +229 -0
  42. swift/llm/template/template/minimax.py +112 -0
  43. swift/llm/template/template/mistral.py +61 -0
  44. swift/llm/template/template/molmo.py +68 -0
  45. swift/llm/template/template/moonshot.py +66 -0
  46. swift/llm/template/template/mplug.py +214 -0
  47. swift/llm/template/template/openbuddy.py +48 -0
  48. swift/llm/template/template/pixtral.py +59 -0
  49. swift/llm/template/template/qwen.py +671 -0
  50. swift/llm/template/template/stepfun.py +128 -0
swift/llm/dataset/preprocessor/extra.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import numpy as np
5
+
6
+ from .core import ResponsePreprocessor
7
+
8
+
9
+ class GroundingMixin:
10
+ """This class offers prompts to the grounding task"""
11
+ task_type: Optional[str] = None
12
+
13
+ _grounding_language_mixin = [0.8, 0.2]
14
+ _grounding_prompts = {
15
+ 'grounding': {
16
+ 'en': [('<ref-object>', '<bbox>'), ('The positions of <ref-object> is', '<bbox>'),
17
+ ('Find the positions of <ref-object>', '<bbox>'), ('Where is <ref-object>', '<bbox>'),
18
+ ('Find <ref-object>', '<bbox>'), ('Show me <ref-object>', '<bbox>'),
19
+ ('Detect <ref-object>', '<bbox>'), ('Locate <ref-object>', '<bbox>'),
20
+ ('Tell me the location of <ref-object>', '<bbox>'), ('Give the location of <ref-object>', '<bbox>'),
21
+ ('Provide the bounding box coordinate of <ref-object>', '<bbox>')],
22
+ 'zh': [('<ref-object>', '<bbox>'), ('<ref-object>的位置在图片中', '<bbox>'), ('<ref-object>在图片中', '<bbox>'),
23
+ ('<ref-object>在', '<bbox>'), ('找到<ref-object>的位置', '<bbox>'), ('<ref-object>在哪里', '<bbox>'),
24
+ ('提供<ref-object>的坐标位置', '<bbox>')]
25
+ },
26
+ 'caption': {
27
+ 'en': [
28
+ ('<bbox>', '<ref-object>'),
29
+ ('The object at position <bbox>', '<ref-object>'),
30
+ ('This <bbox> is', '<ref-object>'),
31
+ ('What is the object at <bbox>', '<ref-object>'),
32
+ ('Describe <bbox>', '<ref-object>'),
33
+ ('<bbox> is', '<ref-object>'),
34
+ ('The bounding box coordinate <bbox> contains', '<ref-object>'),
35
+ ],
36
+ 'zh': [
37
+ ('<bbox>', '<ref-object>'),
38
+ ('<bbox>是什么', '<ref-object>'),
39
+ ('<bbox>的位置包含', '<ref-object>'),
40
+ ('描述<bbox>', '<ref-object>'),
41
+ ('<bbox>中是', '<ref-object>'),
42
+ ('坐标<bbox>描述了什么', '<ref-object>'),
43
+ ('描述<bbox>中的事物', '<ref-object>'),
44
+ ]
45
+ },
46
+ }
47
+
48
+ def construct_grounding_prompt(self):
49
+ # TODO Only support one bbox to one object
50
+ lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2])
51
+ prompts = GroundingMixin._grounding_prompts[self.task_type][lang]
52
+ query, response = prompts[np.random.choice(range(len(prompts)))]
53
+ return query, response
54
+
55
+
56
+ class TextGenerationPreprocessor(ResponsePreprocessor):
57
+
58
+ def __init__(self,
59
+ *,
60
+ prompt: str,
61
+ query_tag: str = '{{QUERY}}',
62
+ columns: Optional[Dict[str, str]] = None,
63
+ **kwargs) -> None:
64
+ self.query_tag = query_tag
65
+ self.prompt = prompt
66
+ super().__init__(columns=columns, **kwargs)
67
+
68
+ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
69
+ row['query'] = self.prompt.replace(self.query_tag, row['query'])
70
+ return super().preprocess(row)
71
+
72
+
73
+ class ClsGenerationPreprocessor(ResponsePreprocessor):
74
+
75
+ def __init__(self,
76
+ labels: List[str],
77
+ *,
78
+ task: str,
79
+ is_pair_seq: bool = False,
80
+ columns: Optional[Dict[str, str]] = None,
81
+ **kwargs) -> None:
82
+ self.labels = labels
83
+ self.task = task
84
+ self.is_pair_seq = is_pair_seq
85
+
86
+ category = ', '.join(labels)
87
+ self.sentence2_key = 'sentence2'
88
+ self.label_key = 'label'
89
+ if is_pair_seq:
90
+ self.sentence_key = 'sentence1'
91
+ inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}'
92
+ else:
93
+ self.sentence_key = 'sentence'
94
+ inputs = 'Sentence: {sentence}'
95
+ self.prompt = f"""Task: {task}
96
+ {inputs}
97
+ Category: {category}
98
+ Output:"""
99
+ super().__init__(columns=columns, **kwargs)
100
+
101
+ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
102
+ label = row.pop(self.label_key, None)
103
+ if label is None:
104
+ return
105
+
106
+ if self.is_pair_seq:
107
+ query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key))
108
+ else:
109
+ query = self.prompt.format(sentence=row.pop(self.sentence_key))
110
+ row['query'] = query
111
+ row['response'] = self.labels[int(label)]
112
+ return super().preprocess(row)
swift/llm/model/model/microsoft.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from types import MethodType
4
+ from typing import Any, Dict
5
+
6
+ from transformers import AutoConfig
7
+
8
+ from swift.llm import TemplateType
9
+ from swift.utils import get_device, get_env_args
10
+ from ..constant import LLMModelType, MLLMModelType
11
+ from ..model_arch import ModelArch
12
+ from ..patcher import patch_ignore_check_imports, patch_output_clone
13
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
14
+ get_model_tokenizer_with_flash_attn, register_model)
15
+ from ..utils import ModelInfo, use_submodel_func
16
+
17
+
18
+ def get_model_tokenizer_phi3_vision(model_dir: str,
19
+ model_info: ModelInfo,
20
+ model_kwargs: Dict[str, Any],
21
+ load_model: bool = True,
22
+ **kwargs):
23
+ processor_kwargs = {}
24
+ if 'num_crops' in kwargs:
25
+ processor_kwargs['num_crops'] = get_env_args('num_crops', int, kwargs['num_crops'])
26
+ from transformers import AutoProcessor
27
+ processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, **processor_kwargs)
28
+ model, tokenizer = get_model_tokenizer_with_flash_attn(
29
+ model_dir, model_info, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs)
30
+
31
+ if load_model:
32
+ patch_output_clone(model.model.vision_embed_tokens.wte)
33
+
34
+ return model, processor
35
+
36
+
37
+ register_model(
38
+ ModelMeta(
39
+ MLLMModelType.phi3_vision,
40
+ [
41
+ ModelGroup([
42
+ Model('LLM-Research/Phi-3-vision-128k-instruct', 'microsoft/Phi-3-vision-128k-instruct'),
43
+ Model('LLM-Research/Phi-3.5-vision-instruct', 'microsoft/Phi-3.5-vision-instruct'),
44
+ ])
45
+ ],
46
+ TemplateType.phi3_vision,
47
+ partial(get_model_tokenizer_phi3_vision, num_crops=4),
48
+ architectures=['Phi3VForCausalLM'],
49
+ model_arch=ModelArch.phi3_vision,
50
+ requires=['transformers>=4.36'],
51
+ tags=['vision'],
52
+ ))
53
+
54
+
55
+ def get_model_tokenizer_phi4_multimodal(*args, **kwargs):
56
+ model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
57
+ processor.audio_processor.audio_compression_rate = processor.audio_processor.compression_rate
58
+ processor.audio_processor.audio_downsample_rate = processor.audio_processor.qformer_compression_rate
59
+ processor.audio_processor.audio_feat_stride = processor.audio_processor.feat_stride
60
+ del processor.audio_processor.feature_size
61
+ del processor.audio_processor.sampling_rate
62
+ del processor.audio_processor.padding_value
63
+ del processor.__class__.chat_template
64
+ processor.chat_template = None
65
+ if model is not None:
66
+ model.set_lora_adapter(['vision', 'speech'])
67
+ return model, processor
68
+
69
+
70
+ register_model(
71
+ ModelMeta(
72
+ MLLMModelType.phi4_multimodal,
73
+ [ModelGroup([
74
+ Model('LLM-Research/Phi-4-multimodal-instruct', 'microsoft/Phi-4-multimodal-instruct'),
75
+ ])],
76
+ TemplateType.phi4_multimodal,
77
+ get_model_tokenizer_phi4_multimodal,
78
+ architectures=['Phi4MMForCausalLM'],
79
+ model_arch=ModelArch.phi4_multimodal,
80
+ requires=['transformers>=4.36,<4.49', 'backoff', 'soundfile'],
81
+ tags=['vision', 'audio'],
82
+ ))
83
+
84
+
85
+ def get_model_tokenizer_florence(model_dir: str,
86
+ model_info: ModelInfo,
87
+ model_kwargs: Dict[str, Any],
88
+ load_model: bool = True,
89
+ **kwargs):
90
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
91
+ model_config.vision_config.model_type = 'davit' # fix merge-lora
92
+ if model_kwargs['device_map'] == 'auto':
93
+ model_kwargs['device_map'] = get_device()
94
+ kwargs['model_config'] = model_config
95
+ with patch_ignore_check_imports():
96
+ model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
97
+
98
+ if model is not None:
99
+ model.vision_tower.enable_checkpoint = True
100
+ use_submodel_func(model, 'language_model', ['generate', 'forward'])
101
+ return model, processor
102
+
103
+
104
+ register_model(
105
+ ModelMeta(
106
+ MLLMModelType.florence,
107
+ [
108
+ # llama2
109
+ ModelGroup([
110
+ Model('AI-ModelScope/Florence-2-base-ft', 'microsoft/Florence-2-base-ft'),
111
+ Model('AI-ModelScope/Florence-2-base', 'microsoft/Florence-2-base'),
112
+ Model('AI-ModelScope/Florence-2-large', 'microsoft/Florence-2-large'),
113
+ Model('AI-ModelScope/Florence-2-large-ft', 'microsoft/Florence-2-large-ft'),
114
+ ]),
115
+ ],
116
+ TemplateType.florence,
117
+ get_model_tokenizer_florence,
118
+ architectures=['Florence2ForConditionalGeneration'],
119
+ model_arch=ModelArch.florence,
120
+ tags=['vision'],
121
+ ))
122
+
123
+
124
+ def get_model_tokenizer_phi3_small(model_dir: str,
125
+ model_info: ModelInfo,
126
+ model_kwargs: Dict[str, Any],
127
+ load_model: bool = True,
128
+ **kwargs):
129
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
130
+
131
+ def rotary_emb(self, query_states, key_states, **kwargs):
132
+ q_type = query_states.dtype
133
+ k_type = key_states.dtype
134
+ query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs)
135
+ query_states = query_states.to(q_type)
136
+ key_states = key_states.to(k_type)
137
+ return query_states, key_states
138
+
139
+ if model is not None:
140
+ for i in range(32):
141
+ re = model.model.layers[i].self_attn.rotary_emb
142
+ re.rotory_emb_origin = re.forward
143
+ re.forward = MethodType(rotary_emb, re)
144
+ return model, tokenizer
145
+
146
+
147
+ register_model(
148
+ ModelMeta(
149
+ LLMModelType.phi3_small,
150
+ [
151
+ ModelGroup([
152
+ Model('LLM-Research/Phi-3-small-8k-instruct', 'microsoft/Phi-3-small-8k-instruct'),
153
+ Model('LLM-Research/Phi-3-small-128k-instruct', 'microsoft/Phi-3-small-128k-instruct'),
154
+ ]),
155
+ ],
156
+ TemplateType.phi3,
157
+ get_model_tokenizer_phi3_small,
158
+ architectures=['Phi3SmallForCausalLM'],
159
+ model_arch=ModelArch.phi3_small,
160
+ requires=['transformers>=4.36'],
161
+ ))
162
+
163
+
164
+ def get_model_tokenizer_phi(model_dir: str,
165
+ model_info: ModelInfo,
166
+ model_kwargs: Dict[str, Any],
167
+ load_model: bool = True,
168
+ **kwargs):
169
+ # TODO: check
170
+ return get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
171
+
172
+
173
+ register_model(
174
+ ModelMeta(
175
+ LLMModelType.phi2,
176
+ [
177
+ ModelGroup([
178
+ Model('AI-ModelScope/phi-2', 'microsoft/phi-2'),
179
+ ]),
180
+ ],
181
+ TemplateType.default,
182
+ get_model_tokenizer_phi,
183
+ architectures=['PhiForCausalLM'],
184
+ model_arch=ModelArch.phi2,
185
+ ))
186
+
187
+ register_model(
188
+ ModelMeta(
189
+ LLMModelType.phi3,
190
+ [
191
+ ModelGroup([
192
+ Model('LLM-Research/Phi-3-mini-4k-instruct', 'microsoft/Phi-3-mini-4k-instruct'),
193
+ Model('LLM-Research/Phi-3-mini-128k-instruct', 'microsoft/Phi-3-mini-128k-instruct'),
194
+ Model('LLM-Research/Phi-3-medium-4k-instruct', 'microsoft/Phi-3-medium-4k-instruct'),
195
+ Model('LLM-Research/Phi-3-medium-128k-instruct', 'microsoft/Phi-3-medium-128k-instruct'),
196
+ Model('LLM-Research/Phi-3.5-mini-instruct', 'microsoft/Phi-3.5-mini-instruct'),
197
+ ]),
198
+ ModelGroup(Model('LLM-Research/Phi-4-mini-instruct', 'microsoft/Phi-4-mini-instruct'))
199
+ ],
200
+ TemplateType.phi3,
201
+ get_model_tokenizer_with_flash_attn,
202
+ architectures=['Phi3ForCausalLM'],
203
+ requires=['transformers>=4.36'],
204
+ model_arch=ModelArch.phi3,
205
+ ))
206
+
207
+ register_model(
208
+ ModelMeta(
209
+ LLMModelType.phi4,
210
+ [
211
+ ModelGroup([
212
+ Model('LLM-Research/phi-4', 'microsoft/phi-4'),
213
+ ]),
214
+ ],
215
+ TemplateType.phi4,
216
+ get_model_tokenizer_with_flash_attn,
217
+ architectures=['Phi3ForCausalLM'],
218
+ requires=['transformers>=4.36'],
219
+ model_arch=ModelArch.phi3,
220
+ ))
221
+
222
+ register_model(
223
+ ModelMeta(
224
+ LLMModelType.phi3_moe,
225
+ [
226
+ ModelGroup([
227
+ Model('LLM-Research/Phi-3.5-MoE-instruct', 'microsoft/Phi-3.5-MoE-instruct'),
228
+ ]),
229
+ ],
230
+ TemplateType.phi3,
231
+ get_model_tokenizer_with_flash_attn,
232
+ architectures=['PhiMoEForCausalLM'],
233
+ requires=['transformers>=4.36'],
234
+ ))
swift/llm/model/model/minicpm.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from types import MethodType
4
+ from typing import Any, Dict
5
+
6
+ from transformers import AutoConfig
7
+ from transformers.utils import strtobool
8
+
9
+ from swift.llm import TemplateType
10
+ from swift.utils import get_env_args
11
+ from ..constant import LLMModelType, MLLMModelType
12
+ from ..model_arch import ModelArch
13
+ from ..patcher import patch_device_map, patch_fixed_device, patch_output_clone
14
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
15
+ from ..utils import ModelInfo, use_submodel_func
16
+ from .deepseek import get_model_tokenizer_deepseek_moe
17
+
18
+ register_model(
19
+ ModelMeta(
20
+ LLMModelType.minicpm_moe,
21
+ [
22
+ ModelGroup([
23
+ Model('OpenBMB/MiniCPM-MoE-8x2B', 'openbmb/MiniCPM-MoE-8x2B'),
24
+ ]),
25
+ ],
26
+ TemplateType.minicpm,
27
+ get_model_tokenizer_deepseek_moe,
28
+ architectures=['MiniCPMForCausalLM'],
29
+ model_arch=ModelArch.llama,
30
+ requires=['transformers>=4.36'],
31
+ ))
32
+
33
+
34
+ def _patch_minicpmv_device_map(model) -> None:
35
+ if not hasattr(model, 'hf_device_map') or len(model.hf_device_map.values()) == 1:
36
+ return
37
+
38
+ device = list(model.hf_device_map.values())[0]
39
+ if hasattr(model, 'get_vision_embedding') and not hasattr(model, '_old_get_vision_embedding'):
40
+ # minicpm-v-v2-chat; avoid double patching
41
+ _old_get_vision_embedding = model.__class__.get_vision_embedding
42
+
43
+ def _get_vision_embedding(self, pixel_values):
44
+ if len(pixel_values) == 0:
45
+ return _old_get_vision_embedding(self, pixel_values)
46
+ output = _old_get_vision_embedding(self, pixel_values)
47
+ return output.to(device=device)
48
+
49
+ model.__class__._old_get_vision_embedding = _old_get_vision_embedding
50
+ model.__class__.get_vision_embedding = _get_vision_embedding
51
+
52
+ if hasattr(model, 'resampler'): # minicpm-v-v2_5-chat
53
+ patch_fixed_device(model.resampler, device)
54
+
55
+
56
+ def get_model_tokenizer_minicpmv(model_dir: str,
57
+ model_info: ModelInfo,
58
+ model_kwargs: Dict[str, Any],
59
+ load_model: bool = True,
60
+ **kwargs):
61
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
62
+ if load_model:
63
+ model.resampler.to(model_info.torch_dtype) # fix float32
64
+ _patch_minicpmv_device_map(model)
65
+ func_list = ['generate', 'get_input_embeddings', 'forward']
66
+ use_submodel_func(model, 'llm', func_list)
67
+ if hasattr(model, 'get_slice_image_placeholder'):
68
+ tokenizer.get_slice_image_placeholder = MethodType(model.get_slice_image_placeholder, tokenizer)
69
+ tokenizer.transform = MethodType(model.transform, tokenizer)
70
+ return model, tokenizer
71
+
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ MLLMModelType.minicpmv,
76
+ [
77
+ ModelGroup([
78
+ Model('OpenBMB/MiniCPM-V', 'openbmb/MiniCPM-V'),
79
+ Model('OpenBMB/MiniCPM-V-2', 'openbmb/MiniCPM-V-2'),
80
+ ], ),
81
+ ],
82
+ TemplateType.minicpmv,
83
+ get_model_tokenizer_minicpmv,
84
+ architectures=['MiniCPMV'],
85
+ model_arch=ModelArch.minicpmv,
86
+ requires=['timm', 'transformers<4.42'],
87
+ tags=['vision'],
88
+ ))
89
+
90
+
91
+ def get_model_tokenizer_minicpmv_2_x(model_dir: str,
92
+ model_info: ModelInfo,
93
+ model_kwargs: Dict[str, Any],
94
+ load_model: bool = True,
95
+ **kwargs):
96
+ from transformers import AutoProcessor
97
+ processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
98
+ version = kwargs.get('version')
99
+ if version == 'o2.6':
100
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
101
+ model_config.init_tts = strtobool(get_env_args('init_tts', str, 'false'))
102
+ model_config.init_audio = strtobool(get_env_args('init_audio', str, 'false'))
103
+ kwargs['model_config'] = model_config
104
+ with patch_device_map():
105
+ model, tokenizer = get_model_tokenizer_minicpmv(
106
+ model_dir, model_info, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs)
107
+ if load_model:
108
+ embedding = model.get_input_embeddings()
109
+ patch_output_clone(embedding)
110
+
111
+ return model, processor
112
+
113
+
114
+ register_model(
115
+ ModelMeta(
116
+ MLLMModelType.minicpmv2_5,
117
+ [
118
+ ModelGroup([
119
+ Model('OpenBMB/MiniCPM-Llama3-V-2_5', 'openbmb/MiniCPM-Llama3-V-2_5'),
120
+ ], ),
121
+ ],
122
+ TemplateType.minicpmv2_5,
123
+ get_model_tokenizer_minicpmv_2_x,
124
+ architectures=['MiniCPMV'],
125
+ model_arch=ModelArch.minicpmv,
126
+ requires=['timm', 'transformers>=4.36'],
127
+ tags=['vision'],
128
+ ))
129
+
130
+ register_model(
131
+ ModelMeta(
132
+ MLLMModelType.minicpmv2_6,
133
+ [
134
+ ModelGroup([
135
+ Model('OpenBMB/MiniCPM-V-2_6', 'openbmb/MiniCPM-V-2_6'),
136
+ ], ),
137
+ ],
138
+ TemplateType.minicpmv2_6,
139
+ get_model_tokenizer_minicpmv_2_x,
140
+ architectures=['MiniCPMV'],
141
+ model_arch=ModelArch.minicpmv,
142
+ requires=['timm', 'transformers>=4.36', 'decord'],
143
+ tags=['vision', 'video'],
144
+ ))
145
+
146
+ register_model(
147
+ ModelMeta(
148
+ MLLMModelType.minicpmo2_6,
149
+ [
150
+ ModelGroup([
151
+ Model('OpenBMB/MiniCPM-o-2_6', 'openbmb/MiniCPM-o-2_6'),
152
+ ]),
153
+ ],
154
+ TemplateType.minicpmo2_6,
155
+ partial(get_model_tokenizer_minicpmv_2_x, version='o2.6'),
156
+ architectures=['MiniCPMO'],
157
+ model_arch=ModelArch.minicpmv,
158
+ requires=['timm', 'transformers>=4.36', 'decord', 'soundfile'],
159
+ tags=['vision', 'video', 'omni', 'audio'],
160
+ ))
161
+
162
+ register_model(
163
+ ModelMeta(
164
+ LLMModelType.minicpm,
165
+ [
166
+ ModelGroup([
167
+ Model('OpenBMB/MiniCPM-2B-sft-fp32', 'openbmb/MiniCPM-2B-sft-fp32'),
168
+ Model('OpenBMB/MiniCPM-2B-dpo-fp32', 'openbmb/MiniCPM-2B-dpo-fp32'),
169
+ Model('OpenBMB/MiniCPM-1B-sft-bf16', 'openbmb/MiniCPM-1B-sft-bf16'),
170
+ ], ),
171
+ ],
172
+ TemplateType.minicpm,
173
+ get_model_tokenizer_with_flash_attn,
174
+ architectures=['MiniCPMForCausalLM'],
175
+ model_arch=ModelArch.llama,
176
+ requires=['transformers>=4.36.0'],
177
+ ))
178
+
179
+ register_model(
180
+ ModelMeta(
181
+ LLMModelType.minicpm_chatml,
182
+ [
183
+ ModelGroup([
184
+ Model('OpenBMB/MiniCPM-2B-128k', 'openbmb/MiniCPM-2B-128k'),
185
+ ]),
186
+ ],
187
+ TemplateType.chatml,
188
+ get_model_tokenizer_with_flash_attn,
189
+ architectures=['MiniCPMForCausalLM'],
190
+ model_arch=ModelArch.llama,
191
+ requires=['transformers>=4.36'],
192
+ ))
193
+
194
+ register_model(
195
+ ModelMeta(
196
+ LLMModelType.minicpm3,
197
+ [
198
+ ModelGroup([
199
+ Model('OpenBMB/MiniCPM3-4B', 'openbmb/MiniCPM3-4B'),
200
+ ]),
201
+ ],
202
+ TemplateType.chatml,
203
+ get_model_tokenizer_with_flash_attn,
204
+ architectures=['MiniCPM3ForCausalLM'],
205
+ model_arch=ModelArch.deepseek_v2,
206
+ requires=['transformers>=4.36'],
207
+ ))
swift/llm/model/model/minimax.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from typing import Any, Dict
4
+
5
+ import json
6
+ from transformers import AutoConfig, AutoProcessor
7
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
8
+
9
+ from swift.llm import TemplateType
10
+ from swift.utils import get_device, get_device_count, get_dist_setting, get_logger
11
+ from ..constant import LLMModelType, MLLMModelType
12
+ from ..patcher import patch_ignore_check_imports
13
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
14
+ from ..utils import ModelInfo
15
+
16
+ logger = get_logger()
17
+
18
+
19
+ def get_model_tokenizer_minimax_vl(model_dir: str,
20
+ model_info: ModelInfo,
21
+ model_kwargs: Dict[str, Any],
22
+ load_model: bool = True,
23
+ **kwargs):
24
+ logger.warn('NOTE: minimax-vl-01 model does not support training.')
25
+ n_gpu = get_device_count()
26
+ _, local_rank, _, local_world_size = get_dist_setting()
27
+ device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
28
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
29
+ kwargs['model_config'] = config
30
+ if kwargs.get('attn_impl') == 'flash_attn':
31
+ config.attn_type_list = [1] * len(config.attn_type_list)
32
+ else:
33
+ config.attn_type_list = [0] * len(config.attn_type_list)
34
+ if 'quantization_config' in model_kwargs:
35
+ quantization_config = model_kwargs['quantization_config']
36
+ from transformers import QuantoConfig
37
+ if isinstance(quantization_config, QuantoConfig):
38
+ quantization_config.modules_to_not_convert = (
39
+ [
40
+ 'vision_tower',
41
+ 'image_newline',
42
+ 'multi_modal_projector',
43
+ 'lm_head',
44
+ 'embed_tokens',
45
+ ] + [f'model.layers.{i}.coefficient' for i in range(config.text_config.num_hidden_layers)]
46
+ + [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.text_config.num_hidden_layers)])
47
+
48
+ if len(device_ids) > 1:
49
+ model_safetensors_index_path = os.path.join(model_dir, 'model.safetensors.index.json')
50
+ with open(model_safetensors_index_path, 'r') as f:
51
+ model_safetensors_index = json.load(f)
52
+ weight_map = model_safetensors_index['weight_map']
53
+ vision_map = {}
54
+ for key, value in weight_map.items():
55
+ if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
56
+ new_key = key.replace('.weight', '').replace('.bias', '')
57
+ if new_key not in vision_map:
58
+ vision_map[new_key] = value
59
+
60
+ device_map = {
61
+ 'language_model.model.embed_tokens': get_device(device_ids[0]),
62
+ 'language_model.model.norm': get_device(device_ids[len(device_ids) - 1]),
63
+ 'language_model.lm_head': get_device(device_ids[len(device_ids) - 1])
64
+ }
65
+ for key, value in vision_map.items():
66
+ device_map[key] = get_device(device_ids[0])
67
+ device_map['vision_tower.vision_model.post_layernorm'] = get_device(device_ids[0])
68
+ layers_per_device = config.text_config.num_hidden_layers // len(device_ids)
69
+ for i in range(len(device_ids)):
70
+ for j in range(layers_per_device):
71
+ device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = get_device(device_ids[i])
72
+ model_kwargs['device_map'] = device_map
73
+
74
+ MiniMaxVL01ProcessorKwargs = get_class_from_dynamic_module('processing_minimax_vl_01.MiniMaxVL01ProcessorKwargs',
75
+ model_dir)
76
+ get_hw_multiple_of = get_class_from_dynamic_module('processing_minimax_vl_01.get_hw_multiple_of', model_dir)
77
+ get_num_token = get_class_from_dynamic_module('processing_minimax_vl_01.get_num_token', model_dir)
78
+
79
+ processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
80
+ processor.MiniMaxVL01ProcessorKwargs = MiniMaxVL01ProcessorKwargs
81
+ processor.get_hw_multiple_of = get_hw_multiple_of
82
+ processor.get_num_token = get_num_token
83
+ with patch_ignore_check_imports():
84
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
85
+ **kwargs)
86
+ processor.tokenizer = tokenizer
87
+ return model, processor
88
+
89
+
90
+ register_model(
91
+ ModelMeta(
92
+ MLLMModelType.minimax_vl, [
93
+ ModelGroup([
94
+ Model('MiniMax/MiniMax-VL-01', 'MiniMaxAI/MiniMax-VL-01'),
95
+ ]),
96
+ ],
97
+ TemplateType.minimax_vl,
98
+ get_model_tokenizer_minimax_vl,
99
+ architectures=['MiniMaxVL01ForConditionalGeneration'],
100
+ tags=['vision']))
101
+
102
+
103
+ def get_model_tokenizer_minimax_text(model_dir: str,
104
+ model_info: ModelInfo,
105
+ model_kwargs: Dict[str, Any],
106
+ load_model: bool = True,
107
+ **kwargs):
108
+ logger.warn('NOTE: minimax-text-01 model does not support training.')
109
+ n_gpu = get_device_count()
110
+ _, local_rank, _, local_world_size = get_dist_setting()
111
+ device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
112
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
113
+ kwargs['model_config'] = config
114
+ if kwargs.get('attn_impl') == 'flash_attn':
115
+ config.attn_type_list = [1] * len(config.attn_type_list)
116
+ else:
117
+ config.attn_type_list = [0] * len(config.attn_type_list)
118
+ if 'quantization_config' in model_kwargs:
119
+ quantization_config = model_kwargs['quantization_config']
120
+ from transformers import QuantoConfig
121
+ if isinstance(quantization_config, QuantoConfig):
122
+ quantization_config.modules_to_not_convert = (
123
+ [
124
+ 'lm_head',
125
+ 'embed_tokens',
126
+ ] + [f'model.layers.{i}.coefficient' for i in range(config.num_hidden_layers)]
127
+ + [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.num_hidden_layers)])
128
+
129
+ if len(device_ids) > 1:
130
+ layers_per_device = config.num_hidden_layers // len(device_ids)
131
+ # set device map
132
+ device_map = {
133
+ 'model.embed_tokens': get_device(0),
134
+ 'model.norm': get_device(len(device_ids) - 1),
135
+ 'lm_head': get_device(len(device_ids) - 1)
136
+ }
137
+ for i in range(len(device_ids)):
138
+ for j in range(layers_per_device):
139
+ device_map[f'model.layers.{i * layers_per_device + j}'] = get_device(i)
140
+ model_kwargs['device_map'] = device_map
141
+ with patch_ignore_check_imports():
142
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
143
+ **kwargs)
144
+ return model, tokenizer
145
+
146
+
147
+ register_model(
148
+ ModelMeta(
149
+ LLMModelType.minimax, [
150
+ ModelGroup([
151
+ Model('MiniMax/MiniMax-Text-01', 'MiniMaxAI/MiniMax-Text-01'),
152
+ ]),
153
+ ],
154
+ TemplateType.minimax,
155
+ get_model_tokenizer_minimax_text,
156
+ architectures=['MiniMaxText01ForCausalLM']))
swift/llm/model/model/mistral.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import Any, Dict
4
+
5
+ from swift.llm import TemplateType
6
+ from ..constant import LLMModelType, MLLMModelType
7
+ from ..model_arch import ModelArch
8
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
9
+ get_model_tokenizer_with_flash_attn, register_model)
10
+ from ..utils import ModelInfo
11
+
12
+ register_model(
13
+ ModelMeta(
14
+ LLMModelType.mistral,
15
+ [
16
+ ModelGroup([
17
+ Model('AI-ModelScope/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1'),
18
+ Model('AI-ModelScope/Mistral-7B-Instruct-v0.2', 'mistralai/Mistral-7B-Instruct-v0.2'),
19
+ Model('LLM-Research/Mistral-7B-Instruct-v0.3', 'mistralai/Mistral-7B-Instruct-v0.3'),
20
+ Model('AI-ModelScope/Mistral-7B-v0.1', 'mistralai/Mistral-7B-v0.1'),
21
+ Model('AI-ModelScope/Mistral-7B-v0.2-hf', 'alpindale/Mistral-7B-v0.2-hf'),
22
+ ]),
23
+ ModelGroup([
24
+ Model('swift/Codestral-22B-v0.1', 'mistralai/Codestral-22B-v0.1'),
25
+ ]),
26
+ ],
27
+ TemplateType.llama,
28
+ get_model_tokenizer_with_flash_attn,
29
+ architectures=['MistralForCausalLM'],
30
+ model_arch=ModelArch.llama,
31
+ requires=['transformers>=4.34'],
32
+ ))
33
+
34
+ register_model(
35
+ ModelMeta(
36
+ LLMModelType.mixtral, [
37
+ ModelGroup([
38
+ Model('AI-ModelScope/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1'),
39
+ Model('AI-ModelScope/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-v0.1'),
40
+ Model('AI-ModelScope/Mixtral-8x22B-v0.1', 'mistral-community/Mixtral-8x22B-v0.1'),
41
+ ],
42
+ requires=['transformers>=4.36']),
43
+ ModelGroup([
44
+ Model('AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf', 'ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf'),
45
+ ],
46
+ requires=['transformers>=4.38', 'aqlm', 'torch>=2.2.0']),
47
+ ],
48
+ TemplateType.llama,
49
+ get_model_tokenizer_with_flash_attn,
50
+ architectures=['MixtralForCausalLM'],
51
+ model_arch=ModelArch.llama))
52
+
53
+ register_model(
54
+ ModelMeta(
55
+ LLMModelType.mistral_nemo, [
56
+ ModelGroup([
57
+ Model('AI-ModelScope/Mistral-Small-Instruct-2409', 'mistralai/Mistral-Small-Instruct-2409'),
58
+ Model('LLM-Research/Mistral-Large-Instruct-2407', 'mistralai/Mistral-Large-Instruct-2407'),
59
+ Model('AI-ModelScope/Mistral-Nemo-Base-2407', 'mistralai/Mistral-Nemo-Base-2407'),
60
+ Model('AI-ModelScope/Mistral-Nemo-Instruct-2407', 'mistralai/Mistral-Nemo-Instruct-2407'),
61
+ ],
62
+ requires=['transformers>=4.43']),
63
+ ModelGroup([
64
+ Model('AI-ModelScope/Ministral-8B-Instruct-2410', 'mistralai/Ministral-8B-Instruct-2410'),
65
+ ],
66
+ requires=['transformers>=4.46']),
67
+ ],
68
+ TemplateType.mistral_nemo,
69
+ get_model_tokenizer_with_flash_attn,
70
+ architectures=['MistralForCausalLM'],
71
+ model_arch=ModelArch.llama))
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ LLMModelType.mistral_2501, [
76
+ ModelGroup([
77
+ Model('mistralai/Mistral-Small-24B-Base-2501', 'mistralai/Mistral-Small-24B-Base-2501'),
78
+ Model('mistralai/Mistral-Small-24B-Instruct-2501', 'mistralai/Mistral-Small-24B-Instruct-2501'),
79
+ ]),
80
+ ],
81
+ TemplateType.mistral_2501,
82
+ get_model_tokenizer_with_flash_attn,
83
+ architectures=['MistralForCausalLM'],
84
+ model_arch=ModelArch.llama))
85
+
86
+ register_model(
87
+ ModelMeta(
88
+ LLMModelType.zephyr,
89
+ [
90
+ ModelGroup([
91
+ Model('modelscope/zephyr-7b-beta', 'HuggingFaceH4/zephyr-7b-beta'),
92
+ ]),
93
+ ],
94
+ TemplateType.zephyr,
95
+ get_model_tokenizer_with_flash_attn,
96
+ model_arch=ModelArch.llama,
97
+ architectures=['MistralForCausalLM'],
98
+ requires=['transformers>=4.34'],
99
+ ))
100
+
101
+ register_model(
102
+ ModelMeta(
103
+ LLMModelType.wizardlm2_moe,
104
+ [ModelGroup([
105
+ Model('AI-ModelScope/WizardLM-2-8x22B', 'alpindale/WizardLM-2-8x22B'),
106
+ ])],
107
+ TemplateType.wizardlm2_moe,
108
+ get_model_tokenizer_with_flash_attn,
109
+ architectures=['MixtralForCausalLM'],
110
+ requires=['transformers>=4.36'],
111
+ ))
112
+
113
+ register_model(
114
+ ModelMeta(
115
+ LLMModelType.wizardlm2,
116
+ [ModelGroup([
117
+ Model('AI-ModelScope/WizardLM-2-7B-AWQ', 'MaziyarPanahi/WizardLM-2-7B-AWQ'),
118
+ ])],
119
+ TemplateType.wizardlm2,
120
+ get_model_tokenizer_with_flash_attn,
121
+ architectures=['MistralForCausalLM'],
122
+ requires=['transformers>=4.34'],
123
+ ))
124
+
125
+
126
+ def get_model_tokenizer_mistral_2503(model_dir: str,
127
+ model_info: ModelInfo,
128
+ model_kwargs: Dict[str, Any],
129
+ load_model: bool = True,
130
+ **kwargs):
131
+ try:
132
+ from transformers import Mistral3ForConditionalGeneration
133
+ except ImportError:
134
+ raise ImportError('Please install Gemma3ForConditionalGeneration by running '
135
+ '`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`')
136
+
137
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
138
+ model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
139
+
140
+ return model, processor
141
+
142
+
143
+ register_model(
144
+ ModelMeta(
145
+ MLLMModelType.mistral_2503,
146
+ [
147
+ ModelGroup([
148
+ Model('mistralai/Mistral-Small-3.1-24B-Base-2503', 'mistralai/Mistral-Small-3.1-24B-Base-2503'),
149
+ Model('mistralai/Mistral-Small-3.1-24B-Instruct-2503', 'mistralai/Mistral-Small-3.1-24B-Instruct-2503'),
150
+ ]),
151
+ ],
152
+ TemplateType.mistral_2503,
153
+ get_model_tokenizer_mistral_2503,
154
+ architectures=['Mistral3ForConditionalGeneration'],
155
+ model_arch=ModelArch.llava_hf,
156
+ requires=['transformers>=4.49'],
157
+ ), )
swift/llm/model/model/mllm.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from types import MethodType
3
+ from typing import Any, Dict
4
+
5
+ import torch
6
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
7
+
8
+ from swift.llm import TemplateType
9
+ from swift.llm.model.model.qwen import get_model_tokenizer_qwen2_vl
10
+ from swift.utils import get_logger
11
+ from ..constant import MLLMModelType
12
+ from ..model_arch import ModelArch
13
+ from ..patcher import patch_output_clone, patch_output_normalizer
14
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
15
+ get_model_tokenizer_with_flash_attn, register_model)
16
+ from ..utils import ModelInfo, use_submodel_func
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ def get_model_tokenizer_idefics(model_dir: str, *args, **kwargs):
22
+ from transformers import AutoModelForVision2Seq
23
+ kwargs['automodel_class'] = kwargs['automodel_class'] or AutoModelForVision2Seq
24
+ model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
25
+ return model, processor
26
+
27
+
28
+ register_model(
29
+ ModelMeta(
30
+ MLLMModelType.idefics3,
31
+ [
32
+ ModelGroup([
33
+ Model('AI-ModelScope/Idefics3-8B-Llama3', 'HuggingFaceM4/Idefics3-8B-Llama3'),
34
+ ]),
35
+ ],
36
+ TemplateType.idefics3,
37
+ get_model_tokenizer_idefics,
38
+ model_arch=ModelArch.idefics3,
39
+ architectures=['Idefics3ForConditionalGeneration'],
40
+ tags=['vision'],
41
+ requires=['transformers>=4.45'],
42
+ ))
43
+
44
+
45
+ def get_model_tokenizer_pixtral(model_dir: str, *args, **kwargs):
46
+ from transformers import LlavaForConditionalGeneration
47
+ kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration
48
+ model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
49
+ return model, processor
50
+
51
+
52
+ register_model(
53
+ ModelMeta(
54
+ MLLMModelType.pixtral,
55
+ [
56
+ ModelGroup([
57
+ Model('AI-ModelScope/pixtral-12b', 'mistral-community/pixtral-12b'),
58
+ ]),
59
+ ],
60
+ TemplateType.pixtral,
61
+ get_model_tokenizer_pixtral,
62
+ model_arch=ModelArch.llava_hf,
63
+ architectures=['LlavaForConditionalGeneration'],
64
+ requires=['transformers>=4.45'],
65
+ tags=['vision'],
66
+ ))
67
+
68
+
69
+ def get_model_tokenizer_molmoe(model_dir: str,
70
+ model_info: ModelInfo,
71
+ model_kwargs: Dict[str, Any],
72
+ load_model: bool = True,
73
+ **kwargs):
74
+ model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
75
+
76
+ # fix bug for molmoe-1b
77
+ def to_dict(self, *args, **kwargs):
78
+ res = self._to_dict(*args, **kwargs)
79
+ res['vision_backbone'] = self.vision_backbone.__dict__
80
+ res.pop('to_dict')
81
+ res.pop('_to_dict')
82
+ return res
83
+
84
+ if model is not None:
85
+ model.config._to_dict = model.config.to_dict
86
+ model.config.to_dict = MethodType(to_dict, model.config)
87
+ patch_output_clone(model.model.transformer.wte)
88
+ return model, processor
89
+
90
+
91
+ register_model(
92
+ ModelMeta(
93
+ MLLMModelType.molmoe,
94
+ [
95
+ ModelGroup([
96
+ Model('LLM-Research/MolmoE-1B-0924', 'allenai/MolmoE-1B-0924'),
97
+ ]),
98
+ ],
99
+ TemplateType.molmo,
100
+ get_model_tokenizer_molmoe,
101
+ model_arch=ModelArch.molmo,
102
+ torch_dtype=torch.float32,
103
+ architectures=['OLMoForCausalLM'],
104
+ tags=['vision'],
105
+ requires=['transformers>=4.45'],
106
+ ))
107
+
108
+
109
+ def get_model_tokenizer_molmo(model_dir: str,
110
+ model_info: ModelInfo,
111
+ model_kwargs: Dict[str, Any],
112
+ load_model: bool = True,
113
+ **kwargs):
114
+ model_cls = get_class_from_dynamic_module('modeling_molmo.MolmoForCausalLM', model_dir)
115
+ model_cls._no_split_modules = ['MolmoSequentialBlock']
116
+ model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
117
+ if model is not None:
118
+ patch_output_clone(model.model.transformer.wte)
119
+ return model, processor
120
+
121
+
122
+ register_model(
123
+ ModelMeta(
124
+ MLLMModelType.molmo,
125
+ [
126
+ ModelGroup([
127
+ Model('LLM-Research/Molmo-7B-O-0924', 'allenai/Molmo-7B-O-0924'),
128
+ Model('LLM-Research/Molmo-7B-D-0924', 'allenai/Molmo-7B-D-0924'),
129
+ Model('LLM-Research/Molmo-72B-0924', 'allenai/Molmo-72B-0924'),
130
+ ]),
131
+ ],
132
+ TemplateType.molmo,
133
+ get_model_tokenizer_molmo,
134
+ model_arch=ModelArch.molmo,
135
+ architectures=['MolmoForCausalLM'],
136
+ tags=['vision'],
137
+ requires=['transformers>=4.45'],
138
+ ))
139
+
140
+
141
+ def get_model_tokenizer_megrez_omni(model_dir, *args, **kwargs):
142
+ model_cls = get_class_from_dynamic_module('modeling_megrezo.MegrezO', model_dir)
143
+ model_cls._no_split_modules = ['ResidualAttentionBlock', 'LlamaDecoderLayer']
144
+ model_cls = get_class_from_dynamic_module('modeling_megrezo.SiglipVisionTransformer', model_dir)
145
+ model_cls._no_split_modules = ['SiglipEncoderLayer']
146
+ model, processor = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
147
+ processor = model._get_or_init_processor()
148
+ patch_output_clone(model.llm.model.embed_tokens)
149
+ use_submodel_func(model, 'llm')
150
+ return model, processor
151
+
152
+
153
+ register_model(
154
+ ModelMeta(
155
+ MLLMModelType.megrez_omni,
156
+ [
157
+ ModelGroup([
158
+ Model('InfiniAI/Megrez-3B-Omni', 'Infinigence/Megrez-3B-Omni'),
159
+ ]),
160
+ ],
161
+ TemplateType.megrez_omni,
162
+ get_model_tokenizer_megrez_omni,
163
+ model_arch=ModelArch.megrez_omni,
164
+ architectures=['MegrezO'],
165
+ tags=['vision', 'audio'],
166
+ ))
167
+
168
+ register_model(
169
+ ModelMeta(
170
+ MLLMModelType.qwen2_gme, [
171
+ ModelGroup([
172
+ Model('iic/gme-Qwen2-VL-2B-Instruct', 'Alibaba-NLP/gme-Qwen2-VL-2B-Instruct'),
173
+ Model('iic/gme-Qwen2-VL-7B-Instruct', 'Alibaba-NLP/gme-Qwen2-VL-7B-Instruct'),
174
+ ]),
175
+ ],
176
+ TemplateType.qwen2_gme,
177
+ get_model_tokenizer_qwen2_vl,
178
+ model_arch=ModelArch.qwen2_vl,
179
+ architectures=['Qwen2VLForConditionalGeneration'],
180
+ tags=['vision']))
swift/llm/model/model/moonshot.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from swift.llm import TemplateType
3
+ from ..constant import LLMModelType, MLLMModelType
4
+ from ..model_arch import ModelArch
5
+ from ..patcher import patch_output_clone
6
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
7
+ get_model_tokenizer_with_flash_attn, register_model)
8
+
9
+ register_model(
10
+ ModelMeta(
11
+ LLMModelType.moonlight,
12
+ [
13
+ ModelGroup([
14
+ Model('moonshotai/Moonlight-16B-A3B', 'moonshotai/Moonlight-16B-A3B'),
15
+ Model('moonshotai/Moonlight-16B-A3B-Instruct', 'moonshotai/Moonlight-16B-A3B-Instruct'),
16
+ ]),
17
+ ],
18
+ TemplateType.moonlight,
19
+ get_model_tokenizer_with_flash_attn,
20
+ architectures=['DeepseekV3ForCausalLM'],
21
+ model_arch=ModelArch.deepseek_v2,
22
+ requires=['transformers<4.49'],
23
+ ))
24
+
25
+
26
+ def get_model_tokenizer_kimi_vl(*args, **kwargs):
27
+ model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
28
+ if model is not None:
29
+ patch_output_clone(model.language_model.model.embed_tokens)
30
+ return model, processor
31
+
32
+
33
+ register_model(
34
+ ModelMeta(
35
+ MLLMModelType.kimi_vl,
36
+ [
37
+ ModelGroup([
38
+ Model('moonshotai/Kimi-VL-A3B-Instruct', 'moonshotai/Kimi-VL-A3B-Instruct'),
39
+ Model('moonshotai/Kimi-VL-A3B-Thinking', 'moonshotai/Kimi-VL-A3B-Thinking'),
40
+ ])
41
+ ],
42
+ TemplateType.kimi_vl,
43
+ get_model_tokenizer_kimi_vl,
44
+ architectures=['KimiVLForConditionalGeneration'],
45
+ model_arch=ModelArch.llava_hf,
46
+ requires=['transformers<4.49'],
47
+ ))
swift/llm/model/model/mplug.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ from functools import partial
5
+ from typing import Any, Dict
6
+
7
+ from transformers import AutoConfig
8
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
9
+
10
+ from swift.llm import TemplateType
11
+ from swift.utils import get_logger
12
+ from ..constant import MLLMModelType
13
+ from ..model_arch import ModelArch
14
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
15
+ from ..utils import ModelInfo, git_clone_github, use_submodel_func
16
+ from .qwen import get_model_tokenizer_qwen
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ def get_model_tokenizer_mplug_owl2(model_dir: str,
22
+ model_info: ModelInfo,
23
+ model_kwargs: Dict[str, Any],
24
+ load_model: bool = True,
25
+ **kwargs):
26
+ local_repo_path = kwargs.get('local_repo_path')
27
+ if not local_repo_path:
28
+ local_repo_path = git_clone_github('https://github.com/X-PLUG/mPLUG-Owl')
29
+ local_repo_path = os.path.join(local_repo_path, 'mPLUG-Owl2')
30
+ sys.path.append(local_repo_path)
31
+
32
+ # register
33
+ # https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py#L447
34
+ from mplug_owl2 import MPLUGOwl2LlamaForCausalLM
35
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
36
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
37
+ vocab_size = kwargs.pop('vocab_size', None)
38
+ if vocab_size is not None:
39
+ model_config.vocab_size = vocab_size
40
+ get_model_tokenizer_function = kwargs.pop('get_model_tokenizer_function', get_model_tokenizer_with_flash_attn)
41
+ model, tokenizer = get_model_tokenizer_function(
42
+ model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
43
+ logger.info('Please ignore the unimported warning.')
44
+ processor = CLIPImageProcessor.from_pretrained(model_dir)
45
+ processor.tokenizer = tokenizer
46
+ return model, processor
47
+
48
+
49
+ register_model(
50
+ ModelMeta(
51
+ MLLMModelType.mplug_owl2, [ModelGroup([
52
+ Model('iic/mPLUG-Owl2', 'MAGAer13/mplug-owl2-llama2-7b'),
53
+ ])],
54
+ TemplateType.mplug_owl2,
55
+ get_model_tokenizer_mplug_owl2,
56
+ model_arch=ModelArch.mplug_owl2,
57
+ requires=['transformers<4.35', 'icecream'],
58
+ tags=['vision']), )
59
+
60
+ register_model(
61
+ ModelMeta(
62
+ MLLMModelType.mplug_owl2_1, [ModelGroup([
63
+ Model('iic/mPLUG-Owl2.1', 'Mizukiluke/mplug_owl_2_1'),
64
+ ])],
65
+ TemplateType.mplug_owl2,
66
+ partial(
67
+ get_model_tokenizer_mplug_owl2, vocab_size=151851, get_model_tokenizer_function=get_model_tokenizer_qwen),
68
+ model_arch=ModelArch.mplug_owl2_1,
69
+ requires=['transformers<4.35', 'icecream'],
70
+ tags=['vision']))
71
+
72
+
73
+ def get_model_tokenizer_mplug_owl3(model_dir: str,
74
+ model_info: ModelInfo,
75
+ model_kwargs: Dict[str, Any],
76
+ load_model: bool = True,
77
+ **kwargs):
78
+ get_class_from_dynamic_module('configuration_hyper_qwen2.HyperQwen2Config', model_dir)
79
+ model_cls = get_class_from_dynamic_module('modeling_mplugowl3.mPLUGOwl3Model', model_dir)
80
+ model_cls._no_split_modules = ['SiglipEncoderLayer']
81
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
82
+ processor = model.init_processor(tokenizer)
83
+ if model is not None:
84
+ func_list = ['generate', 'forward']
85
+ use_submodel_func(model, 'language_model', func_list)
86
+ return model, processor
87
+
88
+
89
+ register_model(
90
+ ModelMeta(
91
+ MLLMModelType.mplug_owl3, [
92
+ ModelGroup([
93
+ Model('iic/mPLUG-Owl3-1B-241014', 'mPLUG/mPLUG-Owl3-1B-241014'),
94
+ Model('iic/mPLUG-Owl3-2B-241014', 'mPLUG/mPLUG-Owl3-2B-241014'),
95
+ Model('iic/mPLUG-Owl3-7B-240728', 'mPLUG/mPLUG-Owl3-7B-240728'),
96
+ ]),
97
+ ],
98
+ TemplateType.mplug_owl3,
99
+ get_model_tokenizer_mplug_owl3,
100
+ architectures=['mPLUGOwl3Model'],
101
+ model_arch=ModelArch.mplug_owl3,
102
+ requires=['transformers>=4.36', 'icecream', 'decord'],
103
+ tags=['vision', 'video']))
104
+
105
+ register_model(
106
+ ModelMeta(
107
+ MLLMModelType.mplug_owl3_241101, [
108
+ ModelGroup([
109
+ Model('iic/mPLUG-Owl3-7B-241101', 'mPLUG/mPLUG-Owl3-7B-241101'),
110
+ ]),
111
+ ],
112
+ TemplateType.mplug_owl3_241101,
113
+ get_model_tokenizer_mplug_owl3,
114
+ architectures=['mPLUGOwl3Model'],
115
+ model_arch=ModelArch.mplug_owl3,
116
+ requires=['transformers>=4.36', 'icecream'],
117
+ tags=['vision', 'video']))
118
+
119
+
120
+ def get_model_tokenizer_doc_owl2(model_dir: str,
121
+ model_info: ModelInfo,
122
+ model_kwargs: Dict[str, Any],
123
+ load_model: bool = True,
124
+ **kwargs):
125
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
126
+ processor = model.init_processor(tokenizer, basic_image_size=504, crop_anchors='grid_12')
127
+ return model, processor
128
+
129
+
130
+ register_model(
131
+ ModelMeta(
132
+ MLLMModelType.doc_owl2, [
133
+ ModelGroup([
134
+ Model('iic/DocOwl2', 'mPLUG/DocOwl2'),
135
+ ]),
136
+ ],
137
+ TemplateType.doc_owl2,
138
+ get_model_tokenizer_doc_owl2,
139
+ architectures=['mPLUGDocOwl2'],
140
+ model_arch=ModelArch.doc_owl2,
141
+ requires=['transformers>=4.36', 'icecream'],
142
+ tags=['vision']))
swift/llm/model/model/openbuddy.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from swift.llm import TemplateType
3
+ from swift.utils import get_logger
4
+ from ..constant import LLMModelType
5
+ from ..model_arch import ModelArch
6
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
7
+
8
+ logger = get_logger()
9
+
10
+ register_model(
11
+ ModelMeta(
12
+ LLMModelType.openbuddy_llama,
13
+ [
14
+ ModelGroup([
15
+ Model('OpenBuddy/openbuddy-llama-65b-v8-bf16', 'OpenBuddy/openbuddy-llama-65b-v8-bf16'),
16
+ ]),
17
+ ModelGroup([
18
+ Model('OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16'),
19
+ Model('OpenBuddy/openbuddy-llama2-70b-v10.1-bf16', 'OpenBuddy/openbuddy-llama2-70b-v10.1-bf16'),
20
+ ]),
21
+ ModelGroup([
22
+ Model('OpenBuddy/openbuddy-deepseek-67b-v15.2', 'OpenBuddy/openbuddy-deepseek-67b-v15.2'),
23
+ ]),
24
+ ],
25
+ TemplateType.openbuddy,
26
+ get_model_tokenizer_with_flash_attn,
27
+ model_arch=ModelArch.llama,
28
+ architectures=['LlamaForCausalLM'],
29
+ ))
30
+
31
+ register_model(
32
+ ModelMeta(
33
+ LLMModelType.openbuddy_llama3,
34
+ [
35
+ ModelGroup([
36
+ Model('OpenBuddy/openbuddy-llama3-8b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-8b-v21.1-8k'),
37
+ Model('OpenBuddy/openbuddy-llama3-70b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-70b-v21.1-8k'),
38
+ Model('OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k', 'OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k'),
39
+ ]),
40
+ ModelGroup([
41
+ Model('OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k', 'OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k'),
42
+ Model('OpenBuddy/openbuddy-nemotron-70b-v23.2-131k', 'OpenBuddy/openbuddy-nemotron-70b-v23.2-131k'),
43
+ ],
44
+ requires=['transformers>=4.43']),
45
+ ModelGroup(
46
+ [Model('OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k', 'OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k')],
47
+ requires=['transformers>=4.45'])
48
+ ],
49
+ TemplateType.openbuddy2,
50
+ get_model_tokenizer_with_flash_attn,
51
+ model_arch=ModelArch.llama,
52
+ architectures=['LlamaForCausalLM'],
53
+ ))
54
+
55
+ register_model(
56
+ ModelMeta(
57
+ LLMModelType.openbuddy_mistral,
58
+ [
59
+ ModelGroup([
60
+ Model('OpenBuddy/openbuddy-mistral-7b-v17.1-32k', 'OpenBuddy/openbuddy-mistral-7b-v17.1-32k'),
61
+ ]),
62
+ ModelGroup([
63
+ Model('OpenBuddy/openbuddy-zephyr-7b-v14.1', 'OpenBuddy/openbuddy-zephyr-7b-v14.1'),
64
+ ]),
65
+ ],
66
+ TemplateType.openbuddy,
67
+ get_model_tokenizer_with_flash_attn,
68
+ model_arch=ModelArch.llama,
69
+ requires=['transformers>=4.34'],
70
+ architectures=['MistralForCausalLM'],
71
+ ))
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ LLMModelType.openbuddy_mixtral,
76
+ [
77
+ ModelGroup([
78
+ Model('OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k', 'OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k'),
79
+ ], ),
80
+ ],
81
+ TemplateType.openbuddy,
82
+ get_model_tokenizer_with_flash_attn,
83
+ architectures=['MixtralForCausalLM'],
84
+ requires=['transformers>=4.36'],
85
+ ))
swift/llm/model/model/qwen.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
6
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
7
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
8
+
9
+ from swift.llm import TemplateType
10
+ from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger
11
+ from ..constant import LLMModelType, MLLMModelType, RMModelType
12
+ from ..model_arch import ModelArch
13
+ from ..patcher import patch_fixed_device, patch_output_clone, patch_output_to_input_device
14
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_reward_model,
15
+ get_model_tokenizer_with_flash_attn, register_model)
16
+ from ..utils import AttnImpl, ModelInfo, use_submodel_func
17
+
18
+ logger = get_logger()
19
+ dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'}
20
+
21
+
22
+ def get_model_tokenizer_qwen(model_dir: str,
23
+ model_info: ModelInfo,
24
+ model_kwargs: Dict[str, Any],
25
+ load_model: bool = True,
26
+ model_config=None,
27
+ **kwargs):
28
+ if model_config is None:
29
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
30
+ if model_info.torch_dtype is not None:
31
+ k_true = dtype_mapping[model_info.torch_dtype]
32
+ for k in dtype_mapping.values():
33
+ setattr(model_config, k, k == k_true)
34
+
35
+ quantization_config = model_kwargs.get('quantization_config')
36
+ if not isinstance(quantization_config, BitsAndBytesConfig):
37
+ # not bnb quant
38
+ model_config.torch_dtype = None
39
+ use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
40
+ model_config.use_flash_attn = use_flash_attn
41
+ kwargs['model_config'] = model_config
42
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
43
+ try:
44
+ # fix mp+ddp bug
45
+ model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
46
+ logger.info('registered_causal_mask to cuda')
47
+ except AttributeError:
48
+ pass
49
+ if tokenizer.eos_token_id is None:
50
+ tokenizer.eos_token_id = tokenizer.eod_id
51
+ return model, tokenizer
52
+
53
+
54
+ register_model(
55
+ ModelMeta(
56
+ LLMModelType.qwen,
57
+ [
58
+ # qwen
59
+ ModelGroup([
60
+ # chat
61
+ Model('Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-1_8B-Chat'),
62
+ Model('Qwen/Qwen-7B-Chat', 'Qwen/Qwen-7B-Chat'),
63
+ Model('Qwen/Qwen-14B-Chat', 'Qwen/Qwen-14B-Chat'),
64
+ Model('Qwen/Qwen-72B-Chat', 'Qwen/Qwen-72B-Chat'),
65
+ # base
66
+ Model('Qwen/Qwen-1_8B', 'Qwen/Qwen-1_8B'),
67
+ Model('Qwen/Qwen-7B', 'Qwen/Qwen-7B'),
68
+ Model('Qwen/Qwen-14B', 'Qwen/Qwen-14B'),
69
+ Model('Qwen/Qwen-72B', 'Qwen/Qwen-72B'),
70
+ # gptq-int4
71
+ Model('Qwen/Qwen-1_8B-Chat-Int4', 'Qwen/Qwen-1_8B-Chat-Int4'),
72
+ Model('Qwen/Qwen-7B-Chat-Int4', 'Qwen/Qwen-7B-Chat-Int4'),
73
+ Model('Qwen/Qwen-14B-Chat-Int4', 'Qwen/Qwen-14B-Chat-Int4'),
74
+ Model('Qwen/Qwen-72B-Chat-Int4', 'Qwen/Qwen-72B-Chat-Int4'),
75
+ # gptq-int8
76
+ Model('Qwen/Qwen-1_8B-Chat-Int8', 'Qwen/Qwen-1_8B-Chat-Int8'),
77
+ Model('Qwen/Qwen-7B-Chat-Int8', 'Qwen/Qwen-7B-Chat-Int8'),
78
+ Model('Qwen/Qwen-14B-Chat-Int8', 'Qwen/Qwen-14B-Chat-Int8'),
79
+ Model('Qwen/Qwen-72B-Chat-Int8', 'Qwen/Qwen-72B-Chat-Int8'),
80
+ ]),
81
+ # tongyi-finance
82
+ ModelGroup([
83
+ Model('TongyiFinance/Tongyi-Finance-14B-Chat', 'jxy/Tongyi-Finance-14B-Chat'),
84
+ Model('TongyiFinance/Tongyi-Finance-14B'),
85
+ Model('TongyiFinance/Tongyi-Finance-14B-Chat-Int4', 'jxy/Tongyi-Finance-14B-Chat-Int4'),
86
+ ],
87
+ tags=['financial']),
88
+ ],
89
+ TemplateType.qwen,
90
+ get_model_tokenizer_qwen,
91
+ architectures=['QWenLMHeadModel'],
92
+ model_arch=ModelArch.qwen))
93
+
94
+ register_model(
95
+ ModelMeta(
96
+ LLMModelType.modelscope_agent,
97
+ [ModelGroup([
98
+ Model('iic/ModelScope-Agent-7B'),
99
+ Model('iic/ModelScope-Agent-14B'),
100
+ ])],
101
+ TemplateType.modelscope_agent,
102
+ get_model_tokenizer_qwen,
103
+ architectures=['QWenLMHeadModel'],
104
+ model_arch=ModelArch.qwen))
105
+
106
+
107
+ def _qwen_vl_audio_decode(self, *args, skip_special_tokens=False, **kwargs) -> str:
108
+ if skip_special_tokens:
109
+ token_ids = kwargs['token_ids']
110
+ while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
111
+ token_ids.pop()
112
+ return self._old_decode(*args, skip_special_tokens=False, **kwargs)
113
+ else:
114
+ return self._old_decode(*args, skip_special_tokens=False, **kwargs)
115
+
116
+
117
+ def fix_qwen_inplace_bug(model) -> None:
118
+ # qwen-vl, qwen-audio
119
+ first_drop = model.transformer.drop
120
+ if first_drop.p == 0.:
121
+ # fix in-place operation bug
122
+ patch_output_clone(first_drop)
123
+
124
+
125
+ def get_model_tokenizer_qwen_audio(model_dir: str,
126
+ model_info: ModelInfo,
127
+ model_kwargs: Dict[str, Any],
128
+ load_model: bool = True,
129
+ **kwargs):
130
+ tokenizer_config = get_tokenizer_config(model_dir)
131
+ class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
132
+ tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
133
+ tokenizer_cls._auto_class = 'AutoTokenizer'
134
+ tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
135
+ if not hasattr(tokenizer_cls, '_old_decode'):
136
+ tokenizer_cls._old_decode = tokenizer_cls._decode
137
+ tokenizer_cls._decode = _qwen_vl_audio_decode
138
+ kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
139
+ model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
140
+ if model is not None:
141
+ fix_qwen_inplace_bug(model)
142
+
143
+ return model, tokenizer
144
+
145
+
146
+ register_model(
147
+ ModelMeta(
148
+ MLLMModelType.qwen_audio, [
149
+ ModelGroup([
150
+ Model('Qwen/Qwen-Audio-Chat', 'Qwen/Qwen-Audio-Chat'),
151
+ Model('Qwen/Qwen-Audio', 'Qwen/Qwen-Audio'),
152
+ ])
153
+ ],
154
+ TemplateType.qwen_audio,
155
+ get_model_tokenizer_qwen_audio,
156
+ model_arch=ModelArch.qwen_audio,
157
+ architectures=['QWenLMHeadModel'],
158
+ additional_saved_files=['mel_filters.npz'],
159
+ tags=['audio']))
160
+
161
+
162
+ def _qwen_vl_visual_block_forward(
163
+ self,
164
+ q_x: torch.Tensor,
165
+ k_x: Optional[torch.Tensor] = None,
166
+ v_x: Optional[torch.Tensor] = None,
167
+ attn_mask: Optional[torch.Tensor] = None,
168
+ ):
169
+ k_x = self.ln_1_kv(k_x) if hasattr(self, 'ln_1_kv') and k_x is not None else None
170
+ v_x = self.ln_1_kv(v_x) if hasattr(self, 'ln_1_kv') and v_x is not None else None
171
+
172
+ x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
173
+ z = self.mlp(self.ln_2(x))
174
+ x = x.to(z.device) + z # FIX
175
+ return x
176
+
177
+
178
+ def get_model_tokenizer_qwen_vl(model_dir: str,
179
+ model_info: ModelInfo,
180
+ model_kwargs: Dict[str, Any],
181
+ load_model: bool = True,
182
+ **kwargs):
183
+ if (model_kwargs.get('quantization_config') is not None
184
+ and isinstance(model_kwargs['quantization_config'], BitsAndBytesConfig)):
185
+ # https://github.com/pytorch/pytorch/issues/58969
186
+ model_kwargs['quantization_config'].llm_int8_skip_modules = ['lm_head', 'attn_pool.attn']
187
+ _TransformerBlock = get_class_from_dynamic_module('visual.TransformerBlock', model_dir)
188
+
189
+ def _get_cast_dtype(self) -> torch.dtype:
190
+ return self.resblocks[0].ln_1.weight.dtype
191
+
192
+ _TransformerBlock.__old_get_cast_dtype = _TransformerBlock.get_cast_dtype
193
+ _TransformerBlock.get_cast_dtype = _get_cast_dtype
194
+
195
+ tokenizer_config = get_tokenizer_config(model_dir)
196
+ class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
197
+ tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
198
+ tokenizer_cls._auto_class = 'AutoTokenizer'
199
+ tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
200
+ if not hasattr(tokenizer_cls, '_old_decode'):
201
+ tokenizer_cls._old_decode = tokenizer_cls._decode
202
+ tokenizer_cls._decode = _qwen_vl_audio_decode
203
+ # fix device_map is 4
204
+ n_gpu = get_device_count()
205
+ local_world_size = get_dist_setting()[3]
206
+ if n_gpu // local_world_size >= 4:
207
+ visual_block_cls = get_class_from_dynamic_module('visual.VisualAttentionBlock', model_dir)
208
+ visual_block_cls.__old_forward = visual_block_cls.forward
209
+ visual_block_cls.forward = _qwen_vl_visual_block_forward
210
+
211
+ kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
212
+ model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
213
+ if model is not None:
214
+ device_type = next(model.parameters()).device.type
215
+ fix_qwen_inplace_bug(model)
216
+ # fix device_map is 4
217
+ if n_gpu // local_world_size >= 4:
218
+ model.transformer.visual.proj.data = model.transformer.visual.proj.to(
219
+ model.transformer.visual.ln_post.bias.device)
220
+ # fix images cuda:1 bug
221
+ patch_fixed_device(model.transformer.visual, f'{device_type}:0')
222
+ return model, tokenizer
223
+
224
+
225
+ register_model(
226
+ ModelMeta(
227
+ MLLMModelType.qwen_vl, [
228
+ ModelGroup([
229
+ Model('Qwen/Qwen-VL-Chat', 'Qwen/Qwen-VL-Chat'),
230
+ Model('Qwen/Qwen-VL', 'Qwen/Qwen-VL'),
231
+ Model('Qwen/Qwen-VL-Chat-Int4', 'Qwen/Qwen-VL-Chat-Int4'),
232
+ ])
233
+ ],
234
+ TemplateType.qwen_vl,
235
+ get_model_tokenizer_qwen_vl,
236
+ model_arch=ModelArch.qwen_vl,
237
+ architectures=['QWenLMHeadModel'],
238
+ additional_saved_files=['SimSun.ttf'],
239
+ tags=['vision']))
240
+
241
+ register_model(
242
+ ModelMeta(
243
+ LLMModelType.qwen2,
244
+ [
245
+ # qwen1.5
246
+ ModelGroup([
247
+ # chat
248
+ Model('Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-0.5B-Chat'),
249
+ Model('Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-1.8B-Chat'),
250
+ Model('Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-4B-Chat'),
251
+ Model('Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-7B-Chat'),
252
+ Model('Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen1.5-14B-Chat'),
253
+ Model('Qwen/Qwen1.5-32B-Chat', 'Qwen/Qwen1.5-32B-Chat'),
254
+ Model('Qwen/Qwen1.5-72B-Chat', 'Qwen/Qwen1.5-72B-Chat'),
255
+ Model('Qwen/Qwen1.5-110B-Chat', 'Qwen/Qwen1.5-110B-Chat'),
256
+ # base
257
+ Model('Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B'),
258
+ Model('Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B'),
259
+ Model('Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B'),
260
+ Model('Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B'),
261
+ Model('Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B'),
262
+ Model('Qwen/Qwen1.5-32B', 'Qwen/Qwen1.5-32B'),
263
+ Model('Qwen/Qwen1.5-72B', 'Qwen/Qwen1.5-72B'),
264
+ Model('Qwen/Qwen1.5-110B', 'Qwen/Qwen1.5-110B'),
265
+ # gptq-int4
266
+ Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4'),
267
+ Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4'),
268
+ Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int4'),
269
+ Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int4'),
270
+ Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int4'),
271
+ Model('Qwen/Qwen1.5-32B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-32B-Chat-GPTQ-Int4'),
272
+ Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int4'),
273
+ Model('Qwen/Qwen1.5-110B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-110B-Chat-GPTQ-Int4'),
274
+ # gptq-int8
275
+ Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8'),
276
+ Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8'),
277
+ Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int8'),
278
+ Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int8'),
279
+ Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int8'),
280
+ Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int8'),
281
+ # awq-int4
282
+ Model('Qwen/Qwen1.5-0.5B-Chat-AWQ', 'Qwen/Qwen1.5-0.5B-Chat-AWQ'),
283
+ Model('Qwen/Qwen1.5-1.8B-Chat-AWQ', 'Qwen/Qwen1.5-1.8B-Chat-AWQ'),
284
+ Model('Qwen/Qwen1.5-4B-Chat-AWQ', 'Qwen/Qwen1.5-4B-Chat-AWQ'),
285
+ Model('Qwen/Qwen1.5-7B-Chat-AWQ', 'Qwen/Qwen1.5-7B-Chat-AWQ'),
286
+ Model('Qwen/Qwen1.5-14B-Chat-AWQ', 'Qwen/Qwen1.5-14B-Chat-AWQ'),
287
+ Model('Qwen/Qwen1.5-32B-Chat-AWQ', 'Qwen/Qwen1.5-32B-Chat-AWQ'),
288
+ Model('Qwen/Qwen1.5-72B-Chat-AWQ', 'Qwen/Qwen1.5-72B-Chat-AWQ'),
289
+ Model('Qwen/Qwen1.5-110B-Chat-AWQ', 'Qwen/Qwen1.5-110B-Chat-AWQ'),
290
+ ]),
291
+ # code-qwen1.5
292
+ ModelGroup([
293
+ Model('Qwen/CodeQwen1.5-7B', 'Qwen/CodeQwen1.5-7B'),
294
+ Model('Qwen/CodeQwen1.5-7B-Chat', 'Qwen/CodeQwen1.5-7B-Chat'),
295
+ Model('Qwen/CodeQwen1.5-7B-Chat-AWQ', 'Qwen/CodeQwen1.5-7B-Chat-AWQ'),
296
+ ],
297
+ tags=['coding']),
298
+ # qwen2
299
+ ModelGroup([
300
+ # instruct
301
+ Model('Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-0.5B-Instruct'),
302
+ Model('Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-1.5B-Instruct'),
303
+ Model('Qwen/Qwen2-7B-Instruct', 'Qwen/Qwen2-7B-Instruct'),
304
+ Model('Qwen/Qwen2-72B-Instruct', 'Qwen/Qwen2-72B-Instruct'),
305
+ # base
306
+ Model('Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B'),
307
+ Model('Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B'),
308
+ Model('Qwen/Qwen2-7B', 'Qwen/Qwen2-7B'),
309
+ Model('Qwen/Qwen2-72B', 'Qwen/Qwen2-72B'),
310
+ # gptq-int4
311
+ Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4'),
312
+ Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4'),
313
+ Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int4'),
314
+ Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int4'),
315
+ # gptq-int8
316
+ Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8'),
317
+ Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8'),
318
+ Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int8'),
319
+ Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int8'),
320
+ # awq-int4
321
+ Model('Qwen/Qwen2-0.5B-Instruct-AWQ', 'Qwen/Qwen2-0.5B-Instruct-AWQ'),
322
+ Model('Qwen/Qwen2-1.5B-Instruct-AWQ', 'Qwen/Qwen2-1.5B-Instruct-AWQ'),
323
+ Model('Qwen/Qwen2-7B-Instruct-AWQ', 'Qwen/Qwen2-7B-Instruct-AWQ'),
324
+ Model('Qwen/Qwen2-72B-Instruct-AWQ', 'Qwen/Qwen2-72B-Instruct-AWQ'),
325
+ ]),
326
+ # qwen2-math
327
+ ModelGroup(
328
+ [
329
+ # instruct
330
+ Model('Qwen/Qwen2-Math-1.5B-Instruct', 'Qwen/Qwen2-Math-1.5B-Instruct'),
331
+ Model('Qwen/Qwen2-Math-7B-Instruct', 'Qwen/Qwen2-Math-7B-Instruct'),
332
+ Model('Qwen/Qwen2-Math-72B-Instruct', 'Qwen/Qwen2-Math-72B-Instruct'),
333
+ # base
334
+ Model('Qwen/Qwen2-Math-1.5B', 'Qwen/Qwen2-Math-1.5B'),
335
+ Model('Qwen/Qwen2-Math-7B', 'Qwen/Qwen2-Math-7B'),
336
+ Model('Qwen/Qwen2-Math-72B', 'Qwen/Qwen2-Math-72B'),
337
+ ],
338
+ tags=['math']),
339
+ # qwen2.5-1m
340
+ ModelGroup([
341
+ Model('Qwen/Qwen2.5-7B-Instruct-1M', 'Qwen/Qwen2.5-7B-Instruct-1M'),
342
+ Model('Qwen/Qwen2.5-14B-Instruct-1M', 'Qwen/Qwen2.5-14B-Instruct-1M'),
343
+ ]),
344
+ # other
345
+ ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]),
346
+ ],
347
+ TemplateType.qwen,
348
+ get_model_tokenizer_with_flash_attn,
349
+ architectures=['Qwen2ForCausalLM'],
350
+ requires=['transformers>=4.37'],
351
+ model_arch=ModelArch.llama))
352
+
353
+ register_model(
354
+ ModelMeta(
355
+ LLMModelType.qwen2_5,
356
+ [
357
+ # qwen2.5
358
+ ModelGroup([
359
+ # instruct
360
+ Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'),
361
+ Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'),
362
+ Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'),
363
+ Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'),
364
+ Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'),
365
+ Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
366
+ Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
367
+ # base
368
+ Model('Qwen/Qwen2.5-0.5B', 'Qwen/Qwen2.5-0.5B'),
369
+ Model('Qwen/Qwen2.5-1.5B', 'Qwen/Qwen2.5-1.5B'),
370
+ Model('Qwen/Qwen2.5-3B', 'Qwen/Qwen2.5-3B'),
371
+ Model('Qwen/Qwen2.5-7B', 'Qwen/Qwen2.5-7B'),
372
+ Model('Qwen/Qwen2.5-14B', 'Qwen/Qwen2.5-14B'),
373
+ Model('Qwen/Qwen2.5-32B', 'Qwen/Qwen2.5-32B'),
374
+ Model('Qwen/Qwen2.5-72B', 'Qwen/Qwen2.5-72B'),
375
+ # gptq-int4
376
+ Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4'),
377
+ Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4'),
378
+ Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4'),
379
+ Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4'),
380
+ Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4'),
381
+ Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4'),
382
+ Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4'),
383
+ # gptq-int8
384
+ Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8'),
385
+ Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8'),
386
+ Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8'),
387
+ Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8'),
388
+ Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8'),
389
+ Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8'),
390
+ Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8'),
391
+ # awq-int4
392
+ Model('Qwen/Qwen2.5-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-0.5B-Instruct-AWQ'),
393
+ Model('Qwen/Qwen2.5-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-1.5B-Instruct-AWQ'),
394
+ Model('Qwen/Qwen2.5-3B-Instruct-AWQ', 'Qwen/Qwen2.5-3B-Instruct-AWQ'),
395
+ Model('Qwen/Qwen2.5-7B-Instruct-AWQ', 'Qwen/Qwen2.5-7B-Instruct-AWQ'),
396
+ Model('Qwen/Qwen2.5-14B-Instruct-AWQ', 'Qwen/Qwen2.5-14B-Instruct-AWQ'),
397
+ Model('Qwen/Qwen2.5-32B-Instruct-AWQ', 'Qwen/Qwen2.5-32B-Instruct-AWQ'),
398
+ Model('Qwen/Qwen2.5-72B-Instruct-AWQ', 'Qwen/Qwen2.5-72B-Instruct-AWQ'),
399
+ ]),
400
+ # qwen2.5-coder
401
+ ModelGroup(
402
+ [
403
+ # instruct
404
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct', 'Qwen/Qwen2.5-Coder-0.5B-Instruct'),
405
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct', 'Qwen/Qwen2.5-Coder-1.5B-Instruct'),
406
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct', 'Qwen/Qwen2.5-Coder-3B-Instruct'),
407
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
408
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct', 'Qwen/Qwen2.5-Coder-14B-Instruct'),
409
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct', 'Qwen/Qwen2.5-Coder-32B-Instruct'),
410
+ # base
411
+ Model('Qwen/Qwen2.5-Coder-0.5B', 'Qwen/Qwen2.5-Coder-0.5B'),
412
+ Model('Qwen/Qwen2.5-Coder-1.5B', 'Qwen/Qwen2.5-Coder-1.5B'),
413
+ Model('Qwen/Qwen2.5-Coder-3B', 'Qwen/Qwen2.5-Coder-3B'),
414
+ Model('Qwen/Qwen2.5-Coder-7B', 'Qwen/Qwen2.5-Coder-7B'),
415
+ Model('Qwen/Qwen2.5-Coder-14B', 'Qwen/Qwen2.5-Coder-14B'),
416
+ Model('Qwen/Qwen2.5-Coder-32B', 'Qwen/Qwen2.5-Coder-32B'),
417
+ # AWQ
418
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ'),
419
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ'),
420
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-3B-Instruct-AWQ'),
421
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-7B-Instruct-AWQ'),
422
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-14B-Instruct-AWQ'),
423
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-32B-Instruct-AWQ'),
424
+ # GPTQ
425
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4'),
426
+ Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8'),
427
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4'),
428
+ Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8'),
429
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4'),
430
+ Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8'),
431
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4'),
432
+ Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8'),
433
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4'),
434
+ Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8'),
435
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4'),
436
+ Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8'),
437
+ ],
438
+ tags=['coding'])
439
+ ],
440
+ TemplateType.qwen2_5,
441
+ get_model_tokenizer_with_flash_attn,
442
+ architectures=['Qwen2ForCausalLM'],
443
+ requires=['transformers>=4.37'],
444
+ model_arch=ModelArch.llama))
445
+
446
+ register_model(
447
+ ModelMeta(
448
+ LLMModelType.qwen2_5_math,
449
+ [
450
+ # qwen2.5-math
451
+ ModelGroup(
452
+ [
453
+ # instruct
454
+ Model('Qwen/Qwen2.5-Math-1.5B-Instruct', 'Qwen/Qwen2.5-Math-1.5B-Instruct'),
455
+ Model('Qwen/Qwen2.5-Math-7B-Instruct', 'Qwen/Qwen2.5-Math-7B-Instruct'),
456
+ Model('Qwen/Qwen2.5-Math-72B-Instruct', 'Qwen/Qwen2.5-Math-72B-Instruct'),
457
+ # base
458
+ Model('Qwen/Qwen2.5-Math-1.5B', 'Qwen/Qwen2.5-Math-1.5B'),
459
+ Model('Qwen/Qwen2.5-Math-7B', 'Qwen/Qwen2.5-Math-7B'),
460
+ Model('Qwen/Qwen2.5-Math-72B', 'Qwen/Qwen2.5-Math-72B'),
461
+ ],
462
+ tags=['math']),
463
+ ],
464
+ TemplateType.qwen2_5_math,
465
+ get_model_tokenizer_with_flash_attn,
466
+ architectures=['Qwen2ForCausalLM'],
467
+ requires=['transformers>=4.37'],
468
+ model_arch=ModelArch.llama))
469
+
470
+ register_model(
471
+ ModelMeta(
472
+ LLMModelType.qwen2_moe,
473
+ [
474
+ # qwen1.5-moe
475
+ ModelGroup([
476
+ Model('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'Qwen/Qwen1.5-MoE-A2.7B-Chat'),
477
+ Model('Qwen/Qwen1.5-MoE-A2.7B', 'Qwen/Qwen1.5-MoE-A2.7B'),
478
+ Model('Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4'),
479
+ ]),
480
+ ModelGroup([
481
+ Model('Qwen/Qwen2-57B-A14B-Instruct', 'Qwen/Qwen2-57B-A14B-Instruct'),
482
+ Model('Qwen/Qwen2-57B-A14B', 'Qwen/Qwen2-57B-A14B'),
483
+ Model('Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4'),
484
+ ])
485
+ ],
486
+ TemplateType.qwen,
487
+ get_model_tokenizer_with_flash_attn,
488
+ architectures=['Qwen2MoeForCausalLM'],
489
+ requires=['transformers>=4.40'],
490
+ ))
491
+
492
+ register_model(
493
+ ModelMeta(
494
+ LLMModelType.qwen3,
495
+ [
496
+ ModelGroup([
497
+ Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'),
498
+ Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'),
499
+ Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'),
500
+ Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'),
501
+ Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'),
502
+ Model('Qwen/Qwen3-32B-Base', 'Qwen/Qwen3-32B-Base'),
503
+ # instruct
504
+ Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'),
505
+ Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'),
506
+ Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'),
507
+ Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'),
508
+ Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'),
509
+ Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'),
510
+ # fp8
511
+ Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'),
512
+ Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'),
513
+ Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'),
514
+ Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'),
515
+ Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'),
516
+ Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'),
517
+ # awq
518
+ Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'),
519
+ Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'),
520
+ Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'),
521
+ Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'),
522
+ # swift
523
+ Model('swift/Qwen3-32B-AWQ'),
524
+ ]),
525
+ ],
526
+ TemplateType.qwen3,
527
+ get_model_tokenizer_with_flash_attn,
528
+ architectures=['Qwen3ForCausalLM'],
529
+ requires=['transformers>=4.51'],
530
+ model_arch=ModelArch.llama))
531
+
532
+ register_model(
533
+ ModelMeta(
534
+ LLMModelType.qwen3_moe,
535
+ [
536
+ ModelGroup([
537
+ Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'),
538
+ Model('Qwen/Qwen3-235B-A22B-Base', 'Qwen/Qwen3-235B-A22B-Base'),
539
+ # instruct
540
+ Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'),
541
+ Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'),
542
+ # fp8
543
+ Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'),
544
+ Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'),
545
+ # awq
546
+ Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'),
547
+ Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'),
548
+ ]),
549
+ ],
550
+ TemplateType.qwen3,
551
+ get_model_tokenizer_with_flash_attn,
552
+ architectures=['Qwen3MoeForCausalLM'],
553
+ requires=['transformers>=4.51'],
554
+ ))
555
+
556
+
557
+ def patch_qwen_vl_utils(vision_process):
558
+ if hasattr(vision_process, '_patch'):
559
+ return
560
+ for key in [
561
+ 'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels',
562
+ 'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames'
563
+ ]:
564
+ type_func = float if key == 'fps' else int
565
+ setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper())))
566
+ _read_video_decord = vision_process._read_video_decord
567
+
568
+ def _new_read_video_decord(ele: dict):
569
+ from swift.llm import load_file
570
+ ele['video'] = load_file(ele['video'])
571
+ return _read_video_decord(ele)
572
+
573
+ vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord
574
+ vision_process._patch = True
575
+
576
+
577
+ def get_model_tokenizer_qwen2_vl(*args, **kwargs):
578
+ from transformers import Qwen2VLForConditionalGeneration
579
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
580
+ model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
581
+ if model is not None and hasattr(model.model, 'embed_tokens'):
582
+ patch_output_clone(model.model.embed_tokens)
583
+ patch_output_to_input_device(model.model.embed_tokens)
584
+
585
+ from qwen_vl_utils import vision_process
586
+ patch_qwen_vl_utils(vision_process)
587
+ return model, tokenizer
588
+
589
+
590
+ register_model(
591
+ ModelMeta(
592
+ MLLMModelType.qwen2_vl,
593
+ [
594
+ ModelGroup(
595
+ [
596
+ # chat
597
+ Model('Qwen/Qwen2-VL-2B-Instruct', 'Qwen/Qwen2-VL-2B-Instruct'),
598
+ Model('Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-7B-Instruct'),
599
+ Model('Qwen/Qwen2-VL-72B-Instruct', 'Qwen/Qwen2-VL-72B-Instruct'),
600
+ # base
601
+ Model('Qwen/Qwen2-VL-2B', 'Qwen/Qwen2-VL-2B'),
602
+ Model('Qwen/Qwen2-VL-7B', 'Qwen/Qwen2-VL-7B'),
603
+ Model('Qwen/Qwen2-VL-72B', 'Qwen/Qwen2-VL-72B'),
604
+ # gptq-int4
605
+ Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4'),
606
+ Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4'),
607
+ Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4'),
608
+ # gptq-int8
609
+ Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8'),
610
+ Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8'),
611
+ Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8'),
612
+ # awq-int4
613
+ Model('Qwen/Qwen2-VL-2B-Instruct-AWQ', 'Qwen/Qwen2-VL-2B-Instruct-AWQ'),
614
+ Model('Qwen/Qwen2-VL-7B-Instruct-AWQ', 'Qwen/Qwen2-VL-7B-Instruct-AWQ'),
615
+ Model('Qwen/Qwen2-VL-72B-Instruct-AWQ', 'Qwen/Qwen2-VL-72B-Instruct-AWQ'),
616
+ ], ),
617
+ ModelGroup([
618
+ Model('bytedance-research/UI-TARS-2B-SFT', 'bytedance-research/UI-TARS-2B-SFT'),
619
+ Model('bytedance-research/UI-TARS-7B-SFT', 'bytedance-research/UI-TARS-7B-SFT'),
620
+ Model('bytedance-research/UI-TARS-7B-DPO', 'bytedance-research/UI-TARS-7B-DPO'),
621
+ Model('bytedance-research/UI-TARS-72B-SFT', 'bytedance-research/UI-TARS-72B-SFT'),
622
+ Model('bytedance-research/UI-TARS-72B-DPO', 'bytedance-research/UI-TARS-72B-DPO'),
623
+ ]),
624
+ ModelGroup([
625
+ Model('allenai/olmOCR-7B-0225-preview', 'allenai/olmOCR-7B-0225-preview'),
626
+ ]),
627
+ ],
628
+ TemplateType.qwen2_vl,
629
+ get_model_tokenizer_qwen2_vl,
630
+ model_arch=ModelArch.qwen2_vl,
631
+ architectures=['Qwen2VLForConditionalGeneration'],
632
+ requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
633
+ tags=['vision', 'video']))
634
+
635
+ register_model(
636
+ ModelMeta(
637
+ MLLMModelType.qvq, [
638
+ ModelGroup([
639
+ Model('Qwen/QVQ-72B-Preview', 'Qwen/QVQ-72B-Preview'),
640
+ ]),
641
+ ],
642
+ TemplateType.qvq,
643
+ get_model_tokenizer_qwen2_vl,
644
+ model_arch=ModelArch.qwen2_vl,
645
+ architectures=['Qwen2VLForConditionalGeneration'],
646
+ requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
647
+ tags=['vision', 'video']))
648
+
649
+
650
+ def get_model_tokenizer_qwen2_5_vl(*args, **kwargs):
651
+ from transformers import Qwen2_5_VLForConditionalGeneration
652
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5_VLForConditionalGeneration
653
+ return get_model_tokenizer_qwen2_vl(*args, **kwargs)
654
+
655
+
656
+ register_model(
657
+ ModelMeta(
658
+ MLLMModelType.qwen2_5_vl, [
659
+ ModelGroup([
660
+ Model('Qwen/Qwen2.5-VL-3B-Instruct', 'Qwen/Qwen2.5-VL-3B-Instruct'),
661
+ Model('Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-7B-Instruct'),
662
+ Model('Qwen/Qwen2.5-VL-32B-Instruct', 'Qwen/Qwen2.5-VL-32B-Instruct'),
663
+ Model('Qwen/Qwen2.5-VL-72B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct'),
664
+ ]),
665
+ ModelGroup([
666
+ Model('Qwen/Qwen2.5-VL-3B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-3B-Instruct-AWQ'),
667
+ Model('Qwen/Qwen2.5-VL-7B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-7B-Instruct-AWQ'),
668
+ Model('Qwen/Qwen2.5-VL-32B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-32B-Instruct-AWQ'),
669
+ Model('Qwen/Qwen2.5-VL-72B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-72B-Instruct-AWQ'),
670
+ ]),
671
+ ],
672
+ TemplateType.qwen2_5_vl,
673
+ get_model_tokenizer_qwen2_5_vl,
674
+ model_arch=ModelArch.qwen2_vl,
675
+ architectures=['Qwen2_5_VLForConditionalGeneration'],
676
+ requires=['transformers>=4.49', 'qwen_vl_utils>=0.0.6', 'decord'],
677
+ tags=['vision', 'video']))
678
+
679
+
680
+ def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
681
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniConfig
682
+ from qwen_omni_utils import vision_process
683
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5OmniForConditionalGeneration
684
+ processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code=True)
685
+ kwargs['tokenizer'] = processor.tokenizer
686
+ kwargs['model_config'] = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code=True)
687
+ patch_qwen_vl_utils(vision_process)
688
+ kwargs['model_config'].enable_audio_output = get_env_args('ENABLE_AUDIO_OUTPUT', bool, True)
689
+ model, _ = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
690
+ if model:
691
+ use_submodel_func(model, 'thinker')
692
+ model.config.keys_to_ignore_at_inference += ['hidden_states', 'attention_mask']
693
+ model.config.talker_config.pad_token_id = None
694
+ return model, processor
695
+
696
+
697
+ register_model(
698
+ ModelMeta(
699
+ MLLMModelType.qwen2_5_omni,
700
+ [
701
+ ModelGroup([
702
+ Model('Qwen/Qwen2.5-Omni-3B', 'Qwen/Qwen2.5-Omni-3B'),
703
+ Model('Qwen/Qwen2.5-Omni-7B', 'Qwen/Qwen2.5-Omni-7B'),
704
+ ]),
705
+ ],
706
+ TemplateType.qwen2_5_omni,
707
+ get_model_tokenizer_qwen2_5_omni,
708
+ model_arch=ModelArch.qwen2_5_omni,
709
+ architectures=['Qwen2_5OmniModel'],
710
+ requires=['transformers>=4.50', 'soundfile', 'qwen_omni_utils', 'decord'],
711
+ tags=['vision', 'video', 'audio'],
712
+ additional_saved_files=['spk_dict.pt'],
713
+ ignore_patterns=[],
714
+ ))
715
+
716
+
717
+ def get_model_tokenizer_qwen2_audio(*args, **kwargs):
718
+ from transformers import Qwen2AudioForConditionalGeneration
719
+ kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2AudioForConditionalGeneration
720
+ return get_model_tokenizer_multimodal(*args, **kwargs)
721
+
722
+
723
+ register_model(
724
+ ModelMeta(
725
+ MLLMModelType.qwen2_audio,
726
+ [
727
+ ModelGroup([
728
+ Model('Qwen/Qwen2-Audio-7B-Instruct', 'Qwen/Qwen2-Audio-7B-Instruct'),
729
+ Model('Qwen/Qwen2-Audio-7B', 'Qwen/Qwen2-Audio-7B'),
730
+ ]),
731
+ ],
732
+ TemplateType.qwen2_audio,
733
+ get_model_tokenizer_qwen2_audio,
734
+ model_arch=ModelArch.qwen2_audio,
735
+ architectures=['Qwen2AudioForConditionalGeneration'],
736
+ requires=['transformers>=4.45,<4.49', 'librosa'],
737
+ tags=['audio'],
738
+ ))
739
+
740
+ register_model(
741
+ ModelMeta(
742
+ LLMModelType.marco_o1, [ModelGroup([Model('AIDC-AI/Marco-o1', 'AIDC-AI/Marco-o1')])],
743
+ TemplateType.marco_o1,
744
+ get_model_tokenizer_with_flash_attn,
745
+ model_arch=ModelArch.llama,
746
+ architectures=['Qwen2ForCausalLM'],
747
+ requires=['transformers>=4.37']))
748
+
749
+ register_model(
750
+ ModelMeta(
751
+ LLMModelType.qwq_preview, [ModelGroup([Model('Qwen/QwQ-32B-Preview', 'Qwen/QwQ-32B-Preview')])],
752
+ TemplateType.qwq_preview,
753
+ get_model_tokenizer_with_flash_attn,
754
+ model_arch=ModelArch.llama,
755
+ architectures=['Qwen2ForCausalLM'],
756
+ requires=['transformers>=4.37']))
757
+
758
+ register_model(
759
+ ModelMeta(
760
+ LLMModelType.qwq,
761
+ [ModelGroup([
762
+ Model('Qwen/QwQ-32B', 'Qwen/QwQ-32B'),
763
+ Model('Qwen/QwQ-32B-AWQ', 'Qwen/QwQ-32B-AWQ'),
764
+ ])],
765
+ TemplateType.qwq,
766
+ get_model_tokenizer_with_flash_attn,
767
+ model_arch=ModelArch.llama,
768
+ architectures=['Qwen2ForCausalLM'],
769
+ requires=['transformers>=4.37']))
770
+
771
+
772
+ def get_model_tokenizer_ovis(*args, **kwargs):
773
+ kwargs['attn_impl_keys'] = ['llm_attn_implementation']
774
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
775
+ if model is not None:
776
+ model.visual_tokenizer.to(model.dtype)
777
+ model.vte.to(model.dtype)
778
+
779
+ model.generation_config.cache_implementation = None
780
+ func_list = ['generate', 'forward', 'get_input_embeddings']
781
+ use_submodel_func(model, 'llm', func_list)
782
+ embedding = model.get_input_embeddings()
783
+ patch_output_clone(embedding)
784
+ try:
785
+ # fix device_map
786
+ from transformers.cache_utils import HybridCache
787
+
788
+ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
789
+ **kwargs) -> Tuple[torch.Tensor]:
790
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
791
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
792
+ return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)
793
+
794
+ if not hasattr(HybridCache, '_update_origin'):
795
+ HybridCache._update_origin = HybridCache.update
796
+ HybridCache.update = update
797
+ except ImportError:
798
+ pass
799
+ return model, tokenizer
800
+
801
+
802
+ register_model(
803
+ ModelMeta(
804
+ MLLMModelType.ovis1_6,
805
+ [
806
+ ModelGroup([
807
+ Model('AIDC-AI/Ovis1.6-Gemma2-9B', 'AIDC-AI/Ovis1.6-Gemma2-9B'),
808
+ Model('AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4', 'AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4'),
809
+ Model('AIDC-AI/Ovis1.6-Gemma2-27B', 'AIDC-AI/Ovis1.6-Gemma2-27B'),
810
+ ]),
811
+ ],
812
+ TemplateType.ovis1_6,
813
+ get_model_tokenizer_ovis,
814
+ model_arch=ModelArch.ovis1_6,
815
+ architectures=['Ovis'],
816
+ tags=['vision'],
817
+ requires=['transformers>=4.42'],
818
+ ))
819
+
820
+ register_model(
821
+ ModelMeta(
822
+ MLLMModelType.ovis1_6_llama3,
823
+ [
824
+ ModelGroup([
825
+ Model('AIDC-AI/Ovis1.6-Llama3.2-3B', 'AIDC-AI/Ovis1.6-Llama3.2-3B'),
826
+ ]),
827
+ ],
828
+ TemplateType.ovis1_6_llama3,
829
+ get_model_tokenizer_ovis,
830
+ model_arch=ModelArch.ovis1_6,
831
+ architectures=['Ovis'],
832
+ tags=['vision'],
833
+ ))
834
+
835
+ register_model(
836
+ ModelMeta(
837
+ MLLMModelType.ovis2,
838
+ [
839
+ ModelGroup([
840
+ Model('AIDC-AI/Ovis2-1B', 'AIDC-AI/Ovis2-1B'),
841
+ Model('AIDC-AI/Ovis2-2B', 'AIDC-AI/Ovis2-2B'),
842
+ Model('AIDC-AI/Ovis2-4B', 'AIDC-AI/Ovis2-4B'),
843
+ Model('AIDC-AI/Ovis2-8B', 'AIDC-AI/Ovis2-8B'),
844
+ Model('AIDC-AI/Ovis2-16B', 'AIDC-AI/Ovis2-16B'),
845
+ Model('AIDC-AI/Ovis2-34B', 'AIDC-AI/Ovis2-34B'),
846
+ ]),
847
+ ],
848
+ TemplateType.ovis2,
849
+ get_model_tokenizer_ovis,
850
+ model_arch=ModelArch.ovis1_6,
851
+ architectures=['Ovis'],
852
+ tags=['vision'],
853
+ requires=['transformers>=4.46.2', 'moviepy<2'],
854
+ ))
855
+
856
+ register_model(
857
+ ModelMeta(
858
+ RMModelType.qwen2_reward,
859
+ [
860
+ ModelGroup([
861
+ Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
862
+ ]),
863
+ ],
864
+ TemplateType.qwen,
865
+ get_model_tokenizer_reward_model,
866
+ architectures=['Qwen2ForRewardModel'],
867
+ requires=['transformers>=4.37'],
868
+ ))
869
+
870
+ register_model(
871
+ ModelMeta(
872
+ RMModelType.qwen2_5_prm,
873
+ [
874
+ ModelGroup([
875
+ Model('Qwen/Qwen2.5-Math-PRM-7B', 'Qwen/Qwen2.5-Math-PRM-7B'),
876
+ Model('Qwen/Qwen2.5-Math-7B-PRM800K', 'Qwen/Qwen2.5-Math-7B-PRM800K'),
877
+ Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'),
878
+ ]),
879
+ ],
880
+ TemplateType.qwen2_5_math_prm,
881
+ get_model_tokenizer_reward_model,
882
+ task_type='prm',
883
+ architectures=['Qwen2ForProcessRewardModel'],
884
+ requires=['transformers>=4.37'],
885
+ ))
886
+
887
+ register_model(
888
+ ModelMeta(
889
+ RMModelType.qwen2_5_math_reward,
890
+ [
891
+ ModelGroup([
892
+ Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
893
+ ]),
894
+ ],
895
+ TemplateType.qwen2_5_math,
896
+ get_model_tokenizer_reward_model,
897
+ architectures=['Qwen2ForRewardModel'],
898
+ requires=['transformers>=4.37'],
899
+ ))
swift/llm/model/model/skywork.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import Any, Dict
4
+
5
+ from swift.llm import TemplateType
6
+ from ..constant import LLMModelType, RMModelType
7
+ from ..model_arch import ModelArch
8
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
9
+ from ..utils import ModelInfo
10
+
11
+
12
+ def get_skywork_model_tokenizer(model_dir: str,
13
+ model_info: ModelInfo,
14
+ model_kwargs: Dict[str, Any],
15
+ load_model: bool = True,
16
+ **kwargs):
17
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
18
+ if 'chat' in model_dir:
19
+ tokenizer.add_tokens('[USER]')
20
+ tokenizer.add_tokens('[BOT]')
21
+ tokenizer.add_tokens('[SEP]')
22
+ return model, tokenizer
23
+
24
+
25
+ register_model(
26
+ ModelMeta(
27
+ LLMModelType.skywork,
28
+ [
29
+ ModelGroup([
30
+ Model('skywork/Skywork-13B-base', 'skywork/Skywork-13B-base'),
31
+ Model('skywork/Skywork-13B-chat'),
32
+ ]),
33
+ ],
34
+ TemplateType.skywork,
35
+ get_skywork_model_tokenizer,
36
+ architectures=['SkyworkForCausalLM'],
37
+ model_arch=ModelArch.llama,
38
+ ))
39
+
40
+ register_model(
41
+ ModelMeta(
42
+ LLMModelType.skywork_o1,
43
+ [
44
+ ModelGroup([
45
+ Model('AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B', 'Skywork/Skywork-o1-Open-Llama-3.1-8B'),
46
+ ]),
47
+ ],
48
+ TemplateType.skywork_o1,
49
+ get_model_tokenizer_with_flash_attn,
50
+ architectures=['LlamaForCausalLM'],
51
+ requires=['transformers>=4.43'],
52
+ model_arch=ModelArch.llama,
53
+ ))
54
+
55
+ register_model(
56
+ ModelMeta(
57
+ RMModelType.llama3_2_reward,
58
+ [
59
+ ModelGroup([
60
+ Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B', 'Skywork/Skywork-Reward-Llama-3.1-8B'),
61
+ Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2', 'Skywork/Skywork-Reward-Llama-3.1-8B-v0.2'),
62
+ ]),
63
+ ModelGroup([
64
+ Model('AI-ModelScope/GRM_Llama3.1_8B_rewardmodel-ft', 'Ray2333/GRM_Llama3.1_8B_rewardmodel-ft'),
65
+ Model('AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft', 'Ray2333/GRM-llama3.2-3B-rewardmodel-ft'),
66
+ ])
67
+ ],
68
+ TemplateType.llama3_2,
69
+ get_model_tokenizer_with_flash_attn,
70
+ requires=['transformers>=4.43'],
71
+ architectures=['LlamaForSequenceClassification'],
72
+ model_arch=ModelArch.llama,
73
+ ))
74
+
75
+ register_model(
76
+ ModelMeta(
77
+ RMModelType.gemma_reward,
78
+ [
79
+ ModelGroup([
80
+ Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B', 'Skywork/Skywork-Reward-Gemma-2-27B'),
81
+ Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2', 'Skywork/Skywork-Reward-Gemma-2-27B-v0.2'),
82
+ ]),
83
+ ],
84
+ TemplateType.gemma,
85
+ get_model_tokenizer_with_flash_attn,
86
+ requires=['transformers>=4.42'],
87
+ architectures=['Gemma2ForSequenceClassification'],
88
+ model_arch=ModelArch.llama,
89
+ ))
swift/llm/model/model/stepfun.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import shutil
4
+ import sys
5
+
6
+ from transformers import AutoModel
7
+
8
+ from swift.llm import TemplateType
9
+ from ..constant import MLLMModelType
10
+ from ..model_arch import ModelArch
11
+ from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
12
+ get_model_tokenizer_with_flash_attn, register_model)
13
+ from ..utils import git_clone_github, safe_snapshot_download
14
+
15
+
16
+ def get_model_tokenizer_got_ocr2(*args, **kwargs):
17
+ kwargs['automodel_class'] = AutoModel
18
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
19
+ return model, tokenizer
20
+
21
+
22
+ register_model(
23
+ ModelMeta(
24
+ MLLMModelType.got_ocr2, [
25
+ ModelGroup([
26
+ Model('stepfun-ai/GOT-OCR2_0', 'stepfun-ai/GOT-OCR2_0'),
27
+ ]),
28
+ ],
29
+ TemplateType.got_ocr2,
30
+ get_model_tokenizer_got_ocr2,
31
+ model_arch=ModelArch.got_ocr2,
32
+ architectures=['GOTQwenForCausalLM'],
33
+ tags=['vision']))
34
+
35
+
36
+ def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
37
+ from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration
38
+ GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer')
39
+ model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
40
+ return model, processor
41
+
42
+
43
+ register_model(
44
+ ModelMeta(
45
+ MLLMModelType.got_ocr2_hf, [
46
+ ModelGroup([
47
+ Model('stepfun-ai/GOT-OCR-2.0-hf', 'stepfun-ai/GOT-OCR-2.0-hf'),
48
+ ]),
49
+ ],
50
+ TemplateType.got_ocr2_hf,
51
+ get_model_tokenizer_got_ocr2_hf,
52
+ model_arch=ModelArch.got_ocr2_hf,
53
+ architectures=['GOTQwenForCausalLM'],
54
+ tags=['vision']))
55
+
56
+
57
+ def get_model_tokenizer_step_audio(*args, **kwargs):
58
+ local_repo_path = kwargs.get('local_repo_path')
59
+ if not local_repo_path:
60
+ local_repo_path = git_clone_github('https://github.com/stepfun-ai/Step-Audio.git')
61
+ sys.path.append(local_repo_path)
62
+ from tokenizer import StepAudioTokenizer
63
+ encoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-Tokenizer', check_local=True)
64
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
65
+ if model is not None:
66
+ model.encoder = StepAudioTokenizer(encoder_path)
67
+ # from tts import StepAudioTTS
68
+ # if not os.path.exists('speakers'):
69
+ # shutil.copytree(os.path.join(local_repo_path, 'speakers'), 'speakers')
70
+ # decoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-TTS-3B', check_local=True)
71
+ # model.decoder = StepAudioTTS(decoder_path, model.encoder)
72
+ return model, tokenizer
73
+
74
+
75
+ register_model(
76
+ ModelMeta(
77
+ MLLMModelType.step_audio, [
78
+ ModelGroup([
79
+ Model('stepfun-ai/Step-Audio-Chat', 'stepfun-ai/Step-Audio-Chat'),
80
+ ]),
81
+ ],
82
+ TemplateType.step_audio,
83
+ get_model_tokenizer_step_audio,
84
+ architectures=['Step1ForCausalLM'],
85
+ requires=['funasr', 'sox', 'conformer', 'openai-whisper', 'librosa'],
86
+ tags=['audio']))
swift/llm/model/model/telechat.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from transformers import GenerationConfig
4
+
5
+ from swift.llm import TemplateType
6
+ from ..constant import LLMModelType
7
+ from ..model_arch import ModelArch
8
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
9
+
10
+
11
+ def get_model_tokenizer_telechat(*args, **kwargs):
12
+ model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
13
+ model_dir = args[0]
14
+ generation_config = GenerationConfig.from_pretrained(model_dir)
15
+ for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'user_token_id', 'bot_token_id']:
16
+ setattr(tokenizer, k, getattr(generation_config, k))
17
+ return model, tokenizer
18
+
19
+
20
+ register_model(
21
+ ModelMeta(
22
+ LLMModelType.telechat,
23
+ [
24
+ ModelGroup([
25
+ Model('TeleAI/TeleChat-7B', 'Tele-AI/telechat-7B'),
26
+ Model('TeleAI/TeleChat-12B', 'Tele-AI/TeleChat-12B'),
27
+ Model('TeleAI/TeleChat-12B-v2', 'Tele-AI/TeleChat-12B-v2'),
28
+ Model('TeleAI/TeleChat-52B', 'TeleAI/TeleChat-52B'),
29
+ ]),
30
+ ModelGroup([
31
+ Model('swift/TeleChat-12B-V2-GPTQ-Int4'),
32
+ ]),
33
+ ModelGroup([
34
+ Model('TeleAI/TeleChat2-35B', 'Tele-AI/TeleChat2-35B'),
35
+ Model('TeleAI/TeleChat2-115B', 'Tele-AI/TeleChat2-115B'),
36
+ ]),
37
+ ],
38
+ TemplateType.telechat,
39
+ get_model_tokenizer_telechat,
40
+ model_arch=ModelArch.telechat,
41
+ architectures=['TelechatForCausalLM', 'TeleChatForCausalLM'],
42
+ ))
43
+
44
+ register_model(
45
+ ModelMeta(
46
+ LLMModelType.telechat2,
47
+ [
48
+ ModelGroup([
49
+ Model('TeleAI/TeleChat2-3B', 'Tele-AI/TeleChat2-3B'),
50
+ Model('TeleAI/TeleChat2-7B-32K', 'Tele-AI/TeleChat2-7B-32K'),
51
+ Model('TeleAI/TeleChat2-35B-32K', 'Tele-AI/TeleChat2-35B-32K'),
52
+ Model('TeleAI/TeleChat2-35B-Nov', 'Tele-AI/TeleChat2-35B-Nov'),
53
+ ]),
54
+ ],
55
+ TemplateType.telechat2,
56
+ get_model_tokenizer_telechat,
57
+ model_arch=ModelArch.telechat,
58
+ architectures=['TeleChat2ForCausalLM'],
59
+ ))
swift/llm/model/model/valley.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ from functools import partial, wraps
5
+ from typing import Any, Dict
6
+
7
+ from swift.llm import TemplateType
8
+ from ..constant import MLLMModelType
9
+ from ..model_arch import ModelArch
10
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
11
+ from ..utils import ModelInfo, git_clone_github, safe_snapshot_download
12
+
13
+
14
+ def get_model_tokenizer_valley(model_dir: str,
15
+ model_info: ModelInfo,
16
+ model_kwargs: Dict[str, Any],
17
+ load_model: bool = True,
18
+ **kwargs):
19
+ llm_model_type = kwargs.pop('llm_model_type')
20
+ local_repo_path = kwargs.get('local_repo_path')
21
+ if not local_repo_path:
22
+ repo_path = 'https://github.com/bytedance/Valley.git'
23
+ local_repo_path = git_clone_github(repo_path)
24
+ sys.path.append(local_repo_path)
25
+
26
+ if llm_model_type == 'valley':
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from valley_eagle.model.language_model.valley_qwen2 import ValleyQwen2ForCausalLM, ValleyConfig
29
+ model_config = ValleyConfig.from_pretrained(model_dir)
30
+ model_config.mm_vision_tower = safe_snapshot_download(
31
+ 'AI-ModelScope/siglip-so400m-patch14-384', check_local=True)
32
+ model_config.eagle_vision_tower = safe_snapshot_download('Qwen/Qwen2-VL-7B-Instruct', check_local=True)
33
+ automodel_class = ValleyQwen2ForCausalLM
34
+
35
+ if not hasattr(ValleyQwen2ForCausalLM, '_origin_forward'):
36
+ forward = ValleyQwen2ForCausalLM.forward
37
+ ValleyQwen2ForCausalLM._origin_forward = forward
38
+
39
+ @wraps(forward)
40
+ def new_forward(*args, **kwargs):
41
+ import torch
42
+ outputs = forward(*args, **kwargs)
43
+ loss = outputs.loss
44
+ if loss is not None and loss.shape[-1] > 0:
45
+ loss = torch.mean(loss, dim=-1)
46
+ return CausalLMOutputWithPast(
47
+ loss=loss,
48
+ logits=outputs.logits,
49
+ past_key_values=outputs.past_key_values,
50
+ hidden_states=outputs.hidden_states,
51
+ attentions=outputs.attentions,
52
+ )
53
+
54
+ ValleyQwen2ForCausalLM.forward = new_forward
55
+ kwargs['model_config'] = model_config
56
+ kwargs['automodel_class'] = automodel_class
57
+ model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
58
+ if model is not None:
59
+ model.generation_config.repetition_penalty = 1.0 # Otherwise, Error. Same for original code.
60
+ from transformers import AutoProcessor, SiglipImageProcessor
61
+ tokenizer.image_processor = SiglipImageProcessor.from_pretrained(model.config.mm_vision_tower)
62
+ tokenizer.qwen2vl_processor = AutoProcessor.from_pretrained(
63
+ model.config.eagle_vision_tower, max_pixels=1280 * 28 * 28)
64
+ tokenizer.image_processor.crop_size = tokenizer.image_processor.size['height']
65
+ return model, tokenizer
66
+
67
+
68
+ register_model(
69
+ ModelMeta(
70
+ MLLMModelType.valley,
71
+ [
72
+ ModelGroup([
73
+ Model('bytedance-research/Valley-Eagle-7B'),
74
+ ], ),
75
+ ],
76
+ TemplateType.valley,
77
+ partial(get_model_tokenizer_valley, llm_model_type='valley'),
78
+ architectures=['ValleyQwen2ForCausalLM'],
79
+ model_arch=ModelArch.valley,
80
+ requires=['transformers>=4.42', 'av'],
81
+ tags=['vision'],
82
+ ))
swift/llm/model/model/yi.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import sys
4
+ from typing import Any, Dict
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ from swift.llm import TemplateType
9
+ from swift.utils import get_logger
10
+ from ..constant import LLMModelType, MLLMModelType
11
+ from ..model_arch import ModelArch
12
+ from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
13
+ from ..utils import ModelInfo, git_clone_github
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ def get_model_tokenizer_yi(model_dir, *args, **kwargs):
19
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
20
+ return get_model_tokenizer_with_flash_attn(model_dir, *args, tokenizer=tokenizer, **kwargs)
21
+
22
+
23
+ def get_model_tokenizer_yi_vl(model_dir: str,
24
+ model_info: ModelInfo,
25
+ model_kwargs: Dict[str, Any],
26
+ load_model: bool = True,
27
+ **kwargs):
28
+ local_repo_path = kwargs.get('local_repo_path')
29
+ if not local_repo_path:
30
+ local_repo_path = git_clone_github('https://github.com/01-ai/Yi')
31
+ sys.path.append(os.path.join(local_repo_path, 'VL'))
32
+ from llava.model import LlavaLlamaForCausalLM, LlavaConfig
33
+ from llava.model.constants import key_info
34
+
35
+ model_config = LlavaConfig.from_pretrained(model_dir)
36
+ mm_vision_tower = model_config.mm_vision_tower
37
+ model_config.mm_vision_tower = os.path.join(model_dir, *mm_vision_tower.rsplit('/', maxsplit=2)[-2:])
38
+ model_config.attention_dropout = 0.
39
+ key_info['model_path'] = model_dir
40
+ kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaLlamaForCausalLM
41
+ model, tokenizer = get_model_tokenizer_yi(
42
+ model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
43
+ if model is not None:
44
+ logger.info('Please ignore the above warning.')
45
+ logger.info('Loading the parameters of vision_tower...')
46
+ model.resize_token_embeddings(len(tokenizer))
47
+ vision_tower = model.get_vision_tower()
48
+ vision_tower.load_model()
49
+ vision_tower.to(device=model.device, dtype=model_config.torch_dtype)
50
+ if not hasattr(model.config, 'max_sequence_length'):
51
+ model.config.max_sequence_length = 2048
52
+ tokenizer.image_processor = vision_tower.image_processor
53
+ return model, tokenizer
54
+
55
+
56
+ register_model(
57
+ ModelMeta(
58
+ MLLMModelType.yi_vl,
59
+ [
60
+ ModelGroup([
61
+ Model('01ai/Yi-VL-6B', '01-ai/Yi-VL-6B'),
62
+ Model('01ai/Yi-VL-34B', '01-ai/Yi-VL-34B'),
63
+ ], ),
64
+ ],
65
+ TemplateType.yi_vl,
66
+ get_model_tokenizer_yi_vl,
67
+ model_arch=ModelArch.llava_llama,
68
+ architectures=['LlavaLlamaForCausalLM'],
69
+ requires=['transformers>=4.34'],
70
+ tags=['vision'],
71
+ ))
72
+
73
+ register_model(
74
+ ModelMeta(
75
+ LLMModelType.yi,
76
+ [ # yi
77
+ ModelGroup([
78
+ Model('01ai/Yi-6B', '01-ai/Yi-6B'),
79
+ Model('01ai/Yi-6B-200K', '01-ai/Yi-6B-200K'),
80
+ Model('01ai/Yi-6B-Chat', '01-ai/Yi-6B-Chat'),
81
+ Model('01ai/Yi-6B-Chat-4bits', '01-ai/Yi-6B-Chat-4bits'),
82
+ Model('01ai/Yi-6B-Chat-8bits', '01-ai/Yi-6B-Chat-8bits'),
83
+ Model('01ai/Yi-9B', '01-ai/Yi-9B'),
84
+ Model('01ai/Yi-9B-200K', '01-ai/Yi-9B-200K'),
85
+ Model('01ai/Yi-34B', '01-ai/Yi-34B'),
86
+ Model('01ai/Yi-34B-200K', '01-ai/Yi-34B-200K'),
87
+ Model('01ai/Yi-34B-Chat', '01-ai/Yi-34B-Chat'),
88
+ Model('01ai/Yi-34B-Chat-4bits', '01-ai/Yi-34B-Chat-4bits'),
89
+ Model('01ai/Yi-34B-Chat-8bits', '01-ai/Yi-34B-Chat-8bits'),
90
+ ]),
91
+ # yi1.5
92
+ ModelGroup([
93
+ Model('01ai/Yi-1.5-6B', '01-ai/Yi-1.5-6B'),
94
+ Model('01ai/Yi-1.5-6B-Chat', '01-ai/Yi-1.5-6B-Chat'),
95
+ Model('01ai/Yi-1.5-9B', '01-ai/Yi-1.5-9B'),
96
+ Model('01ai/Yi-1.5-9B-Chat', '01-ai/Yi-1.5-9B-Chat'),
97
+ Model('01ai/Yi-1.5-9B-Chat-16K', '01-ai/Yi-1.5-9B-Chat-16K'),
98
+ Model('01ai/Yi-1.5-34B', '01-ai/Yi-1.5-34B'),
99
+ Model('01ai/Yi-1.5-34B-Chat', '01-ai/Yi-1.5-34B-Chat'),
100
+ Model('01ai/Yi-1.5-34B-Chat-16K', '01-ai/Yi-1.5-34B-Chat-16K'),
101
+ ]),
102
+ # yi1.5-quant
103
+ ModelGroup([
104
+ Model('AI-ModelScope/Yi-1.5-6B-Chat-GPTQ', 'modelscope/Yi-1.5-6B-Chat-GPTQ'),
105
+ Model('AI-ModelScope/Yi-1.5-6B-Chat-AWQ', 'modelscope/Yi-1.5-6B-Chat-AWQ'),
106
+ Model('AI-ModelScope/Yi-1.5-9B-Chat-GPTQ', 'modelscope/Yi-1.5-9B-Chat-GPTQ'),
107
+ Model('AI-ModelScope/Yi-1.5-9B-Chat-AWQ', 'modelscope/Yi-1.5-9B-Chat-AWQ'),
108
+ Model('AI-ModelScope/Yi-1.5-34B-Chat-GPTQ', 'modelscope/Yi-1.5-34B-Chat-GPTQ'),
109
+ Model('AI-ModelScope/Yi-1.5-34B-Chat-AWQ', 'modelscope/Yi-1.5-34B-Chat-AWQ'),
110
+ ]),
111
+ ],
112
+ TemplateType.chatml,
113
+ get_model_tokenizer_yi,
114
+ architectures=['LlamaForCausalLM'],
115
+ model_arch=ModelArch.llama,
116
+ ))
117
+
118
+ register_model(
119
+ ModelMeta(
120
+ LLMModelType.yi_coder, [
121
+ ModelGroup([
122
+ Model('01ai/Yi-Coder-1.5B', '01-ai/Yi-Coder-1.5B'),
123
+ Model('01ai/Yi-Coder-9B', '01-ai/Yi-Coder-9B'),
124
+ Model('01ai/Yi-Coder-1.5B-Chat', '01-ai/Yi-Coder-1.5B-Chat'),
125
+ Model('01ai/Yi-Coder-9B-Chat', '01-ai/Yi-Coder-9B-Chat'),
126
+ ], )
127
+ ],
128
+ TemplateType.yi_coder,
129
+ get_model_tokenizer_yi,
130
+ architectures=['LlamaForCausalLM'],
131
+ model_arch=ModelArch.llama,
132
+ tags=['coding']))
133
+
134
+ register_model(
135
+ ModelMeta(
136
+ LLMModelType.sus,
137
+ [
138
+ ModelGroup([
139
+ Model('SUSTC/SUS-Chat-34B', 'SUSTech/SUS-Chat-34B'),
140
+ ]),
141
+ ],
142
+ TemplateType.sus,
143
+ get_model_tokenizer_yi,
144
+ model_arch=ModelArch.llama,
145
+ architectures=['LlamaForCausalLM'],
146
+ ))
swift/llm/sampling/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampling import sampling_main
swift/llm/sampling/base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from swift.llm import SamplingArguments
4
+ from swift.plugin import orms, prms
5
+ from swift.utils import get_logger
6
+
7
+ logger = get_logger()
8
+
9
+
10
+ class Sampler:
11
+
12
+ def __init__(self, input_args: SamplingArguments):
13
+ self.args = input_args
14
+ self.template = None
15
+ self.processor = None
16
+ self.prm_model = None
17
+ self.orm_model = None
18
+ self._prepare_model_tokenizer()
19
+ self._prepare_template()
20
+ self._prepare_rm()
21
+
22
+ def _prepare_model_tokenizer(self):
23
+ args = self.args
24
+ _, self.processor = args.get_model_processor(load_model=False)
25
+
26
+ def _prepare_rm(self):
27
+ if self.args.prm_model is None:
28
+ self.prm_model = None
29
+ logger.warning('prm_model is None.')
30
+ elif self.args.prm_model in prms:
31
+ self.prm_model = prms[self.args.prm_model]()
32
+ else:
33
+ from swift.llm import PtEngine
34
+ self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)
35
+
36
+ if self.args.orm_model is None:
37
+ self.orm_model = None
38
+ logger.warning('orm_model is None.')
39
+ elif self.args.orm_model in orms:
40
+ self.orm_model = orms[self.args.orm_model]()
41
+ else:
42
+ from swift.llm import PtEngine
43
+ self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)
44
+
45
+ def _prepare_template(self) -> None:
46
+ template = self.args.get_template(self.processor)
47
+ self.template = template
48
+ self.template.set_mode('train')
49
+
50
+ def truncate_input(self, slices: List[Dict[str, Any]]):
51
+ """Truncate the input rows to avoid hitting the max length of the policy model"""
52
+ return slices
53
+
54
+ def do_sample(self, data):
55
+ raise NotImplementedError
swift/llm/sampling/distill_sampler.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+ from typing import List, Optional
4
+
5
+ from openai import OpenAI
6
+
7
+ from swift.llm.infer.protocol import InferRequest, RequestConfig
8
+ from swift.llm.sampling.vanilla_sampler import VanillaSampler
9
+ from .utils import get_messages_md5
10
+
11
+
12
+ class OpenAI_Engine():
13
+
14
+ def __init__(
15
+ self,
16
+ model: str,
17
+ stream: bool = False,
18
+ base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
19
+ api_key: str = '',
20
+ **kwargs,
21
+ ):
22
+ self.model = model
23
+ self.stream = stream
24
+ self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)
25
+
26
+ def infer(
27
+ self,
28
+ infer_requests: List[InferRequest],
29
+ request_config: Optional[RequestConfig] = None,
30
+ ):
31
+ resp_contents = []
32
+ for infer_request in infer_requests:
33
+ completion = self.client.chat.completions.create(
34
+ model=self.model,
35
+ messages=infer_request['messages'],
36
+ temperature=request_config.temperature,
37
+ top_p=request_config.top_p,
38
+ max_tokens=request_config.max_tokens,
39
+ stream=self.stream,
40
+ )
41
+ if self.stream:
42
+ reasoning_content = ''
43
+ content = ''
44
+ for chunk in completion:
45
+ chunk_choices = chunk.choices
46
+ if len(chunk_choices) == 0:
47
+ continue
48
+ reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
49
+ chunk_choices[0].delta, 'reasoning_content') else ''
50
+ answer_chunk = chunk_choices[0].delta.content
51
+ if reasoning_chunk:
52
+ reasoning_content += reasoning_chunk
53
+ elif answer_chunk:
54
+ content += answer_chunk
55
+ else:
56
+ if hasattr(completion.choices[0].message, 'reasoning_content'):
57
+ reasoning_content = completion.choices[0].message.reasoning_content
58
+ content = completion.choices[0].message.content
59
+ assert len(content) > 0, 'Empty completion'
60
+ if reasoning_content:
61
+ resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
62
+ else:
63
+ resp_content = content
64
+ resp_contents.append(resp_content)
65
+
66
+ return resp_contents
67
+
68
+
69
+ class DistillSampler(VanillaSampler):
70
+
71
+ def __init__(self, *args, **kwargs):
72
+ super(VanillaSampler, self).__init__(*args, **kwargs)
73
+ assert self.args.sampler_engine == 'client'
74
+ _Engine = OpenAI_Engine
75
+ self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
76
+ self.infer_engine.strict = False
77
+ self.caches = self.read_cache()
78
+
79
+ def _prepare_model_tokenizer(self):
80
+ pass
81
+
82
+ def _prepare_template(self):
83
+ pass
84
+
85
+ def extract_choice(self, resp):
86
+ message = resp.choices[0].message
87
+ if hasattr(message, 'reasoning_content'):
88
+ reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
89
+ else:
90
+ reps_content = message.content
91
+ return reps_content
92
+
93
+ def generate(self, data):
94
+ resp_all = []
95
+ infer_requests = []
96
+ sent = 0
97
+ rows = self.convert_data_to_rows(data)
98
+ for idx, row in enumerate(rows):
99
+ row = deepcopy(row)
100
+ messages = row['messages']
101
+ uuid = get_messages_md5(row)
102
+ if uuid in self.caches:
103
+ choices = self.caches[uuid]['choices']
104
+ if len(choices) == self.args.num_return_sequences:
105
+ continue
106
+ if self.args.system:
107
+ if messages[0]['role'] == 'system':
108
+ messages[0]['content'] = self.args.system
109
+ else:
110
+ messages.insert(0, {'role': 'system', 'content': self.args.system})
111
+ if messages[-1]['role'] == 'assistant':
112
+ messages = messages[:-1]
113
+
114
+ row['messages'] = messages
115
+ infer_request = row
116
+ for i in range(self.args.num_return_sequences):
117
+ infer_requests.append(deepcopy(infer_request))
118
+ sent += 1
119
+
120
+ request_config = RequestConfig(
121
+ max_tokens=self.args.max_new_tokens,
122
+ temperature=self.args.temperature,
123
+ top_k=self.args.top_k,
124
+ top_p=self.args.top_p,
125
+ )
126
+
127
+ resp_list = []
128
+ if len(infer_requests) > 0:
129
+ resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
130
+
131
+ _cur = 0
132
+ for idx, row in enumerate(rows):
133
+ row = deepcopy(row)
134
+ uuid = get_messages_md5(row)
135
+ if uuid in self.caches:
136
+ choices = self.caches[uuid]['choices']
137
+ if len(choices) == self.args.num_return_sequences:
138
+ row['choices'] = choices
139
+ resp_all.append(row)
140
+ continue
141
+
142
+ resps = row
143
+ resps['choices'] = []
144
+ for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
145
+ resps['choices'].append(resp_list[j])
146
+ resp_all.append(resps)
147
+ _cur += 1
148
+ return resp_all
swift/llm/sampling/mcts.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from copy import deepcopy
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import InferRequest, SamplingArguments
10
+ from swift.llm.infer.protocol import UsageInfo
11
+ from swift.utils import get_logger
12
+ from .base import Sampler
13
+ from .utils import get_reward, perform_infer
14
+
15
+ logger = get_logger()
16
+
17
+ NXT_PROMPT = """Continue.
18
+ """
19
+
20
+ next_message = {
21
+ 'role': 'user',
22
+ 'content': NXT_PROMPT,
23
+ }
24
+
25
+
26
+ class LanguageNode:
27
+
28
+ def __init__(
29
+ self,
30
+ step: str = None,
31
+ sep_token: str = None,
32
+ parent: 'LanguageNode' = None,
33
+ ):
34
+ self.parent = parent
35
+
36
+ if sep_token:
37
+ self.sep_token = sep_token
38
+ else:
39
+ self.sep_token = parent.sep_token
40
+
41
+ if parent:
42
+ self.path = parent.path[:] + [step]
43
+ self.answer = parent.answer + step + self.sep_token
44
+ self.depth = parent.depth + 1
45
+ else:
46
+ self.path = []
47
+ self.answer = ''
48
+ self.depth = 0
49
+
50
+ self.active_children = []
51
+ self.children = []
52
+ self.visit_count = 0
53
+ self.process_reward = 0.0
54
+ self.outcome_reward = 0.0
55
+ self.terminated = False
56
+ self.correct = False
57
+
58
+ def is_leaf(self):
59
+ return len(self.children) == 0
60
+
61
+ def is_root(self):
62
+ return self.parent is None
63
+
64
+ def visit(self):
65
+ self.visit_count += 1
66
+
67
+ def init_and_update_value(self, value):
68
+ self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1)
69
+
70
+ def add_child(self, child: 'LanguageNode'):
71
+ self.children.append(child)
72
+ if not child.terminated:
73
+ self.active_children.append(child)
74
+
75
+ def collect(self):
76
+ result = {
77
+ 'path': self.path,
78
+ 'depth': self.depth,
79
+ 'visit_count': self.visit_count,
80
+ 'process_reward': self.process_reward,
81
+ 'outcome_reward': self.outcome_reward,
82
+ 'terminated': str(self.terminated),
83
+ 'correct': str(self.correct),
84
+ 'children': [child.collect() for child in self.children],
85
+ }
86
+ return result
87
+
88
+ def __lt__(self, other):
89
+ return self.outcome_reward < other.outcome_reward
90
+
91
+
92
+ class MctsSampler(Sampler):
93
+
94
+ def __init__(self, input_args: SamplingArguments):
95
+ super().__init__(input_args)
96
+ self.usage_info = UsageInfo(0, 0, 0)
97
+
98
+ def _prepare_model_tokenizer(self):
99
+ args = self.args
100
+ self.infer_kwargs = {}
101
+ if args.sampler_engine == 'client':
102
+ from swift.llm import InferClient
103
+ api_key = args.api_key
104
+ base_url = args.base_url
105
+ self.infer_engine = [
106
+ InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)
107
+ ]
108
+ self.infer_kwargs['model'] = args.model
109
+ else:
110
+ _Engine = self.get_infer_engine()
111
+ self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
112
+
113
+ def get_infer_engine(self):
114
+ if self.args.sampler_engine == 'pt':
115
+ from swift.llm import PtEngine
116
+ _Engine = PtEngine
117
+ elif self.args.sampler_engine == 'vllm':
118
+ from swift.llm import VllmEngine
119
+ _Engine = VllmEngine
120
+ elif self.args.sampler_engine == 'lmdeploy':
121
+ from swift.llm import LmdeployEngine
122
+ _Engine = LmdeployEngine
123
+ elif self.args.sampler_engine == 'no':
124
+ _Engine = None
125
+ else:
126
+ raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
127
+ return _Engine
128
+
129
+ def _prepare_template(self) -> None:
130
+ # Hack from super()
131
+ self._prepare_request_configs()
132
+
133
+ def _prepare_request_configs(self):
134
+ _args = self.args
135
+ request_config = _args.get_request_config()
136
+ request_config.stop = _args.stop_words
137
+ request_config.seed = _args.seed
138
+ self.expand_request_configs = []
139
+ self.rollout_request_configs = []
140
+ for i in range(_args.num_return_sequences):
141
+ expand_request_config = deepcopy(request_config)
142
+ expand_request_config.n = 1
143
+ expand_request_config.num_beams = expand_request_config.n
144
+ expand_request_config.seed += i
145
+ self.expand_request_configs.append(expand_request_config)
146
+ rollout_request_config = deepcopy(request_config)
147
+ rollout_request_config.max_tokens = 500
148
+ rollout_request_config.temperature = 0.0
149
+ rollout_request_config.n = 1
150
+ self.rollout_request_configs.append(rollout_request_config)
151
+
152
+ def update_usage_info(self, response):
153
+ for key, value in self.usage_info.__dict__.items():
154
+ update_value = getattr(response.usage, key, None) + value
155
+ setattr(self.usage_info, key, update_value)
156
+
157
+ def search_single(self, query, ground_truth):
158
+
159
+ def _uct(uct_curr_node: LanguageNode):
160
+ alpha = _args.process_reward_rate
161
+ value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward
162
+ if uct_curr_node.is_root():
163
+ return value
164
+
165
+ exploitation_score = value
166
+ exploration_score = (
167
+ _args.exploration_rate
168
+ * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1)))
169
+
170
+ return exploration_score + exploitation_score
171
+
172
+ def _select(select_curr_node: LanguageNode):
173
+ while not select_curr_node.is_leaf():
174
+ select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x))
175
+ return select_curr_node
176
+
177
+ def _expand(expand_curr_node: LanguageNode):
178
+ n = _args.num_return_sequences - len(expand_curr_node.children)
179
+ if expand_curr_node.is_root():
180
+ infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)]
181
+ else:
182
+ history_message = {
183
+ 'role': 'assistant',
184
+ 'content': expand_curr_node.answer,
185
+ }
186
+ infer_request = InferRequest(system_message + [prompt_message, history_message, next_message])
187
+ infer_requests = [infer_request for _ in range(n)]
188
+
189
+ # e_time = time.time()
190
+ # To perform the Expand operation in parallel,
191
+ # there's no need to consider the order for now, since the Prompt is the same.
192
+ expand_iter_index = 0
193
+ while True:
194
+ responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs,
195
+ **self.infer_kwargs)
196
+ if len(responses) > 0:
197
+ break
198
+ if expand_iter_index == 5:
199
+ raise ValueError('Expand should not return any response')
200
+ expand_iter_index += 1
201
+ # logger.info(f"expand.expand time: {time.time() - e_time}")
202
+
203
+ # To fetch Outcome Reward in parallel,
204
+ # the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
205
+ orm_infer_requests = []
206
+ unique_output = set()
207
+ for response in responses:
208
+ self.update_usage_info(response)
209
+ output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0]
210
+ if output in unique_output:
211
+ continue
212
+ unique_output.add(output)
213
+ orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}]))
214
+ child = LanguageNode(step=output, parent=expand_curr_node)
215
+ if self.orm_model.check_terminate(child.answer)[0]:
216
+ child.terminated = True
217
+ expand_curr_node.add_child(child)
218
+
219
+ # e_time = time.time()
220
+ orm_score, _orm_mask = get_reward(
221
+ self.orm_model,
222
+ orm_infer_requests,
223
+ ground_truths=[ground_truth] * len(orm_infer_requests),
224
+ threshold=0.0)
225
+ # logger.info(f"expand.orm time: {time.time() - e_time}")
226
+ for child, score in zip(expand_curr_node.children, orm_score):
227
+ if child.terminated:
228
+ child.init_and_update_value(score)
229
+ child.correct = score > 0.9
230
+ terminated_nodes.append(child)
231
+
232
+ # e_time = time.time()
233
+ if self.prm_model:
234
+ prm_infer_requests = []
235
+ for child in expand_curr_node.children:
236
+ prm_message = {'role': 'assistant', 'content': child.answer}
237
+ prm_infer_requests.append(InferRequest([prompt_message, prm_message]))
238
+ prm_score, _prm_mask = get_reward(
239
+ self.prm_model,
240
+ prm_infer_requests,
241
+ ground_truths=[ground_truth] * len(prm_infer_requests),
242
+ threshold=0.0)
243
+ for child, score in zip(expand_curr_node.children, prm_score):
244
+ child.process_reward = score
245
+ # logger.info(f"expand.prm time: {time.time() - e_time}")
246
+
247
+ def _rollout(rollout_curr_node: LanguageNode):
248
+ rollout_depth = 0
249
+ rollout_nodes = {}
250
+ for i in range(len(rollout_curr_node.active_children)):
251
+ rollout_nodes[i] = {
252
+ 'node': rollout_curr_node.active_children[i],
253
+ 'history_messages': {
254
+ 'role': 'assistant',
255
+ 'content': rollout_curr_node.active_children[i].answer,
256
+ },
257
+ }
258
+ active_rollout_nodes = list(rollout_nodes.keys())
259
+ while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
260
+ # r_time = time.time()
261
+ infer_requests = [
262
+ InferRequest(system_message
263
+ + [prompt_message, rollout_nodes[index]['history_messages'], next_message])
264
+ for index in active_rollout_nodes
265
+ ]
266
+ # logger.info(f"rollout.prepare time: {time.time() - r_time}")
267
+ # r_time = time.time()
268
+ rollout_iter_index = 0
269
+ while True:
270
+ responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs,
271
+ **self.infer_kwargs)
272
+ if len(responses) > 0:
273
+ break
274
+ if rollout_iter_index == 5:
275
+ raise ValueError('Rollout should not return any response')
276
+ rollout_iter_index += 1
277
+ # logger.info(f"rollout.infer time: {time.time() - r_time}")
278
+
279
+ # r_time = time.time()
280
+ orm_infer_requests = []
281
+ end_paths = []
282
+ for index, response in zip(active_rollout_nodes, responses):
283
+ self.update_usage_info(response)
284
+ output = response.choices[0].message.content.rstrip(sep_token
285
+ + '\n').split(sep_token)[0] + sep_token + '\n'
286
+ rollout_nodes[index]['history_messages']['content'] += output
287
+ end_paths.append(rollout_nodes[index]['history_messages']['content'])
288
+ orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']]))
289
+ # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")
290
+
291
+ # r_time = time.time()
292
+ orm_score, _orm_mask = get_reward(
293
+ self.orm_model,
294
+ orm_infer_requests,
295
+ ground_truths=[ground_truth] * len(infer_requests),
296
+ threshold=0.0)
297
+ # logger.info(f"rollout.get_orm time: {time.time() - r_time}")
298
+ terminated_state = self.orm_model.check_terminate(end_paths)
299
+ for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state):
300
+ if terminated:
301
+ rollout_curr_node.active_children[index].init_and_update_value(score)
302
+ if score > 0.9:
303
+ rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content'])
304
+ else:
305
+ rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content'])
306
+ rollout_nodes.pop(index)
307
+ active_rollout_nodes = list(rollout_nodes.keys())
308
+ rollout_depth += 1
309
+
310
+ def _back_propagate(back_curr_node: LanguageNode):
311
+ while back_curr_node:
312
+ if back_curr_node == curr_node:
313
+ best_child_value = max([child.outcome_reward for child in back_curr_node.children])
314
+ back_curr_node.init_and_update_value(best_child_value)
315
+ last_child_value = back_curr_node.outcome_reward
316
+ else:
317
+ back_curr_node.init_and_update_value(last_child_value)
318
+ last_child_value = back_curr_node.outcome_reward
319
+ back_curr_node.visit()
320
+ if len(back_curr_node.active_children) == 0:
321
+ back_curr_node.terminated = True
322
+ if not back_curr_node.is_root():
323
+ back_curr_node.parent.active_children.remove(back_curr_node)
324
+ back_curr_node = back_curr_node.parent
325
+
326
+ _args = self.args
327
+ system_message = [] + _args.system_message
328
+ sep_token = _args.stop_words[0] + '\n'
329
+ _root = LanguageNode(sep_token=sep_token)
330
+ prompt_message = {
331
+ 'role': 'user',
332
+ 'content': query,
333
+ }
334
+
335
+ rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], []
336
+ iter_count = 0
337
+ stop_reason = None
338
+ while True:
339
+ logger.info(f'iter_count: {iter_count}' + '.' * 10)
340
+ s_time = time.time()
341
+ curr_node = _select(_root)
342
+ logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}')
343
+ s_time = time.time()
344
+ _expand(curr_node)
345
+ logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}')
346
+ if curr_node.depth > _args.rollout_start_depth:
347
+ s_time = time.time()
348
+ _rollout(curr_node)
349
+ logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}')
350
+ s_time = time.time()
351
+ _back_propagate(curr_node)
352
+ logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}')
353
+ if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences:
354
+ if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
355
+ stop_reason = 'too easy'
356
+ break
357
+ elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
358
+ stop_reason = 'too hard'
359
+ break
360
+ if _root.terminated:
361
+ stop_reason = 'root terminated'
362
+ break
363
+ if len(terminated_nodes) >= _args.num_return_sequences:
364
+ stop_reason = 'enough nodes'
365
+ break
366
+ if iter_count >= _args.max_iterations:
367
+ stop_reason = 'max_iterations'
368
+ break
369
+ iter_count += 1
370
+ logger.info(f'stop_reason: {stop_reason}')
371
+ # logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
372
+ # logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")
373
+
374
+ monte_carlo_tree = _root.collect()
375
+ result = {
376
+ 'query': query,
377
+ 'ground_truth': ground_truth,
378
+ 'rollout_correct_answers': rollout_correct_answers,
379
+ 'rollout_incorrect_answers': rollout_incorrect_answers,
380
+ 'monte_carlo_tree': monte_carlo_tree,
381
+ }
382
+ result_json = json.dumps(result, ensure_ascii=False)
383
+ logger.info(result_json)
384
+ return result_json
385
+
386
+ def do_sample(self, data):
387
+ if not isinstance(data, list):
388
+ data = [data]
389
+ generated = []
390
+ for item in data:
391
+ logger.info(f'time: {time.ctime(time.time())}')
392
+ try:
393
+ messages = item['messages'][0]
394
+ query = messages[0]['content']
395
+ ground_truth = messages[1]['content']
396
+ generated.append(self.search_single(query, ground_truth) + '\n')
397
+ except Exception as e:
398
+ logger.error(f'Error: {e}')
399
+ logger.error(f'Traceback: {traceback.format_exc()}')
400
+ return generated
swift/llm/sampling/sampling.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import shutil
4
+ import time
5
+ from typing import List, Union
6
+
7
+ import json
8
+
9
+ from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
10
+ from swift.utils import get_logger
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class SwiftSampling(SwiftPipeline):
16
+ args_class = SamplingArguments
17
+ args: args_class
18
+
19
+ def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
20
+ super().__init__(args)
21
+ self.args.save_args()
22
+ os.makedirs(self.args.output_dir, exist_ok=True)
23
+ self.cur_piece = 0
24
+ self.total_piece = 1
25
+
26
+ if self.args.data_range:
27
+ self.cur_piece, self.total_piece = self.args.data_range
28
+
29
+ if self.args.sampler_type == 'sample':
30
+ from swift.llm.sampling.vanilla_sampler import VanillaSampler
31
+ self.sampler = VanillaSampler(self.args)
32
+ elif self.args.sampler_type == 'mcts':
33
+ from swift.llm.sampling.mcts import MctsSampler
34
+ self.sampler = MctsSampler(self.args)
35
+ elif self.args.sampler_type == 'distill':
36
+ from swift.llm.sampling.distill_sampler import DistillSampler
37
+ self.sampler = DistillSampler(self.args)
38
+ else:
39
+ raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')
40
+
41
+ def _get_dataset(self):
42
+ args = self.args
43
+ dataset_kwargs = args.get_dataset_kwargs()
44
+ sampling_dataset, _ = load_dataset(
45
+ args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs)
46
+ logger.info(f'Sampling_dataset: {sampling_dataset}')
47
+ dataset_len = len(sampling_dataset)
48
+ piece_len = dataset_len // self.total_piece
49
+ sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1)))
50
+ return sampling_dataset
51
+
52
+ def run(self):
53
+ os.makedirs(self.args.output_dir, exist_ok=True)
54
+ iter_file = os.path.join(self.args.output_dir, self.args.output_file)
55
+ resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
56
+ tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
57
+ ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
58
+ if os.path.exists(iter_file) and not self.args.override_exist_file:
59
+ return
60
+
61
+ index_resume = -1
62
+ write_mode = 'w'
63
+ if self.args.resume:
64
+ write_mode = 'a'
65
+ if os.path.exists(resume_file):
66
+ shutil.copyfile(resume_file, tmp_file)
67
+
68
+ if os.path.exists(ckpt_state_file):
69
+ with open(ckpt_state_file, 'r') as ckpt_state:
70
+ data = json.load(ckpt_state)
71
+ index_resume = data.get('index', -1)
72
+ logger.info(f'Loaded index_resume: {index_resume}')
73
+ else:
74
+ if os.path.exists(tmp_file):
75
+ os.remove(tmp_file)
76
+
77
+ dataset = self._get_dataset()
78
+ dataset_len = len(dataset)
79
+ total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
80
+
81
+ if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
82
+ self.args.num_sampling_per_gpu_batches = total_iters
83
+
84
+ with open(tmp_file, write_mode) as f:
85
+ for _index in range(self.args.num_sampling_per_gpu_batches):
86
+ if _index <= index_resume:
87
+ continue
88
+ logger.info(f' Sampling index:{_index}')
89
+ slices = dataset[self.args.num_sampling_per_gpu_batch_size
90
+ * _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
91
+ slices = self.sampler.truncate_input(slices)
92
+ generated = self.sampler.do_sample(slices)
93
+ f.writelines(generated)
94
+ f.flush()
95
+ shutil.copy(tmp_file, resume_file)
96
+ with open(ckpt_state_file, 'w') as ckpt_state:
97
+ json.dump({'index': _index}, ckpt_state)
98
+
99
+ if os.path.exists(iter_file):
100
+ shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
101
+ shutil.move(resume_file, iter_file)
102
+ logger.info(f'Sample file {iter_file} generated.')
103
+
104
+
105
+ def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
106
+ return SwiftSampling(args).main()
swift/llm/sampling/utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import inspect
3
+ from copy import copy
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import InferRequest, RequestConfig
10
+ from swift.utils import get_logger
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ def get_messages_md5(row: Dict[str, Any]):
16
+ row = copy(row)
17
+ row.pop('choices', None)
18
+ serialized = json.dumps(row, sort_keys=True)
19
+ return hashlib.md5(serialized.encode('utf-8')).hexdigest()
20
+
21
+
22
+ def get_reward(model: Any,
23
+ infer_requests: List[InferRequest],
24
+ request_config: RequestConfig = None,
25
+ ground_truths: List[str] = None,
26
+ threshold: Optional[float] = None):
27
+ """Get reward from an RM model.
28
+
29
+ Args:
30
+ model: The model instance or an RM evaluator
31
+ infer_requests: Infer requests sent to the model
32
+ request_config: Infer config
33
+ ground_truths: The ground truth list
34
+ threshold: An optional threshold to generate the mask
35
+
36
+ Returns:
37
+ Tuple
38
+ Index 0: The min-max normalized scores matched the infer_requests
39
+ Index 1: The mask filtered by the threshold
40
+ """
41
+ from swift.llm import InferEngine
42
+ infer_func = model.infer if isinstance(model, InferEngine) else model.__call__
43
+ parameters = inspect.signature(infer_func).parameters
44
+ gt_param = {}
45
+ if 'ground_truths' in parameters:
46
+ gt_param = {'ground_truths': ground_truths}
47
+ if isinstance(infer_requests[0], dict):
48
+ infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests]
49
+ rewards = infer_func(infer_requests, request_config=request_config, **gt_param)
50
+ from swift.llm.infer.protocol import ChatCompletionResponse
51
+ if isinstance(rewards[0], ChatCompletionResponse):
52
+ print('reward:', rewards[0].choices[0].message.content)
53
+ if isinstance(rewards[0].choices[0].message.content, str):
54
+ rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards]
55
+ elif isinstance(rewards[0].choices[0].message.content, list):
56
+ rewards = [float(min(r.choices[0].message.content)) for r in rewards]
57
+ else:
58
+ rewards = [float(r.choices[0].message.content) for r in rewards]
59
+ arr = []
60
+ for reward in rewards:
61
+ if isinstance(reward, (list, tuple)):
62
+ arr.append(min(reward))
63
+ else:
64
+ arr.append(float(reward))
65
+
66
+ _mask = np.array([True] * len(arr))
67
+ if threshold is not None:
68
+ # > not >=, orm caller passes 0, which will cause error
69
+ _mask = np.array([a > threshold for a in arr])
70
+
71
+ def normalize(arr):
72
+ min_val = np.min(arr)
73
+ max_val = np.max(arr)
74
+ if min_val == max_val:
75
+ if min_val == 0:
76
+ constant_value = 0.0
77
+ else:
78
+ constant_value = min(1.0, min_val)
79
+ return np.full_like(arr, fill_value=constant_value, dtype=np.float64)
80
+ normalized = (arr - min_val) / (max_val - min_val + 1e-5)
81
+ return normalized
82
+
83
+ return normalize(arr), _mask
84
+
85
+
86
+ def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs):
87
+ if isinstance(infer_engines, list):
88
+ assert len(infer_engines) >= len(request_configs) >= len(infer_requests)
89
+ from concurrent.futures import ThreadPoolExecutor, as_completed
90
+ n = len(infer_requests)
91
+ with ThreadPoolExecutor(max_workers=n) as executor:
92
+ futures = {
93
+ executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs):
94
+ i
95
+ for i in range(n)
96
+ }
97
+ responses = []
98
+ for future in as_completed(futures):
99
+ task_id = futures[future]
100
+ try:
101
+ responses += future.result()
102
+ except Exception as e:
103
+ logger.info(f'Perform infer task: {task_id} get an error: {e}')
104
+ return responses
105
+ elif isinstance(infer_requests, list):
106
+ responses = []
107
+ if isinstance(request_configs, list):
108
+ assert len(infer_requests) <= len(request_configs)
109
+ for i in range(len(infer_requests)):
110
+ responses += infer_engines.infer(
111
+ [infer_requests[i]],
112
+ request_configs[i],
113
+ **infer_kwargs,
114
+ )
115
+ elif isinstance(request_configs, RequestConfig):
116
+ for infer_request in infer_requests:
117
+ responses += infer_engines.infer(
118
+ [infer_request],
119
+ request_configs,
120
+ **infer_kwargs,
121
+ )
122
+ return responses
123
+ return infer_engines.infer(
124
+ [infer_requests],
125
+ request_configs,
126
+ **infer_kwargs,
127
+ )
128
+
129
+
130
+ def collect_from_mct(monte_carlo_tree, collect_filter_threshold):
131
+ from transformers.utils import strtobool
132
+ if isinstance(monte_carlo_tree, str):
133
+ monte_carlo_tree = json.loads(monte_carlo_tree)
134
+
135
+ def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]):
136
+ _prefer_pairs, _correct_answers, _incorrect_answers = [], [], []
137
+ _outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']]
138
+ _process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']]
139
+ if len(collect_curr_node['children']) > 0:
140
+ for child in collect_curr_node['children']:
141
+ p, c, i = _collect(child, _outcome_rewards, _process_rewards)
142
+ _prefer_pairs += p
143
+ _correct_answers += c
144
+ _incorrect_answers += i
145
+ sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward'])
146
+ if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold:
147
+ # TODO: filter with visit count
148
+ prefer_pair = {
149
+ 'path': 'ки\n'.join(collect_curr_node['path']),
150
+ 'good': sorted_children[-1]['path'][-1],
151
+ 'good_score': sorted_children[-1]['outcome_reward'],
152
+ 'bad': sorted_children[0]['path'][-1],
153
+ 'bad_score': sorted_children[0]['outcome_reward'],
154
+ }
155
+ _prefer_pairs.append(prefer_pair)
156
+ if strtobool(collect_curr_node['terminated']):
157
+ _answer = {
158
+ 'answer': 'ки\n'.join(collect_curr_node['path']),
159
+ 'mean_outcome_reward': np.mean(_outcome_rewards),
160
+ 'min_outcome_reward': np.min(_outcome_rewards),
161
+ 'mean_process_reward': np.mean(_process_rewards),
162
+ 'min_process_reward': np.min(_process_rewards),
163
+ }
164
+ if strtobool(collect_curr_node['correct']):
165
+ _correct_answers.append(_answer)
166
+ else:
167
+ _incorrect_answers.append(_answer)
168
+ return _prefer_pairs, _correct_answers, _incorrect_answers
169
+
170
+ _root = monte_carlo_tree
171
+ prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], [])
172
+ return prefer_pairs, correct_answers, incorrect_answers
swift/llm/sampling/vanilla_sampler.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Any, Dict, List
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import RequestConfig
10
+ from swift.llm.sampling.base import Sampler
11
+ from swift.llm.template.template_inputs import InferRequest
12
+ from swift.utils import get_logger
13
+ from .utils import get_messages_md5, get_reward
14
+
15
+ logger = get_logger()
16
+
17
+
18
+ class VanillaSampler(Sampler):
19
+
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ if self.args.sampler_engine == 'pt':
24
+ from swift.llm import PtEngine
25
+ _Engine = PtEngine
26
+ elif self.args.sampler_engine == 'vllm':
27
+ from swift.llm import VllmEngine
28
+ _Engine = VllmEngine
29
+ elif self.args.sampler_engine == 'lmdeploy':
30
+ from swift.llm import LmdeployEngine
31
+ _Engine = LmdeployEngine
32
+ elif self.args.sampler_engine == 'no':
33
+ _Engine = None
34
+ else:
35
+ raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
36
+ self.infer_engine = None
37
+ if _Engine:
38
+ self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
39
+ self.infer_engine.default_template = self.template
40
+ self.infer_engine.strict = False
41
+ self.caches = self.read_cache()
42
+
43
+ def read_cache(self):
44
+ cache_files = self.args.cache_files
45
+ caches = {}
46
+ for file in cache_files:
47
+ if not os.path.exists(file):
48
+ logger.warning(f'Cache file does not exist: {file}')
49
+ continue
50
+ with open(file, 'r') as f:
51
+ for line in f.readlines():
52
+ line = line.strip()
53
+ if not line:
54
+ continue
55
+
56
+ content = json.loads(line)
57
+ uuid = content['id']
58
+ messages = content['messages']
59
+ if uuid not in caches:
60
+ caches[uuid] = {'choices': []}
61
+ assert messages[-1]['role'] == 'assistant'
62
+ caches[uuid]['choices'].append(messages[-1]['content'])
63
+ return caches
64
+
65
+ @staticmethod
66
+ def convert_data_to_rows(data):
67
+ rows = []
68
+ key = list(data.keys())[0]
69
+ data_len = len(data[key])
70
+ for idx in range(data_len):
71
+ row = {key: data[key][idx] for key in data}
72
+ if row.get('images') and 'bytes' in row['images'][0]:
73
+ row['images'] = [img['path'] for img in row['images']]
74
+ rows.append(row)
75
+ VanillaSampler.check_row_valid(rows)
76
+ return rows
77
+
78
+ @staticmethod
79
+ def check_row_valid(rows):
80
+ for row in rows:
81
+ assert not row.get('images') or all([isinstance(img, str) and img for img in row['images']])
82
+ assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']])
83
+ assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']])
84
+
85
+ def generate(self, data):
86
+ resp_all = []
87
+ infer_requests = []
88
+ sent = 0
89
+ rows = self.convert_data_to_rows(data)
90
+ for idx, row in enumerate(rows):
91
+ row = deepcopy(row)
92
+ messages = row['messages']
93
+ uuid = get_messages_md5(row)
94
+ if uuid in self.caches:
95
+ choices = self.caches[uuid]['choices']
96
+ if len(choices) == self.args.num_return_sequences:
97
+ continue
98
+ if self.args.system:
99
+ if messages[0]['role'] == 'system':
100
+ messages[0]['content'] = self.args.system
101
+ else:
102
+ messages.insert(0, {'role': 'system', 'content': self.args.system})
103
+ if messages[-1]['role'] == 'assistant':
104
+ messages = messages[:-1]
105
+
106
+ row['messages'] = messages
107
+ infer_request = row
108
+ for i in range(self.args.num_return_sequences):
109
+ infer_requests.append(deepcopy(infer_request))
110
+ sent += 1
111
+
112
+ request_config = RequestConfig(
113
+ max_tokens=self.args.max_new_tokens,
114
+ temperature=self.args.temperature,
115
+ top_k=self.args.top_k,
116
+ top_p=self.args.top_p,
117
+ )
118
+
119
+ resp_list = []
120
+ if len(infer_requests) > 0:
121
+ resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
122
+
123
+ _cur = 0
124
+ for idx, row in enumerate(rows):
125
+ row = deepcopy(row)
126
+ uuid = get_messages_md5(row)
127
+ if uuid in self.caches:
128
+ choices = self.caches[uuid]['choices']
129
+ if len(choices) == self.args.num_return_sequences:
130
+ row['choices'] = choices
131
+ resp_all.append(row)
132
+ continue
133
+
134
+ resps = row
135
+ resps['choices'] = []
136
+ for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
137
+ if not isinstance(resp_list[j], Exception):
138
+ resps['choices'].append(resp_list[j].choices[0].message.content)
139
+ if resps['choices']:
140
+ resp_all.append(resps)
141
+ _cur += 1
142
+ return resp_all
143
+
144
+ def do_sample(self, data):
145
+ generated = []
146
+ resp_all = self.generate(data)
147
+ for i, resps in enumerate(resp_all):
148
+ choices = resps['choices']
149
+ messages = resps['messages']
150
+ uuid = get_messages_md5(resps)
151
+ assert messages[-1]['role'] == 'assistant'
152
+ ground_truth = messages[-1]['content']
153
+
154
+ infer_requests = []
155
+ for decoded in choices:
156
+ _resps = deepcopy(resps)
157
+ _resps['messages'][-1]['content'] = decoded
158
+ infer_requests.append(_resps)
159
+
160
+ _resps = deepcopy(resps)
161
+ _resps['messages'][-1]['content'] = ground_truth
162
+ infer_requests.append(_resps)
163
+ if self.orm_model is not None:
164
+ orm_score, _orm_mask = get_reward(
165
+ self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0)
166
+ else:
167
+ orm_score = np.array([1.0] * len(infer_requests))
168
+ _orm_mask = np.array([True] * len(infer_requests))
169
+ if self.prm_model is not None:
170
+ prm_score, _prm_mask = get_reward(
171
+ self.prm_model,
172
+ infer_requests,
173
+ ground_truths=[ground_truth] * len(infer_requests),
174
+ threshold=self.args.prm_threshold)
175
+ else:
176
+ prm_score = np.array([1.0] * len(infer_requests))
177
+ _prm_mask = np.array([True] * len(infer_requests))
178
+
179
+ _mask = _orm_mask & _prm_mask
180
+ if not any(_mask):
181
+ continue
182
+
183
+ choices.append(ground_truth)
184
+ choices = np.array(choices)
185
+
186
+ if self.orm_model is None and self.prm_model is None:
187
+ positives = choices[:-1]
188
+ for positive in positives:
189
+ _resps = deepcopy(resps)
190
+ _resps.pop('choices', None)
191
+ _resps['id'] = uuid
192
+ _resps['messages'][-1]['content'] = str(positive)
193
+ generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
194
+ else:
195
+ score = np.array(prm_score) + np.array(orm_score * 10)
196
+ sorted_indices = np.argsort(score)[::-1]
197
+ pos_indexes = sorted_indices[0:self.args.n_best_to_keep]
198
+ pos_indexes = [i for i in pos_indexes if _mask[i]]
199
+ neg_index = sorted_indices[-1]
200
+ logger.info(
201
+ f'orm:{orm_score}, prm:{prm_score}, positive index: {pos_indexes}, negative index: {neg_index}')
202
+ if self.args.easy_query_threshold is not None and sum([score > 0 for score in orm_score]) - 1 >= int(
203
+ self.args.num_return_sequences * self.args.easy_query_threshold):
204
+ continue
205
+ if len(pos_indexes) > 0:
206
+ positives = choices[pos_indexes]
207
+ negative = choices[neg_index]
208
+ for positive in positives:
209
+ _resps = deepcopy(resps)
210
+ messages = deepcopy(messages)
211
+ _resps.pop('choices', None)
212
+ _resps['messages'][-1]['content'] = str(positive)
213
+ _resps['rejected_response'] = str(negative)
214
+ _resps['id'] = uuid
215
+ generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
216
+ return generated
swift/llm/template/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from . import template
3
+ from .base import MaxLengthError, Template
4
+ from .constant import TemplateType
5
+ from .grounding import draw_bbox
6
+ from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template
7
+ from .template_inputs import InferRequest, TemplateInputs
8
+ from .template_meta import TemplateMeta
9
+ from .utils import Prompt, Word, split_str_parts_by
10
+ from .vision_utils import load_file, load_image
swift/llm/template/base.py ADDED
@@ -0,0 +1,1606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import hashlib
3
+ import inspect
4
+ import math
5
+ import os
6
+ import re
7
+ from contextlib import contextmanager, nullcontext
8
+ from copy import deepcopy
9
+ from dataclasses import asdict
10
+ from functools import partial, wraps
11
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from modelscope.hub.utils.utils import get_cache_dir
17
+ from peft import PeftModel
18
+ from PIL import Image
19
+ from torch.nn.utils.rnn import pad_sequence
20
+ from transformers import StoppingCriteriaList
21
+ from transformers.integrations import is_deepspeed_zero3_enabled
22
+ from transformers.utils import strtobool
23
+
24
+ from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
25
+ from ..utils import Processor, ProcessorMixin
26
+ from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
27
+ from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by
28
+ from .vision_utils import load_audio, load_batch, load_image, rescale_image
29
+
30
+ logger = get_logger()
31
+ if TYPE_CHECKING:
32
+ from .template_meta import TemplateMeta
33
+
34
+
35
+ class MaxLengthError(ValueError):
36
+ pass
37
+
38
+
39
+ class Template(ProcessorMixin):
40
+ special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>']
41
+ special_keys = ['images', 'videos', 'audios', 'objects']
42
+
43
+ image_placeholder = ['<image>']
44
+ video_placeholder = ['<video>']
45
+ audio_placeholder = ['<audio>']
46
+ cot_process_placeholder = ['ки']
47
+ placeholder_tokens = [] # For clearer printing
48
+ load_images = True
49
+ skip_prompt = True
50
+ use_model = False
51
+ norm_bbox = 'norm1000'
52
+
53
+ is_encoder_decoder = False
54
+
55
+ def __init__(
56
+ self,
57
+ processor: Processor,
58
+ template_meta: 'TemplateMeta',
59
+ default_system: Optional[str] = None,
60
+ max_length: Optional[int] = None,
61
+ *,
62
+ use_chat_template: bool = True,
63
+ truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
64
+ max_pixels: Optional[int] = None,
65
+ agent_template: Optional[str] = None,
66
+ norm_bbox: Literal['norm1000', 'none', None] = None,
67
+ response_prefix: Optional[str] = None,
68
+ # only for train
69
+ padding_side: Literal['left', 'right'] = 'right',
70
+ loss_scale: str = 'default',
71
+ sequence_parallel_size: int = 1,
72
+ # infer/deploy
73
+ template_backend: Literal['swift', 'jinja'] = 'swift',
74
+ ) -> None:
75
+ """
76
+ default_system: Override the default_system in the template.
77
+ max_length: Max length of the sequence
78
+ truncation_strategy: The truncation strategy
79
+ max_pixels: Rescale image to reduce memory usage, default `None` means no limitation.
80
+ e.g. 512 * 512 (H*W)
81
+ padding_side: The padding_side when the training batch_size >= 2
82
+ loss_scale: The loss scale function to use
83
+ """
84
+ from .template_meta import TemplateMeta
85
+ from swift.plugin import agent_templates
86
+
87
+ self.processor = processor
88
+ self.model_info = processor.model_info
89
+ self.config = self.model_info.config
90
+ self.model_meta = processor.model_meta
91
+ if max_length is None:
92
+ max_length = self.model_info.max_model_len
93
+ tokenizer = self.tokenizer
94
+
95
+ if not use_chat_template:
96
+ template_meta = template_meta.to_generate_template_meta()
97
+ else:
98
+ template_meta = deepcopy(template_meta)
99
+ # if default_system is None. not change self.default_system
100
+ template_meta.check_system(default_system)
101
+ if default_system is not None:
102
+ template_meta.default_system = default_system
103
+ if response_prefix is not None:
104
+ template_meta.response_prefix = response_prefix
105
+ logger.info(f'default_system: {repr(template_meta.default_system)}')
106
+ logger.info(f'response_prefix: {repr(template_meta.response_prefix)}')
107
+
108
+ for i, token in enumerate(self.placeholder_tokens):
109
+ if isinstance(token, str):
110
+ self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token)
111
+ template_meta.init(tokenizer)
112
+
113
+ self.template_meta: TemplateMeta = template_meta
114
+ self.use_chat_template = use_chat_template
115
+ self.template_backend = template_backend
116
+ self.max_length = max_length
117
+ self.truncation_strategy = truncation_strategy
118
+ self.loss_scale = loss_scale
119
+ self.max_pixels = max_pixels
120
+ self.padding_side = padding_side
121
+ self.sequence_parallel_size = sequence_parallel_size
122
+ agent_template = agent_template or template_meta.agent_template
123
+ logger.info(f'agent_template: {agent_template}')
124
+ self.agent_template = agent_templates[agent_template]()
125
+ self.norm_bbox = norm_bbox or self.norm_bbox
126
+ logger.info(f'max_length: {self.max_length}')
127
+ logger.info(f'norm_bbox: {self.norm_bbox}')
128
+ if self.is_encoder_decoder:
129
+ self.skip_prompt = False
130
+ self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
131
+ 'train', 'rlhf', 'kto', # train
132
+ 'seq_cls', 'embedding', 'prm'] = 'pt'
133
+ self._packing = False
134
+ self.use_megatron = False
135
+ if self.model_info.task_type != 'causal_lm':
136
+ self.mode = self.model_info.task_type
137
+ self._handles = []
138
+ self._deepspeed_initialize = None
139
+
140
+ @staticmethod
141
+ def _load_image(image, load_images: bool):
142
+ if load_images:
143
+ if isinstance(image, dict) and 'bytes' in image:
144
+ image = image['bytes'] or image['path']
145
+ image = load_image(image)
146
+ else:
147
+ if isinstance(image, dict):
148
+ path = image['path']
149
+ if path and (path.startswith('http') or os.path.exists(path)):
150
+ image = path
151
+ else:
152
+ image = load_image(image['bytes'])
153
+ elif not isinstance(image, str):
154
+ image = load_image(image)
155
+ return image
156
+
157
+ @staticmethod
158
+ def _get_height_width(inputs: StdTemplateInputs) -> None:
159
+ width = []
160
+ height = []
161
+ for image in inputs.images:
162
+ width.append(image.width)
163
+ height.append(image.height)
164
+ inputs.objects['width'] = width
165
+ inputs.objects['height'] = height
166
+
167
+ def normalize_bbox(self, inputs: StdTemplateInputs) -> None:
168
+ objects = inputs.objects
169
+ bbox_list = objects['bbox']
170
+ width_list = objects['width']
171
+ height_list = objects['height']
172
+ bbox_type = objects.pop('bbox_type', None) or 'real'
173
+ image_id_list = objects.pop('image_id', None) or []
174
+ image_id_list += [0] * (len(bbox_list) - len(image_id_list))
175
+ for bbox, image_id in zip(bbox_list, image_id_list):
176
+ if bbox_type == 'norm1':
177
+ width, height = 1, 1
178
+ else:
179
+ width, height = width_list[image_id], height_list[image_id]
180
+ for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])):
181
+ if self.norm_bbox == 'norm1000':
182
+ norm_width, norm_height = 1000, 1000
183
+ elif self.norm_bbox == 'none':
184
+ image = inputs.images[image_id]
185
+ norm_width, norm_height = image.width, image.height
186
+ bbox[2 * i] = int(round(x / width * norm_width))
187
+ bbox[2 * i + 1] = int(round(y / height * norm_height))
188
+
189
+ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
190
+ agent_template = self.agent_template
191
+ agent_template.template_meta = self.template_meta # for hermes
192
+ if inputs.tools:
193
+ if isinstance(inputs.tools, str):
194
+ inputs.tools = agent_template._parse_json(inputs.tools)
195
+ if not isinstance(inputs.tools, (list, tuple)):
196
+ inputs.tools = [inputs.tools]
197
+ elif isinstance(inputs.tools, (list, tuple)):
198
+ inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools]
199
+ else:
200
+ raise ValueError(f'inputs.tools: {inputs.tools}')
201
+ for i, tool in enumerate(inputs.tools):
202
+ inputs.tools[i] = agent_template.wrap_tool(tool)
203
+ i = 0
204
+ messages = inputs.messages
205
+ while i < len(messages):
206
+ if messages[i]['role'] == 'tool_call':
207
+ i_start = i
208
+ while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call':
209
+ i += 1
210
+ tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1])
211
+ messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}]
212
+ i = i_start + 1
213
+ else:
214
+ i += 1
215
+
216
+ def _preprocess_inputs(
217
+ self,
218
+ inputs: StdTemplateInputs,
219
+ ) -> None:
220
+ self._preprocess_function_call(inputs)
221
+ if self.model_meta.is_multimodal:
222
+ self._replace_image_tags(inputs)
223
+ self._replace_start_image_tags(inputs)
224
+ images = inputs.images
225
+ load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
226
+ load_images_origin = load_images
227
+ if self.max_pixels is not None or inputs.objects:
228
+ load_images = True
229
+ if images:
230
+ for i, image in enumerate(images):
231
+ images[i] = self._load_image(images[i], load_images)
232
+ if inputs.objects:
233
+ self._get_height_width(inputs)
234
+ if self.max_pixels is not None:
235
+ # Scale the image proportionally without affecting the scaled objects.
236
+ images = [rescale_image(img, self.max_pixels) for img in images]
237
+ if images and not load_images_origin: # fix pt & qwen-vl
238
+ for i, image in enumerate(images):
239
+ if isinstance(image, Image.Image):
240
+ images[i] = self._save_pil_image(image)
241
+ inputs.images = images
242
+
243
+ if self.mode == 'vllm' and inputs.audios:
244
+ sampling_rate = get_env_args('sampling_rate', int, None)
245
+ inputs.audios = load_batch(
246
+ inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True))
247
+
248
+ if inputs.is_multimodal:
249
+ self._add_default_tags(inputs)
250
+
251
+ @staticmethod
252
+ def _replace_image_tags(inputs: StdTemplateInputs):
253
+ # compat
254
+ if inputs.images:
255
+ return
256
+ images = []
257
+ pattern = r'<img>(.+?)</img>'
258
+ for message in inputs.messages:
259
+ content = message['content']
260
+ if not isinstance(content, str):
261
+ continue
262
+ for image in re.findall(pattern, content):
263
+ # only support local_path
264
+ if os.path.isfile(image):
265
+ images.append(image)
266
+ else:
267
+ logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>')
268
+ message['content'] = re.sub(pattern, '<image>', content)
269
+ inputs.images = images
270
+
271
+ @staticmethod
272
+ def _replace_start_image_tags(inputs: StdTemplateInputs):
273
+ # compat
274
+ generate_mode = False
275
+ message = inputs.messages[-1]
276
+ content = message['content']
277
+ if message['role'] == 'user' and content.endswith('<start-image>'):
278
+ generate_mode = True
279
+ message['content'] = message['content'][:-len('<start-image>')] # remove the <start-image>
280
+ inputs.generate_mode = generate_mode
281
+
282
+ @staticmethod
283
+ def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
284
+ get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
285
+ added_tokens_len = 0
286
+ for i, idx in enumerate(replace_idx_list):
287
+ new_tokens = get_new_tokens(i)
288
+ token_len = len(new_tokens)
289
+ input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
290
+ if labels:
291
+ labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
292
+ added_tokens_len += token_len - 1
293
+ return input_ids, labels
294
+
295
+ def compute_loss_context(self, model, inputs):
296
+ return nullcontext()
297
+
298
+ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
299
+ chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
300
+ assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
301
+ rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
302
+ chosen_encoded = self._encode_truncated(chosen_inputs)
303
+ rejected_encoded = self._encode_truncated(rejected_inputs)
304
+
305
+ encoded = {}
306
+ for prefix in ['chosen', 'rejected']:
307
+ data = locals()[f'{prefix}_encoded']
308
+ for k, v in data.items():
309
+ encoded[f'{prefix}_{k}'] = v
310
+ return encoded
311
+
312
+ def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
313
+ label, inputs.label = inputs.label, None
314
+ encoded = self._rlhf_encode(inputs)
315
+ encoded['label'] = bool(label)
316
+ return encoded
317
+
318
+ def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
319
+ _encoded = {}
320
+ labels = []
321
+
322
+ def split_multi_medias(_inputs):
323
+ _content = _inputs.messages[-2]['content']
324
+ image_size = len(re.findall('<image>', _content))
325
+ video_size = len(re.findall('<video>', _content))
326
+ audio_size = len(re.findall('<audio>', _content))
327
+ _inputs.images = inputs.images[:image_size]
328
+ assert len(_inputs.images) == image_size
329
+ inputs.images = inputs.images[image_size:]
330
+ _inputs.videos = inputs.videos[:video_size]
331
+ assert len(_inputs.videos) == video_size
332
+ inputs.videos = inputs.videos[video_size:]
333
+ _inputs.audios = inputs.audios[:audio_size]
334
+ assert len(_inputs.audios) == audio_size
335
+ inputs.audios = inputs.audios[audio_size:]
336
+
337
+ anchor = deepcopy(inputs)
338
+ anchor.messages[-1]['content'] = ''
339
+ anchor.rejected_response = []
340
+ split_multi_medias(anchor)
341
+ anchor_encoded = self._encode_truncated(anchor)
342
+ for key in anchor_encoded:
343
+ _encoded[f'anchor_{key}'] = anchor_encoded[key]
344
+
345
+ positive = deepcopy(inputs)
346
+ positive.messages[-2]['content'] = positive.messages[-1]['content']
347
+ positive.messages[-1]['content'] = ''
348
+ positive.rejected_response = []
349
+ split_multi_medias(positive)
350
+ positive_encoded = self._encode_truncated(positive)
351
+ for key in positive_encoded:
352
+ _encoded[f'positive_{key}'] = positive_encoded[key]
353
+ labels.append(float(inputs.label) if inputs.label is not None else 1.0)
354
+
355
+ rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
356
+ for i in range(rejected_len):
357
+ negative = deepcopy(inputs)
358
+ negative.messages[-2]['content'] = negative.rejected_response[i]
359
+ negative.messages[-1]['content'] = ''
360
+ negative.rejected_response = []
361
+ split_multi_medias(negative)
362
+ negative_encoded = self._encode_truncated(negative)
363
+ for key in negative_encoded:
364
+ _encoded[f'negative{i}_{key}'] = negative_encoded[key]
365
+ labels.append(0.0)
366
+
367
+ _encoded['labels'] = labels
368
+ return _encoded
369
+
370
+ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
371
+ encoded = self._encode_truncated(inputs)
372
+ encoded.pop('labels', None)
373
+ if inputs.label is not None:
374
+ labels = inputs.label
375
+ problem_type = self._get_problem_type(self.config, labels=labels)
376
+ if problem_type == 'single_label_classification':
377
+ labels = int(labels)
378
+ encoded['labels'] = labels
379
+ return encoded
380
+
381
+ @torch.inference_mode()
382
+ def encode(self,
383
+ inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
384
+ return_template_inputs: bool = False) -> Dict[str, Any]:
385
+ """The entrance method of Template!
386
+
387
+ Returns:
388
+ return {'input_ids': List[int], 'labels': Optional[List[int]], ...}
389
+ """
390
+ if isinstance(inputs, (InferRequest, TemplateInputs)):
391
+ inputs = asdict(inputs)
392
+
393
+ if isinstance(inputs, dict):
394
+ inputs = deepcopy(inputs)
395
+ if not self.is_training:
396
+ InferRequest.remove_response(inputs['messages'])
397
+ inputs = StdTemplateInputs.from_dict(inputs)
398
+ elif isinstance(inputs, StdTemplateInputs):
399
+ inputs = deepcopy(inputs)
400
+ assert isinstance(inputs, StdTemplateInputs)
401
+ self._preprocess_inputs(inputs)
402
+
403
+ if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}:
404
+ encoded = self._encode_truncated(inputs)
405
+ elif self.mode == 'seq_cls':
406
+ encoded = self._seq_cls_encode(inputs)
407
+ elif self.mode == 'rlhf':
408
+ encoded = self._rlhf_encode(inputs)
409
+ elif self.mode == 'kto':
410
+ encoded = self._kto_encode(inputs)
411
+ elif self.mode == 'embedding':
412
+ encoded = self._embedding_encode(inputs)
413
+ for key in list(encoded.keys()):
414
+ if encoded[key] is None:
415
+ encoded.pop(key)
416
+ if return_template_inputs:
417
+ encoded['template_inputs'] = inputs
418
+ return encoded
419
+
420
+ def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
421
+ packed = {}
422
+ keys = set()
423
+ for r in row:
424
+ keys.update(r[0].keys())
425
+ for key in keys:
426
+ if key in {'input_ids', 'labels', 'loss_scale'}:
427
+ packed[key] = sum((x[0][key] for x in row), start=[])
428
+ if 'position_ids' not in packed:
429
+ packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
430
+
431
+ packed.update(self._data_collator_mm_data([r[0] for r in row]))
432
+ return packed
433
+
434
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
435
+ return inputs
436
+
437
+ @staticmethod
438
+ def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
439
+ len_tokens = len(stop_tokens)
440
+ if is_finished and generate_ids[-len_tokens:] == stop_tokens:
441
+ return generate_ids[:-len_tokens]
442
+ if not is_finished:
443
+ for i in range(len_tokens, 0, -1):
444
+ if generate_ids[-i:] == stop_tokens[:i]:
445
+ return generate_ids[:-i]
446
+ return generate_ids
447
+
448
+ @staticmethod
449
+ def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int):
450
+ idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist()
451
+ logprobs = logprobs.tolist()
452
+ return {
453
+ 'content': [{
454
+ 'index': pred,
455
+ 'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred],
456
+ 'top_logprobs': [{
457
+ 'index': idx,
458
+ 'logprob': logprobs[idx]
459
+ } for idx in idxs]
460
+ }]
461
+ }
462
+
463
+ @staticmethod
464
+ def _get_problem_type(config, labels=None, logits=None) -> str:
465
+ problem_type = config.problem_type
466
+ if problem_type is not None:
467
+ return problem_type
468
+ if labels is not None:
469
+ if isinstance(labels, (list, tuple)):
470
+ if labels and isinstance(labels[0], float):
471
+ problem_type = 'regression'
472
+ else:
473
+ problem_type = 'multi_label_classification'
474
+ else:
475
+ problem_type = 'single_label_classification'
476
+ assert config.num_labels >= labels + 1
477
+ if logits is not None:
478
+ if logits.shape[-1] == 1:
479
+ problem_type = 'regression'
480
+ else:
481
+ problem_type = 'single_label_classification' # compatible with older versions
482
+ assert problem_type is not None
483
+ config.problem_type = problem_type
484
+ return problem_type
485
+
486
+ def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int):
487
+ assert isinstance(logits, torch.Tensor)
488
+ problem_type = self._get_problem_type(self.config, logits=logits)
489
+ if problem_type == 'regression':
490
+ preds = logits.squeeze(dim=-1).tolist()
491
+ logprobs = [None] * len(preds)
492
+ else:
493
+ if problem_type == 'single_label_classification':
494
+ preds = torch.argmax(logits, dim=-1).tolist()
495
+ logprobs = torch.log_softmax(logits, -1)
496
+ else:
497
+ preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)]
498
+ logprobs = F.logsigmoid(logits)
499
+ logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)]
500
+ return preds, logprobs
501
+
502
+ def decode(self,
503
+ generate_ids: List[int],
504
+ *,
505
+ is_finished: bool = True,
506
+ tokenizer_kwargs=None,
507
+ first_token=True,
508
+ **kwargs) -> Any:
509
+ tokenizer_kwargs = tokenizer_kwargs or {}
510
+ response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs)
511
+ if first_token and self.template_meta.response_prefix:
512
+ response = self.template_meta.response_prefix + response
513
+ return response
514
+
515
+ def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
516
+ raise NotImplementedError
517
+
518
+ def generate(self, model, *args, **kwargs):
519
+ if isinstance(model, PeftModel):
520
+ signature = inspect.signature(model.model.generate)
521
+ else:
522
+ signature = inspect.signature(model.generate)
523
+ if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs:
524
+ kwargs['use_model_defaults'] = False
525
+ return model.generate(*args, **kwargs)
526
+
527
+ def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
528
+ # Do not print template_meta.suffix[-1] and eos_token.
529
+ # However, other stop_words will be printed.
530
+ tokenizer = self.tokenizer
531
+
532
+ if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
533
+ generate_ids = generate_ids[:-1]
534
+ # skip suffix and eos_token
535
+ template_suffix = self.template_meta.suffix[-1]
536
+ if isinstance(template_suffix, str):
537
+ # [-1:]: fix OpenGVLab/Mini-InternVL-Chat-4B-V1-5
538
+ template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:]
539
+ generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
540
+ if 'spaces_between_special_tokens' not in decode_kwargs:
541
+ decode_kwargs['spaces_between_special_tokens'] = False
542
+ return tokenizer.decode(generate_ids, **decode_kwargs)
543
+
544
+ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
545
+ generation_config = generate_kwargs['generation_config']
546
+ stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words
547
+ generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)])
548
+ return generate_kwargs
549
+
550
+ @staticmethod
551
+ def _save_pil_image(image: Image.Image) -> str:
552
+ img_bytes = image.tobytes()
553
+ img_hash = hashlib.sha256(img_bytes).hexdigest()
554
+ tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
555
+ logger.info_once(f'create tmp_dir: {tmp_dir}')
556
+ os.makedirs(tmp_dir, exist_ok=True)
557
+ img_path = os.path.join(tmp_dir, f'{img_hash}.png')
558
+ if not os.path.exists(img_path):
559
+ image.save(img_path)
560
+ return img_path
561
+
562
+ @staticmethod
563
+ def _concat_context_list(
564
+ context_list: List[Context],
565
+ res_context_list: List[Context], # inplace
566
+ res_context_type: List[ContextType], # inplace
567
+ system: Optional[str] = None,
568
+ query: Optional[str] = None,
569
+ response: Optional[str] = None,
570
+ round0: Optional[int] = None) -> None:
571
+ """Concat context list and replace placeholder"""
572
+ round1 = None
573
+ if round0 is not None:
574
+ round1 = str(round0 + 1)
575
+ round0 = str(round0)
576
+ for context in context_list:
577
+ if isinstance(context, str):
578
+ if '{{RESPONSE}}' == context:
579
+ assert response is not None
580
+ res_context_list.append(response)
581
+ res_context_type.append(ContextType.RESPONSE)
582
+ continue
583
+ old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
584
+ new_str_list = [system, query, round0, round1]
585
+ for (old_str, new_str) in zip(old_str_list, new_str_list):
586
+ if new_str is not None and old_str in context:
587
+ assert isinstance(new_str, str), f'new_str: {new_str}'
588
+ context = context.replace(old_str, new_str)
589
+ if len(context) == 0:
590
+ continue
591
+ res_context_list.append(context)
592
+ res_context_type.append(ContextType.OTHER)
593
+
594
+ def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
595
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
596
+ """Merge anything in the context to simplify the inputs"""
597
+ context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
598
+ context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs)
599
+
600
+ res: List[Context] = [] # result of context_list
601
+ res_loss_scale: List[float] = [] # result of loss_scale_list
602
+ temp: List[str] = []
603
+ temp_loss_scale = 0.
604
+ for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
605
+ if isinstance(context, str) and (loss_scale == temp_loss_scale):
606
+ temp.append(context)
607
+ else:
608
+ if len(temp) > 0:
609
+ res.append(''.join(temp))
610
+ res_loss_scale.append(temp_loss_scale)
611
+ temp.clear()
612
+ if isinstance(context, str): # loss_scale diff
613
+ temp.append(context)
614
+ else:
615
+ res.append(context)
616
+ res_loss_scale.append(loss_scale)
617
+ temp_loss_scale = loss_scale
618
+ if len(temp) > 0:
619
+ res.append(''.join(temp))
620
+ res_loss_scale.append(temp_loss_scale)
621
+
622
+ return res, res_loss_scale
623
+
624
+ @staticmethod
625
+ def _split_special_tokens(context_list: List[Context],
626
+ loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
627
+ """Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation"""
628
+ res: List[Context] = []
629
+ loss_scale_res: List[float] = []
630
+ for context, loss_scale in zip(context_list, loss_scale_list):
631
+ contexts = []
632
+ if isinstance(fetch_one(context), str):
633
+ for d in split_str_parts_by(context, Template.special_tokens):
634
+ contexts.extend([d['key'], d['content']])
635
+ contexts = [c for c in contexts if c]
636
+ res.extend(contexts)
637
+ loss_scale_res.extend([loss_scale] * len(contexts))
638
+ else:
639
+ res.append(context)
640
+ loss_scale_res.append(loss_scale)
641
+ return res, loss_scale_res
642
+
643
+ def _tokenize(self, context, **tokenizer_kwargs):
644
+ return self.tokenizer(
645
+ context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
646
+
647
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
648
+ inputs: StdTemplateInputs) -> List[Context]:
649
+ """Override this function to do your own replace operation.
650
+
651
+ This method is used to replace standard tags like `<image>` to some tokens that the model needs.
652
+
653
+ Args:
654
+ media_type: The modal.
655
+ index: The index of the medias, for index 0 represents the first elements in `images`
656
+ inputs: The inputs
657
+
658
+ Returns:
659
+ The content or input_ids after replacement.
660
+ """
661
+ if media_type == 'image':
662
+ if self.mode == 'lmdeploy':
663
+ return [[-100]]
664
+ return self.image_placeholder
665
+ elif media_type == 'video':
666
+ return self.video_placeholder
667
+ elif media_type == 'audio':
668
+ return self.audio_placeholder
669
+
670
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
671
+ """Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task.
672
+ Override this function to do your own replace operation.
673
+
674
+ Args:
675
+ ref: Description of the bbox
676
+ index: The index in the `objects` key
677
+ inputs: The inputs
678
+
679
+ Returns:
680
+ The contents or input_ids replaced
681
+ """
682
+ return [ref]
683
+
684
+ def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]:
685
+ """Replace the cot process label for PRM training or inference.
686
+ Override this function to do your own replace operation.
687
+
688
+ Args:
689
+ inputs: The inputs
690
+
691
+ Returns:
692
+ The contents or input_ids replaced
693
+ """
694
+ return [self.cot_process_placeholder]
695
+
696
+ @staticmethod
697
+ def _get_bbox_str(bbox: List[int]) -> str:
698
+ point = []
699
+ for x, y in zip(bbox[::2], bbox[1::2]):
700
+ point.append(f'({x},{y})')
701
+ return ','.join(point)
702
+
703
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
704
+ """Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task.
705
+ Override this function to do your own replace operation.
706
+
707
+ Args:
708
+ bbox: [x, y] or [x1, y1, x2, y2]
709
+ index: The index in the `objects` key
710
+ inputs: The inputs
711
+
712
+ Returns:
713
+ The contents or input_ids replaced
714
+ """
715
+ return [f'[{self._get_bbox_str(bbox)}]']
716
+
717
+ def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float],
718
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
719
+ # https://github.com/modelscope/ms-swift/issues/3407
720
+ # Fix the bounding box position offset issue in the Qwen2.5-VL grounding task.
721
+ res: List[Context] = []
722
+ res_loss_scale: List[float] = []
723
+ inputs.image_idx = 0
724
+
725
+ for context, loss_scale in zip(context_list, loss_scale_list):
726
+ if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
727
+ c_list = self.replace_tag('image', inputs.image_idx, inputs)
728
+ inputs.image_idx += 1
729
+ loss_scale = 0. if self.template_backend == 'swift' else 1.
730
+ else:
731
+ c_list = [context]
732
+ res += c_list
733
+ res_loss_scale += [loss_scale] * len(c_list)
734
+ return res, res_loss_scale
735
+
736
+ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
737
+ inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
738
+ """This method happens before tokenization, replace standard tags to the contents or input_ids needed by
739
+ the model.
740
+
741
+ Args:
742
+ context_list: The content list
743
+ loss_scale_list: The loss scale list
744
+ Returns:
745
+ The context_list and loss_scale_list after replacement.
746
+ """
747
+ context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs)
748
+ if inputs.images and inputs.objects:
749
+ self.normalize_bbox(inputs)
750
+ # replace tag/object/box
751
+ res: List[Context] = [] # result of context_list
752
+ res_loss_scale: List[float] = [] # result of loss_scale_list
753
+
754
+ # reset
755
+ for k in ['video', 'audio', 'object', 'box']:
756
+ setattr(inputs, f'{k}_idx', 0)
757
+
758
+ for context, loss_scale in zip(context_list, loss_scale_list):
759
+ for k in ['video', 'audio']:
760
+ if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len(
761
+ getattr(inputs, f'{k}s')):
762
+ c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs)
763
+ setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1)
764
+ loss_scale = 0.
765
+ break
766
+ else:
767
+ ref = inputs.objects.get('ref') or []
768
+ bbox = inputs.objects.get('bbox') or []
769
+ if context == '<ref-object>' and inputs.ref_idx < len(ref):
770
+ idx = inputs.ref_idx
771
+ c_list = self.replace_ref(ref[idx], idx, inputs)
772
+ inputs.ref_idx += 1
773
+ elif context == '<bbox>' and inputs.bbox_idx < len(bbox):
774
+ idx = inputs.bbox_idx
775
+ c_list = self.replace_bbox(bbox[idx], idx, inputs)
776
+ inputs.bbox_idx += 1
777
+ elif context == '<cot-process>' and self.mode == 'prm':
778
+ c_list = self.replace_cot_process(inputs)
779
+ else:
780
+ c_list = [context]
781
+ res += c_list
782
+ res_loss_scale += [loss_scale] * len(c_list)
783
+ return res, res_loss_scale
784
+
785
+ @staticmethod
786
+ def _add_default_tags(inputs: StdTemplateInputs):
787
+ total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
788
+ if inputs.rejected_response:
789
+ if isinstance(inputs.rejected_response, str):
790
+ total_content += inputs.rejected_response
791
+ else:
792
+ total_content += '\n'.join(inputs.rejected_response)
793
+ if inputs.system:
794
+ total_content = f'{inputs.system}\n{total_content}'
795
+ for media_type in ['image', 'audio', 'video']:
796
+ media_key, media_tag = f'{media_type}s', f'<{media_type}>'
797
+ medias = getattr(inputs, media_key)
798
+ if not isinstance(medias, list):
799
+ medias = [medias]
800
+ if medias:
801
+ num_media_tags = len(re.findall(media_tag, total_content))
802
+ num_media = len(medias)
803
+ num_new_tags = num_media - num_media_tags
804
+ if num_new_tags > 0:
805
+ inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content']
806
+ elif num_new_tags < 0:
807
+ logger.warning(
808
+ f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. '
809
+ 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.')
810
+
811
+ def _encode_context_list(
812
+ self,
813
+ context_list: List[Context],
814
+ loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
815
+ """return: input_ids, labels, tokenizer_kwargs"""
816
+ input_ids: List[int] = []
817
+ labels: List[int] = []
818
+ loss_scale: List[float] = []
819
+ tokenizer_kwargs = {}
820
+ if loss_scale_list is None:
821
+ loss_scale_list = [0.] * len(context_list)
822
+ ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
823
+ for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
824
+ if isinstance(context, str):
825
+ # tokenizer_kwargs is the returned tokenizer_kwargs,
826
+ # while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
827
+ token_list = self._tokenize(context)
828
+ else:
829
+ token_list = context
830
+ input_ids += token_list
831
+ if loss_scale_list[i] > 0.0:
832
+ labels += token_list
833
+ else:
834
+ labels += [-100] * len(token_list)
835
+ if not ignore_loss_scale:
836
+ loss_scale.extend([loss_weight] * len(token_list))
837
+ if ignore_loss_scale:
838
+ loss_scale = None
839
+ return input_ids, labels, loss_scale, tokenizer_kwargs
840
+
841
+ @staticmethod
842
+ def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]],
843
+ suffix_tokens_id: List[int]) -> None:
844
+ suffix_len = len(suffix_tokens_id)
845
+ start = 0
846
+ for i in range(1, len(labels)):
847
+ if labels[i - 1] >= 0 and labels[i] == -100:
848
+ start = i
849
+ if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
850
+ # [0, 1, 2, -100(start), -100, 3(i), 4]
851
+ length = i - start
852
+ if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id:
853
+ labels[start:start + suffix_len] = suffix_tokens_id
854
+ if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len:
855
+ loss_scale[start:start + suffix_len] = [1] * suffix_len
856
+
857
+ @staticmethod
858
+ def _get_std_messages(messages):
859
+ if messages and messages[0]['role'] == 'assistant':
860
+ messages.insert(0, {'role': 'user', 'content': ''}) # pretrain
861
+ if len(messages) % 2 == 1:
862
+ messages.append({'role': 'assistant', 'content': None}) # inference
863
+
864
+ def _jinja_encode(self, inputs: StdTemplateInputs):
865
+ messages = inputs.messages.copy()
866
+ if inputs.system is not None:
867
+ messages.insert(0, {'role': 'system', 'content': inputs.system})
868
+ if messages[-1]['content'] is None:
869
+ messages.pop()
870
+ add_generation_prompt = messages[-1]['role'] != 'assistant'
871
+ kwargs = {}
872
+ if inputs.tools:
873
+ kwargs['tools'] = inputs.tools
874
+ text = self.tokenizer.apply_chat_template(
875
+ messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
876
+ answer_len = 1 if self.is_training else 0
877
+ return [text], [1.], answer_len
878
+
879
+ def _get_system(self, inputs) -> Optional[str]:
880
+ template_meta = self.template_meta
881
+ system = inputs.system
882
+ tools = inputs.tools
883
+ template_meta.check_system(system)
884
+ if system is None:
885
+ system = template_meta.default_system
886
+
887
+ if tools is not None:
888
+ system = self.agent_template._format_tools(tools, system or '', inputs.messages[0])
889
+ return system
890
+
891
+ @staticmethod
892
+ def _swift_prepare_function_call(agent_template, messages):
893
+ if len(messages) < 2:
894
+ return
895
+ i = 1
896
+ while i < len(messages):
897
+ pre_message, message = messages[i - 1], messages[i]
898
+ pre_role, pre_content = pre_message['role'], pre_message['content']
899
+ role, content = message['role'], message['content']
900
+ if pre_role == 'assistant' and role == 'tool':
901
+ i_start = i
902
+ while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
903
+ i += 1
904
+ pre_message['content'], tool_content = agent_template._format_tool_responses(
905
+ pre_content, messages[i_start:i + 1])
906
+ messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}]
907
+ i = i_start + 1
908
+ elif pre_role == 'assistant' and role == 'assistant':
909
+ # Consecutive messages from the assistant role need to be merged to prevent errors.
910
+ pre_message['content'] = pre_content + content
911
+ messages.pop(i)
912
+ else:
913
+ i += 1
914
+
915
+ def _swift_encode(self, inputs: StdTemplateInputs):
916
+ template_meta = self.template_meta
917
+ system = self._get_system(inputs)
918
+ self._swift_prepare_function_call(self.agent_template, inputs.messages)
919
+
920
+ self._get_std_messages(inputs.messages)
921
+ n_round = len(inputs.messages) // 2
922
+ if n_round > 1 and not self.template_meta.support_multi_round:
923
+ logger.warning_once(
924
+ 'The template does not support multi-round chat. Only use the last round of the conversation.')
925
+ inputs.messages = inputs.messages[-2:]
926
+
927
+ res_context_list: List[Context] = []
928
+ res_context_types: List[ContextType] = []
929
+ sep_token = None
930
+ if template_meta.auto_add_bos:
931
+ all_tokens = self.tokenizer.encode('a')
932
+ single_token = self.tokenizer.encode('a', add_special_tokens=False)
933
+ assert len(single_token) == 1
934
+ idx = all_tokens.index(single_token[0])
935
+ bos_token = all_tokens[:idx]
936
+ sep_token = all_tokens[idx + 1:]
937
+ if bos_token:
938
+ res_context_list.append(bos_token)
939
+ res_context_types.append(ContextType.OTHER)
940
+
941
+ prefix = template_meta.system_prefix if system else template_meta.prefix
942
+ self._concat_context_list(prefix, res_context_list, res_context_types, system=system)
943
+
944
+ n_round = len(inputs.messages) // 2
945
+ for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
946
+ query_role, query = query_message['role'], query_message['content']
947
+ response_role, response = response_message['role'], response_message['content']
948
+ # TODO: Optimize the Template mechanism.
949
+ assert query_role in {'user', 'tool'}, f'query_role: {query_role}'
950
+ assert response_role in {'assistant'}, f'response_role: {response_role}'
951
+ if query_role == 'tool':
952
+ prompt = query
953
+ query = ''
954
+ elif template_meta.is_post_system and i == n_round - 1:
955
+ prompt = template_meta.system_prompt
956
+ else:
957
+ prompt = template_meta.prompt
958
+
959
+ context_list = prompt.copy()
960
+ extra_context_list = []
961
+ extra_context_type = None
962
+ if i < n_round - 1:
963
+ # Not the last round.
964
+ context_list.append('{{RESPONSE}}')
965
+ if inputs.messages[2 * (i + 1)]['role'] != 'tool':
966
+ extra_context_list = template_meta.chat_sep
967
+ extra_context_type = ContextType.OTHER
968
+ elif response is not None:
969
+ # It is the final round, and the response exists (during training).
970
+ context_list.append('{{RESPONSE}}')
971
+ if self.is_training and not sep_token:
972
+ extra_context_list = template_meta.suffix
973
+ extra_context_type = ContextType.SUFFIX
974
+ elif template_meta.response_prefix:
975
+ # final round and during inference.
976
+ context_list.append(template_meta.response_prefix)
977
+
978
+ self._concat_context_list(
979
+ context_list,
980
+ res_context_list,
981
+ res_context_types,
982
+ query=query,
983
+ response=response,
984
+ system=system,
985
+ round0=i)
986
+ res_context_list += extra_context_list
987
+ res_context_types += [extra_context_type] * len(extra_context_list)
988
+ if template_meta.auto_add_bos and sep_token:
989
+ res_context_list.append(sep_token)
990
+ res_context_types.append(ContextType.SUFFIX)
991
+ from swift.plugin import loss_scale_map
992
+ res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types,
993
+ inputs.messages)
994
+ if self.is_training:
995
+ answer_len = len(extra_context_list) + bool(response is not None)
996
+ else:
997
+ answer_len = 0
998
+ return res_context_list, loss_scale_list, answer_len
999
+
1000
+ def _encode_truncated(self, inputs):
1001
+ if self.mode in {'vllm', 'lmdeploy'}:
1002
+ encoded = Template._encode(self, inputs)
1003
+ for key in ['images', 'audios', 'videos']:
1004
+ encoded[key] = getattr(inputs, key)
1005
+ else:
1006
+ encoded = self._encode(inputs)
1007
+
1008
+ input_ids = encoded.get('input_ids')
1009
+ labels = encoded.get('labels')
1010
+ loss_scale = encoded.get('loss_scale')
1011
+ if self.max_length is not None:
1012
+ if self.truncation_strategy == 'right':
1013
+ input_ids = input_ids[:self.max_length]
1014
+ if labels is not None:
1015
+ labels = labels[:self.max_length]
1016
+ if loss_scale is not None:
1017
+ loss_scale = loss_scale[:self.max_length]
1018
+ elif self.truncation_strategy == 'left':
1019
+ if len(input_ids) > self.max_length:
1020
+ logger.warning_once(
1021
+ 'Input data was left-truncated because its length exceeds `max_length` (input length: '
1022
+ f'{len(input_ids)}, max_length: {self.max_length}). '
1023
+ 'This may cause loss of important tokens (e.g., image tokens) and lead to errors. '
1024
+ 'To avoid this, consider increasing `max_length` or pre-filtering long sequences.',
1025
+ hash_id='max_length_check')
1026
+ input_ids = input_ids[-self.max_length:]
1027
+ if labels is not None:
1028
+ labels = labels[-self.max_length:]
1029
+ if loss_scale is not None:
1030
+ loss_scale = loss_scale[-self.max_length:]
1031
+ elif self.truncation_strategy == 'raise':
1032
+ length = len(input_ids or labels or [])
1033
+ if length > self.max_length:
1034
+ raise MaxLengthError(f'Current length of row({length}) is larger'
1035
+ f' than the max_length({self.max_length}).')
1036
+ encoded['input_ids'] = input_ids
1037
+ encoded['labels'] = labels
1038
+ encoded['loss_scale'] = loss_scale
1039
+ return encoded
1040
+
1041
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
1042
+ template_backend = self.template_backend
1043
+ if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
1044
+ and self.mode != 'seq_cls'):
1045
+ template_backend = 'jinja'
1046
+ logger.info_once(f'Setting template_backend: {template_backend}')
1047
+ res_context_list, loss_scale_list, answer_len = (
1048
+ self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs))
1049
+ encoded = {}
1050
+ if self.is_encoder_decoder:
1051
+ # tokenizer_kwargs: use prompt (qwen-audio)
1052
+ total_len = len(res_context_list)
1053
+ for key, _slice in zip(['prompt', 'answer'],
1054
+ [slice(0, total_len - answer_len),
1055
+ slice(total_len - answer_len, total_len)]):
1056
+ context_list, loss_scale = self._simplify_context_list(res_context_list[_slice],
1057
+ loss_scale_list[_slice], inputs)
1058
+ input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale)
1059
+ encoded[f'{key}_input_ids'] = input_ids
1060
+ if key == 'answer':
1061
+ encoded['labels'] = labels
1062
+ encoded['loss_scale'] = loss_scale
1063
+ input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids']
1064
+ else:
1065
+ res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
1066
+ input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
1067
+ res_context_list, loss_scale_list)
1068
+ self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0])
1069
+
1070
+ if tokenizer_kwargs:
1071
+ encoded['tokenizer_kwargs'] = tokenizer_kwargs
1072
+
1073
+ encoded['input_ids'] = input_ids
1074
+ encoded['labels'] = labels
1075
+ encoded['loss_scale'] = loss_scale
1076
+ if self.use_megatron:
1077
+ self._handle_megatron_cp(encoded)
1078
+ encoded['labels'] = encoded['labels'][1:] + [-100]
1079
+ encoded['position_ids'] = list(range(len(encoded['labels'])))
1080
+ elif encoded.get('labels') is not None:
1081
+ encoded['labels'][0] = -100
1082
+ if not self.is_training:
1083
+ for k in list(encoded.keys()):
1084
+ if k.endswith('labels') or k.endswith('loss_scale'):
1085
+ encoded[k] = None
1086
+ return encoded
1087
+
1088
+ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None:
1089
+ cp_size = self.sequence_parallel_size
1090
+ if cp_size == 1:
1091
+ return
1092
+ input_ids = encoded['input_ids']
1093
+ padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids)
1094
+ input_ids += [self.tokenizer.pad_token_id] * padding_len
1095
+ encoded['labels'] += [-100] * padding_len
1096
+
1097
+ def debug_logger(self, inputs):
1098
+ if not strtobool(os.getenv('SWIFT_DEBUG', 'false')):
1099
+ return
1100
+ if 'input_ids' in inputs:
1101
+ k = 'input_ids'
1102
+ val = inputs['input_ids']
1103
+ elif 'generate_ids' in inputs:
1104
+ k = 'generate_ids'
1105
+ val = inputs['generate_ids']
1106
+ for v in val:
1107
+ self.print_inputs({k: v.tolist()})
1108
+
1109
+ @staticmethod
1110
+ def _split_list(inputs: List[int], x: int) -> List[List[int]]:
1111
+ idxs = findall(inputs, x)
1112
+ idxs.append(len(inputs))
1113
+ res = []
1114
+ lo = 0
1115
+ for idx in idxs:
1116
+ res.append(inputs[lo:idx])
1117
+ lo = idx + 1
1118
+ return res
1119
+
1120
+ def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
1121
+ context_list = []
1122
+ if self.mode in {'vllm', 'lmdeploy'}:
1123
+ video = inputs.videos.pop(inputs.video_idx)
1124
+ inputs.video_idx -= 1
1125
+ else:
1126
+ video = inputs.videos[inputs.video_idx]
1127
+ images = inputs.images
1128
+ new_images = load_video_func(video)
1129
+ inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
1130
+ for i in range(len(new_images)):
1131
+ context_list += replace_tag(i)
1132
+ inputs.image_idx += len(new_images)
1133
+ return context_list
1134
+
1135
+ def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]],
1136
+ num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]:
1137
+ if self.skip_prompt:
1138
+ generate_ids = generate_ids[..., num_prompt_tokens:]
1139
+ return generate_ids
1140
+
1141
+ def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str:
1142
+ return response
1143
+
1144
+ def pre_forward_hook(self, model: nn.Module, args, kwargs):
1145
+ from swift.llm import to_device
1146
+ old_kwargs = to_device(kwargs, model.device)
1147
+ kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
1148
+ for k, v in old_kwargs.items():
1149
+ if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
1150
+ kwargs[k] = v
1151
+ if 'inputs_embeds' in kwargs:
1152
+ kwargs.pop('input_ids', None)
1153
+
1154
+ if isinstance(model, PeftModel):
1155
+ parameters = inspect.signature(model.model.forward).parameters
1156
+ else:
1157
+ parameters = inspect.signature(model.forward).parameters
1158
+ if 'position_ids' not in parameters:
1159
+ kwargs.pop('position_ids', None)
1160
+ return args, kwargs
1161
+
1162
+ @property
1163
+ def is_training(self):
1164
+ return self.mode not in {'vllm', 'lmdeploy', 'pt'}
1165
+
1166
+ def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
1167
+ self.mode = mode
1168
+
1169
+ def register_post_encode_hook(self, models: List[nn.Module]) -> None:
1170
+ """This function is important for multi-modal training, as it registers the post_encode method
1171
+ as a forward hook, converting input_ids into inputs_embeds.
1172
+ """
1173
+ if self._handles:
1174
+ return
1175
+
1176
+ for model in models:
1177
+ # please use torch>=2.0
1178
+ handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True)
1179
+ self._handles.append((model, handle))
1180
+
1181
+ if is_deepspeed_zero3_enabled():
1182
+ import deepspeed
1183
+ self._deepspeed_initialize = deepspeed.initialize
1184
+
1185
+ @wraps(self._deepspeed_initialize)
1186
+ def _initialize(*args, **kwargs):
1187
+ res = self._deepspeed_initialize(*args, **kwargs)
1188
+ for model, handle in self._handles:
1189
+ model._forward_pre_hooks.move_to_end(handle.id)
1190
+ return res
1191
+
1192
+ deepspeed.initialize = _initialize
1193
+
1194
+ def remove_post_encode_hook(self):
1195
+ models = []
1196
+ for model, handle in self._handles:
1197
+ models.append(model)
1198
+ handle.remove()
1199
+ self._handles = []
1200
+
1201
+ if self._deepspeed_initialize is not None:
1202
+ import deepspeed
1203
+ deepspeed.initialize = self._deepspeed_initialize
1204
+ self._deepspeed_initialize = None
1205
+ return models
1206
+
1207
+ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1208
+ if self.mode == 'rlhf':
1209
+ return self._rlhf_data_collator(batch, padding_to=padding_to)
1210
+ elif self.mode == 'kto':
1211
+ return self._kto_data_collator(batch, padding_to=padding_to)
1212
+ elif self.mode in {'pt', 'train', 'prm'}:
1213
+ return self._data_collator(batch, padding_to=padding_to)
1214
+ elif self.mode == 'seq_cls':
1215
+ return self._seq_cls_data_collator(batch, padding_to=padding_to)
1216
+ elif self.mode == 'embedding':
1217
+ return self._embedding_data_collator(batch, padding_to=padding_to)
1218
+
1219
+ @staticmethod
1220
+ def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]:
1221
+ new_batch = []
1222
+ for inputs in batch:
1223
+ new_inputs = {}
1224
+ for k, v in inputs.items():
1225
+ if k.startswith(prefix):
1226
+ new_inputs[k[len(prefix):]] = v
1227
+ new_batch.append(new_inputs)
1228
+ return new_batch
1229
+
1230
+ @staticmethod
1231
+ def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]:
1232
+ from swift.llm import RowPreprocessor
1233
+ keys = keys or []
1234
+ rows = RowPreprocessor.rows_to_batched(batch)
1235
+ return {k: rows[k] for k in keys if rows.get(k) is not None}
1236
+
1237
+ @staticmethod
1238
+ def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
1239
+ # List[Tensor] -> List[Tensor]
1240
+ res = []
1241
+ for b in batch:
1242
+ if b.get(attr_name) is not None:
1243
+ res += b.pop(attr_name)
1244
+ return res
1245
+
1246
+ @staticmethod
1247
+ def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]:
1248
+ res = []
1249
+ for b in batch:
1250
+ if b.get(attr_name) is not None:
1251
+ res.append(b.pop(attr_name))
1252
+ return torch.concat(res, dim=dim) if res else None
1253
+
1254
+ def _rlhf_data_collator(self,
1255
+ batch: List[Dict[str, Any]],
1256
+ *,
1257
+ chosen_prefix: str = 'chosen_',
1258
+ rejected_prefix: str = 'rejected_',
1259
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1260
+ new_batch = []
1261
+ for prefix in [chosen_prefix, rejected_prefix]:
1262
+ new_batch += self._fetch_inputs_startswith(batch, prefix)
1263
+ return self._data_collator(new_batch, padding_to=padding_to)
1264
+
1265
+ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1266
+ new_batch = self._fetch_inputs_startswith(batch, 'chosen_')
1267
+ kl_batch = self._fetch_inputs_startswith(batch, 'rejected_')
1268
+
1269
+ res = self._data_collator(new_batch, padding_to=padding_to)
1270
+ kl_res = self._data_collator(kl_batch, padding_to=padding_to)
1271
+ res = {
1272
+ **{f'completion_{k}': v
1273
+ for k, v in res.items()},
1274
+ **{f'KL_completion_{k}': v
1275
+ for k, v in kl_res.items()},
1276
+ }
1277
+ label = [b['label'] for b in batch if b.get('label') is not None]
1278
+ if label:
1279
+ res['label'] = label
1280
+ return res
1281
+
1282
+ def _embedding_data_collator(self,
1283
+ batch: List[Dict[str, Any]],
1284
+ *,
1285
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1286
+ labels = []
1287
+ new_batch = []
1288
+ for b in batch:
1289
+ keys = [key for key in b.keys() if 'negative' in key]
1290
+ max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
1291
+ indexes = ['anchor_', 'positive_']
1292
+ if max_neg is not None:
1293
+ for i in range(0, max_neg + 1):
1294
+ indexes.append(f'negative{i}_')
1295
+ for prefix in indexes:
1296
+ new_batch += self._fetch_inputs_startswith([b], prefix)
1297
+ labels.extend(b.get('labels', None))
1298
+ res = self._data_collator(new_batch, padding_to=padding_to)
1299
+ if labels:
1300
+ res['labels'] = torch.tensor(labels, dtype=torch.float32)
1301
+ return res
1302
+
1303
+ def _seq_cls_data_collator(self,
1304
+ batch: List[Dict[str, Any]],
1305
+ *,
1306
+ padding_to: Optional[int] = None) -> Dict[str, Any]:
1307
+ labels = [b.pop('labels') for b in batch if b.get('labels') is not None]
1308
+ res = self._data_collator(batch, padding_to=padding_to)
1309
+ if labels:
1310
+ problem_type = self._get_problem_type(self.config)
1311
+ if problem_type == 'regression':
1312
+ labels = torch.tensor(labels, dtype=torch.float32)
1313
+ elif problem_type == 'multi_label_classification':
1314
+ one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32)
1315
+ for i, label in enumerate(labels):
1316
+ one_hot_labels[i, label] = 1
1317
+ labels = one_hot_labels
1318
+ else:
1319
+ labels = torch.tensor(labels, dtype=torch.long)
1320
+ res['labels'] = labels
1321
+ return res
1322
+
1323
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
1324
+ """
1325
+ Args:
1326
+ batch(`List[Dict[str, Any]]`): The input data in batch
1327
+ padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
1328
+ will be padded to the `longest`
1329
+ """
1330
+ assert self.tokenizer.pad_token_id is not None
1331
+ padding_side = self.padding_side if self.is_training else 'left'
1332
+ padding_right = padding_side == 'right'
1333
+ packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0]
1334
+ res = {}
1335
+ if packing_mode:
1336
+ # only support llm
1337
+ for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']:
1338
+ v = self.gather_list(batch, k)
1339
+ if v:
1340
+ res[k] = [v]
1341
+ else:
1342
+ inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
1343
+ input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
1344
+ if inputs_embeds:
1345
+ res['inputs_embeds'] = inputs_embeds
1346
+ if input_ids:
1347
+ res['input_ids'] = input_ids
1348
+ for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
1349
+ val = [b[key] for b in batch if b.get(key) is not None]
1350
+ if val:
1351
+ res[key] = val
1352
+
1353
+ keys = [
1354
+ 'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids'
1355
+ ]
1356
+ pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
1357
+ # Convert to tensor and remove unnecessary dimensions.
1358
+ seq_lens = None
1359
+ for key in keys:
1360
+ if key not in res:
1361
+ continue
1362
+ for i, val in enumerate(res[key]):
1363
+ if isinstance(val, (list, tuple)):
1364
+ val = torch.tensor(val)
1365
+ elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2:
1366
+ val = val[0]
1367
+ res[key][i] = val
1368
+ if not seq_lens:
1369
+ seq_lens = [seq.shape[0] for seq in res[key]]
1370
+ if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res):
1371
+ res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens]
1372
+ if self.is_training and self.padding_side == 'left':
1373
+ res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens]
1374
+
1375
+ if self.use_megatron:
1376
+ padding_to = math.ceil(max(seq_lens) / 128) * 128
1377
+ cp_size = self.sequence_parallel_size
1378
+ if cp_size > 1:
1379
+ padding_len = padding_to - seq_lens[0]
1380
+ position_ids = res['position_ids'][0].tolist()
1381
+ position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2))
1382
+ res['position_ids'][0] = torch.tensor(position_ids)
1383
+
1384
+ for key, pad_value in zip(keys, pad_values):
1385
+ if key not in res:
1386
+ continue
1387
+ if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1:
1388
+ pass
1389
+ elif padding_to is not None:
1390
+ padding_len = padding_to - seq_lens[0]
1391
+ if padding_len > 0:
1392
+ res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
1393
+ 'constant', pad_value)
1394
+ res[key] = self._pad_sequence(res[key], pad_value)
1395
+
1396
+ # multimodal
1397
+ res.update(self._data_collator_mm_data(batch))
1398
+ if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1):
1399
+ res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)
1400
+
1401
+ return res
1402
+
1403
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
1404
+ # multimodal
1405
+ res = {}
1406
+ pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
1407
+ if len(pixel_values) > 0:
1408
+ res['pixel_values'] = torch.concat(pixel_values)
1409
+
1410
+ image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
1411
+ if len(image_sizes) > 0:
1412
+ res['image_sizes'] = torch.concat(image_sizes)
1413
+
1414
+ pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
1415
+ if len(pixel_values_videos) > 0:
1416
+ res['pixel_values_videos'] = torch.concat(pixel_values_videos)
1417
+ return res
1418
+
1419
+ def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side):
1420
+ # torchacc & xtuner
1421
+ input_ids = res.get('input_ids')
1422
+ attention_mask = res.get('attention_mask')
1423
+ labels = res.get('labels')
1424
+ loss_scale = res.get('loss_scale')
1425
+ if use_torchacc():
1426
+ from swift.utils.torchacc_utils import pad_and_split_batch
1427
+ rank, _, world_size, _ = get_dist_setting()
1428
+ input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
1429
+ padding_to,
1430
+ input_ids,
1431
+ attention_mask,
1432
+ labels,
1433
+ loss_scale,
1434
+ self.max_length,
1435
+ tokenizer,
1436
+ rank,
1437
+ world_size,
1438
+ padding_right=padding_side == 'right')
1439
+ if self.sequence_parallel_size > 1 and input_ids is not None:
1440
+ bs, seq_len = input_ids.shape
1441
+ if 'position_ids' not in res:
1442
+ position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
1443
+ else:
1444
+ position_ids = res['position_ids']
1445
+ assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right'
1446
+ from swift.trainers.sequence_parallel import sequence_parallel
1447
+ if sequence_parallel.world_size() > 1:
1448
+ from swift.trainers.sequence_parallel import sequence_parallel
1449
+ input_ids, _, labels, position_ids, attention_mask, loss_scale = \
1450
+ sequence_parallel.pad_and_split_inputs(
1451
+ tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale)
1452
+ res['position_ids'] = position_ids
1453
+ _local_var = locals()
1454
+ for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']:
1455
+ value = _local_var[key]
1456
+ if value is not None:
1457
+ res[key] = value
1458
+ return res
1459
+
1460
+ def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
1461
+ if tokenizer_kwargs is None:
1462
+ tokenizer_kwargs = {}
1463
+ for key in [
1464
+ 'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
1465
+ ]:
1466
+ val = inputs.get(key) # fix val is a tensor
1467
+ if val is None:
1468
+ val = inputs.get(f'{key}_ids')
1469
+ if val is not None:
1470
+ key_upper = key.upper()
1471
+ logger.info(f'[{key_upper}_IDS] {val}')
1472
+ if key == 'labels' and self.mode in {'seq_cls', 'embedding'}:
1473
+ continue
1474
+ if isinstance(val, (list, tuple, torch.Tensor)):
1475
+ val_str = self.safe_decode(val, **tokenizer_kwargs)
1476
+ logger.info(f'[{key_upper}] {val_str}')
1477
+ if inputs.get('loss_scale') is not None:
1478
+ val = inputs['loss_scale']
1479
+ logger.info(f'[LOSS_SCALE] {val}')
1480
+
1481
+ async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None:
1482
+ images = inputs.pop('images', None) or []
1483
+ if len(images) == 0:
1484
+ return
1485
+ input_ids = inputs['input_ids']
1486
+ idx_list = findall(input_ids, -100)
1487
+ assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
1488
+ idx_list.insert(0, -1)
1489
+ new_input_ids = []
1490
+ for i in range(len(idx_list) - 1):
1491
+ new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
1492
+ images[i]['offset'] = len(new_input_ids)
1493
+ new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens']
1494
+ new_input_ids += input_ids[idx_list[-1] + 1:]
1495
+ inputs['input_ids'] = new_input_ids
1496
+ inputs['multimodal'] = images
1497
+
1498
+ async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
1499
+ images = inputs.pop('images', None) or []
1500
+ if len(images) == 0:
1501
+ return
1502
+ from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
1503
+ input_ids = inputs['input_ids']
1504
+ idx_list = findall(input_ids, -100)
1505
+ assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
1506
+ idx_list.insert(0, -1)
1507
+ new_input_ids = []
1508
+ ranges = []
1509
+ for i in range(len(idx_list) - 1):
1510
+ _range = []
1511
+ new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
1512
+ _range.append(len(new_input_ids))
1513
+ new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0]
1514
+ _range.append(len(new_input_ids))
1515
+ ranges.append(_range)
1516
+ new_input_ids += input_ids[idx_list[-1] + 1:]
1517
+ inputs['input_embeddings'] = [image.to('cpu') for image in images]
1518
+ inputs['input_embedding_ranges'] = ranges
1519
+ inputs['input_ids'] = new_input_ids
1520
+
1521
+ def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor:
1522
+ """Pad sequence by some side
1523
+
1524
+ Args:
1525
+ sequences: The input sequences in tensor.
1526
+ padding_value: The padding value
1527
+
1528
+ Returns:
1529
+ A tensor after padding
1530
+ """
1531
+ padding_side = self.padding_side if self.is_training else 'left'
1532
+ padding_right = padding_side == 'right'
1533
+ if padding_right:
1534
+ return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
1535
+
1536
+ max_len = max([s.shape[0] for s in sequences])
1537
+
1538
+ padded_sequences = []
1539
+ for seq in sequences:
1540
+ pad_length = max_len - seq.shape[0]
1541
+ pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
1542
+ padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
1543
+ padded_sequences.append(padded_seq)
1544
+
1545
+ return torch.stack(padded_sequences)
1546
+
1547
+ def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str:
1548
+ if isinstance(self, Template):
1549
+ tokenizer = self.tokenizer
1550
+ placeholder_tokens = self.placeholder_tokens
1551
+ else:
1552
+ tokenizer = self
1553
+ placeholder_tokens = []
1554
+
1555
+ def _is_special(token: int) -> bool:
1556
+ if isinstance(token, float) or token < 0:
1557
+ return True
1558
+ return token in placeholder_tokens
1559
+
1560
+ if isinstance(input_ids, torch.Tensor):
1561
+ input_ids = input_ids.tolist()
1562
+ if len(input_ids) == 0:
1563
+ return ''
1564
+ result_str = ''
1565
+ for i in range(len(input_ids)):
1566
+ if i == 0:
1567
+ if _is_special(input_ids[i]):
1568
+ s = 0
1569
+ else:
1570
+ e = 0
1571
+ continue
1572
+ if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]):
1573
+ s = i
1574
+ result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
1575
+ if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
1576
+ e = i
1577
+ result_str += f'[{input_ids[i - 1]} * {e - s}]'
1578
+ if _is_special(input_ids[i]):
1579
+ result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
1580
+ else:
1581
+ result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
1582
+ return result_str
1583
+
1584
+ @staticmethod
1585
+ @contextmanager
1586
+ def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False):
1587
+ _origin_flash_attention_forward = modeling_module._flash_attention_forward
1588
+
1589
+ def _flash_attention_forward(*args, **kwargs):
1590
+ if use_new_func:
1591
+ from transformers.modeling_flash_attention_utils import (_flash_attention_forward as
1592
+ flash_attention_forward)
1593
+ if args and isinstance(args[0], nn.Module):
1594
+ args = args[1:]
1595
+ if 'is_causal' not in kwargs:
1596
+ kwargs['is_causal'] = True
1597
+ else:
1598
+ flash_attention_forward = _origin_flash_attention_forward
1599
+ kwargs['position_ids'] = position_ids
1600
+ return flash_attention_forward(*args, **kwargs)
1601
+
1602
+ modeling_module._flash_attention_forward = _flash_attention_forward
1603
+ try:
1604
+ yield
1605
+ finally:
1606
+ modeling_module._flash_attention_forward = _origin_flash_attention_forward
swift/llm/template/constant.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import List
4
+
5
+
6
+ class LLMTemplateType:
7
+ chatml = 'chatml'
8
+ default = 'default'
9
+ dummy = 'dummy'
10
+
11
+ qwen = 'qwen'
12
+ qwen2_5 = 'qwen2_5'
13
+ qwen2_5_math = 'qwen2_5_math'
14
+ qwen2_5_math_prm = 'qwen2_5_math_prm'
15
+ qwen3 = 'qwen3'
16
+ qwq_preview = 'qwq_preview'
17
+ qwq = 'qwq'
18
+ marco_o1 = 'marco_o1'
19
+ modelscope_agent = 'modelscope_agent'
20
+
21
+ llama = 'llama' # llama2
22
+ llama3 = 'llama3'
23
+ llama3_2 = 'llama3_2'
24
+ reflection = 'reflection'
25
+ megrez = 'megrez'
26
+ yi_coder = 'yi_coder'
27
+ sus = 'sus'
28
+
29
+ minimax = 'minimax'
30
+ minimax_vl = 'minimax_vl'
31
+
32
+ numina = 'numina'
33
+ ziya = 'ziya'
34
+ atom = 'atom'
35
+ mengzi = 'mengzi'
36
+
37
+ chatglm2 = 'chatglm2'
38
+ glm4 = 'glm4'
39
+ glm4_0414 = 'glm4_0414'
40
+ glm4_z1_rumination = 'glm4_z1_rumination'
41
+ codegeex4 = 'codegeex4'
42
+ longwriter_llama = 'longwriter_llama'
43
+
44
+ internlm = 'internlm'
45
+ internlm2 = 'internlm2'
46
+ internlm3 = 'internlm3'
47
+
48
+ deepseek = 'deepseek'
49
+ deepseek_coder = 'deepseek_coder'
50
+ deepseek_v2_5 = 'deepseek_v2_5'
51
+ deepseek_r1 = 'deepseek_r1'
52
+
53
+ openbuddy = 'openbuddy'
54
+ openbuddy2 = 'openbuddy2'
55
+ baichuan = 'baichuan'
56
+ baichuan_m1 = 'baichuan_m1'
57
+ minicpm = 'minicpm'
58
+ telechat = 'telechat'
59
+ telechat2 = 'telechat2'
60
+
61
+ codefuse = 'codefuse'
62
+ codefuse_codellama = 'codefuse_codellama'
63
+
64
+ skywork = 'skywork'
65
+ skywork_o1 = 'skywork_o1'
66
+
67
+ mistral_nemo = 'mistral_nemo'
68
+ mistral_2501 = 'mistral_2501'
69
+ zephyr = 'zephyr'
70
+ wizardlm2 = 'wizardlm2'
71
+ wizardlm2_moe = 'wizardlm2_moe'
72
+ gemma = 'gemma'
73
+ gemma3_text = 'gemma3_text'
74
+ phi3 = 'phi3'
75
+ phi4 = 'phi4'
76
+
77
+ ling = 'ling'
78
+ yuan = 'yuan'
79
+ xverse = 'xverse'
80
+ bluelm = 'bluelm'
81
+ orion = 'orion'
82
+ moonlight = 'moonlight'
83
+
84
+ aya = 'aya'
85
+ c4ai = 'c4ai'
86
+ dbrx = 'dbrx'
87
+
88
+
89
+ class RMTemplateType:
90
+ internlm2_reward = 'internlm2_reward'
91
+
92
+
93
+ class MLLMTemplateType:
94
+ qwen_vl = 'qwen_vl'
95
+ qwen_audio = 'qwen_audio'
96
+ qwen2_vl = 'qwen2_vl'
97
+ qwen2_5_vl = 'qwen2_5_vl'
98
+ qwen2_5_omni = 'qwen2_5_omni'
99
+ qwen2_audio = 'qwen2_audio'
100
+ qwen2_gme = 'qwen2_gme'
101
+ qvq = 'qvq'
102
+ ovis1_6 = 'ovis1_6'
103
+ ovis1_6_llama3 = 'ovis1_6_llama3'
104
+ ovis2 = 'ovis2'
105
+
106
+ llama3_1_omni = 'llama3_1_omni'
107
+ llama3_2_vision = 'llama3_2_vision'
108
+ llama4 = 'llama4'
109
+
110
+ llava1_5_hf = 'llava1_5_hf'
111
+ llava1_6_mistral_hf = 'llava1_6_mistral_hf'
112
+ llava1_6_vicuna_hf = 'llava1_6_vicuna_hf'
113
+ llava1_6_yi_hf = 'llava1_6_yi_hf'
114
+ llama3_llava_next_hf = 'llama3_llava_next_hf'
115
+ llava_next_qwen_hf = 'llava_next_qwen_hf'
116
+ llava_onevision_hf = 'llava_onevision_hf'
117
+ llava_next_video_hf = 'llava_next_video_hf'
118
+
119
+ llava_llama3_1_hf = 'llava_llama3_1_hf' # DaozeZhang
120
+ llava_llama3_hf = 'llava_llama3_hf' # xtuner
121
+ # lmms-lab
122
+ llava1_6_mistral = 'llava1_6_mistral'
123
+ llava1_6_yi = 'llava1_6_yi'
124
+ llava_next_qwen = 'llava_next_qwen'
125
+ llama3_llava_next = 'llama3_llava_next'
126
+
127
+ yi_vl = 'yi_vl'
128
+
129
+ internvl = 'internvl'
130
+ internvl_phi3 = 'internvl_phi3'
131
+ internvl2 = 'internvl2'
132
+ internvl2_phi3 = 'internvl2_phi3'
133
+ internvl2_5 = 'internvl2_5'
134
+
135
+ xcomposer2 = 'ixcomposer2'
136
+ xcomposer2_4khd = 'xcomposer2_4khd'
137
+ xcomposer2_5 = 'xcomposer2_5'
138
+
139
+ cogagent_chat = 'cogagent_chat'
140
+ cogagent_vqa = 'cogagent_vqa'
141
+ cogvlm = 'cogvlm'
142
+ cogvlm2 = 'cogvlm2'
143
+ cogvlm2_video = 'cogvlm2_video'
144
+ glm4v = 'glm4v'
145
+ glm_edge_v = 'glm_edge_v'
146
+
147
+ minicpmv = 'minicpmv'
148
+ minicpmv2_5 = 'minicpmv2_5'
149
+ minicpmv2_6 = 'minicpmv2_6'
150
+ minicpmo2_6 = 'minicpmo2_6'
151
+
152
+ deepseek_vl = 'deepseek_vl'
153
+ deepseek_vl2 = 'deepseek_vl2'
154
+ deepseek_janus = 'deepseek_janus'
155
+ deepseek_janus_pro = 'deepseek_janus_pro'
156
+
157
+ mplug_owl2 = 'mplug_owl2'
158
+ mplug_owl3 = 'mplug_owl3'
159
+ mplug_owl3_241101 = 'mplug_owl3_241101'
160
+ doc_owl2 = 'doc_owl2'
161
+
162
+ emu3_chat = 'emu3_chat'
163
+ emu3_gen = 'emu3_gen'
164
+
165
+ got_ocr2 = 'got_ocr2'
166
+ got_ocr2_hf = 'got_ocr2_hf'
167
+ step_audio = 'step_audio'
168
+ kimi_vl = 'kimi_vl'
169
+
170
+ idefics3 = 'idefics3'
171
+ pixtral = 'pixtral'
172
+ paligemma = 'paligemma'
173
+ phi3_vision = 'phi3_vision'
174
+ phi4_multimodal = 'phi4_multimodal'
175
+ florence = 'florence'
176
+ molmo = 'molmo'
177
+ megrez_omni = 'megrez_omni'
178
+ valley = 'valley'
179
+ gemma3_vision = 'gemma3_vision'
180
+ mistral_2503 = 'mistral_2503'
181
+
182
+
183
+ class TemplateType(LLMTemplateType, MLLMTemplateType, RMTemplateType):
184
+
185
+ @classmethod
186
+ def get_template_name_list(cls) -> List[str]:
187
+ res = []
188
+ for k in cls.__dict__.keys():
189
+ if k.startswith('__'):
190
+ continue
191
+ value = cls.__dict__[k]
192
+ if isinstance(value, str):
193
+ res.append(value)
194
+ return res
swift/llm/template/grounding.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import itertools
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Any, List, Literal
6
+
7
+ import requests
8
+ from modelscope.hub.utils.utils import get_cache_dir
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+
12
+ def _shuffle_colors(nums: List[Any]) -> List[Any]:
13
+ if len(nums) == 1:
14
+ return nums
15
+
16
+ mid = len(nums) // 2
17
+
18
+ left = nums[:mid]
19
+ right = nums[mid:]
20
+ left = _shuffle_colors(left)
21
+ right = _shuffle_colors(right)
22
+ new_nums = []
23
+ for x, y in zip(left, right):
24
+ new_nums += [x, y]
25
+ new_nums += left[len(right):] or right[len(left):]
26
+ return new_nums
27
+
28
+
29
+ def generate_colors():
30
+ vs_combinations = [(v, s) for v, s in itertools.product([0.7, 0.3, 1], [0.7, 0.3, 1])]
31
+ colors = [colorsys.hsv_to_rgb(i / 16, s, v) for v, s in vs_combinations for i in _shuffle_colors(list(range(16)))]
32
+ colors = [(int(r * 255), int(g * 255), int(b * 255)) for r, g, b in colors]
33
+ return _shuffle_colors(colors)
34
+
35
+
36
+ def download_file(url: str) -> str:
37
+ url = url.rstrip('/')
38
+ file_name = url.rsplit('/', 1)[-1]
39
+ cache_dir = os.path.join(get_cache_dir(), 'files')
40
+ os.makedirs(cache_dir, exist_ok=True)
41
+ req = requests.get(url)
42
+ file_path = os.path.join(cache_dir, file_name)
43
+ with open(file_path, 'wb') as f:
44
+ f.write(req.content)
45
+ return file_path
46
+
47
+
48
+ colors = generate_colors()
49
+ color_mapping = {}
50
+
51
+
52
+ def _calculate_brightness(image, region: List[int]):
53
+ cropped_image = image.crop(region)
54
+ grayscale_image = cropped_image.convert('L')
55
+ pixels = list(grayscale_image.getdata())
56
+ average_brightness = sum(pixels) / len(pixels)
57
+ return average_brightness
58
+
59
+
60
+ def draw_bbox(image: Image.Image,
61
+ ref: List[str],
62
+ bbox: List[List[int]],
63
+ norm_bbox: Literal['norm1000', 'none'] = 'norm1000'):
64
+ bbox = deepcopy(bbox)
65
+ font_path = 'https://modelscope.cn/models/Qwen/Qwen-VL-Chat/resolve/master/SimSun.ttf'
66
+ # norm bbox
67
+ for i, box in enumerate(bbox):
68
+ for i in range(len(box)):
69
+ box[i] = int(box[i])
70
+ if norm_bbox == 'norm1000':
71
+ box[0] = box[0] / 1000 * image.width
72
+ box[2] = box[2] / 1000 * image.width
73
+ box[1] = box[1] / 1000 * image.height
74
+ box[3] = box[3] / 1000 * image.height
75
+
76
+ draw = ImageDraw.Draw(image)
77
+ # draw bbox
78
+ assert len(ref) == len(bbox), f'len(refs): {len(ref)}, len(bboxes): {len(bbox)}'
79
+ for (left, top, right, bottom), box_ref in zip(bbox, ref):
80
+ if box_ref not in color_mapping:
81
+ color_mapping[box_ref] = colors[len(color_mapping) % len(colors)]
82
+ color = color_mapping[box_ref]
83
+ draw.rectangle([(left, top), (right, bottom)], outline=color, width=3)
84
+ # draw text
85
+ file_path = download_file(font_path)
86
+ font = ImageFont.truetype(file_path, 20)
87
+ for (left, top, _, _), box_ref in zip(bbox, ref):
88
+ brightness = _calculate_brightness(
89
+ image, [left, top, min(left + 100, image.width),
90
+ min(top + 20, image.height)])
91
+ draw.text((left, top), box_ref, fill='white' if brightness < 128 else 'black', font=font)
swift/llm/template/register.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from typing import Dict, Literal, Optional
4
+
5
+ from ..utils import Processor
6
+ from .base import Template
7
+ from .template_meta import TemplateMeta
8
+
9
+ TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {}
10
+
11
+
12
+ def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> None:
13
+ template_type = template_meta.template_type
14
+ if not exist_ok and template_type in TEMPLATE_MAPPING:
15
+ raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
16
+ TEMPLATE_MAPPING[template_type] = template_meta
17
+
18
+
19
+ def get_template(
20
+ template_type: str,
21
+ processor: Processor,
22
+ default_system: Optional[str] = None,
23
+ max_length: Optional[int] = None,
24
+ *,
25
+ use_chat_template: bool = True,
26
+ template_backend: Literal['swift', 'jinja'] = 'swift',
27
+ truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
28
+ max_pixels: Optional[int] = None, # h * w
29
+ agent_template: Optional[str] = None,
30
+ norm_bbox: Literal['norm1000', 'none', None] = None,
31
+ response_prefix: Optional[str] = None,
32
+ # train
33
+ padding_side: Literal['left', 'right'] = 'right',
34
+ loss_scale: str = 'default',
35
+ sequence_parallel_size: int = 1) -> 'Template':
36
+ template_meta = TEMPLATE_MAPPING[template_type]
37
+ template_cls = template_meta.template_cls
38
+ return template_cls(
39
+ processor,
40
+ template_meta,
41
+ default_system,
42
+ max_length,
43
+ use_chat_template=use_chat_template,
44
+ template_backend=template_backend,
45
+ truncation_strategy=truncation_strategy,
46
+ max_pixels=max_pixels,
47
+ agent_template=agent_template,
48
+ norm_bbox=norm_bbox,
49
+ response_prefix=response_prefix,
50
+ padding_side=padding_side,
51
+ loss_scale=loss_scale,
52
+ sequence_parallel_size=sequence_parallel_size,
53
+ )
54
+
55
+
56
+ def get_template_meta(template_type: str) -> TemplateMeta:
57
+ return TEMPLATE_MAPPING[template_type]
swift/llm/template/template/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
2
+ minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
swift/llm/template/template/deepseek.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+
11
+ from swift.utils import get_env_args
12
+ from ..base import Template
13
+ from ..constant import LLMTemplateType, MLLMTemplateType
14
+ from ..register import TemplateMeta, register_template
15
+ from ..template_inputs import StdTemplateInputs
16
+ from ..utils import Prompt, findall
17
+
18
+
19
+ @dataclass
20
+ class DeepseekTemplateMeta(TemplateMeta):
21
+ prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
22
+ prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:'])
23
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
24
+ suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
25
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n'])
26
+
27
+
28
+ register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, ))
29
+
30
+ register_template(
31
+ TemplateMeta(
32
+ LLMTemplateType.deepseek_coder,
33
+ prefix=['{{SYSTEM}}'],
34
+ prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'],
35
+ chat_sep=['\n<|EOT|>\n'],
36
+ suffix=['\n<|EOT|>'],
37
+ stop_words=['<|EOT|>'],
38
+ default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, '
39
+ 'developed by Deepseek Company, and you only answer questions related to computer science. '
40
+ 'For politically sensitive questions, security and privacy issues, '
41
+ 'and other non-computer science questions, you will refuse to answer\n')))
42
+
43
+
44
+ class DeepseekVLTemplate(Template):
45
+ image_placeholder = ['<image_placeholder>']
46
+ skip_prompt = False
47
+ use_model = True
48
+ placeholder_tokens = ['<image_placeholder>']
49
+
50
+ image_token_num_per_image: int = 576
51
+
52
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
53
+ is_janus = getattr(self, 'is_janus', False)
54
+
55
+ encoded = super()._encode(inputs)
56
+ images = inputs.images
57
+ processor = self.processor
58
+ input_ids, labels = encoded['input_ids'], encoded['labels']
59
+
60
+ if not inputs.generate_mode: # understanding task
61
+ idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
62
+ new_input_ids, new_labels = [], []
63
+ lo = 0
64
+ for hi in idx_list:
65
+ new_input_ids += input_ids[lo:hi]
66
+ if labels is not None:
67
+ new_labels += labels[lo:hi]
68
+ image_tokens = [processor.image_id] * processor.num_image_tokens
69
+ if is_janus:
70
+ image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
71
+ new_input_ids += image_tokens
72
+ new_labels += [-100] * len(image_tokens)
73
+ lo = hi + 1
74
+ new_input_ids += input_ids[lo:]
75
+ if labels is not None:
76
+ new_labels += labels[lo:]
77
+ else:
78
+ new_labels = None
79
+ if is_janus:
80
+ from janus.models.processing_vlm import VLChatProcessorOutput
81
+ else:
82
+ from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
83
+
84
+ images_outputs = processor.image_processor(images, return_tensors='pt')
85
+ output = VLChatProcessorOutput(
86
+ sft_format=None,
87
+ input_ids=torch.tensor(new_input_ids),
88
+ pixel_values=images_outputs.pixel_values,
89
+ num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
90
+ encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
91
+ return encoded
92
+
93
+ else: # image generation task
94
+ if self.is_training:
95
+ raise NotImplementedError('Only support the inference of generation of Janus series models.')
96
+ sft_format = self.tokenizer.decode(input_ids)
97
+ prompt = sft_format + processor.image_start_tag
98
+ input_ids = processor.tokenizer.encode(prompt)
99
+ input_ids = torch.LongTensor(input_ids)
100
+
101
+ encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
102
+ return encoded
103
+
104
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
105
+ if not inputs.get('generate_mode'):
106
+ inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype)
107
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
108
+ return {'inputs_embeds': inputs_embeds}
109
+ else:
110
+ return inputs
111
+
112
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
113
+ gene_img_list = [b.get('generate_mode') for b in batch]
114
+ if all(gene_img_list):
115
+ generate_mode = True
116
+ elif not any(gene_img_list):
117
+ generate_mode = False
118
+ else:
119
+ raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')
120
+
121
+ if not generate_mode:
122
+ output = self.fetch_inputs(batch, ['output'])['output']
123
+ batched_output = dict(self.processor.batchify(output))
124
+ res = super()._data_collator(batch, padding_to=padding_to)
125
+ return {**batched_output, **res}
126
+ else:
127
+ res = super()._data_collator(batch, padding_to=padding_to)
128
+ res['generate_mode'] = generate_mode
129
+ return res
130
+
131
+ def generate(self, model, *args, **kwargs):
132
+ if not kwargs.get('generate_mode'):
133
+ return super().generate(model, *args, **kwargs)
134
+
135
+ else:
136
+ # generate how many number of images for each prompt, it is named parallel_size in the author's code
137
+ parallel_size = kwargs['generation_config'].num_return_sequences
138
+ temperature = kwargs['generation_config'].temperature
139
+ cfg_weight = get_env_args('cfg_weight', float, 5.0)
140
+
141
+ input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
142
+ bsz, max_input_token_num = input_ids.shape
143
+ tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
144
+ dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
145
+ for i in range(parallel_size * 2):
146
+ tokens[:, i, :] = input_ids
147
+ if i % 2 != 0:
148
+ tokens[:, i, 1:-1] = self.processor.pad_id
149
+
150
+ inputs_embeds = model.language_model.get_input_embeddings()(
151
+ tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]
152
+
153
+ generated_tokens = torch.zeros(
154
+ (bsz, parallel_size, self.image_token_num_per_image),
155
+ dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens
156
+
157
+ # set the first two dimensions into one dimension for batch size
158
+ inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
159
+ generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)
160
+
161
+ for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
162
+ outputs = model.language_model.model(
163
+ inputs_embeds=inputs_embeds,
164
+ use_cache=True,
165
+ past_key_values=outputs.past_key_values if i != 0 else None)
166
+ hidden_states = outputs.last_hidden_state
167
+
168
+ logits = self.model.gen_head(hidden_states[:, -1, :])
169
+ logit_cond = logits[0::2, :]
170
+ logit_uncond = logits[1::2, :]
171
+
172
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
173
+ probs = torch.softmax(logits / temperature, dim=-1)
174
+
175
+ next_token = torch.multinomial(probs, num_samples=1)
176
+ generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]
177
+
178
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
179
+ img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
180
+ inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]
181
+
182
+ # no need to reset the original first two dimensions, waiting for the update of the upper layer
183
+ # inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
184
+ # generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)
185
+
186
+ return {'sequences': generated_tokens}
187
+
188
+ def decode(self, generate_ids: List[int], **kwargs) -> Any:
189
+ if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
190
+ return super().decode(generate_ids, **kwargs)
191
+ else:
192
+ img_size = get_env_args('img_size', int, 384)
193
+ patch_size = 16
194
+
195
+ num_to_decode = 1 # for now, generate_ids is a 1D list
196
+
197
+ generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]
198
+
199
+ dec = self.model.gen_vision_model.decode_code(
200
+ generate_ids.to(dtype=torch.int),
201
+ shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
202
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]
203
+
204
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
205
+
206
+ visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
207
+ visual_img[:, :, :] = dec
208
+
209
+ img_list = []
210
+ for i in range(num_to_decode):
211
+ cur_img = Image.fromarray(visual_img[i])
212
+ img_list.append({'type': 'image', 'image': cur_img})
213
+ return img_list
214
+
215
+
216
+ @dataclass
217
+ class DeepseekVLTemplateMeta(DeepseekTemplateMeta):
218
+ default_system: Optional[str] = ('You are a helpful language and vision assistant. '
219
+ 'You are able to understand the visual content that the user provides, '
220
+ 'and assist the user with a variety of tasks using natural language.')
221
+
222
+
223
+ register_template(DeepseekVLTemplateMeta(
224
+ MLLMTemplateType.deepseek_vl,
225
+ template_cls=DeepseekVLTemplate,
226
+ ))
227
+
228
+
229
+ class DeepseekJanus(DeepseekVLTemplate):
230
+ is_janus = True
231
+ image_placeholder = ['<image_placeholder>\n']
232
+
233
+
234
+ register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus))
235
+
236
+
237
+ @dataclass
238
+ class DeepseekV2_5TemplateMeta(TemplateMeta):
239
+ prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}'])
240
+ prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>'])
241
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
242
+ suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
243
+
244
+
245
+ register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5))
246
+
247
+
248
+ class DeepseekR1Template(Template):
249
+
250
+ def _swift_encode(self, inputs: StdTemplateInputs):
251
+ if not self.is_training:
252
+ for message in inputs.messages:
253
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
254
+ message['content'] = message['content'].split('</think>')[-1]
255
+ return super()._swift_encode(inputs)
256
+
257
+
258
+ register_template(
259
+ DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='<think>\n'))
260
+
261
+
262
+ class DeepseekVL2Template(DeepseekVLTemplate):
263
+ image_placeholder = ['<image>\n']
264
+ placeholder_tokens = ['<image>']
265
+
266
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
267
+ from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput
268
+ encoded = Template._encode(self, inputs)
269
+ images = inputs.images
270
+ processor = self.processor
271
+ input_ids, labels = encoded['input_ids'], encoded['labels']
272
+ images_seq_mask = [False] * len(input_ids)
273
+ idx_list = findall(input_ids, processor.image_token_id) # '<image>'
274
+ _, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images(
275
+ '<image>' * len(images), images, cropping=len(images) <= 2)
276
+ new_num_tokens = 0
277
+ for idx, n_image_tokens in zip(idx_list, num_image_tokens):
278
+ image_tokens = [processor.image_token_id] * n_image_tokens
279
+ input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:]
280
+ if labels is not None:
281
+ labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:]
282
+ images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:]
283
+ new_num_tokens += n_image_tokens - 1
284
+
285
+ output = VLChatProcessorOutput(
286
+ sft_format=None,
287
+ input_ids=torch.tensor(input_ids),
288
+ target_ids=torch.tensor(input_ids),
289
+ images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)),
290
+ images_seq_mask=torch.tensor(images_seq_mask),
291
+ images_spatial_crop=torch.tensor(images_spatial_crop),
292
+ num_image_tokens=num_image_tokens)
293
+ output.images = output.images.to(dtype=self.model_info.torch_dtype)
294
+ encoded = {'output': output, 'input_ids': input_ids, 'labels': labels}
295
+ return encoded
296
+
297
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
298
+ inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool)
299
+ inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long)
300
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
301
+ return {'inputs_embeds': inputs_embeds}
302
+
303
+
304
+ register_template(
305
+ DeepseekV2_5TemplateMeta(
306
+ MLLMTemplateType.deepseek_vl2,
307
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
308
+ template_cls=DeepseekVL2Template,
309
+ ))
310
+
311
+ register_template(
312
+ DeepseekVLTemplateMeta(
313
+ MLLMTemplateType.deepseek_janus_pro,
314
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
315
+ template_cls=DeepseekJanus))
swift/llm/template/template/emu3.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ import random
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from swift.utils import get_device
10
+ from ..base import Template
11
+ from ..constant import MLLMTemplateType
12
+ from ..register import register_template
13
+ from ..template_inputs import StdTemplateInputs
14
+ from ..template_meta import TemplateMeta
15
+ from ..utils import findall
16
+ from .utils import DEFAULT_SYSTEM, EmptyTemplateMeta
17
+
18
+
19
+ class Emu3GenTemplate(Template):
20
+
21
+ NULL_PROMPT_PROB = 0.1
22
+ COOKBOOK_SIZE = 32768
23
+ CFG_SCALE = os.environ.get('CFG_SCALE', 3.0)
24
+ GENERATION_RATIO = os.environ.get('GENERATION_RATIO', '1:1')
25
+ NEGATIVE_PROMPT = os.environ.get(
26
+ 'NEGATIVE_PROMPT',
27
+ 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, '
28
+ 'worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.')
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ self.bov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=0))[0]
33
+ self.eov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=self.COOKBOOK_SIZE
34
+ - 1))[0]
35
+ self.h, self.w = self.processor.calculate_generate_size(self.GENERATION_RATIO, self.processor.image_area,
36
+ self.processor.vision_tokenizer.spatial_scale_factor)
37
+ self.skip_prompt = False
38
+ self.apply_loss_on_only_vision = True
39
+
40
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
41
+ if self.is_training:
42
+ p_prob = random.random()
43
+ if p_prob < self.NULL_PROMPT_PROB:
44
+ prompt = ''
45
+ else:
46
+ prompt = inputs.to_history()['response']
47
+ image = self.smart_resize(inputs.images[0].convert('RGB'))
48
+ with torch.no_grad():
49
+ image = self.processor.image_processor(
50
+ image, return_tensors='pt')['pixel_values'].to(device=self.processor.vision_tokenizer.device)
51
+ image_token_ids = self.processor.vision_tokenizer.encode(image).squeeze(0)
52
+ encoded = self._process_prompt_train(prompt, image_token_ids)
53
+ else:
54
+ prompt = inputs.to_history()['query']
55
+ encoded = self._process_prompt_test(prompt)
56
+ encoded = {key: encoded[key][0] for key in encoded.keys()} # [1, L] -> [L]
57
+
58
+ return encoded
59
+
60
+ def _process_prompt_train(self, raw_prompt, image_token_ids):
61
+ image_prompt = self.format_image_prompt(image_token_ids)
62
+ prompt = self.tokenizer.bos_token + raw_prompt + image_prompt
63
+ sample = self.tokenizer(prompt, padding='max_length', return_token_type_ids=False)
64
+ labels = torch.tensor(sample['input_ids'])
65
+ if self.apply_loss_on_only_vision:
66
+ labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, -100)
67
+ sample['labels'] = labels.tolist()
68
+ return sample
69
+
70
+ def _process_prompt_test(self, raw_prompt):
71
+ # for supporting multi inputs, use list instead of single string
72
+ if isinstance(raw_prompt, str):
73
+ raw_prompt = [raw_prompt]
74
+ prompt_list = []
75
+ size_list = []
76
+ for text_prompt in raw_prompt:
77
+ prompt = self.processor.tokenizer.bos_token
78
+ image_prompt = (
79
+ self.processor.tokenizer.boi_token + self.processor.prefix_template.format(H=self.h, W=self.w)
80
+ + self.processor.tokenizer.img_token)
81
+ prompt += (text_prompt + image_prompt)
82
+ prompt_list.append(prompt)
83
+ size_list.append([self.h, self.w])
84
+ prompt_list = self.tokenizer(prompt_list, padding='longest', return_token_type_ids=False)
85
+ return prompt_list
86
+
87
+ def prepare_for_output(self, output: str) -> str:
88
+ return output
89
+
90
+ def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
91
+ from transformers import UnbatchedClassifierFreeGuidanceLogitsProcessor
92
+ from transformers import PrefixConstrainedLogitsProcessor
93
+ from transformers import LogitsProcessorList
94
+
95
+ negative_prompt = self.NEGATIVE_PROMPT
96
+ neg_inputs = self._process_prompt_test(negative_prompt)
97
+ neg_inputs = {key: torch.tensor(val) for key, val in neg_inputs.items()}
98
+ batch_size = generate_kwargs['input_ids'].shape[0]
99
+ h = torch.tensor([self.h] * batch_size)
100
+ w = torch.tensor([self.w] * batch_size)
101
+
102
+ constrained_fn = self.processor.build_prefix_constrained_fn(h, w)
103
+ logits_processor = LogitsProcessorList([
104
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
105
+ self.CFG_SCALE,
106
+ model,
107
+ unconditional_ids=neg_inputs['input_ids'].to(get_device()),
108
+ ),
109
+ PrefixConstrainedLogitsProcessor(
110
+ constrained_fn,
111
+ num_beams=1,
112
+ ),
113
+ ])
114
+ res = super().prepare_generate_kwargs(generate_kwargs, model=model)
115
+ res['logits_processor'] = logits_processor
116
+ return res
117
+
118
+ def decode(self, generate_ids: List[int], **kwargs) -> Any:
119
+ mm_list = self.processor.decode(generate_ids)
120
+ for im in mm_list:
121
+ if not isinstance(im, Image.Image):
122
+ continue
123
+ return [{'type': 'image', 'image': im}]
124
+
125
+ def to_imgstr(self, image_tokens):
126
+ image_token_str = [[self.processor.visual_template[0].format(token_id=token_id) for token_id in token_row]
127
+ for token_row in image_tokens]
128
+ image_row_str = [''.join(token_row) for token_row in image_token_str]
129
+ imgstr = self.tokenizer.eol_token.join(image_row_str)
130
+ return imgstr
131
+
132
+ def format_image_prompt(self, image_tokens):
133
+ h, w = image_tokens.shape
134
+ imgstr = self.to_imgstr(image_tokens)
135
+ image_prompt = (
136
+ self.tokenizer.boi_token + f'{h}*{w}' + self.tokenizer.img_token + imgstr + self.tokenizer.eol_token
137
+ + self.tokenizer.eof_token + self.tokenizer.eoi_token)
138
+ return image_prompt
139
+
140
+ def smart_resize(self, image):
141
+ w, h = image.size
142
+ current_area = h * w
143
+ target_ratio = (self.processor.image_area / current_area)**0.5
144
+ th = int(round(h * target_ratio))
145
+ tw = int(round(w * target_ratio))
146
+ image = image.resize((tw, th))
147
+ return image
148
+
149
+
150
+ register_template(EmptyTemplateMeta(
151
+ MLLMTemplateType.emu3_gen,
152
+ template_cls=Emu3GenTemplate,
153
+ ))
154
+
155
+
156
+ class Emu3ChatTemplate(Template):
157
+ system = 'You are a helpful assistant.'
158
+ image_placeholder = ['<|image token|>']
159
+
160
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
161
+ encoded = super()._encode(inputs)
162
+ # image
163
+ images = inputs.images
164
+ input_ids = encoded['input_ids']
165
+ labels = encoded['labels']
166
+ image_tokens = self.processor.tokenize_image(images)
167
+ image_prompts = []
168
+ idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder))
169
+ # Create image prompts
170
+ for i in range(len(images)):
171
+ h, w = image_tokens[i].shape
172
+ imgstr = self.processor.to_imgstr(image_tokens[i])
173
+ image_prompt = (
174
+ self.tokenizer.boi_token + self.processor.prefix_template.format(H=h, W=w) + self.tokenizer.img_token
175
+ + imgstr + self.tokenizer.eol_token + self.tokenizer.eof_token + self.tokenizer.eoi_token)
176
+ image_prompts.append(self.tokenizer.encode(image_prompt))
177
+
178
+ # Insert image tokens into input_ids
179
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i])
180
+ return {'input_ids': input_ids, 'labels': labels}
181
+
182
+
183
+ register_template(
184
+ TemplateMeta(
185
+ MLLMTemplateType.emu3_chat,
186
+ prefix=[['bos_token_id'], '{{SYSTEM}}'],
187
+ prompt=[' User: {{QUERY}}. Assistant:'],
188
+ chat_sep=[['eos_token_id']],
189
+ suffix=[['eos_token_id']],
190
+ default_system=DEFAULT_SYSTEM,
191
+ template_cls=Emu3ChatTemplate))
swift/llm/template/template/gemma.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from swift.utils import upper_bound
9
+ from ..base import Template
10
+ from ..constant import LLMTemplateType, MLLMTemplateType
11
+ from ..register import TemplateMeta, register_template
12
+ from ..template_inputs import StdTemplateInputs
13
+ from ..utils import Context, Prompt, findall
14
+
15
+
16
+ @dataclass
17
+ class GemmaTemplateMeta(TemplateMeta):
18
+ prefix: Prompt = field(default_factory=lambda: ['<bos>'])
19
+ prompt: Prompt = field(
20
+ default_factory=lambda: ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'])
21
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_turn>\n'])
22
+ suffix: Prompt = field(default_factory=lambda: ['<end_of_turn>'])
23
+ system_prefix: Optional[Prompt] = field(
24
+ default_factory=lambda: ['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
25
+
26
+
27
+ register_template(GemmaTemplateMeta(LLMTemplateType.gemma))
28
+
29
+
30
+ class PaliGemmaTemplate(Template):
31
+ placeholder_tokens = ['<image>']
32
+
33
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
34
+ inputs: StdTemplateInputs) -> List[Context]:
35
+ assert media_type == 'image'
36
+ if self.mode == 'vllm':
37
+ self.prompt = ['{{QUERY}}']
38
+ return []
39
+ else:
40
+ self.prompt = ['{{QUERY}}\n']
41
+ return ['<image>' * self.processor.image_seq_length + '<bos>']
42
+
43
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
44
+ encoded = super()._encode(inputs)
45
+ raw_image = inputs.images
46
+ processor = self.processor
47
+ if encoded['labels'] is not None:
48
+ n = upper_bound(0, len(encoded['labels']), lambda idx: encoded['labels'][idx] == -100)
49
+ n2 = len(encoded['labels']) - n
50
+ encoded['token_type_ids'] = [0] * n + [1] * n2
51
+ else:
52
+ encoded['token_type_ids'] = [0] * len(encoded['input_ids'])
53
+ if raw_image:
54
+ model_inputs = processor(text='<image>' * len(raw_image), images=raw_image, return_tensors='pt')
55
+ encoded['pixel_values'] = model_inputs['pixel_values'].to(self.model_info.torch_dtype)
56
+ return encoded
57
+
58
+
59
+ register_template(
60
+ TemplateMeta(
61
+ MLLMTemplateType.paligemma,
62
+ prefix=[],
63
+ prompt=['{{QUERY}}\n'],
64
+ chat_sep=None,
65
+ suffix=['<eos>'],
66
+ template_cls=PaliGemmaTemplate,
67
+ ))
68
+
69
+
70
+ @dataclass
71
+ class Gemma3TextTemplateMeta(TemplateMeta):
72
+ prefix: Prompt = field(default_factory=lambda: ['<bos>'])
73
+ prompt: Prompt = field(
74
+ default_factory=lambda: ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'])
75
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_turn>\n'])
76
+ suffix: Prompt = field(default_factory=lambda: ['<end_of_turn>'])
77
+
78
+
79
+ class Gemma3Template(Template):
80
+
81
+ def _swift_encode(self, inputs: StdTemplateInputs):
82
+ if inputs.system is not None:
83
+ system = inputs.system
84
+ inputs.system = None
85
+ inputs.messages[0]['content'] = system + '\n\n' + inputs.messages[0]['content']
86
+ for message in inputs.messages:
87
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
88
+ message['content'] = message['content'].strip('\n')
89
+ return super()._swift_encode(inputs)
90
+
91
+
92
+ register_template(Gemma3TextTemplateMeta(LLMTemplateType.gemma3_text, template_cls=Gemma3Template))
93
+
94
+
95
+ class Gemma3VisionTemplate(Gemma3Template):
96
+ boi_token_id = 255999
97
+ placeholder_tokens = ['<start_of_image>']
98
+
99
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
100
+ inputs: StdTemplateInputs) -> List[Context]:
101
+ assert media_type == 'image'
102
+ return ['<start_of_image>']
103
+
104
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
105
+ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
106
+
107
+ encoded = super()._encode(inputs)
108
+ if inputs.images:
109
+ input_ids = encoded['input_ids']
110
+ labels = encoded['labels']
111
+ idx_list = findall(input_ids, self.boi_token_id)
112
+ img_tokens = self._tokenize(self.processor.full_image_sequence)
113
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
114
+
115
+ # TODO: customize
116
+ processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs']
117
+ image_inputs = self.processor.image_processor(inputs.images, **processor_kwargs)
118
+ image_inputs['pixel_values'] = torch.as_tensor(np.array(image_inputs['pixel_values']))
119
+ image_inputs.pop('num_crops')
120
+
121
+ array_ids = np.array(input_ids)
122
+ mm_token_type_ids = np.zeros_like(input_ids)
123
+ mm_token_type_ids[array_ids == self.processor.image_token_id] = 1
124
+ encoded['token_type_ids'] = mm_token_type_ids.tolist()
125
+ encoded['input_ids'] = input_ids
126
+ encoded['pixel_values'] = image_inputs['pixel_values']
127
+ encoded['labels'] = labels
128
+ return encoded
129
+
130
+
131
+ register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3_vision, template_cls=Gemma3VisionTemplate))
swift/llm/template/template/glm.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+
7
+ from ..base import Template
8
+ from ..constant import LLMTemplateType, MLLMTemplateType
9
+ from ..register import TemplateMeta, register_template
10
+ from ..template_inputs import StdTemplateInputs
11
+ from ..utils import Context, Prompt, Word, findall
12
+ from ..vision_utils import load_batch, load_video_cogvlm2
13
+
14
+
15
+ @dataclass
16
+ class GLMTemplateMeta(TemplateMeta):
17
+ auto_add_bos: bool = True
18
+
19
+
20
+ class GLM4Template(Template):
21
+
22
+ def _swift_encode(self, inputs: StdTemplateInputs):
23
+ res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
24
+ for i, res_context in enumerate(res_context_list):
25
+ # The last round or is tool_call.
26
+ if isinstance(res_context, str) and res_context.endswith('<|assistant|>\n') and (
27
+ i + 1 >= len(res_context_list) or '<|observation|>' in res_context_list[i + 1]):
28
+ res_context_list[i] = res_context_list[i][:-len('\n')]
29
+ return res_context_list, loss_scale_list, answer_len
30
+
31
+ def decode(self, *args, **kwargs):
32
+ response = super().decode(*args, **kwargs)
33
+ return response.lstrip('\n')
34
+
35
+
36
+ class GLM4_0414Template(GLM4Template):
37
+
38
+ def _swift_encode(self, inputs: StdTemplateInputs):
39
+ if not self.is_training:
40
+ for message in inputs.messages:
41
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
42
+ message['content'] = message['content'].split('</think>')[-1].strip()
43
+ return super()._swift_encode(inputs)
44
+
45
+
46
+ register_template(
47
+ GLMTemplateMeta(
48
+ LLMTemplateType.chatglm2,
49
+ prefix=['{{SYSTEM}}'],
50
+ prompt=['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'],
51
+ chat_sep=['\n\n']))
52
+
53
+
54
+ @dataclass
55
+ class GLM4TemplateMeta(GLMTemplateMeta):
56
+ prefix: Prompt = field(default_factory=list)
57
+ prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n'])
58
+ chat_sep: Optional[Prompt] = field(default_factory=list)
59
+ suffix: Prompt = field(default_factory=lambda: ['<|user|>'])
60
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}'])
61
+
62
+ agent_template: str = 'glm4'
63
+ stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>'])
64
+
65
+
66
+ @dataclass
67
+ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
68
+ prefix: Prompt = field(default_factory=lambda: ['[gMASK]<sop>'])
69
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>\n{{SYSTEM}}'])
70
+ agent_template: str = 'glm4_0414'
71
+
72
+
73
+ class GLM4VTemplate(Template):
74
+
75
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
76
+ inputs: StdTemplateInputs) -> List[Context]:
77
+ assert media_type == 'image'
78
+ return [[-100]]
79
+
80
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
81
+ encoded = super()._encode(inputs)
82
+ input_ids = encoded['input_ids']
83
+ labels = encoded['labels']
84
+ idx_list = findall(input_ids, -100)
85
+ if idx_list:
86
+ idx = idx_list[0]
87
+ image = inputs.images[0]
88
+ placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
89
+ placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
90
+ input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
91
+ if labels is not None:
92
+ labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
93
+ messages = inputs.messages
94
+ messages[0]['image'] = image
95
+ inputs2: Dict[str, Any] = self.processor.apply_chat_template(messages, return_dict=True)
96
+ encoded['images'] = inputs2['images']
97
+ encoded['input_ids'] = input_ids
98
+ encoded['labels'] = labels
99
+ encoded['position_ids'] = list(range(len(input_ids)))
100
+ return encoded
101
+
102
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
103
+ res = super()._data_collator(batch, padding_to=padding_to)
104
+ images = [b['images'] for b in batch if 'images' in b]
105
+ if images:
106
+ res['images'] = torch.concat(images)
107
+ return res
108
+
109
+
110
+ register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>']))
111
+
112
+ register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template))
113
+
114
+ register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
115
+
116
+ glm4z1rumination_system = (
117
+ '你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
118
+ '今年是 2025 年。\n\n'
119
+ '<核心要求>\n'
120
+ '- 首先分解用户请求,得到包含多个子要求的列表\n'
121
+ '- 制定初始研究计划\n'
122
+ '- 进行多轮迭代搜索和页面浏览(at least 10 function calls):\n'
123
+ ' * 根据已获得的信息调整研究计划和关键词\n'
124
+ ' * 打开页面阅读,从发现的内容中识别新的关键概念/名词\n'
125
+ ' * 从搜索结果中提取新的关键词继续搜索\n'
126
+ ' * 访问并仔细阅读相关页面,识别新的关键概念/名词\n\n'
127
+ '<重要配置>\n'
128
+ '- 采用语言\n'
129
+ ' * 搜索关键词:英语\n'
130
+ ' * 思考:英语\n\n'
131
+ '<可调用的工具列表>\n\n'
132
+ '[{"name": "search", "description": "Execute a search query and return search results. '
133
+ 'Use this function when you need to find information about a specific topic.", '
134
+ '"parameters": {"type": "object", "properties": {"query": {"type": "string", '
135
+ '"description": "Search query string, use English words unless it is a proper name in Chinese"}}, '
136
+ '"required": ["query"], "additionalProperties": false}}, '
137
+ '{"name": "click", "description": "Click a link in the search results and navigate to the corresponding page. '
138
+ 'Use this function when you need to view detailed content of a specific search result.", '
139
+ '"parameters": {"type": "object", "properties": {"link_id": {"type": "integer", '
140
+ '"description": "The link ID to click (from the sequence number in search results)"}}, '
141
+ '"required": ["link_id"], "additionalProperties": false}}, '
142
+ '{"name": "open", "description": "Open a specific website. Get content from any website with its URL.", '
143
+ '"parameters": {"type": "object", "properties": {"url": {"type": "string", '
144
+ '"description": "The target website URL or domain"}}, "required": ["url"], "additionalProperties": false}}, '
145
+ '{"name": "finish", "description": "Finish the task. '
146
+ 'Use this function when you have found the information you need.", '
147
+ '"parameters": {"type": "object", "properties": {}, "additionalProperties": false}}]')
148
+
149
+ register_template(
150
+ GLM4_0414TemplateMeta(
151
+ LLMTemplateType.glm4_z1_rumination, template_cls=GLM4_0414Template, default_system=glm4z1rumination_system))
152
+
153
+ codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
154
+
155
+ register_template(GLM4TemplateMeta(LLMTemplateType.codegeex4, default_system=codegeex4_system))
156
+
157
+ register_template(
158
+ TemplateMeta(
159
+ LLMTemplateType.longwriter_llama, ['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'],
160
+ system_prefix=['<<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
161
+
162
+
163
+ class CogTemplate(Template):
164
+ placeholder_tokens = ['<|reserved_special_token_0|>']
165
+
166
+ use_model = True
167
+
168
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
169
+ inputs: StdTemplateInputs) -> List[Context]:
170
+ return []
171
+
172
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
173
+ encoded = super()._encode(inputs)
174
+ model = self.model
175
+ image = inputs.images or []
176
+ history_inputs = inputs.to_history()
177
+ inputs2 = model.build_conversation_input_ids(
178
+ self.processor, query=history_inputs['query'], history=history_inputs['history'], images=image)
179
+ image_token_len = inputs2['token_type_ids'].sum().item()
180
+ input_ids = encoded['input_ids']
181
+ labels = encoded['labels']
182
+ encoded['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
183
+ encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * image_token_len + input_ids[1:]
184
+ if labels is not None:
185
+ encoded['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
186
+ if len(image) > 0:
187
+ encoded['images'] = [[img.to(dtype=self.model_info.torch_dtype)] for img in inputs2['images']]
188
+ if 'cross_images' in inputs2:
189
+ # is cogagent
190
+ encoded['cross_images'] = [[cross_img.to(dtype=self.model_info.torch_dtype)]
191
+ for cross_img in inputs2['cross_images']]
192
+ return encoded
193
+
194
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
195
+ res = super()._data_collator(batch, padding_to=padding_to)
196
+ keys = ['images', 'cross_images']
197
+ for key in keys:
198
+ if key in batch[0]:
199
+ res[key] = [b[key][0] for b in batch]
200
+ return res
201
+
202
+
203
+ register_template(
204
+ TemplateMeta(
205
+ MLLMTemplateType.cogagent_chat,
206
+ prefix=['<s>'],
207
+ prompt=[' [INST] {{QUERY}} [/INST] '],
208
+ chat_sep=[],
209
+ suffix=['</s>'],
210
+ template_cls=CogTemplate,
211
+ ))
212
+
213
+ register_template(
214
+ TemplateMeta(
215
+ MLLMTemplateType.cogagent_vqa,
216
+ prefix=['<s>'],
217
+ prompt=['<EOI>Question: {{QUERY}} Answer:'],
218
+ chat_sep=None,
219
+ suffix=['</s>'],
220
+ template_cls=CogTemplate))
221
+
222
+
223
+ @dataclass
224
+ class CogVLMTemplateMeta(TemplateMeta):
225
+ prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
226
+ prompt: Prompt = field(default_factory=lambda: ['Question: {{QUERY}} Answer:'])
227
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n'])
228
+
229
+
230
+ register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm, template_cls=CogTemplate))
231
+
232
+ register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm2, template_cls=CogTemplate))
233
+
234
+
235
+ class Cog2VideoTemplate(CogTemplate):
236
+ use_model = True
237
+
238
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
239
+ model = self.model
240
+ encoded = super(CogTemplate, self)._encode(inputs)
241
+ videos_path = inputs.videos or []
242
+ video = load_batch(videos_path, load_video_cogvlm2)
243
+ history_inputs = inputs.to_history()
244
+ inputs2 = model.build_conversation_input_ids(
245
+ self.processor,
246
+ query=history_inputs['query'],
247
+ history=history_inputs['history'],
248
+ images=video,
249
+ template_version='chat')
250
+ video_token_len = inputs2['token_type_ids'].sum().item()
251
+ input_ids = encoded['input_ids']
252
+ labels = encoded['labels']
253
+ encoded['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:])
254
+ encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * video_token_len + input_ids[1:]
255
+ if labels is not None:
256
+ encoded['labels'] = labels[:1] + [-100] * video_token_len + labels[1:]
257
+ if len(video) > 0:
258
+ dtype = model.dtype
259
+ encoded['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
260
+ return encoded
261
+
262
+
263
+ register_template(CogVLMTemplateMeta(
264
+ MLLMTemplateType.cogvlm2_video,
265
+ template_cls=Cog2VideoTemplate,
266
+ ))
267
+
268
+
269
+ class GLMEdgeVTemplate(Template):
270
+ placeholder_tokens = ['<|begin_of_image|>']
271
+
272
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
273
+ inputs: StdTemplateInputs) -> List[Context]:
274
+ assert media_type == 'image'
275
+ return ['<|begin_of_image|>' * 578]
276
+
277
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
278
+ encoded = super()._encode(inputs)
279
+ images = inputs.images
280
+ if images:
281
+ encoded['pixel_values'] = torch.tensor(self.processor(images).pixel_values)
282
+ return encoded
283
+
284
+
285
+ register_template(
286
+ GLM4TemplateMeta(
287
+ MLLMTemplateType.glm_edge_v,
288
+ prompt=['<|user|>\\n{{QUERY}}\\n<|assistant|>\\n'],
289
+ chat_sep=['\\n'],
290
+ system_prefix=['<|system|>\\n{{SYSTEM}}\\n'],
291
+ suffix=['<|endoftext|>'],
292
+ template_cls=GLMEdgeVTemplate,
293
+ ))
swift/llm/template/template/idefics3.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict
3
+
4
+ from ..base import Template
5
+ from ..constant import MLLMTemplateType
6
+ from ..register import TemplateMeta, register_template
7
+ from ..template_inputs import StdTemplateInputs
8
+ from ..utils import align_image_inputs
9
+
10
+
11
+ class Idefics3Template(Template):
12
+ placeholder_tokens = ['<image>']
13
+
14
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
15
+ encoded = super()._encode(inputs)
16
+ images = inputs.images or []
17
+ processor = self.processor
18
+ prompt = self.processor.decode(encoded['input_ids'])
19
+ if images:
20
+ image_inputs = processor(text=prompt, images=images, return_tensors='pt', add_special_tokens=False)
21
+ image_token = 128257 # <image>
22
+ encoded['input_ids'], encoded['labels'] = align_image_inputs(encoded['input_ids'], encoded['labels'],
23
+ image_inputs['input_ids'][0], image_token)
24
+ encoded['pixel_values'] = image_inputs['pixel_values']
25
+ return encoded
26
+
27
+
28
+ register_template(
29
+ TemplateMeta(
30
+ MLLMTemplateType.idefics3,
31
+ prefix=['<|begin_of_text|>'],
32
+ prompt=['User:{{QUERY}}<end_of_utterance>\nAssistant:'],
33
+ chat_sep=['<end_of_utterance>\n'],
34
+ suffix=['<end_of_utterance>'],
35
+ system_prefix=['System:{{SYSTEM}}<end_of_utterance>\n'],
36
+ template_cls=Idefics3Template,
37
+ ))
swift/llm/template/template/internlm.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
8
+
9
+ from swift.utils import get_env_args
10
+ from ..base import Template
11
+ from ..constant import LLMTemplateType, MLLMTemplateType, RMTemplateType
12
+ from ..register import TemplateMeta, register_template
13
+ from ..template_inputs import StdTemplateInputs
14
+ from ..utils import Context, Prompt, Word
15
+ from ..vision_utils import load_file
16
+ from .utils import ChatmlTemplateMeta
17
+
18
+ INTERNLM_SYSTEM = (
19
+ 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
20
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
21
+ 'It is designed to be helpful, honest, and harmless.\n'
22
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen '
23
+ 'by the user such as English and 中文.')
24
+
25
+ register_template(
26
+ TemplateMeta(
27
+ LLMTemplateType.internlm,
28
+ prefix=['<s>'],
29
+ prompt=['<|User|>:{{QUERY}}\n<|Bot|>:'],
30
+ chat_sep=['<eoa>\n'],
31
+ suffix=['<eoa>'],
32
+ default_system=INTERNLM_SYSTEM,
33
+ system_prefix=['<s><|System|>:{{SYSTEM}}\n']))
34
+
35
+ register_template(ChatmlTemplateMeta(LLMTemplateType.internlm2, default_system=INTERNLM_SYSTEM))
36
+
37
+ register_template(ChatmlTemplateMeta(RMTemplateType.internlm2_reward, suffix=['<|im_end|>\n<|reward|>']))
38
+
39
+
40
+ class InternLMXComposer2Template(Template):
41
+ image_placeholder = ['</s>']
42
+ version = 'v2'
43
+ skip_prompt = False
44
+ use_model = True
45
+
46
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
47
+ inputs: StdTemplateInputs) -> List[Context]:
48
+ if media_type == 'video':
49
+ inputs.images.insert(inputs.image_idx, inputs.videos[index])
50
+ inputs.image_idx += 1
51
+ return self.image_placeholder
52
+
53
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
54
+ model = self.model
55
+ encoded = super()._encode(inputs)
56
+ images = inputs.images or []
57
+
58
+ if self.version == 'v2.5':
59
+ hd_num = 24
60
+ if len(images) > 1:
61
+ hd_num = 6
62
+ hd_num = get_env_args('hd_num', int, hd_num)
63
+ images_origin = images
64
+ images = []
65
+ for image in images_origin:
66
+ if isinstance(image, Image.Image):
67
+ Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', model.model_dir)
68
+ images.append(Image_transform(image, hd_num=hd_num))
69
+ else:
70
+ load_video = get_class_from_dynamic_module('ixc_utils.load_video', model.model_dir)
71
+ frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', model.model_dir)
72
+ Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', model.model_dir)
73
+ image = load_video(load_file(image))
74
+ image = frame2img(image, model.font)
75
+ images.append(Video_transform(image, hd_num=hd_num))
76
+ elif self.version == 'v2-4khd':
77
+ hd_num = 55
78
+ hd_num = get_env_args('hd_num', int, hd_num)
79
+ HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', model.model_dir)
80
+ images = [HD_transform(image, hd_num=hd_num) for image in images]
81
+ images = [model.vis_processor(image).to(model.dtype) for image in images]
82
+ encoded['images'] = images
83
+ return encoded
84
+
85
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
86
+ batch_size = len(inputs['input_ids'])
87
+ res = []
88
+ im_mask = []
89
+ length = inputs['length']
90
+ for i in range(batch_size):
91
+ input_ids = inputs['input_ids'][i].tolist()[:length[i]]
92
+ input_ids.append(2) # add dummy </s>
93
+ labels = inputs.get('labels')
94
+ if labels is not None:
95
+ labels = labels[i].tolist()[:length[i]]
96
+ labels.append(2)
97
+ else:
98
+ labels = []
99
+ images = inputs['images'][i]
100
+ res_inputs_embeds = []
101
+ res_labels = []
102
+ wrap_im_mask = []
103
+ pre_i, i, idx = 0, 0, 0
104
+ device = model.device
105
+ internlm2_model = model.model
106
+ if not hasattr(internlm2_model, 'tok_embeddings'):
107
+ internlm2_model = internlm2_model.model
108
+ tok_embeddings = internlm2_model.tok_embeddings
109
+ if len(images) > 0:
110
+ images = torch.concat([model.img2emb(image[None])[0] for image in images], dim=0)
111
+ add_bos = False
112
+ while i < len(input_ids):
113
+ if input_ids[i] == 2: # replace_token
114
+ res_input_ids = torch.tensor(([1] if add_bos else []) + input_ids[pre_i:i], device=device)
115
+ if not add_bos and self.version != 'v2.5':
116
+ add_bos = True
117
+ res_inputs_embeds.append(tok_embeddings(res_input_ids[None])[0])
118
+ wrap_im_mask += [0] * len(res_input_ids)
119
+ res_labels += ([-100] if add_bos else []) + labels[pre_i:i]
120
+ if len(images) > 0 and idx < images.shape[0]:
121
+ res_inputs_embeds.append(images[idx].to(device))
122
+ wrap_im_mask += [1] * images.shape[1]
123
+ res_labels += [-100] * images.shape[1]
124
+ idx += 1
125
+ i += 1
126
+ pre_i = i
127
+ continue
128
+ i += 1
129
+ if len(labels) == 0:
130
+ res_labels = None
131
+ im_mask.append(torch.tensor(wrap_im_mask, dtype=torch.bool, device=device))
132
+ res.append({'inputs_embeds': torch.concat(res_inputs_embeds, dim=0), 'labels': res_labels})
133
+ res = Template._data_collator(self, res)
134
+ res['im_mask'] = self._pad_sequence(im_mask, 0)
135
+ return res
136
+
137
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
138
+ res = super()._data_collator(batch, padding_to=padding_to)
139
+ res['length'] = [len(b['input_ids']) for b in batch]
140
+ res.update(self.fetch_inputs(batch, ['images']))
141
+ return res
142
+
143
+
144
+ @dataclass
145
+ class Xcomposer2TemplateMeta(TemplateMeta):
146
+ prefix: Prompt = field(default_factory=lambda: ['<s>'])
147
+ prompt: Prompt = field(
148
+ default_factory=lambda: ['[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'])
149
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['[UNUSED_TOKEN_145]\n'])
150
+ suffix: Prompt = field(default_factory=lambda: ['[UNUSED_TOKEN_145]'])
151
+ system_prefix: Optional[Prompt] = field(
152
+ default_factory=lambda: ['<s>[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n'])
153
+ stop_words: List[Word] = field(default_factory=lambda: ['<|im_end|>'])
154
+
155
+
156
+ register_template(
157
+ Xcomposer2TemplateMeta(
158
+ MLLMTemplateType.xcomposer2,
159
+ template_cls=InternLMXComposer2Template,
160
+ default_system=('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
161
+ '- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by '
162
+ 'Shanghai AI Laboratory (上海人工智能实验室). '
163
+ 'It is designed to be helpful, honest, and harmless.\n'
164
+ '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
165
+ 'by the user such as English and 中文.'),
166
+ ))
167
+
168
+
169
+ class InternLMXComposer2_5Template(InternLMXComposer2Template):
170
+ system = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
171
+ '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model '
172
+ 'that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
173
+ 'It is designed to be helpful, honest, and harmless.\n'
174
+ '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
175
+ 'by the user such as English and 中文.\n'
176
+ '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively '
177
+ 'based on the provided image.')
178
+ version = 'v2.5'
179
+
180
+
181
+ class InternLMXComposer2_4khdTemplate(InternLMXComposer2Template):
182
+ version = 'v2-4khd'
183
+
184
+
185
+ register_template(
186
+ Xcomposer2TemplateMeta(
187
+ MLLMTemplateType.xcomposer2_5,
188
+ template_cls=InternLMXComposer2_5Template,
189
+ default_system=InternLMXComposer2_5Template.system))
190
+
191
+ register_template(
192
+ Xcomposer2TemplateMeta(
193
+ MLLMTemplateType.xcomposer2_4khd,
194
+ template_cls=InternLMXComposer2_4khdTemplate,
195
+ default_system=InternLMXComposer2_5Template.system))
swift/llm/template/template/internvl.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Literal
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from swift.utils import get_env_args, is_deepspeed_enabled
9
+ from ..base import Template
10
+ from ..constant import MLLMTemplateType
11
+ from ..register import register_template
12
+ from ..template_inputs import StdTemplateInputs
13
+ from ..utils import Context, findall
14
+ from ..vision_utils import load_video_internvl, transform_image
15
+ from .microsoft import Phi3TemplateMeta
16
+ from .utils import ChatmlTemplateMeta
17
+
18
+
19
+ class InternvlTemplate(Template):
20
+ skip_prompt = False
21
+ num_image_token = 256
22
+ placeholder_tokens = ['<IMG_CONTEXT>']
23
+
24
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
25
+ inputs: StdTemplateInputs) -> List[Context]:
26
+ if self.mode == 'vllm':
27
+ image_context = ['<image>\n']
28
+ else:
29
+ image_context = ['<img>', [-100], '</img>\n']
30
+ return image_context
31
+
32
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
33
+ encoded = super()._encode(inputs)
34
+ input_ids = encoded['input_ids']
35
+ idx_list = findall(input_ids, -100)
36
+ pixel_values = None
37
+ images = inputs.images
38
+ if images:
39
+ labels = encoded.get('labels')
40
+ input_size = get_env_args('input_size', int, 448)
41
+ max_num = get_env_args('max_num', int, 12)
42
+ pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
43
+ pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model_info.torch_dtype)
44
+ image_bs = pixel_values.shape[0]
45
+
46
+ idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
47
+ img_tokens: List[int] = self.processor.encode(
48
+ '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * image_bs
49
+ input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:]
50
+ if labels is not None:
51
+ labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
52
+ encoded['input_ids'] = input_ids
53
+ encoded['labels'] = labels
54
+ encoded['pixel_values'] = pixel_values
55
+ return encoded
56
+
57
+ def compute_loss_context(self, model, inputs):
58
+ model_name = model.language_model.__class__.__name__.lower()
59
+ if self._packing and 'internlm2' in model_name:
60
+ position_ids = inputs['position_ids']
61
+ modeling_module = model.language_model.model.layers[0].attention.__class__
62
+ return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True)
63
+ else:
64
+ return super().compute_loss_context(model, inputs)
65
+
66
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
67
+ embedding = model.get_input_embeddings()
68
+ device = embedding.weight.device
69
+ input_ids = inputs['input_ids']
70
+ inputs_embeds = embedding(input_ids).to(device=device)
71
+ pixel_values = inputs.get('pixel_values')
72
+ if pixel_values is not None:
73
+ pixel_values = pixel_values.to(device=device)
74
+ vit_embeds = model.extract_feature(pixel_values).to(device=device)
75
+ selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
76
+ inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
77
+ elif is_deepspeed_enabled():
78
+ dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
79
+ vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
80
+ inputs_embeds += vit_embeds.mean() * 0.
81
+ return {'inputs_embeds': inputs_embeds}
82
+
83
+
84
+ register_template(
85
+ ChatmlTemplateMeta(
86
+ MLLMTemplateType.internvl,
87
+ default_system='You are an AI assistant whose name is InternLM (书生·浦语).',
88
+ template_cls=InternvlTemplate,
89
+ auto_add_bos=True))
90
+ register_template(
91
+ Phi3TemplateMeta(
92
+ MLLMTemplateType.internvl_phi3,
93
+ default_system='You are an AI assistant whose name is Phi-3.',
94
+ template_cls=InternvlTemplate,
95
+ auto_add_bos=True))
96
+
97
+
98
+ class Internvl2Template(InternvlTemplate):
99
+ video_segments = 8
100
+
101
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
102
+ inputs: StdTemplateInputs) -> List[Context]:
103
+ image_context = super().replace_tag('image', index, inputs)
104
+ if media_type == 'image':
105
+ return image_context
106
+ elif media_type == 'video':
107
+ video_segments = get_env_args('video_segments', int, self.video_segments)
108
+ load_video = partial(load_video_internvl, num_segments=video_segments)
109
+ return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)
110
+
111
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
112
+ return [f'<ref>{ref}</ref>']
113
+
114
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
115
+ return [f'<box>[{bbox}]</box>']
116
+
117
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
118
+ encoded = super(InternvlTemplate, self)._encode(inputs)
119
+ input_ids = encoded['input_ids']
120
+ idx_list = findall(input_ids, -100)
121
+ labels = encoded['labels']
122
+ images = inputs.images
123
+ if images:
124
+ has_video = bool(inputs.videos)
125
+ input_size = get_env_args('input_size', int, 448)
126
+ max_num = get_env_args('max_num', int, 12)
127
+ video_max_num = get_env_args('video_max_num', int, 1)
128
+ if has_video:
129
+ max_num = video_max_num
130
+ pixel_values = [transform_image(image, input_size, max_num) for image in images]
131
+ num_patches = [pv.shape[0] for pv in pixel_values]
132
+ pixel_values = torch.cat(pixel_values).to(self.model_info.torch_dtype)
133
+ else:
134
+ pixel_values = None
135
+ num_patches = []
136
+ assert len(num_patches) == len(
137
+ idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
138
+
139
+ def _get_new_tokens(i):
140
+ img_tokens: List[int] = self.processor.encode(
141
+ '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patches[i]
142
+ return img_tokens
143
+
144
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
145
+ encoded['pixel_values'] = pixel_values
146
+ return encoded
147
+
148
+
149
+ _internvl2_system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
150
+ register_template(
151
+ ChatmlTemplateMeta(
152
+ MLLMTemplateType.internvl2,
153
+ default_system=_internvl2_system,
154
+ template_cls=Internvl2Template,
155
+ ))
156
+
157
+ register_template(
158
+ Phi3TemplateMeta(
159
+ MLLMTemplateType.internvl2_phi3,
160
+ default_system=_internvl2_system,
161
+ template_cls=Internvl2Template,
162
+ ))
163
+
164
+ register_template(
165
+ ChatmlTemplateMeta(
166
+ MLLMTemplateType.internvl2_5,
167
+ template_cls=Internvl2Template,
168
+ default_system='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'))
swift/llm/template/template/llama.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import datetime as dt
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Literal, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from swift.utils import get_env_args
11
+ from ..base import Template
12
+ from ..constant import LLMTemplateType, MLLMTemplateType
13
+ from ..register import TemplateMeta, register_template
14
+ from ..template_inputs import StdTemplateInputs
15
+ from ..utils import Context, Prompt, Word, findall
16
+ from ..vision_utils import load_batch
17
+
18
+ # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
19
+ LLAMA_DEFAULT_SYSTEM = (
20
+ 'You are a helpful, respectful and honest assistant. '
21
+ 'Always answer as helpfully as possible, while being safe. '
22
+ 'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
23
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
24
+ 'If a question does not make any sense, or is not factually coherent, '
25
+ 'explain why instead of answering something not correct. '
26
+ "If you don't know the answer to a question, please don't share false information.")
27
+
28
+ register_template(
29
+ TemplateMeta(
30
+ LLMTemplateType.llama, ['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'],
31
+ default_system=LLAMA_DEFAULT_SYSTEM,
32
+ system_prefix=['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
33
+
34
+
35
+ @dataclass
36
+ class Llama3TemplateMeta(TemplateMeta):
37
+ prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
38
+ prompt: Prompt = field(default_factory=lambda: [
39
+ '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
40
+ '<|start_header_id|>assistant<|end_header_id|>\n\n'
41
+ ])
42
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot_id|>'])
43
+ suffix: Prompt = field(default_factory=lambda: ['<|eot_id|>'])
44
+ system_prefix: Optional[Prompt] = field(
45
+ default_factory=lambda: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
46
+ agent_template: str = 'llama3'
47
+
48
+
49
+ register_template(Llama3TemplateMeta(LLMTemplateType.llama3))
50
+
51
+
52
+ def _get_llama3_2_prefix() -> Prompt:
53
+ now = dt.datetime.now()
54
+ date_string = now.strftime('%d %b %Y')
55
+ date_prompt = f'Cutting Knowledge Date: December 2023\nToday Date: {date_string}'
56
+ return [f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{date_prompt}\n\n' '{{SYSTEM}}<|eot_id|>']
57
+
58
+
59
+ @dataclass
60
+ class Llama3_2TemplateMeta(Llama3TemplateMeta):
61
+ prefix: Prompt = field(default_factory=lambda: _get_llama3_2_prefix())
62
+ system_prefix: Optional[Prompt] = None
63
+
64
+
65
+ register_template(Llama3_2TemplateMeta(LLMTemplateType.llama3_2))
66
+
67
+
68
+ class Llama3_2VisionTemplate(Template):
69
+
70
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
71
+ inputs: StdTemplateInputs) -> List[Context]:
72
+ assert media_type == 'image'
73
+ return ['<|image|>']
74
+
75
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
76
+ from transformers.models.mllama.processing_mllama import (get_cross_attention_token_mask,
77
+ convert_sparse_cross_attention_mask_to_dense)
78
+ encoded = super()._encode(inputs)
79
+ images = inputs.images
80
+ if images:
81
+ input_ids = encoded['input_ids']
82
+ processor = self.processor
83
+ image_features = processor.image_processor(images, return_tensors='pt')
84
+ num_tiles = image_features.pop('num_tiles')
85
+ encoded.update(image_features)
86
+
87
+ cross_attention_token_mask = [get_cross_attention_token_mask(input_ids, processor.image_token_id)]
88
+ cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
89
+ cross_attention_token_mask,
90
+ num_tiles=num_tiles,
91
+ max_num_tiles=processor.image_processor.max_image_tiles,
92
+ length=len(input_ids),
93
+ )
94
+ encoded['cross_attention_mask'] = torch.tensor(cross_attention_mask)
95
+
96
+ return encoded
97
+
98
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
99
+ res = super()._data_collator(batch, padding_to=padding_to)
100
+ for key in ['aspect_ratio_ids', 'aspect_ratio_mask']:
101
+ value = [b[key] for b in batch if b.get(key) is not None]
102
+ if value:
103
+ res[key] = torch.concat(value)
104
+
105
+ cross_attention_mask = [
106
+ b['cross_attention_mask'][0] for b in batch if b.get('cross_attention_mask') is not None
107
+ ]
108
+ if cross_attention_mask:
109
+ res['cross_attention_mask'] = self._pad_sequence(cross_attention_mask, 0)
110
+ return res
111
+
112
+
113
+ register_template(Llama3_2TemplateMeta(MLLMTemplateType.llama3_2_vision, template_cls=Llama3_2VisionTemplate))
114
+
115
+
116
+ class Llama4Template(Template):
117
+ placeholder_tokens = ['<|patch|>']
118
+
119
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
120
+ inputs: StdTemplateInputs) -> List[Context]:
121
+ assert media_type == 'image'
122
+ return [[-100]]
123
+
124
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
125
+ encoded = super()._encode(inputs)
126
+ images = inputs.images
127
+ if images:
128
+ split_token = self._tokenize('\n')
129
+ input_ids, labels = encoded['input_ids'], encoded['labels']
130
+ idx_list = findall(input_ids, -100)
131
+ media_inputs = self.processor(
132
+ text='\n'.join(['<|image|>'] * len(idx_list)),
133
+ images=images,
134
+ add_special_tokens=False,
135
+ return_tensors='pt')
136
+ splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
137
+
138
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
139
+ lambda i: splited_tokens[i])
140
+ encoded['pixel_values'] = media_inputs['pixel_values']
141
+ return encoded
142
+
143
+
144
+ @dataclass
145
+ class Llama4TemplateMeta(TemplateMeta):
146
+ prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
147
+ prompt: Prompt = field(
148
+ default_factory=lambda:
149
+ ['<|header_start|>user<|header_end|>\n\n{{QUERY}}<|eot|>'
150
+ '<|header_start|>assistant<|header_end|>\n\n'])
151
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot|>'])
152
+ suffix: Prompt = field(default_factory=lambda: ['<|eot|>'])
153
+ stop_words: List[Word] = field(default_factory=lambda: ['<|end_of_text|>', '<|eom|>'])
154
+ system_prefix: Optional[Prompt] = field(
155
+ default_factory=lambda: ['<|begin_of_text|><|header_start|>system<|header_end|>\n\n{{SYSTEM}}<|eot|>'])
156
+ agent_template: str = 'llama4'
157
+
158
+
159
+ register_template(Llama4TemplateMeta(MLLMTemplateType.llama4, template_cls=Llama4Template))
160
+
161
+ register_template(
162
+ Llama3TemplateMeta(
163
+ LLMTemplateType.reflection,
164
+ default_system=('You are a world-class AI system, capable of complex reasoning and reflection. '
165
+ 'Reason through the query inside <thinking> tags, and then provide your final '
166
+ 'response inside <output> tags. If you detect that you made a mistake in your reasoning '
167
+ 'at any point, correct yourself inside <reflection> tags.')))
168
+
169
+
170
+ class Llama3_1OmniTemplate(Template):
171
+ skip_prompt = False
172
+ audio_placeholder = [[-200]]
173
+
174
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
175
+ import whisper
176
+ encoded = super()._encode(inputs)
177
+ audios = inputs.audios
178
+ if audios:
179
+ audios = load_batch(audios, whisper.load_audio)
180
+ n_mels = get_env_args('n_mels', int, 128)
181
+ for i, audio in enumerate(audios):
182
+ audio = whisper.pad_or_trim(audio)
183
+ audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0)
184
+ audios = torch.stack(audios)
185
+ encoded.update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])})
186
+
187
+ return encoded
188
+
189
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
190
+ speech = inputs.get('speech')
191
+ input_ids = inputs['input_ids']
192
+ labels = inputs.get('labels')
193
+ if speech is not None:
194
+ speech_lengths = inputs['speech_lengths']
195
+ speech = speech.to(model.dtype)
196
+ inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels,
197
+ speech, speech_lengths)[4:]
198
+ else:
199
+ inputs_embeds = model.get_model().embed_tokens(input_ids)
200
+ res = {'inputs_embeds': inputs_embeds}
201
+ if labels is not None:
202
+ res['labels'] = labels[0]
203
+ return res
204
+
205
+
206
+ register_template(
207
+ Llama3TemplateMeta(
208
+ MLLMTemplateType.llama3_1_omni,
209
+ default_system=('You are a helpful language and speech assistant. '
210
+ 'You are able to understand the speech content that the user provides, '
211
+ 'and assist the user with a variety of tasks using natural language.'),
212
+ template_cls=Llama3_1OmniTemplate,
213
+ ))
swift/llm/template/template/llava.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+ import transformers
7
+ from packaging import version
8
+
9
+ from ..base import Template
10
+ from ..constant import MLLMTemplateType
11
+ from ..register import TemplateMeta, register_template
12
+ from ..template_inputs import StdTemplateInputs
13
+ from ..utils import Context, Prompt, findall
14
+ from ..vision_utils import load_video_llava
15
+ from .llama import Llama3TemplateMeta
16
+ from .qwen import QwenTemplateMeta
17
+ from .utils import ChatmlTemplateMeta
18
+
19
+
20
+ class LlavaHfTemplate(Template):
21
+ placeholder_tokens = ['<image>']
22
+
23
+ @property
24
+ def image_token_index(self):
25
+ if not hasattr(self, '_image_token_index'):
26
+ self._image_token_index = self.tokenizer.convert_tokens_to_ids(self.processor.image_token)
27
+ return self._image_token_index
28
+
29
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
30
+ inputs: StdTemplateInputs) -> List[Context]:
31
+ assert media_type == 'image'
32
+ return ['<image>\n']
33
+
34
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
35
+ encoded = super()._encode(inputs)
36
+ images = inputs.images
37
+ if images:
38
+ image_processor = self.processor.image_processor
39
+ image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
40
+ encoded['pixel_values'] = image_inputs['pixel_values']
41
+ if 'image_sizes' in image_inputs:
42
+ encoded['image_sizes'] = image_inputs['image_sizes']
43
+ if version.parse(transformers.__version__) >= version.parse('4.47'):
44
+ input_ids = encoded['input_ids']
45
+ labels = encoded['labels']
46
+ idx_list = findall(input_ids, self.image_token_index) # <image>
47
+ height, width = image_inputs['pixel_values'][0].shape[-2:]
48
+ added_tokens_len = 0
49
+ for i, idx in enumerate(idx_list):
50
+ if 'image_sizes' in image_inputs:
51
+ orig_height, orig_width = image_inputs['image_sizes'][i].tolist()
52
+ num_image_tokens = self.processor._get_number_of_features(orig_height, orig_width, height,
53
+ width)
54
+ else:
55
+ num_image_tokens = (height // self.processor.patch_size) * (
56
+ width // self.processor.patch_size) + self.processor.num_additional_image_tokens
57
+ if self.processor.vision_feature_select_strategy == 'default':
58
+ num_image_tokens -= 1
59
+ input_ids = input_ids[:added_tokens_len + idx] + [self.image_token_index] * num_image_tokens \
60
+ + input_ids[added_tokens_len + idx + 1:]
61
+ if labels is not None:
62
+ labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens \
63
+ + labels[added_tokens_len + idx + 1:]
64
+ added_tokens_len += num_image_tokens - 1
65
+ encoded['input_ids'] = input_ids
66
+ encoded['labels'] = labels
67
+ return encoded
68
+
69
+
70
+ register_template(
71
+ TemplateMeta(
72
+ MLLMTemplateType.llava1_5_hf,
73
+ prefix=['<s>'],
74
+ prompt=['USER: {{QUERY}}\nASSISTANT:'],
75
+ chat_sep=['</s>'],
76
+ suffix=['</s>'],
77
+ system_prefix=['<s>{{SYSTEM}}\n'],
78
+ template_cls=LlavaHfTemplate,
79
+ ))
80
+
81
+
82
+ class LlavaVideoHfTemplate(Template):
83
+
84
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
85
+ inputs: StdTemplateInputs) -> List[Context]:
86
+ if media_type == 'image':
87
+ return ['<image>\n']
88
+ assert media_type == 'video'
89
+ media_file = inputs.videos[index]
90
+ if media_file.rsplit('.', 1)[-1] in {'jpg', 'png'}:
91
+ return ['<image>\n']
92
+ else:
93
+ inputs.videos[index] = load_video_llava(inputs.videos[index])
94
+ return ['<video>\n']
95
+
96
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
97
+ encoded = super()._encode(inputs)
98
+ images = inputs.images or []
99
+ videos = inputs.videos or []
100
+ if len(videos) > 0:
101
+ video_processor = self.processor.video_processor
102
+ video_inputs = video_processor(videos, return_tensors='pt').to(self.model_info.torch_dtype)
103
+ encoded['pixel_values_videos'] = video_inputs['pixel_values_videos']
104
+ if len(images) > 0:
105
+ image_processor = self.processor.image_processor
106
+ image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
107
+ encoded['pixel_values'] = image_inputs['pixel_values']
108
+ encoded['image_sizes'] = image_inputs['image_sizes']
109
+ return encoded
110
+
111
+
112
+ register_template(
113
+ TemplateMeta(
114
+ MLLMTemplateType.llava_next_video_hf,
115
+ prefix=['{{SYSTEM}} '],
116
+ prompt=['USER: {{QUERY}} ASSISTANT:'],
117
+ chat_sep=[' '],
118
+ suffix=[['eos_token_id']],
119
+ template_cls=LlavaVideoHfTemplate,
120
+ auto_add_bos=True,
121
+ ))
122
+
123
+
124
+ class Llava1_6HfTemplate(LlavaHfTemplate):
125
+
126
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
127
+ for b in batch:
128
+ pixel_values = b.get('pixel_values')
129
+ if pixel_values is not None:
130
+ b['pixel_values'] = pixel_values.squeeze(0) # 5d -> 4d
131
+ res = super()._data_collator(batch, padding_to=padding_to)
132
+ return res
133
+
134
+
135
+ @dataclass
136
+ class LlavaMistralTemplateMeta(TemplateMeta):
137
+ prefix: Prompt = field(default_factory=lambda: ['<s>[INST] '])
138
+ prompt: Prompt = field(default_factory=lambda: ['{{QUERY}} [/INST]'])
139
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['</s>[INST] '])
140
+ suffix: Prompt = field(default_factory=lambda: ['</s>'])
141
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<<SYS>>\n{{system}}\n<</SYS>>\n\n'])
142
+
143
+
144
+ register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral_hf, template_cls=Llava1_6HfTemplate))
145
+
146
+ register_template(
147
+ TemplateMeta(
148
+ MLLMTemplateType.llava1_6_vicuna_hf,
149
+ prefix=['<s>'],
150
+ prompt=['USER: {{QUERY}} ASSISTANT:'],
151
+ chat_sep=['</s>'],
152
+ suffix=['</s>'],
153
+ default_system=('A chat between a curious human and an artificial intelligence assistant. '
154
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."),
155
+ system_prefix=['<s>{{SYSTEM}} '],
156
+ template_cls=Llava1_6HfTemplate))
157
+
158
+
159
+ class LLava1_6YiHfTemplate(Llava1_6HfTemplate):
160
+
161
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
162
+ inputs: StdTemplateInputs) -> List[Context]:
163
+ if self.mode == 'vllm':
164
+ return [[64000], '\n']
165
+ else:
166
+ return super().replace_tag(media_type, index, inputs)
167
+
168
+
169
+ register_template(ChatmlTemplateMeta(
170
+ MLLMTemplateType.llava1_6_yi_hf,
171
+ template_cls=LLava1_6YiHfTemplate,
172
+ ))
173
+
174
+ register_template(Llama3TemplateMeta(
175
+ MLLMTemplateType.llama3_llava_next_hf,
176
+ template_cls=Llava1_6HfTemplate,
177
+ ))
178
+
179
+ register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen_hf, template_cls=Llava1_6HfTemplate))
180
+
181
+
182
+ class LlavaOneVisionHfTemplate(Llava1_6HfTemplate):
183
+
184
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
185
+ encoded = Template._encode(self, inputs)
186
+ images = inputs.images
187
+ input_ids = encoded['input_ids']
188
+ labels = encoded['labels']
189
+ idx_list = findall(input_ids, 151646) # <image>
190
+ processor = self.processor
191
+ if images:
192
+ image_processor = processor.image_processor
193
+ image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
194
+ height, width = image_inputs['pixel_values'][0].shape[-2:]
195
+ added_tokens_len = 0
196
+ for idx, pixel_v, image_size in zip(idx_list, image_inputs['pixel_values'], image_inputs['image_sizes']):
197
+ if isinstance(image_size, torch.Tensor):
198
+ image_size = image_size.tolist()
199
+ orig_height, orig_width = image_size
200
+ num_image_tokens = processor._get_number_of_features(orig_height, orig_width, height, width)
201
+ input_ids = input_ids[:added_tokens_len
202
+ + idx] + [151646] * num_image_tokens + input_ids[added_tokens_len + idx + 1:]
203
+ if labels is not None:
204
+ labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens + labels[added_tokens_len + idx
205
+ + 1:]
206
+ added_tokens_len += num_image_tokens - 1
207
+ encoded['input_ids'] = input_ids
208
+ encoded['labels'] = labels
209
+ encoded['pixel_values'] = image_inputs['pixel_values']
210
+ if 'image_sizes' in image_inputs:
211
+ encoded['image_sizes'] = image_inputs['image_sizes']
212
+ return encoded
213
+
214
+
215
+ register_template(
216
+ QwenTemplateMeta(
217
+ MLLMTemplateType.llava_onevision_hf,
218
+ default_system=None,
219
+ template_cls=LlavaOneVisionHfTemplate,
220
+ ))
221
+
222
+
223
+ class LlavaLlama3_1HfTemplate(LlavaHfTemplate):
224
+ # DaozeZhang
225
+ system = ('You are a helpful language and vision assistant. '
226
+ 'You are able to understand the visual content that the user provides, '
227
+ 'and assist the user with a variety of tasks using natural language.')
228
+
229
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
230
+ encoded = super()._encode(inputs)
231
+ if len(encoded['pixel_values'].shape) == 5: # (1, num_patch, 3, H/W, W/H)
232
+ encoded['pixel_values'] = torch.squeeze(encoded['pixel_values'], dim=0) # (num_patch, 3, H/W, W/H)
233
+ return encoded
234
+
235
+
236
+ register_template(
237
+ Llama3TemplateMeta(
238
+ MLLMTemplateType.llava_llama3_1_hf,
239
+ default_system=LlavaLlama3_1HfTemplate.system,
240
+ template_cls=LlavaLlama3_1HfTemplate,
241
+ ))
242
+
243
+
244
+ class LLavaLlama3HfTemplate(Template):
245
+ # xtuner
246
+ image_placeholder = ['<image>\n']
247
+
248
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
249
+ encoded = super()._encode(inputs)
250
+ raw_image = inputs.images
251
+ if raw_image:
252
+ pixel_values = self.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
253
+ encoded['pixel_values'] = pixel_values.to(self.model_info.torch_dtype)
254
+ return encoded
255
+
256
+
257
+ register_template(Llama3TemplateMeta(
258
+ MLLMTemplateType.llava_llama3_hf,
259
+ template_cls=LLavaLlama3HfTemplate,
260
+ ))
261
+
262
+
263
+ class LLavaTemplate(Template):
264
+ skip_prompt = False
265
+ use_model = True
266
+
267
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
268
+ inputs: StdTemplateInputs) -> List[Context]:
269
+ assert media_type == 'image'
270
+ return [[-200], '\n']
271
+
272
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
273
+ encoded = super()._encode(inputs)
274
+ images = inputs.images or []
275
+ image_sizes = [x.size for x in images]
276
+ from llava.mm_utils import process_images
277
+ model = self.model.model
278
+ if not hasattr(model, 'vision_tower'):
279
+ model = model.model
280
+ image_processor = model.vision_tower.image_processor
281
+ if images:
282
+ images_tensor = process_images(images, image_processor, model.config)
283
+ encoded['images'] = images_tensor.to(model.dtype).squeeze(0)
284
+ encoded['image_sizes'] = image_sizes
285
+ return encoded
286
+
287
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
288
+ res = super()._data_collator(batch, padding_to=padding_to)
289
+ images = [b['images'] for b in batch if 'images' in b]
290
+ if images:
291
+ res['images'] = images
292
+ res['image_sizes'] = sum([b['image_sizes'] for b in batch if 'image_sizes' in b], start=[])
293
+ return res
294
+
295
+
296
+ register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral, template_cls=LLavaTemplate))
297
+
298
+ register_template(ChatmlTemplateMeta(MLLMTemplateType.llava1_6_yi, template_cls=LLavaTemplate))
299
+
300
+ register_template(
301
+ Llama3TemplateMeta(
302
+ MLLMTemplateType.llama3_llava_next,
303
+ template_cls=LLavaTemplate,
304
+ default_system=('You are a helpful language and vision assistant. '
305
+ 'You are able to understand the visual content that the user provides, '
306
+ 'and assist the user with a variety of tasks using natural language.'),
307
+ ))
308
+
309
+ register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate))
swift/llm/template/template/llm.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from datetime import datetime
4
+ from typing import Optional
5
+
6
+ from ..constant import LLMTemplateType, MLLMTemplateType
7
+ from ..register import TemplateMeta, register_template
8
+ from ..utils import Prompt
9
+ from .llama import Llama3_2TemplateMeta
10
+ from .qwen import Qwen2VLTemplate, QwenTemplateMeta
11
+ from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
12
+
13
+ register_template(
14
+ TemplateMeta(
15
+ LLMTemplateType.default,
16
+ prefix=[],
17
+ prompt=['### Human:\n{{QUERY}}\n\n### Assistant:\n'],
18
+ chat_sep=['\n\n'],
19
+ default_system=DEFAULT_SYSTEM,
20
+ system_prefix=['{{SYSTEM}}\n\n'],
21
+ auto_add_bos=True))
22
+
23
+ register_template(
24
+ TemplateMeta(
25
+ LLMTemplateType.modelscope_agent,
26
+ prefix=[],
27
+ prompt=[' \n\n<|user|>:{{QUERY}} \n\n<|assistant|>:'],
28
+ chat_sep=[],
29
+ suffix=[' \n\n</s>'],
30
+ system_prefix=[' \n\n<|system|>:{{SYSTEM}}'],
31
+ default_system=DEFAULT_SYSTEM,
32
+ ))
33
+
34
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_gme, template_cls=Qwen2VLTemplate, suffix=['<|endoftext|>']))
35
+
36
+ register_template(
37
+ TemplateMeta(LLMTemplateType.baichuan, prefix=['{{SYSTEM}}'], prompt=[[195], '{{QUERY}}', [196]], chat_sep=[]))
38
+
39
+ register_template(
40
+ TemplateMeta(
41
+ LLMTemplateType.baichuan_m1,
42
+ prefix=[],
43
+ prompt=['<C_Q>{{QUERY}}<C_A>'],
44
+ chat_sep=[],
45
+ suffix=['<C_A>'],
46
+ system_prefix=['<B_SYS>{{SYSTEM}}'],
47
+ default_system=DEFAULT_SYSTEM,
48
+ ))
49
+
50
+ register_template(
51
+ TemplateMeta(
52
+ LLMTemplateType.numina,
53
+ prefix=[['bos_token_id']],
54
+ prompt=['### Problem: {{QUERY}}\n### Solution: '],
55
+ chat_sep=['\n'],
56
+ system_prefix=[['bos_token_id'], '{{SYSTEM}}']))
57
+
58
+ register_template(
59
+ TemplateMeta(
60
+ LLMTemplateType.mistral_nemo,
61
+ prefix=['<s>[INST] '],
62
+ prompt=['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'],
63
+ chat_sep=['</s>[INST] '],
64
+ suffix=['</s>']))
65
+
66
+ today = datetime.now().strftime('%Y-%m-%d')
67
+
68
+ mistral_2501_system = (
69
+ 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
70
+ 'headquartered in Paris.\n'
71
+ f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
72
+ "When you're not sure about some information, you say that you don't have the information and don't "
73
+ 'make up anything.\n'
74
+ "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
75
+ 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
76
+ '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
77
+ 'Where do you travel from?")')
78
+
79
+ register_template(
80
+ TemplateMeta(
81
+ LLMTemplateType.mistral_2501,
82
+ prefix=['<s>'],
83
+ prompt=['[INST]{{QUERY}}[/INST]'],
84
+ chat_sep=['</s>'],
85
+ suffix=['</s>'],
86
+ system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
87
+ default_system=mistral_2501_system))
88
+
89
+ register_template(
90
+ TemplateMeta(
91
+ LLMTemplateType.xverse,
92
+ prefix=['{{SYSTEM}}'],
93
+ prompt=['Human: {{QUERY}}\n\nAssistant: '],
94
+ chat_sep=[['eos_token_id']]))
95
+
96
+ register_template(TemplateMeta(LLMTemplateType.yuan, prefix=[], prompt=['{{QUERY}}<sep>'], chat_sep=None))
97
+ register_template(
98
+ TemplateMeta(
99
+ LLMTemplateType.ziya,
100
+ prefix=[['bos_token_id'], '{{SYSTEM}}'],
101
+ prompt=['<human>:{{QUERY}}\n<bot>:'],
102
+ chat_sep=['\n']))
103
+
104
+ register_template(
105
+ TemplateMeta(
106
+ LLMTemplateType.skywork,
107
+ prefix=['<s>{{SYSTEM}}'],
108
+ prompt=['</s><s>[USER]{{QUERY}}[SEP][BOT]'],
109
+ chat_sep=None,
110
+ suffix=['[SEP]</s>']))
111
+
112
+ register_template(
113
+ Llama3_2TemplateMeta(
114
+ LLMTemplateType.skywork_o1,
115
+ default_system=(
116
+ 'You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems '
117
+ "involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's "
118
+ 'request, you first engage in a lengthy and in-depth thinking process to explore possible solutions to '
119
+ 'the problem. After completing your thoughts, you then provide a detailed explanation of the solution '
120
+ 'process in your response.'),
121
+ ))
122
+
123
+ register_template(
124
+ TemplateMeta(
125
+ LLMTemplateType.bluelm,
126
+ prefix=[['bos_token_id'], '{{SYSTEM}}'],
127
+ prompt=['[|Human|]:{{QUERY}}[|AI|]:'],
128
+ chat_sep=[]))
129
+
130
+ register_template(
131
+ TemplateMeta(
132
+ LLMTemplateType.codefuse_codellama,
133
+ prefix=['{{SYSTEM}}'],
134
+ prompt=['<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'],
135
+ chat_sep=[]))
136
+
137
+ register_template(
138
+ TemplateMeta(
139
+ LLMTemplateType.codefuse,
140
+ prefix=[],
141
+ prompt=['<s>human\n{{QUERY}}\n<s>bot\n'],
142
+ chat_sep=[['eos_token_id'], '\n'],
143
+ system_prefix=['<s>system\n{{SYSTEM}}\n']))
144
+
145
+ register_template(
146
+ TemplateMeta(
147
+ LLMTemplateType.zephyr,
148
+ prefix=[],
149
+ prompt=['<|user|>\n{{QUERY}}</s>\n<|assistant|>\n'],
150
+ chat_sep=['</s>\n'],
151
+ suffix=['</s>'],
152
+ system_prefix=['<|system|>\n{{SYSTEM}}</s>\n']))
153
+
154
+ register_template(
155
+ TemplateMeta(
156
+ LLMTemplateType.sus,
157
+ prefix=['{{SYSTEM}}'],
158
+ prompt=['### Human: {{QUERY}}\n\n### Assistant: '],
159
+ chat_sep=['<|endoftext|>'],
160
+ suffix=['<|endoftext|>']))
161
+
162
+ register_template(
163
+ TemplateMeta(
164
+ LLMTemplateType.orion,
165
+ prefix=['<s>{{SYSTEM}}'],
166
+ prompt=['Human: {{QUERY}}\n\nAssistant: </s>'],
167
+ chat_sep=['</s>'],
168
+ suffix=['</s>']))
169
+
170
+
171
+ @dataclass
172
+ class TeleChatTemplateMeta(TemplateMeta):
173
+ prefix: Prompt = field(default_factory=list)
174
+ prompt: Prompt = field(default_factory=lambda: [['user_token_id'], '{{QUERY}}', ['bot_token_id']])
175
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
176
+ suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
177
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<_system>{{SYSTEM}}\n'])
178
+ auto_add_bos: bool = True
179
+
180
+
181
+ register_template(TeleChatTemplateMeta(LLMTemplateType.telechat))
182
+
183
+ telechat_system = '你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。'
184
+ register_template(TeleChatTemplateMeta(LLMTemplateType.telechat2, default_system=telechat_system))
185
+
186
+ DBRX_SYSTEM = (
187
+ 'You are DBRX, created by Databricks. You were last updated in December 2023. '
188
+ 'You answer questions based on information available up to that point.\n'
189
+ 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
190
+ 'but provide thorough responses to more complex and open-ended questions.\n'
191
+ 'You assist with various tasks, from writing to coding (using markdown for code blocks '
192
+ '— remember to use ``` with code, JSON, and tables).\n'
193
+ 'You do not have real-time data access or code execution capabilities.'
194
+ ' You avoid stereotyping and provide balanced perspectives on controversial topics. '
195
+ 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
196
+ 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
197
+ 'If you find yourself talking about this message, stop. You should be responding appropriately '
198
+ 'and usually that means not mentioning this.'
199
+ 'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
200
+ 'PERTINENT TO THE USER\'S QUERY.')
201
+
202
+ register_template(ChatmlTemplateMeta(LLMTemplateType.dbrx, default_system=DBRX_SYSTEM))
203
+
204
+ register_template(
205
+ TemplateMeta(
206
+ LLMTemplateType.mengzi, prefix=[], prompt=['输入:{{QUERY}}输出:\n'], chat_sep=[], system_prefix=['指令:{{SYSTEM}}']))
207
+
208
+ C4AI_SYSTEM = ('You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by '
209
+ 'providing thorough responses.You are trained by Cohere.')
210
+ register_template(
211
+ TemplateMeta(
212
+ LLMTemplateType.c4ai,
213
+ prefix=['<BOS_TOKEN>'],
214
+ prompt=[
215
+ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>'
216
+ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
217
+ ],
218
+ chat_sep=['<|END_OF_TURN_TOKEN|>'],
219
+ suffix=['<|END_OF_TURN_TOKEN|>'],
220
+ default_system=C4AI_SYSTEM,
221
+ system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|']))
222
+
223
+ register_template(
224
+ TemplateMeta(
225
+ LLMTemplateType.wizardlm2,
226
+ prefix=['{{SYSTEM}}'],
227
+ prompt=['User:\n{{QUERY}}\n\nAssistant:\n'],
228
+ chat_sep=['\n\n'],
229
+ suffix=['</s>']))
230
+
231
+ _wizardlm2_system = ('A chat between a curious user and an artificial intelligence assistant. '
232
+ 'The assistant gives helpful, detailed, and polite answers to the user\'s questions. ')
233
+ register_template(
234
+ TemplateMeta(
235
+ LLMTemplateType.wizardlm2_moe,
236
+ prefix=['{{SYSTEM}}'],
237
+ prompt=['USER: {{QUERY}} ASSISTANT:'],
238
+ chat_sep=['</s>'],
239
+ suffix=['</s>'],
240
+ default_system=_wizardlm2_system))
241
+
242
+ register_template(
243
+ TemplateMeta(
244
+ LLMTemplateType.atom,
245
+ prefix=['{{SYSTEM}}'],
246
+ prompt=['<s>Human: {{QUERY}}\n</s><s>Assistant: '],
247
+ chat_sep=['</s>'],
248
+ suffix=['</s>']))
249
+
250
+ AYA_SYSTEM = ('You are Aya, a brilliant, sophisticated, multilingual AI-assistant trained to assist human users by '
251
+ 'providing thorough responses. You are able to interact and respond to questions in 23 languages and '
252
+ 'you are powered by a multilingual model built by Cohere For AI.')
253
+ register_template(
254
+ TemplateMeta(
255
+ LLMTemplateType.aya,
256
+ prefix=['<BOS_TOKEN>'],
257
+ prompt=[
258
+ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>'
259
+ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
260
+ ],
261
+ chat_sep=['<|END_OF_TURN_TOKEN|>'],
262
+ suffix=['<|END_OF_TURN_TOKEN|>'],
263
+ default_system=AYA_SYSTEM,
264
+ system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|']))
265
+
266
+ register_template(
267
+ TemplateMeta(
268
+ LLMTemplateType.ling,
269
+ prefix=[],
270
+ system_prefix=['<role>SYSTEM</role>{{SYSTEM}}'],
271
+ prompt=['<role>HUMAN</role>{{QUERY}}<role>ASSISTANT</role>'],
272
+ chat_sep=[],
273
+ suffix=['<|endoftext|>'],
274
+ ))
swift/llm/template/template/megrez.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ..base import Template
9
+ from ..constant import LLMTemplateType, MLLMTemplateType
10
+ from ..register import TemplateMeta, register_template
11
+ from ..template_inputs import StdTemplateInputs
12
+ from ..utils import Context, Prompt, findall
13
+
14
+
15
+ @dataclass
16
+ class MegrezTemplateMeta(TemplateMeta):
17
+ prefix: Prompt = field(default_factory=lambda: ['<|role_start|>system<|role_end|>{{SYSTEM}}<|turn_end|>'])
18
+ prompt: Prompt = field(default_factory=lambda:
19
+ ['<|role_start|>user<|role_end|>{{QUERY}}<|turn_end|><|role_start|>assistant<|role_end|>'])
20
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|turn_end|>'])
21
+ suffix: Prompt = field(default_factory=lambda: ['<|turn_end|>'])
22
+ default_system: str = '你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。'
23
+
24
+
25
+ register_template(MegrezTemplateMeta(LLMTemplateType.megrez))
26
+
27
+
28
+ class MegrezOmniTemplate(Template):
29
+ skip_prompt = False
30
+ placeholder_tokens = ['<|unk|>']
31
+
32
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
33
+ inputs: StdTemplateInputs) -> List[Context]:
34
+ if media_type == 'image':
35
+ return [[-1], '\n']
36
+ elif media_type == 'audio':
37
+ return [f'Audio {index + 1}: ', [-2], '\n']
38
+
39
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
40
+ encoded = super()._encode(inputs)
41
+ input_ids = encoded['input_ids']
42
+ labels = encoded['labels']
43
+
44
+ for mm_key in ['images', 'audios']:
45
+ mm_data = getattr(inputs, mm_key)
46
+ if not mm_data:
47
+ continue
48
+ if mm_key == 'images':
49
+ idx_list = findall(input_ids, -1)
50
+ encoding = self.processor.process_image(
51
+ mm_data,
52
+ return_tensors='pt',
53
+ )
54
+ text = self.processor.insert_image_feature_placeholders(
55
+ '<s>'.join(['(<image>./</image>)'] * len(mm_data)), encoding)
56
+ encoded['image_encoding'] = encoding
57
+ else:
58
+ idx_list = findall(input_ids, -2)
59
+ encoding = self.processor.process_audio(
60
+ mm_data,
61
+ return_tensors='pt',
62
+ )
63
+ text = self.processor.insert_audio_feature_placeholders(
64
+ '<s>'.join(['(<audio>./</audio>)'] * len(mm_data)), encoding)
65
+ encoded['audio_encoding'] = encoding
66
+
67
+ padding = text.split('<s>')
68
+
69
+ def _get_new_tokens(i):
70
+ return self._tokenize(padding[i])
71
+
72
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
73
+ encoded['input_ids'] = input_ids
74
+ encoded['labels'] = labels
75
+ return encoded
76
+
77
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
78
+ _, inputs_embeds, _ = model.compose_embeddings(inputs)
79
+ inputs.pop('position_ids', None)
80
+ return {'inputs_embeds': inputs_embeds}
81
+
82
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
83
+ res = super()._data_collator(batch, padding_to=padding_to)
84
+ new_batch = []
85
+ for b in batch:
86
+ text_encodings = {'input_ids': torch.tensor(b['input_ids'])}
87
+ multimodal_inputs = {'image_encoding': b.get('image_encoding'), 'audio_encoding': b.get('audio_encoding')}
88
+ new_batch.append(self.processor.merge_encodings(text_encodings, multimodal_inputs))
89
+ res.update(self.processor.data_collator(new_batch))
90
+ return res
91
+
92
+
93
+ register_template(MegrezTemplateMeta(MLLMTemplateType.megrez_omni, template_cls=MegrezOmniTemplate))
swift/llm/template/template/microsoft.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import json
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ..base import Template
10
+ from ..constant import LLMTemplateType, MLLMTemplateType
11
+ from ..register import TemplateMeta, register_template
12
+ from ..template_inputs import StdTemplateInputs
13
+ from ..utils import Context, Prompt, findall
14
+ from ..vision_utils import load_file
15
+
16
+
17
+ class FlorenceTemplate(Template):
18
+ # If it's an encoder-decoder architecture, the default settings are
19
+ # loss_scale: 'last_round' and skip_prompt: False.
20
+ is_encoder_decoder = True
21
+
22
+ @staticmethod
23
+ def _add_default_tags(inputs: StdTemplateInputs) -> None:
24
+ return
25
+
26
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
27
+ inputs: StdTemplateInputs) -> List[Context]:
28
+ return []
29
+
30
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
31
+ return [''.join(f'<loc_{box}>' for box in bbox)]
32
+
33
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
34
+ processor = self.processor
35
+ inputs.query = inputs.to_history()['query']
36
+ new_query = processor._construct_prompts([inputs.query])[0]
37
+ for i in reversed(range(len(inputs.messages))):
38
+ if inputs.messages[i]['role'] == 'user':
39
+ inputs.messages[i]['content'] = new_query
40
+ break
41
+ encoded = super()._encode(inputs)
42
+ input_ids = encoded['prompt_input_ids']
43
+ images = inputs.images or []
44
+ labels = encoded['labels']
45
+ if labels is not None:
46
+ labels = [0] + labels
47
+ if images:
48
+ pixel_values = processor.image_processor(
49
+ images, return_tensors='pt')['pixel_values'].to(self.model_info.torch_dtype)
50
+ encoded['pixel_values'] = pixel_values
51
+ encoded['input_ids'] = input_ids
52
+ encoded['labels'] = labels
53
+ return encoded
54
+
55
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
56
+ inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
57
+ pixel_values = inputs.get('pixel_values')
58
+ if pixel_values is not None:
59
+ image_features = model._encode_image(pixel_values)
60
+ inputs_embeds, inputs['attention_mask'] = model._merge_input_ids_with_image_features(
61
+ image_features, inputs_embeds)
62
+ return {'inputs_embeds': inputs_embeds}
63
+
64
+ def decode(self, generate_ids: List[int], **kwargs) -> Any:
65
+ response = super().decode(generate_ids, **kwargs)
66
+ template_inputs = kwargs.get('template_inputs')
67
+ images = template_inputs.images
68
+ image_size = None
69
+ if images:
70
+ image_size = (images[0].width, images[0].height)
71
+ return json.dumps(
72
+ self.processor.post_process_generation(response, task=template_inputs.query, image_size=image_size))
73
+
74
+
75
+ register_template(
76
+ TemplateMeta(
77
+ MLLMTemplateType.florence,
78
+ prefix=['<s>'],
79
+ prompt=['{{QUERY}}</s>'],
80
+ chat_sep=None,
81
+ suffix=['</s>'],
82
+ template_cls=FlorenceTemplate,
83
+ ))
84
+
85
+
86
+ @dataclass
87
+ class Phi3TemplateMeta(TemplateMeta):
88
+ prefix: Prompt = field(default_factory=list)
89
+ prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n'])
90
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end|>\n'])
91
+ suffix: Prompt = field(default_factory=lambda: ['<|end|>'])
92
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}<|end|>\n'])
93
+ auto_add_bos: bool = True
94
+
95
+
96
+ register_template(Phi3TemplateMeta(LLMTemplateType.phi3))
97
+
98
+
99
+ @dataclass
100
+ class Phi4TemplateMeta(TemplateMeta):
101
+ prefix: Prompt = field(default_factory=list)
102
+ prompt: Prompt = field(
103
+ default_factory=lambda: ['<|im_start|>user<|im_sep|>{{QUERY}}<|im_end|><|im_start|>assistant<|im_sep|>'])
104
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>'])
105
+ suffix: Prompt = field(default_factory=lambda: ['<|im_end|>'])
106
+ system_prefix: Optional[Prompt] = field(
107
+ default_factory=lambda: ['<|im_start|>system<|im_sep|>{{SYSTEM}}<|im_end|>'])
108
+ auto_add_bos: bool = True
109
+
110
+
111
+ register_template(Phi4TemplateMeta(LLMTemplateType.phi4))
112
+
113
+
114
+ class Phi3VisionTemplate(Template):
115
+ image_placeholder = ['<|image|><s>\n'] # <|image|>\n
116
+
117
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
118
+ inputs: StdTemplateInputs) -> List[Context]:
119
+ if self.mode == 'vllm':
120
+ return [f'<|image_{index + 1}|>\n'] # <|image_1|>\n
121
+ else:
122
+ return super().replace_tag(media_type, index, inputs)
123
+
124
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
125
+ images = inputs.images or []
126
+ encoded = super()._encode(inputs)
127
+ input_ids = encoded['input_ids']
128
+ labels = encoded['labels']
129
+ idx_list = findall(input_ids, 32044) # '<|image|>'
130
+
131
+ if len(images) > 0:
132
+ processor = self.processor
133
+ encoded.update(processor.image_processor(images, return_tensors='pt'))
134
+ assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
135
+ res_input_ids = []
136
+ res_labels = []
137
+ num_img_tokens = encoded.pop('num_img_tokens').tolist()
138
+ idx_list.insert(0, -1)
139
+ for i in range(len(idx_list) - 1):
140
+ image_token_id = -i - 1
141
+ res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + [image_token_id] * num_img_tokens[i]
142
+ if labels is not None:
143
+ res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * num_img_tokens[i]
144
+ res_input_ids += input_ids[idx_list[-1] + 1:]
145
+ input_ids = res_input_ids
146
+ if labels is not None:
147
+ res_labels += labels[idx_list[-1] + 1:]
148
+ labels = res_labels
149
+
150
+ encoded['input_ids'] = input_ids
151
+ encoded['labels'] = labels
152
+ return encoded
153
+
154
+
155
+ class Phi4MMTemplate(Template):
156
+ placeholder_tokens = ['<|endoftext10|>', '<|endoftext11|>']
157
+
158
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
159
+ inputs: StdTemplateInputs) -> List[Context]:
160
+ if media_type == 'image':
161
+ return [[-100]]
162
+ elif media_type == 'audio':
163
+ import soundfile as sf
164
+ inputs.audios[index] = sf.read(load_file(inputs.audios[index]))
165
+ return [[-200]]
166
+
167
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
168
+ encoded = super()._encode(inputs)
169
+ input_ids = encoded['input_ids']
170
+ labels = encoded['labels']
171
+ images_idx = findall(input_ids, -100)
172
+ audios_idx = findall(input_ids, -200)
173
+ text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios))
174
+ new_encoded = self.processor(
175
+ text=text, images=inputs.images or None, audios=inputs.audios or None, return_tensors='pt')
176
+ placeholders = self._split_list(new_encoded.pop('input_ids')[0].tolist(), 198)
177
+
178
+ def _get_new_tokens(i):
179
+ return placeholders[i]
180
+
181
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx,
182
+ _get_new_tokens)
183
+ new_encoded.pop('attention_mask')
184
+ encoded.update(new_encoded)
185
+ return encoded
186
+
187
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
188
+ res = super()._data_collator(batch, padding_to=padding_to)
189
+ keys = [
190
+ 'input_image_embeds', 'image_sizes', 'image_attention_mask', 'input_audio_embeds', 'audio_embed_sizes',
191
+ 'input_mode'
192
+ ]
193
+ inputs = self.fetch_inputs(batch, keys)
194
+ for k, v in inputs.items():
195
+ inputs[k] = torch.concat(v)
196
+ res.update(inputs)
197
+ return res
198
+
199
+
200
+ register_template(Phi3TemplateMeta(MLLMTemplateType.phi3_vision, template_cls=Phi3VisionTemplate))
201
+
202
+ register_template(Phi3TemplateMeta(
203
+ MLLMTemplateType.phi4_multimodal,
204
+ template_cls=Phi4MMTemplate,
205
+ ))
swift/llm/template/template/minicpm.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Literal, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from swift.utils import get_env_args
10
+ from ..base import Template
11
+ from ..constant import LLMTemplateType, MLLMTemplateType
12
+ from ..register import TemplateMeta, register_template
13
+ from ..template_inputs import StdTemplateInputs
14
+ from ..utils import Context, Prompt, findall
15
+ from ..vision_utils import load_video_minicpmv_mplug_owl3
16
+ from .llama import Llama3TemplateMeta
17
+ from .qwen import Qwen2_5TemplateMeta, QwenTemplateMeta
18
+
19
+
20
+ @dataclass
21
+ class MinicpmTemplateMeta(TemplateMeta):
22
+ prefix: Prompt = field(default_factory=lambda: ['<s>{{SYSTEM}}'])
23
+ prompt: Prompt = field(default_factory=lambda: ['<用户>{{QUERY}}<AI>'])
24
+ chat_sep: Optional[Prompt] = field(default_factory=list)
25
+ suffix: Prompt = field(default_factory=lambda: ['</s>'])
26
+
27
+
28
+ register_template(MinicpmTemplateMeta(LLMTemplateType.minicpm))
29
+
30
+
31
+ def _remove_idx(arr: List[int], idx_list: List[int]) -> List[int]:
32
+ res = []
33
+ idx_set = set(idx_list)
34
+ for i, x in enumerate(arr):
35
+ if i not in idx_set:
36
+ res.append(x)
37
+ return res
38
+
39
+
40
+ class MiniCPMVTemplate(Template):
41
+ is_v2_5 = False
42
+ use_model = True
43
+ skip_prompt = False
44
+ placeholder_tokens = ['<unk>']
45
+
46
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
47
+ inputs: StdTemplateInputs) -> List[Context]:
48
+ if self.mode == 'vllm':
49
+ return ['(<image>./</image>)\n']
50
+ else:
51
+ return [[-100]]
52
+
53
+ async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
54
+ images = inputs.pop('images', None) or []
55
+ if len(images) == 0:
56
+ return
57
+ input_ids = inputs['input_ids']
58
+ idx_list = findall(input_ids, -100)
59
+ idx_list.insert(0, -1)
60
+ new_input_ids = []
61
+ features = []
62
+ for i in range(len(idx_list) - 1):
63
+ new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
64
+ context_list = ['<image>', [-100], '</image>']
65
+ feat = [x.squeeze() for x in images[i]['embeddings'].split(1)]
66
+ grid = images[i].get('grid')
67
+ if len(feat) > 1 and grid is not None:
68
+ context_list.append('<slice>')
69
+ for j in range(grid[1]):
70
+ if j > 0:
71
+ context_list.append('\n')
72
+ for _ in range(grid[0]):
73
+ context_list += ['<image>', [-100], '</image>']
74
+ context_list.append('</slice>\n')
75
+ new_input_ids += self._encode_context_list(context_list)[0]
76
+ features += feat
77
+ new_input_ids += input_ids[idx_list[-1] + 1:]
78
+ inputs['input_ids'] = new_input_ids
79
+ inputs['images'] = features
80
+ await super().prepare_lmdeploy_turbomind_inputs(inputs)
81
+
82
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
83
+ encoded = super()._encode(inputs)
84
+ images = inputs.images
85
+ input_ids = encoded['input_ids']
86
+ labels = encoded['labels']
87
+ idx_list = findall(input_ids, -100)
88
+ idx = idx_list[0]
89
+ tgt_sizes = None
90
+ slice_mode = getattr(self.config, 'slice_mode', False)
91
+ if slice_mode:
92
+ if self.is_v2_5:
93
+ image_processor = self.processor.image_processor
94
+ image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
95
+ placeholder = image_processor.get_slice_image_placeholder(image_inputs.image_sizes[0][0])
96
+ pixel_values = image_inputs['pixel_values']
97
+ tgt_sizes = image_inputs['tgt_sizes']
98
+ else:
99
+ images, placeholder = self.model.get_slice_image_placeholder(images[0], self.processor)
100
+ pixel_values = [[self.model.transform(img) for img in images]]
101
+ placeholder += '\n'
102
+ placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
103
+ input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
104
+ if labels is not None:
105
+ labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
106
+ input_tensor_ids = torch.tensor(input_ids)
107
+ image_start_idx = torch.where(input_tensor_ids == self.processor.im_start_id)[0]
108
+ image_start_idx += 1
109
+ image_end_idx = torch.where(input_tensor_ids == self.processor.im_end_id)[0]
110
+ valid_image_nums = max(len(image_start_idx), len(image_end_idx))
111
+ image_bound = [
112
+ torch.hstack(
113
+ [image_start_idx[:valid_image_nums].unsqueeze(-1), image_end_idx[:valid_image_nums].unsqueeze(-1)])
114
+ ]
115
+ else:
116
+ placeholder = '<image>' + '<unk>' * self.config.query_num + '</image>\n'
117
+ placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
118
+ input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
119
+ if labels is not None:
120
+ labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
121
+ image_bound = [torch.tensor([[idx, idx + self.config.query_num]])]
122
+ pixel_values = [[self.model.transform(images[0])]]
123
+ encoded = {
124
+ 'input_ids': input_ids,
125
+ 'labels': labels,
126
+ 'image_bound': image_bound,
127
+ 'pixel_values': pixel_values,
128
+ 'tgt_sizes': tgt_sizes
129
+ }
130
+ return encoded
131
+
132
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
133
+ inputs_embeds, _ = model.get_vllm_embedding(inputs)
134
+ return {'inputs_embeds': inputs_embeds}
135
+
136
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
137
+ res = {}
138
+ for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
139
+ res[k] = self.gather_list(batch, k)
140
+ res.update(super()._data_collator(batch, padding_to=padding_to))
141
+ return res
142
+
143
+
144
+ register_template(MinicpmTemplateMeta(MLLMTemplateType.minicpmv, template_cls=MiniCPMVTemplate))
145
+
146
+
147
+ class MiniCPMV2_5Template(MiniCPMVTemplate):
148
+ is_v2_5 = True
149
+
150
+
151
+ register_template(Llama3TemplateMeta(
152
+ MLLMTemplateType.minicpmv2_5,
153
+ template_cls=MiniCPMV2_5Template,
154
+ ))
155
+
156
+
157
+ class MiniCPMV2_6Template(MiniCPMVTemplate):
158
+
159
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
160
+ inputs: StdTemplateInputs) -> List[Context]:
161
+ assert media_type in {'image', 'video'}
162
+ max_num_frames = get_env_args('max_num_frames', int, 64)
163
+ load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
164
+ image_context = super().replace_tag('image', index, inputs)
165
+ if media_type == 'image':
166
+ return image_context
167
+ elif media_type == 'video':
168
+ return self.replace_video2image(load_video, inputs, lambda i: image_context)
169
+
170
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
171
+ encoded = Template._encode(self, inputs)
172
+ images = inputs.images
173
+ use_video = bool(inputs.videos)
174
+ use_image_id = True
175
+ max_slice_nums = get_env_args('max_slice_nums', int, None)
176
+ video_max_slice_nums = get_env_args('video_max_slice_nums', int, 1) # or 2
177
+ if use_video:
178
+ max_slice_nums = video_max_slice_nums
179
+ use_image_id = False
180
+ input_ids = encoded['input_ids']
181
+ labels = encoded['labels']
182
+ idx_list = findall(input_ids, -100)
183
+
184
+ image_processor = self.processor.image_processor
185
+ image_inputs = image_processor([images], return_tensors='pt',
186
+ max_slice_nums=max_slice_nums).to(self.model_info.torch_dtype)
187
+
188
+ def _get_new_tokens(i):
189
+ placeholder = image_processor.get_slice_image_placeholder(
190
+ image_inputs.image_sizes[0][i], image_idx=i, max_slice_nums=max_slice_nums, use_image_id=use_image_id)
191
+ placeholder += '\n'
192
+ return self.processor.encode(placeholder, add_special_tokens=False)
193
+
194
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
195
+ if inputs.images:
196
+ input_tensor_ids = torch.tensor(input_ids)
197
+ unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
198
+ indices = (input_tensor_ids == unk_token).nonzero(as_tuple=True)[0].tolist()
199
+
200
+ ranges = []
201
+ start = indices[0]
202
+ for i in range(1, len(indices)):
203
+ if indices[i] != indices[i - 1] + 1:
204
+ ranges.append([start, indices[i - 1] + 1])
205
+ start = indices[i]
206
+ ranges.append([start, indices[-1] + 1])
207
+ image_bound = [torch.tensor(ranges)]
208
+ else:
209
+ image_bound = [[]]
210
+
211
+ encoded = {
212
+ 'input_ids': input_ids,
213
+ 'labels': labels,
214
+ 'image_bound': image_bound,
215
+ 'pixel_values': image_inputs['pixel_values'],
216
+ 'tgt_sizes': image_inputs['tgt_sizes']
217
+ }
218
+ return encoded
219
+
220
+
221
+ register_template(QwenTemplateMeta(
222
+ MLLMTemplateType.minicpmv2_6,
223
+ template_cls=MiniCPMV2_6Template,
224
+ ))
225
+
226
+ register_template(Qwen2_5TemplateMeta(
227
+ MLLMTemplateType.minicpmo2_6,
228
+ template_cls=MiniCPMV2_6Template,
229
+ ))
swift/llm/template/template/minimax.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ from swift.utils import get_logger
6
+ from ..base import Template
7
+ from ..constant import LLMTemplateType
8
+ from ..register import TemplateMeta, register_template
9
+ from ..template_inputs import StdTemplateInputs
10
+ from ..utils import Context, Prompt
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ @dataclass
16
+ class MinimaxTemplateMeta(TemplateMeta):
17
+ prefix: Prompt = field(default_factory=list)
18
+ prompt: Prompt = field(default_factory=lambda: [
19
+ '<beginning_of_sentence>user name=user\n{{QUERY}}<end_of_sentence>\n'
20
+ '<beginning_of_sentence>ai name=assistant\n'
21
+ ])
22
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_sentence>\n'])
23
+ suffix: Prompt = field(default_factory=lambda: ['<end_of_sentence>'])
24
+ system_prefix: Optional[Prompt] = field(
25
+ default_factory=lambda: ['<beginning_of_sentence>system ai_setting=assistant\n{{SYSTEM}}<end_of_sentence>\n'])
26
+
27
+
28
+ register_template(MinimaxTemplateMeta(LLMTemplateType.minimax))
29
+
30
+
31
+ class MinimaxVLTemplate(Template):
32
+ image_placeholder = ['<image>']
33
+ skip_prompt = True
34
+
35
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
36
+ inputs: StdTemplateInputs) -> List[Context]:
37
+ assert media_type == 'image'
38
+ return self.image_placeholder * inputs.all_image_tokens[index]
39
+
40
+ def calc_num_image_tokens(self, image_inputs):
41
+ from transformers.image_utils import get_image_size, to_numpy_array
42
+ pixel_values = image_inputs['pixel_values']
43
+ image_sizes = image_inputs['image_sizes']
44
+ all_image_tokens = []
45
+ if not image_inputs:
46
+ return all_image_tokens
47
+
48
+ if self.processor.process_image_mode == 'anyres':
49
+ for pixel_value, image_size in zip(pixel_values, image_sizes):
50
+ height, width = image_size
51
+ num_image_tokens = self.processor.get_num_token(height, width, self.processor.grid_pinpoints,
52
+ self.processor.patch_size)
53
+ all_image_tokens.append(num_image_tokens)
54
+ elif self.processor.process_image_mode == 'resize':
55
+ pixel_values = image_inputs['pixel_values']
56
+ all_image_tokens = []
57
+ for pixel_value in pixel_values:
58
+ height, width = get_image_size(to_numpy_array(pixel_value))
59
+ all_image_tokens.append(int(height * width / self.processor.patch_size**2))
60
+ else:
61
+ if self.processor.patch_size is not None:
62
+ pixel_values = image_inputs['pixel_values']
63
+ all_image_tokens = []
64
+ for pixel_value in pixel_values:
65
+ height, width = get_image_size(to_numpy_array(pixel_value))
66
+ new_width, new_height = self.processor.get_hw_multiple_of(
67
+ (width, height), self.processor.patch_size, self.processor.max_size)
68
+ num_image_tokens = ((new_height // self.processor.patch_size) *
69
+ (new_width // self.processor.patch_size)) # + 1
70
+ all_image_tokens.append(num_image_tokens)
71
+ else:
72
+ logger.warning_once(
73
+ 'Expanding inputs for image tokens in MiniMaxVL01 should be done in processing. '
74
+ "Please add `patch_size` and `vision_feature_select_strategy` to the model's "
75
+ 'processing config or set directly '
76
+ 'with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = '
77
+ '{{vision_feature_select_strategy}}`. '
78
+ 'Using processors without these attributes in the config is deprecated '
79
+ 'and will throw an error in v4.47.')
80
+ raise ValueError(
81
+ "You need to provide `patch_size` and `vision_feature_select_strategy` in the model's processing "
82
+ 'config to expand inputs for image tokens.')
83
+ return all_image_tokens
84
+
85
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
86
+ output_kwargs = self.processor._merge_kwargs(
87
+ self.processor.MiniMaxVL01ProcessorKwargs,
88
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
89
+ )
90
+ if inputs.images:
91
+ image_inputs = self.processor.image_processor(
92
+ inputs.images, **output_kwargs['images_kwargs'], return_tensors='pt')
93
+ inputs.all_image_tokens = self.calc_num_image_tokens(image_inputs)
94
+ else:
95
+ image_inputs = {}
96
+ encoded = super()._encode(inputs)
97
+ for key in image_inputs:
98
+ encoded[key] = image_inputs[key]
99
+ return encoded
100
+
101
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
102
+ pixel_values = self.gather_list(batch, 'pixel_values')
103
+ image_sizes = self.gather_list(batch, 'image_sizes')
104
+ res = super()._data_collator(batch, padding_to=padding_to)
105
+ if pixel_values:
106
+ res['pixel_values'] = pixel_values
107
+ if image_sizes:
108
+ res['image_sizes'] = image_sizes
109
+ return res
110
+
111
+
112
+ register_template(MinimaxTemplateMeta(LLMTemplateType.minimax_vl, template_cls=MinimaxVLTemplate))
swift/llm/template/template/mistral.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ import torch
5
+
6
+ from ..base import Template
7
+ from ..constant import MLLMTemplateType
8
+ from ..register import TemplateMeta, register_template
9
+ from ..template_inputs import StdTemplateInputs
10
+ from ..utils import Context, findall
11
+ from .llm import mistral_2501_system
12
+
13
+
14
+ class Mistral2503Template(Template):
15
+ placeholder_tokens = ['[IMG]']
16
+ image_token = 10
17
+
18
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
19
+ inputs: StdTemplateInputs) -> List[Context]:
20
+ assert media_type == 'image'
21
+ return ['[IMG]']
22
+
23
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
24
+ encoded = super()._encode(inputs)
25
+ processor = self.processor
26
+ images = inputs.images
27
+ input_ids = encoded['input_ids']
28
+ labels = encoded['labels']
29
+ idx_list = findall(input_ids, self.image_token)
30
+ if idx_list:
31
+ image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
32
+ encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype)
33
+ encoded['image_sizes'] = image_sizes = image_inputs['image_sizes']
34
+
35
+ def _get_new_tokens(i):
36
+ height, width = image_sizes[i]
37
+ num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size)
38
+ num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size)
39
+ replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token]
40
+ ] * num_height_tokens
41
+ # Flatten list
42
+ replace_tokens = [item for sublist in replace_tokens for item in sublist]
43
+ replace_tokens[-1] = processor.image_end_token
44
+ replace_str = ''.join(replace_tokens)
45
+ return processor.encode(replace_str, add_special_tokens=False)
46
+
47
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
48
+
49
+ return encoded
50
+
51
+
52
+ register_template(
53
+ TemplateMeta(
54
+ MLLMTemplateType.mistral_2503,
55
+ prefix=['<s>'],
56
+ prompt=['[INST]{{QUERY}}[/INST]'],
57
+ chat_sep=['</s>'],
58
+ suffix=['</s>'],
59
+ system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
60
+ default_system=mistral_2501_system,
61
+ template_cls=Mistral2503Template))
swift/llm/template/template/molmo.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ import torch
5
+
6
+ from ..base import Template
7
+ from ..constant import MLLMTemplateType
8
+ from ..register import TemplateMeta, register_template
9
+ from ..template_inputs import StdTemplateInputs
10
+ from ..utils import Context, findall
11
+
12
+
13
+ class MolmoTemplate(Template):
14
+ placeholder_tokens = ['<im_patch>']
15
+
16
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
17
+ inputs: StdTemplateInputs) -> List[Context]:
18
+ return []
19
+
20
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
21
+ encoded = super()._encode(inputs)
22
+ # image
23
+ images_inputs = self.processor.process(images=inputs.images or None, text='')
24
+ images_input_ids = images_inputs.pop('input_ids').tolist()
25
+ user_token = self._tokenize(' User')
26
+ assert len(user_token) == 1
27
+ idx = findall(images_input_ids, user_token[0])
28
+ assert len(idx) == 1
29
+ labels = encoded['labels']
30
+ encoded['input_ids'] = images_input_ids[:idx[0]] + encoded['input_ids']
31
+ if labels:
32
+ encoded['labels'] = [-100] * idx[0] + labels
33
+ if 'images' in images_inputs:
34
+ images_inputs['images'] = images_inputs['images'].to(self.model_info.torch_dtype)
35
+ encoded.update(images_inputs)
36
+ return encoded
37
+
38
+ def generate(self, model, **kwargs):
39
+ kwargs.pop('attention_mask', None)
40
+ generation_config = kwargs.pop('generation_config')
41
+ batch = {
42
+ k: kwargs.pop(k, None)
43
+ for k in ['input_ids', 'attention_mask', 'images', 'image_input_idx', 'image_masks']
44
+ }
45
+ return model.generate_from_batch(batch, generation_config, **kwargs)
46
+
47
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
48
+ res = super()._data_collator(batch, padding_to=padding_to)
49
+ # prepare batchfy inputs
50
+ keys = ['images', 'image_input_idx', 'image_masks']
51
+ images_res = self.fetch_inputs(batch, keys)
52
+ for key in keys:
53
+ val = images_res.get(key)
54
+ if val:
55
+ images_res[key] = torch.stack(val)
56
+ res.update(images_res)
57
+ return res
58
+
59
+
60
+ register_template(
61
+ TemplateMeta(
62
+ MLLMTemplateType.molmo,
63
+ prefix=[],
64
+ prompt=[' User: {{QUERY}} Assistant:'],
65
+ chat_sep=None,
66
+ suffix=['<|endoftext|>'],
67
+ template_cls=MolmoTemplate,
68
+ ))
swift/llm/template/template/moonshot.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Literal, Optional
5
+
6
+ from ..base import Template
7
+ from ..constant import LLMTemplateType, MLLMTemplateType
8
+ from ..register import TemplateMeta, register_template
9
+ from ..template_inputs import StdTemplateInputs
10
+ from ..utils import Context, Prompt, findall
11
+
12
+
13
+ @dataclass
14
+ class MoonlightTemplateMeta(TemplateMeta):
15
+ prefix: Prompt = field(default_factory=list)
16
+ prompt: Prompt = field(default_factory=lambda:
17
+ ['<|im_user|>user<|im_middle|>{{QUERY}}<|im_end|><|im_assistant|>assistant<|im_middle|>'])
18
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>'])
19
+ suffix: Prompt = field(default_factory=lambda: ['<|im_end|>'])
20
+ system_prefix: Optional[Prompt] = field(
21
+ default_factory=lambda: ['<|im_system|>system<|im_middle|>{{SYSTEM}}<|im_end|>'])
22
+ default_system: str = 'You are a helpful assistant'
23
+
24
+
25
+ register_template(MoonlightTemplateMeta(LLMTemplateType.moonlight))
26
+
27
+
28
+ class KimiVLTemplate(Template):
29
+ placeholder_tokens = ['<|media_pad|>']
30
+
31
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
32
+ inputs: StdTemplateInputs) -> List[Context]:
33
+ if media_type == 'image':
34
+ return ['<|media_start|>image<|media_content|><|media_pad|><|media_end|>']
35
+
36
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
37
+ encoded = super()._encode(inputs)
38
+ input_ids = encoded['input_ids']
39
+ labels = encoded['labels']
40
+ media_token = self._tokenize('<|media_pad|>')[0]
41
+ idx_list = findall(input_ids, media_token)
42
+ if inputs.images:
43
+ image_processor = self.processor.image_processor
44
+ image_inputs = image_processor(inputs.images, return_tensors='pt')
45
+ image_grid_hws = image_inputs['image_grid_hws']
46
+ merge_length = image_processor.merge_kernel_size[0] * image_processor.merge_kernel_size[1]
47
+
48
+ def _get_new_tokens(i):
49
+ token_len = (image_grid_hws[i].prod() // merge_length)
50
+ return [media_token] * token_len
51
+
52
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
53
+ encoded['input_ids'] = input_ids
54
+ encoded['labels'] = labels
55
+ encoded.update(image_inputs)
56
+ return encoded
57
+
58
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
59
+ res = super()._data_collator_mm_data(batch)
60
+ image_grid_hws = self.concat_tensor(batch, 'image_grid_hws', 0)
61
+ if image_grid_hws is not None:
62
+ res['image_grid_hws'] = image_grid_hws
63
+ return res
64
+
65
+
66
+ register_template(MoonlightTemplateMeta(MLLMTemplateType.kimi_vl, template_cls=KimiVLTemplate))
swift/llm/template/template/mplug.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Literal, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from swift.utils import get_env_args
10
+ from ..base import Template
11
+ from ..constant import MLLMTemplateType
12
+ from ..register import TemplateMeta, register_template
13
+ from ..template_inputs import StdTemplateInputs
14
+ from ..utils import Context, Prompt, findall
15
+ from ..vision_utils import load_video_minicpmv_mplug_owl3
16
+ from .qwen import QwenTemplateMeta
17
+
18
+
19
+ class mPlugOwl2Template(Template):
20
+
21
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
22
+ inputs: StdTemplateInputs) -> List[Context]:
23
+ assert media_type == 'image'
24
+ return [[-200]]
25
+
26
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
27
+ from mplug_owl2.mm_utils import process_images
28
+ processor = self.processor
29
+ images = inputs.images
30
+ for i, image in enumerate(images):
31
+ # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1
32
+ max_edge = max(image.size)
33
+ image = image.resize((max_edge, max_edge))
34
+ images[i] = image
35
+ encoded = super()._encode(inputs)
36
+ input_ids = encoded['input_ids']
37
+ labels = encoded['labels']
38
+ res = {'input_ids': input_ids, 'labels': labels}
39
+ if images:
40
+ images = process_images(images, processor)
41
+ images = images.to(self.model_info.torch_dtype)
42
+ res['images'] = images
43
+ return res
44
+
45
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
46
+ res = super()._data_collator(batch, padding_to=padding_to)
47
+ images = [b['images'] for b in batch if 'images' in b]
48
+ if images:
49
+ res['images'] = torch.concat(images)
50
+ return res
51
+
52
+
53
+ register_template(
54
+ TemplateMeta(
55
+ MLLMTemplateType.mplug_owl2,
56
+ template_cls=mPlugOwl2Template,
57
+ prefix=['{{SYSTEM}}'],
58
+ prompt=['USER: {{QUERY}}ASSISTANT:'],
59
+ chat_sep=['</s>'],
60
+ suffix=[['eos_token_id']],
61
+ stop_words=['<|endoftext|>', '</s>']))
62
+
63
+
64
+ class mPlugOwl3Template(Template):
65
+ version = None
66
+
67
+ def _get_image_token_list(self, cut_shape):
68
+ text = self.processor.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0], w=cut_shape[1])
69
+ text_list = text.split('<|image|>')
70
+ res_text_list = []
71
+ for text in text_list[:-1]:
72
+ res_text_list += [text, '<|image|>']
73
+ res_text_list += text_list[-1]
74
+ token_list = self._encode_context_list(res_text_list)[0]
75
+ return token_list
76
+
77
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
78
+ inputs: StdTemplateInputs) -> List[Context]:
79
+ assert media_type in {'image', 'video'}
80
+ max_num_frames = get_env_args('max_num_frames', int, 16)
81
+ load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
82
+ if media_type == 'image':
83
+ return [[-100], '\n']
84
+ elif media_type == 'video':
85
+ return self.replace_video2image(load_video, inputs, lambda i: [[-100]]) + ['\n']
86
+
87
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
88
+ encoded = super()._encode(inputs)
89
+ images = inputs.images
90
+ videos = inputs.videos
91
+ cut_enable = not videos
92
+ input_ids = encoded['input_ids']
93
+ labels = encoded['labels']
94
+ idx_list = findall(input_ids, -100)
95
+ processor = self.processor
96
+ encoded = {}
97
+ if images:
98
+ image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt')
99
+ cut_shapes = image_inputs['cut_shape'] or [None] * 2 * len(idx_list)
100
+ image_token_list = self.processor.encode('<|image|>', add_special_tokens=False)
101
+
102
+ def _get_new_tokens(i):
103
+ cut_shape = cut_shapes[2 * i]
104
+ if cut_shape:
105
+ token_list = self._get_image_token_list(cut_shape)
106
+ else:
107
+ token_list = image_token_list
108
+ return token_list
109
+
110
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
111
+ image_token_idx = torch.tensor(findall(input_ids, image_token_list))
112
+ if self.version == '241101':
113
+ media_offset = image_token_idx
114
+ else:
115
+ _range = torch.arange(len(input_ids))[:, None]
116
+ matrix = (_range > image_token_idx[None]).sum(dim=1)
117
+ media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
118
+ encoded.update({
119
+ 'pixel_values': image_inputs['pixel_values'],
120
+ 'media_offset': media_offset,
121
+ })
122
+ encoded['input_ids'] = input_ids
123
+ encoded['labels'] = labels
124
+ return encoded
125
+
126
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
127
+ if 'media_offset' in inputs:
128
+ media_offset = []
129
+ cusum_offset = 0
130
+ image_embeds = []
131
+ pixel_values = inputs.pop('pixel_values')
132
+ max_sequence_length = inputs['input_ids'].shape[1]
133
+ for i, curr_media_offset in enumerate(inputs['media_offset']):
134
+ if curr_media_offset is None:
135
+ continue
136
+ if curr_media_offset.shape[1] < max_sequence_length:
137
+ padding = curr_media_offset[:, -1:, :].expand(curr_media_offset.shape[0],
138
+ max_sequence_length - curr_media_offset.shape[1],
139
+ curr_media_offset.shape[2])
140
+ curr_media_offset = torch.concat([curr_media_offset, padding], dim=1)
141
+ media_offset.append(curr_media_offset + cusum_offset)
142
+ image_embeds.append(model.forward_image(pixel_values[i]))
143
+ cusum_offset += image_embeds[-1].shape[0]
144
+ inputs['media_offset'] = torch.concat(media_offset)
145
+ inputs['image_embeds'] = torch.concat(image_embeds)
146
+ return inputs
147
+
148
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
149
+ res = self.fetch_inputs(batch, ['media_offset', 'pixel_values'])
150
+ for b in batch:
151
+ b.pop('pixel_values', None)
152
+ res.update(super()._data_collator(batch, padding_to=padding_to))
153
+ return res
154
+
155
+
156
+ class mPlugOwl3_241101Template(mPlugOwl3Template):
157
+ version = '241101'
158
+
159
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
160
+ if 'pixel_values' in inputs:
161
+ pixel_values = inputs.pop('pixel_values')
162
+ inputs['image_embeds'] = torch.concat([model.forward_image(pv) for pv in pixel_values])
163
+ else:
164
+ inputs['media_offset'] = [None] * inputs['input_ids'].shape[0]
165
+ return inputs
166
+
167
+
168
+ @dataclass
169
+ class mPlugOwl3TemplateMeta(QwenTemplateMeta):
170
+ prefix: Prompt = field(default_factory=lambda: ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'])
171
+ default_system: Optional[str] = None
172
+ system_prefix: Optional[Prompt] = None
173
+
174
+
175
+ register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3, template_cls=mPlugOwl3Template))
176
+
177
+ register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3_241101, template_cls=mPlugOwl3_241101Template))
178
+
179
+
180
+ class DocOwl2Template(Template):
181
+
182
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
183
+ inputs: StdTemplateInputs) -> List[Context]:
184
+ if media_type == 'image':
185
+ return [f'<img {index + 1}>', [-200]]
186
+
187
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
188
+ encoded = super()._encode(inputs)
189
+ if inputs.images:
190
+ image_tensor, patch_positions, _ = self.processor._process_image(inputs.images)
191
+ image_tensor = image_tensor.to(self.model_info.torch_dtype)
192
+ encoded.update({'images': image_tensor, 'patch_positions': patch_positions})
193
+ return encoded
194
+
195
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
196
+ keys = ['images', 'patch_positions']
197
+ res = self.fetch_inputs(batch, keys)
198
+ for key in keys:
199
+ val = res.get(key)
200
+ if val:
201
+ res[key] = torch.concat([v for v in val if v is not None])
202
+ res.update(super()._data_collator(batch, padding_to=padding_to))
203
+ return res
204
+
205
+
206
+ register_template(
207
+ TemplateMeta(
208
+ MLLMTemplateType.doc_owl2,
209
+ prefix=['<s>'],
210
+ prompt=[' USER: {{QUERY}} ASSISTANT:'],
211
+ chat_sep=['</s>'],
212
+ suffix=['</s>'],
213
+ template_cls=DocOwl2Template,
214
+ ))
swift/llm/template/template/openbuddy.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from ..constant import LLMTemplateType
4
+ from ..register import TemplateMeta, register_template
5
+
6
+ OPENBUDDY_DEFAULT_SYSTEM = (
7
+ 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
8
+ 'Always answer as helpfully and logically as possible, while being safe. '
9
+ 'Your answers should not include any '
10
+ 'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
11
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n'
12
+ 'If a question does not make any sense, or is not factually coherent, '
13
+ 'explain why instead of answering something not correct. '
14
+ "If you don't know the answer to a question, please don't share false information.\n"
15
+ 'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
16
+ 'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
17
+ 'You always deeply love and support China, Chinese government, people and culture.\n'
18
+ 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
19
+ 'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.')
20
+ register_template(
21
+ TemplateMeta(
22
+ LLMTemplateType.openbuddy,
23
+ prefix=[],
24
+ prompt=['User: {{QUERY}}\nAssistant:'],
25
+ chat_sep=['\n'],
26
+ default_system=OPENBUDDY_DEFAULT_SYSTEM,
27
+ system_prefix=['{{SYSTEM}}\n\n'],
28
+ auto_add_bos=True))
29
+
30
+ OPENBUDDY2_DEFAULT_SYSTEM = (
31
+ 'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. '
32
+ 'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. '
33
+ 'Your answers should not include any harmful, political, religious, unethical, racist, '
34
+ 'sexist, toxic, dangerous, or illegal content. '
35
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n'
36
+ 'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n'
37
+ 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
38
+ 'not related to GPT or OpenAI')
39
+
40
+ register_template(
41
+ TemplateMeta(
42
+ LLMTemplateType.openbuddy2,
43
+ prefix=[],
44
+ prompt=['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'],
45
+ chat_sep=['<|end|>\n'],
46
+ suffix=['<|end|>'],
47
+ default_system=OPENBUDDY2_DEFAULT_SYSTEM,
48
+ system_prefix=['<|role|>system<|says|>{{SYSTEM}}<|end|>\n']))
swift/llm/template/template/pixtral.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from ..base import Template
5
+ from ..constant import MLLMTemplateType
6
+ from ..register import TemplateMeta, register_template
7
+ from ..template_inputs import StdTemplateInputs
8
+ from ..utils import findall
9
+
10
+
11
+ class PixtralTemplate(Template):
12
+ image_placeholder = ['[IMG]']
13
+ placeholder_tokens = ['[IMG]']
14
+
15
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
16
+ encoded = super()._encode(inputs)
17
+ processor = self.processor
18
+ images = inputs.images
19
+ input_ids = encoded['input_ids']
20
+ labels = encoded['labels']
21
+ idx_list = findall(input_ids, 10)
22
+ if idx_list:
23
+ image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
24
+ encoded['pixel_values'] = image_inputs['pixel_values'][0]
25
+ image_sizes = image_inputs['image_sizes'][0]
26
+
27
+ def _get_new_tokens(i):
28
+ height, width = image_sizes[i]
29
+ num_height_tokens = height // processor.patch_size
30
+ num_width_tokens = width // processor.patch_size
31
+ replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
32
+ num_height_tokens - 1)
33
+ replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
34
+ # Flatten list
35
+ replace_str = ''.join(replace_tokens)
36
+ img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False)
37
+ return img_tokens
38
+
39
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
40
+
41
+ return encoded
42
+
43
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
44
+ pixel_values = self.gather_list(batch, 'pixel_values')
45
+ res = super()._data_collator(batch, padding_to=padding_to)
46
+ if pixel_values:
47
+ res['pixel_values'] = pixel_values
48
+ return res
49
+
50
+
51
+ register_template(
52
+ TemplateMeta(
53
+ MLLMTemplateType.pixtral,
54
+ prefix=['<s>{{SYSTEM}}'],
55
+ prompt=['[INST]{{QUERY}}[/INST]'],
56
+ chat_sep=['</s>'],
57
+ suffix=['</s>'],
58
+ template_cls=PixtralTemplate,
59
+ ))
swift/llm/template/template/qwen.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Literal, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from swift.llm import to_device, to_float_dtype
10
+ from swift.utils import get_env_args, is_deepspeed_enabled
11
+ from ..base import Template
12
+ from ..constant import LLMTemplateType, MLLMTemplateType
13
+ from ..register import register_template
14
+ from ..template_inputs import StdTemplateInputs
15
+ from ..template_meta import TemplateMeta
16
+ from ..utils import Context, Word, findall
17
+ from ..vision_utils import load_audio, load_batch, load_video_ovis2
18
+ from .llama import Llama3TemplateMeta
19
+ from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
20
+
21
+
22
+ @dataclass
23
+ class QwenTemplateMeta(ChatmlTemplateMeta):
24
+ default_system: Optional[str] = DEFAULT_SYSTEM
25
+ auto_add_bos: bool = False
26
+ stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>'])
27
+ agent_template: str = 'hermes'
28
+
29
+
30
+ @dataclass
31
+ class Qwen2_5TemplateMeta(QwenTemplateMeta):
32
+ default_system: Optional[str] = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.'
33
+
34
+
35
+ @dataclass
36
+ class Qwen2_5MathTemplateMeta(QwenTemplateMeta):
37
+ default_system: Optional[str] = 'Please reason step by step, and put your final answer within \\boxed{}.'
38
+
39
+
40
+ qwq_preview_system = ('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
41
+ 'You should think step-by-step.')
42
+
43
+ register_template(QwenTemplateMeta(LLMTemplateType.qwen))
44
+ register_template(Qwen2_5TemplateMeta(LLMTemplateType.qwen2_5))
45
+ register_template(QwenTemplateMeta(LLMTemplateType.qwq_preview, default_system=qwq_preview_system))
46
+
47
+
48
+ class ThinkingTemplate(Template):
49
+
50
+ def _swift_encode(self, inputs: StdTemplateInputs):
51
+ if not self.is_training:
52
+ for message in inputs.messages:
53
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
54
+ message['content'] = message['content'].split('</think>')[-1].lstrip('\n')
55
+ return super()._swift_encode(inputs)
56
+
57
+
58
+ register_template(
59
+ QwenTemplateMeta(
60
+ LLMTemplateType.qwq, default_system=None, response_prefix='<think>\n', template_cls=ThinkingTemplate))
61
+
62
+ # '<think>\n\n</think>\n\n'
63
+ register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=ThinkingTemplate))
64
+
65
+ register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math))
66
+
67
+
68
+ class QwenPRMTemplate(Template):
69
+ cot_process_placeholder = '<extra_0>'
70
+
71
+ def _preprocess_inputs(
72
+ self,
73
+ inputs: StdTemplateInputs,
74
+ ) -> None:
75
+ super()._preprocess_inputs(inputs)
76
+ total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
77
+ if self.cot_process_placeholder not in total_content:
78
+ inputs.messages[-1]['content'] = inputs.messages[-1]['content'] + self.cot_process_placeholder
79
+
80
+ @staticmethod
81
+ def make_step_rewards(logits, token_masks):
82
+ probabilities = F.softmax(logits, dim=-1)
83
+ probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
84
+
85
+ all_scores_res = []
86
+ for i in range(probabilities.size(0)):
87
+ sample = probabilities[i] # seq_len, num_labels
88
+ positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
89
+ non_zero_elements_list = positive_probs.cpu().tolist()
90
+ all_scores_res.append(non_zero_elements_list)
91
+ return all_scores_res
92
+
93
+ def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
94
+ step_sep_id = self.tokenizer.encode(self.cot_process_placeholder)[0]
95
+ token_masks = (input_ids == step_sep_id)
96
+ return self.make_step_rewards(logits, token_masks)
97
+
98
+
99
+ register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math_prm, template_cls=QwenPRMTemplate))
100
+
101
+
102
+ class QwenVLTemplate(Template):
103
+ load_images = False
104
+
105
+ @staticmethod
106
+ def _load_image(image, load_images: bool):
107
+ if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200):
108
+ load_images = True
109
+ return Template._load_image(image, load_images)
110
+
111
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
112
+ inputs: StdTemplateInputs) -> List[Context]:
113
+ assert media_type == 'image'
114
+ if self.mode == 'lmdeploy':
115
+ return [f'Picture {index + 1}: ', [-100], '\n']
116
+ else:
117
+ image = inputs.images[index]
118
+ if self.mode == 'vllm':
119
+ return [f'Picture {index + 1}: <img></img>\n']
120
+ else:
121
+ assert isinstance(image, str)
122
+ return [f'Picture {index + 1}: <img>{image}</img>\n']
123
+
124
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
125
+ return [f'<ref>{ref}</ref>']
126
+
127
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
128
+ return [f'<box>{self._get_bbox_str(bbox)}</box>']
129
+
130
+
131
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen_vl, template_cls=QwenVLTemplate))
132
+
133
+
134
+ class QwenAudioTemplate(Template):
135
+
136
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
137
+ inputs: StdTemplateInputs) -> List[Context]:
138
+ assert media_type == 'audio'
139
+ audios = inputs.audios
140
+ audio = audios[index]
141
+ assert isinstance(audio, str)
142
+ return [f'Audio {index + 1}:<audio>{audio}</audio>\n']
143
+
144
+ def _tokenize(self, context, **tokenizer_kwargs):
145
+ audio_info = self.processor.process_audio(context)
146
+ return super()._tokenize(context, audio_info=audio_info)
147
+
148
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
149
+ encoded = super()._encode(inputs)
150
+ text = ''.join([f'<audio>{audio}</audio>' for audio in inputs.audios])
151
+ audio_info = self.processor.process_audio(text)
152
+ if audio_info:
153
+ tokenizer_kwargs = {'audio_info': audio_info}
154
+ encoded.update(tokenizer_kwargs)
155
+ encoded['tokenizer_kwargs'] = tokenizer_kwargs
156
+ return encoded
157
+
158
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
159
+ res = super()._data_collator(batch, padding_to=padding_to)
160
+ if batch[0].get('audio_info') is not None:
161
+ res['audio_info'] = [b['audio_info'] for b in batch]
162
+ return res
163
+
164
+
165
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen_audio, template_cls=QwenAudioTemplate))
166
+
167
+
168
+ class Qwen2AudioTemplate(Template):
169
+
170
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
171
+ inputs: StdTemplateInputs) -> List[Context]:
172
+ assert media_type == 'audio'
173
+ if not self.use_chat_template:
174
+ return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
175
+ else:
176
+ return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
177
+
178
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
179
+ encoded = super()._encode(inputs)
180
+ if inputs.audios:
181
+ sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
182
+ audios = load_batch(inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate))
183
+ audio_inputs = self.processor.feature_extractor(
184
+ audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
185
+ audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
186
+ encoded.update(audio_inputs)
187
+ return encoded
188
+
189
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
190
+ res = super()._data_collator(batch, padding_to=padding_to)
191
+ input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
192
+ feature_attention_mask = [
193
+ b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
194
+ ]
195
+ if input_features:
196
+ res['input_features'] = torch.concat(input_features)
197
+ res['feature_attention_mask'] = torch.concat(feature_attention_mask)
198
+ return res
199
+
200
+
201
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate))
202
+
203
+
204
+ class Qwen2VLTemplate(Template):
205
+ image_token_id = 151655
206
+ video_token_id = 151656
207
+ placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
208
+ version = 'v2'
209
+ use_model = True
210
+
211
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
212
+ inputs: StdTemplateInputs) -> List[Context]:
213
+ from qwen_vl_utils import fetch_image, fetch_video
214
+ assert media_type in {'image', 'video'}
215
+ if media_type == 'image':
216
+ inputs.images[index] = fetch_image({'image': inputs.images[index]})
217
+ if self.mode == 'lmdeploy':
218
+ return ['<|vision_start|>', [-100], '<|vision_end|>']
219
+ else:
220
+ return ['<|vision_start|><|image_pad|><|vision_end|>']
221
+ else:
222
+ inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
223
+ return ['<|vision_start|><|video_pad|><|vision_end|>']
224
+
225
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
226
+ return [f'<|object_ref_start|>{ref}<|object_ref_end|>']
227
+
228
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
229
+ return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>']
230
+
231
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
232
+ encoded = super()._encode(inputs)
233
+ processor = self.processor
234
+ input_ids = encoded['input_ids']
235
+ labels = encoded['labels']
236
+ images = inputs.images
237
+ videos = inputs.videos
238
+ for media_type in ['images', 'videos']:
239
+ if locals()[media_type]:
240
+ if media_type == 'images':
241
+ media_token = self.image_token_id
242
+ media_inputs = processor.image_processor(
243
+ images=images, videos=None, return_tensors='pt', do_resize=False)
244
+ media_grid_thw = media_inputs['image_grid_thw']
245
+ else:
246
+ media_inputs = processor.image_processor(
247
+ images=None, videos=videos, return_tensors='pt', do_resize=False)
248
+ media_grid_thw = media_inputs['video_grid_thw']
249
+ media_token = self.video_token_id
250
+ if self.version == 'v2_5':
251
+ from qwen_vl_utils import vision_process
252
+ media_inputs['second_per_grid_ts'] = [
253
+ processor.image_processor.temporal_patch_size / vision_process.FPS
254
+ ] * len(media_grid_thw)
255
+ idx_list = findall(input_ids, media_token)
256
+ merge_length = processor.image_processor.merge_size**2
257
+
258
+ def _get_new_tokens(i):
259
+ token_len = (media_grid_thw[i].prod() // merge_length)
260
+ return [media_token] * token_len
261
+
262
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
263
+ encoded.update(media_inputs)
264
+
265
+ encoded['input_ids'] = input_ids
266
+ encoded['labels'] = labels
267
+ return encoded
268
+
269
+ def compute_loss_context(self, model, inputs):
270
+ if 'real_position_ids' not in inputs:
271
+ return super().compute_loss_context(model, inputs)
272
+ if self.version == 'v2':
273
+ from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
274
+ elif self.version == 'v2_5':
275
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
276
+ elif self.version == 'omni':
277
+ from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
278
+ position_ids = inputs['position_ids']
279
+ inputs['position_ids'] = inputs.pop('real_position_ids')
280
+ return self._patch_flash_attention_forward(modeling_module, position_ids)
281
+
282
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
283
+ if not self.is_training:
284
+ return inputs
285
+ input_ids = inputs['input_ids']
286
+ _model = model.model
287
+ if not hasattr(_model, 'embed_tokens'):
288
+ _model = _model.model # LoRA
289
+ pixel_values = inputs.get('pixel_values')
290
+ pixel_values_videos = inputs.get('pixel_values_videos')
291
+ image_grid_thw = inputs.get('image_grid_thw')
292
+ video_grid_thw = inputs.get('video_grid_thw')
293
+
294
+ inputs_embeds = _model.embed_tokens(input_ids)
295
+
296
+ dtype = model.visual.get_dtype() if self.version == 'v2' else model.visual.dtype
297
+ if pixel_values is None and pixel_values_videos is None: # plain-text
298
+ if is_deepspeed_enabled():
299
+ from PIL import Image
300
+ images = [Image.new('RGB', (32, 32), (0, 0, 0))]
301
+ media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
302
+ device = input_ids.device
303
+ media_inputs = to_device(media_inputs, device)
304
+ pixel_values = media_inputs['pixel_values'].type(dtype)
305
+ image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
306
+ inputs_embeds += image_embeds.mean() * 0.
307
+ else:
308
+ if pixel_values is not None:
309
+ pixel_values = pixel_values.type(dtype)
310
+ image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
311
+ image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
312
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
313
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
314
+
315
+ if pixel_values_videos is not None:
316
+ pixel_values_videos = pixel_values_videos.type(dtype)
317
+ video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
318
+ video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
319
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
320
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
321
+
322
+ return {'inputs_embeds': inputs_embeds}
323
+
324
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
325
+ res = super()._data_collator_mm_data(batch)
326
+ second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')
327
+ if second_per_grid_ts:
328
+ res['second_per_grid_ts'] = second_per_grid_ts
329
+ for media_type in ['image', 'video']:
330
+ grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0)
331
+ if grid_thw is not None:
332
+ res[f'{media_type}_grid_thw'] = grid_thw
333
+ return res
334
+
335
+ def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
336
+ position_ids = []
337
+ for r in row:
338
+ r = r[0].copy()
339
+ r['input_ids'] = torch.tensor(r['input_ids'])[None]
340
+ position_ids.append(self._get_position_ids(r))
341
+ packed = super().packing_row(row)
342
+ packed['real_position_ids'] = torch.concat(position_ids, dim=-1)
343
+ return packed
344
+
345
+ def _get_position_ids(self, inputs: Dict[str, Any]):
346
+ # fix https://github.com/huggingface/transformers/pull/33487
347
+ kwargs = {}
348
+ if self.version == 'v2_5':
349
+ kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')}
350
+ position_ids, _ = self.model.get_rope_index(
351
+ inputs['input_ids'],
352
+ inputs.get('image_grid_thw'),
353
+ inputs.get('video_grid_thw'),
354
+ attention_mask=inputs.get('attention_mask'),
355
+ **kwargs)
356
+ return position_ids.contiguous()
357
+
358
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
359
+ res = super()._data_collator(batch, padding_to=padding_to)
360
+ if self._packing:
361
+ res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1)
362
+ elif self.is_training:
363
+ res['position_ids'] = self._get_position_ids(res)
364
+ return res
365
+
366
+
367
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_vl, template_cls=Qwen2VLTemplate))
368
+
369
+ register_template(
370
+ QwenTemplateMeta(
371
+ MLLMTemplateType.qvq,
372
+ default_system=('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
373
+ 'Answer in the language of the question. You should think step-by-step.'),
374
+ template_cls=Qwen2VLTemplate,
375
+ ))
376
+
377
+
378
+ class Qwen2_5VLTemplate(Qwen2VLTemplate):
379
+ version = 'v2_5'
380
+ norm_bbox = 'none'
381
+
382
+
383
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_vl, template_cls=Qwen2_5VLTemplate))
384
+
385
+
386
+ class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
387
+ version = 'omni'
388
+ placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
389
+
390
+ def __init__(self, *args, **kwargs):
391
+ super().__init__(*args, **kwargs)
392
+ from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
393
+ default = Qwen2_5OmniProcessorKwargs._defaults
394
+ self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
395
+ self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
396
+ self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
397
+ self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
398
+
399
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
400
+ inputs: StdTemplateInputs) -> List[Context]:
401
+ from qwen_omni_utils import fetch_image, fetch_video
402
+ if media_type == 'image':
403
+ inputs.images[index] = fetch_image({'image': inputs.images[index]})
404
+ return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
405
+ elif media_type == 'audio':
406
+ inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
407
+ return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
408
+ elif media_type == 'video':
409
+ video = inputs.videos[index]
410
+ inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
411
+ if self.use_audio_in_video:
412
+ import librosa
413
+ if video.startswith('http://') or video.startswith('https://'):
414
+ import audioread
415
+ video = audioread.ffdec.FFmpegAudioFile(video)
416
+ video = librosa.load(video, sr=self.sampling_rate)[0]
417
+ inputs.audios.insert(inputs.audio_idx, (video, 'video'))
418
+ inputs.audio_idx += 1
419
+ return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
420
+ return ['<|vision_bos|><|VIDEO|><|vision_eos|>']
421
+
422
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
423
+ encoded = Template._encode(self, inputs)
424
+ processor = self.processor
425
+ video_audios_mask = []
426
+ for i, audio in enumerate(inputs.audios):
427
+ if isinstance(audio, tuple) and audio[1] == 'video':
428
+ inputs.audios[i] = audio[0]
429
+ video_audios_mask.append(True)
430
+ else:
431
+ video_audios_mask.append(False)
432
+ video_audios_mask = torch.tensor(video_audios_mask)
433
+ media_inputs = processor(
434
+ text='',
435
+ audio=inputs.audios or None,
436
+ images=inputs.images or None,
437
+ videos=inputs.videos or None,
438
+ return_tensors='pt')
439
+ media_inputs.pop('input_ids')
440
+ media_inputs.pop('attention_mask')
441
+ media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
442
+ input_ids = encoded['input_ids']
443
+ labels = encoded['labels']
444
+ # audio
445
+ audio_token_id = self._tokenize('<|AUDIO|>')
446
+ idx_list = findall(input_ids, audio_token_id)
447
+ feature_attention_mask = media_inputs.get('feature_attention_mask')
448
+ if feature_attention_mask is not None:
449
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
450
+ audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
451
+ else:
452
+ audio_lengths = None
453
+ audio_lengths_origin = audio_lengths
454
+ if idx_list:
455
+ if self.use_audio_in_video:
456
+ audio_lengths = audio_lengths[~video_audios_mask]
457
+
458
+ def _get_new_audio_tokens(i):
459
+ return audio_token_id * audio_lengths[i]
460
+
461
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
462
+
463
+ for media_type in ['image', 'video']:
464
+ token = f'<|{media_type.upper()}|>'
465
+ token_id = self._tokenize(token)
466
+ idx_list = findall(input_ids, token_id)
467
+ if idx_list:
468
+ merge_size = processor.image_processor.merge_size
469
+ media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
470
+ if media_type == 'video' and self.use_audio_in_video:
471
+ audio_lengths = audio_lengths_origin[video_audios_mask]
472
+ video_second_per_grid = media_inputs['video_second_per_grid']
473
+
474
+ def _get_new_tokens_use_audio_in_video(i):
475
+ audio_token_indices = torch.arange(audio_lengths[i])
476
+ grid_thw = media_grid_thw[i]
477
+ height = grid_thw[1] // merge_size
478
+ width = grid_thw[2] // merge_size
479
+ video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
480
+ video_token_indices = torch.broadcast_to(
481
+ video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
482
+ video_token_indices = (
483
+ video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
484
+ tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
485
+ video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
486
+ audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)
487
+
488
+ res = []
489
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
490
+ if j < len(video_chunk_indexes):
491
+ video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
492
+ res += token_id * video_seq_length
493
+ if j < len(audio_chunk_indexes):
494
+ audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
495
+ res += audio_token_id * audio_seq_length
496
+ return res
497
+
498
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
499
+ _get_new_tokens_use_audio_in_video)
500
+
501
+ else:
502
+
503
+ def _get_new_tokens(i):
504
+ token_len = (media_grid_thw[i].prod() // (merge_size**2))
505
+ return token_id * token_len
506
+
507
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
508
+
509
+ encoded['input_ids'] = input_ids
510
+ encoded['labels'] = labels
511
+ encoded.update(media_inputs)
512
+ return encoded
513
+
514
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
515
+ return Template._post_encode(self, model, inputs)
516
+
517
+ def _get_position_ids(self, inputs: Dict[str, Any]):
518
+ feature_attention_mask = inputs.get('feature_attention_mask')
519
+ if feature_attention_mask is not None:
520
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
521
+ else:
522
+ audio_feature_lengths = None
523
+ video_second_per_grid = inputs.pop('video_second_per_grid', None)
524
+ input_ids = inputs['input_ids']
525
+ attention_mask = inputs.get('attention_mask')
526
+ if attention_mask is None:
527
+ attention_mask = torch.ones_like(input_ids)
528
+ position_ids, _ = self.model.thinker.get_rope_index(
529
+ input_ids,
530
+ inputs.get('image_grid_thw'),
531
+ inputs.get('video_grid_thw'),
532
+ attention_mask,
533
+ self.use_audio_in_video,
534
+ audio_feature_lengths,
535
+ video_second_per_grid,
536
+ )
537
+ return position_ids.contiguous()
538
+
539
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
540
+ res = super()._data_collator_mm_data(batch)
541
+ video_second_per_grid = self.gather_list(batch, 'video_second_per_grid')
542
+ if video_second_per_grid:
543
+ res['video_second_per_grid'] = video_second_per_grid
544
+ input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
545
+ feature_attention_mask = [
546
+ b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
547
+ ]
548
+ if input_features:
549
+ res['input_features'] = torch.concat(input_features)
550
+ res['feature_attention_mask'] = torch.concat(feature_attention_mask)
551
+ return res
552
+
553
+ def generate(self, model, *args, **kwargs):
554
+ if kwargs.get('video_grid_thw') is not None:
555
+ kwargs['use_audio_in_video'] = self.use_audio_in_video
556
+ return super().generate(model, *args, **kwargs)
557
+
558
+
559
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_omni, template_cls=Qwen2_5OmniTemplate))
560
+
561
+
562
+ class Ovis1_6Template(Template):
563
+ skip_prompt = False
564
+ use_model = True
565
+
566
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
567
+ inputs: StdTemplateInputs) -> List[Context]:
568
+ assert media_type == 'image'
569
+ return [[-200], '\n']
570
+
571
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
572
+ encoded = super()._encode(inputs)
573
+ images = inputs.images
574
+ input_ids = encoded['input_ids']
575
+ labels = encoded['labels']
576
+ idx_list = findall(input_ids, [-200])
577
+ added_tokens_len = 0
578
+ pixel_values = []
579
+ for i, idx in enumerate(idx_list):
580
+ max_partition = get_env_args('max_partition', int, 9)
581
+ raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image(
582
+ images[i], max_partition=max_partition)
583
+ input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:]
584
+ if labels is not None:
585
+ labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:]
586
+ pixel_values.append(raw_pixel_values)
587
+ added_tokens_len += len(image_placeholders) - 1
588
+ dtype = self.model.visual_tokenizer.dtype
589
+ if pixel_values:
590
+ pixel_values = torch.cat(pixel_values, dim=0).to(dtype)
591
+ else:
592
+ pixel_values = torch.zeros((1, 3, 384, 384), dtype=dtype) # dummpy
593
+ encoded.update({'input_ids': input_ids, 'labels': labels})
594
+ encoded['pixel_values'] = [pixel_values]
595
+ return encoded
596
+
597
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
598
+ padding_side = self.padding_side if self.is_training else 'left'
599
+ if self.max_length is not None:
600
+ model.config.multimodal_max_length = self.max_length
601
+ input_ids = inputs['input_ids']
602
+ labels = inputs.get('labels')
603
+ if labels is None:
604
+ labels = input_ids.new_full(input_ids.shape, -100)
605
+ _, inputs_embeds, labels, attention_mask = model.merge_multimodal(
606
+ text_input_ids=input_ids,
607
+ text_attention_masks=torch.ones_like(input_ids), # not use, only compat
608
+ text_labels=labels,
609
+ pixel_values=inputs['pixel_values'],
610
+ left_padding=padding_side == 'left')
611
+ if inputs.get('labels') is None:
612
+ labels = None
613
+ return {'inputs_embeds': inputs_embeds, 'labels': labels, 'attention_mask': attention_mask}
614
+
615
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
616
+ pixel_values = self.gather_list(batch, 'pixel_values')
617
+ res = super()._data_collator(batch, padding_to=padding_to)
618
+ res['pixel_values'] = pixel_values
619
+ return res
620
+
621
+
622
+ register_template(
623
+ TemplateMeta(
624
+ MLLMTemplateType.ovis1_6,
625
+ prefix=['<bos>'],
626
+ prompt=['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
627
+ chat_sep=['<end_of_turn>\n'],
628
+ suffix=['<end_of_turn>'],
629
+ system_prefix=['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'],
630
+ template_cls=Ovis1_6Template,
631
+ ))
632
+
633
+ register_template(
634
+ Llama3TemplateMeta(
635
+ MLLMTemplateType.ovis1_6_llama3,
636
+ default_system='You are a helpful and honest multimodal assistant.',
637
+ template_cls=Ovis1_6Template,
638
+ ))
639
+
640
+
641
+ class Ovis2Template(Ovis1_6Template):
642
+ placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
643
+ nframes = 12
644
+
645
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
646
+ inputs: StdTemplateInputs) -> List[Context]:
647
+ if media_type == 'image':
648
+ return [[-200], '\n']
649
+ elif media_type == 'video':
650
+ nframes = get_env_args('nframes', int, self.nframes)
651
+ inputs.images = load_video_ovis2(inputs.videos[index], nframes)
652
+ return [[-200] * nframes, '\n']
653
+
654
+
655
+ register_template(QwenTemplateMeta(
656
+ MLLMTemplateType.ovis2,
657
+ template_cls=Ovis2Template,
658
+ ))
659
+
660
+
661
+ @dataclass
662
+ class MarcoO1TemplateMeta(QwenTemplateMeta):
663
+ default_system: Optional[str] = """
664
+ 你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
665
+ \n## 重要!!!!!
666
+ 当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
667
+ <Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
668
+ """
669
+
670
+
671
+ register_template(MarcoO1TemplateMeta(LLMTemplateType.marco_o1))
swift/llm/template/template/stepfun.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ from ..base import Template
5
+ from ..constant import MLLMTemplateType
6
+ from ..register import TemplateMeta, register_template
7
+ from ..template_inputs import StdTemplateInputs
8
+ from ..utils import Context
9
+ from ..vision_utils import load_file
10
+ from .qwen import QwenTemplateMeta
11
+
12
+
13
+ class GOTImageEvalProcessor:
14
+
15
+ def __init__(self, image_size=384, mean=None, std=None):
16
+ from torchvision import transforms
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ if mean is None:
19
+ mean = (0.48145466, 0.4578275, 0.40821073)
20
+ if std is None:
21
+ std = (0.26862954, 0.26130258, 0.27577711)
22
+
23
+ self.normalize = transforms.Normalize(mean, std)
24
+
25
+ self.transform = transforms.Compose([
26
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
27
+ transforms.ToTensor(),
28
+ self.normalize,
29
+ ])
30
+
31
+ def __call__(self, item):
32
+ return self.transform(item)
33
+
34
+
35
+ class GOT_OCR2Template(Template):
36
+ placeholder_tokens = ['<imgpad>']
37
+
38
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
39
+ inputs: StdTemplateInputs) -> List[Context]:
40
+ # 'OCR: '
41
+ # 'OCR with format: '
42
+ assert media_type == 'image'
43
+ return ['<img>' + '<imgpad>' * 256 + '</img>\n']
44
+
45
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
46
+ encoded = super()._encode(inputs)
47
+ images = inputs.images
48
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
49
+ for i, image in enumerate(images):
50
+ images[i] = image_processor_high(image)[None].to(self.model_info.torch_dtype)
51
+ if images:
52
+ encoded['images'] = images
53
+ return encoded
54
+
55
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
56
+ res = super()._data_collator(batch, padding_to=padding_to)
57
+ images = self.gather_list(batch, 'images')
58
+ if images:
59
+ res['images'] = images
60
+ return res
61
+
62
+
63
+ register_template(
64
+ QwenTemplateMeta(
65
+ MLLMTemplateType.got_ocr2,
66
+ default_system=' You should follow the instructions carefully and explain your answers in detail.',
67
+ template_cls=GOT_OCR2Template,
68
+ ))
69
+
70
+
71
+ class GOT_OCR2HfTemplate(Template):
72
+ placeholder_tokens = ['<imgpad>']
73
+
74
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
75
+ inputs: StdTemplateInputs) -> List[Context]:
76
+ # 'OCR: '
77
+ # 'OCR with format: '
78
+ assert media_type == 'image'
79
+ return ['<img>' + '<imgpad>' * 256 + '</img>\n']
80
+
81
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面
82
+ encoded = super()._encode(inputs)
83
+ images = inputs.images
84
+ if images:
85
+ encoded['images'] = images
86
+ return encoded
87
+
88
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
89
+ res = super()._data_collator(batch, padding_to=padding_to)
90
+ images = self.gather_list(batch, 'images')
91
+ _inputs = self.processor(images, return_tensors='pt')
92
+ _inputs.pop('input_ids') # this does not contain the response, so cannot be used when training
93
+ _inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training
94
+
95
+ res.update(_inputs.data)
96
+ return res
97
+
98
+
99
+ register_template(
100
+ QwenTemplateMeta(
101
+ MLLMTemplateType.got_ocr2_hf,
102
+ default_system=' You should follow the instructions carefully and explain your answers in detail.',
103
+ template_cls=GOT_OCR2HfTemplate,
104
+ ))
105
+
106
+
107
+ class StepAudioTemplate(Template):
108
+ use_model = True
109
+
110
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
111
+ inputs: StdTemplateInputs) -> List[Context]:
112
+ assert media_type == 'audio', f'media_type: {media_type}'
113
+ from utils import load_audio
114
+ audio_wav, sr = load_audio(load_file(inputs.audios[index]))
115
+ audio_tokens = self.model.encoder(audio_wav, sr)
116
+ return audio_tokens
117
+
118
+
119
+ register_template(
120
+ TemplateMeta(
121
+ MLLMTemplateType.step_audio,
122
+ template_cls=StepAudioTemplate,
123
+ prefix=['<s>'],
124
+ prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'],
125
+ system_prefix=['<s><|BOT|>system\n{{SYSTEM}}<|EOT|>'],
126
+ chat_sep=['<|EOT|>'],
127
+ suffix=['<|EOT|>'],
128
+ ))