Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ms-swift/processed_data/processed_silence_speaker_segments.json +0 -0
- ms-swift/processed_data/processed_silence_transcriptions.json +0 -0
- ms-swift/silence_overlaps/700/original/overlap5s_segments.json +0 -0
- ms-swift/silence_overlaps/700/original/overlap5s_silence_segments.json +0 -0
- ms-swift/silence_overlaps/700/test/.ipynb_checkpoints/overlap5s_segments_test-checkpoint.json +27 -0
- ms-swift/silence_overlaps/700/train/silence_issilence_train.json +0 -0
- ms-swift/silence_overlaps/700/train/silence_speaker_segments_train.json +0 -0
- ms-swift/silence_overlaps/cleaned_transcriptions2.json +0 -0
- ms-swift/swift/llm/model/model/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/bert.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/codefuse.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/deepseek.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/llm.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/mamba.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/microsoft.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/minicpm.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/minimax.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/mllm.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/moonshot.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/mplug.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/openbuddy.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/qwen.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/skywork.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/stepfun.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/telechat.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/valley.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/__pycache__/yi.cpython-310.pyc +0 -0
- ms-swift/swift/llm/model/model/glm.py +371 -0
- ms-swift/swift/llm/model/model/minimax.py +156 -0
- ms-swift/swift/llm/model/model/openbuddy.py +85 -0
- ms-swift/swift/llm/model/model/qwen.py +899 -0
- ms-swift/swift/llm/model/model/skywork.py +89 -0
- ms-swift/swift/llm/model/model/yi.py +146 -0
- ms-swift/swift/llm/sampling/__init__.py +1 -0
- ms-swift/swift/llm/sampling/base.py +55 -0
- ms-swift/swift/llm/sampling/sampling.py +106 -0
- ms-swift/swift/llm/sampling/utils.py +172 -0
- ms-swift/swift/llm/sampling/vanilla_sampler.py +216 -0
- ms-swift/swift/llm/template/__init__.py +10 -0
- ms-swift/swift/llm/template/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/base.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/grounding.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/register.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/template_inputs.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/utils.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/base.py +1606 -0
- ms-swift/swift/llm/template/grounding.py +91 -0
- ms-swift/swift/llm/template/register.py +57 -0
- ms-swift/swift/llm/template/template/__init__.py +2 -0
ms-swift/processed_data/processed_silence_speaker_segments.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/processed_data/processed_silence_transcriptions.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/original/overlap5s_segments.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/original/overlap5s_silence_segments.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/test/.ipynb_checkpoints/overlap5s_segments_test-checkpoint.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"key": "SODA_PROCESSED--train--123906",
|
| 4 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--123906.wav",
|
| 5 |
+
"model_output": "Multiple speakers talk simultaneously from 00:03-00:09"
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"key": "SODA_PROCESSED--train--1112763",
|
| 9 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--1112763.wav",
|
| 10 |
+
"model_output": "Multiple speakers talk simultaneously from 00:09-00:15"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"key": "SODA_PROCESSED--train--790538",
|
| 14 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--790538.wav",
|
| 15 |
+
"model_output": "Multiple speakers talk simultaneously from 00:15-00:19"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"key": "SODA_PROCESSED--train--822773",
|
| 19 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--822773.wav",
|
| 20 |
+
"model_output": "Multiple speakers talk simultaneously from 00:14-00:19"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"key": "SODA_PROCESSED--train--424960",
|
| 24 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--424960.wav",
|
| 25 |
+
"model_output": "Multiple speakers talk simultaneously from 00:29-00:33"
|
| 26 |
+
}
|
| 27 |
+
]
|
ms-swift/silence_overlaps/700/train/silence_issilence_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/train/silence_speaker_segments_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/cleaned_transcriptions2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/swift/llm/model/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (632 Bytes). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/bert.cpython-310.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/codefuse.cpython-310.pyc
ADDED
|
Binary file (1.72 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/deepseek.cpython-310.pyc
ADDED
|
Binary file (6.45 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/llm.cpython-310.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/mamba.cpython-310.pyc
ADDED
|
Binary file (1.54 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/microsoft.cpython-310.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/minicpm.cpython-310.pyc
ADDED
|
Binary file (4.59 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/minimax.cpython-310.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/mllm.cpython-310.pyc
ADDED
|
Binary file (4.44 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/moonshot.cpython-310.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/mplug.cpython-310.pyc
ADDED
|
Binary file (3.73 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/openbuddy.cpython-310.pyc
ADDED
|
Binary file (1.83 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/qwen.cpython-310.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/skywork.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/stepfun.cpython-310.pyc
ADDED
|
Binary file (2.21 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/telechat.cpython-310.pyc
ADDED
|
Binary file (1.74 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/valley.cpython-310.pyc
ADDED
|
Binary file (2.71 kB). View file
|
|
|
ms-swift/swift/llm/model/model/__pycache__/yi.cpython-310.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
ms-swift/swift/llm/model/model/glm.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import transformers
|
| 6 |
+
from packaging import version
|
| 7 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
|
| 8 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 9 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| 10 |
+
|
| 11 |
+
from swift.llm import TemplateType
|
| 12 |
+
from swift.utils import get_device_count, get_dist_setting, get_logger
|
| 13 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 14 |
+
from ..model_arch import ModelArch
|
| 15 |
+
from ..patcher import patch_output_to_input_device
|
| 16 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 17 |
+
from ..utils import AttnImpl, ModelInfo, safe_snapshot_download
|
| 18 |
+
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def remove_property(tokenizer_cls: Type[PreTrainedTokenizerBase], tokenizer_config: Dict[str, Any]) -> None:
|
| 23 |
+
for k, v in tokenizer_cls.__dict__.items():
|
| 24 |
+
if k.endswith('_token') and isinstance(v, property) and k in tokenizer_config:
|
| 25 |
+
setattr(tokenizer_cls, k, tokenizer_config[k])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_model_tokenizer_chatglm(model_dir: str,
|
| 29 |
+
model_info: ModelInfo,
|
| 30 |
+
model_kwargs: Dict[str, Any],
|
| 31 |
+
load_model: bool = True,
|
| 32 |
+
**kwargs):
|
| 33 |
+
if model_kwargs.get('quantization_config') is not None:
|
| 34 |
+
model_kwargs['quantization_config'].llm_int8_skip_modules = ['output_layer']
|
| 35 |
+
# fix transformers>=4.34 bug
|
| 36 |
+
if version.parse(transformers.__version__) >= version.parse('4.34'):
|
| 37 |
+
tokenizer_config = get_tokenizer_config(model_dir)
|
| 38 |
+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
|
| 39 |
+
tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
|
| 40 |
+
tokenizer_cls._auto_class = 'AutoTokenizer'
|
| 41 |
+
remove_property(tokenizer_cls, tokenizer_config)
|
| 42 |
+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
|
| 43 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 44 |
+
if model is not None:
|
| 45 |
+
from torch.nn import CrossEntropyLoss
|
| 46 |
+
__old_forward = CrossEntropyLoss.forward
|
| 47 |
+
|
| 48 |
+
def cross_entropy_forward(self, inputs: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
target = target.to(device=inputs.device)
|
| 50 |
+
return __old_forward(self, inputs, target)
|
| 51 |
+
|
| 52 |
+
CrossEntropyLoss.forward = cross_entropy_forward
|
| 53 |
+
|
| 54 |
+
return model, tokenizer
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
register_model(
|
| 58 |
+
ModelMeta(
|
| 59 |
+
LLMModelType.chatglm2, [
|
| 60 |
+
ModelGroup([
|
| 61 |
+
Model('ZhipuAI/chatglm2-6b', 'THUDM/chatglm2-6b'),
|
| 62 |
+
Model('ZhipuAI/chatglm2-6b-32k', 'THUDM/chatglm2-6b-32k')
|
| 63 |
+
],
|
| 64 |
+
requires=['transformers<4.42']),
|
| 65 |
+
ModelGroup(
|
| 66 |
+
[Model('ZhipuAI/codegeex2-6b', 'THUDM/codegeex2-6b')],
|
| 67 |
+
requires=['transformers<4.34'],
|
| 68 |
+
tags=['coding'],
|
| 69 |
+
),
|
| 70 |
+
],
|
| 71 |
+
TemplateType.chatglm2,
|
| 72 |
+
get_model_tokenizer_chatglm,
|
| 73 |
+
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
|
| 74 |
+
model_arch=ModelArch.chatglm))
|
| 75 |
+
|
| 76 |
+
register_model(
|
| 77 |
+
ModelMeta(
|
| 78 |
+
LLMModelType.chatglm3, [
|
| 79 |
+
ModelGroup([
|
| 80 |
+
Model('ZhipuAI/chatglm3-6b', 'THUDM/chatglm3-6b'),
|
| 81 |
+
Model('ZhipuAI/chatglm3-6b-base', 'THUDM/chatglm3-6b-base'),
|
| 82 |
+
Model('ZhipuAI/chatglm3-6b-32k', 'THUDM/chatglm3-6b-32k'),
|
| 83 |
+
Model('ZhipuAI/chatglm3-6b-128k', 'THUDM/chatglm3-6b-128k'),
|
| 84 |
+
])
|
| 85 |
+
],
|
| 86 |
+
TemplateType.glm4,
|
| 87 |
+
get_model_tokenizer_chatglm,
|
| 88 |
+
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
|
| 89 |
+
requires=['transformers<4.42'],
|
| 90 |
+
model_arch=ModelArch.chatglm))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_model_tokenizer_glm4(model_dir: str,
|
| 94 |
+
model_info: ModelInfo,
|
| 95 |
+
model_kwargs: Dict[str, Any],
|
| 96 |
+
load_model: bool = True,
|
| 97 |
+
**kwargs):
|
| 98 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 99 |
+
AttnImpl.update_attn_impl(model_config, kwargs.get('attn_impl'))
|
| 100 |
+
kwargs['model_config'] = model_config
|
| 101 |
+
model, tokenizer = get_model_tokenizer_chatglm(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 102 |
+
if len(tokenizer.encode('<|user|>', add_special_tokens=False)) > 1:
|
| 103 |
+
for k in tokenizer.special_tokens.keys():
|
| 104 |
+
tokenizer.add_tokens(k)
|
| 105 |
+
return model, tokenizer
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
register_model(
|
| 109 |
+
ModelMeta(
|
| 110 |
+
LLMModelType.glm4,
|
| 111 |
+
[
|
| 112 |
+
ModelGroup([
|
| 113 |
+
Model('ZhipuAI/glm-4-9b-chat', 'THUDM/glm-4-9b-chat'),
|
| 114 |
+
Model('ZhipuAI/glm-4-9b', 'THUDM/glm-4-9b'),
|
| 115 |
+
Model('ZhipuAI/glm-4-9b-chat-1m', 'THUDM/glm-4-9b-chat-1m'),
|
| 116 |
+
]),
|
| 117 |
+
ModelGroup([
|
| 118 |
+
Model('ZhipuAI/LongWriter-glm4-9b', 'THUDM/LongWriter-glm4-9b'),
|
| 119 |
+
])
|
| 120 |
+
],
|
| 121 |
+
TemplateType.glm4,
|
| 122 |
+
get_model_tokenizer_glm4,
|
| 123 |
+
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
|
| 124 |
+
model_arch=ModelArch.chatglm,
|
| 125 |
+
requires=['transformers>=4.42'],
|
| 126 |
+
))
|
| 127 |
+
|
| 128 |
+
register_model(
|
| 129 |
+
ModelMeta(
|
| 130 |
+
LLMModelType.glm4_0414,
|
| 131 |
+
[
|
| 132 |
+
ModelGroup([
|
| 133 |
+
Model('ZhipuAI/GLM-4-9B-0414', 'THUDM/GLM-4-9B-0414'),
|
| 134 |
+
Model('ZhipuAI/GLM-4-32B-0414', 'THUDM/GLM-4-32B-0414'),
|
| 135 |
+
Model('ZhipuAI/GLM-4-32B-Base-0414', 'THUDM/GLM-4-32B-Base-0414'),
|
| 136 |
+
Model('ZhipuAI/GLM-Z1-9B-0414', 'THUDM/GLM-Z1-9B-0414'),
|
| 137 |
+
Model('ZhipuAI/GLM-Z1-32B-0414', 'THUDM/GLM-Z1-32B-0414'),
|
| 138 |
+
])
|
| 139 |
+
],
|
| 140 |
+
TemplateType.glm4_0414,
|
| 141 |
+
get_model_tokenizer_with_flash_attn,
|
| 142 |
+
architectures=['Glm4ForCausalLM'],
|
| 143 |
+
model_arch=ModelArch.chatglm,
|
| 144 |
+
requires=['transformers>=4.51'],
|
| 145 |
+
))
|
| 146 |
+
|
| 147 |
+
register_model(
|
| 148 |
+
ModelMeta(
|
| 149 |
+
LLMModelType.glm4_z1_rumination,
|
| 150 |
+
[ModelGroup([
|
| 151 |
+
Model('ZhipuAI/GLM-Z1-Rumination-32B-0414', 'THUDM/GLM-Z1-Rumination-32B-0414'),
|
| 152 |
+
])],
|
| 153 |
+
TemplateType.glm4_z1_rumination,
|
| 154 |
+
get_model_tokenizer_with_flash_attn,
|
| 155 |
+
architectures=['Glm4ForCausalLM'],
|
| 156 |
+
model_arch=ModelArch.chatglm,
|
| 157 |
+
requires=['transformers>4.51'],
|
| 158 |
+
))
|
| 159 |
+
|
| 160 |
+
register_model(
|
| 161 |
+
ModelMeta(
|
| 162 |
+
LLMModelType.longwriter_llama3_1,
|
| 163 |
+
[ModelGroup([
|
| 164 |
+
Model('ZhipuAI/LongWriter-llama3.1-8b', 'THUDM/LongWriter-llama3.1-8b'),
|
| 165 |
+
])],
|
| 166 |
+
TemplateType.longwriter_llama,
|
| 167 |
+
get_model_tokenizer_with_flash_attn,
|
| 168 |
+
architectures=['LlamaForCausalLM'],
|
| 169 |
+
requires=['transformers>=4.43'],
|
| 170 |
+
model_arch=ModelArch.llama,
|
| 171 |
+
))
|
| 172 |
+
|
| 173 |
+
register_model(
|
| 174 |
+
ModelMeta(
|
| 175 |
+
LLMModelType.codegeex4,
|
| 176 |
+
[ModelGroup([
|
| 177 |
+
Model('ZhipuAI/codegeex4-all-9b', 'THUDM/codegeex4-all-9b'),
|
| 178 |
+
])],
|
| 179 |
+
TemplateType.codegeex4,
|
| 180 |
+
get_model_tokenizer_glm4,
|
| 181 |
+
requires=['transformers<4.42'],
|
| 182 |
+
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
|
| 183 |
+
model_arch=ModelArch.chatglm,
|
| 184 |
+
tags=['coding'],
|
| 185 |
+
))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_model_tokenizer_glm4v(model_dir: str,
|
| 189 |
+
model_info: ModelInfo,
|
| 190 |
+
model_kwargs: Dict[str, Any],
|
| 191 |
+
load_model: bool = True,
|
| 192 |
+
**kwargs):
|
| 193 |
+
model, tokenizer = get_model_tokenizer_glm4(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 194 |
+
# fix merge-lora
|
| 195 |
+
tokenizer.init_kwargs['image_size'] = 1120
|
| 196 |
+
if load_model:
|
| 197 |
+
# fix device_map 4
|
| 198 |
+
n_gpu = get_device_count()
|
| 199 |
+
local_world_size = get_dist_setting()[3]
|
| 200 |
+
if n_gpu // local_world_size >= 4:
|
| 201 |
+
for layer in model.transformer.vision.transformer.layers:
|
| 202 |
+
patch_output_to_input_device(layer.mlp)
|
| 203 |
+
patch_output_to_input_device(layer.post_attention_layernorm)
|
| 204 |
+
device = next(model.transformer.vision.linear_proj.parameters()).device
|
| 205 |
+
model.transformer.vision.boi.data = model.transformer.vision.boi.to(device)
|
| 206 |
+
model.transformer.vision.eoi.data = model.transformer.vision.eoi.to(device)
|
| 207 |
+
return model, tokenizer
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
register_model(
|
| 211 |
+
ModelMeta(
|
| 212 |
+
MLLMModelType.glm4v,
|
| 213 |
+
[
|
| 214 |
+
ModelGroup(
|
| 215 |
+
[
|
| 216 |
+
Model('ZhipuAI/glm-4v-9b', 'THUDM/glm-4v-9b'),
|
| 217 |
+
],
|
| 218 |
+
requires=['transformers>=4.42,<4.45'],
|
| 219 |
+
),
|
| 220 |
+
ModelGroup(
|
| 221 |
+
[
|
| 222 |
+
Model('ZhipuAI/cogagent-9b-20241220', 'THUDM/cogagent-9b-20241220'),
|
| 223 |
+
],
|
| 224 |
+
requires=['transformers>=4.42'],
|
| 225 |
+
)
|
| 226 |
+
],
|
| 227 |
+
TemplateType.glm4v,
|
| 228 |
+
get_model_tokenizer_glm4v,
|
| 229 |
+
architectures=['ChatGLMModel', 'ChatGLMForConditionalGeneration'],
|
| 230 |
+
model_arch=ModelArch.glm4v,
|
| 231 |
+
))
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_model_tokenizer_cogvlm(model_dir: str,
|
| 235 |
+
model_info: ModelInfo,
|
| 236 |
+
model_kwargs: Dict[str, Any],
|
| 237 |
+
load_model: bool = True,
|
| 238 |
+
**kwargs):
|
| 239 |
+
tokenizer_dir = safe_snapshot_download('AI-ModelScope/vicuna-7b-v1.5', download_model=False, check_local=True)
|
| 240 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, trust_remote_code=True)
|
| 241 |
+
if load_model:
|
| 242 |
+
logger.warning('CogAgent with FusedLayerNorm will cause an training loss of NAN, '
|
| 243 |
+
'to avoid this, please uninstall apex.')
|
| 244 |
+
logger.info('Please ignore the unimported warning.')
|
| 245 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(
|
| 246 |
+
model_dir, model_info, model_kwargs, load_model, tokenizer=tokenizer, **kwargs)
|
| 247 |
+
return model, tokenizer
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
register_model(
|
| 251 |
+
ModelMeta(
|
| 252 |
+
MLLMModelType.cogvlm, [
|
| 253 |
+
ModelGroup([
|
| 254 |
+
Model('ZhipuAI/cogvlm-chat', 'THUDM/cogvlm-chat-hf'),
|
| 255 |
+
]),
|
| 256 |
+
],
|
| 257 |
+
TemplateType.cogvlm,
|
| 258 |
+
get_model_tokenizer_cogvlm,
|
| 259 |
+
architectures=['CogVLMForCausalLM'],
|
| 260 |
+
requires=['transformers<4.42'],
|
| 261 |
+
model_arch=ModelArch.cogvlm))
|
| 262 |
+
|
| 263 |
+
register_model(
|
| 264 |
+
ModelMeta(
|
| 265 |
+
MLLMModelType.cogagent_chat, [
|
| 266 |
+
ModelGroup([
|
| 267 |
+
Model('ZhipuAI/cogagent-chat', 'THUDM/cogagent-chat-hf'),
|
| 268 |
+
]),
|
| 269 |
+
],
|
| 270 |
+
TemplateType.cogagent_chat,
|
| 271 |
+
get_model_tokenizer_cogvlm,
|
| 272 |
+
architectures=['CogAgentForCausalLM'],
|
| 273 |
+
requires=['transformers<4.42', 'timm'],
|
| 274 |
+
model_arch=ModelArch.cogvlm))
|
| 275 |
+
|
| 276 |
+
register_model(
|
| 277 |
+
ModelMeta(
|
| 278 |
+
MLLMModelType.cogagent_vqa, [ModelGroup([
|
| 279 |
+
Model('ZhipuAI/cogagent-vqa', 'THUDM/cogagent-vqa-hf'),
|
| 280 |
+
])],
|
| 281 |
+
TemplateType.cogagent_vqa,
|
| 282 |
+
get_model_tokenizer_cogvlm,
|
| 283 |
+
architectures=['CogAgentForCausalLM'],
|
| 284 |
+
requires=['transformers<4.42'],
|
| 285 |
+
model_arch=ModelArch.cogvlm))
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_model_tokenizer_cogvlm2(*args, **kwargs):
|
| 289 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 290 |
+
if model is not None:
|
| 291 |
+
# fix device map 4
|
| 292 |
+
for layer in model.model.vision.transformer.layers:
|
| 293 |
+
patch_output_to_input_device(layer.mlp)
|
| 294 |
+
patch_output_to_input_device(layer.post_attention_layernorm)
|
| 295 |
+
|
| 296 |
+
device = next(model.model.vision.linear_proj.parameters()).device
|
| 297 |
+
model.model.vision.boi.data = model.model.vision.boi.to(device)
|
| 298 |
+
model.model.vision.eoi.data = model.model.vision.eoi.to(device)
|
| 299 |
+
return model, tokenizer
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
register_model(
|
| 303 |
+
ModelMeta(
|
| 304 |
+
MLLMModelType.cogvlm2, [
|
| 305 |
+
ModelGroup([
|
| 306 |
+
Model('ZhipuAI/cogvlm2-llama3-chat-19B', 'THUDM/cogvlm2-llama3-chat-19B'),
|
| 307 |
+
Model('ZhipuAI/cogvlm2-llama3-chinese-chat-19B', 'THUDM/cogvlm2-llama3-chinese-chat-19B'),
|
| 308 |
+
]),
|
| 309 |
+
],
|
| 310 |
+
TemplateType.cogvlm2,
|
| 311 |
+
get_model_tokenizer_cogvlm2,
|
| 312 |
+
architectures=['CogVLMForCausalLM'],
|
| 313 |
+
requires=['transformers<4.42'],
|
| 314 |
+
model_arch=ModelArch.cogvlm))
|
| 315 |
+
|
| 316 |
+
register_model(
|
| 317 |
+
ModelMeta(
|
| 318 |
+
MLLMModelType.cogvlm2_video,
|
| 319 |
+
[
|
| 320 |
+
ModelGroup([
|
| 321 |
+
Model('ZhipuAI/cogvlm2-video-llama3-chat', 'THUDM/cogvlm2-video-llama3-chat'),
|
| 322 |
+
]),
|
| 323 |
+
],
|
| 324 |
+
TemplateType.cogvlm2_video,
|
| 325 |
+
get_model_tokenizer_cogvlm2,
|
| 326 |
+
architectures=['CogVLMVideoForCausalLM'],
|
| 327 |
+
requires=['decord', 'pytorchvideo', 'transformers>=4.42'],
|
| 328 |
+
model_arch=ModelArch.cogvlm,
|
| 329 |
+
tags=['video'],
|
| 330 |
+
))
|
| 331 |
+
|
| 332 |
+
register_model(
|
| 333 |
+
ModelMeta(
|
| 334 |
+
LLMModelType.glm_edge,
|
| 335 |
+
[
|
| 336 |
+
ModelGroup([
|
| 337 |
+
Model('ZhipuAI/glm-edge-1.5b-chat', 'THUDM/glm-edge-1.5b-chat'),
|
| 338 |
+
Model('ZhipuAI/glm-edge-4b-chat', 'THUDM/glm-edge-4b-chat'),
|
| 339 |
+
]),
|
| 340 |
+
],
|
| 341 |
+
TemplateType.glm4,
|
| 342 |
+
get_model_tokenizer_with_flash_attn,
|
| 343 |
+
architectures=['GlmForCausalLM'],
|
| 344 |
+
requires=['transformers>=4.46'],
|
| 345 |
+
))
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_model_tokenizer_glm_edge_v(model_dir: str, *args, **kwargs):
|
| 349 |
+
from transformers import AutoImageProcessor
|
| 350 |
+
processor = AutoImageProcessor.from_pretrained(model_dir)
|
| 351 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
|
| 352 |
+
processor.tokenizer = tokenizer
|
| 353 |
+
return model, processor
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
register_model(
|
| 357 |
+
ModelMeta(
|
| 358 |
+
MLLMModelType.glm_edge_v,
|
| 359 |
+
[
|
| 360 |
+
ModelGroup([
|
| 361 |
+
Model('ZhipuAI/glm-edge-v-2b', 'THUDM/glm-edge-v-2b'),
|
| 362 |
+
Model('ZhipuAI/glm-edge-4b-chat', 'THUDM/glm-edge-4b-chat'),
|
| 363 |
+
]),
|
| 364 |
+
],
|
| 365 |
+
TemplateType.glm_edge_v,
|
| 366 |
+
get_model_tokenizer_glm_edge_v,
|
| 367 |
+
architectures=['GlmForCausalLM'],
|
| 368 |
+
requires=['transformers>=4.46'],
|
| 369 |
+
model_arch=ModelArch.glm_edge_v,
|
| 370 |
+
tags=['vision'],
|
| 371 |
+
))
|
ms-swift/swift/llm/model/model/minimax.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from transformers import AutoConfig, AutoProcessor
|
| 7 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 8 |
+
|
| 9 |
+
from swift.llm import TemplateType
|
| 10 |
+
from swift.utils import get_device, get_device_count, get_dist_setting, get_logger
|
| 11 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 12 |
+
from ..patcher import patch_ignore_check_imports
|
| 13 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 14 |
+
from ..utils import ModelInfo
|
| 15 |
+
|
| 16 |
+
logger = get_logger()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_model_tokenizer_minimax_vl(model_dir: str,
|
| 20 |
+
model_info: ModelInfo,
|
| 21 |
+
model_kwargs: Dict[str, Any],
|
| 22 |
+
load_model: bool = True,
|
| 23 |
+
**kwargs):
|
| 24 |
+
logger.warn('NOTE: minimax-vl-01 model does not support training.')
|
| 25 |
+
n_gpu = get_device_count()
|
| 26 |
+
_, local_rank, _, local_world_size = get_dist_setting()
|
| 27 |
+
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
|
| 28 |
+
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 29 |
+
kwargs['model_config'] = config
|
| 30 |
+
if kwargs.get('attn_impl') == 'flash_attn':
|
| 31 |
+
config.attn_type_list = [1] * len(config.attn_type_list)
|
| 32 |
+
else:
|
| 33 |
+
config.attn_type_list = [0] * len(config.attn_type_list)
|
| 34 |
+
if 'quantization_config' in model_kwargs:
|
| 35 |
+
quantization_config = model_kwargs['quantization_config']
|
| 36 |
+
from transformers import QuantoConfig
|
| 37 |
+
if isinstance(quantization_config, QuantoConfig):
|
| 38 |
+
quantization_config.modules_to_not_convert = (
|
| 39 |
+
[
|
| 40 |
+
'vision_tower',
|
| 41 |
+
'image_newline',
|
| 42 |
+
'multi_modal_projector',
|
| 43 |
+
'lm_head',
|
| 44 |
+
'embed_tokens',
|
| 45 |
+
] + [f'model.layers.{i}.coefficient' for i in range(config.text_config.num_hidden_layers)]
|
| 46 |
+
+ [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.text_config.num_hidden_layers)])
|
| 47 |
+
|
| 48 |
+
if len(device_ids) > 1:
|
| 49 |
+
model_safetensors_index_path = os.path.join(model_dir, 'model.safetensors.index.json')
|
| 50 |
+
with open(model_safetensors_index_path, 'r') as f:
|
| 51 |
+
model_safetensors_index = json.load(f)
|
| 52 |
+
weight_map = model_safetensors_index['weight_map']
|
| 53 |
+
vision_map = {}
|
| 54 |
+
for key, value in weight_map.items():
|
| 55 |
+
if 'vision_tower' in key or 'image_newline' in key or 'multi_modal_projector' in key:
|
| 56 |
+
new_key = key.replace('.weight', '').replace('.bias', '')
|
| 57 |
+
if new_key not in vision_map:
|
| 58 |
+
vision_map[new_key] = value
|
| 59 |
+
|
| 60 |
+
device_map = {
|
| 61 |
+
'language_model.model.embed_tokens': get_device(device_ids[0]),
|
| 62 |
+
'language_model.model.norm': get_device(device_ids[len(device_ids) - 1]),
|
| 63 |
+
'language_model.lm_head': get_device(device_ids[len(device_ids) - 1])
|
| 64 |
+
}
|
| 65 |
+
for key, value in vision_map.items():
|
| 66 |
+
device_map[key] = get_device(device_ids[0])
|
| 67 |
+
device_map['vision_tower.vision_model.post_layernorm'] = get_device(device_ids[0])
|
| 68 |
+
layers_per_device = config.text_config.num_hidden_layers // len(device_ids)
|
| 69 |
+
for i in range(len(device_ids)):
|
| 70 |
+
for j in range(layers_per_device):
|
| 71 |
+
device_map[f'language_model.model.layers.{i * layers_per_device + j}'] = get_device(device_ids[i])
|
| 72 |
+
model_kwargs['device_map'] = device_map
|
| 73 |
+
|
| 74 |
+
MiniMaxVL01ProcessorKwargs = get_class_from_dynamic_module('processing_minimax_vl_01.MiniMaxVL01ProcessorKwargs',
|
| 75 |
+
model_dir)
|
| 76 |
+
get_hw_multiple_of = get_class_from_dynamic_module('processing_minimax_vl_01.get_hw_multiple_of', model_dir)
|
| 77 |
+
get_num_token = get_class_from_dynamic_module('processing_minimax_vl_01.get_num_token', model_dir)
|
| 78 |
+
|
| 79 |
+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
|
| 80 |
+
processor.MiniMaxVL01ProcessorKwargs = MiniMaxVL01ProcessorKwargs
|
| 81 |
+
processor.get_hw_multiple_of = get_hw_multiple_of
|
| 82 |
+
processor.get_num_token = get_num_token
|
| 83 |
+
with patch_ignore_check_imports():
|
| 84 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
|
| 85 |
+
**kwargs)
|
| 86 |
+
processor.tokenizer = tokenizer
|
| 87 |
+
return model, processor
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
register_model(
|
| 91 |
+
ModelMeta(
|
| 92 |
+
MLLMModelType.minimax_vl, [
|
| 93 |
+
ModelGroup([
|
| 94 |
+
Model('MiniMax/MiniMax-VL-01', 'MiniMaxAI/MiniMax-VL-01'),
|
| 95 |
+
]),
|
| 96 |
+
],
|
| 97 |
+
TemplateType.minimax_vl,
|
| 98 |
+
get_model_tokenizer_minimax_vl,
|
| 99 |
+
architectures=['MiniMaxVL01ForConditionalGeneration'],
|
| 100 |
+
tags=['vision']))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_model_tokenizer_minimax_text(model_dir: str,
|
| 104 |
+
model_info: ModelInfo,
|
| 105 |
+
model_kwargs: Dict[str, Any],
|
| 106 |
+
load_model: bool = True,
|
| 107 |
+
**kwargs):
|
| 108 |
+
logger.warn('NOTE: minimax-text-01 model does not support training.')
|
| 109 |
+
n_gpu = get_device_count()
|
| 110 |
+
_, local_rank, _, local_world_size = get_dist_setting()
|
| 111 |
+
device_ids = list(range(max(local_rank, 0), n_gpu, local_world_size))
|
| 112 |
+
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 113 |
+
kwargs['model_config'] = config
|
| 114 |
+
if kwargs.get('attn_impl') == 'flash_attn':
|
| 115 |
+
config.attn_type_list = [1] * len(config.attn_type_list)
|
| 116 |
+
else:
|
| 117 |
+
config.attn_type_list = [0] * len(config.attn_type_list)
|
| 118 |
+
if 'quantization_config' in model_kwargs:
|
| 119 |
+
quantization_config = model_kwargs['quantization_config']
|
| 120 |
+
from transformers import QuantoConfig
|
| 121 |
+
if isinstance(quantization_config, QuantoConfig):
|
| 122 |
+
quantization_config.modules_to_not_convert = (
|
| 123 |
+
[
|
| 124 |
+
'lm_head',
|
| 125 |
+
'embed_tokens',
|
| 126 |
+
] + [f'model.layers.{i}.coefficient' for i in range(config.num_hidden_layers)]
|
| 127 |
+
+ [f'model.layers.{i}.block_sparse_moe.gate' for i in range(config.num_hidden_layers)])
|
| 128 |
+
|
| 129 |
+
if len(device_ids) > 1:
|
| 130 |
+
layers_per_device = config.num_hidden_layers // len(device_ids)
|
| 131 |
+
# set device map
|
| 132 |
+
device_map = {
|
| 133 |
+
'model.embed_tokens': get_device(0),
|
| 134 |
+
'model.norm': get_device(len(device_ids) - 1),
|
| 135 |
+
'lm_head': get_device(len(device_ids) - 1)
|
| 136 |
+
}
|
| 137 |
+
for i in range(len(device_ids)):
|
| 138 |
+
for j in range(layers_per_device):
|
| 139 |
+
device_map[f'model.layers.{i * layers_per_device + j}'] = get_device(i)
|
| 140 |
+
model_kwargs['device_map'] = device_map
|
| 141 |
+
with patch_ignore_check_imports():
|
| 142 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model,
|
| 143 |
+
**kwargs)
|
| 144 |
+
return model, tokenizer
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
register_model(
|
| 148 |
+
ModelMeta(
|
| 149 |
+
LLMModelType.minimax, [
|
| 150 |
+
ModelGroup([
|
| 151 |
+
Model('MiniMax/MiniMax-Text-01', 'MiniMaxAI/MiniMax-Text-01'),
|
| 152 |
+
]),
|
| 153 |
+
],
|
| 154 |
+
TemplateType.minimax,
|
| 155 |
+
get_model_tokenizer_minimax_text,
|
| 156 |
+
architectures=['MiniMaxText01ForCausalLM']))
|
ms-swift/swift/llm/model/model/openbuddy.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from swift.llm import TemplateType
|
| 3 |
+
from swift.utils import get_logger
|
| 4 |
+
from ..constant import LLMModelType
|
| 5 |
+
from ..model_arch import ModelArch
|
| 6 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 7 |
+
|
| 8 |
+
logger = get_logger()
|
| 9 |
+
|
| 10 |
+
register_model(
|
| 11 |
+
ModelMeta(
|
| 12 |
+
LLMModelType.openbuddy_llama,
|
| 13 |
+
[
|
| 14 |
+
ModelGroup([
|
| 15 |
+
Model('OpenBuddy/openbuddy-llama-65b-v8-bf16', 'OpenBuddy/openbuddy-llama-65b-v8-bf16'),
|
| 16 |
+
]),
|
| 17 |
+
ModelGroup([
|
| 18 |
+
Model('OpenBuddy/openbuddy-llama2-13b-v8.1-fp16', 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16'),
|
| 19 |
+
Model('OpenBuddy/openbuddy-llama2-70b-v10.1-bf16', 'OpenBuddy/openbuddy-llama2-70b-v10.1-bf16'),
|
| 20 |
+
]),
|
| 21 |
+
ModelGroup([
|
| 22 |
+
Model('OpenBuddy/openbuddy-deepseek-67b-v15.2', 'OpenBuddy/openbuddy-deepseek-67b-v15.2'),
|
| 23 |
+
]),
|
| 24 |
+
],
|
| 25 |
+
TemplateType.openbuddy,
|
| 26 |
+
get_model_tokenizer_with_flash_attn,
|
| 27 |
+
model_arch=ModelArch.llama,
|
| 28 |
+
architectures=['LlamaForCausalLM'],
|
| 29 |
+
))
|
| 30 |
+
|
| 31 |
+
register_model(
|
| 32 |
+
ModelMeta(
|
| 33 |
+
LLMModelType.openbuddy_llama3,
|
| 34 |
+
[
|
| 35 |
+
ModelGroup([
|
| 36 |
+
Model('OpenBuddy/openbuddy-llama3-8b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-8b-v21.1-8k'),
|
| 37 |
+
Model('OpenBuddy/openbuddy-llama3-70b-v21.1-8k', 'OpenBuddy/openbuddy-llama3-70b-v21.1-8k'),
|
| 38 |
+
Model('OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k', 'OpenBuddy/openbuddy-yi1.5-34b-v21.3-32k'),
|
| 39 |
+
]),
|
| 40 |
+
ModelGroup([
|
| 41 |
+
Model('OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k', 'OpenBuddy/openbuddy-llama3.1-8b-v22.1-131k'),
|
| 42 |
+
Model('OpenBuddy/openbuddy-nemotron-70b-v23.2-131k', 'OpenBuddy/openbuddy-nemotron-70b-v23.2-131k'),
|
| 43 |
+
],
|
| 44 |
+
requires=['transformers>=4.43']),
|
| 45 |
+
ModelGroup(
|
| 46 |
+
[Model('OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k', 'OpenBuddy/openbuddy-llama3.3-70b-v24.3-131k')],
|
| 47 |
+
requires=['transformers>=4.45'])
|
| 48 |
+
],
|
| 49 |
+
TemplateType.openbuddy2,
|
| 50 |
+
get_model_tokenizer_with_flash_attn,
|
| 51 |
+
model_arch=ModelArch.llama,
|
| 52 |
+
architectures=['LlamaForCausalLM'],
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
register_model(
|
| 56 |
+
ModelMeta(
|
| 57 |
+
LLMModelType.openbuddy_mistral,
|
| 58 |
+
[
|
| 59 |
+
ModelGroup([
|
| 60 |
+
Model('OpenBuddy/openbuddy-mistral-7b-v17.1-32k', 'OpenBuddy/openbuddy-mistral-7b-v17.1-32k'),
|
| 61 |
+
]),
|
| 62 |
+
ModelGroup([
|
| 63 |
+
Model('OpenBuddy/openbuddy-zephyr-7b-v14.1', 'OpenBuddy/openbuddy-zephyr-7b-v14.1'),
|
| 64 |
+
]),
|
| 65 |
+
],
|
| 66 |
+
TemplateType.openbuddy,
|
| 67 |
+
get_model_tokenizer_with_flash_attn,
|
| 68 |
+
model_arch=ModelArch.llama,
|
| 69 |
+
requires=['transformers>=4.34'],
|
| 70 |
+
architectures=['MistralForCausalLM'],
|
| 71 |
+
))
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
LLMModelType.openbuddy_mixtral,
|
| 76 |
+
[
|
| 77 |
+
ModelGroup([
|
| 78 |
+
Model('OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k', 'OpenBuddy/openbuddy-mixtral-7bx8-v18.1-32k'),
|
| 79 |
+
], ),
|
| 80 |
+
],
|
| 81 |
+
TemplateType.openbuddy,
|
| 82 |
+
get_model_tokenizer_with_flash_attn,
|
| 83 |
+
architectures=['MixtralForCausalLM'],
|
| 84 |
+
requires=['transformers>=4.36'],
|
| 85 |
+
))
|
ms-swift/swift/llm/model/model/qwen.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizerBase
|
| 6 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 7 |
+
from transformers.models.auto.tokenization_auto import get_tokenizer_config
|
| 8 |
+
|
| 9 |
+
from swift.llm import TemplateType
|
| 10 |
+
from swift.utils import get_device_count, get_dist_setting, get_env_args, get_logger
|
| 11 |
+
from ..constant import LLMModelType, MLLMModelType, RMModelType
|
| 12 |
+
from ..model_arch import ModelArch
|
| 13 |
+
from ..patcher import patch_fixed_device, patch_output_clone, patch_output_to_input_device
|
| 14 |
+
from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal, get_model_tokenizer_reward_model,
|
| 15 |
+
get_model_tokenizer_with_flash_attn, register_model)
|
| 16 |
+
from ..utils import AttnImpl, ModelInfo, use_submodel_func
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
dtype_mapping = {torch.float16: 'fp16', torch.bfloat16: 'bf16', torch.float32: 'fp32'}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_model_tokenizer_qwen(model_dir: str,
|
| 23 |
+
model_info: ModelInfo,
|
| 24 |
+
model_kwargs: Dict[str, Any],
|
| 25 |
+
load_model: bool = True,
|
| 26 |
+
model_config=None,
|
| 27 |
+
**kwargs):
|
| 28 |
+
if model_config is None:
|
| 29 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 30 |
+
if model_info.torch_dtype is not None:
|
| 31 |
+
k_true = dtype_mapping[model_info.torch_dtype]
|
| 32 |
+
for k in dtype_mapping.values():
|
| 33 |
+
setattr(model_config, k, k == k_true)
|
| 34 |
+
|
| 35 |
+
quantization_config = model_kwargs.get('quantization_config')
|
| 36 |
+
if not isinstance(quantization_config, BitsAndBytesConfig):
|
| 37 |
+
# not bnb quant
|
| 38 |
+
model_config.torch_dtype = None
|
| 39 |
+
use_flash_attn = AttnImpl.to_use_flash_attn(kwargs.pop('attn_impl', None), 'auto')
|
| 40 |
+
model_config.use_flash_attn = use_flash_attn
|
| 41 |
+
kwargs['model_config'] = model_config
|
| 42 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 43 |
+
try:
|
| 44 |
+
# fix mp+ddp bug
|
| 45 |
+
model.transformer.registered_causal_mask = model.transformer.registered_causal_mask.cuda()
|
| 46 |
+
logger.info('registered_causal_mask to cuda')
|
| 47 |
+
except AttributeError:
|
| 48 |
+
pass
|
| 49 |
+
if tokenizer.eos_token_id is None:
|
| 50 |
+
tokenizer.eos_token_id = tokenizer.eod_id
|
| 51 |
+
return model, tokenizer
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
register_model(
|
| 55 |
+
ModelMeta(
|
| 56 |
+
LLMModelType.qwen,
|
| 57 |
+
[
|
| 58 |
+
# qwen
|
| 59 |
+
ModelGroup([
|
| 60 |
+
# chat
|
| 61 |
+
Model('Qwen/Qwen-1_8B-Chat', 'Qwen/Qwen-1_8B-Chat'),
|
| 62 |
+
Model('Qwen/Qwen-7B-Chat', 'Qwen/Qwen-7B-Chat'),
|
| 63 |
+
Model('Qwen/Qwen-14B-Chat', 'Qwen/Qwen-14B-Chat'),
|
| 64 |
+
Model('Qwen/Qwen-72B-Chat', 'Qwen/Qwen-72B-Chat'),
|
| 65 |
+
# base
|
| 66 |
+
Model('Qwen/Qwen-1_8B', 'Qwen/Qwen-1_8B'),
|
| 67 |
+
Model('Qwen/Qwen-7B', 'Qwen/Qwen-7B'),
|
| 68 |
+
Model('Qwen/Qwen-14B', 'Qwen/Qwen-14B'),
|
| 69 |
+
Model('Qwen/Qwen-72B', 'Qwen/Qwen-72B'),
|
| 70 |
+
# gptq-int4
|
| 71 |
+
Model('Qwen/Qwen-1_8B-Chat-Int4', 'Qwen/Qwen-1_8B-Chat-Int4'),
|
| 72 |
+
Model('Qwen/Qwen-7B-Chat-Int4', 'Qwen/Qwen-7B-Chat-Int4'),
|
| 73 |
+
Model('Qwen/Qwen-14B-Chat-Int4', 'Qwen/Qwen-14B-Chat-Int4'),
|
| 74 |
+
Model('Qwen/Qwen-72B-Chat-Int4', 'Qwen/Qwen-72B-Chat-Int4'),
|
| 75 |
+
# gptq-int8
|
| 76 |
+
Model('Qwen/Qwen-1_8B-Chat-Int8', 'Qwen/Qwen-1_8B-Chat-Int8'),
|
| 77 |
+
Model('Qwen/Qwen-7B-Chat-Int8', 'Qwen/Qwen-7B-Chat-Int8'),
|
| 78 |
+
Model('Qwen/Qwen-14B-Chat-Int8', 'Qwen/Qwen-14B-Chat-Int8'),
|
| 79 |
+
Model('Qwen/Qwen-72B-Chat-Int8', 'Qwen/Qwen-72B-Chat-Int8'),
|
| 80 |
+
]),
|
| 81 |
+
# tongyi-finance
|
| 82 |
+
ModelGroup([
|
| 83 |
+
Model('TongyiFinance/Tongyi-Finance-14B-Chat', 'jxy/Tongyi-Finance-14B-Chat'),
|
| 84 |
+
Model('TongyiFinance/Tongyi-Finance-14B'),
|
| 85 |
+
Model('TongyiFinance/Tongyi-Finance-14B-Chat-Int4', 'jxy/Tongyi-Finance-14B-Chat-Int4'),
|
| 86 |
+
],
|
| 87 |
+
tags=['financial']),
|
| 88 |
+
],
|
| 89 |
+
TemplateType.qwen,
|
| 90 |
+
get_model_tokenizer_qwen,
|
| 91 |
+
architectures=['QWenLMHeadModel'],
|
| 92 |
+
model_arch=ModelArch.qwen))
|
| 93 |
+
|
| 94 |
+
register_model(
|
| 95 |
+
ModelMeta(
|
| 96 |
+
LLMModelType.modelscope_agent,
|
| 97 |
+
[ModelGroup([
|
| 98 |
+
Model('iic/ModelScope-Agent-7B'),
|
| 99 |
+
Model('iic/ModelScope-Agent-14B'),
|
| 100 |
+
])],
|
| 101 |
+
TemplateType.modelscope_agent,
|
| 102 |
+
get_model_tokenizer_qwen,
|
| 103 |
+
architectures=['QWenLMHeadModel'],
|
| 104 |
+
model_arch=ModelArch.qwen))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _qwen_vl_audio_decode(self, *args, skip_special_tokens=False, **kwargs) -> str:
|
| 108 |
+
if skip_special_tokens:
|
| 109 |
+
token_ids = kwargs['token_ids']
|
| 110 |
+
while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
|
| 111 |
+
token_ids.pop()
|
| 112 |
+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
|
| 113 |
+
else:
|
| 114 |
+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def fix_qwen_inplace_bug(model) -> None:
|
| 118 |
+
# qwen-vl, qwen-audio
|
| 119 |
+
first_drop = model.transformer.drop
|
| 120 |
+
if first_drop.p == 0.:
|
| 121 |
+
# fix in-place operation bug
|
| 122 |
+
patch_output_clone(first_drop)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_model_tokenizer_qwen_audio(model_dir: str,
|
| 126 |
+
model_info: ModelInfo,
|
| 127 |
+
model_kwargs: Dict[str, Any],
|
| 128 |
+
load_model: bool = True,
|
| 129 |
+
**kwargs):
|
| 130 |
+
tokenizer_config = get_tokenizer_config(model_dir)
|
| 131 |
+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
|
| 132 |
+
tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
|
| 133 |
+
tokenizer_cls._auto_class = 'AutoTokenizer'
|
| 134 |
+
tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
|
| 135 |
+
if not hasattr(tokenizer_cls, '_old_decode'):
|
| 136 |
+
tokenizer_cls._old_decode = tokenizer_cls._decode
|
| 137 |
+
tokenizer_cls._decode = _qwen_vl_audio_decode
|
| 138 |
+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
|
| 139 |
+
model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 140 |
+
if model is not None:
|
| 141 |
+
fix_qwen_inplace_bug(model)
|
| 142 |
+
|
| 143 |
+
return model, tokenizer
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
register_model(
|
| 147 |
+
ModelMeta(
|
| 148 |
+
MLLMModelType.qwen_audio, [
|
| 149 |
+
ModelGroup([
|
| 150 |
+
Model('Qwen/Qwen-Audio-Chat', 'Qwen/Qwen-Audio-Chat'),
|
| 151 |
+
Model('Qwen/Qwen-Audio', 'Qwen/Qwen-Audio'),
|
| 152 |
+
])
|
| 153 |
+
],
|
| 154 |
+
TemplateType.qwen_audio,
|
| 155 |
+
get_model_tokenizer_qwen_audio,
|
| 156 |
+
model_arch=ModelArch.qwen_audio,
|
| 157 |
+
architectures=['QWenLMHeadModel'],
|
| 158 |
+
additional_saved_files=['mel_filters.npz'],
|
| 159 |
+
tags=['audio']))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _qwen_vl_visual_block_forward(
|
| 163 |
+
self,
|
| 164 |
+
q_x: torch.Tensor,
|
| 165 |
+
k_x: Optional[torch.Tensor] = None,
|
| 166 |
+
v_x: Optional[torch.Tensor] = None,
|
| 167 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 168 |
+
):
|
| 169 |
+
k_x = self.ln_1_kv(k_x) if hasattr(self, 'ln_1_kv') and k_x is not None else None
|
| 170 |
+
v_x = self.ln_1_kv(v_x) if hasattr(self, 'ln_1_kv') and v_x is not None else None
|
| 171 |
+
|
| 172 |
+
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
|
| 173 |
+
z = self.mlp(self.ln_2(x))
|
| 174 |
+
x = x.to(z.device) + z # FIX
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_model_tokenizer_qwen_vl(model_dir: str,
|
| 179 |
+
model_info: ModelInfo,
|
| 180 |
+
model_kwargs: Dict[str, Any],
|
| 181 |
+
load_model: bool = True,
|
| 182 |
+
**kwargs):
|
| 183 |
+
if (model_kwargs.get('quantization_config') is not None
|
| 184 |
+
and isinstance(model_kwargs['quantization_config'], BitsAndBytesConfig)):
|
| 185 |
+
# https://github.com/pytorch/pytorch/issues/58969
|
| 186 |
+
model_kwargs['quantization_config'].llm_int8_skip_modules = ['lm_head', 'attn_pool.attn']
|
| 187 |
+
_TransformerBlock = get_class_from_dynamic_module('visual.TransformerBlock', model_dir)
|
| 188 |
+
|
| 189 |
+
def _get_cast_dtype(self) -> torch.dtype:
|
| 190 |
+
return self.resblocks[0].ln_1.weight.dtype
|
| 191 |
+
|
| 192 |
+
_TransformerBlock.__old_get_cast_dtype = _TransformerBlock.get_cast_dtype
|
| 193 |
+
_TransformerBlock.get_cast_dtype = _get_cast_dtype
|
| 194 |
+
|
| 195 |
+
tokenizer_config = get_tokenizer_config(model_dir)
|
| 196 |
+
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
|
| 197 |
+
tokenizer_cls: Type[PreTrainedTokenizerBase] = get_class_from_dynamic_module(class_ref, model_dir)
|
| 198 |
+
tokenizer_cls._auto_class = 'AutoTokenizer'
|
| 199 |
+
tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
|
| 200 |
+
if not hasattr(tokenizer_cls, '_old_decode'):
|
| 201 |
+
tokenizer_cls._old_decode = tokenizer_cls._decode
|
| 202 |
+
tokenizer_cls._decode = _qwen_vl_audio_decode
|
| 203 |
+
# fix device_map is 4
|
| 204 |
+
n_gpu = get_device_count()
|
| 205 |
+
local_world_size = get_dist_setting()[3]
|
| 206 |
+
if n_gpu // local_world_size >= 4:
|
| 207 |
+
visual_block_cls = get_class_from_dynamic_module('visual.VisualAttentionBlock', model_dir)
|
| 208 |
+
visual_block_cls.__old_forward = visual_block_cls.forward
|
| 209 |
+
visual_block_cls.forward = _qwen_vl_visual_block_forward
|
| 210 |
+
|
| 211 |
+
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
|
| 212 |
+
model, tokenizer = get_model_tokenizer_qwen(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 213 |
+
if model is not None:
|
| 214 |
+
device_type = next(model.parameters()).device.type
|
| 215 |
+
fix_qwen_inplace_bug(model)
|
| 216 |
+
# fix device_map is 4
|
| 217 |
+
if n_gpu // local_world_size >= 4:
|
| 218 |
+
model.transformer.visual.proj.data = model.transformer.visual.proj.to(
|
| 219 |
+
model.transformer.visual.ln_post.bias.device)
|
| 220 |
+
# fix images cuda:1 bug
|
| 221 |
+
patch_fixed_device(model.transformer.visual, f'{device_type}:0')
|
| 222 |
+
return model, tokenizer
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
register_model(
|
| 226 |
+
ModelMeta(
|
| 227 |
+
MLLMModelType.qwen_vl, [
|
| 228 |
+
ModelGroup([
|
| 229 |
+
Model('Qwen/Qwen-VL-Chat', 'Qwen/Qwen-VL-Chat'),
|
| 230 |
+
Model('Qwen/Qwen-VL', 'Qwen/Qwen-VL'),
|
| 231 |
+
Model('Qwen/Qwen-VL-Chat-Int4', 'Qwen/Qwen-VL-Chat-Int4'),
|
| 232 |
+
])
|
| 233 |
+
],
|
| 234 |
+
TemplateType.qwen_vl,
|
| 235 |
+
get_model_tokenizer_qwen_vl,
|
| 236 |
+
model_arch=ModelArch.qwen_vl,
|
| 237 |
+
architectures=['QWenLMHeadModel'],
|
| 238 |
+
additional_saved_files=['SimSun.ttf'],
|
| 239 |
+
tags=['vision']))
|
| 240 |
+
|
| 241 |
+
register_model(
|
| 242 |
+
ModelMeta(
|
| 243 |
+
LLMModelType.qwen2,
|
| 244 |
+
[
|
| 245 |
+
# qwen1.5
|
| 246 |
+
ModelGroup([
|
| 247 |
+
# chat
|
| 248 |
+
Model('Qwen/Qwen1.5-0.5B-Chat', 'Qwen/Qwen1.5-0.5B-Chat'),
|
| 249 |
+
Model('Qwen/Qwen1.5-1.8B-Chat', 'Qwen/Qwen1.5-1.8B-Chat'),
|
| 250 |
+
Model('Qwen/Qwen1.5-4B-Chat', 'Qwen/Qwen1.5-4B-Chat'),
|
| 251 |
+
Model('Qwen/Qwen1.5-7B-Chat', 'Qwen/Qwen1.5-7B-Chat'),
|
| 252 |
+
Model('Qwen/Qwen1.5-14B-Chat', 'Qwen/Qwen1.5-14B-Chat'),
|
| 253 |
+
Model('Qwen/Qwen1.5-32B-Chat', 'Qwen/Qwen1.5-32B-Chat'),
|
| 254 |
+
Model('Qwen/Qwen1.5-72B-Chat', 'Qwen/Qwen1.5-72B-Chat'),
|
| 255 |
+
Model('Qwen/Qwen1.5-110B-Chat', 'Qwen/Qwen1.5-110B-Chat'),
|
| 256 |
+
# base
|
| 257 |
+
Model('Qwen/Qwen1.5-0.5B', 'Qwen/Qwen1.5-0.5B'),
|
| 258 |
+
Model('Qwen/Qwen1.5-1.8B', 'Qwen/Qwen1.5-1.8B'),
|
| 259 |
+
Model('Qwen/Qwen1.5-4B', 'Qwen/Qwen1.5-4B'),
|
| 260 |
+
Model('Qwen/Qwen1.5-7B', 'Qwen/Qwen1.5-7B'),
|
| 261 |
+
Model('Qwen/Qwen1.5-14B', 'Qwen/Qwen1.5-14B'),
|
| 262 |
+
Model('Qwen/Qwen1.5-32B', 'Qwen/Qwen1.5-32B'),
|
| 263 |
+
Model('Qwen/Qwen1.5-72B', 'Qwen/Qwen1.5-72B'),
|
| 264 |
+
Model('Qwen/Qwen1.5-110B', 'Qwen/Qwen1.5-110B'),
|
| 265 |
+
# gptq-int4
|
| 266 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4'),
|
| 267 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4'),
|
| 268 |
+
Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int4'),
|
| 269 |
+
Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int4'),
|
| 270 |
+
Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int4'),
|
| 271 |
+
Model('Qwen/Qwen1.5-32B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-32B-Chat-GPTQ-Int4'),
|
| 272 |
+
Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int4'),
|
| 273 |
+
Model('Qwen/Qwen1.5-110B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-110B-Chat-GPTQ-Int4'),
|
| 274 |
+
# gptq-int8
|
| 275 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8'),
|
| 276 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8'),
|
| 277 |
+
Model('Qwen/Qwen1.5-4B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-4B-Chat-GPTQ-Int8'),
|
| 278 |
+
Model('Qwen/Qwen1.5-7B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-7B-Chat-GPTQ-Int8'),
|
| 279 |
+
Model('Qwen/Qwen1.5-14B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-14B-Chat-GPTQ-Int8'),
|
| 280 |
+
Model('Qwen/Qwen1.5-72B-Chat-GPTQ-Int8', 'Qwen/Qwen1.5-72B-Chat-GPTQ-Int8'),
|
| 281 |
+
# awq-int4
|
| 282 |
+
Model('Qwen/Qwen1.5-0.5B-Chat-AWQ', 'Qwen/Qwen1.5-0.5B-Chat-AWQ'),
|
| 283 |
+
Model('Qwen/Qwen1.5-1.8B-Chat-AWQ', 'Qwen/Qwen1.5-1.8B-Chat-AWQ'),
|
| 284 |
+
Model('Qwen/Qwen1.5-4B-Chat-AWQ', 'Qwen/Qwen1.5-4B-Chat-AWQ'),
|
| 285 |
+
Model('Qwen/Qwen1.5-7B-Chat-AWQ', 'Qwen/Qwen1.5-7B-Chat-AWQ'),
|
| 286 |
+
Model('Qwen/Qwen1.5-14B-Chat-AWQ', 'Qwen/Qwen1.5-14B-Chat-AWQ'),
|
| 287 |
+
Model('Qwen/Qwen1.5-32B-Chat-AWQ', 'Qwen/Qwen1.5-32B-Chat-AWQ'),
|
| 288 |
+
Model('Qwen/Qwen1.5-72B-Chat-AWQ', 'Qwen/Qwen1.5-72B-Chat-AWQ'),
|
| 289 |
+
Model('Qwen/Qwen1.5-110B-Chat-AWQ', 'Qwen/Qwen1.5-110B-Chat-AWQ'),
|
| 290 |
+
]),
|
| 291 |
+
# code-qwen1.5
|
| 292 |
+
ModelGroup([
|
| 293 |
+
Model('Qwen/CodeQwen1.5-7B', 'Qwen/CodeQwen1.5-7B'),
|
| 294 |
+
Model('Qwen/CodeQwen1.5-7B-Chat', 'Qwen/CodeQwen1.5-7B-Chat'),
|
| 295 |
+
Model('Qwen/CodeQwen1.5-7B-Chat-AWQ', 'Qwen/CodeQwen1.5-7B-Chat-AWQ'),
|
| 296 |
+
],
|
| 297 |
+
tags=['coding']),
|
| 298 |
+
# qwen2
|
| 299 |
+
ModelGroup([
|
| 300 |
+
# instruct
|
| 301 |
+
Model('Qwen/Qwen2-0.5B-Instruct', 'Qwen/Qwen2-0.5B-Instruct'),
|
| 302 |
+
Model('Qwen/Qwen2-1.5B-Instruct', 'Qwen/Qwen2-1.5B-Instruct'),
|
| 303 |
+
Model('Qwen/Qwen2-7B-Instruct', 'Qwen/Qwen2-7B-Instruct'),
|
| 304 |
+
Model('Qwen/Qwen2-72B-Instruct', 'Qwen/Qwen2-72B-Instruct'),
|
| 305 |
+
# base
|
| 306 |
+
Model('Qwen/Qwen2-0.5B', 'Qwen/Qwen2-0.5B'),
|
| 307 |
+
Model('Qwen/Qwen2-1.5B', 'Qwen/Qwen2-1.5B'),
|
| 308 |
+
Model('Qwen/Qwen2-7B', 'Qwen/Qwen2-7B'),
|
| 309 |
+
Model('Qwen/Qwen2-72B', 'Qwen/Qwen2-72B'),
|
| 310 |
+
# gptq-int4
|
| 311 |
+
Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4'),
|
| 312 |
+
Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4'),
|
| 313 |
+
Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int4'),
|
| 314 |
+
Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int4'),
|
| 315 |
+
# gptq-int8
|
| 316 |
+
Model('Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8'),
|
| 317 |
+
Model('Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8'),
|
| 318 |
+
Model('Qwen/Qwen2-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-7B-Instruct-GPTQ-Int8'),
|
| 319 |
+
Model('Qwen/Qwen2-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-72B-Instruct-GPTQ-Int8'),
|
| 320 |
+
# awq-int4
|
| 321 |
+
Model('Qwen/Qwen2-0.5B-Instruct-AWQ', 'Qwen/Qwen2-0.5B-Instruct-AWQ'),
|
| 322 |
+
Model('Qwen/Qwen2-1.5B-Instruct-AWQ', 'Qwen/Qwen2-1.5B-Instruct-AWQ'),
|
| 323 |
+
Model('Qwen/Qwen2-7B-Instruct-AWQ', 'Qwen/Qwen2-7B-Instruct-AWQ'),
|
| 324 |
+
Model('Qwen/Qwen2-72B-Instruct-AWQ', 'Qwen/Qwen2-72B-Instruct-AWQ'),
|
| 325 |
+
]),
|
| 326 |
+
# qwen2-math
|
| 327 |
+
ModelGroup(
|
| 328 |
+
[
|
| 329 |
+
# instruct
|
| 330 |
+
Model('Qwen/Qwen2-Math-1.5B-Instruct', 'Qwen/Qwen2-Math-1.5B-Instruct'),
|
| 331 |
+
Model('Qwen/Qwen2-Math-7B-Instruct', 'Qwen/Qwen2-Math-7B-Instruct'),
|
| 332 |
+
Model('Qwen/Qwen2-Math-72B-Instruct', 'Qwen/Qwen2-Math-72B-Instruct'),
|
| 333 |
+
# base
|
| 334 |
+
Model('Qwen/Qwen2-Math-1.5B', 'Qwen/Qwen2-Math-1.5B'),
|
| 335 |
+
Model('Qwen/Qwen2-Math-7B', 'Qwen/Qwen2-Math-7B'),
|
| 336 |
+
Model('Qwen/Qwen2-Math-72B', 'Qwen/Qwen2-Math-72B'),
|
| 337 |
+
],
|
| 338 |
+
tags=['math']),
|
| 339 |
+
# qwen2.5-1m
|
| 340 |
+
ModelGroup([
|
| 341 |
+
Model('Qwen/Qwen2.5-7B-Instruct-1M', 'Qwen/Qwen2.5-7B-Instruct-1M'),
|
| 342 |
+
Model('Qwen/Qwen2.5-14B-Instruct-1M', 'Qwen/Qwen2.5-14B-Instruct-1M'),
|
| 343 |
+
]),
|
| 344 |
+
# other
|
| 345 |
+
ModelGroup([Model('PowerInfer/SmallThinker-3B-Preview', 'PowerInfer/SmallThinker-3B-Preview')]),
|
| 346 |
+
],
|
| 347 |
+
TemplateType.qwen,
|
| 348 |
+
get_model_tokenizer_with_flash_attn,
|
| 349 |
+
architectures=['Qwen2ForCausalLM'],
|
| 350 |
+
requires=['transformers>=4.37'],
|
| 351 |
+
model_arch=ModelArch.llama))
|
| 352 |
+
|
| 353 |
+
register_model(
|
| 354 |
+
ModelMeta(
|
| 355 |
+
LLMModelType.qwen2_5,
|
| 356 |
+
[
|
| 357 |
+
# qwen2.5
|
| 358 |
+
ModelGroup([
|
| 359 |
+
# instruct
|
| 360 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct', 'Qwen/Qwen2.5-0.5B-Instruct'),
|
| 361 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct', 'Qwen/Qwen2.5-1.5B-Instruct'),
|
| 362 |
+
Model('Qwen/Qwen2.5-3B-Instruct', 'Qwen/Qwen2.5-3B-Instruct'),
|
| 363 |
+
Model('Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-7B-Instruct'),
|
| 364 |
+
Model('Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-14B-Instruct'),
|
| 365 |
+
Model('Qwen/Qwen2.5-32B-Instruct', 'Qwen/Qwen2.5-32B-Instruct'),
|
| 366 |
+
Model('Qwen/Qwen2.5-72B-Instruct', 'Qwen/Qwen2.5-72B-Instruct'),
|
| 367 |
+
# base
|
| 368 |
+
Model('Qwen/Qwen2.5-0.5B', 'Qwen/Qwen2.5-0.5B'),
|
| 369 |
+
Model('Qwen/Qwen2.5-1.5B', 'Qwen/Qwen2.5-1.5B'),
|
| 370 |
+
Model('Qwen/Qwen2.5-3B', 'Qwen/Qwen2.5-3B'),
|
| 371 |
+
Model('Qwen/Qwen2.5-7B', 'Qwen/Qwen2.5-7B'),
|
| 372 |
+
Model('Qwen/Qwen2.5-14B', 'Qwen/Qwen2.5-14B'),
|
| 373 |
+
Model('Qwen/Qwen2.5-32B', 'Qwen/Qwen2.5-32B'),
|
| 374 |
+
Model('Qwen/Qwen2.5-72B', 'Qwen/Qwen2.5-72B'),
|
| 375 |
+
# gptq-int4
|
| 376 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4'),
|
| 377 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4'),
|
| 378 |
+
Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4'),
|
| 379 |
+
Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4'),
|
| 380 |
+
Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4'),
|
| 381 |
+
Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4'),
|
| 382 |
+
Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4'),
|
| 383 |
+
# gptq-int8
|
| 384 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8'),
|
| 385 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8'),
|
| 386 |
+
Model('Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8'),
|
| 387 |
+
Model('Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8'),
|
| 388 |
+
Model('Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8'),
|
| 389 |
+
Model('Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8'),
|
| 390 |
+
Model('Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8'),
|
| 391 |
+
# awq-int4
|
| 392 |
+
Model('Qwen/Qwen2.5-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-0.5B-Instruct-AWQ'),
|
| 393 |
+
Model('Qwen/Qwen2.5-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-1.5B-Instruct-AWQ'),
|
| 394 |
+
Model('Qwen/Qwen2.5-3B-Instruct-AWQ', 'Qwen/Qwen2.5-3B-Instruct-AWQ'),
|
| 395 |
+
Model('Qwen/Qwen2.5-7B-Instruct-AWQ', 'Qwen/Qwen2.5-7B-Instruct-AWQ'),
|
| 396 |
+
Model('Qwen/Qwen2.5-14B-Instruct-AWQ', 'Qwen/Qwen2.5-14B-Instruct-AWQ'),
|
| 397 |
+
Model('Qwen/Qwen2.5-32B-Instruct-AWQ', 'Qwen/Qwen2.5-32B-Instruct-AWQ'),
|
| 398 |
+
Model('Qwen/Qwen2.5-72B-Instruct-AWQ', 'Qwen/Qwen2.5-72B-Instruct-AWQ'),
|
| 399 |
+
]),
|
| 400 |
+
# qwen2.5-coder
|
| 401 |
+
ModelGroup(
|
| 402 |
+
[
|
| 403 |
+
# instruct
|
| 404 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct', 'Qwen/Qwen2.5-Coder-0.5B-Instruct'),
|
| 405 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct', 'Qwen/Qwen2.5-Coder-1.5B-Instruct'),
|
| 406 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct', 'Qwen/Qwen2.5-Coder-3B-Instruct'),
|
| 407 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct', 'Qwen/Qwen2.5-Coder-7B-Instruct'),
|
| 408 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct', 'Qwen/Qwen2.5-Coder-14B-Instruct'),
|
| 409 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct', 'Qwen/Qwen2.5-Coder-32B-Instruct'),
|
| 410 |
+
# base
|
| 411 |
+
Model('Qwen/Qwen2.5-Coder-0.5B', 'Qwen/Qwen2.5-Coder-0.5B'),
|
| 412 |
+
Model('Qwen/Qwen2.5-Coder-1.5B', 'Qwen/Qwen2.5-Coder-1.5B'),
|
| 413 |
+
Model('Qwen/Qwen2.5-Coder-3B', 'Qwen/Qwen2.5-Coder-3B'),
|
| 414 |
+
Model('Qwen/Qwen2.5-Coder-7B', 'Qwen/Qwen2.5-Coder-7B'),
|
| 415 |
+
Model('Qwen/Qwen2.5-Coder-14B', 'Qwen/Qwen2.5-Coder-14B'),
|
| 416 |
+
Model('Qwen/Qwen2.5-Coder-32B', 'Qwen/Qwen2.5-Coder-32B'),
|
| 417 |
+
# AWQ
|
| 418 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ'),
|
| 419 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ'),
|
| 420 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-3B-Instruct-AWQ'),
|
| 421 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-7B-Instruct-AWQ'),
|
| 422 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-14B-Instruct-AWQ'),
|
| 423 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', 'Qwen/Qwen2.5-Coder-32B-Instruct-AWQ'),
|
| 424 |
+
# GPTQ
|
| 425 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4'),
|
| 426 |
+
Model('Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8'),
|
| 427 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4'),
|
| 428 |
+
Model('Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8'),
|
| 429 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4'),
|
| 430 |
+
Model('Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8'),
|
| 431 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4'),
|
| 432 |
+
Model('Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8'),
|
| 433 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4'),
|
| 434 |
+
Model('Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8'),
|
| 435 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4'),
|
| 436 |
+
Model('Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', 'Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8'),
|
| 437 |
+
],
|
| 438 |
+
tags=['coding'])
|
| 439 |
+
],
|
| 440 |
+
TemplateType.qwen2_5,
|
| 441 |
+
get_model_tokenizer_with_flash_attn,
|
| 442 |
+
architectures=['Qwen2ForCausalLM'],
|
| 443 |
+
requires=['transformers>=4.37'],
|
| 444 |
+
model_arch=ModelArch.llama))
|
| 445 |
+
|
| 446 |
+
register_model(
|
| 447 |
+
ModelMeta(
|
| 448 |
+
LLMModelType.qwen2_5_math,
|
| 449 |
+
[
|
| 450 |
+
# qwen2.5-math
|
| 451 |
+
ModelGroup(
|
| 452 |
+
[
|
| 453 |
+
# instruct
|
| 454 |
+
Model('Qwen/Qwen2.5-Math-1.5B-Instruct', 'Qwen/Qwen2.5-Math-1.5B-Instruct'),
|
| 455 |
+
Model('Qwen/Qwen2.5-Math-7B-Instruct', 'Qwen/Qwen2.5-Math-7B-Instruct'),
|
| 456 |
+
Model('Qwen/Qwen2.5-Math-72B-Instruct', 'Qwen/Qwen2.5-Math-72B-Instruct'),
|
| 457 |
+
# base
|
| 458 |
+
Model('Qwen/Qwen2.5-Math-1.5B', 'Qwen/Qwen2.5-Math-1.5B'),
|
| 459 |
+
Model('Qwen/Qwen2.5-Math-7B', 'Qwen/Qwen2.5-Math-7B'),
|
| 460 |
+
Model('Qwen/Qwen2.5-Math-72B', 'Qwen/Qwen2.5-Math-72B'),
|
| 461 |
+
],
|
| 462 |
+
tags=['math']),
|
| 463 |
+
],
|
| 464 |
+
TemplateType.qwen2_5_math,
|
| 465 |
+
get_model_tokenizer_with_flash_attn,
|
| 466 |
+
architectures=['Qwen2ForCausalLM'],
|
| 467 |
+
requires=['transformers>=4.37'],
|
| 468 |
+
model_arch=ModelArch.llama))
|
| 469 |
+
|
| 470 |
+
register_model(
|
| 471 |
+
ModelMeta(
|
| 472 |
+
LLMModelType.qwen2_moe,
|
| 473 |
+
[
|
| 474 |
+
# qwen1.5-moe
|
| 475 |
+
ModelGroup([
|
| 476 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'Qwen/Qwen1.5-MoE-A2.7B-Chat'),
|
| 477 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B', 'Qwen/Qwen1.5-MoE-A2.7B'),
|
| 478 |
+
Model('Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4', 'Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4'),
|
| 479 |
+
]),
|
| 480 |
+
ModelGroup([
|
| 481 |
+
Model('Qwen/Qwen2-57B-A14B-Instruct', 'Qwen/Qwen2-57B-A14B-Instruct'),
|
| 482 |
+
Model('Qwen/Qwen2-57B-A14B', 'Qwen/Qwen2-57B-A14B'),
|
| 483 |
+
Model('Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4'),
|
| 484 |
+
])
|
| 485 |
+
],
|
| 486 |
+
TemplateType.qwen,
|
| 487 |
+
get_model_tokenizer_with_flash_attn,
|
| 488 |
+
architectures=['Qwen2MoeForCausalLM'],
|
| 489 |
+
requires=['transformers>=4.40'],
|
| 490 |
+
))
|
| 491 |
+
|
| 492 |
+
register_model(
|
| 493 |
+
ModelMeta(
|
| 494 |
+
LLMModelType.qwen3,
|
| 495 |
+
[
|
| 496 |
+
ModelGroup([
|
| 497 |
+
Model('Qwen/Qwen3-0.6B-Base', 'Qwen/Qwen3-0.6B-Base'),
|
| 498 |
+
Model('Qwen/Qwen3-1.7B-Base', 'Qwen/Qwen3-1.7B-Base'),
|
| 499 |
+
Model('Qwen/Qwen3-4B-Base', 'Qwen/Qwen3-4B-Base'),
|
| 500 |
+
Model('Qwen/Qwen3-8B-Base', 'Qwen/Qwen3-8B-Base'),
|
| 501 |
+
Model('Qwen/Qwen3-14B-Base', 'Qwen/Qwen3-14B-Base'),
|
| 502 |
+
Model('Qwen/Qwen3-32B-Base', 'Qwen/Qwen3-32B-Base'),
|
| 503 |
+
# instruct
|
| 504 |
+
Model('Qwen/Qwen3-0.6B', 'Qwen/Qwen3-0.6B'),
|
| 505 |
+
Model('Qwen/Qwen3-1.7B', 'Qwen/Qwen3-1.7B'),
|
| 506 |
+
Model('Qwen/Qwen3-4B', 'Qwen/Qwen3-4B'),
|
| 507 |
+
Model('Qwen/Qwen3-8B', 'Qwen/Qwen3-8B'),
|
| 508 |
+
Model('Qwen/Qwen3-14B', 'Qwen/Qwen3-14B'),
|
| 509 |
+
Model('Qwen/Qwen3-32B', 'Qwen/Qwen3-32B'),
|
| 510 |
+
# fp8
|
| 511 |
+
Model('Qwen/Qwen3-0.6B-FP8', 'Qwen/Qwen3-0.6B-FP8'),
|
| 512 |
+
Model('Qwen/Qwen3-1.7B-FP8', 'Qwen/Qwen3-1.7B-FP8'),
|
| 513 |
+
Model('Qwen/Qwen3-4B-FP8', 'Qwen/Qwen3-4B-FP8'),
|
| 514 |
+
Model('Qwen/Qwen3-8B-FP8', 'Qwen/Qwen3-8B-FP8'),
|
| 515 |
+
Model('Qwen/Qwen3-14B-FP8', 'Qwen/Qwen3-14B-FP8'),
|
| 516 |
+
Model('Qwen/Qwen3-32B-FP8', 'Qwen/Qwen3-32B-FP8'),
|
| 517 |
+
# awq
|
| 518 |
+
Model('Qwen/Qwen3-4B-AWQ', 'Qwen/Qwen3-4B-AWQ'),
|
| 519 |
+
Model('Qwen/Qwen3-8B-AWQ', 'Qwen/Qwen3-8B-AWQ'),
|
| 520 |
+
Model('Qwen/Qwen3-14B-AWQ', 'Qwen/Qwen3-14B-AWQ'),
|
| 521 |
+
Model('Qwen/Qwen3-32B-AWQ', 'Qwen/Qwen3-32B-AWQ'),
|
| 522 |
+
# swift
|
| 523 |
+
Model('swift/Qwen3-32B-AWQ'),
|
| 524 |
+
]),
|
| 525 |
+
],
|
| 526 |
+
TemplateType.qwen3,
|
| 527 |
+
get_model_tokenizer_with_flash_attn,
|
| 528 |
+
architectures=['Qwen3ForCausalLM'],
|
| 529 |
+
requires=['transformers>=4.51'],
|
| 530 |
+
model_arch=ModelArch.llama))
|
| 531 |
+
|
| 532 |
+
register_model(
|
| 533 |
+
ModelMeta(
|
| 534 |
+
LLMModelType.qwen3_moe,
|
| 535 |
+
[
|
| 536 |
+
ModelGroup([
|
| 537 |
+
Model('Qwen/Qwen3-30B-A3B-Base', 'Qwen/Qwen3-30B-A3B-Base'),
|
| 538 |
+
Model('Qwen/Qwen3-235B-A22B-Base', 'Qwen/Qwen3-235B-A22B-Base'),
|
| 539 |
+
# instruct
|
| 540 |
+
Model('Qwen/Qwen3-30B-A3B', 'Qwen/Qwen3-30B-A3B'),
|
| 541 |
+
Model('Qwen/Qwen3-235B-A22B', 'Qwen/Qwen3-235B-A22B'),
|
| 542 |
+
# fp8
|
| 543 |
+
Model('Qwen/Qwen3-30B-A3B-FP8', 'Qwen/Qwen3-30B-A3B-FP8'),
|
| 544 |
+
Model('Qwen/Qwen3-235B-A22B-FP8', 'Qwen/Qwen3-235B-A22B-FP8'),
|
| 545 |
+
# awq
|
| 546 |
+
Model('swift/Qwen3-30B-A3B-AWQ', 'cognitivecomputations/Qwen3-30B-A3B-AWQ'),
|
| 547 |
+
Model('swift/Qwen3-235B-A22B-AWQ', 'cognitivecomputations/Qwen3-235B-A22B-AWQ'),
|
| 548 |
+
]),
|
| 549 |
+
],
|
| 550 |
+
TemplateType.qwen3,
|
| 551 |
+
get_model_tokenizer_with_flash_attn,
|
| 552 |
+
architectures=['Qwen3MoeForCausalLM'],
|
| 553 |
+
requires=['transformers>=4.51'],
|
| 554 |
+
))
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def patch_qwen_vl_utils(vision_process):
|
| 558 |
+
if hasattr(vision_process, '_patch'):
|
| 559 |
+
return
|
| 560 |
+
for key in [
|
| 561 |
+
'image_factor', 'min_pixels', 'max_pixels', 'max_ratio', 'video_min_pixels', 'video_max_pixels',
|
| 562 |
+
'video_total_pixels', 'frame_factor', 'fps', 'fps_min_frames', 'fps_max_frames'
|
| 563 |
+
]:
|
| 564 |
+
type_func = float if key == 'fps' else int
|
| 565 |
+
setattr(vision_process, key.upper(), get_env_args(key, type_func, getattr(vision_process, key.upper())))
|
| 566 |
+
_read_video_decord = vision_process._read_video_decord
|
| 567 |
+
|
| 568 |
+
def _new_read_video_decord(ele: dict):
|
| 569 |
+
from swift.llm import load_file
|
| 570 |
+
ele['video'] = load_file(ele['video'])
|
| 571 |
+
return _read_video_decord(ele)
|
| 572 |
+
|
| 573 |
+
vision_process.VIDEO_READER_BACKENDS['decord'] = _new_read_video_decord
|
| 574 |
+
vision_process._patch = True
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def get_model_tokenizer_qwen2_vl(*args, **kwargs):
|
| 578 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 579 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration
|
| 580 |
+
model, tokenizer = get_model_tokenizer_multimodal(*args, **kwargs)
|
| 581 |
+
if model is not None and hasattr(model.model, 'embed_tokens'):
|
| 582 |
+
patch_output_clone(model.model.embed_tokens)
|
| 583 |
+
patch_output_to_input_device(model.model.embed_tokens)
|
| 584 |
+
|
| 585 |
+
from qwen_vl_utils import vision_process
|
| 586 |
+
patch_qwen_vl_utils(vision_process)
|
| 587 |
+
return model, tokenizer
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
register_model(
|
| 591 |
+
ModelMeta(
|
| 592 |
+
MLLMModelType.qwen2_vl,
|
| 593 |
+
[
|
| 594 |
+
ModelGroup(
|
| 595 |
+
[
|
| 596 |
+
# chat
|
| 597 |
+
Model('Qwen/Qwen2-VL-2B-Instruct', 'Qwen/Qwen2-VL-2B-Instruct'),
|
| 598 |
+
Model('Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-7B-Instruct'),
|
| 599 |
+
Model('Qwen/Qwen2-VL-72B-Instruct', 'Qwen/Qwen2-VL-72B-Instruct'),
|
| 600 |
+
# base
|
| 601 |
+
Model('Qwen/Qwen2-VL-2B', 'Qwen/Qwen2-VL-2B'),
|
| 602 |
+
Model('Qwen/Qwen2-VL-7B', 'Qwen/Qwen2-VL-7B'),
|
| 603 |
+
Model('Qwen/Qwen2-VL-72B', 'Qwen/Qwen2-VL-72B'),
|
| 604 |
+
# gptq-int4
|
| 605 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4'),
|
| 606 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4'),
|
| 607 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4'),
|
| 608 |
+
# gptq-int8
|
| 609 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8'),
|
| 610 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8'),
|
| 611 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8', 'Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8'),
|
| 612 |
+
# awq-int4
|
| 613 |
+
Model('Qwen/Qwen2-VL-2B-Instruct-AWQ', 'Qwen/Qwen2-VL-2B-Instruct-AWQ'),
|
| 614 |
+
Model('Qwen/Qwen2-VL-7B-Instruct-AWQ', 'Qwen/Qwen2-VL-7B-Instruct-AWQ'),
|
| 615 |
+
Model('Qwen/Qwen2-VL-72B-Instruct-AWQ', 'Qwen/Qwen2-VL-72B-Instruct-AWQ'),
|
| 616 |
+
], ),
|
| 617 |
+
ModelGroup([
|
| 618 |
+
Model('bytedance-research/UI-TARS-2B-SFT', 'bytedance-research/UI-TARS-2B-SFT'),
|
| 619 |
+
Model('bytedance-research/UI-TARS-7B-SFT', 'bytedance-research/UI-TARS-7B-SFT'),
|
| 620 |
+
Model('bytedance-research/UI-TARS-7B-DPO', 'bytedance-research/UI-TARS-7B-DPO'),
|
| 621 |
+
Model('bytedance-research/UI-TARS-72B-SFT', 'bytedance-research/UI-TARS-72B-SFT'),
|
| 622 |
+
Model('bytedance-research/UI-TARS-72B-DPO', 'bytedance-research/UI-TARS-72B-DPO'),
|
| 623 |
+
]),
|
| 624 |
+
ModelGroup([
|
| 625 |
+
Model('allenai/olmOCR-7B-0225-preview', 'allenai/olmOCR-7B-0225-preview'),
|
| 626 |
+
]),
|
| 627 |
+
],
|
| 628 |
+
TemplateType.qwen2_vl,
|
| 629 |
+
get_model_tokenizer_qwen2_vl,
|
| 630 |
+
model_arch=ModelArch.qwen2_vl,
|
| 631 |
+
architectures=['Qwen2VLForConditionalGeneration'],
|
| 632 |
+
requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 633 |
+
tags=['vision', 'video']))
|
| 634 |
+
|
| 635 |
+
register_model(
|
| 636 |
+
ModelMeta(
|
| 637 |
+
MLLMModelType.qvq, [
|
| 638 |
+
ModelGroup([
|
| 639 |
+
Model('Qwen/QVQ-72B-Preview', 'Qwen/QVQ-72B-Preview'),
|
| 640 |
+
]),
|
| 641 |
+
],
|
| 642 |
+
TemplateType.qvq,
|
| 643 |
+
get_model_tokenizer_qwen2_vl,
|
| 644 |
+
model_arch=ModelArch.qwen2_vl,
|
| 645 |
+
architectures=['Qwen2VLForConditionalGeneration'],
|
| 646 |
+
requires=['transformers>=4.45', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 647 |
+
tags=['vision', 'video']))
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def get_model_tokenizer_qwen2_5_vl(*args, **kwargs):
|
| 651 |
+
from transformers import Qwen2_5_VLForConditionalGeneration
|
| 652 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5_VLForConditionalGeneration
|
| 653 |
+
return get_model_tokenizer_qwen2_vl(*args, **kwargs)
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
register_model(
|
| 657 |
+
ModelMeta(
|
| 658 |
+
MLLMModelType.qwen2_5_vl, [
|
| 659 |
+
ModelGroup([
|
| 660 |
+
Model('Qwen/Qwen2.5-VL-3B-Instruct', 'Qwen/Qwen2.5-VL-3B-Instruct'),
|
| 661 |
+
Model('Qwen/Qwen2.5-VL-7B-Instruct', 'Qwen/Qwen2.5-VL-7B-Instruct'),
|
| 662 |
+
Model('Qwen/Qwen2.5-VL-32B-Instruct', 'Qwen/Qwen2.5-VL-32B-Instruct'),
|
| 663 |
+
Model('Qwen/Qwen2.5-VL-72B-Instruct', 'Qwen/Qwen2.5-VL-72B-Instruct'),
|
| 664 |
+
]),
|
| 665 |
+
ModelGroup([
|
| 666 |
+
Model('Qwen/Qwen2.5-VL-3B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-3B-Instruct-AWQ'),
|
| 667 |
+
Model('Qwen/Qwen2.5-VL-7B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-7B-Instruct-AWQ'),
|
| 668 |
+
Model('Qwen/Qwen2.5-VL-32B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-32B-Instruct-AWQ'),
|
| 669 |
+
Model('Qwen/Qwen2.5-VL-72B-Instruct-AWQ', 'Qwen/Qwen2.5-VL-72B-Instruct-AWQ'),
|
| 670 |
+
]),
|
| 671 |
+
],
|
| 672 |
+
TemplateType.qwen2_5_vl,
|
| 673 |
+
get_model_tokenizer_qwen2_5_vl,
|
| 674 |
+
model_arch=ModelArch.qwen2_vl,
|
| 675 |
+
architectures=['Qwen2_5_VLForConditionalGeneration'],
|
| 676 |
+
requires=['transformers>=4.49', 'qwen_vl_utils>=0.0.6', 'decord'],
|
| 677 |
+
tags=['vision', 'video']))
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def get_model_tokenizer_qwen2_5_omni(model_dir, *args, **kwargs):
|
| 681 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniConfig
|
| 682 |
+
from qwen_omni_utils import vision_process
|
| 683 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2_5OmniForConditionalGeneration
|
| 684 |
+
processor = Qwen2_5OmniProcessor.from_pretrained(model_dir, trust_remote_code=True)
|
| 685 |
+
kwargs['tokenizer'] = processor.tokenizer
|
| 686 |
+
kwargs['model_config'] = Qwen2_5OmniConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 687 |
+
patch_qwen_vl_utils(vision_process)
|
| 688 |
+
kwargs['model_config'].enable_audio_output = get_env_args('ENABLE_AUDIO_OUTPUT', bool, True)
|
| 689 |
+
model, _ = get_model_tokenizer_with_flash_attn(model_dir, *args, **kwargs)
|
| 690 |
+
if model:
|
| 691 |
+
use_submodel_func(model, 'thinker')
|
| 692 |
+
model.config.keys_to_ignore_at_inference += ['hidden_states', 'attention_mask']
|
| 693 |
+
model.config.talker_config.pad_token_id = None
|
| 694 |
+
return model, processor
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
register_model(
|
| 698 |
+
ModelMeta(
|
| 699 |
+
MLLMModelType.qwen2_5_omni,
|
| 700 |
+
[
|
| 701 |
+
ModelGroup([
|
| 702 |
+
Model('Qwen/Qwen2.5-Omni-3B', 'Qwen/Qwen2.5-Omni-3B'),
|
| 703 |
+
Model('Qwen/Qwen2.5-Omni-7B', 'Qwen/Qwen2.5-Omni-7B'),
|
| 704 |
+
]),
|
| 705 |
+
],
|
| 706 |
+
TemplateType.qwen2_5_omni,
|
| 707 |
+
get_model_tokenizer_qwen2_5_omni,
|
| 708 |
+
model_arch=ModelArch.qwen2_5_omni,
|
| 709 |
+
architectures=['Qwen2_5OmniModel'],
|
| 710 |
+
requires=['transformers>=4.50', 'soundfile', 'qwen_omni_utils', 'decord'],
|
| 711 |
+
tags=['vision', 'video', 'audio'],
|
| 712 |
+
additional_saved_files=['spk_dict.pt'],
|
| 713 |
+
ignore_patterns=[],
|
| 714 |
+
))
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def get_model_tokenizer_qwen2_audio(*args, **kwargs):
|
| 718 |
+
from transformers import Qwen2AudioForConditionalGeneration
|
| 719 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2AudioForConditionalGeneration
|
| 720 |
+
return get_model_tokenizer_multimodal(*args, **kwargs)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
register_model(
|
| 724 |
+
ModelMeta(
|
| 725 |
+
MLLMModelType.qwen2_audio,
|
| 726 |
+
[
|
| 727 |
+
ModelGroup([
|
| 728 |
+
Model('Qwen/Qwen2-Audio-7B-Instruct', 'Qwen/Qwen2-Audio-7B-Instruct'),
|
| 729 |
+
Model('Qwen/Qwen2-Audio-7B', 'Qwen/Qwen2-Audio-7B'),
|
| 730 |
+
]),
|
| 731 |
+
],
|
| 732 |
+
TemplateType.qwen2_audio,
|
| 733 |
+
get_model_tokenizer_qwen2_audio,
|
| 734 |
+
model_arch=ModelArch.qwen2_audio,
|
| 735 |
+
architectures=['Qwen2AudioForConditionalGeneration'],
|
| 736 |
+
requires=['transformers>=4.45,<4.49', 'librosa'],
|
| 737 |
+
tags=['audio'],
|
| 738 |
+
))
|
| 739 |
+
|
| 740 |
+
register_model(
|
| 741 |
+
ModelMeta(
|
| 742 |
+
LLMModelType.marco_o1, [ModelGroup([Model('AIDC-AI/Marco-o1', 'AIDC-AI/Marco-o1')])],
|
| 743 |
+
TemplateType.marco_o1,
|
| 744 |
+
get_model_tokenizer_with_flash_attn,
|
| 745 |
+
model_arch=ModelArch.llama,
|
| 746 |
+
architectures=['Qwen2ForCausalLM'],
|
| 747 |
+
requires=['transformers>=4.37']))
|
| 748 |
+
|
| 749 |
+
register_model(
|
| 750 |
+
ModelMeta(
|
| 751 |
+
LLMModelType.qwq_preview, [ModelGroup([Model('Qwen/QwQ-32B-Preview', 'Qwen/QwQ-32B-Preview')])],
|
| 752 |
+
TemplateType.qwq_preview,
|
| 753 |
+
get_model_tokenizer_with_flash_attn,
|
| 754 |
+
model_arch=ModelArch.llama,
|
| 755 |
+
architectures=['Qwen2ForCausalLM'],
|
| 756 |
+
requires=['transformers>=4.37']))
|
| 757 |
+
|
| 758 |
+
register_model(
|
| 759 |
+
ModelMeta(
|
| 760 |
+
LLMModelType.qwq,
|
| 761 |
+
[ModelGroup([
|
| 762 |
+
Model('Qwen/QwQ-32B', 'Qwen/QwQ-32B'),
|
| 763 |
+
Model('Qwen/QwQ-32B-AWQ', 'Qwen/QwQ-32B-AWQ'),
|
| 764 |
+
])],
|
| 765 |
+
TemplateType.qwq,
|
| 766 |
+
get_model_tokenizer_with_flash_attn,
|
| 767 |
+
model_arch=ModelArch.llama,
|
| 768 |
+
architectures=['Qwen2ForCausalLM'],
|
| 769 |
+
requires=['transformers>=4.37']))
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def get_model_tokenizer_ovis(*args, **kwargs):
|
| 773 |
+
kwargs['attn_impl_keys'] = ['llm_attn_implementation']
|
| 774 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
|
| 775 |
+
if model is not None:
|
| 776 |
+
model.visual_tokenizer.to(model.dtype)
|
| 777 |
+
model.vte.to(model.dtype)
|
| 778 |
+
|
| 779 |
+
model.generation_config.cache_implementation = None
|
| 780 |
+
func_list = ['generate', 'forward', 'get_input_embeddings']
|
| 781 |
+
use_submodel_func(model, 'llm', func_list)
|
| 782 |
+
embedding = model.get_input_embeddings()
|
| 783 |
+
patch_output_clone(embedding)
|
| 784 |
+
try:
|
| 785 |
+
# fix device_map
|
| 786 |
+
from transformers.cache_utils import HybridCache
|
| 787 |
+
|
| 788 |
+
def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
|
| 789 |
+
**kwargs) -> Tuple[torch.Tensor]:
|
| 790 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
|
| 791 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
| 792 |
+
return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)
|
| 793 |
+
|
| 794 |
+
if not hasattr(HybridCache, '_update_origin'):
|
| 795 |
+
HybridCache._update_origin = HybridCache.update
|
| 796 |
+
HybridCache.update = update
|
| 797 |
+
except ImportError:
|
| 798 |
+
pass
|
| 799 |
+
return model, tokenizer
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
register_model(
|
| 803 |
+
ModelMeta(
|
| 804 |
+
MLLMModelType.ovis1_6,
|
| 805 |
+
[
|
| 806 |
+
ModelGroup([
|
| 807 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-9B', 'AIDC-AI/Ovis1.6-Gemma2-9B'),
|
| 808 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4', 'AIDC-AI/Ovis1.6-Gemma2-9B-GPTQ-Int4'),
|
| 809 |
+
Model('AIDC-AI/Ovis1.6-Gemma2-27B', 'AIDC-AI/Ovis1.6-Gemma2-27B'),
|
| 810 |
+
]),
|
| 811 |
+
],
|
| 812 |
+
TemplateType.ovis1_6,
|
| 813 |
+
get_model_tokenizer_ovis,
|
| 814 |
+
model_arch=ModelArch.ovis1_6,
|
| 815 |
+
architectures=['Ovis'],
|
| 816 |
+
tags=['vision'],
|
| 817 |
+
requires=['transformers>=4.42'],
|
| 818 |
+
))
|
| 819 |
+
|
| 820 |
+
register_model(
|
| 821 |
+
ModelMeta(
|
| 822 |
+
MLLMModelType.ovis1_6_llama3,
|
| 823 |
+
[
|
| 824 |
+
ModelGroup([
|
| 825 |
+
Model('AIDC-AI/Ovis1.6-Llama3.2-3B', 'AIDC-AI/Ovis1.6-Llama3.2-3B'),
|
| 826 |
+
]),
|
| 827 |
+
],
|
| 828 |
+
TemplateType.ovis1_6_llama3,
|
| 829 |
+
get_model_tokenizer_ovis,
|
| 830 |
+
model_arch=ModelArch.ovis1_6,
|
| 831 |
+
architectures=['Ovis'],
|
| 832 |
+
tags=['vision'],
|
| 833 |
+
))
|
| 834 |
+
|
| 835 |
+
register_model(
|
| 836 |
+
ModelMeta(
|
| 837 |
+
MLLMModelType.ovis2,
|
| 838 |
+
[
|
| 839 |
+
ModelGroup([
|
| 840 |
+
Model('AIDC-AI/Ovis2-1B', 'AIDC-AI/Ovis2-1B'),
|
| 841 |
+
Model('AIDC-AI/Ovis2-2B', 'AIDC-AI/Ovis2-2B'),
|
| 842 |
+
Model('AIDC-AI/Ovis2-4B', 'AIDC-AI/Ovis2-4B'),
|
| 843 |
+
Model('AIDC-AI/Ovis2-8B', 'AIDC-AI/Ovis2-8B'),
|
| 844 |
+
Model('AIDC-AI/Ovis2-16B', 'AIDC-AI/Ovis2-16B'),
|
| 845 |
+
Model('AIDC-AI/Ovis2-34B', 'AIDC-AI/Ovis2-34B'),
|
| 846 |
+
]),
|
| 847 |
+
],
|
| 848 |
+
TemplateType.ovis2,
|
| 849 |
+
get_model_tokenizer_ovis,
|
| 850 |
+
model_arch=ModelArch.ovis1_6,
|
| 851 |
+
architectures=['Ovis'],
|
| 852 |
+
tags=['vision'],
|
| 853 |
+
requires=['transformers>=4.46.2', 'moviepy<2'],
|
| 854 |
+
))
|
| 855 |
+
|
| 856 |
+
register_model(
|
| 857 |
+
ModelMeta(
|
| 858 |
+
RMModelType.qwen2_reward,
|
| 859 |
+
[
|
| 860 |
+
ModelGroup([
|
| 861 |
+
Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
|
| 862 |
+
]),
|
| 863 |
+
],
|
| 864 |
+
TemplateType.qwen,
|
| 865 |
+
get_model_tokenizer_reward_model,
|
| 866 |
+
architectures=['Qwen2ForRewardModel'],
|
| 867 |
+
requires=['transformers>=4.37'],
|
| 868 |
+
))
|
| 869 |
+
|
| 870 |
+
register_model(
|
| 871 |
+
ModelMeta(
|
| 872 |
+
RMModelType.qwen2_5_prm,
|
| 873 |
+
[
|
| 874 |
+
ModelGroup([
|
| 875 |
+
Model('Qwen/Qwen2.5-Math-PRM-7B', 'Qwen/Qwen2.5-Math-PRM-7B'),
|
| 876 |
+
Model('Qwen/Qwen2.5-Math-7B-PRM800K', 'Qwen/Qwen2.5-Math-7B-PRM800K'),
|
| 877 |
+
Model('Qwen/Qwen2.5-Math-PRM-72B', 'Qwen/Qwen2.5-Math-PRM-72B'),
|
| 878 |
+
]),
|
| 879 |
+
],
|
| 880 |
+
TemplateType.qwen2_5_math_prm,
|
| 881 |
+
get_model_tokenizer_reward_model,
|
| 882 |
+
task_type='prm',
|
| 883 |
+
architectures=['Qwen2ForProcessRewardModel'],
|
| 884 |
+
requires=['transformers>=4.37'],
|
| 885 |
+
))
|
| 886 |
+
|
| 887 |
+
register_model(
|
| 888 |
+
ModelMeta(
|
| 889 |
+
RMModelType.qwen2_5_math_reward,
|
| 890 |
+
[
|
| 891 |
+
ModelGroup([
|
| 892 |
+
Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
|
| 893 |
+
]),
|
| 894 |
+
],
|
| 895 |
+
TemplateType.qwen2_5_math,
|
| 896 |
+
get_model_tokenizer_reward_model,
|
| 897 |
+
architectures=['Qwen2ForRewardModel'],
|
| 898 |
+
requires=['transformers>=4.37'],
|
| 899 |
+
))
|
ms-swift/swift/llm/model/model/skywork.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from swift.llm import TemplateType
|
| 6 |
+
from ..constant import LLMModelType, RMModelType
|
| 7 |
+
from ..model_arch import ModelArch
|
| 8 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 9 |
+
from ..utils import ModelInfo
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_skywork_model_tokenizer(model_dir: str,
|
| 13 |
+
model_info: ModelInfo,
|
| 14 |
+
model_kwargs: Dict[str, Any],
|
| 15 |
+
load_model: bool = True,
|
| 16 |
+
**kwargs):
|
| 17 |
+
model, tokenizer = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
|
| 18 |
+
if 'chat' in model_dir:
|
| 19 |
+
tokenizer.add_tokens('[USER]')
|
| 20 |
+
tokenizer.add_tokens('[BOT]')
|
| 21 |
+
tokenizer.add_tokens('[SEP]')
|
| 22 |
+
return model, tokenizer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
register_model(
|
| 26 |
+
ModelMeta(
|
| 27 |
+
LLMModelType.skywork,
|
| 28 |
+
[
|
| 29 |
+
ModelGroup([
|
| 30 |
+
Model('skywork/Skywork-13B-base', 'skywork/Skywork-13B-base'),
|
| 31 |
+
Model('skywork/Skywork-13B-chat'),
|
| 32 |
+
]),
|
| 33 |
+
],
|
| 34 |
+
TemplateType.skywork,
|
| 35 |
+
get_skywork_model_tokenizer,
|
| 36 |
+
architectures=['SkyworkForCausalLM'],
|
| 37 |
+
model_arch=ModelArch.llama,
|
| 38 |
+
))
|
| 39 |
+
|
| 40 |
+
register_model(
|
| 41 |
+
ModelMeta(
|
| 42 |
+
LLMModelType.skywork_o1,
|
| 43 |
+
[
|
| 44 |
+
ModelGroup([
|
| 45 |
+
Model('AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B', 'Skywork/Skywork-o1-Open-Llama-3.1-8B'),
|
| 46 |
+
]),
|
| 47 |
+
],
|
| 48 |
+
TemplateType.skywork_o1,
|
| 49 |
+
get_model_tokenizer_with_flash_attn,
|
| 50 |
+
architectures=['LlamaForCausalLM'],
|
| 51 |
+
requires=['transformers>=4.43'],
|
| 52 |
+
model_arch=ModelArch.llama,
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
register_model(
|
| 56 |
+
ModelMeta(
|
| 57 |
+
RMModelType.llama3_2_reward,
|
| 58 |
+
[
|
| 59 |
+
ModelGroup([
|
| 60 |
+
Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B', 'Skywork/Skywork-Reward-Llama-3.1-8B'),
|
| 61 |
+
Model('AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2', 'Skywork/Skywork-Reward-Llama-3.1-8B-v0.2'),
|
| 62 |
+
]),
|
| 63 |
+
ModelGroup([
|
| 64 |
+
Model('AI-ModelScope/GRM_Llama3.1_8B_rewardmodel-ft', 'Ray2333/GRM_Llama3.1_8B_rewardmodel-ft'),
|
| 65 |
+
Model('AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft', 'Ray2333/GRM-llama3.2-3B-rewardmodel-ft'),
|
| 66 |
+
])
|
| 67 |
+
],
|
| 68 |
+
TemplateType.llama3_2,
|
| 69 |
+
get_model_tokenizer_with_flash_attn,
|
| 70 |
+
requires=['transformers>=4.43'],
|
| 71 |
+
architectures=['LlamaForSequenceClassification'],
|
| 72 |
+
model_arch=ModelArch.llama,
|
| 73 |
+
))
|
| 74 |
+
|
| 75 |
+
register_model(
|
| 76 |
+
ModelMeta(
|
| 77 |
+
RMModelType.gemma_reward,
|
| 78 |
+
[
|
| 79 |
+
ModelGroup([
|
| 80 |
+
Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B', 'Skywork/Skywork-Reward-Gemma-2-27B'),
|
| 81 |
+
Model('AI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2', 'Skywork/Skywork-Reward-Gemma-2-27B-v0.2'),
|
| 82 |
+
]),
|
| 83 |
+
],
|
| 84 |
+
TemplateType.gemma,
|
| 85 |
+
get_model_tokenizer_with_flash_attn,
|
| 86 |
+
requires=['transformers>=4.42'],
|
| 87 |
+
architectures=['Gemma2ForSequenceClassification'],
|
| 88 |
+
model_arch=ModelArch.llama,
|
| 89 |
+
))
|
ms-swift/swift/llm/model/model/yi.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from swift.llm import TemplateType
|
| 9 |
+
from swift.utils import get_logger
|
| 10 |
+
from ..constant import LLMModelType, MLLMModelType
|
| 11 |
+
from ..model_arch import ModelArch
|
| 12 |
+
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
|
| 13 |
+
from ..utils import ModelInfo, git_clone_github
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_model_tokenizer_yi(model_dir, *args, **kwargs):
|
| 19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)
|
| 20 |
+
return get_model_tokenizer_with_flash_attn(model_dir, *args, tokenizer=tokenizer, **kwargs)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_model_tokenizer_yi_vl(model_dir: str,
|
| 24 |
+
model_info: ModelInfo,
|
| 25 |
+
model_kwargs: Dict[str, Any],
|
| 26 |
+
load_model: bool = True,
|
| 27 |
+
**kwargs):
|
| 28 |
+
local_repo_path = kwargs.get('local_repo_path')
|
| 29 |
+
if not local_repo_path:
|
| 30 |
+
local_repo_path = git_clone_github('https://github.com/01-ai/Yi')
|
| 31 |
+
sys.path.append(os.path.join(local_repo_path, 'VL'))
|
| 32 |
+
from llava.model import LlavaLlamaForCausalLM, LlavaConfig
|
| 33 |
+
from llava.model.constants import key_info
|
| 34 |
+
|
| 35 |
+
model_config = LlavaConfig.from_pretrained(model_dir)
|
| 36 |
+
mm_vision_tower = model_config.mm_vision_tower
|
| 37 |
+
model_config.mm_vision_tower = os.path.join(model_dir, *mm_vision_tower.rsplit('/', maxsplit=2)[-2:])
|
| 38 |
+
model_config.attention_dropout = 0.
|
| 39 |
+
key_info['model_path'] = model_dir
|
| 40 |
+
kwargs['automodel_class'] = kwargs['automodel_class'] or LlavaLlamaForCausalLM
|
| 41 |
+
model, tokenizer = get_model_tokenizer_yi(
|
| 42 |
+
model_dir, model_info, model_kwargs, load_model, model_config=model_config, **kwargs)
|
| 43 |
+
if model is not None:
|
| 44 |
+
logger.info('Please ignore the above warning.')
|
| 45 |
+
logger.info('Loading the parameters of vision_tower...')
|
| 46 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 47 |
+
vision_tower = model.get_vision_tower()
|
| 48 |
+
vision_tower.load_model()
|
| 49 |
+
vision_tower.to(device=model.device, dtype=model_config.torch_dtype)
|
| 50 |
+
if not hasattr(model.config, 'max_sequence_length'):
|
| 51 |
+
model.config.max_sequence_length = 2048
|
| 52 |
+
tokenizer.image_processor = vision_tower.image_processor
|
| 53 |
+
return model, tokenizer
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
register_model(
|
| 57 |
+
ModelMeta(
|
| 58 |
+
MLLMModelType.yi_vl,
|
| 59 |
+
[
|
| 60 |
+
ModelGroup([
|
| 61 |
+
Model('01ai/Yi-VL-6B', '01-ai/Yi-VL-6B'),
|
| 62 |
+
Model('01ai/Yi-VL-34B', '01-ai/Yi-VL-34B'),
|
| 63 |
+
], ),
|
| 64 |
+
],
|
| 65 |
+
TemplateType.yi_vl,
|
| 66 |
+
get_model_tokenizer_yi_vl,
|
| 67 |
+
model_arch=ModelArch.llava_llama,
|
| 68 |
+
architectures=['LlavaLlamaForCausalLM'],
|
| 69 |
+
requires=['transformers>=4.34'],
|
| 70 |
+
tags=['vision'],
|
| 71 |
+
))
|
| 72 |
+
|
| 73 |
+
register_model(
|
| 74 |
+
ModelMeta(
|
| 75 |
+
LLMModelType.yi,
|
| 76 |
+
[ # yi
|
| 77 |
+
ModelGroup([
|
| 78 |
+
Model('01ai/Yi-6B', '01-ai/Yi-6B'),
|
| 79 |
+
Model('01ai/Yi-6B-200K', '01-ai/Yi-6B-200K'),
|
| 80 |
+
Model('01ai/Yi-6B-Chat', '01-ai/Yi-6B-Chat'),
|
| 81 |
+
Model('01ai/Yi-6B-Chat-4bits', '01-ai/Yi-6B-Chat-4bits'),
|
| 82 |
+
Model('01ai/Yi-6B-Chat-8bits', '01-ai/Yi-6B-Chat-8bits'),
|
| 83 |
+
Model('01ai/Yi-9B', '01-ai/Yi-9B'),
|
| 84 |
+
Model('01ai/Yi-9B-200K', '01-ai/Yi-9B-200K'),
|
| 85 |
+
Model('01ai/Yi-34B', '01-ai/Yi-34B'),
|
| 86 |
+
Model('01ai/Yi-34B-200K', '01-ai/Yi-34B-200K'),
|
| 87 |
+
Model('01ai/Yi-34B-Chat', '01-ai/Yi-34B-Chat'),
|
| 88 |
+
Model('01ai/Yi-34B-Chat-4bits', '01-ai/Yi-34B-Chat-4bits'),
|
| 89 |
+
Model('01ai/Yi-34B-Chat-8bits', '01-ai/Yi-34B-Chat-8bits'),
|
| 90 |
+
]),
|
| 91 |
+
# yi1.5
|
| 92 |
+
ModelGroup([
|
| 93 |
+
Model('01ai/Yi-1.5-6B', '01-ai/Yi-1.5-6B'),
|
| 94 |
+
Model('01ai/Yi-1.5-6B-Chat', '01-ai/Yi-1.5-6B-Chat'),
|
| 95 |
+
Model('01ai/Yi-1.5-9B', '01-ai/Yi-1.5-9B'),
|
| 96 |
+
Model('01ai/Yi-1.5-9B-Chat', '01-ai/Yi-1.5-9B-Chat'),
|
| 97 |
+
Model('01ai/Yi-1.5-9B-Chat-16K', '01-ai/Yi-1.5-9B-Chat-16K'),
|
| 98 |
+
Model('01ai/Yi-1.5-34B', '01-ai/Yi-1.5-34B'),
|
| 99 |
+
Model('01ai/Yi-1.5-34B-Chat', '01-ai/Yi-1.5-34B-Chat'),
|
| 100 |
+
Model('01ai/Yi-1.5-34B-Chat-16K', '01-ai/Yi-1.5-34B-Chat-16K'),
|
| 101 |
+
]),
|
| 102 |
+
# yi1.5-quant
|
| 103 |
+
ModelGroup([
|
| 104 |
+
Model('AI-ModelScope/Yi-1.5-6B-Chat-GPTQ', 'modelscope/Yi-1.5-6B-Chat-GPTQ'),
|
| 105 |
+
Model('AI-ModelScope/Yi-1.5-6B-Chat-AWQ', 'modelscope/Yi-1.5-6B-Chat-AWQ'),
|
| 106 |
+
Model('AI-ModelScope/Yi-1.5-9B-Chat-GPTQ', 'modelscope/Yi-1.5-9B-Chat-GPTQ'),
|
| 107 |
+
Model('AI-ModelScope/Yi-1.5-9B-Chat-AWQ', 'modelscope/Yi-1.5-9B-Chat-AWQ'),
|
| 108 |
+
Model('AI-ModelScope/Yi-1.5-34B-Chat-GPTQ', 'modelscope/Yi-1.5-34B-Chat-GPTQ'),
|
| 109 |
+
Model('AI-ModelScope/Yi-1.5-34B-Chat-AWQ', 'modelscope/Yi-1.5-34B-Chat-AWQ'),
|
| 110 |
+
]),
|
| 111 |
+
],
|
| 112 |
+
TemplateType.chatml,
|
| 113 |
+
get_model_tokenizer_yi,
|
| 114 |
+
architectures=['LlamaForCausalLM'],
|
| 115 |
+
model_arch=ModelArch.llama,
|
| 116 |
+
))
|
| 117 |
+
|
| 118 |
+
register_model(
|
| 119 |
+
ModelMeta(
|
| 120 |
+
LLMModelType.yi_coder, [
|
| 121 |
+
ModelGroup([
|
| 122 |
+
Model('01ai/Yi-Coder-1.5B', '01-ai/Yi-Coder-1.5B'),
|
| 123 |
+
Model('01ai/Yi-Coder-9B', '01-ai/Yi-Coder-9B'),
|
| 124 |
+
Model('01ai/Yi-Coder-1.5B-Chat', '01-ai/Yi-Coder-1.5B-Chat'),
|
| 125 |
+
Model('01ai/Yi-Coder-9B-Chat', '01-ai/Yi-Coder-9B-Chat'),
|
| 126 |
+
], )
|
| 127 |
+
],
|
| 128 |
+
TemplateType.yi_coder,
|
| 129 |
+
get_model_tokenizer_yi,
|
| 130 |
+
architectures=['LlamaForCausalLM'],
|
| 131 |
+
model_arch=ModelArch.llama,
|
| 132 |
+
tags=['coding']))
|
| 133 |
+
|
| 134 |
+
register_model(
|
| 135 |
+
ModelMeta(
|
| 136 |
+
LLMModelType.sus,
|
| 137 |
+
[
|
| 138 |
+
ModelGroup([
|
| 139 |
+
Model('SUSTC/SUS-Chat-34B', 'SUSTech/SUS-Chat-34B'),
|
| 140 |
+
]),
|
| 141 |
+
],
|
| 142 |
+
TemplateType.sus,
|
| 143 |
+
get_model_tokenizer_yi,
|
| 144 |
+
model_arch=ModelArch.llama,
|
| 145 |
+
architectures=['LlamaForCausalLM'],
|
| 146 |
+
))
|
ms-swift/swift/llm/sampling/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampling import sampling_main
|
ms-swift/swift/llm/sampling/base.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List
|
| 2 |
+
|
| 3 |
+
from swift.llm import SamplingArguments
|
| 4 |
+
from swift.plugin import orms, prms
|
| 5 |
+
from swift.utils import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Sampler:
|
| 11 |
+
|
| 12 |
+
def __init__(self, input_args: SamplingArguments):
|
| 13 |
+
self.args = input_args
|
| 14 |
+
self.template = None
|
| 15 |
+
self.processor = None
|
| 16 |
+
self.prm_model = None
|
| 17 |
+
self.orm_model = None
|
| 18 |
+
self._prepare_model_tokenizer()
|
| 19 |
+
self._prepare_template()
|
| 20 |
+
self._prepare_rm()
|
| 21 |
+
|
| 22 |
+
def _prepare_model_tokenizer(self):
|
| 23 |
+
args = self.args
|
| 24 |
+
_, self.processor = args.get_model_processor(load_model=False)
|
| 25 |
+
|
| 26 |
+
def _prepare_rm(self):
|
| 27 |
+
if self.args.prm_model is None:
|
| 28 |
+
self.prm_model = None
|
| 29 |
+
logger.warning('prm_model is None.')
|
| 30 |
+
elif self.args.prm_model in prms:
|
| 31 |
+
self.prm_model = prms[self.args.prm_model]()
|
| 32 |
+
else:
|
| 33 |
+
from swift.llm import PtEngine
|
| 34 |
+
self.prm_model = PtEngine(self.args.prm_model, max_batch_size=64)
|
| 35 |
+
|
| 36 |
+
if self.args.orm_model is None:
|
| 37 |
+
self.orm_model = None
|
| 38 |
+
logger.warning('orm_model is None.')
|
| 39 |
+
elif self.args.orm_model in orms:
|
| 40 |
+
self.orm_model = orms[self.args.orm_model]()
|
| 41 |
+
else:
|
| 42 |
+
from swift.llm import PtEngine
|
| 43 |
+
self.orm_model = PtEngine(self.args.orm_model, max_batch_size=64)
|
| 44 |
+
|
| 45 |
+
def _prepare_template(self) -> None:
|
| 46 |
+
template = self.args.get_template(self.processor)
|
| 47 |
+
self.template = template
|
| 48 |
+
self.template.set_mode('train')
|
| 49 |
+
|
| 50 |
+
def truncate_input(self, slices: List[Dict[str, Any]]):
|
| 51 |
+
"""Truncate the input rows to avoid hitting the max length of the policy model"""
|
| 52 |
+
return slices
|
| 53 |
+
|
| 54 |
+
def do_sample(self, data):
|
| 55 |
+
raise NotImplementedError
|
ms-swift/swift/llm/sampling/sampling.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from swift.llm import SamplingArguments, SwiftPipeline, load_dataset
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SwiftSampling(SwiftPipeline):
|
| 16 |
+
args_class = SamplingArguments
|
| 17 |
+
args: args_class
|
| 18 |
+
|
| 19 |
+
def __init__(self, args: Union[List[str], SamplingArguments, None] = None) -> None:
|
| 20 |
+
super().__init__(args)
|
| 21 |
+
self.args.save_args()
|
| 22 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 23 |
+
self.cur_piece = 0
|
| 24 |
+
self.total_piece = 1
|
| 25 |
+
|
| 26 |
+
if self.args.data_range:
|
| 27 |
+
self.cur_piece, self.total_piece = self.args.data_range
|
| 28 |
+
|
| 29 |
+
if self.args.sampler_type == 'sample':
|
| 30 |
+
from swift.llm.sampling.vanilla_sampler import VanillaSampler
|
| 31 |
+
self.sampler = VanillaSampler(self.args)
|
| 32 |
+
elif self.args.sampler_type == 'mcts':
|
| 33 |
+
from swift.llm.sampling.mcts import MctsSampler
|
| 34 |
+
self.sampler = MctsSampler(self.args)
|
| 35 |
+
elif self.args.sampler_type == 'distill':
|
| 36 |
+
from swift.llm.sampling.distill_sampler import DistillSampler
|
| 37 |
+
self.sampler = DistillSampler(self.args)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f'Unsupported sampler type: {self.args.sampler_type}')
|
| 40 |
+
|
| 41 |
+
def _get_dataset(self):
|
| 42 |
+
args = self.args
|
| 43 |
+
dataset_kwargs = args.get_dataset_kwargs()
|
| 44 |
+
sampling_dataset, _ = load_dataset(
|
| 45 |
+
args.dataset, split_dataset_ratio=0., shuffle=args.dataset_shuffle, **dataset_kwargs)
|
| 46 |
+
logger.info(f'Sampling_dataset: {sampling_dataset}')
|
| 47 |
+
dataset_len = len(sampling_dataset)
|
| 48 |
+
piece_len = dataset_len // self.total_piece
|
| 49 |
+
sampling_dataset = sampling_dataset.select(range(piece_len * self.cur_piece, piece_len * (self.cur_piece + 1)))
|
| 50 |
+
return sampling_dataset
|
| 51 |
+
|
| 52 |
+
def run(self):
|
| 53 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 54 |
+
iter_file = os.path.join(self.args.output_dir, self.args.output_file)
|
| 55 |
+
resume_file = os.path.join(self.args.output_dir, self.args.output_file + '.resume')
|
| 56 |
+
tmp_file = os.path.join(self.args.output_dir, self.args.output_file + '.tmp')
|
| 57 |
+
ckpt_state_file = os.path.join(self.args.output_dir, 'ckpt_state.json')
|
| 58 |
+
if os.path.exists(iter_file) and not self.args.override_exist_file:
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
index_resume = -1
|
| 62 |
+
write_mode = 'w'
|
| 63 |
+
if self.args.resume:
|
| 64 |
+
write_mode = 'a'
|
| 65 |
+
if os.path.exists(resume_file):
|
| 66 |
+
shutil.copyfile(resume_file, tmp_file)
|
| 67 |
+
|
| 68 |
+
if os.path.exists(ckpt_state_file):
|
| 69 |
+
with open(ckpt_state_file, 'r') as ckpt_state:
|
| 70 |
+
data = json.load(ckpt_state)
|
| 71 |
+
index_resume = data.get('index', -1)
|
| 72 |
+
logger.info(f'Loaded index_resume: {index_resume}')
|
| 73 |
+
else:
|
| 74 |
+
if os.path.exists(tmp_file):
|
| 75 |
+
os.remove(tmp_file)
|
| 76 |
+
|
| 77 |
+
dataset = self._get_dataset()
|
| 78 |
+
dataset_len = len(dataset)
|
| 79 |
+
total_iters = int(dataset_len // self.args.num_sampling_per_gpu_batch_size)
|
| 80 |
+
|
| 81 |
+
if self.args.num_sampling_per_gpu_batches is None or self.args.num_sampling_per_gpu_batches > total_iters:
|
| 82 |
+
self.args.num_sampling_per_gpu_batches = total_iters
|
| 83 |
+
|
| 84 |
+
with open(tmp_file, write_mode) as f:
|
| 85 |
+
for _index in range(self.args.num_sampling_per_gpu_batches):
|
| 86 |
+
if _index <= index_resume:
|
| 87 |
+
continue
|
| 88 |
+
logger.info(f' Sampling index:{_index}')
|
| 89 |
+
slices = dataset[self.args.num_sampling_per_gpu_batch_size
|
| 90 |
+
* _index:self.args.num_sampling_per_gpu_batch_size * (_index + 1)]
|
| 91 |
+
slices = self.sampler.truncate_input(slices)
|
| 92 |
+
generated = self.sampler.do_sample(slices)
|
| 93 |
+
f.writelines(generated)
|
| 94 |
+
f.flush()
|
| 95 |
+
shutil.copy(tmp_file, resume_file)
|
| 96 |
+
with open(ckpt_state_file, 'w') as ckpt_state:
|
| 97 |
+
json.dump({'index': _index}, ckpt_state)
|
| 98 |
+
|
| 99 |
+
if os.path.exists(iter_file):
|
| 100 |
+
shutil.move(iter_file, iter_file + '.' + str(int(time.time())))
|
| 101 |
+
shutil.move(resume_file, iter_file)
|
| 102 |
+
logger.info(f'Sample file {iter_file} generated.')
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def sampling_main(args: Union[List[str], SamplingArguments, None] = None):
|
| 106 |
+
return SwiftSampling(args).main()
|
ms-swift/swift/llm/sampling/utils.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import inspect
|
| 3 |
+
from copy import copy
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import InferRequest, RequestConfig
|
| 10 |
+
from swift.utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_messages_md5(row: Dict[str, Any]):
|
| 16 |
+
row = copy(row)
|
| 17 |
+
row.pop('choices', None)
|
| 18 |
+
serialized = json.dumps(row, sort_keys=True)
|
| 19 |
+
return hashlib.md5(serialized.encode('utf-8')).hexdigest()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_reward(model: Any,
|
| 23 |
+
infer_requests: List[InferRequest],
|
| 24 |
+
request_config: RequestConfig = None,
|
| 25 |
+
ground_truths: List[str] = None,
|
| 26 |
+
threshold: Optional[float] = None):
|
| 27 |
+
"""Get reward from an RM model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model: The model instance or an RM evaluator
|
| 31 |
+
infer_requests: Infer requests sent to the model
|
| 32 |
+
request_config: Infer config
|
| 33 |
+
ground_truths: The ground truth list
|
| 34 |
+
threshold: An optional threshold to generate the mask
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple
|
| 38 |
+
Index 0: The min-max normalized scores matched the infer_requests
|
| 39 |
+
Index 1: The mask filtered by the threshold
|
| 40 |
+
"""
|
| 41 |
+
from swift.llm import InferEngine
|
| 42 |
+
infer_func = model.infer if isinstance(model, InferEngine) else model.__call__
|
| 43 |
+
parameters = inspect.signature(infer_func).parameters
|
| 44 |
+
gt_param = {}
|
| 45 |
+
if 'ground_truths' in parameters:
|
| 46 |
+
gt_param = {'ground_truths': ground_truths}
|
| 47 |
+
if isinstance(infer_requests[0], dict):
|
| 48 |
+
infer_requests = [InferRequest(messages=req['messages']) for req in infer_requests]
|
| 49 |
+
rewards = infer_func(infer_requests, request_config=request_config, **gt_param)
|
| 50 |
+
from swift.llm.infer.protocol import ChatCompletionResponse
|
| 51 |
+
if isinstance(rewards[0], ChatCompletionResponse):
|
| 52 |
+
print('reward:', rewards[0].choices[0].message.content)
|
| 53 |
+
if isinstance(rewards[0].choices[0].message.content, str):
|
| 54 |
+
rewards = [float(r.choices[0].message.content.strip('[]')) for r in rewards]
|
| 55 |
+
elif isinstance(rewards[0].choices[0].message.content, list):
|
| 56 |
+
rewards = [float(min(r.choices[0].message.content)) for r in rewards]
|
| 57 |
+
else:
|
| 58 |
+
rewards = [float(r.choices[0].message.content) for r in rewards]
|
| 59 |
+
arr = []
|
| 60 |
+
for reward in rewards:
|
| 61 |
+
if isinstance(reward, (list, tuple)):
|
| 62 |
+
arr.append(min(reward))
|
| 63 |
+
else:
|
| 64 |
+
arr.append(float(reward))
|
| 65 |
+
|
| 66 |
+
_mask = np.array([True] * len(arr))
|
| 67 |
+
if threshold is not None:
|
| 68 |
+
# > not >=, orm caller passes 0, which will cause error
|
| 69 |
+
_mask = np.array([a > threshold for a in arr])
|
| 70 |
+
|
| 71 |
+
def normalize(arr):
|
| 72 |
+
min_val = np.min(arr)
|
| 73 |
+
max_val = np.max(arr)
|
| 74 |
+
if min_val == max_val:
|
| 75 |
+
if min_val == 0:
|
| 76 |
+
constant_value = 0.0
|
| 77 |
+
else:
|
| 78 |
+
constant_value = min(1.0, min_val)
|
| 79 |
+
return np.full_like(arr, fill_value=constant_value, dtype=np.float64)
|
| 80 |
+
normalized = (arr - min_val) / (max_val - min_val + 1e-5)
|
| 81 |
+
return normalized
|
| 82 |
+
|
| 83 |
+
return normalize(arr), _mask
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def perform_infer(infer_engines, infer_requests, request_configs, **infer_kwargs):
|
| 87 |
+
if isinstance(infer_engines, list):
|
| 88 |
+
assert len(infer_engines) >= len(request_configs) >= len(infer_requests)
|
| 89 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 90 |
+
n = len(infer_requests)
|
| 91 |
+
with ThreadPoolExecutor(max_workers=n) as executor:
|
| 92 |
+
futures = {
|
| 93 |
+
executor.submit(perform_infer, infer_engines[i], infer_requests[i], request_configs[i], **infer_kwargs):
|
| 94 |
+
i
|
| 95 |
+
for i in range(n)
|
| 96 |
+
}
|
| 97 |
+
responses = []
|
| 98 |
+
for future in as_completed(futures):
|
| 99 |
+
task_id = futures[future]
|
| 100 |
+
try:
|
| 101 |
+
responses += future.result()
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.info(f'Perform infer task: {task_id} get an error: {e}')
|
| 104 |
+
return responses
|
| 105 |
+
elif isinstance(infer_requests, list):
|
| 106 |
+
responses = []
|
| 107 |
+
if isinstance(request_configs, list):
|
| 108 |
+
assert len(infer_requests) <= len(request_configs)
|
| 109 |
+
for i in range(len(infer_requests)):
|
| 110 |
+
responses += infer_engines.infer(
|
| 111 |
+
[infer_requests[i]],
|
| 112 |
+
request_configs[i],
|
| 113 |
+
**infer_kwargs,
|
| 114 |
+
)
|
| 115 |
+
elif isinstance(request_configs, RequestConfig):
|
| 116 |
+
for infer_request in infer_requests:
|
| 117 |
+
responses += infer_engines.infer(
|
| 118 |
+
[infer_request],
|
| 119 |
+
request_configs,
|
| 120 |
+
**infer_kwargs,
|
| 121 |
+
)
|
| 122 |
+
return responses
|
| 123 |
+
return infer_engines.infer(
|
| 124 |
+
[infer_requests],
|
| 125 |
+
request_configs,
|
| 126 |
+
**infer_kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def collect_from_mct(monte_carlo_tree, collect_filter_threshold):
|
| 131 |
+
from transformers.utils import strtobool
|
| 132 |
+
if isinstance(monte_carlo_tree, str):
|
| 133 |
+
monte_carlo_tree = json.loads(monte_carlo_tree)
|
| 134 |
+
|
| 135 |
+
def _collect(collect_curr_node, _outcome_rewards: list[float], _process_rewards: list[float]):
|
| 136 |
+
_prefer_pairs, _correct_answers, _incorrect_answers = [], [], []
|
| 137 |
+
_outcome_rewards = _outcome_rewards[:] + [collect_curr_node['outcome_reward']]
|
| 138 |
+
_process_rewards = _process_rewards[:] + [collect_curr_node['process_reward']]
|
| 139 |
+
if len(collect_curr_node['children']) > 0:
|
| 140 |
+
for child in collect_curr_node['children']:
|
| 141 |
+
p, c, i = _collect(child, _outcome_rewards, _process_rewards)
|
| 142 |
+
_prefer_pairs += p
|
| 143 |
+
_correct_answers += c
|
| 144 |
+
_incorrect_answers += i
|
| 145 |
+
sorted_children = sorted(collect_curr_node['children'], key=lambda x: x['outcome_reward'])
|
| 146 |
+
if sorted_children[-1]['outcome_reward'] - sorted_children[0]['outcome_reward'] > collect_filter_threshold:
|
| 147 |
+
# TODO: filter with visit count
|
| 148 |
+
prefer_pair = {
|
| 149 |
+
'path': 'ки\n'.join(collect_curr_node['path']),
|
| 150 |
+
'good': sorted_children[-1]['path'][-1],
|
| 151 |
+
'good_score': sorted_children[-1]['outcome_reward'],
|
| 152 |
+
'bad': sorted_children[0]['path'][-1],
|
| 153 |
+
'bad_score': sorted_children[0]['outcome_reward'],
|
| 154 |
+
}
|
| 155 |
+
_prefer_pairs.append(prefer_pair)
|
| 156 |
+
if strtobool(collect_curr_node['terminated']):
|
| 157 |
+
_answer = {
|
| 158 |
+
'answer': 'ки\n'.join(collect_curr_node['path']),
|
| 159 |
+
'mean_outcome_reward': np.mean(_outcome_rewards),
|
| 160 |
+
'min_outcome_reward': np.min(_outcome_rewards),
|
| 161 |
+
'mean_process_reward': np.mean(_process_rewards),
|
| 162 |
+
'min_process_reward': np.min(_process_rewards),
|
| 163 |
+
}
|
| 164 |
+
if strtobool(collect_curr_node['correct']):
|
| 165 |
+
_correct_answers.append(_answer)
|
| 166 |
+
else:
|
| 167 |
+
_incorrect_answers.append(_answer)
|
| 168 |
+
return _prefer_pairs, _correct_answers, _incorrect_answers
|
| 169 |
+
|
| 170 |
+
_root = monte_carlo_tree
|
| 171 |
+
prefer_pairs, correct_answers, incorrect_answers = _collect(_root, [], [])
|
| 172 |
+
return prefer_pairs, correct_answers, incorrect_answers
|
ms-swift/swift/llm/sampling/vanilla_sampler.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import RequestConfig
|
| 10 |
+
from swift.llm.sampling.base import Sampler
|
| 11 |
+
from swift.llm.template.template_inputs import InferRequest
|
| 12 |
+
from swift.utils import get_logger
|
| 13 |
+
from .utils import get_messages_md5, get_reward
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class VanillaSampler(Sampler):
|
| 19 |
+
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__(*args, **kwargs)
|
| 22 |
+
|
| 23 |
+
if self.args.sampler_engine == 'pt':
|
| 24 |
+
from swift.llm import PtEngine
|
| 25 |
+
_Engine = PtEngine
|
| 26 |
+
elif self.args.sampler_engine == 'vllm':
|
| 27 |
+
from swift.llm import VllmEngine
|
| 28 |
+
_Engine = VllmEngine
|
| 29 |
+
elif self.args.sampler_engine == 'lmdeploy':
|
| 30 |
+
from swift.llm import LmdeployEngine
|
| 31 |
+
_Engine = LmdeployEngine
|
| 32 |
+
elif self.args.sampler_engine == 'no':
|
| 33 |
+
_Engine = None
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
|
| 36 |
+
self.infer_engine = None
|
| 37 |
+
if _Engine:
|
| 38 |
+
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
|
| 39 |
+
self.infer_engine.default_template = self.template
|
| 40 |
+
self.infer_engine.strict = False
|
| 41 |
+
self.caches = self.read_cache()
|
| 42 |
+
|
| 43 |
+
def read_cache(self):
|
| 44 |
+
cache_files = self.args.cache_files
|
| 45 |
+
caches = {}
|
| 46 |
+
for file in cache_files:
|
| 47 |
+
if not os.path.exists(file):
|
| 48 |
+
logger.warning(f'Cache file does not exist: {file}')
|
| 49 |
+
continue
|
| 50 |
+
with open(file, 'r') as f:
|
| 51 |
+
for line in f.readlines():
|
| 52 |
+
line = line.strip()
|
| 53 |
+
if not line:
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
content = json.loads(line)
|
| 57 |
+
uuid = content['id']
|
| 58 |
+
messages = content['messages']
|
| 59 |
+
if uuid not in caches:
|
| 60 |
+
caches[uuid] = {'choices': []}
|
| 61 |
+
assert messages[-1]['role'] == 'assistant'
|
| 62 |
+
caches[uuid]['choices'].append(messages[-1]['content'])
|
| 63 |
+
return caches
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def convert_data_to_rows(data):
|
| 67 |
+
rows = []
|
| 68 |
+
key = list(data.keys())[0]
|
| 69 |
+
data_len = len(data[key])
|
| 70 |
+
for idx in range(data_len):
|
| 71 |
+
row = {key: data[key][idx] for key in data}
|
| 72 |
+
if row.get('images') and 'bytes' in row['images'][0]:
|
| 73 |
+
row['images'] = [img['path'] for img in row['images']]
|
| 74 |
+
rows.append(row)
|
| 75 |
+
VanillaSampler.check_row_valid(rows)
|
| 76 |
+
return rows
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def check_row_valid(rows):
|
| 80 |
+
for row in rows:
|
| 81 |
+
assert not row.get('images') or all([isinstance(img, str) and img for img in row['images']])
|
| 82 |
+
assert not row.get('videos') or all([isinstance(video, str) and video for video in row['videos']])
|
| 83 |
+
assert not row.get('audios') or all([isinstance(audio, str) and audio for audio in row['audios']])
|
| 84 |
+
|
| 85 |
+
def generate(self, data):
|
| 86 |
+
resp_all = []
|
| 87 |
+
infer_requests = []
|
| 88 |
+
sent = 0
|
| 89 |
+
rows = self.convert_data_to_rows(data)
|
| 90 |
+
for idx, row in enumerate(rows):
|
| 91 |
+
row = deepcopy(row)
|
| 92 |
+
messages = row['messages']
|
| 93 |
+
uuid = get_messages_md5(row)
|
| 94 |
+
if uuid in self.caches:
|
| 95 |
+
choices = self.caches[uuid]['choices']
|
| 96 |
+
if len(choices) == self.args.num_return_sequences:
|
| 97 |
+
continue
|
| 98 |
+
if self.args.system:
|
| 99 |
+
if messages[0]['role'] == 'system':
|
| 100 |
+
messages[0]['content'] = self.args.system
|
| 101 |
+
else:
|
| 102 |
+
messages.insert(0, {'role': 'system', 'content': self.args.system})
|
| 103 |
+
if messages[-1]['role'] == 'assistant':
|
| 104 |
+
messages = messages[:-1]
|
| 105 |
+
|
| 106 |
+
row['messages'] = messages
|
| 107 |
+
infer_request = row
|
| 108 |
+
for i in range(self.args.num_return_sequences):
|
| 109 |
+
infer_requests.append(deepcopy(infer_request))
|
| 110 |
+
sent += 1
|
| 111 |
+
|
| 112 |
+
request_config = RequestConfig(
|
| 113 |
+
max_tokens=self.args.max_new_tokens,
|
| 114 |
+
temperature=self.args.temperature,
|
| 115 |
+
top_k=self.args.top_k,
|
| 116 |
+
top_p=self.args.top_p,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
resp_list = []
|
| 120 |
+
if len(infer_requests) > 0:
|
| 121 |
+
resp_list = self.infer_engine.infer(infer_requests, request_config=request_config)
|
| 122 |
+
|
| 123 |
+
_cur = 0
|
| 124 |
+
for idx, row in enumerate(rows):
|
| 125 |
+
row = deepcopy(row)
|
| 126 |
+
uuid = get_messages_md5(row)
|
| 127 |
+
if uuid in self.caches:
|
| 128 |
+
choices = self.caches[uuid]['choices']
|
| 129 |
+
if len(choices) == self.args.num_return_sequences:
|
| 130 |
+
row['choices'] = choices
|
| 131 |
+
resp_all.append(row)
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
resps = row
|
| 135 |
+
resps['choices'] = []
|
| 136 |
+
for j in range(self.args.num_return_sequences * _cur, self.args.num_return_sequences * (_cur + 1)):
|
| 137 |
+
if not isinstance(resp_list[j], Exception):
|
| 138 |
+
resps['choices'].append(resp_list[j].choices[0].message.content)
|
| 139 |
+
if resps['choices']:
|
| 140 |
+
resp_all.append(resps)
|
| 141 |
+
_cur += 1
|
| 142 |
+
return resp_all
|
| 143 |
+
|
| 144 |
+
def do_sample(self, data):
|
| 145 |
+
generated = []
|
| 146 |
+
resp_all = self.generate(data)
|
| 147 |
+
for i, resps in enumerate(resp_all):
|
| 148 |
+
choices = resps['choices']
|
| 149 |
+
messages = resps['messages']
|
| 150 |
+
uuid = get_messages_md5(resps)
|
| 151 |
+
assert messages[-1]['role'] == 'assistant'
|
| 152 |
+
ground_truth = messages[-1]['content']
|
| 153 |
+
|
| 154 |
+
infer_requests = []
|
| 155 |
+
for decoded in choices:
|
| 156 |
+
_resps = deepcopy(resps)
|
| 157 |
+
_resps['messages'][-1]['content'] = decoded
|
| 158 |
+
infer_requests.append(_resps)
|
| 159 |
+
|
| 160 |
+
_resps = deepcopy(resps)
|
| 161 |
+
_resps['messages'][-1]['content'] = ground_truth
|
| 162 |
+
infer_requests.append(_resps)
|
| 163 |
+
if self.orm_model is not None:
|
| 164 |
+
orm_score, _orm_mask = get_reward(
|
| 165 |
+
self.orm_model, infer_requests, ground_truths=[ground_truth] * len(infer_requests), threshold=0.0)
|
| 166 |
+
else:
|
| 167 |
+
orm_score = np.array([1.0] * len(infer_requests))
|
| 168 |
+
_orm_mask = np.array([True] * len(infer_requests))
|
| 169 |
+
if self.prm_model is not None:
|
| 170 |
+
prm_score, _prm_mask = get_reward(
|
| 171 |
+
self.prm_model,
|
| 172 |
+
infer_requests,
|
| 173 |
+
ground_truths=[ground_truth] * len(infer_requests),
|
| 174 |
+
threshold=self.args.prm_threshold)
|
| 175 |
+
else:
|
| 176 |
+
prm_score = np.array([1.0] * len(infer_requests))
|
| 177 |
+
_prm_mask = np.array([True] * len(infer_requests))
|
| 178 |
+
|
| 179 |
+
_mask = _orm_mask & _prm_mask
|
| 180 |
+
if not any(_mask):
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
choices.append(ground_truth)
|
| 184 |
+
choices = np.array(choices)
|
| 185 |
+
|
| 186 |
+
if self.orm_model is None and self.prm_model is None:
|
| 187 |
+
positives = choices[:-1]
|
| 188 |
+
for positive in positives:
|
| 189 |
+
_resps = deepcopy(resps)
|
| 190 |
+
_resps.pop('choices', None)
|
| 191 |
+
_resps['id'] = uuid
|
| 192 |
+
_resps['messages'][-1]['content'] = str(positive)
|
| 193 |
+
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
|
| 194 |
+
else:
|
| 195 |
+
score = np.array(prm_score) + np.array(orm_score * 10)
|
| 196 |
+
sorted_indices = np.argsort(score)[::-1]
|
| 197 |
+
pos_indexes = sorted_indices[0:self.args.n_best_to_keep]
|
| 198 |
+
pos_indexes = [i for i in pos_indexes if _mask[i]]
|
| 199 |
+
neg_index = sorted_indices[-1]
|
| 200 |
+
logger.info(
|
| 201 |
+
f'orm:{orm_score}, prm:{prm_score}, positive index: {pos_indexes}, negative index: {neg_index}')
|
| 202 |
+
if self.args.easy_query_threshold is not None and sum([score > 0 for score in orm_score]) - 1 >= int(
|
| 203 |
+
self.args.num_return_sequences * self.args.easy_query_threshold):
|
| 204 |
+
continue
|
| 205 |
+
if len(pos_indexes) > 0:
|
| 206 |
+
positives = choices[pos_indexes]
|
| 207 |
+
negative = choices[neg_index]
|
| 208 |
+
for positive in positives:
|
| 209 |
+
_resps = deepcopy(resps)
|
| 210 |
+
messages = deepcopy(messages)
|
| 211 |
+
_resps.pop('choices', None)
|
| 212 |
+
_resps['messages'][-1]['content'] = str(positive)
|
| 213 |
+
_resps['rejected_response'] = str(negative)
|
| 214 |
+
_resps['id'] = uuid
|
| 215 |
+
generated.append(json.dumps(_resps, ensure_ascii=False) + '\n')
|
| 216 |
+
return generated
|
ms-swift/swift/llm/template/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from . import template
|
| 3 |
+
from .base import MaxLengthError, Template
|
| 4 |
+
from .constant import TemplateType
|
| 5 |
+
from .grounding import draw_bbox
|
| 6 |
+
from .register import TEMPLATE_MAPPING, get_template, get_template_meta, register_template
|
| 7 |
+
from .template_inputs import InferRequest, TemplateInputs
|
| 8 |
+
from .template_meta import TemplateMeta
|
| 9 |
+
from .utils import Prompt, Word, split_str_parts_by
|
| 10 |
+
from .vision_utils import load_file, load_image
|
ms-swift/swift/llm/template/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (758 Bytes). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/base.cpython-310.pyc
ADDED
|
Binary file (50.8 kB). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/grounding.cpython-310.pyc
ADDED
|
Binary file (3.47 kB). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/register.cpython-310.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/template_inputs.cpython-310.pyc
ADDED
|
Binary file (6.41 kB). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
ms-swift/swift/llm/template/__pycache__/vision_utils.cpython-310.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
ms-swift/swift/llm/template/base.py
ADDED
|
@@ -0,0 +1,1606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import hashlib
|
| 3 |
+
import inspect
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from contextlib import contextmanager, nullcontext
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from dataclasses import asdict
|
| 10 |
+
from functools import partial, wraps
|
| 11 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 17 |
+
from peft import PeftModel
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 20 |
+
from transformers import StoppingCriteriaList
|
| 21 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
| 22 |
+
from transformers.utils import strtobool
|
| 23 |
+
|
| 24 |
+
from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
|
| 25 |
+
from ..utils import Processor, ProcessorMixin
|
| 26 |
+
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
|
| 27 |
+
from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by
|
| 28 |
+
from .vision_utils import load_audio, load_batch, load_image, rescale_image
|
| 29 |
+
|
| 30 |
+
logger = get_logger()
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from .template_meta import TemplateMeta
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class MaxLengthError(ValueError):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Template(ProcessorMixin):
|
| 40 |
+
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>']
|
| 41 |
+
special_keys = ['images', 'videos', 'audios', 'objects']
|
| 42 |
+
|
| 43 |
+
image_placeholder = ['<image>']
|
| 44 |
+
video_placeholder = ['<video>']
|
| 45 |
+
audio_placeholder = ['<audio>']
|
| 46 |
+
cot_process_placeholder = ['ки']
|
| 47 |
+
placeholder_tokens = [] # For clearer printing
|
| 48 |
+
load_images = True
|
| 49 |
+
skip_prompt = True
|
| 50 |
+
use_model = False
|
| 51 |
+
norm_bbox = 'norm1000'
|
| 52 |
+
|
| 53 |
+
is_encoder_decoder = False
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
processor: Processor,
|
| 58 |
+
template_meta: 'TemplateMeta',
|
| 59 |
+
default_system: Optional[str] = None,
|
| 60 |
+
max_length: Optional[int] = None,
|
| 61 |
+
*,
|
| 62 |
+
use_chat_template: bool = True,
|
| 63 |
+
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
|
| 64 |
+
max_pixels: Optional[int] = None,
|
| 65 |
+
agent_template: Optional[str] = None,
|
| 66 |
+
norm_bbox: Literal['norm1000', 'none', None] = None,
|
| 67 |
+
response_prefix: Optional[str] = None,
|
| 68 |
+
# only for train
|
| 69 |
+
padding_side: Literal['left', 'right'] = 'right',
|
| 70 |
+
loss_scale: str = 'default',
|
| 71 |
+
sequence_parallel_size: int = 1,
|
| 72 |
+
# infer/deploy
|
| 73 |
+
template_backend: Literal['swift', 'jinja'] = 'swift',
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
default_system: Override the default_system in the template.
|
| 77 |
+
max_length: Max length of the sequence
|
| 78 |
+
truncation_strategy: The truncation strategy
|
| 79 |
+
max_pixels: Rescale image to reduce memory usage, default `None` means no limitation.
|
| 80 |
+
e.g. 512 * 512 (H*W)
|
| 81 |
+
padding_side: The padding_side when the training batch_size >= 2
|
| 82 |
+
loss_scale: The loss scale function to use
|
| 83 |
+
"""
|
| 84 |
+
from .template_meta import TemplateMeta
|
| 85 |
+
from swift.plugin import agent_templates
|
| 86 |
+
|
| 87 |
+
self.processor = processor
|
| 88 |
+
self.model_info = processor.model_info
|
| 89 |
+
self.config = self.model_info.config
|
| 90 |
+
self.model_meta = processor.model_meta
|
| 91 |
+
if max_length is None:
|
| 92 |
+
max_length = self.model_info.max_model_len
|
| 93 |
+
tokenizer = self.tokenizer
|
| 94 |
+
|
| 95 |
+
if not use_chat_template:
|
| 96 |
+
template_meta = template_meta.to_generate_template_meta()
|
| 97 |
+
else:
|
| 98 |
+
template_meta = deepcopy(template_meta)
|
| 99 |
+
# if default_system is None. not change self.default_system
|
| 100 |
+
template_meta.check_system(default_system)
|
| 101 |
+
if default_system is not None:
|
| 102 |
+
template_meta.default_system = default_system
|
| 103 |
+
if response_prefix is not None:
|
| 104 |
+
template_meta.response_prefix = response_prefix
|
| 105 |
+
logger.info(f'default_system: {repr(template_meta.default_system)}')
|
| 106 |
+
logger.info(f'response_prefix: {repr(template_meta.response_prefix)}')
|
| 107 |
+
|
| 108 |
+
for i, token in enumerate(self.placeholder_tokens):
|
| 109 |
+
if isinstance(token, str):
|
| 110 |
+
self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token)
|
| 111 |
+
template_meta.init(tokenizer)
|
| 112 |
+
|
| 113 |
+
self.template_meta: TemplateMeta = template_meta
|
| 114 |
+
self.use_chat_template = use_chat_template
|
| 115 |
+
self.template_backend = template_backend
|
| 116 |
+
self.max_length = max_length
|
| 117 |
+
self.truncation_strategy = truncation_strategy
|
| 118 |
+
self.loss_scale = loss_scale
|
| 119 |
+
self.max_pixels = max_pixels
|
| 120 |
+
self.padding_side = padding_side
|
| 121 |
+
self.sequence_parallel_size = sequence_parallel_size
|
| 122 |
+
agent_template = agent_template or template_meta.agent_template
|
| 123 |
+
logger.info(f'agent_template: {agent_template}')
|
| 124 |
+
self.agent_template = agent_templates[agent_template]()
|
| 125 |
+
self.norm_bbox = norm_bbox or self.norm_bbox
|
| 126 |
+
logger.info(f'max_length: {self.max_length}')
|
| 127 |
+
logger.info(f'norm_bbox: {self.norm_bbox}')
|
| 128 |
+
if self.is_encoder_decoder:
|
| 129 |
+
self.skip_prompt = False
|
| 130 |
+
self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
|
| 131 |
+
'train', 'rlhf', 'kto', # train
|
| 132 |
+
'seq_cls', 'embedding', 'prm'] = 'pt'
|
| 133 |
+
self._packing = False
|
| 134 |
+
self.use_megatron = False
|
| 135 |
+
if self.model_info.task_type != 'causal_lm':
|
| 136 |
+
self.mode = self.model_info.task_type
|
| 137 |
+
self._handles = []
|
| 138 |
+
self._deepspeed_initialize = None
|
| 139 |
+
|
| 140 |
+
@staticmethod
|
| 141 |
+
def _load_image(image, load_images: bool):
|
| 142 |
+
if load_images:
|
| 143 |
+
if isinstance(image, dict) and 'bytes' in image:
|
| 144 |
+
image = image['bytes'] or image['path']
|
| 145 |
+
image = load_image(image)
|
| 146 |
+
else:
|
| 147 |
+
if isinstance(image, dict):
|
| 148 |
+
path = image['path']
|
| 149 |
+
if path and (path.startswith('http') or os.path.exists(path)):
|
| 150 |
+
image = path
|
| 151 |
+
else:
|
| 152 |
+
image = load_image(image['bytes'])
|
| 153 |
+
elif not isinstance(image, str):
|
| 154 |
+
image = load_image(image)
|
| 155 |
+
return image
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _get_height_width(inputs: StdTemplateInputs) -> None:
|
| 159 |
+
width = []
|
| 160 |
+
height = []
|
| 161 |
+
for image in inputs.images:
|
| 162 |
+
width.append(image.width)
|
| 163 |
+
height.append(image.height)
|
| 164 |
+
inputs.objects['width'] = width
|
| 165 |
+
inputs.objects['height'] = height
|
| 166 |
+
|
| 167 |
+
def normalize_bbox(self, inputs: StdTemplateInputs) -> None:
|
| 168 |
+
objects = inputs.objects
|
| 169 |
+
bbox_list = objects['bbox']
|
| 170 |
+
width_list = objects['width']
|
| 171 |
+
height_list = objects['height']
|
| 172 |
+
bbox_type = objects.pop('bbox_type', None) or 'real'
|
| 173 |
+
image_id_list = objects.pop('image_id', None) or []
|
| 174 |
+
image_id_list += [0] * (len(bbox_list) - len(image_id_list))
|
| 175 |
+
for bbox, image_id in zip(bbox_list, image_id_list):
|
| 176 |
+
if bbox_type == 'norm1':
|
| 177 |
+
width, height = 1, 1
|
| 178 |
+
else:
|
| 179 |
+
width, height = width_list[image_id], height_list[image_id]
|
| 180 |
+
for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])):
|
| 181 |
+
if self.norm_bbox == 'norm1000':
|
| 182 |
+
norm_width, norm_height = 1000, 1000
|
| 183 |
+
elif self.norm_bbox == 'none':
|
| 184 |
+
image = inputs.images[image_id]
|
| 185 |
+
norm_width, norm_height = image.width, image.height
|
| 186 |
+
bbox[2 * i] = int(round(x / width * norm_width))
|
| 187 |
+
bbox[2 * i + 1] = int(round(y / height * norm_height))
|
| 188 |
+
|
| 189 |
+
def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
|
| 190 |
+
agent_template = self.agent_template
|
| 191 |
+
agent_template.template_meta = self.template_meta # for hermes
|
| 192 |
+
if inputs.tools:
|
| 193 |
+
if isinstance(inputs.tools, str):
|
| 194 |
+
inputs.tools = agent_template._parse_json(inputs.tools)
|
| 195 |
+
if not isinstance(inputs.tools, (list, tuple)):
|
| 196 |
+
inputs.tools = [inputs.tools]
|
| 197 |
+
elif isinstance(inputs.tools, (list, tuple)):
|
| 198 |
+
inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools]
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(f'inputs.tools: {inputs.tools}')
|
| 201 |
+
for i, tool in enumerate(inputs.tools):
|
| 202 |
+
inputs.tools[i] = agent_template.wrap_tool(tool)
|
| 203 |
+
i = 0
|
| 204 |
+
messages = inputs.messages
|
| 205 |
+
while i < len(messages):
|
| 206 |
+
if messages[i]['role'] == 'tool_call':
|
| 207 |
+
i_start = i
|
| 208 |
+
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call':
|
| 209 |
+
i += 1
|
| 210 |
+
tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1])
|
| 211 |
+
messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}]
|
| 212 |
+
i = i_start + 1
|
| 213 |
+
else:
|
| 214 |
+
i += 1
|
| 215 |
+
|
| 216 |
+
def _preprocess_inputs(
|
| 217 |
+
self,
|
| 218 |
+
inputs: StdTemplateInputs,
|
| 219 |
+
) -> None:
|
| 220 |
+
self._preprocess_function_call(inputs)
|
| 221 |
+
if self.model_meta.is_multimodal:
|
| 222 |
+
self._replace_image_tags(inputs)
|
| 223 |
+
self._replace_start_image_tags(inputs)
|
| 224 |
+
images = inputs.images
|
| 225 |
+
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
|
| 226 |
+
load_images_origin = load_images
|
| 227 |
+
if self.max_pixels is not None or inputs.objects:
|
| 228 |
+
load_images = True
|
| 229 |
+
if images:
|
| 230 |
+
for i, image in enumerate(images):
|
| 231 |
+
images[i] = self._load_image(images[i], load_images)
|
| 232 |
+
if inputs.objects:
|
| 233 |
+
self._get_height_width(inputs)
|
| 234 |
+
if self.max_pixels is not None:
|
| 235 |
+
# Scale the image proportionally without affecting the scaled objects.
|
| 236 |
+
images = [rescale_image(img, self.max_pixels) for img in images]
|
| 237 |
+
if images and not load_images_origin: # fix pt & qwen-vl
|
| 238 |
+
for i, image in enumerate(images):
|
| 239 |
+
if isinstance(image, Image.Image):
|
| 240 |
+
images[i] = self._save_pil_image(image)
|
| 241 |
+
inputs.images = images
|
| 242 |
+
|
| 243 |
+
if self.mode == 'vllm' and inputs.audios:
|
| 244 |
+
sampling_rate = get_env_args('sampling_rate', int, None)
|
| 245 |
+
inputs.audios = load_batch(
|
| 246 |
+
inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True))
|
| 247 |
+
|
| 248 |
+
if inputs.is_multimodal:
|
| 249 |
+
self._add_default_tags(inputs)
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def _replace_image_tags(inputs: StdTemplateInputs):
|
| 253 |
+
# compat
|
| 254 |
+
if inputs.images:
|
| 255 |
+
return
|
| 256 |
+
images = []
|
| 257 |
+
pattern = r'<img>(.+?)</img>'
|
| 258 |
+
for message in inputs.messages:
|
| 259 |
+
content = message['content']
|
| 260 |
+
if not isinstance(content, str):
|
| 261 |
+
continue
|
| 262 |
+
for image in re.findall(pattern, content):
|
| 263 |
+
# only support local_path
|
| 264 |
+
if os.path.isfile(image):
|
| 265 |
+
images.append(image)
|
| 266 |
+
else:
|
| 267 |
+
logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>')
|
| 268 |
+
message['content'] = re.sub(pattern, '<image>', content)
|
| 269 |
+
inputs.images = images
|
| 270 |
+
|
| 271 |
+
@staticmethod
|
| 272 |
+
def _replace_start_image_tags(inputs: StdTemplateInputs):
|
| 273 |
+
# compat
|
| 274 |
+
generate_mode = False
|
| 275 |
+
message = inputs.messages[-1]
|
| 276 |
+
content = message['content']
|
| 277 |
+
if message['role'] == 'user' and content.endswith('<start-image>'):
|
| 278 |
+
generate_mode = True
|
| 279 |
+
message['content'] = message['content'][:-len('<start-image>')] # remove the <start-image>
|
| 280 |
+
inputs.generate_mode = generate_mode
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int],
|
| 284 |
+
get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]:
|
| 285 |
+
added_tokens_len = 0
|
| 286 |
+
for i, idx in enumerate(replace_idx_list):
|
| 287 |
+
new_tokens = get_new_tokens(i)
|
| 288 |
+
token_len = len(new_tokens)
|
| 289 |
+
input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:]
|
| 290 |
+
if labels:
|
| 291 |
+
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:]
|
| 292 |
+
added_tokens_len += token_len - 1
|
| 293 |
+
return input_ids, labels
|
| 294 |
+
|
| 295 |
+
def compute_loss_context(self, model, inputs):
|
| 296 |
+
return nullcontext()
|
| 297 |
+
|
| 298 |
+
def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 299 |
+
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
|
| 300 |
+
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
|
| 301 |
+
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
|
| 302 |
+
chosen_encoded = self._encode_truncated(chosen_inputs)
|
| 303 |
+
rejected_encoded = self._encode_truncated(rejected_inputs)
|
| 304 |
+
|
| 305 |
+
encoded = {}
|
| 306 |
+
for prefix in ['chosen', 'rejected']:
|
| 307 |
+
data = locals()[f'{prefix}_encoded']
|
| 308 |
+
for k, v in data.items():
|
| 309 |
+
encoded[f'{prefix}_{k}'] = v
|
| 310 |
+
return encoded
|
| 311 |
+
|
| 312 |
+
def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 313 |
+
label, inputs.label = inputs.label, None
|
| 314 |
+
encoded = self._rlhf_encode(inputs)
|
| 315 |
+
encoded['label'] = bool(label)
|
| 316 |
+
return encoded
|
| 317 |
+
|
| 318 |
+
def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 319 |
+
_encoded = {}
|
| 320 |
+
labels = []
|
| 321 |
+
|
| 322 |
+
def split_multi_medias(_inputs):
|
| 323 |
+
_content = _inputs.messages[-2]['content']
|
| 324 |
+
image_size = len(re.findall('<image>', _content))
|
| 325 |
+
video_size = len(re.findall('<video>', _content))
|
| 326 |
+
audio_size = len(re.findall('<audio>', _content))
|
| 327 |
+
_inputs.images = inputs.images[:image_size]
|
| 328 |
+
assert len(_inputs.images) == image_size
|
| 329 |
+
inputs.images = inputs.images[image_size:]
|
| 330 |
+
_inputs.videos = inputs.videos[:video_size]
|
| 331 |
+
assert len(_inputs.videos) == video_size
|
| 332 |
+
inputs.videos = inputs.videos[video_size:]
|
| 333 |
+
_inputs.audios = inputs.audios[:audio_size]
|
| 334 |
+
assert len(_inputs.audios) == audio_size
|
| 335 |
+
inputs.audios = inputs.audios[audio_size:]
|
| 336 |
+
|
| 337 |
+
anchor = deepcopy(inputs)
|
| 338 |
+
anchor.messages[-1]['content'] = ''
|
| 339 |
+
anchor.rejected_response = []
|
| 340 |
+
split_multi_medias(anchor)
|
| 341 |
+
anchor_encoded = self._encode_truncated(anchor)
|
| 342 |
+
for key in anchor_encoded:
|
| 343 |
+
_encoded[f'anchor_{key}'] = anchor_encoded[key]
|
| 344 |
+
|
| 345 |
+
positive = deepcopy(inputs)
|
| 346 |
+
positive.messages[-2]['content'] = positive.messages[-1]['content']
|
| 347 |
+
positive.messages[-1]['content'] = ''
|
| 348 |
+
positive.rejected_response = []
|
| 349 |
+
split_multi_medias(positive)
|
| 350 |
+
positive_encoded = self._encode_truncated(positive)
|
| 351 |
+
for key in positive_encoded:
|
| 352 |
+
_encoded[f'positive_{key}'] = positive_encoded[key]
|
| 353 |
+
labels.append(float(inputs.label) if inputs.label is not None else 1.0)
|
| 354 |
+
|
| 355 |
+
rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
|
| 356 |
+
for i in range(rejected_len):
|
| 357 |
+
negative = deepcopy(inputs)
|
| 358 |
+
negative.messages[-2]['content'] = negative.rejected_response[i]
|
| 359 |
+
negative.messages[-1]['content'] = ''
|
| 360 |
+
negative.rejected_response = []
|
| 361 |
+
split_multi_medias(negative)
|
| 362 |
+
negative_encoded = self._encode_truncated(negative)
|
| 363 |
+
for key in negative_encoded:
|
| 364 |
+
_encoded[f'negative{i}_{key}'] = negative_encoded[key]
|
| 365 |
+
labels.append(0.0)
|
| 366 |
+
|
| 367 |
+
_encoded['labels'] = labels
|
| 368 |
+
return _encoded
|
| 369 |
+
|
| 370 |
+
def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 371 |
+
encoded = self._encode_truncated(inputs)
|
| 372 |
+
encoded.pop('labels', None)
|
| 373 |
+
if inputs.label is not None:
|
| 374 |
+
labels = inputs.label
|
| 375 |
+
problem_type = self._get_problem_type(self.config, labels=labels)
|
| 376 |
+
if problem_type == 'single_label_classification':
|
| 377 |
+
labels = int(labels)
|
| 378 |
+
encoded['labels'] = labels
|
| 379 |
+
return encoded
|
| 380 |
+
|
| 381 |
+
@torch.inference_mode()
|
| 382 |
+
def encode(self,
|
| 383 |
+
inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
|
| 384 |
+
return_template_inputs: bool = False) -> Dict[str, Any]:
|
| 385 |
+
"""The entrance method of Template!
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
return {'input_ids': List[int], 'labels': Optional[List[int]], ...}
|
| 389 |
+
"""
|
| 390 |
+
if isinstance(inputs, (InferRequest, TemplateInputs)):
|
| 391 |
+
inputs = asdict(inputs)
|
| 392 |
+
|
| 393 |
+
if isinstance(inputs, dict):
|
| 394 |
+
inputs = deepcopy(inputs)
|
| 395 |
+
if not self.is_training:
|
| 396 |
+
InferRequest.remove_response(inputs['messages'])
|
| 397 |
+
inputs = StdTemplateInputs.from_dict(inputs)
|
| 398 |
+
elif isinstance(inputs, StdTemplateInputs):
|
| 399 |
+
inputs = deepcopy(inputs)
|
| 400 |
+
assert isinstance(inputs, StdTemplateInputs)
|
| 401 |
+
self._preprocess_inputs(inputs)
|
| 402 |
+
|
| 403 |
+
if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}:
|
| 404 |
+
encoded = self._encode_truncated(inputs)
|
| 405 |
+
elif self.mode == 'seq_cls':
|
| 406 |
+
encoded = self._seq_cls_encode(inputs)
|
| 407 |
+
elif self.mode == 'rlhf':
|
| 408 |
+
encoded = self._rlhf_encode(inputs)
|
| 409 |
+
elif self.mode == 'kto':
|
| 410 |
+
encoded = self._kto_encode(inputs)
|
| 411 |
+
elif self.mode == 'embedding':
|
| 412 |
+
encoded = self._embedding_encode(inputs)
|
| 413 |
+
for key in list(encoded.keys()):
|
| 414 |
+
if encoded[key] is None:
|
| 415 |
+
encoded.pop(key)
|
| 416 |
+
if return_template_inputs:
|
| 417 |
+
encoded['template_inputs'] = inputs
|
| 418 |
+
return encoded
|
| 419 |
+
|
| 420 |
+
def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
|
| 421 |
+
packed = {}
|
| 422 |
+
keys = set()
|
| 423 |
+
for r in row:
|
| 424 |
+
keys.update(r[0].keys())
|
| 425 |
+
for key in keys:
|
| 426 |
+
if key in {'input_ids', 'labels', 'loss_scale'}:
|
| 427 |
+
packed[key] = sum((x[0][key] for x in row), start=[])
|
| 428 |
+
if 'position_ids' not in packed:
|
| 429 |
+
packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[])
|
| 430 |
+
|
| 431 |
+
packed.update(self._data_collator_mm_data([r[0] for r in row]))
|
| 432 |
+
return packed
|
| 433 |
+
|
| 434 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 435 |
+
return inputs
|
| 436 |
+
|
| 437 |
+
@staticmethod
|
| 438 |
+
def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
|
| 439 |
+
len_tokens = len(stop_tokens)
|
| 440 |
+
if is_finished and generate_ids[-len_tokens:] == stop_tokens:
|
| 441 |
+
return generate_ids[:-len_tokens]
|
| 442 |
+
if not is_finished:
|
| 443 |
+
for i in range(len_tokens, 0, -1):
|
| 444 |
+
if generate_ids[-i:] == stop_tokens[:i]:
|
| 445 |
+
return generate_ids[:-i]
|
| 446 |
+
return generate_ids
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int):
|
| 450 |
+
idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist()
|
| 451 |
+
logprobs = logprobs.tolist()
|
| 452 |
+
return {
|
| 453 |
+
'content': [{
|
| 454 |
+
'index': pred,
|
| 455 |
+
'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred],
|
| 456 |
+
'top_logprobs': [{
|
| 457 |
+
'index': idx,
|
| 458 |
+
'logprob': logprobs[idx]
|
| 459 |
+
} for idx in idxs]
|
| 460 |
+
}]
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
@staticmethod
|
| 464 |
+
def _get_problem_type(config, labels=None, logits=None) -> str:
|
| 465 |
+
problem_type = config.problem_type
|
| 466 |
+
if problem_type is not None:
|
| 467 |
+
return problem_type
|
| 468 |
+
if labels is not None:
|
| 469 |
+
if isinstance(labels, (list, tuple)):
|
| 470 |
+
if labels and isinstance(labels[0], float):
|
| 471 |
+
problem_type = 'regression'
|
| 472 |
+
else:
|
| 473 |
+
problem_type = 'multi_label_classification'
|
| 474 |
+
else:
|
| 475 |
+
problem_type = 'single_label_classification'
|
| 476 |
+
assert config.num_labels >= labels + 1
|
| 477 |
+
if logits is not None:
|
| 478 |
+
if logits.shape[-1] == 1:
|
| 479 |
+
problem_type = 'regression'
|
| 480 |
+
else:
|
| 481 |
+
problem_type = 'single_label_classification' # compatible with older versions
|
| 482 |
+
assert problem_type is not None
|
| 483 |
+
config.problem_type = problem_type
|
| 484 |
+
return problem_type
|
| 485 |
+
|
| 486 |
+
def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int):
|
| 487 |
+
assert isinstance(logits, torch.Tensor)
|
| 488 |
+
problem_type = self._get_problem_type(self.config, logits=logits)
|
| 489 |
+
if problem_type == 'regression':
|
| 490 |
+
preds = logits.squeeze(dim=-1).tolist()
|
| 491 |
+
logprobs = [None] * len(preds)
|
| 492 |
+
else:
|
| 493 |
+
if problem_type == 'single_label_classification':
|
| 494 |
+
preds = torch.argmax(logits, dim=-1).tolist()
|
| 495 |
+
logprobs = torch.log_softmax(logits, -1)
|
| 496 |
+
else:
|
| 497 |
+
preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)]
|
| 498 |
+
logprobs = F.logsigmoid(logits)
|
| 499 |
+
logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)]
|
| 500 |
+
return preds, logprobs
|
| 501 |
+
|
| 502 |
+
def decode(self,
|
| 503 |
+
generate_ids: List[int],
|
| 504 |
+
*,
|
| 505 |
+
is_finished: bool = True,
|
| 506 |
+
tokenizer_kwargs=None,
|
| 507 |
+
first_token=True,
|
| 508 |
+
**kwargs) -> Any:
|
| 509 |
+
tokenizer_kwargs = tokenizer_kwargs or {}
|
| 510 |
+
response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs)
|
| 511 |
+
if first_token and self.template_meta.response_prefix:
|
| 512 |
+
response = self.template_meta.response_prefix + response
|
| 513 |
+
return response
|
| 514 |
+
|
| 515 |
+
def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
|
| 516 |
+
raise NotImplementedError
|
| 517 |
+
|
| 518 |
+
def generate(self, model, *args, **kwargs):
|
| 519 |
+
if isinstance(model, PeftModel):
|
| 520 |
+
signature = inspect.signature(model.model.generate)
|
| 521 |
+
else:
|
| 522 |
+
signature = inspect.signature(model.generate)
|
| 523 |
+
if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs:
|
| 524 |
+
kwargs['use_model_defaults'] = False
|
| 525 |
+
return model.generate(*args, **kwargs)
|
| 526 |
+
|
| 527 |
+
def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
|
| 528 |
+
# Do not print template_meta.suffix[-1] and eos_token.
|
| 529 |
+
# However, other stop_words will be printed.
|
| 530 |
+
tokenizer = self.tokenizer
|
| 531 |
+
|
| 532 |
+
if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
|
| 533 |
+
generate_ids = generate_ids[:-1]
|
| 534 |
+
# skip suffix and eos_token
|
| 535 |
+
template_suffix = self.template_meta.suffix[-1]
|
| 536 |
+
if isinstance(template_suffix, str):
|
| 537 |
+
# [-1:]: fix OpenGVLab/Mini-InternVL-Chat-4B-V1-5
|
| 538 |
+
template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:]
|
| 539 |
+
generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
|
| 540 |
+
if 'spaces_between_special_tokens' not in decode_kwargs:
|
| 541 |
+
decode_kwargs['spaces_between_special_tokens'] = False
|
| 542 |
+
return tokenizer.decode(generate_ids, **decode_kwargs)
|
| 543 |
+
|
| 544 |
+
def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]:
|
| 545 |
+
generation_config = generate_kwargs['generation_config']
|
| 546 |
+
stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words
|
| 547 |
+
generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)])
|
| 548 |
+
return generate_kwargs
|
| 549 |
+
|
| 550 |
+
@staticmethod
|
| 551 |
+
def _save_pil_image(image: Image.Image) -> str:
|
| 552 |
+
img_bytes = image.tobytes()
|
| 553 |
+
img_hash = hashlib.sha256(img_bytes).hexdigest()
|
| 554 |
+
tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images')
|
| 555 |
+
logger.info_once(f'create tmp_dir: {tmp_dir}')
|
| 556 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 557 |
+
img_path = os.path.join(tmp_dir, f'{img_hash}.png')
|
| 558 |
+
if not os.path.exists(img_path):
|
| 559 |
+
image.save(img_path)
|
| 560 |
+
return img_path
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def _concat_context_list(
|
| 564 |
+
context_list: List[Context],
|
| 565 |
+
res_context_list: List[Context], # inplace
|
| 566 |
+
res_context_type: List[ContextType], # inplace
|
| 567 |
+
system: Optional[str] = None,
|
| 568 |
+
query: Optional[str] = None,
|
| 569 |
+
response: Optional[str] = None,
|
| 570 |
+
round0: Optional[int] = None) -> None:
|
| 571 |
+
"""Concat context list and replace placeholder"""
|
| 572 |
+
round1 = None
|
| 573 |
+
if round0 is not None:
|
| 574 |
+
round1 = str(round0 + 1)
|
| 575 |
+
round0 = str(round0)
|
| 576 |
+
for context in context_list:
|
| 577 |
+
if isinstance(context, str):
|
| 578 |
+
if '{{RESPONSE}}' == context:
|
| 579 |
+
assert response is not None
|
| 580 |
+
res_context_list.append(response)
|
| 581 |
+
res_context_type.append(ContextType.RESPONSE)
|
| 582 |
+
continue
|
| 583 |
+
old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
|
| 584 |
+
new_str_list = [system, query, round0, round1]
|
| 585 |
+
for (old_str, new_str) in zip(old_str_list, new_str_list):
|
| 586 |
+
if new_str is not None and old_str in context:
|
| 587 |
+
assert isinstance(new_str, str), f'new_str: {new_str}'
|
| 588 |
+
context = context.replace(old_str, new_str)
|
| 589 |
+
if len(context) == 0:
|
| 590 |
+
continue
|
| 591 |
+
res_context_list.append(context)
|
| 592 |
+
res_context_type.append(ContextType.OTHER)
|
| 593 |
+
|
| 594 |
+
def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
|
| 595 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 596 |
+
"""Merge anything in the context to simplify the inputs"""
|
| 597 |
+
context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list)
|
| 598 |
+
context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs)
|
| 599 |
+
|
| 600 |
+
res: List[Context] = [] # result of context_list
|
| 601 |
+
res_loss_scale: List[float] = [] # result of loss_scale_list
|
| 602 |
+
temp: List[str] = []
|
| 603 |
+
temp_loss_scale = 0.
|
| 604 |
+
for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
|
| 605 |
+
if isinstance(context, str) and (loss_scale == temp_loss_scale):
|
| 606 |
+
temp.append(context)
|
| 607 |
+
else:
|
| 608 |
+
if len(temp) > 0:
|
| 609 |
+
res.append(''.join(temp))
|
| 610 |
+
res_loss_scale.append(temp_loss_scale)
|
| 611 |
+
temp.clear()
|
| 612 |
+
if isinstance(context, str): # loss_scale diff
|
| 613 |
+
temp.append(context)
|
| 614 |
+
else:
|
| 615 |
+
res.append(context)
|
| 616 |
+
res_loss_scale.append(loss_scale)
|
| 617 |
+
temp_loss_scale = loss_scale
|
| 618 |
+
if len(temp) > 0:
|
| 619 |
+
res.append(''.join(temp))
|
| 620 |
+
res_loss_scale.append(temp_loss_scale)
|
| 621 |
+
|
| 622 |
+
return res, res_loss_scale
|
| 623 |
+
|
| 624 |
+
@staticmethod
|
| 625 |
+
def _split_special_tokens(context_list: List[Context],
|
| 626 |
+
loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
|
| 627 |
+
"""Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation"""
|
| 628 |
+
res: List[Context] = []
|
| 629 |
+
loss_scale_res: List[float] = []
|
| 630 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 631 |
+
contexts = []
|
| 632 |
+
if isinstance(fetch_one(context), str):
|
| 633 |
+
for d in split_str_parts_by(context, Template.special_tokens):
|
| 634 |
+
contexts.extend([d['key'], d['content']])
|
| 635 |
+
contexts = [c for c in contexts if c]
|
| 636 |
+
res.extend(contexts)
|
| 637 |
+
loss_scale_res.extend([loss_scale] * len(contexts))
|
| 638 |
+
else:
|
| 639 |
+
res.append(context)
|
| 640 |
+
loss_scale_res.append(loss_scale)
|
| 641 |
+
return res, loss_scale_res
|
| 642 |
+
|
| 643 |
+
def _tokenize(self, context, **tokenizer_kwargs):
|
| 644 |
+
return self.tokenizer(
|
| 645 |
+
context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
|
| 646 |
+
|
| 647 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 648 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 649 |
+
"""Override this function to do your own replace operation.
|
| 650 |
+
|
| 651 |
+
This method is used to replace standard tags like `<image>` to some tokens that the model needs.
|
| 652 |
+
|
| 653 |
+
Args:
|
| 654 |
+
media_type: The modal.
|
| 655 |
+
index: The index of the medias, for index 0 represents the first elements in `images`
|
| 656 |
+
inputs: The inputs
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
The content or input_ids after replacement.
|
| 660 |
+
"""
|
| 661 |
+
if media_type == 'image':
|
| 662 |
+
if self.mode == 'lmdeploy':
|
| 663 |
+
return [[-100]]
|
| 664 |
+
return self.image_placeholder
|
| 665 |
+
elif media_type == 'video':
|
| 666 |
+
return self.video_placeholder
|
| 667 |
+
elif media_type == 'audio':
|
| 668 |
+
return self.audio_placeholder
|
| 669 |
+
|
| 670 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 671 |
+
"""Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task.
|
| 672 |
+
Override this function to do your own replace operation.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
ref: Description of the bbox
|
| 676 |
+
index: The index in the `objects` key
|
| 677 |
+
inputs: The inputs
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
The contents or input_ids replaced
|
| 681 |
+
"""
|
| 682 |
+
return [ref]
|
| 683 |
+
|
| 684 |
+
def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]:
|
| 685 |
+
"""Replace the cot process label for PRM training or inference.
|
| 686 |
+
Override this function to do your own replace operation.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
inputs: The inputs
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
The contents or input_ids replaced
|
| 693 |
+
"""
|
| 694 |
+
return [self.cot_process_placeholder]
|
| 695 |
+
|
| 696 |
+
@staticmethod
|
| 697 |
+
def _get_bbox_str(bbox: List[int]) -> str:
|
| 698 |
+
point = []
|
| 699 |
+
for x, y in zip(bbox[::2], bbox[1::2]):
|
| 700 |
+
point.append(f'({x},{y})')
|
| 701 |
+
return ','.join(point)
|
| 702 |
+
|
| 703 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 704 |
+
"""Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task.
|
| 705 |
+
Override this function to do your own replace operation.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
bbox: [x, y] or [x1, y1, x2, y2]
|
| 709 |
+
index: The index in the `objects` key
|
| 710 |
+
inputs: The inputs
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
The contents or input_ids replaced
|
| 714 |
+
"""
|
| 715 |
+
return [f'[{self._get_bbox_str(bbox)}]']
|
| 716 |
+
|
| 717 |
+
def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float],
|
| 718 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 719 |
+
# https://github.com/modelscope/ms-swift/issues/3407
|
| 720 |
+
# Fix the bounding box position offset issue in the Qwen2.5-VL grounding task.
|
| 721 |
+
res: List[Context] = []
|
| 722 |
+
res_loss_scale: List[float] = []
|
| 723 |
+
inputs.image_idx = 0
|
| 724 |
+
|
| 725 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 726 |
+
if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images):
|
| 727 |
+
c_list = self.replace_tag('image', inputs.image_idx, inputs)
|
| 728 |
+
inputs.image_idx += 1
|
| 729 |
+
loss_scale = 0. if self.template_backend == 'swift' else 1.
|
| 730 |
+
else:
|
| 731 |
+
c_list = [context]
|
| 732 |
+
res += c_list
|
| 733 |
+
res_loss_scale += [loss_scale] * len(c_list)
|
| 734 |
+
return res, res_loss_scale
|
| 735 |
+
|
| 736 |
+
def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
|
| 737 |
+
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
|
| 738 |
+
"""This method happens before tokenization, replace standard tags to the contents or input_ids needed by
|
| 739 |
+
the model.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
context_list: The content list
|
| 743 |
+
loss_scale_list: The loss scale list
|
| 744 |
+
Returns:
|
| 745 |
+
The context_list and loss_scale_list after replacement.
|
| 746 |
+
"""
|
| 747 |
+
context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs)
|
| 748 |
+
if inputs.images and inputs.objects:
|
| 749 |
+
self.normalize_bbox(inputs)
|
| 750 |
+
# replace tag/object/box
|
| 751 |
+
res: List[Context] = [] # result of context_list
|
| 752 |
+
res_loss_scale: List[float] = [] # result of loss_scale_list
|
| 753 |
+
|
| 754 |
+
# reset
|
| 755 |
+
for k in ['video', 'audio', 'object', 'box']:
|
| 756 |
+
setattr(inputs, f'{k}_idx', 0)
|
| 757 |
+
|
| 758 |
+
for context, loss_scale in zip(context_list, loss_scale_list):
|
| 759 |
+
for k in ['video', 'audio']:
|
| 760 |
+
if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len(
|
| 761 |
+
getattr(inputs, f'{k}s')):
|
| 762 |
+
c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs)
|
| 763 |
+
setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1)
|
| 764 |
+
loss_scale = 0.
|
| 765 |
+
break
|
| 766 |
+
else:
|
| 767 |
+
ref = inputs.objects.get('ref') or []
|
| 768 |
+
bbox = inputs.objects.get('bbox') or []
|
| 769 |
+
if context == '<ref-object>' and inputs.ref_idx < len(ref):
|
| 770 |
+
idx = inputs.ref_idx
|
| 771 |
+
c_list = self.replace_ref(ref[idx], idx, inputs)
|
| 772 |
+
inputs.ref_idx += 1
|
| 773 |
+
elif context == '<bbox>' and inputs.bbox_idx < len(bbox):
|
| 774 |
+
idx = inputs.bbox_idx
|
| 775 |
+
c_list = self.replace_bbox(bbox[idx], idx, inputs)
|
| 776 |
+
inputs.bbox_idx += 1
|
| 777 |
+
elif context == '<cot-process>' and self.mode == 'prm':
|
| 778 |
+
c_list = self.replace_cot_process(inputs)
|
| 779 |
+
else:
|
| 780 |
+
c_list = [context]
|
| 781 |
+
res += c_list
|
| 782 |
+
res_loss_scale += [loss_scale] * len(c_list)
|
| 783 |
+
return res, res_loss_scale
|
| 784 |
+
|
| 785 |
+
@staticmethod
|
| 786 |
+
def _add_default_tags(inputs: StdTemplateInputs):
|
| 787 |
+
total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
|
| 788 |
+
if inputs.rejected_response:
|
| 789 |
+
if isinstance(inputs.rejected_response, str):
|
| 790 |
+
total_content += inputs.rejected_response
|
| 791 |
+
else:
|
| 792 |
+
total_content += '\n'.join(inputs.rejected_response)
|
| 793 |
+
if inputs.system:
|
| 794 |
+
total_content = f'{inputs.system}\n{total_content}'
|
| 795 |
+
for media_type in ['image', 'audio', 'video']:
|
| 796 |
+
media_key, media_tag = f'{media_type}s', f'<{media_type}>'
|
| 797 |
+
medias = getattr(inputs, media_key)
|
| 798 |
+
if not isinstance(medias, list):
|
| 799 |
+
medias = [medias]
|
| 800 |
+
if medias:
|
| 801 |
+
num_media_tags = len(re.findall(media_tag, total_content))
|
| 802 |
+
num_media = len(medias)
|
| 803 |
+
num_new_tags = num_media - num_media_tags
|
| 804 |
+
if num_new_tags > 0:
|
| 805 |
+
inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content']
|
| 806 |
+
elif num_new_tags < 0:
|
| 807 |
+
logger.warning(
|
| 808 |
+
f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. '
|
| 809 |
+
'We will only replace the frontmost media_tags while keeping the subsequent media_tags.')
|
| 810 |
+
|
| 811 |
+
def _encode_context_list(
|
| 812 |
+
self,
|
| 813 |
+
context_list: List[Context],
|
| 814 |
+
loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
|
| 815 |
+
"""return: input_ids, labels, tokenizer_kwargs"""
|
| 816 |
+
input_ids: List[int] = []
|
| 817 |
+
labels: List[int] = []
|
| 818 |
+
loss_scale: List[float] = []
|
| 819 |
+
tokenizer_kwargs = {}
|
| 820 |
+
if loss_scale_list is None:
|
| 821 |
+
loss_scale_list = [0.] * len(context_list)
|
| 822 |
+
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list)
|
| 823 |
+
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
|
| 824 |
+
if isinstance(context, str):
|
| 825 |
+
# tokenizer_kwargs is the returned tokenizer_kwargs,
|
| 826 |
+
# while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
|
| 827 |
+
token_list = self._tokenize(context)
|
| 828 |
+
else:
|
| 829 |
+
token_list = context
|
| 830 |
+
input_ids += token_list
|
| 831 |
+
if loss_scale_list[i] > 0.0:
|
| 832 |
+
labels += token_list
|
| 833 |
+
else:
|
| 834 |
+
labels += [-100] * len(token_list)
|
| 835 |
+
if not ignore_loss_scale:
|
| 836 |
+
loss_scale.extend([loss_weight] * len(token_list))
|
| 837 |
+
if ignore_loss_scale:
|
| 838 |
+
loss_scale = None
|
| 839 |
+
return input_ids, labels, loss_scale, tokenizer_kwargs
|
| 840 |
+
|
| 841 |
+
@staticmethod
|
| 842 |
+
def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]],
|
| 843 |
+
suffix_tokens_id: List[int]) -> None:
|
| 844 |
+
suffix_len = len(suffix_tokens_id)
|
| 845 |
+
start = 0
|
| 846 |
+
for i in range(1, len(labels)):
|
| 847 |
+
if labels[i - 1] >= 0 and labels[i] == -100:
|
| 848 |
+
start = i
|
| 849 |
+
if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
|
| 850 |
+
# [0, 1, 2, -100(start), -100, 3(i), 4]
|
| 851 |
+
length = i - start
|
| 852 |
+
if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id:
|
| 853 |
+
labels[start:start + suffix_len] = suffix_tokens_id
|
| 854 |
+
if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len:
|
| 855 |
+
loss_scale[start:start + suffix_len] = [1] * suffix_len
|
| 856 |
+
|
| 857 |
+
@staticmethod
|
| 858 |
+
def _get_std_messages(messages):
|
| 859 |
+
if messages and messages[0]['role'] == 'assistant':
|
| 860 |
+
messages.insert(0, {'role': 'user', 'content': ''}) # pretrain
|
| 861 |
+
if len(messages) % 2 == 1:
|
| 862 |
+
messages.append({'role': 'assistant', 'content': None}) # inference
|
| 863 |
+
|
| 864 |
+
def _jinja_encode(self, inputs: StdTemplateInputs):
|
| 865 |
+
messages = inputs.messages.copy()
|
| 866 |
+
if inputs.system is not None:
|
| 867 |
+
messages.insert(0, {'role': 'system', 'content': inputs.system})
|
| 868 |
+
if messages[-1]['content'] is None:
|
| 869 |
+
messages.pop()
|
| 870 |
+
add_generation_prompt = messages[-1]['role'] != 'assistant'
|
| 871 |
+
kwargs = {}
|
| 872 |
+
if inputs.tools:
|
| 873 |
+
kwargs['tools'] = inputs.tools
|
| 874 |
+
text = self.tokenizer.apply_chat_template(
|
| 875 |
+
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
|
| 876 |
+
answer_len = 1 if self.is_training else 0
|
| 877 |
+
return [text], [1.], answer_len
|
| 878 |
+
|
| 879 |
+
def _get_system(self, inputs) -> Optional[str]:
|
| 880 |
+
template_meta = self.template_meta
|
| 881 |
+
system = inputs.system
|
| 882 |
+
tools = inputs.tools
|
| 883 |
+
template_meta.check_system(system)
|
| 884 |
+
if system is None:
|
| 885 |
+
system = template_meta.default_system
|
| 886 |
+
|
| 887 |
+
if tools is not None:
|
| 888 |
+
system = self.agent_template._format_tools(tools, system or '', inputs.messages[0])
|
| 889 |
+
return system
|
| 890 |
+
|
| 891 |
+
@staticmethod
|
| 892 |
+
def _swift_prepare_function_call(agent_template, messages):
|
| 893 |
+
if len(messages) < 2:
|
| 894 |
+
return
|
| 895 |
+
i = 1
|
| 896 |
+
while i < len(messages):
|
| 897 |
+
pre_message, message = messages[i - 1], messages[i]
|
| 898 |
+
pre_role, pre_content = pre_message['role'], pre_message['content']
|
| 899 |
+
role, content = message['role'], message['content']
|
| 900 |
+
if pre_role == 'assistant' and role == 'tool':
|
| 901 |
+
i_start = i
|
| 902 |
+
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
|
| 903 |
+
i += 1
|
| 904 |
+
pre_message['content'], tool_content = agent_template._format_tool_responses(
|
| 905 |
+
pre_content, messages[i_start:i + 1])
|
| 906 |
+
messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}]
|
| 907 |
+
i = i_start + 1
|
| 908 |
+
elif pre_role == 'assistant' and role == 'assistant':
|
| 909 |
+
# Consecutive messages from the assistant role need to be merged to prevent errors.
|
| 910 |
+
pre_message['content'] = pre_content + content
|
| 911 |
+
messages.pop(i)
|
| 912 |
+
else:
|
| 913 |
+
i += 1
|
| 914 |
+
|
| 915 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 916 |
+
template_meta = self.template_meta
|
| 917 |
+
system = self._get_system(inputs)
|
| 918 |
+
self._swift_prepare_function_call(self.agent_template, inputs.messages)
|
| 919 |
+
|
| 920 |
+
self._get_std_messages(inputs.messages)
|
| 921 |
+
n_round = len(inputs.messages) // 2
|
| 922 |
+
if n_round > 1 and not self.template_meta.support_multi_round:
|
| 923 |
+
logger.warning_once(
|
| 924 |
+
'The template does not support multi-round chat. Only use the last round of the conversation.')
|
| 925 |
+
inputs.messages = inputs.messages[-2:]
|
| 926 |
+
|
| 927 |
+
res_context_list: List[Context] = []
|
| 928 |
+
res_context_types: List[ContextType] = []
|
| 929 |
+
sep_token = None
|
| 930 |
+
if template_meta.auto_add_bos:
|
| 931 |
+
all_tokens = self.tokenizer.encode('a')
|
| 932 |
+
single_token = self.tokenizer.encode('a', add_special_tokens=False)
|
| 933 |
+
assert len(single_token) == 1
|
| 934 |
+
idx = all_tokens.index(single_token[0])
|
| 935 |
+
bos_token = all_tokens[:idx]
|
| 936 |
+
sep_token = all_tokens[idx + 1:]
|
| 937 |
+
if bos_token:
|
| 938 |
+
res_context_list.append(bos_token)
|
| 939 |
+
res_context_types.append(ContextType.OTHER)
|
| 940 |
+
|
| 941 |
+
prefix = template_meta.system_prefix if system else template_meta.prefix
|
| 942 |
+
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)
|
| 943 |
+
|
| 944 |
+
n_round = len(inputs.messages) // 2
|
| 945 |
+
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
|
| 946 |
+
query_role, query = query_message['role'], query_message['content']
|
| 947 |
+
response_role, response = response_message['role'], response_message['content']
|
| 948 |
+
# TODO: Optimize the Template mechanism.
|
| 949 |
+
assert query_role in {'user', 'tool'}, f'query_role: {query_role}'
|
| 950 |
+
assert response_role in {'assistant'}, f'response_role: {response_role}'
|
| 951 |
+
if query_role == 'tool':
|
| 952 |
+
prompt = query
|
| 953 |
+
query = ''
|
| 954 |
+
elif template_meta.is_post_system and i == n_round - 1:
|
| 955 |
+
prompt = template_meta.system_prompt
|
| 956 |
+
else:
|
| 957 |
+
prompt = template_meta.prompt
|
| 958 |
+
|
| 959 |
+
context_list = prompt.copy()
|
| 960 |
+
extra_context_list = []
|
| 961 |
+
extra_context_type = None
|
| 962 |
+
if i < n_round - 1:
|
| 963 |
+
# Not the last round.
|
| 964 |
+
context_list.append('{{RESPONSE}}')
|
| 965 |
+
if inputs.messages[2 * (i + 1)]['role'] != 'tool':
|
| 966 |
+
extra_context_list = template_meta.chat_sep
|
| 967 |
+
extra_context_type = ContextType.OTHER
|
| 968 |
+
elif response is not None:
|
| 969 |
+
# It is the final round, and the response exists (during training).
|
| 970 |
+
context_list.append('{{RESPONSE}}')
|
| 971 |
+
if self.is_training and not sep_token:
|
| 972 |
+
extra_context_list = template_meta.suffix
|
| 973 |
+
extra_context_type = ContextType.SUFFIX
|
| 974 |
+
elif template_meta.response_prefix:
|
| 975 |
+
# final round and during inference.
|
| 976 |
+
context_list.append(template_meta.response_prefix)
|
| 977 |
+
|
| 978 |
+
self._concat_context_list(
|
| 979 |
+
context_list,
|
| 980 |
+
res_context_list,
|
| 981 |
+
res_context_types,
|
| 982 |
+
query=query,
|
| 983 |
+
response=response,
|
| 984 |
+
system=system,
|
| 985 |
+
round0=i)
|
| 986 |
+
res_context_list += extra_context_list
|
| 987 |
+
res_context_types += [extra_context_type] * len(extra_context_list)
|
| 988 |
+
if template_meta.auto_add_bos and sep_token:
|
| 989 |
+
res_context_list.append(sep_token)
|
| 990 |
+
res_context_types.append(ContextType.SUFFIX)
|
| 991 |
+
from swift.plugin import loss_scale_map
|
| 992 |
+
res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types,
|
| 993 |
+
inputs.messages)
|
| 994 |
+
if self.is_training:
|
| 995 |
+
answer_len = len(extra_context_list) + bool(response is not None)
|
| 996 |
+
else:
|
| 997 |
+
answer_len = 0
|
| 998 |
+
return res_context_list, loss_scale_list, answer_len
|
| 999 |
+
|
| 1000 |
+
def _encode_truncated(self, inputs):
|
| 1001 |
+
if self.mode in {'vllm', 'lmdeploy'}:
|
| 1002 |
+
encoded = Template._encode(self, inputs)
|
| 1003 |
+
for key in ['images', 'audios', 'videos']:
|
| 1004 |
+
encoded[key] = getattr(inputs, key)
|
| 1005 |
+
else:
|
| 1006 |
+
encoded = self._encode(inputs)
|
| 1007 |
+
|
| 1008 |
+
input_ids = encoded.get('input_ids')
|
| 1009 |
+
labels = encoded.get('labels')
|
| 1010 |
+
loss_scale = encoded.get('loss_scale')
|
| 1011 |
+
if self.max_length is not None:
|
| 1012 |
+
if self.truncation_strategy == 'right':
|
| 1013 |
+
input_ids = input_ids[:self.max_length]
|
| 1014 |
+
if labels is not None:
|
| 1015 |
+
labels = labels[:self.max_length]
|
| 1016 |
+
if loss_scale is not None:
|
| 1017 |
+
loss_scale = loss_scale[:self.max_length]
|
| 1018 |
+
elif self.truncation_strategy == 'left':
|
| 1019 |
+
if len(input_ids) > self.max_length:
|
| 1020 |
+
logger.warning_once(
|
| 1021 |
+
'Input data was left-truncated because its length exceeds `max_length` (input length: '
|
| 1022 |
+
f'{len(input_ids)}, max_length: {self.max_length}). '
|
| 1023 |
+
'This may cause loss of important tokens (e.g., image tokens) and lead to errors. '
|
| 1024 |
+
'To avoid this, consider increasing `max_length` or pre-filtering long sequences.',
|
| 1025 |
+
hash_id='max_length_check')
|
| 1026 |
+
input_ids = input_ids[-self.max_length:]
|
| 1027 |
+
if labels is not None:
|
| 1028 |
+
labels = labels[-self.max_length:]
|
| 1029 |
+
if loss_scale is not None:
|
| 1030 |
+
loss_scale = loss_scale[-self.max_length:]
|
| 1031 |
+
elif self.truncation_strategy == 'raise':
|
| 1032 |
+
length = len(input_ids or labels or [])
|
| 1033 |
+
if length > self.max_length:
|
| 1034 |
+
raise MaxLengthError(f'Current length of row({length}) is larger'
|
| 1035 |
+
f' than the max_length({self.max_length}).')
|
| 1036 |
+
encoded['input_ids'] = input_ids
|
| 1037 |
+
encoded['labels'] = labels
|
| 1038 |
+
encoded['loss_scale'] = loss_scale
|
| 1039 |
+
return encoded
|
| 1040 |
+
|
| 1041 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 1042 |
+
template_backend = self.template_backend
|
| 1043 |
+
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training
|
| 1044 |
+
and self.mode != 'seq_cls'):
|
| 1045 |
+
template_backend = 'jinja'
|
| 1046 |
+
logger.info_once(f'Setting template_backend: {template_backend}')
|
| 1047 |
+
res_context_list, loss_scale_list, answer_len = (
|
| 1048 |
+
self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs))
|
| 1049 |
+
encoded = {}
|
| 1050 |
+
if self.is_encoder_decoder:
|
| 1051 |
+
# tokenizer_kwargs: use prompt (qwen-audio)
|
| 1052 |
+
total_len = len(res_context_list)
|
| 1053 |
+
for key, _slice in zip(['prompt', 'answer'],
|
| 1054 |
+
[slice(0, total_len - answer_len),
|
| 1055 |
+
slice(total_len - answer_len, total_len)]):
|
| 1056 |
+
context_list, loss_scale = self._simplify_context_list(res_context_list[_slice],
|
| 1057 |
+
loss_scale_list[_slice], inputs)
|
| 1058 |
+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale)
|
| 1059 |
+
encoded[f'{key}_input_ids'] = input_ids
|
| 1060 |
+
if key == 'answer':
|
| 1061 |
+
encoded['labels'] = labels
|
| 1062 |
+
encoded['loss_scale'] = loss_scale
|
| 1063 |
+
input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids']
|
| 1064 |
+
else:
|
| 1065 |
+
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
|
| 1066 |
+
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
|
| 1067 |
+
res_context_list, loss_scale_list)
|
| 1068 |
+
self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0])
|
| 1069 |
+
|
| 1070 |
+
if tokenizer_kwargs:
|
| 1071 |
+
encoded['tokenizer_kwargs'] = tokenizer_kwargs
|
| 1072 |
+
|
| 1073 |
+
encoded['input_ids'] = input_ids
|
| 1074 |
+
encoded['labels'] = labels
|
| 1075 |
+
encoded['loss_scale'] = loss_scale
|
| 1076 |
+
if self.use_megatron:
|
| 1077 |
+
self._handle_megatron_cp(encoded)
|
| 1078 |
+
encoded['labels'] = encoded['labels'][1:] + [-100]
|
| 1079 |
+
encoded['position_ids'] = list(range(len(encoded['labels'])))
|
| 1080 |
+
elif encoded.get('labels') is not None:
|
| 1081 |
+
encoded['labels'][0] = -100
|
| 1082 |
+
if not self.is_training:
|
| 1083 |
+
for k in list(encoded.keys()):
|
| 1084 |
+
if k.endswith('labels') or k.endswith('loss_scale'):
|
| 1085 |
+
encoded[k] = None
|
| 1086 |
+
return encoded
|
| 1087 |
+
|
| 1088 |
+
def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None:
|
| 1089 |
+
cp_size = self.sequence_parallel_size
|
| 1090 |
+
if cp_size == 1:
|
| 1091 |
+
return
|
| 1092 |
+
input_ids = encoded['input_ids']
|
| 1093 |
+
padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids)
|
| 1094 |
+
input_ids += [self.tokenizer.pad_token_id] * padding_len
|
| 1095 |
+
encoded['labels'] += [-100] * padding_len
|
| 1096 |
+
|
| 1097 |
+
def debug_logger(self, inputs):
|
| 1098 |
+
if not strtobool(os.getenv('SWIFT_DEBUG', 'false')):
|
| 1099 |
+
return
|
| 1100 |
+
if 'input_ids' in inputs:
|
| 1101 |
+
k = 'input_ids'
|
| 1102 |
+
val = inputs['input_ids']
|
| 1103 |
+
elif 'generate_ids' in inputs:
|
| 1104 |
+
k = 'generate_ids'
|
| 1105 |
+
val = inputs['generate_ids']
|
| 1106 |
+
for v in val:
|
| 1107 |
+
self.print_inputs({k: v.tolist()})
|
| 1108 |
+
|
| 1109 |
+
@staticmethod
|
| 1110 |
+
def _split_list(inputs: List[int], x: int) -> List[List[int]]:
|
| 1111 |
+
idxs = findall(inputs, x)
|
| 1112 |
+
idxs.append(len(inputs))
|
| 1113 |
+
res = []
|
| 1114 |
+
lo = 0
|
| 1115 |
+
for idx in idxs:
|
| 1116 |
+
res.append(inputs[lo:idx])
|
| 1117 |
+
lo = idx + 1
|
| 1118 |
+
return res
|
| 1119 |
+
|
| 1120 |
+
def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]:
|
| 1121 |
+
context_list = []
|
| 1122 |
+
if self.mode in {'vllm', 'lmdeploy'}:
|
| 1123 |
+
video = inputs.videos.pop(inputs.video_idx)
|
| 1124 |
+
inputs.video_idx -= 1
|
| 1125 |
+
else:
|
| 1126 |
+
video = inputs.videos[inputs.video_idx]
|
| 1127 |
+
images = inputs.images
|
| 1128 |
+
new_images = load_video_func(video)
|
| 1129 |
+
inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:]
|
| 1130 |
+
for i in range(len(new_images)):
|
| 1131 |
+
context_list += replace_tag(i)
|
| 1132 |
+
inputs.image_idx += len(new_images)
|
| 1133 |
+
return context_list
|
| 1134 |
+
|
| 1135 |
+
def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]],
|
| 1136 |
+
num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]:
|
| 1137 |
+
if self.skip_prompt:
|
| 1138 |
+
generate_ids = generate_ids[..., num_prompt_tokens:]
|
| 1139 |
+
return generate_ids
|
| 1140 |
+
|
| 1141 |
+
def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str:
|
| 1142 |
+
return response
|
| 1143 |
+
|
| 1144 |
+
def pre_forward_hook(self, model: nn.Module, args, kwargs):
|
| 1145 |
+
from swift.llm import to_device
|
| 1146 |
+
old_kwargs = to_device(kwargs, model.device)
|
| 1147 |
+
kwargs = to_device(self._post_encode(model, old_kwargs), model.device)
|
| 1148 |
+
for k, v in old_kwargs.items():
|
| 1149 |
+
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs:
|
| 1150 |
+
kwargs[k] = v
|
| 1151 |
+
if 'inputs_embeds' in kwargs:
|
| 1152 |
+
kwargs.pop('input_ids', None)
|
| 1153 |
+
|
| 1154 |
+
if isinstance(model, PeftModel):
|
| 1155 |
+
parameters = inspect.signature(model.model.forward).parameters
|
| 1156 |
+
else:
|
| 1157 |
+
parameters = inspect.signature(model.forward).parameters
|
| 1158 |
+
if 'position_ids' not in parameters:
|
| 1159 |
+
kwargs.pop('position_ids', None)
|
| 1160 |
+
return args, kwargs
|
| 1161 |
+
|
| 1162 |
+
@property
|
| 1163 |
+
def is_training(self):
|
| 1164 |
+
return self.mode not in {'vllm', 'lmdeploy', 'pt'}
|
| 1165 |
+
|
| 1166 |
+
def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
|
| 1167 |
+
self.mode = mode
|
| 1168 |
+
|
| 1169 |
+
def register_post_encode_hook(self, models: List[nn.Module]) -> None:
|
| 1170 |
+
"""This function is important for multi-modal training, as it registers the post_encode method
|
| 1171 |
+
as a forward hook, converting input_ids into inputs_embeds.
|
| 1172 |
+
"""
|
| 1173 |
+
if self._handles:
|
| 1174 |
+
return
|
| 1175 |
+
|
| 1176 |
+
for model in models:
|
| 1177 |
+
# please use torch>=2.0
|
| 1178 |
+
handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True)
|
| 1179 |
+
self._handles.append((model, handle))
|
| 1180 |
+
|
| 1181 |
+
if is_deepspeed_zero3_enabled():
|
| 1182 |
+
import deepspeed
|
| 1183 |
+
self._deepspeed_initialize = deepspeed.initialize
|
| 1184 |
+
|
| 1185 |
+
@wraps(self._deepspeed_initialize)
|
| 1186 |
+
def _initialize(*args, **kwargs):
|
| 1187 |
+
res = self._deepspeed_initialize(*args, **kwargs)
|
| 1188 |
+
for model, handle in self._handles:
|
| 1189 |
+
model._forward_pre_hooks.move_to_end(handle.id)
|
| 1190 |
+
return res
|
| 1191 |
+
|
| 1192 |
+
deepspeed.initialize = _initialize
|
| 1193 |
+
|
| 1194 |
+
def remove_post_encode_hook(self):
|
| 1195 |
+
models = []
|
| 1196 |
+
for model, handle in self._handles:
|
| 1197 |
+
models.append(model)
|
| 1198 |
+
handle.remove()
|
| 1199 |
+
self._handles = []
|
| 1200 |
+
|
| 1201 |
+
if self._deepspeed_initialize is not None:
|
| 1202 |
+
import deepspeed
|
| 1203 |
+
deepspeed.initialize = self._deepspeed_initialize
|
| 1204 |
+
self._deepspeed_initialize = None
|
| 1205 |
+
return models
|
| 1206 |
+
|
| 1207 |
+
def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1208 |
+
if self.mode == 'rlhf':
|
| 1209 |
+
return self._rlhf_data_collator(batch, padding_to=padding_to)
|
| 1210 |
+
elif self.mode == 'kto':
|
| 1211 |
+
return self._kto_data_collator(batch, padding_to=padding_to)
|
| 1212 |
+
elif self.mode in {'pt', 'train', 'prm'}:
|
| 1213 |
+
return self._data_collator(batch, padding_to=padding_to)
|
| 1214 |
+
elif self.mode == 'seq_cls':
|
| 1215 |
+
return self._seq_cls_data_collator(batch, padding_to=padding_to)
|
| 1216 |
+
elif self.mode == 'embedding':
|
| 1217 |
+
return self._embedding_data_collator(batch, padding_to=padding_to)
|
| 1218 |
+
|
| 1219 |
+
@staticmethod
|
| 1220 |
+
def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]:
|
| 1221 |
+
new_batch = []
|
| 1222 |
+
for inputs in batch:
|
| 1223 |
+
new_inputs = {}
|
| 1224 |
+
for k, v in inputs.items():
|
| 1225 |
+
if k.startswith(prefix):
|
| 1226 |
+
new_inputs[k[len(prefix):]] = v
|
| 1227 |
+
new_batch.append(new_inputs)
|
| 1228 |
+
return new_batch
|
| 1229 |
+
|
| 1230 |
+
@staticmethod
|
| 1231 |
+
def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]:
|
| 1232 |
+
from swift.llm import RowPreprocessor
|
| 1233 |
+
keys = keys or []
|
| 1234 |
+
rows = RowPreprocessor.rows_to_batched(batch)
|
| 1235 |
+
return {k: rows[k] for k in keys if rows.get(k) is not None}
|
| 1236 |
+
|
| 1237 |
+
@staticmethod
|
| 1238 |
+
def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]:
|
| 1239 |
+
# List[Tensor] -> List[Tensor]
|
| 1240 |
+
res = []
|
| 1241 |
+
for b in batch:
|
| 1242 |
+
if b.get(attr_name) is not None:
|
| 1243 |
+
res += b.pop(attr_name)
|
| 1244 |
+
return res
|
| 1245 |
+
|
| 1246 |
+
@staticmethod
|
| 1247 |
+
def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]:
|
| 1248 |
+
res = []
|
| 1249 |
+
for b in batch:
|
| 1250 |
+
if b.get(attr_name) is not None:
|
| 1251 |
+
res.append(b.pop(attr_name))
|
| 1252 |
+
return torch.concat(res, dim=dim) if res else None
|
| 1253 |
+
|
| 1254 |
+
def _rlhf_data_collator(self,
|
| 1255 |
+
batch: List[Dict[str, Any]],
|
| 1256 |
+
*,
|
| 1257 |
+
chosen_prefix: str = 'chosen_',
|
| 1258 |
+
rejected_prefix: str = 'rejected_',
|
| 1259 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1260 |
+
new_batch = []
|
| 1261 |
+
for prefix in [chosen_prefix, rejected_prefix]:
|
| 1262 |
+
new_batch += self._fetch_inputs_startswith(batch, prefix)
|
| 1263 |
+
return self._data_collator(new_batch, padding_to=padding_to)
|
| 1264 |
+
|
| 1265 |
+
def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1266 |
+
new_batch = self._fetch_inputs_startswith(batch, 'chosen_')
|
| 1267 |
+
kl_batch = self._fetch_inputs_startswith(batch, 'rejected_')
|
| 1268 |
+
|
| 1269 |
+
res = self._data_collator(new_batch, padding_to=padding_to)
|
| 1270 |
+
kl_res = self._data_collator(kl_batch, padding_to=padding_to)
|
| 1271 |
+
res = {
|
| 1272 |
+
**{f'completion_{k}': v
|
| 1273 |
+
for k, v in res.items()},
|
| 1274 |
+
**{f'KL_completion_{k}': v
|
| 1275 |
+
for k, v in kl_res.items()},
|
| 1276 |
+
}
|
| 1277 |
+
label = [b['label'] for b in batch if b.get('label') is not None]
|
| 1278 |
+
if label:
|
| 1279 |
+
res['label'] = label
|
| 1280 |
+
return res
|
| 1281 |
+
|
| 1282 |
+
def _embedding_data_collator(self,
|
| 1283 |
+
batch: List[Dict[str, Any]],
|
| 1284 |
+
*,
|
| 1285 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1286 |
+
labels = []
|
| 1287 |
+
new_batch = []
|
| 1288 |
+
for b in batch:
|
| 1289 |
+
keys = [key for key in b.keys() if 'negative' in key]
|
| 1290 |
+
max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None
|
| 1291 |
+
indexes = ['anchor_', 'positive_']
|
| 1292 |
+
if max_neg is not None:
|
| 1293 |
+
for i in range(0, max_neg + 1):
|
| 1294 |
+
indexes.append(f'negative{i}_')
|
| 1295 |
+
for prefix in indexes:
|
| 1296 |
+
new_batch += self._fetch_inputs_startswith([b], prefix)
|
| 1297 |
+
labels.extend(b.get('labels', None))
|
| 1298 |
+
res = self._data_collator(new_batch, padding_to=padding_to)
|
| 1299 |
+
if labels:
|
| 1300 |
+
res['labels'] = torch.tensor(labels, dtype=torch.float32)
|
| 1301 |
+
return res
|
| 1302 |
+
|
| 1303 |
+
def _seq_cls_data_collator(self,
|
| 1304 |
+
batch: List[Dict[str, Any]],
|
| 1305 |
+
*,
|
| 1306 |
+
padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1307 |
+
labels = [b.pop('labels') for b in batch if b.get('labels') is not None]
|
| 1308 |
+
res = self._data_collator(batch, padding_to=padding_to)
|
| 1309 |
+
if labels:
|
| 1310 |
+
problem_type = self._get_problem_type(self.config)
|
| 1311 |
+
if problem_type == 'regression':
|
| 1312 |
+
labels = torch.tensor(labels, dtype=torch.float32)
|
| 1313 |
+
elif problem_type == 'multi_label_classification':
|
| 1314 |
+
one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32)
|
| 1315 |
+
for i, label in enumerate(labels):
|
| 1316 |
+
one_hot_labels[i, label] = 1
|
| 1317 |
+
labels = one_hot_labels
|
| 1318 |
+
else:
|
| 1319 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
| 1320 |
+
res['labels'] = labels
|
| 1321 |
+
return res
|
| 1322 |
+
|
| 1323 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 1324 |
+
"""
|
| 1325 |
+
Args:
|
| 1326 |
+
batch(`List[Dict[str, Any]]`): The input data in batch
|
| 1327 |
+
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
|
| 1328 |
+
will be padded to the `longest`
|
| 1329 |
+
"""
|
| 1330 |
+
assert self.tokenizer.pad_token_id is not None
|
| 1331 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 1332 |
+
padding_right = padding_side == 'right'
|
| 1333 |
+
packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0]
|
| 1334 |
+
res = {}
|
| 1335 |
+
if packing_mode:
|
| 1336 |
+
# only support llm
|
| 1337 |
+
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']:
|
| 1338 |
+
v = self.gather_list(batch, k)
|
| 1339 |
+
if v:
|
| 1340 |
+
res[k] = [v]
|
| 1341 |
+
else:
|
| 1342 |
+
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
|
| 1343 |
+
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
|
| 1344 |
+
if inputs_embeds:
|
| 1345 |
+
res['inputs_embeds'] = inputs_embeds
|
| 1346 |
+
if input_ids:
|
| 1347 |
+
res['input_ids'] = input_ids
|
| 1348 |
+
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
|
| 1349 |
+
val = [b[key] for b in batch if b.get(key) is not None]
|
| 1350 |
+
if val:
|
| 1351 |
+
res[key] = val
|
| 1352 |
+
|
| 1353 |
+
keys = [
|
| 1354 |
+
'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids'
|
| 1355 |
+
]
|
| 1356 |
+
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0]
|
| 1357 |
+
# Convert to tensor and remove unnecessary dimensions.
|
| 1358 |
+
seq_lens = None
|
| 1359 |
+
for key in keys:
|
| 1360 |
+
if key not in res:
|
| 1361 |
+
continue
|
| 1362 |
+
for i, val in enumerate(res[key]):
|
| 1363 |
+
if isinstance(val, (list, tuple)):
|
| 1364 |
+
val = torch.tensor(val)
|
| 1365 |
+
elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2:
|
| 1366 |
+
val = val[0]
|
| 1367 |
+
res[key][i] = val
|
| 1368 |
+
if not seq_lens:
|
| 1369 |
+
seq_lens = [seq.shape[0] for seq in res[key]]
|
| 1370 |
+
if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res):
|
| 1371 |
+
res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens]
|
| 1372 |
+
if self.is_training and self.padding_side == 'left':
|
| 1373 |
+
res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens]
|
| 1374 |
+
|
| 1375 |
+
if self.use_megatron:
|
| 1376 |
+
padding_to = math.ceil(max(seq_lens) / 128) * 128
|
| 1377 |
+
cp_size = self.sequence_parallel_size
|
| 1378 |
+
if cp_size > 1:
|
| 1379 |
+
padding_len = padding_to - seq_lens[0]
|
| 1380 |
+
position_ids = res['position_ids'][0].tolist()
|
| 1381 |
+
position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2))
|
| 1382 |
+
res['position_ids'][0] = torch.tensor(position_ids)
|
| 1383 |
+
|
| 1384 |
+
for key, pad_value in zip(keys, pad_values):
|
| 1385 |
+
if key not in res:
|
| 1386 |
+
continue
|
| 1387 |
+
if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1:
|
| 1388 |
+
pass
|
| 1389 |
+
elif padding_to is not None:
|
| 1390 |
+
padding_len = padding_to - seq_lens[0]
|
| 1391 |
+
if padding_len > 0:
|
| 1392 |
+
res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
|
| 1393 |
+
'constant', pad_value)
|
| 1394 |
+
res[key] = self._pad_sequence(res[key], pad_value)
|
| 1395 |
+
|
| 1396 |
+
# multimodal
|
| 1397 |
+
res.update(self._data_collator_mm_data(batch))
|
| 1398 |
+
if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1):
|
| 1399 |
+
res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)
|
| 1400 |
+
|
| 1401 |
+
return res
|
| 1402 |
+
|
| 1403 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1404 |
+
# multimodal
|
| 1405 |
+
res = {}
|
| 1406 |
+
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
|
| 1407 |
+
if len(pixel_values) > 0:
|
| 1408 |
+
res['pixel_values'] = torch.concat(pixel_values)
|
| 1409 |
+
|
| 1410 |
+
image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
|
| 1411 |
+
if len(image_sizes) > 0:
|
| 1412 |
+
res['image_sizes'] = torch.concat(image_sizes)
|
| 1413 |
+
|
| 1414 |
+
pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
|
| 1415 |
+
if len(pixel_values_videos) > 0:
|
| 1416 |
+
res['pixel_values_videos'] = torch.concat(pixel_values_videos)
|
| 1417 |
+
return res
|
| 1418 |
+
|
| 1419 |
+
def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side):
|
| 1420 |
+
# torchacc & xtuner
|
| 1421 |
+
input_ids = res.get('input_ids')
|
| 1422 |
+
attention_mask = res.get('attention_mask')
|
| 1423 |
+
labels = res.get('labels')
|
| 1424 |
+
loss_scale = res.get('loss_scale')
|
| 1425 |
+
if use_torchacc():
|
| 1426 |
+
from swift.utils.torchacc_utils import pad_and_split_batch
|
| 1427 |
+
rank, _, world_size, _ = get_dist_setting()
|
| 1428 |
+
input_ids, attention_mask, labels, loss_scale = pad_and_split_batch(
|
| 1429 |
+
padding_to,
|
| 1430 |
+
input_ids,
|
| 1431 |
+
attention_mask,
|
| 1432 |
+
labels,
|
| 1433 |
+
loss_scale,
|
| 1434 |
+
self.max_length,
|
| 1435 |
+
tokenizer,
|
| 1436 |
+
rank,
|
| 1437 |
+
world_size,
|
| 1438 |
+
padding_right=padding_side == 'right')
|
| 1439 |
+
if self.sequence_parallel_size > 1 and input_ids is not None:
|
| 1440 |
+
bs, seq_len = input_ids.shape
|
| 1441 |
+
if 'position_ids' not in res:
|
| 1442 |
+
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
|
| 1443 |
+
else:
|
| 1444 |
+
position_ids = res['position_ids']
|
| 1445 |
+
assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right'
|
| 1446 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 1447 |
+
if sequence_parallel.world_size() > 1:
|
| 1448 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 1449 |
+
input_ids, _, labels, position_ids, attention_mask, loss_scale = \
|
| 1450 |
+
sequence_parallel.pad_and_split_inputs(
|
| 1451 |
+
tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale)
|
| 1452 |
+
res['position_ids'] = position_ids
|
| 1453 |
+
_local_var = locals()
|
| 1454 |
+
for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']:
|
| 1455 |
+
value = _local_var[key]
|
| 1456 |
+
if value is not None:
|
| 1457 |
+
res[key] = value
|
| 1458 |
+
return res
|
| 1459 |
+
|
| 1460 |
+
def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
| 1461 |
+
if tokenizer_kwargs is None:
|
| 1462 |
+
tokenizer_kwargs = {}
|
| 1463 |
+
for key in [
|
| 1464 |
+
'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'
|
| 1465 |
+
]:
|
| 1466 |
+
val = inputs.get(key) # fix val is a tensor
|
| 1467 |
+
if val is None:
|
| 1468 |
+
val = inputs.get(f'{key}_ids')
|
| 1469 |
+
if val is not None:
|
| 1470 |
+
key_upper = key.upper()
|
| 1471 |
+
logger.info(f'[{key_upper}_IDS] {val}')
|
| 1472 |
+
if key == 'labels' and self.mode in {'seq_cls', 'embedding'}:
|
| 1473 |
+
continue
|
| 1474 |
+
if isinstance(val, (list, tuple, torch.Tensor)):
|
| 1475 |
+
val_str = self.safe_decode(val, **tokenizer_kwargs)
|
| 1476 |
+
logger.info(f'[{key_upper}] {val_str}')
|
| 1477 |
+
if inputs.get('loss_scale') is not None:
|
| 1478 |
+
val = inputs['loss_scale']
|
| 1479 |
+
logger.info(f'[LOSS_SCALE] {val}')
|
| 1480 |
+
|
| 1481 |
+
async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None:
|
| 1482 |
+
images = inputs.pop('images', None) or []
|
| 1483 |
+
if len(images) == 0:
|
| 1484 |
+
return
|
| 1485 |
+
input_ids = inputs['input_ids']
|
| 1486 |
+
idx_list = findall(input_ids, -100)
|
| 1487 |
+
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
|
| 1488 |
+
idx_list.insert(0, -1)
|
| 1489 |
+
new_input_ids = []
|
| 1490 |
+
for i in range(len(idx_list) - 1):
|
| 1491 |
+
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
|
| 1492 |
+
images[i]['offset'] = len(new_input_ids)
|
| 1493 |
+
new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens']
|
| 1494 |
+
new_input_ids += input_ids[idx_list[-1] + 1:]
|
| 1495 |
+
inputs['input_ids'] = new_input_ids
|
| 1496 |
+
inputs['multimodal'] = images
|
| 1497 |
+
|
| 1498 |
+
async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None:
|
| 1499 |
+
images = inputs.pop('images', None) or []
|
| 1500 |
+
if len(images) == 0:
|
| 1501 |
+
return
|
| 1502 |
+
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
|
| 1503 |
+
input_ids = inputs['input_ids']
|
| 1504 |
+
idx_list = findall(input_ids, -100)
|
| 1505 |
+
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
|
| 1506 |
+
idx_list.insert(0, -1)
|
| 1507 |
+
new_input_ids = []
|
| 1508 |
+
ranges = []
|
| 1509 |
+
for i in range(len(idx_list) - 1):
|
| 1510 |
+
_range = []
|
| 1511 |
+
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
|
| 1512 |
+
_range.append(len(new_input_ids))
|
| 1513 |
+
new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0]
|
| 1514 |
+
_range.append(len(new_input_ids))
|
| 1515 |
+
ranges.append(_range)
|
| 1516 |
+
new_input_ids += input_ids[idx_list[-1] + 1:]
|
| 1517 |
+
inputs['input_embeddings'] = [image.to('cpu') for image in images]
|
| 1518 |
+
inputs['input_embedding_ranges'] = ranges
|
| 1519 |
+
inputs['input_ids'] = new_input_ids
|
| 1520 |
+
|
| 1521 |
+
def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor:
|
| 1522 |
+
"""Pad sequence by some side
|
| 1523 |
+
|
| 1524 |
+
Args:
|
| 1525 |
+
sequences: The input sequences in tensor.
|
| 1526 |
+
padding_value: The padding value
|
| 1527 |
+
|
| 1528 |
+
Returns:
|
| 1529 |
+
A tensor after padding
|
| 1530 |
+
"""
|
| 1531 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 1532 |
+
padding_right = padding_side == 'right'
|
| 1533 |
+
if padding_right:
|
| 1534 |
+
return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
|
| 1535 |
+
|
| 1536 |
+
max_len = max([s.shape[0] for s in sequences])
|
| 1537 |
+
|
| 1538 |
+
padded_sequences = []
|
| 1539 |
+
for seq in sequences:
|
| 1540 |
+
pad_length = max_len - seq.shape[0]
|
| 1541 |
+
pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
|
| 1542 |
+
padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
|
| 1543 |
+
padded_sequences.append(padded_seq)
|
| 1544 |
+
|
| 1545 |
+
return torch.stack(padded_sequences)
|
| 1546 |
+
|
| 1547 |
+
def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str:
|
| 1548 |
+
if isinstance(self, Template):
|
| 1549 |
+
tokenizer = self.tokenizer
|
| 1550 |
+
placeholder_tokens = self.placeholder_tokens
|
| 1551 |
+
else:
|
| 1552 |
+
tokenizer = self
|
| 1553 |
+
placeholder_tokens = []
|
| 1554 |
+
|
| 1555 |
+
def _is_special(token: int) -> bool:
|
| 1556 |
+
if isinstance(token, float) or token < 0:
|
| 1557 |
+
return True
|
| 1558 |
+
return token in placeholder_tokens
|
| 1559 |
+
|
| 1560 |
+
if isinstance(input_ids, torch.Tensor):
|
| 1561 |
+
input_ids = input_ids.tolist()
|
| 1562 |
+
if len(input_ids) == 0:
|
| 1563 |
+
return ''
|
| 1564 |
+
result_str = ''
|
| 1565 |
+
for i in range(len(input_ids)):
|
| 1566 |
+
if i == 0:
|
| 1567 |
+
if _is_special(input_ids[i]):
|
| 1568 |
+
s = 0
|
| 1569 |
+
else:
|
| 1570 |
+
e = 0
|
| 1571 |
+
continue
|
| 1572 |
+
if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]):
|
| 1573 |
+
s = i
|
| 1574 |
+
result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs)
|
| 1575 |
+
if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
|
| 1576 |
+
e = i
|
| 1577 |
+
result_str += f'[{input_ids[i - 1]} * {e - s}]'
|
| 1578 |
+
if _is_special(input_ids[i]):
|
| 1579 |
+
result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
|
| 1580 |
+
else:
|
| 1581 |
+
result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
|
| 1582 |
+
return result_str
|
| 1583 |
+
|
| 1584 |
+
@staticmethod
|
| 1585 |
+
@contextmanager
|
| 1586 |
+
def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False):
|
| 1587 |
+
_origin_flash_attention_forward = modeling_module._flash_attention_forward
|
| 1588 |
+
|
| 1589 |
+
def _flash_attention_forward(*args, **kwargs):
|
| 1590 |
+
if use_new_func:
|
| 1591 |
+
from transformers.modeling_flash_attention_utils import (_flash_attention_forward as
|
| 1592 |
+
flash_attention_forward)
|
| 1593 |
+
if args and isinstance(args[0], nn.Module):
|
| 1594 |
+
args = args[1:]
|
| 1595 |
+
if 'is_causal' not in kwargs:
|
| 1596 |
+
kwargs['is_causal'] = True
|
| 1597 |
+
else:
|
| 1598 |
+
flash_attention_forward = _origin_flash_attention_forward
|
| 1599 |
+
kwargs['position_ids'] = position_ids
|
| 1600 |
+
return flash_attention_forward(*args, **kwargs)
|
| 1601 |
+
|
| 1602 |
+
modeling_module._flash_attention_forward = _flash_attention_forward
|
| 1603 |
+
try:
|
| 1604 |
+
yield
|
| 1605 |
+
finally:
|
| 1606 |
+
modeling_module._flash_attention_forward = _origin_flash_attention_forward
|
ms-swift/swift/llm/template/grounding.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import colorsys
|
| 2 |
+
import itertools
|
| 3 |
+
import os
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import Any, List, Literal
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
from modelscope.hub.utils.utils import get_cache_dir
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _shuffle_colors(nums: List[Any]) -> List[Any]:
|
| 13 |
+
if len(nums) == 1:
|
| 14 |
+
return nums
|
| 15 |
+
|
| 16 |
+
mid = len(nums) // 2
|
| 17 |
+
|
| 18 |
+
left = nums[:mid]
|
| 19 |
+
right = nums[mid:]
|
| 20 |
+
left = _shuffle_colors(left)
|
| 21 |
+
right = _shuffle_colors(right)
|
| 22 |
+
new_nums = []
|
| 23 |
+
for x, y in zip(left, right):
|
| 24 |
+
new_nums += [x, y]
|
| 25 |
+
new_nums += left[len(right):] or right[len(left):]
|
| 26 |
+
return new_nums
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_colors():
|
| 30 |
+
vs_combinations = [(v, s) for v, s in itertools.product([0.7, 0.3, 1], [0.7, 0.3, 1])]
|
| 31 |
+
colors = [colorsys.hsv_to_rgb(i / 16, s, v) for v, s in vs_combinations for i in _shuffle_colors(list(range(16)))]
|
| 32 |
+
colors = [(int(r * 255), int(g * 255), int(b * 255)) for r, g, b in colors]
|
| 33 |
+
return _shuffle_colors(colors)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def download_file(url: str) -> str:
|
| 37 |
+
url = url.rstrip('/')
|
| 38 |
+
file_name = url.rsplit('/', 1)[-1]
|
| 39 |
+
cache_dir = os.path.join(get_cache_dir(), 'files')
|
| 40 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 41 |
+
req = requests.get(url)
|
| 42 |
+
file_path = os.path.join(cache_dir, file_name)
|
| 43 |
+
with open(file_path, 'wb') as f:
|
| 44 |
+
f.write(req.content)
|
| 45 |
+
return file_path
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
colors = generate_colors()
|
| 49 |
+
color_mapping = {}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _calculate_brightness(image, region: List[int]):
|
| 53 |
+
cropped_image = image.crop(region)
|
| 54 |
+
grayscale_image = cropped_image.convert('L')
|
| 55 |
+
pixels = list(grayscale_image.getdata())
|
| 56 |
+
average_brightness = sum(pixels) / len(pixels)
|
| 57 |
+
return average_brightness
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def draw_bbox(image: Image.Image,
|
| 61 |
+
ref: List[str],
|
| 62 |
+
bbox: List[List[int]],
|
| 63 |
+
norm_bbox: Literal['norm1000', 'none'] = 'norm1000'):
|
| 64 |
+
bbox = deepcopy(bbox)
|
| 65 |
+
font_path = 'https://modelscope.cn/models/Qwen/Qwen-VL-Chat/resolve/master/SimSun.ttf'
|
| 66 |
+
# norm bbox
|
| 67 |
+
for i, box in enumerate(bbox):
|
| 68 |
+
for i in range(len(box)):
|
| 69 |
+
box[i] = int(box[i])
|
| 70 |
+
if norm_bbox == 'norm1000':
|
| 71 |
+
box[0] = box[0] / 1000 * image.width
|
| 72 |
+
box[2] = box[2] / 1000 * image.width
|
| 73 |
+
box[1] = box[1] / 1000 * image.height
|
| 74 |
+
box[3] = box[3] / 1000 * image.height
|
| 75 |
+
|
| 76 |
+
draw = ImageDraw.Draw(image)
|
| 77 |
+
# draw bbox
|
| 78 |
+
assert len(ref) == len(bbox), f'len(refs): {len(ref)}, len(bboxes): {len(bbox)}'
|
| 79 |
+
for (left, top, right, bottom), box_ref in zip(bbox, ref):
|
| 80 |
+
if box_ref not in color_mapping:
|
| 81 |
+
color_mapping[box_ref] = colors[len(color_mapping) % len(colors)]
|
| 82 |
+
color = color_mapping[box_ref]
|
| 83 |
+
draw.rectangle([(left, top), (right, bottom)], outline=color, width=3)
|
| 84 |
+
# draw text
|
| 85 |
+
file_path = download_file(font_path)
|
| 86 |
+
font = ImageFont.truetype(file_path, 20)
|
| 87 |
+
for (left, top, _, _), box_ref in zip(bbox, ref):
|
| 88 |
+
brightness = _calculate_brightness(
|
| 89 |
+
image, [left, top, min(left + 100, image.width),
|
| 90 |
+
min(top + 20, image.height)])
|
| 91 |
+
draw.text((left, top), box_ref, fill='white' if brightness < 128 else 'black', font=font)
|
ms-swift/swift/llm/template/register.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Literal, Optional
|
| 4 |
+
|
| 5 |
+
from ..utils import Processor
|
| 6 |
+
from .base import Template
|
| 7 |
+
from .template_meta import TemplateMeta
|
| 8 |
+
|
| 9 |
+
TEMPLATE_MAPPING: Dict[str, TemplateMeta] = {}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def register_template(template_meta: TemplateMeta, *, exist_ok: bool = False) -> None:
|
| 13 |
+
template_type = template_meta.template_type
|
| 14 |
+
if not exist_ok and template_type in TEMPLATE_MAPPING:
|
| 15 |
+
raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
|
| 16 |
+
TEMPLATE_MAPPING[template_type] = template_meta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_template(
|
| 20 |
+
template_type: str,
|
| 21 |
+
processor: Processor,
|
| 22 |
+
default_system: Optional[str] = None,
|
| 23 |
+
max_length: Optional[int] = None,
|
| 24 |
+
*,
|
| 25 |
+
use_chat_template: bool = True,
|
| 26 |
+
template_backend: Literal['swift', 'jinja'] = 'swift',
|
| 27 |
+
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise',
|
| 28 |
+
max_pixels: Optional[int] = None, # h * w
|
| 29 |
+
agent_template: Optional[str] = None,
|
| 30 |
+
norm_bbox: Literal['norm1000', 'none', None] = None,
|
| 31 |
+
response_prefix: Optional[str] = None,
|
| 32 |
+
# train
|
| 33 |
+
padding_side: Literal['left', 'right'] = 'right',
|
| 34 |
+
loss_scale: str = 'default',
|
| 35 |
+
sequence_parallel_size: int = 1) -> 'Template':
|
| 36 |
+
template_meta = TEMPLATE_MAPPING[template_type]
|
| 37 |
+
template_cls = template_meta.template_cls
|
| 38 |
+
return template_cls(
|
| 39 |
+
processor,
|
| 40 |
+
template_meta,
|
| 41 |
+
default_system,
|
| 42 |
+
max_length,
|
| 43 |
+
use_chat_template=use_chat_template,
|
| 44 |
+
template_backend=template_backend,
|
| 45 |
+
truncation_strategy=truncation_strategy,
|
| 46 |
+
max_pixels=max_pixels,
|
| 47 |
+
agent_template=agent_template,
|
| 48 |
+
norm_bbox=norm_bbox,
|
| 49 |
+
response_prefix=response_prefix,
|
| 50 |
+
padding_side=padding_side,
|
| 51 |
+
loss_scale=loss_scale,
|
| 52 |
+
sequence_parallel_size=sequence_parallel_size,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_template_meta(template_type: str) -> TemplateMeta:
|
| 57 |
+
return TEMPLATE_MAPPING[template_type]
|
ms-swift/swift/llm/template/template/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
|
| 2 |
+
minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
|