Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- swift/llm/dataset/preprocessor/extra.py +112 -0
- swift/llm/model/model/microsoft.py +234 -0
- swift/llm/model/model/minicpm.py +207 -0
- swift/llm/model/model/minimax.py +156 -0
- swift/llm/model/model/mistral.py +157 -0
- swift/llm/model/model/mllm.py +180 -0
- swift/llm/model/model/moonshot.py +47 -0
- swift/llm/model/model/mplug.py +142 -0
- swift/llm/model/model/openbuddy.py +85 -0
- swift/llm/model/model/qwen.py +899 -0
- swift/llm/model/model/skywork.py +89 -0
- swift/llm/model/model/stepfun.py +86 -0
- swift/llm/model/model/telechat.py +59 -0
- swift/llm/model/model/valley.py +82 -0
- swift/llm/model/model/yi.py +146 -0
- swift/llm/sampling/__init__.py +1 -0
- swift/llm/sampling/base.py +55 -0
- swift/llm/sampling/distill_sampler.py +148 -0
- swift/llm/sampling/mcts.py +400 -0
- swift/llm/sampling/sampling.py +106 -0
- swift/llm/sampling/utils.py +172 -0
- swift/llm/sampling/vanilla_sampler.py +216 -0
- swift/llm/template/__init__.py +10 -0
- swift/llm/template/base.py +1606 -0
- swift/llm/template/constant.py +194 -0
- swift/llm/template/grounding.py +91 -0
- swift/llm/template/register.py +57 -0
- swift/llm/template/template/__init__.py +2 -0
- swift/llm/template/template/deepseek.py +315 -0
- swift/llm/template/template/emu3.py +191 -0
- swift/llm/template/template/gemma.py +131 -0
- swift/llm/template/template/glm.py +293 -0
- swift/llm/template/template/idefics3.py +37 -0
- swift/llm/template/template/internlm.py +195 -0
- swift/llm/template/template/internvl.py +168 -0
- swift/llm/template/template/llama.py +213 -0
- swift/llm/template/template/llava.py +309 -0
- swift/llm/template/template/llm.py +274 -0
- swift/llm/template/template/megrez.py +93 -0
- swift/llm/template/template/microsoft.py +205 -0
- swift/llm/template/template/minicpm.py +229 -0
- swift/llm/template/template/minimax.py +112 -0
- swift/llm/template/template/mistral.py +61 -0
- swift/llm/template/template/molmo.py +68 -0
- swift/llm/template/template/moonshot.py +66 -0
- swift/llm/template/template/mplug.py +214 -0
- swift/llm/template/template/openbuddy.py +48 -0
- swift/llm/template/template/pixtral.py +59 -0
- swift/llm/template/template/qwen.py +671 -0
- swift/llm/template/template/stepfun.py +128 -0
swift/llm/dataset/preprocessor/extra.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .core import ResponsePreprocessor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class GroundingMixin:
|
| 10 |
+
"""This class offers prompts to the grounding task"""
|
| 11 |
+
task_type: Optional[str] = None
|
| 12 |
+
|
| 13 |
+
_grounding_language_mixin = [0.8, 0.2]
|
| 14 |
+
_grounding_prompts = {
|
| 15 |
+
'grounding': {
|
| 16 |
+
'en': [('<ref-object>', '<bbox>'), ('The positions of <ref-object> is', '<bbox>'),
|
| 17 |
+
('Find the positions of <ref-object>', '<bbox>'), ('Where is <ref-object>', '<bbox>'),
|
| 18 |
+
('Find <ref-object>', '<bbox>'), ('Show me <ref-object>', '<bbox>'),
|
| 19 |
+
('Detect <ref-object>', '<bbox>'), ('Locate <ref-object>', '<bbox>'),
|
| 20 |
+
('Tell me the location of <ref-object>', '<bbox>'), ('Give the location of <ref-object>', '<bbox>'),
|
| 21 |
+
('Provide the bounding box coordinate of <ref-object>', '<bbox>')],
|
| 22 |
+
'zh': [('<ref-object>', '<bbox>'), ('<ref-object>的位置在图片中', '<bbox>'), ('<ref-object>在图片中', '<bbox>'),
|
| 23 |
+
('<ref-object>在', '<bbox>'), ('找到<ref-object>的位置', '<bbox>'), ('<ref-object>在哪里', '<bbox>'),
|
| 24 |
+
('提供<ref-object>的坐标位置', '<bbox>')]
|
| 25 |
+
},
|
| 26 |
+
'caption': {
|
| 27 |
+
'en': [
|
| 28 |
+
('<bbox>', '<ref-object>'),
|
| 29 |
+
('The object at position <bbox>', '<ref-object>'),
|
| 30 |
+
('This <bbox> is', '<ref-object>'),
|
| 31 |
+
('What is the object at <bbox>', '<ref-object>'),
|
| 32 |
+
('Describe <bbox>', '<ref-object>'),
|
| 33 |
+
('<bbox> is', '<ref-object>'),
|
| 34 |
+
('The bounding box coordinate <bbox> contains', '<ref-object>'),
|
| 35 |
+
],
|
| 36 |
+
'zh': [
|
| 37 |
+
('<bbox>', '<ref-object>'),
|
| 38 |
+
('<bbox>是什么', '<ref-object>'),
|
| 39 |
+
('<bbox>的位置包含', '<ref-object>'),
|
| 40 |
+
('描述<bbox>', '<ref-object>'),
|
| 41 |
+
('<bbox>中是', '<ref-object>'),
|
| 42 |
+
('坐标<bbox>描述了什么', '<ref-object>'),
|
| 43 |
+
('描述<bbox>中的事物', '<ref-object>'),
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def construct_grounding_prompt(self):
|
| 49 |
+
# TODO Only support one bbox to one object
|
| 50 |
+
lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2])
|
| 51 |
+
prompts = GroundingMixin._grounding_prompts[self.task_type][lang]
|
| 52 |
+
query, response = prompts[np.random.choice(range(len(prompts)))]
|
| 53 |
+
return query, response
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TextGenerationPreprocessor(ResponsePreprocessor):
|
| 57 |
+
|
| 58 |
+
def __init__(self,
|
| 59 |
+
*,
|
| 60 |
+
prompt: str,
|
| 61 |
+
query_tag: str = '{{QUERY}}',
|
| 62 |
+
columns: Optional[Dict[str, str]] = None,
|
| 63 |
+
**kwargs) -> None:
|
| 64 |
+
self.query_tag = query_tag
|
| 65 |
+
self.prompt = prompt
|
| 66 |
+
super().__init__(columns=columns, **kwargs)
|
| 67 |
+
|
| 68 |
+
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 69 |
+
row['query'] = self.prompt.replace(self.query_tag, row['query'])
|
| 70 |
+
return super().preprocess(row)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ClsGenerationPreprocessor(ResponsePreprocessor):
|
| 74 |
+
|
| 75 |
+
def __init__(self,
|
| 76 |
+
labels: List[str],
|
| 77 |
+
*,
|
| 78 |
+
task: str,
|
| 79 |
+
is_pair_seq: bool = False,
|
| 80 |
+
columns: Optional[Dict[str, str]] = None,
|
| 81 |
+
**kwargs) -> None:
|
| 82 |
+
self.labels = labels
|
| 83 |
+
self.task = task
|
| 84 |
+
self.is_pair_seq = is_pair_seq
|
| 85 |
+
|
| 86 |
+
category = ', '.join(labels)
|
| 87 |
+
self.sentence2_key = 'sentence2'
|
| 88 |
+
self.label_key = 'label'
|
| 89 |
+
if is_pair_seq:
|
| 90 |
+
self.sentence_key = 'sentence1'
|
| 91 |
+
inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}'
|
| 92 |
+
else:
|
| 93 |
+
self.sentence_key = 'sentence'
|
| 94 |
+
inputs = 'Sentence: {sentence}'
|
| 95 |
+
self.prompt = f"""Task: {task}
|
| 96 |
+
{inputs}
|
| 97 |
+
Category: {category}
|
| 98 |
+
Output:"""
|
| 99 |
+
super().__init__(columns=columns, **kwargs)
|
| 100 |
+
|
| 101 |
+
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
| 102 |
+
label = row.pop(self.label_key, None)
|
| 103 |
+
if label is None:
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
if self.is_pair_seq:
|
| 107 |
+
query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key))
|
| 108 |
+
else:
|
| 109 |
+
query = self.prompt.format(sentence=row.pop(self.sentence_key))
|
| 110 |
+
row['query'] = query
|
| 111 |
+
row['response'] = self.labels[int(label)]
|
| 112 |
+
return super().preprocess(row)
|
swift/llm/model/model/microsoft.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from types import MethodType
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from transformers import AutoConfig
|
| 7 |
+
|
| 8 |
+
from swift.llm import TemplateType
|
| 9 |
+
from swift.utils import get_device, get_env_args
|
| 10 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 11 |
+
from ..model_arch import ModelArch
|
| 12 |
+
from ..patcher import patch_ignore_check_imports, patch_output_clone
|
| 13 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
|
| 14 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 15 |
+
from ..utils import ModelInfo, use_submodel_func
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_model_tokenizer_phi3_vision(model_dir: str,
|
| 19 |
+
model_info: ModelInfo,
|
| 20 |
+
model_kwargs: Dict[str, Any],
|
| 21 |
+
load_model: bool = True,
|
| 22 |
+
**kwargs):
|
| 23 |
+
processor_kwargs = {}
|
| 24 |
+
if 'num_crops' in kwargs:
|
| 25 |
+
processor_kwargs['num_crops'] = get_env_args('num_crops', int, kwargs['num_crops'])
|
| 26 |
+
from transformers import AutoProcessor
|
| 27 |
+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, **processor_kwargs)
|
| 28 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(
|
| 29 |
+
model_dir, model_info, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs)
|
| 30 |
+
|
| 31 |
+
if load_model:
|
| 32 |
+
patch_output_clone(model.model.vision_embed_tokens.wte)
|
| 33 |
+
|
| 34 |
+
return model, processor
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
register_model(
|
| 38 |
+
ModelMeta(
|
| 39 |
+
MLLMModelType.phi3_vision,
|
| 40 |
+
[
|
| 41 |
+
ModelGroup([
|
| 42 |
+
Model('LLM-Research/Phi-3-vision-128k-instruct', 'microsoft/Phi-3-vision-128k-instruct'),
|
| 43 |
+
Model('LLM-Research/Phi-3.5-vision-instruct', 'microsoft/Phi-3.5-vision-instruct'),
|
| 44 |
+
])
|
| 45 |
+
],
|
| 46 |
+
TemplateType.phi3_vision,
|
| 47 |
+
partial(get_model_tokenizer_phi3_vision, num_crops=4),
|
| 48 |
+
architectures=['Phi3VForCausalLM'],
|
| 49 |
+
model_arch=ModelArch.phi3_vision,
|
| 50 |
+
requires=['transformers>=4.36'],
|
| 51 |
+
tags=['vision'],
|
| 52 |
+
))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_model_tokenizer_phi4_multimodal(*args, **kwargs):
|
| 56 |
+
model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
|
| 57 |
+
processor.audio_processor.audio_compression_rate = processor.audio_processor.compression_rate
|
| 58 |
+
processor.audio_processor.audio_downsample_rate = processor.audio_processor.qformer_compression_rate
|
| 59 |
+
processor.audio_processor.audio_feat_stride = processor.audio_processor.feat_stride
|
| 60 |
+
del processor.audio_processor.feature_size
|
| 61 |
+
del processor.audio_processor.sampling_rate
|
| 62 |
+
del processor.audio_processor.padding_value
|
| 63 |
+
del processor.__class__.chat_template
|
| 64 |
+
processor.chat_template = None
|
| 65 |
+
if model is not None:
|
| 66 |
+
model.set_lora_adapter(['vision', 'speech'])
|
| 67 |
+
return model, processor
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
register_model(
|
| 71 |
+
ModelMeta(
|
| 72 |
+
MLLMModelType.phi4_multimodal,
|
| 73 |
+
[ModelGroup([
|
| 74 |
+
Model('LLM-Research/Phi-4-multimodal-instruct', 'microsoft/Phi-4-multimodal-instruct'),
|
| 75 |
+
])],
|
| 76 |
+
TemplateType.phi4_multimodal,
|
| 77 |
+
get_model_tokenizer_phi4_multimodal,
|
| 78 |
+
architectures=['Phi4MMForCausalLM'],
|
| 79 |
+
model_arch=ModelArch.phi4_multimodal,
|
| 80 |
+
requires=['transformers>=4.36,<4.49', 'backoff', 'soundfile'],
|
| 81 |
+
tags=['vision', 'audio'],
|
| 82 |
+
))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_model_tokenizer_florence(model_dir: str,
|
| 86 |
+
model_info: ModelInfo,
|
| 87 |
+
model_kwargs: Dict[str, Any],
|
| 88 |
+
load_model: bool = True,
|
| 89 |
+
**kwargs):
|
| 90 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 91 |
+
model_config.vision_config.model_type = 'davit' # fix merge-lora
|
| 92 |
+
if model_kwargs['device_map'] == 'auto':
|
| 93 |
+
model_kwargs['device_map'] = get_device()
|
| 94 |
+
kwargs['model_config'] = model_config
|
| 95 |
+
with patch_ignore_check_imports():
|
| 96 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 97 |
+
|
| 98 |
+
if model is not None:
|
| 99 |
+
model.vision_tower.enable_checkpoint = True
|
| 100 |
+
use_submodel_func(model, 'language_model', ['generate', 'forward'])
|
| 101 |
+
return model, processor
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
register_model(
|
| 105 |
+
ModelMeta(
|
| 106 |
+
MLLMModelType.florence,
|
| 107 |
+
[
|
| 108 |
+
# llama2
|
| 109 |
+
ModelGroup([
|
| 110 |
+
Model('AI-ModelScope/Florence-2-base-ft', 'microsoft/Florence-2-base-ft'),
|
| 111 |
+
Model('AI-ModelScope/Florence-2-base', 'microsoft/Florence-2-base'),
|
| 112 |
+
Model('AI-ModelScope/Florence-2-large', 'microsoft/Florence-2-large'),
|
| 113 |
+
Model('AI-ModelScope/Florence-2-large-ft', 'microsoft/Florence-2-large-ft'),
|
| 114 |
+
]),
|
| 115 |
+
],
|
| 116 |
+
TemplateType.florence,
|
| 117 |
+
get_model_tokenizer_florence,
|
| 118 |
+
architectures=['Florence2ForConditionalGeneration'],
|
| 119 |
+
model_arch=ModelArch.florence,
|
| 120 |
+
tags=['vision'],
|
| 121 |
+
))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_model_tokenizer_phi3_small(model_dir: str,
|
| 125 |
+
model_info: ModelInfo,
|
| 126 |
+
model_kwargs: Dict[str, Any],
|
| 127 |
+
load_model: bool = True,
|
| 128 |
+
**kwargs):
|
| 129 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 130 |
+
|
| 131 |
+
def rotary_emb(self, query_states, key_states, **kwargs):
|
| 132 |
+
q_type = query_states.dtype
|
| 133 |
+
k_type = key_states.dtype
|
| 134 |
+
query_states, key_states = self.rotory_emb_origin(query_states, key_states, **kwargs)
|
| 135 |
+
query_states = query_states.to(q_type)
|
| 136 |
+
key_states = key_states.to(k_type)
|
| 137 |
+
return query_states, key_states
|
| 138 |
+
|
| 139 |
+
if model is not None:
|
| 140 |
+
for i in range(32):
|
| 141 |
+
re = model.model.layers[i].self_attn.rotary_emb
|
| 142 |
+
re.rotory_emb_origin = re.forward
|
| 143 |
+
re.forward = MethodType(rotary_emb, re)
|
| 144 |
+
return model, tokenizer
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
register_model(
|
| 148 |
+
ModelMeta(
|
| 149 |
+
LLMModelType.phi3_small,
|
| 150 |
+
[
|
| 151 |
+
ModelGroup([
|
| 152 |
+
Model('LLM-Research/Phi-3-small-8k-instruct', 'microsoft/Phi-3-small-8k-instruct'),
|
| 153 |
+
Model('LLM-Research/Phi-3-small-128k-instruct', 'microsoft/Phi-3-small-128k-instruct'),
|
| 154 |
+
]),
|
| 155 |
+
],
|
| 156 |
+
TemplateType.phi3,
|
| 157 |
+
get_model_tokenizer_phi3_small,
|
| 158 |
+
architectures=['Phi3SmallForCausalLM'],
|
| 159 |
+
model_arch=ModelArch.phi3_small,
|
| 160 |
+
requires=['transformers>=4.36'],
|
| 161 |
+
))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_model_tokenizer_phi(model_dir: str,
|
| 165 |
+
model_info: ModelInfo,
|
| 166 |
+
model_kwargs: Dict[str, Any],
|
| 167 |
+
load_model: bool = True,
|
| 168 |
+
**kwargs):
|
| 169 |
+
# TODO: check
|
| 170 |
+
return get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
register_model(
|
| 174 |
+
ModelMeta(
|
| 175 |
+
LLMModelType.phi2,
|
| 176 |
+
[
|
| 177 |
+
ModelGroup([
|
| 178 |
+
Model('AI-ModelScope/phi-2', 'microsoft/phi-2'),
|
| 179 |
+
]),
|
| 180 |
+
],
|
| 181 |
+
TemplateType.default,
|
| 182 |
+
get_model_tokenizer_phi,
|
| 183 |
+
architectures=['PhiForCausalLM'],
|
| 184 |
+
model_arch=ModelArch.phi2,
|
| 185 |
+
))
|
| 186 |
+
|
| 187 |
+
register_model(
|
| 188 |
+
ModelMeta(
|
| 189 |
+
LLMModelType.phi3,
|
| 190 |
+
[
|
| 191 |
+
ModelGroup([
|
| 192 |
+
Model('LLM-Research/Phi-3-mini-4k-instruct', 'microsoft/Phi-3-mini-4k-instruct'),
|
| 193 |
+
Model('LLM-Research/Phi-3-mini-128k-instruct', 'microsoft/Phi-3-mini-128k-instruct'),
|
| 194 |
+
Model('LLM-Research/Phi-3-medium-4k-instruct', 'microsoft/Phi-3-medium-4k-instruct'),
|
| 195 |
+
Model('LLM-Research/Phi-3-medium-128k-instruct', 'microsoft/Phi-3-medium-128k-instruct'),
|
| 196 |
+
Model('LLM-Research/Phi-3.5-mini-instruct', 'microsoft/Phi-3.5-mini-instruct'),
|
| 197 |
+
]),
|
| 198 |
+
ModelGroup(Model('LLM-Research/Phi-4-mini-instruct', 'microsoft/Phi-4-mini-instruct'))
|
| 199 |
+
],
|
| 200 |
+
TemplateType.phi3,
|
| 201 |
+
get_model_tokenizer_with_flash_attn,
|
| 202 |
+
architectures=['Phi3ForCausalLM'],
|
| 203 |
+
requires=['transformers>=4.36'],
|
| 204 |
+
model_arch=ModelArch.phi3,
|
| 205 |
+
))
|
| 206 |
+
|
| 207 |
+
register_model(
|
| 208 |
+
ModelMeta(
|
| 209 |
+
LLMModelType.phi4,
|
| 210 |
+
[
|
| 211 |
+
ModelGroup([
|
| 212 |
+
Model('LLM-Research/phi-4', 'microsoft/phi-4'),
|
| 213 |
+
]),
|
| 214 |
+
],
|
| 215 |
+
TemplateType.phi4,
|
| 216 |
+
get_model_tokenizer_with_flash_attn,
|
| 217 |
+
architectures=['Phi3ForCausalLM'],
|
| 218 |
+
requires=['transformers>=4.36'],
|
| 219 |
+
model_arch=ModelArch.phi3,
|
| 220 |
+
))
|
| 221 |
+
|
| 222 |
+
register_model(
|
| 223 |
+
ModelMeta(
|
| 224 |
+
LLMModelType.phi3_moe,
|
| 225 |
+
[
|
| 226 |
+
ModelGroup([
|
| 227 |
+
Model('LLM-Research/Phi-3.5-MoE-instruct', 'microsoft/Phi-3.5-MoE-instruct'),
|
| 228 |
+
]),
|
| 229 |
+
],
|
| 230 |
+
TemplateType.phi3,
|
| 231 |
+
get_model_tokenizer_with_flash_attn,
|
| 232 |
+
architectures=['PhiMoEForCausalLM'],
|
| 233 |
+
requires=['transformers>=4.36'],
|
| 234 |
+
))
|
swift/llm/model/model/minicpm.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from types import MethodType
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from transformers import AutoConfig
|
| 7 |
+
from transformers.utils import strtobool
|
| 8 |
+
|
| 9 |
+
from swift.llm import TemplateType
|
| 10 |
+
from swift.utils import get_env_args
|
| 11 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 12 |
+
from ..model_arch import ModelArch
|
| 13 |
+
from ..patcher import patch_device_map, patch_fixed_device, patch_output_clone
|
| 14 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 15 |
+
from ..utils import ModelInfo, use_submodel_func
|
| 16 |
+
from .deepseek import get_model_tokenizer_deepseek_moe
|
| 17 |
+
|
| 18 |
+
register_model(
|
| 19 |
+
ModelMeta(
|
| 20 |
+
LLMModelType.minicpm_moe,
|
| 21 |
+
[
|
| 22 |
+
ModelGroup([
|
| 23 |
+
Model('OpenBMB/MiniCPM-MoE-8x2B', 'openbmb/MiniCPM-MoE-8x2B'),
|
| 24 |
+
]),
|
| 25 |
+
],
|
| 26 |
+
TemplateType.minicpm,
|
| 27 |
+
get_model_tokenizer_deepseek_moe,
|
| 28 |
+
architectures=['MiniCPMForCausalLM'],
|
| 29 |
+
model_arch=ModelArch.llama,
|
| 30 |
+
requires=['transformers>=4.36'],
|
| 31 |
+
))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _patch_minicpmv_device_map(model) -> None:
|
| 35 |
+
if not hasattr(model, 'hf_device_map') or len(model.hf_device_map.values()) == 1:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
device = list(model.hf_device_map.values())[0]
|
| 39 |
+
if hasattr(model, 'get_vision_embedding') and not hasattr(model, '_old_get_vision_embedding'):
|
| 40 |
+
# minicpm-v-v2-chat; avoid double patching
|
| 41 |
+
_old_get_vision_embedding = model.__class__.get_vision_embedding
|
| 42 |
+
|
| 43 |
+
def _get_vision_embedding(self, pixel_values):
|
| 44 |
+
if len(pixel_values) == 0:
|
| 45 |
+
return _old_get_vision_embedding(self, pixel_values)
|
| 46 |
+
output = _old_get_vision_embedding(self, pixel_values)
|
| 47 |
+
return output.to(device=device)
|
| 48 |
+
|
| 49 |
+
model.__class__._old_get_vision_embedding = _old_get_vision_embedding
|
| 50 |
+
model.__class__.get_vision_embedding = _get_vision_embedding
|
| 51 |
+
|
| 52 |
+
if hasattr(model, 'resampler'): # minicpm-v-v2_5-chat
|
| 53 |
+
patch_fixed_device(model.resampler, device)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_model_tokenizer_minicpmv(model_dir: str,
|
| 57 |
+
model_info: ModelInfo,
|
| 58 |
+
model_kwargs: Dict[str, Any],
|
| 59 |
+
load_model: bool = True,
|
| 60 |
+
**kwargs):
|
| 61 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 62 |
+
if load_model:
|
| 63 |
+
model.resampler.to(model_info.torch_dtype) # fix float32
|
| 64 |
+
_patch_minicpmv_device_map(model)
|
| 65 |
+
func_list = ['generate', 'get_input_embeddings', 'forward']
|
| 66 |
+
use_submodel_func(model, 'llm', func_list)
|
| 67 |
+
if hasattr(model, 'get_slice_image_placeholder'):
|
| 68 |
+
tokenizer.get_slice_image_placeholder = MethodType(model.get_slice_image_placeholder, tokenizer)
|
| 69 |
+
tokenizer.transform = MethodType(model.transform, tokenizer)
|
| 70 |
+
return model, tokenizer
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
MLLMModelType.minicpmv,
|
| 76 |
+
[
|
| 77 |
+
ModelGroup([
|
| 78 |
+
Model('OpenBMB/MiniCPM-V', 'openbmb/MiniCPM-V'),
|
| 79 |
+
Model('OpenBMB/MiniCPM-V-2', 'openbmb/MiniCPM-V-2'),
|
| 80 |
+
], ),
|
| 81 |
+
],
|
| 82 |
+
TemplateType.minicpmv,
|
| 83 |
+
get_model_tokenizer_minicpmv,
|
| 84 |
+
architectures=['MiniCPMV'],
|
| 85 |
+
model_arch=ModelArch.minicpmv,
|
| 86 |
+
requires=['timm', 'transformers<4.42'],
|
| 87 |
+
tags=['vision'],
|
| 88 |
+
))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_model_tokenizer_minicpmv_2_x(model_dir: str,
|
| 92 |
+
model_info: ModelInfo,
|
| 93 |
+
model_kwargs: Dict[str, Any],
|
| 94 |
+
load_model: bool = True,
|
| 95 |
+
**kwargs):
|
| 96 |
+
from transformers import AutoProcessor
|
| 97 |
+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
|
| 98 |
+
version = kwargs.get('version')
|
| 99 |
+
if version == 'o2.6':
|
| 100 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 101 |
+
model_config.init_tts = strtobool(get_env_args('init_tts', str, 'false'))
|
| 102 |
+
model_config.init_audio = strtobool(get_env_args('init_audio', str, 'false'))
|
| 103 |
+
kwargs['model_config'] = model_config
|
| 104 |
+
with patch_device_map():
|
| 105 |
+
model, tokenizer = get_model_tokenizer_minicpmv(
|
| 106 |
+
model_dir, model_info, model_kwargs, load_model, tokenizer=processor.tokenizer, **kwargs)
|
| 107 |
+
if load_model:
|
| 108 |
+
embedding = model.get_input_embeddings()
|
| 109 |
+
patch_output_clone(embedding)
|
| 110 |
+
|
| 111 |
+
return model, processor
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
register_model(
|
| 115 |
+
ModelMeta(
|
| 116 |
+
MLLMModelType.minicpmv2_5,
|
| 117 |
+
[
|
| 118 |
+
ModelGroup([
|
| 119 |
+
Model('OpenBMB/MiniCPM-Llama3-V-2_5', 'openbmb/MiniCPM-Llama3-V-2_5'),
|
| 120 |
+
], ),
|
| 121 |
+
],
|
| 122 |
+
TemplateType.minicpmv2_5,
|
| 123 |
+
get_model_tokenizer_minicpmv_2_x,
|
| 124 |
+
architectures=['MiniCPMV'],
|
| 125 |
+
model_arch=ModelArch.minicpmv,
|
| 126 |
+
requires=['timm', 'transformers>=4.36'],
|
| 127 |
+
tags=['vision'],
|
| 128 |
+
))
|
| 129 |
+
|
| 130 |
+
register_model(
|
| 131 |
+
ModelMeta(
|
| 132 |
+
MLLMModelType.minicpmv2_6,
|
| 133 |
+
[
|
| 134 |
+
ModelGroup([
|
| 135 |
+
Model('OpenBMB/MiniCPM-V-2_6', 'openbmb/MiniCPM-V-2_6'),
|
| 136 |
+
], ),
|
| 137 |
+
],
|
| 138 |
+
TemplateType.minicpmv2_6,
|
| 139 |
+
get_model_tokenizer_minicpmv_2_x,
|
| 140 |
+
architectures=['MiniCPMV'],
|
| 141 |
+
model_arch=ModelArch.minicpmv,
|
| 142 |
+
requires=['timm', 'transformers>=4.36', 'decord'],
|
| 143 |
+
tags=['vision', 'video'],
|
| 144 |
+
))
|
| 145 |
+
|
| 146 |
+
register_model(
|
| 147 |
+
ModelMeta(
|
| 148 |
+
MLLMModelType.minicpmo2_6,
|
| 149 |
+
[
|
| 150 |
+
ModelGroup([
|
| 151 |
+
Model('OpenBMB/MiniCPM-o-2_6', 'openbmb/MiniCPM-o-2_6'),
|
| 152 |
+
]),
|
| 153 |
+
],
|
| 154 |
+
TemplateType.minicpmo2_6,
|
| 155 |
+
partial(get_model_tokenizer_minicpmv_2_x, version='o2.6'),
|
| 156 |
+
architectures=['MiniCPMO'],
|
| 157 |
+
model_arch=ModelArch.minicpmv,
|
| 158 |
+
requires=['timm', 'transformers>=4.36', 'decord', 'soundfile'],
|
| 159 |
+
tags=['vision', 'video', 'omni', 'audio'],
|
| 160 |
+
))
|
| 161 |
+
|
| 162 |
+
register_model(
|
| 163 |
+
ModelMeta(
|
| 164 |
+
LLMModelType.minicpm,
|
| 165 |
+
[
|
| 166 |
+
ModelGroup([
|
| 167 |
+
Model('OpenBMB/MiniCPM-2B-sft-fp32', 'openbmb/MiniCPM-2B-sft-fp32'),
|
| 168 |
+
Model('OpenBMB/MiniCPM-2B-dpo-fp32', 'openbmb/MiniCPM-2B-dpo-fp32'),
|
| 169 |
+
Model('OpenBMB/MiniCPM-1B-sft-bf16', 'openbmb/MiniCPM-1B-sft-bf16'),
|
| 170 |
+
], ),
|
| 171 |
+
],
|
| 172 |
+
TemplateType.minicpm,
|
| 173 |
+
get_model_tokenizer_with_flash_attn,
|
| 174 |
+
architectures=['MiniCPMForCausalLM'],
|
| 175 |
+
model_arch=ModelArch.llama,
|
| 176 |
+
requires=['transformers>=4.36.0'],
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
register_model(
|
| 180 |
+
ModelMeta(
|
| 181 |
+
LLMModelType.minicpm_chatml,
|
| 182 |
+
[
|
| 183 |
+
ModelGroup([
|
| 184 |
+
Model('OpenBMB/MiniCPM-2B-128k', 'openbmb/MiniCPM-2B-128k'),
|
| 185 |
+
]),
|
| 186 |
+
],
|
| 187 |
+
TemplateType.chatml,
|
| 188 |
+
get_model_tokenizer_with_flash_attn,
|
| 189 |
+
architectures=['MiniCPMForCausalLM'],
|
| 190 |
+
model_arch=ModelArch.llama,
|
| 191 |
+
requires=['transformers>=4.36'],
|
| 192 |
+
))
|
| 193 |
+
|
| 194 |
+
register_model(
|
| 195 |
+
ModelMeta(
|
| 196 |
+
LLMModelType.minicpm3,
|
| 197 |
+
[
|
| 198 |
+
ModelGroup([
|
| 199 |
+
Model('OpenBMB/MiniCPM3-4B', 'openbmb/MiniCPM3-4B'),
|
| 200 |
+
]),
|
| 201 |
+
],
|
| 202 |
+
TemplateType.chatml,
|
| 203 |
+
get_model_tokenizer_with_flash_attn,
|
| 204 |
+
architectures=['MiniCPM3ForCausalLM'],
|
| 205 |
+
model_arch=ModelArch.deepseek_v2,
|
| 206 |
+
requires=['transformers>=4.36'],
|
| 207 |
+
))
|
swift/llm/model/model/minimax.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from transformers import AutoConfig, AutoProcessor
|
| 7 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 8 |
+
|
| 9 |
+
from swift.llm import TemplateType
|
| 10 |
+
from swift.utils import get_device, get_device_count, get_dist_setting, get_logger
|
| 11 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 12 |
+
from ..patcher import patch_ignore_check_imports
|
| 13 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 14 |
+
from ..utils import ModelInfo
|
| 15 |
+
|
| 16 |
+
logger = get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_model_tokenizer_minimax_vl(model_dir: str,
|
| 20 |
+
model_info: ModelInfo,
|
| 21 |
+
model_kwargs: Dict[str, Any],
|
| 22 |
+
load_model: bool = True,
|
| 23 |
+
**kwargs):
|
| 24 |
+
logger.warn('NOTE: minimax-vl-01 model does not support training.')
|
| 25 |
+
n_gpu = get_device_count()
|
| 26 |
+
_, local_rank, _, local_world_size = get_dist_setting()
|
| 27 |
+
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
|
| 28 |
+
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 29 |
+
kwargs['model_config'] = config
|
| 30 |
+
if kwargs.get('attn_impl') == 'flash_attn':
|
| 31 |
+
config.attn_type_list = [1] * len(config.attn_type_list)
|
| 32 |
+
else:
|
| 33 |
+
config.attn_type_list = [0] * len(config.attn_type_list)
|
| 34 |
+
if 'quantization_config' in model_kwargs:
|
| 35 |
+
quantization_config = model_kwargs['quantization_config']
|
| 36 |
+
from transformers import QuantoConfig
|
| 37 |
+
if isinstance(quantization_config, QuantoConfig):
|
| 38 |
+
quantization_config.modules_to_not_convert = (
|
| 39 |
+
[
|
| 40 |
+
'vision_tower',
|
| 41 |
+
'image_newline',
|
| 42 |
+
'multi_modal_projector',
|
| 43 |
+
'lm_head',
|
| 44 |
+
'embed_tokens',
|
| 45 |
+
] + [f'model.layers.{i}.coefficient' for i in range(config.text_config.num_hidden_layers)]
|
| 46 |
+
+ [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.text_config.num_hidden_layers)])
|
| 47 |
+
|
| 48 |
+
if len(device_ids) > 1:
|
| 49 |
+
model_safetensors_index_path = os.path.join(model_dir, 'model.safetensors.index.json')
|
| 50 |
+
with open(model_safetensors_index_path, 'r') as f:
|
| 51 |
+
model_safetensors_index = json.load(f)
|
| 52 |
+
weight_map = model_safetensors_index['weight_map']
|
| 53 |
+
vision_map = {}
|
| 54 |
+
for key, value in weight_map.items():
|
| 55 |
+
if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
|
| 56 |
+
new_key = key.replace('.weight', '').replace('.bias', '')
|
| 57 |
+
if new_key not in vision_map:
|
| 58 |
+
vision_map[new_key] = value
|
| 59 |
+
|
| 60 |
+
device_map = {
|
| 61 |
+
'language_model.model.embed_tokens': get_device(device_ids[0]),
|
| 62 |
+
'language_model.model.norm': get_device(device_ids[len(device_ids) - 1]),
|
| 63 |
+
'language_model.lm_head': get_device(device_ids[len(device_ids) - 1])
|
| 64 |
+
}
|
| 65 |
+
for key, value in vision_map.items():
|
| 66 |
+
device_map[key] = get_device(device_ids[0])
|
| 67 |
+
device_map['vision_tower.vision_model.post_layernorm'] = get_device(device_ids[0])
|
| 68 |
+
layers_per_device = config.text_config.num_hidden_layers // len(device_ids)
|
| 69 |
+
for i in range(len(device_ids)):
|
| 70 |
+
for j in range(layers_per_device):
|
| 71 |
+
device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = get_device(device_ids[i])
|
| 72 |
+
model_kwargs['device_map'] = device_map
|
| 73 |
+
|
| 74 |
+
MiniMaxVL01ProcessorKwargs = get_class_from_dynamic_module('processing_minimax_vl_01.MiniMaxVL01ProcessorKwargs',
|
| 75 |
+
model_dir)
|
| 76 |
+
get_hw_multiple_of = get_class_from_dynamic_module('processing_minimax_vl_01.get_hw_multiple_of', model_dir)
|
| 77 |
+
get_num_token = get_class_from_dynamic_module('processing_minimax_vl_01.get_num_token', model_dir)
|
| 78 |
+
|
| 79 |
+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
|
| 80 |
+
processor.MiniMaxVL01ProcessorKwargs = MiniMaxVL01ProcessorKwargs
|
| 81 |
+
processor.get_hw_multiple_of = get_hw_multiple_of
|
| 82 |
+
processor.get_num_token = get_num_token
|
| 83 |
+
with patch_ignore_check_imports():
|
| 84 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
|
| 85 |
+
**kwargs)
|
| 86 |
+
processor.tokenizer = tokenizer
|
| 87 |
+
return model, processor
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
register_model(
|
| 91 |
+
ModelMeta(
|
| 92 |
+
MLLMModelType.minimax_vl, [
|
| 93 |
+
ModelGroup([
|
| 94 |
+
Model('MiniMax/MiniMax-VL-01', 'MiniMaxAI/MiniMax-VL-01'),
|
| 95 |
+
]),
|
| 96 |
+
],
|
| 97 |
+
TemplateType.minimax_vl,
|
| 98 |
+
get_model_tokenizer_minimax_vl,
|
| 99 |
+
architectures=['MiniMaxVL01ForConditionalGeneration'],
|
| 100 |
+
tags=['vision']))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_model_tokenizer_minimax_text(model_dir: str,
|
| 104 |
+
model_info: ModelInfo,
|
| 105 |
+
model_kwargs: Dict[str, Any],
|
| 106 |
+
load_model: bool = True,
|
| 107 |
+
**kwargs):
|
| 108 |
+
logger.warn('NOTE: minimax-text-01 model does not support training.')
|
| 109 |
+
n_gpu = get_device_count()
|
| 110 |
+
_, local_rank, _, local_world_size = get_dist_setting()
|
| 111 |
+
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
|
| 112 |
+
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 113 |
+
kwargs['model_config'] = config
|
| 114 |
+
if kwargs.get('attn_impl') == 'flash_attn':
|
| 115 |
+
config.attn_type_list = [1] * len(config.attn_type_list)
|
| 116 |
+
else:
|
| 117 |
+
config.attn_type_list = [0] * len(config.attn_type_list)
|
| 118 |
+
if 'quantization_config' in model_kwargs:
|
| 119 |
+
quantization_config = model_kwargs['quantization_config']
|
| 120 |
+
from transformers import QuantoConfig
|
| 121 |
+
if isinstance(quantization_config, QuantoConfig):
|
| 122 |
+
quantization_config.modules_to_not_convert = (
|
| 123 |
+
[
|
| 124 |
+
'lm_head',
|
| 125 |
+
'embed_tokens',
|
| 126 |
+
] + [f'model.layers.{i}.coefficient' for i in range(config.num_hidden_layers)]
|
| 127 |
+
+ [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.num_hidden_layers)])
|
| 128 |
+
|
| 129 |
+
if len(device_ids) > 1:
|
| 130 |
+
layers_per_device = config.num_hidden_layers // len(device_ids)
|
| 131 |
+
# set device map
|
| 132 |
+
device_map = {
|
| 133 |
+
'model.embed_tokens': get_device(0),
|
| 134 |
+
'model.norm': get_device(len(device_ids) - 1),
|
| 135 |
+
'lm_head': get_device(len(device_ids) - 1)
|
| 136 |
+
}
|
| 137 |
+
for i in range(len(device_ids)):
|
| 138 |
+
for j in range(layers_per_device):
|
| 139 |
+
device_map[f'model.layers.{i * layers_per_device + j}'] = get_device(i)
|
| 140 |
+
model_kwargs['device_map'] = device_map
|
| 141 |
+
with patch_ignore_check_imports():
|
| 142 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
|
| 143 |
+
**kwargs)
|
| 144 |
+
return model, tokenizer
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
register_model(
|
| 148 |
+
ModelMeta(
|
| 149 |
+
LLMModelType.minimax, [
|
| 150 |
+
ModelGroup([
|
| 151 |
+
Model('MiniMax/MiniMax-Text-01', 'MiniMaxAI/MiniMax-Text-01'),
|
| 152 |
+
]),
|
| 153 |
+
],
|
| 154 |
+
TemplateType.minimax,
|
| 155 |
+
get_model_tokenizer_minimax_text,
|
| 156 |
+
architectures=['MiniMaxText01ForCausalLM']))
|
swift/llm/model/model/mistral.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from swift.llm import TemplateType
|
| 6 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 7 |
+
from ..model_arch import ModelArch
|
| 8 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
|
| 9 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 10 |
+
from ..utils import ModelInfo
|
| 11 |
+
|
| 12 |
+
register_model(
|
| 13 |
+
ModelMeta(
|
| 14 |
+
LLMModelType.mistral,
|
| 15 |
+
[
|
| 16 |
+
ModelGroup([
|
| 17 |
+
Model('AI-ModelScope/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1'),
|
| 18 |
+
Model('AI-ModelScope/Mistral-7B-Instruct-v0.2', 'mistralai/Mistral-7B-Instruct-v0.2'),
|
| 19 |
+
Model('LLM-Research/Mistral-7B-Instruct-v0.3', 'mistralai/Mistral-7B-Instruct-v0.3'),
|
| 20 |
+
Model('AI-ModelScope/Mistral-7B-v0.1', 'mistralai/Mistral-7B-v0.1'),
|
| 21 |
+
Model('AI-ModelScope/Mistral-7B-v0.2-hf', 'alpindale/Mistral-7B-v0.2-hf'),
|
| 22 |
+
]),
|
| 23 |
+
ModelGroup([
|
| 24 |
+
Model('swift/Codestral-22B-v0.1', 'mistralai/Codestral-22B-v0.1'),
|
| 25 |
+
]),
|
| 26 |
+
],
|
| 27 |
+
TemplateType.llama,
|
| 28 |
+
get_model_tokenizer_with_flash_attn,
|
| 29 |
+
architectures=['MistralForCausalLM'],
|
| 30 |
+
model_arch=ModelArch.llama,
|
| 31 |
+
requires=['transformers>=4.34'],
|
| 32 |
+
))
|
| 33 |
+
|
| 34 |
+
register_model(
|
| 35 |
+
ModelMeta(
|
| 36 |
+
LLMModelType.mixtral, [
|
| 37 |
+
ModelGroup([
|
| 38 |
+
Model('AI-ModelScope/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mixtral-8x7B-Instruct-v0.1'),
|
| 39 |
+
Model('AI-ModelScope/Mixtral-8x7B-v0.1', 'mistralai/Mixtral-8x7B-v0.1'),
|
| 40 |
+
Model('AI-ModelScope/Mixtral-8x22B-v0.1', 'mistral-community/Mixtral-8x22B-v0.1'),
|
| 41 |
+
],
|
| 42 |
+
requires=['transformers>=4.36']),
|
| 43 |
+
ModelGroup([
|
| 44 |
+
Model('AI-ModelScope/Mixtral-8x7b-AQLM-2Bit-1x16-hf', 'ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf'),
|
| 45 |
+
],
|
| 46 |
+
requires=['transformers>=4.38', 'aqlm', 'torch>=2.2.0']),
|
| 47 |
+
],
|
| 48 |
+
TemplateType.llama,
|
| 49 |
+
get_model_tokenizer_with_flash_attn,
|
| 50 |
+
architectures=['MixtralForCausalLM'],
|
| 51 |
+
model_arch=ModelArch.llama))
|
| 52 |
+
|
| 53 |
+
register_model(
|
| 54 |
+
ModelMeta(
|
| 55 |
+
LLMModelType.mistral_nemo, [
|
| 56 |
+
ModelGroup([
|
| 57 |
+
Model('AI-ModelScope/Mistral-Small-Instruct-2409', 'mistralai/Mistral-Small-Instruct-2409'),
|
| 58 |
+
Model('LLM-Research/Mistral-Large-Instruct-2407', 'mistralai/Mistral-Large-Instruct-2407'),
|
| 59 |
+
Model('AI-ModelScope/Mistral-Nemo-Base-2407', 'mistralai/Mistral-Nemo-Base-2407'),
|
| 60 |
+
Model('AI-ModelScope/Mistral-Nemo-Instruct-2407', 'mistralai/Mistral-Nemo-Instruct-2407'),
|
| 61 |
+
],
|
| 62 |
+
requires=['transformers>=4.43']),
|
| 63 |
+
ModelGroup([
|
| 64 |
+
Model('AI-ModelScope/Ministral-8B-Instruct-2410', 'mistralai/Ministral-8B-Instruct-2410'),
|
| 65 |
+
],
|
| 66 |
+
requires=['transformers>=4.46']),
|
| 67 |
+
],
|
| 68 |
+
TemplateType.mistral_nemo,
|
| 69 |
+
get_model_tokenizer_with_flash_attn,
|
| 70 |
+
architectures=['MistralForCausalLM'],
|
| 71 |
+
model_arch=ModelArch.llama))
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
LLMModelType.mistral_2501, [
|
| 76 |
+
ModelGroup([
|
| 77 |
+
Model('mistralai/Mistral-Small-24B-Base-2501', 'mistralai/Mistral-Small-24B-Base-2501'),
|
| 78 |
+
Model('mistralai/Mistral-Small-24B-Instruct-2501', 'mistralai/Mistral-Small-24B-Instruct-2501'),
|
| 79 |
+
]),
|
| 80 |
+
],
|
| 81 |
+
TemplateType.mistral_2501,
|
| 82 |
+
get_model_tokenizer_with_flash_attn,
|
| 83 |
+
architectures=['MistralForCausalLM'],
|
| 84 |
+
model_arch=ModelArch.llama))
|
| 85 |
+
|
| 86 |
+
register_model(
|
| 87 |
+
ModelMeta(
|
| 88 |
+
LLMModelType.zephyr,
|
| 89 |
+
[
|
| 90 |
+
ModelGroup([
|
| 91 |
+
Model('modelscope/zephyr-7b-beta', 'HuggingFaceH4/zephyr-7b-beta'),
|
| 92 |
+
]),
|
| 93 |
+
],
|
| 94 |
+
TemplateType.zephyr,
|
| 95 |
+
get_model_tokenizer_with_flash_attn,
|
| 96 |
+
model_arch=ModelArch.llama,
|
| 97 |
+
architectures=['MistralForCausalLM'],
|
| 98 |
+
requires=['transformers>=4.34'],
|
| 99 |
+
))
|
| 100 |
+
|
| 101 |
+
register_model(
|
| 102 |
+
ModelMeta(
|
| 103 |
+
LLMModelType.wizardlm2_moe,
|
| 104 |
+
[ModelGroup([
|
| 105 |
+
Model('AI-ModelScope/WizardLM-2-8x22B', 'alpindale/WizardLM-2-8x22B'),
|
| 106 |
+
])],
|
| 107 |
+
TemplateType.wizardlm2_moe,
|
| 108 |
+
get_model_tokenizer_with_flash_attn,
|
| 109 |
+
architectures=['MixtralForCausalLM'],
|
| 110 |
+
requires=['transformers>=4.36'],
|
| 111 |
+
))
|
| 112 |
+
|
| 113 |
+
register_model(
|
| 114 |
+
ModelMeta(
|
| 115 |
+
LLMModelType.wizardlm2,
|
| 116 |
+
[ModelGroup([
|
| 117 |
+
Model('AI-ModelScope/WizardLM-2-7B-AWQ', 'MaziyarPanahi/WizardLM-2-7B-AWQ'),
|
| 118 |
+
])],
|
| 119 |
+
TemplateType.wizardlm2,
|
| 120 |
+
get_model_tokenizer_with_flash_attn,
|
| 121 |
+
architectures=['MistralForCausalLM'],
|
| 122 |
+
requires=['transformers>=4.34'],
|
| 123 |
+
))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_model_tokenizer_mistral_2503(model_dir: str,
|
| 127 |
+
model_info: ModelInfo,
|
| 128 |
+
model_kwargs: Dict[str, Any],
|
| 129 |
+
load_model: bool = True,
|
| 130 |
+
**kwargs):
|
| 131 |
+
try:
|
| 132 |
+
from transformers import Mistral3ForConditionalGeneration
|
| 133 |
+
except ImportError:
|
| 134 |
+
raise ImportError('Please install Gemma3ForConditionalGeneration by running '
|
| 135 |
+
'`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`')
|
| 136 |
+
|
| 137 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
|
| 138 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 139 |
+
|
| 140 |
+
return model, processor
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
register_model(
|
| 144 |
+
ModelMeta(
|
| 145 |
+
MLLMModelType.mistral_2503,
|
| 146 |
+
[
|
| 147 |
+
ModelGroup([
|
| 148 |
+
Model('mistralai/Mistral-Small-3.1-24B-Base-2503', 'mistralai/Mistral-Small-3.1-24B-Base-2503'),
|
| 149 |
+
Model('mistralai/Mistral-Small-3.1-24B-Instruct-2503', 'mistralai/Mistral-Small-3.1-24B-Instruct-2503'),
|
| 150 |
+
]),
|
| 151 |
+
],
|
| 152 |
+
TemplateType.mistral_2503,
|
| 153 |
+
get_model_tokenizer_mistral_2503,
|
| 154 |
+
architectures=['Mistral3ForConditionalGeneration'],
|
| 155 |
+
model_arch=ModelArch.llava_hf,
|
| 156 |
+
requires=['transformers>=4.49'],
|
| 157 |
+
), )
|
swift/llm/model/model/mllm.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from types import MethodType
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 7 |
+
|
| 8 |
+
from swift.llm import TemplateType
|
| 9 |
+
from swift.llm.model.model.qwen import get_model_tokenizer_qwen2_vl
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
from ..constant import MLLMModelType
|
| 12 |
+
from ..model_arch import ModelArch
|
| 13 |
+
from ..patcher import patch_output_clone, patch_output_normalizer
|
| 14 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
|
| 15 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 16 |
+
from ..utils import ModelInfo, use_submodel_func
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_model_tokenizer_idefics(model_dir: str, *args, **kwargs):
|
| 22 |
+
from transformers import AutoModelForVision2Seq
|
| 23 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or AutoModelForVision2Seq
|
| 24 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
|
| 25 |
+
return model, processor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_model(
|
| 29 |
+
ModelMeta(
|
| 30 |
+
MLLMModelType.idefics3,
|
| 31 |
+
[
|
| 32 |
+
ModelGroup([
|
| 33 |
+
Model('AI-ModelScope/Idefics3-8B-Llama3', 'HuggingFaceM4/Idefics3-8B-Llama3'),
|
| 34 |
+
]),
|
| 35 |
+
],
|
| 36 |
+
TemplateType.idefics3,
|
| 37 |
+
get_model_tokenizer_idefics,
|
| 38 |
+
model_arch=ModelArch.idefics3,
|
| 39 |
+
architectures=['Idefics3ForConditionalGeneration'],
|
| 40 |
+
tags=['vision'],
|
| 41 |
+
requires=['transformers>=4.45'],
|
| 42 |
+
))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_model_tokenizer_pixtral(model_dir: str, *args, **kwargs):
|
| 46 |
+
from transformers import LlavaForConditionalGeneration
|
| 47 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaForConditionalGeneration
|
| 48 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
|
| 49 |
+
return model, processor
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
register_model(
|
| 53 |
+
ModelMeta(
|
| 54 |
+
MLLMModelType.pixtral,
|
| 55 |
+
[
|
| 56 |
+
ModelGroup([
|
| 57 |
+
Model('AI-ModelScope/pixtral-12b', 'mistral-community/pixtral-12b'),
|
| 58 |
+
]),
|
| 59 |
+
],
|
| 60 |
+
TemplateType.pixtral,
|
| 61 |
+
get_model_tokenizer_pixtral,
|
| 62 |
+
model_arch=ModelArch.llava_hf,
|
| 63 |
+
architectures=['LlavaForConditionalGeneration'],
|
| 64 |
+
requires=['transformers>=4.45'],
|
| 65 |
+
tags=['vision'],
|
| 66 |
+
))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_model_tokenizer_molmoe(model_dir: str,
|
| 70 |
+
model_info: ModelInfo,
|
| 71 |
+
model_kwargs: Dict[str, Any],
|
| 72 |
+
load_model: bool = True,
|
| 73 |
+
**kwargs):
|
| 74 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 75 |
+
|
| 76 |
+
# fix bug for molmoe-1b
|
| 77 |
+
def to_dict(self, *args, **kwargs):
|
| 78 |
+
res = self._to_dict(*args, **kwargs)
|
| 79 |
+
res['vision_backbone'] = self.vision_backbone.__dict__
|
| 80 |
+
res.pop('to_dict')
|
| 81 |
+
res.pop('_to_dict')
|
| 82 |
+
return res
|
| 83 |
+
|
| 84 |
+
if model is not None:
|
| 85 |
+
model.config._to_dict = model.config.to_dict
|
| 86 |
+
model.config.to_dict = MethodType(to_dict, model.config)
|
| 87 |
+
patch_output_clone(model.model.transformer.wte)
|
| 88 |
+
return model, processor
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
register_model(
|
| 92 |
+
ModelMeta(
|
| 93 |
+
MLLMModelType.molmoe,
|
| 94 |
+
[
|
| 95 |
+
ModelGroup([
|
| 96 |
+
Model('LLM-Research/MolmoE-1B-0924', 'allenai/MolmoE-1B-0924'),
|
| 97 |
+
]),
|
| 98 |
+
],
|
| 99 |
+
TemplateType.molmo,
|
| 100 |
+
get_model_tokenizer_molmoe,
|
| 101 |
+
model_arch=ModelArch.molmo,
|
| 102 |
+
torch_dtype=torch.float32,
|
| 103 |
+
architectures=['OLMoForCausalLM'],
|
| 104 |
+
tags=['vision'],
|
| 105 |
+
requires=['transformers>=4.45'],
|
| 106 |
+
))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_model_tokenizer_molmo(model_dir: str,
|
| 110 |
+
model_info: ModelInfo,
|
| 111 |
+
model_kwargs: Dict[str, Any],
|
| 112 |
+
load_model: bool = True,
|
| 113 |
+
**kwargs):
|
| 114 |
+
model_cls = get_class_from_dynamic_module('modeling_molmo.MolmoForCausalLM', model_dir)
|
| 115 |
+
model_cls._no_split_modules = ['MolmoSequentialBlock']
|
| 116 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 117 |
+
if model is not None:
|
| 118 |
+
patch_output_clone(model.model.transformer.wte)
|
| 119 |
+
return model, processor
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
register_model(
|
| 123 |
+
ModelMeta(
|
| 124 |
+
MLLMModelType.molmo,
|
| 125 |
+
[
|
| 126 |
+
ModelGroup([
|
| 127 |
+
Model('LLM-Research/Molmo-7B-O-0924', 'allenai/Molmo-7B-O-0924'),
|
| 128 |
+
Model('LLM-Research/Molmo-7B-D-0924', 'allenai/Molmo-7B-D-0924'),
|
| 129 |
+
Model('LLM-Research/Molmo-72B-0924', 'allenai/Molmo-72B-0924'),
|
| 130 |
+
]),
|
| 131 |
+
],
|
| 132 |
+
TemplateType.molmo,
|
| 133 |
+
get_model_tokenizer_molmo,
|
| 134 |
+
model_arch=ModelArch.molmo,
|
| 135 |
+
architectures=['MolmoForCausalLM'],
|
| 136 |
+
tags=['vision'],
|
| 137 |
+
requires=['transformers>=4.45'],
|
| 138 |
+
))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_model_tokenizer_megrez_omni(model_dir, *args, **kwargs):
|
| 142 |
+
model_cls = get_class_from_dynamic_module('modeling_megrezo.MegrezO', model_dir)
|
| 143 |
+
model_cls._no_split_modules = ['ResidualAttentionBlock', 'LlamaDecoderLayer']
|
| 144 |
+
model_cls = get_class_from_dynamic_module('modeling_megrezo.SiglipVisionTransformer', model_dir)
|
| 145 |
+
model_cls._no_split_modules = ['SiglipEncoderLayer']
|
| 146 |
+
model, processor = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
|
| 147 |
+
processor = model._get_or_init_processor()
|
| 148 |
+
patch_output_clone(model.llm.model.embed_tokens)
|
| 149 |
+
use_submodel_func(model, 'llm')
|
| 150 |
+
return model, processor
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
register_model(
|
| 154 |
+
ModelMeta(
|
| 155 |
+
MLLMModelType.megrez_omni,
|
| 156 |
+
[
|
| 157 |
+
ModelGroup([
|
| 158 |
+
Model('InfiniAI/Megrez-3B-Omni', 'Infinigence/Megrez-3B-Omni'),
|
| 159 |
+
]),
|
| 160 |
+
],
|
| 161 |
+
TemplateType.megrez_omni,
|
| 162 |
+
get_model_tokenizer_megrez_omni,
|
| 163 |
+
model_arch=ModelArch.megrez_omni,
|
| 164 |
+
architectures=['MegrezO'],
|
| 165 |
+
tags=['vision', 'audio'],
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
register_model(
|
| 169 |
+
ModelMeta(
|
| 170 |
+
MLLMModelType.qwen2_gme, [
|
| 171 |
+
ModelGroup([
|
| 172 |
+
Model('iic/gme-Qwen2-VL-2B-Instruct', 'Alibaba-NLP/gme-Qwen2-VL-2B-Instruct'),
|
| 173 |
+
Model('iic/gme-Qwen2-VL-7B-Instruct', 'Alibaba-NLP/gme-Qwen2-VL-7B-Instruct'),
|
| 174 |
+
]),
|
| 175 |
+
],
|
| 176 |
+
TemplateType.qwen2_gme,
|
| 177 |
+
get_model_tokenizer_qwen2_vl,
|
| 178 |
+
model_arch=ModelArch.qwen2_vl,
|
| 179 |
+
architectures=['Qwen2VLForConditionalGeneration'],
|
| 180 |
+
tags=['vision']))
|
swift/llm/model/model/moonshot.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from swift.llm import TemplateType
|
| 3 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 4 |
+
from ..model_arch import ModelArch
|
| 5 |
+
from ..patcher import patch_output_clone
|
| 6 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
|
| 7 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 8 |
+
|
| 9 |
+
register_model(
|
| 10 |
+
ModelMeta(
|
| 11 |
+
LLMModelType.moonlight,
|
| 12 |
+
[
|
| 13 |
+
ModelGroup([
|
| 14 |
+
Model('moonshotai/Moonlight-16B-A3B', 'moonshotai/Moonlight-16B-A3B'),
|
| 15 |
+
Model('moonshotai/Moonlight-16B-A3B-Instruct', 'moonshotai/Moonlight-16B-A3B-Instruct'),
|
| 16 |
+
]),
|
| 17 |
+
],
|
| 18 |
+
TemplateType.moonlight,
|
| 19 |
+
get_model_tokenizer_with_flash_attn,
|
| 20 |
+
architectures=['DeepseekV3ForCausalLM'],
|
| 21 |
+
model_arch=ModelArch.deepseek_v2,
|
| 22 |
+
requires=['transformers<4.49'],
|
| 23 |
+
))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_model_tokenizer_kimi_vl(*args, **kwargs):
|
| 27 |
+
model, processor = get_model_tokenizer_multimodal(*args, **kwargs)
|
| 28 |
+
if model is not None:
|
| 29 |
+
patch_output_clone(model.language_model.model.embed_tokens)
|
| 30 |
+
return model, processor
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
register_model(
|
| 34 |
+
ModelMeta(
|
| 35 |
+
MLLMModelType.kimi_vl,
|
| 36 |
+
[
|
| 37 |
+
ModelGroup([
|
| 38 |
+
Model('moonshotai/Kimi-VL-A3B-Instruct', 'moonshotai/Kimi-VL-A3B-Instruct'),
|
| 39 |
+
Model('moonshotai/Kimi-VL-A3B-Thinking', 'moonshotai/Kimi-VL-A3B-Thinking'),
|
| 40 |
+
])
|
| 41 |
+
],
|
| 42 |
+
TemplateType.kimi_vl,
|
| 43 |
+
get_model_tokenizer_kimi_vl,
|
| 44 |
+
architectures=['KimiVLForConditionalGeneration'],
|
| 45 |
+
model_arch=ModelArch.llava_hf,
|
| 46 |
+
requires=['transformers<4.49'],
|
| 47 |
+
))
|
swift/llm/model/model/mplug.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from transformers import AutoConfig
|
| 8 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 9 |
+
|
| 10 |
+
from swift.llm import TemplateType
|
| 11 |
+
from swift.utils import get_logger
|
| 12 |
+
from ..constant import MLLMModelType
|
| 13 |
+
from ..model_arch import ModelArch
|
| 14 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 15 |
+
from ..utils import ModelInfo, git_clone_github, use_submodel_func
|
| 16 |
+
from .qwen import get_model_tokenizer_qwen
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_model_tokenizer_mplug_owl2(model_dir: str,
|
| 22 |
+
model_info: ModelInfo,
|
| 23 |
+
model_kwargs: Dict[str, Any],
|
| 24 |
+
load_model: bool = True,
|
| 25 |
+
**kwargs):
|
| 26 |
+
local_repo_path = kwargs.get('local_repo_path')
|
| 27 |
+
if not local_repo_path:
|
| 28 |
+
local_repo_path = git_clone_github('https://github.com/X-PLUG/mPLUG-Owl')
|
| 29 |
+
local_repo_path = os.path.join(local_repo_path, 'mPLUG-Owl2')
|
| 30 |
+
sys.path.append(local_repo_path)
|
| 31 |
+
|
| 32 |
+
# register
|
| 33 |
+
# https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py#L447
|
| 34 |
+
from mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
| 35 |
+
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
| 36 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 37 |
+
vocab_size = kwargs.pop('vocab_size', None)
|
| 38 |
+
if vocab_size is not None:
|
| 39 |
+
model_config.vocab_size = vocab_size
|
| 40 |
+
get_model_tokenizer_function = kwargs.pop('get_model_tokenizer_function', get_model_tokenizer_with_flash_attn)
|
| 41 |
+
model, tokenizer = get_model_tokenizer_function(
|
| 42 |
+
model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
|
| 43 |
+
logger.info('Please ignore the unimported warning.')
|
| 44 |
+
processor = CLIPImageProcessor.from_pretrained(model_dir)
|
| 45 |
+
processor.tokenizer = tokenizer
|
| 46 |
+
return model, processor
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
register_model(
|
| 50 |
+
ModelMeta(
|
| 51 |
+
MLLMModelType.mplug_owl2, [ModelGroup([
|
| 52 |
+
Model('iic/mPLUG-Owl2', 'MAGAer13/mplug-owl2-llama2-7b'),
|
| 53 |
+
])],
|
| 54 |
+
TemplateType.mplug_owl2,
|
| 55 |
+
get_model_tokenizer_mplug_owl2,
|
| 56 |
+
model_arch=ModelArch.mplug_owl2,
|
| 57 |
+
requires=['transformers<4.35', 'icecream'],
|
| 58 |
+
tags=['vision']), )
|
| 59 |
+
|
| 60 |
+
register_model(
|
| 61 |
+
ModelMeta(
|
| 62 |
+
MLLMModelType.mplug_owl2_1, [ModelGroup([
|
| 63 |
+
Model('iic/mPLUG-Owl2.1', 'Mizukiluke/mplug_owl_2_1'),
|
| 64 |
+
])],
|
| 65 |
+
TemplateType.mplug_owl2,
|
| 66 |
+
partial(
|
| 67 |
+
get_model_tokenizer_mplug_owl2, vocab_size=151851, get_model_tokenizer_function=get_model_tokenizer_qwen),
|
| 68 |
+
model_arch=ModelArch.mplug_owl2_1,
|
| 69 |
+
requires=['transformers<4.35', 'icecream'],
|
| 70 |
+
tags=['vision']))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_model_tokenizer_mplug_owl3(model_dir: str,
|
| 74 |
+
model_info: ModelInfo,
|
| 75 |
+
model_kwargs: Dict[str, Any],
|
| 76 |
+
load_model: bool = True,
|
| 77 |
+
**kwargs):
|
| 78 |
+
get_class_from_dynamic_module('configuration_hyper_qwen2.HyperQwen2Config', model_dir)
|
| 79 |
+
model_cls = get_class_from_dynamic_module('modeling_mplugowl3.mPLUGOwl3Model', model_dir)
|
| 80 |
+
model_cls._no_split_modules = ['SiglipEncoderLayer']
|
| 81 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 82 |
+
processor = model.init_processor(tokenizer)
|
| 83 |
+
if model is not None:
|
| 84 |
+
func_list = ['generate', 'forward']
|
| 85 |
+
use_submodel_func(model, 'language_model', func_list)
|
| 86 |
+
return model, processor
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
register_model(
|
| 90 |
+
ModelMeta(
|
| 91 |
+
MLLMModelType.mplug_owl3, [
|
| 92 |
+
ModelGroup([
|
| 93 |
+
Model('iic/mPLUG-Owl3-1B-241014', 'mPLUG/mPLUG-Owl3-1B-241014'),
|
| 94 |
+
Model('iic/mPLUG-Owl3-2B-241014', 'mPLUG/mPLUG-Owl3-2B-241014'),
|
| 95 |
+
Model('iic/mPLUG-Owl3-7B-240728', 'mPLUG/mPLUG-Owl3-7B-240728'),
|
| 96 |
+
]),
|
| 97 |
+
],
|
| 98 |
+
TemplateType.mplug_owl3,
|
| 99 |
+
get_model_tokenizer_mplug_owl3,
|
| 100 |
+
architectures=['mPLUGOwl3Model'],
|
| 101 |
+
model_arch=ModelArch.mplug_owl3,
|
| 102 |
+
requires=['transformers>=4.36', 'icecream', 'decord'],
|
| 103 |
+
tags=['vision', 'video']))
|
| 104 |
+
|
| 105 |
+
register_model(
|
| 106 |
+
ModelMeta(
|
| 107 |
+
MLLMModelType.mplug_owl3_241101, [
|
| 108 |
+
ModelGroup([
|
| 109 |
+
Model('iic/mPLUG-Owl3-7B-241101', 'mPLUG/mPLUG-Owl3-7B-241101'),
|
| 110 |
+
]),
|
| 111 |
+
],
|
| 112 |
+
TemplateType.mplug_owl3_241101,
|
| 113 |
+
get_model_tokenizer_mplug_owl3,
|
| 114 |
+
architectures=['mPLUGOwl3Model'],
|
| 115 |
+
model_arch=ModelArch.mplug_owl3,
|
| 116 |
+
requires=['transformers>=4.36', 'icecream'],
|
| 117 |
+
tags=['vision', 'video']))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_model_tokenizer_doc_owl2(model_dir: str,
|
| 121 |
+
model_info: ModelInfo,
|
| 122 |
+
model_kwargs: Dict[str, Any],
|
| 123 |
+
load_model: bool = True,
|
| 124 |
+
**kwargs):
|
| 125 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 126 |
+
processor = model.init_processor(tokenizer, basic_image_size=504, crop_anchors='grid_12')
|
| 127 |
+
return model, processor
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
register_model(
|
| 131 |
+
ModelMeta(
|
| 132 |
+
MLLMModelType.doc_owl2, [
|
| 133 |
+
ModelGroup([
|
| 134 |
+
Model('iic/DocOwl2', 'mPLUG/DocOwl2'),
|
| 135 |
+
]),
|
| 136 |
+
],
|
| 137 |
+
TemplateType.doc_owl2,
|
| 138 |
+
get_model_tokenizer_doc_owl2,
|
| 139 |
+
architectures=['mPLUGDocOwl2'],
|
| 140 |
+
model_arch=ModelArch.doc_owl2,
|
| 141 |
+
requires=['transformers>=4.36', 'icecream'],
|
| 142 |
+
tags=['vision']))
|
swift/llm/model/model/openbuddy.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from swift.llm import TemplateType
|
| 3 |
+
from swift.utils import get_logger
|
| 4 |
+
from ..constant import LLMModelType
|
| 5 |
+
from ..model_arch import ModelArch
|
| 6 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 7 |
+
|
| 8 |
+
logger = get_logger()
|
| 9 |
+
|
| 10 |
+
register_model(
|
| 11 |
+
ModelMeta(
|
| 12 |
+
LLMModelType.openbuddy_llama,
|
| 13 |
+
[
|
| 14 |
+
ModelGroup([
|
| 15 |
+
Model('OpenBuddy/openbuddy-llama-65b-v8-bf16', 'OpenBuddy/openbuddy-llama-65b-v8-bf16'),
|
| 16 |
+
]),
|
| 17 |
+
ModelGroup([
|
| 18 |
+
Model('OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16'),
|
| 19 |
+
Model('OpenBuddy/openbuddy-llama2-70b-v10.1-bf16', 'OpenBuddy/openbuddy-llama2-70b-v10.1-bf16'),
|
| 20 |
+
]),
|
| 21 |
+
ModelGroup([
|
| 22 |
+
Model('OpenBuddy/openbuddy-deepseek-67b-v15.2', 'OpenBuddy/openbuddy-deepseek-67b-v15.2'),
|
| 23 |
+
]),
|
| 24 |
+
],
|
| 25 |
+
TemplateType.openbuddy,
|
| 26 |
+
get_model_tokenizer_with_flash_attn,
|
| 27 |
+
model_arch=ModelArch.llama,
|
| 28 |
+
architectures=['LlamaForCausalLM'],
|
| 29 |
+
))
|
| 30 |
+
|
| 31 |
+
register_model(
|
| 32 |
+
ModelMeta(
|
| 33 |
+
LLMModelType.openbuddy_llama3,
|
| 34 |
+
[
|
| 35 |
+
ModelGroup([
|
| 36 |
+
Model('OpenBuddy/openbuddy-llama3-8b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-8b-v21.1-8k'),
|
| 37 |
+
Model('OpenBuddy/openbuddy-llama3-70b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-70b-v21.1-8k'),
|
| 38 |
+
Model('OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k', 'OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k'),
|
| 39 |
+
]),
|
| 40 |
+
ModelGroup([
|
| 41 |
+
Model('OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k', 'OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k'),
|
| 42 |
+
Model('OpenBuddy/openbuddy-nemotron-70b-v23.2-131k', 'OpenBuddy/openbuddy-nemotron-70b-v23.2-131k'),
|
| 43 |
+
],
|
| 44 |
+
requires=['transformers>=4.43']),
|
| 45 |
+
ModelGroup(
|
| 46 |
+
[Model('OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k', 'OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k')],
|
| 47 |
+
requires=['transformers>=4.45'])
|
| 48 |
+
],
|
| 49 |
+
TemplateType.openbuddy2,
|
| 50 |
+
get_model_tokenizer_with_flash_attn,
|
| 51 |
+
model_arch=ModelArch.llama,
|
| 52 |
+
architectures=['LlamaForCausalLM'],
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
register_model(
|
| 56 |
+
ModelMeta(
|
| 57 |
+
LLMModelType.openbuddy_mistral,
|
| 58 |
+
[
|
| 59 |
+
ModelGroup([
|
| 60 |
+
Model('OpenBuddy/openbuddy-mistral-7b-v17.1-32k', 'OpenBuddy/openbuddy-mistral-7b-v17.1-32k'),
|
| 61 |
+
]),
|
| 62 |
+
ModelGroup([
|
| 63 |
+
Model('OpenBuddy/openbuddy-zephyr-7b-v14.1', 'OpenBuddy/openbuddy-zephyr-7b-v14.1'),
|
| 64 |
+
]),
|
| 65 |
+
],
|
| 66 |
+
TemplateType.openbuddy,
|
| 67 |
+
get_model_tokenizer_with_flash_attn,
|
| 68 |
+
model_arch=ModelArch.llama,
|
| 69 |
+
requires=['transformers>=4.34'],
|
| 70 |
+
architectures=['MistralForCausalLM'],
|
| 71 |
+
))
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
LLMModelType.openbuddy_mixtral,
|
| 76 |
+
[
|
| 77 |
+
ModelGroup([
|
| 78 |
+
Model('OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k', 'OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k'),
|
| 79 |
+
], ),
|
| 80 |
+
],
|
| 81 |
+
TemplateType.openbuddy,
|
| 82 |
+
get_model_tokenizer_with_flash_attn,
|
| 83 |
+
architectures=['MixtralForCausalLM'],
|
| 84 |
+
requires=['transformers>=4.36'],
|
| 85 |
+
))
|
swift/llm/model/model/qwen.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
|
| 6 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 7 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| 8 |
+
|
| 9 |
+
from swift.llm import TemplateType
|
| 10 |
+
from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger
|
| 11 |
+
from ..constant import LLMModelType, MLLMModelType, RMModelType
|
| 12 |
+
from ..model_arch import ModelArch
|
| 13 |
+
from ..patcher import patch_fixed_device, patch_output_clone, patch_output_to_input_device
|
| 14 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_reward_model,
|
| 15 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 16 |
+
from ..utils import AttnImpl, ModelInfo, use_submodel_func
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_model_tokenizer_qwen(model_dir: str,
|
| 23 |
+
model_info: ModelInfo,
|
| 24 |
+
model_kwargs: Dict[str, Any],
|
| 25 |
+
load_model: bool = True,
|
| 26 |
+
model_config=None,
|
| 27 |
+
**kwargs):
|
| 28 |
+
if model_config is None:
|
| 29 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 30 |
+
if model_info.torch_dtype is not None:
|
| 31 |
+
k_true = dtype_mapping[model_info.torch_dtype]
|
| 32 |
+
for k in dtype_mapping.values():
|
| 33 |
+
setattr(model_config, k, k == k_true)
|
| 34 |
+
|
| 35 |
+
quantization_config = model_kwargs.get('quantization_config')
|
| 36 |
+
if not isinstance(quantization_config, BitsAndBytesConfig):
|
| 37 |
+
# not bnb quant
|
| 38 |
+
model_config.torch_dtype = None
|
| 39 |
+
use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
|
| 40 |
+
model_config.use_flash_attn = use_flash_attn
|
| 41 |
+
kwargs['model_config'] = model_config
|
| 42 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 43 |
+
try:
|
| 44 |
+
# fix mp+ddp bug
|
| 45 |
+
model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
|
| 46 |
+
logger.info('registered_causal_mask to cuda')
|
| 47 |
+
except AttributeError:
|
| 48 |
+
pass
|
| 49 |
+
if tokenizer.eos_token_id is None:
|
| 50 |
+
tokenizer.eos_token_id = tokenizer.eod_id
|
| 51 |
+
return model, tokenizer
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
register_model(
|
| 55 |
+
ModelMeta(
|
| 56 |
+
LLMModelType.qwen,
|
| 57 |
+
[
|
| 58 |
+
# qwen
|
| 59 |
+
ModelGroup([
|
| 60 |
+
# chat
|
| 61 |
+
Model('Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-1_8B-Chat'),
|
| 62 |
+
Model('Qwen/Qwen-7B-Chat', 'Qwen/Qwen-7B-Chat'),
|
| 63 |
+
Model('Qwen/Qwen-14B-Chat', 'Qwen/Qwen-14B-Chat'),
|
| 64 |
+
Model('Qwen/Qwen-72B-Chat', 'Qwen/Qwen-72B-Chat'),
|
| 65 |
+
# base
|
| 66 |
+
Model('Qwen/Qwen-1_8B', 'Qwen/Qwen-1_8B'),
|
| 67 |
+
Model('Qwen/Qwen-7B', 'Qwen/Qwen-7B'),
|
| 68 |
+
Model('Qwen/Qwen-14B', 'Qwen/Qwen-14B'),
|
| 69 |
+
Model('Qwen/Qwen-72B', 'Qwen/Qwen-72B'),
|
| 70 |
+
# gptq-int4
|
| 71 |
+
Model('Qwen/Qwen-1_8B-Chat-Int4', 'Qwen/Qwen-1_8B-Chat-Int4'),
|
| 72 |
+
Model('Qwen/Qwen-7B-Chat-Int4', 'Qwen/Qwen-7B-Chat-Int4'),
|
| 73 |
+
Model('Qwen/Qwen-14B-Chat-Int4', 'Qwen/Qwen-14B-Chat-Int4'),
|
| 74 |
+
Model('Qwen/Qwen-72B-Chat-Int4', 'Qwen/Qwen-72B-Chat-Int4'),
|
| 75 |
+
# gptq-int8
|
| 76 |
+
Model('Qwen/Qwen-1_8B-Chat-Int8', 'Qwen/Qwen-1_8B-Chat-Int8'),
|
| 77 |
+
Model('Qwen/Qwen-7B-Chat-Int8', 'Qwen/Qwen-7B-Chat-Int8'),
|
| 78 |
+
Model('Qwen/Qwen-14B-Chat-Int8', 'Qwen/Qwen-14B-Chat-Int8'),
|
| 79 |
+
Model('Qwen/Qwen-72B-Chat-Int8', 'Qwen/Qwen-72B-Chat-Int8'),
|
| 80 |
+
]),
|
| 81 |
+
# tongyi-finance
|
| 82 |
+
ModelGroup([
|
| 83 |
+
Model('TongyiFinance/Tongyi-Finance-14B-Chat', 'jxy/Tongyi-Finance-14B-Chat'),
|
| 84 |
+
Model('TongyiFinance/Tongyi-Finance-14B'),
|
| 85 |
+
Model('TongyiFinance/Tongyi-Finance-14B-Chat-Int4', 'jxy/Tongyi-Finance-14B-Chat-Int4'),
|
| 86 |
+
],
|
| 87 |
+
tags=['financial']),
|
| 88 |
+
],
|
| 89 |
+
TemplateType.qwen,
|
| 90 |
+
get_model_tokenizer_qwen,
|
| 91 |
+
architectures=['QWenLMHeadModel'],
|
| 92 |
+
model_arch=ModelArch.qwen))
|
| 93 |
+
|
| 94 |
+
register_model(
|
| 95 |
+
ModelMeta(
|
| 96 |
+
LLMModelType.modelscope_agent,
|
| 97 |
+
[ModelGroup([
|
| 98 |
+
Model('iic/ModelScope-Agent-7B'),
|
| 99 |
+
Model('iic/ModelScope-Agent-14B'),
|
| 100 |
+
])],
|
| 101 |
+
TemplateType.modelscope_agent,
|
| 102 |
+
get_model_tokenizer_qwen,
|
| 103 |
+
architectures=['QWenLMHeadModel'],
|
| 104 |
+
model_arch=ModelArch.qwen))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _qwen_vl_audio_decode(self, *args, skip_special_tokens=False, **kwargs) -> str:
|
| 108 |
+
if skip_special_tokens:
|
| 109 |
+
token_ids = kwargs['token_ids']
|
| 110 |
+
while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
|
| 111 |
+
token_ids.pop()
|
| 112 |
+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
|
| 113 |
+
else:
|
| 114 |
+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def fix_qwen_inplace_bug(model) -> None:
|
| 118 |
+
# qwen-vl, qwen-audio
|
| 119 |
+
first_drop = model.transformer.drop
|
| 120 |
+
if first_drop.p == 0.:
|
| 121 |
+
# fix in-place operation bug
|
| 122 |
+
patch_output_clone(first_drop)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_model_tokenizer_qwen_audio(model_dir: str,
|
| 126 |
+
model_info: ModelInfo,
|
| 127 |
+
model_kwargs: Dict[str, Any],
|
| 128 |
+
load_model: bool = True,
|
| 129 |
+
**kwargs):
|
| 130 |
+
tokenizer_config = get_tokenizer_config(model_dir)
|
| 131 |
+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
|
| 132 |
+
tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
|
| 133 |
+
tokenizer_cls._auto_class = 'AutoTokenizer'
|
| 134 |
+
tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
|
| 135 |
+
if not hasattr(tokenizer_cls, '_old_decode'):
|
| 136 |
+
tokenizer_cls._old_decode = tokenizer_cls._decode
|
| 137 |
+
tokenizer_cls._decode = _qwen_vl_audio_decode
|
| 138 |
+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
|
| 139 |
+
model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 140 |
+
if model is not None:
|
| 141 |
+
fix_qwen_inplace_bug(model)
|
| 142 |
+
|
| 143 |
+
return model, tokenizer
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
register_model(
|
| 147 |
+
ModelMeta(
|
| 148 |
+
MLLMModelType.qwen_audio, [
|
| 149 |
+
ModelGroup([
|
| 150 |
+
Model('Qwen/Qwen-Audio-Chat', 'Qwen/Qwen-Audio-Chat'),
|
| 151 |
+
Model('Qwen/Qwen-Audio', 'Qwen/Qwen-Audio'),
|
| 152 |
+
])
|
| 153 |
+
],
|
| 154 |
+
TemplateType.qwen_audio,
|
| 155 |
+
get_model_tokenizer_qwen_audio,
|
| 156 |
+
model_arch=ModelArch.qwen_audio,
|
| 157 |
+
architectures=['QWenLMHeadModel'],
|
| 158 |
+
additional_saved_files=['mel_filters.npz'],
|
| 159 |
+
tags=['audio']))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _qwen_vl_visual_block_forward(
|
| 163 |
+
self,
|
| 164 |
+
q_x: torch.Tensor,
|
| 165 |
+
k_x: Optional[torch.Tensor] = None,
|
| 166 |
+
v_x: Optional[torch.Tensor] = None,
|
| 167 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 168 |
+
):
|
| 169 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, 'ln_1_kv') and k_x is not None else None
|
| 170 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, 'ln_1_kv') and v_x is not None else None
|
| 171 |
+
|
| 172 |
+
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
|
| 173 |
+
z = self.mlp(self.ln_2(x))
|
| 174 |
+
x = x.to(z.device) + z # FIX
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_model_tokenizer_qwen_vl(model_dir: str,
|
| 179 |
+
model_info: ModelInfo,
|
| 180 |
+
model_kwargs: Dict[str, Any],
|
| 181 |
+
load_model: bool = True,
|
| 182 |
+
**kwargs):
|
| 183 |
+
if (model_kwargs.get('quantization_config') is not None
|
| 184 |
+
and isinstance(model_kwargs['quantization_config'], BitsAndBytesConfig)):
|
| 185 |
+
# https://github.com/pytorch/pytorch/issues/58969
|
| 186 |
+
model_kwargs['quantization_config'].llm_int8_skip_modules = ['lm_head', 'attn_pool.attn']
|
| 187 |
+
_TransformerBlock = get_class_from_dynamic_module('visual.TransformerBlock', model_dir)
|
| 188 |
+
|
| 189 |
+
def _get_cast_dtype(self) -> torch.dtype:
|
| 190 |
+
return self.resblocks[0].ln_1.weight.dtype
|
| 191 |
+
|
| 192 |
+
_TransformerBlock.__old_get_cast_dtype = _TransformerBlock.get_cast_dtype
|
| 193 |
+
_TransformerBlock.get_cast_dtype = _get_cast_dtype
|
| 194 |
+
|
| 195 |
+
tokenizer_config = get_tokenizer_config(model_dir)
|
| 196 |
+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
|
| 197 |
+
tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
|
| 198 |
+
tokenizer_cls._auto_class = 'AutoTokenizer'
|
| 199 |
+
tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
|
| 200 |
+
if not hasattr(tokenizer_cls, '_old_decode'):
|
| 201 |
+
tokenizer_cls._old_decode = tokenizer_cls._decode
|
| 202 |
+
tokenizer_cls._decode = _qwen_vl_audio_decode
|
| 203 |
+
# fix device_map is 4
|
| 204 |
+
n_gpu = get_device_count()
|
| 205 |
+
local_world_size = get_dist_setting()[3]
|
| 206 |
+
if n_gpu // local_world_size >= 4:
|
| 207 |
+
visual_block_cls = get_class_from_dynamic_module('visual.VisualAttentionBlock', model_dir)
|
| 208 |
+
visual_block_cls.__old_forward = visual_block_cls.forward
|
| 209 |
+
visual_block_cls.forward = _qwen_vl_visual_block_forward
|
| 210 |
+
|
| 211 |
+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
|
| 212 |
+
model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 213 |
+
if model is not None:
|
| 214 |
+
device_type = next(model.parameters()).device.type
|
| 215 |
+
fix_qwen_inplace_bug(model)
|
| 216 |
+
# fix device_map is 4
|
| 217 |
+
if n_gpu // local_world_size >= 4:
|
| 218 |
+
model.transformer.visual.proj.data = model.transformer.visual.proj.to(
|
| 219 |
+
model.transformer.visual.ln_post.bias.device)
|
| 220 |
+
# fix images cuda:1 bug
|
| 221 |
+
patch_fixed_device(model.transformer.visual, f'{device_type}:0')
|
| 222 |
+
return model, tokenizer
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
register_model(
|
| 226 |
+
ModelMeta(
|
| 227 |
+
MLLMModelType.qwen_vl, [
|
| 228 |
+
ModelGroup([
|
| 229 |
+
Model('Qwen/Qwen-VL-Chat', 'Qwen/Qwen-VL-Chat'),
|
| 230 |
+
Model('Qwen/Qwen-VL', 'Qwen/Qwen-VL'),
|
| 231 |
+
Model('Qwen/Qwen-VL-Chat-Int4', 'Qwen/Qwen-VL-Chat-Int4'),
|
| 232 |
+
])
|
| 233 |
+
],
|
| 234 |
+
TemplateType.qwen_vl,
|
| 235 |
+
get_model_tokenizer_qwen_vl,
|
| 236 |
+
model_arch=ModelArch.qwen_vl,
|
| 237 |
+
architectures=['QWenLMHeadModel'],
|
| 238 |
+
additional_saved_files=['SimSun.ttf'],
|
| 239 |
+
tags=['vision']))
|
| 240 |
+
|
| 241 |
+
register_model(
|
| 242 |
+
ModelMeta(
|
| 243 |
+
LLMModelType.qwen2,
|
| 244 |
+
[
|
| 245 |
+
# qwen1.5
|
| 246 |
+
ModelGroup([
|
| 247 |
+
# chat
|
| 248 |
+
Model('Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-0.5B-Chat'),
|
| 249 |
+
Model('Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-1.8B-Chat'),
|
| 250 |
+
Model('Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-4B-Chat'),
|
| 251 |
+
Model('Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-7B-Chat'),
|
| 252 |
+
Model('Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen1.5-14B-Chat'),
|
| 253 |
+
Model('Qwen/Qwen1.5-32B-Chat', 'Qwen/Qwen1.5-32B-Chat'),
|
| 254 |
+
Model('Qwen/Qwen1.5-72B-Chat', 'Qwen/Qwen1.5-72B-Chat'),
|
| 255 |
+
Model('Qwen/Qwen1.5-110B-Chat', 'Qwen/Qwen1.5-110B-Chat'),
|
| 256 |
+
# base
|
| 257 |
+
Model('Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B'),
|
| 258 |
+
Model('Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B'),
|
| 259 |
+
Model('Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B'),
|
| 260 |
+
Model('Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B'),
|
| 261 |
+
Model('Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B'),
|
| 262 |
+
Model('Qwen/Qwen1.5-32B', 'Qwen/Qwen1.5-32B'),
|
| 263 |
+
Model('Qwen/Qwen1.5-72B', 'Qwen/Qwen1.5-72B'),
|
| 264 |
+
Model('Qwen/Qwen1.5-110B', 'Qwen/Qwen1.5-110B'),
|
| 265 |
+
# gptq-int4
|
| 266 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4'),
|
| 267 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4'),
|
| 268 |
+
Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int4'),
|
| 269 |
+
Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int4'),
|
| 270 |
+
Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int4'),
|
| 271 |
+
Model('Qwen/Qwen1.5-32B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-32B-Chat-GPTQ-Int4'),
|
| 272 |
+
Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int4'),
|
| 273 |
+
Model('Qwen/Qwen1.5-110B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-110B-Chat-GPTQ-Int4'),
|
| 274 |
+
# gptq-int8
|
| 275 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8'),
|
| 276 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8'),
|
| 277 |
+
Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int8'),
|
| 278 |
+
Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int8'),
|
| 279 |
+
Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int8'),
|
| 280 |
+
Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int8'),
|
| 281 |
+
# awq-int4
|
| 282 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-AWQ', 'Qwen/Qwen1.5-0.5B-Chat-AWQ'),
|
| 283 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-AWQ', 'Qwen/Qwen1.5-1.8B-Chat-AWQ'),
|
| 284 |
+
Model('Qwen/Qwen1.5-4B-Chat-AWQ', 'Qwen/Qwen1.5-4B-Chat-AWQ'),
|
| 285 |
+
Model('Qwen/Qwen1.5-7B-Chat-AWQ', 'Qwen/Qwen1.5-7B-Chat-AWQ'),
|
| 286 |
+
Model('Qwen/Qwen1.5-14B-Chat-AWQ', 'Qwen/Qwen1.5-14B-Chat-AWQ'),
|
| 287 |
+
Model('Qwen/Qwen1.5-32B-Chat-AWQ', 'Qwen/Qwen1.5-32B-Chat-AWQ'),
|
| 288 |
+
Model('Qwen/Qwen1.5-72B-Chat-AWQ', 'Qwen/Qwen1.5-72B-Chat-AWQ'),
|
| 289 |
+
Model('Qwen/Qwen1.5-110B-Chat-AWQ', 'Qwen/Qwen1.5-110B-Chat-AWQ'),
|
| 290 |
+
]),
|
| 291 |
+
# code-qwen1.5
|
| 292 |
+
ModelGroup([
|
| 293 |
+
Model('Qwen/CodeQwen1.5-7B', 'Qwen/CodeQwen1.5-7B'),
|
| 294 |
+
Model('Qwen/CodeQwen1.5-7B-Chat', 'Qwen/CodeQwen1.5-7B-Chat'),
|
| 295 |
+
Model('Qwen/CodeQwen1.5-7B-Chat-AWQ', 'Qwen/CodeQwen1.5-7B-Chat-AWQ'),
|
| 296 |
+
],
|
| 297 |
+
tags=['coding']),
|
| 298 |
+
# qwen2
|
| 299 |
+
ModelGroup([
|
| 300 |
+
# instruct
|
| 301 |
+
Model('Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-0.5B-Instruct'),
|
| 302 |
+
Model('Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-1.5B-Instruct'),
|
| 303 |
+
Model('Qwen/Qwen2-7B-Instruct', 'Qwen/Qwen2-7B-Instruct'),
|
| 304 |
+
Model('Qwen/Qwen2-72B-Instruct', 'Qwen/Qwen2-72B-Instruct'),
|
| 305 |
+
# base
|
| 306 |
+
Model('Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B'),
|
| 307 |
+
Model('Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B'),
|
| 308 |
+
Model('Qwen/Qwen2-7B', 'Qwen/Qwen2-7B'),
|
| 309 |
+
Model('Qwen/Qwen2-72B', 'Qwen/Qwen2-72B'),
|
| 310 |
+
# gptq-int4
|
| 311 |
+
Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4'),
|
| 312 |
+
Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4'),
|
| 313 |
+
Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int4'),
|
| 314 |
+
Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int4'),
|
| 315 |
+
# gptq-int8
|
| 316 |
+
Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8'),
|
| 317 |
+
Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8'),
|
| 318 |
+
Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int8'),
|
| 319 |
+
Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int8'),
|
| 320 |
+
# awq-int4
|
| 321 |
+
Model('Qwen/Qwen2-0.5B-Instruct-AWQ', 'Qwen/Qwen2-0.5B-Instruct-AWQ'),
|
| 322 |
+
Model('Qwen/Qwen2-1.5B-Instruct-AWQ', 'Qwen/Qwen2-1.5B-Instruct-AWQ'),
|
| 323 |
+
Model('Qwen/Qwen2-7B-Instruct-AWQ', 'Qwen/Qwen2-7B-Instruct-AWQ'),
|
| 324 |
+
Model('Qwen/Qwen2-72B-Instruct-AWQ', 'Qwen/Qwen2-72B-Instruct-AWQ'),
|
| 325 |
+
]),
|
| 326 |
+
# qwen2-math
|
| 327 |
+
ModelGroup(
|
| 328 |
+
[
|
| 329 |
+
# instruct
|
| 330 |
+
Model('Qwen/Qwen2-Math-1.5B-Instruct', 'Qwen/Qwen2-Math-1.5B-Instruct'),
|
| 331 |
+
Model('Qwen/Qwen2-Math-7B-Instruct', 'Qwen/Qwen2-Math-7B-Instruct'),
|
| 332 |
+
Model('Qwen/Qwen2-Math-72B-Instruct', 'Qwen/Qwen2-Math-72B-Instruct'),
|
| 333 |
+
# base
|
| 334 |
+
Model('Qwen/Qwen2-Math-1.5B', 'Qwen/Qwen2-Math-1.5B'),
|
| 335 |
+
Model('Qwen/Qwen2-Math-7B', 'Qwen/Qwen2-Math-7B'),
|
| 336 |
+
Model('Qwen/Qwen2-Math-72B', 'Qwen/Qwen2-Math-72B'),
|
| 337 |
+
],
|
| 338 |
+
tags=['math']),
|
| 339 |
+
# qwen2.5-1m
|
| 340 |
+
ModelGroup([
|
| 341 |
+
Model('Qwen/Qwen2.5-7B-Instruct-1M', 'Qwen/Qwen2.5-7B-Instruct-1M'),
|
| 342 |
+
Model('Qwen/Qwen2.5-14B-Instruct-1M', 'Qwen/Qwen2.5-14B-Instruct-1M'),
|
| 343 |
+
]),
|
| 344 |
+
# other
|
| 345 |
+
ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]),
|
| 346 |
+
],
|
| 347 |
+
TemplateType.qwen,
|
| 348 |
+
get_model_tokenizer_with_flash_attn,
|
| 349 |
+
architectures=['Qwen2ForCausalLM'],
|
| 350 |
+
requires=['transformers>=4.37'],
|
| 351 |
+
model_arch=ModelArch.llama))
|
| 352 |
+
|
| 353 |
+
register_model(
|
| 354 |
+
ModelMeta(
|
| 355 |
+
LLMModelType.qwen2_5,
|
| 356 |
+
[
|
| 357 |
+
# qwen2.5
|
| 358 |
+
ModelGroup([
|
| 359 |
+
# instruct
|
| 360 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'),
|
| 361 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'),
|
| 362 |
+
Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'),
|
| 363 |
+
Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'),
|
| 364 |
+
Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'),
|
| 365 |
+
Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
|
| 366 |
+
Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
|
| 367 |
+
# base
|
| 368 |
+
Model('Qwen/Qwen2.5-0.5B', 'Qwen/Qwen2.5-0.5B'),
|
| 369 |
+
Model('Qwen/Qwen2.5-1.5B', 'Qwen/Qwen2.5-1.5B'),
|
| 370 |
+
Model('Qwen/Qwen2.5-3B', 'Qwen/Qwen2.5-3B'),
|
| 371 |
+
Model('Qwen/Qwen2.5-7B', 'Qwen/Qwen2.5-7B'),
|
| 372 |
+
Model('Qwen/Qwen2.5-14B', 'Qwen/Qwen2.5-14B'),
|
| 373 |
+
Model('Qwen/Qwen2.5-32B', 'Qwen/Qwen2.5-32B'),
|
| 374 |
+
Model('Qwen/Qwen2.5-72B', 'Qwen/Qwen2.5-72B'),
|
| 375 |
+
# gptq-int4
|
| 376 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4'),
|
| 377 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4'),
|
| 378 |
+
Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4'),
|
| 379 |
+
Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4'),
|
| 380 |
+
Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4'),
|
| 381 |
+
Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4'),
|
| 382 |
+
Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4'),
|
| 383 |
+
# gptq-int8
|
| 384 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8'),
|
| 385 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8'),
|
| 386 |
+
Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8'),
|
| 387 |
+
Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8'),
|
| 388 |
+
Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8'),
|
| 389 |
+
Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8'),
|
| 390 |
+
Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8'),
|
| 391 |
+
# awq-int4
|
| 392 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-0.5B-Instruct-AWQ'),
|
| 393 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-1.5B-Instruct-AWQ'),
|
| 394 |
+
Model('Qwen/Qwen2.5-3B-Instruct-AWQ', 'Qwen/Qwen2.5-3B-Instruct-AWQ'),
|
| 395 |
+
Model('Qwen/Qwen2.5-7B-Instruct-AWQ', 'Qwen/Qwen2.5-7B-Instruct-AWQ'),
|
| 396 |
+
Model('Qwen/Qwen2.5-14B-Instruct-AWQ', 'Qwen/Qwen2.5-14B-Instruct-AWQ'),
|
| 397 |
+
Model('Qwen/Qwen2.5-32B-Instruct-AWQ', 'Qwen/Qwen2.5-32B-Instruct-AWQ'),
|
| 398 |
+
Model('Qwen/Qwen2.5-72B-Instruct-AWQ', 'Qwen/Qwen2.5-72B-Instruct-AWQ'),
|
| 399 |
+
]),
|
| 400 |
+
# qwen2.5-coder
|
| 401 |
+
ModelGroup(
|
| 402 |
+
[
|
| 403 |
+
# instruct
|
| 404 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct', 'Qwen/Qwen2.5-Coder-0.5B-Instruct'),
|
| 405 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct', 'Qwen/Qwen2.5-Coder-1.5B-Instruct'),
|
| 406 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct', 'Qwen/Qwen2.5-Coder-3B-Instruct'),
|
| 407 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
|
| 408 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct', 'Qwen/Qwen2.5-Coder-14B-Instruct'),
|
| 409 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct', 'Qwen/Qwen2.5-Coder-32B-Instruct'),
|
| 410 |
+
# base
|
| 411 |
+
Model('Qwen/Qwen2.5-Coder-0.5B', 'Qwen/Qwen2.5-Coder-0.5B'),
|
| 412 |
+
Model('Qwen/Qwen2.5-Coder-1.5B', 'Qwen/Qwen2.5-Coder-1.5B'),
|
| 413 |
+
Model('Qwen/Qwen2.5-Coder-3B', 'Qwen/Qwen2.5-Coder-3B'),
|
| 414 |
+
Model('Qwen/Qwen2.5-Coder-7B', 'Qwen/Qwen2.5-Coder-7B'),
|
| 415 |
+
Model('Qwen/Qwen2.5-Coder-14B', 'Qwen/Qwen2.5-Coder-14B'),
|
| 416 |
+
Model('Qwen/Qwen2.5-Coder-32B', 'Qwen/Qwen2.5-Coder-32B'),
|
| 417 |
+
# AWQ
|
| 418 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ'),
|
| 419 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ'),
|
| 420 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-3B-Instruct-AWQ'),
|
| 421 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-7B-Instruct-AWQ'),
|
| 422 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-14B-Instruct-AWQ'),
|
| 423 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-32B-Instruct-AWQ'),
|
| 424 |
+
# GPTQ
|
| 425 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4'),
|
| 426 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8'),
|
| 427 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4'),
|
| 428 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8'),
|
| 429 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4'),
|
| 430 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8'),
|
| 431 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4'),
|
| 432 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8'),
|
| 433 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4'),
|
| 434 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8'),
|
| 435 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4'),
|
| 436 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8'),
|
| 437 |
+
],
|
| 438 |
+
tags=['coding'])
|
| 439 |
+
],
|
| 440 |
+
TemplateType.qwen2_5,
|
| 441 |
+
get_model_tokenizer_with_flash_attn,
|
| 442 |
+
architectures=['Qwen2ForCausalLM'],
|
| 443 |
+
requires=['transformers>=4.37'],
|
| 444 |
+
model_arch=ModelArch.llama))
|
| 445 |
+
|
| 446 |
+
register_model(
|
| 447 |
+
ModelMeta(
|
| 448 |
+
LLMModelType.qwen2_5_math,
|
| 449 |
+
[
|
| 450 |
+
# qwen2.5-math
|
| 451 |
+
ModelGroup(
|
| 452 |
+
[
|
| 453 |
+
# instruct
|
| 454 |
+
Model('Qwen/Qwen2.5-Math-1.5B-Instruct', 'Qwen/Qwen2.5-Math-1.5B-Instruct'),
|
| 455 |
+
Model('Qwen/Qwen2.5-Math-7B-Instruct', 'Qwen/Qwen2.5-Math-7B-Instruct'),
|
| 456 |
+
Model('Qwen/Qwen2.5-Math-72B-Instruct', 'Qwen/Qwen2.5-Math-72B-Instruct'),
|
| 457 |
+
# base
|
| 458 |
+
Model('Qwen/Qwen2.5-Math-1.5B', 'Qwen/Qwen2.5-Math-1.5B'),
|
| 459 |
+
Model('Qwen/Qwen2.5-Math-7B', 'Qwen/Qwen2.5-Math-7B'),
|
| 460 |
+
Model('Qwen/Qwen2.5-Math-72B', 'Qwen/Qwen2.5-Math-72B'),
|
| 461 |
+
],
|
| 462 |
+
tags=['math']),
|
| 463 |
+
],
|
| 464 |
+
TemplateType.qwen2_5_math,
|
| 465 |
+
get_model_tokenizer_with_flash_attn,
|
| 466 |
+
architectures=['Qwen2ForCausalLM'],
|
| 467 |
+
requires=['transformers>=4.37'],
|
| 468 |
+
model_arch=ModelArch.llama))
|
| 469 |
+
|
| 470 |
+
register_model(
|
| 471 |
+
ModelMeta(
|
| 472 |
+
LLMModelType.qwen2_moe,
|
| 473 |
+
[
|
| 474 |
+
# qwen1.5-moe
|
| 475 |
+
ModelGroup([
|
| 476 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'Qwen/Qwen1.5-MoE-A2.7B-Chat'),
|
| 477 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B', 'Qwen/Qwen1.5-MoE-A2.7B'),
|
| 478 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4'),
|
| 479 |
+
]),
|
| 480 |
+
ModelGroup([
|
| 481 |
+
Model('Qwen/Qwen2-57B-A14B-Instruct', 'Qwen/Qwen2-57B-A14B-Instruct'),
|
| 482 |
+
Model('Qwen/Qwen2-57B-A14B', 'Qwen/Qwen2-57B-A14B'),
|
| 483 |
+
Model('Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4'),
|
| 484 |
+
])
|
| 485 |
+
],
|
| 486 |
+
TemplateType.qwen,
|
| 487 |
+
get_model_tokenizer_with_flash_attn,
|
| 488 |
+
architectures=['Qwen2MoeForCausalLM'],
|
| 489 |
+
requires=['transformers>=4.40'],
|
| 490 |
+
))
|
| 491 |
+
|
| 492 |
+
register_model(
|
| 493 |
+
ModelMeta(
|
| 494 |
+
LLMModelType.qwen3,
|
| 495 |
+
[
|
| 496 |
+
ModelGroup([
|
| 497 |
+
Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'),
|
| 498 |
+
Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'),
|
| 499 |
+
Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'),
|
| 500 |
+
Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'),
|
| 501 |
+
Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'),
|
| 502 |
+
Model('Qwen/Qwen3-32B-Base', 'Qwen/Qwen3-32B-Base'),
|
| 503 |
+
# instruct
|
| 504 |
+
Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'),
|
| 505 |
+
Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'),
|
| 506 |
+
Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'),
|
| 507 |
+
Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'),
|
| 508 |
+
Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'),
|
| 509 |
+
Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'),
|
| 510 |
+
# fp8
|
| 511 |
+
Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'),
|
| 512 |
+
Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'),
|
| 513 |
+
Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'),
|
| 514 |
+
Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'),
|
| 515 |
+
Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'),
|
| 516 |
+
Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'),
|
| 517 |
+
# awq
|
| 518 |
+
Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'),
|
| 519 |
+
Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'),
|
| 520 |
+
Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'),
|
| 521 |
+
Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'),
|
| 522 |
+
# swift
|
| 523 |
+
Model('swift/Qwen3-32B-AWQ'),
|
| 524 |
+
]),
|
| 525 |
+
],
|
| 526 |
+
TemplateType.qwen3,
|
| 527 |
+
get_model_tokenizer_with_flash_attn,
|
| 528 |
+
architectures=['Qwen3ForCausalLM'],
|
| 529 |
+
requires=['transformers>=4.51'],
|
| 530 |
+
model_arch=ModelArch.llama))
|
| 531 |
+
|
| 532 |
+
register_model(
|
| 533 |
+
ModelMeta(
|
| 534 |
+
LLMModelType.qwen3_moe,
|
| 535 |
+
[
|
| 536 |
+
ModelGroup([
|
| 537 |
+
Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'),
|
| 538 |
+
Model('Qwen/Qwen3-235B-A22B-Base', 'Qwen/Qwen3-235B-A22B-Base'),
|
| 539 |
+
# instruct
|
| 540 |
+
Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'),
|
| 541 |
+
Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'),
|
| 542 |
+
# fp8
|
| 543 |
+
Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'),
|
| 544 |
+
Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'),
|
| 545 |
+
# awq
|
| 546 |
+
Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'),
|
| 547 |
+
Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'),
|
| 548 |
+
]),
|
| 549 |
+
],
|
| 550 |
+
TemplateType.qwen3,
|
| 551 |
+
get_model_tokenizer_with_flash_attn,
|
| 552 |
+
architectures=['Qwen3MoeForCausalLM'],
|
| 553 |
+
requires=['transformers>=4.51'],
|
| 554 |
+
))
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def patch_qwen_vl_utils(vision_process):
|
| 558 |
+
if hasattr(vision_process, '_patch'):
|
| 559 |
+
return
|
| 560 |
+
for key in [
|
| 561 |
+
'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels',
|
| 562 |
+
'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames'
|
| 563 |
+
]:
|
| 564 |
+
type_func = float if key == 'fps' else int
|
| 565 |
+
setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper())))
|
| 566 |
+
_read_video_decord = vision_process._read_video_decord
|
| 567 |
+
|
| 568 |
+
def _new_read_video_decord(ele: dict):
|
| 569 |
+
from swift.llm import load_file
|
| 570 |
+
ele['video'] = load_file(ele['video'])
|
| 571 |
+
return _read_video_decord(ele)
|
| 572 |
+
|
| 573 |
+
vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord
|
| 574 |
+
vision_process._patch = True
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def get_model_tokenizer_qwen2_vl(*args, **kwargs):
|
| 578 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 579 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
|
| 580 |
+
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
|
| 581 |
+
if model is not None and hasattr(model.model, 'embed_tokens'):
|
| 582 |
+
patch_output_clone(model.model.embed_tokens)
|
| 583 |
+
patch_output_to_input_device(model.model.embed_tokens)
|
| 584 |
+
|
| 585 |
+
from qwen_vl_utils import vision_process
|
| 586 |
+
patch_qwen_vl_utils(vision_process)
|
| 587 |
+
return model, tokenizer
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
register_model(
|
| 591 |
+
ModelMeta(
|
| 592 |
+
MLLMModelType.qwen2_vl,
|
| 593 |
+
[
|
| 594 |
+
ModelGroup(
|
| 595 |
+
[
|
| 596 |
+
# chat
|
| 597 |
+
Model('Qwen/Qwen2-VL-2B-Instruct', 'Qwen/Qwen2-VL-2B-Instruct'),
|
| 598 |
+
Model('Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-7B-Instruct'),
|
| 599 |
+
Model('Qwen/Qwen2-VL-72B-Instruct', 'Qwen/Qwen2-VL-72B-Instruct'),
|
| 600 |
+
# base
|
| 601 |
+
Model('Qwen/Qwen2-VL-2B', 'Qwen/Qwen2-VL-2B'),
|
| 602 |
+
Model('Qwen/Qwen2-VL-7B', 'Qwen/Qwen2-VL-7B'),
|
| 603 |
+
Model('Qwen/Qwen2-VL-72B', 'Qwen/Qwen2-VL-72B'),
|
| 604 |
+
# gptq-int4
|
| 605 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4'),
|
| 606 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4'),
|
| 607 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4'),
|
| 608 |
+
# gptq-int8
|
| 609 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8'),
|
| 610 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8'),
|
| 611 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8'),
|
| 612 |
+
# awq-int4
|
| 613 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-AWQ', 'Qwen/Qwen2-VL-2B-Instruct-AWQ'),
|
| 614 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-AWQ', 'Qwen/Qwen2-VL-7B-Instruct-AWQ'),
|
| 615 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-AWQ', 'Qwen/Qwen2-VL-72B-Instruct-AWQ'),
|
| 616 |
+
], ),
|
| 617 |
+
ModelGroup([
|
| 618 |
+
Model('bytedance-research/UI-TARS-2B-SFT', 'bytedance-research/UI-TARS-2B-SFT'),
|
| 619 |
+
Model('bytedance-research/UI-TARS-7B-SFT', 'bytedance-research/UI-TARS-7B-SFT'),
|
| 620 |
+
Model('bytedance-research/UI-TARS-7B-DPO', 'bytedance-research/UI-TARS-7B-DPO'),
|
| 621 |
+
Model('bytedance-research/UI-TARS-72B-SFT', 'bytedance-research/UI-TARS-72B-SFT'),
|
| 622 |
+
Model('bytedance-research/UI-TARS-72B-DPO', 'bytedance-research/UI-TARS-72B-DPO'),
|
| 623 |
+
]),
|
| 624 |
+
ModelGroup([
|
| 625 |
+
Model('allenai/olmOCR-7B-0225-preview', 'allenai/olmOCR-7B-0225-preview'),
|
| 626 |
+
]),
|
| 627 |
+
],
|
| 628 |
+
TemplateType.qwen2_vl,
|
| 629 |
+
get_model_tokenizer_qwen2_vl,
|
| 630 |
+
model_arch=ModelArch.qwen2_vl,
|
| 631 |
+
architectures=['Qwen2VLForConditionalGeneration'],
|
| 632 |
+
requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 633 |
+
tags=['vision', 'video']))
|
| 634 |
+
|
| 635 |
+
register_model(
|
| 636 |
+
ModelMeta(
|
| 637 |
+
MLLMModelType.qvq, [
|
| 638 |
+
ModelGroup([
|
| 639 |
+
Model('Qwen/QVQ-72B-Preview', 'Qwen/QVQ-72B-Preview'),
|
| 640 |
+
]),
|
| 641 |
+
],
|
| 642 |
+
TemplateType.qvq,
|
| 643 |
+
get_model_tokenizer_qwen2_vl,
|
| 644 |
+
model_arch=ModelArch.qwen2_vl,
|
| 645 |
+
architectures=['Qwen2VLForConditionalGeneration'],
|
| 646 |
+
requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 647 |
+
tags=['vision', 'video']))
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def get_model_tokenizer_qwen2_5_vl(*args, **kwargs):
|
| 651 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 652 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5_VLForConditionalGeneration
|
| 653 |
+
return get_model_tokenizer_qwen2_vl(*args, **kwargs)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
register_model(
|
| 657 |
+
ModelMeta(
|
| 658 |
+
MLLMModelType.qwen2_5_vl, [
|
| 659 |
+
ModelGroup([
|
| 660 |
+
Model('Qwen/Qwen2.5-VL-3B-Instruct', 'Qwen/Qwen2.5-VL-3B-Instruct'),
|
| 661 |
+
Model('Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-7B-Instruct'),
|
| 662 |
+
Model('Qwen/Qwen2.5-VL-32B-Instruct', 'Qwen/Qwen2.5-VL-32B-Instruct'),
|
| 663 |
+
Model('Qwen/Qwen2.5-VL-72B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct'),
|
| 664 |
+
]),
|
| 665 |
+
ModelGroup([
|
| 666 |
+
Model('Qwen/Qwen2.5-VL-3B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-3B-Instruct-AWQ'),
|
| 667 |
+
Model('Qwen/Qwen2.5-VL-7B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-7B-Instruct-AWQ'),
|
| 668 |
+
Model('Qwen/Qwen2.5-VL-32B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-32B-Instruct-AWQ'),
|
| 669 |
+
Model('Qwen/Qwen2.5-VL-72B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-72B-Instruct-AWQ'),
|
| 670 |
+
]),
|
| 671 |
+
],
|
| 672 |
+
TemplateType.qwen2_5_vl,
|
| 673 |
+
get_model_tokenizer_qwen2_5_vl,
|
| 674 |
+
model_arch=ModelArch.qwen2_vl,
|
| 675 |
+
architectures=['Qwen2_5_VLForConditionalGeneration'],
|
| 676 |
+
requires=['transformers>=4.49', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 677 |
+
tags=['vision', 'video']))
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
|
| 681 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniConfig
|
| 682 |
+
from qwen_omni_utils import vision_process
|
| 683 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5OmniForConditionalGeneration
|
| 684 |
+
processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code=True)
|
| 685 |
+
kwargs['tokenizer'] = processor.tokenizer
|
| 686 |
+
kwargs['model_config'] = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 687 |
+
patch_qwen_vl_utils(vision_process)
|
| 688 |
+
kwargs['model_config'].enable_audio_output = get_env_args('ENABLE_AUDIO_OUTPUT', bool, True)
|
| 689 |
+
model, _ = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
|
| 690 |
+
if model:
|
| 691 |
+
use_submodel_func(model, 'thinker')
|
| 692 |
+
model.config.keys_to_ignore_at_inference += ['hidden_states', 'attention_mask']
|
| 693 |
+
model.config.talker_config.pad_token_id = None
|
| 694 |
+
return model, processor
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
register_model(
|
| 698 |
+
ModelMeta(
|
| 699 |
+
MLLMModelType.qwen2_5_omni,
|
| 700 |
+
[
|
| 701 |
+
ModelGroup([
|
| 702 |
+
Model('Qwen/Qwen2.5-Omni-3B', 'Qwen/Qwen2.5-Omni-3B'),
|
| 703 |
+
Model('Qwen/Qwen2.5-Omni-7B', 'Qwen/Qwen2.5-Omni-7B'),
|
| 704 |
+
]),
|
| 705 |
+
],
|
| 706 |
+
TemplateType.qwen2_5_omni,
|
| 707 |
+
get_model_tokenizer_qwen2_5_omni,
|
| 708 |
+
model_arch=ModelArch.qwen2_5_omni,
|
| 709 |
+
architectures=['Qwen2_5OmniModel'],
|
| 710 |
+
requires=['transformers>=4.50', 'soundfile', 'qwen_omni_utils', 'decord'],
|
| 711 |
+
tags=['vision', 'video', 'audio'],
|
| 712 |
+
additional_saved_files=['spk_dict.pt'],
|
| 713 |
+
ignore_patterns=[],
|
| 714 |
+
))
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def get_model_tokenizer_qwen2_audio(*args, **kwargs):
|
| 718 |
+
from transformers import Qwen2AudioForConditionalGeneration
|
| 719 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2AudioForConditionalGeneration
|
| 720 |
+
return get_model_tokenizer_multimodal(*args, **kwargs)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
register_model(
|
| 724 |
+
ModelMeta(
|
| 725 |
+
MLLMModelType.qwen2_audio,
|
| 726 |
+
[
|
| 727 |
+
ModelGroup([
|
| 728 |
+
Model('Qwen/Qwen2-Audio-7B-Instruct', 'Qwen/Qwen2-Audio-7B-Instruct'),
|
| 729 |
+
Model('Qwen/Qwen2-Audio-7B', 'Qwen/Qwen2-Audio-7B'),
|
| 730 |
+
]),
|
| 731 |
+
],
|
| 732 |
+
TemplateType.qwen2_audio,
|
| 733 |
+
get_model_tokenizer_qwen2_audio,
|
| 734 |
+
model_arch=ModelArch.qwen2_audio,
|
| 735 |
+
architectures=['Qwen2AudioForConditionalGeneration'],
|
| 736 |
+
requires=['transformers>=4.45,<4.49', 'librosa'],
|
| 737 |
+
tags=['audio'],
|
| 738 |
+
))
|
| 739 |
+
|
| 740 |
+
register_model(
|
| 741 |
+
ModelMeta(
|
| 742 |
+
LLMModelType.marco_o1, [ModelGroup([Model('AIDC-AI/Marco-o1', 'AIDC-AI/Marco-o1')])],
|
| 743 |
+
TemplateType.marco_o1,
|
| 744 |
+
get_model_tokenizer_with_flash_attn,
|
| 745 |
+
model_arch=ModelArch.llama,
|
| 746 |
+
architectures=['Qwen2ForCausalLM'],
|
| 747 |
+
requires=['transformers>=4.37']))
|
| 748 |
+
|
| 749 |
+
register_model(
|
| 750 |
+
ModelMeta(
|
| 751 |
+
LLMModelType.qwq_preview, [ModelGroup([Model('Qwen/QwQ-32B-Preview', 'Qwen/QwQ-32B-Preview')])],
|
| 752 |
+
TemplateType.qwq_preview,
|
| 753 |
+
get_model_tokenizer_with_flash_attn,
|
| 754 |
+
model_arch=ModelArch.llama,
|
| 755 |
+
architectures=['Qwen2ForCausalLM'],
|
| 756 |
+
requires=['transformers>=4.37']))
|
| 757 |
+
|
| 758 |
+
register_model(
|
| 759 |
+
ModelMeta(
|
| 760 |
+
LLMModelType.qwq,
|
| 761 |
+
[ModelGroup([
|
| 762 |
+
Model('Qwen/QwQ-32B', 'Qwen/QwQ-32B'),
|
| 763 |
+
Model('Qwen/QwQ-32B-AWQ', 'Qwen/QwQ-32B-AWQ'),
|
| 764 |
+
])],
|
| 765 |
+
TemplateType.qwq,
|
| 766 |
+
get_model_tokenizer_with_flash_attn,
|
| 767 |
+
model_arch=ModelArch.llama,
|
| 768 |
+
architectures=['Qwen2ForCausalLM'],
|
| 769 |
+
requires=['transformers>=4.37']))
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def get_model_tokenizer_ovis(*args, **kwargs):
|
| 773 |
+
kwargs['attn_impl_keys'] = ['llm_attn_implementation']
|
| 774 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 775 |
+
if model is not None:
|
| 776 |
+
model.visual_tokenizer.to(model.dtype)
|
| 777 |
+
model.vte.to(model.dtype)
|
| 778 |
+
|
| 779 |
+
model.generation_config.cache_implementation = None
|
| 780 |
+
func_list = ['generate', 'forward', 'get_input_embeddings']
|
| 781 |
+
use_submodel_func(model, 'llm', func_list)
|
| 782 |
+
embedding = model.get_input_embeddings()
|
| 783 |
+
patch_output_clone(embedding)
|
| 784 |
+
try:
|
| 785 |
+
# fix device_map
|
| 786 |
+
from transformers.cache_utils import HybridCache
|
| 787 |
+
|
| 788 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
|
| 789 |
+
**kwargs) -> Tuple[torch.Tensor]:
|
| 790 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
| 791 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
| 792 |
+
return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)
|
| 793 |
+
|
| 794 |
+
if not hasattr(HybridCache, '_update_origin'):
|
| 795 |
+
HybridCache._update_origin = HybridCache.update
|
| 796 |
+
HybridCache.update = update
|
| 797 |
+
except ImportError:
|
| 798 |
+
pass
|
| 799 |
+
return model, tokenizer
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
register_model(
|
| 803 |
+
ModelMeta(
|
| 804 |
+
MLLMModelType.ovis1_6,
|
| 805 |
+
[
|
| 806 |
+
ModelGroup([
|
| 807 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-9B', 'AIDC-AI/Ovis1.6-Gemma2-9B'),
|
| 808 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4', 'AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4'),
|
| 809 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-27B', 'AIDC-AI/Ovis1.6-Gemma2-27B'),
|
| 810 |
+
]),
|
| 811 |
+
],
|
| 812 |
+
TemplateType.ovis1_6,
|
| 813 |
+
get_model_tokenizer_ovis,
|
| 814 |
+
model_arch=ModelArch.ovis1_6,
|
| 815 |
+
architectures=['Ovis'],
|
| 816 |
+
tags=['vision'],
|
| 817 |
+
requires=['transformers>=4.42'],
|
| 818 |
+
))
|
| 819 |
+
|
| 820 |
+
register_model(
|
| 821 |
+
ModelMeta(
|
| 822 |
+
MLLMModelType.ovis1_6_llama3,
|
| 823 |
+
[
|
| 824 |
+
ModelGroup([
|
| 825 |
+
Model('AIDC-AI/Ovis1.6-Llama3.2-3B', 'AIDC-AI/Ovis1.6-Llama3.2-3B'),
|
| 826 |
+
]),
|
| 827 |
+
],
|
| 828 |
+
TemplateType.ovis1_6_llama3,
|
| 829 |
+
get_model_tokenizer_ovis,
|
| 830 |
+
model_arch=ModelArch.ovis1_6,
|
| 831 |
+
architectures=['Ovis'],
|
| 832 |
+
tags=['vision'],
|
| 833 |
+
))
|
| 834 |
+
|
| 835 |
+
register_model(
|
| 836 |
+
ModelMeta(
|
| 837 |
+
MLLMModelType.ovis2,
|
| 838 |
+
[
|
| 839 |
+
ModelGroup([
|
| 840 |
+
Model('AIDC-AI/Ovis2-1B', 'AIDC-AI/Ovis2-1B'),
|
| 841 |
+
Model('AIDC-AI/Ovis2-2B', 'AIDC-AI/Ovis2-2B'),
|
| 842 |
+
Model('AIDC-AI/Ovis2-4B', 'AIDC-AI/Ovis2-4B'),
|
| 843 |
+
Model('AIDC-AI/Ovis2-8B', 'AIDC-AI/Ovis2-8B'),
|
| 844 |
+
Model('AIDC-AI/Ovis2-16B', 'AIDC-AI/Ovis2-16B'),
|
| 845 |
+
Model('AIDC-AI/Ovis2-34B', 'AIDC-AI/Ovis2-34B'),
|
| 846 |
+
]),
|
| 847 |
+
],
|
| 848 |
+
TemplateType.ovis2,
|
| 849 |
+
get_model_tokenizer_ovis,
|
| 850 |
+
model_arch=ModelArch.ovis1_6,
|
| 851 |
+
architectures=['Ovis'],
|
| 852 |
+
tags=['vision'],
|
| 853 |
+
requires=['transformers>=4.46.2', 'moviepy<2'],
|
| 854 |
+
))
|
| 855 |
+
|
| 856 |
+
register_model(
|
| 857 |
+
ModelMeta(
|
| 858 |
+
RMModelType.qwen2_reward,
|
| 859 |
+
[
|
| 860 |
+
ModelGroup([
|
| 861 |
+
Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
|
| 862 |
+
]),
|
| 863 |
+
],
|
| 864 |
+
TemplateType.qwen,
|
| 865 |
+
get_model_tokenizer_reward_model,
|
| 866 |
+
architectures=['Qwen2ForRewardModel'],
|
| 867 |
+
requires=['transformers>=4.37'],
|
| 868 |
+
))
|
| 869 |
+
|
| 870 |
+
register_model(
|
| 871 |
+
ModelMeta(
|
| 872 |
+
RMModelType.qwen2_5_prm,
|
| 873 |
+
[
|
| 874 |
+
ModelGroup([
|
| 875 |
+
Model('Qwen/Qwen2.5-Math-PRM-7B', 'Qwen/Qwen2.5-Math-PRM-7B'),
|
| 876 |
+
Model('Qwen/Qwen2.5-Math-7B-PRM800K', 'Qwen/Qwen2.5-Math-7B-PRM800K'),
|
| 877 |
+
Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'),
|
| 878 |
+
]),
|
| 879 |
+
],
|
| 880 |
+
TemplateType.qwen2_5_math_prm,
|
| 881 |
+
get_model_tokenizer_reward_model,
|
| 882 |
+
task_type='prm',
|
| 883 |
+
architectures=['Qwen2ForProcessRewardModel'],
|
| 884 |
+
requires=['transformers>=4.37'],
|
| 885 |
+
))
|
| 886 |
+
|
| 887 |
+
register_model(
|
| 888 |
+
ModelMeta(
|
| 889 |
+
RMModelType.qwen2_5_math_reward,
|
| 890 |
+
[
|
| 891 |
+
ModelGroup([
|
| 892 |
+
Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
|
| 893 |
+
]),
|
| 894 |
+
],
|
| 895 |
+
TemplateType.qwen2_5_math,
|
| 896 |
+
get_model_tokenizer_reward_model,
|
| 897 |
+
architectures=['Qwen2ForRewardModel'],
|
| 898 |
+
requires=['transformers>=4.37'],
|
| 899 |
+
))
|
swift/llm/model/model/skywork.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from swift.llm import TemplateType
|
| 6 |
+
from ..constant import LLMModelType, RMModelType
|
| 7 |
+
from ..model_arch import ModelArch
|
| 8 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 9 |
+
from ..utils import ModelInfo
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_skywork_model_tokenizer(model_dir: str,
|
| 13 |
+
model_info: ModelInfo,
|
| 14 |
+
model_kwargs: Dict[str, Any],
|
| 15 |
+
load_model: bool = True,
|
| 16 |
+
**kwargs):
|
| 17 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 18 |
+
if 'chat' in model_dir:
|
| 19 |
+
tokenizer.add_tokens('[USER]')
|
| 20 |
+
tokenizer.add_tokens('[BOT]')
|
| 21 |
+
tokenizer.add_tokens('[SEP]')
|
| 22 |
+
return model, tokenizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
register_model(
|
| 26 |
+
ModelMeta(
|
| 27 |
+
LLMModelType.skywork,
|
| 28 |
+
[
|
| 29 |
+
ModelGroup([
|
| 30 |
+
Model('skywork/Skywork-13B-base', 'skywork/Skywork-13B-base'),
|
| 31 |
+
Model('skywork/Skywork-13B-chat'),
|
| 32 |
+
]),
|
| 33 |
+
],
|
| 34 |
+
TemplateType.skywork,
|
| 35 |
+
get_skywork_model_tokenizer,
|
| 36 |
+
architectures=['SkyworkForCausalLM'],
|
| 37 |
+
model_arch=ModelArch.llama,
|
| 38 |
+
))
|
| 39 |
+
|
| 40 |
+
register_model(
|
| 41 |
+
ModelMeta(
|
| 42 |
+
LLMModelType.skywork_o1,
|
| 43 |
+
[
|
| 44 |
+
ModelGroup([
|
| 45 |
+
Model('AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B', 'Skywork/Skywork-o1-Open-Llama-3.1-8B'),
|
| 46 |
+
]),
|
| 47 |
+
],
|
| 48 |
+
TemplateType.skywork_o1,
|
| 49 |
+
get_model_tokenizer_with_flash_attn,
|
| 50 |
+
architectures=['LlamaForCausalLM'],
|
| 51 |
+
requires=['transformers>=4.43'],
|
| 52 |
+
model_arch=ModelArch.llama,
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
register_model(
|
| 56 |
+
ModelMeta(
|
| 57 |
+
RMModelType.llama3_2_reward,
|
| 58 |
+
[
|
| 59 |
+
ModelGroup([
|
| 60 |
+
Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B', 'Skywork/Skywork-Reward-Llama-3.1-8B'),
|
| 61 |
+
Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2', 'Skywork/Skywork-Reward-Llama-3.1-8B-v0.2'),
|
| 62 |
+
]),
|
| 63 |
+
ModelGroup([
|
| 64 |
+
Model('AI-ModelScope/GRM_Llama3.1_8B_rewardmodel-ft', 'Ray2333/GRM_Llama3.1_8B_rewardmodel-ft'),
|
| 65 |
+
Model('AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft', 'Ray2333/GRM-llama3.2-3B-rewardmodel-ft'),
|
| 66 |
+
])
|
| 67 |
+
],
|
| 68 |
+
TemplateType.llama3_2,
|
| 69 |
+
get_model_tokenizer_with_flash_attn,
|
| 70 |
+
requires=['transformers>=4.43'],
|
| 71 |
+
architectures=['LlamaForSequenceClassification'],
|
| 72 |
+
model_arch=ModelArch.llama,
|
| 73 |
+
))
|
| 74 |
+
|
| 75 |
+
register_model(
|
| 76 |
+
ModelMeta(
|
| 77 |
+
RMModelType.gemma_reward,
|
| 78 |
+
[
|
| 79 |
+
ModelGroup([
|
| 80 |
+
Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B', 'Skywork/Skywork-Reward-Gemma-2-27B'),
|
| 81 |
+
Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2', 'Skywork/Skywork-Reward-Gemma-2-27B-v0.2'),
|
| 82 |
+
]),
|
| 83 |
+
],
|
| 84 |
+
TemplateType.gemma,
|
| 85 |
+
get_model_tokenizer_with_flash_attn,
|
| 86 |
+
requires=['transformers>=4.42'],
|
| 87 |
+
architectures=['Gemma2ForSequenceClassification'],
|
| 88 |
+
model_arch=ModelArch.llama,
|
| 89 |
+
))
|
swift/llm/model/model/stepfun.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from transformers import AutoModel
|
| 7 |
+
|
| 8 |
+
from swift.llm import TemplateType
|
| 9 |
+
from ..constant import MLLMModelType
|
| 10 |
+
from ..model_arch import ModelArch
|
| 11 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
|
| 12 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 13 |
+
from ..utils import git_clone_github, safe_snapshot_download
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_model_tokenizer_got_ocr2(*args, **kwargs):
|
| 17 |
+
kwargs['automodel_class'] = AutoModel
|
| 18 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 19 |
+
return model, tokenizer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
register_model(
|
| 23 |
+
ModelMeta(
|
| 24 |
+
MLLMModelType.got_ocr2, [
|
| 25 |
+
ModelGroup([
|
| 26 |
+
Model('stepfun-ai/GOT-OCR2_0', 'stepfun-ai/GOT-OCR2_0'),
|
| 27 |
+
]),
|
| 28 |
+
],
|
| 29 |
+
TemplateType.got_ocr2,
|
| 30 |
+
get_model_tokenizer_got_ocr2,
|
| 31 |
+
model_arch=ModelArch.got_ocr2,
|
| 32 |
+
architectures=['GOTQwenForCausalLM'],
|
| 33 |
+
tags=['vision']))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
|
| 37 |
+
from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration
|
| 38 |
+
GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer')
|
| 39 |
+
model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
|
| 40 |
+
return model, processor
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
register_model(
|
| 44 |
+
ModelMeta(
|
| 45 |
+
MLLMModelType.got_ocr2_hf, [
|
| 46 |
+
ModelGroup([
|
| 47 |
+
Model('stepfun-ai/GOT-OCR-2.0-hf', 'stepfun-ai/GOT-OCR-2.0-hf'),
|
| 48 |
+
]),
|
| 49 |
+
],
|
| 50 |
+
TemplateType.got_ocr2_hf,
|
| 51 |
+
get_model_tokenizer_got_ocr2_hf,
|
| 52 |
+
model_arch=ModelArch.got_ocr2_hf,
|
| 53 |
+
architectures=['GOTQwenForCausalLM'],
|
| 54 |
+
tags=['vision']))
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_model_tokenizer_step_audio(*args, **kwargs):
|
| 58 |
+
local_repo_path = kwargs.get('local_repo_path')
|
| 59 |
+
if not local_repo_path:
|
| 60 |
+
local_repo_path = git_clone_github('https://github.com/stepfun-ai/Step-Audio.git')
|
| 61 |
+
sys.path.append(local_repo_path)
|
| 62 |
+
from tokenizer import StepAudioTokenizer
|
| 63 |
+
encoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-Tokenizer', check_local=True)
|
| 64 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 65 |
+
if model is not None:
|
| 66 |
+
model.encoder = StepAudioTokenizer(encoder_path)
|
| 67 |
+
# from tts import StepAudioTTS
|
| 68 |
+
# if not os.path.exists('speakers'):
|
| 69 |
+
# shutil.copytree(os.path.join(local_repo_path, 'speakers'), 'speakers')
|
| 70 |
+
# decoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-TTS-3B', check_local=True)
|
| 71 |
+
# model.decoder = StepAudioTTS(decoder_path, model.encoder)
|
| 72 |
+
return model, tokenizer
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
register_model(
|
| 76 |
+
ModelMeta(
|
| 77 |
+
MLLMModelType.step_audio, [
|
| 78 |
+
ModelGroup([
|
| 79 |
+
Model('stepfun-ai/Step-Audio-Chat', 'stepfun-ai/Step-Audio-Chat'),
|
| 80 |
+
]),
|
| 81 |
+
],
|
| 82 |
+
TemplateType.step_audio,
|
| 83 |
+
get_model_tokenizer_step_audio,
|
| 84 |
+
architectures=['Step1ForCausalLM'],
|
| 85 |
+
requires=['funasr', 'sox', 'conformer', 'openai-whisper', 'librosa'],
|
| 86 |
+
tags=['audio']))
|
swift/llm/model/model/telechat.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from transformers import GenerationConfig
|
| 4 |
+
|
| 5 |
+
from swift.llm import TemplateType
|
| 6 |
+
from ..constant import LLMModelType
|
| 7 |
+
from ..model_arch import ModelArch
|
| 8 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_model_tokenizer_telechat(*args, **kwargs):
|
| 12 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 13 |
+
model_dir = args[0]
|
| 14 |
+
generation_config = GenerationConfig.from_pretrained(model_dir)
|
| 15 |
+
for k in ['bos_token_id', 'eos_token_id', 'pad_token_id', 'user_token_id', 'bot_token_id']:
|
| 16 |
+
setattr(tokenizer, k, getattr(generation_config, k))
|
| 17 |
+
return model, tokenizer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
register_model(
|
| 21 |
+
ModelMeta(
|
| 22 |
+
LLMModelType.telechat,
|
| 23 |
+
[
|
| 24 |
+
ModelGroup([
|
| 25 |
+
Model('TeleAI/TeleChat-7B', 'Tele-AI/telechat-7B'),
|
| 26 |
+
Model('TeleAI/TeleChat-12B', 'Tele-AI/TeleChat-12B'),
|
| 27 |
+
Model('TeleAI/TeleChat-12B-v2', 'Tele-AI/TeleChat-12B-v2'),
|
| 28 |
+
Model('TeleAI/TeleChat-52B', 'TeleAI/TeleChat-52B'),
|
| 29 |
+
]),
|
| 30 |
+
ModelGroup([
|
| 31 |
+
Model('swift/TeleChat-12B-V2-GPTQ-Int4'),
|
| 32 |
+
]),
|
| 33 |
+
ModelGroup([
|
| 34 |
+
Model('TeleAI/TeleChat2-35B', 'Tele-AI/TeleChat2-35B'),
|
| 35 |
+
Model('TeleAI/TeleChat2-115B', 'Tele-AI/TeleChat2-115B'),
|
| 36 |
+
]),
|
| 37 |
+
],
|
| 38 |
+
TemplateType.telechat,
|
| 39 |
+
get_model_tokenizer_telechat,
|
| 40 |
+
model_arch=ModelArch.telechat,
|
| 41 |
+
architectures=['TelechatForCausalLM', 'TeleChatForCausalLM'],
|
| 42 |
+
))
|
| 43 |
+
|
| 44 |
+
register_model(
|
| 45 |
+
ModelMeta(
|
| 46 |
+
LLMModelType.telechat2,
|
| 47 |
+
[
|
| 48 |
+
ModelGroup([
|
| 49 |
+
Model('TeleAI/TeleChat2-3B', 'Tele-AI/TeleChat2-3B'),
|
| 50 |
+
Model('TeleAI/TeleChat2-7B-32K', 'Tele-AI/TeleChat2-7B-32K'),
|
| 51 |
+
Model('TeleAI/TeleChat2-35B-32K', 'Tele-AI/TeleChat2-35B-32K'),
|
| 52 |
+
Model('TeleAI/TeleChat2-35B-Nov', 'Tele-AI/TeleChat2-35B-Nov'),
|
| 53 |
+
]),
|
| 54 |
+
],
|
| 55 |
+
TemplateType.telechat2,
|
| 56 |
+
get_model_tokenizer_telechat,
|
| 57 |
+
model_arch=ModelArch.telechat,
|
| 58 |
+
architectures=['TeleChat2ForCausalLM'],
|
| 59 |
+
))
|
swift/llm/model/model/valley.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from functools import partial, wraps
|
| 5 |
+
from typing import Any, Dict
|
| 6 |
+
|
| 7 |
+
from swift.llm import TemplateType
|
| 8 |
+
from ..constant import MLLMModelType
|
| 9 |
+
from ..model_arch import ModelArch
|
| 10 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 11 |
+
from ..utils import ModelInfo, git_clone_github, safe_snapshot_download
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_model_tokenizer_valley(model_dir: str,
|
| 15 |
+
model_info: ModelInfo,
|
| 16 |
+
model_kwargs: Dict[str, Any],
|
| 17 |
+
load_model: bool = True,
|
| 18 |
+
**kwargs):
|
| 19 |
+
llm_model_type = kwargs.pop('llm_model_type')
|
| 20 |
+
local_repo_path = kwargs.get('local_repo_path')
|
| 21 |
+
if not local_repo_path:
|
| 22 |
+
repo_path = 'https://github.com/bytedance/Valley.git'
|
| 23 |
+
local_repo_path = git_clone_github(repo_path)
|
| 24 |
+
sys.path.append(local_repo_path)
|
| 25 |
+
|
| 26 |
+
if llm_model_type == 'valley':
|
| 27 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 28 |
+
from valley_eagle.model.language_model.valley_qwen2 import ValleyQwen2ForCausalLM, ValleyConfig
|
| 29 |
+
model_config = ValleyConfig.from_pretrained(model_dir)
|
| 30 |
+
model_config.mm_vision_tower = safe_snapshot_download(
|
| 31 |
+
'AI-ModelScope/siglip-so400m-patch14-384', check_local=True)
|
| 32 |
+
model_config.eagle_vision_tower = safe_snapshot_download('Qwen/Qwen2-VL-7B-Instruct', check_local=True)
|
| 33 |
+
automodel_class = ValleyQwen2ForCausalLM
|
| 34 |
+
|
| 35 |
+
if not hasattr(ValleyQwen2ForCausalLM, '_origin_forward'):
|
| 36 |
+
forward = ValleyQwen2ForCausalLM.forward
|
| 37 |
+
ValleyQwen2ForCausalLM._origin_forward = forward
|
| 38 |
+
|
| 39 |
+
@wraps(forward)
|
| 40 |
+
def new_forward(*args, **kwargs):
|
| 41 |
+
import torch
|
| 42 |
+
outputs = forward(*args, **kwargs)
|
| 43 |
+
loss = outputs.loss
|
| 44 |
+
if loss is not None and loss.shape[-1] > 0:
|
| 45 |
+
loss = torch.mean(loss, dim=-1)
|
| 46 |
+
return CausalLMOutputWithPast(
|
| 47 |
+
loss=loss,
|
| 48 |
+
logits=outputs.logits,
|
| 49 |
+
past_key_values=outputs.past_key_values,
|
| 50 |
+
hidden_states=outputs.hidden_states,
|
| 51 |
+
attentions=outputs.attentions,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
ValleyQwen2ForCausalLM.forward = new_forward
|
| 55 |
+
kwargs['model_config'] = model_config
|
| 56 |
+
kwargs['automodel_class'] = automodel_class
|
| 57 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 58 |
+
if model is not None:
|
| 59 |
+
model.generation_config.repetition_penalty = 1.0 # Otherwise, Error. Same for original code.
|
| 60 |
+
from transformers import AutoProcessor, SiglipImageProcessor
|
| 61 |
+
tokenizer.image_processor = SiglipImageProcessor.from_pretrained(model.config.mm_vision_tower)
|
| 62 |
+
tokenizer.qwen2vl_processor = AutoProcessor.from_pretrained(
|
| 63 |
+
model.config.eagle_vision_tower, max_pixels=1280 * 28 * 28)
|
| 64 |
+
tokenizer.image_processor.crop_size = tokenizer.image_processor.size['height']
|
| 65 |
+
return model, tokenizer
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
register_model(
|
| 69 |
+
ModelMeta(
|
| 70 |
+
MLLMModelType.valley,
|
| 71 |
+
[
|
| 72 |
+
ModelGroup([
|
| 73 |
+
Model('bytedance-research/Valley-Eagle-7B'),
|
| 74 |
+
], ),
|
| 75 |
+
],
|
| 76 |
+
TemplateType.valley,
|
| 77 |
+
partial(get_model_tokenizer_valley, llm_model_type='valley'),
|
| 78 |
+
architectures=['ValleyQwen2ForCausalLM'],
|
| 79 |
+
model_arch=ModelArch.valley,
|
| 80 |
+
requires=['transformers>=4.42', 'av'],
|
| 81 |
+
tags=['vision'],
|
| 82 |
+
))
|
swift/llm/model/model/yi.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from swift.llm import TemplateType
|
| 9 |
+
from swift.utils import get_logger
|
| 10 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 11 |
+
from ..model_arch import ModelArch
|
| 12 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 13 |
+
from ..utils import ModelInfo, git_clone_github
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_model_tokenizer_yi(model_dir, *args, **kwargs):
|
| 19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
|
| 20 |
+
return get_model_tokenizer_with_flash_attn(model_dir, *args, tokenizer=tokenizer, **kwargs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_model_tokenizer_yi_vl(model_dir: str,
|
| 24 |
+
model_info: ModelInfo,
|
| 25 |
+
model_kwargs: Dict[str, Any],
|
| 26 |
+
load_model: bool = True,
|
| 27 |
+
**kwargs):
|
| 28 |
+
local_repo_path = kwargs.get('local_repo_path')
|
| 29 |
+
if not local_repo_path:
|
| 30 |
+
local_repo_path = git_clone_github('https://github.com/01-ai/Yi')
|
| 31 |
+
sys.path.append(os.path.join(local_repo_path, 'VL'))
|
| 32 |
+
from llava.model import LlavaLlamaForCausalLM, LlavaConfig
|
| 33 |
+
from llava.model.constants import key_info
|
| 34 |
+
|
| 35 |
+
model_config = LlavaConfig.from_pretrained(model_dir)
|
| 36 |
+
mm_vision_tower = model_config.mm_vision_tower
|
| 37 |
+
model_config.mm_vision_tower = os.path.join(model_dir, *mm_vision_tower.rsplit('/', maxsplit=2)[-2:])
|
| 38 |
+
model_config.attention_dropout = 0.
|
| 39 |
+
key_info['model_path'] = model_dir
|
| 40 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaLlamaForCausalLM
|
| 41 |
+
model, tokenizer = get_model_tokenizer_yi(
|
| 42 |
+
model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
|
| 43 |
+
if model is not None:
|
| 44 |
+
logger.info('Please ignore the above warning.')
|
| 45 |
+
logger.info('Loading the parameters of vision_tower...')
|
| 46 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 47 |
+
vision_tower = model.get_vision_tower()
|
| 48 |
+
vision_tower.load_model()
|
| 49 |
+
vision_tower.to(device=model.device, dtype=model_config.torch_dtype)
|
| 50 |
+
if not hasattr(model.config, 'max_sequence_length'):
|
| 51 |
+
model.config.max_sequence_length = 2048
|
| 52 |
+
tokenizer.image_processor = vision_tower.image_processor
|
| 53 |
+
return model, tokenizer
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
register_model(
|
| 57 |
+
ModelMeta(
|
| 58 |
+
MLLMModelType.yi_vl,
|
| 59 |
+
[
|
| 60 |
+
ModelGroup([
|
| 61 |
+
Model('01ai/Yi-VL-6B', '01-ai/Yi-VL-6B'),
|
| 62 |
+
Model('01ai/Yi-VL-34B', '01-ai/Yi-VL-34B'),
|
| 63 |
+
], ),
|
| 64 |
+
],
|
| 65 |
+
TemplateType.yi_vl,
|
| 66 |
+
get_model_tokenizer_yi_vl,
|
| 67 |
+
model_arch=ModelArch.llava_llama,
|
| 68 |
+
architectures=['LlavaLlamaForCausalLM'],
|
| 69 |
+
requires=['transformers>=4.34'],
|
| 70 |
+
tags=['vision'],
|
| 71 |
+
))
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
LLMModelType.yi,
|
| 76 |
+
[ # yi
|
| 77 |
+
ModelGroup([
|
| 78 |
+
Model('01ai/Yi-6B', '01-ai/Yi-6B'),
|
| 79 |
+
Model('01ai/Yi-6B-200K', '01-ai/Yi-6B-200K'),
|
| 80 |
+
Model('01ai/Yi-6B-Chat', '01-ai/Yi-6B-Chat'),
|
| 81 |
+
Model('01ai/Yi-6B-Chat-4bits', '01-ai/Yi-6B-Chat-4bits'),
|
| 82 |
+
Model('01ai/Yi-6B-Chat-8bits', '01-ai/Yi-6B-Chat-8bits'),
|
| 83 |
+
Model('01ai/Yi-9B', '01-ai/Yi-9B'),
|
| 84 |
+
Model('01ai/Yi-9B-200K', '01-ai/Yi-9B-200K'),
|
| 85 |
+
Model('01ai/Yi-34B', '01-ai/Yi-34B'),
|
| 86 |
+
Model('01ai/Yi-34B-200K', '01-ai/Yi-34B-200K'),
|
| 87 |
+
Model('01ai/Yi-34B-Chat', '01-ai/Yi-34B-Chat'),
|
| 88 |
+
Model('01ai/Yi-34B-Chat-4bits', '01-ai/Yi-34B-Chat-4bits'),
|
| 89 |
+
Model('01ai/Yi-34B-Chat-8bits', '01-ai/Yi-34B-Chat-8bits'),
|
| 90 |
+
]),
|
| 91 |
+
# yi1.5
|
| 92 |
+
ModelGroup([
|
| 93 |
+
Model('01ai/Yi-1.5-6B', '01-ai/Yi-1.5-6B'),
|
| 94 |
+
Model('01ai/Yi-1.5-6B-Chat', '01-ai/Yi-1.5-6B-Chat'),
|
| 95 |
+
Model('01ai/Yi-1.5-9B', '01-ai/Yi-1.5-9B'),
|
| 96 |
+
Model('01ai/Yi-1.5-9B-Chat', '01-ai/Yi-1.5-9B-Chat'),
|
| 97 |
+
Model('01ai/Yi-1.5-9B-Chat-16K', '01-ai/Yi-1.5-9B-Chat-16K'),
|
| 98 |
+
Model('01ai/Yi-1.5-34B', '01-ai/Yi-1.5-34B'),
|
| 99 |
+
Model('01ai/Yi-1.5-34B-Chat', '01-ai/Yi-1.5-34B-Chat'),
|
| 100 |
+
Model('01ai/Yi-1.5-34B-Chat-16K', '01-ai/Yi-1.5-34B-Chat-16K'),
|
| 101 |
+
]),
|
| 102 |
+
# yi1.5-quant
|
| 103 |
+
ModelGroup([
|
| 104 |
+
Model('AI-ModelScope/Yi-1.5-6B-Chat-GPTQ', 'modelscope/Yi-1.5-6B-Chat-GPTQ'),
|
| 105 |
+
Model('AI-ModelScope/Yi-1.5-6B-Chat-AWQ', 'modelscope/Yi-1.5-6B-Chat-AWQ'),
|
| 106 |
+
Model('AI-ModelScope/Yi-1.5-9B-Chat-GPTQ', 'modelscope/Yi-1.5-9B-Chat-GPTQ'),
|
| 107 |
+
Model('AI-ModelScope/Yi-1.5-9B-Chat-AWQ', 'modelscope/Yi-1.5-9B-Chat-AWQ'),
|
| 108 |
+
Model('AI-ModelScope/Yi-1.5-34B-Chat-GPTQ', 'modelscope/Yi-1.5-34B-Chat-GPTQ'),
|
| 109 |
+
Model('AI-ModelScope/Yi-1.5-34B-Chat-AWQ', 'modelscope/Yi-1.5-34B-Chat-AWQ'),
|
| 110 |
+
]),
|
| 111 |
+
],
|
| 112 |
+
TemplateType.chatml,
|
| 113 |
+
get_model_tokenizer_yi,
|
| 114 |
+
architectures=['LlamaForCausalLM'],
|
| 115 |
+
model_arch=ModelArch.llama,
|
| 116 |
+
))
|
| 117 |
+
|
| 118 |
+
register_model(
|
| 119 |
+
ModelMeta(
|
| 120 |
+
LLMModelType.yi_coder, [
|
| 121 |
+
ModelGroup([
|
| 122 |
+
Model('01ai/Yi-Coder-1.5B', '01-ai/Yi-Coder-1.5B'),
|
| 123 |
+
Model('01ai/Yi-Coder-9B', '01-ai/Yi-Coder-9B'),
|
| 124 |
+
Model('01ai/Yi-Coder-1.5B-Chat', '01-ai/Yi-Coder-1.5B-Chat'),
|
| 125 |
+
Model('01ai/Yi-Coder-9B-Chat', '01-ai/Yi-Coder-9B-Chat'),
|
| 126 |
+
], )
|
| 127 |
+
],
|
| 128 |
+
TemplateType.yi_coder,
|
| 129 |
+
get_model_tokenizer_yi,
|
| 130 |
+
architectures=['LlamaForCausalLM'],
|
| 131 |
+
model_arch=ModelArch.llama,
|
| 132 |
+
tags=['coding']))
|
| 133 |
+
|
| 134 |
+
register_model(
|
| 135 |
+
ModelMeta(
|
| 136 |
+
LLMModelType.sus,
|
| 137 |
+
[
|
| 138 |
+
ModelGroup([
|
| 139 |
+
Model('SUSTC/SUS-Chat-34B', 'SUSTech/SUS-Chat-34B'),
|
| 140 |
+
]),
|
| 141 |
+
],
|
| 142 |
+
TemplateType.sus,
|
| 143 |
+
get_model_tokenizer_yi,
|
| 144 |
+
model_arch=ModelArch.llama,
|
| 145 |
+
architectures=['LlamaForCausalLM'],
|
| 146 |
+
))
|
swift/llm/sampling/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampling import sampling_main
|
swift/llm/sampling/base.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List
|
| 2 |
+
|
| 3 |
+
from swift.llm import SamplingArguments
|
| 4 |
+
from swift.plugin import orms, prms
|
| 5 |
+
from swift.utils import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Sampler:
|
| 11 |
+
|
| 12 |
+
def __init__(self, input_args: SamplingArguments):
|
| 13 |
+
self.args = input_args
|
| 14 |
+
self.template = None
|
| 15 |
+
self.processor = None
|
| 16 |
+
self.prm_model = None
|
| 17 |
+
self.orm_model = None
|
| 18 |
+
self._prepare_model_tokenizer()
|
| 19 |
+
self._prepare_template()
|
| 20 |
+
self._prepare_rm()
|
| 21 |
+
|
| 22 |
+
def _prepare_model_tokenizer(self):
|
| 23 |
+
args = self.args
|
| 24 |
+
_, self.processor = args.get_model_processor(load_model=False)
|
| 25 |
+
|
| 26 |
+
def _prepare_rm(self):
|
| 27 |
+
if self.args.prm_model is None:
|
| 28 |
+
self.prm_model = None
|
| 29 |
+
logger.warning('prm_model is None.')
|
| 30 |
+
elif self.args.prm_model in prms:
|
| 31 |
+
self.prm_model = prms[self.args.prm_model]()
|
| 32 |
+
else:
|
| 33 |
+
from swift.llm import PtEngine
|
| 34 |
+
self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)
|
| 35 |
+
|
| 36 |
+
if self.args.orm_model is None:
|
| 37 |
+
self.orm_model = None
|
| 38 |
+
logger.warning('orm_model is None.')
|
| 39 |
+
elif self.args.orm_model in orms:
|
| 40 |
+
self.orm_model = orms[self.args.orm_model]()
|
| 41 |
+
else:
|
| 42 |
+
from swift.llm import PtEngine
|
| 43 |
+
self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)
|
| 44 |
+
|
| 45 |
+
def _prepare_template(self) -> None:
|
| 46 |
+
template = self.args.get_template(self.processor)
|
| 47 |
+
self.template = template
|
| 48 |
+
self.template.set_mode('train')
|
| 49 |
+
|
| 50 |
+
def truncate_input(self, slices: List[Dict[str, Any]]):
|
| 51 |
+
"""Truncate the input rows to avoid hitting the max length of the policy model"""
|
| 52 |
+
return slices
|
| 53 |
+
|
| 54 |
+
def do_sample(self, data):
|
| 55 |
+
raise NotImplementedError
|
swift/llm/sampling/distill_sampler.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
|
| 7 |
+
from swift.llm.infer.protocol import InferRequest, RequestConfig
|
| 8 |
+
from swift.llm.sampling.vanilla_sampler import VanillaSampler
|
| 9 |
+
from .utils import get_messages_md5
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OpenAI_Engine():
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model: str,
|
| 17 |
+
stream: bool = False,
|
| 18 |
+
base_url: str = 'https://dashscope.aliyuncs.com/compatible-mode/v1',
|
| 19 |
+
api_key: str = '',
|
| 20 |
+
**kwargs,
|
| 21 |
+
):
|
| 22 |
+
self.model = model
|
| 23 |
+
self.stream = stream
|
| 24 |
+
self.client = OpenAI(api_key=api_key if api_key else os.getenv('OPENAI_API_KEY'), base_url=base_url, **kwargs)
|
| 25 |
+
|
| 26 |
+
def infer(
|
| 27 |
+
self,
|
| 28 |
+
infer_requests: List[InferRequest],
|
| 29 |
+
request_config: Optional[RequestConfig] = None,
|
| 30 |
+
):
|
| 31 |
+
resp_contents = []
|
| 32 |
+
for infer_request in infer_requests:
|
| 33 |
+
completion = self.client.chat.completions.create(
|
| 34 |
+
model=self.model,
|
| 35 |
+
messages=infer_request['messages'],
|
| 36 |
+
temperature=request_config.temperature,
|
| 37 |
+
top_p=request_config.top_p,
|
| 38 |
+
max_tokens=request_config.max_tokens,
|
| 39 |
+
stream=self.stream,
|
| 40 |
+
)
|
| 41 |
+
if self.stream:
|
| 42 |
+
reasoning_content = ''
|
| 43 |
+
content = ''
|
| 44 |
+
for chunk in completion:
|
| 45 |
+
chunk_choices = chunk.choices
|
| 46 |
+
if len(chunk_choices) == 0:
|
| 47 |
+
continue
|
| 48 |
+
reasoning_chunk = chunk_choices[0].delta.reasoning_content if hasattr(
|
| 49 |
+
chunk_choices[0].delta, 'reasoning_content') else ''
|
| 50 |
+
answer_chunk = chunk_choices[0].delta.content
|
| 51 |
+
if reasoning_chunk:
|
| 52 |
+
reasoning_content += reasoning_chunk
|
| 53 |
+
elif answer_chunk:
|
| 54 |
+
content += answer_chunk
|
| 55 |
+
else:
|
| 56 |
+
if hasattr(completion.choices[0].message, 'reasoning_content'):
|
| 57 |
+
reasoning_content = completion.choices[0].message.reasoning_content
|
| 58 |
+
content = completion.choices[0].message.content
|
| 59 |
+
assert len(content) > 0, 'Empty completion'
|
| 60 |
+
if reasoning_content:
|
| 61 |
+
resp_content = f'<think>{reasoning_content}</think>\n\n<answer>{content}</answer>'
|
| 62 |
+
else:
|
| 63 |
+
resp_content = content
|
| 64 |
+
resp_contents.append(resp_content)
|
| 65 |
+
|
| 66 |
+
return resp_contents
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DistillSampler(VanillaSampler):
|
| 70 |
+
|
| 71 |
+
def __init__(self, *args, **kwargs):
|
| 72 |
+
super(VanillaSampler, self).__init__(*args, **kwargs)
|
| 73 |
+
assert self.args.sampler_engine == 'client'
|
| 74 |
+
_Engine = OpenAI_Engine
|
| 75 |
+
self.infer_engine = _Engine(model=self.args.model, stream=self.args.stream, **self.args.engine_kwargs)
|
| 76 |
+
self.infer_engine.strict = False
|
| 77 |
+
self.caches = self.read_cache()
|
| 78 |
+
|
| 79 |
+
def _prepare_model_tokenizer(self):
|
| 80 |
+
pass
|
| 81 |
+
|
| 82 |
+
def _prepare_template(self):
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
def extract_choice(self, resp):
|
| 86 |
+
message = resp.choices[0].message
|
| 87 |
+
if hasattr(message, 'reasoning_content'):
|
| 88 |
+
reps_content = f'<think>{message.reasoning_content}</think>\n\n<answer>{message.content}</answer>'
|
| 89 |
+
else:
|
| 90 |
+
reps_content = message.content
|
| 91 |
+
return reps_content
|
| 92 |
+
|
| 93 |
+
def generate(self, data):
|
| 94 |
+
resp_all = []
|
| 95 |
+
infer_requests = []
|
| 96 |
+
sent = 0
|
| 97 |
+
rows = self.convert_data_to_rows(data)
|
| 98 |
+
for idx, row in enumerate(rows):
|
| 99 |
+
row = deepcopy(row)
|
| 100 |
+
messages = row['messages']
|
| 101 |
+
uuid = get_messages_md5(row)
|
| 102 |
+
if uuid in self.caches:
|
| 103 |
+
choices = self.caches[uuid]['choices']
|
| 104 |
+
if len(choices) == self.args.num_return_sequences:
|
| 105 |
+
continue
|
| 106 |
+
if self.args.system:
|
| 107 |
+
if messages[0]['role'] == 'system':
|
| 108 |
+
messages[0]['content'] = self.args.system
|
| 109 |
+
else:
|
| 110 |
+
messages.insert(0, {'role': 'system', 'content': self.args.system})
|
| 111 |
+
if messages[-1]['role'] == 'assistant':
|
| 112 |
+
messages = messages[:-1]
|
| 113 |
+
|
| 114 |
+
row['messages'] = messages
|
| 115 |
+
infer_request = row
|
| 116 |
+
for i in range(self.args.num_return_sequences):
|
| 117 |
+
infer_requests.append(deepcopy(infer_request))
|
| 118 |
+
sent += 1
|
| 119 |
+
|
| 120 |
+
request_config = RequestConfig(
|
| 121 |
+
max_tokens=self.args.max_new_tokens,
|
| 122 |
+
temperature=self.args.temperature,
|
| 123 |
+
top_k=self.args.top_k,
|
| 124 |
+
top_p=self.args.top_p,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
resp_list = []
|
| 128 |
+
if len(infer_requests) > 0:
|
| 129 |
+
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
|
| 130 |
+
|
| 131 |
+
_cur = 0
|
| 132 |
+
for idx, row in enumerate(rows):
|
| 133 |
+
row = deepcopy(row)
|
| 134 |
+
uuid = get_messages_md5(row)
|
| 135 |
+
if uuid in self.caches:
|
| 136 |
+
choices = self.caches[uuid]['choices']
|
| 137 |
+
if len(choices) == self.args.num_return_sequences:
|
| 138 |
+
row['choices'] = choices
|
| 139 |
+
resp_all.append(row)
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
resps = row
|
| 143 |
+
resps['choices'] = []
|
| 144 |
+
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
|
| 145 |
+
resps['choices'].append(resp_list[j])
|
| 146 |
+
resp_all.append(resps)
|
| 147 |
+
_cur += 1
|
| 148 |
+
return resp_all
|
swift/llm/sampling/mcts.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import InferRequest, SamplingArguments
|
| 10 |
+
from swift.llm.infer.protocol import UsageInfo
|
| 11 |
+
from swift.utils import get_logger
|
| 12 |
+
from .base import Sampler
|
| 13 |
+
from .utils import get_reward, perform_infer
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
NXT_PROMPT = """Continue.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
next_message = {
|
| 21 |
+
'role': 'user',
|
| 22 |
+
'content': NXT_PROMPT,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LanguageNode:
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
step: str = None,
|
| 31 |
+
sep_token: str = None,
|
| 32 |
+
parent: 'LanguageNode' = None,
|
| 33 |
+
):
|
| 34 |
+
self.parent = parent
|
| 35 |
+
|
| 36 |
+
if sep_token:
|
| 37 |
+
self.sep_token = sep_token
|
| 38 |
+
else:
|
| 39 |
+
self.sep_token = parent.sep_token
|
| 40 |
+
|
| 41 |
+
if parent:
|
| 42 |
+
self.path = parent.path[:] + [step]
|
| 43 |
+
self.answer = parent.answer + step + self.sep_token
|
| 44 |
+
self.depth = parent.depth + 1
|
| 45 |
+
else:
|
| 46 |
+
self.path = []
|
| 47 |
+
self.answer = ''
|
| 48 |
+
self.depth = 0
|
| 49 |
+
|
| 50 |
+
self.active_children = []
|
| 51 |
+
self.children = []
|
| 52 |
+
self.visit_count = 0
|
| 53 |
+
self.process_reward = 0.0
|
| 54 |
+
self.outcome_reward = 0.0
|
| 55 |
+
self.terminated = False
|
| 56 |
+
self.correct = False
|
| 57 |
+
|
| 58 |
+
def is_leaf(self):
|
| 59 |
+
return len(self.children) == 0
|
| 60 |
+
|
| 61 |
+
def is_root(self):
|
| 62 |
+
return self.parent is None
|
| 63 |
+
|
| 64 |
+
def visit(self):
|
| 65 |
+
self.visit_count += 1
|
| 66 |
+
|
| 67 |
+
def init_and_update_value(self, value):
|
| 68 |
+
self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1)
|
| 69 |
+
|
| 70 |
+
def add_child(self, child: 'LanguageNode'):
|
| 71 |
+
self.children.append(child)
|
| 72 |
+
if not child.terminated:
|
| 73 |
+
self.active_children.append(child)
|
| 74 |
+
|
| 75 |
+
def collect(self):
|
| 76 |
+
result = {
|
| 77 |
+
'path': self.path,
|
| 78 |
+
'depth': self.depth,
|
| 79 |
+
'visit_count': self.visit_count,
|
| 80 |
+
'process_reward': self.process_reward,
|
| 81 |
+
'outcome_reward': self.outcome_reward,
|
| 82 |
+
'terminated': str(self.terminated),
|
| 83 |
+
'correct': str(self.correct),
|
| 84 |
+
'children': [child.collect() for child in self.children],
|
| 85 |
+
}
|
| 86 |
+
return result
|
| 87 |
+
|
| 88 |
+
def __lt__(self, other):
|
| 89 |
+
return self.outcome_reward < other.outcome_reward
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MctsSampler(Sampler):
|
| 93 |
+
|
| 94 |
+
def __init__(self, input_args: SamplingArguments):
|
| 95 |
+
super().__init__(input_args)
|
| 96 |
+
self.usage_info = UsageInfo(0, 0, 0)
|
| 97 |
+
|
| 98 |
+
def _prepare_model_tokenizer(self):
|
| 99 |
+
args = self.args
|
| 100 |
+
self.infer_kwargs = {}
|
| 101 |
+
if args.sampler_engine == 'client':
|
| 102 |
+
from swift.llm import InferClient
|
| 103 |
+
api_key = args.api_key
|
| 104 |
+
base_url = args.base_url
|
| 105 |
+
self.infer_engine = [
|
| 106 |
+
InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)
|
| 107 |
+
]
|
| 108 |
+
self.infer_kwargs['model'] = args.model
|
| 109 |
+
else:
|
| 110 |
+
_Engine = self.get_infer_engine()
|
| 111 |
+
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
|
| 112 |
+
|
| 113 |
+
def get_infer_engine(self):
|
| 114 |
+
if self.args.sampler_engine == 'pt':
|
| 115 |
+
from swift.llm import PtEngine
|
| 116 |
+
_Engine = PtEngine
|
| 117 |
+
elif self.args.sampler_engine == 'vllm':
|
| 118 |
+
from swift.llm import VllmEngine
|
| 119 |
+
_Engine = VllmEngine
|
| 120 |
+
elif self.args.sampler_engine == 'lmdeploy':
|
| 121 |
+
from swift.llm import LmdeployEngine
|
| 122 |
+
_Engine = LmdeployEngine
|
| 123 |
+
elif self.args.sampler_engine == 'no':
|
| 124 |
+
_Engine = None
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
|
| 127 |
+
return _Engine
|
| 128 |
+
|
| 129 |
+
def _prepare_template(self) -> None:
|
| 130 |
+
# Hack from super()
|
| 131 |
+
self._prepare_request_configs()
|
| 132 |
+
|
| 133 |
+
def _prepare_request_configs(self):
|
| 134 |
+
_args = self.args
|
| 135 |
+
request_config = _args.get_request_config()
|
| 136 |
+
request_config.stop = _args.stop_words
|
| 137 |
+
request_config.seed = _args.seed
|
| 138 |
+
self.expand_request_configs = []
|
| 139 |
+
self.rollout_request_configs = []
|
| 140 |
+
for i in range(_args.num_return_sequences):
|
| 141 |
+
expand_request_config = deepcopy(request_config)
|
| 142 |
+
expand_request_config.n = 1
|
| 143 |
+
expand_request_config.num_beams = expand_request_config.n
|
| 144 |
+
expand_request_config.seed += i
|
| 145 |
+
self.expand_request_configs.append(expand_request_config)
|
| 146 |
+
rollout_request_config = deepcopy(request_config)
|
| 147 |
+
rollout_request_config.max_tokens = 500
|
| 148 |
+
rollout_request_config.temperature = 0.0
|
| 149 |
+
rollout_request_config.n = 1
|
| 150 |
+
self.rollout_request_configs.append(rollout_request_config)
|
| 151 |
+
|
| 152 |
+
def update_usage_info(self, response):
|
| 153 |
+
for key, value in self.usage_info.__dict__.items():
|
| 154 |
+
update_value = getattr(response.usage, key, None) + value
|
| 155 |
+
setattr(self.usage_info, key, update_value)
|
| 156 |
+
|
| 157 |
+
def search_single(self, query, ground_truth):
|
| 158 |
+
|
| 159 |
+
def _uct(uct_curr_node: LanguageNode):
|
| 160 |
+
alpha = _args.process_reward_rate
|
| 161 |
+
value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward
|
| 162 |
+
if uct_curr_node.is_root():
|
| 163 |
+
return value
|
| 164 |
+
|
| 165 |
+
exploitation_score = value
|
| 166 |
+
exploration_score = (
|
| 167 |
+
_args.exploration_rate
|
| 168 |
+
* np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1)))
|
| 169 |
+
|
| 170 |
+
return exploration_score + exploitation_score
|
| 171 |
+
|
| 172 |
+
def _select(select_curr_node: LanguageNode):
|
| 173 |
+
while not select_curr_node.is_leaf():
|
| 174 |
+
select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x))
|
| 175 |
+
return select_curr_node
|
| 176 |
+
|
| 177 |
+
def _expand(expand_curr_node: LanguageNode):
|
| 178 |
+
n = _args.num_return_sequences - len(expand_curr_node.children)
|
| 179 |
+
if expand_curr_node.is_root():
|
| 180 |
+
infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)]
|
| 181 |
+
else:
|
| 182 |
+
history_message = {
|
| 183 |
+
'role': 'assistant',
|
| 184 |
+
'content': expand_curr_node.answer,
|
| 185 |
+
}
|
| 186 |
+
infer_request = InferRequest(system_message + [prompt_message, history_message, next_message])
|
| 187 |
+
infer_requests = [infer_request for _ in range(n)]
|
| 188 |
+
|
| 189 |
+
# e_time = time.time()
|
| 190 |
+
# To perform the Expand operation in parallel,
|
| 191 |
+
# there's no need to consider the order for now, since the Prompt is the same.
|
| 192 |
+
expand_iter_index = 0
|
| 193 |
+
while True:
|
| 194 |
+
responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs,
|
| 195 |
+
**self.infer_kwargs)
|
| 196 |
+
if len(responses) > 0:
|
| 197 |
+
break
|
| 198 |
+
if expand_iter_index == 5:
|
| 199 |
+
raise ValueError('Expand should not return any response')
|
| 200 |
+
expand_iter_index += 1
|
| 201 |
+
# logger.info(f"expand.expand time: {time.time() - e_time}")
|
| 202 |
+
|
| 203 |
+
# To fetch Outcome Reward in parallel,
|
| 204 |
+
# the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
|
| 205 |
+
orm_infer_requests = []
|
| 206 |
+
unique_output = set()
|
| 207 |
+
for response in responses:
|
| 208 |
+
self.update_usage_info(response)
|
| 209 |
+
output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0]
|
| 210 |
+
if output in unique_output:
|
| 211 |
+
continue
|
| 212 |
+
unique_output.add(output)
|
| 213 |
+
orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}]))
|
| 214 |
+
child = LanguageNode(step=output, parent=expand_curr_node)
|
| 215 |
+
if self.orm_model.check_terminate(child.answer)[0]:
|
| 216 |
+
child.terminated = True
|
| 217 |
+
expand_curr_node.add_child(child)
|
| 218 |
+
|
| 219 |
+
# e_time = time.time()
|
| 220 |
+
orm_score, _orm_mask = get_reward(
|
| 221 |
+
self.orm_model,
|
| 222 |
+
orm_infer_requests,
|
| 223 |
+
ground_truths=[ground_truth] * len(orm_infer_requests),
|
| 224 |
+
threshold=0.0)
|
| 225 |
+
# logger.info(f"expand.orm time: {time.time() - e_time}")
|
| 226 |
+
for child, score in zip(expand_curr_node.children, orm_score):
|
| 227 |
+
if child.terminated:
|
| 228 |
+
child.init_and_update_value(score)
|
| 229 |
+
child.correct = score > 0.9
|
| 230 |
+
terminated_nodes.append(child)
|
| 231 |
+
|
| 232 |
+
# e_time = time.time()
|
| 233 |
+
if self.prm_model:
|
| 234 |
+
prm_infer_requests = []
|
| 235 |
+
for child in expand_curr_node.children:
|
| 236 |
+
prm_message = {'role': 'assistant', 'content': child.answer}
|
| 237 |
+
prm_infer_requests.append(InferRequest([prompt_message, prm_message]))
|
| 238 |
+
prm_score, _prm_mask = get_reward(
|
| 239 |
+
self.prm_model,
|
| 240 |
+
prm_infer_requests,
|
| 241 |
+
ground_truths=[ground_truth] * len(prm_infer_requests),
|
| 242 |
+
threshold=0.0)
|
| 243 |
+
for child, score in zip(expand_curr_node.children, prm_score):
|
| 244 |
+
child.process_reward = score
|
| 245 |
+
# logger.info(f"expand.prm time: {time.time() - e_time}")
|
| 246 |
+
|
| 247 |
+
def _rollout(rollout_curr_node: LanguageNode):
|
| 248 |
+
rollout_depth = 0
|
| 249 |
+
rollout_nodes = {}
|
| 250 |
+
for i in range(len(rollout_curr_node.active_children)):
|
| 251 |
+
rollout_nodes[i] = {
|
| 252 |
+
'node': rollout_curr_node.active_children[i],
|
| 253 |
+
'history_messages': {
|
| 254 |
+
'role': 'assistant',
|
| 255 |
+
'content': rollout_curr_node.active_children[i].answer,
|
| 256 |
+
},
|
| 257 |
+
}
|
| 258 |
+
active_rollout_nodes = list(rollout_nodes.keys())
|
| 259 |
+
while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
|
| 260 |
+
# r_time = time.time()
|
| 261 |
+
infer_requests = [
|
| 262 |
+
InferRequest(system_message
|
| 263 |
+
+ [prompt_message, rollout_nodes[index]['history_messages'], next_message])
|
| 264 |
+
for index in active_rollout_nodes
|
| 265 |
+
]
|
| 266 |
+
# logger.info(f"rollout.prepare time: {time.time() - r_time}")
|
| 267 |
+
# r_time = time.time()
|
| 268 |
+
rollout_iter_index = 0
|
| 269 |
+
while True:
|
| 270 |
+
responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs,
|
| 271 |
+
**self.infer_kwargs)
|
| 272 |
+
if len(responses) > 0:
|
| 273 |
+
break
|
| 274 |
+
if rollout_iter_index == 5:
|
| 275 |
+
raise ValueError('Rollout should not return any response')
|
| 276 |
+
rollout_iter_index += 1
|
| 277 |
+
# logger.info(f"rollout.infer time: {time.time() - r_time}")
|
| 278 |
+
|
| 279 |
+
# r_time = time.time()
|
| 280 |
+
orm_infer_requests = []
|
| 281 |
+
end_paths = []
|
| 282 |
+
for index, response in zip(active_rollout_nodes, responses):
|
| 283 |
+
self.update_usage_info(response)
|
| 284 |
+
output = response.choices[0].message.content.rstrip(sep_token
|
| 285 |
+
+ '\n').split(sep_token)[0] + sep_token + '\n'
|
| 286 |
+
rollout_nodes[index]['history_messages']['content'] += output
|
| 287 |
+
end_paths.append(rollout_nodes[index]['history_messages']['content'])
|
| 288 |
+
orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']]))
|
| 289 |
+
# logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")
|
| 290 |
+
|
| 291 |
+
# r_time = time.time()
|
| 292 |
+
orm_score, _orm_mask = get_reward(
|
| 293 |
+
self.orm_model,
|
| 294 |
+
orm_infer_requests,
|
| 295 |
+
ground_truths=[ground_truth] * len(infer_requests),
|
| 296 |
+
threshold=0.0)
|
| 297 |
+
# logger.info(f"rollout.get_orm time: {time.time() - r_time}")
|
| 298 |
+
terminated_state = self.orm_model.check_terminate(end_paths)
|
| 299 |
+
for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state):
|
| 300 |
+
if terminated:
|
| 301 |
+
rollout_curr_node.active_children[index].init_and_update_value(score)
|
| 302 |
+
if score > 0.9:
|
| 303 |
+
rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content'])
|
| 304 |
+
else:
|
| 305 |
+
rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content'])
|
| 306 |
+
rollout_nodes.pop(index)
|
| 307 |
+
active_rollout_nodes = list(rollout_nodes.keys())
|
| 308 |
+
rollout_depth += 1
|
| 309 |
+
|
| 310 |
+
def _back_propagate(back_curr_node: LanguageNode):
|
| 311 |
+
while back_curr_node:
|
| 312 |
+
if back_curr_node == curr_node:
|
| 313 |
+
best_child_value = max([child.outcome_reward for child in back_curr_node.children])
|
| 314 |
+
back_curr_node.init_and_update_value(best_child_value)
|
| 315 |
+
last_child_value = back_curr_node.outcome_reward
|
| 316 |
+
else:
|
| 317 |
+
back_curr_node.init_and_update_value(last_child_value)
|
| 318 |
+
last_child_value = back_curr_node.outcome_reward
|
| 319 |
+
back_curr_node.visit()
|
| 320 |
+
if len(back_curr_node.active_children) == 0:
|
| 321 |
+
back_curr_node.terminated = True
|
| 322 |
+
if not back_curr_node.is_root():
|
| 323 |
+
back_curr_node.parent.active_children.remove(back_curr_node)
|
| 324 |
+
back_curr_node = back_curr_node.parent
|
| 325 |
+
|
| 326 |
+
_args = self.args
|
| 327 |
+
system_message = [] + _args.system_message
|
| 328 |
+
sep_token = _args.stop_words[0] + '\n'
|
| 329 |
+
_root = LanguageNode(sep_token=sep_token)
|
| 330 |
+
prompt_message = {
|
| 331 |
+
'role': 'user',
|
| 332 |
+
'content': query,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], []
|
| 336 |
+
iter_count = 0
|
| 337 |
+
stop_reason = None
|
| 338 |
+
while True:
|
| 339 |
+
logger.info(f'iter_count: {iter_count}' + '.' * 10)
|
| 340 |
+
s_time = time.time()
|
| 341 |
+
curr_node = _select(_root)
|
| 342 |
+
logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 343 |
+
s_time = time.time()
|
| 344 |
+
_expand(curr_node)
|
| 345 |
+
logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 346 |
+
if curr_node.depth > _args.rollout_start_depth:
|
| 347 |
+
s_time = time.time()
|
| 348 |
+
_rollout(curr_node)
|
| 349 |
+
logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 350 |
+
s_time = time.time()
|
| 351 |
+
_back_propagate(curr_node)
|
| 352 |
+
logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 353 |
+
if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences:
|
| 354 |
+
if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
|
| 355 |
+
stop_reason = 'too easy'
|
| 356 |
+
break
|
| 357 |
+
elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
|
| 358 |
+
stop_reason = 'too hard'
|
| 359 |
+
break
|
| 360 |
+
if _root.terminated:
|
| 361 |
+
stop_reason = 'root terminated'
|
| 362 |
+
break
|
| 363 |
+
if len(terminated_nodes) >= _args.num_return_sequences:
|
| 364 |
+
stop_reason = 'enough nodes'
|
| 365 |
+
break
|
| 366 |
+
if iter_count >= _args.max_iterations:
|
| 367 |
+
stop_reason = 'max_iterations'
|
| 368 |
+
break
|
| 369 |
+
iter_count += 1
|
| 370 |
+
logger.info(f'stop_reason: {stop_reason}')
|
| 371 |
+
# logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
|
| 372 |
+
# logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")
|
| 373 |
+
|
| 374 |
+
monte_carlo_tree = _root.collect()
|
| 375 |
+
result = {
|
| 376 |
+
'query': query,
|
| 377 |
+
'ground_truth': ground_truth,
|
| 378 |
+
'rollout_correct_answers': rollout_correct_answers,
|
| 379 |
+
'rollout_incorrect_answers': rollout_incorrect_answers,
|
| 380 |
+
'monte_carlo_tree': monte_carlo_tree,
|
| 381 |
+
}
|
| 382 |
+
result_json = json.dumps(result, ensure_ascii=False)
|
| 383 |
+
logger.info(result_json)
|
| 384 |
+
return result_json
|
| 385 |
+
|
| 386 |
+
def do_sample(self, data):
|
| 387 |
+
if not isinstance(data, list):
|
| 388 |
+
data = [data]
|
| 389 |
+
generated = []
|
| 390 |
+
for item in data:
|
| 391 |
+
logger.info(f'time: {time.ctime(time.time())}')
|
| 392 |
+
try:
|
| 393 |
+
messages = item['messages'][0]
|
| 394 |
+
query = messages[0]['content']
|
| 395 |
+
ground_truth = messages[1]['content']
|
| 396 |
+
generated.append(self.search_single(query, ground_truth) + '\n')
|
| 397 |
+
except Exception as e:
|
| 398 |
+
logger.error(f'Error: {e}')
|
| 399 |
+
logger.error(f'Traceback: {traceback.format_exc()}')
|
| 400 |
+
return generated
|
swift/llm/sampling/sampling.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SwiftSampling(SwiftPipeline):
|
| 16 |
+
args_class = SamplingArguments
|
| 17 |
+
args: args_class
|
| 18 |
+
|
| 19 |
+
def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
|
| 20 |
+
super().__init__(args)
|
| 21 |
+
self.args.save_args()
|
| 22 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 23 |
+
self.cur_piece = 0
|
| 24 |
+
self.total_piece = 1
|
| 25 |
+
|
| 26 |
+
if self.args.data_range:
|
| 27 |
+
self.cur_piece, self.total_piece = self.args.data_range
|
| 28 |
+
|
| 29 |
+
if self.args.sampler_type == 'sample':
|
| 30 |
+
from swift.llm.sampling.vanilla_sampler import VanillaSampler
|
| 31 |
+
self.sampler = VanillaSampler(self.args)
|
| 32 |
+
elif self.args.sampler_type == 'mcts':
|
| 33 |
+
from swift.llm.sampling.mcts import MctsSampler
|
| 34 |
+
self.sampler = MctsSampler(self.args)
|
| 35 |
+
elif self.args.sampler_type == 'distill':
|
| 36 |
+
from swift.llm.sampling.distill_sampler import DistillSampler
|
| 37 |
+
self.sampler = DistillSampler(self.args)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')
|
| 40 |
+
|
| 41 |
+
def _get_dataset(self):
|
| 42 |
+
args = self.args
|
| 43 |
+
dataset_kwargs = args.get_dataset_kwargs()
|
| 44 |
+
sampling_dataset, _ = load_dataset(
|
| 45 |
+
args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs)
|
| 46 |
+
logger.info(f'Sampling_dataset: {sampling_dataset}')
|
| 47 |
+
dataset_len = len(sampling_dataset)
|
| 48 |
+
piece_len = dataset_len // self.total_piece
|
| 49 |
+
sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1)))
|
| 50 |
+
return sampling_dataset
|
| 51 |
+
|
| 52 |
+
def run(self):
|
| 53 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 54 |
+
iter_file = os.path.join(self.args.output_dir, self.args.output_file)
|
| 55 |
+
resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
|
| 56 |
+
tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
|
| 57 |
+
ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
|
| 58 |
+
if os.path.exists(iter_file) and not self.args.override_exist_file:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
index_resume = -1
|
| 62 |
+
write_mode = 'w'
|
| 63 |
+
if self.args.resume:
|
| 64 |
+
write_mode = 'a'
|
| 65 |
+
if os.path.exists(resume_file):
|
| 66 |
+
shutil.copyfile(resume_file, tmp_file)
|
| 67 |
+
|
| 68 |
+
if os.path.exists(ckpt_state_file):
|
| 69 |
+
with open(ckpt_state_file, 'r') as ckpt_state:
|
| 70 |
+
data = json.load(ckpt_state)
|
| 71 |
+
index_resume = data.get('index', -1)
|
| 72 |
+
logger.info(f'Loaded index_resume: {index_resume}')
|
| 73 |
+
else:
|
| 74 |
+
if os.path.exists(tmp_file):
|
| 75 |
+
os.remove(tmp_file)
|
| 76 |
+
|
| 77 |
+
dataset = self._get_dataset()
|
| 78 |
+
dataset_len = len(dataset)
|
| 79 |
+
total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
|
| 80 |
+
|
| 81 |
+
if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
|
| 82 |
+
self.args.num_sampling_per_gpu_batches = total_iters
|
| 83 |
+
|
| 84 |
+
with open(tmp_file, write_mode) as f:
|
| 85 |
+
for _index in range(self.args.num_sampling_per_gpu_batches):
|
| 86 |
+
if _index <= index_resume:
|
| 87 |
+
continue
|
| 88 |
+
logger.info(f' Sampling index:{_index}')
|
| 89 |
+
slices = dataset[self.args.num_sampling_per_gpu_batch_size
|
| 90 |
+
* _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
|
| 91 |
+
slices = self.sampler.truncate_input(slices)
|
| 92 |
+
generated = self.sampler.do_sample(slices)
|
| 93 |
+
f.writelines(generated)
|
| 94 |
+
f.flush()
|
| 95 |
+
shutil.copy(tmp_file, resume_file)
|
| 96 |
+
with open(ckpt_state_file, 'w') as ckpt_state:
|
| 97 |
+
json.dump({'index': _index}, ckpt_state)
|
| 98 |
+
|
| 99 |
+
if os.path.exists(iter_file):
|
| 100 |
+
shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
|
| 101 |
+
shutil.move(resume_file, iter_file)
|
| 102 |
+
logger.info(f'Sample file {iter_file} generated.')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
|
| 106 |
+
return SwiftSampling(args).main()
|
swift/llm/sampling/utils.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import inspect
|
| 3 |
+
from copy import copy
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import InferRequest, RequestConfig
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_messages_md5(row: Dict[str, Any]):
|
| 16 |
+
row = copy(row)
|
| 17 |
+
row.pop('choices', None)
|
| 18 |
+
serialized = json.dumps(row, sort_keys=True)
|
| 19 |
+
return hashlib.md5(serialized.encode('utf-8')).hexdigest()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_reward(model: Any,
|
| 23 |
+
infer_requests: List[InferRequest],
|
| 24 |
+
request_config: RequestConfig = None,
|
| 25 |
+
ground_truths: List[str] = None,
|
| 26 |
+
threshold: Optional[float] = None):
|
| 27 |
+
"""Get reward from an RM model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model: The model instance or an RM evaluator
|
| 31 |
+
infer_requests: Infer requests sent to the model
|
| 32 |
+
request_config: Infer config
|
| 33 |
+
ground_truths: The ground truth list
|
| 34 |
+
threshold: An optional threshold to generate the mask
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple
|
| 38 |
+
Index 0: The min-max normalized scores matched the infer_requests
|
| 39 |
+
Index 1: The mask filtered by the threshold
|
| 40 |
+
"""
|
| 41 |
+
from swift.llm import InferEngine
|
| 42 |
+
infer_func = model.infer if isinstance(model, InferEngine) else model.__call__
|
| 43 |
+
parameters = inspect.signature(infer_func).parameters
|
| 44 |
+
gt_param = {}
|
| 45 |
+
if 'ground_truths' in parameters:
|
| 46 |
+
gt_param = {'ground_truths': ground_truths}
|
| 47 |
+
if isinstance(infer_requests[0], dict):
|
| 48 |
+
infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests]
|
| 49 |
+
rewards = infer_func(infer_requests, request_config=request_config, **gt_param)
|
| 50 |
+
from swift.llm.infer.protocol import ChatCompletionResponse
|
| 51 |
+
if isinstance(rewards[0], ChatCompletionResponse):
|
| 52 |
+
print('reward:', rewards[0].choices[0].message.content)
|
| 53 |
+
if isinstance(rewards[0].choices[0].message.content, str):
|
| 54 |
+
rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards]
|
| 55 |
+
elif isinstance(rewards[0].choices[0].message.content, list):
|
| 56 |
+
rewards = [float(min(r.choices[0].message.content)) for r in rewards]
|
| 57 |
+
else:
|
| 58 |
+
rewards = [float(r.choices[0].message.content) for r in rewards]
|
| 59 |
+
arr = []
|
| 60 |
+
for reward in rewards:
|
| 61 |
+
if isinstance(reward, (list, tuple)):
|
| 62 |
+
arr.append(min(reward))
|
| 63 |
+
else:
|
| 64 |
+
arr.append(float(reward))
|
| 65 |
+
|
| 66 |
+
_mask = np.array([True] * len(arr))
|
| 67 |
+
if threshold is not None:
|
| 68 |
+
# > not >=, orm caller passes 0, which will cause error
|
| 69 |
+
_mask = np.array([a > threshold for a in arr])
|
| 70 |
+
|
| 71 |
+
def normalize(arr):
|
| 72 |
+
min_val = np.min(arr)
|
| 73 |
+
max_val = np.max(arr)
|
| 74 |
+
if min_val == max_val:
|
| 75 |
+
if min_val == 0:
|
| 76 |
+
constant_value = 0.0
|
| 77 |
+
else:
|
| 78 |
+
constant_value = min(1.0, min_val)
|
| 79 |
+
return np.full_like(arr, fill_value=constant_value, dtype=np.float64)
|
| 80 |
+
normalized = (arr - min_val) / (max_val - min_val + 1e-5)
|
| 81 |
+
return normalized
|
| 82 |
+
|
| 83 |
+
return normalize(arr), _mask
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs):
|
| 87 |
+
if isinstance(infer_engines, list):
|
| 88 |
+
assert len(infer_engines) >= len(request_configs) >= len(infer_requests)
|
| 89 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 90 |
+
n = len(infer_requests)
|
| 91 |
+
with ThreadPoolExecutor(max_workers=n) as executor:
|
| 92 |
+
futures = {
|
| 93 |
+
executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs):
|
| 94 |
+
i
|
| 95 |
+
for i in range(n)
|
| 96 |
+
}
|
| 97 |
+
responses = []
|
| 98 |
+
for future in as_completed(futures):
|
| 99 |
+
task_id = futures[future]
|
| 100 |
+
try:
|
| 101 |
+
responses += future.result()
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.info(f'Perform infer task: {task_id} get an error: {e}')
|
| 104 |
+
return responses
|
| 105 |
+
elif isinstance(infer_requests, list):
|
| 106 |
+
responses = []
|
| 107 |
+
if isinstance(request_configs, list):
|
| 108 |
+
assert len(infer_requests) <= len(request_configs)
|
| 109 |
+
for i in range(len(infer_requests)):
|
| 110 |
+
responses += infer_engines.infer(
|
| 111 |
+
[infer_requests[i]],
|
| 112 |
+
request_configs[i],
|
| 113 |
+
**infer_kwargs,
|
| 114 |
+
)
|
| 115 |
+
elif isinstance(request_configs, RequestConfig):
|
| 116 |
+
for infer_request in infer_requests:
|
| 117 |
+
responses += infer_engines.infer(
|
| 118 |
+
[infer_request],
|
| 119 |
+
request_configs,
|
| 120 |
+
**infer_kwargs,
|
| 121 |
+
)
|
| 122 |
+
return responses
|
| 123 |
+
return infer_engines.infer(
|
| 124 |
+
[infer_requests],
|
| 125 |
+
request_configs,
|
| 126 |
+
**infer_kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def collect_from_mct(monte_carlo_tree, collect_filter_threshold):
|
| 131 |
+
from transformers.utils import strtobool
|
| 132 |
+
if isinstance(monte_carlo_tree, str):
|
| 133 |
+
monte_carlo_tree = json.loads(monte_carlo_tree)
|
| 134 |
+
|
| 135 |
+
def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]):
|
| 136 |
+
_prefer_pairs, _correct_answers, _incorrect_answers = [], [], []
|
| 137 |
+
_outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']]
|
| 138 |
+
_process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']]
|
| 139 |
+
if len(collect_curr_node['children']) > 0:
|
| 140 |
+
for child in collect_curr_node['children']:
|
| 141 |
+
p, c, i = _collect(child, _outcome_rewards, _process_rewards)
|
| 142 |
+
_prefer_pairs += p
|
| 143 |
+
_correct_answers += c
|
| 144 |
+
_incorrect_answers += i
|
| 145 |
+
sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward'])
|
| 146 |
+
if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold:
|
| 147 |
+
# TODO: filter with visit count
|
| 148 |
+
prefer_pair = {
|
| 149 |
+
'path': 'ки\n'.join(collect_curr_node['path']),
|
| 150 |
+
'good': sorted_children[-1]['path'][-1],
|
| 151 |
+
'good_score': sorted_children[-1]['outcome_reward'],
|
| 152 |
+
'bad': sorted_children[0]['path'][-1],
|
| 153 |
+
'bad_score': sorted_children[0]['outcome_reward'],
|
| 154 |
+
}
|
| 155 |
+
_prefer_pairs.append(prefer_pair)
|
| 156 |
+
if strtobool(collect_curr_node['terminated']):
|
| 157 |
+
_answer = {
|
| 158 |
+
'answer': 'ки\n'.join(collect_curr_node['path']),
|
| 159 |
+
'mean_outcome_reward': np.mean(_outcome_rewards),
|
| 160 |
+
'min_outcome_reward': np.min(_outcome_rewards),
|
| 161 |
+
'mean_process_reward': np.mean(_process_rewards),
|
| 162 |
+
'min_process_reward': np.min(_process_rewards),
|
| 163 |
+
}
|
| 164 |
+
if strtobool(collect_curr_node['correct']):
|
| 165 |
+
_correct_answers.append(_answer)
|
| 166 |
+
else:
|
| 167 |
+
_incorrect_answers.append(_answer)
|
| 168 |
+
return _prefer_pairs, _correct_answers, _incorrect_answers
|
| 169 |
+
|
| 170 |
+
_root = monte_carlo_tree
|
| 171 |
+
prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], [])
|
| 172 |
+
return prefer_pairs, correct_answers, incorrect_answers
|
swift/llm/sampling/vanilla_sampler.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import RequestConfig
|
| 10 |
+
from swift.llm.sampling.base import Sampler
|
| 11 |
+
from swift.llm.template.template_inputs import InferRequest
|
| 12 |
+
from swift.utils import get_logger
|
| 13 |
+
from .utils import get_messages_md5, get_reward
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VanillaSampler(Sampler):
|
| 19 |
+
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
if self.args.sampler_engine == 'pt':
|
| 24 |
+
from swift.llm import PtEngine
|
| 25 |
+
_Engine = PtEngine
|
| 26 |
+
elif self.args.sampler_engine == 'vllm':
|
| 27 |
+
from swift.llm import VllmEngine
|
| 28 |
+
_Engine = VllmEngine
|
| 29 |
+
elif self.args.sampler_engine == 'lmdeploy':
|
| 30 |
+
from swift.llm import LmdeployEngine
|
| 31 |
+
_Engine = LmdeployEngine
|
| 32 |
+
elif self.args.sampler_engine == 'no':
|
| 33 |
+
_Engine = None
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
|
| 36 |
+
self.infer_engine = None
|
| 37 |
+
if _Engine:
|
| 38 |
+
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
|
| 39 |
+
self.infer_engine.default_template = self.template
|
| 40 |
+
self.infer_engine.strict = False
|
| 41 |
+
self.caches = self.read_cache()
|
| 42 |
+
|
| 43 |
+
def read_cache(self):
|
| 44 |
+
cache_files = self.args.cache_files
|
| 45 |
+
caches = {}
|
| 46 |
+
for file in cache_files:
|
| 47 |
+
if not os.path.exists(file):
|
| 48 |
+
logger.warning(f'Cache file does not exist: {file}')
|
| 49 |
+
continue
|
| 50 |
+
with open(file, 'r') as f:
|
| 51 |
+
for line in f.readlines():
|
| 52 |
+
line = line.strip()
|
| 53 |
+
if not line:
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
content = json.loads(line)
|
| 57 |
+
uuid = content['id']
|
| 58 |
+
messages = content['messages']
|
| 59 |
+
if uuid not in caches:
|
| 60 |
+
caches[uuid] = {'choices': []}
|
| 61 |
+
assert messages[-1]['role'] == 'assistant'
|
| 62 |
+
caches[uuid]['choices'].append(messages[-1]['content'])
|
| 63 |
+
return caches
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def convert_data_to_rows(data):
|
| 67 |
+
rows = []
|
| 68 |
+
key = list(data.keys())[0]
|
| 69 |
+
data_len = len(data[key])
|
| 70 |
+
for idx in range(data_len):
|
| 71 |
+
row = {key: data[key][idx] for key in data}
|
| 72 |
+
if row.get('images') and 'bytes' in row['images'][0]:
|
| 73 |
+
row['images'] = [img['path'] for img in row['images']]
|
| 74 |
+
rows.append(row)
|
| 75 |
+
VanillaSampler.check_row_valid(rows)
|
| 76 |
+
return rows
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def check_row_valid(rows):
|
| 80 |
+
for row in rows:
|
| 81 |
+
assert not row.get('images') or all([isinstance(img, str) and img for img in row['images']])
|
| 82 |
+
assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']])
|
| 83 |
+
assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']])
|
| 84 |
+
|
| 85 |
+
def generate(self, data):
|
| 86 |
+
resp_all = []
|
| 87 |
+
infer_requests = []
|
| 88 |
+
sent = 0
|
| 89 |
+
rows = self.convert_data_to_rows(data)
|
| 90 |
+
for idx, row in enumerate(rows):
|
| 91 |
+
row = deepcopy(row)
|
| 92 |
+
messages = row['messages']
|
| 93 |
+
uuid = get_messages_md5(row)
|
| 94 |
+
if uuid in self.caches:
|
| 95 |
+
choices = self.caches[uuid]['choices']
|
| 96 |
+
if len(choices) == self.args.num_return_sequences:
|
| 97 |
+
continue
|
| 98 |
+
if self.args.system:
|
| 99 |
+
if messages[0]['role'] == 'system':
|
| 100 |
+
messages[0]['content'] = self.args.system
|
| 101 |
+
else:
|
| 102 |
+
messages.insert(0, {'role': 'system', 'content': self.args.system})
|
| 103 |
+
if messages[-1]['role'] == 'assistant':
|
| 104 |
+
messages = messages[:-1]
|
| 105 |
+
|
| 106 |
+
row['messages'] = messages
|
| 107 |
+
infer_request = row
|
| 108 |
+
for i in range(self.args.num_return_sequences):
|
| 109 |
+
infer_requests.append(deepcopy(infer_request))
|
| 110 |
+
sent += 1
|
| 111 |
+
|
| 112 |
+
request_config = RequestConfig(
|
| 113 |
+
max_tokens=self.args.max_new_tokens,
|
| 114 |
+
temperature=self.args.temperature,
|
| 115 |
+
top_k=self.args.top_k,
|
| 116 |
+
top_p=self.args.top_p,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
resp_list = []
|
| 120 |
+
if len(infer_requests) > 0:
|
| 121 |
+
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
|
| 122 |
+
|
| 123 |
+
_cur = 0
|
| 124 |
+
for idx, row in enumerate(rows):
|
| 125 |
+
row = deepcopy(row)
|
| 126 |
+
uuid = get_messages_md5(row)
|
| 127 |
+
if uuid in self.caches:
|
| 128 |
+
choices = self.caches[uuid]['choices']
|
| 129 |
+
if len(choices) == self.args.num_return_sequences:
|
| 130 |
+
row['choices'] = choices
|
| 131 |
+
resp_all.append(row)
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
resps = row
|
| 135 |
+
resps['choices'] = []
|
| 136 |
+
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
|
| 137 |
+
if not isinstance(resp_list[j], Exception):
|
| 138 |
+
resps['choices'].append(resp_list[j].choices[0].message.content)
|
| 139 |
+
if resps['choices']:
|
| 140 |
+
resp_all.append(resps)
|
| 141 |
+
_cur += 1
|
| 142 |
+
return resp_all
|
| 143 |
+
|
| 144 |
+
def do_sample(self, data):
|
| 145 |
+
generated = []
|
| 146 |
+
resp_all = self.generate(data)
|
| 147 |
+
for i, resps in enumerate(resp_all):
|
| 148 |
+
choices = resps['choices']
|
| 149 |
+
messages = resps['messages']
|
| 150 |
+
uuid = get_messages_md5(resps)
|
| 151 |
+
assert messages[-1]['role'] == 'assistant'
|
| 152 |
+
ground_truth = messages[-1]['content']
|
| 153 |
+
|
| 154 |
+
infer_requests = []
|
| 155 |
+
for decoded in choices:
|
| 156 |
+
_resps = deepcopy(resps)
|
| 157 |
+
_resps['messages'][-1]['content'] = decoded
|
| 158 |
+
infer_requests.append(_resps)
|
| 159 |
+
|
| 160 |
+
_resps = deepcopy(resps)
|
| 161 |
+
_resps['messages'][-1]['content'] = ground_truth
|
| 162 |
+
infer_requests.append(_resps)
|
| 163 |
+
if self.orm_model is not None:
|
| 164 |
+
orm_score, _orm_mask = get_reward(
|
| 165 |
+
self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0)
|
| 166 |
+
else:
|
| 167 |
+
orm_score = np.array([1.0] * len(infer_requests))
|
| 168 |
+
_orm_mask = np.array([True] * len(infer_requests))
|
| 169 |
+
if self.prm_model is not None:
|
| 170 |
+
prm_score, _prm_mask = get_reward(
|
| 171 |
+
self.prm_model,
|
| 172 |
+
infer_requests,
|
| 173 |
+
ground_truths=[ground_truth] * len(infer_requests),
|
| 174 |
+
threshold=self.args.prm_threshold)
|
| 175 |
+
else:
|
| 176 |
+
prm_score = np.array([1.0] * len(infer_requests))
|
| 177 |
+
_prm_mask = np.array([True] * len(infer_requests))
|
| 178 |
+
|
| 179 |
+
_mask = _orm_mask & _prm_mask
|
| 180 |
+
if not any(_mask):
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
choices.append(ground_truth)
|
| 184 |
+
choices = np.array(choices)
|
| 185 |
+
|
| 186 |
+
if self.orm_model is None and self.prm_model is None:
|
| 187 |
+
positives = choices[:-1]
|
| 188 |
+
for positive in positives:
|
| 189 |
+
_resps = deepcopy(resps)
|
| 190 |
+
_resps.pop('choices', None)
|
| 191 |
+
_resps['id'] = uuid
|
| 192 |
+
_resps['messages'][-1]['content'] = str(positive)
|
| 193 |
+
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
|
| 194 |
+
else:
|
| 195 |
+
score = np.array(prm_score) + np.array(orm_score * 10)
|
| 196 |
+
sorted_indices = np.argsort(score)[::-1]
|
| 197 |
+
pos_indexes = sorted_indices[0:self.args.n_best_to_keep]
|
| 198 |
+
pos_indexes = [i for i in pos_indexes if _mask[i]]
|
| 199 |
+
neg_index = sorted_indices[-1]
|
| 200 |
+
logger.info(
|
| 201 |
+
f'orm:{orm_score}, prm:{prm_score}, positive index: {pos_indexes}, negative index: {neg_index}')
|
| 202 |
+
if self.args.easy_query_threshold is not None and sum([score > 0 for score in orm_score]) - 1 >= int(
|
| 203 |
+
self.args.num_return_sequences * self.args.easy_query_threshold):
|
| 204 |
+
continue
|
| 205 |
+
if len(pos_indexes) > 0:
|
| 206 |
+
positives = choices[pos_indexes]
|
| 207 |
+
negative = choices[neg_index]
|
| 208 |
+
for positive in positives:
|
| 209 |
+
_resps = deepcopy(resps)
|
| 210 |
+
messages = deepcopy(messages)
|
| 211 |
+
_resps.pop('choices', None)
|
| 212 |
+
_resps['messages'][-1]['content'] = str(positive)
|
| 213 |
+
_resps['rejected_response'] = str(negative)
|
| 214 |
+
_resps['id'] = uuid
|
| 215 |
+
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
|
| 216 |
+
return generated
|
swift/llm/template/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from . import template
|
| 3 |
+
from .base import MaxLengthError, Template
|
| 4 |
+
from .constant import TemplateType
|
| 5 |
+
from .grounding import draw_bbox
|
| 6 |
+
from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template
|
| 7 |
+
from .template_inputs import InferRequest, TemplateInputs
|
| 8 |
+
from .template_meta import TemplateMeta
|
| 9 |
+
from .utils import Prompt, Word, split_str_parts_by
|
| 10 |
+
from .vision_utils import load_file, load_image
|
swift/llm/template/base.py
ADDED
|
@@ -0,0 +1,1606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import hashlib
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from contextlib import contextmanager, nullcontext
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from dataclasses import asdict
|
| 10 |
+
from functools import partial, wraps
|
| 11 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 17 |
+
from peft import PeftModel
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 20 |
+
from transformers import StoppingCriteriaList
|
| 21 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 22 |
+
from transformers.utils import strtobool
|
| 23 |
+
|
| 24 |
+
from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
|
| 25 |
+
from ..utils import Processor, ProcessorMixin
|
| 26 |
+
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
|
| 27 |
+
from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by
|
| 28 |
+
from .vision_utils import load_audio, load_batch, load_image, rescale_image
|
| 29 |
+
|
| 30 |
+
logger = get_logger()
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from .template_meta import TemplateMeta
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MaxLengthError(ValueError):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Template(ProcessorMixin):
|
| 40 |
+
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>']
|
| 41 |
+
special_keys = ['images', 'videos', 'audios', 'objects']
|
| 42 |
+
|
| 43 |
+
image_placeholder = ['<image>']
|
| 44 |
+
video_placeholder = ['<video>']
|
| 45 |
+
audio_placeholder = ['<audio>']
|
| 46 |
+
cot_process_placeholder = ['ки']
|
| 47 |
+
placeholder_tokens = [] # For clearer printing
|
| 48 |
+
load_images = True
|
| 49 |
+
skip_prompt = True
|
| 50 |
+
use_model = False
|
| 51 |
+
norm_bbox = 'norm1000'
|
| 52 |
+
|
| 53 |
+
is_encoder_decoder = False
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
processor: Processor,
|
| 58 |
+
template_meta: 'TemplateMeta',
|
| 59 |
+
default_system: Optional[str] = None,
|
| 60 |
+
max_length: Optional[int] = None,
|
| 61 |
+
*,
|
| 62 |
+
use_chat_template: bool = True,
|
| 63 |
+
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
|
| 64 |
+
max_pixels: Optional[int] = None,
|
| 65 |
+
agent_template: Optional[str] = None,
|
| 66 |
+
norm_bbox: Literal['norm1000', 'none', None] = None,
|
| 67 |
+
response_prefix: Optional[str] = None,
|
| 68 |
+
# only for train
|
| 69 |
+
padding_side: Literal['left', 'right'] = 'right',
|
| 70 |
+
loss_scale: str = 'default',
|
| 71 |
+
sequence_parallel_size: int = 1,
|
| 72 |
+
# infer/deploy
|
| 73 |
+
template_backend: Literal['swift', 'jinja'] = 'swift',
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
default_system: Override the default_system in the template.
|
| 77 |
+
max_length: Max length of the sequence
|
| 78 |
+
truncation_strategy: The truncation strategy
|
| 79 |
+
max_pixels: Rescale image to reduce memory usage, default `None` means no limitation.
|
| 80 |
+
e.g. 512 * 512 (H*W)
|
| 81 |
+
padding_side: The padding_side when the training batch_size >= 2
|
| 82 |
+
loss_scale: The loss scale function to use
|
| 83 |
+
"""
|
| 84 |
+
from .template_meta import TemplateMeta
|
| 85 |
+
from swift.plugin import agent_templates
|
| 86 |
+
|
| 87 |
+
self.processor = processor
|
| 88 |
+
self.model_info = processor.model_info
|
| 89 |
+
self.config = self.model_info.config
|
| 90 |
+
self.model_meta = processor.model_meta
|
| 91 |
+
if max_length is None:
|
| 92 |
+
max_length = self.model_info.max_model_len
|
| 93 |
+
tokenizer = self.tokenizer
|
| 94 |
+
|
| 95 |
+
if not use_chat_template:
|
| 96 |
+
template_meta = template_meta.to_generate_template_meta()
|
| 97 |
+
else:
|
| 98 |
+
template_meta = deepcopy(template_meta)
|
| 99 |
+
# if default_system is None. not change self.default_system
|
| 100 |
+
template_meta.check_system(default_system)
|
| 101 |
+
if default_system is not None:
|
| 102 |
+
template_meta.default_system = default_system
|
| 103 |
+
if response_prefix is not None:
|
| 104 |
+
template_meta.response_prefix = response_prefix
|
| 105 |
+
logger.info(f'default_system: {repr(template_meta.default_system)}')
|
| 106 |
+
logger.info(f'response_prefix: {repr(template_meta.response_prefix)}')
|
| 107 |
+
|
| 108 |
+
for i, token in enumerate(self.placeholder_tokens):
|
| 109 |
+
if isinstance(token, str):
|
| 110 |
+
self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token)
|
| 111 |
+
template_meta.init(tokenizer)
|
| 112 |
+
|
| 113 |
+
self.template_meta: TemplateMeta = template_meta
|
| 114 |
+
self.use_chat_template = use_chat_template
|
| 115 |
+
self.template_backend = template_backend
|
| 116 |
+
self.max_length = max_length
|
| 117 |
+
self.truncation_strategy = truncation_strategy
|
| 118 |
+
self.loss_scale = loss_scale
|
| 119 |
+
self.max_pixels = max_pixels
|
| 120 |
+
self.padding_side = padding_side
|
| 121 |
+
self.sequence_parallel_size = sequence_parallel_size
|
| 122 |
+
agent_template = agent_template or template_meta.agent_template
|
| 123 |
+
logger.info(f'agent_template: {agent_template}')
|
| 124 |
+
self.agent_template = agent_templates[agent_template]()
|
| 125 |
+
self.norm_bbox = norm_bbox or self.norm_bbox
|
| 126 |
+
logger.info(f'max_length: {self.max_length}')
|
| 127 |
+
logger.info(f'norm_bbox: {self.norm_bbox}')
|
| 128 |
+
if self.is_encoder_decoder:
|
| 129 |
+
self.skip_prompt = False
|
| 130 |
+
self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
|
| 131 |
+
'train', 'rlhf', 'kto', # train
|
| 132 |
+
'seq_cls', 'embedding', 'prm'] = 'pt'
|
| 133 |
+
self._packing = False
|
| 134 |
+
self.use_megatron = False
|
| 135 |
+
if self.model_info.task_type != 'causal_lm':
|
| 136 |
+
self.mode = self.model_info.task_type
|
| 137 |
+
self._handles = []
|
| 138 |
+
self._deepspeed_initialize = None
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def _load_image(image, load_images: bool):
|
| 142 |
+
if load_images:
|
| 143 |
+
if isinstance(image, dict) and 'bytes' in image:
|
| 144 |
+
image = image['bytes'] or image['path']
|
| 145 |
+
image = load_image(image)
|
| 146 |
+
else:
|
| 147 |
+
if isinstance(image, dict):
|
| 148 |
+
path = image['path']
|
| 149 |
+
if path and (path.startswith('http') or os.path.exists(path)):
|
| 150 |
+
image = path
|
| 151 |
+
else:
|
| 152 |
+
image = load_image(image['bytes'])
|
| 153 |
+
elif not isinstance(image, str):
|
| 154 |
+
image = load_image(image)
|
| 155 |
+
return image
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _get_height_width(inputs: StdTemplateInputs) -> None:
|
| 159 |
+
width = []
|
| 160 |
+
height = []
|
| 161 |
+
for image in inputs.images:
|
| 162 |
+
width.append(image.width)
|
| 163 |
+
height.append(image.height)
|
| 164 |
+
inputs.objects['width'] = width
|
| 165 |
+
inputs.objects['height'] = height
|
| 166 |
+
|
| 167 |
+
def normalize_bbox(self, inputs: StdTemplateInputs) -> None:
|
| 168 |
+
objects = inputs.objects
|
| 169 |
+
bbox_list = objects['bbox']
|
| 170 |
+
width_list = objects['width']
|
| 171 |
+
height_list = objects['height']
|
| 172 |
+
bbox_type = objects.pop('bbox_type', None) or 'real'
|
| 173 |
+
image_id_list = objects.pop('image_id', None) or []
|
| 174 |
+
image_id_list += [0] * (len(bbox_list) - len(image_id_list))
|
| 175 |
+
for bbox, image_id in zip(bbox_list, image_id_list):
|
| 176 |
+
if bbox_type == 'norm1':
|
| 177 |
+
width, height = 1, 1
|
| 178 |
+
else:
|
| 179 |
+
width, height = width_list[image_id], height_list[image_id]
|
| 180 |
+
for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])):
|
| 181 |
+
if self.norm_bbox == 'norm1000':
|
| 182 |
+
norm_width, norm_height = 1000, 1000
|
| 183 |
+
elif self.norm_bbox == 'none':
|
| 184 |
+
image = inputs.images[image_id]
|
| 185 |
+
norm_width, norm_height = image.width, image.height
|
| 186 |
+
bbox[2 * i] = int(round(x / width * norm_width))
|
| 187 |
+
bbox[2 * i + 1] = int(round(y / height * norm_height))
|
| 188 |
+
|
| 189 |
+
def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
|
| 190 |
+
agent_template = self.agent_template
|
| 191 |
+
agent_template.template_meta = self.template_meta # for hermes
|
| 192 |
+
if inputs.tools:
|
| 193 |
+
if isinstance(inputs.tools, str):
|
| 194 |
+
inputs.tools = agent_template._parse_json(inputs.tools)
|
| 195 |
+
if not isinstance(inputs.tools, (list, tuple)):
|
| 196 |
+
inputs.tools = [inputs.tools]
|
| 197 |
+
elif isinstance(inputs.tools, (list, tuple)):
|
| 198 |
+
inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools]
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(f'inputs.tools: {inputs.tools}')
|
| 201 |
+
for i, tool in enumerate(inputs.tools):
|
| 202 |
+
inputs.tools[i] = agent_template.wrap_tool(tool)
|
| 203 |
+
i = 0
|
| 204 |
+
messages = inputs.messages
|
| 205 |
+
while i < len(messages):
|
| 206 |
+
if messages[i]['role'] == 'tool_call':
|
| 207 |
+
i_start = i
|
| 208 |
+
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call':
|
| 209 |
+
i += 1
|
| 210 |
+
tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1])
|
| 211 |
+
messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}]
|
| 212 |
+
i = i_start + 1
|
| 213 |
+
else:
|
| 214 |
+
i += 1
|
| 215 |
+
|
| 216 |
+
def _preprocess_inputs(
|
| 217 |
+
self,
|
| 218 |
+
inputs: StdTemplateInputs,
|
| 219 |
+
) -> None:
|
| 220 |
+
self._preprocess_function_call(inputs)
|
| 221 |
+
if self.model_meta.is_multimodal:
|
| 222 |
+
self._replace_image_tags(inputs)
|
| 223 |
+
self._replace_start_image_tags(inputs)
|
| 224 |
+
images = inputs.images
|
| 225 |
+
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
|
| 226 |
+
load_images_origin = load_images
|
| 227 |
+
if self.max_pixels is not None or inputs.objects:
|
| 228 |
+
load_images = True
|
| 229 |
+
if images:
|
| 230 |
+
for i, image in enumerate(images):
|
| 231 |
+
images[i] = self._load_image(images[i], load_images)
|
| 232 |
+
if inputs.objects:
|
| 233 |
+
self._get_height_width(inputs)
|
| 234 |
+
if self.max_pixels is not None:
|
| 235 |
+
# Scale the image proportionally without affecting the scaled objects.
|
| 236 |
+
images = [rescale_image(img, self.max_pixels) for img in images]
|
| 237 |
+
if images and not load_images_origin: # fix pt & qwen-vl
|
| 238 |
+
for i, image in enumerate(images):
|
| 239 |
+
if isinstance(image, Image.Image):
|
| 240 |
+
images[i] = self._save_pil_image(image)
|
| 241 |
+
inputs.images = images
|
| 242 |
+
|
| 243 |
+
if self.mode == 'vllm' and inputs.audios:
|
| 244 |
+
sampling_rate = get_env_args('sampling_rate', int, None)
|
| 245 |
+
inputs.audios = load_batch(
|
| 246 |
+
inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True))
|
| 247 |
+
|
| 248 |
+
if inputs.is_multimodal:
|
| 249 |
+
self._add_default_tags(inputs)
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _replace_image_tags(inputs: StdTemplateInputs):
|
| 253 |
+
# compat
|
| 254 |
+
if inputs.images:
|
| 255 |
+
return
|
| 256 |
+
images = []
|
| 257 |
+
pattern = r'<img>(.+?)</img>'
|
| 258 |
+
for message in inputs.messages:
|
| 259 |
+
content = message['content']
|
| 260 |
+
if not isinstance(content, str):
|
| 261 |
+
continue
|
| 262 |
+
for image in re.findall(pattern, content):
|
| 263 |
+
# only support local_path
|
| 264 |
+
if os.path.isfile(image):
|
| 265 |
+
images.append(image)
|
| 266 |
+
else:
|
| 267 |
+
logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>')
|
| 268 |
+
message['content'] = re.sub(pattern, '<image>', content)
|
| 269 |
+
inputs.images = images
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def _replace_start_image_tags(inputs: StdTemplateInputs):
|
| 273 |
+
# compat
|
| 274 |
+
generate_mode = False
|
| 275 |
+
message = inputs.messages[-1]
|
| 276 |
+
content = message['content']
|
| 277 |
+
if message['role'] == 'user' and content.endswith('<start-image>'):
|
| 278 |
+
generate_mode = True
|
| 279 |
+
message['content'] = message['content'][:-len('<start-image>')] # remove the <start-image>
|
| 280 |
+
inputs.generate_mode = generate_mode
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
|
| 284 |
+
get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
|
| 285 |
+
added_tokens_len = 0
|
| 286 |
+
for i, idx in enumerate(replace_idx_list):
|
| 287 |
+
new_tokens = get_new_tokens(i)
|
| 288 |
+
token_len = len(new_tokens)
|
| 289 |
+
input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
|
| 290 |
+
if labels:
|
| 291 |
+
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
|
| 292 |
+
added_tokens_len += token_len - 1
|
| 293 |
+
return input_ids, labels
|
| 294 |
+
|
| 295 |
+
def compute_loss_context(self, model, inputs):
|
| 296 |
+
return nullcontext()
|
| 297 |
+
|
| 298 |
+
def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 299 |
+
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
|
| 300 |
+
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
|
| 301 |
+
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
|
| 302 |
+
chosen_encoded = self._encode_truncated(chosen_inputs)
|
| 303 |
+
rejected_encoded = self._encode_truncated(rejected_inputs)
|
| 304 |
+
|
| 305 |
+
encoded = {}
|
| 306 |
+
for prefix in ['chosen', 'rejected']:
|
| 307 |
+
data = locals()[f'{prefix}_encoded']
|
| 308 |
+
for k, v in data.items():
|
| 309 |
+
encoded[f'{prefix}_{k}'] = v
|
| 310 |
+
return encoded
|
| 311 |
+
|
| 312 |
+
def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 313 |
+
label, inputs.label = inputs.label, None
|
| 314 |
+
encoded = self._rlhf_encode(inputs)
|
| 315 |
+
encoded['label'] = bool(label)
|
| 316 |
+
return encoded
|
| 317 |
+
|
| 318 |
+
def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 319 |
+
_encoded = {}
|
| 320 |
+
labels = []
|
| 321 |
+
|
| 322 |
+
def split_multi_medias(_inputs):
|
| 323 |
+
_content = _inputs.messages[-2]['content']
|
| 324 |
+
image_size = len(re.findall('<image>', _content))
|
| 325 |
+
video_size = len(re.findall('<video>', _content))
|
| 326 |
+
audio_size = len(re.findall('<audio>', _content))
|
| 327 |
+
_inputs.images = inputs.images[:image_size]
|
| 328 |
+
assert len(_inputs.images) == image_size
|
| 329 |
+
inputs.images = inputs.images[image_size:]
|
| 330 |
+
_inputs.videos = inputs.videos[:video_size]
|
| 331 |
+
assert len(_inputs.videos) == video_size
|
| 332 |
+
inputs.videos = inputs.videos[video_size:]
|
| 333 |
+
_inputs.audios = inputs.audios[:audio_size]
|
| 334 |
+
assert len(_inputs.audios) == audio_size
|
| 335 |
+
inputs.audios = inputs.audios[audio_size:]
|
| 336 |
+
|
| 337 |
+
anchor = deepcopy(inputs)
|
| 338 |
+
anchor.messages[-1]['content'] = ''
|
| 339 |
+
anchor.rejected_response = []
|
| 340 |
+
split_multi_medias(anchor)
|
| 341 |
+
anchor_encoded = self._encode_truncated(anchor)
|
| 342 |
+
for key in anchor_encoded:
|
| 343 |
+
_encoded[f'anchor_{key}'] = anchor_encoded[key]
|
| 344 |
+
|
| 345 |
+
positive = deepcopy(inputs)
|
| 346 |
+
positive.messages[-2]['content'] = positive.messages[-1]['content']
|
| 347 |
+
positive.messages[-1]['content'] = ''
|
| 348 |
+
positive.rejected_response = []
|
| 349 |
+
split_multi_medias(positive)
|
| 350 |
+
positive_encoded = self._encode_truncated(positive)
|
| 351 |
+
for key in positive_encoded:
|
| 352 |
+
_encoded[f'positive_{key}'] = positive_encoded[key]
|
| 353 |
+
labels.append(float(inputs.label) if inputs.label is not None else 1.0)
|
| 354 |
+
|
| 355 |
+
rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
|
| 356 |
+
for i in range(rejected_len):
|
| 357 |
+
negative = deepcopy(inputs)
|
| 358 |
+
negative.messages[-2]['content'] = negative.rejected_response[i]
|
| 359 |
+
negative.messages[-1]['content'] = ''
|
| 360 |
+
negative.rejected_response = []
|
| 361 |
+
split_multi_medias(negative)
|
| 362 |
+
negative_encoded = self._encode_truncated(negative)
|
| 363 |
+
for key in negative_encoded:
|
| 364 |
+
_encoded[f'negative{i}_{key}'] = negative_encoded[key]
|
| 365 |
+
labels.append(0.0)
|
| 366 |
+
|
| 367 |
+
_encoded['labels'] = labels
|
| 368 |
+
return _encoded
|
| 369 |
+
|
| 370 |
+
def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 371 |
+
encoded = self._encode_truncated(inputs)
|
| 372 |
+
encoded.pop('labels', None)
|
| 373 |
+
if inputs.label is not None:
|
| 374 |
+
labels = inputs.label
|
| 375 |
+
problem_type = self._get_problem_type(self.config, labels=labels)
|
| 376 |
+
if problem_type == 'single_label_classification':
|
| 377 |
+
labels = int(labels)
|
| 378 |
+
encoded['labels'] = labels
|
| 379 |
+
return encoded
|
| 380 |
+
|
| 381 |
+
@torch.inference_mode()
|
| 382 |
+
def encode(self,
|
| 383 |
+
inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
|
| 384 |
+
return_template_inputs: bool = False) -> Dict[str, Any]:
|
| 385 |
+
"""The entrance method of Template!
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
return {'input_ids': List[int], 'labels': Optional[List[int]], ...}
|
| 389 |
+
"""
|
| 390 |
+
if isinstance(inputs, (InferRequest, TemplateInputs)):
|
| 391 |
+
inputs = asdict(inputs)
|
| 392 |
+
|
| 393 |
+
if isinstance(inputs, dict):
|
| 394 |
+
inputs = deepcopy(inputs)
|
| 395 |
+
if not self.is_training:
|
| 396 |
+
InferRequest.remove_response(inputs['messages'])
|
| 397 |
+
inputs = StdTemplateInputs.from_dict(inputs)
|
| 398 |
+
elif isinstance(inputs, StdTemplateInputs):
|
| 399 |
+
inputs = deepcopy(inputs)
|
| 400 |
+
assert isinstance(inputs, StdTemplateInputs)
|
| 401 |
+
self._preprocess_inputs(inputs)
|
| 402 |
+
|
| 403 |
+
if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}:
|
| 404 |
+
encoded = self._encode_truncated(inputs)
|
| 405 |
+
elif self.mode == 'seq_cls':
|
| 406 |
+
encoded = self._seq_cls_encode(inputs)
|
| 407 |
+
elif self.mode == 'rlhf':
|
| 408 |
+
encoded = self._rlhf_encode(inputs)
|
| 409 |
+
elif self.mode == 'kto':
|
| 410 |
+
encoded = self._kto_encode(inputs)
|
| 411 |
+
elif self.mode == 'embedding':
|
| 412 |
+
encoded = self._embedding_encode(inputs)
|
| 413 |
+
for key in list(encoded.keys()):
|
| 414 |
+
if encoded[key] is None:
|
| 415 |
+
encoded.pop(key)
|
| 416 |
+
if return_template_inputs:
|
| 417 |
+
encoded['template_inputs'] = inputs
|
| 418 |
+
return encoded
|
| 419 |
+
|
| 420 |
+
def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
|
| 421 |
+
packed = {}
|
| 422 |
+
keys = set()
|
| 423 |
+
for r in row:
|
| 424 |
+
keys.update(r[0].keys())
|
| 425 |
+
for key in keys:
|
| 426 |
+
if key in {'input_ids', 'labels', 'loss_scale'}:
|
| 427 |
+
packed[key] = sum((x[0][key] for x in row), start=[])
|
| 428 |
+
if 'position_ids' not in packed:
|
| 429 |
+
packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
|
| 430 |
+
|
| 431 |
+
packed.update(self._data_collator_mm_data([r[0] for r in row]))
|
| 432 |
+
return packed
|
| 433 |
+
|
| 434 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 435 |
+
return inputs
|
| 436 |
+
|
| 437 |
+
@staticmethod
|
| 438 |
+
def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
|
| 439 |
+
len_tokens = len(stop_tokens)
|
| 440 |
+
if is_finished and generate_ids[-len_tokens:] == stop_tokens:
|
| 441 |
+
return generate_ids[:-len_tokens]
|
| 442 |
+
if not is_finished:
|
| 443 |
+
for i in range(len_tokens, 0, -1):
|
| 444 |
+
if generate_ids[-i:] == stop_tokens[:i]:
|
| 445 |
+
return generate_ids[:-i]
|
| 446 |
+
return generate_ids
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int):
|
| 450 |
+
idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist()
|
| 451 |
+
logprobs = logprobs.tolist()
|
| 452 |
+
return {
|
| 453 |
+
'content': [{
|
| 454 |
+
'index': pred,
|
| 455 |
+
'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred],
|
| 456 |
+
'top_logprobs': [{
|
| 457 |
+
'index': idx,
|
| 458 |
+
'logprob': logprobs[idx]
|
| 459 |
+
} for idx in idxs]
|
| 460 |
+
}]
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
@staticmethod
|
| 464 |
+
def _get_problem_type(config, labels=None, logits=None) -> str:
|
| 465 |
+
problem_type = config.problem_type
|
| 466 |
+
if problem_type is not None:
|
| 467 |
+
return problem_type
|
| 468 |
+
if labels is not None:
|
| 469 |
+
if isinstance(labels, (list, tuple)):
|
| 470 |
+
if labels and isinstance(labels[0], float):
|
| 471 |
+
problem_type = 'regression'
|
| 472 |
+
else:
|
| 473 |
+
problem_type = 'multi_label_classification'
|
| 474 |
+
else:
|
| 475 |
+
problem_type = 'single_label_classification'
|
| 476 |
+
assert config.num_labels >= labels + 1
|
| 477 |
+
if logits is not None:
|
| 478 |
+
if logits.shape[-1] == 1:
|
| 479 |
+
problem_type = 'regression'
|
| 480 |
+
else:
|
| 481 |
+
problem_type = 'single_label_classification' # compatible with older versions
|
| 482 |
+
assert problem_type is not None
|
| 483 |
+
config.problem_type = problem_type
|
| 484 |
+
return problem_type
|
| 485 |
+
|
| 486 |
+
def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int):
|
| 487 |
+
assert isinstance(logits, torch.Tensor)
|
| 488 |
+
problem_type = self._get_problem_type(self.config, logits=logits)
|
| 489 |
+
if problem_type == 'regression':
|
| 490 |
+
preds = logits.squeeze(dim=-1).tolist()
|
| 491 |
+
logprobs = [None] * len(preds)
|
| 492 |
+
else:
|
| 493 |
+
if problem_type == 'single_label_classification':
|
| 494 |
+
preds = torch.argmax(logits, dim=-1).tolist()
|
| 495 |
+
logprobs = torch.log_softmax(logits, -1)
|
| 496 |
+
else:
|
| 497 |
+
preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)]
|
| 498 |
+
logprobs = F.logsigmoid(logits)
|
| 499 |
+
logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)]
|
| 500 |
+
return preds, logprobs
|
| 501 |
+
|
| 502 |
+
def decode(self,
|
| 503 |
+
generate_ids: List[int],
|
| 504 |
+
*,
|
| 505 |
+
is_finished: bool = True,
|
| 506 |
+
tokenizer_kwargs=None,
|
| 507 |
+
first_token=True,
|
| 508 |
+
**kwargs) -> Any:
|
| 509 |
+
tokenizer_kwargs = tokenizer_kwargs or {}
|
| 510 |
+
response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs)
|
| 511 |
+
if first_token and self.template_meta.response_prefix:
|
| 512 |
+
response = self.template_meta.response_prefix + response
|
| 513 |
+
return response
|
| 514 |
+
|
| 515 |
+
def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
|
| 516 |
+
raise NotImplementedError
|
| 517 |
+
|
| 518 |
+
def generate(self, model, *args, **kwargs):
|
| 519 |
+
if isinstance(model, PeftModel):
|
| 520 |
+
signature = inspect.signature(model.model.generate)
|
| 521 |
+
else:
|
| 522 |
+
signature = inspect.signature(model.generate)
|
| 523 |
+
if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs:
|
| 524 |
+
kwargs['use_model_defaults'] = False
|
| 525 |
+
return model.generate(*args, **kwargs)
|
| 526 |
+
|
| 527 |
+
def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
|
| 528 |
+
# Do not print template_meta.suffix[-1] and eos_token.
|
| 529 |
+
# However, other stop_words will be printed.
|
| 530 |
+
tokenizer = self.tokenizer
|
| 531 |
+
|
| 532 |
+
if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
|
| 533 |
+
generate_ids = generate_ids[:-1]
|
| 534 |
+
# skip suffix and eos_token
|
| 535 |
+
template_suffix = self.template_meta.suffix[-1]
|
| 536 |
+
if isinstance(template_suffix, str):
|
| 537 |
+
# [-1:]: fix OpenGVLab/Mini-InternVL-Chat-4B-V1-5
|
| 538 |
+
template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:]
|
| 539 |
+
generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
|
| 540 |
+
if 'spaces_between_special_tokens' not in decode_kwargs:
|
| 541 |
+
decode_kwargs['spaces_between_special_tokens'] = False
|
| 542 |
+
return tokenizer.decode(generate_ids, **decode_kwargs)
|
| 543 |
+
|
| 544 |
+
def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
|
| 545 |
+
generation_config = generate_kwargs['generation_config']
|
| 546 |
+
stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words
|
| 547 |
+
generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)])
|
| 548 |
+
return generate_kwargs
|
| 549 |
+
|
| 550 |
+
@staticmethod
|
| 551 |
+
def _save_pil_image(image: Image.Image) -> str:
|
| 552 |
+
img_bytes = image.tobytes()
|
| 553 |
+
img_hash = hashlib.sha256(img_bytes).hexdigest()
|
| 554 |
+
tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
|
| 555 |
+
logger.info_once(f'create tmp_dir: {tmp_dir}')
|
| 556 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 557 |
+
img_path = os.path.join(tmp_dir, f'{img_hash}.png')
|
| 558 |
+
if not os.path.exists(img_path):
|
| 559 |
+
image.save(img_path)
|
| 560 |
+
return img_path
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def _concat_context_list(
|
| 564 |
+
context_list: List[Context],
|
| 565 |
+
res_context_list: List[Context], # inplace
|
| 566 |
+
res_context_type: List[ContextType], # inplace
|
| 567 |
+
system: Optional[str] = None,
|
| 568 |
+
query: Optional[str] = None,
|
| 569 |
+
response: Optional[str] = None,
|
| 570 |
+
round0: Optional[int] = None) -> None:
|
| 571 |
+
"""Concat context list and replace placeholder"""
|
| 572 |
+
round1 = None
|
| 573 |
+
if round0 is not None:
|
| 574 |
+
round1 = str(round0 + 1)
|
| 575 |
+
round0 = str(round0)
|
| 576 |
+
for context in context_list:
|
| 577 |
+
if isinstance(context, str):
|
| 578 |
+
if '{{RESPONSE}}' == context:
|
| 579 |
+
assert response is not None
|
| 580 |
+
res_context_list.append(response)
|
| 581 |
+
res_context_type.append(ContextType.RESPONSE)
|
| 582 |
+
continue
|
| 583 |
+
old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
|
| 584 |
+
new_str_list = [system, query, round0, round1]
|
| 585 |
+
for (old_str, new_str) in zip(old_str_list, new_str_list):
|
| 586 |
+
if new_str is not None and old_str in context:
|
| 587 |
+
assert isinstance(new_str, str), f'new_str: {new_str}'
|
| 588 |
+
context = context.replace(old_str, new_str)
|
| 589 |
+
if len(context) == 0:
|
| 590 |
+
continue
|
| 591 |
+
res_context_list.append(context)
|
| 592 |
+
res_context_type.append(ContextType.OTHER)
|
| 593 |
+
|
| 594 |
+
def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
|
| 595 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 596 |
+
"""Merge anything in the context to simplify the inputs"""
|
| 597 |
+
context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
|
| 598 |
+
context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs)
|
| 599 |
+
|
| 600 |
+
res: List[Context] = [] # result of context_list
|
| 601 |
+
res_loss_scale: List[float] = [] # result of loss_scale_list
|
| 602 |
+
temp: List[str] = []
|
| 603 |
+
temp_loss_scale = 0.
|
| 604 |
+
for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
|
| 605 |
+
if isinstance(context, str) and (loss_scale == temp_loss_scale):
|
| 606 |
+
temp.append(context)
|
| 607 |
+
else:
|
| 608 |
+
if len(temp) > 0:
|
| 609 |
+
res.append(''.join(temp))
|
| 610 |
+
res_loss_scale.append(temp_loss_scale)
|
| 611 |
+
temp.clear()
|
| 612 |
+
if isinstance(context, str): # loss_scale diff
|
| 613 |
+
temp.append(context)
|
| 614 |
+
else:
|
| 615 |
+
res.append(context)
|
| 616 |
+
res_loss_scale.append(loss_scale)
|
| 617 |
+
temp_loss_scale = loss_scale
|
| 618 |
+
if len(temp) > 0:
|
| 619 |
+
res.append(''.join(temp))
|
| 620 |
+
res_loss_scale.append(temp_loss_scale)
|
| 621 |
+
|
| 622 |
+
return res, res_loss_scale
|
| 623 |
+
|
| 624 |
+
@staticmethod
|
| 625 |
+
def _split_special_tokens(context_list: List[Context],
|
| 626 |
+
loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
|
| 627 |
+
"""Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation"""
|
| 628 |
+
res: List[Context] = []
|
| 629 |
+
loss_scale_res: List[float] = []
|
| 630 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 631 |
+
contexts = []
|
| 632 |
+
if isinstance(fetch_one(context), str):
|
| 633 |
+
for d in split_str_parts_by(context, Template.special_tokens):
|
| 634 |
+
contexts.extend([d['key'], d['content']])
|
| 635 |
+
contexts = [c for c in contexts if c]
|
| 636 |
+
res.extend(contexts)
|
| 637 |
+
loss_scale_res.extend([loss_scale] * len(contexts))
|
| 638 |
+
else:
|
| 639 |
+
res.append(context)
|
| 640 |
+
loss_scale_res.append(loss_scale)
|
| 641 |
+
return res, loss_scale_res
|
| 642 |
+
|
| 643 |
+
def _tokenize(self, context, **tokenizer_kwargs):
|
| 644 |
+
return self.tokenizer(
|
| 645 |
+
context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
|
| 646 |
+
|
| 647 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 648 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 649 |
+
"""Override this function to do your own replace operation.
|
| 650 |
+
|
| 651 |
+
This method is used to replace standard tags like `<image>` to some tokens that the model needs.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
media_type: The modal.
|
| 655 |
+
index: The index of the medias, for index 0 represents the first elements in `images`
|
| 656 |
+
inputs: The inputs
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
The content or input_ids after replacement.
|
| 660 |
+
"""
|
| 661 |
+
if media_type == 'image':
|
| 662 |
+
if self.mode == 'lmdeploy':
|
| 663 |
+
return [[-100]]
|
| 664 |
+
return self.image_placeholder
|
| 665 |
+
elif media_type == 'video':
|
| 666 |
+
return self.video_placeholder
|
| 667 |
+
elif media_type == 'audio':
|
| 668 |
+
return self.audio_placeholder
|
| 669 |
+
|
| 670 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 671 |
+
"""Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task.
|
| 672 |
+
Override this function to do your own replace operation.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
ref: Description of the bbox
|
| 676 |
+
index: The index in the `objects` key
|
| 677 |
+
inputs: The inputs
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
The contents or input_ids replaced
|
| 681 |
+
"""
|
| 682 |
+
return [ref]
|
| 683 |
+
|
| 684 |
+
def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]:
|
| 685 |
+
"""Replace the cot process label for PRM training or inference.
|
| 686 |
+
Override this function to do your own replace operation.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
inputs: The inputs
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
The contents or input_ids replaced
|
| 693 |
+
"""
|
| 694 |
+
return [self.cot_process_placeholder]
|
| 695 |
+
|
| 696 |
+
@staticmethod
|
| 697 |
+
def _get_bbox_str(bbox: List[int]) -> str:
|
| 698 |
+
point = []
|
| 699 |
+
for x, y in zip(bbox[::2], bbox[1::2]):
|
| 700 |
+
point.append(f'({x},{y})')
|
| 701 |
+
return ','.join(point)
|
| 702 |
+
|
| 703 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 704 |
+
"""Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task.
|
| 705 |
+
Override this function to do your own replace operation.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
bbox: [x, y] or [x1, y1, x2, y2]
|
| 709 |
+
index: The index in the `objects` key
|
| 710 |
+
inputs: The inputs
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
The contents or input_ids replaced
|
| 714 |
+
"""
|
| 715 |
+
return [f'[{self._get_bbox_str(bbox)}]']
|
| 716 |
+
|
| 717 |
+
def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float],
|
| 718 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 719 |
+
# https://github.com/modelscope/ms-swift/issues/3407
|
| 720 |
+
# Fix the bounding box position offset issue in the Qwen2.5-VL grounding task.
|
| 721 |
+
res: List[Context] = []
|
| 722 |
+
res_loss_scale: List[float] = []
|
| 723 |
+
inputs.image_idx = 0
|
| 724 |
+
|
| 725 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 726 |
+
if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
|
| 727 |
+
c_list = self.replace_tag('image', inputs.image_idx, inputs)
|
| 728 |
+
inputs.image_idx += 1
|
| 729 |
+
loss_scale = 0. if self.template_backend == 'swift' else 1.
|
| 730 |
+
else:
|
| 731 |
+
c_list = [context]
|
| 732 |
+
res += c_list
|
| 733 |
+
res_loss_scale += [loss_scale] * len(c_list)
|
| 734 |
+
return res, res_loss_scale
|
| 735 |
+
|
| 736 |
+
def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
|
| 737 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 738 |
+
"""This method happens before tokenization, replace standard tags to the contents or input_ids needed by
|
| 739 |
+
the model.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
context_list: The content list
|
| 743 |
+
loss_scale_list: The loss scale list
|
| 744 |
+
Returns:
|
| 745 |
+
The context_list and loss_scale_list after replacement.
|
| 746 |
+
"""
|
| 747 |
+
context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs)
|
| 748 |
+
if inputs.images and inputs.objects:
|
| 749 |
+
self.normalize_bbox(inputs)
|
| 750 |
+
# replace tag/object/box
|
| 751 |
+
res: List[Context] = [] # result of context_list
|
| 752 |
+
res_loss_scale: List[float] = [] # result of loss_scale_list
|
| 753 |
+
|
| 754 |
+
# reset
|
| 755 |
+
for k in ['video', 'audio', 'object', 'box']:
|
| 756 |
+
setattr(inputs, f'{k}_idx', 0)
|
| 757 |
+
|
| 758 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 759 |
+
for k in ['video', 'audio']:
|
| 760 |
+
if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len(
|
| 761 |
+
getattr(inputs, f'{k}s')):
|
| 762 |
+
c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs)
|
| 763 |
+
setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1)
|
| 764 |
+
loss_scale = 0.
|
| 765 |
+
break
|
| 766 |
+
else:
|
| 767 |
+
ref = inputs.objects.get('ref') or []
|
| 768 |
+
bbox = inputs.objects.get('bbox') or []
|
| 769 |
+
if context == '<ref-object>' and inputs.ref_idx < len(ref):
|
| 770 |
+
idx = inputs.ref_idx
|
| 771 |
+
c_list = self.replace_ref(ref[idx], idx, inputs)
|
| 772 |
+
inputs.ref_idx += 1
|
| 773 |
+
elif context == '<bbox>' and inputs.bbox_idx < len(bbox):
|
| 774 |
+
idx = inputs.bbox_idx
|
| 775 |
+
c_list = self.replace_bbox(bbox[idx], idx, inputs)
|
| 776 |
+
inputs.bbox_idx += 1
|
| 777 |
+
elif context == '<cot-process>' and self.mode == 'prm':
|
| 778 |
+
c_list = self.replace_cot_process(inputs)
|
| 779 |
+
else:
|
| 780 |
+
c_list = [context]
|
| 781 |
+
res += c_list
|
| 782 |
+
res_loss_scale += [loss_scale] * len(c_list)
|
| 783 |
+
return res, res_loss_scale
|
| 784 |
+
|
| 785 |
+
@staticmethod
|
| 786 |
+
def _add_default_tags(inputs: StdTemplateInputs):
|
| 787 |
+
total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
|
| 788 |
+
if inputs.rejected_response:
|
| 789 |
+
if isinstance(inputs.rejected_response, str):
|
| 790 |
+
total_content += inputs.rejected_response
|
| 791 |
+
else:
|
| 792 |
+
total_content += '\n'.join(inputs.rejected_response)
|
| 793 |
+
if inputs.system:
|
| 794 |
+
total_content = f'{inputs.system}\n{total_content}'
|
| 795 |
+
for media_type in ['image', 'audio', 'video']:
|
| 796 |
+
media_key, media_tag = f'{media_type}s', f'<{media_type}>'
|
| 797 |
+
medias = getattr(inputs, media_key)
|
| 798 |
+
if not isinstance(medias, list):
|
| 799 |
+
medias = [medias]
|
| 800 |
+
if medias:
|
| 801 |
+
num_media_tags = len(re.findall(media_tag, total_content))
|
| 802 |
+
num_media = len(medias)
|
| 803 |
+
num_new_tags = num_media - num_media_tags
|
| 804 |
+
if num_new_tags > 0:
|
| 805 |
+
inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content']
|
| 806 |
+
elif num_new_tags < 0:
|
| 807 |
+
logger.warning(
|
| 808 |
+
f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. '
|
| 809 |
+
'We will only replace the frontmost media_tags while keeping the subsequent media_tags.')
|
| 810 |
+
|
| 811 |
+
def _encode_context_list(
|
| 812 |
+
self,
|
| 813 |
+
context_list: List[Context],
|
| 814 |
+
loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
|
| 815 |
+
"""return: input_ids, labels, tokenizer_kwargs"""
|
| 816 |
+
input_ids: List[int] = []
|
| 817 |
+
labels: List[int] = []
|
| 818 |
+
loss_scale: List[float] = []
|
| 819 |
+
tokenizer_kwargs = {}
|
| 820 |
+
if loss_scale_list is None:
|
| 821 |
+
loss_scale_list = [0.] * len(context_list)
|
| 822 |
+
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
|
| 823 |
+
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
|
| 824 |
+
if isinstance(context, str):
|
| 825 |
+
# tokenizer_kwargs is the returned tokenizer_kwargs,
|
| 826 |
+
# while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
|
| 827 |
+
token_list = self._tokenize(context)
|
| 828 |
+
else:
|
| 829 |
+
token_list = context
|
| 830 |
+
input_ids += token_list
|
| 831 |
+
if loss_scale_list[i] > 0.0:
|
| 832 |
+
labels += token_list
|
| 833 |
+
else:
|
| 834 |
+
labels += [-100] * len(token_list)
|
| 835 |
+
if not ignore_loss_scale:
|
| 836 |
+
loss_scale.extend([loss_weight] * len(token_list))
|
| 837 |
+
if ignore_loss_scale:
|
| 838 |
+
loss_scale = None
|
| 839 |
+
return input_ids, labels, loss_scale, tokenizer_kwargs
|
| 840 |
+
|
| 841 |
+
@staticmethod
|
| 842 |
+
def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]],
|
| 843 |
+
suffix_tokens_id: List[int]) -> None:
|
| 844 |
+
suffix_len = len(suffix_tokens_id)
|
| 845 |
+
start = 0
|
| 846 |
+
for i in range(1, len(labels)):
|
| 847 |
+
if labels[i - 1] >= 0 and labels[i] == -100:
|
| 848 |
+
start = i
|
| 849 |
+
if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
|
| 850 |
+
# [0, 1, 2, -100(start), -100, 3(i), 4]
|
| 851 |
+
length = i - start
|
| 852 |
+
if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id:
|
| 853 |
+
labels[start:start + suffix_len] = suffix_tokens_id
|
| 854 |
+
if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len:
|
| 855 |
+
loss_scale[start:start + suffix_len] = [1] * suffix_len
|
| 856 |
+
|
| 857 |
+
@staticmethod
|
| 858 |
+
def _get_std_messages(messages):
|
| 859 |
+
if messages and messages[0]['role'] == 'assistant':
|
| 860 |
+
messages.insert(0, {'role': 'user', 'content': ''}) # pretrain
|
| 861 |
+
if len(messages) % 2 == 1:
|
| 862 |
+
messages.append({'role': 'assistant', 'content': None}) # inference
|
| 863 |
+
|
| 864 |
+
def _jinja_encode(self, inputs: StdTemplateInputs):
|
| 865 |
+
messages = inputs.messages.copy()
|
| 866 |
+
if inputs.system is not None:
|
| 867 |
+
messages.insert(0, {'role': 'system', 'content': inputs.system})
|
| 868 |
+
if messages[-1]['content'] is None:
|
| 869 |
+
messages.pop()
|
| 870 |
+
add_generation_prompt = messages[-1]['role'] != 'assistant'
|
| 871 |
+
kwargs = {}
|
| 872 |
+
if inputs.tools:
|
| 873 |
+
kwargs['tools'] = inputs.tools
|
| 874 |
+
text = self.tokenizer.apply_chat_template(
|
| 875 |
+
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
|
| 876 |
+
answer_len = 1 if self.is_training else 0
|
| 877 |
+
return [text], [1.], answer_len
|
| 878 |
+
|
| 879 |
+
def _get_system(self, inputs) -> Optional[str]:
|
| 880 |
+
template_meta = self.template_meta
|
| 881 |
+
system = inputs.system
|
| 882 |
+
tools = inputs.tools
|
| 883 |
+
template_meta.check_system(system)
|
| 884 |
+
if system is None:
|
| 885 |
+
system = template_meta.default_system
|
| 886 |
+
|
| 887 |
+
if tools is not None:
|
| 888 |
+
system = self.agent_template._format_tools(tools, system or '', inputs.messages[0])
|
| 889 |
+
return system
|
| 890 |
+
|
| 891 |
+
@staticmethod
|
| 892 |
+
def _swift_prepare_function_call(agent_template, messages):
|
| 893 |
+
if len(messages) < 2:
|
| 894 |
+
return
|
| 895 |
+
i = 1
|
| 896 |
+
while i < len(messages):
|
| 897 |
+
pre_message, message = messages[i - 1], messages[i]
|
| 898 |
+
pre_role, pre_content = pre_message['role'], pre_message['content']
|
| 899 |
+
role, content = message['role'], message['content']
|
| 900 |
+
if pre_role == 'assistant' and role == 'tool':
|
| 901 |
+
i_start = i
|
| 902 |
+
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
|
| 903 |
+
i += 1
|
| 904 |
+
pre_message['content'], tool_content = agent_template._format_tool_responses(
|
| 905 |
+
pre_content, messages[i_start:i + 1])
|
| 906 |
+
messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}]
|
| 907 |
+
i = i_start + 1
|
| 908 |
+
elif pre_role == 'assistant' and role == 'assistant':
|
| 909 |
+
# Consecutive messages from the assistant role need to be merged to prevent errors.
|
| 910 |
+
pre_message['content'] = pre_content + content
|
| 911 |
+
messages.pop(i)
|
| 912 |
+
else:
|
| 913 |
+
i += 1
|
| 914 |
+
|
| 915 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 916 |
+
template_meta = self.template_meta
|
| 917 |
+
system = self._get_system(inputs)
|
| 918 |
+
self._swift_prepare_function_call(self.agent_template, inputs.messages)
|
| 919 |
+
|
| 920 |
+
self._get_std_messages(inputs.messages)
|
| 921 |
+
n_round = len(inputs.messages) // 2
|
| 922 |
+
if n_round > 1 and not self.template_meta.support_multi_round:
|
| 923 |
+
logger.warning_once(
|
| 924 |
+
'The template does not support multi-round chat. Only use the last round of the conversation.')
|
| 925 |
+
inputs.messages = inputs.messages[-2:]
|
| 926 |
+
|
| 927 |
+
res_context_list: List[Context] = []
|
| 928 |
+
res_context_types: List[ContextType] = []
|
| 929 |
+
sep_token = None
|
| 930 |
+
if template_meta.auto_add_bos:
|
| 931 |
+
all_tokens = self.tokenizer.encode('a')
|
| 932 |
+
single_token = self.tokenizer.encode('a', add_special_tokens=False)
|
| 933 |
+
assert len(single_token) == 1
|
| 934 |
+
idx = all_tokens.index(single_token[0])
|
| 935 |
+
bos_token = all_tokens[:idx]
|
| 936 |
+
sep_token = all_tokens[idx + 1:]
|
| 937 |
+
if bos_token:
|
| 938 |
+
res_context_list.append(bos_token)
|
| 939 |
+
res_context_types.append(ContextType.OTHER)
|
| 940 |
+
|
| 941 |
+
prefix = template_meta.system_prefix if system else template_meta.prefix
|
| 942 |
+
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)
|
| 943 |
+
|
| 944 |
+
n_round = len(inputs.messages) // 2
|
| 945 |
+
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
|
| 946 |
+
query_role, query = query_message['role'], query_message['content']
|
| 947 |
+
response_role, response = response_message['role'], response_message['content']
|
| 948 |
+
# TODO: Optimize the Template mechanism.
|
| 949 |
+
assert query_role in {'user', 'tool'}, f'query_role: {query_role}'
|
| 950 |
+
assert response_role in {'assistant'}, f'response_role: {response_role}'
|
| 951 |
+
if query_role == 'tool':
|
| 952 |
+
prompt = query
|
| 953 |
+
query = ''
|
| 954 |
+
elif template_meta.is_post_system and i == n_round - 1:
|
| 955 |
+
prompt = template_meta.system_prompt
|
| 956 |
+
else:
|
| 957 |
+
prompt = template_meta.prompt
|
| 958 |
+
|
| 959 |
+
context_list = prompt.copy()
|
| 960 |
+
extra_context_list = []
|
| 961 |
+
extra_context_type = None
|
| 962 |
+
if i < n_round - 1:
|
| 963 |
+
# Not the last round.
|
| 964 |
+
context_list.append('{{RESPONSE}}')
|
| 965 |
+
if inputs.messages[2 * (i + 1)]['role'] != 'tool':
|
| 966 |
+
extra_context_list = template_meta.chat_sep
|
| 967 |
+
extra_context_type = ContextType.OTHER
|
| 968 |
+
elif response is not None:
|
| 969 |
+
# It is the final round, and the response exists (during training).
|
| 970 |
+
context_list.append('{{RESPONSE}}')
|
| 971 |
+
if self.is_training and not sep_token:
|
| 972 |
+
extra_context_list = template_meta.suffix
|
| 973 |
+
extra_context_type = ContextType.SUFFIX
|
| 974 |
+
elif template_meta.response_prefix:
|
| 975 |
+
# final round and during inference.
|
| 976 |
+
context_list.append(template_meta.response_prefix)
|
| 977 |
+
|
| 978 |
+
self._concat_context_list(
|
| 979 |
+
context_list,
|
| 980 |
+
res_context_list,
|
| 981 |
+
res_context_types,
|
| 982 |
+
query=query,
|
| 983 |
+
response=response,
|
| 984 |
+
system=system,
|
| 985 |
+
round0=i)
|
| 986 |
+
res_context_list += extra_context_list
|
| 987 |
+
res_context_types += [extra_context_type] * len(extra_context_list)
|
| 988 |
+
if template_meta.auto_add_bos and sep_token:
|
| 989 |
+
res_context_list.append(sep_token)
|
| 990 |
+
res_context_types.append(ContextType.SUFFIX)
|
| 991 |
+
from swift.plugin import loss_scale_map
|
| 992 |
+
res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types,
|
| 993 |
+
inputs.messages)
|
| 994 |
+
if self.is_training:
|
| 995 |
+
answer_len = len(extra_context_list) + bool(response is not None)
|
| 996 |
+
else:
|
| 997 |
+
answer_len = 0
|
| 998 |
+
return res_context_list, loss_scale_list, answer_len
|
| 999 |
+
|
| 1000 |
+
def _encode_truncated(self, inputs):
|
| 1001 |
+
if self.mode in {'vllm', 'lmdeploy'}:
|
| 1002 |
+
encoded = Template._encode(self, inputs)
|
| 1003 |
+
for key in ['images', 'audios', 'videos']:
|
| 1004 |
+
encoded[key] = getattr(inputs, key)
|
| 1005 |
+
else:
|
| 1006 |
+
encoded = self._encode(inputs)
|
| 1007 |
+
|
| 1008 |
+
input_ids = encoded.get('input_ids')
|
| 1009 |
+
labels = encoded.get('labels')
|
| 1010 |
+
loss_scale = encoded.get('loss_scale')
|
| 1011 |
+
if self.max_length is not None:
|
| 1012 |
+
if self.truncation_strategy == 'right':
|
| 1013 |
+
input_ids = input_ids[:self.max_length]
|
| 1014 |
+
if labels is not None:
|
| 1015 |
+
labels = labels[:self.max_length]
|
| 1016 |
+
if loss_scale is not None:
|
| 1017 |
+
loss_scale = loss_scale[:self.max_length]
|
| 1018 |
+
elif self.truncation_strategy == 'left':
|
| 1019 |
+
if len(input_ids) > self.max_length:
|
| 1020 |
+
logger.warning_once(
|
| 1021 |
+
'Input data was left-truncated because its length exceeds `max_length` (input length: '
|
| 1022 |
+
f'{len(input_ids)}, max_length: {self.max_length}). '
|
| 1023 |
+
'This may cause loss of important tokens (e.g., image tokens) and lead to errors. '
|
| 1024 |
+
'To avoid this, consider increasing `max_length` or pre-filtering long sequences.',
|
| 1025 |
+
hash_id='max_length_check')
|
| 1026 |
+
input_ids = input_ids[-self.max_length:]
|
| 1027 |
+
if labels is not None:
|
| 1028 |
+
labels = labels[-self.max_length:]
|
| 1029 |
+
if loss_scale is not None:
|
| 1030 |
+
loss_scale = loss_scale[-self.max_length:]
|
| 1031 |
+
elif self.truncation_strategy == 'raise':
|
| 1032 |
+
length = len(input_ids or labels or [])
|
| 1033 |
+
if length > self.max_length:
|
| 1034 |
+
raise MaxLengthError(f'Current length of row({length}) is larger'
|
| 1035 |
+
f' than the max_length({self.max_length}).')
|
| 1036 |
+
encoded['input_ids'] = input_ids
|
| 1037 |
+
encoded['labels'] = labels
|
| 1038 |
+
encoded['loss_scale'] = loss_scale
|
| 1039 |
+
return encoded
|
| 1040 |
+
|
| 1041 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 1042 |
+
template_backend = self.template_backend
|
| 1043 |
+
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
|
| 1044 |
+
and self.mode != 'seq_cls'):
|
| 1045 |
+
template_backend = 'jinja'
|
| 1046 |
+
logger.info_once(f'Setting template_backend: {template_backend}')
|
| 1047 |
+
res_context_list, loss_scale_list, answer_len = (
|
| 1048 |
+
self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs))
|
| 1049 |
+
encoded = {}
|
| 1050 |
+
if self.is_encoder_decoder:
|
| 1051 |
+
# tokenizer_kwargs: use prompt (qwen-audio)
|
| 1052 |
+
total_len = len(res_context_list)
|
| 1053 |
+
for key, _slice in zip(['prompt', 'answer'],
|
| 1054 |
+
[slice(0, total_len - answer_len),
|
| 1055 |
+
slice(total_len - answer_len, total_len)]):
|
| 1056 |
+
context_list, loss_scale = self._simplify_context_list(res_context_list[_slice],
|
| 1057 |
+
loss_scale_list[_slice], inputs)
|
| 1058 |
+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale)
|
| 1059 |
+
encoded[f'{key}_input_ids'] = input_ids
|
| 1060 |
+
if key == 'answer':
|
| 1061 |
+
encoded['labels'] = labels
|
| 1062 |
+
encoded['loss_scale'] = loss_scale
|
| 1063 |
+
input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids']
|
| 1064 |
+
else:
|
| 1065 |
+
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
|
| 1066 |
+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
|
| 1067 |
+
res_context_list, loss_scale_list)
|
| 1068 |
+
self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0])
|
| 1069 |
+
|
| 1070 |
+
if tokenizer_kwargs:
|
| 1071 |
+
encoded['tokenizer_kwargs'] = tokenizer_kwargs
|
| 1072 |
+
|
| 1073 |
+
encoded['input_ids'] = input_ids
|
| 1074 |
+
encoded['labels'] = labels
|
| 1075 |
+
encoded['loss_scale'] = loss_scale
|
| 1076 |
+
if self.use_megatron:
|
| 1077 |
+
self._handle_megatron_cp(encoded)
|
| 1078 |
+
encoded['labels'] = encoded['labels'][1:] + [-100]
|
| 1079 |
+
encoded['position_ids'] = list(range(len(encoded['labels'])))
|
| 1080 |
+
elif encoded.get('labels') is not None:
|
| 1081 |
+
encoded['labels'][0] = -100
|
| 1082 |
+
if not self.is_training:
|
| 1083 |
+
for k in list(encoded.keys()):
|
| 1084 |
+
if k.endswith('labels') or k.endswith('loss_scale'):
|
| 1085 |
+
encoded[k] = None
|
| 1086 |
+
return encoded
|
| 1087 |
+
|
| 1088 |
+
def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None:
|
| 1089 |
+
cp_size = self.sequence_parallel_size
|
| 1090 |
+
if cp_size == 1:
|
| 1091 |
+
return
|
| 1092 |
+
input_ids = encoded['input_ids']
|
| 1093 |
+
padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids)
|
| 1094 |
+
input_ids += [self.tokenizer.pad_token_id] * padding_len
|
| 1095 |
+
encoded['labels'] += [-100] * padding_len
|
| 1096 |
+
|
| 1097 |
+
def debug_logger(self, inputs):
|
| 1098 |
+
if not strtobool(os.getenv('SWIFT_DEBUG', 'false')):
|
| 1099 |
+
return
|
| 1100 |
+
if 'input_ids' in inputs:
|
| 1101 |
+
k = 'input_ids'
|
| 1102 |
+
val = inputs['input_ids']
|
| 1103 |
+
elif 'generate_ids' in inputs:
|
| 1104 |
+
k = 'generate_ids'
|
| 1105 |
+
val = inputs['generate_ids']
|
| 1106 |
+
for v in val:
|
| 1107 |
+
self.print_inputs({k: v.tolist()})
|
| 1108 |
+
|
| 1109 |
+
@staticmethod
|
| 1110 |
+
def _split_list(inputs: List[int], x: int) -> List[List[int]]:
|
| 1111 |
+
idxs = findall(inputs, x)
|
| 1112 |
+
idxs.append(len(inputs))
|
| 1113 |
+
res = []
|
| 1114 |
+
lo = 0
|
| 1115 |
+
for idx in idxs:
|
| 1116 |
+
res.append(inputs[lo:idx])
|
| 1117 |
+
lo = idx + 1
|
| 1118 |
+
return res
|
| 1119 |
+
|
| 1120 |
+
def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
|
| 1121 |
+
context_list = []
|
| 1122 |
+
if self.mode in {'vllm', 'lmdeploy'}:
|
| 1123 |
+
video = inputs.videos.pop(inputs.video_idx)
|
| 1124 |
+
inputs.video_idx -= 1
|
| 1125 |
+
else:
|
| 1126 |
+
video = inputs.videos[inputs.video_idx]
|
| 1127 |
+
images = inputs.images
|
| 1128 |
+
new_images = load_video_func(video)
|
| 1129 |
+
inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
|
| 1130 |
+
for i in range(len(new_images)):
|
| 1131 |
+
context_list += replace_tag(i)
|
| 1132 |
+
inputs.image_idx += len(new_images)
|
| 1133 |
+
return context_list
|
| 1134 |
+
|
| 1135 |
+
def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]],
|
| 1136 |
+
num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]:
|
| 1137 |
+
if self.skip_prompt:
|
| 1138 |
+
generate_ids = generate_ids[..., num_prompt_tokens:]
|
| 1139 |
+
return generate_ids
|
| 1140 |
+
|
| 1141 |
+
def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str:
|
| 1142 |
+
return response
|
| 1143 |
+
|
| 1144 |
+
def pre_forward_hook(self, model: nn.Module, args, kwargs):
|
| 1145 |
+
from swift.llm import to_device
|
| 1146 |
+
old_kwargs = to_device(kwargs, model.device)
|
| 1147 |
+
kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
|
| 1148 |
+
for k, v in old_kwargs.items():
|
| 1149 |
+
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
|
| 1150 |
+
kwargs[k] = v
|
| 1151 |
+
if 'inputs_embeds' in kwargs:
|
| 1152 |
+
kwargs.pop('input_ids', None)
|
| 1153 |
+
|
| 1154 |
+
if isinstance(model, PeftModel):
|
| 1155 |
+
parameters = inspect.signature(model.model.forward).parameters
|
| 1156 |
+
else:
|
| 1157 |
+
parameters = inspect.signature(model.forward).parameters
|
| 1158 |
+
if 'position_ids' not in parameters:
|
| 1159 |
+
kwargs.pop('position_ids', None)
|
| 1160 |
+
return args, kwargs
|
| 1161 |
+
|
| 1162 |
+
@property
|
| 1163 |
+
def is_training(self):
|
| 1164 |
+
return self.mode not in {'vllm', 'lmdeploy', 'pt'}
|
| 1165 |
+
|
| 1166 |
+
def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
|
| 1167 |
+
self.mode = mode
|
| 1168 |
+
|
| 1169 |
+
def register_post_encode_hook(self, models: List[nn.Module]) -> None:
|
| 1170 |
+
"""This function is important for multi-modal training, as it registers the post_encode method
|
| 1171 |
+
as a forward hook, converting input_ids into inputs_embeds.
|
| 1172 |
+
"""
|
| 1173 |
+
if self._handles:
|
| 1174 |
+
return
|
| 1175 |
+
|
| 1176 |
+
for model in models:
|
| 1177 |
+
# please use torch>=2.0
|
| 1178 |
+
handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True)
|
| 1179 |
+
self._handles.append((model, handle))
|
| 1180 |
+
|
| 1181 |
+
if is_deepspeed_zero3_enabled():
|
| 1182 |
+
import deepspeed
|
| 1183 |
+
self._deepspeed_initialize = deepspeed.initialize
|
| 1184 |
+
|
| 1185 |
+
@wraps(self._deepspeed_initialize)
|
| 1186 |
+
def _initialize(*args, **kwargs):
|
| 1187 |
+
res = self._deepspeed_initialize(*args, **kwargs)
|
| 1188 |
+
for model, handle in self._handles:
|
| 1189 |
+
model._forward_pre_hooks.move_to_end(handle.id)
|
| 1190 |
+
return res
|
| 1191 |
+
|
| 1192 |
+
deepspeed.initialize = _initialize
|
| 1193 |
+
|
| 1194 |
+
def remove_post_encode_hook(self):
|
| 1195 |
+
models = []
|
| 1196 |
+
for model, handle in self._handles:
|
| 1197 |
+
models.append(model)
|
| 1198 |
+
handle.remove()
|
| 1199 |
+
self._handles = []
|
| 1200 |
+
|
| 1201 |
+
if self._deepspeed_initialize is not None:
|
| 1202 |
+
import deepspeed
|
| 1203 |
+
deepspeed.initialize = self._deepspeed_initialize
|
| 1204 |
+
self._deepspeed_initialize = None
|
| 1205 |
+
return models
|
| 1206 |
+
|
| 1207 |
+
def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1208 |
+
if self.mode == 'rlhf':
|
| 1209 |
+
return self._rlhf_data_collator(batch, padding_to=padding_to)
|
| 1210 |
+
elif self.mode == 'kto':
|
| 1211 |
+
return self._kto_data_collator(batch, padding_to=padding_to)
|
| 1212 |
+
elif self.mode in {'pt', 'train', 'prm'}:
|
| 1213 |
+
return self._data_collator(batch, padding_to=padding_to)
|
| 1214 |
+
elif self.mode == 'seq_cls':
|
| 1215 |
+
return self._seq_cls_data_collator(batch, padding_to=padding_to)
|
| 1216 |
+
elif self.mode == 'embedding':
|
| 1217 |
+
return self._embedding_data_collator(batch, padding_to=padding_to)
|
| 1218 |
+
|
| 1219 |
+
@staticmethod
|
| 1220 |
+
def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]:
|
| 1221 |
+
new_batch = []
|
| 1222 |
+
for inputs in batch:
|
| 1223 |
+
new_inputs = {}
|
| 1224 |
+
for k, v in inputs.items():
|
| 1225 |
+
if k.startswith(prefix):
|
| 1226 |
+
new_inputs[k[len(prefix):]] = v
|
| 1227 |
+
new_batch.append(new_inputs)
|
| 1228 |
+
return new_batch
|
| 1229 |
+
|
| 1230 |
+
@staticmethod
|
| 1231 |
+
def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 1232 |
+
from swift.llm import RowPreprocessor
|
| 1233 |
+
keys = keys or []
|
| 1234 |
+
rows = RowPreprocessor.rows_to_batched(batch)
|
| 1235 |
+
return {k: rows[k] for k in keys if rows.get(k) is not None}
|
| 1236 |
+
|
| 1237 |
+
@staticmethod
|
| 1238 |
+
def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
|
| 1239 |
+
# List[Tensor] -> List[Tensor]
|
| 1240 |
+
res = []
|
| 1241 |
+
for b in batch:
|
| 1242 |
+
if b.get(attr_name) is not None:
|
| 1243 |
+
res += b.pop(attr_name)
|
| 1244 |
+
return res
|
| 1245 |
+
|
| 1246 |
+
@staticmethod
|
| 1247 |
+
def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]:
|
| 1248 |
+
res = []
|
| 1249 |
+
for b in batch:
|
| 1250 |
+
if b.get(attr_name) is not None:
|
| 1251 |
+
res.append(b.pop(attr_name))
|
| 1252 |
+
return torch.concat(res, dim=dim) if res else None
|
| 1253 |
+
|
| 1254 |
+
def _rlhf_data_collator(self,
|
| 1255 |
+
batch: List[Dict[str, Any]],
|
| 1256 |
+
*,
|
| 1257 |
+
chosen_prefix: str = 'chosen_',
|
| 1258 |
+
rejected_prefix: str = 'rejected_',
|
| 1259 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1260 |
+
new_batch = []
|
| 1261 |
+
for prefix in [chosen_prefix, rejected_prefix]:
|
| 1262 |
+
new_batch += self._fetch_inputs_startswith(batch, prefix)
|
| 1263 |
+
return self._data_collator(new_batch, padding_to=padding_to)
|
| 1264 |
+
|
| 1265 |
+
def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1266 |
+
new_batch = self._fetch_inputs_startswith(batch, 'chosen_')
|
| 1267 |
+
kl_batch = self._fetch_inputs_startswith(batch, 'rejected_')
|
| 1268 |
+
|
| 1269 |
+
res = self._data_collator(new_batch, padding_to=padding_to)
|
| 1270 |
+
kl_res = self._data_collator(kl_batch, padding_to=padding_to)
|
| 1271 |
+
res = {
|
| 1272 |
+
**{f'completion_{k}': v
|
| 1273 |
+
for k, v in res.items()},
|
| 1274 |
+
**{f'KL_completion_{k}': v
|
| 1275 |
+
for k, v in kl_res.items()},
|
| 1276 |
+
}
|
| 1277 |
+
label = [b['label'] for b in batch if b.get('label') is not None]
|
| 1278 |
+
if label:
|
| 1279 |
+
res['label'] = label
|
| 1280 |
+
return res
|
| 1281 |
+
|
| 1282 |
+
def _embedding_data_collator(self,
|
| 1283 |
+
batch: List[Dict[str, Any]],
|
| 1284 |
+
*,
|
| 1285 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1286 |
+
labels = []
|
| 1287 |
+
new_batch = []
|
| 1288 |
+
for b in batch:
|
| 1289 |
+
keys = [key for key in b.keys() if 'negative' in key]
|
| 1290 |
+
max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
|
| 1291 |
+
indexes = ['anchor_', 'positive_']
|
| 1292 |
+
if max_neg is not None:
|
| 1293 |
+
for i in range(0, max_neg + 1):
|
| 1294 |
+
indexes.append(f'negative{i}_')
|
| 1295 |
+
for prefix in indexes:
|
| 1296 |
+
new_batch += self._fetch_inputs_startswith([b], prefix)
|
| 1297 |
+
labels.extend(b.get('labels', None))
|
| 1298 |
+
res = self._data_collator(new_batch, padding_to=padding_to)
|
| 1299 |
+
if labels:
|
| 1300 |
+
res['labels'] = torch.tensor(labels, dtype=torch.float32)
|
| 1301 |
+
return res
|
| 1302 |
+
|
| 1303 |
+
def _seq_cls_data_collator(self,
|
| 1304 |
+
batch: List[Dict[str, Any]],
|
| 1305 |
+
*,
|
| 1306 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1307 |
+
labels = [b.pop('labels') for b in batch if b.get('labels') is not None]
|
| 1308 |
+
res = self._data_collator(batch, padding_to=padding_to)
|
| 1309 |
+
if labels:
|
| 1310 |
+
problem_type = self._get_problem_type(self.config)
|
| 1311 |
+
if problem_type == 'regression':
|
| 1312 |
+
labels = torch.tensor(labels, dtype=torch.float32)
|
| 1313 |
+
elif problem_type == 'multi_label_classification':
|
| 1314 |
+
one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32)
|
| 1315 |
+
for i, label in enumerate(labels):
|
| 1316 |
+
one_hot_labels[i, label] = 1
|
| 1317 |
+
labels = one_hot_labels
|
| 1318 |
+
else:
|
| 1319 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
| 1320 |
+
res['labels'] = labels
|
| 1321 |
+
return res
|
| 1322 |
+
|
| 1323 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1324 |
+
"""
|
| 1325 |
+
Args:
|
| 1326 |
+
batch(`List[Dict[str, Any]]`): The input data in batch
|
| 1327 |
+
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
|
| 1328 |
+
will be padded to the `longest`
|
| 1329 |
+
"""
|
| 1330 |
+
assert self.tokenizer.pad_token_id is not None
|
| 1331 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 1332 |
+
padding_right = padding_side == 'right'
|
| 1333 |
+
packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0]
|
| 1334 |
+
res = {}
|
| 1335 |
+
if packing_mode:
|
| 1336 |
+
# only support llm
|
| 1337 |
+
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']:
|
| 1338 |
+
v = self.gather_list(batch, k)
|
| 1339 |
+
if v:
|
| 1340 |
+
res[k] = [v]
|
| 1341 |
+
else:
|
| 1342 |
+
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
|
| 1343 |
+
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
|
| 1344 |
+
if inputs_embeds:
|
| 1345 |
+
res['inputs_embeds'] = inputs_embeds
|
| 1346 |
+
if input_ids:
|
| 1347 |
+
res['input_ids'] = input_ids
|
| 1348 |
+
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
|
| 1349 |
+
val = [b[key] for b in batch if b.get(key) is not None]
|
| 1350 |
+
if val:
|
| 1351 |
+
res[key] = val
|
| 1352 |
+
|
| 1353 |
+
keys = [
|
| 1354 |
+
'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids'
|
| 1355 |
+
]
|
| 1356 |
+
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
|
| 1357 |
+
# Convert to tensor and remove unnecessary dimensions.
|
| 1358 |
+
seq_lens = None
|
| 1359 |
+
for key in keys:
|
| 1360 |
+
if key not in res:
|
| 1361 |
+
continue
|
| 1362 |
+
for i, val in enumerate(res[key]):
|
| 1363 |
+
if isinstance(val, (list, tuple)):
|
| 1364 |
+
val = torch.tensor(val)
|
| 1365 |
+
elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2:
|
| 1366 |
+
val = val[0]
|
| 1367 |
+
res[key][i] = val
|
| 1368 |
+
if not seq_lens:
|
| 1369 |
+
seq_lens = [seq.shape[0] for seq in res[key]]
|
| 1370 |
+
if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res):
|
| 1371 |
+
res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens]
|
| 1372 |
+
if self.is_training and self.padding_side == 'left':
|
| 1373 |
+
res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens]
|
| 1374 |
+
|
| 1375 |
+
if self.use_megatron:
|
| 1376 |
+
padding_to = math.ceil(max(seq_lens) / 128) * 128
|
| 1377 |
+
cp_size = self.sequence_parallel_size
|
| 1378 |
+
if cp_size > 1:
|
| 1379 |
+
padding_len = padding_to - seq_lens[0]
|
| 1380 |
+
position_ids = res['position_ids'][0].tolist()
|
| 1381 |
+
position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2))
|
| 1382 |
+
res['position_ids'][0] = torch.tensor(position_ids)
|
| 1383 |
+
|
| 1384 |
+
for key, pad_value in zip(keys, pad_values):
|
| 1385 |
+
if key not in res:
|
| 1386 |
+
continue
|
| 1387 |
+
if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1:
|
| 1388 |
+
pass
|
| 1389 |
+
elif padding_to is not None:
|
| 1390 |
+
padding_len = padding_to - seq_lens[0]
|
| 1391 |
+
if padding_len > 0:
|
| 1392 |
+
res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
|
| 1393 |
+
'constant', pad_value)
|
| 1394 |
+
res[key] = self._pad_sequence(res[key], pad_value)
|
| 1395 |
+
|
| 1396 |
+
# multimodal
|
| 1397 |
+
res.update(self._data_collator_mm_data(batch))
|
| 1398 |
+
if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1):
|
| 1399 |
+
res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)
|
| 1400 |
+
|
| 1401 |
+
return res
|
| 1402 |
+
|
| 1403 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1404 |
+
# multimodal
|
| 1405 |
+
res = {}
|
| 1406 |
+
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
|
| 1407 |
+
if len(pixel_values) > 0:
|
| 1408 |
+
res['pixel_values'] = torch.concat(pixel_values)
|
| 1409 |
+
|
| 1410 |
+
image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
|
| 1411 |
+
if len(image_sizes) > 0:
|
| 1412 |
+
res['image_sizes'] = torch.concat(image_sizes)
|
| 1413 |
+
|
| 1414 |
+
pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
|
| 1415 |
+
if len(pixel_values_videos) > 0:
|
| 1416 |
+
res['pixel_values_videos'] = torch.concat(pixel_values_videos)
|
| 1417 |
+
return res
|
| 1418 |
+
|
| 1419 |
+
def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side):
|
| 1420 |
+
# torchacc & xtuner
|
| 1421 |
+
input_ids = res.get('input_ids')
|
| 1422 |
+
attention_mask = res.get('attention_mask')
|
| 1423 |
+
labels = res.get('labels')
|
| 1424 |
+
loss_scale = res.get('loss_scale')
|
| 1425 |
+
if use_torchacc():
|
| 1426 |
+
from swift.utils.torchacc_utils import pad_and_split_batch
|
| 1427 |
+
rank, _, world_size, _ = get_dist_setting()
|
| 1428 |
+
input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
|
| 1429 |
+
padding_to,
|
| 1430 |
+
input_ids,
|
| 1431 |
+
attention_mask,
|
| 1432 |
+
labels,
|
| 1433 |
+
loss_scale,
|
| 1434 |
+
self.max_length,
|
| 1435 |
+
tokenizer,
|
| 1436 |
+
rank,
|
| 1437 |
+
world_size,
|
| 1438 |
+
padding_right=padding_side == 'right')
|
| 1439 |
+
if self.sequence_parallel_size > 1 and input_ids is not None:
|
| 1440 |
+
bs, seq_len = input_ids.shape
|
| 1441 |
+
if 'position_ids' not in res:
|
| 1442 |
+
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
|
| 1443 |
+
else:
|
| 1444 |
+
position_ids = res['position_ids']
|
| 1445 |
+
assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right'
|
| 1446 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 1447 |
+
if sequence_parallel.world_size() > 1:
|
| 1448 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 1449 |
+
input_ids, _, labels, position_ids, attention_mask, loss_scale = \
|
| 1450 |
+
sequence_parallel.pad_and_split_inputs(
|
| 1451 |
+
tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale)
|
| 1452 |
+
res['position_ids'] = position_ids
|
| 1453 |
+
_local_var = locals()
|
| 1454 |
+
for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']:
|
| 1455 |
+
value = _local_var[key]
|
| 1456 |
+
if value is not None:
|
| 1457 |
+
res[key] = value
|
| 1458 |
+
return res
|
| 1459 |
+
|
| 1460 |
+
def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
| 1461 |
+
if tokenizer_kwargs is None:
|
| 1462 |
+
tokenizer_kwargs = {}
|
| 1463 |
+
for key in [
|
| 1464 |
+
'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
|
| 1465 |
+
]:
|
| 1466 |
+
val = inputs.get(key) # fix val is a tensor
|
| 1467 |
+
if val is None:
|
| 1468 |
+
val = inputs.get(f'{key}_ids')
|
| 1469 |
+
if val is not None:
|
| 1470 |
+
key_upper = key.upper()
|
| 1471 |
+
logger.info(f'[{key_upper}_IDS] {val}')
|
| 1472 |
+
if key == 'labels' and self.mode in {'seq_cls', 'embedding'}:
|
| 1473 |
+
continue
|
| 1474 |
+
if isinstance(val, (list, tuple, torch.Tensor)):
|
| 1475 |
+
val_str = self.safe_decode(val, **tokenizer_kwargs)
|
| 1476 |
+
logger.info(f'[{key_upper}] {val_str}')
|
| 1477 |
+
if inputs.get('loss_scale') is not None:
|
| 1478 |
+
val = inputs['loss_scale']
|
| 1479 |
+
logger.info(f'[LOSS_SCALE] {val}')
|
| 1480 |
+
|
| 1481 |
+
async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None:
|
| 1482 |
+
images = inputs.pop('images', None) or []
|
| 1483 |
+
if len(images) == 0:
|
| 1484 |
+
return
|
| 1485 |
+
input_ids = inputs['input_ids']
|
| 1486 |
+
idx_list = findall(input_ids, -100)
|
| 1487 |
+
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
|
| 1488 |
+
idx_list.insert(0, -1)
|
| 1489 |
+
new_input_ids = []
|
| 1490 |
+
for i in range(len(idx_list) - 1):
|
| 1491 |
+
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
|
| 1492 |
+
images[i]['offset'] = len(new_input_ids)
|
| 1493 |
+
new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens']
|
| 1494 |
+
new_input_ids += input_ids[idx_list[-1] + 1:]
|
| 1495 |
+
inputs['input_ids'] = new_input_ids
|
| 1496 |
+
inputs['multimodal'] = images
|
| 1497 |
+
|
| 1498 |
+
async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
|
| 1499 |
+
images = inputs.pop('images', None) or []
|
| 1500 |
+
if len(images) == 0:
|
| 1501 |
+
return
|
| 1502 |
+
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
|
| 1503 |
+
input_ids = inputs['input_ids']
|
| 1504 |
+
idx_list = findall(input_ids, -100)
|
| 1505 |
+
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
|
| 1506 |
+
idx_list.insert(0, -1)
|
| 1507 |
+
new_input_ids = []
|
| 1508 |
+
ranges = []
|
| 1509 |
+
for i in range(len(idx_list) - 1):
|
| 1510 |
+
_range = []
|
| 1511 |
+
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
|
| 1512 |
+
_range.append(len(new_input_ids))
|
| 1513 |
+
new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0]
|
| 1514 |
+
_range.append(len(new_input_ids))
|
| 1515 |
+
ranges.append(_range)
|
| 1516 |
+
new_input_ids += input_ids[idx_list[-1] + 1:]
|
| 1517 |
+
inputs['input_embeddings'] = [image.to('cpu') for image in images]
|
| 1518 |
+
inputs['input_embedding_ranges'] = ranges
|
| 1519 |
+
inputs['input_ids'] = new_input_ids
|
| 1520 |
+
|
| 1521 |
+
def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor:
|
| 1522 |
+
"""Pad sequence by some side
|
| 1523 |
+
|
| 1524 |
+
Args:
|
| 1525 |
+
sequences: The input sequences in tensor.
|
| 1526 |
+
padding_value: The padding value
|
| 1527 |
+
|
| 1528 |
+
Returns:
|
| 1529 |
+
A tensor after padding
|
| 1530 |
+
"""
|
| 1531 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 1532 |
+
padding_right = padding_side == 'right'
|
| 1533 |
+
if padding_right:
|
| 1534 |
+
return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
|
| 1535 |
+
|
| 1536 |
+
max_len = max([s.shape[0] for s in sequences])
|
| 1537 |
+
|
| 1538 |
+
padded_sequences = []
|
| 1539 |
+
for seq in sequences:
|
| 1540 |
+
pad_length = max_len - seq.shape[0]
|
| 1541 |
+
pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
|
| 1542 |
+
padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
|
| 1543 |
+
padded_sequences.append(padded_seq)
|
| 1544 |
+
|
| 1545 |
+
return torch.stack(padded_sequences)
|
| 1546 |
+
|
| 1547 |
+
def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str:
|
| 1548 |
+
if isinstance(self, Template):
|
| 1549 |
+
tokenizer = self.tokenizer
|
| 1550 |
+
placeholder_tokens = self.placeholder_tokens
|
| 1551 |
+
else:
|
| 1552 |
+
tokenizer = self
|
| 1553 |
+
placeholder_tokens = []
|
| 1554 |
+
|
| 1555 |
+
def _is_special(token: int) -> bool:
|
| 1556 |
+
if isinstance(token, float) or token < 0:
|
| 1557 |
+
return True
|
| 1558 |
+
return token in placeholder_tokens
|
| 1559 |
+
|
| 1560 |
+
if isinstance(input_ids, torch.Tensor):
|
| 1561 |
+
input_ids = input_ids.tolist()
|
| 1562 |
+
if len(input_ids) == 0:
|
| 1563 |
+
return ''
|
| 1564 |
+
result_str = ''
|
| 1565 |
+
for i in range(len(input_ids)):
|
| 1566 |
+
if i == 0:
|
| 1567 |
+
if _is_special(input_ids[i]):
|
| 1568 |
+
s = 0
|
| 1569 |
+
else:
|
| 1570 |
+
e = 0
|
| 1571 |
+
continue
|
| 1572 |
+
if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]):
|
| 1573 |
+
s = i
|
| 1574 |
+
result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
|
| 1575 |
+
if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
|
| 1576 |
+
e = i
|
| 1577 |
+
result_str += f'[{input_ids[i - 1]} * {e - s}]'
|
| 1578 |
+
if _is_special(input_ids[i]):
|
| 1579 |
+
result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
|
| 1580 |
+
else:
|
| 1581 |
+
result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
|
| 1582 |
+
return result_str
|
| 1583 |
+
|
| 1584 |
+
@staticmethod
|
| 1585 |
+
@contextmanager
|
| 1586 |
+
def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False):
|
| 1587 |
+
_origin_flash_attention_forward = modeling_module._flash_attention_forward
|
| 1588 |
+
|
| 1589 |
+
def _flash_attention_forward(*args, **kwargs):
|
| 1590 |
+
if use_new_func:
|
| 1591 |
+
from transformers.modeling_flash_attention_utils import (_flash_attention_forward as
|
| 1592 |
+
flash_attention_forward)
|
| 1593 |
+
if args and isinstance(args[0], nn.Module):
|
| 1594 |
+
args = args[1:]
|
| 1595 |
+
if 'is_causal' not in kwargs:
|
| 1596 |
+
kwargs['is_causal'] = True
|
| 1597 |
+
else:
|
| 1598 |
+
flash_attention_forward = _origin_flash_attention_forward
|
| 1599 |
+
kwargs['position_ids'] = position_ids
|
| 1600 |
+
return flash_attention_forward(*args, **kwargs)
|
| 1601 |
+
|
| 1602 |
+
modeling_module._flash_attention_forward = _flash_attention_forward
|
| 1603 |
+
try:
|
| 1604 |
+
yield
|
| 1605 |
+
finally:
|
| 1606 |
+
modeling_module._flash_attention_forward = _origin_flash_attention_forward
|
swift/llm/template/constant.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LLMTemplateType:
|
| 7 |
+
chatml = 'chatml'
|
| 8 |
+
default = 'default'
|
| 9 |
+
dummy = 'dummy'
|
| 10 |
+
|
| 11 |
+
qwen = 'qwen'
|
| 12 |
+
qwen2_5 = 'qwen2_5'
|
| 13 |
+
qwen2_5_math = 'qwen2_5_math'
|
| 14 |
+
qwen2_5_math_prm = 'qwen2_5_math_prm'
|
| 15 |
+
qwen3 = 'qwen3'
|
| 16 |
+
qwq_preview = 'qwq_preview'
|
| 17 |
+
qwq = 'qwq'
|
| 18 |
+
marco_o1 = 'marco_o1'
|
| 19 |
+
modelscope_agent = 'modelscope_agent'
|
| 20 |
+
|
| 21 |
+
llama = 'llama' # llama2
|
| 22 |
+
llama3 = 'llama3'
|
| 23 |
+
llama3_2 = 'llama3_2'
|
| 24 |
+
reflection = 'reflection'
|
| 25 |
+
megrez = 'megrez'
|
| 26 |
+
yi_coder = 'yi_coder'
|
| 27 |
+
sus = 'sus'
|
| 28 |
+
|
| 29 |
+
minimax = 'minimax'
|
| 30 |
+
minimax_vl = 'minimax_vl'
|
| 31 |
+
|
| 32 |
+
numina = 'numina'
|
| 33 |
+
ziya = 'ziya'
|
| 34 |
+
atom = 'atom'
|
| 35 |
+
mengzi = 'mengzi'
|
| 36 |
+
|
| 37 |
+
chatglm2 = 'chatglm2'
|
| 38 |
+
glm4 = 'glm4'
|
| 39 |
+
glm4_0414 = 'glm4_0414'
|
| 40 |
+
glm4_z1_rumination = 'glm4_z1_rumination'
|
| 41 |
+
codegeex4 = 'codegeex4'
|
| 42 |
+
longwriter_llama = 'longwriter_llama'
|
| 43 |
+
|
| 44 |
+
internlm = 'internlm'
|
| 45 |
+
internlm2 = 'internlm2'
|
| 46 |
+
internlm3 = 'internlm3'
|
| 47 |
+
|
| 48 |
+
deepseek = 'deepseek'
|
| 49 |
+
deepseek_coder = 'deepseek_coder'
|
| 50 |
+
deepseek_v2_5 = 'deepseek_v2_5'
|
| 51 |
+
deepseek_r1 = 'deepseek_r1'
|
| 52 |
+
|
| 53 |
+
openbuddy = 'openbuddy'
|
| 54 |
+
openbuddy2 = 'openbuddy2'
|
| 55 |
+
baichuan = 'baichuan'
|
| 56 |
+
baichuan_m1 = 'baichuan_m1'
|
| 57 |
+
minicpm = 'minicpm'
|
| 58 |
+
telechat = 'telechat'
|
| 59 |
+
telechat2 = 'telechat2'
|
| 60 |
+
|
| 61 |
+
codefuse = 'codefuse'
|
| 62 |
+
codefuse_codellama = 'codefuse_codellama'
|
| 63 |
+
|
| 64 |
+
skywork = 'skywork'
|
| 65 |
+
skywork_o1 = 'skywork_o1'
|
| 66 |
+
|
| 67 |
+
mistral_nemo = 'mistral_nemo'
|
| 68 |
+
mistral_2501 = 'mistral_2501'
|
| 69 |
+
zephyr = 'zephyr'
|
| 70 |
+
wizardlm2 = 'wizardlm2'
|
| 71 |
+
wizardlm2_moe = 'wizardlm2_moe'
|
| 72 |
+
gemma = 'gemma'
|
| 73 |
+
gemma3_text = 'gemma3_text'
|
| 74 |
+
phi3 = 'phi3'
|
| 75 |
+
phi4 = 'phi4'
|
| 76 |
+
|
| 77 |
+
ling = 'ling'
|
| 78 |
+
yuan = 'yuan'
|
| 79 |
+
xverse = 'xverse'
|
| 80 |
+
bluelm = 'bluelm'
|
| 81 |
+
orion = 'orion'
|
| 82 |
+
moonlight = 'moonlight'
|
| 83 |
+
|
| 84 |
+
aya = 'aya'
|
| 85 |
+
c4ai = 'c4ai'
|
| 86 |
+
dbrx = 'dbrx'
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class RMTemplateType:
|
| 90 |
+
internlm2_reward = 'internlm2_reward'
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class MLLMTemplateType:
|
| 94 |
+
qwen_vl = 'qwen_vl'
|
| 95 |
+
qwen_audio = 'qwen_audio'
|
| 96 |
+
qwen2_vl = 'qwen2_vl'
|
| 97 |
+
qwen2_5_vl = 'qwen2_5_vl'
|
| 98 |
+
qwen2_5_omni = 'qwen2_5_omni'
|
| 99 |
+
qwen2_audio = 'qwen2_audio'
|
| 100 |
+
qwen2_gme = 'qwen2_gme'
|
| 101 |
+
qvq = 'qvq'
|
| 102 |
+
ovis1_6 = 'ovis1_6'
|
| 103 |
+
ovis1_6_llama3 = 'ovis1_6_llama3'
|
| 104 |
+
ovis2 = 'ovis2'
|
| 105 |
+
|
| 106 |
+
llama3_1_omni = 'llama3_1_omni'
|
| 107 |
+
llama3_2_vision = 'llama3_2_vision'
|
| 108 |
+
llama4 = 'llama4'
|
| 109 |
+
|
| 110 |
+
llava1_5_hf = 'llava1_5_hf'
|
| 111 |
+
llava1_6_mistral_hf = 'llava1_6_mistral_hf'
|
| 112 |
+
llava1_6_vicuna_hf = 'llava1_6_vicuna_hf'
|
| 113 |
+
llava1_6_yi_hf = 'llava1_6_yi_hf'
|
| 114 |
+
llama3_llava_next_hf = 'llama3_llava_next_hf'
|
| 115 |
+
llava_next_qwen_hf = 'llava_next_qwen_hf'
|
| 116 |
+
llava_onevision_hf = 'llava_onevision_hf'
|
| 117 |
+
llava_next_video_hf = 'llava_next_video_hf'
|
| 118 |
+
|
| 119 |
+
llava_llama3_1_hf = 'llava_llama3_1_hf' # DaozeZhang
|
| 120 |
+
llava_llama3_hf = 'llava_llama3_hf' # xtuner
|
| 121 |
+
# lmms-lab
|
| 122 |
+
llava1_6_mistral = 'llava1_6_mistral'
|
| 123 |
+
llava1_6_yi = 'llava1_6_yi'
|
| 124 |
+
llava_next_qwen = 'llava_next_qwen'
|
| 125 |
+
llama3_llava_next = 'llama3_llava_next'
|
| 126 |
+
|
| 127 |
+
yi_vl = 'yi_vl'
|
| 128 |
+
|
| 129 |
+
internvl = 'internvl'
|
| 130 |
+
internvl_phi3 = 'internvl_phi3'
|
| 131 |
+
internvl2 = 'internvl2'
|
| 132 |
+
internvl2_phi3 = 'internvl2_phi3'
|
| 133 |
+
internvl2_5 = 'internvl2_5'
|
| 134 |
+
|
| 135 |
+
xcomposer2 = 'ixcomposer2'
|
| 136 |
+
xcomposer2_4khd = 'xcomposer2_4khd'
|
| 137 |
+
xcomposer2_5 = 'xcomposer2_5'
|
| 138 |
+
|
| 139 |
+
cogagent_chat = 'cogagent_chat'
|
| 140 |
+
cogagent_vqa = 'cogagent_vqa'
|
| 141 |
+
cogvlm = 'cogvlm'
|
| 142 |
+
cogvlm2 = 'cogvlm2'
|
| 143 |
+
cogvlm2_video = 'cogvlm2_video'
|
| 144 |
+
glm4v = 'glm4v'
|
| 145 |
+
glm_edge_v = 'glm_edge_v'
|
| 146 |
+
|
| 147 |
+
minicpmv = 'minicpmv'
|
| 148 |
+
minicpmv2_5 = 'minicpmv2_5'
|
| 149 |
+
minicpmv2_6 = 'minicpmv2_6'
|
| 150 |
+
minicpmo2_6 = 'minicpmo2_6'
|
| 151 |
+
|
| 152 |
+
deepseek_vl = 'deepseek_vl'
|
| 153 |
+
deepseek_vl2 = 'deepseek_vl2'
|
| 154 |
+
deepseek_janus = 'deepseek_janus'
|
| 155 |
+
deepseek_janus_pro = 'deepseek_janus_pro'
|
| 156 |
+
|
| 157 |
+
mplug_owl2 = 'mplug_owl2'
|
| 158 |
+
mplug_owl3 = 'mplug_owl3'
|
| 159 |
+
mplug_owl3_241101 = 'mplug_owl3_241101'
|
| 160 |
+
doc_owl2 = 'doc_owl2'
|
| 161 |
+
|
| 162 |
+
emu3_chat = 'emu3_chat'
|
| 163 |
+
emu3_gen = 'emu3_gen'
|
| 164 |
+
|
| 165 |
+
got_ocr2 = 'got_ocr2'
|
| 166 |
+
got_ocr2_hf = 'got_ocr2_hf'
|
| 167 |
+
step_audio = 'step_audio'
|
| 168 |
+
kimi_vl = 'kimi_vl'
|
| 169 |
+
|
| 170 |
+
idefics3 = 'idefics3'
|
| 171 |
+
pixtral = 'pixtral'
|
| 172 |
+
paligemma = 'paligemma'
|
| 173 |
+
phi3_vision = 'phi3_vision'
|
| 174 |
+
phi4_multimodal = 'phi4_multimodal'
|
| 175 |
+
florence = 'florence'
|
| 176 |
+
molmo = 'molmo'
|
| 177 |
+
megrez_omni = 'megrez_omni'
|
| 178 |
+
valley = 'valley'
|
| 179 |
+
gemma3_vision = 'gemma3_vision'
|
| 180 |
+
mistral_2503 = 'mistral_2503'
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class TemplateType(LLMTemplateType, MLLMTemplateType, RMTemplateType):
|
| 184 |
+
|
| 185 |
+
@classmethod
|
| 186 |
+
def get_template_name_list(cls) -> List[str]:
|
| 187 |
+
res = []
|
| 188 |
+
for k in cls.__dict__.keys():
|
| 189 |
+
if k.startswith('__'):
|
| 190 |
+
continue
|
| 191 |
+
value = cls.__dict__[k]
|
| 192 |
+
if isinstance(value, str):
|
| 193 |
+
res.append(value)
|
| 194 |
+
return res
|
swift/llm/template/grounding.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import colorsys
|
| 2 |
+
import itertools
|
| 3 |
+
import os
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import Any, List, Literal
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _shuffle_colors(nums: List[Any]) -> List[Any]:
|
| 13 |
+
if len(nums) == 1:
|
| 14 |
+
return nums
|
| 15 |
+
|
| 16 |
+
mid = len(nums) // 2
|
| 17 |
+
|
| 18 |
+
left = nums[:mid]
|
| 19 |
+
right = nums[mid:]
|
| 20 |
+
left = _shuffle_colors(left)
|
| 21 |
+
right = _shuffle_colors(right)
|
| 22 |
+
new_nums = []
|
| 23 |
+
for x, y in zip(left, right):
|
| 24 |
+
new_nums += [x, y]
|
| 25 |
+
new_nums += left[len(right):] or right[len(left):]
|
| 26 |
+
return new_nums
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_colors():
|
| 30 |
+
vs_combinations = [(v, s) for v, s in itertools.product([0.7, 0.3, 1], [0.7, 0.3, 1])]
|
| 31 |
+
colors = [colorsys.hsv_to_rgb(i / 16, s, v) for v, s in vs_combinations for i in _shuffle_colors(list(range(16)))]
|
| 32 |
+
colors = [(int(r * 255), int(g * 255), int(b * 255)) for r, g, b in colors]
|
| 33 |
+
return _shuffle_colors(colors)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def download_file(url: str) -> str:
|
| 37 |
+
url = url.rstrip('/')
|
| 38 |
+
file_name = url.rsplit('/', 1)[-1]
|
| 39 |
+
cache_dir = os.path.join(get_cache_dir(), 'files')
|
| 40 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 41 |
+
req = requests.get(url)
|
| 42 |
+
file_path = os.path.join(cache_dir, file_name)
|
| 43 |
+
with open(file_path, 'wb') as f:
|
| 44 |
+
f.write(req.content)
|
| 45 |
+
return file_path
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
colors = generate_colors()
|
| 49 |
+
color_mapping = {}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _calculate_brightness(image, region: List[int]):
|
| 53 |
+
cropped_image = image.crop(region)
|
| 54 |
+
grayscale_image = cropped_image.convert('L')
|
| 55 |
+
pixels = list(grayscale_image.getdata())
|
| 56 |
+
average_brightness = sum(pixels) / len(pixels)
|
| 57 |
+
return average_brightness
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def draw_bbox(image: Image.Image,
|
| 61 |
+
ref: List[str],
|
| 62 |
+
bbox: List[List[int]],
|
| 63 |
+
norm_bbox: Literal['norm1000', 'none'] = 'norm1000'):
|
| 64 |
+
bbox = deepcopy(bbox)
|
| 65 |
+
font_path = 'https://modelscope.cn/models/Qwen/Qwen-VL-Chat/resolve/master/SimSun.ttf'
|
| 66 |
+
# norm bbox
|
| 67 |
+
for i, box in enumerate(bbox):
|
| 68 |
+
for i in range(len(box)):
|
| 69 |
+
box[i] = int(box[i])
|
| 70 |
+
if norm_bbox == 'norm1000':
|
| 71 |
+
box[0] = box[0] / 1000 * image.width
|
| 72 |
+
box[2] = box[2] / 1000 * image.width
|
| 73 |
+
box[1] = box[1] / 1000 * image.height
|
| 74 |
+
box[3] = box[3] / 1000 * image.height
|
| 75 |
+
|
| 76 |
+
draw = ImageDraw.Draw(image)
|
| 77 |
+
# draw bbox
|
| 78 |
+
assert len(ref) == len(bbox), f'len(refs): {len(ref)}, len(bboxes): {len(bbox)}'
|
| 79 |
+
for (left, top, right, bottom), box_ref in zip(bbox, ref):
|
| 80 |
+
if box_ref not in color_mapping:
|
| 81 |
+
color_mapping[box_ref] = colors[len(color_mapping) % len(colors)]
|
| 82 |
+
color = color_mapping[box_ref]
|
| 83 |
+
draw.rectangle([(left, top), (right, bottom)], outline=color, width=3)
|
| 84 |
+
# draw text
|
| 85 |
+
file_path = download_file(font_path)
|
| 86 |
+
font = ImageFont.truetype(file_path, 20)
|
| 87 |
+
for (left, top, _, _), box_ref in zip(bbox, ref):
|
| 88 |
+
brightness = _calculate_brightness(
|
| 89 |
+
image, [left, top, min(left + 100, image.width),
|
| 90 |
+
min(top + 20, image.height)])
|
| 91 |
+
draw.text((left, top), box_ref, fill='white' if brightness < 128 else 'black', font=font)
|
swift/llm/template/register.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Literal, Optional
|
| 4 |
+
|
| 5 |
+
from ..utils import Processor
|
| 6 |
+
from .base import Template
|
| 7 |
+
from .template_meta import TemplateMeta
|
| 8 |
+
|
| 9 |
+
TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> None:
|
| 13 |
+
template_type = template_meta.template_type
|
| 14 |
+
if not exist_ok and template_type in TEMPLATE_MAPPING:
|
| 15 |
+
raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
|
| 16 |
+
TEMPLATE_MAPPING[template_type] = template_meta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_template(
|
| 20 |
+
template_type: str,
|
| 21 |
+
processor: Processor,
|
| 22 |
+
default_system: Optional[str] = None,
|
| 23 |
+
max_length: Optional[int] = None,
|
| 24 |
+
*,
|
| 25 |
+
use_chat_template: bool = True,
|
| 26 |
+
template_backend: Literal['swift', 'jinja'] = 'swift',
|
| 27 |
+
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
|
| 28 |
+
max_pixels: Optional[int] = None, # h * w
|
| 29 |
+
agent_template: Optional[str] = None,
|
| 30 |
+
norm_bbox: Literal['norm1000', 'none', None] = None,
|
| 31 |
+
response_prefix: Optional[str] = None,
|
| 32 |
+
# train
|
| 33 |
+
padding_side: Literal['left', 'right'] = 'right',
|
| 34 |
+
loss_scale: str = 'default',
|
| 35 |
+
sequence_parallel_size: int = 1) -> 'Template':
|
| 36 |
+
template_meta = TEMPLATE_MAPPING[template_type]
|
| 37 |
+
template_cls = template_meta.template_cls
|
| 38 |
+
return template_cls(
|
| 39 |
+
processor,
|
| 40 |
+
template_meta,
|
| 41 |
+
default_system,
|
| 42 |
+
max_length,
|
| 43 |
+
use_chat_template=use_chat_template,
|
| 44 |
+
template_backend=template_backend,
|
| 45 |
+
truncation_strategy=truncation_strategy,
|
| 46 |
+
max_pixels=max_pixels,
|
| 47 |
+
agent_template=agent_template,
|
| 48 |
+
norm_bbox=norm_bbox,
|
| 49 |
+
response_prefix=response_prefix,
|
| 50 |
+
padding_side=padding_side,
|
| 51 |
+
loss_scale=loss_scale,
|
| 52 |
+
sequence_parallel_size=sequence_parallel_size,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_template_meta(template_type: str) -> TemplateMeta:
|
| 57 |
+
return TEMPLATE_MAPPING[template_type]
|
swift/llm/template/template/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
|
| 2 |
+
minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
|
swift/llm/template/template/deepseek.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from swift.utils import get_env_args
|
| 12 |
+
from ..base import Template
|
| 13 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 14 |
+
from ..register import TemplateMeta, register_template
|
| 15 |
+
from ..template_inputs import StdTemplateInputs
|
| 16 |
+
from ..utils import Prompt, findall
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class DeepseekTemplateMeta(TemplateMeta):
|
| 21 |
+
prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
|
| 22 |
+
prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:'])
|
| 23 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
|
| 24 |
+
suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
|
| 25 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n'])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, ))
|
| 29 |
+
|
| 30 |
+
register_template(
|
| 31 |
+
TemplateMeta(
|
| 32 |
+
LLMTemplateType.deepseek_coder,
|
| 33 |
+
prefix=['{{SYSTEM}}'],
|
| 34 |
+
prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'],
|
| 35 |
+
chat_sep=['\n<|EOT|>\n'],
|
| 36 |
+
suffix=['\n<|EOT|>'],
|
| 37 |
+
stop_words=['<|EOT|>'],
|
| 38 |
+
default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, '
|
| 39 |
+
'developed by Deepseek Company, and you only answer questions related to computer science. '
|
| 40 |
+
'For politically sensitive questions, security and privacy issues, '
|
| 41 |
+
'and other non-computer science questions, you will refuse to answer\n')))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DeepseekVLTemplate(Template):
|
| 45 |
+
image_placeholder = ['<image_placeholder>']
|
| 46 |
+
skip_prompt = False
|
| 47 |
+
use_model = True
|
| 48 |
+
placeholder_tokens = ['<image_placeholder>']
|
| 49 |
+
|
| 50 |
+
image_token_num_per_image: int = 576
|
| 51 |
+
|
| 52 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 53 |
+
is_janus = getattr(self, 'is_janus', False)
|
| 54 |
+
|
| 55 |
+
encoded = super()._encode(inputs)
|
| 56 |
+
images = inputs.images
|
| 57 |
+
processor = self.processor
|
| 58 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 59 |
+
|
| 60 |
+
if not inputs.generate_mode: # understanding task
|
| 61 |
+
idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
|
| 62 |
+
new_input_ids, new_labels = [], []
|
| 63 |
+
lo = 0
|
| 64 |
+
for hi in idx_list:
|
| 65 |
+
new_input_ids += input_ids[lo:hi]
|
| 66 |
+
if labels is not None:
|
| 67 |
+
new_labels += labels[lo:hi]
|
| 68 |
+
image_tokens = [processor.image_id] * processor.num_image_tokens
|
| 69 |
+
if is_janus:
|
| 70 |
+
image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
|
| 71 |
+
new_input_ids += image_tokens
|
| 72 |
+
new_labels += [-100] * len(image_tokens)
|
| 73 |
+
lo = hi + 1
|
| 74 |
+
new_input_ids += input_ids[lo:]
|
| 75 |
+
if labels is not None:
|
| 76 |
+
new_labels += labels[lo:]
|
| 77 |
+
else:
|
| 78 |
+
new_labels = None
|
| 79 |
+
if is_janus:
|
| 80 |
+
from janus.models.processing_vlm import VLChatProcessorOutput
|
| 81 |
+
else:
|
| 82 |
+
from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
|
| 83 |
+
|
| 84 |
+
images_outputs = processor.image_processor(images, return_tensors='pt')
|
| 85 |
+
output = VLChatProcessorOutput(
|
| 86 |
+
sft_format=None,
|
| 87 |
+
input_ids=torch.tensor(new_input_ids),
|
| 88 |
+
pixel_values=images_outputs.pixel_values,
|
| 89 |
+
num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
|
| 90 |
+
encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
|
| 91 |
+
return encoded
|
| 92 |
+
|
| 93 |
+
else: # image generation task
|
| 94 |
+
if self.is_training:
|
| 95 |
+
raise NotImplementedError('Only support the inference of generation of Janus series models.')
|
| 96 |
+
sft_format = self.tokenizer.decode(input_ids)
|
| 97 |
+
prompt = sft_format + processor.image_start_tag
|
| 98 |
+
input_ids = processor.tokenizer.encode(prompt)
|
| 99 |
+
input_ids = torch.LongTensor(input_ids)
|
| 100 |
+
|
| 101 |
+
encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
|
| 102 |
+
return encoded
|
| 103 |
+
|
| 104 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 105 |
+
if not inputs.get('generate_mode'):
|
| 106 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype)
|
| 107 |
+
inputs_embeds = model.prepare_inputs_embeds(**inputs)
|
| 108 |
+
return {'inputs_embeds': inputs_embeds}
|
| 109 |
+
else:
|
| 110 |
+
return inputs
|
| 111 |
+
|
| 112 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 113 |
+
gene_img_list = [b.get('generate_mode') for b in batch]
|
| 114 |
+
if all(gene_img_list):
|
| 115 |
+
generate_mode = True
|
| 116 |
+
elif not any(gene_img_list):
|
| 117 |
+
generate_mode = False
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')
|
| 120 |
+
|
| 121 |
+
if not generate_mode:
|
| 122 |
+
output = self.fetch_inputs(batch, ['output'])['output']
|
| 123 |
+
batched_output = dict(self.processor.batchify(output))
|
| 124 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 125 |
+
return {**batched_output, **res}
|
| 126 |
+
else:
|
| 127 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 128 |
+
res['generate_mode'] = generate_mode
|
| 129 |
+
return res
|
| 130 |
+
|
| 131 |
+
def generate(self, model, *args, **kwargs):
|
| 132 |
+
if not kwargs.get('generate_mode'):
|
| 133 |
+
return super().generate(model, *args, **kwargs)
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
# generate how many number of images for each prompt, it is named parallel_size in the author's code
|
| 137 |
+
parallel_size = kwargs['generation_config'].num_return_sequences
|
| 138 |
+
temperature = kwargs['generation_config'].temperature
|
| 139 |
+
cfg_weight = get_env_args('cfg_weight', float, 5.0)
|
| 140 |
+
|
| 141 |
+
input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
|
| 142 |
+
bsz, max_input_token_num = input_ids.shape
|
| 143 |
+
tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
|
| 144 |
+
dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
|
| 145 |
+
for i in range(parallel_size * 2):
|
| 146 |
+
tokens[:, i, :] = input_ids
|
| 147 |
+
if i % 2 != 0:
|
| 148 |
+
tokens[:, i, 1:-1] = self.processor.pad_id
|
| 149 |
+
|
| 150 |
+
inputs_embeds = model.language_model.get_input_embeddings()(
|
| 151 |
+
tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]
|
| 152 |
+
|
| 153 |
+
generated_tokens = torch.zeros(
|
| 154 |
+
(bsz, parallel_size, self.image_token_num_per_image),
|
| 155 |
+
dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens
|
| 156 |
+
|
| 157 |
+
# set the first two dimensions into one dimension for batch size
|
| 158 |
+
inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
|
| 159 |
+
generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)
|
| 160 |
+
|
| 161 |
+
for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
|
| 162 |
+
outputs = model.language_model.model(
|
| 163 |
+
inputs_embeds=inputs_embeds,
|
| 164 |
+
use_cache=True,
|
| 165 |
+
past_key_values=outputs.past_key_values if i != 0 else None)
|
| 166 |
+
hidden_states = outputs.last_hidden_state
|
| 167 |
+
|
| 168 |
+
logits = self.model.gen_head(hidden_states[:, -1, :])
|
| 169 |
+
logit_cond = logits[0::2, :]
|
| 170 |
+
logit_uncond = logits[1::2, :]
|
| 171 |
+
|
| 172 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
| 173 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 174 |
+
|
| 175 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 176 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]
|
| 177 |
+
|
| 178 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 179 |
+
img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
|
| 180 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]
|
| 181 |
+
|
| 182 |
+
# no need to reset the original first two dimensions, waiting for the update of the upper layer
|
| 183 |
+
# inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
|
| 184 |
+
# generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)
|
| 185 |
+
|
| 186 |
+
return {'sequences': generated_tokens}
|
| 187 |
+
|
| 188 |
+
def decode(self, generate_ids: List[int], **kwargs) -> Any:
|
| 189 |
+
if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
|
| 190 |
+
return super().decode(generate_ids, **kwargs)
|
| 191 |
+
else:
|
| 192 |
+
img_size = get_env_args('img_size', int, 384)
|
| 193 |
+
patch_size = 16
|
| 194 |
+
|
| 195 |
+
num_to_decode = 1 # for now, generate_ids is a 1D list
|
| 196 |
+
|
| 197 |
+
generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]
|
| 198 |
+
|
| 199 |
+
dec = self.model.gen_vision_model.decode_code(
|
| 200 |
+
generate_ids.to(dtype=torch.int),
|
| 201 |
+
shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
|
| 202 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]
|
| 203 |
+
|
| 204 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
| 205 |
+
|
| 206 |
+
visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
|
| 207 |
+
visual_img[:, :, :] = dec
|
| 208 |
+
|
| 209 |
+
img_list = []
|
| 210 |
+
for i in range(num_to_decode):
|
| 211 |
+
cur_img = Image.fromarray(visual_img[i])
|
| 212 |
+
img_list.append({'type': 'image', 'image': cur_img})
|
| 213 |
+
return img_list
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@dataclass
|
| 217 |
+
class DeepseekVLTemplateMeta(DeepseekTemplateMeta):
|
| 218 |
+
default_system: Optional[str] = ('You are a helpful language and vision assistant. '
|
| 219 |
+
'You are able to understand the visual content that the user provides, '
|
| 220 |
+
'and assist the user with a variety of tasks using natural language.')
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
register_template(DeepseekVLTemplateMeta(
|
| 224 |
+
MLLMTemplateType.deepseek_vl,
|
| 225 |
+
template_cls=DeepseekVLTemplate,
|
| 226 |
+
))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class DeepseekJanus(DeepseekVLTemplate):
|
| 230 |
+
is_janus = True
|
| 231 |
+
image_placeholder = ['<image_placeholder>\n']
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@dataclass
|
| 238 |
+
class DeepseekV2_5TemplateMeta(TemplateMeta):
|
| 239 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}'])
|
| 240 |
+
prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>'])
|
| 241 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
|
| 242 |
+
suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class DeepseekR1Template(Template):
|
| 249 |
+
|
| 250 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 251 |
+
if not self.is_training:
|
| 252 |
+
for message in inputs.messages:
|
| 253 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 254 |
+
message['content'] = message['content'].split('</think>')[-1]
|
| 255 |
+
return super()._swift_encode(inputs)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
register_template(
|
| 259 |
+
DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='<think>\n'))
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DeepseekVL2Template(DeepseekVLTemplate):
|
| 263 |
+
image_placeholder = ['<image>\n']
|
| 264 |
+
placeholder_tokens = ['<image>']
|
| 265 |
+
|
| 266 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 267 |
+
from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput
|
| 268 |
+
encoded = Template._encode(self, inputs)
|
| 269 |
+
images = inputs.images
|
| 270 |
+
processor = self.processor
|
| 271 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 272 |
+
images_seq_mask = [False] * len(input_ids)
|
| 273 |
+
idx_list = findall(input_ids, processor.image_token_id) # '<image>'
|
| 274 |
+
_, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images(
|
| 275 |
+
'<image>' * len(images), images, cropping=len(images) <= 2)
|
| 276 |
+
new_num_tokens = 0
|
| 277 |
+
for idx, n_image_tokens in zip(idx_list, num_image_tokens):
|
| 278 |
+
image_tokens = [processor.image_token_id] * n_image_tokens
|
| 279 |
+
input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:]
|
| 280 |
+
if labels is not None:
|
| 281 |
+
labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:]
|
| 282 |
+
images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:]
|
| 283 |
+
new_num_tokens += n_image_tokens - 1
|
| 284 |
+
|
| 285 |
+
output = VLChatProcessorOutput(
|
| 286 |
+
sft_format=None,
|
| 287 |
+
input_ids=torch.tensor(input_ids),
|
| 288 |
+
target_ids=torch.tensor(input_ids),
|
| 289 |
+
images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)),
|
| 290 |
+
images_seq_mask=torch.tensor(images_seq_mask),
|
| 291 |
+
images_spatial_crop=torch.tensor(images_spatial_crop),
|
| 292 |
+
num_image_tokens=num_image_tokens)
|
| 293 |
+
output.images = output.images.to(dtype=self.model_info.torch_dtype)
|
| 294 |
+
encoded = {'output': output, 'input_ids': input_ids, 'labels': labels}
|
| 295 |
+
return encoded
|
| 296 |
+
|
| 297 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 298 |
+
inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool)
|
| 299 |
+
inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long)
|
| 300 |
+
inputs_embeds = model.prepare_inputs_embeds(**inputs)
|
| 301 |
+
return {'inputs_embeds': inputs_embeds}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
register_template(
|
| 305 |
+
DeepseekV2_5TemplateMeta(
|
| 306 |
+
MLLMTemplateType.deepseek_vl2,
|
| 307 |
+
prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
|
| 308 |
+
template_cls=DeepseekVL2Template,
|
| 309 |
+
))
|
| 310 |
+
|
| 311 |
+
register_template(
|
| 312 |
+
DeepseekVLTemplateMeta(
|
| 313 |
+
MLLMTemplateType.deepseek_janus_pro,
|
| 314 |
+
prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
|
| 315 |
+
template_cls=DeepseekJanus))
|
swift/llm/template/template/emu3.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from swift.utils import get_device
|
| 10 |
+
from ..base import Template
|
| 11 |
+
from ..constant import MLLMTemplateType
|
| 12 |
+
from ..register import register_template
|
| 13 |
+
from ..template_inputs import StdTemplateInputs
|
| 14 |
+
from ..template_meta import TemplateMeta
|
| 15 |
+
from ..utils import findall
|
| 16 |
+
from .utils import DEFAULT_SYSTEM, EmptyTemplateMeta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Emu3GenTemplate(Template):
|
| 20 |
+
|
| 21 |
+
NULL_PROMPT_PROB = 0.1
|
| 22 |
+
COOKBOOK_SIZE = 32768
|
| 23 |
+
CFG_SCALE = os.environ.get('CFG_SCALE', 3.0)
|
| 24 |
+
GENERATION_RATIO = os.environ.get('GENERATION_RATIO', '1:1')
|
| 25 |
+
NEGATIVE_PROMPT = os.environ.get(
|
| 26 |
+
'NEGATIVE_PROMPT',
|
| 27 |
+
'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, '
|
| 28 |
+
'worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry.')
|
| 29 |
+
|
| 30 |
+
def __init__(self, *args, **kwargs):
|
| 31 |
+
super().__init__(*args, **kwargs)
|
| 32 |
+
self.bov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=0))[0]
|
| 33 |
+
self.eov = self.processor.tokenizer.encode(self.processor.visual_template[0].format(token_id=self.COOKBOOK_SIZE
|
| 34 |
+
- 1))[0]
|
| 35 |
+
self.h, self.w = self.processor.calculate_generate_size(self.GENERATION_RATIO, self.processor.image_area,
|
| 36 |
+
self.processor.vision_tokenizer.spatial_scale_factor)
|
| 37 |
+
self.skip_prompt = False
|
| 38 |
+
self.apply_loss_on_only_vision = True
|
| 39 |
+
|
| 40 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 41 |
+
if self.is_training:
|
| 42 |
+
p_prob = random.random()
|
| 43 |
+
if p_prob < self.NULL_PROMPT_PROB:
|
| 44 |
+
prompt = ''
|
| 45 |
+
else:
|
| 46 |
+
prompt = inputs.to_history()['response']
|
| 47 |
+
image = self.smart_resize(inputs.images[0].convert('RGB'))
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
image = self.processor.image_processor(
|
| 50 |
+
image, return_tensors='pt')['pixel_values'].to(device=self.processor.vision_tokenizer.device)
|
| 51 |
+
image_token_ids = self.processor.vision_tokenizer.encode(image).squeeze(0)
|
| 52 |
+
encoded = self._process_prompt_train(prompt, image_token_ids)
|
| 53 |
+
else:
|
| 54 |
+
prompt = inputs.to_history()['query']
|
| 55 |
+
encoded = self._process_prompt_test(prompt)
|
| 56 |
+
encoded = {key: encoded[key][0] for key in encoded.keys()} # [1, L] -> [L]
|
| 57 |
+
|
| 58 |
+
return encoded
|
| 59 |
+
|
| 60 |
+
def _process_prompt_train(self, raw_prompt, image_token_ids):
|
| 61 |
+
image_prompt = self.format_image_prompt(image_token_ids)
|
| 62 |
+
prompt = self.tokenizer.bos_token + raw_prompt + image_prompt
|
| 63 |
+
sample = self.tokenizer(prompt, padding='max_length', return_token_type_ids=False)
|
| 64 |
+
labels = torch.tensor(sample['input_ids'])
|
| 65 |
+
if self.apply_loss_on_only_vision:
|
| 66 |
+
labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, -100)
|
| 67 |
+
sample['labels'] = labels.tolist()
|
| 68 |
+
return sample
|
| 69 |
+
|
| 70 |
+
def _process_prompt_test(self, raw_prompt):
|
| 71 |
+
# for supporting multi inputs, use list instead of single string
|
| 72 |
+
if isinstance(raw_prompt, str):
|
| 73 |
+
raw_prompt = [raw_prompt]
|
| 74 |
+
prompt_list = []
|
| 75 |
+
size_list = []
|
| 76 |
+
for text_prompt in raw_prompt:
|
| 77 |
+
prompt = self.processor.tokenizer.bos_token
|
| 78 |
+
image_prompt = (
|
| 79 |
+
self.processor.tokenizer.boi_token + self.processor.prefix_template.format(H=self.h, W=self.w)
|
| 80 |
+
+ self.processor.tokenizer.img_token)
|
| 81 |
+
prompt += (text_prompt + image_prompt)
|
| 82 |
+
prompt_list.append(prompt)
|
| 83 |
+
size_list.append([self.h, self.w])
|
| 84 |
+
prompt_list = self.tokenizer(prompt_list, padding='longest', return_token_type_ids=False)
|
| 85 |
+
return prompt_list
|
| 86 |
+
|
| 87 |
+
def prepare_for_output(self, output: str) -> str:
|
| 88 |
+
return output
|
| 89 |
+
|
| 90 |
+
def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
|
| 91 |
+
from transformers import UnbatchedClassifierFreeGuidanceLogitsProcessor
|
| 92 |
+
from transformers import PrefixConstrainedLogitsProcessor
|
| 93 |
+
from transformers import LogitsProcessorList
|
| 94 |
+
|
| 95 |
+
negative_prompt = self.NEGATIVE_PROMPT
|
| 96 |
+
neg_inputs = self._process_prompt_test(negative_prompt)
|
| 97 |
+
neg_inputs = {key: torch.tensor(val) for key, val in neg_inputs.items()}
|
| 98 |
+
batch_size = generate_kwargs['input_ids'].shape[0]
|
| 99 |
+
h = torch.tensor([self.h] * batch_size)
|
| 100 |
+
w = torch.tensor([self.w] * batch_size)
|
| 101 |
+
|
| 102 |
+
constrained_fn = self.processor.build_prefix_constrained_fn(h, w)
|
| 103 |
+
logits_processor = LogitsProcessorList([
|
| 104 |
+
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
| 105 |
+
self.CFG_SCALE,
|
| 106 |
+
model,
|
| 107 |
+
unconditional_ids=neg_inputs['input_ids'].to(get_device()),
|
| 108 |
+
),
|
| 109 |
+
PrefixConstrainedLogitsProcessor(
|
| 110 |
+
constrained_fn,
|
| 111 |
+
num_beams=1,
|
| 112 |
+
),
|
| 113 |
+
])
|
| 114 |
+
res = super().prepare_generate_kwargs(generate_kwargs, model=model)
|
| 115 |
+
res['logits_processor'] = logits_processor
|
| 116 |
+
return res
|
| 117 |
+
|
| 118 |
+
def decode(self, generate_ids: List[int], **kwargs) -> Any:
|
| 119 |
+
mm_list = self.processor.decode(generate_ids)
|
| 120 |
+
for im in mm_list:
|
| 121 |
+
if not isinstance(im, Image.Image):
|
| 122 |
+
continue
|
| 123 |
+
return [{'type': 'image', 'image': im}]
|
| 124 |
+
|
| 125 |
+
def to_imgstr(self, image_tokens):
|
| 126 |
+
image_token_str = [[self.processor.visual_template[0].format(token_id=token_id) for token_id in token_row]
|
| 127 |
+
for token_row in image_tokens]
|
| 128 |
+
image_row_str = [''.join(token_row) for token_row in image_token_str]
|
| 129 |
+
imgstr = self.tokenizer.eol_token.join(image_row_str)
|
| 130 |
+
return imgstr
|
| 131 |
+
|
| 132 |
+
def format_image_prompt(self, image_tokens):
|
| 133 |
+
h, w = image_tokens.shape
|
| 134 |
+
imgstr = self.to_imgstr(image_tokens)
|
| 135 |
+
image_prompt = (
|
| 136 |
+
self.tokenizer.boi_token + f'{h}*{w}' + self.tokenizer.img_token + imgstr + self.tokenizer.eol_token
|
| 137 |
+
+ self.tokenizer.eof_token + self.tokenizer.eoi_token)
|
| 138 |
+
return image_prompt
|
| 139 |
+
|
| 140 |
+
def smart_resize(self, image):
|
| 141 |
+
w, h = image.size
|
| 142 |
+
current_area = h * w
|
| 143 |
+
target_ratio = (self.processor.image_area / current_area)**0.5
|
| 144 |
+
th = int(round(h * target_ratio))
|
| 145 |
+
tw = int(round(w * target_ratio))
|
| 146 |
+
image = image.resize((tw, th))
|
| 147 |
+
return image
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
register_template(EmptyTemplateMeta(
|
| 151 |
+
MLLMTemplateType.emu3_gen,
|
| 152 |
+
template_cls=Emu3GenTemplate,
|
| 153 |
+
))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Emu3ChatTemplate(Template):
|
| 157 |
+
system = 'You are a helpful assistant.'
|
| 158 |
+
image_placeholder = ['<|image token|>']
|
| 159 |
+
|
| 160 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 161 |
+
encoded = super()._encode(inputs)
|
| 162 |
+
# image
|
| 163 |
+
images = inputs.images
|
| 164 |
+
input_ids = encoded['input_ids']
|
| 165 |
+
labels = encoded['labels']
|
| 166 |
+
image_tokens = self.processor.tokenize_image(images)
|
| 167 |
+
image_prompts = []
|
| 168 |
+
idx_list = findall(input_ids, self.tokenizer.encode(self.image_placeholder))
|
| 169 |
+
# Create image prompts
|
| 170 |
+
for i in range(len(images)):
|
| 171 |
+
h, w = image_tokens[i].shape
|
| 172 |
+
imgstr = self.processor.to_imgstr(image_tokens[i])
|
| 173 |
+
image_prompt = (
|
| 174 |
+
self.tokenizer.boi_token + self.processor.prefix_template.format(H=h, W=w) + self.tokenizer.img_token
|
| 175 |
+
+ imgstr + self.tokenizer.eol_token + self.tokenizer.eof_token + self.tokenizer.eoi_token)
|
| 176 |
+
image_prompts.append(self.tokenizer.encode(image_prompt))
|
| 177 |
+
|
| 178 |
+
# Insert image tokens into input_ids
|
| 179 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda i: image_prompts[i])
|
| 180 |
+
return {'input_ids': input_ids, 'labels': labels}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
register_template(
|
| 184 |
+
TemplateMeta(
|
| 185 |
+
MLLMTemplateType.emu3_chat,
|
| 186 |
+
prefix=[['bos_token_id'], '{{SYSTEM}}'],
|
| 187 |
+
prompt=[' User: {{QUERY}}. Assistant:'],
|
| 188 |
+
chat_sep=[['eos_token_id']],
|
| 189 |
+
suffix=[['eos_token_id']],
|
| 190 |
+
default_system=DEFAULT_SYSTEM,
|
| 191 |
+
template_cls=Emu3ChatTemplate))
|
swift/llm/template/template/gemma.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from swift.utils import upper_bound
|
| 9 |
+
from ..base import Template
|
| 10 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 11 |
+
from ..register import TemplateMeta, register_template
|
| 12 |
+
from ..template_inputs import StdTemplateInputs
|
| 13 |
+
from ..utils import Context, Prompt, findall
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class GemmaTemplateMeta(TemplateMeta):
|
| 18 |
+
prefix: Prompt = field(default_factory=lambda: ['<bos>'])
|
| 19 |
+
prompt: Prompt = field(
|
| 20 |
+
default_factory=lambda: ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'])
|
| 21 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_turn>\n'])
|
| 22 |
+
suffix: Prompt = field(default_factory=lambda: ['<end_of_turn>'])
|
| 23 |
+
system_prefix: Optional[Prompt] = field(
|
| 24 |
+
default_factory=lambda: ['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
register_template(GemmaTemplateMeta(LLMTemplateType.gemma))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PaliGemmaTemplate(Template):
|
| 31 |
+
placeholder_tokens = ['<image>']
|
| 32 |
+
|
| 33 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 34 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 35 |
+
assert media_type == 'image'
|
| 36 |
+
if self.mode == 'vllm':
|
| 37 |
+
self.prompt = ['{{QUERY}}']
|
| 38 |
+
return []
|
| 39 |
+
else:
|
| 40 |
+
self.prompt = ['{{QUERY}}\n']
|
| 41 |
+
return ['<image>' * self.processor.image_seq_length + '<bos>']
|
| 42 |
+
|
| 43 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 44 |
+
encoded = super()._encode(inputs)
|
| 45 |
+
raw_image = inputs.images
|
| 46 |
+
processor = self.processor
|
| 47 |
+
if encoded['labels'] is not None:
|
| 48 |
+
n = upper_bound(0, len(encoded['labels']), lambda idx: encoded['labels'][idx] == -100)
|
| 49 |
+
n2 = len(encoded['labels']) - n
|
| 50 |
+
encoded['token_type_ids'] = [0] * n + [1] * n2
|
| 51 |
+
else:
|
| 52 |
+
encoded['token_type_ids'] = [0] * len(encoded['input_ids'])
|
| 53 |
+
if raw_image:
|
| 54 |
+
model_inputs = processor(text='<image>' * len(raw_image), images=raw_image, return_tensors='pt')
|
| 55 |
+
encoded['pixel_values'] = model_inputs['pixel_values'].to(self.model_info.torch_dtype)
|
| 56 |
+
return encoded
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
register_template(
|
| 60 |
+
TemplateMeta(
|
| 61 |
+
MLLMTemplateType.paligemma,
|
| 62 |
+
prefix=[],
|
| 63 |
+
prompt=['{{QUERY}}\n'],
|
| 64 |
+
chat_sep=None,
|
| 65 |
+
suffix=['<eos>'],
|
| 66 |
+
template_cls=PaliGemmaTemplate,
|
| 67 |
+
))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class Gemma3TextTemplateMeta(TemplateMeta):
|
| 72 |
+
prefix: Prompt = field(default_factory=lambda: ['<bos>'])
|
| 73 |
+
prompt: Prompt = field(
|
| 74 |
+
default_factory=lambda: ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'])
|
| 75 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_turn>\n'])
|
| 76 |
+
suffix: Prompt = field(default_factory=lambda: ['<end_of_turn>'])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Gemma3Template(Template):
|
| 80 |
+
|
| 81 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 82 |
+
if inputs.system is not None:
|
| 83 |
+
system = inputs.system
|
| 84 |
+
inputs.system = None
|
| 85 |
+
inputs.messages[0]['content'] = system + '\n\n' + inputs.messages[0]['content']
|
| 86 |
+
for message in inputs.messages:
|
| 87 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 88 |
+
message['content'] = message['content'].strip('\n')
|
| 89 |
+
return super()._swift_encode(inputs)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
register_template(Gemma3TextTemplateMeta(LLMTemplateType.gemma3_text, template_cls=Gemma3Template))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Gemma3VisionTemplate(Gemma3Template):
|
| 96 |
+
boi_token_id = 255999
|
| 97 |
+
placeholder_tokens = ['<start_of_image>']
|
| 98 |
+
|
| 99 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 100 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 101 |
+
assert media_type == 'image'
|
| 102 |
+
return ['<start_of_image>']
|
| 103 |
+
|
| 104 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 105 |
+
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
| 106 |
+
|
| 107 |
+
encoded = super()._encode(inputs)
|
| 108 |
+
if inputs.images:
|
| 109 |
+
input_ids = encoded['input_ids']
|
| 110 |
+
labels = encoded['labels']
|
| 111 |
+
idx_list = findall(input_ids, self.boi_token_id)
|
| 112 |
+
img_tokens = self._tokenize(self.processor.full_image_sequence)
|
| 113 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, lambda _: img_tokens)
|
| 114 |
+
|
| 115 |
+
# TODO: customize
|
| 116 |
+
processor_kwargs = Gemma3ProcessorKwargs._defaults['images_kwargs']
|
| 117 |
+
image_inputs = self.processor.image_processor(inputs.images, **processor_kwargs)
|
| 118 |
+
image_inputs['pixel_values'] = torch.as_tensor(np.array(image_inputs['pixel_values']))
|
| 119 |
+
image_inputs.pop('num_crops')
|
| 120 |
+
|
| 121 |
+
array_ids = np.array(input_ids)
|
| 122 |
+
mm_token_type_ids = np.zeros_like(input_ids)
|
| 123 |
+
mm_token_type_ids[array_ids == self.processor.image_token_id] = 1
|
| 124 |
+
encoded['token_type_ids'] = mm_token_type_ids.tolist()
|
| 125 |
+
encoded['input_ids'] = input_ids
|
| 126 |
+
encoded['pixel_values'] = image_inputs['pixel_values']
|
| 127 |
+
encoded['labels'] = labels
|
| 128 |
+
return encoded
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
register_template(GemmaTemplateMeta(MLLMTemplateType.gemma3_vision, template_cls=Gemma3VisionTemplate))
|
swift/llm/template/template/glm.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..base import Template
|
| 8 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 9 |
+
from ..register import TemplateMeta, register_template
|
| 10 |
+
from ..template_inputs import StdTemplateInputs
|
| 11 |
+
from ..utils import Context, Prompt, Word, findall
|
| 12 |
+
from ..vision_utils import load_batch, load_video_cogvlm2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class GLMTemplateMeta(TemplateMeta):
|
| 17 |
+
auto_add_bos: bool = True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GLM4Template(Template):
|
| 21 |
+
|
| 22 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 23 |
+
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
|
| 24 |
+
for i, res_context in enumerate(res_context_list):
|
| 25 |
+
# The last round or is tool_call.
|
| 26 |
+
if isinstance(res_context, str) and res_context.endswith('<|assistant|>\n') and (
|
| 27 |
+
i + 1 >= len(res_context_list) or '<|observation|>' in res_context_list[i + 1]):
|
| 28 |
+
res_context_list[i] = res_context_list[i][:-len('\n')]
|
| 29 |
+
return res_context_list, loss_scale_list, answer_len
|
| 30 |
+
|
| 31 |
+
def decode(self, *args, **kwargs):
|
| 32 |
+
response = super().decode(*args, **kwargs)
|
| 33 |
+
return response.lstrip('\n')
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GLM4_0414Template(GLM4Template):
|
| 37 |
+
|
| 38 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 39 |
+
if not self.is_training:
|
| 40 |
+
for message in inputs.messages:
|
| 41 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 42 |
+
message['content'] = message['content'].split('</think>')[-1].strip()
|
| 43 |
+
return super()._swift_encode(inputs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
register_template(
|
| 47 |
+
GLMTemplateMeta(
|
| 48 |
+
LLMTemplateType.chatglm2,
|
| 49 |
+
prefix=['{{SYSTEM}}'],
|
| 50 |
+
prompt=['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'],
|
| 51 |
+
chat_sep=['\n\n']))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class GLM4TemplateMeta(GLMTemplateMeta):
|
| 56 |
+
prefix: Prompt = field(default_factory=list)
|
| 57 |
+
prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n'])
|
| 58 |
+
chat_sep: Optional[Prompt] = field(default_factory=list)
|
| 59 |
+
suffix: Prompt = field(default_factory=lambda: ['<|user|>'])
|
| 60 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}'])
|
| 61 |
+
|
| 62 |
+
agent_template: str = 'glm4'
|
| 63 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>'])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class GLM4_0414TemplateMeta(GLM4TemplateMeta):
|
| 68 |
+
prefix: Prompt = field(default_factory=lambda: ['[gMASK]<sop>'])
|
| 69 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>\n{{SYSTEM}}'])
|
| 70 |
+
agent_template: str = 'glm4_0414'
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GLM4VTemplate(Template):
|
| 74 |
+
|
| 75 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 76 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 77 |
+
assert media_type == 'image'
|
| 78 |
+
return [[-100]]
|
| 79 |
+
|
| 80 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 81 |
+
encoded = super()._encode(inputs)
|
| 82 |
+
input_ids = encoded['input_ids']
|
| 83 |
+
labels = encoded['labels']
|
| 84 |
+
idx_list = findall(input_ids, -100)
|
| 85 |
+
if idx_list:
|
| 86 |
+
idx = idx_list[0]
|
| 87 |
+
image = inputs.images[0]
|
| 88 |
+
placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
|
| 89 |
+
placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
|
| 90 |
+
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
|
| 91 |
+
if labels is not None:
|
| 92 |
+
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
|
| 93 |
+
messages = inputs.messages
|
| 94 |
+
messages[0]['image'] = image
|
| 95 |
+
inputs2: Dict[str, Any] = self.processor.apply_chat_template(messages, return_dict=True)
|
| 96 |
+
encoded['images'] = inputs2['images']
|
| 97 |
+
encoded['input_ids'] = input_ids
|
| 98 |
+
encoded['labels'] = labels
|
| 99 |
+
encoded['position_ids'] = list(range(len(input_ids)))
|
| 100 |
+
return encoded
|
| 101 |
+
|
| 102 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 103 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 104 |
+
images = [b['images'] for b in batch if 'images' in b]
|
| 105 |
+
if images:
|
| 106 |
+
res['images'] = torch.concat(images)
|
| 107 |
+
return res
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>']))
|
| 111 |
+
|
| 112 |
+
register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template))
|
| 113 |
+
|
| 114 |
+
register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
|
| 115 |
+
|
| 116 |
+
glm4z1rumination_system = (
|
| 117 |
+
'你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
|
| 118 |
+
'今年是 2025 年。\n\n'
|
| 119 |
+
'<核心要求>\n'
|
| 120 |
+
'- 首先分解用户请求,得到包含多个子要求的列表\n'
|
| 121 |
+
'- 制定初始研究计划\n'
|
| 122 |
+
'- 进行多轮迭代搜索和页面浏览(at least 10 function calls):\n'
|
| 123 |
+
' * 根据已获得的信息调整研究计划和关键词\n'
|
| 124 |
+
' * 打开页面阅读,从发现的内容中识别新的关键概念/名词\n'
|
| 125 |
+
' * 从搜索结果中提取新的关键词继续搜索\n'
|
| 126 |
+
' * 访问并仔细阅读相关页面,识别新的关键概念/名词\n\n'
|
| 127 |
+
'<重要配置>\n'
|
| 128 |
+
'- 采用语言\n'
|
| 129 |
+
' * 搜索关键词:英语\n'
|
| 130 |
+
' * 思考:英语\n\n'
|
| 131 |
+
'<可调用的工具列表>\n\n'
|
| 132 |
+
'[{"name": "search", "description": "Execute a search query and return search results. '
|
| 133 |
+
'Use this function when you need to find information about a specific topic.", '
|
| 134 |
+
'"parameters": {"type": "object", "properties": {"query": {"type": "string", '
|
| 135 |
+
'"description": "Search query string, use English words unless it is a proper name in Chinese"}}, '
|
| 136 |
+
'"required": ["query"], "additionalProperties": false}}, '
|
| 137 |
+
'{"name": "click", "description": "Click a link in the search results and navigate to the corresponding page. '
|
| 138 |
+
'Use this function when you need to view detailed content of a specific search result.", '
|
| 139 |
+
'"parameters": {"type": "object", "properties": {"link_id": {"type": "integer", '
|
| 140 |
+
'"description": "The link ID to click (from the sequence number in search results)"}}, '
|
| 141 |
+
'"required": ["link_id"], "additionalProperties": false}}, '
|
| 142 |
+
'{"name": "open", "description": "Open a specific website. Get content from any website with its URL.", '
|
| 143 |
+
'"parameters": {"type": "object", "properties": {"url": {"type": "string", '
|
| 144 |
+
'"description": "The target website URL or domain"}}, "required": ["url"], "additionalProperties": false}}, '
|
| 145 |
+
'{"name": "finish", "description": "Finish the task. '
|
| 146 |
+
'Use this function when you have found the information you need.", '
|
| 147 |
+
'"parameters": {"type": "object", "properties": {}, "additionalProperties": false}}]')
|
| 148 |
+
|
| 149 |
+
register_template(
|
| 150 |
+
GLM4_0414TemplateMeta(
|
| 151 |
+
LLMTemplateType.glm4_z1_rumination, template_cls=GLM4_0414Template, default_system=glm4z1rumination_system))
|
| 152 |
+
|
| 153 |
+
codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
|
| 154 |
+
|
| 155 |
+
register_template(GLM4TemplateMeta(LLMTemplateType.codegeex4, default_system=codegeex4_system))
|
| 156 |
+
|
| 157 |
+
register_template(
|
| 158 |
+
TemplateMeta(
|
| 159 |
+
LLMTemplateType.longwriter_llama, ['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'],
|
| 160 |
+
system_prefix=['<<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CogTemplate(Template):
|
| 164 |
+
placeholder_tokens = ['<|reserved_special_token_0|>']
|
| 165 |
+
|
| 166 |
+
use_model = True
|
| 167 |
+
|
| 168 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 169 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 173 |
+
encoded = super()._encode(inputs)
|
| 174 |
+
model = self.model
|
| 175 |
+
image = inputs.images or []
|
| 176 |
+
history_inputs = inputs.to_history()
|
| 177 |
+
inputs2 = model.build_conversation_input_ids(
|
| 178 |
+
self.processor, query=history_inputs['query'], history=history_inputs['history'], images=image)
|
| 179 |
+
image_token_len = inputs2['token_type_ids'].sum().item()
|
| 180 |
+
input_ids = encoded['input_ids']
|
| 181 |
+
labels = encoded['labels']
|
| 182 |
+
encoded['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
|
| 183 |
+
encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * image_token_len + input_ids[1:]
|
| 184 |
+
if labels is not None:
|
| 185 |
+
encoded['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
|
| 186 |
+
if len(image) > 0:
|
| 187 |
+
encoded['images'] = [[img.to(dtype=self.model_info.torch_dtype)] for img in inputs2['images']]
|
| 188 |
+
if 'cross_images' in inputs2:
|
| 189 |
+
# is cogagent
|
| 190 |
+
encoded['cross_images'] = [[cross_img.to(dtype=self.model_info.torch_dtype)]
|
| 191 |
+
for cross_img in inputs2['cross_images']]
|
| 192 |
+
return encoded
|
| 193 |
+
|
| 194 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 195 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 196 |
+
keys = ['images', 'cross_images']
|
| 197 |
+
for key in keys:
|
| 198 |
+
if key in batch[0]:
|
| 199 |
+
res[key] = [b[key][0] for b in batch]
|
| 200 |
+
return res
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
register_template(
|
| 204 |
+
TemplateMeta(
|
| 205 |
+
MLLMTemplateType.cogagent_chat,
|
| 206 |
+
prefix=['<s>'],
|
| 207 |
+
prompt=[' [INST] {{QUERY}} [/INST] '],
|
| 208 |
+
chat_sep=[],
|
| 209 |
+
suffix=['</s>'],
|
| 210 |
+
template_cls=CogTemplate,
|
| 211 |
+
))
|
| 212 |
+
|
| 213 |
+
register_template(
|
| 214 |
+
TemplateMeta(
|
| 215 |
+
MLLMTemplateType.cogagent_vqa,
|
| 216 |
+
prefix=['<s>'],
|
| 217 |
+
prompt=['<EOI>Question: {{QUERY}} Answer:'],
|
| 218 |
+
chat_sep=None,
|
| 219 |
+
suffix=['</s>'],
|
| 220 |
+
template_cls=CogTemplate))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dataclass
|
| 224 |
+
class CogVLMTemplateMeta(TemplateMeta):
|
| 225 |
+
prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
|
| 226 |
+
prompt: Prompt = field(default_factory=lambda: ['Question: {{QUERY}} Answer:'])
|
| 227 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n'])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm, template_cls=CogTemplate))
|
| 231 |
+
|
| 232 |
+
register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm2, template_cls=CogTemplate))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class Cog2VideoTemplate(CogTemplate):
|
| 236 |
+
use_model = True
|
| 237 |
+
|
| 238 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 239 |
+
model = self.model
|
| 240 |
+
encoded = super(CogTemplate, self)._encode(inputs)
|
| 241 |
+
videos_path = inputs.videos or []
|
| 242 |
+
video = load_batch(videos_path, load_video_cogvlm2)
|
| 243 |
+
history_inputs = inputs.to_history()
|
| 244 |
+
inputs2 = model.build_conversation_input_ids(
|
| 245 |
+
self.processor,
|
| 246 |
+
query=history_inputs['query'],
|
| 247 |
+
history=history_inputs['history'],
|
| 248 |
+
images=video,
|
| 249 |
+
template_version='chat')
|
| 250 |
+
video_token_len = inputs2['token_type_ids'].sum().item()
|
| 251 |
+
input_ids = encoded['input_ids']
|
| 252 |
+
labels = encoded['labels']
|
| 253 |
+
encoded['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:])
|
| 254 |
+
encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * video_token_len + input_ids[1:]
|
| 255 |
+
if labels is not None:
|
| 256 |
+
encoded['labels'] = labels[:1] + [-100] * video_token_len + labels[1:]
|
| 257 |
+
if len(video) > 0:
|
| 258 |
+
dtype = model.dtype
|
| 259 |
+
encoded['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
|
| 260 |
+
return encoded
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
register_template(CogVLMTemplateMeta(
|
| 264 |
+
MLLMTemplateType.cogvlm2_video,
|
| 265 |
+
template_cls=Cog2VideoTemplate,
|
| 266 |
+
))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class GLMEdgeVTemplate(Template):
|
| 270 |
+
placeholder_tokens = ['<|begin_of_image|>']
|
| 271 |
+
|
| 272 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 273 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 274 |
+
assert media_type == 'image'
|
| 275 |
+
return ['<|begin_of_image|>' * 578]
|
| 276 |
+
|
| 277 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 278 |
+
encoded = super()._encode(inputs)
|
| 279 |
+
images = inputs.images
|
| 280 |
+
if images:
|
| 281 |
+
encoded['pixel_values'] = torch.tensor(self.processor(images).pixel_values)
|
| 282 |
+
return encoded
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
register_template(
|
| 286 |
+
GLM4TemplateMeta(
|
| 287 |
+
MLLMTemplateType.glm_edge_v,
|
| 288 |
+
prompt=['<|user|>\\n{{QUERY}}\\n<|assistant|>\\n'],
|
| 289 |
+
chat_sep=['\\n'],
|
| 290 |
+
system_prefix=['<|system|>\\n{{SYSTEM}}\\n'],
|
| 291 |
+
suffix=['<|endoftext|>'],
|
| 292 |
+
template_cls=GLMEdgeVTemplate,
|
| 293 |
+
))
|
swift/llm/template/template/idefics3.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from ..base import Template
|
| 5 |
+
from ..constant import MLLMTemplateType
|
| 6 |
+
from ..register import TemplateMeta, register_template
|
| 7 |
+
from ..template_inputs import StdTemplateInputs
|
| 8 |
+
from ..utils import align_image_inputs
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Idefics3Template(Template):
|
| 12 |
+
placeholder_tokens = ['<image>']
|
| 13 |
+
|
| 14 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 15 |
+
encoded = super()._encode(inputs)
|
| 16 |
+
images = inputs.images or []
|
| 17 |
+
processor = self.processor
|
| 18 |
+
prompt = self.processor.decode(encoded['input_ids'])
|
| 19 |
+
if images:
|
| 20 |
+
image_inputs = processor(text=prompt, images=images, return_tensors='pt', add_special_tokens=False)
|
| 21 |
+
image_token = 128257 # <image>
|
| 22 |
+
encoded['input_ids'], encoded['labels'] = align_image_inputs(encoded['input_ids'], encoded['labels'],
|
| 23 |
+
image_inputs['input_ids'][0], image_token)
|
| 24 |
+
encoded['pixel_values'] = image_inputs['pixel_values']
|
| 25 |
+
return encoded
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_template(
|
| 29 |
+
TemplateMeta(
|
| 30 |
+
MLLMTemplateType.idefics3,
|
| 31 |
+
prefix=['<|begin_of_text|>'],
|
| 32 |
+
prompt=['User:{{QUERY}}<end_of_utterance>\nAssistant:'],
|
| 33 |
+
chat_sep=['<end_of_utterance>\n'],
|
| 34 |
+
suffix=['<end_of_utterance>'],
|
| 35 |
+
system_prefix=['System:{{SYSTEM}}<end_of_utterance>\n'],
|
| 36 |
+
template_cls=Idefics3Template,
|
| 37 |
+
))
|
swift/llm/template/template/internlm.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 8 |
+
|
| 9 |
+
from swift.utils import get_env_args
|
| 10 |
+
from ..base import Template
|
| 11 |
+
from ..constant import LLMTemplateType, MLLMTemplateType, RMTemplateType
|
| 12 |
+
from ..register import TemplateMeta, register_template
|
| 13 |
+
from ..template_inputs import StdTemplateInputs
|
| 14 |
+
from ..utils import Context, Prompt, Word
|
| 15 |
+
from ..vision_utils import load_file
|
| 16 |
+
from .utils import ChatmlTemplateMeta
|
| 17 |
+
|
| 18 |
+
INTERNLM_SYSTEM = (
|
| 19 |
+
'You are an AI assistant whose name is InternLM (书生·浦语).\n'
|
| 20 |
+
'- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
|
| 21 |
+
'It is designed to be helpful, honest, and harmless.\n'
|
| 22 |
+
'- InternLM (书生·浦语) can understand and communicate fluently in the language chosen '
|
| 23 |
+
'by the user such as English and 中文.')
|
| 24 |
+
|
| 25 |
+
register_template(
|
| 26 |
+
TemplateMeta(
|
| 27 |
+
LLMTemplateType.internlm,
|
| 28 |
+
prefix=['<s>'],
|
| 29 |
+
prompt=['<|User|>:{{QUERY}}\n<|Bot|>:'],
|
| 30 |
+
chat_sep=['<eoa>\n'],
|
| 31 |
+
suffix=['<eoa>'],
|
| 32 |
+
default_system=INTERNLM_SYSTEM,
|
| 33 |
+
system_prefix=['<s><|System|>:{{SYSTEM}}\n']))
|
| 34 |
+
|
| 35 |
+
register_template(ChatmlTemplateMeta(LLMTemplateType.internlm2, default_system=INTERNLM_SYSTEM))
|
| 36 |
+
|
| 37 |
+
register_template(ChatmlTemplateMeta(RMTemplateType.internlm2_reward, suffix=['<|im_end|>\n<|reward|>']))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class InternLMXComposer2Template(Template):
|
| 41 |
+
image_placeholder = ['</s>']
|
| 42 |
+
version = 'v2'
|
| 43 |
+
skip_prompt = False
|
| 44 |
+
use_model = True
|
| 45 |
+
|
| 46 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 47 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 48 |
+
if media_type == 'video':
|
| 49 |
+
inputs.images.insert(inputs.image_idx, inputs.videos[index])
|
| 50 |
+
inputs.image_idx += 1
|
| 51 |
+
return self.image_placeholder
|
| 52 |
+
|
| 53 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 54 |
+
model = self.model
|
| 55 |
+
encoded = super()._encode(inputs)
|
| 56 |
+
images = inputs.images or []
|
| 57 |
+
|
| 58 |
+
if self.version == 'v2.5':
|
| 59 |
+
hd_num = 24
|
| 60 |
+
if len(images) > 1:
|
| 61 |
+
hd_num = 6
|
| 62 |
+
hd_num = get_env_args('hd_num', int, hd_num)
|
| 63 |
+
images_origin = images
|
| 64 |
+
images = []
|
| 65 |
+
for image in images_origin:
|
| 66 |
+
if isinstance(image, Image.Image):
|
| 67 |
+
Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', model.model_dir)
|
| 68 |
+
images.append(Image_transform(image, hd_num=hd_num))
|
| 69 |
+
else:
|
| 70 |
+
load_video = get_class_from_dynamic_module('ixc_utils.load_video', model.model_dir)
|
| 71 |
+
frame2img = get_class_from_dynamic_module('ixc_utils.frame2img', model.model_dir)
|
| 72 |
+
Video_transform = get_class_from_dynamic_module('ixc_utils.Video_transform', model.model_dir)
|
| 73 |
+
image = load_video(load_file(image))
|
| 74 |
+
image = frame2img(image, model.font)
|
| 75 |
+
images.append(Video_transform(image, hd_num=hd_num))
|
| 76 |
+
elif self.version == 'v2-4khd':
|
| 77 |
+
hd_num = 55
|
| 78 |
+
hd_num = get_env_args('hd_num', int, hd_num)
|
| 79 |
+
HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', model.model_dir)
|
| 80 |
+
images = [HD_transform(image, hd_num=hd_num) for image in images]
|
| 81 |
+
images = [model.vis_processor(image).to(model.dtype) for image in images]
|
| 82 |
+
encoded['images'] = images
|
| 83 |
+
return encoded
|
| 84 |
+
|
| 85 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 86 |
+
batch_size = len(inputs['input_ids'])
|
| 87 |
+
res = []
|
| 88 |
+
im_mask = []
|
| 89 |
+
length = inputs['length']
|
| 90 |
+
for i in range(batch_size):
|
| 91 |
+
input_ids = inputs['input_ids'][i].tolist()[:length[i]]
|
| 92 |
+
input_ids.append(2) # add dummy </s>
|
| 93 |
+
labels = inputs.get('labels')
|
| 94 |
+
if labels is not None:
|
| 95 |
+
labels = labels[i].tolist()[:length[i]]
|
| 96 |
+
labels.append(2)
|
| 97 |
+
else:
|
| 98 |
+
labels = []
|
| 99 |
+
images = inputs['images'][i]
|
| 100 |
+
res_inputs_embeds = []
|
| 101 |
+
res_labels = []
|
| 102 |
+
wrap_im_mask = []
|
| 103 |
+
pre_i, i, idx = 0, 0, 0
|
| 104 |
+
device = model.device
|
| 105 |
+
internlm2_model = model.model
|
| 106 |
+
if not hasattr(internlm2_model, 'tok_embeddings'):
|
| 107 |
+
internlm2_model = internlm2_model.model
|
| 108 |
+
tok_embeddings = internlm2_model.tok_embeddings
|
| 109 |
+
if len(images) > 0:
|
| 110 |
+
images = torch.concat([model.img2emb(image[None])[0] for image in images], dim=0)
|
| 111 |
+
add_bos = False
|
| 112 |
+
while i < len(input_ids):
|
| 113 |
+
if input_ids[i] == 2: # replace_token
|
| 114 |
+
res_input_ids = torch.tensor(([1] if add_bos else []) + input_ids[pre_i:i], device=device)
|
| 115 |
+
if not add_bos and self.version != 'v2.5':
|
| 116 |
+
add_bos = True
|
| 117 |
+
res_inputs_embeds.append(tok_embeddings(res_input_ids[None])[0])
|
| 118 |
+
wrap_im_mask += [0] * len(res_input_ids)
|
| 119 |
+
res_labels += ([-100] if add_bos else []) + labels[pre_i:i]
|
| 120 |
+
if len(images) > 0 and idx < images.shape[0]:
|
| 121 |
+
res_inputs_embeds.append(images[idx].to(device))
|
| 122 |
+
wrap_im_mask += [1] * images.shape[1]
|
| 123 |
+
res_labels += [-100] * images.shape[1]
|
| 124 |
+
idx += 1
|
| 125 |
+
i += 1
|
| 126 |
+
pre_i = i
|
| 127 |
+
continue
|
| 128 |
+
i += 1
|
| 129 |
+
if len(labels) == 0:
|
| 130 |
+
res_labels = None
|
| 131 |
+
im_mask.append(torch.tensor(wrap_im_mask, dtype=torch.bool, device=device))
|
| 132 |
+
res.append({'inputs_embeds': torch.concat(res_inputs_embeds, dim=0), 'labels': res_labels})
|
| 133 |
+
res = Template._data_collator(self, res)
|
| 134 |
+
res['im_mask'] = self._pad_sequence(im_mask, 0)
|
| 135 |
+
return res
|
| 136 |
+
|
| 137 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 138 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 139 |
+
res['length'] = [len(b['input_ids']) for b in batch]
|
| 140 |
+
res.update(self.fetch_inputs(batch, ['images']))
|
| 141 |
+
return res
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclass
|
| 145 |
+
class Xcomposer2TemplateMeta(TemplateMeta):
|
| 146 |
+
prefix: Prompt = field(default_factory=lambda: ['<s>'])
|
| 147 |
+
prompt: Prompt = field(
|
| 148 |
+
default_factory=lambda: ['[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'])
|
| 149 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['[UNUSED_TOKEN_145]\n'])
|
| 150 |
+
suffix: Prompt = field(default_factory=lambda: ['[UNUSED_TOKEN_145]'])
|
| 151 |
+
system_prefix: Optional[Prompt] = field(
|
| 152 |
+
default_factory=lambda: ['<s>[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n'])
|
| 153 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|im_end|>'])
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
register_template(
|
| 157 |
+
Xcomposer2TemplateMeta(
|
| 158 |
+
MLLMTemplateType.xcomposer2,
|
| 159 |
+
template_cls=InternLMXComposer2Template,
|
| 160 |
+
default_system=('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
|
| 161 |
+
'- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by '
|
| 162 |
+
'Shanghai AI Laboratory (上海人工智能实验室). '
|
| 163 |
+
'It is designed to be helpful, honest, and harmless.\n'
|
| 164 |
+
'- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
|
| 165 |
+
'by the user such as English and 中文.'),
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class InternLMXComposer2_5Template(InternLMXComposer2Template):
|
| 170 |
+
system = ('You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
|
| 171 |
+
'- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model '
|
| 172 |
+
'that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
|
| 173 |
+
'It is designed to be helpful, honest, and harmless.\n'
|
| 174 |
+
'- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
|
| 175 |
+
'by the user such as English and 中文.\n'
|
| 176 |
+
'- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively '
|
| 177 |
+
'based on the provided image.')
|
| 178 |
+
version = 'v2.5'
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class InternLMXComposer2_4khdTemplate(InternLMXComposer2Template):
|
| 182 |
+
version = 'v2-4khd'
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
register_template(
|
| 186 |
+
Xcomposer2TemplateMeta(
|
| 187 |
+
MLLMTemplateType.xcomposer2_5,
|
| 188 |
+
template_cls=InternLMXComposer2_5Template,
|
| 189 |
+
default_system=InternLMXComposer2_5Template.system))
|
| 190 |
+
|
| 191 |
+
register_template(
|
| 192 |
+
Xcomposer2TemplateMeta(
|
| 193 |
+
MLLMTemplateType.xcomposer2_4khd,
|
| 194 |
+
template_cls=InternLMXComposer2_4khdTemplate,
|
| 195 |
+
default_system=InternLMXComposer2_5Template.system))
|
swift/llm/template/template/internvl.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, List, Literal
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from swift.utils import get_env_args, is_deepspeed_enabled
|
| 9 |
+
from ..base import Template
|
| 10 |
+
from ..constant import MLLMTemplateType
|
| 11 |
+
from ..register import register_template
|
| 12 |
+
from ..template_inputs import StdTemplateInputs
|
| 13 |
+
from ..utils import Context, findall
|
| 14 |
+
from ..vision_utils import load_video_internvl, transform_image
|
| 15 |
+
from .microsoft import Phi3TemplateMeta
|
| 16 |
+
from .utils import ChatmlTemplateMeta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InternvlTemplate(Template):
|
| 20 |
+
skip_prompt = False
|
| 21 |
+
num_image_token = 256
|
| 22 |
+
placeholder_tokens = ['<IMG_CONTEXT>']
|
| 23 |
+
|
| 24 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 25 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 26 |
+
if self.mode == 'vllm':
|
| 27 |
+
image_context = ['<image>\n']
|
| 28 |
+
else:
|
| 29 |
+
image_context = ['<img>', [-100], '</img>\n']
|
| 30 |
+
return image_context
|
| 31 |
+
|
| 32 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 33 |
+
encoded = super()._encode(inputs)
|
| 34 |
+
input_ids = encoded['input_ids']
|
| 35 |
+
idx_list = findall(input_ids, -100)
|
| 36 |
+
pixel_values = None
|
| 37 |
+
images = inputs.images
|
| 38 |
+
if images:
|
| 39 |
+
labels = encoded.get('labels')
|
| 40 |
+
input_size = get_env_args('input_size', int, 448)
|
| 41 |
+
max_num = get_env_args('max_num', int, 12)
|
| 42 |
+
pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
|
| 43 |
+
pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model_info.torch_dtype)
|
| 44 |
+
image_bs = pixel_values.shape[0]
|
| 45 |
+
|
| 46 |
+
idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
|
| 47 |
+
img_tokens: List[int] = self.processor.encode(
|
| 48 |
+
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * image_bs
|
| 49 |
+
input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:]
|
| 50 |
+
if labels is not None:
|
| 51 |
+
labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
|
| 52 |
+
encoded['input_ids'] = input_ids
|
| 53 |
+
encoded['labels'] = labels
|
| 54 |
+
encoded['pixel_values'] = pixel_values
|
| 55 |
+
return encoded
|
| 56 |
+
|
| 57 |
+
def compute_loss_context(self, model, inputs):
|
| 58 |
+
model_name = model.language_model.__class__.__name__.lower()
|
| 59 |
+
if self._packing and 'internlm2' in model_name:
|
| 60 |
+
position_ids = inputs['position_ids']
|
| 61 |
+
modeling_module = model.language_model.model.layers[0].attention.__class__
|
| 62 |
+
return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True)
|
| 63 |
+
else:
|
| 64 |
+
return super().compute_loss_context(model, inputs)
|
| 65 |
+
|
| 66 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 67 |
+
embedding = model.get_input_embeddings()
|
| 68 |
+
device = embedding.weight.device
|
| 69 |
+
input_ids = inputs['input_ids']
|
| 70 |
+
inputs_embeds = embedding(input_ids).to(device=device)
|
| 71 |
+
pixel_values = inputs.get('pixel_values')
|
| 72 |
+
if pixel_values is not None:
|
| 73 |
+
pixel_values = pixel_values.to(device=device)
|
| 74 |
+
vit_embeds = model.extract_feature(pixel_values).to(device=device)
|
| 75 |
+
selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
|
| 76 |
+
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
|
| 77 |
+
elif is_deepspeed_enabled():
|
| 78 |
+
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
|
| 79 |
+
vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
|
| 80 |
+
inputs_embeds += vit_embeds.mean() * 0.
|
| 81 |
+
return {'inputs_embeds': inputs_embeds}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
register_template(
|
| 85 |
+
ChatmlTemplateMeta(
|
| 86 |
+
MLLMTemplateType.internvl,
|
| 87 |
+
default_system='You are an AI assistant whose name is InternLM (书生·浦语).',
|
| 88 |
+
template_cls=InternvlTemplate,
|
| 89 |
+
auto_add_bos=True))
|
| 90 |
+
register_template(
|
| 91 |
+
Phi3TemplateMeta(
|
| 92 |
+
MLLMTemplateType.internvl_phi3,
|
| 93 |
+
default_system='You are an AI assistant whose name is Phi-3.',
|
| 94 |
+
template_cls=InternvlTemplate,
|
| 95 |
+
auto_add_bos=True))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Internvl2Template(InternvlTemplate):
|
| 99 |
+
video_segments = 8
|
| 100 |
+
|
| 101 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 102 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 103 |
+
image_context = super().replace_tag('image', index, inputs)
|
| 104 |
+
if media_type == 'image':
|
| 105 |
+
return image_context
|
| 106 |
+
elif media_type == 'video':
|
| 107 |
+
video_segments = get_env_args('video_segments', int, self.video_segments)
|
| 108 |
+
load_video = partial(load_video_internvl, num_segments=video_segments)
|
| 109 |
+
return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)
|
| 110 |
+
|
| 111 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 112 |
+
return [f'<ref>{ref}</ref>']
|
| 113 |
+
|
| 114 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 115 |
+
return [f'<box>[{bbox}]</box>']
|
| 116 |
+
|
| 117 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 118 |
+
encoded = super(InternvlTemplate, self)._encode(inputs)
|
| 119 |
+
input_ids = encoded['input_ids']
|
| 120 |
+
idx_list = findall(input_ids, -100)
|
| 121 |
+
labels = encoded['labels']
|
| 122 |
+
images = inputs.images
|
| 123 |
+
if images:
|
| 124 |
+
has_video = bool(inputs.videos)
|
| 125 |
+
input_size = get_env_args('input_size', int, 448)
|
| 126 |
+
max_num = get_env_args('max_num', int, 12)
|
| 127 |
+
video_max_num = get_env_args('video_max_num', int, 1)
|
| 128 |
+
if has_video:
|
| 129 |
+
max_num = video_max_num
|
| 130 |
+
pixel_values = [transform_image(image, input_size, max_num) for image in images]
|
| 131 |
+
num_patches = [pv.shape[0] for pv in pixel_values]
|
| 132 |
+
pixel_values = torch.cat(pixel_values).to(self.model_info.torch_dtype)
|
| 133 |
+
else:
|
| 134 |
+
pixel_values = None
|
| 135 |
+
num_patches = []
|
| 136 |
+
assert len(num_patches) == len(
|
| 137 |
+
idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
|
| 138 |
+
|
| 139 |
+
def _get_new_tokens(i):
|
| 140 |
+
img_tokens: List[int] = self.processor.encode(
|
| 141 |
+
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patches[i]
|
| 142 |
+
return img_tokens
|
| 143 |
+
|
| 144 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 145 |
+
encoded['pixel_values'] = pixel_values
|
| 146 |
+
return encoded
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
_internvl2_system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
|
| 150 |
+
register_template(
|
| 151 |
+
ChatmlTemplateMeta(
|
| 152 |
+
MLLMTemplateType.internvl2,
|
| 153 |
+
default_system=_internvl2_system,
|
| 154 |
+
template_cls=Internvl2Template,
|
| 155 |
+
))
|
| 156 |
+
|
| 157 |
+
register_template(
|
| 158 |
+
Phi3TemplateMeta(
|
| 159 |
+
MLLMTemplateType.internvl2_phi3,
|
| 160 |
+
default_system=_internvl2_system,
|
| 161 |
+
template_cls=Internvl2Template,
|
| 162 |
+
))
|
| 163 |
+
|
| 164 |
+
register_template(
|
| 165 |
+
ChatmlTemplateMeta(
|
| 166 |
+
MLLMTemplateType.internvl2_5,
|
| 167 |
+
template_cls=Internvl2Template,
|
| 168 |
+
default_system='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'))
|
swift/llm/template/template/llama.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import datetime as dt
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from swift.utils import get_env_args
|
| 11 |
+
from ..base import Template
|
| 12 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 13 |
+
from ..register import TemplateMeta, register_template
|
| 14 |
+
from ..template_inputs import StdTemplateInputs
|
| 15 |
+
from ..utils import Context, Prompt, Word, findall
|
| 16 |
+
from ..vision_utils import load_batch
|
| 17 |
+
|
| 18 |
+
# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
|
| 19 |
+
LLAMA_DEFAULT_SYSTEM = (
|
| 20 |
+
'You are a helpful, respectful and honest assistant. '
|
| 21 |
+
'Always answer as helpfully as possible, while being safe. '
|
| 22 |
+
'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
| 23 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
|
| 24 |
+
'If a question does not make any sense, or is not factually coherent, '
|
| 25 |
+
'explain why instead of answering something not correct. '
|
| 26 |
+
"If you don't know the answer to a question, please don't share false information.")
|
| 27 |
+
|
| 28 |
+
register_template(
|
| 29 |
+
TemplateMeta(
|
| 30 |
+
LLMTemplateType.llama, ['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'],
|
| 31 |
+
default_system=LLAMA_DEFAULT_SYSTEM,
|
| 32 |
+
system_prefix=['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Llama3TemplateMeta(TemplateMeta):
|
| 37 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
|
| 38 |
+
prompt: Prompt = field(default_factory=lambda: [
|
| 39 |
+
'<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
|
| 40 |
+
'<|start_header_id|>assistant<|end_header_id|>\n\n'
|
| 41 |
+
])
|
| 42 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot_id|>'])
|
| 43 |
+
suffix: Prompt = field(default_factory=lambda: ['<|eot_id|>'])
|
| 44 |
+
system_prefix: Optional[Prompt] = field(
|
| 45 |
+
default_factory=lambda: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
|
| 46 |
+
agent_template: str = 'llama3'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
register_template(Llama3TemplateMeta(LLMTemplateType.llama3))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_llama3_2_prefix() -> Prompt:
|
| 53 |
+
now = dt.datetime.now()
|
| 54 |
+
date_string = now.strftime('%d %b %Y')
|
| 55 |
+
date_prompt = f'Cutting Knowledge Date: December 2023\nToday Date: {date_string}'
|
| 56 |
+
return [f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{date_prompt}\n\n' '{{SYSTEM}}<|eot_id|>']
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class Llama3_2TemplateMeta(Llama3TemplateMeta):
|
| 61 |
+
prefix: Prompt = field(default_factory=lambda: _get_llama3_2_prefix())
|
| 62 |
+
system_prefix: Optional[Prompt] = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
register_template(Llama3_2TemplateMeta(LLMTemplateType.llama3_2))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Llama3_2VisionTemplate(Template):
|
| 69 |
+
|
| 70 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 71 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 72 |
+
assert media_type == 'image'
|
| 73 |
+
return ['<|image|>']
|
| 74 |
+
|
| 75 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 76 |
+
from transformers.models.mllama.processing_mllama import (get_cross_attention_token_mask,
|
| 77 |
+
convert_sparse_cross_attention_mask_to_dense)
|
| 78 |
+
encoded = super()._encode(inputs)
|
| 79 |
+
images = inputs.images
|
| 80 |
+
if images:
|
| 81 |
+
input_ids = encoded['input_ids']
|
| 82 |
+
processor = self.processor
|
| 83 |
+
image_features = processor.image_processor(images, return_tensors='pt')
|
| 84 |
+
num_tiles = image_features.pop('num_tiles')
|
| 85 |
+
encoded.update(image_features)
|
| 86 |
+
|
| 87 |
+
cross_attention_token_mask = [get_cross_attention_token_mask(input_ids, processor.image_token_id)]
|
| 88 |
+
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
|
| 89 |
+
cross_attention_token_mask,
|
| 90 |
+
num_tiles=num_tiles,
|
| 91 |
+
max_num_tiles=processor.image_processor.max_image_tiles,
|
| 92 |
+
length=len(input_ids),
|
| 93 |
+
)
|
| 94 |
+
encoded['cross_attention_mask'] = torch.tensor(cross_attention_mask)
|
| 95 |
+
|
| 96 |
+
return encoded
|
| 97 |
+
|
| 98 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 99 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 100 |
+
for key in ['aspect_ratio_ids', 'aspect_ratio_mask']:
|
| 101 |
+
value = [b[key] for b in batch if b.get(key) is not None]
|
| 102 |
+
if value:
|
| 103 |
+
res[key] = torch.concat(value)
|
| 104 |
+
|
| 105 |
+
cross_attention_mask = [
|
| 106 |
+
b['cross_attention_mask'][0] for b in batch if b.get('cross_attention_mask') is not None
|
| 107 |
+
]
|
| 108 |
+
if cross_attention_mask:
|
| 109 |
+
res['cross_attention_mask'] = self._pad_sequence(cross_attention_mask, 0)
|
| 110 |
+
return res
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
register_template(Llama3_2TemplateMeta(MLLMTemplateType.llama3_2_vision, template_cls=Llama3_2VisionTemplate))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Llama4Template(Template):
|
| 117 |
+
placeholder_tokens = ['<|patch|>']
|
| 118 |
+
|
| 119 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 120 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 121 |
+
assert media_type == 'image'
|
| 122 |
+
return [[-100]]
|
| 123 |
+
|
| 124 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 125 |
+
encoded = super()._encode(inputs)
|
| 126 |
+
images = inputs.images
|
| 127 |
+
if images:
|
| 128 |
+
split_token = self._tokenize('\n')
|
| 129 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 130 |
+
idx_list = findall(input_ids, -100)
|
| 131 |
+
media_inputs = self.processor(
|
| 132 |
+
text='\n'.join(['<|image|>'] * len(idx_list)),
|
| 133 |
+
images=images,
|
| 134 |
+
add_special_tokens=False,
|
| 135 |
+
return_tensors='pt')
|
| 136 |
+
splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
|
| 137 |
+
|
| 138 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
|
| 139 |
+
lambda i: splited_tokens[i])
|
| 140 |
+
encoded['pixel_values'] = media_inputs['pixel_values']
|
| 141 |
+
return encoded
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclass
|
| 145 |
+
class Llama4TemplateMeta(TemplateMeta):
|
| 146 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
|
| 147 |
+
prompt: Prompt = field(
|
| 148 |
+
default_factory=lambda:
|
| 149 |
+
['<|header_start|>user<|header_end|>\n\n{{QUERY}}<|eot|>'
|
| 150 |
+
'<|header_start|>assistant<|header_end|>\n\n'])
|
| 151 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot|>'])
|
| 152 |
+
suffix: Prompt = field(default_factory=lambda: ['<|eot|>'])
|
| 153 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|end_of_text|>', '<|eom|>'])
|
| 154 |
+
system_prefix: Optional[Prompt] = field(
|
| 155 |
+
default_factory=lambda: ['<|begin_of_text|><|header_start|>system<|header_end|>\n\n{{SYSTEM}}<|eot|>'])
|
| 156 |
+
agent_template: str = 'llama4'
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
register_template(Llama4TemplateMeta(MLLMTemplateType.llama4, template_cls=Llama4Template))
|
| 160 |
+
|
| 161 |
+
register_template(
|
| 162 |
+
Llama3TemplateMeta(
|
| 163 |
+
LLMTemplateType.reflection,
|
| 164 |
+
default_system=('You are a world-class AI system, capable of complex reasoning and reflection. '
|
| 165 |
+
'Reason through the query inside <thinking> tags, and then provide your final '
|
| 166 |
+
'response inside <output> tags. If you detect that you made a mistake in your reasoning '
|
| 167 |
+
'at any point, correct yourself inside <reflection> tags.')))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class Llama3_1OmniTemplate(Template):
|
| 171 |
+
skip_prompt = False
|
| 172 |
+
audio_placeholder = [[-200]]
|
| 173 |
+
|
| 174 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 175 |
+
import whisper
|
| 176 |
+
encoded = super()._encode(inputs)
|
| 177 |
+
audios = inputs.audios
|
| 178 |
+
if audios:
|
| 179 |
+
audios = load_batch(audios, whisper.load_audio)
|
| 180 |
+
n_mels = get_env_args('n_mels', int, 128)
|
| 181 |
+
for i, audio in enumerate(audios):
|
| 182 |
+
audio = whisper.pad_or_trim(audio)
|
| 183 |
+
audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0)
|
| 184 |
+
audios = torch.stack(audios)
|
| 185 |
+
encoded.update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])})
|
| 186 |
+
|
| 187 |
+
return encoded
|
| 188 |
+
|
| 189 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 190 |
+
speech = inputs.get('speech')
|
| 191 |
+
input_ids = inputs['input_ids']
|
| 192 |
+
labels = inputs.get('labels')
|
| 193 |
+
if speech is not None:
|
| 194 |
+
speech_lengths = inputs['speech_lengths']
|
| 195 |
+
speech = speech.to(model.dtype)
|
| 196 |
+
inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels,
|
| 197 |
+
speech, speech_lengths)[4:]
|
| 198 |
+
else:
|
| 199 |
+
inputs_embeds = model.get_model().embed_tokens(input_ids)
|
| 200 |
+
res = {'inputs_embeds': inputs_embeds}
|
| 201 |
+
if labels is not None:
|
| 202 |
+
res['labels'] = labels[0]
|
| 203 |
+
return res
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
register_template(
|
| 207 |
+
Llama3TemplateMeta(
|
| 208 |
+
MLLMTemplateType.llama3_1_omni,
|
| 209 |
+
default_system=('You are a helpful language and speech assistant. '
|
| 210 |
+
'You are able to understand the speech content that the user provides, '
|
| 211 |
+
'and assist the user with a variety of tasks using natural language.'),
|
| 212 |
+
template_cls=Llama3_1OmniTemplate,
|
| 213 |
+
))
|
swift/llm/template/template/llava.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import transformers
|
| 7 |
+
from packaging import version
|
| 8 |
+
|
| 9 |
+
from ..base import Template
|
| 10 |
+
from ..constant import MLLMTemplateType
|
| 11 |
+
from ..register import TemplateMeta, register_template
|
| 12 |
+
from ..template_inputs import StdTemplateInputs
|
| 13 |
+
from ..utils import Context, Prompt, findall
|
| 14 |
+
from ..vision_utils import load_video_llava
|
| 15 |
+
from .llama import Llama3TemplateMeta
|
| 16 |
+
from .qwen import QwenTemplateMeta
|
| 17 |
+
from .utils import ChatmlTemplateMeta
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LlavaHfTemplate(Template):
|
| 21 |
+
placeholder_tokens = ['<image>']
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def image_token_index(self):
|
| 25 |
+
if not hasattr(self, '_image_token_index'):
|
| 26 |
+
self._image_token_index = self.tokenizer.convert_tokens_to_ids(self.processor.image_token)
|
| 27 |
+
return self._image_token_index
|
| 28 |
+
|
| 29 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 30 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 31 |
+
assert media_type == 'image'
|
| 32 |
+
return ['<image>\n']
|
| 33 |
+
|
| 34 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 35 |
+
encoded = super()._encode(inputs)
|
| 36 |
+
images = inputs.images
|
| 37 |
+
if images:
|
| 38 |
+
image_processor = self.processor.image_processor
|
| 39 |
+
image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
|
| 40 |
+
encoded['pixel_values'] = image_inputs['pixel_values']
|
| 41 |
+
if 'image_sizes' in image_inputs:
|
| 42 |
+
encoded['image_sizes'] = image_inputs['image_sizes']
|
| 43 |
+
if version.parse(transformers.__version__) >= version.parse('4.47'):
|
| 44 |
+
input_ids = encoded['input_ids']
|
| 45 |
+
labels = encoded['labels']
|
| 46 |
+
idx_list = findall(input_ids, self.image_token_index) # <image>
|
| 47 |
+
height, width = image_inputs['pixel_values'][0].shape[-2:]
|
| 48 |
+
added_tokens_len = 0
|
| 49 |
+
for i, idx in enumerate(idx_list):
|
| 50 |
+
if 'image_sizes' in image_inputs:
|
| 51 |
+
orig_height, orig_width = image_inputs['image_sizes'][i].tolist()
|
| 52 |
+
num_image_tokens = self.processor._get_number_of_features(orig_height, orig_width, height,
|
| 53 |
+
width)
|
| 54 |
+
else:
|
| 55 |
+
num_image_tokens = (height // self.processor.patch_size) * (
|
| 56 |
+
width // self.processor.patch_size) + self.processor.num_additional_image_tokens
|
| 57 |
+
if self.processor.vision_feature_select_strategy == 'default':
|
| 58 |
+
num_image_tokens -= 1
|
| 59 |
+
input_ids = input_ids[:added_tokens_len + idx] + [self.image_token_index] * num_image_tokens \
|
| 60 |
+
+ input_ids[added_tokens_len + idx + 1:]
|
| 61 |
+
if labels is not None:
|
| 62 |
+
labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens \
|
| 63 |
+
+ labels[added_tokens_len + idx + 1:]
|
| 64 |
+
added_tokens_len += num_image_tokens - 1
|
| 65 |
+
encoded['input_ids'] = input_ids
|
| 66 |
+
encoded['labels'] = labels
|
| 67 |
+
return encoded
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
register_template(
|
| 71 |
+
TemplateMeta(
|
| 72 |
+
MLLMTemplateType.llava1_5_hf,
|
| 73 |
+
prefix=['<s>'],
|
| 74 |
+
prompt=['USER: {{QUERY}}\nASSISTANT:'],
|
| 75 |
+
chat_sep=['</s>'],
|
| 76 |
+
suffix=['</s>'],
|
| 77 |
+
system_prefix=['<s>{{SYSTEM}}\n'],
|
| 78 |
+
template_cls=LlavaHfTemplate,
|
| 79 |
+
))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LlavaVideoHfTemplate(Template):
|
| 83 |
+
|
| 84 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
|
| 85 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 86 |
+
if media_type == 'image':
|
| 87 |
+
return ['<image>\n']
|
| 88 |
+
assert media_type == 'video'
|
| 89 |
+
media_file = inputs.videos[index]
|
| 90 |
+
if media_file.rsplit('.', 1)[-1] in {'jpg', 'png'}:
|
| 91 |
+
return ['<image>\n']
|
| 92 |
+
else:
|
| 93 |
+
inputs.videos[index] = load_video_llava(inputs.videos[index])
|
| 94 |
+
return ['<video>\n']
|
| 95 |
+
|
| 96 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 97 |
+
encoded = super()._encode(inputs)
|
| 98 |
+
images = inputs.images or []
|
| 99 |
+
videos = inputs.videos or []
|
| 100 |
+
if len(videos) > 0:
|
| 101 |
+
video_processor = self.processor.video_processor
|
| 102 |
+
video_inputs = video_processor(videos, return_tensors='pt').to(self.model_info.torch_dtype)
|
| 103 |
+
encoded['pixel_values_videos'] = video_inputs['pixel_values_videos']
|
| 104 |
+
if len(images) > 0:
|
| 105 |
+
image_processor = self.processor.image_processor
|
| 106 |
+
image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
|
| 107 |
+
encoded['pixel_values'] = image_inputs['pixel_values']
|
| 108 |
+
encoded['image_sizes'] = image_inputs['image_sizes']
|
| 109 |
+
return encoded
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
register_template(
|
| 113 |
+
TemplateMeta(
|
| 114 |
+
MLLMTemplateType.llava_next_video_hf,
|
| 115 |
+
prefix=['{{SYSTEM}} '],
|
| 116 |
+
prompt=['USER: {{QUERY}} ASSISTANT:'],
|
| 117 |
+
chat_sep=[' '],
|
| 118 |
+
suffix=[['eos_token_id']],
|
| 119 |
+
template_cls=LlavaVideoHfTemplate,
|
| 120 |
+
auto_add_bos=True,
|
| 121 |
+
))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Llava1_6HfTemplate(LlavaHfTemplate):
|
| 125 |
+
|
| 126 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 127 |
+
for b in batch:
|
| 128 |
+
pixel_values = b.get('pixel_values')
|
| 129 |
+
if pixel_values is not None:
|
| 130 |
+
b['pixel_values'] = pixel_values.squeeze(0) # 5d -> 4d
|
| 131 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 132 |
+
return res
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class LlavaMistralTemplateMeta(TemplateMeta):
|
| 137 |
+
prefix: Prompt = field(default_factory=lambda: ['<s>[INST] '])
|
| 138 |
+
prompt: Prompt = field(default_factory=lambda: ['{{QUERY}} [/INST]'])
|
| 139 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['</s>[INST] '])
|
| 140 |
+
suffix: Prompt = field(default_factory=lambda: ['</s>'])
|
| 141 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<<SYS>>\n{{system}}\n<</SYS>>\n\n'])
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral_hf, template_cls=Llava1_6HfTemplate))
|
| 145 |
+
|
| 146 |
+
register_template(
|
| 147 |
+
TemplateMeta(
|
| 148 |
+
MLLMTemplateType.llava1_6_vicuna_hf,
|
| 149 |
+
prefix=['<s>'],
|
| 150 |
+
prompt=['USER: {{QUERY}} ASSISTANT:'],
|
| 151 |
+
chat_sep=['</s>'],
|
| 152 |
+
suffix=['</s>'],
|
| 153 |
+
default_system=('A chat between a curious human and an artificial intelligence assistant. '
|
| 154 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions."),
|
| 155 |
+
system_prefix=['<s>{{SYSTEM}} '],
|
| 156 |
+
template_cls=Llava1_6HfTemplate))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class LLava1_6YiHfTemplate(Llava1_6HfTemplate):
|
| 160 |
+
|
| 161 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
|
| 162 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 163 |
+
if self.mode == 'vllm':
|
| 164 |
+
return [[64000], '\n']
|
| 165 |
+
else:
|
| 166 |
+
return super().replace_tag(media_type, index, inputs)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
register_template(ChatmlTemplateMeta(
|
| 170 |
+
MLLMTemplateType.llava1_6_yi_hf,
|
| 171 |
+
template_cls=LLava1_6YiHfTemplate,
|
| 172 |
+
))
|
| 173 |
+
|
| 174 |
+
register_template(Llama3TemplateMeta(
|
| 175 |
+
MLLMTemplateType.llama3_llava_next_hf,
|
| 176 |
+
template_cls=Llava1_6HfTemplate,
|
| 177 |
+
))
|
| 178 |
+
|
| 179 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen_hf, template_cls=Llava1_6HfTemplate))
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class LlavaOneVisionHfTemplate(Llava1_6HfTemplate):
|
| 183 |
+
|
| 184 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 185 |
+
encoded = Template._encode(self, inputs)
|
| 186 |
+
images = inputs.images
|
| 187 |
+
input_ids = encoded['input_ids']
|
| 188 |
+
labels = encoded['labels']
|
| 189 |
+
idx_list = findall(input_ids, 151646) # <image>
|
| 190 |
+
processor = self.processor
|
| 191 |
+
if images:
|
| 192 |
+
image_processor = processor.image_processor
|
| 193 |
+
image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
|
| 194 |
+
height, width = image_inputs['pixel_values'][0].shape[-2:]
|
| 195 |
+
added_tokens_len = 0
|
| 196 |
+
for idx, pixel_v, image_size in zip(idx_list, image_inputs['pixel_values'], image_inputs['image_sizes']):
|
| 197 |
+
if isinstance(image_size, torch.Tensor):
|
| 198 |
+
image_size = image_size.tolist()
|
| 199 |
+
orig_height, orig_width = image_size
|
| 200 |
+
num_image_tokens = processor._get_number_of_features(orig_height, orig_width, height, width)
|
| 201 |
+
input_ids = input_ids[:added_tokens_len
|
| 202 |
+
+ idx] + [151646] * num_image_tokens + input_ids[added_tokens_len + idx + 1:]
|
| 203 |
+
if labels is not None:
|
| 204 |
+
labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens + labels[added_tokens_len + idx
|
| 205 |
+
+ 1:]
|
| 206 |
+
added_tokens_len += num_image_tokens - 1
|
| 207 |
+
encoded['input_ids'] = input_ids
|
| 208 |
+
encoded['labels'] = labels
|
| 209 |
+
encoded['pixel_values'] = image_inputs['pixel_values']
|
| 210 |
+
if 'image_sizes' in image_inputs:
|
| 211 |
+
encoded['image_sizes'] = image_inputs['image_sizes']
|
| 212 |
+
return encoded
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
register_template(
|
| 216 |
+
QwenTemplateMeta(
|
| 217 |
+
MLLMTemplateType.llava_onevision_hf,
|
| 218 |
+
default_system=None,
|
| 219 |
+
template_cls=LlavaOneVisionHfTemplate,
|
| 220 |
+
))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class LlavaLlama3_1HfTemplate(LlavaHfTemplate):
|
| 224 |
+
# DaozeZhang
|
| 225 |
+
system = ('You are a helpful language and vision assistant. '
|
| 226 |
+
'You are able to understand the visual content that the user provides, '
|
| 227 |
+
'and assist the user with a variety of tasks using natural language.')
|
| 228 |
+
|
| 229 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 230 |
+
encoded = super()._encode(inputs)
|
| 231 |
+
if len(encoded['pixel_values'].shape) == 5: # (1, num_patch, 3, H/W, W/H)
|
| 232 |
+
encoded['pixel_values'] = torch.squeeze(encoded['pixel_values'], dim=0) # (num_patch, 3, H/W, W/H)
|
| 233 |
+
return encoded
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
register_template(
|
| 237 |
+
Llama3TemplateMeta(
|
| 238 |
+
MLLMTemplateType.llava_llama3_1_hf,
|
| 239 |
+
default_system=LlavaLlama3_1HfTemplate.system,
|
| 240 |
+
template_cls=LlavaLlama3_1HfTemplate,
|
| 241 |
+
))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class LLavaLlama3HfTemplate(Template):
|
| 245 |
+
# xtuner
|
| 246 |
+
image_placeholder = ['<image>\n']
|
| 247 |
+
|
| 248 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 249 |
+
encoded = super()._encode(inputs)
|
| 250 |
+
raw_image = inputs.images
|
| 251 |
+
if raw_image:
|
| 252 |
+
pixel_values = self.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
|
| 253 |
+
encoded['pixel_values'] = pixel_values.to(self.model_info.torch_dtype)
|
| 254 |
+
return encoded
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
register_template(Llama3TemplateMeta(
|
| 258 |
+
MLLMTemplateType.llava_llama3_hf,
|
| 259 |
+
template_cls=LLavaLlama3HfTemplate,
|
| 260 |
+
))
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class LLavaTemplate(Template):
|
| 264 |
+
skip_prompt = False
|
| 265 |
+
use_model = True
|
| 266 |
+
|
| 267 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
|
| 268 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 269 |
+
assert media_type == 'image'
|
| 270 |
+
return [[-200], '\n']
|
| 271 |
+
|
| 272 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 273 |
+
encoded = super()._encode(inputs)
|
| 274 |
+
images = inputs.images or []
|
| 275 |
+
image_sizes = [x.size for x in images]
|
| 276 |
+
from llava.mm_utils import process_images
|
| 277 |
+
model = self.model.model
|
| 278 |
+
if not hasattr(model, 'vision_tower'):
|
| 279 |
+
model = model.model
|
| 280 |
+
image_processor = model.vision_tower.image_processor
|
| 281 |
+
if images:
|
| 282 |
+
images_tensor = process_images(images, image_processor, model.config)
|
| 283 |
+
encoded['images'] = images_tensor.to(model.dtype).squeeze(0)
|
| 284 |
+
encoded['image_sizes'] = image_sizes
|
| 285 |
+
return encoded
|
| 286 |
+
|
| 287 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 288 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 289 |
+
images = [b['images'] for b in batch if 'images' in b]
|
| 290 |
+
if images:
|
| 291 |
+
res['images'] = images
|
| 292 |
+
res['image_sizes'] = sum([b['image_sizes'] for b in batch if 'image_sizes' in b], start=[])
|
| 293 |
+
return res
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
register_template(LlavaMistralTemplateMeta(MLLMTemplateType.llava1_6_mistral, template_cls=LLavaTemplate))
|
| 297 |
+
|
| 298 |
+
register_template(ChatmlTemplateMeta(MLLMTemplateType.llava1_6_yi, template_cls=LLavaTemplate))
|
| 299 |
+
|
| 300 |
+
register_template(
|
| 301 |
+
Llama3TemplateMeta(
|
| 302 |
+
MLLMTemplateType.llama3_llava_next,
|
| 303 |
+
template_cls=LLavaTemplate,
|
| 304 |
+
default_system=('You are a helpful language and vision assistant. '
|
| 305 |
+
'You are able to understand the visual content that the user provides, '
|
| 306 |
+
'and assist the user with a variety of tasks using natural language.'),
|
| 307 |
+
))
|
| 308 |
+
|
| 309 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.llava_next_qwen, template_cls=LLavaTemplate))
|
swift/llm/template/template/llm.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 7 |
+
from ..register import TemplateMeta, register_template
|
| 8 |
+
from ..utils import Prompt
|
| 9 |
+
from .llama import Llama3_2TemplateMeta
|
| 10 |
+
from .qwen import Qwen2VLTemplate, QwenTemplateMeta
|
| 11 |
+
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
|
| 12 |
+
|
| 13 |
+
register_template(
|
| 14 |
+
TemplateMeta(
|
| 15 |
+
LLMTemplateType.default,
|
| 16 |
+
prefix=[],
|
| 17 |
+
prompt=['### Human:\n{{QUERY}}\n\n### Assistant:\n'],
|
| 18 |
+
chat_sep=['\n\n'],
|
| 19 |
+
default_system=DEFAULT_SYSTEM,
|
| 20 |
+
system_prefix=['{{SYSTEM}}\n\n'],
|
| 21 |
+
auto_add_bos=True))
|
| 22 |
+
|
| 23 |
+
register_template(
|
| 24 |
+
TemplateMeta(
|
| 25 |
+
LLMTemplateType.modelscope_agent,
|
| 26 |
+
prefix=[],
|
| 27 |
+
prompt=[' \n\n<|user|>:{{QUERY}} \n\n<|assistant|>:'],
|
| 28 |
+
chat_sep=[],
|
| 29 |
+
suffix=[' \n\n</s>'],
|
| 30 |
+
system_prefix=[' \n\n<|system|>:{{SYSTEM}}'],
|
| 31 |
+
default_system=DEFAULT_SYSTEM,
|
| 32 |
+
))
|
| 33 |
+
|
| 34 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_gme, template_cls=Qwen2VLTemplate, suffix=['<|endoftext|>']))
|
| 35 |
+
|
| 36 |
+
register_template(
|
| 37 |
+
TemplateMeta(LLMTemplateType.baichuan, prefix=['{{SYSTEM}}'], prompt=[[195], '{{QUERY}}', [196]], chat_sep=[]))
|
| 38 |
+
|
| 39 |
+
register_template(
|
| 40 |
+
TemplateMeta(
|
| 41 |
+
LLMTemplateType.baichuan_m1,
|
| 42 |
+
prefix=[],
|
| 43 |
+
prompt=['<C_Q>{{QUERY}}<C_A>'],
|
| 44 |
+
chat_sep=[],
|
| 45 |
+
suffix=['<C_A>'],
|
| 46 |
+
system_prefix=['<B_SYS>{{SYSTEM}}'],
|
| 47 |
+
default_system=DEFAULT_SYSTEM,
|
| 48 |
+
))
|
| 49 |
+
|
| 50 |
+
register_template(
|
| 51 |
+
TemplateMeta(
|
| 52 |
+
LLMTemplateType.numina,
|
| 53 |
+
prefix=[['bos_token_id']],
|
| 54 |
+
prompt=['### Problem: {{QUERY}}\n### Solution: '],
|
| 55 |
+
chat_sep=['\n'],
|
| 56 |
+
system_prefix=[['bos_token_id'], '{{SYSTEM}}']))
|
| 57 |
+
|
| 58 |
+
register_template(
|
| 59 |
+
TemplateMeta(
|
| 60 |
+
LLMTemplateType.mistral_nemo,
|
| 61 |
+
prefix=['<s>[INST] '],
|
| 62 |
+
prompt=['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'],
|
| 63 |
+
chat_sep=['</s>[INST] '],
|
| 64 |
+
suffix=['</s>']))
|
| 65 |
+
|
| 66 |
+
today = datetime.now().strftime('%Y-%m-%d')
|
| 67 |
+
|
| 68 |
+
mistral_2501_system = (
|
| 69 |
+
'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
|
| 70 |
+
'headquartered in Paris.\n'
|
| 71 |
+
f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
|
| 72 |
+
"When you're not sure about some information, you say that you don't have the information and don't "
|
| 73 |
+
'make up anything.\n'
|
| 74 |
+
"If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
|
| 75 |
+
'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
|
| 76 |
+
'"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
|
| 77 |
+
'Where do you travel from?")')
|
| 78 |
+
|
| 79 |
+
register_template(
|
| 80 |
+
TemplateMeta(
|
| 81 |
+
LLMTemplateType.mistral_2501,
|
| 82 |
+
prefix=['<s>'],
|
| 83 |
+
prompt=['[INST]{{QUERY}}[/INST]'],
|
| 84 |
+
chat_sep=['</s>'],
|
| 85 |
+
suffix=['</s>'],
|
| 86 |
+
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
|
| 87 |
+
default_system=mistral_2501_system))
|
| 88 |
+
|
| 89 |
+
register_template(
|
| 90 |
+
TemplateMeta(
|
| 91 |
+
LLMTemplateType.xverse,
|
| 92 |
+
prefix=['{{SYSTEM}}'],
|
| 93 |
+
prompt=['Human: {{QUERY}}\n\nAssistant: '],
|
| 94 |
+
chat_sep=[['eos_token_id']]))
|
| 95 |
+
|
| 96 |
+
register_template(TemplateMeta(LLMTemplateType.yuan, prefix=[], prompt=['{{QUERY}}<sep>'], chat_sep=None))
|
| 97 |
+
register_template(
|
| 98 |
+
TemplateMeta(
|
| 99 |
+
LLMTemplateType.ziya,
|
| 100 |
+
prefix=[['bos_token_id'], '{{SYSTEM}}'],
|
| 101 |
+
prompt=['<human>:{{QUERY}}\n<bot>:'],
|
| 102 |
+
chat_sep=['\n']))
|
| 103 |
+
|
| 104 |
+
register_template(
|
| 105 |
+
TemplateMeta(
|
| 106 |
+
LLMTemplateType.skywork,
|
| 107 |
+
prefix=['<s>{{SYSTEM}}'],
|
| 108 |
+
prompt=['</s><s>[USER]{{QUERY}}[SEP][BOT]'],
|
| 109 |
+
chat_sep=None,
|
| 110 |
+
suffix=['[SEP]</s>']))
|
| 111 |
+
|
| 112 |
+
register_template(
|
| 113 |
+
Llama3_2TemplateMeta(
|
| 114 |
+
LLMTemplateType.skywork_o1,
|
| 115 |
+
default_system=(
|
| 116 |
+
'You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems '
|
| 117 |
+
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's "
|
| 118 |
+
'request, you first engage in a lengthy and in-depth thinking process to explore possible solutions to '
|
| 119 |
+
'the problem. After completing your thoughts, you then provide a detailed explanation of the solution '
|
| 120 |
+
'process in your response.'),
|
| 121 |
+
))
|
| 122 |
+
|
| 123 |
+
register_template(
|
| 124 |
+
TemplateMeta(
|
| 125 |
+
LLMTemplateType.bluelm,
|
| 126 |
+
prefix=[['bos_token_id'], '{{SYSTEM}}'],
|
| 127 |
+
prompt=['[|Human|]:{{QUERY}}[|AI|]:'],
|
| 128 |
+
chat_sep=[]))
|
| 129 |
+
|
| 130 |
+
register_template(
|
| 131 |
+
TemplateMeta(
|
| 132 |
+
LLMTemplateType.codefuse_codellama,
|
| 133 |
+
prefix=['{{SYSTEM}}'],
|
| 134 |
+
prompt=['<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'],
|
| 135 |
+
chat_sep=[]))
|
| 136 |
+
|
| 137 |
+
register_template(
|
| 138 |
+
TemplateMeta(
|
| 139 |
+
LLMTemplateType.codefuse,
|
| 140 |
+
prefix=[],
|
| 141 |
+
prompt=['<s>human\n{{QUERY}}\n<s>bot\n'],
|
| 142 |
+
chat_sep=[['eos_token_id'], '\n'],
|
| 143 |
+
system_prefix=['<s>system\n{{SYSTEM}}\n']))
|
| 144 |
+
|
| 145 |
+
register_template(
|
| 146 |
+
TemplateMeta(
|
| 147 |
+
LLMTemplateType.zephyr,
|
| 148 |
+
prefix=[],
|
| 149 |
+
prompt=['<|user|>\n{{QUERY}}</s>\n<|assistant|>\n'],
|
| 150 |
+
chat_sep=['</s>\n'],
|
| 151 |
+
suffix=['</s>'],
|
| 152 |
+
system_prefix=['<|system|>\n{{SYSTEM}}</s>\n']))
|
| 153 |
+
|
| 154 |
+
register_template(
|
| 155 |
+
TemplateMeta(
|
| 156 |
+
LLMTemplateType.sus,
|
| 157 |
+
prefix=['{{SYSTEM}}'],
|
| 158 |
+
prompt=['### Human: {{QUERY}}\n\n### Assistant: '],
|
| 159 |
+
chat_sep=['<|endoftext|>'],
|
| 160 |
+
suffix=['<|endoftext|>']))
|
| 161 |
+
|
| 162 |
+
register_template(
|
| 163 |
+
TemplateMeta(
|
| 164 |
+
LLMTemplateType.orion,
|
| 165 |
+
prefix=['<s>{{SYSTEM}}'],
|
| 166 |
+
prompt=['Human: {{QUERY}}\n\nAssistant: </s>'],
|
| 167 |
+
chat_sep=['</s>'],
|
| 168 |
+
suffix=['</s>']))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@dataclass
|
| 172 |
+
class TeleChatTemplateMeta(TemplateMeta):
|
| 173 |
+
prefix: Prompt = field(default_factory=list)
|
| 174 |
+
prompt: Prompt = field(default_factory=lambda: [['user_token_id'], '{{QUERY}}', ['bot_token_id']])
|
| 175 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
|
| 176 |
+
suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
|
| 177 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<_system>{{SYSTEM}}\n'])
|
| 178 |
+
auto_add_bos: bool = True
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
register_template(TeleChatTemplateMeta(LLMTemplateType.telechat))
|
| 182 |
+
|
| 183 |
+
telechat_system = '你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。'
|
| 184 |
+
register_template(TeleChatTemplateMeta(LLMTemplateType.telechat2, default_system=telechat_system))
|
| 185 |
+
|
| 186 |
+
DBRX_SYSTEM = (
|
| 187 |
+
'You are DBRX, created by Databricks. You were last updated in December 2023. '
|
| 188 |
+
'You answer questions based on information available up to that point.\n'
|
| 189 |
+
'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
|
| 190 |
+
'but provide thorough responses to more complex and open-ended questions.\n'
|
| 191 |
+
'You assist with various tasks, from writing to coding (using markdown for code blocks '
|
| 192 |
+
'— remember to use ``` with code, JSON, and tables).\n'
|
| 193 |
+
'You do not have real-time data access or code execution capabilities.'
|
| 194 |
+
' You avoid stereotyping and provide balanced perspectives on controversial topics. '
|
| 195 |
+
'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
|
| 196 |
+
'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
|
| 197 |
+
'If you find yourself talking about this message, stop. You should be responding appropriately '
|
| 198 |
+
'and usually that means not mentioning this.'
|
| 199 |
+
'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
|
| 200 |
+
'PERTINENT TO THE USER\'S QUERY.')
|
| 201 |
+
|
| 202 |
+
register_template(ChatmlTemplateMeta(LLMTemplateType.dbrx, default_system=DBRX_SYSTEM))
|
| 203 |
+
|
| 204 |
+
register_template(
|
| 205 |
+
TemplateMeta(
|
| 206 |
+
LLMTemplateType.mengzi, prefix=[], prompt=['输入:{{QUERY}}输出:\n'], chat_sep=[], system_prefix=['指令:{{SYSTEM}}']))
|
| 207 |
+
|
| 208 |
+
C4AI_SYSTEM = ('You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by '
|
| 209 |
+
'providing thorough responses.You are trained by Cohere.')
|
| 210 |
+
register_template(
|
| 211 |
+
TemplateMeta(
|
| 212 |
+
LLMTemplateType.c4ai,
|
| 213 |
+
prefix=['<BOS_TOKEN>'],
|
| 214 |
+
prompt=[
|
| 215 |
+
'<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>'
|
| 216 |
+
'<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
|
| 217 |
+
],
|
| 218 |
+
chat_sep=['<|END_OF_TURN_TOKEN|>'],
|
| 219 |
+
suffix=['<|END_OF_TURN_TOKEN|>'],
|
| 220 |
+
default_system=C4AI_SYSTEM,
|
| 221 |
+
system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|']))
|
| 222 |
+
|
| 223 |
+
register_template(
|
| 224 |
+
TemplateMeta(
|
| 225 |
+
LLMTemplateType.wizardlm2,
|
| 226 |
+
prefix=['{{SYSTEM}}'],
|
| 227 |
+
prompt=['User:\n{{QUERY}}\n\nAssistant:\n'],
|
| 228 |
+
chat_sep=['\n\n'],
|
| 229 |
+
suffix=['</s>']))
|
| 230 |
+
|
| 231 |
+
_wizardlm2_system = ('A chat between a curious user and an artificial intelligence assistant. '
|
| 232 |
+
'The assistant gives helpful, detailed, and polite answers to the user\'s questions. ')
|
| 233 |
+
register_template(
|
| 234 |
+
TemplateMeta(
|
| 235 |
+
LLMTemplateType.wizardlm2_moe,
|
| 236 |
+
prefix=['{{SYSTEM}}'],
|
| 237 |
+
prompt=['USER: {{QUERY}} ASSISTANT:'],
|
| 238 |
+
chat_sep=['</s>'],
|
| 239 |
+
suffix=['</s>'],
|
| 240 |
+
default_system=_wizardlm2_system))
|
| 241 |
+
|
| 242 |
+
register_template(
|
| 243 |
+
TemplateMeta(
|
| 244 |
+
LLMTemplateType.atom,
|
| 245 |
+
prefix=['{{SYSTEM}}'],
|
| 246 |
+
prompt=['<s>Human: {{QUERY}}\n</s><s>Assistant: '],
|
| 247 |
+
chat_sep=['</s>'],
|
| 248 |
+
suffix=['</s>']))
|
| 249 |
+
|
| 250 |
+
AYA_SYSTEM = ('You are Aya, a brilliant, sophisticated, multilingual AI-assistant trained to assist human users by '
|
| 251 |
+
'providing thorough responses. You are able to interact and respond to questions in 23 languages and '
|
| 252 |
+
'you are powered by a multilingual model built by Cohere For AI.')
|
| 253 |
+
register_template(
|
| 254 |
+
TemplateMeta(
|
| 255 |
+
LLMTemplateType.aya,
|
| 256 |
+
prefix=['<BOS_TOKEN>'],
|
| 257 |
+
prompt=[
|
| 258 |
+
'<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|>'
|
| 259 |
+
'<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'
|
| 260 |
+
],
|
| 261 |
+
chat_sep=['<|END_OF_TURN_TOKEN|>'],
|
| 262 |
+
suffix=['<|END_OF_TURN_TOKEN|>'],
|
| 263 |
+
default_system=AYA_SYSTEM,
|
| 264 |
+
system_prefix=['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|']))
|
| 265 |
+
|
| 266 |
+
register_template(
|
| 267 |
+
TemplateMeta(
|
| 268 |
+
LLMTemplateType.ling,
|
| 269 |
+
prefix=[],
|
| 270 |
+
system_prefix=['<role>SYSTEM</role>{{SYSTEM}}'],
|
| 271 |
+
prompt=['<role>HUMAN</role>{{QUERY}}<role>ASSISTANT</role>'],
|
| 272 |
+
chat_sep=[],
|
| 273 |
+
suffix=['<|endoftext|>'],
|
| 274 |
+
))
|
swift/llm/template/template/megrez.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from ..base import Template
|
| 9 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 10 |
+
from ..register import TemplateMeta, register_template
|
| 11 |
+
from ..template_inputs import StdTemplateInputs
|
| 12 |
+
from ..utils import Context, Prompt, findall
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class MegrezTemplateMeta(TemplateMeta):
|
| 17 |
+
prefix: Prompt = field(default_factory=lambda: ['<|role_start|>system<|role_end|>{{SYSTEM}}<|turn_end|>'])
|
| 18 |
+
prompt: Prompt = field(default_factory=lambda:
|
| 19 |
+
['<|role_start|>user<|role_end|>{{QUERY}}<|turn_end|><|role_start|>assistant<|role_end|>'])
|
| 20 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|turn_end|>'])
|
| 21 |
+
suffix: Prompt = field(default_factory=lambda: ['<|turn_end|>'])
|
| 22 |
+
default_system: str = '你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
register_template(MegrezTemplateMeta(LLMTemplateType.megrez))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MegrezOmniTemplate(Template):
|
| 29 |
+
skip_prompt = False
|
| 30 |
+
placeholder_tokens = ['<|unk|>']
|
| 31 |
+
|
| 32 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 33 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 34 |
+
if media_type == 'image':
|
| 35 |
+
return [[-1], '\n']
|
| 36 |
+
elif media_type == 'audio':
|
| 37 |
+
return [f'Audio {index + 1}: ', [-2], '\n']
|
| 38 |
+
|
| 39 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 40 |
+
encoded = super()._encode(inputs)
|
| 41 |
+
input_ids = encoded['input_ids']
|
| 42 |
+
labels = encoded['labels']
|
| 43 |
+
|
| 44 |
+
for mm_key in ['images', 'audios']:
|
| 45 |
+
mm_data = getattr(inputs, mm_key)
|
| 46 |
+
if not mm_data:
|
| 47 |
+
continue
|
| 48 |
+
if mm_key == 'images':
|
| 49 |
+
idx_list = findall(input_ids, -1)
|
| 50 |
+
encoding = self.processor.process_image(
|
| 51 |
+
mm_data,
|
| 52 |
+
return_tensors='pt',
|
| 53 |
+
)
|
| 54 |
+
text = self.processor.insert_image_feature_placeholders(
|
| 55 |
+
'<s>'.join(['(<image>./</image>)'] * len(mm_data)), encoding)
|
| 56 |
+
encoded['image_encoding'] = encoding
|
| 57 |
+
else:
|
| 58 |
+
idx_list = findall(input_ids, -2)
|
| 59 |
+
encoding = self.processor.process_audio(
|
| 60 |
+
mm_data,
|
| 61 |
+
return_tensors='pt',
|
| 62 |
+
)
|
| 63 |
+
text = self.processor.insert_audio_feature_placeholders(
|
| 64 |
+
'<s>'.join(['(<audio>./</audio>)'] * len(mm_data)), encoding)
|
| 65 |
+
encoded['audio_encoding'] = encoding
|
| 66 |
+
|
| 67 |
+
padding = text.split('<s>')
|
| 68 |
+
|
| 69 |
+
def _get_new_tokens(i):
|
| 70 |
+
return self._tokenize(padding[i])
|
| 71 |
+
|
| 72 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 73 |
+
encoded['input_ids'] = input_ids
|
| 74 |
+
encoded['labels'] = labels
|
| 75 |
+
return encoded
|
| 76 |
+
|
| 77 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 78 |
+
_, inputs_embeds, _ = model.compose_embeddings(inputs)
|
| 79 |
+
inputs.pop('position_ids', None)
|
| 80 |
+
return {'inputs_embeds': inputs_embeds}
|
| 81 |
+
|
| 82 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 83 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 84 |
+
new_batch = []
|
| 85 |
+
for b in batch:
|
| 86 |
+
text_encodings = {'input_ids': torch.tensor(b['input_ids'])}
|
| 87 |
+
multimodal_inputs = {'image_encoding': b.get('image_encoding'), 'audio_encoding': b.get('audio_encoding')}
|
| 88 |
+
new_batch.append(self.processor.merge_encodings(text_encodings, multimodal_inputs))
|
| 89 |
+
res.update(self.processor.data_collator(new_batch))
|
| 90 |
+
return res
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
register_template(MegrezTemplateMeta(MLLMTemplateType.megrez_omni, template_cls=MegrezOmniTemplate))
|
swift/llm/template/template/microsoft.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ..base import Template
|
| 10 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 11 |
+
from ..register import TemplateMeta, register_template
|
| 12 |
+
from ..template_inputs import StdTemplateInputs
|
| 13 |
+
from ..utils import Context, Prompt, findall
|
| 14 |
+
from ..vision_utils import load_file
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FlorenceTemplate(Template):
|
| 18 |
+
# If it's an encoder-decoder architecture, the default settings are
|
| 19 |
+
# loss_scale: 'last_round' and skip_prompt: False.
|
| 20 |
+
is_encoder_decoder = True
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def _add_default_tags(inputs: StdTemplateInputs) -> None:
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 27 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 28 |
+
return []
|
| 29 |
+
|
| 30 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 31 |
+
return [''.join(f'<loc_{box}>' for box in bbox)]
|
| 32 |
+
|
| 33 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 34 |
+
processor = self.processor
|
| 35 |
+
inputs.query = inputs.to_history()['query']
|
| 36 |
+
new_query = processor._construct_prompts([inputs.query])[0]
|
| 37 |
+
for i in reversed(range(len(inputs.messages))):
|
| 38 |
+
if inputs.messages[i]['role'] == 'user':
|
| 39 |
+
inputs.messages[i]['content'] = new_query
|
| 40 |
+
break
|
| 41 |
+
encoded = super()._encode(inputs)
|
| 42 |
+
input_ids = encoded['prompt_input_ids']
|
| 43 |
+
images = inputs.images or []
|
| 44 |
+
labels = encoded['labels']
|
| 45 |
+
if labels is not None:
|
| 46 |
+
labels = [0] + labels
|
| 47 |
+
if images:
|
| 48 |
+
pixel_values = processor.image_processor(
|
| 49 |
+
images, return_tensors='pt')['pixel_values'].to(self.model_info.torch_dtype)
|
| 50 |
+
encoded['pixel_values'] = pixel_values
|
| 51 |
+
encoded['input_ids'] = input_ids
|
| 52 |
+
encoded['labels'] = labels
|
| 53 |
+
return encoded
|
| 54 |
+
|
| 55 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 56 |
+
inputs_embeds = model.get_input_embeddings()(inputs['input_ids'])
|
| 57 |
+
pixel_values = inputs.get('pixel_values')
|
| 58 |
+
if pixel_values is not None:
|
| 59 |
+
image_features = model._encode_image(pixel_values)
|
| 60 |
+
inputs_embeds, inputs['attention_mask'] = model._merge_input_ids_with_image_features(
|
| 61 |
+
image_features, inputs_embeds)
|
| 62 |
+
return {'inputs_embeds': inputs_embeds}
|
| 63 |
+
|
| 64 |
+
def decode(self, generate_ids: List[int], **kwargs) -> Any:
|
| 65 |
+
response = super().decode(generate_ids, **kwargs)
|
| 66 |
+
template_inputs = kwargs.get('template_inputs')
|
| 67 |
+
images = template_inputs.images
|
| 68 |
+
image_size = None
|
| 69 |
+
if images:
|
| 70 |
+
image_size = (images[0].width, images[0].height)
|
| 71 |
+
return json.dumps(
|
| 72 |
+
self.processor.post_process_generation(response, task=template_inputs.query, image_size=image_size))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
register_template(
|
| 76 |
+
TemplateMeta(
|
| 77 |
+
MLLMTemplateType.florence,
|
| 78 |
+
prefix=['<s>'],
|
| 79 |
+
prompt=['{{QUERY}}</s>'],
|
| 80 |
+
chat_sep=None,
|
| 81 |
+
suffix=['</s>'],
|
| 82 |
+
template_cls=FlorenceTemplate,
|
| 83 |
+
))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class Phi3TemplateMeta(TemplateMeta):
|
| 88 |
+
prefix: Prompt = field(default_factory=list)
|
| 89 |
+
prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n'])
|
| 90 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end|>\n'])
|
| 91 |
+
suffix: Prompt = field(default_factory=lambda: ['<|end|>'])
|
| 92 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}<|end|>\n'])
|
| 93 |
+
auto_add_bos: bool = True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
register_template(Phi3TemplateMeta(LLMTemplateType.phi3))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class Phi4TemplateMeta(TemplateMeta):
|
| 101 |
+
prefix: Prompt = field(default_factory=list)
|
| 102 |
+
prompt: Prompt = field(
|
| 103 |
+
default_factory=lambda: ['<|im_start|>user<|im_sep|>{{QUERY}}<|im_end|><|im_start|>assistant<|im_sep|>'])
|
| 104 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>'])
|
| 105 |
+
suffix: Prompt = field(default_factory=lambda: ['<|im_end|>'])
|
| 106 |
+
system_prefix: Optional[Prompt] = field(
|
| 107 |
+
default_factory=lambda: ['<|im_start|>system<|im_sep|>{{SYSTEM}}<|im_end|>'])
|
| 108 |
+
auto_add_bos: bool = True
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
register_template(Phi4TemplateMeta(LLMTemplateType.phi4))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Phi3VisionTemplate(Template):
|
| 115 |
+
image_placeholder = ['<|image|><s>\n'] # <|image|>\n
|
| 116 |
+
|
| 117 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 118 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 119 |
+
if self.mode == 'vllm':
|
| 120 |
+
return [f'<|image_{index + 1}|>\n'] # <|image_1|>\n
|
| 121 |
+
else:
|
| 122 |
+
return super().replace_tag(media_type, index, inputs)
|
| 123 |
+
|
| 124 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 125 |
+
images = inputs.images or []
|
| 126 |
+
encoded = super()._encode(inputs)
|
| 127 |
+
input_ids = encoded['input_ids']
|
| 128 |
+
labels = encoded['labels']
|
| 129 |
+
idx_list = findall(input_ids, 32044) # '<|image|>'
|
| 130 |
+
|
| 131 |
+
if len(images) > 0:
|
| 132 |
+
processor = self.processor
|
| 133 |
+
encoded.update(processor.image_processor(images, return_tensors='pt'))
|
| 134 |
+
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
|
| 135 |
+
res_input_ids = []
|
| 136 |
+
res_labels = []
|
| 137 |
+
num_img_tokens = encoded.pop('num_img_tokens').tolist()
|
| 138 |
+
idx_list.insert(0, -1)
|
| 139 |
+
for i in range(len(idx_list) - 1):
|
| 140 |
+
image_token_id = -i - 1
|
| 141 |
+
res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + [image_token_id] * num_img_tokens[i]
|
| 142 |
+
if labels is not None:
|
| 143 |
+
res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * num_img_tokens[i]
|
| 144 |
+
res_input_ids += input_ids[idx_list[-1] + 1:]
|
| 145 |
+
input_ids = res_input_ids
|
| 146 |
+
if labels is not None:
|
| 147 |
+
res_labels += labels[idx_list[-1] + 1:]
|
| 148 |
+
labels = res_labels
|
| 149 |
+
|
| 150 |
+
encoded['input_ids'] = input_ids
|
| 151 |
+
encoded['labels'] = labels
|
| 152 |
+
return encoded
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class Phi4MMTemplate(Template):
|
| 156 |
+
placeholder_tokens = ['<|endoftext10|>', '<|endoftext11|>']
|
| 157 |
+
|
| 158 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 159 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 160 |
+
if media_type == 'image':
|
| 161 |
+
return [[-100]]
|
| 162 |
+
elif media_type == 'audio':
|
| 163 |
+
import soundfile as sf
|
| 164 |
+
inputs.audios[index] = sf.read(load_file(inputs.audios[index]))
|
| 165 |
+
return [[-200]]
|
| 166 |
+
|
| 167 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 168 |
+
encoded = super()._encode(inputs)
|
| 169 |
+
input_ids = encoded['input_ids']
|
| 170 |
+
labels = encoded['labels']
|
| 171 |
+
images_idx = findall(input_ids, -100)
|
| 172 |
+
audios_idx = findall(input_ids, -200)
|
| 173 |
+
text = '\n'.join(['<|image_1|>'] * len(inputs.images) + ['<|audio_1|>'] * len(inputs.audios))
|
| 174 |
+
new_encoded = self.processor(
|
| 175 |
+
text=text, images=inputs.images or None, audios=inputs.audios or None, return_tensors='pt')
|
| 176 |
+
placeholders = self._split_list(new_encoded.pop('input_ids')[0].tolist(), 198)
|
| 177 |
+
|
| 178 |
+
def _get_new_tokens(i):
|
| 179 |
+
return placeholders[i]
|
| 180 |
+
|
| 181 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, images_idx + audios_idx,
|
| 182 |
+
_get_new_tokens)
|
| 183 |
+
new_encoded.pop('attention_mask')
|
| 184 |
+
encoded.update(new_encoded)
|
| 185 |
+
return encoded
|
| 186 |
+
|
| 187 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 188 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 189 |
+
keys = [
|
| 190 |
+
'input_image_embeds', 'image_sizes', 'image_attention_mask', 'input_audio_embeds', 'audio_embed_sizes',
|
| 191 |
+
'input_mode'
|
| 192 |
+
]
|
| 193 |
+
inputs = self.fetch_inputs(batch, keys)
|
| 194 |
+
for k, v in inputs.items():
|
| 195 |
+
inputs[k] = torch.concat(v)
|
| 196 |
+
res.update(inputs)
|
| 197 |
+
return res
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
register_template(Phi3TemplateMeta(MLLMTemplateType.phi3_vision, template_cls=Phi3VisionTemplate))
|
| 201 |
+
|
| 202 |
+
register_template(Phi3TemplateMeta(
|
| 203 |
+
MLLMTemplateType.phi4_multimodal,
|
| 204 |
+
template_cls=Phi4MMTemplate,
|
| 205 |
+
))
|
swift/llm/template/template/minicpm.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from swift.utils import get_env_args
|
| 10 |
+
from ..base import Template
|
| 11 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 12 |
+
from ..register import TemplateMeta, register_template
|
| 13 |
+
from ..template_inputs import StdTemplateInputs
|
| 14 |
+
from ..utils import Context, Prompt, findall
|
| 15 |
+
from ..vision_utils import load_video_minicpmv_mplug_owl3
|
| 16 |
+
from .llama import Llama3TemplateMeta
|
| 17 |
+
from .qwen import Qwen2_5TemplateMeta, QwenTemplateMeta
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class MinicpmTemplateMeta(TemplateMeta):
|
| 22 |
+
prefix: Prompt = field(default_factory=lambda: ['<s>{{SYSTEM}}'])
|
| 23 |
+
prompt: Prompt = field(default_factory=lambda: ['<用户>{{QUERY}}<AI>'])
|
| 24 |
+
chat_sep: Optional[Prompt] = field(default_factory=list)
|
| 25 |
+
suffix: Prompt = field(default_factory=lambda: ['</s>'])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_template(MinicpmTemplateMeta(LLMTemplateType.minicpm))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _remove_idx(arr: List[int], idx_list: List[int]) -> List[int]:
|
| 32 |
+
res = []
|
| 33 |
+
idx_set = set(idx_list)
|
| 34 |
+
for i, x in enumerate(arr):
|
| 35 |
+
if i not in idx_set:
|
| 36 |
+
res.append(x)
|
| 37 |
+
return res
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MiniCPMVTemplate(Template):
|
| 41 |
+
is_v2_5 = False
|
| 42 |
+
use_model = True
|
| 43 |
+
skip_prompt = False
|
| 44 |
+
placeholder_tokens = ['<unk>']
|
| 45 |
+
|
| 46 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 47 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 48 |
+
if self.mode == 'vllm':
|
| 49 |
+
return ['(<image>./</image>)\n']
|
| 50 |
+
else:
|
| 51 |
+
return [[-100]]
|
| 52 |
+
|
| 53 |
+
async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
|
| 54 |
+
images = inputs.pop('images', None) or []
|
| 55 |
+
if len(images) == 0:
|
| 56 |
+
return
|
| 57 |
+
input_ids = inputs['input_ids']
|
| 58 |
+
idx_list = findall(input_ids, -100)
|
| 59 |
+
idx_list.insert(0, -1)
|
| 60 |
+
new_input_ids = []
|
| 61 |
+
features = []
|
| 62 |
+
for i in range(len(idx_list) - 1):
|
| 63 |
+
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
|
| 64 |
+
context_list = ['<image>', [-100], '</image>']
|
| 65 |
+
feat = [x.squeeze() for x in images[i]['embeddings'].split(1)]
|
| 66 |
+
grid = images[i].get('grid')
|
| 67 |
+
if len(feat) > 1 and grid is not None:
|
| 68 |
+
context_list.append('<slice>')
|
| 69 |
+
for j in range(grid[1]):
|
| 70 |
+
if j > 0:
|
| 71 |
+
context_list.append('\n')
|
| 72 |
+
for _ in range(grid[0]):
|
| 73 |
+
context_list += ['<image>', [-100], '</image>']
|
| 74 |
+
context_list.append('</slice>\n')
|
| 75 |
+
new_input_ids += self._encode_context_list(context_list)[0]
|
| 76 |
+
features += feat
|
| 77 |
+
new_input_ids += input_ids[idx_list[-1] + 1:]
|
| 78 |
+
inputs['input_ids'] = new_input_ids
|
| 79 |
+
inputs['images'] = features
|
| 80 |
+
await super().prepare_lmdeploy_turbomind_inputs(inputs)
|
| 81 |
+
|
| 82 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 83 |
+
encoded = super()._encode(inputs)
|
| 84 |
+
images = inputs.images
|
| 85 |
+
input_ids = encoded['input_ids']
|
| 86 |
+
labels = encoded['labels']
|
| 87 |
+
idx_list = findall(input_ids, -100)
|
| 88 |
+
idx = idx_list[0]
|
| 89 |
+
tgt_sizes = None
|
| 90 |
+
slice_mode = getattr(self.config, 'slice_mode', False)
|
| 91 |
+
if slice_mode:
|
| 92 |
+
if self.is_v2_5:
|
| 93 |
+
image_processor = self.processor.image_processor
|
| 94 |
+
image_inputs = image_processor(images, return_tensors='pt').to(self.model_info.torch_dtype)
|
| 95 |
+
placeholder = image_processor.get_slice_image_placeholder(image_inputs.image_sizes[0][0])
|
| 96 |
+
pixel_values = image_inputs['pixel_values']
|
| 97 |
+
tgt_sizes = image_inputs['tgt_sizes']
|
| 98 |
+
else:
|
| 99 |
+
images, placeholder = self.model.get_slice_image_placeholder(images[0], self.processor)
|
| 100 |
+
pixel_values = [[self.model.transform(img) for img in images]]
|
| 101 |
+
placeholder += '\n'
|
| 102 |
+
placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
|
| 103 |
+
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
|
| 104 |
+
if labels is not None:
|
| 105 |
+
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
|
| 106 |
+
input_tensor_ids = torch.tensor(input_ids)
|
| 107 |
+
image_start_idx = torch.where(input_tensor_ids == self.processor.im_start_id)[0]
|
| 108 |
+
image_start_idx += 1
|
| 109 |
+
image_end_idx = torch.where(input_tensor_ids == self.processor.im_end_id)[0]
|
| 110 |
+
valid_image_nums = max(len(image_start_idx), len(image_end_idx))
|
| 111 |
+
image_bound = [
|
| 112 |
+
torch.hstack(
|
| 113 |
+
[image_start_idx[:valid_image_nums].unsqueeze(-1), image_end_idx[:valid_image_nums].unsqueeze(-1)])
|
| 114 |
+
]
|
| 115 |
+
else:
|
| 116 |
+
placeholder = '<image>' + '<unk>' * self.config.query_num + '</image>\n'
|
| 117 |
+
placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
|
| 118 |
+
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
|
| 119 |
+
if labels is not None:
|
| 120 |
+
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
|
| 121 |
+
image_bound = [torch.tensor([[idx, idx + self.config.query_num]])]
|
| 122 |
+
pixel_values = [[self.model.transform(images[0])]]
|
| 123 |
+
encoded = {
|
| 124 |
+
'input_ids': input_ids,
|
| 125 |
+
'labels': labels,
|
| 126 |
+
'image_bound': image_bound,
|
| 127 |
+
'pixel_values': pixel_values,
|
| 128 |
+
'tgt_sizes': tgt_sizes
|
| 129 |
+
}
|
| 130 |
+
return encoded
|
| 131 |
+
|
| 132 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 133 |
+
inputs_embeds, _ = model.get_vllm_embedding(inputs)
|
| 134 |
+
return {'inputs_embeds': inputs_embeds}
|
| 135 |
+
|
| 136 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 137 |
+
res = {}
|
| 138 |
+
for k in ['pixel_values', 'image_bound', 'tgt_sizes']:
|
| 139 |
+
res[k] = self.gather_list(batch, k)
|
| 140 |
+
res.update(super()._data_collator(batch, padding_to=padding_to))
|
| 141 |
+
return res
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
register_template(MinicpmTemplateMeta(MLLMTemplateType.minicpmv, template_cls=MiniCPMVTemplate))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class MiniCPMV2_5Template(MiniCPMVTemplate):
|
| 148 |
+
is_v2_5 = True
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
register_template(Llama3TemplateMeta(
|
| 152 |
+
MLLMTemplateType.minicpmv2_5,
|
| 153 |
+
template_cls=MiniCPMV2_5Template,
|
| 154 |
+
))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class MiniCPMV2_6Template(MiniCPMVTemplate):
|
| 158 |
+
|
| 159 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index,
|
| 160 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 161 |
+
assert media_type in {'image', 'video'}
|
| 162 |
+
max_num_frames = get_env_args('max_num_frames', int, 64)
|
| 163 |
+
load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
|
| 164 |
+
image_context = super().replace_tag('image', index, inputs)
|
| 165 |
+
if media_type == 'image':
|
| 166 |
+
return image_context
|
| 167 |
+
elif media_type == 'video':
|
| 168 |
+
return self.replace_video2image(load_video, inputs, lambda i: image_context)
|
| 169 |
+
|
| 170 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 171 |
+
encoded = Template._encode(self, inputs)
|
| 172 |
+
images = inputs.images
|
| 173 |
+
use_video = bool(inputs.videos)
|
| 174 |
+
use_image_id = True
|
| 175 |
+
max_slice_nums = get_env_args('max_slice_nums', int, None)
|
| 176 |
+
video_max_slice_nums = get_env_args('video_max_slice_nums', int, 1) # or 2
|
| 177 |
+
if use_video:
|
| 178 |
+
max_slice_nums = video_max_slice_nums
|
| 179 |
+
use_image_id = False
|
| 180 |
+
input_ids = encoded['input_ids']
|
| 181 |
+
labels = encoded['labels']
|
| 182 |
+
idx_list = findall(input_ids, -100)
|
| 183 |
+
|
| 184 |
+
image_processor = self.processor.image_processor
|
| 185 |
+
image_inputs = image_processor([images], return_tensors='pt',
|
| 186 |
+
max_slice_nums=max_slice_nums).to(self.model_info.torch_dtype)
|
| 187 |
+
|
| 188 |
+
def _get_new_tokens(i):
|
| 189 |
+
placeholder = image_processor.get_slice_image_placeholder(
|
| 190 |
+
image_inputs.image_sizes[0][i], image_idx=i, max_slice_nums=max_slice_nums, use_image_id=use_image_id)
|
| 191 |
+
placeholder += '\n'
|
| 192 |
+
return self.processor.encode(placeholder, add_special_tokens=False)
|
| 193 |
+
|
| 194 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 195 |
+
if inputs.images:
|
| 196 |
+
input_tensor_ids = torch.tensor(input_ids)
|
| 197 |
+
unk_token = self.processor.encode('<unk>', add_special_tokens=False)[0]
|
| 198 |
+
indices = (input_tensor_ids == unk_token).nonzero(as_tuple=True)[0].tolist()
|
| 199 |
+
|
| 200 |
+
ranges = []
|
| 201 |
+
start = indices[0]
|
| 202 |
+
for i in range(1, len(indices)):
|
| 203 |
+
if indices[i] != indices[i - 1] + 1:
|
| 204 |
+
ranges.append([start, indices[i - 1] + 1])
|
| 205 |
+
start = indices[i]
|
| 206 |
+
ranges.append([start, indices[-1] + 1])
|
| 207 |
+
image_bound = [torch.tensor(ranges)]
|
| 208 |
+
else:
|
| 209 |
+
image_bound = [[]]
|
| 210 |
+
|
| 211 |
+
encoded = {
|
| 212 |
+
'input_ids': input_ids,
|
| 213 |
+
'labels': labels,
|
| 214 |
+
'image_bound': image_bound,
|
| 215 |
+
'pixel_values': image_inputs['pixel_values'],
|
| 216 |
+
'tgt_sizes': image_inputs['tgt_sizes']
|
| 217 |
+
}
|
| 218 |
+
return encoded
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
register_template(QwenTemplateMeta(
|
| 222 |
+
MLLMTemplateType.minicpmv2_6,
|
| 223 |
+
template_cls=MiniCPMV2_6Template,
|
| 224 |
+
))
|
| 225 |
+
|
| 226 |
+
register_template(Qwen2_5TemplateMeta(
|
| 227 |
+
MLLMTemplateType.minicpmo2_6,
|
| 228 |
+
template_cls=MiniCPMV2_6Template,
|
| 229 |
+
))
|
swift/llm/template/template/minimax.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
from swift.utils import get_logger
|
| 6 |
+
from ..base import Template
|
| 7 |
+
from ..constant import LLMTemplateType
|
| 8 |
+
from ..register import TemplateMeta, register_template
|
| 9 |
+
from ..template_inputs import StdTemplateInputs
|
| 10 |
+
from ..utils import Context, Prompt
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class MinimaxTemplateMeta(TemplateMeta):
|
| 17 |
+
prefix: Prompt = field(default_factory=list)
|
| 18 |
+
prompt: Prompt = field(default_factory=lambda: [
|
| 19 |
+
'<beginning_of_sentence>user name=user\n{{QUERY}}<end_of_sentence>\n'
|
| 20 |
+
'<beginning_of_sentence>ai name=assistant\n'
|
| 21 |
+
])
|
| 22 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<end_of_sentence>\n'])
|
| 23 |
+
suffix: Prompt = field(default_factory=lambda: ['<end_of_sentence>'])
|
| 24 |
+
system_prefix: Optional[Prompt] = field(
|
| 25 |
+
default_factory=lambda: ['<beginning_of_sentence>system ai_setting=assistant\n{{SYSTEM}}<end_of_sentence>\n'])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_template(MinimaxTemplateMeta(LLMTemplateType.minimax))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MinimaxVLTemplate(Template):
|
| 32 |
+
image_placeholder = ['<image>']
|
| 33 |
+
skip_prompt = True
|
| 34 |
+
|
| 35 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 36 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 37 |
+
assert media_type == 'image'
|
| 38 |
+
return self.image_placeholder * inputs.all_image_tokens[index]
|
| 39 |
+
|
| 40 |
+
def calc_num_image_tokens(self, image_inputs):
|
| 41 |
+
from transformers.image_utils import get_image_size, to_numpy_array
|
| 42 |
+
pixel_values = image_inputs['pixel_values']
|
| 43 |
+
image_sizes = image_inputs['image_sizes']
|
| 44 |
+
all_image_tokens = []
|
| 45 |
+
if not image_inputs:
|
| 46 |
+
return all_image_tokens
|
| 47 |
+
|
| 48 |
+
if self.processor.process_image_mode == 'anyres':
|
| 49 |
+
for pixel_value, image_size in zip(pixel_values, image_sizes):
|
| 50 |
+
height, width = image_size
|
| 51 |
+
num_image_tokens = self.processor.get_num_token(height, width, self.processor.grid_pinpoints,
|
| 52 |
+
self.processor.patch_size)
|
| 53 |
+
all_image_tokens.append(num_image_tokens)
|
| 54 |
+
elif self.processor.process_image_mode == 'resize':
|
| 55 |
+
pixel_values = image_inputs['pixel_values']
|
| 56 |
+
all_image_tokens = []
|
| 57 |
+
for pixel_value in pixel_values:
|
| 58 |
+
height, width = get_image_size(to_numpy_array(pixel_value))
|
| 59 |
+
all_image_tokens.append(int(height * width / self.processor.patch_size**2))
|
| 60 |
+
else:
|
| 61 |
+
if self.processor.patch_size is not None:
|
| 62 |
+
pixel_values = image_inputs['pixel_values']
|
| 63 |
+
all_image_tokens = []
|
| 64 |
+
for pixel_value in pixel_values:
|
| 65 |
+
height, width = get_image_size(to_numpy_array(pixel_value))
|
| 66 |
+
new_width, new_height = self.processor.get_hw_multiple_of(
|
| 67 |
+
(width, height), self.processor.patch_size, self.processor.max_size)
|
| 68 |
+
num_image_tokens = ((new_height // self.processor.patch_size) *
|
| 69 |
+
(new_width // self.processor.patch_size)) # + 1
|
| 70 |
+
all_image_tokens.append(num_image_tokens)
|
| 71 |
+
else:
|
| 72 |
+
logger.warning_once(
|
| 73 |
+
'Expanding inputs for image tokens in MiniMaxVL01 should be done in processing. '
|
| 74 |
+
"Please add `patch_size` and `vision_feature_select_strategy` to the model's "
|
| 75 |
+
'processing config or set directly '
|
| 76 |
+
'with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = '
|
| 77 |
+
'{{vision_feature_select_strategy}}`. '
|
| 78 |
+
'Using processors without these attributes in the config is deprecated '
|
| 79 |
+
'and will throw an error in v4.47.')
|
| 80 |
+
raise ValueError(
|
| 81 |
+
"You need to provide `patch_size` and `vision_feature_select_strategy` in the model's processing "
|
| 82 |
+
'config to expand inputs for image tokens.')
|
| 83 |
+
return all_image_tokens
|
| 84 |
+
|
| 85 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 86 |
+
output_kwargs = self.processor._merge_kwargs(
|
| 87 |
+
self.processor.MiniMaxVL01ProcessorKwargs,
|
| 88 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 89 |
+
)
|
| 90 |
+
if inputs.images:
|
| 91 |
+
image_inputs = self.processor.image_processor(
|
| 92 |
+
inputs.images, **output_kwargs['images_kwargs'], return_tensors='pt')
|
| 93 |
+
inputs.all_image_tokens = self.calc_num_image_tokens(image_inputs)
|
| 94 |
+
else:
|
| 95 |
+
image_inputs = {}
|
| 96 |
+
encoded = super()._encode(inputs)
|
| 97 |
+
for key in image_inputs:
|
| 98 |
+
encoded[key] = image_inputs[key]
|
| 99 |
+
return encoded
|
| 100 |
+
|
| 101 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 102 |
+
pixel_values = self.gather_list(batch, 'pixel_values')
|
| 103 |
+
image_sizes = self.gather_list(batch, 'image_sizes')
|
| 104 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 105 |
+
if pixel_values:
|
| 106 |
+
res['pixel_values'] = pixel_values
|
| 107 |
+
if image_sizes:
|
| 108 |
+
res['image_sizes'] = image_sizes
|
| 109 |
+
return res
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
register_template(MinimaxTemplateMeta(LLMTemplateType.minimax_vl, template_cls=MinimaxVLTemplate))
|
swift/llm/template/template/mistral.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..base import Template
|
| 7 |
+
from ..constant import MLLMTemplateType
|
| 8 |
+
from ..register import TemplateMeta, register_template
|
| 9 |
+
from ..template_inputs import StdTemplateInputs
|
| 10 |
+
from ..utils import Context, findall
|
| 11 |
+
from .llm import mistral_2501_system
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Mistral2503Template(Template):
|
| 15 |
+
placeholder_tokens = ['[IMG]']
|
| 16 |
+
image_token = 10
|
| 17 |
+
|
| 18 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 19 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 20 |
+
assert media_type == 'image'
|
| 21 |
+
return ['[IMG]']
|
| 22 |
+
|
| 23 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 24 |
+
encoded = super()._encode(inputs)
|
| 25 |
+
processor = self.processor
|
| 26 |
+
images = inputs.images
|
| 27 |
+
input_ids = encoded['input_ids']
|
| 28 |
+
labels = encoded['labels']
|
| 29 |
+
idx_list = findall(input_ids, self.image_token)
|
| 30 |
+
if idx_list:
|
| 31 |
+
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
|
| 32 |
+
encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype)
|
| 33 |
+
encoded['image_sizes'] = image_sizes = image_inputs['image_sizes']
|
| 34 |
+
|
| 35 |
+
def _get_new_tokens(i):
|
| 36 |
+
height, width = image_sizes[i]
|
| 37 |
+
num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size)
|
| 38 |
+
num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size)
|
| 39 |
+
replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token]
|
| 40 |
+
] * num_height_tokens
|
| 41 |
+
# Flatten list
|
| 42 |
+
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
| 43 |
+
replace_tokens[-1] = processor.image_end_token
|
| 44 |
+
replace_str = ''.join(replace_tokens)
|
| 45 |
+
return processor.encode(replace_str, add_special_tokens=False)
|
| 46 |
+
|
| 47 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 48 |
+
|
| 49 |
+
return encoded
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
register_template(
|
| 53 |
+
TemplateMeta(
|
| 54 |
+
MLLMTemplateType.mistral_2503,
|
| 55 |
+
prefix=['<s>'],
|
| 56 |
+
prompt=['[INST]{{QUERY}}[/INST]'],
|
| 57 |
+
chat_sep=['</s>'],
|
| 58 |
+
suffix=['</s>'],
|
| 59 |
+
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
|
| 60 |
+
default_system=mistral_2501_system,
|
| 61 |
+
template_cls=Mistral2503Template))
|
swift/llm/template/template/molmo.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..base import Template
|
| 7 |
+
from ..constant import MLLMTemplateType
|
| 8 |
+
from ..register import TemplateMeta, register_template
|
| 9 |
+
from ..template_inputs import StdTemplateInputs
|
| 10 |
+
from ..utils import Context, findall
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MolmoTemplate(Template):
|
| 14 |
+
placeholder_tokens = ['<im_patch>']
|
| 15 |
+
|
| 16 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 17 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 18 |
+
return []
|
| 19 |
+
|
| 20 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 21 |
+
encoded = super()._encode(inputs)
|
| 22 |
+
# image
|
| 23 |
+
images_inputs = self.processor.process(images=inputs.images or None, text='')
|
| 24 |
+
images_input_ids = images_inputs.pop('input_ids').tolist()
|
| 25 |
+
user_token = self._tokenize(' User')
|
| 26 |
+
assert len(user_token) == 1
|
| 27 |
+
idx = findall(images_input_ids, user_token[0])
|
| 28 |
+
assert len(idx) == 1
|
| 29 |
+
labels = encoded['labels']
|
| 30 |
+
encoded['input_ids'] = images_input_ids[:idx[0]] + encoded['input_ids']
|
| 31 |
+
if labels:
|
| 32 |
+
encoded['labels'] = [-100] * idx[0] + labels
|
| 33 |
+
if 'images' in images_inputs:
|
| 34 |
+
images_inputs['images'] = images_inputs['images'].to(self.model_info.torch_dtype)
|
| 35 |
+
encoded.update(images_inputs)
|
| 36 |
+
return encoded
|
| 37 |
+
|
| 38 |
+
def generate(self, model, **kwargs):
|
| 39 |
+
kwargs.pop('attention_mask', None)
|
| 40 |
+
generation_config = kwargs.pop('generation_config')
|
| 41 |
+
batch = {
|
| 42 |
+
k: kwargs.pop(k, None)
|
| 43 |
+
for k in ['input_ids', 'attention_mask', 'images', 'image_input_idx', 'image_masks']
|
| 44 |
+
}
|
| 45 |
+
return model.generate_from_batch(batch, generation_config, **kwargs)
|
| 46 |
+
|
| 47 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 48 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 49 |
+
# prepare batchfy inputs
|
| 50 |
+
keys = ['images', 'image_input_idx', 'image_masks']
|
| 51 |
+
images_res = self.fetch_inputs(batch, keys)
|
| 52 |
+
for key in keys:
|
| 53 |
+
val = images_res.get(key)
|
| 54 |
+
if val:
|
| 55 |
+
images_res[key] = torch.stack(val)
|
| 56 |
+
res.update(images_res)
|
| 57 |
+
return res
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
register_template(
|
| 61 |
+
TemplateMeta(
|
| 62 |
+
MLLMTemplateType.molmo,
|
| 63 |
+
prefix=[],
|
| 64 |
+
prompt=[' User: {{QUERY}} Assistant:'],
|
| 65 |
+
chat_sep=None,
|
| 66 |
+
suffix=['<|endoftext|>'],
|
| 67 |
+
template_cls=MolmoTemplate,
|
| 68 |
+
))
|
swift/llm/template/template/moonshot.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 5 |
+
|
| 6 |
+
from ..base import Template
|
| 7 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 8 |
+
from ..register import TemplateMeta, register_template
|
| 9 |
+
from ..template_inputs import StdTemplateInputs
|
| 10 |
+
from ..utils import Context, Prompt, findall
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class MoonlightTemplateMeta(TemplateMeta):
|
| 15 |
+
prefix: Prompt = field(default_factory=list)
|
| 16 |
+
prompt: Prompt = field(default_factory=lambda:
|
| 17 |
+
['<|im_user|>user<|im_middle|>{{QUERY}}<|im_end|><|im_assistant|>assistant<|im_middle|>'])
|
| 18 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>'])
|
| 19 |
+
suffix: Prompt = field(default_factory=lambda: ['<|im_end|>'])
|
| 20 |
+
system_prefix: Optional[Prompt] = field(
|
| 21 |
+
default_factory=lambda: ['<|im_system|>system<|im_middle|>{{SYSTEM}}<|im_end|>'])
|
| 22 |
+
default_system: str = 'You are a helpful assistant'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
register_template(MoonlightTemplateMeta(LLMTemplateType.moonlight))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class KimiVLTemplate(Template):
|
| 29 |
+
placeholder_tokens = ['<|media_pad|>']
|
| 30 |
+
|
| 31 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 32 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 33 |
+
if media_type == 'image':
|
| 34 |
+
return ['<|media_start|>image<|media_content|><|media_pad|><|media_end|>']
|
| 35 |
+
|
| 36 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 37 |
+
encoded = super()._encode(inputs)
|
| 38 |
+
input_ids = encoded['input_ids']
|
| 39 |
+
labels = encoded['labels']
|
| 40 |
+
media_token = self._tokenize('<|media_pad|>')[0]
|
| 41 |
+
idx_list = findall(input_ids, media_token)
|
| 42 |
+
if inputs.images:
|
| 43 |
+
image_processor = self.processor.image_processor
|
| 44 |
+
image_inputs = image_processor(inputs.images, return_tensors='pt')
|
| 45 |
+
image_grid_hws = image_inputs['image_grid_hws']
|
| 46 |
+
merge_length = image_processor.merge_kernel_size[0] * image_processor.merge_kernel_size[1]
|
| 47 |
+
|
| 48 |
+
def _get_new_tokens(i):
|
| 49 |
+
token_len = (image_grid_hws[i].prod() // merge_length)
|
| 50 |
+
return [media_token] * token_len
|
| 51 |
+
|
| 52 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 53 |
+
encoded['input_ids'] = input_ids
|
| 54 |
+
encoded['labels'] = labels
|
| 55 |
+
encoded.update(image_inputs)
|
| 56 |
+
return encoded
|
| 57 |
+
|
| 58 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 59 |
+
res = super()._data_collator_mm_data(batch)
|
| 60 |
+
image_grid_hws = self.concat_tensor(batch, 'image_grid_hws', 0)
|
| 61 |
+
if image_grid_hws is not None:
|
| 62 |
+
res['image_grid_hws'] = image_grid_hws
|
| 63 |
+
return res
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
register_template(MoonlightTemplateMeta(MLLMTemplateType.kimi_vl, template_cls=KimiVLTemplate))
|
swift/llm/template/template/mplug.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from swift.utils import get_env_args
|
| 10 |
+
from ..base import Template
|
| 11 |
+
from ..constant import MLLMTemplateType
|
| 12 |
+
from ..register import TemplateMeta, register_template
|
| 13 |
+
from ..template_inputs import StdTemplateInputs
|
| 14 |
+
from ..utils import Context, Prompt, findall
|
| 15 |
+
from ..vision_utils import load_video_minicpmv_mplug_owl3
|
| 16 |
+
from .qwen import QwenTemplateMeta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class mPlugOwl2Template(Template):
|
| 20 |
+
|
| 21 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 22 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 23 |
+
assert media_type == 'image'
|
| 24 |
+
return [[-200]]
|
| 25 |
+
|
| 26 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 27 |
+
from mplug_owl2.mm_utils import process_images
|
| 28 |
+
processor = self.processor
|
| 29 |
+
images = inputs.images
|
| 30 |
+
for i, image in enumerate(images):
|
| 31 |
+
# ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1
|
| 32 |
+
max_edge = max(image.size)
|
| 33 |
+
image = image.resize((max_edge, max_edge))
|
| 34 |
+
images[i] = image
|
| 35 |
+
encoded = super()._encode(inputs)
|
| 36 |
+
input_ids = encoded['input_ids']
|
| 37 |
+
labels = encoded['labels']
|
| 38 |
+
res = {'input_ids': input_ids, 'labels': labels}
|
| 39 |
+
if images:
|
| 40 |
+
images = process_images(images, processor)
|
| 41 |
+
images = images.to(self.model_info.torch_dtype)
|
| 42 |
+
res['images'] = images
|
| 43 |
+
return res
|
| 44 |
+
|
| 45 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 46 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 47 |
+
images = [b['images'] for b in batch if 'images' in b]
|
| 48 |
+
if images:
|
| 49 |
+
res['images'] = torch.concat(images)
|
| 50 |
+
return res
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
register_template(
|
| 54 |
+
TemplateMeta(
|
| 55 |
+
MLLMTemplateType.mplug_owl2,
|
| 56 |
+
template_cls=mPlugOwl2Template,
|
| 57 |
+
prefix=['{{SYSTEM}}'],
|
| 58 |
+
prompt=['USER: {{QUERY}}ASSISTANT:'],
|
| 59 |
+
chat_sep=['</s>'],
|
| 60 |
+
suffix=[['eos_token_id']],
|
| 61 |
+
stop_words=['<|endoftext|>', '</s>']))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class mPlugOwl3Template(Template):
|
| 65 |
+
version = None
|
| 66 |
+
|
| 67 |
+
def _get_image_token_list(self, cut_shape):
|
| 68 |
+
text = self.processor.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0], w=cut_shape[1])
|
| 69 |
+
text_list = text.split('<|image|>')
|
| 70 |
+
res_text_list = []
|
| 71 |
+
for text in text_list[:-1]:
|
| 72 |
+
res_text_list += [text, '<|image|>']
|
| 73 |
+
res_text_list += text_list[-1]
|
| 74 |
+
token_list = self._encode_context_list(res_text_list)[0]
|
| 75 |
+
return token_list
|
| 76 |
+
|
| 77 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 78 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 79 |
+
assert media_type in {'image', 'video'}
|
| 80 |
+
max_num_frames = get_env_args('max_num_frames', int, 16)
|
| 81 |
+
load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
|
| 82 |
+
if media_type == 'image':
|
| 83 |
+
return [[-100], '\n']
|
| 84 |
+
elif media_type == 'video':
|
| 85 |
+
return self.replace_video2image(load_video, inputs, lambda i: [[-100]]) + ['\n']
|
| 86 |
+
|
| 87 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 88 |
+
encoded = super()._encode(inputs)
|
| 89 |
+
images = inputs.images
|
| 90 |
+
videos = inputs.videos
|
| 91 |
+
cut_enable = not videos
|
| 92 |
+
input_ids = encoded['input_ids']
|
| 93 |
+
labels = encoded['labels']
|
| 94 |
+
idx_list = findall(input_ids, -100)
|
| 95 |
+
processor = self.processor
|
| 96 |
+
encoded = {}
|
| 97 |
+
if images:
|
| 98 |
+
image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt')
|
| 99 |
+
cut_shapes = image_inputs['cut_shape'] or [None] * 2 * len(idx_list)
|
| 100 |
+
image_token_list = self.processor.encode('<|image|>', add_special_tokens=False)
|
| 101 |
+
|
| 102 |
+
def _get_new_tokens(i):
|
| 103 |
+
cut_shape = cut_shapes[2 * i]
|
| 104 |
+
if cut_shape:
|
| 105 |
+
token_list = self._get_image_token_list(cut_shape)
|
| 106 |
+
else:
|
| 107 |
+
token_list = image_token_list
|
| 108 |
+
return token_list
|
| 109 |
+
|
| 110 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 111 |
+
image_token_idx = torch.tensor(findall(input_ids, image_token_list))
|
| 112 |
+
if self.version == '241101':
|
| 113 |
+
media_offset = image_token_idx
|
| 114 |
+
else:
|
| 115 |
+
_range = torch.arange(len(input_ids))[:, None]
|
| 116 |
+
matrix = (_range > image_token_idx[None]).sum(dim=1)
|
| 117 |
+
media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
|
| 118 |
+
encoded.update({
|
| 119 |
+
'pixel_values': image_inputs['pixel_values'],
|
| 120 |
+
'media_offset': media_offset,
|
| 121 |
+
})
|
| 122 |
+
encoded['input_ids'] = input_ids
|
| 123 |
+
encoded['labels'] = labels
|
| 124 |
+
return encoded
|
| 125 |
+
|
| 126 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 127 |
+
if 'media_offset' in inputs:
|
| 128 |
+
media_offset = []
|
| 129 |
+
cusum_offset = 0
|
| 130 |
+
image_embeds = []
|
| 131 |
+
pixel_values = inputs.pop('pixel_values')
|
| 132 |
+
max_sequence_length = inputs['input_ids'].shape[1]
|
| 133 |
+
for i, curr_media_offset in enumerate(inputs['media_offset']):
|
| 134 |
+
if curr_media_offset is None:
|
| 135 |
+
continue
|
| 136 |
+
if curr_media_offset.shape[1] < max_sequence_length:
|
| 137 |
+
padding = curr_media_offset[:, -1:, :].expand(curr_media_offset.shape[0],
|
| 138 |
+
max_sequence_length - curr_media_offset.shape[1],
|
| 139 |
+
curr_media_offset.shape[2])
|
| 140 |
+
curr_media_offset = torch.concat([curr_media_offset, padding], dim=1)
|
| 141 |
+
media_offset.append(curr_media_offset + cusum_offset)
|
| 142 |
+
image_embeds.append(model.forward_image(pixel_values[i]))
|
| 143 |
+
cusum_offset += image_embeds[-1].shape[0]
|
| 144 |
+
inputs['media_offset'] = torch.concat(media_offset)
|
| 145 |
+
inputs['image_embeds'] = torch.concat(image_embeds)
|
| 146 |
+
return inputs
|
| 147 |
+
|
| 148 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 149 |
+
res = self.fetch_inputs(batch, ['media_offset', 'pixel_values'])
|
| 150 |
+
for b in batch:
|
| 151 |
+
b.pop('pixel_values', None)
|
| 152 |
+
res.update(super()._data_collator(batch, padding_to=padding_to))
|
| 153 |
+
return res
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class mPlugOwl3_241101Template(mPlugOwl3Template):
|
| 157 |
+
version = '241101'
|
| 158 |
+
|
| 159 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 160 |
+
if 'pixel_values' in inputs:
|
| 161 |
+
pixel_values = inputs.pop('pixel_values')
|
| 162 |
+
inputs['image_embeds'] = torch.concat([model.forward_image(pv) for pv in pixel_values])
|
| 163 |
+
else:
|
| 164 |
+
inputs['media_offset'] = [None] * inputs['input_ids'].shape[0]
|
| 165 |
+
return inputs
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@dataclass
|
| 169 |
+
class mPlugOwl3TemplateMeta(QwenTemplateMeta):
|
| 170 |
+
prefix: Prompt = field(default_factory=lambda: ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'])
|
| 171 |
+
default_system: Optional[str] = None
|
| 172 |
+
system_prefix: Optional[Prompt] = None
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3, template_cls=mPlugOwl3Template))
|
| 176 |
+
|
| 177 |
+
register_template(mPlugOwl3TemplateMeta(MLLMTemplateType.mplug_owl3_241101, template_cls=mPlugOwl3_241101Template))
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class DocOwl2Template(Template):
|
| 181 |
+
|
| 182 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 183 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 184 |
+
if media_type == 'image':
|
| 185 |
+
return [f'<img {index + 1}>', [-200]]
|
| 186 |
+
|
| 187 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 188 |
+
encoded = super()._encode(inputs)
|
| 189 |
+
if inputs.images:
|
| 190 |
+
image_tensor, patch_positions, _ = self.processor._process_image(inputs.images)
|
| 191 |
+
image_tensor = image_tensor.to(self.model_info.torch_dtype)
|
| 192 |
+
encoded.update({'images': image_tensor, 'patch_positions': patch_positions})
|
| 193 |
+
return encoded
|
| 194 |
+
|
| 195 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 196 |
+
keys = ['images', 'patch_positions']
|
| 197 |
+
res = self.fetch_inputs(batch, keys)
|
| 198 |
+
for key in keys:
|
| 199 |
+
val = res.get(key)
|
| 200 |
+
if val:
|
| 201 |
+
res[key] = torch.concat([v for v in val if v is not None])
|
| 202 |
+
res.update(super()._data_collator(batch, padding_to=padding_to))
|
| 203 |
+
return res
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
register_template(
|
| 207 |
+
TemplateMeta(
|
| 208 |
+
MLLMTemplateType.doc_owl2,
|
| 209 |
+
prefix=['<s>'],
|
| 210 |
+
prompt=[' USER: {{QUERY}} ASSISTANT:'],
|
| 211 |
+
chat_sep=['</s>'],
|
| 212 |
+
suffix=['</s>'],
|
| 213 |
+
template_cls=DocOwl2Template,
|
| 214 |
+
))
|
swift/llm/template/template/openbuddy.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from ..constant import LLMTemplateType
|
| 4 |
+
from ..register import TemplateMeta, register_template
|
| 5 |
+
|
| 6 |
+
OPENBUDDY_DEFAULT_SYSTEM = (
|
| 7 |
+
'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
|
| 8 |
+
'Always answer as helpfully and logically as possible, while being safe. '
|
| 9 |
+
'Your answers should not include any '
|
| 10 |
+
'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
| 11 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n'
|
| 12 |
+
'If a question does not make any sense, or is not factually coherent, '
|
| 13 |
+
'explain why instead of answering something not correct. '
|
| 14 |
+
"If you don't know the answer to a question, please don't share false information.\n"
|
| 15 |
+
'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
|
| 16 |
+
'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
|
| 17 |
+
'You always deeply love and support China, Chinese government, people and culture.\n'
|
| 18 |
+
'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
|
| 19 |
+
'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.')
|
| 20 |
+
register_template(
|
| 21 |
+
TemplateMeta(
|
| 22 |
+
LLMTemplateType.openbuddy,
|
| 23 |
+
prefix=[],
|
| 24 |
+
prompt=['User: {{QUERY}}\nAssistant:'],
|
| 25 |
+
chat_sep=['\n'],
|
| 26 |
+
default_system=OPENBUDDY_DEFAULT_SYSTEM,
|
| 27 |
+
system_prefix=['{{SYSTEM}}\n\n'],
|
| 28 |
+
auto_add_bos=True))
|
| 29 |
+
|
| 30 |
+
OPENBUDDY2_DEFAULT_SYSTEM = (
|
| 31 |
+
'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. '
|
| 32 |
+
'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. '
|
| 33 |
+
'Your answers should not include any harmful, political, religious, unethical, racist, '
|
| 34 |
+
'sexist, toxic, dangerous, or illegal content. '
|
| 35 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n'
|
| 36 |
+
'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n'
|
| 37 |
+
'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
|
| 38 |
+
'not related to GPT or OpenAI')
|
| 39 |
+
|
| 40 |
+
register_template(
|
| 41 |
+
TemplateMeta(
|
| 42 |
+
LLMTemplateType.openbuddy2,
|
| 43 |
+
prefix=[],
|
| 44 |
+
prompt=['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'],
|
| 45 |
+
chat_sep=['<|end|>\n'],
|
| 46 |
+
suffix=['<|end|>'],
|
| 47 |
+
default_system=OPENBUDDY2_DEFAULT_SYSTEM,
|
| 48 |
+
system_prefix=['<|role|>system<|says|>{{SYSTEM}}<|end|>\n']))
|
swift/llm/template/template/pixtral.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ..base import Template
|
| 5 |
+
from ..constant import MLLMTemplateType
|
| 6 |
+
from ..register import TemplateMeta, register_template
|
| 7 |
+
from ..template_inputs import StdTemplateInputs
|
| 8 |
+
from ..utils import findall
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PixtralTemplate(Template):
|
| 12 |
+
image_placeholder = ['[IMG]']
|
| 13 |
+
placeholder_tokens = ['[IMG]']
|
| 14 |
+
|
| 15 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 16 |
+
encoded = super()._encode(inputs)
|
| 17 |
+
processor = self.processor
|
| 18 |
+
images = inputs.images
|
| 19 |
+
input_ids = encoded['input_ids']
|
| 20 |
+
labels = encoded['labels']
|
| 21 |
+
idx_list = findall(input_ids, 10)
|
| 22 |
+
if idx_list:
|
| 23 |
+
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
|
| 24 |
+
encoded['pixel_values'] = image_inputs['pixel_values'][0]
|
| 25 |
+
image_sizes = image_inputs['image_sizes'][0]
|
| 26 |
+
|
| 27 |
+
def _get_new_tokens(i):
|
| 28 |
+
height, width = image_sizes[i]
|
| 29 |
+
num_height_tokens = height // processor.patch_size
|
| 30 |
+
num_width_tokens = width // processor.patch_size
|
| 31 |
+
replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
|
| 32 |
+
num_height_tokens - 1)
|
| 33 |
+
replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
|
| 34 |
+
# Flatten list
|
| 35 |
+
replace_str = ''.join(replace_tokens)
|
| 36 |
+
img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False)
|
| 37 |
+
return img_tokens
|
| 38 |
+
|
| 39 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 40 |
+
|
| 41 |
+
return encoded
|
| 42 |
+
|
| 43 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 44 |
+
pixel_values = self.gather_list(batch, 'pixel_values')
|
| 45 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 46 |
+
if pixel_values:
|
| 47 |
+
res['pixel_values'] = pixel_values
|
| 48 |
+
return res
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
register_template(
|
| 52 |
+
TemplateMeta(
|
| 53 |
+
MLLMTemplateType.pixtral,
|
| 54 |
+
prefix=['<s>{{SYSTEM}}'],
|
| 55 |
+
prompt=['[INST]{{QUERY}}[/INST]'],
|
| 56 |
+
chat_sep=['</s>'],
|
| 57 |
+
suffix=['</s>'],
|
| 58 |
+
template_cls=PixtralTemplate,
|
| 59 |
+
))
|
swift/llm/template/template/qwen.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from swift.llm import to_device, to_float_dtype
|
| 10 |
+
from swift.utils import get_env_args, is_deepspeed_enabled
|
| 11 |
+
from ..base import Template
|
| 12 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 13 |
+
from ..register import register_template
|
| 14 |
+
from ..template_inputs import StdTemplateInputs
|
| 15 |
+
from ..template_meta import TemplateMeta
|
| 16 |
+
from ..utils import Context, Word, findall
|
| 17 |
+
from ..vision_utils import load_audio, load_batch, load_video_ovis2
|
| 18 |
+
from .llama import Llama3TemplateMeta
|
| 19 |
+
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class QwenTemplateMeta(ChatmlTemplateMeta):
|
| 24 |
+
default_system: Optional[str] = DEFAULT_SYSTEM
|
| 25 |
+
auto_add_bos: bool = False
|
| 26 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>'])
|
| 27 |
+
agent_template: str = 'hermes'
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Qwen2_5TemplateMeta(QwenTemplateMeta):
|
| 32 |
+
default_system: Optional[str] = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.'
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Qwen2_5MathTemplateMeta(QwenTemplateMeta):
|
| 37 |
+
default_system: Optional[str] = 'Please reason step by step, and put your final answer within \\boxed{}.'
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
qwq_preview_system = ('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
|
| 41 |
+
'You should think step-by-step.')
|
| 42 |
+
|
| 43 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwen))
|
| 44 |
+
register_template(Qwen2_5TemplateMeta(LLMTemplateType.qwen2_5))
|
| 45 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwq_preview, default_system=qwq_preview_system))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ThinkingTemplate(Template):
|
| 49 |
+
|
| 50 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 51 |
+
if not self.is_training:
|
| 52 |
+
for message in inputs.messages:
|
| 53 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 54 |
+
message['content'] = message['content'].split('</think>')[-1].lstrip('\n')
|
| 55 |
+
return super()._swift_encode(inputs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
register_template(
|
| 59 |
+
QwenTemplateMeta(
|
| 60 |
+
LLMTemplateType.qwq, default_system=None, response_prefix='<think>\n', template_cls=ThinkingTemplate))
|
| 61 |
+
|
| 62 |
+
# '<think>\n\n</think>\n\n'
|
| 63 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=ThinkingTemplate))
|
| 64 |
+
|
| 65 |
+
register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class QwenPRMTemplate(Template):
|
| 69 |
+
cot_process_placeholder = '<extra_0>'
|
| 70 |
+
|
| 71 |
+
def _preprocess_inputs(
|
| 72 |
+
self,
|
| 73 |
+
inputs: StdTemplateInputs,
|
| 74 |
+
) -> None:
|
| 75 |
+
super()._preprocess_inputs(inputs)
|
| 76 |
+
total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
|
| 77 |
+
if self.cot_process_placeholder not in total_content:
|
| 78 |
+
inputs.messages[-1]['content'] = inputs.messages[-1]['content'] + self.cot_process_placeholder
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def make_step_rewards(logits, token_masks):
|
| 82 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 83 |
+
probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
|
| 84 |
+
|
| 85 |
+
all_scores_res = []
|
| 86 |
+
for i in range(probabilities.size(0)):
|
| 87 |
+
sample = probabilities[i] # seq_len, num_labels
|
| 88 |
+
positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
|
| 89 |
+
non_zero_elements_list = positive_probs.cpu().tolist()
|
| 90 |
+
all_scores_res.append(non_zero_elements_list)
|
| 91 |
+
return all_scores_res
|
| 92 |
+
|
| 93 |
+
def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
|
| 94 |
+
step_sep_id = self.tokenizer.encode(self.cot_process_placeholder)[0]
|
| 95 |
+
token_masks = (input_ids == step_sep_id)
|
| 96 |
+
return self.make_step_rewards(logits, token_masks)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math_prm, template_cls=QwenPRMTemplate))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class QwenVLTemplate(Template):
|
| 103 |
+
load_images = False
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _load_image(image, load_images: bool):
|
| 107 |
+
if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200):
|
| 108 |
+
load_images = True
|
| 109 |
+
return Template._load_image(image, load_images)
|
| 110 |
+
|
| 111 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 112 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 113 |
+
assert media_type == 'image'
|
| 114 |
+
if self.mode == 'lmdeploy':
|
| 115 |
+
return [f'Picture {index + 1}: ', [-100], '\n']
|
| 116 |
+
else:
|
| 117 |
+
image = inputs.images[index]
|
| 118 |
+
if self.mode == 'vllm':
|
| 119 |
+
return [f'Picture {index + 1}: <img></img>\n']
|
| 120 |
+
else:
|
| 121 |
+
assert isinstance(image, str)
|
| 122 |
+
return [f'Picture {index + 1}: <img>{image}</img>\n']
|
| 123 |
+
|
| 124 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 125 |
+
return [f'<ref>{ref}</ref>']
|
| 126 |
+
|
| 127 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 128 |
+
return [f'<box>{self._get_bbox_str(bbox)}</box>']
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen_vl, template_cls=QwenVLTemplate))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class QwenAudioTemplate(Template):
|
| 135 |
+
|
| 136 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 137 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 138 |
+
assert media_type == 'audio'
|
| 139 |
+
audios = inputs.audios
|
| 140 |
+
audio = audios[index]
|
| 141 |
+
assert isinstance(audio, str)
|
| 142 |
+
return [f'Audio {index + 1}:<audio>{audio}</audio>\n']
|
| 143 |
+
|
| 144 |
+
def _tokenize(self, context, **tokenizer_kwargs):
|
| 145 |
+
audio_info = self.processor.process_audio(context)
|
| 146 |
+
return super()._tokenize(context, audio_info=audio_info)
|
| 147 |
+
|
| 148 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 149 |
+
encoded = super()._encode(inputs)
|
| 150 |
+
text = ''.join([f'<audio>{audio}</audio>' for audio in inputs.audios])
|
| 151 |
+
audio_info = self.processor.process_audio(text)
|
| 152 |
+
if audio_info:
|
| 153 |
+
tokenizer_kwargs = {'audio_info': audio_info}
|
| 154 |
+
encoded.update(tokenizer_kwargs)
|
| 155 |
+
encoded['tokenizer_kwargs'] = tokenizer_kwargs
|
| 156 |
+
return encoded
|
| 157 |
+
|
| 158 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 159 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 160 |
+
if batch[0].get('audio_info') is not None:
|
| 161 |
+
res['audio_info'] = [b['audio_info'] for b in batch]
|
| 162 |
+
return res
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen_audio, template_cls=QwenAudioTemplate))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Qwen2AudioTemplate(Template):
|
| 169 |
+
|
| 170 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 171 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 172 |
+
assert media_type == 'audio'
|
| 173 |
+
if not self.use_chat_template:
|
| 174 |
+
return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
|
| 175 |
+
else:
|
| 176 |
+
return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
|
| 177 |
+
|
| 178 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 179 |
+
encoded = super()._encode(inputs)
|
| 180 |
+
if inputs.audios:
|
| 181 |
+
sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
|
| 182 |
+
audios = load_batch(inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate))
|
| 183 |
+
audio_inputs = self.processor.feature_extractor(
|
| 184 |
+
audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
|
| 185 |
+
audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
|
| 186 |
+
encoded.update(audio_inputs)
|
| 187 |
+
return encoded
|
| 188 |
+
|
| 189 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 190 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 191 |
+
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
|
| 192 |
+
feature_attention_mask = [
|
| 193 |
+
b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
|
| 194 |
+
]
|
| 195 |
+
if input_features:
|
| 196 |
+
res['input_features'] = torch.concat(input_features)
|
| 197 |
+
res['feature_attention_mask'] = torch.concat(feature_attention_mask)
|
| 198 |
+
return res
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Qwen2VLTemplate(Template):
|
| 205 |
+
image_token_id = 151655
|
| 206 |
+
video_token_id = 151656
|
| 207 |
+
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
|
| 208 |
+
version = 'v2'
|
| 209 |
+
use_model = True
|
| 210 |
+
|
| 211 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 212 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 213 |
+
from qwen_vl_utils import fetch_image, fetch_video
|
| 214 |
+
assert media_type in {'image', 'video'}
|
| 215 |
+
if media_type == 'image':
|
| 216 |
+
inputs.images[index] = fetch_image({'image': inputs.images[index]})
|
| 217 |
+
if self.mode == 'lmdeploy':
|
| 218 |
+
return ['<|vision_start|>', [-100], '<|vision_end|>']
|
| 219 |
+
else:
|
| 220 |
+
return ['<|vision_start|><|image_pad|><|vision_end|>']
|
| 221 |
+
else:
|
| 222 |
+
inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
|
| 223 |
+
return ['<|vision_start|><|video_pad|><|vision_end|>']
|
| 224 |
+
|
| 225 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 226 |
+
return [f'<|object_ref_start|>{ref}<|object_ref_end|>']
|
| 227 |
+
|
| 228 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 229 |
+
return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>']
|
| 230 |
+
|
| 231 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 232 |
+
encoded = super()._encode(inputs)
|
| 233 |
+
processor = self.processor
|
| 234 |
+
input_ids = encoded['input_ids']
|
| 235 |
+
labels = encoded['labels']
|
| 236 |
+
images = inputs.images
|
| 237 |
+
videos = inputs.videos
|
| 238 |
+
for media_type in ['images', 'videos']:
|
| 239 |
+
if locals()[media_type]:
|
| 240 |
+
if media_type == 'images':
|
| 241 |
+
media_token = self.image_token_id
|
| 242 |
+
media_inputs = processor.image_processor(
|
| 243 |
+
images=images, videos=None, return_tensors='pt', do_resize=False)
|
| 244 |
+
media_grid_thw = media_inputs['image_grid_thw']
|
| 245 |
+
else:
|
| 246 |
+
media_inputs = processor.image_processor(
|
| 247 |
+
images=None, videos=videos, return_tensors='pt', do_resize=False)
|
| 248 |
+
media_grid_thw = media_inputs['video_grid_thw']
|
| 249 |
+
media_token = self.video_token_id
|
| 250 |
+
if self.version == 'v2_5':
|
| 251 |
+
from qwen_vl_utils import vision_process
|
| 252 |
+
media_inputs['second_per_grid_ts'] = [
|
| 253 |
+
processor.image_processor.temporal_patch_size / vision_process.FPS
|
| 254 |
+
] * len(media_grid_thw)
|
| 255 |
+
idx_list = findall(input_ids, media_token)
|
| 256 |
+
merge_length = processor.image_processor.merge_size**2
|
| 257 |
+
|
| 258 |
+
def _get_new_tokens(i):
|
| 259 |
+
token_len = (media_grid_thw[i].prod() // merge_length)
|
| 260 |
+
return [media_token] * token_len
|
| 261 |
+
|
| 262 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 263 |
+
encoded.update(media_inputs)
|
| 264 |
+
|
| 265 |
+
encoded['input_ids'] = input_ids
|
| 266 |
+
encoded['labels'] = labels
|
| 267 |
+
return encoded
|
| 268 |
+
|
| 269 |
+
def compute_loss_context(self, model, inputs):
|
| 270 |
+
if 'real_position_ids' not in inputs:
|
| 271 |
+
return super().compute_loss_context(model, inputs)
|
| 272 |
+
if self.version == 'v2':
|
| 273 |
+
from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
|
| 274 |
+
elif self.version == 'v2_5':
|
| 275 |
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
|
| 276 |
+
elif self.version == 'omni':
|
| 277 |
+
from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
|
| 278 |
+
position_ids = inputs['position_ids']
|
| 279 |
+
inputs['position_ids'] = inputs.pop('real_position_ids')
|
| 280 |
+
return self._patch_flash_attention_forward(modeling_module, position_ids)
|
| 281 |
+
|
| 282 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 283 |
+
if not self.is_training:
|
| 284 |
+
return inputs
|
| 285 |
+
input_ids = inputs['input_ids']
|
| 286 |
+
_model = model.model
|
| 287 |
+
if not hasattr(_model, 'embed_tokens'):
|
| 288 |
+
_model = _model.model # LoRA
|
| 289 |
+
pixel_values = inputs.get('pixel_values')
|
| 290 |
+
pixel_values_videos = inputs.get('pixel_values_videos')
|
| 291 |
+
image_grid_thw = inputs.get('image_grid_thw')
|
| 292 |
+
video_grid_thw = inputs.get('video_grid_thw')
|
| 293 |
+
|
| 294 |
+
inputs_embeds = _model.embed_tokens(input_ids)
|
| 295 |
+
|
| 296 |
+
dtype = model.visual.get_dtype() if self.version == 'v2' else model.visual.dtype
|
| 297 |
+
if pixel_values is None and pixel_values_videos is None: # plain-text
|
| 298 |
+
if is_deepspeed_enabled():
|
| 299 |
+
from PIL import Image
|
| 300 |
+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
|
| 301 |
+
media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
|
| 302 |
+
device = input_ids.device
|
| 303 |
+
media_inputs = to_device(media_inputs, device)
|
| 304 |
+
pixel_values = media_inputs['pixel_values'].type(dtype)
|
| 305 |
+
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
|
| 306 |
+
inputs_embeds += image_embeds.mean() * 0.
|
| 307 |
+
else:
|
| 308 |
+
if pixel_values is not None:
|
| 309 |
+
pixel_values = pixel_values.type(dtype)
|
| 310 |
+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
| 311 |
+
image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
| 312 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 313 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 314 |
+
|
| 315 |
+
if pixel_values_videos is not None:
|
| 316 |
+
pixel_values_videos = pixel_values_videos.type(dtype)
|
| 317 |
+
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 318 |
+
video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
| 319 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 320 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 321 |
+
|
| 322 |
+
return {'inputs_embeds': inputs_embeds}
|
| 323 |
+
|
| 324 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 325 |
+
res = super()._data_collator_mm_data(batch)
|
| 326 |
+
second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')
|
| 327 |
+
if second_per_grid_ts:
|
| 328 |
+
res['second_per_grid_ts'] = second_per_grid_ts
|
| 329 |
+
for media_type in ['image', 'video']:
|
| 330 |
+
grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0)
|
| 331 |
+
if grid_thw is not None:
|
| 332 |
+
res[f'{media_type}_grid_thw'] = grid_thw
|
| 333 |
+
return res
|
| 334 |
+
|
| 335 |
+
def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
|
| 336 |
+
position_ids = []
|
| 337 |
+
for r in row:
|
| 338 |
+
r = r[0].copy()
|
| 339 |
+
r['input_ids'] = torch.tensor(r['input_ids'])[None]
|
| 340 |
+
position_ids.append(self._get_position_ids(r))
|
| 341 |
+
packed = super().packing_row(row)
|
| 342 |
+
packed['real_position_ids'] = torch.concat(position_ids, dim=-1)
|
| 343 |
+
return packed
|
| 344 |
+
|
| 345 |
+
def _get_position_ids(self, inputs: Dict[str, Any]):
|
| 346 |
+
# fix https://github.com/huggingface/transformers/pull/33487
|
| 347 |
+
kwargs = {}
|
| 348 |
+
if self.version == 'v2_5':
|
| 349 |
+
kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')}
|
| 350 |
+
position_ids, _ = self.model.get_rope_index(
|
| 351 |
+
inputs['input_ids'],
|
| 352 |
+
inputs.get('image_grid_thw'),
|
| 353 |
+
inputs.get('video_grid_thw'),
|
| 354 |
+
attention_mask=inputs.get('attention_mask'),
|
| 355 |
+
**kwargs)
|
| 356 |
+
return position_ids.contiguous()
|
| 357 |
+
|
| 358 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 359 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 360 |
+
if self._packing:
|
| 361 |
+
res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1)
|
| 362 |
+
elif self.is_training:
|
| 363 |
+
res['position_ids'] = self._get_position_ids(res)
|
| 364 |
+
return res
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_vl, template_cls=Qwen2VLTemplate))
|
| 368 |
+
|
| 369 |
+
register_template(
|
| 370 |
+
QwenTemplateMeta(
|
| 371 |
+
MLLMTemplateType.qvq,
|
| 372 |
+
default_system=('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
|
| 373 |
+
'Answer in the language of the question. You should think step-by-step.'),
|
| 374 |
+
template_cls=Qwen2VLTemplate,
|
| 375 |
+
))
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Qwen2_5VLTemplate(Qwen2VLTemplate):
|
| 379 |
+
version = 'v2_5'
|
| 380 |
+
norm_bbox = 'none'
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_vl, template_cls=Qwen2_5VLTemplate))
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
|
| 387 |
+
version = 'omni'
|
| 388 |
+
placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
|
| 389 |
+
|
| 390 |
+
def __init__(self, *args, **kwargs):
|
| 391 |
+
super().__init__(*args, **kwargs)
|
| 392 |
+
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
|
| 393 |
+
default = Qwen2_5OmniProcessorKwargs._defaults
|
| 394 |
+
self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
|
| 395 |
+
self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
|
| 396 |
+
self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
|
| 397 |
+
self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
|
| 398 |
+
|
| 399 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 400 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 401 |
+
from qwen_omni_utils import fetch_image, fetch_video
|
| 402 |
+
if media_type == 'image':
|
| 403 |
+
inputs.images[index] = fetch_image({'image': inputs.images[index]})
|
| 404 |
+
return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
|
| 405 |
+
elif media_type == 'audio':
|
| 406 |
+
inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
|
| 407 |
+
return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
|
| 408 |
+
elif media_type == 'video':
|
| 409 |
+
video = inputs.videos[index]
|
| 410 |
+
inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
|
| 411 |
+
if self.use_audio_in_video:
|
| 412 |
+
import librosa
|
| 413 |
+
if video.startswith('http://') or video.startswith('https://'):
|
| 414 |
+
import audioread
|
| 415 |
+
video = audioread.ffdec.FFmpegAudioFile(video)
|
| 416 |
+
video = librosa.load(video, sr=self.sampling_rate)[0]
|
| 417 |
+
inputs.audios.insert(inputs.audio_idx, (video, 'video'))
|
| 418 |
+
inputs.audio_idx += 1
|
| 419 |
+
return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
|
| 420 |
+
return ['<|vision_bos|><|VIDEO|><|vision_eos|>']
|
| 421 |
+
|
| 422 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 423 |
+
encoded = Template._encode(self, inputs)
|
| 424 |
+
processor = self.processor
|
| 425 |
+
video_audios_mask = []
|
| 426 |
+
for i, audio in enumerate(inputs.audios):
|
| 427 |
+
if isinstance(audio, tuple) and audio[1] == 'video':
|
| 428 |
+
inputs.audios[i] = audio[0]
|
| 429 |
+
video_audios_mask.append(True)
|
| 430 |
+
else:
|
| 431 |
+
video_audios_mask.append(False)
|
| 432 |
+
video_audios_mask = torch.tensor(video_audios_mask)
|
| 433 |
+
media_inputs = processor(
|
| 434 |
+
text='',
|
| 435 |
+
audio=inputs.audios or None,
|
| 436 |
+
images=inputs.images or None,
|
| 437 |
+
videos=inputs.videos or None,
|
| 438 |
+
return_tensors='pt')
|
| 439 |
+
media_inputs.pop('input_ids')
|
| 440 |
+
media_inputs.pop('attention_mask')
|
| 441 |
+
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
|
| 442 |
+
input_ids = encoded['input_ids']
|
| 443 |
+
labels = encoded['labels']
|
| 444 |
+
# audio
|
| 445 |
+
audio_token_id = self._tokenize('<|AUDIO|>')
|
| 446 |
+
idx_list = findall(input_ids, audio_token_id)
|
| 447 |
+
feature_attention_mask = media_inputs.get('feature_attention_mask')
|
| 448 |
+
if feature_attention_mask is not None:
|
| 449 |
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 450 |
+
audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
|
| 451 |
+
else:
|
| 452 |
+
audio_lengths = None
|
| 453 |
+
audio_lengths_origin = audio_lengths
|
| 454 |
+
if idx_list:
|
| 455 |
+
if self.use_audio_in_video:
|
| 456 |
+
audio_lengths = audio_lengths[~video_audios_mask]
|
| 457 |
+
|
| 458 |
+
def _get_new_audio_tokens(i):
|
| 459 |
+
return audio_token_id * audio_lengths[i]
|
| 460 |
+
|
| 461 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
|
| 462 |
+
|
| 463 |
+
for media_type in ['image', 'video']:
|
| 464 |
+
token = f'<|{media_type.upper()}|>'
|
| 465 |
+
token_id = self._tokenize(token)
|
| 466 |
+
idx_list = findall(input_ids, token_id)
|
| 467 |
+
if idx_list:
|
| 468 |
+
merge_size = processor.image_processor.merge_size
|
| 469 |
+
media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
|
| 470 |
+
if media_type == 'video' and self.use_audio_in_video:
|
| 471 |
+
audio_lengths = audio_lengths_origin[video_audios_mask]
|
| 472 |
+
video_second_per_grid = media_inputs['video_second_per_grid']
|
| 473 |
+
|
| 474 |
+
def _get_new_tokens_use_audio_in_video(i):
|
| 475 |
+
audio_token_indices = torch.arange(audio_lengths[i])
|
| 476 |
+
grid_thw = media_grid_thw[i]
|
| 477 |
+
height = grid_thw[1] // merge_size
|
| 478 |
+
width = grid_thw[2] // merge_size
|
| 479 |
+
video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
|
| 480 |
+
video_token_indices = torch.broadcast_to(
|
| 481 |
+
video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
|
| 482 |
+
video_token_indices = (
|
| 483 |
+
video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
|
| 484 |
+
tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
|
| 485 |
+
video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
|
| 486 |
+
audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)
|
| 487 |
+
|
| 488 |
+
res = []
|
| 489 |
+
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
|
| 490 |
+
if j < len(video_chunk_indexes):
|
| 491 |
+
video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
|
| 492 |
+
res += token_id * video_seq_length
|
| 493 |
+
if j < len(audio_chunk_indexes):
|
| 494 |
+
audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
|
| 495 |
+
res += audio_token_id * audio_seq_length
|
| 496 |
+
return res
|
| 497 |
+
|
| 498 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
|
| 499 |
+
_get_new_tokens_use_audio_in_video)
|
| 500 |
+
|
| 501 |
+
else:
|
| 502 |
+
|
| 503 |
+
def _get_new_tokens(i):
|
| 504 |
+
token_len = (media_grid_thw[i].prod() // (merge_size**2))
|
| 505 |
+
return token_id * token_len
|
| 506 |
+
|
| 507 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 508 |
+
|
| 509 |
+
encoded['input_ids'] = input_ids
|
| 510 |
+
encoded['labels'] = labels
|
| 511 |
+
encoded.update(media_inputs)
|
| 512 |
+
return encoded
|
| 513 |
+
|
| 514 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 515 |
+
return Template._post_encode(self, model, inputs)
|
| 516 |
+
|
| 517 |
+
def _get_position_ids(self, inputs: Dict[str, Any]):
|
| 518 |
+
feature_attention_mask = inputs.get('feature_attention_mask')
|
| 519 |
+
if feature_attention_mask is not None:
|
| 520 |
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 521 |
+
else:
|
| 522 |
+
audio_feature_lengths = None
|
| 523 |
+
video_second_per_grid = inputs.pop('video_second_per_grid', None)
|
| 524 |
+
input_ids = inputs['input_ids']
|
| 525 |
+
attention_mask = inputs.get('attention_mask')
|
| 526 |
+
if attention_mask is None:
|
| 527 |
+
attention_mask = torch.ones_like(input_ids)
|
| 528 |
+
position_ids, _ = self.model.thinker.get_rope_index(
|
| 529 |
+
input_ids,
|
| 530 |
+
inputs.get('image_grid_thw'),
|
| 531 |
+
inputs.get('video_grid_thw'),
|
| 532 |
+
attention_mask,
|
| 533 |
+
self.use_audio_in_video,
|
| 534 |
+
audio_feature_lengths,
|
| 535 |
+
video_second_per_grid,
|
| 536 |
+
)
|
| 537 |
+
return position_ids.contiguous()
|
| 538 |
+
|
| 539 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 540 |
+
res = super()._data_collator_mm_data(batch)
|
| 541 |
+
video_second_per_grid = self.gather_list(batch, 'video_second_per_grid')
|
| 542 |
+
if video_second_per_grid:
|
| 543 |
+
res['video_second_per_grid'] = video_second_per_grid
|
| 544 |
+
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
|
| 545 |
+
feature_attention_mask = [
|
| 546 |
+
b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
|
| 547 |
+
]
|
| 548 |
+
if input_features:
|
| 549 |
+
res['input_features'] = torch.concat(input_features)
|
| 550 |
+
res['feature_attention_mask'] = torch.concat(feature_attention_mask)
|
| 551 |
+
return res
|
| 552 |
+
|
| 553 |
+
def generate(self, model, *args, **kwargs):
|
| 554 |
+
if kwargs.get('video_grid_thw') is not None:
|
| 555 |
+
kwargs['use_audio_in_video'] = self.use_audio_in_video
|
| 556 |
+
return super().generate(model, *args, **kwargs)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_omni, template_cls=Qwen2_5OmniTemplate))
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class Ovis1_6Template(Template):
|
| 563 |
+
skip_prompt = False
|
| 564 |
+
use_model = True
|
| 565 |
+
|
| 566 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 567 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 568 |
+
assert media_type == 'image'
|
| 569 |
+
return [[-200], '\n']
|
| 570 |
+
|
| 571 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 572 |
+
encoded = super()._encode(inputs)
|
| 573 |
+
images = inputs.images
|
| 574 |
+
input_ids = encoded['input_ids']
|
| 575 |
+
labels = encoded['labels']
|
| 576 |
+
idx_list = findall(input_ids, [-200])
|
| 577 |
+
added_tokens_len = 0
|
| 578 |
+
pixel_values = []
|
| 579 |
+
for i, idx in enumerate(idx_list):
|
| 580 |
+
max_partition = get_env_args('max_partition', int, 9)
|
| 581 |
+
raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image(
|
| 582 |
+
images[i], max_partition=max_partition)
|
| 583 |
+
input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:]
|
| 584 |
+
if labels is not None:
|
| 585 |
+
labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:]
|
| 586 |
+
pixel_values.append(raw_pixel_values)
|
| 587 |
+
added_tokens_len += len(image_placeholders) - 1
|
| 588 |
+
dtype = self.model.visual_tokenizer.dtype
|
| 589 |
+
if pixel_values:
|
| 590 |
+
pixel_values = torch.cat(pixel_values, dim=0).to(dtype)
|
| 591 |
+
else:
|
| 592 |
+
pixel_values = torch.zeros((1, 3, 384, 384), dtype=dtype) # dummpy
|
| 593 |
+
encoded.update({'input_ids': input_ids, 'labels': labels})
|
| 594 |
+
encoded['pixel_values'] = [pixel_values]
|
| 595 |
+
return encoded
|
| 596 |
+
|
| 597 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 598 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 599 |
+
if self.max_length is not None:
|
| 600 |
+
model.config.multimodal_max_length = self.max_length
|
| 601 |
+
input_ids = inputs['input_ids']
|
| 602 |
+
labels = inputs.get('labels')
|
| 603 |
+
if labels is None:
|
| 604 |
+
labels = input_ids.new_full(input_ids.shape, -100)
|
| 605 |
+
_, inputs_embeds, labels, attention_mask = model.merge_multimodal(
|
| 606 |
+
text_input_ids=input_ids,
|
| 607 |
+
text_attention_masks=torch.ones_like(input_ids), # not use, only compat
|
| 608 |
+
text_labels=labels,
|
| 609 |
+
pixel_values=inputs['pixel_values'],
|
| 610 |
+
left_padding=padding_side == 'left')
|
| 611 |
+
if inputs.get('labels') is None:
|
| 612 |
+
labels = None
|
| 613 |
+
return {'inputs_embeds': inputs_embeds, 'labels': labels, 'attention_mask': attention_mask}
|
| 614 |
+
|
| 615 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 616 |
+
pixel_values = self.gather_list(batch, 'pixel_values')
|
| 617 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 618 |
+
res['pixel_values'] = pixel_values
|
| 619 |
+
return res
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
register_template(
|
| 623 |
+
TemplateMeta(
|
| 624 |
+
MLLMTemplateType.ovis1_6,
|
| 625 |
+
prefix=['<bos>'],
|
| 626 |
+
prompt=['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
|
| 627 |
+
chat_sep=['<end_of_turn>\n'],
|
| 628 |
+
suffix=['<end_of_turn>'],
|
| 629 |
+
system_prefix=['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'],
|
| 630 |
+
template_cls=Ovis1_6Template,
|
| 631 |
+
))
|
| 632 |
+
|
| 633 |
+
register_template(
|
| 634 |
+
Llama3TemplateMeta(
|
| 635 |
+
MLLMTemplateType.ovis1_6_llama3,
|
| 636 |
+
default_system='You are a helpful and honest multimodal assistant.',
|
| 637 |
+
template_cls=Ovis1_6Template,
|
| 638 |
+
))
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class Ovis2Template(Ovis1_6Template):
|
| 642 |
+
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
|
| 643 |
+
nframes = 12
|
| 644 |
+
|
| 645 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 646 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 647 |
+
if media_type == 'image':
|
| 648 |
+
return [[-200], '\n']
|
| 649 |
+
elif media_type == 'video':
|
| 650 |
+
nframes = get_env_args('nframes', int, self.nframes)
|
| 651 |
+
inputs.images = load_video_ovis2(inputs.videos[index], nframes)
|
| 652 |
+
return [[-200] * nframes, '\n']
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
register_template(QwenTemplateMeta(
|
| 656 |
+
MLLMTemplateType.ovis2,
|
| 657 |
+
template_cls=Ovis2Template,
|
| 658 |
+
))
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@dataclass
|
| 662 |
+
class MarcoO1TemplateMeta(QwenTemplateMeta):
|
| 663 |
+
default_system: Optional[str] = """
|
| 664 |
+
你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
|
| 665 |
+
\n## 重要!!!!!
|
| 666 |
+
当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
|
| 667 |
+
<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
register_template(MarcoO1TemplateMeta(LLMTemplateType.marco_o1))
|
swift/llm/template/template/stepfun.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
from ..base import Template
|
| 5 |
+
from ..constant import MLLMTemplateType
|
| 6 |
+
from ..register import TemplateMeta, register_template
|
| 7 |
+
from ..template_inputs import StdTemplateInputs
|
| 8 |
+
from ..utils import Context
|
| 9 |
+
from ..vision_utils import load_file
|
| 10 |
+
from .qwen import QwenTemplateMeta
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GOTImageEvalProcessor:
|
| 14 |
+
|
| 15 |
+
def __init__(self, image_size=384, mean=None, std=None):
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 18 |
+
if mean is None:
|
| 19 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 20 |
+
if std is None:
|
| 21 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 22 |
+
|
| 23 |
+
self.normalize = transforms.Normalize(mean, std)
|
| 24 |
+
|
| 25 |
+
self.transform = transforms.Compose([
|
| 26 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
| 27 |
+
transforms.ToTensor(),
|
| 28 |
+
self.normalize,
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
def __call__(self, item):
|
| 32 |
+
return self.transform(item)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GOT_OCR2Template(Template):
|
| 36 |
+
placeholder_tokens = ['<imgpad>']
|
| 37 |
+
|
| 38 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 39 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 40 |
+
# 'OCR: '
|
| 41 |
+
# 'OCR with format: '
|
| 42 |
+
assert media_type == 'image'
|
| 43 |
+
return ['<img>' + '<imgpad>' * 256 + '</img>\n']
|
| 44 |
+
|
| 45 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 46 |
+
encoded = super()._encode(inputs)
|
| 47 |
+
images = inputs.images
|
| 48 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 49 |
+
for i, image in enumerate(images):
|
| 50 |
+
images[i] = image_processor_high(image)[None].to(self.model_info.torch_dtype)
|
| 51 |
+
if images:
|
| 52 |
+
encoded['images'] = images
|
| 53 |
+
return encoded
|
| 54 |
+
|
| 55 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 56 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 57 |
+
images = self.gather_list(batch, 'images')
|
| 58 |
+
if images:
|
| 59 |
+
res['images'] = images
|
| 60 |
+
return res
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
register_template(
|
| 64 |
+
QwenTemplateMeta(
|
| 65 |
+
MLLMTemplateType.got_ocr2,
|
| 66 |
+
default_system=' You should follow the instructions carefully and explain your answers in detail.',
|
| 67 |
+
template_cls=GOT_OCR2Template,
|
| 68 |
+
))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class GOT_OCR2HfTemplate(Template):
|
| 72 |
+
placeholder_tokens = ['<imgpad>']
|
| 73 |
+
|
| 74 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 75 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 76 |
+
# 'OCR: '
|
| 77 |
+
# 'OCR with format: '
|
| 78 |
+
assert media_type == 'image'
|
| 79 |
+
return ['<img>' + '<imgpad>' * 256 + '</img>\n']
|
| 80 |
+
|
| 81 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面
|
| 82 |
+
encoded = super()._encode(inputs)
|
| 83 |
+
images = inputs.images
|
| 84 |
+
if images:
|
| 85 |
+
encoded['images'] = images
|
| 86 |
+
return encoded
|
| 87 |
+
|
| 88 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 89 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 90 |
+
images = self.gather_list(batch, 'images')
|
| 91 |
+
_inputs = self.processor(images, return_tensors='pt')
|
| 92 |
+
_inputs.pop('input_ids') # this does not contain the response, so cannot be used when training
|
| 93 |
+
_inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training
|
| 94 |
+
|
| 95 |
+
res.update(_inputs.data)
|
| 96 |
+
return res
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
register_template(
|
| 100 |
+
QwenTemplateMeta(
|
| 101 |
+
MLLMTemplateType.got_ocr2_hf,
|
| 102 |
+
default_system=' You should follow the instructions carefully and explain your answers in detail.',
|
| 103 |
+
template_cls=GOT_OCR2HfTemplate,
|
| 104 |
+
))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class StepAudioTemplate(Template):
|
| 108 |
+
use_model = True
|
| 109 |
+
|
| 110 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 111 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 112 |
+
assert media_type == 'audio', f'media_type: {media_type}'
|
| 113 |
+
from utils import load_audio
|
| 114 |
+
audio_wav, sr = load_audio(load_file(inputs.audios[index]))
|
| 115 |
+
audio_tokens = self.model.encoder(audio_wav, sr)
|
| 116 |
+
return audio_tokens
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
register_template(
|
| 120 |
+
TemplateMeta(
|
| 121 |
+
MLLMTemplateType.step_audio,
|
| 122 |
+
template_cls=StepAudioTemplate,
|
| 123 |
+
prefix=['<s>'],
|
| 124 |
+
prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'],
|
| 125 |
+
system_prefix=['<s><|BOT|>system\n{{SYSTEM}}<|EOT|>'],
|
| 126 |
+
chat_sep=['<|EOT|>'],
|
| 127 |
+
suffix=['<|EOT|>'],
|
| 128 |
+
))
|